import torch
import numpy as np
[docs]
def accuracy_rate(spk_out, targets, population_code=False, num_classes=False):
"""Use spike count to measure accuracy.
:param spk_out: Output spikes of shape \
[num_steps x batch_size x num_outputs]
:type spk_out: torch.Tensor
:param targets: Target tensor (without one-hot-encoding) of shape \
[batch_size]
:type targets: torch.Tensor
:return: accuracy
:rtype: numpy.float64
"""
if population_code:
_, _, num_outputs = _prediction_check(spk_out)
_, idx = _population_code(spk_out, num_classes, num_outputs).max(1)
else:
_, idx = spk_out.sum(dim=0).max(1)
accuracy = np.mean((targets == idx).detach().cpu().numpy())
return accuracy
[docs]
def accuracy_temporal(spk_out, targets):
"""Use spike timing to measure accuracy.
:param spk_out: Output spikes of shape \
[num_steps x batch_size x num_outputs]
:type spk_out: torch.Tensor
:param targets: Target tensor (without one-hot-encoding) of shape \
[batch_size]
:type targets: torch.Tensor
:return: accuracy
:rtype: numpy.float64
"""
device, _, _ = _prediction_check(spk_out)
# convert spk_out into first spike
spk_time = (
spk_out.transpose(0, -1)
* (torch.arange(0, spk_out.size(0)).detach().to(device) + 1)
).transpose(0, -1)
# Extact first spike time.
# Used to pass into loss function.
first_spike_time = torch.zeros_like(spk_time[0])
for step in range(spk_time.size(0)):
first_spike_time += (
spk_time[step] * ~first_spike_time.bool()
) # mask out subsequent spikes
# Override element 0 (no spike) with shadow spike at final time step,
# then offset by -1 s.t. first_spike is at t=0
first_spike_time += ~first_spike_time.bool() * (spk_time.size(0))
first_spike_time -= 1 # fix offset
# take idx of torch.min, see if it matches targets
_, idx = first_spike_time.min(1)
accuracy = np.mean((targets == idx).detach().cpu().numpy())
return accuracy
def _prediction_check(spk_out):
device = spk_out.device
num_steps = spk_out.size(0)
num_outputs = spk_out.size(-1)
return device, num_steps, num_outputs
def _population_code(spk_out, num_classes, num_outputs):
"""Count up spikes sequentially from output classes."""
if not num_classes:
raise Exception(
"``num_classes`` must be specified if ``population_code=True``."
)
if num_outputs % num_classes:
raise Exception(
f"``num_outputs {num_outputs} must be a factor of num_classes "
f"{num_classes}."
)
# device = "cpu"
# if spk_out.is_cuda:
# device = "cuda"
device = spk_out.device
pop_code = torch.zeros(tuple([spk_out.size(1)] + [num_classes])).to(device)
for idx in range(num_classes):
pop_code[:, idx] = (
spk_out[
:,
:,
int(num_outputs * idx / num_classes) : int(
num_outputs * (idx + 1) / num_classes
),
]
.sum(-1)
.sum(0)
)
return pop_code