import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Optional, Union, Sequence
from deepSTRF.datasets import NeuralDataset
[docs]
def hanning_smooth(response: torch.Tensor, window_ms: float, dt_ms: float) -> torch.Tensor:
"""Convolve `response` along its last (time) axis with a Hanning window.
Parameters
----------
response : torch.Tensor
Response tensor of any shape; the last axis is assumed to be time.
window_ms : float
Full width of the Hanning window in ms. Rounded to the nearest odd
number of ``dt_ms`` bins (``dt_ms``-floor, then +1 if even).
dt_ms : float
Time-bin width of ``response``, in ms.
Returns
-------
torch.Tensor
Smoothed response, same shape as input.
Notes
-----
Padded with zeros on both sides (``F.pad(..., mode='constant')``), so
edge bins get attenuated. The kernel is the raw ``np.hanning(K)``, i.e.
NOT sum-normalized — matches the legacy behaviour used by the Hsu /
Borst / Theunissen (2004) PSTH smoothing step in the CRCNS-AA datasets.
NaN-unsafe: NaN values propagate to neighbouring time bins under the
window. Callers (e.g. ``NeuralDataset.smooth_responses``) must filter
fully-NaN responses before calling.
"""
assert window_ms > 0 and dt_ms > 0, "window_ms and dt_ms must be positive"
K = int(window_ms // dt_ms)
if K < 1:
K = 1
if K % 2 == 0:
K += 1
kernel = torch.tensor(np.hanning(K), dtype=response.dtype, device=response.device).view(1, 1, K)
pad = (K - 1) // 2
orig_shape = response.shape
flat = response.reshape(-1, 1, orig_shape[-1]) # (*, 1, T)
flat = F.pad(flat, (pad, pad), mode='constant', value=0.0)
smoothed = F.conv1d(flat, kernel)
return smoothed.view(orig_shape)
[docs]
def neural_collate(batch):
"""Collate fn for any :class:`~deepSTRF.datasets.neural_dataset.NeuralDataset`.
Pads variable-duration stims with zeros along the last (time) axis and
variable-duration / variable-repeat-count responses with NaN along both
the repeat and the time axes. Derives a fine-grained ``valid_mask`` from
the NaN sentinels so downstream loss code can use boolean indexing or
multiplicative masking without re-scanning.
Parameters
----------
batch : list of dict
Each dict is one item as yielded by ``NeuralDataset.__getitem__``,
with keys:
* ``'stims'`` — a stim tensor of shape ``(..., T_s)`` (modality-specific
leading dims, e.g. ``(1, F, T_s)`` for audio).
* ``'responses'`` — list of length ``N_selected``; each element is a
``(R_{s,n}, T_s)`` spike-count tensor or a ``(1, 1)`` NaN sentinel.
* ``'valid_mask'`` — ``(N_selected,)`` per-neuron bool tensor (ignored
here; the fine-grained batch ``'valid_mask'`` below subsumes it).
* ``'stim_meta'`` — per-stim metadata dict.
Extra keys (e.g. ``'behav'``) are passed through untouched: any key
not handled explicitly is collected into a length-``B`` list.
Returns
-------
dict
A dict with keys:
* ``'stims'`` — ``(B, ..., T_stim_max)`` float tensor, zero-padded
along the last axis. Contains no NaN.
* ``'responses'`` — ``(B, N_selected, R_max, T_resp_max)`` float
tensor. NaN-padded along both the repeat (``R``) and time (``T``)
axes. Fully-NaN slabs mark (stim, neuron) pairs with no recorded
data. The response-time axis is sized to ``T_resp_max``
independently of the stim-time axis: in spectrogram mode the two
are equal (one bin per neural sample), but in waveform mode the
stim axis runs at ``audio_fs`` Hz while responses stay at the
dataset's neural ``dt_ms`` rate.
* ``'valid_mask'`` — ``(B, N_selected, R_max, T_resp_max)`` bool
tensor, ``~responses.isnan()``, cached here so downstream loss code
does not have to recompute.
* ``'stim_meta'`` — length-``B`` list of the per-item stim_meta dicts.
* any extra per-item keys — length-``B`` lists, passed through.
"""
stims_list = [item['stims'] for item in batch]
resps_list = [item['responses'] for item in batch]
metas_list = [item['stim_meta'] for item in batch]
B = len(stims_list)
N = len(resps_list[0])
# pad stims along their time axis (last dim) with zeros.
# fill_missing_data operates over any shape; we ask it to pad the last axis.
stims = fill_missing_data(stims_list, dims=-1, value=0.0)
# response-T axis is independent from stim-T (they differ in waveform mode
# where stim is at audio_fs and responses are at neural rate); pad each
# axis to its own max.
T_resp_max = max((r.shape[1] for resps in resps_list for r in resps), default=0)
R_max = max((r.shape[0] for resps in resps_list for r in resps), default=0)
# fill the (B, N, R, T_resp) grid directly.
responses = torch.full((B, N, R_max, T_resp_max), float('nan'),
dtype=stims.dtype, device=stims.device)
for b in range(B):
for n in range(N):
r = resps_list[b][n]
responses[b, n, :r.shape[0], :r.shape[1]] = r
# derive fine-grained mask once per batch — the training loop gets it
# "for free" and does not need to scan again.
valid_mask = ~responses.isnan()
out = {
'stims': stims,
'responses': responses,
'valid_mask': valid_mask,
'stim_meta': list(metas_list),
}
# pass through any extra per-item keys (e.g. 'behav') as length-B lists,
# so datasets can add covariates without touching this collate.
handled = {'stims', 'responses', 'valid_mask', 'stim_meta'}
for key in batch[0]:
if key not in handled:
out[key] = [item[key] for item in batch]
return out
[docs]
def concat_neural_datasets(datasets: Sequence[NeuralDataset],
names: Optional[Sequence[str]] = None,
) -> NeuralDataset:
"""Concatenate neural datasets along BOTH the stim and neuron axes.
Given ``k`` datasets with ``(S_i, N_i)`` stimuli and neurons each, returns
a single dataset with ``S = sum(S_i)`` stimuli and ``N = sum(N_i)`` neurons.
The response grid is block-diagonal: real data where a stimulus belongs
to a given source dataset *and* the neuron belongs to the same source,
``(1, 1)`` NaN sentinels everywhere else. This cross-block missingness is
paradigm-compliant — ``nrn_masks`` (the derived property) then reflects
the block-diagonal coverage automatically.
Primary use case: building "chimeric" datasets that pool recordings
across species / labs / preparations (e.g. CRCNS AA1 + AA2 + NS1 for
auditory), so that a single model can be fit to the union.
Parameters
----------
datasets : sequence of NeuralDataset
Two or more instances. They must be of compatible types and share
``dt`` (bin width) and any modality-specific dimensions (``F`` for
audio, ``(H, W)`` for video). Compatibility is checked by each
class's ``_concat_check_compat`` hook; mismatches raise
``AssertionError``. Resampling to align ``dt`` or ``F`` is the
caller's responsibility and must be done before concatenation.
names : sequence of str, optional
One label per input dataset, written into ``stim_meta["dataset"]``
and ``nrn_meta["dataset"]`` on the output as a provenance
tag. Defaults to ``[type(d).__name__ for d in datasets]`` — i.e.
the class name (``"CRCNSAA1Dataset"`` etc.). Pass explicit names
to disambiguate two instances of the same class, or to use a
shorter human-readable label.
The tags enable post-hoc selection by source dataset via
:meth:`NeuralDataset.select_pop_by_nrn_attr` /
:meth:`select_stims_by_attr` (e.g.
``c.select_pop_by_nrn_attr("dataset", "CRCNSAA1Dataset")``).
Existing ``"dataset"`` entries in the input metadata are
overwritten — nest-concat callers wanting to preserve inner
provenance should pass ``names=`` explicitly.
Returns
-------
NeuralDataset
A fresh instance. Its concrete type is the most-specific class that
is a superclass of every input (``type(datasets[0])`` when all
inputs share a type, otherwise walks the MRO). Neuron selection is
reset (``self.I = []``).
Notes
-----
Concatenation is eager — the output holds its own full ``(S, N)`` grid
of response references in memory. At deepSTRF scales (S, N in the low
hundreds) this is negligible; cross-block entries are single-element
``(1, 1)`` NaN tensors that cost ~8 bytes each. A lazy wrapper-class
alternative exists but would complicate ``self.responses[s][n]``
access for uncertain benefit at this scale.
Metadata dicts on the output are **shallow copies** of the inputs'
(the ``"dataset"`` tag is written into the copies, never into the
sources). Tensors and other shared values inside those dicts are not
deep-copied — mutate at your own risk.
Neuron / stim UID uniqueness across inputs is *not* validated — deepSTRF
trusts the caller to pass mutually exclusive sources, since that is
the only semantically meaningful case (pooling a dataset's subset with
its superset is degenerate — use constructor arguments instead).
The single-dataset case (``len(datasets) == 1``) returns the input
unchanged, with no ``"dataset"`` tagging applied — provenance only
becomes meaningful once there is more than one source.
"""
assert len(datasets) >= 1, "concat_neural_datasets needs at least one dataset"
if len(datasets) == 1:
return datasets[0]
first = datasets[0]
for other in datasets[1:]:
assert isinstance(other, NeuralDataset), \
f"All entries must be NeuralDataset instances (got {type(other).__name__})"
first._concat_check_compat(other)
# resolve provenance labels (one per source dataset).
if names is None:
names = [type(d).__name__ for d in datasets]
else:
names = list(names)
assert len(names) == len(datasets), (
f"names must have one entry per dataset "
f"(got {len(names)} for {len(datasets)} datasets)"
)
# determine concrete output type: most-specific common ancestor.
types = [type(d) for d in datasets]
if len(set(types)) == 1:
out_cls = types[0]
else:
out_cls = NeuralDataset
for cls in types[0].__mro__:
if cls is object:
break
if all(isinstance(d, cls) for d in datasets):
out_cls = cls
break
# N cumulative sum, used both for row-offset when laying out responses
# and for the total N_neurons.
N_cum = [0]
for d in datasets:
N_cum.append(N_cum[-1] + d.N_neurons)
total_N = N_cum[-1]
# build the result as a bare instance (skip __init__, which would
# re-trigger data loading). All required attributes are set below.
out = out_cls.__new__(out_cls)
NeuralDataset.__init__(out, path="+".join(d.path for d in datasets), dt_ms=first.dt)
out._concat_copy_attrs(first)
# merge the core list-of-X attributes. stim_meta / nrn_meta are
# shallow-copied per entry so the provenance tag goes onto the output
# only — input datasets keep their original metadata dicts untouched.
out.stims = [s for d in datasets for s in d.stims]
out.stim_meta = [{**m, "dataset": name}
for d, name in zip(datasets, names)
for m in d.stim_meta]
out.nrn_meta = [{**m, "dataset": name}
for d, name in zip(datasets, names)
for m in d.nrn_meta]
out.N_neurons = total_N
# build block-diagonal response grid.
# for each source dataset k and each of its stims, emit a row of length
# total_N where the k-th block holds real responses and the rest is NaN.
nan = torch.full((1, 1), float('nan'))
responses: List[list] = []
for k, d in enumerate(datasets):
prefix = [nan] * N_cum[k]
suffix = [nan] * (total_N - N_cum[k + 1])
for s_idx in range(len(d.stim_meta)):
responses.append(prefix + list(d.responses[s_idx]) + suffix)
out.responses = responses
out.validate()
return out
[docs]
def concatenate_datasets(ds1: NeuralDataset, ds2: NeuralDataset) -> NeuralDataset:
"""Deprecated — use ``concat_neural_datasets([ds1, ds2])`` instead."""
return concat_neural_datasets([ds1, ds2])
[docs]
def fill_missing_data(stims: Sequence[torch.Tensor],
dims: Union[int, Sequence[int]],
value: float = 0.0
) -> torch.Tensor:
"""Pad a list of tensors along one or more dimensions to match their maxima.
Parameters
----------
stims : sequence of torch.Tensor
``S`` tensors, each of shape ``(D0, D1, ..., Dk-1)``. Shapes may
differ at the dimensions in ``dims`` but must agree on all others.
dims : int or sequence of int
Dimension index or indices (negatives allowed) along which to pad.
Refer to the tensors' 0-based axes.
value : float, default 0.0
Fill value for padding.
Returns
-------
torch.Tensor
A tensor of shape ``(S, D0', D1', ..., Dk-1')`` where, for each ``d``
in ``dims``, ``Dd' = max_i stims[i].shape[d]`` and, for other axes,
``Dd' = stims[0].shape[d]``.
"""
if not stims:
raise ValueError("`stims` must be a non-empty sequence of tensors")
# Normalize dims to a sorted list of positive indices
if isinstance(dims, int):
dims = [dims]
dims = sorted({d if d >= 0 else d + stims[0].ndim for d in dims})
# Validate tensor shapes, dtype, device
first = stims[0]
if not isinstance(first, torch.Tensor):
raise TypeError("All elements of `stims` must be torch.Tensor")
k = first.ndim
dtype, device = first.dtype, first.device
# Check dims are in range
for d in dims:
if not (0 <= d < k):
raise IndexError(f"Dimension {d} is out of bounds for tensors of ndim={k}")
# Compute max sizes for each dim in dims
max_sizes = {d: 0 for d in dims}
for t in stims:
if not isinstance(t, torch.Tensor):
raise TypeError("All elements of `stims` must be torch.Tensor")
if t.ndim != k:
raise ValueError(f"All tensors must have the same number of dims; got {t.ndim} vs {k}")
if t.dtype != dtype or t.device != device:
raise ValueError("All tensors must share the same dtype and device")
for d in dims:
max_sizes[d] = max(max_sizes[d], t.shape[d])
# Determine output shape
out_shape = []
for axis in range(k):
if axis in max_sizes:
out_shape.append(max_sizes[axis])
else:
# ensure consistent base shape on non-padded dims
base = first.shape[axis]
for t in stims:
if t.shape[axis] != base:
raise ValueError(
f"Dimension {axis} mismatch: got {t.shape[axis]} vs {base}"
)
out_shape.append(base)
S = len(stims)
# Preallocate output tensor: (S, *out_shape)
out = torch.full((S, *out_shape), fill_value=value, dtype=dtype, device=device)
# Copy each tensor into the padded output
for i, t in enumerate(stims):
# Build slice objects for each axis
# out has dims (S, D0', D1', ..., Dk-1')
# so original axis j maps to out axis j+1
slices = [i]
for axis in range(k):
if axis in dims:
slices.append(slice(0, t.shape[axis]))
else:
slices.append(slice(None))
out[tuple(slices)] = t
return out
if __name__ == "__main__":
from deepSTRF.datasets import CRCNSAA1Dataset
from deepSTRF.datasets import CRCNSAA2Dataset
aa1 = CRCNSAA1Dataset('../deepSTRF/datasets/audio/CRCNS_AA1/data/', areas=('MLd', 'Field_L'), stimuli=('conspecific, flatrip'))
aa2 = CRCNSAA2Dataset('../deepSTRF/datasets/audio/CRCNS_AA2/data/', areas=('MLd', 'OV', 'CM'), stimuli=('conspecific', 'songrip'))
aa12 = concatenate_datasets(aa1, aa2)