import os
import re
from typing import Optional
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
import torchaudio
from deepSTRF.datasets.audio.audio_dataset import AudioNeuralDataset
from deepSTRF.datasets.audio._nat4_native import (
epoch_names_matching,
extract_epoch,
load_per_site_recording,
load_pop_recording,
normalize_log1p_minmax_inplace,
normalize_minmax_inplace,
)
from deepSTRF.utils.audio_io import load_wav
from deepSTRF.utils.data_download import (
default_cache_dir,
unzip,
zenodo_download,
)
# NAT4 Zenodo record (https://doi.org/10.5281/zenodo.8044773), public.
NAT4_ZENODO_RECORD = 8044773
# NEMS cell ids in NAT4 follow the convention <site>-<electrode>-<unit>, e.g.
# 'ARM029a-01-1'. The site itself is <3-letter animal code><digits><session>,
# e.g. 'ARM029a' (animal 'ARM', recording 029, session 'a').
_CELL_ID_RE = re.compile(r"^([A-Za-z]{3})\d+[a-z]?-(\d+)-(\d+)$")
def _parse_nat4_cell_id(cell_id: str) -> dict:
"""Best-effort decomposition of a NAT4 cell id.
Returns a dict with ``site``, ``animal``, ``electrode``, ``unit_in_electrode``.
Any field whose source is missing or unparseable is set to ``None``.
"""
out = {"site": None, "animal": None, "electrode": None, "unit_in_electrode": None}
if not isinstance(cell_id, str) or "-" not in cell_id:
return out
parts = cell_id.split("-")
out["site"] = parts[0]
m = _CELL_ID_RE.match(cell_id)
if m is None:
# site looked sensible (first segment), but electrode/unit didn't parse.
return out
out["animal"] = m.group(1)
out["electrode"] = int(m.group(2))
out["unit_in_electrode"] = int(m.group(3))
return out
[docs]
def download_nat4(area: str, dest: Optional[str] = None, wav: bool = False) -> str:
"""Download the NAT4 release from Zenodo into ``dest``.
Fetches the population .tgz, the per-cell auditory CSV, and the per-site
.zip. The single-sites zip is unpacked into ``<dest>/<area>_single_sites/``
so the loader finds the per-site .tgzs where it expects them.
Idempotent: skips files / dirs that already exist.
Parameters
----------
area : {'A1', 'PEG'}
dest : str, optional
Defaults to ``default_cache_dir('NAT4')`` (overridable via
``$DEEPSTRF_DATA_DIR``).
wav : bool, default False
If True, also fetch and unpack ``wav.zip`` (the 593 source waveforms,
44.1 kHz / 1 s each) into ``<dest>/wav/`` for the raw-waveform branch
(``NAT4Dataset(return_waveform=True)``). The spectrogram-mode loader
does not need it.
"""
assert area in ("A1", "PEG"), f"area must be 'A1' or 'PEG' (got {area!r})"
dest_path = str(default_cache_dir("NAT4") if dest is None else dest)
os.makedirs(dest_path, exist_ok=True)
pop_tgz_name = f"{area}_NAT4_ozgf.fs100.ch18.tgz"
pop_tgz_path = os.path.join(dest_path, pop_tgz_name)
if not os.path.exists(pop_tgz_path):
zenodo_download(NAT4_ZENODO_RECORD, pop_tgz_name, pop_tgz_path)
csv_name = f"{area}_pred_correlation.csv"
csv_path = os.path.join(dest_path, csv_name)
if not os.path.exists(csv_path):
zenodo_download(NAT4_ZENODO_RECORD, csv_name, csv_path)
single_sites_dir = os.path.join(dest_path, f"{area}_single_sites")
if not os.path.isdir(single_sites_dir):
zip_name = f"{area}_single_sites.zip"
zip_path = os.path.join(dest_path, zip_name)
if not os.path.exists(zip_path):
zenodo_download(NAT4_ZENODO_RECORD, zip_name, zip_path)
unzip(zip_path, dest_path)
if wav:
wav_dir = os.path.join(dest_path, "wav")
if not os.path.isdir(wav_dir):
wav_zip = os.path.join(dest_path, "wav.zip")
if not os.path.exists(wav_zip):
zenodo_download(NAT4_ZENODO_RECORD, "wav.zip", wav_zip)
unzip(wav_zip, dest_path)
return dest_path
[docs]
class NAT4Dataset(AudioNeuralDataset):
"""PyTorch dataset for NAT4 (Pennington & David, 2022 / 2023).
Two cortical areas: ``A1`` (primary, 849 cells of which 777 auditory)
and ``PEG`` (secondary, 398 of which 339 auditory). Pass ``area=...``;
one instance covers one area. To pool both, instantiate twice and
``concat_neural_datasets([a1, peg])``.
There are 595 stimuli total: 18 high-rep (``val``, 20 trials) + 577
low-rep (``est``, 1 trial), each clip 1.5 s. The default time bin is
``dt_ms = 10`` (the population recording is precomputed at fs=100 with
``val`` pre-averaged over 20 reps; per-site spike trains are at fs=1000
and downsampled to 10 ms by summing). The spectrogram has ``F = 18``
ozgf bands and ``T = 150`` frames per stim.
The loader reads the published NAT4 archive directly with native CSV /
JSON / HDF5 parsers — no NEMS0 dependency. Data are freely available at
https://doi.org/10.5281/zenodo.8044773 (no account required) and
auto-fetched by ``NAT4Dataset(download=True)``.
Notes
-----
Follows the standard deepSTRF data paradigm (see
``docs/_source/md/data_paradigm.md``). NAT4-specific metadata:
- ``stim_meta`` dicts hold ``name`` and ``subset`` (``'est'`` or
``'val'``); the ``subset='all'|'est'|'val'`` constructor argument
filters this list at load time.
- ``nrn_meta`` dicts hold ``cell_id`` (raw NEMS id, e.g.
``'ARM029a-01-1'``), ``area``, ``auditory`` (flag from the dataset's
``<area>_pred_correlation.csv``), and the parsed components ``site``
(e.g. ``'ARM029a'``), ``animal`` (3-char site prefix, e.g. ``'ARM'``),
``electrode`` (int) and ``unit_in_electrode`` (int). Components default
to ``None`` for any cell whose id does not match the standard
``<site>-<elec>-<unit>`` scheme.
``est`` responses have shape ``(R=1, T=150)`` and ``val`` responses
``(R=20, T=150)``; the ``(1, 1)`` NaN sentinel marks ``(stim, neuron)``
pairs where the cell was not recorded for that stim.
With ``return_waveform=True``, ``stims`` are instead the raw mono
waveforms ``(1, T_audio = T * hop)`` at ``audio_fs`` (hop=441 at
44.1 kHz / 10 ms) — feed them through a model's ``wav2spec`` slot.
References
----------
Pennington & David (2022, preprint). "Can deep learning provide a
generalizable model for dynamic sound encoding in auditory cortex?"
Pennington & David (2023). "A convolutional neural network provides a
generalizable model of natural sound coding by neural populations in
auditory cortex." *PLOS Computational Biology*.
"""
def __init__(self, path: Optional[str] = None, area: str = 'A1',
dt_ms: float = 10.0, smooth: bool = False,
download: bool = False, subset: str = 'all',
return_waveform: bool = False, audio_fs: int = 44100):
"""
Parameters
----------
path : str, optional
Path to the NAT4 data folder. Defaults to the platformdirs cache.
area : {'A1', 'PEG'}
Cortical area.
dt_ms : float, default 10.0
Time-bin width in ms. Currently must equal 10.0; the population
recording is precomputed at fs=100 and the per-site downsampling
assumes a fixed 10x ratio from fs=1000.
smooth : bool, default False
If True, smooth PSTHs with a 21 ms Hanning window. Off by default
here because NAT4 trials are typically used as-is for STRF
fitting (unlike CRCNS-AA where smoothing is the published norm).
download : bool, default False
If True and the data is missing under ``path``, fetch it from
Zenodo (record 8044773).
subset : {'all', 'est', 'val'}, default 'all'
If 'est' or 'val', only that stimulus subset is loaded —
``stim_meta`` / ``stims`` / ``responses`` shrink accordingly,
and the (more expensive) per-site spike-time pass is skipped
entirely under ``subset='est'``. The two subsets correspond to
Pennington & David's published estimation set (575 stims, R=1,
from the population recording) and validation set (18 stims,
R=20, from the per-site recordings) respectively. Note that 33
of the 849 A1 cells have no val data — under ``subset='val'``
their responses are full NaN sentinels; pair the constructor
arg with ``ds.select_pop_by_stim_attr('subset', 'val')`` to
drop them automatically (idiomatic alternative:
``ds.select_stims_by_attr('subset', 'val')`` — which leaves the
full stim bank loaded but applies the bidirectional rule, so
cells without val data are hidden from ``__getitem__``).
return_waveform : bool, default False
If True, each stimulus is the raw mono waveform ``(1, T_audio)`` at
``audio_fs`` Hz instead of the precomputed ozgf cochleagram. The
593 source .wav files (44.1 kHz, 1 s of sound) are read from
``<path>/wav/`` and embedded in the 1.5 s trial window at the
recording's pre-silence offset, then grid-locked to
``T_audio = T_neural * hop`` (``hop = audio_fs * dt_ms / 1000``).
Feed it through a model's ``wav2spec`` slot (e.g.
``CausalGammatone`` to reproduce the native ozgf front-end). Pass
``download=True`` to also fetch ``wav.zip`` from Zenodo.
audio_fs : int, default 44100
Audio sample rate for ``return_waveform=True``. The default 44.1 kHz
is the native rate of the NAT4 wavs and gives an exact integer
``hop = 441`` at ``dt_ms = 10`` (no resampling). Choose any rate
making ``audio_fs * dt_ms / 1000`` an integer. Ignored unless
``return_waveform=True``.
"""
assert area in ("A1", "PEG"), \
f"Unexpected area {area!r}, choose between 'A1' or 'PEG'"
assert subset in ("all", "est", "val"), \
f"Unexpected subset {subset!r}, choose between 'all', 'est', 'val'"
assert dt_ms == 10.0, (
f"NAT4 spectrograms are precomputed at dt=10 ms; got dt_ms={dt_ms}. "
f"Re-rasterizing the responses is straightforward but the "
f"spectrogram .tgz would also need re-binning (TODO)."
)
if path is None:
path = str(default_cache_dir("NAT4"))
if download:
download_nat4(area, path, wav=return_waveform)
super().__init__(path, dt_ms)
self.area = area
self.species = 'ferret'
self.F = 18
self.hearing_range_hz = (200.0, 40000.0) # ferret (informational)
# Raw-waveform input mode (opt-in). The native stim is the precomputed
# ozgf cochleagram; here we instead hand out the source waveform and let
# a model's wav2spec slot build the spectrogram (strictly causally).
self.return_waveform = bool(return_waveform)
self.audio_fs = int(audio_fs) if return_waveform else None
# ========= LOAD THE POPULATION RECORDING (used for est, R=1) ===========
# Accept either the .tgz archive OR an already-extracted directory.
tgz_path = os.path.join(path, f'{area}_NAT4_ozgf.fs100.ch18.tgz')
dir_path = os.path.join(path, f'{area}_NAT4_ozgf.fs100.ch18')
if os.path.exists(tgz_path):
datafile = tgz_path
elif os.path.isdir(dir_path):
datafile = dir_path
else:
raise FileNotFoundError(
f"NAT4 expects either {tgz_path} or {dir_path}/. "
f"Pass download=True to fetch the .tgz from Zenodo, or "
f"place the data manually."
)
rec = load_pop_recording(datafile)
# log1p + minmax for the spectrogram, plain minmax for the response —
# matches the preprocessing baked into the published Pennington &
# David models. Both operations are global (single (min, max) per
# signal across all of T × K).
normalize_log1p_minmax_inplace(rec)
normalize_minmax_inplace(rec)
cells = rec.chans
val_sounds = epoch_names_matching(rec.epochs, "^STIM_00cat")
est_sounds = epoch_names_matching(rec.epochs, "^STIM_cat")
# In waveform mode we read source .wav files from <path>/wav/ and inset
# each at the trial's pre-stimulus-silence offset (NAT4 trials are a
# 1.5 s window = pre-silence + 1 s sound + post-silence; the wav holds
# only the sound). Derive the offset from the epoch table once.
if self.return_waveform:
self._wav_dir = os.path.join(path, "wav")
if not os.path.isdir(self._wav_dir):
raise FileNotFoundError(
f"NAT4 waveform mode needs the source wavs at {self._wav_dir!r}. "
f"Pass download=True to fetch wav.zip from Zenodo, or unpack it "
f"manually."
)
# case-insensitive index: a few epoch names disagree with the on-disk
# filename only in letter case (e.g. 'True' vs 'true' inside the name).
self._wav_index = {fn.lower(): fn for fn in os.listdir(self._wav_dir)
if fn.lower().endswith(".wav")}
sil = rec.epochs[rec.epochs['name'] == 'PreStimSilence']
prestim_s = float((sil['end'] - sil['start']).iloc[0]) if len(sil) else 0.0
self._pre_samples = int(round(prestim_s * self.audio_fs))
# ========= STIM SPECTROGRAMS / WAVEFORMS (est first, then val) ===========
load_est = subset in ("all", "est")
load_val = subset in ("all", "val")
self.stim_meta = []
self.stims = []
def _append_stim(name, subset_label):
spec = extract_epoch(rec, 'stim', name) # (R=1, F, T)
if self.return_waveform:
self.stims.append(self._load_stim_waveform(name, spec.shape[-1]))
else:
self.stims.append(torch.from_numpy(spec[0]).unsqueeze(0).float()) # (1, F, T)
self.stim_meta.append({'name': name, 'subset': subset_label})
if load_est:
for est_sound in est_sounds:
_append_stim(est_sound, 'est')
if load_val:
for val_sound in val_sounds:
_append_stim(val_sound, 'val')
# ========= NEURON METADATA (auditory flag + parsed cell_id) ===========
self.nrn_meta = []
list_neurons = pd.read_csv(os.path.join(path, f'{area}_pred_correlation.csv'))
cell_to_aud = dict(zip(list_neurons['cellid'], list_neurons['sig_auditory']))
for cell in cells:
self.nrn_meta.append({
'cell_id': cell,
'area': area,
'auditory': bool(cell_to_aud.get(cell, False)),
**_parse_nat4_cell_id(cell),
})
self.N_neurons = len(self.nrn_meta)
# ========= EST RESPONSES (1 trial per stim, full population) ===========
# Cells that didn't see a given est stim get a (1, 1) NaN sentinel
# rather than a (1, T) trace of NaNs / zeros.
est_responses_per_stim = [] # list of S_est lists of N (1, T) tensors / NaN
if load_est:
for est_sound in est_sounds:
arr = extract_epoch(rec, 'resp', est_sound) # (R=1, N, T)
stim_resps = []
for n in range(self.N_neurons):
trace = arr[:, n, :] # (1, T)
if np.isnan(trace).all():
stim_resps.append(torch.full((1, 1), float('nan')))
else:
stim_resps.append(torch.from_numpy(trace).float())
est_responses_per_stim.append(stim_resps)
del rec
# ========= VAL RESPONSES (20 trials per stim, per-site stitching) ===========
# The pop rec averages val over 20 reps and only keeps R=1; for trial-
# resolved data we go to the per-site .tgzs at fs=1000 and downsample.
val_responses_per_stim = []
if load_val:
val_files = sorted(os.listdir(os.path.join(path, f'{area}_single_sites')))
per_site_val = [] # list of (S_val, R, N_subpop, T) tensors
val_cells_in_order = [] # cell-id order across stitched per-site recs
for filename in tqdm(val_files, desc=f'NAT4 {area} val sites'):
site_path = os.path.join(path, f'{area}_single_sites', filename)
site_rec = load_per_site_recording(site_path)
val_cells_in_order += list(site_rec.chans)
site_responses = []
for val_sound in val_sounds:
# extract at fs=1000 -> (R=20, N_subpop, T_ms=1500)
arr = extract_epoch(site_rec, 'resp', val_sound)
R, N_subpop, T_ms = arr.shape
# downsample 1 ms -> 10 ms by summing
arr = arr.reshape(R, N_subpop, -1, 10).sum(axis=-1)
site_responses.append(torch.from_numpy(arr).float())
per_site_val.append(torch.stack(site_responses)) # (S_val, R, N_subpop, T)
del site_rec
val_full = torch.cat(per_site_val, dim=2) # (S_val, R, N_total_in_val_order, T)
val_full = val_full.permute(0, 2, 1, 3) # (S_val, N, R, T)
# Cells in the per-site stitching order are not in the same order as
# the population's `cells` list (and the val concatenation may include
# cells the pop rec doesn't have, or vice versa). Reindex onto `cells`;
# any pop-rec cell missing from the val concat gets a NaN sentinel.
index_map = {u: i for i, u in enumerate(val_cells_in_order)}
S_val = val_full.shape[0]
for s in range(S_val):
stim_resps = []
for n, cell in enumerate(cells):
if cell in index_map:
stim_resps.append(val_full[s, index_map[cell]])
else:
stim_resps.append(torch.full((1, 1), float('nan')))
val_responses_per_stim.append(stim_resps)
# est first, then val — matches the order of self.stims / self.stim_meta
self.responses = est_responses_per_stim + val_responses_per_stim
if smooth:
self.smooth_responses(window_ms=21.0)
self.validate()
def _load_stim_waveform(self, epoch_name: str, T_neural: int) -> torch.Tensor:
"""Reconstruct a stim's ``(1, T_audio)`` waveform from its source .wav.
The .wav holds only the 1 s sound; we zero-pad it to the trial's
pre-silence offset so the audio lines up with the spectrogram frames
(= response bins), then crop / pad to exactly ``T_neural * hop`` samples
(grid lock C1). The epoch name is ``STIM_<wavfile>``.
"""
fname = epoch_name[len("STIM_"):] if epoch_name.startswith("STIM_") else epoch_name
resolved = self._wav_index.get(fname.lower())
if resolved is None:
raise FileNotFoundError(
f"NAT4 waveform mode: no source wav for epoch {epoch_name!r} in "
f"{self._wav_dir!r}. Pass download=True to fetch wav.zip from Zenodo."
)
wav_path = os.path.join(self._wav_dir, resolved)
w, sr = load_wav(wav_path) # (C, T)
if w.shape[0] > 1:
w = w.mean(dim=0, keepdim=True) # downmix to mono
if sr != self.audio_fs:
w = torchaudio.functional.resample(w, sr, self.audio_fs)
T_audio = T_neural * self.hop
full = torch.zeros(1, T_audio)
seg = w[0, : max(0, T_audio - self._pre_samples)]
full[0, self._pre_samples: self._pre_samples + seg.shape[0]] = seg
return full.contiguous().float()