deepSTRF.models.audio package
Submodules
deepSTRF.models.audio.audio_model module
- class deepSTRF.models.audio.audio_model.AudioEncodingModel(n_frequency_bands: int, temporal_window_size: int, out_neurons: int = 1, prefiltering: Module = None, wav2spec: Module = None, *args, **kwargs)[source]
Bases:
NeuralModelBase class for encoding models of audio neural responses.
Forward signature: input
(B, 1, F, T)spectrogram → output(B, N, R=1, T)neural activity. Concrete subclasses populate the four canonical slotswav2spec/prefiltering/core/readout(seeNeuralModel).- Parameters:
n_frequency_bands (
int) – Number of input spectrogram frequency bandsF.temporal_window_size (
int) – STRF temporal extentTin frames. Used bySTRF_gradmapto size the null stimulus.out_neurons (
int, default1) – Number of output neuronsN.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (e.g.AdapTrans,ICAdaptation). Must expose anout_channelsinteger attribute so the model can sizeC_inautomatically.None(default) givesnn.Identity()andC_in = 1.wav2spec (
nn.Module, optional) – Optional raw-waveform front-end. Maps a mono waveform(B, 1, T_audio)to a spectrogram(B, 1, F, T_neural)and slots in at the top of the canonical forward pipeline (seeNeuralModel). Must expose anout_channels: intattribute equal ton_frequency_bands. Pair with a dataset in waveform mode (e.g.NS1Dataset(return_waveform=True)).None(default) keeps the slot asnn.Identity()— the model then expects spectrogram input(B, 1, F, T)as before.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- STRF_gradmap(T: int = None)[source]
Compute one STRF gradient map per output neuron in parallel.
For each of the
Noutput neurons, finds the changes in a null spectrogram that elicit an increase in that neuron’s activity at the last timestep (a Spike-Triggered-Average-like readout, computed by autodiff). The batch dimension is used to parallelize across neurons in a single forward / backward pass.- Parameters:
T (
int, optional) – Time-axis length of the null stimulus. Defaults toself.T.- Returns:
Per-neuron gradient map.
- Return type:
Tensorofshape ``(N,1,F,T)``
References
Rançon et al. (2025), “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.”
Notes
Future work:
Handle multi-channel inputs (the gradient is currently shaped
(N, 1, F, T)regardless ofC_in; an AdapTrans-prefiltered model hasC_in == 2and the per-channel gradients differ).Allow custom losses (e.g. sustained activity rather than last-timestep-only).
- validate()[source]
Check that the instance is deepSTRF-compatible.
Subclasses should call
super().validate()and then add their own checks (e.g.AudioEncodingModelchecksF, T > 0).- Raises:
AssertionError – If
self.Ois not a positive int, ifreadoutis unset or not antorch.nn.Module, or if any of thewav2spec/prefiltering/coreslots is not antorch.nn.Module.
- waveform_gradmap(stimulus, neuron=None, reduce='last')[source]
Gradient of a neuron’s response w.r.t. the input waveform.
The waveform-domain analogue of
STRF_gradmap(): instead of probing the spectrogram input, backprop a neuron’s response all the way through the (learnable)wav2specfront-end to the raw audio samples. The returned gradient is itself a waveform — a listenable, time-domain receptive field — only defined for wav-native models (a non-Identitywav2spec).- Parameters:
stimulus (
array-like,shape (T_audio,),(1,T_audio)or(1,1,T_audio)) – Audio to compute the gradmap around (e.g. a real stimulus). A real stimulus is recommended over silence — adaptive front-ends (PCEN) are ill-conditioned at zero energy.neuron (
int, optional) – Which output neuron.None(default) sums over all neurons (a population gradmap).reduce (
{'last', 'peak', 'sum'}, default'last') – How to reduce the neuron’s response over time before backprop.'last'(default, matchingSTRF_gradmap()) maximizes the activation at the last timestep, so the gradient is supported only within the receptive field before it — it reveals the RF and decays to ~zero further into the past.'peak'does the same at the peak-response timestep.'sum'time-integrates over all output timesteps, which makes the gradient non-zero almost everywhere by construction (a whole-stimulus saliency map, not a receptive field).
- Returns:
Per-audio-sample gradient — the waveform-domain receptive field. Computed in
evalmode (the strictly-causal inference regime). Withreduce='last'the support is the RF length (e.g. ~45 ms for a T=9 STRF on a mel front-end; longer for adaptive front-ends like LEAF, whose PCEN smoother adds a decaying temporal memory).- Return type:
torch.Tensor,shape ``(T_audio,)``
deepSTRF.models.audio.audio_zoo module
- class deepSTRF.models.audio.audio_zoo.ConvNet2D(n_frequency_bands: int = 34, kernel_size: tuple = (3, 9), c_hidden: int = 10, n_hidden: int = 20, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None)[source]
Bases:
AudioEncodingModelConvolutional STRF model with three sequential 2D convs and a 2-layer fully-connected readout — adapted from the ‘2D-CNN’ of Pennington & David (2023).
Architecture: three Conv2d → CausalLayerNorm → LeakyReLU blocks extract a stack of feature maps; the per-time-step features are flattened over the (channel × downsampled-frequency) axes and a 2-layer FC reads out
Noutput neurons.- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.kernel_size (
tupleofint, default(3,9)) – Conv2d kernel(K_F, K_T)shared across the three conv blocks.c_hidden (
int, default10) – Number of channels in each conv block.n_hidden (
int, default20) – Width of the FC hidden layer.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, defaultParametricSoftplus(out_neurons)) – Pointwise nonlinearity at the output. The default is unbounded above and non-negative — natural for spike-count regression.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (AdapTrans,ICAdaptation, or any module exposingout_channels).None(default) givesnn.IdentityandC_in = 1.
References
Pennington & David (2023). “A convolutional neural network provides a generalizable model of natural sound coding by neural populations in auditory cortex.” PLOS Comp. Biol. 19(5): e1011110. https://doi.org/10.1371/journal.pcbi.1011110
Notes
Differences from the original paper:
Causal LayerNorm replaces the missing internal normalization (paper uses none).
Hidden activation is
LeakyReLU(0.1)rather than ReLU. Empirical preference, very small architectural difference.2D convs over
(F, T)rather than 1D convs overT(the paper applies 1D convolutions with implicit spectral pooling).Frequency downsampling is implicit via valid-padding shrinkage: three convs each shrink
FbyK_F - 1, givingF_down = F - 3*(K_F - 1).Causal left-padding extends the model to arbitrary input lengths; the paper also uses explicit causal padding.
The output activation is configurable; the paper uses a 4-parameter double-exponential — see
deepSTRF.models.activations.ParametricDoubleExponential.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class deepSTRF.models.audio.audio_zoo.DNet(n_frequency_bands: int = 34, temporal_window_size: int = 9, n_hidden: int = 20, init_tau: float = 2.0, decay_input: bool = True, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]
Bases:
AudioEncodingModelDynamic Network (DNet) — an NRF whose hidden and output units are stateful with learnable temporal decay.
Architecture: STRF projection → channel-norm → sigmoid → learnable exponential decay (one time constant per hidden unit) → 1×1 readout → output activation. The exponential decay is causal: each unit’s output at time
tis a convolution of its instantaneous input with a learned one-sided exponential kernel.- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.temporal_window_size (
int, default9) – STRF temporal extentT.n_hidden (
int, default20) – Hidden layer widthH.init_tau (
float, default2.0) – Initial time constant (in frames) for the hidden-unit exponential decay.decay_input (
bool, defaultTrue) – If True, the exponential decay also weights its instantaneous input by1/(1+d²)(paper convention); if False, the instantaneous input passes through unscaled.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, optional) – Pointwise nonlinearity at the readout output. Defaultnn.Identity(paper-faithful linear readout).prefiltering (
nn.Module, optional) – Optional spectrogram prefilter.kernel (
nn.Module, optional) – Pluggable hidden-layer STRF kernel.
References
Rahman, Willmore, King & Harper (2019). “A dynamic network model of temporal receptive fields in primary auditory cortex.” PLOS Comp. Biol. 15(5): e1006618. https://doi.org/10.1371/journal.pcbi.1006618
Notes
Differences from the original paper:
Causal LayerNorm replaces the missing internal normalization (paper assumes preprocessing-time input normalization).
Causal left-padding extends the model to arbitrary input lengths; the paper uses fixed-window slicing.
The hidden STRF kernel can be parameterized (DCLS); the paper uses a vanilla full kernel.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- STRFs(hidden_idx: int = 0, polarity: str = 'ON')[source]
Return the hidden-layer STRF kernel for one hidden unit as
(F, T).- Parameters:
hidden_idx (
int, default0) – Which of theHhidden units to inspect.polarity (
{'ON', 'OFF'}, default'ON') – Only relevant for AdapTrans-prefiltered models (C_in == 2).
- class deepSTRF.models.audio.audio_zoo.Linear(n_frequency_bands: int = 34, temporal_window_size: int = 9, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None, wav2spec: Module = None)[source]
Bases:
AudioEncodingModelThe canonical Linear (L) STRF model — a single SpectroTemporal Receptive Field convolved with the (optionally prefiltered) input spectrogram.
All learnable parameters live in the readout (
STRFReadout), which holds the kernel of shape(N, C_in, F, T), applies it causally via left-padding, and follows it with a per-neuronnn.BatchNorm1d(N)before the output activation. The model’scoreisnn.Identity— every learnable scalar has the neuron axis as leading dim, so the model is strictly no-shared-params (each neuron’s parameters are independent of every other neuron’s).- Parameters:
n_frequency_bands (
int, default34) – Number of input spectrogram frequency bandsF.temporal_window_size (
int, default9) – STRF temporal extentTin frames.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, optional) – Pointwise nonlinearity applied at the readout output. Defaultnn.Identity(true linear model).prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (AdapTrans,ICAdaptation, or anynn.Moduleexposingout_channels).None(default) givesnn.IdentityandC_in = 1.kernel (
nn.Module, optional) – Pluggable STRF kernel for the readout.None(default) gives a vanillann.Conv2d; passParametricSTRF(...)for DCLS, or a separablenn.Sequentialfor a rank-1 factorization. SeedeepSTRF.models.layersfor the kernel module catalogue.wav2spec (
nn.Module, optional) – Optional raw-waveform front-end (deepSTRF.models.wav2spec.*). When provided, the model accepts raw audio(B, 1, T_audio)instead of a spectrogram.None(default) keeps the slot asnn.Identity()and the model expects(B, 1, F, T)spec input.
References
The L model is a folklore baseline; canonical formulations appear in:
Theunissen, Sen & Doupe (2000). “Spectral-Temporal Receptive Fields of Nonlinear Auditory Neurons Obtained Using Natural Sounds.” J. Neurosci. 20(6): 2315–2331. https://doi.org/10.1523/JNEUROSCI.20-06-02315.2000
Sahani & Linden (2003). “How Linear are Auditory Cortical Responses?” NIPS.
Notes
The per-neuron BatchNorm inside the readout absorbs into the kernel at inference (its running stats are frozen per-channel scalars), so the model remains a strict linear-affine map of the input at eval time and the learned STRF kernel is directly interpretable up to a per-neuron affine rescaling.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class deepSTRF.models.audio.audio_zoo.LinearNonlinear(n_frequency_bands: int = 34, temporal_window_size: int = 9, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]
Bases:
LinearLinear-Nonlinear (LN) STRF model — the Linear model followed by a pointwise output nonlinearity.
Inherits everything from
Linearand only changes the default output activation fromnn.Identityto a per-neuronParametricSoftplus. Pass anynn.Moduletooutput_activationto override.- Parameters:
output_activation (
nn.Module, defaultParametricSoftplus(out_neurons)) – Pointwise nonlinearity applied at the readout output. The default is unbounded above and non-negative — natural for spike-count regression on smoothed PSTHs that exceed 1. SeedeepSTRF.models.activationsfor other parametric variants (ParametricSigmoid,ParametricDoubleExponential).
See also
LinearSame architecture without the output nonlinearity.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class deepSTRF.models.audio.audio_zoo.NetworkReceptiveField(n_frequency_bands: int = 34, temporal_window_size: int = 9, n_hidden: int = 20, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]
Bases:
AudioEncodingModelNetwork Receptive Field (NRF) model — a two-layer feedforward STRF network.
Architecture: a STRF kernel projects the input spectrogram into a hidden layer of
Hunits; a 1×1 conv reads out theNoutput neurons from the hidden activations. With L1 regularization the paper finds typically 1–7 effective hidden units per neuron.- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.temporal_window_size (
int, default9) – STRF temporal extentT.n_hidden (
int, default20) – Hidden layer widthH.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, optional) – Pointwise nonlinearity at the readout output. Defaultnn.Sigmoid.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (AdapTrans,ICAdaptation, any module exposingout_channels).None(default) givesnn.IdentityandC_in = 1.kernel (
nn.Module, optional) – Pluggable hidden-layer STRF kernel.None(default) gives a vanillann.Conv2d; passParametricSTRF(...)for DCLS.
References
Harper, Schoppe, Willmore, Cui, Schnupp & King (2016). “Network Receptive Field Modeling Reveals Extensive Integration and Multi-feature Selectivity in Auditory Cortical Neurons.” PLOS Comp. Biol. 12(11): e1005113. https://doi.org/10.1371/journal.pcbi.1005113
Notes
Differences from the original paper:
We add a causal LayerNorm over input frequencies and over the hidden channel axis. The original assumes preprocessing-time input normalization and uses no internal norm.
The hidden activation is
nn.Tanh(paper-faithful: scaled tanh with ρ₁ ≈ 1.7159, ρ₂ = 2/3 — we use the unscaled standard tanh, equivalent up to a learned rescaling absorbed into the readout).Causal left-padding extends the model to arbitrary input lengths; the paper uses fixed-window slicing.
The hidden STRF kernel can be parameterized (DCLS); the paper uses a vanilla full kernel.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- STRFs(hidden_idx: int = 0, polarity: str = 'ON')[source]
Return the hidden-layer STRF kernel for one hidden unit as
(F, T).- Parameters:
hidden_idx (
int, default0) – Which of theHhidden units to return the STRF for.polarity (
{'ON', 'OFF'}, default'ON') – Only relevant when the prefilter hasC_in == 2(e.g. AdapTrans). Selects the ON or OFF channel of the kernel.
- class deepSTRF.models.audio.audio_zoo.StateNet(n_frequency_bands: int = 34, temporal_window_size: int = 1, kernel_size: int = 7, stride: int = 3, hidden_channels: int = 7, connectivity: str = 'LC', rnn_type: str = 'GRU', out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, wav2spec: Module = None)[source]
Bases:
AudioEncodingModelFully stateful STRF model — relies entirely on temporal recurrence to extract information from stimulus sequences, with no explicit STRF delay window.
Architecture: a stateless per-timestep spectral encoder maps each spectrogram column
(C_in, F)to a hidden representation(C, F_down). The flattened hidden representation is fed timestep-by-timestep to a recurrent (or state-space) model that accumulates context implicitly through its hidden state. A linear readout projects the recurrent hidden state toNoutput neurons.Causality is inherent to the recurrent backbone (RNN/GRU/LSTM/LMU/ Mamba/S4). The spectral encoder operates on a single timestep at a time so it does not couple frames temporally.
- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.temporal_window_size (
int, default1) – Unused by StateNet (kept forAudioEncodingModelAPI compatibility); recurrence handles temporal context.kernel_size (
int, default7) – Frequency kernel size for the spectral encoder.stride (
int, default3) – Frequency stride for the spectral encoder.hidden_channels (
int, default7) – Channel count of the spectral encoderC.connectivity (
{'LC', 'FC', 'CONV'}, default'LC') – Spectral encoder connectivity.'LC': locally-connected 1D layer (frequency-position-specific weights).'FC': dense linear projection with reshape to(C, F_down).'CONV': weight-shared 1D convolution.rnn_type (
{'GRU', 'LSTM', 'RNN', 'vanilla', 'LMU', 'Mamba', 'S4'}, default'GRU') – Recurrent / state-space backbone.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, defaultParametricSoftplus(out_neurons)) – Pointwise nonlinearity at the output. The default is unbounded above and non-negative — natural for spike-count regression.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (AdapTrans,ICAdaptation, or any module exposingout_channels).None(default) givesnn.IdentityandC_in = 1.
References
Rançon, Masquelier & Cottereau (2025). “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.” Communications Biology 8:1456. https://doi.org/10.1038/s42003-025-08858-3
Notes
The spectral encoder uses a CausalLayerNorm over the channel axis (
C); the original implementation used BatchNorm1d which pools statistics over the (T*B, F_down) axis, making it non-causal.The S4 backbone is imported lazily — its module emits CUDA-extension warnings on import that other backends would not see.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Per-timestep spectral encoder feeding a recurrent backbone.
Overrides the base template because the encoder runs in a flattened (T*B, C_in, F) batch (so every timestep is independent in the spectral pass) and the RNN expects a batch-first
(B, T, H)layout — these reshapes don’t decompose into the canonical core / readout slots.
- class deepSTRF.models.audio.audio_zoo.Transformer(n_frequency_bands: int = 34, freq_patch_size: int = None, time_patch_size: int = 1, context_window: int = None, embedding_dim: int = 48, n_heads: int = 1, n_layers: int = 1, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, wav2spec: Module = None)[source]
Bases:
AudioEncodingModelAttention-based STRF model — a Transformer encoder runs causal self-attention over a per-timestep token sequence extracted from the spectrogram.
Architecture:
input (B, 1, F, L) ↓ prefilter (B, C_in, F, L) ↓ time pad (B, C_in, F, L + K_T - 1) left-only causal ↓ patchify (B, embedding_dim, F_p, L) Conv2d, stride=(K_F, 1) ↓ flatten/permute (B, L, embedding_dim * F_p) one token per timestep ↓ + sinusoidal positional encoding ↓ TransformerEncoder with causal (+optional window) mask ↓ readout (B, N, 1, L)
The patchifier is a strided
nn.Conv2dwith kernel(K_F, K_T)and stride(K_F, 1). Frequency patches are non-overlapping (F_p = F // K_F); time stride is 1 so there is one token per timestep, andK_T > 1lets each token aggregatetime_patch_sizerecent frames. A(K_T - 1)-zero left pad along time keeps the patchifier strictly causal.The attention mask is constructed per forward pass at the actual sequence length, so the model generalizes to any input length
L. Settingcontext_windowto a positive int restricts attention to the most recentcontext_windowpast frames (band-causal), recovering a Sahani-style fixed STRF context window when wanted — the model can be evaluated with or without the bound at inference without retraining.- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.freq_patch_size (
int, optional) – Frequency-axis patch size for the patchifier.None(default) usesFitself — one token spans the full frequency axis at each timestep. Must divideF.time_patch_size (
int, default1) – Temporal extent of each patch in frames.1gives one token = one timestep slice; larger values let each token aggregate acrosstime_patch_sizerecent frames via the patchifier.context_window (
int, optional) – If set, restrict attention to the most recentcontext_windowpast frames (still causal — band-causal mask).None(default) gives unlimited causal context.embedding_dim (
int, default48) – Per-patch embedding dimension after the patchifier.n_heads (
int, default1) – Number of attention heads. Must divideembedding_dim * F_p.n_layers (
int, default1) – Number of TransformerEncoderLayer blocks.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, optional) – Pointwise nonlinearity at the readout. Defaultnn.Identity.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter.
References
Rançon, Masquelier & Cottereau (2025). “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.” Communications Biology 8:1456. https://doi.org/10.1038/s42003-025-08858-3
Vaswani et al. (2017). “Attention Is All You Need.” NeurIPS.
Notes
Sinusoidal positional encoding (Vaswani 2017) is used by default; it generalizes to arbitrary sequence lengths at inference. RoPE (Rotary Position Embedding, Su et al. 2021) is a planned alternative; it is omitted here because it requires a custom TransformerEncoderLayer (PyTorch’s stock module hides Q and K).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Module contents
- class deepSTRF.models.audio.AudioEncodingModel(n_frequency_bands: int, temporal_window_size: int, out_neurons: int = 1, prefiltering: Module = None, wav2spec: Module = None, *args, **kwargs)[source]
Bases:
NeuralModelBase class for encoding models of audio neural responses.
Forward signature: input
(B, 1, F, T)spectrogram → output(B, N, R=1, T)neural activity. Concrete subclasses populate the four canonical slotswav2spec/prefiltering/core/readout(seeNeuralModel).- Parameters:
n_frequency_bands (
int) – Number of input spectrogram frequency bandsF.temporal_window_size (
int) – STRF temporal extentTin frames. Used bySTRF_gradmapto size the null stimulus.out_neurons (
int, default1) – Number of output neuronsN.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (e.g.AdapTrans,ICAdaptation). Must expose anout_channelsinteger attribute so the model can sizeC_inautomatically.None(default) givesnn.Identity()andC_in = 1.wav2spec (
nn.Module, optional) – Optional raw-waveform front-end. Maps a mono waveform(B, 1, T_audio)to a spectrogram(B, 1, F, T_neural)and slots in at the top of the canonical forward pipeline (seeNeuralModel). Must expose anout_channels: intattribute equal ton_frequency_bands. Pair with a dataset in waveform mode (e.g.NS1Dataset(return_waveform=True)).None(default) keeps the slot asnn.Identity()— the model then expects spectrogram input(B, 1, F, T)as before.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- STRF_gradmap(T: int = None)[source]
Compute one STRF gradient map per output neuron in parallel.
For each of the
Noutput neurons, finds the changes in a null spectrogram that elicit an increase in that neuron’s activity at the last timestep (a Spike-Triggered-Average-like readout, computed by autodiff). The batch dimension is used to parallelize across neurons in a single forward / backward pass.- Parameters:
T (
int, optional) – Time-axis length of the null stimulus. Defaults toself.T.- Returns:
Per-neuron gradient map.
- Return type:
Tensorofshape ``(N,1,F,T)``
References
Rançon et al. (2025), “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.”
Notes
Future work:
Handle multi-channel inputs (the gradient is currently shaped
(N, 1, F, T)regardless ofC_in; an AdapTrans-prefiltered model hasC_in == 2and the per-channel gradients differ).Allow custom losses (e.g. sustained activity rather than last-timestep-only).
- validate()[source]
Check that the instance is deepSTRF-compatible.
Subclasses should call
super().validate()and then add their own checks (e.g.AudioEncodingModelchecksF, T > 0).- Raises:
AssertionError – If
self.Ois not a positive int, ifreadoutis unset or not antorch.nn.Module, or if any of thewav2spec/prefiltering/coreslots is not antorch.nn.Module.
- waveform_gradmap(stimulus, neuron=None, reduce='last')[source]
Gradient of a neuron’s response w.r.t. the input waveform.
The waveform-domain analogue of
STRF_gradmap(): instead of probing the spectrogram input, backprop a neuron’s response all the way through the (learnable)wav2specfront-end to the raw audio samples. The returned gradient is itself a waveform — a listenable, time-domain receptive field — only defined for wav-native models (a non-Identitywav2spec).- Parameters:
stimulus (
array-like,shape (T_audio,),(1,T_audio)or(1,1,T_audio)) – Audio to compute the gradmap around (e.g. a real stimulus). A real stimulus is recommended over silence — adaptive front-ends (PCEN) are ill-conditioned at zero energy.neuron (
int, optional) – Which output neuron.None(default) sums over all neurons (a population gradmap).reduce (
{'last', 'peak', 'sum'}, default'last') – How to reduce the neuron’s response over time before backprop.'last'(default, matchingSTRF_gradmap()) maximizes the activation at the last timestep, so the gradient is supported only within the receptive field before it — it reveals the RF and decays to ~zero further into the past.'peak'does the same at the peak-response timestep.'sum'time-integrates over all output timesteps, which makes the gradient non-zero almost everywhere by construction (a whole-stimulus saliency map, not a receptive field).
- Returns:
Per-audio-sample gradient — the waveform-domain receptive field. Computed in
evalmode (the strictly-causal inference regime). Withreduce='last'the support is the RF length (e.g. ~45 ms for a T=9 STRF on a mel front-end; longer for adaptive front-ends like LEAF, whose PCEN smoother adds a decaying temporal memory).- Return type:
torch.Tensor,shape ``(T_audio,)``
- class deepSTRF.models.audio.ConvNet2D(n_frequency_bands: int = 34, kernel_size: tuple = (3, 9), c_hidden: int = 10, n_hidden: int = 20, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None)[source]
Bases:
AudioEncodingModelConvolutional STRF model with three sequential 2D convs and a 2-layer fully-connected readout — adapted from the ‘2D-CNN’ of Pennington & David (2023).
Architecture: three Conv2d → CausalLayerNorm → LeakyReLU blocks extract a stack of feature maps; the per-time-step features are flattened over the (channel × downsampled-frequency) axes and a 2-layer FC reads out
Noutput neurons.- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.kernel_size (
tupleofint, default(3,9)) – Conv2d kernel(K_F, K_T)shared across the three conv blocks.c_hidden (
int, default10) – Number of channels in each conv block.n_hidden (
int, default20) – Width of the FC hidden layer.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, defaultParametricSoftplus(out_neurons)) – Pointwise nonlinearity at the output. The default is unbounded above and non-negative — natural for spike-count regression.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (AdapTrans,ICAdaptation, or any module exposingout_channels).None(default) givesnn.IdentityandC_in = 1.
References
Pennington & David (2023). “A convolutional neural network provides a generalizable model of natural sound coding by neural populations in auditory cortex.” PLOS Comp. Biol. 19(5): e1011110. https://doi.org/10.1371/journal.pcbi.1011110
Notes
Differences from the original paper:
Causal LayerNorm replaces the missing internal normalization (paper uses none).
Hidden activation is
LeakyReLU(0.1)rather than ReLU. Empirical preference, very small architectural difference.2D convs over
(F, T)rather than 1D convs overT(the paper applies 1D convolutions with implicit spectral pooling).Frequency downsampling is implicit via valid-padding shrinkage: three convs each shrink
FbyK_F - 1, givingF_down = F - 3*(K_F - 1).Causal left-padding extends the model to arbitrary input lengths; the paper also uses explicit causal padding.
The output activation is configurable; the paper uses a 4-parameter double-exponential — see
deepSTRF.models.activations.ParametricDoubleExponential.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class deepSTRF.models.audio.DNet(n_frequency_bands: int = 34, temporal_window_size: int = 9, n_hidden: int = 20, init_tau: float = 2.0, decay_input: bool = True, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]
Bases:
AudioEncodingModelDynamic Network (DNet) — an NRF whose hidden and output units are stateful with learnable temporal decay.
Architecture: STRF projection → channel-norm → sigmoid → learnable exponential decay (one time constant per hidden unit) → 1×1 readout → output activation. The exponential decay is causal: each unit’s output at time
tis a convolution of its instantaneous input with a learned one-sided exponential kernel.- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.temporal_window_size (
int, default9) – STRF temporal extentT.n_hidden (
int, default20) – Hidden layer widthH.init_tau (
float, default2.0) – Initial time constant (in frames) for the hidden-unit exponential decay.decay_input (
bool, defaultTrue) – If True, the exponential decay also weights its instantaneous input by1/(1+d²)(paper convention); if False, the instantaneous input passes through unscaled.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, optional) – Pointwise nonlinearity at the readout output. Defaultnn.Identity(paper-faithful linear readout).prefiltering (
nn.Module, optional) – Optional spectrogram prefilter.kernel (
nn.Module, optional) – Pluggable hidden-layer STRF kernel.
References
Rahman, Willmore, King & Harper (2019). “A dynamic network model of temporal receptive fields in primary auditory cortex.” PLOS Comp. Biol. 15(5): e1006618. https://doi.org/10.1371/journal.pcbi.1006618
Notes
Differences from the original paper:
Causal LayerNorm replaces the missing internal normalization (paper assumes preprocessing-time input normalization).
Causal left-padding extends the model to arbitrary input lengths; the paper uses fixed-window slicing.
The hidden STRF kernel can be parameterized (DCLS); the paper uses a vanilla full kernel.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- STRFs(hidden_idx: int = 0, polarity: str = 'ON')[source]
Return the hidden-layer STRF kernel for one hidden unit as
(F, T).- Parameters:
hidden_idx (
int, default0) – Which of theHhidden units to inspect.polarity (
{'ON', 'OFF'}, default'ON') – Only relevant for AdapTrans-prefiltered models (C_in == 2).
- class deepSTRF.models.audio.ICNet(audio_fs: int, out_neurons: int, dt_ms: float = 5.0, n_filters: int = 48, sincnet_kernel_size: int = 64, encoder_channels: int = 128, encoder_kernel_size: int = 64, n_encoder_layers: int = 5, bottleneck_channels: int = 64, encoder_strides: Sequence[int] | None = None)[source]
Bases:
AudioEncodingModelEnd-to-end ICNet (Drakopoulos et al. 2025) ported to deepSTRF.
Architecture: SincNet (48 filters, K=64, stride 1, symlog) → 5× causal
Conv1d(128 ch, K=64, PReLU)at strides that multiply toaudio_fs · dt_ms / 1000→ bottleneckConv1d(64 ch, K=64, stride 1, PReLU)→Linear(64 → N)→ softplus (Poisson head,N_c = 1in paper notation).Cross-dataset configuration
The paper trains on 24 414 Hz gerbil-IC audio binned at ~1.31 ms (32 samples per bin, 5 stride-2 conv layers). To use the same architecture on a dataset at a different
(audio_fs, dt_ms), the encoder strides are auto-factored so they multiply toaudio_fs · dt_ms / 1000(the number of audio samples per neural bin). For NS1 (48 kHz / 5 ms) that’s 240 samples / bin and the default factorisation is[2, 2, 2, 2, 15]. Pass an explicitencoder_strideslist to override. The layer structure (kernel sizes, channel counts, activations) stays paper-faithful; only the strides scale with the dataset, per the deepSTRF policy of adapting hyperparameters to each dataset’s temporal resolution.The decoder is intentionally simple — paper-faithful (the paper: “the simple linear decoders in ICNet … ensure that the latent representation in the bottleneck is constrained to directly reflect the dynamics that underlie neural activity”). The expressivity lives in the shared encoder.
Differences from the paper
Single-branch / time-invariant only. The paper’s multi-branch and time-variant heads (animal-specific decoders, timestamp-input modulation) are out of scope for the deepSTRF v1 port.
Poisson head only. The paper’s main result uses a categorical cross-entropy head with
N_c = 5classes for spike counts in{0, 1, 2, 3, ≥4}. The deepSTRF training stack centres on rate-based losses; cross-entropy can be added later.No left-context crop. The paper feeds 10 240 audio samples in and crops the leftmost 64 frames from the bottleneck output to suppress edge effects. deepSTRF’s convention is to keep
T_neuraloutput frames matching the dataset’s response window; causal convs leave the first few frames noisier but downstream losses handle that.
- param audio_fs:
Audio sample rate (Hz). Determines the total encoder downsampling.
- type audio_fs:
int- param out_neurons:
Number of output neurons
N.- type out_neurons:
int- param dt_ms:
Target neural bin width in ms. Encoder strides are factored so the total downsampling matches
audio_fs · dt_ms / 1000.- type dt_ms:
float, default5.0- param n_filters:
SincNet filter count.
- type n_filters:
int, default48- param sincnet_kernel_size:
- type sincnet_kernel_size:
int, default64- param encoder_channels:
- type encoder_channels:
int, default128- param encoder_kernel_size:
- type encoder_kernel_size:
int, default64- param n_encoder_layers:
- type n_encoder_layers:
int, default5- param bottleneck_channels:
Output channel count of the bottleneck conv.
- type bottleneck_channels:
int, default64- param encoder_strides:
Per-layer encoder strides. Default: auto-factor.
- type encoder_strides:
sequenceofint, optional
References
Drakopoulos, Pellatt, Sabesan, Xia, Fragner & Lesica (2025). “Modelling neural coding in the auditory midbrain with high resolution and accuracy.” Nature Machine Intelligence 7:1478-1493. https://doi.org/10.1038/s42256-025-01104-9
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor[source]
Forward pass.
Overrides the base template because the bottleneck latent is shaped
(B, 1, 64, T)(an explicit C_in axis on top of the latent dim) and the paper’s decoder is a per-timestep linear map — the canonicalSTRFReadoutslot doesn’t fit cleanly.- Parameters:
x (
torch.Tensor) – Mono waveform, shape(B, 1, T_audio).- Returns:
Predicted spike rate, shape
(B, N, 1, T_neural). Non-negative (softplus output) — pair withpoisson_loss().- Return type:
torch.Tensor
- class deepSTRF.models.audio.Linear(n_frequency_bands: int = 34, temporal_window_size: int = 9, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None, wav2spec: Module = None)[source]
Bases:
AudioEncodingModelThe canonical Linear (L) STRF model — a single SpectroTemporal Receptive Field convolved with the (optionally prefiltered) input spectrogram.
All learnable parameters live in the readout (
STRFReadout), which holds the kernel of shape(N, C_in, F, T), applies it causally via left-padding, and follows it with a per-neuronnn.BatchNorm1d(N)before the output activation. The model’scoreisnn.Identity— every learnable scalar has the neuron axis as leading dim, so the model is strictly no-shared-params (each neuron’s parameters are independent of every other neuron’s).- Parameters:
n_frequency_bands (
int, default34) – Number of input spectrogram frequency bandsF.temporal_window_size (
int, default9) – STRF temporal extentTin frames.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, optional) – Pointwise nonlinearity applied at the readout output. Defaultnn.Identity(true linear model).prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (AdapTrans,ICAdaptation, or anynn.Moduleexposingout_channels).None(default) givesnn.IdentityandC_in = 1.kernel (
nn.Module, optional) – Pluggable STRF kernel for the readout.None(default) gives a vanillann.Conv2d; passParametricSTRF(...)for DCLS, or a separablenn.Sequentialfor a rank-1 factorization. SeedeepSTRF.models.layersfor the kernel module catalogue.wav2spec (
nn.Module, optional) – Optional raw-waveform front-end (deepSTRF.models.wav2spec.*). When provided, the model accepts raw audio(B, 1, T_audio)instead of a spectrogram.None(default) keeps the slot asnn.Identity()and the model expects(B, 1, F, T)spec input.
References
The L model is a folklore baseline; canonical formulations appear in:
Theunissen, Sen & Doupe (2000). “Spectral-Temporal Receptive Fields of Nonlinear Auditory Neurons Obtained Using Natural Sounds.” J. Neurosci. 20(6): 2315–2331. https://doi.org/10.1523/JNEUROSCI.20-06-02315.2000
Sahani & Linden (2003). “How Linear are Auditory Cortical Responses?” NIPS.
Notes
The per-neuron BatchNorm inside the readout absorbs into the kernel at inference (its running stats are frozen per-channel scalars), so the model remains a strict linear-affine map of the input at eval time and the learned STRF kernel is directly interpretable up to a per-neuron affine rescaling.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class deepSTRF.models.audio.LinearNonlinear(n_frequency_bands: int = 34, temporal_window_size: int = 9, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]
Bases:
LinearLinear-Nonlinear (LN) STRF model — the Linear model followed by a pointwise output nonlinearity.
Inherits everything from
Linearand only changes the default output activation fromnn.Identityto a per-neuronParametricSoftplus. Pass anynn.Moduletooutput_activationto override.- Parameters:
output_activation (
nn.Module, defaultParametricSoftplus(out_neurons)) – Pointwise nonlinearity applied at the readout output. The default is unbounded above and non-negative — natural for spike-count regression on smoothed PSTHs that exceed 1. SeedeepSTRF.models.activationsfor other parametric variants (ParametricSigmoid,ParametricDoubleExponential).
See also
LinearSame architecture without the output nonlinearity.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class deepSTRF.models.audio.NetworkReceptiveField(n_frequency_bands: int = 34, temporal_window_size: int = 9, n_hidden: int = 20, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, kernel: Module = None)[source]
Bases:
AudioEncodingModelNetwork Receptive Field (NRF) model — a two-layer feedforward STRF network.
Architecture: a STRF kernel projects the input spectrogram into a hidden layer of
Hunits; a 1×1 conv reads out theNoutput neurons from the hidden activations. With L1 regularization the paper finds typically 1–7 effective hidden units per neuron.- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.temporal_window_size (
int, default9) – STRF temporal extentT.n_hidden (
int, default20) – Hidden layer widthH.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, optional) – Pointwise nonlinearity at the readout output. Defaultnn.Sigmoid.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (AdapTrans,ICAdaptation, any module exposingout_channels).None(default) givesnn.IdentityandC_in = 1.kernel (
nn.Module, optional) – Pluggable hidden-layer STRF kernel.None(default) gives a vanillann.Conv2d; passParametricSTRF(...)for DCLS.
References
Harper, Schoppe, Willmore, Cui, Schnupp & King (2016). “Network Receptive Field Modeling Reveals Extensive Integration and Multi-feature Selectivity in Auditory Cortical Neurons.” PLOS Comp. Biol. 12(11): e1005113. https://doi.org/10.1371/journal.pcbi.1005113
Notes
Differences from the original paper:
We add a causal LayerNorm over input frequencies and over the hidden channel axis. The original assumes preprocessing-time input normalization and uses no internal norm.
The hidden activation is
nn.Tanh(paper-faithful: scaled tanh with ρ₁ ≈ 1.7159, ρ₂ = 2/3 — we use the unscaled standard tanh, equivalent up to a learned rescaling absorbed into the readout).Causal left-padding extends the model to arbitrary input lengths; the paper uses fixed-window slicing.
The hidden STRF kernel can be parameterized (DCLS); the paper uses a vanilla full kernel.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- STRFs(hidden_idx: int = 0, polarity: str = 'ON')[source]
Return the hidden-layer STRF kernel for one hidden unit as
(F, T).- Parameters:
hidden_idx (
int, default0) – Which of theHhidden units to return the STRF for.polarity (
{'ON', 'OFF'}, default'ON') – Only relevant when the prefilter hasC_in == 2(e.g. AdapTrans). Selects the ON or OFF channel of the kernel.
- class deepSTRF.models.audio.StateNet(n_frequency_bands: int = 34, temporal_window_size: int = 1, kernel_size: int = 7, stride: int = 3, hidden_channels: int = 7, connectivity: str = 'LC', rnn_type: str = 'GRU', out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, wav2spec: Module = None)[source]
Bases:
AudioEncodingModelFully stateful STRF model — relies entirely on temporal recurrence to extract information from stimulus sequences, with no explicit STRF delay window.
Architecture: a stateless per-timestep spectral encoder maps each spectrogram column
(C_in, F)to a hidden representation(C, F_down). The flattened hidden representation is fed timestep-by-timestep to a recurrent (or state-space) model that accumulates context implicitly through its hidden state. A linear readout projects the recurrent hidden state toNoutput neurons.Causality is inherent to the recurrent backbone (RNN/GRU/LSTM/LMU/ Mamba/S4). The spectral encoder operates on a single timestep at a time so it does not couple frames temporally.
- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.temporal_window_size (
int, default1) – Unused by StateNet (kept forAudioEncodingModelAPI compatibility); recurrence handles temporal context.kernel_size (
int, default7) – Frequency kernel size for the spectral encoder.stride (
int, default3) – Frequency stride for the spectral encoder.hidden_channels (
int, default7) – Channel count of the spectral encoderC.connectivity (
{'LC', 'FC', 'CONV'}, default'LC') – Spectral encoder connectivity.'LC': locally-connected 1D layer (frequency-position-specific weights).'FC': dense linear projection with reshape to(C, F_down).'CONV': weight-shared 1D convolution.rnn_type (
{'GRU', 'LSTM', 'RNN', 'vanilla', 'LMU', 'Mamba', 'S4'}, default'GRU') – Recurrent / state-space backbone.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, defaultParametricSoftplus(out_neurons)) – Pointwise nonlinearity at the output. The default is unbounded above and non-negative — natural for spike-count regression.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter (AdapTrans,ICAdaptation, or any module exposingout_channels).None(default) givesnn.IdentityandC_in = 1.
References
Rançon, Masquelier & Cottereau (2025). “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.” Communications Biology 8:1456. https://doi.org/10.1038/s42003-025-08858-3
Notes
The spectral encoder uses a CausalLayerNorm over the channel axis (
C); the original implementation used BatchNorm1d which pools statistics over the (T*B, F_down) axis, making it non-causal.The S4 backbone is imported lazily — its module emits CUDA-extension warnings on import that other backends would not see.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Per-timestep spectral encoder feeding a recurrent backbone.
Overrides the base template because the encoder runs in a flattened (T*B, C_in, F) batch (so every timestep is independent in the spectral pass) and the RNN expects a batch-first
(B, T, H)layout — these reshapes don’t decompose into the canonical core / readout slots.
- class deepSTRF.models.audio.Transformer(n_frequency_bands: int = 34, freq_patch_size: int = None, time_patch_size: int = 1, context_window: int = None, embedding_dim: int = 48, n_heads: int = 1, n_layers: int = 1, out_neurons: int = 1, output_activation: Module = None, prefiltering: Module = None, wav2spec: Module = None)[source]
Bases:
AudioEncodingModelAttention-based STRF model — a Transformer encoder runs causal self-attention over a per-timestep token sequence extracted from the spectrogram.
Architecture:
input (B, 1, F, L) ↓ prefilter (B, C_in, F, L) ↓ time pad (B, C_in, F, L + K_T - 1) left-only causal ↓ patchify (B, embedding_dim, F_p, L) Conv2d, stride=(K_F, 1) ↓ flatten/permute (B, L, embedding_dim * F_p) one token per timestep ↓ + sinusoidal positional encoding ↓ TransformerEncoder with causal (+optional window) mask ↓ readout (B, N, 1, L)
The patchifier is a strided
nn.Conv2dwith kernel(K_F, K_T)and stride(K_F, 1). Frequency patches are non-overlapping (F_p = F // K_F); time stride is 1 so there is one token per timestep, andK_T > 1lets each token aggregatetime_patch_sizerecent frames. A(K_T - 1)-zero left pad along time keeps the patchifier strictly causal.The attention mask is constructed per forward pass at the actual sequence length, so the model generalizes to any input length
L. Settingcontext_windowto a positive int restricts attention to the most recentcontext_windowpast frames (band-causal), recovering a Sahani-style fixed STRF context window when wanted — the model can be evaluated with or without the bound at inference without retraining.- Parameters:
n_frequency_bands (
int, default34) – Number of input frequency bandsF.freq_patch_size (
int, optional) – Frequency-axis patch size for the patchifier.None(default) usesFitself — one token spans the full frequency axis at each timestep. Must divideF.time_patch_size (
int, default1) – Temporal extent of each patch in frames.1gives one token = one timestep slice; larger values let each token aggregate acrosstime_patch_sizerecent frames via the patchifier.context_window (
int, optional) – If set, restrict attention to the most recentcontext_windowpast frames (still causal — band-causal mask).None(default) gives unlimited causal context.embedding_dim (
int, default48) – Per-patch embedding dimension after the patchifier.n_heads (
int, default1) – Number of attention heads. Must divideembedding_dim * F_p.n_layers (
int, default1) – Number of TransformerEncoderLayer blocks.out_neurons (
int, default1) – Number of output neuronsN.output_activation (
nn.Module, optional) – Pointwise nonlinearity at the readout. Defaultnn.Identity.prefiltering (
nn.Module, optional) – Optional spectrogram prefilter.
References
Rançon, Masquelier & Cottereau (2025). “Temporal recurrence as a general mechanism to explain neural responses in the auditory system.” Communications Biology 8:1456. https://doi.org/10.1038/s42003-025-08858-3
Vaswani et al. (2017). “Attention Is All You Need.” NeurIPS.
Notes
Sinusoidal positional encoding (Vaswani 2017) is used by default; it generalizes to arbitrary sequence lengths at inference. RoPE (Rotary Position Embedding, Su et al. 2021) is a planned alternative; it is omitted here because it requires a custom TransformerEncoderLayer (PyTorch’s stock module hides Q and K).
Initialize internal Module state, shared by both nn.Module and ScriptModule.