import torch
from torch._C import Value
import torch.nn as nn
from snntorch import spikegen
###############################################################
# When adding new loss/reg functions, update criterion_dict #
# and reg_dict in backprop.py #
###############################################################
dtype = torch.float
[docs]
class LossFunctions:
def __init__(self, reduction, weight):
self.reduction = reduction
self.weight = weight
def __call__(self, spk_out, targets):
loss = self._compute_loss(spk_out, targets)
return self._reduce(loss)
def _prediction_check(self, spk_out):
device = spk_out.device
num_steps = spk_out.size(0)
num_outputs = spk_out.size(-1)
return device, num_steps, num_outputs
def _population_code(self, spk_out, num_classes, num_outputs):
"""Count up spikes sequentially from output classes."""
if not num_classes:
raise Exception(
"``num_classes`` must be specified if "
"``population_code=True``."
)
if num_outputs % num_classes:
raise Exception(
f"``num_outputs {num_outputs} must be a factor "
f"of num_classes {num_classes}."
)
device = spk_out.device
pop_code = torch.zeros(tuple([spk_out.size(1)] + [num_classes])).to(
device
)
for idx in range(num_classes):
pop_code[:, idx] = (
spk_out[
:,
:,
int(num_outputs * idx / num_classes) : int(
num_outputs * (idx + 1) / num_classes
),
]
.sum(-1)
.sum(0)
)
return pop_code
def _intermediate_reduction(self):
return self.reduction if self.weight is None else 'none'
def _reduce(self, loss):
# if reduction was delayed due to weight
requires_reduction = self.weight is not None and self.reduction == 'mean'
return loss.mean() if requires_reduction else loss
[docs]
class ce_rate_loss(LossFunctions):
"""Cross Entropy Spike Rate Loss.
When called, the spikes at each time step are sequentially passed
through the Cross Entropy Loss function.
This criterion combines log_softmax and NLLLoss in a single function.
The losses are accumulated over time steps to give the final loss.
The Cross Entropy Loss encourages the correct class to fire at all
time steps, and aims to suppress incorrect classes from firing.
The Cross Entropy Rate Loss applies the Cross Entropy function at
every time step. In contrast, the Cross Entropy Count Loss accumulates
spikes first, and applies Cross Entropy Loss only once.
Example::
import snntorch.functional as SF
loss_fn = SF.ce_rate_loss()
loss = loss_fn(outputs, targets)
:return: Loss
:rtype: torch.Tensor (single element)
"""
def __init__(self, population_code=False, num_classes=False, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.population_code = population_code
self.num_classes = num_classes
self.__name__ = "ce_rate_loss"
def _compute_loss(self, spk_out, targets):
device, num_steps, num_outputs = self._prediction_check(spk_out)
log_softmax_fn = nn.LogSoftmax(dim=-1)
if self.population_code:
for idx in range(self.num_classes):
spk_out[
:,
:,
int(num_outputs * idx / self.num_classes) : int(
num_outputs * (idx + 1) / self.num_classes
),
]
weights = torch.Tensor([self.weight[0] if i < int(num_outputs/self.num_classes) else self.weight[1] for i in range(num_outputs) ]).to(device)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights)
else:
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)
log_p_y = log_softmax_fn(spk_out)
loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1)
loss = torch.zeros(loss_shape, dtype=dtype, device=device)
for step in range(num_steps):
loss += loss_fn(log_p_y[step], targets)
return loss / num_steps
[docs]
class ce_count_loss(LossFunctions):
"""Cross Entropy Spike Count Loss.
The spikes at each time step [num_steps x batch_size x num_outputs]
are accumulated and then passed through the Cross Entropy Loss function.
This criterion combines log_softmax and NLLLoss in a single function.
The Cross Entropy Loss encourages the correct class to fire at all
time steps, and aims to suppress incorrect classes from firing.
The Cross Entropy Count Loss accumulates spikes first, and applies
Cross Entropy Loss only once.
In contrast, the Cross Entropy Rate Loss applies the Cross Entropy
function at every time step.
Example::
import snntorch.functional as SF
# if not using population codes (i.e., more output neurons than
there are classes)
loss_fn = ce_count_loss()
loss = loss_fn(spk_out, targets)
# if using population codes; e.g., 200 output neurons, 10 output
classes --> 20 output neurons p/class
loss_fn = ce_count_loss(population_code=True, num_classes=10)
loss = loss_fn(spk_out, targets)
:param population_code: Specify if a population code is applied, i.e.,
the number of outputs is greater than the number of classes. Defaults
to ``False``
:type population_code: bool, optional
:param num_classes: Number of output classes must be specified if
``population_code=True``. Must be a factor of the number of output
neurons if population code is enabled. Defaults to ``False``
:type num_classes: int, optional
:return: Loss
:rtype: torch.Tensor (single element)
"""
def __init__(self, population_code=False, num_classes=False, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.population_code = population_code
self.num_classes = num_classes
self.__name__ = "ce_count_loss"
def _compute_loss(self, spk_out, targets):
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)
if self.population_code:
_, _, num_outputs = self._prediction_check(spk_out)
spike_count = self._population_code(
spk_out, self.num_classes, num_outputs
)
else:
spike_count = torch.sum(spk_out, 0) # B x C
log_p_y = log_softmax_fn(spike_count)
loss = loss_fn(log_p_y, targets)
return loss
[docs]
class ce_max_membrane_loss(LossFunctions):
"""Cross Entropy Max Membrane Loss.
When called, the maximum membrane potential value for each output
neuron is sampled and passed through the Cross Entropy Loss Function.
This criterion combines log_softmax and NLLLoss in a single function.
The Cross Entropy Loss encourages the maximum membrane potential of
the correct class to increase, while suppressing the maximum membrane
potential of incorrect classes.
This function is adopted from SpyTorch by Friedemann Zenke.
Example::
import snntorch.functional as SF
loss_fn = SF.ce_max_membrane_loss()
loss = loss_fn(outputs, targets)
:param mem_out: The output tensor of the SNN's membrane potential,
of the dimension timestep * batch_size * num_output_neurons
:type mem_out: torch.Tensor
:param targets: The tensor containing the targets of the current
mini-batch, of the dimension batch_size
:type targets: torch.Tensor
:return: Loss
:rtype: torch.Tensor (single element)
"""
def __init__(self, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.__name__ = "ce_max_membrane_loss"
def _compute_loss(self, mem_out, targets):
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)
max_mem_out, _ = torch.max(mem_out, 0)
log_p_y = log_softmax_fn(max_mem_out)
loss = loss_fn(log_p_y, targets)
return loss
[docs]
class mse_count_loss(LossFunctions):
"""Mean Square Error Spike Count Loss.
When called, the total spike count is accumulated over time for
each neuron.
The target spike count for correct classes is set to
(num_steps * correct_rate), and for incorrect classes
(num_steps * incorrect_rate).
The spike counts and target spike counts are then applied to a
Mean Square Error Loss Function.
This function is adopted from SLAYER by Sumit Bam Shrestha and
Garrick Orchard.
Example::
import snntorch.functional as SF
loss_fn = SF.mse_count_loss(correct_rate=0.75, incorrect_rate=0.25)
loss = loss_fn(outputs, targets)
:param correct_rate: Firing frequency of correct class as a ratio, e.g.,
``1`` promotes firing at every step; ``0.5`` promotes firing at 50% of
steps, ``0`` discourages any firing, defaults to ``1``
:type correct_rate: float, optional
:param incorrect_rate: Firing frequency of incorrect class(es) as a
ratio, e.g., ``1`` promotes firing at every step; ``0.5`` promotes
firing at 50% of steps, ``0`` discourages any firing, defaults to ``1``
:type incorrect_rate: float, optional
:param population_code: Specify if a population code is applied, i.e., the
number of outputs is greater than the number of classes. Defaults to
``False``
:type population_code: bool, optional
:param num_classes: Number of output classes must be specified if
``population_code=True``. Must be a factor of the number of output
neurons if population code is enabled. Defaults to ``False``
:type num_classes: int, optional
:return: Loss
:rtype: torch.Tensor (single element)
"""
def __init__(
self,
correct_rate=1,
incorrect_rate=0,
population_code=False,
num_classes=False,
reduction='mean',
weight=None
):
super().__init__(reduction=reduction, weight=weight)
self.correct_rate = correct_rate
self.incorrect_rate = incorrect_rate
self.population_code = population_code
self.num_classes = num_classes
self.__name__ = "mse_count_loss"
def _compute_loss(self, spk_out, targets):
_, num_steps, num_outputs = self._prediction_check(spk_out)
loss_fn = nn.MSELoss(reduction=self._intermediate_reduction())
if not self.population_code:
# generate ideal spike-count in C sized vector
on_target = int(num_steps * self.correct_rate)
off_target = int(num_steps * self.incorrect_rate)
spike_count_target = spikegen.targets_convert(
targets,
num_classes=num_outputs,
on_target=on_target,
off_target=off_target,
)
spike_count = torch.sum(spk_out, 0) # B x C
else:
on_target = int(
num_steps
* self.correct_rate
* (num_outputs / self.num_classes)
)
off_target = int(
num_steps
* self.incorrect_rate
* (num_outputs / self.num_classes)
)
spike_count_target = spikegen.targets_convert(
targets,
num_classes=self.num_classes,
on_target=on_target,
off_target=off_target,
)
spike_count = self._population_code(
spk_out, self.num_classes, num_outputs
)
loss = loss_fn(spike_count, spike_count_target)
if self.weight is not None:
loss = loss * self.weight[targets]
return loss / num_steps
[docs]
class mse_membrane_loss(LossFunctions):
"""Mean Square Error Membrane Loss.
When called, pass the output membrane of shape [num_steps x batch_size x
num_outputs] and the target tensor of membrane potential.
The membrane potential and target are then applied to a Mean Square Error
Loss Function.
This function is adopted from Spike-Op by Jason K. Eshraghian.
Example::
import snntorch.functional as SF
# if targets are the same at each time-step
loss_fn = mse_membrane_loss(time_var_targets=False)
loss = loss_fn(outputs, targets)
# if targets are time-varying
loss_fn = mse_membrane_loss(time_var_targets=True)
loss = loss_fn(outputs, targets)
:param time_var_targets: Specifies whether the targets are time-varying,
defaults to ``False``
:type correct_rate: bool, optional
:param on_target: Specify target membrane potential for correct class,
defaults to ``1``
:type on_target: float, optional
:param off_target: Specify target membrane potential for incorrect class,
defaults to ``0``
:type off_target: float, optional
:return: Loss
:rtype: torch.Tensor (single element)
"""
# to-do: add **kwargs to modify other keyword args in
# spikegen.targets_convert
def __init__(self, time_var_targets=False, on_target=1, off_target=0, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.time_var_targets = time_var_targets
self.on_target = on_target
self.off_target = off_target
self.__name__ = "mse_membrane_loss"
def _compute_loss(self, mem_out, targets):
device, num_steps, num_outputs = self._prediction_check(mem_out)
targets_spikes = spikegen.targets_convert(
targets,
num_classes=num_outputs,
on_target=self.on_target,
off_target=self.off_target,
)
loss_shape = mem_out[0].shape if self._intermediate_reduction() == 'none' else (1)
loss = torch.zeros(loss_shape, dtype=dtype, device=device)
loss_fn = nn.MSELoss(reduction=self._intermediate_reduction())
if self.time_var_targets:
for step in range(num_steps):
loss += loss_fn(mem_out[step], targets_spikes[step])
else:
for step in range(num_steps):
loss += loss_fn(mem_out[step], targets_spikes)
if self.weight is not None:
loss = loss * self.weight[targets]
return loss / num_steps
# Uses a sign estimator - approximates leaky as gradient is undefined.
# for neurons with defined gradients, this leads to an approximation.
# Use labels by default unless target_is_time = True
[docs]
class SpikeTime(nn.Module):
"""Used by ce_temporal_loss and mse_temporal_loss to convert spike
outputs into spike times."""
def __init__(
self,
target_is_time=False,
on_target=0,
off_target=-1,
tolerance=0,
multi_spike=False,
):
super().__init__()
self.target_is_time = target_is_time
self.tolerance = tolerance
self.tolerance_fn = self.Tolerance.apply
self.multi_spike = multi_spike
if not self.target_is_time:
self.on_target = on_target
self.off_target = off_target # override this with final step
# function used to extract the first F spike times. If
# multi_spike=False, F=1.
if self.multi_spike:
self.first_spike_fn = self.MultiSpike.apply
else:
self.first_spike_fn = self.FirstSpike.apply
# spiking output from final layer is a recording: T x B x N
# targets can either be labels or spike times
[docs]
def forward(self, spk_out, targets):
self.device, num_steps, num_outputs = self._prediction_check(spk_out)
# convert labels to spike times
if not self.target_is_time:
targets = self.labels_to_spike_times(targets, num_outputs)
# convert negative spike times to time steps: -1 -->
# ( num_steps+ (-1) )
targets[targets < 0] = spk_out.size(0) + targets[targets < 0]
# now operating in the spike-time domain rather than with labels
# Consider merging multi-spike and single-spike?
# single-spike is faster, so keep them separate for now.
if self.multi_spike:
self.spike_count = targets.size(0)
spk_time_final = self.first_spike_fn(
spk_out, self.spike_count, self.device
) # spk_time_final here means the first spike time
else:
spk_time_final = self.first_spike_fn(spk_out, self.device)
# next need to check how tolerance copes with multi-spikes
if self.tolerance:
spk_time_final = self.tolerance_fn(
spk_time_final, targets, self.tolerance
)
return spk_time_final, targets
def _prediction_check(self, spk_out):
# device = "cpu"
# if spk_out.is_cuda:
# device = "cuda"
device = spk_out.device
num_steps = spk_out.size(0)
num_outputs = spk_out.size(-1)
return device, num_steps, num_outputs
[docs]
@staticmethod
class FirstSpike(torch.autograd.Function):
"""Convert spk_rec of 1/0s [TxBxN] --> first spike time [BxN].
Linearize df/dS=-1 if spike, 0 if no spike."""
[docs]
@staticmethod
def forward(ctx, spk_rec, device="cpu"):
"""Convert spk_rec of 1/0s [TxBxN] --> spk_time [TxBxN].
0's indicate no spike --> +1 is first time step.
Transpose accounts for broadcasting along final dimension
(i.e., multiply along T)."""
spk_time = (
spk_rec.transpose(0, -1)
* (torch.arange(0, spk_rec.size(0)).detach().to(device) + 1)
).transpose(0, -1)
"""extact first spike time. Will be used to pass into loss
function."""
first_spike_time = torch.zeros_like(spk_time[0])
for step in range(spk_time.size(0)):
first_spike_time += (
spk_time[step] * ~first_spike_time.bool()
) # mask out subsequent spikes
"""override element 0 (no spike) with shadow spike @ final time
step, then offset by -1
s.t. first_spike is at t=0."""
first_spike_time += ~first_spike_time.bool() * (spk_time.size(0))
first_spike_time -= 1 # fix offset
ctx.save_for_backward(first_spike_time, spk_rec)
return first_spike_time
[docs]
@staticmethod
def backward(ctx, grad_output):
(first_spike_time, spk_rec) = ctx.saved_tensors
spk_time_grad = torch.zeros_like(spk_rec) # T x B x N
"""spike extraction step/indexing @ each step is
non-differentiable.
Apply sign estimator by substituting gradient for -1 ONLY at
first spike time."""
for i in range(first_spike_time.size(0)):
for j in range(first_spike_time.size(1)):
spk_time_grad[first_spike_time[i, j].long(), i, j] = 1.0
grad = -grad_output * spk_time_grad
return grad, None
[docs]
@staticmethod
class MultiSpike(torch.autograd.Function):
"""Convert spk_rec of 1/0s [TxBxN] --> first F spike times [FxBxN].
Linearize df/dS=-1 if spike, 0 if no spike."""
[docs]
@staticmethod
def forward(ctx, spk_rec, spk_count, device="cpu"):
spk_rec_tmp = spk_rec.clone()
spk_time_rec = []
for step in range(spk_count):
"""Convert spk_rec of 1/0s [TxBxN] --> spk_time [TxBxN].
0's indicate no spike --> +1 is first time step.
Transpose accounts for broadcasting along final dimension
(i.e., multiply along T)."""
spk_time = (
spk_rec_tmp.transpose(0, -1)
* (
torch.arange(0, spk_rec_tmp.size(0))
.detach()
.to(device)
+ 1
)
).transpose(0, -1)
"""extact n-th spike time (n=step) up to F."""
nth_spike_time = torch.zeros_like(spk_time[0])
for step in range(spk_time.size(0)):
nth_spike_time += (
spk_time[step] * ~nth_spike_time.bool()
) # mask out subsequent spikes
"""override element 0 (no spike) with shadow spike @ final
time step, then offset by -1
s.t. first_spike is at t=0."""
nth_spike_time += ~nth_spike_time.bool() * (
spk_time.size(0)
) # populate non-spiking with total size
nth_spike_time -= 1 # fix offset
spk_time_rec.append(nth_spike_time)
"""before looping, eliminate n-th spike. this avoids double
counting spikes."""
spk_rec_tmp[nth_spike_time.long()] = 0
"""Pass this into loss function."""
spk_time_rec = torch.stack(spk_time_rec)
ctx.save_for_backward(spk_time_rec, spk_rec)
return spk_time_rec
[docs]
@staticmethod
def backward(ctx, grad_output):
(spk_time_final, spk_rec) = ctx.saved_tensors
spk_time_grad = torch.zeros_like(spk_rec) # T x B x N
"""spike extraction step/indexing @ each step is
non-differentiable.
Apply sign estimator by substituting gradient for -1 ONLY at
F-th spike time."""
for i in range(spk_time_final.size(0)):
for j in range(spk_time_final.size(1)):
for k in range(spk_time_final.size(2)):
spk_time_grad[
spk_time_final[i, j, k].long(), j, k
] = -grad_output[i, j, k]
grad = spk_time_grad
return grad, None, None
[docs]
@staticmethod
class Tolerance(torch.autograd.Function):
"""If spike time is 'close enough' to target spike within tolerance,
set the time to target for loss calc only."""
# TO-DO: remove ctx?
[docs]
@staticmethod
def forward(ctx, spk_time, target, tolerance):
spk_time_clone = (
spk_time.clone()
) # spk_time_clone: BxN (FxBxN for multi-spike); target: TxBxN
spk_time_clone[torch.abs(spk_time - target) < tolerance] = (
torch.ones_like(spk_time) * target
)[torch.abs(spk_time - target) < tolerance]
return spk_time_clone
[docs]
@staticmethod
def backward(ctx, grad_output):
grad = grad_output
return grad, None, None
[docs]
def labels_to_spike_times(self, targets, num_outputs):
"""Convert index labels [B] into spike times."""
if not self.multi_spike:
targets = self.label_to_single_spike(targets, num_outputs)
# pass in labels --> output multiple spikes
# assumes on_target & off_target are iterable
else:
targets = self.label_to_multi_spike(targets, num_outputs)
return targets
[docs]
def label_to_single_spike(self, targets, num_outputs):
"""Convert labels from neuron index (dim: B) to first spike time
(dim: B x N)."""
# guess: i designed this code with on_target >> off_target in mind
targets = spikegen.targets_convert(
targets,
num_classes=num_outputs,
on_target=self.on_target,
off_target=self.off_target,
)
return targets
[docs]
def label_to_multi_spike(self, targets, num_outputs):
"""Convert labels from neuron index (dim: B) to multiple spike times
(dim: F x B x N).
F is the number of spikes per neuron. Assumes target is iterable
along F."""
num_spikes_on = len(self.on_target)
num_spikes_off = len(self.off_target)
if num_spikes_on != num_spikes_off:
raise IndexError(
f"`on_target` (length: {num_spikes_on}) must have the same "
f"length as `off_target` (length: {num_spikes_off}."
)
# iterate through each spike
targets_rec = []
for step in range(num_spikes_on):
target_step = spikegen.targets_convert(
targets,
num_classes=num_outputs,
on_target=self.on_target[step],
off_target=self.off_target[step],
)
targets_rec.append(target_step)
targets_rec = torch.stack(targets_rec)
return targets_rec
[docs]
class mse_temporal_loss:
"""Mean Square Error Temporal Loss.
The first spike time of each output neuron [batch_size x num_outputs] is
measured against the desired spike time with the Mean Square Error Loss
Function.
Note that the derivative of each spike time with respect to the spike
df/dU is non-differentiable for most neuron classes, and is set to a sign
estimator of -1.
I.e., increasing membrane potential causes a proportionately earlier
firing time.
The Mean Square Error Temporal Loss can account for multiple spikes by
setting ``multi_spike=True``.
If the actual spike time is close enough to the target spike time within
a given tolerance, e.g., ``tolerance = 5`` time steps, then it does not
contribute to the loss.
Index labels are passed as the target by default.
To enable passing in the spike time(s) for output neuron(s), set
``target_is_time=True``.
Note: After spike times with specified targets, no penalty is applied
for subsequent spiking.
To eliminate later spikes, an additional target should be applied.
Example::
import torch
import snntorch.functional as SF
# default takes in idx labels as targets
# correct classes aimed to fire by default at t=0, incorrect at t=-1
(final time step)
loss_fn = mse_temporal_loss()
loss = loss_fn(spk_out, targets)
# as above, but correct class fire @ t=5, incorrect at t=100 with a
tolerance of 2 steps
loss_fn = mse_temporal_loss(on_target=5, off_target=100, tolerance=2)
loss = loss_fn(spk_out, targets)
# as above with multiple spike time targets
on_target = torch.tensor(5, 10)
off_target = torch.tensor(100, 105)
loss_fn = mse_temporal_loss(on_target=on_target,
off_target=off_target, tolerance=2)
loss = loss_fn(spk_out, targets)
# specify first spike time for 5 neurons individually, zero tolerance
target = torch.tensor(5, 10, 15, 20, 25)
loss_fn = mse_temporal_loss(target_is_time=True)
loss = loss_fn(spk_out, target)
:param target_is_time: Specify if target is specified as spike times
(True) or as neuron indexes (False). Defaults to ``False``
:type target_is_time: bool, optional
:param on_target: Spike time for correct classes
(only if target_is_time=False). Defaults to ``0``
:type on_target: int
(or interable over multiple int if ``multi_spike=True``), optional
:param off_target: Spike time for incorrect classes
(only if target_is_time=False).
Defaults to ``-1``, i.e., final time step
:type off_target: int (or interable over multiple int if
``multi_spike=True``), optional
:param tolerance: If the distance between the spike time and target is
less than the specified tolerance, then it does not contribute to the
loss. Defaults to ``0``.
:type tolerance: int, optional
:param multi_spike: Specify if multiple spikes in target. Defaults to
``False``
:type multi_spike: bool, optional
:return: Loss
:rtype: torch.Tensor (single element)
"""
def __init__(
self,
target_is_time=False,
on_target=0,
off_target=-1,
tolerance=0,
multi_spike=False,
reduction='mean',
weight=None
):
super().__init__()
self.reduction = reduction
self.weight = weight
self.loss_fn = nn.MSELoss(reduction=('none' if self.weight is not None else self.reduction))
self.spk_time_fn = SpikeTime(
target_is_time, on_target, off_target, tolerance, multi_spike
)
self.__name__ = "mse_temporal_loss"
def __call__(self, spk_rec, targets):
spk_time, target_time = self.spk_time_fn(
spk_rec, targets
) # return encoded targets
loss = self.loss_fn(
spk_time / spk_rec.size(0), target_time / spk_rec.size(0)
) # spk_time_final: num_spikes x B x Nc. # Same with targets.
if self.weight is not None:
loss = loss * self.weight[targets]
if self.reduction == 'mean':
loss = loss.mean()
return loss
[docs]
class ce_temporal_loss:
"""Cross Entropy Temporal Loss.
The cross entropy loss of an 'inverted' first spike time of each output
neuron [batch_size x num_outputs] is calculated.
The 'inversion' is applied such that maximizing the value of the correct
class decreases the first spike time (i.e., earlier spike).
Options for inversion include: ``inverse='negate'`` which applies
(-1 * output), or ``inverse='reciprocal'`` which takes (1/output).
Note that the derivative of each spike time with respect to the spike
df/dU is non-differentiable for most neuron classes, and is set to a
sign estimator of -1.
I.e., increasing membrane potential causes a proportionately earlier
firing time.
Index labels are passed as the target. To specify the exact spike time,
use ``mse_temporal_loss`` instead.
Note: After spike times with specified targets, no penalty is applied
for subsequent spiking.
Example::
import torch
import snntorch.functional as SF
# correct classes aimed to fire by default at t=0, incorrect at
final step
loss_fn = ce_temporal_loss()
loss = loss_fn(spk_out, targets)
:param inverse: Specify how to invert output before taking cross
entropy. Either scale by (-1 * x) with ``inverse='negate'`` or take the
reciprocal (1/x) with ``inverse='reciprocal'``. Defaults to ``negate``
:type inverse: str, optional
:return: Loss
:rtype: torch.Tensor (single element)
"""
def __init__(self, inverse="negate", reduction='mean', weight=None):
super().__init__()
self.reduction = reduction
self.weight = weight
self.loss_fn = nn.CrossEntropyLoss(reduction=self.reduction, weight=self.weight)
self.spk_time_fn = SpikeTime(target_is_time=False)
self.inverse = inverse
self._ce_temporal_cases()
self.__name__ = "ce_temporal_loss"
def __call__(self, spk_rec, targets):
spk_time, _ = self.spk_time_fn(
spk_rec, targets
) # return encoded targets
if self.inverse == "negate":
spk_time = -spk_time
if self.inverse == "reciprocal":
spk_time = 1 / (spk_time + 1)
# loss = self.loss_fn(
# spk_time / spk_rec.size(0), targets / spk_rec.size(0)
# ) # spk_time_final: num_spikes x B x Nc. # Same with targets.
loss = self.loss_fn(
spk_time, targets
) # spk_time_final: num_spikes x B x Nc. # Same with targets.
return loss
def _ce_temporal_cases(self):
if self.inverse != "negate" and self.inverse != "reciprocal":
raise ValueError(
'`inverse` must be of type string containing either "negate" '
'or "reciprocal".'
)