{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fitting NS1 from raw waveform: causal mel, SincNet, ICNet\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/urancon/deepSTRF/blob/develop/examples/fit_ns1_linear_from_waveform.ipynb)\n", "\n", "deepSTRF accepts raw audio waveforms in addition to precomputed\n", "spectrograms. This notebook walks through the three shipped\n", "`wav2spec` front-ends on NS1:\n", "\n", "1. **`CausalMelSpectrogram`** — non-learnable causal log-mel with the\n", " Rahman 2019 cochleagram defaults (10 ms Hanning, 500–22 627 Hz,\n", " amplitude, threshold-clipped log). The pipeline-validation baseline.\n", "2. **`SincNet`** — parametric bandpass filterbank (Ravanelli & Bengio\n", " 2018) with `envelope=True` to make it a proper cochleagram. The\n", " learnable spectrogram.\n", "3. **`ICNet`** — full encoder + decoder model from Drakopoulos et al.\n", " (Nat. Mach. Intell. 2025), ported to deepSTRF with auto-adapted\n", " strides for NS1's 48 kHz / 5 ms binning. Paper-faithful single-branch\n", " Poisson-head variant.\n", "\n", "The data side is just `NS1Dataset(return_waveform=True)` — the dataset\n", "returns `(1, T_audio=239 760)` mono float tensors at 48 kHz, aligned to\n", "the existing 999-bin neural response grid. See\n", "[`wav2spec.md`](../docs/_source/md/wav2spec.md) for the slot contract." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup — Google Colab\n", "\n", "If you're running on Google Colab, install deepSTRF from source. On a\n", "local install (`pip install -e .`) the cell is a no-op." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "if 'google.colab' in sys.modules:\n", " !pip install -q git+https://github.com/urancon/deepSTRF.git\n", " print('deepSTRF installed from GitHub.')\n", "else:\n", " print('Local environment — assuming deepSTRF is already importable.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import time\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from torch.utils.data import DataLoader, Subset\n", "\n", "from deepSTRF.datasets.audio.ns1 import NS1Dataset\n", "from deepSTRF.metrics import poisson_loss\n", "from deepSTRF.models.audio import Linear, ICNet\n", "from deepSTRF.models.wav2spec import CausalMelSpectrogram, SincNet\n", "from deepSTRF.training import Fitter\n", "from deepSTRF.utils import neural_collate, compare_wav2spec_to_groundtruth\n", "\n", "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f'Using device: {DEVICE}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load NS1 in waveform mode\n", "\n", "We instantiate the dataset twice — once in waveform mode (for the\n", "wav-input models) and once in default spectrogram mode (as the\n", "ground-truth spec for visual comparison + the spec-side baseline\n", "training arm). Responses are bit-identical between the two; only the\n", "`self.stims` representation differs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds_wav = NS1Dataset(return_waveform=True, download=True)\n", "ds_spec = NS1Dataset(download=True)\n", "\n", "N = ds_wav.N_neurons\n", "samples_per_bin = ds_wav.hop # audio samples per neural bin = audio_fs * dt_ms / 1000\n", "T_neural = ds_wav.stims[0].shape[-1] // samples_per_bin\n", "print(f'NS1: N={N} cells | audio {ds_wav.audio_fs} Hz | T_audio={ds_wav.stims[0].shape[-1]} '\n", " f'({samples_per_bin} samples per {ds_wav.dt:.0f} ms bin) | T_neural={T_neural} bins of {ds_wav.dt:.0f} ms')\n", "print(f'wav stim 0 shape: {tuple(ds_wav.stims[0].shape)}')\n", "print(f'spec stim 0 shape: {tuple(ds_spec.stims[0].shape)}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Visual sanity: Rahman causal mel vs ground-truth\n", "\n", "`compare_wav2spec_to_groundtruth` returns the wav2spec output, the\n", "precomputed Rahman cochleagram (`X_nfht`), and a 3-panel figure\n", "(pred | truth | difference, all z-scored). With the Rahman-tuned\n", "defaults the qualitative match is strong — typical stims correlate\n", "0.7–0.85 with the ground truth (mean ≈ 0.66 across the 20 stims)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mel = CausalMelSpectrogram(audio_fs=ds_wav.audio_fs) # Rahman defaults\n", "for stim_idx in (0, 6, 8, 12):\n", " pred, truth, fig = compare_wav2spec_to_groundtruth(\n", " ds_wav, mel, stim_idx=stim_idx,\n", " ground_truth_stims=ds_spec.stims,\n", " suptitle=f'NS1 stim {stim_idx} ({ds_spec.stim_meta[stim_idx][\"type\"]})'\n", " )\n", " r = np.corrcoef(pred.ravel(), truth.ravel())[0, 1]\n", " print(f'stim {stim_idx} ({ds_spec.stim_meta[stim_idx][\"type\"]}): pred-vs-truth r = {r:.3f}')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. A small fit-and-report helper\n", "\n", "Same split (14 / 3 / 3 by stim index) and patience (30) across arms.\n", "Default loss is MSE; ICNet passes `poisson_loss` to match its\n", "non-negative softplus output." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_and_report(ds, model, label, *, max_epochs=100, lr=1e-3,\n", " loss_fn=None, patience=30):\n", " train = DataLoader(Subset(ds, list(range(14))), batch_size=1, shuffle=True, collate_fn=neural_collate)\n", " val = DataLoader(Subset(ds, list(range(14, 17))), batch_size=1, shuffle=False, collate_fn=neural_collate)\n", " test = DataLoader(Subset(ds, list(range(17, 20))), batch_size=1, shuffle=False, collate_fn=neural_collate)\n", " optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0)\n", " kwargs = {'loss_fn': loss_fn} if loss_fn is not None else {}\n", " fitter = Fitter(model, train, val, optimizer=optim, device=DEVICE,\n", " max_epochs=max_epochs, patience=patience,\n", " monitor='val_cc_norm', mode='max',\n", " log_fn=lambda d: None, **kwargs)\n", " t0 = time.time()\n", " history = fitter.fit()\n", " elapsed = time.time() - t0\n", " cc_norm = fitter.evaluate(test)['cc_norm'].cpu()\n", " n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", " return dict(label=label, cc_norm=cc_norm,\n", " mean=float(cc_norm.mean()), median=float(cc_norm.median()),\n", " n_params=n_params, elapsed=elapsed, epochs=len(history))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Three Linear arms: spec, wav+mel, wav+sincnet\n", "\n", "Each arm uses the same `Linear(F=34, T_strf=9, N)` readout — only the\n", "input front-end changes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results = []\n", "F_bands, T_strf = 34, 9\n", "\n", "# A: spec input baseline (default wav2spec=Identity)\n", "torch.manual_seed(0)\n", "m_a = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N)\n", "results.append(fit_and_report(ds_spec, m_a, 'spec (baseline)'))\n", "\n", "# B: wav input + Rahman causal mel (defaults)\n", "torch.manual_seed(0)\n", "m_b = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N,\n", " wav2spec=CausalMelSpectrogram(audio_fs=ds_wav.audio_fs))\n", "results.append(fit_and_report(ds_wav, m_b, 'wav + causal mel'))\n", "\n", "# C: wav input + SincNet (envelope, mel-init)\n", "torch.manual_seed(0)\n", "m_c = Linear(n_frequency_bands=F_bands, temporal_window_size=T_strf, out_neurons=N,\n", " wav2spec=SincNet(audio_fs=ds_wav.audio_fs, n_filters=F_bands,\n", " kernel_size=753, hop_ms=ds_wav.dt,\n", " f_min=500.0, f_max=22627.0,\n", " init='mel', activation='logabs',\n", " envelope=True, env_window_ms=10.0))\n", "results.append(fit_and_report(ds_wav, m_c, 'wav + sincnet (env)'))\n", "\n", "for r in results:\n", " print(f\" {r['label']:22s} mean cc_norm = {r['mean']:.4f} median = {r['median']:.4f} \"\n", " f\"params = {r['n_params']:>7,d} epochs = {r['epochs']:>3d} {r['elapsed']:.0f}s\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. ICNet (Drakopoulos et al. 2025) on NS1\n", "\n", "ICNet is a much deeper model (5.1 M params) and was trained in the\n", "paper on **midbrain** (IC) data in **gerbils** — not cortex (A1) in\n", "ferrets like NS1. Two paper-faithful hyperparameters that matter on\n", "NS1:\n", "\n", "- `lr=4e-4` (the paper's value; `1e-3` makes training unstable here).\n", "- `poisson_loss` (the model's softplus output is a non-negative rate;\n", " Poisson NLL is the appropriate loss).\n", "\n", "Training takes ~15 minutes on a small GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(0)\n", "m_icnet = ICNet(audio_fs=ds_wav.audio_fs, out_neurons=N, dt_ms=ds_wav.dt)\n", "print(f'ICNet on NS1: strides={m_icnet.wav2spec.encoder_strides} '\n", " f'params={sum(p.numel() for p in m_icnet.parameters()):,}')\n", "\n", "results.append(fit_and_report(ds_wav, m_icnet, 'wav + ICNet',\n", " max_epochs=100, lr=4e-4,\n", " loss_fn=poisson_loss))\n", "\n", "print()\n", "print(f'{\"front-end\":<22s} {\"mean cc_norm\":>12s} {\"median\":>8s} {\"params\":>9s} {\"epochs\":>6s} {\"time\":>5s}')\n", "for r in results:\n", " print(f'{r[\"label\"]:<22s} {r[\"mean\"]:>12.4f} {r[\"median\"]:>8.4f} '\n", " f'{r[\"n_params\"]:>9,d} {r[\"epochs\"]:>6d} {r[\"elapsed\"]:>4.0f}s')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. What did SincNet learn?\n", "\n", "Plot the SincNet cutoffs before vs after training. The cutoffs barely\n", "move during NS1 fitting — gradients into f1/f2 are tiny relative to\n", "the cutoff magnitudes, so SincNet effectively acts as a fixed\n", "mel-spaced bandpass filterbank. The downstream conv stack does the\n", "representational work." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Re-init a fresh SincNet for the initial-cutoff baseline\n", "torch.manual_seed(0)\n", "sn_init = SincNet(audio_fs=ds_wav.audio_fs, n_filters=F_bands, kernel_size=753, hop_ms=ds_wav.dt,\n", " f_min=500.0, f_max=22627.0, init='mel', activation='logabs',\n", " envelope=True, env_window_ms=10.0)\n", "\n", "sn_trained = m_c.wav2spec\n", "with torch.no_grad():\n", " f1_i, f2_i = sn_init.f1.cpu().numpy(), sn_init.f2.cpu().numpy()\n", " f1_t, f2_t = sn_trained.f1.cpu().numpy(), sn_trained.f2.cpu().numpy()\n", "\n", "fig, ax = plt.subplots(figsize=(7, 4))\n", "idx = np.arange(len(f1_i))\n", "ax.fill_between(idx, f1_i, f2_i, alpha=0.3, label='init passband')\n", "ax.plot(idx, (f1_i + f2_i) / 2, 'o--', ms=3, label='init centre')\n", "ax.fill_between(idx, f1_t, f2_t, alpha=0.3, label='trained passband', color='C1')\n", "ax.plot(idx, (f1_t + f2_t) / 2, 'x-', ms=4, label='trained centre', color='C1')\n", "ax.set_yscale('log')\n", "ax.set_xlabel('filter index')\n", "ax.set_ylabel('frequency (Hz, log)')\n", "ax.set_title('SincNet cutoffs on NS1: init vs trained')\n", "ax.legend()\n", "ax.grid(True, which='both', alpha=0.3)\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "rel_drift = np.mean(np.abs(f1_t - f1_i) / np.abs(f1_i + 1e-9))\n", "print(f'mean relative drift of f1 cutoffs: {rel_drift:.2e} (typically < 1e-3 — cutoffs barely move)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Takeaways\n", "\n", "Numbers from a single-seed run on a GTX 1650 (your numbers may vary\n", "slightly):\n", "\n", "| front-end | mean test cc_norm | params |\n", "|---|---|---|\n", "| spec (X_nfht baseline) | 0.548 | 37 k |\n", "| wav + Rahman causal mel | 0.573 | 37 k |\n", "| wav + SincNet (envelope) | 0.340 | 37 k |\n", "| wav + **ICNet** (Poisson) | **0.659** | 5.1 M |\n", "\n", "- **Causal mel from wav matches/beats the precomputed-spec baseline.**\n", " Confirms the `wav2spec` slot mechanics + Rahman defaults are correct.\n", "- **SincNet underperforms fixed mel** when paired with a thin Linear\n", " readout. Inspecting the learned cutoffs (cell 6) shows they barely\n", " move during NS1 training — gradients to f1/f2 are tiny relative to\n", " the cutoff magnitudes. SincNet effectively acts as a fixed\n", " mel-spaced bandpass filterbank.\n", "- **ICNet's deep conv stack is what makes the difference.** With the\n", " paper's Poisson head + lr=4e-4 it reaches test cc_norm 0.66 on NS1\n", " cortex, even though it was designed for gerbil IC.\n", "\n", "See [`wav2spec.md`](../docs/_source/md/wav2spec.md) for the slot\n", "contract and how to write your own front-end." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 4 }