deepSTRF.models package

Subpackages

Submodules

deepSTRF.models.layers module

class deepSTRF.models.layers.CausalLayerNorm(normalized_shape, dim: int = 1, eps: float = 1e-05, elementwise_affine: bool = True)[source]

Bases: Module

LayerNorm applied to a non-trailing axis of its input — equivalently, LayerNorm computed independently at every position of every other axis.

Strictly causal: never pools statistics across time. Drop-in replacement for nn.BatchNorm{1,2,3}d in models that need to stay causal.

Parameters:
  • normalized_shape (int) – Size of the axis being normalized.

  • dim (int, default 1) – Index of the axis to normalize. dim=1 (default) targets the channel axis of an (B, C, ...) tensor; use dim=-2 to normalize the frequency axis of an (B, C, F, T) audio spectrogram (the axis just before time).

  • eps (float, default 1e-5) – Numerical stability term forwarded to nn.LayerNorm.

  • elementwise_affine (bool, default True) – Whether to learn per-element scale and shift.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepSTRF.models.layers.CausalSTRFConv(F: int, T: int, C_in: int, C_out: int, kernel: Module = None, bias: bool = True)[source]

Bases: Module

A causal Spectro-Temporal Receptive Field convolution: T-1 zeros are prepended along the time axis, then a 2D STRF kernel of shape (C_out, C_in, F, T) is applied with valid padding. Output time length matches input time length.

The actual STRF kernel module is pluggable via the kernel kwarg. The default is a plain nn.Conv2d; passing ParametricSTRF or a separable-kernel nn.Sequential swaps in alternative parameterizations without changing the model that holds this layer.

Parameters:
  • F (int) – Spectrogram frequency bins (the height of the kernel).

  • T (int) – Temporal extent of the STRF in frames (the width of the kernel).

  • C_in (int) – Input channel count (typically 1, or 2 with AdapTrans prefiltering).

  • C_out (int) – Output channel count. Hidden width when used inside a core, N when used inside an STRF readout.

  • kernel (nn.Module, optional) – Pre-built kernel module. Must produce (B, C_out, 1, T_in) from (B, C_in, F, T_in + T - 1). None (default) instantiates a vanilla nn.Conv2d(C_in, C_out, kernel_size=(F, T)).

  • bias (bool, default True) – Used only when kernel is None.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRF_weight()[source]

Return the effective STRF kernel as a (C_out, C_in, F, T) tensor, detached and on CPU.

Works across kernel types: vanilla nn.Conv2d, ParametricSTRF (DCLS), and the frequency-time separable nn.Sequential variant.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepSTRF.models.layers.LearnableExponentialDecay(input_size: int, kernel_size: int, init_tau: float = 2.0, decay_input: bool = True)[source]

Bases: Module

Per-band learnable exponential-decay low-pass filter.

Convolves each frequency band of a (B, 1, F, T) spectrogram with a causal exponential kernel whose time constant is learned per band, and returns a low-pass version of the same shape. The decay parameterization follows Rahman et al. (DNet) and Fang et al. (PLIF).

Parameters:
  • input_size (int) – Number of frequency bands F (one learnable time constant each).

  • kernel_size (int) – Temporal extent K of the decay kernel in frames.

  • init_tau (float, default 2.0) – Mean initial time constant (frames) for the decay parameters.

  • decay_input (bool, default True) – If True, scale the kernel so the filtered input keeps unit DC gain.

Notes

Currently single input / output channel only, and processes a 2-D (B, 1, F, T) input as a stack of 1-D temporal convolutions.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

build_kernel(device='cpu')[source]

Build the per-band decay kernel of shape (input_size, 1, kernel_size).

The kernel is convolved with the last (temporal) axis of the input.

Parameters:

device (torch.device or str, default 'cpu') – Device on which to build the kernel.

Returns:

Kernel of shape (input_size, 1, kernel_size).

Return type:

torch.Tensor

forward(x)[source]

Low-pass each frequency band of a 1-channel spectrogram.

Parameters:

x (torch.Tensor) – Input spectrogram of shape (B, 1, F, T) (B=batch, 1 channel, F frequency bands, T timesteps).

Returns:

Low-pass-filtered spectrogram of the same shape.

Return type:

torch.Tensor

tau()[source]
class deepSTRF.models.layers.LocallyConnected1d(input_size, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True)[source]

Bases: Module

Locally connected (LC) layer for 1-D tensors — a trade-off between Linear and Conv1d.

Like a convolution, but the kernel weights are not shared across positions: every output position has its own filter.

Notes

nn.Unfold only accepts images, so 1-D inputs are first unsqueezed to 2-D, the unfold / conv operations are performed, and the result is squeezed back to 1-D.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepSTRF.models.layers.ParametricSTRF(F: int, T: int, C_in: int, C_out: int, num_gaussians: int = 1, bias: bool = True)[source]

Bases: Module

Spectro-Temporal Receptive Field kernel parameterized as a sum of learnable 2D Gaussians on the (F, T) grid.

A direct PyTorch reimplementation of the DCLS Gaussian-mixture parameterization (Khalfaoui-Hassani et al. 2023, ICLR), free from the upstream library’s silent asymmetric-kernel bug. Each of the num_gaussians Gaussians has:

  • a 2D position (f, t) in the kernel grid coordinates [0, F-1] × [0, T-1],

  • per-axis standard deviations (sigma_f, sigma_t),

  • per-(C_out, C_in) weight.

The effective (C_out, C_in, F, T) kernel is the weighted sum of Gaussians, normalized so each Gaussian has unit mass on the grid (DCLS convention).

Parameters:
  • F (int) – Frequency bins of the kernel.

  • T (int) – Temporal extent of the kernel in frames.

  • C_in (int) – Input / output channel counts.

  • C_out (int) – Input / output channel counts.

  • num_gaussians (int, default 1) – Number of Gaussians per (C_out, C_in) slot.

  • bias (bool, default True) – Whether to add a per-output-channel bias (conv2d convention).

References

Khalfaoui-Hassani, Pellegrini & Masquelier (2023). “Dilated Convolution with Learnable Spacings.” ICLR.

Notes

The upstream DCLS library’s ConstructKernel2d silently mishandles asymmetric kernels: for dilated_kernel_size=(F, T) with F != T, its position-offset step +lim//2 adds the F-half-width to the T-axis position parameter and vice versa, concentrating Gaussians near the centre of one axis and outside the grid on the other. This is the cause of the observed “Gaussians don’t populate the entire STRF window” behavior on auditory STRF shapes like (34, 9). The deepSTRF reimplementation parametrizes positions in absolute grid coordinates [0, F-1] × [0, T-1], avoids the offset entirely, and removes the optional DCLS dependency.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

build_kernel(device=None)[source]

Build the effective (C_out, C_in, F, T) kernel from the K Gaussians.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepSTRF.models.layers.SeparableSTRF(F: int, T: int, C_in, C_out, bias: bool = True)[source]

Bases: Module

Frequency-time separable Spectro-Temporal Receptive Field (2D) kernel.

The effective (C_out, C_in, F, T) kernel is the rank-1 outer product w_F(f) · w_T(t) of two per-(C_out, C_in) factors. Drastically reduces parameter count compared to a vanilla nn.Conv2d STRF (C_out·C_in·(F + T) vs C_out·C_in·F·T) while preserving the conv2d call signature so it drops in as a kernel= arg on any STRFReadout-using model.

Parameters:
  • F (int) – Frequency bins of the kernel.

  • T (int) – Temporal extent of the kernel in frames.

  • C_in (int) – Input / output channel counts.

  • C_out (int) – Input / output channel counts.

  • bias (bool, default True) – Whether to add a per-output-channel bias (conv2d convention).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

build_kernel(device='cpu')[source]
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepSTRF.models.layers.SinusoidalPositionalEncoding(d_model: int)[source]

Bases: Module

Sinusoidal positional encoding (Vaswani et al. 2017, “Attention Is All You Need”). Computed on the fly per forward pass, so the same module generalizes to arbitrary sequence lengths.

Parameters:

d_model (int) – Embedding dimension. Must be even (the encoding alternates sin and cos along the channel axis).

Notes

Adds the encoding to the input rather than returning it separately. Input shape (B, L, d_model), output shape (B, L, d_model).

Per-dimension frequencies follow the standard 1 / 10000^(2i / d_model) schedule and are stored as a non-trained buffer so they follow .to(device).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepSTRF.models.layers.build_causal_window_mask(L: int, window: int = None, device=None) Tensor[source]

Build a (L, L) attention mask for causal (optionally windowed) self-attention.

Position i attends to position j iff j <= i (causal) and, when window is set, i - j < window (otherwise the past is unlimited).

Parameters:
  • L (int) – Sequence length.

  • window (int, optional) – Maximum look-back distance. None (default) allows attending to the entire past.

  • device (torch.device, optional) – Device for the returned mask.

Returns:

A (L, L) bool tensor where True means “mask out / forbid attention” — matching nn.TransformerEncoderLayer and F.scaled_dot_product_attention.

Return type:

torch.Tensor

deepSTRF.models.neural_model module

class deepSTRF.models.neural_model.NeuralModel(out_neurons: int = 1, *args, **kwargs)[source]

Bases: Module, ABC

Base class for encoding models of sensory neural responses.

A four-slot template defines the canonical forward pipeline:

forward(x):
    x = self.wav2spec(x)        # raw-waveform front-end (future)
    x = self.prefiltering(x)    # AdapTrans / ICAdaptation / Identity
    f = self.core(x)            # shared feature backbone
    return self.readout(f)      # per-neuron projection (B, N, 1, T)

Concrete subclasses populate the slots in their __init__. The defaults for wav2spec, prefiltering and core are nn.Identity, so a minimal model only needs to provide a readout. Subclasses may override forward() for architectures that don’t fit the four-slot pipeline (e.g. StateNet’s recurrent reshape, Transformer’s per-frame attention).

See docs/_source/md/model_paradigm.md for the full contract.

Parameters:

out_neurons (int, default 1) – Number of output neurons N the model predicts. Stored on self.O and used by validate() and STRF_gradmap.

Notes

Subclasses inherit a Hugging Face Hub interface for pretrained checkpoints:

The init kwargs needed to rebuild the architecture are auto-captured on construction (see __init_subclass__()), so end-users never have to supply a config dict — StateNet.from_pretrained("urancon/...") just works.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

count_trainable_params()[source]

Count the model’s trainable parameters.

Returns:

Total number of parameters with requires_grad=True.

Return type:

int

detach()[source]

Detach stateful variables from the computational graph.

No-op by default; recurrent subclasses override this to truncate backpropagation-through-time between chunks (cf. spikingjelly).

forward(stimulus)[source]

Run the default template pipeline.

Applies wav2specprefilteringcorereadout in sequence.

Parameters:

stimulus (torch.Tensor) – Input stimulus batch (shape is modality-dependent; for audio models a spectrogram (B, F, T)).

Returns:

Predicted response of shape (B, N, 1, T).

Return type:

torch.Tensor

classmethod from_pretrained(repo_id_or_path: str | Path, *, extra_kwargs: dict | None = None, strict: bool = True, map_location: str | device = 'cpu', cache_dir: str | Path | None = None, revision: str | None = None, token: str | None = None, return_metadata: bool = False)[source]

Instantiate cls from an HF Hub repo or a local checkpoint folder.

Parameters:
  • repo_id_or_path (str or Path) – "<owner>/<name>" (HF Hub) or a path to a local folder produced by save_pretrained(). The decision is made by pathlib.Path.is_dir() — if the path exists locally it wins, otherwise we treat the string as a Hub repo id.

  • extra_kwargs (dict, optional) – Override / extend the saved config. Use this to re-supply any __init__ argument that wasn’t JSON-serialisable at save time (e.g. a custom prefiltering module).

  • strict (bool, default True) – state_dict strictness.

  • map_location (str or torch.device, default :py:class:``’cpu’:py:class:``)

  • cache_dir – Forwarded to huggingface_hub.snapshot_download().

  • revision – Forwarded to huggingface_hub.snapshot_download().

  • token – Forwarded to huggingface_hub.snapshot_download().

  • return_metadata (bool, default False) – If True, return (model, metadata) instead of just model.

Returns:

Instance of cls with weights loaded.

Return type:

model

push_to_hub(repo_id: str, *, metadata: dict | None = None, model_card: str | None = None, private: bool = False, token: str | None = None, commit_message: str | None = None) str[source]

Push this model to repo_id on the HF Hub.

Saves the checkpoint to a temporary folder and uploads it via deepSTRF.utils.hub.upload_pretrained(). Creates the repo on the fly if it doesn’t exist; user must be authenticated with write access (hf auth login or token=).

Returns:

URL of the resulting commit.

Return type:

str

save_pretrained(save_dir: str | Path, *, metadata: dict | None = None, model_card: str | None = None) Path[source]

Write a checkpoint folder (config.json + model.safetensors).

See deepSTRF.utils.hub.save_pretrained_to_dir() for the full contract.

Parameters:
  • save_dir (str or pathlib.Path) – Destination folder; created if it does not exist.

  • metadata (dict, optional) – Extra JSON-serialisable metadata stored alongside the config (e.g. training dataset, val/test scores).

  • model_card (str, optional) – Markdown content written to README.md in the folder.

Returns:

Path to the written checkpoint folder.

Return type:

pathlib.Path

validate()[source]

Check that the instance is deepSTRF-compatible.

Subclasses should call super().validate() and then add their own checks (e.g. AudioEncodingModel checks F, T > 0).

Raises:

AssertionError – If self.O is not a positive int, if readout is unset or not an torch.nn.Module, or if any of the wav2spec / prefiltering / core slots is not an torch.nn.Module.

deepSTRF.models.prefiltering module

class deepSTRF.models.prefiltering.AdapTrans(init_a_vals, init_w_vals, kernel_size: int = 2, learnable: bool = True)[source]

Bases: Module

Adaptive ON/OFF spectrogram prefilter — the learnable extension of the inferior-colliculus adaptation prefilter.

Computes ON and OFF spectrograms through high-pass exponential filters with frequency-dependent, learnable time constants. Each frequency band is independently filtered along the temporal dimension with a parameterized exponential kernel:

kernel = […; -Cwa²; -Cwa; -Cw; +1] with C = 1/(… + a² + a + 1)

where the sum of the negative terms equals w. The filter computes the difference between the current value of the signal in each frequency band and an exponential average of its recent past, with separate (a, w) pairs giving rise to ON and OFF polarities.

Parameters:
  • init_a_vals (1D Tensor of length ``F``) – Per-frequency a parameters (related to the time constant).

  • init_w_vals (1D Tensor of length ``F``) – Per-frequency w parameters (relative weight of the past average vs the present sample).

  • kernel_size (int, default 2) – Length of the temporal kernel (in frames).

  • learnable (bool, default True) – If True, a and w are learnable nn.Parameters; if False, they are frozen buffers (still follow .to(device)).

References

Rançon, Masquelier & Cottereau (2024). “A general model unifying the adaptive, transient and sustained properties of ON and OFF auditory neural responses.” PLOS Computational Biology 20(8):e1012288. https://doi.org/10.1371/journal.pcbi.1012288

Notes

Input shape: (B, 1, F, T). Output shape: (B, 2, F, T) — channel 0 is ON, channel 1 is OFF.

init_a_vals: a 1D vector of ‘a’ parameters (related to the time constant of the kernel’s exponential). The

higher the ‘a’, the higher the corresponding time constant of the exponential

init_w_vals: a 1D vector of ‘w’ parameters (representing the weight given to the exponential average of the

signal in its recent past)

OFF_kernel(d, p)[source]

Creates the OFF kernel — the ON kernel flipped about zero, then renormalized so its tail equals -w.

ON_kernel(d, p)[source]

Creates the ON kernel

build_kernels()[source]
Creates two parametrized kernels:
  • one for the ON response, highlighting onsets in the signal

  • one for the OFF response (offsets), which is the flipped version of the ON kernel

forward(spectro_in)[source]

Compute the ON and OFF high-pass-filtered spectrograms.

Parameters:

spectro_in (torch.Tensor) – Input spectrogram of shape (B, 1, F, T) (B=batch, 1 channel, F frequency bands, T timesteps).

Returns:

Tensor of shape (B, 2, F, T): channel 0 is the ON response, channel 1 is the OFF response (both half-wave rectified).

Return type:

torch.Tensor

get_a()[source]
get_w()[source]
out_channels: int = 2
plot_kernels(frequency_bin=0)[source]
class deepSTRF.models.prefiltering.ICAdaptation(init_a_vals, kernel_size: int = 2)[source]

Bases: Module

High-pass exponential filter with frequency-dependent time constants — a paper-faithful re-implementation of the inferior-colliculus adaptation prefilter described by Willmore et al. (2016).

Independently filters each frequency band of an input spectrogram along the temporal dimension with a parameterized exponential kernel:

kernel = […; -Cwa²; -Cwa; -Cw; +1] with C = 1/(… + a² + a + 1)

where the sum of the negative terms equals w. The filter effectively computes the difference between the current value of the signal in each frequency band and an exponential average of its recent past, then applies a half-wave rectification.

Parameters:
  • init_a_vals (1D Tensor of length ``F``) – Per-frequency a parameters (related to the exponential time constant: higher a → longer time constant).

  • kernel_size (int, default 2) – Length of the temporal kernel (in frames).

References

Willmore, Schoppe, King, Schnupp, Harper (2016). “Incorporating Midbrain Adaptation to Mean Sound Level Improves Models of Auditory Cortical Processing.” J. Neurosci. 36(2): 280–289. https://doi.org/10.1523/JNEUROSCI.2441-15.2016

Notes

Intentionally non-learnable: the time constants are derived analytically from the cochlear frequency map (see freq_to_tau) and are paper-faithful. For a learnable extension, use AdapTrans.

Input shape: (B, 1, F, T). Output shape: (B, 1, F, T).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

build_kernels()[source]

Creates a parametrized kernel: a high-pass exponential filter that highlights onsets in the signal.

forward(spectro_in)[source]

High-pass filter and half-wave rectify each frequency band.

Parameters:

spectro_in (torch.Tensor) – Input spectrogram of shape (B, 1, F, T) (B=batch, 1 channel, F frequency bands, T timesteps).

Returns:

High-pass-filtered, full-wave-rectified spectrogram of shape (B, 1, F, T).

Return type:

torch.Tensor

out_channels: int = 1
plot_kernels(frequency_bin=0)[source]
deepSTRF.models.prefiltering.Willmore_Adaptation

alias of ICAdaptation

deepSTRF.models.prefiltering.a_to_tau(a, dt: float = 1)[source]

Inverse of tau_to_a(): a parameter back to a time constant (ms).

Parameters:
  • a (torch.Tensor) – Dimensionless a parameters.

  • dt (float, default 1) – Time-step width in ms.

Returns:

Time constants in ms.

Return type:

torch.Tensor

deepSTRF.models.prefiltering.freq_to_tau(freqs)[source]

Map frequencies (Hz) to midbrain-neuron time constants (ms).

Parameters:

freqs (torch.Tensor) – Frequencies in Hz.

Returns:

Associated time constants in ms.

Return type:

torch.Tensor

References

Willmore et al. (2016). “Incorporating Midbrain Adaptation to Mean Sound Level Improves Models of Auditory Cortical Processing.”

deepSTRF.models.prefiltering.get_CFs(min_freq, max_freq, n_freqs, scale)[source]

Return n_freqs cochlear/center frequencies (CFs) on a warped scale.

Parameters:
  • min_freq (float) – Frequency range in Hz.

  • max_freq (float) – Frequency range in Hz.

  • n_freqs (int) – Number of CFs to return.

  • scale ({'mel', 'greenwood'}) – Frequency-axis warping used to space the CFs.

Returns:

The n_freqs center frequencies in Hz.

Return type:

torch.Tensor

Raises:

NotImplementedError – If scale is not 'mel' or 'greenwood'.

deepSTRF.models.prefiltering.make_prefiltering(kind: str, n_frequency_bands: int, dt: float, min_freq: float = 500.0, max_freq: float = 20000.0, scale: str = 'mel', learnable: bool = True, init_w: float = 0.75) Module[source]

Factory for constructing a prefilter module from compact arguments.

Convenience wrapper that derives per-frequency a (and w) initial values from the cochlear frequency map, then instantiates the requested prefilter class. Equivalent to building the prefilter by hand; the factory exists so that user code does not need to repeat the get_CFs / freq_to_tau / tau_to_a pipeline.

Parameters:
  • kind ({'adaptrans', 'icadaptation', 'willmore'}) – Which prefilter to build. 'willmore' is an alias for 'icadaptation'.

  • n_frequency_bands (int) – Number of input frequency bands F of the spectrogram.

  • dt (float) – Time bin width in milliseconds (matches dataset.dt_ms).

  • min_freq (float) – Frequency range (in Hz) spanned by the cochlear filterbank that produced the spectrogram. Defaults: 500 / 20 000 Hz.

  • max_freq (float) – Frequency range (in Hz) spanned by the cochlear filterbank that produced the spectrogram. Defaults: 500 / 20 000 Hz.

  • scale ({'mel', 'greenwood'}, default 'mel') – Frequency-axis scaling used to derive per-band time constants.

  • learnable (bool, default True) – Only relevant for 'adaptrans'. ICAdaptation is always frozen (paper-faithful).

  • init_w (float, default 0.75) – Only relevant for 'adaptrans': initial value of the past-vs- present weight w.

Returns:

Configured prefilter instance with an out_channels attribute.

Return type:

nn.Module

deepSTRF.models.prefiltering.tau_to_a(time_constants, dt: float = 1)[source]

Convert physical time constants (ms) to dimensionless a parameters.

Parameters:
  • time_constants (torch.Tensor) – Time constants in ms.

  • dt (float, default 1) – Time-step width in ms.

Returns:

The corresponding a = exp(-dt / tau) parameters.

Return type:

torch.Tensor

deepSTRF.models.readouts module

Readout modules: per-neuron projections that take a feature representation emitted by a model’s core and produce predictions of shape (B, N, R=1, T).

Two flavours are shipped:

  • STRFReadout — a learnable STRF kernel applied causally via CausalSTRFConv. Used when the readout itself is the model’s main learnable apparatus (Linear / LinearNonlinear) or when an explicit STRF interpretation of the per-neuron weights is wanted.

  • LinearReadout — a per-timestep linear projection in_features -> N, with an optional 1-hidden-layer MLP. Used by models whose core already produces a flat per-timestep feature vector (ConvNet2D / Transformer / StateNet).

See docs/_source/md/model_paradigm.md §7 for the full readout contract.

class deepSTRF.models.readouts.LinearReadout(in_features: int, out_neurons: int, hidden: int = None, activation: Module = None, bias: bool = True)[source]

Bases: Module

Per-neuron readout: a per-timestep linear projection from a flat feature vector to N output neurons.

Accepts either a 3D input (B, in_features, T) or a 4D input (B, in_features, 1, T) (the singleton spatial axis emitted by an STRF-style conv that collapsed F 1); both shapes route through the same projection. The output is the canonical (B, N, 1, T).

Parameters:
  • in_features (int) – Per-timestep feature dimension produced by the model’s core.

  • out_neurons (int) – Number of output neurons N.

  • hidden (int, optional) – If given, inserts a 1-hidden-layer MLP in_features hidden N with a LeakyReLU(0.1) between. None (default) gives a single linear projection.

  • activation (nn.Module, optional) – Pointwise output nonlinearity applied to the per-timestep output before the rank is unsqueezed to (B, N, 1, T). Defaults to nn.Identity.

  • bias (bool, default True) – Whether the linear projection(s) include a bias term.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepSTRF.models.readouts.STRFReadout(F: int, T: int, C_in: int, out_neurons: int, kernel: Module = None, activation: Module = None, bias: bool = True)[source]

Bases: Module

Per-neuron readout backed by a causal Spectro-Temporal Receptive Field kernel, with a per-neuron BatchNorm placed after the conv.

Wraps a CausalSTRFConv of shape (N, C_in, F, T) and applies an output activation. The frequency axis is collapsed by the conv from F 1, so the readout naturally emits the canonical (B, N, R=1, T) rank.

The per-neuron nn.BatchNorm1d(N) between the conv and the activation stabilises training and serves as the model’s only normalisation layer in the Linear / LinearNonlinear cases — every learnable scalar in this readout (STRF kernel, conv bias, BN affine, BN running stats) has the neuron axis as leading dim, so the readout is strictly no-shared-params. Causality in eval mode is preserved: BN’s running statistics are per-channel scalars, applied element-wise on the time axis at inference.

Parameters:
  • F (int) – Frequency bins of the input spectrogram.

  • T (int) – STRF temporal extent.

  • C_in (int) – Input channel count after prefiltering.

  • out_neurons (int) – Number of output neurons N.

  • kernel (nn.Module, optional) – Pluggable kernel module. None (default) instantiates a vanilla nn.Conv2d; pass ParametricSTRF (DCLS) or a separable nn.Sequential to swap parameterizations.

  • activation (nn.Module, optional) – Pointwise output nonlinearity. Defaults to nn.Identity.

  • bias (bool, default True) – Used only when kernel is None.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

STRF_weight(polarity: str = None)[source]

Return the underlying STRF kernel as (N, C_in, F, T).

Parameters:

polarity ({'ON', 'OFF', None}, optional) – If the model’s prefilter produces C_in == 2 channels (e.g. AdapTrans’s ON/OFF), select one. None (default) returns the full (N, C_in, F, T) tensor; 'ON' slices channel 0; 'OFF' slices channel 1.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepSTRF.models.activations module

Output activations for deepSTRF readouts.

Three parametric activations, each with opt-out non-negativity reparameterisation that pairs naturally with poisson_loss(log_input=False) (see metrics_paradigm.md §6.2):

All three follow the same input-shape contract: parameters are (N,)-shaped per-neuron tensors that broadcast against the last axis of the input. Use these inside readouts that emit (..., N)-last-axis tensors (e.g. LinearReadout).

class deepSTRF.models.activations.ParametricDoubleExponential(num_features: int, bias: bool = True, non_negative_output: bool = True)[source]

Bases: Module

4-parameter parametric double-exponential (Thorson et al. 2015).

Per-neuron output:

\[f(x) = a \cdot \exp(-\exp(k \cdot x - s)) + b \quad\text{(bias=True)}\]

where a is the saturated firing rate, b the baseline, s the firing threshold, and k the gain.

Parameters:
  • num_features (int) – N: number of independent per-neuron parameter sets.

  • bias (bool, default True) – Whether to include the additive baseline b.

  • non_negative_output (bool, default True) – When True, a (and b, if bias=True) are stored as raw parameters and softplus-mapped to the strictly-positive half-line at every forward pass. The inner exp(-exp(k·x s)) term is always in (0, 1], so this fully guarantees f(x) 0. When False, parameters are direct (signed-output mode).

References

Thorson, I. L., Liénard, J. & David, S. V. (2015). “The essential complexity of auditory receptive fields.” PLOS Computational Biology, 11(3), e1004228.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

property a: Tensor
property b: Tensor | None
extra_repr() str[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepSTRF.models.activations.ParametricSigmoid(num_features: int, bias: bool = True, non_negative_output: bool = True)[source]

Bases: Module

4-parameter parametric sigmoid (Willmore et al. 2016).

Per-neuron output:

\[f(x) = b \cdot \sigma((x - c) / d) + a \quad\text{(bias=True)}\]

where b is the dynamic range, a the baseline (minimum firing rate), c the input inflection point, and d the reciprocal of the gain.

Parameters:
  • num_features (int) – N: number of independent per-neuron parameter sets.

  • bias (bool, default True) – Whether to include the additive baseline a.

  • non_negative_output (bool, default True) – When True, b (and a, if bias=True) are stored as raw parameters and softplus-mapped to the strictly-positive half-line at every forward pass. This guarantees a non-negative output curve, suitable for spike-count targets and poisson_loss. When False, parameters are direct (signed-output mode) — useful for LFP / EEG / centred PSTH targets where outputs may legitimately be negative.

Notes

The shipped behaviour replaces an earlier closure-based implementation that built forward inside __init__. The current version uses a standard forward() method and exposes b and a via @property so that softplus is re-applied on the live parameter values at every step (ensures state_dict round-trips and parameter-replacement work correctly).

References

Willmore, B. D. B., Schoppe, O., King, A. J., Schnupp, J. W. H. & Harper, N. S. (2016). “Incorporating midbrain adaptation to mean sound level improves models of auditory cortical processing.” Journal of Neuroscience, 36(2), 280–289.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

property a: Tensor | None

Baseline parameter, post-reparameterisation. None if bias=False.

property b: Tensor

Dynamic-range parameter, post-reparameterisation.

extra_repr() str[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepSTRF.models.activations.ParametricSoftplus(num_features: int, non_negative_output: bool = True)[source]

Bases: Module

Per-neuron Softplus with learnable sharpness β and additive baseline.

Output:

\[f(x) = \frac{1}{\beta} \log\!\bigl(1 + \exp(\beta x)\bigr) + b\]

where β > 0 is the per-neuron sharpness (β → ∞ approaches ReLU, β → 0 approaches a soft linear with slope ½) and b is the per-neuron additive baseline. Both are always learnable per-neuron.

Unlike ParametricSigmoid and ParametricDoubleExponential, the underlying curve is unbounded above — natural for spike-count regression on smoothed PSTHs that can take any non-negative magnitude (e.g. NS1: peaks ~3 spikes/bin after Hsu/Borst/Theunissen 21 ms Hanning smoothing).

Parameters:
  • num_features (int) – N: number of independent per-neuron parameter sets.

  • non_negative_output (bool, default True) – When True, b is stored as a raw parameter and softplus-mapped to the strictly-positive half-line at every forward pass. Combined with the always-non-negative softplus core, this guarantees f(x) 0 for every x. When False, b is unconstrained — pair with poisson_loss(log_input=True) for signed-output targets (LFP / EEG / centred PSTH).

Notes

β is always non-negativity-reparameterised (β = softplus(_raw_beta)) regardless of non_negative_output, because a non-positive sharpness would flip the curve and is never physically meaningful.

The implementation expects an input whose last axis is N (matching the ParametricSigmoid / ParametricDoubleExponential contract); that is what LinearReadout and STRFReadout emit at the activation step.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

property b: Tensor

Per-neuron additive baseline, post-reparameterisation.

property beta: Tensor

Per-neuron sharpness, always > 0.

extra_repr() str[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepSTRF.models.scales module

deepSTRF.models.scales.ERB_bandwidth(f)[source]

Equivalent rectangular bandwidth (Hz) at centre frequency f (Hz), Glasberg & Moore (1990): ERB(f) = 24.7 · (0.00437 f + 1).

deepSTRF.models.scales.ERB_to_Hz(erb)[source]

Inverse of Hz_to_ERB().

deepSTRF.models.scales.Greenwood(x, animal='human')[source]

Greenwood cochlear position-to-frequency map.

Parameters:
  • x (torch.Tensor or float) – Normalised position along the cochlea in [0, 1].

  • animal ({'human', 'mouse'}, default 'human') – Species whose Greenwood coefficients to use.

Returns:

Characteristic frequency in Hz.

Return type:

torch.Tensor or float

Raises:

NotImplementedError – If animal is not 'human' or 'mouse'.

deepSTRF.models.scales.Hz_to_ERB(f)[source]

ERB-rate (ERB-number) scale of Glasberg & Moore (1990).

E(f) = 21.4 · log10(0.00437 f + 1) for f in Hz. Equal steps on this scale are equal numbers of equivalent-rectangular-bandwidths apart — the standard cochlear frequency axis for gammatone filterbanks.

deepSTRF.models.scales.Hz_to_mel(f)[source]

Convert a frequency (Hz) to the mel scale.

Parameters:

f (torch.Tensor or float) – Frequency in Hz.

Returns:

Frequency on the mel scale.

Return type:

torch.Tensor or float

deepSTRF.models.scales.inverse_Greenwood(f, animal='human')[source]

Inverse Greenwood map: frequency (Hz) to normalised cochlear position.

Parameters:
  • f (torch.Tensor or float) – Characteristic frequency in Hz.

  • animal ({'human', 'mouse'}, default 'human') – Species whose Greenwood coefficients to use.

Returns:

Normalised position along the cochlea in [0, 1].

Return type:

torch.Tensor or float

Raises:

NotImplementedError – If animal is not 'human' or 'mouse'.

deepSTRF.models.scales.mel_to_Hz(mel_freqs)[source]

Convert mel-scale frequencies back to Hz (inverse of Hz_to_mel()).

Parameters:

mel_freqs (torch.Tensor or float) – Frequency on the mel scale.

Returns:

Frequency in Hz.

Return type:

torch.Tensor or float

Module contents

class deepSTRF.models.NeuralModel(out_neurons: int = 1, *args, **kwargs)[source]

Bases: Module, ABC

Base class for encoding models of sensory neural responses.

A four-slot template defines the canonical forward pipeline:

forward(x):
    x = self.wav2spec(x)        # raw-waveform front-end (future)
    x = self.prefiltering(x)    # AdapTrans / ICAdaptation / Identity
    f = self.core(x)            # shared feature backbone
    return self.readout(f)      # per-neuron projection (B, N, 1, T)

Concrete subclasses populate the slots in their __init__. The defaults for wav2spec, prefiltering and core are nn.Identity, so a minimal model only needs to provide a readout. Subclasses may override forward() for architectures that don’t fit the four-slot pipeline (e.g. StateNet’s recurrent reshape, Transformer’s per-frame attention).

See docs/_source/md/model_paradigm.md for the full contract.

Parameters:

out_neurons (int, default 1) – Number of output neurons N the model predicts. Stored on self.O and used by validate() and STRF_gradmap.

Notes

Subclasses inherit a Hugging Face Hub interface for pretrained checkpoints:

The init kwargs needed to rebuild the architecture are auto-captured on construction (see __init_subclass__()), so end-users never have to supply a config dict — StateNet.from_pretrained("urancon/...") just works.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

count_trainable_params()[source]

Count the model’s trainable parameters.

Returns:

Total number of parameters with requires_grad=True.

Return type:

int

detach()[source]

Detach stateful variables from the computational graph.

No-op by default; recurrent subclasses override this to truncate backpropagation-through-time between chunks (cf. spikingjelly).

forward(stimulus)[source]

Run the default template pipeline.

Applies wav2specprefilteringcorereadout in sequence.

Parameters:

stimulus (torch.Tensor) – Input stimulus batch (shape is modality-dependent; for audio models a spectrogram (B, F, T)).

Returns:

Predicted response of shape (B, N, 1, T).

Return type:

torch.Tensor

classmethod from_pretrained(repo_id_or_path: str | Path, *, extra_kwargs: dict | None = None, strict: bool = True, map_location: str | device = 'cpu', cache_dir: str | Path | None = None, revision: str | None = None, token: str | None = None, return_metadata: bool = False)[source]

Instantiate cls from an HF Hub repo or a local checkpoint folder.

Parameters:
  • repo_id_or_path (str or Path) – "<owner>/<name>" (HF Hub) or a path to a local folder produced by save_pretrained(). The decision is made by pathlib.Path.is_dir() — if the path exists locally it wins, otherwise we treat the string as a Hub repo id.

  • extra_kwargs (dict, optional) – Override / extend the saved config. Use this to re-supply any __init__ argument that wasn’t JSON-serialisable at save time (e.g. a custom prefiltering module).

  • strict (bool, default True) – state_dict strictness.

  • map_location (str or torch.device, default :py:class:``’cpu’:py:class:``)

  • cache_dir – Forwarded to huggingface_hub.snapshot_download().

  • revision – Forwarded to huggingface_hub.snapshot_download().

  • token – Forwarded to huggingface_hub.snapshot_download().

  • return_metadata (bool, default False) – If True, return (model, metadata) instead of just model.

Returns:

Instance of cls with weights loaded.

Return type:

model

push_to_hub(repo_id: str, *, metadata: dict | None = None, model_card: str | None = None, private: bool = False, token: str | None = None, commit_message: str | None = None) str[source]

Push this model to repo_id on the HF Hub.

Saves the checkpoint to a temporary folder and uploads it via deepSTRF.utils.hub.upload_pretrained(). Creates the repo on the fly if it doesn’t exist; user must be authenticated with write access (hf auth login or token=).

Returns:

URL of the resulting commit.

Return type:

str

save_pretrained(save_dir: str | Path, *, metadata: dict | None = None, model_card: str | None = None) Path[source]

Write a checkpoint folder (config.json + model.safetensors).

See deepSTRF.utils.hub.save_pretrained_to_dir() for the full contract.

Parameters:
  • save_dir (str or pathlib.Path) – Destination folder; created if it does not exist.

  • metadata (dict, optional) – Extra JSON-serialisable metadata stored alongside the config (e.g. training dataset, val/test scores).

  • model_card (str, optional) – Markdown content written to README.md in the folder.

Returns:

Path to the written checkpoint folder.

Return type:

pathlib.Path

validate()[source]

Check that the instance is deepSTRF-compatible.

Subclasses should call super().validate() and then add their own checks (e.g. AudioEncodingModel checks F, T > 0).

Raises:

AssertionError – If self.O is not a positive int, if readout is unset or not an torch.nn.Module, or if any of the wav2spec / prefiltering / core slots is not an torch.nn.Module.