snntorch.spikeplot
snntorch.spikeplot
is deeply integrated with matplotlib.pyplot and celluloid.
It serves to reduce the amount of boilerplate code required to generate a variety of animations and plots.
- class snntorch.spikeplot.Camera(figure: Figure)[source]
Bases:
object
Make animations easier.
- snntorch.spikeplot.animator(data, fig, ax, num_steps=False, interval=40, cmap='plasma')[source]
Generate an animation by looping through the first dimension of a sample of spiking data. Time must be the first dimension of
data
.Example:
import snntorch.spikeplot as splt import matplotlib.pyplot as plt # spike_data contains 128 samples, each of 100 time steps in duration print(spike_data.size()) >>> torch.Size([100, 128, 1, 28, 28]) # Index into a single sample from a minibatch spike_data_sample = spike_data[:, 0, 0] print(spike_data_sample.size()) >>> torch.Size([100, 28, 28]) # Plot fig, ax = plt.subplots() anim = splt.animator(spike_data_sample, fig, ax) HTML(anim.to_html5_video()) # Save as a gif anim.save("spike_mnist.gif")
- Parameters:
data (torch.Tensor) – Data tensor for a single sample across time steps of shape [num_steps x input_size]
fig (matplotlib.figure.Figure) – Top level container for all plot elements
ax (matplotlib.axes._subplots.AxesSubplot) –
Contains additional figure elements and sets the coordinate system. E.g.:
fig, ax = plt.subplots(facecolor=’w’, figsize=(12, 7))
num_steps (int, optional) –
Number of time steps to plot. If not specified, the number of entries in the first dimension
of
data
will automatically be used, defaults toFalse
interval (int, optional) – Delay between frames in milliseconds, defaults to
40
cmap (string, optional) – color map, defaults to
plasma
- Returns:
animation to be displayed using
matplotlib.pyplot.show()
- Return type:
FuncAnimation
- snntorch.spikeplot.raster(data, ax, **kwargs)[source]
Generate a raster plot using
matplotlib.pyplot.scatter
.Example:
import snntorch.spikeplot as splt import matplotlib.pyplot as plt # spike_data contains 128 samples, each of 100 time steps in duration print(spike_data.size()) >>> torch.Size([100, 128, 1, 28, 28]) # Index into a single sample from a minibatch spike_data_sample = spike_data[:, 0, 0] print(spike_data_sample.size()) >>> torch.Size([100, 28, 28]) fig = plt.figure(facecolor="w", figsize=(10, 5)) ax = fig.add_subplot(111) # s: size of scatter points; c: color of scatter points splt.raster(spike_data_sample, ax, s=1.5, c="black") plt.title("Input Layer") plt.xlabel("Time step") plt.ylabel("Neuron Number") plt.show()
- snntorch.spikeplot.spike_count(data, fig, ax, labels, num_steps=False, animate=False, interpolate=1, gridshader=True, interval=25, time_step=False)[source]
Generate horizontal bar plot for a single forward pass. Options to animate also available.
Example:
import snntorch.spikeplot as splt import matplotlib.pyplot as plt from IPython.display import HTML num_steps = 25 # Use splt.spike_count to display behavior of output neurons for a single sample during feedforward # spk_rec is a recording of output spikes across 25 time steps, using ``batch_size=128`` print(spk_rec.size()) >>> torch.Size([25, 128, 10]) # We only need a single data sample spk_results = torch.stack(spk_rec, dim=0)[:, 0, :].to('cpu') print(spk_results.size()) >>> torch.Size([25, 10]) fig, ax = plt.subplots(facecolor='w', figsize=(12, 7)) labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9'] # Plot and save spike count histogram splt.spike_count(spk_results, fig, ax, labels, num_steps = num_steps, time_step=1e-3) plt.show() plt.savefig('hist2.png', dpi=300, bbox_inches='tight') # Animate and save spike count histogram anim = splt.spike_count(spk_results, fig, ax, labels, animate=True, interpolate=5, num_steps = num_steps, time_step=1e-3) HTML(anim.to_html5_video()) anim.save("spike_bar.gif")
- Parameters:
data (torch.Tensor) – Sample of spiking data across numerous time steps [num_steps x num_outputs]
fig (matplotlib.figure.Figures) – Top level container for all plot elements
ax (matplotlib.axes._subplots.AxesSubplot) –
Contains additional figure elements and sets the coordinate system.
E.g., fig, ax = plt.subplots(facecolor=’w’, figsize=(12, 7))
labels (list) – List of strings of the names of the output labels. E.g., for MNIST,
labels = ['0', '1', '2', ... , '9']
num_steps (int, optional) –
Number of time steps to plot. If not specified, the number of entries in the first dimension
of
data
will automatically be used, defaults toFalse
animate (Bool, optional) –
If
True
, return type matplotlib.animation. ArtistAnimation sequentially scanning across therange of time steps available in
data
. IfFalse
, display plot of the final step once all spikes have been counted, defaults toFalse
interpolate (int, optional) –
Can be increased to smooth the animation of the vertical time bar. The value passed is the
interpolation factor: e.g.,
interpolate=1
results in no additional interpolation. e.g.,interpolate=5
results in 4 additional frames for each time step, defaults to1
gridshader (Bool, optional) – Applies shading to figure background to distinguish output classes, defaults to
True
interval (int, optional) – Delay between frames in milliseconds, defaults to
25
time_step (int, optional) – Duration of each time step in seconds. If
False
, time-axis will be in terms ofnum_steps
. Else, time-axis is scaled by the argument passed, defaults toFalse
- Returns:
animation to be displayed using
matplotlib.pyplot.show()
- Return type:
FuncAnimation (if animate is
True
)
- snntorch.spikeplot.traces(data, spk=None, dim=(3, 3), spk_height=5, titles=None)[source]
Plot an array of neuron traces (e.g., membrane potential or synaptic current). Optionally apply spikes to ride on the traces. traces was originally written by Friedemann Zenke.
Example:
import snntorch.spikeplot as splt # mem_rec contains the traces of 9 neuron membrane potentials across 100 time steps in duration print(mem_rec.size()) >>> torch.Size([100, 9]) # Plot traces(mem_rec, dim=(3,3))
- Parameters:
data (torch.Tensor) – Data tensor for neuron traces across time steps of shape [num_steps x num_neurons]
spk (torch.Tensor, optional) – Data tensor for neuron traces across time steps of shape [num_steps x num_neurons], defaults to
None
dim (tuple, optional) – Dimensions of figure, defaults to
(3, 3)
spk_height (float, optional) – height of spike to plot, defaults to
5
titles (list of strings, optional) – Adds subplot titles, defaults to
None