Source code for snntorch.surrogate

import torch
import math

# Spike-gradient functions

# slope = 25
# """``snntorch.surrogate.slope``
# parameterizes the transition rate of the surrogate gradients."""


[docs] class StraightThroughEstimator(torch.autograd.Function): """ Straight Through Estimator. **Forward pass:** Heaviside step function shifted. .. math:: S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Gradient of fast sigmoid function. .. math:: \\frac{∂S}{∂U}=1 """
[docs] @staticmethod def forward(ctx, input_): out = (input_ > 0).float() return out
[docs] @staticmethod def backward(ctx, grad_output): grad_input = grad_output.clone() return grad_input
[docs] def straight_through_estimator(): """Straight Through Estimator surrogate gradient enclosed with a parameterized slope.""" def inner(x): return StraightThroughEstimator.apply(x) return inner
[docs] class Triangular(torch.autograd.Function): """ Triangular Surrogate Gradient. **Forward pass:** Heaviside step function shifted. .. math:: S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Gradient of the triangular function. .. math:: \\frac{∂S}{∂U}=\\begin{cases} U_{\\rm thr} & \\text{if U < U$_{\\rm thr}$} \\\\ -U_{\\rm thr} & \\text{if U ≥ U$_{\\rm thr}$} \\end{cases} """
[docs] @staticmethod def forward(ctx, input_, threshold): ctx.save_for_backward(input_) ctx.threshold = threshold out = (input_ > 0).float() return out
[docs] @staticmethod def backward(ctx, grad_output): (input_,) = ctx.saved_tensors grad_input = grad_output.clone() grad = grad_input * ctx.threshold grad[input_ >= 0] = -grad[input_ >= 0] return grad, None
[docs] def triangular(threshold=1): """Triangular surrogate gradient enclosed with a parameterized threshold.""" threshold = threshold def inner(x): return Triangular.apply(x, threshold) return inner
[docs] class FastSigmoid(torch.autograd.Function): """ Surrogate gradient of the Heaviside step function. **Forward pass:** Heaviside step function shifted. .. math:: S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Gradient of fast sigmoid function. .. math:: S&≈\\frac{U}{1 + k|U|} \\\\ \\frac{∂S}{∂U}&=\\frac{1}{(1+k|U|)^2} :math:`k` defaults to 25, and can be modified by calling \ ``surrogate.fast_sigmoid(slope=25)``. Adapted from: *F. Zenke, S. Ganguli (2018) SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.*"""
[docs] @staticmethod def forward(ctx, input_, slope): ctx.save_for_backward(input_) ctx.slope = slope out = (input_ > 0).float() return out
[docs] @staticmethod def backward(ctx, grad_output): (input_,) = ctx.saved_tensors grad_input = grad_output.clone() grad = grad_input / (ctx.slope * torch.abs(input_) + 1.0) ** 2 return grad, None
[docs] def fast_sigmoid(slope=25): """FastSigmoid surrogate gradient enclosed with a parameterized slope.""" slope = slope def inner(x): return FastSigmoid.apply(x, slope) return inner
[docs] class ATan(torch.autograd.Function): """ Surrogate gradient of the Heaviside step function. **Forward pass:** Heaviside step function shifted. .. math:: S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Gradient of shifted arc-tan function. .. math:: S&≈\\frac{1}{π}\\text{arctan}(πU \\frac{α}{2}) \\\\ \\frac{∂S}{∂U}&=\\frac{1}{π}\\frac{1}{(1+(πU\\frac{α}{2})^2)} α defaults to 2, and can be modified by calling \ ``surrogate.atan(alpha=2)``. Adapted from: *W. Fang, Z. Yu, Y. Chen, T. Masquelier, T. Huang, Y. Tian (2021) Incorporating Learnable Membrane Time Constants to Enhance Learning of Spiking Neural Networks. Proc. IEEE/CVF Int. Conf. Computer Vision (ICCV), pp. 2661-2671.*"""
[docs] @staticmethod def forward(ctx, input_, alpha): ctx.save_for_backward(input_) ctx.alpha = alpha out = (input_ > 0).float() return out
[docs] @staticmethod def backward(ctx, grad_output): (input_,) = ctx.saved_tensors grad_input = grad_output.clone() grad = ( ctx.alpha / 2 / (1 + (torch.pi / 2 * ctx.alpha * input_).pow_(2)) * grad_input ) return grad, None
[docs] def atan(alpha=2.0): """ArcTan surrogate gradient enclosed with a parameterized slope.""" alpha = alpha def inner(x): return ATan.apply(x, alpha) return inner
[docs] @staticmethod class Heaviside(torch.autograd.Function): """Default spiking function for neuron. **Forward pass:** Heaviside step function shifted. .. math:: S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Heaviside step function shifted. .. math:: \\frac{∂S}{∂U}=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} Although the backward pass is clearly not the analytical solution of the forward pass, this assumption holds true on the basis that a reset necessarily occurs after a spike is generated when :math:`U ≥ U_{\\rm thr}`.""" @staticmethod def forward(ctx, input_): out = (input_ > 0).float() ctx.save_for_backward(out) return out @staticmethod def backward(ctx, grad_output): (out,) = ctx.saved_tensors grad = grad_output * out return grad
[docs] def heaviside(): """Heaviside surrogate gradient wrapper.""" def inner(x): return Heaviside.apply(x) return inner
[docs] class Sigmoid(torch.autograd.Function): """ Surrogate gradient of the Heaviside step function. **Forward pass:** Heaviside step function shifted. .. math:: S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Gradient of sigmoid function. .. math:: S&≈\\frac{1}{1 + {\\rm exp}(-kU)} \\\\ \\frac{∂S}{∂U}&=\\frac{k {\\rm exp}(-kU)}{[{\\rm exp}(-kU)+1]^2} :math:`k` defaults to 25, and can be modified by calling \ ``surrogate.sigmoid(slope=25)``. Adapted from: *F. Zenke, S. Ganguli (2018) SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.*"""
[docs] @staticmethod def forward(ctx, input_, slope): ctx.save_for_backward(input_) ctx.slope = slope out = (input_ > 0).float() return out
[docs] @staticmethod def backward(ctx, grad_output): (input_,) = ctx.saved_tensors grad_input = grad_output.clone() grad = ( grad_input * ctx.slope * torch.exp(-ctx.slope * input_) / ((torch.exp(-ctx.slope * input_) + 1) ** 2) ) return grad, None
[docs] def sigmoid(slope=25): """Sigmoid surrogate gradient enclosed with a parameterized slope.""" slope = slope def inner(x): return Sigmoid.apply(x, slope) return inner
[docs] class SpikeRateEscape(torch.autograd.Function): """ Surrogate gradient of the Heaviside step function. **Forward pass:** Heaviside step function shifted. .. math:: S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Gradient of Boltzmann Distribution. .. math:: \\frac{∂S}{∂U}=k{\\rm exp}(-β|U-1|) :math:`β` defaults to 1, and can be modified by calling \ ``surrogate.spike_rate_escape(beta=1)``. :math:`k` defaults to 25, and can be modified by calling \ ``surrogate.spike_rate_escape(slope=25)``. Adapted from: * Wulfram Gerstner and Werner M. Kistler, Spiking neuron models: Single neurons, populations, plasticity. Cambridge University Press, 2002.*"""
[docs] @staticmethod def forward(ctx, input_, beta, slope): ctx.save_for_backward(input_) ctx.beta = beta ctx.slope = slope out = (input_ > 0).float() return out
[docs] def backward(ctx, grad_output): (input_,) = ctx.saved_tensors grad_input = grad_output.clone() grad = ( grad_input * ctx.slope * torch.exp(-ctx.beta * torch.abs(input_ - 1)) ) return grad, None, None
[docs] def spike_rate_escape(beta=1, slope=25): """SpikeRateEscape surrogate gradient enclosed with a parameterized slope.""" beta = beta slope = slope def inner(x): return SpikeRateEscape.apply(x, beta, slope) return inner
[docs] class StochasticSpikeOperator(torch.autograd.Function): """ Surrogate gradient of the Heaviside step function. **Forward pass:** Spike operator function. .. math:: S=\\begin{cases} \\frac{U(t)}{U} & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Gradient of spike operator, where the subthreshold gradient is sampled from uniformly distributed noise on the interval :math:`(𝒰\\sim[-0.5, 0.5)+μ) σ^2`, where :math:`μ` is the mean and :math:`σ^2` is the variance. .. math:: S&≈\\begin{cases} \\frac{U(t)}{U}\\Big{|}_{U(t)→U_{\\rm thr}} & \\text{if U ≥ U$_{\\rm thr}$} \\\\ (𝒰\\sim[-0.5, 0.5) + μ) σ^2 & \\text{if U < U$_{\\rm thr}$} \\end{cases} \\\\ \\frac{∂S}{∂U}&=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ (𝒰\\sim[-0.5, 0.5) + μ) σ^2 & \\text{if U < U$_{\\rm thr}$} \\end{cases} :math:`μ` defaults to 0, and can be modified by calling \ ``surrogate.SSO(mean=0)``. :math:`σ^2` defaults to 0.2, and can be modified by calling \ ``surrogate.SSO(variance=0.5)``. The above defaults set the gradient to the following expression: .. math:: \\frac{∂S}{∂U}&=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ (𝒰\\sim[-0.1, 0.1) & \\text{if U < U$_{\\rm thr}$} \\end{cases} """
[docs] @staticmethod def forward(ctx, input_, mean, variance): out = (input_ > 0).float() ctx.save_for_backward(input_, out) ctx.mean = mean ctx.variance = variance return out
[docs] @staticmethod def backward(ctx, grad_output): (input_, out) = ctx.saved_tensors grad_input = grad_output.clone() grad = grad_input * out + (grad_input * (~out.bool()).float()) * ( (torch.rand_like(input_) - 0.5 + ctx.mean) * ctx.variance ) return grad, None, None
[docs] def SSO(mean=0, variance=0.2): """Stochastic spike operator gradient enclosed with a parameterized mean and variance.""" mean = mean variance = variance def inner(x): return StochasticSpikeOperator.apply(x, mean, variance) return inner
[docs] class LeakySpikeOperator(torch.autograd.Function): """ Surrogate gradient of the Heaviside step function. **Forward pass:** Spike operator function. .. math:: S=\\begin{cases} \\frac{U(t)}{U} & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Leaky gradient of spike operator, where the subthreshold gradient is treated as a small constant slope. .. math:: S&≈\\begin{cases} \\frac{U(t)}{U}\\Big{|}_{U(t)→U_{\\rm thr}} & \\text{if U ≥ U$_{\\rm thr}$} \\\\ kU & \\text{if U < U$_{\\rm thr}$}\\end{cases} \\\\ \\frac{∂S}{∂U}&=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ k & \\text{if U < U$_{\\rm thr}$} \\end{cases} :math:`k` defaults to 0.1, and can be modified by calling \ ``surrogate.LSO(slope=0.1)``. The gradient is identical to that of a threshold-shifted Leaky ReLU."""
[docs] @staticmethod def forward(ctx, input_, slope): out = (input_ > 0).float() ctx.save_for_backward(out) ctx.slope = slope return out
[docs] @staticmethod def backward(ctx, grad_output): (out,) = ctx.saved_tensors grad_input = grad_output.clone() grad = ( grad_input * out + (~out.bool()).float() * ctx.slope * grad_input ) return grad
[docs] def LSO(slope=0.1): """Leaky spike operator gradient enclosed with a parameterized slope.""" slope = slope def inner(x): return StochasticSpikeOperator.apply(x, slope) return inner
[docs] class SparseFastSigmoid(torch.autograd.Function): """ Surrogate gradient of the Heaviside step function. **Forward pass:** Heaviside step function shifted. .. math:: S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** Gradient of fast sigmoid function clipped below B. .. math:: S&≈\\frac{U}{1 + k|U|}H(U-B) \\\\ \\frac{∂S}{∂U}&=\\begin{cases} \\frac{1}{(1+k|U|)^2} & \\text{\\rm if U > B} 0 & \\text{\\rm otherwise} \\end{cases} :math:`k` defaults to 25, and can be modified by calling \ ``surrogate.SFS(slope=25)``. :math:`B` defaults to 1, and can be modified by calling \ ``surrogate.SFS(B=1)``. Adapted from: *N. Perez-Nieves and D.F.M. Goodman (2021) Sparse Spiking Gradient Descent. https://arxiv.org/pdf/2105.08810.pdf.*"""
[docs] @staticmethod def forward(ctx, input_, slope, B): ctx.save_for_backward(input_) ctx.slope = slope ctx.B = B out = (input_ > 0).float() return out
[docs] @staticmethod def backward(ctx, grad_output): (input_,) = ctx.saved_tensors grad_input = grad_output.clone() grad = ( grad_input / (ctx.slope * torch.abs(input_) + 1.0) ** 2 * (input_ > ctx.B).float() ) return grad, None, None
[docs] def SFS(slope=25, B=1): """SparseFastSigmoid surrogate gradient enclosed with a parameterized slope and sparsity threshold.""" slope = slope B = B def inner(x): return SparseFastSigmoid.apply(x, slope, B) return inner
[docs] class CustomSurrogate(torch.autograd.Function): """ Surrogate gradient of the Heaviside step function. **Forward pass:** Spike operator function. .. math:: S=\\begin{cases} \\frac{U(t)}{U} & \\text{if U ≥ U$_{\\rm thr}$} \\\\ 0 & \\text{if U < U$_{\\rm thr}$} \\end{cases} **Backward pass:** User-defined custom surrogate gradient function. The user defines the custom surrogate gradient in a separate function. It is passed in the forward static method and used in the backward static method. The arguments of the custom surrogate gradient function are always the input of the forward pass (input_), the gradient of the input (grad_input) and the output of the forward pass (out). ** Important Note: The hyperparameters of the custom surrogate gradient function have to be defined inside of the function itself. ** Example:: import torch import torch.nn as nn import snntorch as snn from snntorch import surrogate def custom_fast_sigmoid(input_, grad_input, spikes): ## The hyperparameter slope is defined inside the function. slope = 25 grad = grad_input / (slope * torch.abs(input_) + 1.0) ** 2 return grad spike_grad = surrogate.custom_surrogate(custom_fast_sigmoid) net_seq = nn.Sequential(nn.Conv2d(1, 12, 5), nn.MaxPool2d(2), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True), nn.Conv2d(12, 64, 5), nn.MaxPool2d(2), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True), nn.Flatten(), nn.Linear(64*4*4, 10), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True) ).to(device) """
[docs] @staticmethod def forward(ctx, input_, custom_surrogate_function): out = (input_ > 0).float() ctx.save_for_backward(input_, out) ctx.custom_surrogate_function = custom_surrogate_function return out
[docs] @staticmethod def backward(ctx, grad_output): input_, out = ctx.saved_tensors custom_surrogate_function = ctx.custom_surrogate_function grad_input = grad_output.clone() grad = custom_surrogate_function(input_, grad_input, out) return grad, None
[docs] def custom_surrogate(custom_surrogate_function): """Custom surrogate gradient enclosed within a wrapper.""" func = custom_surrogate_function def inner(data): return CustomSurrogate.apply(data, func) return inner
# class InverseSpikeOperator(torch.autograd.Function): # """ # Surrogate gradient of the Heaviside step function. # **Forward pass:** Spike operator function. # .. math:: # S=\\begin{cases} \\frac{U(t)}{U} & \\text{if U ≥ # U$_{\\rm thr}$} \\\\ # 0 & \\text{if U < U$_{\\rm thr}$} # \\end{cases} # **Backward pass:** Gradient of spike operator. # .. math:: # \\frac{∂S}{∂U}&=\\begin{cases} \\frac{1}{U} # & \\text{if U ≥ U$_{\\rm thr}$} \\\\ # 0 & \\text{if U < U$_{\\rm thr}$} # \\end{cases} # :math:`U_{\\rm thr}` defaults to 1, and can be modified by calling # ``surrogate.spike_operator(threshold=1)``. # .. warning:: ``threshold`` should match the threshold of the neuron, # which defaults to 1 as well. # """ # @staticmethod # def forward(ctx, input_, threshold=1): # out = (input_ > 0).float() # ctx.save_for_backward(input_, out) # ctx.threshold = threshold # return out # @staticmethod # def backward(ctx, grad_output): # (input_, out) = ctx.saved_tensors # grad_input = grad_output.clone() # grad = (grad_input * out) / (input_ + ctx.threshold) # return grad, None # def inverse_spike_operator(threshold=1): # """Spike operator gradient enclosed with a parameterized threshold.""" # threshold = threshold # def inner(x): # return InverseSpikeOperator.apply(x, threshold) # return inner # class InverseStochasticSpikeOperator(torch.autograd.Function): # """ # Surrogate gradient of the Heaviside step function. # **Forward pass:** Spike operator function. # .. math:: # S=\\begin{cases} \\frac{U(t)}{U} # & \\text{if U ≥ U$_{\\rm thr}$} \\\\ # 0 & \\text{if U < U$_{\\rm thr}$} # \\end{cases} # **Backward pass:** Gradient of spike operator, # where the subthreshold gradient is sampled from # uniformly distributed noise on the interval # :math:`(𝒰\\sim[-0.5, 0.5)+μ) σ^2`, # where :math:`μ` is the mean and :math:`σ^2` is the variance. # .. math:: # S&≈\\begin{cases} \\frac{U(t)}{U} # & \\text{if U ≥ U$_{\\rm thr}$} \\\\ # (𝒰\\sim[-0.5, 0.5) + μ) σ^2 # & \\text{if U < U$_{\\rm thr}$}\\end{cases} \\\\ # \\frac{∂S}{∂U}&=\\begin{cases} \\frac{1}{U} # & \\text{if U ≥ U$_{\\rm thr}$} \\\\ # (𝒰\\sim[-0.5, 0.5) + μ) σ^2 # & \\text{if U < U$_{\\rm thr}$} # \\end{cases} # :math:`U_{\\rm thr}` defaults to 1, and can be modified by calling # ``surrogate.SSO(threshold=1)``. # :math:`μ` defaults to 0, and can be modified by calling # ``surrogate.SSO(mean=0)``. # :math:`σ^2` defaults to 0.2, and can be modified by calling # ``surrogate.SSO(variance=0.5)``. # The above defaults set the gradient to the following expression: # .. math:: # \\frac{∂S}{∂U}&=\\begin{cases} \\frac{1}{U} # & \\text{if U ≥ U$_{\\rm thr}$} \\\\ # (𝒰\\sim[-0.1, 0.1) & \\text{if U < U$_{\\rm thr}$} # \\end{cases} # .. warning:: ``threshold`` should match the threshold of the neuron, # which defaults to 1 as well. # """ # @staticmethod # def forward(ctx, input_, threshold=1, mean=0, variance=0.2): # out = (input_ > 0).float() # ctx.save_for_backward(input_, out) # ctx.threshold = threshold # ctx.mean = mean # ctx.variance = variance # return out # @staticmethod # def backward(ctx, grad_output): # (input_, out) = ctx.saved_tensors # grad_input = grad_output.clone() # grad = (grad_input * out) / (input_ + ctx.threshold) + ( # grad_input * (~out.bool()).float() # ) * ((torch.rand_like(input_) - 0.5 + ctx.mean) * ctx.variance) # return grad, None, None, None # def ISSO(threshold=1, mean=0, variance=0.2): # """Stochastic spike operator gradient enclosed with a parameterized # threshold, mean and variance.""" # threshold = threshold # mean = mean # variance = variance # def inner(x): # return InverseStochasticSpikeOperator. # apply(x, threshold, mean, variance) # return inner # piecewise linear func # tanh surrogate func