Source code for calibrain.benchmark

import datetime
import os
import logging
from typing import Optional
from zipfile import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import ParameterGrid
from sklearn.utils import check_random_state
from itertools import product
import mne
from itertools import combinations
from scipy.stats import wishart
from matplotlib.patches import Ellipse
from scipy.stats import chi2

from calibrain import MetricEvaluator
from calibrain import SourceEstimator, SensorSimulator, SourceSimulator, UncertaintyEstimator, Visualizer, LeadfieldBuilder, gamma_map, eloreta
# from calibrain.leadfield_builder import LeadfieldBuilder
from calibrain.utils import inspect_object
from mne.io.constants import FIFF

[docs] class Benchmark:
[docs] def __init__(self, solver : callable, solver_param_grid : dict, data_param_grid : dict, ERP_config : dict, source_simulator : SourceSimulator, leadfield_builder : LeadfieldBuilder, sensor_simulator : SensorSimulator, uncertainty_estimator : UncertaintyEstimator, metric_evaluator : MetricEvaluator, random_state=42, logger=None): """ Initialize the Benchmark class. Parameters: ---------- solver : callable The solver function (e.g., gamma_map, eloreta). solver_param_grid : dict Grid of solver hyperparameters (e.g. noise_type, init_gamma). data_param_grid : dict Grid of data generation hyperparameters. ERP_config : dict Configuration for ERP simulation. source_simulator : SourceSimulator Instance of SourceSimulator for generating source data. leadfield_builder : LeadfieldBuilder Instance of LeadfieldBuilder for generating leadfields. sensor_simulator : SensorSimulator Instance of SensorSimulator for generating data. uncertainty_estimator : UncertaintyEstimator Instance of UncertaintyEstimator for uncertainty estimation. metric_evaluator : MetricEvaluator Instance of MetricEvaluator for evaluating metrics. random_state : int, optional Random seed for reproducibility. logger : logging.Logger, optional Logger instance for logging messages. """ self.solver = solver self.solver_param_grid = solver_param_grid self.data_param_grid = data_param_grid self.ERP_config = ERP_config self.source_simulator = source_simulator self.leadfield_builder = leadfield_builder self.sensor_simulator = sensor_simulator self.uncertainty_estimator = uncertainty_estimator self.metric_evaluator = metric_evaluator self.random_state = random_state self.logger = logger if logger else logging.getLogger(__name__)
[docs] def create_experiment_directory(self, base_dir, params, desired_order): """ Create a directory structure for the experiment, with subdirectories for each parameter in a specified order, followed by any remaining parameters. Parameters: - base_dir (str): Base directory for the experiment. - params (dict): Dictionary of parameters. - desired_order (list): List of parameter keys in the desired order for the directory structure. Returns: - experiment_dir (str): Path to the experiment directory. """ # Exclude 'cov' and sanitize values for directory names sanitized_params_for_path = { k: str(v).replace("/", "_").replace("\\", "_").replace(" ", "_") for k, v in params.items() if k not in ("cov", "run_id") # Add other keys to exclude from path if necessary } # Desired order of parameters for the directory structure # This list defines the specific order. if desired_order is None: desired_order = ["subject", "solver", "init_gamma", "orientation_type", "alpha_SNR", "noise_type", "nnz", "seed"] path_components = [] # Add parameters in the desired order if they exist in the sanitized params for key in desired_order: if key in sanitized_params_for_path: path_components.append(f"{key}={sanitized_params_for_path[key]}") del sanitized_params_for_path[key] # Remove to avoid adding them again # Add any remaining parameters, sorted by key for consistent ordering for key, value in sorted(sanitized_params_for_path.items()): path_components.append(f"{key}={value}") # Create subdirectories for each parameter component experiment_dir = base_dir for component in path_components: experiment_dir = os.path.join(experiment_dir, component) try: # Create the directory structure if it doesn't exist os.makedirs(experiment_dir, exist_ok=True) except OSError as e: self.logger.error(f"Failed to create experiment directory: {experiment_dir}. Error: {e}") raise self.logger.info(f"Experiment directory created: {experiment_dir}") return experiment_dir
[docs] def run(self, nruns: int = 2, fig_path: str = "results/figures/uncertainty_analysis_figures"): """ Run benchmarking by iterating over combinations of solver and data parameters. Returns: - results (pd.DataFrame): DataFrame containing the results for each parameter combination. """ # ------------------------------------------------------------- # 1. Generate seeds and initialize bookkeeping # ------------------------------------------------------------- rng = check_random_state(self.random_state) seeds = rng.randint(low=0, high=2 ** 32, size=nruns) results_list = [] param_combinations = list(product( ParameterGrid(self.solver_param_grid), ParameterGrid(self.data_param_grid), seeds )) # ------------------------------------------------------------- # 2. Iterate over solver and data parameter combinations # ------------------------------------------------------------- self.logger.info(f"Starting benchmark with {len(param_combinations)} runs...") for run_id, (solver_params, data_params, seed) in enumerate(param_combinations, start=1): solver_name = getattr(self.solver, "__name__", str(self.solver)) orientation_type = data_params.get("orientation_type") n_orient = 3 if orientation_type == "free" else 1 self.logger.info(f"[Run {run_id}/{len(param_combinations)}] Seed: {seed}") self.logger.info(f"Solver: {solver_name} | Params: {solver_params}") self.logger.info(f"Data Params: {data_params}") this_result = { 'run_id': run_id, "seed": seed, "solver": solver_name, **solver_params, **data_params, } n_trials = 5 global_source_rng = np.random.RandomState(seed) global_source_seeds = global_source_rng.randint(0, 2**32 - 1, n_trials) global_noise_rng = np.random.RandomState(seed + 123456) global_noise_seeds = global_noise_rng.randint(0, 2**32 - 1, size=n_trials) try: # ------------------------------------------------------------- # 3. Create directory # ------------------------------------------------------------- experiment_dir = self.create_experiment_directory( base_dir=fig_path, params=this_result, desired_order = [ "subject", "solver", "init_gamma", "orientation_type", "alpha_SNR", "noise_type", "nnz", "seed" ] ) # ------------------------------------------------------------- # 4. Get leadfield matrix # ------------------------------------------------------------- self.logger.info("Building leadfield matrix...") L = self.leadfield_builder.get_leadfield( subject=data_params['subject'], orientation_type=orientation_type, retrieve_mode="load" ) n_sensors, n_sources = L.shape sensor_units = self.leadfield_builder.sensor_units # ------------------------------------------------------------- # 5. Simulate source and sensor data # ------------------------------------------------------------- self.logger.info("Simulating source trials...") x_trials, x_active_indices_trials = self.source_simulator.simulate( orientation_type=orientation_type, n_sources=n_sources, nnz=data_params['nnz'], n_trials=n_trials, global_seed=global_source_seeds, ) source_units = self.source_simulator.source_units # ------------------------------------------------------------- # 6. Simulate sensor and sensor data # ------------------------------------------------------------- self.logger.info("Simulating sensor trials...") y_clean_trials, y_noisy_trials, noise_trials, noise_var_trials =\ self.sensor_simulator.simulate( x_trials=x_trials, L=L, orientation_type=orientation_type, alpha_SNR=data_params['alpha_SNR'], n_trials=n_trials, global_seed=global_noise_seeds, ) self.sensor_simulator.sensor_units = sensor_units # Set units based on leadfield # ------------------------------------------------------------- # 7. Fit the source estimator and predict posterior mean & covariance # ------------------------------------------------------------- # Slice the first trial for processing trial_idx = 0 x_one_trial = x_trials[trial_idx] x_active_indices_one_trial = x_active_indices_trials[trial_idx] y_noisy_one_trial = y_noisy_trials[trial_idx] noise_var_one_trial = noise_var_trials[trial_idx] self.logger.info("Fitting source estimator...") source_estimator = SourceEstimator( solver=self.solver, solver_params=solver_params, n_orient=n_orient, logger=self.logger ) source_estimator.fit(L, y_noisy_one_trial) x_hat_one_trial, x_hat_active_indices_one_trial, posterior_cov = source_estimator.predict( y=y_noisy_one_trial, noise_var=noise_var_one_trial ) # ------------------------------------------------------------- # 8. Estimate uncertainty (-> credible intervals) # ------------------------------------------------------------- self.logger.info("Estimating uncertainty...") # TODO: check whether we still need to set keepdims=True. x_one_trial_avg_time = np.mean(x_one_trial, axis=1, keepdims=True) x_hat_one_trial_avg_time = np.mean(x_hat_one_trial, axis=1, keepdims=True) # full_posterior_cov = self.uncertainty_estimator.construct_full_covariance( # x=x_avg_time, # x_hat_active_indices=x_hat_active_indices, # posterior_cov=posterior_cov, # orientation_type=orientation_type, # ) # Find matched location between ground truth simulated sources and estimated sources # Get boolean mask for sources present in both sets matched_mask = np.isin(x_hat_active_indices_one_trial, x_active_indices_one_trial) if not np.any(matched_mask): self.logger.warning(f"No intersection between true active sources and estimated active sources") ci_lower_active = np.zeros( (len(self.uncertainty_estimator.confidence_levels), n_orient, 1) ) ci_upper_active = np.zeros( (len(self.uncertainty_estimator.confidence_levels), n_orient, 1) ) empirical_coverage_active = np.zeros( (len(self.uncertainty_estimator.confidence_levels)) ) empirical_coverage_active = np.zeros( (len(self.uncertainty_estimator.confidence_levels)) ) else: # Get relative indices within x_hat_active_indices (for posterior_cov slicing) matched_relative_indices = np.where(matched_mask)[0] # Get the actual source indices for data slicing matched_absolute_indices = x_hat_active_indices_one_trial[matched_mask] # Slice data using absolute indices (both arrays are full-size) x_matched = x_one_trial_avg_time[matched_absolute_indices] # Use absolute for x (full array) x_hat_matched = x_hat_one_trial_avg_time[matched_absolute_indices] # Use absolute for x_hat (also full array) # Slice posterior covariance using relative indices (only contains active components) posterior_cov_matched = posterior_cov[np.ix_( matched_relative_indices, matched_relative_indices )] # Verify dimensions print(f"x_matched shape: {x_matched.shape}") print(f"x_hat_matched shape: {x_hat_matched.shape}") print(f"posterior_cov_matched shape: {posterior_cov_matched.shape}") # Now all should have consistent dimensions assert x_matched.shape[0] == x_hat_matched.shape[0] == posterior_cov_matched.shape[0] # Compute confidence intervals ci_lower_active, ci_upper_active, counts_within_ci_active, empirical_coverage_active = \ self.uncertainty_estimator.get_confidence_intervals_data( x=x_matched, x_hat=x_hat_matched, posterior_cov=posterior_cov_matched, orientation_type=orientation_type ) # ------------------------------------------------------------- # 9. Evaluate metrics and store results # ------------------------------------------------------------- self.metric_evaluator.evaluate_and_store_metrics( current_results_dict=this_result, metric_suffix="_active_indices", empirical_coverage=empirical_coverage_active, cov=posterior_cov, #_matched? TODO x = x_one_trial_avg_time, x_hat = x_hat_one_trial_avg_time, orientation_type=orientation_type, nnz=data_params.get("nnz"), subject=data_params.get("subject"), fwd_path=self.leadfield_builder.leadfield_dir ) # -------------------------------------------------------------- # 10. Vizualization # -------------------------------------------------------------- viz = Visualizer(base_save_path=experiment_dir, logger=self.logger) viz.plot_all( x_trials=x_trials, x_active_indices_trials=x_active_indices_trials, x_hat_one_trial=x_hat_one_trial, x_hat_active_indices_one_trial=x_hat_active_indices_one_trial, y_clean_trials=y_clean_trials, y_noisy_trials=y_noisy_trials, trial_idx=trial_idx, n_sources=n_sources, subject=data_params.get("subject"), subjects_dir=mne.datasets.sample.data_path() / 'subjects', #TODO: include MEG anatomical data fwd_path=self.leadfield_builder.leadfield_dir, nnz=data_params.get("nnz"), ERP_config=self.ERP_config, sample_idx=200, source_units=source_units, sensor_units=sensor_units, confidence_levels=self.uncertainty_estimator.confidence_levels, empirical_coverages={ # 'all_sources': all_sources_empirical_coverage, 'active_indices': empirical_coverage_active }, ci_lower=ci_lower_active, ci_upper=ci_upper_active, orientation_type=orientation_type, result=this_result, experiment_dir=experiment_dir, ) this_result['active_indices_size'] = len(x_hat_active_indices_one_trial) except Exception as e: self.logger.error(f"Error during benchmarking run_id {this_result.get('run_id', 'N/A')}: {e}", exc_info=True) this_result["error_message"] = str(e) # More specific error key results_list.append(this_result) self.logger.info(f"Completed run_id {this_result.get('run_id', 'N/A')}") self.logger.info("Benchmarking completed.") return pd.DataFrame(results_list)