Source code for deepSTRF.training.wandb_log

"""Opt-in WandB logger for ``fit_multi_seed`` (and any Fitter user).

This module is a thin convenience wrapper. The library does not require
WandB for training — ``fit_multi_seed`` only knows about a generic
``logger_factory(seed) -> SeedLogger`` protocol. ``WandbSeedLogger``
implements that protocol with one ``wandb.init`` per seed.

Build a factory with :func:`make_wandb_logger_factory` and pass it as
``logger_factory=`` to :func:`deepSTRF.training.fit_multi_seed`.

Examples
--------
>>> from deepSTRF.training import fit_multi_seed
>>> from deepSTRF.training.wandb_log import make_wandb_logger_factory
>>> results = fit_multi_seed(
...     model_factory=..., loader_factory=..., n_seeds=3,
...     logger_factory=make_wandb_logger_factory(
...         project="deepstrf", entity="urancon",
...         group="ns1-linear", mode="offline",
...     ),
... )

If WandB is not installed, importing this module raises ``ImportError``
with a clear hint; the rest of ``deepSTRF.training`` keeps working.
"""

from __future__ import annotations

import os
from typing import Any, Callable, Mapping, Optional

import torch

from deepSTRF.training.fitter import _to_scalar


# -----------------------------------------------------------------------------
# Per-neuron summarisation
# -----------------------------------------------------------------------------


def _per_neuron_scalars(name: str, t: torch.Tensor) -> dict:
    """Reduce a 1-d per-neuron tensor to a {mean, p10, p50, p90} scalar dict
    keyed ``{name}``, ``{name}/p10``, ``{name}/p50``, ``{name}/p90``.

    NaN cells are dropped before quantiles; ``nanmean`` is used for the mean.
    Returns ``{name: nan, ...}`` placeholders if every cell is NaN.
    """
    t = t.detach().cpu()
    valid = t[~t.isnan()]
    out: dict = {name: float(torch.nanmean(t).item())}
    if valid.numel():
        out[f"{name}/p10"] = float(torch.quantile(valid, 0.10).item())
        out[f"{name}/p50"] = float(torch.quantile(valid, 0.50).item())
        out[f"{name}/p90"] = float(torch.quantile(valid, 0.90).item())
    else:
        out[f"{name}/p10"] = float("nan")
        out[f"{name}/p50"] = float("nan")
        out[f"{name}/p90"] = float("nan")
    return out


def _make_histogram(t: torch.Tensor):
    """Return a ``wandb.Histogram`` of the non-NaN cells, or ``None`` if empty."""
    import wandb
    valid = t.detach().cpu()
    valid = valid[~valid.isnan()]
    if valid.numel() == 0:
        return None
    return wandb.Histogram(valid.numpy())


# -----------------------------------------------------------------------------
# WandbSeedLogger — one wandb run per seed
# -----------------------------------------------------------------------------


[docs] class WandbSeedLogger: """One-seed logger backed by a single ``wandb.Run``. Implements the duck-typed ``SeedLogger`` protocol expected by :func:`fit_multi_seed`: - ``__call__(epoch_dict)`` is invoked once per epoch; per-neuron tensors are summarised as ``{mean, p10, p50, p90}`` scalars plus a per-epoch ``wandb.Histogram``. - ``finalize(final_metrics)`` is invoked once at end of seed, after ``Fitter.fit()`` and the post-fit val + test ``evaluate`` passes. It writes test metrics + per-neuron summaries to ``run.summary`` (so they show up as sortable run-level columns). - ``close()`` calls ``run.finish()``. Parameters ---------- seed The seed value. Added to ``wandb.config`` and used to derive the run name. project, entity, group Forwarded to ``wandb.init``. Run name is auto-derived as ``<name>-seed{seed}`` if ``name=`` is set, else ``<group>-seed{seed}`` if ``group=`` is set, else ``seed{seed}``. name, dir, mode, config, **wandb_init_kwargs Forwarded to ``wandb.init`` verbatim. ``WANDB_MODE`` defaults to ``offline`` if the env var is unset and ``mode=`` is not given, so logging works with no account / network. """ def __init__( self, seed: int, *, project: Optional[str] = None, entity: Optional[str] = None, group: Optional[str] = None, name: Optional[str] = None, dir: Optional[str] = None, mode: Optional[str] = None, config: Optional[Mapping[str, Any]] = None, **wandb_init_kwargs: Any, ) -> None: try: import wandb except ImportError as e: raise ImportError( "wandb is required for WandbSeedLogger; " "`pip install wandb`." ) from e if mode is None: os.environ.setdefault("WANDB_MODE", "offline") if name is not None: run_name = f"{name}-seed{seed}" elif group is not None: run_name = f"{group}-seed{seed}" else: run_name = f"seed{seed}" cfg = dict(config or {}) cfg["seed"] = seed init_kwargs = dict(wandb_init_kwargs) if mode is not None: init_kwargs["mode"] = mode if dir is not None: init_kwargs["dir"] = dir self.seed = seed self.run = wandb.init( project=project, entity=entity, group=group, name=run_name, config=cfg, reinit=True, **init_kwargs, ) # SeedLogger protocol -------------------------------------------------- def __call__(self, epoch_dict: Mapping[str, Any]) -> None: step = int(epoch_dict.get("epoch", 0)) flat: dict = {} for k, v in epoch_dict.items(): if k == "epoch": continue if isinstance(v, torch.Tensor) and v.dim() >= 1 and v.numel() > 1: flat.update(_per_neuron_scalars(k, v)) hist = _make_histogram(v) if hist is not None: flat[f"{k}/hist"] = hist else: flat[k] = _to_scalar(v) self.run.log(flat, step=step)
[docs] def finalize(self, final_metrics: Mapping[str, Mapping[str, Any]]) -> None: """Write ``{prefix}_{metric}`` summary rows + percentile scalars. ``final_metrics`` keys are split prefixes (e.g. ``"val"``, ``"test"``) each holding a dict of metric_name -> scalar-or-per-neuron-tensor. """ summary: dict = {} for prefix, metrics in final_metrics.items(): for k, v in metrics.items(): full = f"{prefix}_{k}" if isinstance(v, torch.Tensor) and v.dim() >= 1 and v.numel() > 1: summary.update(_per_neuron_scalars(full, v)) hist = _make_histogram(v) if hist is not None: summary[f"{full}/hist"] = hist else: summary[full] = _to_scalar(v) self.run.summary.update(summary)
[docs] def close(self) -> None: self.run.finish()
# ----------------------------------------------------------------------------- # Factory helper # -----------------------------------------------------------------------------
[docs] def make_wandb_logger_factory(**wandb_kwargs: Any) -> Callable[[int], WandbSeedLogger]: """Build a ``logger_factory`` that returns one ``WandbSeedLogger`` per seed. All kwargs are forwarded to :class:`WandbSeedLogger`. Use as ``fit_multi_seed(..., logger_factory=make_wandb_logger_factory(project="...", ...))``. """ def _factory(seed: int) -> WandbSeedLogger: return WandbSeedLogger(seed=seed, **wandb_kwargs) return _factory