#pragma once #if !defined(USE_ROCM) #include // for CUDA_VERSION #endif #if !defined(USE_ROCM) #include #else #define CUB_VERSION 0 #endif // cub sort support for __nv_bfloat16 is added to cub 1.13 in: // https://github.com/NVIDIA/cub/pull/306 #if CUB_VERSION >= 101300 #define CUB_SUPPORTS_NV_BFLOAT16() true #else #define CUB_SUPPORTS_NV_BFLOAT16() false #endif // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: // https://github.com/NVIDIA/cub/pull/326 // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake // starting from CUDA 11.5 #if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE) #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true #else #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false #endif // cub support for UniqueByKey is added to cub 1.16 in: // https://github.com/NVIDIA/cub/pull/405 #if CUB_VERSION >= 101600 #define CUB_SUPPORTS_UNIQUE_BY_KEY() true #else #define CUB_SUPPORTS_UNIQUE_BY_KEY() false #endif // cub support for scan by key is added to cub 1.15 // in https://github.com/NVIDIA/cub/pull/376 #if CUB_VERSION >= 101500 #define CUB_SUPPORTS_SCAN_BY_KEY() 1 #else #define CUB_SUPPORTS_SCAN_BY_KEY() 0 #endif // cub support for cub::FutureValue is added to cub 1.15 in: // https://github.com/NVIDIA/cub/pull/305 #if CUB_VERSION >= 101500 #define CUB_SUPPORTS_FUTURE_VALUE() true #else #define CUB_SUPPORTS_FUTURE_VALUE() false #endif