Source code for calibrain.calibration_storage

from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Iterator, Mapping, Optional, Sequence, Tuple, Union

import numpy as np

SCHEMA_VERSION = 1
_RECORD_GLOB = "calibration_run-*.json"

Serializable = Union[str, int, float, None, bool, Sequence["Serializable"], Mapping[str, "Serializable"]]


def _to_serializable(value: Any) -> Serializable:
    if isinstance(value, Path):
        return value.as_posix()
    if isinstance(value, (np.generic,)):
        return value.item()
    if isinstance(value, np.ndarray):
        return np.asarray(value).tolist()
    if isinstance(value, dict):
        return {str(k): _to_serializable(v) for k, v in value.items()}
    if isinstance(value, (list, tuple)):
        return [_to_serializable(v) for v in value]
    return value  # type: ignore[arg-type]


def _infer_n_observations(stage_data: Mapping[str, Any]) -> Optional[int]:
    if "n_observations" in stage_data and stage_data["n_observations"] is not None:
        return int(stage_data["n_observations"])
    ci_lowers = stage_data.get("ci_lowers")
    if ci_lowers is not None:
        return int(np.asarray(ci_lowers).shape[-1])
    ci_counts = stage_data.get("ci_counts")
    empirical = stage_data.get("empirical_coverages")
    if ci_counts is not None and empirical is not None:
        counts = np.asarray(ci_counts, dtype=float)
        empirical = np.asarray(empirical, dtype=float)
        valid = empirical > 0
        if np.any(valid):
            denom = counts[valid] / empirical[valid]
            return int(round(float(np.median(denom))))
    return None


def _compact_stage(stage_data: Optional[Mapping[str, Any]], stage: str) -> Dict[str, Any]:
    if not stage_data:
        return {}
    payload: Dict[str, Any] = {}
    for key in ("nominal_coverages", "empirical_coverages"):
        value = stage_data.get(key)
        if value is not None:
            payload[key] = value
    metrics = stage_data.get("calibration_metrics")
    if metrics:
        payload["calibration_metrics"] = metrics
    if stage == "pre":
        ci_counts = stage_data.get("ci_counts")
        if ci_counts is not None:
            payload["ci_counts"] = ci_counts
        n_obs = _infer_n_observations(stage_data)
        if n_obs is not None:
            payload["n_observations"] = n_obs
    if stage == "post":
        recal = stage_data.get("recalibrated_nominal_coverages")
        if recal is not None:
            payload["recalibrated_nominal_coverages"] = recal
    return payload


[docs] def save_calibration_record( output_dir: Union[str, Path], record_name: str, metadata: Mapping[str, Any], pre_calibration: Mapping[str, Any], post_calibration: Mapping[str, Any], ) -> Path: """Persist calibration payload alongside metadata in JSON format.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) payload = { "schema_version": SCHEMA_VERSION, "metadata": _to_serializable(dict(metadata)), "pre_calibration": _to_serializable(_compact_stage(pre_calibration, "pre")), "post_calibration": _to_serializable(_compact_stage(post_calibration, "post")), } record_path = output_dir / f"{record_name}.json" with record_path.open("w", encoding="utf-8") as f: json.dump(payload, f, indent=2) return record_path
[docs] def load_calibration_record(path: Union[str, Path]) -> Dict[str, Any]: path = Path(path) with path.open("r", encoding="utf-8") as f: data = json.load(f) data["path"] = path.as_posix() return data
[docs] def iter_calibration_records( root: Union[str, Path], predicate: Optional[Callable[[Mapping[str, Any]], bool]] = None, ) -> Iterator[Dict[str, Any]]: root_path = Path(root) if root_path.is_file(): paths = [root_path] else: paths = sorted(root_path.rglob(_RECORD_GLOB)) for path in paths: record = load_calibration_record(path) meta = record.get("metadata", {}) if predicate and not predicate(meta): continue yield record
def metadata_matcher(**expected: Any) -> Callable[[Mapping[str, Any]], bool]: def _match(meta: Mapping[str, Any]) -> bool: for key, value in expected.items(): if meta.get(key) != value: return False return True return _match
[docs] def stack_empirical_curves( records: Iterable[Mapping[str, Any]], stage: str = "pre", ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: curves = [] weights = [] nominal = None for record in records: stage_payload = record.get(f"{stage}_calibration") or {} empirical = stage_payload.get("empirical_coverages") if empirical is None: continue empirical_arr = np.asarray(empirical, dtype=float) curves.append(empirical_arr) if nominal is None and stage_payload.get("nominal_coverages") is not None: nominal = np.asarray(stage_payload["nominal_coverages"], dtype=float) if stage == "pre": weights.append(stage_payload.get("n_observations", 1.0)) else: weights.append(stage_payload.get("weight", 1.0)) if not curves: raise ValueError(f"No calibration curves found for stage '{stage}'") if nominal is None: nominal = np.arange(curves[0].shape[-1]) curves_arr = np.vstack(curves) weights_arr = np.asarray(weights, dtype=float) if not np.any(weights_arr): weights_arr = None return nominal, curves_arr, weights_arr