Fitting NS1 from raw waveform: causal mel, SincNet, ICNet
deepSTRF accepts raw audio waveforms in addition to precomputed spectrograms. This notebook walks through the three shipped wav2spec front-ends on NS1:
``CausalMelSpectrogram`` — non-learnable causal log-mel with the Rahman 2019 cochleagram defaults (10 ms Hanning, 500–22 627 Hz, amplitude, threshold-clipped log). The pipeline-validation baseline.
``SincNet`` — parametric bandpass filterbank (Ravanelli & Bengio
with
envelope=Trueto make it a proper cochleagram. The learnable spectrogram.
``ICNet`` — full encoder + decoder model from Drakopoulos et al. (Nat. Mach. Intell. 2025), ported to deepSTRF with auto-adapted strides for NS1’s 48 kHz / 5 ms binning. Paper-faithful single-branch Poisson-head variant.
The data side is just NS1Dataset(return_waveform=True) — the dataset returns (1, T_audio=239 760) mono float tensors at 48 kHz, aligned to the existing 999-bin neural response grid. See `wav2spec.md <../docs/_source/md/wav2spec.md>`__ for the slot contract.
Setup — Google Colab
If you’re running on Google Colab, install deepSTRF from source. On a local install (pip install -e .) the cell is a no-op.
[ ]:
import sys
if 'google.colab' in sys.modules:
!pip install -q git+https://github.com/urancon/deepSTRF.git
print('deepSTRF installed from GitHub.')
else:
print('Local environment — assuming deepSTRF is already importable.')
Imports
[ ]:
%matplotlib inline
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset
from deepSTRF.datasets.audio.ns1 import NS1Dataset
from deepSTRF.metrics import poisson_loss
from deepSTRF.models.audio import Linear, ICNet
from deepSTRF.models.wav2spec import CausalMelSpectrogram, SincNet
from deepSTRF.training import Fitter
from deepSTRF.utils import neural_collate, compare_wav2spec_to_groundtruth
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}')
1. Load NS1 in waveform mode
We instantiate the dataset twice — once in waveform mode (for the wav-input models) and once in default spectrogram mode (as the ground-truth spec for visual comparison + the spec-side baseline training arm). Responses are bit-identical between the two; only the self.stims representation differs.
[ ]:
ds_wav = NS1Dataset(return_waveform=True, download=True)
ds_spec = NS1Dataset(download=True)
N = ds_wav.N_neurons
samples_per_bin = ds_wav.hop # audio samples per neural bin = audio_fs * dt_ms / 1000
T_neural = ds_wav.stims[0].shape[-1] // samples_per_bin
print(f'NS1: N={N} cells | audio {ds_wav.audio_fs} Hz | T_audio={ds_wav.stims[0].shape[-1]} '
f'({samples_per_bin} samples per {ds_wav.dt:.0f} ms bin) | T_neural={T_neural} bins of {ds_wav.dt:.0f} ms')
print(f'wav stim 0 shape: {tuple(ds_wav.stims[0].shape)}')
print(f'spec stim 0 shape: {tuple(ds_spec.stims[0].shape)}')
2. Visual sanity: Rahman causal mel vs ground-truth
compare_wav2spec_to_groundtruth returns the wav2spec output, the precomputed Rahman cochleagram (X_nfht), and a 3-panel figure (pred | truth | difference, all z-scored). With the Rahman-tuned defaults the qualitative match is strong — typical stims correlate 0.7–0.85 with the ground truth (mean ≈ 0.66 across the 20 stims).
[ ]:
mel = CausalMelSpectrogram(audio_fs=ds_wav.audio_fs) # Rahman defaults
for stim_idx in (0, 6, 8, 12):
pred, truth, fig = compare_wav2spec_to_groundtruth(
ds_wav, mel, stim_idx=stim_idx,
ground_truth_stims=ds_spec.stims,
suptitle=f'NS1 stim {stim_idx} ({ds_spec.stim_meta[stim_idx]["type"]})'
)
r = np.corrcoef(pred.ravel(), truth.ravel())[0, 1]
print(f'stim {stim_idx} ({ds_spec.stim_meta[stim_idx]["type"]}): pred-vs-truth r = {r:.3f}')
plt.show()
3. A small fit-and-report helper
Same split (14 / 3 / 3 by stim index) and patience (30) across arms. Default loss is MSE; ICNet passes poisson_loss to match its non-negative softplus output.
[ ]:
def fit_and_report(ds, model, label, *, max_epochs=100, lr=1e-3,
loss_fn=None, patience=30):
train = DataLoader(Subset(ds, list(range(14))), batch_size=1, shuffle=True, collate_fn=neural_collate)
val = DataLoader(Subset(ds, list(range(14, 17))), batch_size=1, shuffle=False, collate_fn=neural_collate)
test = DataLoader(Subset(ds, list(range(17, 20))), batch_size=1, shuffle=False, collate_fn=neural_collate)
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0)
kwargs = {'loss_fn': loss_fn} if loss_fn is not None else {}
fitter = Fitter(model, train, val, optimizer=optim, device=DEVICE,
max_epochs=max_epochs, patience=patience,
monitor='val_cc_norm', mode='max',
log_fn=lambda d: None, **kwargs)
t0 = time.time()
history = fitter.fit()
elapsed = time.time() - t0
cc_norm = fitter.evaluate(test)['cc_norm'].cpu()
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return dict(label=label, cc_norm=cc_norm,
mean=float(cc_norm.mean()), median=float(cc_norm.median()),
n_params=n_params, elapsed=elapsed, epochs=len(history))
4. Three Linear arms: spec, wav+mel, wav+sincnet
Each arm uses the same Linear(F=34, T_strf=9, N) readout — only the input front-end changes.
[ ]:
results = []
F_bands, T_strf = 34, 9
# A: spec input baseline (default wav2spec=Identity)
torch.manual_seed(0)
m_a = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N)
results.append(fit_and_report(ds_spec, m_a, 'spec (baseline)'))
# B: wav input + Rahman causal mel (defaults)
torch.manual_seed(0)
m_b = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N,
wav2spec=CausalMelSpectrogram(audio_fs=ds_wav.audio_fs))
results.append(fit_and_report(ds_wav, m_b, 'wav + causal mel'))
# C: wav input + SincNet (envelope, mel-init)
torch.manual_seed(0)
m_c = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N,
wav2spec=SincNet(audio_fs=ds_wav.audio_fs, n_filters=F_bands,
kernel_size=753, hop_ms=ds_wav.dt,
f_min=500.0, f_max=22627.0,
init='mel', activation='logabs',
envelope=True, env_window_ms=10.0))
results.append(fit_and_report(ds_wav, m_c, 'wav + sincnet (env)'))
for r in results:
print(f" {r['label']:22s} mean cc_norm = {r['mean']:.4f} median = {r['median']:.4f} "
f"params = {r['n_params']:>7,d} epochs = {r['epochs']:>3d} {r['elapsed']:.0f}s")
5. ICNet (Drakopoulos et al. 2025) on NS1
ICNet is a much deeper model (5.1 M params) and was trained in the paper on midbrain (IC) data in gerbils — not cortex (A1) in ferrets like NS1. Two paper-faithful hyperparameters that matter on NS1:
lr=4e-4(the paper’s value;1e-3makes training unstable here).poisson_loss(the model’s softplus output is a non-negative rate; Poisson NLL is the appropriate loss).
Training takes ~15 minutes on a small GPU.
[ ]:
torch.manual_seed(0)
m_icnet = ICNet(audio_fs=ds_wav.audio_fs, out_neurons=N, dt_ms=ds_wav.dt)
print(f'ICNet on NS1: strides={m_icnet.wav2spec.encoder_strides} '
f'params={sum(p.numel() for p in m_icnet.parameters()):,}')
results.append(fit_and_report(ds_wav, m_icnet, 'wav + ICNet',
max_epochs=100, lr=4e-4,
loss_fn=poisson_loss))
print()
print(f'{"front-end":<22s} {"mean cc_norm":>12s} {"median":>8s} {"params":>9s} {"epochs":>6s} {"time":>5s}')
for r in results:
print(f'{r["label"]:<22s} {r["mean"]:>12.4f} {r["median"]:>8.4f} '
f'{r["n_params"]:>9,d} {r["epochs"]:>6d} {r["elapsed"]:>4.0f}s')
6. What did SincNet learn?
Plot the SincNet cutoffs before vs after training. The cutoffs barely move during NS1 fitting — gradients into f1/f2 are tiny relative to the cutoff magnitudes, so SincNet effectively acts as a fixed mel-spaced bandpass filterbank. The downstream conv stack does the representational work.
[ ]:
# Re-init a fresh SincNet for the initial-cutoff baseline
torch.manual_seed(0)
sn_init = SincNet(audio_fs=ds_wav.audio_fs, n_filters=F_bands, kernel_size=753, hop_ms=ds_wav.dt,
f_min=500.0, f_max=22627.0, init='mel', activation='logabs',
envelope=True, env_window_ms=10.0)
sn_trained = m_c.wav2spec
with torch.no_grad():
f1_i, f2_i = sn_init.f1.cpu().numpy(), sn_init.f2.cpu().numpy()
f1_t, f2_t = sn_trained.f1.cpu().numpy(), sn_trained.f2.cpu().numpy()
fig, ax = plt.subplots(figsize=(7, 4))
idx = np.arange(len(f1_i))
ax.fill_between(idx, f1_i, f2_i, alpha=0.3, label='init passband')
ax.plot(idx, (f1_i + f2_i) / 2, 'o--', ms=3, label='init centre')
ax.fill_between(idx, f1_t, f2_t, alpha=0.3, label='trained passband', color='C1')
ax.plot(idx, (f1_t + f2_t) / 2, 'x-', ms=4, label='trained centre', color='C1')
ax.set_yscale('log')
ax.set_xlabel('filter index')
ax.set_ylabel('frequency (Hz, log)')
ax.set_title('SincNet cutoffs on NS1: init vs trained')
ax.legend()
ax.grid(True, which='both', alpha=0.3)
plt.tight_layout()
plt.show()
rel_drift = np.mean(np.abs(f1_t - f1_i) / np.abs(f1_i + 1e-9))
print(f'mean relative drift of f1 cutoffs: {rel_drift:.2e} (typically < 1e-3 — cutoffs barely move)')
Takeaways
Numbers from a single-seed run on a GTX 1650 (your numbers may vary slightly):
front-end |
mean test cc_norm |
params |
|---|---|---|
spec (X_nfht baseline) |
0.548 |
37 k |
wav + Rahman causal mel |
0.573 |
37 k |
wav + SincNet (envelope) |
0.340 |
37 k |
wav + ICNet (Poisson) |
0.659 |
5.1 M |
Causal mel from wav matches/beats the precomputed-spec baseline. Confirms the
wav2specslot mechanics + Rahman defaults are correct.SincNet underperforms fixed mel when paired with a thin Linear readout. Inspecting the learned cutoffs (cell 6) shows they barely move during NS1 training — gradients to f1/f2 are tiny relative to the cutoff magnitudes. SincNet effectively acts as a fixed mel-spaced bandpass filterbank.
ICNet’s deep conv stack is what makes the difference. With the paper’s Poisson head + lr=4e-4 it reaches test cc_norm 0.66 on NS1 cortex, even though it was designed for gerbil IC.
See `wav2spec.md <../docs/_source/md/wav2spec.md>`__ for the slot contract and how to write your own front-end.