// Original TunableOp is from onnxruntime. // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // // Adapting TunableOp into PyTorch // Copyright (c) Advanced Micro Devices, Inc. // #pragma once #include #ifdef USE_ROCM #include #include #endif #include #include #include #include #include #include #include #include namespace at::cuda::tunable { template class DefaultGemmOp : public Callable> { public: TuningStatus Call(const GemmParams* params) override { at::cuda::blas::gemm_internal( params->transa, params->transb, params->m, params->n, params->k, params->alpha, params->a, params->lda, params->b, params->ldb, params->beta, params->c, params->ldc); return OK; } }; static bool _transposeBoolFromChar(char op) { return op == 't' || op == 'T'; } template class DefaultGemmAndBiasOp : public Callable> { public: TuningStatus Call(const GemmAndBiasParams* params) override { at::cuda::blas::gemm_and_bias( _transposeBoolFromChar(params->transa), _transposeBoolFromChar(params->transb), params->m, params->n, params->k, params->alpha, params->a, params->lda, params->b, params->ldb, params->bias, params->c, params->ldc, params->activation); return OK; } }; template class DefaultGemmStridedBatchedOp : public Callable> { public: TuningStatus Call(const GemmStridedBatchedParams* params) override { at::cuda::blas::bgemm_internal( params->transa, params->transb, params->m, params->n, params->k, params->alpha, params->a, params->lda, params->stride_a, params->b, params->ldb, params->stride_b, params->beta, params->c, params->ldc, params->stride_c, params->batch); return OK; } }; template class DefaultScaledGemmOp : public Callable> { public: TuningStatus Call(const ScaledGemmParams* params) override { at::cuda::blas::scaled_gemm( params->transa, params->transb, params->m, params->n, params->k, params->a, params->a_scale_ptr, params->lda, params->a_dtype, params->b, params->b_scale_ptr, params->ldb, params->b_dtype, params->bias_ptr, params->bias_dtype, params->c, params->c_scale_ptr, params->ldc, params->c_dtype, params->amax_ptr, params->use_fast_accum); return OK; } }; template inline bool IsZero(T v) { return v == 0.0f; } template <> inline bool IsZero(BFloat16 v) { return v.x == 0; } template <> inline bool IsZero(Half v) { return float(v) == 0.0f; } template <> inline bool IsZero(c10::complex v) { return v == 0.0; } template <> inline bool IsZero(c10::complex v) { return v == 0.0f; } template inline std::string TypeName(T v) { return "unknown"; } template <> inline std::string TypeName(float v) { return "float"; } template <> inline std::string TypeName(double v) { return "double"; } template <> inline std::string TypeName(BFloat16 v) { return "BFloat16"; } template <> inline std::string TypeName(Half v) { return "Half"; } template <> inline std::string TypeName(Float8_e4m3fn v) { return "Float8_e4m3fn"; } template <> inline std::string TypeName(Float8_e5m2 v) { return "Float8_e5m2"; } template <> inline std::string TypeName(Float8_e4m3fnuz v) { return "Float8_e4m3fnuz"; } template <> inline std::string TypeName(Float8_e5m2fnuz v) { return "Float8_e5m2fnuz"; } template <> inline std::string TypeName(c10::complex v) { return "c10::complex"; } template <> inline std::string TypeName(c10::complex v) { return "c10::complex"; } template class GemmTunableOp : public TunableOp, StreamTimer> { public: GemmTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); #ifdef USE_ROCM static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } } static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && !std::is_same_v>) { for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } } } #endif } std::string Signature() override { return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; template class GemmAndBiasTunableOp : public TunableOp, StreamTimer> { public: GemmAndBiasTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); #ifdef USE_ROCM static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && !std::is_same_v>) { for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } } } #endif } std::string Signature() override { return c10::str("GemmAndBiasTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; template class GemmStridedBatchedTunableOp : public TunableOp, StreamTimer> { public: GemmStridedBatchedTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); #ifdef USE_ROCM static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } } static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && !std::is_same_v>) { for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } } } #endif } std::string Signature() override { return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; template class ScaledGemmTunableOp : public TunableOp, StreamTimer> { public: ScaledGemmTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); #ifdef USE_ROCM for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } #endif } std::string Signature() override { return c10::str("ScaledGemmTunableOp", "_", TypeName(AT{}), "_", TypeName(BT{}), "_", TypeName(CT{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; } // namespace at::cuda::tunable