Source code for snntorch.spikevision.spikedata.dvs_gesture

# DVS Gesture citation:
# A. Amir, B. Taba, D. Berg, T. Melano,
# J. McKinstry, C. Di Nolfo, T. Nayak, A. Andreopoulos,
# G. Garreau, M. Mendoza, J. Kusnitz, M. Debole, S. Esser,
# T. Delbruck, M. Flickner, and D. Modha, "A Low Power,
# Fully Event-Based Gesture Recognition System,"
# 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR),
# Honolulu, HI, 2017.
# Dataloader adapted from
# https://github.com/nmi-lab/torchneuromorphic
# by Emre Neftci and Clemens Schaefer

# import struct
# import time
import numpy as np
import h5py

# import torch.utils.data
from ..neuromorphic_dataset import NeuromorphicDataset
from ..events_timeslices import get_tmad_slice
from .._transforms import (
    dvs_permute,
    Compose,
    Repeat,
    toOneHot,
    ToTensor,
    ToCountFrame,
    Downsample,
)
from .._utils import aedat_to_events
import os
from tqdm import tqdm
import glob


mapping = {
    0: "Hand Clapping",
    1: "Right Hand Wave",
    2: "Left Hand Wave",
    3: "Right Arm CW",
    4: "Right Arm CCW",
    5: "Left Arm CW",
    6: "Left Arm CCW",
    7: "Arm Roll",
    8: "Air Drums",
    9: "Air Guitar",
    10: "Other",
}


[docs] class DVSGesture(NeuromorphicDataset): """`DVS Gesture <https://www.research.ibm.com/dvsgesture/>`_ Dataset. The data was recorded using a DVS128. The dataset contains 11 hand gestures from 29 subjects under 3 illumination conditions. **Number of classes:** 11 **Number of train samples:** 1176 **Number of test samples:** 288 **Dimensions:** ``[num_steps x 2 x 128 x 128]`` * **num_steps:** time-dimension of event-based footage * **2:** number of channels (on-spikes for luminance increasing; off-spikes for luminance decreasing) * **128x128:** W x H spatial dimensions of event-based footage For further reading, see: *A. Amir, B. Taba, D. Berg, T. Melano, J. McKinstry, C. Di Nolfo, T. Nayak, A. Andreopoulos, G. Garreau, M. Mendoza, J. Kusnitz, M. Debole, S. Esser, T. Delbruck, M. Flickner, and D. Modha, "A Low Power, Fully Event-Based Gesture Recognition System," 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI, 2017.* Example:: from snntorch.spikevision import spikedata train_ds = spikedata.DVSGesture("data/dvsgesture", train=True, num_steps=500, dt=1000) test_ds = spikedata.DVSGesture("data/dvsgesture", train=False, num_steps=1800, dt=1000) # by default, each time step is integrated over 1ms, or dt=1000 microseconds # dt can be changed to integrate events over a varying number of time steps # Note that num_steps should be scaled inversely by the same factor train_ds = spikedata.DVSGesture("data/dvsgesture", train=True, num_steps=250, dt=2000) test_ds = spikedata.DVSGesture("data/dvsgesture", train=False, num_steps=900, dt=2000) The dataset can also be manually downloaded, extracted and placed into ``root`` which will allow the dataloader to bypass straight to the generation of a hdf5 file. **Direct Download Links:** `IBM Box Link <https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/ folder/50167556794>`_ `Dropbox Link <https://www.dropbox.com/s/cct5kyilhtsliup/ DvsGesture.tar.gz?dl=0>`_ :param root: Root directory of dataset. :type root: string :param train: If True, creates dataset from training set of dvsgesture, otherwise test set. :type train: bool, optional :param transform: A function/transform that takes in a PIL image and returns a transforms version. By default, a pre-defined set of transforms are applied to all samples to convert them into a time-first tensor with correct orientation. :type transform: callable, optional :param target_transform: A function/transform that takes in the target and transforms it. :type target_transform: callable, optional :param download_and_create: If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. :type download_and_create: bool, optional :param num_steps: Number of time steps, defaults to ``500`` for train set, or ``1800`` for test set :type num_steps: int, optional :param dt: The number of time stamps integrated in microseconds, defaults to ``1000`` :type dt: int, optional :param ds: Rescaling factor, defaults to ``1``. :type ds: int, optional :return_meta: Option to return metadata, defaults to ``False`` :type return_meta: bool, optional :time_shuffle: Option to randomize start time of dataset, defaults to ``False`` :type time_shuffle: bool, optional Dataloader adapted from `torchneuromorphic <https://github.com/nmi-lab/torchneuromorphic>`_ originally by Emre Neftci and Clemens Schaefer. The dataset is released under a Creative Commons Attribution 4.0 license. All rights remain with the original authors. """ # _resources_url = [['Manually Download dataset here: # https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/ # file/211521748942?sb=/details and place under {0}'.format(directory), # None, 'DvsGesture.tar.gz']] _resources_url = [ [ "https://www.dropbox.com/s/cct5kyilhtsliup/DvsGesture.tar.gz?dl=1", None, "DvsGesture.tar.gz", ] ] # directory = 'data/dvsgesture/' def __init__( self, root, train=True, transform=None, target_transform=None, download_and_create=True, num_steps=None, dt=1000, ds=None, return_meta=False, time_shuffle=False, ): self.n = 0 self.download_and_create = download_and_create self.root = root self.train = train self.dt = dt self.return_meta = return_meta self.time_shuffle = time_shuffle self.hdf5_name = "dvs_gesture.hdf5" self.directory = root.split(self.hdf5_name)[0] self.resources_local = [self.directory + "/DvsGesture.tar.gz"] self.resources_local_extracted = [self.directory + "/DvsGesture/"] if ds is None: ds = 1 if isinstance(ds, int): ds = [ds, ds] size = [2, 128 // ds[0], 128 // ds[1]] # 128//ds[0], 128//ds[1] if num_steps is None: if self.train: self.num_steps = 500 else: self.num_steps = 1800 else: self.num_steps = num_steps if transform is None: transform = Compose( [ Downsample(factor=[self.dt, 1, ds[0], ds[1]]), ToCountFrame(T=self.num_steps, size=size), ToTensor(), dvs_permute(), ] ) if target_transform is not None: target_transform = Compose([Repeat(num_steps), toOneHot(11)]) super().__init__( root=root + "/" + self.hdf5_name, transform=transform, target_transform_train=target_transform, ) with h5py.File(self.root, "r", swmr=True, libver="latest") as f: if train: self.n = f["extra"].attrs["Ntrain"] self.keys = f["extra"]["train_keys"][()] else: self.n = f["extra"].attrs["Ntest"] self.keys = f["extra"]["test_keys"][()] def _download(self): try: _ = super(DVSGesture, self)._download() except Exception as e: print(e.message, e.args) def _create_hdf5(self): create_events_hdf5( self.directory, self.resources_local_extracted[0], self.directory + "/" + self.hdf5_name, ) def __len__(self): return self.n def __getitem__(self, key): # Important to open and close in getitem to enable num_workers>0 with h5py.File(self.root, "r", swmr=True, libver="latest") as f: if not self.train: key = key + f["extra"].attrs["Ntrain"] assert key in self.keys data, target, meta_info_light, meta_info_user = sample( f, key, T=self.num_steps, shuffle=self.time_shuffle, train=self.train, ) if self.transform is not None: data = self.transform(data) if self.target_transform is not None: target = self.target_transform(target) if self.return_meta is True: return data, target, meta_info_light, meta_info_user else: return data, target
def sample(hdf5_file, key, T=500, shuffle=False, train=True): if train: T_default = 500 else: T_default = 1800 dset = hdf5_file["data"][str(key)] label = dset["labels"][()] tbegin = dset["times"][0] tend = np.maximum(0, dset["times"][-1] - 2 * T * 1000) start_time = np.random.randint(tbegin, tend + 1) if shuffle else 0 # print(start_time) # tmad = get_tmad_slice(dset['times'][()], dset['addrs'][()], # start_time, T*1000) tmad = get_tmad_slice( dset["times"][()], dset["addrs"][()], start_time, T_default * 1000 ) tmad[:, 0] -= tmad[0, 0] meta = eval(dset.attrs["meta_info"]) return ( tmad[:, [0, 3, 1, 2]], label, meta["light condition"], meta["subject"], ) def create_events_hdf5(directory, extracted_directory, hdf5_filename): fns_train = gather_aedat(directory, extracted_directory, 1, 24) fns_test = gather_aedat(directory, extracted_directory, 24, 30) test_keys = [] train_keys = [] assert len(fns_train) == 98 with h5py.File(hdf5_filename, "w") as f: f.clear() key = 0 metas = [] data_grp = f.create_group("data") extra_grp = f.create_group("extra") print("\nCreating dvs_gesture.hdf5...") for file_d in tqdm(fns_train + fns_test): istrain = file_d in fns_train data, labels_starttime = aedat_to_events(file_d) # tms = data[:, 0] # ads = data[:, 1:] lbls = labels_starttime[:, 0] # start_tms = labels_starttime[:, 1] # end_tms = labels_starttime[:, 2] # out = [] for i, v in enumerate(lbls): if istrain: train_keys.append(key) else: test_keys.append(key) # s_ = get_slice(tms, ads, start_tms[i], end_tms[i]) # times = s_[0] # addrs = s_[1] # subj, light = file_d.replace('\\', '/').split('/')[-1]. # split('.')[0].split('_')[:2] # this line throws an error in get_slice, # because idx_beg = idx_end --> empty batch subj, light = ( file_d.split("/")[-1].split(".")[0].split("_")[:2] ) metas.append( { "key": str(key), "subject": subj, "light condition": light, "training sample": istrain, } ) subgrp = data_grp.create_group(str(key)) # tm_dset = subgrp.create_dataset( # "times", data=times, dtype=np.uint32 # ) # ad_dset = subgrp.create_dataset( # "addrs", data=addrs, dtype=np.uint8 # ) # lbl_dset = subgrp.create_dataset( # "labels", data=lbls[i] - 1, dtype=np.uint8 # ) subgrp.attrs["meta_info"] = str(metas[-1]) assert lbls[i] - 1 in range(11) key += 1 extra_grp.create_dataset("train_keys", data=train_keys) extra_grp.create_dataset("test_keys", data=test_keys) extra_grp.attrs["N"] = len(train_keys) + len(test_keys) extra_grp.attrs["Ntrain"] = len(train_keys) extra_grp.attrs["Ntest"] = len(test_keys) print("dvs_gesture.hdf5 was created successfully.") def gather_aedat( directory, extracted_directory, start_id, end_id, filename_prefix="user" ): if not os.path.isdir(directory): raise FileNotFoundError( "DVS Gestures Dataset not found, looked at: {}".format(directory) ) fns = [] for i in range(start_id, end_id): search_mask = ( extracted_directory + "/" + filename_prefix + "{0:02d}".format(i) + "*.aedat" ) glob_out = glob.glob(search_mask) if len(glob_out) > 0: fns += glob_out return fns