Tutorial 6 - Surrogate Gradient Descent in a Convolutional SNN

Tutorial written by Jason K. Eshraghian (www.ncg.ucsc.edu)

Open In Colab

The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:


This tutorial is a static non-editable version. Interactive, editable versions are available via the following links:


In this tutorial, you will:

  • Learn how to modify surrogate gradient descent to overcome the dead neuron problem

  • Construct and train a convolutional spiking neural network

  • Use a sequential container, nn.Sequential to simplify model construction

Part of this tutorial was inspired by Friedemann Zenke’s extensive work on SNNs. Check out his repo on surrogate gradients here, and a favourite paper of mine: E. O. Neftci, H. Mostafa, F. Zenke, Surrogate Gradient Learning in Spiking Neural Networks: Bringing the Power of Gradient-based optimization to spiking neural networks. IEEE Signal Processing Magazine 36, 51–63.

At the end of the tutorial, we will train a convolutional spiking neural network (CSNN) using the MNIST dataset to perform image classification. The background theory follows on from Tutorials 2, 4 and 5, so feel free to go back if you need to brush up.

Install the latest PyPi distribution of snnTorch:

$ pip install snntorch
# imports
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

1. Surrogate Gradient Descent

Tutorial 5 raised the dead neuron problem. This arises because of the non-differentiability of spikes:

\[S[t] = \Theta(U[t] - U_{\rm thr}) \tag{1}\]
\[\frac{\partial S}{\partial U} = \delta(U - U_{\rm thr}) \in \{0, \infty\} \tag{2}\]

where \(\Theta(\cdot)\) is the Heaviside step function, and \(\delta(\cdot)\) is the Dirac-Delta function. We previously overcame this using the threshold-shifted ArcTangent function on the backward pass instead.

Other common smoothing functions include the sigmoid function, or the fast sigmoid function. The sigmoidal functions must also be shifted such that they are centered at the threshold \(U_{\rm thr}.\) Defining the overdrive of the membrane potential as \(U_{OD} = U - U_{\rm thr}\):

\[\tilde{S} = \frac{U_{OD}}{1+k|U_{OD}|} \tag{3}\]
\[\frac{\partial \tilde{S}}{\partial U} = \frac{1}{(k|U_{OD}|+1)^2}\tag{4}\]

where \(k\) modulates how smooth the surrogate function is, and is treated as a hyperparameter. As \(k\) increases, the approximation converges towards the original derivative in \((2)\):

\[\frac{\partial \tilde{S}}{\partial U} \Bigg|_{k \rightarrow \infty} = \delta(U-U_{\rm thr})\]

To summarize:

  • Forward Pass

    • Determine \(S\) using the shifted Heaviside function in \((1)\)

    • Store \(U\) for later use during the backward pass

  • Backward Pass

    • Pass \(U\) into \((4)\) to calculate the derivative term

In the same way the ArcTangent approach was used in Tutorial 5, the gradient of the fast sigmoid function can override the Dirac-Delta function in a Leaky Integrate-and-Fire (LIF) neuron model:

# Leaky neuron model, overriding the backward pass with a custom function
class LeakySigmoidSurrogate(nn.Module):
  def __init__(self, beta, threshold=1.0, k=25):

      # Leaky_Surrogate is defined in the previous tutorial and not used here
      super(Leaky_Surrogate, self).__init__()

      # initialize decay rate beta and threshold
      self.beta = beta
      self.threshold = threshold
      self.surrogate_func = self.FastSigmoid.apply

  # the forward function is called each time we call Leaky
  def forward(self, input_, mem):
    spk = self.surrogate_func((mem-self.threshold))  # call the Heaviside function
    reset = (spk - self.threshold).detach()
    mem = self.beta * mem + input_ - reset
    return spk, mem

  # Forward pass: Heaviside function
  # Backward pass: Override Dirac Delta with gradient of fast sigmoid
  class FastSigmoid(torch.autograd.Function):
    def forward(ctx, mem, k=25):
        ctx.save_for_backward(mem) # store the membrane potential for use in the backward pass
        ctx.k = k
        out = (mem > 0).float() # Heaviside on the forward pass: Eq(1)
        return out

    def backward(ctx, grad_output):
        (mem,) = ctx.saved_tensors  # retrieve membrane potential
        grad_input = grad_output.clone()
        grad = grad_input / (ctx.k * torch.abs(mem) + 1.0) ** 2  # gradient of fast sigmoid on backward pass: Eq(4)
        return grad, None

Better yet, all of that can be condensed by using the built-in module snn.surrogate from snnTorch, where \(k\) from \((4)\) is denoted slope. The surrogate gradient is passed into spike_grad as an argument:

spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5

lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)

To explore the other surrogate gradient functions available, take a look at the documentation here.

2. Setting up the CSNN

2.1 DataLoaders

# dataloader arguments
batch_size = 128

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

2.2 Define the Network

The convolutional network architecture to be used is: 12C5-MP2-64C5-MP2-1024FC10

  • 12C5 is a 5 \(\times\) 5 convolutional kernel with 12 filters

  • MP2 is a 2 \(\times\) 2 max-pooling function

  • 1024FC10 is a fully-connected layer that maps 1,024 neurons to 10 outputs

# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50
# Define Network
class Net(nn.Module):
    def __init__(self):

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 12, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 64, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(64*4*4, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        cur1 = F.max_pool2d(self.conv1(x), 2)
        spk1, mem1 = self.lif1(cur1, mem1)

        cur2 = F.max_pool2d(self.conv2(spk1), 2)
        spk2, mem2 = self.lif2(cur2, mem2)

        cur3 = self.fc1(spk2.view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        return spk3, mem3

In the previous tutorial, the network was wrapped inside of a class, as shown above. With increasing network complexity, this adds a lot of boilerplate code that we might wish to avoid. Alternatively, the nn.Sequential method can be used instead.


The following code-block simulates over one single time-step, and requires a separate for-loop over time.

#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 64, 5),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Linear(64*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

The init_hidden argument initializes the hidden states of the neuron (here, membrane potential). This takes place in the background as an instance variable. If init_hidden is activated, the membrane potential is not explicitly returned to the user, ensuring only the output spikes are sequentially passed through the layers wrapped in nn.Sequential.

To train a model using the final layer’s membrane potential, set the argument output=True. This enables the final layer to return both the spike and membrane potential response of the neuron.

2.3 Forward-Pass

A forward pass across a simulation duration of num_steps looks like this:

data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

for step in range(num_steps):
    spk_out, mem_out = net(data)

Wrap that in a function, recording the membrane potential and spike response over time:

def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)

  return torch.stack(spk_rec), torch.stack(mem_rec)
spk_rec, mem_rec = forward_pass(net, num_steps, data)

3. Training Loop

3.1 Loss Using snn.Functional

In the previous tutorial, the Cross Entropy Loss between the membrane potential of the output neurons and the target was used to train the network. This time, the total number of spikes from each neuron will be used to calculate the Cross Entropy instead.

A variety of loss functions are included in the snn.functional module, which is analogous to torch.nn.functional in PyTorch. These implement a mix of cross entropy and mean square error losses, are applied to spikes and/or membrane potential, to train a rate or latency-coded network.

The approach below applies the cross entropy loss to the output spike count in order train a rate-coded network:

# already imported snntorch.functional as SF
loss_fn = SF.ce_rate_loss()

The recordings of the spike are passed as the first argument to loss_fn, and the target neuron index as the second argument to generate a loss. The documentation provides further information and exmaples.

loss_val = loss_fn(spk_rec, targets)
>>> print(f"The loss from an untrained network is {loss_val.item():.3f}")
The loss from an untrained network is 2.303

3.2 Accuracy Using snn.Functional

The SF.accuracy_rate() function works similarly, in that the predicted output spikes and actual targets are supplied as arguments. accuracy_rate assumes a rate code is used to interpret the output by checking if the index of the neuron with the highest spike count matches the target index.

acc = SF.accuracy_rate(spk_rec, targets)
>>> print(f"The accuracy of a single batch using an untrained network is {acc*100:.3f}%")
The accuracy of a single batch using an untrained network is 10.938%

As the above function only returns the accuracy of a single batch of data, the following function returns the accuracy on the entire DataLoader object:

def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0

    train_loader = iter(train_loader)
    for data, targets in train_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = forward_pass(net, num_steps, data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total
test_acc = batch_accuracy(test_loader, net, num_steps)
>>> print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")
The total accuracy on the test set is: 8.59%

3.3 Training Loop

The following training loop is qualitatively similar to the previous tutorial.

optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_acc_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):

    # Training loop
    for data, targets in iter(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        spk_rec, _ = forward_pass(net, num_steps, data)

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update

        # Store loss history for future plotting

        # Test set
        if counter % 50 == 0:
        with torch.no_grad():

            # Test set forward pass
            test_acc = batch_accuracy(test_loader, net, num_steps)
            print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")

        counter += 1

The output should look something like this:

Iteration 0, Test Acc: 9.82%

Iteration 50, Test Acc: 91.98%

Iteration 100, Test Acc: 94.90%

Iteration 150, Test Acc: 95.70%

Despite having selected some fairly generic values and architectures, the test set accuracy should be fairly competitive given the brief training run!

4. Results

4.1 Plot Test Accuracy

# Plot Loss
fig = plt.figure(facecolor="w")
plt.title("Test Set Accuracy")

4.2 Spike Counter

Run a forward pass on a batch of data to obtain spike and membrane readings.

spk_rec, mem_rec = forward_pass(net, num_steps, data)

Changing idx allows you to index into various samples from the simulated minibatch. Use splt.spike_count to explore the spiking behaviour of a few different samples!

Note: if you are running the notebook locally on your desktop, please uncomment the line below and modify the path to your ffmpeg.exe

from IPython.display import HTML

idx = 0

fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']

# plt.rcParams['animation.ffmpeg_path'] = 'C:\\path\\to\\your\\ffmpeg.exe'

#  Plot spike count histogram
anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels,
                        animate=True, interpolate=4)

# anim.save("spike_bar.mp4")
>>> print(f"The target label is: {targets[idx]}")
The target label is: 3


You should now have a grasp of the basic features of snnTorch and be able to start running your own experiments. In the next tutorial, we will train a network using a neuromorphic dataset.

A special thanks to Gianfrancesco Angelini for providing valuable feedback on the tutorial.

If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it.

Additional Resources