Source code for snntorch._neurons.slstm

import torch
import torch.nn as nn
from .neurons import SpikingNeuron


[docs] class SLSTM(SpikingNeuron): """ A spiking long short-term memory cell. Hidden states are membrane potential and synaptic current :math:`mem, syn`, which correspond to the hidden and cell states :math:`h, c` in the original LSTM formulation. The input is expected to be of size :math:`(N, X)` where :math:`N` is the batch size. Unlike the LSTM module in PyTorch, only one time step is simulated each time the cell is called. .. math:: \\begin{array}{ll} \\\\ i_t = \\sigma(W_{ii} x_t + b_{ii} + W_{hi} mem_{t-1} + b_{hi}) \\\\ f_t = \\sigma(W_{if} x_t + b_{if} + W_{hf} mem_{t-1} + b_{hf}) \\\\ g_t = \\tanh(W_{ig} x_t + b_{ig} + W_{hg} mem_{t-1} + b_{hg}) \\\\ o_t = \\sigma(W_{io} x_t + b_{io} + W_{ho} mem_{t-1} + b_{ho}) \\\\ syn_t = f_t ∗ syn_{t-1} + i_t ∗ g_t \\\\ mem_t = o_t ∗ \\tanh(syn_t) \\\\ \\end{array} where :math:`\\sigma` is the sigmoid function and ∗ is the Hadamard product. The output state :math:`mem_{t+1}` is thresholded to determine whether an output spike is generated. To conform to standard LSTM state behavior, the default reset mechanism is set to `reset="none"`, i.e., no reset is applied. If this is changed, the reset is only applied to :math:`h_t`. Example:: import torch import torch.nn as nn import snntorch as snn beta = 0.5 # Define Network class Net(nn.Module): def __init__(self): super().__init__() num_inputs = 784 num_hidden1 = 1000 num_hidden2 = 10 spike_grad_lstm = surrogate.straight_through_estimator() # initialize layers self.slstm1 = snn.SLSTM(num_inputs, num_hidden1, spike_grad=spike_grad_lstm) self.slstm2 = snn.SLSTM(num_hidden1, num_hidden2, spike_grad=spike_grad_lstm) def forward(self, x): # Initialize hidden states and outputs at t=0 syn1, mem1 = self.slstm1.init_slstm() syn2, mem2 = self.slstm2.init_slstm() # Record the final layer spk2_rec = [] mem2_rec = [] for step in range(num_steps): spk1, syn1, mem1 = self.slstm1(x.flatten(1), syn1, mem1) spk2, syn2, mem2 = self.slstm2(spk1, syn2, mem2) spk2_rec.append(spk2) mem2_rec.append(mem2) return torch.stack(spk2_rec), torch.stack(mem2_rec) :param input_size: number of expected features in the input :math:`x` :type input_size: int :param hidden_size: the number of features in the hidden state :math:`mem` :type hidden_size: int :param bias: If `True`, adds a learnable bias to the output. Defaults to `True` :type bias: bool, optional :param threshold: Threshold for :math:`h` to reach in order to generate a spike `S=1`. Defaults to 1 :type threshold: float, optional :param spike_grad: Surrogate gradient for the term dS/dU. Defaults to ATan surrogate gradient :type spike_grad: surrogate gradient function from snntorch.surrogate, optional :param surrogate_disable: Disables surrogate gradients regardless of `spike_grad` argument. Useful for ONNX compatibility. Defaults to False :type surrogate_disable: bool, Optional :param learn_threshold: Option to enable learnable threshold. Defaults to False :type learn_threshold: bool, optional :param init_hidden: Instantiates state variables as instance variables. Defaults to False :type init_hidden: bool, optional :param inhibition: If `True`, suppresses all spiking other than the neuron with the highest state. Defaults to False :type inhibition: bool, optional :param reset_mechanism: Defines the reset mechanism applied to \ :math:`mem` each time the threshold is met. Reset-by-subtraction: \ "subtract", reset-to-zero: "zero, none: "none". Defaults to "none" :type reset_mechanism: str, optional :param state_quant: If specified, hidden states :math:`mem` and \ :math:`syn` are quantized to a valid state for the forward pass. \ Defaults to False :type state_quant: quantization function from snntorch.quant, optional :param output: If `True` as well as `init_hidden=True`, states are returned when neuron is called. Defaults to False :type output: bool, optional Inputs: \\input_, syn_0, mem_0 - **input_** of shape `(batch, input_size)`: tensor containing input \ features - **syn_0** of shape `(batch, hidden_size)`: tensor containing the \ initial synaptic current (or cell state) for each element in the batch. - **mem_0** of shape `(batch, hidden_size)`: tensor containing the \ initial membrane potential (or hidden state) for each element in the \ batch. Outputs: spk, syn_1, mem_1 - **spk** of shape `(batch, hidden_size)`: tensor containing the \ output spike - **syn_1** of shape `(batch, hidden_size)`: tensor containing the \ next synaptic current (or cell state) for each element in the batch - **mem_1** of shape `(batch, hidden_size)`: tensor containing the \ next membrane potential (or hidden state) for each element in the batch Learnable Parameters: - **SLSTM.lstm_cell.weight_ih** (torch.Tensor) - the learnable \ input-hidden weights, of shape (4*hidden_size, input_size) - **SLSTM.lstm_cell.weight_ih** (torch.Tensor) – the learnable \ hidden-hidden weights, of shape (4*hidden_size, hidden_size) - **SLSTM.lstm_cell.bias_ih** – the learnable input-hidden bias, of \ shape (4*hidden_size) - **SLSTM.lstm_cell.bias_hh** – the learnable hidden-hidden bias, of \ shape (4*hidden_size) """ def __init__( self, input_size, hidden_size, bias=True, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_threshold=False, reset_mechanism="none", state_quant=False, output=False, ): super().__init__( threshold, spike_grad, surrogate_disable, init_hidden, inhibition, learn_threshold, reset_mechanism, state_quant, output, ) self._init_mem() if self.reset_mechanism_val == 0: # reset by subtraction self.state_function = self._base_sub elif self.reset_mechanism_val == 1: # reset to zero self.state_function = self._base_zero elif self.reset_mechanism_val == 2: # no reset, pure integration self.state_function = self._base_int self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.lstm_cell = nn.LSTMCell( self.input_size, self.hidden_size, bias=self.bias ) def _init_mem(self): syn = torch.zeros(0) mem = torch.zeros(0) self.register_buffer("syn", syn, False) self.register_buffer("mem", mem, False)
[docs] def reset_mem(self): self.syn = torch.zeros_like(self.syn, device=self.syn.device) self.mem = torch.zeros_like(self.mem, device=self.mem.device) return self.syn, self.mem
[docs] def init_slstm(self): """Deprecated, use :class:`SLSTM.reset_mem` instead""" return self.reset_mem()
[docs] def forward(self, input_, syn=None, mem=None): if not syn == None: self.syn = syn if not mem == None: self.mem = mem if self.init_hidden and (not mem == None or not syn == None): raise TypeError( "`mem` or `syn` should not be passed as an argument while `init_hidden=True`" ) size = input_.size() correct_shape = (size[0], self.hidden_size) if not self.syn.shape == input_.shape: self.syn = torch.zeros(correct_shape, device=self.syn.device) if not self.mem.shape == input_.shape: self.mem = torch.zeros(correct_shape, device=self.mem.device) self.reset = self.mem_reset(self.mem) self.syn, self.mem = self.state_function(input_) if self.state_quant: self.syn = self.state_quant(self.syn) self.mem = self.state_quant(self.mem) self.spk = self.fire(self.mem) if self.output: return self.spk, self.syn, self.mem elif self.init_hidden: return self.spk else: return self.spk, self.syn, self.mem
def _base_state_function(self, input_): base_fn_mem, base_fn_syn = self.lstm_cell(input_, (self.mem, self.syn)) return base_fn_syn, base_fn_mem def _base_state_reset_zero(self, input_): base_fn_mem, _ = self.lstm_cell(input_, (self.mem, self.syn)) return 0, base_fn_mem def _base_sub(self, input_): syn, mem = self._base_state_function(input_) mem -= self.reset * self.threshold return syn, mem def _base_zero(self, input_): syn, mem = self._base_state_function(input_) syn2, mem2 = self._base_state_reset_zero(input_) syn2 *= self.reset mem2 *= self.reset syn -= syn2 mem -= mem2 return syn, mem def _base_int(self, input_): return self._base_state_function(input_)
[docs] @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are instance variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], SLSTM): cls.instances[layer].syn.detach_() cls.instances[layer].mem.detach_()
[docs] @classmethod def reset_hidden(cls): """Used to clear hidden state variables to zero. Intended for use where hidden state variables are instance variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], SLSTM): cls.instances[layer].syn = torch.zeros_like( cls.instances[layer].syn, device=cls.instances[layer].syn.device, ) cls.instances[layer].mem = torch.zeros_like( cls.instances[layer].mem, device=cls.instances[layer].mem.device, )