Fitting NS1 from raw waveform: causal mel, SincNet, ICNet

Open In Colab

deepSTRF accepts raw audio waveforms in addition to precomputed spectrograms. This notebook walks through the three shipped wav2spec front-ends on NS1:

  1. ``CausalMelSpectrogram`` — non-learnable causal log-mel with the Rahman 2019 cochleagram defaults (10 ms Hanning, 500–22 627 Hz, amplitude, threshold-clipped log). The pipeline-validation baseline.

  2. ``SincNet`` — parametric bandpass filterbank (Ravanelli & Bengio

    1. with envelope=True to make it a proper cochleagram. The learnable spectrogram.

  3. ``ICNet`` — full encoder + decoder model from Drakopoulos et al. (Nat. Mach. Intell. 2025), ported to deepSTRF with auto-adapted strides for NS1’s 48 kHz / 5 ms binning. Paper-faithful single-branch Poisson-head variant.

The data side is just NS1Dataset(return_waveform=True) — the dataset returns (1, T_audio=239 760) mono float tensors at 48 kHz, aligned to the existing 999-bin neural response grid. See `wav2spec.md <../docs/_source/md/wav2spec.md>`__ for the slot contract.

Setup — Google Colab

If you’re running on Google Colab, install deepSTRF from source. On a local install (pip install -e .) the cell is a no-op.

[ ]:
import sys
if 'google.colab' in sys.modules:
    !pip install -q git+https://github.com/urancon/deepSTRF.git
    print('deepSTRF installed from GitHub.')
else:
    print('Local environment — assuming deepSTRF is already importable.')

Imports

[ ]:
%matplotlib inline
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset

from deepSTRF.datasets.audio.ns1 import NS1Dataset
from deepSTRF.metrics import poisson_loss
from deepSTRF.models.audio import Linear, ICNet
from deepSTRF.models.wav2spec import CausalMelSpectrogram, SincNet
from deepSTRF.training import Fitter
from deepSTRF.utils import neural_collate, compare_wav2spec_to_groundtruth

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}')

1. Load NS1 in waveform mode

We instantiate the dataset twice — once in waveform mode (for the wav-input models) and once in default spectrogram mode (as the ground-truth spec for visual comparison + the spec-side baseline training arm). Responses are bit-identical between the two; only the self.stims representation differs.

[ ]:
ds_wav = NS1Dataset(return_waveform=True, download=True)
ds_spec = NS1Dataset(download=True)

N = ds_wav.N_neurons
samples_per_bin = ds_wav.hop   # audio samples per neural bin = audio_fs * dt_ms / 1000
T_neural = ds_wav.stims[0].shape[-1] // samples_per_bin
print(f'NS1: N={N} cells | audio {ds_wav.audio_fs} Hz | T_audio={ds_wav.stims[0].shape[-1]} '
      f'({samples_per_bin} samples per {ds_wav.dt:.0f} ms bin) | T_neural={T_neural} bins of {ds_wav.dt:.0f} ms')
print(f'wav stim 0 shape: {tuple(ds_wav.stims[0].shape)}')
print(f'spec stim 0 shape: {tuple(ds_spec.stims[0].shape)}')

2. Visual sanity: Rahman causal mel vs ground-truth

compare_wav2spec_to_groundtruth returns the wav2spec output, the precomputed Rahman cochleagram (X_nfht), and a 3-panel figure (pred | truth | difference, all z-scored). With the Rahman-tuned defaults the qualitative match is strong — typical stims correlate 0.7–0.85 with the ground truth (mean ≈ 0.66 across the 20 stims).

[ ]:
mel = CausalMelSpectrogram(audio_fs=ds_wav.audio_fs)   # Rahman defaults
for stim_idx in (0, 6, 8, 12):
    pred, truth, fig = compare_wav2spec_to_groundtruth(
        ds_wav, mel, stim_idx=stim_idx,
        ground_truth_stims=ds_spec.stims,
        suptitle=f'NS1 stim {stim_idx} ({ds_spec.stim_meta[stim_idx]["type"]})'
    )
    r = np.corrcoef(pred.ravel(), truth.ravel())[0, 1]
    print(f'stim {stim_idx} ({ds_spec.stim_meta[stim_idx]["type"]}): pred-vs-truth r = {r:.3f}')
    plt.show()

3. A small fit-and-report helper

Same split (14 / 3 / 3 by stim index) and patience (30) across arms. Default loss is MSE; ICNet passes poisson_loss to match its non-negative softplus output.

[ ]:
def fit_and_report(ds, model, label, *, max_epochs=100, lr=1e-3,
                    loss_fn=None, patience=30):
    train = DataLoader(Subset(ds, list(range(14))),     batch_size=1, shuffle=True,  collate_fn=neural_collate)
    val   = DataLoader(Subset(ds, list(range(14, 17))), batch_size=1, shuffle=False, collate_fn=neural_collate)
    test  = DataLoader(Subset(ds, list(range(17, 20))), batch_size=1, shuffle=False, collate_fn=neural_collate)
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0)
    kwargs = {'loss_fn': loss_fn} if loss_fn is not None else {}
    fitter = Fitter(model, train, val, optimizer=optim, device=DEVICE,
                    max_epochs=max_epochs, patience=patience,
                    monitor='val_cc_norm', mode='max',
                    log_fn=lambda d: None, **kwargs)
    t0 = time.time()
    history = fitter.fit()
    elapsed = time.time() - t0
    cc_norm = fitter.evaluate(test)['cc_norm'].cpu()
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return dict(label=label, cc_norm=cc_norm,
                mean=float(cc_norm.mean()), median=float(cc_norm.median()),
                n_params=n_params, elapsed=elapsed, epochs=len(history))

4. Three Linear arms: spec, wav+mel, wav+sincnet

Each arm uses the same Linear(F=34, T_strf=9, N) readout — only the input front-end changes.

[ ]:
results = []
F_bands, T_strf = 34, 9

# A: spec input baseline (default wav2spec=Identity)
torch.manual_seed(0)
m_a = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N)
results.append(fit_and_report(ds_spec, m_a, 'spec (baseline)'))

# B: wav input + Rahman causal mel (defaults)
torch.manual_seed(0)
m_b = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N,
             wav2spec=CausalMelSpectrogram(audio_fs=ds_wav.audio_fs))
results.append(fit_and_report(ds_wav, m_b, 'wav + causal mel'))

# C: wav input + SincNet (envelope, mel-init)
torch.manual_seed(0)
m_c = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N,
             wav2spec=SincNet(audio_fs=ds_wav.audio_fs, n_filters=F_bands,
                              kernel_size=753, hop_ms=ds_wav.dt,
                              f_min=500.0, f_max=22627.0,
                              init='mel', activation='logabs',
                              envelope=True, env_window_ms=10.0))
results.append(fit_and_report(ds_wav, m_c, 'wav + sincnet (env)'))

for r in results:
    print(f"  {r['label']:22s}  mean cc_norm = {r['mean']:.4f}  median = {r['median']:.4f}  "
          f"params = {r['n_params']:>7,d}  epochs = {r['epochs']:>3d}  {r['elapsed']:.0f}s")

5. ICNet (Drakopoulos et al. 2025) on NS1

ICNet is a much deeper model (5.1 M params) and was trained in the paper on midbrain (IC) data in gerbils — not cortex (A1) in ferrets like NS1. Two paper-faithful hyperparameters that matter on NS1:

  • lr=4e-4 (the paper’s value; 1e-3 makes training unstable here).

  • poisson_loss (the model’s softplus output is a non-negative rate; Poisson NLL is the appropriate loss).

Training takes ~15 minutes on a small GPU.

[ ]:
torch.manual_seed(0)
m_icnet = ICNet(audio_fs=ds_wav.audio_fs, out_neurons=N, dt_ms=ds_wav.dt)
print(f'ICNet on NS1: strides={m_icnet.wav2spec.encoder_strides}  '
      f'params={sum(p.numel() for p in m_icnet.parameters()):,}')

results.append(fit_and_report(ds_wav, m_icnet, 'wav + ICNet',
                               max_epochs=100, lr=4e-4,
                               loss_fn=poisson_loss))

print()
print(f'{"front-end":<22s}  {"mean cc_norm":>12s}  {"median":>8s}  {"params":>9s}  {"epochs":>6s}  {"time":>5s}')
for r in results:
    print(f'{r["label"]:<22s}  {r["mean"]:>12.4f}  {r["median"]:>8.4f}  '
          f'{r["n_params"]:>9,d}  {r["epochs"]:>6d}  {r["elapsed"]:>4.0f}s')

6. What did SincNet learn?

Plot the SincNet cutoffs before vs after training. The cutoffs barely move during NS1 fitting — gradients into f1/f2 are tiny relative to the cutoff magnitudes, so SincNet effectively acts as a fixed mel-spaced bandpass filterbank. The downstream conv stack does the representational work.

[ ]:
# Re-init a fresh SincNet for the initial-cutoff baseline
torch.manual_seed(0)
sn_init = SincNet(audio_fs=ds_wav.audio_fs, n_filters=F_bands, kernel_size=753, hop_ms=ds_wav.dt,
                   f_min=500.0, f_max=22627.0, init='mel', activation='logabs',
                   envelope=True, env_window_ms=10.0)

sn_trained = m_c.wav2spec
with torch.no_grad():
    f1_i, f2_i = sn_init.f1.cpu().numpy(), sn_init.f2.cpu().numpy()
    f1_t, f2_t = sn_trained.f1.cpu().numpy(), sn_trained.f2.cpu().numpy()

fig, ax = plt.subplots(figsize=(7, 4))
idx = np.arange(len(f1_i))
ax.fill_between(idx, f1_i, f2_i, alpha=0.3, label='init passband')
ax.plot(idx, (f1_i + f2_i) / 2, 'o--', ms=3, label='init centre')
ax.fill_between(idx, f1_t, f2_t, alpha=0.3, label='trained passband', color='C1')
ax.plot(idx, (f1_t + f2_t) / 2, 'x-',  ms=4, label='trained centre', color='C1')
ax.set_yscale('log')
ax.set_xlabel('filter index')
ax.set_ylabel('frequency (Hz, log)')
ax.set_title('SincNet cutoffs on NS1: init vs trained')
ax.legend()
ax.grid(True, which='both', alpha=0.3)
plt.tight_layout()
plt.show()

rel_drift = np.mean(np.abs(f1_t - f1_i) / np.abs(f1_i + 1e-9))
print(f'mean relative drift of f1 cutoffs: {rel_drift:.2e}  (typically < 1e-3 — cutoffs barely move)')

Takeaways

Numbers from a single-seed run on a GTX 1650 (your numbers may vary slightly):

front-end

mean test cc_norm

params

spec (X_nfht baseline)

0.548

37 k

wav + Rahman causal mel

0.573

37 k

wav + SincNet (envelope)

0.340

37 k

wav + ICNet (Poisson)

0.659

5.1 M

  • Causal mel from wav matches/beats the precomputed-spec baseline. Confirms the wav2spec slot mechanics + Rahman defaults are correct.

  • SincNet underperforms fixed mel when paired with a thin Linear readout. Inspecting the learned cutoffs (cell 6) shows they barely move during NS1 training — gradients to f1/f2 are tiny relative to the cutoff magnitudes. SincNet effectively acts as a fixed mel-spaced bandpass filterbank.

  • ICNet’s deep conv stack is what makes the difference. With the paper’s Poisson head + lr=4e-4 it reaches test cc_norm 0.66 on NS1 cortex, even though it was designed for gerbil IC.

See `wav2spec.md <../docs/_source/md/wav2spec.md>`__ for the slot contract and how to write your own front-end.