"""Loss functions for deepSTRF: gradient-bearing, NaN-aware, reduction over neurons.
See ``docs/_source/md/metrics_paradigm.md`` for shape conventions, NaN handling,
and reduction semantics.
"""
from __future__ import annotations
from typing import Optional
import torch
from deepSTRF.metrics._masking import (
collapse_to_psth_if_needed,
per_neuron_mean,
reduce_over_neurons,
resolve_mask,
)
[docs]
def mse_loss(
pred: torch.Tensor,
gt: torch.Tensor,
mask: Optional[torch.Tensor] = None,
reduction: str = "mean",
) -> torch.Tensor:
"""Boolean-masked mean squared error, reduced over the neuron axis.
Parameters
----------
pred : torch.Tensor
Prediction tensor of shape ``(B, N, 1, T)``.
gt : torch.Tensor
Ground-truth tensor of shape ``(B, N, 1, T)`` (PSTH or single-trial
target) **or** ``(B, N, R, T)`` with ``R > 1`` (raw responses), in
which case it is collapsed to PSTH via ``nanmean(dim=2, keepdim=True)``
before the loss is computed. ``gt`` may contain NaN; positions where
the resulting PSTH is NaN are dropped from the per-neuron mean.
mask : torch.Tensor, optional
Bool tensor broadcastable to the post-collapse ``gt`` shape
``(B, N, 1, T)``. If None, defaults to ``~gt.isnan()``. If provided,
REPLACES (does not augment) the NaN-derived mask.
reduction : {'none', 'mean', 'sum'}, default 'mean'
Reduction over the neuron axis. ``'none'`` returns the per-neuron
vector; ``'mean'`` / ``'sum'`` reduce it via nanmean / nansum.
Returns
-------
torch.Tensor
Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar.
"""
gt = collapse_to_psth_if_needed(gt)
if pred.shape != gt.shape:
raise ValueError(
f"pred shape {tuple(pred.shape)} must equal gt shape "
f"{tuple(gt.shape)}"
)
if pred.dim() != 4:
raise ValueError(
f"expected pred and gt with 4 dims (B, N, R, T), got {pred.dim()}"
)
valid = resolve_mask(gt, mask)
# Replace NaN positions in gt with 0 BEFORE the residual so the
# autograd path stays finite. The forward sum is unchanged because
# per_neuron_mean masks these positions out; without this the
# gradient ``2*(pred - NaN) * 0`` evaluates to NaN under IEEE 754
# and contaminates upstream parameters.
gt_safe = torch.where(valid, gt, torch.zeros_like(gt))
diff_sq = (pred - gt_safe) ** 2
per_neuron = per_neuron_mean(diff_sq, valid)
return reduce_over_neurons(per_neuron, reduction)
[docs]
def poisson_loss(
pred: torch.Tensor,
gt: torch.Tensor,
mask: Optional[torch.Tensor] = None,
reduction: str = "mean",
log_input: bool = False,
validate_input: bool = False,
eps: float = 1e-8,
) -> torch.Tensor:
"""Negative Poisson log-likelihood (without the ``log(gt!)`` constant).
Parameters
----------
pred : torch.Tensor
Prediction of shape ``(B, N, 1, T)``. Interpreted as the rate ``λ``
when ``log_input=False`` and as the log-rate ``η = log(λ)`` when
``log_input=True`` (see Notes).
gt : torch.Tensor
Ground-truth target. A pre-computed PSTH ``(B, N, 1, T)`` or a raw
responses tensor ``(B, N, R, T)`` with ``R > 1``; in the latter case
it is collapsed to PSTH via ``nanmean(dim=2, keepdim=True)`` first.
May contain NaN.
mask : torch.Tensor, optional
Bool tensor broadcastable to ``(B, N, 1, T)``. If None, defaults to
``~gt.isnan()``. If provided, REPLACES (does not augment) the
NaN-derived mask.
reduction : {'none', 'mean', 'sum'}, default 'mean'
Reduction over the neuron axis (``'none'`` keeps the ``(N,)`` vector).
log_input : bool, default False
Selects the prediction parameterisation (see Notes).
validate_input : bool, default False
When ``log_input=False``, raise on negative ``pred`` at masked-in
positions instead of silently clamping inside the log. Costs a
per-step CPU sync.
eps : float, default 1e-8
Floor applied to ``pred`` inside the log when ``log_input=False``.
Returns
-------
torch.Tensor
Shape ``(N,)`` if ``reduction='none'``, otherwise a scalar.
Notes
-----
Two parameterisations of the prediction are supported, matching the
canonical-link logic of generalised linear models:
- ``log_input=False`` (default): the loss is ``pred − gt · log(pred + eps)``
per element. ``pred`` must be non-negative for the log to be meaningful;
the implementation *silently clamps* ``pred`` to ``≥ eps`` inside the
``log`` to avoid NaN propagation (the linear term keeps its sign). Pass
``validate_input=True`` for a loud failure on negative predictions.
- ``log_input=True``: the loss becomes ``exp(pred) − gt · pred``, which is
well-defined for any real-valued ``pred``. This is the standard trick
for pairing a Poisson NLL with an unbounded readout (Linear,
sign-permitting ParametricSigmoid, etc.). See ``metrics_paradigm.md``
§6.2 for the GLM-canonical-link derivation.
The ``log(gt!)`` Stirling term is *not* added — for non-integer ``gt``
(e.g. trial-averaged PSTH binned counts) it is meaningless, and for
integer ``gt`` it is constant in ``pred`` so it does not affect
optimisation. Users who want the full likelihood for AIC/BIC can add it
themselves.
"""
gt = collapse_to_psth_if_needed(gt)
if pred.shape != gt.shape:
raise ValueError(
f"pred shape {tuple(pred.shape)} must equal gt shape "
f"{tuple(gt.shape)}"
)
if pred.dim() != 4:
raise ValueError(
f"expected pred and gt with 4 dims (B, N, R, T), got {pred.dim()}"
)
valid = resolve_mask(gt, mask)
# Same NaN-gradient-leak protection as in mse_loss: replace NaN
# positions in gt with 0 before the loss formula, so the autograd
# path stays finite.
gt_safe = torch.where(valid, gt, torch.zeros_like(gt))
if log_input:
elem = torch.exp(pred) - gt_safe * pred
else:
if validate_input and (pred.masked_fill(~valid, 0.0) < 0).any():
raise ValueError(
"poisson_loss(log_input=False): pred has negative values at "
"masked-in positions. Either use a non-negative output "
"activation, set log_input=True (interpret pred as log-rate), "
"or drop validate_input=True to silently clamp inside log."
)
elem = pred - gt_safe * torch.log(pred.clamp(min=eps) + eps)
per_neuron = per_neuron_mean(elem, valid)
return reduce_over_neurons(per_neuron, reduction)