Source code for snntorch.functional.quant

import torch
import math


[docs] class StateQuant(torch.autograd.Function): """Wrapper function for state_quant"""
[docs] @staticmethod def forward(ctx, input_, levels): device = input_.device levels = levels.to(device) size = input_.size() input_ = input_.flatten() # Broadcast mem along new direction same # of times as num_levels repeat_dims = torch.ones(len(input_.size())).tolist() repeat_dims.append(len(levels)) repeat_dims = [int(item) for item in repeat_dims] repeat_dims = tuple(repeat_dims) input_ = input_.unsqueeze(-1).repeat(repeat_dims) # find closest valid quant state idx_match = torch.min(torch.abs(levels - input_), dim=-1)[1] quant_tensor = levels[idx_match] return quant_tensor.reshape(size)
# STE
[docs] @staticmethod def backward(ctx, grad_output): grad_input = grad_output.clone() return grad_input, None
[docs] def state_quant( num_bits=8, uniform=True, thr_centered=True, threshold=1, lower_limit=0, upper_limit=0.2, multiplier=None, ): """Quantization-Aware Training with spiking neuron states. **Note: for weight quantization, we recommend using Brevitas or another pre-existing PyTorch-friendly library.** Uniform and non-uniform quantization can be applied in various modes by specifying ``uniform=True``. Valid quantization levels can be centered about 0 or threshold by specifying ``thr_centered=True``. ``upper_limit`` and ``lower_limit`` specify the proportion of how far valid levels go above and below the positive and negative threshold/ E.g., upper_limit=0.2 means the maximum valid state is 20% higher than the value specified in ``threshold``. Example:: import torch import snntorch as snn from snntorch.functional import quant beta = 0.5 thr = 5 # set the quantization parameters q_lif = quant.state_quant(num_bits=4, uniform=True, threshold=thr) # specifying state_quant applies state-quantization to the # hidden state(s) automatically lif = snn.Leaky(beta=beta, threshold=thr, state_quant=q_lif) rand_input = torch.rand(1) mem = lif.init_leaky() # forward-pass for one step spk, mem = lif(rand_input, mem) Note: Quantization-Aware training is focused on modelling a reduced precision network, but does not in of itself accelerate low-precision models. Hidden states are still represented as full precision values for compatibility with PyTorch. For accelerated performance or constrained-memory, the model should be exported to a downstream backend. :param num_bits: Number of bits to quantize state variables to, defaults to ``8`` :type num_bits: int, optional :param uniform: Applies uniform quantization if specified, non-uniform if unspecified, defaults to ``True`` :type uniform: Bool, optional :param thr_centered: For non-uniform quantization, specifies if valid states should be centered (densely clustered) around the threshold rather than at 0, defaults to ``True`` :type thr_centered: Bool, optional :param threshold: Specifies the threshold, defaults to ``1`` :type threshold: float, optional :param lower_limit: Specifies how far below (-threshold) the lowest valid state can be, i.e., (-threshold - threshold*lower_limit), defaults to ``0`` :type lower_limit: float, optional :param upper_limit: Specifies how far above (threshold) the highest valid state can be, i.e., (threshold + threshold*upper_limit), defaults to ``0.2`` :type upper_limit: float, optional :param multiplier: For non-uniform distributions, specify the base of the exponential. If ``None``, an appropriate value is set internally based on ``num_bits``, defaults to ``None`` :type multiplier: float, optional """ num_levels = 2 << num_bits - 1 # linear / uniform quantization - ignores thr_centered if uniform: levels = torch.linspace( -threshold - threshold * lower_limit, threshold + threshold * upper_limit, num_levels, ) # exponential / non-uniform quantization else: if multiplier is None: if num_bits == 1: multiplier = 0.05 if num_bits == 2: multiplier = 0.1 elif num_bits == 3: multiplier = 0.3 elif num_bits == 4: multiplier = 0.5 elif num_bits == 5: multiplier = 0.7 elif num_bits == 6: multiplier = 0.9 elif num_bits == 7: multiplier = 0.925 elif num_bits > 7: multiplier = 0.95 # asymmetric: shifted to threshold if thr_centered: max_val = threshold + (threshold * upper_limit) # maximum level that can be reached min_val = -(threshold + (threshold * lower_limit)) # minimum level that can be reached num_levels = 2 << num_bits - 1 # total number of levels overall_range = max_val - min_val # range of number of levels lower_range = threshold - min_val # levels below the threshold upper_range = max_val - threshold # levels above the threshold lower_percent = lower_range / overall_range # percent +66 the threshold upper_percent = upper_range / overall_range # percent above the threshold lower_bits = math.floor(num_levels * lower_percent) # how many levels lower than the threshold upper_bits = num_levels - lower_bits # how many bits above the threshold lower_summation = 0 store_val = [] if lower_bits != 0: for j in reversed(range(lower_bits)): # figure out how much the summation travels lower_curr = (multiplier ** j) lower_summation += (multiplier ** j) lower_room = lower_summation / lower_range min_temp_sum = min_val # store_val.append(min_temp_sum) for j in (range(lower_bits)): lower_curr = multiplier ** j store_val.append(min_temp_sum) min_temp_sum += (lower_curr / lower_room) # store_val.append(min_temp_sum) if upper_bits != 0: upper_summation = 0 for j in reversed(range(upper_bits)): upper_curr = (multiplier ** j) upper_summation += (multiplier ** j) upper_room = upper_summation / upper_range max_temp_sum = threshold diff_store_val = [] # store_val.append(max_temp_sum) for j in reversed(range(upper_bits)): upper_curr = multiplier ** j # store_val.append(max_temp_sum) max_temp_sum += (upper_curr / upper_room) store_val.append(max_temp_sum) # store_val.append(max_temp_sum) levels = torch.tensor([x for x in store_val]) # centered about zero else: max_val = threshold + (threshold * upper_limit) # maximum level that can be reached min_val = -(threshold + (threshold * lower_limit)) # minimum level that can be reached num_levels = 2 << num_bits - 1 # total number of levels overall_range = max_val - min_val # range of number of levels lower_range = 0 - min_val # levels below the threshold upper_range = max_val - 0 # levels above the threshold lower_percent = lower_range / overall_range # percent +66 the threshold upper_percent = upper_range / overall_range # percent above the threshold lower_bits = math.floor(num_levels * lower_percent) # how many levels lower than the threshold upper_bits = num_levels - lower_bits # how many bits above the threshold lower_summation = 0 store_val = [] if lower_bits != 0: for j in reversed(range(lower_bits)): # figure out how much the summation travels lower_curr = (multiplier ** j) lower_summation += (multiplier ** j) lower_room = lower_summation / lower_range min_temp_sum = min_val # store_val.append(min_temp_sum) for j in (range(lower_bits)): lower_curr = multiplier ** j store_val.append(min_temp_sum) min_temp_sum += (lower_curr / lower_room) # store_val.append(min_temp_sum) if upper_bits != 0: upper_summation = 0 for j in reversed(range(upper_bits)): upper_curr = (multiplier ** j) upper_summation += (multiplier ** j) upper_room = upper_summation / upper_range max_temp_sum = 0 diff_store_val = [] # store_val.append(max_temp_sum) for j in reversed(range(upper_bits)): upper_curr = multiplier ** j max_temp_sum += (upper_curr / upper_room) store_val.append(max_temp_sum) # store_val.append(max_temp_sum) levels = torch.tensor([x for x in store_val]) def inner(x): return StateQuant.apply(x, levels) return inner