deepSTRF.training package
Submodules
deepSTRF.training.fitter module
deepSTRF.training.Fitter — a thin, opt-in PyTorch training loop.
See docs/_source/md/fitter.md for the full design contract. This module
implements the canonical 3-line training step from
metrics_paradigm.md §7, plus early stopping, checkpoint selection, and
cross-batch metric accumulation. The class is intentionally short — when
something doesn’t fit (multi-GPU, mixed-precision, curricula, …) the
recommended path is to write the loop, not to extend the Fitter.
- class deepSTRF.training.fitter.Fitter(model: ~torch.nn.modules.module.Module, train_loader: ~torch.utils.data.dataloader.DataLoader, val_loader: ~torch.utils.data.dataloader.DataLoader, *, loss_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function mse_loss>, val_metrics: ~typing.Dict[str, ~typing.Callable] | None = None, optimizer: ~torch.optim.optimizer.Optimizer | None = None, device: str | ~torch.device = 'cpu', max_epochs: int = 1000, patience: int = 10, monitor: str = 'val_cc_norm', mode: str = 'max', ckpt_path: str | ~pathlib.Path | None = None, log_fn: ~typing.Callable[[~typing.Mapping[str, ~typing.Any]], None] = <function _format_epoch>, track_train_metrics: bool = True, track_per_cell_best: bool = False)[source]
Bases:
objectOpt-in training loop for a deepSTRF
NeuralModel.See
docs/_source/md/fitter.mdfor the full design.- Parameters:
model – Any
nn.Modulewhoseforwardemits(B, N, 1, T)predictions. Stateful models may implementmodel.detach()(no-op by default ondeepSTRF.models.NeuralModel); the Fitter calls it after every step.train_loader –
DataLoaderinstances built withdeepSTRF.utils.data.neural_collate.val_loaderwithbatch_size=1is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6).val_loader –
DataLoaderinstances built withdeepSTRF.utils.data.neural_collate.val_loaderwithbatch_size=1is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6).loss_fn – Callable
(pred, responses) -> Tensor. Defaultmse_loss. The deepSTRF losses auto-collapseresponsesto PSTH internally (metrics_paradigm.md§2), so no caller-sidenanmeanis needed.val_metrics – Mapping
name -> callable(pred, responses) -> per-neuron Tensor. Default: the canonical{'cc', 'cc_norm'}pair. Stored underf'val_{name}'in the epoch dict.optimizer – Any
torch.optim.Optimizer. Default:AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4).device – Where to place the model and per-batch tensors.
max_epochs – Hard cap on training epochs.
patience – Early-stop patience: number of epochs without improvement on
monitorbefore the loop terminates.monitor – Key in the per-epoch dict to track for early stopping. Default
'val_cc_norm'. Use'val_loss','val_cc', or any custom key you added viaval_metrics.mode –
'max'or'min'— direction of improvement onmonitor. Default'max'(paired with'val_cc_norm').ckpt_path – If given, save the best-on-
monitorstate_dictto this path and restore it at the end offit().log_fn – Called as
log_fn(epoch_dict)once per epoch. Default: a small formatter that printsepoch | k=v | .... Override to log to WandB, MLflow, a file, etc.track_train_metrics – If
True(default), recomputeval_metricsover the training predictions accumulated this epoch and add them to the epoch dict as'train_<name>'. Useful for diagnosing overfitting but expensive on large datasets — accumulating(B, N, R, T)responses across all train batches is the dominant per-epoch cost whenN × R × Tis in the millions (e.g. AA2’s 494-cell population). Set toFalseto skip;train_lossis always reported.track_per_cell_best – If
True, maintain a per-cell best-on-monitorsnapshot of the readout’s per-N parameter and buffer slices throughout training. At end-of-fit, after the globalckpt_pathrestore, each cell’s slice is overlaid with its individual-best snapshot. On no-shared-params models this is strictly at least as good as the vanilla restore on the validation set, cell-by-cell, by construction — every cell ends up at its individual val peak. The training trajectory itself is unchanged (no gradient masking, no per-cell stopping); the only difference is which checkpoint is restored at end. Requiresval_metrics[monitor.removeprefix ('val_')]to return a(N,)per-cell tensor (the default_default_val_metrics()does this). DefaultFalse.
- compute_loss(pred: Tensor, responses: Tensor) Tensor[source]
Default: delegate to
self.loss_fn(pred, responses)(auto-PSTH inside).
- evaluate(loader: DataLoader) Dict[str, Any][source]
Run loss +
val_metricson a loader (no backprop, no key prefix).Returns a dict with keys
'loss'plus each entry ofself.val_metrics. For test-set evaluation after training:fitter.evaluate(test_loader).
deepSTRF.training.multi_seed module
deepSTRF.training.fit_multi_seed — multi-seed init-variance training.
See docs/_source/md/fitter.md §4.2 for the full design.
This wrapper runs the same Fitter configuration K times under different seeds and aggregates per-neuron val and test metrics across seeds. It addresses initialization variance — the same data split, the same hyperparameters, but different initial weights and shuffle order. It is NOT k-fold or leave-one-stim-out cross-validation; those are separate TODOs that need a split-factory rather than a seed sweep.
- deepSTRF.training.multi_seed.fit_multi_seed(model_factory: Callable[[int], Module], loader_factory: Callable[[int], Tuple[DataLoader, DataLoader, DataLoader]], n_seeds: int = 5, *, seeds: Sequence[int] | None = None, fitter_kwargs: Mapping[str, Any] | None = None, logger_factory: Callable[[int], Any] | None = None, output_dir: str | Path | None = None, set_seed_strict: bool = False) Dict[str, Any][source]
Run the same Fitter configuration
n_seedstimes under different seeds.For each seed: call
set_random_seed(seed), instantiate model and(train, val, test)loaders via the factories, fit a Fitter withfitter_kwargs, then re-evaluate post-restore on val and test.This is multi-seed init variance — same split, different init. NOT k-fold CV (that needs a split-factory, separate TODO).
- Parameters:
model_factory – Callable
seed -> nn.Module. Called fresh each seed, afterset_random_seed(seed), so weight init draws from the seeded RNG.loader_factory – Callable
seed -> (train_loader, val_loader, test_loader). Required 3-tuple. Called fresh each seed so each run gets a deterministic shuffle generator.n_seeds – Number of seeds to run. Ignored if
seedsis given. Default 5.seeds – Explicit seed list. Overrides
n_seeds. Default[0, 1, ..., n_seeds-1].fitter_kwargs – Forwarded to
Fitter(...). May not include'model','train_loader', or'val_loader'. If'ckpt_path'is set, each seed’s checkpoint is saved to<stem>_seed{seed}<suffix>so seeds don’t overwrite each other.logger_factory – Optional
Callable[[int], SeedLogger]. If given, the returned object is used as the Fitter’slog_fnfor that seed (so it is invoked once per epoch with the epoch dict). Afterfit(), if the logger exposesfinalize(final_metrics), it is called with{'val': val_post, 'test': test_post}.close()is called at end of seed if present. SeedeepSTRF.training.wandb_log.WandbSeedLoggerfor the reference implementation. DefaultNone(silent multi-seed sweep unlessfitter_kwargs['log_fn']is set).output_dir – If given, write per-seed
history.json(JSON-summarised epoch log),final.json(post-fit val + test population summaries),final_neurons.pt(per-neuron tensors for downstream analysis), andbest.pt(state_dict) underoutput_dir/seed{seed}/. Also writesoutput_dir/summary.json(across-seed mean / std + best_seed) andoutput_dir/summary_neurons.pt(per-neuron mean / std / per-seed tensors). Logger-agnostic: this happens regardless oflogger_factory. DefaultNone(in-memory results only).set_seed_strict – Forwarded to
set_random_seed(strict=...). DefaultFalse.
- Returns:
results –
- Keys:
'seeds':list[int]'per_seed_histories':list[list[dict]]— raw Fitter histories, one list per seed.- For each metric
m(defaultcc,cc_norm,loss): 'per_seed_val_<m>':(n_seeds, N)for per-neuron,(n_seeds, 1)for scalar (loss)'mean_val_<m>':(N,)nanmean across seeds'std_val_<m>':(N,)population nanstd across seedssame triple for
test.
- For each metric
'best_seed':int— seed whose post-fit valmonitormetric was best (max or min depending onmode).'best_state_dict':OrderedDict[str, Tensor]— deep copy of the best seed’s post-fit model state.
- Return type:
dict
Notes
“Best” uses the same
(monitor, mode)pair as the per-seed Fitters (defaultval_cc_norm/max), reduced to a scalar viananmeanover the neuron axis.
deepSTRF.training.config module
deepSTRF.training.auto_config — build wandb.config / mlflow.log_params
dicts from a deepSTRF model + Fitter kwargs.
The point: a wandb run is most useful when every knob that varies
across runs is captured in wandb.config so the run table can sort /
filter / colour by it. Writing that dict by hand for each experiment is
boring; this module introspects the common deepSTRF audio-encoding
shape and pulls out the obvious fields.
See docs/_source/md/fitter.md §4.3.
- deepSTRF.training.config.auto_config(model: Module, fitter_kwargs: Mapping[str, Any] | None = None, dataset_name: str | None = None, extra: Mapping[str, Any] | None = None) dict[source]
Build a JSON-friendly config dict from a deepSTRF model + fitter_kwargs.
- Pulled fields, when present:
model: class name ofmodeldataset:dataset_nameif givenn_frequency_bands:model.Ftemporal_window_size:model.Tout_neurons:model.Oprefiltering: class name ofmodel.prefilteringcore: class name ofmodel.corereadout: class name ofmodel.readoutreadout.kernel: class name ofmodel.readout.kernel(if present)readout.activation: class name ofmodel.readout.activation(if present)output_activation: class name ofmodel.output_activation(if present)Every JSON-friendly
fitter_kwargsentry; non-friendly values are stored as theirrepr()truncated to 60 chars.Every key/value in
extra, overwriting any of the above.
The result is safe to pass as
wandb.config(no torch tensors / Modules).
- deepSTRF.training.config.slug_from_config(config: Mapping[str, Any], fields: Sequence | None = None) str[source]
Build a short dash-separated slug like
'linear-ns1-T9-F34'.- Parameters:
config – A config dict (typically from
auto_config()).fields – Sequence of
(key, formatter)pairs.formatteris a callablevalue -> str. Missing keys are silently skipped. Default fields:model(lowercased),dataset(lowercased),temporal_window_size(asT{v}),n_frequency_bands(asF{v}).
deepSTRF.training.seed module
Reproducibility seeding for deepSTRF.training.
See docs/_source/md/fitter.md §7.
- deepSTRF.training.seed.set_random_seed(seed: int, *, strict: bool = False) None[source]
Seed Python
random, NumPy, PyTorch CPU, and CUDA.- Parameters:
seed – Integer seed value, broadcast to all four RNGs.
strict – If True, additionally enables
torch.use_deterministic_algorithms(True)(with theCUBLAS_WORKSPACE_CONFIGenv var required by CUBLAS) so that any non-deterministic CUDA op raises rather than silently breaking reproducibility. Trades speed for bit-exactness; mainly useful for debugging. DefaultFalse.
Notes
Even with
strict=False,cudnn.deterministicis set to True andcudnn.benchmarkto False — that is the legacydeepSTRFbehaviour and is enough for run-to-run reproducibility on the same hardware so long as no opt-in non-deterministic kernel is invoked.
deepSTRF.training.wandb_log module
Opt-in WandB logger for fit_multi_seed (and any Fitter user).
This module is a thin convenience wrapper. The library does not require
WandB for training — fit_multi_seed only knows about a generic
logger_factory(seed) -> SeedLogger protocol. WandbSeedLogger
implements that protocol with one wandb.init per seed.
Build a factory with make_wandb_logger_factory() and pass it as
logger_factory= to deepSTRF.training.fit_multi_seed().
Examples
>>> from deepSTRF.training import fit_multi_seed
>>> from deepSTRF.training.wandb_log import make_wandb_logger_factory
>>> results = fit_multi_seed(
... model_factory=..., loader_factory=..., n_seeds=3,
... logger_factory=make_wandb_logger_factory(
... project="deepstrf", entity="urancon",
... group="ns1-linear", mode="offline",
... ),
... )
If WandB is not installed, importing this module raises ImportError
with a clear hint; the rest of deepSTRF.training keeps working.
- class deepSTRF.training.wandb_log.WandbSeedLogger(seed: int, *, project: str | None = None, entity: str | None = None, group: str | None = None, name: str | None = None, dir: str | None = None, mode: str | None = None, config: Mapping[str, Any] | None = None, **wandb_init_kwargs: Any)[source]
Bases:
objectOne-seed logger backed by a single
wandb.Run.Implements the duck-typed
SeedLoggerprotocol expected byfit_multi_seed():__call__(epoch_dict)is invoked once per epoch; per-neuron tensors are summarised as{mean, p10, p50, p90}scalars plus a per-epochwandb.Histogram.finalize(final_metrics)is invoked once at end of seed, afterFitter.fit()and the post-fit val + testevaluatepasses. It writes test metrics + per-neuron summaries torun.summary(so they show up as sortable run-level columns).close()callsrun.finish().
- Parameters:
seed – The seed value. Added to
wandb.configand used to derive the run name.project – Forwarded to
wandb.init. Run name is auto-derived as<name>-seed{seed}ifname=is set, else<group>-seed{seed}ifgroup=is set, elseseed{seed}.entity – Forwarded to
wandb.init. Run name is auto-derived as<name>-seed{seed}ifname=is set, else<group>-seed{seed}ifgroup=is set, elseseed{seed}.group – Forwarded to
wandb.init. Run name is auto-derived as<name>-seed{seed}ifname=is set, else<group>-seed{seed}ifgroup=is set, elseseed{seed}.name – Forwarded to
wandb.initverbatim.WANDB_MODEdefaults toofflineif the env var is unset andmode=is not given, so logging works with no account / network.dir – Forwarded to
wandb.initverbatim.WANDB_MODEdefaults toofflineif the env var is unset andmode=is not given, so logging works with no account / network.mode – Forwarded to
wandb.initverbatim.WANDB_MODEdefaults toofflineif the env var is unset andmode=is not given, so logging works with no account / network.config – Forwarded to
wandb.initverbatim.WANDB_MODEdefaults toofflineif the env var is unset andmode=is not given, so logging works with no account / network.**wandb_init_kwargs – Forwarded to
wandb.initverbatim.WANDB_MODEdefaults toofflineif the env var is unset andmode=is not given, so logging works with no account / network.
- deepSTRF.training.wandb_log.make_wandb_logger_factory(**wandb_kwargs: Any) Callable[[int], WandbSeedLogger][source]
Build a
logger_factorythat returns oneWandbSeedLoggerper seed.All kwargs are forwarded to
WandbSeedLogger. Use asfit_multi_seed(..., logger_factory=make_wandb_logger_factory(project="...", ...)).
deepSTRF.training.tb_log module
Opt-in TensorBoard logger for fit_multi_seed (and any Fitter user).
Same protocol as deepSTRF.training.wandb_log — implements the
SeedLogger duck-typed interface __call__ / finalize /
close against torch.utils.tensorboard.SummaryWriter instead of
wandb.Run.
Use TensorBoard when: - You want a local-only, no-account, single-process viewer. - You’re on a machine with no outbound network.
Use WandB when: - You want a cross-run table view with sortable columns. - You want cloud-hosted persistence + sharing.
Browse the logs with tensorboard --logdir=<your log_dir>; the URL
defaults to http://localhost:6006.
Examples
>>> from deepSTRF.training import fit_multi_seed
>>> from deepSTRF.training.tb_log import make_tensorboard_logger_factory
>>> results = fit_multi_seed(
... model_factory=..., loader_factory=..., n_seeds=3,
... logger_factory=make_tensorboard_logger_factory(
... log_dir="tb_logs", group="linear-ns1-T9",
... config={"model": "Linear", "T": 9, "F": 34},
... ),
... )
- class deepSTRF.training.tb_log.TensorBoardSeedLogger(seed: int, *, log_dir: str | Path = 'tb_logs', group: str | None = None, name: str | None = None, config: Mapping[str, Any] | None = None)[source]
Bases:
objectOne-seed logger backed by a
torch.utils.tensorboard.SummaryWriter.Implements the
SeedLoggerprotocol expected byfit_multi_seed(). Each seed writes to{log_dir}/{group or 'default'}/{run_name}/so a typicaltensorboard --logdir <log_dir>invocation surfaces every group × seed combination side by side.- Parameters:
seed – Seed value. Used in the run name and the
hparams/seedscalar.log_dir – Root directory for event files. Subdirectories per group / seed are created automatically.
group – Optional group name. Determines the second-level subdirectory.
name – Optional user-supplied run name. Final run dir is
{name}-seed{seed}; falls back to{group}-seed{seed}thenseed{seed}.config – Optional dict of hparams. Each JSON-friendly entry is written as
hparams/<key>(numeric values as scalars, others viaadd_text). The full config is also dumped as a singleconfigtext entry.
- deepSTRF.training.tb_log.make_tensorboard_logger_factory(**tb_kwargs: Any) Callable[[int], TensorBoardSeedLogger][source]
Build a
logger_factorythat returns oneTensorBoardSeedLoggerper seed.All kwargs are forwarded to
TensorBoardSeedLogger.
Module contents
deepSTRF.training — opt-in training utilities.
See docs/_source/md/fitter.md.
- class deepSTRF.training.Fitter(model: ~torch.nn.modules.module.Module, train_loader: ~torch.utils.data.dataloader.DataLoader, val_loader: ~torch.utils.data.dataloader.DataLoader, *, loss_fn: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function mse_loss>, val_metrics: ~typing.Dict[str, ~typing.Callable] | None = None, optimizer: ~torch.optim.optimizer.Optimizer | None = None, device: str | ~torch.device = 'cpu', max_epochs: int = 1000, patience: int = 10, monitor: str = 'val_cc_norm', mode: str = 'max', ckpt_path: str | ~pathlib.Path | None = None, log_fn: ~typing.Callable[[~typing.Mapping[str, ~typing.Any]], None] = <function _format_epoch>, track_train_metrics: bool = True, track_per_cell_best: bool = False)[source]
Bases:
objectOpt-in training loop for a deepSTRF
NeuralModel.See
docs/_source/md/fitter.mdfor the full design.- Parameters:
model – Any
nn.Modulewhoseforwardemits(B, N, 1, T)predictions. Stateful models may implementmodel.detach()(no-op by default ondeepSTRF.models.NeuralModel); the Fitter calls it after every step.train_loader –
DataLoaderinstances built withdeepSTRF.utils.data.neural_collate.val_loaderwithbatch_size=1is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6).val_loader –
DataLoaderinstances built withdeepSTRF.utils.data.neural_collate.val_loaderwithbatch_size=1is the simplest case but any batch size works thanks to NaN-pad-and-cat (§6).loss_fn – Callable
(pred, responses) -> Tensor. Defaultmse_loss. The deepSTRF losses auto-collapseresponsesto PSTH internally (metrics_paradigm.md§2), so no caller-sidenanmeanis needed.val_metrics – Mapping
name -> callable(pred, responses) -> per-neuron Tensor. Default: the canonical{'cc', 'cc_norm'}pair. Stored underf'val_{name}'in the epoch dict.optimizer – Any
torch.optim.Optimizer. Default:AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4).device – Where to place the model and per-batch tensors.
max_epochs – Hard cap on training epochs.
patience – Early-stop patience: number of epochs without improvement on
monitorbefore the loop terminates.monitor – Key in the per-epoch dict to track for early stopping. Default
'val_cc_norm'. Use'val_loss','val_cc', or any custom key you added viaval_metrics.mode –
'max'or'min'— direction of improvement onmonitor. Default'max'(paired with'val_cc_norm').ckpt_path – If given, save the best-on-
monitorstate_dictto this path and restore it at the end offit().log_fn – Called as
log_fn(epoch_dict)once per epoch. Default: a small formatter that printsepoch | k=v | .... Override to log to WandB, MLflow, a file, etc.track_train_metrics – If
True(default), recomputeval_metricsover the training predictions accumulated this epoch and add them to the epoch dict as'train_<name>'. Useful for diagnosing overfitting but expensive on large datasets — accumulating(B, N, R, T)responses across all train batches is the dominant per-epoch cost whenN × R × Tis in the millions (e.g. AA2’s 494-cell population). Set toFalseto skip;train_lossis always reported.track_per_cell_best – If
True, maintain a per-cell best-on-monitorsnapshot of the readout’s per-N parameter and buffer slices throughout training. At end-of-fit, after the globalckpt_pathrestore, each cell’s slice is overlaid with its individual-best snapshot. On no-shared-params models this is strictly at least as good as the vanilla restore on the validation set, cell-by-cell, by construction — every cell ends up at its individual val peak. The training trajectory itself is unchanged (no gradient masking, no per-cell stopping); the only difference is which checkpoint is restored at end. Requiresval_metrics[monitor.removeprefix ('val_')]to return a(N,)per-cell tensor (the default_default_val_metrics()does this). DefaultFalse.
- compute_loss(pred: Tensor, responses: Tensor) Tensor[source]
Default: delegate to
self.loss_fn(pred, responses)(auto-PSTH inside).
- evaluate(loader: DataLoader) Dict[str, Any][source]
Run loss +
val_metricson a loader (no backprop, no key prefix).Returns a dict with keys
'loss'plus each entry ofself.val_metrics. For test-set evaluation after training:fitter.evaluate(test_loader).
- deepSTRF.training.auto_config(model: Module, fitter_kwargs: Mapping[str, Any] | None = None, dataset_name: str | None = None, extra: Mapping[str, Any] | None = None) dict[source]
Build a JSON-friendly config dict from a deepSTRF model + fitter_kwargs.
- Pulled fields, when present:
model: class name ofmodeldataset:dataset_nameif givenn_frequency_bands:model.Ftemporal_window_size:model.Tout_neurons:model.Oprefiltering: class name ofmodel.prefilteringcore: class name ofmodel.corereadout: class name ofmodel.readoutreadout.kernel: class name ofmodel.readout.kernel(if present)readout.activation: class name ofmodel.readout.activation(if present)output_activation: class name ofmodel.output_activation(if present)Every JSON-friendly
fitter_kwargsentry; non-friendly values are stored as theirrepr()truncated to 60 chars.Every key/value in
extra, overwriting any of the above.
The result is safe to pass as
wandb.config(no torch tensors / Modules).
- deepSTRF.training.fit_multi_seed(model_factory: Callable[[int], Module], loader_factory: Callable[[int], Tuple[DataLoader, DataLoader, DataLoader]], n_seeds: int = 5, *, seeds: Sequence[int] | None = None, fitter_kwargs: Mapping[str, Any] | None = None, logger_factory: Callable[[int], Any] | None = None, output_dir: str | Path | None = None, set_seed_strict: bool = False) Dict[str, Any][source]
Run the same Fitter configuration
n_seedstimes under different seeds.For each seed: call
set_random_seed(seed), instantiate model and(train, val, test)loaders via the factories, fit a Fitter withfitter_kwargs, then re-evaluate post-restore on val and test.This is multi-seed init variance — same split, different init. NOT k-fold CV (that needs a split-factory, separate TODO).
- Parameters:
model_factory – Callable
seed -> nn.Module. Called fresh each seed, afterset_random_seed(seed), so weight init draws from the seeded RNG.loader_factory – Callable
seed -> (train_loader, val_loader, test_loader). Required 3-tuple. Called fresh each seed so each run gets a deterministic shuffle generator.n_seeds – Number of seeds to run. Ignored if
seedsis given. Default 5.seeds – Explicit seed list. Overrides
n_seeds. Default[0, 1, ..., n_seeds-1].fitter_kwargs – Forwarded to
Fitter(...). May not include'model','train_loader', or'val_loader'. If'ckpt_path'is set, each seed’s checkpoint is saved to<stem>_seed{seed}<suffix>so seeds don’t overwrite each other.logger_factory – Optional
Callable[[int], SeedLogger]. If given, the returned object is used as the Fitter’slog_fnfor that seed (so it is invoked once per epoch with the epoch dict). Afterfit(), if the logger exposesfinalize(final_metrics), it is called with{'val': val_post, 'test': test_post}.close()is called at end of seed if present. SeedeepSTRF.training.wandb_log.WandbSeedLoggerfor the reference implementation. DefaultNone(silent multi-seed sweep unlessfitter_kwargs['log_fn']is set).output_dir – If given, write per-seed
history.json(JSON-summarised epoch log),final.json(post-fit val + test population summaries),final_neurons.pt(per-neuron tensors for downstream analysis), andbest.pt(state_dict) underoutput_dir/seed{seed}/. Also writesoutput_dir/summary.json(across-seed mean / std + best_seed) andoutput_dir/summary_neurons.pt(per-neuron mean / std / per-seed tensors). Logger-agnostic: this happens regardless oflogger_factory. DefaultNone(in-memory results only).set_seed_strict – Forwarded to
set_random_seed(strict=...). DefaultFalse.
- Returns:
results –
- Keys:
'seeds':list[int]'per_seed_histories':list[list[dict]]— raw Fitter histories, one list per seed.- For each metric
m(defaultcc,cc_norm,loss): 'per_seed_val_<m>':(n_seeds, N)for per-neuron,(n_seeds, 1)for scalar (loss)'mean_val_<m>':(N,)nanmean across seeds'std_val_<m>':(N,)population nanstd across seedssame triple for
test.
- For each metric
'best_seed':int— seed whose post-fit valmonitormetric was best (max or min depending onmode).'best_state_dict':OrderedDict[str, Tensor]— deep copy of the best seed’s post-fit model state.
- Return type:
dict
Notes
“Best” uses the same
(monitor, mode)pair as the per-seed Fitters (defaultval_cc_norm/max), reduced to a scalar viananmeanover the neuron axis.
- deepSTRF.training.set_random_seed(seed: int, *, strict: bool = False) None[source]
Seed Python
random, NumPy, PyTorch CPU, and CUDA.- Parameters:
seed – Integer seed value, broadcast to all four RNGs.
strict – If True, additionally enables
torch.use_deterministic_algorithms(True)(with theCUBLAS_WORKSPACE_CONFIGenv var required by CUBLAS) so that any non-deterministic CUDA op raises rather than silently breaking reproducibility. Trades speed for bit-exactness; mainly useful for debugging. DefaultFalse.
Notes
Even with
strict=False,cudnn.deterministicis set to True andcudnn.benchmarkto False — that is the legacydeepSTRFbehaviour and is enough for run-to-run reproducibility on the same hardware so long as no opt-in non-deterministic kernel is invoked.
- deepSTRF.training.slug_from_config(config: Mapping[str, Any], fields: Sequence | None = None) str[source]
Build a short dash-separated slug like
'linear-ns1-T9-F34'.- Parameters:
config – A config dict (typically from
auto_config()).fields – Sequence of
(key, formatter)pairs.formatteris a callablevalue -> str. Missing keys are silently skipped. Default fields:model(lowercased),dataset(lowercased),temporal_window_size(asT{v}),n_frequency_bands(asF{v}).