Source code for snntorch._neurons.rleaky

import torch
import torch.nn as nn

# from torch import functional as F
from .neurons import LIF


[docs] class RLeaky(LIF): """ First-order recurrent leaky integrate-and-fire neuron model. Input is assumed to be a current injection appended to the voltage spike output. Membrane potential decays exponentially with rate beta. For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`. If `reset_mechanism = "subtract"`, then :math:`U[t+1]` will have `threshold` subtracted from it whenever the neuron emits a spike: .. math:: U[t+1] = βU[t] + I_{\\rm in}[t+1] + V(S_{\\rm out}[t]) - RU_{\\rm thr} Where :math:`V(\\cdot)` acts either as a linear layer, a convolutional operator, or elementwise product on :math:`S_{\\rm out}`. * If `all_to_all = "True"` and `linear_features` is specified, then \ :math:`V(\\cdot)` acts as a recurrent linear layer of the \ same size as :math:`S_{\\rm out}`. * If `all_to_all = "True"` and `conv2d_channels` and `kernel_size` are \ specified, then :math:`V(\\cdot)` acts as a recurrent convlutional \ layer \ with padding to ensure the output matches the size of the input. * If `all_to_all = "False"`, then :math:`V(\\cdot)` acts as an \ elementwise multiplier with :math:`V`. * If `reset_mechanism = "zero"`, then :math:`U[t+1]` will be set to `0` \ whenever the neuron emits a spike: .. math:: U[t+1] = βU[t] + I_{\\rm in}[t+1] + V(S_{\\rm out}[t]) - R(βU[t] + I_{\\rm in}[t+1] + V(S_{\\rm out}[t])) * :math:`I_{\\rm in}` - Input current * :math:`U` - Membrane potential * :math:`U_{\\rm thr}` - Membrane threshold * :math:`S_{\\rm out}` - Output spike * :math:`R` - Reset mechanism: if active, :math:`R = 1`, otherwise \ :math:`R = 0` * :math:`β` - Membrane potential decay rate * :math:`V` - Explicit recurrent weight when `all_to_all=False` Example:: import torch import torch.nn as nn import snntorch as snn beta = 0.5 # decay rate V1 = 0.5 # shared recurrent connection V2 = torch.rand(num_outputs) # unshared recurrent connections # Define Network class Net(nn.Module): def __init__(self): super().__init__() # initialize layers self.fc1 = nn.Linear(num_inputs, num_hidden) # Default RLeaky Layer where recurrent connections # are initialized using PyTorch defaults in nn.Linear. self.lif1 = snn.RLeaky(beta=beta, linear_features=num_hidden) self.fc2 = nn.Linear(num_hidden, num_outputs) # each neuron has a single connection back to itself # where the output spike is scaled by V. # For `all_to_all = False`, V can be shared between # neurons (e.g., V1) or unique / unshared between # neurons (e.g., V2). # V is learnable by default. self.lif2 = snn.RLeaky(beta=beta, all_to_all=False, V=V1) def forward(self, x): # Initialize hidden states at t=0 spk1, mem1 = self.lif1.init_rleaky() spk2, mem2 = self.lif2.init_rleaky() # Record output layer spikes and membrane spk2_rec = [] mem2_rec = [] # time-loop for step in range(num_steps): cur1 = self.fc1(x) spk1, mem1 = self.lif1(cur1, spk1, mem1) cur2 = self.fc2(spk1) spk2, mem2 = self.lif2(cur2, spk2, mem2) spk2_rec.append(spk2) mem2_rec.append(mem2) # convert lists to tensors spk2_rec = torch.stack(spk2_rec) mem2_rec = torch.stack(mem2_rec) return spk2_rec, mem2_rec :param beta: membrane potential decay rate. Clipped between 0 and 1 during the forward-pass. May be a single-valued tensor (i.e., equal decay rate for all neurons in a layer), or multi-valued (one weight per neuron). :type beta: float or torch.tensor :param V: Recurrent weights to scale output spikes, only used when `all_to_all=False`. Defaults to 1. :type V: float or torch.tensor :param all_to_all: Enables output spikes to be connected in dense or convolutional recurrent structures instead of 1-to-1 connections. Defaults to True. :type all_to_all: bool, optional :param linear_features: Size of each output sample. Must be specified if `all_to_all=True` and the input data is 1D. Defaults to None :type linear_features: int, optional :param conv2d_channels: Number of channels in each output sample. Must be specified if `all_to_all=True` and the input data is 3D. Defaults to None :type conv2d_channels: int, optional :param kernel_size: Size of the convolving kernel. Must be specified if `all_to_all=True` and the input data is 3D. Defaults to None :type kernel_size: int or tuple :param threshold: Threshold for :math:`mem` to reach in order to generate a spike `S=1`. Defaults to 1 :type threshold: float, optional :param spike_grad: Surrogate gradient for the term dS/dU. Defaults to None (corresponds to ATan surrogate gradient. See `snntorch.surrogate` for more options) :type spike_grad: surrogate gradient function from snntorch.surrogate, optional :param surrogate_disable: Disables surrogate gradients regardless of `spike_grad` argument. Useful for ONNX compatibility. Defaults to False :type surrogate_disable: bool, Optional :param init_hidden: Instantiates state variables as instance variables. Defaults to False :type init_hidden: bool, optional :param inhibition: If `True`, suppresses all spiking other than the neuron with the highest state. Defaults to False :type inhibition: bool, optional :param learn_beta: Option to enable learnable beta. Defaults to False :type learn_beta: bool, optional :param learn_recurrent: Option to enable learnable recurrent weights. Defaults to True :type learn_recurrent: bool, optional :param learn_threshold: Option to enable learnable threshold. Defaults to False :type learn_threshold: bool, optional :param reset_mechanism: Defines the reset mechanism applied to :math:`mem` each time the threshold is met. Reset-by-subtraction: "subtract", reset-to-zero: "zero", none: "none". Defaults to "subtract" :type reset_mechanism: str, optional :param state_quant: If specified, hidden state :math:`mem` is quantized to a valid state for the forward pass. Defaults to False :type state_quant: quantization function from snntorch.quant, optional :param output: If `True` as well as `init_hidden=True`, states are returned when neuron is called. Defaults to False :type output: bool, optional Inputs: \\input_, spk_0, mem_0 - **input_** of shape `(batch, input_size)`: tensor containing input features - **spk_0** of shape `(batch, input_size)`: tensor containing output spike features - **mem_0** of shape `(batch, input_size)`: tensor containing the initial membrane potential for each element in the batch. Outputs: spk_1, mem_1 - **spk_1** of shape `(batch, input_size)`: tensor containing the output spikes. - **mem_1** of shape `(batch, input_size)`: tensor containing the next membrane potential for each element in the batch Learnable Parameters: - **RLeaky.beta** (torch.Tensor) - optional learnable weights must be manually passed in, of shape `1` or (input_size). - **RLeaky.recurrent.weight** (torch.Tensor) - optional learnable weights are automatically generated if `all_to_all=True`. `RLeaky.recurrent` stores a `nn.Linear` or `nn.Conv2d` layer depending on input arguments provided. - **RLeaky.V** (torch.Tensor) - optional learnable weights must be manually passed in, of shape `1` or (input_size). It is only used where `all_to_all=False` for 1-to-1 recurrent connections. - **RLeaky.threshold** (torch.Tensor) - optional learnable thresholds must be manually passed in, of shape `1` or`` (input_size). """ def __init__( self, beta, V=1.0, all_to_all=True, linear_features=None, conv2d_channels=None, kernel_size=None, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_beta=False, learn_threshold=False, learn_recurrent=True, # changed learn_V reset_mechanism="subtract", state_quant=False, output=False, reset_delay=True, ): super().__init__( beta, threshold, spike_grad, surrogate_disable, init_hidden, inhibition, learn_beta, learn_threshold, reset_mechanism, state_quant, output, ) self.all_to_all = all_to_all self.learn_recurrent = learn_recurrent # linear params self.linear_features = linear_features # Conv2d params self.kernel_size = kernel_size self.conv2d_channels = conv2d_channels # catch cases self._rleaky_init_cases() # initialize recurrent connections if self.all_to_all: # init all-all connections self._init_recurrent_net() else: # initialize 1-1 connections self._V_register_buffer(V, learn_recurrent) self._init_recurrent_one_to_one() if not learn_recurrent: self._disable_recurrent_grad() self._init_mem() if self.reset_mechanism_val == 0: # reset by subtraction self.state_function = self._base_sub elif self.reset_mechanism_val == 1: # reset to zero self.state_function = self._base_zero elif self.reset_mechanism_val == 2: # no reset, pure integration self.state_function = self._base_int self.reset_delay = reset_delay def _init_mem(self): spk = torch.zeros(0) mem = torch.zeros(0) self.register_buffer("spk", spk, False) self.register_buffer("mem", mem, False)
[docs] def reset_mem(self): self.spk = torch.zeros_like(self.spk, device=self.spk.device) self.mem = torch.zeros_like(self.mem, device=self.mem.device) return self.spk, self.mem
[docs] def init_rleaky(self): """Deprecated, use :class:`RLeaky.reset_mem` instead""" return self.reset_mem()
[docs] def forward(self, input_, spk=None, mem=None): if not spk == None: self.spk = spk if not mem == None: self.mem = mem if self.init_hidden and (not mem == None or not spk == None): raise TypeError( "When `init_hidden=True`," "RLeaky expects 1 input argument." ) if not self.spk.shape == input_.shape: self.spk = torch.zeros_like(input_, device=self.spk.device) if not self.mem.shape == input_.shape: self.mem = torch.zeros_like(input_, device=self.mem.device) # TO-DO: alternatively, we could do torch.exp(-1 / # self.beta.clamp_min(0)), giving actual time constants instead of # values in [0, 1] as initial beta beta = self.beta.clamp(0, 1) self.reset = self.mem_reset(self.mem) self.mem = self.state_function(input_) if self.state_quant: self.mem = self.state_quant(self.mem) if self.inhibition: self.spk = self.fire_inhibition(self.mem.size(0), self.mem) else: self.spk = self.fire(self.mem) if not self.reset_delay: do_reset = ( self.spk / self.graded_spikes_factor - self.reset ) # avoid double reset if self.reset_mechanism_val == 0: # reset by subtraction self.mem = self.mem - do_reset * self.threshold elif self.reset_mechanism_val == 1: # reset to zero self.mem = self.mem - do_reset * self.mem if self.output: return self.spk, self.mem elif self.init_hidden: return self.spk else: return self.spk, self.mem
def _init_recurrent_net(self): if self.all_to_all: if self.linear_features: self._init_recurrent_linear() elif self.kernel_size is not None: self._init_recurrent_conv2d() else: self._init_recurrent_one_to_one() def _init_recurrent_linear(self): self.recurrent = nn.Linear(self.linear_features, self.linear_features) def _init_recurrent_conv2d(self): self._init_padding() self.recurrent = nn.Conv2d( in_channels=self.conv2d_channels, out_channels=self.conv2d_channels, kernel_size=self.kernel_size, padding=self.padding, ) def _init_padding(self): if type(self.kernel_size) is int: self.padding = self.kernel_size // 2, self.kernel_size // 2 else: self.padding = self.kernel_size[0] // 2, self.kernel_size[1] // 2 def _init_recurrent_one_to_one(self): self.recurrent = RecurrentOneToOne(self.V) def _disable_recurrent_grad(self): for param in self.recurrent.parameters(): param.requires_grad = False def _base_state_function(self, input_): base_fn = ( self.beta.clamp(0, 1) * self.mem + input_ + self.recurrent(self.spk) ) return base_fn def _base_sub(self, input_): return self._base_state_function(input_) - self.reset * self.threshold def _base_zero(self, input_): return self._base_state_function( input_ ) - self.reset * self._base_state_function(input_) def _base_int(self, input_): return self._base_state_function(input_) def _rleaky_init_cases(self): all_to_all_bool = bool(self.all_to_all) linear_features_bool = self.linear_features conv2d_channels_bool = bool(self.conv2d_channels) kernel_size_bool = bool(self.kernel_size) if all_to_all_bool: if not (linear_features_bool): if not (conv2d_channels_bool or kernel_size_bool): raise TypeError( "When `all_to_all=True`, RLeaky requires either" "`linear_features` or (`conv2d_channels` and " "`kernel_size`) to be specified. The " "shape should match the shape of the output spike of " "the layer." ) elif conv2d_channels_bool ^ kernel_size_bool: raise TypeError( "`conv2d_channels` and `kernel_size` must both be" "specified. The shape of `conv2d_channels` should " "match the shape of the output" "spikes." ) elif (linear_features_bool and kernel_size_bool) or ( linear_features_bool and conv2d_channels_bool ): raise TypeError( "`linear_features` cannot be specified at the same time as" "`conv2d_channels` or `kernel_size`. A linear layer and " "conv2d layer cannot both" "be specified at the same time." ) else: if ( linear_features_bool or conv2d_channels_bool or kernel_size_bool ): raise TypeError( "When `all_to_all`=False, none of `linear_features`," "`conv2d_channels`, or `kernel_size` should be specified. " "The weight `V` is used" "instead." )
[docs] @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are instance variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], RLeaky): cls.instances[layer].mem.detach_() cls.instances[layer].spk.detach_()
[docs] @classmethod def reset_hidden(cls): """Used to clear hidden state variables to zero. Intended for use where hidden state variables are instance variables. Assumes hidden states have a batch dimension already.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], RLeaky): ( cls.instances[layer].spk, cls.instances[layer].mem, ) = cls.instances[layer].init_rleaky()
[docs] class RecurrentOneToOne(nn.Module): def __init__(self, V): super().__init__() self.V = V
[docs] def forward(self, x): return x * self.V # element-wise or global multiplication