Source code for calibrain.workflows.aggregation

from datetime import datetime, timezone
import json
import os
from pathlib import Path
from typing import Any, Dict, Union

import numpy as np

from calibrain.calibration_dataset import (
    concatenate_summaries,
    filter_summaries_by_metadata,
    EEG_COIL_TYPES,
    MEG_COIL_TYPES,
)
from calibrain.run_manifest import summaries_from_manifest
from calibrain.utils import load_python_config

DEFAULT_CONFIG_PATH = Path("configs/aggregate_default.py")

UNCERTAINTY_SCHEMA_VERSION = 2


def _serialize_filter(criteria: Dict[str, Any] | None) -> Dict[str, Any]:
    if not criteria:
        return {}

    def _convert(value: Any) -> Any:
        if callable(value):
            doc = getattr(value, "__doc__", None)
            return doc.strip() if isinstance(doc, str) else repr(value)
        if isinstance(value, dict):
            return {k: _convert(v) for k, v in value.items()}
        if isinstance(value, (list, tuple, set, frozenset)):
            return [_convert(v) for v in value]
        return value

    return {key: _convert(val) for key, val in criteria.items()}


def _reduce_posterior_uncertainty(dataset: Dict[str, Any]) -> Dict[str, Any]:
    """
    Reduce stored uncertainty so aggregated datasets do not persist full
    posterior covariance matrices.

    Output schema (UNCERTAINTY_SCHEMA_VERSION=2)
    - fixed: store `posterior_var` with shape (N,)
    - free EEG/MEG: store `posterior_cov_blocks` with shape (N,K,K), K=3 (EEG) or 2 (MEG)
    """
    if "posterior_cov" not in dataset:
        raise ValueError("Expected dataset to contain 'posterior_cov' for reduction.")

    orientation = str(dataset.get("orientation_type") or "fixed").lower()
    coil_type = dataset.get("coil_type")
    cov = np.asarray(dataset["posterior_cov"], dtype=float)

    # Always drop the full covariance from aggregated outputs.
    dataset = dict(dataset)
    dataset.pop("posterior_cov", None)
    dataset["uncertainty_schema_version"] = int(UNCERTAINTY_SCHEMA_VERSION)

    if orientation == "fixed":
        if cov.ndim != 2 or cov.shape[0] != cov.shape[1]:
            raise ValueError(f"Fixed posterior_cov must be square; got {cov.shape}")
        dataset["posterior_var"] = np.maximum(np.diag(cov).astype(float, copy=False), 0.0)
        return dataset

    if orientation != "free":
        raise ValueError(f"Unsupported orientation_type '{orientation}' for uncertainty reduction.")

    if coil_type in EEG_COIL_TYPES:
        block_dim = 3
    elif coil_type in MEG_COIL_TYPES:
        block_dim = 2
    else:
        raise ValueError(
            f"Free-orientation datasets must specify a supported coil_type; got {coil_type}"
        )

    n_sources = int(dataset.get("n_sources") or 0)
    if n_sources <= 0:
        x_true = np.asarray(dataset.get("x_true"))
        n_sources = int(x_true.shape[0]) if x_true.size else 0
    if n_sources <= 0:
        raise ValueError("Unable to infer n_sources for free-orientation uncertainty reduction.")

    expected = (block_dim * n_sources, block_dim * n_sources)
    if cov.shape != expected:
        raise ValueError(
            f"Free posterior_cov must be {expected} for K={block_dim}; got {cov.shape}"
        )

    blocks = np.zeros((n_sources, block_dim, block_dim), dtype=float)
    for i in range(n_sources):
        start = block_dim * i
        blocks[i] = cov[start:start + block_dim, start:start + block_dim]
    dataset["posterior_cov_blocks"] = blocks
    return dataset


def _build_output_path(base_dir: Path, metadata: Dict[str, Any]) -> Path:
    base_dir.mkdir(parents=True, exist_ok=True)
    run_id = metadata.get("run_id") or metadata.get("global_run_id")
    subject = metadata.get("subject")
    solver = metadata.get("solver")
    seed = metadata.get("seed")
    parts = []
    if subject:
        parts.append(str(subject))
    if solver:
        parts.append(str(solver))
    if run_id is not None:
        parts.append(f"run{int(run_id):08d}")
    if seed is not None:
        parts.append(f"seed{seed}")
    if not parts:
        parts.append("summary")
    stem = "_".join(parts)
    return base_dir / f"{stem}.npz"


def _write_dataset(
    *,
    dataset: Dict[str, Any],
    output_path: Path,
    summaries,
    summaries_root: Path,
    split_name: str,
    criteria: Dict[str, Any] | None,
) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    metadata_entries = dataset.get("metadata") or []
    subject_candidates = {
        entry.get("subject")
        for entry in metadata_entries
        if isinstance(entry, dict) and entry.get("subject")
    }
    primary_subject = subject_candidates.pop() if len(subject_candidates) == 1 else None

    solver_candidates = {
        entry.get("solver")
        for entry in metadata_entries
        if isinstance(entry, dict) and entry.get("solver")
    }
    primary_solver = solver_candidates.pop() if len(solver_candidates) == 1 else None

    arrays: Dict[str, np.ndarray] = {}
    scalars: Dict[str, Any] = {}
    for key, value in dataset.items():
        if key == "metadata":
            continue
        if isinstance(value, np.ndarray):
            arrays[key] = value
        elif value is None:
            continue
        else:
            scalars[key] = value

    if primary_subject is not None:
        scalars["subject"] = primary_subject
    if primary_solver is not None:
        scalars["solver"] = primary_solver

    # Persist additional per-run metadata that is useful for downstream grouping
    # (plots, pooled filtering, pairing fixed/free runs). When aggregating a
    # single posterior summary (the default), these are unambiguous.
    def _unique_or_none(values):
        uniq = {v for v in values if v is not None}
        return uniq.pop() if len(uniq) == 1 else None

    def _collect_unique(key: str):
        return _unique_or_none(
            entry.get(key)
            for entry in metadata_entries
            if isinstance(entry, dict)
        )

    for key in ("noise_type", "alpha_SNR", "nnz", "seed", "run_id", "global_run_id"):
        value = _collect_unique(key)
        if value is not None:
            scalars[key] = value

    scalar_arrays = {key: np.array(val) for key, val in scalars.items()}
    # Write NPZ atomically to avoid corrupted zip archives if the job is interrupted.
    # NOTE: numpy appends ".npz" when the provided filename does not end with it.
    # Use a tmp path that ends with ".npz" so we can safely os.replace it.
    tmp_npz = output_path.with_suffix(".tmp.npz")
    np.savez_compressed(tmp_npz, **arrays, **scalar_arrays)
    os.replace(tmp_npz, output_path)

    dataset_meta = {
        "orientation_type": dataset.get("orientation_type"),
        "coil_type": dataset.get("coil_type"),
        "sensor_kind": dataset.get("sensor_kind"),
        "n_sources": int(dataset.get("n_sources", 0)),
        "n_times": int(dataset.get("n_times", 0)),
        "subject": primary_subject,
        "solver": primary_solver,
    }
    meta_path = output_path.with_suffix(".json")
    meta_payload = {
        "created_at": datetime.now(timezone.utc).isoformat(),
        "summary_root": str(summaries_root),
        "summary_count": len(summaries),
        "n_sources": dataset_meta["n_sources"],
        "n_times": dataset_meta["n_times"],
        "split": split_name,
        "criteria": _serialize_filter(criteria),
        "dataset_info": dataset_meta,
        "solver": primary_solver,
        "summaries": [str(summary.path) for summary in summaries],
    }
    tmp_meta = meta_path.with_name(f"{meta_path.name}.tmp")
    tmp_meta.write_text(json.dumps(meta_payload, indent=2), encoding="utf-8")
    os.replace(tmp_meta, meta_path)


def _aggregate_single(config: Dict[str, Any], tag: str | None = None) -> None:
    if isinstance(config, (str, Path)):
        config = load_python_config(config)

    if "manifest_path" not in config:
        raise KeyError(
            "Aggregation config must specify 'manifest_path' (CSV) so discovery does not "
            "scan the filesystem."
        )
    manifest_path = Path(config["manifest_path"])
    if not manifest_path.exists():
        raise FileNotFoundError(f"Manifest CSV does not exist: {manifest_path}")

    summaries_root = Path(config.get("summaries_root") or manifest_path.parent)

    output_dir = Path(config.get("output_dir", "results/calibration_datasets"))
    criteria = config.get("filter")

    summaries = summaries_from_manifest(manifest_path)
    if not summaries:
        raise FileNotFoundError(
            f"No posterior summary files found in manifest {manifest_path}"
        )

    if criteria:
        summaries = filter_summaries_by_metadata(summaries, criteria)
        if not summaries:
            raise ValueError("No summaries matched the provided metadata filter.")

    for summary in summaries:
        dataset = concatenate_summaries([summary])
        dataset = _reduce_posterior_uncertainty(dataset)
        meta = (dataset.get("metadata") or [{}])[0]
        output_path = _build_output_path(output_dir, meta or {})
        _write_dataset(
            dataset=dataset,
            output_path=output_path,
            summaries=[summary],
            summaries_root=summaries_root,
            split_name="single",
            criteria=criteria,
        )
    if tag:
        print(f"[aggregation] Completed split '{tag}' with {len(summaries)} summaries.")


[docs] def aggregate_posteriors(config: Union[str, Path, Dict[str, Any]]) -> None: if isinstance(config, (str, Path)): config = load_python_config(config) if "splits" in config: splits = config["splits"] if isinstance(splits, dict): items = splits.items() else: items = enumerate(splits) for name, sub_config in items: if isinstance(sub_config, dict): sub_config = dict(sub_config) if "manifest_path" not in sub_config and "manifest_path" in config: sub_config["manifest_path"] = config["manifest_path"] if "summaries_root" not in sub_config and "summaries_root" in config: sub_config["summaries_root"] = config["summaries_root"] print(f"[aggregation] Running split '{name}'") _aggregate_single(sub_config, tag=str(name)) return _aggregate_single(config)
if __name__ == "__main__": aggregate_posteriors(DEFAULT_CONFIG_PATH)