Source code for deepSTRF.datasets.audio.crcns_aa4

import os
import re
from typing import Optional, Sequence

import h5py
import numpy as np
import torch
import torchaudio

from deepSTRF.datasets.audio.audio_dataset import AudioNeuralDataset
from deepSTRF.datasets.audio._crcns_aa_loaders import time_binning
from deepSTRF.utils.audio_io import load_wav
from deepSTRF.utils.data_download import (
    crcns_download,
    default_cache_dir,
    untar,
)


def _get_subgroups(group):
    """Return list of subgroup names under `group` in an h5 File."""
    return [name for name, obj in group.items() if isinstance(obj, h5py.Group)]


def _decode_attr(val) -> str:
    """Decode an h5 string attribute that may be bytes / array-of-bytes / str."""
    try:
        return val.decode()
    except AttributeError:
        try:
            return val[0].decode()
        except (AttributeError, IndexError, TypeError):
            return str(val)


# Filename format from the AA4 PDF:
#   Site<S>_L<Lz>R<Rz>_e<elec>_s<online_sortid>[_ss<offline_sortid>].h5
# (e.g. "Site1_L1400R1400_e10_s0_ss1.h5"); some files omit the trailing _ss<n>.
_AA4_SUBSORT_RE = re.compile(r"_ss(\d+)$")

# Some cells in the data have a typo'd sortType ("singl" instead of "single").
# Normalise so downstream filters don't have to care.
_AA4_SORTTYPE_FIXES = {"singl": "single"}


AA4_ANIMAL_IDS = ('BlaBro09xxF', 'GreBlu9508M', 'LblBlu2028M', 'WhiBlu5396M', 'WhiWhi4522M', 'YelBlu6903F')


[docs] def download_aa4(dest: Optional[str] = None, animals: Sequence[str] = AA4_ANIMAL_IDS, username: Optional[str] = None, password: Optional[str] = None) -> str: """Download CRCNS-AA4 archives from the NERSC mirror into ``dest``. AA4 is split into one ``.tar.gz`` per animal (each is hundreds of MB); by default this fetches all 6, but ``animals`` can be narrowed to a subset. The CRCNSCode tutorial archive is also fetched (small, ~1 MB). Idempotent: skips an archive if its animal directory already exists, skips the CRCNSCode archive if ``CRCNSCode/`` already exists. Parameters ---------- dest : str, optional Defaults to ``default_cache_dir('AA4')`` (``$DEEPSTRF_DATA_DIR`` overrides). animals : sequence of str, default all 6 Animals to download. Must be a subset of ``AA4_ANIMAL_IDS``. username, password : str, optional Default to ``$CRCNS_USERNAME`` / ``$CRCNS_PASSWORD``. """ dest_path = str(default_cache_dir("AA4") if dest is None else dest) os.makedirs(dest_path, exist_ok=True) for animal in animals: assert animal in AA4_ANIMAL_IDS, \ f"Unknown AA4 animal {animal!r}. Valid: {AA4_ANIMAL_IDS}" if os.path.isdir(os.path.join(dest_path, animal)): continue archive_name = f"{animal}.tar.gz" archive_path = os.path.join(dest_path, archive_name) if not os.path.exists(archive_path): crcns_download(f"aa-4/{archive_name}", archive_path, username=username, password=password) untar(archive_path, dest_path) # tarball already wraps in <animal>/ # CRCNSCode tutorial — small, useful pointer to the original loaders code_dir = os.path.join(dest_path, "CRCNSCode") if not os.path.isdir(code_dir): archive_path = os.path.join(dest_path, "CRCNSCode.tar.gz") if not os.path.exists(archive_path): crcns_download("aa-4/CRCNSCode.tar.gz", archive_path, username=username, password=password) untar(archive_path, dest_path) return dest_path
[docs] class CRCNSAA4Dataset(AudioNeuralDataset): """PyTorch dataset for the CRCNS-AA4 recordings. 1401 extracellular, spike-sorted single and multi units of adult zebra finches (4 males, 2 females) in Field L, caudolateral and caudomedial mesopallium (CLM, CMM) and caudomedial nidopallium (NCM) — though units were not precisely assigned to one of these areas. Three stimulus classes (conspecific songs, calls, ripple noise), each a few seconds long and presented ~10 times. Population- and batch-compatible. Data are available at https://crcns.org/data-sets/aa/aa-4/about-aa-4 (free CRCNS account). Notes ----- Follows the standard deepSTRF data paradigm (see ``docs/_source/md/data_paradigm.md``). AA4-specific metadata: - ``stims`` are mel-spectrograms ``(1, F, T_s)``. - ``stim_meta`` dicts hold ``name`` (the stimulus md5 — the canonical identifier, since the wav filename is per-animal and not unique across the corpus), ``type``, ``class`` and ``duration_s`` (the ``stim_duration`` attr from the h5, in seconds). - ``nrn_meta`` dicts hold: ``cell_id`` (h5 basename, no extension), ``animal_id``, ``sex`` (``'M'`` / ``'F'``), ``site`` (e.g. ``"Site1"``), ``electrode`` (int 1-32, channel index across both 16-channel arrays at a site), ``ldepth`` / ``rdepth`` (left / right array depth in µm), ``sort_type`` (``'single'`` / ``'multi'``; ``'noise'`` / ``'tdt'`` are filtered out), ``sort_id`` (online-sort int) and ``subsort_id`` (offline-sort int parsed from the trailing ``_ss<N>``; ``None`` if absent). The dataset paper does not publish a per-cell brain-area assignment, so the depth + electrode-array geometry is the only anatomical proxy; nor does it document which electrode IDs (1-16 vs 17-32) map to the left vs right hemisphere — confirm with the dataset authors before deriving a hemisphere from ``electrode``. References ---------- Elie & Theunissen (2015). "Meaning in the avian auditory cortex: Neural representation of communication calls." *European Journal of Neuroscience*. Elie & Theunissen (2019). "Invariant neural responses for sensory categories revealed by the time-varying information for communication calls." *PLoS Computational Biology*. """ def __init__(self, path: Optional[str] = None, animals='all', stimuli=('song', 'call', 'mlnoise'), dt_ms=1.0, smooth=True, n_mels=32, compression='cubic', window_ms: float = 10.0, return_waveform: bool = False, audio_fs: int = 24000, download: bool = False, username: Optional[str] = None, password: Optional[str] = None): """ Initializes the AA4 Dataset. Parameters ---------- path : str, optional Path to the ``CRCNS_AA4/data/`` folder containing one subfolder per animal (with ``.h5`` cell files + a ``wavfiles/`` directory of stimulus ``.wav`` files). Defaults to the platformdirs cache. animals : 'all' or sequence of str Animals to load (any subset of ``AA4_ANIMAL_IDS``). stimuli : sequence of str Stimulus types to keep; subset of {'song', 'call', 'mlnoise'}. dt_ms : float Time-bin width in ms. smooth : bool If True, smooth PSTHs in place with a 21 ms Hanning window (Hsu, Borst & Theunissen 2004). n_mels : int Number of mel frequency bands of the stimulus spectrogram. compression : {'cubic', 'log1p', 'none'} Compression applied to the spectrogram (saturation effect of hair cells). Ignored when ``return_waveform=True``. window_ms : float, default 10.0 FFT analysis-window length in ms. ``n_fft`` is computed per-stim as ``round(window_ms * 1e-3 * sample_rate)`` and is **decoupled from ``hop_length``** so phonemic detail is preserved at any ``dt_ms``. Earlier versions of this dataset hardcoded ``n_fft = hop * 10`` — at ``dt_ms=50`` that gave a 500 ms FFT window and over-smoothed every spec frame. Default ``window_ms=10.0`` preserves bit-identical behaviour at ``dt_ms=1`` (n_fft=320 at sr=32 kHz) while fixing the scaling bug at coarser bins. Ignored when ``return_waveform=True``. return_waveform : bool, default False If True, ``self.stims[s]`` holds the raw audio waveform ``(1, T_audio)`` at ``audio_fs`` Hz (grid-locked to ``T_audio = T_neural * hop``) instead of the in-loader mel spectrogram. Pair with a model whose ``wav2spec`` slot is a waveform front-end; responses are unchanged. audio_fs : int, default 24000 Sample rate for waveform mode. The AA4 wavs are 24414 Hz, which gives a non-integer hop at dt=1 ms; the default 24 kHz resamples to a clean ``hop = 24`` (exactly dt=1 ms bins, slightly better than the native spec's 0.983 ms). Other values must keep ``audio_fs * dt_ms / 1000`` an integer. Ignored unless ``return_waveform=True``. download : bool, default False If True and an animal's data is missing under ``path``, fetch its tarball (~hundreds of MB per animal) from the NERSC mirror and untar in place. Only the animals listed in ``animals`` are downloaded — useful for quick iteration on a subset. username, password : str, optional CRCNS credentials. Default to ``$CRCNS_USERNAME`` / ``$CRCNS_PASSWORD``. Prefer env vars over passing literals. """ if path is None: path = str(default_cache_dir("AA4")) animals_to_load = AA4_ANIMAL_IDS if animals == 'all' else tuple(animals) if download: download_aa4(path, animals=animals_to_load, username=username, password=password) super().__init__(path, dt_ms) # general self.species = 'zebra finch' # Informational hearing range (zebra finch ≈ 250 Hz – 8 kHz). self.hearing_range_hz = (250.0, 8000.0) self.F = n_mels self.compression = compression # Waveform-input mode: store raw audio (resampled to a single audio_fs, # since AA4 wavs are 24414 Hz which gives a non-integer hop at dt=1 ms) # instead of the in-loader mel spec. 24 kHz -> integer hop=24 at dt=1 ms. self.return_waveform = bool(return_waveform) self.audio_fs = int(audio_fs) if return_waveform else None self._wav_hop = int(round(audio_fs * self.dt / 1000)) if return_waveform else None self.animals = animals_to_load self.stim_types = set(stimuli) ########################################### # 1. preload mel-spectrograms per animal ########################################### # hop_length (samples) | dt (ms) — at sr = stim wav's sr # the wav sample rate varies across animals so hop = sr * dt_ms / 1000. # ``n_fft`` is decoupled from ``hop`` and pinned to # ``window_ms * 1e-3 * sr`` (with a floor at ``hop`` so the STFT # constraint ``n_fft >= hop_length`` is satisfied). See the # ``window_ms`` docstring above for the rationale and the # bit-identical-at-default contract. self.window_ms = float(window_ms) wav_specs_by_animal = {} wav_audio_by_animal = {} for animal in self.animals: wav_dir = os.path.join(path, animal, 'wavfiles') specs = {} wavs = {} for fname in sorted(os.listdir(wav_dir)): if not fname.endswith('.wav'): continue sid = os.path.splitext(fname)[0] # e.g. 'stim85' waveform, sr = load_wav(os.path.join(wav_dir, fname)) hop = max(1, int(sr * self.dt / 1000)) # Derive n_fft from the (already-truncated) hop via the # ratio ``window_ms / dt_ms``. At the default # ``window_ms = 10 * dt_ms`` this collapses to the legacy # ``hop * 10`` regardless of sr — bit-identical on the # 32 kHz and 44.1 kHz wavs that ship with this dataset. # Floored at ``hop`` so MelSpectrogram's # ``n_fft >= hop_length`` constraint always holds. n_fft = max(int(round((self.window_ms / float(self.dt)) * hop)), hop) mel_tf = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_mels=self.F, n_fft=n_fft, hop_length=hop, ) spec = mel_tf(waveform) # (1, F, T) for mono wav if self.compression == 'cubic': spec = torch.pow(spec, 1.0 / 3) elif self.compression == 'log1p': spec = torch.log1p(spec) elif self.compression == 'none': pass if spec.ndim == 2: spec = spec.unsqueeze(0) specs[sid] = spec if return_waveform: # store the raw audio, resampled to the dataset's single # audio_fs and grid-locked to T_neural * hop samples so it # aligns with the T_neural spec frames (= response bins). The # spec is still kept (it sets T_neural / the response length). T_audio = spec.shape[-1] * self._wav_hop w = (waveform if sr == audio_fs else torchaudio.functional.resample(waveform, sr, audio_fs)) if w.shape[0] > 1: w = w.mean(dim=0, keepdim=True) # downmix to mono if w.shape[-1] < T_audio: w = torch.nn.functional.pad(w, (0, T_audio - w.shape[-1])) else: w = w[..., :T_audio] wavs[sid] = w.contiguous().float() wav_specs_by_animal[animal] = specs wav_audio_by_animal[animal] = wavs ########################################### # 2. walk h5 cell files per animal ########################################### # ordered list of unique stim md5s (corpus-wide canonical id) stim_uids = [] stim_meta_map = {} # md5 -> {"name", "type", "class"} stim_spec_map = {} # md5 -> spectrogram tensor (1, F, T) stim_wav_map = {} # md5 -> waveform tensor (1, T_audio) [waveform mode] # per-neuron accumulator units_data = [] # list of dicts: {'meta': nrn_meta_dict, 'responses': {md5: (R, T) tensor}} for animal in self.animals: sex = animal[-1] animal_path = os.path.join(path, animal) for fname in sorted(os.listdir(animal_path)): if not fname.endswith('.h5'): continue h5_path = os.path.join(animal_path, fname) cell_id = os.path.splitext(fname)[0] with h5py.File(h5_path, 'r') as celldata: sort_type = _decode_attr(celldata.attrs.get('sortType', b'')) if sort_type in ('tdt', 'noise'): continue sort_type = _AA4_SORTTYPE_FIXES.get(sort_type, sort_type) subsort_match = _AA4_SUBSORT_RE.search(cell_id) subsort_id = int(subsort_match.group(1)) if subsort_match else None nrn_meta = { 'cell_id': cell_id, 'animal_id': animal, 'sex': sex, 'site': _decode_attr(celldata.attrs.get('site', b'')), 'electrode': int(celldata.attrs.get('electrode', 0)), 'ldepth': float(celldata.attrs.get('ldepth', np.nan)), 'rdepth': float(celldata.attrs.get('rdepth', np.nan)), 'sort_type': sort_type, 'sort_id': int(celldata.attrs.get('sortid', 0)), 'subsort_id': subsort_id, } responses = {} # iterate stim classes (skip metadata groups) for cls in sorted(_get_subgroups(celldata)): if cls in ('class_info', 'extra_info'): continue cls_grp = celldata[cls] for stim_key in sorted(_get_subgroups(cls_grp)): stim_grp = cls_grp[stim_key] stim_type = _decode_attr(stim_grp.attrs.get('stim_type', b'')) if stim_type not in self.stim_types: continue stim_md5 = _decode_attr(stim_grp.attrs.get('stim_md5', b'')) stim_class = _decode_attr(stim_grp.attrs.get('stim_class', b'')) stim_dur_s = float(stim_grp.attrs.get('stim_duration', np.nan)) # register unique stimulus on first encounter if stim_md5 not in stim_meta_map: spec = wav_specs_by_animal[animal].get(f'stim{stim_key}') if spec is None: # wav missing for this animal — skip the stim altogether continue stim_uids.append(stim_md5) stim_meta_map[stim_md5] = { 'name': stim_md5, 'type': stim_type, 'class': stim_class, 'duration_s': stim_dur_s, } stim_spec_map[stim_md5] = spec if return_waveform: stim_wav_map[stim_md5] = \ wav_audio_by_animal[animal].get(f'stim{stim_key}') T_stim = stim_spec_map[stim_md5].shape[-1] # bin spike times (in seconds in h5) into (R, T_stim) trial_tensors = [] for trial_key in sorted(_get_subgroups(stim_grp)): raw_times = stim_grp[trial_key]['spike_times'][()] raw_times = raw_times[raw_times >= 0] # post-onset only if raw_times.size == 0: continue times_ms = (raw_times * 1000.0).tolist() trial_tensors.append(time_binning(times_ms, dt_ms=self.dt)) if not trial_tensors: continue # align each trial to T_stim (right-pad with 0, or crop) aligned = [] for t in trial_tensors: if t.shape[-1] < T_stim: t = torch.nn.functional.pad( t, (0, T_stim - t.shape[-1]), mode='constant', value=0.0, ) elif t.shape[-1] > T_stim: t = t[..., :T_stim] aligned.append(t) counts = torch.stack(aligned, dim=0) # (R, T_stim) if torch.all(counts == 0): continue responses[stim_md5] = counts if responses: units_data.append({'meta': nrn_meta, 'responses': responses}) ########################################### # 3. assemble core dataset attributes ########################################### self.N_neurons = len(units_data) _stim_map = stim_wav_map if self.return_waveform else stim_spec_map self.stims = [_stim_map[uid] for uid in stim_uids] self.stim_meta = [stim_meta_map[uid] for uid in stim_uids] self.nrn_meta = [u['meta'] for u in units_data] # responses[s][n] = (R, T) tensor or (1, 1) NaN sentinel self.responses = [] for uid in stim_uids: row = [] for u in units_data: if uid in u['responses']: row.append(u['responses'][uid]) else: row.append(torch.full((1, 1), float('nan'))) self.responses.append(row) # smooth PSTHs with a 21 ms Hanning window (Hsu / Borst / Theunissen 2004) if smooth: self.smooth_responses(window_ms=21.0) # self.nrn_masks is a derived @property on the base class — no need # to populate it here self.validate()