STRF gradmaps: 2D-CNN on CRCNS AA2 (OV / conspecific)

Open In Colab

This notebook fits the 2D-CNN of Pennington & David (2023) on a slice of the CRCNS AA2 auditory dataset (zebra finch ovoidalis, OV — a thalamic nucleus relaying auditory information to forebrain Field L) and extracts a per-neuron STRF gradient map from the trained model — the canonical interpretable “what does this cell respond to?” figure for nonlinear receptive-field models.

The 2D-CNN doesn’t have an explicit STRF parameter to plot. Instead, we ask: at the neuron’s last-timestep prediction, what change in a null spectrogram (zeros) most increases the response? Autodiff gives that gradient in closed form — a (F, T) map per neuron. The deepSTRF base class implements this in AudioEncodingModel.STRF_gradmap(); the batch dimension parallelises the calculation across all N output neurons in a single forward / backward pass.

Scope: 59 OV cells × 20 conspecific stimuli at dt_ms=5. 100% valid (s, n) coverage on this subset, so no NaN-aware bookkeeping is exercised here — the same notebook on the full 494-cell, 117- stim AA2 would be the next step (see “What’s next”).

Setup — Google Colab

If you’re running on Google Colab, the cell below installs deepSTRF from source. Local installs (pip install -e .) are no-ops.

Note on data: AA2 is an authenticated CRCNS dataset. To auto-download it, set $CRCNS_USERNAME and $CRCNS_PASSWORD (free account at https://crcns.org/). On a local machine that already has the data extracted, it’s picked up from the cache automatically.

[1]:
import sys
if 'google.colab' in sys.modules:
    !pip install -q git+https://github.com/urancon/deepSTRF.git
    print("deepSTRF installed from GitHub.")
else:
    print("Local environment — assuming deepSTRF is already importable.")

Local environment — assuming deepSTRF is already importable.

Imports

[2]:
%matplotlib inline
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

from deepSTRF.datasets.audio.crcns_aa2 import CRCNSAA2Dataset
from deepSTRF.models.audio import ConvNet2D
from deepSTRF.training import Fitter, set_random_seed
from deepSTRF.utils import neural_collate, plot_strf_grid

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cuda

1. Load CRCNS AA2 — OV cells, conspecific stimuli only

We narrow the dataset at construction time via the areas= and stimuli= arguments. A separate notebook covers the select_*_by_attr runtime filtering API and the bidirectional rule from data_paradigm.md §8; here we just want a clean small subset to showcase gradmaps.

[3]:
ds = CRCNSAA2Dataset(download=True, dt_ms=5, areas=('OV',), stimuli=('conspecific',))
print(f"OV + conspecific: N={ds.N_neurons} cells, S={len(ds.stims)} stims, "
      f"F={ds.F} bands, dt={ds.dt} ms")

T_min = min(s.shape[-1] for s in ds.stims)
T_max = max(s.shape[-1] for s in ds.stims)
print(f"stim T range: {T_min}..{T_max} frames "
      f"({T_min*ds.dt/1000:.2f} s .. {T_max*ds.dt/1000:.2f} s)")

valid_frac = ds.nrn_masks.float().mean().item()
print(f"valid (s, n) pairs: {100*valid_frac:.1f}%")

OV + conspecific: N=59 cells, S=20 stims, F=32 bands, dt=5 ms
stim T range: 234..511 frames (1.17 s .. 2.56 s)
valid (s, n) pairs: 100.0%

2. Train / val / test split + per-band standardisation

14 / 3 / 3 split, all conspecific. Standardisation statistics are computed on train+val and applied to all stimuli — the held-out test set is transformed with the same mean and std.

[4]:
set_random_seed(0)
S = len(ds)
g = torch.Generator().manual_seed(0)
perm = torch.randperm(S, generator=g).tolist()
train_idx = perm[:14]
val_idx   = perm[14:17]
test_idx  = perm[17:20]
print(f"split: train {len(train_idx)} | val {len(val_idx)} | test {len(test_idx)}")

stats = ds.standardize_stims(stim_indices=train_idx + val_idx, per_band=True)
print(f"per-band mean range: [{stats['mean'].min():.3f}, {stats['mean'].max():.3f}]")
print(f"per-band std range:  [{stats['std'].min():.3f}, {stats['std'].max():.3f}]")

split: train 14 | val 3 | test 3
per-band mean range: [0.007, 1.557]
per-band std range:  [0.001, 1.433]

3. Build the 2D-CNN and fit

ConvNet2D is the deepSTRF port of Pennington & David’s 2D-CNN: three (Conv2d → CausalLayerNorm → LeakyReLU) blocks over the spectrogram, then a 2-layer FC reading out the population. Default temporal_window_size = 3·(K_T 1) = 24 frames (120 ms at our 5 ms resolution) sets the gradmap window we’ll plot in section 5.

Hyperparameters:

  • batch_size=4 — with only 14 train stims, larger batches collapse to gradient descent (≤ 1 step per epoch) and lose SGD noise. On a bigger AA2 subset (more stims) you can dial this up, on this particular GPU up to ~32.

  • patience=50 on val cc_norm — the demo set is tiny so each epoch is quick and we can afford a long patience.

  • track_train_metrics=False — skip the per-epoch recomputation of cc / cc_norm on the training accumulator; not useful at this scale and saves a few seconds per epoch.

[5]:
N = ds.N_neurons
train_loader = DataLoader(Subset(ds, train_idx), batch_size=1,
                          shuffle=True, collate_fn=neural_collate)
val_loader   = DataLoader(Subset(ds, val_idx),   batch_size=1,
                          shuffle=False, collate_fn=neural_collate)
test_loader  = DataLoader(Subset(ds, test_idx),  batch_size=1,
                          shuffle=False, collate_fn=neural_collate)

set_random_seed(0)
model = ConvNet2D(
    n_frequency_bands=ds.F, kernel_size=(3, 9),
    c_hidden=10, n_hidden=20, out_neurons=N,
)
print(f"ConvNet2D params: {model.count_trainable_params():,}, "
      f"gradmap window: T={model.T} frames ({model.T * ds.dt} ms)")

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.0)
fitter = Fitter(
    model, train_loader, val_loader,
    optimizer=optimizer, device=device,
    max_epochs=2000, patience=50,
    monitor="val_cc_norm", mode="max",
    track_train_metrics=False,
    # default log_fn prints per-epoch lines
)

t0 = time.time()
history = fitter.fit()
elapsed = time.time() - t0
print(f"\nfit completed: {len(history)} epochs in {elapsed/60:.1f} min")

ConvNet2D params: 12,401, gradmap window: T=24 frames (120 ms)
epoch    0 | train_loss=0.0554 | val_loss=0.0456 | val_cc=0.0113 | val_cc_norm=0.0055
epoch    1 | train_loss=0.0450 | val_loss=0.0388 | val_cc=0.0360 | val_cc_norm=0.0248
epoch    2 | train_loss=0.0385 | val_loss=0.0352 | val_cc=0.0702 | val_cc_norm=0.0615
epoch    3 | train_loss=0.0351 | val_loss=0.0329 | val_cc=0.1782 | val_cc_norm=0.1957
epoch    4 | train_loss=0.0328 | val_loss=0.0314 | val_cc=0.2422 | val_cc_norm=0.2773
epoch    5 | train_loss=0.0314 | val_loss=0.0303 | val_cc=0.2746 | val_cc_norm=0.3216
epoch    6 | train_loss=0.0299 | val_loss=0.0288 | val_cc=0.3145 | val_cc_norm=0.3786
epoch    7 | train_loss=0.0286 | val_loss=0.0280 | val_cc=0.3576 | val_cc_norm=0.4388
epoch    8 | train_loss=0.0278 | val_loss=0.0276 | val_cc=0.3767 | val_cc_norm=0.4663
epoch    9 | train_loss=0.0273 | val_loss=0.0270 | val_cc=0.3874 | val_cc_norm=0.4808
epoch   10 | train_loss=0.0271 | val_loss=0.0274 | val_cc=0.3989 | val_cc_norm=0.4969
epoch   11 | train_loss=0.0269 | val_loss=0.0269 | val_cc=0.4001 | val_cc_norm=0.5006
epoch   12 | train_loss=0.0268 | val_loss=0.0262 | val_cc=0.4083 | val_cc_norm=0.5127
epoch   13 | train_loss=0.0264 | val_loss=0.0258 | val_cc=0.4225 | val_cc_norm=0.5291
epoch   14 | train_loss=0.0264 | val_loss=0.0257 | val_cc=0.4285 | val_cc_norm=0.5398
epoch   15 | train_loss=0.0255 | val_loss=0.0253 | val_cc=0.4345 | val_cc_norm=0.5471
epoch   16 | train_loss=0.0251 | val_loss=0.0248 | val_cc=0.4418 | val_cc_norm=0.5560
epoch   17 | train_loss=0.0247 | val_loss=0.0252 | val_cc=0.4480 | val_cc_norm=0.5643
epoch   18 | train_loss=0.0245 | val_loss=0.0247 | val_cc=0.4579 | val_cc_norm=0.5766
epoch   19 | train_loss=0.0245 | val_loss=0.0242 | val_cc=0.4589 | val_cc_norm=0.5792
epoch   20 | train_loss=0.0243 | val_loss=0.0240 | val_cc=0.4653 | val_cc_norm=0.5872
epoch   21 | train_loss=0.0239 | val_loss=0.0249 | val_cc=0.4637 | val_cc_norm=0.5853
epoch   22 | train_loss=0.0235 | val_loss=0.0235 | val_cc=0.4756 | val_cc_norm=0.6003
epoch   23 | train_loss=0.0236 | val_loss=0.0247 | val_cc=0.4789 | val_cc_norm=0.6048
epoch   24 | train_loss=0.0234 | val_loss=0.0238 | val_cc=0.4794 | val_cc_norm=0.6061
epoch   25 | train_loss=0.0231 | val_loss=0.0240 | val_cc=0.4835 | val_cc_norm=0.6111
epoch   26 | train_loss=0.0231 | val_loss=0.0252 | val_cc=0.4706 | val_cc_norm=0.5954
epoch   27 | train_loss=0.0233 | val_loss=0.0246 | val_cc=0.4768 | val_cc_norm=0.6035
epoch   28 | train_loss=0.0231 | val_loss=0.0240 | val_cc=0.4882 | val_cc_norm=0.6179
epoch   29 | train_loss=0.0228 | val_loss=0.0233 | val_cc=0.4898 | val_cc_norm=0.6204
epoch   30 | train_loss=0.0225 | val_loss=0.0244 | val_cc=0.4885 | val_cc_norm=0.6191
epoch   31 | train_loss=0.0223 | val_loss=0.0235 | val_cc=0.4887 | val_cc_norm=0.6197
epoch   32 | train_loss=0.0225 | val_loss=0.0233 | val_cc=0.4909 | val_cc_norm=0.6230
epoch   33 | train_loss=0.0222 | val_loss=0.0230 | val_cc=0.4942 | val_cc_norm=0.6272
epoch   34 | train_loss=0.0219 | val_loss=0.0234 | val_cc=0.4888 | val_cc_norm=0.6206
epoch   35 | train_loss=0.0221 | val_loss=0.0234 | val_cc=0.4882 | val_cc_norm=0.6206
epoch   36 | train_loss=0.0220 | val_loss=0.0231 | val_cc=0.4929 | val_cc_norm=0.6270
epoch   37 | train_loss=0.0217 | val_loss=0.0231 | val_cc=0.4993 | val_cc_norm=0.6340
epoch   38 | train_loss=0.0216 | val_loss=0.0229 | val_cc=0.5045 | val_cc_norm=0.6422
epoch   39 | train_loss=0.0215 | val_loss=0.0232 | val_cc=0.4990 | val_cc_norm=0.6357
epoch   40 | train_loss=0.0213 | val_loss=0.0224 | val_cc=0.5076 | val_cc_norm=0.6465
epoch   41 | train_loss=0.0213 | val_loss=0.0221 | val_cc=0.5104 | val_cc_norm=0.6505
epoch   42 | train_loss=0.0211 | val_loss=0.0222 | val_cc=0.5133 | val_cc_norm=0.6543
epoch   43 | train_loss=0.0212 | val_loss=0.0226 | val_cc=0.5038 | val_cc_norm=0.6425
epoch   44 | train_loss=0.0208 | val_loss=0.0228 | val_cc=0.5093 | val_cc_norm=0.6499
epoch   45 | train_loss=0.0208 | val_loss=0.0239 | val_cc=0.5044 | val_cc_norm=0.6431
epoch   46 | train_loss=0.0211 | val_loss=0.0225 | val_cc=0.5121 | val_cc_norm=0.6528
epoch   47 | train_loss=0.0210 | val_loss=0.0220 | val_cc=0.5161 | val_cc_norm=0.6585
epoch   48 | train_loss=0.0209 | val_loss=0.0232 | val_cc=0.5103 | val_cc_norm=0.6509
epoch   49 | train_loss=0.0205 | val_loss=0.0222 | val_cc=0.5137 | val_cc_norm=0.6553
epoch   50 | train_loss=0.0207 | val_loss=0.0231 | val_cc=0.5099 | val_cc_norm=0.6503
epoch   51 | train_loss=0.0203 | val_loss=0.0225 | val_cc=0.5163 | val_cc_norm=0.6593
epoch   52 | train_loss=0.0203 | val_loss=0.0219 | val_cc=0.5173 | val_cc_norm=0.6607
epoch   53 | train_loss=0.0201 | val_loss=0.0222 | val_cc=0.5179 | val_cc_norm=0.6615
epoch   54 | train_loss=0.0200 | val_loss=0.0218 | val_cc=0.5195 | val_cc_norm=0.6635
epoch   55 | train_loss=0.0198 | val_loss=0.0231 | val_cc=0.5113 | val_cc_norm=0.6529
epoch   56 | train_loss=0.0198 | val_loss=0.0219 | val_cc=0.5205 | val_cc_norm=0.6650
epoch   57 | train_loss=0.0199 | val_loss=0.0224 | val_cc=0.5208 | val_cc_norm=0.6655
epoch   58 | train_loss=0.0198 | val_loss=0.0221 | val_cc=0.5236 | val_cc_norm=0.6693
epoch   59 | train_loss=0.0197 | val_loss=0.0217 | val_cc=0.5209 | val_cc_norm=0.6661
epoch   60 | train_loss=0.0198 | val_loss=0.0219 | val_cc=0.5186 | val_cc_norm=0.6628
epoch   61 | train_loss=0.0198 | val_loss=0.0219 | val_cc=0.5222 | val_cc_norm=0.6678
epoch   62 | train_loss=0.0196 | val_loss=0.0231 | val_cc=0.5086 | val_cc_norm=0.6500
epoch   63 | train_loss=0.0196 | val_loss=0.0223 | val_cc=0.5221 | val_cc_norm=0.6670
epoch   64 | train_loss=0.0195 | val_loss=0.0216 | val_cc=0.5263 | val_cc_norm=0.6730
epoch   65 | train_loss=0.0193 | val_loss=0.0218 | val_cc=0.5245 | val_cc_norm=0.6709
epoch   66 | train_loss=0.0194 | val_loss=0.0225 | val_cc=0.5199 | val_cc_norm=0.6647
epoch   67 | train_loss=0.0194 | val_loss=0.0226 | val_cc=0.5181 | val_cc_norm=0.6625
epoch   68 | train_loss=0.0192 | val_loss=0.0223 | val_cc=0.5268 | val_cc_norm=0.6735
epoch   69 | train_loss=0.0191 | val_loss=0.0220 | val_cc=0.5247 | val_cc_norm=0.6711
epoch   70 | train_loss=0.0189 | val_loss=0.0225 | val_cc=0.5242 | val_cc_norm=0.6706
epoch   71 | train_loss=0.0188 | val_loss=0.0223 | val_cc=0.5169 | val_cc_norm=0.6608
epoch   72 | train_loss=0.0188 | val_loss=0.0217 | val_cc=0.5240 | val_cc_norm=0.6702
epoch   73 | train_loss=0.0188 | val_loss=0.0222 | val_cc=0.5252 | val_cc_norm=0.6714
epoch   74 | train_loss=0.0185 | val_loss=0.0227 | val_cc=0.5195 | val_cc_norm=0.6641
epoch   75 | train_loss=0.0185 | val_loss=0.0220 | val_cc=0.5242 | val_cc_norm=0.6701
epoch   76 | train_loss=0.0184 | val_loss=0.0224 | val_cc=0.5198 | val_cc_norm=0.6643
epoch   77 | train_loss=0.0183 | val_loss=0.0218 | val_cc=0.5263 | val_cc_norm=0.6731
epoch   78 | train_loss=0.0183 | val_loss=0.0224 | val_cc=0.5164 | val_cc_norm=0.6599
epoch   79 | train_loss=0.0182 | val_loss=0.0221 | val_cc=0.5209 | val_cc_norm=0.6658
epoch   80 | train_loss=0.0183 | val_loss=0.0234 | val_cc=0.5122 | val_cc_norm=0.6545
epoch   81 | train_loss=0.0182 | val_loss=0.0216 | val_cc=0.5288 | val_cc_norm=0.6766
epoch   82 | train_loss=0.0181 | val_loss=0.0220 | val_cc=0.5198 | val_cc_norm=0.6645
epoch   83 | train_loss=0.0182 | val_loss=0.0220 | val_cc=0.5241 | val_cc_norm=0.6703
epoch   84 | train_loss=0.0180 | val_loss=0.0225 | val_cc=0.5182 | val_cc_norm=0.6625
epoch   85 | train_loss=0.0180 | val_loss=0.0219 | val_cc=0.5263 | val_cc_norm=0.6733
epoch   86 | train_loss=0.0179 | val_loss=0.0222 | val_cc=0.5186 | val_cc_norm=0.6633
epoch   87 | train_loss=0.0179 | val_loss=0.0218 | val_cc=0.5239 | val_cc_norm=0.6701
epoch   88 | train_loss=0.0178 | val_loss=0.0225 | val_cc=0.5215 | val_cc_norm=0.6664
epoch   89 | train_loss=0.0177 | val_loss=0.0227 | val_cc=0.5176 | val_cc_norm=0.6622
epoch   90 | train_loss=0.0178 | val_loss=0.0224 | val_cc=0.5188 | val_cc_norm=0.6635
epoch   91 | train_loss=0.0176 | val_loss=0.0227 | val_cc=0.5179 | val_cc_norm=0.6622
epoch   92 | train_loss=0.0175 | val_loss=0.0220 | val_cc=0.5217 | val_cc_norm=0.6672
epoch   93 | train_loss=0.0175 | val_loss=0.0228 | val_cc=0.5224 | val_cc_norm=0.6683
epoch   94 | train_loss=0.0175 | val_loss=0.0228 | val_cc=0.5158 | val_cc_norm=0.6594
epoch   95 | train_loss=0.0173 | val_loss=0.0221 | val_cc=0.5232 | val_cc_norm=0.6694
epoch   96 | train_loss=0.0173 | val_loss=0.0223 | val_cc=0.5193 | val_cc_norm=0.6643
epoch   97 | train_loss=0.0173 | val_loss=0.0224 | val_cc=0.5180 | val_cc_norm=0.6628
epoch   98 | train_loss=0.0172 | val_loss=0.0228 | val_cc=0.5176 | val_cc_norm=0.6616
epoch   99 | train_loss=0.0173 | val_loss=0.0220 | val_cc=0.5239 | val_cc_norm=0.6702
epoch  100 | train_loss=0.0172 | val_loss=0.0223 | val_cc=0.5184 | val_cc_norm=0.6633
epoch  101 | train_loss=0.0172 | val_loss=0.0229 | val_cc=0.5195 | val_cc_norm=0.6642
epoch  102 | train_loss=0.0171 | val_loss=0.0227 | val_cc=0.5163 | val_cc_norm=0.6603
epoch  103 | train_loss=0.0171 | val_loss=0.0224 | val_cc=0.5181 | val_cc_norm=0.6628
epoch  104 | train_loss=0.0170 | val_loss=0.0223 | val_cc=0.5165 | val_cc_norm=0.6605
epoch  105 | train_loss=0.0171 | val_loss=0.0229 | val_cc=0.5151 | val_cc_norm=0.6582
epoch  106 | train_loss=0.0169 | val_loss=0.0225 | val_cc=0.5195 | val_cc_norm=0.6644
epoch  107 | train_loss=0.0169 | val_loss=0.0221 | val_cc=0.5199 | val_cc_norm=0.6651
epoch  108 | train_loss=0.0169 | val_loss=0.0231 | val_cc=0.5175 | val_cc_norm=0.6620
epoch  109 | train_loss=0.0169 | val_loss=0.0231 | val_cc=0.5116 | val_cc_norm=0.6543
epoch  110 | train_loss=0.0171 | val_loss=0.0225 | val_cc=0.5161 | val_cc_norm=0.6604
epoch  111 | train_loss=0.0170 | val_loss=0.0232 | val_cc=0.5139 | val_cc_norm=0.6574
epoch  112 | train_loss=0.0168 | val_loss=0.0225 | val_cc=0.5193 | val_cc_norm=0.6644
epoch  113 | train_loss=0.0167 | val_loss=0.0223 | val_cc=0.5220 | val_cc_norm=0.6679
epoch  114 | train_loss=0.0166 | val_loss=0.0227 | val_cc=0.5168 | val_cc_norm=0.6609
epoch  115 | train_loss=0.0167 | val_loss=0.0225 | val_cc=0.5204 | val_cc_norm=0.6659
epoch  116 | train_loss=0.0166 | val_loss=0.0222 | val_cc=0.5256 | val_cc_norm=0.6727
epoch  117 | train_loss=0.0166 | val_loss=0.0231 | val_cc=0.5200 | val_cc_norm=0.6653
epoch  118 | train_loss=0.0165 | val_loss=0.0226 | val_cc=0.5158 | val_cc_norm=0.6595
epoch  119 | train_loss=0.0165 | val_loss=0.0229 | val_cc=0.5165 | val_cc_norm=0.6607
epoch  120 | train_loss=0.0166 | val_loss=0.0228 | val_cc=0.5185 | val_cc_norm=0.6631
epoch  121 | train_loss=0.0165 | val_loss=0.0223 | val_cc=0.5216 | val_cc_norm=0.6675
epoch  122 | train_loss=0.0165 | val_loss=0.0232 | val_cc=0.5208 | val_cc_norm=0.6667
epoch  123 | train_loss=0.0164 | val_loss=0.0229 | val_cc=0.5192 | val_cc_norm=0.6641
epoch  124 | train_loss=0.0163 | val_loss=0.0225 | val_cc=0.5250 | val_cc_norm=0.6720
epoch  125 | train_loss=0.0164 | val_loss=0.0238 | val_cc=0.5144 | val_cc_norm=0.6582
epoch  126 | train_loss=0.0165 | val_loss=0.0231 | val_cc=0.5100 | val_cc_norm=0.6519
epoch  127 | train_loss=0.0165 | val_loss=0.0226 | val_cc=0.5192 | val_cc_norm=0.6642
epoch  128 | train_loss=0.0164 | val_loss=0.0230 | val_cc=0.5174 | val_cc_norm=0.6618
epoch  129 | train_loss=0.0162 | val_loss=0.0230 | val_cc=0.5169 | val_cc_norm=0.6613
epoch  130 | train_loss=0.0161 | val_loss=0.0226 | val_cc=0.5217 | val_cc_norm=0.6681
epoch  131 | train_loss=0.0161 | val_loss=0.0230 | val_cc=0.5150 | val_cc_norm=0.6587

fit completed: 132 epochs in 0.7 min

4. Test-set summary

[6]:
test = fitter.evaluate(test_loader)
test_cc_norm = test["cc_norm"].cpu()
test_cc      = test["cc"].cpu()
print(f"test cc:      {torch.nanmean(test_cc):+.3f}")
print(f"test cc_norm: {torch.nanmean(test_cc_norm):+.3f}")

# Training curve
fig, ax = plt.subplots(figsize=(8, 3.5))
ep = [h["epoch"] for h in history]
vC = [torch.nanmean(h["val_cc_norm"]).item() for h in history]
ax.plot(ep, vC, lw=1.6, color='tab:purple')
ax.set_xlabel("epoch"); ax.set_ylabel("val cc_norm (mean over cells)")
ax.set_title("ConvNet2D on AA2 OV / conspecific — training curve")
plt.tight_layout(); plt.show()

test cc:      +0.497
test cc_norm: +0.653
../../_images/_source_ipynb_strf_gradmap_aa2_12_1.png

5. STRF gradmaps for the best-predicted cells

model.STRF_gradmap() returns a (N, 1, F, T) tensor — for each output neuron, the gradient \(\partial \mathcal{L}_n / \partial \mathbf{x}_0\) of the negative last-timestep activity with respect to a null (all-zero) spectrogram. Under this sign convention positive entries are inhibitory features — adding mass there would decrease the neuron’s response (see `README_gradmap_strf.md <../md/README_gradmap_strf.md>`__ §4 for the full derivation). To get the more intuitive excitatory-as-positive plot — red = features the cell prefers — we plot \(-\mathbf{g}_n\), i.e. negate the returned gradmaps before passing them to plot_strf_grid.

This is the closest thing a nonlinear model can give you to a classical reverse-correlation STRF, and it’s the right diagnostic for asking “did the network learn something biologically sensible?” — particularly when the model has no explicit STRF layer.

We pick the 12 best-predicted OV cells (highest test cc_norm). Time runs left to right in units of milliseconds of history. Frequency runs bottom to top (low → high band). Diverging colormap RdBu_r with per-cell symmetric scale so the spatial structure is easy to read across cells of varying gradient magnitude.

[7]:
model.eval()
gradmaps = model.STRF_gradmap()                        # (N, 1, F, T)
gradmaps = gradmaps.squeeze(1).detach().cpu().numpy()  # (N, F, T)
print(f"gradmaps shape: {gradmaps.shape}")

K = 12
test_arr = test_cc_norm.numpy()
order = np.argsort(np.where(np.isnan(test_arr), -np.inf, test_arr))[::-1]
top_idx = order[:K]
print(f"top {K} cells, cc_norm range: "
      f"{test_arr[top_idx[0]]:.3f} ... {test_arr[top_idx[-1]]:.3f}")

# Negate so that red (positive) = excitatory under RdBu_r — see the
# markdown above and README_gradmap_strf.md §4 for the sign convention.
titles = [f"{ds.nrn_meta[n]['cell_id'][:14]} ({ds.nrn_meta[n]['area']})\n"
          f"cc_norm={test_arr[n]:.3f}"
          for n in top_idx]
plot_strf_grid(
    -gradmaps[top_idx], titles=titles, dt_ms=ds.dt, ncols=4,
    figsize=(4 * 3.2, 3 * 2.6),
)
plt.show()
gradmaps shape: (59, 32, 24)
top 12 cells, cc_norm range: 0.967 ... 0.792
../../_images/_source_ipynb_strf_gradmap_aa2_14_1.png

What’s next

  • Scale up to the full AA2 population. Drop areas= and stimuli= from the dataset constructor to fit the full 494 cells × 117 stims. Each epoch becomes ~80 s on a single GPU rather than a few seconds, but the gradmap pipeline is identical. Useful for comparing receptive-field structure across MLd / Field L / OV / CM at the population level.

  • Sustained-loss gradmaps. The current STRF_gradmap uses the last-timestep activity as the loss target — a Spike-Triggered- Average view. A complementary diagnostic is the gradient of the time-averaged activity, which highlights features that drive sustained responses rather than transient ones.

  • Cross-model comparison. Run the same gradmap extraction on a Linear / LinearNonlinear model fit to the same split, and overlay the explicit STRF kernel against the gradient map. They should agree closely on the L/LN models, and the difference between ConvNet2D’s gradmap and an LN’s STRF is exactly the nonlinear contribution.