Regression with SNNs: Part II

Regression-based Classification with Recurrent Leaky Integrate-and-Fire Neurons

Tutorial written by Alexander Henkes (ORCID) and Jason K. Eshraghian (ncg.ucsc.edu)

Open In Colab

This tutorial is based on the following papers on nonlinear regression and spiking neural networks. If you find these resources or code useful in your work, please consider citing the following sources:

Note

This tutorial is a static non-editable version. Interactive, editable versions are available via the following links:

In the regression tutorial series, you will learn how to use snnTorch to perform regression using a variety of spiking neuron models, including:

  • Leaky Integrate-and-Fire (LIF) Neurons

  • Recurrent LIF Neurons

  • Spiking LSTMs

An overview of the regression tutorial series:

  • Part I will train the membrane potential of a LIF neuron to follow a given trajectory over time.

  • Part II (this tutorial) will use LIF neurons with recurrent feedback to perform classification using regression-based loss functions

  • Part III will use a more complex spiking LSTM network instead to train the firing time of a neuron.

!pip install snntorch --quiet
# imports
import snntorch as snn
from snntorch import functional as SF

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools
import tqdm

1. Classification as Regression

In conventional deep learning, we often calculate the Cross Entropy Loss to train a network to do classification. The output neuron with the highest activation is thought of as the predicted class.

In spiking neural nets, this may be interpreted as the class that fires the most spikes. I.e., apply cross entropy to the total spike count (or firing frequency). The effect of this is that the predicted class will be maximized, while other classes aim to be suppressed.

The brain does not quite work like this. SNNs are sparsely activated, and while approaching SNNs with this deep learning attitude may lead to optimal accuracy, it’s important not to ‘overfit’ too much to what the deep learning folk are doing. After all, we use spikes to achieve better power efficiency. Good power efficiency relies on sparse spiking activity.

In other words, training bio-inspired SNNs using deep learning tricks does not lead to brain-like activity.

So what can we do?

We will focus on recasting classification problems into regression tasks. This is done by training the predicted neuron to fire a given number of times, while incorrect neurons are trained to still fire a given number of times, albeit less frequently.

This contrasts with cross-entropy which would try to drive the correct class to fire at all time steps, and incorrect classes to not fire at all.

As with the previous tutorial, we can use the mean-square error to achieve this. Recall the form of the mean-square error loss:

\[\mathcal{L}_{MSE} = \frac{1}{n}\sum_{i=1}^n(y_i-\hat{y_i})^2\]

where \(y\) is the target and \(\hat{y}\) is the predicted value.

To apply MSE to the spike count, assume we have \(n\) output neurons in a classification problem, where \(n\) is the number of possible classes. \(\hat{y}_i\) is now the total number of spikes the \(i^{th}\) output neuron emits over the full simulation runtime.

Given that we have \(n\) neurons, this means that \(y\) and \(\hat{y}\) must be vectors with \(n\) elements, and our loss will sum the independent MSE losses of each neuron.

1.1 A Theoretical Example

Consider a simulation of 10 time steps. Say we wish for the correct neuron class to fire 8 times, and the incorrect classes to fire 2 times. Assume \(y_1\) is the correct class:

\[\begin{split}y = \begin{bmatrix} 8 \\ 2 \\ \vdots \\ 2 \end{bmatrix}, \hat{y} = \begin{bmatrix} y_1 \\ y_2 \\ \vdots \\ y_n \end{bmatrix}\end{split}\]

The element-wise MSE is taken to generate \(n\) loss components, which are all summed together to generate a final loss.

2. Recurrent Leaky Integrate-and-Fire Neurons

Neurons in the brain have a ton of feedback connections. And so the SNN community have been exploring the dynamics of networks that feed output spikes back to the input. This is in addition to the recurrent dynamics of the membrane potential.

There are a few ways to construct recurrent leaky integrate-and-fire (RLeaky) neurons in snnTorch. Refer to the docs for an exhaustive description of the neuron’s hyperparameters. Let’s see a few examples.

2.1 RLIF Neurons with 1-to-1 connections

https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/regression2/reg2-1.jpg?raw=true

This assumes each neuron feeds back its output spikes into itself, and only itself. There are no cross-coupled connections between neurons in the same layer.

beta = 0.9 # membrane potential decay rate
num_steps = 10 # 10 time steps

rlif = snn.RLeaky(beta=beta, all_to_all=False) # initialize RLeaky Neuron
spk, mem = rlif.init_rleaky() # initialize state variables
x = torch.rand(1) # generate random input

spk_recording = []
mem_recording = []

# run simulation
for step in range(num_steps):
  spk, mem = rlif(x, spk, mem)
  spk_recording.append(spk)
  mem_recording.append(mem)

By default, V is a learnable parameter that initializes to \(1\) and will be updated during the training process. If you wish to disable learning, or use your own initialization variables, then you may do so as follows:

rlif = snn.RLeaky(beta=beta, all_to_all=False, learn_recurrent=False) # disable learning of recurrent connection
rlif.V = torch.rand(1) # set this to layer size
print(f"The recurrent weight is: {rlif.V.item()}")

2.2 RLIF Neurons with all-to-all connections

2.2.1 Linear feedback

https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/regression2/reg2-2.jpg?raw=true

By default, RLeaky assumes feedback connections where all spikes from a given layer are first weighted by a feedback layer before being passed to the input of all neurons. This introduces more parameters, but it is thought this helps with learning time-varying features in data.

beta = 0.9 # membrane potential decay rate
num_steps = 10 # 10 time steps

rlif = snn.RLeaky(beta=beta, linear_features=10)  # initialize RLeaky Neuron
spk, mem = rlif.init_rleaky() # initialize state variables
x = torch.rand(10) # generate random input

spk_recording = []
mem_recording = []

# run simulation
for step in range(num_steps):
  spk, mem = rlif(x, spk, mem)
  spk_recording.append(spk)
  mem_recording.append(mem)

You can disable learning in the feedback layer with learn_recurrent=False.

2.2.2 Convolutional feedback

If you are using a convolutional layer, this will throw an error because it does not make sense for the output spikes (3-dimensional) to be projected into 1-dimension by a nn.Linear feedback layer.

To address this, you must specify that you are using a convolutional feedback layer:

beta = 0.9 # membrane potential decay rate
num_steps = 10 # 10 time steps

rlif = snn.RLeaky(beta=beta, conv2d_channels=3, kernel_size=(5,5))  # initialize RLeaky Neuron
spk, mem = rlif.init_rleaky() # initialize state variables
x = torch.rand(3, 32, 32) # generate random 3D input

spk_recording = []
mem_recording = []

# run simulation
for step in range(num_steps):
  spk, mem = rlif(x, spk, mem)
  spk_recording.append(spk)
  mem_recording.append(mem)

To ensure the output spike dimension matches the input dimensions, padding is automatically applied.

If you have exotically shaped data, you will need to construct your own feedback layers manually.

3. Construct Model

Let’s train a couple of models using RLeaky layers. For speed, we will train a model with linear feedback.

class Net(torch.nn.Module):
    """Simple spiking neural network in snntorch."""

    def __init__(self, timesteps, hidden, beta):
        super().__init__()

        self.timesteps = timesteps
        self.hidden = hidden
        self.beta = beta

        # layer 1
        self.fc1 = torch.nn.Linear(in_features=784, out_features=self.hidden)
        self.rlif1 = snn.RLeaky(beta=self.beta, linear_features=self.hidden)

        # layer 2
        self.fc2 = torch.nn.Linear(in_features=self.hidden, out_features=10)
        self.rlif2 = snn.RLeaky(beta=self.beta, linear_features=10)

    def forward(self, x):
        """Forward pass for several time steps."""

        # Initalize membrane potential
        spk1, mem1 = self.rlif1.init_rleaky()
        spk2, mem2 = self.rlif2.init_rleaky()

        # Empty lists to record outputs
        spk_recording = []

        for step in range(self.timesteps):
            spk1, mem1 = self.rlif1(self.fc1(x), spk1, mem1)
            spk2, mem2 = self.rlif2(self.fc2(spk1), spk2, mem2)
            spk_recording.append(spk2)

        return torch.stack(spk_recording)

Instantiate the network below:

hidden = 128
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = Net(timesteps=num_steps, hidden=hidden, beta=0.9).to(device)

4. Construct Training Loop

4.1 Mean Square Error Loss in snntorch.functional

From snntorch.functional, we call mse_count_loss to set the target neuron to fire 80% of the time, and incorrect neurons to fire 20% of the time. What it took 10 paragraphs to explain is achieved in one line of code:

loss_function = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

4.2 DataLoader

Dataloader boilerplate. Let’s just do MNIST, and testing this on temporal data is an exercise left to the reader/coder.

batch_size = 128
data_path='/tmp/data/mnist'

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

4.3 Train Network

num_epochs = 5
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
loss_hist = []

with tqdm.trange(num_epochs) as pbar:
    for _ in pbar:
        train_batch = iter(train_loader)
        minibatch_counter = 0
        loss_epoch = []

        for feature, label in train_batch:
            feature = feature.to(device)
            label = label.to(device)

            spk = model(feature.flatten(1)) # forward-pass
            loss_val = loss_function(spk, label) # apply loss
            optimizer.zero_grad() # zero out gradients
            loss_val.backward() # calculate gradients
            optimizer.step() # update weights

            loss_hist.append(loss_val.item())
            minibatch_counter += 1

            avg_batch_loss = sum(loss_hist) / minibatch_counter
            pbar.set_postfix(loss="%.3e" % avg_batch_loss)

5. Evaluation

test_batch = iter(test_loader)
minibatch_counter = 0
loss_epoch = []

model.eval()
with torch.no_grad():
  total = 0
  acc = 0
  for feature, label in test_batch:
      feature = feature.to(device)
      label = label.to(device)

      spk = model(feature.flatten(1)) # forward-pass
      acc += SF.accuracy_rate(spk, label) * spk.size(1)
      total += spk.size(1)

print(f"The total accuracy on the test set is: {(acc/total) * 100:.2f}%")

In the previous tutorial, we tested membrane potential learning. We can do the same here by setting the target neuron to reach a membrane potential greater than the firing threshold, and incorrect neurons to reach a membrane potential below the firing threshold:

loss_function = SF.mse_membrane_loss(on_target=1.05, off_target=0.2)

In the above case, we are trying to get the correct neuron to constantly sit above the firing threshold.

Try updating the network and the training loop to make this work.

Hints:

  • You will need to return the output membrane potential instead of spikes.

  • Pass membrane potential to the loss function instead of spikes

Conclusion

The next regression tutorial will introduce spiking LSTMs to achieve precise spike time learning.

If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it.

Additional Resources