Source code for deepSTRF.datasets.audio.crcns_aa1

import os
from typing import Optional

import torch
import torchaudio

from deepSTRF.datasets.audio.audio_dataset import AudioNeuralDataset
from deepSTRF.datasets.audio._crcns_aa_loaders import load_spike_file, parse_cell_name
from deepSTRF.utils.audio_io import load_wav
from deepSTRF.utils.data_download import crcns_download, default_cache_dir, unzip


# TODO (misc.):
#  - find a way to use pre-onset activity ?
#  - make concatenable ?


# CRCNS-AA1 ships as a single zip archive on the NERSC mirror.
# After unzip -> 'all_stims/', 'Field_L_cells/', 'MLd_cells/' at the
# directory root (flat — no wrapping top-level folder).
AA1_NERSC_PATH = "aa-1/crcns-aa1.zip"


[docs] def download_aa1(dest: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None) -> str: """Download the CRCNS-AA1 archive from the NERSC mirror into ``dest``. Idempotent: skips the archive if already present, and skips unzipping if ``Field_L_cells/`` already exists in ``dest``. Returns the dataset directory. Parameters ---------- dest : str, optional Defaults to the platformdirs cache (overridable via ``$DEEPSTRF_DATA_DIR``). username, password : str, optional Default to ``$CRCNS_USERNAME`` / ``$CRCNS_PASSWORD``. Free CRCNS account at https://crcns.org/register. """ dest_path = str(default_cache_dir("AA1") if dest is None else dest) os.makedirs(dest_path, exist_ok=True) zip_path = os.path.join(dest_path, "crcns-aa1.zip") if not os.path.exists(zip_path): crcns_download(AA1_NERSC_PATH, zip_path, username=username, password=password) if not os.path.isdir(os.path.join(dest_path, "Field_L_cells")): unzip(zip_path, dest_path, strip_root=True) return dest_path
[docs] def get_animals_ids(data_path): """ Takes in the path of the 'CRCNS_AA1/data/' folder, goes through 'Field_L/' and 'MLd', and outputs a list of unique animal ids, which are the string preceding the first underscore of each subfolder. e.g., 'gg0304_4_B' --> 'gg0304' """ animal_ids = [] for area_folder in ['Field_L_cells', 'MLd_cells']: cell_names = sorted(os.listdir(os.path.join(data_path, area_folder))) for cell_name in cell_names: animal_id = cell_name.split('_')[0] if animal_id not in animal_ids: animal_ids.append(animal_id) return animal_ids
[docs] def get_area_cells(data_path): """ From the cell_regions.csv file, returns a dictionary with area labels as keys and lists of cell names as values. """ cell_dict = { 'Field_L': [], # e.g., '[pupu2122_2_A, pupu2122_2_B, ...] 'MLd': [] } for area in cell_dict.keys(): area_folder = f'{area}_cells/' cell_names = sorted(os.listdir(os.path.join(data_path, area_folder))) for cell_name in cell_names: cell_dict[area].append(cell_name) return cell_dict
[docs] def get_stim_ids(data_path): """ From the cell_regions.csv file, returns a dictionary with area labels as keys and lists of cell names as values. """ stim_dict = { 'conspecific': [], # e.g., ['723792DF8CA8D0B99B8059503E5006BA.wav', '4922458336F516A1D0E31DA099896C0A.wav', ...] 'flatrip': [] } for stim_type in stim_dict.keys(): stim_folder = f'all_stims/{stim_type}/' stim_names = sorted(os.listdir(os.path.join(data_path, stim_folder))) for stim_name in stim_names: stim_dict[stim_type].append(stim_name) return stim_dict
[docs] class CRCNSAA1Dataset(AudioNeuralDataset): """PyTorch dataset for the CRCNS-AA1 recordings. Extracellular, spike-sorted single units of anesthetized male zebra finches: 50 cells in Field L and 50 in MLd, recorded in response to 10 clips of conspecific vocalizations and 20 clips of flat ripples (up to 5 s each, ~10 trials on average). Data are available at https://crcns.org/data-sets/aa/aa-1/about (free CRCNS account); see the AA1 README in the deepSTRF docs for the full notes. Notes ----- Follows the standard deepSTRF data paradigm (see ``docs/_source/md/data_paradigm.md``). AA1-specific metadata: - ``stims`` are mel-spectrograms ``(1, F, T_s)``. - ``stim_meta`` dicts hold ``name``, ``type``, ``sample_rate``, ``n_samples`` and ``duration_s``. - ``nrn_meta`` dicts hold ``cell_id``, ``animal_id``, ``area``, ``cell_seq`` and ``rig``. ``cell_seq`` is the sequential cell index parsed from the cell folder name (the n-th cell recorded); ``rig`` is the single-letter rig label when present, else ``None`` (cells "4_A" and "4_B" were recorded simultaneously, possibly in different areas). Two cells lack ``conspecific`` responses: ``pipu1018_2_A`` (MLd) and ``pipu1018_2_B`` (Field_L). References ---------- Woolley et al. (2005). "Tuning for Spectro-temporal Modulations: a Mechanism for Auditory Discrimination of Natural Sound." Hsu et al. (2004). "Modulation power and phase spectrum of natural sounds enhance neural encoding performed by single auditory neurons." Singh & Theunissen (2003). "Modulation spectra of natural sounds and ethological theories of auditory processing." """ def __init__(self, path: Optional[str] = None, areas=('Field_L', 'MLd'), stimuli=('conspecific', 'flatrip'), animals='all', dt_ms=1, smooth=True, n_mels=32, compression='cubic', window_ms: float = 10.0, return_waveform: bool = False, audio_fs: int = 32000, download: bool = False, username: Optional[str] = None, password: Optional[str] = None): """ Initializes the AA1 Dataset. Parameters ---------- path : str, optional Path to the AA1 data folder (containing ``Field_L_cells/``, ``MLd_cells/``, ``all_stims/``). Defaults to the platformdirs cache (``$DEEPSTRF_DATA_DIR`` overrides). areas : tuple of str Recording sites: 'Field_L' or 'MLd'. stimuli : tuple of str Stimulus types: 'conspecific' or 'flatrip'. dt_ms : float Time step size in ms. n_mels : int Number of mel frequency bands. compression : str Spectrogram compression ('cubic', 'log1p', 'none'). Ignored when ``return_waveform=True``. window_ms : float, default 10.0 FFT analysis-window length in ms. ``n_fft`` is computed as ``round(window_ms * 1e-3 * sample_rate)`` and is **decoupled from ``hop_length``** so phonemic detail is preserved at any ``dt_ms``. Earlier versions of this dataset hardcoded ``n_fft = 10 * hop_length`` — benign at the default ``dt_ms=1`` (10 ms FFT window), but at ``dt_ms=50`` the same formula produced a 500 ms FFT window and over-smoothed every spec frame. The default ``window_ms=10.0`` preserves bit-identical behaviour at ``dt_ms=1`` while removing the scaling bug at coarser bins. Speech-pipeline users may prefer ``window_ms=25.0`` (Kaldi default). Ignored when ``return_waveform=True``. return_waveform : bool, default False If True, ``self.stims[s]`` holds the raw audio waveform ``(1, T_audio)`` at ``audio_fs`` Hz (grid-locked to ``T_audio = T_neural * hop``) instead of the in-loader mel spectrogram. Pair with a model whose ``wav2spec`` slot is a waveform front-end (see ``deepSTRF.models.wav2spec``); responses are unchanged. audio_fs : int, default 32000 Sample rate for waveform mode. Default 32 kHz is the native rate of the AA1 wavs (so no resampling); other values resample and must keep ``audio_fs * dt_ms / 1000`` an integer. Ignored unless ``return_waveform=True``. download : bool, default False If True and the data is missing under ``path``, fetch the ~17 MB CRCNS-AA1 archive from the NERSC mirror (free CRCNS account required; see ``crcns_download``) and unzip in place. username, password : str, optional CRCNS credentials. Default to ``$CRCNS_USERNAME`` / ``$CRCNS_PASSWORD``. Prefer the env vars over passing literals — anything in source / a notebook ends up in history / logs / VCS. """ if path is None: path = str(default_cache_dir("AA1")) if download: download_aa1(path, username=username, password=password) super().__init__(path, dt_ms) # general self.species = 'zebra finch' # Informational hearing range (zebra finch behavioural audiogram ≈ 250 Hz – 8 kHz). self.hearing_range_hz = (250.0, 8000.0) # Waveform-input mode: store raw audio instead of the in-loader mel spec. self.return_waveform = bool(return_waveform) self.audio_fs = int(audio_fs) if return_waveform else None self.behavioral_state = 'anesthetized' # sr = 32 kHz (CRCNS-AA1 wavs). hop_length tracks dt_ms; n_fft is # **independent of dt_ms** and pinned to ``window_ms``. At the # default (window_ms=10, dt_ms=1) this gives n_fft=320, which # equals the legacy ``10 * hl`` value bit-for-bit — no behaviour # change at the historical default. At coarser dt_ms the new # formula yields a sensible window (n_fft=320 at dt_ms=50 vs the # legacy 16000). See ``window_ms`` docstring above. sample_rate = 32000 self.F = n_mels self.window_ms = float(window_ms) hl = int(dt_ms * 32) # ``n_fft`` is derived from ``hop`` (already truncated to an int) # via the ratio ``window_ms / dt_ms`` so the default # ``window_ms = 10.0`` reproduces the legacy ``n_fft = 10 * hl`` # value bit-for-bit at every sr supported by AA1 (32 kHz fixed # here, but the same trick is used in AA4 where sr varies and # ``hop`` is truncated). Floored at ``hl`` so the # ``n_fft >= hop_length`` STFT constraint always holds. n_fft = max(int(round((self.window_ms / float(dt_ms)) * hl)), hl) transform = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hl, n_mels=self.F, ) self.compression = compression # audio-samples-per-neural-bin for waveform mode (== hl when audio_fs is # the native 32 kHz). Grid-lock: audio_fs * dt_ms / 1000 must be integer. hop = int(round(audio_fs * dt_ms / 1000)) if return_waveform else None ####################### # 1. get metadata ####################### # get all animal ids ANIMAL_IDs = get_animals_ids(path) # get all cells for each area AREA_CELLs = get_area_cells(path) # get all stims STIM_IDs = get_stim_ids(path) ################################ # 2. cells & stims selection ################################ # pre-selection of cell names, based on stim_type, animal, and area cells = [] # filter cells by area cell_areas = [] if areas == 'all': areas = ('Field_L', 'MLd') for area in areas: cells += AREA_CELLs[area] cell_areas += ([area] * len(AREA_CELLs[area])) # filter cells by animal if animals == 'all': self.animals = ANIMAL_IDs else: self.animals = animals cells = [cell for cell in cells if cell.split('_')[0] in self.animals] cell_animals = [cell.split('_')[0] for cell in cells] # filter stimuli by type stims = [] stim_types = [] for stim_type in stimuli: stims += STIM_IDs[stim_type] stim_types += [stim_type] * len(STIM_IDs[stim_type]) # filter cells by stimulus type: # remove cells if they don't have any stimulus of the required type for area in ['Field_L', 'MLd']: area_path = os.path.join(path, f'{area}_cells') for cell in sorted(os.listdir(area_path)): # if the cell is in the current selection of cells, check if it has responses to at least one stim type if cell in cells: cell_path = os.path.join(area_path, cell) cell_stim_types = sorted(os.listdir(cell_path)) i = 0 for stim_type in cell_stim_types: if stim_type not in stimuli: i += 1 # if none of the available stim types for this cell were among the required, remove the cell if i == len(cell_stim_types): cell_idx = cells.index(cell) cells.remove(cell) cell_animals.remove(cell.split('_')[0]) cell_areas.remove(cell_areas[cell_idx]) else: continue self.N_neurons = len(cells) ################################ # 3. cells & stims selection ################################ self.stims = [] # --> list of S tensors of shape (1, F, T_s) self.responses = [] # --> list of S lists of N tensors of shape (R_{s,n}, T_s) self.stim_meta = [] # --> list of S dicts {name, type} stim_meta = list(zip(stims, stim_types)) # list of N dicts; cell_seq + rig parsed from the documented # AA1 cell-name format (<animal>_<cell_seq>[_<rig>], cf. AA1 readme PDF) self.nrn_meta = [] for c, a, r in zip(cells, cell_animals, cell_areas): _, cell_seq, rig = parse_cell_name(c) self.nrn_meta.append({ "cell_id": c, "animal_id": a, "area": r, "cell_seq": cell_seq, "rig": rig, }) for s, (stim_name, stim_type) in enumerate(stim_meta): # =========== load the stim ============ stims_dir = os.path.join(path, f"all_stims/{stim_type}/") wav, sr = load_wav(os.path.join(stims_dir, stim_name)) # sample rate: 32 kHz (mono) assert sr == 32000, f"found wav sr of {sr}, expected 32000" n_samples_wav = wav.shape[-1] spec = transform(wav) # (T,) --> (1, F, T-) if self.compression == 'cubic': spec = torch.pow(spec, 1.0 / 3) elif self.compression == 'log1p': spec = torch.log1p(spec) elif self.compression == 'none': pass T = spec.shape[-1] # nbr of timesteps of current stim, in spectrogram form # =========== load the resps ============ pop_resps = [] no_data_nrn_idces = [] for n, nrn in enumerate(self.nrn_meta): cell_name = nrn["cell_id"] area = nrn["area"] # some cells may not have any response for the current stim type; # if that is the case --> null response directly if stim_type not in os.listdir(os.path.join(path, f'{area}_cells/', cell_name)): resp = torch.full((1,1), fill_value=float('nan')) no_data_nrn_idces.append(n) # if they do have responses to this stim_type: else: # 1. find the stim file corresponding to the stim_name # 2. find the spike file corresponding to that stim file, if any spike_dir = os.path.join(path, f'{area}_cells/', cell_name, stim_type) no_stim = True stim_files = sorted(file for file in os.listdir(spike_dir) if 'stim' in file) for stim_file in stim_files: i = int(stim_file[4:]) # e.g., 'stim20' --> '20' with open(os.path.join(spike_dir, stim_file)) as f: wavname = f.readlines()[0][:-1] if wavname == stim_name: no_stim = False break # some neurons may have responses to stims of this type, but not this one in particular; # in this case --> null response if no_stim: resp = torch.full((1,1), fill_value=float('nan')) no_data_nrn_idces.append(n) # 3. if they do indeed have a response to this specific stim, get the response else: spike_file = os.path.join(spike_dir, f'spike{i}') try: resp = load_spike_file(spike_file, dt_ms=self.dt) # (R, T) # align response time dim to the stimulus duration T: # shorter: right-pad with zeros (no spikes) # longer: crop (post-stimulus spikes discarded) # TODO: alternatively, keep post-stim spikes and pad the spectrogram to match if resp.shape[-1] <= T: Pt = T - resp.shape[-1] resp = torch.nn.functional.pad(resp, pad=(0, Pt), mode='constant', value=0.) else: resp = resp[:, :T] # some neurons have a 'stimXX' file, but not the corresponding 'spikeXX' file except FileNotFoundError: resp = torch.full((1,1), fill_value=float('nan')) no_data_nrn_idces.append(n) # add the cell's response, whether it is null or not, to the population activity for this stim # at the end of the for loop on cells, pop_resps is [N * (R_n, T_ns)] pop_resps.append(resp) # assert population response is well-formed assert len(pop_resps) == self.N_neurons # if none of the neurons heard this stim at all, skip it if len(no_data_nrn_idces) == self.N_neurons: continue # otherwise keep the stim and its per-neuron responses. In waveform # mode store the raw audio (grid-locked to T * hop samples so it # aligns with the T response bins) instead of the in-loader mel spec. if return_waveform: T_audio = T * hop w = wav if audio_fs == sr else torchaudio.functional.resample(wav, sr, audio_fs) if w.shape[-1] < T_audio: w = torch.nn.functional.pad(w, (0, T_audio - w.shape[-1])) else: w = w[..., :T_audio] self.stims.append(w.contiguous().float()) else: self.stims.append(spec) self.responses.append(pop_resps) self.stim_meta.append({ "name": stim_name, "type": stim_type, "sample_rate": float(sr), "n_samples": int(n_samples_wav), "duration_s": n_samples_wav / float(sr), }) # smooth PSTHs with a 21 ms Hanning window (Hsu / Borst / Theunissen 2004) if smooth: self.smooth_responses(window_ms=21.0) # self.nrn_masks is a derived @property on the base class — no need # to populate it here self.validate()