# mypy: allow-untyped-defs import operator from typing import Any, Callable, Dict, Tuple, Optional import torch import torch.fx import torch.fx as fx from torch.fx import Transformer, Proxy from torch.fx.node import Argument, Target, Node, map_aggregate from torch.fx.operator_schemas import ( normalize_module, normalize_function, create_type_hint, ) from .schema_type_annotation import AnnotateTypesWithSchema class NormalizeArgs(Transformer): """ Normalize arguments to Python targets. This means that `args/kwargs` will be matched up to the module/functional's signature and rewritten to exclusively kwargs in positional order if `normalize_to_only_use_kwargs` is true. Also populates default values. Does not support positional-only parameters or varargs parameters (*args, **kwargs). If the nodes have 'type' metadata, it will use it to disambiguate overloads. Otherwise, it will throw an error. Example usage: m = torchvision.models.resnet18() traced = torch.fx.symbolic_trace(m) traced = NormalizeArgs(traced).transform() """ def __init__( self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True ): super().__init__(module) self.node_map: Dict[Proxy, Node] = {} self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs def run_node(self, n: Node) -> Any: args, kwargs = self.fetch_args_kwargs_from_env(n) def get_type(arg): if isinstance(arg, fx.Node): return n.meta["type"] if "type" in n.meta else None return type(arg) arg_types = map_aggregate(n.args, get_type) assert isinstance(arg_types, tuple) arg_types = tuple([create_type_hint(i) for i in arg_types]) kwarg_types = {k: get_type(v) for k, v in kwargs.items()} if n.op == "call_function": out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) else: out = super().run_node(n) if n.op != "output": self.node_map[out] = n out.node.meta = n.meta out.node.type = n.type return out def call_function( self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any], arg_types: Optional[Tuple[Any, ...]] = None, kwarg_types: Optional[Dict[str, Any]] = None, ): assert callable(target) new_args_and_kwargs = normalize_function( target, args, # type: ignore[arg-type] kwargs, arg_types, # type: ignore[arg-type] kwarg_types, self.normalize_to_only_use_kwargs, ) if new_args_and_kwargs: new_args, new_kwargs = new_args_and_kwargs return self.tracer.create_proxy( "call_function", target, new_args, new_kwargs ) else: return super().call_function(target, args, kwargs) def call_module( self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] ): assert isinstance(target, str) new_args_and_kwargs = normalize_module( self.module, target, args, # type: ignore[arg-type] kwargs, self.normalize_to_only_use_kwargs, ) if new_args_and_kwargs: new_args, new_kwargs = new_args_and_kwargs return super().call_module(target, new_args, new_kwargs) else: return super().call_module(target, args, kwargs) class NormalizeOperators(AnnotateTypesWithSchema): """ Normalize callsites that are different ways of "spelling" the same invocation into a single, canonical call. Currently supports: 1. Normalize operators (e.g. operator.add) to the `torch` ops they ultimately invoke (e.g. torch.add) when it is possible to statically reason that Example usage: m = torchvision.models.resnet18() traced = torch.fx.symbolic_trace(m) traced = NormalizeOperators(traced).transform() """ binary_magic_method_remap: Dict[ Callable[[Any, Any], Any], Callable[[Any, Any], Any] ] = { torch.add: operator.add, torch.mul: operator.mul, torch.sub: operator.sub, torch.div: operator.truediv, torch.floor_divide: operator.floordiv, torch.remainder: operator.mod, torch.eq: operator.eq, torch.ne: operator.ne, torch.lt: operator.lt, torch.le: operator.le, torch.gt: operator.gt, torch.ge: operator.ge, } def call_function( self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] ): # Normalize operators according to the magic methods implemented on tensors here: # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 assert callable(target) if target in self.binary_magic_method_remap: if len(args) != 2: return super().call_function(target, args, kwargs) lhs, rhs = args return super().call_function( target=self.binary_magic_method_remap[target], args=(lhs, rhs), kwargs={}, ) return super().call_function(target, args, kwargs)