import os
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from scipy.stats import chi2, norm
from itertools import combinations, zip_longest
import mne
import logging
from typing import Optional, Dict, Any, Tuple
[docs]
def lift_reduced_sources_to_3d(a_red: np.ndarray, Q_basis: np.ndarray) -> np.ndarray:
"""
Lift reduced coordinates back to the retained local 3D source basis.
Parameters
----------
a_red : (N, k, T)
Reduced source coefficients, typically with k=2 for free_meg.
Q_basis : (N, 3, k)
Per-source basis mapping reduced coordinates into the local 3D basis.
Returns
-------
x_3d : (N, 3, T)
Lifted coefficients in the local 3D source basis.
"""
a_red = np.asarray(a_red, dtype=float)
Q_basis = np.asarray(Q_basis, dtype=float)
if a_red.ndim != 3:
raise ValueError(f"Expected a_red with shape (N, k, T). Got {a_red.shape}")
if Q_basis.ndim != 3:
raise ValueError(f"Expected Q_basis with shape (N, 3, k). Got {Q_basis.shape}")
if a_red.shape[0] != Q_basis.shape[0]:
raise ValueError("Mismatch in number of sources between a_red and Q_basis")
if a_red.shape[1] != Q_basis.shape[2]:
raise ValueError("Mismatch in reduced dimension k between a_red and Q_basis")
return np.einsum("nok,nkt->not", Q_basis, a_red)
# =============================================================================
# Uncertainty estimator
# =============================================================================
[docs]
class UncertaintyEstimator:
"""
Uncertainty estimator for fixed- and free-orientation source estimates.
Supported settings
------------------
1) Fixed orientation (EEG or MEG):
- pointwise marginal credible interval membership
- aggregated-mean marginal credible interval membership
2) Free orientation EEG:
- pointwise 3D credible ellipsoid membership
- aggregated-mean 3D credible ellipsoid membership
3) Free orientation MEG (reduced rank-2 model):
- pointwise 2D credible ellipse membership in reduced coordinates
- aggregated-mean 2D credible ellipse membership in reduced coordinates
Modeling convention
-------------------
- posterior mean varies over time
- posterior covariance is static over time
- when aggregating by time-average, covariance is scaled by 1 / T
Important MEG convention
------------------------
For MEG, the inverse model is assumed to be built on a reduced 2-orientation
leadfield obtained source-wise via SVD or, preferably, taken directly from
the extractor as `Q_basis`. Therefore:
- truth can be represented in reduced coordinates a_true with shape (N,2,T)
- if needed, truth can be lifted to local 3D via Q_basis -> shape (N,3,T)
- posterior mean is in reduced 2D coordinates, shape (N,2,T)
- posterior covariance is in reduced coordinates, shape (2N,2N)
Recommended usage
-----------------
Use the extractor-provided basis:
V_tan = lf_free_meg["Q_basis"]
instead of recomputing a new basis inside the uncertainty step.
"""
[docs]
def __init__(
self,
nominal_coverages: Optional[np.ndarray] = None,
logger: Optional[logging.Logger] = None,
):
if nominal_coverages is None:
nominal_coverages = np.array(
[0.0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5,
0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1.0],
dtype=float,
)
else:
nominal_coverages = np.asarray(nominal_coverages, dtype=float)
if np.any(nominal_coverages < 0.0) or np.any(nominal_coverages > 1.0):
raise ValueError("Nominal coverages must be between 0 and 1.")
nominal_coverages = np.unique(nominal_coverages)
nominal_coverages.sort()
self.nominal_coverages = nominal_coverages
self.logger = logger or logging.getLogger(__name__)
self.z_scores: Dict[float, float] = {}
for c in self.nominal_coverages:
c = float(c)
if np.isclose(c, 0.0):
self.z_scores[c] = 0.0
elif np.isclose(c, 1.0):
self.z_scores[c] = np.inf
else:
self.z_scores[c] = float(norm.ppf((1.0 + c) / 2.0))
# def __getstate__(self):
# state = self.__dict__.copy()
# state['logger'] = None
# return state
# def __setstate__(self, state):
# self.__dict__.update(state)
# if self.logger is None:
# self.logger = logging.getLogger(self.__class__.__name__)
# ------------------------------------------------------------------
# Core helpers
# ------------------------------------------------------------------
def _get_z(self, nominal_coverage: float) -> float:
c = float(nominal_coverage)
if np.isclose(c, 0.0):
return 0.0
if np.isclose(c, 1.0):
return np.inf
if c not in self.z_scores:
self.z_scores[c] = float(norm.ppf((1.0 + c) / 2.0))
return self.z_scores[c]
@staticmethod
def _get_chi2_threshold(nominal_coverage: float, df: int) -> float:
c = float(nominal_coverage)
if np.isclose(c, 0.0):
return 0.0
if np.isclose(c, 1.0):
return np.inf
return float(chi2.ppf(c, df=df))
@staticmethod
def _psd_clip_block(cov_block: np.ndarray, eps: float) -> np.ndarray:
cov_block = (cov_block + cov_block.T) / 2.0
evals, evecs = np.linalg.eigh(cov_block)
evals = np.maximum(evals, eps)
cov_psd = evecs @ np.diag(evals) @ evecs.T
return (cov_psd + cov_psd.T) / 2.0
@staticmethod
def _canonicalize_basis_columns(Q: np.ndarray) -> np.ndarray:
"""
Stabilize basis column signs so that the entry with largest absolute
value in each column is nonnegative.
"""
Q = np.asarray(Q, dtype=float)
if Q.ndim != 2:
raise ValueError(f"Expected 2D basis matrix. Got {Q.shape}")
Q_out = np.array(Q, dtype=float, copy=True)
for j in range(Q_out.shape[1]):
idx = int(np.argmax(np.abs(Q_out[:, j])))
if Q_out[idx, j] < 0:
Q_out[:, j] *= -1.0
return Q_out
[docs]
@staticmethod
def posterior_variance_from_cov(
posterior_cov: np.ndarray,
*,
min_var: float = 0.0,
) -> np.ndarray:
var = np.diag(posterior_cov).astype(float, copy=True)
if min_var <= 0.0:
return np.maximum(var, 0.0)
return np.maximum(var, float(min_var))
@staticmethod
def _set_equal_3d_limits_centered(
ax,
center: np.ndarray,
xyz_points: np.ndarray,
margin: float = 0.08,
) -> None:
center = np.asarray(center, dtype=float).reshape(3)
xyz_points = np.asarray(xyz_points, dtype=float).reshape(-1, 3)
max_dev = float(np.max(np.abs(xyz_points - center[None, :])))
half = max((1.0 + margin) * max_dev, 1e-6)
ax.set_xlim(center[0] - half, center[0] + half)
ax.set_ylim(center[1] - half, center[1] + half)
ax.set_zlim(center[2] - half, center[2] + half)
try:
ax.set_box_aspect((1, 1, 1))
except Exception:
pass
# ------------------------------------------------------------------
# Fixed orientation
# ------------------------------------------------------------------
[docs]
def credible_intervals_normal(
self,
mean: np.ndarray,
variance: np.ndarray,
nominal_coverage: float,
) -> Tuple[np.ndarray, np.ndarray]:
c = float(nominal_coverage)
if np.isclose(c, 0.0):
return mean.copy(), mean.copy()
if np.isclose(c, 1.0):
return np.full_like(mean, -np.inf), np.full_like(mean, np.inf)
z = self._get_z(c)
std = np.sqrt(np.maximum(variance, 0.0))
return mean - z * std, mean + z * std
[docs]
def pointwise_interval_membership(
self,
x_true: np.ndarray, # (N,T)
x_hat: np.ndarray, # (N,T)
posterior_var: np.ndarray, # (N,)
nominal_coverage: float,
) -> Dict[str, Any]:
x_true = np.asarray(x_true, dtype=float)
x_hat = np.asarray(x_hat, dtype=float)
posterior_var = np.asarray(posterior_var, dtype=float)
if x_true.ndim != 2 or x_hat.ndim != 2:
raise ValueError("x_true and x_hat must have shape (N,T) for fixed orientation.")
if x_true.shape != x_hat.shape:
raise ValueError("x_true and x_hat must have the same shape.")
if posterior_var.ndim != 1 or posterior_var.shape[0] != x_hat.shape[0]:
raise ValueError("posterior_var must have shape (N,).")
N, T = x_hat.shape
var_full = np.repeat(posterior_var[:, None], T, axis=1)
lo, hi = self.credible_intervals_normal(x_hat, var_full, nominal_coverage)
within = (x_true >= lo) & (x_true <= hi)
return {
"ci_lower": lo,
"ci_upper": hi,
"within": within,
"posterior_var": var_full,
"z_score": float(self._get_z(nominal_coverage)),
"count_within": int(np.sum(within)),
"total_count": int(within.size),
"empirical_coverage": float(np.mean(within)),
"n_times": int(T),
}
[docs]
def aggregated_interval_membership(
self,
x_true: np.ndarray, # (N,T)
x_hat: np.ndarray, # (N,T)
posterior_var: np.ndarray, # (N,)
nominal_coverage: float,
) -> Dict[str, Any]:
x_true = np.asarray(x_true, dtype=float)
x_hat = np.asarray(x_hat, dtype=float)
posterior_var = np.asarray(posterior_var, dtype=float)
if x_true.ndim != 2 or x_hat.ndim != 2:
raise ValueError("Aggregated fixed-orientation interval expects x_true and x_hat with shape (N,T).")
if x_true.shape != x_hat.shape:
raise ValueError("x_true and x_hat must have the same shape.")
if posterior_var.ndim != 1 or posterior_var.shape[0] != x_hat.shape[0]:
raise ValueError("posterior_var must have shape (N,).")
N, T = x_hat.shape
x_true_agg = np.mean(x_true, axis=1)
x_hat_agg = np.mean(x_hat, axis=1)
var_agg = posterior_var / float(T)
lo, hi = self.credible_intervals_normal(x_hat_agg, var_agg, nominal_coverage)
within = (x_true_agg >= lo) & (x_true_agg <= hi)
return {
"x_true_agg": x_true_agg,
"x_hat_agg": x_hat_agg,
"posterior_var_agg": var_agg,
"ci_lower": lo,
"ci_upper": hi,
"within": within,
"z_score": float(self._get_z(nominal_coverage)),
"count_within": int(np.sum(within)),
"total_count": int(within.size),
"empirical_coverage": float(np.mean(within)),
"n_times": int(T),
}
[docs]
def calibration_curve_intervals_pointwise(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_var: np.ndarray,
) -> Dict[str, Any]:
empirical_coverages = []
counts = []
for c in self.nominal_coverages:
out = self.pointwise_interval_membership(
x_true=x_true,
x_hat=x_hat,
posterior_var=posterior_var,
nominal_coverage=float(c),
)
empirical_coverages.append(out["empirical_coverage"])
counts.append(out["count_within"])
return {
"nominal_coverages": self.nominal_coverages,
"empirical_coverages": np.asarray(empirical_coverages, dtype=float),
"ci_counts": np.asarray(counts, dtype=int),
"interval_type": "marginal",
}
[docs]
def calibration_curve_intervals_aggregated(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_var: np.ndarray,
) -> Dict[str, Any]:
empirical_coverages = []
counts = []
for c in self.nominal_coverages:
out = self.aggregated_interval_membership(
x_true=x_true,
x_hat=x_hat,
posterior_var=posterior_var,
nominal_coverage=float(c),
)
empirical_coverages.append(out["empirical_coverage"])
counts.append(out["count_within"])
return {
"nominal_coverages": self.nominal_coverages,
"empirical_coverages": np.asarray(empirical_coverages, dtype=float),
"ci_counts": np.asarray(counts, dtype=int),
"interval_type": "marginal",
}
# ------------------------------------------------------------------
# EEG free orientation
# ------------------------------------------------------------------
[docs]
def pointwise_ellipsoid_membership_eeg_free(
self,
x_true: np.ndarray, # (N,3,T)
x_hat: np.ndarray, # (N,3,T)
posterior_cov: np.ndarray, # (3N,3N) or (N,3,3)
nominal_coverage: float,
*,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
) -> Dict[str, Any]:
x_true = np.asarray(x_true, dtype=float)
x_hat = np.asarray(x_hat, dtype=float)
posterior_cov = np.asarray(posterior_cov, dtype=float)
thresh = self._get_chi2_threshold(nominal_coverage, df=3)
if x_true.ndim != 3 or x_true.shape[1] != 3:
raise ValueError(f"x_true must be (N,3,T); got {x_true.shape}")
if x_hat.shape != x_true.shape:
raise ValueError("x_hat must match x_true shape (N,3,T).")
N, _, T = x_true.shape
cov_is_blocks = posterior_cov.ndim == 3
if cov_is_blocks:
if posterior_cov.shape != (N, 3, 3):
raise ValueError(
f"posterior_cov block form must be (N,3,3); got {posterior_cov.shape}"
)
else:
if posterior_cov.shape != (3 * N, 3 * N):
raise ValueError(f"posterior_cov must be (3N,3N); got {posterior_cov.shape}")
q_values = np.zeros((N, T), dtype=float)
within = np.zeros((N, T), dtype=bool)
cov_blocks = np.zeros((N, 3, 3), dtype=float)
for i in range(N):
Sigma3 = posterior_cov[i] if cov_is_blocks else posterior_cov[3 * i:3 * i + 3, 3 * i:3 * i + 3]
Sigma3 = (Sigma3 + Sigma3.T) / 2.0
if psd_repair_blocks:
Sigma3 = self._psd_clip_block(Sigma3, block_epsilon)
cov_blocks[i] = Sigma3
evals, evecs = np.linalg.eigh(Sigma3)
evals = np.maximum(evals, block_epsilon)
Sigma3_inv = evecs @ np.diag(1.0 / evals) @ evecs.T
d = x_true[i] - x_hat[i] # (3,T)
q = np.einsum("it,ij,jt->t", d, Sigma3_inv, d)
q_values[i] = q
within[i] = (q <= thresh)
return {
"q_values": q_values,
"threshold": float(thresh),
"within": within,
"cov_blocks": cov_blocks,
"count_within": int(np.sum(within)),
"total_count": int(within.size),
"empirical_coverage": float(np.mean(within)),
"n_times": int(T),
}
[docs]
def aggregated_ellipsoid_membership_eeg_free(
self,
x_true: np.ndarray, # (N,3,T)
x_hat: np.ndarray, # (N,3,T)
posterior_cov: np.ndarray, # (3N,3N) or (N,3,3)
nominal_coverage: float,
*,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
) -> Dict[str, Any]:
x_true = np.asarray(x_true, dtype=float)
x_hat = np.asarray(x_hat, dtype=float)
posterior_cov = np.asarray(posterior_cov, dtype=float)
if x_true.ndim != 3 or x_true.shape[1] != 3:
raise ValueError(f"x_true must be (N,3,T); got {x_true.shape}")
if x_hat.shape != x_true.shape:
raise ValueError("x_hat must match x_true shape (N,3,T).")
N, _, T = x_true.shape
cov_is_blocks = posterior_cov.ndim == 3
if cov_is_blocks:
if posterior_cov.shape != (N, 3, 3):
raise ValueError(
f"posterior_cov block form must be (N,3,3); got {posterior_cov.shape}"
)
else:
if posterior_cov.shape != (3 * N, 3 * N):
raise ValueError(f"posterior_cov must be (3N,3N); got {posterior_cov.shape}")
x_true_agg = np.mean(x_true, axis=2)
x_hat_agg = np.mean(x_hat, axis=2)
thresh = self._get_chi2_threshold(nominal_coverage, df=3)
q_values = np.zeros(N, dtype=float)
within = np.zeros(N, dtype=bool)
cov_blocks = np.zeros((N, 3, 3), dtype=float)
for i in range(N):
Sigma3 = posterior_cov[i] if cov_is_blocks else posterior_cov[3 * i:3 * i + 3, 3 * i:3 * i + 3]
Sigma3 = (Sigma3 + Sigma3.T) / 2.0
if psd_repair_blocks:
Sigma3 = self._psd_clip_block(Sigma3, block_epsilon)
Sigma3_agg = Sigma3 / float(T)
cov_blocks[i] = Sigma3_agg
evals, evecs = np.linalg.eigh(Sigma3_agg)
evals = np.maximum(evals, block_epsilon)
Sigma3_inv = evecs @ np.diag(1.0 / evals) @ evecs.T
d = x_true_agg[i] - x_hat_agg[i]
q = float(d.T @ Sigma3_inv @ d)
q_values[i] = q
within[i] = (q <= thresh)
return {
"x_true_agg": x_true_agg,
"x_hat_agg": x_hat_agg,
"q_values": q_values,
"threshold": float(thresh),
"within": within,
"cov_blocks": cov_blocks,
"count_within": int(np.sum(within)),
"total_count": int(within.size),
"empirical_coverage": float(np.mean(within)),
"n_times": int(T),
}
[docs]
def calibration_curve_ellipsoid_eeg_free_pointwise(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_cov: np.ndarray,
*,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
) -> Dict[str, Any]:
empirical_coverages = []
counts = []
for c in self.nominal_coverages:
out = self.pointwise_ellipsoid_membership_eeg_free(
x_true=x_true,
x_hat=x_hat,
posterior_cov=posterior_cov,
nominal_coverage=float(c),
psd_repair_blocks=psd_repair_blocks,
block_epsilon=block_epsilon,
)
empirical_coverages.append(out["empirical_coverage"])
counts.append(out["count_within"])
return {
"nominal_coverages": self.nominal_coverages,
"empirical_coverages": np.asarray(empirical_coverages, dtype=float),
"ci_counts": np.asarray(counts, dtype=int),
"interval_type": "full_cov",
}
[docs]
def calibration_curve_ellipsoid_eeg_free_aggregated(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_cov: np.ndarray,
*,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
) -> Dict[str, Any]:
empirical_coverages = []
counts = []
for c in self.nominal_coverages:
out = self.aggregated_ellipsoid_membership_eeg_free(
x_true=x_true,
x_hat=x_hat,
posterior_cov=posterior_cov,
nominal_coverage=float(c),
psd_repair_blocks=psd_repair_blocks,
block_epsilon=block_epsilon,
)
empirical_coverages.append(out["empirical_coverage"])
counts.append(out["count_within"])
return {
"nominal_coverages": self.nominal_coverages,
"empirical_coverages": np.asarray(empirical_coverages, dtype=float),
"ci_counts": np.asarray(counts, dtype=int),
"interval_type": "full_cov",
}
# ------------------------------------------------------------------
# Direction-aggregated marginal intervals (free orientation)
# ------------------------------------------------------------------
@staticmethod
def _as_source_major_free_mean(
posterior_mean: np.ndarray,
*,
n_sources: int,
n_orient: int,
name: str = "posterior_mean",
) -> np.ndarray:
"""Convert a free-orientation posterior mean to source-major shape (N,K,T).
Accepted input shapes
---------------------
- (N,K,T): already source-major.
- (K*N,T): flattened source-major representation.
Notes
-----
This assumes source-major flattening:
(N,K,T) -> (N*K,T),
consistent with lead-field reshaping conventions such as:
L_block.reshape(M, N*K).
"""
posterior_mean = np.asarray(posterior_mean, dtype=float)
if posterior_mean.ndim == 3:
if posterior_mean.shape[0] != n_sources or posterior_mean.shape[1] != n_orient:
raise ValueError(
f"{name} must have shape (N,K,T)=({n_sources},{n_orient},T). Got {posterior_mean.shape}."
)
return posterior_mean
if posterior_mean.ndim == 2:
expected_first_dim = n_sources * n_orient
if posterior_mean.shape[0] != expected_first_dim:
raise ValueError(
f"{name} has incompatible flattened shape. Expected first dimension {expected_first_dim}, got {posterior_mean.shape[0]}."
)
T = posterior_mean.shape[1]
return posterior_mean.reshape(n_sources, n_orient, T)
raise ValueError(f"{name} must be either (N,K,T) or (K*N,T). Got {posterior_mean.shape}.")
[docs]
@staticmethod
def componentwise_variance_from_uncert(
posterior_uncert: np.ndarray,
*,
n_sources: int,
n_orient: int,
min_var: float = 1e-12,
) -> np.ndarray:
"""Extract per-source per-component marginal variances as an (N,K) array.
Supports:
- full covariance: (K*N, K*N), uses diag(...) and reshapes to (N,K)
- block covariance: (N, K, K), uses diagonal blocks directly
This utility underpins the `interval_type="marginal"` diagnostic for free
orientation, where directions/components are not evaluated separately but
pooled (see `calibration_curve_componentwise_intervals_*`).
"""
posterior_uncert = np.asarray(posterior_uncert, dtype=float)
if posterior_uncert.ndim == 2:
expected = (n_sources * n_orient, n_sources * n_orient)
if posterior_uncert.shape != expected:
raise ValueError(f"posterior_uncert must have shape {expected}. Got {posterior_uncert.shape}.")
posterior_var = np.diag(posterior_uncert).astype(float, copy=True).reshape(n_sources, n_orient)
elif posterior_uncert.ndim == 3:
expected = (n_sources, n_orient, n_orient)
if posterior_uncert.shape != expected:
raise ValueError(
f"posterior_uncert block form must have shape {expected}. Got {posterior_uncert.shape}."
)
posterior_var = np.diagonal(posterior_uncert, axis1=1, axis2=2).astype(float, copy=True)
else:
raise ValueError("posterior_uncert must be either (K*N,K*N) or (N,K,K).")
if min_var <= 0.0:
return np.maximum(posterior_var, 0.0)
return np.maximum(posterior_var, float(min_var))
[docs]
def pointwise_componentwise_interval_membership_free(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_uncert: np.ndarray,
nominal_coverage: float,
*,
n_orient: int,
min_var: float = 1e-12,
) -> Dict[str, Any]:
"""Pointwise marginal (component-wise) CI membership for free orientation.
This implements the "direction-aggregated marginal" diagnostic from
`_temp/Component_CI_R1.py`:
- Build scalar normal credible intervals per retained component using
only marginal variances (diagonal of the covariance).
- Pool interval membership over sources, components, and time.
Important
---------
Directions/components are *not* evaluated separately. They are pooled
because local coordinate systems are arbitrary (EEG xyz axes depend on
coordinate choice; MEG reduced components are voxel-dependent).
"""
x_true = np.asarray(x_true, dtype=float)
x_hat = np.asarray(x_hat, dtype=float)
if x_true.ndim != 3 or x_true.shape[1] != n_orient:
raise ValueError(f"x_true must have shape (N,{n_orient},T). Got {x_true.shape}.")
if x_hat.shape != x_true.shape:
raise ValueError(f"x_hat must match x_true shape. Got {x_hat.shape} vs {x_true.shape}.")
N, K, T = x_true.shape
posterior_var = self.componentwise_variance_from_uncert(
posterior_uncert,
n_sources=N,
n_orient=n_orient,
min_var=min_var,
)
posterior_var_full = np.repeat(posterior_var[:, :, None], T, axis=2)
ci_lower, ci_upper = self.credible_intervals_normal(
mean=x_hat,
variance=posterior_var_full,
nominal_coverage=float(nominal_coverage),
)
within = (x_true >= ci_lower) & (x_true <= ci_upper)
count_within = int(np.sum(within))
total_count = int(within.size)
return {
"ci_lower": ci_lower,
"ci_upper": ci_upper,
"within": within,
"posterior_var": posterior_var,
"posterior_var_full": posterior_var_full,
"z_score": float(self._get_z(float(nominal_coverage))),
"count_within": count_within,
"total_count": total_count,
"empirical_coverage": float(count_within / total_count),
"aggregation_axes": "sources, directions, time",
"n_sources": int(N),
"n_orient": int(K),
"n_times": int(T),
}
[docs]
def aggregated_componentwise_interval_membership_free(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_uncert: np.ndarray,
nominal_coverage: float,
*,
n_orient: int,
min_var: float = 1e-12,
) -> Dict[str, Any]:
"""Temporally aggregated marginal (component-wise) CI membership.
This corresponds to building the diagnostic on time-averaged signals:
- Aggregate `x_true`/`x_hat` over time (mean over T).
- Scale marginal variances by 1/T (variance of the mean).
- Build scalar CIs per component and pool membership over sources and
components.
"""
x_true = np.asarray(x_true, dtype=float)
x_hat = np.asarray(x_hat, dtype=float)
if x_true.ndim != 3 or x_true.shape[1] != n_orient:
raise ValueError(f"x_true must have shape (N,{n_orient},T). Got {x_true.shape}.")
if x_hat.shape != x_true.shape:
raise ValueError(f"x_hat must match x_true shape. Got {x_hat.shape} vs {x_true.shape}.")
N, K, T = x_true.shape
x_true_agg = np.mean(x_true, axis=2)
x_hat_agg = np.mean(x_hat, axis=2)
posterior_var = self.componentwise_variance_from_uncert(
posterior_uncert,
n_sources=N,
n_orient=n_orient,
min_var=min_var,
)
posterior_var_agg = posterior_var / float(T)
ci_lower, ci_upper = self.credible_intervals_normal(
mean=x_hat_agg,
variance=posterior_var_agg,
nominal_coverage=float(nominal_coverage),
)
within = (x_true_agg >= ci_lower) & (x_true_agg <= ci_upper)
count_within = int(np.sum(within))
total_count = int(within.size)
return {
"x_true_agg": x_true_agg,
"x_hat_agg": x_hat_agg,
"ci_lower": ci_lower,
"ci_upper": ci_upper,
"within": within,
"posterior_var": posterior_var,
"posterior_var_agg": posterior_var_agg,
"z_score": float(self._get_z(float(nominal_coverage))),
"count_within": count_within,
"total_count": total_count,
"empirical_coverage": float(count_within / total_count),
"aggregation_axes": "sources, directions",
"n_sources": int(N),
"n_orient": int(K),
"n_times": int(T),
}
[docs]
def calibration_curve_componentwise_intervals_pointwise_free(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_uncert: np.ndarray,
*,
n_orient: int,
min_var: float = 1e-12,
) -> Dict[str, Any]:
"""Calibration curve for marginal (component-wise) intervals (pointwise).
Summary
-------
Builds scalar normal credible intervals per component (using only
marginal variances) and pools membership over sources, components, and
time. This is the `interval_type="marginal"` diagnostic for free
orientation.
Shapes
------
- EEG free: `n_orient=3`, `x_true/x_hat` are (N,3,T)
- Reduced MEG free: `n_orient=2`, `x_true/x_hat` are (N,2,T)
- `posterior_uncert`: either full (K*N,K*N) or per-source blocks (N,K,K)
Important
---------
Components are not evaluated separately (pooled), because local
coordinate labels are arbitrary.
"""
empirical_coverages = []
counts = []
totals = []
nominal_coverages = np.asarray(self.nominal_coverages, dtype=float)
for c in nominal_coverages:
out = self.pointwise_componentwise_interval_membership_free(
x_true=x_true,
x_hat=x_hat,
posterior_uncert=posterior_uncert,
nominal_coverage=float(c),
n_orient=n_orient,
min_var=min_var,
)
empirical_coverages.append(out["empirical_coverage"])
counts.append(out["count_within"])
totals.append(out["total_count"])
return {
"nominal_coverages": nominal_coverages,
"empirical_coverages": np.asarray(empirical_coverages, dtype=float),
"ci_counts": np.asarray(counts, dtype=int),
"total_counts": np.asarray(totals, dtype=int),
"n_orient": int(n_orient),
"mode": "pointwise",
"interval_type": "marginal",
"direction_aggregation": "pooled_over_all_components",
}
[docs]
def calibration_curve_componentwise_intervals_aggregated_free(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_uncert: np.ndarray,
*,
n_orient: int,
min_var: float = 1e-12,
) -> Dict[str, Any]:
"""Calibration curve for marginal (component-wise) intervals (aggregated).
Summary
-------
Temporally aggregated version of the `interval_type="marginal"` free
diagnostic:
1) average `x_true`/`x_hat` over time,
2) scale marginal variances by 1/T,
3) build scalar CIs per component,
4) pool membership over sources and components.
This is the mode used by `workflows/calibration.py`.
Important
---------
Components are pooled (no direction-wise curves), because local
coordinate systems are arbitrary.
"""
empirical_coverages = []
counts = []
totals = []
nominal_coverages = np.asarray(self.nominal_coverages, dtype=float)
for c in nominal_coverages:
out = self.aggregated_componentwise_interval_membership_free(
x_true=x_true,
x_hat=x_hat,
posterior_uncert=posterior_uncert,
nominal_coverage=float(c),
n_orient=n_orient,
min_var=min_var,
)
empirical_coverages.append(out["empirical_coverage"])
counts.append(out["count_within"])
totals.append(out["total_count"])
return {
"nominal_coverages": nominal_coverages,
"empirical_coverages": np.asarray(empirical_coverages, dtype=float),
"ci_counts": np.asarray(counts, dtype=int),
"total_counts": np.asarray(totals, dtype=int),
"n_orient": int(n_orient),
"mode": "aggregated",
"interval_type": "marginal",
"direction_aggregation": "pooled_over_all_components",
}
[docs]
def calibration_curve_componentwise_eeg_free_pointwise(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_uncert: np.ndarray,
*,
min_var: float = 1e-12,
) -> Dict[str, Any]:
"""Pointwise marginal calibration curve for free-orientation EEG (K=3).
Expected shapes:
- `x_true`, `x_hat`: (N,3,T)
- `posterior_uncert`: (3N,3N) or per-source blocks (N,3,3)
"""
return self.calibration_curve_componentwise_intervals_pointwise_free(
x_true=x_true,
x_hat=x_hat,
posterior_uncert=posterior_uncert,
n_orient=3,
min_var=min_var,
)
[docs]
def calibration_curve_componentwise_eeg_free_aggregated(
self,
x_true: np.ndarray,
x_hat: np.ndarray,
posterior_uncert: np.ndarray,
*,
min_var: float = 1e-12,
) -> Dict[str, Any]:
"""Aggregated marginal calibration curve for free-orientation EEG (K=3).
Expected shapes:
- `x_true`, `x_hat`: (N,3,T)
- `posterior_uncert`: (3N,3N) or per-source blocks (N,3,3)
"""
return self.calibration_curve_componentwise_intervals_aggregated_free(
x_true=x_true,
x_hat=x_hat,
posterior_uncert=posterior_uncert,
n_orient=3,
min_var=min_var,
)
[docs]
def calibration_curve_componentwise_meg_free_pointwise(
self,
x_true_2d: np.ndarray,
x_hat_2d: np.ndarray,
posterior_uncert_2d: np.ndarray,
*,
min_var: float = 1e-12,
) -> Dict[str, Any]:
"""Pointwise marginal calibration curve for reduced free-orientation MEG (K=2).
Expected shapes:
- `x_true_2d`, `x_hat_2d`: (N,2,T)
- `posterior_uncert_2d`: (2N,2N) or per-source blocks (N,2,2)
Notes
-----
This diagnostic operates directly in reduced 2D coefficient coordinates:
no 2D->3D lifting and no 3D->2D projection is performed here.
"""
return self.calibration_curve_componentwise_intervals_pointwise_free(
x_true=x_true_2d,
x_hat=x_hat_2d,
posterior_uncert=posterior_uncert_2d,
n_orient=2,
min_var=min_var,
)
[docs]
def calibration_curve_componentwise_meg_free_aggregated(
self,
x_true_2d: np.ndarray,
x_hat_2d: np.ndarray,
posterior_uncert_2d: np.ndarray,
*,
min_var: float = 1e-12,
) -> Dict[str, Any]:
"""Aggregated marginal calibration curve for reduced free-orientation MEG (K=2).
Expected shapes:
- `x_true_2d`, `x_hat_2d`: (N,2,T)
- `posterior_uncert_2d`: (2N,2N) or per-source blocks (N,2,2)
Notes
-----
This diagnostic operates directly in reduced 2D coefficient coordinates:
no 2D->3D lifting and no 3D->2D projection is performed here.
"""
return self.calibration_curve_componentwise_intervals_aggregated_free(
x_true=x_true_2d,
x_hat=x_hat_2d,
posterior_uncert=posterior_uncert_2d,
n_orient=2,
min_var=min_var,
)
# ------------------------------------------------------------------
# MEG free orientation (reduced 2D model)
# ------------------------------------------------------------------
[docs]
def reduce_meg_leadfield_svd(
self,
L_free_Mx3N: np.ndarray,
) -> Dict[str, np.ndarray]:
"""
Fallback-only helper: build a source-wise SVD-reduced MEG leadfield.
For each source i:
L_i = U_i S_i V_i^T, with L_i in R^{M x 3}
keep the first two right singular vectors V_i(:, :2), and define
L_i_reduced = L_i @ V_i(:, :2) in R^{M x 2}
Recommended usage
-----------------
In the current pipeline, prefer the extractor-provided basis
Q_basis = lf_free_meg["Q_basis"]
rather than recomputing a new basis here.
Returns
-------
dict with:
- L_meg_Mx2N: reduced global leadfield, shape (M, 2N)
- V_tan: source-wise 3D->2D basis, shape (N, 3, 2)
- singular_values: shape (N, 2)
"""
L_free_Mx3N = np.asarray(L_free_Mx3N, dtype=float)
M, threeN = L_free_Mx3N.shape
if threeN % 3 != 0:
raise ValueError(f"L_free_Mx3N must be (M,3N). Got {L_free_Mx3N.shape}")
N = threeN // 3
L_meg_Mx2N = np.zeros((M, 2 * N), dtype=float)
V_tan = np.zeros((N, 3, 2), dtype=float)
singular_values = np.zeros((N, 2), dtype=float)
for i in range(N):
Li = L_free_Mx3N[:, 3 * i:3 * i + 3] # (M,3)
U, s, Vt = np.linalg.svd(Li, full_matrices=False)
V2 = self._canonicalize_basis_columns(Vt.T[:, :2]) # (3,2)
Li_red = Li @ V2 # (M,2)
L_meg_Mx2N[:, 2 * i:2 * i + 2] = Li_red
V_tan[i] = V2
singular_values[i] = s[:2]
return {
"L_meg_Mx2N": L_meg_Mx2N,
"V_tan": V_tan,
"singular_values": singular_values,
}
[docs]
def precompute_meg_tangent_bases_svd(self, L_free_Mx3N: np.ndarray) -> np.ndarray:
"""
Backward-compatible wrapper.
Returns source-wise local 3x2 reduction bases from the original free
leadfield. Prefer passing the extractor basis directly whenever possible.
"""
return self.reduce_meg_leadfield_svd(L_free_Mx3N)["V_tan"]
[docs]
@staticmethod
def project_meg_3d_to_2d(
x_3d: np.ndarray,
V_tan: np.ndarray,
) -> np.ndarray:
"""
Project 3D MEG truth/estimate onto the local reduced 2D coordinates.
Parameters
----------
x_3d : (N,3,T) or (N,3)
V_tan : (N,3,2)
Returns
-------
x_2d : (N,2,T) or (N,2)
"""
x_3d = np.asarray(x_3d, dtype=float)
V_tan = np.asarray(V_tan, dtype=float)
if V_tan.ndim != 3 or V_tan.shape[1:] != (3, 2):
raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}")
N = V_tan.shape[0]
if x_3d.ndim == 3:
if x_3d.shape[0] != N or x_3d.shape[1] != 3:
raise ValueError(f"x_3d must be (N,3,T); got {x_3d.shape}")
return np.einsum("nck,nct->nkt", V_tan, x_3d)
if x_3d.ndim == 2:
if x_3d.shape != (N, 3):
raise ValueError(f"x_3d must be (N,3); got {x_3d.shape}")
return np.einsum("nck,nc->nk", V_tan, x_3d)
raise ValueError("x_3d must have shape (N,3,T) or (N,3).")
[docs]
@staticmethod
def lift_meg_2d_to_3d(
x_2d: np.ndarray,
V_tan: np.ndarray,
) -> np.ndarray:
"""
Lift reduced 2D MEG coordinates back to the retained local 3D basis.
Parameters
----------
x_2d : (N,2,T) or (N,2)
V_tan : (N,3,2)
Returns
-------
x_3d_lifted : (N,3,T) or (N,3)
"""
x_2d = np.asarray(x_2d, dtype=float)
V_tan = np.asarray(V_tan, dtype=float)
if V_tan.ndim != 3 or V_tan.shape[1:] != (3, 2):
raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}")
N = V_tan.shape[0]
if x_2d.ndim == 3:
if x_2d.shape[0] != N or x_2d.shape[1] != 2:
raise ValueError(f"x_2d must be (N,2,T); got {x_2d.shape}")
return np.einsum("nck,nkt->nct", V_tan, x_2d)
if x_2d.ndim == 2:
if x_2d.shape != (N, 2):
raise ValueError(f"x_2d must be (N,2); got {x_2d.shape}")
return np.einsum("nck,nk->nc", V_tan, x_2d)
raise ValueError("x_2d must have shape (N,2,T) or (N,2).")
[docs]
def pointwise_ellipse_membership_meg_free(
self,
x_true_3d: np.ndarray, # (N,3,T)
x_hat_2d: np.ndarray, # (N,2,T)
posterior_cov_2d: np.ndarray, # (2N,2N) or (N,2,2)
nominal_coverage: float,
*,
V_tan: Optional[np.ndarray] = None,
L_free_Mx3N: Optional[np.ndarray] = None,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
) -> Dict[str, Any]:
"""
Direct 2D MEG credible ellipse membership in reduced coordinates.
Truth is provided in 3D and projected to reduced coordinates:
x_true_2d = V_tan^T x_true_3d
Posterior mean and covariance are assumed to already come from the
reduced 2D MEG inverse model.
Recommended usage
-----------------
Pass `V_tan=Q_basis` from the leadfield extractor to avoid recomputing
a potentially sign-flipped basis.
"""
x_true_3d = np.asarray(x_true_3d, dtype=float)
x_hat_2d = np.asarray(x_hat_2d, dtype=float)
posterior_cov_2d = np.asarray(posterior_cov_2d, dtype=float)
thresh = self._get_chi2_threshold(nominal_coverage, df=2)
if x_true_3d.ndim != 3 or x_true_3d.shape[1] != 3:
raise ValueError(f"x_true_3d must be (N,3,T); got {x_true_3d.shape}")
if x_hat_2d.ndim != 3 or x_hat_2d.shape[1] != 2:
raise ValueError(f"x_hat_2d must be (N,2,T); got {x_hat_2d.shape}")
N, _, T = x_true_3d.shape
if x_hat_2d.shape[0] != N or x_hat_2d.shape[2] != T:
raise ValueError("x_hat_2d must have shape (N,2,T) matching x_true_3d.")
cov_is_blocks = posterior_cov_2d.ndim == 3
if cov_is_blocks:
if posterior_cov_2d.shape != (N, 2, 2):
raise ValueError(
f"posterior_cov_2d block form must be (N,2,2); got {posterior_cov_2d.shape}"
)
else:
if posterior_cov_2d.shape != (2 * N, 2 * N):
raise ValueError(
f"posterior_cov_2d must be (2N,2N); got {posterior_cov_2d.shape}"
)
if V_tan is None:
if L_free_Mx3N is None:
raise ValueError("Provide either V_tan or L_free_Mx3N for MEG 3D->2D projection.")
V_tan = self.precompute_meg_tangent_bases_svd(L_free_Mx3N)
V_tan = np.asarray(V_tan, dtype=float)
if V_tan.shape != (N, 3, 2):
raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}")
x_true_2d = self.project_meg_3d_to_2d(x_true_3d, V_tan)
q_values = np.zeros((N, T), dtype=float)
within = np.zeros((N, T), dtype=bool)
cov_blocks = np.zeros((N, 2, 2), dtype=float)
for i in range(N):
Sigma2 = posterior_cov_2d[i] if cov_is_blocks else posterior_cov_2d[2 * i:2 * i + 2, 2 * i:2 * i + 2]
Sigma2 = (Sigma2 + Sigma2.T) / 2.0
if psd_repair_blocks:
Sigma2 = self._psd_clip_block(Sigma2, block_epsilon)
cov_blocks[i] = Sigma2
evals, evecs = np.linalg.eigh(Sigma2)
evals = np.maximum(evals, block_epsilon)
Sigma2_inv = evecs @ np.diag(1.0 / evals) @ evecs.T
d2 = x_true_2d[i] - x_hat_2d[i] # (2,T)
q = np.einsum("it,ij,jt->t", d2, Sigma2_inv, d2)
q_values[i] = q
within[i] = (q <= thresh)
return {
"q_values": q_values,
"threshold": float(thresh),
"within": within,
"projected_true": x_true_2d,
"projected_mean": x_hat_2d,
"cov_blocks": cov_blocks,
"projected_cov_blocks": cov_blocks,
"V_tan": V_tan,
"count_within": int(np.sum(within)),
"total_count": int(within.size),
"empirical_coverage": float(np.mean(within)),
"n_times": int(T),
}
[docs]
def aggregated_ellipse_membership_meg_free(
self,
x_true_3d: np.ndarray, # (N,3,T)
x_hat_2d: np.ndarray, # (N,2,T)
posterior_cov_2d: np.ndarray, # (2N,2N) or (N,2,2)
nominal_coverage: float,
*,
V_tan: Optional[np.ndarray] = None,
L_free_Mx3N: Optional[np.ndarray] = None,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
) -> Dict[str, Any]:
"""
Aggregated 2D MEG credible ellipse membership in reduced coordinates.
Recommended usage
-----------------
Pass `V_tan=Q_basis` from the leadfield extractor.
"""
x_true_3d = np.asarray(x_true_3d, dtype=float)
x_hat_2d = np.asarray(x_hat_2d, dtype=float)
posterior_cov_2d = np.asarray(posterior_cov_2d, dtype=float)
if x_true_3d.ndim != 3 or x_true_3d.shape[1] != 3:
raise ValueError(f"x_true_3d must be (N,3,T); got {x_true_3d.shape}")
if x_hat_2d.ndim != 3 or x_hat_2d.shape[1] != 2:
raise ValueError(f"x_hat_2d must be (N,2,T); got {x_hat_2d.shape}")
N, _, T = x_true_3d.shape
if x_hat_2d.shape[0] != N or x_hat_2d.shape[2] != T:
raise ValueError("x_hat_2d must have shape (N,2,T) matching x_true_3d.")
cov_is_blocks = posterior_cov_2d.ndim == 3
if cov_is_blocks:
if posterior_cov_2d.shape != (N, 2, 2):
raise ValueError(
f"posterior_cov_2d block form must be (N,2,2); got {posterior_cov_2d.shape}"
)
else:
if posterior_cov_2d.shape != (2 * N, 2 * N):
raise ValueError(
f"posterior_cov_2d must be (2N,2N); got {posterior_cov_2d.shape}"
)
if V_tan is None:
if L_free_Mx3N is None:
raise ValueError("Provide either V_tan or L_free_Mx3N for MEG 3D->2D projection.")
V_tan = self.precompute_meg_tangent_bases_svd(L_free_Mx3N)
V_tan = np.asarray(V_tan, dtype=float)
if V_tan.shape != (N, 3, 2):
raise ValueError(f"V_tan must be (N,3,2); got {V_tan.shape}")
x_true_2d = self.project_meg_3d_to_2d(x_true_3d, V_tan)
x_true_agg = np.mean(x_true_2d, axis=2)
x_hat_agg = np.mean(x_hat_2d, axis=2)
thresh = self._get_chi2_threshold(nominal_coverage, df=2)
q_values = np.zeros(N, dtype=float)
within = np.zeros(N, dtype=bool)
cov_blocks = np.zeros((N, 2, 2), dtype=float)
for i in range(N):
Sigma2 = posterior_cov_2d[i] if cov_is_blocks else posterior_cov_2d[2 * i:2 * i + 2, 2 * i:2 * i + 2]
Sigma2 = (Sigma2 + Sigma2.T) / 2.0
if psd_repair_blocks:
Sigma2 = self._psd_clip_block(Sigma2, block_epsilon)
Sigma2_agg = Sigma2 / float(T)
cov_blocks[i] = Sigma2_agg
evals, evecs = np.linalg.eigh(Sigma2_agg)
evals = np.maximum(evals, block_epsilon)
Sigma2_inv = evecs @ np.diag(1.0 / evals) @ evecs.T
d2 = x_true_agg[i] - x_hat_agg[i]
q = float(d2.T @ Sigma2_inv @ d2)
q_values[i] = q
within[i] = (q <= thresh)
return {
"x_true_agg_2d": x_true_agg,
"x_hat_agg_2d": x_hat_agg,
"q_values": q_values,
"threshold": float(thresh),
"within": within,
"projected_true": x_true_agg,
"projected_mean": x_hat_agg,
"cov_blocks": cov_blocks,
"projected_cov_blocks": cov_blocks,
"V_tan": V_tan,
"count_within": int(np.sum(within)),
"total_count": int(within.size),
"empirical_coverage": float(np.mean(within)),
"n_times": int(T),
}
[docs]
def calibration_curve_ellipse_meg_free_pointwise(
self,
x_true_3d: np.ndarray,
x_hat_2d: np.ndarray,
posterior_cov_2d: np.ndarray,
*,
V_tan: Optional[np.ndarray] = None,
L_free_Mx3N: Optional[np.ndarray] = None,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
) -> Dict[str, Any]:
empirical_coverages = []
counts = []
for c in self.nominal_coverages:
out = self.pointwise_ellipse_membership_meg_free(
x_true_3d=x_true_3d,
x_hat_2d=x_hat_2d,
posterior_cov_2d=posterior_cov_2d,
nominal_coverage=float(c),
V_tan=V_tan,
L_free_Mx3N=L_free_Mx3N,
psd_repair_blocks=psd_repair_blocks,
block_epsilon=block_epsilon,
)
empirical_coverages.append(out["empirical_coverage"])
counts.append(out["count_within"])
return {
"nominal_coverages": self.nominal_coverages,
"empirical_coverages": np.asarray(empirical_coverages, dtype=float),
"ci_counts": np.asarray(counts, dtype=int),
"interval_type": "full_cov",
}
[docs]
def calibration_curve_ellipse_meg_free_aggregated(
self,
x_true_3d: np.ndarray,
x_hat_2d: np.ndarray,
posterior_cov_2d: np.ndarray,
*,
V_tan: Optional[np.ndarray] = None,
L_free_Mx3N: Optional[np.ndarray] = None,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
) -> Dict[str, Any]:
empirical_coverages = []
counts = []
for c in self.nominal_coverages:
out = self.aggregated_ellipse_membership_meg_free(
x_true_3d=x_true_3d,
x_hat_2d=x_hat_2d,
posterior_cov_2d=posterior_cov_2d,
nominal_coverage=float(c),
V_tan=V_tan,
L_free_Mx3N=L_free_Mx3N,
psd_repair_blocks=psd_repair_blocks,
block_epsilon=block_epsilon,
)
empirical_coverages.append(out["empirical_coverage"])
counts.append(out["count_within"])
return {
"nominal_coverages": self.nominal_coverages,
"empirical_coverages": np.asarray(empirical_coverages, dtype=float),
"ci_counts": np.asarray(counts, dtype=int),
"interval_type": "full_cov",
}
# ------------------------------------------------------------------
# Visualization: fixed orientation
# ------------------------------------------------------------------
[docs]
def plot_fixed_interval_membership_pointwise(
self,
x_true: np.ndarray, # (N,T)
x_hat: np.ndarray, # (N,T)
posterior_var: np.ndarray, # (N,)
src_idx: int,
nominal_coverage: float = 0.95,
*,
times: Optional[np.ndarray] = None,
figsize: Tuple[float, float] = (9.0, 4.0),
) -> None:
import matplotlib.pyplot as plt
x_true = np.asarray(x_true, dtype=float)
x_hat = np.asarray(x_hat, dtype=float)
posterior_var = np.asarray(posterior_var, dtype=float)
if x_true.ndim != 2 or x_hat.ndim != 2:
raise ValueError("x_true and x_hat must have shape (N,T).")
if x_true.shape != x_hat.shape:
raise ValueError("x_true and x_hat must have the same shape.")
if posterior_var.ndim != 1 or posterior_var.shape[0] != x_hat.shape[0]:
raise ValueError("posterior_var must have shape (N,).")
_, T = x_hat.shape
if times is None:
times = np.arange(T)
else:
times = np.asarray(times)
if times.shape[0] != T:
raise ValueError("times must have length T.")
var_full = np.repeat(posterior_var[:, None], T, axis=1)
lo, hi = self.credible_intervals_normal(x_hat, var_full, nominal_coverage)
within = (x_true >= lo) & (x_true <= hi)
inside_all = bool(np.all(within[src_idx]))
fig, ax = plt.subplots(figsize=figsize)
ax.fill_between(
times,
lo[src_idx],
hi[src_idx],
alpha=0.25,
label=f"{int(100 * nominal_coverage)}% credible interval",
)
ax.plot(times, x_hat[src_idx], label="posterior mean")
ax.plot(times, x_true[src_idx], label="x_true")
outside = ~within[src_idx]
if np.any(outside):
ax.scatter(
times[outside],
x_true[src_idx, outside],
marker="x",
s=40,
label="outside interval",
zorder=5,
)
ax.set_title(f"Fixed source={src_idx}, all-times-inside={inside_all}")
ax.set_xlabel("Time")
ax.set_ylabel("Source value")
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
plt.tight_layout()
plt.show()
[docs]
def plot_fixed_interval_membership_aggregated(
self,
diag_fixed: Dict[str, Any],
src_idx: int,
*,
nominal_coverage: float = 0.95,
figsize: Tuple[float, float] = (7.0, 2.2),
) -> None:
import matplotlib.pyplot as plt
lo = float(diag_fixed["ci_lower"][src_idx])
hi = float(diag_fixed["ci_upper"][src_idx])
mu = float(diag_fixed["x_hat_agg"][src_idx])
xt = float(diag_fixed["x_true_agg"][src_idx])
inside = bool(diag_fixed["within"][src_idx])
vals = np.array([lo, hi, mu, xt], dtype=float)
pad = 0.15 * max(np.ptp(vals), 1e-6)
fig, ax = plt.subplots(figsize=figsize)
ax.axvspan(lo, hi, alpha=0.25, label=f"{int(100 * nominal_coverage)}% credible interval")
ax.scatter(mu, 0.0, s=90, label="aggregated posterior mean", zorder=3)
ax.scatter(xt, 0.0, s=110, marker="x", label="aggregated x_true", zorder=4)
ax.set_xlim(np.min(vals) - pad, np.max(vals) + pad)
ax.set_yticks([])
ax.set_xlabel("Aggregated source value")
ax.set_title(f"Fixed source={src_idx}, inside={inside}")
for spine in ax.spines.values():
spine.set_visible(False)
ax.grid(False)
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
plt.tight_layout()
plt.show()
# ------------------------------------------------------------------
# Visualization: EEG free orientation
# ------------------------------------------------------------------
[docs]
def plot_eeg_ellipsoid_membership_pointwise(
self,
x_hat: np.ndarray, # (N,3,T)
x_true: np.ndarray, # (N,3,T)
posterior_cov: np.ndarray, # (3N,3N)
src_idx: int,
time_idx: int,
nominal_coverage: float = 0.95,
*,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
figsize: Tuple[float, float] = (8.0, 6.0),
elev: float = 22.0,
azim: float = -58.0,
) -> None:
import matplotlib.pyplot as plt
if x_true.ndim != 3 or x_true.shape[1] != 3:
raise ValueError(f"x_true must be (N,3,T); got {x_true.shape}")
if x_hat.shape != x_true.shape:
raise ValueError("x_hat must match x_true shape.")
N, _, T = x_true.shape
if posterior_cov.shape != (3 * N, 3 * N):
raise ValueError(f"posterior_cov must be (3N,3N); got {posterior_cov.shape}")
if not (0 <= src_idx < N):
raise ValueError("src_idx out of range.")
if not (0 <= time_idx < T):
raise ValueError("time_idx out of range.")
threshold = self._get_chi2_threshold(nominal_coverage, df=3)
center = x_hat[src_idx, :, time_idx]
truth = x_true[src_idx, :, time_idx]
Sigma3 = posterior_cov[3 * src_idx:3 * src_idx + 3, 3 * src_idx:3 * src_idx + 3]
Sigma3 = (Sigma3 + Sigma3.T) / 2.0
if psd_repair_blocks:
Sigma3 = self._psd_clip_block(Sigma3, block_epsilon)
evals, evecs = np.linalg.eigh(Sigma3)
evals = np.maximum(evals, block_epsilon)
Sigma3_inv = evecs @ np.diag(1.0 / evals) @ evecs.T
d = truth - center
q_val = float(d.T @ Sigma3_inv @ d)
inside = bool(q_val <= threshold)
radii = np.sqrt(threshold * evals)
u = np.linspace(0.0, 2.0 * np.pi, 60)
v = np.linspace(0.0, np.pi, 30)
xs = np.outer(np.cos(u), np.sin(v))
ys = np.outer(np.sin(u), np.sin(v))
zs = np.outer(np.ones_like(u), np.cos(v))
sphere = np.stack([xs, ys, zs], axis=0).reshape(3, -1)
ell = (evecs @ np.diag(radii) @ sphere).reshape(3, xs.shape[0], xs.shape[1])
X = center[0] + ell[0]
Y = center[1] + ell[1]
Z = center[2] + ell[2]
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection="3d")
ax.plot_wireframe(X, Y, Z, rstride=2, cstride=2, alpha=0.35, linewidth=0.9)
for k in range(3):
axis_vec = evecs[:, k] * radii[k]
p1 = center - axis_vec
p2 = center + axis_vec
ax.plot(
[p1[0], p2[0]],
[p1[1], p2[1]],
[p1[2], p2[2]],
linewidth=1.2,
alpha=0.9,
)
ax.scatter(center[0], center[1], center[2], s=80, label="posterior mean", zorder=5)
ax.scatter(truth[0], truth[1], truth[2], s=100, marker="x", label="x_true", zorder=6)
ax.plot(
[center[0], truth[0]],
[center[1], truth[1]],
[center[2], truth[2]],
linestyle="--",
linewidth=1.0,
alpha=0.8,
)
pts = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
pts = np.vstack([pts, center[None, :], truth[None, :]])
self._set_equal_3d_limits_centered(ax, center=center, xyz_points=pts, margin=0.08)
ax.view_init(elev=elev, azim=azim)
ax.set_title(f"EEG source={src_idx}, time={time_idx}, inside={inside}")
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0), frameon=False)
plt.tight_layout()
plt.show()
[docs]
def plot_eeg_ellipsoid_membership_aggregated(
self,
diag_eeg: Dict[str, Any],
src_idx: int,
*,
figsize: Tuple[float, float] = (8.0, 6.0),
elev: float = 22.0,
azim: float = -58.0,
) -> None:
import matplotlib.pyplot as plt
center = diag_eeg["x_hat_agg"][src_idx]
truth = diag_eeg["x_true_agg"][src_idx]
Sigma3 = diag_eeg["cov_blocks"][src_idx]
threshold = float(diag_eeg["threshold"])
inside = bool(diag_eeg["within"][src_idx])
Sigma3 = (Sigma3 + Sigma3.T) / 2.0
evals, evecs = np.linalg.eigh(Sigma3)
evals = np.maximum(evals, 1e-12)
radii = np.sqrt(threshold * evals)
u = np.linspace(0.0, 2.0 * np.pi, 60)
v = np.linspace(0.0, np.pi, 30)
xs = np.outer(np.cos(u), np.sin(v))
ys = np.outer(np.sin(u), np.sin(v))
zs = np.outer(np.ones_like(u), np.cos(v))
sphere = np.stack([xs, ys, zs], axis=0).reshape(3, -1)
ell = (evecs @ np.diag(radii) @ sphere).reshape(3, xs.shape[0], xs.shape[1])
X = center[0] + ell[0]
Y = center[1] + ell[1]
Z = center[2] + ell[2]
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection="3d")
ax.plot_wireframe(X, Y, Z, rstride=2, cstride=2, alpha=0.35, linewidth=0.9)
for k in range(3):
axis_vec = evecs[:, k] * radii[k]
p1 = center - axis_vec
p2 = center + axis_vec
ax.plot(
[p1[0], p2[0]],
[p1[1], p2[1]],
[p1[2], p2[2]],
linewidth=1.2,
alpha=0.9,
)
ax.scatter(center[0], center[1], center[2], s=80, label="aggregated posterior mean", zorder=5)
ax.scatter(truth[0], truth[1], truth[2], s=100, marker="x", label="aggregated x_true", zorder=6)
ax.plot(
[center[0], truth[0]],
[center[1], truth[1]],
[center[2], truth[2]],
linestyle="--",
linewidth=1.0,
alpha=0.8,
)
pts = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
pts = np.vstack([pts, center[None, :], truth[None, :]])
self._set_equal_3d_limits_centered(ax, center=center, xyz_points=pts, margin=0.08)
ax.view_init(elev=elev, azim=azim)
ax.set_title(f"EEG source={src_idx}, inside={inside}")
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0), frameon=False)
plt.tight_layout()
plt.show()
# ------------------------------------------------------------------
# Visualization: MEG free orientation (reduced 2D posterior)
# ------------------------------------------------------------------
[docs]
def plot_meg_ellipse_membership_pointwise(
self,
x_true_3d: np.ndarray, # (N,3,T)
x_hat_2d: np.ndarray, # (N,2,T)
posterior_cov_2d: np.ndarray, # (2N,2N)
src_idx: int,
time_idx: int,
nominal_coverage: float = 0.95,
*,
V_tan: Optional[np.ndarray] = None,
L_free_Mx3N: Optional[np.ndarray] = None,
psd_repair_blocks: bool = False,
block_epsilon: float = 1e-12,
figsize: Tuple[float, float] = (6.5, 6.0),
) -> None:
import matplotlib.pyplot as plt
diag = self.pointwise_ellipse_membership_meg_free(
x_true_3d=x_true_3d,
x_hat_2d=x_hat_2d,
posterior_cov_2d=posterior_cov_2d,
nominal_coverage=nominal_coverage,
V_tan=V_tan,
L_free_Mx3N=L_free_Mx3N,
psd_repair_blocks=psd_repair_blocks,
block_epsilon=block_epsilon,
)
center2 = diag["projected_mean"][src_idx, :, time_idx]
truth2 = diag["projected_true"][src_idx, :, time_idx]
Sigma2 = diag["cov_blocks"][src_idx]
threshold = float(diag["threshold"])
inside = bool(diag["within"][src_idx, time_idx])
Sigma2 = (Sigma2 + Sigma2.T) / 2.0
evals, evecs = np.linalg.eigh(Sigma2)
evals = np.maximum(evals, 1e-12)
radii = np.sqrt(threshold * evals)
theta = np.linspace(0.0, 2.0 * np.pi, 361)
circle = np.vstack([np.cos(theta), np.sin(theta)])
ellipse = evecs @ np.diag(radii) @ circle
ex = center2[0] + ellipse[0]
ey = center2[1] + ellipse[1]
fig, ax = plt.subplots(figsize=figsize)
ax.plot(ex, ey, linewidth=1.4, label=f"{int(100 * nominal_coverage)}% credible ellipse")
for k in range(2):
axis_vec = evecs[:, k] * radii[k]
p1 = center2 - axis_vec
p2 = center2 + axis_vec
ax.plot(
[p1[0], p2[0]],
[p1[1], p2[1]],
linewidth=1.2,
alpha=0.9,
label=f"Axis {k + 1}",
)
ax.scatter(center2[0], center2[1], s=90, label="posterior mean", zorder=3)
ax.scatter(truth2[0], truth2[1], s=110, marker="x", label="projected x_true", zorder=4)
ax.plot(
[center2[0], truth2[0]],
[center2[1], truth2[1]],
linestyle="--",
linewidth=1.0,
alpha=0.8,
)
x_all = np.concatenate([ex, [truth2[0], center2[0]]])
y_all = np.concatenate([ey, [truth2[1], center2[1]]])
pad_x = 0.15 * max(np.ptp(x_all), 1e-6)
pad_y = 0.15 * max(np.ptp(y_all), 1e-6)
ax.set_xlim(np.min(x_all) - pad_x, np.max(x_all) + pad_x)
ax.set_ylim(np.min(y_all) - pad_y, np.max(y_all) + pad_y)
ax.set_aspect("equal", adjustable="box")
ax.set_xlabel("Reduced MEG coordinate 1")
ax.set_ylabel("Reduced MEG coordinate 2")
ax.set_title(f"MEG source={src_idx}, time={time_idx}, inside={inside}")
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
plt.tight_layout()
plt.show()
[docs]
def plot_meg_ellipse_membership_aggregated(
self,
diag_meg: Dict[str, Any],
src_idx: int,
*,
nominal_coverage: float = 0.95,
figsize: Tuple[float, float] = (6.5, 6.0),
) -> None:
import matplotlib.pyplot as plt
center2 = diag_meg["projected_mean"][src_idx]
truth2 = diag_meg["projected_true"][src_idx]
Sigma2 = diag_meg["cov_blocks"][src_idx]
threshold = float(diag_meg["threshold"])
inside = bool(diag_meg["within"][src_idx])
Sigma2 = (Sigma2 + Sigma2.T) / 2.0
evals, evecs = np.linalg.eigh(Sigma2)
evals = np.maximum(evals, 1e-12)
radii = np.sqrt(threshold * evals)
theta = np.linspace(0.0, 2.0 * np.pi, 361)
circle = np.vstack([np.cos(theta), np.sin(theta)])
ellipse = evecs @ np.diag(radii) @ circle
ex = center2[0] + ellipse[0]
ey = center2[1] + ellipse[1]
fig, ax = plt.subplots(figsize=figsize)
ax.plot(ex, ey, linewidth=1.4, label=f"{int(100 * nominal_coverage)}% credible ellipse")
for k in range(2):
axis_vec = evecs[:, k] * radii[k]
p1 = center2 - axis_vec
p2 = center2 + axis_vec
ax.plot(
[p1[0], p2[0]],
[p1[1], p2[1]],
linewidth=1.2,
alpha=0.9,
label=f"Axis {k + 1}",
)
ax.scatter(center2[0], center2[1], s=90, label="aggregated posterior mean", zorder=3)
ax.scatter(truth2[0], truth2[1], s=110, marker="x", label="aggregated projected x_true", zorder=4)
ax.plot(
[center2[0], truth2[0]],
[center2[1], truth2[1]],
linestyle="--",
linewidth=1.0,
alpha=0.8,
)
x_all = np.concatenate([ex, [truth2[0], center2[0]]])
y_all = np.concatenate([ey, [truth2[1], center2[1]]])
pad_x = 0.15 * max(np.ptp(x_all), 1e-6)
pad_y = 0.15 * max(np.ptp(y_all), 1e-6)
ax.set_xlim(np.min(x_all) - pad_x, np.max(x_all) + pad_x)
ax.set_ylim(np.min(y_all) - pad_y, np.max(y_all) + pad_y)
ax.set_aspect("equal", adjustable="box")
ax.set_xlabel("Reduced MEG coordinate 1")
ax.set_ylabel("Reduced MEG coordinate 2")
ax.set_title(f"MEG source={src_idx}, inside={inside}")
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
plt.tight_layout()
plt.show()
# # =============================================================================
# # Example usage
# #
# # This section covers BOTH uncertainty modes for the current MEG setup:
# #
# # 1) pointwise mode
# # 2) aggregated mode
# #
# # Visualizations are called separately for the two modes.
# #
# # Active:
# # - fixed MEG (BMN)
# # - fixed MEG (BMN_joint)
# # - free MEG reduced rank-2 (BMN)
# # - free MEG reduced rank-2 (BMN_joint)
# #
# # Commented:
# # - fixed EEG
# # - free EEG
# # =============================================================================
# # ------------------------------------------------------------------
# # 1) FIXED MEG -- BMN
# # ------------------------------------------------------------------
# results_unc_fixed_meg = run_uncertainty_example_fixed(
# x_true_fixed=x_fixed_meg,
# out_fixed=out_fixed_bmn,
# nominal_coverage=0.95,
# src_idx=0,
# do_plots=False, # plots called separately below
# )
# print("\n================ FIXED MEG / BMN / POINTWISE ================")
# print("Empirical coverage :", results_unc_fixed_meg["diag_fixed_point"]["empirical_coverage"])
# print("Count within :", results_unc_fixed_meg["diag_fixed_point"]["count_within"])
# print("Total count :", results_unc_fixed_meg["diag_fixed_point"]["total_count"])
# print("\n=============== FIXED MEG / BMN / AGGREGATED ===============")
# print("Empirical coverage :", results_unc_fixed_meg["diag_fixed_agg"]["empirical_coverage"])
# print("Count within :", results_unc_fixed_meg["diag_fixed_agg"]["count_within"])
# print("Total count :", results_unc_fixed_meg["diag_fixed_agg"]["total_count"])
# # Pointwise visualization
# results_unc_fixed_meg["ue"].plot_fixed_interval_membership_pointwise(
# x_true=x_fixed_meg,
# x_hat=out_fixed_bmn["posterior_mean"],
# posterior_var=results_unc_fixed_meg["posterior_var_fixed"],
# src_idx=0,
# nominal_coverage=0.95,
# )
# # Aggregated visualization
# results_unc_fixed_meg["ue"].plot_fixed_interval_membership_aggregated(
# diag_fixed=results_unc_fixed_meg["diag_fixed_agg"],
# src_idx=0,
# nominal_coverage=0.95,
# )
# # ------------------------------------------------------------------
# # 2) FIXED MEG -- BMN_joint
# # ------------------------------------------------------------------
# results_unc_fixed_meg_joint = run_uncertainty_example_fixed(
# x_true_fixed=x_fixed_meg,
# out_fixed=out_fixed_joint,
# nominal_coverage=0.95,
# src_idx=0,
# do_plots=False, # plots called separately below
# )
# print("\n============= FIXED MEG / BMN_joint / POINTWISE =============")
# print("Empirical coverage :", results_unc_fixed_meg_joint["diag_fixed_point"]["empirical_coverage"])
# print("Count within :", results_unc_fixed_meg_joint["diag_fixed_point"]["count_within"])
# print("Total count :", results_unc_fixed_meg_joint["diag_fixed_point"]["total_count"])
# print("\n============ FIXED MEG / BMN_joint / AGGREGATED ============")
# print("Empirical coverage :", results_unc_fixed_meg_joint["diag_fixed_agg"]["empirical_coverage"])
# print("Count within :", results_unc_fixed_meg_joint["diag_fixed_agg"]["count_within"])
# print("Total count :", results_unc_fixed_meg_joint["diag_fixed_agg"]["total_count"])
# # Pointwise visualization
# results_unc_fixed_meg_joint["ue"].plot_fixed_interval_membership_pointwise(
# x_true=x_fixed_meg,
# x_hat=out_fixed_joint["posterior_mean"],
# posterior_var=results_unc_fixed_meg_joint["posterior_var_fixed"],
# src_idx=0,
# nominal_coverage=0.95,
# )
# # Aggregated visualization
# results_unc_fixed_meg_joint["ue"].plot_fixed_interval_membership_aggregated(
# diag_fixed=results_unc_fixed_meg_joint["diag_fixed_agg"],
# src_idx=0,
# nominal_coverage=0.95,
# )
# # ------------------------------------------------------------------
# # 3) FREE MEG REDUCED RANK-2 -- BMN
# # ------------------------------------------------------------------
# results_unc_meg = run_uncertainty_example_meg(
# a_true_meg_2d=a_free_meg,
# Q_basis=lf_free_meg["Q_basis"],
# out_meg=out_free_meg_bmn,
# nominal_coverage=0.95,
# src_idx=0,
# time_idx=0,
# psd_repair_blocks=True,
# block_epsilon=1e-12,
# do_plots=False, # plots called separately below
# )
# print("\n============= FREE MEG / BMN / POINTWISE ===================")
# print("Empirical coverage :", results_unc_meg["diag_meg_point"]["empirical_coverage"])
# print("Count within :", results_unc_meg["diag_meg_point"]["count_within"])
# print("Total count :", results_unc_meg["diag_meg_point"]["total_count"])
# print("\n============ FREE MEG / BMN / AGGREGATED ==================")
# print("Empirical coverage :", results_unc_meg["diag_meg_agg"]["empirical_coverage"])
# print("Count within :", results_unc_meg["diag_meg_agg"]["count_within"])
# print("Total count :", results_unc_meg["diag_meg_agg"]["total_count"])
# # Pointwise visualization
# results_unc_meg["ue"].plot_meg_ellipse_membership_pointwise(
# x_true_3d=results_unc_meg["x_true_meg_3d"],
# x_hat_2d=results_unc_meg["x_hat_meg_2d"],
# posterior_cov_2d=out_free_meg_bmn["posterior_cov"],
# src_idx=0,
# time_idx=0,
# nominal_coverage=0.95,
# V_tan=results_unc_meg["V_tan"],
# psd_repair_blocks=True,
# block_epsilon=1e-12,
# )
# # Aggregated visualization
# results_unc_meg["ue"].plot_meg_ellipse_membership_aggregated(
# diag_meg=results_unc_meg["diag_meg_agg"],
# src_idx=0,
# nominal_coverage=0.95,
# )
# # ------------------------------------------------------------------
# # 4) FREE MEG REDUCED RANK-2 -- BMN_joint
# # ------------------------------------------------------------------
# results_unc_meg_joint = run_uncertainty_example_meg(
# a_true_meg_2d=a_free_meg,
# Q_basis=lf_free_meg["Q_basis"],
# out_meg=out_free_meg_joint,
# nominal_coverage=0.95,
# src_idx=0,
# time_idx=0,
# psd_repair_blocks=True,
# block_epsilon=1e-12,
# do_plots=False, # plots called separately below
# )
# print("\n=========== FREE MEG / BMN_joint / POINTWISE ==============")
# print("Empirical coverage :", results_unc_meg_joint["diag_meg_point"]["empirical_coverage"])
# print("Count within :", results_unc_meg_joint["diag_meg_point"]["count_within"])
# print("Total count :", results_unc_meg_joint["diag_meg_point"]["total_count"])
# print("\n========== FREE MEG / BMN_joint / AGGREGATED =============")
# print("Empirical coverage :", results_unc_meg_joint["diag_meg_agg"]["empirical_coverage"])
# print("Count within :", results_unc_meg_joint["diag_meg_agg"]["count_within"])
# print("Total count :", results_unc_meg_joint["diag_meg_agg"]["total_count"])
# # Pointwise visualization
# results_unc_meg_joint["ue"].plot_meg_ellipse_membership_pointwise(
# x_true_3d=results_unc_meg_joint["x_true_meg_3d"],
# x_hat_2d=results_unc_meg_joint["x_hat_meg_2d"],
# posterior_cov_2d=out_free_meg_joint["posterior_cov"],
# src_idx=0,
# time_idx=0,
# nominal_coverage=0.95,
# V_tan=results_unc_meg_joint["V_tan"],
# psd_repair_blocks=True,
# block_epsilon=1e-12,
# )
# # Aggregated visualization
# results_unc_meg_joint["ue"].plot_meg_ellipse_membership_aggregated(
# diag_meg=results_unc_meg_joint["diag_meg_agg"],
# src_idx=0,
# nominal_coverage=0.95,
# )