import numpy as np
import pandas as pd
from scipy import linalg
from warnings import warn
from sklearn.base import BaseEstimator, RegressorMixin
from scipy import linalg
[docs]
def gamma_map(
L,
y,
noise_type="oracle",
cov=None, # covariance matrix of the noise
n_orient=1,
max_iter=1000,
tol=1e-15,
update_mode=2,
# threshold=1e-5,
gammas=None,
verbose=True,
logger=None,
):
if noise_type == "oracle":
sigma_squared = np.diag(cov)[0] # sigma_squared: noise variance = diagonal of the covariance matrix, where all diagonal elements are equal.
# NOTE - TODO: override and hardcode for now, but should be changed to use the computed covariance matrix and sigma_squared
sigma_squared = 0.01
cov = sigma_squared * np.eye(L.shape[0])
# whiten the data
whitener = linalg.inv(linalg.sqrtm(cov))
y = whitener @ y
L = whitener @ L # Note: L is already shaped into (n_sensors, n_sources * n_orient)
if gammas is None:
gammas = np.ones(L.shape[1], dtype=np.float64)
elif isinstance(gammas, (float, np.float64, int, np.int64)):
gammas = np.full((L.shape[1],), gammas, dtype=np.float64)
elif len(gammas) == 2 and isinstance(gammas, tuple):
gammas = np.linspace(gammas[0], gammas[1], num=L.shape[1])
else:
raise ValueError("gammas should be a float, a tuple of two floats, or a list of floats.")
x_hat_, active_set, posterior_cov = _gamma_map_opt(
y,
L,
sigma_squared=sigma_squared,
tol=tol,
maxit=max_iter,
gammas=gammas,
update_mode=update_mode,
group_size=n_orient,
verbose=verbose,
logger=logger,
)
x_hat = np.zeros((L.shape[1], y.shape[1]))
x_hat[active_set] = x_hat_
if n_orient > 1:
x_hat = x_hat.reshape((-1, n_orient, x_hat.shape[1]))
return x_hat, active_set, posterior_cov
def _gamma_map_opt(
M,
G,
sigma_squared,
maxit=10000,
tol=1e-6,
update_mode=2,
group_size=1,
gammas=None,
verbose=None,
logger=None,
):
"""Hierarchical Bayes (Gamma-MAP).
Parameters
----------
M : array, shape=(n_sensors, n_times)
Observation.
G : array, shape=(n_sensors, n_sources)
Forward operator.
sigma_squared : float
Regularization parameter (noise variance).
maxit : int
Maximum number of iterations.
tol : float
Tolerance parameter for convergence.
group_size : int
Number of consecutive sources which use the same gamma.
update_mode : int
Update mode, 1: MacKay update (default), 3: Modified MacKay update.
gammas : array, shape=(n_sources,)
Initial values for posterior variances (gammas). If None, a
variance of 1.0 is used.
%(verbose)s
Returns
-------
X : array, shape=(n_active, n_times)
Estimated source time courses.
active_set : array, shape=(n_active,)
Indices of active sources.
posterior_cov: array, shape=(n_active, n_active)
Posterior coveriance matrix of estimated active sources
"""
G = G.copy()
M = M.copy()
n_sources = G.shape[1]
n_sensors, n_times = M.shape
eps = np.finfo(float).eps
# apply normalization so the numerical values are sane
M_normalize_constant = np.linalg.norm(np.dot(M, M.T), ord="fro")
M /= np.sqrt(M_normalize_constant)
sigma_squared /= M_normalize_constant
G_normalize_constant = np.linalg.norm(G, ord=np.inf)
G /= G_normalize_constant
if n_sources % group_size != 0:
raise ValueError(
"Number of sources has to be evenly dividable by the " "group size"
)
n_active = n_sources
active_set = np.arange(n_sources)
gammas_full_old = gammas.copy()
if update_mode == 2:
denom_fun = np.sqrt
else:
# do nothing
def denom_fun(x):
return x
last_size = -1
for itno in range(maxit):
gammas[np.isnan(gammas)] = 0.0
gidx = np.abs(gammas) > eps
active_set = active_set[gidx]
gammas = gammas[gidx]
# update only active gammas (once set to zero it stays at zero)
if n_active > len(active_set):
n_active = active_set.size
G = G[:, gidx]
CM = np.dot(G * gammas[np.newaxis, :], G.T)
CM.flat[:: n_sensors + 1] += sigma_squared
# Invert CM keeping symmetry
U, S, _ = linalg.svd(CM, full_matrices=False)
S = S[np.newaxis, :]
del CM
CMinv = np.dot(U / (S + eps), U.T)
CMinvG = np.dot(CMinv, G)
A = np.dot(CMinvG.T, M) # mult. w. Diag(gamma) in gamma update
# G_CMinvG = G.T @ CMinvG
if update_mode == 1:
# MacKay fixed point update (10) in [1]
numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1)
denom = gammas * np.sum(G * CMinvG, axis=0)
elif update_mode == 2:
# modified MacKay fixed point update (11) in [1]
numer = gammas * np.sqrt(np.mean((A * A.conj()).real, axis=1))
denom = np.sum(G * CMinvG, axis=0) # sqrt is applied below
elif update_mode == 3:
# Expectation Maximization (EM) update
denom = None
numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1) + gammas * (
1 - gammas * np.sum(G * CMinvG, axis=0)
)
else:
raise ValueError("Invalid value for update_mode")
if group_size == 1:
if denom is None:
gammas = numer
else:
gammas = numer / np.maximum(denom_fun(denom), np.finfo("float").eps)
else:
numer_comb = np.sum(numer.reshape(-1, group_size), axis=1)
if denom is None:
gammas_comb = numer_comb
else:
denom_comb = np.sum(denom.reshape(-1, group_size), axis=1)
gammas_comb = numer_comb / denom_fun(denom_comb)
gammas = np.repeat(gammas_comb / group_size, group_size)
# compute convergence criterion
gammas_full = np.zeros(n_sources, dtype=np.float64)
gammas_full[active_set] = gammas
err = np.sum(np.abs(gammas_full - gammas_full_old)) / np.sum(
np.abs(gammas_full_old)
)
gammas_full_old = gammas_full
breaking = err < tol or n_active == 0
if len(gammas) != last_size or breaking:
# logger.info(
# "Iteration: %d\t active set size: %d\t convergence: "
# "%0.3e" % (itno, len(gammas), err)
# )
last_size = len(gammas)
if breaking:
break
if itno < maxit - 1:
logger.info(
"Iteration: %d\t active set size: %d\t convergence: "
"%0.3e" % (itno, len(gammas), err)
)
logger.info("\nConvergence reached !\n")
else:
logger.info(
"Iteration: %d\t active set size: %d\t convergence: "
"%0.3e" % (itno, len(gammas), err)
)
warn("\nConvergence NOT reached !\n")
# undo normalization and compute final posterior mean
n_const = np.sqrt(M_normalize_constant) / G_normalize_constant
x_active = n_const * gammas[:, None] * A
# Compute the posterior convariance matrix as in eq. (2.10) in Hashemi, Ali. "Advances in hierarchical Bayesian learning with applications to neuroimaging." (2023).
# pos_cov = np.diag(gammas) - gammas[:, np.newaxis] * G_CMinvG * gammas
posterior_cov = np.diag(gammas) - gammas[:, np.newaxis] * G.T @ CMinv @ G * gammas
# A similar approach can be implmented (as Large_gamma is interpreted as adiagonal matrix with small_gammas:
# posterior_cov = np.diag(gammas) - np.diag(gammas) @ G.T @ CMinv @ G @ np.diag(gammas)
return x_active, active_set, posterior_cov
[docs]
def eloreta(L, y, **kwargs):
raise NotImplementedError("The eloreta solver is not yet implemented.")
[docs]
class SourceEstimator(BaseEstimator, RegressorMixin):
def __init__(self, solver, solver_params=None, cov=None, n_orient=1, logger=None):
"""
Initialize the SourceEstimator class.
Parameters:
- solver (callable): The inverse solver function (e.g., gamma_map, eloreta).
- solver_params (dict, optional): Parameters for the solver function.
- logger (logging.Logger, optional): Logger instance for logging messages.
- cov (np.ndarray, optional): Covariance matrix of the noise.
- n_orient (int, optional): Number of orientations for the sources.
Default is 1 (for fixed orientation) or 3 (for free orientation).
"""
self.solver = solver
self.solver_params = solver_params if solver_params else {}
self.logger = logger
self.cov = cov
self.n_orient = n_orient
[docs]
def fit(self, L, y):
"""
Fit the inverse solver to the data.
Parameters:
- L (np.ndarray): Leadfield matrix of shape (n_sensors, n_sources).
- y (np.ndarray): Observed EEG/MEG signals of shape (n_sensors, n_times).
Returns:
- self: The fitted estimator.
"""
self.L_ = L
self.y_ = y
return self
[docs]
def predict(self, y=None):
"""
Predict the source activity given the observed signals.
Parameters:
- y (np.ndarray, optional): Observed EEG/MEG signals of shape (n_sensors, n_times).
If None, uses the signals provided during `fit`.
Returns:
- x_hat (np.ndarray): Estimated source activity of shape (n_sources, n_times).
- active_set (np.ndarray): Indices of active sources.
- posterior_cov (np.ndarray): Posterior covariance matrix of estimated sources.
"""
if not hasattr(self, "L_") or not hasattr(self, "y_"):
raise ValueError("The estimator must be fitted with `fit(L, y)` before calling `predict()`.")
# enable the use to pass y for inference
if y is None:
y = self.y_
# Apply the solver
x_hat, active_set, posterior_cov = self.solver(self.L_, y, logger=self.logger, **self.solver_params, cov=self.cov, n_orient=self.n_orient)
return x_hat, active_set, posterior_cov