import logging
import numpy as np
from numpy.linalg import inv
from matplotlib import cm
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from scipy import linalg
from scipy.spatial.distance import pdist, squareform
from scipy.stats import chi2, norm
from scipy.linalg import sqrtm
from scipy.sparse import coo_matrix, csr_matrix, diags, eye, kron, issparse, block_diag, identity
from functools import partial
import mne
from mne.utils import sqrtm_sym, eigh
from mne.io.constants import FIFF
from typing import Optional, Dict, Any, Tuple
from calibrain.utils import get_data_path
# ===================
# GAMMA-MAP Functions
# ===================
def _validate_gamma_map_inputs(
L: np.ndarray,
y: np.ndarray,
n_orient: int,
) -> Tuple[np.ndarray, np.ndarray]:
L = np.asarray(L, dtype=np.float64)
y = np.asarray(y, dtype=np.float64)
if y.ndim == 1:
y = y[:, np.newaxis]
if L.ndim != 2:
raise ValueError("L must be 2D.")
if y.ndim != 2:
raise ValueError("y must have shape (M,T) or (M,).")
if L.shape[0] != y.shape[0]:
raise ValueError("L and y must have the same number of sensor rows.")
if n_orient not in (1, 2, 3):
raise ValueError("n_orient must be 1, 2, or 3.")
if L.shape[1] % n_orient != 0:
raise ValueError(
f"For n_orient={n_orient}, L must have k*N columns with k={n_orient}."
)
return L, y
def _prepare_init_gamma(
n_coeffs: int,
n_orient: int,
init_gamma=None,
) -> np.ndarray:
n_groups = n_coeffs // n_orient
if init_gamma is None:
gamma0 = np.ones(n_coeffs, dtype=np.float64)
elif isinstance(init_gamma, (float, np.floating, int, np.integer)):
gamma0 = np.full(n_coeffs, float(init_gamma), dtype=np.float64)
elif isinstance(init_gamma, tuple) and len(init_gamma) == 2:
gamma0 = np.linspace(init_gamma[0], init_gamma[1], num=n_coeffs).astype(np.float64)
else:
gamma0 = np.asarray(init_gamma, dtype=np.float64).ravel()
if gamma0.size == n_groups:
gamma0 = np.repeat(gamma0, n_orient)
elif gamma0.size != n_coeffs:
raise ValueError(
f"init_gamma must have length {n_coeffs} or {n_groups}; got {gamma0.size}."
)
gamma0 = np.maximum(gamma0, 0.0)
if n_orient > 1:
gamma_group = gamma0.reshape(n_groups, n_orient).mean(axis=1)
gamma0 = np.repeat(gamma_group, n_orient)
return gamma0
def _gamma_map_opt(
M: np.ndarray,
G: np.ndarray,
sigma_squared: float,
*,
maxit: int = 300,
tol: float = 1e-6,
update_mode: int = 2,
group_size: int = 1,
init_gamma: Optional[np.ndarray] = None,
verbose: bool = False,
logger: Optional[logging.Logger] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
if logger is None:
logger = logging.getLogger(__name__)
G = np.asarray(G, dtype=np.float64).copy()
M = np.asarray(M, dtype=np.float64).copy()
n_coeffs_total = G.shape[1]
n_sensors, n_times = M.shape
eps = np.finfo(float).eps
if n_coeffs_total % group_size != 0:
raise ValueError("Number of coefficients must be divisible by group_size.")
if init_gamma is None:
gamma = np.ones(n_coeffs_total, dtype=np.float64)
else:
gamma = np.asarray(init_gamma, dtype=np.float64).copy()
if gamma.shape != (n_coeffs_total,):
raise ValueError(
f"init_gamma must have shape ({n_coeffs_total},), got {gamma.shape}"
)
M_norm_c = np.linalg.norm(M @ M.T, ord="fro")
if M_norm_c <= 0:
raise ValueError("Degenerate data: M has zero norm.")
M /= np.sqrt(M_norm_c)
sigma_squared /= M_norm_c
G_norm_c = np.linalg.norm(G, ord=np.inf)
if G_norm_c <= 0:
raise ValueError("Degenerate leadfield: G has zero norm.")
G /= G_norm_c
active_indices = np.arange(n_coeffs_total, dtype=int)
gammas_full_old = gamma.copy()
A_last = np.zeros((0, n_times), dtype=np.float64)
CMinv_last = np.eye(n_sensors, dtype=np.float64)
G_last = np.zeros((n_sensors, 0), dtype=np.float64)
gamma_last = np.zeros((0,), dtype=np.float64)
active_indices_last = np.zeros((0,), dtype=int)
last_size = -1
it_used = 0
for itno in range(maxit):
it_used = itno + 1
gamma = np.nan_to_num(gamma, nan=0.0, posinf=0.0, neginf=0.0)
gamma = np.maximum(gamma, 0.0)
n_groups_active = gamma.size // group_size
gamma_group = gamma.reshape(n_groups_active, group_size).mean(axis=1)
gmask_group = np.abs(gamma_group) > eps
gmask_coeff = np.repeat(gmask_group, group_size)
if not np.all(gmask_coeff):
active_indices = active_indices[gmask_coeff]
gamma = gamma[gmask_coeff]
G = G[:, gmask_coeff]
if gamma.size == 0:
break
CM = (G * gamma[np.newaxis, :]) @ G.T
CM.flat[:: n_sensors + 1] += sigma_squared
U, S, _ = linalg.svd(CM, full_matrices=False)
S = S[np.newaxis, :]
CMinv = (U / (S + eps)) @ U.T
CMinvG = CMinv @ G
A = CMinvG.T @ M
if update_mode == 1:
numer = gamma**2 * np.mean((A * A.conj()).real, axis=1)
denom = gamma * np.sum(G * CMinvG, axis=0)
elif update_mode == 2:
numer = gamma * np.sqrt(np.mean((A * A.conj()).real, axis=1))
denom = np.sum(G * CMinvG, axis=0)
elif update_mode == 3:
numer = gamma**2 * np.mean((A * A.conj()).real, axis=1) + gamma * (
1.0 - gamma * np.sum(G * CMinvG, axis=0)
)
denom = None
else:
raise ValueError("update_mode must be 1, 2, or 3.")
if group_size == 1:
if denom is None:
gamma = numer
elif update_mode == 2:
gamma = numer / np.sqrt(np.maximum(denom, eps))
else:
gamma = numer / np.maximum(denom, eps)
else:
numer_group = np.sum(numer.reshape(-1, group_size), axis=1)
if denom is None:
gamma_group_new = numer_group
else:
denom_group = np.sum(denom.reshape(-1, group_size), axis=1)
if update_mode == 2:
gamma_group_new = numer_group / np.sqrt(np.maximum(denom_group, eps))
else:
gamma_group_new = numer_group / np.maximum(denom_group, eps)
gamma = np.repeat(gamma_group_new / group_size, group_size)
gamma = np.maximum(gamma, 0.0)
gammas_full = np.zeros(n_coeffs_total, dtype=np.float64)
gammas_full[active_indices] = gamma
err = np.sum(np.abs(gammas_full - gammas_full_old)) / np.sum(
np.abs(gammas_full_old) + eps
)
gammas_full_old = gammas_full.copy()
A_last = A.copy()
CMinv_last = CMinv.copy()
G_last = G.copy()
gamma_last = gamma.copy()
active_indices_last = active_indices.copy()
breaking = (err < tol) or (gamma.size == 0)
if (gamma.size != last_size) or breaking:
if verbose:
logger.debug(f"it={itno:4d} active={gamma.size:4d} err={err:0.3e}")
last_size = gamma.size
if breaking:
break
if active_indices_last.size == 0:
x_active = np.zeros((0, n_times), dtype=np.float64)
posterior_cov_active = np.zeros((0, 0), dtype=np.float64)
gammas_full = np.zeros(n_coeffs_total, dtype=np.float64)
return x_active, active_indices_last, posterior_cov_active, gammas_full, it_used
n_const = np.sqrt(M_norm_c) / G_norm_c
x_active = n_const * gamma_last[:, None] * A_last
posterior_cov_active = (
np.diag(gamma_last) - gamma_last[:, None] * (G_last.T @ CMinv_last @ G_last) * gamma_last
)
posterior_cov_active = (n_const**2) * _symmetrize(posterior_cov_active)
gammas_full = np.zeros(n_coeffs_total, dtype=np.float64)
gammas_full[active_indices_last] = gamma_last
return x_active, active_indices_last, posterior_cov_active, gammas_full, it_used
def gamma_map(
L: np.ndarray,
y: np.ndarray,
noise_var: float,
n_orient: int = 1,
max_iter: int = 300,
tol: float = 1e-6,
update_mode: int = 2,
init_gamma=None,
verbose: bool = False,
logger: Optional[logging.Logger] = None,
**kwargs,
) -> Dict[str, Any]:
if logger is None:
logger = logging.getLogger(__name__)
L, y = _validate_gamma_map_inputs(L=L, y=y, n_orient=n_orient)
noise_var = float(noise_var)
if noise_var <= 0:
raise ValueError("noise_var must be positive.")
n_sensors, n_times = y.shape
n_coeffs = L.shape[1]
n_sources = n_coeffs // n_orient
gamma0 = _prepare_init_gamma(
n_coeffs=n_coeffs,
n_orient=n_orient,
init_gamma=init_gamma,
)
whitener = (1.0 / np.sqrt(noise_var)) * np.eye(n_sensors, dtype=np.float64)
y_w = whitener @ y
L_w = whitener @ L
x_active, active_indices, posterior_cov_active, gammas_full, n_iter = _gamma_map_opt(
M=y_w,
G=L_w,
sigma_squared=1.0,
maxit=max_iter,
tol=tol,
update_mode=update_mode,
group_size=n_orient,
init_gamma=gamma0,
verbose=verbose,
logger=logger,
)
x_hat = np.zeros((n_coeffs, n_times), dtype=np.float64)
posterior_cov = np.zeros((n_coeffs, n_coeffs), dtype=np.float64)
if active_indices.size > 0:
x_hat[active_indices] = x_active
posterior_cov[np.ix_(active_indices, active_indices)] = posterior_cov_active
posterior_cov = _symmetrize(posterior_cov)
out = {
"posterior_mean": x_hat,
"posterior_cov": posterior_cov,
"posterior_cov_active": posterior_cov_active,
"noise_var": float(noise_var),
"gamma": float(np.mean(gammas_full)),
"gammas_full": gammas_full,
"active_indices": active_indices,
"active_source_indices": np.unique(active_indices // n_orient),
"coefficient_indices": np.arange(n_coeffs),
"source_indices": np.arange(n_sources),
"n_orient": int(n_orient),
"n_iter": int(n_iter),
}
if n_orient > 1:
out["posterior_mean_reshaped"] = x_hat.reshape(n_sources, n_orient, n_times)
return out
# ==================
# sFlex Functions
# ==================
# def get_subset_source_rr_from_extract(lf_dict: Dict[str, Any]) -> np.ndarray:
# src = lf_dict["fwd"]["src"]
# rr_lh = src[0]["rr"][src[0]["vertno"]]
# rr_rh = src[1]["rr"][src[1]["vertno"]]
# rr_full = np.vstack([rr_lh, rr_rh])
# subset_idx = np.asarray(lf_dict["subset_idx"], dtype=int)
# return rr_full[subset_idx]
def compute_B(
sigma: float,
threshold_factor: float = 3.0,
normalize: Optional[str] = "sym",
eps: float = 1e-12,
src_coords: np.ndarray = None,
):
if src_coords.ndim != 2 or src_coords.shape[1] != 3:
raise ValueError("src_coords must have shape (N,3).")
if sigma <= 0:
raise ValueError("sigma must be positive.")
N = src_coords.shape[0]
dist2 = squareform(pdist(src_coords, metric="sqeuclidean"))
r2 = (threshold_factor * sigma) ** 2
mask = dist2 <= r2
rows, cols = np.nonzero(mask)
weights = np.exp(-dist2[rows, cols] / (2.0 * sigma**2))
B = coo_matrix((weights, (rows, cols)), shape=(N, N)).tocsr()
B = 0.5 * (B + B.T)
if normalize is None:
return B
row_sums = np.asarray(B.sum(axis=1)).ravel()
if normalize == "row":
inv = 1.0 / np.maximum(row_sums, eps)
return diags(inv) @ B
if normalize == "sym":
inv_sqrt = 1.0 / np.sqrt(np.maximum(row_sums, eps))
Dm = diags(inv_sqrt)
B = Dm @ B @ Dm
B = 0.5 * (B + B.T)
return B
raise ValueError("normalize must be None, 'row', or 'sym'.")
def _expand_spatial_basis(B, n_sources: int, n_orient: int):
if issparse(B):
B = B.tocsr()
else:
B = np.asarray(B, dtype=np.float64)
if B.shape != (n_sources, n_sources):
raise ValueError(
f"B must have shape ({n_sources},{n_sources}); got {B.shape}."
)
if n_orient == 1:
return B
I_k = eye(n_orient, format="csr")
if issparse(B):
return kron(B, I_k, format="csr")
return np.kron(B, np.eye(n_orient, dtype=np.float64))
def _right_multiply_dense_by_sparse(A: np.ndarray, S) -> np.ndarray:
out = (S.T @ A.T).T
return np.asarray(out, dtype=np.float64)
[docs]
def gamma_map_sflex(
L: np.ndarray,
y: np.ndarray,
noise_var: float,
n_orient: int = 1,
max_iter: int = 300,
tol: float = 1e-6,
update_mode: int = 2,
init_gamma=None,
sigma: float = 0.01,
threshold_factor: float = 3.0,
normalize: Optional[str] = "sym",
eps: float = 1e-12,
verbose: bool = False,
logger: Optional[logging.Logger] = None,
**kwargs,
) -> Dict[str, Any]:
if logger is None:
logger = logging.getLogger(__name__)
L, y = _validate_gamma_map_inputs(L=L, y=y, n_orient=n_orient)
n_coeffs = L.shape[1]
n_sources = n_coeffs // n_orient
n_times = y.shape[1]
B = compute_B(
sigma=sigma,
threshold_factor=threshold_factor,
normalize=normalize,
eps=eps,
src_coords=kwargs.get("src_coords"),
)
B_big = _expand_spatial_basis(B=B, n_sources=n_sources, n_orient=n_orient)
if issparse(B_big):
G = _right_multiply_dense_by_sparse(L, B_big)
else:
G = L @ B_big
res_c = gamma_map(
L=G,
y=y,
noise_var=noise_var,
n_orient=n_orient,
max_iter=max_iter,
tol=tol,
update_mode=update_mode,
init_gamma=init_gamma,
verbose=verbose,
logger=logger,
)
c_hat = np.asarray(res_c["posterior_mean"], dtype=np.float64)
Sigma_c = np.asarray(res_c["posterior_cov"], dtype=np.float64)
if issparse(B_big):
x_hat = np.asarray(B_big @ c_hat, dtype=np.float64)
B_big_dense = B_big.toarray()
else:
B_big_dense = np.asarray(B_big, dtype=np.float64)
x_hat = B_big_dense @ c_hat
posterior_cov = B_big_dense @ Sigma_c @ B_big_dense.T
posterior_cov = _symmetrize(posterior_cov)
out = {
"posterior_mean": x_hat,
"posterior_cov": posterior_cov,
"posterior_mean_coeff": c_hat,
"posterior_cov_coeff": Sigma_c,
"noise_var": float(res_c["noise_var"]),
"gamma": float(res_c["gamma"]),
"gammas_full": np.asarray(res_c["gammas_full"], dtype=np.float64),
"active_indices": np.asarray(res_c["active_indices"], dtype=int),
"active_source_indices": np.asarray(res_c["active_source_indices"], dtype=int),
"coefficient_indices": np.arange(n_coeffs),
"source_indices": np.arange(n_sources),
"n_orient": int(n_orient),
"n_iter": int(res_c["n_iter"]),
"B_spatial": B,
}
if n_orient > 1:
out["posterior_mean_reshaped"] = x_hat.reshape(n_sources, n_orient, n_times)
return out
# ===================
# eLORETA Functions
# ===================
def sqrtm_sym(M, inv=False):
"""
Compute the square root (or inverse square root) of symmetric matrices,
handling both 2D and block-diagonal 3D cases.
"""
if M.ndim == 3:
# Process each block separately (n_blocks, n, n)
n_blocks, n, _ = M.shape
S = np.zeros_like(M)
s = np.zeros((n_blocks, n))
for i in range(n_blocks):
s_i, U_i = eigh(M[i])
s_i = np.clip(s_i, 0, None)
if inv:
s_i = 1.0 / np.sqrt(s_i + np.finfo(float).eps)
else:
s_i = np.sqrt(s_i)
S[i] = (U_i * s_i) @ U_i.T
s[i] = s_i
return S, s
else:
# Original 2D case
s, U = eigh(M)
s = np.clip(s, 0, None)
if inv:
s = 1.0 / np.sqrt(s + np.finfo(float).eps)
else:
s = np.sqrt(s)
S = (U * s) @ U.T
return S, s
def normalize_R(G, R, G_3, n_nzero, force_equal, n_src, n_orient):
"""
Normalize the source covariance matrix (R) for consistency with eigenvalues.
This function normalizes the product G @ R @ G.T so that its trace matches a
reference value (n_nzero).
Parameters
----------
G : ndarray, shape (n_chan, n_src * n_orient)
The lead-field or forward matrix after applying whitening and source scaling.
R : ndarray
The source covariance matrix; may be a 1D vector (single orientation) or
a block diagonal structure (multiple orientations).
G_3 : ndarray or None
Reshaped version of G for multi-orientation sources (n_src x n_orient x n_chan),
or None for single orientation.
n_nzero : int
The number of non-zero sensor dimensions (typically, the number of sensors).
force_equal : bool
If True, enforce equal orientation weights (i.e., treat sources with single orientation).
n_src : int
Number of sources (after accounting for orientation).
n_orient : int
Number of orientations per source (1 for fixed, 3 for free orientation).
Returns
-------
G_R_Gt : ndarray
The normalized product G @ R @ G.T.
"""
# If sources are scalar (single orientation) or forced to have equal orientation,
# perform element-wise multiplication for R.
if n_orient == 1 or force_equal:
# R[:, np.newaxis] makes R a column vector, then multiply each column of G.T
R_Gt = R[:, np.newaxis] * G.T
else:
# For multi-orientation: Perform matrix multiplication with reshaped G.
R_Gt = np.matmul(R, G_3).reshape(n_src * 3, -1)
# Compute product G @ R @ G.T (the sensor-level covariance after applying R)
G_R_Gt = G @ R_Gt
# Compute the normalization factor as the trace divided by number of sensors (n_nzero)
norm = np.trace(G_R_Gt) / n_nzero
# Scale the matrix and R by the normalization factor
G_R_Gt /= norm
R /= norm
return G_R_Gt
def get_G_3(G, n_orient):
"""
Reshape and transpose the lead-field matrix G for multi-orientation sources.
Parameters
----------
G : ndarray, shape (n_chan, n_src * n_orient)
The original lead-field matrix, after whitening and orientation‐prior scaling.
n_orient : int
Number of orientations per source (1 for fixed, 3 for free orientation).
Returns
-------
ndarray or None :
If n_orient > 1, returns an array of shape (n_src, n_orient, n_chan),
so that each source’s 3×n_chan lead-field slice is one block.
If n_orient == 1, returns None.
"""
if n_orient == 1:
return None # No multi-orientation; nothing to reshape
else:
# 1) G originally is (n_chan, n_src * n_orient).
# We want to group every 'n_orient' columns into one source.
# 2) First reshape to (n_chan, n_src, n_orient):
# G.reshape(n_chan, n_src, n_orient)
# 3) Then transpose axes so that the block for source i is at G_3[i]:
# .transpose(1, 2, 0) → (n_src, n_orient, n_chan)
return G.reshape(G.shape[0], -1, n_orient).transpose(1, 2, 0)
def R_sqrt_mult(other, R_sqrt):
"""
Efficiently compute the multiplication: other @ R_sqrt.
This function handles both diagonal and block-diagonal cases for R_sqrt.
Parameters
----------
other : ndarray, shape (n_chan, n_src * n_orient) or similar
The matrix to be multiplied with R_sqrt.
R_sqrt : ndarray
The square root of the source covariance matrix R. It is either a 1D vector
(for a diagonal matrix) or a 3D array (for block-diagonal multi-orientation case).
Returns
-------
out : ndarray
The result of the matrix multiplication.
"""
if R_sqrt.ndim == 1: # Diagonal matrix represented as a vector
# Ensure compatible dimensions: other.shape[1] == size of R_sqrt
assert other.shape[1] == R_sqrt.size
out = R_sqrt * other # Element-wise multiplication
else:
# For multi-orientation, each source has a 3x3 block.
# Assert dimensions of R_sqrt: (n_src, 3, 3)
assert R_sqrt.shape[1:3] == (3, 3)
# other.shape[1] should be equal to (n_src*3)
assert other.shape[1] == np.prod(R_sqrt.shape[:2])
assert other.ndim == 2
n_src = R_sqrt.shape[0] # Number of sources
n_chan = other.shape[0] # Number of channels/sensors
# Reshape and transpose to perform block multiplication
out = (
np.matmul(R_sqrt, other.reshape(n_chan, n_src, 3).transpose(1, 2, 0))
.reshape(n_src * 3, n_chan)
.T
)
return out
def compute_reginv2(sing, n_nzero, lambda2):
"""
Compute the regularized inverse of singular values.
This applies Tikhonov regularization in the SVD domain to handle small singular values.
Parameters
----------
sing : array-like, singular values from the SVD.
n_nzero : int, number of non-zero singular values (typically number of sensors).
lambda2 : float, regularization parameter.
Returns
-------
reginv : array-like, the regularized inverses.
"""
# Ensure the singular values are in floating point for precision.
sing = np.array(sing, dtype=np.float64)
reginv = np.zeros_like(sing) # Initialize the output array
# Consider only the first n_nzero singular values.
sing = sing[:n_nzero]
with np.errstate(invalid="ignore"):
# Regularized inversion: sigma / (sigma^2 + lambda2)
reginv[:n_nzero] = np.where(sing > 0, sing / (sing ** 2 + lambda2), 0)
return reginv
def compute_orient_prior(G, n_orient, loose=0.9):
"""
Compute an orientation prior for sources.
The orientation prior weights help to scale the source estimates according
to expected orientation variability (e.g., "loose" constraints for x and y directions).
Parameters
----------
G : ndarray, the lead-field matrix.
n_orient : int, number of orientations per source.
loose : float, scaling factor for certain orientations.
Returns
-------
orient_prior : ndarray, shape (n_sources * n_orient,)
The prior weights for each source orientation.
"""
n_sources = G.shape[1]
orient_prior = np.ones(n_sources, dtype=np.float64) # Default is weight of 1 for all sources
if n_orient == 1:
return orient_prior # No adjustment needed for single orientation
# For multi-orientation (e.g., free orientation with three components),
# the x and y orientations are scaled by the 'loose' factor.
orient_prior[::3] *= loose # Scale the first orientation (x)
orient_prior[1::3] *= loose # Scale the second orientation (y)
# The third orientation (z) remains unchanged (multiplied by 1)
return orient_prior
def safe_svd(A, full_matrices=False):
"""
Safely compute the SVD of matrix A.
Parameters
----------
A : ndarray
The matrix for which to compute the singular value decomposition.
full_matrices : bool
Flag determining if full or reduced SVD is computed.
Returns
-------
U, S, Vh : ndarrays
The left singular vectors, singular values, and right singular vectors.
"""
return np.linalg.svd(A, full_matrices=full_matrices)
def compute_eloreta_kernel(L, *, lambda2, n_orient, whitener, loose=1.0, max_iter=20, logger=None):
"""
Compute the eLORETA kernel and the posterior source covariance.
This function carries out the main steps of the eLORETA estimation:
1. Whiten the lead-field matrix L.
2. Apply the orientation prior to the source covariance.
3. Initialize and iteratively update the source covariance matrix R.
4. Normalize R and compute the effective gain matrix.
5. Perform an SVD on the effective gain matrix and regularize the singular values.
6. Assemble the final inverse operator (kernel K).
Parameters
----------
L : ndarray, shape (n_chan, n_src*n_orient)
The original lead-field matrix.
lambda2 : float, regularization parameter to stabilize the inversion.
n_orient : int, the number of orientations per source (1 for fixed orientation, 3 for free orientation).
whitener : ndarray, the whitening matrix derived from the noise covariance.
loose : float, parameter for the orientation prior (looseness of the constraints).
max_iter : int, maximum number of iterations for the iterative fitting procedure.
Returns
-------
K : ndarray, the eLORETA kernel (inverse operator) used to compute source estimates.
Sigma : ndarray, the posterior source covariance matrix.
"""
options = dict(eps=1e-6, max_iter=max_iter, force_equal=False) # taken from mne
eps, max_iter = options["eps"], options["max_iter"]
force_equal = bool(options["force_equal"]) # None means False
G = whitener @ L
n_nzero = G.shape[0]
# restore orientation prior
source_std = np.ones(G.shape[1])
orient_prior = compute_orient_prior(G, n_orient, loose=loose)
source_std *= np.sqrt(orient_prior)
G *= source_std
# We do not multiply by the depth prior, as eLORETA should compensate for
# depth bias.
_, n_src = G.shape
n_src //= n_orient
assert n_orient in (1, 3)
# src, sens, 3
G_3 = get_G_3(G, n_orient)
if n_orient != 1 and not force_equal:
# Outer product
R_prior = source_std.reshape(n_src, 1, 3) * source_std.reshape(n_src, 3, 1)
else:
R_prior = source_std ** 2
# The following was adapted under BSD license by permission of Guido Nolte
if force_equal or n_orient == 1:
R_shape = (n_src * n_orient,)
R = np.ones(R_shape)
else:
R_shape = (n_src, n_orient, n_orient)
R = np.empty(R_shape)
R[:] = np.eye(n_orient)[np.newaxis]
R *= R_prior
this_normalize_R = partial(
normalize_R,
n_nzero=n_nzero,
force_equal=force_equal,
n_src=n_src,
n_orient=n_orient,
)
G_R_Gt = this_normalize_R(G, R, G_3)
# extra = " (this make take a while)" if n_orient == 3 else ""
for kk in range(max_iter):
# 1. Compute inverse of the weights (stabilized) and C
s, u = eigh(G_R_Gt)
s = abs(s)
sidx = np.argsort(s)[::-1][:n_nzero]
s, u = s[sidx], u[:, sidx]
with np.errstate(invalid="ignore"):
s = np.where(s > 0, 1 / (s + lambda2), 0)
N = np.dot(u * s, u.T)
del s
# Update the weights
R_last = R.copy()
if n_orient == 1:
R[:] = 1.0 / np.sqrt((np.dot(N, G) * G).sum(0))
else:
M = np.matmul(np.matmul(G_3, N[np.newaxis]), G_3.swapaxes(-2, -1))
if force_equal:
_, s = sqrtm_sym(M, inv=True)
R[:] = np.repeat(1.0 / np.mean(s, axis=-1), 3)
else:
R[:], _ = sqrtm_sym(M, inv=True)
R *= R_prior # reapply our prior, eLORETA undoes it
G_R_Gt = this_normalize_R(G, R, G_3)
# Check for weight convergence
delta = np.linalg.norm(R.ravel() - R_last.ravel()) / np.linalg.norm(
R_last.ravel()
)
if delta < eps:
break
else:
logger.debug("eLORETA weight fitting did not converge (>= %s)" % eps)
del G_R_Gt
G /= source_std # undo our biasing
G_3 = get_G_3(G, n_orient)
this_normalize_R(G, R, G_3)
del G_3
if n_orient == 1 or force_equal:
R_sqrt = np.sqrt(R)
else:
R_sqrt = sqrtm_sym(R)[0]
assert R_sqrt.shape == R_shape
A = R_sqrt_mult(G, R_sqrt)
eigen_fields, sing, eigen_leads = safe_svd(A, full_matrices=False)
# Precompute regularization terms for K and Σ_X
reginv_k = compute_reginv2(sing, n_nzero, lambda2) # σ_i / (σ_i² + λ)
reginv_s = sing * reginv_k # σ_i² / (σ_i² + λ) = σ_i * (σ_i / (σ_i² + λ))
# Compute K using existing terms
eigen_leads = R_sqrt_mult(eigen_leads, R_sqrt).T
trans = np.dot(eigen_fields.T, whitener)
trans *= reginv_k[:, None]
K = np.dot(eigen_leads, trans)
# Compute Σ_X directly from V and reginv_s
eigen_leads_t = eigen_leads.T
eigen_leads_t *= reginv_s[:, None] # each row scaled by σ_i / (σ_i² + λ)
Sigma = R - np.dot(eigen_leads, eigen_leads_t)
return K, Sigma
def eloreta(L, y, noise_var, n_orient=1, verbose=True, logger=None, **kwargs):
"""
Compute the eLORETA solution for EEG/MEG inverse modeling.
This is the main interface function that:
- Preprocesses the lead-field and data,
- Applies noise whitening,
- Computes the eLORETA kernel,
- And finally estimates the source activity.
Parameters
----------
L : ndarray, shape (n_chan, n_src*n_orient)
The lead-field (forward) matrix mapping sources to sensors.
y : ndarray, shape (n_chan, n_times) or (n_chan,)
The sensor data (EEG/MEG recordings) to be inverted.
n_orient : int
Number of orientations per source (1 for fixed or 3 for free orientation).
Returns
-------
x : ndarray
The estimated source activations. The shape will be (n_src, n_times) for
single orientation or (n_src, n_orient, n_times) for free orientations.
Sigma : ndarray
The posterior source covariance, characterizing the uncertainty in estimates.
"""
# TODO: check if this work for all noise types
noise_cov = noise_var * np.eye(L.shape[0])
# Create the whitening matrix from the noise covariance:
# Typically computed as the inverse of the square root of the covariance.
whitener = linalg.inv(linalg.sqrtm(noise_cov))
# Whiten both the sensor data and the lead-field matrix.
y = whitener @ y
L = whitener @ L
# Compute the eLORETA kernel and the posterior source covariance using the helper.
# alpha is lambda2 = noise_var
K, Sigma = compute_eloreta_kernel(
L,
lambda2=1.0,
n_orient=n_orient,
whitener=whitener,
logger=logger
)
# Compute the mean source estimates.
x = K @ y # get the source time courses with simple dot product
# If using free orientation sources (n_orient > 1), reshape the output.
if n_orient > 1:
x = x.reshape((-1, n_orient, x.shape[1]))
active_indices = np.arange(Sigma.shape[0]) # All sources are active in eLORETA
return {
"posterior_mean": x,
"posterior_cov": Sigma,
"noise_var": noise_var,
"active_indices": active_indices,
}
# ==========================================
# BMN (with sLORETA normalization) Functions
# ==========================================
def _symmetrize(A: np.ndarray) -> np.ndarray:
"""Return the symmetric part of a square matrix."""
return 0.5 * (A + A.T)
def _svd_inverse(A: np.ndarray, eps: float = 1e-12) -> np.ndarray:
"""
Stable matrix inverse using SVD.
"""
U, S, Vt = np.linalg.svd(A, full_matrices=False)
S_inv = 1.0 / np.maximum(S, eps)
return U @ np.diag(S_inv) @ Vt
def _validate_bmn_inputs(
L: np.ndarray,
y: np.ndarray,
n_orient: int,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Common validation for BMN / BMN_joint.
Expected leadfield shapes
-------------------------
n_orient = 1: L is (M, N)
n_orient = 2: L is (M, 2N)
n_orient = 3: L is (M, 3N)
"""
L = np.asarray(L, dtype=float)
y = np.asarray(y, dtype=float)
if y.ndim == 1:
y = y[:, np.newaxis]
if L.ndim != 2:
raise ValueError("L must be 2D.")
if y.ndim != 2:
raise ValueError("y must have shape (M, T) or (M,).")
if L.shape[0] != y.shape[0]:
raise ValueError("L and y must have the same number of sensor rows.")
if n_orient not in (1, 2, 3):
raise ValueError("n_orient must be 1, 2, or 3.")
if n_orient > 1 and (L.shape[1] % n_orient != 0):
raise ValueError(
f"For n_orient={n_orient}, L must have k*N columns with k={n_orient}."
)
return L, y
# sLORETA normalization
def compute_W(L: np.ndarray, n_orient: int = 1, beta: float = 1e-6) -> np.ndarray:
"""
Compute sLORETA-type normalization matrix W.
Supports
--------
1) Fixed orientation:
- L shape: (M, N)
- W shape: (N, N), diagonal
2) Reduced / free orientation:
- L shape: (M, kN), with k in {2, 3}
- W shape: (kN, kN), block-diagonal with k x k blocks
Notes
-----
- Uses SVD-based inversion for numerical stability.
- Uses symmetrization before eigendecomposition.
- For n_orient > 1, each source contributes one local k x k normalization block.
"""
L = np.asarray(L, dtype=float)
if L.ndim != 2:
raise ValueError("L must be 2D.")
if n_orient not in (1, 2, 3):
raise ValueError("n_orient must be 1, 2, or 3.")
M, dim = L.shape
eps = 1e-12
LLt = _symmetrize(L @ L.T)
LLt_reg = _symmetrize(LLt + beta * np.eye(M))
LLt_inv = _svd_inverse(LLt_reg, eps=eps)
# -------------------------------------------------------------------------
# Fixed orientation: scalar normalization per source
# -------------------------------------------------------------------------
if n_orient == 1:
A = LLt_inv @ L
diag_S = np.sum(L * A, axis=0)
W_diag = 1.0 / np.sqrt(np.maximum(diag_S, eps))
return np.diag(W_diag)
# -------------------------------------------------------------------------
# Reduced / free orientation: generic k x k local blocks, k in {2, 3}
# -------------------------------------------------------------------------
if dim % n_orient != 0:
raise ValueError(
f"Lead-field L must have {n_orient}N columns for n_orient={n_orient}."
)
N = dim // n_orient
S_hat = _symmetrize(L.T @ LLt_inv @ L)
W_blocks = []
for n in range(N):
sl = slice(n_orient * n, n_orient * (n + 1))
S_n = _symmetrize(S_hat[sl, sl])
evals, evecs = np.linalg.eigh(S_n)
evals = np.maximum(evals, eps)
# W_n = S_n^{-1/2}
W_n = evecs @ np.diag(1.0 / np.sqrt(evals)) @ evecs.T
W_n = _symmetrize(W_n)
W_blocks.append(W_n)
return block_diag(W_blocks, format="csr").toarray()
# BMN with fixed known noise variance
def BMN_opt(
y: np.ndarray,
L: np.ndarray,
alpha: float,
maxit: int = 1000,
tol: float = 1e-6,
init_gamma: Optional[float] = None,
logger: Optional[logging.Logger] = None,
verbose: bool = False,
) -> Tuple[np.ndarray, np.ndarray, float]:
"""
BMN optimization using Bayesian evidence maximization for one common
source variance.
Model
-----
y_t ~ N(L x_t, alpha I)
x_t ~ N(0, gamma I)
Notes
-----
- This is the normal BMN without adaptive noise learning.
- Gamma is a single common scalar variance in the internal optimization
coordinate system.
"""
L = np.asarray(L, dtype=float).copy()
y = np.asarray(y, dtype=float).copy()
gamma = 1.0 if init_gamma is None else float(init_gamma)
eps = np.finfo(float).eps
if gamma <= 0:
raise ValueError("init_gamma must be positive.")
if y.ndim == 1:
y = y[:, np.newaxis]
if y.ndim != 2:
raise ValueError("y must have shape (M, T) or (M,).")
if L.ndim != 2:
raise ValueError("L must be 2D.")
if L.shape[0] != y.shape[0]:
raise ValueError("L and y must have the same number of sensor rows.")
M, T = y.shape
y_original_scale = np.linalg.norm(y, ord="fro")
L_original_scale = np.linalg.norm(L, ord=2)
if y_original_scale < eps or L_original_scale < eps:
raise ValueError("Degenerate input: y or L has (near) zero norm.")
# Scale-normalized optimization
y = y / y_original_scale
L = L / L_original_scale
alpha = float(alpha) / (y_original_scale ** 2)
if alpha <= 0:
raise ValueError("alpha must be positive.")
LLt = _symmetrize(L @ L.T)
for it in range(maxit):
gamma_old = gamma
model_cov = _symmetrize(gamma * LLt + alpha * np.eye(M))
model_cov_inv = _svd_inverse(model_cov, eps=eps)
model_cov_inv_y = model_cov_inv @ y
numerator = np.trace(y.T @ model_cov_inv @ LLt @ model_cov_inv_y) / T
denominator = np.trace(model_cov_inv @ LLt)
gamma = numerator / max(denominator, eps)
gamma = max(float(gamma), eps)
err = np.abs(gamma - gamma_old) / (np.abs(gamma_old) + eps)
if verbose and logger is not None:
logger.debug(f"BMN iter {it:4d}: gamma={gamma:.6e}, err={err:.3e}")
if err < tol:
break
model_cov = _symmetrize(gamma * LLt + alpha * np.eye(M))
model_cov_inv = _svd_inverse(model_cov, eps=eps)
A = L.T @ model_cov_inv @ y
x_est = gamma * A
posterior_cov = gamma * np.eye(L.shape[1]) - gamma**2 * (L.T @ model_cov_inv @ L)
posterior_cov = _symmetrize(posterior_cov)
# Map posterior outputs back to original coefficient scale
scale_factor = y_original_scale / L_original_scale
x_hat = scale_factor * x_est
posterior_cov = (scale_factor ** 2) * posterior_cov
posterior_cov = _symmetrize(posterior_cov)
# gamma is returned as the internal common scalar hyperparameter
return x_hat, posterior_cov, float(gamma)
[docs]
def BMN(
L: np.ndarray,
y: np.ndarray,
noise_var: float,
n_orient: int = 1,
max_iter: int = 1000,
tol: float = 1e-6,
init_gamma: Optional[float] = None,
verbose: bool = False,
normalization: bool = False,
logger: Optional[logging.Logger] = None,
**kwargs,
) -> Dict[str, Any]:
"""
BMN estimate with optional sLORETA normalization.
Supports
--------
n_orient = 1 -> fixed (EEG or MEG)
n_orient = 2 -> reduced free MEG
n_orient = 3 -> free EEG
Notes
-----
- `posterior_mean` and `posterior_cov` are returned in the original
coefficient space.
- `gamma` is the learned common scalar hyperparameter in the internal
optimization parameterization, so it should be treated mainly as a
diagnostic quantity, especially when `normalization=True`.
"""
L, y = _validate_bmn_inputs(L=L, y=y, n_orient=n_orient)
noise_var = float(noise_var)
if noise_var <= 0:
raise ValueError("noise_var must be positive.")
M = L.shape[0]
# Optional sLORETA normalization
if normalization:
W = compute_W(L, n_orient=n_orient, beta=1e-6)
L_normal = L @ W
else:
W = np.eye(L.shape[1])
L_normal = L
# Fixed known-noise whitening
whitener = (1.0 / np.sqrt(noise_var)) * np.eye(M)
y_white = whitener @ y
L_white = whitener @ L_normal
x_hat_normal, posterior_cov_normal, gamma = BMN_opt(
y=y_white,
L=L_white,
alpha=1.0, # after whitening, noise covariance is I
maxit=max_iter,
tol=tol,
init_gamma=init_gamma,
logger=logger,
verbose=verbose,
)
# Undo normalization
x_hat = W @ x_hat_normal
posterior_cov = W @ posterior_cov_normal @ W.T
posterior_cov = _symmetrize(posterior_cov)
n_coeff = posterior_cov.shape[0]
n_sources = L.shape[1] if n_orient == 1 else L.shape[1] // n_orient
out = {
"posterior_mean": x_hat,
"posterior_cov": posterior_cov,
"noise_var": float(noise_var),
"gamma": float(gamma),
"coefficient_indices": np.arange(n_coeff),
"source_indices": np.arange(n_sources),
"active_indices": np.arange(n_coeff), # backward-compat alias (coefficient-level)
}
# Generic reshape for n_orient = 2 or 3
if n_orient > 1:
out["posterior_mean_reshaped"] = x_hat.reshape(n_sources, n_orient, x_hat.shape[1])
return out
# =============================================================================
# BMN with noise learning API
# =============================================================================
# Convex-bounding update rule for common scalar noise variance parameter
def update_common_lambda_convex(
y: np.ndarray,
L: np.ndarray,
posterior_mean: np.ndarray,
C_inv: np.ndarray,
eps: float = 1e-12,
) -> float:
"""
Convex-bounding update rule for one common scalar noise variance in normalized scale.
This function returns alpha_new, where
alpha = lambda / ||Y||_F^2
Formula
-------
alpha = sqrt( ( ||Y - L X||_F^2 / T ) / tr(C^{-1}) )
"""
_, T = y.shape
residual = y - L @ posterior_mean
residual_term = (np.linalg.norm(residual, ord="fro") ** 2) / T
denominator = np.trace(C_inv)
alpha_new = np.sqrt(max(residual_term, 0.0) / max(denominator, eps))
return max(float(alpha_new), eps)
# BMN optimization with optional adaptive noise learning
def BMN_joint_opt(
y: np.ndarray,
L: np.ndarray,
noise_var: Optional[float] = None,
maxit: int = 10000,
tol: float = 1e-6,
init_gamma: Optional[float] = None,
init_lambda: Optional[float] = None,
learn_noise: bool = False,
logger: Optional[logging.Logger] = None,
verbose: bool = False,
track_history: bool = True,
) -> Tuple[np.ndarray, np.ndarray, float, float, Dict[str, list]]:
"""
BMN optimization with one common source variance gamma and optional
one common scalar sensor noise variance lambda.
Model
-----
y_t ~ N(L x_t, lambda I)
x_t ~ N(0, gamma I)
Notes
-----
- Gamma update is the same scalar BMN update rule as in BMN_bayesian_opt.
- Lambda update is convex-bounding only.
- Convergence is checked on gamma.
- Returned `gamma` is the internal common scalar hyperparameter in the
optimization coordinate system.
"""
L = np.asarray(L, dtype=float).copy()
y = np.asarray(y, dtype=float).copy()
eps = np.finfo(float).eps
gamma = 1.0 if init_gamma is None else float(init_gamma)
gamma = max(gamma, eps)
if y.ndim == 1:
y = y[:, np.newaxis]
if y.ndim != 2:
raise ValueError("y must have shape (M, T) or (M,).")
if L.ndim != 2:
raise ValueError("L must be 2D.")
if L.shape[0] != y.shape[0]:
raise ValueError("L and y must have the same number of sensor rows.")
M, T = y.shape
y_original_scale = np.linalg.norm(y, ord="fro")
L_original_scale = np.linalg.norm(L, ord=2)
if y_original_scale < eps or L_original_scale < eps:
raise ValueError("Degenerate input: y or L has (near) zero norm.")
# Scale-normalized optimization
y = y / y_original_scale
L = L / L_original_scale
if learn_noise:
if noise_var is not None:
raise ValueError(
"When learn_noise=True, noise_var must be None. "
"Use init_lambda for initialization."
)
lambda_var = 1.0 if init_lambda is None else float(init_lambda)
if lambda_var <= 0:
raise ValueError("init_lambda must be positive when learn_noise=True.")
else:
if noise_var is None:
raise ValueError("When learn_noise=False, noise_var must be provided.")
lambda_var = float(noise_var)
if lambda_var <= 0:
raise ValueError("noise_var must be positive.")
alpha = lambda_var / (y_original_scale ** 2)
LLt = _symmetrize(L @ L.T)
hist: Dict[str, list] = {}
if track_history:
hist["gamma_hist"] = []
hist["lambda_hist"] = []
hist["noise_var_hist"] = []
hist["err_gamma_hist"] = []
for it in range(maxit):
gamma_old = gamma
alpha_old = alpha
model_cov = _symmetrize(gamma_old * LLt + alpha_old * np.eye(M))
model_cov_inv = _svd_inverse(model_cov, eps)
# Same scalar gamma update as normal BMN
model_cov_inv_y = model_cov_inv @ y
numerator = np.trace(y.T @ model_cov_inv @ LLt @ model_cov_inv_y) / T
denominator = np.trace(model_cov_inv @ LLt)
gamma = numerator / max(denominator, eps)
gamma = max(float(gamma), eps)
# Optional common lambda update
if learn_noise:
model_cov = _symmetrize(gamma * LLt + alpha_old * np.eye(M))
model_cov_inv = _svd_inverse(model_cov, eps)
A = L.T @ model_cov_inv @ y
x_est_norm = gamma * A
alpha = update_common_lambda_convex(
y=y,
L=L,
posterior_mean=x_est_norm,
C_inv=model_cov_inv,
eps=eps,
)
else:
alpha = alpha_old
err_gamma = np.abs(gamma - gamma_old) / (np.abs(gamma_old) + eps)
lambda_curr = alpha * (y_original_scale ** 2)
if track_history:
hist["gamma_hist"].append(float(gamma))
hist["lambda_hist"].append(float(lambda_curr))
hist["noise_var_hist"].append(float(lambda_curr))
hist["err_gamma_hist"].append(float(err_gamma))
if verbose and logger is not None:
logger.debug(
f"BMN iter {it:4d}: gamma={gamma:.6e}, "
f"lambda={lambda_curr:.6e}, err_gamma={err_gamma:.3e}"
)
if err_gamma < tol:
break
model_cov = _symmetrize(gamma * LLt + alpha * np.eye(M))
model_cov_inv = _svd_inverse(model_cov, eps)
A = L.T @ model_cov_inv @ y
x_est = gamma * A
posterior_cov = gamma * np.eye(L.shape[1]) - gamma**2 * (L.T @ model_cov_inv @ L)
posterior_cov = _symmetrize(posterior_cov)
# Map posterior outputs back to original coefficient scale
scale_factor = y_original_scale / L_original_scale
x_hat = scale_factor * x_est
posterior_cov = (scale_factor ** 2) * posterior_cov
posterior_cov = _symmetrize(posterior_cov)
lambda_var = alpha * (y_original_scale ** 2)
return x_hat, posterior_cov, float(gamma), float(lambda_var), hist
[docs]
def BMN_joint(
L: np.ndarray,
y: np.ndarray,
noise_var: Optional[float] = None,
n_orient: int = 1,
max_iter: int = 1000,
tol: float = 1e-6,
init_gamma: Optional[float] = None,
init_lambda: Optional[float] = None,
learn_noise: bool = True,
verbose: bool = False,
normalization: bool = False,
track_history: bool = True,
logger: Optional[logging.Logger] = None,
**kwargs,
) -> Dict[str, Any]:
"""
BMN estimate with optional sLORETA normalization and optional adaptive
common-noise learning.
Supports
--------
n_orient = 1 -> fixed (EEG or MEG)
n_orient = 2 -> reduced free MEG
n_orient = 3 -> free EEG
Notes
-----
- `posterior_mean` and `posterior_cov` are returned in the original
coefficient space.
- `gamma` is the learned common scalar hyperparameter in the internal
optimization parameterization, so it should be treated mainly as a
diagnostic quantity, especially when `normalization=True`.
"""
L, y = _validate_bmn_inputs(L=L, y=y, n_orient=n_orient)
if learn_noise:
if noise_var is not None:
raise ValueError(
"When learn_noise=True, noise_var must be None. "
"Use init_lambda for initialization."
)
else:
if noise_var is None:
raise ValueError("When learn_noise=False, noise_var must be provided.")
if float(noise_var) <= 0:
raise ValueError("noise_var must be positive.")
# Optional sLORETA normalization
if normalization:
W = compute_W(L, n_orient=n_orient, beta=1e-6)
L_normal = L @ W
else:
W = np.eye(L.shape[1])
L_normal = L
x_hat_normal, posterior_cov_normal, gamma, lambda_var, hist = BMN_joint_opt(
y=y,
L=L_normal,
noise_var=noise_var,
maxit=max_iter,
tol=tol,
init_gamma=init_gamma,
init_lambda=init_lambda,
learn_noise=learn_noise,
logger=logger,
verbose=verbose,
track_history=track_history,
)
# Undo normalization
x_hat = W @ x_hat_normal
posterior_cov = W @ posterior_cov_normal @ W.T
posterior_cov = _symmetrize(posterior_cov)
n_coeff = posterior_cov.shape[0]
n_sources = L.shape[1] if n_orient == 1 else L.shape[1] // n_orient
out = {
"posterior_mean": x_hat,
"posterior_cov": posterior_cov,
"gamma": float(gamma),
"lambda": float(lambda_var),
"noise_var": float(lambda_var), # compatibility alias
"coefficient_indices": np.arange(n_coeff),
"source_indices": np.arange(n_sources),
"active_indices": np.arange(n_coeff), # backward-compat alias (coefficient-level)
}
# Generic reshape for n_orient = 2 or 3
if n_orient > 1:
out["posterior_mean_reshaped"] = x_hat.reshape(n_sources, n_orient, x_hat.shape[1])
if track_history:
out.update(hist)
return out
# ==================
# Main Solver Class
# ==================
[docs]
class SourceEstimator(BaseEstimator, ClassifierMixin):
[docs]
def __init__(self, solver, solver_params=None, noise_var=None, n_orient=1, logger=None):
"""
Initialize the SourceEstimator class.
Parameters
----------
solver : callable
The inverse solver function (e.g., gamma_map_sflex, BMN).
solver_params : dict, optional
Parameters for the solver function.
noise_var : float, optional
Noise variance for the solver.
logger : logging.Logger, optional
Logger instance for logging messages.
n_orient : int, optional
Number of orientations for the sources. Default is 1 (for fixed
orientation) or 3 (for free orientation).
"""
# Follow sklearn convention: __init__ should *only* assign the passed
# parameters to attributes without mutating them. Keep `solver_params`
# as provided (None is allowed) so that clone can reconstruct the
# estimator exactly. Downstream code should treat None as an empty
# dict when invoking the solver.
self.solver = solver
self.solver_params = solver_params
self.noise_var = noise_var
self.logger = logger or logging.getLogger(__name__)
self.n_orient = n_orient
def _format_leadfield(self, L):
"""
Ensure the leadfield matches the solver expectation of (n_sensors, n_sources * n_orient).
Parameters
----------
L : np.ndarray
Leadfield array with shape (n_sensors, n_sources) for fixed-orientation
or (n_sensors, n_sources, n_orient) for free-orientation setups.
Returns
-------
np.ndarray
A 2-D leadfield with shape (n_sensors, n_sources * n_orient).
"""
if L.ndim == 2:
return L
if L.ndim == 3:
n_sensors, n_sources, n_vec = L.shape
if self.n_orient not in (None, n_vec):
self.logger.debug(
"Updating n_orient from %s to %s based on leadfield shape. Setting n_orient to %s.",
self.n_orient,
n_vec,
)
self.n_orient = n_vec
if n_vec not in (1, 3):
self.logger.warning(
"Leadfield last dimension is %s; expected orientation components "
"of size 1 or 3.",
n_vec,
)
return L.reshape(n_sensors, n_sources * n_vec)
raise ValueError(f"Leadfield must be 2-D or 3-D, got shape {L.shape}")
[docs]
def fit(self, L, y):
"""
Fit the inverse solver to the data.
Parameters
----------
L : np.ndarray
Leadfield matrix of shape (n_sensors, n_sources) for fixed
orientation or (n_sensors, n_sources, n_orient) for free orientation.
y : np.ndarray
Observed EEG/MEG signals of shape (n_sensors, n_times).
Returns
-------
self
The fitted estimator.
"""
self.logger.debug("Fitting the solver...")
self.L_ = self._format_leadfield(L)
self.y_ = y
return self
def _get_coef(self, y):
"""
Internal method to compute the source estimates.
Parameters
----------
y : np.ndarray
Observed EEG/MEG signals of shape (n_sensors, n_times).
Returns
-------
x_hat : np.ndarray
Estimated source activity of shape (n_sources, n_times).
- active_indices (np.ndarray): Indices of active sources.
- posterior_cov (np.ndarray): Posterior covariance matrix of estimated sources.
"""
# Apply the solver
if y is None:
if not hasattr(self, "y_"):
raise ValueError("No data available to compute source estimates. Fit the estimator or pass y.")
y = self.y_
solver_name = getattr(self.solver, "__name__", self.solver.__class__.__name__)
self.logger.debug(f"Estimating sources using {solver_name}...")
solver_kwargs = dict(self.solver_params or {})
solver_kwargs.update(
{
"L": self.L_,
"y": y,
"n_orient": self.n_orient,
"logger": self.logger,
}
)
try:
# Try passing noise_var if the solver accepts it
return self.solver(noise_var=self.noise_var, **solver_kwargs)
except TypeError as err:
if "noise_var" not in str(err):
raise # re-raise unexpected TypeErrors
# fallback for solvers that do not accept noise_var argument (e.g. joint learning with gamma_lambda_map_sflex())
return self.solver(**solver_kwargs)
[docs]
def predict(self, y=None):
if y is None:
if not hasattr(self, "y_"):
raise ValueError("Estimator has not been fitted and no data was provided to predict().")
y = self.y_
return self._get_coef(y)
# =================
# Gamma-MAP with Joint Learning
# =================
def _as_2d_y(y: np.ndarray) -> np.ndarray:
y = np.asarray(y, dtype=float)
if y.ndim == 1:
y = y[:, None]
if y.ndim != 2:
raise ValueError("y must have shape (M,T) or (M,).")
return y
def _validate_inverse_inputs(
L: np.ndarray,
y: np.ndarray,
n_orient: int,
) -> Tuple[np.ndarray, np.ndarray]:
L = np.asarray(L, dtype=float)
y = _as_2d_y(y)
if L.ndim != 2:
raise ValueError("L must be 2D.")
if L.shape[0] != y.shape[0]:
raise ValueError("L and y must have the same number of sensor rows.")
if n_orient not in (1, 2, 3):
raise ValueError("n_orient must be 1, 2, or 3.")
if L.shape[1] % n_orient != 0:
raise ValueError(
f"For n_orient={n_orient}, L must have k*N columns with k={n_orient}."
)
return L, y
def _expand_grouped_parameter(
value,
n_coeff: int,
group_size: int,
name: str,
) -> np.ndarray:
"""
Accepts:
- None
- scalar
- length n_coeff
- length n_groups
and returns length n_coeff.
"""
if n_coeff % group_size != 0:
raise ValueError("n_coeff must be divisible by group_size.")
n_groups = n_coeff // group_size
if value is None:
return np.ones(n_coeff, dtype=float)
if np.isscalar(value):
return np.full(n_coeff, float(value), dtype=float)
value = np.asarray(value, dtype=float).ravel()
if value.size == n_coeff:
return value.copy()
if value.size == n_groups:
return np.repeat(value, group_size).astype(float)
if isinstance(value, tuple) and len(value) == 2:
return np.linspace(value[0], value[1], num=n_coeff).astype(float)
raise ValueError(
f"{name} must be None, scalar, length {n_coeff}, or length {n_groups}."
)
def _build_sflex_operator(B, n_orient: int) -> csr_matrix:
"""
Build coefficient-space sFLEX operator.
For flattened coefficient ordering
[source1 comp1..k, source2 comp1..k, ...],
the correct operator is:
B ⊗ I_k
"""
if issparse(B):
B = B.tocsr()
else:
B = csr_matrix(np.asarray(B, dtype=float))
if B.shape[0] != B.shape[1]:
raise ValueError("B must be square.")
if n_orient == 1:
return B
Ik = eye(n_orient, format="csr")
return kron(B, Ik, format="csr")
def _lambda_opt(
M: np.ndarray,
G_active: np.ndarray,
x_active_norm: np.ndarray,
posterior_cov_active_norm: np.ndarray,
current_lambda_norm: np.ndarray,
CMinv: np.ndarray,
update_mode_noise: int,
) -> np.ndarray:
"""
Update diagonal lambda in NORMALIZED scale.
update_mode_noise
-----------------
1 : EM-style variance update
2 : Convex-bounding style update
"""
M = np.asarray(M, dtype=float)
G_active = np.asarray(G_active, dtype=float)
x_active_norm = np.asarray(x_active_norm, dtype=float)
posterior_cov_active_norm = np.asarray(posterior_cov_active_norm, dtype=float)
current_lambda_norm = np.asarray(current_lambda_norm, dtype=float)
CMinv = np.asarray(CMinv, dtype=float)
n_sensors, T = M.shape
eps = 1e-16
lam_new = np.zeros(n_sensors, dtype=float)
if update_mode_noise == 1:
for m in range(n_sensors):
residual = M[m, :] - (G_active[m, :] @ x_active_norm)
residual_term = float(np.mean(residual**2))
g_m = G_active[m, :]
cov_term = float(g_m @ posterior_cov_active_norm @ g_m)
lam_new[m] = residual_term + cov_term
elif update_mode_noise == 2:
for m in range(n_sensors):
residual = M[m, :] - (G_active[m, :] @ x_active_norm)
numerator = float(np.mean(residual**2))
denom = float(CMinv[m, m])
if denom > eps:
lam_new[m] = np.sqrt(max(numerator, 0.0) / denom)
else:
lam_new[m] = current_lambda_norm[m]
else:
raise ValueError("update_mode_noise must be 1 or 2.")
return np.maximum(lam_new, eps)
def _gamma_lambda_map_opt(
M: np.ndarray,
G: np.ndarray,
*,
maxit: int = 300,
tol: float = 1e-6,
update_mode: int = 2,
group_size: int = 1,
init_gamma=None,
init_lambda=None,
learn_lambda: bool = True,
update_mode_noise: int = 2,
lambda_damping: float = 1.0,
track_history: bool = True,
verbose: bool = False,
logger=None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, list]]:
"""
Internal optimizer for grouped gamma MAP with diagonal adaptive lambda.
Important conventions
---------------------
- grouped gamma updates are preserved through group_size
- init_lambda=None means ones in the INTERNAL normalized scale
- user-supplied init_lambda is interpreted in ORIGINAL sensor-variance units
"""
if logger is None:
logger = logging.getLogger(__name__)
G = np.asarray(G, dtype=float).copy()
M = _as_2d_y(M).copy()
n_coeff = G.shape[1]
n_sensors, n_times = M.shape
eps = np.finfo(float).eps
if n_coeff % group_size != 0:
raise ValueError("Number of coefficients must be divisible by group_size.")
# ------------------------------------------------------------
# Normalize M and G for numerical stability
# ------------------------------------------------------------
M_norm_c = float(np.linalg.norm(M @ M.T, ord="fro"))
if M_norm_c <= 0:
raise ValueError("Degenerate M.")
G_norm_c = float(np.linalg.norm(G, ord=np.inf))
if G_norm_c <= 0:
raise ValueError("Degenerate G.")
M /= (np.sqrt(M_norm_c) + eps)
G /= (G_norm_c + eps)
# ------------------------------------------------------------
# Init gamma
# ------------------------------------------------------------
gammas_full_old = _expand_grouped_parameter(
init_gamma,
n_coeff=n_coeff,
group_size=group_size,
name="init_gamma",
).astype(float)
gammas_full_old = np.maximum(gammas_full_old, 0.0)
# ------------------------------------------------------------
# Init lambda
# ------------------------------------------------------------
# FIX:
# init_lambda=None -> ones in INTERNAL NORMALIZED scale
# user-supplied init_lambda -> ORIGINAL scale, then normalized
if init_lambda is None:
if learn_lambda:
current_lambda = np.ones(n_sensors, dtype=float)
else:
raise ValueError("learn_lambda=False requires init_lambda.")
else:
if np.isscalar(init_lambda):
lambda_orig = np.full(n_sensors, float(init_lambda), dtype=float)
else:
lambda_orig = np.asarray(init_lambda, dtype=float).ravel()
if lambda_orig.size != n_sensors:
raise ValueError(
f"init_lambda must be scalar or length {n_sensors}."
)
current_lambda = np.maximum(lambda_orig, eps) / (M_norm_c + eps)
denom_fun = np.sqrt if update_mode == 2 else (lambda x: x)
hist: Dict[str, list] = {}
if track_history:
hist["n_active_hist"] = []
hist["err_gamma_hist"] = []
hist["lambda_mean_hist"] = []
active_indices = np.arange(n_coeff, dtype=int)
gammas_active_new = None
posterior_cov_active_norm = None
A = None
G_CMinvG = None
last_size = -1
for itno in range(int(maxit)):
gammas_active = gammas_full_old[active_indices]
gammas_active = np.nan_to_num(gammas_active, nan=0.0, posinf=0.0, neginf=0.0)
keep = np.abs(gammas_active) > eps
active_indices = active_indices[keep]
gammas_active = gammas_active[keep]
if active_indices.size == 0:
break
if active_indices.size % group_size != 0:
raise RuntimeError(
"Active coefficient count is not divisible by group_size. "
"Grouped coefficients must remain together."
)
G_active = G[:, active_indices]
# CM = G diag(gamma) G^T + diag(lambda)
CM = (G_active * gammas_active[None, :]) @ G_active.T
np.fill_diagonal(CM, CM.diagonal() + current_lambda)
try:
U, S, _ = linalg.svd(CM, full_matrices=False)
CMinv = (U / (S[None, :] + eps)) @ U.T
except linalg.LinAlgError:
CMinv = linalg.pinv(CM)
CMinvG = CMinv @ G_active
A = CMinvG.T @ M # (K_active, T)
# --------------------------------------------------------
# Gamma update
# --------------------------------------------------------
if update_mode == 1:
numer = gammas_active**2 * np.mean((A * A.conj()).real, axis=1)
denom = gammas_active * np.sum(G_active * CMinvG, axis=0)
elif update_mode == 2:
numer = gammas_active * np.sqrt(np.mean((A * A.conj()).real, axis=1))
denom = np.sum(G_active * CMinvG, axis=0)
elif update_mode == 3:
denom = None
numer = gammas_active**2 * np.mean((A * A.conj()).real, axis=1) + gammas_active * (
1.0 - gammas_active * np.sum(G_active * CMinvG, axis=0)
)
else:
raise ValueError("Invalid update_mode. Use 1, 2, or 3.")
if group_size == 1:
if denom is None:
gammas_active_new = numer
else:
gammas_active_new = numer / np.maximum(denom_fun(denom), eps)
else:
numer_comb = np.sum(numer.reshape(-1, group_size), axis=1)
if denom is None:
gammas_comb = numer_comb
else:
denom_comb = np.sum(denom.reshape(-1, group_size), axis=1)
gammas_comb = numer_comb / np.maximum(denom_fun(denom_comb), eps)
gammas_active_new = np.repeat(gammas_comb / group_size, group_size)
gammas_active_new = np.maximum(gammas_active_new, 0.0)
# --------------------------------------------------------
# Posterior covariance in normalized coefficient space
# --------------------------------------------------------
G_CMinvG = G_active.T @ CMinvG
posterior_cov_active_norm = (
np.diag(gammas_active_new)
- gammas_active_new[:, None] * G_CMinvG * gammas_active_new[None, :]
)
posterior_cov_active_norm = _symmetrize(posterior_cov_active_norm)
# --------------------------------------------------------
# Lambda update
# --------------------------------------------------------
if learn_lambda:
x_active_norm = gammas_active_new[:, None] * A
lam_new = _lambda_opt(
M=M,
G_active=G_active,
x_active_norm=x_active_norm,
posterior_cov_active_norm=posterior_cov_active_norm,
current_lambda_norm=current_lambda,
CMinv=CMinv,
update_mode_noise=update_mode_noise,
)
d = float(np.clip(lambda_damping, 0.0, 1.0))
current_lambda = (1.0 - d) * current_lambda + d * lam_new
gammas_full = np.zeros(n_coeff, dtype=float)
gammas_full[active_indices] = gammas_active_new
err = np.sum(np.abs(gammas_full - gammas_full_old)) / (
np.sum(np.abs(gammas_full_old)) + eps
)
gammas_full_old = gammas_full
if track_history:
hist["n_active_hist"].append(int(active_indices.size))
hist["err_gamma_hist"].append(float(err))
hist["lambda_mean_hist"].append(float(np.mean(current_lambda) * M_norm_c))
breaking = (err < tol) or (active_indices.size == 0)
if verbose and ((active_indices.size != last_size) or breaking):
logger.info(
f"it={itno:3d} active={active_indices.size:4d} "
f"err_gamma={err:.3e} "
f"lambda_mean(orig)={np.mean(current_lambda) * M_norm_c:.3e}"
)
last_size = active_indices.size
if breaking:
break
# ------------------------------------------------------------
# Empty solution
# ------------------------------------------------------------
if gammas_active_new is None or active_indices.size == 0:
x_active = np.zeros((0, n_times), dtype=float)
cov_out = np.zeros((0, 0), dtype=float)
gammas_full = np.zeros(n_coeff, dtype=float)
lambda_final = current_lambda * M_norm_c
return x_active, active_indices, cov_out, gammas_full, lambda_final, hist
# ------------------------------------------------------------
# Undo normalization back to original coefficient scale
# ------------------------------------------------------------
n_const = np.sqrt(M_norm_c) / (G_norm_c + eps)
x_active = n_const * gammas_active_new[:, None] * A
cov_out = (
np.diag(gammas_active_new)
- gammas_active_new[:, None] * G_CMinvG * gammas_active_new[None, :]
)
cov_out = (n_const**2) * cov_out
cov_out = _symmetrize(cov_out)
lambda_final = current_lambda * M_norm_c
return x_active, active_indices, cov_out, gammas_full_old, lambda_final, hist
def gamma_lambda_map(
L: np.ndarray,
y: np.ndarray,
n_orient: int = 1,
init_gamma=None,
init_lambda=None,
max_iter: int = 300,
tol: float = 1e-6,
update_mode: int = 2,
learn_lambda: bool = True,
update_mode_noise: int = 2,
lambda_damping: float = 1.0,
track_history: bool = True,
verbose: bool = False,
logger=None,
) -> Dict[str, Any]:
"""
Grouped gamma-lambda MAP with diagonal adaptive lambda.
Supports
--------
n_orient = 1 : fixed
n_orient = 2 : reduced free MEG
n_orient = 3 : free EEG
Returns
-------
dict with keys:
posterior_mean : (N,T) if n_orient=1 else (kN,T)
posterior_mean_reshaped : (N,k,T) for k>1
posterior_cov : full coefficient covariance, shape (kN,kN)
posterior_cov_active : active-only covariance
active_indices : active coefficient indices
gammas_full : length kN
gamma : mean(gammas_full)
lambdas : diagonal lambda vector, original scale
lambda_mean : mean diagonal lambda
"""
if logger is None:
logger = logging.getLogger(__name__)
L, y = _validate_inverse_inputs(L=L, y=y, n_orient=n_orient)
n_coeff = L.shape[1]
n_sources = n_coeff // n_orient
x_active, active_idx, cov_active, gammas_full, lambdas, hist = _gamma_lambda_map_opt(
M=y,
G=L,
maxit=max_iter,
tol=tol,
update_mode=update_mode,
group_size=n_orient,
init_gamma=init_gamma,
init_lambda=init_lambda,
learn_lambda=learn_lambda,
update_mode_noise=update_mode_noise,
lambda_damping=lambda_damping,
track_history=track_history,
verbose=verbose,
logger=logger,
)
x_hat = np.zeros((n_coeff, y.shape[1]), dtype=float)
if active_idx.size > 0:
x_hat[active_idx] = x_active
posterior_cov = np.zeros((n_coeff, n_coeff), dtype=float)
if active_idx.size > 0:
posterior_cov[np.ix_(active_idx, active_idx)] = cov_active
posterior_cov = _symmetrize(posterior_cov)
out = {
"posterior_mean": x_hat if n_orient > 1 else x_hat.reshape(n_sources, y.shape[1]),
"posterior_cov": posterior_cov,
"posterior_cov_active": cov_active,
"active_indices": active_idx,
"gamma": float(np.mean(gammas_full)) if gammas_full.size else 0.0,
"gammas_full": gammas_full,
"lambdas": np.asarray(lambdas, dtype=float),
"lambda_mean": float(np.mean(lambdas)) if lambdas.size else 0.0,
"noise_var": float(np.mean(lambdas)) if lambdas.size else 0.0, # compatibility alias
"coefficient_indices": np.arange(n_coeff),
"source_indices": np.arange(n_sources),
}
if n_orient > 1:
out["posterior_mean_reshaped"] = x_hat.reshape(n_sources, n_orient, y.shape[1])
if track_history:
out.update(hist)
return out
[docs]
def gamma_lambda_map_sflex(
L: np.ndarray,
y: np.ndarray,
n_orient: int = 1,
init_gamma=None,
init_lambda=None,
learn_lambda: bool = True,
update_mode_noise: int = 2,
lambda_damping: float = 1.0,
max_iter: int = 300,
tol: float = 1e-6,
update_mode: int = 2,
track_history: bool = True,
sigma: float = 0.01,
threshold_factor: float = 3.0,
normalize: Optional[str] = "sym",
eps: float = 1e-12,
verbose: bool = False,
logger=None,
**kwargs,
) -> Dict[str, Any]:
"""
Unified sFLEX + gamma-lambda MAP.
Source model
------------
x = (B ⊗ I_k) c
Supports
--------
n_orient = 1 : fixed
n_orient = 2 : reduced free MEG
n_orient = 3 : free EEG
Returns
-------
dict with posterior quantities in SOURCE space, plus coefficient-space
auxiliaries for debugging / analysis.
"""
if logger is None:
logger = logging.getLogger(__name__)
L, y = _validate_inverse_inputs(L=L, y=y, n_orient=n_orient)
n_coeff = L.shape[1]
n_sources = n_coeff // n_orient
T = y.shape[1]
B = compute_B(
sigma=sigma,
threshold_factor=threshold_factor,
normalize=normalize,
eps=eps,
src_coords=kwargs.get("src_coords"),
)
if issparse(B):
B = B.tocsr()
else:
B = csr_matrix(np.asarray(B, dtype=float))
if B.shape != (n_sources, n_sources):
raise ValueError(
f"B must have shape ({n_sources},{n_sources}); got {B.shape}."
)
B_op = _build_sflex_operator(B, n_orient=n_orient) # (kN, kN)
G = L @ B_op
res_coeff = gamma_lambda_map(
L=G,
y=y,
n_orient=n_orient,
init_gamma=init_gamma,
init_lambda=init_lambda,
max_iter=max_iter,
tol=tol,
update_mode=update_mode,
learn_lambda=learn_lambda,
update_mode_noise=update_mode_noise,
lambda_damping=lambda_damping,
track_history=track_history,
verbose=verbose,
logger=logger,
)
# coefficient posterior mean c_hat
c_hat_flat = np.asarray(res_coeff["posterior_mean"], dtype=float)
if n_orient == 1:
c_hat_flat = c_hat_flat.reshape(n_sources, T)
else:
c_hat_flat = c_hat_flat.reshape(n_coeff, T)
# map mean to source space: x = (B ⊗ I_k) c
x_hat_flat = B_op @ c_hat_flat
x_hat_flat = np.asarray(x_hat_flat, dtype=float)
# map active covariance to source space
active = np.asarray(res_coeff["active_indices"], dtype=int)
Sigma_c_active = np.asarray(res_coeff["posterior_cov_active"], dtype=float)
if active.size > 0:
B_active = B_op[:, active]
B_active_dense = B_active.toarray()
posterior_cov_x = B_active_dense @ Sigma_c_active @ B_active_dense.T
posterior_cov_x = _symmetrize(np.asarray(posterior_cov_x, dtype=float))
else:
posterior_cov_x = np.zeros((n_coeff, n_coeff), dtype=float)
out = {
"posterior_mean": x_hat_flat if n_orient > 1 else x_hat_flat.reshape(n_sources, T),
"posterior_cov": posterior_cov_x,
"posterior_cov_active": posterior_cov_x[np.ix_(active, active)] if active.size > 0 else np.zeros((0, 0)),
"active_indices": active,
"gamma": float(res_coeff["gamma"]),
"gammas_full": np.asarray(res_coeff["gammas_full"], dtype=float),
"lambdas": np.asarray(res_coeff["lambdas"], dtype=float),
"lambda_mean": float(res_coeff["lambda_mean"]),
"noise_var": float(res_coeff["lambda_mean"]), # compatibility alias
# coefficient-space extras
"posterior_mean_coeff": c_hat_flat if n_orient > 1 else c_hat_flat.reshape(n_sources, T),
"posterior_cov_coeff": np.asarray(res_coeff["posterior_cov"], dtype=float),
"posterior_cov_active_coeff": np.asarray(res_coeff["posterior_cov_active"], dtype=float),
"B_operator": B_op,
"coefficient_indices": np.arange(n_coeff),
"source_indices": np.arange(n_sources),
}
if n_orient > 1:
out["posterior_mean_reshaped"] = x_hat_flat.reshape(n_sources, n_orient, T)
out["posterior_mean_coeff_reshaped"] = c_hat_flat.reshape(n_sources, n_orient, T)
if track_history:
for key in ["n_active_hist", "err_gamma_hist", "lambda_mean_hist"]:
if key in res_coeff:
out[key] = res_coeff[key]
return out