Source code for deepSTRF.metrics.performance

"""Performance metrics for deepSTRF: NaN-aware, eval-only, reduction over neurons.

See ``docs/_source/md/metrics_paradigm.md`` for the design rationale, shape
conventions, NaN handling, and per-metric formulas.
"""

from __future__ import annotations

import itertools
from typing import Optional, Tuple

import numpy as np
import scipy.signal
import torch

from deepSTRF.metrics._masking import (
    collapse_to_psth_if_needed,
    reduce_over_neurons,
    resolve_mask,
)

_EPS = 1e-12


# -----------------------------------------------------------------------------
# Shape and helper utilities
# -----------------------------------------------------------------------------


def _check_pred_gt(pred: torch.Tensor, gt: torch.Tensor) -> None:
    if pred.shape != gt.shape:
        raise ValueError(
            f"pred shape {tuple(pred.shape)} must equal gt shape {tuple(gt.shape)}"
        )
    if pred.dim() != 4:
        raise ValueError(
            f"expected 4-D tensors (B, N, R, T), got {pred.dim()}"
        )


def _check_responses(responses: torch.Tensor) -> None:
    if responses.dim() != 4:
        raise ValueError(
            f"expected responses with 4 dims (B, N, R, T), got {responses.dim()}"
        )


def _pearson_1d(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Pearson correlation between two 1-D tensors. NaN if either has zero variance."""
    x_c = x - x.mean()
    y_c = y - y.mean()
    num = (x_c * y_c).sum()
    den = torch.sqrt((x_c * x_c).sum() * (y_c * y_c).sum())
    if den.item() < _EPS:
        return x.new_full((), float("nan"))
    return num / den


# -----------------------------------------------------------------------------
# Public: prediction-vs-PSTH metrics
# -----------------------------------------------------------------------------


[docs] @torch.no_grad() def corrcoef( pred: torch.Tensor, gt: torch.Tensor, mask: Optional[torch.Tensor] = None, reduction: str = "mean", ) -> torch.Tensor: """Pearson correlation per neuron over flattened valid (B, T) positions. See ``metrics_paradigm.md`` §6.3. Parameters ---------- pred : torch.Tensor Prediction of shape ``(B, N, 1, T)``. gt : torch.Tensor Pre-computed PSTH ``(B, N, 1, T)`` or raw responses ``(B, N, R, T)`` with ``R > 1`` (collapsed to PSTH via ``nanmean(dim=2, keepdim=True)`` first). May contain NaN. mask : torch.Tensor, optional Bool tensor broadcastable to ``(B, N, 1, T)``. If None, defaults to ``~gt.isnan()``; if provided, REPLACES the NaN-derived mask. reduction : {'none', 'mean', 'sum'}, default 'mean' Reduction over the neuron axis. Returns ------- torch.Tensor Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar. Neurons with fewer than 2 valid positions or zero variance yield NaN. """ gt = collapse_to_psth_if_needed(gt) _check_pred_gt(pred, gt) valid = resolve_mask(gt, mask) N = pred.shape[1] nan = pred.new_full((), float("nan")) out = [] for n in range(N): v_n = valid[:, n, ...] if int(v_n.sum().item()) < 2: out.append(nan) continue p_n = pred[:, n, ...][v_n] g_n = gt[:, n, ...][v_n] out.append(_pearson_1d(p_n, g_n)) return reduce_over_neurons(torch.stack(out), reduction)
[docs] @torch.no_grad() def fve( pred: torch.Tensor, gt: torch.Tensor, mask: Optional[torch.Tensor] = None, reduction: str = "mean", ) -> torch.Tensor: """Fraction of variance explained (R²) per neuron, over flattened valid positions. ``FVE_n = 1 - SS_res / SS_tot`` where ``SS_tot`` is the variance of ``gt`` only. Negative when the prediction is worse than predicting the mean. Parameters ---------- pred : torch.Tensor Prediction of shape ``(B, N, 1, T)``. gt : torch.Tensor Pre-computed PSTH ``(B, N, 1, T)`` or raw responses ``(B, N, R, T)`` with ``R > 1`` (collapsed to PSTH via ``nanmean(dim=2, keepdim=True)`` first). May contain NaN. mask : torch.Tensor, optional Bool tensor broadcastable to ``(B, N, 1, T)``. If None, defaults to ``~gt.isnan()``; if provided, REPLACES the NaN-derived mask. reduction : {'none', 'mean', 'sum'}, default 'mean' Reduction over the neuron axis. Returns ------- torch.Tensor Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar. Neurons with fewer than 2 valid positions or zero target variance yield NaN. """ gt = collapse_to_psth_if_needed(gt) _check_pred_gt(pred, gt) valid = resolve_mask(gt, mask) N = pred.shape[1] nan = pred.new_full((), float("nan")) out = [] for n in range(N): v_n = valid[:, n, ...] if int(v_n.sum().item()) < 2: out.append(nan) continue p_n = pred[:, n, ...][v_n] g_n = gt[:, n, ...][v_n] ss_res = ((p_n - g_n) ** 2).sum() ss_tot = ((g_n - g_n.mean()) ** 2).sum() if ss_tot.item() < _EPS: out.append(nan) continue out.append(1.0 - ss_res / ss_tot) return reduce_over_neurons(torch.stack(out), reduction)
# ----------------------------------------------------------------------------- # Public: Sahani–Linden signal/noise/SNR (responses-only) # ----------------------------------------------------------------------------- def _per_stim_sp_np( responses_b: torch.Tensor, valid_b: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Per-stim SP, NP, and valid-time-count, all shape ``(N,)``. Behaviour-preserving extraction of the inner per-``(stim, neuron)`` loop from :func:`_sahani_linden_per_neuron` — exposes a stim-streaming entry point so callers can accumulate length-weighted partial sums without pre-building an ``(S, N, R_max, T_max)`` padded tensor (see :meth:`NeuralDataset.compute_neuron_quality`). Parameters ---------- responses_b : torch.Tensor Per-stim responses, shape ``(N, R, T)``. Invalid entries flagged via ``valid_b`` (NaN sentinels or an explicit mask). valid_b : torch.Tensor Bool mask of the same shape as ``responses_b``. Returns ------- sp_b, np_b : torch.Tensor ``(N,)`` per-neuron SP / NP for this stim. ``NaN`` for neurons that don't qualify (``R_v < 2`` or ``T_v < 2``). Tv_b : torch.Tensor ``(N,)`` per-neuron valid-time-bin count. ``0`` for non-qualifying neurons — used as the weight in length-weighted aggregation. """ if responses_b.dim() != 3: raise ValueError( f"_per_stim_sp_np expects (N, R, T), got {tuple(responses_b.shape)}" ) N, _, _ = responses_b.shape nan = responses_b.new_full((), float("nan")) zero = responses_b.new_zeros(()) sp_out = [] np_out = [] tv_out = [] for n in range(N): v_n = valid_b[n] # (R, T) valid_repeats = v_n.any(dim=-1) # (R,) valid_time = v_n.any(dim=0) # (T,) R_v = int(valid_repeats.sum().item()) T_v = int(valid_time.sum().item()) if R_v < 2 or T_v < 2: sp_out.append(nan) np_out.append(nan) tv_out.append(zero) continue sub = responses_b[n][valid_repeats][:, valid_time] # (R_v, T_v) psth = sub.mean(dim=0) # (T_v,) var_psth = psth.var(unbiased=True) # scalar tp = sub.var(dim=-1, unbiased=True).mean() # scalar sp = (R_v * var_psth - tp) / (R_v - 1) sp_out.append(sp) np_out.append(tp - sp) tv_out.append(responses_b.new_tensor(float(T_v))) return torch.stack(sp_out), torch.stack(np_out), torch.stack(tv_out) def _sahani_linden_per_neuron( responses: torch.Tensor, valid: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return (SP, NP) per neuron, both shape ``(N,)``. Per-stimulus computation, then **length-weighted** average across stims: SP_n = sum_b ( T_b · SP_{b,n} ) / sum_b T_b where ``T_b`` is the number of valid time bins for stim ``b``. This matches the natural sample-count weighting of ``cov`` and ``var`` over the concatenated ``(b, t)`` time series in ``corrcoef`` / ``normalized_corrcoef``, so long stims drive the SP estimate proportionally to their information content (the BLUE estimator under the assumption that each per-stim ``SP_{b,n}`` is unbiased — see ``metrics_paradigm.md`` §6.5). A given ``(b, n)`` cell is included iff it has ≥ 2 valid repeats and ≥ 2 valid time bins. Cells that have no qualifying stim get NaN. Thin wrapper over :func:`_per_stim_sp_np`: loops over the batch axis and accumulates per-neuron length-weighted partial sums. The streaming-friendly helper is the underscore-prefixed per-stim variant — callers handling large datasets should use it directly to avoid pre-padding ``(S, N, R_max, T_max)`` in memory. """ B, N, _, _ = responses.shape nan = responses.new_full((), float("nan")) sum_w = responses.new_zeros(N) sum_w_sp = responses.new_zeros(N) sum_w_np = responses.new_zeros(N) for b in range(B): sp_b, np_b, tv_b = _per_stim_sp_np(responses[b], valid[b]) contrib = tv_b > 0 if not bool(contrib.any()): continue # tv_b is 0 for non-qualifying neurons → masked multiplications stay 0 # whether sp_b/np_b are NaN there or not; explicit `where` avoids the # 0 * NaN = NaN trap. sum_w = sum_w + tv_b sum_w_sp = sum_w_sp + torch.where(contrib, tv_b * sp_b, sum_w_sp.new_zeros(())) sum_w_np = sum_w_np + torch.where(contrib, tv_b * np_b, sum_w_np.new_zeros(())) qualifying = sum_w > 0 sp_out = torch.where(qualifying, sum_w_sp / sum_w.clamp(min=1), nan.expand(N)) np_out = torch.where(qualifying, sum_w_np / sum_w.clamp(min=1), nan.expand(N)) return sp_out, np_out
[docs] @torch.no_grad() def signal_power( responses: torch.Tensor, mask: Optional[torch.Tensor] = None, reduction: str = "mean", ) -> torch.Tensor: """Sahani-Linden signal power per neuron. Computed per-stimulus (with the per-stim valid repeat count) then length-weighted across stimuli. Cells without any qualifying stim — needs ≥ 2 valid repeats and ≥ 2 valid time bins — return NaN. Parameters ---------- responses : torch.Tensor Raw responses of shape ``(B, N, R, T)``. NaN marks missing repeats / time bins. mask : torch.Tensor, optional Bool tensor broadcastable to ``responses``. If None, defaults to ``~responses.isnan()``; if provided, REPLACES the NaN-derived mask. reduction : {'none', 'mean', 'sum'}, default 'mean' Reduction over the neuron axis. Returns ------- torch.Tensor Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar. """ _check_responses(responses) valid = resolve_mask(responses, mask) sp, _ = _sahani_linden_per_neuron(responses, valid) return reduce_over_neurons(sp, reduction)
[docs] @torch.no_grad() def noise_power( responses: torch.Tensor, mask: Optional[torch.Tensor] = None, reduction: str = "mean", ) -> torch.Tensor: """Sahani-Linden noise power per neuron (``NP = TP - SP``). Parameters ---------- responses : torch.Tensor Raw responses of shape ``(B, N, R, T)``. NaN marks missing repeats / time bins. mask : torch.Tensor, optional Bool tensor broadcastable to ``responses``. If None, defaults to ``~responses.isnan()``; if provided, REPLACES the NaN-derived mask. reduction : {'none', 'mean', 'sum'}, default 'mean' Reduction over the neuron axis. Returns ------- torch.Tensor Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar. """ _check_responses(responses) valid = resolve_mask(responses, mask) _, np_ = _sahani_linden_per_neuron(responses, valid) return reduce_over_neurons(np_, reduction)
[docs] @torch.no_grad() def snr( responses: torch.Tensor, mask: Optional[torch.Tensor] = None, reduction: str = "mean", ) -> torch.Tensor: """Signal-to-noise ratio per neuron (``SNR = SP / NP``). Parameters ---------- responses : torch.Tensor Raw responses of shape ``(B, N, R, T)``. NaN marks missing repeats / time bins. mask : torch.Tensor, optional Bool tensor broadcastable to ``responses``. If None, defaults to ``~responses.isnan()``; if provided, REPLACES the NaN-derived mask. reduction : {'none', 'mean', 'sum'}, default 'mean' Reduction over the neuron axis. Returns ------- torch.Tensor Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar. Noiseless cells (``NP ≈ 0``) yield ``+inf``; the caller decides whether to filter. """ _check_responses(responses) valid = resolve_mask(responses, mask) sp, np_ = _sahani_linden_per_neuron(responses, valid) out = sp / np_.clamp(min=_EPS) return reduce_over_neurons(out, reduction)
# ----------------------------------------------------------------------------- # Public: noise-corrected prediction correlation # ----------------------------------------------------------------------------- def _per_stim_ccmax( responses_b: torch.Tensor, valid_b: torch.Tensor, max_iters: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Per-stim CCmax + valid-time-count, both shape ``(N,)``. Behaviour-preserving extraction of the inner per-``(stim, neuron)`` loop from :func:`_ccmax_per_neuron`. Returns ``T_v = 0`` for any neuron that doesn't contribute on this stim — both the ``R_v < 2 / T_v < 2`` structural skip and the ``ρ_half ≤ 0`` "too noisy to estimate" skip collapse to zero weight in the streaming aggregator. Parameters ---------- responses_b : torch.Tensor Per-stim responses, shape ``(N, R, T)``. valid_b : torch.Tensor Bool mask of the same shape as ``responses_b``. max_iters : int Cap on random half-splits per neuron (``C(R, R/2)`` blows up quickly). Returns ------- ccmax_b : torch.Tensor ``(N,)`` per-neuron CCmax for this stim. ``NaN`` for neurons that don't qualify or where ``ρ_half ≤ 0``. Tv_b : torch.Tensor ``(N,)`` per-neuron valid-time-bin count. ``0`` for non-contributing neurons — use as the streaming weight. """ if responses_b.dim() != 3: raise ValueError( f"_per_stim_ccmax expects (N, R, T), got {tuple(responses_b.shape)}" ) N, _, _ = responses_b.shape nan = responses_b.new_full((), float("nan")) zero = responses_b.new_zeros(()) cc_out = [] tv_out = [] for n in range(N): v_n = valid_b[n] valid_repeats = v_n.any(dim=-1) valid_time = v_n.any(dim=0) R_v = int(valid_repeats.sum().item()) T_v = int(valid_time.sum().item()) if R_v < 2 or T_v < 2: cc_out.append(nan) tv_out.append(zero) continue sub = responses_b[n][valid_repeats][:, valid_time] if R_v % 2 == 1: R_v -= 1 sub = sub[:R_v] half_sets = list(itertools.combinations(range(R_v), R_v // 2)) n_iters = min(len(half_sets) // 2, max_iters) cc_halfs = [] for i in range(n_iters): first = sub[list(half_sets[i])].mean(dim=0) second = sub[list(half_sets[-1 - i])].mean(dim=0) cc_halfs.append(_pearson_1d(first, second)) if not cc_halfs: cc_out.append(nan) tv_out.append(zero) continue rho_half = torch.stack(cc_halfs).mean() if rho_half.item() <= 0: # too noisy to estimate ceiling; drop this stim from the weighted avg cc_out.append(nan) tv_out.append(zero) continue cc_out.append(torch.sqrt(2 * rho_half / (1 + rho_half))) tv_out.append(responses_b.new_tensor(float(T_v))) return torch.stack(cc_out), torch.stack(tv_out) def _ccmax_per_neuron( responses: torch.Tensor, valid: torch.Tensor, max_iters: int, ) -> torch.Tensor: """Per-neuron CCmax (Hsu / Spearman-Brown), length-weighted across stims. Same length-weighting convention as ``_sahani_linden_per_neuron``: each per-stim CCmax_{b,n} is weighted by ``T_b`` (its valid time count) so that long stims drive the ceiling estimate more than short ones, matching the natural sample weighting of ``corrcoef`` over the concatenated time axis. Per-stim cells with ``ρ_half ≤ 0`` are dropped (NaN-skip in the weighted average). Cells with no qualifying stim get NaN. Thin wrapper over :func:`_per_stim_ccmax`. See :func:`_sahani_linden_per_neuron` for the equivalent SP/NP streaming rationale. """ B, N, _, _ = responses.shape nan = responses.new_full((), float("nan")) sum_w = responses.new_zeros(N) sum_w_cc = responses.new_zeros(N) for b in range(B): cc_b, tv_b = _per_stim_ccmax(responses[b], valid[b], max_iters=max_iters) contrib = tv_b > 0 if not bool(contrib.any()): continue sum_w = sum_w + tv_b sum_w_cc = sum_w_cc + torch.where(contrib, tv_b * cc_b, sum_w_cc.new_zeros(())) qualifying = sum_w > 0 return torch.where(qualifying, sum_w_cc / sum_w.clamp(min=1), nan.expand(N))
[docs] @torch.no_grad() def normalized_corrcoef( pred: torch.Tensor, responses: torch.Tensor, method: str = "schoppe", mask: Optional[torch.Tensor] = None, reduction: str = "mean", ccmax_iters: int = 126, ) -> torch.Tensor: """Noise-corrected correlation coefficient per neuron. See ``metrics_paradigm.md`` §6.4. Parameters ---------- pred : torch.Tensor Prediction of shape ``(B, N, 1, T)`` (the model-output convention; the R-axis must be 1). responses : torch.Tensor Raw responses ``(B, N, R, T)``. The PSTH is formed internally via ``nanmean(dim=2, keepdim=True)``; the noise ceiling uses the full repeat axis. May contain NaN. method : {'schoppe', 'hsu'}, default 'schoppe' ``'schoppe'`` divides by ``sqrt(var(pred) · SP)`` (Schoppe et al. 2016); ``'hsu'`` divides the raw Pearson by ``CCmax`` (Hsu/Spearman-Brown). mask : torch.Tensor, optional Bool tensor broadcastable to the PSTH ``(B, N, 1, T)``. If None, defaults to ``~psth.isnan()``; if provided, REPLACES the NaN-derived mask. reduction : {'none', 'mean', 'sum'}, default 'mean' Reduction over the neuron axis. ccmax_iters : int, default 126 Cap on random half-splits per ``(stim, neuron)`` for the ``'hsu'`` CCmax estimate. Returns ------- torch.Tensor Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar. In the single-trial degenerate case (``R = 1`` everywhere) the noise correction is undefined and the raw ``corrcoef(pred, psth)`` is returned. Raises ------ ValueError If ``method`` is not ``'schoppe'`` or ``'hsu'``, or if the input shapes are inconsistent. """ if pred.dim() != 4 or responses.dim() != 4: raise ValueError( f"pred and responses must be 4-D, got {pred.dim()} and " f"{responses.dim()}" ) if pred.shape[2] != 1: raise ValueError( f"pred R-axis must be 1 (model output convention), got " f"{pred.shape[2]}" ) if ( pred.shape[0] != responses.shape[0] or pred.shape[1] != responses.shape[1] or pred.shape[3] != responses.shape[3] ): raise ValueError( f"pred and responses must agree on B, N, T axes; got " f"pred {tuple(pred.shape)} vs responses {tuple(responses.shape)}" ) psth = torch.nanmean(responses, dim=2, keepdim=True) # (B, N, 1, T) valid_psth = resolve_mask(psth, mask) valid_resp = resolve_mask(responses, None) # Single-trial degenerate case: the noise correction is undefined. if responses.shape[2] == 1: return corrcoef(pred, psth, mask=mask, reduction=reduction) if method == "schoppe": sp, _ = _sahani_linden_per_neuron(responses, valid_resp) # (N,) N = pred.shape[1] nan = pred.new_full((), float("nan")) out = [] for n in range(N): v_n = valid_psth[:, n, ...] if int(v_n.sum().item()) < 2: out.append(nan) continue p_n = pred[:, n, ...][v_n] g_n = psth[:, n, ...][v_n] T_eff = p_n.numel() # Bessel-correct cov to match var(unbiased=True) and SP. cov_pg = ((p_n - p_n.mean()) * (g_n - g_n.mean())).sum() / max(T_eff - 1, 1) var_p = p_n.var(unbiased=True) sp_n = sp[n] if not torch.isfinite(sp_n) or sp_n.item() <= 0 or var_p.item() < _EPS: out.append(nan) continue out.append(cov_pg / torch.sqrt(var_p * sp_n)) return reduce_over_neurons(torch.stack(out), reduction) if method == "hsu": ccmax = _ccmax_per_neuron(responses, valid_resp, ccmax_iters) # (N,) cc_raw = corrcoef(pred, psth, mask=mask, reduction="none") # (N,) cc_norm = cc_raw / ccmax return reduce_over_neurons(cc_norm, reduction) raise ValueError( f"method must be 'schoppe' or 'hsu', got {method!r}" )
# ----------------------------------------------------------------------------- # Public: coherence (eval-only, NaN-intolerant) # -----------------------------------------------------------------------------
[docs] @torch.no_grad() def coherence( pred: torch.Tensor, gt: torch.Tensor, dt_ms: float, reduction: str = "mean", ) -> torch.Tensor: """Magnitude-squared coherence per neuron (mean over frequency bins). Eval-only: uses ``scipy.signal.coherence``, no gradient. Parameters ---------- pred : torch.Tensor Prediction of shape ``(B, N, 1, T)``. gt : torch.Tensor Ground-truth PSTH of shape ``(B, N, 1, T)``. dt_ms : float Time-bin width in milliseconds; sets the sampling rate ``fs = 1 / (dt_ms · 1e-3)`` passed to ``scipy.signal.coherence``. reduction : {'none', 'mean', 'sum'}, default 'mean' Reduction over the neuron axis. Returns ------- torch.Tensor Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar. Raises ------ ValueError If any element of ``pred`` or ``gt`` is NaN (this metric does not tolerate NaN; pre-flatten to a NaN-free subset first). """ _check_pred_gt(pred, gt) if torch.isnan(pred).any() or torch.isnan(gt).any(): raise ValueError( "coherence does not tolerate NaN; pre-flatten to a NaN-free subset." ) fs = 1.0 / (dt_ms * 1e-3) p = pred.squeeze(2).cpu().numpy() # (B, N, T) g = gt.squeeze(2).cpu().numpy() # (B, N, T) _, coh = scipy.signal.coherence(p, g, fs=fs) # (B, N, F) coh_t = torch.from_numpy(coh).to(pred.device).to(pred.dtype) per_neuron = coh_t.mean(dim=(0, -1)) # (N,) return reduce_over_neurons(per_neuron, reduction)
# ----------------------------------------------------------------------------- # Internal helpers (importable from audio loaders) # -----------------------------------------------------------------------------
[docs] @torch.no_grad() def compute_CCmax( responses: torch.Tensor, max_iters: int = 126, ) -> torch.Tensor: """CCmax (Hsu / Spearman-Brown) per ``(B,)`` cell. Internal helper kept importable for the audio loaders (legacy CCmax / TTRC pipeline). Prefer :func:`compute_neuron_quality` on the dataset itself. Parameters ---------- responses : torch.Tensor Responses of shape ``(B, R, T)``. NaN marks invalid repeats / time bins (dropped per ``(b,)``). max_iters : int, default 126 Cap on random half-splits per cell. Returns ------- torch.Tensor Shape ``(B,)``. ``1.0`` for cells with ``R = 1``; NaN for cells with ``ρ_half ≤ 0``. """ if responses.dim() != 3: raise ValueError( f"compute_CCmax expects (B, R, T), got {tuple(responses.shape)}" ) valid = ~responses.isnan() B, R, T = responses.shape nan = responses.new_full((), float("nan")) out = [] for b in range(B): v_b = valid[b] valid_repeats = v_b.any(dim=-1) valid_time = v_b.any(dim=0) R_v = int(valid_repeats.sum().item()) T_v = int(valid_time.sum().item()) if R_v < 2 or T_v < 2: out.append(responses.new_ones(()) if R_v == 1 else nan) continue sub = responses[b][valid_repeats][:, valid_time] if R_v % 2 == 1: R_v -= 1 sub = sub[:R_v] half_sets = list(itertools.combinations(range(R_v), R_v // 2)) n_iters = min(len(half_sets) // 2, max_iters) cc_halfs = [] for i in range(n_iters): first = sub[list(half_sets[i])].mean(dim=0) second = sub[list(half_sets[-1 - i])].mean(dim=0) cc_halfs.append(_pearson_1d(first, second)) if not cc_halfs: out.append(nan) continue rho_half = torch.stack(cc_halfs).mean() if rho_half.item() <= 0: out.append(nan) continue out.append(torch.sqrt(2 * rho_half / (1 + rho_half))) return torch.stack(out)
[docs] @torch.no_grad() def compute_TTRC(responses: torch.Tensor) -> torch.Tensor: """Trial-to-trial response correlation per ``(B,)`` cell. Internal helper kept importable for the audio loaders (legacy CCmax / TTRC pipeline). Prefer :func:`compute_neuron_quality` on the dataset itself. Parameters ---------- responses : torch.Tensor Responses of shape ``(B, R, T)``. NaN-aware. Returns ------- torch.Tensor Shape ``(B,)``. ``1.0`` for cells with ``R = 1``; NaN for cells with no valid trial pair. """ if responses.dim() != 3: raise ValueError( f"compute_TTRC expects (B, R, T), got {tuple(responses.shape)}" ) valid = ~responses.isnan() B, R, T = responses.shape nan = responses.new_full((), float("nan")) out = [] for b in range(B): v_b = valid[b] valid_repeats = v_b.any(dim=-1) valid_time = v_b.any(dim=0) R_v = int(valid_repeats.sum().item()) T_v = int(valid_time.sum().item()) if R_v < 2 or T_v < 2: out.append(responses.new_ones(()) if R_v == 1 else nan) continue sub = responses[b][valid_repeats][:, valid_time] cc_pairs = [] for i in range(R_v): for j in range(i + 1, R_v): cc_pairs.append(_pearson_1d(sub[i], sub[j])) if not cc_pairs: out.append(nan) continue out.append(torch.stack(cc_pairs).mean()) return torch.stack(out)