import numpy as np
import pandas as pd
import logging # Import logging
import numpy as np
from mne.io.constants import FIFF
from scipy.spatial.distance import cdist
from mne import read_forward_solution, convert_forward_solution
from ot import emd2 # Earth Mover's Distance (Wasserstein-2)
from mne.inverse_sparse.mxne_inverse import _make_sparse_stc
from sklearn.metrics import jaccard_score, mean_squared_error, f1_score
[docs]
class MetricEvaluator:
[docs]
def __init__(self, confidence_levels : np.ndarray = None, metrics : list[str] = None, logger : logging.Logger = None):
"""
Initialize the MetricEvaluator with confidence levels, metrics, and a logger.
Parameters
----------
confidence_levels : np.ndarray, optional
Array of confidence levels to evaluate metrics against.
metrics : list[str], optional
List of metric names (method names) to evaluate.
logger : logging.Logger, optional
Logger instance for logging debug and error messages.
"""
self.confidence_levels = confidence_levels
self.metrics = metrics if metrics is not None else []
self.logger = logger
# Calibration curve metrics
[docs]
def mean_calibration_error(self, empirical_coverage, **kwargs):
"""Calculate the area under the curve (AUC) deviation, which measures the average calibration error.
Parameters
----------
kwargs : dict
Additional keyword arguments that may be needed for metric calculations:
- 'empirical_coverage': np.ndarray, empirical coverage values.
Returns
-------
float
The AUC deviation value.
"""
delta_c = np.diff(self.confidence_levels, prepend=self.confidence_levels[0])
abs_dev = np.abs(empirical_coverage - self.confidence_levels)
return np.sum(abs_dev * delta_c)
[docs]
def max_underconfidence_deviation(self, empirical_coverage, **kwargs):
"""Calculate the maximum positive deviation from the confidence levels ().
Parameters
----------
empirical_coverage : np.ndarray
Empirical coverage values.
kwargs : dict
Additional keyword arguments that may be needed for metric calculations:
- 'empirical_coverage': np.ndarray, empirical coverage values.
Returns
-------
float
The maximum positive deviation value.
"""
deviation = empirical_coverage - self.confidence_levels
return np.max(deviation)
[docs]
def max_overconfidence_deviation(self, empirical_coverage, **kwargs):
"""Calculate the maximum negative deviation from the confidence levels.
Parameters
----------
empirical_coverage : np.ndarray
Empirical coverage values.
kwargs : dict
Additional keyword arguments that may be needed for metric calculations:
- 'empirical_coverage': np.ndarray, empirical coverage values.
Returns
-------
float
The maximum negative deviation value.
"""
deviation = empirical_coverage - self.confidence_levels
return -np.min(deviation) #TODO: check whether we need the minus here!
[docs]
def mean_absolute_deviation(self, empirical_coverage, **kwargs):
"""Calculate the mean absolute deviation from the confidence levels.
----------
kwargs : dict
Additional keyword arguments that may be needed for metric calculations:
- 'empirical_coverage': np.ndarray, empirical coverage values.
Returns
-------
float
The maximum absolute deviation value.
"""
deviation = empirical_coverage - self.confidence_levels
return np.mean(np.abs(deviation))
[docs]
def mean_signed_deviation(self, empirical_coverage, **kwargs):
"""Calculate the mean signed deviation from the confidence levels.
----------
kwargs : dict
Additional keyword arguments that may be needed for metric calculations:
- 'empirical_coverage': np.ndarray, empirical coverage values.
Returns
-------
float
The mean signed deviation value.
"""
deviation = empirical_coverage - self.confidence_levels
return np.mean(deviation)
[docs]
def mean_posterior_std(self, cov, **kwargs):
"""Calculate the mean posterior standard deviation.
Parameters
----------
kwargs : dict
Additional keyword arguments that may be needed for metric calculations:
- 'cov': np.ndarray, covariance matrix for uncertainty metrics.
Returns
-------
float
The mean posterior standard deviation.
"""
posterior_std = np.sqrt(np.diag(cov))
# If a mask is needed, it should be an attribute of self, e.g., self.active_mask
# For now, calculating mean over all available std values.
# if hasattr(self, 'active_mask') and self.active_mask is not None:
# return {"mean_posterior_std": np.mean(posterior_std[self.active_mask])}
return np.mean(posterior_std)
[docs]
def emd(self, x, x_hat, orientation_type, subject, fwd_path, **kwargs):
"""
Compute Earth Mover's Distance (EMD) between true and estimated source activations.
Adapted from BSI-ZOO
Parameters:
- x : (n_sources, n_times) or (n_sources, 3, n_times)
Ground truth source time courses.
- x_hat : same shape as x
Estimated source time courses.
- orientation_type : str
'fixed' or 'free' for orientation modeling.
- subject : str
Subject ID used to locate the forward model.
Returns:
- float
Earth Mover's Distance between normalized source distributions.
"""
if orientation_type == "fixed":
temp = np.linalg.norm(x, axis=1)
a_mask = temp != 0
a = temp[a_mask]
temp = np.linalg.norm(x_hat, axis=1)
b_mask = temp != 0
b = temp[b_mask]
# temp_ = np.partition(-temp, nnz)
# b = -temp_[:nnz] # get n(=nnz) max amplitudes
# b = -temp_[:nnz] # get n(=nnz) max amplitudes
elif orientation_type == "free":
temp = np.linalg.norm(x, axis=2)
temp = np.linalg.norm(temp, axis=1)
a_mask = temp != 0
a = temp[a_mask]
temp = np.linalg.norm(x_hat, axis=2)
temp = np.linalg.norm(temp, axis=1)
b_mask = temp != 0
b = temp[b_mask]
# temp_ = np.partition(-temp, nnz)
# b = -temp_[:nnz] # get n(=nnz) max amplitudes
# Step 3: Load the forward solution and extract source locations
fwd = read_forward_solution(f"{fwd_path}/{subject}-fwd.fif")
# if orientation_type == "fixed":
# if fwd["source_ori"] == FIFF.FIFFV_MNE_FREE_ORI:
# fwd = convert_forward_solution(fwd, force_fixed=True) # surf_ori=True
fwd = convert_forward_solution(fwd, force_fixed=True)
src = fwd["src"]
stc_a = _make_sparse_stc(a[:, None], a_mask, fwd, tmin=1, tstep=1)
stc_b = _make_sparse_stc(b[:, None], b_mask, fwd, tmin=1, tstep=1)
rr_a = np.r_[src[0]["rr"][stc_a.lh_vertno], src[1]["rr"][stc_a.rh_vertno]]
rr_b = np.r_[src[0]["rr"][stc_b.lh_vertno], src[1]["rr"][stc_b.rh_vertno]]
M = cdist(rr_a, rr_b, metric="euclidean")
# Normalize a and b as EMD is defined between probability distributions
a /= a.sum()
b /= b.sum()
return emd2(a, b, M)
[docs]
def jaccard_error(self, x, x_hat, orientation_type=None, **kwargs):
"""
TODO: To be checked!
Calculate Jaccard error between true and estimated active source sets.
Parameters
----------
x : np.ndarray
True source activations
x_hat : np.ndarray
Estimated source activations
orientation_type : str, optional
'fixed' or 'free' for orientation modeling
**kwargs : dict
Additional arguments
Returns
-------
float
Jaccard error (1 - Jaccard index) between active source sets
"""
# Convert continuous activations to binary (active/inactive)
if orientation_type == "fixed":
# For fixed orientation: check if source amplitude > threshold
x_binary = (np.linalg.norm(x, axis=1) > 1e-10).astype(int)
x_hat_binary = (np.linalg.norm(x_hat, axis=1) > 1e-10).astype(int)
elif orientation_type == "free":
# For free orientation: check if source amplitude > threshold
x_binary = (np.linalg.norm(np.linalg.norm(x, axis=2), axis=1) > 1e-10).astype(int)
x_hat_binary = (np.linalg.norm(np.linalg.norm(x_hat, axis=2), axis=1) > 1e-10).astype(int)
# Compute Jaccard score for binary arrays
jaccard_score_value = jaccard_score(x_binary, x_hat_binary, average='binary')
return 1 - jaccard_score_value # Convert to error (lower is better)
[docs]
def mse(self, x, x_hat, orientation_type, **kwargs):
if orientation_type == "free":
x = np.linalg.norm(x, axis=2)
x_hat = np.linalg.norm(x_hat, axis=2)
return mean_squared_error(x, x_hat)
def _get_active_nnz(self, x, x_hat, orientation_type, subject, fwd_path, nnz):
"adapted from BSI-ZOO"
fwd = read_forward_solution(f"{fwd_path}/{subject}-fwd.fif")
if orientation_type == "fixed":
fwd = convert_forward_solution(fwd, force_fixed=True)
active_set = np.linalg.norm(x, axis=1) != 0
# check if no vertices are estimated
temp = np.linalg.norm(x_hat, axis=1)
if len(np.unique(temp)) == 1:
print("No vertices estimated!")
temp_ = np.partition(-temp, nnz)
max_temp = -temp_[:nnz] # get n(=nnz) max amplitudes
# remove 0 from list incase less vertices than nnz were estimated
max_temp = np.delete(max_temp, np.where(max_temp == 0.0))
active_set_hat = np.array(list(map(max_temp.__contains__, temp)))
stc = _make_sparse_stc(
x[active_set], active_set, fwd, tmin=1, tstep=1
) # ground truth
stc_hat = _make_sparse_stc(
x_hat[active_set_hat], active_set_hat, fwd, tmin=1, tstep=1
) # estimate
elif orientation_type == "free":
fwd = convert_forward_solution(fwd)
# temp = np.linalg.norm
active_set = np.linalg.norm(x, axis=2) != 0
temp = np.linalg.norm(x_hat, axis=2)
temp = np.linalg.norm(temp, axis=1)
temp_ = np.partition(-temp, nnz)
max_temp = -temp_[:nnz] # get n(=nnz) max amplitudes
max_temp = np.delete(max_temp, np.where(max_temp == 0.0))
active_set_hat = np.array(list(map(max_temp.__contains__, temp)))
active_set_hat = np.repeat(active_set_hat, 3).reshape(
active_set_hat.shape[0], -1
)
stc = _make_sparse_stc(
x[active_set], active_set, fwd, tmin=1, tstep=1
) # ground truth
stc_hat = _make_sparse_stc(
x_hat[active_set_hat], active_set_hat, fwd, tmin=1, tstep=1
) # estimate
return stc, stc_hat, active_set, active_set_hat, fwd
[docs]
def euclidean_distance(self, x, x_hat, orientation_type, subject, nnz, fwd_path, **kwargs):
"adapted from BSI-ZOO"
stc, stc_hat, _, _, fwd = self._get_active_nnz(x, x_hat, orientation_type, subject, fwd_path, nnz)
# euclidean distance check
lh_coordinates = fwd["src"][0]["rr"][stc.lh_vertno]
lh_coordinates_hat = fwd["src"][0]["rr"][stc_hat.lh_vertno]
rh_coordinates = fwd["src"][1]["rr"][stc.rh_vertno]
rh_coordinates_hat = fwd["src"][1]["rr"][stc_hat.rh_vertno]
coordinates = np.concatenate([lh_coordinates, rh_coordinates], axis=0)
coordinates_hat = np.concatenate([lh_coordinates_hat, rh_coordinates_hat], axis=0)
euclidean_distance = np.linalg.norm(
coordinates[: coordinates_hat.shape[0], :] - coordinates_hat, axis=1
)
return np.mean(euclidean_distance)
[docs]
def f1(self, x, x_hat, orientation_type, **kwargs):
"adapted from BSI-ZOO"
if orientation_type == "fixed":
active_set = np.linalg.norm(x, axis=1) != 0
active_set_hat = np.linalg.norm(x_hat, axis=1) != 0
elif orientation_type == "free":
temp = np.linalg.norm(x, axis=2)
active_set = np.linalg.norm(temp, axis=1) != 0
temp = np.linalg.norm(x_hat, axis=2)
active_set_hat = np.linalg.norm(temp, axis=1) != 0
return f1_score(active_set, active_set_hat)
[docs]
def accuracy(self, x, x_hat, orientation_type, **kwargs):
"""
Calculate accuracy between true and estimated active source sets.
Accuracy = (TP + TN) / (TP + TN + FP + FN)
where:
- TP: True Positives (correctly identified active sources)
- TN: True Negatives (correctly identified inactive sources)
- FP: False Positives (incorrectly identified as active)
- FN: False Negatives (missed active sources)
Parameters
----------
x : np.ndarray
True source activations (n_sources, n_times) or (n_sources, 3, n_times)
x_hat : np.ndarray
Estimated source activations (same shape as x)
orientation_type : str
'fixed' or 'free' for orientation modeling
**kwargs : dict
Additional arguments
Returns
-------
float
Accuracy score between 0.0 and 1.0 (higher is better)
"""
# Convert continuous activations to binary (active/inactive)
if orientation_type == "fixed":
# For fixed orientation: check if source amplitude > threshold
true_active = (np.linalg.norm(x, axis=1) > 1e-10).astype(int)
pred_active = (np.linalg.norm(x_hat, axis=1) > 1e-10).astype(int)
elif orientation_type == "free":
# For free orientation: check if source amplitude > threshold
temp_true = np.linalg.norm(x, axis=2)
true_active = (np.linalg.norm(temp_true, axis=1) > 1e-10).astype(int)
temp_pred = np.linalg.norm(x_hat, axis=2)
pred_active = (np.linalg.norm(temp_pred, axis=1) > 1e-10).astype(int)
else:
raise ValueError(f"Unknown orientation_type: {orientation_type}")
# Calculate confusion matrix components
tp = np.sum((true_active == 1) & (pred_active == 1)) # True Positives
tn = np.sum((true_active == 0) & (pred_active == 0)) # True Negatives
fp = np.sum((true_active == 0) & (pred_active == 1)) # False Positives
fn = np.sum((true_active == 1) & (pred_active == 0)) # False Negatives
# Calculate accuracy
total = tp + tn + fp + fn
if total == 0:
return 1.0 # Perfect accuracy when no sources exist
accuracy_score = (tp + tn) / total
return accuracy_score
# Evaluate and store metrics
[docs]
def evaluate_and_store_metrics(self, current_results_dict : dict, metric_suffix="", **kwargs):
"""Evaluate metrics and update the results dictionary.
Parameters
----------
current_results_dict : dict
Dictionary to store the results of the metrics.
metric_suffix : str, optional
Suffix to add to metric keys (e.g., "_all_sources", "_active_indices").
kwargs : dict
Additional keyword arguments that may be needed for metric calculations:
- 'empirical_coverage': np.ndarray, empirical coverage values.
- 'cov': np.ndarray, covariance matrix for uncertainty metrics.
"""
if not self.metrics: # Handles if self.metrics is an empty list
self.logger.info(f"No metrics to call for suffix '{metric_suffix}' (self.metrics is empty).")
return
self.logger.debug(
f"Evaluating metrics with suffix: '{metric_suffix}' from self.metrics: {self.metrics} "
f"for instance of {type(self).__name__}"
)
for metric_name_str in self.metrics:
metric_output = {} # Initialize for each metric
try:
if hasattr(self, metric_name_str):
method = getattr(self, metric_name_str)
if callable(method):
self.logger.debug(f"Calling metric method: {metric_name_str} with suffix '{metric_suffix}'")
# Call the metric method with kwargs, which should contain necessary parameters
result = method(**kwargs)
# Wrap scalar outputs into a dict
metric_output = {f"{metric_name_str}{metric_suffix}": result}
else:
self.logger.error(
f"Attribute '{metric_name_str}' found in {type(self).__name__} but it is not callable "
f"(suffix: '{metric_suffix}'). Skipping."
)
metric_output = {f"{metric_name_str}{metric_suffix}_error": "Attribute not callable"}
else:
self.logger.error(
f"Metric method '{metric_name_str}' not found in {type(self).__name__} "
f"(suffix: '{metric_suffix}'). Skipping."
)
metric_output = {f"{metric_name_str}{metric_suffix}_error": "Method not found"}
except Exception as e:
self.logger.error(
f"Unexpected error evaluating metric method {metric_name_str} (suffix: '{metric_suffix}') "
f"on '{type(self).__name__}': {e}", exc_info=True
)
metric_output = {f"{metric_name_str}{metric_suffix}_error": f"Execution error: {str(e)}"}
current_results_dict.update(metric_output)