#pragma once #include #include #include namespace at::cuda::detail { TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t); using at::native::canUse32BitIndexMath; template TensorInfo getTensorInfo(const at::TensorBase &t) { IndexType sz[MAX_TENSORINFO_DIMS]; IndexType st[MAX_TENSORINFO_DIMS]; int dims = t.dim(); for (int i = 0; i < dims; ++i) { sz[i] = t.size(i); st[i] = t.stride(i); } scalar* data_ptr = nullptr; if constexpr (std::is_const::value) { data_ptr = t.const_data_ptr(); } else { data_ptr = t.mutable_data_ptr(); } return TensorInfo( data_ptr, dims, sz, st); } } // namespace at::cuda::detail