Source code for deepSTRF.datasets.audio.ns1

import os
from typing import Optional

import numpy as np
import scipy.io as sio
import torch

from deepSTRF.datasets.audio.audio_dataset import AudioNeuralDataset
from deepSTRF.utils.data_download import (
    default_cache_dir,
    github_raw_download,
    osf_download,
    unzip,
)


# Stimulus duration is 4995 ms = 999 bins at the 5 ms binning the original
# authors used. The provided spectrogram tensor (test_data_5ms.mat) is also
# at this temporal resolution.
NS1_RAW_LEN_MS = 4995
NS1_NAT_SOUNDS = 20

# OSF storage GUIDs for the public NS1 release (https://osf.io/ayw2p/) —
# metadata + raw spike data. Resolved with osf_download(<guid>, dest).
NS1_OSF_FILES = {
    "MetadataSHEnCneurons.mat": "gdwyd",
    "spikesandwav.zip":         "5nxga",
    "ReadMe.rtf":               "f3rtm",
}

# The precomputed mel-spectrogram tensor (X_nfht: S=20, F=34, hopdim=1, T=999)
# used by the original Harper/Rahman analyses is NOT on OSF, but it IS in the
# DNet companion repo (Rahman et al. 2019 PLoS Comp Biol, doi: 10.1371/
# journal.pcbi.1006618 — github.com/monzilur/DNet). The 5 ms version is
# ``test_data_5ms.mat`` (5.2 MB); the 1 ms version (``test_data.mat``, 52 MB)
# is also there but we don't use it. The same file also contains a ``y_nt``
# variable, but it is a (S, T) trace that does not cleanly match either a
# single neuron's PSTH nor any population mean of our 119-neuron data — we
# ignore it and stick with our trial-resolved (R=20, T=999) responses
# computed from the OSF spike .mat files.
NS1_DNET_REPO = "monzilur/DNet"
NS1_DNET_REF = "master"
NS1_DNET_FILES = {
    "test_data_5ms.mat": "test_data_5ms.mat",
}

# Indices of the 4 natural-speech stimuli, per Rahman et al. (2020) Fig. S2.
# 0-indexed; the original (1-indexed) sound numbers are 9, 10, 11, 12.
NS1_SPEECH_INDICES = (8, 9, 10, 11)
# Index 0 (sound 1) and index 19 (sound 20) are water sounds; index 3 is a
# ferret vocalization; index 6 is insects buzzing. All remaining indices have
# no published category and are left as "unknown".
NS1_TYPE_OVERRIDES = {
    0: "water_sounds",
    3: "ferret_vocalization",
    6: "insects_buzzing",
    19: "water_sounds",
    **{i: "human_speech" for i in NS1_SPEECH_INDICES},
}

# Train/test split used by Rahman et al. (2020). Kept on the class as a
# convenience for downstream notebooks; we do NOT enforce this split inside
# the dataset (callers can sub-select by stim index).
NS1_RAHMAN_TRAINVAL_INDICES = [0, 1, 2, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18]
NS1_RAHMAN_TEST_INDICES = [3, 6, 9, 19]

# Mapping from stim index (0..19) to the matching wav filename under
# ``spikesandwav/SH.En.C/``. Verified empirically (mel-spec correlation to
# X_nfht — see ``tests/test_ns1_waveform.py``): stim_idx 0..11 come from
# ``source.1`` with ``fw=2`` and frozen tokens 1..12; stim_idx 12..19 come
# from ``source.2`` with ``fw=1`` and frozen tokens 1..8.
NS1_WAV_FILENAMES = tuple(
    f"source.1.sound.0.snr.0.token.0.fw.2.frozen.{i+1}" for i in range(12)
) + tuple(
    f"source.2.sound.0.snr.0.token.0.fw.1.frozen.{i+1}" for i in range(8)
)
assert len(NS1_WAV_FILENAMES) == NS1_NAT_SOUNDS

# Raw wavs are float-mono at 48828.125 Hz, 5.000 s long (244140 samples). The
# precomputed spec covers 4.995 s (999 bins × 5 ms) — so when constructing
# waveforms aligned to the response window we crop / pad the resampled audio
# to exactly ``T_neural * audio_fs * dt_ms / 1000`` samples.
NS1_WAV_NATIVE_FS = 48828
NS1_WAV_DIR_NAME = "spikesandwav/SH.En.C"


[docs] def download_ns1(dest: Optional[str] = None) -> str: """Download all NS1 data assets into ``dest``. Sources: - **OSF** (https://osf.io/ayw2p/, no account): the dataset README, the per-neuron metadata (.mat), and the spike + wav zip (~155 MB total). - **DNet GitHub** (https://github.com/monzilur/DNet, master branch): the precomputed 5 ms mel-spectrogram tensor ``test_data_5ms.mat`` (5.2 MB) accompanying Rahman et al. 2019 PLoS Comp Biol. NOT on OSF. Idempotent: skips files that already exist; returns the destination path. Parameters ---------- dest : str, optional Where to put the downloaded files. Defaults to the platformdirs cache (overridable via ``$DEEPSTRF_DATA_DIR``). Returns ------- str Absolute path to the dataset directory. """ dest_path = str(default_cache_dir("NS1") if dest is None else dest) os.makedirs(dest_path, exist_ok=True) # OSF assets for fname, guid in NS1_OSF_FILES.items(): target = os.path.join(dest_path, fname) if os.path.exists(target): continue osf_download(guid, target) # unzip spikesandwav.zip if not already unpacked spikes_dir = os.path.join(dest_path, "spikesandwav") zip_path = os.path.join(dest_path, "spikesandwav.zip") if os.path.isfile(zip_path) and not os.path.isdir(spikes_dir): unzip(zip_path, dest_path) # DNet GitHub asset(s) — the precomputed spectrogram tensor for dest_name, repo_path in NS1_DNET_FILES.items(): target = os.path.join(dest_path, dest_name) if os.path.exists(target): continue github_raw_download(NS1_DNET_REPO, repo_path, target, ref=NS1_DNET_REF) return dest_path
[docs] class NS1Dataset(AudioNeuralDataset): """PyTorch dataset for the NS1 (Harper et al. 2016, Rahman et al. 2020) data. 119 multi/single units from primary auditory cortex (A1) of deeply anesthetized ferrets, recorded in response to 20 natural sound clips of 4.995 s each, presented 20 times per neuron. Every neuron heard every clip, so the response grid is fully dense (no NaN sentinels). Of the 119 units, 73 pass the "single-unit at known depth" filter the original authors used (``single_t in {'Yes', 'Maybe'}`` and ``depth >= 0``); :meth:`~deepSTRF.datasets.neural_dataset.NeuralDataset.select_pop_by_nrn_attr` over ``single_t`` / ``depth_um`` reproduces this subset. The spectrogram tensor is precomputed at ``dt = 5 ms`` (``F = 34`` frequency bands, ``T = 999`` bins); the ``dt_ms`` constructor argument is currently validated against this resolution. With ``return_waveform=True``, ``stims`` are instead raw mono waveforms ``(1, T_audio)`` at ``audio_fs`` (aligned to ``T_audio = T_neural * audio_fs * dt_ms / 1000``) — feed them through a model's ``wav2spec`` front-end. Data are freely available (no account required) and auto-fetched by ``NS1Dataset(download=True)``: - https://osf.io/ayw2p/ — metadata, raw spike and wav data. - https://github.com/monzilur/DNet — precomputed 5 ms mel spectrogram. Notes ----- Follows the standard deepSTRF data paradigm (see ``docs/_source/md/data_paradigm.md``). NS1-specific metadata: - ``stim_meta`` dicts hold ``name`` and ``type``. - ``nrn_meta`` dicts hold ``cell_id``, ``area``, ``depth_um``, ``noise_ratio``, ``single_n``, ``single_t``, ``n_electrodes`` and ``electrode_number``. ``noise_ratio`` is the Sahani-Linden normalised noise power (lower = cleaner; NOT an SNR despite the legacy ``.mat`` field name). ``single_n`` is the single-unit flag from spike-snippet clustering (0/1); ``single_t`` is the manual triage label ('Yes'/'Maybe'/'No'). References ---------- Harper et al. (2016). "Network receptive field modeling reveals extensive integration and multi-feature selectivity in auditory cortical neurons." *PLoS Computational Biology*. Rahman et al. (2020). "Simple transformations capture auditory input to cortex." *PNAS*. """ def __init__(self, path: Optional[str] = None, dt_ms: float = 5.0, smooth: bool = True, download: bool = False, return_waveform: bool = False, audio_fs: int = 48000): """ Parameters ---------- path : str, optional Path to the NS1 data folder containing ``test_data_5ms.mat``, ``MetadataSHEnCneurons.mat``, and ``spikesandwav/``. If ``None``, defaults to the platformdirs cache (``user_cache_dir('deepSTRF') / 'NS1'`` — overridable via ``$DEEPSTRF_DATA_DIR``). dt_ms : float, default 5.0 Time-bin width in ms. Must equal 5.0 — the bundled spectrogram is precomputed at this resolution. Other values would require re-spectrogramming the wavs (not implemented). smooth : bool, default True If True, smooth PSTHs in place with a 21 ms Hanning window (Hsu, Borst & Theunissen 2004). download : bool, default False If True and the data assets are missing under ``path``, fetch them from their public sources (no account required) — OSF (https://osf.io/ayw2p/: metadata + spike data + wavs) and DNet GitHub (https://github.com/monzilur/DNet: the precomputed 5 ms mel-spectrogram tensor ``test_data_5ms.mat``). Total ~160 MB, ~16 s on a fast connection. See :func:`download_ns1`. return_waveform : bool, default False If True, ``self.stims`` holds raw audio waveforms instead of precomputed spectrograms. Each ``self.stims[s]`` is a ``(1, T_audio)`` float32 tensor at ``audio_fs`` Hz, downmixed to mono, resampled from the native 48 828.125 Hz, and right-cropped / zero-padded to exactly ``T_neural * audio_fs * dt_ms / 1000`` samples so it aligns with the 4.995 s response window. Pair with a model that has a ``wav2spec`` front-end (see ``deepSTRF.models.wav2spec``). audio_fs : int, default 48000 Sample rate (Hz) for waveform mode. Default 48 kHz gives a clean 240 samples / 5-ms bin and a Nyquist of 24 kHz — enough to preserve the ~22.6 kHz content used in Rahman et al. 2019's cochleagram. (Native is 48 828.125 Hz; the small downsample keeps an integer sample-per-bin factor.) Ignored when ``return_waveform=False``. """ if path is None: path = str(default_cache_dir("NS1")) if download: download_ns1(path) super().__init__(path, dt_ms) assert dt_ms == 5.0, ( f"NS1 spectrograms are precomputed at 5 ms; got dt_ms={dt_ms}. " f"Re-binning the spike trains is straightforward but the spectrogram " f"would need to be recomputed from raw wavs (TODO)." ) self.species = "ferret" # Informational only (not enforced): ferret behavioural audiogram spans # roughly 200 Hz – 40 kHz. Lets tooling/notebooks display the range and # users optionally clamp a wav2spec's frequency limits. self.hearing_range_hz = (200.0, 40000.0) # Raw-waveform mode flag (gates the base-class grid-lock validation). self.return_waveform = bool(return_waveform) # ----------- 1. load the precomputed spectrograms ----------- # X_nfht: (S=20, F=34, 1, T=999) at dt=5 ms spec_path = os.path.join(path, "test_data_5ms.mat") if not os.path.isfile(spec_path): raise FileNotFoundError( f"NS1 expects 'test_data_5ms.mat' at {spec_path}. Pass\n" f"download=True to fetch it from the DNet companion repo\n" f"(https://github.com/monzilur/DNet), or place it manually." ) spec_data = sio.loadmat(spec_path) X = spec_data["X_nfht"] S, F, _, T = X.shape assert S == NS1_NAT_SOUNDS, f"expected {NS1_NAT_SOUNDS} stims, got {S}" self.F = int(F) self.stim_meta = [ {"name": f"nat{s + 1:02d}", "type": NS1_TYPE_OVERRIDES.get(s, "unknown")} for s in range(S) ] if return_waveform: # ---- raw-waveform mode: load the 20 OSF wavs in stim-index order ---- from deepSTRF.utils.audio_io import load_resampled_mono_wav wav_root = os.path.join(path, NS1_WAV_DIR_NAME) if not os.path.isdir(wav_root): raise FileNotFoundError( f"NS1 waveform mode expects raw wavs at {wav_root}; pass\n" f"download=True to fetch them from OSF (https://osf.io/ayw2p/)." ) T_neural = NS1_RAW_LEN_MS // int(round(dt_ms)) # 999 at 5 ms T_audio = int(round(T_neural * audio_fs * dt_ms / 1000)) self.audio_fs = int(audio_fs) self.stims = [ load_resampled_mono_wav( os.path.join(wav_root, NS1_WAV_FILENAMES[s]), target_fs=audio_fs, target_length=T_audio, ) for s in range(S) ] else: # ---- spectrogram mode (default): use the precomputed tensor ---- self.stims = [ torch.from_numpy(X[s, :, 0, :]).float().unsqueeze(0) # (1, F, T) for s in range(S) ] # ----------- 2. load per-neuron metadata + spike data ----------- meta = sio.loadmat(os.path.join(path, "MetadataSHEnCneurons.mat")) neurons = meta["neuron"][0] self.responses = [[] for _ in range(S)] self.nrn_meta = [] spikes_root = os.path.join(path, "spikesandwav") for neuron in neurons: uid = str(neuron[0].item())[:-4] # drop trailing '.mat' from filename-as-uid spike_path = os.path.join(spikes_root, str(neuron["path"].item())) try: temp = sio.loadmat(spike_path) except FileNotFoundError: # The OSF release has a few neurons with a "_1" suffix mismatch # between the metadata path and the actual file; try the fallback. fallback = spike_path[:-4] + "_1.mat" try: temp = sio.loadmat(fallback) except FileNotFoundError: print(f"NS1: skipping neuron {uid!r} (spike file not found at {spike_path} or {fallback})") continue self.nrn_meta.append({ "cell_id": uid, "area": "A1", "depth_um": int(neuron["depth"].item()), "noise_ratio": float(neuron["NoiseRatio"].item()), "single_n": int(neuron["singleN"].item()), "single_t": str(neuron["singleT"].item()), "n_electrodes": int(neuron["NrOfElectrodes"].item()), "electrode_number": int(neuron["ElectrodeNumber"].item()), }) for s in range(S): repeats_grp = temp["data"]["set"][0, 0]["repeats"][0, s] R = int(repeats_grp.shape[1]) # bin spike times to the requested dt bin_size = int(round(dt_ms)) T_resp = NS1_RAW_LEN_MS // bin_size # = 999 for 5 ms / 4995 ms spike_matrix = np.zeros((R, T_resp), dtype=np.float32) for r in range(R): spiketimes = np.round(repeats_grp[0]["t"][r][0]).astype(np.int64) spiketimes = spiketimes[(spiketimes >= 0) & (spiketimes < NS1_RAW_LEN_MS)] one_hot = np.zeros(NS1_RAW_LEN_MS, dtype=np.float32) one_hot[spiketimes] = 1.0 spike_matrix[r] = one_hot.reshape(-1, bin_size).sum(axis=1) self.responses[s].append(torch.from_numpy(spike_matrix)) self.N_neurons = len(self.nrn_meta) # smooth PSTHs with a 21 ms Hanning window (Hsu / Borst / Theunissen 2004) if smooth: self.smooth_responses(window_ms=21.0) # self.nrn_masks is a derived @property on the base class — no need # to populate it here self.validate()