snntorch

snnTorch Neurons

snntorch is designed to be intuitively used with PyTorch, as though each spiking neuron were simply another activation in a sequence of layers.

A variety of spiking neuron classes are available which can simply be treated as activation units with PyTorch. Each layer of spiking neurons are therefore agnostic to fully-connected layers, convolutional layers, residual connections, etc.

The neuron models are represented by recursive functions which removes the need to store membrane potential traces in order to calculate the gradient. The lean requirements of snntorch enable small and large networks to be viably trained on CPU, where needed. Being deeply integrated with torch.autograd, snntorch is able to take advantage of GPU acceleration in the same way as PyTorch.

By default, PyTorch’s autodifferentiation mechanism in torch.autograd nulls the gradient signal of the spiking neuron graph due to non-differentiable spiking threshold functions. snntorch overrides the default gradient by using snntorch.neurons.Heaviside. Alternative options exist in snntorch.surrogate.

At present, the neurons available in snntorch are variants of the Leaky Integrate-and-Fire neuron model:

  • Leaky - 1st-Order Leaky Integrate-and-Fire Neuron

  • RLeaky - As above, with recurrent connections for output spikes

  • Synaptic - 2nd-Order Integrate-and-Fire Neuron (including synaptic conductance)

  • RSynaptic - As above, with recurrent connections for output spikes

  • Lapicque - Lapicque’s RC Neuron Model

  • Alpha - Alpha Membrane Model

Neuron models that accelerate training require passing data in parallel. Available neurons include: * LeakyParallel - 1st Order Leaky Integrate-and-Fire Neuron

Additional models include spiking-LSTMs and spiking-ConvLSTMs:

  • SLSTM - Spiking long short-term memory cell with state-thresholding

  • SConv2dLSTM - Spiking 2d convolutional short-term memory cell with state thresholding

How to use snnTorch’s neuron models

The following arguments are common across most neuron models:

  • threshold - firing threshold of the neuron

  • spike_grad - surrogate gradient function (see snntorch.surrogate)

  • init_hidden - setting to True hides all neuron states as instance variables to reduce code complexity

  • inhibition - setting to True enables only the neuron with the highest membrane potential to fire in a dense layer (not for use in convs etc.)

  • learn_beta - setting to True enables the decay rate to be a learnable parameter

  • learn_threshold - setting to True enables the threshold to be a learnable parameter

  • reset_mechanism - options include subtract (reset-by-subtraction), zero (reset-to-zero), and none (no reset mechanism: i.e., leaky integrator neuron)

  • output - if init_hidden=True, the spiking neuron will only return the output spikes. Setting output=True enables the hidden state(s) to be returned as well. Useful when using torch.nn.sequential.

Leaky integrate-and-fire neuron models also include:

  • beta - decay rate of membrane potential, clipped between 0 and 1 during the forward-pass. Can be a single-value tensor (same decay for all neurons in a layer), or can be multi-valued (individual weights p/neuron in a layer. More complex neurons include additional parameters, such as alpha.

Recurrent spiking neuron models, such as snntorch.RLeaky and snntorch.RSynaptic explicitly pass the output spike back to the input. Such neurons include additional arguments:

  • V - Recurrent weight. Can be a single-valued tensor (same weight across all neurons in a layer), or multi-valued tensor (individual weights p/neuron in a layer).

  • learn_V - defaults to True, which enables V to be a learnable parameter.

Spiking neural networks can be constructed using a combination of the snntorch and torch.nn packages.

Example:

import torch
import torch.nn as nn
import snntorch as snn

alpha = 0.9
beta = 0.85

num_steps = 100


# Define Network
class Net(nn.Module):
   def __init__(self):
      super().__init__()

      # initialize layers
      self.fc1 = nn.Linear(num_inputs, num_hidden)
      self.lif1 = snn.Leaky(beta=beta)
      self.fc2 = nn.Linear(num_hidden, num_outputs)
      self.lif2 = snn.Leaky(beta=beta)

   def forward(self, x):
      mem1 = self.lif1.init_leaky()
      mem2 = self.lif2.init_leaky()

      spk2_rec = []  # Record the output trace of spikes
      mem2_rec = []  # Record the output trace of membrane potential

      for step in range(num_steps):
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

      return torch.stack(spk2_rec), torch.stack(mem2_rec)

net = Net().to(device)

output, mem_rec = net(data)

In the above example, the hidden state mem must be manually initialized for each layer. This can be overcome by automatically instantiating neuron hidden states by invoking init_hidden=True.

In some cases (e.g., truncated backprop through time), it might be necessary to perform backward passes before all time steps have completed processing. This requires moving the time step for-loop out of the network and into the training-loop.

An example of this is shown below:

import torch
import torch.nn as nn
import snntorch as snn

num_steps = 100

lif1 = snn.Leaky(beta=0.9, init_hidden=True) # only returns spk
lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) # returns mem and spk if output=True


#  Initialize Network
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784,1000),
                    lif1,
                    nn.Linear(1000, 10),
                    lif2).to(device)

for step in range(num_steps):
   spk_out, mem_out = net(data)

Setting the hidden states to instance variables is necessary for calling nn.Sequential from PyTorch.

Whenever a neuron is instantiated, it is added as a list item to the class variable LIF.instances. This allows you to keep track of what neurons are being used in the network, and to detach neurons from the computation graph.

In the above examples, the decay rate of membrane potential beta is treated as a hyperparameter. But it can also be configured as a learnable parameter, as shown below:

import torch
import torch.nn as nn
import snntorch as snn

num_steps = 100

lif1 = snn.Leaky(beta=0.9, learn_beta=True, init_hidden=True) # only returns spk
lif2 = snn.Leaky(beta=0.5, learn_beta=True, init_hidden=True, output=True) # returns mem and spk if output=True


#  Initialize Network
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784,1000),
                    lif1,
                    nn.Linear(1000, 10),
                    lif2).to(device)

for step in range(num_steps):
   spk_out, mem_out = net(data.view(batch_size, -1))

Here, beta is initialized to 0.9 for the first layer, and 0.5 for the second layer. Each layer then treats it as a learnable parameter, just like all the other network weights. In the event you wish to have a learnable decay rate for each neuron rather than each layer, the following example shows how:

import torch
import torch.nn as nn
import snntorch as snn

num_steps = 100
num_hidden = 1000
num_output = 10

beta1 = torch.rand(num_hidden)  # randomly initialize beta as a vector
beta2 = torch.rand(num_output)

lif1 = snn.Leaky(beta=beta1, learn_beta=True, init_hidden=True) # only returns spk
lif2 = snn.Leaky(beta=beta2 learn_beta=True, init_hidden=True, output=True) # returns mem and spk if output=True

#  Initialize Network
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, num_hidden),
                    lif1,
                    nn.Linear(1000, num_output),
                    lif2).to(device)

for step in range(num_steps):
   spk_out, mem_out = net(data.view(batch_size, -1))

The same approach as above can be used for implementing learnable thresholds, using learn_threshold=True.

Each neuron has the option to inhibit other neurons within the same dense layer from firing. This can be invoked by setting inhibition=True when instantiating the neuron layer. It has not yet been implemented for networks other than fully-connected layers, so use with caution.

Neuron List

snnTorch Layers

snntorch._layers.bntt.BatchNormTT1d(input_features, time_steps, eps=1e-05, momentum=0.1, affine=True)[source]

Generate a torch.nn.ModuleList of 1D Batch Normalization Layer with length time_steps. Input to this layer is the same as the vanilla torch.nn.BatchNorm1d layer.

Batch Normalisation Through Time (BNTT) as presented in: ‘Revisiting Batch Normalization for Training Low-Latency Deep Spiking Neural Networks From Scratch’ By Youngeun Kim & Priyadarshini Panda arXiv preprint arXiv:2010.01729

Original GitHub repo: https://github.com/Intelligent-Computing-Lab-Yale/ BNTT-Batch-Normalization-Through-Time

Using LIF neuron as the neuron of choice for the math shown below.

Typically, for a single post-synaptic neuron i, we can represent its membrane potential \(U_{i}^{t}\) at time-step t as:

\[U_{i}^{t} = λ u_{i}^{t-1} + \sum_j w_{ij}S_{j}^{t}\]

where:

  • λ - a leak factor which is less than one

  • j - the index of the pre-synaptic neuron

  • \(S_{j}\) - the binary spike activation

  • \(w_{ij}\) - the weight of the connection between the pre & post neurons.

With Batch Normalization Throught Time, the membrane potential can be modeled as:

\[ \begin{align}\begin{aligned}U_{i}^{t} = λu_{i}^{t-1} + BNTT_{γ^{t}}\\ = λu_{i}^{t-1} + γ _{i}^{t} (\frac{\sum_j w_{ij}S_{j}^{t} - µ_{i}^{t}}{\sqrt{(σ _{i}^{t})^{2} + ε}})\end{aligned}\end{align} \]
Parameters:
  • input_features (int) – number of features of the input

  • time_steps (int) – number of time-steps of the SNN

  • eps (float) – a value added to the denominator for numerical stability

  • momentum (float) – the value used for the running_mean and running_var computation

  • affine (bool) – a boolean value that when set to True, the Batch Norm layer will have learnable affine parameters

Inputs: input_features, time_steps
  • input_features: same number of features as the input

  • time_steps: the number of time-steps to unroll in the SNN

Outputs: bntt
  • bntt of shape (time_steps): toch.nn.ModuleList of BatchNorm1d layers for the specified number of time-steps

snntorch._layers.bntt.BatchNormTT2d(input_features, time_steps, eps=1e-05, momentum=0.1, affine=True)[source]

Generate a torch.nn.ModuleList of 2D Batch Normalization Layer with length time_steps. Input to this layer is the same as the vanilla torch.nn.BatchNorm2d layer.

Batch Normalisation Through Time (BNTT) as presented in: ‘Revisiting Batch Normalization for Training Low-Latency Deep Spiking Neural Networks From Scratch’ By Youngeun Kim & Priyadarshini Panda arXiv preprint arXiv:2010.01729

Using LIF neuron as the neuron of choice for the math shown below.

Typically, for a single post-synaptic neuron i, we can represent its membrane potential \(U_{i}^{t}\) at time-step t as:

\[U_{i}^{t} = λ u_{i}^{t-1} + \sum_j w_{ij}S_{j}^{t}\]

where:

  • λ - a leak factor which is less than one

  • j - the index of the pre-synaptic neuron

  • \(S_{j}\) - the binary spike activation

  • \(w_{ij}\) - the weight of the connection between the pre & post neurons.

With Batch Normalization Throught Time, the membrane potential can be modeled as:

\[ \begin{align}\begin{aligned}U_{i}^{t} = λ u_{i}^{t-1} + BNTT_{γ^{t}}\\ = λ u_{i}^{t-1} + γ_{i}^{t} (\frac{\sum_j w_{ij}S_{j}^{t} - µ_{i}^{t}}{\sqrt{(σ _{i}^{t})^{2} + ε}})\end{aligned}\end{align} \]
Parameters:
  • input_features (int) – number of channels of the input

  • time_steps (int) – number of time-steps of the SNN

  • eps (float) – a value added to the denominator for numerical stability

  • momentum (float) – the value used for the running_mean and running_var computation

  • affine (bool) – a boolean value that when set to True, the Batch Norm layer will have learnable affine parameters

Inputs: input_features, time_steps
  • input_features: same number of channels as the input

  • time_steps: the number of time-steps to unroll in the SNN

Outputs: bntt
  • bntt of shape (time_steps): toch.nn.ModuleList of BatchNorm1d layers for the specified number of time-steps

Neuron Parent Classes

class snntorch._neurons.neurons.LIF(beta, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_beta=False, learn_threshold=False, reset_mechanism='subtract', state_quant=False, output=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False)[source]

Bases: SpikingNeuron

Parent class for leaky integrate and fire neuron models.

class snntorch._neurons.neurons.SpikingNeuron(threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_threshold=False, reset_mechanism='subtract', state_quant=False, output=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False)[source]

Bases: Module

Parent class for spiking neuron models.

static detach(*args)[source]

Used to detach input arguments from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are global variables.

fire(mem)[source]

Generates spike if mem > threshold. Returns spk.

fire_inhibition(batch_size, mem)[source]

Generates spike if mem > threshold, only for the largest membrane. All others neurons will be inhibited for that time step. Returns spk.

classmethod init()[source]

Removes all items from snntorch.SpikingNeuron.instances when called.

instances = []

Each snntorch.SpikingNeuron neuron (e.g., snntorch.Synaptic) will populate the snntorch.SpikingNeuron.instances list with a new entry. The list is used to initialize and clear neuron states when the argument init_hidden=True.

mem_reset(mem)[source]

Generates detached reset signal if mem > threshold. Returns reset.

reset_dict = {'none': 2, 'subtract': 0, 'zero': 1}
property reset_mechanism

If reset_mechanism is modified, reset_mechanism_val is triggered to update. 0: subtract, 1: zero, 2: none.

static zeros(*args)[source]

Used to clear hidden state variables to zero. Intended for use where hidden state variables are global variables.