Source code for calibrain.metric_evaluation

import logging
from typing import Optional, Dict, Any, Iterable, Tuple

import numpy as np
from calibrain import UncertaintyEstimator

try:
    from scipy.spatial.distance import cdist
except Exception:
    cdist = None

try:
    from ot import emd2
except Exception:
    emd2 = None


# =============================================================================
# Helpers
# =============================================================================
def lift_reduced_sources_to_3d(a_red: np.ndarray, Q_basis: np.ndarray) -> np.ndarray:
    """
    Lift reduced coordinates back to the retained local 3D source basis.

    Parameters
    ----------
    a_red : (N, k, T)
    Q_basis : (N, 3, k)

    Returns
    -------
    x_3d : (N, 3, T)
    """
    a_red = np.asarray(a_red, dtype=float)
    Q_basis = np.asarray(Q_basis, dtype=float)

    if a_red.ndim != 3:
        raise ValueError(f"Expected a_red with shape (N, k, T). Got {a_red.shape}")
    if Q_basis.ndim != 3:
        raise ValueError(f"Expected Q_basis with shape (N, 3, k). Got {Q_basis.shape}")
    if a_red.shape[0] != Q_basis.shape[0]:
        raise ValueError("Mismatch in number of sources between a_red and Q_basis")
    if a_red.shape[1] != Q_basis.shape[2]:
        raise ValueError("Mismatch in reduced dimension k between a_red and Q_basis")

    return np.einsum("nok,nkt->not", Q_basis, a_red)


def _reshape_free_mean(arr: np.ndarray, n_sources: int, n_components: int) -> np.ndarray:
    arr = np.asarray(arr, dtype=float)
    if arr.ndim == 3:
        if arr.shape[0] != n_sources or arr.shape[1] != n_components:
            raise ValueError(
                f"Expected shape ({n_sources},{n_components},T); got {arr.shape}"
            )
        return arr
    if arr.ndim != 2:
        raise ValueError("Posterior means must be 2D or 3D arrays.")
    n_times = arr.shape[1]
    expected = n_sources * n_components
    if arr.shape[0] != expected:
        raise ValueError(
            f"First dimension must equal {expected}; got {arr.shape[0]}"
        )
    return arr.reshape(n_components, n_sources, n_times).transpose(1, 0, 2)


[docs] def get_subset_source_rr( lf_dict: Dict[str, Any], *, to_mm: bool = False, ) -> np.ndarray: """ Get subset-aligned source coordinates for EMD and other source-space metrics. Parameters ---------- lf_dict : dict Output of extract_subset_leadfield(...), containing: - "fwd" - "subset_idx" to_mm : bool If True, return coordinates in millimeters. Otherwise return coordinates in meters (MNE default). Returns ------- coords_subset : ndarray, shape (N, 3) Source coordinates aligned exactly with lf_dict["subset_idx"] and therefore with the source ordering used in L_flat and L_block. Notes ----- - This helper uses the used-source ordering in the forward solution: [left hemisphere used sources, right hemisphere used sources] - This matches how subset_idx is constructed inside extract_subset_leadfield(...). - The same helper works for fixed, free_eeg, and free_meg, because subset_idx is source-level, not coefficient-level. """ if "fwd" not in lf_dict: raise ValueError('lf_dict must contain key "fwd".') if "subset_idx" not in lf_dict: raise ValueError('lf_dict must contain key "subset_idx".') fwd = lf_dict["fwd"] subset_idx = np.asarray(lf_dict["subset_idx"], dtype=int) src = fwd["src"] if len(src) != 2: raise ValueError("Expected a two-hemisphere source space in fwd['src'].") rr_lh = np.asarray(src[0]["rr"][src[0]["vertno"]], dtype=float) rr_rh = np.asarray(src[1]["rr"][src[1]["vertno"]], dtype=float) rr_used = np.vstack([rr_lh, rr_rh]) if np.any(subset_idx < 0) or np.any(subset_idx >= rr_used.shape[0]): raise ValueError( f"subset_idx contains out-of-range entries for used-source coordinates. " f"Valid range is [0, {rr_used.shape[0] - 1}]." ) coords_subset = rr_used[subset_idx] if to_mm: coords_subset = 1000.0 * coords_subset return coords_subset
# ============================================================================= # MetricEvaluator # ============================================================================= DEFAULT_CALIBRATION_METRICS = ( "mean_signed_deviation", "mean_absolute_deviation", "max_underconfidence_deviation", "max_overconfidence_deviation", )
[docs] class MetricEvaluator: """ Evaluator aligned with the updated UncertaintyEstimator. Supported settings ------------------ - fixed - eeg_free - meg_free Supported modes --------------- - pointwise - aggregated Conventions ----------- 1) fixed: x_true, x_hat have shape (N,T) 2) eeg_free: x_true, x_hat have shape (N,3,T) error metrics use amplitude representation ||x_i(t)||_2 3) meg_free: x_true is 3D truth in the retained local 3D basis, shape (N,3,T) x_hat is reduced 2D posterior mean, shape (N,2,T) or flat (2N,T) posterior covariance is reduced 2D, shape (2N,2N) error metrics use reduced-coordinate amplitude representation after projecting truth to 2D via the same V_tan basis used by the inverse Notes ----- - Aggregated mode always means time-average. - EMD expects coords already aligned with the source subset being evaluated. - Calibration metrics are: * max_underconfidence_deviation * max_overconfidence_deviation * mean_absolute_deviation * mean_signed_deviation - For MEG, pass V_tan = lf_free_meg["Q_basis"] whenever possible to avoid basis mismatches. """
[docs] def __init__( self, ue: UncertaintyEstimator, *, nominal_coverages: Optional[Iterable[float]] = None, evaluation_metrics: Optional[Iterable[str]] = None, calibration_metrics: Optional[Iterable[str]] = None, logger: Optional[logging.Logger] = None, ): self.ue = ue self.logger = logger or logging.getLogger(__name__) if nominal_coverages is not None: cov = np.asarray(list(nominal_coverages), dtype=float) if cov.ndim != 1 or cov.size == 0: raise ValueError("nominal_coverages must be a 1-D array with at least one entry.") self.ue.nominal_coverages = cov self.nominal_coverages = np.asarray(self.ue.nominal_coverages, dtype=float) self.evaluation_metrics = ( tuple(evaluation_metrics) if evaluation_metrics is not None else tuple() ) self.calibration_metrics = ( tuple(calibration_metrics) if calibration_metrics is not None else DEFAULT_CALIBRATION_METRICS )
# ------------------------------------------------------------------ # Validation helpers # ------------------------------------------------------------------ @staticmethod def _check_setting(setting: str) -> str: setting = (setting or "").lower().strip() if setting not in {"fixed", "eeg_free", "meg_free"}: raise ValueError("setting must be one of: 'fixed', 'eeg_free', 'meg_free'.") return setting @staticmethod def _check_mode(mode: str) -> str: mode = (mode or "").lower().strip() if mode not in {"pointwise", "aggregated"}: raise ValueError("mode must be one of: 'pointwise', 'aggregated'.") return mode def __getstate__(self): state = self.__dict__.copy() state['logger'] = None return state def __setstate__(self, state): self.__dict__.update(state) if self.logger is None: self.logger = logging.getLogger(self.__class__.__name__) if not hasattr(self, "evaluation_metrics"): legacy = state.get("metrics") if legacy is None: legacy = () self.evaluation_metrics = tuple(legacy) if not hasattr(self, "calibration_metrics") or self.calibration_metrics is None: self.calibration_metrics = DEFAULT_CALIBRATION_METRICS @staticmethod def _reshape_meg_mean_if_needed(x_hat_meg: np.ndarray, N: int, T: int) -> np.ndarray: x_hat_meg = np.asarray(x_hat_meg, dtype=float) if x_hat_meg.ndim == 3: if x_hat_meg.shape != (N, 2, T): raise ValueError(f"x_hat for meg_free must be (N,2,T); got {x_hat_meg.shape}") return x_hat_meg if x_hat_meg.ndim == 2: if x_hat_meg.shape != (2 * N, T): raise ValueError(f"x_hat for meg_free flat form must be (2N,T); got {x_hat_meg.shape}") return x_hat_meg.reshape(N, 2, T) raise ValueError("x_hat for meg_free must have shape (N,2,T) or (2N,T).") @staticmethod def _cov_blocks_from_full(posterior_uncert: np.ndarray, block_dim: int) -> np.ndarray: posterior_uncert = np.asarray(posterior_uncert, dtype=float) if posterior_uncert.ndim == 3: if posterior_uncert.shape[1:] != (block_dim, block_dim): raise ValueError( f"Block covariance form must be (N,{block_dim},{block_dim}); " f"got {posterior_uncert.shape}" ) return posterior_uncert if posterior_uncert.ndim == 2: M, K = posterior_uncert.shape if M != K or M % block_dim != 0: raise ValueError( f"Full covariance must be square with size multiple of {block_dim}; " f"got {posterior_uncert.shape}" ) N = M // block_dim blocks = np.zeros((N, block_dim, block_dim), dtype=float) for i in range(N): blocks[i] = posterior_uncert[ block_dim * i:block_dim * i + block_dim, block_dim * i:block_dim * i + block_dim, ] return blocks raise ValueError("posterior_uncert must be either full covariance or block covariance.") @staticmethod def _fixed_variance_from_uncert(posterior_uncert: np.ndarray) -> np.ndarray: posterior_uncert = np.asarray(posterior_uncert, dtype=float) if posterior_uncert.ndim == 1: return np.maximum(posterior_uncert, 0.0) if posterior_uncert.ndim == 2: if posterior_uncert.shape[0] != posterior_uncert.shape[1]: raise ValueError("Fixed posterior covariance must be square.") return np.maximum(np.diag(posterior_uncert), 0.0) raise ValueError("For fixed setting, posterior_uncert must be (N,) or (N,N).") def _get_meg_truth_2d( self, x_true_meg_3d: np.ndarray, *, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, ) -> np.ndarray: x_true_meg_3d = np.asarray(x_true_meg_3d, dtype=float) if x_true_meg_3d.ndim != 3 or x_true_meg_3d.shape[1] != 3: raise ValueError(f"x_true for meg_free must be (N,3,T); got {x_true_meg_3d.shape}") N = x_true_meg_3d.shape[0] if V_tan is None: if L_free_Mx3N is None: raise ValueError("Provide either V_tan or L_free_Mx3N for meg_free.") V_tan = self.ue.precompute_meg_tangent_bases_svd(L_free_Mx3N) V_tan = np.asarray(V_tan, dtype=float) if V_tan.shape != (N, 3, 2): raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}") return self.ue.project_meg_3d_to_2d(x_true_meg_3d, V_tan)
[docs] @staticmethod def calibration_metrics_4(nominal: np.ndarray, empirical: np.ndarray) -> Dict[str, float]: nominal = np.asarray(nominal, dtype=float) empirical = np.asarray(empirical, dtype=float) dev = empirical - nominal under = np.maximum(nominal - empirical, 0.0) over = np.maximum(empirical - nominal, 0.0) return { "max_underconfidence_deviation": float(np.max(under)), "max_overconfidence_deviation": float(np.max(over)), "mean_absolute_deviation": float(np.mean(np.abs(dev))), "mean_signed_deviation": float(np.mean(dev)), }
# ------------------------------------------------------------------ # Signal extraction for error metrics # ------------------------------------------------------------------ def _signals_for_error_metrics( self, *, x_true: np.ndarray, x_hat: np.ndarray, setting: str, mode: str, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, ) -> Dict[str, np.ndarray]: setting = self._check_setting(setting) mode = self._check_mode(mode) if setting == "fixed": x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) if x_true.ndim != 2 or x_hat.ndim != 2 or x_true.shape != x_hat.shape: raise ValueError("For fixed, x_true and x_hat must both be (N,T).") if mode == "pointwise": return {"truth_signal": x_true, "est_signal": x_hat} x_true_agg = np.mean(x_true, axis=1) x_hat_agg = np.mean(x_hat, axis=1) return {"truth_signal": x_true_agg, "est_signal": x_hat_agg} if setting == "eeg_free": x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) if x_true.ndim != 3 or x_true.shape[1] != 3: raise ValueError(f"For eeg_free, x_true must be (N,3,T); got {x_true.shape}") if x_hat.shape != x_true.shape: raise ValueError("For eeg_free, x_hat must match x_true shape (N,3,T).") if mode == "pointwise": amp_true = np.linalg.norm(x_true, axis=1) amp_hat = np.linalg.norm(x_hat, axis=1) return {"truth_signal": amp_true, "est_signal": amp_hat} x_true_agg = np.mean(x_true, axis=2) x_hat_agg = np.mean(x_hat, axis=2) amp_true_agg = np.linalg.norm(x_true_agg, axis=1) amp_hat_agg = np.linalg.norm(x_hat_agg, axis=1) return {"truth_signal": amp_true_agg, "est_signal": amp_hat_agg} # meg_free x_true = np.asarray(x_true, dtype=float) if x_true.ndim != 3 or x_true.shape[1] != 3: raise ValueError(f"For meg_free, x_true must be (N,3,T); got {x_true.shape}") N, _, T = x_true.shape x_hat_2d = self._reshape_meg_mean_if_needed(x_hat, N, T) x_true_2d = self._get_meg_truth_2d( x_true, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ) if mode == "pointwise": amp_true = np.linalg.norm(x_true_2d, axis=1) amp_hat = np.linalg.norm(x_hat_2d, axis=1) return {"truth_signal": amp_true, "est_signal": amp_hat} x_true_agg = np.mean(x_true_2d, axis=2) x_hat_agg = np.mean(x_hat_2d, axis=2) amp_true_agg = np.linalg.norm(x_true_agg, axis=1) amp_hat_agg = np.linalg.norm(x_hat_agg, axis=1) return {"truth_signal": amp_true_agg, "est_signal": amp_hat_agg} # ------------------------------------------------------------------ # Error metrics # ------------------------------------------------------------------
[docs] def mse( self, *, x_true: np.ndarray, x_hat: np.ndarray, setting: str, mode: str = "pointwise", V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, ) -> float: sig = self._signals_for_error_metrics( x_true=x_true, x_hat=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ) d = sig["truth_signal"] - sig["est_signal"] return float(np.mean(d * d))
[docs] def mae( self, *, x_true: np.ndarray, x_hat: np.ndarray, setting: str, mode: str = "pointwise", V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, ) -> float: sig = self._signals_for_error_metrics( x_true=x_true, x_hat=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ) return float(np.mean(np.abs(sig["truth_signal"] - sig["est_signal"])))
[docs] def rmse( self, *, x_true: np.ndarray, x_hat: np.ndarray, setting: str, mode: str = "pointwise", V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, ) -> float: return float(np.sqrt(self.mse( x_true=x_true, x_hat=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, )))
[docs] def rmae( self, *, x_true: np.ndarray, x_hat: np.ndarray, setting: str, mode: str = "pointwise", V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, ) -> float: return float(np.sqrt(self.mae( x_true=x_true, x_hat=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, )))
# ------------------------------------------------------------------ # Posterior uncertainty summary # ------------------------------------------------------------------
[docs] def mean_posterior_std( self, *, posterior_uncert: np.ndarray, setting: str, mode: str = "pointwise", n_times: Optional[int] = None, ) -> float: setting = self._check_setting(setting) mode = self._check_mode(mode) if setting == "fixed": var = self._fixed_variance_from_uncert(posterior_uncert) std = np.sqrt(np.maximum(var, 0.0)) if mode == "aggregated": if n_times is None: raise ValueError("For aggregated mode, n_times must be provided.") std = std / np.sqrt(float(n_times)) return float(np.mean(std)) if setting == "eeg_free": blocks = self._cov_blocks_from_full(posterior_uncert, block_dim=3) source_std = np.sqrt(np.maximum(np.trace(blocks, axis1=1, axis2=2) / 3.0, 0.0)) if mode == "aggregated": if n_times is None: raise ValueError("For aggregated mode, n_times must be provided.") source_std = source_std / np.sqrt(float(n_times)) return float(np.mean(source_std)) # meg_free blocks = self._cov_blocks_from_full(posterior_uncert, block_dim=2) source_std = np.sqrt(np.maximum(np.trace(blocks, axis1=1, axis2=2) / 2.0, 0.0)) if mode == "aggregated": if n_times is None: raise ValueError("For aggregated mode, n_times must be provided.") source_std = source_std / np.sqrt(float(n_times)) return float(np.mean(source_std))
# ------------------------------------------------------------------ # EMD # ------------------------------------------------------------------ def _source_mass_for_emd( self, *, x: np.ndarray, setting: str, mode: str, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, is_truth: bool = True, ) -> np.ndarray: setting = self._check_setting(setting) mode = self._check_mode(mode) if setting == "fixed": x = np.asarray(x, dtype=float) if x.ndim != 2: raise ValueError("For fixed, x must be (N,T).") if mode == "pointwise": return np.linalg.norm(x, axis=1) return np.abs(np.mean(x, axis=1)) if setting == "eeg_free": x = np.asarray(x, dtype=float) if x.ndim != 3 or x.shape[1] != 3: raise ValueError("For eeg_free, x must be (N,3,T).") if mode == "pointwise": amp = np.linalg.norm(x, axis=1) return np.linalg.norm(amp, axis=1) x_agg = np.mean(x, axis=2) return np.linalg.norm(x_agg, axis=1) # meg_free if is_truth: x_true_2d = self._get_meg_truth_2d( x, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ) if mode == "pointwise": amp = np.linalg.norm(x_true_2d, axis=1) return np.linalg.norm(amp, axis=1) x_agg = np.mean(x_true_2d, axis=2) return np.linalg.norm(x_agg, axis=1) x = np.asarray(x, dtype=float) if x.ndim != 3 or x.shape[1] != 2: raise ValueError("For meg_free estimate, x must be (N,2,T).") if mode == "pointwise": amp = np.linalg.norm(x, axis=1) return np.linalg.norm(amp, axis=1) x_agg = np.mean(x, axis=2) return np.linalg.norm(x_agg, axis=1)
[docs] def emd( self, *, x_true: np.ndarray, x_hat: np.ndarray, coords: np.ndarray, setting: str, mode: str = "pointwise", V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, eps: float = 1e-12, ) -> float: if cdist is None or emd2 is None: raise ImportError("EMD requires scipy.spatial.distance.cdist and POT (ot.emd2).") setting = self._check_setting(setting) mode = self._check_mode(mode) coords = np.asarray(coords, dtype=float) if coords.ndim != 2 or coords.shape[1] != 3: raise ValueError("coords must be aligned subset coordinates with shape (N,3).") if setting == "meg_free": x_true = np.asarray(x_true, dtype=float) if x_true.ndim != 3 or x_true.shape[1] != 3: raise ValueError("For meg_free, x_true must be (N,3,T).") N, _, T = x_true.shape x_hat_2d = self._reshape_meg_mean_if_needed(x_hat, N, T) a = self._source_mass_for_emd( x=x_true, setting="meg_free", mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, is_truth=True, ) b = self._source_mass_for_emd( x=x_hat_2d, setting="meg_free", mode=mode, is_truth=False, ) else: a = self._source_mass_for_emd( x=x_true, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, is_truth=True, ) b = self._source_mass_for_emd( x=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, is_truth=False, ) if coords.shape[0] != a.shape[0] or coords.shape[0] != b.shape[0]: raise ValueError( f"coords must align with evaluated subset. " f"Got coords.shape[0]={coords.shape[0]}, masses {a.shape[0]} and {b.shape[0]}" ) a_mask = a > eps b_mask = b > eps if not np.any(a_mask) or not np.any(b_mask): self.logger.warning("EMD: empty active set in true or estimate -> returning inf.") return float(np.inf) a_w = a[a_mask] b_w = b[b_mask] rr_a = coords[a_mask] rr_b = coords[b_mask] M = cdist(rr_a, rr_b, metric="euclidean") a_norm = a_w / np.sum(a_w) b_norm = b_w / np.sum(b_w) return float(emd2(a_norm, b_norm, M))
# ------------------------------------------------------------------ # Calibration # ------------------------------------------------------------------
[docs] def calibration_curve( self, *, x_true: np.ndarray, x_hat: np.ndarray, posterior_uncert: np.ndarray, setting: str, mode: str = "aggregated", V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, free_interval_type: str = "full_cov", ) -> Dict[str, Any]: setting = self._check_setting(setting) mode = self._check_mode(mode) if free_interval_type not in {"full_cov", "marginal"}: raise ValueError( "free_interval_type must be 'full_cov' or 'marginal'. " f"Got {free_interval_type!r}." ) if setting == "fixed": posterior_var = self._fixed_variance_from_uncert(posterior_uncert) if mode == "pointwise": curve = self.ue.calibration_curve_intervals_pointwise( x_true=x_true, x_hat=x_hat, posterior_var=posterior_var, ) else: curve = self.ue.calibration_curve_intervals_aggregated( x_true=x_true, x_hat=x_hat, posterior_var=posterior_var, ) elif setting == "eeg_free": if free_interval_type == "marginal": if mode == "pointwise": curve = self.ue.calibration_curve_componentwise_eeg_free_pointwise( x_true=x_true, x_hat=x_hat, posterior_uncert=posterior_uncert, ) else: curve = self.ue.calibration_curve_componentwise_eeg_free_aggregated( x_true=x_true, x_hat=x_hat, posterior_uncert=posterior_uncert, ) else: if mode == "pointwise": curve = self.ue.calibration_curve_ellipsoid_eeg_free_pointwise( x_true=x_true, x_hat=x_hat, posterior_cov=posterior_uncert, ) else: curve = self.ue.calibration_curve_ellipsoid_eeg_free_aggregated( x_true=x_true, x_hat=x_hat, posterior_cov=posterior_uncert, ) else: # meg_free x_true = np.asarray(x_true, dtype=float) if x_true.ndim != 3: raise ValueError(f"For meg_free, x_true must be 3D (N,K,T); got {x_true.shape}") N, K, T = x_true.shape if free_interval_type == "marginal": # Work in reduced 2D tangent coordinates. if K == 2: x_true_2d = x_true elif K == 3: if V_tan is None: raise ValueError("meg_free marginal calibration needs V_tan when x_true is 3D.") V = np.asarray(V_tan, dtype=float) if V.shape != (N, 3, 2): raise ValueError(f"V_tan must have shape (N,3,2); got {V.shape}") x_true_2d = np.einsum('nck,nct->nkt', V, x_true) else: raise ValueError(f"For meg_free marginal, x_true must have K=2 or K=3; got {K}") x_hat_arr = np.asarray(x_hat, dtype=float) if x_hat_arr.ndim == 3 and x_hat_arr.shape[1] == 3: if V_tan is None: raise ValueError("meg_free marginal calibration needs V_tan when x_hat is 3D.") V = np.asarray(V_tan, dtype=float) if V.shape != (N, 3, 2): raise ValueError(f"V_tan must have shape (N,3,2); got {V.shape}") x_hat_2d = np.einsum('nck,nct->nkt', V, x_hat_arr) else: x_hat_2d = self._reshape_meg_mean_if_needed(x_hat, N, T) if mode == "pointwise": curve = self.ue.calibration_curve_componentwise_meg_free_pointwise( x_true_2d=x_true_2d, x_hat_2d=x_hat_2d, posterior_uncert_2d=posterior_uncert, ) else: curve = self.ue.calibration_curve_componentwise_meg_free_aggregated( x_true_2d=x_true_2d, x_hat_2d=x_hat_2d, posterior_uncert_2d=posterior_uncert, ) else: # Full-covariance 3D ellipses: x_true may be reduced (K=2) or lifted (K=3). if K == 2: if V_tan is None: raise ValueError("meg_free full_cov calibration needs V_tan when x_true is reduced 2D.") q_basis = np.asarray(V_tan, dtype=float) if q_basis.shape != (N, 3, 2): raise ValueError(f"V_tan must have shape (N,3,2); got {q_basis.shape}") x_true_3d = lift_reduced_sources_to_3d(x_true, q_basis) elif K == 3: x_true_3d = x_true else: raise ValueError(f"For meg_free, x_true must have K=2 (reduced) or K=3 (lifted); got {K}") x_hat_2d = self._reshape_meg_mean_if_needed(x_hat, N, T) if mode == "pointwise": curve = self.ue.calibration_curve_ellipse_meg_free_pointwise( x_true_3d=x_true_3d, x_hat_2d=x_hat_2d, posterior_cov_2d=posterior_uncert, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ) else: curve = self.ue.calibration_curve_ellipse_meg_free_aggregated( x_true_3d=x_true_3d, x_hat_2d=x_hat_2d, posterior_cov_2d=posterior_uncert, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ) nominal = np.asarray(curve["nominal_coverages"], dtype=float) empirical = np.asarray(curve["empirical_coverages"], dtype=float) return { "nominal": nominal, "empirical": empirical, "metrics_4": self.calibration_metrics_4(nominal, empirical), }
# ------------------------------------------------------------------ # All metrics together # ------------------------------------------------------------------
[docs] def evaluate_all( self, *, x_true: np.ndarray, x_hat: np.ndarray, posterior_uncert: np.ndarray, setting: str, mode: str = "aggregated", coords: Optional[np.ndarray] = None, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, compute_emd: bool = False, free_interval_type: str = "full_cov", ) -> Dict[str, Any]: setting = self._check_setting(setting) mode = self._check_mode(mode) if mode == "aggregated": if setting == "fixed": n_times = np.asarray(x_true).shape[1] else: n_times = np.asarray(x_true).shape[2] else: n_times = None out = { "mse": self.mse( x_true=x_true, x_hat=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ), "mae": self.mae( x_true=x_true, x_hat=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ), "rmse": self.rmse( x_true=x_true, x_hat=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ), "rmae": self.rmae( x_true=x_true, x_hat=x_hat, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ), "mean_posterior_std": self.mean_posterior_std( posterior_uncert=posterior_uncert, setting=setting, mode=mode, n_times=n_times, ), "calibration": self.calibration_curve( x_true=x_true, x_hat=x_hat, posterior_uncert=posterior_uncert, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, free_interval_type=free_interval_type, ), } if compute_emd: if coords is None: raise ValueError("compute_emd=True requires subset-aligned coords with shape (N,3).") out["emd"] = self.emd( x_true=x_true, x_hat=x_hat, coords=coords, setting=setting, mode=mode, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, ) return out
def _prepare_meg_sources_for_emd_dataset(eval_data: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: n_sources = int(eval_data.get("n_sources") or 0) if n_sources <= 0: raise ValueError("MEG dataset missing n_sources metadata required for EMD.") q_basis = eval_data.get("Q_basis") if q_basis is None: raise ValueError("MEG dataset missing Q_basis required for EMD.") q_basis = np.asarray(q_basis, dtype=float) x_hat = _reshape_free_mean(np.asarray(eval_data["x_hat"], dtype=float), n_sources, 2) x_true_raw = np.asarray(eval_data["x_true"], dtype=float) if x_true_raw.ndim == 3 and x_true_raw.shape[1] == 2: x_true_2d = x_true_raw x_true_3d = lift_reduced_sources_to_3d(x_true_2d, q_basis) elif x_true_raw.ndim == 3 and x_true_raw.shape[1] == 3: x_true_3d = x_true_raw basis_T = np.transpose(q_basis, (0, 2, 1)) x_true_2d = np.einsum("nki,nit->nkt", basis_T, x_true_3d) else: x_true_2d = _reshape_free_mean(x_true_raw, n_sources, 2) x_true_3d = lift_reduced_sources_to_3d(x_true_2d, q_basis) return x_true_2d, x_true_3d, x_hat, q_basis def _prepare_eeg_sources_for_emd_dataset(eval_data: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]: n_sources = int(eval_data.get("n_sources") or 0) if n_sources <= 0: raise ValueError("EEG dataset missing n_sources metadata required for EMD.") x_true = np.asarray(eval_data["x_true"], dtype=float) if not (x_true.ndim == 3 and x_true.shape[1] == 3): x_true = _reshape_free_mean(x_true, n_sources, 3) x_hat = np.asarray(eval_data["x_hat"], dtype=float) if not (x_hat.ndim == 3 and x_hat.shape[1] == 3): x_hat = _reshape_free_mean(x_hat, n_sources, 3) return x_true, x_hat def _compute_dataset_emd( *, metric_evaluator: MetricEvaluator, eval_data: Dict[str, Any], coords: Optional[np.ndarray], setting: Optional[str], emd_mode: str = "reduced", ) -> Optional[float]: if coords is None or setting is None: return None logger = getattr(metric_evaluator, "logger", None) or logging.getLogger(__name__) mode = (emd_mode or "reduced").lower() if mode not in {"reduced", "lifted"}: raise ValueError("emd_mode must be 'reduced' or 'lifted'.") try: if setting == "meg_free": x_true_2d, x_true_3d, x_hat_2d, q_basis = _prepare_meg_sources_for_emd_dataset(eval_data) if mode == "lifted": x_hat_3d = lift_reduced_sources_to_3d(x_hat_2d, q_basis) return metric_evaluator.emd( x_true=x_true_3d, x_hat=x_hat_3d, coords=coords, setting="eeg_free", mode="aggregated", ) return metric_evaluator.emd( x_true=x_true_2d, x_hat=x_hat_2d, coords=coords, setting="meg_free", mode="aggregated", V_tan=q_basis, ) if setting == "eeg_free": x_true_3d, x_hat_3d = _prepare_eeg_sources_for_emd_dataset(eval_data) return metric_evaluator.emd( x_true=x_true_3d, x_hat=x_hat_3d, coords=coords, setting="eeg_free", mode="aggregated", ) return metric_evaluator.emd( x_true=eval_data["x_true"], x_hat=eval_data["x_hat"], coords=coords, setting=setting, mode="aggregated", V_tan=eval_data.get("Q_basis"), ) except Exception as exc: logger.warning("EMD computation failed (%s mode): %s", setting, exc) return None