CRCNS AA4: inspection of the zebra-finch auditory pallium dataset

Open In Colab

This notebook is a visual smoke test of the deepSTRF loader for CRCNS AA4 (Elie & Theunissen, 2019) — extracellular spike trains from the avian auditory pallium of zebra finches (Field-L, CLM, CMM, NCM), recorded in response to a large corpus of conspecific songs, calls, and ripple-noise stimuli. It complements crcns_aa_tutorial.ipynb, which covers the sibling AA1 / AA2 datasets, by focusing on the AA4-specific quirks:

  • Sparse coverage — not every cell heard every stim. The nrn_masks property is the canonical (stim, neuron) availability query, and the NaN-sentinel response convention propagates through the dataset API. See `data_paradigm.md <../docs/_source/md/data_paradigm.md>`__ §4.

  • Per-cell electrode metadata — AA4 cells carry electrode (1-32) and subsort_id fields, so we can re-render the population raster ordered by physical recording channel and check whether obvious bands persist.

  • The filter API at full reach — select_pop_by_nrn_attr, select_pop_by_stim_attr (neurons whose responses cover a given stim type), and select_stims_by_attr (restrict iteration to a stim subset).

Setup — Google Colab

If you’re running on Google Colab, the cell below installs deepSTRF from source. On a local install (pip install -e .) it’s a no-op.

Note on data: AA4 is an authenticated CRCNS dataset. To auto-download it, set $CRCNS_USERNAME and $CRCNS_PASSWORD (free account at https://crcns.org/) before running the dataset cell. On a local machine that already has the data extracted, it’s picked up from the platformdirs cache automatically.

[ ]:
import sys
if 'google.colab' in sys.modules:
    !pip install -q git+https://github.com/urancon/deepSTRF.git
    print("deepSTRF installed from GitHub.")
else:
    print("Local environment — assuming deepSTRF is already importable.")

Imports

[ ]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch

from deepSTRF.datasets.audio.crcns_aa4 import CRCNSAA4Dataset

# Bin width in ms. Typical choices: 1, 5, 10.
DT_MS = 5

# Smallest animal — fastest first download / load. Drop this to pull the
# full six-bird corpus.
ANIMAL = 'LblBlu2028M'

1. Instantiate the dataset

The constructor pulls the per-animal tarball from the CRCNS NERSC mirror on first run and extracts it into the platformdirs cache; subsequent runs are free. We restrict to a single animal and keep the default smoothing (21 ms Hanning, Hsu, Borst & Theunissen 2004).

[ ]:
ds = CRCNSAA4Dataset(
    download=True,
    animals=(ANIMAL,),
    dt_ms=DT_MS,
    smooth=True,
)
print(ds)
print(f'first stim_meta:        {ds.stim_meta[0]}')
print(f'first nrn_meta:         {ds.nrn_meta[0]}')

m = ds.nrn_masks
print(f'nrn_masks shape {tuple(m.shape)}  '
      f'valid={int(m.sum())}/{m.numel()}  '
      f'coverage={m.float().mean().item():.2%}')

2. Single-cell view: spectrogram + raster + PSTH

Pick the cell with the broadest stim coverage, then sample four stims it heard, evenly spaced across the stim list. For each stim we show the mel-spectrogram (top), per-trial spike raster (middle), and smoothed PSTH (bottom).

[ ]:
per_cell_cov = m.sum(dim=0)                           # (N,)
n_idx = int(per_cell_cov.argmax().item())
print(f'picked cell {n_idx}: {ds.nrn_meta[n_idx]}')
print(f'  covers {int(per_cell_cov[n_idx].item())}/{m.shape[0]} stims')

valid_s = [s for s in range(m.shape[0]) if m[s, n_idx]]
picks = [valid_s[i] for i in np.linspace(0, len(valid_s) - 1, 4).astype(int)]
print(f'picked stims: {picks}')
for s in picks:
    sm = ds.stim_meta[s]
    print(f"  s={s}: type={sm['type']}/{sm['class']}  "
          f"name={sm['name'][:10]}...  "
          f"resp shape={tuple(ds.responses[s][n_idx].shape)}")

[ ]:
fig, axes = plt.subplots(3, len(picks), figsize=(4*len(picks), 8), sharex='col')
for col, s in enumerate(picks):
    spec = ds.stims[s][0].numpy()                         # (F, T)
    resp = ds.responses[s][n_idx].numpy()                 # (R, T)
    psth = resp.mean(axis=0)                              # (T,)
    R, T = resp.shape
    t_axis = np.arange(T) * DT_MS / 1000.0                # seconds

    ax = axes[0, col]
    ax.imshow(spec, origin='lower', aspect='auto',
              extent=[t_axis[0], t_axis[-1], 0, spec.shape[0]])
    ax.set_title(f"s={s}  {ds.stim_meta[s]['type']}/{ds.stim_meta[s]['class']}")
    if col == 0:
        ax.set_ylabel('mel band')

    ax = axes[1, col]
    ax.imshow(resp, origin='lower', aspect='auto', cmap='gray_r',
              extent=[t_axis[0], t_axis[-1], 0, R])
    if col == 0:
        ax.set_ylabel('trial')

    ax = axes[2, col]
    ax.plot(t_axis, psth)
    ax.set_xlabel('time (s)')
    if col == 0:
        ax.set_ylabel('PSTH (smoothed)')

plt.suptitle(f"cell {n_idx}: {ds.nrn_meta[n_idx]['cell_id']}", y=1.02)
plt.tight_layout()
plt.show()

3. Population PSTH raster for one stim

For one stim, plot the mean-across-trials PSTH for every cell as a row in an (N, T) matrix. Cells that did not hear this stim — i.e. the (1, 1) NaN sentinels under the deepSTRF data paradigm — are rendered as a distinct grey, so they are visually separable from cells with valid PSTHs that happened to fire little or nothing.

[ ]:
per_stim_cov = m.sum(dim=1)
s_idx = int(per_stim_cov.argmax().item())
print(f"picked stim {s_idx}: {ds.stim_meta[s_idx]}  covered by "
      f"{int(per_stim_cov[s_idx].item())}/{m.shape[1]} cells")

T = ds.stims[s_idx].shape[-1]
N = ds.N_neurons
psth_pop = np.full((N, T), np.nan, dtype=np.float32)
for n in range(N):
    if m[s_idx, n]:
        psth_pop[n] = ds.responses[s_idx][n].numpy().mean(axis=0)

# normalise each *valid* row to its own max so rasters across cells are comparable
with np.errstate(invalid='ignore'):
    row_max = np.nanmax(psth_pop, axis=1, keepdims=True)
    row_max[row_max == 0] = 1.0
    psth_pop_norm = psth_pop / row_max

[ ]:
fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True,
                         gridspec_kw={'height_ratios': [1, 4]})

spec = ds.stims[s_idx][0].numpy()
t_axis = np.arange(T) * DT_MS / 1000.0
axes[0].imshow(spec, origin='lower', aspect='auto',
               extent=[t_axis[0], t_axis[-1], 0, spec.shape[0]])
axes[0].set_title(f"stim s={s_idx}  type={ds.stim_meta[s_idx]['type']}  "
                  f"class={ds.stim_meta[s_idx]['class']}")
axes[0].set_ylabel('mel band')

# NaN cells get a distinctive grey via cmap.set_bad
cmap = plt.get_cmap('viridis').copy()
cmap.set_bad(color='lightgrey')
im = axes[1].imshow(np.ma.masked_invalid(psth_pop_norm),
                    origin='lower', aspect='auto', cmap=cmap,
                    extent=[t_axis[0], t_axis[-1], 0, N],
                    interpolation='nearest')
axes[1].set_xlabel('time (s)')
axes[1].set_ylabel('neuron index')
fig.colorbar(im, ax=axes[1], label='normalised PSTH')

n_invalid = int((~m[s_idx]).sum())
axes[1].text(0.01, 0.99,
             f'grey rows = {n_invalid} cell(s) that did NOT hear this stim '
             f'(NaN sentinel)',
             transform=axes[1].transAxes, va='top', fontsize=9,
             bbox=dict(facecolor='white', alpha=0.85, edgecolor='lightgrey'))
plt.tight_layout()
plt.show()

4. Demo of the filter API

NeuralDataset exposes a small selection API; calls mutate self.I (neuron selection) or self.S_sel (stim selection), after which __len__ and __getitem__ only iterate over the still-active indices, with the bidirectional rule auto-hiding cells with no valid responses left in the active stim set.

  • select_neuron(i) / select_population([i, j, ...]) — manual indices

  • select_pop_by_nrn_attr(key, value) — by neuron-metadata key

  • select_pop_by_stim_attr(key, value) — keep cells with ≥ 1 response to a stim matching stim_meta[key] == value

  • select_stims_by_attr(key, value) — restrict the active stim space

  • reset_pop_selection() / reset_stim_selection() — clear the respective selection

[ ]:
sel = ds.select_pop_by_nrn_attr('animal_id', ANIMAL)
print(f"select_pop_by_nrn_attr('animal_id', '{ANIMAL}') -> {len(sel)} neurons")
print(f"  len(ds) under this selection: {len(ds)}")

sel = ds.select_pop_by_stim_attr('type', 'song')
print(f"select_pop_by_stim_attr('type', 'song')        -> {len(sel)} neurons "
      f"(at least one song response)")

ds.reset_pop_selection()
print(f"reset_pop_selection() -> len(ds) = {len(ds)}  (back to all neurons)")

s_sel = ds.select_stims_by_attr('type', 'song')
print(f"select_stims_by_attr('type', 'song')           -> {len(s_sel)} stims, "
      f"len(ds) = {len(ds)}")

ds.reset_stim_selection()
print(f"reset_stim_selection() -> len(ds) = {len(ds)}")

5. Cells by electrode # (sanity check on spatial structure)

The AA4 loader stores per-cell electrode (1-32) and subsort_id in nrn_meta. With that, we can re-render the same single-stim population PSTH raster but with cells sorted numerically by ``(electrode, subsort_id)`` instead of the default lex-sort on filenames. If bands persist under this re-sort, they reflect real spatial/functional structure rather than a quirk of filename ordering.

Per the dataset PDF, each recording site uses two 16-electrode arrays placed bilaterally — so an e1-e16 vs. e17-e32 split is the natural anatomical hypothesis to test, once the hemisphere convention is confirmed.

[ ]:
order = sorted(range(N), key=lambda i: (
    ds.nrn_meta[i]['electrode'],
    ds.nrn_meta[i].get('subsort_id') or 0,
))
psth_pop_norm_reord = psth_pop_norm[order]
electrodes = [ds.nrn_meta[i]['electrode'] for i in order]

fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(np.ma.masked_invalid(psth_pop_norm_reord),
               origin='lower', aspect='auto', cmap=cmap,
               extent=[t_axis[0], t_axis[-1], 0, N],
               interpolation='nearest')
ax.set_xlabel('time (s)')
ax.set_ylabel('neuron (re-sorted by electrode, subsort_id)')
ax.set_title(f"stim s={s_idx} — cells sorted by electrode #")
fig.colorbar(im, ax=ax, label='normalised PSTH')

# overlay electrode-number ticks on the right y-axis
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
tick_y = np.arange(N) + 0.5
ax2.set_yticks(tick_y[::4])
ax2.set_yticklabels([f'e{electrodes[i]}' for i in range(0, N, 4)], fontsize=7)
ax2.set_ylabel('electrode')

# horizontal line at the e16 -> e17 boundary (candidate hemisphere split)
boundary = next((y for y, e in enumerate(electrodes) if e > 16), None)
if boundary is not None:
    ax.axhline(boundary, color='red', linestyle='--', linewidth=1)
    ax.text(t_axis[-1] * 0.99, boundary + 0.5, ' e16/e17 boundary',
            color='red', ha='right', va='bottom', fontsize=8)

plt.tight_layout()
plt.show()

Recap

  • AA4 is loaded through the same NeuralDataset interface as the rest of the audio zoo — stims, responses, stim_meta, nrn_meta, plus the derived nrn_masks property.

  • Coverage is sparse: only a fraction of (stim, neuron) pairs have data; the rest are encoded as (1, 1) NaN sentinels. The nrn_masks property is the canonical availability query.

  • The filter API combines bidirectionally — narrowing the stim space auto-hides cells that have no responses left in it, and vice versa.

  • Cell-level metadata (electrode, subsort_id, animal_id, …) lets you re-order the population raster against real recording geometry rather than filename order.

Next stop: pick a model from deepSTRF.models.audio and fit it on the selection of your choice. See the strf_gradmap_aa2.ipynb notebook for a worked example of training + interpretability on AA2 — the same recipe transfers to AA4.