The deepSTRF model paradigm
This note documents how deepSTRF structures encoding models, what the
forward contract is, and what invariants any model author must respect. It
is the single reference for the “model contract” of the library and should
stay in sync with the NeuralModel / AudioEncodingModel /
VideoEncodingModel base classes.
It is the encoding-model counterpart to data_paradigm.md. The two
documents together describe how data and models meet at the loss.
1. Scope: encoding models
All current models in deepSTRF.models.audio and deepSTRF.models.video
are encoding models: they take a sensory stimulus and predict the
neural response (PSTH or per-trial activity).
A future extension will introduce decoding models (response → stim).
The dataloader contract does not change — only the model class hierarchy.
Encoding models live under AudioEncodingModel / VideoEncodingModel;
decoding models will get parallel *DecodingModel base classes when a
concrete use case drives their design.
2. The forward template
Every encoding model in deepSTRF is a four-stage pipeline:
stimulus
│
▼
┌─────────────┐
│ wav2spec │ future slot — Identity by default. Learnable
│ │ audio front-end (e.g. LEAF) when raw waveforms
│ │ become first-class inputs.
└─────────────┘
│
▼
┌─────────────┐
│ prefiltering│ AdapTrans / ICAdaptation / Identity. Stim-side
│ │ non-trainable or lightly-trainable preprocessing.
│ │ Declares `out_channels: int` so the downstream
│ │ core can size its inputs.
└─────────────┘
│
▼
┌─────────────┐
│ core │ Stim-shared feature backbone. ConvNet / RNN / SSM
│ │ / Transformer / etc. For LN models the core is
│ │ `nn.Identity()` — no shared features, the readout
│ │ does all the work.
└─────────────┘
│
▼
┌─────────────┐
│ readout │ Per-neuron projection. `out_channels = N`. Owns
│ │ the output activation. STRFReadout for L/LN/NRF/
│ │ DNet; LinearReadout for ConvNet / RNN / Transformer
│ │ models.
└─────────────┘
│
▼
response
The base class implements forward once:
def forward(self, x):
x = self.wav2spec(x)
x = self.prefiltering(x)
f = self.core(x)
return self.readout(f)
Concrete models populate the slots in their __init__. The base
defines reasonable defaults (nn.Identity() for wav2spec,
prefiltering, and core); the readout slot has no default and must
be set by every concrete model — validate() enforces this.
Here is the actual implementation of Linear in audio_zoo.py, the
minimal reference example:
class Linear(AudioEncodingModel):
def __init__(self, n_frequency_bands=34, temporal_window_size=9, out_neurons=1,
output_activation=None, prefiltering=None, kernel=None):
super().__init__(
n_frequency_bands=n_frequency_bands,
temporal_window_size=temporal_window_size,
out_neurons=out_neurons,
prefiltering=prefiltering,
)
# core: causal per-timestep frequency normalization
self.core = CausalLayerNorm(self.F, dim=-2)
# readout: pluggable STRF kernel + output activation
self.readout = STRFReadout(
F=self.F, T=self.T, C_in=self.C_in, out_neurons=self.O,
kernel=kernel,
activation=output_activation or nn.Identity(),
)
# forward inherited from NeuralModel
Subclasses may override forward when the architecture genuinely cannot
fit the template — StateNet’s RNN reshape and Transformer’s per-forward
attention-mask construction are current examples. These overrides are
exceptions and should be called out in the model docstring.
3. The forward signature
Input shape (per modality)
Modality |
Stimulus shape |
Notes |
|---|---|---|
audio |
|
float32, zero-padded on the right along T. Never NaN. |
video |
|
float32, zero-padded on the right along T. Never NaN. |
These match the output of neural_collate exactly. Models should never
need to reshape the input; if they do, the reshape is internal to one of
the four slots (typically the core).
Output shape
All models output (B, N, R=1, T_max).
The R-axis is degenerate today (always 1) but kept for forward- compatibility with probabilistic models that emit one or more samples (or distribution parameters) per time bin. Single-prediction models unsqueeze; probabilistic models populate the axis.
Where the R=1 axis comes from
Slot |
Output rank |
|---|---|
|
|
|
|
|
model-specific — |
|
|
The readout owns the responsibility of projecting to N output channels and
unsqueezing the R axis. This is by design: it keeps the rank-normalization
in one place, and the STRF_weight / STRF_gradmap introspection helpers
have a single object to query.
4. Causality
deepSTRF encoding models are strictly causal by convention: the
prediction at time t depends only on stim [..., t], never on
[t+1, ...].
This is a hard contract, not an aspiration. It matters for:
Real-time inference — closed-loop neuroscience experiments.
Generalization across context lengths — a non-causal model trained on
T=100clips may behave very differently onT=1000because the effective receptive field shifted.Comparability with biological neurons — they’re causal too.
Causal building blocks
Operation |
Causal? |
Notes |
|---|---|---|
|
✓ |
left-pad explicitly with |
|
✓ |
left-only temporal padding |
|
✓ |
causal extrapolation |
|
✓ |
LayerNorm over the channel axis at every (B, F, T) position |
|
✓ |
LayerNorm over the frequency axis at every (B, C, T) position |
|
✓ |
inherently causal |
RNN-style SSMs (S4, Mamba) |
✓ |
inherently causal |
|
✓ |
mask required — without it, attention is bidirectional |
Non-causal traps to avoid
Operation |
Why it breaks causality |
|---|---|
|
Pools statistics across batch and time |
|
Includes future timesteps |
|
Pools over the full clip |
|
Pools statistics over time within each sample |
|
Attention sees all positions |
Use CausalLayerNorm instead of BatchNorm — it normalizes across
a single axis (channel or frequency, by dim argument) at each
position of every other axis, never pooling across time. Defined in
deepSTRF.models.layers; thin wrapper around nn.LayerNorm with a
movedim to reach a non-trailing target axis.
5. Where the weights live
Heuristic for placing learnable parameters:
Parameter shape / role |
Slot |
|---|---|
Per-neuron ( |
readout |
Stim-shared ( |
core |
Stim-side, often biology-derived |
prefiltering |
Audio-frontend, raw-waveform-to-spectrogram |
wav2spec (future) |
Worked example: LN model
For a Linear-Nonlinear model the entire trainable apparatus that
distinguishes one neuron from another is the STRF kernel
(N, C_in, F, T) — N independent filters, one per neuron — held by the
readout. The core does only one thing: per-timestep frequency
normalization.
self.wav2spec = nn.Identity()
self.prefiltering = AdapTrans(...) or nn.Identity()
self.core = CausalLayerNorm(F, dim=-2) # input freq norm
self.readout = STRFReadout(F, T, C_in, N,
kernel=...,
activation=nn.Sigmoid())
The core’s LayerNorm has 2*F parameters (γ, β per frequency band) but
is stim-shared across all N neurons. The per-neuron weights all live
in the readout. The template handles deeper models the same way — they
just populate core with more layers.
Worked example: ConvNet2D
self.prefiltering = ...
self.core = nn.Sequential(
CausalLayerNorm(F, dim=-2), # input freq norm
nn.ZeroPad2d((3*(K[1]-1), 0, 0, 0)), # explicit causal left-pad
Conv2d(C_in, C, K), CausalLayerNorm(C, dim=1), LeakyReLU(),
Conv2d(C, C, K), CausalLayerNorm(C, dim=1), LeakyReLU(),
Conv2d(C, C, K), CausalLayerNorm(C, dim=1), LeakyReLU(),
nn.Flatten(start_dim=1, end_dim=2), # (B, C, F_down, T) → (B, C*F_down, T)
)
self.readout = LinearReadout(C * F_down, N,
hidden=H,
activation=nn.Sigmoid())
The convolutional features are stim-shared (every neuron sees the same
backbone output); the per-neuron projection happens in the readout’s
final nn.Linear(H, N).
6. Prefiltering
Prefiltering is a stim-side preprocessing slot, often biologically motivated (cochlear adaptation, retinal lateral inhibition).
Contract
A prefilter is any nn.Module that:
Takes a stimulus tensor of shape
(B, 1, F, T)(audio) or(B, 1, H, W, T)(video).Returns a tensor of shape
(B, C_out, F, T)or(B, C_out, H, W, T)with the same spatial / spectral dimensions and an arbitrary number of channels.Exposes an integer attribute
out_channels.
The contract is duck-typed (no formal ABC) — the base model reads
prefiltering.out_channels to size the downstream core.
Currently shipped
Class |
Module |
|
Learnable? |
Reference |
|---|---|---|---|---|
|
|
1 |
— |
default no-op |
|
|
1 |
No |
Willmore et al. 2016, J. Neurosci. |
|
|
2 |
Yes (default; |
Rançon et al. 2024, PLOS CB |
ICAdaptation is intentionally frozen — paper-faithful to Willmore
et al. who derive the time constants analytically from the cochlear
frequency map (freq_to_tau formula).
Ergonomics: make_prefiltering
For users who don’t want to construct a prefilter by hand:
from deepSTRF.models.prefiltering import make_prefiltering
prefilt = make_prefiltering("adaptrans", n_frequency_bands=34, dt=5.0,
min_freq=500, max_freq=20000, scale="mel")
Equivalent to the hand-construction. The factory is convenience, not
contract — passing a nn.Module instance directly is also fine.
7. Readouts
A readout takes the core’s output and emits (B, N, 1, T). Two are
shipped; more can be added per modality or per parameterization scheme.
STRFReadout
A learnable (C_in, N, F, T) STRF kernel applied via causal Conv2d. The
“linear” half of an LN model.
STRFReadout(F, T, C_in, N,
kernel: nn.Module = None, # default: vanilla full kernel
activation: nn.Module = nn.Identity(),
bias: bool = True)
The kernel slot is itself pluggable, which is how parameterization
schemes ride into the model:
Kernel class |
What it parameterizes |
|---|---|
|
Vanilla full |
|
Sum of K Gaussians with learnable positions/sigmas (DCLS, Khalfaoui-Hassani 2023, ICLR) |
|
Frequency × time outer product (rank-1 by construction) |
Future: |
Rank-r factorization |
Internally STRFReadout left-pads its input by T-1 zeros on the time
axis (causality) and applies Conv2d with the kernel. Output is
(B, N, 1, T) — the singleton frequency axis after (F → 1) reduction
serves naturally as the R axis after a transpose.
STRFReadout.STRF_weight(polarity='ON' | 'OFF') returns the underlying
kernel for visualization, regardless of parameterization. Models with an
STRFReadout proxy this method: model.STRF_weight() → readout’s.
LinearReadout
A per-neuron projection from a flat feature vector — the “fc” at the end of ConvNet / RNN / Transformer cores.
LinearReadout(in_features, out_neurons,
hidden: int = None, # optional 1-hidden-layer MLP
activation: nn.Module = nn.Identity(),
bias: bool = True)
Accepts (B, in_features, T) (channel-as-dim-1 convention, matching
Conv2d outputs) or (B, in_features, 1, T) (the singleton-spatial-axis
shape produced by an STRF-style conv that collapsed F → 1); both shapes
route through the same projection. Always emits (B, N, 1, T).
8. Output activations
Activations live inside the readout, not on the base class. Each
readout takes an activation: nn.Module kwarg.
Shipped
Class |
Use |
Reference |
|---|---|---|
|
Centered targets (z-scored PSTH) |
default for |
|
Bounded |
— |
|
Non-negative spike-rate targets, smooth |
— |
|
2-param |
deepSTRF (default for the rest of the zoo) |
|
4-param |
Willmore et al. 2016, J. Neurosci. |
|
4-param |
Thorson et al. 2015, PLOS CB |
The parametric activations have one set of learnable parameters per
output neuron (N instances each). ParametricSoftplus is the default
for LinearNonlinear, NetworkReceptiveField, ConvNet2D, and StateNet
because it is unbounded above (smoothed PSTHs routinely peak above 1) and
non-negative by default — natural for spike-count regression. Pre-2026-05
defaults were nn.Sigmoid() for those four models, which empirically
caused mean-collapse on NS1 because the [0, 1] cap fights the gradient.
non_negative_output flag
All three parametric activations expose a non_negative_output: bool
constructor kwarg. When True (the default), the parameters that gate
non-negativity are stored as raw parameters and softplus-mapped to the
strictly-positive half-line at every forward pass. The output is then
guaranteed non-negative by construction — suitable for spike-count
targets paired with poisson_loss(log_input=False) (see
metrics_paradigm.md §6.2).
For ParametricSoftplus, the gated parameter is the additive baseline
b; the sharpness β is always softplus-reparameterised regardless
(non-positive sharpness would flip the curve and is never physically
meaningful).
When False, parameters are direct (signed-output mode). Use this for
LFP / EEG / centred PSTH targets where the output may legitimately be
negative; pair with poisson_loss(log_input=True) if you still want
a Poisson NLL.
Pairing with poisson_loss
Activation |
Recommended |
|---|---|
|
|
|
|
|
|
|
|
|
|
Any with |
|
Empirical note (2026-05)
The original closure-based forward was replaced by a regular
forward() method, and the default reparameterisation guarantees
f(x) ≥ 0. End-to-end on NS1 + StateNet, the new
ParametricSoftplus(N) default reaches the published cc_norm range
(~0.7-0.8) in 50 epochs — competitive with hand-tuned nn.Identity().
The previous nn.Sigmoid() default for the same four models silently
mean-collapsed (val cc_norm slowly dropping toward 0 while loss
decreased) because the [0, 1] output cap fights the gradient on
spike-count targets that exceed 1.
ParametricSoftplus was chosen as the new default over the saturating
parametric activations (Sigmoid / DoubleExp) precisely because the
unbounded-above shape removes that failure mode. ParametricSigmoid and
ParametricDoubleExponential remain available for users with bounded
targets or explicit saturation modelling needs.
9. STRF introspection
Two complementary helpers:
Method |
What it returns |
Where |
Generality |
|---|---|---|---|
|
The underlying kernel |
|
L / LN / NRF / DNet (closed-form readout) |
|
Gradient of output activity w.r.t. a null input |
|
Any model (architecture-agnostic) |
STRF_gradmap parallelizes one gradmap per output neuron in a single
forward/backward pass, using the batch dimension (B = self.O). This is
worth preserving because it scales linearly with neurons-on-GPU and
avoids per-neuron re-instantiation.
For the full math, sign convention, and caveats see
README_gradmap_strf.md; the worked example
lives in examples/strf_gradmap_aa2.ipynb.
10. The validate() contract
Every concrete model must satisfy:
Slots populated.
wav2spec,prefiltering,core,readoutare allnn.Moduleinstances. The first three default tonn.Identityif the subclass doesn’t set them;readouthas no default and must be assigned beforevalidate()is called.out_neurons > 0(self.O > 0).Output rank.
forward(x)emits(B, N, R=1, T). The contract tests intests/test_audio_models.pyenforce this for every concrete audio model.Causality.
forward(x)[..., :t]is independent ofx[..., t:]. Enforced by the bit-causality test intests/test_audio_models.py::test_bitwise_causality_in_eval_mode.
Subclasses extend validate() with modality-specific invariants —
AudioEncodingModel.validate() checks F, T > 0 and C_in >= 1;
VideoEncodingModel.validate() checks H, W, T > 0.
12. Planned extensions
Decoding models. Parallel
*DecodingModelhierarchy when the first concrete decoder lands.Raw-waveform front-end.
wav2specslot becomes an active position; LEAF (Zeghidour et al. 2021) is the first candidate.Transformer RoPE option. Sinusoidal positional encoding ships today; RoPE (Su et al. 2021) is a planned alternative once it has empirical support.
gradmap.mdandexamples/strf_gradmap.ipynb.Output activations recheck. Verify
ParametricSigmoidandParametricDoubleExponentialagainst their original papers; addSoftplusfor non-negative-rate targets.SSM dependency cleanup. The Mamba backbone now uses the upstream
mambapypackage (a default dependency), replacing the vendoredmamba.py/pscan.py. S4 and LMU remain vendored underdeepSTRF.models.dependencies: there is no maintained PyTorch package on PyPI for either (state-spaces/s4shipss4.pyas a copy-into-repo standalone, andpytorch-lmuis GitHub-only). Both rely solely on already-present dependencies (numpy/scipy/einops), so they add no install burden.