Source code for deepSTRF.training.config

"""``deepSTRF.training.auto_config`` — build wandb.config / mlflow.log_params
dicts from a deepSTRF model + Fitter kwargs.

The point: a wandb run is most useful when *every* knob that varies
across runs is captured in ``wandb.config`` so the run table can sort /
filter / colour by it. Writing that dict by hand for each experiment is
boring; this module introspects the common deepSTRF audio-encoding
shape and pulls out the obvious fields.

See ``docs/_source/md/fitter.md`` §4.3.
"""

from __future__ import annotations

from typing import Any, Mapping, Optional, Sequence

import torch.nn as nn


_JSON_FRIENDLY = (int, float, str, bool, type(None))


[docs] def auto_config( model: nn.Module, fitter_kwargs: Optional[Mapping[str, Any]] = None, dataset_name: Optional[str] = None, extra: Optional[Mapping[str, Any]] = None, ) -> dict: """Build a JSON-friendly config dict from a deepSTRF model + fitter_kwargs. Pulled fields, when present: - ``model`` : class name of ``model`` - ``dataset`` : ``dataset_name`` if given - ``n_frequency_bands`` : ``model.F`` - ``temporal_window_size``: ``model.T`` - ``out_neurons`` : ``model.O`` - ``prefiltering`` : class name of ``model.prefiltering`` - ``core`` : class name of ``model.core`` - ``readout`` : class name of ``model.readout`` - ``readout.kernel`` : class name of ``model.readout.kernel`` (if present) - ``readout.activation`` : class name of ``model.readout.activation`` (if present) - ``output_activation`` : class name of ``model.output_activation`` (if present) - Every JSON-friendly ``fitter_kwargs`` entry; non-friendly values are stored as their ``repr()`` truncated to 60 chars. - Every key/value in ``extra``, overwriting any of the above. The result is safe to pass as ``wandb.config`` (no torch tensors / Modules). """ cfg: dict = {"model": type(model).__name__} if dataset_name is not None: cfg["dataset"] = dataset_name for attr, key in ( ("F", "n_frequency_bands"), ("T", "temporal_window_size"), ("O", "out_neurons"), ): if hasattr(model, attr): v = getattr(model, attr) if isinstance(v, _JSON_FRIENDLY): cfg[key] = v for attr in ("prefiltering", "core", "readout", "output_activation"): sub = getattr(model, attr, None) if isinstance(sub, nn.Module): cfg[attr] = type(sub).__name__ readout = getattr(model, "readout", None) if isinstance(readout, nn.Module): for sub_attr in ("kernel", "activation"): sub = getattr(readout, sub_attr, None) if isinstance(sub, nn.Module): cfg[f"readout.{sub_attr}"] = type(sub).__name__ if fitter_kwargs: for k, v in fitter_kwargs.items(): cfg[k] = v if isinstance(v, _JSON_FRIENDLY) else repr(v)[:60] if extra: cfg.update(extra) return cfg
_DEFAULT_SLUG_FIELDS = ( ("model", str.lower), ("dataset", str.lower), ("temporal_window_size", lambda v: f"T{v}"), ("n_frequency_bands", lambda v: f"F{v}"), )
[docs] def slug_from_config( config: Mapping[str, Any], fields: Optional[Sequence] = None, ) -> str: """Build a short dash-separated slug like ``'linear-ns1-T9-F34'``. Parameters ---------- config A config dict (typically from :func:`auto_config`). fields Sequence of ``(key, formatter)`` pairs. ``formatter`` is a callable ``value -> str``. Missing keys are silently skipped. Default fields: ``model`` (lowercased), ``dataset`` (lowercased), ``temporal_window_size`` (as ``T{v}``), ``n_frequency_bands`` (as ``F{v}``). """ fields = fields if fields is not None else _DEFAULT_SLUG_FIELDS parts = [] for key, fmt in fields: if key in config: parts.append(fmt(config[key])) return "-".join(parts)