snntorch.backprop

snntorch.backprop is a module implementing various time-variant backpropagation algorithms. Each method will perform the forward-pass, backward-pass, and parameter update across all time steps in a single line of code.

How to use backprop

To use snntorch.backprop you must first construct a network, determine a loss criterion, and select an optimizer. When initializing neurons, set init_hidden=True. This enables the methods in snntorch.backprop to automatically clear the hidden state variables, as well as detach them from the computational graph when necessary.

Note

The first dimension of input data is assumed to be time. The built-in backprop functions iterate through the first dimension of data by default. For time-invariant inputs, set time_var=False.

Example:

net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=betas)
criterion = nn.CrossEntropyLoss()

# Time-variant input data
for input, target in dataset:
   loss = BPTT(net, input, target, num_steps, batch_size, optimizer, criterion)

# Time-invariant input data
for input, targets in dataset:
   loss = BPTT(net, input, target, num_steps, batch_size, optimizer, criterion, time_var=False)
snntorch.backprop.BPTT(net, dataloader, optimizer, criterion, num_steps=False, time_var=True, time_first=True, regularization=False, device='cpu')[source]

Backpropagation through time. LIF layers require parameter init_hidden = True. A forward pass is applied for each time step while the loss accumulates. The backward pass and parameter update is only applied at the end of each time step sequence. BPTT is equivalent to TBPTT for the case where num_steps = K.

Example:

import snntorch as snn
import snntorch.functional as SF
from snntorch import utils
from snntorch import backprop
import torch
import torch.nn as nn

lif1 = snn.Leaky(beta=0.9, init_hidden=True)
lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True)

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784,500),
                    lif1,
                    nn.Linear(500, 10),
                    lif2).to(device)

device = torch.device("cuda") if torch.cuda.is_available() else
torch.device("cpu")
num_steps = 100

optimizer = torch.optim.Adam(net.parameters(), lr=5e-4,
betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss()
reg_fn = SF.l1_rate_sparsity()


# train_loader is of type torch.utils.data.DataLoader
# if input data is time-static, set time_var=False, and specify
# num_steps.
# if input data is time-varying, set time_var=True and do not
# specify num_steps.

for epoch in range(5):
    loss = backprop.RTRL(net, train_loader, optimizer=optimizer,
    criterion=loss_fn, num_steps=num_steps, time_var=False,
    regularization=reg_fn, device=device)
Parameters:
  • net (torch.nn.modules.container.Sequential) – Network model (either wrapped in Sequential container or as a class)

  • dataloader (torch.utils.data.DataLoader) – DataLoader containing data and targets

  • optimizer (torch.optim) – Optimizer used, e.g., torch.optim.adam.Adam

  • criterion (snn.functional.LossFunctions) – Loss criterion from snntorch.functional, e.g., snn.functional.mse_count_loss()

  • num_steps (int, optional) – Number of time steps. Does not need to be specified if time_var=True.

  • time_var (Bool, optional) – Set to True if input data is time-varying [T x B x dims]. Otherwise, set to false if input data is time-static [B x dims], defaults to True

  • time_first (Bool, optional) – Set to False if first dimension of data is not time [B x T x dims] AND must also be permuted to [T x B x dims], defaults to True

  • regularization (snn.functional regularization function, optional) – Option to add a regularization term to the loss function

  • device (string, optional) – Specify either “cuda” or “cpu”, defaults to “cpu”

Returns:

return average loss for one epoch

Return type:

torch.Tensor

snntorch.backprop.RTRL(net, dataloader, optimizer, criterion, num_steps=False, time_var=True, time_first=True, regularization=False, device='cpu')[source]

Real-time Recurrent Learning. LIF layers require parameter init_hidden = True. A forward pass, backward pass and parameter update are applied at each time step. RTRL is equivalent to TBPTT for the case where K = 1.

Example:

import snntorch as snn
import snntorch.functional as SF
from snntorch import utils
from snntorch import backprop
import torch
import torch.nn as nn

lif1 = snn.Leaky(beta=0.9, init_hidden=True)
lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True)

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784,500),
                    lif1,
                    nn.Linear(500, 10),
                    lif2).to(device)

device = torch.device("cuda") if torch.cuda.is_available() else
torch.device("cpu")
num_steps = 100

optimizer = torch.optim.Adam(net.parameters(), lr=5e-4,
betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss()
reg_fn = SF.l1_rate_sparsity()

# train_loader is of type torch.utils.data.DataLoader
# if input data is time-static, set time_var=False, and
specify num_steps.

for epoch in range(5):
    loss = backprop.RTRL(net, train_loader, optimizer=optimizer,
    criterion=loss_fn, num_steps=num_steps, time_var=False,
    regularization=reg_fn, device=device)
Parameters:
  • net (torch.nn.modules.container.Sequential) – Network model (either wrapped in Sequential container or as a class)

  • dataloader (torch.utils.data.DataLoader) – DataLoader containing data and targets

  • optimizer (torch.optim) – Optimizer used, e.g., torch.optim.adam.Adam

  • criterion (snn.functional.LossFunctions) – Loss criterion from snntorch.functional, e.g., snn.functional.mse_count_loss()

  • num_steps (int, optional) – Number of time steps. Does not need to be specified if time_var=True.

  • time_var (Bool, optional) – Set to True if input data is time-varying [T x B x dims]. Otherwise, set to false if input data is time-static [B x dims], defaults to True

  • time_first (Bool, optional) – Set to False if first dimension of data is not time [B x T x dims] AND must also be permuted to [T x B x dims], defaults to True

  • regularization (snn.functional regularization function, optional) – Option to add a regularization term to the loss function

  • device (string, optional) – Specify either “cuda” or “cpu”, defaults to “cpu”

  • K (int, optional) – Number of time steps to process per weight update, defaults to 1

Returns:

return average loss for one epoch

Return type:

torch.Tensor

snntorch.backprop.TBPTT(net, dataloader, optimizer, criterion, num_steps=False, time_var=True, time_first=True, regularization=False, device='cpu', K=1)[source]

Truncated backpropagation through time. LIF layers require parameter init_hidden = True. Weight updates are performed every K time steps.

Example:

import snntorch as snn
import snntorch.functional as SF
from snntorch import utils
from snntorch import backprop
import torch
import torch.nn as nn

lif1 = snn.Leaky(beta=0.9, init_hidden=True)
lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True)

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784,500),
                    lif1,
                    nn.Linear(500, 10),
                    lif2).to(device)

device = torch.device("cuda") if torch.cuda.is_available() else
torch.device("cpu")
num_steps = 100

optimizer = torch.optim.Adam(net.parameters(), lr=5e-4,
betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss()
reg_fn = SF.l1_rate_sparsity()

# train_loader is of type torch.utils.data.DataLoader
# if input data is time-static, set time_var=False, and specify
# num_steps.
# if input data is time-varying, set time_var=True and do not
# specify num_steps.
# backprop is automatically applied every K=40 time steps

for epoch in range(5):
    loss = backprop.RTRL(net, train_loader, optimizer=optimizer,
    criterion=loss_fn, num_steps=num_steps, time_var=False,
    regularization=reg_fn, device=device, K=40)
Parameters:
  • net (torch.nn.modules.container.Sequential) – Network model (either wrapped in Sequential container or as a class)

  • dataloader (torch.utils.data.DataLoader) – DataLoader containing data and targets

  • optimizer (torch.optim) – Optimizer used, e.g., torch.optim.adam.Adam

  • criterion (snn.functional.LossFunctions) – Loss criterion from snntorch.functional, e.g., snn.functional.mse_count_loss()

  • num_steps (int, optional) – Number of time steps. Does not need to be specified if time_var=True.

  • time_var (Bool, optional) – Set to True if input data is time-varying [T x B x dims]. Otherwise, set to false if input data is time-static [B x dims], defaults to True

  • time_first (Bool, optional) – Set to False if first dimension of data is not time [B x T x dims] AND must also be permuted to [T x B x dims], defaults to True

  • regularization (snn.functional regularization function, optional) – Option to add a regularization term to the loss function

  • device (string, optional) – Specify either “cuda” or “cpu”, defaults to “cpu”

  • K (int, optional) – Number of time steps to process per weight update, defaults to 1

Returns:

return average loss for one epoch

Return type:

torch.Tensor