Source code for snntorch.import_nir

import numpy as np
import nir
import nirtorch
import torch
import snntorch as snn

def _create_rnn_subgraph(
    graph: nir.NIRGraph, lif_nk: str, w_nk: str
) -> nir.NIRGraph:
    """Take a NIRGraph plus the node keys for a LIF and a W_rec, and return a new NIRGraph
    which has the RNN subgraph replaced with a subgraph (i.e., a single NIRGraph node).

    The subgraph will have the following structure:
    ```
    LIF -> W_rec -> LIF
    ^             |
    |             v
    Input         Output
    ```

    :param graph: NIRGraph
    :type graph: nir.NIRGraph

    :param lif_nk: key for the LIF node
    :type lif_nk: str

    :param w_nk: key for the W_rec node
    :type w_nk: str

    :return: NIRGraph with the RNN subgraph replaced with a single NIRGraph node
    :rtype: nir.NIRGraph
    """
    # NOTE: assuming that the LIF and W_rec have keys of form xyz.abc
    sg_key = lif_nk.split(".")[0]  # TODO: make this more general?

    # create subgraph for RNN
    sg_edges = [
        (lif_nk, w_nk),
        (w_nk, lif_nk),
        (lif_nk, f"{sg_key}.output"),
        (f"{sg_key}.input", w_nk),
    ]
    sg_nodes = {
        lif_nk: graph.nodes[lif_nk],
        w_nk: graph.nodes[w_nk],
        f"{sg_key}.input": nir.Input(graph.nodes[lif_nk].input_type),
        f"{sg_key}.output": nir.Output(graph.nodes[lif_nk].output_type),
    }
    sg = nir.NIRGraph(nodes=sg_nodes, edges=sg_edges)

    # remove subgraph edges from graph
    graph.edges = [
        e for e in graph.edges if e not in [(lif_nk, w_nk), (w_nk, lif_nk)]
    ]
    # remove subgraph nodes from graph
    graph.nodes = {
        k: v for k, v in graph.nodes.items() if k not in [lif_nk, w_nk]
    }

    # change edges of type (x, lif_nk) to (x, sg_key)
    graph.edges = [
        (e[0], sg_key) if e[1] == lif_nk else e for e in graph.edges
    ]
    # change edges of type (lif_nk, x) to (sg_key, x)
    graph.edges = [
        (sg_key, e[1]) if e[0] == lif_nk else e for e in graph.edges
    ]

    # insert subgraph into graph and return
    graph.nodes[sg_key] = sg
    return graph


def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph:
    """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node.
    Goes through the NIRGraph to find any RNN subgraphs, and replaces them with a single NIRGraph node,
    using the _create_rnn_subgraph function.

    :param graph: NIRGraph
    :type graph: nir.NIRGraph

    :return: NIRGraph with RNN subgraphs replaced with a single NIRGraph node
    :rtype: nir.NIRGraph
    """
    print("replace rnn subgraph with nirgraph")

    if len(set(graph.edges)) != len(graph.edges):
        print("[WARNING] duplicate edges found, removing")
        graph.edges = list(set(graph.edges))

    # find cycle of LIF <> Dense nodes
    for edge1 in graph.edges:
        for edge2 in graph.edges:
            if not edge1 == edge2:
                if edge1[0] == edge2[1] and edge1[1] == edge2[0]:
                    lif_nk = edge1[0]
                    lif_n = graph.nodes[lif_nk]
                    w_nk = edge1[1]
                    w_n = graph.nodes[w_nk]
                    is_lif = isinstance(lif_n, (nir.LIF, nir.CubaLIF))
                    is_dense = isinstance(w_n, (nir.Affine, nir.Linear))
                    # check if the dense only connects to the LIF
                    w_out_nk = [e[1] for e in graph.edges if e[0] == w_nk]
                    w_in_nk = [e[0] for e in graph.edges if e[1] == w_nk]
                    is_rnn = len(w_out_nk) == 1 and len(w_in_nk) == 1
                    # check if we found an RNN - if so, then parse it
                    if is_rnn and is_lif and is_dense:
                        graph = _create_rnn_subgraph(graph, edge1[0], edge1[1])
    return graph


def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int):  # type: ignore
    """Try parsing the presented graph as a RNN subgraph. Assumes the graph is a valid RNN subgraph
    with four nodes in the following structure:

    ```
    Input -> LIF | CubaLIF -> Output
                ^
                |
                v
             Affine | Linear
    ```

    :param graph: NIRGraph
    :type graph: nir.NIRGraph

    :return: LIF | CubaLIF node
    :rtype: nir.NIRNode

    :return: Affine | Linear node
    :rtype: nir.NIRNode

    :return: int, number of neurons in the RNN
    :rtype: int
    """
    sub_nodes = graph.nodes.values()
    assert len(sub_nodes) == 4, "only 4-node RNN allowed in subgraph"
    try:
        input_node = [n for n in sub_nodes if isinstance(n, nir.Input)][0]
        output_node = [n for n in sub_nodes if isinstance(n, nir.Output)][0]
        lif_node = [
            n for n in sub_nodes if isinstance(n, (nir.LIF, nir.CubaLIF))
        ][0]
        wrec_node = [
            n for n in sub_nodes if isinstance(n, (nir.Affine, nir.Linear))
        ][0]
    except IndexError:
        raise ValueError(
            "invalid RNN subgraph - could not find all required nodes"
        )
    lif_size = list(input_node.input_type.values())[0][0]
    assert (
        lif_size == list(output_node.output_type.values())[0][0]
    ), "output size mismatch"
    assert (
        lif_size == lif_node.v_threshold.size
    ), "lif size mismatch (v_threshold)"
    assert lif_size == wrec_node.weight.shape[0], "w_rec shape mismatch"
    assert lif_size == wrec_node.weight.shape[1], "w_rec shape mismatch"

    return lif_node, wrec_node, lif_size


def _nir_to_snntorch_module(
    node: nir.NIRNode, hack_w_scale=True, init_hidden=False
) -> torch.nn.Module:
    """Convert a NIR node to a snnTorch module. This function is used by the import_from_nir function.

    :param node: NIR node
    :type node: nir.NIRNode

    :param hack_w_scale: if True, then the function will attempt to scale the weights to avoid scaling the inputs
    :type hack_w_scale: bool

    :param init_hidden: the init_hidden flag of the snnTorch neuron.
    :type init_hidden: bool

    :return: snnTorch module
    :rtype: torch.nn.Module
    """
    if isinstance(node, nir.Input) or isinstance(node, nir.Output):
        return None

    elif isinstance(node, nir.Affine):
        assert node.bias is not None, "bias must be specified for Affine layer"

        mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0])
        mod.weight.data = torch.Tensor(node.weight)
        mod.bias.data = torch.Tensor(node.bias)

        return mod

    elif isinstance(node, nir.Linear):
        mod = torch.nn.Linear(
            node.weight.shape[1], node.weight.shape[0], bias=False
        )
        mod.weight.data = torch.Tensor(node.weight)

        return mod

    elif isinstance(node, nir.Conv2d):
        mod = torch.nn.Conv2d(
            node.weight.shape[1],
            node.weight.shape[0],
            kernel_size=[*node.weight.shape[-2:]],
            stride=node.stride,
            padding=node.padding,
            dilation=node.dilation,
            groups=node.groups,
        )
        mod.bias.data = torch.Tensor(node.bias)
        mod.weight.data = torch.Tensor(node.weight)
        return mod

    if isinstance(node, nir.Flatten):
        return torch.nn.Flatten(node.start_dim, node.end_dim)

    if isinstance(node, nir.AvgPool2d):
        return torch.nn.AvgPool2d(
            kernel_size=tuple(node.kernel_size),
            stride=tuple(node.stride),
            padding=tuple(node.padding),
           # divisor_override=1,
        )

    elif isinstance(node, nir.IF):
        assert (
            np.unique(node.v_threshold).size == 1
        ), "v_threshold must be same for all neurons"
        assert np.unique(node.r).size == 1, "r must be same for all neurons"
        vthr = np.unique(node.v_threshold)[0]
        r = np.unique(node.r)[0]
        assert r == 1, "r != 1 not supported"
        mod = snn.Leaky(
            beta=0.9,
            threshold=vthr * r,
            init_hidden=False,
            reset_delay=False,
        )
        return mod

    elif isinstance(node, nir.LIF):
        dt = 1e-4

        assert np.allclose(node.v_leak, 0.0), "v_leak not supported"
        assert (
            np.unique(node.v_threshold).size == 1
        ), "v_threshold must be same for all neurons"

        beta = 1 - (dt / node.tau)
        vthr = node.v_threshold
        w_scale = node.r * dt / node.tau

        if not np.allclose(w_scale, 1.0):
            if hack_w_scale:
                vthr = vthr / np.unique(w_scale)[0]
                print("[warning] scaling weights to avoid scaling inputs")
                print(
                    f"w_scale: {w_scale}, r: {node.r}, dt: {dt}, tau: {node.tau}"
                )
            else:
                raise NotImplementedError(
                    "w_scale must be 1, or the same for all neurons"
                )

        assert (
            np.unique(vthr).size == 1
        ), "LIF v_thr must be same for all neurons"

        return snn.Leaky(
            beta=beta,
            threshold=np.unique(vthr)[0],
            reset_mechanism="zero",
            init_hidden=init_hidden,
            reset_delay=False,
        )

    elif isinstance(node, nir.CubaLIF):
        dt = 1e-4

        assert np.allclose(node.v_leak, 0), "v_leak not supported"
        assert np.allclose(
            node.r, node.tau_mem / dt
        ), "r not supported in CubaLIF"

        alpha = 1 - (dt / node.tau_syn)
        beta = 1 - (dt / node.tau_mem)
        vthr = node.v_threshold
        w_scale = node.w_in * (dt / node.tau_syn)

        if not np.allclose(w_scale, 1.0):
            if hack_w_scale:
                vthr = vthr / w_scale
                print("[warning] scaling weights to avoid scaling inputs")
                print(
                    f"w_scale: {w_scale}, w_in: {node.w_in}, dt: {dt}, tau_syn: {node.tau_syn}"
                )
            else:
                raise NotImplementedError(
                    "w_scale must be 1, or the same for all neurons"
                )

        assert (
            np.unique(vthr).size == 1
        ), "CubaLIF v_thr must be same for all neurons"

        if np.unique(alpha).size == 1:
            alpha = float(np.unique(alpha)[0])
        if np.unique(beta).size == 1:
            beta = float(np.unique(beta)[0])

        return snn.Synaptic(
            alpha=alpha,
            beta=beta,
            threshold=float(np.unique(vthr)[0]),
            reset_mechanism="zero",
            init_hidden=init_hidden,
            reset_delay=False,
        )

    elif isinstance(node, nir.NIRGraph):
        lif_node, wrec_node, lif_size = _parse_rnn_subgraph(node)

        if isinstance(lif_node, nir.LIF):
            raise NotImplementedError("LIF in subgraph not supported")

        elif isinstance(lif_node, nir.CubaLIF):
            dt = 1e-4

            assert np.allclose(lif_node.v_leak, 0), "v_leak not supported"
            assert np.allclose(
                lif_node.r, lif_node.tau_mem / dt
            ), "r not supported in CubaLIF"

            alpha = 1 - (dt / lif_node.tau_syn)
            beta = 1 - (dt / lif_node.tau_mem)
            vthr = lif_node.v_threshold
            w_scale = lif_node.w_in * (dt / lif_node.tau_syn)

            if not np.allclose(w_scale, 1.0):
                if hack_w_scale:
                    vthr = vthr / w_scale
                    print(
                        f"[warning] scaling weights to avoid scaling inputs. w_scale: {w_scale}"
                    )
                    print(
                        f"w_in: {lif_node.w_in}, dt: {dt}, tau_syn: {lif_node.tau_syn}"
                    )
                else:
                    raise NotImplementedError(
                        "w_scale must be 1, or the same for all neurons"
                    )

            assert (
                np.unique(vthr).size == 1
            ), "CubaLIF v_thr must be same for all neurons"

            diagonal = np.array_equal(
                wrec_node.weight, np.diag(np.diag(wrec_node.weight))
            )

            if np.unique(alpha).size == 1:
                alpha = float(np.unique(alpha)[0])
            if np.unique(beta).size == 1:
                beta = float(np.unique(beta)[0])

            if diagonal:
                V = torch.from_numpy(np.diag(wrec_node.weight)).to(
                    dtype=torch.float32
                )
            else:
                V = None

            rsynaptic = snn.RSynaptic(
                alpha=alpha,
                beta=beta,
                threshold=float(np.unique(vthr)[0]),
                reset_mechanism="zero",
                init_hidden=init_hidden,
                all_to_all=not diagonal,
                linear_features=lif_size if not diagonal else None,
                V=V if diagonal else None,
                reset_delay=False,
            )

            if isinstance(rsynaptic.recurrent, torch.nn.Linear):
                rsynaptic.recurrent.weight.data = torch.Tensor(
                    wrec_node.weight
                )
                if isinstance(wrec_node, nir.Affine):
                    rsynaptic.recurrent.bias.data = torch.Tensor(
                        wrec_node.bias
                    )
                else:
                    rsynaptic.recurrent.bias.data = torch.zeros_like(
                        rsynaptic.recurrent.bias
                    )
            else:
                rsynaptic.recurrent.V.data = torch.diagonal(
                    torch.Tensor(wrec_node.weight)
                )

            return rsynaptic

    elif node is None:
        return torch.nn.Identity()

    else:
        print(
            "[WARNING] could not parse node of type:", node.__class__.__name__
        )
        return None


[docs] def import_from_nir(graph: nir.NIRGraph) -> torch.nn.Module: """Convert a NIRGraph to a snnTorch module. This function is the inverse of export_to_nir. It proceeds by wrapping any recurrent connections into NIR sub-graphs, then converts each NIR module into the equivalent snnTorch module, and wraps them into a torch.nn.Module using the generic GraphExecutor from NIRTorch to execute all modules in the right order. Missing features: - RLeaky (LIF inside RNN) Example:: import snntorch as snn import torch from snntorch.export_nir import export_to_nir from snntorch.import_nir import import_from_nir lif1 = snn.Leaky(beta=0.9, init_hidden=True) lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) net = torch.nn.Sequential( torch.nn.Flatten(), torch.nn.Linear(784, 500), lif1, torch.nn.Linear(500, 10), lif2 ) sample_data = torch.randn(1, 784) nir_graph = export_to_nir(net, sample_data, model_name="snntorch") net2 = import_from_nir(nir_graph) :param graph: NIR graph :type graph: NIR.NIRGraph :return: snnTorch network :rtype: torch.nn.Module """ # find valid RNN subgraphs, and replace them with a single NIRGraph node graph = _replace_rnn_subgraph_with_nirgraph(graph) # convert the NIR graph into a torch.nn.Module using snnTorch modules return nirtorch.load(graph, _nir_to_snntorch_module)