Source code for deepSTRF.models.readouts

"""
Readout modules: per-neuron projections that take a feature representation
emitted by a model's ``core`` and produce predictions of shape
``(B, N, R=1, T)``.

Two flavours are shipped:

- :class:`STRFReadout` — a learnable STRF kernel applied causally via
  ``CausalSTRFConv``. Used when the readout itself is the model's main
  learnable apparatus (Linear / LinearNonlinear) or when an explicit
  STRF interpretation of the per-neuron weights is wanted.

- :class:`LinearReadout` — a per-timestep linear projection
  ``in_features -> N``, with an optional 1-hidden-layer MLP. Used by
  models whose ``core`` already produces a flat per-timestep feature
  vector (ConvNet2D / Transformer / StateNet).

See ``docs/_source/md/model_paradigm.md`` §7 for the full readout
contract.
"""
import torch
import torch.nn as nn

from deepSTRF.models.layers import CausalSTRFConv


[docs] class STRFReadout(nn.Module): """ Per-neuron readout backed by a causal Spectro-Temporal Receptive Field kernel, with a per-neuron BatchNorm placed after the conv. Wraps a ``CausalSTRFConv`` of shape ``(N, C_in, F, T)`` and applies an output activation. The frequency axis is collapsed by the conv from ``F → 1``, so the readout naturally emits the canonical ``(B, N, R=1, T)`` rank. The per-neuron ``nn.BatchNorm1d(N)`` between the conv and the activation stabilises training and serves as the model's only normalisation layer in the Linear / LinearNonlinear cases — every learnable scalar in this readout (STRF kernel, conv bias, BN affine, BN running stats) has the neuron axis as leading dim, so the readout is strictly no-shared-params. Causality in eval mode is preserved: BN's running statistics are per-channel scalars, applied element-wise on the time axis at inference. Parameters ---------- F : int Frequency bins of the input spectrogram. T : int STRF temporal extent. C_in : int Input channel count after prefiltering. out_neurons : int Number of output neurons ``N``. kernel : nn.Module, optional Pluggable kernel module. ``None`` (default) instantiates a vanilla ``nn.Conv2d``; pass ``ParametricSTRF`` (DCLS) or a separable ``nn.Sequential`` to swap parameterizations. activation : nn.Module, optional Pointwise output nonlinearity. Defaults to ``nn.Identity``. bias : bool, default True Used only when ``kernel is None``. """ def __init__(self, F: int, T: int, C_in: int, out_neurons: int, kernel: nn.Module = None, activation: nn.Module = None, bias: bool = True): super().__init__() self.strf = CausalSTRFConv(F, T, C_in, out_neurons, kernel=kernel, bias=bias) self.bn = nn.BatchNorm1d(out_neurons) self.activation = activation if activation is not None else nn.Identity()
[docs] def forward(self, x): # x: (B, C_in, F, T) out = self.strf(x) # (B, N, 1, T) out = out.squeeze(-2) # (B, N, T) out = self.bn(out) # (B, N, T) — per-neuron BN # Apply activation on (B, T, N) so per-neuron parametric activations # (ParametricSoftplus / ParametricSigmoid / ParametricDoubleExponential) # broadcast correctly with N as the last axis. Shape-invariant # activations (Identity, nn.Sigmoid, nn.ReLU, ...) are unaffected. out = out.transpose(-1, -2) # (B, T, N) out = self.activation(out) # (B, T, N) out = out.transpose(-1, -2).unsqueeze(-2) # (B, N, 1, T) return out
[docs] def STRF_weight(self, polarity: str = None): """ Return the underlying STRF kernel as ``(N, C_in, F, T)``. Parameters ---------- polarity : {'ON', 'OFF', None}, optional If the model's prefilter produces ``C_in == 2`` channels (e.g. AdapTrans's ON/OFF), select one. ``None`` (default) returns the full ``(N, C_in, F, T)`` tensor; ``'ON'`` slices channel 0; ``'OFF'`` slices channel 1. """ kernel = self.strf.STRF_weight() if polarity is None: return kernel if polarity in ('ON', 'On', 'on', 0): return kernel[:, 0] if polarity in ('OFF', 'Off', 'off', 1): return kernel[:, 1] raise ValueError(f"polarity must be 'ON', 'OFF', or None — got {polarity!r}")
[docs] class LinearReadout(nn.Module): """ Per-neuron readout: a per-timestep linear projection from a flat feature vector to ``N`` output neurons. Accepts either a 3D input ``(B, in_features, T)`` or a 4D input ``(B, in_features, 1, T)`` (the singleton spatial axis emitted by an STRF-style conv that collapsed ``F → 1``); both shapes route through the same projection. The output is the canonical ``(B, N, 1, T)``. Parameters ---------- in_features : int Per-timestep feature dimension produced by the model's ``core``. out_neurons : int Number of output neurons ``N``. hidden : int, optional If given, inserts a 1-hidden-layer MLP ``in_features → hidden → N`` with a LeakyReLU(0.1) between. ``None`` (default) gives a single linear projection. activation : nn.Module, optional Pointwise output nonlinearity applied to the per-timestep output before the rank is unsqueezed to ``(B, N, 1, T)``. Defaults to ``nn.Identity``. bias : bool, default True Whether the linear projection(s) include a bias term. """ def __init__(self, in_features: int, out_neurons: int, hidden: int = None, activation: nn.Module = None, bias: bool = True): super().__init__() self.in_features = in_features self.N = out_neurons if hidden is None: self.fc = nn.Linear(in_features, out_neurons, bias=bias) else: self.fc = nn.Sequential( nn.Linear(in_features, hidden, bias=bias), nn.LeakyReLU(0.1), nn.Linear(hidden, out_neurons, bias=bias), ) self.activation = activation if activation is not None else nn.Identity()
[docs] def forward(self, x): if x.dim() == 4: assert x.shape[-2] == 1, ( f"4D input to LinearReadout must have a singleton dim -2 " f"(emitted by an F→1 conv); got shape {tuple(x.shape)}" ) x = x.squeeze(-2) # (B, C, T) elif x.dim() != 3: raise ValueError( f"LinearReadout expects 3D (B, C, T) or 4D (B, C, 1, T) " f"input; got {x.dim()}D shape {tuple(x.shape)}" ) # x: (B, C, T) y = x.transpose(-2, -1) # (B, T, C) y = self.fc(y) # (B, T, N) y = self.activation(y) # (B, T, N) — broadcasts over N y = y.transpose(-2, -1) # (B, N, T) return y.unsqueeze(-2) # (B, N, 1, T)