import os
from typing import Optional
import numpy as np
import scipy.io as sio
import torch
from deepSTRF.datasets.audio.audio_dataset import AudioNeuralDataset
from deepSTRF.utils.data_download import (
default_cache_dir,
github_raw_download,
osf_download,
unzip,
)
# Stimulus duration is 4995 ms = 999 bins at the 5 ms binning the original
# authors used. The provided spectrogram tensor (test_data_5ms.mat) is also
# at this temporal resolution.
NS1_RAW_LEN_MS = 4995
NS1_NAT_SOUNDS = 20
# OSF storage GUIDs for the public NS1 release (https://osf.io/ayw2p/) —
# metadata + raw spike data. Resolved with osf_download(<guid>, dest).
NS1_OSF_FILES = {
"MetadataSHEnCneurons.mat": "gdwyd",
"spikesandwav.zip": "5nxga",
"ReadMe.rtf": "f3rtm",
}
# The precomputed mel-spectrogram tensor (X_nfht: S=20, F=34, hopdim=1, T=999)
# used by the original Harper/Rahman analyses is NOT on OSF, but it IS in the
# DNet companion repo (Rahman et al. 2019 PLoS Comp Biol, doi: 10.1371/
# journal.pcbi.1006618 — github.com/monzilur/DNet). The 5 ms version is
# ``test_data_5ms.mat`` (5.2 MB); the 1 ms version (``test_data.mat``, 52 MB)
# is also there but we don't use it. The same file also contains a ``y_nt``
# variable, but it is a (S, T) trace that does not cleanly match either a
# single neuron's PSTH nor any population mean of our 119-neuron data — we
# ignore it and stick with our trial-resolved (R=20, T=999) responses
# computed from the OSF spike .mat files.
NS1_DNET_REPO = "monzilur/DNet"
NS1_DNET_REF = "master"
NS1_DNET_FILES = {
"test_data_5ms.mat": "test_data_5ms.mat",
}
# Indices of the 4 natural-speech stimuli, per Rahman et al. (2020) Fig. S2.
# 0-indexed; the original (1-indexed) sound numbers are 9, 10, 11, 12.
NS1_SPEECH_INDICES = (8, 9, 10, 11)
# Index 0 (sound 1) and index 19 (sound 20) are water sounds; index 3 is a
# ferret vocalization; index 6 is insects buzzing. All remaining indices have
# no published category and are left as "unknown".
NS1_TYPE_OVERRIDES = {
0: "water_sounds",
3: "ferret_vocalization",
6: "insects_buzzing",
19: "water_sounds",
**{i: "human_speech" for i in NS1_SPEECH_INDICES},
}
# Train/test split used by Rahman et al. (2020). Kept on the class as a
# convenience for downstream notebooks; we do NOT enforce this split inside
# the dataset (callers can sub-select by stim index).
NS1_RAHMAN_TRAINVAL_INDICES = [0, 1, 2, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18]
NS1_RAHMAN_TEST_INDICES = [3, 6, 9, 19]
# Mapping from stim index (0..19) to the matching wav filename under
# ``spikesandwav/SH.En.C/``. Verified empirically (mel-spec correlation to
# X_nfht — see ``tests/test_ns1_waveform.py``): stim_idx 0..11 come from
# ``source.1`` with ``fw=2`` and frozen tokens 1..12; stim_idx 12..19 come
# from ``source.2`` with ``fw=1`` and frozen tokens 1..8.
NS1_WAV_FILENAMES = tuple(
f"source.1.sound.0.snr.0.token.0.fw.2.frozen.{i+1}" for i in range(12)
) + tuple(
f"source.2.sound.0.snr.0.token.0.fw.1.frozen.{i+1}" for i in range(8)
)
assert len(NS1_WAV_FILENAMES) == NS1_NAT_SOUNDS
# Raw wavs are float-mono at 48828.125 Hz, 5.000 s long (244140 samples). The
# precomputed spec covers 4.995 s (999 bins × 5 ms) — so when constructing
# waveforms aligned to the response window we crop / pad the resampled audio
# to exactly ``T_neural * audio_fs * dt_ms / 1000`` samples.
NS1_WAV_NATIVE_FS = 48828
NS1_WAV_DIR_NAME = "spikesandwav/SH.En.C"
[docs]
def download_ns1(dest: Optional[str] = None) -> str:
"""Download all NS1 data assets into ``dest``.
Sources:
- **OSF** (https://osf.io/ayw2p/, no account): the dataset README, the
per-neuron metadata (.mat), and the spike + wav zip (~155 MB total).
- **DNet GitHub** (https://github.com/monzilur/DNet, master branch): the
precomputed 5 ms mel-spectrogram tensor ``test_data_5ms.mat``
(5.2 MB) accompanying Rahman et al. 2019 PLoS Comp Biol. NOT on OSF.
Idempotent: skips files that already exist; returns the destination path.
Parameters
----------
dest : str, optional
Where to put the downloaded files. Defaults to the platformdirs cache
(overridable via ``$DEEPSTRF_DATA_DIR``).
Returns
-------
str
Absolute path to the dataset directory.
"""
dest_path = str(default_cache_dir("NS1") if dest is None else dest)
os.makedirs(dest_path, exist_ok=True)
# OSF assets
for fname, guid in NS1_OSF_FILES.items():
target = os.path.join(dest_path, fname)
if os.path.exists(target):
continue
osf_download(guid, target)
# unzip spikesandwav.zip if not already unpacked
spikes_dir = os.path.join(dest_path, "spikesandwav")
zip_path = os.path.join(dest_path, "spikesandwav.zip")
if os.path.isfile(zip_path) and not os.path.isdir(spikes_dir):
unzip(zip_path, dest_path)
# DNet GitHub asset(s) — the precomputed spectrogram tensor
for dest_name, repo_path in NS1_DNET_FILES.items():
target = os.path.join(dest_path, dest_name)
if os.path.exists(target):
continue
github_raw_download(NS1_DNET_REPO, repo_path, target, ref=NS1_DNET_REF)
return dest_path
[docs]
class NS1Dataset(AudioNeuralDataset):
"""PyTorch dataset for the NS1 (Harper et al. 2016, Rahman et al. 2020) data.
119 multi/single units from primary auditory cortex (A1) of deeply
anesthetized ferrets, recorded in response to 20 natural sound clips of
4.995 s each, presented 20 times per neuron. Every neuron heard every
clip, so the response grid is fully dense (no NaN sentinels).
Of the 119 units, 73 pass the "single-unit at known depth" filter the
original authors used (``single_t in {'Yes', 'Maybe'}`` and
``depth >= 0``); :meth:`~deepSTRF.datasets.neural_dataset.NeuralDataset.select_pop_by_nrn_attr`
over ``single_t`` / ``depth_um`` reproduces this subset.
The spectrogram tensor is precomputed at ``dt = 5 ms`` (``F = 34``
frequency bands, ``T = 999`` bins); the ``dt_ms`` constructor argument is
currently validated against this resolution. With ``return_waveform=True``,
``stims`` are instead raw mono waveforms ``(1, T_audio)`` at ``audio_fs``
(aligned to ``T_audio = T_neural * audio_fs * dt_ms / 1000``) — feed them
through a model's ``wav2spec`` front-end.
Data are freely available (no account required) and auto-fetched by
``NS1Dataset(download=True)``:
- https://osf.io/ayw2p/ — metadata, raw spike and wav data.
- https://github.com/monzilur/DNet — precomputed 5 ms mel spectrogram.
Notes
-----
Follows the standard deepSTRF data paradigm (see
``docs/_source/md/data_paradigm.md``). NS1-specific metadata:
- ``stim_meta`` dicts hold ``name`` and ``type``.
- ``nrn_meta`` dicts hold ``cell_id``, ``area``, ``depth_um``,
``noise_ratio``, ``single_n``, ``single_t``, ``n_electrodes`` and
``electrode_number``. ``noise_ratio`` is the Sahani-Linden normalised
noise power (lower = cleaner; NOT an SNR despite the legacy ``.mat``
field name). ``single_n`` is the single-unit flag from spike-snippet
clustering (0/1); ``single_t`` is the manual triage label
('Yes'/'Maybe'/'No').
References
----------
Harper et al. (2016). "Network receptive field modeling reveals extensive
integration and multi-feature selectivity in auditory cortical neurons."
*PLoS Computational Biology*.
Rahman et al. (2020). "Simple transformations capture auditory input to
cortex." *PNAS*.
"""
def __init__(self, path: Optional[str] = None, dt_ms: float = 5.0,
smooth: bool = True, download: bool = False,
return_waveform: bool = False, audio_fs: int = 48000):
"""
Parameters
----------
path : str, optional
Path to the NS1 data folder containing ``test_data_5ms.mat``,
``MetadataSHEnCneurons.mat``, and ``spikesandwav/``. If ``None``,
defaults to the platformdirs cache (``user_cache_dir('deepSTRF')
/ 'NS1'`` — overridable via ``$DEEPSTRF_DATA_DIR``).
dt_ms : float, default 5.0
Time-bin width in ms. Must equal 5.0 — the bundled spectrogram is
precomputed at this resolution. Other values would require
re-spectrogramming the wavs (not implemented).
smooth : bool, default True
If True, smooth PSTHs in place with a 21 ms Hanning window
(Hsu, Borst & Theunissen 2004).
download : bool, default False
If True and the data assets are missing under ``path``, fetch
them from their public sources (no account required) — OSF
(https://osf.io/ayw2p/: metadata + spike data + wavs) and DNet
GitHub (https://github.com/monzilur/DNet: the precomputed 5 ms
mel-spectrogram tensor ``test_data_5ms.mat``). Total ~160 MB,
~16 s on a fast connection. See :func:`download_ns1`.
return_waveform : bool, default False
If True, ``self.stims`` holds raw audio waveforms instead of
precomputed spectrograms. Each ``self.stims[s]`` is a
``(1, T_audio)`` float32 tensor at ``audio_fs`` Hz, downmixed to
mono, resampled from the native 48 828.125 Hz, and right-cropped /
zero-padded to exactly ``T_neural * audio_fs * dt_ms / 1000``
samples so it aligns with the 4.995 s response window. Pair with a
model that has a ``wav2spec`` front-end (see
``deepSTRF.models.wav2spec``).
audio_fs : int, default 48000
Sample rate (Hz) for waveform mode. Default 48 kHz gives a clean
240 samples / 5-ms bin and a Nyquist of 24 kHz — enough to
preserve the ~22.6 kHz content used in Rahman et al. 2019's
cochleagram. (Native is 48 828.125 Hz; the small downsample
keeps an integer sample-per-bin factor.) Ignored when
``return_waveform=False``.
"""
if path is None:
path = str(default_cache_dir("NS1"))
if download:
download_ns1(path)
super().__init__(path, dt_ms)
assert dt_ms == 5.0, (
f"NS1 spectrograms are precomputed at 5 ms; got dt_ms={dt_ms}. "
f"Re-binning the spike trains is straightforward but the spectrogram "
f"would need to be recomputed from raw wavs (TODO)."
)
self.species = "ferret"
# Informational only (not enforced): ferret behavioural audiogram spans
# roughly 200 Hz – 40 kHz. Lets tooling/notebooks display the range and
# users optionally clamp a wav2spec's frequency limits.
self.hearing_range_hz = (200.0, 40000.0)
# Raw-waveform mode flag (gates the base-class grid-lock validation).
self.return_waveform = bool(return_waveform)
# ----------- 1. load the precomputed spectrograms -----------
# X_nfht: (S=20, F=34, 1, T=999) at dt=5 ms
spec_path = os.path.join(path, "test_data_5ms.mat")
if not os.path.isfile(spec_path):
raise FileNotFoundError(
f"NS1 expects 'test_data_5ms.mat' at {spec_path}. Pass\n"
f"download=True to fetch it from the DNet companion repo\n"
f"(https://github.com/monzilur/DNet), or place it manually."
)
spec_data = sio.loadmat(spec_path)
X = spec_data["X_nfht"]
S, F, _, T = X.shape
assert S == NS1_NAT_SOUNDS, f"expected {NS1_NAT_SOUNDS} stims, got {S}"
self.F = int(F)
self.stim_meta = [
{"name": f"nat{s + 1:02d}", "type": NS1_TYPE_OVERRIDES.get(s, "unknown")}
for s in range(S)
]
if return_waveform:
# ---- raw-waveform mode: load the 20 OSF wavs in stim-index order ----
from deepSTRF.utils.audio_io import load_resampled_mono_wav
wav_root = os.path.join(path, NS1_WAV_DIR_NAME)
if not os.path.isdir(wav_root):
raise FileNotFoundError(
f"NS1 waveform mode expects raw wavs at {wav_root}; pass\n"
f"download=True to fetch them from OSF (https://osf.io/ayw2p/)."
)
T_neural = NS1_RAW_LEN_MS // int(round(dt_ms)) # 999 at 5 ms
T_audio = int(round(T_neural * audio_fs * dt_ms / 1000))
self.audio_fs = int(audio_fs)
self.stims = [
load_resampled_mono_wav(
os.path.join(wav_root, NS1_WAV_FILENAMES[s]),
target_fs=audio_fs,
target_length=T_audio,
)
for s in range(S)
]
else:
# ---- spectrogram mode (default): use the precomputed tensor ----
self.stims = [
torch.from_numpy(X[s, :, 0, :]).float().unsqueeze(0) # (1, F, T)
for s in range(S)
]
# ----------- 2. load per-neuron metadata + spike data -----------
meta = sio.loadmat(os.path.join(path, "MetadataSHEnCneurons.mat"))
neurons = meta["neuron"][0]
self.responses = [[] for _ in range(S)]
self.nrn_meta = []
spikes_root = os.path.join(path, "spikesandwav")
for neuron in neurons:
uid = str(neuron[0].item())[:-4] # drop trailing '.mat' from filename-as-uid
spike_path = os.path.join(spikes_root, str(neuron["path"].item()))
try:
temp = sio.loadmat(spike_path)
except FileNotFoundError:
# The OSF release has a few neurons with a "_1" suffix mismatch
# between the metadata path and the actual file; try the fallback.
fallback = spike_path[:-4] + "_1.mat"
try:
temp = sio.loadmat(fallback)
except FileNotFoundError:
print(f"NS1: skipping neuron {uid!r} (spike file not found at {spike_path} or {fallback})")
continue
self.nrn_meta.append({
"cell_id": uid,
"area": "A1",
"depth_um": int(neuron["depth"].item()),
"noise_ratio": float(neuron["NoiseRatio"].item()),
"single_n": int(neuron["singleN"].item()),
"single_t": str(neuron["singleT"].item()),
"n_electrodes": int(neuron["NrOfElectrodes"].item()),
"electrode_number": int(neuron["ElectrodeNumber"].item()),
})
for s in range(S):
repeats_grp = temp["data"]["set"][0, 0]["repeats"][0, s]
R = int(repeats_grp.shape[1])
# bin spike times to the requested dt
bin_size = int(round(dt_ms))
T_resp = NS1_RAW_LEN_MS // bin_size # = 999 for 5 ms / 4995 ms
spike_matrix = np.zeros((R, T_resp), dtype=np.float32)
for r in range(R):
spiketimes = np.round(repeats_grp[0]["t"][r][0]).astype(np.int64)
spiketimes = spiketimes[(spiketimes >= 0) & (spiketimes < NS1_RAW_LEN_MS)]
one_hot = np.zeros(NS1_RAW_LEN_MS, dtype=np.float32)
one_hot[spiketimes] = 1.0
spike_matrix[r] = one_hot.reshape(-1, bin_size).sum(axis=1)
self.responses[s].append(torch.from_numpy(spike_matrix))
self.N_neurons = len(self.nrn_meta)
# smooth PSTHs with a 21 ms Hanning window (Hsu / Borst / Theunissen 2004)
if smooth:
self.smooth_responses(window_ms=21.0)
# self.nrn_masks is a derived @property on the base class — no need
# to populate it here
self.validate()