.. container:: cell markdown ` `__ .. rubric:: snnTorch - Spiking Autoencoder (SAE) using Convolutional Spiking Neural Networks :name: snntorch---spiking-autoencoder-sae-using-convolutional-spiking-neural-networks .. rubric:: Tutorial by Alon Loeffler (www.alonloeffler.com) :name: tutorial-by-alon-loeffler-wwwalonloefflercom \*This tutorial is adapted from my original article published on Medium.com ` `__ ` `__ .. container:: cell markdown For a comprehensive overview on how SNNs work, and what is going on under the hood, `then you might be interested in the snnTorch tutorial series available here. `__ The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source: `Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. “Training Spiking Neural Networks Using Lessons From Deep Learning”. Proceedings of the IEEE, 111(9) September 2023. `_ .. container:: cell markdown In this tutorial, you will learn how to use snnTorch to: - Create a spiking Autoencoder - Reconstruct MNIST images If running in Google Colab: - You may connect to GPU by checking ``Runtime`` > ``Change runtime type`` > ``Hardware accelerator: GPU`` .. container:: cell markdown .. rubric:: 1. Autoencoders :name: 1-autoencoders | An autoencoder is a neural network that is trained to reconstruct its input data. It consists of two main components: 1) An encoder | 2) A decoder The encoder takes in input data (e.g. an image) and maps it to a lower-dimensional latent space. For example an encoder might take in as input a 28 x 28 pixel MNIST image (784 pixels total), and extract the important features from the image while compressing it to a smaller dimensionality (e.g. 32 features). This compressed representation of the image is called the *latent representation*. The decoder maps the latent representation back to the original input space (i.e. from 32 features back to 784 pixels), and tries to reconstruct the original image from a small number of key features. .. raw:: html
Example of a simple Autoencoder where x is the input data, z is the encoded latent space, and x' is the reconstructed inputs once z is decoded (source: Wikipedia).
The goal of the autoencoder is to minimize the reconstruction error between the input data and the output of the decoder. This is achieved by training the model to minimize the reconstruction loss, which is typically defined as the mean squared error (MSE) between the input and the reconstructed output. .. raw:: html
MSE loss equation. Here, $y$ would represent the original image (y true) and $\hat{y}$ would represent the reconstructed outputs (y pred) (source: Towards Data Science).
Autoencoders are excellent tools for reducing noise in data by finding only the important parts of the data, and discarding everything else during the reconstruction process. This is effectively a dimensionality reduction tool. .. container:: cell markdown In this tutorial (similar to tutorial 1), we will assume we have some non-spiking input data (i.e., the MNIST dataset) and that we want to encode it and reconstruct it. So let's get started! .. container:: cell markdown .. rubric:: 2. Setting Up :name: 2-setting-up .. container:: cell markdown .. rubric:: 2.1 Install/Import packages and set up environment :name: 21-installimport-packages-and-set-up-environment .. container:: cell markdown To start, we need to install snnTorch and its dependencies (note this tutorial assumes you have pytorch and torchvision already installed - these come preinstalled in Colab). You can do this by running the following command: .. container:: cell code .. code:: python !pip install snntorch .. container:: cell markdown Next, let’s import the necessary modules and set up the SAE model. We can use pyTorch to define the encoder and decoder networks, and snnTorch to convert the neurons in the networks into leaky integrate and fire (LIF) neurons, which read in and output spikes. We will be using convolutional neural networks (CNN), covered in tutorial 6, for the basis of our encoder and decoder. .. container:: cell code .. code:: python import os import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader from torchvision import utils as utls import snntorch as snn from snntorch import utils from snntorch import surrogate import numpy as np #Define the SAE model: class SAE(nn.Module): def __init__(self,latent_dim): super().__init__() self.latent_dim = latent_dim #dimensions of the encoded z-space data .. container:: cell markdown .. rubric:: 3. Building the Autoencoder :name: 3-building-the-autoencoder .. container:: cell markdown .. rubric:: 3.1 DataLoaders :name: 31-dataloaders We will be using the MNIST dataset .. container:: cell code .. code:: python # dataloader arguments batch_size = 250 data_path='/tmp/data/mnist' dtype = torch.float device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") .. container:: cell code .. code:: python # Define a transform input_size = 32 #for the sake of this tutorial, we will be resizing the original MNIST from 28 to 32 transform = transforms.Compose([ transforms.Resize((input_size, input_size)), transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0,), (1,))]) # Load MNIST # Training data train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Testing data test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transform, download=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) .. container:: cell markdown .. rubric:: 3.2 The Encoder :name: 32-the-encoder Let's start building the sections of our autoencoder which we slowly combine together to the SAE model we defined above: .. container:: cell markdown First, let's add an encoder with three convolutional layers (``nn.Conv2d``), and one fully-connected linear output layer. - We will use a kernel of size 3, with padding of 1 and stride of 2 for the CNN hyperparameters. - We also add a Batch Norm layer between convolutional layers. Since will be using the neuron membrane potential as outputs from each neuron, normalization will help our training process. .. container:: cell code .. code:: python #Define the SAE model: class SAE(nn.Module): def __init__(self): super().__init__() self.latent_dim = latent_dim #dimensions of the encoded z-space data # Encoder self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2), # Conv Layer 1 nn.BatchNorm2d(32), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), #SNN TORCH LIF NEURON nn.Conv2d(32, 64, 3,padding = 1,stride=2), # Conv Layer 2 nn.BatchNorm2d(64), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Conv2d(64, 128, 3,padding = 1,stride=2), # Conv Layer 3 nn.BatchNorm2d(128), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Flatten(start_dim = 1, end_dim = 3), #Flatten convolutional output nn.Linear(128*4*4, latent_dim), # Fully connected linear layer snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh) ) .. container:: cell markdown .. rubric:: 3.3 The Decoder :name: 33-the-decoder .. container:: cell markdown Before we write the decoder, there is one more small step required. When decoding the latent z-space data, we need to move from the flattened encoded representation (latent_dim) back to a tensor representation to use in transposed convolution. To do so, we need to run an additional fully-connected linear layer transforming the data back into a tensor of 128 x 4 x 4. .. container:: cell code .. code:: python #Define the SAE model: class SAE(nn.Module): def __init__(self,latent_dim): super().__init__() self.latent_dim = latent_dim #dimensions of the encoded z-space data # Encoder self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2), # Conv Layer 1 nn.BatchNorm2d(32), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), #SNN TORCH LIF NEURON nn.Conv2d(32, 64, 3,padding = 1,stride=2), # Conv Layer 2 nn.BatchNorm2d(64), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Conv2d(64, 128, 3,padding = 1,stride=2), # Conv Layer 3 nn.BatchNorm2d(128), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Flatten(start_dim = 1, end_dim = 3), #Flatten convolutional output nn.Linear(128*4*4, latent_dim), # Fully connected linear layer snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh) ) # From latent back to tensor for convolution self.linearNet= nn.Sequential(nn.Linear(latent_dim,128*4*4), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)) .. container:: cell markdown Now we can write the decoder, with three transposed convolutional (``nn.ConvTranspose2d``) layers and one linear output layer. Although we converted the latent data back into tensor form for convolution, we still need to Unflatten it to a tensor of 128 x 4 x 4, as the input to the network is 1 dimensional. This is done using ``nn.Unflatten`` in the first line of the Decoder. .. container:: cell code .. code:: python #Define the SAE model: class SAE(nn.Module): def __init__(self,latent_dim): super().__init__() self.latent_dim = latent_dim #dimensions of the encoded z-space data # Encoder self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2), # Conv Layer 1 nn.BatchNorm2d(32), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), #SNN TORCH LIF NEURON nn.Conv2d(32, 64, 3,padding = 1,stride=2), # Conv Layer 2 nn.BatchNorm2d(64), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Conv2d(64, 128, 3,padding = 1,stride=2), # Conv Layer 3 nn.BatchNorm2d(128), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Flatten(start_dim = 1, end_dim = 3), #Flatten convolutional output nn.Linear(128*4*4, latent_dim), # Fully connected linear layer snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh) ) # From latent back to tensor for convolution self.linearNet = nn.Sequential(nn.Linear(latent_dim,128*4*4), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)) # Decoder self.decoder = nn.Sequential(nn.Unflatten(1,(128,4,4)), #Unflatten data from 1 dim to tensor of 128 x 4 x 4 snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.ConvTranspose2d(128, 64, 3,padding = 1,stride=(2,2),output_padding=1), nn.BatchNorm2d(64), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.ConvTranspose2d(64, 32, 3,padding = 1,stride=(2,2),output_padding=1), nn.BatchNorm2d(32), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.ConvTranspose2d(32, 1, 3,padding = 1,stride=(2,2),output_padding=1), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,output=True,threshold=20000) #make large so membrane can be trained ) .. container:: cell markdown One important thing to note is in the final Leaky layer, our spiking threshold (``thresh``) is set extremely high. This is a neat trick in snnTorch, which allows the neuron membrane in the final layer to continuously be updated, without ever reaching a spiking threshold. The output of each Leaky Neuron will consist of a tensor of spikes (0 or 1) and a tensor of neuron membrane potential (negative or positive real numbers). snnTorch allows us to use either the spikes or membrane potential of each neuron in training. We will be using the membrane potential output from the final layer for the image reconstruction. .. container:: cell markdown .. rubric:: 3.4 Forward Function :name: 34-forward-function Finally, let’s write the forward, encode and decode functions, before putting it all together .. container:: cell code .. code:: python def forward(self, x): utils.reset(self.encoder) #need to reset the hidden states of LIF utils.reset(self.decoder) utils.reset(self.linearNet) #encode spk_mem=[];spk_rec=[];encoded_x=[] for step in range(num_steps): #for t in time spk_x,mem_x=self.encode(x) #Output spike trains and neuron membrane states spk_rec.append(spk_x) spk_mem.append(mem_x) spk_rec=torch.stack(spk_rec,dim=2) # stack spikes in second tensor dimension spk_mem=torch.stack(spk_mem,dim=2) # stack membranes in second tensor dimension #decode spk_mem2=[];spk_rec2=[];decoded_x=[] for step in range(num_steps): #for t in time x_recon,x_mem_recon=self.decode(spk_rec[...,step]) spk_rec2.append(x_recon) spk_mem2.append(x_mem_recon) spk_rec2=torch.stack(spk_rec2,dim=4) spk_mem2=torch.stack(spk_mem2,dim=4) out = spk_mem2[:,:,:,:,-1] #return the membrane potential of the output neuron at t = -1 (last t) return out def encode(self,x): spk_latent_x,mem_latent_x=self.encoder(x) return spk_latent_x,mem_latent_x def decode(self,x): spk_x,mem_x = self.latentToConv(x) #convert latent dimension back to total size of features in encoder final layer spk_x2,mem_x2=self.decoder(spk_x) return spk_x2,mem_x2 .. container:: cell markdown There are a couple of key things to notice here: 1) At the beginning of each call of our forward function, we need to reset the hidden weights of each LIF neuron. If we do not do this, we will get weird gradient errors from pytorch when we try to backprop. To do so we use ``utils.reset``. 2) In the forward function, when we call the encode and decode functions, we do so in a loop. This is because we are converting static images into spike trains, as explained previously. Spike trains need a time, t, during which spiking can occur or not occur. Therefore, we encode and decode the original image :math:`t` (or ``num_steps``) times, to create a latent representation, :math:`z`. .. container:: cell markdown For example, converting a sample digit 7 from the MNIST dataset into a spike-train with a latent dimension of 32 and t = 50, might look like this: Spike-Train of sample MNIST digit 7 after encoding. Other instances of 7 will have slightly different spike-trains, and different digits will have even more different spike-trains. .. container:: cell markdown .. rubric:: 3.5 Putting it all together: :name: 35-putting-it-all-together Our final, complete SAE class should look like this: .. container:: cell code .. code:: python class SAE(nn.Module): def __init__(self): super().__init__() #Encoder self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2), nn.BatchNorm2d(32), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Conv2d(32, 64, 3,padding = 1,stride=2), nn.BatchNorm2d(64), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Conv2d(64, 128, 3,padding = 1,stride=2), nn.BatchNorm2d(128), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.Flatten(start_dim = 1, end_dim = 3), nn.Linear(2048, latent_dim), #this needs to be the final layer output size (channels * pixels * pixels) snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh) ) # From latent back to tensor for convolution self.linearNet= nn.Sequential(nn.Linear(latent_dim,128*4*4), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)) #Decoder self.decoder = nn.Sequential(nn.Unflatten(1,(128,4,4)), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.ConvTranspose2d(128, 64, 3,padding = 1,stride=(2,2),output_padding=1), nn.BatchNorm2d(64), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.ConvTranspose2d(64, 32, 3,padding = 1,stride=(2,2),output_padding=1), nn.BatchNorm2d(32), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh), nn.ConvTranspose2d(32, 1, 3,padding = 1,stride=(2,2),output_padding=1), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,output=True,threshold=20000) #make large so membrane can be trained ) def forward(self, x): #Dimensions: [Batch,Channels,Width,Length] utils.reset(self.encoder) #need to reset the hidden states of LIF utils.reset(self.decoder) utils.reset(self.linearNet) #encode spk_mem=[];spk_rec=[];encoded_x=[] for step in range(num_steps): #for t in time spk_x,mem_x=self.encode(x) #Output spike trains and neuron membrane states spk_rec.append(spk_x) spk_mem.append(mem_x) spk_rec=torch.stack(spk_rec,dim=2) spk_mem=torch.stack(spk_mem,dim=2) #Dimensions:[Batch,Channels,Width,Length, Time] #decode spk_mem2=[];spk_rec2=[];decoded_x=[] for step in range(num_steps): #for t in time x_recon,x_mem_recon=self.decode(spk_rec[...,step]) spk_rec2.append(x_recon) spk_mem2.append(x_mem_recon) spk_rec2=torch.stack(spk_rec2,dim=4) spk_mem2=torch.stack(spk_mem2,dim=4)#Dimensions:[Batch,Channels,Width,Length, Time] out = spk_mem2[:,:,:,:,-1] #return the membrane potential of the output neuron at t = -1 (last t) return out #Dimensions:[Batch,Channels,Width,Length] def encode(self,x): spk_latent_x,mem_latent_x=self.encoder(x) return spk_latent_x,mem_latent_x def decode(self,x): spk_x,mem_x = self.linearNet(x) #convert latent dimension back to total size of features in encoder final layer spk_x2,mem_x2=self.decoder(spk_x) return spk_x2,mem_x2 .. container:: cell markdown .. rubric:: 4. Training and Testing :name: 4-training-and-testing Finally, we can move on to training our SAE, and testing its usefulness. We have already loaded the MNIST dataset, and split it into training and testing classes. .. container:: cell markdown .. rubric:: 4.1 Training Function :name: 41-training-function We define our training function, which takes in the network model, training dataset, optimizer and epoch number as inputs, and returns the loss value after running all batches of the current epoch. As discussed at the beginning, we will be using MSE loss to compare the reconstructed image (``x_recon``) with the original image (``real_img``) As always, to set up our gradients for backprop we use ``opti.zero_grad()``, and then call ``loss_val.backward()`` and ``opti.step()`` to perform backprop. .. container:: cell code .. code:: python #Training def train(network, trainloader, opti, epoch): network=network.train() train_loss_hist=[] for batch_idx, (real_img, labels) in enumerate(trainloader): opti.zero_grad() real_img = real_img.to(device) labels = labels.to(device) #Pass data into network, and return reconstructed image from Membrane Potential at t = -1 x_recon = network(real_img) #Dimensions passed in: [Batch_size,Input_size,Image_Width,Image_Length] #Calculate loss loss_val = F.mse_loss(x_recon, real_img) print(f'Train[{epoch}/{max_epoch}][{batch_idx}/{len(trainloader)}] Loss: {loss_val.item()}') loss_val.backward() opti.step() #Save reconstructed images every at the end of the epoch if batch_idx == len(trainloader)-1: # NOTE: you need to create training/ and testing/ folders in your chosen path utls.save_image((real_img+1)/2, f'figures/training/epoch{epoch}_finalbatch_inputs.png') utls.save_image((x_recon+1)/2, f'figures/training/epoch{epoch}_finalbatch_recon.png') return loss_val .. container:: cell markdown .. rubric:: 4.2 Testing Function :name: 42-testing-function The testing function is nearly identifcal to the training function, except we do not backpropagate, therefore no gradients are required and we use ``torch.no_grad()`` .. container:: cell code .. code:: python #Testing def test(network, testloader, opti, epoch): network=network.eval() test_loss_hist=[] with torch.no_grad(): #no gradient this time for batch_idx, (real_img, labels) in enumerate(testloader): real_img = real_img.to(device)# labels = labels.to(device) x_recon = network(real_img) loss_val = F.mse_loss(x_recon, real_img) print(f'Test[{epoch}/{max_epoch}][{batch_idx}/{len(testloader)}] Loss: {loss_val.item()}')#, RECONS: {recons_meter.avg}, DISTANCE: {dist_meter.avg}') if batch_idx == len(testloader)-1: utls.save_image((real_img+1)/2, f'figures/testing/epoch{epoch}_finalbatch_inputs.png') utls.save_image((x_recon+1)/2, f'figures/testing/epoch{epoch}_finalbatch_recons.png') return loss_val .. container:: cell markdown There are a couple of ways to calculate loss with spiking neural networks. Here, we are simply taking the membrane potential of the final fully-connected layer of neurons at the last time step (:math:`t = 5`). Therefore, we only need to compare each original image with its corresponding decoded, reconstructed image once per epoch. We can also return the membrane potentials at each time step, and create t different versions of the reconstructed image, and then compare each of them with the original image and take the average loss. For those of you interested in this, you can replace the loss function above with something like this: (*note this will fail to run as we have not defined any of the variables yet, it is just here for illustrative purposes*) .. container:: cell code .. code:: python train_loss_hist=[] loss_val = torch.zeros((1), dtype=dtype, device=device) for step in range(num_steps): loss_val += F.mse_loss(x_recon, real_img) train_loss_hist.append(loss_val.item()) avg_loss=loss_val/num_steps .. container:: output error :: --------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[72], line 4 2 loss_val = torch.zeros((1), dtype=dtype, device=device) 3 for step in range(num_steps): ----> 4 loss_val += F.mse_loss(x_recon, real_img) 5 train_loss_hist.append(loss_val.item()) 6 avg_loss=loss_val/num_steps NameError: name 'x_recon' is not defined .. container:: cell markdown .. rubric:: 5. Conclusion: Running the SAE :name: 5-conclusion-running-the-sae Now, finally, we can run our SAE model. Let’s define some parameters, and run training and testing .. container:: cell markdown Let's create directories where we can save our original and reconstructed images for training and testing: .. container:: cell code .. code:: python # create training/ and testing/ folders in your chosen path if not os.path.isdir('figures/training'): os.makedirs('figures/training') if not os.path.isdir('figures/testing'): os.makedirs('figures/testing') .. container:: cell code .. code:: python # dataloader arguments batch_size = 250 input_size = 32 #resize of mnist data (optional) #setup GPU dtype = torch.float device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # neuron and simulation parameters spike_grad = surrogate.atan(alpha=2.0)# alternate surrogate gradient fast_sigmoid(slope=25) beta = 0.5 #decay rate of neurons num_steps=5 latent_dim = 32 #dimension of latent layer (how compressed we want the information) thresh=1#spiking threshold (lower = more spikes are let through) epochs=10 max_epoch=epochs #Define Network and optimizer net=SAE() net = net.to(device) optimizer = torch.optim.AdamW(net.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=0.001) #Run training and testing for e in range(epochs): train_loss = train(net, train_loader, optimizer, e) test_loss = test(net,test_loader,optimizer,e) .. container:: output stream stdout :: Train[0/10][0/240] Loss: 0.10109379142522812 Train[0/10][1/240] Loss: 0.10465191304683685 .. container:: output stream stderr :: KeyboardInterrupt .. container:: cell markdown After only 10 epochs, our training and testing reconstructed losses should be around 0.05, and our reconstructed images should look something like this: .. container:: cell markdown .. container:: cell markdown Yes, the reconstructed images are a bit blurry, and the loss isn’t perfect, but from only 10 epochs, and only using the final membrane potential at :math:`t = 5` for our reconstructed loss, it’s a pretty decent start! .. container:: cell markdown Try increasing the number of epochs, or playing around with ``thresh``, ``num_steps`` and ``batch_size`` to see if you can get better loss!