#pragma once namespace at { namespace native { #if defined(USE_ROCM) // take these out when ROCm implements std:: math functions #include template static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); template <> __forceinline__ __device__ float device_sqrt(float val) { return ::sqrtf(val); } template <> __forceinline__ __device__ double device_sqrt(double val) { return ::sqrt(val); } #else template __forceinline__ __device__ double device_sqrt(scalar_t val) { return std::sqrt(val); } #endif }}