deepSTRF.models.audio package

Submodules

deepSTRF.models.audio.audio_model module

class deepSTRF.models.audio.audio_model.AudioEncodingModel(n_frequency_bands: int, temporal_window_size: int, out_neurons: int = 1, prefiltering: Module = None, wav2spec: Module = None, *args, **kwargs)[source]

Bases: NeuralModel

Base class for encoding models of audio neural responses.

Forward signature: input (B, 1, F, T) spectrogram → output (B, N, R=1, T) neural activity. Concrete subclasses populate the four canonical slots wav2spec / prefiltering / core / readout (see NeuralModel).

Parameters:
  • n_frequency_bands (int) – Number of input spectrogram frequency bands F.

  • temporal_window_size (int) – STRF temporal extent T in frames. Used by STRF_gradmap to size the null stimulus.

  • out_neurons (int, default 1) – Number of output neurons N.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (e.g. AdapTrans, ICAdaptation). Must expose an out_channels integer attribute so the model can size C_in automatically. None (default) gives nn.Identity() and C_in = 1.

  • wav2spec (nn.Module, optional) – Optional raw-waveform front-end. Maps a mono waveform (B, 1, T_audio) to a spectrogram (B, 1, F, T_neural) and slots in at the top of the canonical forward pipeline (see NeuralModel). Must expose an out_channels: int attribute equal to n_frequency_bands. Pair with a dataset in waveform mode (e.g. NS1Dataset(return_waveform=True)). None (default) keeps the slot as nn.Identity() — the model then expects spectrogram input (B, 1, F, T) as before.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRF_gradmap(T: int = None)[source]

Compute one STRF gradient map per output neuron in parallel.

For each of the N output neurons, finds the changes in a null spectrogram that elicit an increase in that neuron’s activity at the last timestep (a Spike-Triggered-Average-like readout, computed by autodiff). The batch dimension is used to parallelize across neurons in a single forward / backward pass.

Parameters:

T (int, optional) – Time-axis length of the null stimulus. Defaults to self.T.

Returns:

Per-neuron gradient map.

Return type:

Tensor of shape ``(N, 1, F, T)``

References

Rançon et al. (2025), “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.”

Notes

Future work:

  • Handle multi-channel inputs (the gradient is currently shaped (N, 1, F, T) regardless of C_in; an AdapTrans-prefiltered model has C_in == 2 and the per-channel gradients differ).

  • Allow custom losses (e.g. sustained activity rather than last-timestep-only).

validate()[source]

Check that the instance is deepSTRF-compatible.

Subclasses should call super().validate() and then add their own checks (e.g. AudioEncodingModel checks F, T > 0).

Raises:

AssertionError – If self.O is not a positive int, if readout is unset or not an torch.nn.Module, or if any of the wav2spec / prefiltering / core slots is not an torch.nn.Module.

waveform_gradmap(stimulus, neuron=None, reduce='last')[source]

Gradient of a neuron’s response w.r.t. the input waveform.

The waveform-domain analogue of STRF_gradmap(): instead of probing the spectrogram input, backprop a neuron’s response all the way through the (learnable) wav2spec front-end to the raw audio samples. The returned gradient is itself a waveform — a listenable, time-domain receptive field — only defined for wav-native models (a non-Identity wav2spec).

Parameters:
  • stimulus (array-like, shape (T_audio,), (1, T_audio) or (1, 1, T_audio)) – Audio to compute the gradmap around (e.g. a real stimulus). A real stimulus is recommended over silence — adaptive front-ends (PCEN) are ill-conditioned at zero energy.

  • neuron (int, optional) – Which output neuron. None (default) sums over all neurons (a population gradmap).

  • reduce ({'last', 'peak', 'sum'}, default 'last') – How to reduce the neuron’s response over time before backprop. 'last' (default, matching STRF_gradmap()) maximizes the activation at the last timestep, so the gradient is supported only within the receptive field before it — it reveals the RF and decays to ~zero further into the past. 'peak' does the same at the peak-response timestep. 'sum' time-integrates over all output timesteps, which makes the gradient non-zero almost everywhere by construction (a whole-stimulus saliency map, not a receptive field).

Returns:

Per-audio-sample gradient — the waveform-domain receptive field. Computed in eval mode (the strictly-causal inference regime). With reduce='last' the support is the RF length (e.g. ~45 ms for a T=9 STRF on a mel front-end; longer for adaptive front-ends like LEAF, whose PCEN smoother adds a decaying temporal memory).

Return type:

torch.Tensor, shape ``(T_audio,)``

deepSTRF.models.audio.audio_zoo module

class deepSTRF.models.audio.audio_zoo.ConvNet2D(n_frequency_bands: int = 34, kernel_size: tuple = (3, 9), c_hidden: int = 10, n_hidden: int = 20, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None)[source]

Bases: AudioEncodingModel

Convolutional STRF model with three sequential 2D convs and a 2-layer fully-connected readout — adapted from the ‘2D-CNN’ of Pennington & David (2023).

Architecture: three Conv2d → CausalLayerNorm → LeakyReLU blocks extract a stack of feature maps; the per-time-step features are flattened over the (channel × downsampled-frequency) axes and a 2-layer FC reads out N output neurons.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • kernel_size (tuple of int, default (3, 9)) – Conv2d kernel (K_F, K_T) shared across the three conv blocks.

  • c_hidden (int, default 10) – Number of channels in each conv block.

  • n_hidden (int, default 20) – Width of the FC hidden layer.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, default ParametricSoftplus(out_neurons)) – Pointwise nonlinearity at the output. The default is unbounded above and non-negative — natural for spike-count regression.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (AdapTrans, ICAdaptation, or any module exposing out_channels). None (default) gives nn.Identity and C_in = 1.

References

Pennington & David (2023). “A convolutional neural network provides a generalizable model of natural sound coding by neural populations in auditory cortex.” PLOS Comp. Biol. 19(5): e1011110. https://doi.org/10.1371/journal.pcbi.1011110

Notes

Differences from the original paper:

  • Causal LayerNorm replaces the missing internal normalization (paper uses none).

  • Hidden activation is LeakyReLU(0.1) rather than ReLU. Empirical preference, very small architectural difference.

  • 2D convs over (F, T) rather than 1D convs over T (the paper applies 1D convolutions with implicit spectral pooling).

  • Frequency downsampling is implicit via valid-padding shrinkage: three convs each shrink F by K_F - 1, giving F_down = F - 3*(K_F - 1).

  • Causal left-padding extends the model to arbitrary input lengths; the paper also uses explicit causal padding.

  • The output activation is configurable; the paper uses a 4-parameter double-exponential — see deepSTRF.models.activations.ParametricDoubleExponential.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class deepSTRF.models.audio.audio_zoo.DNet(n_frequency_bands: int = 34, temporal_window_size: int = 9, n_hidden: int = 20, init_tau: float = 2.0, decay_input: bool = True, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]

Bases: AudioEncodingModel

Dynamic Network (DNet) — an NRF whose hidden and output units are stateful with learnable temporal decay.

Architecture: STRF projection → channel-norm → sigmoid → learnable exponential decay (one time constant per hidden unit) → 1×1 readout → output activation. The exponential decay is causal: each unit’s output at time t is a convolution of its instantaneous input with a learned one-sided exponential kernel.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • temporal_window_size (int, default 9) – STRF temporal extent T.

  • n_hidden (int, default 20) – Hidden layer width H.

  • init_tau (float, default 2.0) – Initial time constant (in frames) for the hidden-unit exponential decay.

  • decay_input (bool, default True) – If True, the exponential decay also weights its instantaneous input by 1/(1+d²) (paper convention); if False, the instantaneous input passes through unscaled.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, optional) – Pointwise nonlinearity at the readout output. Default nn.Identity (paper-faithful linear readout).

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter.

  • kernel (nn.Module, optional) – Pluggable hidden-layer STRF kernel.

References

Rahman, Willmore, King & Harper (2019). “A dynamic network model of temporal receptive fields in primary auditory cortex.” PLOS Comp. Biol. 15(5): e1006618. https://doi.org/10.1371/journal.pcbi.1006618

Notes

Differences from the original paper:

  • Causal LayerNorm replaces the missing internal normalization (paper assumes preprocessing-time input normalization).

  • Causal left-padding extends the model to arbitrary input lengths; the paper uses fixed-window slicing.

  • The hidden STRF kernel can be parameterized (DCLS); the paper uses a vanilla full kernel.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRFs(hidden_idx: int = 0, polarity: str = 'ON')[source]

Return the hidden-layer STRF kernel for one hidden unit as (F, T).

Parameters:
  • hidden_idx (int, default 0) – Which of the H hidden units to inspect.

  • polarity ({'ON', 'OFF'}, default 'ON') – Only relevant for AdapTrans-prefiltered models (C_in == 2).

class deepSTRF.models.audio.audio_zoo.Linear(n_frequency_bands: int = 34, temporal_window_size: int = 9, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None, wav2spec: Module = None)[source]

Bases: AudioEncodingModel

The canonical Linear (L) STRF model — a single SpectroTemporal Receptive Field convolved with the (optionally prefiltered) input spectrogram.

All learnable parameters live in the readout (STRFReadout), which holds the kernel of shape (N, C_in, F, T), applies it causally via left-padding, and follows it with a per-neuron nn.BatchNorm1d(N) before the output activation. The model’s core is nn.Identity — every learnable scalar has the neuron axis as leading dim, so the model is strictly no-shared-params (each neuron’s parameters are independent of every other neuron’s).

Parameters:
  • n_frequency_bands (int, default 34) – Number of input spectrogram frequency bands F.

  • temporal_window_size (int, default 9) – STRF temporal extent T in frames.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, optional) – Pointwise nonlinearity applied at the readout output. Default nn.Identity (true linear model).

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (AdapTrans, ICAdaptation, or any nn.Module exposing out_channels). None (default) gives nn.Identity and C_in = 1.

  • kernel (nn.Module, optional) – Pluggable STRF kernel for the readout. None (default) gives a vanilla nn.Conv2d; pass ParametricSTRF(...) for DCLS, or a separable nn.Sequential for a rank-1 factorization. See deepSTRF.models.layers for the kernel module catalogue.

  • wav2spec (nn.Module, optional) – Optional raw-waveform front-end (deepSTRF.models.wav2spec.*). When provided, the model accepts raw audio (B, 1, T_audio) instead of a spectrogram. None (default) keeps the slot as nn.Identity() and the model expects (B, 1, F, T) spec input.

References

The L model is a folklore baseline; canonical formulations appear in:

Theunissen, Sen & Doupe (2000). “Spectral-Temporal Receptive Fields of Nonlinear Auditory Neurons Obtained Using Natural Sounds.” J. Neurosci. 20(6): 2315–2331. https://doi.org/10.1523/JNEUROSCI.20-06-02315.2000

Sahani & Linden (2003). “How Linear are Auditory Cortical Responses?” NIPS.

Notes

The per-neuron BatchNorm inside the readout absorbs into the kernel at inference (its running stats are frozen per-channel scalars), so the model remains a strict linear-affine map of the input at eval time and the learned STRF kernel is directly interpretable up to a per-neuron affine rescaling.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRF_weight(polarity: str = 'ON')[source]

Return the readout’s STRF kernel as a (N, F, T) tensor.

For models prefiltered with AdapTrans (C_in == 2), polarity selects the ON or OFF channel of the kernel. For single-channel inputs the parameter is ignored.

class deepSTRF.models.audio.audio_zoo.LinearNonlinear(n_frequency_bands: int = 34, temporal_window_size: int = 9, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]

Bases: Linear

Linear-Nonlinear (LN) STRF model — the Linear model followed by a pointwise output nonlinearity.

Inherits everything from Linear and only changes the default output activation from nn.Identity to a per-neuron ParametricSoftplus. Pass any nn.Module to output_activation to override.

Parameters:

output_activation (nn.Module, default ParametricSoftplus(out_neurons)) – Pointwise nonlinearity applied at the readout output. The default is unbounded above and non-negative — natural for spike-count regression on smoothed PSTHs that exceed 1. See deepSTRF.models.activations for other parametric variants (ParametricSigmoid, ParametricDoubleExponential).

See also

Linear

Same architecture without the output nonlinearity.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class deepSTRF.models.audio.audio_zoo.NetworkReceptiveField(n_frequency_bands: int = 34, temporal_window_size: int = 9, n_hidden: int = 20, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]

Bases: AudioEncodingModel

Network Receptive Field (NRF) model — a two-layer feedforward STRF network.

Architecture: a STRF kernel projects the input spectrogram into a hidden layer of H units; a 1×1 conv reads out the N output neurons from the hidden activations. With L1 regularization the paper finds typically 1–7 effective hidden units per neuron.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • temporal_window_size (int, default 9) – STRF temporal extent T.

  • n_hidden (int, default 20) – Hidden layer width H.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, optional) – Pointwise nonlinearity at the readout output. Default nn.Sigmoid.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (AdapTrans, ICAdaptation, any module exposing out_channels). None (default) gives nn.Identity and C_in = 1.

  • kernel (nn.Module, optional) – Pluggable hidden-layer STRF kernel. None (default) gives a vanilla nn.Conv2d; pass ParametricSTRF(...) for DCLS.

References

Harper, Schoppe, Willmore, Cui, Schnupp & King (2016). “Network Receptive Field Modeling Reveals Extensive Integration and Multi-feature Selectivity in Auditory Cortical Neurons.” PLOS Comp. Biol. 12(11): e1005113. https://doi.org/10.1371/journal.pcbi.1005113

Notes

Differences from the original paper:

  • We add a causal LayerNorm over input frequencies and over the hidden channel axis. The original assumes preprocessing-time input normalization and uses no internal norm.

  • The hidden activation is nn.Tanh (paper-faithful: scaled tanh with ρ₁ ≈ 1.7159, ρ₂ = 2/3 — we use the unscaled standard tanh, equivalent up to a learned rescaling absorbed into the readout).

  • Causal left-padding extends the model to arbitrary input lengths; the paper uses fixed-window slicing.

  • The hidden STRF kernel can be parameterized (DCLS); the paper uses a vanilla full kernel.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRFs(hidden_idx: int = 0, polarity: str = 'ON')[source]

Return the hidden-layer STRF kernel for one hidden unit as (F, T).

Parameters:
  • hidden_idx (int, default 0) – Which of the H hidden units to return the STRF for.

  • polarity ({'ON', 'OFF'}, default 'ON') – Only relevant when the prefilter has C_in == 2 (e.g. AdapTrans). Selects the ON or OFF channel of the kernel.

class deepSTRF.models.audio.audio_zoo.StateNet(n_frequency_bands: int = 34, temporal_window_size: int = 1, kernel_size: int = 7, stride: int = 3, hidden_channels: int = 7, connectivity: str = 'LC', rnn_type: str = 'GRU', out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, wav2spec: Module = None)[source]

Bases: AudioEncodingModel

Fully stateful STRF model — relies entirely on temporal recurrence to extract information from stimulus sequences, with no explicit STRF delay window.

Architecture: a stateless per-timestep spectral encoder maps each spectrogram column (C_in, F) to a hidden representation (C, F_down). The flattened hidden representation is fed timestep-by-timestep to a recurrent (or state-space) model that accumulates context implicitly through its hidden state. A linear readout projects the recurrent hidden state to N output neurons.

Causality is inherent to the recurrent backbone (RNN/GRU/LSTM/LMU/ Mamba/S4). The spectral encoder operates on a single timestep at a time so it does not couple frames temporally.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • temporal_window_size (int, default 1) – Unused by StateNet (kept for AudioEncodingModel API compatibility); recurrence handles temporal context.

  • kernel_size (int, default 7) – Frequency kernel size for the spectral encoder.

  • stride (int, default 3) – Frequency stride for the spectral encoder.

  • hidden_channels (int, default 7) – Channel count of the spectral encoder C.

  • connectivity ({'LC', 'FC', 'CONV'}, default 'LC') – Spectral encoder connectivity. 'LC': locally-connected 1D layer (frequency-position-specific weights). 'FC': dense linear projection with reshape to (C, F_down). 'CONV': weight-shared 1D convolution.

  • rnn_type ({'GRU', 'LSTM', 'RNN', 'vanilla', 'LMU', 'Mamba', 'S4'}, default 'GRU') – Recurrent / state-space backbone.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, default ParametricSoftplus(out_neurons)) – Pointwise nonlinearity at the output. The default is unbounded above and non-negative — natural for spike-count regression.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (AdapTrans, ICAdaptation, or any module exposing out_channels). None (default) gives nn.Identity and C_in = 1.

References

Rançon, Masquelier & Cottereau (2025). “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.” Communications Biology 8:1456. https://doi.org/10.1038/s42003-025-08858-3

Notes

  • The spectral encoder uses a CausalLayerNorm over the channel axis (C); the original implementation used BatchNorm1d which pools statistics over the (T*B, F_down) axis, making it non-causal.

  • The S4 backbone is imported lazily — its module emits CUDA-extension warnings on import that other backends would not see.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Per-timestep spectral encoder feeding a recurrent backbone.

Overrides the base template because the encoder runs in a flattened (T*B, C_in, F) batch (so every timestep is independent in the spectral pass) and the RNN expects a batch-first (B, T, H) layout — these reshapes don’t decompose into the canonical core / readout slots.

class deepSTRF.models.audio.audio_zoo.Transformer(n_frequency_bands: int = 34, freq_patch_size: int = None, time_patch_size: int = 1, context_window: int = None, embedding_dim: int = 48, n_heads: int = 1, n_layers: int = 1, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, wav2spec: Module = None)[source]

Bases: AudioEncodingModel

Attention-based STRF model — a Transformer encoder runs causal self-attention over a per-timestep token sequence extracted from the spectrogram.

Architecture:

input              (B, 1, F, L)
  ↓ prefilter      (B, C_in, F, L)
  ↓ time pad       (B, C_in, F, L + K_T - 1)         left-only causal
  ↓ patchify       (B, embedding_dim, F_p, L)        Conv2d, stride=(K_F, 1)
  ↓ flatten/permute (B, L, embedding_dim * F_p)      one token per timestep
  ↓ + sinusoidal positional encoding
  ↓ TransformerEncoder with causal (+optional window) mask
  ↓ readout        (B, N, 1, L)

The patchifier is a strided nn.Conv2d with kernel (K_F, K_T) and stride (K_F, 1). Frequency patches are non-overlapping (F_p = F // K_F); time stride is 1 so there is one token per timestep, and K_T > 1 lets each token aggregate time_patch_size recent frames. A (K_T - 1)-zero left pad along time keeps the patchifier strictly causal.

The attention mask is constructed per forward pass at the actual sequence length, so the model generalizes to any input length L. Setting context_window to a positive int restricts attention to the most recent context_window past frames (band-causal), recovering a Sahani-style fixed STRF context window when wanted — the model can be evaluated with or without the bound at inference without retraining.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • freq_patch_size (int, optional) – Frequency-axis patch size for the patchifier. None (default) uses F itself — one token spans the full frequency axis at each timestep. Must divide F.

  • time_patch_size (int, default 1) – Temporal extent of each patch in frames. 1 gives one token = one timestep slice; larger values let each token aggregate across time_patch_size recent frames via the patchifier.

  • context_window (int, optional) – If set, restrict attention to the most recent context_window past frames (still causal — band-causal mask). None (default) gives unlimited causal context.

  • embedding_dim (int, default 48) – Per-patch embedding dimension after the patchifier.

  • n_heads (int, default 1) – Number of attention heads. Must divide embedding_dim * F_p.

  • n_layers (int, default 1) – Number of TransformerEncoderLayer blocks.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, optional) – Pointwise nonlinearity at the readout. Default nn.Identity.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter.

References

Rançon, Masquelier & Cottereau (2025). “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.” Communications Biology 8:1456. https://doi.org/10.1038/s42003-025-08858-3

Vaswani et al. (2017). “Attention Is All You Need.” NeurIPS.

Notes

Sinusoidal positional encoding (Vaswani 2017) is used by default; it generalizes to arbitrary sequence lengths at inference. RoPE (Rotary Position Embedding, Su et al. 2021) is a planned alternative; it is omitted here because it requires a custom TransformerEncoderLayer (PyTorch’s stock module hides Q and K).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Causal-attention forward.

Overrides the base template because the attention mask must be rebuilt at the actual sequence length L of each input.

Module contents

class deepSTRF.models.audio.AudioEncodingModel(n_frequency_bands: int, temporal_window_size: int, out_neurons: int = 1, prefiltering: Module = None, wav2spec: Module = None, *args, **kwargs)[source]

Bases: NeuralModel

Base class for encoding models of audio neural responses.

Forward signature: input (B, 1, F, T) spectrogram → output (B, N, R=1, T) neural activity. Concrete subclasses populate the four canonical slots wav2spec / prefiltering / core / readout (see NeuralModel).

Parameters:
  • n_frequency_bands (int) – Number of input spectrogram frequency bands F.

  • temporal_window_size (int) – STRF temporal extent T in frames. Used by STRF_gradmap to size the null stimulus.

  • out_neurons (int, default 1) – Number of output neurons N.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (e.g. AdapTrans, ICAdaptation). Must expose an out_channels integer attribute so the model can size C_in automatically. None (default) gives nn.Identity() and C_in = 1.

  • wav2spec (nn.Module, optional) – Optional raw-waveform front-end. Maps a mono waveform (B, 1, T_audio) to a spectrogram (B, 1, F, T_neural) and slots in at the top of the canonical forward pipeline (see NeuralModel). Must expose an out_channels: int attribute equal to n_frequency_bands. Pair with a dataset in waveform mode (e.g. NS1Dataset(return_waveform=True)). None (default) keeps the slot as nn.Identity() — the model then expects spectrogram input (B, 1, F, T) as before.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRF_gradmap(T: int = None)[source]

Compute one STRF gradient map per output neuron in parallel.

For each of the N output neurons, finds the changes in a null spectrogram that elicit an increase in that neuron’s activity at the last timestep (a Spike-Triggered-Average-like readout, computed by autodiff). The batch dimension is used to parallelize across neurons in a single forward / backward pass.

Parameters:

T (int, optional) – Time-axis length of the null stimulus. Defaults to self.T.

Returns:

Per-neuron gradient map.

Return type:

Tensor of shape ``(N, 1, F, T)``

References

Rançon et al. (2025), “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.”

Notes

Future work:

  • Handle multi-channel inputs (the gradient is currently shaped (N, 1, F, T) regardless of C_in; an AdapTrans-prefiltered model has C_in == 2 and the per-channel gradients differ).

  • Allow custom losses (e.g. sustained activity rather than last-timestep-only).

validate()[source]

Check that the instance is deepSTRF-compatible.

Subclasses should call super().validate() and then add their own checks (e.g. AudioEncodingModel checks F, T > 0).

Raises:

AssertionError – If self.O is not a positive int, if readout is unset or not an torch.nn.Module, or if any of the wav2spec / prefiltering / core slots is not an torch.nn.Module.

waveform_gradmap(stimulus, neuron=None, reduce='last')[source]

Gradient of a neuron’s response w.r.t. the input waveform.

The waveform-domain analogue of STRF_gradmap(): instead of probing the spectrogram input, backprop a neuron’s response all the way through the (learnable) wav2spec front-end to the raw audio samples. The returned gradient is itself a waveform — a listenable, time-domain receptive field — only defined for wav-native models (a non-Identity wav2spec).

Parameters:
  • stimulus (array-like, shape (T_audio,), (1, T_audio) or (1, 1, T_audio)) – Audio to compute the gradmap around (e.g. a real stimulus). A real stimulus is recommended over silence — adaptive front-ends (PCEN) are ill-conditioned at zero energy.

  • neuron (int, optional) – Which output neuron. None (default) sums over all neurons (a population gradmap).

  • reduce ({'last', 'peak', 'sum'}, default 'last') – How to reduce the neuron’s response over time before backprop. 'last' (default, matching STRF_gradmap()) maximizes the activation at the last timestep, so the gradient is supported only within the receptive field before it — it reveals the RF and decays to ~zero further into the past. 'peak' does the same at the peak-response timestep. 'sum' time-integrates over all output timesteps, which makes the gradient non-zero almost everywhere by construction (a whole-stimulus saliency map, not a receptive field).

Returns:

Per-audio-sample gradient — the waveform-domain receptive field. Computed in eval mode (the strictly-causal inference regime). With reduce='last' the support is the RF length (e.g. ~45 ms for a T=9 STRF on a mel front-end; longer for adaptive front-ends like LEAF, whose PCEN smoother adds a decaying temporal memory).

Return type:

torch.Tensor, shape ``(T_audio,)``

class deepSTRF.models.audio.ConvNet2D(n_frequency_bands: int = 34, kernel_size: tuple = (3, 9), c_hidden: int = 10, n_hidden: int = 20, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None)[source]

Bases: AudioEncodingModel

Convolutional STRF model with three sequential 2D convs and a 2-layer fully-connected readout — adapted from the ‘2D-CNN’ of Pennington & David (2023).

Architecture: three Conv2d → CausalLayerNorm → LeakyReLU blocks extract a stack of feature maps; the per-time-step features are flattened over the (channel × downsampled-frequency) axes and a 2-layer FC reads out N output neurons.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • kernel_size (tuple of int, default (3, 9)) – Conv2d kernel (K_F, K_T) shared across the three conv blocks.

  • c_hidden (int, default 10) – Number of channels in each conv block.

  • n_hidden (int, default 20) – Width of the FC hidden layer.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, default ParametricSoftplus(out_neurons)) – Pointwise nonlinearity at the output. The default is unbounded above and non-negative — natural for spike-count regression.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (AdapTrans, ICAdaptation, or any module exposing out_channels). None (default) gives nn.Identity and C_in = 1.

References

Pennington & David (2023). “A convolutional neural network provides a generalizable model of natural sound coding by neural populations in auditory cortex.” PLOS Comp. Biol. 19(5): e1011110. https://doi.org/10.1371/journal.pcbi.1011110

Notes

Differences from the original paper:

  • Causal LayerNorm replaces the missing internal normalization (paper uses none).

  • Hidden activation is LeakyReLU(0.1) rather than ReLU. Empirical preference, very small architectural difference.

  • 2D convs over (F, T) rather than 1D convs over T (the paper applies 1D convolutions with implicit spectral pooling).

  • Frequency downsampling is implicit via valid-padding shrinkage: three convs each shrink F by K_F - 1, giving F_down = F - 3*(K_F - 1).

  • Causal left-padding extends the model to arbitrary input lengths; the paper also uses explicit causal padding.

  • The output activation is configurable; the paper uses a 4-parameter double-exponential — see deepSTRF.models.activations.ParametricDoubleExponential.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class deepSTRF.models.audio.DNet(n_frequency_bands: int = 34, temporal_window_size: int = 9, n_hidden: int = 20, init_tau: float = 2.0, decay_input: bool = True, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]

Bases: AudioEncodingModel

Dynamic Network (DNet) — an NRF whose hidden and output units are stateful with learnable temporal decay.

Architecture: STRF projection → channel-norm → sigmoid → learnable exponential decay (one time constant per hidden unit) → 1×1 readout → output activation. The exponential decay is causal: each unit’s output at time t is a convolution of its instantaneous input with a learned one-sided exponential kernel.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • temporal_window_size (int, default 9) – STRF temporal extent T.

  • n_hidden (int, default 20) – Hidden layer width H.

  • init_tau (float, default 2.0) – Initial time constant (in frames) for the hidden-unit exponential decay.

  • decay_input (bool, default True) – If True, the exponential decay also weights its instantaneous input by 1/(1+d²) (paper convention); if False, the instantaneous input passes through unscaled.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, optional) – Pointwise nonlinearity at the readout output. Default nn.Identity (paper-faithful linear readout).

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter.

  • kernel (nn.Module, optional) – Pluggable hidden-layer STRF kernel.

References

Rahman, Willmore, King & Harper (2019). “A dynamic network model of temporal receptive fields in primary auditory cortex.” PLOS Comp. Biol. 15(5): e1006618. https://doi.org/10.1371/journal.pcbi.1006618

Notes

Differences from the original paper:

  • Causal LayerNorm replaces the missing internal normalization (paper assumes preprocessing-time input normalization).

  • Causal left-padding extends the model to arbitrary input lengths; the paper uses fixed-window slicing.

  • The hidden STRF kernel can be parameterized (DCLS); the paper uses a vanilla full kernel.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRFs(hidden_idx: int = 0, polarity: str = 'ON')[source]

Return the hidden-layer STRF kernel for one hidden unit as (F, T).

Parameters:
  • hidden_idx (int, default 0) – Which of the H hidden units to inspect.

  • polarity ({'ON', 'OFF'}, default 'ON') – Only relevant for AdapTrans-prefiltered models (C_in == 2).

class deepSTRF.models.audio.ICNet(audio_fs: int, out_neurons: int, dt_ms: float = 5.0, n_filters: int = 48, sincnet_kernel_size: int = 64, encoder_channels: int = 128, encoder_kernel_size: int = 64, n_encoder_layers: int = 5, bottleneck_channels: int = 64, encoder_strides: Sequence[int] | None = None)[source]

Bases: AudioEncodingModel

End-to-end ICNet (Drakopoulos et al. 2025) ported to deepSTRF.

Architecture: SincNet (48 filters, K=64, stride 1, symlog) → 5× causal Conv1d(128 ch, K=64, PReLU) at strides that multiply to audio_fs · dt_ms / 1000 → bottleneck Conv1d(64 ch, K=64, stride 1, PReLU)Linear(64 N) → softplus (Poisson head, N_c = 1 in paper notation).

Cross-dataset configuration

The paper trains on 24 414 Hz gerbil-IC audio binned at ~1.31 ms (32 samples per bin, 5 stride-2 conv layers). To use the same architecture on a dataset at a different (audio_fs, dt_ms), the encoder strides are auto-factored so they multiply to audio_fs · dt_ms / 1000 (the number of audio samples per neural bin). For NS1 (48 kHz / 5 ms) that’s 240 samples / bin and the default factorisation is [2, 2, 2, 2, 15]. Pass an explicit encoder_strides list to override. The layer structure (kernel sizes, channel counts, activations) stays paper-faithful; only the strides scale with the dataset, per the deepSTRF policy of adapting hyperparameters to each dataset’s temporal resolution.

The decoder is intentionally simple — paper-faithful (the paper: “the simple linear decoders in ICNet … ensure that the latent representation in the bottleneck is constrained to directly reflect the dynamics that underlie neural activity”). The expressivity lives in the shared encoder.

Differences from the paper

  • Single-branch / time-invariant only. The paper’s multi-branch and time-variant heads (animal-specific decoders, timestamp-input modulation) are out of scope for the deepSTRF v1 port.

  • Poisson head only. The paper’s main result uses a categorical cross-entropy head with N_c = 5 classes for spike counts in {0, 1, 2, 3, ≥4}. The deepSTRF training stack centres on rate-based losses; cross-entropy can be added later.

  • No left-context crop. The paper feeds 10 240 audio samples in and crops the leftmost 64 frames from the bottleneck output to suppress edge effects. deepSTRF’s convention is to keep T_neural output frames matching the dataset’s response window; causal convs leave the first few frames noisier but downstream losses handle that.

param audio_fs:

Audio sample rate (Hz). Determines the total encoder downsampling.

type audio_fs:

int

param out_neurons:

Number of output neurons N.

type out_neurons:

int

param dt_ms:

Target neural bin width in ms. Encoder strides are factored so the total downsampling matches audio_fs · dt_ms / 1000.

type dt_ms:

float, default 5.0

param n_filters:

SincNet filter count.

type n_filters:

int, default 48

param sincnet_kernel_size:

type sincnet_kernel_size:

int, default 64

param encoder_channels:

type encoder_channels:

int, default 128

param encoder_kernel_size:

type encoder_kernel_size:

int, default 64

param n_encoder_layers:

type n_encoder_layers:

int, default 5

param bottleneck_channels:

Output channel count of the bottleneck conv.

type bottleneck_channels:

int, default 64

param encoder_strides:

Per-layer encoder strides. Default: auto-factor.

type encoder_strides:

sequence of int, optional

References

Drakopoulos, Pellatt, Sabesan, Xia, Fragner & Lesica (2025). “Modelling neural coding in the auditory midbrain with high resolution and accuracy.” Nature Machine Intelligence 7:1478-1493. https://doi.org/10.1038/s42256-025-01104-9

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]

Forward pass.

Overrides the base template because the bottleneck latent is shaped (B, 1, 64, T) (an explicit C_in axis on top of the latent dim) and the paper’s decoder is a per-timestep linear map — the canonical STRFReadout slot doesn’t fit cleanly.

Parameters:

x (torch.Tensor) – Mono waveform, shape (B, 1, T_audio).

Returns:

Predicted spike rate, shape (B, N, 1, T_neural). Non-negative (softplus output) — pair with poisson_loss().

Return type:

torch.Tensor

class deepSTRF.models.audio.Linear(n_frequency_bands: int = 34, temporal_window_size: int = 9, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None, wav2spec: Module = None)[source]

Bases: AudioEncodingModel

The canonical Linear (L) STRF model — a single SpectroTemporal Receptive Field convolved with the (optionally prefiltered) input spectrogram.

All learnable parameters live in the readout (STRFReadout), which holds the kernel of shape (N, C_in, F, T), applies it causally via left-padding, and follows it with a per-neuron nn.BatchNorm1d(N) before the output activation. The model’s core is nn.Identity — every learnable scalar has the neuron axis as leading dim, so the model is strictly no-shared-params (each neuron’s parameters are independent of every other neuron’s).

Parameters:
  • n_frequency_bands (int, default 34) – Number of input spectrogram frequency bands F.

  • temporal_window_size (int, default 9) – STRF temporal extent T in frames.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, optional) – Pointwise nonlinearity applied at the readout output. Default nn.Identity (true linear model).

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (AdapTrans, ICAdaptation, or any nn.Module exposing out_channels). None (default) gives nn.Identity and C_in = 1.

  • kernel (nn.Module, optional) – Pluggable STRF kernel for the readout. None (default) gives a vanilla nn.Conv2d; pass ParametricSTRF(...) for DCLS, or a separable nn.Sequential for a rank-1 factorization. See deepSTRF.models.layers for the kernel module catalogue.

  • wav2spec (nn.Module, optional) – Optional raw-waveform front-end (deepSTRF.models.wav2spec.*). When provided, the model accepts raw audio (B, 1, T_audio) instead of a spectrogram. None (default) keeps the slot as nn.Identity() and the model expects (B, 1, F, T) spec input.

References

The L model is a folklore baseline; canonical formulations appear in:

Theunissen, Sen & Doupe (2000). “Spectral-Temporal Receptive Fields of Nonlinear Auditory Neurons Obtained Using Natural Sounds.” J. Neurosci. 20(6): 2315–2331. https://doi.org/10.1523/JNEUROSCI.20-06-02315.2000

Sahani & Linden (2003). “How Linear are Auditory Cortical Responses?” NIPS.

Notes

The per-neuron BatchNorm inside the readout absorbs into the kernel at inference (its running stats are frozen per-channel scalars), so the model remains a strict linear-affine map of the input at eval time and the learned STRF kernel is directly interpretable up to a per-neuron affine rescaling.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRF_weight(polarity: str = 'ON')[source]

Return the readout’s STRF kernel as a (N, F, T) tensor.

For models prefiltered with AdapTrans (C_in == 2), polarity selects the ON or OFF channel of the kernel. For single-channel inputs the parameter is ignored.

class deepSTRF.models.audio.LinearNonlinear(n_frequency_bands: int = 34, temporal_window_size: int = 9, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]

Bases: Linear

Linear-Nonlinear (LN) STRF model — the Linear model followed by a pointwise output nonlinearity.

Inherits everything from Linear and only changes the default output activation from nn.Identity to a per-neuron ParametricSoftplus. Pass any nn.Module to output_activation to override.

Parameters:

output_activation (nn.Module, default ParametricSoftplus(out_neurons)) – Pointwise nonlinearity applied at the readout output. The default is unbounded above and non-negative — natural for spike-count regression on smoothed PSTHs that exceed 1. See deepSTRF.models.activations for other parametric variants (ParametricSigmoid, ParametricDoubleExponential).

See also

Linear

Same architecture without the output nonlinearity.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class deepSTRF.models.audio.NetworkReceptiveField(n_frequency_bands: int = 34, temporal_window_size: int = 9, n_hidden: int = 20, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]

Bases: AudioEncodingModel

Network Receptive Field (NRF) model — a two-layer feedforward STRF network.

Architecture: a STRF kernel projects the input spectrogram into a hidden layer of H units; a 1×1 conv reads out the N output neurons from the hidden activations. With L1 regularization the paper finds typically 1–7 effective hidden units per neuron.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • temporal_window_size (int, default 9) – STRF temporal extent T.

  • n_hidden (int, default 20) – Hidden layer width H.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, optional) – Pointwise nonlinearity at the readout output. Default nn.Sigmoid.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (AdapTrans, ICAdaptation, any module exposing out_channels). None (default) gives nn.Identity and C_in = 1.

  • kernel (nn.Module, optional) – Pluggable hidden-layer STRF kernel. None (default) gives a vanilla nn.Conv2d; pass ParametricSTRF(...) for DCLS.

References

Harper, Schoppe, Willmore, Cui, Schnupp & King (2016). “Network Receptive Field Modeling Reveals Extensive Integration and Multi-feature Selectivity in Auditory Cortical Neurons.” PLOS Comp. Biol. 12(11): e1005113. https://doi.org/10.1371/journal.pcbi.1005113

Notes

Differences from the original paper:

  • We add a causal LayerNorm over input frequencies and over the hidden channel axis. The original assumes preprocessing-time input normalization and uses no internal norm.

  • The hidden activation is nn.Tanh (paper-faithful: scaled tanh with ρ₁ ≈ 1.7159, ρ₂ = 2/3 — we use the unscaled standard tanh, equivalent up to a learned rescaling absorbed into the readout).

  • Causal left-padding extends the model to arbitrary input lengths; the paper uses fixed-window slicing.

  • The hidden STRF kernel can be parameterized (DCLS); the paper uses a vanilla full kernel.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRFs(hidden_idx: int = 0, polarity: str = 'ON')[source]

Return the hidden-layer STRF kernel for one hidden unit as (F, T).

Parameters:
  • hidden_idx (int, default 0) – Which of the H hidden units to return the STRF for.

  • polarity ({'ON', 'OFF'}, default 'ON') – Only relevant when the prefilter has C_in == 2 (e.g. AdapTrans). Selects the ON or OFF channel of the kernel.

class deepSTRF.models.audio.StateNet(n_frequency_bands: int = 34, temporal_window_size: int = 1, kernel_size: int = 7, stride: int = 3, hidden_channels: int = 7, connectivity: str = 'LC', rnn_type: str = 'GRU', out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, wav2spec: Module = None)[source]

Bases: AudioEncodingModel

Fully stateful STRF model — relies entirely on temporal recurrence to extract information from stimulus sequences, with no explicit STRF delay window.

Architecture: a stateless per-timestep spectral encoder maps each spectrogram column (C_in, F) to a hidden representation (C, F_down). The flattened hidden representation is fed timestep-by-timestep to a recurrent (or state-space) model that accumulates context implicitly through its hidden state. A linear readout projects the recurrent hidden state to N output neurons.

Causality is inherent to the recurrent backbone (RNN/GRU/LSTM/LMU/ Mamba/S4). The spectral encoder operates on a single timestep at a time so it does not couple frames temporally.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • temporal_window_size (int, default 1) – Unused by StateNet (kept for AudioEncodingModel API compatibility); recurrence handles temporal context.

  • kernel_size (int, default 7) – Frequency kernel size for the spectral encoder.

  • stride (int, default 3) – Frequency stride for the spectral encoder.

  • hidden_channels (int, default 7) – Channel count of the spectral encoder C.

  • connectivity ({'LC', 'FC', 'CONV'}, default 'LC') – Spectral encoder connectivity. 'LC': locally-connected 1D layer (frequency-position-specific weights). 'FC': dense linear projection with reshape to (C, F_down). 'CONV': weight-shared 1D convolution.

  • rnn_type ({'GRU', 'LSTM', 'RNN', 'vanilla', 'LMU', 'Mamba', 'S4'}, default 'GRU') – Recurrent / state-space backbone.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, default ParametricSoftplus(out_neurons)) – Pointwise nonlinearity at the output. The default is unbounded above and non-negative — natural for spike-count regression.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter (AdapTrans, ICAdaptation, or any module exposing out_channels). None (default) gives nn.Identity and C_in = 1.

References

Rançon, Masquelier & Cottereau (2025). “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.” Communications Biology 8:1456. https://doi.org/10.1038/s42003-025-08858-3

Notes

  • The spectral encoder uses a CausalLayerNorm over the channel axis (C); the original implementation used BatchNorm1d which pools statistics over the (T*B, F_down) axis, making it non-causal.

  • The S4 backbone is imported lazily — its module emits CUDA-extension warnings on import that other backends would not see.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Per-timestep spectral encoder feeding a recurrent backbone.

Overrides the base template because the encoder runs in a flattened (T*B, C_in, F) batch (so every timestep is independent in the spectral pass) and the RNN expects a batch-first (B, T, H) layout — these reshapes don’t decompose into the canonical core / readout slots.

class deepSTRF.models.audio.Transformer(n_frequency_bands: int = 34, freq_patch_size: int = None, time_patch_size: int = 1, context_window: int = None, embedding_dim: int = 48, n_heads: int = 1, n_layers: int = 1, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, wav2spec: Module = None)[source]

Bases: AudioEncodingModel

Attention-based STRF model — a Transformer encoder runs causal self-attention over a per-timestep token sequence extracted from the spectrogram.

Architecture:

input              (B, 1, F, L)
  ↓ prefilter      (B, C_in, F, L)
  ↓ time pad       (B, C_in, F, L + K_T - 1)         left-only causal
  ↓ patchify       (B, embedding_dim, F_p, L)        Conv2d, stride=(K_F, 1)
  ↓ flatten/permute (B, L, embedding_dim * F_p)      one token per timestep
  ↓ + sinusoidal positional encoding
  ↓ TransformerEncoder with causal (+optional window) mask
  ↓ readout        (B, N, 1, L)

The patchifier is a strided nn.Conv2d with kernel (K_F, K_T) and stride (K_F, 1). Frequency patches are non-overlapping (F_p = F // K_F); time stride is 1 so there is one token per timestep, and K_T > 1 lets each token aggregate time_patch_size recent frames. A (K_T - 1)-zero left pad along time keeps the patchifier strictly causal.

The attention mask is constructed per forward pass at the actual sequence length, so the model generalizes to any input length L. Setting context_window to a positive int restricts attention to the most recent context_window past frames (band-causal), recovering a Sahani-style fixed STRF context window when wanted — the model can be evaluated with or without the bound at inference without retraining.

Parameters:
  • n_frequency_bands (int, default 34) – Number of input frequency bands F.

  • freq_patch_size (int, optional) – Frequency-axis patch size for the patchifier. None (default) uses F itself — one token spans the full frequency axis at each timestep. Must divide F.

  • time_patch_size (int, default 1) – Temporal extent of each patch in frames. 1 gives one token = one timestep slice; larger values let each token aggregate across time_patch_size recent frames via the patchifier.

  • context_window (int, optional) – If set, restrict attention to the most recent context_window past frames (still causal — band-causal mask). None (default) gives unlimited causal context.

  • embedding_dim (int, default 48) – Per-patch embedding dimension after the patchifier.

  • n_heads (int, default 1) – Number of attention heads. Must divide embedding_dim * F_p.

  • n_layers (int, default 1) – Number of TransformerEncoderLayer blocks.

  • out_neurons (int, default 1) – Number of output neurons N.

  • output_activation (nn.Module, optional) – Pointwise nonlinearity at the readout. Default nn.Identity.

  • prefiltering (nn.Module, optional) – Optional spectrogram prefilter.

References

Rançon, Masquelier & Cottereau (2025). “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.” Communications Biology 8:1456. https://doi.org/10.1038/s42003-025-08858-3

Vaswani et al. (2017). “Attention Is All You Need.” NeurIPS.

Notes

Sinusoidal positional encoding (Vaswani 2017) is used by default; it generalizes to arbitrary sequence lengths at inference. RoPE (Rotary Position Embedding, Su et al. 2021) is a planned alternative; it is omitted here because it requires a custom TransformerEncoderLayer (PyTorch’s stock module hides Q and K).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Causal-attention forward.

Overrides the base template because the attention mask must be rebuilt at the actual sequence length L of each input.