import torch
import torch.nn as nn
import torch.nn.functional as F
from .neurons import SpikingNeuron
[docs]
class SConv2dLSTM(SpikingNeuron):
"""
A spiking 2d convolutional long short-term memory cell.
Hidden states are membrane potential and synaptic current
:math:`mem, syn`, which correspond to the hidden and cell states
:math:`h, c` in the original LSTM formulation.
The input is expected to be of size :math:`(N, C_{in}, H_{in}, W_{in})`
where :math:`N` is the batch size.
Unlike the LSTM module in PyTorch, only one time step is simulated each
time the cell is called.
.. math::
\\begin{array}{ll} \\\\
i_t = \\sigma(W_{ii} ⋆ x_t + b_{ii} + W_{hi} ⋆ mem_{t-1} + b_{hi})
\\\\
f_t = \\sigma(W_{if} ⋆ x_t + b_{if} + W_{hf} mem_{t-1} + b_{hf})
\\\\
g_t = \\tanh(W_{ig} ⋆ x_t + b_{ig} + W_{hg} ⋆ mem_{t-1} + b_{hg})
\\\\
o_t = \\sigma(W_{io} ⋆ x_t + b_{io} + W_{ho} ⋆ mem_{t-1} + b_{ho})
\\\\
syn_t = f_t ∗ c_{t-1} + i_t ∗ g_t \\\\
mem_t = o_t ∗ \\tanh(syn_t) \\\\
\\end{array}
where :math:`\\sigma` is the sigmoid function, ⋆ is the 2D
cross-correlation operator and ∗ is the Hadamard product.
The output state :math:`mem_{t+1}` is thresholded to determine whether
an output spike is generated.
To conform to standard LSTM state behavior, the default reset mechanism
is set to `reset="none"`, i.e., no reset is applied. If this is changed,
the reset is only applied to :math:`mem_t`.
Options to apply max-pooling or average-pooling to the state
:math:`mem_t` are also enabled. Note that it is preferable to apply
pooling to the state rather than the spike, as it does not make sense
to apply pooling to activations of 1's and 0's which may lead to random
tie-breaking.
Padding is automatically applied to ensure consistent sizes for
hidden states from one time step to the next.
At the moment, stride != 1 is not supported.
Example::
import torch
import torch.nn as nn
import snntorch as snn
# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()
in_channels = 1
out_channels = 8
kernel_size = 3
max_pool = 2
avg_pool = 2
flattened_input = 49 * 16
num_outputs = 10
beta = 0.5
spike_grad_lstm = snn.surrogate.straight_through_estimator()
spike_grad_fc = snn.surrogate.fast_sigmoid(slope=5)
# initialize layers
self.sclstm1 = snn.SConv2dLSTM(
in_channels,
out_channels,
kernel_size,
max_pool=max_pool,
spike_grad=spike_grad_lstm,
)
self.sclstm2 = snn.SConv2dLSTM(
out_channels,
out_channels,
kernel_size,
avg_pool=avg_pool,
spike_grad=spike_grad_lstm,
)
self.fc1 = nn.Linear(flattened_input, num_outputs)
self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad_fc)
def forward(self, x):
# Initialize hidden states and outputs at t=0
syn1, mem1 = self.lif1.init_sconv2dlstm()
syn2, mem2 = self.lif1.init_sconv2dlstm()
mem3 = self.lif3.init_leaky()
# Record the final layer
spk3_rec = []
mem3_rec = []
# Number of steps assuming x is [N, T, C, H, W] with
# N = Batches, T = Time steps, C = Channels,
# H = Height, W = Width
num_steps = x.size()[1]
for step in range(num_steps):
x_step = x[:, step, :, :, :]
spk1, syn1, mem1 = self.sclstm1(x_step, syn1, mem1)
spk2, syn2, mem2 = self.sclstm2(spk1, syn2, mem2)
cur = self.fc1(spk2.flatten(1))
spk3, mem3 = self.lif1(cur, mem3)
spk3_rec.append(spk3)
mem3_rec.append(mem3)
return torch.stack(spk3_rec), torch.stack(mem3_rec)
:param in_channels: number of input channels
:type in_channels: int
:param kernel_size: Size of the convolving kernel
:type kernel_size: int, tuple, or list
:param bias: If `True`, adds a learnable bias to the output. Defaults to
`True`
:type bias: bool, optional
:param max_pool: Applies max-pooling to the hidden state :math:`mem`
prior to thresholding if specified. Defaults to 0
:type max_pool: int, tuple, or list, optional
:param avg_pool: Applies average-pooling to the hidden state :math:`mem`
prior to thresholding if specified. Defaults to 0
:type avg_pool: int, tuple, or list, optional
:param threshold: Threshold for :math:`mem` to reach in order to
generate a spike `S=1`. Defaults to 1
:type threshold: float, optional
:param spike_grad: Surrogate gradient for the term dS/dU. Defaults to
ATan surrogate gradient
:type spike_grad: surrogate gradient function from snntorch.surrogate,
optional
:param surrogate_disable: Disables surrogate gradients regardless of
`spike_grad` argument. Useful for ONNX compatibility. Defaults
to False
:type surrogate_disable: bool, Optional
:param learn_threshold: Option to enable learnable threshold. Defaults
to False
:type learn_threshold: bool, optional
:param init_hidden: Instantiates state variables as instance variables.
Defaults to False
:type init_hidden: bool, optional
:param inhibition: If `True`, suppresses all spiking other than the
neuron with the highest state. Defaults to False
:type inhibition: bool, optional
:param reset_mechanism: Defines the reset mechanism applied to \
:math:`mem` each time the threshold is met. Reset-by-subtraction: \
"subtract", reset-to-zero: "zero, none: "none". Defaults to "none"
:type reset_mechanism: str, optional
:param state_quant: If specified, hidden states :math:`mem` and \
:math:`syn` are quantized to a valid state for the forward pass. \
Defaults to False
:type state_quant: quantization function from snntorch.quant, optional
:param output: If `True` as well as `init_hidden=True`, states are
returned when neuron is called. Defaults to False
:type output: bool, optional
Inputs: \\input_, syn_0, mem_0
- **input_** of shape `(batch, in_channels, H, W)`: tensor \
containing input features
- **syn_0** of shape `(batch, out_channels, H, W)`: tensor \
containing the initial synaptic current (or cell state) for each \
element in the batch.
- **mem_0** of shape `(batch, out_channels, H, W)`: tensor \
containing the initial membrane potential (or hidden state) for each \
element in the batch.
Outputs: spk, syn_1, mem_1
- **spk** of shape `(batch, out_channels, H/pool, W/pool)`: tensor \
containing the output spike (avg_pool and max_pool scale if greater \
than 0.)
- **syn_1** of shape `(batch, out_channels, H, W)`: tensor \
containing the next synaptic current (or cell state) for each element \
in the batch
- **mem_1** of shape `(batch, out_channels, H, W)`: tensor \
containing the next membrane potential (or hidden state) for each \
element in the batch
Learnable Parameters:
- **SConv2dLSTM.conv.weight** (torch.Tensor) - the learnable \
weights, of shape ((in_channels + out_channels), 4*out_channels, \
kernel_size).
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
bias=True,
max_pool=0,
avg_pool=0,
threshold=1.0,
spike_grad=None,
surrogate_disable=False,
init_hidden=False,
inhibition=False,
learn_threshold=False,
reset_mechanism="none",
state_quant=False,
output=False,
):
super().__init__(
threshold,
spike_grad,
surrogate_disable,
init_hidden,
inhibition,
learn_threshold,
reset_mechanism,
state_quant,
output,
)
self._init_mem()
if self.reset_mechanism_val == 0: # reset by subtraction
self.state_function = self._base_sub
elif self.reset_mechanism_val == 1: # reset to zero
self.state_function = self._base_zero
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.max_pool = max_pool
self.avg_pool = avg_pool
self.bias = bias
self._sconv2dlstm_cases()
# padding is essential to keep same shape for next step
if type(self.kernel_size) is int:
self.padding = kernel_size // 2, kernel_size // 2
else:
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
# Note, this applies the same Conv to all 4 gates
# Regular LSTMs have different dense layers applied to all 4 gates
# Consider: a separate nn.Conv2d instance p/gate?
self.conv = nn.Conv2d(
in_channels=self.in_channels + self.out_channels,
out_channels=4 * self.out_channels,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias,
)
def _init_mem(self):
syn = torch.zeros(0)
mem = torch.zeros(0)
self.register_buffer("syn", syn, False)
self.register_buffer("mem", mem, False)
[docs]
def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.syn, self.mem
[docs]
def init_sconv2dlstm(self):
"""Deprecated, use :class:`SConv2dLSTM.reset_mem` instead"""
return self.reset_mem()
[docs]
def forward(self, input_, syn=None, mem=None):
if not syn == None:
self.syn = syn
if not mem == None:
self.mem = mem
if self.init_hidden and (not mem == None or not syn == None):
raise TypeError(
"`mem` or `syn` should not be passed as an argument while `init_hidden=True`"
)
size = input_.size()
correct_shape = (size[0], self.out_channels, size[2], size[3])
if not self.syn.shape == correct_shape:
self.syn = torch.zeros(correct_shape, device=self.syn.device)
if not self.mem.shape == correct_shape:
self.mem = torch.zeros(correct_shape, device=self.mem.device)
self.reset = self.mem_reset(self.mem)
self.syn, self.mem = self.state_function(input_)
if self.state_quant:
self.syn = self.state_quant(self.syn)
self.mem = self.state_quant(self.mem)
if self.max_pool:
self.spk = self.fire(F.max_pool2d(self.mem, self.max_pool))
elif self.avg_pool:
self.spk = self.fire(F.avg_pool2d(self.mem, self.avg_pool))
else:
self.spk = self.fire(self.mem)
if self.output:
return self.spk, self.syn, self.mem
elif self.init_hidden:
return self.spk
else:
return self.spk, self.syn, self.mem
def _base_state_function(self, input_):
combined = torch.cat(
[input_, self.mem], dim=1
) # concatenate along channel axis
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(
combined_conv, self.out_channels, dim=1
)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
base_fn_syn = f * self.syn + i * g
base_fn_mem = o * torch.tanh(base_fn_syn)
return base_fn_syn, base_fn_mem
def _base_state_reset_zero(self, input_):
combined = torch.cat(
[input_, self.mem], dim=1
) # concatenate along channel axis
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(
combined_conv, self.out_channels, dim=1
)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
base_fn_syn = f * self.syn + i * g
base_fn_mem = o * torch.tanh(base_fn_syn)
return 0, base_fn_mem
def _base_sub(self, input_):
syn, mem = self._base_state_function(input_)
mem -= self.reset * self.threshold
return syn, mem
def _base_zero(self, input_):
syn, mem = self._base_state_function(input_)
syn2, mem2 = self._base_state_reset_zero(input_)
syn2 *= self.reset
mem2 *= self.reset
syn -= syn2
mem -= mem2
return syn, mem
def _base_int(self, input_):
return self._base_state_function(input_)
def _sconv2dlstm_cases(self):
if self.max_pool and self.avg_pool:
raise ValueError(
"Only one of either `max_pool` or `avg_pool` may be "
"specified, not both."
)
[docs]
@classmethod
def detach_hidden(cls):
"""Returns the hidden states, detached from the current graph.
Intended for use in truncated backpropagation through time where
hidden state variables are instance variables."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], SConv2dLSTM):
cls.instances[layer].syn.detach_()
cls.instances[layer].mem.detach_()
[docs]
@classmethod
def reset_hidden(cls):
"""Used to clear hidden state variables to zero.
Intended for use where hidden state variables are
instance variables."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], SConv2dLSTM):
cls.instances[layer].syn = torch.zeros_like(
cls.instances[layer].syn,
device=cls.instances[layer].syn.device,
)
cls.instances[layer].mem = torch.zeros_like(
cls.instances[layer].mem,
device=cls.instances[layer].mem.device,
)