Source code for snntorch._neurons.lapicque

import torch
from .neurons import LIF


[docs] class Lapicque(LIF): """ An extension of Lapicque's experimental comparison between extracellular nerve fibers and an RC circuit. It is qualitatively equivalent to :code:`Leaky` but defined using RC circuit parameters. Input stimulus is integrated by membrane potential which decays exponentially with a rate of beta. For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`. If `reset_mechanism = "subtract"`, then :math:`U[t+1]` will have `threshold` subtracted from it whenever the neuron emits a spike: .. math:: U[t+1] = I_{\\rm in}[t+1] (\\frac{T}{C}) + (1- \\frac{T}{\\tau})U[t] - RU_{\\rm thr} If `reset_mechanism = "zero"`, then :math:`U[t+1]` will be set to `0` whenever the neuron emits a spike: .. math:: U[t+1] = I_{\\rm in}[t+1] (\\frac{T}{\\tau}) + (1- \\frac{T}{\\tau})U[t] - R(I_{\\rm in}[t+1] (\\frac{T}{C}) + (1- \\frac{T}{\\tau})U[t]) * :math:`I_{\\rm in}` - Input current * :math:`U` - Membrane potential * :math:`U_{\\rm thr}` - Membrane threshold * :math:`T`- duration of each time step * :math:`R` - Reset mechanism: if active, :math:`R = 1`, otherwise \ :math:`R = 0` * :math:`β` - Membrane potential decay rate. \ Alternatively, the membrane potential decay rate β can be \ specified instead: .. math:: β = e^{-1/RC} * :math:`R` - Parallel resistance of passive membrane \ (note: distinct from the reset :math:`R`) * :math:`C` - Parallel capacitance of passive membrane Notes: * If only β is defined, then R will default to 1, and C will be inferred. * If RC is defined, β will be automatically calculated. * If (β and R) or (β and C) are defined, the missing variable will be \ automatically calculated. * Note that β, R and C are treated as 'hard-wired' physically \ plausible parameters, and are therefore not learnable. For a \ single-state neuron with a learnable decay rate β, \ use :code:`snn.Leaky` instead. Example:: import torch import torch.nn as nn import snntorch as snn beta = 0.5 R = 1 C = 1.44 # Define Network class Net(nn.Module): def __init__(self): super().__init__() # initialize layers self.fc1 = nn.Linear(num_inputs, num_hidden) self.lif1 = snn.Lapicque(beta=beta) self.fc2 = nn.Linear(num_hidden, num_outputs) self.lif2 = snn.Lapicque(R=R, C=C) # lif1 and lif2 are approximately equivalent def forward(self, x, mem1, spk1, mem2): cur1 = self.fc1(x) spk1, mem1 = self.lif1(cur1, mem1) cur2 = self.fc2(spk1) spk2, mem2 = self.lif2(cur2, mem2) return mem1, spk1, mem2, spk2 For further reading, see: *L. Lapicque (1907) Recherches quantitatives sur l'excitation électrique des nerfs traitée comme une polarisation. J. Physiol. Pathol. Gen. 9, pp. 620-635. (French)* *N. Brunel and M. C. Van Rossum (2007) Lapicque's 1907 paper: From frogs to integrate-and-fire. Biol. Cybern. 97, pp. 337-339. (English)* Although Lapicque did not formally introduce this as an integrate-and-fire neuron model, we pay homage to his discovery of an RC circuit mimicking the dynamics of synaptic current. :param beta: RC potential decay rate. Clipped between 0 and 1 during the forward-pass. May be a single-valued tensor (i.e., equal decay rate for all neurons in a layer), or multi-valued (one weight per neuron). :type beta: float or torch.tensor, Optional :param R: Resistance of RC circuit :type R: int or torch.tensor, Optional :param C: Capacitance of RC circuit :type C: int or torch.tensor, Optional :param time_step: time step precision. Defaults to 1 :type time_step: float, Optional :param threshold: Threshold for :math:`mem` to reach in order to generate a spike `S=1`. Defaults to 1 :type threshold: float, optional :param spike_grad: Surrogate gradient for the term dS/dU. Defaults to None (corresponds to ATan surrogate gradient. See `snntorch.surrogate` for more options) :type spike_grad: surrogate gradient function from snntorch.surrogate, optional :param surrogate_disable: Disables surrogate gradients regardless of `spike_grad` argument. Useful for ONNX compatibility. Defaults to False :type surrogate_disable: bool, Optional :param init_hidden: Instantiates state variables as instance variables. Defaults to False :type init_hidden: bool, optional :param inhibition: If `True`, suppresses all spiking other than the neuron with the highest state. Defaults to False :type inhibition: bool, optional :param learn_beta: Option to enable learnable beta. Defaults to False :type learn_beta: bool, optional :param learn_threshold: Option to enable learnable threshold. Defaults to False :type learn_threshold: bool, optional :param reset_mechanism: Defines the reset mechanism applied to \ :math:`mem` each time the threshold is met. Reset-by-subtraction: \ "subtract", reset-to-zero: "zero", none: "none". Defaults to "none" :type reset_mechanism: str, optional :param state_quant: If specified, hidden state :math:`mem` is quantized to a valid state for the forward pass. Defaults to False :type state_quant: quantization function from snntorch.quant, optional :param output: If `True` as well as `init_hidden=True`, states are returned when neuron is called. Defaults to False :type output: bool, optional Inputs: \\input_, mem_0 - **input_** of shape `(batch, input_size)`: tensor containing input features - **mem_0** of shape `(batch, input_size)`: tensor containing the initial membrane potential for each element in the batch. Outputs: spk, mem_1 - **spk** of shape `(batch, input_size)`: tensor containing the output spikes. - **mem_1** of shape `(batch, input_size)`: tensor containing the next membrane potential for each element in the batch Learnable Parameters: - **Lapcique.beta** (torch.Tensor) - optional learnable weights must be manually passed in, of shape `1` or (input_size). - **Lapcique.threshold** (torch.Tensor) - optional learnable thresholds must be manually passed in, of shape `1` or`` (input_size). """ def __init__( self, beta=False, R=False, C=False, time_step=1, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_beta=False, learn_threshold=False, reset_mechanism="subtract", state_quant=False, output=False, ): super().__init__( beta, threshold, spike_grad, surrogate_disable, init_hidden, inhibition, learn_beta, learn_threshold, reset_mechanism, state_quant, output, ) self._lapicque_cases(time_step, beta, R, C) self._init_mem() if self.reset_mechanism_val == 0: # reset by subtraction self.state_function = self._base_sub elif self.reset_mechanism_val == 1: # reset to zero self.state_function = self._base_zero elif self.reset_mechanism_val == 2: # no reset, pure integration self.state_function = self._base_int def _init_mem(self): mem = torch.zeros(0) self.register_buffer("mem", mem, False)
[docs] def reset_mem(self): self.mem = torch.zeros_like(self.mem, device=self.mem.device) return self.mem
[docs] def init_lapicque(self): """Deprecated, use :class:`Lapicque.reset_mem` instead""" return self.reset_mem()
[docs] def forward(self, input_, mem=None): if not mem == None: self.mem = mem if self.init_hidden and not mem == None: raise TypeError( "`mem` should not be passed as an argument while `init_hidden=True`" ) if not self.mem.shape == input_.shape: self.mem = torch.zeros_like(input_, device=self.mem.device) self.reset = self.mem_reset(self.mem) self.mem = self.state_function(input_) if self.state_quant: self.mem = self.state_quant(self.mem) if self.inhibition: spk = self.fire_inhibition( self.mem.size(0), self.mem ) # batch_size else: spk = self.fire(self.mem) if self.output: return spk, self.mem elif self.init_hidden: return spk else: return spk, self.mem
def _base_state_function(self, input_): base_fn = ( input_ * self.R * (1 / (self.R * self.C)) * self.time_step + (1 - (self.time_step / (self.R * self.C))) * self.mem ) return base_fn def _base_sub(self, input_): return self._base_state_function(input_) - self.reset * self.threshold def _base_zero(self, input_): return self._base_state_function( input_ ) - self.reset * self._base_state_function(input_) def _base_int(self, input_): return self._base_state_function(input_) def _lapicque_cases(self, time_step, beta, R, C): if not isinstance(time_step, torch.Tensor): time_step = torch.as_tensor(time_step) self.register_buffer("time_step", time_step) if not self.beta and not (R and C): raise ValueError( "Either beta or 2 of beta, R and C must be specified as an " "input argument." ) elif not self.beta and (bool(R) ^ bool(C)): raise ValueError( "Either beta or 2 of beta, R and C must be specified as an " "input argument." ) elif (R and C) and not self.beta: beta = torch.exp(torch.ones(1) * (-self.time_step / (R * C))) self.register_buffer("beta", beta) if not isinstance(R, torch.Tensor): R = torch.as_tensor(R) self.register_buffer("R", R) if not isinstance(C, torch.Tensor): C = torch.as_tensor(C) self.register_buffer("C", C) elif self.beta and not (R or C): R = torch.as_tensor(1) self.register_buffer("R", R) C = self.time_step / (R * torch.log(1 / self.beta)) self.register_buffer("C", C) if not isinstance(R, torch.Tensor): self.register_buffer("beta", self.beta) elif self.beta and R and not C: C = self.time_step / (R * torch.log(1 / self.beta)) self.register_buffer("C", C) if not isinstance(R, torch.Tensor): R = torch.as_tensor(R) self.register_buffer("R", R) self.register_buffer("beta", self.beta) elif self.beta and C and not R: if not isinstance(C, torch.Tensor): C = torch.as_tensor(C) self.register_buffer("C", C) self.register_buffer("beta", self.beta) R = self.time_step / (C * torch.log(1 / self.beta)) self.register_buffer("R", R)
[docs] @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are instance variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Lapicque): cls.instances[layer].mem.detach_()
[docs] @classmethod def reset_hidden(cls): """Used to clear hidden state variables to zero. Intended for use where hidden state variables are instance variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Lapicque): cls.instances[layer].mem = torch.zeros_like( cls.instances[layer].mem, device=cls.instances[layer].mem.device, )