Source code for snntorch._layers.bntt

import torch.nn as nn


[docs] def BatchNormTT1d( input_features, time_steps, eps=1e-5, momentum=0.1, affine=True ): """ Generate a torch.nn.ModuleList of 1D Batch Normalization Layer with length time_steps. Input to this layer is the same as the vanilla torch.nn.BatchNorm1d layer. Batch Normalisation Through Time (BNTT) as presented in: 'Revisiting Batch Normalization for Training Low-Latency Deep Spiking Neural Networks From Scratch' By Youngeun Kim & Priyadarshini Panda arXiv preprint arXiv:2010.01729 Original GitHub repo: https://github.com/Intelligent-Computing-Lab-Yale/ BNTT-Batch-Normalization-Through-Time Using LIF neuron as the neuron of choice for the math shown below. Typically, for a single post-synaptic neuron i, we can represent its membrane potential :math:`U_{i}^{t}` at time-step t as: .. math:: U_{i}^{t} = λ u_{i}^{t-1} + \\sum_j w_{ij}S_{j}^{t} where: * λ - a leak factor which is less than one * j - the index of the pre-synaptic neuron * :math:`S_{j}` - the binary spike activation * :math:`w_{ij}` - the weight of the connection between the pre & \ post neurons. With Batch Normalization Throught Time, the membrane potential can be modeled as: .. math:: U_{i}^{t} = λu_{i}^{t-1} + BNTT_{γ^{t}} = λu_{i}^{t-1} + γ _{i}^{t} (\\frac{\\sum_j w_{ij}S_{j}^{t} - µ_{i}^{t}}{\\sqrt{(σ _{i}^{t})^{2} + ε}}) :param input_features: number of features of the input :type input_features: int :param time_steps: number of time-steps of the SNN :type time_steps: int :param eps: a value added to the denominator for numerical stability :type eps: float :param momentum: the value used for the running_mean and running_var \ computation :type momentum: float :param affine: a boolean value that when set to True, the Batch Norm \ layer will have learnable affine parameters :type affine: bool Inputs: input_features, time_steps - **input_features**: same number of features as the input - **time_steps**: the number of time-steps to unroll in the SNN Outputs: bntt - **bntt** of shape `(time_steps)`: toch.nn.ModuleList of \ BatchNorm1d layers for the specified number of time-steps """ bntt = nn.ModuleList( [ nn.BatchNorm1d( input_features, eps=eps, momentum=momentum, affine=affine ) for _ in range(time_steps) ] ) # Disable bias/beta of Batch Norm for bn in bntt: bn.bias = None return bntt
[docs] def BatchNormTT2d( input_features, time_steps, eps=1e-5, momentum=0.1, affine=True ): """ Generate a torch.nn.ModuleList of 2D Batch Normalization Layer with length time_steps. Input to this layer is the same as the vanilla torch.nn.BatchNorm2d layer. Batch Normalisation Through Time (BNTT) as presented in: 'Revisiting Batch Normalization for Training Low-Latency Deep Spiking Neural Networks From Scratch' By Youngeun Kim & Priyadarshini Panda arXiv preprint arXiv:2010.01729 Using LIF neuron as the neuron of choice for the math shown below. Typically, for a single post-synaptic neuron i, we can represent its membrane potential :math:`U_{i}^{t}` at time-step t as: .. math:: U_{i}^{t} = λ u_{i}^{t-1} + \\sum_j w_{ij}S_{j}^{t} where: * λ - a leak factor which is less than one * j - the index of the pre-synaptic neuron * :math:`S_{j}` - the binary spike activation * :math:`w_{ij}` - the weight of the connection between the pre & post \ neurons. With Batch Normalization Throught Time, the membrane potential can be \ modeled as: .. math:: U_{i}^{t} = λ u_{i}^{t-1} + BNTT_{γ^{t}} = λ u_{i}^{t-1} + γ_{i}^{t} (\\frac{\\sum_j w_{ij}S_{j}^{t} - µ_{i}^{t}}{\\sqrt{(σ _{i}^{t})^{2} + ε}}) :param input_features: number of channels of the input :type input_features: int :param time_steps: number of time-steps of the SNN :type time_steps: int :param eps: a value added to the denominator for numerical stability :type eps: float :param momentum: the value used for the running_mean and running_var \ computation :type momentum: float :param affine: a boolean value that when set to True, the Batch Norm \ layer will have learnable affine parameters :type affine: bool Inputs: input_features, time_steps - **input_features**: same number of channels as the input - **time_steps**: the number of time-steps to unroll in the SNN Outputs: bntt - **bntt** of shape `(time_steps)`: toch.nn.ModuleList of \ BatchNorm1d layers for the specified number of time-steps """ bntt = nn.ModuleList( [ nn.BatchNorm2d( input_features, eps=eps, momentum=momentum, affine=affine ) for _ in range(time_steps) ] ) # Disable bias/beta of Batch Norm for bn in bntt: bn.bias = None return bntt