Source code for snntorch.utils

# Note: need NumPy 1.17 or later for RNG functions
import numpy as np
import snntorch as snn


[docs] def data_subset(dataset, subset, idx=0): """Partition the dataset by a factor of ``1/subset`` without removing access to data and target attributes. Example:: from snntorch import utils from torchvision import datasets data_path = "path/to/data" subset = 10 # Download MNIST training set mnist_train = datasets.MNIST(data_path, train=True, download=True) print(len(mnist_train)) >>> 60000 # Reduce size of MNIST training set utils.data_subset(mnist_train, subset) print(len(mnist_train)) >>> 6000 :param dataset: Dataset :type dataset: torchvision dataset :param subset: Factor to reduce dataset by :type subset: int :param idx: Which subset of the train and test sets to index into, defaults to ``0`` :type idx: int, optional :return: Partitioned dataset :rtype: list of torch.utils.data """ if subset > 1: N = len(dataset.data) idx_range = np.arange(N, dtype="int") step = N // subset idx_range = idx_range[step * idx : step * (idx + 1)] data = dataset.data[idx_range] targets = dataset.targets[idx_range] dataset.data = data dataset.targets = targets return dataset
[docs] def valid_split(ds_train, ds_val, split, seed=0): """Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results. Operates similarly to ``random_split`` from ``torch.utils.data.dataset`` but retains data and target attributes. Example :: from snntorch import utils from torchvision import datasets data_path = "path/to/data" val_split = 0.1 # Download MNIST training set into mnist_val and mnist_train mnist_train = datasets.MNIST(data_path, train=True, download=True) mnist_val = datasets.MNIST(data_path, train=True, download=True) print(len(mnist_train)) >>> 60000 print(len(mnist_val)) >>> 60000 # Validation split mnist_train, mnist_val = utils.valid_split(mnist_train, mnist_val, val_split) print(len(mnist_train)) >>> 54000 print(len(mnist_val)) >>> 6000 :param ds_train: Training set :type ds_train: torchvision dataset :param ds_val: Validation set :type ds_val: torchvision dataset :param split: Proportion of samples assigned to the validation set from the training set :type split: Float :param seed: Fix to generate reproducible results, defaults to ``0`` :type seed: int, optional :return: Randomly split train and validation sets :rtype: list of torch.utils.data """ n = len(ds_train) n_val = int(n * split) n_train = n - n_val # Create an index list of length n_train, containing non-repeating # values from 0 to n-1 rng = np.random.default_rng(seed=seed) train_idx = rng.choice(n, size=n_train, replace=False) # create inverted index for validation from train val_idx = [] for i in range(n): if i not in train_idx: val_idx.append(i) # Generate ds_val by indexing into ds_train vd = ds_train.data[val_idx] vt = ds_train.targets[val_idx] ds_val.data = vd ds_val.targets = vt # Recreate ds_train by indexing into the previous ds_train td = ds_train.data[train_idx] tt = ds_train.targets[train_idx] ds_train.data = td ds_train.targets = tt return ds_train, ds_val
[docs] def reset(net): """Check for the types of LIF neurons contained in net. Reset their hidden parameters to zero and detach them from the current computation graph.""" global is_alpha global is_leaky global is_lapicque global is_rleaky global is_synaptic global is_rsynaptic global is_sconv2dlstm global is_slstm is_alpha = False is_leaky = False is_rleaky = False is_synaptic = False is_rsynaptic = False is_lapicque = False is_sconv2dlstm = False is_slstm = False _layer_check(net=net) _layer_reset()
def _layer_check(net): """Check for the types of LIF neurons contained in net.""" global is_leaky global is_lapicque global is_synaptic global is_alpha global is_rleaky global is_rsynaptic global is_sconv2dlstm global is_slstm for idx in range(len(list(net._modules.values()))): if isinstance(list(net._modules.values())[idx], snn.Lapicque): is_lapicque = True if isinstance(list(net._modules.values())[idx], snn.Synaptic): is_synaptic = True if isinstance(list(net._modules.values())[idx], snn.Leaky): is_leaky = True if isinstance(list(net._modules.values())[idx], snn.Alpha): is_alpha = True if isinstance(list(net._modules.values())[idx], snn.RLeaky): is_rleaky = True if isinstance(list(net._modules.values())[idx], snn.RSynaptic): is_rsynaptic = True if isinstance(list(net._modules.values())[idx], snn.SConv2dLSTM): is_sconv2dlstm = True if isinstance(list(net._modules.values())[idx], snn.SLSTM): is_slstm = True def _layer_reset(): """Reset hidden parameters to zero and detach them from the current computation graph.""" if is_lapicque: snn.Lapicque.reset_hidden() # reset hidden state to 0's snn.Lapicque.detach_hidden() if is_synaptic: snn.Synaptic.reset_hidden() # reset hidden state to 0's snn.Synaptic.detach_hidden() if is_leaky: snn.Leaky.reset_hidden() # reset hidden state to 0's snn.Leaky.detach_hidden() if is_alpha: snn.Alpha.reset_hidden() # reset hidden state to 0's snn.Alpha.detach_hidden() if is_rleaky: snn.RLeaky.reset_hidden() # reset hidden state to 0's snn.RLeaky.detach_hidden() if is_rsynaptic: snn.RSynaptic.reset_hidden() # reset hidden state to 0's snn.RSynaptic.detach_hidden() if is_sconv2dlstm: snn.SConv2dLSTM.reset_hidden() # reset hidden state to 0's snn.SConv2dLSTM.detach_hidden() if is_slstm: snn.SLSTM.reset_hidden() # reset hidden state to 0's snn.SLSTM.detach_hidden() def _final_layer_check(net): """Check class of final layer and return the number of outputs.""" if isinstance(list(net._modules.values())[-1], snn.Lapicque): return 2 if isinstance(list(net._modules.values())[-1], snn.Synaptic): return 3 if isinstance(list(net._modules.values())[-1], snn.RSynaptic): return 3 if isinstance(list(net._modules.values())[-1], snn.Leaky): return 2 if isinstance(list(net._modules.values())[-1], snn.RLeaky): return 2 if isinstance(list(net._modules.values())[-1], snn.SConv2dLSTM): return 3 if isinstance(list(net._modules.values())[-1], snn.SLSTM): return 3 if isinstance(list(net._modules.values())[-1], snn.Alpha): return 4 else: # if not from snn, assume from nn with 1 return return 1