import torch
import torch.nn as nn
from .neurons import LIF
[docs]
class Synaptic(LIF):
"""
2nd order leaky integrate and fire neuron model accounting for synaptic
conductance.
The synaptic current jumps upon spike arrival, which causes a jump in
membrane potential.
Synaptic current and membrane potential decay exponentially with rates
of alpha and beta, respectively.
For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`.
If `reset_mechanism = "subtract"`, then :math:`U[t+1]` will have
`threshold` subtracted from it whenever the neuron emits a spike:
.. math::
I_{\\rm syn}[t+1] = αI_{\\rm syn}[t] + I_{\\rm in}[t+1] \\\\
U[t+1] = βU[t] + I_{\\rm syn}[t+1] - RU_{\\rm thr}
If `reset_mechanism = "zero"`, then :math:`U[t+1]` will be set to `0`
whenever the neuron emits a spike:
.. math::
I_{\\rm syn}[t+1] = αI_{\\rm syn}[t] + I_{\\rm in}[t+1] \\\\
U[t+1] = βU[t] + I_{\\rm syn}[t+1] - R(βU[t] + I_{\\rm syn}[t+1])
* :math:`I_{\\rm syn}` - Synaptic current
* :math:`I_{\\rm in}` - Input current
* :math:`U` - Membrane potential
* :math:`U_{\\rm thr}` - Membrane threshold
* :math:`R` - Reset mechanism: if active, :math:`R = 1`, otherwise \
:math:`R = 0`
* :math:`α` - Synaptic current decay rate
* :math:`β` - Membrane potential decay rate
Example::
import torch
import torch.nn as nn
import snntorch as snn
alpha = 0.9
beta = 0.5
# Define Network
class Net(nn.Module):
def __init__(self, num_inputs, num_hidden, num_outputs):
super().__init__()
# initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.lif1 = snn.Synaptic(alpha=alpha, beta=beta)
self.fc2 = nn.Linear(num_hidden, num_outputs)
self.lif2 = snn.Synaptic(alpha=alpha, beta=beta)
def forward(self, x, syn1, mem1, spk1, syn2, mem2):
cur1 = self.fc1(x)
spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
cur2 = self.fc2(spk1)
spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)
return syn1, mem1, spk1, syn2, mem2, spk2
:param alpha: synaptic current decay rate. Clipped between 0 and 1
during the forward-pass. May be a single-valued tensor (i.e.,
equal decay rate for all neurons in a layer), or multi-valued
(one weight per neuron).
:type alpha: float or torch.tensor
:param beta: membrane potential decay rate. Clipped between 0 and 1
during the forward-pass. May be a single-valued tensor (i.e., equal
decay rate for all neurons in a layer), or multi-valued (one weight
per neuron).
:type beta: float or torch.tensor
:param threshold: Threshold for :math:`mem` to reach in order to generate
a spike `S=1`. Defaults to 1
:type threshold: float, optional
:param spike_grad: Surrogate gradient for the term dS/dU. Defaults to None
(corresponds to Heaviside surrogate gradient. See `snntorch.surrogate`
for more options)
:type spike_grad: surrogate gradient function from snntorch.surrogate,
optional
:param surrogate_disable: Disables surrogate gradients regardless of
`spike_grad` argument. Useful for ONNX compatibility. Defaults
to False
:type surrogate_disable: bool, Optional
:param init_hidden: Instantiates state variables as instance variables.
Defaults to False
:type init_hidden: bool, optional
:param inhibition: If `True`, suppresses all spiking other than the
neuron with the highest state. Defaults to False
:type inhibition: bool, optional
:param learn_alpha: Option to enable learnable alpha. Defaults to False
:type learn_alpha: bool, optional
:param learn_beta: Option to enable learnable beta. Defaults to False
:type learn_beta: bool, optional
:param learn_threshold: Option to enable learnable threshold. Defaults
to False
:type learn_threshold: bool, optional
:param reset_mechanism: Defines the reset mechanism applied to :math:`mem`
each time the threshold is met. Reset-by-subtraction: "subtract",
reset-to-zero: "zero", none: "none". Defaults to "subtract"
:type reset_mechanism: str, optional
:param state_quant: If specified, hidden states :math:`mem` and \
:math:`syn` are quantized to a valid state for the forward pass. \
Defaults to False
:type state_quant: quantization function from snntorch.quant, optional
:param output: If `True` as well as `init_hidden=True`, states are
returned when neuron is called. Defaults to False
:type output: bool, optional
Inputs: \\input_, syn_0, mem_0
- **input_** of shape `(batch, input_size)`: tensor containing \
input features
- **syn_0** of shape `(batch, input_size)`: tensor containing \
input features
- **mem_0** of shape `(batch, input_size)`: tensor containing \
the initial membrane potential for each element in the batch.
Outputs: spk, syn_1, mem_1
- **spk** of shape `(batch, input_size)`: tensor containing the \
output spikes.
- **syn_1** of shape `(batch, input_size)`: tensor containing the \
next synaptic current for each element in the batch
- **mem_1** of shape `(batch, input_size)`: tensor containing the \
next membrane potential for each element in the batch
Learnable Parameters:
- **Synaptic.alpha** (torch.Tensor) - optional learnable weights \
must be manually passed in, of shape `1` or (input_size).
- **Synaptic.beta** (torch.Tensor) - optional learnable weights must \
be manually passed in, of shape `1` or (input_size).
- **Synaptic.threshold** (torch.Tensor) - optional learnable \
thresholds must be manually passed in, of shape `1` or`` (input_size).
"""
def __init__(
self,
alpha,
beta,
threshold=1.0,
spike_grad=None,
surrogate_disable=False,
init_hidden=False,
inhibition=False,
learn_alpha=False,
learn_beta=False,
learn_threshold=False,
reset_mechanism="subtract",
state_quant=False,
output=False,
reset_delay=True,
):
super().__init__(
beta,
threshold,
spike_grad,
surrogate_disable,
init_hidden,
inhibition,
learn_beta,
learn_threshold,
reset_mechanism,
state_quant,
output,
)
self._alpha_register_buffer(alpha, learn_alpha)
self._init_mem()
if self.reset_mechanism_val == 0: # reset by subtraction
self.state_function = self._base_sub
elif self.reset_mechanism_val == 1: # reset to zero
self.state_function = self._base_zero
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int
self.reset_delay = reset_delay
def _init_mem(self):
syn = torch.zeros(0)
mem = torch.zeros(0)
self.register_buffer("syn", syn, False)
self.register_buffer("mem", mem, False)
[docs]
def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.syn, self.mem
[docs]
def init_synaptic(self):
"""Deprecated, use :class:`Synaptic.reset_mem` instead"""
return self.reset_mem()
[docs]
def forward(self, input_, syn=None, mem=None):
if not syn == None:
self.syn = mem
if not mem == None:
self.mem = mem
if self.init_hidden and (not mem == None or not syn == None):
raise TypeError(
"`mem` or `syn` should not be passed as an argument while `init_hidden=True`"
)
if not self.syn.shape == input_.shape:
self.syn = torch.zeros_like(input_, device=self.syn.device)
if not self.mem.shape == input_.shape:
self.mem = torch.zeros_like(input_, device=self.mem.device)
self.reset = self.mem_reset(self.mem)
self.syn, self.mem = self.state_function(input_)
if self.state_quant:
self.mem = self.state_quant(self.mem)
self.syn = self.state_quant(self.syn)
if self.inhibition:
spk = self.fire_inhibition(
self.mem.size(0), self.mem
) # batch_size
else:
spk = self.fire(self.mem)
if not self.reset_delay:
# reset membrane potential _right_ after spike
do_reset = (
spk / self.graded_spikes_factor - self.reset
) # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
mem = mem - do_reset * mem
if self.output:
return spk, self.syn, self.mem
elif self.init_hidden:
return spk
else:
return spk, self.syn, self.mem
def _base_state_function(self, input_):
base_fn_syn = self.alpha.clamp(0, 1) * self.syn + input_
base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn
return base_fn_syn, base_fn_mem
def _base_state_reset_zero(self, input_):
base_fn_syn = self.alpha.clamp(0, 1) * self.syn + input_
base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn
return 0, base_fn_mem
def _base_sub(self, input_):
syn, mem = self._base_state_function(input_)
mem = mem - self.reset * self.threshold
return syn, mem
def _base_zero(self, input_):
syn, mem = self._base_state_function(input_)
syn2, mem2 = self._base_state_reset_zero(input_)
syn -= syn2 * self.reset
mem -= mem2 * self.reset
return syn, mem
def _base_int(self, input_):
return self._base_state_function(input_)
def _alpha_register_buffer(self, alpha, learn_alpha):
if not isinstance(alpha, torch.Tensor):
alpha = torch.as_tensor(alpha)
if learn_alpha:
self.alpha = nn.Parameter(alpha)
else:
self.register_buffer("alpha", alpha)
[docs]
@classmethod
def detach_hidden(cls):
"""Returns the hidden states, detached from the current graph.
Intended for use in truncated backpropagation through time where
hidden state variables are instance variables."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], Synaptic):
cls.instances[layer].syn.detach_()
cls.instances[layer].mem.detach_()
[docs]
@classmethod
def reset_hidden(cls):
"""Used to clear hidden state variables to zero.
Intended for use where hidden state variables are
instance variables."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], Synaptic):
cls.instances[layer].syn = torch.zeros_like(
cls.instances[layer].syn,
device=cls.instances[layer].syn.device,
)
cls.instances[layer].mem = torch.zeros_like(
cls.instances[layer].mem,
device=cls.instances[layer].mem.device,
)