Source code for deepSTRF.training.tb_log

"""Opt-in TensorBoard logger for ``fit_multi_seed`` (and any Fitter user).

Same protocol as :mod:`deepSTRF.training.wandb_log` — implements the
``SeedLogger`` duck-typed interface ``__call__`` / ``finalize`` /
``close`` against ``torch.utils.tensorboard.SummaryWriter`` instead of
``wandb.Run``.

Use TensorBoard when:
- You want a local-only, no-account, single-process viewer.
- You're on a machine with no outbound network.

Use WandB when:
- You want a cross-run table view with sortable columns.
- You want cloud-hosted persistence + sharing.

Browse the logs with ``tensorboard --logdir=<your log_dir>``; the URL
defaults to http://localhost:6006.

Examples
--------
>>> from deepSTRF.training import fit_multi_seed
>>> from deepSTRF.training.tb_log import make_tensorboard_logger_factory
>>> results = fit_multi_seed(
...     model_factory=..., loader_factory=..., n_seeds=3,
...     logger_factory=make_tensorboard_logger_factory(
...         log_dir="tb_logs", group="linear-ns1-T9",
...         config={"model": "Linear", "T": 9, "F": 34},
...     ),
... )
"""

from __future__ import annotations

from pathlib import Path
from typing import Any, Callable, Mapping, Optional, Union

import torch

from deepSTRF.training.fitter import _to_scalar


# -----------------------------------------------------------------------------
# Per-neuron helpers
# -----------------------------------------------------------------------------


def _log_per_neuron(writer, name: str, t: torch.Tensor, step: int) -> None:
    """Write {mean, p10, p50, p90, hist} for a 1-d per-neuron tensor.

    Mean uses ``nanmean`` (so a few NaN cells don't poison the value).
    Quantiles and histograms are computed over non-NaN cells only.
    """
    t = t.detach().cpu()
    valid = t[~t.isnan()]
    writer.add_scalar(name, float(torch.nanmean(t).item()), step)
    if valid.numel():
        writer.add_scalar(f"{name}/p10", float(torch.quantile(valid, 0.10).item()), step)
        writer.add_scalar(f"{name}/p50", float(torch.quantile(valid, 0.50).item()), step)
        writer.add_scalar(f"{name}/p90", float(torch.quantile(valid, 0.90).item()), step)
        writer.add_histogram(f"{name}/hist", valid, step)


# -----------------------------------------------------------------------------
# TensorBoardSeedLogger
# -----------------------------------------------------------------------------


[docs] class TensorBoardSeedLogger: """One-seed logger backed by a ``torch.utils.tensorboard.SummaryWriter``. Implements the ``SeedLogger`` protocol expected by :func:`fit_multi_seed`. Each seed writes to ``{log_dir}/{group or 'default'}/{run_name}/`` so a typical ``tensorboard --logdir <log_dir>`` invocation surfaces every group × seed combination side by side. Parameters ---------- seed Seed value. Used in the run name and the ``hparams/seed`` scalar. log_dir Root directory for event files. Subdirectories per group / seed are created automatically. group Optional group name. Determines the second-level subdirectory. name Optional user-supplied run name. Final run dir is ``{name}-seed{seed}``; falls back to ``{group}-seed{seed}`` then ``seed{seed}``. config Optional dict of hparams. Each JSON-friendly entry is written as ``hparams/<key>`` (numeric values as scalars, others via ``add_text``). The full config is also dumped as a single ``config`` text entry. """ def __init__( self, seed: int, *, log_dir: Union[str, Path] = "tb_logs", group: Optional[str] = None, name: Optional[str] = None, config: Optional[Mapping[str, Any]] = None, ) -> None: try: from torch.utils.tensorboard import SummaryWriter except ImportError as e: raise ImportError( "TensorBoard support requires `torch.utils.tensorboard` — " "install via `pip install tensorboard`." ) from e if name is not None: run_name = f"{name}-seed{seed}" elif group is not None: run_name = f"{group}-seed{seed}" else: run_name = f"seed{seed}" full_dir = Path(log_dir) / (group if group is not None else "default") / run_name full_dir.mkdir(parents=True, exist_ok=True) self.seed = seed self.run_dir = full_dir self.writer = SummaryWriter(log_dir=str(full_dir)) # Surface the full config as one text dump + per-key scalars/text if config: self.writer.add_text("config", _format_config(config), 0) for k, v in config.items(): if isinstance(v, (int, float)) and not isinstance(v, bool): self.writer.add_scalar(f"hparams/{k}", float(v), 0) else: self.writer.add_text(f"hparams/{k}", str(v), 0) self.writer.add_scalar("hparams/seed", float(seed), 0) # SeedLogger protocol -------------------------------------------------- def __call__(self, epoch_dict: Mapping[str, Any]) -> None: step = int(epoch_dict.get("epoch", 0)) for k, v in epoch_dict.items(): if k == "epoch": continue if isinstance(v, torch.Tensor) and v.dim() >= 1 and v.numel() > 1: _log_per_neuron(self.writer, k, v, step) else: self.writer.add_scalar(k, _to_scalar(v), step)
[docs] def finalize(self, final_metrics: Mapping[str, Mapping[str, Any]]) -> None: """Write final ``{prefix}_{metric}`` scalars + percentile / histogram tags at step 0 (TensorBoard's "Summary" doesn't exist as a separate concept — we use a fixed step so they appear as a single point on the chart).""" for prefix, metrics in final_metrics.items(): for k, v in metrics.items(): full = f"final/{prefix}_{k}" if isinstance(v, torch.Tensor) and v.dim() >= 1 and v.numel() > 1: _log_per_neuron(self.writer, full, v, step=0) else: self.writer.add_scalar(full, _to_scalar(v), 0)
[docs] def close(self) -> None: self.writer.flush() self.writer.close()
def _format_config(config: Mapping[str, Any]) -> str: """Pretty-format a config dict as a markdown table for ``add_text``.""" lines = ["| key | value |", "| --- | --- |"] for k, v in config.items(): lines.append(f"| {k} | {v!r} |") return "\n".join(lines) # ----------------------------------------------------------------------------- # Factory helper # -----------------------------------------------------------------------------
[docs] def make_tensorboard_logger_factory( **tb_kwargs: Any, ) -> Callable[[int], TensorBoardSeedLogger]: """Build a ``logger_factory`` that returns one ``TensorBoardSeedLogger`` per seed. All kwargs are forwarded to :class:`TensorBoardSeedLogger`. """ def _factory(seed: int) -> TensorBoardSeedLogger: return TensorBoardSeedLogger(seed=seed, **tb_kwargs) return _factory