Fitting Espejo ferret A1 responses with the NRF model

Open In Colab

End-to-end demo: dataset → population filter → train / val / test split → model → fit → eval.

  • Data: Lopez-Espejo et al. (2019) ferret A1 — natural-sound release (NAT, F=18, ~540 cells across 7 animals). Loaded from public Zenodo deposit 3445557.

  • Population: one animal (AMT, the largest cohort with 168 cells).

  • Splits: paper-faithful test set (high-rep stims). Estimation set split 90/10 into train / val (stim-level holdout, fixed seed).

  • Model: Network Receptive Field (NRF, Harper et al. 2016) — a two-layer STRF network. This is the first deepSTRF example notebook for the NRF.

  • Loss / metrics: mse_loss (NaN-aware), corrcoef, normalized_corrcoef('schoppe') — the canonical deepSTRF triad.

Setup — Google Colab

On Colab, the next cell installs deepSTRF from source. On a local pip install -e . checkout it’s a no-op.

[ ]:
import sys
if 'google.colab' in sys.modules:
    !pip install -q git+https://github.com/urancon/deepSTRF.git
    print('deepSTRF installed from GitHub.')
else:
    print('Local environment — assuming deepSTRF is already importable.')

Imports

[ ]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Subset

from deepSTRF.datasets.audio import EspejoDataset
from deepSTRF.models.audio import NetworkReceptiveField
from deepSTRF.metrics import corrcoef, normalized_corrcoef
from deepSTRF.training import Fitter, set_random_seed
from deepSTRF.utils.data import neural_collate

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
set_random_seed(0)

1. Load Espejo NAT

We instantiate the dataset twice: once filtered to the estimation set (for train + val) and once to the test set. Both share the same 18-band gammatone log-spectrograms at dt = 10 ms. The first run with download=True fetches the 638 MB NAT archive into the platformdirs cache (~/.cache/deepSTRF/Espejo by default); subsequent runs reuse the unpacked archives.

[ ]:
ds_est = EspejoDataset(stimuli='nat', subset='estimation', download=True)
ds_test = EspejoDataset(stimuli='nat', subset='test', download=True)

print(ds_est)
print(ds_test)
print(f'\nstim shape: {tuple(ds_est.stims[0].shape)}  (1, F=18, T) at dt=10 ms')
print(f'sample est stim_meta:  {ds_est.stim_meta[0]}')
print(f'sample test stim_meta: {ds_test.stim_meta[0]}')

2. Filter to one animal (AMT)

Espejo NAT pools 7 ferret cohorts. We focus on AMT (168 cells, the largest population). The neuron-side filter select_pop_by_nrn_attr triggers the bidirectional rule: stims that no AMT cell heard are automatically hidden from __getitem__.

[ ]:
ANIMAL = 'AMT'

ds_est.select_pop_by_nrn_attr('animal_id', ANIMAL)
ds_test.select_pop_by_nrn_attr('animal_id', ANIMAL)

N_AMT = len(ds_est.I)
assert N_AMT == len(ds_test.I), 'AMT cell count should match across est/test'
print(f'AMT cells:        {N_AMT}')
print(f'est stims (visible after filter): {len(ds_est)}')
print(f'test stims (visible after filter): {len(ds_test)}')

3. Quick look — one stim, one cell

Inspect one of the high-rep test stims for one AMT cell: spectrogram, raster, PSTH.

[ ]:
# pick the test stim with the most coverage across AMT cells
masks = ds_test.nrn_masks[:, ds_test.I]   # (S, N_AMT)
stim_cov = masks.sum(dim=1)
stim_idx = int(stim_cov.argmax().item())
# inside that stim, pick a cell with valid data
cell_local_idx = int(masks[stim_idx].nonzero(as_tuple=True)[0][0].item())
cell_global_idx = ds_test.I[cell_local_idx]

spec = ds_test.stims[stim_idx][0].numpy()                   # (F, T)
resp = ds_test.responses[stim_idx][cell_global_idx].numpy() # (R, T)
psth = resp.mean(axis=0)
t = np.arange(spec.shape[1]) * 1e-2  # seconds (dt=10 ms)

fig, axs = plt.subplots(3, 1, figsize=(9, 5.5), sharex=True,
                        gridspec_kw={'height_ratios': [2, 2, 1]})
axs[0].imshow(spec, aspect='auto', origin='lower', cmap='magma',
              extent=[t[0], t[-1], 0, 18])
axs[0].set_ylabel('freq band')
axs[0].set_title(f'stim: {ds_test.stim_meta[stim_idx]["name"]}  |  cell: {ds_test.nrn_meta[cell_global_idx]["cell_id"]}  |  R={resp.shape[0]} reps')

yi, xi = np.where(resp > 0)
axs[1].scatter(t[xi], yi, s=4, c='k')
axs[1].set_ylabel('trial')
axs[1].set_ylim(-0.5, resp.shape[0] - 0.5)

axs[2].plot(t, psth, color='k', lw=1.2)
axs[2].set_xlabel('time (s)')
axs[2].set_ylabel('spikes/bin')
plt.tight_layout(); plt.show()

4. Train / val / test split

The test set is fixed by the paper convention (high-rep stims). The estimation set is split 90/10 at the stim level for train / val with a fixed seed.

[ ]:
S_est = len(ds_est)
rng = np.random.RandomState(0)
shuffled = list(range(S_est))
rng.shuffle(shuffled)
n_val = max(1, S_est // 10)
val_idx = sorted(shuffled[:n_val])
train_idx = sorted(shuffled[n_val:])

print(f'train stims: {len(train_idx)}')
print(f'val stims:   {len(val_idx)}')
print(f'test stims:  {len(ds_test)}')

BS = 32
train_loader = DataLoader(Subset(ds_est, train_idx), batch_size=BS,
                          shuffle=True, collate_fn=neural_collate)
val_loader   = DataLoader(Subset(ds_est, val_idx),   batch_size=BS,
                          shuffle=False, collate_fn=neural_collate)
test_loader  = DataLoader(ds_test,                    batch_size=BS,
                          shuffle=False, collate_fn=neural_collate)

5. Model — NRF

NetworkReceptiveField (Harper, Schoppe, Willmore, Cui, Schnupp & King 2016) is a two-layer STRF network: an STRF kernel projects the input spectrogram into H hidden units, then a per-neuron 1×1 readout produces the population output. With H=20 hidden units and N_AMT output neurons, one model fits all AMT cells jointly through a shared bottleneck.

[ ]:
model = NetworkReceptiveField(
    n_frequency_bands=18,
    temporal_window_size=15,    # 150 ms history at dt=10 ms
    n_hidden=20,
    out_neurons=N_AMT,
)
print(model)
print(f'\nTrainable params: {model.count_trainable_params():,}')
print(f'  per neuron:      {model.count_trainable_params() / N_AMT:,.0f}')

6. Train with the Fitter

Fitter wires the model, the loaders, an optimizer, and the canonical val metrics (cc, cc_norm) into a single .fit() call. Default loss is NaN-aware mse_loss against the auto-PSTH of responses. We monitor val_cc_norm (Schoppe noise-corrected correlation) and stop early on no improvement for 15 epochs.

[ ]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

fitter = Fitter(
    model, train_loader, val_loader,
    optimizer=optimizer,
    device=device,
    max_epochs=80,
    patience=15,
    monitor='val_cc_norm',
    mode='max',
    log_fn=lambda d: print(
        f"epoch {d['epoch']:3d}  "
        f"train_loss={d['train_loss']:.4f}  "
        f"val_loss={d['val_loss']:.4f}  "
        f"val_cc_norm={torch.nanmean(d['val_cc_norm']):+.3f}",
        flush=True,
    ),
)
history = fitter.fit()

7. Training curves

[ ]:
epochs = [h['epoch'] for h in history]
train_loss = [h['train_loss'] for h in history]
val_loss = [h['val_loss'] for h in history]
val_cc = [torch.nanmean(h['val_cc']).item() for h in history]
val_ccn = [torch.nanmean(h['val_cc_norm']).item() for h in history]

fig, axs = plt.subplots(1, 2, figsize=(10, 3.5))
axs[0].plot(epochs, train_loss, label='train', lw=1.5)
axs[0].plot(epochs, val_loss, label='val', lw=1.5)
axs[0].set_xlabel('epoch'); axs[0].set_ylabel('MSE loss'); axs[0].legend()
axs[0].set_title('Loss')

axs[1].plot(epochs, val_cc, label='val cc', lw=1.5)
axs[1].plot(epochs, val_ccn, label='val cc_norm', lw=1.5)
axs[1].set_xlabel('epoch'); axs[1].set_ylabel('mean across cells')
axs[1].legend(); axs[1].set_title('Correlations')
plt.tight_layout(); plt.show()

8. Test-set evaluation

Fitter.evaluate runs the same cross-batch concat-then-compute pipeline on the held-out test stims and returns the un-prefixed metric dict.

[ ]:
test_metrics = fitter.evaluate(test_loader)
test_cc = test_metrics['cc'].cpu()
test_ccn = test_metrics['cc_norm'].cpu()
print(f"test loss:       {test_metrics['loss']:.4f}")
print(f'test cc      mean={torch.nanmean(test_cc):+.3f}  median={torch.nanmedian(test_cc):+.3f}')
print(f'test cc_norm mean={torch.nanmean(test_ccn):+.3f}  median={torch.nanmedian(test_ccn):+.3f}')

9. Per-cell cc_norm distribution

[ ]:
fig, ax = plt.subplots(figsize=(8, 3.5))
valid = ~test_ccn.isnan()
ax.hist(test_ccn[valid].numpy(), bins=30, edgecolor='black', alpha=0.85)
ax.axvline(torch.nanmean(test_ccn).item(), color='red', lw=2,
           label=f'mean = {torch.nanmean(test_ccn):.3f}')
ax.set_xlabel('test cc_norm (Schoppe)')
ax.set_ylabel('# AMT cells')
ax.set_title('Per-cell noise-corrected correlation on Espejo NAT')
ax.legend()
plt.tight_layout(); plt.show()

10. NRF hidden STRFs

The hidden STRF kernels are interpretable as the model’s learned auditory features. We plot the first 8 hidden units’ STRFs.

[ ]:
n_show = min(8, model.H)
fig, axs = plt.subplots(2, n_show // 2, figsize=(2.0 * (n_show // 2), 4))
for h, ax in enumerate(axs.flat):
    strf = model.STRFs(hidden_idx=h).detach().cpu().numpy()   # (F, T)
    vmax = float(np.abs(strf).max() + 1e-9)
    ax.imshow(strf, aspect='auto', origin='lower', cmap='RdBu_r',
              vmin=-vmax, vmax=vmax)
    ax.set_title(f'h={h}', fontsize=9)
    ax.set_xticks([]); ax.set_yticks([])
fig.suptitle(f'NRF hidden STRFs (first {n_show} of {model.H} units)', y=1.02)
plt.tight_layout(); plt.show()

Notes

  • Estimation stims in Espejo NAT have 1–3 repetitions per (cell, stim) — the PSTH target is essentially the single trial. cc_norm (Schoppe) corrects for the resulting noise ceiling.

  • The subset='estimation' filter keeps the full estimation stim bank that any AMT site presented; some stims have valid responses for only a subset of the 168 AMT cells (the nrn_masks is block-sparse across recording sites). The Fitter’s NaN-aware loss handles this transparently.

  • With ~150 epochs and a few seeds, NRF on Espejo AMT typically reaches cc_norm in the 0.35–0.45 range — in line with the LN baselines reported in Lopez-Espejo et al. (their LN baseline is ~0.5 prediction correlation on a held-out stim; cc_norm metrics are stricter).