Source code for deepSTRF.datasets.audio.audio_dataset

from typing import Optional

import torch

from deepSTRF.datasets.neural_dataset import NeuralDataset


[docs] class AudioNeuralDataset(NeuralDataset): """Neural dataset class for auditory stimuli. Stim shape is polymorphic depending on the loading mode: - **Spectrogram mode** (default): ``self.stims[s]`` is a ``(1, F, T)`` tensor where ``F = self.F`` is the frequency-band count and ``T`` is the neural time-bin count. - **Waveform mode** (opt-in, subclass-specific): ``self.stims[s]`` is a ``(1, T_audio)`` mono float32 tensor at sample rate ``self.audio_fs``. Subclasses that support this mode expose a ``return_waveform=True`` constructor flag and set ``self.audio_fs`` to a positive int. The ``(1, ...)`` leading dim is the mono-channel axis, kept for collate compatibility (``neural_collate`` zero-pads the last axis only). Subclasses must additionally set ``self.F`` (number of frequency bins in the spectrogram — kept positive even in waveform mode so downstream models know the *target* spectrogram width a ``wav2spec`` module should produce) in their ``__init__``, before calling ``self.validate()``. Attributes ---------- F : int Frequency-band count of the target spectrogram. Set by the subclass. audio_fs : int or None Sample rate of the raw waveform when in waveform mode; ``None`` otherwise. Subclasses without a waveform branch leave this ``None``. hearing_range_hz : tuple of float or None Optional ``(low, high)`` informational bound on the species' canonical hearing range in Hz (e.g. ``(200.0, 40000.0)`` for ferret). Purely advisory — nothing is enforced against it; it exists so notebooks / tooling can display the range and users can choose to clamp a ``wav2spec``'s frequency limits. ``None`` when unknown. """ def __init__(self, path: str, dt_ms: float): super().__init__(path, dt_ms) self.F = -1 self.audio_fs: Optional[int] = None self.hearing_range_hz: Optional[tuple] = None
[docs] def get_F(self): """Return the number of frequency bins in the spectrograms. Returns ------- int ``self.F``, the spectrogram frequency-band count. """ return self.F
@property def hop(self) -> Optional[int]: """Audio samples per neural bin in waveform mode (``None`` in spec mode). ``hop = round(audio_fs * dt_ms / 1000)``. The grid-lock contract (see :meth:`validate`) requires this to be an exact integer, so a ``wav2spec`` front-end's own ``hop`` must equal this value for the audio→neural resampling to stay aligned with the response bins. """ if self.audio_fs is None: return None return int(round(self.audio_fs * self.dt / 1000.0)) def _concat_check_compat(self, other): super()._concat_check_compat(other) assert isinstance(other, AudioNeuralDataset), \ f"Cannot concatenate AudioNeuralDataset with {type(other).__name__}" assert self.F == other.F, \ f"F mismatch: {self.F} vs {other.F}. Re-instantiate with matching n_mels." assert self.audio_fs == other.audio_fs, \ f"audio_fs mismatch: {self.audio_fs} vs {other.audio_fs}." def _concat_copy_attrs(self, source): super()._concat_copy_attrs(source) self.F = source.F self.audio_fs = source.audio_fs self.hearing_range_hz = source.hearing_range_hz
[docs] def validate(self): super().validate() assert isinstance(self.F, int) and self.F > 0, \ f"self.F must be a positive int (got {self.F!r})" if self.hearing_range_hz is not None: hr = self.hearing_range_hz assert isinstance(hr, (tuple, list)) and len(hr) == 2 \ and 0 < hr[0] < hr[1], ( f"self.hearing_range_hz must be a (low, high) pair of positive " f"increasing numbers or None (got {hr!r})" ) # The waveform grid-lock contract only applies in raw-waveform mode. # ``audio_fs`` alone is NOT the signal: some spectrogram datasets (e.g. # Downer2025) set ``audio_fs`` as the in-loader spec sample rate while # still handing out (1, F, T) spectrograms. Key off the explicit mode # flag ``return_waveform`` instead (absent ⇒ spectrogram dataset). if getattr(self, "return_waveform", False): assert isinstance(self.audio_fs, int) and self.audio_fs > 0, \ f"waveform mode requires a positive-int audio_fs (got {self.audio_fs!r})" self._validate_waveform_grid()
def _validate_waveform_grid(self): """Enforce the waveform grid-lock contract (convention C1). In waveform mode every stim must be a ``(1, T_audio)`` tensor whose length is an exact multiple of ``hop = audio_fs * dt_ms / 1000``, and ``T_audio // hop`` must equal that stim's neural response length. This pins audio sample ``j`` to response bin ``j // hop``, which is what lets a strictly-causal ``wav2spec`` stay causal w.r.t. the responses. See the "Waveform conventions" section of ``docs/_source/md/data_paradigm.md``. """ ratio = self.audio_fs * self.dt / 1000.0 hop = int(round(ratio)) assert hop >= 1 and abs(ratio - hop) < 1e-6, ( f"waveform grid lock: audio_fs * dt_ms / 1000 = {ratio} is not an " f"integer number of samples per bin. Choose an audio_fs such that " f"audio_fs * {self.dt} / 1000 is an integer (e.g. 48000 at dt=5 ms)." ) for s, stim in enumerate(self.stims): assert stim.dim() == 2 and stim.shape[0] == 1, ( f"waveform stim {s} must be (1, T_audio); got {tuple(stim.shape)}" ) T_audio = stim.shape[1] assert T_audio % hop == 0, ( f"waveform stim {s}: T_audio={T_audio} is not a multiple of " f"hop={hop}. Each waveform must be right-padded / cropped to an " f"exact multiple of hop so it aligns with the response bins." ) T_resp = self._stim_response_length(s) if T_resp is None: continue # no neuron heard this stim — nothing to align against assert T_audio // hop == T_resp, ( f"waveform stim {s}: T_audio // hop = {T_audio // hop} neural " f"frames but the response has {T_resp} bins. The waveform length " f"must equal T_resp * hop = {T_resp * hop} samples." ) def _stim_response_length(self, s: int) -> Optional[int]: """Neural response length (in bins) for stim ``s``, or ``None`` if no neuron has real (non-sentinel) data for it. ``(1, 1)`` all-NaN tensors are the missing-data sentinels of the deepSTRF response paradigm and are skipped. """ for r in self.responses[s]: if tuple(r.shape) == (1, 1) and bool(torch.isnan(r).all()): continue return int(r.shape[1]) return None