Source code for calibrain.data_simulation

import numpy as np
from numpy.random import Generator
from scipy.stats import wishart
from pathlib import Path
import logging
import os
from typing import Optional, Tuple, Union, Dict, List, Any

import mne
import matplotlib.pyplot as plt

from calibrain import LeadfieldSimulator
from calibrain.utils import load_config

"""
Module for simulating brain activity data.

Provides the `DataSimulator` class to generate synthetic brain activity,
including source time courses, leadfield matrices (via loading, simulation,
or random generation), and sensor-level measurements with controllable noise (SNR).
Supports both fixed and free source orientations and includes visualization tools.
"""

[docs] class DataSimulator: """ Simulates brain activity data including source activity, leadfield, and sensor measurements. Handles different leadfield generation modes (random, load, simulate) and noise addition based on specified SNR. Supports fixed and free source orientations. Attributes ---------- n_sensors : int Number of sensors. Updated based on the obtained leadfield. n_sources : int Number of sources. Updated based on the obtained leadfield. n_times : int Number of time points. nnz : int Number of non-zero (active) sources to simulate. orientation_type : str Orientation type ('fixed' or 'free'). alpha_snr : float Target SNR defined as signal_norm / (signal_norm + noise_norm). noise_type : str Type of noise generation method (currently affects oracle covariance). seed : Optional[int] Random seed for reproducibility. logger : logging.Logger Logger instance. rng : np.random.Generator Random number generator instance. leadfield_mode : str Mode for leadfield generation ('load', 'simulate', 'random'). leadfield_path : Optional[Path] Path to the leadfield file (used if mode='load'). leadfield_config_path : Optional[Path] Path to config file (used if mode='simulate'). """ def __init__( self, n_sensors: int = 60, n_sources: int = 3000, n_times: int = 100, nnz: int = 3, orientation_type: str = "fixed", alpha_snr: float = 0.99, noise_type: str = "oracle", seed: Optional[int] = None, logger: Optional[logging.Logger] = None, rng: Optional[Generator] = None, leadfield_mode: str = "random", leadfield_path: Optional[Union[str, Path]] = None, leadfield_config_path: Optional[Union[str, Path]] = None, ): """ Initialize the DataSimulator. Parameters ---------- n_sensors : int, optional Initial number of sensors, by default 60. May be updated by leadfield. n_sources : int, optional Initial number of sources, by default 3000. May be updated by leadfield. n_times : int, optional Number of time points, by default 100. nnz : int, optional Number of non-zero sources, by default 3. orientation_type : str, optional Orientation type ('fixed' or 'free'), by default "fixed". alpha_snr : float, optional Target SNR = signal_norm / (signal_norm + noise_norm), by default 0.99. noise_type : str, optional Type of noise ('random' or 'scaled_identity'), by default "oracle". seed : Optional[int], optional Random seed for reproducibility. If None, uses default seeding, by default None. logger : Optional[logging.Logger], optional Logger instance. If None, creates a default logger, by default None. rng : Optional[Generator], optional NumPy random number generator. If None, creates one based on seed, by default None. leadfield_mode : str, optional 'load', 'simulate', or 'random', by default "random". leadfield_path : Optional[Union[str, Path]], optional Path to leadfield file (if mode='load'), by default None. leadfield_config_path : Optional[Union[str, Path]], optional Path to config file (if mode='simulate'), by default None. """ self.n_sensors = n_sensors self.n_sources = n_sources self.n_times = n_times self.nnz = nnz self.orientation_type = orientation_type self.alpha_snr = alpha_snr self.noise_type = noise_type # Currently unused beyond init logging self.seed = seed self.logger = logger if logger else logging.getLogger(__name__) self.rng = rng if rng else np.random.default_rng(seed) self.leadfield_mode = leadfield_mode self.leadfield_path = Path(leadfield_path) if leadfield_path else None self.leadfield_config_path = Path(leadfield_config_path) if leadfield_config_path else None self.logger.info(f"DataSimulator initialized with orientation: {self.orientation_type}, leadfield mode: {self.leadfield_mode}") def _get_leadfield(self) -> np.ndarray: """ Get or generate the leadfield matrix based on the specified mode. Updates self.n_sensors and self.n_sources based on the obtained leadfield. Returns ------- np.ndarray The leadfield matrix (L). Shape depends on orientation_type: - 'fixed': (n_sensors, n_sources) - 'free': (n_sensors, n_sources, 3) Raises ------ ValueError If leadfield_mode is invalid, required paths are missing, or loaded/simulated leadfield has unexpected dimensions/format. FileNotFoundError If leadfield_path does not exist when mode='load'. """ expected_suffix = "-free.npz" if self.orientation_type == "free" else "-fixed.npz" expected_dimensions = 3 if self.orientation_type == "free" else 2 leadfield: np.ndarray if self.leadfield_mode == "load": if not self.leadfield_path: raise ValueError("Path to the leadfield file (leadfield_path) must be provided when leadfield_mode='load'.") try: if not self.leadfield_path.exists(): raise FileNotFoundError(f"Leadfield file does not exist: {self.leadfield_path}") # Optional strict check: # if not self.leadfield_path.name.endswith(expected_suffix): # self.logger.warning(f"Leadfield file name '{self.leadfield_path.name}' does not match expected suffix '{expected_suffix}' for orientation '{self.orientation_type}'.") self.logger.info(f"Loading leadfield matrix from file: {self.leadfield_path}") with np.load(self.leadfield_path) as data: if "leadfield" not in data: raise ValueError(f"File {self.leadfield_path} does not contain 'leadfield' key.") leadfield = data["leadfield"] if leadfield.ndim != expected_dimensions: raise ValueError( f"Loaded leadfield matrix dimension mismatch for orientation '{self.orientation_type}': " f"expected {expected_dimensions} dimensions, but got {leadfield.ndim}." ) self.logger.info(f"Leadfield loaded with shape {leadfield.shape}") except (FileNotFoundError, ValueError) as e: self.logger.error(f"Failed to load leadfield matrix: {e}") raise elif self.leadfield_mode == "simulate": if not self.leadfield_config_path: raise ValueError("Path to the configuration file (leadfield_config_path) must be provided when leadfield_mode='simulate'.") self.logger.info(f"Simulating leadfield matrix using LeadfieldSimulator with config: {self.leadfield_config_path}") try: config = load_config(Path(self.leadfield_config_path)) L_simulator = LeadfieldSimulator(config=config, logger=self.logger) leadfield = L_simulator.simulate() self.logger.info(f"Simulated leadfield matrix with shape {leadfield.shape}") if leadfield.ndim != expected_dimensions: raise ValueError( f"Simulated leadfield matrix dimension mismatch for orientation '{self.orientation_type}': " f"expected {expected_dimensions} dimensions, but got {leadfield.ndim}." ) except Exception as e: self.logger.error(f"Failed to simulate leadfield matrix: {e}") raise elif self.leadfield_mode == "random": self.logger.info(f"Generating a random leadfield matrix (n_sensors={self.n_sensors}, n_sources={self.n_sources}).") if self.orientation_type == "fixed": leadfield = self.rng.standard_normal((self.n_sensors, self.n_sources)) else: leadfield = self.rng.standard_normal((self.n_sensors, self.n_sources, 3)) self.logger.info(f"Random leadfield generated with shape {leadfield.shape}") else: raise ValueError(f"Invalid leadfield mode '{self.leadfield_mode}'. Options are 'load', 'simulate', or 'random'.") # Update n_sensors and n_sources based on the actual leadfield dimensions if leadfield.ndim == 2: # Fixed self.n_sensors, self.n_sources = leadfield.shape elif leadfield.ndim == 3: # Free self.n_sensors, self.n_sources, _ = leadfield.shape self.logger.info(f"Leadfield obtained. Updated n_sensors={self.n_sensors}, n_sources={self.n_sources}") return leadfield def _generate_source_time_courses(self) -> np.ndarray: """ Generate synthetic source time courses in the source space. Only `self.nnz` sources will have non-zero activity, randomly chosen. Returns ------- np.ndarray Source time courses (x). Shape depends on orientation_type: - 'fixed': (n_sources, n_times) - 'free': (n_sources, 3, n_times) Raises ------ ValueError If `self.orientation_type` is unsupported. """ if self.orientation_type == "fixed": x = np.zeros((self.n_sources, self.n_times)) idx = self.rng.choice(self.n_sources, size=self.nnz, replace=False) x[idx] = self.rng.standard_normal((self.nnz, self.n_times)) elif self.orientation_type == "free": n_orient = 3 x = np.zeros((self.n_sources, n_orient, self.n_times)) idx = self.rng.choice(self.n_sources, size=self.nnz, replace=False) x[idx] = self.rng.standard_normal((self.nnz, n_orient, self.n_times)) else: raise ValueError(f"Unsupported orientation type: {self.orientation_type}") return x def _project_to_sensor_space(self, L: np.ndarray, x: np.ndarray) -> np.ndarray: """ Project the source activity to the sensor space using the leadfield matrix. Parameters ---------- L : np.ndarray Leadfield matrix (µV / nAm). - 'fixed': Shape (n_sensors, n_sources). - 'free': Shape (n_sensors, n_sources, 3). x : np.ndarray (nAm) Source activity. - 'fixed': Shape (n_sources, n_times). - 'free': Shape (n_sources, 3, n_times). Returns ------- np.ndarray Sensor measurements (y_clean). Shape: (n_sensors, n_times). => (µV / nAm) * nAm = µV Raises ------ ValueError If `self.orientation_type` is unsupported. """ # (µV / nAm) * nAm = µV if self.orientation_type == "fixed": # Matrix multiplication: (n_sensors, n_sources) @ (n_sources, n_times) -> (n_sensors, n_times) y = L @ x elif self.orientation_type == "free": # Einstein summation: Sum over source index 'm' and orientation index 'r' # (n_sensors, n_sources, 3) einsum (n_sources, 3, n_times) -> (n_sensors, n_times) y = np.einsum("nmr,mrt->nt", L, x) # Corrected einsum indices else: raise ValueError(f"Unsupported orientation type: {self.orientation_type}") return y def _add_noise(self, y_clean: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Add scaled Gaussian noise to clean sensor measurements based on alpha_snr. Parameters ---------- y_clean : np.ndarray Clean sensor measurements, shape (n_sensors, n_times). Returns ------- Tuple[np.ndarray, np.ndarray, np.ndarray] - y_noisy : Noisy sensor measurements, shape (n_sensors, n_times). - cov_scaled : Scaled noise covariance matrix, shape (n_sensors, n_sensors). - noise_scaled : Scaled noise added, shape (n_sensors, n_times). """ # Define a oracle noise covariance (identity matrix scaled) oracle_cov = 1e-2 * np.eye(self.n_sensors) # Generate multivariate normal noise based on the oracle covariance noise = self.rng.multivariate_normal(np.zeros(self.n_sensors), oracle_cov, size=self.n_times).T # Frobenius norm for signal and noise signal_norm = np.linalg.norm(y_clean, "fro") noise_norm = np.linalg.norm(noise, "fro") # Prevent division by zero if norms are zero if noise_norm == 0: self.logger.warning("Initial noise norm is zero. Cannot scale noise based on SNR. Returning clean signal.") return y_clean, oracle_cov, noise # Return unscaled noise and cov if signal_norm == 0: self.logger.warning("Clean signal norm is zero. Noise scaling might be arbitrary.") # Decide on behavior: maybe return zero noise or unscaled noise? # Returning unscaled noise for now. snr_scaling_factor = 1.0 else: # Calculate SNR scaling factor to achieve target alpha_snr snr_scaling_factor = ((1 - self.alpha_snr) / self.alpha_snr) * (signal_norm / noise_norm) noise_scaled = noise * snr_scaling_factor cov_scaled = oracle_cov * snr_scaling_factor ** 2 y_noisy = y_clean + noise_scaled if (signal_norm + np.linalg.norm(noise_scaled, "fro")) > 0: actual_alpha = signal_norm / (signal_norm + np.linalg.norm(noise_scaled, "fro")) else: actual_alpha = 0 self.logger.info(f"Target alpha_snr: {self.alpha_snr:.4f}, Actual alpha_snr: {actual_alpha:.4f}") return y_noisy, cov_scaled, noise_scaled
[docs] def simulate(self, visualize: bool = True, save_path: str = "results/figures/data_sim/") -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Run the full data simulation pipeline. Steps: 1. Get leadfield matrix (load, simulate, or random). 2. Generate source time courses. 3. Project sources to sensor space (clean measurements). 4. Add noise based on SNR. 5. Optionally visualize results. Parameters ---------- visualize : bool, optional Whether to generate and save visualization plots, by default True. Requires `leadfield_mode` to be 'simulate' or a valid `leadfield_config_path` to be provided for obtaining MNE info. save_path : str, optional Base directory to save visualization figures, by default "results/figures/data_sim/". Returns ------- Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] - y_noisy : Noisy sensor measurements (n_sensors, n_times). - L : Leadfield matrix. Shape depends on orientation type, potentially reshaped for 'free'. - x : Ground truth source activity. Shape depends on orientation type. - cov_scaled : Scaled noise covariance matrix (n_sensors, n_sensors). - noise_scaled : Scaled noise added (n_sensors, n_times). Raises ------ ValueError If visualization is requested but MNE info cannot be obtained (e.g., `leadfield_mode` is not 'simulate' and `leadfield_config_path` is missing). """ L_orig = self._get_leadfield() x = self._generate_source_time_courses() y_clean = self._project_to_sensor_space(L_orig, x) y_noisy, cov_scaled, noise_scaled = self._add_noise(y_clean) if visualize: info = None try: # Attempt to get info using LeadfieldSimulator setup config = load_config(Path(self.leadfield_config_path)) info = LeadfieldSimulator(config=config, logger=self.logger).handle_info() except Exception as e: self.logger.warning(f"Could not load info from config {self.leadfield_config_path} for visualization: {e}") save_path_obj = Path(save_path) save_path_obj.mkdir(parents=True, exist_ok=True) self.visualize_signals( x=x, y_clean=y_clean, y_noisy=y_noisy, nnz_to_plot=self.nnz, sfreq=info["sfreq"] if info else 100.0, # Default sfreq if info missing max_sensors=3, plot_sensors_together=False, show=False, save_path=os.path.join(save_path, "data_sim.png"), ) self.visualize_leadfield( L_orig, orientation_type=self.orientation_type, save_path=os.path.join(save_path, "leadfield_matrix.png"), show=False ) if info: self.visualize_leadfield_topomap( leadfield_matrix=L_orig, info=info, x=x, orientation_type=self.orientation_type, title="Leadfield Topomap for Active (Nonzero) Sources", save_path=os.path.join(save_path, "leadfield_topomap.png"), show=False, ) else: self.logger.info("Skipping leadfield topomap visualization due to missing MNE info.") # Reshape leadfield matrix for free orientation if needed by downstream estimators. # This might be a workaround depending on estimator expectations. L = L_orig if self.orientation_type == "free": self.logger.debug("Reshaping free orientation leadfield from (sensors, sources, 3) to (sensors, sources*3)") L = L_orig.reshape(L_orig.shape[0], -1) return y_noisy, L, x, cov_scaled, noise_scaled
[docs] def visualize_signals( self, x: np.ndarray, y_clean: np.ndarray, y_noisy: np.ndarray, active_sources: Optional[np.ndarray] = None, nnz_to_plot: int = -1, sfreq: float = 100.0, max_sensors: int = 3, plot_sensors_together: bool = False, shift: float = 20.0, figsize: Tuple[float, float] = (14, 10), save_path: Optional[str] = 'results/figures/data_sim.png', show: bool = False ) -> None: """ Visualize source activity and sensor measurements. Plots active source time courses and compares clean vs. noisy sensor signals. Parameters ---------- x : np.ndarray Source activity. Shape depends on orientation type. y_clean : np.ndarray Clean sensor measurements (n_sensors, n_times). y_noisy : np.ndarray Noisy sensor measurements (n_sensors, n_times). active_sources : Optional[np.ndarray], optional Indices of non-zero (active) sources. If None, they are determined from x, by default None. nnz_to_plot : int, optional Number of non-zero sources to plot. If -1, plot all non-zero sources found, by default -1. sfreq : float, optional Sampling frequency in Hz, by default 100.0. max_sensors : int, optional Maximum number of sensors to plot, by default 3. plot_sensors_together : bool, optional If True, plot all sensors on the same subplot. If False, stack plots vertically, by default False. shift : float, optional Vertical shift between sensors when plotting together, by default 20.0. figsize : Tuple[float, float], optional Figure size for the plot, by default (14, 10). save_path : Optional[str], optional Path to save the figure. If None, the figure is not saved, by default 'results/figures/data_sim.png'. show : bool, optional If True, display the plot, by default False. """ n_times = y_clean.shape[1] times = np.linspace(0, (n_times - 1) / sfreq, n_times) if n_times > 1 else np.array([0]) if active_sources is None: if self.orientation_type == "fixed": active_sources = np.where(np.any(x != 0, axis=-1))[0] elif self.orientation_type == "free": self.logger.info("Calculating norm of source activity to find active sources for free orientation.") # Check if any component (X, Y, Z) at any time point is non-zero for a source active_sources = np.where(np.any(x != 0, axis=(1, 2)))[0] else: raise ValueError(f"Unsupported orientation type: {self.orientation_type}") if nnz_to_plot != -1 and len(active_sources) > nnz_to_plot: plot_indices = self.rng.choice(active_sources, nnz_to_plot, replace=False) self.logger.info(f"Plotting {nnz_to_plot} randomly selected active sources out of {len(active_sources)}.") else: plot_indices = active_sources nnz_to_plot = len(plot_indices) # Update actual number plotted y_min = min(y_clean.min(), y_noisy.min()) y_max = max(y_clean.max(), y_noisy.max()) y_range = y_max - y_min if y_max > y_min else 1.0 # Avoid zero range num_sensors_to_plot = min(max_sensors, y_clean.shape[0]) total_plots = 1 + (1 if plot_sensors_together else num_sensors_to_plot) fig, axes = plt.subplots( total_plots, 1, figsize=figsize, gridspec_kw={"height_ratios": [1] * total_plots}, # Equal height for now sharex=True # Share x-axis ) # Ensure axes is always an array if total_plots == 1: axes = [axes] ax_sources = axes[0] if self.orientation_type == "fixed": for i in plot_indices: ax_sources.plot(times, x[i].T, label=f"Source {i}") elif self.orientation_type == "free": for i in plot_indices: # Plot norm or individual components? Plotting norm for simplicity. source_norm = np.linalg.norm(x[i], axis=0) ax_sources.plot(times, source_norm, label=f"Source {i} (Norm)") # Alternatively, plot components: # for j, orient in enumerate(["X", "Y", "Z"]): # ax_sources.plot(times, x[i, j], label=f"Source {i} ({orient})", alpha=0.7) ax_sources.set_title(f"{nnz_to_plot} Active Simulated Source Activity") ax_sources.set_ylabel("Amplitude (a.u.)") # Arbitrary units for sources ax_sources.grid(True) ax_sources.legend(loc='center left', bbox_to_anchor=(1, 0.5)) sensor_axes = axes[1:] if plot_sensors_together: ax_sensors = sensor_axes[0] current_shift = 0 for i in range(num_sensors_to_plot): ax_sensors.plot(times, y_clean[i] + current_shift, label=f"Clean (Sensor {i})", linewidth=1.5) ax_sensors.plot(times, y_noisy[i] + current_shift, label=f"Noisy (Sensor {i})", alpha=0.8, linewidth=1) current_shift += shift # Use provided shift relative to previous signal ax_sensors.set_title("Sensor Measurements") ax_sensors.set_ylabel("Amplitude (a.u.)") # Arbitrary units for sensors ax_sensors.grid(True) ax_sensors.legend(loc='center left', bbox_to_anchor=(1, 0.5)) else: for idx, ax_sens in enumerate(sensor_axes): ax_sens.plot(times, y_clean[idx], label=f"Clean", linewidth=1.5) ax_sens.plot(times, y_noisy[idx], label=f"Noisy", alpha=0.8, linewidth=1) ax_sens.set_title(f"Sensor {idx}") ax_sens.set_ylabel("Amplitude (a.u.)") ax_sens.set_ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range) # Consistent ylim ax_sens.grid(True) ax_sens.legend(loc='center left', bbox_to_anchor=(1, 0.5)) axes[-1].set_xlabel("Time (s)") plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout for legends if save_path: save_dir = Path(save_path).parent save_dir.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, bbox_inches="tight") self.logger.info(f"Signals visualization saved to {save_path}") if show: plt.show() plt.close(fig)
[docs] def visualize_leadfield( self, leadfield_matrix: np.ndarray, orientation_type: str = "fixed", save_path: Optional[str] = None, show: bool = False ) -> None: """ Visualize the leadfield matrix as a heatmap. Parameters ---------- leadfield_matrix : np.ndarray The leadfield matrix. - 'fixed': Shape (n_sensors, n_sources). - 'free': Shape (n_sensors, n_sources, 3). orientation_type : str, optional Orientation type ('fixed' or 'free'), by default "fixed". save_path : Optional[str], optional Path to save the figure. If None, not saved, by default None. show : bool, optional If True, display the plot, by default False. Raises ------ ValueError If leadfield_matrix is invalid or orientation_type is unsupported. """ if leadfield_matrix is None or not isinstance(leadfield_matrix, np.ndarray) or leadfield_matrix.size == 0: self.logger.error("Invalid leadfield matrix provided for visualization.") return fig = None # Initialize fig try: if orientation_type == "fixed": if leadfield_matrix.ndim != 2: raise ValueError(f"Expected 2D leadfield for fixed orientation, got {leadfield_matrix.ndim}D") fig, ax = plt.subplots(figsize=(10, 8)) im = ax.imshow(leadfield_matrix, aspect='auto', cmap='viridis', interpolation='nearest') fig.colorbar(im, ax=ax, label="Amplitude") ax.set_title("Leadfield Matrix (Fixed Orientation)") ax.set_xlabel("Sources") ax.set_ylabel("Sensors") elif orientation_type == "free": if leadfield_matrix.ndim != 3 or leadfield_matrix.shape[-1] != 3: raise ValueError(f"Expected 3D leadfield (..., 3) for free orientation, got shape {leadfield_matrix.shape}") n_orient = leadfield_matrix.shape[-1] fig, axes = plt.subplots(1, n_orient, figsize=(15, 5), sharey=True) if n_orient == 1: axes = [axes] # Ensure axes is iterable orientations = ["X", "Y", "Z"] images = [] for i in range(n_orient): im = axes[i].imshow(leadfield_matrix[:, :, i], aspect='auto', cmap='viridis', interpolation='nearest') images.append(im) axes[i].set_title(f"Leadfield Matrix ({orientations[i]})") axes[i].set_xlabel("Sources") axes[0].set_ylabel("Sensors") fig.colorbar(images[0], ax=axes, location="right", label="Amplitude", fraction=0.05, pad=0.04) else: raise ValueError("Invalid orientation type. Must be 'fixed' or 'free'.") plt.tight_layout() if save_path: save_dir = Path(save_path).parent save_dir.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, bbox_inches="tight") self.logger.info(f"Leadfield matrix visualization saved to {save_path}") if show: plt.show() except Exception as e: self.logger.error(f"Failed during leadfield visualization: {e}") finally: if fig: plt.close(fig)
[docs] def visualize_leadfield_topomap( self, leadfield_matrix: np.ndarray, info: mne.Info, x: np.ndarray, orientation_type: str = "fixed", save_path: Optional[str] = None, title: Optional[str] = None, show: bool = False ) -> None: """ Visualize leadfield patterns as topomaps for active sources. Parameters ---------- leadfield_matrix : np.ndarray The leadfield matrix. - 'fixed': Shape (n_sensors, n_sources). - 'free': Shape (n_sensors, n_sources, 3). info : mne.Info MNE info object containing sensor locations. x : np.ndarray Source activity matrix to determine active sources. - 'fixed': Shape (n_sources, n_times). - 'free': Shape (n_sources, 3, n_times). orientation_type : str, optional Orientation type ('fixed' or 'free'), by default "fixed". save_path : Optional[str], optional Path to save the figure. If None, not saved, by default None. title : Optional[str], optional Title for the entire figure, by default None. show : bool, optional If True, display the plot, by default False. Raises ------ ValueError If inputs are invalid or orientation_type is unsupported. """ if leadfield_matrix is None or not isinstance(leadfield_matrix, np.ndarray) or leadfield_matrix.size == 0: self.logger.error("Invalid leadfield matrix provided for topomap visualization.") return if x is None or not isinstance(x, np.ndarray) or x.size == 0: self.logger.error("Invalid source activity matrix provided for topomap visualization.") return if info is None or not isinstance(info, mne.Info): self.logger.error("Invalid MNE info object provided for topomap visualization.") return fig = None # Initialize fig try: if orientation_type == "fixed": if leadfield_matrix.ndim != 2: raise ValueError(f"Expected 2D leadfield for fixed orientation, got {leadfield_matrix.ndim}D") active_sources = np.where(np.any(x != 0, axis=-1))[0] elif orientation_type == "free": if leadfield_matrix.ndim != 3 or leadfield_matrix.shape[-1] != 3: raise ValueError(f"Expected 3D leadfield (..., 3) for free orientation, got shape {leadfield_matrix.shape}") self.logger.info("Calculating norm of source activity to find active sources for free orientation.") active_sources = np.where(np.any(x != 0, axis=(1, 2)))[0] else: raise ValueError("Invalid orientation type. Must be 'fixed' or 'free'.") if len(active_sources) == 0: self.logger.warning("No active sources found to visualize topomaps.") return n_active = len(active_sources) n_cols = min(5, n_active) # Max 5 columns n_rows = int(np.ceil(n_active / n_cols)) fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3), squeeze=False) axes_flat = axes.flatten() # Determine global color limits for consistency all_leadfield_values = [] for i, source_idx in enumerate(active_sources): if orientation_type == "fixed": leadfield_values = leadfield_matrix[:, source_idx] else: # free # Visualize the norm of the 3 components for simplicity leadfield_values = np.linalg.norm(leadfield_matrix[:, source_idx, :], axis=-1) all_leadfield_values.append(leadfield_values) if not all_leadfield_values: self.logger.warning("Could not extract leadfield values for any active source.") return vmax = np.max(np.abs(all_leadfield_values)) vmin = -vmax for i, source_idx in enumerate(active_sources): leadfield_values = all_leadfield_values[i] im, _ = mne.viz.plot_topomap( leadfield_values, info, axes=axes_flat[i], cmap="RdBu_r", # Use diverging colormap vlim=(vmin, vmax), show=False, contours=6 ) axes_flat[i].set_title(f"Source {source_idx}") # Add a single colorbar fig.colorbar(im, ax=axes.ravel().tolist(), label='Leadfield Amplitude', shrink=0.6, aspect=10) # Hide unused subplots for j in range(n_active, len(axes_flat)): axes_flat[j].axis("off") if title: fig.suptitle(title, fontsize=16) # Removed weight="bold" plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1]) # Adjust for suptitle if save_path: save_dir = Path(save_path).parent save_dir.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, bbox_inches="tight") self.logger.info(f"Leadfield topomap visualization saved to {save_path}") if show: plt.show() except Exception as e: self.logger.error(f"Failed during leadfield topomap visualization: {e}") finally: if fig: plt.close(fig)