Source code for deepSTRF.training.multi_seed

"""``deepSTRF.training.fit_multi_seed`` — multi-seed init-variance training.

See ``docs/_source/md/fitter.md`` §4.2 for the full design.

This wrapper runs the same Fitter configuration K times under different
seeds and aggregates per-neuron val and test metrics across seeds. It
addresses *initialization variance* — the same data split, the same
hyperparameters, but different initial weights and shuffle order. It is
NOT k-fold or leave-one-stim-out cross-validation; those are separate
TODOs that need a split-factory rather than a seed sweep.
"""

from __future__ import annotations

import copy
import json
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from deepSTRF.training.fitter import Fitter, _to_scalar
from deepSTRF.training.seed import set_random_seed


LoaderTriple = Tuple[DataLoader, DataLoader, DataLoader]


# -----------------------------------------------------------------------------
# Aggregation helpers
# -----------------------------------------------------------------------------


def _nanstd(x: torch.Tensor, dim: int = 0) -> torch.Tensor:
    """Population std along ``dim``, ignoring NaN entries."""
    mean = torch.nanmean(x, dim=dim, keepdim=True)
    sq = (x - mean) ** 2
    return torch.sqrt(torch.nanmean(sq, dim=dim))


def _summarise_tensor(t: torch.Tensor) -> Dict[str, float]:
    """Return ``{mean, p10, p50, p90, n_valid}`` for a 1-d per-neuron tensor.

    NaN entries are dropped before quantiles. The dict is JSON-serializable.
    """
    t = t.detach().cpu()
    valid = t[~t.isnan()]
    out: Dict[str, float] = {"mean": float(torch.nanmean(t).item()),
                              "n_valid": int(valid.numel())}
    if valid.numel():
        out["p10"] = float(torch.quantile(valid, 0.10).item())
        out["p50"] = float(torch.quantile(valid, 0.50).item())
        out["p90"] = float(torch.quantile(valid, 0.90).item())
    else:
        out["p10"] = float("nan")
        out["p50"] = float("nan")
        out["p90"] = float("nan")
    return out


def _jsonify(value: Any) -> Any:
    """Convert torch tensors in a value tree to JSON-friendly summaries.

    Scalars (0-d / length-1) become Python floats. Per-neuron tensors become
    ``_summarise_tensor`` dicts. Lists and dicts are walked recursively.
    Everything else passes through.
    """
    if isinstance(value, torch.Tensor):
        if value.dim() == 0 or value.numel() == 1:
            return float(value.detach().item())
        if value.dim() == 1:
            return _summarise_tensor(value)
        # higher-rank tensors: punt to mean — these are unexpected in the
        # epoch dict but we don't want to crash on weird custom metrics.
        return {"mean": float(torch.nanmean(value).item()),
                "shape": list(value.shape)}
    if isinstance(value, dict):
        return {k: _jsonify(v) for k, v in value.items()}
    if isinstance(value, (list, tuple)):
        return [_jsonify(v) for v in value]
    return value


def _save_seed_outputs(
    seed_dir: Path,
    history: List[Dict[str, Any]],
    val_post: Mapping[str, Any],
    test_post: Mapping[str, Any],
    state_dict: Mapping[str, torch.Tensor],
) -> None:
    """Write ``history.json``, ``final.json``, ``final_neurons.pt``,
    ``best.pt`` under ``seed_dir``."""
    seed_dir.mkdir(parents=True, exist_ok=True)

    # history.json: per-epoch dicts with tensors summarised to JSON
    with open(seed_dir / "history.json", "w") as f:
        json.dump([_jsonify(epoch) for epoch in history], f, indent=2)

    # final.json: population-level summaries for val + test
    with open(seed_dir / "final.json", "w") as f:
        json.dump({"val": _jsonify(dict(val_post)),
                    "test": _jsonify(dict(test_post))}, f, indent=2)

    # final_neurons.pt: full per-neuron tensors for downstream analysis
    def _per_neuron_only(d: Mapping[str, Any]) -> Dict[str, torch.Tensor]:
        out: Dict[str, torch.Tensor] = {}
        for k, v in d.items():
            if isinstance(v, torch.Tensor) and v.dim() >= 1 and v.numel() > 1:
                out[k] = v.detach().cpu()
        return out
    torch.save(
        {"val": _per_neuron_only(val_post), "test": _per_neuron_only(test_post)},
        seed_dir / "final_neurons.pt",
    )

    # best.pt: state_dict (deep copy already happened upstream)
    torch.save(state_dict, seed_dir / "best.pt")


def _save_summary(
    output_dir: Path,
    results: Mapping[str, Any],
    monitor: str,
    mode: str,
) -> None:
    """Write ``summary.json`` (across-seed mean/std + best_seed) and
    ``summary_neurons.pt`` (per-neuron mean/std tensors) under ``output_dir``."""
    summary_scalars: Dict[str, Any] = {
        "seeds": list(results["seeds"]),
        "best_seed": int(results["best_seed"]),
        "monitor": monitor,
        "mode": mode,
    }
    neuron_tensors: Dict[str, torch.Tensor] = {}
    for k, v in results.items():
        if k.startswith(("mean_", "std_", "per_seed_val_", "per_seed_test_")):
            if isinstance(v, torch.Tensor):
                if v.dim() == 1 and v.numel() > 1:
                    summary_scalars[k] = _summarise_tensor(v)
                    neuron_tensors[k] = v.detach().cpu()
                elif v.dim() <= 1:
                    summary_scalars[k] = float(torch.nanmean(v).item()) \
                        if v.numel() > 0 else float("nan")
                else:
                    # per_seed_*: (n_seeds, N)
                    summary_scalars[k] = {
                        "shape": list(v.shape),
                        "mean": float(torch.nanmean(v).item()),
                    }
                    neuron_tensors[k] = v.detach().cpu()
    with open(output_dir / "summary.json", "w") as f:
        json.dump(summary_scalars, f, indent=2)
    torch.save(neuron_tensors, output_dir / "summary_neurons.pt")


def _stack_metric(values: Sequence[Any]) -> torch.Tensor:
    """Stack a sequence of per-neuron tensors / scalars into ``(n_seeds, ...)``.

    Each entry is either a ``(N,)`` tensor (per-neuron metric) or a scalar
    (already-reduced metric, e.g. ``loss``). Scalars are upcast to 1-d
    length-1 tensors before stacking. All entries must share shape.
    """
    tensors: List[torch.Tensor] = []
    for v in values:
        if isinstance(v, torch.Tensor):
            t = v.detach().cpu().float()
            if t.dim() == 0:
                t = t.unsqueeze(0)
        else:
            t = torch.tensor([float(v)])
        tensors.append(t)
    shapes = {tuple(t.shape) for t in tensors}
    if len(shapes) > 1:
        raise ValueError(
            f"metric tensors across seeds have inconsistent shapes: {shapes}"
        )
    return torch.stack(tensors, dim=0)


# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------


[docs] def fit_multi_seed( model_factory: Callable[[int], nn.Module], loader_factory: Callable[[int], LoaderTriple], n_seeds: int = 5, *, seeds: Optional[Sequence[int]] = None, fitter_kwargs: Optional[Mapping[str, Any]] = None, logger_factory: Optional[Callable[[int], Any]] = None, output_dir: Optional[Union[str, Path]] = None, set_seed_strict: bool = False, ) -> Dict[str, Any]: """Run the same Fitter configuration ``n_seeds`` times under different seeds. For each seed: call ``set_random_seed(seed)``, instantiate model and ``(train, val, test)`` loaders via the factories, fit a Fitter with ``fitter_kwargs``, then re-evaluate post-restore on val and test. This is **multi-seed init variance** — same split, different init. NOT k-fold CV (that needs a split-factory, separate TODO). Parameters ---------- model_factory Callable ``seed -> nn.Module``. Called fresh each seed, after ``set_random_seed(seed)``, so weight init draws from the seeded RNG. loader_factory Callable ``seed -> (train_loader, val_loader, test_loader)``. Required 3-tuple. Called fresh each seed so each run gets a deterministic shuffle generator. n_seeds Number of seeds to run. Ignored if ``seeds`` is given. Default 5. seeds Explicit seed list. Overrides ``n_seeds``. Default ``[0, 1, ..., n_seeds-1]``. fitter_kwargs Forwarded to ``Fitter(...)``. May not include ``'model'``, ``'train_loader'``, or ``'val_loader'``. If ``'ckpt_path'`` is set, each seed's checkpoint is saved to ``<stem>_seed{seed}<suffix>`` so seeds don't overwrite each other. logger_factory Optional ``Callable[[int], SeedLogger]``. If given, the returned object is used as the Fitter's ``log_fn`` for that seed (so it is invoked once per epoch with the epoch dict). After ``fit()``, if the logger exposes ``finalize(final_metrics)``, it is called with ``{'val': val_post, 'test': test_post}``. ``close()`` is called at end of seed if present. See :class:`deepSTRF.training.wandb_log.WandbSeedLogger` for the reference implementation. Default ``None`` (silent multi-seed sweep unless ``fitter_kwargs['log_fn']`` is set). output_dir If given, write per-seed ``history.json`` (JSON-summarised epoch log), ``final.json`` (post-fit val + test population summaries), ``final_neurons.pt`` (per-neuron tensors for downstream analysis), and ``best.pt`` (state_dict) under ``output_dir/seed{seed}/``. Also writes ``output_dir/summary.json`` (across-seed mean / std + best_seed) and ``output_dir/summary_neurons.pt`` (per-neuron mean / std / per-seed tensors). Logger-agnostic: this happens regardless of ``logger_factory``. Default ``None`` (in-memory results only). set_seed_strict Forwarded to ``set_random_seed(strict=...)``. Default ``False``. Returns ------- results : dict Keys: - ``'seeds'`` : ``list[int]`` - ``'per_seed_histories'`` : ``list[list[dict]]`` — raw Fitter histories, one list per seed. - For each metric ``m`` (default ``cc``, ``cc_norm``, ``loss``): - ``'per_seed_val_<m>'`` : ``(n_seeds, N)`` for per-neuron, ``(n_seeds, 1)`` for scalar (``loss``) - ``'mean_val_<m>'`` : ``(N,)`` nanmean across seeds - ``'std_val_<m>'`` : ``(N,)`` population nanstd across seeds - same triple for ``test``. - ``'best_seed'`` : ``int`` — seed whose post-fit val ``monitor`` metric was best (max or min depending on ``mode``). - ``'best_state_dict'`` : ``OrderedDict[str, Tensor]`` — deep copy of the best seed's post-fit model state. Notes ----- "Best" uses the same ``(monitor, mode)`` pair as the per-seed Fitters (default ``val_cc_norm`` / ``max``), reduced to a scalar via ``nanmean`` over the neuron axis. """ if seeds is None: seeds = list(range(n_seeds)) else: seeds = list(seeds) n_seeds = len(seeds) if n_seeds < 1: raise ValueError("n_seeds must be >= 1") fitter_kwargs = dict(fitter_kwargs or {}) for k in ("model", "train_loader", "val_loader"): if k in fitter_kwargs: raise ValueError( f"fitter_kwargs[{k!r}] is managed by fit_multi_seed; remove it" ) user_log_fn = fitter_kwargs.pop("log_fn", None) base_ckpt = fitter_kwargs.pop("ckpt_path", None) output_path = Path(output_dir) if output_dir is not None else None if output_path is not None: output_path.mkdir(parents=True, exist_ok=True) monitor = fitter_kwargs.get("monitor", "val_cc_norm") mode = fitter_kwargs.get("mode", "max") monitor_key = monitor.removeprefix("val_") if monitor.startswith("val_") else monitor histories: List[List[Dict[str, Any]]] = [] val_per_seed: List[Dict[str, Any]] = [] test_per_seed: List[Dict[str, Any]] = [] state_dicts: List[Dict[str, torch.Tensor]] = [] for seed in seeds: set_random_seed(seed, strict=set_seed_strict) model = model_factory(seed) train_loader, val_loader, test_loader = loader_factory(seed) logger = logger_factory(seed) if logger_factory is not None else None if logger is not None: log_fn: Callable[[Mapping[str, Any]], None] = logger elif user_log_fn is not None: log_fn = user_log_fn else: log_fn = lambda d: None # silent default for multi-seed sweeps fk = dict(fitter_kwargs) fk["log_fn"] = log_fn if base_ckpt is not None: p = Path(base_ckpt) fk["ckpt_path"] = p.with_name(f"{p.stem}_seed{seed}{p.suffix}") fitter = Fitter(model, train_loader, val_loader, **fk) history = fitter.fit() val_post = fitter.evaluate(val_loader) test_post = fitter.evaluate(test_loader) if logger is not None: if hasattr(logger, "finalize"): logger.finalize({"val": val_post, "test": test_post}) if hasattr(logger, "close"): logger.close() histories.append(history) val_per_seed.append(val_post) test_per_seed.append(test_post) sd = copy.deepcopy(fitter.model.state_dict()) state_dicts.append(sd) if output_path is not None: _save_seed_outputs( output_path / f"seed{seed}", history=history, val_post=val_post, test_post=test_post, state_dict=sd, ) # Aggregation results: Dict[str, Any] = { "seeds": list(seeds), "per_seed_histories": histories, } metric_names = list(val_per_seed[0].keys()) for name in metric_names: for prefix, per_seed in (("val", val_per_seed), ("test", test_per_seed)): stacked = _stack_metric([p[name] for p in per_seed]) results[f"per_seed_{prefix}_{name}"] = stacked results[f"mean_{prefix}_{name}"] = torch.nanmean(stacked, dim=0) results[f"std_{prefix}_{name}"] = _nanstd(stacked, dim=0) # Best-of-N: highest (or lowest) post-fit val-monitor scalar. if monitor_key not in metric_names: raise KeyError( f"monitor {monitor!r} (key {monitor_key!r}) not produced by val " f"metrics; available: {sorted(metric_names)}" ) monitor_per_seed = torch.tensor( [_to_scalar(p[monitor_key]) for p in val_per_seed] ) best_idx = int( torch.argmax(monitor_per_seed).item() if mode == "max" else torch.argmin(monitor_per_seed).item() ) results["best_seed"] = seeds[best_idx] results["best_state_dict"] = state_dicts[best_idx] if output_path is not None: _save_summary(output_path, results, monitor=monitor, mode=mode) return results