# mypy: allow-untyped-defs """Utilities for manipulating the torch.Graph object and the torchscript.""" # TODO(justinchuby): Move more of the symbolic helper functions here and expose # them to the user. from __future__ import annotations import dataclasses import re import typing from typing import Any, Iterable, Sequence import torch from torch import _C from torch.onnx._globals import GLOBALS from torch.onnx._internal import registration _ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$") _SKIP_NODE_ATTRIBUTES = {"inplace", "aten"} @dataclasses.dataclass class GraphContext: """Extra context for symbolic functions with all methods from torch.Graph. NOTE: This class is not meant for external consumption. Please do not depend on it outside of torch.onnx as the interface may evolve. Attributes: graph: The _C.Graph being constructed. block: The current _C.Block being constructed. opset: The opset version. original_node: Current node that is being converted from. params_dict: Mapping from graph initializer name to IValue. env: Mapping from Torch domain graph Value to ONNX domain graph Value. values_in_env: Set of all values in env, for constant-time lookups. new_nodes: List that tracks all new nodes that are added (used to make sure metadata is propagated to all new nodes). """ graph: _C.Graph block: _C.Block opset: int original_node: _C.Node params_dict: dict[str, _C.IValue] env: dict[_C.Value, _C.Value] values_in_env: set[_C.Value] new_nodes: list[_C.Node] = dataclasses.field(default_factory=list) # Relay methods from _C.Graph for compatibility with symbolic functions that expect # a _C.Graph def __getattr__(self, name: str) -> Any: return getattr(self.graph, name) def op( self, opname: str, *raw_args: torch.Tensor | _C.Value, outputs: int = 1, **kwargs, ): """Creates an ONNX operator "opname", taking "raw_args" as inputs and "kwargs" as attributes. The set of operators and the inputs/attributes they take is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md Args: opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified with a namespace, e.g., `aten::add`. raw_args: The inputs to the operator; usually provided as arguments to the `symbolic` definition. outputs: The number of outputs this operator returns. By default an operator is assumed to return a single output. If `outputs` is greater than one, this functions returns a tuple of output `Value`, representing each output of the ONNX operator in order. kwargs: The attributes of the ONNX operator, whose keys are named according to the following convention: `alpha_f` indicates the `alpha` attribute with type `f`. The valid type specifiers are `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute specified with type float accepts either a single float, or a list of floats (e.g., you would say `dims_i` for a `dims` attribute that takes a list of integers). Returns: The value representing the single output of this operator (see the `outputs` keyword argument for multi-return nodes). """ # FIXME(justinchuby): Add the return type back once we know how to handle mypy return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs) def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs): """Generates an ONNX ATen op node. This function is for backward compatibility with the old symbolic functions. """ return self.op( "aten::ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs, ) # NOTE: For backward compatibility with the old symbolic functions. # We are probably going to remove this only after the fx exporter is established. at = aten_op def onnxscript_op( self, onnx_fn, *raw_args: torch.Tensor | _C.Value, outputs: int = 1, **kwargs, ): """Creates an ONNX operator from onnx-script function, taking "raw_args" as inputs and "kwargs" as attributes. onnx-script repository: https://github.com/microsoft/onnx-script Args: onnx_fn: ONNXFunction from onnx-script; An example can be found at https://github.com/microsoft/onnx-script#example raw_args: The inputs to the operator; usually provided as arguments to the `symbolic` definition. outputs: The number of outputs this operator returns. By default an operator is assumed to return a single output. If `outputs` is greater than one, this functions returns a tuple of output `Value`, representing each output of the ONNX operator in order. kwargs: The attributes of the ONNX operator, whose keys are named according to the following convention: `alpha_f` indicates the `alpha` attribute with type `f`. The valid type specifiers are `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute specified with type float accepts either a single float, or a list of floats (e.g., you would say `dims_i` for a `dims` attribute that takes a list of integers). Returns: The value representing the single output of this operator (see the `outputs` keyword argument for multi-return nodes). """ # NOTE(titaiwang): This is using class attributes, and it needs to be updated # if onnx-script makes any change on these. symbolic_name = f"{onnx_fn.opset.domain}::{onnx_fn.name}" opset_version = onnx_fn.opset.version registration.custom_onnx_symbolic(symbolic_name, opset_version)(onnx_fn) return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs) def add_op_with_blocks( graph_context: GraphContext, opname: str, *inputs: _C.Value, outputs: int = 1, n_blocks: int = 1, **attributes, ) -> tuple[Any, tuple[GraphContext, ...], _C.Node]: """Creates an ONNX operator "opname", taking inputs and attributes. Args: graph_context: The context for the current graph. opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified with a namespace, e.g., `aten::add`. inputs: The inputs to the operator. outputs: The number of outputs this operator returns. By default an operator is assumed to return a single output. If `outputs` is greater than one, this functions returns a tuple of output `Value`, representing each output of the ONNX operator in order. n_blocks: The number of sub-blocks to create in the node. attributes: The attributes of the ONNX operator. Returns: A tuple of (output_values, new_contexts, node) where: output_values: One or more output value of this operator (see the `outputs` keyword argument for multi-return nodes). new_contexts: A tuple of new graph contexts for each sub-block. node: The node representing the operator. """ output_values = graph_context.op(opname, *inputs, outputs=outputs, **attributes) if isinstance(output_values, Sequence): node = output_values[0].node() else: node = output_values.node() new_contexts = [] for _ in range(n_blocks): new_block = node.addBlock() # Create shallow copy of the graph context and update the block new_context = dataclasses.replace(graph_context, block=new_block) new_contexts.append(new_context) return output_values, tuple(new_contexts), node def _add_op( graph_context: GraphContext, opname: str, *args: torch.Tensor | _C.Value, outputs: int = 1, **kwargs, ): """Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs". The set of operators and the inputs/attributes they take is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md This function is monkey-patched onto Graph. Args: graph_context: The Torch Graph or Block. opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified with a namespace, e.g., `aten::add`. args: The inputs to the operator; usually provided as arguments to the `symbolic` definition. outputs: The number of outputs this operator returns. By default an operator is assumed to return a single output. If `outputs` is greater than one, this functions returns a tuple of output `Value`, representing each output of the ONNX operator in order. kwargs: The attributes of the ONNX operator, whose keys are named according to the following convention: `alpha_f` indicates the `alpha` attribute with type `f`. The valid type specifiers are `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute specified with type float accepts either a single float, or a list of floats (e.g., you would say `dims_i` for a `dims` attribute that takes a list of integers). Returns: (Union[_C.Value, Tuple[_C.Value, ...]]) The value representing the single output of this operator (see the `outputs` keyword argument for multi-return nodes). """ inputs = [_const_if_tensor(graph_context, arg) for arg in args] # Filter out None attributes, this can be convenient client side because # now they can pass through None attributes, and have them not show up attributes = {k: v for k, v in kwargs.items() if v is not None} if "::" not in opname: opname = "onnx::" + opname node = _create_node( graph_context.block, opname, inputs, attributes, params_dict=graph_context.params_dict, opset_version=graph_context.opset, n_outputs=outputs, shape_inference=GLOBALS.onnx_shape_inference, ) graph_context.new_nodes.append(node) if outputs == 1: return node.output() return tuple(node.outputs()) def _const_if_tensor(graph_context: GraphContext, arg): if arg is None: return arg if isinstance(arg, _C.Value): return arg return _add_op(graph_context, "onnx::Constant", value_z=arg) def _create_node( graph_or_block: _C.Graph | _C.Block, domain_op: str, inputs: Sequence, attributes: dict, params_dict: dict, opset_version: int, n_outputs: int, shape_inference: bool = True, ) -> _C.Node: """Creates an node 'domain_op', taking inputs and attributes.""" if isinstance(graph_or_block, _C.Graph): graph = graph_or_block node = graph.create(domain_op, inputs, n_outputs) node = graph.insertNode(node) elif isinstance(graph_or_block, _C.Block): block = graph_or_block node = block.addNode(domain_op, inputs) # Block does not have create defined, so we need to add outputs manually if n_outputs > 1: for _ in range(1, n_outputs): node.addOutput() node_outputs = tuple(node.outputs()) # type: ignore[possibly-undefined] assert len(node_outputs) == n_outputs aten = domain_op.startswith("aten::") # Add all attributes for key, value in sorted(attributes.items()): if key in _SKIP_NODE_ATTRIBUTES: continue _add_attribute(node, key, value, aten=aten) if shape_inference: _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version) return node def _is_onnx_list(value): return isinstance(value, Iterable) and not isinstance( value, (str, bytes, torch.Tensor) ) def _scalar(x: torch.Tensor): """Convert a scalar tensor into a Python value.""" assert x.numel() == 1 return x[0] def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): r"""Initializes the right attribute based on type of value.""" m = _ATTR_PATTERN.match(key) if m is None: raise ValueError( f"Invalid attribute specifier '{key}' names " "must be suffixed with type, e.g. 'dim_i' or 'dims_i'" ) name, kind = m.group(1), m.group(2) if _is_onnx_list(value): kind += "s" return getattr(node, f"{kind}_")(name, value) # TODO: Expose this to user when migrating symbolic helper functions to here. def _is_tensor(x: _C.Value) -> bool: return x.type().isSubtypeOf(_C.TensorType.get()) def get_device_from_value(value: _C.Value) -> torch.device | None: if not _is_tensor(value): return None tensor_type = typing.cast(_C.TensorType, value.type()) return tensor_type.device() def parse_node_kind(kind: str) -> tuple[str, str]: """Parse node kind into domain and Op name.""" if "::" not in kind: raise ValueError(f"Node kind: {kind} is invalid. '::' is not in node kind.") domain, opname = kind.split("::", 1) if "::" in opname: raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.") return domain, opname def is_aten(domain: str) -> bool: """Check if the domain is official.""" return domain == "aten" def is_prim(domain: str) -> bool: """Check if the domain is official.""" return domain == "prim" def is_onnx(domain: str) -> bool: """Check if the domain is official.""" return domain == "onnx"