import torch
dtype = torch.float
[docs]
def rate(
data,
num_steps=False,
gain=1,
offset=0,
first_spike_time=0,
time_var_input=False,
):
"""Spike rate encoding of input data. Convert tensor into Poisson spike
trains using the features as the mean of a
binomial distribution. If `num_steps` is specified, then the data will be
first repeated in the first dimension
before rate encoding.
If data is time-varying, tensor dimensions use time first.
Example::
# 100% chance of spike generation
a = torch.Tensor([1, 1, 1, 1])
spikegen.rate(a, num_steps=1)
>>> tensor([1., 1., 1., 1.])
# 0% chance of spike generation
b = torch.Tensor([0, 0, 0, 0])
spikegen.rate(b, num_steps=1)
>>> tensor([0., 0., 0., 0.])
# 50% chance of spike generation per time step
c = torch.Tensor([0.5, 0.5, 0.5, 0.5])
spikegen.rate(c, num_steps=1)
>>> tensor([0., 1., 0., 1.])
# Increasing num_steps will increase the length of
# the first dimension (time-first)
print(c.size())
>>> torch.Size([1, 4])
d = spikegen.rate(torch.Tensor([0.5, 0.5, 0.5, 0.5]), num_steps = 2)
print(d.size())
>>> torch.Size([2, 4])
:param data: Data tensor for a single batch of shape [batch x input_size]
:type data: torch.Tensor
:param num_steps: Number of time steps. Only specify if input data
does not already have time dimension, defaults to ``False``
:type num_steps: int, optional
:param gain: Scale input features by the gain, defaults to ``1``
:type gain: float, optional
:param offset: Shift input features by the offset, defaults to ``0``
:type offset: torch.optim, optional
:param first_spike_time: Time to first spike, defaults to ``0``.
:type first_spike_time: int, optional
:param time_var_input: Set to ``True`` if input tensor is time-varying.
Otherwise, `first_spike_time!=0` will modify the wrong dimension.
Defaults to ``False``
:type time_var_input: bool, optional
:return: rate encoding spike train of input features of shape
[num_steps x batch x input_size]
:rtype: torch.Tensor
"""
if first_spike_time < 0 or num_steps < 0:
raise Exception(
"``first_spike_time`` and ``num_steps`` cannot be negative."
)
if first_spike_time > (num_steps - 1):
if num_steps:
raise Exception(
f"first_spike_time ({first_spike_time}) must be equal to "
f"or less than num_steps-1 ({num_steps-1})."
)
if not time_var_input:
raise Exception(
"If the input data is time-varying, set "
"``time_var_input=True``.\n If the input data is not "
"time-varying, ensure ``num_steps > 0``."
)
if first_spike_time > 0 and not time_var_input and not num_steps:
raise Exception(
"``num_steps`` must be specified if both the input is not "
"time-varying and ``first_spike_time`` is greater than 0."
)
if time_var_input and num_steps:
raise Exception(
"``num_steps`` should not be specified if input is "
"time-varying, i.e., ``time_var_input=True``.\n "
"The first dimension of the input data + ``first_spike_time`` "
"will determine ``num_steps``."
)
device = data.device
# intended for time-varying input data
if time_var_input:
spike_data = rate_conv(data)
# zeros are added directly to the start of 0th (time) dimension
if first_spike_time > 0:
spike_data = torch.cat(
(
torch.zeros(
tuple([first_spike_time] + list(spike_data[0].size())),
device=device,
dtype=dtype,
),
spike_data,
)
)
# intended for time-static input data
else:
# Generate a tuple: (num_steps, 1..., 1) where the number of 1's
# = number of dimensions in the original data.
# Multiply by gain and add offset.
time_data = (
data.repeat(
tuple(
[num_steps]
+ torch.ones(len(data.size()), dtype=int).tolist()
)
)
* gain
+ offset
)
spike_data = rate_conv(time_data)
# zeros are multiplied by the start of the 0th (time) dimension
if first_spike_time > 0:
spike_data[0:first_spike_time] = 0
return spike_data
[docs]
def latency(
data,
num_steps=False,
threshold=0.01,
tau=1,
first_spike_time=0,
on_target=1,
off_target=0,
clip=False,
normalize=False,
linear=False,
interpolate=False,
bypass=False,
epsilon=1e-7,
):
"""Latency encoding of input or target label data. Use input features
to determine time-to-first spike. Expected inputs should be
between 0 and 1.
Assume a LIF neuron model that charges up with time constant tau.
Tensor dimensions use time first.
Example::
a = torch.Tensor([0.02, 0.5, 1])
spikegen.latency(a, num_steps=5, normalize=True, linear=True)
>>> tensor([[0., 0., 1.],
[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.],
[1., 0., 0.]])
:param data: Data tensor for a single batch of shape [batch x input_size]
:type data: torch.Tensor
:param num_steps: Number of time steps. Explicitly needed if
``normalize=True``, defaults to ``False`` (then changed to ``1``
if ``normalize=False``)
:type num_steps: int, optional
:param threshold: Input features below the threhold will fire at the
final time step unless ``clip=True`` in which case they will not
fire at all, defaults to ``0.01``
:type threshold: float, optional
:param tau: RC Time constant for LIF model used to calculate
firing time, defaults to ``1``
:type tau: float, optional
:param first_spike_time: Time to first spike, defaults to ``0``.
:type first_spike_time: int, optional
:param on_target: Target at spike times, defaults to ``1``
:type on_target: float, optional
:param off_target: Target during refractory period, defaults to ``0``
:type off_target: float, optional
:param clip: Option to remove spikes from features that fall
below the threshold, defaults to ``False``
:type clip: Bool, optional
:param normalize: Option to normalize the latency code such
that the final spike(s) occur within num_steps, defaults to ``False``
:type normalize: Bool, optional
:param linear: Apply a linear latency code rather than the
default logarithmic code, defaults to ``False``
:type linear: Bool, optional
:param interpolate: Applies linear interpolation such that
there is a gradually increasing target up to each spike, defaults to
``False``
:type interpolate: Bool, optional
:param bypass: Used to block error messages that occur from either: i)
spike times exceeding the bounds of ``num_steps``, or ii) if
``num_steps`` is not specified, setting ``bypass=True``
allows the largest spike time to set ``num_steps``. Defaults to
``False``
:type bypass: bool, optional
:param epsilon: A tiny positive value to avoid rounding errors
when using torch.arange, defaults to ``1e-7``
:type epsilon: float, optional
:return: latency encoding spike train of features or labels
:rtype: torch.Tensor
"""
if torch.min(data) < 0 or torch.max(data) > 1:
raise Exception(
f"Elements of ``data`` must be between [0, 1], but input "
f"is [{torch.min(data)}, {torch.max(data)}]"
)
if threshold < 0 or threshold > 1:
raise Exception(f"``threshold`` [{threshold}] must be between [0, 1]")
if not num_steps and not bypass:
raise Exception(
"``num_steps`` must be specified. Alternatively, setting "
"``bypass=True`` will automatically set ``num_steps`` "
"to the last spike time. This may lead to uneven tensor "
"sizes when used in a loop."
)
device = data.device
spike_time, idx = latency_code(
data,
num_steps=num_steps,
threshold=threshold,
tau=tau,
first_spike_time=first_spike_time,
normalize=normalize,
linear=linear,
epsilon=epsilon,
)
# automatically set num_steps using max element in spike_time
if not num_steps and bypass:
num_steps = int(torch.round(torch.max(spike_time)).long() + 1)
if num_steps <= 0:
raise Exception(
f"``num_steps`` [{num_steps}] must be positive. "
f"This can be specifiedInput data should be normalized "
f"to larger values or ``threshold`` should be set to a "
f"smaller value."
)
if (
torch.round(torch.max(spike_time)).long() > (num_steps - 1)
and not bypass
):
raise Exception(
f"The maximum value in ``spike_time`` "
f"[{torch.round(torch.max(spike_time)).long()}] is out of "
f"bounds for ``num_steps`` [{num_steps}-1].\n To bypass "
f"this error, set ``bypass=True``.\n Alternatively, constrain "
f"``spike_time`` within the range of ``num_steps`` "
f"by either decreasing ``tau`` or ``setting normalize=True``."
)
if not interpolate:
spike_data = torch.zeros(
(tuple([num_steps] + list(spike_time.size()))),
dtype=dtype,
device=device,
)
# use rm_idx to remove spikes beyond the range of num_steps
rm_idx = torch.round(spike_time).long() > num_steps - 1
spike_data = (
spike_data.scatter(
0,
torch.round(torch.clamp_max(spike_time, num_steps - 1))
.long()
.unsqueeze(0),
1,
)
* ~rm_idx
)
# Use idx to remove spikes below the threshold
if clip:
spike_data = spike_data * ~idx # idx is broadcast in T direction
return torch.clamp(spike_data * on_target, off_target)
elif interpolate:
return latency_interpolate(
spike_time, num_steps, on_target=on_target, off_target=off_target
)
[docs]
def delta(
data,
threshold=0.1,
padding=False,
off_spike=False,
):
"""Generate spike only when the difference between two subsequent
time steps meets a threshold.
Optionally include off_spikes for negative changes.
Example::
a = torch.Tensor([1, 2, 2.9, 3, 3.9])
spikegen.delta(a, threshold=1)
>>> tensor([1., 1., 0., 0., 0.])
spikegen.delta(a, threshold=1, padding=True)
>>> tensor([0., 1., 0., 0., 0.])
b = torch.Tensor([1, 2, 0, 2, 2.9])
spikegen.delta(b, threshold=1, off_spike=True)
>>> tensor([ 1., 1., -1., 1., 0.])
spikegen.delta(b, threshold=1, padding=True, off_spike=True)
>>> tensor([ 0., 1., -1., 1., 0.])
:param data: Data tensor for a single batch of shape [num_steps x batch
x input_size]
:type data: torch.Tensor
:param threshold: Input features with a change greater than the thresold
across one timestep will generate a spike, defaults to ``0.1``
:type thr: float, optional
:param padding: Used to change how the first time step of spikes are
measured. If ``True``, the first time step will be repeated with itself
resulting in ``0``'s for the output spikes.
If ``False``, the first time step will be padded with ``0``'s, defaults
to ``False``
:type padding: bool, optional
:param off_spike: If ``True``, negative spikes for changes less than
``-threshold``, defaults to ``False``
:type off_spike: bool, optional
"""
if padding:
data_offset = torch.cat((data[0].unsqueeze(0), data))[
:-1
] # duplicate first time step, remove final step
else:
data_offset = torch.cat(
(torch.zeros_like(data[0]).unsqueeze(0), data)
)[
:-1
] # add 0's to first step, remove final step
if not off_spike:
return torch.ones_like(data) * ((data - data_offset) >= threshold)
else:
on_spk = torch.ones_like(data) * ((data - data_offset) >= threshold)
off_spk = -torch.ones_like(data) * ((data - data_offset) <= -threshold)
return on_spk + off_spk
[docs]
def rate_conv(data):
"""Convert tensor into Poisson spike trains using the features as
the mean of a binomial distribution.
Values outside the range of [0, 1] are clipped so they can be
treated as probabilities.
Example::
# 100% chance of spike generation
a = torch.Tensor([1, 1, 1, 1])
spikegen.rate_conv(a)
>>> tensor([1., 1., 1., 1.])
# 0% chance of spike generation
b = torch.Tensor([0, 0, 0, 0])
spikegen.rate_conv(b)
>>> tensor([0., 0., 0., 0.])
# 50% chance of spike generation per time step
c = torch.Tensor([0.5, 0.5, 0.5, 0.5])
spikegen.rate_conv(c)
>>> tensor([0., 1., 0., 1.])
:param data: Data tensor for a single batch of shape [batch x input_size]
:type data: torch.Tensor
:return: rate encoding spike train of input features of shape
[num_steps x batch x input_size]
:rtype: torch.Tensor
"""
# Clip all features between 0 and 1 so they can be used as
# probabilities.
clipped_data = torch.clamp(data, min=0, max=1)
# pass time_data matrix into bernoulli function.
spike_data = torch.bernoulli(clipped_data)
return spike_data
[docs]
def latency_code(
data,
num_steps=False,
threshold=0.01,
tau=1,
first_spike_time=0,
normalize=False,
linear=False,
epsilon=1e-7,
):
"""Latency encoding of input data. Convert input features or
target labels to spike times. Assumes a LIF neuron model
that charges up with time constant tau by default.
Example::
a = torch.Tensor([0.02, 0.5, 1])
spikegen.latency_code(a, num_steps=5, normalize=True, linear=True)
>>> (tensor([3.9200, 2.0000, 0.0000]), tensor([False, False, False]))
:param data: Data tensor for a single batch of shape [batch x input_size]
:type data: torch.Tensor
:param num_steps: Number of time steps. Explicitly needed if
``normalize=True``, defaults to ``False`` (then changed to ``1``
if ``normalize=False``)
:type num_steps: int, optional
:param threshold: Input features below the threhold will fire at
the final time step unless ``clip=True`` in which case they will
not fire at all, defaults to ``0.01``
:type threshold: float, optional
:param tau: RC Time constant for LIF model used to calculate
firing time, defaults to ``1``
:type tau: float, optional
:param first_spike_time: Time to first spike, defaults to ``0``.
:type first_spike_time: int, optional
:param normalize: Option to normalize the latency code such
that the final spike(s) occur within num_steps, defaults to ``False``
:type normalize: Bool, optional
:param linear: Apply a linear latency code rather than the
default logarithmic code, defaults to ``False``
:type linear: Bool, optional
:param epsilon: A tiny positive value to avoid rounding errors
when using torch.arange, defaults to ``1e-7``
:type epsilon: float, optional
:return: latency encoding spike times of features
:rtype: torch.Tensor
:return: Tensor of Boolean values which correspond to the
latency encoding elements that fall below the threshold.
Used in ``latency_conv`` to clip saturated spikes.
:rtype: torch.Tensor
"""
idx = data < threshold
if not linear:
spike_time = latency_code_log(
data,
num_steps=num_steps,
threshold=threshold,
tau=tau,
first_spike_time=first_spike_time,
normalize=normalize,
epsilon=epsilon,
)
elif linear:
spike_time = latency_code_linear(
data,
num_steps=num_steps,
threshold=threshold,
tau=tau,
first_spike_time=first_spike_time,
normalize=normalize,
)
return spike_time, idx
[docs]
def latency_code_linear(
data,
num_steps=False,
threshold=0.01,
tau=1,
first_spike_time=0,
normalize=False,
):
"""Linear latency encoding of input data. Convert input features
or target labels to spike times.
Example::
a = torch.Tensor([0.02, 0.5, 1])
spikegen.latency_code(a, num_steps=5, normalize=True, linear=True)
>>> (tensor([3.9200, 2.0000, 0.0000]), tensor([False, False, False]))
:param data: Data tensor for a single batch of shape [batch x input_size]
:type data: torch.Tensor
:param num_steps: Number of time steps. Explicitly needed if
``normalize=True``, defaults to ``False``
(then changed to ``1`` if ``normalize=False``)
:type num_steps: int, optional
:param threshold: Input features below the threhold will
fire at the final time step, defaults to ``0.01``
:type threshold: float, optional
:param tau: Linear time constant used to calculate firing time,
defaults to ``1``
:type tau: float, optional
:param first_spike_time: Time to first spike, defaults to ``0``.
:type first_spike_time: int, optional
:param normalize: Option to normalize the latency code such that
the final spike(s) occur within num_steps, defaults to ``False``
:type normalize: Bool, optional
:return: linear latency encoding spike times of features
:rtype: torch.Tensor
"""
_latency_errors(
data, num_steps, threshold, tau, first_spike_time, normalize
) # error checks
if normalize:
tau = num_steps - 1 - first_spike_time
spike_time = (
torch.clamp_max((-tau * (data - 1)), -tau * (threshold - 1))
) + first_spike_time
# the following code is intended for negative input data.
# it is more broadly caught in latency code by ensuring 0 < data < 1.
# Consider disabling ~(0<data<1) input.
if torch.min(spike_time) < 0 and normalize:
spike_time = (
(spike_time - torch.min(spike_time))
* (1 / (torch.max(spike_time) - torch.min(spike_time)))
* (num_steps - 1)
)
return spike_time
[docs]
def latency_code_log(
data,
num_steps=False,
threshold=0.01,
tau=1,
first_spike_time=0,
normalize=False,
epsilon=1e-7,
):
"""Logarithmic latency encoding of input data. Convert input features
or target labels to spike times.
Example::
a = torch.Tensor([0.02, 0.5, 1])
spikegen.latency_code(a, num_steps=5, normalize=True)
>>> (tensor([4.0000, 0.1166, 0.0580]), tensor([False, False, False]))
:param data: Data tensor for a single batch of shape [batch x input_size]
:type data: torch.Tensor
:param num_steps: Number of time steps. Explicitly needed if
``normalize=True``, defaults to ``False`` (then changed to ``1`` if
``normalize=False``)
:type num_steps: int, optional
:param threshold: Input features below the threhold will fire at the
final time step, defaults to ``0.01``
:type threshold: float, optional
:param tau: Logarithmic time constant used to calculate firing time,
defaults to ``1``
:type tau: float, optional
:param first_spike_time: Time to first spike, defaults to ``0``.
:type first_spike_time: int, optional
:param normalize: Option to normalize the latency code such that
the final spike(s) occur within num_steps, defaults to ``False``
:type normalize: Bool, optional
:param epsilon: A tiny positive value to avoid rounding errors when
using torch.arange, defaults to ``1e-7``
:type epsilon: float, optional
:return: logarithmic latency encoding spike times of features
:rtype: torch.Tensor
"""
_latency_errors(
data, num_steps, threshold, tau, first_spike_time, normalize
) # error checks
data = torch.clamp(
data, threshold + epsilon
) # saturates all values below threshold.
spike_time = tau * torch.log(data / (data - threshold))
if first_spike_time > 0:
spike_time += first_spike_time
if normalize:
spike_time = (spike_time - first_spike_time) * (
num_steps - first_spike_time - 1
) / torch.max(spike_time - first_spike_time) + first_spike_time
return spike_time
def _latency_errors(
data, num_steps, threshold, tau, first_spike_time, normalize
):
"""Catch errors for spike time encoding latency functions
``latency_code_linear`` and ``latency_code_log``"""
if (
threshold <= 0 or threshold >= 1
): # double check if this can just be threshold < 0 instead.
raise Exception("Threshold must be between 0 and 1.")
if tau <= 0: # double check if this can just be threshold < 0 instead.
raise Exception("``tau`` must be greater than 0.")
if first_spike_time and num_steps and first_spike_time > (num_steps - 1):
raise Exception(
f"first_spike_time ({first_spike_time}) must be equal to "
f"or less than num_steps-1 ({num_steps-1})."
)
# this condition is more broadly caught in latency code by ensuring 0
# < data < 1
if first_spike_time and torch.max(data) > 1 and torch.min(data) < 0:
raise Exception(
"`first_spike_time` can only be applied to data between "
"`0` and `1`."
)
if first_spike_time < 0:
raise Exception(
"``first_spike_time`` [{first_spike_time}] cannot be negative."
)
if num_steps < 0:
raise Exception("``num_steps`` [{num_steps}] cannot be negative.")
if normalize and not num_steps:
raise Exception(
"`num_steps` should not be empty if normalize is set to True."
)
[docs]
def targets_convert(
targets,
num_classes,
code="rate",
num_steps=False,
first_spike_time=0,
correct_rate=1,
incorrect_rate=0,
on_target=1,
off_target=0,
firing_pattern="regular",
interpolate=False,
epsilon=1e-7,
threshold=0.01,
tau=1,
clip=False,
normalize=False,
linear=False,
bypass=False,
):
"""Spike encoding of targets. Expected input is a 1-D tensor with
index of targets.
If the output tensor is time-varying, the returned tensor will
have time in the first dimension.
If it is not time-varying, then the returned tensor will omit
the time dimension and use batch first.
The following arguments will necessarily incur a time-varying output:
``code='latency'``, ``first_spike_time!=0``, ``correct_rate!=1``,
or ``incorrect_rate!=0``
The target output may be applied to the internal state (e.g., membrane)
of the neuron or to the spike.
The following arguments will produce an output tensor that may
sensibly be applied as a target to either the output spike or the
membrane potential, as the output will consistently be either a
`1` or `0`:
``on_target=1``, ``off_target=0``, and ``interpolate=False``
If any of the above 3 conditions do not hold, then the target is
better suited for the output membrane potential, as it will likely
include values other than `1` and `0`.
Example::
a = torch.Tensor([4])
# rate-coding
# one-hot
spikegen.targets_convert(a, num_classes=5, code="rate")
>>> (tensor([[0., 0., 0., 0., 1.]]), )
# one-hot + time-first
spikegen.targets_convert(a, num_classes=5, code="rate",
correct_rate=0.8, incorrect_rate=0.2, num_steps=5).size()
>>> torch.Size([5, 1, 5])
For more examples of rate-coding, see
``help(snntorch.spikegen(targets_rate))``.
:param targets: Target tensor for a single batch.
The target should be a class index in the range [0, C-1]
where C=number of classes.
:type targets: torch.Tensor
:param num_classes: Number of outputs.
:type num_classes: int
:param code: Encoding scheme. Options of ``'rate'`` or
``'latency'``, defaults to ``'rate'``
:type code: string, optional
:param num_steps: Number of time steps, defaults to ``False``
:type num_steps: int, optional
:param first_spike_time: Time step for first spike to occur,
defaults to ``0``
:type first_spike_time: int, optional
:param correct_rate: Firing frequency of correct class as a ratio,
e.g., ``1`` enables firing at every step; ``0.5`` enables firing at
50% of steps, ``0`` means no firing, defaults to ``1``
:type correct_rate: float, optional
:param incorrect_rate: Firing frequency of incorrect class(es),
e.g., ``1`` enables firing at every step; ``0.5`` enables firing at
50% of steps, ``0`` means no firing, defaults to ``0``
:type incorrect_rate: float, optional
:param on_target: Target at spike times, defaults to ``1``
:type on_target: float, optional
:param off_target: Target during refractory period, defaults to ``0``
:type off_target: float, optional
:param firing_pattern: Firing pattern of correct and incorrect classes.
``'regular'`` enables periodic firing, ``'uniform'`` samples spike
times from a uniform distributions (duplicates are removed),
``'poisson'`` samples from a binomial distribution at each step where
each probability is the firing frequency, defaults to ``'regular'``
:type firing_pattern: string, optional
:param interpolate: Applies linear interpolation such that there is
a gradually increasing target up to each spike, defaults to ``False``
:type interpolate: Bool, optional
:param epsilon: A tiny positive value to avoid rounding errors when
using torch.arange, defaults to ``1e-7``
:type epsilon: float, optional
:param bypass: Used to block error messages that occur from either: i)
spike times exceeding the bounds of ``num_steps``, or ii) if
``num_steps`` is not specified, setting ``bypass=True`` allows
the largest spike time to set ``num_steps``. Defaults to ``False``
:type bypass: bool, optional
:return: spike coded target of output neurons. If targets are
time-varying, the output tensor will use time-first dimensions.
Otherwise, time is omitted from the tensor.
:rtype: torch.Tensor
"""
# raise exceptions if num_steps is not supplied, and rates have
# been specified, or if latency is specified.
if code == "rate":
return targets_rate(
targets=targets,
num_classes=num_classes,
num_steps=num_steps,
first_spike_time=first_spike_time,
correct_rate=correct_rate,
incorrect_rate=incorrect_rate,
on_target=on_target,
off_target=off_target,
firing_pattern=firing_pattern,
interpolate=interpolate,
epsilon=epsilon,
)
# do we need num_steps
elif code == "latency":
return targets_latency(
targets,
num_classes,
num_steps=num_steps,
threshold=threshold,
tau=tau,
first_spike_time=first_spike_time,
clip=clip,
normalize=normalize,
linear=linear,
bypass=bypass,
)
else:
raise Exception(f"code ['{code}'] must be either 'rate' or 'latency'")
[docs]
def targets_rate(
targets,
num_classes,
num_steps=False,
first_spike_time=0,
correct_rate=1,
incorrect_rate=0,
on_target=1,
off_target=0,
firing_pattern="regular",
interpolate=False,
epsilon=1e-7,
):
"""Spike rate encoding of targets. Input tensor must be one-dimensional
with target indexes.
If the output tensor is time-varying, the returned tensor will have
time in the first dimension.
If it is not time-varying, then the returned tensor will omit the time
dimension and use batch first.
If ``first_spike_time!=0``, ``correct_rate!=1``, or ``incorrect_rate!=0``,
the output tensor will be time-varying.
If ``on_target=1``, ``off_target=0``, and ``interpolate=False``,
then the target may sensibly be applied as a target for the output spike.
IF any of the above 3 conditions do not hold, then the target would
be better suited for the output membrane potential.
Example::
a = torch.Tensor([4])
# one-hot
spikegen.targets_rate(a, num_classes=5)
>>> (tensor([[0., 0., 0., 0., 1.]]), )
# first spike time delay, spike evolution over time
spikegen.targets_rate(a, num_classes=5, num_steps=5,
first_spike_time=2).size()
>>> torch.Size([5, 1, 5])
spikegen.targets_rate(a, num_classes=5, num_steps=5,
first_spike_time=2)[:, 0, 4]
>>> (tensor([0., 0., 1., 1., 1.]))
# note: time has not been repeated because every time step
would be identical where first_spike_time defaults to 0
spikegen.targets_rate(a, num_classes=5, num_steps=5).size()
>>> torch.Size([1, 5])
# on/off targets - membrane evolution over time
spikegen.targets_rate(a, num_classes=5, num_steps=5,
first_spike_time=2, on_target=1.2, off_target=0.5)[:, 0, 4]
>>> (tensor([0.5000, 0.5000, 1.2000, 1.2000, 1.2000]))
# correct rate at 25% + linear interpolation of membrane evolution
spikegen.targets_rate(a, num_classes=5, num_steps=5,
correct_rate=0.25, on_target=1.2,
off_target=0.5, interpolate=True)[:, 0, 4]
>>> tensor([1.2000, 0.5000, 0.7333, 0.9667, 1.2000])
:param targets: Target tensor for a single batch. The target
should be a class index in the range [0, C-1]
where C=number of classes.
:type targets: torch.Tensor
:param num_classes: Number of outputs.
:type num_classes: int
:param num_steps: Number of time steps, defaults to ``False``
:type num_steps: int, optional
:param first_spike_time: Time step for first spike to occur,
defaults to ``0``
:type first_spike_time: int, optional
:param correct_rate: Firing frequency of correct class as a
ratio, e.g., ``1`` enables firing at every step; ``0.5``
enables firing at 50% of steps, ``0`` means no firing,
defaults to ``1``
:type correct_rate: float, optional
:param incorrect_rate: Firing frequency of incorrect class(es), e.g.,
``1`` enables firing at every step; ``0.5``
enables firing at 50% of steps,
``0`` means no firing, defaults to ``0``
:type incorrect_rate: float, optional
:param on_target: Target at spike times, defaults to ``1``
:type on_target: float, optional
:param off_target: Target during refractory period, defaults to ``0``
:type off_target: float, optional
:param firing_pattern: Firing pattern of correct and incorrect classes.
``'regular'`` enables periodic firing, ``'uniform'`` samples spike
times from a uniform distributions (duplicates are removed),
``'poisson'`` samples from a binomial distribution at each step
where each probability is the firing frequency,
defaults to ``'regular'``
:type firing_pattern: string, optional
:param interpolate: Applies linear interpolation such that there
is a gradually increasing target
up to each spike, defaults to ``False``
:type interpolate: Bool, optional
:param epsilon: A tiny positive value to avoid rounding errors when
using torch.arange, defaults to ``1e-7``
:type epsilon: float, optional
:return: rate coded target of output neurons. If targets are
time-varying, the output tensor will use time-first dimensions.
Otherwise, time is omitted from the tensor.
:rtype: torch.Tensor
"""
if not 0 <= correct_rate <= 1 or not 0 <= incorrect_rate <= 1:
raise Exception(
f"``correct_rate``{correct_rate} and "
f"``incorrect_rate``{incorrect_rate} must be between 0 and 1."
)
if not num_steps and (correct_rate != 1 or incorrect_rate != 0):
raise Exception(
"``num_steps`` must be passed if correct_rate is not 1 or "
"incorrect_rate is not 0."
)
if incorrect_rate > correct_rate:
raise Exception(
"``correct_rate`` must be greater than ``incorrect_rate``."
)
if firing_pattern.lower() not in ["regular", "uniform", "poisson"]:
raise Exception(
"``firing_pattern`` must be either 'regular', 'uniform' or "
"'poisson'."
)
device = targets.device
# return a non time-varying tensor
if correct_rate == 1 and incorrect_rate == 0:
if first_spike_time == 0:
if on_target > off_target:
return torch.clamp(
to_one_hot(targets, num_classes) * on_target, off_target
)
else:
return (
to_one_hot(targets, num_classes) * on_target
+ ~(to_one_hot(targets, num_classes)).bool() * off_target
)
# return time-varying tensor: off up to first_spike_time,
# then correct classes are on after
if first_spike_time > 0:
spike_targets = torch.clamp(
to_one_hot(targets, num_classes) * on_target, off_target
)
spike_targets = spike_targets.repeat(
tuple(
[num_steps]
+ torch.ones(len(spike_targets.size()), dtype=int).tolist()
)
)
spike_targets[0:first_spike_time] = off_target
return spike_targets
# executes if on/off firing rates are not 100% / 0%
else:
one_hot_targets = to_one_hot(targets, num_classes)
one_hot_inverse = to_one_hot_inverse(one_hot_targets)
# project one-hot-encodings along the time-axis (0th dim)
one_hot_targets = one_hot_targets.repeat(
tuple(
[num_steps]
+ torch.ones(len(one_hot_targets.size()), dtype=int).tolist()
)
)
one_hot_inverse = one_hot_inverse.repeat(
tuple(
[num_steps]
+ torch.ones(len(one_hot_inverse.size()), dtype=int).tolist()
)
)
# create tensor of spike_targets for correct class
correct_spike_targets, correct_spike_times = target_rate_code(
num_steps=num_steps,
first_spike_time=first_spike_time,
rate=correct_rate,
firing_pattern=firing_pattern,
)
correct_spikes_one_hot = one_hot_targets * correct_spike_targets.to(
device
).unsqueeze(-1).unsqueeze(
-1
) # the two unsquezes make the dims of correct_spikes
# num_steps x 1 x 1, s.t. time is broadcast in every other direction
# create tensor of spike targets for incorrect class
incorrect_spike_targets, incorrect_spike_times = target_rate_code(
num_steps=num_steps,
first_spike_time=first_spike_time,
rate=incorrect_rate,
firing_pattern=firing_pattern,
)
incorrect_spikes_one_hot = (
(one_hot_inverse * incorrect_spike_targets)
.to(device)
.unsqueeze(-1)
.unsqueeze(-1)
) # the two unsquezes make the dims of correct_spikes
# num_steps x 1 x 1, s.t. time is broadcasted in every other direction
# merge the incorrect and correct tensors
if not interpolate:
return torch.clamp(
(
incorrect_spikes_one_hot.to(device)
+ correct_spikes_one_hot.to(device)
)
* on_target,
off_target,
)
# interpolate values between spikes
else:
correct_spike_targets = one_hot_targets * (
rate_interpolate(
correct_spike_times,
num_steps=num_steps,
on_target=on_target,
off_target=off_target,
epsilon=epsilon,
)
.to(device)
.unsqueeze(-1)
.unsqueeze(-1)
) # the two unsquezes make the dims of correct_spikes
# num_steps x 1 x 1, s.t. the time is broadcasted in every
# other direction
incorrect_spike_targets = one_hot_inverse * (
rate_interpolate(
incorrect_spike_times,
num_steps=num_steps,
on_target=on_target,
off_target=off_target,
epsilon=epsilon,
)
.to(device)
.unsqueeze(-1)
.unsqueeze(-1)
)
return correct_spike_targets + incorrect_spike_targets
[docs]
def target_rate_code(
num_steps, first_spike_time=0, rate=1, firing_pattern="regular"
):
"""
Rate coding a single output neuron of tensor of length ``num_steps``
containing spikes, and another tensor containing the spike times.
Example::
spikegen.target_rate_code(num_steps=5, rate=1)
>>> (tensor([1., 1., 1., 1., 1.]), tensor([0, 1, 2, 3, 4]))
spikegen.target_rate_code(num_steps=5, first_spike_time=3, rate=1)
>>> (tensor([0., 0., 0., 1., 1.]), tensor([3, 4]))
spikegen.target_rate_code(num_steps=5, rate=0.3)
>>> (tensor([1., 0., 0., 1., 0.]), tensor([0, 3]))
spikegen.target_rate_code(
num_steps=5, rate=0.3, firing_pattern="poisson")
>>> (tensor([0., 1., 0., 1., 0.]), tensor([1, 3]))
:param num_steps: Number of time steps, defaults to ``False``
:type num_steps: int, optional
:param first_spike_time: Time step for first spike to occur,
defaults to ``0``
:type first_spike_time: int, optional
:param rate: Firing frequency as a ratio, e.g., ``1``
enables firing at every step; ``0.5`` enables firing at 50% of steps,
``0`` means no firing, defaults to ``1``
:type rate: float, optional
:param firing_pattern: Firing pattern of correct and
incorrect classes. ``'regular'`` enables periodic firing,
``'uniform'`` samples spike times from a uniform distributions
(duplicates are removed), ``'poisson'`` samples from a binomial
distribution at each step where each probability
is the firing frequency,
defaults to ``'regular'``
:type firing_pattern: string, optional
:return: rate coded target of single neuron class of length ``num_steps``
:rtype: torch.Tensor
:return: rate coded spike times in terms of steps
:rtype: torch.Tensor
"""
if not 0 <= rate <= 1:
raise Exception(f"``rate``{rate} must be between 0 and 1.")
if first_spike_time > num_steps:
raise Exception(
f"``first_spike_time {first_spike_time} must be less "
f"than num_steps {num_steps}."
)
if rate == 0:
return torch.zeros(num_steps), torch.Tensor()
if firing_pattern.lower() == "regular":
spike_times = torch.arange(first_spike_time, num_steps, 1 / rate)
return (
torch.zeros(num_steps).scatter(0, spike_times.long(), 1),
spike_times.long(),
)
elif firing_pattern.lower() == "uniform":
spike_times = (
torch.rand(
len(torch.arange(first_spike_time, num_steps, 1 / rate))
)
* (num_steps - first_spike_time)
+ first_spike_time
)
return (
torch.zeros(num_steps).scatter(0, spike_times.long(), 1),
spike_times.long(),
)
elif firing_pattern.lower() == "poisson":
spike_targets = torch.bernoulli(
torch.cat(
(
# torch.zeros((first_spike_time), device=device),
# torch.ones((num_steps - first_spike_time),
# device=device) * rate,
torch.zeros((first_spike_time)),
torch.ones((num_steps - first_spike_time)) * rate,
)
)
)
return spike_targets, torch.where(spike_targets == 1)[0]
[docs]
def rate_interpolate(
spike_time, num_steps, on_target=1, off_target=0, epsilon=1e-7
):
"""Apply linear interpolation to a tensor of target spike times to
enable gradual increasing membrane.
Example::
a = torch.Tensor([0, 4])
spikegen.rate_interpolate(a, num_steps=5)
>>> tensor([1.0000, 0.0000, 0.3333, 0.6667, 1.0000])
spikegen.rate_interpolate(a, num_steps=5, on_target=1.25,
off_target=0.25)
>>> tensor([1.2500, 0.2500, 0.5833, 0.9167, 1.2500])
:param spike_time: spike time targets in terms of steps
:type targets: torch.Tensor
:param num_steps: Number of time steps, defaults to ``False``
:type num_steps: int, optional
:param on_target: Target at spike times, defaults to ``1``
:type on_target: float, optional
:param off_target: Target during refractory period, defaults to ``0``
:type off_target: float, optional
:param epsilon: A tiny positive value to avoid rounding errors when
using torch.arange, defaults to ``1e-7``
:type epsilon: float, optional
:return: interpolated target of output neurons. Output tensor will
use time-first dimensions.
:rtype: torch.Tensor
"""
# if no spikes
if not spike_time.numel():
return torch.ones((num_steps)) * off_target
current_time = -1
interpolated_targets = torch.Tensor([])
for step in range(num_steps):
if step in spike_time:
if step == (current_time + 1):
interpolated_targets = torch.cat(
(interpolated_targets, torch.Tensor([on_target]))
)
else:
interpolated_targets = torch.cat(
(
interpolated_targets,
torch.arange(
off_target,
on_target + epsilon,
(on_target - off_target)
/ (step - current_time - 1),
),
)
)
current_time = step
if torch.max(spike_time) < num_steps - 1:
for step in range(int(torch.max(spike_time).item()), num_steps - 1):
interpolated_targets = torch.cat(
(interpolated_targets, torch.Tensor([off_target]))
)
return interpolated_targets
[docs]
def latency_interpolate(spike_time, num_steps, on_target=1, off_target=0):
"""Apply linear interpolation to a tensor of target spike times to
enable gradual increasing membrane.
Each spike is assumed to occur from a separate neuron.
Example::
a = torch.Tensor([0, 4])
spikegen.latency_interpolate(a, num_steps=5)
>>> tensor([[1.0000, 0.0000],
[0.0000, 0.2500],
[0.0000, 0.5000],
[0.0000, 0.7500],
[0.0000, 1.0000]])
spikegen.latency_interpolate(a, num_steps=5, on_target=1.25,
off_target=0.25)
>>> tensor([[1.2500, 0.2500],
[0.2500, 0.5000],
[0.2500, 0.7500],
[0.2500, 1.0000],
[0.2500, 1.2500]])
:param spike_time: spike time targets in terms of steps
:type targets: torch.Tensor
:param num_steps: Number of time steps, defaults to ``False``
:type num_steps: int, optional
:param on_target: Target at spike times, defaults to ``1``
:type on_target: float, optional
:param off_target: Target during refractory period, defaults to ``0``
:type off_target: float, optional
:return: interpolated target of output neurons. Output tensor will use
time-first dimensions.
:rtype: torch.Tensor
"""
if on_target < off_target:
raise Exception(
f"``on_target`` [{on_target}] must be greater than "
f"``off_target`` [{off_target}]."
)
device = spike_time.device
spike_time = torch.round(
spike_time
).float() # Needs to be float as 0s and out-of-bounds spikes
# are set to 0.5
spike_time[
spike_time > num_steps
] = 0.5 # avoid div by 0. instead setting spike time to < 1
# --> (step/spike_time) > 1, which gets clipped.
interpolated_targets = torch.ones(
(tuple([num_steps] + list(spike_time.size()))),
dtype=dtype,
device=device,
)
# offset skips first step if a 0 spike occurs. must be handled
# separately to avoid div by zero.
offset = 0
# index into first step
if 0 in spike_time:
interpolated_targets[0] = torch.where(
spike_time == 0,
interpolated_targets[0],
interpolated_targets[0] * 0,
) # replace 0's with ones for first spike time, others with 0s
spike_time[spike_time == 0] = 0.5
offset = 1
# i.e., when step/spike_time=1
for step in range(num_steps - offset):
interpolated_targets[step + offset] = (step + offset) / spike_time
# next we clamp those that exceed 1, and rescale
interpolated_targets = (
interpolated_targets * (on_target - off_target) + off_target
)
interpolated_targets[interpolated_targets > on_target] = off_target
return interpolated_targets
[docs]
def targets_latency(
targets,
num_classes,
num_steps=False,
first_spike_time=0,
on_target=1,
off_target=0,
interpolate=False,
threshold=0.01,
tau=1,
clip=False,
normalize=False,
linear=False,
epsilon=1e-7,
bypass=False,
):
"""Latency encoding of target labels. Use target labels to determine
time-to-first spike. Expected input is index of correct class.
The index is one-hot-encoded before being passed to ``spikegen.latency``.
Assume a LIF neuron model that charges up with time constant tau.
Tensor dimensions use time first.
Example::
a = torch.Tensor([0, 3])
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True).size()
>>> torch.Size([5, 2, 4])
# time evolution of correct neuron class
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True)[:, 0, 0]
>>> tensor([1., 0., 0., 0., 0.])
# time evolution of incorrect neuron class
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True)[:, 0, 1]
>>> tensor([0., 0., 0., 0., 1.])
# correct class w/interpolation
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True, interpolate=True)[:, 0, 0]
>>> tensor([1., 0., 0., 0., 0.])
# incorrect class w/interpolation
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True, interpolate=True)[:, 0, 1]
>>> tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
:param targets: Target tensor for a single batch. The target
should be a class index in the range [0, C-1]
where C=number of classes.
:type targets: torch.Tensor
:param num_classes: Number of outputs.
:type num_classes: int
:param num_steps: Number of time steps. Explicitly needed if
``normalize=True``, defaults to ``False``
(then changed to ``1`` if ``normalize=False``)
:type num_steps: int, optional
:param first_spike_time: Time to first spike, defaults to ``0``.
:type first_spike_time: int, optional
:param on_target: Target at spike times, defaults to ``1``
:type on_target: float, optional
:param off_target: Target during refractory period, defaults to ``0``
:type off_target: float, optional
:param interpolate: Applies linear interpolation such that there is
a gradually increasing target up to each spike, defaults to ``False``
:type interpolate: Bool, optional
:param threshold: Input features below the threhold will fire at the
final time step unless ``clip=True`` in which case they will not fire
at all, defaults to ``0.01``
:type threshold: float, optional
:param tau: RC Time constant for LIF model used to calculate firing
time, defaults to ``1``
:type tau: float, optional
:param clip: Option to remove spikes from features that fall below
the threshold, defaults to ``False``
:type clip: Bool, optional
:param normalize: Option to normalize the latency code such that the
final spike(s) occur within num_steps, defaults to ``False``
:type normalize: Bool, optional
:param linear: Apply a linear latency code rather than the default
logarithmic code, defaults to ``False``
:type linear: Bool, optional
:param bypass: Used to block error messages that occur from either: i)
spike times exceeding the bounds of ``num_steps``, or ii) if
``num_steps`` is not specified, setting ``bypass=True``
allows the largest spike time to set ``num_steps``.
Defaults to ``False``
:type bypass: bool, optional
:param epsilon: A tiny positive value to avoid rounding errors when
using torch.arange, defaults to ``1e-7``
:type epsilon: float, optional
:return: latency encoding spike train of features or labels
:rtype: torch.Tensor
"""
return latency(
to_one_hot(targets, num_classes),
num_steps=num_steps,
first_spike_time=first_spike_time,
on_target=on_target,
off_target=off_target,
interpolate=interpolate,
threshold=threshold,
tau=tau,
clip=clip,
normalize=normalize,
linear=linear,
bypass=bypass,
epsilon=epsilon,
)
[docs]
def to_one_hot_inverse(one_hot_targets):
"""Boolean inversion of a matrix of 1's and 0's.
Used to merge the targets of correct and incorrect neuron classes in
``targets_rate``.
Example::
a = torch.Tensor([0, 0, 0, 0, 1])
spikegen.to_one_hot_inverse(a)
>>> tensor([[1., 1., 1., 1., 0.]])
"""
one_hot_inverse = one_hot_targets.clone()
one_hot_inverse[one_hot_targets == 0] = 1
one_hot_inverse[one_hot_targets != 0] = 0
return one_hot_inverse
[docs]
def to_one_hot(targets, num_classes):
"""One hot encoding of target labels.
Example::
targets = torch.tensor([0, 1, 2, 3])
spikegen.targets_to_spikes(targets, num_classes=4)
>>> tensor([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
:param targets: Target tensor for a single batch
:type targets: torch.Tensor
:param num_classes: Number of classes
:type num_classes: int
:return: one-hot encoding of targets of shape [batch x num_classes]
:rtype: torch.Tensor
"""
if torch.max(targets > num_classes - 1):
raise Exception(
f"target [{torch.max(targets)}] is out of bounds for "
f"``num_classes`` [{num_classes}]"
)
device = targets.device
# Initialize zeros. E.g, for MNIST: (batch_size, 10).
one_hot = torch.zeros(
[len(targets), num_classes], device=device, dtype=dtype
)
# Unsqueeze converts dims of [100] to [100, 1]
one_hot = one_hot.scatter(1, targets.type(torch.int64).unsqueeze(-1), 1)
return one_hot
[docs]
def from_one_hot(one_hot_label):
"""Convert one-hot encoding back into an integer
Example::
one_hot_label = torch.tensor([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
spikegen.from_one_hot(one_hot_label)
>>> tensor([0, 1, 2, 3])
:param targets: one-hot label vector
:type targets: torch.Tensor
:return: targets
:rtype: torch.Tensor
"""
# one_hot_label = torch.where(one_hot_label == 1)[0][0]
# return int(one_hot_label)
one_hot_label = torch.where(one_hot_label == 1)[0]
return one_hot_label