snn.Synaptic

class snntorch._neurons.synaptic.Synaptic(alpha, beta, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_alpha=False, learn_beta=False, learn_threshold=False, reset_mechanism='subtract', state_quant=False, output=False, reset_delay=True)[source]

Bases: LIF

2nd order leaky integrate and fire neuron model accounting for synaptic conductance. The synaptic current jumps upon spike arrival, which causes a jump in membrane potential. Synaptic current and membrane potential decay exponentially with rates of alpha and beta, respectively. For \(U[T] > U_{\rm thr} ⇒ S[T+1] = 1\).

If reset_mechanism = “subtract”, then \(U[t+1]\) will have threshold subtracted from it whenever the neuron emits a spike:

\[\begin{split}I_{\rm syn}[t+1] = αI_{\rm syn}[t] + I_{\rm in}[t+1] \\ U[t+1] = βU[t] + I_{\rm syn}[t+1] - RU_{\rm thr}\end{split}\]

If reset_mechanism = “zero”, then \(U[t+1]\) will be set to 0 whenever the neuron emits a spike:

\[\begin{split}I_{\rm syn}[t+1] = αI_{\rm syn}[t] + I_{\rm in}[t+1] \\ U[t+1] = βU[t] + I_{\rm syn}[t+1] - R(βU[t] + I_{\rm syn}[t+1])\end{split}\]
  • \(I_{\rm syn}\) - Synaptic current

  • \(I_{\rm in}\) - Input current

  • \(U\) - Membrane potential

  • \(U_{\rm thr}\) - Membrane threshold

  • \(R\) - Reset mechanism: if active, \(R = 1\), otherwise \(R = 0\)

  • \(α\) - Synaptic current decay rate

  • \(β\) - Membrane potential decay rate

Example:

import torch
import torch.nn as nn
import snntorch as snn

alpha = 0.9
beta = 0.5

# Define Network
class Net(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super().__init__()

        # initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Synaptic(alpha=alpha, beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Synaptic(alpha=alpha, beta=beta)

    def forward(self, x, syn1, mem1, spk1, syn2, mem2):
        cur1 = self.fc1(x)
        spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
        cur2 = self.fc2(spk1)
        spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)
        return syn1, mem1, spk1, syn2, mem2, spk2
Parameters:
  • alpha (float or torch.tensor) – synaptic current decay rate. Clipped between 0 and 1 during the forward-pass. May be a single-valued tensor (i.e., equal decay rate for all neurons in a layer), or multi-valued (one weight per neuron).

  • beta (float or torch.tensor) – membrane potential decay rate. Clipped between 0 and 1 during the forward-pass. May be a single-valued tensor (i.e., equal decay rate for all neurons in a layer), or multi-valued (one weight per neuron).

  • threshold (float, optional) – Threshold for \(mem\) to reach in order to generate a spike S=1. Defaults to 1

  • spike_grad (surrogate gradient function from snntorch.surrogate, optional) – Surrogate gradient for the term dS/dU. Defaults to None (corresponds to Heaviside surrogate gradient. See snntorch.surrogate for more options)

  • surrogate_disable (bool, Optional) – Disables surrogate gradients regardless of spike_grad argument. Useful for ONNX compatibility. Defaults to False

  • init_hidden (bool, optional) – Instantiates state variables as instance variables. Defaults to False

  • inhibition (bool, optional) – If True, suppresses all spiking other than the neuron with the highest state. Defaults to False

  • learn_alpha (bool, optional) – Option to enable learnable alpha. Defaults to False

  • learn_beta (bool, optional) – Option to enable learnable beta. Defaults to False

  • learn_threshold (bool, optional) – Option to enable learnable threshold. Defaults to False

  • reset_mechanism (str, optional) – Defines the reset mechanism applied to \(mem\) each time the threshold is met. Reset-by-subtraction: “subtract”, reset-to-zero: “zero”, none: “none”. Defaults to “subtract”

  • state_quant (quantization function from snntorch.quant, optional) – If specified, hidden states \(mem\) and \(syn\) are quantized to a valid state for the forward pass. Defaults to False

  • output (bool, optional) – If True as well as init_hidden=True, states are returned when neuron is called. Defaults to False

Inputs: input_, syn_0, mem_0
  • input_ of shape (batch, input_size): tensor containing input features

  • syn_0 of shape (batch, input_size): tensor containing input features

  • mem_0 of shape (batch, input_size): tensor containing the initial membrane potential for each element in the batch.

Outputs: spk, syn_1, mem_1
  • spk of shape (batch, input_size): tensor containing the output spikes.

  • syn_1 of shape (batch, input_size): tensor containing the next synaptic current for each element in the batch

  • mem_1 of shape (batch, input_size): tensor containing the next membrane potential for each element in the batch

Learnable Parameters:
  • Synaptic.alpha (torch.Tensor) - optional learnable weights must be manually passed in, of shape 1 or (input_size).

  • Synaptic.beta (torch.Tensor) - optional learnable weights must be manually passed in, of shape 1 or (input_size).

  • Synaptic.threshold (torch.Tensor) - optional learnable thresholds must be manually passed in, of shape 1 or`` (input_size).

classmethod detach_hidden()[source]

Returns the hidden states, detached from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are instance variables.

forward(input_, syn=None, mem=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_synaptic()[source]

Deprecated, use Synaptic.reset_mem instead

classmethod reset_hidden()[source]

Used to clear hidden state variables to zero. Intended for use where hidden state variables are instance variables.

reset_mem()[source]