Tutorial 7 - Neuromorphic Datasets with Tonic + snnTorch

Tutorial written by Gregor Lenz (https://lenzgregor.com) and Jason K. Eshraghian (www.ncg.ucsc.edu)

Open In Colab

The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:


This tutorial is a static non-editable version. Interactive, editable versions are available via the following links:


In this tutorial, you will:

  • Learn how to load neuromorphic datasets using Tonic

  • Make use of caching to speed up dataloading

  • Train a CSNN with the Neuromorphic-MNIST Dataset

Install the latest PyPi distribution of snnTorch:

pip install tonic
pip install snntorch

1. Using Tonic to Load Neuromorphic Datasets

Loading datasets from neuromorphic sensors is made super simple thanks to Tonic, which works much like PyTorch vision.

Let’s start by loading the neuromorphic version of the MNIST dataset, called N-MNIST. We can have a look at some raw events to get a feel for what we’re working with.

import tonic

dataset = tonic.datasets.NMNIST(save_to='./data', train=True)
events, target = dataset[0]
>>> print(events)
[(10, 30, 937, 1) (33, 20, 1030, 1) (12, 27, 1052, 1) ...
( 7, 15, 302706, 1) (26, 11, 303852, 1) (11, 17, 305341, 1)]

Each row corresponds to a single event, which consists of four parameters: (x-coordinate, y-coordinate, timestamp, polarity).

  • x & y co-ordinates correspond to an address in a \(34 \times 34\) grid.

  • The timestamp of the event is recorded in microseconds.

  • The polarity refers to whether an on-spike (+1) or an off-spike (-1) occured; i.e., an increase in brightness or a decrease in brightness.

If we were to accumulate those events over time and plot the bins as images, it looks like this:

>>> tonic.utils.plot_event_grid(events)

1.1 Transformations

However, neural nets don’t take lists of events as input. The raw data must be converted into a suitable representation, such as a tensor. We can choose a set of transforms to apply to our data before feeding it to our network. The neuromorphic camera sensor has a temporal resolution of microseconds, which when converted into a dense representation, ends up as a very large tensor. That is why we bin events into a smaller number of frames using the ToFrame transformation, which reduces temporal precision but also allows us to work with it in a dense format.

  • time_window=1000 integrates events into 1000\(~\mu\)s bins

  • Denoise removes isolated, one-off events. If no event occurs within a neighbourhood of 1 pixel across filter_time microseconds, the event is filtered. Smaller filter_time will filter more events.

import tonic.transforms as transforms

sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),

trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)
testset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=False)

1.2 Fast DataLoading

The original data is stored in a format that is slow to read. To speed up dataloading, we can make use of disk caching and batching. That means that once files are loaded from the original dataset, they are written to the disk.

Because event recordings have different lengths, we are going to provide a collation function tonic.collation.PadTensors() that will pad out shorter recordings to ensure all samples in a batch have the same dimensions.

from torch.utils.data import DataLoader
from tonic import DiskCachedDataset

cached_trainset = DiskCachedDataset(trainset, cache_path='./cache/nmnist/train')
cached_dataloader = DataLoader(cached_trainset)

batch_size = 128
trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors())
def load_sample_batched():
    events, target = next(iter(cached_dataloader))
>>> %timeit -o -r 10 load_sample_batched()
4.2 ms ± 119 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)

By using disk caching and a PyTorch dataloader with multithreading and batching support, we have signifantly reduced loading times.

If you have a large amount of RAM available, you can speed up dataloading further by caching to main memory instead of to disk:

from tonic import MemoryCachedDataset

cached_trainset = MemoryCachedDataset(trainset)

2. Training our network using frames created from events

Now let’s actually train a network on the N-MNIST classification task. We start by defining our caching wrappers and dataloaders. While doing that, we’re also going to apply some augmentations to the training data. The samples we receive from the cached dataset are frames, so we can make use of PyTorch Vision to apply whatever random transform we would like.

import torch
import torchvision

transform = tonic.transforms.Compose([torch.from_numpy,

cached_trainset = DiskCachedDataset(trainset, transform=transform, cache_path='./cache/nmnist/train')

# no augmentations for the testset
cached_testset = DiskCachedDataset(testset, cache_path='./cache/nmnist/test')

batch_size = 128
trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))

A mini-batch now has the dimensions (time steps, batch size, channels, height, width). The number of time steps will be set to that of the longest recording in the mini-batch, and all other samples will be padded with zeros to match it.

>>> event_tensor, target = next(iter(trainloader))
>>> print(event_tensor.shape)
torch.Size([311, 128, 2, 34, 34])

2.1 Defining our network

We will use snnTorch + PyTorch to construct a CSNN, just as in the previous tutorial. The convolutional network architecture to be used is: 12C5-MP2-32C5-MP2-800FC10

  • 12C5 is a 5 \(\times\) 5 convolutional kernel with 12 filters

  • MP2 is a 2 \(\times\) 2 max-pooling function

  • 800FC10 is a fully-connected layer that maps 800 neurons to 10 outputs

import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils
import torch.nn as nn
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 0.5

#  Initialize Network
net = nn.Sequential(nn.Conv2d(2, 12, 5),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 32, 5),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Linear(32*5*5, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
# this time, we won't return membrane as we don't need it

def forward_pass(net, data):
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(data.size(0)):  # data.size(0) = number of time steps
      spk_out, mem_out = net(data[step])

  return torch.stack(spk_rec)

2.2 Training

In the previous tutorial, Cross Entropy Loss was applied to the total spike count to maximize the number of spikes from the correct class.

Another option from the snn.functional module is to specify the target number of spikes from correct and incorrect classes. The approach below uses the Mean Square Error Spike Count Loss, which aims to elicit spikes from the correct class 80% of the time, and 20% of the time from incorrect classes. Encouraging incorrect neurons to fire could be motivated to avoid dead neurons.

optimizer = torch.optim.Adam(net.parameters(), lr=2e-2, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

Training neuromorphic data is expensive as it requires sequentially iterating through many time steps (approximately 300 time steps in the N-MNIST dataset). The following simulation will take some time, so we will just stick to training across 50 iterations (which is roughly 1/10th of a full epoch). Feel free to change num_iters if you have more time to kill. As we are printing results at each iteration, the results will be quite noisy and will also take some time before we start to see any sort of improvement.

In our own experiments, it took about 20 iterations before we saw any improvement, and after 50 iterations, managed to crack ~60% accuracy.

Warning: the following simulation will take a while. Go make yourself a coffee, or ten.

num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(trainloader)):
        data = data.to(device)
        targets = targets.to(device)

        spk_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update

        # Store loss history for future plotting

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        # training loop breaks after 50 iterations
        if i == num_iters:

The output should look something like this:

Epoch 0, Iteration 0
Train Loss: 31.00
Accuracy: 10.16%

Epoch 0, Iteration 1
Train Loss: 30.58
Accuracy: 13.28%

And after some more time:

Epoch 0, Iteration 49
Train Loss: 8.78
Accuracy: 47.66%

Epoch 0, Iteration 50
Train Loss: 8.43
Accuracy: 56.25%

3. Results

3.1 Plot Test Accuracy

import matplotlib.pyplot as plt

# Plot Loss
fig = plt.figure(facecolor="w")
plt.title("Train Set Accuracy")

3.2 Spike Counter

Run a forward pass on a batch of data to obtain spike recordings.

spk_rec = forward_pass(net, data)

Changing idx allows you to index into various samples from the simulated minibatch. Use splt.spike_count to explore the spiking behaviour of a few different samples. Generating the following animation will take some time.

Note: if you are running the notebook locally on your desktop, please uncomment the line below and modify the path to your ffmpeg.exe

from IPython.display import HTML

idx = 0

fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']
print(f"The target label is: {targets[idx]}")

# plt.rcParams['animation.ffmpeg_path'] = 'C:\\path\\to\\your\\ffmpeg.exe'

#  Plot spike count histogram
anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels,
                        animate=True, interpolate=1)

# anim.save("spike_bar.mp4")
The target label is: 3


If you made it this far, then congratulations - you have the patience of a monk. You should now also understand how to load neuromorphic datasets using Tonic and then train a network using snnTorch.

This concludes the deep-dive tutorial series. Check out the advanced tutorials to learn more advanced techniques, such as introducing long-term temporal dynamics into our SNNs, population coding, or accelerating on Intelligence Processing Units.

If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it.

Additional Resources