snntorch.functional
snntorch.functional
implements common arithmetic operations applied to spiking neurons, such as loss and regularization functions, and state quantization etc.
How to use functional
To use snntorch.functional
you assign the function state to a variable, and then call that variable.
Example:
import snntorch as snn
import snntorch.functional as SF
net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=betas)
criterion = SF.ce_count_loss() # apply cross-entropy to spike count
spk_rec, mem_rec = net(input_data)
loss = loss_fn(spk_rec, targets)
optimizer.zero_grad()
loss.backward()
# Weight Update
optimizer.step()
Accuracy Functions
- snntorch.functional.acc.accuracy_rate(spk_out, targets, population_code=False, num_classes=False)[source]
Use spike count to measure accuracy.
- Parameters:
spk_out (torch.Tensor) – Output spikes of shape [num_steps x batch_size x num_outputs]
targets (torch.Tensor) – Target tensor (without one-hot-encoding) of shape [batch_size]
- Returns:
accuracy
- Return type:
numpy.float64
- snntorch.functional.acc.accuracy_temporal(spk_out, targets)[source]
Use spike timing to measure accuracy.
- Parameters:
spk_out (torch.Tensor) – Output spikes of shape [num_steps x batch_size x num_outputs]
targets (torch.Tensor) – Target tensor (without one-hot-encoding) of shape [batch_size]
- Returns:
accuracy
- Return type:
numpy.float64
Loss Functions
- class snntorch.functional.loss.SpikeTime(target_is_time=False, on_target=0, off_target=-1, tolerance=0, multi_spike=False)[source]
Bases:
Module
Used by ce_temporal_loss and mse_temporal_loss to convert spike outputs into spike times.
- class FirstSpike(*args, **kwargs)[source]
Bases:
Function
Convert spk_rec of 1/0s [TxBxN] –> first spike time [BxN]. Linearize df/dS=-1 if spike, 0 if no spike.
- static backward(ctx, grad_output)[source]
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- class MultiSpike(*args, **kwargs)[source]
Bases:
Function
Convert spk_rec of 1/0s [TxBxN] –> first F spike times [FxBxN]. Linearize df/dS=-1 if spike, 0 if no spike.
- static backward(ctx, grad_output)[source]
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, spk_rec, spk_count, device='cpu')[source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class Tolerance(*args, **kwargs)[source]
Bases:
Function
If spike time is ‘close enough’ to target spike within tolerance, set the time to target for loss calc only.
- static backward(ctx, grad_output)[source]
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, spk_time, target, tolerance)[source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- forward(spk_out, targets)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- label_to_multi_spike(targets, num_outputs)[source]
Convert labels from neuron index (dim: B) to multiple spike times (dim: F x B x N). F is the number of spikes per neuron. Assumes target is iterable along F.
- class snntorch.functional.loss.ce_count_loss(population_code=False, num_classes=False, reduction='mean', weight=None)[source]
Bases:
LossFunctions
Cross Entropy Spike Count Loss.
The spikes at each time step [num_steps x batch_size x num_outputs] are accumulated and then passed through the Cross Entropy Loss function. This criterion combines log_softmax and NLLLoss in a single function. The Cross Entropy Loss encourages the correct class to fire at all time steps, and aims to suppress incorrect classes from firing.
The Cross Entropy Count Loss accumulates spikes first, and applies Cross Entropy Loss only once. In contrast, the Cross Entropy Rate Loss applies the Cross Entropy function at every time step.
Example:
import snntorch.functional as SF # if not using population codes (i.e., more output neurons than there are classes) loss_fn = ce_count_loss() loss = loss_fn(spk_out, targets) # if using population codes; e.g., 200 output neurons, 10 output classes --> 20 output neurons p/class loss_fn = ce_count_loss(population_code=True, num_classes=10) loss = loss_fn(spk_out, targets)
- Parameters:
population_code (bool, optional) – Specify if a population code is applied, i.e., the number of outputs is greater than the number of classes. Defaults to
False
num_classes (int, optional) – Number of output classes must be specified if
population_code=True
. Must be a factor of the number of output neurons if population code is enabled. Defaults toFalse
- Returns:
Loss
- Return type:
torch.Tensor (single element)
- class snntorch.functional.loss.ce_max_membrane_loss(reduction='mean', weight=None)[source]
Bases:
LossFunctions
Cross Entropy Max Membrane Loss. When called, the maximum membrane potential value for each output neuron is sampled and passed through the Cross Entropy Loss Function. This criterion combines log_softmax and NLLLoss in a single function. The Cross Entropy Loss encourages the maximum membrane potential of the correct class to increase, while suppressing the maximum membrane potential of incorrect classes. This function is adopted from SpyTorch by Friedemann Zenke.
Example:
import snntorch.functional as SF loss_fn = SF.ce_max_membrane_loss() loss = loss_fn(outputs, targets)
- Parameters:
mem_out (torch.Tensor) – The output tensor of the SNN’s membrane potential, of the dimension timestep * batch_size * num_output_neurons
targets (torch.Tensor) – The tensor containing the targets of the current mini-batch, of the dimension batch_size
- Returns:
Loss
- Return type:
torch.Tensor (single element)
- class snntorch.functional.loss.ce_rate_loss(reduction='mean', weight=None)[source]
Bases:
LossFunctions
Cross Entropy Spike Rate Loss. When called, the spikes at each time step are sequentially passed through the Cross Entropy Loss function. This criterion combines log_softmax and NLLLoss in a single function. The losses are accumulated over time steps to give the final loss. The Cross Entropy Loss encourages the correct class to fire at all time steps, and aims to suppress incorrect classes from firing.
The Cross Entropy Rate Loss applies the Cross Entropy function at every time step. In contrast, the Cross Entropy Count Loss accumulates spikes first, and applies Cross Entropy Loss only once.
Example:
import snntorch.functional as SF loss_fn = SF.ce_rate_loss() loss = loss_fn(outputs, targets)
- Returns:
Loss
- Return type:
torch.Tensor (single element)
- class snntorch.functional.loss.ce_temporal_loss(inverse='negate', reduction='mean', weight=None)[source]
Bases:
object
Cross Entropy Temporal Loss.
The cross entropy loss of an ‘inverted’ first spike time of each output neuron [batch_size x num_outputs] is calculated. The ‘inversion’ is applied such that maximizing the value of the correct class decreases the first spike time (i.e., earlier spike).
Options for inversion include:
inverse='negate'
which applies (-1 * output), orinverse='reciprocal'
which takes (1/output).Note that the derivative of each spike time with respect to the spike df/dU is non-differentiable for most neuron classes, and is set to a sign estimator of -1. I.e., increasing membrane potential causes a proportionately earlier firing time.
Index labels are passed as the target. To specify the exact spike time, use
mse_temporal_loss
instead.Note: After spike times with specified targets, no penalty is applied for subsequent spiking.
Example:
import torch import snntorch.functional as SF # correct classes aimed to fire by default at t=0, incorrect at final step loss_fn = ce_temporal_loss() loss = loss_fn(spk_out, targets)
- Parameters:
inverse (str, optional) – Specify how to invert output before taking cross entropy. Either scale by (-1 * x) with
inverse='negate'
or take the reciprocal (1/x) withinverse='reciprocal'
. Defaults tonegate
- Returns:
Loss
- Return type:
torch.Tensor (single element)
- class snntorch.functional.loss.mse_count_loss(correct_rate=1, incorrect_rate=0, population_code=False, num_classes=False, reduction='mean', weight=None)[source]
Bases:
LossFunctions
Mean Square Error Spike Count Loss. When called, the total spike count is accumulated over time for each neuron. The target spike count for correct classes is set to (num_steps * correct_rate), and for incorrect classes (num_steps * incorrect_rate). The spike counts and target spike counts are then applied to a
Mean Square Error Loss Function.
This function is adopted from SLAYER by Sumit Bam Shrestha and Garrick Orchard.
Example:
import snntorch.functional as SF loss_fn = SF.mse_count_loss(correct_rate=0.75, incorrect_rate=0.25) loss = loss_fn(outputs, targets)
- Parameters:
correct_rate (float, optional) – Firing frequency of correct class as a ratio, e.g.,
1
promotes firing at every step;0.5
promotes firing at 50% of steps,0
discourages any firing, defaults to1
incorrect_rate (float, optional) – Firing frequency of incorrect class(es) as a ratio, e.g.,
1
promotes firing at every step;0.5
promotes firing at 50% of steps,0
discourages any firing, defaults to1
population_code (bool, optional) – Specify if a population code is applied, i.e., the number of outputs is greater than the number of classes. Defaults to
False
num_classes (int, optional) – Number of output classes must be specified if
population_code=True
. Must be a factor of the number of output neurons if population code is enabled. Defaults toFalse
- Returns:
Loss
- Return type:
torch.Tensor (single element)
- class snntorch.functional.loss.mse_membrane_loss(time_var_targets=False, on_target=1, off_target=0, reduction='mean', weight=None)[source]
Bases:
LossFunctions
Mean Square Error Membrane Loss. When called, pass the output membrane of shape [num_steps x batch_size x num_outputs] and the target tensor of membrane potential. The membrane potential and target are then applied to a Mean Square Error Loss Function. This function is adopted from Spike-Op by Jason K. Eshraghian.
Example:
import snntorch.functional as SF # if targets are the same at each time-step loss_fn = mse_membrane_loss(time_var_targets=False) loss = loss_fn(outputs, targets) # if targets are time-varying loss_fn = mse_membrane_loss(time_var_targets=True) loss = loss_fn(outputs, targets)
- Parameters:
time_var_targets – Specifies whether the targets are time-varying, defaults to
False
on_target (float, optional) – Specify target membrane potential for correct class, defaults to
1
off_target (float, optional) – Specify target membrane potential for incorrect class, defaults to
0
- Returns:
Loss
- Return type:
torch.Tensor (single element)
- class snntorch.functional.loss.mse_temporal_loss(target_is_time=False, on_target=0, off_target=-1, tolerance=0, multi_spike=False, reduction='mean', weight=None)[source]
Bases:
object
Mean Square Error Temporal Loss.
The first spike time of each output neuron [batch_size x num_outputs] is measured against the desired spike time with the Mean Square Error Loss Function. Note that the derivative of each spike time with respect to the spike df/dU is non-differentiable for most neuron classes, and is set to a sign estimator of -1. I.e., increasing membrane potential causes a proportionately earlier firing time.
The Mean Square Error Temporal Loss can account for multiple spikes by setting
multi_spike=True
. If the actual spike time is close enough to the target spike time within a given tolerance, e.g.,tolerance = 5
time steps, then it does not contribute to the loss.Index labels are passed as the target by default. To enable passing in the spike time(s) for output neuron(s), set
target_is_time=True
.Note: After spike times with specified targets, no penalty is applied for subsequent spiking. To eliminate later spikes, an additional target should be applied.
Example:
import torch import snntorch.functional as SF # default takes in idx labels as targets # correct classes aimed to fire by default at t=0, incorrect at t=-1 (final time step) loss_fn = mse_temporal_loss() loss = loss_fn(spk_out, targets) # as above, but correct class fire @ t=5, incorrect at t=100 with a tolerance of 2 steps loss_fn = mse_temporal_loss(on_target=5, off_target=100, tolerance=2) loss = loss_fn(spk_out, targets) # as above with multiple spike time targets on_target = torch.tensor(5, 10) off_target = torch.tensor(100, 105) loss_fn = mse_temporal_loss(on_target=on_target, off_target=off_target, tolerance=2) loss = loss_fn(spk_out, targets) # specify first spike time for 5 neurons individually, zero tolerance target = torch.tensor(5, 10, 15, 20, 25) loss_fn = mse_temporal_loss(target_is_time=True) loss = loss_fn(spk_out, target)
- Parameters:
target_is_time (bool, optional) – Specify if target is specified as spike times (True) or as neuron indexes (False). Defaults to
False
on_target (int (or interable over multiple int if
multi_spike=True
), optional) – Spike time for correct classes (only if target_is_time=False). Defaults to0
off_target (int (or interable over multiple int if
multi_spike=True
), optional) – Spike time for incorrect classes (only if target_is_time=False). Defaults to-1
, i.e., final time steptolerance (int, optional) – If the distance between the spike time and target is less than the specified tolerance, then it does not contribute to the loss. Defaults to
0
.multi_spike (bool, optional) – Specify if multiple spikes in target. Defaults to
False
- Returns:
Loss
- Return type:
torch.Tensor (single element)
Regularization Functions
State Quantization
- class snntorch.functional.quant.StateQuant(*args, **kwargs)[source]
Bases:
Function
Wrapper function for state_quant
- static backward(ctx, grad_output)[source]
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, input_, levels)[source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- snntorch.functional.quant.state_quant(num_bits=8, uniform=True, thr_centered=True, threshold=1, lower_limit=0, upper_limit=0.2, multiplier=None)[source]
Quantization-Aware Training with spiking neuron states.
Note: for weight quantization, we recommend using Brevitas or another pre-existing PyTorch-friendly library.
Uniform and non-uniform quantization can be applied in various modes by specifying
uniform=True
.Valid quantization levels can be centered about 0 or threshold by specifying
thr_centered=True
.upper_limit
andlower_limit
specify the proportion of how far valid levels go above and below the positive and negative threshold/ E.g., upper_limit=0.2 means the maximum valid state is 20% higher than the value specified inthreshold
.Example:
import torch import snntorch as snn from snntorch.functional import quant beta = 0.5 thr = 5 # set the quantization parameters q_lif = quant.state_quant(num_bits=4, uniform=True, threshold=thr) # specifying state_quant applies state-quantization to the # hidden state(s) automatically lif = snn.Leaky(beta=beta, threshold=thr, state_quant=q_lif) rand_input = torch.rand(1) mem = lif.init_leaky() # forward-pass for one step spk, mem = lif(rand_input, mem)
Note: Quantization-Aware training is focused on modelling a reduced precision network, but does not in of itself accelerate low-precision models. Hidden states are still represented as full precision values for compatibility with PyTorch. For accelerated performance or constrained-memory, the model should be exported to a downstream backend.
- Parameters:
num_bits (int, optional) – Number of bits to quantize state variables to, defaults to
8
uniform (Bool, optional) – Applies uniform quantization if specified, non-uniform if unspecified, defaults to
True
thr_centered (Bool, optional) – For non-uniform quantization, specifies if valid states should be centered (densely clustered) around the threshold rather than at 0, defaults to
True
threshold (float, optional) – Specifies the threshold, defaults to
1
lower_limit (float, optional) – Specifies how far below (-threshold) the lowest valid state can be, i.e., (-threshold - threshold*lower_limit), defaults to
0
upper_limit (float, optional) – Specifies how far above (threshold) the highest valid state can be, i.e., (threshold + threshold*upper_limit), defaults to
0.2
multiplier (float, optional) – For non-uniform distributions, specify the base of the exponential. If
None
, an appropriate value is set internally based onnum_bits
, defaults toNone
Probe
- class snntorch.functional.probe.AttributeMonitor(attribute_name: str, pre_forward: bool, net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_attribute: ~typing.Callable = <function AttributeMonitor.<lambda>>)[source]
Bases:
BaseMonitor
A monitor to record the attribute (e.g. membrane potential) of a specific neuron layer (e.g. Leaky) in a network. The attribute name can be specified as the first argument of this function. All attribute data is recorded in
self.record
as data type ‘’list’’. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear recorded data.Example:
import snntorch as snn from snntorch.functional import probe import torch from torch import nn class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.lif2(x_seq) return x_seq net = Net() monitor = probe.AttributeMonitor('mem', False, net, instance=snn.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'monitor.records={monitor.records}') print(f'monitor[0]={monitor[0]}') print(f'monitor.monitored_layers={monitor.monitored_layers}') print(f"monitor['lif1']={monitor['lif1']}")
- Parameters:
attribute_name – Attribute’s name of probed neuron layer (e.g., mem, syn, etc.)
pre_forward (bool) – If
True
, record the attribute value before the forward pass, otherwise record the value after forward pass.net (nn.Module) – Network model (either wrapped in Sequential container or as a class)
instance (Any or tuple) – Instance of modules to be monitored. If
None
, defaults totype(net)
function_on_attribute (Callable, optional) – Function that is applied to the monitored modules’ attribute
- class snntorch.functional.probe.GradInputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_grad_input: ~typing.Callable = <function GradInputMonitor.<lambda>>)[source]
Bases:
BaseMonitor
A monitor to record the input gradient of each neuron layer (e.g. Leaky) in a network. All input gradient data is recorded in
self.record
as data type ‘’list’’. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear recorded data.Example:
import snntorch as snn from snntorch.functional import probe import torch from torch import nn class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.lif2(x_seq) return x_seq net = Net() monitor = probe.GradInputMonitor(net, instance=snn.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'monitor.records={monitor.records}') print(f'monitor[0]={monitor[0]}') print(f'monitor.monitored_layers={monitor.monitored_layers}') print(f"monitor['lif1']={monitor['lif1']}")
- Parameters:
net (nn.Module) – Network model (either wrapped in Sequential container or as a class)
instance (Any or tuple) – Instance of modules to be monitored. If
None
, defaults totype(net)
function_on_grad_input (Callable, optional) – Function that is applied to the monitored modules’ gradients
- class snntorch.functional.probe.GradOutputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_grad_output: ~typing.Callable = <function GradOutputMonitor.<lambda>>)[source]
Bases:
BaseMonitor
A monitor to record the output gradient of each specific neuron layer (e.g. Leaky) in a network. All output gradient data is recorded in
self.record
as data type ‘’list’’. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear recorded data.Example:
import snntorch as snn from snntorch.functional import probe import torch from torch import nn class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.lif2(x_seq) return x_seq net = Net() mtor = probe.GradOutputMonitor(net, instance=snn.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') print(f'mtor[0]={mtor[0]}') print(f'mtor.monitored_layers={mtor.monitored_layers}') print(f"mtor['lif1']={mtor['lif1']}")
- Parameters:
net (nn.Module) – Network model (either wrapped in Sequential container or as a class)
instance (Any or tuple) – Instance of modules to be monitored. If
None
, defaults totype(net)
function_on_grad_output (Callable, optional) – Function that is applied to the monitored modules’ gradients
- class snntorch.functional.probe.InputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_input: ~typing.Callable = <function InputMonitor.<lambda>>)[source]
Bases:
BaseMonitor
A monitor to record the input of each neuron layer (e.g. Leaky) in a network. All input data is recorded in
self.record
as data type ‘’list’’. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear recorded data.Example:
import snntorch as snn from snntorch.functional import probe import torch from torch import nn class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.lif2(x_seq) return x_seq net = Net() monitor = probe.InputMonitor(net, instance=snn.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'monitor.records={monitor.records}') print(f'monitor[0]={monitor[0]}') print(f'monitor.monitored_layers={monitor.monitored_layers}') print(f"monitor['lif1']={monitor['lif1']}")
- Parameters:
net (nn.Module) – Network model (either wrapped in Sequential container or as a class)
instance (Any or tuple) – Instance of modules to be monitored. If
None
, defaults totype(net)
function_on_input (Callable, optional) – Function that is applied to the monitored modules’ input
- class snntorch.functional.probe.OutputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Any = None, function_on_output: ~typing.Callable = <function OutputMonitor.<lambda>>)[source]
Bases:
BaseMonitor
A monitor to record the output spikes of each specific neuron layer (e.g. Leaky) in a network. All output data is recorded in
self.record
as data type ‘’list’’. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear recorded data.Example:
import snntorch as snn from snntorch.functional import probe import torch from torch import nn class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(8, 4) self.lif1 = snn.Leaky() self.fc2 = nn.Linear(4, 2) self.lif2 = snn.Leaky() def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.lif1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.lif2(x_seq) return x_seq net = Net() monitor = probe.OutputMonitor(net, instance=snntorch.Leaky()) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'monitor.records={monitor.records}') print(f'monitor[0]={monitor[0]}') print(f'monitor.monitored_layers={monitor.monitored_layers}') print(f"monitor['lif1']={monitor['lif1']}")
- Parameters:
net (nn.Module) – Network model (either wrapped in Sequential container or as a class)
instance (Any or tuple) – Instance of modules to be monitored. If
None
, defaults totype(net)
function_on_output (Callable, optional) – Function that is applied to the monitored modules’ outputs