Source code for snntorch.functional.reg

import torch


[docs] class l1_rate_sparsity: """L1 regularization using total spike count as the penalty term. Lambda is a scalar factor for regularization.""" def __init__(self, Lambda=1e-5): self.Lambda = Lambda self.__name__ = "l1_rate_sparsity" def __call__(self, spk_out): return self.Lambda * torch.sum(spk_out)
# # def l2_sparsity(mem_out, Lambda=1e-6): # # """L2 regularization using accumulated membrane potential # as the penalty term.""" # # return Lambda * (torch.sum(mem_out)**2)