Source code for calibrain.uncertainty_estimation


import os
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from scipy.stats import chi2, norm
from itertools import combinations, zip_longest
import mne
import logging
from typing import Optional, Dict, Any, Tuple


[docs] def lift_reduced_sources_to_3d(a_red: np.ndarray, Q_basis: np.ndarray) -> np.ndarray: """ Lift reduced coordinates back to the retained local 3D source basis. Parameters ---------- a_red : (N, k, T) Reduced source coefficients, typically with k=2 for free_meg. Q_basis : (N, 3, k) Per-source basis mapping reduced coordinates into the local 3D basis. Returns ------- x_3d : (N, 3, T) Lifted coefficients in the local 3D source basis. """ a_red = np.asarray(a_red, dtype=float) Q_basis = np.asarray(Q_basis, dtype=float) if a_red.ndim != 3: raise ValueError(f"Expected a_red with shape (N, k, T). Got {a_red.shape}") if Q_basis.ndim != 3: raise ValueError(f"Expected Q_basis with shape (N, 3, k). Got {Q_basis.shape}") if a_red.shape[0] != Q_basis.shape[0]: raise ValueError("Mismatch in number of sources between a_red and Q_basis") if a_red.shape[1] != Q_basis.shape[2]: raise ValueError("Mismatch in reduced dimension k between a_red and Q_basis") return np.einsum("nok,nkt->not", Q_basis, a_red)
# ============================================================================= # Uncertainty estimator # =============================================================================
[docs] class UncertaintyEstimator: """ Uncertainty estimator for fixed- and free-orientation source estimates. Supported settings ------------------ 1) Fixed orientation (EEG or MEG): - pointwise marginal credible interval membership - aggregated-mean marginal credible interval membership 2) Free orientation EEG: - pointwise 3D credible ellipsoid membership - aggregated-mean 3D credible ellipsoid membership 3) Free orientation MEG (reduced rank-2 model): - pointwise 2D credible ellipse membership in reduced coordinates - aggregated-mean 2D credible ellipse membership in reduced coordinates Modeling convention ------------------- - posterior mean varies over time - posterior covariance is static over time - when aggregating by time-average, covariance is scaled by 1 / T Important MEG convention ------------------------ For MEG, the inverse model is assumed to be built on a reduced 2-orientation leadfield obtained source-wise via SVD or, preferably, taken directly from the extractor as `Q_basis`. Therefore: - truth can be represented in reduced coordinates a_true with shape (N,2,T) - if needed, truth can be lifted to local 3D via Q_basis -> shape (N,3,T) - posterior mean is in reduced 2D coordinates, shape (N,2,T) - posterior covariance is in reduced coordinates, shape (2N,2N) Recommended usage ----------------- Use the extractor-provided basis: V_tan = lf_free_meg["Q_basis"] instead of recomputing a new basis inside the uncertainty step. """
[docs] def __init__( self, nominal_coverages: Optional[np.ndarray] = None, logger: Optional[logging.Logger] = None, ): if nominal_coverages is None: nominal_coverages = np.array( [0.0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1.0], dtype=float, ) else: nominal_coverages = np.asarray(nominal_coverages, dtype=float) if np.any(nominal_coverages < 0.0) or np.any(nominal_coverages > 1.0): raise ValueError("Nominal coverages must be between 0 and 1.") nominal_coverages = np.unique(nominal_coverages) nominal_coverages.sort() self.nominal_coverages = nominal_coverages self.logger = logger or logging.getLogger(__name__) self.z_scores: Dict[float, float] = {} for c in self.nominal_coverages: c = float(c) if np.isclose(c, 0.0): self.z_scores[c] = 0.0 elif np.isclose(c, 1.0): self.z_scores[c] = np.inf else: self.z_scores[c] = float(norm.ppf((1.0 + c) / 2.0))
# def __getstate__(self): # state = self.__dict__.copy() # state['logger'] = None # return state # def __setstate__(self, state): # self.__dict__.update(state) # if self.logger is None: # self.logger = logging.getLogger(self.__class__.__name__) # ------------------------------------------------------------------ # Core helpers # ------------------------------------------------------------------ def _get_z(self, nominal_coverage: float) -> float: c = float(nominal_coverage) if np.isclose(c, 0.0): return 0.0 if np.isclose(c, 1.0): return np.inf if c not in self.z_scores: self.z_scores[c] = float(norm.ppf((1.0 + c) / 2.0)) return self.z_scores[c] @staticmethod def _get_chi2_threshold(nominal_coverage: float, df: int) -> float: c = float(nominal_coverage) if np.isclose(c, 0.0): return 0.0 if np.isclose(c, 1.0): return np.inf return float(chi2.ppf(c, df=df)) @staticmethod def _psd_clip_block(cov_block: np.ndarray, eps: float) -> np.ndarray: cov_block = (cov_block + cov_block.T) / 2.0 evals, evecs = np.linalg.eigh(cov_block) evals = np.maximum(evals, eps) cov_psd = evecs @ np.diag(evals) @ evecs.T return (cov_psd + cov_psd.T) / 2.0 @staticmethod def _canonicalize_basis_columns(Q: np.ndarray) -> np.ndarray: """ Stabilize basis column signs so that the entry with largest absolute value in each column is nonnegative. """ Q = np.asarray(Q, dtype=float) if Q.ndim != 2: raise ValueError(f"Expected 2D basis matrix. Got {Q.shape}") Q_out = np.array(Q, dtype=float, copy=True) for j in range(Q_out.shape[1]): idx = int(np.argmax(np.abs(Q_out[:, j]))) if Q_out[idx, j] < 0: Q_out[:, j] *= -1.0 return Q_out
[docs] @staticmethod def posterior_variance_from_cov( posterior_cov: np.ndarray, *, min_var: float = 0.0, ) -> np.ndarray: var = np.diag(posterior_cov).astype(float, copy=True) if min_var <= 0.0: return np.maximum(var, 0.0) return np.maximum(var, float(min_var))
@staticmethod def _set_equal_3d_limits_centered( ax, center: np.ndarray, xyz_points: np.ndarray, margin: float = 0.08, ) -> None: center = np.asarray(center, dtype=float).reshape(3) xyz_points = np.asarray(xyz_points, dtype=float).reshape(-1, 3) max_dev = float(np.max(np.abs(xyz_points - center[None, :]))) half = max((1.0 + margin) * max_dev, 1e-6) ax.set_xlim(center[0] - half, center[0] + half) ax.set_ylim(center[1] - half, center[1] + half) ax.set_zlim(center[2] - half, center[2] + half) try: ax.set_box_aspect((1, 1, 1)) except Exception: pass # ------------------------------------------------------------------ # Fixed orientation # ------------------------------------------------------------------
[docs] def credible_intervals_normal( self, mean: np.ndarray, variance: np.ndarray, nominal_coverage: float, ) -> Tuple[np.ndarray, np.ndarray]: c = float(nominal_coverage) if np.isclose(c, 0.0): return mean.copy(), mean.copy() if np.isclose(c, 1.0): return np.full_like(mean, -np.inf), np.full_like(mean, np.inf) z = self._get_z(c) std = np.sqrt(np.maximum(variance, 0.0)) return mean - z * std, mean + z * std
[docs] def pointwise_interval_membership( self, x_true: np.ndarray, # (N,T) x_hat: np.ndarray, # (N,T) posterior_var: np.ndarray, # (N,) nominal_coverage: float, ) -> Dict[str, Any]: x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) posterior_var = np.asarray(posterior_var, dtype=float) if x_true.ndim != 2 or x_hat.ndim != 2: raise ValueError("x_true and x_hat must have shape (N,T) for fixed orientation.") if x_true.shape != x_hat.shape: raise ValueError("x_true and x_hat must have the same shape.") if posterior_var.ndim != 1 or posterior_var.shape[0] != x_hat.shape[0]: raise ValueError("posterior_var must have shape (N,).") N, T = x_hat.shape var_full = np.repeat(posterior_var[:, None], T, axis=1) lo, hi = self.credible_intervals_normal(x_hat, var_full, nominal_coverage) within = (x_true >= lo) & (x_true <= hi) return { "ci_lower": lo, "ci_upper": hi, "within": within, "posterior_var": var_full, "z_score": float(self._get_z(nominal_coverage)), "count_within": int(np.sum(within)), "total_count": int(within.size), "empirical_coverage": float(np.mean(within)), "n_times": int(T), }
[docs] def aggregated_interval_membership( self, x_true: np.ndarray, # (N,T) x_hat: np.ndarray, # (N,T) posterior_var: np.ndarray, # (N,) nominal_coverage: float, ) -> Dict[str, Any]: x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) posterior_var = np.asarray(posterior_var, dtype=float) if x_true.ndim != 2 or x_hat.ndim != 2: raise ValueError("Aggregated fixed-orientation interval expects x_true and x_hat with shape (N,T).") if x_true.shape != x_hat.shape: raise ValueError("x_true and x_hat must have the same shape.") if posterior_var.ndim != 1 or posterior_var.shape[0] != x_hat.shape[0]: raise ValueError("posterior_var must have shape (N,).") N, T = x_hat.shape x_true_agg = np.mean(x_true, axis=1) x_hat_agg = np.mean(x_hat, axis=1) var_agg = posterior_var / float(T) lo, hi = self.credible_intervals_normal(x_hat_agg, var_agg, nominal_coverage) within = (x_true_agg >= lo) & (x_true_agg <= hi) return { "x_true_agg": x_true_agg, "x_hat_agg": x_hat_agg, "posterior_var_agg": var_agg, "ci_lower": lo, "ci_upper": hi, "within": within, "z_score": float(self._get_z(nominal_coverage)), "count_within": int(np.sum(within)), "total_count": int(within.size), "empirical_coverage": float(np.mean(within)), "n_times": int(T), }
[docs] def calibration_curve_intervals_pointwise( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_var: np.ndarray, ) -> Dict[str, Any]: empirical_coverages = [] counts = [] for c in self.nominal_coverages: out = self.pointwise_interval_membership( x_true=x_true, x_hat=x_hat, posterior_var=posterior_var, nominal_coverage=float(c), ) empirical_coverages.append(out["empirical_coverage"]) counts.append(out["count_within"]) return { "nominal_coverages": self.nominal_coverages, "empirical_coverages": np.asarray(empirical_coverages, dtype=float), "ci_counts": np.asarray(counts, dtype=int), "interval_type": "marginal", }
[docs] def calibration_curve_intervals_aggregated( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_var: np.ndarray, ) -> Dict[str, Any]: empirical_coverages = [] counts = [] for c in self.nominal_coverages: out = self.aggregated_interval_membership( x_true=x_true, x_hat=x_hat, posterior_var=posterior_var, nominal_coverage=float(c), ) empirical_coverages.append(out["empirical_coverage"]) counts.append(out["count_within"]) return { "nominal_coverages": self.nominal_coverages, "empirical_coverages": np.asarray(empirical_coverages, dtype=float), "ci_counts": np.asarray(counts, dtype=int), "interval_type": "marginal", }
# ------------------------------------------------------------------ # EEG free orientation # ------------------------------------------------------------------
[docs] def pointwise_ellipsoid_membership_eeg_free( self, x_true: np.ndarray, # (N,3,T) x_hat: np.ndarray, # (N,3,T) posterior_cov: np.ndarray, # (3N,3N) or (N,3,3) nominal_coverage: float, *, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, ) -> Dict[str, Any]: x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) posterior_cov = np.asarray(posterior_cov, dtype=float) thresh = self._get_chi2_threshold(nominal_coverage, df=3) if x_true.ndim != 3 or x_true.shape[1] != 3: raise ValueError(f"x_true must be (N,3,T); got {x_true.shape}") if x_hat.shape != x_true.shape: raise ValueError("x_hat must match x_true shape (N,3,T).") N, _, T = x_true.shape cov_is_blocks = posterior_cov.ndim == 3 if cov_is_blocks: if posterior_cov.shape != (N, 3, 3): raise ValueError( f"posterior_cov block form must be (N,3,3); got {posterior_cov.shape}" ) else: if posterior_cov.shape != (3 * N, 3 * N): raise ValueError(f"posterior_cov must be (3N,3N); got {posterior_cov.shape}") q_values = np.zeros((N, T), dtype=float) within = np.zeros((N, T), dtype=bool) cov_blocks = np.zeros((N, 3, 3), dtype=float) for i in range(N): Sigma3 = posterior_cov[i] if cov_is_blocks else posterior_cov[3 * i:3 * i + 3, 3 * i:3 * i + 3] Sigma3 = (Sigma3 + Sigma3.T) / 2.0 if psd_repair_blocks: Sigma3 = self._psd_clip_block(Sigma3, block_epsilon) cov_blocks[i] = Sigma3 evals, evecs = np.linalg.eigh(Sigma3) evals = np.maximum(evals, block_epsilon) Sigma3_inv = evecs @ np.diag(1.0 / evals) @ evecs.T d = x_true[i] - x_hat[i] # (3,T) q = np.einsum("it,ij,jt->t", d, Sigma3_inv, d) q_values[i] = q within[i] = (q <= thresh) return { "q_values": q_values, "threshold": float(thresh), "within": within, "cov_blocks": cov_blocks, "count_within": int(np.sum(within)), "total_count": int(within.size), "empirical_coverage": float(np.mean(within)), "n_times": int(T), }
[docs] def aggregated_ellipsoid_membership_eeg_free( self, x_true: np.ndarray, # (N,3,T) x_hat: np.ndarray, # (N,3,T) posterior_cov: np.ndarray, # (3N,3N) or (N,3,3) nominal_coverage: float, *, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, ) -> Dict[str, Any]: x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) posterior_cov = np.asarray(posterior_cov, dtype=float) if x_true.ndim != 3 or x_true.shape[1] != 3: raise ValueError(f"x_true must be (N,3,T); got {x_true.shape}") if x_hat.shape != x_true.shape: raise ValueError("x_hat must match x_true shape (N,3,T).") N, _, T = x_true.shape cov_is_blocks = posterior_cov.ndim == 3 if cov_is_blocks: if posterior_cov.shape != (N, 3, 3): raise ValueError( f"posterior_cov block form must be (N,3,3); got {posterior_cov.shape}" ) else: if posterior_cov.shape != (3 * N, 3 * N): raise ValueError(f"posterior_cov must be (3N,3N); got {posterior_cov.shape}") x_true_agg = np.mean(x_true, axis=2) x_hat_agg = np.mean(x_hat, axis=2) thresh = self._get_chi2_threshold(nominal_coverage, df=3) q_values = np.zeros(N, dtype=float) within = np.zeros(N, dtype=bool) cov_blocks = np.zeros((N, 3, 3), dtype=float) for i in range(N): Sigma3 = posterior_cov[i] if cov_is_blocks else posterior_cov[3 * i:3 * i + 3, 3 * i:3 * i + 3] Sigma3 = (Sigma3 + Sigma3.T) / 2.0 if psd_repair_blocks: Sigma3 = self._psd_clip_block(Sigma3, block_epsilon) Sigma3_agg = Sigma3 / float(T) cov_blocks[i] = Sigma3_agg evals, evecs = np.linalg.eigh(Sigma3_agg) evals = np.maximum(evals, block_epsilon) Sigma3_inv = evecs @ np.diag(1.0 / evals) @ evecs.T d = x_true_agg[i] - x_hat_agg[i] q = float(d.T @ Sigma3_inv @ d) q_values[i] = q within[i] = (q <= thresh) return { "x_true_agg": x_true_agg, "x_hat_agg": x_hat_agg, "q_values": q_values, "threshold": float(thresh), "within": within, "cov_blocks": cov_blocks, "count_within": int(np.sum(within)), "total_count": int(within.size), "empirical_coverage": float(np.mean(within)), "n_times": int(T), }
[docs] def calibration_curve_ellipsoid_eeg_free_pointwise( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_cov: np.ndarray, *, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, ) -> Dict[str, Any]: empirical_coverages = [] counts = [] for c in self.nominal_coverages: out = self.pointwise_ellipsoid_membership_eeg_free( x_true=x_true, x_hat=x_hat, posterior_cov=posterior_cov, nominal_coverage=float(c), psd_repair_blocks=psd_repair_blocks, block_epsilon=block_epsilon, ) empirical_coverages.append(out["empirical_coverage"]) counts.append(out["count_within"]) return { "nominal_coverages": self.nominal_coverages, "empirical_coverages": np.asarray(empirical_coverages, dtype=float), "ci_counts": np.asarray(counts, dtype=int), "interval_type": "full_cov", }
[docs] def calibration_curve_ellipsoid_eeg_free_aggregated( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_cov: np.ndarray, *, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, ) -> Dict[str, Any]: empirical_coverages = [] counts = [] for c in self.nominal_coverages: out = self.aggregated_ellipsoid_membership_eeg_free( x_true=x_true, x_hat=x_hat, posterior_cov=posterior_cov, nominal_coverage=float(c), psd_repair_blocks=psd_repair_blocks, block_epsilon=block_epsilon, ) empirical_coverages.append(out["empirical_coverage"]) counts.append(out["count_within"]) return { "nominal_coverages": self.nominal_coverages, "empirical_coverages": np.asarray(empirical_coverages, dtype=float), "ci_counts": np.asarray(counts, dtype=int), "interval_type": "full_cov", }
# ------------------------------------------------------------------ # Direction-aggregated marginal intervals (free orientation) # ------------------------------------------------------------------ @staticmethod def _as_source_major_free_mean( posterior_mean: np.ndarray, *, n_sources: int, n_orient: int, name: str = "posterior_mean", ) -> np.ndarray: """Convert a free-orientation posterior mean to source-major shape (N,K,T). Accepted input shapes --------------------- - (N,K,T): already source-major. - (K*N,T): flattened source-major representation. Notes ----- This assumes source-major flattening: (N,K,T) -> (N*K,T), consistent with lead-field reshaping conventions such as: L_block.reshape(M, N*K). """ posterior_mean = np.asarray(posterior_mean, dtype=float) if posterior_mean.ndim == 3: if posterior_mean.shape[0] != n_sources or posterior_mean.shape[1] != n_orient: raise ValueError( f"{name} must have shape (N,K,T)=({n_sources},{n_orient},T). Got {posterior_mean.shape}." ) return posterior_mean if posterior_mean.ndim == 2: expected_first_dim = n_sources * n_orient if posterior_mean.shape[0] != expected_first_dim: raise ValueError( f"{name} has incompatible flattened shape. Expected first dimension {expected_first_dim}, got {posterior_mean.shape[0]}." ) T = posterior_mean.shape[1] return posterior_mean.reshape(n_sources, n_orient, T) raise ValueError(f"{name} must be either (N,K,T) or (K*N,T). Got {posterior_mean.shape}.")
[docs] @staticmethod def componentwise_variance_from_uncert( posterior_uncert: np.ndarray, *, n_sources: int, n_orient: int, min_var: float = 1e-12, ) -> np.ndarray: """Extract per-source per-component marginal variances as an (N,K) array. Supports: - full covariance: (K*N, K*N), uses diag(...) and reshapes to (N,K) - block covariance: (N, K, K), uses diagonal blocks directly This utility underpins the `interval_type="marginal"` diagnostic for free orientation, where directions/components are not evaluated separately but pooled (see `calibration_curve_componentwise_intervals_*`). """ posterior_uncert = np.asarray(posterior_uncert, dtype=float) if posterior_uncert.ndim == 2: expected = (n_sources * n_orient, n_sources * n_orient) if posterior_uncert.shape != expected: raise ValueError(f"posterior_uncert must have shape {expected}. Got {posterior_uncert.shape}.") posterior_var = np.diag(posterior_uncert).astype(float, copy=True).reshape(n_sources, n_orient) elif posterior_uncert.ndim == 3: expected = (n_sources, n_orient, n_orient) if posterior_uncert.shape != expected: raise ValueError( f"posterior_uncert block form must have shape {expected}. Got {posterior_uncert.shape}." ) posterior_var = np.diagonal(posterior_uncert, axis1=1, axis2=2).astype(float, copy=True) else: raise ValueError("posterior_uncert must be either (K*N,K*N) or (N,K,K).") if min_var <= 0.0: return np.maximum(posterior_var, 0.0) return np.maximum(posterior_var, float(min_var))
[docs] def pointwise_componentwise_interval_membership_free( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_uncert: np.ndarray, nominal_coverage: float, *, n_orient: int, min_var: float = 1e-12, ) -> Dict[str, Any]: """Pointwise marginal (component-wise) CI membership for free orientation. This implements the "direction-aggregated marginal" diagnostic from `_temp/Component_CI_R1.py`: - Build scalar normal credible intervals per retained component using only marginal variances (diagonal of the covariance). - Pool interval membership over sources, components, and time. Important --------- Directions/components are *not* evaluated separately. They are pooled because local coordinate systems are arbitrary (EEG xyz axes depend on coordinate choice; MEG reduced components are voxel-dependent). """ x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) if x_true.ndim != 3 or x_true.shape[1] != n_orient: raise ValueError(f"x_true must have shape (N,{n_orient},T). Got {x_true.shape}.") if x_hat.shape != x_true.shape: raise ValueError(f"x_hat must match x_true shape. Got {x_hat.shape} vs {x_true.shape}.") N, K, T = x_true.shape posterior_var = self.componentwise_variance_from_uncert( posterior_uncert, n_sources=N, n_orient=n_orient, min_var=min_var, ) posterior_var_full = np.repeat(posterior_var[:, :, None], T, axis=2) ci_lower, ci_upper = self.credible_intervals_normal( mean=x_hat, variance=posterior_var_full, nominal_coverage=float(nominal_coverage), ) within = (x_true >= ci_lower) & (x_true <= ci_upper) count_within = int(np.sum(within)) total_count = int(within.size) return { "ci_lower": ci_lower, "ci_upper": ci_upper, "within": within, "posterior_var": posterior_var, "posterior_var_full": posterior_var_full, "z_score": float(self._get_z(float(nominal_coverage))), "count_within": count_within, "total_count": total_count, "empirical_coverage": float(count_within / total_count), "aggregation_axes": "sources, directions, time", "n_sources": int(N), "n_orient": int(K), "n_times": int(T), }
[docs] def aggregated_componentwise_interval_membership_free( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_uncert: np.ndarray, nominal_coverage: float, *, n_orient: int, min_var: float = 1e-12, ) -> Dict[str, Any]: """Temporally aggregated marginal (component-wise) CI membership. This corresponds to building the diagnostic on time-averaged signals: - Aggregate `x_true`/`x_hat` over time (mean over T). - Scale marginal variances by 1/T (variance of the mean). - Build scalar CIs per component and pool membership over sources and components. """ x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) if x_true.ndim != 3 or x_true.shape[1] != n_orient: raise ValueError(f"x_true must have shape (N,{n_orient},T). Got {x_true.shape}.") if x_hat.shape != x_true.shape: raise ValueError(f"x_hat must match x_true shape. Got {x_hat.shape} vs {x_true.shape}.") N, K, T = x_true.shape x_true_agg = np.mean(x_true, axis=2) x_hat_agg = np.mean(x_hat, axis=2) posterior_var = self.componentwise_variance_from_uncert( posterior_uncert, n_sources=N, n_orient=n_orient, min_var=min_var, ) posterior_var_agg = posterior_var / float(T) ci_lower, ci_upper = self.credible_intervals_normal( mean=x_hat_agg, variance=posterior_var_agg, nominal_coverage=float(nominal_coverage), ) within = (x_true_agg >= ci_lower) & (x_true_agg <= ci_upper) count_within = int(np.sum(within)) total_count = int(within.size) return { "x_true_agg": x_true_agg, "x_hat_agg": x_hat_agg, "ci_lower": ci_lower, "ci_upper": ci_upper, "within": within, "posterior_var": posterior_var, "posterior_var_agg": posterior_var_agg, "z_score": float(self._get_z(float(nominal_coverage))), "count_within": count_within, "total_count": total_count, "empirical_coverage": float(count_within / total_count), "aggregation_axes": "sources, directions", "n_sources": int(N), "n_orient": int(K), "n_times": int(T), }
[docs] def calibration_curve_componentwise_intervals_pointwise_free( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_uncert: np.ndarray, *, n_orient: int, min_var: float = 1e-12, ) -> Dict[str, Any]: """Calibration curve for marginal (component-wise) intervals (pointwise). Summary ------- Builds scalar normal credible intervals per component (using only marginal variances) and pools membership over sources, components, and time. This is the `interval_type="marginal"` diagnostic for free orientation. Shapes ------ - EEG free: `n_orient=3`, `x_true/x_hat` are (N,3,T) - Reduced MEG free: `n_orient=2`, `x_true/x_hat` are (N,2,T) - `posterior_uncert`: either full (K*N,K*N) or per-source blocks (N,K,K) Important --------- Components are not evaluated separately (pooled), because local coordinate labels are arbitrary. """ empirical_coverages = [] counts = [] totals = [] nominal_coverages = np.asarray(self.nominal_coverages, dtype=float) for c in nominal_coverages: out = self.pointwise_componentwise_interval_membership_free( x_true=x_true, x_hat=x_hat, posterior_uncert=posterior_uncert, nominal_coverage=float(c), n_orient=n_orient, min_var=min_var, ) empirical_coverages.append(out["empirical_coverage"]) counts.append(out["count_within"]) totals.append(out["total_count"]) return { "nominal_coverages": nominal_coverages, "empirical_coverages": np.asarray(empirical_coverages, dtype=float), "ci_counts": np.asarray(counts, dtype=int), "total_counts": np.asarray(totals, dtype=int), "n_orient": int(n_orient), "mode": "pointwise", "interval_type": "marginal", "direction_aggregation": "pooled_over_all_components", }
[docs] def calibration_curve_componentwise_intervals_aggregated_free( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_uncert: np.ndarray, *, n_orient: int, min_var: float = 1e-12, ) -> Dict[str, Any]: """Calibration curve for marginal (component-wise) intervals (aggregated). Summary ------- Temporally aggregated version of the `interval_type="marginal"` free diagnostic: 1) average `x_true`/`x_hat` over time, 2) scale marginal variances by 1/T, 3) build scalar CIs per component, 4) pool membership over sources and components. This is the mode used by `workflows/calibration.py`. Important --------- Components are pooled (no direction-wise curves), because local coordinate systems are arbitrary. """ empirical_coverages = [] counts = [] totals = [] nominal_coverages = np.asarray(self.nominal_coverages, dtype=float) for c in nominal_coverages: out = self.aggregated_componentwise_interval_membership_free( x_true=x_true, x_hat=x_hat, posterior_uncert=posterior_uncert, nominal_coverage=float(c), n_orient=n_orient, min_var=min_var, ) empirical_coverages.append(out["empirical_coverage"]) counts.append(out["count_within"]) totals.append(out["total_count"]) return { "nominal_coverages": nominal_coverages, "empirical_coverages": np.asarray(empirical_coverages, dtype=float), "ci_counts": np.asarray(counts, dtype=int), "total_counts": np.asarray(totals, dtype=int), "n_orient": int(n_orient), "mode": "aggregated", "interval_type": "marginal", "direction_aggregation": "pooled_over_all_components", }
[docs] def calibration_curve_componentwise_eeg_free_pointwise( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_uncert: np.ndarray, *, min_var: float = 1e-12, ) -> Dict[str, Any]: """Pointwise marginal calibration curve for free-orientation EEG (K=3). Expected shapes: - `x_true`, `x_hat`: (N,3,T) - `posterior_uncert`: (3N,3N) or per-source blocks (N,3,3) """ return self.calibration_curve_componentwise_intervals_pointwise_free( x_true=x_true, x_hat=x_hat, posterior_uncert=posterior_uncert, n_orient=3, min_var=min_var, )
[docs] def calibration_curve_componentwise_eeg_free_aggregated( self, x_true: np.ndarray, x_hat: np.ndarray, posterior_uncert: np.ndarray, *, min_var: float = 1e-12, ) -> Dict[str, Any]: """Aggregated marginal calibration curve for free-orientation EEG (K=3). Expected shapes: - `x_true`, `x_hat`: (N,3,T) - `posterior_uncert`: (3N,3N) or per-source blocks (N,3,3) """ return self.calibration_curve_componentwise_intervals_aggregated_free( x_true=x_true, x_hat=x_hat, posterior_uncert=posterior_uncert, n_orient=3, min_var=min_var, )
[docs] def calibration_curve_componentwise_meg_free_pointwise( self, x_true_2d: np.ndarray, x_hat_2d: np.ndarray, posterior_uncert_2d: np.ndarray, *, min_var: float = 1e-12, ) -> Dict[str, Any]: """Pointwise marginal calibration curve for reduced free-orientation MEG (K=2). Expected shapes: - `x_true_2d`, `x_hat_2d`: (N,2,T) - `posterior_uncert_2d`: (2N,2N) or per-source blocks (N,2,2) Notes ----- This diagnostic operates directly in reduced 2D coefficient coordinates: no 2D->3D lifting and no 3D->2D projection is performed here. """ return self.calibration_curve_componentwise_intervals_pointwise_free( x_true=x_true_2d, x_hat=x_hat_2d, posterior_uncert=posterior_uncert_2d, n_orient=2, min_var=min_var, )
[docs] def calibration_curve_componentwise_meg_free_aggregated( self, x_true_2d: np.ndarray, x_hat_2d: np.ndarray, posterior_uncert_2d: np.ndarray, *, min_var: float = 1e-12, ) -> Dict[str, Any]: """Aggregated marginal calibration curve for reduced free-orientation MEG (K=2). Expected shapes: - `x_true_2d`, `x_hat_2d`: (N,2,T) - `posterior_uncert_2d`: (2N,2N) or per-source blocks (N,2,2) Notes ----- This diagnostic operates directly in reduced 2D coefficient coordinates: no 2D->3D lifting and no 3D->2D projection is performed here. """ return self.calibration_curve_componentwise_intervals_aggregated_free( x_true=x_true_2d, x_hat=x_hat_2d, posterior_uncert=posterior_uncert_2d, n_orient=2, min_var=min_var, )
# ------------------------------------------------------------------ # MEG free orientation (reduced 2D model) # ------------------------------------------------------------------
[docs] def reduce_meg_leadfield_svd( self, L_free_Mx3N: np.ndarray, ) -> Dict[str, np.ndarray]: """ Fallback-only helper: build a source-wise SVD-reduced MEG leadfield. For each source i: L_i = U_i S_i V_i^T, with L_i in R^{M x 3} keep the first two right singular vectors V_i(:, :2), and define L_i_reduced = L_i @ V_i(:, :2) in R^{M x 2} Recommended usage ----------------- In the current pipeline, prefer the extractor-provided basis Q_basis = lf_free_meg["Q_basis"] rather than recomputing a new basis here. Returns ------- dict with: - L_meg_Mx2N: reduced global leadfield, shape (M, 2N) - V_tan: source-wise 3D->2D basis, shape (N, 3, 2) - singular_values: shape (N, 2) """ L_free_Mx3N = np.asarray(L_free_Mx3N, dtype=float) M, threeN = L_free_Mx3N.shape if threeN % 3 != 0: raise ValueError(f"L_free_Mx3N must be (M,3N). Got {L_free_Mx3N.shape}") N = threeN // 3 L_meg_Mx2N = np.zeros((M, 2 * N), dtype=float) V_tan = np.zeros((N, 3, 2), dtype=float) singular_values = np.zeros((N, 2), dtype=float) for i in range(N): Li = L_free_Mx3N[:, 3 * i:3 * i + 3] # (M,3) U, s, Vt = np.linalg.svd(Li, full_matrices=False) V2 = self._canonicalize_basis_columns(Vt.T[:, :2]) # (3,2) Li_red = Li @ V2 # (M,2) L_meg_Mx2N[:, 2 * i:2 * i + 2] = Li_red V_tan[i] = V2 singular_values[i] = s[:2] return { "L_meg_Mx2N": L_meg_Mx2N, "V_tan": V_tan, "singular_values": singular_values, }
[docs] def precompute_meg_tangent_bases_svd(self, L_free_Mx3N: np.ndarray) -> np.ndarray: """ Backward-compatible wrapper. Returns source-wise local 3x2 reduction bases from the original free leadfield. Prefer passing the extractor basis directly whenever possible. """ return self.reduce_meg_leadfield_svd(L_free_Mx3N)["V_tan"]
[docs] @staticmethod def project_meg_3d_to_2d( x_3d: np.ndarray, V_tan: np.ndarray, ) -> np.ndarray: """ Project 3D MEG truth/estimate onto the local reduced 2D coordinates. Parameters ---------- x_3d : (N,3,T) or (N,3) V_tan : (N,3,2) Returns ------- x_2d : (N,2,T) or (N,2) """ x_3d = np.asarray(x_3d, dtype=float) V_tan = np.asarray(V_tan, dtype=float) if V_tan.ndim != 3 or V_tan.shape[1:] != (3, 2): raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}") N = V_tan.shape[0] if x_3d.ndim == 3: if x_3d.shape[0] != N or x_3d.shape[1] != 3: raise ValueError(f"x_3d must be (N,3,T); got {x_3d.shape}") return np.einsum("nck,nct->nkt", V_tan, x_3d) if x_3d.ndim == 2: if x_3d.shape != (N, 3): raise ValueError(f"x_3d must be (N,3); got {x_3d.shape}") return np.einsum("nck,nc->nk", V_tan, x_3d) raise ValueError("x_3d must have shape (N,3,T) or (N,3).")
[docs] @staticmethod def lift_meg_2d_to_3d( x_2d: np.ndarray, V_tan: np.ndarray, ) -> np.ndarray: """ Lift reduced 2D MEG coordinates back to the retained local 3D basis. Parameters ---------- x_2d : (N,2,T) or (N,2) V_tan : (N,3,2) Returns ------- x_3d_lifted : (N,3,T) or (N,3) """ x_2d = np.asarray(x_2d, dtype=float) V_tan = np.asarray(V_tan, dtype=float) if V_tan.ndim != 3 or V_tan.shape[1:] != (3, 2): raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}") N = V_tan.shape[0] if x_2d.ndim == 3: if x_2d.shape[0] != N or x_2d.shape[1] != 2: raise ValueError(f"x_2d must be (N,2,T); got {x_2d.shape}") return np.einsum("nck,nkt->nct", V_tan, x_2d) if x_2d.ndim == 2: if x_2d.shape != (N, 2): raise ValueError(f"x_2d must be (N,2); got {x_2d.shape}") return np.einsum("nck,nk->nc", V_tan, x_2d) raise ValueError("x_2d must have shape (N,2,T) or (N,2).")
[docs] def pointwise_ellipse_membership_meg_free( self, x_true_3d: np.ndarray, # (N,3,T) x_hat_2d: np.ndarray, # (N,2,T) posterior_cov_2d: np.ndarray, # (2N,2N) or (N,2,2) nominal_coverage: float, *, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, ) -> Dict[str, Any]: """ Direct 2D MEG credible ellipse membership in reduced coordinates. Truth is provided in 3D and projected to reduced coordinates: x_true_2d = V_tan^T x_true_3d Posterior mean and covariance are assumed to already come from the reduced 2D MEG inverse model. Recommended usage ----------------- Pass `V_tan=Q_basis` from the leadfield extractor to avoid recomputing a potentially sign-flipped basis. """ x_true_3d = np.asarray(x_true_3d, dtype=float) x_hat_2d = np.asarray(x_hat_2d, dtype=float) posterior_cov_2d = np.asarray(posterior_cov_2d, dtype=float) thresh = self._get_chi2_threshold(nominal_coverage, df=2) if x_true_3d.ndim != 3 or x_true_3d.shape[1] != 3: raise ValueError(f"x_true_3d must be (N,3,T); got {x_true_3d.shape}") if x_hat_2d.ndim != 3 or x_hat_2d.shape[1] != 2: raise ValueError(f"x_hat_2d must be (N,2,T); got {x_hat_2d.shape}") N, _, T = x_true_3d.shape if x_hat_2d.shape[0] != N or x_hat_2d.shape[2] != T: raise ValueError("x_hat_2d must have shape (N,2,T) matching x_true_3d.") cov_is_blocks = posterior_cov_2d.ndim == 3 if cov_is_blocks: if posterior_cov_2d.shape != (N, 2, 2): raise ValueError( f"posterior_cov_2d block form must be (N,2,2); got {posterior_cov_2d.shape}" ) else: if posterior_cov_2d.shape != (2 * N, 2 * N): raise ValueError( f"posterior_cov_2d must be (2N,2N); got {posterior_cov_2d.shape}" ) if V_tan is None: if L_free_Mx3N is None: raise ValueError("Provide either V_tan or L_free_Mx3N for MEG 3D->2D projection.") V_tan = self.precompute_meg_tangent_bases_svd(L_free_Mx3N) V_tan = np.asarray(V_tan, dtype=float) if V_tan.shape != (N, 3, 2): raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}") x_true_2d = self.project_meg_3d_to_2d(x_true_3d, V_tan) q_values = np.zeros((N, T), dtype=float) within = np.zeros((N, T), dtype=bool) cov_blocks = np.zeros((N, 2, 2), dtype=float) for i in range(N): Sigma2 = posterior_cov_2d[i] if cov_is_blocks else posterior_cov_2d[2 * i:2 * i + 2, 2 * i:2 * i + 2] Sigma2 = (Sigma2 + Sigma2.T) / 2.0 if psd_repair_blocks: Sigma2 = self._psd_clip_block(Sigma2, block_epsilon) cov_blocks[i] = Sigma2 evals, evecs = np.linalg.eigh(Sigma2) evals = np.maximum(evals, block_epsilon) Sigma2_inv = evecs @ np.diag(1.0 / evals) @ evecs.T d2 = x_true_2d[i] - x_hat_2d[i] # (2,T) q = np.einsum("it,ij,jt->t", d2, Sigma2_inv, d2) q_values[i] = q within[i] = (q <= thresh) return { "q_values": q_values, "threshold": float(thresh), "within": within, "projected_true": x_true_2d, "projected_mean": x_hat_2d, "cov_blocks": cov_blocks, "projected_cov_blocks": cov_blocks, "V_tan": V_tan, "count_within": int(np.sum(within)), "total_count": int(within.size), "empirical_coverage": float(np.mean(within)), "n_times": int(T), }
[docs] def aggregated_ellipse_membership_meg_free( self, x_true_3d: np.ndarray, # (N,3,T) x_hat_2d: np.ndarray, # (N,2,T) posterior_cov_2d: np.ndarray, # (2N,2N) or (N,2,2) nominal_coverage: float, *, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, ) -> Dict[str, Any]: """ Aggregated 2D MEG credible ellipse membership in reduced coordinates. Recommended usage ----------------- Pass `V_tan=Q_basis` from the leadfield extractor. """ x_true_3d = np.asarray(x_true_3d, dtype=float) x_hat_2d = np.asarray(x_hat_2d, dtype=float) posterior_cov_2d = np.asarray(posterior_cov_2d, dtype=float) if x_true_3d.ndim != 3 or x_true_3d.shape[1] != 3: raise ValueError(f"x_true_3d must be (N,3,T); got {x_true_3d.shape}") if x_hat_2d.ndim != 3 or x_hat_2d.shape[1] != 2: raise ValueError(f"x_hat_2d must be (N,2,T); got {x_hat_2d.shape}") N, _, T = x_true_3d.shape if x_hat_2d.shape[0] != N or x_hat_2d.shape[2] != T: raise ValueError("x_hat_2d must have shape (N,2,T) matching x_true_3d.") cov_is_blocks = posterior_cov_2d.ndim == 3 if cov_is_blocks: if posterior_cov_2d.shape != (N, 2, 2): raise ValueError( f"posterior_cov_2d block form must be (N,2,2); got {posterior_cov_2d.shape}" ) else: if posterior_cov_2d.shape != (2 * N, 2 * N): raise ValueError( f"posterior_cov_2d must be (2N,2N); got {posterior_cov_2d.shape}" ) if V_tan is None: if L_free_Mx3N is None: raise ValueError("Provide either V_tan or L_free_Mx3N for MEG 3D->2D projection.") V_tan = self.precompute_meg_tangent_bases_svd(L_free_Mx3N) V_tan = np.asarray(V_tan, dtype=float) if V_tan.shape != (N, 3, 2): raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}") x_true_2d = self.project_meg_3d_to_2d(x_true_3d, V_tan) x_true_agg = np.mean(x_true_2d, axis=2) x_hat_agg = np.mean(x_hat_2d, axis=2) thresh = self._get_chi2_threshold(nominal_coverage, df=2) q_values = np.zeros(N, dtype=float) within = np.zeros(N, dtype=bool) cov_blocks = np.zeros((N, 2, 2), dtype=float) for i in range(N): Sigma2 = posterior_cov_2d[i] if cov_is_blocks else posterior_cov_2d[2 * i:2 * i + 2, 2 * i:2 * i + 2] Sigma2 = (Sigma2 + Sigma2.T) / 2.0 if psd_repair_blocks: Sigma2 = self._psd_clip_block(Sigma2, block_epsilon) Sigma2_agg = Sigma2 / float(T) cov_blocks[i] = Sigma2_agg evals, evecs = np.linalg.eigh(Sigma2_agg) evals = np.maximum(evals, block_epsilon) Sigma2_inv = evecs @ np.diag(1.0 / evals) @ evecs.T d2 = x_true_agg[i] - x_hat_agg[i] q = float(d2.T @ Sigma2_inv @ d2) q_values[i] = q within[i] = (q <= thresh) return { "x_true_agg_2d": x_true_agg, "x_hat_agg_2d": x_hat_agg, "q_values": q_values, "threshold": float(thresh), "within": within, "projected_true": x_true_agg, "projected_mean": x_hat_agg, "cov_blocks": cov_blocks, "projected_cov_blocks": cov_blocks, "V_tan": V_tan, "count_within": int(np.sum(within)), "total_count": int(within.size), "empirical_coverage": float(np.mean(within)), "n_times": int(T), }
[docs] def calibration_curve_ellipse_meg_free_pointwise( self, x_true_3d: np.ndarray, x_hat_2d: np.ndarray, posterior_cov_2d: np.ndarray, *, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, ) -> Dict[str, Any]: empirical_coverages = [] counts = [] for c in self.nominal_coverages: out = self.pointwise_ellipse_membership_meg_free( x_true_3d=x_true_3d, x_hat_2d=x_hat_2d, posterior_cov_2d=posterior_cov_2d, nominal_coverage=float(c), V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, psd_repair_blocks=psd_repair_blocks, block_epsilon=block_epsilon, ) empirical_coverages.append(out["empirical_coverage"]) counts.append(out["count_within"]) return { "nominal_coverages": self.nominal_coverages, "empirical_coverages": np.asarray(empirical_coverages, dtype=float), "ci_counts": np.asarray(counts, dtype=int), "interval_type": "full_cov", }
[docs] def calibration_curve_ellipse_meg_free_aggregated( self, x_true_3d: np.ndarray, x_hat_2d: np.ndarray, posterior_cov_2d: np.ndarray, *, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, ) -> Dict[str, Any]: empirical_coverages = [] counts = [] for c in self.nominal_coverages: out = self.aggregated_ellipse_membership_meg_free( x_true_3d=x_true_3d, x_hat_2d=x_hat_2d, posterior_cov_2d=posterior_cov_2d, nominal_coverage=float(c), V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, psd_repair_blocks=psd_repair_blocks, block_epsilon=block_epsilon, ) empirical_coverages.append(out["empirical_coverage"]) counts.append(out["count_within"]) return { "nominal_coverages": self.nominal_coverages, "empirical_coverages": np.asarray(empirical_coverages, dtype=float), "ci_counts": np.asarray(counts, dtype=int), "interval_type": "full_cov", }
# ------------------------------------------------------------------ # Visualization: fixed orientation # ------------------------------------------------------------------
[docs] def plot_fixed_interval_membership_pointwise( self, x_true: np.ndarray, # (N,T) x_hat: np.ndarray, # (N,T) posterior_var: np.ndarray, # (N,) src_idx: int, nominal_coverage: float = 0.95, *, times: Optional[np.ndarray] = None, figsize: Tuple[float, float] = (9.0, 4.0), ) -> None: import matplotlib.pyplot as plt x_true = np.asarray(x_true, dtype=float) x_hat = np.asarray(x_hat, dtype=float) posterior_var = np.asarray(posterior_var, dtype=float) if x_true.ndim != 2 or x_hat.ndim != 2: raise ValueError("x_true and x_hat must have shape (N,T).") if x_true.shape != x_hat.shape: raise ValueError("x_true and x_hat must have the same shape.") if posterior_var.ndim != 1 or posterior_var.shape[0] != x_hat.shape[0]: raise ValueError("posterior_var must have shape (N,).") _, T = x_hat.shape if times is None: times = np.arange(T) else: times = np.asarray(times) if times.shape[0] != T: raise ValueError("times must have length T.") var_full = np.repeat(posterior_var[:, None], T, axis=1) lo, hi = self.credible_intervals_normal(x_hat, var_full, nominal_coverage) within = (x_true >= lo) & (x_true <= hi) inside_all = bool(np.all(within[src_idx])) fig, ax = plt.subplots(figsize=figsize) ax.fill_between( times, lo[src_idx], hi[src_idx], alpha=0.25, label=f"{int(100 * nominal_coverage)}% credible interval", ) ax.plot(times, x_hat[src_idx], label="posterior mean") ax.plot(times, x_true[src_idx], label="x_true") outside = ~within[src_idx] if np.any(outside): ax.scatter( times[outside], x_true[src_idx, outside], marker="x", s=40, label="outside interval", zorder=5, ) ax.set_title(f"Fixed source={src_idx}, all-times-inside={inside_all}") ax.set_xlabel("Time") ax.set_ylabel("Source value") ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False) plt.tight_layout() plt.show()
[docs] def plot_fixed_interval_membership_aggregated( self, diag_fixed: Dict[str, Any], src_idx: int, *, nominal_coverage: float = 0.95, figsize: Tuple[float, float] = (7.0, 2.2), ) -> None: import matplotlib.pyplot as plt lo = float(diag_fixed["ci_lower"][src_idx]) hi = float(diag_fixed["ci_upper"][src_idx]) mu = float(diag_fixed["x_hat_agg"][src_idx]) xt = float(diag_fixed["x_true_agg"][src_idx]) inside = bool(diag_fixed["within"][src_idx]) vals = np.array([lo, hi, mu, xt], dtype=float) pad = 0.15 * max(np.ptp(vals), 1e-6) fig, ax = plt.subplots(figsize=figsize) ax.axvspan(lo, hi, alpha=0.25, label=f"{int(100 * nominal_coverage)}% credible interval") ax.scatter(mu, 0.0, s=90, label="aggregated posterior mean", zorder=3) ax.scatter(xt, 0.0, s=110, marker="x", label="aggregated x_true", zorder=4) ax.set_xlim(np.min(vals) - pad, np.max(vals) + pad) ax.set_yticks([]) ax.set_xlabel("Aggregated source value") ax.set_title(f"Fixed source={src_idx}, inside={inside}") for spine in ax.spines.values(): spine.set_visible(False) ax.grid(False) ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False) plt.tight_layout() plt.show()
# ------------------------------------------------------------------ # Visualization: EEG free orientation # ------------------------------------------------------------------
[docs] def plot_eeg_ellipsoid_membership_pointwise( self, x_hat: np.ndarray, # (N,3,T) x_true: np.ndarray, # (N,3,T) posterior_cov: np.ndarray, # (3N,3N) src_idx: int, time_idx: int, nominal_coverage: float = 0.95, *, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, figsize: Tuple[float, float] = (8.0, 6.0), elev: float = 22.0, azim: float = -58.0, ) -> None: import matplotlib.pyplot as plt if x_true.ndim != 3 or x_true.shape[1] != 3: raise ValueError(f"x_true must be (N,3,T); got {x_true.shape}") if x_hat.shape != x_true.shape: raise ValueError("x_hat must match x_true shape.") N, _, T = x_true.shape if posterior_cov.shape != (3 * N, 3 * N): raise ValueError(f"posterior_cov must be (3N,3N); got {posterior_cov.shape}") if not (0 <= src_idx < N): raise ValueError("src_idx out of range.") if not (0 <= time_idx < T): raise ValueError("time_idx out of range.") threshold = self._get_chi2_threshold(nominal_coverage, df=3) center = x_hat[src_idx, :, time_idx] truth = x_true[src_idx, :, time_idx] Sigma3 = posterior_cov[3 * src_idx:3 * src_idx + 3, 3 * src_idx:3 * src_idx + 3] Sigma3 = (Sigma3 + Sigma3.T) / 2.0 if psd_repair_blocks: Sigma3 = self._psd_clip_block(Sigma3, block_epsilon) evals, evecs = np.linalg.eigh(Sigma3) evals = np.maximum(evals, block_epsilon) Sigma3_inv = evecs @ np.diag(1.0 / evals) @ evecs.T d = truth - center q_val = float(d.T @ Sigma3_inv @ d) inside = bool(q_val <= threshold) radii = np.sqrt(threshold * evals) u = np.linspace(0.0, 2.0 * np.pi, 60) v = np.linspace(0.0, np.pi, 30) xs = np.outer(np.cos(u), np.sin(v)) ys = np.outer(np.sin(u), np.sin(v)) zs = np.outer(np.ones_like(u), np.cos(v)) sphere = np.stack([xs, ys, zs], axis=0).reshape(3, -1) ell = (evecs @ np.diag(radii) @ sphere).reshape(3, xs.shape[0], xs.shape[1]) X = center[0] + ell[0] Y = center[1] + ell[1] Z = center[2] + ell[2] fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection="3d") ax.plot_wireframe(X, Y, Z, rstride=2, cstride=2, alpha=0.35, linewidth=0.9) for k in range(3): axis_vec = evecs[:, k] * radii[k] p1 = center - axis_vec p2 = center + axis_vec ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], linewidth=1.2, alpha=0.9, ) ax.scatter(center[0], center[1], center[2], s=80, label="posterior mean", zorder=5) ax.scatter(truth[0], truth[1], truth[2], s=100, marker="x", label="x_true", zorder=6) ax.plot( [center[0], truth[0]], [center[1], truth[1]], [center[2], truth[2]], linestyle="--", linewidth=1.0, alpha=0.8, ) pts = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()]) pts = np.vstack([pts, center[None, :], truth[None, :]]) self._set_equal_3d_limits_centered(ax, center=center, xyz_points=pts, margin=0.08) ax.view_init(elev=elev, azim=azim) ax.set_title(f"EEG source={src_idx}, time={time_idx}, inside={inside}") ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0), frameon=False) plt.tight_layout() plt.show()
[docs] def plot_eeg_ellipsoid_membership_aggregated( self, diag_eeg: Dict[str, Any], src_idx: int, *, figsize: Tuple[float, float] = (8.0, 6.0), elev: float = 22.0, azim: float = -58.0, ) -> None: import matplotlib.pyplot as plt center = diag_eeg["x_hat_agg"][src_idx] truth = diag_eeg["x_true_agg"][src_idx] Sigma3 = diag_eeg["cov_blocks"][src_idx] threshold = float(diag_eeg["threshold"]) inside = bool(diag_eeg["within"][src_idx]) Sigma3 = (Sigma3 + Sigma3.T) / 2.0 evals, evecs = np.linalg.eigh(Sigma3) evals = np.maximum(evals, 1e-12) radii = np.sqrt(threshold * evals) u = np.linspace(0.0, 2.0 * np.pi, 60) v = np.linspace(0.0, np.pi, 30) xs = np.outer(np.cos(u), np.sin(v)) ys = np.outer(np.sin(u), np.sin(v)) zs = np.outer(np.ones_like(u), np.cos(v)) sphere = np.stack([xs, ys, zs], axis=0).reshape(3, -1) ell = (evecs @ np.diag(radii) @ sphere).reshape(3, xs.shape[0], xs.shape[1]) X = center[0] + ell[0] Y = center[1] + ell[1] Z = center[2] + ell[2] fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection="3d") ax.plot_wireframe(X, Y, Z, rstride=2, cstride=2, alpha=0.35, linewidth=0.9) for k in range(3): axis_vec = evecs[:, k] * radii[k] p1 = center - axis_vec p2 = center + axis_vec ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], linewidth=1.2, alpha=0.9, ) ax.scatter(center[0], center[1], center[2], s=80, label="aggregated posterior mean", zorder=5) ax.scatter(truth[0], truth[1], truth[2], s=100, marker="x", label="aggregated x_true", zorder=6) ax.plot( [center[0], truth[0]], [center[1], truth[1]], [center[2], truth[2]], linestyle="--", linewidth=1.0, alpha=0.8, ) pts = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()]) pts = np.vstack([pts, center[None, :], truth[None, :]]) self._set_equal_3d_limits_centered(ax, center=center, xyz_points=pts, margin=0.08) ax.view_init(elev=elev, azim=azim) ax.set_title(f"EEG source={src_idx}, inside={inside}") ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0), frameon=False) plt.tight_layout() plt.show()
# ------------------------------------------------------------------ # Visualization: MEG free orientation (reduced 2D posterior) # ------------------------------------------------------------------
[docs] def plot_meg_ellipse_membership_pointwise( self, x_true_3d: np.ndarray, # (N,3,T) x_hat_2d: np.ndarray, # (N,2,T) posterior_cov_2d: np.ndarray, # (2N,2N) src_idx: int, time_idx: int, nominal_coverage: float = 0.95, *, V_tan: Optional[np.ndarray] = None, L_free_Mx3N: Optional[np.ndarray] = None, psd_repair_blocks: bool = False, block_epsilon: float = 1e-12, figsize: Tuple[float, float] = (6.5, 6.0), ) -> None: import matplotlib.pyplot as plt diag = self.pointwise_ellipse_membership_meg_free( x_true_3d=x_true_3d, x_hat_2d=x_hat_2d, posterior_cov_2d=posterior_cov_2d, nominal_coverage=nominal_coverage, V_tan=V_tan, L_free_Mx3N=L_free_Mx3N, psd_repair_blocks=psd_repair_blocks, block_epsilon=block_epsilon, ) center2 = diag["projected_mean"][src_idx, :, time_idx] truth2 = diag["projected_true"][src_idx, :, time_idx] Sigma2 = diag["cov_blocks"][src_idx] threshold = float(diag["threshold"]) inside = bool(diag["within"][src_idx, time_idx]) Sigma2 = (Sigma2 + Sigma2.T) / 2.0 evals, evecs = np.linalg.eigh(Sigma2) evals = np.maximum(evals, 1e-12) radii = np.sqrt(threshold * evals) theta = np.linspace(0.0, 2.0 * np.pi, 361) circle = np.vstack([np.cos(theta), np.sin(theta)]) ellipse = evecs @ np.diag(radii) @ circle ex = center2[0] + ellipse[0] ey = center2[1] + ellipse[1] fig, ax = plt.subplots(figsize=figsize) ax.plot(ex, ey, linewidth=1.4, label=f"{int(100 * nominal_coverage)}% credible ellipse") for k in range(2): axis_vec = evecs[:, k] * radii[k] p1 = center2 - axis_vec p2 = center2 + axis_vec ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], linewidth=1.2, alpha=0.9, label=f"Axis {k + 1}", ) ax.scatter(center2[0], center2[1], s=90, label="posterior mean", zorder=3) ax.scatter(truth2[0], truth2[1], s=110, marker="x", label="projected x_true", zorder=4) ax.plot( [center2[0], truth2[0]], [center2[1], truth2[1]], linestyle="--", linewidth=1.0, alpha=0.8, ) x_all = np.concatenate([ex, [truth2[0], center2[0]]]) y_all = np.concatenate([ey, [truth2[1], center2[1]]]) pad_x = 0.15 * max(np.ptp(x_all), 1e-6) pad_y = 0.15 * max(np.ptp(y_all), 1e-6) ax.set_xlim(np.min(x_all) - pad_x, np.max(x_all) + pad_x) ax.set_ylim(np.min(y_all) - pad_y, np.max(y_all) + pad_y) ax.set_aspect("equal", adjustable="box") ax.set_xlabel("Reduced MEG coordinate 1") ax.set_ylabel("Reduced MEG coordinate 2") ax.set_title(f"MEG source={src_idx}, time={time_idx}, inside={inside}") ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False) plt.tight_layout() plt.show()
[docs] def plot_meg_ellipse_membership_aggregated( self, diag_meg: Dict[str, Any], src_idx: int, *, nominal_coverage: float = 0.95, figsize: Tuple[float, float] = (6.5, 6.0), ) -> None: import matplotlib.pyplot as plt center2 = diag_meg["projected_mean"][src_idx] truth2 = diag_meg["projected_true"][src_idx] Sigma2 = diag_meg["cov_blocks"][src_idx] threshold = float(diag_meg["threshold"]) inside = bool(diag_meg["within"][src_idx]) Sigma2 = (Sigma2 + Sigma2.T) / 2.0 evals, evecs = np.linalg.eigh(Sigma2) evals = np.maximum(evals, 1e-12) radii = np.sqrt(threshold * evals) theta = np.linspace(0.0, 2.0 * np.pi, 361) circle = np.vstack([np.cos(theta), np.sin(theta)]) ellipse = evecs @ np.diag(radii) @ circle ex = center2[0] + ellipse[0] ey = center2[1] + ellipse[1] fig, ax = plt.subplots(figsize=figsize) ax.plot(ex, ey, linewidth=1.4, label=f"{int(100 * nominal_coverage)}% credible ellipse") for k in range(2): axis_vec = evecs[:, k] * radii[k] p1 = center2 - axis_vec p2 = center2 + axis_vec ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], linewidth=1.2, alpha=0.9, label=f"Axis {k + 1}", ) ax.scatter(center2[0], center2[1], s=90, label="aggregated posterior mean", zorder=3) ax.scatter(truth2[0], truth2[1], s=110, marker="x", label="aggregated projected x_true", zorder=4) ax.plot( [center2[0], truth2[0]], [center2[1], truth2[1]], linestyle="--", linewidth=1.0, alpha=0.8, ) x_all = np.concatenate([ex, [truth2[0], center2[0]]]) y_all = np.concatenate([ey, [truth2[1], center2[1]]]) pad_x = 0.15 * max(np.ptp(x_all), 1e-6) pad_y = 0.15 * max(np.ptp(y_all), 1e-6) ax.set_xlim(np.min(x_all) - pad_x, np.max(x_all) + pad_x) ax.set_ylim(np.min(y_all) - pad_y, np.max(y_all) + pad_y) ax.set_aspect("equal", adjustable="box") ax.set_xlabel("Reduced MEG coordinate 1") ax.set_ylabel("Reduced MEG coordinate 2") ax.set_title(f"MEG source={src_idx}, inside={inside}") ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False) plt.tight_layout() plt.show()
# # ============================================================================= # # Example usage # # # # This section covers BOTH uncertainty modes for the current MEG setup: # # # # 1) pointwise mode # # 2) aggregated mode # # # # Visualizations are called separately for the two modes. # # # # Active: # # - fixed MEG (BMN) # # - fixed MEG (BMN_joint) # # - free MEG reduced rank-2 (BMN) # # - free MEG reduced rank-2 (BMN_joint) # # # # Commented: # # - fixed EEG # # - free EEG # # ============================================================================= # # ------------------------------------------------------------------ # # 1) FIXED MEG -- BMN # # ------------------------------------------------------------------ # results_unc_fixed_meg = run_uncertainty_example_fixed( # x_true_fixed=x_fixed_meg, # out_fixed=out_fixed_bmn, # nominal_coverage=0.95, # src_idx=0, # do_plots=False, # plots called separately below # ) # print("\n================ FIXED MEG / BMN / POINTWISE ================") # print("Empirical coverage :", results_unc_fixed_meg["diag_fixed_point"]["empirical_coverage"]) # print("Count within :", results_unc_fixed_meg["diag_fixed_point"]["count_within"]) # print("Total count :", results_unc_fixed_meg["diag_fixed_point"]["total_count"]) # print("\n=============== FIXED MEG / BMN / AGGREGATED ===============") # print("Empirical coverage :", results_unc_fixed_meg["diag_fixed_agg"]["empirical_coverage"]) # print("Count within :", results_unc_fixed_meg["diag_fixed_agg"]["count_within"]) # print("Total count :", results_unc_fixed_meg["diag_fixed_agg"]["total_count"]) # # Pointwise visualization # results_unc_fixed_meg["ue"].plot_fixed_interval_membership_pointwise( # x_true=x_fixed_meg, # x_hat=out_fixed_bmn["posterior_mean"], # posterior_var=results_unc_fixed_meg["posterior_var_fixed"], # src_idx=0, # nominal_coverage=0.95, # ) # # Aggregated visualization # results_unc_fixed_meg["ue"].plot_fixed_interval_membership_aggregated( # diag_fixed=results_unc_fixed_meg["diag_fixed_agg"], # src_idx=0, # nominal_coverage=0.95, # ) # # ------------------------------------------------------------------ # # 2) FIXED MEG -- BMN_joint # # ------------------------------------------------------------------ # results_unc_fixed_meg_joint = run_uncertainty_example_fixed( # x_true_fixed=x_fixed_meg, # out_fixed=out_fixed_joint, # nominal_coverage=0.95, # src_idx=0, # do_plots=False, # plots called separately below # ) # print("\n============= FIXED MEG / BMN_joint / POINTWISE =============") # print("Empirical coverage :", results_unc_fixed_meg_joint["diag_fixed_point"]["empirical_coverage"]) # print("Count within :", results_unc_fixed_meg_joint["diag_fixed_point"]["count_within"]) # print("Total count :", results_unc_fixed_meg_joint["diag_fixed_point"]["total_count"]) # print("\n============ FIXED MEG / BMN_joint / AGGREGATED ============") # print("Empirical coverage :", results_unc_fixed_meg_joint["diag_fixed_agg"]["empirical_coverage"]) # print("Count within :", results_unc_fixed_meg_joint["diag_fixed_agg"]["count_within"]) # print("Total count :", results_unc_fixed_meg_joint["diag_fixed_agg"]["total_count"]) # # Pointwise visualization # results_unc_fixed_meg_joint["ue"].plot_fixed_interval_membership_pointwise( # x_true=x_fixed_meg, # x_hat=out_fixed_joint["posterior_mean"], # posterior_var=results_unc_fixed_meg_joint["posterior_var_fixed"], # src_idx=0, # nominal_coverage=0.95, # ) # # Aggregated visualization # results_unc_fixed_meg_joint["ue"].plot_fixed_interval_membership_aggregated( # diag_fixed=results_unc_fixed_meg_joint["diag_fixed_agg"], # src_idx=0, # nominal_coverage=0.95, # ) # # ------------------------------------------------------------------ # # 3) FREE MEG REDUCED RANK-2 -- BMN # # ------------------------------------------------------------------ # results_unc_meg = run_uncertainty_example_meg( # a_true_meg_2d=a_free_meg, # Q_basis=lf_free_meg["Q_basis"], # out_meg=out_free_meg_bmn, # nominal_coverage=0.95, # src_idx=0, # time_idx=0, # psd_repair_blocks=True, # block_epsilon=1e-12, # do_plots=False, # plots called separately below # ) # print("\n============= FREE MEG / BMN / POINTWISE ===================") # print("Empirical coverage :", results_unc_meg["diag_meg_point"]["empirical_coverage"]) # print("Count within :", results_unc_meg["diag_meg_point"]["count_within"]) # print("Total count :", results_unc_meg["diag_meg_point"]["total_count"]) # print("\n============ FREE MEG / BMN / AGGREGATED ==================") # print("Empirical coverage :", results_unc_meg["diag_meg_agg"]["empirical_coverage"]) # print("Count within :", results_unc_meg["diag_meg_agg"]["count_within"]) # print("Total count :", results_unc_meg["diag_meg_agg"]["total_count"]) # # Pointwise visualization # results_unc_meg["ue"].plot_meg_ellipse_membership_pointwise( # x_true_3d=results_unc_meg["x_true_meg_3d"], # x_hat_2d=results_unc_meg["x_hat_meg_2d"], # posterior_cov_2d=out_free_meg_bmn["posterior_cov"], # src_idx=0, # time_idx=0, # nominal_coverage=0.95, # V_tan=results_unc_meg["V_tan"], # psd_repair_blocks=True, # block_epsilon=1e-12, # ) # # Aggregated visualization # results_unc_meg["ue"].plot_meg_ellipse_membership_aggregated( # diag_meg=results_unc_meg["diag_meg_agg"], # src_idx=0, # nominal_coverage=0.95, # ) # # ------------------------------------------------------------------ # # 4) FREE MEG REDUCED RANK-2 -- BMN_joint # # ------------------------------------------------------------------ # results_unc_meg_joint = run_uncertainty_example_meg( # a_true_meg_2d=a_free_meg, # Q_basis=lf_free_meg["Q_basis"], # out_meg=out_free_meg_joint, # nominal_coverage=0.95, # src_idx=0, # time_idx=0, # psd_repair_blocks=True, # block_epsilon=1e-12, # do_plots=False, # plots called separately below # ) # print("\n=========== FREE MEG / BMN_joint / POINTWISE ==============") # print("Empirical coverage :", results_unc_meg_joint["diag_meg_point"]["empirical_coverage"]) # print("Count within :", results_unc_meg_joint["diag_meg_point"]["count_within"]) # print("Total count :", results_unc_meg_joint["diag_meg_point"]["total_count"]) # print("\n========== FREE MEG / BMN_joint / AGGREGATED =============") # print("Empirical coverage :", results_unc_meg_joint["diag_meg_agg"]["empirical_coverage"]) # print("Count within :", results_unc_meg_joint["diag_meg_agg"]["count_within"]) # print("Total count :", results_unc_meg_joint["diag_meg_agg"]["total_count"]) # # Pointwise visualization # results_unc_meg_joint["ue"].plot_meg_ellipse_membership_pointwise( # x_true_3d=results_unc_meg_joint["x_true_meg_3d"], # x_hat_2d=results_unc_meg_joint["x_hat_meg_2d"], # posterior_cov_2d=out_free_meg_joint["posterior_cov"], # src_idx=0, # time_idx=0, # nominal_coverage=0.95, # V_tan=results_unc_meg_joint["V_tan"], # psd_repair_blocks=True, # block_epsilon=1e-12, # ) # # Aggregated visualization # results_unc_meg_joint["ue"].plot_meg_ellipse_membership_aggregated( # diag_meg=results_unc_meg_joint["diag_meg_agg"], # src_idx=0, # nominal_coverage=0.95, # )