deepSTRF.models package
Subpackages
Submodules
deepSTRF.models.layers module
- class deepSTRF.models.layers.CausalLayerNorm(normalized_shape, dim: int = 1, eps: float = 1e-05, elementwise_affine: bool = True)[source]
Bases:
ModuleLayerNorm applied to a non-trailing axis of its input — equivalently, LayerNorm computed independently at every position of every other axis.
Strictly causal: never pools statistics across time. Drop-in replacement for
nn.BatchNorm{1,2,3}din models that need to stay causal.- Parameters:
normalized_shape (
int) – Size of the axis being normalized.dim (
int, default1) – Index of the axis to normalize.dim=1(default) targets the channel axis of an(B, C, ...)tensor; usedim=-2to normalize the frequency axis of an(B, C, F, T)audio spectrogram (the axis just before time).eps (
float, default1e-5) – Numerical stability term forwarded tonn.LayerNorm.elementwise_affine (
bool, defaultTrue) – Whether to learn per-element scale and shift.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class deepSTRF.models.layers.CausalSTRFConv(F: int, T: int, C_in: int, C_out: int, kernel: Module = None, bias: bool = True)[source]
Bases:
ModuleA causal Spectro-Temporal Receptive Field convolution:
T-1zeros are prepended along the time axis, then a 2D STRF kernel of shape(C_out, C_in, F, T)is applied with valid padding. Output time length matches input time length.The actual STRF kernel module is pluggable via the
kernelkwarg. The default is a plainnn.Conv2d; passingParametricSTRFor a separable-kernelnn.Sequentialswaps in alternative parameterizations without changing the model that holds this layer.- Parameters:
F (
int) – Spectrogram frequency bins (the height of the kernel).T (
int) – Temporal extent of the STRF in frames (the width of the kernel).C_in (
int) – Input channel count (typically 1, or 2 with AdapTrans prefiltering).C_out (
int) – Output channel count. Hidden width when used inside a core,Nwhen used inside an STRF readout.kernel (
nn.Module, optional) – Pre-built kernel module. Must produce(B, C_out, 1, T_in)from(B, C_in, F, T_in + T - 1).None(default) instantiates a vanillann.Conv2d(C_in, C_out, kernel_size=(F, T)).bias (
bool, defaultTrue) – Used only whenkernel is None.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- STRF_weight()[source]
Return the effective STRF kernel as a
(C_out, C_in, F, T)tensor, detached and on CPU.Works across kernel types: vanilla
nn.Conv2d,ParametricSTRF(DCLS), and the frequency-time separablenn.Sequentialvariant.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class deepSTRF.models.layers.LearnableExponentialDecay(input_size: int, kernel_size: int, init_tau: float = 2.0, decay_input: bool = True)[source]
Bases:
ModulePer-band learnable exponential-decay low-pass filter.
Convolves each frequency band of a
(B, 1, F, T)spectrogram with a causal exponential kernel whose time constant is learned per band, and returns a low-pass version of the same shape. The decay parameterization follows Rahman et al. (DNet) and Fang et al. (PLIF).- Parameters:
input_size (
int) – Number of frequency bandsF(one learnable time constant each).kernel_size (
int) – Temporal extentKof the decay kernel in frames.init_tau (
float, default2.0) – Mean initial time constant (frames) for the decay parameters.decay_input (
bool, defaultTrue) – If True, scale the kernel so the filtered input keeps unit DC gain.
Notes
Currently single input / output channel only, and processes a 2-D
(B, 1, F, T)input as a stack of 1-D temporal convolutions.Initialize internal Module state, shared by both nn.Module and ScriptModule.
- build_kernel(device='cpu')[source]
Build the per-band decay kernel of shape
(input_size, 1, kernel_size).The kernel is convolved with the last (temporal) axis of the input.
- Parameters:
device (
torch.deviceorstr, default'cpu') – Device on which to build the kernel.- Returns:
Kernel of shape
(input_size, 1, kernel_size).- Return type:
torch.Tensor
- class deepSTRF.models.layers.LocallyConnected1d(input_size, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True)[source]
Bases:
ModuleLocally connected (LC) layer for 1-D tensors — a trade-off between Linear and Conv1d.
Like a convolution, but the kernel weights are not shared across positions: every output position has its own filter.
Notes
nn.Unfoldonly accepts images, so 1-D inputs are first unsqueezed to 2-D, the unfold / conv operations are performed, and the result is squeezed back to 1-D.Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class deepSTRF.models.layers.ParametricSTRF(F: int, T: int, C_in: int, C_out: int, num_gaussians: int = 1, bias: bool = True)[source]
Bases:
ModuleSpectro-Temporal Receptive Field kernel parameterized as a sum of learnable 2D Gaussians on the
(F, T)grid.A direct PyTorch reimplementation of the DCLS Gaussian-mixture parameterization (Khalfaoui-Hassani et al. 2023, ICLR), free from the upstream library’s silent asymmetric-kernel bug. Each of the
num_gaussiansGaussians has:a 2D position
(f, t)in the kernel grid coordinates[0, F-1] × [0, T-1],per-axis standard deviations
(sigma_f, sigma_t),per-(C_out, C_in) weight.
The effective
(C_out, C_in, F, T)kernel is the weighted sum of Gaussians, normalized so each Gaussian has unit mass on the grid (DCLS convention).- Parameters:
F (
int) – Frequency bins of the kernel.T (
int) – Temporal extent of the kernel in frames.C_in (
int) – Input / output channel counts.C_out (
int) – Input / output channel counts.num_gaussians (
int, default1) – Number of Gaussians per(C_out, C_in)slot.bias (
bool, defaultTrue) – Whether to add a per-output-channel bias (conv2d convention).
References
Khalfaoui-Hassani, Pellegrini & Masquelier (2023). “Dilated Convolution with Learnable Spacings.” ICLR.
Notes
The upstream DCLS library’s ConstructKernel2d silently mishandles asymmetric kernels: for
dilated_kernel_size=(F, T)withF != T, its position-offset step+lim//2adds the F-half-width to the T-axis position parameter and vice versa, concentrating Gaussians near the centre of one axis and outside the grid on the other. This is the cause of the observed “Gaussians don’t populate the entire STRF window” behavior on auditory STRF shapes like(34, 9). The deepSTRF reimplementation parametrizes positions in absolute grid coordinates[0, F-1] × [0, T-1], avoids the offset entirely, and removes the optional DCLS dependency.Initialize internal Module state, shared by both nn.Module and ScriptModule.
- build_kernel(device=None)[source]
Build the effective
(C_out, C_in, F, T)kernel from the K Gaussians.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class deepSTRF.models.layers.SeparableSTRF(F: int, T: int, C_in, C_out, bias: bool = True)[source]
Bases:
ModuleFrequency-time separable Spectro-Temporal Receptive Field (2D) kernel.
The effective
(C_out, C_in, F, T)kernel is the rank-1 outer productw_F(f) · w_T(t)of two per-(C_out, C_in)factors. Drastically reduces parameter count compared to a vanillann.Conv2dSTRF (C_out·C_in·(F + T)vsC_out·C_in·F·T) while preserving the conv2d call signature so it drops in as akernel=arg on any STRFReadout-using model.- Parameters:
F (
int) – Frequency bins of the kernel.T (
int) – Temporal extent of the kernel in frames.C_in (
int) – Input / output channel counts.C_out (
int) – Input / output channel counts.bias (
bool, defaultTrue) – Whether to add a per-output-channel bias (conv2d convention).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class deepSTRF.models.layers.SinusoidalPositionalEncoding(d_model: int)[source]
Bases:
ModuleSinusoidal positional encoding (Vaswani et al. 2017, “Attention Is All You Need”). Computed on the fly per forward pass, so the same module generalizes to arbitrary sequence lengths.
- Parameters:
d_model (
int) – Embedding dimension. Must be even (the encoding alternatessinandcosalong the channel axis).
Notes
Adds the encoding to the input rather than returning it separately. Input shape
(B, L, d_model), output shape(B, L, d_model).Per-dimension frequencies follow the standard
1 / 10000^(2i / d_model)schedule and are stored as a non-trained buffer so they follow.to(device).Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- deepSTRF.models.layers.build_causal_window_mask(L: int, window: int = None, device=None) Tensor[source]
Build a
(L, L)attention mask for causal (optionally windowed) self-attention.Position
iattends to positionjiffj <= i(causal) and, whenwindowis set,i - j < window(otherwise the past is unlimited).- Parameters:
L (
int) – Sequence length.window (
int, optional) – Maximum look-back distance.None(default) allows attending to the entire past.device (
torch.device, optional) – Device for the returned mask.
- Returns:
A
(L, L)bool tensor whereTruemeans “mask out / forbid attention” — matchingnn.TransformerEncoderLayerandF.scaled_dot_product_attention.- Return type:
torch.Tensor
deepSTRF.models.neural_model module
- class deepSTRF.models.neural_model.NeuralModel(out_neurons: int = 1, *args, **kwargs)[source]
Bases:
Module,ABCBase class for encoding models of sensory neural responses.
A four-slot template defines the canonical forward pipeline:
forward(x): x = self.wav2spec(x) # raw-waveform front-end (future) x = self.prefiltering(x) # AdapTrans / ICAdaptation / Identity f = self.core(x) # shared feature backbone return self.readout(f) # per-neuron projection (B, N, 1, T)
Concrete subclasses populate the slots in their
__init__. The defaults forwav2spec,prefilteringandcorearenn.Identity, so a minimal model only needs to provide areadout. Subclasses may overrideforward()for architectures that don’t fit the four-slot pipeline (e.g. StateNet’s recurrent reshape, Transformer’s per-frame attention).See
docs/_source/md/model_paradigm.mdfor the full contract.- Parameters:
out_neurons (
int, default1) – Number of output neuronsNthe model predicts. Stored onself.Oand used byvalidate()andSTRF_gradmap.
Notes
Subclasses inherit a Hugging Face Hub interface for pretrained checkpoints:
save_pretrained()— write config + weights to a folder.push_to_hub()— upload to the HF Hub (auth required).from_pretrained()— instantiate and load weights.
The init kwargs needed to rebuild the architecture are auto-captured on construction (see
__init_subclass__()), so end-users never have to supply a config dict —StateNet.from_pretrained("urancon/...")just works.Initialize internal Module state, shared by both nn.Module and ScriptModule.
- count_trainable_params()[source]
Count the model’s trainable parameters.
- Returns:
Total number of parameters with
requires_grad=True.- Return type:
int
- detach()[source]
Detach stateful variables from the computational graph.
No-op by default; recurrent subclasses override this to truncate backpropagation-through-time between chunks (cf. spikingjelly).
- forward(stimulus)[source]
Run the default template pipeline.
Applies
wav2spec→prefiltering→core→readoutin sequence.- Parameters:
stimulus (
torch.Tensor) – Input stimulus batch (shape is modality-dependent; for audio models a spectrogram(B, F, T)).- Returns:
Predicted response of shape
(B, N, 1, T).- Return type:
torch.Tensor
- classmethod from_pretrained(repo_id_or_path: str | Path, *, extra_kwargs: dict | None = None, strict: bool = True, map_location: str | device = 'cpu', cache_dir: str | Path | None = None, revision: str | None = None, token: str | None = None, return_metadata: bool = False)[source]
Instantiate
clsfrom an HF Hub repo or a local checkpoint folder.- Parameters:
repo_id_or_path (
strorPath) –"<owner>/<name>"(HF Hub) or a path to a local folder produced bysave_pretrained(). The decision is made bypathlib.Path.is_dir()— if the path exists locally it wins, otherwise we treat the string as a Hub repo id.extra_kwargs (
dict, optional) – Override / extend the saved config. Use this to re-supply any__init__argument that wasn’t JSON-serialisable at save time (e.g. a customprefilteringmodule).strict (
bool, defaultTrue) –state_dictstrictness.map_location (
strortorch.device, default :py:class:``’cpu’:py:class:``)cache_dir – Forwarded to
huggingface_hub.snapshot_download().revision – Forwarded to
huggingface_hub.snapshot_download().token – Forwarded to
huggingface_hub.snapshot_download().return_metadata (
bool, defaultFalse) – If True, return(model, metadata)instead of justmodel.
- Returns:
Instance of
clswith weights loaded.- Return type:
model
- push_to_hub(repo_id: str, *, metadata: dict | None = None, model_card: str | None = None, private: bool = False, token: str | None = None, commit_message: str | None = None) str[source]
Push this model to
repo_idon the HF Hub.Saves the checkpoint to a temporary folder and uploads it via
deepSTRF.utils.hub.upload_pretrained(). Creates the repo on the fly if it doesn’t exist; user must be authenticated with write access (hf auth loginortoken=).- Returns:
URL of the resulting commit.
- Return type:
str
- save_pretrained(save_dir: str | Path, *, metadata: dict | None = None, model_card: str | None = None) Path[source]
Write a checkpoint folder (
config.json+model.safetensors).See
deepSTRF.utils.hub.save_pretrained_to_dir()for the full contract.- Parameters:
save_dir (
strorpathlib.Path) – Destination folder; created if it does not exist.metadata (
dict, optional) – Extra JSON-serialisable metadata stored alongside the config (e.g. training dataset, val/test scores).model_card (
str, optional) – Markdown content written toREADME.mdin the folder.
- Returns:
Path to the written checkpoint folder.
- Return type:
pathlib.Path
- validate()[source]
Check that the instance is deepSTRF-compatible.
Subclasses should call
super().validate()and then add their own checks (e.g.AudioEncodingModelchecksF, T > 0).- Raises:
AssertionError – If
self.Ois not a positive int, ifreadoutis unset or not antorch.nn.Module, or if any of thewav2spec/prefiltering/coreslots is not antorch.nn.Module.
deepSTRF.models.prefiltering module
- class deepSTRF.models.prefiltering.AdapTrans(init_a_vals, init_w_vals, kernel_size: int = 2, learnable: bool = True)[source]
Bases:
ModuleAdaptive ON/OFF spectrogram prefilter — the learnable extension of the inferior-colliculus adaptation prefilter.
Computes ON and OFF spectrograms through high-pass exponential filters with frequency-dependent, learnable time constants. Each frequency band is independently filtered along the temporal dimension with a parameterized exponential kernel:
kernel = […; -Cwa²; -Cwa; -Cw; +1] with C = 1/(… + a² + a + 1)
where the sum of the negative terms equals
w. The filter computes the difference between the current value of the signal in each frequency band and an exponential average of its recent past, with separate(a, w)pairs giving rise to ON and OFF polarities.- Parameters:
init_a_vals (
1D Tensoroflength ``F``) – Per-frequencyaparameters (related to the time constant).init_w_vals (
1D Tensoroflength ``F``) – Per-frequencywparameters (relative weight of the past average vs the present sample).kernel_size (
int, default2) – Length of the temporal kernel (in frames).learnable (
bool, defaultTrue) – If True,aandware learnable nn.Parameters; if False, they are frozen buffers (still follow.to(device)).
References
Rançon, Masquelier & Cottereau (2024). “A general model unifying the adaptive, transient and sustained properties of ON and OFF auditory neural responses.” PLOS Computational Biology 20(8):e1012288. https://doi.org/10.1371/journal.pcbi.1012288
Notes
Input shape:
(B, 1, F, T). Output shape:(B, 2, F, T)— channel 0 is ON, channel 1 is OFF.- init_a_vals: a 1D vector of ‘a’ parameters (related to the time constant of the kernel’s exponential). The
higher the ‘a’, the higher the corresponding time constant of the exponential
- init_w_vals: a 1D vector of ‘w’ parameters (representing the weight given to the exponential average of the
signal in its recent past)
- OFF_kernel(d, p)[source]
Creates the OFF kernel — the ON kernel flipped about zero, then renormalized so its tail equals -w.
- build_kernels()[source]
- Creates two parametrized kernels:
one for the ON response, highlighting onsets in the signal
one for the OFF response (offsets), which is the flipped version of the ON kernel
- forward(spectro_in)[source]
Compute the ON and OFF high-pass-filtered spectrograms.
- Parameters:
spectro_in (
torch.Tensor) – Input spectrogram of shape(B, 1, F, T)(B=batch, 1 channel, F frequency bands, T timesteps).- Returns:
Tensor of shape
(B, 2, F, T): channel 0 is the ON response, channel 1 is the OFF response (both half-wave rectified).- Return type:
torch.Tensor
- out_channels: int = 2
- class deepSTRF.models.prefiltering.ICAdaptation(init_a_vals, kernel_size: int = 2)[source]
Bases:
ModuleHigh-pass exponential filter with frequency-dependent time constants — a paper-faithful re-implementation of the inferior-colliculus adaptation prefilter described by Willmore et al. (2016).
Independently filters each frequency band of an input spectrogram along the temporal dimension with a parameterized exponential kernel:
kernel = […; -Cwa²; -Cwa; -Cw; +1] with C = 1/(… + a² + a + 1)
where the sum of the negative terms equals
w. The filter effectively computes the difference between the current value of the signal in each frequency band and an exponential average of its recent past, then applies a half-wave rectification.- Parameters:
init_a_vals (
1D Tensoroflength ``F``) – Per-frequencyaparameters (related to the exponential time constant: highera→ longer time constant).kernel_size (
int, default2) – Length of the temporal kernel (in frames).
References
Willmore, Schoppe, King, Schnupp, Harper (2016). “Incorporating Midbrain Adaptation to Mean Sound Level Improves Models of Auditory Cortical Processing.” J. Neurosci. 36(2): 280–289. https://doi.org/10.1523/JNEUROSCI.2441-15.2016
Notes
Intentionally non-learnable: the time constants are derived analytically from the cochlear frequency map (see
freq_to_tau) and are paper-faithful. For a learnable extension, useAdapTrans.Input shape:
(B, 1, F, T). Output shape:(B, 1, F, T).Initialize internal Module state, shared by both nn.Module and ScriptModule.
- build_kernels()[source]
Creates a parametrized kernel: a high-pass exponential filter that highlights onsets in the signal.
- forward(spectro_in)[source]
High-pass filter and half-wave rectify each frequency band.
- Parameters:
spectro_in (
torch.Tensor) – Input spectrogram of shape(B, 1, F, T)(B=batch, 1 channel, F frequency bands, T timesteps).- Returns:
High-pass-filtered, full-wave-rectified spectrogram of shape
(B, 1, F, T).- Return type:
torch.Tensor
- out_channels: int = 1
- deepSTRF.models.prefiltering.Willmore_Adaptation
alias of
ICAdaptation
- deepSTRF.models.prefiltering.a_to_tau(a, dt: float = 1)[source]
Inverse of
tau_to_a():aparameter back to a time constant (ms).- Parameters:
a (
torch.Tensor) – Dimensionlessaparameters.dt (
float, default1) – Time-step width in ms.
- Returns:
Time constants in ms.
- Return type:
torch.Tensor
- deepSTRF.models.prefiltering.freq_to_tau(freqs)[source]
Map frequencies (Hz) to midbrain-neuron time constants (ms).
- Parameters:
freqs (
torch.Tensor) – Frequencies in Hz.- Returns:
Associated time constants in ms.
- Return type:
torch.Tensor
References
Willmore et al. (2016). “Incorporating Midbrain Adaptation to Mean Sound Level Improves Models of Auditory Cortical Processing.”
- deepSTRF.models.prefiltering.get_CFs(min_freq, max_freq, n_freqs, scale)[source]
Return
n_freqscochlear/center frequencies (CFs) on a warped scale.- Parameters:
min_freq (
float) – Frequency range in Hz.max_freq (
float) – Frequency range in Hz.n_freqs (
int) – Number of CFs to return.scale (
{'mel', 'greenwood'}) – Frequency-axis warping used to space the CFs.
- Returns:
The
n_freqscenter frequencies in Hz.- Return type:
torch.Tensor- Raises:
NotImplementedError – If
scaleis not'mel'or'greenwood'.
- deepSTRF.models.prefiltering.make_prefiltering(kind: str, n_frequency_bands: int, dt: float, min_freq: float = 500.0, max_freq: float = 20000.0, scale: str = 'mel', learnable: bool = True, init_w: float = 0.75) Module[source]
Factory for constructing a prefilter module from compact arguments.
Convenience wrapper that derives per-frequency
a(andw) initial values from the cochlear frequency map, then instantiates the requested prefilter class. Equivalent to building the prefilter by hand; the factory exists so that user code does not need to repeat theget_CFs/freq_to_tau/tau_to_apipeline.- Parameters:
kind (
{'adaptrans', 'icadaptation', 'willmore'}) – Which prefilter to build.'willmore'is an alias for'icadaptation'.n_frequency_bands (
int) – Number of input frequency bandsFof the spectrogram.dt (
float) – Time bin width in milliseconds (matchesdataset.dt_ms).min_freq (
float) – Frequency range (in Hz) spanned by the cochlear filterbank that produced the spectrogram. Defaults: 500 / 20 000 Hz.max_freq (
float) – Frequency range (in Hz) spanned by the cochlear filterbank that produced the spectrogram. Defaults: 500 / 20 000 Hz.scale (
{'mel', 'greenwood'}, default'mel') – Frequency-axis scaling used to derive per-band time constants.learnable (
bool, defaultTrue) – Only relevant for'adaptrans'.ICAdaptationis always frozen (paper-faithful).init_w (
float, default0.75) – Only relevant for'adaptrans': initial value of the past-vs- present weightw.
- Returns:
Configured prefilter instance with an
out_channelsattribute.- Return type:
nn.Module
- deepSTRF.models.prefiltering.tau_to_a(time_constants, dt: float = 1)[source]
Convert physical time constants (ms) to dimensionless
aparameters.- Parameters:
time_constants (
torch.Tensor) – Time constants in ms.dt (
float, default1) – Time-step width in ms.
- Returns:
The corresponding
a = exp(-dt / tau)parameters.- Return type:
torch.Tensor
deepSTRF.models.readouts module
Readout modules: per-neuron projections that take a feature representation
emitted by a model’s core and produce predictions of shape
(B, N, R=1, T).
Two flavours are shipped:
STRFReadout— a learnable STRF kernel applied causally viaCausalSTRFConv. Used when the readout itself is the model’s main learnable apparatus (Linear / LinearNonlinear) or when an explicit STRF interpretation of the per-neuron weights is wanted.LinearReadout— a per-timestep linear projectionin_features -> N, with an optional 1-hidden-layer MLP. Used by models whosecorealready produces a flat per-timestep feature vector (ConvNet2D / Transformer / StateNet).
See docs/_source/md/model_paradigm.md §7 for the full readout
contract.
- class deepSTRF.models.readouts.LinearReadout(in_features: int, out_neurons: int, hidden: int = None, activation: Module = None, bias: bool = True)[source]
Bases:
ModulePer-neuron readout: a per-timestep linear projection from a flat feature vector to
Noutput neurons.Accepts either a 3D input
(B, in_features, T)or a 4D input(B, in_features, 1, T)(the singleton spatial axis emitted by an STRF-style conv that collapsedF → 1); both shapes route through the same projection. The output is the canonical(B, N, 1, T).- Parameters:
in_features (
int) – Per-timestep feature dimension produced by the model’score.out_neurons (
int) – Number of output neuronsN.hidden (
int, optional) – If given, inserts a 1-hidden-layer MLPin_features → hidden → Nwith a LeakyReLU(0.1) between.None(default) gives a single linear projection.activation (
nn.Module, optional) – Pointwise output nonlinearity applied to the per-timestep output before the rank is unsqueezed to(B, N, 1, T). Defaults tonn.Identity.bias (
bool, defaultTrue) – Whether the linear projection(s) include a bias term.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class deepSTRF.models.readouts.STRFReadout(F: int, T: int, C_in: int, out_neurons: int, kernel: Module = None, activation: Module = None, bias: bool = True)[source]
Bases:
ModulePer-neuron readout backed by a causal Spectro-Temporal Receptive Field kernel, with a per-neuron BatchNorm placed after the conv.
Wraps a
CausalSTRFConvof shape(N, C_in, F, T)and applies an output activation. The frequency axis is collapsed by the conv fromF → 1, so the readout naturally emits the canonical(B, N, R=1, T)rank.The per-neuron
nn.BatchNorm1d(N)between the conv and the activation stabilises training and serves as the model’s only normalisation layer in the Linear / LinearNonlinear cases — every learnable scalar in this readout (STRF kernel, conv bias, BN affine, BN running stats) has the neuron axis as leading dim, so the readout is strictly no-shared-params. Causality in eval mode is preserved: BN’s running statistics are per-channel scalars, applied element-wise on the time axis at inference.- Parameters:
F (
int) – Frequency bins of the input spectrogram.T (
int) – STRF temporal extent.C_in (
int) – Input channel count after prefiltering.out_neurons (
int) – Number of output neuronsN.kernel (
nn.Module, optional) – Pluggable kernel module.None(default) instantiates a vanillann.Conv2d; passParametricSTRF(DCLS) or a separablenn.Sequentialto swap parameterizations.activation (
nn.Module, optional) – Pointwise output nonlinearity. Defaults tonn.Identity.bias (
bool, defaultTrue) – Used only whenkernel is None.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- STRF_weight(polarity: str = None)[source]
Return the underlying STRF kernel as
(N, C_in, F, T).- Parameters:
polarity (
{'ON', 'OFF', None}, optional) – If the model’s prefilter producesC_in == 2channels (e.g. AdapTrans’s ON/OFF), select one.None(default) returns the full(N, C_in, F, T)tensor;'ON'slices channel 0;'OFF'slices channel 1.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
deepSTRF.models.activations module
Output activations for deepSTRF readouts.
Three parametric activations, each with opt-out non-negativity
reparameterisation that pairs naturally with
poisson_loss(log_input=False) (see metrics_paradigm.md §6.2):
ParametricSigmoid— saturating, bounded above (Willmore 2016).ParametricDoubleExponential— saturating, bounded above (Thorson 2015).ParametricSoftplus— non-saturating, unbounded above. Natural default for spike-count regression where the response can take any non-negative magnitude.
All three follow the same input-shape contract: parameters are
(N,)-shaped per-neuron tensors that broadcast against the last axis
of the input. Use these inside readouts that emit
(..., N)-last-axis tensors (e.g. LinearReadout).
- class deepSTRF.models.activations.ParametricDoubleExponential(num_features: int, bias: bool = True, non_negative_output: bool = True)[source]
Bases:
Module4-parameter parametric double-exponential (Thorson et al. 2015).
Per-neuron output:
\[f(x) = a \cdot \exp(-\exp(k \cdot x - s)) + b \quad\text{(bias=True)}\]where
ais the saturated firing rate,bthe baseline,sthe firing threshold, andkthe gain.- Parameters:
num_features (
int) –N: number of independent per-neuron parameter sets.bias (
bool, defaultTrue) – Whether to include the additive baselineb.non_negative_output (
bool, defaultTrue) – When True,a(andb, ifbias=True) are stored as raw parameters and softplus-mapped to the strictly-positive half-line at every forward pass. The innerexp(-exp(k·x − s))term is always in(0, 1], so this fully guaranteesf(x) ≥ 0. When False, parameters are direct (signed-output mode).
References
Thorson, I. L., Liénard, J. & David, S. V. (2015). “The essential complexity of auditory receptive fields.” PLOS Computational Biology, 11(3), e1004228.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- property a: Tensor
- property b: Tensor | None
- extra_repr() str[source]
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x: Tensor) Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class deepSTRF.models.activations.ParametricSigmoid(num_features: int, bias: bool = True, non_negative_output: bool = True)[source]
Bases:
Module4-parameter parametric sigmoid (Willmore et al. 2016).
Per-neuron output:
\[f(x) = b \cdot \sigma((x - c) / d) + a \quad\text{(bias=True)}\]where
bis the dynamic range,athe baseline (minimum firing rate),cthe input inflection point, anddthe reciprocal of the gain.- Parameters:
num_features (
int) –N: number of independent per-neuron parameter sets.bias (
bool, defaultTrue) – Whether to include the additive baselinea.non_negative_output (
bool, defaultTrue) – When True,b(anda, ifbias=True) are stored as raw parameters and softplus-mapped to the strictly-positive half-line at every forward pass. This guarantees a non-negative output curve, suitable for spike-count targets andpoisson_loss. When False, parameters are direct (signed-output mode) — useful for LFP / EEG / centred PSTH targets where outputs may legitimately be negative.
Notes
The shipped behaviour replaces an earlier closure-based implementation that built
forwardinside__init__. The current version uses a standardforward()method and exposesbandavia@propertyso thatsoftplusis re-applied on the live parameter values at every step (ensuresstate_dictround-trips and parameter-replacement work correctly).References
Willmore, B. D. B., Schoppe, O., King, A. J., Schnupp, J. W. H. & Harper, N. S. (2016). “Incorporating midbrain adaptation to mean sound level improves models of auditory cortical processing.” Journal of Neuroscience, 36(2), 280–289.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- property a: Tensor | None
Baseline parameter, post-reparameterisation.
Noneifbias=False.
- property b: Tensor
Dynamic-range parameter, post-reparameterisation.
- extra_repr() str[source]
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x: Tensor) Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class deepSTRF.models.activations.ParametricSoftplus(num_features: int, non_negative_output: bool = True)[source]
Bases:
ModulePer-neuron Softplus with learnable sharpness β and additive baseline.
Output:
\[f(x) = \frac{1}{\beta} \log\!\bigl(1 + \exp(\beta x)\bigr) + b\]where
β > 0is the per-neuron sharpness (β → ∞ approaches ReLU, β → 0 approaches a soft linear with slope ½) andbis the per-neuron additive baseline. Both are always learnable per-neuron.Unlike
ParametricSigmoidandParametricDoubleExponential, the underlying curve is unbounded above — natural for spike-count regression on smoothed PSTHs that can take any non-negative magnitude (e.g. NS1: peaks ~3 spikes/bin after Hsu/Borst/Theunissen 21 ms Hanning smoothing).- Parameters:
num_features (
int) –N: number of independent per-neuron parameter sets.non_negative_output (
bool, defaultTrue) – When True,bis stored as a raw parameter and softplus-mapped to the strictly-positive half-line at every forward pass. Combined with the always-non-negative softplus core, this guaranteesf(x) ≥ 0for everyx. When False,bis unconstrained — pair withpoisson_loss(log_input=True)for signed-output targets (LFP / EEG / centred PSTH).
Notes
βis always non-negativity-reparameterised (β = softplus(_raw_beta)) regardless ofnon_negative_output, because a non-positive sharpness would flip the curve and is never physically meaningful.The implementation expects an input whose last axis is N (matching the
ParametricSigmoid/ParametricDoubleExponentialcontract); that is whatLinearReadoutandSTRFReadoutemit at the activation step.Initialize internal Module state, shared by both nn.Module and ScriptModule.
- property b: Tensor
Per-neuron additive baseline, post-reparameterisation.
- property beta: Tensor
Per-neuron sharpness, always > 0.
- extra_repr() str[source]
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x: Tensor) Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
deepSTRF.models.scales module
- deepSTRF.models.scales.ERB_bandwidth(f)[source]
Equivalent rectangular bandwidth (Hz) at centre frequency
f(Hz), Glasberg & Moore (1990):ERB(f) = 24.7 · (0.00437 f + 1).
- deepSTRF.models.scales.ERB_to_Hz(erb)[source]
Inverse of
Hz_to_ERB().
- deepSTRF.models.scales.Greenwood(x, animal='human')[source]
Greenwood cochlear position-to-frequency map.
- Parameters:
x (
torch.Tensororfloat) – Normalised position along the cochlea in[0, 1].animal (
{'human', 'mouse'}, default'human') – Species whose Greenwood coefficients to use.
- Returns:
Characteristic frequency in Hz.
- Return type:
torch.Tensororfloat- Raises:
NotImplementedError – If
animalis not'human'or'mouse'.
- deepSTRF.models.scales.Hz_to_ERB(f)[source]
ERB-rate (ERB-number) scale of Glasberg & Moore (1990).
E(f) = 21.4 · log10(0.00437 f + 1)forfin Hz. Equal steps on this scale are equal numbers of equivalent-rectangular-bandwidths apart — the standard cochlear frequency axis for gammatone filterbanks.
- deepSTRF.models.scales.Hz_to_mel(f)[source]
Convert a frequency (Hz) to the mel scale.
- Parameters:
f (
torch.Tensororfloat) – Frequency in Hz.- Returns:
Frequency on the mel scale.
- Return type:
torch.Tensororfloat
- deepSTRF.models.scales.inverse_Greenwood(f, animal='human')[source]
Inverse Greenwood map: frequency (Hz) to normalised cochlear position.
- Parameters:
f (
torch.Tensororfloat) – Characteristic frequency in Hz.animal (
{'human', 'mouse'}, default'human') – Species whose Greenwood coefficients to use.
- Returns:
Normalised position along the cochlea in
[0, 1].- Return type:
torch.Tensororfloat- Raises:
NotImplementedError – If
animalis not'human'or'mouse'.
- deepSTRF.models.scales.mel_to_Hz(mel_freqs)[source]
Convert mel-scale frequencies back to Hz (inverse of
Hz_to_mel()).- Parameters:
mel_freqs (
torch.Tensororfloat) – Frequency on the mel scale.- Returns:
Frequency in Hz.
- Return type:
torch.Tensororfloat
Module contents
- class deepSTRF.models.NeuralModel(out_neurons: int = 1, *args, **kwargs)[source]
Bases:
Module,ABCBase class for encoding models of sensory neural responses.
A four-slot template defines the canonical forward pipeline:
forward(x): x = self.wav2spec(x) # raw-waveform front-end (future) x = self.prefiltering(x) # AdapTrans / ICAdaptation / Identity f = self.core(x) # shared feature backbone return self.readout(f) # per-neuron projection (B, N, 1, T)
Concrete subclasses populate the slots in their
__init__. The defaults forwav2spec,prefilteringandcorearenn.Identity, so a minimal model only needs to provide areadout. Subclasses may overrideforward()for architectures that don’t fit the four-slot pipeline (e.g. StateNet’s recurrent reshape, Transformer’s per-frame attention).See
docs/_source/md/model_paradigm.mdfor the full contract.- Parameters:
out_neurons (
int, default1) – Number of output neuronsNthe model predicts. Stored onself.Oand used byvalidate()andSTRF_gradmap.
Notes
Subclasses inherit a Hugging Face Hub interface for pretrained checkpoints:
save_pretrained()— write config + weights to a folder.push_to_hub()— upload to the HF Hub (auth required).from_pretrained()— instantiate and load weights.
The init kwargs needed to rebuild the architecture are auto-captured on construction (see
__init_subclass__()), so end-users never have to supply a config dict —StateNet.from_pretrained("urancon/...")just works.Initialize internal Module state, shared by both nn.Module and ScriptModule.
- count_trainable_params()[source]
Count the model’s trainable parameters.
- Returns:
Total number of parameters with
requires_grad=True.- Return type:
int
- detach()[source]
Detach stateful variables from the computational graph.
No-op by default; recurrent subclasses override this to truncate backpropagation-through-time between chunks (cf. spikingjelly).
- forward(stimulus)[source]
Run the default template pipeline.
Applies
wav2spec→prefiltering→core→readoutin sequence.- Parameters:
stimulus (
torch.Tensor) – Input stimulus batch (shape is modality-dependent; for audio models a spectrogram(B, F, T)).- Returns:
Predicted response of shape
(B, N, 1, T).- Return type:
torch.Tensor
- classmethod from_pretrained(repo_id_or_path: str | Path, *, extra_kwargs: dict | None = None, strict: bool = True, map_location: str | device = 'cpu', cache_dir: str | Path | None = None, revision: str | None = None, token: str | None = None, return_metadata: bool = False)[source]
Instantiate
clsfrom an HF Hub repo or a local checkpoint folder.- Parameters:
repo_id_or_path (
strorPath) –"<owner>/<name>"(HF Hub) or a path to a local folder produced bysave_pretrained(). The decision is made bypathlib.Path.is_dir()— if the path exists locally it wins, otherwise we treat the string as a Hub repo id.extra_kwargs (
dict, optional) – Override / extend the saved config. Use this to re-supply any__init__argument that wasn’t JSON-serialisable at save time (e.g. a customprefilteringmodule).strict (
bool, defaultTrue) –state_dictstrictness.map_location (
strortorch.device, default :py:class:``’cpu’:py:class:``)cache_dir – Forwarded to
huggingface_hub.snapshot_download().revision – Forwarded to
huggingface_hub.snapshot_download().token – Forwarded to
huggingface_hub.snapshot_download().return_metadata (
bool, defaultFalse) – If True, return(model, metadata)instead of justmodel.
- Returns:
Instance of
clswith weights loaded.- Return type:
model
- push_to_hub(repo_id: str, *, metadata: dict | None = None, model_card: str | None = None, private: bool = False, token: str | None = None, commit_message: str | None = None) str[source]
Push this model to
repo_idon the HF Hub.Saves the checkpoint to a temporary folder and uploads it via
deepSTRF.utils.hub.upload_pretrained(). Creates the repo on the fly if it doesn’t exist; user must be authenticated with write access (hf auth loginortoken=).- Returns:
URL of the resulting commit.
- Return type:
str
- save_pretrained(save_dir: str | Path, *, metadata: dict | None = None, model_card: str | None = None) Path[source]
Write a checkpoint folder (
config.json+model.safetensors).See
deepSTRF.utils.hub.save_pretrained_to_dir()for the full contract.- Parameters:
save_dir (
strorpathlib.Path) – Destination folder; created if it does not exist.metadata (
dict, optional) – Extra JSON-serialisable metadata stored alongside the config (e.g. training dataset, val/test scores).model_card (
str, optional) – Markdown content written toREADME.mdin the folder.
- Returns:
Path to the written checkpoint folder.
- Return type:
pathlib.Path
- validate()[source]
Check that the instance is deepSTRF-compatible.
Subclasses should call
super().validate()and then add their own checks (e.g.AudioEncodingModelchecksF, T > 0).- Raises:
AssertionError – If
self.Ois not a positive int, ifreadoutis unset or not antorch.nn.Module, or if any of thewav2spec/prefiltering/coreslots is not antorch.nn.Module.