deepSTRF.utils package
Submodules
deepSTRF.utils.data module
- class deepSTRF.utils.data.ResponseSmoothingTransform(dt_ms=1, window_size_ms=21, *args, **kwargs)[source]
Bases:
ModuleTemporally smooth responses with a Hanning window (typically ~20-40 ms).
- Parameters:
dt_ms (
float, default1) – Time-bin width of the responses, in ms.window_size_ms (
float, default21) – Full width of the Hanning window in ms (rounded to an odd number ofdt_msbins).
References
Hsu, A., Borst, A., & Theunissen, F. E. (2004). Quantifying variability in neural responses and its application for the validation of model predictions. Network: Computation in Neural Systems, 15(2), 91-109. https://doi.org/10.1088/0954-898X_15_2_002
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(responses, dt=1)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- deepSTRF.utils.data.concat_neural_datasets(datasets: Sequence[NeuralDataset], names: Sequence[str] | None = None) NeuralDataset[source]
Concatenate neural datasets along BOTH the stim and neuron axes.
Given
kdatasets with(S_i, N_i)stimuli and neurons each, returns a single dataset withS = sum(S_i)stimuli andN = sum(N_i)neurons. The response grid is block-diagonal: real data where a stimulus belongs to a given source dataset and the neuron belongs to the same source,(1, 1)NaN sentinels everywhere else. This cross-block missingness is paradigm-compliant —nrn_masks(the derived property) then reflects the block-diagonal coverage automatically.Primary use case: building “chimeric” datasets that pool recordings across species / labs / preparations (e.g. CRCNS AA1 + AA2 + NS1 for auditory), so that a single model can be fit to the union.
- Parameters:
datasets (
sequenceofNeuralDataset) – Two or more instances. They must be of compatible types and sharedt(bin width) and any modality-specific dimensions (Ffor audio,(H, W)for video). Compatibility is checked by each class’s_concat_check_compathook; mismatches raiseAssertionError. Resampling to aligndtorFis the caller’s responsibility and must be done before concatenation.names (
sequenceofstr, optional) –One label per input dataset, written into
stim_meta["dataset"]andnrn_meta["dataset"]on the output as a provenance tag. Defaults to[type(d).__name__ for d in datasets]— i.e. the class name ("CRCNSAA1Dataset"etc.). Pass explicit names to disambiguate two instances of the same class, or to use a shorter human-readable label.The tags enable post-hoc selection by source dataset via
NeuralDataset.select_pop_by_nrn_attr()/select_stims_by_attr()(e.g.c.select_pop_by_nrn_attr("dataset", "CRCNSAA1Dataset")). Existing"dataset"entries in the input metadata are overwritten — nest-concat callers wanting to preserve inner provenance should passnames=explicitly.
- Returns:
A fresh instance. Its concrete type is the most-specific class that is a superclass of every input (
type(datasets[0])when all inputs share a type, otherwise walks the MRO). Neuron selection is reset (self.I = []).- Return type:
NeuralDataset
Notes
Concatenation is eager — the output holds its own full
(S, N)grid of response references in memory. At deepSTRF scales (S, N in the low hundreds) this is negligible; cross-block entries are single-element(1, 1)NaN tensors that cost ~8 bytes each. A lazy wrapper-class alternative exists but would complicateself.responses[s][n]access for uncertain benefit at this scale.Metadata dicts on the output are shallow copies of the inputs’ (the
"dataset"tag is written into the copies, never into the sources). Tensors and other shared values inside those dicts are not deep-copied — mutate at your own risk.Neuron / stim UID uniqueness across inputs is not validated — deepSTRF trusts the caller to pass mutually exclusive sources, since that is the only semantically meaningful case (pooling a dataset’s subset with its superset is degenerate — use constructor arguments instead).
The single-dataset case (
len(datasets) == 1) returns the input unchanged, with no"dataset"tagging applied — provenance only becomes meaningful once there is more than one source.
- deepSTRF.utils.data.concatenate_datasets(ds1: NeuralDataset, ds2: NeuralDataset) NeuralDataset[source]
Deprecated — use
concat_neural_datasets([ds1, ds2])instead.
- deepSTRF.utils.data.fill_missing_data(stims: Sequence[Tensor], dims: int | Sequence[int], value: float = 0.0) Tensor[source]
Pad a list of tensors along one or more dimensions to match their maxima.
- Parameters:
stims (
sequenceoftorch.Tensor) –Stensors, each of shape(D0, D1, ..., Dk-1). Shapes may differ at the dimensions indimsbut must agree on all others.dims (
intorsequenceofint) – Dimension index or indices (negatives allowed) along which to pad. Refer to the tensors’ 0-based axes.value (
float, default0.0) – Fill value for padding.
- Returns:
A tensor of shape
(S, D0', D1', ..., Dk-1')where, for eachdindims,Dd' = max_i stims[i].shape[d]and, for other axes,Dd' = stims[0].shape[d].- Return type:
torch.Tensor
- deepSTRF.utils.data.hanning_smooth(response: Tensor, window_ms: float, dt_ms: float) Tensor[source]
Convolve response along its last (time) axis with a Hanning window.
- Parameters:
response (
torch.Tensor) – Response tensor of any shape; the last axis is assumed to be time.window_ms (
float) – Full width of the Hanning window in ms. Rounded to the nearest odd number ofdt_msbins (dt_ms-floor, then +1 if even).dt_ms (
float) – Time-bin width ofresponse, in ms.
- Returns:
Smoothed response, same shape as input.
- Return type:
torch.Tensor
Notes
Padded with zeros on both sides (
F.pad(..., mode='constant')), so edge bins get attenuated. The kernel is the rawnp.hanning(K), i.e. NOT sum-normalized — matches the legacy behaviour used by the Hsu / Borst / Theunissen (2004) PSTH smoothing step in the CRCNS-AA datasets.NaN-unsafe: NaN values propagate to neighbouring time bins under the window. Callers (e.g.
NeuralDataset.smooth_responses) must filter fully-NaN responses before calling.
- deepSTRF.utils.data.neural_collate(batch)[source]
Collate fn for any
NeuralDataset.Pads variable-duration stims with zeros along the last (time) axis and variable-duration / variable-repeat-count responses with NaN along both the repeat and the time axes. Derives a fine-grained
valid_maskfrom the NaN sentinels so downstream loss code can use boolean indexing or multiplicative masking without re-scanning.- Parameters:
batch (
listofdict) –Each dict is one item as yielded by
NeuralDataset.__getitem__, with keys:'stims'— a stim tensor of shape(..., T_s)(modality-specific leading dims, e.g.(1, F, T_s)for audio).'responses'— list of lengthN_selected; each element is a(R_{s,n}, T_s)spike-count tensor or a(1, 1)NaN sentinel.'valid_mask'—(N_selected,)per-neuron bool tensor (ignored here; the fine-grained batch'valid_mask'below subsumes it).'stim_meta'— per-stim metadata dict.
Extra keys (e.g.
'behav') are passed through untouched: any key not handled explicitly is collected into a length-Blist.- Returns:
A dict with keys:
'stims'—(B, ..., T_stim_max)float tensor, zero-padded along the last axis. Contains no NaN.'responses'—(B, N_selected, R_max, T_resp_max)float tensor. NaN-padded along both the repeat (R) and time (T) axes. Fully-NaN slabs mark (stim, neuron) pairs with no recorded data. The response-time axis is sized toT_resp_maxindependently of the stim-time axis: in spectrogram mode the two are equal (one bin per neural sample), but in waveform mode the stim axis runs ataudio_fsHz while responses stay at the dataset’s neuraldt_msrate.'valid_mask'—(B, N_selected, R_max, T_resp_max)bool tensor,~responses.isnan(), cached here so downstream loss code does not have to recompute.'stim_meta'— length-Blist of the per-item stim_meta dicts.any extra per-item keys — length-
Blists, passed through.
- Return type:
dict
deepSTRF.utils.data_download module
Auto-download utilities for deepSTRF datasets.
- Public surface:
default_cache_dir(dataset_name) -> Path— platformdirs-based defaultstream_download(url, dest_path)— resumable streaming downloadunzip(zip_path, dest_dir)— flat unzip with overwriteuntar(tar_path, dest_dir, strip_components=)— tar.gz / .tar / .tar.bz2 unpackosf_download(guid, dest)— public OSF storage filesgithub_raw_download(repo, path, dest, ref=)— public GitHub raw fileszenodo_download(record_id, filename, dest)— public Zenodo recordsfigshare_download(article_id, dest_dir, filename=)— public figshare articlescrcns_download(file_path, dest, username=, password=)— CRCNS (free account)
- deepSTRF.utils.data_download.crcns_download(file_path: str, dest_path: str | Path, *, username: str | None = None, password: str | None = None, chunk_size: int = 1048576, progress: bool = True) Path[source]
Download a single file from the CRCNS NERSC mirror with form auth.
The CRCNS download portal at
https://portal.nersc.gov/project/crcns/ download/<file_path>serves an HTML login form to anonymous GETs. To actually fetch the file, the form must be POSTed to the same URL withusername/password/fn/submitfields. There is no persistent session cookie — auth is per-request, so the same pattern works equally well whether you fetch one file or many.- Parameters:
file_path (
str) – Path under/download/, e.g."aa-1/crcns-aa1.zip"or"aa-4/BlaBro09xxF.tar.gz".dest_path (
path-like)username (
str, optional) – Default to$CRCNS_USERNAME/$CRCNS_PASSWORD. Account is free at https://crcns.org/register.password (
str, optional) – Default to$CRCNS_USERNAME/$CRCNS_PASSWORD. Account is free at https://crcns.org/register.chunk_size – As
stream_download.progress – As
stream_download.
- Returns:
Resolved destination.
- Return type:
Path- Raises:
RuntimeError – If credentials are missing, or if the response body still looks like the login form (auth failed silently — the portal returns 200 OK with the login HTML rather than 401 on bad credentials).
Notes
Status: experimental. The auth + URL conventions were reverse-engineered from probing the public NERSC mirror; we do not have a contract from CRCNS that they’ll stay stable. If the portal layout changes, this helper breaks. Verified 2026-04-25 against the AA1 archive (
aa-1/crcns-aa1.zip).Example
>>> import os >>> os.environ["CRCNS_USERNAME"] = "..." >>> os.environ["CRCNS_PASSWORD"] = "..." >>> crcns_download("aa-1/crcns-aa1.zip", "/tmp/aa1.zip")
- deepSTRF.utils.data_download.default_cache_dir(dataset_name: str) Path[source]
Return
$DEEPSTRF_DATA_DIR/<dataset>if the env var is set, otherwiseplatformdirs.user_cache_dir('deepSTRF') / <dataset>.The env-var override is the standard escape hatch for users on shared storage / scratch filesystems, and matches the convention in torchvision / huggingface_hub.
Download one file from a public figshare article.
Resolves the article’s file list via the public REST API (
https://api.figshare.com/v2/articles/<id>) and streams the matching file intodest_dir. Withfilename=None, the article must contain exactly one file; pass an explicit name to disambiguate when there are several.- Parameters:
article_id (
int | str) – Numeric figshare article id (the trailing component of the DOI10.6084/m9.figshare.<id>, e.g.29203457).dest_dir (
path-like) – Directory the file is downloaded into. Created if missing.filename (
str, optional) – Name of the file to fetch. Required when the article carries more than one file. Matched case-sensitively against the file’snamefield returned by the API.
- Returns:
Path to the downloaded file under
dest_dir.- Return type:
Path
Example
>>> figshare_download(29203457, "/tmp/le2025") PosixPath('/tmp/le2025/zebf-auditory-restoration-1.zip')
- deepSTRF.utils.data_download.github_raw_download(repo: str, path_in_repo: str, dest_path: str | Path, *, ref: str = 'HEAD', **kwargs) Path[source]
Download a file from a GitHub repo’s raw content.
Useful for paper-companion repos that publish small datasets / model artefacts alongside the code (e.g. DNet hosts
test_data_5ms.matfor the Rahman et al. 2018 NS1 reanalysis at https://github.com/monzilur/DNet).- Parameters:
repo (
str) –"<owner>/<name>", e.g."monzilur/DNet".path_in_repo (
str) – Path of the file within the repo, e.g."test_data_5ms.mat".dest_path (
path-like)ref (
str, default"HEAD") – Branch / tag / commit."HEAD"resolves the default branch.
Notes
Uses the
raw.githubusercontent.comCDN, which has no rate limit for anonymous reads (unlike the GitHub REST API).
- deepSTRF.utils.data_download.osf_download(file_guid: str, dest_path: str | Path, **kwargs) Path[source]
Download a single file from OSF by its short GUID.
Resolves to
https://osf.io/download/<guid>/— works for any public OSF storage file (the OSF API exposes this URL as thedownloadlink in each file’s metadata).Example
>>> osf_download("gdwyd", "MetadataSHEnCneurons.mat")
- deepSTRF.utils.data_download.stream_download(url: str, dest_path: str | Path, *, chunk_size: int = 1048576, progress: bool = True) Path[source]
Stream-download
urltodest_path. Atomic via a.partswap.Already-existing destination paths are returned unchanged (no-op) — the caller is responsible for cache invalidation.
- Parameters:
url (
str)dest_path (
str | Path)chunk_size (
int, default1 MiB)progress (
bool, defaultTrue) – Show a tqdm progress bar if available; falls back silently otherwise.
- Returns:
The destination path (resolved).
- Return type:
Path
- deepSTRF.utils.data_download.untar(tar_path: str | Path, dest_dir: str | Path, *, strip_components: int = 0) Path[source]
Extract a tar / tar.gz / tar.bz2 archive into
dest_dir.Mirrors GNU
tar --strip-components=N: drops the firstNpath components from every member. Useful when an archive wraps everything in nested directories that aren’t part of the dataset’s own layout — e.g. CRCNS-AA2 archives all wrap content incrcns/aa2/(strip 2).- Parameters:
tar_path (
path-like)dest_dir (
path-like)strip_components (
int, default0) – How many leading path components to drop. Members that have fewer components than this are silently skipped.
- Returns:
The destination directory.
- Return type:
Path
- deepSTRF.utils.data_download.unzip(zip_path: str | Path, dest_dir: str | Path, *, strip_root: bool = False) Path[source]
Unzip
zip_pathintodest_dir. Idempotent (overwrites existing files).- Parameters:
zip_path (
path-like)dest_dir (
path-like)strip_root (
bool, defaultFalse) – If True and the archive contains a single top-level directory, strip it from the extracted layout (sofoo/a/b->a/b). Mirrors the common--strip-components=1tar idiom.
- Returns:
The destination directory.
- Return type:
Path
- deepSTRF.utils.data_download.zenodo_download(record_id: int | str, filename: str, dest_path: str | Path, **kwargs) Path[source]
Download a single file from a public Zenodo record.
Resolves to
https://zenodo.org/api/records/<record_id>/files/<filename>/content— the canonical URL for fetching a file from a Zenodo record. Public records are accessible without auth.Example
>>> zenodo_download(8044773, "A1_NAT4_ozgf.fs100.ch18.tgz", "/tmp/X.tgz")
deepSTRF.utils.plotting module
Plotting helpers shared across the example notebooks.
Every function accepts either torch.Tensor or numpy.ndarray and
normalises internally; in line with the rest of the deepSTRF public API.
Functions:
plot_stim_with_response()— spectrogram + optional spike raster + PSTH (with optional model prediction overlay), shared x-axis.plot_psth_vs_pred()— single-panel target-vs-prediction overlay.plot_strf_grid()— grid of STRF / gradmap kernels.
All return matplotlib objects (Figure and/or Axes); callers
decide whether to plt.show(), save, or compose further. No
plt.show is invoked inside.
- deepSTRF.utils.plotting.compare_wav2spec_to_groundtruth(ds, wav2spec, stim_idx: int = 0, *, ground_truth_stims: Sequence | None = None, z_score: bool = True, figsize: Tuple[float, float] | None = None, suptitle: str | None = None)[source]
Side-by-side visual comparison of a learned/hand-built
wav2specoutput against a dataset’s precomputed (ground-truth) spectrogram.Useful for sanity-checking a new front-end on a dataset that ships both raw waveforms and a precomputed spectrogram (e.g. NS1, where the OSF release has raw wavs and the DNet companion repo provides the matching log-mel
X_nfht).- Parameters:
ds (
NeuralDataset) – Dataset instance in waveform mode (ds.stims[s].shape == (1, T_audio)). The dataset’s regular spectrogram is treated as the ground truth, supplied viaground_truth_stims.wav2spec (
nn.Module) – Module to apply. Must accept(B, 1, T_audio)and return(B, 1, F, T_neural)— i.e. the wav2spec slot contract.stim_idx (
int, default0) – Which dataset stim to compare. Indexesds.stims.ground_truth_stims (
sequence, optional) – Per-stim ground-truth spectrograms (each(F, T)or(1, F, T)). Required: the waveform-mode dataset doesn’t carry them itself. For NS1 build them by re-instantiatingNS1Dataset()(spec mode) and passing itsstims.z_score (
bool, defaultTrue) – If True, both spectrograms are independently z-scored (mean 0, std 1) before plotting, so a constant offset / global scale mismatch does not visually dominate the comparison.
- Returns:
pred_spec (
numpy.ndarray) – The wav2spec output, shape(F, T).truth_spec (
numpy.ndarray) – The ground-truth spectrogram, shape(F, T).fig (
matplotlib.figure.Figure) – 3-panel side-by-side figure (pred | truth | difference).
- deepSTRF.utils.plotting.plot_psth_vs_pred(target: ArrayLike, pred: ArrayLike, dt_ms: float | None = None, title: str | None = None, target_label: str = 'PSTH (target)', pred_label: str = 'model', ax: Axes | None = None, legend: bool = True) Axes[source]
Overlay a target PSTH and a model prediction on a single panel.
Designed to be called inside a per-cell loop (“best / median / worst”) that pre-builds a column of axes — this is the canonical val/test visualisation in
fit_ns1_statenet.ipynbandload_pretrained_statenet_ns1.ipynb.- Parameters:
target (
array-like,shape (T,)) – Trial-averaged target PSTH.pred (
array-like,shape (T,)) – Model prediction.dt_ms (
float, optional) – Bin width in milliseconds. If given the x-axis is in seconds; otherwise it is the bin index.title (
str, optional) – Axes title.target_label (
str) – Legend labels.pred_label (
str) – Legend labels.ax (
matplotlib.axes.Axes, optional) – Pre-made axes to draw into. If None, a new figure is created.legend (
bool, defaultTrue) – Whether to render the legend.
- Return type:
matplotlib.axes.Axes
- deepSTRF.utils.plotting.plot_stim_with_response(stim: ArrayLike, response: ArrayLike, pred: ArrayLike | None = None, dt_ms: float | None = None, title: str | None = None, spec_cmap: str = 'magma', raster_cmap: str = 'Greys', axes: Sequence[Axes] | None = None, figsize: Tuple[float, float] | None = None) Tuple[Figure, Sequence[Axes]][source]
Plot a stimulus spectrogram alongside its recorded (and optionally predicted) response.
Builds a vertically stacked figure with a shared time axis. The middle raster panel is omitted automatically when
responseis 1-D (already a PSTH).- Parameters:
stim (
array-like) – Stimulus spectrogram, shape(F, T)or(1, F, T).response (
array-like) – Per-trial responses(R, T)or pre-averaged PSTH(T,)/(1, T).R > 1→ spec/raster/PSTH; otherwise spec/PSTH.pred (
array-like, optional) – Model prediction, shape(T,). Overlaid on the PSTH panel.dt_ms (
float, optional) – Bin width in milliseconds. If given, the x-axis is in seconds; otherwise it is the bin index.title (
str, optional) – Suptitle for the whole figure.spec_cmap (
str) – Colormaps for the spectrogram and raster panels.raster_cmap (
str) – Colormaps for the spectrogram and raster panels.axes (
sequenceofmatplotlib.axes.Axes, optional) – Pre-made axes to draw into (2 or 3, matching the panel count). If None, a new figure is created.figsize (
(w,h), optional) – Figure size whenaxesis None.
- Returns:
fig (
matplotlib.figure.Figure)axes (
listofmatplotlib.axes.Axes) –[spec_ax, raster_ax, psth_ax]or[spec_ax, psth_ax].
- deepSTRF.utils.plotting.plot_strf_grid(strfs: ArrayLike | Sequence[ArrayLike], titles: Sequence[str] | None = None, dt_ms: float | None = None, ncols: int = 4, cmap: str = 'RdBu_r', shared_clim: bool = False, suptitle: str | None = None, figsize: Tuple[float, float] | None = None) Tuple[Figure, Sequence[Axes]][source]
Plot a grid of STRF / gradmap kernels.
The classical STRF visualisation: each panel is an
(F, T)weight map, frequency on the y-axis (low→high), time on the x-axis (history;[0, T·dt_ms]ifdt_msis given, else bin index). Diverging colormap (RdBu_rby default) with per-panel symmetric vmax — i.e. each cell gets its own|max|so the spatial structure is comparable across cells of very different gradient magnitudes. Passshared_clim=Truefor a global symmetric vmax if the kernels are intended to be compared on the same scale (e.g. same neuron under different model variants).- Parameters:
strfs (
array-like,shape ``(K,F,T)``orsequenceof :py:class:`(F`, :py:class:`T)`) – Stack of kernels to plot. Numpy ndarrays or torch tensors.titles (
sequenceofstr, optional) – Length-Klist of per-panel titles.None→ unlabeled.dt_ms (
float, optional) – Bin width in milliseconds. If given, the x-axis is in ms (history extent[0, T·dt_ms]); otherwise it is the bin index.ncols (
int, default4) – Number of columns in the grid; rows = ceil(K / ncols).cmap (
str, default :py:class:``”RdBu_r”:py:class:``)shared_clim (
bool, defaultFalse) – If True, use one global symmetricvmax = max_k |strf_k|across all panels. Otherwise per-panel.suptitle (
str, optional) – Figure-level title.figsize (
(w,h), optional) – Figure size. Default scales with the grid shape.
- Returns:
fig (
matplotlib.figure.Figure)axes (
flat listofmatplotlib.axes.Axes) – LengthK; unused grid cells (whenK < nrows·ncols) are hidden viaax.axis('off')and not included in the return.
deepSTRF.utils.hub module
HuggingFace Hub integration for pretrained deepSTRF models.
A deepSTRF “checkpoint” is a folder with three files:
config.json— JSON-serialisable__init__kwargs (auto-capturedby
NeuralModel) plus a_model_classsentinel for safety checks.
model.safetensors— the model’sstate_dict.README.md— optional model card (free-form markdown shown byHF Hub).
metadata.json— optional user metadata (test metrics, datasetname, training config, …). Preserved through round-trips, never used for instantiation.
This module exposes four helpers covering the local and Hub sides of that
layout. The high-level methods on
NeuralModel
(save_pretrained / from_pretrained / push_to_hub) are thin
wrappers over these.
Public surface
save_pretrained_to_dir()— write a model + config to a folderload_pretrained_from_dir()— instantiate a class + load weightsdownload_pretrained()— fetch an HF Hub repo to local cacheupload_pretrained()— push a folder to an HF Hub repo
- deepSTRF.utils.hub.download_pretrained(repo_id: str, *, cache_dir: str | Path | None = None, revision: str | None = None, token: str | None = None) Path[source]
Fetch all files of an HF Hub model repo into a local cache folder.
Thin wrapper around
huggingface_hub.snapshot_download(). The Hub client handles caching, resumability, and ETag validation, so repeated calls with the samerepo_idare essentially free.- Parameters:
repo_id (
str) – E.g."urancon/deepSTRF-statenet-gru-ns1".cache_dir (
path-like, optional) – Override the HF Hub cache root. Defaults to~/.cache/huggingface/hub.revision (
str, optional) – Branch, tag, or commit SHA. Defaults to the repo’s default branch.token (
str, optional) – HF auth token for private repos. Public repos work anonymously. For convenience, also picks up$HF_TOKENautomatically via the Hub client.
- Returns:
Folder containing the snapshot.
- Return type:
Path
- deepSTRF.utils.hub.load_pretrained_from_dir(model_class: Type[Module], load_dir: str | Path, *, extra_kwargs: dict | None = None, strict: bool = True, map_location: str | device = 'cpu') Tuple[Module, dict | None][source]
Instantiate
model_classfrom a deepSTRF checkpoint folder.- Parameters:
model_class (
type) – The concrete class to instantiate (e.g.StateNet).load_dir (
path-like) – Folder produced bysave_pretrained_to_dir().extra_kwargs (
dict, optional) – Override / extend the JSON-loaded config. Useful for kwargs that couldn’t be JSON-serialised at save time (e.g.prefilteringoroutput_activationnn.Moduleinstances). Merged on top of the loaded config;Noneis allowed.strict (
bool, defaultTrue) – Passed toload_state_dict. Set toFalseto tolerate missing / unexpected keys (e.g. when loading a population checkpoint into a single-cell model).map_location (
strortorch.device, default :py:class:``’cpu’:py:class:``) – Device to move tensors to. Mirrorstorch.load’s argument.
- Returns:
(model, metadata) –
modelis an instance ofmodel_classwith weights loaded.metadatais the parsedmetadata.jsonif present, elseNone.- Return type:
tuple
- deepSTRF.utils.hub.save_pretrained_to_dir(model: Module, save_dir: str | Path, *, metadata: dict | None = None, model_card: str | None = None) Path[source]
Write a deepSTRF checkpoint to
save_dir.The model must carry a
_init_kwargsattribute (auto-populated by__init_subclass__). Everything else is optional.- Parameters:
model (
nn.Module) – Model instance with_init_kwargsset.save_dir (
path-like) – Destination folder. Created if missing; existing files are overwritten.metadata (
dict, optional) – Free-form, JSON-serialisable user metadata (test metrics, dataset name, …). Written tometadata.json.model_card (
str, optional) – Markdown content forREADME.md. IfNone, a minimal default card is generated from the class name and config.
- Returns:
The resolved
save_dir.- Return type:
Path- Raises:
AttributeError – If
model._init_kwargsis missing — i.e. the class isn’t aNeuralModelsubclass, or its__init__was not run through the auto-capture wrapper.
- deepSTRF.utils.hub.upload_pretrained(local_dir: str | Path, repo_id: str, *, private: bool = False, token: str | None = None, commit_message: str | None = None) str[source]
Push a checkpoint folder to an HF Hub model repo.
Creates the repo on the fly if it doesn’t exist (idempotent — safe to call repeatedly). The user must be authenticated with write access to
repo_id(runhf auth loginonce or passtoken=).- Parameters:
local_dir (
path-like) – Folder produced bysave_pretrained_to_dir().repo_id (
str) –"<owner>/<name>"— the owner must be your username or an org you can write to.private (
bool, defaultFalse) – If creating a new repo, mark it private. Ignored if the repo already exists.token (
str, optional) – HF auth token. Defaults to the cached login token /$HF_TOKEN.commit_message (
str, optional) – Git commit message on the Hub repo. Defaults to a short message with the timestamp.
- Returns:
URL of the resulting commit.
- Return type:
str
Module contents
deepSTRF.utils — cross-cutting helpers.
For dataset-side helpers (concat_neural_datasets, neural_collate,
hanning_smooth, …) see deepSTRF.utils.data. For training
utilities (Fitter, set_random_seed) see deepSTRF.training.
For shared notebook plotting (stim+response panels, PSTH-vs-prediction
overlays) see deepSTRF.utils.plotting.
- class deepSTRF.utils.ResponseSmoothingTransform(dt_ms=1, window_size_ms=21, *args, **kwargs)[source]
Bases:
ModuleTemporally smooth responses with a Hanning window (typically ~20-40 ms).
- Parameters:
dt_ms (
float, default1) – Time-bin width of the responses, in ms.window_size_ms (
float, default21) – Full width of the Hanning window in ms (rounded to an odd number ofdt_msbins).
References
Hsu, A., Borst, A., & Theunissen, F. E. (2004). Quantifying variability in neural responses and its application for the validation of model predictions. Network: Computation in Neural Systems, 15(2), 91-109. https://doi.org/10.1088/0954-898X_15_2_002
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(responses, dt=1)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- deepSTRF.utils.compare_wav2spec_to_groundtruth(ds, wav2spec, stim_idx: int = 0, *, ground_truth_stims: Sequence | None = None, z_score: bool = True, figsize: Tuple[float, float] | None = None, suptitle: str | None = None)[source]
Side-by-side visual comparison of a learned/hand-built
wav2specoutput against a dataset’s precomputed (ground-truth) spectrogram.Useful for sanity-checking a new front-end on a dataset that ships both raw waveforms and a precomputed spectrogram (e.g. NS1, where the OSF release has raw wavs and the DNet companion repo provides the matching log-mel
X_nfht).- Parameters:
ds (
NeuralDataset) – Dataset instance in waveform mode (ds.stims[s].shape == (1, T_audio)). The dataset’s regular spectrogram is treated as the ground truth, supplied viaground_truth_stims.wav2spec (
nn.Module) – Module to apply. Must accept(B, 1, T_audio)and return(B, 1, F, T_neural)— i.e. the wav2spec slot contract.stim_idx (
int, default0) – Which dataset stim to compare. Indexesds.stims.ground_truth_stims (
sequence, optional) – Per-stim ground-truth spectrograms (each(F, T)or(1, F, T)). Required: the waveform-mode dataset doesn’t carry them itself. For NS1 build them by re-instantiatingNS1Dataset()(spec mode) and passing itsstims.z_score (
bool, defaultTrue) – If True, both spectrograms are independently z-scored (mean 0, std 1) before plotting, so a constant offset / global scale mismatch does not visually dominate the comparison.
- Returns:
pred_spec (
numpy.ndarray) – The wav2spec output, shape(F, T).truth_spec (
numpy.ndarray) – The ground-truth spectrogram, shape(F, T).fig (
matplotlib.figure.Figure) – 3-panel side-by-side figure (pred | truth | difference).
- deepSTRF.utils.concat_neural_datasets(datasets: Sequence[NeuralDataset], names: Sequence[str] | None = None) NeuralDataset[source]
Concatenate neural datasets along BOTH the stim and neuron axes.
Given
kdatasets with(S_i, N_i)stimuli and neurons each, returns a single dataset withS = sum(S_i)stimuli andN = sum(N_i)neurons. The response grid is block-diagonal: real data where a stimulus belongs to a given source dataset and the neuron belongs to the same source,(1, 1)NaN sentinels everywhere else. This cross-block missingness is paradigm-compliant —nrn_masks(the derived property) then reflects the block-diagonal coverage automatically.Primary use case: building “chimeric” datasets that pool recordings across species / labs / preparations (e.g. CRCNS AA1 + AA2 + NS1 for auditory), so that a single model can be fit to the union.
- Parameters:
datasets (
sequenceofNeuralDataset) – Two or more instances. They must be of compatible types and sharedt(bin width) and any modality-specific dimensions (Ffor audio,(H, W)for video). Compatibility is checked by each class’s_concat_check_compathook; mismatches raiseAssertionError. Resampling to aligndtorFis the caller’s responsibility and must be done before concatenation.names (
sequenceofstr, optional) –One label per input dataset, written into
stim_meta["dataset"]andnrn_meta["dataset"]on the output as a provenance tag. Defaults to[type(d).__name__ for d in datasets]— i.e. the class name ("CRCNSAA1Dataset"etc.). Pass explicit names to disambiguate two instances of the same class, or to use a shorter human-readable label.The tags enable post-hoc selection by source dataset via
NeuralDataset.select_pop_by_nrn_attr()/select_stims_by_attr()(e.g.c.select_pop_by_nrn_attr("dataset", "CRCNSAA1Dataset")). Existing"dataset"entries in the input metadata are overwritten — nest-concat callers wanting to preserve inner provenance should passnames=explicitly.
- Returns:
A fresh instance. Its concrete type is the most-specific class that is a superclass of every input (
type(datasets[0])when all inputs share a type, otherwise walks the MRO). Neuron selection is reset (self.I = []).- Return type:
NeuralDataset
Notes
Concatenation is eager — the output holds its own full
(S, N)grid of response references in memory. At deepSTRF scales (S, N in the low hundreds) this is negligible; cross-block entries are single-element(1, 1)NaN tensors that cost ~8 bytes each. A lazy wrapper-class alternative exists but would complicateself.responses[s][n]access for uncertain benefit at this scale.Metadata dicts on the output are shallow copies of the inputs’ (the
"dataset"tag is written into the copies, never into the sources). Tensors and other shared values inside those dicts are not deep-copied — mutate at your own risk.Neuron / stim UID uniqueness across inputs is not validated — deepSTRF trusts the caller to pass mutually exclusive sources, since that is the only semantically meaningful case (pooling a dataset’s subset with its superset is degenerate — use constructor arguments instead).
The single-dataset case (
len(datasets) == 1) returns the input unchanged, with no"dataset"tagging applied — provenance only becomes meaningful once there is more than one source.
- deepSTRF.utils.hanning_smooth(response: Tensor, window_ms: float, dt_ms: float) Tensor[source]
Convolve response along its last (time) axis with a Hanning window.
- Parameters:
response (
torch.Tensor) – Response tensor of any shape; the last axis is assumed to be time.window_ms (
float) – Full width of the Hanning window in ms. Rounded to the nearest odd number ofdt_msbins (dt_ms-floor, then +1 if even).dt_ms (
float) – Time-bin width ofresponse, in ms.
- Returns:
Smoothed response, same shape as input.
- Return type:
torch.Tensor
Notes
Padded with zeros on both sides (
F.pad(..., mode='constant')), so edge bins get attenuated. The kernel is the rawnp.hanning(K), i.e. NOT sum-normalized — matches the legacy behaviour used by the Hsu / Borst / Theunissen (2004) PSTH smoothing step in the CRCNS-AA datasets.NaN-unsafe: NaN values propagate to neighbouring time bins under the window. Callers (e.g.
NeuralDataset.smooth_responses) must filter fully-NaN responses before calling.
- deepSTRF.utils.neural_collate(batch)[source]
Collate fn for any
NeuralDataset.Pads variable-duration stims with zeros along the last (time) axis and variable-duration / variable-repeat-count responses with NaN along both the repeat and the time axes. Derives a fine-grained
valid_maskfrom the NaN sentinels so downstream loss code can use boolean indexing or multiplicative masking without re-scanning.- Parameters:
batch (
listofdict) –Each dict is one item as yielded by
NeuralDataset.__getitem__, with keys:'stims'— a stim tensor of shape(..., T_s)(modality-specific leading dims, e.g.(1, F, T_s)for audio).'responses'— list of lengthN_selected; each element is a(R_{s,n}, T_s)spike-count tensor or a(1, 1)NaN sentinel.'valid_mask'—(N_selected,)per-neuron bool tensor (ignored here; the fine-grained batch'valid_mask'below subsumes it).'stim_meta'— per-stim metadata dict.
Extra keys (e.g.
'behav') are passed through untouched: any key not handled explicitly is collected into a length-Blist.- Returns:
A dict with keys:
'stims'—(B, ..., T_stim_max)float tensor, zero-padded along the last axis. Contains no NaN.'responses'—(B, N_selected, R_max, T_resp_max)float tensor. NaN-padded along both the repeat (R) and time (T) axes. Fully-NaN slabs mark (stim, neuron) pairs with no recorded data. The response-time axis is sized toT_resp_maxindependently of the stim-time axis: in spectrogram mode the two are equal (one bin per neural sample), but in waveform mode the stim axis runs ataudio_fsHz while responses stay at the dataset’s neuraldt_msrate.'valid_mask'—(B, N_selected, R_max, T_resp_max)bool tensor,~responses.isnan(), cached here so downstream loss code does not have to recompute.'stim_meta'— length-Blist of the per-item stim_meta dicts.any extra per-item keys — length-
Blists, passed through.
- Return type:
dict
- deepSTRF.utils.plot_psth_vs_pred(target: ArrayLike, pred: ArrayLike, dt_ms: float | None = None, title: str | None = None, target_label: str = 'PSTH (target)', pred_label: str = 'model', ax: Axes | None = None, legend: bool = True) Axes[source]
Overlay a target PSTH and a model prediction on a single panel.
Designed to be called inside a per-cell loop (“best / median / worst”) that pre-builds a column of axes — this is the canonical val/test visualisation in
fit_ns1_statenet.ipynbandload_pretrained_statenet_ns1.ipynb.- Parameters:
target (
array-like,shape (T,)) – Trial-averaged target PSTH.pred (
array-like,shape (T,)) – Model prediction.dt_ms (
float, optional) – Bin width in milliseconds. If given the x-axis is in seconds; otherwise it is the bin index.title (
str, optional) – Axes title.target_label (
str) – Legend labels.pred_label (
str) – Legend labels.ax (
matplotlib.axes.Axes, optional) – Pre-made axes to draw into. If None, a new figure is created.legend (
bool, defaultTrue) – Whether to render the legend.
- Return type:
matplotlib.axes.Axes
- deepSTRF.utils.plot_stim_with_response(stim: ArrayLike, response: ArrayLike, pred: ArrayLike | None = None, dt_ms: float | None = None, title: str | None = None, spec_cmap: str = 'magma', raster_cmap: str = 'Greys', axes: Sequence[Axes] | None = None, figsize: Tuple[float, float] | None = None) Tuple[Figure, Sequence[Axes]][source]
Plot a stimulus spectrogram alongside its recorded (and optionally predicted) response.
Builds a vertically stacked figure with a shared time axis. The middle raster panel is omitted automatically when
responseis 1-D (already a PSTH).- Parameters:
stim (
array-like) – Stimulus spectrogram, shape(F, T)or(1, F, T).response (
array-like) – Per-trial responses(R, T)or pre-averaged PSTH(T,)/(1, T).R > 1→ spec/raster/PSTH; otherwise spec/PSTH.pred (
array-like, optional) – Model prediction, shape(T,). Overlaid on the PSTH panel.dt_ms (
float, optional) – Bin width in milliseconds. If given, the x-axis is in seconds; otherwise it is the bin index.title (
str, optional) – Suptitle for the whole figure.spec_cmap (
str) – Colormaps for the spectrogram and raster panels.raster_cmap (
str) – Colormaps for the spectrogram and raster panels.axes (
sequenceofmatplotlib.axes.Axes, optional) – Pre-made axes to draw into (2 or 3, matching the panel count). If None, a new figure is created.figsize (
(w,h), optional) – Figure size whenaxesis None.
- Returns:
fig (
matplotlib.figure.Figure)axes (
listofmatplotlib.axes.Axes) –[spec_ax, raster_ax, psth_ax]or[spec_ax, psth_ax].
- deepSTRF.utils.plot_strf_grid(strfs: ArrayLike | Sequence[ArrayLike], titles: Sequence[str] | None = None, dt_ms: float | None = None, ncols: int = 4, cmap: str = 'RdBu_r', shared_clim: bool = False, suptitle: str | None = None, figsize: Tuple[float, float] | None = None) Tuple[Figure, Sequence[Axes]][source]
Plot a grid of STRF / gradmap kernels.
The classical STRF visualisation: each panel is an
(F, T)weight map, frequency on the y-axis (low→high), time on the x-axis (history;[0, T·dt_ms]ifdt_msis given, else bin index). Diverging colormap (RdBu_rby default) with per-panel symmetric vmax — i.e. each cell gets its own|max|so the spatial structure is comparable across cells of very different gradient magnitudes. Passshared_clim=Truefor a global symmetric vmax if the kernels are intended to be compared on the same scale (e.g. same neuron under different model variants).- Parameters:
strfs (
array-like,shape ``(K,F,T)``orsequenceof :py:class:`(F`, :py:class:`T)`) – Stack of kernels to plot. Numpy ndarrays or torch tensors.titles (
sequenceofstr, optional) – Length-Klist of per-panel titles.None→ unlabeled.dt_ms (
float, optional) – Bin width in milliseconds. If given, the x-axis is in ms (history extent[0, T·dt_ms]); otherwise it is the bin index.ncols (
int, default4) – Number of columns in the grid; rows = ceil(K / ncols).cmap (
str, default :py:class:``”RdBu_r”:py:class:``)shared_clim (
bool, defaultFalse) – If True, use one global symmetricvmax = max_k |strf_k|across all panels. Otherwise per-panel.suptitle (
str, optional) – Figure-level title.figsize (
(w,h), optional) – Figure size. Default scales with the grid shape.
- Returns:
fig (
matplotlib.figure.Figure)axes (
flat listofmatplotlib.axes.Axes) – LengthK; unused grid cells (whenK < nrows·ncols) are hidden viaax.axis('off')and not included in the return.