Quickstart

Tutorial written by Jason K. Eshraghian (www.jasoneshraghian.com)

Open In Colab

For a comprehensive overview on how SNNs work, and what is going on under the hood, then you might be interested in the snnTorch tutorial series available here. The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:

Note

This tutorial is a static non-editable version. Interactive, editable versions are available via the following links:
pip install snntorch
import torch, torch.nn as nn
import snntorch as snn

DataLoading

Define variables for dataloading.

batch_size = 128
data_path='/tmp/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

Load MNIST dataset.

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

Define Network with snnTorch.

  • snn.Leaky() instantiates a simple leaky integrate-and-fire neuron.

  • spike_grad optionally defines the surrogate gradient. If left undefined, the relevant gradient term is simply set to the output spike itself (1/0) by default.

By default, each LIF neuron returns two values: the spike and hidden state. But neurons chained together in nn.Sequential expect only one value. To handle this:

  • init_hidden initializes the hidden states (e.g., membrane potential) as instance variables to be processed in the background.

The final layer is not bound by this constraint, and can return multiple tensors:

  • output=True enables the final layer to return the hidden state in addition to the spike.

from snntorch import surrogate

beta = 0.9  # neuron decay rate
spike_grad = surrogate.fast_sigmoid() # fast sigmoid surrogate gradient

#  Initialize Convolutional SNN
net = nn.Sequential(nn.Conv2d(1, 8, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(8, 16, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(16*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

Refer to the snnTorch documentation to see more neuron types and surrogate gradient options.

Define the Forward Pass

Now define the forward pass over multiple time steps of simulation.

from snntorch import utils

def forward_pass(net, data, num_steps):
  spk_rec = [] # record spikes over time
  utils.reset(net)  # reset/initialize hidden states for all LIF neurons in net

  for step in range(num_steps): # loop over time
      spk_out, mem_out = net(data) # one time step of the forward-pass
      spk_rec.append(spk_out) # record spikes

  return torch.stack(spk_rec)

Define the optimizer and loss function. Here, we use the MSE Count Loss, which counts up the total number of output spikes at the end of the simulation run. The correct class has a target firing rate of 80% of all time steps, and incorrect classes are set to 20%.

import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

Objective functions do not have to be applied to the spike count. They may be applied to the membrane potential (hidden state), or to spike-timing targets instead of rate-based methods. A non-exhaustive list of objective functions available include:

Apply the objective directly to spikes:

  • MSE Spike Count Loss: mse_count_loss()

  • Cross Entropy Spike Count Loss: ce_count_loss()

  • Cross Entropy Spike Rate Loss: ce_rate_loss()

Apply the objective to the hidden state:

  • Cross Entropy Maximum Membrane Potential Loss: ce_max_membrane_loss()

  • MSE Membrane Potential Loss: mse_membrane_loss()

For alternative objective functions, refer to the snntorch.functional documentation here.

Training Loop

Now for the training loop. The predicted class will be set to the neuron with the highest firing rate, i.e., a rate-coded output. We will just measure accuracy on the training set. This training loop follows the same syntax as with PyTorch.

num_epochs = 1 # run for 1 epoch - each data sample is seen only once
num_steps = 25  # run for 25 time steps

loss_hist = [] # record loss over iterations
acc_hist = [] # record accuracy over iterations

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data, num_steps) # forward-pass
        loss_val = loss_fn(spk_rec, targets) # loss calculation
        optimizer.zero_grad() # null gradients
        loss_val.backward() # calculate gradients
        optimizer.step() # update weights
        loss_hist.append(loss_val.item()) # store loss

        # print every 25 iterations
        if i % 25 == 0:
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")

        # uncomment for faster termination
        # if i == 150:
        #     break

More control over your model

If you are simulating more complex architectures, such as residual nets, then your best bet is to wrap the network up in a class as shown below. This time, we will explicitly use the membrane potential, mem, and let init_hidden default to false.

For the sake of speed, we’ll just simulate a fully-connected SNN, but this can be generalized to other network types (e.g., Convs).

In addition, let’s set the neuron decay rate, beta, to be a learnable parameter. The first layer will have a shared decay rate across neurons. Each neuron in the second layer will have an independent decay rate. The decay is clipped between [0,1].

import torch.nn.functional as F

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        num_inputs = 784 # number of inputs
        num_hidden = 300 # number of hidden neurons
        num_outputs = 10 # number of classes (i.e., output neurons)

        beta1 = 0.9 # global decay rate for all leaky neurons in layer 1
        beta2 = torch.rand((num_outputs), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta1) # not a learnable decay rate
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta2, learn_beta=True) # learnable decay rate

    def forward(self, x):
        mem1 = self.lif1.init_leaky() # reset/init hidden states at t=0
        mem2 = self.lif2.init_leaky() # reset/init hidden states at t=0
        spk2_rec = [] # record output spikes
        mem2_rec = [] # record output hidden states

        for step in range(num_steps): # loop over time
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2) # record spikes
            mem2_rec.append(mem2) # record membrane

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

# Load the network onto CUDA if available
net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

num_epochs = 1 # run for 1 epoch - each data sample is seen only once
num_steps = 25  # run for 25 time steps

loss_hist = [] # record loss over iterations
acc_hist = [] # record accuracy over iterations

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec, _ = net(data) # forward-pass
        loss_val = loss_fn(spk_rec, targets) # loss calculation
        optimizer.zero_grad() # null gradients
        loss_val.backward() # calculate gradients
        optimizer.step() # update weights
        loss_hist.append(loss_val.item()) # store loss

        # print every 25 iterations
        if i % 25 == 0:
          net.eval()
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")

        # uncomment for faster termination
        # if i == 150:
        #     break
print(f"Trained decay rate of the first layer: {net.lif1.beta:.3f}\n")

print(f"Trained decay rates of the second layer: {net.lif2.beta}")
# function to measure accuracy on full test set
def test_accuracy(data_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = net(data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total
print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%")

Conclusion

That’s it for the quick intro to snnTorch!