import torch
import torch.nn as nn
import torch.nn.functional
from .audio_model import AudioEncodingModel
import deepSTRF.models.layers as layers
from deepSTRF.models.activations import ParametricSoftplus
from deepSTRF.models.dependencies.lmu import LMU
from mambapy.mamba import MambaBlock, MambaConfig
from deepSTRF.models.prefiltering import AdapTrans
from deepSTRF.models.readouts import STRFReadout, LinearReadout
# S4Block is imported lazily inside StateNet — its module emits noisy stderr
# warnings about missing CUDA extensions and pulls a heavy dependency
# graph; users who don't pick rnn_type='S4' shouldn't pay either cost.
[docs]
class Linear(AudioEncodingModel):
"""
The canonical Linear (L) STRF model — a single SpectroTemporal Receptive
Field convolved with the (optionally prefiltered) input spectrogram.
All learnable parameters live in the readout (``STRFReadout``), which
holds the kernel of shape ``(N, C_in, F, T)``, applies it causally via
left-padding, and follows it with a per-neuron ``nn.BatchNorm1d(N)``
before the output activation. The model's ``core`` is ``nn.Identity``
— every learnable scalar has the neuron axis as leading dim, so the
model is strictly no-shared-params (each neuron's parameters are
independent of every other neuron's).
Parameters
----------
n_frequency_bands : int, default 34
Number of input spectrogram frequency bands ``F``.
temporal_window_size : int, default 9
STRF temporal extent ``T`` in frames.
out_neurons : int, default 1
Number of output neurons ``N``.
output_activation : nn.Module, optional
Pointwise nonlinearity applied at the readout output. Default
``nn.Identity`` (true linear model).
prefiltering : nn.Module, optional
Optional spectrogram prefilter (``AdapTrans``, ``ICAdaptation``,
or any ``nn.Module`` exposing ``out_channels``). ``None``
(default) gives ``nn.Identity`` and ``C_in = 1``.
kernel : nn.Module, optional
Pluggable STRF kernel for the readout. ``None`` (default) gives
a vanilla ``nn.Conv2d``; pass ``ParametricSTRF(...)`` for DCLS,
or a separable ``nn.Sequential`` for a rank-1 factorization.
See ``deepSTRF.models.layers`` for the kernel module catalogue.
wav2spec : nn.Module, optional
Optional raw-waveform front-end (``deepSTRF.models.wav2spec.*``).
When provided, the model accepts raw audio ``(B, 1, T_audio)``
instead of a spectrogram. ``None`` (default) keeps the slot as
``nn.Identity()`` and the model expects ``(B, 1, F, T)`` spec
input.
References
----------
The L model is a folklore baseline; canonical formulations appear in:
Theunissen, Sen & Doupe (2000). "Spectral-Temporal Receptive Fields of
Nonlinear Auditory Neurons Obtained Using Natural Sounds."
J. Neurosci. 20(6): 2315–2331.
https://doi.org/10.1523/JNEUROSCI.20-06-02315.2000
Sahani & Linden (2003). "How Linear are Auditory Cortical Responses?"
NIPS.
Notes
-----
The per-neuron BatchNorm inside the readout absorbs into the kernel
at inference (its running stats are frozen per-channel scalars), so
the model remains a strict linear-affine map of the input at eval
time and the learned STRF kernel is directly interpretable up to a
per-neuron affine rescaling.
"""
def __init__(self, n_frequency_bands: int = 34, temporal_window_size: int = 9,
out_neurons: int = 1,
output_activation: nn.Module = None,
prefiltering: nn.Module = None,
kernel: nn.Module = None,
wav2spec: nn.Module = None):
super().__init__(
n_frequency_bands=n_frequency_bands,
temporal_window_size=temporal_window_size,
out_neurons=out_neurons,
prefiltering=prefiltering,
wav2spec=wav2spec,
)
# core: identity. All learnable parameters live in the readout —
# the per-neuron BatchNorm1d inside STRFReadout handles normalisation
# at the model's output (one BN per cell, no params shared across N).
self.core = nn.Identity()
# readout: pluggable STRF kernel + per-neuron BN + output activation
self.readout = STRFReadout(
F=self.F, T=self.T, C_in=self.C_in, out_neurons=self.O,
kernel=kernel,
activation=output_activation if output_activation is not None else nn.Identity(),
)
# forward is inherited from NeuralModel — wav2spec → prefiltering → core → readout
[docs]
def STRF_weight(self, polarity: str = 'ON'):
"""
Return the readout's STRF kernel as a ``(N, F, T)`` tensor.
For models prefiltered with ``AdapTrans`` (``C_in == 2``),
``polarity`` selects the ON or OFF channel of the kernel. For
single-channel inputs the parameter is ignored.
"""
full = self.readout.STRF_weight() # (N, C_in, F, T)
if isinstance(self.prefiltering, AdapTrans):
if polarity in ('ON', 'On', 'on', 0):
return full[:, 0]
if polarity in ('OFF', 'Off', 'off', 1):
return full[:, 1]
raise ValueError(
f"polarity must be 'ON' or 'OFF' for an AdapTrans-prefiltered "
f"Linear model — got {polarity!r}"
)
return full[:, 0]
[docs]
class LinearNonlinear(Linear):
"""
Linear-Nonlinear (LN) STRF model — the Linear model followed by a
pointwise output nonlinearity.
Inherits everything from ``Linear`` and only changes the default
output activation from ``nn.Identity`` to a per-neuron
:class:`ParametricSoftplus`. Pass any ``nn.Module`` to
``output_activation`` to override.
Parameters
----------
output_activation : nn.Module, default ParametricSoftplus(out_neurons)
Pointwise nonlinearity applied at the readout output. The default
is unbounded above and non-negative — natural for spike-count
regression on smoothed PSTHs that exceed 1. See
``deepSTRF.models.activations`` for other parametric variants
(``ParametricSigmoid``, ``ParametricDoubleExponential``).
See Also
--------
Linear : Same architecture without the output nonlinearity.
"""
def __init__(self, n_frequency_bands: int = 34, temporal_window_size: int = 9,
out_neurons: int = 1,
output_activation: nn.Module = None,
prefiltering: nn.Module = None,
kernel: nn.Module = None):
super().__init__(
n_frequency_bands=n_frequency_bands,
temporal_window_size=temporal_window_size,
out_neurons=out_neurons,
output_activation=(output_activation if output_activation is not None
else ParametricSoftplus(out_neurons)),
prefiltering=prefiltering,
kernel=kernel,
)
[docs]
class NetworkReceptiveField(AudioEncodingModel):
"""
Network Receptive Field (NRF) model — a two-layer feedforward STRF
network.
Architecture: a STRF kernel projects the input spectrogram into a
hidden layer of ``H`` units; a 1×1 conv reads out the ``N`` output
neurons from the hidden activations. With L1 regularization the
paper finds typically 1–7 effective hidden units per neuron.
Parameters
----------
n_frequency_bands : int, default 34
Number of input frequency bands ``F``.
temporal_window_size : int, default 9
STRF temporal extent ``T``.
n_hidden : int, default 20
Hidden layer width ``H``.
out_neurons : int, default 1
Number of output neurons ``N``.
output_activation : nn.Module, optional
Pointwise nonlinearity at the readout output. Default
``nn.Sigmoid``.
prefiltering : nn.Module, optional
Optional spectrogram prefilter (``AdapTrans``, ``ICAdaptation``,
any module exposing ``out_channels``). ``None`` (default) gives
``nn.Identity`` and ``C_in = 1``.
kernel : nn.Module, optional
Pluggable hidden-layer STRF kernel. ``None`` (default) gives a
vanilla ``nn.Conv2d``; pass ``ParametricSTRF(...)`` for DCLS.
References
----------
Harper, Schoppe, Willmore, Cui, Schnupp & King (2016).
"Network Receptive Field Modeling Reveals Extensive Integration and
Multi-feature Selectivity in Auditory Cortical Neurons."
PLOS Comp. Biol. 12(11): e1005113.
https://doi.org/10.1371/journal.pcbi.1005113
Notes
-----
Differences from the original paper:
- We add a causal LayerNorm over input frequencies and over the
hidden channel axis. The original assumes preprocessing-time
input normalization and uses no internal norm.
- The hidden activation is ``nn.Tanh`` (paper-faithful: scaled
tanh with ρ₁ ≈ 1.7159, ρ₂ = 2/3 — we use the unscaled standard
tanh, equivalent up to a learned rescaling absorbed into the
readout).
- Causal left-padding extends the model to arbitrary input lengths;
the paper uses fixed-window slicing.
- The hidden STRF kernel can be parameterized (DCLS); the paper
uses a vanilla full kernel.
"""
def __init__(self, n_frequency_bands: int = 34, temporal_window_size: int = 9,
n_hidden: int = 20,
out_neurons: int = 1,
output_activation: nn.Module = None,
prefiltering: nn.Module = None,
kernel: nn.Module = None):
super().__init__(
n_frequency_bands=n_frequency_bands,
temporal_window_size=temporal_window_size,
out_neurons=out_neurons,
prefiltering=prefiltering,
)
self.H = n_hidden
# core: hidden STRF projection → channel norm → tanh.
# The hidden STRF projection emits (B, H, 1, T); LinearReadout downstream
# squeezes the singleton spatial axis automatically.
self.core = nn.Sequential(
layers.CausalSTRFConv(self.F, self.T, self.C_in, self.H, kernel=kernel),
layers.CausalLayerNorm(self.H, dim=1),
nn.Tanh(),
)
# readout: per-neuron 1×1 projection from H hidden units.
self.readout = LinearReadout(
in_features=self.H, out_neurons=self.O,
activation=(output_activation if output_activation is not None
else ParametricSoftplus(self.O)),
)
# forward inherited from NeuralModel — wav2spec → prefiltering → core → readout
[docs]
def STRFs(self, hidden_idx: int = 0, polarity: str = 'ON'):
"""
Return the hidden-layer STRF kernel for one hidden unit as ``(F, T)``.
Parameters
----------
hidden_idx : int, default 0
Which of the ``H`` hidden units to return the STRF for.
polarity : {'ON', 'OFF'}, default 'ON'
Only relevant when the prefilter has ``C_in == 2`` (e.g.
AdapTrans). Selects the ON or OFF channel of the kernel.
"""
# core[0] is the CausalSTRFConv whose .STRF_weight() returns (H, C_in, F, T)
full = self.core[0].STRF_weight()
if isinstance(self.prefiltering, AdapTrans):
if polarity in ('ON', 'On', 'on', 0):
return full[hidden_idx, 0]
if polarity in ('OFF', 'Off', 'off', 1):
return full[hidden_idx, 1]
raise ValueError(
f"polarity must be 'ON' or 'OFF' for an AdapTrans-prefiltered NRF — got {polarity!r}"
)
return full[hidden_idx, 0]
[docs]
class DNet(AudioEncodingModel):
"""
Dynamic Network (DNet) — an NRF whose hidden and output units are
stateful with learnable temporal decay.
Architecture: STRF projection → channel-norm → sigmoid → learnable
exponential decay (one time constant per hidden unit) → 1×1 readout
→ output activation. The exponential decay is causal: each unit's
output at time ``t`` is a convolution of its instantaneous input
with a learned one-sided exponential kernel.
Parameters
----------
n_frequency_bands : int, default 34
Number of input frequency bands ``F``.
temporal_window_size : int, default 9
STRF temporal extent ``T``.
n_hidden : int, default 20
Hidden layer width ``H``.
init_tau : float, default 2.0
Initial time constant (in frames) for the hidden-unit
exponential decay.
decay_input : bool, default True
If True, the exponential decay also weights its instantaneous
input by ``1/(1+d²)`` (paper convention); if False, the
instantaneous input passes through unscaled.
out_neurons : int, default 1
Number of output neurons ``N``.
output_activation : nn.Module, optional
Pointwise nonlinearity at the readout output. Default
``nn.Identity`` (paper-faithful linear readout).
prefiltering : nn.Module, optional
Optional spectrogram prefilter.
kernel : nn.Module, optional
Pluggable hidden-layer STRF kernel.
References
----------
Rahman, Willmore, King & Harper (2019).
"A dynamic network model of temporal receptive fields in primary
auditory cortex." PLOS Comp. Biol. 15(5): e1006618.
https://doi.org/10.1371/journal.pcbi.1006618
Notes
-----
Differences from the original paper:
- Causal LayerNorm replaces the missing internal normalization
(paper assumes preprocessing-time input normalization).
- Causal left-padding extends the model to arbitrary input lengths;
the paper uses fixed-window slicing.
- The hidden STRF kernel can be parameterized (DCLS); the paper
uses a vanilla full kernel.
"""
def __init__(self, n_frequency_bands: int = 34, temporal_window_size: int = 9,
n_hidden: int = 20, init_tau: float = 2.0, decay_input: bool = True,
out_neurons: int = 1,
output_activation: nn.Module = None,
prefiltering: nn.Module = None,
kernel: nn.Module = None):
super().__init__(
n_frequency_bands=n_frequency_bands,
temporal_window_size=temporal_window_size,
out_neurons=out_neurons,
prefiltering=prefiltering,
)
self.H = n_hidden
decay_kernel = round(init_tau * 7)
# core: hidden STRF projection → channel norm → sigmoid
# → per-hidden-unit causal exponential decay
self.core = nn.Sequential(
layers.CausalSTRFConv(self.F, self.T, self.C_in, self.H, kernel=kernel),
layers.CausalLayerNorm(self.H, dim=1),
nn.Sigmoid(),
layers.LearnableExponentialDecay(self.H, kernel_size=decay_kernel,
init_tau=init_tau, decay_input=decay_input),
)
# readout: per-neuron 1×1 projection from H decayed hidden units;
# the hidden-side decay already provides temporal smoothing.
self.readout = LinearReadout(
in_features=self.H, out_neurons=self.O,
activation=output_activation if output_activation is not None else nn.Identity(),
)
# forward inherited from NeuralModel — wav2spec → prefiltering → core → readout
[docs]
def STRFs(self, hidden_idx: int = 0, polarity: str = 'ON'):
"""
Return the hidden-layer STRF kernel for one hidden unit as ``(F, T)``.
Parameters
----------
hidden_idx : int, default 0
Which of the ``H`` hidden units to inspect.
polarity : {'ON', 'OFF'}, default 'ON'
Only relevant for AdapTrans-prefiltered models (``C_in == 2``).
"""
# core[0] is the CausalSTRFConv whose .STRF_weight() returns (H, C_in, F, T)
full = self.core[0].STRF_weight()
if isinstance(self.prefiltering, AdapTrans):
if polarity in ('ON', 'On', 'on', 0):
return full[hidden_idx, 0]
if polarity in ('OFF', 'Off', 'off', 1):
return full[hidden_idx, 1]
raise ValueError(
f"polarity must be 'ON' or 'OFF' for an AdapTrans-prefiltered DNet — got {polarity!r}"
)
return full[hidden_idx, 0]
[docs]
class ConvNet2D(AudioEncodingModel):
"""
Convolutional STRF model with three sequential 2D convs and a 2-layer
fully-connected readout — adapted from the '2D-CNN' of Pennington
& David (2023).
Architecture: three Conv2d → CausalLayerNorm → LeakyReLU blocks
extract a stack of feature maps; the per-time-step features are
flattened over the (channel × downsampled-frequency) axes and a
2-layer FC reads out ``N`` output neurons.
Parameters
----------
n_frequency_bands : int, default 34
Number of input frequency bands ``F``.
kernel_size : tuple of int, default (3, 9)
Conv2d kernel ``(K_F, K_T)`` shared across the three conv blocks.
c_hidden : int, default 10
Number of channels in each conv block.
n_hidden : int, default 20
Width of the FC hidden layer.
out_neurons : int, default 1
Number of output neurons ``N``.
output_activation : nn.Module, default ``ParametricSoftplus(out_neurons)``
Pointwise nonlinearity at the output. The default is unbounded
above and non-negative — natural for spike-count regression.
prefiltering : nn.Module, optional
Optional spectrogram prefilter (``AdapTrans``, ``ICAdaptation``, or
any module exposing ``out_channels``). ``None`` (default) gives
``nn.Identity`` and ``C_in = 1``.
References
----------
Pennington & David (2023). "A convolutional neural network provides
a generalizable model of natural sound coding by neural populations
in auditory cortex." PLOS Comp. Biol. 19(5): e1011110.
https://doi.org/10.1371/journal.pcbi.1011110
Notes
-----
Differences from the original paper:
- Causal LayerNorm replaces the missing internal normalization
(paper uses none).
- Hidden activation is ``LeakyReLU(0.1)`` rather than ReLU. Empirical
preference, very small architectural difference.
- 2D convs over ``(F, T)`` rather than 1D convs over ``T`` (the
paper applies 1D convolutions with implicit spectral pooling).
- Frequency downsampling is implicit via valid-padding shrinkage:
three convs each shrink ``F`` by ``K_F - 1``, giving
``F_down = F - 3*(K_F - 1)``.
- Causal left-padding extends the model to arbitrary input lengths;
the paper also uses explicit causal padding.
- The output activation is configurable; the paper uses a 4-parameter
double-exponential — see ``deepSTRF.models.activations.ParametricDoubleExponential``.
"""
def __init__(self, n_frequency_bands: int = 34, kernel_size: tuple = (3, 9),
c_hidden: int = 10, n_hidden: int = 20,
out_neurons: int = 1,
output_activation: nn.Module = None,
prefiltering: nn.Module = None):
temporal_window_size = 3 * (kernel_size[1] - 1)
super().__init__(
n_frequency_bands=n_frequency_bands,
temporal_window_size=temporal_window_size,
out_neurons=out_neurons,
prefiltering=prefiltering,
)
self.K = kernel_size
self.C = c_hidden
self.H = n_hidden
# core: causal left-pad → 3× (Conv2d → LN → LeakyReLU)
# → flatten (C, F_down) into a single feature axis.
# Three convs each shrink time by K_T-1 (and frequency by K_F-1); the
# explicit left-pad of 3*(K_T-1) zeros restores the time length.
F_down = self.F - 3 * (self.K[0] - 1) # frequency dim after 3 convs
self.core = nn.Sequential(
nn.ZeroPad2d((3 * (self.K[1] - 1), 0, 0, 0)),
nn.Conv2d(self.C_in, self.C, kernel_size=self.K, stride=1),
layers.CausalLayerNorm(self.C, dim=1),
nn.LeakyReLU(0.1),
nn.Conv2d(self.C, self.C, kernel_size=self.K, stride=1),
layers.CausalLayerNorm(self.C, dim=1),
nn.LeakyReLU(0.1),
nn.Conv2d(self.C, self.C, kernel_size=self.K, stride=1),
layers.CausalLayerNorm(self.C, dim=1),
nn.LeakyReLU(0.1),
nn.Flatten(start_dim=1, end_dim=2), # (B, C, F_down, T) → (B, C*F_down, T)
)
# readout: per-timestep MLP (in → hidden → N) with output activation.
self.readout = LinearReadout(
in_features=self.C * F_down,
out_neurons=self.O,
hidden=self.H,
activation=(output_activation if output_activation is not None
else ParametricSoftplus(self.O)),
)
# forward inherited from NeuralModel
[docs]
class StateNet(AudioEncodingModel):
"""
Fully stateful STRF model — relies entirely on temporal recurrence to
extract information from stimulus sequences, with no explicit STRF
delay window.
Architecture: a stateless per-timestep spectral encoder maps each
spectrogram column ``(C_in, F)`` to a hidden representation
``(C, F_down)``. The flattened hidden representation is fed
timestep-by-timestep to a recurrent (or state-space) model that
accumulates context implicitly through its hidden state. A linear
readout projects the recurrent hidden state to ``N`` output neurons.
Causality is inherent to the recurrent backbone (RNN/GRU/LSTM/LMU/
Mamba/S4). The spectral encoder operates on a single timestep at a
time so it does not couple frames temporally.
Parameters
----------
n_frequency_bands : int, default 34
Number of input frequency bands ``F``.
temporal_window_size : int, default 1
Unused by StateNet (kept for ``AudioEncodingModel`` API
compatibility); recurrence handles temporal context.
kernel_size : int, default 7
Frequency kernel size for the spectral encoder.
stride : int, default 3
Frequency stride for the spectral encoder.
hidden_channels : int, default 7
Channel count of the spectral encoder ``C``.
connectivity : {'LC', 'FC', 'CONV'}, default 'LC'
Spectral encoder connectivity. ``'LC'``: locally-connected 1D
layer (frequency-position-specific weights). ``'FC'``: dense
linear projection with reshape to ``(C, F_down)``. ``'CONV'``:
weight-shared 1D convolution.
rnn_type : {'GRU', 'LSTM', 'RNN', 'vanilla', 'LMU', 'Mamba', 'S4'}, default 'GRU'
Recurrent / state-space backbone.
out_neurons : int, default 1
Number of output neurons ``N``.
output_activation : nn.Module, default ``ParametricSoftplus(out_neurons)``
Pointwise nonlinearity at the output. The default is unbounded
above and non-negative — natural for spike-count regression.
prefiltering : nn.Module, optional
Optional spectrogram prefilter (``AdapTrans``, ``ICAdaptation``, or
any module exposing ``out_channels``). ``None`` (default) gives
``nn.Identity`` and ``C_in = 1``.
References
----------
Rançon, Masquelier & Cottereau (2025). "Temporal recurrence as a
general mechanism to explain neural responses in the auditory
system." Communications Biology 8:1456.
https://doi.org/10.1038/s42003-025-08858-3
Notes
-----
- The spectral encoder uses a CausalLayerNorm over the channel
axis (``C``); the original implementation used BatchNorm1d
which pools statistics over the (T*B, F_down) axis, making it
non-causal.
- The S4 backbone is imported lazily — its module emits CUDA-extension
warnings on import that other backends would not see.
"""
def __init__(self, n_frequency_bands: int = 34, temporal_window_size: int = 1,
kernel_size: int = 7, stride: int = 3, hidden_channels: int = 7,
connectivity: str = 'LC', rnn_type: str = 'GRU',
out_neurons: int = 1,
output_activation: nn.Module = None,
prefiltering: nn.Module = None,
wav2spec: nn.Module = None):
super().__init__(
n_frequency_bands=n_frequency_bands,
temporal_window_size=temporal_window_size,
out_neurons=out_neurons,
prefiltering=prefiltering,
wav2spec=wav2spec,
)
self.K = kernel_size
self.S = stride
self.C = hidden_channels
self.rnn_type = rnn_type
# stateless per-timestep spectral encoder: (B, C_in, F) -> (B, C, F_down).
# Three connectivity options share the same input/output shape contract.
if connectivity == 'LC':
self.encoder_layers = nn.Sequential(
layers.LocallyConnected1d(input_size=self.F, in_channels=self.C_in,
out_channels=self.C, kernel_size=self.K, stride=self.S),
layers.CausalLayerNorm(self.C, dim=1),
nn.Sigmoid(),
)
elif connectivity == 'FC':
F_down = int((self.F - kernel_size) / self.S + 1)
self.encoder_layers = nn.Sequential(
nn.Flatten(start_dim=-2, end_dim=-1), # (B, C_in, F) -> (B, C_in*F)
nn.Linear(self.C_in * self.F, self.C * F_down),
nn.Unflatten(dim=-1, unflattened_size=(self.C, F_down)), # (B, C, F_down)
layers.CausalLayerNorm(self.C, dim=1),
nn.Sigmoid(),
)
elif connectivity == 'CONV':
self.encoder_layers = nn.Sequential(
nn.Conv1d(self.C_in, self.C, kernel_size=self.K, stride=self.S),
layers.CausalLayerNorm(self.C, dim=1),
nn.Sigmoid(),
)
else:
raise NotImplementedError(
f"connectivity must be 'LC', 'FC', or 'CONV', got {connectivity!r}"
)
# F_down after the spectral encoder: used to size the RNN's input.
self.L = (n_frequency_bands - self.K) // self.S + 1
self.H = self.L * self.C
# recurrent / state-space backbone
if self.rnn_type == 'GRU':
self.rnn = nn.GRU(input_size=self.H, hidden_size=self.H, num_layers=1, batch_first=True)
elif self.rnn_type == 'LSTM':
self.rnn = nn.LSTM(input_size=self.H, hidden_size=self.H, num_layers=1, batch_first=True)
elif self.rnn_type in ('vanilla', 'RNN'):
self.rnn = nn.RNN(input_size=self.H, hidden_size=self.H, num_layers=1, batch_first=True)
elif self.rnn_type == 'LMU':
self.rnn = LMU(input_size=self.H, hidden_size=self.H, memory_size=128,
theta=99, learn_a=False, learn_b=False)
elif self.rnn_type == 'Mamba':
self.rnn = MambaBlock(MambaConfig(d_model=self.H, n_layers=1))
elif self.rnn_type == 'S4':
from deepSTRF.models.dependencies.s4 import S4Block
self.rnn = S4Block(d_model=self.H, transposed=False)
else:
raise NotImplementedError(
f"unknown rnn_type {rnn_type!r}: choose 'GRU', 'LSTM', 'RNN', 'vanilla', "
f"'LMU', 'Mamba', or 'S4'"
)
# per-neuron readout from the recurrent hidden state.
self.readout = LinearReadout(
in_features=self.H, out_neurons=self.O,
activation=(output_activation if output_activation is not None
else ParametricSoftplus(self.O)),
)
[docs]
def forward(self, x):
"""
Per-timestep spectral encoder feeding a recurrent backbone.
Overrides the base template because the encoder runs in a flattened
(T*B, C_in, F) batch (so every timestep is independent in the
spectral pass) and the RNN expects a batch-first ``(B, T, H)``
layout — these reshapes don't decompose into the canonical core /
readout slots.
"""
x = self.wav2spec(x) # (B,1,T_audio)->
# x: (B, 1, F, T) # (B,1,F,T) in wav mode; Identity in spec mode
y = self.prefiltering(x) # (B, C_in, F, T)
y = y.permute(3, 0, 1, 2) # (T, B, C_in, F)
T_, B = y.shape[:2]
y = y.flatten(0, 1) # (T*B, C_in, F)
y = self.encoder_layers(y) # (T*B, C, F_down)
y = y.view(T_, B, self.C, self.L) # (T, B, C, F_down)
y = y.flatten(start_dim=2, end_dim=3).permute(1, 0, 2) # (B, T, C*F_down=H)
# RNN/SSM backbone — most modules return (output, hidden); Mamba is
# the exception (returns a single tensor).
if self.rnn_type == 'Mamba':
y = self.rnn(y)
else:
y, _ = self.rnn(y) # (B, T, H)
y = y.transpose(-2, -1) # (B, H, T) — readout convention
return self.readout(y) # (B, N, 1, T)