import os
import pandas as pd
import csv
import torch
import torchaudio
from typing import Optional
from deepSTRF.datasets.audio.audio_dataset import AudioNeuralDataset
from deepSTRF.datasets.audio._crcns_aa_loaders import load_spike_file, parse_cell_name
from deepSTRF.utils.audio_io import load_wav
from deepSTRF.utils.data_download import (
crcns_download,
default_cache_dir,
untar,
)
# CRCNS-AA2 ships as 3 tar.gz archives on the NERSC mirror, all wrapping
# their content in ``crcns/aa2/`` (strip 2 components on extract). After
# extraction the layout is the flat one CRCNSAA2Dataset expects:
# <dest>/all_cells/<cell>/<stim_type>/{spike*, stim*}
# <dest>/all_stims/*.wav
# <dest>/{cell_regions.csv, cell_stim_classes.csv, stim_data.csv,
# crcns-aa2-README.txt}
AA2_NERSC_FILES = (
"aa-2/crcns-aa2-docs.tar.gz",
"aa-2/crcns-aa2-all_cells.tar.gz",
"aa-2/crcns-aa2-all_stims.tar.gz",
)
[docs]
def download_aa2(dest: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None) -> str:
"""Download the CRCNS-AA2 archives from the NERSC mirror into ``dest``.
Idempotent: skips an archive if already on disk, and skips extraction
of an archive if its anchor sub-tree (``all_cells/``, ``all_stims/``,
or ``stim_data.csv``) already exists.
Parameters
----------
dest : str, optional
Defaults to ``default_cache_dir('AA2')`` (overridable via
``$DEEPSTRF_DATA_DIR``).
username, password : str, optional
Default to ``$CRCNS_USERNAME`` / ``$CRCNS_PASSWORD``.
"""
dest_path = str(default_cache_dir("AA2") if dest is None else dest)
os.makedirs(dest_path, exist_ok=True)
# anchor file/dir per archive — if it exists, the archive has already
# been extracted, so we skip both download and extraction.
extraction_anchors = {
"crcns-aa2-docs.tar.gz": "stim_data.csv",
"crcns-aa2-all_cells.tar.gz": "all_cells",
"crcns-aa2-all_stims.tar.gz": "all_stims",
}
for nersc_path in AA2_NERSC_FILES:
archive_name = os.path.basename(nersc_path)
archive_path = os.path.join(dest_path, archive_name)
anchor = extraction_anchors[archive_name]
if os.path.exists(os.path.join(dest_path, anchor)):
continue
if not os.path.exists(archive_path):
crcns_download(nersc_path, archive_path,
username=username, password=password)
untar(archive_path, dest_path, strip_components=2)
return dest_path
# TODO (misc.):
# - some PSTHs have very high peaks (> 20) on some stims (songrips) --> double check
# - make concatenable to AA1 Dataset ?
[docs]
def get_animals_ids(file_path):
"""
Extracts unique animal identifiers from the first column of the 'cell_stim_classes.csv' file.
The unique identifier is defined as the substring preceding the first underscore in the first column.
The output is a list of unique identifiers.
"""
# Load the CSV file
df = pd.read_csv(file_path)
# Extract the first column
first_column = df.iloc[:, 0]
# Extract the substring preceding the first underscore and find unique values
unique_identifiers = first_column.str.split('_').str[0].unique()
return list(unique_identifiers)
[docs]
def get_stims_ids_from_csv(file_path):
"""
Extracts .wav file names from the first column of the 'stim_data.csv' file, and classify them into stimulus types.
The output is a dictionary with categories as keys and lists of .wav file names as values.
"""
stim_dict = {
"songrip": [],
"flatrip": [],
"conspecific": [],
"unknown": [],
"bengalese": []
}
stim_count = 0
with open(file_path, 'r') as file:
reader = csv.reader(file)
for row in reader:
if len(row) >= 2: # Ensure the row has enough columns
category = row[-1].strip().lower()
wav_file = row[0].strip()
# Match the category with the keys in the dictionary
if category in stim_dict:
stim_dict[category].append(wav_file)
stim_count += 1
else:
print(f"Unknown category '{category}' found in the file. Skipping...")
#print(f"Number or stimuli: {stim_count}")
return stim_dict
[docs]
def load_stim_data_csv(file_path):
"""Read CRCNS-AA2 ``stim_data.csv`` into a ``{wav_filename: {...}}`` dict.
Each value is ``{"sample_rate": Hz, "bit_depth": int, "n_samples": int,
"duration_s": float}``. Returned even for stims classified as "unknown" /
"bengalese", since the dataset class will simply not select those by
default.
"""
info = {}
with open(file_path, 'r') as file:
reader = csv.reader(file)
for row in reader:
if len(row) < 5:
continue
wav, sr_str, bd_str, n_str, _cls = (c.strip() for c in row[:5])
sr = float(sr_str)
n_samples = int(n_str)
info[wav] = {
"sample_rate": sr,
"bit_depth": int(bd_str),
"n_samples": n_samples,
"duration_s": n_samples / sr if sr > 0 else float("nan"),
}
return info
[docs]
def get_area_cells(file_path):
"""
From the cell_regions.csv file, returns a dictionary with area labels as keys and lists of cell names as values.
"""
cell_dict = {
'L': [], # e.g., '[pupu2122_2_A, pupu2122_2_B, ...]
'L1': [],
'L2a': [],
'L2b': [],
'L3': [],
'mld': [],
'OV': [],
'CM': [],
'None': []
}
with open(file_path, 'r') as file:
reader = csv.reader(file)
for row in reader:
if len(row) >= 2: # Ensure the row has enough columns
site = row[-1].strip()
cell = row[0].strip()
cell_dict[site].append(cell)
return cell_dict
[docs]
def get_stim_ids_from_folders(cells_path, verbose=False):
"""
needs the 'all_cells/' path
returns a dictionary with the three main stim_types 'consepcific', 'songrip' and 'flatrip' as keys, and a
list of unique wav names for each value
"""
stim_dict = {
"songrip": [],
"flatrip": [],
"conspecific": []
}
for stim_type in stim_dict.keys():
for cell in sorted(os.listdir(cells_path)):
if stim_type not in os.listdir(os.path.join(cells_path, cell)):
if verbose:
print(f"no {stim_type} stim for cell {cell}, skipping...")
else:
stimfiles = sorted(file for file in os.listdir(os.path.join(cells_path, cell, stim_type)) if 'stim' in file)
for stimfile in stimfiles:
with open(os.path.join(cells_path, cell, stim_type, stimfile)) as f:
wavname = f.readlines()[0][:-1]
if wavname not in stim_dict[stim_type]:
stim_dict[stim_type].append(wavname)
return stim_dict
[docs]
class CRCNSAA2Dataset(AudioNeuralDataset):
"""PyTorch dataset for the CRCNS-AA2 recordings (OV, MLd, Field L, CM).
494 extracellular, spike-sorted single units of male zebra finches,
identified in OV, MLd, Field L, L1, L2a, L2b, L3 (and some with
unidentified area, ``None``). Three stimulus classes — conspecific songs
(72 stims), flat ripples (20) and song ripples (25) — each presented
10-20 times, with low trial-to-trial variability. Almost all cells saw
conspecific and songrip stimuli; about half saw flatrip. Population
fitting-compatible. Data are available at
https://crcns.org/data-sets/aa/aa-2/about (free CRCNS account).
Notes
-----
Follows the standard deepSTRF data paradigm (see
``docs/_source/md/data_paradigm.md``). AA2-specific metadata:
- ``stims`` are mel-spectrograms ``(1, F, T_s)``.
- ``stim_meta`` dicts hold ``name``, ``type``, ``sample_rate``,
``n_samples`` and ``duration_s`` (the last three from
``data/stim_data.csv``).
- ``nrn_meta`` dicts hold ``cell_id``, ``animal_id``, ``area``,
``cell_seq`` and ``rig`` (see :class:`~deepSTRF.datasets.audio.crcns_aa1.CRCNSAA1Dataset`
for the cell-name format; ``rig`` is often ``None`` in AA2).
References
----------
Gill et al. (2006). "Sound representation methods for spectro-temporal
receptive field estimation."
Amin et al. (2010). "Role of the Zebra Finch Auditory Thalamus in
Generating Complex Representations for Natural Sounds."
"""
def __init__(self, path: Optional[str] = None,
areas=('Field_L', 'mld', 'OV', 'CM', 'None'),
stimuli=('conspecific', 'flatrip', 'songrip'),
animals='all', dt_ms=1, smooth=True, n_mels=32,
compression='cubic',
window_ms: float = 10.0,
return_waveform: bool = False, audio_fs: int = 32000,
download: bool = False,
username: Optional[str] = None,
password: Optional[str] = None):
"""
Initializes the AA2 Dataset.
Parameters
----------
path : str, optional
Path to the AA2 data folder. Defaults to the platformdirs cache.
areas : tuple of str
Recording sites of interest: 'Field_L', 'L1', 'L2a', 'L2b',
'L3', 'mld', 'OV', 'CM', or 'None'.
stimuli : tuple of str
Stimulus types of interest: 'conspecific', 'flatrip', 'songrip'.
dt_ms : float
Time step size in ms.
n_mels : int
Number of mel frequency bands.
compression : str
Spectrogram compression ('cubic', 'log1p', 'none'). Ignored when
``return_waveform=True``.
window_ms : float, default 10.0
FFT analysis-window length in ms. ``n_fft`` is computed as
``round(window_ms * 1e-3 * sample_rate)`` and is **decoupled
from ``hop_length``** so phonemic detail is preserved at any
``dt_ms``. Earlier versions of this dataset hardcoded
``n_fft = 10 * hop_length``, which gave a benign 10 ms FFT
window at ``dt_ms=1`` but a 500 ms window at ``dt_ms=50``.
Default ``window_ms=10.0`` preserves bit-identical behaviour
at ``dt_ms=1`` while fixing the scaling bug at coarser bins.
Ignored when ``return_waveform=True``.
return_waveform : bool, default False
If True, ``self.stims[s]`` holds the raw audio waveform
``(1, T_audio)`` at ``audio_fs`` Hz (grid-locked to ``T_audio =
T_neural * hop``) instead of the in-loader mel spectrogram. Pair
with a model whose ``wav2spec`` slot is a waveform front-end;
responses are unchanged.
audio_fs : int, default 32000
Sample rate for waveform mode. Default 32 kHz is the native rate of
the AA2 wavs (no resampling); other values resample and must keep
``audio_fs * dt_ms / 1000`` an integer. Ignored unless
``return_waveform=True``.
download : bool, default False
If True and the data is missing under ``path``, fetch the
~30 MB worth of CRCNS-AA2 archives from the NERSC mirror
(free CRCNS account required) and extract in place.
username, password : str, optional
CRCNS credentials. Default to ``$CRCNS_USERNAME`` /
``$CRCNS_PASSWORD``. Prefer env vars over passing literals.
"""
if path is None:
path = str(default_cache_dir("AA2"))
if download:
download_aa2(path, username=username, password=password)
super().__init__(path, dt_ms)
self.species = 'zebra finch'
# Informational hearing range (zebra finch behavioural audiogram ≈ 250 Hz – 8 kHz).
self.hearing_range_hz = (250.0, 8000.0)
# Waveform-input mode: store raw audio instead of the in-loader mel spec.
self.return_waveform = bool(return_waveform)
self.audio_fs = int(audio_fs) if return_waveform else None
# sr = 32 kHz (CRCNS-AA2 wavs). hop_length tracks dt_ms; n_fft is
# **independent of dt_ms** and pinned to ``window_ms``. At the
# default (window_ms=10, dt_ms=1) this gives n_fft=320, which
# equals the legacy ``10 * hl`` value bit-for-bit. At coarser
# dt_ms the legacy formula scaled the window with the hop and
# over-smoothed the spec; the new formula caps n_fft at
# ``window_ms * sr`` (or ``hop_length`` if that's larger — STFT
# constraint). See ``window_ms`` docstring above.
sample_rate = 32000
self.F = n_mels
self.window_ms = float(window_ms)
hl = int(dt_ms * 32)
# See the AA1 spec block for why we derive n_fft from the
# truncated hop via the ratio ``window_ms / dt_ms`` — preserves
# bit-identical behaviour at the legacy ``window_ms = 10 * dt_ms``
# contract for every supported sr.
n_fft = max(int(round((self.window_ms / float(dt_ms)) * hl)), hl)
transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hl, n_mels=self.F,
)
self.compression = compression
# audio-samples-per-neural-bin for waveform mode (== hl at native 32 kHz).
hop = int(round(audio_fs * dt_ms / 1000)) if return_waveform else None
#######################
# 1. get metadata
#######################
# get all animal ids
ANIMAL_IDs = get_animals_ids(os.path.join(path, 'cell_stim_classes.csv')) # list of animal ids
# get all cells for each areas
AREA_CELLs = get_area_cells(os.path.join(path, 'cell_regions.csv')) # dict with areas as keys and list of cell names as values
# get all stims
STIM_IDs = get_stim_ids_from_folders(os.path.join(path, 'all_cells/'), verbose=False) # dict with stim types as keys and list of wav names as values
# per-wav metadata (sample_rate, n_samples, duration_s, bit_depth)
STIM_INFO = load_stim_data_csv(os.path.join(path, 'stim_data.csv'))
################################
# 2. cells & stims selection
################################
# pre-selection of cell names, based on stim_type, animal, and area
cells = []
# filter cells by area
cell_areas = []
if areas == 'all':
areas = ('Field_L', 'OV', 'CM', 'mld', 'None')
for area in areas:
if (area == 'Field_L') or (area == 'L'):
cells += AREA_CELLs['L'] + AREA_CELLs['L1'] + AREA_CELLs['L2a'] + AREA_CELLs['L2b'] + AREA_CELLs['L3']
cell_areas += (['L'] * len(AREA_CELLs['L']) +
['L1'] * len(AREA_CELLs['L1']) +
['L2a'] * len(AREA_CELLs['L2a']) +
['L2b'] * len(AREA_CELLs['L2b']) +
['L3'] * len(AREA_CELLs['L3']))
else:
cells += AREA_CELLs[area]
cell_areas += ([area] * len(AREA_CELLs[area]))
# filter cells by animal
if animals == 'all':
self.animals = ANIMAL_IDs
else:
self.animals = animals
cells = [cell for cell in cells if cell.split('_')[0] in self.animals]
cell_animals = [cell.split('_')[0] for cell in cells]
# filter stimuli by stimulus type
stims = []
stim_types = []
for stim_type in stimuli:
stims += STIM_IDs[stim_type]
stim_types += [stim_type] * len(STIM_IDs[stim_type])
# filter cells by stimulus type:
# remove cells if they don't have any stimulus of the required type
for cell in cells:
i = 0
for stim_type in stimuli:
if stim_type not in os.listdir(os.path.join(path, 'all_cells', cell)):
i += 1
if i >= len(stimuli):
cell_idx = cells.index(cell)
cells.remove(cell)
cell_animals.remove(cell.split('_')[0])
cell_areas.remove(cell_areas[cell_idx])
self.N_neurons = len(cells)
####################################
# 3. load stim spectros and resps
####################################
stims_dir = os.path.join(path, f"all_stims/")
self.stims = [] # --> list of S tensors of shape (1, F, T_s)
self.responses = [] # --> list of S lists of N tensors of shape (R_{s,n}, T_s)
self.stim_meta = [] # --> list of S dicts {name, type}
stim_meta = list(zip(stims, stim_types))
# list of N dicts; cell_seq + rig parsed from the documented AA1/AA2
# cell-name format (<animal>_<cell_seq>[_<rig>], cf. AA1 readme PDF —
# AA2 inherits the convention).
self.nrn_meta = []
for c, a, r in zip(cells, cell_animals, cell_areas):
_, cell_seq, rig = parse_cell_name(c)
self.nrn_meta.append({
"cell_id": c,
"animal_id": a,
"area": r,
"cell_seq": cell_seq,
"rig": rig,
})
for s, (stim_name, stim_type) in enumerate(stim_meta):
# =========== load the stim ============
wav, sr = load_wav(os.path.join(stims_dir, stim_name)) # sample rate: 32 kHz (mono)
assert sr == 32000, f"found wav sr of {sr}, expected 32000"
spec = transform(wav) # (T,) --> (1, F, T-)
if self.compression == 'cubic':
spec = torch.pow(spec, 1.0/3)
elif self.compression == 'log1p':
spec = torch.log1p(spec)
elif self.compression == 'none':
pass
T = spec.shape[-1] # nbr of timesteps of current stim, in spectrogram form
# =========== load the resps ============
pop_resps = []
no_data_nrn_idces = []
for n, cell in enumerate(cells):
# some cells may not have any response for the current stim type;
# if that is the case --> null response directly --> (1, 1) tensor of NaN
if stim_type not in os.listdir(os.path.join(path, 'all_cells/', cell)):
resp = torch.full((1,1), fill_value=float('nan'))
no_data_nrn_idces.append(n)
# if they do have responses to this stim_type:
else:
# 1. find the stim file corresponding to the stim_name
# 2. find the spike file corresponding to that stim file, if any
spike_dir = os.path.join(path, 'all_cells/', cell, stim_type)
no_stim = True
stim_files = sorted(file for file in os.listdir(spike_dir) if 'stim' in file)
for stim_file in stim_files:
i = int(stim_file[4:]) # e.g., 'stim20' --> '20'
with open(os.path.join(spike_dir, stim_file)) as f:
wavname = f.readlines()[0][:-1]
if wavname == stim_name:
no_stim = False
break
# some neurons may have responses to stims of this type, but not this one in particular;
# in this case --> null response --> (1, 1) tensor of NaN again
if no_stim :
resp = torch.full((1,1), fill_value=float('nan'))
no_data_nrn_idces.append(n)
# 3. if they do indeed have a response to this specific stim, get the response
else:
spike_file = os.path.join(spike_dir, f'spike{i}')
try:
resp = load_spike_file(spike_file, dt_ms=self.dt) # (R, T)
# align response time dim to the stimulus duration T:
# shorter: right-pad with zeros (no spikes)
# longer: crop (post-stimulus spikes discarded)
# TODO: alternatively, keep post-stim spikes and pad the spectrogram to match
if resp.shape[-1] <= T:
Pt = T - resp.shape[-1]
resp = torch.nn.functional.pad(resp, pad=(0, Pt), mode='constant', value=0.)
else:
resp = resp[:, :T]
# some neurons have a 'stimXX' file, but not the corresponding 'spikeXX' file
# in this case --> null response --> (1, 1) tensor of NaN again
except FileNotFoundError:
resp = torch.full((1,1), fill_value=float('nan'))
no_data_nrn_idces.append(n)
# add the cell's response, whether it is null or not, to the population activity for this stim
# at the end of the for loop on cells, pop_resps is [N * (R_n, T_ns)]
pop_resps.append(resp)
# assert population response is well-formed
assert len(pop_resps) == self.N_neurons
# if none of the neurons heard this stim at all, skip it
if len(no_data_nrn_idces) == self.N_neurons:
continue
# otherwise keep the stim and its per-neuron responses. In waveform
# mode store the raw audio (grid-locked to T * hop samples so it
# aligns with the T response bins) instead of the in-loader mel spec.
if return_waveform:
T_audio = T * hop
w = wav if audio_fs == sr else torchaudio.functional.resample(wav, sr, audio_fs)
if w.shape[-1] < T_audio:
w = torch.nn.functional.pad(w, (0, T_audio - w.shape[-1]))
else:
w = w[..., :T_audio]
self.stims.append(w.contiguous().float())
else:
self.stims.append(spec)
self.responses.append(pop_resps)
wav_info = STIM_INFO.get(stim_name, {})
self.stim_meta.append({
"name": stim_name,
"type": stim_type,
"sample_rate": wav_info.get("sample_rate"),
"n_samples": wav_info.get("n_samples"),
"duration_s": wav_info.get("duration_s"),
})
# smooth PSTHs with a 21 ms Hanning window (Hsu / Borst / Theunissen 2004)
if smooth:
self.smooth_responses(window_ms=21.0)
# self.nrn_masks is a derived @property on the base class — no need
# to populate it here
self.validate()