import datetime
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
import mne
import pandas as pd
from sklearn.model_selection import ParameterGrid
from calibrain import (
DataGenerator,
LeadfieldBuilder,
SourceSimulator,
SensorSimulator,
BMN,
BMN_joint,
gamma_map_sflex,
gamma_lambda_map_sflex,
)
from calibrain.utils import get_data_path
from calibrain.utils import load_python_config
DEFAULT_CONFIG_PATH = Path("configs/data_generation_default.py")
_SOLVER_REGISTRY = {
"BMN": BMN,
"BMN_joint": BMN_joint,
"gamma_map_sflex": gamma_map_sflex,
"gamma_lambda_map_sflex": gamma_lambda_map_sflex,
}
def _resolve_solver(name: str):
try:
return _SOLVER_REGISTRY[name]
except KeyError as exc:
raise ValueError(
f"Unknown solver '{name}'. Available: {sorted(_SOLVER_REGISTRY)}"
) from exc
[docs]
def run_data_generation(config: Union[str, Path, Dict[str, Any]]) -> pd.DataFrame:
if isinstance(config, (str, Path)):
config = load_python_config(config)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = Path(config.get("log_dir", "results/logs"))
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"data_generation_log_{timestamp}.log"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.FileHandler(log_file, mode="w"), logging.StreamHandler()],
)
os.environ["CALIBRAIN_LOG_FILE"] = str(log_file)
mne.set_log_level("ERROR")
logging.getLogger("mne").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
nruns = int(config.get("nruns", 1))
generation_n_jobs = int(config.get("generation_n_jobs", 1))
random_state = int(config.get("random_state", 42))
ERP_config = config.get("ERP_config", {})
leadfield_dir = Path(
config.get("leadfield_dir", get_data_path() / "1284src_leadfield")
)
leadfield_builder = LeadfieldBuilder(leadfield_dir=leadfield_dir, logger=logger)
sensor_simulator = SensorSimulator(logger=logger)
source_simulator = SourceSimulator(ERP_config=ERP_config, logger=logger)
estimators_cfg: List[Dict[str, Any]] = config.get("estimators", [])
if not estimators_cfg:
raise ValueError("Data generation config must define at least one estimator entry.")
save_posterior_stats = bool(config.get("save_posterior_stats", True))
posterior_dir = Path(config.get("posterior_dir", "results/posterior_summaries"))
posterior_dir.mkdir(parents=True, exist_ok=True)
manifest_path = Path(config.get("manifest_path", "results/run_manifest.csv"))
manifest_path.parent.mkdir(parents=True, exist_ok=True)
df_list = []
config_counts = []
for estimator in estimators_cfg:
solver = _resolve_solver(estimator["solver"])
entry = (
solver,
estimator.get("solver_params", {}),
estimator.get("data_param_grid", {}),
estimator.get("noise_param_grid", {}),
)
config_counts.append(entry)
total_experiments = 0
for _, solver_params, data_grid, noise_grid in config_counts:
num_configs = (
len(ParameterGrid(solver_params))
* len(ParameterGrid(data_grid))
* len(ParameterGrid(noise_grid))
)
total_experiments += nruns * max(1, num_configs)
run_offset = 0
for solver, solver_params, data_grid, noise_grid in config_counts:
num_configs = (
len(ParameterGrid(solver_params))
* len(ParameterGrid(data_grid))
* len(ParameterGrid(noise_grid))
)
total_local_runs = nruns * max(1, num_configs)
generator = DataGenerator(
solver=solver,
solver_param_grid=solver_params,
data_param_grid=data_grid,
noise_param_grid=noise_grid,
ERP_config=ERP_config,
source_simulator=source_simulator,
leadfield_builder=leadfield_builder,
sensor_simulator=sensor_simulator,
save_posterior_stats=save_posterior_stats,
posterior_dir=posterior_dir,
random_state=random_state,
logger=logger,
)
results_df = generator.run(
nruns=nruns,
n_jobs=generation_n_jobs,
run_offset=run_offset,
global_total_runs=total_experiments,
)
df_list.append(results_df)
run_offset += total_local_runs
if not df_list:
raise RuntimeError("Data generation produced no results data.")
final_df = pd.concat(df_list)
final_df.sort_values(
by=[
"run_id",
"subject",
"orientation_type",
"nnz",
"solver",
"noise_type",
"alpha_SNR",
],
inplace=True,
ascending=True,
)
final_df["nruns"] = nruns
desired_cols = [
"global_run_id",
"run_id",
"subject",
"orientation_type",
"nnz",
"solver",
"noise_type",
"alpha_SNR",
"gamma",
"nruns",
]
other_cols = [c for c in final_df.columns if c not in desired_cols]
final_df = final_df[[c for c in desired_cols if c in final_df.columns] + other_cols]
logger.info("Data generation results DataFrame assembled with %d rows", len(final_df))
if "posterior_summary" not in final_df.columns:
raise ValueError(
"Data generation results are missing the 'posterior_summary' column. "
"Make sure `save_posterior_stats=True` in the data generation config."
)
# Append to (or create) a CSV manifest so downstream steps can discover runs
# without scanning the filesystem.
try:
if manifest_path.exists():
existing = pd.read_csv(manifest_path)
combined = pd.concat([existing, final_df], ignore_index=True)
if "posterior_summary" in combined.columns:
combined.drop_duplicates(subset=["posterior_summary"], inplace=True, keep="last")
else:
combined = final_df
combined.to_csv(manifest_path, index=False)
logger.info("Updated run manifest CSV: %s (%d rows)", manifest_path, len(combined))
except Exception as exc:
raise RuntimeError(f"Failed to write run manifest CSV at {manifest_path}: {exc}") from exc
return final_df
if __name__ == "__main__":
run_data_generation(DEFAULT_CONFIG_PATH)