snntorch
snnTorch Neurons
snntorch
is designed to be intuitively used with PyTorch, as though each spiking neuron were simply another activation in a sequence of layers.
A variety of spiking neuron classes are available which can simply be treated as activation units with PyTorch. Each layer of spiking neurons are therefore agnostic to fully-connected layers, convolutional layers, residual connections, etc.
The neuron models are represented by recursive functions which removes the need to store membrane potential traces in order to calculate the gradient.
The lean requirements of snntorch
enable small and large networks to be viably trained on CPU, where needed.
Being deeply integrated with torch.autograd
, snntorch
is able to take advantage of GPU acceleration in the same way as PyTorch.
By default, PyTorch’s autodifferentiation mechanism in torch.autograd
nulls the gradient signal of the spiking neuron graph due to non-differentiable spiking threshold functions.
snntorch
overrides the default gradient by using snntorch.neurons.Heaviside
. Alternative options exist in snntorch.surrogate
.
At present, the neurons available in snntorch
are variants of the Leaky Integrate-and-Fire neuron model:
Leaky - 1st-Order Leaky Integrate-and-Fire Neuron
RLeaky - As above, with recurrent connections for output spikes
Synaptic - 2nd-Order Integrate-and-Fire Neuron (including synaptic conductance)
RSynaptic - As above, with recurrent connections for output spikes
Lapicque - Lapicque’s RC Neuron Model
Alpha - Alpha Membrane Model
Neuron models that accelerate training require passing data in parallel. Available neurons include: * LeakyParallel - 1st Order Leaky Integrate-and-Fire Neuron
Additional models include spiking-LSTMs and spiking-ConvLSTMs:
SLSTM - Spiking long short-term memory cell with state-thresholding
SConv2dLSTM - Spiking 2d convolutional short-term memory cell with state thresholding
How to use snnTorch’s neuron models
The following arguments are common across most neuron models:
threshold - firing threshold of the neuron
spike_grad - surrogate gradient function (see
snntorch.surrogate
)init_hidden - setting to
True
hides all neuron states as instance variables to reduce code complexityinhibition - setting to
True
enables only the neuron with the highest membrane potential to fire in a dense layer (not for use in convs etc.)learn_beta - setting to
True
enables the decay rate to be a learnable parameterlearn_threshold - setting to
True
enables the threshold to be a learnable parameterreset_mechanism - options include
subtract
(reset-by-subtraction),zero
(reset-to-zero), andnone
(no reset mechanism: i.e., leaky integrator neuron)output - if
init_hidden=True
, the spiking neuron will only return the output spikes. Settingoutput=True
enables the hidden state(s) to be returned as well. Useful when usingtorch.nn.sequential
.
Leaky integrate-and-fire neuron models also include:
beta - decay rate of membrane potential, clipped between 0 and 1 during the forward-pass. Can be a single-value tensor (same decay for all neurons in a layer), or can be multi-valued (individual weights p/neuron in a layer. More complex neurons include additional parameters, such as alpha.
Recurrent spiking neuron models, such as snntorch.RLeaky
and snntorch.RSynaptic
explicitly pass the output spike back to the input.
Such neurons include additional arguments:
V - Recurrent weight. Can be a single-valued tensor (same weight across all neurons in a layer), or multi-valued tensor (individual weights p/neuron in a layer).
learn_V - defaults to
True
, which enables V to be a learnable parameter.
Spiking neural networks can be constructed using a combination of the snntorch
and torch.nn
packages.
Example:
import torch
import torch.nn as nn
import snntorch as snn
alpha = 0.9
beta = 0.85
num_steps = 100
# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()
# initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.lif1 = snn.Leaky(beta=beta)
self.fc2 = nn.Linear(num_hidden, num_outputs)
self.lif2 = snn.Leaky(beta=beta)
def forward(self, x):
mem1 = self.lif1.init_leaky()
mem2 = self.lif2.init_leaky()
spk2_rec = [] # Record the output trace of spikes
mem2_rec = [] # Record the output trace of membrane potential
for step in range(num_steps):
cur1 = self.fc1(x.flatten(1))
spk1, mem1 = self.lif1(cur1, mem1)
cur2 = self.fc2(spk1)
spk2, mem2 = self.lif2(cur2, mem2)
spk2_rec.append(spk2)
mem2_rec.append(mem2)
return torch.stack(spk2_rec), torch.stack(mem2_rec)
net = Net().to(device)
output, mem_rec = net(data)
In the above example, the hidden state mem
must be manually initialized for each layer.
This can be overcome by automatically instantiating neuron hidden states by invoking init_hidden=True
.
In some cases (e.g., truncated backprop through time), it might be necessary to perform backward passes before all time steps have completed processing. This requires moving the time step for-loop out of the network and into the training-loop.
An example of this is shown below:
import torch
import torch.nn as nn
import snntorch as snn
num_steps = 100
lif1 = snn.Leaky(beta=0.9, init_hidden=True) # only returns spk
lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) # returns mem and spk if output=True
# Initialize Network
net = nn.Sequential(nn.Flatten(),
nn.Linear(784,1000),
lif1,
nn.Linear(1000, 10),
lif2).to(device)
for step in range(num_steps):
spk_out, mem_out = net(data)
Setting the hidden states to instance variables is necessary for calling nn.Sequential
from PyTorch.
Whenever a neuron is instantiated, it is added as a list item to the class variable LIF.instances
.
This allows you to keep track of what neurons are being used in the network, and to detach neurons from the computation graph.
In the above examples, the decay rate of membrane potential beta
is treated as a hyperparameter.
But it can also be configured as a learnable parameter, as shown below:
import torch
import torch.nn as nn
import snntorch as snn
num_steps = 100
lif1 = snn.Leaky(beta=0.9, learn_beta=True, init_hidden=True) # only returns spk
lif2 = snn.Leaky(beta=0.5, learn_beta=True, init_hidden=True, output=True) # returns mem and spk if output=True
# Initialize Network
net = nn.Sequential(nn.Flatten(),
nn.Linear(784,1000),
lif1,
nn.Linear(1000, 10),
lif2).to(device)
for step in range(num_steps):
spk_out, mem_out = net(data.view(batch_size, -1))
Here, beta
is initialized to 0.9 for the first layer, and 0.5 for the second layer.
Each layer then treats it as a learnable parameter, just like all the other network weights.
In the event you wish to have a learnable decay rate for each neuron rather than each layer, the following example shows how:
import torch
import torch.nn as nn
import snntorch as snn
num_steps = 100
num_hidden = 1000
num_output = 10
beta1 = torch.rand(num_hidden) # randomly initialize beta as a vector
beta2 = torch.rand(num_output)
lif1 = snn.Leaky(beta=beta1, learn_beta=True, init_hidden=True) # only returns spk
lif2 = snn.Leaky(beta=beta2 learn_beta=True, init_hidden=True, output=True) # returns mem and spk if output=True
# Initialize Network
net = nn.Sequential(nn.Flatten(),
nn.Linear(784, num_hidden),
lif1,
nn.Linear(1000, num_output),
lif2).to(device)
for step in range(num_steps):
spk_out, mem_out = net(data.view(batch_size, -1))
The same approach as above can be used for implementing learnable thresholds, using learn_threshold=True
.
Each neuron has the option to inhibit other neurons within the same dense layer from firing.
This can be invoked by setting inhibition=True
when instantiating the neuron layer. It has not yet been implemented for networks other than fully-connected layers, so use with caution.
Neuron List
snnTorch Layers
- snntorch._layers.bntt.BatchNormTT1d(input_features, time_steps, eps=1e-05, momentum=0.1, affine=True)[source]
Generate a torch.nn.ModuleList of 1D Batch Normalization Layer with length time_steps. Input to this layer is the same as the vanilla torch.nn.BatchNorm1d layer.
Batch Normalisation Through Time (BNTT) as presented in: ‘Revisiting Batch Normalization for Training Low-Latency Deep Spiking Neural Networks From Scratch’ By Youngeun Kim & Priyadarshini Panda arXiv preprint arXiv:2010.01729
Original GitHub repo: https://github.com/Intelligent-Computing-Lab-Yale/ BNTT-Batch-Normalization-Through-Time
Using LIF neuron as the neuron of choice for the math shown below.
Typically, for a single post-synaptic neuron i, we can represent its membrane potential \(U_{i}^{t}\) at time-step t as:
\[U_{i}^{t} = λ u_{i}^{t-1} + \sum_j w_{ij}S_{j}^{t}\]where:
λ - a leak factor which is less than one
j - the index of the pre-synaptic neuron
\(S_{j}\) - the binary spike activation
\(w_{ij}\) - the weight of the connection between the pre & post neurons.
With Batch Normalization Throught Time, the membrane potential can be modeled as:
\[ \begin{align}\begin{aligned}U_{i}^{t} = λu_{i}^{t-1} + BNTT_{γ^{t}}\\ = λu_{i}^{t-1} + γ _{i}^{t} (\frac{\sum_j w_{ij}S_{j}^{t} - µ_{i}^{t}}{\sqrt{(σ _{i}^{t})^{2} + ε}})\end{aligned}\end{align} \]- Parameters:
input_features (int) – number of features of the input
time_steps (int) – number of time-steps of the SNN
eps (float) – a value added to the denominator for numerical stability
momentum (float) – the value used for the running_mean and running_var computation
affine (bool) – a boolean value that when set to True, the Batch Norm layer will have learnable affine parameters
- Inputs: input_features, time_steps
input_features: same number of features as the input
time_steps: the number of time-steps to unroll in the SNN
- Outputs: bntt
bntt of shape (time_steps): toch.nn.ModuleList of BatchNorm1d layers for the specified number of time-steps
- snntorch._layers.bntt.BatchNormTT2d(input_features, time_steps, eps=1e-05, momentum=0.1, affine=True)[source]
Generate a torch.nn.ModuleList of 2D Batch Normalization Layer with length time_steps. Input to this layer is the same as the vanilla torch.nn.BatchNorm2d layer.
Batch Normalisation Through Time (BNTT) as presented in: ‘Revisiting Batch Normalization for Training Low-Latency Deep Spiking Neural Networks From Scratch’ By Youngeun Kim & Priyadarshini Panda arXiv preprint arXiv:2010.01729
Using LIF neuron as the neuron of choice for the math shown below.
Typically, for a single post-synaptic neuron i, we can represent its membrane potential \(U_{i}^{t}\) at time-step t as:
\[U_{i}^{t} = λ u_{i}^{t-1} + \sum_j w_{ij}S_{j}^{t}\]where:
λ - a leak factor which is less than one
j - the index of the pre-synaptic neuron
\(S_{j}\) - the binary spike activation
\(w_{ij}\) - the weight of the connection between the pre & post neurons.
With Batch Normalization Throught Time, the membrane potential can be modeled as:
\[ \begin{align}\begin{aligned}U_{i}^{t} = λ u_{i}^{t-1} + BNTT_{γ^{t}}\\ = λ u_{i}^{t-1} + γ_{i}^{t} (\frac{\sum_j w_{ij}S_{j}^{t} - µ_{i}^{t}}{\sqrt{(σ _{i}^{t})^{2} + ε}})\end{aligned}\end{align} \]- Parameters:
input_features (int) – number of channels of the input
time_steps (int) – number of time-steps of the SNN
eps (float) – a value added to the denominator for numerical stability
momentum (float) – the value used for the running_mean and running_var computation
affine (bool) – a boolean value that when set to True, the Batch Norm layer will have learnable affine parameters
- Inputs: input_features, time_steps
input_features: same number of channels as the input
time_steps: the number of time-steps to unroll in the SNN
- Outputs: bntt
bntt of shape (time_steps): toch.nn.ModuleList of BatchNorm1d layers for the specified number of time-steps
Neuron Parent Classes
- class snntorch._neurons.neurons.LIF(beta, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_beta=False, learn_threshold=False, reset_mechanism='subtract', state_quant=False, output=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False)[source]
Bases:
SpikingNeuron
Parent class for leaky integrate and fire neuron models.
- class snntorch._neurons.neurons.SpikingNeuron(threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_threshold=False, reset_mechanism='subtract', state_quant=False, output=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False)[source]
Bases:
Module
Parent class for spiking neuron models.
- static detach(*args)[source]
Used to detach input arguments from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are global variables.
- fire_inhibition(batch_size, mem)[source]
Generates spike if mem > threshold, only for the largest membrane. All others neurons will be inhibited for that time step. Returns spk.
- instances = []
Each
snntorch.SpikingNeuron
neuron (e.g.,snntorch.Synaptic
) will populate thesnntorch.SpikingNeuron.instances
list with a new entry. The list is used to initialize and clear neuron states when the argument init_hidden=True.
- reset_dict = {'none': 2, 'subtract': 0, 'zero': 1}
- property reset_mechanism
If reset_mechanism is modified, reset_mechanism_val is triggered to update. 0: subtract, 1: zero, 2: none.