#pragma once #include #include #include #include namespace at::native { template C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) { return std::abs(weight) < scalar_t(0.5); } template C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex weight) { // Avoid the sqrt in abs(weight) return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25); } template C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) { using opmath_t = at::opmath_type; using opmath_weight_t = at::opmath_type; opmath_t self = self_; opmath_t end = end_; opmath_weight_t weight = weight_; // Conditional for better numeric. This has been discussed in // https://github.com/pytorch/pytorch/pull/18871 return is_lerp_weight_small(weight) ? self + weight * (end - self) : end - (end - self) * (opmath_t(1) - weight); } using lerp_fn_scalar = void (*)( at::TensorIteratorBase& iter, const Scalar& weight); using lerp_fn_tensor = void (*)( at::TensorIteratorBase& iter); DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight); DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight); } // namespace at::native