deepSTRF.utils package

Submodules

deepSTRF.utils.data module

class deepSTRF.utils.data.ResponseSmoothingTransform(dt_ms=1, window_size_ms=21, *args, **kwargs)[source]

Bases: Module

Temporally smooth responses with a Hanning window (typically ~20-40 ms).

Parameters:
  • dt_ms (float, default 1) – Time-bin width of the responses, in ms.

  • window_size_ms (float, default 21) – Full width of the Hanning window in ms (rounded to an odd number of dt_ms bins).

References

Hsu, A., Borst, A., & Theunissen, F. E. (2004). Quantifying variability in neural responses and its application for the validation of model predictions. Network: Computation in Neural Systems, 15(2), 91-109. https://doi.org/10.1088/0954-898X_15_2_002

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(responses, dt=1)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepSTRF.utils.data.concat_neural_datasets(datasets: Sequence[NeuralDataset], names: Sequence[str] | None = None) NeuralDataset[source]

Concatenate neural datasets along BOTH the stim and neuron axes.

Given k datasets with (S_i, N_i) stimuli and neurons each, returns a single dataset with S = sum(S_i) stimuli and N = sum(N_i) neurons. The response grid is block-diagonal: real data where a stimulus belongs to a given source dataset and the neuron belongs to the same source, (1, 1) NaN sentinels everywhere else. This cross-block missingness is paradigm-compliant — nrn_masks (the derived property) then reflects the block-diagonal coverage automatically.

Primary use case: building “chimeric” datasets that pool recordings across species / labs / preparations (e.g. CRCNS AA1 + AA2 + NS1 for auditory), so that a single model can be fit to the union.

Parameters:
  • datasets (sequence of NeuralDataset) – Two or more instances. They must be of compatible types and share dt (bin width) and any modality-specific dimensions (F for audio, (H, W) for video). Compatibility is checked by each class’s _concat_check_compat hook; mismatches raise AssertionError. Resampling to align dt or F is the caller’s responsibility and must be done before concatenation.

  • names (sequence of str, optional) –

    One label per input dataset, written into stim_meta["dataset"] and nrn_meta["dataset"] on the output as a provenance tag. Defaults to [type(d).__name__ for d in datasets] — i.e. the class name ("CRCNSAA1Dataset" etc.). Pass explicit names to disambiguate two instances of the same class, or to use a shorter human-readable label.

    The tags enable post-hoc selection by source dataset via NeuralDataset.select_pop_by_nrn_attr() / select_stims_by_attr() (e.g. c.select_pop_by_nrn_attr("dataset", "CRCNSAA1Dataset")). Existing "dataset" entries in the input metadata are overwritten — nest-concat callers wanting to preserve inner provenance should pass names= explicitly.

Returns:

A fresh instance. Its concrete type is the most-specific class that is a superclass of every input (type(datasets[0]) when all inputs share a type, otherwise walks the MRO). Neuron selection is reset (self.I = []).

Return type:

NeuralDataset

Notes

Concatenation is eager — the output holds its own full (S, N) grid of response references in memory. At deepSTRF scales (S, N in the low hundreds) this is negligible; cross-block entries are single-element (1, 1) NaN tensors that cost ~8 bytes each. A lazy wrapper-class alternative exists but would complicate self.responses[s][n] access for uncertain benefit at this scale.

Metadata dicts on the output are shallow copies of the inputs’ (the "dataset" tag is written into the copies, never into the sources). Tensors and other shared values inside those dicts are not deep-copied — mutate at your own risk.

Neuron / stim UID uniqueness across inputs is not validated — deepSTRF trusts the caller to pass mutually exclusive sources, since that is the only semantically meaningful case (pooling a dataset’s subset with its superset is degenerate — use constructor arguments instead).

The single-dataset case (len(datasets) == 1) returns the input unchanged, with no "dataset" tagging applied — provenance only becomes meaningful once there is more than one source.

deepSTRF.utils.data.concatenate_datasets(ds1: NeuralDataset, ds2: NeuralDataset) NeuralDataset[source]

Deprecated — use concat_neural_datasets([ds1, ds2]) instead.

deepSTRF.utils.data.fill_missing_data(stims: Sequence[Tensor], dims: int | Sequence[int], value: float = 0.0) Tensor[source]

Pad a list of tensors along one or more dimensions to match their maxima.

Parameters:
  • stims (sequence of torch.Tensor) – S tensors, each of shape (D0, D1, ..., Dk-1). Shapes may differ at the dimensions in dims but must agree on all others.

  • dims (int or sequence of int) – Dimension index or indices (negatives allowed) along which to pad. Refer to the tensors’ 0-based axes.

  • value (float, default 0.0) – Fill value for padding.

Returns:

A tensor of shape (S, D0', D1', ..., Dk-1') where, for each d in dims, Dd' = max_i stims[i].shape[d] and, for other axes, Dd' = stims[0].shape[d].

Return type:

torch.Tensor

deepSTRF.utils.data.hanning_smooth(response: Tensor, window_ms: float, dt_ms: float) Tensor[source]

Convolve response along its last (time) axis with a Hanning window.

Parameters:
  • response (torch.Tensor) – Response tensor of any shape; the last axis is assumed to be time.

  • window_ms (float) – Full width of the Hanning window in ms. Rounded to the nearest odd number of dt_ms bins (dt_ms-floor, then +1 if even).

  • dt_ms (float) – Time-bin width of response, in ms.

Returns:

Smoothed response, same shape as input.

Return type:

torch.Tensor

Notes

Padded with zeros on both sides (F.pad(..., mode='constant')), so edge bins get attenuated. The kernel is the raw np.hanning(K), i.e. NOT sum-normalized — matches the legacy behaviour used by the Hsu / Borst / Theunissen (2004) PSTH smoothing step in the CRCNS-AA datasets.

NaN-unsafe: NaN values propagate to neighbouring time bins under the window. Callers (e.g. NeuralDataset.smooth_responses) must filter fully-NaN responses before calling.

deepSTRF.utils.data.neural_collate(batch)[source]

Collate fn for any NeuralDataset.

Pads variable-duration stims with zeros along the last (time) axis and variable-duration / variable-repeat-count responses with NaN along both the repeat and the time axes. Derives a fine-grained valid_mask from the NaN sentinels so downstream loss code can use boolean indexing or multiplicative masking without re-scanning.

Parameters:

batch (list of dict) –

Each dict is one item as yielded by NeuralDataset.__getitem__, with keys:

  • 'stims' — a stim tensor of shape (..., T_s) (modality-specific leading dims, e.g. (1, F, T_s) for audio).

  • 'responses' — list of length N_selected; each element is a (R_{s,n}, T_s) spike-count tensor or a (1, 1) NaN sentinel.

  • 'valid_mask'(N_selected,) per-neuron bool tensor (ignored here; the fine-grained batch 'valid_mask' below subsumes it).

  • 'stim_meta' — per-stim metadata dict.

Extra keys (e.g. 'behav') are passed through untouched: any key not handled explicitly is collected into a length-B list.

Returns:

A dict with keys:

  • 'stims'(B, ..., T_stim_max) float tensor, zero-padded along the last axis. Contains no NaN.

  • 'responses'(B, N_selected, R_max, T_resp_max) float tensor. NaN-padded along both the repeat (R) and time (T) axes. Fully-NaN slabs mark (stim, neuron) pairs with no recorded data. The response-time axis is sized to T_resp_max independently of the stim-time axis: in spectrogram mode the two are equal (one bin per neural sample), but in waveform mode the stim axis runs at audio_fs Hz while responses stay at the dataset’s neural dt_ms rate.

  • 'valid_mask'(B, N_selected, R_max, T_resp_max) bool tensor, ~responses.isnan(), cached here so downstream loss code does not have to recompute.

  • 'stim_meta' — length-B list of the per-item stim_meta dicts.

  • any extra per-item keys — length-B lists, passed through.

Return type:

dict

deepSTRF.utils.data_download module

Auto-download utilities for deepSTRF datasets.

Public surface:
  • default_cache_dir(dataset_name) -> Path — platformdirs-based default

  • stream_download(url, dest_path) — resumable streaming download

  • unzip(zip_path, dest_dir) — flat unzip with overwrite

  • untar(tar_path, dest_dir, strip_components=) — tar.gz / .tar / .tar.bz2 unpack

  • osf_download(guid, dest) — public OSF storage files

  • github_raw_download(repo, path, dest, ref=) — public GitHub raw files

  • zenodo_download(record_id, filename, dest) — public Zenodo records

  • figshare_download(article_id, dest_dir, filename=) — public figshare articles

  • crcns_download(file_path, dest, username=, password=) — CRCNS (free account)

deepSTRF.utils.data_download.crcns_download(file_path: str, dest_path: str | Path, *, username: str | None = None, password: str | None = None, chunk_size: int = 1048576, progress: bool = True) Path[source]

Download a single file from the CRCNS NERSC mirror with form auth.

The CRCNS download portal at https://portal.nersc.gov/project/crcns/ download/<file_path> serves an HTML login form to anonymous GETs. To actually fetch the file, the form must be POSTed to the same URL with username / password / fn / submit fields. There is no persistent session cookie — auth is per-request, so the same pattern works equally well whether you fetch one file or many.

Parameters:
  • file_path (str) – Path under /download/, e.g. "aa-1/crcns-aa1.zip" or "aa-4/BlaBro09xxF.tar.gz".

  • dest_path (path-like)

  • username (str, optional) – Default to $CRCNS_USERNAME / $CRCNS_PASSWORD. Account is free at https://crcns.org/register.

  • password (str, optional) – Default to $CRCNS_USERNAME / $CRCNS_PASSWORD. Account is free at https://crcns.org/register.

  • chunk_size – As stream_download.

  • progress – As stream_download.

Returns:

Resolved destination.

Return type:

Path

Raises:

RuntimeError – If credentials are missing, or if the response body still looks like the login form (auth failed silently — the portal returns 200 OK with the login HTML rather than 401 on bad credentials).

Notes

Status: experimental. The auth + URL conventions were reverse-engineered from probing the public NERSC mirror; we do not have a contract from CRCNS that they’ll stay stable. If the portal layout changes, this helper breaks. Verified 2026-04-25 against the AA1 archive (aa-1/crcns-aa1.zip).

Example

>>> import os
>>> os.environ["CRCNS_USERNAME"] = "..."
>>> os.environ["CRCNS_PASSWORD"] = "..."
>>> crcns_download("aa-1/crcns-aa1.zip", "/tmp/aa1.zip")
deepSTRF.utils.data_download.default_cache_dir(dataset_name: str) Path[source]

Return $DEEPSTRF_DATA_DIR/<dataset> if the env var is set, otherwise platformdirs.user_cache_dir('deepSTRF') / <dataset>.

The env-var override is the standard escape hatch for users on shared storage / scratch filesystems, and matches the convention in torchvision / huggingface_hub.

deepSTRF.utils.data_download.figshare_download(article_id: int | str, dest_dir: str | Path, *, filename: str | None = None, **kwargs) Path[source]

Download one file from a public figshare article.

Resolves the article’s file list via the public REST API (https://api.figshare.com/v2/articles/<id>) and streams the matching file into dest_dir. With filename=None, the article must contain exactly one file; pass an explicit name to disambiguate when there are several.

Parameters:
  • article_id (int | str) – Numeric figshare article id (the trailing component of the DOI 10.6084/m9.figshare.<id>, e.g. 29203457).

  • dest_dir (path-like) – Directory the file is downloaded into. Created if missing.

  • filename (str, optional) – Name of the file to fetch. Required when the article carries more than one file. Matched case-sensitively against the file’s name field returned by the API.

Returns:

Path to the downloaded file under dest_dir.

Return type:

Path

Example

>>> figshare_download(29203457, "/tmp/le2025")
PosixPath('/tmp/le2025/zebf-auditory-restoration-1.zip')
deepSTRF.utils.data_download.github_raw_download(repo: str, path_in_repo: str, dest_path: str | Path, *, ref: str = 'HEAD', **kwargs) Path[source]

Download a file from a GitHub repo’s raw content.

Useful for paper-companion repos that publish small datasets / model artefacts alongside the code (e.g. DNet hosts test_data_5ms.mat for the Rahman et al. 2018 NS1 reanalysis at https://github.com/monzilur/DNet).

Parameters:
  • repo (str) – "<owner>/<name>", e.g. "monzilur/DNet".

  • path_in_repo (str) – Path of the file within the repo, e.g. "test_data_5ms.mat".

  • dest_path (path-like)

  • ref (str, default "HEAD") – Branch / tag / commit. "HEAD" resolves the default branch.

Notes

Uses the raw.githubusercontent.com CDN, which has no rate limit for anonymous reads (unlike the GitHub REST API).

deepSTRF.utils.data_download.osf_download(file_guid: str, dest_path: str | Path, **kwargs) Path[source]

Download a single file from OSF by its short GUID.

Resolves to https://osf.io/download/<guid>/ — works for any public OSF storage file (the OSF API exposes this URL as the download link in each file’s metadata).

Example

>>> osf_download("gdwyd", "MetadataSHEnCneurons.mat")
deepSTRF.utils.data_download.stream_download(url: str, dest_path: str | Path, *, chunk_size: int = 1048576, progress: bool = True) Path[source]

Stream-download url to dest_path. Atomic via a .part swap.

Already-existing destination paths are returned unchanged (no-op) — the caller is responsible for cache invalidation.

Parameters:
  • url (str)

  • dest_path (str | Path)

  • chunk_size (int, default 1 MiB)

  • progress (bool, default True) – Show a tqdm progress bar if available; falls back silently otherwise.

Returns:

The destination path (resolved).

Return type:

Path

deepSTRF.utils.data_download.untar(tar_path: str | Path, dest_dir: str | Path, *, strip_components: int = 0) Path[source]

Extract a tar / tar.gz / tar.bz2 archive into dest_dir.

Mirrors GNU tar --strip-components=N: drops the first N path components from every member. Useful when an archive wraps everything in nested directories that aren’t part of the dataset’s own layout — e.g. CRCNS-AA2 archives all wrap content in crcns/aa2/ (strip 2).

Parameters:
  • tar_path (path-like)

  • dest_dir (path-like)

  • strip_components (int, default 0) – How many leading path components to drop. Members that have fewer components than this are silently skipped.

Returns:

The destination directory.

Return type:

Path

deepSTRF.utils.data_download.unzip(zip_path: str | Path, dest_dir: str | Path, *, strip_root: bool = False) Path[source]

Unzip zip_path into dest_dir. Idempotent (overwrites existing files).

Parameters:
  • zip_path (path-like)

  • dest_dir (path-like)

  • strip_root (bool, default False) – If True and the archive contains a single top-level directory, strip it from the extracted layout (so foo/a/b -> a/b). Mirrors the common --strip-components=1 tar idiom.

Returns:

The destination directory.

Return type:

Path

deepSTRF.utils.data_download.zenodo_download(record_id: int | str, filename: str, dest_path: str | Path, **kwargs) Path[source]

Download a single file from a public Zenodo record.

Resolves to https://zenodo.org/api/records/<record_id>/files/<filename>/content — the canonical URL for fetching a file from a Zenodo record. Public records are accessible without auth.

Example

>>> zenodo_download(8044773, "A1_NAT4_ozgf.fs100.ch18.tgz", "/tmp/X.tgz")

deepSTRF.utils.plotting module

Plotting helpers shared across the example notebooks.

Every function accepts either torch.Tensor or numpy.ndarray and normalises internally; in line with the rest of the deepSTRF public API.

Functions:

All return matplotlib objects (Figure and/or Axes); callers decide whether to plt.show(), save, or compose further. No plt.show is invoked inside.

deepSTRF.utils.plotting.compare_wav2spec_to_groundtruth(ds, wav2spec, stim_idx: int = 0, *, ground_truth_stims: Sequence | None = None, z_score: bool = True, figsize: Tuple[float, float] | None = None, suptitle: str | None = None)[source]

Side-by-side visual comparison of a learned/hand-built wav2spec output against a dataset’s precomputed (ground-truth) spectrogram.

Useful for sanity-checking a new front-end on a dataset that ships both raw waveforms and a precomputed spectrogram (e.g. NS1, where the OSF release has raw wavs and the DNet companion repo provides the matching log-mel X_nfht).

Parameters:
  • ds (NeuralDataset) – Dataset instance in waveform mode (ds.stims[s].shape == (1, T_audio)). The dataset’s regular spectrogram is treated as the ground truth, supplied via ground_truth_stims.

  • wav2spec (nn.Module) – Module to apply. Must accept (B, 1, T_audio) and return (B, 1, F, T_neural) — i.e. the wav2spec slot contract.

  • stim_idx (int, default 0) – Which dataset stim to compare. Indexes ds.stims.

  • ground_truth_stims (sequence, optional) – Per-stim ground-truth spectrograms (each (F, T) or (1, F, T)). Required: the waveform-mode dataset doesn’t carry them itself. For NS1 build them by re-instantiating NS1Dataset() (spec mode) and passing its stims.

  • z_score (bool, default True) – If True, both spectrograms are independently z-scored (mean 0, std 1) before plotting, so a constant offset / global scale mismatch does not visually dominate the comparison.

Returns:

  • pred_spec (numpy.ndarray) – The wav2spec output, shape (F, T).

  • truth_spec (numpy.ndarray) – The ground-truth spectrogram, shape (F, T).

  • fig (matplotlib.figure.Figure) – 3-panel side-by-side figure (pred | truth | difference).

deepSTRF.utils.plotting.plot_psth_vs_pred(target: ArrayLike, pred: ArrayLike, dt_ms: float | None = None, title: str | None = None, target_label: str = 'PSTH (target)', pred_label: str = 'model', ax: Axes | None = None, legend: bool = True) Axes[source]

Overlay a target PSTH and a model prediction on a single panel.

Designed to be called inside a per-cell loop (“best / median / worst”) that pre-builds a column of axes — this is the canonical val/test visualisation in fit_ns1_statenet.ipynb and load_pretrained_statenet_ns1.ipynb.

Parameters:
  • target (array-like, shape (T,)) – Trial-averaged target PSTH.

  • pred (array-like, shape (T,)) – Model prediction.

  • dt_ms (float, optional) – Bin width in milliseconds. If given the x-axis is in seconds; otherwise it is the bin index.

  • title (str, optional) – Axes title.

  • target_label (str) – Legend labels.

  • pred_label (str) – Legend labels.

  • ax (matplotlib.axes.Axes, optional) – Pre-made axes to draw into. If None, a new figure is created.

  • legend (bool, default True) – Whether to render the legend.

Return type:

matplotlib.axes.Axes

deepSTRF.utils.plotting.plot_stim_with_response(stim: ArrayLike, response: ArrayLike, pred: ArrayLike | None = None, dt_ms: float | None = None, title: str | None = None, spec_cmap: str = 'magma', raster_cmap: str = 'Greys', axes: Sequence[Axes] | None = None, figsize: Tuple[float, float] | None = None) Tuple[Figure, Sequence[Axes]][source]

Plot a stimulus spectrogram alongside its recorded (and optionally predicted) response.

Builds a vertically stacked figure with a shared time axis. The middle raster panel is omitted automatically when response is 1-D (already a PSTH).

Parameters:
  • stim (array-like) – Stimulus spectrogram, shape (F, T) or (1, F, T).

  • response (array-like) – Per-trial responses (R, T) or pre-averaged PSTH (T,) / (1, T). R > 1 → spec/raster/PSTH; otherwise spec/PSTH.

  • pred (array-like, optional) – Model prediction, shape (T,). Overlaid on the PSTH panel.

  • dt_ms (float, optional) – Bin width in milliseconds. If given, the x-axis is in seconds; otherwise it is the bin index.

  • title (str, optional) – Suptitle for the whole figure.

  • spec_cmap (str) – Colormaps for the spectrogram and raster panels.

  • raster_cmap (str) – Colormaps for the spectrogram and raster panels.

  • axes (sequence of matplotlib.axes.Axes, optional) – Pre-made axes to draw into (2 or 3, matching the panel count). If None, a new figure is created.

  • figsize ((w, h), optional) – Figure size when axes is None.

Returns:

  • fig (matplotlib.figure.Figure)

  • axes (list of matplotlib.axes.Axes) – [spec_ax, raster_ax, psth_ax] or [spec_ax, psth_ax].

deepSTRF.utils.plotting.plot_strf_grid(strfs: ArrayLike | Sequence[ArrayLike], titles: Sequence[str] | None = None, dt_ms: float | None = None, ncols: int = 4, cmap: str = 'RdBu_r', shared_clim: bool = False, suptitle: str | None = None, figsize: Tuple[float, float] | None = None) Tuple[Figure, Sequence[Axes]][source]

Plot a grid of STRF / gradmap kernels.

The classical STRF visualisation: each panel is an (F, T) weight map, frequency on the y-axis (low→high), time on the x-axis (history; [0, T·dt_ms] if dt_ms is given, else bin index). Diverging colormap (RdBu_r by default) with per-panel symmetric vmax — i.e. each cell gets its own |max| so the spatial structure is comparable across cells of very different gradient magnitudes. Pass shared_clim=True for a global symmetric vmax if the kernels are intended to be compared on the same scale (e.g. same neuron under different model variants).

Parameters:
  • strfs (array-like, shape ``(K, F, T)`` or sequence of :py:class:`(F`, :py:class:`T)`) – Stack of kernels to plot. Numpy ndarrays or torch tensors.

  • titles (sequence of str, optional) – Length-K list of per-panel titles. None → unlabeled.

  • dt_ms (float, optional) – Bin width in milliseconds. If given, the x-axis is in ms (history extent [0, T·dt_ms]); otherwise it is the bin index.

  • ncols (int, default 4) – Number of columns in the grid; rows = ceil(K / ncols).

  • cmap (str, default :py:class:``”RdBu_r”:py:class:``)

  • shared_clim (bool, default False) – If True, use one global symmetric vmax = max_k |strf_k| across all panels. Otherwise per-panel.

  • suptitle (str, optional) – Figure-level title.

  • figsize ((w, h), optional) – Figure size. Default scales with the grid shape.

Returns:

  • fig (matplotlib.figure.Figure)

  • axes (flat list of matplotlib.axes.Axes) – Length K; unused grid cells (when K < nrows·ncols) are hidden via ax.axis('off') and not included in the return.

deepSTRF.utils.hub module

HuggingFace Hub integration for pretrained deepSTRF models.

A deepSTRF “checkpoint” is a folder with three files:

  • config.json — JSON-serialisable __init__ kwargs (auto-captured

    by NeuralModel) plus a _model_class sentinel for safety checks.

  • model.safetensors — the model’s state_dict.

  • README.md — optional model card (free-form markdown shown by

    HF Hub).

  • metadata.json — optional user metadata (test metrics, dataset

    name, training config, …). Preserved through round-trips, never used for instantiation.

This module exposes four helpers covering the local and Hub sides of that layout. The high-level methods on NeuralModel (save_pretrained / from_pretrained / push_to_hub) are thin wrappers over these.

Public surface

deepSTRF.utils.hub.download_pretrained(repo_id: str, *, cache_dir: str | Path | None = None, revision: str | None = None, token: str | None = None) Path[source]

Fetch all files of an HF Hub model repo into a local cache folder.

Thin wrapper around huggingface_hub.snapshot_download(). The Hub client handles caching, resumability, and ETag validation, so repeated calls with the same repo_id are essentially free.

Parameters:
  • repo_id (str) – E.g. "urancon/deepSTRF-statenet-gru-ns1".

  • cache_dir (path-like, optional) – Override the HF Hub cache root. Defaults to ~/.cache/huggingface/hub.

  • revision (str, optional) – Branch, tag, or commit SHA. Defaults to the repo’s default branch.

  • token (str, optional) – HF auth token for private repos. Public repos work anonymously. For convenience, also picks up $HF_TOKEN automatically via the Hub client.

Returns:

Folder containing the snapshot.

Return type:

Path

deepSTRF.utils.hub.load_pretrained_from_dir(model_class: Type[Module], load_dir: str | Path, *, extra_kwargs: dict | None = None, strict: bool = True, map_location: str | device = 'cpu') Tuple[Module, dict | None][source]

Instantiate model_class from a deepSTRF checkpoint folder.

Parameters:
  • model_class (type) – The concrete class to instantiate (e.g. StateNet).

  • load_dir (path-like) – Folder produced by save_pretrained_to_dir().

  • extra_kwargs (dict, optional) – Override / extend the JSON-loaded config. Useful for kwargs that couldn’t be JSON-serialised at save time (e.g. prefiltering or output_activation nn.Module instances). Merged on top of the loaded config; None is allowed.

  • strict (bool, default True) – Passed to load_state_dict. Set to False to tolerate missing / unexpected keys (e.g. when loading a population checkpoint into a single-cell model).

  • map_location (str or torch.device, default :py:class:``’cpu’:py:class:``) – Device to move tensors to. Mirrors torch.load’s argument.

Returns:

(model, metadata)model is an instance of model_class with weights loaded. metadata is the parsed metadata.json if present, else None.

Return type:

tuple

deepSTRF.utils.hub.save_pretrained_to_dir(model: Module, save_dir: str | Path, *, metadata: dict | None = None, model_card: str | None = None) Path[source]

Write a deepSTRF checkpoint to save_dir.

The model must carry a _init_kwargs attribute (auto-populated by __init_subclass__). Everything else is optional.

Parameters:
  • model (nn.Module) – Model instance with _init_kwargs set.

  • save_dir (path-like) – Destination folder. Created if missing; existing files are overwritten.

  • metadata (dict, optional) – Free-form, JSON-serialisable user metadata (test metrics, dataset name, …). Written to metadata.json.

  • model_card (str, optional) – Markdown content for README.md. If None, a minimal default card is generated from the class name and config.

Returns:

The resolved save_dir.

Return type:

Path

Raises:

AttributeError – If model._init_kwargs is missing — i.e. the class isn’t a NeuralModel subclass, or its __init__ was not run through the auto-capture wrapper.

deepSTRF.utils.hub.upload_pretrained(local_dir: str | Path, repo_id: str, *, private: bool = False, token: str | None = None, commit_message: str | None = None) str[source]

Push a checkpoint folder to an HF Hub model repo.

Creates the repo on the fly if it doesn’t exist (idempotent — safe to call repeatedly). The user must be authenticated with write access to repo_id (run hf auth login once or pass token=).

Parameters:
  • local_dir (path-like) – Folder produced by save_pretrained_to_dir().

  • repo_id (str) – "<owner>/<name>" — the owner must be your username or an org you can write to.

  • private (bool, default False) – If creating a new repo, mark it private. Ignored if the repo already exists.

  • token (str, optional) – HF auth token. Defaults to the cached login token / $HF_TOKEN.

  • commit_message (str, optional) – Git commit message on the Hub repo. Defaults to a short message with the timestamp.

Returns:

URL of the resulting commit.

Return type:

str

Module contents

deepSTRF.utils — cross-cutting helpers.

For dataset-side helpers (concat_neural_datasets, neural_collate, hanning_smooth, …) see deepSTRF.utils.data. For training utilities (Fitter, set_random_seed) see deepSTRF.training. For shared notebook plotting (stim+response panels, PSTH-vs-prediction overlays) see deepSTRF.utils.plotting.

class deepSTRF.utils.ResponseSmoothingTransform(dt_ms=1, window_size_ms=21, *args, **kwargs)[source]

Bases: Module

Temporally smooth responses with a Hanning window (typically ~20-40 ms).

Parameters:
  • dt_ms (float, default 1) – Time-bin width of the responses, in ms.

  • window_size_ms (float, default 21) – Full width of the Hanning window in ms (rounded to an odd number of dt_ms bins).

References

Hsu, A., Borst, A., & Theunissen, F. E. (2004). Quantifying variability in neural responses and its application for the validation of model predictions. Network: Computation in Neural Systems, 15(2), 91-109. https://doi.org/10.1088/0954-898X_15_2_002

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(responses, dt=1)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepSTRF.utils.compare_wav2spec_to_groundtruth(ds, wav2spec, stim_idx: int = 0, *, ground_truth_stims: Sequence | None = None, z_score: bool = True, figsize: Tuple[float, float] | None = None, suptitle: str | None = None)[source]

Side-by-side visual comparison of a learned/hand-built wav2spec output against a dataset’s precomputed (ground-truth) spectrogram.

Useful for sanity-checking a new front-end on a dataset that ships both raw waveforms and a precomputed spectrogram (e.g. NS1, where the OSF release has raw wavs and the DNet companion repo provides the matching log-mel X_nfht).

Parameters:
  • ds (NeuralDataset) – Dataset instance in waveform mode (ds.stims[s].shape == (1, T_audio)). The dataset’s regular spectrogram is treated as the ground truth, supplied via ground_truth_stims.

  • wav2spec (nn.Module) – Module to apply. Must accept (B, 1, T_audio) and return (B, 1, F, T_neural) — i.e. the wav2spec slot contract.

  • stim_idx (int, default 0) – Which dataset stim to compare. Indexes ds.stims.

  • ground_truth_stims (sequence, optional) – Per-stim ground-truth spectrograms (each (F, T) or (1, F, T)). Required: the waveform-mode dataset doesn’t carry them itself. For NS1 build them by re-instantiating NS1Dataset() (spec mode) and passing its stims.

  • z_score (bool, default True) – If True, both spectrograms are independently z-scored (mean 0, std 1) before plotting, so a constant offset / global scale mismatch does not visually dominate the comparison.

Returns:

  • pred_spec (numpy.ndarray) – The wav2spec output, shape (F, T).

  • truth_spec (numpy.ndarray) – The ground-truth spectrogram, shape (F, T).

  • fig (matplotlib.figure.Figure) – 3-panel side-by-side figure (pred | truth | difference).

deepSTRF.utils.concat_neural_datasets(datasets: Sequence[NeuralDataset], names: Sequence[str] | None = None) NeuralDataset[source]

Concatenate neural datasets along BOTH the stim and neuron axes.

Given k datasets with (S_i, N_i) stimuli and neurons each, returns a single dataset with S = sum(S_i) stimuli and N = sum(N_i) neurons. The response grid is block-diagonal: real data where a stimulus belongs to a given source dataset and the neuron belongs to the same source, (1, 1) NaN sentinels everywhere else. This cross-block missingness is paradigm-compliant — nrn_masks (the derived property) then reflects the block-diagonal coverage automatically.

Primary use case: building “chimeric” datasets that pool recordings across species / labs / preparations (e.g. CRCNS AA1 + AA2 + NS1 for auditory), so that a single model can be fit to the union.

Parameters:
  • datasets (sequence of NeuralDataset) – Two or more instances. They must be of compatible types and share dt (bin width) and any modality-specific dimensions (F for audio, (H, W) for video). Compatibility is checked by each class’s _concat_check_compat hook; mismatches raise AssertionError. Resampling to align dt or F is the caller’s responsibility and must be done before concatenation.

  • names (sequence of str, optional) –

    One label per input dataset, written into stim_meta["dataset"] and nrn_meta["dataset"] on the output as a provenance tag. Defaults to [type(d).__name__ for d in datasets] — i.e. the class name ("CRCNSAA1Dataset" etc.). Pass explicit names to disambiguate two instances of the same class, or to use a shorter human-readable label.

    The tags enable post-hoc selection by source dataset via NeuralDataset.select_pop_by_nrn_attr() / select_stims_by_attr() (e.g. c.select_pop_by_nrn_attr("dataset", "CRCNSAA1Dataset")). Existing "dataset" entries in the input metadata are overwritten — nest-concat callers wanting to preserve inner provenance should pass names= explicitly.

Returns:

A fresh instance. Its concrete type is the most-specific class that is a superclass of every input (type(datasets[0]) when all inputs share a type, otherwise walks the MRO). Neuron selection is reset (self.I = []).

Return type:

NeuralDataset

Notes

Concatenation is eager — the output holds its own full (S, N) grid of response references in memory. At deepSTRF scales (S, N in the low hundreds) this is negligible; cross-block entries are single-element (1, 1) NaN tensors that cost ~8 bytes each. A lazy wrapper-class alternative exists but would complicate self.responses[s][n] access for uncertain benefit at this scale.

Metadata dicts on the output are shallow copies of the inputs’ (the "dataset" tag is written into the copies, never into the sources). Tensors and other shared values inside those dicts are not deep-copied — mutate at your own risk.

Neuron / stim UID uniqueness across inputs is not validated — deepSTRF trusts the caller to pass mutually exclusive sources, since that is the only semantically meaningful case (pooling a dataset’s subset with its superset is degenerate — use constructor arguments instead).

The single-dataset case (len(datasets) == 1) returns the input unchanged, with no "dataset" tagging applied — provenance only becomes meaningful once there is more than one source.

deepSTRF.utils.hanning_smooth(response: Tensor, window_ms: float, dt_ms: float) Tensor[source]

Convolve response along its last (time) axis with a Hanning window.

Parameters:
  • response (torch.Tensor) – Response tensor of any shape; the last axis is assumed to be time.

  • window_ms (float) – Full width of the Hanning window in ms. Rounded to the nearest odd number of dt_ms bins (dt_ms-floor, then +1 if even).

  • dt_ms (float) – Time-bin width of response, in ms.

Returns:

Smoothed response, same shape as input.

Return type:

torch.Tensor

Notes

Padded with zeros on both sides (F.pad(..., mode='constant')), so edge bins get attenuated. The kernel is the raw np.hanning(K), i.e. NOT sum-normalized — matches the legacy behaviour used by the Hsu / Borst / Theunissen (2004) PSTH smoothing step in the CRCNS-AA datasets.

NaN-unsafe: NaN values propagate to neighbouring time bins under the window. Callers (e.g. NeuralDataset.smooth_responses) must filter fully-NaN responses before calling.

deepSTRF.utils.neural_collate(batch)[source]

Collate fn for any NeuralDataset.

Pads variable-duration stims with zeros along the last (time) axis and variable-duration / variable-repeat-count responses with NaN along both the repeat and the time axes. Derives a fine-grained valid_mask from the NaN sentinels so downstream loss code can use boolean indexing or multiplicative masking without re-scanning.

Parameters:

batch (list of dict) –

Each dict is one item as yielded by NeuralDataset.__getitem__, with keys:

  • 'stims' — a stim tensor of shape (..., T_s) (modality-specific leading dims, e.g. (1, F, T_s) for audio).

  • 'responses' — list of length N_selected; each element is a (R_{s,n}, T_s) spike-count tensor or a (1, 1) NaN sentinel.

  • 'valid_mask'(N_selected,) per-neuron bool tensor (ignored here; the fine-grained batch 'valid_mask' below subsumes it).

  • 'stim_meta' — per-stim metadata dict.

Extra keys (e.g. 'behav') are passed through untouched: any key not handled explicitly is collected into a length-B list.

Returns:

A dict with keys:

  • 'stims'(B, ..., T_stim_max) float tensor, zero-padded along the last axis. Contains no NaN.

  • 'responses'(B, N_selected, R_max, T_resp_max) float tensor. NaN-padded along both the repeat (R) and time (T) axes. Fully-NaN slabs mark (stim, neuron) pairs with no recorded data. The response-time axis is sized to T_resp_max independently of the stim-time axis: in spectrogram mode the two are equal (one bin per neural sample), but in waveform mode the stim axis runs at audio_fs Hz while responses stay at the dataset’s neural dt_ms rate.

  • 'valid_mask'(B, N_selected, R_max, T_resp_max) bool tensor, ~responses.isnan(), cached here so downstream loss code does not have to recompute.

  • 'stim_meta' — length-B list of the per-item stim_meta dicts.

  • any extra per-item keys — length-B lists, passed through.

Return type:

dict

deepSTRF.utils.plot_psth_vs_pred(target: ArrayLike, pred: ArrayLike, dt_ms: float | None = None, title: str | None = None, target_label: str = 'PSTH (target)', pred_label: str = 'model', ax: Axes | None = None, legend: bool = True) Axes[source]

Overlay a target PSTH and a model prediction on a single panel.

Designed to be called inside a per-cell loop (“best / median / worst”) that pre-builds a column of axes — this is the canonical val/test visualisation in fit_ns1_statenet.ipynb and load_pretrained_statenet_ns1.ipynb.

Parameters:
  • target (array-like, shape (T,)) – Trial-averaged target PSTH.

  • pred (array-like, shape (T,)) – Model prediction.

  • dt_ms (float, optional) – Bin width in milliseconds. If given the x-axis is in seconds; otherwise it is the bin index.

  • title (str, optional) – Axes title.

  • target_label (str) – Legend labels.

  • pred_label (str) – Legend labels.

  • ax (matplotlib.axes.Axes, optional) – Pre-made axes to draw into. If None, a new figure is created.

  • legend (bool, default True) – Whether to render the legend.

Return type:

matplotlib.axes.Axes

deepSTRF.utils.plot_stim_with_response(stim: ArrayLike, response: ArrayLike, pred: ArrayLike | None = None, dt_ms: float | None = None, title: str | None = None, spec_cmap: str = 'magma', raster_cmap: str = 'Greys', axes: Sequence[Axes] | None = None, figsize: Tuple[float, float] | None = None) Tuple[Figure, Sequence[Axes]][source]

Plot a stimulus spectrogram alongside its recorded (and optionally predicted) response.

Builds a vertically stacked figure with a shared time axis. The middle raster panel is omitted automatically when response is 1-D (already a PSTH).

Parameters:
  • stim (array-like) – Stimulus spectrogram, shape (F, T) or (1, F, T).

  • response (array-like) – Per-trial responses (R, T) or pre-averaged PSTH (T,) / (1, T). R > 1 → spec/raster/PSTH; otherwise spec/PSTH.

  • pred (array-like, optional) – Model prediction, shape (T,). Overlaid on the PSTH panel.

  • dt_ms (float, optional) – Bin width in milliseconds. If given, the x-axis is in seconds; otherwise it is the bin index.

  • title (str, optional) – Suptitle for the whole figure.

  • spec_cmap (str) – Colormaps for the spectrogram and raster panels.

  • raster_cmap (str) – Colormaps for the spectrogram and raster panels.

  • axes (sequence of matplotlib.axes.Axes, optional) – Pre-made axes to draw into (2 or 3, matching the panel count). If None, a new figure is created.

  • figsize ((w, h), optional) – Figure size when axes is None.

Returns:

  • fig (matplotlib.figure.Figure)

  • axes (list of matplotlib.axes.Axes) – [spec_ax, raster_ax, psth_ax] or [spec_ax, psth_ax].

deepSTRF.utils.plot_strf_grid(strfs: ArrayLike | Sequence[ArrayLike], titles: Sequence[str] | None = None, dt_ms: float | None = None, ncols: int = 4, cmap: str = 'RdBu_r', shared_clim: bool = False, suptitle: str | None = None, figsize: Tuple[float, float] | None = None) Tuple[Figure, Sequence[Axes]][source]

Plot a grid of STRF / gradmap kernels.

The classical STRF visualisation: each panel is an (F, T) weight map, frequency on the y-axis (low→high), time on the x-axis (history; [0, T·dt_ms] if dt_ms is given, else bin index). Diverging colormap (RdBu_r by default) with per-panel symmetric vmax — i.e. each cell gets its own |max| so the spatial structure is comparable across cells of very different gradient magnitudes. Pass shared_clim=True for a global symmetric vmax if the kernels are intended to be compared on the same scale (e.g. same neuron under different model variants).

Parameters:
  • strfs (array-like, shape ``(K, F, T)`` or sequence of :py:class:`(F`, :py:class:`T)`) – Stack of kernels to plot. Numpy ndarrays or torch tensors.

  • titles (sequence of str, optional) – Length-K list of per-panel titles. None → unlabeled.

  • dt_ms (float, optional) – Bin width in milliseconds. If given, the x-axis is in ms (history extent [0, T·dt_ms]); otherwise it is the bin index.

  • ncols (int, default 4) – Number of columns in the grid; rows = ceil(K / ncols).

  • cmap (str, default :py:class:``”RdBu_r”:py:class:``)

  • shared_clim (bool, default False) – If True, use one global symmetric vmax = max_k |strf_k| across all panels. Otherwise per-panel.

  • suptitle (str, optional) – Figure-level title.

  • figsize ((w, h), optional) – Figure size. Default scales with the grid shape.

Returns:

  • fig (matplotlib.figure.Figure)

  • axes (flat list of matplotlib.axes.Axes) – Length K; unused grid cells (when K < nrows·ncols) are hidden via ax.axis('off') and not included in the return.