// Copyright © 2022 Apple Inc. #pragma once #include #include #include #ifdef __OBJC__ #include #include #include typedef id MTLDevice_t; typedef id MTLLibrary_t; typedef id MTLComputePipelineState_t; typedef id MTLLibrary_t; #else typedef void* MTLDevice; typedef void* MTLDevice_t; typedef void* MTLLibrary_t; typedef void* MTLComputePipelineState_t; typedef void* MTLLibrary_t; #endif namespace at::mps { // Helper enum to check if a MPSGraph op is supported in a given macOS version enum class MacOSVersion : uint32_t { MACOS_VER_13_1_PLUS = 0, MACOS_VER_13_2_PLUS, MACOS_VER_13_3_PLUS, MACOS_VER_14_0_PLUS, MACOS_VER_14_4_PLUS, MACOS_VER_15_0_PLUS, }; //----------------------------------------------------------------- // MPSDevice // // MPSDevice is a singleton class that returns the default device //----------------------------------------------------------------- class TORCH_API MPSDevice { public: /** * MPSDevice should not be cloneable. */ MPSDevice(MPSDevice& other) = delete; /** * MPSDevice should not be assignable. */ void operator=(const MPSDevice&) = delete; /** * Gets single instance of the Device. */ static MPSDevice* getInstance(); /** * Returns the single device. */ MTLDevice_t device() { return _mtl_device; } /** * Returns whether running on Ventura or newer */ bool isMacOS13Plus(MacOSVersion version) const; MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel); MTLLibrary_t getMetalIndexingLibrary(); ~MPSDevice(); private: static MPSDevice* _device; MTLDevice_t _mtl_device; MTLLibrary_t _mtl_indexing_library; MPSDevice(); }; TORCH_API bool is_available(); TORCH_API bool is_macos_13_or_newer(MacOSVersion version); TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); } // namespace at::mps