# mypy: allow-untyped-defs """ This is a simple interpreter for Sympy expressions that dispatches to classes following the torch._inductor.virtualized calling convention. For directness, the interpreter takes the handler directly rather than consulting the TLS. It does not use most of the methods on the full handler; only those with corresponding Sympy expressions. To see an example of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis. """ import functools import logging from typing import Any, Dict, Union import sympy from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom import torch from .functions import ( CeilToInt, CleanDiv, FloatPow, FloatTrueDiv, FloorDiv, FloorToInt, Identity, IntTrueDiv, IsNonOverlappingAndDenseIndicator, Max, Min, Mod, ModularIndexing, PowByNatural, PythonMod, RoundDecimal, RoundToInt, ToFloat, TruncToFloat, TruncToInt, Where, ) log = logging.getLogger(__name__) # TODO: Dedupe this with SYMPY_INTERP @functools.lru_cache(None) def handlers(): # TODO add CeilDiv (it doesn't appear in the index_expr) # TODO default to some decompositions if the interpreter doesn't have them # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a) HANDLERS = { sympy.Or: "or_", sympy.And: "and_", sympy.Eq: "eq", sympy.Ne: "ne", sympy.Lt: "lt", sympy.Gt: "gt", sympy.Le: "le", sympy.Ge: "ge", sympy.Not: "not_", IntTrueDiv: "int_truediv", FloatTrueDiv: "truediv", FloorDiv: "floordiv", CleanDiv: "floordiv", # TODO: hmm? TruncToFloat: "trunc", Where: "where", sympy.Add: "add", sympy.Mul: "mul", FloatPow: "pow", PowByNatural: "pow_by_natural", # sympy simplifies x * x into Pow(x, 2), so we need to handle this. # Do NOT use builtin Pow for floats # TODO: There is a hazard here, if we have float * float it will # also get turned into Pow(float, 2) but we don't want this because # pow_by_natural is assumed to only be integers. Probably the fix is # to add a FloatMul to impede this optimization sympy.Pow: "pow_by_natural", Mod: "mod", PythonMod: "mod", # TODO: this is wrong # TODO: Inductor can generate these, but it's ill-specified which # semantics were intended here. Needs to be cleaned up along with # FloorDiv in a bigger cleanup sympy.Mod: "mod", sympy.Abs: "abs", sympy.log: "log", sympy.exp: "exp", sympy.Min: "minimum", sympy.Max: "maximum", Min: "minimum", Max: "maximum", ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", Identity: "identity", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", RoundDecimal: "round_decimal", } for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: HANDLERS[getattr(sympy, name)] = name return HANDLERS ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"} def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64): # Special cases if isinstance(expr, sympy.Pow) and isinstance( expr.args[1], sympy.core.numbers.Half ): return analysis.sqrt(args[0]) if isinstance(expr, ToFloat): return analysis.to_dtype(args[0], torch.float64) # These handlers are special because they take an extra dtype argument # specifying what they should convert to, and we need to appropriately set # this up when we convert from Sympy. A reasonable default when you # are translating is to conservatively do int64, and then narrow these # arguments later when you discover you can narrow the index range. But # if you already know that 32-bit indexing is OK, you can directly do the # sympy translation with index_dtype=torch.int32 INDEX_DTYPE_HANDLERS = { TruncToInt: "trunc_to_int", sympy.floor: "floor_to_int", sympy.ceiling: "ceil_to_int", FloorToInt: "floor_to_int", CeilToInt: "ceil_to_int", RoundToInt: "round_to_int", } if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: return getattr(analysis, handler_name)(*args, index_dtype) if hasattr(expr.func, "_torch_handler_name"): handler_name = expr.func._torch_handler_name else: handler_name = handlers()[expr.func] handler = getattr(analysis, handler_name) try: if handler_name in ASSOCIATIVE_OPS: assert len(args) > 1 acc = handler(args[0], args[1]) for i in range(2, len(args)): acc = handler(acc, args[i]) log.debug("%s(%s) -> %s", handler_name, args, acc) return acc else: r = handler(*args) log.debug("%s(%s) -> %s", handler_name, args, r) return r except Exception: log.warning("failed while executing %s(%s)", handler_name, args) raise def sympy_interp( analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean], *, index_dtype=torch.int64, ): # Handle base cases dtype = None if isinstance(expr, BooleanAtom): dtype = torch.bool elif isinstance(expr, sympy.Integer): dtype = torch.int64 elif isinstance(expr, sympy.Number): dtype = torch.double if dtype is not None: return analysis.constant(expr, dtype) elif isinstance(expr, sympy.Symbol): return env[expr] # Recursive case return _run_sympy_handler( analysis, [sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type] expr, index_dtype=index_dtype, ) # type: ignore[arg-type]