snntorch.spikegen

snntorch.spikegen is a module that provides a variety of common spike generation and conversion methods, including spike-rate and latency coding.

How to use spikegen

In general, tensors containing non-spiking data can simply be passed into one of the functions in snntorch.spikegen to convert them into discrete spikes. There are a variety of methods to achieve this conversion. At present, snntorch supports:

There are also options for converting targets into time-varying spikes.

snntorch.spikegen.delta(data, threshold=0.1, padding=False, off_spike=False)[source]

Generate spike only when the difference between two subsequent time steps meets a threshold. Optionally include off_spikes for negative changes.

Example:

a = torch.Tensor([1, 2, 2.9, 3, 3.9])
spikegen.delta(a, threshold=1)
>>> tensor([1., 1., 0., 0., 0.])

spikegen.delta(a, threshold=1, padding=True)
>>> tensor([0., 1., 0., 0., 0.])

b = torch.Tensor([1, 2, 0, 2, 2.9])
spikegen.delta(b, threshold=1, off_spike=True)
>>> tensor([ 1.,  1., -1.,  1.,  0.])

spikegen.delta(b, threshold=1, padding=True, off_spike=True)
>>> tensor([ 0.,  1., -1.,  1.,  0.])
Parameters:
  • data (torch.Tensor) – Data tensor for a single batch of shape [num_steps x batch x input_size]

  • threshold – Input features with a change greater than the thresold across one timestep will generate a spike, defaults to 0.1

  • padding (bool, optional) – Used to change how the first time step of spikes are measured. If True, the first time step will be repeated with itself resulting in 0’s for the output spikes. If False, the first time step will be padded with 0’s, defaults to False

  • off_spike (bool, optional) – If True, negative spikes for changes less than -threshold, defaults to False

snntorch.spikegen.from_one_hot(one_hot_label)[source]

Convert one-hot encoding back into an integer

Example:

one_hot_label = torch.tensor([[1., 0., 0., 0.],
                              [0., 1., 0., 0.],
                              [0., 0., 1., 0.],
                              [0., 0., 0., 1.]])
spikegen.from_one_hot(one_hot_label)
>>> tensor([0, 1, 2, 3])
Parameters:

targets (torch.Tensor) – one-hot label vector

Returns:

targets

Return type:

torch.Tensor

snntorch.spikegen.latency(data, num_steps=False, threshold=0.01, tau=1, first_spike_time=0, on_target=1, off_target=0, clip=False, normalize=False, linear=False, interpolate=False, bypass=False, epsilon=1e-07)[source]

Latency encoding of input or target label data. Use input features to determine time-to-first spike. Expected inputs should be between 0 and 1.

Assume a LIF neuron model that charges up with time constant tau. Tensor dimensions use time first.

Example:

a = torch.Tensor([0.02, 0.5, 1])
spikegen.latency(a, num_steps=5, normalize=True, linear=True)
>>> tensor([[0., 0., 1.],
            [0., 0., 0.],
            [0., 1., 0.],
            [0., 0., 0.],
            [1., 0., 0.]])
Parameters:
  • data (torch.Tensor) – Data tensor for a single batch of shape [batch x input_size]

  • num_steps (int, optional) – Number of time steps. Explicitly needed if normalize=True, defaults to False (then changed to 1 if normalize=False)

  • threshold (float, optional) – Input features below the threhold will fire at the final time step unless clip=True in which case they will not fire at all, defaults to 0.01

  • tau (float, optional) – RC Time constant for LIF model used to calculate firing time, defaults to 1

  • first_spike_time (int, optional) – Time to first spike, defaults to 0.

  • on_target (float, optional) – Target at spike times, defaults to 1

  • off_target (float, optional) – Target during refractory period, defaults to 0

  • clip (Bool, optional) – Option to remove spikes from features that fall below the threshold, defaults to False

  • normalize (Bool, optional) – Option to normalize the latency code such that the final spike(s) occur within num_steps, defaults to False

  • linear (Bool, optional) – Apply a linear latency code rather than the default logarithmic code, defaults to False

  • interpolate (Bool, optional) – Applies linear interpolation such that there is a gradually increasing target up to each spike, defaults to False

  • bypass (bool, optional) – Used to block error messages that occur from either: i) spike times exceeding the bounds of num_steps, or ii) if num_steps is not specified, setting bypass=True allows the largest spike time to set num_steps. Defaults to False

  • epsilon (float, optional) – A tiny positive value to avoid rounding errors when using torch.arange, defaults to 1e-7

Returns:

latency encoding spike train of features or labels

Return type:

torch.Tensor

snntorch.spikegen.latency_code(data, num_steps=False, threshold=0.01, tau=1, first_spike_time=0, normalize=False, linear=False, epsilon=1e-07)[source]

Latency encoding of input data. Convert input features or target labels to spike times. Assumes a LIF neuron model that charges up with time constant tau by default.

Example:

a = torch.Tensor([0.02, 0.5, 1])
spikegen.latency_code(a, num_steps=5, normalize=True, linear=True)
>>> (tensor([3.9200, 2.0000, 0.0000]), tensor([False, False, False]))
Parameters:
  • data (torch.Tensor) – Data tensor for a single batch of shape [batch x input_size]

  • num_steps (int, optional) – Number of time steps. Explicitly needed if normalize=True, defaults to False (then changed to 1 if normalize=False)

  • threshold (float, optional) – Input features below the threhold will fire at the final time step unless clip=True in which case they will not fire at all, defaults to 0.01

  • tau (float, optional) – RC Time constant for LIF model used to calculate firing time, defaults to 1

  • first_spike_time (int, optional) – Time to first spike, defaults to 0.

  • normalize (Bool, optional) – Option to normalize the latency code such that the final spike(s) occur within num_steps, defaults to False

  • linear (Bool, optional) – Apply a linear latency code rather than the default logarithmic code, defaults to False

  • epsilon (float, optional) – A tiny positive value to avoid rounding errors when using torch.arange, defaults to 1e-7

Returns:

latency encoding spike times of features

Return type:

torch.Tensor

Returns:

Tensor of Boolean values which correspond to the latency encoding elements that fall below the threshold. Used in latency_conv to clip saturated spikes.

Return type:

torch.Tensor

snntorch.spikegen.latency_code_linear(data, num_steps=False, threshold=0.01, tau=1, first_spike_time=0, normalize=False)[source]

Linear latency encoding of input data. Convert input features or target labels to spike times.

Example:

a = torch.Tensor([0.02, 0.5, 1])
spikegen.latency_code(a, num_steps=5, normalize=True, linear=True)
>>> (tensor([3.9200, 2.0000, 0.0000]), tensor([False, False, False]))
Parameters:
  • data (torch.Tensor) – Data tensor for a single batch of shape [batch x input_size]

  • num_steps (int, optional) – Number of time steps. Explicitly needed if normalize=True, defaults to False (then changed to 1 if normalize=False)

  • threshold (float, optional) – Input features below the threhold will fire at the final time step, defaults to 0.01

  • tau (float, optional) – Linear time constant used to calculate firing time, defaults to 1

  • first_spike_time (int, optional) – Time to first spike, defaults to 0.

  • normalize (Bool, optional) – Option to normalize the latency code such that the final spike(s) occur within num_steps, defaults to False

Returns:

linear latency encoding spike times of features

Return type:

torch.Tensor

snntorch.spikegen.latency_code_log(data, num_steps=False, threshold=0.01, tau=1, first_spike_time=0, normalize=False, epsilon=1e-07)[source]

Logarithmic latency encoding of input data. Convert input features or target labels to spike times.

Example:

a = torch.Tensor([0.02, 0.5, 1])
spikegen.latency_code(a, num_steps=5, normalize=True)
>>> (tensor([4.0000, 0.1166, 0.0580]), tensor([False, False, False]))
Parameters:
  • data (torch.Tensor) – Data tensor for a single batch of shape [batch x input_size]

  • num_steps (int, optional) – Number of time steps. Explicitly needed if normalize=True, defaults to False (then changed to 1 if normalize=False)

  • threshold (float, optional) – Input features below the threhold will fire at the final time step, defaults to 0.01

  • tau (float, optional) – Logarithmic time constant used to calculate firing time, defaults to 1

  • first_spike_time (int, optional) – Time to first spike, defaults to 0.

  • normalize (Bool, optional) – Option to normalize the latency code such that the final spike(s) occur within num_steps, defaults to False

  • epsilon (float, optional) – A tiny positive value to avoid rounding errors when using torch.arange, defaults to 1e-7

Returns:

logarithmic latency encoding spike times of features

Return type:

torch.Tensor

snntorch.spikegen.latency_interpolate(spike_time, num_steps, on_target=1, off_target=0)[source]

Apply linear interpolation to a tensor of target spike times to enable gradual increasing membrane. Each spike is assumed to occur from a separate neuron.

Example:

a = torch.Tensor([0, 4])
spikegen.latency_interpolate(a, num_steps=5)
>>> tensor([[1.0000, 0.0000],
            [0.0000, 0.2500],
            [0.0000, 0.5000],
            [0.0000, 0.7500],
            [0.0000, 1.0000]])

spikegen.latency_interpolate(a, num_steps=5, on_target=1.25,
off_target=0.25)
>>> tensor([[1.2500, 0.2500],
            [0.2500, 0.5000],
            [0.2500, 0.7500],
            [0.2500, 1.0000],
            [0.2500, 1.2500]])
Parameters:
  • spike_time – spike time targets in terms of steps

  • num_steps (int, optional) – Number of time steps, defaults to False

  • on_target (float, optional) – Target at spike times, defaults to 1

  • off_target (float, optional) – Target during refractory period, defaults to 0

Returns:

interpolated target of output neurons. Output tensor will use time-first dimensions.

Return type:

torch.Tensor

snntorch.spikegen.rate(data, num_steps=False, gain=1, offset=0, first_spike_time=0, time_var_input=False)[source]

Spike rate encoding of input data. Convert tensor into Poisson spike trains using the features as the mean of a binomial distribution. If num_steps is specified, then the data will be first repeated in the first dimension before rate encoding.

If data is time-varying, tensor dimensions use time first.

Example:

# 100% chance of spike generation
a = torch.Tensor([1, 1, 1, 1])
spikegen.rate(a, num_steps=1)
>>> tensor([1., 1., 1., 1.])

# 0% chance of spike generation
b = torch.Tensor([0, 0, 0, 0])
spikegen.rate(b, num_steps=1)
>>> tensor([0., 0., 0., 0.])

# 50% chance of spike generation per time step
c = torch.Tensor([0.5, 0.5, 0.5, 0.5])
spikegen.rate(c, num_steps=1)
>>> tensor([0., 1., 0., 1.])

# Increasing num_steps will increase the length of
# the first dimension (time-first)
print(c.size())
>>> torch.Size([1, 4])

d = spikegen.rate(torch.Tensor([0.5, 0.5, 0.5, 0.5]), num_steps = 2)
print(d.size())
>>> torch.Size([2, 4])
Parameters:
  • data (torch.Tensor) – Data tensor for a single batch of shape [batch x input_size]

  • num_steps (int, optional) – Number of time steps. Only specify if input data does not already have time dimension, defaults to False

  • gain (float, optional) – Scale input features by the gain, defaults to 1

  • offset (torch.optim, optional) – Shift input features by the offset, defaults to 0

  • first_spike_time (int, optional) – Time to first spike, defaults to 0.

  • time_var_input (bool, optional) – Set to True if input tensor is time-varying. Otherwise, first_spike_time!=0 will modify the wrong dimension. Defaults to False

Returns:

rate encoding spike train of input features of shape [num_steps x batch x input_size]

Return type:

torch.Tensor

snntorch.spikegen.rate_conv(data)[source]

Convert tensor into Poisson spike trains using the features as the mean of a binomial distribution. Values outside the range of [0, 1] are clipped so they can be treated as probabilities.

Example:

# 100% chance of spike generation
a = torch.Tensor([1, 1, 1, 1])
spikegen.rate_conv(a)
>>> tensor([1., 1., 1., 1.])

# 0% chance of spike generation
b = torch.Tensor([0, 0, 0, 0])
spikegen.rate_conv(b)
>>> tensor([0., 0., 0., 0.])

# 50% chance of spike generation per time step
c = torch.Tensor([0.5, 0.5, 0.5, 0.5])
spikegen.rate_conv(c)
>>> tensor([0., 1., 0., 1.])
Parameters:

data (torch.Tensor) – Data tensor for a single batch of shape [batch x input_size]

Returns:

rate encoding spike train of input features of shape [num_steps x batch x input_size]

Return type:

torch.Tensor

snntorch.spikegen.rate_interpolate(spike_time, num_steps, on_target=1, off_target=0, epsilon=1e-07)[source]

Apply linear interpolation to a tensor of target spike times to enable gradual increasing membrane.

Example:

a = torch.Tensor([0, 4])
spikegen.rate_interpolate(a, num_steps=5)
>>> tensor([1.0000, 0.0000, 0.3333, 0.6667, 1.0000])

spikegen.rate_interpolate(a, num_steps=5, on_target=1.25,
off_target=0.25)
>>> tensor([1.2500, 0.2500, 0.5833, 0.9167, 1.2500])
Parameters:
  • spike_time – spike time targets in terms of steps

  • num_steps (int, optional) – Number of time steps, defaults to False

  • on_target (float, optional) – Target at spike times, defaults to 1

  • off_target (float, optional) – Target during refractory period, defaults to 0

  • epsilon (float, optional) – A tiny positive value to avoid rounding errors when using torch.arange, defaults to 1e-7

Returns:

interpolated target of output neurons. Output tensor will use time-first dimensions.

Return type:

torch.Tensor

snntorch.spikegen.target_rate_code(num_steps, first_spike_time=0, rate=1, firing_pattern='regular')[source]

Rate coding a single output neuron of tensor of length num_steps containing spikes, and another tensor containing the spike times.

Example:

spikegen.target_rate_code(num_steps=5, rate=1)
>>> (tensor([1., 1., 1., 1., 1.]), tensor([0, 1, 2, 3, 4]))

spikegen.target_rate_code(num_steps=5, first_spike_time=3, rate=1)
>>> (tensor([0., 0., 0., 1., 1.]), tensor([3, 4]))

spikegen.target_rate_code(num_steps=5, rate=0.3)
>>> (tensor([1., 0., 0., 1., 0.]), tensor([0, 3]))

spikegen.target_rate_code(
num_steps=5, rate=0.3, firing_pattern="poisson")
>>> (tensor([0., 1., 0., 1., 0.]), tensor([1, 3]))
Parameters:
  • num_steps (int, optional) – Number of time steps, defaults to False

  • first_spike_time (int, optional) – Time step for first spike to occur, defaults to 0

  • rate (float, optional) – Firing frequency as a ratio, e.g., 1 enables firing at every step; 0.5 enables firing at 50% of steps, 0 means no firing, defaults to 1

  • firing_pattern (string, optional) – Firing pattern of correct and incorrect classes. 'regular' enables periodic firing, 'uniform' samples spike times from a uniform distributions (duplicates are removed), 'poisson' samples from a binomial distribution at each step where each probability is the firing frequency, defaults to 'regular'

Returns:

rate coded target of single neuron class of length num_steps

Return type:

torch.Tensor

Returns:

rate coded spike times in terms of steps

Return type:

torch.Tensor

snntorch.spikegen.targets_convert(targets, num_classes, code='rate', num_steps=False, first_spike_time=0, correct_rate=1, incorrect_rate=0, on_target=1, off_target=0, firing_pattern='regular', interpolate=False, epsilon=1e-07, threshold=0.01, tau=1, clip=False, normalize=False, linear=False, bypass=False)[source]

Spike encoding of targets. Expected input is a 1-D tensor with index of targets. If the output tensor is time-varying, the returned tensor will have time in the first dimension. If it is not time-varying, then the returned tensor will omit the time dimension and use batch first.

The following arguments will necessarily incur a time-varying output:

code='latency', first_spike_time!=0, correct_rate!=1, or incorrect_rate!=0

The target output may be applied to the internal state (e.g., membrane) of the neuron or to the spike. The following arguments will produce an output tensor that may sensibly be applied as a target to either the output spike or the membrane potential, as the output will consistently be either a 1 or 0:

on_target=1, off_target=0, and interpolate=False

If any of the above 3 conditions do not hold, then the target is better suited for the output membrane potential, as it will likely include values other than 1 and 0.

Example:

a = torch.Tensor([4])

# rate-coding
# one-hot
spikegen.targets_convert(a, num_classes=5, code="rate")
>>> (tensor([[0., 0., 0., 0., 1.]]), )

# one-hot + time-first
spikegen.targets_convert(a, num_classes=5, code="rate",
correct_rate=0.8, incorrect_rate=0.2, num_steps=5).size()
>>> torch.Size([5, 1, 5])

For more examples of rate-coding, see help(snntorch.spikegen(targets_rate)).

Parameters:
  • targets (torch.Tensor) – Target tensor for a single batch. The target should be a class index in the range [0, C-1] where C=number of classes.

  • num_classes (int) – Number of outputs.

  • code (string, optional) – Encoding scheme. Options of 'rate' or 'latency', defaults to 'rate'

  • num_steps (int, optional) – Number of time steps, defaults to False

  • first_spike_time (int, optional) – Time step for first spike to occur, defaults to 0

  • correct_rate (float, optional) – Firing frequency of correct class as a ratio, e.g., 1 enables firing at every step; 0.5 enables firing at 50% of steps, 0 means no firing, defaults to 1

  • incorrect_rate (float, optional) – Firing frequency of incorrect class(es), e.g., 1 enables firing at every step; 0.5 enables firing at 50% of steps, 0 means no firing, defaults to 0

  • on_target (float, optional) – Target at spike times, defaults to 1

  • off_target (float, optional) – Target during refractory period, defaults to 0

  • firing_pattern (string, optional) – Firing pattern of correct and incorrect classes. 'regular' enables periodic firing, 'uniform' samples spike times from a uniform distributions (duplicates are removed), 'poisson' samples from a binomial distribution at each step where each probability is the firing frequency, defaults to 'regular'

  • interpolate (Bool, optional) – Applies linear interpolation such that there is a gradually increasing target up to each spike, defaults to False

  • epsilon (float, optional) – A tiny positive value to avoid rounding errors when using torch.arange, defaults to 1e-7

  • bypass (bool, optional) – Used to block error messages that occur from either: i) spike times exceeding the bounds of num_steps, or ii) if num_steps is not specified, setting bypass=True allows the largest spike time to set num_steps. Defaults to False

Returns:

spike coded target of output neurons. If targets are time-varying, the output tensor will use time-first dimensions. Otherwise, time is omitted from the tensor.

Return type:

torch.Tensor

snntorch.spikegen.targets_latency(targets, num_classes, num_steps=False, first_spike_time=0, on_target=1, off_target=0, interpolate=False, threshold=0.01, tau=1, clip=False, normalize=False, linear=False, epsilon=1e-07, bypass=False)[source]

Latency encoding of target labels. Use target labels to determine time-to-first spike. Expected input is index of correct class. The index is one-hot-encoded before being passed to spikegen.latency.

Assume a LIF neuron model that charges up with time constant tau. Tensor dimensions use time first.

Example:

a = torch.Tensor([0, 3])
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True).size()
>>> torch.Size([5, 2, 4])

# time evolution of correct neuron class
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True)[:, 0, 0]
>>> tensor([1., 0., 0., 0., 0.])

# time evolution of incorrect neuron class
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True)[:, 0, 1]
>>> tensor([0., 0., 0., 0., 1.])

# correct class w/interpolation
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True, interpolate=True)[:, 0, 0]
>>> tensor([1., 0., 0., 0., 0.])

# incorrect class w/interpolation
spikegen.targets_latency(a, num_classes=4, num_steps=5,
normalize=True, interpolate=True)[:, 0, 1]
>>> tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
Parameters:
  • targets (torch.Tensor) – Target tensor for a single batch. The target should be a class index in the range [0, C-1] where C=number of classes.

  • num_classes (int) – Number of outputs.

  • num_steps (int, optional) – Number of time steps. Explicitly needed if normalize=True, defaults to False (then changed to 1 if normalize=False)

  • first_spike_time (int, optional) – Time to first spike, defaults to 0.

  • on_target (float, optional) – Target at spike times, defaults to 1

  • off_target (float, optional) – Target during refractory period, defaults to 0

  • interpolate (Bool, optional) – Applies linear interpolation such that there is a gradually increasing target up to each spike, defaults to False

  • threshold (float, optional) – Input features below the threhold will fire at the final time step unless clip=True in which case they will not fire at all, defaults to 0.01

  • tau (float, optional) – RC Time constant for LIF model used to calculate firing time, defaults to 1

  • clip (Bool, optional) – Option to remove spikes from features that fall below the threshold, defaults to False

  • normalize (Bool, optional) – Option to normalize the latency code such that the final spike(s) occur within num_steps, defaults to False

  • linear (Bool, optional) – Apply a linear latency code rather than the default logarithmic code, defaults to False

  • bypass (bool, optional) – Used to block error messages that occur from either: i) spike times exceeding the bounds of num_steps, or ii) if num_steps is not specified, setting bypass=True allows the largest spike time to set num_steps. Defaults to False

  • epsilon (float, optional) – A tiny positive value to avoid rounding errors when using torch.arange, defaults to 1e-7

Returns:

latency encoding spike train of features or labels

Return type:

torch.Tensor

snntorch.spikegen.targets_rate(targets, num_classes, num_steps=False, first_spike_time=0, correct_rate=1, incorrect_rate=0, on_target=1, off_target=0, firing_pattern='regular', interpolate=False, epsilon=1e-07)[source]

Spike rate encoding of targets. Input tensor must be one-dimensional with target indexes. If the output tensor is time-varying, the returned tensor will have time in the first dimension. If it is not time-varying, then the returned tensor will omit the time dimension and use batch first. If first_spike_time!=0, correct_rate!=1, or incorrect_rate!=0, the output tensor will be time-varying.

If on_target=1, off_target=0, and interpolate=False, then the target may sensibly be applied as a target for the output spike. IF any of the above 3 conditions do not hold, then the target would be better suited for the output membrane potential.

Example:

a = torch.Tensor([4])

# one-hot
spikegen.targets_rate(a, num_classes=5)
>>> (tensor([[0., 0., 0., 0., 1.]]), )

# first spike time delay, spike evolution over time
spikegen.targets_rate(a, num_classes=5, num_steps=5,
first_spike_time=2).size()
>>> torch.Size([5, 1, 5])
spikegen.targets_rate(a, num_classes=5, num_steps=5,
first_spike_time=2)[:, 0, 4]
>>> (tensor([0., 0., 1., 1., 1.]))

# note: time has not been repeated because every time step
 would be identical where first_spike_time defaults to 0
spikegen.targets_rate(a, num_classes=5, num_steps=5).size()
>>> torch.Size([1, 5])

# on/off targets - membrane evolution over time
spikegen.targets_rate(a, num_classes=5, num_steps=5,
first_spike_time=2, on_target=1.2, off_target=0.5)[:, 0, 4]
>>> (tensor([0.5000, 0.5000, 1.2000, 1.2000, 1.2000]))

# correct rate at 25% + linear interpolation of membrane evolution
spikegen.targets_rate(a, num_classes=5, num_steps=5,
correct_rate=0.25, on_target=1.2,
off_target=0.5, interpolate=True)[:, 0, 4]
>>> tensor([1.2000, 0.5000, 0.7333, 0.9667, 1.2000])
Parameters:
  • targets (torch.Tensor) – Target tensor for a single batch. The target should be a class index in the range [0, C-1] where C=number of classes.

  • num_classes (int) – Number of outputs.

  • num_steps (int, optional) – Number of time steps, defaults to False

  • first_spike_time (int, optional) – Time step for first spike to occur, defaults to 0

  • correct_rate (float, optional) – Firing frequency of correct class as a ratio, e.g., 1 enables firing at every step; 0.5 enables firing at 50% of steps, 0 means no firing, defaults to 1

  • incorrect_rate (float, optional) – Firing frequency of incorrect class(es), e.g., 1 enables firing at every step; 0.5 enables firing at 50% of steps, 0 means no firing, defaults to 0

  • on_target (float, optional) – Target at spike times, defaults to 1

  • off_target (float, optional) – Target during refractory period, defaults to 0

  • firing_pattern (string, optional) – Firing pattern of correct and incorrect classes. 'regular' enables periodic firing, 'uniform' samples spike times from a uniform distributions (duplicates are removed), 'poisson' samples from a binomial distribution at each step where each probability is the firing frequency, defaults to 'regular'

  • interpolate (Bool, optional) – Applies linear interpolation such that there is a gradually increasing target up to each spike, defaults to False

  • epsilon (float, optional) – A tiny positive value to avoid rounding errors when using torch.arange, defaults to 1e-7

Returns:

rate coded target of output neurons. If targets are time-varying, the output tensor will use time-first dimensions. Otherwise, time is omitted from the tensor.

Return type:

torch.Tensor

snntorch.spikegen.to_one_hot(targets, num_classes)[source]

One hot encoding of target labels.

Example:

targets = torch.tensor([0, 1, 2, 3])
spikegen.targets_to_spikes(targets, num_classes=4)
>>> tensor([[1., 0., 0., 0.],
            [0., 1., 0., 0.],
            [0., 0., 1., 0.],
            [0., 0., 0., 1.]])
Parameters:
  • targets (torch.Tensor) – Target tensor for a single batch

  • num_classes (int) – Number of classes

Returns:

one-hot encoding of targets of shape [batch x num_classes]

Return type:

torch.Tensor

snntorch.spikegen.to_one_hot_inverse(one_hot_targets)[source]

Boolean inversion of a matrix of 1’s and 0’s. Used to merge the targets of correct and incorrect neuron classes in targets_rate.

Example:

a = torch.Tensor([0, 0, 0, 0, 1])
spikegen.to_one_hot_inverse(a)
>>> tensor([[1., 1., 1., 1., 0.]])