snn.RSynaptic
- class snntorch._neurons.rsynaptic.RSynaptic(alpha, beta, V=1.0, all_to_all=True, linear_features=None, conv2d_channels=None, kernel_size=None, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_alpha=False, learn_beta=False, learn_threshold=False, learn_recurrent=True, reset_mechanism='subtract', state_quant=False, output=False, reset_delay=True)[source]
Bases:
LIF
2nd order recurrent leaky integrate and fire neuron model accounting for synaptic conductance. The synaptic current jumps upon spike arrival, which causes a jump in membrane potential. Synaptic current and membrane potential decay exponentially with rates of alpha and beta, respectively. For \(U[T] > U_{\rm thr} ⇒ S[T+1] = 1\).
If reset_mechanism = “subtract”, then \(U[t+1]\) will have threshold subtracted from it whenever the neuron emits a spike:
\[\begin{split}I_{\rm syn}[t+1] = αI_{\rm syn}[t] + V(S_{\rm out}[t] + I_{\rm in}[t+1] \\ U[t+1] = βU[t] + I_{\rm syn}[t+1] - RU_{\rm thr}\end{split}\]Where \(V(\cdot)\) acts either as a linear layer, a convolutional operator, or elementwise product on \(S_{\rm out}\).
If all_to_all = “True” and linear_features is specified, then \(V(\cdot)\) acts as a recurrent linear layer of the same size as \(S_{\rm out}\).
If all_to_all = “True” and conv2d_channels and kernel_size are specified, then \(V(\cdot)\) acts as a recurrent convlutional layer with padding to ensure the output matches the size of the input.
If all_to_all = “False”, then \(V(\cdot)\) acts as an elementwise multiplier with \(V\).
- If reset_mechanism = “zero”, then \(U[t+1]\) will be set to 0
whenever the neuron emits a spike:
\[\begin{split}I_{\rm syn}[t+1] = αI_{\rm syn}[t] + VS_{\rm out}[t] + I_{\rm in}[t+1] \\ U[t+1] = βU[t] + I_{\rm syn}[t+1] - R(βU[t] + I_{\rm syn}[t+1])\end{split}\]\(I_{\rm syn}\) - Synaptic current
\(I_{\rm in}\) - Input current
\(U\) - Membrane potential
\(U_{\rm thr}\) - Membrane threshold
\(S_{\rm out}\) - Output spike
\(R\) - Reset mechanism: if active, \(R = 1\), otherwise \(R = 0\)
\(α\) - Synaptic current decay rate
\(β\) - Membrane potential decay rate
\(V\) - Explicit recurrent weight when all_to_all=False
Example:
import torch import torch.nn as nn import snntorch as snn beta = 0.5 # decay rate V1 = 0.5 # shared recurrent connection V2 = torch.rand(num_outputs) # unshared recurrent connections # Define Network class Net(nn.Module): def __init__(self): super().__init__() # initialize layers self.fc1 = nn.Linear(num_inputs, num_hidden) # Default RLeaky Layer where recurrent connections # are initialized using PyTorch defaults in nn.Linear. self.lif1 = snn.RLeaky(beta=beta, linear_features=num_hidden) self.fc2 = nn.Linear(num_hidden, num_outputs) # each neuron has a single connection back to itself # where the output spike is scaled by V. # For `all_to_all = True`, V can be shared between # neurons (e.g., V1) or unique / unshared between # neurons (e.g., V2). # V is learnable by default. self.lif2 = snn.RLeaky(beta=beta, all_to_all=False, V=V1) def forward(self, x): # Initialize hidden states at t=0 spk1, syn1, mem1 = self.lif1.init_rsynaptic() spk2, syn2, mem2 = self.lif2.init_rsynaptic() # Record output layer spikes and membrane spk2_rec = [] mem2_rec = [] # time-loop for step in range(num_steps): cur1 = self.fc1(x) spk1, syn1, mem1 = self.lif1(cur1, spk1, syn1, mem1) cur2 = self.fc2(spk1) spk2, syn2, mem2 = self.lif2(cur2, spk2, syn2, mem2) spk2_rec.append(spk2) mem2_rec.append(mem2) # convert lists to tensors spk2_rec = torch.stack(spk2_rec) mem2_rec = torch.stack(mem2_rec) return spk2_rec, mem2_rec
- Parameters:
alpha (float or torch.tensor) – synaptic current decay rate. Clipped between 0 and 1 during the forward-pass. May be a single-valued tensor (i.e., equal decay rate for all neurons in a layer), or multi-valued (one weight per neuron).
beta (float or torch.tensor) – membrane potential decay rate. Clipped between 0 and 1 during the forward-pass. May be a single-valued tensor (i.e., equal decay rate for all neurons in a layer), or multi-valued (one weight per neuron).
V (float or torch.tensor) – Recurrent weights to scale output spikes, only used when all_to_all=False. Defaults to 1.
all_to_all (bool, optional) – Enables output spikes to be connected in dense or convolutional recurrent structures instead of 1-to-1 connections. Defaults to True.
linear_features (int, optional) – Size of each output sample. Must be specified if all_to_all=True and the input data is 1D. Defaults to None
conv2d_channels (int, optional) – Number of channels in each output sample. Must be specified if all_to_all=True and the input data is 3D. Defaults to None
kernel_size (int or tuple) – Size of the convolving kernel. Must be specified if all_to_all=True and the input data is 3D. Defaults to None
threshold (float, optional) – Threshold for \(mem\) to reach in order to generate a spike S=1. Defaults to 1
spike_grad (surrogate gradient function from snntorch.surrogate, optional) – Surrogate gradient for the term dS/dU. Defaults to None (corresponds to ATan surrogate gradient. See snntorch.surrogate for more options)
surrogate_disable (bool, Optional) – Disables surrogate gradients regardless of spike_grad argument. Useful for ONNX compatibility. Defaults to False
init_hidden (bool, optional) – Instantiates state variables as instance variables. Defaults to False
inhibition (bool, optional) – If True, suppresses all spiking other than the neuron with the highest state. Defaults to False
learn_alpha (bool, optional) – Option to enable learnable alpha. Defaults to False
learn_beta (bool, optional) – Option to enable learnable beta. Defaults to False
learn_recurrent (bool, optional) – Option to enable learnable recurrent weights. Defaults to True
learn_threshold (bool, optional) – Option to enable learnable threshold. Defaults to False
reset_mechanism (str, optional) – Defines the reset mechanism applied to \(mem\) each time the threshold is met. Reset-by-subtraction: “subtract”, reset-to-zero: “zero, none: “none”. Defaults to “subtract”
state_quant (quantization function from snntorch.quant, optional) – If specified, hidden states \(mem\) and \(syn\) are quantized to a valid state for the forward pass. Defaults to False
output (bool, optional) – If True as well as init_hidden=True, states are returned when neuron is called. Defaults to False
- Inputs: input_, spk_0, syn_0, mem_0
input_ of shape (batch, input_size): tensor containing input features
spk_0 of shape (batch, input_size): tensor containing output spike features
syn_0 of shape (batch, input_size): tensor containing input features
mem_0 of shape (batch, input_size): tensor containing the initial membrane potential for each element in the batch.
- Outputs: spk_1, syn_1, mem_1
spk_1 of shape (batch, input_size): tensor containing the output spikes.
syn_1 of shape (batch, input_size): tensor containing the next synaptic current for each element in the batch
mem_1 of shape (batch, input_size): tensor containing the next membrane potential for each element in the batch
- Learnable Parameters:
RSynaptic.alpha (torch.Tensor) - optional learnable weights must be manually passed in, of shape 1 or (input_size).
RSynaptic.beta (torch.Tensor) - optional learnable weights must be manually passed in, of shape 1 or (input_size).
RSynaptic.recurrent.weight (torch.Tensor) - optional learnable weights are automatically generated if all_to_all=True. RSynaptic.recurrent stores a nn.Linear or nn.Conv2d layer depending on input arguments provided.
RSynaptic.V (torch.Tensor) - optional learnable weights must be manually passed in, of shape 1 or (input_size). It is only used where all_to_all=False for 1-to-1 recurrent connections.
RSynaptic.threshold (torch.Tensor) - optional learnable thresholds must be manually passed in, of shape 1 or`` (input_size).
Returns the hidden states, detached from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are instance variables.
- forward(input_, spk=None, syn=None, mem=None)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- init_rsynaptic()[source]
Deprecated, use
RSynaptic.reset_mem
instead
Used to clear hidden state variables to zero. Intended for use where hidden state variables are instance variables.
- class snntorch._neurons.rsynaptic.RecurrentOneToOne(V)[source]
Bases:
Module
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.