Alice EEG with deepSTRF — modality generalisation tutorial

Open In Colab

This notebook adapts deepSTRF — built for single-unit recordings of the auditory pathway — to a scalp EEG dataset, the Alice corpus released by Bhattasali et al. (2020) and preprocessed by Brodbeck et al. (2023, eLife) for their Eelbrain toolkit paper.

What this notebook is: a deepSTRF-on-EEG demonstration. It loads 33-subject EEG into the standard deepSTRF paradigm (B, N, R=1, T), applies the canonical preprocessing pipeline, and fits a deepSTRF model end-to-end. The model produces per-channel predictions that correlate weakly but positively with held-out EEG segments.

What this notebook is not: a multi-subject group analysis. We fit a single subject on a single train/val/test split. The per-channel prediction we obtain (r 0.09) is the published single-subject envelope-TRF ceiling for this dataset — confirmed in §6 by running Brodbeck’s own eelbrain.boosting pipeline on the same data and getting the same number. Mind the unit in Brodbeck’s figures: “% variability explained” is ``100·r²``, not Pearson ``r``. Their Fig 4B colorbar maxes at 1 %, i.e. r 0.10 — a tiny absolute accuracy that is normal for single-trial scalp-EEG envelope tracking (see §6).

Steps

  1. Load Alice EEG with the canonical 0.5–20 Hz bandpass (matches the eelbrain analysis pipeline).

  2. Visualise stimulus and EEG alignment.

  3. Standardise both predictor and response (using train-set statistics).

  4. Fit a Linear STRF (TRF analog) and a StateNet GRU (recurrent, fewer per-channel parameters).

  5. Report per-channel test correlation and discuss the gap to Brodbeck.

[ ]:
# Colab-friendly install. Skipped on local installations.
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    !pip install -q 'deepSTRF[eeg] @ git+https://github.com/urancon/deepSTRF.git@develop'
[ ]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset

from deepSTRF.datasets.audio import AliceEEGDataset
from deepSTRF.models.audio import Linear, StateNet
from deepSTRF.utils.data import neural_collate
from deepSTRF.metrics import corrcoef, fve, mse_loss

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

1. Load the dataset

Single subject by default. download=True pulls the ~2.5 GiB Brodbeck restructure from UMd DRUM into the platform cache; replace with path=... if you already have the data.

Defaults set by the dataset class match Brodbeck’s analysis pipeline:

  • dt_ms=10 → 100 Hz analysis rate

  • hp_freq_hz=1.0 (override to 0.5 to match the eelbrain convert-all.py exactly)

  • lp_freq_hz=None (we set 20 Hz below for the cortical-tracking band)

  • n_frequency_bands=8 (ERB-band gammatone approximation)

[ ]:
SUBJECT = "S20"

ds = AliceEEGDataset(
    download=True,           # set False + path=... if data is local
    subjects=[SUBJECT],
    dt_ms=10.0,
    n_frequency_bands=8,
    spec_backend='heeris',   # paper-faithful Heeris time-domain gammatone (needs deepSTRF[eeg])
    fmax=8000.0,             # speech-relevant band edge (drops inaudible-for-speech bands)
    hp_freq_hz=0.5,          # matches eelbrain convert-all.py exactly
    lp_freq_hz=20.0,         # cortical-tracking band; suppresses beta/gamma
)
print(f"S={len(ds.stims)} segments, N={ds.N_neurons} channels, F={ds.F} bands, dt={ds.dt} ms")
print(f"first segment: spectrogram {ds.stims[0].shape}, EEG {ds.responses[0][0].shape}")
total_min = sum(m['duration_s'] for m in ds.stim_meta) / 60
print(f"total audio: {total_min:.2f} min")

2. Visualise the stimulus / response (Brodbeck Fig 8 analog)

First 6 seconds of segment 1: the ERB-band log spectrogram, its summed envelope, and one EEG channel.

[ ]:
spec = ds.stims[0][0]                       # (F, T)
ch_idx = next(
    i for i, m in enumerate(ds.nrn_meta)
    if not ds.responses[0][i].isnan().all()
)
eeg = ds.responses[0][ch_idx][0]            # (T,)
envelope = spec.exp().sum(dim=0)            # broadband acoustic energy
T = min(600, spec.shape[-1])                # 6 s at 100 Hz
t_s = np.arange(T) * ds.dt / 1000

fig, axes = plt.subplots(3, 1, figsize=(9, 5), sharex=True)
axes[0].imshow(spec[:, :T], aspect='auto', origin='lower',
               extent=[0, t_s[-1], 0, ds.F], cmap='magma')
axes[0].set_ylabel('ERB band'); axes[0].set_title(f'Subject {SUBJECT}, segment 1')
axes[1].plot(t_s, envelope[:T]); axes[1].set_ylabel('Envelope')
axes[2].plot(t_s, eeg[:T]); axes[2].set_ylabel(f'EEG ch {ds.nrn_meta[ch_idx]["channel_id"]}')
axes[2].set_xlabel('Time (s)')
plt.tight_layout()
plt.show()

3. Standardise predictor and response, split train / val / test

deepSTRF ships two base-class helpers that handle the NaN-sentinel-aware normalisation:

  • standardize_stims(stim_indices, per_band=True) — per-band z-score using the given subset.

  • normalize_responses(method='zscore', stim_indices=...) — per-channel z-score for signed EEG targets.

Both are computed on train+val statistics and applied to all stims (so test segments are transformed with the same stats but their post-transform mean/std need not be 0/1). This is the same convention Brodbeck’s boosting(scale_data=True) uses.

We hold out segment 11 as test, segment 9 as val, and train on segments 0-8 + segment 10.

[ ]:
TRAIN_IDX = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10]
VAL_IDX   = [9]
TEST_IDX  = [11]
STATS_IDX = TRAIN_IDX + VAL_IDX

ds.standardize_stims(stim_indices=STATS_IDX, per_band=True)
ds.normalize_responses(method='zscore', stim_indices=STATS_IDX)

train_loader = DataLoader(Subset(ds, TRAIN_IDX), batch_size=1, shuffle=True, collate_fn=neural_collate)
val_loader   = DataLoader(Subset(ds, VAL_IDX),   batch_size=1, collate_fn=neural_collate)
test_loader  = DataLoader(Subset(ds, TEST_IDX),  batch_size=1, collate_fn=neural_collate)

4. Fit two models

  • Linear (deepSTRF’s Linear) — strictly causal STRF with a 1-second window over the 8-band gammatone-approximation spectrogram. The closest deepSTRF analog of a multivariate TRF, but without basis-function smoothing or boosting’s L1 sparsity.

  • StateNet GRU C=14 — recurrent backbone with a per-(subject, channel) linear readout. Roughly 7× fewer parameters than Linear while matching it on test cc — the sample-efficient choice on the small per-subject data.

Both use Identity output activation and MSE loss against the z-scored EEG target.

AdamW + weight decay 1e-3 + patience-based early stopping on validation correlation. The training schedule is intentionally long (max 500 epochs, patience 100) because the val cc rises slowly; aggressive early stopping under-fits.

[ ]:
def fit_and_evaluate(model, max_epochs=500, patience=100, lr=1e-3, wd=1e-3):
    model = model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    best_val, best_state, waited = -float('inf'), None, 0
    history = {'train_loss': [], 'val_cc': []}
    for ep in range(max_epochs):
        model.train()
        ep_losses = []
        for batch in train_loader:
            stims, responses = batch['stims'].to(device), batch['responses'].to(device)
            pred = model(stims)
            loss = mse_loss(pred, responses)
            opt.zero_grad(); loss.backward(); opt.step()
            ep_losses.append(loss.item())
        model.eval()
        with torch.no_grad():
            v = []
            for batch in val_loader:
                stims, responses = batch['stims'].to(device), batch['responses'].to(device)
                pred = model(stims); gt = responses.nanmean(dim=2, keepdim=True)
                v.append(corrcoef(pred, gt, reduction='none').cpu())
            val_cc = torch.stack(v).nanmean().item()
        history['train_loss'].append(float(np.mean(ep_losses)))
        history['val_cc'].append(val_cc)
        if val_cc > best_val:
            best_val, best_state, waited = val_cc, {k: w.detach().clone() for k, w in model.state_dict().items()}, 0
        else:
            waited += 1
            if waited >= patience: break
    if best_state: model.load_state_dict(best_state)
    model.eval()
    with torch.no_grad():
        cc, ve = [], []
        for batch in test_loader:
            stims, responses = batch['stims'].to(device), batch['responses'].to(device)
            pred = model(stims); gt = responses.nanmean(dim=2, keepdim=True)
            cc.append(corrcoef(pred, gt, reduction='none').cpu())
            ve.append(fve(pred, gt, reduction='none').cpu())
        cc = torch.stack(cc).nanmean(dim=0)
        ve = torch.stack(ve).nanmean(dim=0)
    return {'val_cc': best_val, 'cc': cc, 'fve': ve, 'history': history, 'epochs': ep + 1}

F, N = ds.F, ds.N_neurons

results = {}
for name, builder in [
    ('Linear (T=100)',
        lambda: Linear(F, temporal_window_size=100, out_neurons=N,
                       output_activation=nn.Identity())),
    ('StateNet GRU (C=14)',
        lambda: StateNet(F, temporal_window_size=1, kernel_size=5, stride=2,
                         hidden_channels=14, rnn_type='GRU', out_neurons=N,
                         output_activation=nn.Identity())),
]:
    print(f'fitting {name}…', flush=True)
    results[name] = fit_and_evaluate(builder())
    cc_mean = results[name]['cc'].nanmean().item()
    cc_max  = results[name]['cc'][~results[name]['cc'].isnan()].max().item()
    fve_mean = results[name]['fve'].nanmean().item()
    print(f'  {name}: val={results[name]["val_cc"]:+.3f}, test_cc mean={cc_mean:+.3f} '
          f'max={cc_max:+.3f}, test_fve mean={fve_mean:+.4f}, epochs={results[name]["epochs"]}')

5. Per-channel results

Two views:

  • A bar of mean and max test cc per model, with Brodbeck’s published group-mean reference plotted as a horizontal line.

  • The distribution of per-channel test cc across the 61 EEG channels — informs which channels carry the predictable signal vs which channels are noise-dominated.

The reference line marks Brodbeck’s single-subject envelope-TRF ceiling, r 0.10 (Fig 4B colorbar max = 1 % variability explained = ). Our per-channel test cc sits right at this line — see §6.

[ ]:
fig, axes = plt.subplots(1, 2, figsize=(11, 4))
x = np.arange(len(results))
labels = list(results.keys())
means = [r['cc'].nanmean().item() for r in results.values()]
maxes = [r['cc'][~r['cc'].isnan()].max().item() for r in results.values()]

axes[0].bar(x - 0.18, means, width=0.36, label='mean across channels', color='#4c72b0')
axes[0].bar(x + 0.18, maxes, width=0.36, label='best channel', color='#dd8452')
axes[0].axhline(0.10, ls='--', lw=1, color='black',
                label='Brodbeck ceiling r≈0.10 (Fig 4B = 1% var)')
axes[0].set_xticks(x); axes[0].set_xticklabels(labels, rotation=12, ha='right')
axes[0].set_ylabel('Test correlation'); axes[0].legend(fontsize=8)
axes[0].set_title(f'Subject {SUBJECT}, held-out segment')

for name, r in results.items():
    vals = r['cc'][~r['cc'].isnan()].numpy()
    axes[1].hist(vals, bins=20, alpha=0.6, label=name)
axes[1].axvline(0, ls=':', color='gray')
axes[1].set_xlabel('Per-channel test cc'); axes[1].set_ylabel('# channels')
axes[1].legend(fontsize=8); axes[1].set_title('Distribution across 61 EEG channels')
plt.tight_layout(); plt.show()

6. What we got: at the data ceiling, not below it

On a single subject, held out on two audio segments, with the paper-faithful Heeris spectrogram + 0.5–20 Hz band, expect roughly:

  • Linear: test cc mean ≈ 0.08 (best channel ≈ 0.16).

  • StateNet GRU C=14: test cc mean ≈ 0.09 (best channel ≈ 0.17), with 7× fewer parameters.

These look tiny if you come from single-unit or ECoG work — but they are the published ceiling for this dataset. We checked this directly by running Brodbeck’s own pipeline (eelbrain.boosting, L1 + 50 ms Hamming basis, the exact method behind their Fig 4) on the same subject and the same train/val/test split:

Estimator

mean test r

max

mean % var

deepSTRF Linear STRF

0.084

0.16

0.86 %

sklearn Ridge (α-grid)

0.084

0.16

0.97 %

deepSTRF StateNet GRU

0.086

0.17

0.74 %

eelbrain ``boosting`` (Brodbeck)

0.090

0.17

1.05 %

The four estimators agree to within Δr 0.005. There is no regularisation gap and no pipeline gapr 0.09 mean per channel is simply what single-trial scalp EEG supports. Brodbeck’s Fig 4B colorbar maxes at 1 % variability explained, which is exactly this (% var = 100·r²).

Why such low numbers are normal — and what EEG people report instead

Single-trial scalp-EEG envelope-TRF r is 0.05–0.15 across the literature (Lalor & Foxe 2010; Ding & Simon 2012; Di Liberto 2015; Crosse 2016; Broderick 2018). The cortical envelope-tracking response is ~1 µV against 30–50 µV of background, so even = 1 % is a real, replicable signal. r 0.3 only happens with intracranial recordings or unit-level data.

Because absolute accuracy is low, the field’s deliverables are not the prediction r:

  1. The TRF/STRF kernel shape — interpretable P1/N1/P2 peaks at ~50/100/200 ms; estimated precisely even when prediction is poor. (deepSTRF: AudioEncodingModel.STRF_gradmap.)

  2. Predictive power as a significance test — with n=33, even 0.5 % variance is p 0.001; the question is whether adding a predictor significantly increases predictive power (Brodbeck Fig 4C/4D).

  3. Nested-model comparison — “does a lexical/surprisal predictor explain variance beyond acoustics?” This is what TRFs are for.

  4. Backward (decoding) models — reconstruct the envelope from EEG (pools channels, r 0.1–0.3); attention decoding then reports classification accuracy (80–90 %), not r.

Concrete follow-ups (extend the analysis, not chase r)

  • Word-onset / surprisal predictors from stimuli/AliceChapterOne-EEG.csv → reproduces Brodbeck Fig 5–6 (the nested-model-comparison story; the main reason to use this dataset).

  • Hamming-basis STRF kernel (BasisKernel, eelbrain’s basis=0.050) → smoother, more interpretable TRF kernels (won’t move r — we’re at the ceiling).

  • Subject embeddings + shared StateNet backbone → true multi-subject pooling; the one route to a meaningfully higher number.

  • ``eelbrain.boosting`` wrapper as a deepSTRF ``Fitter`` → reusable apples-to-apples cross-check (the validation used here lives at untracked/alice_eeg_eelbrain_compare.py).

Optional: subjects-as-repeats mode

deepSTRF exposes an alternate view: treat each subject as a repeat of a canonical EEG response per channel. Enables normalized_corrcoef(method='schoppe') for inter-subject reliability bounds. See `docs/_source/md/README_Alice_EEG.md <../docs/_source/md/README_Alice_EEG.md>`__ for the interpretive caveat.

ds_isc = AliceEEGDataset(download=True, treat_subjects_as='repeats')
# N = 61 channels, R = 33 subjects per (channel, segment)