snn.LeakyParallel
- class snntorch._neurons.leakyparallel.LeakyParallel(input_size, hidden_size, beta=None, bias=True, threshold=1.0, dropout=0.0, spike_grad=None, surrogate_disable=False, learn_beta=False, learn_threshold=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False, weight_hh_enable=False, device=None, dtype=None)[source]
Bases:
Module
A parallel implementation of the Leaky neuron with a fused input linear layer. All time steps are passed to the input at once. This implementation uses torch.nn.RNN to accelerate the implementation.
First-order leaky integrate-and-fire neuron model. Input is assumed to be a current injection. Membrane potential decays exponentially with rate beta. For \(U[T] > U_{\rm thr} ⇒ S[T+1] = 1\).
\[U[t+1] = βU[t] + I_{\rm in}[t+1]\]\(I_{\rm in}\) - Input current
\(U\) - Membrane potential
\(U_{\rm thr}\) - Membrane threshold
\(β\) - Membrane potential decay rate
Several differences between LeakyParallel and Leaky include:
- Negative hidden states are clipped due to the
forced ReLU operation in RNN.
- Linear weights are included in addition to
recurrent weights.
- beta is clipped between [0,1] and cloned to
weight_hh_l only upon layer initialization. It is unused otherwise.
There is no explicit reset mechanism.
- Several functions such as init_hidden, output,
inhibition, and state_quant are unavailable in LeakyParallel.
- Only the output spike is returned. Membrane potential
is not accessible by default.
- RNN uses a hidden matrix of size (num_hidden, num_hidden)
to transform the hidden state vector. This would ‘leak’ the membrane potential between LIF neurons, and so the hidden matrix is forced to a diagonal matrix by default. This can be disabled by setting weight_hh_enable=True.
Example:
import torch import torch.nn as nn import snntorch as snn beta = 0.5 num_inputs = 784 num_hidden = 128 num_outputs = 10 batch_size = 128 x = torch.rand((num_steps, batch_size, num_inputs)) # Define Network class Net(nn.Module): def __init__(self): super().__init__() # initialize layers self.lif1 = snn.LeakyParallel(input_size=num_inputs, hidden_size=num_hidden) # randomly initialize recurrent weights self.lif2 = snn.LeakyParallel(input_size=num_hidden, hidden_size=num_outputs, beta=beta, learn_beta=True) # learnable recurrent weights initialized at beta def forward(self, x): spk1 = self.lif1(x) spk2 = self.lif2(spk1) return spk2
- param input_size:
The number of expected features in the input x
- type input_size:
int
- param hidden_size:
The number of features in the hidden state h
- type hidden_size:
int
- param beta:
membrane potential decay rate. Clipped between 0 and 1 during the forward-pass. May be a single-valued tensor (i.e., equal decay rate for all neurons in a layer), or multi-valued (one weight per neuron). If left unspecified, then the decay rates will be randomly initialized based on PyTorch’s initialization for RNN. Defaults to None
- type beta:
float or torch.tensor, optional
- param bias:
If False, then the layer does not use bias weights b_ih and b_hh. Defaults to True
- type bias:
Bool, optional
- param threshold:
Threshold for \(mem\) to reach in order to generate a spike S=1. Defaults to 1
- type threshold:
float, optional
- param dropout:
If non-zero, introduces a Dropout layer on the RNN output with dropout probability equal to dropout. Defaults to 0
- type dropout:
float, optional
- param spike_grad:
Surrogate gradient for the term dS/dU. Defaults to None (corresponds to ATan surrogate gradient. See snntorch.surrogate for more options)
- type spike_grad:
surrogate gradient function from snntorch.surrogate, optional
- param surrogate_disable:
Disables surrogate gradients regardless of spike_grad argument. Useful for ONNX compatibility. Defaults to False
- type surrogate_disable:
bool, Optional
- param learn_beta:
Option to enable learnable beta. Defaults to False
- type learn_beta:
bool, optional
- param learn_threshold:
Option to enable learnable threshold. Defaults to False
- type learn_threshold:
bool, optional
- param weight_hh_enable:
Option to set the hidden matrix to be dense or diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. Dense (True) would allow the membrane potential of one LIF neuron to influence all others, and follow the RNN default implementation. Defaults to False
- type weight_hh_enable:
bool, optional
- Inputs: input_
- input_ of shape of shape (L, H_{in}) for unbatched input,
or (L, N, H_{in}) containing the features of the input sequence.
- Outputs: spk
- spk of shape (L, batch, input_size): tensor containing the
output spikes.
where:
`L = sequence length`
`N = batch size`
`H_{in} = input_size`
`H_{out} = hidden_size`
- Learnable Parameters:
- rnn.weight_ih_l (torch.Tensor) - the learnable input-hidden
weights of shape (hidden_size, input_size).
- rnn.weight_hh_l (torch.Tensor) - the learnable hidden-hidden
weights of the k-th layer which are sampled from beta of shape (hidden_size, hidden_size).
- bias_ih_l - the learnable input-hidden bias of the k-th layer,
of shape (hidden_size).
- bias_hh_l - the learnable hidden-hidden bias of the k-th layer,
of shape (hidden_size).
- threshold (torch.Tensor) - optional learnable thresholds must be
manually passed in, of shape 1 or`` (input_size).
- graded_spikes_factor (torch.Tensor) - optional learnable graded
spike factor.
- class ATan(*args, **kwargs)[source]
Bases:
Function
Surrogate gradient of the Heaviside step function.
Forward pass: Heaviside step function shifted.
\[\begin{split}S=\begin{cases} 1 & \text{if U ≥ U$_{\rm thr}$} \\ 0 & \text{if U < U$_{\rm thr}$} \end{cases}\end{split}\]Backward pass: Gradient of shifted arc-tan function.
\[\begin{split}S&≈\frac{1}{π}\text{arctan}(πU \frac{α}{2}) \\ \frac{∂S}{∂U}&=\frac{1}{π} \frac{1}{(1+(πU\frac{α}{2})^2)}\end{split}\]\(alpha\) defaults to 2, and can be modified by calling
surrogate.atan(alpha=2)
.Adapted from:
W. Fang, Z. Yu, Y. Chen, T. Masquelier, T. Huang, Y. Tian (2021) Incorporating Learnable Membrane Time Constants to Enhance Learning of Spiking Neural Networks. Proc. IEEE/CVF Int. Conf. Computer Vision (ICCV), pp. 2661-2671.
- static backward(ctx, grad_output)[source]
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, input_, alpha=2.0)[source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- forward(input_)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.