// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include #include #include #include #define ROCBLAS_BETA_FEATURES_API #include #define TORCH_ROCBLAS_CHECK(EXPR) \ do { \ rocblas_status __err = EXPR; \ TORCH_CHECK(__err == rocblas_status_success, \ "rocblas error: ", \ rocblas_status_to_string(__err), \ " when calling `" #EXPR "`"); \ } while (0) namespace at::cuda::tunable { template constexpr rocblas_datatype RocBlasDataTypeFor(); template <> constexpr rocblas_datatype RocBlasDataTypeFor() { return rocblas_datatype_f32_r; } template <> constexpr rocblas_datatype RocBlasDataTypeFor() { return rocblas_datatype_f64_r; } template <> constexpr rocblas_datatype RocBlasDataTypeFor() { return rocblas_datatype_f16_r; } template <> constexpr rocblas_datatype RocBlasDataTypeFor() { return rocblas_datatype_bf16_r; } template <> constexpr rocblas_datatype RocBlasDataTypeFor>() { return rocblas_datatype_f32_c; } template <> constexpr rocblas_datatype RocBlasDataTypeFor>() { return rocblas_datatype_f64_c; } template constexpr rocblas_datatype RocBlasComputeTypeFor(); template <> constexpr rocblas_datatype RocBlasComputeTypeFor() { return rocblas_datatype_f32_r; } template <> constexpr rocblas_datatype RocBlasComputeTypeFor() { return rocblas_datatype_f64_r; } template <> constexpr rocblas_datatype RocBlasComputeTypeFor() { // Note that we're returning the _compute_ type for a given datatype. // As of 12/2022, using compute type FP16 for 16-bit floats was much // slower than using compute type FP32. So we use FP32 compute even for // FP16 datatypes. This is how GEMM is implemented even in the function // rocblasGemmHelper (see fpgeneric.h) return rocblas_datatype_f32_r; } template <> constexpr rocblas_datatype RocBlasComputeTypeFor() { // Note that we're returning the _compute_ type for a given datatype. // As of 12/2022, using compute type FP16 for 16-bit floats was much // slower than using compute type FP32. So we use FP32 compute even for // BF16 datatypes. This is how GEMM is implemented even in the function // rocblasGemmHelper (see fpgeneric.h) return rocblas_datatype_f32_r; } template <> constexpr rocblas_datatype RocBlasComputeTypeFor>() { return rocblas_datatype_f32_c; } template <> constexpr rocblas_datatype RocBlasComputeTypeFor>() { return rocblas_datatype_f64_c; } template auto DoCastForHalfOrBfloat16(const T fp) { return fp; } template <> inline auto DoCastForHalfOrBfloat16(const Half fp) { // alpha and beta should be the same as compute_type, in Half case it is float. float h = fp; return h; } template <> inline auto DoCastForHalfOrBfloat16(const BFloat16 fp) { // alpha and beta should be the same as compute_type, in bfloat16 case it is float. float h = fp; return h; } static rocblas_operation _rocblasOpFromChar(char op) { switch (op) { case 'n': case 'N': return rocblas_operation_none; case 't': case 'T': return rocblas_operation_transpose; case 'c': case 'C': return rocblas_operation_conjugate_transpose; } AT_ERROR( "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); } template class RocblasGemmOp : public Callable> { public: RocblasGemmOp(int solution) : solution_{solution} {} TuningStatus Call(const GemmParams* params) override { auto input_output_type = RocBlasDataTypeFor(); auto compute_type = RocBlasComputeTypeFor(); auto h_a = DoCastForHalfOrBfloat16(params->alpha); auto h_b = DoCastForHalfOrBfloat16(params->beta); auto status = rocblas_gemm_ex( (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), _rocblasOpFromChar(params->transa), _rocblasOpFromChar(params->transb), params->m, params->n, params->k, &h_a, params->a, input_output_type, params->lda, params->b, input_output_type, params->ldb, &h_b, params->c, input_output_type, params->ldc, params->c, input_output_type, params->ldc, compute_type, rocblas_gemm_algo_solution_index, solution_, rocblas_gemm_flags_none); if (status != rocblas_status_success) { return FAIL; } return OK; } private: int solution_; }; template auto GetRocBlasGemmTypeStringAndOps() { rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); int solution_size; auto input_output_type = RocBlasDataTypeFor(); auto compute_type = RocBlasComputeTypeFor(); // Get the number of available solutions TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, input_output_type, input_output_type, compute_type, rocblas_gemm_flags_none, nullptr, &solution_size)); std::vector solutions(solution_size); // Get the list of available solutions TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, input_output_type, input_output_type, compute_type, rocblas_gemm_flags_none, solutions.data(), &solution_size)); // Sort the solutions in ascending order to make the solution vector deterministic across runs std::sort(solutions.begin(), solutions.end()); std::vector>>>> ret; for (size_t i = 0; i < solutions.size(); ++i) { auto callable = std::make_unique>(solutions[i]); ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); } return ret; } template class RocblasGemmStridedBatchedOp : public Callable> { public: RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {} TuningStatus Call(const GemmStridedBatchedParams* params) override { auto input_output_type = RocBlasDataTypeFor(); auto compute_type = RocBlasComputeTypeFor(); auto h_a = DoCastForHalfOrBfloat16(params->alpha); auto h_b = DoCastForHalfOrBfloat16(params->beta); auto status = rocblas_gemm_strided_batched_ex( (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), _rocblasOpFromChar(params->transa), _rocblasOpFromChar(params->transb), params->m, params->n, params->k, &h_a, params->a, input_output_type, params->lda, params->stride_a, params->b, input_output_type, params->ldb, params->stride_b, &h_b, params->c, input_output_type, params->ldc, params->stride_c, params->c, input_output_type, params->ldc, params->stride_c, params->batch, compute_type, rocblas_gemm_algo_solution_index, solution_, rocblas_gemm_flags_none); if (status != rocblas_status_success) { return FAIL; } return OK; } private: int solution_; }; template auto GetRocBlasGemmStridedBatchedTypeStringAndOps() { rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); int solution_size; auto input_output_type = RocBlasDataTypeFor(); auto compute_type = RocBlasComputeTypeFor(); // Get the number of available solutions TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, input_output_type, input_output_type, compute_type, rocblas_gemm_flags_none, nullptr, &solution_size)); std::vector solutions(solution_size); // Get the list of available solutions TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, input_output_type, input_output_type, compute_type, rocblas_gemm_flags_none, solutions.data(), &solution_size)); // Sort the solutions in ascending order to make the solution vector deterministic across runs std::sort(solutions.begin(), solutions.end()); std::vector>>>> ret; for (size_t i = 0; i < solutions.size(); ++i) { auto callable = std::make_unique>(solutions[i]); ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); } return ret; } } // namespace at::cuda::tunable