#pragma once /// Defines the Float8_e5m2fnuz type (8-bit floating-point) including /// conversions to standard C types and basic arithmetic operations. Note that /// arithmetic operations are implemented by converting to floating point and /// performing the operation in float32. /// Binary configuration remains the same as e5m2: /// s eeeee mm /// 1 sign bit /// 5 exponent bits /// 2 mantissa bits /// The key differences that e5m2fnuz brings are: /// bias = 16 /// no infinities or negative zero /// NaN only when sign bit is 1, rest all 0s /// /// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and /// the existing Float8_e4m3fn implementation. #include #include #include #if defined(__cplusplus) #include #elif !defined(__OPENCL_VERSION__) #include #include #endif #include #include namespace c10 { namespace detail { /* * Convert a 32-bit floating-point number in IEEE single-precision format to a * 8-bit floating-point number in fp8 E5M2 format, in bit representation. */ inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) { /* * Binary representation of 65536.0f, which is the first value not * representable (i.e. the first value which would overflow in to the sign * bit, resulting in a NaN) in fp8e4m3fnuz range: * 1 00000 00 - fp8e5m2fnuz * 0 10001111 00000000000000000000000 - fp32 */ constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23; /* * A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range * into denormalized representation. * magic number: ((127 - 16) + (23 - 2) + 1) */ constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23; uint32_t f_bits = fp32_to_bits(f); uint32_t result = 0u; /* * Extract the sign of the input number into the high bit of the 32-bit word: * * +---+----------------------------------+ * | S |0000000 00000000 00000000 00000000| * +---+----------------------------------+ * Bits 31 0-31 */ const uint32_t sign = f_bits & UINT32_C(0x80000000); /* * Set sign bit to 0 */ f_bits ^= sign; if (f_bits >= fnuz_max) { // NaN -- sign bit set to 1, rest 0s return 0x80; } if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) { // Input exponent is less than -15, the smallest e5m2fnuz exponent, so the // number will become subnormal. f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); result = static_cast(f_bits - denorm_mask); if (result == 0) { // fnuz types don't have negative zero. return 0; } } else { // resulting mantissa is odd uint8_t mant_odd = (f_bits >> 21) & 1; // update exponent, rounding bias part 1 f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF; // rounding bias part 2 f_bits += mant_odd; // take the bits! result = static_cast(f_bits >> 21); } result |= sign >> 24; return result; } } // namespace detail struct alignas(1) Float8_e5m2fnuz { uint8_t x; struct from_bits_t {}; C10_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } Float8_e5m2fnuz() = default; constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t) : x(bits) {} inline C10_HOST_DEVICE Float8_e5m2fnuz(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; inline C10_HOST_DEVICE bool isinf() const; }; C10_API inline std::ostream& operator<<( std::ostream& out, const Float8_e5m2fnuz& value) { out << (float)value; return out; } } // namespace c10 #include // IWYU pragma: keep