import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt
from deepSTRF.models.scales import mel_to_Hz, Hz_to_mel, Greenwood, inverse_Greenwood
[docs]
def get_CFs(min_freq, max_freq, n_freqs, scale):
"""Return ``n_freqs`` cochlear/center frequencies (CFs) on a warped scale.
Parameters
----------
min_freq, max_freq : float
Frequency range in Hz.
n_freqs : int
Number of CFs to return.
scale : {'mel', 'greenwood'}
Frequency-axis warping used to space the CFs.
Returns
-------
torch.Tensor
The ``n_freqs`` center frequencies in Hz.
Raises
------
NotImplementedError
If ``scale`` is not ``'mel'`` or ``'greenwood'``.
"""
if scale == 'mel':
min_cf = Hz_to_mel(torch.tensor(min_freq))
max_cf = Hz_to_mel(torch.tensor(max_freq))
CFs = torch.linspace(min_cf, max_cf, n_freqs)
CFs = mel_to_Hz(CFs)
elif scale == 'greenwood':
min_cf = Greenwood(torch.tensor(min_freq))
max_cf = Greenwood(torch.tensor(max_freq))
CFs = torch.linspace(min_cf, max_cf, n_freqs)
CFs = inverse_Greenwood(CFs)
else:
raise NotImplementedError(f"'scale' argument should be 'mel' or 'greenwood', not '{scale}'")
return CFs
[docs]
def freq_to_tau(freqs):
"""Map frequencies (Hz) to midbrain-neuron time constants (ms).
Parameters
----------
freqs : torch.Tensor
Frequencies in Hz.
Returns
-------
torch.Tensor
Associated time constants in ms.
References
----------
Willmore et al. (2016). "Incorporating Midbrain Adaptation to Mean Sound
Level Improves Models of Auditory Cortical Processing."
"""
return 500. - 105. * torch.log10(freqs)
[docs]
def tau_to_a(time_constants, dt: float = 1):
"""Convert physical time constants (ms) to dimensionless ``a`` parameters.
Parameters
----------
time_constants : torch.Tensor
Time constants in ms.
dt : float, default 1
Time-step width in ms.
Returns
-------
torch.Tensor
The corresponding ``a = exp(-dt / tau)`` parameters.
"""
return torch.exp(- dt / time_constants)
[docs]
def a_to_tau(a, dt: float = 1):
"""Inverse of :func:`tau_to_a`: ``a`` parameter back to a time constant (ms).
Parameters
----------
a : torch.Tensor
Dimensionless ``a`` parameters.
dt : float, default 1
Time-step width in ms.
Returns
-------
torch.Tensor
Time constants in ms.
"""
return - dt / torch.log(a)
[docs]
class ICAdaptation(nn.Module):
"""
High-pass exponential filter with frequency-dependent time constants —
a paper-faithful re-implementation of the inferior-colliculus
adaptation prefilter described by Willmore et al. (2016).
Independently filters each frequency band of an input spectrogram
along the temporal dimension with a parameterized exponential kernel:
kernel = [...; -Cwa²; -Cwa; -Cw; +1] with C = 1/(... + a² + a + 1)
where the sum of the negative terms equals ``w``. The filter
effectively computes the difference between the current value of
the signal in each frequency band and an exponential average of its
recent past, then applies a half-wave rectification.
Parameters
----------
init_a_vals : 1D Tensor of length ``F``
Per-frequency ``a`` parameters (related to the exponential time
constant: higher ``a`` → longer time constant).
kernel_size : int, default 2
Length of the temporal kernel (in frames).
References
----------
Willmore, Schoppe, King, Schnupp, Harper (2016). "Incorporating
Midbrain Adaptation to Mean Sound Level Improves Models of
Auditory Cortical Processing." J. Neurosci. 36(2): 280–289.
https://doi.org/10.1523/JNEUROSCI.2441-15.2016
Notes
-----
Intentionally non-learnable: the time constants are derived
analytically from the cochlear frequency map (see
``freq_to_tau``) and are paper-faithful. For a learnable
extension, use :class:`AdapTrans`.
Input shape: ``(B, 1, F, T)``.
Output shape: ``(B, 1, F, T)``.
"""
out_channels: int = 1
def __init__(self, init_a_vals, kernel_size: int = 2):
super().__init__()
# make sure passed a is one-dimensional
assert len(init_a_vals.shape) == 1
# general attributes
self.F = len(init_a_vals)
self.K = kernel_size
# frozen, paper-faithful (Willmore et al. 2016) — registered as
# buffer so it follows .to(device) but isn't trained.
self.register_buffer('a', init_a_vals)
[docs]
def build_kernels(self):
"""
Creates a parametrized kernel: a high-pass exponential filter that
highlights onsets in the signal.
"""
device = self.a.device
# normalization constant
ones = torch.ones(self.K - 1, device=device)
rng = torch.arange(0, self.K - 1, device=device)
C = 1 / (torch.outer(self.a, ones) ** rng).sum(dim=1)
# kernel begins with an exponential whose elements sum to -w, then finishes with +1
kernel = torch.ones(self.F, 1, self.K, device=device)
kernel[:, 0, 1:] = -C.unsqueeze(-1) * (torch.outer(self.a, ones) ** rng)
kernel = torch.flip(kernel, dims=(2,))
return kernel
[docs]
def forward(self, spectro_in):
"""High-pass filter and half-wave rectify each frequency band.
Parameters
----------
spectro_in : torch.Tensor
Input spectrogram of shape ``(B, 1, F, T)`` (B=batch, 1 channel,
F frequency bands, T timesteps).
Returns
-------
torch.Tensor
High-pass-filtered, full-wave-rectified spectrogram of shape
``(B, 1, F, T)``.
"""
# reshape input spectrogram from single-channel 2D representation to multi-channel 1D
spectro_in = spectro_in.squeeze(1) # (B, 1, F, T) --> (B, F, T)
# build high-pass exponential kernel
kernel = self.build_kernels()
# convolve input spectrogram with the kernels
spectro_in = F.pad(spectro_in, pad=(self.K-1, 0), mode='replicate') # (B, F, T) --> (B, F, T+K-1)
out = F.conv1d(spectro_in, kernel, stride=1, groups=self.F) # (B, F, T+K-1) --> (B, F, T)
# full-wave rectification
out = torch.relu(out)
# reshape output from a 1D back to 2D representation
spectro_out = torch.unsqueeze(out, dim=1) # (B, 1, F, T)
return spectro_out
[docs]
def plot_kernels(self, frequency_bin=0):
kernel = self.build_kernels()
filter = kernel[frequency_bin, :].squeeze().detach().cpu().numpy()
plt.figure()
plt.stem(torch.arange(0, self.K, 1).numpy(), filter, 'r', markerfmt='ro', label='ICAdaptation')
plt.legend()
plt.show()
# Backwards-compatibility alias — the original class name was kept long
# enough to be cited in older notebooks. Will be removed in a future release.
Willmore_Adaptation = ICAdaptation
[docs]
class AdapTrans(nn.Module):
"""
Adaptive ON/OFF spectrogram prefilter — the learnable extension of the
inferior-colliculus adaptation prefilter.
Computes ON and OFF spectrograms through high-pass exponential
filters with frequency-dependent, learnable time constants. Each
frequency band is independently filtered along the temporal
dimension with a parameterized exponential kernel:
kernel = [...; -Cwa²; -Cwa; -Cw; +1] with C = 1/(... + a² + a + 1)
where the sum of the negative terms equals ``w``. The filter
computes the difference between the current value of the signal
in each frequency band and an exponential average of its recent
past, with separate ``(a, w)`` pairs giving rise to ON and OFF
polarities.
Parameters
----------
init_a_vals : 1D Tensor of length ``F``
Per-frequency ``a`` parameters (related to the time constant).
init_w_vals : 1D Tensor of length ``F``
Per-frequency ``w`` parameters (relative weight of the past
average vs the present sample).
kernel_size : int, default 2
Length of the temporal kernel (in frames).
learnable : bool, default True
If True, ``a`` and ``w`` are learnable nn.Parameters; if False,
they are frozen buffers (still follow ``.to(device)``).
References
----------
Rançon, Masquelier & Cottereau (2024). "A general model unifying
the adaptive, transient and sustained properties of ON and OFF
auditory neural responses." PLOS Computational Biology
20(8):e1012288. https://doi.org/10.1371/journal.pcbi.1012288
Notes
-----
Input shape: ``(B, 1, F, T)``.
Output shape: ``(B, 2, F, T)`` — channel 0 is ON, channel 1 is OFF.
"""
out_channels: int = 2
def __init__(self, init_a_vals, init_w_vals, kernel_size: int = 2, learnable: bool = True):
"""
init_a_vals: a 1D vector of 'a' parameters (related to the time constant of the kernel's exponential). The
higher the 'a', the higher the corresponding time constant of the exponential
init_w_vals: a 1D vector of 'w' parameters (representing the weight given to the exponential average of the
signal in its recent past)
"""
super(AdapTrans, self).__init__()
# make sure passed a and w are one-dimensional
assert init_a_vals.shape == init_w_vals.shape
assert len(init_a_vals.shape) == 1
# general attributes
self.F = len(init_a_vals)
self.K = kernel_size
# conversion
init_d_vals = torch.sqrt(1/torch.Tensor(init_a_vals) - 1)
init_p_vals = torch.sqrt(1/torch.Tensor(init_w_vals) - 1)
# parameters (one per freq.) — Parameter when learnable, buffer otherwise.
# Both follow the module's .to(device); raw tensor attributes do not.
if learnable:
self.d_on = Parameter(init_d_vals.clone())
self.d_off = Parameter(init_d_vals.clone())
self.p = Parameter(init_p_vals.clone())
else:
self.register_buffer('d_on', init_d_vals.clone())
self.register_buffer('d_off', init_d_vals.clone())
self.register_buffer('p', init_p_vals.clone())
[docs]
def build_kernels(self):
"""
Creates two parametrized kernels:
- one for the ON response, highlighting onsets in the signal
- one for the OFF response (offsets), which is the flipped version of the ON kernel
"""
kernel_ON = self.ON_kernel(self.d_on, self.p)
kernel_OFF = self.OFF_kernel(self.d_off, self.p)
return kernel_ON, kernel_OFF
[docs]
def ON_kernel(self, d, p):
"""
Creates the ON kernel
"""
device = d.device
# normalization constant
a = 1 / (1 + d.pow(2))
ones = torch.ones(self.K - 1, device=device)
rng = torch.arange(0, self.K - 1, device=device)
C = 1 / (torch.outer(a, ones) ** rng).sum(dim=1)
# ON kernel begins with an exponential whose elements sum to -w, then finishes with +1
kernel_ON = torch.ones(self.F, 1, self.K, device=device)
w = 1 / (1 + p.pow(2))
kernel_ON[:, 0, 1:] = (-C * w).unsqueeze(-1) * (torch.outer(a, ones) ** rng)
kernel_ON = torch.flip(kernel_ON, dims=(2,))
return kernel_ON
[docs]
def OFF_kernel(self, d, p):
"""
Creates the OFF kernel — the ON kernel flipped about zero, then
renormalized so its tail equals -w.
"""
kernel_ON = self.ON_kernel(d, p)
w = 1 / (1 + p.pow(2))
# OFF kernel begins with an exponential whose elements sum to +1, then finishes with -w
kernel_OFF = - kernel_ON / w.unsqueeze(1).unsqueeze(1)
kernel_OFF[:, 0, -1] = - w
return kernel_OFF
[docs]
def forward(self, spectro_in):
"""Compute the ON and OFF high-pass-filtered spectrograms.
Parameters
----------
spectro_in : torch.Tensor
Input spectrogram of shape ``(B, 1, F, T)`` (B=batch, 1 channel,
F frequency bands, T timesteps).
Returns
-------
torch.Tensor
Tensor of shape ``(B, 2, F, T)``: channel 0 is the ON response,
channel 1 is the OFF response (both half-wave rectified).
"""
# reshape input spectrogram from single-channel 2D representation to multi-channel 1D
spectro_in = spectro_in.squeeze(1) # (B, 1, F, T) --> (B, F, T)
# build ON and OFF high-pass exponential kernels (live on parameter device, follows .to(device))
kernel_ON, kernel_OFF = self.build_kernels()
# convolve input spectrogram with the kernels
spectro_in = nn.functional.pad(spectro_in, pad=(self.K-1, 0), mode='replicate') # (B, F, T) --> (B, F, T+K-1)
out_ON = nn.functional.conv1d(spectro_in, kernel_ON, stride=1, groups=self.F) # (B, F, T+K-1) --> (B, F, T)
out_OFF = nn.functional.conv1d(spectro_in, kernel_OFF, stride=1, groups=self.F) # (B, F, T+K-1) --> (B, F, T)
# reshape output from a 1D back to 2D representation
spectro_out = torch.stack([out_ON, out_OFF], dim=1) # (B, 2, F, T)
return torch.relu(spectro_out)
[docs]
def get_a(self):
a_on = 1 / (1 + self.d_on.cpu().detach().pow(2))
a_off = 1 / (1 + self.d_off.cpu().detach().pow(2))
return a_on, a_off
[docs]
def get_w(self):
return 1 / (1 + self.p.cpu().detach().pow(2))
[docs]
def plot_kernels(self, frequency_bin=0):
kernel_ON, kernel_OFF = self.build_kernels()
ON_filter = kernel_ON[frequency_bin, :].squeeze().detach().cpu().numpy()
OFF_filter = kernel_OFF[frequency_bin, :].squeeze().detach().cpu().numpy()
plt.figure()
plt.stem(torch.arange(0, self.K, 1).numpy(), ON_filter, 'r', markerfmt='ro', label='ON')
plt.stem(torch.arange(0, self.K, 1).numpy(), OFF_filter, 'b', markerfmt='bo', label='OFF')
plt.legend()
plt.show()
[docs]
def make_prefiltering(kind: str, n_frequency_bands: int, dt: float,
min_freq: float = 500.0, max_freq: float = 20000.0,
scale: str = 'mel', learnable: bool = True,
init_w: float = 0.75) -> nn.Module:
"""
Factory for constructing a prefilter module from compact arguments.
Convenience wrapper that derives per-frequency ``a`` (and ``w``)
initial values from the cochlear frequency map, then instantiates
the requested prefilter class. Equivalent to building the prefilter
by hand; the factory exists so that user code does not need to
repeat the ``get_CFs`` / ``freq_to_tau`` / ``tau_to_a`` pipeline.
Parameters
----------
kind : {'adaptrans', 'icadaptation', 'willmore'}
Which prefilter to build. ``'willmore'`` is an alias for
``'icadaptation'``.
n_frequency_bands : int
Number of input frequency bands ``F`` of the spectrogram.
dt : float
Time bin width in milliseconds (matches ``dataset.dt_ms``).
min_freq, max_freq : float
Frequency range (in Hz) spanned by the cochlear filterbank that
produced the spectrogram. Defaults: 500 / 20 000 Hz.
scale : {'mel', 'greenwood'}, default 'mel'
Frequency-axis scaling used to derive per-band time constants.
learnable : bool, default True
Only relevant for ``'adaptrans'``. ``ICAdaptation`` is always
frozen (paper-faithful).
init_w : float, default 0.75
Only relevant for ``'adaptrans'``: initial value of the past-vs-
present weight ``w``.
Returns
-------
nn.Module
Configured prefilter instance with an ``out_channels`` attribute.
"""
cf = get_CFs(min_freq, max_freq, n_frequency_bands, scale)
tau = freq_to_tau(cf)
a = tau_to_a(tau, dt=dt)
K = round(3 * max(tau).item()) + 1
kind = kind.lower()
if kind == 'adaptrans':
w = torch.ones_like(a) * init_w
return AdapTrans(init_a_vals=a, init_w_vals=w, kernel_size=K, learnable=learnable)
if kind in ('icadaptation', 'willmore'):
return ICAdaptation(init_a_vals=a, kernel_size=K)
raise ValueError(
f"Unknown prefilter kind {kind!r}. Currently supported: "
f"'adaptrans', 'icadaptation' (alias 'willmore')."
)