import torch
import torch.nn as nn
from .neurons import LIF
[docs]
class RSynaptic(LIF):
"""
2nd order recurrent leaky integrate and fire neuron model accounting for
synaptic conductance.
The synaptic current jumps upon spike arrival, which causes a jump in
membrane potential.
Synaptic current and membrane potential decay exponentially with rates
of alpha and beta, respectively.
For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`.
If `reset_mechanism = "subtract"`, then :math:`U[t+1]` will have
`threshold` subtracted from it whenever the neuron emits a spike:
.. math::
I_{\\rm syn}[t+1] = αI_{\\rm syn}[t] + V(S_{\\rm out}[t]
+ I_{\\rm in}[t+1] \\\\
U[t+1] = βU[t] + I_{\\rm syn}[t+1] - RU_{\\rm thr}
Where :math:`V(\\cdot)` acts either as a linear layer, a convolutional
operator, or elementwise product on :math:`S_{\\rm out}`.
* If `all_to_all = "True"` and `linear_features` is specified, then \
:math:`V(\\cdot)` acts as a recurrent linear layer of the same size \
as :math:`S_{\\rm out}`.
* If `all_to_all = "True"` and `conv2d_channels` and `kernel_size` are \
specified, then :math:`V(\\cdot)` acts as a recurrent convlutional \
layer with padding to ensure the output matches the size of the input.
* If `all_to_all = "False"`, then :math:`V(\\cdot)` acts as an \
elementwise multiplier with :math:`V`.
* If `reset_mechanism = "zero"`, then :math:`U[t+1]` will be set to `0`
whenever the neuron emits a spike:
.. math::
I_{\\rm syn}[t+1] = αI_{\\rm syn}[t] + VS_{\\rm out}[t]
+ I_{\\rm in}[t+1] \\\\
U[t+1] = βU[t] + I_{\\rm syn}[t+1] - R(βU[t] + I_{\\rm syn}[t+1])
* :math:`I_{\\rm syn}` - Synaptic current
* :math:`I_{\\rm in}` - Input current
* :math:`U` - Membrane potential
* :math:`U_{\\rm thr}` - Membrane threshold
* :math:`S_{\\rm out}` - Output spike
* :math:`R` - Reset mechanism: if active, :math:`R = 1`, otherwise \
:math:`R = 0`
* :math:`α` - Synaptic current decay rate
* :math:`β` - Membrane potential decay rate
* :math:`V` - Explicit recurrent weight when `all_to_all=False`
Example::
import torch
import torch.nn as nn
import snntorch as snn
beta = 0.5 # decay rate
V1 = 0.5 # shared recurrent connection
V2 = torch.rand(num_outputs) # unshared recurrent connections
# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()
# initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
# Default RLeaky Layer where recurrent connections
# are initialized using PyTorch defaults in nn.Linear.
self.lif1 = snn.RLeaky(beta=beta,
linear_features=num_hidden)
self.fc2 = nn.Linear(num_hidden, num_outputs)
# each neuron has a single connection back to itself
# where the output spike is scaled by V.
# For `all_to_all = True`, V can be shared between
# neurons (e.g., V1) or unique / unshared between
# neurons (e.g., V2).
# V is learnable by default.
self.lif2 = snn.RLeaky(beta=beta, all_to_all=False, V=V1)
def forward(self, x):
# Initialize hidden states at t=0
spk1, syn1, mem1 = self.lif1.init_rsynaptic()
spk2, syn2, mem2 = self.lif2.init_rsynaptic()
# Record output layer spikes and membrane
spk2_rec = []
mem2_rec = []
# time-loop
for step in range(num_steps):
cur1 = self.fc1(x)
spk1, syn1, mem1 = self.lif1(cur1, spk1, syn1, mem1)
cur2 = self.fc2(spk1)
spk2, syn2, mem2 = self.lif2(cur2, spk2, syn2, mem2)
spk2_rec.append(spk2)
mem2_rec.append(mem2)
# convert lists to tensors
spk2_rec = torch.stack(spk2_rec)
mem2_rec = torch.stack(mem2_rec)
return spk2_rec, mem2_rec
:param alpha: synaptic current decay rate. Clipped between 0 and 1
during the forward-pass. May be a single-valued tensor (i.e., equal
decay rate for all neurons in a layer), or multi-valued (one weight per
neuron).
:type alpha: float or torch.tensor
:param beta: membrane potential decay rate. Clipped between 0 and 1
during the forward-pass. May be a single-valued tensor (i.e., equal
decay rate for all neurons in a layer), or multi-valued (one weight per
neuron).
:type beta: float or torch.tensor
:param V: Recurrent weights to scale output spikes, only used when
`all_to_all=False`. Defaults to 1.
:type V: float or torch.tensor
:param all_to_all: Enables output spikes to be connected in dense or
convolutional recurrent structures instead of 1-to-1 connections.
Defaults to True.
:type all_to_all: bool, optional
:param linear_features: Size of each output sample. Must be specified if
`all_to_all=True` and the input data is 1D. Defaults to None
:type linear_features: int, optional
:param conv2d_channels: Number of channels in each output sample. Must
be specified if `all_to_all=True` and the input data is 3D. Defaults to
None
:type conv2d_channels: int, optional
:param kernel_size: Size of the convolving kernel. Must be specified if
`all_to_all=True` and the input data is 3D. Defaults to None
:type kernel_size: int or tuple
:param threshold: Threshold for :math:`mem` to reach in order to
generate a spike `S=1`. Defaults to 1
:type threshold: float, optional
:param spike_grad: Surrogate gradient for the term dS/dU. Defaults to
None (corresponds to ATan surrogate gradient. See
`snntorch.surrogate` for more options)
:type spike_grad: surrogate gradient function from snntorch.surrogate,
optional
:param surrogate_disable: Disables surrogate gradients regardless of
`spike_grad` argument. Useful for ONNX compatibility. Defaults
to False
:type surrogate_disable: bool, Optional
:param init_hidden: Instantiates state variables as instance variables.
Defaults to False
:type init_hidden: bool, optional
:param inhibition: If `True`, suppresses all spiking other than the
neuron with the highest state. Defaults to False
:type inhibition: bool, optional
:param learn_alpha: Option to enable learnable alpha. Defaults to False
:type learn_alpha: bool, optional
:param learn_beta: Option to enable learnable beta. Defaults to False
:type learn_beta: bool, optional
:param learn_recurrent: Option to enable learnable recurrent weights.
Defaults to True
:type learn_recurrent: bool, optional
:param learn_threshold: Option to enable learnable threshold.
Defaults to False
:type learn_threshold: bool, optional
:param reset_mechanism: Defines the reset mechanism applied to \
:math:`mem` each time the threshold is met. Reset-by-subtraction:
"subtract", reset-to-zero: "zero, none: "none". Defaults to "subtract"
:type reset_mechanism: str, optional
:param state_quant: If specified, hidden states :math:`mem` and \
:math:`syn` are quantized to a valid state for the forward pass. \
Defaults to False
:type state_quant: quantization function from snntorch.quant, optional
:param output: If `True` as well as `init_hidden=True`, states are
returned when neuron is called. Defaults to False
:type output: bool, optional
Inputs: \\input_, spk_0, syn_0, mem_0
- **input_** of shape `(batch, input_size)`: tensor containing input \
features
- **spk_0** of shape `(batch, input_size)`: tensor containing output \
spike features
- **syn_0** of shape `(batch, input_size)`: tensor containing input \
features
- **mem_0** of shape `(batch, input_size)`: tensor containing the \
initial membrane potential for each element in the batch.
Outputs: spk_1, syn_1, mem_1
- **spk_1** of shape `(batch, input_size)`: tensor containing the \
output spikes.
- **syn_1** of shape `(batch, input_size)`: tensor containing the \
next synaptic current for each element in the batch
- **mem_1** of shape `(batch, input_size)`: tensor containing the \
next membrane potential for each element in the batch
Learnable Parameters:
- **RSynaptic.alpha** (torch.Tensor) - optional learnable weights \
must be manually passed in, of shape `1` or (input_size).
- **RSynaptic.beta** (torch.Tensor) - optional learnable weights \
must be manually passed in, of shape `1` or (input_size).
- **RSynaptic.recurrent.weight** (torch.Tensor) - optional learnable \
weights are automatically generated if `all_to_all=True`. \
`RSynaptic.recurrent` stores a `nn.Linear` or `nn.Conv2d` layer \
depending on input arguments provided.
- **RSynaptic.V** (torch.Tensor) - optional learnable weights must \
be manually passed in, of shape `1` or (input_size). It is only used \
where `all_to_all=False` for 1-to-1 recurrent connections.
- **RSynaptic.threshold** (torch.Tensor) - optional learnable \
thresholds must be manually passed in, of shape `1` or`` (input_size).
"""
def __init__(
self,
alpha,
beta,
V=1.0,
all_to_all=True,
linear_features=None,
conv2d_channels=None,
kernel_size=None,
threshold=1.0,
spike_grad=None,
surrogate_disable=False,
init_hidden=False,
inhibition=False,
learn_alpha=False,
learn_beta=False,
learn_threshold=False,
learn_recurrent=True,
reset_mechanism="subtract",
state_quant=False,
output=False,
reset_delay=True,
):
super().__init__(
beta,
threshold,
spike_grad,
surrogate_disable,
init_hidden,
inhibition,
learn_beta,
learn_threshold,
reset_mechanism,
state_quant,
output,
)
self.all_to_all = all_to_all
self.learn_recurrent = learn_recurrent
# linear params
self.linear_features = linear_features
# Conv2d params
self.kernel_size = kernel_size
self.conv2d_channels = conv2d_channels
# catch cases
self._rsynaptic_init_cases()
# initialize recurrent connections
if self.all_to_all: # init all-all connections
self._init_recurrent_net()
else: # initialize 1-1 connections
self._V_register_buffer(V, learn_recurrent)
self._init_recurrent_one_to_one()
if not learn_recurrent:
self._disable_recurrent_grad()
self._alpha_register_buffer(alpha, learn_alpha)
self._init_mem()
if self.reset_mechanism_val == 0: # reset by subtraction
self.state_function = self._base_sub
elif self.reset_mechanism_val == 1: # reset to zero
self.state_function = self._base_zero
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int
self.reset_delay = reset_delay
def _init_mem(self):
spk = torch.zeros(0)
syn = torch.zeros(0)
mem = torch.zeros(0)
self.register_buffer("spk", spk, False)
self.register_buffer("syn", syn, False)
self.register_buffer("mem", mem, False)
[docs]
def reset_mem(self):
self.spk = torch.zeros_like(self.spk, device=self.spk.device)
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.spk, self.syn, self.mem
[docs]
def init_rsynaptic(self):
"""Deprecated, use :class:`RSynaptic.reset_mem` instead"""
return self.reset_mem()
[docs]
def forward(self, input_, spk=None, syn=None, mem=None):
if not spk == None:
self.spk = spk
if not syn == None:
self.syn = syn
if not mem == None:
self.mem = mem
if self.init_hidden and (
not spk == None or not syn == None or not mem == None
):
raise TypeError(
"When `init_hidden=True`, RSynaptic expects 1 input argument."
)
if not self.spk.shape == input_.shape:
self.spk = torch.zeros_like(input_, device=self.spk.device)
if not self.syn.shape == input_.shape:
self.syn = torch.zeros_like(input_, device=self.syn.device)
if not self.mem.shape == input_.shape:
self.mem = torch.zeros_like(input_, device=self.mem.device)
self.reset = self.mem_reset(self.mem)
self.syn, self.mem = self.state_function(input_)
if self.state_quant:
self.syn = self.state_quant(self.syn)
self.mem = self.state_quant(self.mem)
if self.inhibition:
self.spk = self.fire_inhibition(self.mem.size(0), self.mem)
else:
self.spk = self.fire(self.mem)
if not self.reset_delay:
# reset membrane potential _right_ after spike
do_reset = (
spk / self.graded_spikes_factor - self.reset
) # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
# mem -= do_reset * mem
mem = mem - do_reset * mem
if self.output:
return self.spk, self.syn, self.mem
elif self.init_hidden:
return self.spk
else:
return self.spk, self.syn, self.mem
def _init_recurrent_net(self):
if self.all_to_all:
if self.linear_features:
self._init_recurrent_linear()
elif self.kernel_size is not None:
self._init_recurrent_conv2d()
else:
self._init_recurrent_one_to_one()
def _init_recurrent_linear(self):
self.recurrent = nn.Linear(self.linear_features, self.linear_features)
def _init_recurrent_conv2d(self):
self._init_padding()
self.recurrent = nn.Conv2d(
in_channels=self.conv2d_channels,
out_channels=self.conv2d_channels,
kernel_size=self.kernel_size,
padding=self.padding,
)
def _init_padding(self):
if type(self.kernel_size) is int:
self.padding = self.kernel_size // 2, self.kernel_size // 2
else:
self.padding = self.kernel_size[0] // 2, self.kernel_size[1] // 2
def _init_recurrent_one_to_one(self):
self.recurrent = RecurrentOneToOne(self.V)
def _disable_recurrent_grad(self):
for param in self.recurrent.parameters():
param.requires_grad = False
def _base_state_function(self, input_):
base_fn_syn = (
self.alpha.clamp(0, 1) * self.syn
+ input_
+ self.recurrent(self.spk)
)
base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn
return base_fn_syn, base_fn_mem
def _base_state_reset_zero(self, input_):
base_fn_syn = (
self.alpha.clamp(0, 1) * self.syn
+ input_
+ self.recurrent(self.spk)
)
base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn
return 0, base_fn_mem
def _base_sub(self, input_):
syn, mem = self._base_state_function(input_)
mem -= self.reset * self.threshold
return syn, mem
def _base_zero(self, input_):
syn, mem = self._base_state_function(input_)
syn2, mem2 = self._base_state_reset_zero(input_)
syn2 *= self.reset
mem2 *= self.reset
syn -= syn2
mem -= mem2
return syn, mem
def _base_int(self, input_):
return self._base_state_function(input_)
def _alpha_register_buffer(self, alpha, learn_alpha):
if not isinstance(alpha, torch.Tensor):
alpha = torch.as_tensor(alpha)
if learn_alpha:
self.alpha = nn.Parameter(alpha)
else:
self.register_buffer("alpha", alpha)
def _rsynaptic_init_cases(self):
all_to_all_bool = bool(self.all_to_all)
linear_features_bool = self.linear_features
conv2d_channels_bool = bool(self.conv2d_channels)
kernel_size_bool = bool(self.kernel_size)
if all_to_all_bool:
if not (linear_features_bool):
if not (conv2d_channels_bool or kernel_size_bool):
raise TypeError(
"When `all_to_all=True`, RSynaptic requires either "
"`linear_features` or (`conv2d_channels` and "
"`kernel_size`) to be specified. The shape should "
"match the shape of the output spike of the layer."
)
elif conv2d_channels_bool ^ kernel_size_bool:
raise TypeError(
"`conv2d_channels` and `kernel_size` must both be "
"specified. The shape of `conv2d_channels` should "
"match the shape of the output spikes."
)
elif (linear_features_bool and kernel_size_bool) or (
linear_features_bool and conv2d_channels_bool
):
raise TypeError(
"`linear_features` cannot be specified at the same time "
"as `conv2d_channels` or `kernel_size`. A linear layer "
"and conv2d layer cannot both be specified at the same "
"time."
)
else:
if (
linear_features_bool
or conv2d_channels_bool
or kernel_size_bool
):
raise TypeError(
"When `all_to_all`=False, none of `linear_features`, "
"`conv2d_channels`, or `kernel_size` should be specified. "
"The weight `V` is used instead."
)
[docs]
@classmethod
def detach_hidden(cls):
"""Returns the hidden states, detached from the current graph.
Intended for use in truncated backpropagation through time where
hidden state variables are instance variables."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], RSynaptic):
cls.instances[layer].syn.detach_()
cls.instances[layer].mem.detach_()
[docs]
@classmethod
def reset_hidden(cls):
"""Used to clear hidden state variables to zero.
Intended for use where hidden state variables are instance
variables."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], RSynaptic):
cls.instances[layer].syn = torch.zeros_like(
cls.instances[layer].syn,
device=cls.instances[layer].syn.device,
)
cls.instances[layer].mem = torch.zeros_like(
cls.instances[layer].mem,
device=cls.instances[layer].mem.device,
)
[docs]
class RecurrentOneToOne(nn.Module):
def __init__(self, V):
super().__init__()
self.V = V
[docs]
def forward(self, x):
return x * self.V # element-wise or global multiplication