# mypy: allow-untyped-defs import uuid from collections import OrderedDict from functools import wraps from typing import Callable, Dict, List, Optional, Sequence, Type, Union import torch import torch.nn as nn from torch.distributed._composable_state import _State from torch.distributed.utils import _get_root_modules def generate_state_key(string="__composable_api_state_key"): return f"{string}_{str(uuid.uuid4())}" STATE_KEY = generate_state_key() REGISTRY_KEY = generate_state_key() # TODO: we can add additional info to RegistryItem to share across APIs. E.g., # we can add args and kwargs here, and then we can detect whether fully_shard # is combined with reentrant activation checkpointing and error out with a clear # message. class RegistryItem: pass def contract(state_cls: Type[_State] = _State): r""" Decorate a function as a composable distributed API, where the first argument of the function must be an :class:`nn.Module` instance or sequence of :class:`nn.Module` instances. The decorator verifies that the decorated function does not modify fully-qualified names (FQNs) for parameters, buffers, or modules. The decorated function can return different module instances than the input modules; the FQN invariant will be enforced following the input order. When a function ``func`` is decorated by ``@contract()``, a ``.state(module: nn.Module)`` method will be installed to the decorated function. Then you can retrieve and modify the state on a module by calling ``func.state(module)``. Example:: >>> # xdoctest: +SKIP >>> import torch.nn as nn >>> >>> class MyModel(nn.Module): >>> def __init__(self) -> None: >>> super().__init__() >>> self.l1 = nn.Linear(10, 10) >>> self.l2 = nn.Linear(10, 10) >>> >>> def forward(self, x): >>> return self.l2(self.l1(x)) >>> >>> @contract() >>> def my_feature(module: nn.Module) -> nn.Module: >>> my_feature.state(module).some_state = "any value" >>> return module >>> >>> model = MyModel() >>> my_feature(model.l1) >>> assert my_feature.state(model.l1).some_state == "any value" >>> my_feature(model.l2) >>> model(torch.randn(2, 10)).sum().backward() """ # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package @wraps(state_cls) def inner(func): @wraps(func) def wrapper( module: Union[nn.Module, Sequence[nn.Module]], *args, **kwargs ) -> Optional[nn.Module]: inp_module = module if isinstance(module, nn.Module): modules = [module] else: # If the user passes a sequence of modules, then we assume that # we only need to insert the state object on the root modules # (i.e. those without a parent) among the passed-in modules. modules = _get_root_modules(list(module)) state = state_cls() # shared across all modules registry_item = RegistryItem() # shared across all modules # `func` is allowed to return different module instances than the # input modules as long as FQNs are preserved following the input # module order all_orig_named_params: List[Dict[str, nn.Parameter]] = [] all_orig_named_buffers: List[Dict[str, torch.Tensor]] = [] all_orig_named_modules: List[Dict[str, nn.Module]] = [] for module in modules: default_all_state: Dict[Callable, _State] = OrderedDict() default_registry: Dict[str, RegistryItem] = OrderedDict() all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] STATE_KEY, default_all_state ) if not isinstance(all_state, dict): raise AssertionError( f"Distributed composable API states corrupted: {all_state}" ) registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] REGISTRY_KEY, default_registry ) if not isinstance(registry, dict): raise AssertionError( f"Distributed composable API registry corrupted: {registry}" ) if func in all_state or func.__name__ in registry: raise AssertionError( "Each distinct composable distributed API can only be applied to a " f"module once. {func.__name__} has already been applied to the " f"following module:\n{module}" ) all_state.setdefault(func, state) registry.setdefault(func.__name__, registry_item) all_orig_named_params.append(OrderedDict(module.named_parameters())) all_orig_named_buffers.append(OrderedDict(module.named_buffers())) all_orig_named_modules.append(OrderedDict(module.named_modules())) updated = func(inp_module, *args, **kwargs) if updated is None: updated = inp_module if isinstance(updated, nn.Module): updated_modules = [updated] else: updated_modules = _get_root_modules(list(inp_module)) all_new_named_params: List[Dict[str, nn.Parameter]] = [] all_new_named_buffers: List[Dict[str, torch.Tensor]] = [] all_new_named_modules: List[Dict[str, nn.Module]] = [] for module in updated_modules: all_new_named_params.append(OrderedDict(module.named_parameters())) all_new_named_buffers.append(OrderedDict(module.named_buffers())) all_new_named_modules.append(OrderedDict(module.named_modules())) num_orig_modules = len(all_orig_named_modules) num_new_modules = len(all_new_named_modules) if num_orig_modules != num_new_modules: raise AssertionError( f"{func.__name__} should return the same number of modules as input modules" f"Inputs: {num_orig_modules} modules\n" f"Outputs: {num_new_modules} modules" ) def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str): if orig_fqns == new_fqns: return orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns) orig_only = orig_fqn_set - new_fqn_set new_only = new_fqn_set - orig_fqn_set if len(orig_only) or len(new_only): raise RuntimeError( f"{check_key}" "Composable distributed API implementations cannot modify FQNs.\n" f"FQNs only in original: {orig_only}\n" f"FQNs only in new: {new_only}" ) else: raise RuntimeError( f"{check_key}" "Composable distributed API implementations cannot modify " "the order of FQNs.\n" f"Original FQNs: {orig_only}\n" f"New FQNs: {new_only}" ) for orig_named_params, new_named_params in zip( all_orig_named_params, all_new_named_params ): check_fqn( list(orig_named_params.keys()), list(new_named_params.keys()), "Checking parameters: ", ) for orig_named_buffers, new_named_buffers in zip( all_orig_named_buffers, all_new_named_buffers ): check_fqn( list(orig_named_buffers.keys()), list(new_named_buffers.keys()), "Checking buffers: ", ) for orig_named_modules, new_named_modules in zip( all_orig_named_modules, all_new_named_modules ): check_fqn( list(orig_named_modules.keys()), list(new_named_modules.keys()), "Checking modules: ", ) # TODO: verify that installed distributed paradigms are compatible with # each other. return updated def get_state(module: nn.Module) -> Optional[_State]: return module.__dict__.setdefault( # type: ignore[call-overload] STATE_KEY, {}, # TODO(@yhcharles): this is a temporary fix, need a better way ).get( func ) # type: ignore[call-overload] wrapper.state = get_state # type: ignore[attr-defined] return wrapper return inner def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]: r""" Get an ``OrderedDict`` of composable APIs that have been applied to the ``module``, indexed by the API name. If no API has been applied, then this returns ``None``. """ return getattr(module, REGISTRY_KEY, None)