#pragma once #include #include #include #include #include #include #include #include #include namespace at::native { inline namespace CPU_CAPABILITY { using namespace vec; #define AT_DISPATCH_REDUCTION_TYPES(op, ...) \ [&] { \ switch (op) { \ case ReductionType::SUM: { \ static constexpr auto reduce = ReductionType::SUM; \ return __VA_ARGS__(); \ } \ case ReductionType::MEAN: { \ static constexpr auto reduce = ReductionType::MEAN; \ return __VA_ARGS__(); \ } \ case ReductionType::MIN: { \ static constexpr auto reduce = ReductionType::MIN; \ return __VA_ARGS__(); \ } \ case ReductionType::MAX: { \ static constexpr auto reduce = ReductionType::MAX; \ return __VA_ARGS__(); \ } \ case ReductionType::PROD: { \ static constexpr auto reduce = ReductionType::PROD; \ return __VA_ARGS__(); \ } \ } \ }() template inline vec_scalar_t init_value() { using acc_t = vec_scalar_t; acc_t val; if (reduce == ReductionType::SUM || reduce == ReductionType::MEAN) { val = static_cast(0); } else if (reduce == ReductionType::PROD) { val = static_cast(1); } else if (reduce == ReductionType::MAX) { val = -std::numeric_limits::infinity(); } else { TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); val = std::numeric_limits::infinity(); } return val; } template inline vec_scalar_t init_value(const std::optional& initial) { using acc_t = vec_scalar_t; if (initial.has_value()) { return initial.value().to(); } else { return init_value(); } } template inline void init(scalar_t* out, int64_t size, const vec_scalar_t& val) { using Vec = Vectorized>; map( [val](Vec x) { return Vec(val); }, out, out, size); } template inline void init(scalar_t* out, int64_t size, const std::optional& initial) { using acc_t = vec_scalar_t; acc_t val = init_value(initial); init(out, size, val); } // overload with `include_self`, used by scatter_reduce template inline void init(scalar_t* out, int64_t size, bool include_self = false) { using acc_t = vec_scalar_t; if (!include_self) { acc_t val = init_value(); init(out, size, val); } } template inline void _init(scalar_t* self_ptr, at::opmath_type* buffer_ptr, int64_t size, bool include_self) { if (!include_self) { init, reduce>(buffer_ptr, size, include_self); } else { vec::convert(self_ptr, buffer_ptr, size); } } template inline typename std::enable_if::value, scalar_t>::type _max(const scalar_t& x, const scalar_t& y) { return at::_isnan(y) ? y : std::max(x, y); } template inline Vectorized _max(const Vectorized& x, const Vectorized& y) { // vec::maximum propagates NaN return vec::maximum(x, y); } template inline typename std::enable_if::value, Vec2>::type _max(const vec_t& x, const vec_t& y) { // vec::maximum propagates NaN return maximum(x, y); } template inline typename std::enable_if::value, scalar_t>::type _min(const scalar_t& x, const scalar_t& y) { return at::_isnan(y) ? y : std::min(x, y); } template inline Vectorized _min(const Vectorized& x, const Vectorized& y) { // vec::minimum propagates NaN return vec::minimum(x, y); } template inline typename std::enable_if::value, Vec2>::type _min(const vec_t& x, const vec_t& y) { // vec::minimum propagates NaN return minimum(x, y); } template , int> = 0> inline void map_acc( const Op& vec_fun, accumut* output_data, const accumut* input_data, const scalar_t* input_data2, int64_t size) { using Vec = vec::Vectorized; using aVec = vec::Vectorized; int64_t d = 0; constexpr int64_t kVecSize = Vec::size(); constexpr int64_t kaVecSize = aVec::size(); for (d = 0; d < size - (size % kVecSize); d += kVecSize) { Vec data2_vec = Vec::loadu(input_data2 + d); auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec); aVec input_vec0 = aVec::loadu(input_data + d); aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize); vec_fun(input_vec0, data2_avec0).store(output_data + d); vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize); } if (size - d > 0) { int64_t tail_size = size - d; Vec data2_vec = Vec::loadu(input_data2 + d, tail_size); auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec); if (tail_size > kaVecSize) { aVec input_vec0 = aVec::loadu(input_data + d); aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize); vec_fun(input_vec0, data2_avec0).store(output_data + d); vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize); } else { aVec input_vec0 = aVec::loadu(input_data + d, tail_size); vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size); } } } // for Max and Min, propagate NaN: template inline T update(const T& x, const T& y) { if (reduce == ReductionType::SUM || reduce == ReductionType::MEAN) { return x + y; } else if (reduce == ReductionType::PROD) { return x * y; } else if (reduce == ReductionType::MAX) { return _max(x, y); } else { TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); return _min(x, y); } } template inline void update(scalar_t* out, const scalar_t* data, int64_t K) { using Vec = vec::Vectorized>; map2( [](Vec x, Vec y) { return update(x, y); }, out, out, data, K); } template , int> = 0> inline void update(at::opmath_type* out, const scalar_t* data, int64_t K) { using opmath_t = at::opmath_type; using Vec = vec::Vectorized; map_acc( [](Vec x, Vec y) { return update(x, y); }, out, out, data, K); } template inline void write(scalar_t* out, int64_t count, int64_t K) { using Vec = vec::Vectorized>; if (reduce == ReductionType::MEAN) { if (count > 0) { vec::map( [count](Vec x) { return x / Vec(count); }, out, out, K); } } } } // namespace CPU_CAPABILITY } // namespace at::native