snntorch.functional

snntorch.functional implements common arithmetic operations applied to spiking neurons, such as loss and regularization functions, and state quantization etc.

How to use functional

To use snntorch.functional you assign the function state to a variable, and then call that variable.

Example:

import snntorch as snn
import snntorch.functional as SF

net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=betas)
criterion = SF.ce_count_loss()  # apply cross-entropy to spike count

spk_rec, mem_rec = net(input_data)
loss = loss_fn(spk_rec, targets)

optimizer.zero_grad()
loss.backward()

# Weight Update
optimizer.step()

Accuracy Functions

snntorch.functional.acc.accuracy_rate(spk_out, targets, population_code=False, num_classes=False)[source]

Use spike count to measure accuracy.

Parameters:
  • spk_out (torch.Tensor) – Output spikes of shape [num_steps x batch_size x num_outputs]

  • targets (torch.Tensor) – Target tensor (without one-hot-encoding) of shape [batch_size]

Returns:

accuracy

Return type:

numpy.float64

snntorch.functional.acc.accuracy_temporal(spk_out, targets)[source]

Use spike timing to measure accuracy.

Parameters:
  • spk_out (torch.Tensor) – Output spikes of shape [num_steps x batch_size x num_outputs]

  • targets (torch.Tensor) – Target tensor (without one-hot-encoding) of shape [batch_size]

Returns:

accuracy

Return type:

numpy.float64

Loss Functions

class snntorch.functional.loss.LossFunctions(reduction, weight)[source]

Bases: object

class snntorch.functional.loss.SpikeTime(target_is_time=False, on_target=0, off_target=-1, tolerance=0, multi_spike=False)[source]

Bases: Module

Used by ce_temporal_loss and mse_temporal_loss to convert spike outputs into spike times.

class FirstSpike(*args, **kwargs)[source]

Bases: Function

Convert spk_rec of 1/0s [TxBxN] –> first spike time [BxN]. Linearize df/dS=-1 if spike, 0 if no spike.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, spk_rec, device='cpu')[source]

Convert spk_rec of 1/0s [TxBxN] –> spk_time [TxBxN]. 0’s indicate no spike –> +1 is first time step. Transpose accounts for broadcasting along final dimension (i.e., multiply along T).

class MultiSpike(*args, **kwargs)[source]

Bases: Function

Convert spk_rec of 1/0s [TxBxN] –> first F spike times [FxBxN]. Linearize df/dS=-1 if spike, 0 if no spike.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, spk_rec, spk_count, device='cpu')[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class Tolerance(*args, **kwargs)[source]

Bases: Function

If spike time is ‘close enough’ to target spike within tolerance, set the time to target for loss calc only.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, spk_time, target, tolerance)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

forward(spk_out, targets)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

label_to_multi_spike(targets, num_outputs)[source]

Convert labels from neuron index (dim: B) to multiple spike times (dim: F x B x N). F is the number of spikes per neuron. Assumes target is iterable along F.

label_to_single_spike(targets, num_outputs)[source]

Convert labels from neuron index (dim: B) to first spike time (dim: B x N).

labels_to_spike_times(targets, num_outputs)[source]

Convert index labels [B] into spike times.

class snntorch.functional.loss.ce_count_loss(population_code=False, num_classes=False, reduction='mean', weight=None)[source]

Bases: LossFunctions

Cross Entropy Spike Count Loss.

The spikes at each time step [num_steps x batch_size x num_outputs] are accumulated and then passed through the Cross Entropy Loss function. This criterion combines log_softmax and NLLLoss in a single function. The Cross Entropy Loss encourages the correct class to fire at all time steps, and aims to suppress incorrect classes from firing.

The Cross Entropy Count Loss accumulates spikes first, and applies Cross Entropy Loss only once. In contrast, the Cross Entropy Rate Loss applies the Cross Entropy function at every time step.

Example:

import snntorch.functional as SF

# if not using population codes (i.e., more output neurons than
there are classes)
loss_fn = ce_count_loss()
loss = loss_fn(spk_out, targets)

# if using population codes; e.g., 200 output neurons, 10 output
classes --> 20 output neurons p/class
loss_fn = ce_count_loss(population_code=True, num_classes=10)
loss = loss_fn(spk_out, targets)
Parameters:
  • population_code (bool, optional) – Specify if a population code is applied, i.e., the number of outputs is greater than the number of classes. Defaults to False

  • num_classes (int, optional) – Number of output classes must be specified if population_code=True. Must be a factor of the number of output neurons if population code is enabled. Defaults to False

Returns:

Loss

Return type:

torch.Tensor (single element)

class snntorch.functional.loss.ce_max_membrane_loss(reduction='mean', weight=None)[source]

Bases: LossFunctions

Cross Entropy Max Membrane Loss. When called, the maximum membrane potential value for each output neuron is sampled and passed through the Cross Entropy Loss Function. This criterion combines log_softmax and NLLLoss in a single function. The Cross Entropy Loss encourages the maximum membrane potential of the correct class to increase, while suppressing the maximum membrane potential of incorrect classes. This function is adopted from SpyTorch by Friedemann Zenke.

Example:

import snntorch.functional as SF

loss_fn = SF.ce_max_membrane_loss()
loss = loss_fn(outputs, targets)
Parameters:
  • mem_out (torch.Tensor) – The output tensor of the SNN’s membrane potential, of the dimension timestep * batch_size * num_output_neurons

  • targets (torch.Tensor) – The tensor containing the targets of the current mini-batch, of the dimension batch_size

Returns:

Loss

Return type:

torch.Tensor (single element)

class snntorch.functional.loss.ce_rate_loss(reduction='mean', weight=None)[source]

Bases: LossFunctions

Cross Entropy Spike Rate Loss. When called, the spikes at each time step are sequentially passed through the Cross Entropy Loss function. This criterion combines log_softmax and NLLLoss in a single function. The losses are accumulated over time steps to give the final loss. The Cross Entropy Loss encourages the correct class to fire at all time steps, and aims to suppress incorrect classes from firing.

The Cross Entropy Rate Loss applies the Cross Entropy function at every time step. In contrast, the Cross Entropy Count Loss accumulates spikes first, and applies Cross Entropy Loss only once.

Example:

import snntorch.functional as SF

loss_fn = SF.ce_rate_loss()
loss = loss_fn(outputs, targets)
Returns:

Loss

Return type:

torch.Tensor (single element)

class snntorch.functional.loss.ce_temporal_loss(inverse='negate', reduction='mean', weight=None)[source]

Bases: object

Cross Entropy Temporal Loss.

The cross entropy loss of an ‘inverted’ first spike time of each output neuron [batch_size x num_outputs] is calculated. The ‘inversion’ is applied such that maximizing the value of the correct class decreases the first spike time (i.e., earlier spike).

Options for inversion include: inverse='negate' which applies (-1 * output), or inverse='reciprocal' which takes (1/output).

Note that the derivative of each spike time with respect to the spike df/dU is non-differentiable for most neuron classes, and is set to a sign estimator of -1. I.e., increasing membrane potential causes a proportionately earlier firing time.

Index labels are passed as the target. To specify the exact spike time, use mse_temporal_loss instead.

Note: After spike times with specified targets, no penalty is applied for subsequent spiking.

Example:

import torch
import snntorch.functional as SF

# correct classes aimed to fire by default at t=0, incorrect at
final step
loss_fn = ce_temporal_loss()
loss = loss_fn(spk_out, targets)
Parameters:

inverse (str, optional) – Specify how to invert output before taking cross entropy. Either scale by (-1 * x) with inverse='negate' or take the reciprocal (1/x) with inverse='reciprocal'. Defaults to negate

Returns:

Loss

Return type:

torch.Tensor (single element)

class snntorch.functional.loss.mse_count_loss(correct_rate=1, incorrect_rate=0, population_code=False, num_classes=False, reduction='mean', weight=None)[source]

Bases: LossFunctions

Mean Square Error Spike Count Loss. When called, the total spike count is accumulated over time for each neuron. The target spike count for correct classes is set to (num_steps * correct_rate), and for incorrect classes (num_steps * incorrect_rate). The spike counts and target spike counts are then applied to a

Mean Square Error Loss Function.

This function is adopted from SLAYER by Sumit Bam Shrestha and Garrick Orchard.

Example:

import snntorch.functional as SF

loss_fn = SF.mse_count_loss(correct_rate=0.75, incorrect_rate=0.25)
loss = loss_fn(outputs, targets)
Parameters:
  • correct_rate (float, optional) – Firing frequency of correct class as a ratio, e.g., 1 promotes firing at every step; 0.5 promotes firing at 50% of steps, 0 discourages any firing, defaults to 1

  • incorrect_rate (float, optional) – Firing frequency of incorrect class(es) as a ratio, e.g., 1 promotes firing at every step; 0.5 promotes firing at 50% of steps, 0 discourages any firing, defaults to 1

  • population_code (bool, optional) – Specify if a population code is applied, i.e., the number of outputs is greater than the number of classes. Defaults to False

  • num_classes (int, optional) – Number of output classes must be specified if population_code=True. Must be a factor of the number of output neurons if population code is enabled. Defaults to False

Returns:

Loss

Return type:

torch.Tensor (single element)

class snntorch.functional.loss.mse_membrane_loss(time_var_targets=False, on_target=1, off_target=0, reduction='mean', weight=None)[source]

Bases: LossFunctions

Mean Square Error Membrane Loss. When called, pass the output membrane of shape [num_steps x batch_size x num_outputs] and the target tensor of membrane potential. The membrane potential and target are then applied to a Mean Square Error Loss Function. This function is adopted from Spike-Op by Jason K. Eshraghian.

Example:

import snntorch.functional as SF

# if targets are the same at each time-step
loss_fn = mse_membrane_loss(time_var_targets=False)
loss = loss_fn(outputs, targets)

# if targets are time-varying
loss_fn = mse_membrane_loss(time_var_targets=True)
loss = loss_fn(outputs, targets)
Parameters:
  • time_var_targets – Specifies whether the targets are time-varying, defaults to False

  • on_target (float, optional) – Specify target membrane potential for correct class, defaults to 1

  • off_target (float, optional) – Specify target membrane potential for incorrect class, defaults to 0

Returns:

Loss

Return type:

torch.Tensor (single element)

class snntorch.functional.loss.mse_temporal_loss(target_is_time=False, on_target=0, off_target=-1, tolerance=0, multi_spike=False, reduction='mean', weight=None)[source]

Bases: object

Mean Square Error Temporal Loss.

The first spike time of each output neuron [batch_size x num_outputs] is measured against the desired spike time with the Mean Square Error Loss Function. Note that the derivative of each spike time with respect to the spike df/dU is non-differentiable for most neuron classes, and is set to a sign estimator of -1. I.e., increasing membrane potential causes a proportionately earlier firing time.

The Mean Square Error Temporal Loss can account for multiple spikes by setting multi_spike=True. If the actual spike time is close enough to the target spike time within a given tolerance, e.g., tolerance = 5 time steps, then it does not contribute to the loss.

Index labels are passed as the target by default. To enable passing in the spike time(s) for output neuron(s), set target_is_time=True.

Note: After spike times with specified targets, no penalty is applied for subsequent spiking. To eliminate later spikes, an additional target should be applied.

Example:

import torch
import snntorch.functional as SF

# default takes in idx labels as targets
# correct classes aimed to fire by default at t=0, incorrect at t=-1
(final time step)
loss_fn = mse_temporal_loss()
loss = loss_fn(spk_out, targets)

# as above, but correct class fire @ t=5, incorrect at t=100 with a
tolerance of 2 steps
loss_fn = mse_temporal_loss(on_target=5, off_target=100, tolerance=2)
loss = loss_fn(spk_out, targets)

# as above with multiple spike time targets
on_target = torch.tensor(5, 10)
off_target = torch.tensor(100, 105)
loss_fn = mse_temporal_loss(on_target=on_target,
off_target=off_target, tolerance=2)
loss = loss_fn(spk_out, targets)

# specify first spike time for 5 neurons individually, zero tolerance
target = torch.tensor(5, 10, 15, 20, 25)
loss_fn = mse_temporal_loss(target_is_time=True)
loss = loss_fn(spk_out, target)
Parameters:
  • target_is_time (bool, optional) – Specify if target is specified as spike times (True) or as neuron indexes (False). Defaults to False

  • on_target (int (or interable over multiple int if multi_spike=True), optional) – Spike time for correct classes (only if target_is_time=False). Defaults to 0

  • off_target (int (or interable over multiple int if multi_spike=True), optional) – Spike time for incorrect classes (only if target_is_time=False). Defaults to -1, i.e., final time step

  • tolerance (int, optional) – If the distance between the spike time and target is less than the specified tolerance, then it does not contribute to the loss. Defaults to 0.

  • multi_spike (bool, optional) – Specify if multiple spikes in target. Defaults to False

Returns:

Loss

Return type:

torch.Tensor (single element)

Regularization Functions

class snntorch.functional.reg.l1_rate_sparsity(Lambda=1e-05)[source]

Bases: object

L1 regularization using total spike count as the penalty term. Lambda is a scalar factor for regularization.

State Quantization

class snntorch.functional.quant.StateQuant(*args, **kwargs)[source]

Bases: Function

Wrapper function for state_quant

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input_, levels)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

snntorch.functional.quant.state_quant(num_bits=8, uniform=True, thr_centered=True, threshold=1, lower_limit=0, upper_limit=0.2, multiplier=None)[source]

Quantization-Aware Training with spiking neuron states.

Note: for weight quantization, we recommend using Brevitas or another pre-existing PyTorch-friendly library.

Uniform and non-uniform quantization can be applied in various modes by specifying uniform=True.

Valid quantization levels can be centered about 0 or threshold by specifying thr_centered=True.

upper_limit and lower_limit specify the proportion of how far valid levels go above and below the positive and negative threshold/ E.g., upper_limit=0.2 means the maximum valid state is 20% higher than the value specified in threshold.

Example:

import torch
import snntorch as snn
from snntorch.functional import quant

beta = 0.5
thr = 5

# set the quantization parameters
q_lif = quant.state_quant(num_bits=4, uniform=True, threshold=thr)

# specifying state_quant applies state-quantization to the
# hidden state(s) automatically
lif = snn.Leaky(beta=beta, threshold=thr, state_quant=q_lif)

rand_input = torch.rand(1)
mem = lif.init_leaky()

# forward-pass for one step
spk, mem = lif(rand_input, mem)

Note: Quantization-Aware training is focused on modelling a reduced precision network, but does not in of itself accelerate low-precision models. Hidden states are still represented as full precision values for compatibility with PyTorch. For accelerated performance or constrained-memory, the model should be exported to a downstream backend.

Parameters:
  • num_bits (int, optional) – Number of bits to quantize state variables to, defaults to 8

  • uniform (Bool, optional) – Applies uniform quantization if specified, non-uniform if unspecified, defaults to True

  • thr_centered (Bool, optional) – For non-uniform quantization, specifies if valid states should be centered (densely clustered) around the threshold rather than at 0, defaults to True

  • threshold (float, optional) – Specifies the threshold, defaults to 1

  • lower_limit (float, optional) – Specifies how far below (-threshold) the lowest valid state can be, i.e., (-threshold - threshold*lower_limit), defaults to 0

  • upper_limit (float, optional) – Specifies how far above (threshold) the highest valid state can be, i.e., (threshold + threshold*upper_limit), defaults to 0.2

  • multiplier (float, optional) – For non-uniform distributions, specify the base of the exponential. If None, an appropriate value is set internally based on num_bits, defaults to None

Probe

class snntorch.functional.probe.AttributeMonitor(attribute_name: str, pre_forward: bool, net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_attribute: ~typing.Callable = <function AttributeMonitor.<lambda>>)[source]

Bases: BaseMonitor

A monitor to record the attribute (e.g. membrane potential) of a specific neuron layer (e.g. Leaky) in a network. The attribute name can be specified as the first argument of this function. All attribute data is recorded in self.record as data type ‘’list’’. Call self.enable() or self.disable() to enable or disable the monitor. Call self.clear_recorded_data() to clear recorded data.

Example:

import snntorch as snn
from snntorch.functional import probe

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(8, 4)
        self.lif1 = snn.Leaky()
        self.fc2 = nn.Linear(4, 2)
        self.lif2 = snn.Leaky()

    def forward(self, x_seq: torch.Tensor):
        x_seq = self.fc1(x_seq)
        x_seq = self.lif1(x_seq)
        x_seq = self.fc2(x_seq)
        x_seq = self.lif2(x_seq)
        return x_seq

net = Net()

monitor = probe.AttributeMonitor('mem', False, net,
instance=snn.Leaky())

with torch.no_grad():
    y = net(torch.rand([1, 8]))
    print(f'monitor.records={monitor.records}')
    print(f'monitor[0]={monitor[0]}')
    print(f'monitor.monitored_layers={monitor.monitored_layers}')
    print(f"monitor['lif1']={monitor['lif1']}")
Parameters:
  • attribute_name – Attribute’s name of probed neuron layer (e.g., mem, syn, etc.)

  • pre_forward (bool) – If True, record the attribute value before the forward pass, otherwise record the value after forward pass.

  • net (nn.Module) – Network model (either wrapped in Sequential container or as a class)

  • instance (Any or tuple) – Instance of modules to be monitored. If None, defaults to type(net)

  • function_on_attribute (Callable, optional) – Function that is applied to the monitored modules’ attribute

create_hook(name)[source]
class snntorch.functional.probe.BaseMonitor[source]

Bases: object

clear_recorded_data()[source]
disable()[source]
enable()[source]
is_enable()[source]
remove_hooks()[source]
class snntorch.functional.probe.GradInputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_grad_input: ~typing.Callable = <function GradInputMonitor.<lambda>>)[source]

Bases: BaseMonitor

A monitor to record the input gradient of each neuron layer (e.g. Leaky) in a network. All input gradient data is recorded in self.record as data type ‘’list’’. Call self.enable() or self.disable() to enable or disable the monitor. Call self.clear_recorded_data() to clear recorded data.

Example:

import snntorch as snn
from snntorch.functional import probe

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(8, 4)
        self.lif1 = snn.Leaky()
        self.fc2 = nn.Linear(4, 2)
        self.lif2 = snn.Leaky()

    def forward(self, x_seq: torch.Tensor):
        x_seq = self.fc1(x_seq)
        x_seq = self.lif1(x_seq)
        x_seq = self.fc2(x_seq)
        x_seq = self.lif2(x_seq)
        return x_seq

net = Net()

monitor = probe.GradInputMonitor(net, instance=snn.Leaky())

with torch.no_grad():
    y = net(torch.rand([1, 8]))
    print(f'monitor.records={monitor.records}')
    print(f'monitor[0]={monitor[0]}')
    print(f'monitor.monitored_layers={monitor.monitored_layers}')
    print(f"monitor['lif1']={monitor['lif1']}")
Parameters:
  • net (nn.Module) – Network model (either wrapped in Sequential container or as a class)

  • instance (Any or tuple) – Instance of modules to be monitored. If None, defaults to type(net)

  • function_on_grad_input (Callable, optional) – Function that is applied to the monitored modules’ gradients

create_hook(name)[source]
class snntorch.functional.probe.GradOutputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_grad_output: ~typing.Callable = <function GradOutputMonitor.<lambda>>)[source]

Bases: BaseMonitor

A monitor to record the output gradient of each specific neuron layer (e.g. Leaky) in a network. All output gradient data is recorded in self.record as data type ‘’list’’. Call self.enable() or self.disable() to enable or disable the monitor. Call self.clear_recorded_data() to clear recorded data.

Example:

import snntorch as snn
from snntorch.functional import probe

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(8, 4)
        self.lif1 = snn.Leaky()
        self.fc2 = nn.Linear(4, 2)
        self.lif2 = snn.Leaky()

    def forward(self, x_seq: torch.Tensor):
        x_seq = self.fc1(x_seq)
        x_seq = self.lif1(x_seq)
        x_seq = self.fc2(x_seq)
        x_seq = self.lif2(x_seq)
        return x_seq

net = Net()

mtor = probe.GradOutputMonitor(net, instance=snn.Leaky())

with torch.no_grad():
    y = net(torch.rand([1, 8]))
    print(f'mtor.records={mtor.records}')
    print(f'mtor[0]={mtor[0]}')
    print(f'mtor.monitored_layers={mtor.monitored_layers}')
    print(f"mtor['lif1']={mtor['lif1']}")
Parameters:
  • net (nn.Module) – Network model (either wrapped in Sequential container or as a class)

  • instance (Any or tuple) – Instance of modules to be monitored. If None, defaults to type(net)

  • function_on_grad_output (Callable, optional) – Function that is applied to the monitored modules’ gradients

create_hook(name)[source]
class snntorch.functional.probe.InputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_input: ~typing.Callable = <function InputMonitor.<lambda>>)[source]

Bases: BaseMonitor

A monitor to record the input of each neuron layer (e.g. Leaky) in a network. All input data is recorded in self.record as data type ‘’list’’. Call self.enable() or self.disable() to enable or disable the monitor. Call self.clear_recorded_data() to clear recorded data.

Example:

import snntorch as snn
from snntorch.functional import probe

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(8, 4)
        self.lif1 = snn.Leaky()
        self.fc2 = nn.Linear(4, 2)
        self.lif2 = snn.Leaky()

    def forward(self, x_seq: torch.Tensor):
        x_seq = self.fc1(x_seq)
        x_seq = self.lif1(x_seq)
        x_seq = self.fc2(x_seq)
        x_seq = self.lif2(x_seq)
        return x_seq

net = Net()

monitor = probe.InputMonitor(net, instance=snn.Leaky())

with torch.no_grad():
    y = net(torch.rand([1, 8]))
    print(f'monitor.records={monitor.records}')
    print(f'monitor[0]={monitor[0]}')
    print(f'monitor.monitored_layers={monitor.monitored_layers}')
    print(f"monitor['lif1']={monitor['lif1']}")
Parameters:
  • net (nn.Module) – Network model (either wrapped in Sequential container or as a class)

  • instance (Any or tuple) – Instance of modules to be monitored. If None, defaults to type(net)

  • function_on_input (Callable, optional) – Function that is applied to the monitored modules’ input

create_hook(name)[source]
class snntorch.functional.probe.OutputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_output: ~typing.Callable = <function OutputMonitor.<lambda>>)[source]

Bases: BaseMonitor

A monitor to record the output spikes of each specific neuron layer (e.g. Leaky) in a network. All output data is recorded in self.record as data type ‘’list’’. Call self.enable() or self.disable() to enable or disable the monitor. Call self.clear_recorded_data() to clear recorded data.

Example:

import snntorch as snn
from snntorch.functional import probe

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(8, 4)
        self.lif1 = snn.Leaky()
        self.fc2 = nn.Linear(4, 2)
        self.lif2 = snn.Leaky()

    def forward(self, x_seq: torch.Tensor):
        x_seq = self.fc1(x_seq)
        x_seq = self.lif1(x_seq)
        x_seq = self.fc2(x_seq)
        x_seq = self.lif2(x_seq)
        return x_seq

net = Net()

monitor = probe.OutputMonitor(net, instance=snntorch.Leaky())

with torch.no_grad():
    y = net(torch.rand([1, 8]))
    print(f'monitor.records={monitor.records}')
    print(f'monitor[0]={monitor[0]}')
    print(f'monitor.monitored_layers={monitor.monitored_layers}')
    print(f"monitor['lif1']={monitor['lif1']}")
Parameters:
  • net (nn.Module) – Network model (either wrapped in Sequential container or as a class)

  • instance (Any or tuple) – Instance of modules to be monitored. If None, defaults to type(net)

  • function_on_output (Callable, optional) – Function that is applied to the monitored modules’ outputs

create_hook(name)[source]
snntorch.functional.probe.unpack_len1_tuple(x: tuple)[source]