# mypy: allow-untyped-defs import operator from collections import defaultdict from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type import sympy import torch import torch.fx from torch.fx.experimental.symbolic_shapes import ( compute_unbacked_bindings, rebind_unbacked, statically_known_true, sym_eq, ) from torch.utils import _pytree as pytree from torch.utils._pytree import tree_map from .virtualized import V # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. # Works for length 2 patterns with 1 module and 1 function/method. def matches_module_function_pattern( pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]], node: torch.fx.node.Node, modules: Dict[str, torch.nn.modules.Module], ) -> bool: if len(node.args) == 0: return False if not isinstance(node.args[0], torch.fx.Node) or not isinstance( node, torch.fx.Node ): return False # the first node is call_module if node.args[0].op != "call_module": return False if not isinstance(node.args[0].target, str): return False if node.args[0].target not in modules: return False if type(modules[node.args[0].target]) is not pattern[0]: return False # the second node is call_function or call_method if node.op != "call_function" and node.op != "call_method": return False if node.target != pattern[1]: return False # make sure node.args[0] output is only used by current node. if len(node.args[0].users) > 1: return False return True class FakeTensorUpdater: """ The main idea here is that it's difficult to maintain accurate fake tensors (our primary form of metadata) for each node in our graph as we transform it. The most reliable way to obtain this information is by rerunning faketensor propagation. However, in general, faketensor propagation is fairly expensive. So, instead we'd like to only rerun faketensor propagation on nodes that have changed. In order to detect which nodes have changed, we first hash its node, target, and argument lists (which are immutable in FX). Then, whenever we call incremental_update, we check which FX nodes have a new hash, and recompute the faketensor metadata for that node. Then, we continue to recursively compute the faketensors for all users until the fake tensors stop changing. """ def __init__(self, graph: torch.fx.Graph) -> None: self.processed_hashes = set() self.graph = graph for node in self.graph.nodes: self.processed_hashes.add(self.hash_node(node)) def hash_node(self, node: torch.fx.Node): # todo(chilli): Not a great hash function return (node, node.target, id(node.args), id(node.kwargs)) def incremental_update(self): processed = set() existing_storages: DefaultDict[Optional[int], int] = defaultdict(int) for node in self.graph.nodes: existing_storages[get_node_storage(node)] += 1 def is_intlist_same(new, old): return statically_known_true(sym_eq(new, old)) def is_fake_tensor_same(new, old): if type(new) != type(old): return False if isinstance(new, (list, tuple)): if len(new) != len(old): return False return all( is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old) ) if new is None: return old is None if not isinstance(new, torch.Tensor): assert isinstance( new, (torch.SymInt, torch.SymBool, torch.SymFloat) ), f"Unknown type {type(new)} in {self.graph}" return ( new.node.shape_env._maybe_evaluate_static( sympy.Eq(new.node.expr, old.node.expr) ) == sympy.true ) if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout: return False if new.layout == torch.strided and ( not is_intlist_same(new.stride(), old.stride()) or not statically_known_true( new.storage_offset() == old.storage_offset() ) ): return False if new.device != old.device: return False if get_storage(new) == get_storage(old): return True # This is the case where it returns a completely fresh storage that's used nowhere else. if ( existing_storages[get_storage(old)] == 1 and get_storage(new) not in existing_storages ): return True return False def should_process_node(node): # node.target for nodes returning true from this function # are called under fake mode and does not work for inductor # lowerings. We check if the node.target is an aten operator # or operator.getitem which is used when returning multiple # tensors from an op. return node.op == "call_function" and ( isinstance(node.target, torch._ops.OpOverload) or node.target == operator.getitem ) to_process = set() for node in self.graph.nodes: if ( self.hash_node(node) in self.processed_hashes and id(node) not in to_process ): continue if not should_process_node(node): continue is_valid, args, kwargs = get_fake_args_kwargs(node) if not is_valid: continue with V.fake_mode: new_fake_tensor = node.target(*args, **kwargs) if "val" in node.meta and is_fake_tensor_same( new_fake_tensor, node.meta["val"] ): continue rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor) node.meta["val"] = new_fake_tensor if (shape_env := V.fake_mode.shape_env) and ( symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor) ): # Refresh the bindings to the new symbols node.meta["unbacked_bindings"] = symbol_to_path existing_storages[get_node_storage(node)] += 1 to_process.update([id(user) for user in node.users]) self.processed_hashes.add(self.hash_node(node)) def get_storage(t: torch.Tensor) -> int: return t.untyped_storage()._cdata def get_node_storage(node: torch.fx.Node) -> Optional[int]: if "val" not in node.meta: return None if not isinstance(node.meta["val"], torch.Tensor): return None if not torch._C._has_storage(node.meta["val"]): return None return get_storage(node.meta["val"]) def get_fake(x): if isinstance(x, torch.fx.Node): if "val" not in x.meta: return x return x.meta["val"] return x def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]: """ First value returns a boolean if any of the input nodes don't have a faketensor. """ args, kwargs = tree_map(get_fake, (x.args, x.kwargs)) if any( isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs) ): return False, args, kwargs return True, args, kwargs def is_node_realized(node: torch.fx.Node) -> bool: """Returns true if a node is always realized when lowered to inductor IR. NOTE: This may return some false negatives. e.g. it doesn't handle buffers realized heuristically during lowering, or buffers realized indirectly through view ops. """ from torch._inductor.lowering import fallbacks, needs_realized_inputs def is_buffer(node: torch.fx.Node) -> bool: if node.op == "call_function" and node.target is operator.getitem: # For nodes with multiple outputs, we get the fx graph: # foo = torch.ops.aten.foo(...) # getitem = foo[0] # getitem_1 = foo[1] # where we need to check if foo is a fallback kernel return is_buffer(node.args[0]) # type: ignore[arg-type] return node.op in ("placeholder", "output") or node.target in fallbacks if is_buffer(node): return True def realizes_inputs(node: torch.fx.Node) -> bool: return node.op == "output" or node.target in needs_realized_inputs if any(realizes_inputs(user) for user in node.users): return True # Otherwise, assume node isn't realized return False