snntorch.backprop
snntorch.backprop
is a module implementing various time-variant backpropagation algorithms. Each method will perform the forward-pass, backward-pass, and parameter update across all time steps in a single line of code.
How to use backprop
To use snntorch.backprop
you must first construct a network, determine a loss criterion, and select an optimizer. When initializing neurons, set init_hidden=True
. This enables the methods in snntorch.backprop
to automatically clear the hidden state variables, as well as detach them from the computational graph when necessary.
Note
The first dimension of input data
is assumed to be time. The built-in backprop functions iterate through the first dimension of data
by default. For time-invariant inputs, set time_var=False
.
Example:
net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=betas)
criterion = nn.CrossEntropyLoss()
# Time-variant input data
for input, target in dataset:
loss = BPTT(net, input, target, num_steps, batch_size, optimizer, criterion)
# Time-invariant input data
for input, targets in dataset:
loss = BPTT(net, input, target, num_steps, batch_size, optimizer, criterion, time_var=False)
- snntorch.backprop.BPTT(net, dataloader, optimizer, criterion, num_steps=False, time_var=True, time_first=True, regularization=False, device='cpu')[source]
Backpropagation through time. LIF layers require parameter
init_hidden = True
. A forward pass is applied for each time step while the loss accumulates. The backward pass and parameter update is only applied at the end of each time step sequence. BPTT is equivalent to TBPTT for the case wherenum_steps = K
.Example:
import snntorch as snn import snntorch.functional as SF from snntorch import utils from snntorch import backprop import torch import torch.nn as nn lif1 = snn.Leaky(beta=0.9, init_hidden=True) lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) net = nn.Sequential(nn.Flatten(), nn.Linear(784,500), lif1, nn.Linear(500, 10), lif2).to(device) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") num_steps = 100 optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999)) loss_fn = SF.mse_count_loss() reg_fn = SF.l1_rate_sparsity() # train_loader is of type torch.utils.data.DataLoader # if input data is time-static, set time_var=False, and specify # num_steps. # if input data is time-varying, set time_var=True and do not # specify num_steps. for epoch in range(5): loss = backprop.RTRL(net, train_loader, optimizer=optimizer, criterion=loss_fn, num_steps=num_steps, time_var=False, regularization=reg_fn, device=device)
- Parameters:
net (torch.nn.modules.container.Sequential) – Network model (either wrapped in Sequential container or as a class)
dataloader (torch.utils.data.DataLoader) – DataLoader containing data and targets
optimizer (torch.optim) – Optimizer used, e.g., torch.optim.adam.Adam
criterion (snn.functional.LossFunctions) – Loss criterion from snntorch.functional, e.g., snn.functional.mse_count_loss()
num_steps (int, optional) – Number of time steps. Does not need to be specified if
time_var=True
.time_var (Bool, optional) – Set to
True
if input data is time-varying [T x B x dims]. Otherwise, set to false if input data is time-static [B x dims], defaults toTrue
time_first (Bool, optional) – Set to
False
if first dimension of data is not time [B x T x dims] AND must also be permuted to [T x B x dims], defaults toTrue
regularization (snn.functional regularization function, optional) – Option to add a regularization term to the loss function
device (string, optional) – Specify either “cuda” or “cpu”, defaults to “cpu”
- Returns:
return average loss for one epoch
- Return type:
torch.Tensor
- snntorch.backprop.RTRL(net, dataloader, optimizer, criterion, num_steps=False, time_var=True, time_first=True, regularization=False, device='cpu')[source]
Real-time Recurrent Learning. LIF layers require parameter
init_hidden = True
. A forward pass, backward pass and parameter update are applied at each time step. RTRL is equivalent to TBPTT for the case whereK = 1
.Example:
import snntorch as snn import snntorch.functional as SF from snntorch import utils from snntorch import backprop import torch import torch.nn as nn lif1 = snn.Leaky(beta=0.9, init_hidden=True) lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) net = nn.Sequential(nn.Flatten(), nn.Linear(784,500), lif1, nn.Linear(500, 10), lif2).to(device) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") num_steps = 100 optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999)) loss_fn = SF.mse_count_loss() reg_fn = SF.l1_rate_sparsity() # train_loader is of type torch.utils.data.DataLoader # if input data is time-static, set time_var=False, and specify num_steps. for epoch in range(5): loss = backprop.RTRL(net, train_loader, optimizer=optimizer, criterion=loss_fn, num_steps=num_steps, time_var=False, regularization=reg_fn, device=device)
- Parameters:
net (torch.nn.modules.container.Sequential) – Network model (either wrapped in Sequential container or as a class)
dataloader (torch.utils.data.DataLoader) – DataLoader containing data and targets
optimizer (torch.optim) – Optimizer used, e.g., torch.optim.adam.Adam
criterion (snn.functional.LossFunctions) – Loss criterion from snntorch.functional, e.g., snn.functional.mse_count_loss()
num_steps (int, optional) – Number of time steps. Does not need to be specified if
time_var=True
.time_var (Bool, optional) – Set to
True
if input data is time-varying [T x B x dims]. Otherwise, set to false if input data is time-static [B x dims], defaults toTrue
time_first (Bool, optional) – Set to
False
if first dimension of data is not time [B x T x dims] AND must also be permuted to [T x B x dims], defaults toTrue
regularization (snn.functional regularization function, optional) – Option to add a regularization term to the loss function
device (string, optional) – Specify either “cuda” or “cpu”, defaults to “cpu”
K (int, optional) – Number of time steps to process per weight update, defaults to
1
- Returns:
return average loss for one epoch
- Return type:
torch.Tensor
- snntorch.backprop.TBPTT(net, dataloader, optimizer, criterion, num_steps=False, time_var=True, time_first=True, regularization=False, device='cpu', K=1)[source]
Truncated backpropagation through time. LIF layers require parameter
init_hidden = True
. Weight updates are performed everyK
time steps.Example:
import snntorch as snn import snntorch.functional as SF from snntorch import utils from snntorch import backprop import torch import torch.nn as nn lif1 = snn.Leaky(beta=0.9, init_hidden=True) lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) net = nn.Sequential(nn.Flatten(), nn.Linear(784,500), lif1, nn.Linear(500, 10), lif2).to(device) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") num_steps = 100 optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999)) loss_fn = SF.mse_count_loss() reg_fn = SF.l1_rate_sparsity() # train_loader is of type torch.utils.data.DataLoader # if input data is time-static, set time_var=False, and specify # num_steps. # if input data is time-varying, set time_var=True and do not # specify num_steps. # backprop is automatically applied every K=40 time steps for epoch in range(5): loss = backprop.RTRL(net, train_loader, optimizer=optimizer, criterion=loss_fn, num_steps=num_steps, time_var=False, regularization=reg_fn, device=device, K=40)
- Parameters:
net (torch.nn.modules.container.Sequential) – Network model (either wrapped in Sequential container or as a class)
dataloader (torch.utils.data.DataLoader) – DataLoader containing data and targets
optimizer (torch.optim) – Optimizer used, e.g., torch.optim.adam.Adam
criterion (snn.functional.LossFunctions) – Loss criterion from snntorch.functional, e.g., snn.functional.mse_count_loss()
num_steps (int, optional) – Number of time steps. Does not need to be specified if
time_var=True
.time_var (Bool, optional) – Set to
True
if input data is time-varying [T x B x dims]. Otherwise, set to false if input data is time-static [B x dims], defaults toTrue
time_first (Bool, optional) – Set to
False
if first dimension of data is not time [B x T x dims] AND must also be permuted to [T x B x dims], defaults toTrue
regularization (snn.functional regularization function, optional) – Option to add a regularization term to the loss function
device (string, optional) – Specify either “cuda” or “cpu”, defaults to “cpu”
K (int, optional) – Number of time steps to process per weight update, defaults to
1
- Returns:
return average loss for one epoch
- Return type:
torch.Tensor