#pragma once namespace at::native { inline namespace CPU_CAPABILITY { // n: number of function arguments (arity) // traits: function_traits (see FunctionTraits.h) // s: index of scalar argument or -1 template struct IsContiguous { static bool eval(const int64_t* strides) { using type = typename traits::template arg::type; return strides[stride_index] == (s == n ? 0 : sizeof(type)) && IsContiguous::eval(strides); } }; // will be called when there is an output exists template struct IsContiguous<0, 0, traits, s> { static bool eval(const int64_t* strides) { return strides[0] == sizeof(typename traits::result_type); } }; // will be called when there is no output template struct IsContiguous<0, -1, traits, s> { static bool eval(const int64_t* /*strides*/) { return true; } }; // output and all inputs are contiguous template ::value>::type* = nullptr> static inline bool is_contiguous(const int64_t* strides) { return IsContiguous::eval(strides); } template ::value>::type* = nullptr> static inline bool is_contiguous(const int64_t* strides) { return IsContiguous::eval(strides); } // input at `s` is scalar (stride 0); output and other inputs are contiguous // NB: output is typically at strides[0] so first input corresponds to s=1 template ::value>::type* = nullptr> static inline bool is_contiguous_scalar(const int64_t* strides) { static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); return IsContiguous::eval(strides); } template ::value>::type* = nullptr> static inline bool is_contiguous_scalar(const int64_t* strides) { static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); return IsContiguous::eval(strides); } }}