from typing import Any, Callable, Dict, Optional from torch.utils._pytree import Context, TreeSpec def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]: """Reorder user-provided kwargs to match the order in `spec`. `spec` is expected to be the in_spec of an exported program, i.e. the spec that results from flattening `(args, kwargs)`. We need this to provide consistent input ordering, such so that users can pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result. """ # Make sure that the spec is actually shaped like (args, kwargs) assert spec.type is tuple assert spec.num_children == 2 kwargs_spec = spec.children_specs[1] assert kwargs_spec.type is dict if set(user_kwargs) != set(kwargs_spec.context): raise ValueError( f"kwarg key mismatch: " f"Got {list(user_kwargs)} but expected {kwargs_spec.context}" ) reordered_kwargs = {} for kw in kwargs_spec.context: reordered_kwargs[kw] = user_kwargs[kw] return reordered_kwargs def is_equivalent( spec1: TreeSpec, spec2: TreeSpec, equivalence_fn: Callable[[Optional[type], Context, Optional[type], Context], bool], ) -> bool: """Customizable equivalence check for two TreeSpecs. Arguments: spec1: The first TreeSpec to compare spec2: The second TreeSpec to compare equivalence_fn: A function to determine the equivalence of two TreeSpecs by examining their types and contexts. It will be called like: equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context) This function will be applied recursively to all children. Returns: True if the two TreeSpecs are equivalent, False otherwise. """ if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context): return False # Recurse on children if len(spec1.children_specs) != len(spec2.children_specs): return False for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs): if not is_equivalent(child_spec1, child_spec2, equivalence_fn): return False return True