"""
source_simulation.py
Module for simulating synthetic brain activity data for source-level measurements.
Specifically simulating event-related potential (ERP)-like signals for use in
neuroimaging research (e.g., MEG/EEG source simulation). It supports flexible
configuration of ERP waveform properties, source orientation (fixed or free),
and trial-based simulation with reproducible randomization.
"""
import os
from pathlib import Path
import logging
from typing import Optional, Tuple, Union, Dict, List, Any
import numpy as np
from numpy.random import Generator
from scipy.stats import wishart
from scipy.signal import butter, filtfilt
import mne
from mne.io.constants import FIFF
import matplotlib.pyplot as plt
import matplotlib.cm as cm # Import colormap functionality
from matplotlib.lines import Line2D # Import for custom legend
import matplotlib.gridspec as gridspec
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable # For better colorbar placement
from calibrain.utils import load_config
[docs]
class SourceSimulator:
"""Simulates synthetic brain activity data for source-level measurements.
"""
[docs]
def __init__(
self,
ERP_config: Optional[Dict[str, Any]] = None,
logger: Optional[logging.Logger] = None,
):
"""Initialize the SourceSimulator with parameters for simulating dipole sources.
Parameters
----------
ERP_config : Optional[Dict[str, Any]]
Configuration dictionary for the ERP simulation parameters. If None, default values are used.
Default values include:
- tmin: -0.5 (start time of the ERP segment in seconds)
- tmax: 0.5 (end time of the ERP segment in seconds)
- stim_onset: 0.0 (time of stimulus onset in seconds, relative to the start of the ERP segment)
- sfreq: 250 (sampling frequency in Hz)
- fmin: 1 (minimum frequency for the bandpass filter in Hz)
- fmax: 5 (maximum frequency for the bandpass filter in Hz)
- amplitude: 1.0 (amplitude of the ERP waveform)
- random_erp_timing: True (if True, the exact start time and duration of the ERP waveform within the post-stimulus window are randomized)
- erp_min_length : Optional[int] (minimum length of the ERP waveform in samples; if None, a default value is used)
logger : Optional[logging.Logger], optional
Logger instance, by default None.
"""
self.ERP_config = ERP_config if ERP_config else {
"tmin": -0.5,
"tmax": 0.5,
"stim_onset": 0.0,
"sfreq": 250,
"fmin": 1,
"fmax": 5,
"amplitude": 1.0,
"random_erp_timing": True,
"erp_min_length": None,
}
self.logger = logger if logger else logging.getLogger(__name__)
# Default units for ERP simulation
self.source_units: str = FIFF.FIFF_UNIT_AM # Amperes (Am)
def _simulate_erp_waveform(
self,
source_seed: int = 512,
) -> np.ndarray:
"""
Generate a smoothed ERP-like waveform for a single source.
This method creates an ERP-like signal segment using bandpass-filtered white noise, applies a Hanning window, normalizes and scales by the specified amplitude, and places the segment at a randomized or fixed position after the stimulus onset within the time course.
Parameters
----------
source_seed : int
Seed for the random number generator to ensure reproducibility of the ERP waveform generation. Default is 512.
Returns
-------
np.ndarray
The generated ERP signal of length n_times.
Notes
-----
- The output signal is zero-padded before the `stim_onset` to ensure it starts with zeros, simulating the pre-stimulus baseline.
- If `random_erp_timing` is True, the exact start time
(offset from `stim_onset`) and duration of the ERP waveform within the
post-`stim_onset` window are randomized. The ERP will still be contained
entirely within the `stim_onset` to `n_times` interval.
- If `random_erp_timing` is False, the ERP waveform
spans the entire duration from `stim_onset` to `n_times`.
"""
# Extract ERP configuration parameters
tmin = self.ERP_config['tmin']
tmax = self.ERP_config['tmax']
stim_onset = self.ERP_config['stim_onset']
sfreq = self.ERP_config['sfreq']
fmin = self.ERP_config['fmin']
fmax = self.ERP_config['fmax']
amplitude = self.ERP_config['amplitude']
random_erp_timing = self.ERP_config['random_erp_timing']
erp_min_length = self.ERP_config['erp_min_length']
# Ensure stim_onset is within [tmin, tmax]
if stim_onset < tmin or stim_onset > tmax:
raise ValueError(f"stim_onset ({stim_onset}) is outside the time range [{tmin}, {tmax}]")
source_duration_rng = np.random.RandomState(source_seed)
# For filter stability (filtfilt butter order 4) & meaningful Hanning window
_DEFAULT_MIN_ERP_LEN = 82
times = np.arange(tmin, tmax, 1.0 / sfreq)
n_times = len(times)
# Determine the index for stimulus onset
stim_indices = np.where(times >= stim_onset)[0]
if len(stim_indices) == 0:
# Stimulus onset is at or after tmax, effectively no ERP in this epoch
stim_onset_samples = n_times
else:
stim_onset_samples = stim_indices[0]
waveform = np.zeros(n_times)
current_min_erp_len = erp_min_length if erp_min_length is not None else _DEFAULT_MIN_ERP_LEN
# Maximum available duration for ERP activity after stim_onset_samples
max_available_post_stim_duration = n_times - stim_onset_samples
if max_available_post_stim_duration < current_min_erp_len:
# Not enough samples in the post-stimulus window for a meaningful ERP
return waveform
actual_placement_start_sample: int
erp_duration_samples: int
if random_erp_timing:
# Randomize ERP duration: from current_min_erp_len up to max_available_post_stim_duration (inclusive)
actual_erp_duration = source_duration_rng.randint(low=current_min_erp_len, high=max_available_post_stim_duration + 1)
self.logger.debug(f"Randomized ERP duration: {actual_erp_duration} samples")
# Randomize ERP start offset within the available post-stimulus window
# Max possible start offset (from stim_onset_samples) for the chosen actual_erp_duration
max_start_offset_from_onset = max_available_post_stim_duration - actual_erp_duration
start_offset_from_onset = source_duration_rng.randint(0, max_start_offset_from_onset + 1)
self.logger.debug(f"Randomized ERP start offset from onset: {start_offset_from_onset} samples")
actual_placement_start_sample = stim_onset_samples + start_offset_from_onset
erp_duration_samples = actual_erp_duration
else:
# ERP spans the entire available post-stimulus duration
erp_duration_samples = max_available_post_stim_duration
actual_placement_start_sample = stim_onset_samples
# Safeguard, though preceding logic should ensure this
if erp_duration_samples < current_min_erp_len:
return waveform
# Generate noise only for the determined duration of the ERP activity
white_noise_for_erp = source_duration_rng.randn(erp_duration_samples)
self.logger.debug(f"Generated white noise for ERP with {erp_duration_samples} samples.")
# Design a Butterworth bandpass filter
low = fmin / (sfreq / 2)
high = fmax / (sfreq / 2)
epsilon = 1e-9
low = max(epsilon, low)
high = min(1.0 - epsilon, high)
if low >= high:
return waveform # Invalid frequency band
try:
b, a = butter(4, [low, high], btype='band')
except ValueError as e:
return waveform # Filter design failed
# Filter the noise segment
erp_segment = filtfilt(b, a, white_noise_for_erp)
# Apply Hanning window over the ERP segment
erp_segment *= np.hanning(erp_duration_samples)
# Normalize the ERP segment by standard deviation (OLD APPROACH)
# std_erp_segment = np.std(erp_segment)
# if std_erp_segment < 1e-9: # Check if standard deviation is effectively zero
# return waveform # Avoid division by zero; segment is flat
# erp_segment /= std_erp_segment # Normalize by its standard deviation
# Normalize the ERP segment by its peak amplitude
erp_peak = np.max(np.abs(erp_segment)) # Normalize by peak amplitude
if erp_peak < 1e-9: # Check if peak amplitude is effectively zero
return waveform # Avoid division by zero; segment is flat
erp_segment /= erp_peak
erp_segment *= amplitude # Scale to desired amplitude
# convert unit from nAm to Am
erp_segment *= 1e-9
# Place the generated ERP segment into the output signal at the determined start
end_sample_for_erp_segment = actual_placement_start_sample + len(erp_segment)
# Ensure placement is within bounds (should be guaranteed by earlier logic)
if actual_placement_start_sample < n_times and end_sample_for_erp_segment <= n_times:
waveform[actual_placement_start_sample : end_sample_for_erp_segment] = erp_segment
self.logger.debug(f"ERP waveform generated with shape: {waveform.shape}")
return waveform
def _simulate_source_time_courses(
self,
orientation_type: str = "fixed",
n_sources: int = 100,
nnz: int = 5,
trial_seed: int = 256,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Generate simulated source time courses for a single trial.
This method creates ERP-like signals for a subset of active sources, determined by `nnz`. For each active source, an ERP waveform is generated using a bandpass-filtered noise segment, optionally randomized in onset and duration, and scaled by the specified amplitude. The ERP waveform is placed at the appropriate time index based on `stim_onset`.
Parameters
----------
orientation_type : str
Orientation of the sources, either "fixed" or "free". Default is "fixed".
n_sources : int
Total number of sources to simulate. Default is 100.
nnz : int
Number of non-zero (active) sources in the trial. Must be less than or equal to `n_sources`. Default is 5.
trial_seed : int
Seed for the random number generator to ensure reproducibility of the source activity. Default is 256.
Returns
-------
x : np.ndarray
Simulated source activity array.
- Shape (n_sources, n_times) for "fixed" orientation.
- Shape (n_sources, n_orient, n_times) for "free" orientation.
active_indices : np.ndarray
Indices of the sources that were activated in this trial.
Notes
-------
- For "fixed" orientation, each active source has a single time course.
- For "free" orientation, each active source has three orientation components, with random orientation coefficients.
"""
trial_rng = np.random.RandomState(trial_seed)
tmin = self.ERP_config['tmin']
tmax = self.ERP_config['tmax']
sfreq = self.ERP_config['sfreq']
times = np.arange(tmin, tmax, 1.0 / sfreq)
n_times = len(times)
if orientation_type == "fixed":
# active_indices = np.sort(rng.choice(self.n_sources, size=self.nnz, replace=False))
active_indices = trial_rng.choice(n_sources, size=nnz, replace=False)
x = np.zeros((n_sources, n_times))
for i, src_idx in enumerate(active_indices):
# Generate ERP signal with specified onset
source_seed = trial_rng.randint(low=0, high=2**32 -1) # Derive a new seed for this source
self.logger.debug(f"Generating ERP for source index {src_idx} with seed {source_seed}")
erp_waveform = self._simulate_erp_waveform(source_seed=source_seed)
x[src_idx, :] = erp_waveform # Assign the full waveform (includes leading zeros)
elif orientation_type == "free":
# TODO: +++ THIS IS A TEMPORARY FIX. A NEW APPROACH IS NEEDED TO HANDLE +++
n_orient = 3 # TODO: Make this configurable
# active_indices = np.sort(rng.choice(self.n_sources, size=self.nnz, replace=False))
active_indices = trial_rng.choice(n_sources, size=nnz, replace=False)
x = np.zeros((n_sources, n_orient, n_times))
for i, src_idx in enumerate(active_indices):
source_seed = trial_rng.randint(0, 2**32 -1)
erp_waveform = self._simulate_erp_waveform(
source_seed,
)
orient_coeffs = trial_rng.randn(n_orient)
norm_orient = np.linalg.norm(orient_coeffs)
if norm_orient < 1e-9: # Avoid division by zero
orient_coeffs = np.array([1.0, 0.0, 0.0]) # Default orientation
else:
orient_coeffs /= norm_orient
for j_orient in range(n_orient):
x[src_idx, j_orient, :] = orient_coeffs[j_orient] * erp_waveform
# Alternatively, if we want to assign the same waveform to all orientations
# for j_orient in range(n_orientations_free):
# x[src_idx, j_orient, :] = erp_waveform
else:
raise ValueError("Invalid orientation_type. Choose 'fixed' or 'free'.")
self.logger.debug(f"Simulated source time courses with shape: {x.shape}")
self.logger.debug(f"Active source indices: {active_indices}")
return x, active_indices
[docs]
def simulate(
self,
orientation_type: str = "fixed",
n_sources: int = 100,
nnz: int = 5,
n_trials: int = 1,
global_seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Simulate multiple trials of source time courses.
This function generates synthetic source activity for `n_trials` trials using
ERP-like signals. Each trial uses a unique random seed derived from the
provided `global_seed` for reproducibility.
Parameters
----------
orientation_type : str
Orientation of the sources, either "fixed" or "free". Default is "fixed".
n_sources : int
Total number of sources to simulate. Default is 100.
nnz : int
Number of non-zero (active) sources in each trial. Must be less than or equal to `n_sources`. Default is 5.
n_trials : int
Number of trials to simulate. Default is 1.
global_seed : int
Seed for the random number generator to ensure reproducibility across trials. Default is 42.
Returns
-------
Tuple[np.ndarray, np.ndarray]
- x_all_trials : np.ndarray
Array of shape (n_trials, ...) containing simulated source time courses.
Shape depends on source orientation:
- fixed: (n_trials, n_sources, n_times)
- free: (n_trials, n_sources, 3, n_times)
- active_indices_all_trials : np.ndarray
Array of shape (n_trials, nnz) containing indices of active sources per trial.
"""
source_rng = np.random.RandomState(global_seed)
source_seeds = source_rng.randint(0, 2**32 - 1, size=n_trials)
x_all_trials = []
active_indices_all_trials = []
for i, seed in enumerate(source_seeds):
self.logger.debug(f"Simulating trial {i + 1}/{n_trials} with seed {seed}")
x, active_indices = self._simulate_source_time_courses(
orientation_type=orientation_type,
n_sources=n_sources,
nnz=nnz,
trial_seed=seed,
)
x_all_trials.append(x)
active_indices_all_trials.append(active_indices)
# Convert lists to numpy arrays
x_all_trials = np.array(x_all_trials)
active_indices_all_trials = np.array(active_indices_all_trials)
# Log the shapes of the results
self.logger.info(f"Completed simulating source time courses for {n_trials} trials.")
self.logger.info(f"Shape of source time courses of all trials {n_trials} trials: {x_all_trials.shape}")
self.logger.info(f"Shape of active indices for all {n_trials} trials: {active_indices_all_trials.shape}")
# Print active indices for all trials, each trial on a new line
self.logger.info("Active indices for all trials:")
for i, indices in enumerate(active_indices_all_trials):
self.logger.info(f" Trial {i+1}: {indices}")
return x_all_trials, active_indices_all_trials
[docs]
def main():
from calibrain import Visualizer
logging.basicConfig(
level=logging.INFO, # or DEBUG
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
handlers=[
logging.StreamHandler(), # Console
logging.FileHandler("simulation.log") # Optional: file output
]
)
logger = logging.getLogger("SourceSimulator")
ERP_config = {
"tmin": -0.5,
"tmax": 0.5,
"stim_onset": 0.0,
"sfreq": 250,
"fmin": 1,
"fmax": 5,
"amplitude": 10.0,
"random_erp_timing": True,
"erp_min_length": None,
}
n_trials=4
orientation_type="fixed"
n_sources=10
nnz=5
global_seed=42
trial_idx=0
source_simulator = SourceSimulator(
ERP_config=ERP_config,
logger=logger
)
print(f"Default units for source activity: {source_simulator.source_units}")
x_trials, active_indices_trials = source_simulator.simulate(
orientation_type=orientation_type,
n_sources=n_sources,
nnz=nnz,
n_trials=n_trials,
global_seed=global_seed,
)
# source_simulator.source_units = "Am"
logger.info("Simulation complete.")
viz = Visualizer(base_save_path="testViz", logger=logger)
# Plot sources (single trial)
viz.plot_source_signals(
ERP_config=ERP_config,
x=x_trials,
active_indices=active_indices_trials,
units=source_simulator.source_units,
trial_idx=trial_idx,
title=f"Source Trial {trial_idx+1}",
save_dir="data_simulation",
file_name=f"src_trial_{trial_idx+1}",
show=False,
)
# Plot sources (all trials)
viz.plot_source_signals(
ERP_config=ERP_config,
x=x_trials,
active_indices=active_indices_trials,
units=source_simulator.source_units,
trial_idx=None,
title="Source Trials (All)",
save_dir="data_simulation",
file_name="src_trials_all",
show=False,
)
if __name__ == "__main__":
main()