=========================================================== 教程(五) - 使用snnTorch训练脉冲神经网络 =========================================================== 本教程出自 Jason K. Eshraghian (`www.ncg.ucsc.edu `_) `English `_ .. image:: https://colab.research.google.com/assets/colab-badge.svg :alt: Open In Colab :target: https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_5_FCN.ipynb snnTorch 教程系列基于以下论文。如果您发现这些资源或代码对您的工作有用,请考虑引用以下来源: `Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. “Training Spiking Neural Networks Using Lessons From Deep Learning”. Proceedings of the IEEE, 111(9) September 2023. `_ .. note:: 本教程是不可编辑的静态版本。交互式可编辑版本可通过以下链接获取: * `Google Colab `_ * `Local Notebook (download via GitHub) `_ 简介 --------------- 在本教程中,你将: * 了解脉冲神经元如何作为递归网络实现 * 通过时间了解反向传播,以及 SNN 中的相关挑战,如脉冲的不可微分性 * 在静态 MNIST 数据集上训练全连接网络 .. 本教程的部分灵感来自 Friedemann Zenke 在 SNN 方面的大量工作。 请在 `这里 `_ 查看他关于替代梯度的资料库, 以及我最喜欢的一篇论文: E. O. Neftci, H. Mostafa, F. Zenke, `SNN中的替代梯度学习: 将基于梯度的优化功能引入SNN。 `_ IEEE Signal Processing Magazine 36, 51-63. 在教程的最后,我们将实施一种基本的监督学习算法。 我们将使用原始静态 MNIST 数据集,并使用梯度下降法训练 多层 全连接 脉冲神经网络 来执行图像分类。 安装 snnTorch 的最新 PyPi 发行版: :: $ pip install snntorch :: # imports import snntorch as snn from snntorch import spikeplot as splt from snntorch import spikegen import torch import torch.nn as nn from torch.utils.data import DataLoader from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np import itertools 1. 脉冲神经网络的递归表示 ---------------------------------------- 在 `教程(三) `_ 中, 我们推导出了泄漏整合-发射(LIF)神经元的递归表示: .. math:: U[t+1] = \underbrace{\beta U[t]}_\text{decay} + \underbrace{WX[t+1]}_\text{input} - \underbrace{R[t]}_\text{reset} \tag{1} 其中,输入突触电流解释为 :math:`I_{\rm in}[t] = WX[t]`, 而 :math:`X[t]` 可以是任意输入的脉冲、 阶跃/时变电压或非加权阶跃/时变电流。 脉冲用下式表示,如果膜电位超过阈值,就会发出一个脉冲: .. math:: S[t] = \begin{cases} 1, &\text{if}~U[t] > U_{\rm thr} \\ 0, &\text{otherwise}\end{cases} .. math:: \tag{2} 这种离散递归形式的脉冲神经元表述几乎可以完美利用训练递归神经网络(RNN) 和基于序列模型的发展。我们使用一个*隐式*递归连接来说明膜电位的衰减, 并将其与*显式*递归区分开来,在*显式*递归中, 输出脉冲 :math:`S_{\rm out}`被反馈回输入。 在下图中, 权重为 :math:`U_{\rm thr}`的连接代表着复位机制:math:`R[t]`。 .. image:: https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/unrolled_2.png?raw=true :align: center :width: 600 展开图的好处在于,它明确描述了计算是如何进行的。 展开过程说明了信息流在时间上的前向(从左到右),以计算输出和损失, 以及在时间上的后向,以计算梯度。模拟的时间步数越多,图形就越深。 传统的 RNN 将 :math:`\beta` 视为可学习的参数。 这对 SNN 也是可行的, 不过默认情况下, 它们被视为超参数(hyperparameters)。 这就用超参数搜索取代了梯度消失和梯度爆炸问题。 未来的教程将介绍如何使 :math:`\beta` 成为可学习参数。 2. 脉冲的不可微分性 ----------------------------------------- 2.1 使用反向传播算法进行训练 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 表示 :math:`S` 和 :math:`U` 之间关系的另一种方法是: .. math:: S[t] = \Theta(U[t] - U_{\rm thr}) \tag{3} 其中 :math:`\Theta(\cdot)` 是 Heaviside 阶跃函数(其实就是在原点发生阶跃的函数): .. image:: https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial3/3_2_spike_descrip.png?raw=true :align: center :width: 600 以这种形式训练网络会带来一些严峻的挑战。 考虑上图中题为 *"脉冲神经元的递归表示"* 的计算图的一个单独的时间步, 如下图 *前向传递* 所示: .. image:: https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/non-diff.png?raw=true :align: center :width: 400 我们的目标是利用损失相对于权重的梯度来训练网络,从而更新权重,使损失最小化。 反向传播算法利用链式规则实现了这一目标: .. math:: \frac{\partial \mathcal{L}}{\partial W} = \frac{\partial \mathcal{L}}{\partial S} \underbrace{\frac{\partial S}{\partial U}}_{\{0, \infty\}} \frac{\partial U}{\partial I}\ \frac{\partial I}{\partial W}\ \tag{4} 从 :math:`(1)`, :math:`/partial I//partial W=X`, 以及 :math:`partial U//partial I=1`。 虽然没定义损失函数, 我们还是可以假设 :math:`\partial \mathcal{L}/\partial S` 有一个解析解,有一个类似于交叉熵或均方误差损失(稍后会详细介绍)的解析解。 我们真正要处理的项是 :math:`\partial S/\partial U`。 (3)中的Heaviside阶跃函数的导数是狄拉克-德尔塔函数, 它在任何地方都求值为 :math:`0`, 但在阈值处除外 :math:`U_{\rm thr} = \theta`, 在这里它趋于无穷大。这意味着 梯度几乎总是归零 (如果 :math:`U` 恰好位于阈值处,则为饱和而不是归零), 无法进行学习。这被称为 **死神经元问题** 。 2.2 克服死神经元问题 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 解决死神经元问题的最常见方法是在前向传递过程中保持Heaviside函数的原样, 但将导数项 :math:`\partial S/\partial U` 换成在后向传递过程中不会扼杀学习过程的导数项, 即 :math:`\partial \tilde{S}/\partial U`。这听起来可能有些奇怪, 但事实证明,神经网络对这种近似是相当稳健的。这就是通常所说的 *替代梯度* 方法。 使用替代梯度有多种选择, 我们将在 `教程(六) `_" 中详细介绍这些方法。 snnTorch 的默认方法(截至 v0.6.0)是用反正切函数平滑 Heaviside 函数。 使用的后向导数为 .. math:: \frac{\partial \tilde{S}}{\partial U} \leftarrow \frac{1}{\pi}\frac{1}{(1+[U\pi]^2)} 其中左箭头表示替换。 下面用 PyTorch 实现了 :math:`(1)-(2)` 中描述的同一个神经元模型 (就是教程(三)中的 `snn.Leaky` 神经元)。如果您不理解,请不要担心。 稍后我们将使用 snnTorch 将其浓缩为一行代码: :: # Leaky neuron model, overriding the backward pass with a custom function class LeakySurrogate(nn.Module): def __init__(self, beta, threshold=1.0): super(LeakySurrogate, self).__init__() # initialize decay rate beta and threshold self.beta = beta self.threshold = threshold self.spike_gradient = self.ATan.apply # the forward function is called each time we call Leaky def forward(self, input_, mem): spk = self.spike_gradient((mem-self.threshold)) # call the Heaviside function reset = (self.beta * spk * self.threshold).detach() # remove reset from computational graph mem = self.beta * mem + input_ - reset # Eq (1) return spk, mem # Forward pass: Heaviside function # Backward pass: Override Dirac Delta with the derivative of the ArcTan function @staticmethod class ATan(torch.autograd.Function): @staticmethod def forward(ctx, mem): spk = (mem > 0).float() # Heaviside on the forward pass: Eq(2) ctx.save_for_backward(mem) # store the membrane for use in the backward pass return spk @staticmethod def backward(ctx, grad_output): (spk,) = ctx.saved_tensors # retrieve the membrane potential grad = 1 / (1 + (np.pi * mem).pow_(2)) * grad_output # Eqn 5 return grad 请注意,重置机制是与计算图分离的,因为替代梯度只应用于 :math:`\partial S/\partial U` 而不是 :math:`\partial R/\partial U`。 以上神经元可以这样实现: :: lif1 = LeakySurrogate(beta=0.9) 这个神经元可以用 for 循环来模拟,就像之前的教程一样。 PyTorch 的自动差异化(autodiff)机制会在后台跟踪梯度。 调用 ``snn.Leaky`` 神经元也能实现同样的效果。 事实上,每次从 snnTorch 调用任何神经元模型时, *ATan* 替代梯度都会默认应用于该神经元: :: lif1 = snn.Leaky(beta=0.9) 如果您想了解该神经元的行为,请参阅 `教程(三) `__. 3. 通过时间反向传播(BPTT) ---------------------- 方程 :math:`(4)` 仅计算一个单一时间步的梯度(在下图中称为 *即时影响*), 但是通过时间反向传播(BPTT)算法计算 从损失到 *所有* 后代(descendants)的梯度并将它们相加。 权重 :math:`W` 在每个时间步都应用,因此可以想象在每个时间步也计算了损失。 权重对当前和历史损失的影响必须相加在一起以定义全局梯度: .. math:: \frac{\partial \mathcal{L}}{\partial W}=\sum_t \frac{\partial\mathcal{L}[t]}{\partial W} = \sum_t \sum_{s\leq t} \frac{\partial\mathcal{L}[t]}{\partial W[s]}\frac{\partial W[s]}{\partial W} \tag{5} 方程 :math:`(5)` 的目的是确保因果关系: 通过限制 :math:`s\leq t`,我们只考虑了权重 :math:`W` 对损失的即时和先前影响的贡献。 循环系统将权重限制为在所有步骤中共享::math:`W[0]=W[1] =~... ~ = W`。 因此,对于所有的 :math:`W`,改变 :math:`W[s]` 将对所有 :math:`W` 产生相同的影响, 这意味着 :math:`\partial W[s]/\partial W=1`: .. math:: \frac{\partial \mathcal{L}}{\partial W}= \sum_t \sum_{s\leq t} \frac{\partial\mathcal{L}[t]}{\partial W[s]} \tag{6} 举个例子,隔离由于 :math:`s = t-1` *仅* 导致的先前影响; 这意味着反向传递必须回溯一步。可以将 :math:`W[t-1]` 对损失的影响写成: .. math:: \frac{\partial \mathcal{L}[t]}{\partial W[t-1]} = \frac{\partial \mathcal{L}[t]}{\partial S[t]} \underbrace{\frac{\partial \tilde{S}[t]}{\partial U[t]}}_{方程~(5)} \underbrace{\frac{\partial U[t]}{\partial U[t-1]}}_\beta \underbrace{\frac{\partial U[t-1]}{\partial I[t-1]}}_1 \underbrace{\frac{\partial I[t-1]}{\partial W[t-1]}}_{X[t-1]} \tag{7} 我们已经处理了来自方程 :math:`(4)` 的所有这些项, 除了 :math:`\partial U[t]/\partial U[t-1]`。 根据方程 :math:`(1)`,这个时间导数项简单地等于 :math:`\beta`。 因此,如果我们真的想,我们现在已经知道足够的信息来手动(且痛苦地) 计算每个时间步的每个权重的导数,对于单个神经元,它会看起来像这样: .. image:: https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/bptt.png?raw=true :align: center :width: 600 但幸运的是,PyTorch 的自动微分在后台为我们处理这些。 .. note:: 以上图中省略了重置机制。在 snnTorch 中,重置包含在前向传递中,但与反向传递分离。 4. 设置损失函数 / 输出解码 ------------------------------------------ 在传统的非脉冲神经网络中,有监督的多类分类问题会选取 激活度最高的神经元,并将其作为预测类别。 在脉冲神经网络中,有多种解释输出脉冲的方式。最常见的方法包括: * **脉冲率编码:** 选择具有最高脉冲率(或脉冲计数)的神经元作为预测类别 * **延迟编码:** 选择首先发放脉冲的神经元作为预测类别 这可能会让你联想到关于 `教程(一)神经编码 `__。不同之处在于,在这里,我们是在解释(解码)输出脉冲,而不是将原始输入数据编码/转换成脉冲。 让我们专注于脉冲率编码。当输入数据传递到网络时, 我们希望正确的神经元类别在仿真运行的过程中发射最多的脉冲。 这对应于最高的平均脉冲频率。实现这一目标的一种方法是增加正确类别的膜电位至 :math:`U>U_{\rm thr}`, 并将不正确类别的膜电位设置为 :math:`U