snn.SConv2dLSTM

class snntorch._neurons.sconv2dlstm.SConv2dLSTM(in_channels, out_channels, kernel_size, bias=True, max_pool=0, avg_pool=0, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_threshold=False, reset_mechanism='none', state_quant=False, output=False)[source]

Bases: SpikingNeuron

A spiking 2d convolutional long short-term memory cell. Hidden states are membrane potential and synaptic current \(mem, syn\), which correspond to the hidden and cell states \(h, c\) in the original LSTM formulation.

The input is expected to be of size \((N, C_{in}, H_{in}, W_{in})\) where \(N\) is the batch size.

Unlike the LSTM module in PyTorch, only one time step is simulated each time the cell is called.

\[\begin{split}\begin{array}{ll} \\ i_t = \sigma(W_{ii} ⋆ x_t + b_{ii} + W_{hi} ⋆ mem_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} ⋆ x_t + b_{if} + W_{hf} mem_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} ⋆ x_t + b_{ig} + W_{hg} ⋆ mem_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} ⋆ x_t + b_{io} + W_{ho} ⋆ mem_{t-1} + b_{ho}) \\ syn_t = f_t ∗ c_{t-1} + i_t ∗ g_t \\ mem_t = o_t ∗ \tanh(syn_t) \\ \end{array}\end{split}\]

where \(\sigma\) is the sigmoid function, ⋆ is the 2D cross-correlation operator and ∗ is the Hadamard product. The output state \(mem_{t+1}\) is thresholded to determine whether an output spike is generated. To conform to standard LSTM state behavior, the default reset mechanism is set to reset=”none”, i.e., no reset is applied. If this is changed, the reset is only applied to \(mem_t\).

Options to apply max-pooling or average-pooling to the state \(mem_t\) are also enabled. Note that it is preferable to apply pooling to the state rather than the spike, as it does not make sense to apply pooling to activations of 1’s and 0’s which may lead to random tie-breaking.

Padding is automatically applied to ensure consistent sizes for hidden states from one time step to the next.

At the moment, stride != 1 is not supported.

Example:

import torch
import torch.nn as nn
import snntorch as snn


# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        in_channels = 1
        out_channels = 8
        kernel_size = 3
        max_pool = 2
        avg_pool = 2
        flattened_input = 49 * 16
        num_outputs = 10
        beta = 0.5

        spike_grad_lstm = snn.surrogate.straight_through_estimator()
        spike_grad_fc = snn.surrogate.fast_sigmoid(slope=5)

        # initialize layers
        self.sclstm1 = snn.SConv2dLSTM(
            in_channels,
            out_channels,
            kernel_size,
            max_pool=max_pool,
            spike_grad=spike_grad_lstm,
        )
        self.sclstm2 = snn.SConv2dLSTM(
            out_channels,
            out_channels,
            kernel_size,
            avg_pool=avg_pool,
            spike_grad=spike_grad_lstm,
        )
        self.fc1 = nn.Linear(flattened_input, num_outputs)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad_fc)

    def forward(self, x):
        # Initialize hidden states and outputs at t=0
        syn1, mem1 = self.lif1.init_sconv2dlstm()
        syn2, mem2 = self.lif1.init_sconv2dlstm()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Number of steps assuming x is [N, T, C, H, W] with
        # N = Batches, T = Time steps, C = Channels,
        # H = Height, W = Width
        num_steps = x.size()[1]

        for step in range(num_steps):
            x_step = x[:, step, :, :, :]
            spk1, syn1, mem1 = self.sclstm1(x_step, syn1, mem1)
            spk2, syn2, mem2 = self.sclstm2(spk1, syn2, mem2)
            cur = self.fc1(spk2.flatten(1))
            spk3, mem3 = self.lif1(cur, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec), torch.stack(mem3_rec)
Parameters:
  • in_channels (int) – number of input channels

  • kernel_size (int, tuple, or list) – Size of the convolving kernel

  • bias (bool, optional) – If True, adds a learnable bias to the output. Defaults to True

  • max_pool (int, tuple, or list, optional) – Applies max-pooling to the hidden state \(mem\) prior to thresholding if specified. Defaults to 0

  • avg_pool (int, tuple, or list, optional) – Applies average-pooling to the hidden state \(mem\) prior to thresholding if specified. Defaults to 0

  • threshold (float, optional) – Threshold for \(mem\) to reach in order to generate a spike S=1. Defaults to 1

  • spike_grad (surrogate gradient function from snntorch.surrogate, optional) – Surrogate gradient for the term dS/dU. Defaults to ATan surrogate gradient

  • surrogate_disable (bool, Optional) – Disables surrogate gradients regardless of spike_grad argument. Useful for ONNX compatibility. Defaults to False

  • learn_threshold (bool, optional) – Option to enable learnable threshold. Defaults to False

  • init_hidden (bool, optional) – Instantiates state variables as instance variables. Defaults to False

  • inhibition (bool, optional) – If True, suppresses all spiking other than the neuron with the highest state. Defaults to False

  • reset_mechanism (str, optional) – Defines the reset mechanism applied to \(mem\) each time the threshold is met. Reset-by-subtraction: “subtract”, reset-to-zero: “zero, none: “none”. Defaults to “none”

  • state_quant (quantization function from snntorch.quant, optional) – If specified, hidden states \(mem\) and \(syn\) are quantized to a valid state for the forward pass. Defaults to False

  • output (bool, optional) – If True as well as init_hidden=True, states are returned when neuron is called. Defaults to False

Inputs: input_, syn_0, mem_0
  • input_ of shape (batch, in_channels, H, W): tensor containing input features

  • syn_0 of shape (batch, out_channels, H, W): tensor containing the initial synaptic current (or cell state) for each element in the batch.

  • mem_0 of shape (batch, out_channels, H, W): tensor containing the initial membrane potential (or hidden state) for each element in the batch.

Outputs: spk, syn_1, mem_1
  • spk of shape (batch, out_channels, H/pool, W/pool): tensor containing the output spike (avg_pool and max_pool scale if greater than 0.)

  • syn_1 of shape (batch, out_channels, H, W): tensor containing the next synaptic current (or cell state) for each element in the batch

  • mem_1 of shape (batch, out_channels, H, W): tensor containing the next membrane potential (or hidden state) for each element in the batch

Learnable Parameters:
  • SConv2dLSTM.conv.weight (torch.Tensor) - the learnable weights, of shape ((in_channels + out_channels), 4*out_channels, kernel_size).

classmethod detach_hidden()[source]

Returns the hidden states, detached from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are instance variables.

forward(input_, syn=None, mem=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_sconv2dlstm()[source]

Deprecated, use SConv2dLSTM.reset_mem instead

classmethod reset_hidden()[source]

Used to clear hidden state variables to zero. Intended for use where hidden state variables are instance variables.

reset_mem()[source]