import datetime
import json
import os
import logging
import h5py
from typing import Optional
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import ParameterGrid
from sklearn.utils import check_random_state
from itertools import product
from joblib import Parallel, delayed
import mne
from itertools import combinations
from scipy.stats import wishart
from matplotlib.patches import Ellipse
from scipy.stats import chi2
from calibrain import MetricEvaluator
from calibrain import SourceEstimator, SensorSimulator, SourceSimulator, UncertaintyEstimator, Visualizer, LeadfieldBuilder, UncertaintyCalibrator
from calibrain.utils import get_data_path, inspect_object
from mne.io.constants import FIFF
from calibrain.calibration_storage import save_calibration_record
from calibrain.calibration_dataset import EEG_COIL_TYPES, MEG_COIL_TYPES
# Suppress verbose MNE console output in worker processes (joblib spawns new interpreters)
logging.getLogger("mne").setLevel(logging.ERROR)
logging.getLogger("mne.utils").setLevel(logging.ERROR)
[docs]
class DataGenerator:
[docs]
def __init__(
self,
solver: callable,
solver_param_grid: dict,
data_param_grid: dict,
noise_param_grid: dict,
ERP_config: dict,
source_simulator: SourceSimulator,
leadfield_builder: LeadfieldBuilder,
sensor_simulator: SensorSimulator,
save_posterior_stats: bool = True,
posterior_dir: str | Path | None = None,
random_state=42,
logger=None,
):
"""
Initialize the DataGenerator class.
Parameters
----------
solver : callable
The solver function (e.g., ``gamma_map_sflex`` or ``BMN``).
solver_param_grid : dict
Grid of solver hyperparameters (e.g. noise_type, init_gamma).
data_param_grid : dict
Grid of data generation hyperparameters.
noise_param_grid : dict
Grid of noise-related hyperparameters.
ERP_config : dict
Configuration for ERP simulation.
source_simulator : SourceSimulator
Instance of SourceSimulator for generating source data.
leadfield_builder : LeadfieldBuilder
Instance of LeadfieldBuilder for generating leadfields.
sensor_simulator : SensorSimulator
Instance of SensorSimulator for generating data.
save_posterior_stats : bool, optional
If True, persist per-run posterior summaries for later aggregation.
posterior_dir : str or Path, optional
Directory where posterior summary files should be stored. Defaults
to the per-run experiment directory when omitted.
random_state : int, optional
Random seed for reproducibility.
logger : logging.Logger, optional
Logger instance for logging messages.
"""
self.solver = solver
self.solver_param_grid = solver_param_grid
self.data_param_grid = data_param_grid
self.noise_param_grid = noise_param_grid
self.ERP_config = ERP_config
self.source_simulator = source_simulator
self.leadfield_builder = leadfield_builder
self.sensor_simulator = sensor_simulator
self.save_posterior_stats = save_posterior_stats
self.posterior_dir = Path(posterior_dir) if posterior_dir else None
if self.posterior_dir is not None:
self.posterior_dir.mkdir(parents=True, exist_ok=True)
self.random_state = random_state
self.logger = logger if logger else logging.getLogger(__name__)
def __getstate__(self):
state = self.__dict__.copy()
state['logger'] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
if self.logger is None:
self.logger = logging.getLogger(__name__)
def _log_run_progress(
self,
run_in_config: int,
nruns_local: int,
config_index: int,
num_configs: int,
run_id: int,
total_runs: int,
global_run_id: Optional[int],
global_total_runs: Optional[int],
solver_name: str,
noise_type: Optional[str],
nnz: Optional[int],
alpha_snr: Optional[float],
) -> None:
"""Log per-run progress and an optional separator summary when the
current run is the final run of the current configuration.
Parameters mirror the values computed in `_execute_single_run`.
"""
try:
self.logger.info(
"[run: %d/%d | config: %d/%d | total: %d/%d] %s | %s | %s NNZ | %s SNR",
run_in_config,
nruns_local,
config_index,
num_configs,
global_run_id if global_run_id is not None else run_id,
global_total_runs if global_total_runs is not None else total_runs,
solver_name,
noise_type,
nnz,
alpha_snr,
)
except Exception:
# tolerate logging errors — do not break generation
try:
# fallback to basic debug message
self.logger.debug("Progress logging failed for run %s", run_id)
except Exception:
pass
def _create_experiment_directory(self, base_dir, params, desired_order):
"""
Create a directory structure for the experiment, with subdirectories for each parameter in a specified order, followed by any remaining parameters.
Parameters
----------
base_dir : str
Base directory for the experiment.
params : dict
Dictionary of parameters.
desired_order : list
List of parameter keys in the desired order for the directory structure.
Returns
-------
experiment_dir : str
Path to the experiment directory.
"""
# Exclude 'cov' and sanitize values for directory names
sanitized_params_for_path = {
k: str(v).replace("/", "_").replace("\\", "_").replace(" ", "_")
for k, v in params.items() if k not in ("run_id", "global_run_id")
# Add other keys to exclude from path if necessary
}
# Desired order of parameters for the directory structure
# This list defines the specific order.
if desired_order is None:
desired_order = ["solver", "noise_type", "orientation_type", "alpha_SNR", "subject", "nnz", "seed"]
path_components = []
# Add parameters in the desired order if they exist in the sanitized params
for key in desired_order:
if key in sanitized_params_for_path:
path_components.append(f"{key}={sanitized_params_for_path[key]}")
del sanitized_params_for_path[key] # Remove to avoid adding them again
# Add any remaining parameters, sorted by key for consistent ordering
for key, value in sorted(sanitized_params_for_path.items()):
path_components.append(f"{key}={value}")
# Create subdirectories for each parameter component
experiment_dir = base_dir
for component in path_components:
experiment_dir = os.path.join(experiment_dir, component)
try:
# Create the directory structure if it doesn't exist
os.makedirs(experiment_dir, exist_ok=True)
except OSError as e:
self.logger.error(f"Failed to create experiment directory: {experiment_dir}. Error: {e}")
raise
self.logger.debug(f"Experiment directory created: {experiment_dir}")
return experiment_dir
def _build_calibration_metadata(
self,
*,
solver_name: str,
solver_params: dict,
data_params: dict,
noise_params: dict,
seed: int,
experiment_dir: str | Path,
global_run_id: int,
config_index: int,
run_in_config: int,
nruns_local: int,
) -> dict:
timestamp = datetime.datetime.utcnow().isoformat()
return {
"global_run_id": global_run_id,
"config_index": config_index,
"run_in_config": run_in_config,
"nruns_per_config": nruns_local,
"seed": seed,
"timestamp_utc": timestamp,
"solver": solver_name,
"solver_params": dict(solver_params),
"data_params": dict(data_params),
"noise_params": dict(noise_params),
"noise_type": noise_params.get("noise_type"),
"subject": data_params.get("subject"),
"nnz": data_params.get("nnz"),
"alpha_SNR": data_params.get("alpha_SNR"),
"orientation_type": data_params.get("orientation_type"),
"experiment_dir": Path(experiment_dir).as_posix(),
}
def _persist_calibration_results(
self,
*,
experiment_dir: str | Path,
record_dir: str | Path,
solver_name: str,
solver_params: dict,
data_params: dict,
noise_params: dict,
seed: int,
global_run_id: int,
config_index: int,
run_in_config: int,
nruns_local: int,
pre_calibration: dict,
post_calibration: dict,
) -> Optional[Path]:
record_dir = Path(record_dir)
record_name = f"calibration_run-{global_run_id:05d}"
metadata = self._build_calibration_metadata(
solver_name=solver_name,
solver_params=solver_params,
data_params=data_params,
noise_params=noise_params,
seed=seed,
experiment_dir=experiment_dir,
global_run_id=global_run_id,
config_index=config_index,
run_in_config=run_in_config,
nruns_local=nruns_local,
)
try:
return save_calibration_record(
output_dir=record_dir,
record_name=record_name,
metadata=metadata,
pre_calibration=pre_calibration,
post_calibration=post_calibration,
)
except Exception as exc:
self.logger.warning(
"Failed to store calibration record for global_run_id %s: %s",
global_run_id,
exc,
)
return None
@staticmethod
def _sanitize_metadata(metadata: dict | None) -> dict:
if not metadata:
return {}
def _convert(value):
if isinstance(value, (np.generic,)):
return value.item()
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, Path):
return value.as_posix()
if isinstance(value, dict):
return {k: _convert(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_convert(v) for v in value]
return value
return {k: _convert(v) for k, v in metadata.items()}
def _persist_posterior_summary(
self,
*,
experiment_dir: str | Path,
datasets: dict,
metadata: dict | None = None,
posterior_dir: str | Path | None = None,
filename: str | None = None,
) -> Optional[Path]:
base_dir = Path(posterior_dir) if posterior_dir else Path(experiment_dir)
base_dir.mkdir(parents=True, exist_ok=True)
summary_filename = filename or "posterior_summary.h5"
summary_path = base_dir / summary_filename
try:
with h5py.File(summary_path, "w") as handle:
for name, value in datasets.items():
dataset_kwargs = {"data": value}
if isinstance(value, np.ndarray):
if value.ndim == 0:
pass
else:
dataset_kwargs["compression"] = "gzip"
else:
value_arr = np.asarray(value)
if value_arr.ndim == 0:
dataset_kwargs["data"] = value_arr
else:
dataset_kwargs["data"] = value_arr
dataset_kwargs["compression"] = "gzip"
handle.create_dataset(name, **dataset_kwargs)
safe_metadata = self._sanitize_metadata(metadata)
if safe_metadata:
handle.attrs["metadata_json"] = json.dumps(safe_metadata)
return summary_path
except Exception as exc:
self.logger.warning(
"Failed to store posterior summary at %s: %s", summary_path, exc
)
return None
def _prepare_run_data(self, data_params: dict, seed: int) -> dict:
data_params = dict(data_params)
orientation_type = data_params.get("orientation_type")
leadfield_data = self.leadfield_builder.get_leadfield(
subject=data_params['subject'],
orientation_type=orientation_type,
retrieve_mode="load",
return_metadata=True,
)
L = leadfield_data.leadfield
if orientation_type == "fixed":
n_sensors, n_sources = L.shape
else:
n_sensors, n_sources = L.shape[:2]
self.sensor_simulator.set_sensor_metadata(
kind=leadfield_data.sensor_kind,
coil_type=leadfield_data.coil_type,
units=leadfield_data.sensor_units,
unitmult=leadfield_data.sensor_unitmult,
)
source_seed = seed
sensor_seed = (seed * 26544) % (2**32)
x, x_active_indices = self.source_simulator.simulate(
n_sources=n_sources,
nnz=data_params['nnz'],
orientation_type=orientation_type,
coil_type=leadfield_data.coil_type,
seed=source_seed,
)
source_units = self.source_simulator.units
source_unitmult = self.source_simulator.unitmult
y_clean, y_noisy, noise, noise_eta = self.sensor_simulator.simulate(
x=x,
L=L,
alpha_SNR=data_params['alpha_SNR'],
sensor_white_noise_std=data_params['sensor_white_noise_std'],
seed=sensor_seed,
)
return {
"leadfield": L,
"src_coords": leadfield_data.src_coords,
"sensor_metadata": {
"kind": leadfield_data.sensor_kind,
"coil_type": leadfield_data.coil_type,
"units": leadfield_data.sensor_units,
"unitmult": leadfield_data.sensor_unitmult,
},
"Q_basis": leadfield_data.Q_basis,
"x": x,
"x_active_indices": x_active_indices,
"y_clean": y_clean,
"y_noisy": y_noisy,
"noise": noise,
"noise_eta": noise_eta,
"source_units": source_units,
"source_unitmult": source_unitmult,
}
def _execute_single_run(
self,
run_id: int,
nruns: int,
total_runs: int,
solver_params: dict,
data_params: dict,
noise_params: dict,
seed: int,
fig_path: str,
global_run_id: Optional[int] = None,
global_total_runs: Optional[int] = None,
) -> dict:
# Ensure worker processes emit INFO logs to stdout/file (joblib workers start with default WARNING level)
root_logger = logging.getLogger()
if not root_logger.handlers:
log_handlers = [logging.StreamHandler()]
log_file_env = os.environ.get("CALIBRAIN_LOG_FILE")
if log_file_env:
try:
log_handlers.insert(0, logging.FileHandler(log_file_env, mode="a"))
except Exception:
pass
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=log_handlers,
)
root_logger.setLevel(logging.INFO)
solver_params = dict(solver_params)
data_params = dict(data_params)
noise_params = dict(noise_params)
solver_name = getattr(self.solver, "__name__", str(self.solver))
orientation_type = data_params.get("orientation_type")
solver_params['fwd_path'] = get_data_path() / '1284src_fwd' / data_params['subject']
# Format human-friendly, 1-based progress counters and avoid division-by-zero
nruns_local = max(1, int(nruns))
# number of parameter configurations (ceil division)
num_configs = max(1, (total_runs + nruns_local - 1) // nruns_local)
config_index = (run_id - 1) // nruns_local + 1
run_in_config = (run_id - 1) % nruns_local + 1
# extract noise_type, nnz and alpha_SNR early for logging
noise_type = noise_params.get("noise_type")
nnz = data_params.get("nnz")
alpha_snr = data_params.get("alpha_SNR")
# delegate logging to helper (keeps formatting in one place)
if global_run_id is None:
global_run_id = run_id
self._log_run_progress(
run_in_config=run_in_config,
nruns_local=nruns_local,
config_index=config_index,
num_configs=num_configs,
run_id=run_id,
total_runs=total_runs,
global_run_id=global_run_id,
global_total_runs=global_total_runs,
solver_name=solver_name,
noise_type=noise_type,
nnz=nnz,
alpha_snr=alpha_snr,
)
self.logger.debug(f"Solver params: {solver_params}")
self.logger.debug(f"Noise params: {noise_params}")
self.logger.debug(f"Data params: {data_params}")
this_result = {
'run_id': run_in_config,
'global_run_id': global_run_id,
"seed": seed,
"solver": solver_name,
'noise_type': noise_params['noise_type'],
**{k: v for k, v in solver_params.items() if k != "fwd_path"},
**data_params,
}
try:
experiment_dir = self._create_experiment_directory(
base_dir=fig_path,
params=this_result,
desired_order=[
"orientation_type", "solver", "noise_type", "nnz", "subject", "alpha_SNR", "seed"
]
)
prepared_data = self._prepare_run_data(data_params, seed)
L = prepared_data["leadfield"]
src_coords = prepared_data["src_coords"]
sensor_meta = prepared_data["sensor_metadata"]
sensor_kind = sensor_meta.get("kind")
sensor_coil_type = sensor_meta.get("coil_type")
sensor_units = sensor_meta.get("units")
sensor_unitmult = sensor_meta.get("unitmult")
# Persist sensor metadata in the per-run results so the CSV manifest
# can be used for filtering without opening the H5 summaries.
this_result["sensor_kind"] = sensor_kind
this_result["coil_type"] = sensor_coil_type
self.sensor_simulator.set_sensor_metadata(
kind=sensor_kind,
coil_type=sensor_coil_type,
units=sensor_units,
unitmult=sensor_unitmult,
)
if orientation_type == "fixed":
n_sensors, n_sources = L.shape
else:
n_sensors, n_sources = L.shape[:2]
x = prepared_data["x"]
x_active_indices = prepared_data["x_active_indices"]
y_clean = prepared_data["y_clean"]
y_noisy = prepared_data["y_noisy"]
noise = prepared_data["noise"]
noise_eta = prepared_data["noise_eta"]
source_units = prepared_data["source_units"]
source_unitmult = prepared_data["source_unitmult"]
n_times = x.shape[-1]
this_result["n_sources"] = int(n_sources)
this_result["n_times"] = int(n_times)
if orientation_type == "fixed":
n_orient = 1
else:
if x.ndim != 3:
raise ValueError(
f"Free-orientation simulations must have shape (N, comps, T); got {x.shape}"
)
n_orient = x.shape[1]
tmin = self.source_simulator.ERP_config['tmin']
stim_onset = self.source_simulator.ERP_config['stim_onset']
sfreq = self.source_simulator.ERP_config['sfreq']
pre_stimulus_onset = int((stim_onset - tmin) * sfreq)
if pre_stimulus_onset <= 0:
self.logger.warning(
"Computed pre_stimulus_onset <= 0; using full trial for baseline estimation"
)
y_pre = y_noisy
else:
y_pre = y_noisy[:, :pre_stimulus_onset]
try:
baseline_noise_var = float(np.mean(np.std(y_pre, axis=1) ** 2))
except Exception:
baseline_noise_var = None
if not baseline_noise_var or not np.isfinite(baseline_noise_var):
baseline_noise_var = 1.0
allowed_noise_types = {
"adaptive_joint_learning",
"oracle",
"baseline",
}
solver_params['src_coords'] = src_coords
noise_type = noise_params.get("noise_type")
if noise_type not in allowed_noise_types:
raise ValueError(f"Invalid noise_type: {noise_type!r}. Allowed: {sorted(allowed_noise_types)}")
if noise_type == 'oracle':
self.logger.debug("Using oracle noise variance estimate")
noise_var = float(np.var(noise))
elif noise_type == 'baseline':
self.logger.debug("Using baseline noise variance estimate")
noise_var = baseline_noise_var
self.logger.debug(
f"Baseline noise variance (global run {run_id}, config run {run_in_config}): {noise_var:.3e}, eta: {noise_eta:.3e}"
)
elif noise_type == 'adaptive_joint_learning':
noise_var = None
source_estimator = SourceEstimator(
solver=self.solver,
solver_params=solver_params,
noise_var=noise_var,
n_orient=n_orient,
logger=self.logger
)
self.logger.debug(f"Fitting source estimator {self.solver.__name__}")
source_estimator.fit(L, y_noisy)
solver_output = source_estimator.predict(y=y_noisy)
x_hat = solver_output.get("posterior_mean_reshaped")
if x_hat is None:
x_hat = solver_output.get("posterior_mean")
x_hat_active_indices = solver_output.get("active_indices")
posterior_cov = solver_output.get("posterior_cov")
noise_var = solver_output.get("noise_var")
gamma = solver_output.get("gamma")
# TODO: remove temporary plotting code
if noise_type == 'adaptive_joint_learning':
plot_error_curves(
err_gamma=solver_output["err_gamma_hist"],
# err_lambda=solver_output["err_lambda_hist"],
title="Gamma errors (joint learning)",
save_path=os.path.join(experiment_dir, "gamma_lambda_errors.png"),
)
this_result['gamma'] = gamma
this_result["noise_var"] = noise_var
# posterior_var = self.uncertainty_estimator.get_posterior_variance(
# posterior_cov=posterior_cov,
# orientation_type=orientation_type
# )
# x_avg_time = np.mean(x, axis=-1, keepdims=True)
# x_hat_avg_time = np.mean(x_hat, axis=-1, keepdims=True)
# n_times = x.shape[-1]
# posterior_var_avg_time = posterior_var / n_times
# posterior_std_avg_time = np.sqrt(np.maximum(posterior_var_avg_time, 0.0))
# TODO: this is a temporary workaround to allow calibration of free orientation solvers using only the norm of the source estimates and their uncertainty.
# for free orientation, reshape posterior covariance from (3N, 3N) to (N, N, 3, 3)
# if orientation_type == "free":
# x_avg_time=np.linalg.norm(x_avg_time, axis=1, keepdims=False)
# x_hat_avg_time=np.linalg.norm(x_hat_avg_time, axis=1, keepdims=False)
# x = np.linalg.norm(x, axis=1, keepdims=False)
# x_hat = np.linalg.norm(x_hat, axis=1, keepdims=False)
# x_hat_active_indices = x_hat_active_indices[:x_hat_avg_time.shape[0]]
active_indices_size = (
len(x_hat_active_indices)
if x_hat_active_indices is not None
else 0
)
this_result['active_indices_size'] = active_indices_size
summary_path = None
if self.save_posterior_stats:
posterior_base = self.posterior_dir if self.posterior_dir is not None else Path(experiment_dir)
summary_filename = f"posterior_summary_run{global_run_id:08d}_seed{seed}.h5"
summary_arrays = {
"x_true": x,
"x_hat": x_hat,
"posterior_cov": posterior_cov,
"Q_basis": np.asarray(prepared_data.get("Q_basis"), dtype=float),
}
summary_metadata = {
"global_run_id": global_run_id,
"run_id": run_in_config,
"nruns": nruns,
"solver": solver_name,
"noise_type": noise_params.get("noise_type"),
"subject": data_params.get("subject"),
"orientation_type": orientation_type,
"nnz": data_params.get("nnz"),
"alpha_SNR": data_params.get("alpha_SNR"),
"n_sources": n_sources,
"n_times": n_times,
"seed": seed,
"experiment_dir": Path(experiment_dir).as_posix(),
"posterior_dir": Path(posterior_base).as_posix(),
"posterior_filename": summary_filename,
"sensor_kind": sensor_kind,
"coil_type": sensor_coil_type,
"sensor_units": sensor_units,
"sensor_unitmult": sensor_unitmult,
}
summary_path = self._persist_posterior_summary(
experiment_dir=experiment_dir,
datasets=summary_arrays,
metadata=summary_metadata,
posterior_dir=posterior_base,
filename=summary_filename,
)
if summary_path is not None:
this_result["posterior_summary"] = summary_path.as_posix()
return this_result
# calibrator = UncertaintyCalibrator(
# uncertainty_estimator=self.uncertainty_estimator,
# metric_evaluator=self.metric_evaluator,
# )
# calibration_results = calibrator.calibrate(
# x_true=x_avg_time,
# x_hat=x_hat_avg_time,
# posterior_std=posterior_std_avg_time,
# )
# pre_calibration = calibration_results['pre_calibration']
# post_calibration = calibration_results['post_calibration']
# calibration_record_dir = Path("results") / "calibration_records"
# record_path = self._persist_calibration_results(
# experiment_dir=experiment_dir,
# record_dir=calibration_record_dir,
# solver_name=solver_name,
# solver_params=solver_params,
# data_params=data_params,
# noise_params=noise_params,
# seed=seed,
# global_run_id=global_run_id,
# config_index=config_index,
# run_in_config=run_in_config,
# nruns_local=nruns_local,
# pre_calibration=pre_calibration,
# post_calibration=post_calibration,
# )
# if record_path is not None:
# this_result["calibration_record"] = record_path.as_posix()
# metric_kwargs = dict(
# x=x_avg_time,
# x_hat=x_hat_avg_time,
# posterior_var=posterior_var_avg_time,
# orientation_type="fixed", # TODO: remove hardcoding
# nnz=data_params.get("nnz"),
# subject=data_params.get("subject"),
# fwd_path=solver_params['fwd_path'],
# )
# try:
# evaluation_metrics = self.metric_evaluator.evaluate_metrics(
# which="evaluation",
# empirical_coverages=pre_calibration['empirical_coverages'],
# **metric_kwargs,
# )
# this_result.update(evaluation_metrics)
# except Exception as e:
# self.logger.error(f"Error while evaluating evaluation metrics: {e}", exc_info=True)
# this_result.update({"metric_evaluation_error": str(e)})
# calibration_metric_names = tuple(
# getattr(self.metric_evaluator, "calibration_metrics", tuple())
# )
# pre_cal_metrics = pre_calibration.get('calibration_metrics', {})
# post_cal_metrics = post_calibration.get('calibration_metrics', {})
# for metric_name in calibration_metric_names:
# pre_value = pre_cal_metrics.get(metric_name)
# post_value = post_cal_metrics.get(metric_name)
# if pre_value is not None:
# this_result[f"pre_cal_{metric_name}"] = pre_value
# if post_value is not None:
# this_result[f"post_cal_{metric_name}"] = post_value
# improvement_key = f"improvement_{metric_name}"
# if (
# pre_value is None
# or post_value is None
# or (isinstance(pre_value, (int, float, np.floating)) and np.isclose(pre_value, 0.0))
# ):
# this_result[improvement_key] = None
# else:
# this_result[improvement_key] = (pre_value - post_value) / pre_value * 100
# viz = Visualizer(base_save_path=experiment_dir, logger=self.logger)
# viz.plot_all(
# x=x,
# x_active_indices=x_active_indices,
# x_hat=x_hat,
# x_hat_active_indices=x_hat_active_indices,
# y_clean=y_clean,
# y_noisy=y_noisy,
# n_sources=n_sources,
# subject=data_params.get("subject"),
# fwd_path=solver_params['fwd_path'],
# nnz=data_params.get("nnz"),
# ERP_config=self.ERP_config,
# sample_idx=200,
# source_units=source_units,
# source_unitmult=source_unitmult,
# sensor_units=sensor_units,
# sensor_unitmult=sensor_unitmult,
# confidence_levels=self.uncertainty_estimator.nominal_coverages,
# nominal_coverages=pre_calibration['nominal_coverages'],
# empirical_coverages=pre_calibration['empirical_coverages'],
# empirical_coverages_post_cal=post_calibration['empirical_coverages'],
# ci_lower=pre_calibration.get('ci_lowers'),
# ci_upper=pre_calibration.get('ci_uppers'),
# orientation_type="fixed", # TODO: remove
# result=this_result,
# experiment_dir=experiment_dir,
# )
except Exception as e:
self.logger.error(
f"Error during data generation global_run_id {this_result.get('global_run_id', 'N/A')} (config run {this_result.get('run_id', 'N/A')}): {e}",
exc_info=True,
)
this_result["error_message"] = str(e)
self.logger.debug(
f"Completed global_run_id {this_result.get('global_run_id', 'N/A')} (config run {this_result.get('run_id', 'N/A')})"
)
return this_result
[docs]
def run(
self,
nruns: int = 2,
fig_path: str = "results/figures",
n_jobs: int = 1,
run_offset: int = 0,
global_total_runs: Optional[int] = None,
):
"""
Run data generation by iterating over combinations of solver and data parameters.
Parameters
----------
nruns : int
Number of seeds to evaluate for each parameter combination.
fig_path : str
Base directory where per-run visualizations will be saved.
n_jobs : int
Number of parallel workers to use. ``1`` (default) keeps the sequential behaviour.
run_offset : int
Number of experiments completed prior to this generator call. Used
for global progress tracking when multiple estimators are run
sequentially.
global_total_runs : int, optional
Total number of experiments planned across estimators. If provided,
the logger also reports this aggregate figure.
Returns
-------
pd.DataFrame
DataFrame containing the results for each parameter combination.
"""
rng = check_random_state(self.random_state)
# First, create all parameter combinations (without seeds)
param_grids = list(product(
ParameterGrid(self.solver_param_grid),
ParameterGrid(self.data_param_grid),
ParameterGrid(self.noise_param_grid),
))
# Calculate total number of runs
num_configs = len(param_grids)
total_runs = num_configs * nruns
# Generate unique seeds for EVERY experiment (not just nruns seeds)
# This ensures each parameter combination gets different random data
all_seeds = rng.randint(low=0, high=2 ** 32, size=total_runs)
# Create parameter combinations with unique seeds
param_combinations = []
seed_idx = 0
for solver_params, data_params, noise_params in param_grids:
for _ in range(nruns):
param_combinations.append((
solver_params,
data_params,
noise_params,
all_seeds[seed_idx]
))
seed_idx += 1
self.logger.info(
"%s\nStarting data generation for estimator %s with %d experiments (%d nruns x %d configurations)",
"-" * 50,
getattr(self.solver, "__name__", str(self.solver)),
total_runs,
nruns,
num_configs,
)
if total_runs == 0:
return pd.DataFrame()
worker_args = [
(
run_id,
nruns,
total_runs,
solver_params,
data_params,
noise_params,
seed,
fig_path,
run_offset + run_id,
global_total_runs,
)
for run_id, (solver_params, data_params, noise_params, seed) in enumerate(param_combinations, start=1)
]
if n_jobs == 1:
# sequential execution
results_list = [self._execute_single_run(*args) for args in worker_args]
else:
# parallel execution
self.logger.debug(f"Running data generation in parallel with n_jobs={n_jobs}")
parallel = Parallel(n_jobs=n_jobs, backend="loky", verbose=0)
results_list = parallel(
delayed(self._execute_single_run)(*args)
for args in worker_args
)
self.logger.debug("Data generation completed.")
return pd.DataFrame(results_list)
# TODO: move plotting functions to Visualizer class
def plot_error_curves(err_gamma, title="Gamma/Lambda errors", save_path=None):
"""
Plot err_gamma and err_lambda vs iteration.
Parameters
----------
err_gamma : sequence of float
Relative gamma errors per iteration.
err_lambda : sequence of float
Relative lambda errors per iteration.
title : str
Plot title.
"""
iters = np.arange(len(err_gamma))
plt.figure()
plt.semilogy(iters, err_gamma, label="err_gamma")
# plt.semilogy(iters, err_lambda, label="err_lambda")
plt.xlabel("Iteration")
plt.ylabel("Relative error")
plt.title(title)
plt.grid(True, which="both", linestyle="--", alpha=0.3)
plt.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
# TODO: remove temporary plotting code
def plot_error_curves_comparison(err_gamma_1, err_lambda_1,
err_gamma_2, err_lambda_2,
labels=("Method 1", "Method 2")):
"""
Compare error curves of two methods (e.g. with sFLEX and without sFLEX).
Parameters
----------
err_gamma_1, err_lambda_1 : sequence of float
Errors for method 1.
err_gamma_2, err_lambda_2 : sequence of float
Errors for method 2.
labels : tuple of str
Labels for the two methods.
"""
iters1 = np.arange(len(err_gamma_1))
iters2 = np.arange(len(err_gamma_2))
plt.figure()
plt.semilogy(iters1, err_gamma_1, label=f"{labels[0]}: err_gamma")
plt.semilogy(iters1, err_lambda_1, label=f"{labels[0]}: err_lambda")
plt.semilogy(iters2, err_gamma_2, label=f"{labels[1]}: err_gamma", linestyle="--")
plt.semilogy(iters2, err_lambda_2, label=f"{labels[1]}: err_lambda", linestyle="--")
plt.xlabel("Iteration")
plt.ylabel("Relative error")
plt.title("Gamma/Lambda Error Comparison")
plt.grid(True, which="both", linestyle="--", alpha=0.3)
plt.legend()
plt.tight_layout()
# TODO: remove temporary plotting code
def plot_alphas_cv(alphas, grid_factors, baseline_noise_var, experiment_dir):
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(grid_factors, alphas, marker='o', linestyle='-', label=f'alphas (n={len(alphas)})')
ax.axhline(
baseline_noise_var,
color='red',
linestyle='--',
label=f'baseline_noise_var = {baseline_noise_var:.3e}',
)
ax.set_xscale('log')
ax.set_xlabel('grid factor (log scale)')
ax.set_ylabel('alpha (noise variance)')
ax.set_title('Spatial CV: Alpha grid vs baseline noise variance')
ax.legend()
ax.grid(True, which="both", ls="--", linewidth=0.5)
plt.tight_layout()
save_path = os.path.join(experiment_dir, "alphas_vs_baseline.png")
fig.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close(fig)