{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fitting Espejo ferret A1 responses with the NRF model\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/urancon/deepSTRF/blob/develop/examples/espejo_nat_nrf.ipynb)\n", "\n", "End-to-end demo: dataset \u2192 population filter \u2192 train / val / test split \u2192 model \u2192 fit \u2192 eval.\n", "\n", "- **Data**: Lopez-Espejo et al. (2019) ferret A1 \u2014 natural-sound release (NAT, F=18, ~540 cells across 7 animals). Loaded from public Zenodo deposit `3445557`.\n", "- **Population**: one animal (`AMT`, the largest cohort with 168 cells).\n", "- **Splits**: paper-faithful test set (high-rep stims). Estimation set split 90/10 into train / val (stim-level holdout, fixed seed).\n", "- **Model**: Network Receptive Field (NRF, Harper et al. 2016) \u2014 a two-layer STRF network. This is the first deepSTRF example notebook for the NRF.\n", "- **Loss / metrics**: `mse_loss` (NaN-aware), `corrcoef`, `normalized_corrcoef('schoppe')` \u2014 the canonical deepSTRF triad." ], "id": "79dc2727" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup \u2014 Google Colab\n", "\n", "On Colab, the next cell installs deepSTRF from source. On a local `pip install -e .` checkout it's a no-op." ], "id": "a9c2700d" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "if 'google.colab' in sys.modules:\n", " !pip install -q git+https://github.com/urancon/deepSTRF.git\n", " print('deepSTRF installed from GitHub.')\n", "else:\n", " print('Local environment \u2014 assuming deepSTRF is already importable.')" ], "id": "054cc5b4" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ], "id": "66af606b" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from torch.utils.data import DataLoader, Subset\n", "\n", "from deepSTRF.datasets.audio import EspejoDataset\n", "from deepSTRF.models.audio import NetworkReceptiveField\n", "from deepSTRF.metrics import corrcoef, normalized_corrcoef\n", "from deepSTRF.training import Fitter, set_random_seed\n", "from deepSTRF.utils.data import neural_collate\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f'Using device: {device}')\n", "set_random_seed(0)" ], "id": "b9c9b238" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load Espejo NAT\n", "\n", "We instantiate the dataset twice: once filtered to the estimation set (for train + val) and once to the test set. Both share the same 18-band gammatone log-spectrograms at `dt = 10 ms`. The first run with `download=True` fetches the 638 MB NAT archive into the platformdirs cache (`~/.cache/deepSTRF/Espejo` by default); subsequent runs reuse the unpacked archives." ], "id": "717566ed" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds_est = EspejoDataset(stimuli='nat', subset='estimation', download=True)\n", "ds_test = EspejoDataset(stimuli='nat', subset='test', download=True)\n", "\n", "print(ds_est)\n", "print(ds_test)\n", "print(f'\\nstim shape: {tuple(ds_est.stims[0].shape)} (1, F=18, T) at dt=10 ms')\n", "print(f'sample est stim_meta: {ds_est.stim_meta[0]}')\n", "print(f'sample test stim_meta: {ds_test.stim_meta[0]}')" ], "id": "371cd942" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Filter to one animal (`AMT`)\n", "\n", "Espejo NAT pools 7 ferret cohorts. We focus on `AMT` (168 cells, the largest population). The neuron-side filter `select_pop_by_nrn_attr` triggers the [bidirectional rule](https://github.com/urancon/deepSTRF/blob/develop/docs/_source/md/data_paradigm.md#8-iteration-honours-the-current-selection-bidirectional): stims that no AMT cell heard are automatically hidden from `__getitem__`." ], "id": "5a0f2b24" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ANIMAL = 'AMT'\n", "\n", "ds_est.select_pop_by_nrn_attr('animal_id', ANIMAL)\n", "ds_test.select_pop_by_nrn_attr('animal_id', ANIMAL)\n", "\n", "N_AMT = len(ds_est.I)\n", "assert N_AMT == len(ds_test.I), 'AMT cell count should match across est/test'\n", "print(f'AMT cells: {N_AMT}')\n", "print(f'est stims (visible after filter): {len(ds_est)}')\n", "print(f'test stims (visible after filter): {len(ds_test)}')\n" ], "id": "d6dd43f5" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Quick look \u2014 one stim, one cell\n", "\n", "Inspect one of the high-rep test stims for one AMT cell: spectrogram, raster, PSTH." ], "id": "dc3dcb42" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# pick the test stim with the most coverage across AMT cells\n", "masks = ds_test.nrn_masks[:, ds_test.I] # (S, N_AMT)\n", "stim_cov = masks.sum(dim=1)\n", "stim_idx = int(stim_cov.argmax().item())\n", "# inside that stim, pick a cell with valid data\n", "cell_local_idx = int(masks[stim_idx].nonzero(as_tuple=True)[0][0].item())\n", "cell_global_idx = ds_test.I[cell_local_idx]\n", "\n", "spec = ds_test.stims[stim_idx][0].numpy() # (F, T)\n", "resp = ds_test.responses[stim_idx][cell_global_idx].numpy() # (R, T)\n", "psth = resp.mean(axis=0)\n", "t = np.arange(spec.shape[1]) * 1e-2 # seconds (dt=10 ms)\n", "\n", "fig, axs = plt.subplots(3, 1, figsize=(9, 5.5), sharex=True,\n", " gridspec_kw={'height_ratios': [2, 2, 1]})\n", "axs[0].imshow(spec, aspect='auto', origin='lower', cmap='magma',\n", " extent=[t[0], t[-1], 0, 18])\n", "axs[0].set_ylabel('freq band')\n", "axs[0].set_title(f'stim: {ds_test.stim_meta[stim_idx][\"name\"]} | cell: {ds_test.nrn_meta[cell_global_idx][\"cell_id\"]} | R={resp.shape[0]} reps')\n", "\n", "yi, xi = np.where(resp > 0)\n", "axs[1].scatter(t[xi], yi, s=4, c='k')\n", "axs[1].set_ylabel('trial')\n", "axs[1].set_ylim(-0.5, resp.shape[0] - 0.5)\n", "\n", "axs[2].plot(t, psth, color='k', lw=1.2)\n", "axs[2].set_xlabel('time (s)')\n", "axs[2].set_ylabel('spikes/bin')\n", "plt.tight_layout(); plt.show()" ], "id": "b04b1754" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Train / val / test split\n", "\n", "The test set is fixed by the paper convention (high-rep stims). The estimation set is split 90/10 at the stim level for train / val with a fixed seed." ], "id": "457c170f" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "S_est = len(ds_est)\n", "rng = np.random.RandomState(0)\n", "shuffled = list(range(S_est))\n", "rng.shuffle(shuffled)\n", "n_val = max(1, S_est // 10)\n", "val_idx = sorted(shuffled[:n_val])\n", "train_idx = sorted(shuffled[n_val:])\n", "\n", "print(f'train stims: {len(train_idx)}')\n", "print(f'val stims: {len(val_idx)}')\n", "print(f'test stims: {len(ds_test)}')\n", "\n", "BS = 32\n", "train_loader = DataLoader(Subset(ds_est, train_idx), batch_size=BS,\n", " shuffle=True, collate_fn=neural_collate)\n", "val_loader = DataLoader(Subset(ds_est, val_idx), batch_size=BS,\n", " shuffle=False, collate_fn=neural_collate)\n", "test_loader = DataLoader(ds_test, batch_size=BS,\n", " shuffle=False, collate_fn=neural_collate)\n" ], "id": "54775818" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Model \u2014 NRF\n", "\n", "`NetworkReceptiveField` (Harper, Schoppe, Willmore, Cui, Schnupp & King 2016) is a two-layer STRF network: an STRF kernel projects the input spectrogram into `H` hidden units, then a per-neuron 1\u00d71 readout produces the population output. With `H=20` hidden units and `N_AMT` output neurons, one model fits all AMT cells jointly through a shared bottleneck." ], "id": "40ee02e9" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = NetworkReceptiveField(\n", " n_frequency_bands=18,\n", " temporal_window_size=15, # 150 ms history at dt=10 ms\n", " n_hidden=20,\n", " out_neurons=N_AMT,\n", ")\n", "print(model)\n", "print(f'\\nTrainable params: {model.count_trainable_params():,}')\n", "print(f' per neuron: {model.count_trainable_params() / N_AMT:,.0f}')" ], "id": "e24a7fd5" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Train with the `Fitter`\n", "\n", "`Fitter` wires the model, the loaders, an optimizer, and the canonical val metrics (`cc`, `cc_norm`) into a single `.fit()` call. Default loss is NaN-aware `mse_loss` against the auto-PSTH of `responses`. We monitor `val_cc_norm` (Schoppe noise-corrected correlation) and stop early on no improvement for 15 epochs." ], "id": "0f7eab15" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n", "\n", "fitter = Fitter(\n", " model, train_loader, val_loader,\n", " optimizer=optimizer,\n", " device=device,\n", " max_epochs=80,\n", " patience=15,\n", " monitor='val_cc_norm',\n", " mode='max',\n", " log_fn=lambda d: print(\n", " f\"epoch {d['epoch']:3d} \"\n", " f\"train_loss={d['train_loss']:.4f} \"\n", " f\"val_loss={d['val_loss']:.4f} \"\n", " f\"val_cc_norm={torch.nanmean(d['val_cc_norm']):+.3f}\",\n", " flush=True,\n", " ),\n", ")\n", "history = fitter.fit()" ], "id": "774b2c4d" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Training curves" ], "id": "a9122f52" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "epochs = [h['epoch'] for h in history]\n", "train_loss = [h['train_loss'] for h in history]\n", "val_loss = [h['val_loss'] for h in history]\n", "val_cc = [torch.nanmean(h['val_cc']).item() for h in history]\n", "val_ccn = [torch.nanmean(h['val_cc_norm']).item() for h in history]\n", "\n", "fig, axs = plt.subplots(1, 2, figsize=(10, 3.5))\n", "axs[0].plot(epochs, train_loss, label='train', lw=1.5)\n", "axs[0].plot(epochs, val_loss, label='val', lw=1.5)\n", "axs[0].set_xlabel('epoch'); axs[0].set_ylabel('MSE loss'); axs[0].legend()\n", "axs[0].set_title('Loss')\n", "\n", "axs[1].plot(epochs, val_cc, label='val cc', lw=1.5)\n", "axs[1].plot(epochs, val_ccn, label='val cc_norm', lw=1.5)\n", "axs[1].set_xlabel('epoch'); axs[1].set_ylabel('mean across cells')\n", "axs[1].legend(); axs[1].set_title('Correlations')\n", "plt.tight_layout(); plt.show()" ], "id": "ee04e357" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Test-set evaluation\n", "\n", "`Fitter.evaluate` runs the same cross-batch concat-then-compute pipeline on the held-out test stims and returns the un-prefixed metric dict." ], "id": "399feebb" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_metrics = fitter.evaluate(test_loader)\n", "test_cc = test_metrics['cc'].cpu()\n", "test_ccn = test_metrics['cc_norm'].cpu()\n", "print(f\"test loss: {test_metrics['loss']:.4f}\")\n", "print(f'test cc mean={torch.nanmean(test_cc):+.3f} median={torch.nanmedian(test_cc):+.3f}')\n", "print(f'test cc_norm mean={torch.nanmean(test_ccn):+.3f} median={torch.nanmedian(test_ccn):+.3f}')" ], "id": "8daed454" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Per-cell `cc_norm` distribution" ], "id": "a7fa2d7f" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(8, 3.5))\n", "valid = ~test_ccn.isnan()\n", "ax.hist(test_ccn[valid].numpy(), bins=30, edgecolor='black', alpha=0.85)\n", "ax.axvline(torch.nanmean(test_ccn).item(), color='red', lw=2,\n", " label=f'mean = {torch.nanmean(test_ccn):.3f}')\n", "ax.set_xlabel('test cc_norm (Schoppe)')\n", "ax.set_ylabel('# AMT cells')\n", "ax.set_title('Per-cell noise-corrected correlation on Espejo NAT')\n", "ax.legend()\n", "plt.tight_layout(); plt.show()" ], "id": "e20ffbc4" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10. NRF hidden STRFs\n", "\n", "The hidden STRF kernels are interpretable as the model's learned auditory features. We plot the first 8 hidden units' STRFs." ], "id": "55000cfc" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_show = min(8, model.H)\n", "fig, axs = plt.subplots(2, n_show // 2, figsize=(2.0 * (n_show // 2), 4))\n", "for h, ax in enumerate(axs.flat):\n", " strf = model.STRFs(hidden_idx=h).detach().cpu().numpy() # (F, T)\n", " vmax = float(np.abs(strf).max() + 1e-9)\n", " ax.imshow(strf, aspect='auto', origin='lower', cmap='RdBu_r',\n", " vmin=-vmax, vmax=vmax)\n", " ax.set_title(f'h={h}', fontsize=9)\n", " ax.set_xticks([]); ax.set_yticks([])\n", "fig.suptitle(f'NRF hidden STRFs (first {n_show} of {model.H} units)', y=1.02)\n", "plt.tight_layout(); plt.show()" ], "id": "81c055ab" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Notes\n", "\n", "- Estimation stims in Espejo NAT have 1\u20133 repetitions per (cell, stim) \u2014 the PSTH target is essentially the single trial. `cc_norm` (Schoppe) corrects for the resulting noise ceiling.\n", "- The `subset='estimation'` filter keeps the full estimation stim bank that any AMT site presented; some stims have valid responses for only a subset of the 168 AMT cells (the `nrn_masks` is block-sparse across recording sites). The Fitter's NaN-aware loss handles this transparently.\n", "- With ~150 epochs and a few seeds, NRF on Espejo AMT typically reaches `cc_norm` in the 0.35\u20130.45 range \u2014 in line with the LN baselines reported in Lopez-Espejo et al. (their LN baseline is ~0.5 prediction correlation on a held-out stim; cc_norm metrics are stricter)." ], "id": "bd893b70" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }