Fitting Espejo ferret A1 responses with the NRF model
End-to-end demo: dataset → population filter → train / val / test split → model → fit → eval.
Data: Lopez-Espejo et al. (2019) ferret A1 — natural-sound release (NAT, F=18, ~540 cells across 7 animals). Loaded from public Zenodo deposit
3445557.Population: one animal (
AMT, the largest cohort with 168 cells).Splits: paper-faithful test set (high-rep stims). Estimation set split 90/10 into train / val (stim-level holdout, fixed seed).
Model: Network Receptive Field (NRF, Harper et al. 2016) — a two-layer STRF network. This is the first deepSTRF example notebook for the NRF.
Loss / metrics:
mse_loss(NaN-aware),corrcoef,normalized_corrcoef('schoppe')— the canonical deepSTRF triad.
Setup — Google Colab
On Colab, the next cell installs deepSTRF from source. On a local pip install -e . checkout it’s a no-op.
[ ]:
import sys
if 'google.colab' in sys.modules:
!pip install -q git+https://github.com/urancon/deepSTRF.git
print('deepSTRF installed from GitHub.')
else:
print('Local environment — assuming deepSTRF is already importable.')
Imports
[ ]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Subset
from deepSTRF.datasets.audio import EspejoDataset
from deepSTRF.models.audio import NetworkReceptiveField
from deepSTRF.metrics import corrcoef, normalized_corrcoef
from deepSTRF.training import Fitter, set_random_seed
from deepSTRF.utils.data import neural_collate
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
set_random_seed(0)
1. Load Espejo NAT
We instantiate the dataset twice: once filtered to the estimation set (for train + val) and once to the test set. Both share the same 18-band gammatone log-spectrograms at dt = 10 ms. The first run with download=True fetches the 638 MB NAT archive into the platformdirs cache (~/.cache/deepSTRF/Espejo by default); subsequent runs reuse the unpacked archives.
[ ]:
ds_est = EspejoDataset(stimuli='nat', subset='estimation', download=True)
ds_test = EspejoDataset(stimuli='nat', subset='test', download=True)
print(ds_est)
print(ds_test)
print(f'\nstim shape: {tuple(ds_est.stims[0].shape)} (1, F=18, T) at dt=10 ms')
print(f'sample est stim_meta: {ds_est.stim_meta[0]}')
print(f'sample test stim_meta: {ds_test.stim_meta[0]}')
2. Filter to one animal (AMT)
Espejo NAT pools 7 ferret cohorts. We focus on AMT (168 cells, the largest population). The neuron-side filter select_pop_by_nrn_attr triggers the bidirectional rule: stims that no AMT cell heard are automatically hidden from __getitem__.
[ ]:
ANIMAL = 'AMT'
ds_est.select_pop_by_nrn_attr('animal_id', ANIMAL)
ds_test.select_pop_by_nrn_attr('animal_id', ANIMAL)
N_AMT = len(ds_est.I)
assert N_AMT == len(ds_test.I), 'AMT cell count should match across est/test'
print(f'AMT cells: {N_AMT}')
print(f'est stims (visible after filter): {len(ds_est)}')
print(f'test stims (visible after filter): {len(ds_test)}')
3. Quick look — one stim, one cell
Inspect one of the high-rep test stims for one AMT cell: spectrogram, raster, PSTH.
[ ]:
# pick the test stim with the most coverage across AMT cells
masks = ds_test.nrn_masks[:, ds_test.I] # (S, N_AMT)
stim_cov = masks.sum(dim=1)
stim_idx = int(stim_cov.argmax().item())
# inside that stim, pick a cell with valid data
cell_local_idx = int(masks[stim_idx].nonzero(as_tuple=True)[0][0].item())
cell_global_idx = ds_test.I[cell_local_idx]
spec = ds_test.stims[stim_idx][0].numpy() # (F, T)
resp = ds_test.responses[stim_idx][cell_global_idx].numpy() # (R, T)
psth = resp.mean(axis=0)
t = np.arange(spec.shape[1]) * 1e-2 # seconds (dt=10 ms)
fig, axs = plt.subplots(3, 1, figsize=(9, 5.5), sharex=True,
gridspec_kw={'height_ratios': [2, 2, 1]})
axs[0].imshow(spec, aspect='auto', origin='lower', cmap='magma',
extent=[t[0], t[-1], 0, 18])
axs[0].set_ylabel('freq band')
axs[0].set_title(f'stim: {ds_test.stim_meta[stim_idx]["name"]} | cell: {ds_test.nrn_meta[cell_global_idx]["cell_id"]} | R={resp.shape[0]} reps')
yi, xi = np.where(resp > 0)
axs[1].scatter(t[xi], yi, s=4, c='k')
axs[1].set_ylabel('trial')
axs[1].set_ylim(-0.5, resp.shape[0] - 0.5)
axs[2].plot(t, psth, color='k', lw=1.2)
axs[2].set_xlabel('time (s)')
axs[2].set_ylabel('spikes/bin')
plt.tight_layout(); plt.show()
4. Train / val / test split
The test set is fixed by the paper convention (high-rep stims). The estimation set is split 90/10 at the stim level for train / val with a fixed seed.
[ ]:
S_est = len(ds_est)
rng = np.random.RandomState(0)
shuffled = list(range(S_est))
rng.shuffle(shuffled)
n_val = max(1, S_est // 10)
val_idx = sorted(shuffled[:n_val])
train_idx = sorted(shuffled[n_val:])
print(f'train stims: {len(train_idx)}')
print(f'val stims: {len(val_idx)}')
print(f'test stims: {len(ds_test)}')
BS = 32
train_loader = DataLoader(Subset(ds_est, train_idx), batch_size=BS,
shuffle=True, collate_fn=neural_collate)
val_loader = DataLoader(Subset(ds_est, val_idx), batch_size=BS,
shuffle=False, collate_fn=neural_collate)
test_loader = DataLoader(ds_test, batch_size=BS,
shuffle=False, collate_fn=neural_collate)
5. Model — NRF
NetworkReceptiveField (Harper, Schoppe, Willmore, Cui, Schnupp & King 2016) is a two-layer STRF network: an STRF kernel projects the input spectrogram into H hidden units, then a per-neuron 1×1 readout produces the population output. With H=20 hidden units and N_AMT output neurons, one model fits all AMT cells jointly through a shared bottleneck.
[ ]:
model = NetworkReceptiveField(
n_frequency_bands=18,
temporal_window_size=15, # 150 ms history at dt=10 ms
n_hidden=20,
out_neurons=N_AMT,
)
print(model)
print(f'\nTrainable params: {model.count_trainable_params():,}')
print(f' per neuron: {model.count_trainable_params() / N_AMT:,.0f}')
6. Train with the Fitter
Fitter wires the model, the loaders, an optimizer, and the canonical val metrics (cc, cc_norm) into a single .fit() call. Default loss is NaN-aware mse_loss against the auto-PSTH of responses. We monitor val_cc_norm (Schoppe noise-corrected correlation) and stop early on no improvement for 15 epochs.
[ ]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
fitter = Fitter(
model, train_loader, val_loader,
optimizer=optimizer,
device=device,
max_epochs=80,
patience=15,
monitor='val_cc_norm',
mode='max',
log_fn=lambda d: print(
f"epoch {d['epoch']:3d} "
f"train_loss={d['train_loss']:.4f} "
f"val_loss={d['val_loss']:.4f} "
f"val_cc_norm={torch.nanmean(d['val_cc_norm']):+.3f}",
flush=True,
),
)
history = fitter.fit()
7. Training curves
[ ]:
epochs = [h['epoch'] for h in history]
train_loss = [h['train_loss'] for h in history]
val_loss = [h['val_loss'] for h in history]
val_cc = [torch.nanmean(h['val_cc']).item() for h in history]
val_ccn = [torch.nanmean(h['val_cc_norm']).item() for h in history]
fig, axs = plt.subplots(1, 2, figsize=(10, 3.5))
axs[0].plot(epochs, train_loss, label='train', lw=1.5)
axs[0].plot(epochs, val_loss, label='val', lw=1.5)
axs[0].set_xlabel('epoch'); axs[0].set_ylabel('MSE loss'); axs[0].legend()
axs[0].set_title('Loss')
axs[1].plot(epochs, val_cc, label='val cc', lw=1.5)
axs[1].plot(epochs, val_ccn, label='val cc_norm', lw=1.5)
axs[1].set_xlabel('epoch'); axs[1].set_ylabel('mean across cells')
axs[1].legend(); axs[1].set_title('Correlations')
plt.tight_layout(); plt.show()
8. Test-set evaluation
Fitter.evaluate runs the same cross-batch concat-then-compute pipeline on the held-out test stims and returns the un-prefixed metric dict.
[ ]:
test_metrics = fitter.evaluate(test_loader)
test_cc = test_metrics['cc'].cpu()
test_ccn = test_metrics['cc_norm'].cpu()
print(f"test loss: {test_metrics['loss']:.4f}")
print(f'test cc mean={torch.nanmean(test_cc):+.3f} median={torch.nanmedian(test_cc):+.3f}')
print(f'test cc_norm mean={torch.nanmean(test_ccn):+.3f} median={torch.nanmedian(test_ccn):+.3f}')
9. Per-cell cc_norm distribution
[ ]:
fig, ax = plt.subplots(figsize=(8, 3.5))
valid = ~test_ccn.isnan()
ax.hist(test_ccn[valid].numpy(), bins=30, edgecolor='black', alpha=0.85)
ax.axvline(torch.nanmean(test_ccn).item(), color='red', lw=2,
label=f'mean = {torch.nanmean(test_ccn):.3f}')
ax.set_xlabel('test cc_norm (Schoppe)')
ax.set_ylabel('# AMT cells')
ax.set_title('Per-cell noise-corrected correlation on Espejo NAT')
ax.legend()
plt.tight_layout(); plt.show()
Notes
Estimation stims in Espejo NAT have 1–3 repetitions per (cell, stim) — the PSTH target is essentially the single trial.
cc_norm(Schoppe) corrects for the resulting noise ceiling.The
subset='estimation'filter keeps the full estimation stim bank that any AMT site presented; some stims have valid responses for only a subset of the 168 AMT cells (thenrn_masksis block-sparse across recording sites). The Fitter’s NaN-aware loss handles this transparently.With ~150 epochs and a few seeds, NRF on Espejo AMT typically reaches
cc_normin the 0.35–0.45 range — in line with the LN baselines reported in Lopez-Espejo et al. (their LN baseline is ~0.5 prediction correlation on a held-out stim; cc_norm metrics are stricter).