import functools
import inspect
import json
from abc import ABC
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
[docs]
class NeuralModel(nn.Module, ABC):
"""Base class for encoding models of sensory neural responses.
A four-slot template defines the canonical forward pipeline::
forward(x):
x = self.wav2spec(x) # raw-waveform front-end (future)
x = self.prefiltering(x) # AdapTrans / ICAdaptation / Identity
f = self.core(x) # shared feature backbone
return self.readout(f) # per-neuron projection (B, N, 1, T)
Concrete subclasses populate the slots in their ``__init__``. The
defaults for ``wav2spec``, ``prefiltering`` and ``core`` are
``nn.Identity``, so a minimal model only needs to provide a
``readout``. Subclasses may override :meth:`forward` for architectures
that don't fit the four-slot pipeline (e.g. StateNet's recurrent
reshape, Transformer's per-frame attention).
See ``docs/_source/md/model_paradigm.md`` for the full contract.
Parameters
----------
out_neurons : int, default 1
Number of output neurons ``N`` the model predicts. Stored on
``self.O`` and used by :meth:`validate` and ``STRF_gradmap``.
Notes
-----
Subclasses inherit a Hugging Face Hub interface for pretrained
checkpoints:
- :meth:`save_pretrained` — write config + weights to a folder.
- :meth:`push_to_hub` — upload to the HF Hub (auth required).
- :meth:`from_pretrained` — instantiate and load weights.
The init kwargs needed to rebuild the architecture are auto-captured
on construction (see :meth:`__init_subclass__`), so end-users never
have to supply a config dict — ``StateNet.from_pretrained("urancon/...")``
just works.
"""
def __init__(self, out_neurons: int = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
# number of output neurons N (used by validate() and STRF_gradmap)
self.O = out_neurons
# canonical pipeline slots — defaults are no-ops; subclasses override
self.wav2spec = nn.Identity()
self.prefiltering = nn.Identity()
self.core = nn.Identity()
# readout has no sensible default; subclasses must set it before
# forward() is called. validate() enforces this.
def __init_subclass__(cls, **kwargs):
"""Wrap each subclass's ``__init__`` to auto-capture JSON-serialisable
kwargs into ``self._init_kwargs``.
The capture only fires for the *leaf* class — ``super().__init__``
calls cascading through intermediate wrappers no-op. Non-JSON values
(e.g. an ``nn.Module`` passed for ``prefiltering``) are silently
dropped; they have to be re-supplied via ``from_pretrained(...,
extra_kwargs=...)``.
"""
super().__init_subclass__(**kwargs)
if "__init__" not in cls.__dict__:
return
original_init = cls.__init__
sig = inspect.signature(original_init)
@functools.wraps(original_init)
def wrapped_init(self, *args, **kw):
# Only the leaf class wrapper captures. Intermediate super().__init__
# cascades land in their own wrappers but type(self) won't match.
if type(self) is cls:
self._init_kwargs = _capture_kwargs(sig, self, args, kw)
original_init(self, *args, **kw)
cls.__init__ = wrapped_init
[docs]
def forward(self, stimulus):
"""Run the default template pipeline.
Applies ``wav2spec`` → ``prefiltering`` → ``core`` → ``readout`` in
sequence.
Parameters
----------
stimulus : torch.Tensor
Input stimulus batch (shape is modality-dependent; for audio
models a spectrogram ``(B, F, T)``).
Returns
-------
torch.Tensor
Predicted response of shape ``(B, N, 1, T)``.
"""
x = self.wav2spec(stimulus)
x = self.prefiltering(x)
f = self.core(x)
return self.readout(f)
[docs]
def detach(self):
"""Detach stateful variables from the computational graph.
No-op by default; recurrent subclasses override this to truncate
backpropagation-through-time between chunks (cf. spikingjelly).
"""
pass
[docs]
def count_trainable_params(self):
"""Count the model's trainable parameters.
Returns
-------
int
Total number of parameters with ``requires_grad=True``.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
[docs]
def validate(self):
"""Check that the instance is deepSTRF-compatible.
Subclasses should call ``super().validate()`` and then add their own
checks (e.g. :class:`~deepSTRF.models.audio.audio_model.AudioEncodingModel`
checks ``F, T > 0``).
Raises
------
AssertionError
If ``self.O`` is not a positive int, if ``readout`` is unset or
not an :class:`torch.nn.Module`, or if any of the ``wav2spec`` /
``prefiltering`` / ``core`` slots is not an :class:`torch.nn.Module`.
"""
assert isinstance(self.O, int) and self.O > 0, \
f"self.O must be a positive int (got {self.O!r})"
assert hasattr(self, 'readout') and isinstance(self.readout, nn.Module), \
f"{type(self).__name__}.readout must be set to an nn.Module before validate()"
for slot in ('wav2spec', 'prefiltering', 'core'):
assert isinstance(getattr(self, slot), nn.Module), \
f"self.{slot} must be an nn.Module (got {type(getattr(self, slot)).__name__})"
# ------------------------------------------------------------------
# Pretrained-weights API (HuggingFace Hub)
# ------------------------------------------------------------------
[docs]
def save_pretrained(self, save_dir: Union[str, Path], *,
metadata: Optional[dict] = None,
model_card: Optional[str] = None) -> Path:
"""Write a checkpoint folder (``config.json`` + ``model.safetensors``).
See :func:`deepSTRF.utils.hub.save_pretrained_to_dir` for the
full contract.
Parameters
----------
save_dir : str or pathlib.Path
Destination folder; created if it does not exist.
metadata : dict, optional
Extra JSON-serialisable metadata stored alongside the config
(e.g. training dataset, val/test scores).
model_card : str, optional
Markdown content written to ``README.md`` in the folder.
Returns
-------
pathlib.Path
Path to the written checkpoint folder.
"""
from deepSTRF.utils.hub import save_pretrained_to_dir
return save_pretrained_to_dir(self, save_dir, metadata=metadata,
model_card=model_card)
[docs]
@classmethod
def from_pretrained(cls, repo_id_or_path: Union[str, Path], *,
extra_kwargs: Optional[dict] = None,
strict: bool = True,
map_location: Union[str, torch.device] = "cpu",
cache_dir: Optional[Union[str, Path]] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
return_metadata: bool = False):
"""Instantiate ``cls`` from an HF Hub repo or a local checkpoint folder.
Parameters
----------
repo_id_or_path : str or Path
``"<owner>/<name>"`` (HF Hub) or a path to a local folder
produced by :meth:`save_pretrained`. The decision is made by
:func:`pathlib.Path.is_dir` — if the path exists locally it
wins, otherwise we treat the string as a Hub repo id.
extra_kwargs : dict, optional
Override / extend the saved config. Use this to re-supply any
``__init__`` argument that wasn't JSON-serialisable at save
time (e.g. a custom ``prefiltering`` module).
strict : bool, default True
``state_dict`` strictness.
map_location : str or torch.device, default ``'cpu'``
cache_dir, revision, token
Forwarded to :func:`huggingface_hub.snapshot_download`.
return_metadata : bool, default False
If True, return ``(model, metadata)`` instead of just ``model``.
Returns
-------
model
Instance of ``cls`` with weights loaded.
"""
from deepSTRF.utils.hub import download_pretrained, load_pretrained_from_dir
path = Path(str(repo_id_or_path)).expanduser()
if path.is_dir():
local_dir = path
else:
local_dir = download_pretrained(str(repo_id_or_path),
cache_dir=cache_dir,
revision=revision, token=token)
model, metadata = load_pretrained_from_dir(
cls, local_dir, extra_kwargs=extra_kwargs,
strict=strict, map_location=map_location,
)
if return_metadata:
return model, metadata
return model
[docs]
def push_to_hub(self, repo_id: str, *,
metadata: Optional[dict] = None,
model_card: Optional[str] = None,
private: bool = False,
token: Optional[str] = None,
commit_message: Optional[str] = None) -> str:
"""Push this model to ``repo_id`` on the HF Hub.
Saves the checkpoint to a temporary folder and uploads it via
:func:`deepSTRF.utils.hub.upload_pretrained`. Creates the repo on
the fly if it doesn't exist; user must be authenticated with
write access (``hf auth login`` or ``token=``).
Returns
-------
str
URL of the resulting commit.
"""
import tempfile
from deepSTRF.utils.hub import save_pretrained_to_dir, upload_pretrained
with tempfile.TemporaryDirectory() as tmp:
save_pretrained_to_dir(self, tmp, metadata=metadata, model_card=model_card)
return upload_pretrained(tmp, repo_id, private=private, token=token,
commit_message=commit_message)
def _capture_kwargs(sig: inspect.Signature, self, args: tuple, kw: dict) -> dict:
"""Bind ``args`` / ``kw`` to ``sig`` and keep only JSON-serialisable
named parameters (skipping ``self`` / ``*args`` / ``**kwargs``)."""
try:
bound = sig.bind_partial(self, *args, **kw)
except TypeError:
return {}
bound.apply_defaults()
captured = {}
for name, val in bound.arguments.items():
if name == "self":
continue
param = sig.parameters.get(name)
if param is not None and param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
try:
json.dumps(val)
except (TypeError, ValueError):
continue
captured[name] = val
return captured