snntorch.surrogate

By default, PyTorch’s autodifferentiation tools are unable to calculate the analytical derivative of the spiking neuron graph. The discrete nature of spikes makes it difficult for torch.autograd to calculate a gradient that facilitates learning. snntorch overrides the default gradient by using snntorch.surrogate.ATan.

Alternative gradients are also available in the snntorch.surrogate module. These represent either approximations of the backward pass or probabilistic models of firing as a function of the membrane potential. Custom, user-defined surrogate gradients can also be implemented.

At present, the surrogate gradient functions available include:

amongst several other options.

For further reading, see:

E. O. Neftci, H. Mostafa, F. Zenke (2019) Surrogate Gradient Learning in Spiking Neural Networks: Bringing the Power of Gradient-Based Optimization to Spiking Neural Networks. IEEE Signal Processing Magazine, pp. 51-63.

How to use surrogate

The surrogate gradient must be passed as the spike_grad argument to the neuron model. If spike_grad is left unspecified, it defaults to snntorch.surrogate.ATan. In the following example, we apply the fast sigmoid surrogate to snntorch.Synaptic.

Example:

import snntorch as snn
from snntorch import surrogate
import torch
import torch.nn as nn

alpha = 0.9
beta = 0.85

# Initialize surrogate gradient
spike_grad1 = surrogate.fast_sigmoid()  # passes default parameters from a closure
spike_grad2 = surrogate.FastSigmoid.apply  # passes default parameters, equivalent to above
spike_grad3 = surrogate.fast_sigmoid(slope=50)  # custom parameters from a closure

# Define Network
class Net(nn.Module):
 def __init__(self):
     super().__init__()

 # Initialize layers, specify the ``spike_grad`` argument
     self.fc1 = nn.Linear(num_inputs, num_hidden)
     self.lif1 = snn.Synaptic(alpha=alpha, beta=beta, spike_grad=spike_grad1)
     self.fc2 = nn.Linear(num_hidden, num_outputs)
     self.lif2 = snn.Synaptic(alpha=alpha, beta=beta, spike_grad=spike_grad3)

 def forward(self, x, syn1, mem1, spk1, syn2, mem2):
     cur1 = self.fc1(x)
     spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
     cur2 = self.fc2(spk1)
     spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)
     return syn1, mem1, spk1, syn2, mem2, spk2

 net = Net().to(device)

Custom Surrogate Gradients

For flexibility, custom surrogate gradients can also be defined by the user using custom_surrogate.

Example:

import snntorch as snn
from snntorch import surrogate
import torch
import torch.nn as nn

beta = 0.9

# Define custom surrogate gradient
 def custom_fast_sigmoid(input_, grad_input, spikes):
     ## The hyperparameter slope is defined inside the function.
     slope = 25
     grad = grad_input / (slope * torch.abs(input_) + 1.0) ** 2
     return grad

 spike_grad = surrogate.custom_surrogate(custom_fast_sigmoid)

# Define Network
class Net(nn.Module):
 def __init__(self):
     super().__init__()

 # Initialize layers
     self.fc1 = nn.Linear(num_inputs, num_hidden)
     self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
     self.fc2 = nn.Linear(num_hidden, num_outputs)
     self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)

 def forward(self, x, mem1, spk1, mem2):
     cur1 = self.fc1(x)
     spk1, mem1 = self.lif1(cur1, mem1)
     cur2 = self.fc2(spk1)
     spk2, mem2 = self.lif2(cur2, mem2)

     return mem1, spk1, mem2, spk2

 net = Net().to(device)

List of surrogate gradients

class snntorch.surrogate.ATan(*args, **kwargs)[source]

Bases: Function

Surrogate gradient of the Heaviside step function.

Forward pass: Heaviside step function shifted.

\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Gradient of shifted arc-tan function.

\[\begin{split}S&≈\frac{1}{π}\text{arctan}(πU \frac{α}{2}) \\ \frac{∂S}{∂U}&=\frac{1}{π}\frac{1}{(1+(πU\frac{α}{2})^2)}\end{split}\]

α defaults to 2, and can be modified by calling surrogate.atan(alpha=2).

Adapted from:

W. Fang, Z. Yu, Y. Chen, T. Masquelier, T. Huang, Y. Tian (2021) Incorporating Learnable Membrane Time Constants to Enhance Learning of Spiking Neural Networks. Proc. IEEE/CVF Int. Conf. Computer Vision (ICCV), pp. 2661-2671.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, alpha)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class snntorch.surrogate.CustomSurrogate(*args, **kwargs)[source]

Bases: Function

Surrogate gradient of the Heaviside step function.

Forward pass: Spike operator function.

\[\begin{split}S=\begin{cases} \frac{U(t)}{U} & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: User-defined custom surrogate gradient function.

The user defines the custom surrogate gradient in a separate function. It is passed in the forward static method and used in the backward static method.

The arguments of the custom surrogate gradient function are always the input of the forward pass (input_), the gradient of the input (grad_input) and the output of the forward pass (out).

** Important Note: The hyperparameters of the custom surrogate gradient function have to be defined inside of the function itself. **

Example:

import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate

def custom_fast_sigmoid(input_, grad_input, spikes):
    ## The hyperparameter slope is defined inside the function.
    slope = 25
    grad = grad_input / (slope * torch.abs(input_) + 1.0) ** 2
    return grad

spike_grad = surrogate.custom_surrogate(custom_fast_sigmoid)

net_seq = nn.Sequential(nn.Conv2d(1, 12, 5),
            nn.MaxPool2d(2),
            snn.Leaky(beta=beta,
                    spike_grad=spike_grad,
                    init_hidden=True),
            nn.Conv2d(12, 64, 5),
            nn.MaxPool2d(2),
            snn.Leaky(beta=beta,
                    spike_grad=spike_grad,
                    init_hidden=True),
            nn.Flatten(),
            nn.Linear(64*4*4, 10),
            snn.Leaky(beta=beta,
                    spike_grad=spike_grad,
                    init_hidden=True,
                    output=True)
            ).to(device)
static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, custom_surrogate_function)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class snntorch.surrogate.FastSigmoid(*args, **kwargs)[source]

Bases: Function

Surrogate gradient of the Heaviside step function.

Forward pass: Heaviside step function shifted.

\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Gradient of fast sigmoid function.

\[\begin{split}S&≈\frac{U}{1 + k|U|} \\ \frac{∂S}{∂U}&=\frac{1}{(1+k|U|)^2}\end{split}\]

\(k\) defaults to 25, and can be modified by calling surrogate.fast_sigmoid(slope=25).

Adapted from:

F. Zenke, S. Ganguli (2018) SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, slope)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

snntorch.surrogate.Heaviside(*args, **kwargs)[source]

Default spiking function for neuron.

Forward pass: Heaviside step function shifted.

\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Heaviside step function shifted.

\[\begin{split}\frac{∂S}{∂U}=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Although the backward pass is clearly not the analytical solution of the forward pass, this assumption holds true on the basis that a reset necessarily occurs after a spike is generated when \(U ≥ U_{\rm thr}\).

snntorch.surrogate.LSO(slope=0.1)[source]

Leaky spike operator gradient enclosed with a parameterized slope.

class snntorch.surrogate.LeakySpikeOperator(*args, **kwargs)[source]

Bases: Function

Surrogate gradient of the Heaviside step function.

Forward pass: Spike operator function.

\[\begin{split}S=\begin{cases} \frac{U(t)}{U} & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Leaky gradient of spike operator, where the subthreshold gradient is treated as a small constant slope.

\[\begin{split}S&≈\begin{cases} \frac{U(t)}{U}\Big{|}_{U(t)→U_{\rm thr}} & \text{if U ≥ U$_{\rm thr}$} \\ kU & \text{if U < U$_{\rm thr}$}\end{cases} \\ \frac{∂S}{∂U}&=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ k & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

\(k\) defaults to 0.1, and can be modified by calling surrogate.LSO(slope=0.1).

The gradient is identical to that of a threshold-shifted Leaky ReLU.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, slope)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

snntorch.surrogate.SFS(slope=25, B=1)[source]

SparseFastSigmoid surrogate gradient enclosed with a parameterized slope and sparsity threshold.

snntorch.surrogate.SSO(mean=0, variance=0.2)[source]

Stochastic spike operator gradient enclosed with a parameterized mean and variance.

class snntorch.surrogate.Sigmoid(*args, **kwargs)[source]

Bases: Function

Surrogate gradient of the Heaviside step function.

Forward pass: Heaviside step function shifted.

\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Gradient of sigmoid function.

\[\begin{split}S&≈\frac{1}{1 + {\rm exp}(-kU)} \\ \frac{∂S}{∂U}&=\frac{k {\rm exp}(-kU)}{[{\rm exp}(-kU)+1]^2}\end{split}\]

\(k\) defaults to 25, and can be modified by calling surrogate.sigmoid(slope=25).

Adapted from:

F. Zenke, S. Ganguli (2018) SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, slope)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class snntorch.surrogate.SparseFastSigmoid(*args, **kwargs)[source]

Bases: Function

Surrogate gradient of the Heaviside step function.

Forward pass: Heaviside step function shifted.

\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Gradient of fast sigmoid function clipped below B.

\[\begin{split}S&≈\frac{U}{1 + k|U|}H(U-B) \\ \frac{∂S}{∂U}&=\begin{cases} \frac{1}{(1+k|U|)^2} & \text{\rm if U > B} 0 & \text{\rm otherwise} \end{cases}\end{split}\]

\(k\) defaults to 25, and can be modified by calling surrogate.SFS(slope=25). \(B\) defaults to 1, and can be modified by calling surrogate.SFS(B=1).

Adapted from:

N. Perez-Nieves and D.F.M. Goodman (2021) Sparse Spiking Gradient Descent. https://arxiv.org/pdf/2105.08810.pdf.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, slope, B)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class snntorch.surrogate.SpikeRateEscape(*args, **kwargs)[source]

Bases: Function

Surrogate gradient of the Heaviside step function.

Forward pass: Heaviside step function shifted.

\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Gradient of Boltzmann Distribution.

\[\frac{∂S}{∂U}=k{\rm exp}(-β|U-1|)\]

\(β\) defaults to 1, and can be modified by calling surrogate.spike_rate_escape(beta=1). \(k\) defaults to 25, and can be modified by calling surrogate.spike_rate_escape(slope=25).

Adapted from:

  • Wulfram Gerstner and Werner M. Kistler,

Spiking neuron models: Single neurons, populations, plasticity. Cambridge University Press, 2002.*

backward(grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, beta, slope)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class snntorch.surrogate.StochasticSpikeOperator(*args, **kwargs)[source]

Bases: Function

Surrogate gradient of the Heaviside step function.

Forward pass: Spike operator function.

\[\begin{split}S=\begin{cases} \frac{U(t)}{U} & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Gradient of spike operator, where the subthreshold gradient is sampled from uniformly distributed noise on the interval \((𝒰\sim[-0.5, 0.5)+μ) σ^2\), where \(μ\) is the mean and \(σ^2\) is the variance.

\[\begin{split}S&≈\begin{cases} \frac{U(t)}{U}\Big{|}_{U(t)→U_{\rm thr}} & \text{if U ≥ U$_{\rm thr}$} \\ (𝒰\sim[-0.5, 0.5) + μ) σ^2 & \text{if U < U$_{\rm thr}$} \end{cases} \\ \frac{∂S}{∂U}&=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ (𝒰\sim[-0.5, 0.5) + μ) σ^2 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

\(μ\) defaults to 0, and can be modified by calling surrogate.SSO(mean=0).

\(σ^2\) defaults to 0.2, and can be modified by calling surrogate.SSO(variance=0.5).

The above defaults set the gradient to the following expression:

\[\begin{split}\frac{∂S}{∂U}&=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ (𝒰\sim[-0.1, 0.1) & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]
static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, mean, variance)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class snntorch.surrogate.StraightThroughEstimator(*args, **kwargs)[source]

Bases: Function

Straight Through Estimator.

Forward pass: Heaviside step function shifted.

\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Gradient of fast sigmoid function.

\[\frac{∂S}{∂U}=1\]
static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class snntorch.surrogate.Triangular(*args, **kwargs)[source]

Bases: Function

Triangular Surrogate Gradient.

Forward pass: Heaviside step function shifted.

\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]

Backward pass: Gradient of the triangular function.

\[\begin{split}\frac{∂S}{∂U}=\begin{cases} U_{\rm thr} & \text{if U < U$_{\rm thr}$} \\ -U_{\rm thr} & \text{if U ≥ U$_{\rm thr}$} \end{cases}\end{split}\]
static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, threshold)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

snntorch.surrogate.atan(alpha=2.0)[source]

ArcTan surrogate gradient enclosed with a parameterized slope.

snntorch.surrogate.custom_surrogate(custom_surrogate_function)[source]

Custom surrogate gradient enclosed within a wrapper.

snntorch.surrogate.fast_sigmoid(slope=25)[source]

FastSigmoid surrogate gradient enclosed with a parameterized slope.

snntorch.surrogate.heaviside()[source]

Heaviside surrogate gradient wrapper.

snntorch.surrogate.sigmoid(slope=25)[source]

Sigmoid surrogate gradient enclosed with a parameterized slope.

snntorch.surrogate.spike_rate_escape(beta=1, slope=25)[source]

SpikeRateEscape surrogate gradient enclosed with a parameterized slope.

snntorch.surrogate.straight_through_estimator()[source]

Straight Through Estimator surrogate gradient enclosed with a parameterized slope.

snntorch.surrogate.triangular(threshold=1)[source]

Triangular surrogate gradient enclosed with a parameterized threshold.