=========================================================== Tutorial 5 - Training Spiking Neural Networks with snntorch =========================================================== Tutorial written by Jason K. Eshraghian (`www.ncg.ucsc.edu `_) .. image:: https://colab.research.google.com/assets/colab-badge.svg :alt: Open In Colab :target: https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_5_FCN.ipynb The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source: `Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. “Training Spiking Neural Networks Using Lessons From Deep Learning”. arXiv preprint arXiv:2109.12894, September 2021. `_ .. note:: This tutorial is a static non-editable version. Interactive, editable versions are available via the following links: * `Google Colab `_ * `Local Notebook (download via GitHub) `_ Introduction --------------- In this tutorial, you will: * Learn how spiking neurons are implemented as a recurrent network * Understand backpropagation through time, and the associated challenges in SNNs such as the non-differentiability of spikes * Train a fully-connected network on the static MNIST dataset .. Part of this tutorial was inspired by Friedemann Zenke’s extensive work on SNNs. Check out his repo on surrogate gradients `here `__, and a favourite paper of mine: E. O. Neftci, H. Mostafa, F. Zenke, `Surrogate Gradient Learning in Spiking Neural Networks: Bringing the Power of Gradient-based optimization to spiking neural networks. `__ IEEE Signal Processing Magazine 36, 51–63. At the end of the tutorial, a basic supervised learning algorithm will be implemented. We will use the original static MNIST dataset and train a multi-layer fully-connected spiking neural network using gradient descent to perform image classification. Install the latest PyPi distribution of snnTorch: :: $ pip install snntorch :: # imports import snntorch as snn from snntorch import spikeplot as splt from snntorch import spikegen import torch import torch.nn as nn from torch.utils.data import DataLoader from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np import itertools 1. A Recurrent Representation of SNNs ---------------------------------------- In `Tutorial 3 `_, we derived a recursive representation of a leaky integrate-and-fire (LIF) neuron: .. math:: U[t+1] = \underbrace{\beta U[t]}_\text{decay} + \underbrace{WX[t+1]}_\text{input} - \underbrace{R[t]}_\text{reset} \tag{1} where input synaptic current is interpreted as :math:`I_{\rm in}[t] = WX[t]`, and :math:`X[t]` may be some arbitrary input of spikes, a step/time-varying voltage, or unweighted step/time-varying current. Spiking is represented with the following equation, where if the membrane potential exceeds the threshold, a spike is emitted: .. math:: S[t] = \begin{cases} 1, &\text{if}~U[t] > U_{\rm thr} \\ 0, &\text{otherwise}\end{cases} .. math:: \tag{2} This formulation of a spiking neuron in a discrete, recursive form is almost perfectly poised to take advantage of the developments in training recurrent neural networks (RNNs) and sequence-based models. This is illustrated using an *implicit* recurrent connection for the decay of the membrane potential, and is distinguished from *explicit* recurrence where the output spike :math:`S_{\rm out}` is fed back to the input. In the figure below, the connection weighted by :math:`-U_{\rm thr}` represents the reset mechanism :math:`R[t]`. .. image:: https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/unrolled_2.png?raw=true :align: center :width: 600 The benefit of an unrolled graph is that it provides an explicit description of how computations are performed. The process of unfolding illustrates the flow of information forward in time (from left to right) to compute outputs and losses, and backward in time to compute gradients. The more time steps that are simulated, the deeper the graph becomes. Conventional RNNs treat :math:`\beta` as a learnable parameter. This is also possible for SNNs, though by default, they are treated as hyperparameters. This replaces the vanishing and exploding gradient problems with a hyperparameter search. A future tutorial will describe how to make :math:`\beta` a learnable parameter. 2. The Non-Differentiability of Spikes ----------------------------------------- 2.1 Training Using the Backprop Algorithm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ An alternative way to represent the relationship between :math:`S` and :math:`U` in :math:`(2)` is: .. math:: S[t] = \Theta(U[t] - U_{\rm thr}) \tag{3} where :math:`\Theta(\cdot)` is the Heaviside step function: .. image:: https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial3/3_2_spike_descrip.png?raw=true :align: center :width: 600 Training a network in this form poses some serious challenges. Consider a single, isolated time step of the computational graph from the previous figure titled *"Recurrent representation of spiking neurons"*, as shown in the *forward pass* below: .. image:: https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/non-diff.png?raw=true :align: center :width: 400 The goal is to train the network using the gradient of the loss with respect to the weights, such that the weights are updated to minimize the loss. The backpropagation algorithm achieves this using the chain rule: .. math:: \frac{\partial \mathcal{L}}{\partial W} = \frac{\partial \mathcal{L}}{\partial S} \underbrace{\frac{\partial S}{\partial U}}_{\{0, \infty\}} \frac{\partial U}{\partial I}\ \frac{\partial I}{\partial W}\ \tag{4} From :math:`(1)`, :math:`\partial I/\partial W=X`, and :math:`\partial U/\partial I=1`. While a loss function is yet to be defined, we can assume :math:`\partial \mathcal{L}/\partial S` has an analytical solution, in a similar form to the cross-entropy or mean-square error loss (more on that shortly). However, the term that we are going to grapple with is :math:`\partial S/\partial U`. The derivative of the Heaviside step function from :math:`(3)` is the Dirac Delta function, which evaluates to :math:`0` everywhere, except at the threshold :math:`U_{\rm thr} = \theta`, where it tends to infinity. This means the gradient will almost always be nulled to zero (or saturated if :math:`U` sits precisely at the threshold), and no learning can take place. This is known as the **dead neuron problem**. 2.2 Overcoming the Dead Neuron Problem ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The most common way to address the dead neuron problem is to keep the Heaviside function as it is during the forward pass, but swap the derivative term :math:`\partial S/\partial U` for something that does not kill the learning process during the backward pass, which will be denoted :math:`\partial \tilde{S}/\partial U`. This might sound odd, but it turns out that neural networks are quite robust to such approximations. This is commonly known as the *surrogate gradient* approach. A variety of options exist to using surrogate gradients, and we will dive into more detail on these methods in `Tutorial 6 `_. The default method in snnTorch (as of v0.6.0) is to smooth the Heaviside function with the arctangent function. The backward-pass derivative used is: .. math:: \frac{\partial \tilde{S}}{\partial U} \leftarrow \frac{1}{\pi}\frac{1}{(1+[U\pi]^2)} where the left arrow denotes substitution. The same neuron model described in :math:`(1)-(2)` (a.k.a., ``snn.Leaky`` neuron from Tutorial 3) is implemented in PyTorch below. Don’t worry if you don’t understand this. This will be condensed into one line of code using snnTorch in a moment: :: # Leaky neuron model, overriding the backward pass with a custom function class LeakySurrogate(nn.Module): def __init__(self, beta, threshold=1.0): super(LeakySurrogate, self).__init__() # initialize decay rate beta and threshold self.beta = beta self.threshold = threshold self.spike_gradient = self.ATan.apply # the forward function is called each time we call Leaky def forward(self, input_, mem): spk = self.spike_gradient((mem-self.threshold)) # call the Heaviside function reset = (self.beta * spk * self.threshold).detach() # remove reset from computational graph mem = self.beta * mem + input_ - reset # Eq (1) return spk, mem # Forward pass: Heaviside function # Backward pass: Override Dirac Delta with the derivative of the ArcTan function @staticmethod class ATan(torch.autograd.Function): @staticmethod def forward(ctx, mem): spk = (mem > 0).float() # Heaviside on the forward pass: Eq(2) ctx.save_for_backward(mem) # store the membrane for use in the backward pass return spk @staticmethod def backward(ctx, grad_output): (spk,) = ctx.saved_tensors # retrieve the membrane potential grad = 1 / (1 + (np.pi * mem).pow_(2)) * grad_output # Eqn 5 return grad Note that the reset mechanism is detached from the computational graph, as the surrogate gradient should only be applied to :math:`\partial S/\partial U`, and not :math:`\partial R/\partial U`. The above neuron is instantiated using: :: lif1 = LeakySurrogate(beta=0.9) This neuron can be simulated using a for-loop, just as in previous tutorials, while PyTorch’s automatic differentation (autodiff) mechanism keeps track of the gradient in the background. The same thing can be accomplished by calling the ``snn.Leaky`` neuron. In fact, every time you call any neuron model from snnTorch, the *ATan* surrogate gradient is applied to it by default: :: lif1 = snn.Leaky(beta=0.9) If you would like to explore how this neuron behaves, then refer to `Tutorial 3 `__. 3. Backprop Through Time -------------------------- Equation :math:`(4)` only calculates the gradient for one single time step (referred to as the *immediate influence* in the figure below), but the backpropagation through time (BPTT) algorithm calculates the gradient from the loss to *all* descendants and sums them together. The weight :math:`W` is applied at every time step, and so imagine a loss is also calculated at every time step. The influence of the weight on present and historical losses must be summed together to define the global gradient: .. math:: \frac{\partial \mathcal{L}}{\partial W}=\sum_t \frac{\partial\mathcal{L}[t]}{\partial W} = \sum_t \sum_{s\leq t} \frac{\partial\mathcal{L}[t]}{\partial W[s]}\frac{\partial W[s]}{\partial W} \tag{5} The point of :math:`(5)` is to ensure causality: by constraining :math:`s\leq t`, we only account for the contribution of immediate and prior influences of :math:`W` on the loss. A recurrent system constrains the weight to be shared across all steps: :math:`W[0]=W[1] =~... ~ = W`. Therefore, a change in :math:`W[s]` will have the same effect on all :math:`W`, which implies that :math:`\partial W[s]/\partial W=1`: .. math:: \frac{\partial \mathcal{L}}{\partial W}= \sum_t \sum_{s\leq t} \frac{\partial\mathcal{L}[t]}{\partial W[s]} \tag{6} As an example, isolate the prior influence due to :math:`s = t-1` *only*; this means the backward pass must track back in time by one step. The influence of :math:`W[t-1]` on the loss can be written as: .. math:: \frac{\partial \mathcal{L}[t]}{\partial W[t-1]} = \frac{\partial \mathcal{L}[t]}{\partial S[t]} \underbrace{\frac{\partial \tilde{S}[t]}{\partial U[t]}}_{Eq.~(5)} \underbrace{\frac{\partial U[t]}{\partial U[t-1]}}_\beta \underbrace{\frac{\partial U[t-1]}{\partial I[t-1]}}_1 \underbrace{\frac{\partial I[t-1]}{\partial W[t-1]}}_{X[t-1]} \tag{7} We have already dealt with all of these terms from :math:`(4)`, except for :math:`\partial U[t]/\partial U[t-1]`. From :math:`(1)`, this temporal derivative term simply evaluates to :math:`\beta`. So if we really wanted to, we now know enough to painstakingly calculate the derivative of every weight at every time step by hand, and it’d look something like this for a single neuron: .. image:: https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/bptt.png?raw=true :align: center :width: 600 But thankfully, PyTorch’s autodiff takes care of that in the background for us. .. note:: The reset mechanism has been omitted from the above figure. In snnTorch, reset is included in the forward-pass, but detached from the backward pass. 4. Setting up the Loss / Output Decoding ------------------------------------------- In a conventional, non-spiking neural network, a supervised, multi-class classification problem takes the neuron with the highest activation and treats that as the predicted class. In a spiking neural net, there are several options to interpreting the output spikes. The most common approaches are: * **Rate coding:** Take the neuron with the highest firing rate (or spike count) as the predicted class * **Latency coding:** Take the neuron that fires *first* as the predicted class This might feel familiar to `Tutorial 1 on neural encoding `__. The difference is that, here, we are interpreting (decoding) the output spikes, rather than encoding/converting raw input data into spikes. Let’s focus on a rate code. When input data is passed to the network, we want the correct neuron class to emit the most spikes over the course of the simulation run. This corresponds to the highest average firing frequency. One way to achieve this is to increase the membrane potential of the correct class to :math:`U>U_{\rm thr}`, and that of incorrect classes to :math:`U