Source code for deepSTRF.models.video.video_model

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

from deepSTRF.models.neural_model import NeuralModel


[docs] class VideoEncodingModel(NeuralModel): """ General mother class for ENCODING models of VIDEO sensory neural responses. The forward() method takes as input a single-channel spectrogram of shape (B, C, H, W, T) TODO: prefiltering_dict example """ def __init__(self, spatial_resol, temporal_window_size: int, out_neurons: int = 1, output_activation: nn.Module = None, *args, **kwargs): super().__init__(out_neurons=out_neurons, *args, **kwargs) # general attributes for VIDEO neural response models self.H, self.W = spatial_resol self.T = temporal_window_size self.output_activation = output_activation if output_activation is not None else nn.Identity()
[docs] def validate(self): super().validate() assert isinstance(self.H, int) and self.H > 0, \ f"self.H must be a positive int (got {self.H!r})" assert isinstance(self.W, int) and self.W > 0, \ f"self.W must be a positive int (got {self.W!r})" assert isinstance(self.T, int) and self.T > 0, \ f"self.T must be a positive int (got {self.T!r})"
[docs] def STRF_gradmap(self, T=None): """ Get the SPATIO-Temporal Receptive Field (STRF) of the OUTPUT neurons, with a history of T timesteps, as the changes in the stimulus that elicit an increase in output activity. cf. Rançon et al. (2025), "Temporal recurrence as a general mechanism to explain neural responses in the auditory system", BioRxiv Returns a (N, C, H, W, T) tensor TODO: - handle multiple input channels (on & off) because of adaptrans ? - allow custom losses ? (e.g. sustained activity rather than last spike ?) """ B = self.O # use the batch dimension to parallelize # initial stim = null stimulus = absence of bias / absence of information / no entropy if T is not None: stim_opt = Parameter(torch.zeros(B, 1, self.H, self.W, T), requires_grad=True) else: stim_opt = Parameter(torch.zeros(B, 1, self.H, self.W, self.T), requires_grad=True) # forward pass response = self.forward(stim_opt) # (B=N, N, T) # Spike-Triggered Average (STA) loss = activation at the last timestep loss = - torch.trace(response[:, :, -1]) # scalar # backward pass loss.backward() # STRF / gradmap = gradient of this loss w.r.t. this null input strf_gradmap = stim_opt.grad return strf_gradmap