deepSTRF.datasets package

Subpackages

Submodules

deepSTRF.datasets.neural_dataset module

class deepSTRF.datasets.neural_dataset.NeuralDataset(path: str, dt_ms: float)[source]

Bases: Dataset, ABC

General base class for datasets of sensory neural responses.

deepSTRF datasets are triply ragged (variable stim duration, variable repeat count per (stim, neuron), sparse stim/neuron coverage). They are stored as Python lists of tensors, with NaN used as the single channel for encoding missingness. See docs/_source/md/data_paradigm.md for the full rationale, collate behaviour, and recommended loss pattern.

Subclass contract

A concrete subclass must populate the following attributes in its __init__ and then call self.validate() as its last line:

  • self.stims — list of length S, each element a stimulus

    tensor of modality-specific shape (audio spectrogram: (1, F, T_stim), audio waveform: (1, T_stim), video: (1, H, W, T_stim)). T_stim may vary across stimuli AND may differ from the response time axis when the stim sampling rate is finer than the neural rate (e.g. raw waveforms vs spike counts).

  • self.responses — list of length S, each element itself a

    list of length N. responses[s][n] is a (R_{s,n}, T_resp_s) float tensor of spike counts per repeat × time bin at the dataset’s neural dt_ms, or a (1, 1) NaN tensor if neuron n did not hear stim s. T_resp_s is the per-stim response length; in spectrogram mode it equals the last axis of self.stims[s].

  • self.stim_meta — list of length S, per-stim metadata dicts.

  • self.nrn_meta — list of length N, per-neuron metadata dicts.

  • self.N_neurons — int, must equal len(self.nrn_meta).

Derived attributes (no explicit population needed):

  • self.nrn_masks(S, N) bool tensor, derived on the fly from the NaN sentinels in self.responses. nrn_masks[s, n] is True iff neuron n has real data for stim s. Implemented as a @property so it is always consistent with the current self.responses — no risk of the mask going out of sync.

Opt-in per-neuron quality metrics

Call self.compute_neuron_quality() after construction to populate each nrn_meta[i] with two scalars derived from the responses themselves: Sahani-Linden 'snr' and Hsu/Spearman-Brown 'ccmax'. Useful as predicate-filter inputs (e.g. ds.select_pop_by_nrn_predicate(lambda n: n['snr'] > 0.5)). Opt-in rather than auto because CCmax is \(O(S \cdot N \cdot R^2)\) and a few seconds on the big datasets.

Key invariants

  • Stim tensors never contain NaN. Batch-level collate zero-pads them on the right along T.

  • Response tensors may contain NaN. Use self.nrn_masks (dataset level) or the derived valid_mask from collate (batch level).

  • Response-side preprocessing (smooth_responses, normalize_responses, any user-written transform) must be NaN-aware — either use nanmean / nanstd / etc., or apply the mask before reducing.

compute_neuron_quality(max_ccmax_iters: int = 126) dict[source]

Write per-neuron SNR and CCmax scalars into self.nrn_meta.

For each neuron, adds two keys to self.nrn_meta[i]:

  • 'snr' (float) — Sahani-Linden signal-to-noise ratio \(\text{SP}_n / \text{NP}_n\), with the two terms length-weighted across stims (weight = number of valid time bins per stim). NaN when the neuron has no stim with \(R_{b,n} \ge 2\) and \(T_b \ge 2\).

  • 'ccmax' (float) — Hsu/Spearman-Brown noise ceiling (capped at max_ccmax_iters random half-splits per (stim, neuron)), length-weighted across stims with \(R_{b,n} \ge 2\). Falls back to 1.0 when the neuron has zero such stims (R=1 everywhere — no normalization possible, so cc_norm = cc_raw). NaN if every contributing stim has \(\rho_{\text{half}} \le 0\) (signal too weak to estimate the ceiling).

Both scalars use the length-weighted aggregation convention from metrics_paradigm.md §11, matching what corrcoef / normalized_corrcoef would compute on the concatenated-over-stims time axis.

Parameters:

max_ccmax_iters (int, default 126) – Maximum number of random half-splits per (stim, neuron) for the CCmax estimate. C(R, R/2) blows up for R > 10; 126 matches the default of compute_CCmax().

Returns:

{'snr': Tensor[N], 'ccmax': Tensor[N]} for inspection. The same values are written into self.nrn_meta as plain Python floats.

Return type:

dict

Notes

Opt-in (not auto-called in __init__). CCmax is \(O(S \cdot N \cdot R^2 \cdot \text{max\_iters})\) in the worst case — on big datasets (AA4, NAT4, Espejo NAT) this can take tens of seconds. Call once after dataset construction; results live on nrn_meta for subsequent filter-API calls.

Memory profile: streams one stim at a time, peaking at the largest single-stim (N, R_s, T_s) slab. The earlier implementation pre-built a global (S, N, R_max, T_max) padded tensor and OOMed on Downer 2025 TIMIT (~54 GB); the per-stim streaming variant lands the same numbers bit-identically but with a much smaller working set.

Examples

>>> ds = CRCNSAA1Dataset(...)
>>> ds.compute_neuron_quality()
>>> ds.select_pop_by_nrn_predicate(lambda n: n['snr'] > 0.5)
get_N()[source]

Return the total number of selectable neurons.

Returns:

self.N_neurons (the full population size, not the current selection).

Return type:

int

get_S()[source]

Return the total number of stimuli in the dataset.

Returns:

Number of stimuli presented to the whole neural population.

Return type:

int

get_nrn_meta()[source]

Return metadata for each currently selected neuron.

Returns:

The nrn_meta dicts for the neurons in the current selection self.I.

Return type:

list of dict

normalize_responses(method: str = 'max', stim_indices: Sequence[int] | None = None, eps: float = 1e-08) dict[source]

Normalize self.responses in place, per neuron.

Statistics are computed on a chosen stim subset (typically train +val) and applied to all stims, mirroring standardize_stims. (1, 1) NaN sentinels for structurally missing (stim, neuron) pairs are preserved unchanged.

Parameters:
  • method ({'max', 'zscore'}, default 'max') –

    ‘max’ — divide each neuron’s responses by their max across all (s, r, t) in stim_indices. Preserves non-negativity; range becomes [0, max] -> [0, 1]. Natural for rate-coded / spike-count targets where 0 is meaningful.

    ’zscore’ — subtract per-neuron mean, divide by per-neuron std, both computed NaN-aware over the same flat (s, r, t) samples. Maps signed continuous targets (EEG, LFP) to N(0, 1).

  • stim_indices (sequence of int, optional) – Stims used to compute statistics. If None, all stims are used.

  • eps (float, default 1e-8) – Floor on the divisor.

Returns:

{'method': str, 'scale': Tensor[N], 'offset': Tensor[N], 'stim_indices': list | None} — also stored on self.response_normalization. scale is the divisor (max or std); offset is the subtracted location (0 for ‘max’, mean for ‘zscore’).

Return type:

dict

Notes

Not idempotent — calling twice double-normalizes.

property nrn_masks: Tensor

True iff neuron n has real data for stim s.

Derived from the NaN sentinels in self.responses — single source of truth, cannot go out of sync. Lazy-cached on first access: subsequent reads are O(1). Callers that structurally mutate the response list (replace a real tensor with a (1, 1) NaN sentinel, or vice versa) should call self._invalidate_nrn_masks() afterwards. Shape-preserving mutations (smooth_responses, normalization, etc.) leave the mask unchanged and do not require invalidation.

For a bare dataset (no populated responses), returns an empty (0, N_neurons) tensor.

Type:

Derived (S, N) bool tensor

reset_pop_selection()[source]

Clear the population selection so all neurons are eligible again.

Mirror of reset_stim_selection. Restores self.I to its empty default (interpreted as “no neuron-side restriction” by _selected()).

reset_stim_selection()[source]

Clear self.S_sel so all stimuli are eligible again.

select_neuron(neuron_index: int)[source]

Restrict the selection to a single neuron.

Parameters:

neuron_index (int) – Index into the full population, in [0, N_neurons).

select_pop_by_nrn_attr(attribute_name: str, value)[source]

Select neurons whose nrn_meta[attribute_name] == value.

Neurons whose metadata dict does not contain attribute_name are silently skipped — this lets a single filter call work against a concatenated dataset that pools sources with different metadata schemas (e.g. AA1’s area is not present on AA4 neurons; calling select_pop_by_nrn_attr("area", "Field_L") on the concatenation keeps only AA1 neurons in Field L, with no KeyError).

Parameters:
  • attribute_name (str) – Key looked up in each nrn_meta dict.

  • value – Required value for an exact (==) match.

Returns:

Indices of the selected neurons. Also stored in self.I.

Return type:

list of int

See also

select_pop_by_nrn_predicate

threshold / range / compound queries.

select_pop_by_nrn_predicate(predicate)[source]

Select neurons whose nrn_meta dict satisfies predicate.

More flexible than select_pop_by_nrn_attr(): takes any callable that maps a single nrn_meta dict to a truthy / falsy value, so threshold queries on continuous attributes (”snr > 0.5”), range queries (”200 <= depth_um <= 800”) and compound conditions (”area in {'Field_L', 'MLd'} and auditory”) are expressible. The *_attr siblings remain available for the equality-only case.

Neurons whose predicate raises KeyError or TypeError are silently skipped — same convention as select_pop_by_nrn_attr() so a single predicate works on a concatenated dataset whose sources carry heterogeneous metadata schemas. Note: this means a typo in the predicate (referencing a wrong key) will silently select no neurons rather than raising; use nrn.get(key, default) in the predicate for explicit-default semantics.

Parameters:

predicate (callable(dict) -> bool) – Tested on each nrn_meta[i] dict.

Returns:

Indices of selected neurons. Also stored in self.I.

Return type:

list[int]

Examples

>>> ds.select_pop_by_nrn_predicate(lambda n: n.get("snr", 0) > 0.5)
>>> ds.select_pop_by_nrn_predicate(lambda n: n["area"] in {"Field_L", "MLd"})
>>> ds.select_pop_by_nrn_predicate(
...     lambda n: 200 <= n.get("depth_um", -1) <= 800
... )
select_pop_by_stim_attr(attribute_name: str, value)[source]

Select neurons with >=1 non-null response to stimuli matching a given attribute.

Looks up stimuli whose stim_meta[attribute_name] == value and keeps only neurons whose nrn_masks is True for at least one of them. Stims missing the key are silently skipped (same convention as select_pop_by_nrn_attr()).

Parameters:
  • attribute_name (str) – Key looked up in each stim_meta dict.

  • value – Required value for an exact (==) match.

Returns:

Indices of the selected neurons. Also stored in self.I.

Return type:

list of int

select_pop_by_stim_predicate(predicate)[source]

Select neurons with >=1 non-null response to stimuli matching predicate.

Predicate variant of select_pop_by_stim_attr(). Looks up stimuli whose stim_meta dict satisfies predicate and keeps only neurons whose nrn_masks is True for at least one of them. Stims whose predicate raises KeyError or TypeError are silently skipped — same forgiving convention as select_pop_by_nrn_predicate().

Parameters:

predicate (callable(dict) -> bool) – Tested on each stim_meta[s] dict.

Returns:

Indices of selected neurons. Also stored in self.I.

Return type:

list[int]

select_population(neuron_indices)[source]

Restrict the selection to the listed neurons.

Parameters:

neuron_indices (sequence of int) – Indices into the full population, each in [0, N_neurons).

select_stim(stim_index: int)[source]

Restrict iteration to a single stimulus index.

Pairs with the bidirectional rule in _selected(): cells whose only valid responses lie outside the selected stim are auto-hidden from __getitem__.

Parameters:

stim_index (int) – Index into the stim space, in [0, S).

select_stims(stim_indices)[source]

Restrict iteration to the listed stimulus indices.

Parameters:

stim_indices (sequence of int) – Indices into the stim space, each in [0, S).

select_stims_by_attr(attribute_name: str, value)[source]

Restrict iteration to stimuli matching stim_meta[attr] == value.

Stims whose metadata dict does not contain attribute_name are silently skipped — same convention as select_pop_by_nrn_attr(), so a single call works on a concatenated dataset whose sources have heterogeneous stim metadata schemas.

Parameters:
  • attribute_name (str) – Key looked up in each stim_meta dict.

  • value – Required value for an exact (==) match.

Returns:

Indices of the selected stims. Also stored in self.S_sel.

Return type:

list of int

select_stims_by_predicate(predicate)[source]

Restrict iteration to stimuli whose stim_meta satisfies predicate.

Predicate variant of select_stims_by_attr(). Takes any callable mapping a single stim_meta dict to truthy / falsy, so threshold and compound queries on continuous attributes (”duration_s > 2.0”, “sample_rate >= 24000”) become expressible.

Stims whose predicate raises KeyError or TypeError are silently skipped — same forgiving convention as select_pop_by_nrn_predicate(). Use sm.get(key, default) in the predicate for explicit-default semantics.

Parameters:

predicate (callable(dict) -> bool) – Tested on each stim_meta[s] dict.

Returns:

Indices of selected stims. Also stored in self.S_sel.

Return type:

list[int]

Examples

>>> ds.select_stims_by_predicate(lambda s: s.get("duration_s", 0) >= 2.0)
>>> ds.select_stims_by_predicate(lambda s: s["type"] in {"song", "call"})
smooth_responses(window_ms: float = 21.0) None[source]

Temporally smooth each non-NaN response in place with a Hanning window.

Parameters:

window_ms (float, default 21.0) – Full width of the Hanning window in ms. Rounded to the nearest odd number of self.dt bins.

Notes

Follows Hsu, Borst & Theunissen (2004) for reducing PSTH estimator variance — a common preprocessing step across spike-count datasets. (1, 1) NaN-sentinel responses (neurons that did not hear a given stim) are preserved unchanged.

standardize_stims(stim_indices: Sequence[int] | None = None, per_band: bool = True, eps: float = 1e-08) dict[source]

Standardize self.stims in place: (x mean) / std.

Statistics are computed over the stims selected by stim_indices (typically train + validation indices) and applied to all stims in the dataset — so the held-out test stims are automatically transformed with the same train+val statistics, preventing leakage of test-set first-order moments into the standardisation while still ensuring train / val / test all live in the same standardised space.

Parameters:
  • stim_indices (sequence of int, optional) – Indices of stims to compute statistics from. If None (default), statistics are computed over all stims — equivalent to “no held-out test set”; useful for single-split exploratory analysis but introduces a tiny (first-order) leakage if a test set is held out downstream.

  • per_band (bool, default True) – If True, statistics are per-frequency-band (axis -2): mean / std are tensors of shape broadcastable to (C, F, 1). If False, a single scalar mean and std are computed over the whole concatenated stim tensor.

  • eps (float, default 1e-8) – Floor on std to avoid division by zero on constant bands.

Returns:

{'mean': Tensor, 'std': Tensor, 'per_band': bool, 'stim_indices': list | None} — also stored on self.stim_normalization for inspection (e.g. to fold into a model kernel for STRF visualisation).

Return type:

dict

Notes

Not idempotent: calling twice double-standardizes. To re-do with different statistics, rebuild the dataset.

validate()[source]

Check that the instance is deepSTRF-compatible.

Subclasses should call super().validate() and then add their own checks (e.g. AudioNeuralDataset checks self.F > 0).

Module contents

class deepSTRF.datasets.NeuralDataset(path: str, dt_ms: float)[source]

Bases: Dataset, ABC

General base class for datasets of sensory neural responses.

deepSTRF datasets are triply ragged (variable stim duration, variable repeat count per (stim, neuron), sparse stim/neuron coverage). They are stored as Python lists of tensors, with NaN used as the single channel for encoding missingness. See docs/_source/md/data_paradigm.md for the full rationale, collate behaviour, and recommended loss pattern.

Subclass contract

A concrete subclass must populate the following attributes in its __init__ and then call self.validate() as its last line:

  • self.stims — list of length S, each element a stimulus

    tensor of modality-specific shape (audio spectrogram: (1, F, T_stim), audio waveform: (1, T_stim), video: (1, H, W, T_stim)). T_stim may vary across stimuli AND may differ from the response time axis when the stim sampling rate is finer than the neural rate (e.g. raw waveforms vs spike counts).

  • self.responses — list of length S, each element itself a

    list of length N. responses[s][n] is a (R_{s,n}, T_resp_s) float tensor of spike counts per repeat × time bin at the dataset’s neural dt_ms, or a (1, 1) NaN tensor if neuron n did not hear stim s. T_resp_s is the per-stim response length; in spectrogram mode it equals the last axis of self.stims[s].

  • self.stim_meta — list of length S, per-stim metadata dicts.

  • self.nrn_meta — list of length N, per-neuron metadata dicts.

  • self.N_neurons — int, must equal len(self.nrn_meta).

Derived attributes (no explicit population needed):

  • self.nrn_masks(S, N) bool tensor, derived on the fly from the NaN sentinels in self.responses. nrn_masks[s, n] is True iff neuron n has real data for stim s. Implemented as a @property so it is always consistent with the current self.responses — no risk of the mask going out of sync.

Opt-in per-neuron quality metrics

Call self.compute_neuron_quality() after construction to populate each nrn_meta[i] with two scalars derived from the responses themselves: Sahani-Linden 'snr' and Hsu/Spearman-Brown 'ccmax'. Useful as predicate-filter inputs (e.g. ds.select_pop_by_nrn_predicate(lambda n: n['snr'] > 0.5)). Opt-in rather than auto because CCmax is \(O(S \cdot N \cdot R^2)\) and a few seconds on the big datasets.

Key invariants

  • Stim tensors never contain NaN. Batch-level collate zero-pads them on the right along T.

  • Response tensors may contain NaN. Use self.nrn_masks (dataset level) or the derived valid_mask from collate (batch level).

  • Response-side preprocessing (smooth_responses, normalize_responses, any user-written transform) must be NaN-aware — either use nanmean / nanstd / etc., or apply the mask before reducing.

compute_neuron_quality(max_ccmax_iters: int = 126) dict[source]

Write per-neuron SNR and CCmax scalars into self.nrn_meta.

For each neuron, adds two keys to self.nrn_meta[i]:

  • 'snr' (float) — Sahani-Linden signal-to-noise ratio \(\text{SP}_n / \text{NP}_n\), with the two terms length-weighted across stims (weight = number of valid time bins per stim). NaN when the neuron has no stim with \(R_{b,n} \ge 2\) and \(T_b \ge 2\).

  • 'ccmax' (float) — Hsu/Spearman-Brown noise ceiling (capped at max_ccmax_iters random half-splits per (stim, neuron)), length-weighted across stims with \(R_{b,n} \ge 2\). Falls back to 1.0 when the neuron has zero such stims (R=1 everywhere — no normalization possible, so cc_norm = cc_raw). NaN if every contributing stim has \(\rho_{\text{half}} \le 0\) (signal too weak to estimate the ceiling).

Both scalars use the length-weighted aggregation convention from metrics_paradigm.md §11, matching what corrcoef / normalized_corrcoef would compute on the concatenated-over-stims time axis.

Parameters:

max_ccmax_iters (int, default 126) – Maximum number of random half-splits per (stim, neuron) for the CCmax estimate. C(R, R/2) blows up for R > 10; 126 matches the default of compute_CCmax().

Returns:

{'snr': Tensor[N], 'ccmax': Tensor[N]} for inspection. The same values are written into self.nrn_meta as plain Python floats.

Return type:

dict

Notes

Opt-in (not auto-called in __init__). CCmax is \(O(S \cdot N \cdot R^2 \cdot \text{max\_iters})\) in the worst case — on big datasets (AA4, NAT4, Espejo NAT) this can take tens of seconds. Call once after dataset construction; results live on nrn_meta for subsequent filter-API calls.

Memory profile: streams one stim at a time, peaking at the largest single-stim (N, R_s, T_s) slab. The earlier implementation pre-built a global (S, N, R_max, T_max) padded tensor and OOMed on Downer 2025 TIMIT (~54 GB); the per-stim streaming variant lands the same numbers bit-identically but with a much smaller working set.

Examples

>>> ds = CRCNSAA1Dataset(...)
>>> ds.compute_neuron_quality()
>>> ds.select_pop_by_nrn_predicate(lambda n: n['snr'] > 0.5)
get_N()[source]

Return the total number of selectable neurons.

Returns:

self.N_neurons (the full population size, not the current selection).

Return type:

int

get_S()[source]

Return the total number of stimuli in the dataset.

Returns:

Number of stimuli presented to the whole neural population.

Return type:

int

get_nrn_meta()[source]

Return metadata for each currently selected neuron.

Returns:

The nrn_meta dicts for the neurons in the current selection self.I.

Return type:

list of dict

normalize_responses(method: str = 'max', stim_indices: Sequence[int] | None = None, eps: float = 1e-08) dict[source]

Normalize self.responses in place, per neuron.

Statistics are computed on a chosen stim subset (typically train +val) and applied to all stims, mirroring standardize_stims. (1, 1) NaN sentinels for structurally missing (stim, neuron) pairs are preserved unchanged.

Parameters:
  • method ({'max', 'zscore'}, default 'max') –

    ‘max’ — divide each neuron’s responses by their max across all (s, r, t) in stim_indices. Preserves non-negativity; range becomes [0, max] -> [0, 1]. Natural for rate-coded / spike-count targets where 0 is meaningful.

    ’zscore’ — subtract per-neuron mean, divide by per-neuron std, both computed NaN-aware over the same flat (s, r, t) samples. Maps signed continuous targets (EEG, LFP) to N(0, 1).

  • stim_indices (sequence of int, optional) – Stims used to compute statistics. If None, all stims are used.

  • eps (float, default 1e-8) – Floor on the divisor.

Returns:

{'method': str, 'scale': Tensor[N], 'offset': Tensor[N], 'stim_indices': list | None} — also stored on self.response_normalization. scale is the divisor (max or std); offset is the subtracted location (0 for ‘max’, mean for ‘zscore’).

Return type:

dict

Notes

Not idempotent — calling twice double-normalizes.

property nrn_masks: Tensor

True iff neuron n has real data for stim s.

Derived from the NaN sentinels in self.responses — single source of truth, cannot go out of sync. Lazy-cached on first access: subsequent reads are O(1). Callers that structurally mutate the response list (replace a real tensor with a (1, 1) NaN sentinel, or vice versa) should call self._invalidate_nrn_masks() afterwards. Shape-preserving mutations (smooth_responses, normalization, etc.) leave the mask unchanged and do not require invalidation.

For a bare dataset (no populated responses), returns an empty (0, N_neurons) tensor.

Type:

Derived (S, N) bool tensor

reset_pop_selection()[source]

Clear the population selection so all neurons are eligible again.

Mirror of reset_stim_selection. Restores self.I to its empty default (interpreted as “no neuron-side restriction” by _selected()).

reset_stim_selection()[source]

Clear self.S_sel so all stimuli are eligible again.

select_neuron(neuron_index: int)[source]

Restrict the selection to a single neuron.

Parameters:

neuron_index (int) – Index into the full population, in [0, N_neurons).

select_pop_by_nrn_attr(attribute_name: str, value)[source]

Select neurons whose nrn_meta[attribute_name] == value.

Neurons whose metadata dict does not contain attribute_name are silently skipped — this lets a single filter call work against a concatenated dataset that pools sources with different metadata schemas (e.g. AA1’s area is not present on AA4 neurons; calling select_pop_by_nrn_attr("area", "Field_L") on the concatenation keeps only AA1 neurons in Field L, with no KeyError).

Parameters:
  • attribute_name (str) – Key looked up in each nrn_meta dict.

  • value – Required value for an exact (==) match.

Returns:

Indices of the selected neurons. Also stored in self.I.

Return type:

list of int

See also

select_pop_by_nrn_predicate

threshold / range / compound queries.

select_pop_by_nrn_predicate(predicate)[source]

Select neurons whose nrn_meta dict satisfies predicate.

More flexible than select_pop_by_nrn_attr(): takes any callable that maps a single nrn_meta dict to a truthy / falsy value, so threshold queries on continuous attributes (”snr > 0.5”), range queries (”200 <= depth_um <= 800”) and compound conditions (”area in {'Field_L', 'MLd'} and auditory”) are expressible. The *_attr siblings remain available for the equality-only case.

Neurons whose predicate raises KeyError or TypeError are silently skipped — same convention as select_pop_by_nrn_attr() so a single predicate works on a concatenated dataset whose sources carry heterogeneous metadata schemas. Note: this means a typo in the predicate (referencing a wrong key) will silently select no neurons rather than raising; use nrn.get(key, default) in the predicate for explicit-default semantics.

Parameters:

predicate (callable(dict) -> bool) – Tested on each nrn_meta[i] dict.

Returns:

Indices of selected neurons. Also stored in self.I.

Return type:

list[int]

Examples

>>> ds.select_pop_by_nrn_predicate(lambda n: n.get("snr", 0) > 0.5)
>>> ds.select_pop_by_nrn_predicate(lambda n: n["area"] in {"Field_L", "MLd"})
>>> ds.select_pop_by_nrn_predicate(
...     lambda n: 200 <= n.get("depth_um", -1) <= 800
... )
select_pop_by_stim_attr(attribute_name: str, value)[source]

Select neurons with >=1 non-null response to stimuli matching a given attribute.

Looks up stimuli whose stim_meta[attribute_name] == value and keeps only neurons whose nrn_masks is True for at least one of them. Stims missing the key are silently skipped (same convention as select_pop_by_nrn_attr()).

Parameters:
  • attribute_name (str) – Key looked up in each stim_meta dict.

  • value – Required value for an exact (==) match.

Returns:

Indices of the selected neurons. Also stored in self.I.

Return type:

list of int

select_pop_by_stim_predicate(predicate)[source]

Select neurons with >=1 non-null response to stimuli matching predicate.

Predicate variant of select_pop_by_stim_attr(). Looks up stimuli whose stim_meta dict satisfies predicate and keeps only neurons whose nrn_masks is True for at least one of them. Stims whose predicate raises KeyError or TypeError are silently skipped — same forgiving convention as select_pop_by_nrn_predicate().

Parameters:

predicate (callable(dict) -> bool) – Tested on each stim_meta[s] dict.

Returns:

Indices of selected neurons. Also stored in self.I.

Return type:

list[int]

select_population(neuron_indices)[source]

Restrict the selection to the listed neurons.

Parameters:

neuron_indices (sequence of int) – Indices into the full population, each in [0, N_neurons).

select_stim(stim_index: int)[source]

Restrict iteration to a single stimulus index.

Pairs with the bidirectional rule in _selected(): cells whose only valid responses lie outside the selected stim are auto-hidden from __getitem__.

Parameters:

stim_index (int) – Index into the stim space, in [0, S).

select_stims(stim_indices)[source]

Restrict iteration to the listed stimulus indices.

Parameters:

stim_indices (sequence of int) – Indices into the stim space, each in [0, S).

select_stims_by_attr(attribute_name: str, value)[source]

Restrict iteration to stimuli matching stim_meta[attr] == value.

Stims whose metadata dict does not contain attribute_name are silently skipped — same convention as select_pop_by_nrn_attr(), so a single call works on a concatenated dataset whose sources have heterogeneous stim metadata schemas.

Parameters:
  • attribute_name (str) – Key looked up in each stim_meta dict.

  • value – Required value for an exact (==) match.

Returns:

Indices of the selected stims. Also stored in self.S_sel.

Return type:

list of int

select_stims_by_predicate(predicate)[source]

Restrict iteration to stimuli whose stim_meta satisfies predicate.

Predicate variant of select_stims_by_attr(). Takes any callable mapping a single stim_meta dict to truthy / falsy, so threshold and compound queries on continuous attributes (”duration_s > 2.0”, “sample_rate >= 24000”) become expressible.

Stims whose predicate raises KeyError or TypeError are silently skipped — same forgiving convention as select_pop_by_nrn_predicate(). Use sm.get(key, default) in the predicate for explicit-default semantics.

Parameters:

predicate (callable(dict) -> bool) – Tested on each stim_meta[s] dict.

Returns:

Indices of selected stims. Also stored in self.S_sel.

Return type:

list[int]

Examples

>>> ds.select_stims_by_predicate(lambda s: s.get("duration_s", 0) >= 2.0)
>>> ds.select_stims_by_predicate(lambda s: s["type"] in {"song", "call"})
smooth_responses(window_ms: float = 21.0) None[source]

Temporally smooth each non-NaN response in place with a Hanning window.

Parameters:

window_ms (float, default 21.0) – Full width of the Hanning window in ms. Rounded to the nearest odd number of self.dt bins.

Notes

Follows Hsu, Borst & Theunissen (2004) for reducing PSTH estimator variance — a common preprocessing step across spike-count datasets. (1, 1) NaN-sentinel responses (neurons that did not hear a given stim) are preserved unchanged.

standardize_stims(stim_indices: Sequence[int] | None = None, per_band: bool = True, eps: float = 1e-08) dict[source]

Standardize self.stims in place: (x mean) / std.

Statistics are computed over the stims selected by stim_indices (typically train + validation indices) and applied to all stims in the dataset — so the held-out test stims are automatically transformed with the same train+val statistics, preventing leakage of test-set first-order moments into the standardisation while still ensuring train / val / test all live in the same standardised space.

Parameters:
  • stim_indices (sequence of int, optional) – Indices of stims to compute statistics from. If None (default), statistics are computed over all stims — equivalent to “no held-out test set”; useful for single-split exploratory analysis but introduces a tiny (first-order) leakage if a test set is held out downstream.

  • per_band (bool, default True) – If True, statistics are per-frequency-band (axis -2): mean / std are tensors of shape broadcastable to (C, F, 1). If False, a single scalar mean and std are computed over the whole concatenated stim tensor.

  • eps (float, default 1e-8) – Floor on std to avoid division by zero on constant bands.

Returns:

{'mean': Tensor, 'std': Tensor, 'per_band': bool, 'stim_indices': list | None} — also stored on self.stim_normalization for inspection (e.g. to fold into a model kernel for STRF visualisation).

Return type:

dict

Notes

Not idempotent: calling twice double-standardizes. To re-do with different statistics, rebuild the dataset.

validate()[source]

Check that the instance is deepSTRF-compatible.

Subclasses should call super().validate() and then add their own checks (e.g. AudioNeuralDataset checks self.F > 0).