# mypy: allow-untyped-defs import contextlib import warnings import weakref from abc import ABC, abstractmethod from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union import torch import torch._inductor.config as inductor_config import torch.utils._pytree as pytree from torch._C import _functionalization_reapply_views_tls as _reapply_views from torch._ops import _get_dispatch_mode_pre_dispatch from torch._subclasses.meta_utils import is_sparse_any from torch.utils._python_dispatch import ( _detect_infra_mode, _disable_infra_mode, return_and_correct_aliasing, TorchDispatchMode, ) not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") # NOTE Some special handling for tensor conversion during export is needed. # Normally, when tracing through the model with tensor.to(), the maybe-aliasing # relationship between input and output tensors will be baked into the graph. # For example, if we got a tensor with device cpu and call tensor.to("cpu"), # it will become a no-op in the graph. For a whole graph capture, this is not # sound so we need to do something different. Instead, in export we will try to # preserve the tensor conversion by forcing a non-semantic-breaking aten::_to_copy # operator to be traced in the graph, and subsequently banning mutations on all # such converted tensors. # In addition to patching .to() method call in functionalization, we will have to # patch other similar methods like float() and cpu(), because they intentionally # don't fall back to .to() methods, but have the same behavior as .to() according to # pytorch document. https://pytorch.org/docs/stable/generated/torch.Tensor.float.html # thus we simply force them to go through .to() call. def _conversion_method_template(**extra_kwargs): def _(self, *args, **kwargs): return self.to(*args, **{**kwargs, **extra_kwargs}) return _ class FunctionalTensor(torch.Tensor): """ Functional tensors represent tensors that will remove mutations from a program. If you perform a mutable operation on a functional tensor, it will re-dispatch to the functional variant of that operation. Historically, functionalization is implemented in C++ in the dispatcher. This class is a lightweight python shim around the C++ functionalization logic. FunctionalTensor is required to be used with a corresponding FunctionalTensormode active, because it relies on using the mode for dispatch (which can properly handle factory functions). """ elem: torch.Tensor # Indicates to our torch_dispatch dispatching infra that # this is an "infra" mode with lower dispatching precedence. _mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL # Note: The reason we add these extra keys to our FunctionalTensor subclass # is to mirror the behavior of C++ functionalization (we can choose to change this # later, as long as it doesn't break anything). # FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor # to the wrapper, excluding functorch and python dispatch keys. # Here I'm trying to re-use the keyset the functorch wrapper subclasses copy, # except that they don't include ZeroTensor so I'm manually adding it in. _extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add( torch._C.DispatchKey.ZeroTensor ) # These are all aten ops that correspond to metadata queries. # We want FunctionalTensor to be able to handle them directly. metadata_fns = [ torch.ops.aten.is_contiguous.default, # type: ignore[has-type] torch.ops.aten.is_contiguous.memory_format, # type: ignore[has-type] torch.ops.aten.is_strides_like_format.default, # type: ignore[has-type] torch.ops.aten.is_non_overlapping_and_dense.default, # type: ignore[has-type] torch.ops.aten.size.default, # type: ignore[has-type] torch.ops.aten.sym_size.default, # type: ignore[has-type] torch.ops.aten.stride.default, # type: ignore[has-type] torch.ops.aten.sym_stride.default, # type: ignore[has-type] torch.ops.aten.storage_offset.default, # type: ignore[has-type] torch.ops.aten.sym_storage_offset.default, # type: ignore[has-type] torch.ops.aten.numel.default, # type: ignore[has-type] torch.ops.aten.sym_numel.default, # type: ignore[has-type] torch.ops.aten.dim.default, # type: ignore[has-type] torch.ops.prim.device.default, # type: ignore[has-type] ] # These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing # TODO (tmanlaibaatar) make it a tag maybe_aliasing_or_mutating_ops = [ torch.ops.aten.dropout.default, # type: ignore[has-type] torch.ops.aten.batch_norm.default, # type: ignore[has-type] torch.ops.aten.native_batch_norm.default, # type: ignore[has-type] torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type] torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type] torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type] torch.ops.aten.atleast_1d.default, # type: ignore[has-type] torch.ops.aten.atleast_2d.default, # type: ignore[has-type] torch.ops.aten.atleast_3d.default, # type: ignore[has-type] torch.ops.aten.cartesian_prod.default, # type: ignore[has-type] torch.ops.aten.conj_physical.default, # type: ignore[has-type] torch.ops.aten.alpha_dropout.default, # type: ignore[has-type] torch.ops.aten.feature_dropout.default, # type: ignore[has-type] torch.ops.aten.feature_alpha_dropout.default, # type: ignore[has-type] torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type] ] # Used by auto_functionalize to determine base of tensors during inference mode. _inference_mode_base: Optional["FunctionalTensor"] = None def __new__(cls, elem, mode): assert torch._is_functional_tensor(elem) # In general, we'd like our functional tensor subclass to only be in charge of functionalization, # and defer to the inner subclass for all other functionality. # Example: If our inner tensor is a ZeroTensor, we would want to defer running the ZeroTensor fallback # until after we redispatch to our inner ZeroTensor. # However, there are a few keys that we need to mirror between the inner and outer tensors. # Conjugate # Negative # Why? These keys are used to test metadata queries, like `.is_conj()` and `.is_neg()`. # We **need** calls to is_conj() to return the same thing on the outer and inner tensors, # Because user code / framework code that branches like so needs to do the same thing # when it sees the outer FunctionalTensor: # if (x.is_conj()) { # return at::view_as_real(x.resolve_conj()); # } else { # return at::view_as_real(x); # } extra_dispatch_keys = ( FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem) ) out = torch.Tensor._make_wrapper_subclass( # type: ignore[arg-type, attr-defined] # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. # Calling the overload that has kwargs causes us to go down the first overload path, # which will **always** specialize sizes. # We should probably eventually fix this so that the first overload can just handle dynamic shapes. cls, elem.shape, # sizes elem.stride() if not is_sparse_any(elem) else None, # strides ( elem.storage_offset() if not is_sparse_any(elem) else None ), # storage_offset None, # memory_format elem.dtype, # dtype elem.layout, # layout elem.device, # device False, # pin_memory elem.requires_grad, # requires_grad None, # dispatch_sizes_strides_policy False, # dispatch_device False, # dispatch_layout extra_dispatch_keys, # _extra_dispatch_keys ) torch._C._set_throw_on_mutable_data_ptr(out) out.elem = elem if ( torch.is_inference_mode_enabled() and torch._inductor.config.enable_auto_functionalized_v2 ): if out.is_base_tensor(): out._inference_mode_base = None # This assumes that the FunctionalTensor.elem does not change its storage after this point. # Otherwise this would be invalid. mode._storage_to_base[out.elem.untyped_storage()] = out else: out._inference_mode_base = mode._storage_to_base[ out.elem.untyped_storage() ] assert out._inference_mode_base is not None return out def __torch_dispatch__(self, func, types, args=(), kwargs=None): unrecognized_types = [ t for t in types if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor] ] if unrecognized_types: not_implemented_log.debug( "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types ) return NotImplemented if kwargs is None: kwargs = {} # FunctionalTensor needs to plumb all metadata requests to the inner tensor. # In theory we don't have to do this - but if we want to service metadata requests here, # we need to carefully make sure all metadata is accurate (including metadata mutations) if func in FunctionalTensor.metadata_fns: # All metadata accesses should be plumbed to the inner tensor, that way we don't have to worry # about the problem of keeping metadata in sync between the wrapper and inner tensor. # This also alleviates us from having to manually handle metadata mutations on the wrapper. assert len(kwargs) == 0 if func in [ torch.ops.aten.is_strides_like_format.default, torch.ops.aten.is_contiguous.memory_format, ]: assert len(args) == 2 and isinstance(args[0], FunctionalTensor) return func(torch._from_functional_tensor(args[0].elem), args[1]) assert len(args) == 1 and isinstance(args[0], FunctionalTensor) return func(torch._from_functional_tensor(args[0].elem)) # Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up: # - _make_wrapper_subclass requires a __torch_dispatch__ # - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor, # which is of type FunctionalTensorWrapper! We explicitly do not want our wrapper to be a FunctionalTensorWrapper. # - If we use the default tensor.__new__(), we have another problem: it returns inner_tensor.alias(), # which causes every subclass created above autograd to have autograd view metadata # (in addition to also being a FunctionalTensorWrapper). raise RuntimeError( "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" ) def __repr__(self): return f"FunctionalTensor({repr(self.elem)})" @staticmethod def to_functional(x): # We will do the wrapping for the user. assert not torch._is_functional_tensor(x) # The only autograd metadata we care about on the FunctionalTensor is: # - requires_grad (so autograd runs) # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine) # this is handled by FunctionalTensor.to_functional x_functional = torch._to_functional_tensor(x) # Technically the FunctionalTensormode here is unnecessary, # but it avoids spurious NotImplemented logs during `ProxyTorchDispatchMode` tracing. # _mirror_autograd_meta_to queries tensor sizes, # and otherwise the sym_size() call will go to the proxy mode before hitting # FunctionalTensor.__torch_dispatch__ functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) assert functional_mode is not None with functional_mode: torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined] out = FunctionalTensor(x_functional, functional_mode) torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined] return out def from_functional(self): torch._sync(self) return torch._from_functional_tensor(self.elem) def is_base_tensor(self) -> bool: return torch._is_functional_tensor_base(self.elem) def replace_(self, output) -> None: torch._functionalize_replace(self.elem, output) def commit_update(self) -> None: torch._functionalize_commit_update(self.elem) def sync(self) -> None: torch._functionalize_sync(self.elem) def mark_mutation_hidden_from_autograd(self) -> None: torch._functionalize_mark_mutation_hidden_from_autograd(self.elem) def tolist(self) -> Any: if self.elem.dim() == 0: return self.elem.item() elif self.elem.dim() == 1: return [elem.item() for elem in self.elem] else: return [elem.tolist() for elem in self.elem] def to(self, *args, **kwargs): if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export: # If copy is specified as pos arg, it's always the second one. if len([arg for arg in args if isinstance(arg, bool)]) <= 1: return super().to(*args, **{**kwargs, "copy": True}) return super().to(*args, **kwargs) def cuda(self, device=None, *args, **kwargs): device = device or torch.cuda.current_device() if len(args) > 0: return self.to(device, *args, **kwargs) else: return self.to(device=device, **kwargs) char = _conversion_method_template(dtype=torch.int8) cpu = _conversion_method_template(device=torch.device("cpu")) bfloat16 = _conversion_method_template(dtype=torch.bfloat16) byte = _conversion_method_template(dtype=torch.uint8) double = _conversion_method_template(dtype=torch.float64) float = _conversion_method_template(dtype=torch.float32) bool = _conversion_method_template(dtype=torch.bool) half = _conversion_method_template(dtype=torch.float16) int = _conversion_method_template(dtype=torch.int32) long = _conversion_method_template(dtype=torch.int64) # TODO(sparse-team): fixes #133174 but can we do without the relay? def to_dense(self): return self.elem.to_dense() @property def layout(self): return self.elem.layout class FunctionalTensorMode(TorchDispatchMode): def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False): super().__init__() self.export = export self.is_on_stack = False self.enter_stack = [] # Indicates to our torch_dispatch dispatching infra that # this is an "infra" mode with lower dispatching precedence. self._mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL self.pre_dispatch = pre_dispatch # This will be turned off later for pre-dispatch functionalization self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined] # Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep # track of the ordering between side effectful operations. self._tokens: Dict[Any, torch.Tensor] = {} # Filled after forward tracing. self._tokens_forward_output: Dict[Any, torch.Tensor] = {} # Functionalization runs twice in AOTAutograd, once in # `run_functionalized_fw_and_collect_metadata` to collect metadata to # see which tensors need to be functionalized and discover how many # tokens we need, and another time in `make_fx` which does the actual # tracing to replace ops with their functional variants and handling # side-effectful ops. In the second stage there should be no token # discovery. This flag distinguishes between the two stages. self._allow_token_discovery = _allow_token_discovery self._storage_to_base: weakref.WeakKeyDictionary[ torch.storage.UntypedStorage, Optional[FunctionalTensor] ] = weakref.WeakKeyDictionary() # No-op if FunctionalTensorMode is already in use def __enter__(self): def _get_prev_mode(): if self._dispatch_key == torch._C.DispatchKey.PreDispatch: return _get_dispatch_mode_pre_dispatch( torch._C._TorchDispatchModeKey.FUNCTIONAL ) return torch._C._get_dispatch_mode( torch._C._TorchDispatchModeKey.FUNCTIONAL ) if _get_prev_mode() is None: self.enter_stack.append(True) return super().__enter__() else: self.enter_stack.append(False) return self def __exit__(self, a, b, c): is_on_stack = self.enter_stack.pop() if is_on_stack: super().__exit__(a, b, c) def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if self.export: # We need to make sure that we don't decompose to() as usual in export mode, # because it can get optimized away. Instead we always replace it with _to_copy(). if func == torch.ops.aten.to.dtype_layout: kwargs.pop("copy", None) return self.__torch_dispatch__( torch.ops.aten._to_copy.default, types, args, kwargs ) if func == torch.ops.aten.to.dtype: schema = tuple(arg.name for arg in func._schema.arguments) for arg, name in zip(args[1:], schema[1:]): kwargs[name] = arg kwargs.pop("copy", None) return self.__torch_dispatch__( torch.ops.aten._to_copy.default, types, args[:1], kwargs ) unrecognized_types = [ t for t in types if not issubclass(t, torch._subclasses.FakeTensor) and t not in [torch.Tensor, FunctionalTensor] ] if unrecognized_types: not_implemented_log.debug( "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types ) return NotImplemented def _can_decompose(func): # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832 # Never decompose dropout in export if self.export and func == torch.ops.aten.dropout.default: return False # We unconditionally decompose ops that are maybe aliasing or mutating ops if func in FunctionalTensor.maybe_aliasing_or_mutating_ops: return True # (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops, # because we must know statically of an op mutates or aliasing in order to functionalize it properly # (2) for mutating ops that have CompositeImplicit decomps, we choose to decompose them today. # In theory, we could walk this back and avoid decomposing them later if we need to. alias_info_present = any(arg.alias_info for arg in func._schema.arguments) if alias_info_present or func._schema.is_mutable: return True # If we are here, it means we are seeing functional composite op. # For pre-dispatch IR or export inference IR, we wont' decompose them if (self.export or self.pre_dispatch) and func._can_decompose(): if func.namespace not in ["aten", "prim"]: # TODO (tmanlaibaatar) check if the op is PT2 compliant warnings.warn( f"At pre-dispatch tracing, we assume that any custom op marked with " f"CompositeImplicitAutograd and have functional schema are safe to not decompose. " f"Found {func} to be one such op." ) return False # in normal torch.compile IR, we decompose functional composite ops return True if ( func not in FunctionalTensor.metadata_fns and _can_decompose(func) # Not all funcs from __torch_dispatch__ are actual dispatcher ops, # e.g. prim.device and torch._C._dispatch_has_kernel(func.name()) ): with self: r = func.decompose(*args, **kwargs) if r is not NotImplemented: return r def wrap(x): # Only wrap our outputs in subclasses if the inner functionalization call # also wrapped outputs into FunctionalTensorWrappers. # When can this happen? e.g. `torch.div(2, 2)` assert not isinstance(x, FunctionalTensor) if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): return FunctionalTensor(x, self) return x def unwrap(x): return x.elem from torch._higher_order_ops.auto_functionalize import ( can_auto_functionalize, do_auto_functionalize, do_auto_functionalize_v2, ) if can_auto_functionalize( func ) and not torch._C._dispatch_has_kernel_for_dispatch_key( func.name(), torch._C.DispatchKey.Functionalize ): # it doesn't matter what mode we use here because # the implementation of do_auto_functionalize doesn't # interact with FunctionalTensorMode at all if self.export or not inductor_config.enable_auto_functionalized_v2: return do_auto_functionalize(func, args, kwargs) else: return do_auto_functionalize_v2(func, args, kwargs) from torch._higher_order_ops.effects import handle_effects, has_effects if has_effects(func, args, kwargs): assert not torch._C._dispatch_has_kernel_for_dispatch_key( func.name(), torch._C.DispatchKey.Functionalize ) return handle_effects( self._allow_token_discovery, self._tokens, func, args, kwargs ) args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( FunctionalTensor, unwrap, (args, kwargs) ) # Expectation: functionalization should not **already** be enabled above our mode. # Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization # to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper. is_included = torch._C._dispatch_tls_is_dispatch_key_included( torch._C.DispatchKey.Functionalize ) is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( torch._C.DispatchKey.Functionalize ) assert is_excluded or not is_included include_to_set = ( torch._C._dispatch_tls_local_include_set() | torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) ) exclude_to_set = ( torch._C._dispatch_tls_local_exclude_set().remove( torch._C.DispatchKey.Functionalize ) - FunctionalTensor._extra_dispatch_keys ) # All we want to do here is re-use the existing C++ functionalization logic. # This requires swizzling our TLS dispatch keys so that the Functionalize key is active. with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): try: # By default for python functionalization (for AOTAutograd), we reapply views. old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined] # Sometimes these functions cannot be directly dispatched to functionalize key # because args are sometimes not functional tensors for some reason? if func in FunctionalTensor.metadata_fns: outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped) outs_wrapped = pytree.tree_map_only( torch.Tensor, wrap, outs_unwrapped ) else: # When we dispatch to the C++ functionalization kernel, we might need to jump back to the # PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath # FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch # from the TLS in order to avoid infinite looping, but this would prevent us from coming # back to PreDispatch later outs_unwrapped = func._op_dk( torch._C.DispatchKey.Functionalize, *args_unwrapped, **kwargs_unwrapped, ) # We don't allow any mutation on result of dropout or _to_copy if self.export: if func in ( torch.ops.aten.dropout.default, torch.ops.aten._to_copy.default, ): torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] outs_wrapped = pytree.tree_map_only( torch.Tensor, wrap, outs_unwrapped ) finally: torch._disable_functionalization() torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined] is_included = torch._C._dispatch_tls_is_dispatch_key_included( torch._C.DispatchKey.Functionalize ) is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( torch._C.DispatchKey.Functionalize ) assert is_excluded or not is_included if ( # If no outputs are our functional subclass, then don't try to fix up aliasing not any( isinstance(x, FunctionalTensor) for x in pytree.tree_leaves(outs_wrapped) ) # Since lift_fresh lifts its argument into a functional tensor, we can skip the # aliasing correction step. Otherwise, we would be setting the storage of a # lifted tensor to that of an unlifted tensor. # Ref: https://github.com/pytorch/pytorch/issues/111506 or func == torch.ops.aten.lift_fresh.default ): return outs_wrapped # for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper if ( torch.Tag.inplace_view in func.tags and func is not torch.ops.aten.set_.source_Tensor ): with torch.utils._mode_utils.no_dispatch(): func(*args, **kwargs) # Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing. # inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects. # Use this util to figure out the right thing to return. # If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for. return return_and_correct_aliasing(func, args, kwargs, outs_wrapped) @classmethod def is_infra_mode(cls) -> bool: return True @contextlib.contextmanager def disable_functional_mode(): return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) # This is similar to torch.func.functionalize, but: # - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass). # One important advantage to using this mode is that it will let us # run functionalization underneath __torch_dispatch__, # which we need in AOTAutograd. # - Doing so means that it does not automatically compose with other # functorch transforms, since these transforms always run above __torch_dispatch__. # That's why this util lives here, and not in functorch. def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()): # TODO: pull these from aot autograd def to_fun(t): if isinstance(t, torch.Tensor): return FunctionalTensor.to_functional(t) return t def from_fun(t): if not isinstance(t, FunctionalTensor): # quick sanity assert if isinstance(t, torch.Tensor): assert not torch._is_functional_tensor(t) return t torch._sync(t) return torch._from_functional_tensor(t.elem) def inner(*args, **kwargs): disable_above = torch._C._ExcludeDispatchKeyGuard( torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) ) with disable_above, mode: func_args = pytree.tree_map_only(torch.Tensor, to_fun, args) func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs) func_outputs = func(*func_args, **func_kwargs) outputs = pytree.tree_map_only(FunctionalTensor, from_fun, func_outputs) return outputs return inner class BaseFunctionalizeAPI(ABC): @abstractmethod def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: pass @abstractmethod def unwrap_tensors( self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] ) -> Any: pass @abstractmethod def functionalize(self, inner_f: Callable) -> Callable: pass @abstractmethod def redispatch_to_next(self) -> ContextManager: pass @abstractmethod def replace(self, input_tensor, output_tensor) -> None: pass @abstractmethod def commit_update(self, tensor) -> None: pass @abstractmethod def sync(self, tensor) -> None: pass @abstractmethod def mark_mutation_hidden_from_autograd(self, tensor) -> None: pass class PythonFunctionalizeAPI(BaseFunctionalizeAPI): def __init__( self, mode: Optional[FunctionalTensorMode] = None, pre_dispatch: bool = False ) -> None: super().__init__() self.mode = mode if mode else FunctionalTensorMode() self.pre_dispatch = pre_dispatch def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: with self.mode: return torch.utils._pytree.tree_map_only( torch.Tensor, FunctionalTensor.to_functional, args ) def unwrap_tensors( self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]] ) -> Any: return torch.utils._pytree.tree_map_only( FunctionalTensor, FunctionalTensor.from_functional, args ) def functionalize(self, inner_f: Callable) -> Callable: return dispatch_functionalize(inner_f, self.mode) def redispatch_to_next(self) -> ContextManager: # [NOTE] We don't do anything here because at the time # we exercise this path, we would have already popped the # FunctionalTensorMode from mode stack. Since FunctionalTensorMode # is now stateful, it is better to explicitly pass in correct mode # directly instead of globally setting it. return contextlib.nullcontext() def replace(self, input_tensor, output_tensor) -> None: assert isinstance(input_tensor, FunctionalTensor) assert not isinstance(output_tensor, FunctionalTensor) input_tensor.replace_(output_tensor) def commit_update(self, tensor) -> None: assert isinstance(tensor, FunctionalTensor) tensor.commit_update() def sync(self, tensor) -> None: assert isinstance(tensor, FunctionalTensor) tensor.sync() def mark_mutation_hidden_from_autograd(self, tensor) -> None: assert isinstance(tensor, FunctionalTensor) tensor.mark_mutation_hidden_from_autograd() class CppFunctionalizeAPI(BaseFunctionalizeAPI): def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional return _wrap_all_tensors_to_functional(args, level=0) def unwrap_tensors( self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: from torch._functorch.eager_transforms import ( _unwrap_all_tensors_from_functional, ) return _unwrap_all_tensors_from_functional(args, reapply_views=_reapply_views()) def functionalize(self, inner_f: Callable) -> Callable: return torch.func.functionalize(inner_f) def redispatch_to_next(self) -> ContextManager: return torch._C._ExcludeDispatchKeyGuard( torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) ) def replace(self, input_tensor, output_tensor) -> None: torch._functionalize_replace(input_tensor, output_tensor) def commit_update(self, tensor) -> None: torch._functionalize_commit_update(tensor) def sync(self, tensor) -> None: torch._functionalize_sync(tensor) def mark_mutation_hidden_from_autograd(self, tensor) -> None: torch._functionalize_mark_mutation_hidden_from_autograd(tensor) class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI): def __init__(self, interpreter): self.interpreter = interpreter def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional return _wrap_all_tensors_to_functional(args, level=self.interpreter.level()) def unwrap_tensors( self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: from torch._functorch.eager_transforms import ( _unwrap_all_tensors_from_functional, ) return _unwrap_all_tensors_from_functional( args, reapply_views=self.interpreter.functionalize_add_back_views() ) def functionalize(self, inner_f: Callable) -> Callable: return torch.func.functionalize( inner_f, remove=( "mutations_and_views" if self.interpreter.functionalize_add_back_views() else "mutations" ), ) def redispatch_to_next(self) -> ContextManager: return self.interpreter.lower() def replace(self, input_tensor, output_tensor) -> None: torch._functionalize_replace(input_tensor, output_tensor) def commit_update(self, tensor) -> None: torch._functionalize_commit_update(tensor) def sync(self, tensor) -> None: torch._functionalize_sync(tensor) def mark_mutation_hidden_from_autograd(self, tensor) -> None: torch._functionalize_mark_mutation_hidden_from_autograd(tensor) def mb_unwrap_functional_tensor(tensor: torch.Tensor): if isinstance(tensor, FunctionalTensor): return torch._from_functional_tensor(tensor.elem) return tensor