{ "cells": [ { "cell_type": "markdown", "id": "24d9a595-2e18-4a5a-a3b9-910e8fabaa43", "metadata": {}, "source": "# CRCNS AA4: inspection of the zebra-finch auditory pallium dataset\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/urancon/deepSTRF/blob/develop/examples/aa4_inspection.ipynb)\n\nThis notebook is a **visual smoke test** of the deepSTRF loader for\n**[CRCNS AA4](https://crcns.org/data-sets/aa/aa-4/about-aa-4)** (Elie &\nTheunissen, 2019) \u2014 extracellular spike trains from the avian auditory\npallium of zebra finches (Field-L, CLM, CMM, NCM), recorded in response\nto a large corpus of conspecific songs, calls, and ripple-noise\nstimuli. It complements `crcns_aa_tutorial.ipynb`, which covers the\nsibling AA1 / AA2 datasets, by focusing on the AA4-specific quirks:\n\n- **Sparse coverage** \u2014 not every cell heard every stim. The\n `nrn_masks` property is the canonical (stim, neuron) availability\n query, and the NaN-sentinel response convention propagates through\n the dataset API. See [`data_paradigm.md`](../docs/_source/md/data_paradigm.md) \u00a74.\n- **Per-cell electrode metadata** \u2014 AA4 cells carry `electrode` (1-32)\n and `subsort_id` fields, so we can re-render the population raster\n ordered by physical recording channel and check whether obvious\n bands persist.\n- **The filter API** at full reach \u2014 `select_pop_by_nrn_attr`,\n `select_pop_by_stim_attr` (neurons whose responses *cover* a given\n stim type), and `select_stims_by_attr` (restrict iteration to a stim\n subset).\n" }, { "cell_type": "markdown", "id": "e22e30af-74ba-4eb6-b222-fd3157f6eda4", "metadata": {}, "source": "## Setup \u2014 Google Colab\n\nIf you're running on Google Colab, the cell below installs deepSTRF\nfrom source. On a local install (`pip install -e .`) it's a no-op.\n\n**Note on data**: AA4 is an authenticated CRCNS dataset. To\nauto-download it, set `$CRCNS_USERNAME` and `$CRCNS_PASSWORD` (free\naccount at https://crcns.org/) before running the dataset cell. On a\nlocal machine that already has the data extracted, it's picked up from\nthe platformdirs cache automatically.\n" }, { "cell_type": "code", "execution_count": null, "id": "05bb46fe-d148-4bd2-981c-5c147cab0aa6", "metadata": {}, "outputs": [], "source": "import sys\nif 'google.colab' in sys.modules:\n !pip install -q git+https://github.com/urancon/deepSTRF.git\n print(\"deepSTRF installed from GitHub.\")\nelse:\n print(\"Local environment \u2014 assuming deepSTRF is already importable.\")\n" }, { "cell_type": "markdown", "id": "fee08d91-90a2-4ec8-90a0-2503dfa979a0", "metadata": {}, "source": "## Imports\n" }, { "cell_type": "code", "execution_count": null, "id": "13a8a3d5-72d1-4095-b5cd-efcd46aa816e", "metadata": {}, "outputs": [], "source": "%matplotlib inline\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport torch\n\nfrom deepSTRF.datasets.audio.crcns_aa4 import CRCNSAA4Dataset\n\n# Bin width in ms. Typical choices: 1, 5, 10.\nDT_MS = 5\n\n# Smallest animal \u2014 fastest first download / load. Drop this to pull the\n# full six-bird corpus.\nANIMAL = 'LblBlu2028M'\n" }, { "cell_type": "markdown", "id": "f10ab5c3-23c0-4b55-91e5-a2ab5ee67b5f", "metadata": {}, "source": "## 1. Instantiate the dataset\n\nThe constructor pulls the per-animal tarball from the CRCNS NERSC mirror\non first run and extracts it into the platformdirs cache; subsequent\nruns are free. We restrict to a single animal and keep the default\nsmoothing (21 ms Hanning, Hsu, Borst & Theunissen 2004).\n" }, { "cell_type": "code", "execution_count": null, "id": "fae25532-7b6a-4657-ae43-a51e48104582", "metadata": {}, "outputs": [], "source": "ds = CRCNSAA4Dataset(\n download=True,\n animals=(ANIMAL,),\n dt_ms=DT_MS,\n smooth=True,\n)\nprint(ds)\nprint(f'first stim_meta: {ds.stim_meta[0]}')\nprint(f'first nrn_meta: {ds.nrn_meta[0]}')\n\nm = ds.nrn_masks\nprint(f'nrn_masks shape {tuple(m.shape)} '\n f'valid={int(m.sum())}/{m.numel()} '\n f'coverage={m.float().mean().item():.2%}')\n" }, { "cell_type": "markdown", "id": "eb890040-d37a-4a84-8d38-c1c7c4d4cdf0", "metadata": {}, "source": "## 2. Single-cell view: spectrogram + raster + PSTH\n\nPick the cell with the broadest stim coverage, then sample four stims\nit heard, evenly spaced across the stim list. For each stim we show\nthe mel-spectrogram (top), per-trial spike raster (middle), and\nsmoothed PSTH (bottom).\n" }, { "cell_type": "code", "execution_count": null, "id": "90f57749-fff2-4984-aff7-eba7441a0e61", "metadata": {}, "outputs": [], "source": "per_cell_cov = m.sum(dim=0) # (N,)\nn_idx = int(per_cell_cov.argmax().item())\nprint(f'picked cell {n_idx}: {ds.nrn_meta[n_idx]}')\nprint(f' covers {int(per_cell_cov[n_idx].item())}/{m.shape[0]} stims')\n\nvalid_s = [s for s in range(m.shape[0]) if m[s, n_idx]]\npicks = [valid_s[i] for i in np.linspace(0, len(valid_s) - 1, 4).astype(int)]\nprint(f'picked stims: {picks}')\nfor s in picks:\n sm = ds.stim_meta[s]\n print(f\" s={s}: type={sm['type']}/{sm['class']} \"\n f\"name={sm['name'][:10]}... \"\n f\"resp shape={tuple(ds.responses[s][n_idx].shape)}\")\n" }, { "cell_type": "code", "execution_count": null, "id": "81d8f0e7-de1b-45c2-8f9f-fb4463c8412e", "metadata": {}, "outputs": [], "source": "fig, axes = plt.subplots(3, len(picks), figsize=(4*len(picks), 8), sharex='col')\nfor col, s in enumerate(picks):\n spec = ds.stims[s][0].numpy() # (F, T)\n resp = ds.responses[s][n_idx].numpy() # (R, T)\n psth = resp.mean(axis=0) # (T,)\n R, T = resp.shape\n t_axis = np.arange(T) * DT_MS / 1000.0 # seconds\n\n ax = axes[0, col]\n ax.imshow(spec, origin='lower', aspect='auto',\n extent=[t_axis[0], t_axis[-1], 0, spec.shape[0]])\n ax.set_title(f\"s={s} {ds.stim_meta[s]['type']}/{ds.stim_meta[s]['class']}\")\n if col == 0:\n ax.set_ylabel('mel band')\n\n ax = axes[1, col]\n ax.imshow(resp, origin='lower', aspect='auto', cmap='gray_r',\n extent=[t_axis[0], t_axis[-1], 0, R])\n if col == 0:\n ax.set_ylabel('trial')\n\n ax = axes[2, col]\n ax.plot(t_axis, psth)\n ax.set_xlabel('time (s)')\n if col == 0:\n ax.set_ylabel('PSTH (smoothed)')\n\nplt.suptitle(f\"cell {n_idx}: {ds.nrn_meta[n_idx]['cell_id']}\", y=1.02)\nplt.tight_layout()\nplt.show()\n" }, { "cell_type": "markdown", "id": "c453ab7b-3209-4b3b-ab5a-21a47a6d0c8e", "metadata": {}, "source": "## 3. Population PSTH raster for one stim\n\nFor one stim, plot the mean-across-trials PSTH for every cell as a row\nin an `(N, T)` matrix. Cells that did not hear this stim \u2014 i.e. the\n`(1, 1)` NaN sentinels under the deepSTRF data paradigm \u2014 are rendered\nas a distinct grey, so they are visually separable from cells with\nvalid PSTHs that happened to fire little or nothing.\n" }, { "cell_type": "code", "execution_count": null, "id": "074a81c7-cd65-4af9-87a7-0b291f2b724d", "metadata": {}, "outputs": [], "source": "per_stim_cov = m.sum(dim=1)\ns_idx = int(per_stim_cov.argmax().item())\nprint(f\"picked stim {s_idx}: {ds.stim_meta[s_idx]} covered by \"\n f\"{int(per_stim_cov[s_idx].item())}/{m.shape[1]} cells\")\n\nT = ds.stims[s_idx].shape[-1]\nN = ds.N_neurons\npsth_pop = np.full((N, T), np.nan, dtype=np.float32)\nfor n in range(N):\n if m[s_idx, n]:\n psth_pop[n] = ds.responses[s_idx][n].numpy().mean(axis=0)\n\n# normalise each *valid* row to its own max so rasters across cells are comparable\nwith np.errstate(invalid='ignore'):\n row_max = np.nanmax(psth_pop, axis=1, keepdims=True)\n row_max[row_max == 0] = 1.0\n psth_pop_norm = psth_pop / row_max\n" }, { "cell_type": "code", "execution_count": null, "id": "ff2e3c6d-2f4f-41a9-b449-7c81346760db", "metadata": {}, "outputs": [], "source": "fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True,\n gridspec_kw={'height_ratios': [1, 4]})\n\nspec = ds.stims[s_idx][0].numpy()\nt_axis = np.arange(T) * DT_MS / 1000.0\naxes[0].imshow(spec, origin='lower', aspect='auto',\n extent=[t_axis[0], t_axis[-1], 0, spec.shape[0]])\naxes[0].set_title(f\"stim s={s_idx} type={ds.stim_meta[s_idx]['type']} \"\n f\"class={ds.stim_meta[s_idx]['class']}\")\naxes[0].set_ylabel('mel band')\n\n# NaN cells get a distinctive grey via cmap.set_bad\ncmap = plt.get_cmap('viridis').copy()\ncmap.set_bad(color='lightgrey')\nim = axes[1].imshow(np.ma.masked_invalid(psth_pop_norm),\n origin='lower', aspect='auto', cmap=cmap,\n extent=[t_axis[0], t_axis[-1], 0, N],\n interpolation='nearest')\naxes[1].set_xlabel('time (s)')\naxes[1].set_ylabel('neuron index')\nfig.colorbar(im, ax=axes[1], label='normalised PSTH')\n\nn_invalid = int((~m[s_idx]).sum())\naxes[1].text(0.01, 0.99,\n f'grey rows = {n_invalid} cell(s) that did NOT hear this stim '\n f'(NaN sentinel)',\n transform=axes[1].transAxes, va='top', fontsize=9,\n bbox=dict(facecolor='white', alpha=0.85, edgecolor='lightgrey'))\nplt.tight_layout()\nplt.show()\n" }, { "cell_type": "markdown", "id": "8a42eb63-5d38-48b5-808f-77111b121441", "metadata": {}, "source": "## 4. Demo of the filter API\n\n`NeuralDataset` exposes a small selection API; calls mutate `self.I`\n(neuron selection) or `self.S_sel` (stim selection), after which\n`__len__` and `__getitem__` only iterate over the still-active\nindices, with the\n[bidirectional rule](../docs/_source/md/data_paradigm.md#8-iteration-honours-the-current-selection-bidirectional)\nauto-hiding cells with no valid responses left in the active stim set.\n\n- `select_neuron(i)` / `select_population([i, j, ...])` \u2014 manual indices\n- `select_pop_by_nrn_attr(key, value)` \u2014 by neuron-metadata key\n- `select_pop_by_stim_attr(key, value)` \u2014 keep cells with \u2265 1 response\n to a stim matching `stim_meta[key] == value`\n- `select_stims_by_attr(key, value)` \u2014 restrict the active stim space\n- `reset_pop_selection()` / `reset_stim_selection()` \u2014 clear the\n respective selection\n" }, { "cell_type": "code", "execution_count": null, "id": "e45be5fe-2edc-4486-b428-f49337aa5cc7", "metadata": {}, "outputs": [], "source": "sel = ds.select_pop_by_nrn_attr('animal_id', ANIMAL)\nprint(f\"select_pop_by_nrn_attr('animal_id', '{ANIMAL}') -> {len(sel)} neurons\")\nprint(f\" len(ds) under this selection: {len(ds)}\")\n\nsel = ds.select_pop_by_stim_attr('type', 'song')\nprint(f\"select_pop_by_stim_attr('type', 'song') -> {len(sel)} neurons \"\n f\"(at least one song response)\")\n\nds.reset_pop_selection()\nprint(f\"reset_pop_selection() -> len(ds) = {len(ds)} (back to all neurons)\")\n\ns_sel = ds.select_stims_by_attr('type', 'song')\nprint(f\"select_stims_by_attr('type', 'song') -> {len(s_sel)} stims, \"\n f\"len(ds) = {len(ds)}\")\n\nds.reset_stim_selection()\nprint(f\"reset_stim_selection() -> len(ds) = {len(ds)}\")\n" }, { "cell_type": "markdown", "id": "4a42a715-58f3-4a05-8379-c676d0fda835", "metadata": {}, "source": "## 5. Cells by electrode # (sanity check on spatial structure)\n\nThe AA4 loader stores per-cell `electrode` (1-32) and `subsort_id` in\n`nrn_meta`. With that, we can re-render the same single-stim\npopulation PSTH raster but with cells **sorted numerically by\n`(electrode, subsort_id)`** instead of the default lex-sort on\nfilenames. If bands persist under this re-sort, they reflect real\nspatial/functional structure rather than a quirk of filename ordering.\n\nPer the dataset PDF, each recording site uses two 16-electrode arrays\nplaced bilaterally \u2014 so an e1-e16 vs. e17-e32 split is the natural\nanatomical hypothesis to test, once the hemisphere convention is\nconfirmed.\n" }, { "cell_type": "code", "execution_count": null, "id": "5926ad0e-d209-45d8-929c-3f3fc6434fb9", "metadata": {}, "outputs": [], "source": "order = sorted(range(N), key=lambda i: (\n ds.nrn_meta[i]['electrode'],\n ds.nrn_meta[i].get('subsort_id') or 0,\n))\npsth_pop_norm_reord = psth_pop_norm[order]\nelectrodes = [ds.nrn_meta[i]['electrode'] for i in order]\n\nfig, ax = plt.subplots(figsize=(10, 6))\nim = ax.imshow(np.ma.masked_invalid(psth_pop_norm_reord),\n origin='lower', aspect='auto', cmap=cmap,\n extent=[t_axis[0], t_axis[-1], 0, N],\n interpolation='nearest')\nax.set_xlabel('time (s)')\nax.set_ylabel('neuron (re-sorted by electrode, subsort_id)')\nax.set_title(f\"stim s={s_idx} \u2014 cells sorted by electrode #\")\nfig.colorbar(im, ax=ax, label='normalised PSTH')\n\n# overlay electrode-number ticks on the right y-axis\nax2 = ax.twinx()\nax2.set_ylim(ax.get_ylim())\ntick_y = np.arange(N) + 0.5\nax2.set_yticks(tick_y[::4])\nax2.set_yticklabels([f'e{electrodes[i]}' for i in range(0, N, 4)], fontsize=7)\nax2.set_ylabel('electrode')\n\n# horizontal line at the e16 -> e17 boundary (candidate hemisphere split)\nboundary = next((y for y, e in enumerate(electrodes) if e > 16), None)\nif boundary is not None:\n ax.axhline(boundary, color='red', linestyle='--', linewidth=1)\n ax.text(t_axis[-1] * 0.99, boundary + 0.5, ' e16/e17 boundary',\n color='red', ha='right', va='bottom', fontsize=8)\n\nplt.tight_layout()\nplt.show()\n" }, { "cell_type": "markdown", "id": "8a384c09-715e-495d-aa62-9af9aa663fa2", "metadata": {}, "source": "## Recap\n\n- AA4 is loaded through the same `NeuralDataset` interface as the rest\n of the audio zoo \u2014 `stims`, `responses`, `stim_meta`, `nrn_meta`,\n plus the derived `nrn_masks` property.\n- Coverage is **sparse**: only a fraction of (stim, neuron) pairs have\n data; the rest are encoded as `(1, 1)` NaN sentinels. The\n `nrn_masks` property is the canonical availability query.\n- The filter API combines bidirectionally \u2014 narrowing the stim space\n auto-hides cells that have no responses left in it, and vice versa.\n- Cell-level metadata (`electrode`, `subsort_id`, `animal_id`, \u2026) lets\n you re-order the population raster against real recording geometry\n rather than filename order.\n\nNext stop: pick a model from `deepSTRF.models.audio` and fit it on the\nselection of your choice. See the `strf_gradmap_aa2.ipynb` notebook\nfor a worked example of training + interpretability on AA2 \u2014 the same\nrecipe transfers to AA4.\n" } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 5 }