deepSTRF.training package

Submodules

deepSTRF.training.fitter module

deepSTRF.training.Fitter — a thin, opt-in PyTorch training loop.

See docs/_source/md/fitter.md for the full design contract. This module implements the canonical 3-line training step from metrics_paradigm.md §7, plus early stopping, checkpoint selection, and cross-batch metric accumulation. The class is intentionally short — when something doesn’t fit (multi-GPU, mixed-precision, curricula, …) the recommended path is to write the loop, not to extend the Fitter.

class deepSTRF.training.fitter.Fitter(model: ~torch.nn.modules.module.Module, train_loader: ~torch.utils.data.dataloader.DataLoader, val_loader: ~torch.utils.data.dataloader.DataLoader, *, loss_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function mse_loss>, val_metrics: ~typing.Dict[str, ~typing.Callable] | None = None, optimizer: ~torch.optim.optimizer.Optimizer | None = None, device: str | ~torch.device = 'cpu', max_epochs: int = 1000, patience: int = 10, monitor: str = 'val_cc_norm', mode: str = 'max', ckpt_path: str | ~pathlib.Path | None = None, log_fn: ~typing.Callable[[~typing.Mapping[str, ~typing.Any]], None] = <function _format_epoch>, track_train_metrics: bool = True, track_per_cell_best: bool = False)[source]

Bases: object

Opt-in training loop for a deepSTRF NeuralModel.

See docs/_source/md/fitter.md for the full design.

Parameters:
  • model – Any nn.Module whose forward emits (B, N, 1, T) predictions. Stateful models may implement model.detach() (no-op by default on deepSTRF.models.NeuralModel); the Fitter calls it after every step.

  • train_loaderDataLoader instances built with deepSTRF.utils.data.neural_collate. val_loader with batch_size=1 is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6).

  • val_loaderDataLoader instances built with deepSTRF.utils.data.neural_collate. val_loader with batch_size=1 is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6).

  • loss_fn – Callable (pred, responses) -> Tensor. Default mse_loss. The deepSTRF losses auto-collapse responses to PSTH internally (metrics_paradigm.md §2), so no caller-side nanmean is needed.

  • val_metrics – Mapping name -> callable(pred, responses) -> per-neuron Tensor. Default: the canonical {'cc', 'cc_norm'} pair. Stored under f'val_{name}' in the epoch dict.

  • optimizer – Any torch.optim.Optimizer. Default: AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4).

  • device – Where to place the model and per-batch tensors.

  • max_epochs – Hard cap on training epochs.

  • patience – Early-stop patience: number of epochs without improvement on monitor before the loop terminates.

  • monitor – Key in the per-epoch dict to track for early stopping. Default 'val_cc_norm'. Use 'val_loss', 'val_cc', or any custom key you added via val_metrics.

  • mode'max' or 'min' — direction of improvement on monitor. Default 'max' (paired with 'val_cc_norm').

  • ckpt_path – If given, save the best-on-monitor state_dict to this path and restore it at the end of fit().

  • log_fn – Called as log_fn(epoch_dict) once per epoch. Default: a small formatter that prints epoch | k=v | .... Override to log to WandB, MLflow, a file, etc.

  • track_train_metrics – If True (default), recompute val_metrics over the training predictions accumulated this epoch and add them to the epoch dict as 'train_<name>'. Useful for diagnosing overfitting but expensive on large datasets — accumulating (B, N, R, T) responses across all train batches is the dominant per-epoch cost when N × R × T is in the millions (e.g. AA2’s 494-cell population). Set to False to skip; train_loss is always reported.

  • track_per_cell_best – If True, maintain a per-cell best-on-monitor snapshot of the readout’s per-N parameter and buffer slices throughout training. At end-of-fit, after the global ckpt_path restore, each cell’s slice is overlaid with its individual-best snapshot. On no-shared-params models this is strictly at least as good as the vanilla restore on the validation set, cell-by-cell, by construction — every cell ends up at its individual val peak. The training trajectory itself is unchanged (no gradient masking, no per-cell stopping); the only difference is which checkpoint is restored at end. Requires val_metrics[monitor.removeprefix ('val_')] to return a (N,) per-cell tensor (the default _default_val_metrics() does this). Default False.

compute_loss(pred: Tensor, responses: Tensor) Tensor[source]

Default: delegate to self.loss_fn(pred, responses) (auto-PSTH inside).

evaluate(loader: DataLoader) Dict[str, Any][source]

Run loss + val_metrics on a loader (no backprop, no key prefix).

Returns a dict with keys 'loss' plus each entry of self.val_metrics. For test-set evaluation after training: fitter.evaluate(test_loader).

fit() List[Dict[str, Any]][source]

Train until max_epochs or early-stop on monitor.

Returns:

history – One dict per completed epoch, with keys 'epoch', 'train_*', and 'val_*'.

Return type:

list of dict

on_epoch_end(epoch: int, epoch_dict: Dict[str, Any]) None[source]

Default: log the epoch dict via self.log_fn.

deepSTRF.training.multi_seed module

deepSTRF.training.fit_multi_seed — multi-seed init-variance training.

See docs/_source/md/fitter.md §4.2 for the full design.

This wrapper runs the same Fitter configuration K times under different seeds and aggregates per-neuron val and test metrics across seeds. It addresses initialization variance — the same data split, the same hyperparameters, but different initial weights and shuffle order. It is NOT k-fold or leave-one-stim-out cross-validation; those are separate TODOs that need a split-factory rather than a seed sweep.

deepSTRF.training.multi_seed.fit_multi_seed(model_factory: Callable[[int], Module], loader_factory: Callable[[int], Tuple[DataLoader, DataLoader, DataLoader]], n_seeds: int = 5, *, seeds: Sequence[int] | None = None, fitter_kwargs: Mapping[str, Any] | None = None, logger_factory: Callable[[int], Any] | None = None, output_dir: str | Path | None = None, set_seed_strict: bool = False) Dict[str, Any][source]

Run the same Fitter configuration n_seeds times under different seeds.

For each seed: call set_random_seed(seed), instantiate model and (train, val, test) loaders via the factories, fit a Fitter with fitter_kwargs, then re-evaluate post-restore on val and test.

This is multi-seed init variance — same split, different init. NOT k-fold CV (that needs a split-factory, separate TODO).

Parameters:
  • model_factory – Callable seed -> nn.Module. Called fresh each seed, after set_random_seed(seed), so weight init draws from the seeded RNG.

  • loader_factory – Callable seed -> (train_loader, val_loader, test_loader). Required 3-tuple. Called fresh each seed so each run gets a deterministic shuffle generator.

  • n_seeds – Number of seeds to run. Ignored if seeds is given. Default 5.

  • seeds – Explicit seed list. Overrides n_seeds. Default [0, 1, ..., n_seeds-1].

  • fitter_kwargs – Forwarded to Fitter(...). May not include 'model', 'train_loader', or 'val_loader'. If 'ckpt_path' is set, each seed’s checkpoint is saved to <stem>_seed{seed}<suffix> so seeds don’t overwrite each other.

  • logger_factory – Optional Callable[[int], SeedLogger]. If given, the returned object is used as the Fitter’s log_fn for that seed (so it is invoked once per epoch with the epoch dict). After fit(), if the logger exposes finalize(final_metrics), it is called with {'val': val_post, 'test': test_post}. close() is called at end of seed if present. See deepSTRF.training.wandb_log.WandbSeedLogger for the reference implementation. Default None (silent multi-seed sweep unless fitter_kwargs['log_fn'] is set).

  • output_dir – If given, write per-seed history.json (JSON-summarised epoch log), final.json (post-fit val + test population summaries), final_neurons.pt (per-neuron tensors for downstream analysis), and best.pt (state_dict) under output_dir/seed{seed}/. Also writes output_dir/summary.json (across-seed mean / std + best_seed) and output_dir/summary_neurons.pt (per-neuron mean / std / per-seed tensors). Logger-agnostic: this happens regardless of logger_factory. Default None (in-memory results only).

  • set_seed_strict – Forwarded to set_random_seed(strict=...). Default False.

Returns:

results

Keys:
  • 'seeds' : list[int]

  • 'per_seed_histories' : list[list[dict]] — raw Fitter histories, one list per seed.

  • For each metric m (default cc, cc_norm, loss):
    • 'per_seed_val_<m>' : (n_seeds, N) for per-neuron, (n_seeds, 1) for scalar (loss)

    • 'mean_val_<m>' : (N,) nanmean across seeds

    • 'std_val_<m>' : (N,) population nanstd across seeds

    • same triple for test.

  • 'best_seed' : int — seed whose post-fit val monitor metric was best (max or min depending on mode).

  • 'best_state_dict' : OrderedDict[str, Tensor] — deep copy of the best seed’s post-fit model state.

Return type:

dict

Notes

“Best” uses the same (monitor, mode) pair as the per-seed Fitters (default val_cc_norm / max), reduced to a scalar via nanmean over the neuron axis.

deepSTRF.training.config module

deepSTRF.training.auto_config — build wandb.config / mlflow.log_params dicts from a deepSTRF model + Fitter kwargs.

The point: a wandb run is most useful when every knob that varies across runs is captured in wandb.config so the run table can sort / filter / colour by it. Writing that dict by hand for each experiment is boring; this module introspects the common deepSTRF audio-encoding shape and pulls out the obvious fields.

See docs/_source/md/fitter.md §4.3.

deepSTRF.training.config.auto_config(model: Module, fitter_kwargs: Mapping[str, Any] | None = None, dataset_name: str | None = None, extra: Mapping[str, Any] | None = None) dict[source]

Build a JSON-friendly config dict from a deepSTRF model + fitter_kwargs.

Pulled fields, when present:
  • model : class name of model

  • dataset : dataset_name if given

  • n_frequency_bands : model.F

  • temporal_window_size: model.T

  • out_neurons : model.O

  • prefiltering : class name of model.prefiltering

  • core : class name of model.core

  • readout : class name of model.readout

  • readout.kernel : class name of model.readout.kernel (if present)

  • readout.activation : class name of model.readout.activation (if present)

  • output_activation : class name of model.output_activation (if present)

  • Every JSON-friendly fitter_kwargs entry; non-friendly values are stored as their repr() truncated to 60 chars.

  • Every key/value in extra, overwriting any of the above.

The result is safe to pass as wandb.config (no torch tensors / Modules).

deepSTRF.training.config.slug_from_config(config: Mapping[str, Any], fields: Sequence | None = None) str[source]

Build a short dash-separated slug like 'linear-ns1-T9-F34'.

Parameters:
  • config – A config dict (typically from auto_config()).

  • fields – Sequence of (key, formatter) pairs. formatter is a callable value -> str. Missing keys are silently skipped. Default fields: model (lowercased), dataset (lowercased), temporal_window_size (as T{v}), n_frequency_bands (as F{v}).

deepSTRF.training.seed module

Reproducibility seeding for deepSTRF.training.

See docs/_source/md/fitter.md §7.

deepSTRF.training.seed.set_random_seed(seed: int, *, strict: bool = False) None[source]

Seed Python random, NumPy, PyTorch CPU, and CUDA.

Parameters:
  • seed – Integer seed value, broadcast to all four RNGs.

  • strict – If True, additionally enables torch.use_deterministic_algorithms(True) (with the CUBLAS_WORKSPACE_CONFIG env var required by CUBLAS) so that any non-deterministic CUDA op raises rather than silently breaking reproducibility. Trades speed for bit-exactness; mainly useful for debugging. Default False.

Notes

Even with strict=False, cudnn.deterministic is set to True and cudnn.benchmark to False — that is the legacy deepSTRF behaviour and is enough for run-to-run reproducibility on the same hardware so long as no opt-in non-deterministic kernel is invoked.

deepSTRF.training.wandb_log module

Opt-in WandB logger for fit_multi_seed (and any Fitter user).

This module is a thin convenience wrapper. The library does not require WandB for training — fit_multi_seed only knows about a generic logger_factory(seed) -> SeedLogger protocol. WandbSeedLogger implements that protocol with one wandb.init per seed.

Build a factory with make_wandb_logger_factory() and pass it as logger_factory= to deepSTRF.training.fit_multi_seed().

Examples

>>> from deepSTRF.training import fit_multi_seed
>>> from deepSTRF.training.wandb_log import make_wandb_logger_factory
>>> results = fit_multi_seed(
...     model_factory=..., loader_factory=..., n_seeds=3,
...     logger_factory=make_wandb_logger_factory(
...         project="deepstrf", entity="urancon",
...         group="ns1-linear", mode="offline",
...     ),
... )

If WandB is not installed, importing this module raises ImportError with a clear hint; the rest of deepSTRF.training keeps working.

class deepSTRF.training.wandb_log.WandbSeedLogger(seed: int, *, project: str | None = None, entity: str | None = None, group: str | None = None, name: str | None = None, dir: str | None = None, mode: str | None = None, config: Mapping[str, Any] | None = None, **wandb_init_kwargs: Any)[source]

Bases: object

One-seed logger backed by a single wandb.Run.

Implements the duck-typed SeedLogger protocol expected by fit_multi_seed():

  • __call__(epoch_dict) is invoked once per epoch; per-neuron tensors are summarised as {mean, p10, p50, p90} scalars plus a per-epoch wandb.Histogram.

  • finalize(final_metrics) is invoked once at end of seed, after Fitter.fit() and the post-fit val + test evaluate passes. It writes test metrics + per-neuron summaries to run.summary (so they show up as sortable run-level columns).

  • close() calls run.finish().

Parameters:
  • seed – The seed value. Added to wandb.config and used to derive the run name.

  • project – Forwarded to wandb.init. Run name is auto-derived as <name>-seed{seed} if name= is set, else <group>-seed{seed} if group= is set, else seed{seed}.

  • entity – Forwarded to wandb.init. Run name is auto-derived as <name>-seed{seed} if name= is set, else <group>-seed{seed} if group= is set, else seed{seed}.

  • group – Forwarded to wandb.init. Run name is auto-derived as <name>-seed{seed} if name= is set, else <group>-seed{seed} if group= is set, else seed{seed}.

  • name – Forwarded to wandb.init verbatim. WANDB_MODE defaults to offline if the env var is unset and mode= is not given, so logging works with no account / network.

  • dir – Forwarded to wandb.init verbatim. WANDB_MODE defaults to offline if the env var is unset and mode= is not given, so logging works with no account / network.

  • mode – Forwarded to wandb.init verbatim. WANDB_MODE defaults to offline if the env var is unset and mode= is not given, so logging works with no account / network.

  • config – Forwarded to wandb.init verbatim. WANDB_MODE defaults to offline if the env var is unset and mode= is not given, so logging works with no account / network.

  • **wandb_init_kwargs – Forwarded to wandb.init verbatim. WANDB_MODE defaults to offline if the env var is unset and mode= is not given, so logging works with no account / network.

close() None[source]
finalize(final_metrics: Mapping[str, Mapping[str, Any]]) None[source]

Write {prefix}_{metric} summary rows + percentile scalars.

final_metrics keys are split prefixes (e.g. "val", "test") each holding a dict of metric_name -> scalar-or-per-neuron-tensor.

deepSTRF.training.wandb_log.make_wandb_logger_factory(**wandb_kwargs: Any) Callable[[int], WandbSeedLogger][source]

Build a logger_factory that returns one WandbSeedLogger per seed.

All kwargs are forwarded to WandbSeedLogger. Use as fit_multi_seed(..., logger_factory=make_wandb_logger_factory(project="...", ...)).

deepSTRF.training.tb_log module

Opt-in TensorBoard logger for fit_multi_seed (and any Fitter user).

Same protocol as deepSTRF.training.wandb_log — implements the SeedLogger duck-typed interface __call__ / finalize / close against torch.utils.tensorboard.SummaryWriter instead of wandb.Run.

Use TensorBoard when: - You want a local-only, no-account, single-process viewer. - You’re on a machine with no outbound network.

Use WandB when: - You want a cross-run table view with sortable columns. - You want cloud-hosted persistence + sharing.

Browse the logs with tensorboard --logdir=<your log_dir>; the URL defaults to http://localhost:6006.

Examples

>>> from deepSTRF.training import fit_multi_seed
>>> from deepSTRF.training.tb_log import make_tensorboard_logger_factory
>>> results = fit_multi_seed(
...     model_factory=..., loader_factory=..., n_seeds=3,
...     logger_factory=make_tensorboard_logger_factory(
...         log_dir="tb_logs", group="linear-ns1-T9",
...         config={"model": "Linear", "T": 9, "F": 34},
...     ),
... )
class deepSTRF.training.tb_log.TensorBoardSeedLogger(seed: int, *, log_dir: str | Path = 'tb_logs', group: str | None = None, name: str | None = None, config: Mapping[str, Any] | None = None)[source]

Bases: object

One-seed logger backed by a torch.utils.tensorboard.SummaryWriter.

Implements the SeedLogger protocol expected by fit_multi_seed(). Each seed writes to {log_dir}/{group or 'default'}/{run_name}/ so a typical tensorboard --logdir <log_dir> invocation surfaces every group × seed combination side by side.

Parameters:
  • seed – Seed value. Used in the run name and the hparams/seed scalar.

  • log_dir – Root directory for event files. Subdirectories per group / seed are created automatically.

  • group – Optional group name. Determines the second-level subdirectory.

  • name – Optional user-supplied run name. Final run dir is {name}-seed{seed}; falls back to {group}-seed{seed} then seed{seed}.

  • config – Optional dict of hparams. Each JSON-friendly entry is written as hparams/<key> (numeric values as scalars, others via add_text). The full config is also dumped as a single config text entry.

close() None[source]
finalize(final_metrics: Mapping[str, Mapping[str, Any]]) None[source]

Write final {prefix}_{metric} scalars + percentile / histogram tags at step 0 (TensorBoard’s “Summary” doesn’t exist as a separate concept — we use a fixed step so they appear as a single point on the chart).

deepSTRF.training.tb_log.make_tensorboard_logger_factory(**tb_kwargs: Any) Callable[[int], TensorBoardSeedLogger][source]

Build a logger_factory that returns one TensorBoardSeedLogger per seed.

All kwargs are forwarded to TensorBoardSeedLogger.

Module contents

deepSTRF.training — opt-in training utilities.

See docs/_source/md/fitter.md.

class deepSTRF.training.Fitter(model: ~torch.nn.modules.module.Module, train_loader: ~torch.utils.data.dataloader.DataLoader, val_loader: ~torch.utils.data.dataloader.DataLoader, *, loss_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function mse_loss>, val_metrics: ~typing.Dict[str, ~typing.Callable] | None = None, optimizer: ~torch.optim.optimizer.Optimizer | None = None, device: str | ~torch.device = 'cpu', max_epochs: int = 1000, patience: int = 10, monitor: str = 'val_cc_norm', mode: str = 'max', ckpt_path: str | ~pathlib.Path | None = None, log_fn: ~typing.Callable[[~typing.Mapping[str, ~typing.Any]], None] = <function _format_epoch>, track_train_metrics: bool = True, track_per_cell_best: bool = False)[source]

Bases: object

Opt-in training loop for a deepSTRF NeuralModel.

See docs/_source/md/fitter.md for the full design.

Parameters:
  • model – Any nn.Module whose forward emits (B, N, 1, T) predictions. Stateful models may implement model.detach() (no-op by default on deepSTRF.models.NeuralModel); the Fitter calls it after every step.

  • train_loaderDataLoader instances built with deepSTRF.utils.data.neural_collate. val_loader with batch_size=1 is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6).

  • val_loaderDataLoader instances built with deepSTRF.utils.data.neural_collate. val_loader with batch_size=1 is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6).

  • loss_fn – Callable (pred, responses) -> Tensor. Default mse_loss. The deepSTRF losses auto-collapse responses to PSTH internally (metrics_paradigm.md §2), so no caller-side nanmean is needed.

  • val_metrics – Mapping name -> callable(pred, responses) -> per-neuron Tensor. Default: the canonical {'cc', 'cc_norm'} pair. Stored under f'val_{name}' in the epoch dict.

  • optimizer – Any torch.optim.Optimizer. Default: AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4).

  • device – Where to place the model and per-batch tensors.

  • max_epochs – Hard cap on training epochs.

  • patience – Early-stop patience: number of epochs without improvement on monitor before the loop terminates.

  • monitor – Key in the per-epoch dict to track for early stopping. Default 'val_cc_norm'. Use 'val_loss', 'val_cc', or any custom key you added via val_metrics.

  • mode'max' or 'min' — direction of improvement on monitor. Default 'max' (paired with 'val_cc_norm').

  • ckpt_path – If given, save the best-on-monitor state_dict to this path and restore it at the end of fit().

  • log_fn – Called as log_fn(epoch_dict) once per epoch. Default: a small formatter that prints epoch | k=v | .... Override to log to WandB, MLflow, a file, etc.

  • track_train_metrics – If True (default), recompute val_metrics over the training predictions accumulated this epoch and add them to the epoch dict as 'train_<name>'. Useful for diagnosing overfitting but expensive on large datasets — accumulating (B, N, R, T) responses across all train batches is the dominant per-epoch cost when N × R × T is in the millions (e.g. AA2’s 494-cell population). Set to False to skip; train_loss is always reported.

  • track_per_cell_best – If True, maintain a per-cell best-on-monitor snapshot of the readout’s per-N parameter and buffer slices throughout training. At end-of-fit, after the global ckpt_path restore, each cell’s slice is overlaid with its individual-best snapshot. On no-shared-params models this is strictly at least as good as the vanilla restore on the validation set, cell-by-cell, by construction — every cell ends up at its individual val peak. The training trajectory itself is unchanged (no gradient masking, no per-cell stopping); the only difference is which checkpoint is restored at end. Requires val_metrics[monitor.removeprefix ('val_')] to return a (N,) per-cell tensor (the default _default_val_metrics() does this). Default False.

compute_loss(pred: Tensor, responses: Tensor) Tensor[source]

Default: delegate to self.loss_fn(pred, responses) (auto-PSTH inside).

evaluate(loader: DataLoader) Dict[str, Any][source]

Run loss + val_metrics on a loader (no backprop, no key prefix).

Returns a dict with keys 'loss' plus each entry of self.val_metrics. For test-set evaluation after training: fitter.evaluate(test_loader).

fit() List[Dict[str, Any]][source]

Train until max_epochs or early-stop on monitor.

Returns:

history – One dict per completed epoch, with keys 'epoch', 'train_*', and 'val_*'.

Return type:

list of dict

on_epoch_end(epoch: int, epoch_dict: Dict[str, Any]) None[source]

Default: log the epoch dict via self.log_fn.

deepSTRF.training.auto_config(model: Module, fitter_kwargs: Mapping[str, Any] | None = None, dataset_name: str | None = None, extra: Mapping[str, Any] | None = None) dict[source]

Build a JSON-friendly config dict from a deepSTRF model + fitter_kwargs.

Pulled fields, when present:
  • model : class name of model

  • dataset : dataset_name if given

  • n_frequency_bands : model.F

  • temporal_window_size: model.T

  • out_neurons : model.O

  • prefiltering : class name of model.prefiltering

  • core : class name of model.core

  • readout : class name of model.readout

  • readout.kernel : class name of model.readout.kernel (if present)

  • readout.activation : class name of model.readout.activation (if present)

  • output_activation : class name of model.output_activation (if present)

  • Every JSON-friendly fitter_kwargs entry; non-friendly values are stored as their repr() truncated to 60 chars.

  • Every key/value in extra, overwriting any of the above.

The result is safe to pass as wandb.config (no torch tensors / Modules).

deepSTRF.training.fit_multi_seed(model_factory: Callable[[int], Module], loader_factory: Callable[[int], Tuple[DataLoader, DataLoader, DataLoader]], n_seeds: int = 5, *, seeds: Sequence[int] | None = None, fitter_kwargs: Mapping[str, Any] | None = None, logger_factory: Callable[[int], Any] | None = None, output_dir: str | Path | None = None, set_seed_strict: bool = False) Dict[str, Any][source]

Run the same Fitter configuration n_seeds times under different seeds.

For each seed: call set_random_seed(seed), instantiate model and (train, val, test) loaders via the factories, fit a Fitter with fitter_kwargs, then re-evaluate post-restore on val and test.

This is multi-seed init variance — same split, different init. NOT k-fold CV (that needs a split-factory, separate TODO).

Parameters:
  • model_factory – Callable seed -> nn.Module. Called fresh each seed, after set_random_seed(seed), so weight init draws from the seeded RNG.

  • loader_factory – Callable seed -> (train_loader, val_loader, test_loader). Required 3-tuple. Called fresh each seed so each run gets a deterministic shuffle generator.

  • n_seeds – Number of seeds to run. Ignored if seeds is given. Default 5.

  • seeds – Explicit seed list. Overrides n_seeds. Default [0, 1, ..., n_seeds-1].

  • fitter_kwargs – Forwarded to Fitter(...). May not include 'model', 'train_loader', or 'val_loader'. If 'ckpt_path' is set, each seed’s checkpoint is saved to <stem>_seed{seed}<suffix> so seeds don’t overwrite each other.

  • logger_factory – Optional Callable[[int], SeedLogger]. If given, the returned object is used as the Fitter’s log_fn for that seed (so it is invoked once per epoch with the epoch dict). After fit(), if the logger exposes finalize(final_metrics), it is called with {'val': val_post, 'test': test_post}. close() is called at end of seed if present. See deepSTRF.training.wandb_log.WandbSeedLogger for the reference implementation. Default None (silent multi-seed sweep unless fitter_kwargs['log_fn'] is set).

  • output_dir – If given, write per-seed history.json (JSON-summarised epoch log), final.json (post-fit val + test population summaries), final_neurons.pt (per-neuron tensors for downstream analysis), and best.pt (state_dict) under output_dir/seed{seed}/. Also writes output_dir/summary.json (across-seed mean / std + best_seed) and output_dir/summary_neurons.pt (per-neuron mean / std / per-seed tensors). Logger-agnostic: this happens regardless of logger_factory. Default None (in-memory results only).

  • set_seed_strict – Forwarded to set_random_seed(strict=...). Default False.

Returns:

results

Keys:
  • 'seeds' : list[int]

  • 'per_seed_histories' : list[list[dict]] — raw Fitter histories, one list per seed.

  • For each metric m (default cc, cc_norm, loss):
    • 'per_seed_val_<m>' : (n_seeds, N) for per-neuron, (n_seeds, 1) for scalar (loss)

    • 'mean_val_<m>' : (N,) nanmean across seeds

    • 'std_val_<m>' : (N,) population nanstd across seeds

    • same triple for test.

  • 'best_seed' : int — seed whose post-fit val monitor metric was best (max or min depending on mode).

  • 'best_state_dict' : OrderedDict[str, Tensor] — deep copy of the best seed’s post-fit model state.

Return type:

dict

Notes

“Best” uses the same (monitor, mode) pair as the per-seed Fitters (default val_cc_norm / max), reduced to a scalar via nanmean over the neuron axis.

deepSTRF.training.set_random_seed(seed: int, *, strict: bool = False) None[source]

Seed Python random, NumPy, PyTorch CPU, and CUDA.

Parameters:
  • seed – Integer seed value, broadcast to all four RNGs.

  • strict – If True, additionally enables torch.use_deterministic_algorithms(True) (with the CUBLAS_WORKSPACE_CONFIG env var required by CUBLAS) so that any non-deterministic CUDA op raises rather than silently breaking reproducibility. Trades speed for bit-exactness; mainly useful for debugging. Default False.

Notes

Even with strict=False, cudnn.deterministic is set to True and cudnn.benchmark to False — that is the legacy deepSTRF behaviour and is enough for run-to-run reproducibility on the same hardware so long as no opt-in non-deterministic kernel is invoked.

deepSTRF.training.slug_from_config(config: Mapping[str, Any], fields: Sequence | None = None) str[source]

Build a short dash-separated slug like 'linear-ns1-T9-F34'.

Parameters:
  • config – A config dict (typically from auto_config()).

  • fields – Sequence of (key, formatter) pairs. formatter is a callable value -> str. Missing keys are silently skipped. Default fields: model (lowercased), dataset (lowercased), temporal_window_size (as T{v}), n_frequency_bands (as F{v}).