import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import chi2
from itertools import combinations
from matplotlib.patches import Ellipse
import mne
import logging
[docs]
class UncertaintyEstimator:
def __init__(self, orientation_type, x, x_hat, active_set, posterior_cov, experiment_dir=None, logger=None):
"""
Initialize the uncertainty estimator.
Parameters:
- orientation_type (str): Orientation type ('fixed' or 'free').
- x (np.ndarray): Ground truth source activity.
- x_hat (np.ndarray): Estimated source activity.
- active_set (np.ndarray): Indices of active sources.
- posterior_cov (np.ndarray, optional): Posterior covariance matrix.
- experiment_dir (str, optional): Directory for experiment results.
- logger (logging.Logger, optional): Logger instance for logging messages.
"""
self.orientation_type = orientation_type
self.x = x
self.x_hat = x_hat
self.active_set = active_set
self.posterior_cov = posterior_cov
self.experiment_dir = experiment_dir
self.logger = logger
[docs]
def reshape_source_data(self, n_times):
"""
Reshape data based on the orientation type.
Parameters:
- n_times (int): Number of time points.
Returns:
- x (np.ndarray): Reshaped ground truth source activity.
- x_hat (np.ndarray): Reshaped estimated source activity.
"""
if self.orientation_type == "free":
self.x = self.x.reshape(-1, n_times)
self.x_hat = self.x_hat.reshape(-1, n_times)
return self.x, self.x_hat
[docs]
def construct_full_covariance(self):
"""
Create a full covariance matrix corresponding to all source components,
embedding the posterior covariance of the active set.
Handles both 'fixed' and 'free' orientation types by inspecting the
shape of self.x and the self.orientation_type attribute.
Returns:
- full_posterior_cov (np.ndarray): Full posterior covariance matrix.
Shape: (n_total_components, n_total_components).
Raises:
- ValueError: If input shapes, orientation_type, or active_set indices are inconsistent.
- AttributeError: If required attributes (x, active_set, posterior_cov, orientation_type) are missing.
"""
# Determine the total number of source components based on orientation type and x shape
n_total_components = 0
if self.orientation_type == 'fixed':
if self.x.ndim != 2:
raise ValueError(f"For fixed orientation, expected self.x to be 2D (n_sources, n_times), but got shape {self.x.shape}")
n_total_components = self.x.shape[0] # n_sources
elif self.orientation_type == 'free':
# Expects original shape (n_sources, 3, n_times) or reshaped (n_sources*3, n_times)
if self.x.ndim == 3: # Original shape
if self.x.shape[1] != 3:
raise ValueError(f"For free orientation with 3D self.x, expected shape (n_sources, 3, n_times), but got {self.x.shape}")
n_total_components = self.x.shape[0] * self.x.shape[1] # n_sources * 3
# Initialize the full covariance matrix
full_posterior_cov = np.zeros((n_total_components, n_total_components), dtype=self.posterior_cov.dtype)
self.logger.debug(f"Initialized full_posterior_cov with shape {full_posterior_cov.shape}")
# Embed the active set covariance using nested loops (safe for unsorted active_set)
# Alternative: If performance critical and active_set is sorted:
# idx = np.ix_(self.active_set, self.active_set)
# full_posterior_cov[idx] = self.posterior_cov
try:
for i, idx_i in enumerate(self.active_set):
for j, idx_j in enumerate(self.active_set):
full_posterior_cov[idx_i, idx_j] = self.posterior_cov[i, j]
self.logger.debug("Successfully embedded posterior_cov into full_posterior_cov.")
except IndexError as e:
self.logger.error(f"IndexError during covariance embedding: {e}. "
f"i={i}, j={j}, idx_i={idx_i}, idx_j={idx_j}, "
f"posterior_cov shape={self.posterior_cov.shape}, active_set size={self.active_set.size}")
raise IndexError(f"Error accessing elements during covariance embedding. Check active_set indices ({idx_i}, {idx_j}) against posterior_cov shape {self.posterior_cov.shape} using indices ({i}, {j}).") from e
try:
# This assertion holds true if posterior_cov is dense and square for the active set
assert self.posterior_cov.size == self.active_set.size ** 2, \
f"Size of posterior_cov ({self.posterior_cov.size}) should be square of active_set size ({self.active_set.size}^2 = {self.active_set.size**2})."
except AssertionError as e:
self.logger.error(f"Assertion failed: {e}")
raise AssertionError(f"Validation failed: {e}") from e
self.logger.debug(f"Constructed full covariance matrix of shape {full_posterior_cov.shape}")
return full_posterior_cov
# ------------------------------
# ------------------------------
[docs]
def plot_sorted_posterior_variances(self, top_k=None):
"""
Plot the sorted variances from the covariance matrix, highlighting the top-k variances.
"""
variances = np.diag(self.posterior_cov)
sorted_indices = np.argsort(variances)[::-1]
sorted_variances = variances[sorted_indices]
plt.figure(figsize=(12, 6))
bars = plt.bar(range(len(sorted_variances)), sorted_variances, color='skyblue', edgecolor='blue')
if top_k is not None:
for bar in bars[:top_k]:
bar.set_color('orange')
plt.xlabel("Source Index")
plt.ylabel("Variance")
plt.title(f"Sorted Posterior Variances (Top-{top_k if top_k else len(variances)} Highlighted)")
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout(rect=[0, 0.05, 1, 0.96])
plt.savefig(os.path.join(self.experiment_dir, 'sorted_variances.png'))
plt.close()
# ------------------------------
def _compute_top_covariance_pairs(self, cov, top_k=None):
"""
Compute and optionally sort the magnitudes of covariances for all pairs of dimensions.
Parameters:
cov (array-like): Covariance matrix of shape (n, n).
top_k (int, optional): Number of top pairs to return. If None, return all pairs.
Returns:
list: A sorted list of tuples. Each tuple contains:
- A pair of indices (i, j).
- The absolute magnitude of their covariance.
"""
# Ensure covariance matrix is a NumPy array
cov = np.asarray(cov)
# Get all unique pairs of indices
n = cov.shape[0]
pairs = list(combinations(range(n), 2))
# Compute magnitudes of covariances for each pair
pair_cov_magnitudes = [(pair, np.abs(cov[pair[0], pair[1]])) for pair in pairs]
# Sort by covariance magnitude in descending order
sorted_pairs = sorted(pair_cov_magnitudes, key=lambda x: x[1], reverse=True)
# Return top-k pairs if specified
if top_k is not None:
return sorted_pairs[:top_k]
return sorted_pairs
[docs]
def visualize_sorted_covariances(self, top_k=None):
"""
Visualize the sorted magnitudes of covariances for all pairs of dimensions.
"""
sorted_pairs = self._compute_top_covariance_pairs(self.posterior_cov, top_k=top_k)
pairs = [f"({i},{j})" for (i, j), _ in sorted_pairs]
magnitudes = [magnitude for _, magnitude in sorted_pairs]
plt.figure(figsize=(10, 6))
plt.bar(pairs, magnitudes, color='skyblue')
plt.xlabel('Pairs of Dimensions')
plt.ylabel('Covariance Magnitude')
plt.title(f"Top-{top_k if top_k else len(magnitudes)} Sorted Covariance Magnitudes")
plt.xticks(rotation=45, ha='right')
plt.tight_layout(rect=[0, 0.05, 1, 0.96])
plt.savefig(os.path.join(self.experiment_dir, 'sorted_covariances.png'))
plt.close()
# ------------------------------
def _make_psd(self, cov, epsilon=1e-6):
"""
Ensure that the covariance matrix is positive semi-definite by adding epsilon to the diagonal.
"""
# print("Regularizing covariance matrix...")
max_iterations = 100
iterations = 0
while not np.all(np.linalg.eigvals(cov) >= 0):
cov += np.eye(cov.shape[0]) * epsilon
epsilon *= 10
iterations += 1
if iterations > max_iterations:
self.logger.warning("Regularizing covariance matrix...")
self.logger.warning("Covariance matrix could not be made positive semi-definite.")
break
return cov
def _compute_confidence_ellipse(self, mean, cov, confidence_level=0.95):
"""
Compute the parameters of a confidence ellipse for a given mean and covariance matrix.
"""
# Validate covariance matrix
condition_number = np.linalg.cond(cov)
if condition_number > 1e10:
print("Covariance matrix is ill-conditioned")
# Regularize covariance matrix if not positive definite by adding gradually increasing epsilon to the diagonal.
if not np.all(np.linalg.eigvals(cov) > 0):
cov = self._make_psd(cov, epsilon=1e-6)
chi2_val = chi2.ppf(confidence_level, df=2)
eigenvals, eigenvecs = np.linalg.eigh(cov)
if np.all(eigenvals > 0):
print("Covariance matrix is now positive definite.")
else:
print("Covariance matrix is still not positive definite.")
order = np.argsort(eigenvals)[::-1]
eigenvals = eigenvals[order]
eigenvecs = eigenvecs[:, order]
width, height = 2 * np.sqrt(eigenvals * chi2_val)
angle = np.degrees(np.arctan2(eigenvecs[1, 0], eigenvecs[0, 0]))
return width, height, angle
def _plot_confidence_ellipse(self, mean, width, height, angle, ax=None, **kwargs):
"""
Plot a confidence ellipse for given parameters.
Parameters:
- mean: array-like, shape (2,)
The mean of the data in the two dimensions being plotted.
- width: float
The width of the ellipse (related to variance along the major axis).
- height: float
The height of the ellipse (related to variance along the minor axis).
- angle: float
The rotation angle of the ellipse in degrees.
- ax: matplotlib.axes.Axes, optional
The axis on which to plot the ellipse. If None, creates a new figure.
- **kwargs: additional keyword arguments for matplotlib.patches.Ellipse.
"""
if ax is None:
fig, ax = plt.subplots()
# Add ellipse patch
ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle, **kwargs)
ax.add_patch(ellipse)
ax.scatter(*mean, color='blue', label='Mean')
# Set axis labels
ax.set_xlabel("Principal Component 1 (Variance in Dim 1)")
ax.set_ylabel("Principal Component 2 (Variance in Dim 2)")
# Set title
ax.set_title("Confidence Ellipse (Width and Height Indicate Variance)")
ax.grid()
ax.legend()
[docs]
def plot_top_relevant_CE_pairs(self, top_k=5, confidence_level=0.95):
"""
Identify the top-k relevant pairs of dimensions (based on covariance magnitude)
and plot their confidence ellipses.
"""
mean = self.x_hat[self.active_set]
cov = self.posterior_cov
n = len(mean)
pairs = list(combinations(range(n), 2))
pair_cov_magnitudes = [(pair, np.abs(cov[pair[0], pair[1]])) for pair in pairs]
sorted_pairs = sorted(pair_cov_magnitudes, key=lambda x: x[1], reverse=True)
top_pairs = [pair for pair, _ in sorted_pairs[:top_k]]
n_cols = min(3, top_k)
n_rows = (top_k + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 6 * n_rows))
axes = axes.flatten()
for idx, (i, j) in enumerate(top_pairs):
mean_ij = mean[[i, j]]
cov_ij = cov[np.ix_([i, j], [i, j])]
width, height, angle = self._compute_confidence_ellipse(mean_ij, cov_ij, confidence_level)
self._plot_confidence_ellipse(mean_ij, width, height, angle, ax=axes[idx], edgecolor='blue', alpha=0.5)
axes[idx].set_title(f"Dimensions {i} & {j}")
for ax in axes[len(top_pairs):]:
fig.delaxes(ax)
fig.suptitle("Top Relevant Dimensional Pairs with Confidence Ellipses", fontsize=16)
plt.tight_layout(rect=[0, 0.05, 1, 0.96])
plt.savefig(os.path.join(self.experiment_dir, 'top_relevant_CE_pairs.png'))
plt.close()
# ------------------------------
[docs]
def plot_active_sources_single_time_step(self, time_step=0):
"""
Plot the active sources for a single time step, comparing ground truth and estimated sources.
Handles both 3D and 2D input shapes for x and x_hat.
"""
if self.orientation_type == 'free':
if self.x_hat.ndim == 3:
n_sources, n_orient, n_times = self.x_hat.shape
x_hat_is_3d = True
elif self.x_hat.ndim == 2:
n_components, n_times = self.x_hat.shape
if n_components % 3 != 0:
self.logger.warning(f"Free orientation: self.x_hat is 2D, but first dim ({n_components}) not divisible by 3.")
n_sources = n_components // 3
x_hat_is_3d = False
else:
raise ValueError(f"Unexpected number of dimensions for self.x_hat: {self.x_hat.ndim}")
# --- Handle self.x shape similarly ---
if self.x.ndim == 3:
x_is_3d = True
elif self.x.ndim == 2:
x_is_3d = False
fig, axes = plt.subplots(3, 1, figsize=(12, 18), sharex=True) # Share x-axis
orientations = ['X', 'Y', 'Z']
for i, ax in enumerate(axes): # i is the orientation index (0, 1, 2)
# --- Ground Truth ---
if x_is_3d:
# Find non-zero elements for this orientation at this time step
gt_source_indices_orient = np.where(self.x[:, i, time_step] != 0)[0]
gt_amplitudes = self.x[gt_source_indices_orient, i, time_step]
else: # x is 2D (n_sources*3, n_times)
gt_indices_all = np.where(self.x[:, time_step] != 0)[0]
gt_indices_orient_flat = gt_indices_all[gt_indices_all % 3 == i] # Flat indices for this orientation
gt_source_indices_orient = gt_indices_orient_flat // 3 # Source indices
gt_amplitudes = self.x[gt_indices_orient_flat, time_step]
# --- Estimated ---
# Get the flat indices from active_set corresponding to this orientation
active_indices_orient_flat = self.active_set[self.active_set % 3 == i]
# Derive source indices from the flat indices
est_source_indices_orient = active_indices_orient_flat // 3
# Get amplitudes using appropriate indexing based on x_hat shape
if x_hat_is_3d:
# Check bounds before indexing 3D array
valid_source_indices = est_source_indices_orient[est_source_indices_orient < n_sources]
if len(valid_source_indices) < len(est_source_indices_orient):
self.logger.warning(f"Orientation {i}: Some derived source indices from active_set were out of bounds for 3D x_hat. Filtering.")
est_amplitudes = self.x_hat[valid_source_indices, i, time_step]
# Use the valid source indices for plotting
plot_est_source_indices = valid_source_indices
else: # x_hat is 2D (n_sources*3, n_times)
# Check bounds before indexing 2D array
max_idx_x_hat = self.x_hat.shape[0] - 1
valid_flat_indices = active_indices_orient_flat[active_indices_orient_flat <= max_idx_x_hat]
if len(valid_flat_indices) < len(active_indices_orient_flat):
self.logger.warning(f"Orientation {i}: Some flat indices from active_set were out of bounds for 2D x_hat. Filtering.")
est_amplitudes = self.x_hat[valid_flat_indices, time_step]
# Use source indices derived from valid flat indices for plotting
plot_est_source_indices = valid_flat_indices // 3
# --- Plotting ---
ax.scatter(gt_source_indices_orient, gt_amplitudes, color='blue', alpha=0.6, label='Ground Truth Active')
ax.scatter(plot_est_source_indices, est_amplitudes, color='red', marker='x', alpha=0.6, label='Estimated Active')
ax.set_xlabel('Source Index') # Label only needed on bottom plot due to sharex
ax.set_ylabel('Amplitude')
ax.set_title(f'Active Sources Comparison ({orientations[i]} Orientation, Time Step {time_step})')
ax.legend(loc='best')
ax.grid(True, alpha=0.5)
ax.axhline(0, color='grey', linestyle='--', linewidth=0.8)
# Add shared x-label
fig.text(0.5, 0.04, 'Source Index', ha='center', va='center')
plt.tight_layout(rect=[0, 0.05, 1, 0.96])
fig.suptitle(f"Active Sources Comparison (Free Orientation, Time Step {time_step})", fontsize=16)
save_path = os.path.join(self.experiment_dir, f'active_sources_single_time_step_{time_step}.png')
plt.savefig(save_path)
self.logger.debug(f"Saved active sources plot to {save_path}")
plt.close(fig)
else: # Fixed orientation (assuming x and x_hat are 2D: n_sources, n_times)
if self.x.ndim != 2 or self.x_hat.ndim != 2:
raise ValueError(f"Fixed orientation plotting expects 2D x ({self.x.shape}) and x_hat ({self.x_hat.shape})")
if self.x.shape[0] != self.x_hat.shape[0]:
raise ValueError(f"Shape mismatch between x ({self.x.shape}) and x_hat ({self.x_hat.shape})")
n_sources = self.x.shape[0]
max_index_x_hat = n_sources - 1
gt_active_sources = np.where(self.x[:, time_step] != 0)[0]
gt_amplitudes = self.x[gt_active_sources, time_step]
# Filter active_set for valid indices
valid_mask = self.active_set <= max_index_x_hat
active_set_plot = self.active_set[valid_mask]
if not np.all(valid_mask):
invalid_indices = self.active_set[~valid_mask]
self.logger.warning(f"Fixed Orientation: Found indices in active_set {invalid_indices.tolist()} "
f"that are out of bounds for x_hat (max index: {max_index_x_hat}). Filtering them out for plotting.")
est_amplitudes = self.x_hat[active_set_plot, time_step]
plt.figure(figsize=(12, 6))
plt.scatter(gt_active_sources, gt_amplitudes, color='blue', alpha=0.6, label='Ground Truth Active')
plt.scatter(active_set_plot, est_amplitudes, color='red', marker='x', alpha=0.6, label='Estimated Active')
plt.xlabel('Source Index')
plt.ylabel('Amplitude')
plt.title(f'Active Sources Comparison (Fixed Orientation, Time Step {time_step})')
plt.legend(loc='best')
plt.grid(True, alpha=0.5)
plt.axhline(0, color='grey', linestyle='--', linewidth=0.8)
plt.tight_layout(rect=[0, 0.05, 1, 0.96])
save_path = os.path.join(self.experiment_dir, f'active_sources_single_time_step_{time_step}.png')
plt.savefig(save_path)
self.logger.debug(f"Saved active sources plot to {save_path}")
plt.close()
[docs]
def plot_posterior_covariance_matrix(self):
"""
Plot the posterior covariance matrix.
"""
if self.orientation_type == 'free':
# Check if posterior_cov shape is compatible with free orientation slicing
n_active_components = self.posterior_cov.shape[0]
if n_active_components % 3 != 0:
self.logger.warning(f"Free orientation: posterior_cov shape {self.posterior_cov.shape}, first dimension is not divisible by 3.")
# Fallback to plotting the whole matrix if slicing is not possible
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(self.posterior_cov, cmap='viridis', aspect='auto')
fig.colorbar(im, ax=ax, label='Covariance Value')
ax.set_title('Posterior Covariance Matrix (Free Orientation - Full)')
ax.set_xlabel('Active Component Index')
ax.set_ylabel('Active Component Index')
plt.tight_layout(rect=[0, 0.05, 1, 0.96])
else:
fig, axes = plt.subplots(3, 1, figsize=(10, 18))
orientations = ['X', 'Y', 'Z']
# Determine shared color limits across the subplots
vmin = np.min(self.posterior_cov)
vmax = np.max(self.posterior_cov)
images = [] # Store image objects for colorbar
for i, ax in enumerate(axes):
# Select the block corresponding to the orientation
# This assumes active_set components are ordered [src0_x, src0_y, src0_z, src1_x, ...]
# which might not be true. A safer plot might show the full matrix.
# Let's plot the diagonal blocks for now, assuming structure.
try:
cov_matrix_block = self.posterior_cov[i::3, i::3]
im = ax.imshow(cov_matrix_block, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
images.append(im)
ax.set_title(f'Diagonal Block - Orientation {orientations[i]}')
ax.set_xlabel('Source Index (within orientation)')
ax.set_ylabel('Source Index (within orientation)')
except IndexError:
self.logger.warning(f"Could not extract block {i}::3 for orientation {orientations[i]}. Skipping subplot.")
ax.set_title(f'Orientation {orientations[i]} - Error')
plt.tight_layout(rect=[0, 0.05, 1, 0.96])
# Add colorbar spanning all axes, using the first image's mappable
if images:
fig.colorbar(images[0], ax=axes.ravel().tolist(), orientation='vertical', fraction=0.02, pad=0.04, label='Covariance Value')
else: # Fixed orientation
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(self.posterior_cov, cmap='viridis', aspect='auto')
plt.colorbar(im, label='Covariance Value')
ax.set_title('Posterior Covariance Matrix (Fixed Orientation)')
ax.set_xlabel('Active Source Index')
ax.set_ylabel('Active Source Index')
plt.tight_layout(rect=[0, 0.05, 1, 0.96])
try:
plt.savefig(os.path.join(self.experiment_dir, 'posterior_covariance_matrix.png'))
self.logger.info(f"Posterior covariance matrix plot saved to {self.experiment_dir}/posterior_covariance_matrix.png")
except Exception as e:
self.logger.error(f"Failed to save posterior covariance matrix plot: {e}")
finally:
plt.close(fig)
# ------------------------------
def _compute_confidence_intervals(self, mean, cov, confidence_level=0.95):
"""
Compute confidence intervals based on the diagonal of the covariance matrix.
Assumes inputs correspond only to the active components.
Parameters:
- mean (np.ndarray): Mean array for active components, shape (n_active_components, n_times).
- cov (np.ndarray): Covariance matrix for active components, shape (n_active_components, n_active_components).
- confidence_level (float): Confidence level for the intervals (e.g., 0.95 for 95%).
Returns:
- ci_lower (np.ndarray): Lower bounds of confidence intervals, shape (n_active_components, n_times).
- ci_upper (np.ndarray): Upper bounds of confidence intervals, shape (n_active_components, n_times).
"""
# Calculate the Z-score corresponding to the confidence level for a normal distribution
# Example: 0.95 -> z = 1.96
alpha = 1.0 - confidence_level
z = np.abs(np.percentile(np.random.normal(0, 1, 1000000), [alpha / 2 * 100, (1 - alpha / 2) * 100]))[1]
self.logger.debug(f"Z-score for confidence level {confidence_level}: {z:.4f}")
# Ensure covariance matrix is positive semi-definite for variance calculation
# Note: _make_psd might modify cov in place if not careful, consider passing a copy if needed elsewhere.
# However, we only need the diagonal here, so modifying cov might be acceptable if not used later.
cov_psd = self._make_psd(cov.copy()) # Work on a copy to avoid modifying original cov
# Extract diagonal variances
variances = np.diag(cov_psd)
# Handle potential negative variances after PSD adjustment (should ideally not happen with _make_psd)
# variances[variances < 0] = 0
# self.logger.debug(f"Number of non-positive variances after PSD adjustment: {np.sum(variances <= 0)}")
# Calculate standard deviation for each active component
std_dev = np.sqrt(variances)
# Expand dimensions for broadcasting: (n_active_components,) -> (n_active_components, 1)
std_dev = std_dev[:, np.newaxis]
# Calculate confidence intervals: mean +/- z * std_dev
ci_lower = mean - z * std_dev
ci_upper = mean + z * std_dev
self.logger.debug(f"Computed CI shapes: lower={ci_lower.shape}, upper={ci_upper.shape}")
# The distinction between 'fixed' and 'free' is not needed here,
# as the inputs `mean` and `cov` are already specific to the active components.
# The interpretation of these components (fixed source vs. free orientation component)
# happens in other functions like _count_values_within_ci.
return ci_lower, ci_upper
def _count_values_within_ci(self, x, ci_lower, ci_upper):
"""
Count the number of ground truth values that lie within the confidence intervals for each time point.
Assumes input arrays correspond only to the active components.
Parameters:
- x (np.ndarray): Ground truth source activity for active components, shape (n_active_components, n_times).
- ci_lower (np.ndarray): Lower bounds of confidence intervals for active components, shape (n_active_components, n_times).
- ci_upper (np.ndarray): Upper bounds of confidence intervals for active components, shape (n_active_components, n_times).
Returns:
- count_within_ci (np.ndarray): Count of values within confidence intervals.
- For "fixed" orientation: 1D array (n_times,) with counts per time point.
- For "free" orientation: 2D array (3, n_times) with counts per orientation (X, Y, Z) per time point.
"""
if x.shape[0] != len(self.active_set) or ci_lower.shape[0] != len(self.active_set) or ci_upper.shape[0] != len(self.active_set):
raise ValueError("Input array dimensions do not match the length of the active_set.")
n_times = x.shape[1]
if self.orientation_type == "fixed":
# Sum over all active sources for each time point
count_within_ci = np.sum((x >= ci_lower) & (x <= ci_upper), axis=0)
self.logger.debug(f"Fixed orientation counts shape: {count_within_ci.shape}")
elif self.orientation_type == "free":
# Initialize counts for each orientation (X, Y, Z) and each time point
count_within_ci = np.zeros((3, n_times))
# Determine the orientation for each row in the input arrays based on the original active_set indices
orientations = self.active_set % 3 # Shape: (n_active_components,)
for i in range(3): # Iterate through X, Y, Z
# Create a mask for rows corresponding to the current orientation
orient_mask = (orientations == i)
# Select the rows for the current orientation from the input arrays
x_orient = x[orient_mask, :]
ci_lower_orient = ci_lower[orient_mask, :]
ci_upper_orient = ci_upper[orient_mask, :]
# Count values within confidence intervals for the current orientation, summing over sources
if x_orient.size > 0: # Ensure there are sources for this orientation
count_within_ci[i, :] = np.sum((x_orient >= ci_lower_orient) & (x_orient <= ci_upper_orient), axis=0)
# else: counts remain zero, which is correct
self.logger.debug(f"Free orientation counts shape: {count_within_ci.shape}")
return count_within_ci
def _plot_ci_times(self, x, x_hat, active_set, ci_lower, ci_upper, confidence_level, figsize=(20, 15)):
"""
Plot the estimated source activity with confidence intervals for active components and save them.
Assumes input arrays correspond only to the active components.
Parameters:
- x (np.ndarray): Ground truth source activity for active components, shape (n_active_components, n_times).
- x_hat (np.ndarray): Estimated source activity for active components, shape (n_active_components, n_times).
- active_set (np.ndarray): Original indices (flattened for free orientation) of active components, shape (n_active_components,).
- ci_lower (np.ndarray): Lower bounds of confidence intervals for active components, shape (n_active_components, n_times).
- ci_upper (np.ndarray): Upper bounds of confidence intervals for active components, shape (n_active_components, n_times).
- confidence_level (float): Confidence level for the intervals.
- figsize (tuple): Size of the plot.
"""
logger = self.logger if hasattr(self, 'logger') and self.logger else logging.getLogger(__name__)
# Create the base directory for confidence intervals
confidence_intervals_dir = os.path.join(self.experiment_dir, 'CI')
os.makedirs(confidence_intervals_dir, exist_ok=True)
logger.debug(f"Saving CI plots to: {confidence_intervals_dir}")
n_active_components, n_times = x.shape
if n_active_components == 0:
logger.warning("No active components to plot for CI times.")
return
if self.orientation_type == "free":
orientations = ['X', 'Y', 'Z']
# Map active component index (0 to n_active_components-1) to original source index and orientation
original_source_indices = active_set // 3
original_orient_indices = active_set % 3
for t in range(n_times):
time_point_dir = os.path.join(confidence_intervals_dir, f't{t}')
os.makedirs(time_point_dir, exist_ok=True)
fig, axes = plt.subplots(3, 1, figsize=figsize, sharex=True, sharey=True) # Share axes
# Track if legend labels have been added for each subplot
legend_labels_added = [False, False, False]
for i in range(n_active_components): # Loop through active components
source_idx = original_source_indices[i]
orient_idx = original_orient_indices[i]
ax = axes[orient_idx]
# Determine if labels should be added (only for the first point on each subplot)
add_label = not legend_labels_added[orient_idx]
# Use source_idx for x-coordinate
ax.scatter(source_idx, x_hat[i, t], marker='x', s=50, color='red',
label='Posterior Mean' if add_label else "")
# Use fill_between for the CI bar
ax.fill_between(
[source_idx - 2, source_idx + 2], # x-range for the bar
ci_lower[i, t],
ci_upper[i, t],
color='green', # Match scatter color
alpha=0.8,
label='Confidence Interval' if add_label else ""
)
ax.scatter(source_idx, x[i, t], s=30, color='blue', alpha=0.7,
label='Ground Truth' if add_label else "")
# Mark that labels have been added for this subplot
if add_label:
legend_labels_added[orient_idx] = True
# Configure axes after plotting all points for this time step
all_plotted_source_indices = sorted(list(set(original_source_indices)))
for j, (ax, orient) in enumerate(zip(axes, orientations)):
ax.set_title(f'Orientation {orient}')
ax.axhline(0, color='grey', lw=0.8, ls='--')
# Calculate total unique sources plotted on this specific axis
sources_on_this_axis = {original_source_indices[k] for k in range(n_active_components) if original_orient_indices[k] == j}
n_sources_this_axis = len(sources_on_this_axis)
# Add legend with total sources in the title
ax.legend(title=f"Total Sources: {n_sources_this_axis}", loc='best')
ax.grid(False)
# Set ticks only for sources actually plotted
ax.set_xticks(all_plotted_source_indices)
ax.set_xticklabels([str(idx) for idx in all_plotted_source_indices], rotation=45, ha='right')
# Limit x-axis slightly beyond plotted sources
if all_plotted_source_indices:
ax.set_xlim(min(all_plotted_source_indices) - 1, max(all_plotted_source_indices) + 1)
fig.text(0.5, 0.04, 'Original Source Index', ha='center', va='center')
fig.text(0.04, 0.5, 'Activity', va='center', rotation='vertical')
fig.suptitle(f'Confidence Intervals (Level={confidence_level:.2f}, Time={t})', fontsize=16)
plt.tight_layout(rect=[0.05, 0.05, 1, 0.95]) # Adjust rect for titles
save_path = os.path.join(time_point_dir, f'ci_t{t}_clvl{round(confidence_level, 2)}.png')
plt.savefig(save_path)
logger.debug(f"Saved CI plot: {save_path}")
plt.close(fig)
else: # Fixed orientation
original_source_indices = active_set # These are the source indices
for t in range(n_times):
time_point_dir = os.path.join(confidence_intervals_dir, f't{t}')
os.makedirs(time_point_dir, exist_ok=True)
fig, ax = plt.subplots(figsize=figsize)
legend_labels_added = False # Track if labels added for this plot
for i in range(n_active_components): # Loop through active components
source_idx = original_source_indices[i]
# Determine if labels should be added (only for the first point)
add_label = not legend_labels_added
# Use source_idx for x-coordinate
ax.scatter(source_idx, x_hat[i, t], marker='x', s=50, color='red',
label='Posterior Mean' if add_label else "")
ax.fill_between(
[source_idx - 0.4, source_idx + 0.4], # Adjust width
ci_lower[i, t],
ci_upper[i, t],
color='red',
alpha=0.3,
label='Confidence Interval' if add_label else ""
)
ax.scatter(source_idx, x[i, t], s=30, color='blue', alpha=0.7,
label='Ground Truth' if add_label else "")
# Mark that labels have been added
if add_label:
legend_labels_added = True
# Configure axis after plotting
all_plotted_source_indices = sorted(list(set(original_source_indices)))
ax.set_title(f'Confidence Intervals (Level={confidence_level:.2f}, Time={t})')
ax.axhline(0, color='grey', lw=0.8, ls='--')
# Add legend with total active sources in the title
ax.legend(title=f'Total Active Sources: {n_active_components}', loc='best')
ax.grid(False)
ax.set_xticks(all_plotted_source_indices)
ax.set_xticklabels([str(idx) for idx in all_plotted_source_indices], rotation=45, ha='right')
ax.set_xlabel('Original Source Index')
ax.set_ylabel('Activity')
if all_plotted_source_indices:
ax.set_xlim(min(all_plotted_source_indices) - 1, max(all_plotted_source_indices) + 1)
plt.tight_layout(rect=[0.05, 0.05, 1, 0.96]) # Adjust rect
save_path = os.path.join(time_point_dir, f'ci_t{t}_clvl{round(confidence_level, 2)}.png')
plt.savefig(save_path)
logger.debug(f"Saved CI plot: {save_path}")
plt.close(fig)
def _plot_proportion_of_hits(
self,
confidence_levels,
CI_count_per_confidence_level,
total_sources,
time_point=0,
filename='proportion_of_hits',
):
"""
Internal method to plot the proportion of hits within confidence intervals for a specific time point.
Parameters:
- confidence_levels (list or np.ndarray): Confidence levels to plot.
- CI_count_per_confidence_level (np.ndarray): Array with counts of values within confidence intervals.
- For "fixed": shape (n_levels, n_times).
- For "free": shape (n_levels, 3, n_times).
- total_sources (int): Total number of sources (denominator for proportion).
For 'free', this is typically the number of unique sources.
For 'fixed', this is typically the number of active sources.
- time_point (int): The specific time point to plot.
- filename (str): Name of the file to save the plot.
"""
if self.orientation_type == 'free':
# Create subplots for the three orientations (X, Y, Z)
fig, axes = plt.subplots(3, 1, figsize=(6, 18), sharex=True, sharey=True)
orientations = ['X', 'Y', 'Z']
for i, ax in enumerate(axes):
# Extract hits for the current orientation and time point
# Ensure time_point is within bounds
if time_point >= CI_count_per_confidence_level.shape[2]:
self.logger.error(f"time_point {time_point} is out of bounds for CI_count_per_confidence_level with shape {CI_count_per_confidence_level.shape}")
plt.close(fig)
return
hits = CI_count_per_confidence_level[:, i, time_point] # Correct indexing order
proportions = hits / total_sources # Normalize hits to proportions
# Plot proportions and diagonal line y=x
ax.plot(confidence_levels, proportions, marker='o', linestyle='-', color='blue', label='Proportion of Hits')
ax.plot([0, 1], [0, 1], linestyle='--', color='gray', label='y=x')
# Set axis labels, title, and grid
ax.set_ylabel('Proportion of Hits')
ax.grid(True)
ax.set_xticks(confidence_levels)
ax.set_xticklabels([f'{cl:.0%}' for cl in confidence_levels]) # Use percentage format
ax.set_title(f'Orientation {orientations[i]} (Time Point {time_point})')
ax.legend(loc='lower right')
# Ensure axes are square
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05) # Corrected typo here
ax.set_aspect('equal', adjustable='box')
# Add x-axis label to the last subplot
axes[-1].set_xlabel('Confidence Level')
# Add a title for the entire figure
fig.suptitle(f'Proportion of Hits at Time Point {time_point} (Free Orientation)', fontsize=14)
fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Leave space for the title
plt.savefig(os.path.join(self.experiment_dir, filename + '.png'))
plt.close(fig)
else: # Fixed orientation
hits = CI_count_per_confidence_level[:, time_point]
proportions = hits / total_sources # Normalize hits to proportions
fig, ax = plt.subplots(figsize=(6, 6)) # Square figure
ax.plot(confidence_levels, proportions, marker='o', linestyle='-', color='blue', label='Proportion of Hits')
ax.plot([0, 1], [0, 1], linestyle='--', color='gray', label='y=x')
# Set axis labels, title, and grid
ax.set_xlabel('Confidence Level')
ax.set_ylabel('Proportion of Hits')
ax.set_title(f'Proportion of Hits at Time Point {time_point} (Fixed Orientation)')
ax.grid(True)
ax.set_xticks(confidence_levels)
ax.set_xticklabels([f'{cl:.0%}' for cl in confidence_levels]) # Use percentage format
ax.legend(loc='lower right')
# Ensure axes are square
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
ax.set_aspect('equal', adjustable='box')
fig.tight_layout(rect=[0.05, 0.05, 1, 0.96])
plt.savefig(os.path.join(self.experiment_dir, filename + '.png'))
plt.close(fig)
self.logger.info(f"Proportion of hits plot saved to {os.path.join(self.experiment_dir, filename + '.png')}")
[docs]
def visualize_confidence_intervals(self, confidence_levels=None, time_point=0):
"""
Visualize confidence intervals and save the results. Handles both fixed and free orientation.
Parameters:
- confidence_levels (list, optional): List of confidence levels to visualize. If None, defaults to 10 levels from 0.1 to 0.99.
- time_point (int): Time point to visualize for the proportion of hits plot.
"""
if confidence_levels is None:
confidence_levels = np.linspace(0.1, 0.99, 10)
# --- Prepare data based on orientation type ---
if self.orientation_type == 'free':
# Shapes: x=(5124, 3, 10), x_hat=(5124, 3, 10), active_set=(1515,), posterior_cov=(1515, 1515)
n_sources, n_orient, n_times = self.x.shape
n_total_components = n_sources * n_orient
# Reshape to (n_total_components, n_times)
x_proc = self.x.reshape(n_total_components, n_times)
x_hat_proc = self.x_hat.reshape(n_total_components, n_times)
self.logger.debug(f"Free orientation: Reshaped x to {x_proc.shape}, x_hat to {x_hat_proc.shape}")
# Index flattened data using active_set
active_x = x_proc[self.active_set] # Shape: (1515, 10)
active_x_hat = x_hat_proc[self.active_set] # Shape: (1515, 10)
# Denominator for proportion plot (number of unique sources in active set)
total_sources_for_plot = len(np.unique(self.active_set // 3))
elif self.orientation_type == 'fixed':
# Shapes: x=(5124, 10), x_hat=(5124, 10), active_set=(1515,), posterior_cov=(1515, 1515)
n_sources, n_times = self.x.shape
n_total_components = n_sources
# Data is already 2D
x_proc = self.x
x_hat_proc = self.x_hat
self.logger.debug(f"Fixed orientation: Using x {x_proc.shape}, x_hat {x_hat_proc.shape}")
# Index data using active_set
active_x = x_proc[self.active_set] # Shape: (1515, 10)
active_x_hat = x_hat_proc[self.active_set] # Shape: (1515, 10)
# Denominator for proportion plot (number of active sources)
total_sources_for_plot = len(self.active_set)
self.logger.debug(f"Indexed data using active_set. Shapes: "
f"active_x={active_x.shape}, active_x_hat={active_x_hat.shape}")
self.logger.debug(f"Total sources for proportion plot denominator: {total_sources_for_plot}")
# --- Loop through confidence levels ---
self.logger.info("Computing and creating figures for confidence intervals; for each confidence level and time point. This may take a while...")
CI_count_per_confidence_level = []
for confidence_level in confidence_levels:
# Compute CIs using active estimated data and the full posterior_cov for the active set
# _compute_confidence_intervals should handle orientation internally based on self.orientation_type
ci_lower, ci_upper = self._compute_confidence_intervals(
active_x_hat, # Shape (1515, 10)
self.posterior_cov, # Shape (1515, 1515)
confidence_level=confidence_level
) # ci_lower/upper shape depends on _compute_confidence_intervals logic
# Count hits using active ground truth data
# _count_values_within_ci should handle orientation internally
count_within_ci = self._count_values_within_ci(
active_x, # Shape (1515, 10)
ci_lower,
ci_upper
) # count_within_ci shape depends on _count_values_within_ci logic
# Plot CIs over time
# _plot_ci_times should handle orientation internally
self._plot_ci_times(
active_x, # Shape (1515, 10)
active_x_hat, # Shape (1515, 10)
self.active_set, # Shape (1515,) - Original indices
ci_lower,
ci_upper,
confidence_level,
)
CI_count_per_confidence_level.append(count_within_ci)
# --- Plot Proportion of Hits ---
CI_count_per_confidence_level = np.array(CI_count_per_confidence_level)
self.logger.debug(f"Shape of CI_count_per_confidence_level array: {CI_count_per_confidence_level.shape}")
# _plot_proportion_of_hits should handle orientation internally
self._plot_proportion_of_hits(
confidence_levels=confidence_levels,
CI_count_per_confidence_level=CI_count_per_confidence_level,
total_sources=total_sources_for_plot, # Use the calculated denominator
time_point=time_point,
)
# ------------------------------
[docs]
def plot_source_estimates(self, posterior_cov, orientations):
"""
Plot source estimates and save the visualizations.
Parameters:
- posterior_cov (np.ndarray): Posterior covariance matrix.
- experiment_dir (str): Path to the experiment directory.
- orientations (list): List of orientations for visualization.
"""
posterior_var = np.diag(posterior_cov)
z_score = self.x_hat[:, 0] / (np.sqrt(np.abs(posterior_var)) + 1e-10)
stc_x_t0 = mne.SourceEstimate(self.x[:, 0], vertices=self.vertices, tmin=0, tstep=0)
stc_x_hat_t0 = mne.SourceEstimate(self.x_hat[:, 0], vertices=self.vertices, tmin=0, tstep=0)
stc_variance = mne.SourceEstimate(posterior_var, vertices=self.vertices, tmin=0, tstep=0)
stc_zscore = mne.SourceEstimate(z_score, vertices=self.vertices, tmin=0, tstep=0)
source_estimates = [
(stc_x_t0, 'Ground Truth'),
(stc_x_hat_t0, 'Posterior Mean'),
(stc_variance, 'Posterior Variance'),
(stc_zscore, 'Z-Score')
]
for stc, title in source_estimates:
brain = stc.plot(hemi="both", subject='fsaverage', subjects_dir="/Users/orabe/0.braindata/MNE-sample-data/subjects", spacing='ico4', title=title)
for orientation in orientations:
orientation_dir = os.path.join(self.experiment_dir, 'brain', orientation)
os.makedirs(orientation_dir, exist_ok=True)
brain.show_view(orientation)
brain.save_image(os.path.join(orientation_dir, f'{title.replace(" ", "_").lower()}_{orientation}.png'))
brain.close()