snntorch.spikeplot
snntorch.spikeplot is deeply integrated with matplotlib.pyplot and celluloid.
It serves to reduce the amount of boilerplate code required to generate a variety of animations and plots.
- class snntorch.spikeplot.Camera(figure: Figure)[source]
Bases:
objectMake animations easier.
- snntorch.spikeplot.animator(data, fig, ax, num_steps=False, interval=40, cmap='plasma')[source]
Generate an animation by looping through the first dimension of a sample of spiking data. Time must be the first dimension of
data.Example:
import snntorch.spikeplot as splt import matplotlib.pyplot as plt # spike_data contains 128 samples, each of 100 time steps in duration print(spike_data.size()) >>> torch.Size([100, 128, 1, 28, 28]) # Index into a single sample from a minibatch spike_data_sample = spike_data[:, 0, 0] print(spike_data_sample.size()) >>> torch.Size([100, 28, 28]) # Plot fig, ax = plt.subplots() anim = splt.animator(spike_data_sample, fig, ax) HTML(anim.to_html5_video()) # Save as a gif anim.save("spike_mnist.gif")
- Parameters:
data (torch.Tensor) – Data tensor for a single sample across time steps of shape [num_steps x input_size]
fig (matplotlib.figure.Figure) – Top level container for all plot elements
ax (matplotlib.axes._subplots.AxesSubplot) –
Contains additional figure elements and sets the coordinate system. E.g.:
fig, ax = plt.subplots(facecolor=’w’, figsize=(12, 7))
num_steps (int, optional) –
Number of time steps to plot. If not specified, the number of entries in the first dimension
of
datawill automatically be used, defaults toFalseinterval (int, optional) – Delay between frames in milliseconds, defaults to
40cmap (string, optional) – color map, defaults to
plasma
- Returns:
animation to be displayed using
matplotlib.pyplot.show()- Return type:
FuncAnimation
- snntorch.spikeplot.raster(data, ax, **kwargs)[source]
Generate a raster plot using
matplotlib.pyplot.scatter.Example:
import snntorch.spikeplot as splt import matplotlib.pyplot as plt # spike_data contains 128 samples, each of 100 time steps in duration print(spike_data.size()) >>> torch.Size([100, 128, 1, 28, 28]) # Index into a single sample from a minibatch spike_data_sample = spike_data[:, 0, 0] print(spike_data_sample.size()) >>> torch.Size([100, 28, 28]) fig = plt.figure(facecolor="w", figsize=(10, 5)) ax = fig.add_subplot(111) # s: size of scatter points; c: color of scatter points splt.raster(spike_data_sample, ax, s=1.5, c="black") plt.title("Input Layer") plt.xlabel("Time step") plt.ylabel("Neuron Number") plt.show()
- snntorch.spikeplot.spike_count(data, fig, ax, labels, num_steps=False, animate=False, interpolate=1, gridshader=True, interval=25, time_step=False)[source]
Generate horizontal bar plot for a single forward pass. Options to animate also available.
Example:
import snntorch.spikeplot as splt import matplotlib.pyplot as plt from IPython.display import HTML num_steps = 25 # Use splt.spike_count to display behavior of output neurons for a single sample during feedforward # spk_rec is a recording of output spikes across 25 time steps, using ``batch_size=128`` print(spk_rec.size()) >>> torch.Size([25, 128, 10]) # We only need a single data sample spk_results = torch.stack(spk_rec, dim=0)[:, 0, :].to('cpu') print(spk_results.size()) >>> torch.Size([25, 10]) fig, ax = plt.subplots(facecolor='w', figsize=(12, 7)) labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9'] # Plot and save spike count histogram splt.spike_count(spk_results, fig, ax, labels, num_steps = num_steps, time_step=1e-3) plt.show() plt.savefig('hist2.png', dpi=300, bbox_inches='tight') # Animate and save spike count histogram anim = splt.spike_count(spk_results, fig, ax, labels, animate=True, interpolate=5, num_steps = num_steps, time_step=1e-3) HTML(anim.to_html5_video()) anim.save("spike_bar.gif")- Parameters:
data (torch.Tensor) – Sample of spiking data across numerous time steps [num_steps x num_outputs]
fig (matplotlib.figure.Figures) – Top level container for all plot elements
ax (matplotlib.axes._subplots.AxesSubplot) –
Contains additional figure elements and sets the coordinate system.
E.g., fig, ax = plt.subplots(facecolor=’w’, figsize=(12, 7))
labels (list) – List of strings of the names of the output labels. E.g., for MNIST,
labels = ['0', '1', '2', ... , '9']num_steps (int, optional) –
Number of time steps to plot. If not specified, the number of entries in the first dimension
of
datawill automatically be used, defaults toFalseanimate (Bool, optional) –
If
True, return type matplotlib.animation. ArtistAnimation sequentially scanning across therange of time steps available in
data. IfFalse, display plot of the final step once all spikes have been counted, defaults toFalseinterpolate (int, optional) –
Can be increased to smooth the animation of the vertical time bar. The value passed is the
interpolation factor: e.g.,
interpolate=1results in no additional interpolation. e.g.,interpolate=5results in 4 additional frames for each time step, defaults to1gridshader (Bool, optional) – Applies shading to figure background to distinguish output classes, defaults to
Trueinterval (int, optional) – Delay between frames in milliseconds, defaults to
25time_step (int, optional) – Duration of each time step in seconds. If
False, time-axis will be in terms ofnum_steps. Else, time-axis is scaled by the argument passed, defaults toFalse
- Returns:
animation to be displayed using
matplotlib.pyplot.show()- Return type:
FuncAnimation (if animate is
True)
- snntorch.spikeplot.traces(data, spk=None, dim=(3, 3), spk_height=5, titles=None)[source]
Plot an array of neuron traces (e.g., membrane potential or synaptic current). Optionally apply spikes to ride on the traces. traces was originally written by Friedemann Zenke.
Example:
import snntorch.spikeplot as splt # mem_rec contains the traces of 9 neuron membrane potentials across 100 time steps in duration print(mem_rec.size()) >>> torch.Size([100, 9]) # Plot traces(mem_rec, dim=(3,3))
- Parameters:
data (torch.Tensor) – Data tensor for neuron traces across time steps of shape [num_steps x num_neurons]
spk (torch.Tensor, optional) – Data tensor for neuron traces across time steps of shape [num_steps x num_neurons], defaults to
Nonedim (tuple, optional) – Dimensions of figure, defaults to
(3, 3)spk_height (float, optional) – height of spike to plot, defaults to
5titles (list of strings, optional) – Adds subplot titles, defaults to
None