snntorch.spikeplot

snntorch.spikeplot is deeply integrated with matplotlib.pyplot and celluloid. It serves to reduce the amount of boilerplate code required to generate a variety of animations and plots.

class snntorch.spikeplot.Camera(figure: Figure)[source]

Bases: object

Make animations easier.

animate(*args, **kwargs) ArtistAnimation[source]

Animate the snapshots taken. Uses matplotlib.animation.ArtistAnimation Returns ——- ArtistAnimation

snap() List[Artist][source]

Capture current state of the figure.

snntorch.spikeplot.animator(data, fig, ax, num_steps=False, interval=40, cmap='plasma')[source]

Generate an animation by looping through the first dimension of a sample of spiking data. Time must be the first dimension of data.

Example:

import snntorch.spikeplot as splt
import matplotlib.pyplot as plt

#  spike_data contains 128 samples, each of 100 time steps in duration
print(spike_data.size())
>>> torch.Size([100, 128, 1, 28, 28])

#  Index into a single sample from a minibatch
spike_data_sample = spike_data[:, 0, 0]
print(spike_data_sample.size())
>>> torch.Size([100, 28, 28])

#  Plot
fig, ax = plt.subplots()
anim = splt.animator(spike_data_sample, fig, ax)
HTML(anim.to_html5_video())

#  Save as a gif
anim.save("spike_mnist.gif")
Parameters:
  • data (torch.Tensor) – Data tensor for a single sample across time steps of shape [num_steps x input_size]

  • fig (matplotlib.figure.Figure) – Top level container for all plot elements

  • ax (matplotlib.axes._subplots.AxesSubplot) –

    Contains additional figure elements and sets the coordinate system. E.g.:

    fig, ax = plt.subplots(facecolor=’w’, figsize=(12, 7))

  • num_steps (int, optional) –

    Number of time steps to plot. If not specified, the number of entries in the first dimension

    of data will automatically be used, defaults to False

  • interval (int, optional) – Delay between frames in milliseconds, defaults to 40

  • cmap (string, optional) – color map, defaults to plasma

Returns:

animation to be displayed using matplotlib.pyplot.show()

Return type:

FuncAnimation

snntorch.spikeplot.raster(data, ax, **kwargs)[source]

Generate a raster plot using matplotlib.pyplot.scatter.

Example:

import snntorch.spikeplot as splt
import matplotlib.pyplot as plt

#  spike_data contains 128 samples, each of 100 time steps in duration
print(spike_data.size())
>>> torch.Size([100, 128, 1, 28, 28])

#  Index into a single sample from a minibatch
spike_data_sample = spike_data[:, 0, 0]
print(spike_data_sample.size())
>>> torch.Size([100, 28, 28])

fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)

#  s: size of scatter points; c: color of scatter points
splt.raster(spike_data_sample, ax, s=1.5, c="black")
plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()
snntorch.spikeplot.spike_count(data, fig, ax, labels, num_steps=False, animate=False, interpolate=1, gridshader=True, interval=25, time_step=False)[source]

Generate horizontal bar plot for a single forward pass. Options to animate also available.

Example:

import snntorch.spikeplot as splt
import matplotlib.pyplot as plt
from IPython.display import HTML

num_steps = 25

#  Use splt.spike_count to display behavior of output neurons for a
single sample during feedforward

#  spk_rec is a recording of output spikes across 25 time steps,
using ``batch_size=128``
print(spk_rec.size())
>>> torch.Size([25, 128, 10])

#  We only need a single data sample
spk_results = torch.stack(spk_rec, dim=0)[:, 0, :].to('cpu')
print(spk_results.size())
>>> torch.Size([25, 10])

fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']

#  Plot and save spike count histogram
splt.spike_count(spk_results, fig, ax, labels, num_steps = num_steps,
time_step=1e-3)
plt.show()
plt.savefig('hist2.png', dpi=300, bbox_inches='tight')

# Animate and save spike count histogram
anim = splt.spike_count(spk_results, fig, ax, labels, animate=True,
interpolate=5, num_steps = num_steps, time_step=1e-3)
HTML(anim.to_html5_video())
anim.save("spike_bar.gif")
Parameters:
  • data (torch.Tensor) – Sample of spiking data across numerous time steps [num_steps x num_outputs]

  • fig (matplotlib.figure.Figures) – Top level container for all plot elements

  • ax (matplotlib.axes._subplots.AxesSubplot) –

    Contains additional figure elements and sets the coordinate system.

    E.g., fig, ax = plt.subplots(facecolor=’w’, figsize=(12, 7))

  • labels (list) – List of strings of the names of the output labels. E.g., for MNIST, labels = ['0', '1', '2', ... , '9']

  • num_steps (int, optional) –

    Number of time steps to plot. If not specified, the number of entries in the first dimension

    of data will automatically be used, defaults to False

  • animate (Bool, optional) –

    If True, return type matplotlib.animation. ArtistAnimation sequentially scanning across the

    range of time steps available in data. If False, display plot of the final step once all spikes have been counted, defaults to False

  • interpolate (int, optional) –

    Can be increased to smooth the animation of the vertical time bar. The value passed is the

    interpolation factor: e.g., interpolate=1 results in no additional interpolation. e.g., interpolate=5 results in 4 additional frames for each time step, defaults to 1

  • gridshader (Bool, optional) – Applies shading to figure background to distinguish output classes, defaults to True

  • interval (int, optional) – Delay between frames in milliseconds, defaults to 25

  • time_step (int, optional) – Duration of each time step in seconds. If False, time-axis will be in terms of num_steps. Else, time-axis is scaled by the argument passed, defaults to False

Returns:

animation to be displayed using matplotlib.pyplot.show()

Return type:

FuncAnimation (if animate is True)

snntorch.spikeplot.traces(data, spk=None, dim=(3, 3), spk_height=5, titles=None)[source]

Plot an array of neuron traces (e.g., membrane potential or synaptic current). Optionally apply spikes to ride on the traces. traces was originally written by Friedemann Zenke.

Example:

import snntorch.spikeplot as splt

#  mem_rec contains the traces of 9 neuron membrane potentials across
100 time steps in duration
print(mem_rec.size())
>>> torch.Size([100, 9])

#  Plot
traces(mem_rec, dim=(3,3))
Parameters:
  • data (torch.Tensor) – Data tensor for neuron traces across time steps of shape [num_steps x num_neurons]

  • spk (torch.Tensor, optional) – Data tensor for neuron traces across time steps of shape [num_steps x num_neurons], defaults to None

  • dim (tuple, optional) – Dimensions of figure, defaults to (3, 3)

  • spk_height (float, optional) – height of spike to plot, defaults to 5

  • titles (list of strings, optional) – Adds subplot titles, defaults to None