The deepSTRF data paradigm
This note documents how deepSTRF represents stimulus/response data internally,
why it is done that way, and what invariants any user, contributor, or model
author must respect. It is the single reference for the “data contract” of the
library and should stay in sync with the NeuralDataset / AudioNeuralDataset
/ VideoNeuralDataset base classes.
1. Problem statement
Sensory neurophysiology datasets are triply ragged:
Stimulus duration varies: audio clips of 1–5 s, video clips of different lengths.
T_sis per-stimulus.Repeat count varies per
(stim, neuron)pair: different neurons were recorded under different numbers of trials.R_{s,n}is per-pair.Stim × neuron coverage is sparse: not every neuron heard every stimulus (recording sessions differ, cohorts differ, some bonus datasets concatenate disjoint populations).
We need a single storage scheme that handles all three, supports both single-neuron and population training, and degrades gracefully when users forget an invariant.
2. Storage at the dataset level
A concrete NeuralDataset subclass populates six attributes:
Attribute |
Type |
Shape / structure |
|---|---|---|
|
|
Each element is a stimulus tensor of modality-specific shape. Audio spectrogram: |
|
|
|
|
|
Per-stim metadata dict: e.g. |
|
|
Per-neuron metadata dict: e.g. |
|
|
Total neurons; equals |
|
|
Derived |
self.dt (time-bin width in ms) and self.path (data location) are set by
the base class constructor from the dt_ms and path arguments.
3. Encoding missingness
Missingness has three distinct sources. deepSTRF encodes them in a single channel (NaN in the response tensor) to keep one source of truth, with a derived boolean mask exposed for ergonomics.
3.1 Structural: neuron n never saw stim s
Convention: responses[s][n] is a (R=1, T=1) all-NaN tensor.
The nrn_masks property on the dataset derives this on the fly — it
sets nrn_masks[s, n] = False iff responses[s][n].isnan().any(). So
at any time:
dataset.nrn_masks[s, n] # True iff neuron n has real data for stim s
Subclasses must produce the (1, 1) NaN sentinel tensor for structurally
missing entries, not skip the entry or leave it unset — the mask is
derived from responses, so failing to record the sentinel loses the
missingness information entirely.
3.2 Temporal padding: T_s varies across stims within a batch
Introduced at batch time by the collate function. Stims are zero-padded
on the right; responses are NaN-padded on the right. The batched output
shapes become (B, ..., T_stim_max) for stims and (B, N, R_max, T_resp_max)
for responses. The two T_* axes are sized independently — equal in
spectrogram mode (one bin per neural sample), unequal in the raw-waveform
mode where the stim runs at audio_fs and responses stay at the dataset’s
neural rate.
3.3 Repeat padding: R_{s,n} varies across neurons within a stim
Also introduced at batch time by the collate function. Responses are NaN-padded
along the repeat dimension up to R_max.
3.4 Raw-waveform inputs (audio only)
Audio datasets may expose an opt-in waveform-input mode (e.g.
NS1Dataset(return_waveform=True, audio_fs=48000)). The spectrogram transform
that would otherwise be baked into the dataset moves into the model, as the
first slot of the canonical pipeline (wav2spec → prefiltering → core → readout), so it can be learned (SincNet, ICNet) rather than fixed.
In this mode:
self.stims[s]is a(1, T_audio)mono float32 tensor atself.audio_fsHz.self.responses[s][n]is unchanged — still(R, T_neural)at the dataset’sdt_ms.The model owns the rate change: pair the waveform stim with a model whose
wav2specslot is a non-Identity()module (seedeepSTRF.models.wav2spec). The slot maps(B, 1, T_audio) → (B, 1, F, T_neural)and the rest of the pipeline runs exactly as in spectrogram mode.
Strict causality factors into three composable pieces. deepSTRF’s hard
requirement — model output at response bin t depends only on stimulus up to
wall-clock time (t+1)·dt — is preserved because:
Grid lock (dataset).
hop = audio_fs · dt_ms / 1000is an integer andT_audio = T_neural · hop, pinning audio samplejto response binj // hop.Per-frame causality (wav2spec). Output frame
tsees only audio samples[0, (t+1)·hop)— enforced by the Jacobian probe intests/test_wav2spec.py.Downstream causality (rest of the pipeline). Once wav2spec emits a frame-aligned causal spectrogram at the neural rate, the existing spec-domain contract (
tests/test_audio_models.py) takes over unchanged.
Compose (1) and (2): frame t sees exactly the audio of bins 0…t and nothing
later. The wav2spec is effectively a causal resampler from audio rate to neural
rate, and hop is the bridge between the two grids.
Conventions for a dataset’s waveform branch:
C1 — grid lock.
audio_fs · dt_ms / 1000must be a positive integer, and each waveform must be cropped / padded to exactlyT_resp_s · hopsamples. Enforced for every audio dataset byAudioNeuralDataset.validate()(and surfaced asself.hop).C2 — right-pad. Variable-length waveforms are zero-padded on the end (append), mirroring the response NaN-padding, so bin-0 alignment is preserved.
neural_collatedoes this automatically — one unified collate serves both modes (stims zero-padded, responses NaN-padded), so users never swap collate.C3 — mono. Waveforms are
(1, T_audio), downmixed to mono.C4 — opt-in.
return_waveform=Trueis implemented only where genuine source audio exists. Synthetic-spectrogram-only stimuli (with no underlying waveform) stay spec-mode; a dataset with no source audio raisesNotImplementedError.
Matching a wav2spec to its dataset. A front-end’s audio_fs and hop must
agree with the dataset’s, or frames won’t align with response bins. The simple,
explicit path is to read them off the dataset:
w = SincNet(audio_fs=ds.audio_fs, hop_ms=ds.dt, n_filters=ds.F)
# or: w = make_wav2spec("sincnet", audio_fs=ds.audio_fs, dt_ms=ds.dt, ...)
A gross mismatch (non-hop-divisible T_audio) raises in the wav2spec’s
forward; a subtler one (matching audio_fs, wrong hop) surfaces as a
prediction-vs-response shape error at the loss. There is no dataset↔model
auto-binding by design — the model holds no dataset reference, and the explicit
path above keeps things simple.
hearing_range_hz(informational). Audio datasets may set an optional(low, high)Hz bound on the species’ canonical hearing range (e.g.(200.0, 40000.0)for ferret). Purely advisory — nothing is enforced against it; tooling can display it and users may choose to clamp a wav2spec’s frequency limits.Nonewhen unknown.Datasets without a waveform branch leave
self.audio_fs = None(andself.hopis thenNone).
4. Why NaN-as-sentinel
This is a deliberate design choice, weighed against the explicit-mask alternative used by HuggingFace / fairseq.
Chosen properties:
Single source of truth. The mask is a derived view of the data; it cannot go out of sync.
Loud failure mode. If downstream code forgets to respect missingness, NaN propagates through loss and gradients and crashes training immediately. The alternative (multiply-by-mask silently accepts zero-contamination on misuse).
No dtype overhead. Spike counts are already float, so NaN fits natively.
Accepted costs:
Any response-side preprocessing (smoothing, normalization, etc.) must be NaN-aware. Use
nanmean/nanstd/nanmax, or apply the mask before reducing. Specifically:NeuralDataset.smooth_responses()andnormalize_responses()(currently unimplemented) must handle NaN when they are filled in.Float-only;
intspike-count tensors cannot hold NaN. Responses are stored as float throughout.NaN vs numerical instability is ambiguous when debugging. When a loss is NaN, check the mask path first.
Invariants this does not compromise:
Stim tensors are the model’s input. They are zero-padded, never NaN-padded. This means
BatchNorm,LayerNorm,softmax, and all standard model-side reductions operate on clean data. The only caveat is the small statistical bias from including zero-padded regions in normalization stats — see §8.
5. Batching: what the collate produces
Use neural_collate from deepSTRF.utils.data with a PyTorch DataLoader:
from torch.utils.data import DataLoader
from deepSTRF.utils.data import neural_collate
loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=neural_collate)
One yielded batch is a 4-tuple:
Name |
Shape |
Contents |
|---|---|---|
|
|
Float tensor. Zero-padded on the right along |
|
|
Float tensor. NaN-padded on the right along |
|
|
|
|
|
Per-stim metadata dicts, same as stored in |
If you need the coarser “did this neuron hear this stim” per batch-item
mask, recover it as valid_mask.any(dim=(-1, -2)) — a (B, N) bool
tensor.
6. Recommended training-loop loss pattern
NaN handling at the loss / metric site is the responsibility of
deepSTRF.metrics — every shipped function in that module is
NaN-aware by default. The training loop therefore stays simple:
from deepSTRF.metrics import mse_loss, corrcoef, normalized_corrcoef
for batch in loader: # batch is a dict with keys
stims, responses = batch['stims'], batch['responses'] # 'stims','responses','valid_mask','stim_meta'
pred = model(stims) # (B, N, 1, T_max)
gt_psth = responses.nanmean(dim=2, keepdim=True) # (B, N, 1, T_max)
loss = mse_loss(pred, gt_psth) # scalar; handles NaN internally
loss.backward()
optimizer.step()
if val_step:
cc = corrcoef(pred, gt_psth, reduction='none')
cc_norm = normalized_corrcoef(pred, responses, method='schoppe',
reduction='none')
Notes:
responses.nanmean(dim=2, keepdim=True)is the canonical PSTH — carries NaN where all repeats are NaN, paired withpredshape.The dataloader’s
valid_maskis informational; redundant with~gt_psth.isnan()for prediction-vs-PSTH metrics.For non-default masking (e.g. excluding stimulus onsets), every metric accepts a
mask=override.
Detailed treatment (per-neuron flattening, mask= override
semantics, length-weighting across stims, single-trial degenerate
handling, eval-only metrics): see
metrics_paradigm.md.
7. Invariants for developers
Stim tensors never contain NaN. Zero-padding only.
Response tensors may contain NaN. Always respect
valid_mask(at the loss) ornrn_masks(at the dataset) before reducing.Response-side preprocessing must be NaN-aware.
smooth_responses,normalize_responses, any user-written response transform.Never feed responses to the model. They are targets, not inputs.
Models emit predictions for every batched position — including zero-padded stim regions and uncorded neuron-stim pairs. The loss, not the model, handles masking.
Subclasses of
NeuralDatasetmust callself.validate()as the last line of__init__(the oldself.compute_nrn_masks()call is no longer needed —nrn_masksis a@propertyderived from responses).__len__and__getitem__honour the current selection — bidirectionally. See §8 below — the contract deserves its own section.Waveform mode obeys the grid lock. When
self.audio_fsis set, every stim is(1, T_audio)withT_audio = T_resp_s · hopand integerhop = audio_fs · dt_ms / 1000(§3.4, conventions C1–C4). Enforced byAudioNeuralDataset.validate().
8. Iteration honours the current selection (bidirectional)
deepSTRF datasets carry two independent selection filters:
Attribute |
Type |
“No restriction” sentinel |
Default |
|---|---|---|---|
|
|
|
|
|
|
|
|
The two filters compose. __len__ and __getitem__ always agree on the
iterable subset of stimuli, defined as the intersection of:
the stimuli with at least one valid response among the currently selected neurons (
self.I-side filter), andthe stimuli explicitly listed in
self.S_sel, if it is set (self.S_sel-side filter).
The same intersection drives the neuron side: when self.S_sel is set,
neurons whose only valid responses lie outside the selected stim subset
are also hidden from the yielded responses. So selecting NAT4’s val
subset (18 stims) automatically drops the 33 A1 cells that have no val
data — they would otherwise surface as full-NaN responses.
This is what we mean by “bidirectional”: narrowing one side narrows the
other side too, when missingness in nrn_masks makes that the
information-preserving choice.
Selection API
# neuron-side — exact-match (one of the following)
ds.select_neuron(i) # single
ds.select_population([0, 1, 5]) # explicit list
ds.select_pop_by_nrn_attr("area", "MLd") # by neuron metadata
ds.select_pop_by_stim_attr("subset", "val") # neurons with >=1 valid resp on val stims
# neuron-side — predicate (threshold, range, compound)
ds.select_pop_by_nrn_predicate(lambda n: n.get("snr", 0) > 0.5)
ds.select_pop_by_stim_predicate(lambda s: s["duration_s"] >= 2.0)
# stim-side (one of the following)
ds.select_stim(i) # single
ds.select_stims([0, 5, 10]) # explicit list
ds.select_stims_by_attr("subset", "val") # by stim metadata
ds.select_stims_by_predicate(lambda s: s["type"] in {"song", "call"})
ds.reset_stim_selection() # clear S_sel (back to None)
The *_predicate variants take any callable(dict) -> bool, so threshold
(snr > 0.5), range (200 <= depth_um <= 800), and compound conditions
are expressible. Forgiving missing-key semantics match the *_attr
siblings (KeyError / TypeError silently skip), so a single predicate
works on a concatenated dataset whose sources carry heterogeneous metadata
schemas.
Opt-in per-neuron quality metrics
Call ds.compute_neuron_quality() once after construction to write two
scalars into each nrn_meta[i]:
ds.compute_neuron_quality()
ds.nrn_meta[0] # → {..., "snr": 0.52, "ccmax": 0.91}
ds.select_pop_by_nrn_predicate(lambda n: n["snr"] > 0.5)
'snr'— Sahani-Linden signal-to-noise ratio \(\text{SP}_n / \text{NP}_n\), length-weighted across stims (weight = number of valid time bins per stim, matching the convention inmetrics_paradigm.md§11). NaN when the neuron has no stim with \(R_{b,n} \ge 2\) and \(T_b \ge 2\).'ccmax'— Hsu/Spearman-Brown noise ceiling, length-weighted across stims with \(R_{b,n} \ge 2\). Falls back to1.0when the neuron has zero such stims (R=1 everywhere — no normalization possible, socc_norm = cc_raw).
Opt-in (not auto-called in __init__) because CCmax is
\(O(S \cdot N \cdot R^2 \cdot \text{max\_iters})\) in the worst case — on
big datasets (AA4, NAT4, Espejo NAT) this runs in tens of seconds.
Why two selectors are sometimes both needed:
select_pop_by_stim_attr("subset", "val")keeps the neurons with val data, but leaves all stims iterable (you might want est responses for those neurons too).select_stims_by_attr("subset", "val")keeps the val stims, and also drops the cells that lack val data via the bidirectional rule.
So the “I want to train on val only, with val-having cells only” recipe
is one call to select_stims_by_attr("subset", "val") — the bidirectional
rule does the other half.
Why this matters
DataLoader compatibility. PyTorch iterates
range(len(ds))and callsds[i]for each. The two have to agree, or the loader requests indices that map to fully-NaN items. The filter is what makes them agree.Concatenated datasets are safe. When you concatenate sources A and B and then select only A’s neurons, B’s stimuli are automatically hidden — their responses against A’s neurons are all NaN, so they don’t survive the filter. No special-case logic in the training loop.
Selecting on stim attributes is implicit too. Selecting only the neurons that heard a particular stim type narrows iteration to those stims, by transitivity. This subsumes most one-off “iterate only the stims that match X” needs.
Concrete consequences
ds = concat_neural_datasets([aa1, aa2]) # 30 + 117 stims, 100 + 494 neurons
len(ds) # 147 — full population by default
ds.select_population(list(range(100))) # only AA1's neurons
len(ds) # 30 — AA2's stims now hidden
ds[0] # the FIRST iterable stim under selection
# i.e. AA1's stim 0; same as before selection
ds[29] # AA1's stim 29 — last iterable
ds[30] # IndexError, not a fully-NaN AA2 stim
ds.select_pop_by_nrn_attr("area", "MLd") # MLd neurons across A and B
len(ds) # however many stims any MLd neuron heard
# stim filter + bidirectional rule
nat4 = NAT4Dataset(area="A1") # 593 stims, 849 cells
nat4.select_stims_by_attr("subset", "val") # 18 val stims, 816 val-having cells
len(nat4) # 18; the 33 val-less cells are gone
S_sel = None vs S_sel = []
The two are intentionally distinct:
S_sel = Nonemeans “no restriction” — iteration spans all stims. Returned byreset_stim_selection()and the default at construction.S_sel = []means “explicit zero-stim selection” — iteration yields zero items. This is whatselect_stims_by_attr("foo", "bar")produces when no stim matches. The asymmetry withself.I(which uses[]for “no restriction”) is deliberate: selecting by an attribute that no stim carries should not silently disable the filter.
Raw access vs iteration
The stored attributes (self.stims, self.responses, self.stim_meta,
self.nrn_meta, self.nrn_masks) are not filtered. They keep
the dataset’s full structure regardless of self.I or self.S_sel. To
address a raw stim by its absolute index, read those attributes directly:
ds.stims[42] # raw stim 42, regardless of selection
ds.responses[42] # all neurons' responses to stim 42
ds.nrn_masks[42] # which neurons have data for stim 42
The selection filter applies only at the iteration interface (__len__,
__getitem__).
9. Gotchas
Non-causal models on zero-padded stims. A bidirectional RNN, a Transformer without an attention mask, or a CNN with center-weighted temporal kernels will see “future silence” where the stim has been right-padded. The loss will ignore these positions on the output side (since
valid_maskis False there), but the model’s internal representations at valid positions can still be affected by attending to / convolving over the padded zeros. For causal architectures, this is not an issue.Normalization bias from zero-padded stim regions.
BatchNorm/LayerNormoverTinclude zero-padded steps in the stats. Acceptable when padding fraction is small (< ~20 %); otherwise use a masked normalization or bucket batches by length.GPU synchronization from boolean indexing.
pred[valid]has a data-dependent output size, which triggers a small GPU sync. Imperceptible at deepSTRF scales (B ≤ 32,N ≤ 100,T ≤ 5000); worth knowing if scaling up.Spike-count dtype must be float. Storing responses as
intwould prevent NaN encoding.
10. Dataset concatenation
The NaN-sentinel convention enables a bonus feature: concatenating neural
datasets along both the stim and neuron axes is essentially free — the
cross-block (stim_A, neuron_B) entries reuse the same (1, 1) NaN tensor
used elsewhere for structural missingness, and nrn_masks reports the
block-diagonal coverage automatically.
from deepSTRF.utils.data import concat_neural_datasets
combined = concat_neural_datasets([aa1, aa2]) # or: aa1 + aa2
See dataset_concatenation.md for the full
feature page (rationale, compatibility rules, chimeric-model use case)
and examples/dataset_concatenation.ipynb
for a runnable demo.
11. When the paradigm might need to evolve
Migration to
torch.nestedtensors once the ecosystem matures — would remove explicit padding, possibly with model-side support gaps.Length-bucketed batching for datasets with very heterogeneous
T_s— optimization, not a redesign.Per-repeat weighting (some repeats noisier than others) — would require either extending the mask from bool to float, or adding a separate weight tensor.
These are explicitly not on the current roadmap; flagged here so the discussion doesn’t have to be rediscovered.