Source code for snntorch.export_nir

#from typing import Optional
from typing import Optional, Tuple, Union
import torch
import os
import sys
import nir
import numpy as np
import nirtorch
import snntorch as snn



def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]:
    """Convert a single snnTorch module to the equivalent object in the Neuromorphic
    Intermediate Representation (NIR). This function is used internally by the export_to_nir
    function to convert each submodule/layer of the network to the NIR.

    Currently supported snnTorch modules: Leaky, Linear, Synaptic, RLeaky, RSynaptic.

    Note that recurrent layers such as RLeaky and RSynaptic will be converted to a NIR graph,
    which will then be embedded as a subgraph into the main NIR graph.

    :param module: snnTorch module
    :type module: torch.nn.Module

    :return: return the NIR node
    :rtype: Optional[nir.NIRNode]
    """
    # Adding Conv2d layer
    """
    if isinstance(module, torch.nn.Conv2d):
        return nir.Conv2d(
            input_shape=None,
            weight=module.weight.detach(),
            bias=module.bias.detach(),
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            groups=module.groups,
        )
    """
    #modifiying bias of the conv2d layer extraction 
    if isinstance(module, torch.nn.Conv2d):
        return nir.Conv2d(
            input_shape=None,
            weight=module.weight.detach(),
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            groups=module.groups,
            #better handle for the bias if it's False
            bias=(
                module.bias.detach()
                if isinstance(module.bias, torch.Tensor)
                else torch.zeros((module.weight.shape[0]))
            ),
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        return nir.AvgPool2d(
            kernel_size=module.kernel_size,  # (Height, Width)
            stride=module.kernel_size
            if module.stride is None
            else module.stride,  # (Height, width)
            padding=(0, 0),  # (Height, width)
        )

    elif isinstance(module, snn.Leaky):
        dt = 1e-4

        beta = module.beta.detach().numpy()
        vthr = module.threshold.detach().numpy()
        vthr = np.array([vthr]) if isinstance(vthr, (int, float)) else vthr
        tau_mem = dt / (1 - beta)
        r = tau_mem / dt
        v_leak = np.zeros_like(beta)

        return nir.LIF(
            tau=tau_mem,
            v_threshold=vthr,
            v_leak=v_leak,
            r=r,
        )

    elif isinstance(module, torch.nn.Linear):
        if module.bias is None:
            return nir.Linear(weight=module.weight.data.detach().numpy())
        else:
            return nir.Affine(
                weight=module.weight.data.detach().numpy(),
                bias=module.bias.data.detach().numpy(),
            )

    elif isinstance(module, snn.Synaptic):
        dt = 1e-4

        # TODO: assert that size of the current layer is correct
        alpha = module.alpha.detach().numpy()
        beta = module.beta.detach().numpy()
        vthr = module.threshold.detach().numpy()
        vthr = np.array([vthr]) if isinstance(vthr, (int, float)) else vthr

        tau_syn = dt / (1 - alpha)
        tau_mem = dt / (1 - beta)
        r = tau_mem / dt
        v_leak = np.zeros_like(beta)
        w_in = tau_syn / dt

        return nir.CubaLIF(
            tau_syn=tau_syn,
            tau_mem=tau_mem,
            v_threshold=vthr,
            v_leak=v_leak,
            r=r,
            w_in=w_in,
        )

    elif isinstance(module, snn.RLeaky):
        # TODO(stevenabreu7): implement RLeaky
        raise NotImplementedError("RLeaky not supported")

    elif isinstance(module, snn.RSynaptic):
        if module.all_to_all:
            w_rec = _extract_snntorch_module(module.recurrent)
            n_neurons = w_rec.weight.shape[0]
        else:
            if len(module.recurrent.V.shape) == 0:
                # TODO: handle this better - if V is a scalar, then the weight has wrong shape
                raise ValueError(
                    "V must be a vector, cannot infer layer size for scalar V"
                )
            n_neurons = module.recurrent.V.shape[0]
            w = np.diag(module.recurrent.V.data.detach().numpy())
            w_rec = nir.Linear(weight=w)

        dt = 1e-4

        alpha = module.alpha.detach().numpy()
        beta = module.beta.detach().numpy()
        vthr = module.threshold.detach().numpy()
        alpha = np.ones(n_neurons) * alpha
        beta = np.ones(n_neurons) * beta
        vthr = np.ones(n_neurons) * vthr

        tau_syn = dt / (1 - alpha)
        tau_mem = dt / (1 - beta)
        r = tau_mem / dt
        v_leak = np.zeros_like(beta)
        w_in = tau_syn / dt

        return nir.NIRGraph(
            nodes={
                "input": nir.Input(input_type=[n_neurons]),
                "lif": nir.CubaLIF(
                    v_threshold=vthr,
                    tau_mem=tau_mem,
                    tau_syn=tau_syn,
                    r=r,
                    v_leak=v_leak,
                    w_in=w_in,
                ),
                "w_rec": w_rec,
                "output": nir.Output(output_type=[n_neurons]),
            },
            edges=[
                ("input", "lif"),
                ("lif", "w_rec"),
                ("w_rec", "lif"),
                ("lif", "output"),
            ],
        )
    elif isinstance(module, torch.nn.Flatten):
        # Getting rid of the batch dimension for NIR
        start_dim = (
            module.start_dim - 1 if module.start_dim > 0 else module.start_dim
        )
        end_dim = module.end_dim - 1 if module.end_dim > 0 else module.end_dim
        return nir.Flatten(
            input_type=None,
            start_dim=start_dim,
            end_dim=end_dim,
        )

    else:
        print(f"[WARNING] module not implemented: {module.__class__.__name__}")
        return None


[docs] def export_to_nir( module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch", model_fwd_args=[], ignore_dims=[], ) -> nir.NIRNode: """Convert an snnTorch module to the Neuromorphic Intermediate Representation (NIR). This function uses nirtorch to extract the computational graph of the torch module, and the _extract_snntorch_module method is used to convert each module in the graph to the corresponding NIR module. The NIR is a graph-based representation of a spiking neural network, which can be used to port the network to different neuromorphic hardware and software platforms. Missing features: - RLeaky Example:: import snntorch as snn import torch from snntorch.export_nir import export_to_nir lif1 = snn.Leaky(beta=0.9, init_hidden=True) lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) net = torch.nn.Sequential( torch.nn.Flatten(), torch.nn.Linear(784, 500), lif1, torch.nn.Linear(500, 10), lif2 ) sample_data = torch.randn(1, 784) nir_graph = export_to_nir(net, sample_data) :param module: Network model (either wrapped in Sequential container or as a class) :type module: torch.nn.Module :param sample_data: Sample input data to the network :type sample_data: torch.Tensor :param model_name: Name of the model :type model_name: str, optional :param model_fwd_args: Arguments to pass to the forward function of the model :type model_fwd_args: list, optional :param ignore_dims: List of dimensions to ignore when extracting the NIR :type ignore_dims: list, optional :return: return the NIR graph :rtype: nir.NIRNode """ nir_graph = nirtorch.extract_nir_graph( module, _extract_snntorch_module, sample_data, model_name=model_name, ignore_submodules_of=[snn.RLeaky, snn.RSynaptic], model_fwd_args=model_fwd_args, ignore_dims=ignore_dims, ) return nir_graph