Source code for deepSTRF.training.fitter

"""``deepSTRF.training.Fitter`` — a thin, opt-in PyTorch training loop.

See ``docs/_source/md/fitter.md`` for the full design contract. This module
implements the canonical 3-line training step from
``metrics_paradigm.md`` §7, plus early stopping, checkpoint selection, and
cross-batch metric accumulation. The class is intentionally short — when
something doesn't fit (multi-GPU, mixed-precision, curricula, ...) the
recommended path is to write the loop, not to extend the Fitter.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from deepSTRF.metrics import corrcoef, mse_loss, normalized_corrcoef


# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------


def _default_val_metrics() -> Dict[str, Callable]:
    """Canonical val-metric pair: raw and noise-corrected Pearson correlation."""
    return {
        "cc": lambda pred, responses: corrcoef(pred, responses, reduction="none"),
        "cc_norm": lambda pred, responses: normalized_corrcoef(
            pred, responses, method="schoppe", reduction="none"
        ),
    }


def _to_scalar(x: Any) -> float:
    """Reduce a per-neuron tensor to a scalar via ``nanmean``; pass scalars through."""
    if isinstance(x, torch.Tensor):
        if x.numel() == 0:
            return float("nan")
        if x.numel() == 1:
            return float(x.detach().item())
        return float(torch.nanmean(x.detach()).item())
    return float(x)


def _format_epoch(epoch_dict: Mapping[str, Any]) -> None:
    """Default ``log_fn``: ``epoch | k=v | k=v | ...`` to stdout."""
    parts = []
    for k, v in epoch_dict.items():
        if k == "epoch":
            parts.append(f"epoch {int(v):4d}")
        else:
            parts.append(f"{k}={_to_scalar(v):.4f}")
    print(" | ".join(parts))


def _pad_and_cat(tensors: List[torch.Tensor]) -> torch.Tensor:
    """Right-pad each tensor along its last two axes (R, T) with NaN, cat along dim 0.

    All tensors must share dims (B_i, N, R_i, T_i) with the same ``N``. ``B_i``
    may vary (e.g. last partial batch); ``R_i`` and ``T_i`` may vary per batch
    and are padded to the global max with NaN. The NaN pads are dropped by
    every ``deepSTRF.metrics`` function via its NaN-derived mask
    (``metrics_paradigm.md`` §4).
    """
    if not tensors:
        raise ValueError("cannot concatenate an empty list of tensors")
    max_R = max(t.shape[2] for t in tensors)
    max_T = max(t.shape[3] for t in tensors)
    padded = []
    for t in tensors:
        pad_R = max_R - t.shape[2]
        pad_T = max_T - t.shape[3]
        if pad_R or pad_T:
            # F.pad takes (left_T, right_T, left_R, right_R) — back-to-front.
            t = F.pad(t, (0, pad_T, 0, pad_R), value=float("nan"))
        padded.append(t)
    return torch.cat(padded, dim=0)


# -----------------------------------------------------------------------------
# Fitter
# -----------------------------------------------------------------------------


[docs] class Fitter: """Opt-in training loop for a deepSTRF :class:`NeuralModel`. See ``docs/_source/md/fitter.md`` for the full design. Parameters ---------- model Any ``nn.Module`` whose ``forward`` emits ``(B, N, 1, T)`` predictions. Stateful models may implement ``model.detach()`` (no-op by default on ``deepSTRF.models.NeuralModel``); the Fitter calls it after every step. train_loader, val_loader ``DataLoader`` instances built with ``deepSTRF.utils.data.neural_collate``. ``val_loader`` with ``batch_size=1`` is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6). loss_fn Callable ``(pred, responses) -> Tensor``. Default ``mse_loss``. The deepSTRF losses auto-collapse ``responses`` to PSTH internally (``metrics_paradigm.md`` §2), so no caller-side ``nanmean`` is needed. val_metrics Mapping ``name -> callable(pred, responses) -> per-neuron Tensor``. Default: the canonical ``{'cc', 'cc_norm'}`` pair. Stored under ``f'val_{name}'`` in the epoch dict. optimizer Any ``torch.optim.Optimizer``. Default: ``AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)``. device Where to place the model and per-batch tensors. max_epochs Hard cap on training epochs. patience Early-stop patience: number of epochs without improvement on ``monitor`` before the loop terminates. monitor Key in the per-epoch dict to track for early stopping. Default ``'val_cc_norm'``. Use ``'val_loss'``, ``'val_cc'``, or any custom key you added via ``val_metrics``. mode ``'max'`` or ``'min'`` — direction of improvement on ``monitor``. Default ``'max'`` (paired with ``'val_cc_norm'``). ckpt_path If given, save the best-on-``monitor`` ``state_dict`` to this path and restore it at the end of ``fit()``. log_fn Called as ``log_fn(epoch_dict)`` once per epoch. Default: a small formatter that prints ``epoch | k=v | ...``. Override to log to WandB, MLflow, a file, etc. track_train_metrics If ``True`` (default), recompute ``val_metrics`` over the training predictions accumulated this epoch and add them to the epoch dict as ``'train_<name>'``. Useful for diagnosing overfitting but expensive on large datasets — accumulating ``(B, N, R, T)`` responses across all train batches is the dominant per-epoch cost when ``N × R × T`` is in the millions (e.g. AA2's 494-cell population). Set to ``False`` to skip; ``train_loss`` is always reported. track_per_cell_best If ``True``, maintain a per-cell best-on-``monitor`` snapshot of the readout's per-N parameter and buffer slices throughout training. At end-of-fit, after the global ``ckpt_path`` restore, each cell's slice is overlaid with its individual-best snapshot. On no-shared-params models this is **strictly** at least as good as the vanilla restore on the validation set, cell-by-cell, by construction — every cell ends up at its individual val peak. The training trajectory itself is unchanged (no gradient masking, no per-cell stopping); the only difference is which checkpoint is restored at end. Requires ``val_metrics[monitor.removeprefix ('val_')]`` to return a ``(N,)`` per-cell tensor (the default :func:`_default_val_metrics` does this). Default ``False``. """ def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, *, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = mse_loss, val_metrics: Optional[Dict[str, Callable]] = None, optimizer: Optional[Optimizer] = None, device: Union[str, torch.device] = "cpu", max_epochs: int = 1000, patience: int = 10, monitor: str = "val_cc_norm", mode: str = "max", ckpt_path: Optional[Union[str, Path]] = None, log_fn: Callable[[Mapping[str, Any]], None] = _format_epoch, track_train_metrics: bool = True, track_per_cell_best: bool = False, ) -> None: if mode not in ("max", "min"): raise ValueError(f"mode must be 'max' or 'min', got {mode!r}") if patience < 1: raise ValueError(f"patience must be >= 1, got {patience}") if max_epochs < 1: raise ValueError(f"max_epochs must be >= 1, got {max_epochs}") self.device = torch.device(device) self.model = model.to(self.device) self.train_loader = train_loader self.val_loader = val_loader self.loss_fn = loss_fn self.val_metrics = ( val_metrics if val_metrics is not None else _default_val_metrics() ) self.optimizer = ( optimizer if optimizer is not None else AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4) ) self.max_epochs = max_epochs self.patience = patience self.monitor = monitor self.mode = mode self.ckpt_path = Path(ckpt_path) if ckpt_path is not None else None self.log_fn = log_fn self.track_train_metrics = track_train_metrics self.track_per_cell_best = track_per_cell_best # ------------------------------------------------------------------ # Hooks (subclass and override, or pass kwargs at construction time) # ------------------------------------------------------------------
[docs] def compute_loss( self, pred: torch.Tensor, responses: torch.Tensor ) -> torch.Tensor: """Default: delegate to ``self.loss_fn(pred, responses)`` (auto-PSTH inside).""" return self.loss_fn(pred, responses)
[docs] def on_epoch_end(self, epoch: int, epoch_dict: Dict[str, Any]) -> None: """Default: log the epoch dict via ``self.log_fn``.""" self.log_fn(epoch_dict)
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def fit(self) -> List[Dict[str, Any]]: """Train until ``max_epochs`` or early-stop on ``monitor``. Returns ------- history : list of dict One dict per completed epoch, with keys ``'epoch'``, ``'train_*'``, and ``'val_*'``. """ history: List[Dict[str, Any]] = [] best_score = -float("inf") if self.mode == "max" else float("inf") better = ( (lambda new, best: new > best) if self.mode == "max" else (lambda new, best: new < best) ) epochs_no_improvement = 0 if self.track_per_cell_best: N = self.model.O self._per_cell_best_score = torch.full( (N,), -float("inf") if self.mode == "max" else float("inf"), device=self.device, ) self._per_cell_snapshots: Dict[int, List[torch.Tensor]] = {} for epoch in range(self.max_epochs): train = self._train_one_epoch() val = self._evaluate(self.val_loader) epoch_dict: Dict[str, Any] = {"epoch": epoch} epoch_dict.update({f"train_{k}": v for k, v in train.items()}) epoch_dict.update({f"val_{k}": v for k, v in val.items()}) history.append(epoch_dict) self.on_epoch_end(epoch, epoch_dict) if self.monitor not in epoch_dict: raise KeyError( f"monitor key {self.monitor!r} not in epoch dict; " f"available keys: {sorted(epoch_dict)}" ) # Per-cell snapshot update happens BEFORE the global best/patience # update so that snapshots track each cell's individual best # regardless of population-level early-stop behaviour. Needs the # per-cell monitor tensor; raises if it isn't one. if self.track_per_cell_best: self._update_per_cell_best(epoch_dict[self.monitor]) score = _to_scalar(epoch_dict[self.monitor]) if better(score, best_score): best_score = score if self.ckpt_path is not None: self.ckpt_path.parent.mkdir(parents=True, exist_ok=True) torch.save(self.model.state_dict(), self.ckpt_path) epochs_no_improvement = 0 else: epochs_no_improvement += 1 if epochs_no_improvement >= self.patience: break if self.ckpt_path is not None and self.ckpt_path.exists(): self.model.load_state_dict( torch.load(self.ckpt_path, map_location=self.device) ) # Overlay per-cell snapshots on top of the global-ckpt restore. # Order is intentional: the global restore resets every parameter # (including non-readout ones like the model's core) to the # population-best state; the per-cell overlay then replaces each # cell's readout slice with its individual-best. if self.track_per_cell_best: self._restore_per_cell_snapshots() return history
[docs] def evaluate(self, loader: DataLoader) -> Dict[str, Any]: """Run loss + ``val_metrics`` on a loader (no backprop, no key prefix). Returns a dict with keys ``'loss'`` plus each entry of ``self.val_metrics``. For test-set evaluation after training: ``fitter.evaluate(test_loader)``. """ return self._evaluate(loader)
# ------------------------------------------------------------------ # Per-cell snapshot bookkeeping (only used when track_per_cell_best=True) # ------------------------------------------------------------------ def _per_cell_readout_tensors(self): """Yield every readout tensor whose leading axis is the neuron axis. Iterates both ``parameters()`` and ``buffers()`` under ``self.model.readout``, filtering for ``shape[0] == self.model.O``. On a no-shared-params readout (STRF kernel + per-neuron BN + per-neuron activation, the post-2026-05-19 audio convention) this yields every learnable scalar in the readout. """ N = self.model.O for p in self.model.readout.parameters(): if p.dim() >= 1 and p.shape[0] == N: yield p for b in self.model.readout.buffers(): # skip 0-d scalar buffers (e.g. BN's num_batches_tracked) if b.dim() >= 1 and b.shape[0] == N: yield b def _update_per_cell_best(self, score: Any) -> None: N = self.model.O if not isinstance(score, torch.Tensor) or score.shape != (N,): raise ValueError( f"track_per_cell_best=True requires the {self.monitor!r} " f"val metric to return a per-cell tensor of shape ({N},); " f"got {type(score).__name__} with shape " f"{tuple(score.shape) if isinstance(score, torch.Tensor) else None!r}. " f"Use val_metrics callables with reduction='none'." ) score = score.to(self.device) better = ( (lambda new, best: new > best) if self.mode == "max" else (lambda new, best: new < best) ) improved = better(score, self._per_cell_best_score) & ~score.isnan() for n in torch.nonzero(improved, as_tuple=True)[0].tolist(): self._per_cell_snapshots[n] = [ p.data[n].detach().clone() for p in self._per_cell_readout_tensors() ] self._per_cell_best_score = torch.where( improved, score, self._per_cell_best_score ) def _restore_per_cell_snapshots(self) -> None: for n, snap in self._per_cell_snapshots.items(): for p, s in zip(self._per_cell_readout_tensors(), snap): p.data[n] = s # ------------------------------------------------------------------ # Internals # ------------------------------------------------------------------ def _train_one_epoch(self) -> Dict[str, Any]: self.model.train() loss_sum = 0.0 n_batches = 0 preds_list: List[torch.Tensor] = [] responses_list: List[torch.Tensor] = [] for batch in self.train_loader: stims = batch['stims'].to(self.device) responses = batch['responses'].to(self.device) self.optimizer.zero_grad() pred = self.model(stims) loss = self.compute_loss(pred, responses) loss.backward() self.optimizer.step() if hasattr(self.model, "detach"): self.model.detach() loss_sum += float(loss.detach().item()) n_batches += 1 # Accumulate on CPU so the concatenated metrics tensor does not # have to fit in GPU memory — for large datasets (494 cells × 81 # train stims × 20 trials × 511 frames on AA2) the GPU concat # exceeds typical visible memory by several gigabytes. Skip the # accumulation entirely when train-side metrics are disabled — # halves wall time on large datasets where users only care # about val metrics. if self.track_train_metrics: preds_list.append(pred.detach().cpu()) responses_list.append(responses.detach().cpu()) out: Dict[str, Any] = {"loss": loss_sum / max(n_batches, 1)} if self.track_train_metrics: with torch.no_grad(): preds_cat = _pad_and_cat(preds_list) responses_cat = _pad_and_cat(responses_list) for name, fn in self.val_metrics.items(): out[name] = fn(preds_cat, responses_cat) return out def _evaluate(self, loader: DataLoader) -> Dict[str, Any]: self.model.eval() preds_list: List[torch.Tensor] = [] responses_list: List[torch.Tensor] = [] with torch.no_grad(): for batch in loader: stims = batch['stims'].to(self.device) responses = batch['responses'].to(self.device) pred = self.model(stims) if hasattr(self.model, "detach"): self.model.detach() # Same CPU-accumulation as in _train_one_epoch — see comment # there for the AA2-scale memory rationale. preds_list.append(pred.cpu()) responses_list.append(responses.cpu()) preds_cat = _pad_and_cat(preds_list) responses_cat = _pad_and_cat(responses_list) out: Dict[str, Any] = { "loss": float(self.compute_loss(preds_cat, responses_cat).item()), } for name, fn in self.val_metrics.items(): out[name] = fn(preds_cat, responses_cat) return out