Source code for classes.torch_profiler

from abc import ABC, abstractmethod
from collections import OrderedDict
from deep_data_profiler.classes.profile import Profile
from deep_data_profiler.utils import model_graph, TorchHook
import inspect
import networkx as nx
import torch
from torch.nn.modules import activation, dropout, batchnorm
from typing import Dict, List, Optional, Tuple, Union, Any

[docs]class TorchProfiler(ABC): """ A TorchProfiler wraps a PyTorch model into a TorchHook model which can register activations as it evaluates data. Subclasses of TorchProfiler profile model inputs by tracing influential neurons through the layers of the network during classification. Each subclass implementation defines neurons, influence, and contribution differently. Attributes ---------- activation_classes : list List of activation classes in PyTorch dropout_classes : list List of dropout classes in PyTorch implemented_classes : list List of classes in PyTorch which have contrib functions implemented contrib_functions : list List of contrib functions ordered by their corresponding classes in implemented_classes (does not include "contrib_resnetadd") model : torch.nn.Module Model to be profiled """ def __init__(self, model: torch.nn.Module, device: str = "cpu") -> None: super().__init__() self.activation_classes = [ m[1] for m in inspect.getmembers(activation, inspect.isclass) if m[1].__module__ == "torch.nn.modules.activation" ] self.dropout_classes = [ m[1] for m in inspect.getmembers(dropout, inspect.isclass) if m[1].__module__ == "torch.nn.modules.dropout" ] self.batchnorm_classes = [ m[1] for m in inspect.getmembers(batchnorm, inspect.isclass) if m[1].__module__ == "torch.nn.modules.batchnorm" ] # at present the index of the implemented class must match the index of the # contrib_function used for that class. When we have enough many to one # relationships between the two lists we will change this to a dictionary. self.implemented_classes = [ torch.nn.Linear, torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d, torch.nn.Conv2d, ] self.contrib_functions = [ "contrib_linear", "contrib_max2d", "contrib_adaptive_avg_pool2d", "contrib_conv2d", ] self.model = TorchHook(model, device) self.device = device self.hooks = self.model.available_modules() supernodes, SG, pos = self.super_nodes_graph() self.supernodes = supernodes self.SG = SG self.pos = pos self.sght = self.sgheadtail() self.pred_dict = { nd[0]: sorted([preds[0] for preds in self.SG.predecessors(nd)]) for nd in sorted(self.SG) } self.layerdict = self.create_layers()
[docs] def super_nodes_graph(self) -> Tuple[OrderedDict, Any, Dict]: """ Returns ------- supernodes : OrderedDict SG : nx.DiGraph sgpos : dict """ model = self.model.model # grab back the original model G, ndorder, pos = model_graph(model) # group nodes into layers pmas = ( self.model.available_modules() ) # links the graph nodes to PyTorch modules nodes = G.nodes # A layer starts with an implemented class. Some of these correspond to pytorch # modules, others were custom made for the application, like 'add'. def chk_type(n): try: if n.endswith("add") or type(pmas[n]) in self.implemented_classes: return True except: return False # The implemented nodes are the first function to apply to input into the DDPlayer, # the successors which do not change the shape are added to the DDPlayer. implemented_nodes = [n for n in nodes if chk_type(n)] supernodes = OrderedDict({"x_in": ["x_in"]}) for n in implemented_nodes: temp = [n] while True: scs = list(G.successors(temp[-1])) if (len(scs) == 0 or len(scs) > 1) or ( len(scs) == 1 and (chk_type(scs[0])) ): break else: temp.append(scs[0]) supernodes[n] = temp SGnodes = list(supernodes.keys()) SG = nx.DiGraph() # the Digraph of the DDP network for nd in SGnodes: tail = supernodes[nd][-1] try: scs = list(G.successors(tail)) snds = [(nd, snd) for snd in scs if snd in SGnodes] SG.add_edges_from(snds) except Exception as ex: print(ex) continue # Rename nodes according to sorted node order for SG sequences layers in the graph def sortorder(x): return ndorder.index(x) ordG = { v: (k, v) for k, v in enumerate(sorted(list(SG.nodes()), key=sortorder)) } nx.relabel_nodes(SG, ordG, copy=False) sgpos = dict() for nd in SG.nodes(): sgpos[nd] = pos[nd[1]] return supernodes, SG, sgpos
[docs] def sgheadtail(self) -> Dict[int, Tuple[str, str]]: """ Returns ------- sght : dict Dictionary of the names of the head and tail modules at each layer """ namelist = list() sght = dict() pmas = self.model.available_modules() for nd in self.SG.nodes: if nd[0] == 0: sght[0] = ("x_in", "x_in") continue arr = self.supernodes[nd[1]] n = -1 * len(arr) - 1 for k in arr: if k in pmas: head = k namelist.append(k) break else: head = None for k in arr[-1:n:-1]: if k in pmas: tail = k namelist.append(k) break else: tail = None sght[nd[0]] = (head, tail) self.model.add_hooks(namelist) return sght
[docs] def create_layers(self) -> Dict[int, Tuple[List[str], str]]: """ Create a dictionary of layers to profile. Returns ------- layerdict : OrderedDict A dictionary of layers keyed by their position according to the order Pytorch uses to list the layers of the model. Note ---- Because a weighted module is often followed by an activation and/or normalization modules those modules are combined into a single layer to be profiled. As a consequence the number of layers generated will typically be less than the number of available modules in the model. """ # identify layers to profile. layerdict = OrderedDict() # ordered dictionary to feed to profiler layerdict[0] = [ ["x_in"], "contrib_identity", ] # this is the tensor input for the image SGnodes = sorted(self.SG.nodes, key=lambda nd: nd[0]) for ndx in range(1, len(SGnodes)): nd = SGnodes[ndx] if nd[1].endswith("resnetadd"): layerdict[nd[0]] = [self.supernodes[nd[1]], "contrib_resnetadd"] else: this_type = type(self.model.available_modules()[nd[1]]) layerdict[nd[0]] = [ self.supernodes[nd[1]], self.contrib_functions[self.implemented_classes.index(this_type)], ] return layerdict
[docs] def build_dicts( self, x: torch.Tensor, infl_threshold: float, contrib_threshold: float, layers_to_profile: Optional[Union[List[int], Tuple[int]]] = None, **kwargs, ) -> Tuple[dict, dict, dict, dict]: """ Generates neuron/synapse counts/weights for a given input. Parameters ---------- x : torch.Tensor input to model being profiled layers_to_profile : list or tuple list of specific layers to profile or tuple with first,last layers (exclusive of last) to profile and all layers in between infl_threshold : float Parameter for influence contrib_threshold : float Parameter for contribution Returns ------- neuron_counts neuron_weights synapse_counts synapse_weights activation_shapes """ # Create empty dictionaries for storing neurons and synapses neuron_counts = dict() neuron_weights = dict() synapse_counts = dict() synapse_weights = dict() ######## Step 1: Identify the Influential Neurons (Nodes) in the Activation Graph ######## y, activations = self.model.forward(x) activations["x_in"] = x activation_shapes = {} for ldx, modules in self.layerdict.items(): if "resnetadd" in modules[1]: activation_shapes[ldx] = activations[modules[0][1]].shape else: activation_shapes[ldx] = activations[modules[0][0]].shape # Create an x-dependent function to identify the the highest valued neurons at each layer influential_neurons = self.influence_generator(activations, **kwargs) # Layers to profile: n = len(self.layerdict) if layers_to_profile == None: ltp = range(1, n) elif isinstance(layers_to_profile, list): ltp = [lyr for lyr in layers_to_profile if lyr >= 1 or lyr <= n - 1] else: # a tuple is expected start = max(1, layers_to_profile[0]) end = min(n, layers_to_profile[1]) ltp = range(start, end) # Fill dictionaries for k in ltp: neuron_counts[k], neuron_weights[k] = influential_neurons(k, infl_threshold) ######## Step 2: Identify Contributing Neurons (Nodes) and Synaptic connections (Edges) ######## with torch.no_grad(): y, activations = self.model.forward(x) activations["x_in"] = x # Assign activations to x_in and y_out for each layer # Note that if f = first available module in a layer # y_out = f(x_in) y_out = dict() x_in = dict() for k in ltp: y_out[k] = { k: activations[self.sght[k][0]] } # this is the output from the first available module in the layer nd = (k, self.layerdict[k][0][0]) pred_layers = list(self.SG.predecessors(nd)) x_in[k] = { pd[0]: activations[self.sght[pd[0]][1]] for pd in pred_layers } # this is the output from the last available module in the layer results = [ self.single_profile( x_in[ldx], y_out[ldx], neuron_counts[ldx], ldx, contrib_threshold, ) for ldx in ltp ] for i, f in enumerate(results): ldx = ltp[i] x_ldx = list(x_in[ldx].keys()) ncs, scs, sws = f synapse_counts[ldx] = scs synapse_weights[ldx] = sws if len(x_ldx) == 1: ncs = (ncs,) for nc, xdx in zip(ncs, x_ldx): if xdx in neuron_counts: neuron_counts[xdx] += nc neuron_counts[xdx] = neuron_counts[xdx].tocoo() else: neuron_counts[xdx] = nc return ( neuron_counts, neuron_weights, synapse_counts, synapse_weights, activation_shapes, )
[docs] @abstractmethod def influence_generator(self): pass
[docs] @abstractmethod def single_profile(self): pass
[docs] @abstractmethod def create_profile(self): pass