# mypy: allow-untyped-defs import logging import warnings from typing import Any, Dict, Iterable, Optional, Tuple import torch import torch.export import torch.export._trace from torch._utils_internal import log_export_usage log = logging.getLogger(__name__) __all__ = ["report_exportability"] def _generate_inputs_for_submodules( model: torch.nn.Module, target_submodules: Iterable[str], args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Tuple[Any, Any]]: """ Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this function doesn't work. Args: model: root model. inputs: inputs to the root model. target_submodules: submodules that we want to generate inputs for. Returns: A dict that maps from submodule name to its inputs. """ kwargs = kwargs or {} handles = [] results = {} submodule_to_names = {mod: name for name, mod in model.named_modules()} def pre_forward(module, module_args, module_kwargs): results[submodule_to_names[module]] = (module_args, module_kwargs) try: for name, mod in model.named_modules(): if name in target_submodules: handles.append( mod.register_forward_pre_hook(pre_forward, with_kwargs=True) ) model(*args, **kwargs) except Exception as e: warnings.warn( f"Failed to generate submodule inputs because of the following error:\n{e}" ) finally: for h in handles: h.remove() return results def report_exportability( mod: torch.nn.Module, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, *, strict: bool = True, pre_dispatch: bool = False, ) -> Dict[str, Optional[Exception]]: """ Report exportability issues for a module in one-shot. Args: mod: root module. args: args to the root module. kwargs: kwargs to the root module. Returns: A dict that maps from submodule name to the exception that was raised when trying to export it. `None` means the module is exportable without issue. Sample output: { '': UnsupportedOperatorException(func=), 'submod_1': UnsupportedOperatorException(func=), 'submod_2': None } """ log_export_usage(event="export.report_exportability") kwargs = kwargs or {} all_submod_names = [name for name, _ in mod.named_modules() if name != ""] submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) tried_module_types = set() report: Dict[str, Optional[Exception]] = {} def try_export(module, module_name, args, kwargs): nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types if type(module) in tried_module_types: return tried_module_types.add(type(module)) if args is not None or kwargs is not None: try: torch.export._trace._export( module, args, kwargs, strict=strict, pre_dispatch=pre_dispatch, ) report[module_name] = None log.info("Successfully exported `%s`", module_name) return except Exception as e: short_msg = repr(e).split("\n")[0] log.warning( "Failed exporting `%s` with exception: %s", module_name, short_msg ) report[module_name] = e for name, submod in module.named_children(): sub_module_name = name if module_name == "" else f"{module_name}.{name}" submod_args, submod_kwargs = submod_inputs.get( sub_module_name, (None, None) ) try_export(submod, sub_module_name, submod_args, submod_kwargs) return try_export(mod, "", args, kwargs) unique_issues = set() for exception in report.values(): if exception is not None: key = repr(exception).split("\\n")[0] unique_issues.add(key) log.warning("Found %d export issues:", len(unique_issues)) for issue in unique_issues: log.warning(issue) return report