Source code for snntorch._neurons.alpha

import torch
import torch.nn as nn

from .neurons import LIF


[docs] class Alpha(LIF): """ A variant of the leaky integrate and fire neuron where membrane potential follows an alpha function. The time course of the membrane potential response depends on a combination of exponentials. In general, this causes the change in membrane potential to experience a delay with respect to an input spike. For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`. .. warning:: For a positive input current to induce a positive membrane \ response, ensure :math:`α > β`. If `reset_mechanism = "zero"`, then :math:`I_{\\rm exc}, I_{\\rm inh}` will both be set to `0` whenever the neuron emits a spike: .. math:: I_{\\rm exc}[t+1] = (αI_{\\rm exc}[t] + I_{\\rm in}[t+1]) - R(αI_{\\rm exc}[t] + I_{\\rm in}[t+1]) \\\\ I_{\\rm inh}[t+1] = (βI_{\\rm inh}[t] - I_{\\rm in}[t+1]) - R(βI_{\\rm inh}[t] - I_{\\rm in}[t+1]) \\\\ U[t+1] = τ_{\\rm α}(I_{\\rm exc}[t+1] + I_{\\rm inh}[t+1]) * :math:`I_{\\rm exc}` - Excitatory current * :math:`I_{\\rm inh}` - Inhibitory current * :math:`I_{\\rm in}` - Input current * :math:`U` - Membrane potential * :math:`U_{\\rm thr}` - Membrane threshold * :math:`R` - Reset mechanism, :math:`R = 1` if spike occurs, otherwise \ :math:`R = 0` * :math:`α` - Excitatory current decay rate * :math:`β` - Inhibitory current decay rate * :math:`τ_{\\rm α} = \\frac{log(α)}{log(β)} - log(α) + 1` Example:: import torch import torch.nn as nn import snntorch as snn alpha = 0.9 beta = 0.8 # Define Network class Net(nn.Module): def __init__(self): super().__init__() # initialize layers self.fc1 = nn.Linear(num_inputs, num_hidden) self.lif1 = snn.Alpha(alpha=alpha, beta=beta) self.fc2 = nn.Linear(num_hidden, num_outputs) self.lif2 = snn.Alpha(alpha=alpha, beta=beta) def forward(self, x, syn_exc1, syn_inh1, mem1, spk1, syn_exc2, syn_inh2, mem2): cur1 = self.fc1(x) spk1, syn_exc1, syn_inh1, mem1 = self.lif1(cur1, syn_exc1, syn_inh1, mem1) cur2 = self.fc2(spk1) spk2, syn_exc2, syn_inh2, mem2 = self.lif2(cur2, syn_exc2, syn_inh2, mem2) return syn_exc1, syn_inh1, mem1, spk1, syn_exc2, syn_inh2, mem2, spk2 # Too many state variables which becomes cumbersome, so the # following is also an option: alpha = 0.9 beta = 0.8 net = nn.Sequential(nn.Linear(num_inputs, num_hidden), snn.Alpha(alpha=alpha, beta=beta, init_hidden=True), nn.Linear(num_hidden, num_outputs), snn.Alpha(alpha=alpha, beta=beta, init_hidden=True, output=True)) """ def __init__( self, alpha, beta, threshold=1.0, spike_grad=None, surrogate_disable=False, init_hidden=False, inhibition=False, learn_alpha=False, learn_beta=False, learn_threshold=False, reset_mechanism="zero", state_quant=False, output=False, ): super().__init__( beta, threshold, spike_grad, surrogate_disable, init_hidden, inhibition, learn_beta, learn_threshold, reset_mechanism, state_quant, output, ) self._alpha_register_buffer(alpha, learn_alpha) self._alpha_cases() self._init_mem() if self.reset_mechanism_val == 0: # reset by subtraction self.state_function = self._base_sub elif self.reset_mechanism_val == 1: # reset to zero self.state_function = self._base_zero elif self.reset_mechanism_val == 2: # no reset, pure integration self.state_function = self._base_int def _init_mem(self): syn_exc = torch.zeros(0) syn_inh = torch.zeros(0) mem = torch.zeros(0) self.register_buffer("syn_exc", syn_exc, False) self.register_buffer("syn_inh", syn_inh, False) self.register_buffer("mem", mem, False)
[docs] def reset_mem(self): self.syn_exc = torch.zeros_like( self.syn_exc, device=self.syn_exc.device ) self.syn_inh = torch.zeros_like( self.syn_inh, device=self.syn_inh.device ) self.mem = torch.zeros_like(self.mem, device=self.mem.device) return self.syn_exc, self.syn_inh, self.mem
[docs] def init_alpha(self): """Deprecated, use :class:`Alpha.reset_mem` instead""" return self.reset_mem()
[docs] def forward(self, input_, syn_exc=None, syn_inh=None, mem=None): if not syn_exc == None: self.syn_exc = syn_exc if not syn_inh == None: self.syn_inh = syn_inh if not mem == None: self.mem = mem if self.init_hidden and ( not mem == None or not syn_exc == None or not syn_inh == None ): raise TypeError( "When `init_hidden=True`, Alpha expects 1 input argument." ) if not self.syn_exc.shape == input_.shape: self.syn_exc = torch.zeros_like(input_, device=self.syn_exc.device) if not self.syn_inh.shape == input_.shape: self.syn_inh = torch.zeros_like(input_, device=self.syn_inh.device) if not self.mem.shape == input_.shape: self.mem = torch.zeros_like(input_, device=self.mem.device) self.reset = self.mem_reset(self.mem) self.syn_exc, self.syn_inh, self.mem = self.state_function(input_) if self.state_quant: self.syn_exc = self.state_quant(self.syn_exc) self.syn_inh = self.state_quant(self.syn_inh) self.mem = self.state_quant(self.mem) if self.inhibition: spk = self.fire_inhibition(self.mem.size(0), self.mem) else: spk = self.fire(self.mem) if self.output: return spk, self.syn_exc, self.syn_inh, self.mem elif self.init_hidden: return spk else: return spk, self.syn_exc, self.syn_inh, self.mem
def _base_state_function(self, input_): base_fn_syn_exc = self.alpha.clamp(0, 1) * self.syn_exc + input_ base_fn_syn_inh = self.beta.clamp(0, 1) * self.syn_inh - input_ tau_alpha = ( torch.log(self.alpha.clamp(0, 1)) / ( torch.log(self.beta.clamp(0, 1)) - torch.log(self.alpha.clamp(0, 1)) ) + 1 ) base_fn_mem = tau_alpha * (base_fn_syn_exc + base_fn_syn_inh) return base_fn_syn_exc, base_fn_syn_inh, base_fn_mem def _base_state_reset_sub_function(self, input_): syn_exc_reset = self.threshold syn_inh_reset = self.beta.clamp(0, 1) * self.syn_inh - input_ mem_reset = -self.syn_inh return syn_exc_reset, syn_inh_reset, mem_reset def _base_sub(self, input_): syn_exec, syn_inh, mem = self._base_state_function(input_) syn_exec2, syn_inh2, mem2 = self._base_state_reset_sub_function(input_) syn_exec -= syn_exec2 * self.reset syn_inh -= syn_inh2 * self.reset mem -= mem2 * self.reset return syn_exec, syn_inh, mem def _base_zero(self, input_): syn_exec, syn_inh, mem = self._base_state_function(input_) syn_exec2, syn_inh2, mem2 = self._base_state_function(input_) syn_exec -= syn_exec2 * self.reset syn_inh -= syn_inh2 * self.reset mem -= mem2 * self.reset return syn_exec, syn_inh, mem def _base_int(self, input_): return self._base_state_function(input_) def _alpha_register_buffer(self, alpha, learn_alpha): if not isinstance(alpha, torch.Tensor): alpha = torch.as_tensor(alpha) if learn_alpha: self.alpha = nn.Parameter(alpha) else: self.register_buffer("alpha", alpha) self.alpha = self.alpha.clamp(0, 1) def _alpha_cases(self): if (self.alpha <= self.beta).any(): raise ValueError("alpha must be greater than beta.") if (self.beta == 1).any(): raise ValueError( "beta cannot be '1' otherwise ZeroDivisionError occurs: " "tau_alpha = log(alpha)/log(beta) - log(alpha) + 1" )
[docs] @classmethod def detach_hidden(cls): """Used to detach hidden states from the current graph. Intended for use in truncated backpropagation through time where hidden state variables are instance variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Alpha): cls.instances[layer].syn_exc.detach_() cls.instances[layer].syn_inh.detach_() cls.instances[layer].mem.detach_()
[docs] @classmethod def reset_hidden(cls): """Used to clear hidden state variables to zero. Intended for use where hidden state variables are instance variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Alpha): cls.instances[layer].syn_exc = torch.zeros_like( cls.instances[layer].syn_exc, device=cls.instances[layer].syn_exc.device, ) cls.instances[layer].syn_inh = torch.zeros_like( cls.instances[layer].syn_inh, device=cls.instances[layer].syn_inh.device, ) cls.instances[layer].mem = torch.zeros_like( cls.instances[layer].mem, device=cls.instances[layer].mem.device, )