triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,1391 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* MIT License
|
|
3
|
+
*
|
|
4
|
+
* Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved.
|
|
5
|
+
*
|
|
6
|
+
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
* of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
* in the Software without restriction, including without limitation the rights
|
|
9
|
+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
* copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
* furnished to do so, subject to the following conditions:
|
|
12
|
+
*
|
|
13
|
+
* The above copyright notice and this permission notice shall be included in
|
|
14
|
+
* all copies or substantial portions of the Software.
|
|
15
|
+
*
|
|
16
|
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
* SOFTWARE.
|
|
23
|
+
*/
|
|
24
|
+
|
|
25
|
+
/**
|
|
26
|
+
* \file
|
|
27
|
+
* \brief amd_hip_fp8.h header, for AMD fp8 data types
|
|
28
|
+
*/
|
|
29
|
+
|
|
30
|
+
#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
|
|
31
|
+
#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
|
|
32
|
+
|
|
33
|
+
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && __HIP_DEVICE_COMPILE__
|
|
34
|
+
#define HIP_FP8_CVT_FAST_PATH 1
|
|
35
|
+
#else
|
|
36
|
+
#define HIP_FP8_CVT_FAST_PATH 0
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#if !defined(__HIPCC_RTC__)
|
|
40
|
+
#include <hip/amd_detail/amd_hip_common.h>
|
|
41
|
+
#include <climits>
|
|
42
|
+
|
|
43
|
+
#include "host_defines.h" // __hip_internal::
|
|
44
|
+
#include "amd_hip_vector_types.h" // float2 etc
|
|
45
|
+
#include "amd_hip_fp16.h" // __half_raw
|
|
46
|
+
#include "amd_hip_bf16.h" // bf16
|
|
47
|
+
#include "math_fwd.h" // ocml device functions
|
|
48
|
+
#endif // !defined(__HIPCC_RTC__)
|
|
49
|
+
|
|
50
|
+
#if defined(__HIPCC_RTC__)
|
|
51
|
+
#define __FP8_HOST_DEVICE__ __device__
|
|
52
|
+
#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static
|
|
53
|
+
#else
|
|
54
|
+
#define __FP8_HOST_DEVICE__ __host__ __device__
|
|
55
|
+
#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline
|
|
56
|
+
#endif // __HIPCC_RTC__
|
|
57
|
+
|
|
58
|
+
#if !defined(__HIPCC_RTC__)
|
|
59
|
+
static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
|
|
60
|
+
#endif
|
|
61
|
+
static_assert(sizeof(unsigned char) == 1);
|
|
62
|
+
static_assert(sizeof(unsigned short int) == 2);
|
|
63
|
+
static_assert(sizeof(unsigned int) == 4);
|
|
64
|
+
|
|
65
|
+
/**
|
|
66
|
+
* \brief Describes FP8 interpretation
|
|
67
|
+
*/
|
|
68
|
+
enum __hip_fp8_interpretation_t {
|
|
69
|
+
__HIP_E4M3_FNUZ = 0, /**< Standard FP8 */
|
|
70
|
+
__HIP_E5M2_FNUZ = 1, /**< BF8 */
|
|
71
|
+
};
|
|
72
|
+
|
|
73
|
+
/**
|
|
74
|
+
* \brief Describes saturation behavior
|
|
75
|
+
*/
|
|
76
|
+
enum __hip_saturation_t {
|
|
77
|
+
__HIP_NOSAT = 0, /**< No saturation */
|
|
78
|
+
__HIP_SATFINITE = 1, /**< Saturate to finite */
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
/** \typedef __hip_fp8_storage_t
|
|
82
|
+
*
|
|
83
|
+
* \brief type to store single fp8 number
|
|
84
|
+
*/
|
|
85
|
+
typedef unsigned char __hip_fp8_storage_t;
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
/** \typedef __hip_fp8x2_storage_t
|
|
89
|
+
*
|
|
90
|
+
* \brief type to store two fp8 numbers
|
|
91
|
+
*/
|
|
92
|
+
typedef unsigned short int __hip_fp8x2_storage_t;
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
/** \typedef __hip_fp8x4_storage_t
|
|
96
|
+
*
|
|
97
|
+
* \brief type to store four fp8 numbers
|
|
98
|
+
*/
|
|
99
|
+
typedef unsigned int __hip_fp8x4_storage_t;
|
|
100
|
+
|
|
101
|
+
namespace internal {
|
|
102
|
+
// The conversion function is from rocblas
|
|
103
|
+
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
|
|
104
|
+
// This has been modified to add double types conversion as well
|
|
105
|
+
template <typename T, bool negative_zero_nan>
|
|
106
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false,
|
|
107
|
+
bool stoch = false,
|
|
108
|
+
unsigned int rng = 0) {
|
|
109
|
+
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
|
|
110
|
+
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
|
|
111
|
+
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
|
|
112
|
+
static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8");
|
|
113
|
+
|
|
114
|
+
const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
|
|
115
|
+
unsigned long long x;
|
|
116
|
+
|
|
117
|
+
if (sizeof(T) == 8)
|
|
118
|
+
x = reinterpret_cast<unsigned long long&>(_x);
|
|
119
|
+
else if (sizeof(T) == 4)
|
|
120
|
+
x = reinterpret_cast<unsigned int&>(_x);
|
|
121
|
+
else
|
|
122
|
+
x = reinterpret_cast<unsigned short int&>(_x);
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
unsigned long long head, mantissa;
|
|
126
|
+
int exponent, bias;
|
|
127
|
+
unsigned int sign;
|
|
128
|
+
|
|
129
|
+
if (sizeof(T) == 8) {
|
|
130
|
+
head = x & 0xFFF0000000000000ull;
|
|
131
|
+
mantissa = x & 0xFFFFFFFFFFFFFull;
|
|
132
|
+
exponent = (head >> 52) & 0x7FF;
|
|
133
|
+
sign = head >> 63;
|
|
134
|
+
bias = 1023;
|
|
135
|
+
} else if (sizeof(T) == 4) {
|
|
136
|
+
head = x & 0xFF800000;
|
|
137
|
+
mantissa = x & 0x7FFFFF;
|
|
138
|
+
exponent = (head >> 23) & 0xFF;
|
|
139
|
+
sign = head >> 31;
|
|
140
|
+
bias = 127;
|
|
141
|
+
} else {
|
|
142
|
+
head = x & 0xFC00;
|
|
143
|
+
mantissa = x & 0x3FF;
|
|
144
|
+
exponent = (head >> 10) & 0x1F;
|
|
145
|
+
sign = head >> 15;
|
|
146
|
+
bias = 15;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
unsigned int signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
|
150
|
+
|
|
151
|
+
// Deal with inf and NaNs
|
|
152
|
+
if (negative_zero_nan) {
|
|
153
|
+
if (sizeof(T) == 8) {
|
|
154
|
+
if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull) return 0x80;
|
|
155
|
+
} else if (sizeof(T) == 4) {
|
|
156
|
+
if ((x & 0x7F800000) == 0x7F800000) return 0x80;
|
|
157
|
+
} else {
|
|
158
|
+
if ((x & 0x7C00) == 0x7C00) return 0x80;
|
|
159
|
+
}
|
|
160
|
+
} else {
|
|
161
|
+
if (sizeof(T) == 8) {
|
|
162
|
+
if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull)
|
|
163
|
+
return signed_inf + (mantissa != 0 ? 1 : 0);
|
|
164
|
+
} else if (sizeof(T) == 4) {
|
|
165
|
+
if ((x & 0x7F800000) == 0x7F800000) return signed_inf + (mantissa != 0 ? 1 : 0);
|
|
166
|
+
} else {
|
|
167
|
+
if ((x & 0x7C00) == 0x7C00) return signed_inf + (mantissa != 0 ? 1 : 0);
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
if (x == 0) {
|
|
172
|
+
return 0;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// First need to check if it is normal or denorm as there is a difference of implict 1
|
|
176
|
+
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
|
|
177
|
+
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
|
|
178
|
+
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
|
|
179
|
+
// exponent and mantissa again
|
|
180
|
+
|
|
181
|
+
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
|
|
182
|
+
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
|
183
|
+
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
|
184
|
+
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
|
185
|
+
// f8_exponent is the converted f8 exponent with bias encoding
|
|
186
|
+
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
|
187
|
+
// the difference needs to be adjusted and mantissa shifted
|
|
188
|
+
int act_exponent, f8_exponent, exponent_diff;
|
|
189
|
+
|
|
190
|
+
if (exponent == 0) { // fp32/fp16 is in denormal.
|
|
191
|
+
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
|
|
192
|
+
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
|
|
193
|
+
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
|
|
194
|
+
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
|
195
|
+
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In
|
|
196
|
+
this case, the fp16 mantissa should be shift left by 1 */
|
|
197
|
+
act_exponent = exponent - bias + 1;
|
|
198
|
+
exponent_diff = f8_denormal_act_exponent -
|
|
199
|
+
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
|
200
|
+
} else { // fp32/fp16 is normal with implicit 1
|
|
201
|
+
act_exponent = exponent - bias;
|
|
202
|
+
if (act_exponent <= f8_denormal_act_exponent) {
|
|
203
|
+
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
|
|
204
|
+
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
|
|
205
|
+
actual exponent is -7, it is actually larger due to the implict 1,
|
|
206
|
+
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
|
207
|
+
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
|
208
|
+
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
|
209
|
+
} else { // both fp32/fp16 and f8 are in normal range
|
|
210
|
+
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case,
|
|
211
|
+
// act_exponent could be larger. Just that it does not need shift mantissa
|
|
212
|
+
}
|
|
213
|
+
mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
|
|
217
|
+
(1ull << (mfmt - wm + exponent_diff - 1));
|
|
218
|
+
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift
|
|
219
|
+
right as shift right could rip off some residual part and make something not midpoint look like
|
|
220
|
+
midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint, but
|
|
221
|
+
after shift right by 4 bits, it would look like midpoint.
|
|
222
|
+
*/
|
|
223
|
+
|
|
224
|
+
if (exponent_diff > 0)
|
|
225
|
+
mantissa >>= exponent_diff;
|
|
226
|
+
else if (exponent_diff == -1)
|
|
227
|
+
mantissa <<= -exponent_diff;
|
|
228
|
+
bool implicit_one = mantissa & (1ull << mfmt);
|
|
229
|
+
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
|
|
230
|
+
f8_exponent =
|
|
231
|
+
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
|
232
|
+
|
|
233
|
+
// Now we have the exponent and mantissa adjusted
|
|
234
|
+
unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
|
|
235
|
+
bool odd =
|
|
236
|
+
mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
|
|
237
|
+
mantissa +=
|
|
238
|
+
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
|
|
239
|
+
|
|
240
|
+
// Now we deal with overflow
|
|
241
|
+
if (f8_exponent == 0) {
|
|
242
|
+
if ((1ull << mfmt) & mantissa) {
|
|
243
|
+
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
|
244
|
+
}
|
|
245
|
+
} else {
|
|
246
|
+
if ((1ull << (mfmt + 1)) & mantissa) {
|
|
247
|
+
mantissa >>= 1;
|
|
248
|
+
f8_exponent++;
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
mantissa >>= (mfmt - wm);
|
|
253
|
+
|
|
254
|
+
// above range: quantize to maximum possible float of the same sign
|
|
255
|
+
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
|
256
|
+
if (f8_exponent > max_exp) {
|
|
257
|
+
if (clip) {
|
|
258
|
+
mantissa = (1 << wm) - 1;
|
|
259
|
+
f8_exponent = max_exp;
|
|
260
|
+
} else {
|
|
261
|
+
return signed_inf;
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if (f8_exponent == 0 && mantissa == 0) return negative_zero_nan ? 0 : (sign << 7);
|
|
266
|
+
mantissa &= (1 << wm) - 1;
|
|
267
|
+
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// The conversion function is from rocblas
|
|
271
|
+
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
|
|
272
|
+
// This has been modified to handle double types as well
|
|
273
|
+
template <typename T, bool negative_zero_nan>
|
|
274
|
+
__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we) {
|
|
275
|
+
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
|
|
276
|
+
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
|
|
277
|
+
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
|
|
278
|
+
static_assert(is_half || is_float || is_double, "only half, float and double are supported");
|
|
279
|
+
|
|
280
|
+
constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
|
|
281
|
+
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
|
|
282
|
+
|
|
283
|
+
T fInf, fNegInf, fNaN, fNeg0;
|
|
284
|
+
if (is_half) {
|
|
285
|
+
const unsigned short int ihInf = 0x7C00;
|
|
286
|
+
const unsigned short int ihNegInf = 0xFC00;
|
|
287
|
+
const unsigned short int ihNaN = 0x7C01;
|
|
288
|
+
const unsigned short int ihNeg0 = 0x8000;
|
|
289
|
+
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
|
290
|
+
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
|
291
|
+
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
|
292
|
+
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
|
293
|
+
} else if (is_float) {
|
|
294
|
+
const unsigned int ifInf = 0x7F800000;
|
|
295
|
+
const unsigned int ifNegInf = 0xFF800000;
|
|
296
|
+
const unsigned int ifNaN = 0x7F800001;
|
|
297
|
+
const unsigned int ifNeg0 = 0x80000000;
|
|
298
|
+
fInf = reinterpret_cast<const float&>(ifInf);
|
|
299
|
+
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
|
300
|
+
fNaN = reinterpret_cast<const float&>(ifNaN);
|
|
301
|
+
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
|
302
|
+
} else if (is_double) {
|
|
303
|
+
const unsigned long long ifInf = 0x7FF0000000000000ull;
|
|
304
|
+
const unsigned long long ifNegInf = 0xFFF0000000000000ull;
|
|
305
|
+
const unsigned long long ifNaN = 0x7FF0000000000001ull;
|
|
306
|
+
const unsigned long long ifNeg0 = 0x8000000000000000ull;
|
|
307
|
+
fInf = reinterpret_cast<const double&>(ifInf);
|
|
308
|
+
fNegInf = reinterpret_cast<const double&>(ifNegInf);
|
|
309
|
+
fNaN = reinterpret_cast<const double&>(ifNaN);
|
|
310
|
+
fNeg0 = reinterpret_cast<const double&>(ifNeg0);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
if (x == 0) {
|
|
314
|
+
return 0;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
unsigned long long sign = x >> 7;
|
|
318
|
+
unsigned long long mantissa = x & ((1 << wm) - 1);
|
|
319
|
+
int exponent = (x & 0x7F) >> wm;
|
|
320
|
+
if (negative_zero_nan) {
|
|
321
|
+
if (x == 0x80) return fNaN;
|
|
322
|
+
} else {
|
|
323
|
+
if (x == 0x80) return fNeg0;
|
|
324
|
+
if (exponent == ((1 << we) - 1)) return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
typename __hip_internal::conditional<
|
|
328
|
+
sizeof(T) == 2, unsigned short int,
|
|
329
|
+
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int,
|
|
330
|
+
unsigned long long>::type>::type retval;
|
|
331
|
+
|
|
332
|
+
if (we == 5 && is_half && !negative_zero_nan) {
|
|
333
|
+
retval = x << 8;
|
|
334
|
+
return reinterpret_cast<const T&>(retval);
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
|
338
|
+
|
|
339
|
+
// subnormal input
|
|
340
|
+
if (exponent == 0) {
|
|
341
|
+
#if __HIP_DEVICE_COMPILE__
|
|
342
|
+
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
|
343
|
+
int sh = 1 + __clz(mantissa) - (32 - wm);
|
|
344
|
+
#else
|
|
345
|
+
int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
|
|
346
|
+
#endif
|
|
347
|
+
mantissa <<= sh;
|
|
348
|
+
exponent += 1 - sh;
|
|
349
|
+
mantissa &= ((1ull << wm) - 1);
|
|
350
|
+
}
|
|
351
|
+
exponent += exp_low_cutoff - 1;
|
|
352
|
+
mantissa <<= wmo - wm;
|
|
353
|
+
|
|
354
|
+
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
|
355
|
+
if (exponent <= 0) {
|
|
356
|
+
mantissa |= 1 << wmo;
|
|
357
|
+
mantissa >>= 1 - exponent;
|
|
358
|
+
exponent = 0;
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
if (sizeof(T) == 2)
|
|
362
|
+
retval = (sign << 15) | (exponent << 10) | mantissa;
|
|
363
|
+
else if (sizeof(T) == 4)
|
|
364
|
+
retval = (sign << 31) | (exponent << 23) | mantissa;
|
|
365
|
+
else
|
|
366
|
+
retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
|
|
367
|
+
return reinterpret_cast<const T&>(retval);
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
371
|
+
// The conversion function is from rocblas
|
|
372
|
+
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
|
|
373
|
+
template <bool stochastic_rounding = false>
|
|
374
|
+
static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate,
|
|
375
|
+
__hip_fp8_interpretation_t interpret,
|
|
376
|
+
unsigned int rng = 0) {
|
|
377
|
+
__hip_fp8_storage_t i8data;
|
|
378
|
+
union {
|
|
379
|
+
float fval;
|
|
380
|
+
unsigned int i32val;
|
|
381
|
+
unsigned char i8val[4]; // NOTE: not endian independent
|
|
382
|
+
} val;
|
|
383
|
+
|
|
384
|
+
unsigned int ival = 0;
|
|
385
|
+
val.fval = v;
|
|
386
|
+
|
|
387
|
+
if (saturate) {
|
|
388
|
+
if (interpret == __HIP_E4M3_FNUZ) {
|
|
389
|
+
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
|
390
|
+
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
|
391
|
+
}
|
|
392
|
+
} else {
|
|
393
|
+
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
|
394
|
+
val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
if (stochastic_rounding) {
|
|
400
|
+
ival = interpret == __HIP_E4M3_FNUZ
|
|
401
|
+
? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
|
|
402
|
+
: __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
|
403
|
+
val.i32val = ival;
|
|
404
|
+
i8data = val.i8val[0]; // little endian
|
|
405
|
+
} else { // RNE CVT
|
|
406
|
+
ival = interpret == __HIP_E4M3_FNUZ
|
|
407
|
+
? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
|
|
408
|
+
: __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
|
409
|
+
val.i32val = ival;
|
|
410
|
+
i8data = val.i8val[0];
|
|
411
|
+
}
|
|
412
|
+
return i8data;
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
static __device__ __hip_fp8x2_storage_t
|
|
416
|
+
cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) {
|
|
417
|
+
union {
|
|
418
|
+
static_assert(sizeof(float2) == sizeof(unsigned int[2]));
|
|
419
|
+
static_assert(sizeof(float2) == sizeof(unsigned short[4]));
|
|
420
|
+
float2 fval;
|
|
421
|
+
unsigned int i32val[2];
|
|
422
|
+
unsigned short i16val[4];
|
|
423
|
+
} f2val;
|
|
424
|
+
|
|
425
|
+
f2val.fval = v;
|
|
426
|
+
|
|
427
|
+
if (saturate) { /// propagate NAN/INF, no clipping
|
|
428
|
+
if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
|
|
429
|
+
f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
|
|
430
|
+
}
|
|
431
|
+
if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
|
|
432
|
+
f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
f2val.i32val[0] = interpret == __HIP_E4M3_FNUZ
|
|
437
|
+
? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0, false)
|
|
438
|
+
: __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0, false);
|
|
439
|
+
|
|
440
|
+
return static_cast<__hip_fp8x2_storage_t>(f2val.i16val[0]);
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
static __device__ float cast_to_f32_from_f8(__hip_fp8_storage_t v,
|
|
444
|
+
__hip_fp8_interpretation_t interpret) {
|
|
445
|
+
union {
|
|
446
|
+
unsigned int i32val;
|
|
447
|
+
unsigned char i8val[4];
|
|
448
|
+
} val;
|
|
449
|
+
val.i8val[0] = v;
|
|
450
|
+
|
|
451
|
+
float fval = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0)
|
|
452
|
+
: __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
|
|
453
|
+
return fval;
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
static __device__ float2 cast_to_f32x2_from_f8x2(__hip_fp8x2_storage_t v,
|
|
457
|
+
__hip_fp8_interpretation_t interpret) {
|
|
458
|
+
union {
|
|
459
|
+
unsigned int i32val;
|
|
460
|
+
unsigned short i16val[2];
|
|
461
|
+
} val;
|
|
462
|
+
val.i16val[0] = v;
|
|
463
|
+
|
|
464
|
+
auto f2 = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val, false)
|
|
465
|
+
: __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val, false);
|
|
466
|
+
return float2{f2[0], f2[1]};
|
|
467
|
+
}
|
|
468
|
+
#endif // HIP_FP8_CVT_FAST_PATH
|
|
469
|
+
|
|
470
|
+
/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
|
|
471
|
+
Inf are not supported. This gives us one additional number to represent.
|
|
472
|
+
NaN are represented by 1-0000-000 or 1-00000-00 */
|
|
473
|
+
__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_fnuz_is_nan(__hip_fp8_storage_t a) {
|
|
474
|
+
return static_cast<unsigned char>(a) == 0x80;
|
|
475
|
+
}
|
|
476
|
+
} // namespace internal
|
|
477
|
+
|
|
478
|
+
/**
|
|
479
|
+
* \brief convert float to @p __hip_fp8_storage_t
|
|
480
|
+
*
|
|
481
|
+
* \param f float number
|
|
482
|
+
* \param sat saturation of fp8
|
|
483
|
+
* \param type interpretation of fp8
|
|
484
|
+
* \return __hip_fp8_storage_t
|
|
485
|
+
*/
|
|
486
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(
|
|
487
|
+
const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
|
|
488
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
489
|
+
return internal::cast_to_f8_from_f32<false>(f, sat == __HIP_SATFINITE, type);
|
|
490
|
+
#else // HIP_FP8_CVT_FAST_PATH
|
|
491
|
+
int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
|
|
492
|
+
int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
|
|
493
|
+
return internal::cast_to_f8<float, true>(f, wm, we, sat == __HIP_SATFINITE);
|
|
494
|
+
#endif // HIP_FP8_CVT_FAST_PATH
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
/**
|
|
498
|
+
* \brief convert float2 to @p __hip_fp8x2_storage_t
|
|
499
|
+
*
|
|
500
|
+
* \param f2 float2 number
|
|
501
|
+
* \param sat saturation of fp8
|
|
502
|
+
* \param type interpretation of fp8
|
|
503
|
+
* \return __hip_fp8x2_storage_t
|
|
504
|
+
*/
|
|
505
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2(
|
|
506
|
+
const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
|
|
507
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
508
|
+
return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, type);
|
|
509
|
+
#else
|
|
510
|
+
return static_cast<__hip_fp8x2_storage_t>(
|
|
511
|
+
static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.y, sat, type)) << 8 |
|
|
512
|
+
static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.x, sat, type)));
|
|
513
|
+
#endif
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
/**
|
|
517
|
+
* \brief convert double to @p __hip_fp8_storage_t
|
|
518
|
+
*
|
|
519
|
+
* \param d double val
|
|
520
|
+
* \param sat saturation of fp8
|
|
521
|
+
* \param type interpretation of fp8
|
|
522
|
+
* \return __hip_fp8_storage_t
|
|
523
|
+
*/
|
|
524
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(
|
|
525
|
+
const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
|
|
526
|
+
int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
|
|
527
|
+
int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
|
|
528
|
+
return internal::cast_to_f8<double, true>(d, wm, we, sat == __HIP_SATFINITE);
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
/**
|
|
532
|
+
* \brief convert double2 to @p __hip_fp8x2_storage_t
|
|
533
|
+
*
|
|
534
|
+
* \param d2 double2 val
|
|
535
|
+
* \param sat saturation of fp8
|
|
536
|
+
* \param type interpretation of fp8
|
|
537
|
+
* \return __hip_fp8x2_storage_t
|
|
538
|
+
*/
|
|
539
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2(
|
|
540
|
+
const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
|
|
541
|
+
return static_cast<__hip_fp8x2_storage_t>(
|
|
542
|
+
static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.y, sat, type)) << 8 |
|
|
543
|
+
static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.x, sat, type)));
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
/**
|
|
547
|
+
* \brief convert __hip_bfloat16_raw to @p __hip_fp8_storage_t
|
|
548
|
+
*
|
|
549
|
+
* \param hr __hip_bfloat16_raw val
|
|
550
|
+
* \param sat saturation of fp8
|
|
551
|
+
* \param type interpretation of fp8
|
|
552
|
+
* \return __hip_fp8_storage_t
|
|
553
|
+
*/
|
|
554
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t
|
|
555
|
+
__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat,
|
|
556
|
+
const __hip_fp8_interpretation_t type) {
|
|
557
|
+
float fval = __hip_bfloat16(hr);
|
|
558
|
+
return __hip_cvt_float_to_fp8(fval, sat, type);
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
/**
|
|
562
|
+
* \brief convert double2 to @p __hip_fp8x2_storage_t
|
|
563
|
+
*
|
|
564
|
+
* \param hr __hip_bfloat162_raw value
|
|
565
|
+
* \param sat saturation of fp8
|
|
566
|
+
* \param type interpretation of fp8
|
|
567
|
+
* \return __hip_fp8x2_storage_t
|
|
568
|
+
*/
|
|
569
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t
|
|
570
|
+
__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat,
|
|
571
|
+
const __hip_fp8_interpretation_t type) {
|
|
572
|
+
float2 f2 = __hip_bfloat162(hr);
|
|
573
|
+
return __hip_cvt_float2_to_fp8x2(f2, sat, type);
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
/**
|
|
577
|
+
* \brief convert @p __hip_fp8_storage_t to __half_raw
|
|
578
|
+
*
|
|
579
|
+
* \param x __hip_fp8_storage_t val
|
|
580
|
+
* \param type interpretation of fp8
|
|
581
|
+
* \return __half_raw
|
|
582
|
+
*/
|
|
583
|
+
__FP8_HOST_DEVICE_STATIC__ __half_raw
|
|
584
|
+
__hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t type) {
|
|
585
|
+
unsigned int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
|
|
586
|
+
unsigned int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
|
|
587
|
+
return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
/**
|
|
591
|
+
* \brief convert @p __hip_fp8x2_storage_t to __half2_raw
|
|
592
|
+
*
|
|
593
|
+
* \param x __hip_fp8x2_storage_t val
|
|
594
|
+
* \param type interpretation of fp8
|
|
595
|
+
* \return __half2_raw
|
|
596
|
+
*/
|
|
597
|
+
__FP8_HOST_DEVICE_STATIC__ __half2_raw
|
|
598
|
+
__hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t type) {
|
|
599
|
+
__half2 ret(static_cast<__half>(
|
|
600
|
+
__hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), type)),
|
|
601
|
+
static_cast<__half>(
|
|
602
|
+
__hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), type)));
|
|
603
|
+
return static_cast<__half2_raw>(ret);
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
/**
|
|
607
|
+
* \brief convert __half_raw to @p __hip_fp8_storage_t
|
|
608
|
+
*
|
|
609
|
+
* \param x __half_raw value
|
|
610
|
+
* \param sat saturation of fp8
|
|
611
|
+
* \param type interpretation of fp8
|
|
612
|
+
* \return __hip_fp8_storage_t
|
|
613
|
+
*/
|
|
614
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8(
|
|
615
|
+
const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
|
|
616
|
+
return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, type);
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
/**
|
|
620
|
+
* \brief convert __half2_raw to @p __hip_fp8x2_storage_t
|
|
621
|
+
*
|
|
622
|
+
* \param x __half2_raw value
|
|
623
|
+
* \param sat saturation of fp8
|
|
624
|
+
* \param type interpretation of fp8
|
|
625
|
+
* \return __hip_fp8x2_storage_t
|
|
626
|
+
*/
|
|
627
|
+
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(
|
|
628
|
+
const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
|
|
629
|
+
return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, type);
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
/**
|
|
633
|
+
* \brief struct representing single fp8 number with e4m3 interpretation
|
|
634
|
+
*
|
|
635
|
+
*/
|
|
636
|
+
struct __hip_fp8_e4m3_fnuz {
|
|
637
|
+
__hip_fp8_storage_t __x; //! raw storage of fp8 number
|
|
638
|
+
constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE;
|
|
639
|
+
constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
|
|
640
|
+
constexpr static unsigned int __we = 4;
|
|
641
|
+
constexpr static unsigned int __wm = 3;
|
|
642
|
+
|
|
643
|
+
// TODO: SWDEV-452411
|
|
644
|
+
// Add cast from unsigned long long, long long to fp8
|
|
645
|
+
|
|
646
|
+
/*! create fp8 e4m3 from long */
|
|
647
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
|
|
648
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
649
|
+
__default_interpret)) {}
|
|
650
|
+
|
|
651
|
+
/*! create fp8 e4m3 from int */
|
|
652
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
|
|
653
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
654
|
+
__default_interpret)) {}
|
|
655
|
+
|
|
656
|
+
/*! create fp8 e4m3 from short int */
|
|
657
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
|
|
658
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
659
|
+
__default_interpret)) {}
|
|
660
|
+
|
|
661
|
+
/*! create fp8 e4m3 from unsigned long */
|
|
662
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
|
|
663
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
664
|
+
__default_interpret)) {}
|
|
665
|
+
|
|
666
|
+
/*! create fp8 e4m3 from unsigned int */
|
|
667
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
|
|
668
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
669
|
+
__default_interpret)) {}
|
|
670
|
+
|
|
671
|
+
/*! create fp8 e4m3 from unsigned short */
|
|
672
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
|
|
673
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
674
|
+
__default_interpret)) {}
|
|
675
|
+
|
|
676
|
+
/*! create fp8 e4m3 from double */
|
|
677
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
|
|
678
|
+
: __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
|
|
679
|
+
|
|
680
|
+
/*! create fp8 e4m3 from float */
|
|
681
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
|
|
682
|
+
: __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
|
|
683
|
+
|
|
684
|
+
/*! create fp8 e4m3 from __hip_bfloat16 */
|
|
685
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
|
|
686
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
|
|
687
|
+
__default_interpret)) {}
|
|
688
|
+
|
|
689
|
+
/*! create fp8 e4m3 from __half */
|
|
690
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
|
|
691
|
+
: __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
|
|
692
|
+
__default_interpret)) {}
|
|
693
|
+
|
|
694
|
+
/*! default construct fp8 e4m3 */
|
|
695
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz() = default;
|
|
696
|
+
|
|
697
|
+
/*! convert fp8 e4m3 to __half */
|
|
698
|
+
__FP8_HOST_DEVICE__ operator __half() const {
|
|
699
|
+
return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
/*! convert fp8 e4m3 to __hip_bfloat16 */
|
|
703
|
+
__FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
|
|
704
|
+
float f = *this;
|
|
705
|
+
return __hip_bfloat16(f);
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
/*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */
|
|
709
|
+
__FP8_HOST_DEVICE__ operator bool() const {
|
|
710
|
+
// it can be 0x00 (+0.0) since 0x80 will be nan
|
|
711
|
+
return !(static_cast<unsigned short>(__x) == 0);
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
/*! convert fp8 e4m3 to char, clamp number to CHAR_MIN/CHAR_MAX if its out of range */
|
|
715
|
+
__FP8_HOST_DEVICE__ operator char() const {
|
|
716
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
717
|
+
return 0;
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
auto fval = internal::cast_from_f8<float, true>(__x, __wm, __we);
|
|
721
|
+
auto llval = static_cast<long long>(fval);
|
|
722
|
+
if (llval <= CHAR_MIN) {
|
|
723
|
+
return CHAR_MIN;
|
|
724
|
+
} else if (llval >= CHAR_MAX) {
|
|
725
|
+
return CHAR_MAX;
|
|
726
|
+
}
|
|
727
|
+
return static_cast<char>(fval);
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
/*! convert fp8 e4m3 to double */
|
|
731
|
+
__FP8_HOST_DEVICE__ operator double() const {
|
|
732
|
+
return internal::cast_from_f8<double, true>(__x, __wm, __we);
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
/*! convert fp8 e4m3 to float */
|
|
736
|
+
__FP8_HOST_DEVICE__ operator float() const {
|
|
737
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
738
|
+
return internal::cast_to_f32_from_f8(__x, __default_interpret);
|
|
739
|
+
#else
|
|
740
|
+
return internal::cast_from_f8<float, true>(__x, __wm, __we);
|
|
741
|
+
#endif
|
|
742
|
+
}
|
|
743
|
+
|
|
744
|
+
/*! convert fp8 e4m3 to int, return 0 if value is NaN */
|
|
745
|
+
__FP8_HOST_DEVICE__ operator int() const {
|
|
746
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
747
|
+
return 0;
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
float fval = *this;
|
|
751
|
+
return static_cast<int>(fval);
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
/*! convert fp8 e4m3 to long, return 0 if value is NaN */
|
|
755
|
+
__FP8_HOST_DEVICE__ operator long int() const {
|
|
756
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
757
|
+
return 0;
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
float fval = *this;
|
|
761
|
+
return static_cast<long>(fval);
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
/*! convert fp8 e4m3 to long long, return 0 if value is NaN */
|
|
765
|
+
__FP8_HOST_DEVICE__ operator long long int() const {
|
|
766
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
767
|
+
return 0;
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
float fval = *this;
|
|
771
|
+
return static_cast<long long>(fval);
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
/*! convert fp8 e4m3 to short int, clamp out of bound values, return 0 if value is NaN */
|
|
775
|
+
__FP8_HOST_DEVICE__ operator short int() const {
|
|
776
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
777
|
+
return 0;
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
float fval = *this;
|
|
781
|
+
auto llval = static_cast<long long>(fval);
|
|
782
|
+
if (llval <= SHRT_MIN) {
|
|
783
|
+
return SHRT_MIN;
|
|
784
|
+
} else if (llval >= SHRT_MAX) {
|
|
785
|
+
return SHRT_MAX;
|
|
786
|
+
}
|
|
787
|
+
return static_cast<short>(fval);
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
/*! convert fp8 e4m3 to signed char, clamp out of bound values, return 0 if value is NaN */
|
|
791
|
+
__FP8_HOST_DEVICE__ operator signed char() const {
|
|
792
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
793
|
+
return 0;
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
float fval = *this;
|
|
797
|
+
auto llval = static_cast<long long>(fval);
|
|
798
|
+
if (llval <= SCHAR_MIN) {
|
|
799
|
+
return SCHAR_MIN;
|
|
800
|
+
} else if (llval >= SCHAR_MAX) {
|
|
801
|
+
return SCHAR_MAX;
|
|
802
|
+
}
|
|
803
|
+
return static_cast<signed char>(fval);
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
/*! convert fp8 e4m3 to unsigned char, clamp out of bound values, return 0 if value is NaN */
|
|
807
|
+
__FP8_HOST_DEVICE__ operator unsigned char() const {
|
|
808
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
809
|
+
return 0;
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
float fval = *this;
|
|
813
|
+
auto llval = static_cast<long long>(fval);
|
|
814
|
+
if (llval <= 0) {
|
|
815
|
+
return 0;
|
|
816
|
+
} else if (llval >= UCHAR_MAX) {
|
|
817
|
+
return UCHAR_MAX;
|
|
818
|
+
}
|
|
819
|
+
return static_cast<unsigned char>(fval);
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
/*! convert fp8 e4m3 to unsigned int, return 0 if value is NaN */
|
|
823
|
+
__FP8_HOST_DEVICE__ operator unsigned int() const {
|
|
824
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
825
|
+
return 0;
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
float fval = *this;
|
|
829
|
+
auto llval = static_cast<long long>(fval);
|
|
830
|
+
if (llval <= 0) {
|
|
831
|
+
return 0;
|
|
832
|
+
}
|
|
833
|
+
return static_cast<unsigned int>(fval);
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
/*! convert fp8 e4m3 to unsigned long, return 0 if value is NaN */
|
|
837
|
+
__FP8_HOST_DEVICE__ operator unsigned long int() const {
|
|
838
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
839
|
+
return 0;
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
float fval = *this;
|
|
843
|
+
auto llval = static_cast<long long>(fval);
|
|
844
|
+
if (llval <= 0) {
|
|
845
|
+
return 0;
|
|
846
|
+
}
|
|
847
|
+
return static_cast<unsigned long>(fval);
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
/*! convert fp8 e4m3 to long long int, return 0 if value is NaN */
|
|
851
|
+
__FP8_HOST_DEVICE__ operator unsigned long long int() const {
|
|
852
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
853
|
+
return 0;
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
float fval = *this;
|
|
857
|
+
auto llval = static_cast<long long>(fval);
|
|
858
|
+
if (llval <= 0) {
|
|
859
|
+
return 0;
|
|
860
|
+
}
|
|
861
|
+
return static_cast<unsigned long long>(fval);
|
|
862
|
+
}
|
|
863
|
+
|
|
864
|
+
/*! convert fp8 e4m3 to unsigned short, return 0 if value is NaN */
|
|
865
|
+
__FP8_HOST_DEVICE__ operator unsigned short int() const {
|
|
866
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
867
|
+
return 0;
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
float fval = *this;
|
|
871
|
+
auto llval = static_cast<long long>(fval);
|
|
872
|
+
if (llval <= 0) {
|
|
873
|
+
return 0;
|
|
874
|
+
}
|
|
875
|
+
return static_cast<unsigned short>(fval);
|
|
876
|
+
}
|
|
877
|
+
};
|
|
878
|
+
|
|
879
|
+
/**
|
|
880
|
+
* \brief struct representing two fp8 numbers with e4m3 interpretation
|
|
881
|
+
*
|
|
882
|
+
*/
|
|
883
|
+
struct __hip_fp8x2_e4m3_fnuz {
|
|
884
|
+
__hip_fp8x2_storage_t __x; //! raw storage of two fp8 numbers
|
|
885
|
+
static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
|
|
886
|
+
static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
|
|
887
|
+
static constexpr unsigned int __we = 4;
|
|
888
|
+
static constexpr unsigned int __wm = 3;
|
|
889
|
+
|
|
890
|
+
/*! create fp8x2 e4m3 type from double2 */
|
|
891
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
|
|
892
|
+
: __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
|
|
893
|
+
|
|
894
|
+
/*! create fp8x2 e4m3 type from float2 */
|
|
895
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
|
|
896
|
+
: __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
|
|
897
|
+
|
|
898
|
+
/*! create fp8x2 e4m3 type from __hip_bfloat162 */
|
|
899
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
|
|
900
|
+
: __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
|
|
901
|
+
|
|
902
|
+
/*! create fp8x2 e4m3 type from __half2 */
|
|
903
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
|
|
904
|
+
: __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
|
|
905
|
+
|
|
906
|
+
/*! Default construct of fp8x2 e4m3 */
|
|
907
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz() = default;
|
|
908
|
+
|
|
909
|
+
/*! convert fp8x2 e4m3 to __half2 */
|
|
910
|
+
__FP8_HOST_DEVICE__ operator __half2() const {
|
|
911
|
+
return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
|
|
912
|
+
}
|
|
913
|
+
|
|
914
|
+
/*! convert fp8x2 e4m3 to float2 */
|
|
915
|
+
__FP8_HOST_DEVICE__ operator float2() const {
|
|
916
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
917
|
+
return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
|
|
918
|
+
#else
|
|
919
|
+
return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
|
|
920
|
+
__wm, __we),
|
|
921
|
+
internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
|
|
922
|
+
__wm, __we));
|
|
923
|
+
#endif
|
|
924
|
+
}
|
|
925
|
+
};
|
|
926
|
+
|
|
927
|
+
/**
|
|
928
|
+
* \brief struct representing four fp8 numbers with e4m3 interpretation
|
|
929
|
+
*
|
|
930
|
+
*/
|
|
931
|
+
struct __hip_fp8x4_e4m3_fnuz {
|
|
932
|
+
__hip_fp8x4_storage_t __x; //! raw storage of four fp8 numbers
|
|
933
|
+
static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
|
|
934
|
+
static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
|
|
935
|
+
static constexpr unsigned int __we = 4;
|
|
936
|
+
static constexpr unsigned int __wm = 3;
|
|
937
|
+
|
|
938
|
+
/*! create fp8x4 e4m3 type from double4 */
|
|
939
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
|
|
940
|
+
: __x{reinterpret_cast<__hip_fp8x4_storage_t>(
|
|
941
|
+
static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
|
|
942
|
+
val.x, __default_saturation, __default_interpret)) |
|
|
943
|
+
reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
|
|
944
|
+
val.y, __default_saturation, __default_interpret))
|
|
945
|
+
<< 8 |
|
|
946
|
+
reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
|
|
947
|
+
val.z, __default_saturation, __default_interpret))
|
|
948
|
+
<< 16 |
|
|
949
|
+
reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
|
|
950
|
+
val.w, __default_saturation, __default_interpret))
|
|
951
|
+
<< 24))} {}
|
|
952
|
+
|
|
953
|
+
/*! create fp8x4 e4m3 type from float4 */
|
|
954
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
|
|
955
|
+
: __x{reinterpret_cast<__hip_fp8x4_storage_t>(
|
|
956
|
+
static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
|
|
957
|
+
val.x, __default_saturation, __default_interpret)) |
|
|
958
|
+
reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
|
|
959
|
+
val.y, __default_saturation, __default_interpret))
|
|
960
|
+
<< 8 |
|
|
961
|
+
reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
|
|
962
|
+
val.z, __default_saturation, __default_interpret))
|
|
963
|
+
<< 16 |
|
|
964
|
+
reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
|
|
965
|
+
val.w, __default_saturation, __default_interpret))
|
|
966
|
+
<< 24))} {}
|
|
967
|
+
|
|
968
|
+
/*! create fp8x4 e4m3 type from two __hip_bfloat162 */
|
|
969
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
|
|
970
|
+
: __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
|
|
971
|
+
reinterpret_cast<unsigned short>(
|
|
972
|
+
__hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
|
|
973
|
+
reinterpret_cast<unsigned short>(
|
|
974
|
+
__hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
|
|
975
|
+
<< 16))) {}
|
|
976
|
+
|
|
977
|
+
/*! create fp8x4 e4m3 type from two __half2 */
|
|
978
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
|
|
979
|
+
: __x(reinterpret_cast<__hip_fp8x4_storage_t>(
|
|
980
|
+
static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
|
|
981
|
+
high, __default_saturation, __default_interpret)) |
|
|
982
|
+
reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
|
|
983
|
+
low, __default_saturation, __default_interpret))
|
|
984
|
+
<< 16))) {}
|
|
985
|
+
|
|
986
|
+
/*! Default construct fp8x4 e4m3 */
|
|
987
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz() = default;
|
|
988
|
+
|
|
989
|
+
/*! convert fp8x4 e4m3 to float4 */
|
|
990
|
+
__FP8_HOST_DEVICE__ operator float4() const {
|
|
991
|
+
auto x = __x; // bypass const
|
|
992
|
+
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
|
|
993
|
+
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
|
|
994
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
995
|
+
float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
|
|
996
|
+
float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
|
|
997
|
+
#else
|
|
998
|
+
float2 high = float2(internal::cast_from_f8<float, true>(
|
|
999
|
+
static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
|
|
1000
|
+
internal::cast_from_f8<float, true>(
|
|
1001
|
+
static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
|
|
1002
|
+
float2 low = float2(internal::cast_from_f8<float, true>(
|
|
1003
|
+
static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
|
|
1004
|
+
internal::cast_from_f8<float, true>(
|
|
1005
|
+
static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
|
|
1006
|
+
#endif
|
|
1007
|
+
return float4(low.x, low.y, high.x, high.y);
|
|
1008
|
+
}
|
|
1009
|
+
};
|
|
1010
|
+
|
|
1011
|
+
/**
|
|
1012
|
+
* \brief struct representing one fp8 number with e5m2 interpretation
|
|
1013
|
+
*
|
|
1014
|
+
*/
|
|
1015
|
+
struct __hip_fp8_e5m2_fnuz {
|
|
1016
|
+
__hip_fp8_storage_t __x; //! raw storage of one fp8 numbers
|
|
1017
|
+
static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
|
|
1018
|
+
static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
|
|
1019
|
+
static constexpr unsigned int __we = 5;
|
|
1020
|
+
static constexpr unsigned int __wm = 2;
|
|
1021
|
+
|
|
1022
|
+
|
|
1023
|
+
// TODO: SWDEV-452411
|
|
1024
|
+
// Add cast from unsigned long long, long long to fp8
|
|
1025
|
+
|
|
1026
|
+
/*! create fp8 e5m2 type from long */
|
|
1027
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
|
|
1028
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
1029
|
+
__default_interpret)) {}
|
|
1030
|
+
|
|
1031
|
+
/*! create fp8 e5m2 type from int */
|
|
1032
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
|
|
1033
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
1034
|
+
__default_interpret)) {}
|
|
1035
|
+
|
|
1036
|
+
/*! create fp8 e5m2 type from short int */
|
|
1037
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
|
|
1038
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
1039
|
+
__default_interpret)) {}
|
|
1040
|
+
|
|
1041
|
+
/*! create fp8 e5m2 type from unsigned long */
|
|
1042
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
|
|
1043
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
1044
|
+
__default_interpret)) {}
|
|
1045
|
+
|
|
1046
|
+
/*! create fp8 e5m2 type from unsigned int */
|
|
1047
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
|
|
1048
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
1049
|
+
__default_interpret)) {}
|
|
1050
|
+
|
|
1051
|
+
/*! create fp8 e5m2 type from unsigned short */
|
|
1052
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
|
|
1053
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
|
|
1054
|
+
__default_interpret)) {}
|
|
1055
|
+
|
|
1056
|
+
/*! create fp8 e5m2 type from double */
|
|
1057
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
|
|
1058
|
+
: __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
|
|
1059
|
+
|
|
1060
|
+
/*! create fp8 e5m2 type from float */
|
|
1061
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
|
|
1062
|
+
: __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
|
|
1063
|
+
|
|
1064
|
+
/*! create fp8 e5m2 type from __hip_bfloat16 */
|
|
1065
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
|
|
1066
|
+
: __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
|
|
1067
|
+
__default_interpret)) {}
|
|
1068
|
+
|
|
1069
|
+
/*! create fp8 e5m2 type from __hip_bfloat16 */
|
|
1070
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
|
|
1071
|
+
: __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
|
|
1072
|
+
__default_interpret)) {}
|
|
1073
|
+
|
|
1074
|
+
/*! default construct fp8 e5m2 */
|
|
1075
|
+
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz() = default;
|
|
1076
|
+
|
|
1077
|
+
/*! convert fp8 e5m2 to float */
|
|
1078
|
+
__FP8_HOST_DEVICE__ operator float() const {
|
|
1079
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
1080
|
+
return internal::cast_to_f32_from_f8(__x, __default_interpret);
|
|
1081
|
+
#else
|
|
1082
|
+
return internal::cast_from_f8<float, true>(__x, __wm, __we);
|
|
1083
|
+
#endif
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
/*! convert fp8 e5m2 to __half */
|
|
1087
|
+
__FP8_HOST_DEVICE__ operator __half() const {
|
|
1088
|
+
return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
|
|
1089
|
+
}
|
|
1090
|
+
|
|
1091
|
+
/*! convert fp8 e5m2 to __hip_bfloat16 */
|
|
1092
|
+
__FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
|
|
1093
|
+
float f = *this;
|
|
1094
|
+
return __hip_bfloat16(f);
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
/*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */
|
|
1098
|
+
__FP8_HOST_DEVICE__ operator bool() const {
|
|
1099
|
+
// it can be 0x00 (+0.0) since 0x80 will be nan
|
|
1100
|
+
return !(static_cast<unsigned short>(__x) == 0);
|
|
1101
|
+
}
|
|
1102
|
+
|
|
1103
|
+
/*! convert fp8 e5m2 to char, clamp out of bound values, return 0 if value is NaN */
|
|
1104
|
+
__FP8_HOST_DEVICE__ operator char() const {
|
|
1105
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1106
|
+
return 0;
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
float fval = *this;
|
|
1110
|
+
auto llval = static_cast<long long>(fval);
|
|
1111
|
+
if (llval <= CHAR_MIN) {
|
|
1112
|
+
return CHAR_MIN;
|
|
1113
|
+
} else if (llval >= CHAR_MAX) {
|
|
1114
|
+
return CHAR_MAX;
|
|
1115
|
+
}
|
|
1116
|
+
return static_cast<char>(fval);
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
/*! convert fp8 e5m2 to double */
|
|
1120
|
+
__FP8_HOST_DEVICE__ operator double() const {
|
|
1121
|
+
return internal::cast_from_f8<double, true>(__x, __wm, __we);
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
/*! convert fp8 e5m2 to int, return 0 if value is NaN */
|
|
1125
|
+
__FP8_HOST_DEVICE__ operator int() const {
|
|
1126
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1127
|
+
return 0;
|
|
1128
|
+
}
|
|
1129
|
+
|
|
1130
|
+
float fval = *this;
|
|
1131
|
+
return static_cast<int>(fval);
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
/*! convert fp8 e5m2 to long, return 0 if value is NaN */
|
|
1135
|
+
__FP8_HOST_DEVICE__ operator long int() const {
|
|
1136
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1137
|
+
return 0;
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
float fval = *this;
|
|
1141
|
+
return static_cast<long>(fval);
|
|
1142
|
+
}
|
|
1143
|
+
|
|
1144
|
+
/*! convert fp8 e5m2 to long long, return 0 if value is NaN */
|
|
1145
|
+
__FP8_HOST_DEVICE__ operator long long int() const {
|
|
1146
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1147
|
+
return 0;
|
|
1148
|
+
}
|
|
1149
|
+
|
|
1150
|
+
float fval = *this;
|
|
1151
|
+
return static_cast<long long>(fval);
|
|
1152
|
+
}
|
|
1153
|
+
|
|
1154
|
+
/*! convert fp8 e5m2 to short, clamp out of bound values, return 0 if value is NaN */
|
|
1155
|
+
__FP8_HOST_DEVICE__ operator short int() const {
|
|
1156
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1157
|
+
return 0;
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
float fval = *this;
|
|
1161
|
+
auto llval = static_cast<long long>(fval);
|
|
1162
|
+
if (llval <= SHRT_MIN) {
|
|
1163
|
+
return SHRT_MIN;
|
|
1164
|
+
} else if (llval >= SHRT_MAX) {
|
|
1165
|
+
return SHRT_MAX;
|
|
1166
|
+
}
|
|
1167
|
+
return static_cast<short>(fval);
|
|
1168
|
+
}
|
|
1169
|
+
|
|
1170
|
+
/*! convert fp8 e5m2 to signed char, clamp out of bound values, return 0 if value is NaN */
|
|
1171
|
+
__FP8_HOST_DEVICE__ operator signed char() const {
|
|
1172
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1173
|
+
return 0;
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1176
|
+
float fval = *this;
|
|
1177
|
+
auto llval = static_cast<long long>(fval);
|
|
1178
|
+
if (llval <= SCHAR_MIN) {
|
|
1179
|
+
return SCHAR_MIN;
|
|
1180
|
+
} else if (llval >= SCHAR_MAX) {
|
|
1181
|
+
return SCHAR_MAX;
|
|
1182
|
+
}
|
|
1183
|
+
return static_cast<signed char>(fval);
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
/*! convert fp8 e5m2 to unsigned char, clamp out of bound values, return 0 if value is NaN */
|
|
1187
|
+
__FP8_HOST_DEVICE__ operator unsigned char() const {
|
|
1188
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1189
|
+
return 0;
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
float fval = *this;
|
|
1193
|
+
auto llval = static_cast<long long>(fval);
|
|
1194
|
+
if (llval <= 0) {
|
|
1195
|
+
return 0;
|
|
1196
|
+
} else if (llval >= UCHAR_MAX) {
|
|
1197
|
+
return UCHAR_MAX;
|
|
1198
|
+
}
|
|
1199
|
+
return static_cast<unsigned char>(fval);
|
|
1200
|
+
}
|
|
1201
|
+
|
|
1202
|
+
/*! convert fp8 e5m2 to unsigned int, return 0 if value is NaN */
|
|
1203
|
+
__FP8_HOST_DEVICE__ operator unsigned int() const {
|
|
1204
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1205
|
+
return 0;
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
float fval = *this;
|
|
1209
|
+
auto llval = static_cast<long long>(fval);
|
|
1210
|
+
if (llval <= 0) {
|
|
1211
|
+
return 0;
|
|
1212
|
+
}
|
|
1213
|
+
return static_cast<unsigned int>(fval);
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
/*! convert fp8 e5m2 to unsigned long, return 0 if value is NaN */
|
|
1217
|
+
__FP8_HOST_DEVICE__ operator unsigned long int() const {
|
|
1218
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1219
|
+
return 0;
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
float fval = *this;
|
|
1223
|
+
auto llval = static_cast<long long>(fval);
|
|
1224
|
+
if (llval <= 0) {
|
|
1225
|
+
return 0;
|
|
1226
|
+
}
|
|
1227
|
+
return static_cast<unsigned long>(fval);
|
|
1228
|
+
}
|
|
1229
|
+
|
|
1230
|
+
/*! convert fp8 e5m2 to unsigned long long, return 0 if value is NaN */
|
|
1231
|
+
__FP8_HOST_DEVICE__ operator unsigned long long int() const {
|
|
1232
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1233
|
+
return 0;
|
|
1234
|
+
}
|
|
1235
|
+
|
|
1236
|
+
float fval = *this;
|
|
1237
|
+
auto llval = static_cast<long long>(fval);
|
|
1238
|
+
if (llval <= 0) {
|
|
1239
|
+
return 0;
|
|
1240
|
+
}
|
|
1241
|
+
return static_cast<unsigned long long>(fval);
|
|
1242
|
+
}
|
|
1243
|
+
|
|
1244
|
+
/*! convert fp8 e5m2 to unsigned short, return 0 if value is NaN */
|
|
1245
|
+
__FP8_HOST_DEVICE__ operator unsigned short int() const {
|
|
1246
|
+
if (internal::hip_fp8_fnuz_is_nan(__x)) {
|
|
1247
|
+
return 0;
|
|
1248
|
+
}
|
|
1249
|
+
|
|
1250
|
+
float fval = *this;
|
|
1251
|
+
auto llval = static_cast<long long>(fval);
|
|
1252
|
+
if (llval <= 0) {
|
|
1253
|
+
return 0;
|
|
1254
|
+
}
|
|
1255
|
+
return static_cast<unsigned short>(fval);
|
|
1256
|
+
}
|
|
1257
|
+
};
|
|
1258
|
+
|
|
1259
|
+
/**
|
|
1260
|
+
* \brief struct representing two fp8 numbers with e5m2 interpretation
|
|
1261
|
+
*
|
|
1262
|
+
*/
|
|
1263
|
+
struct __hip_fp8x2_e5m2_fnuz {
|
|
1264
|
+
__hip_fp8x2_storage_t __x; //! raw storage of two fp8 numbers
|
|
1265
|
+
static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
|
|
1266
|
+
static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
|
|
1267
|
+
static constexpr unsigned int __we = 5;
|
|
1268
|
+
static constexpr unsigned int __wm = 2;
|
|
1269
|
+
|
|
1270
|
+
/*! create fp8x2 e5m2 type from double2 */
|
|
1271
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
|
|
1272
|
+
: __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
|
|
1273
|
+
|
|
1274
|
+
/*! create fp8x2 e5m2 type from float2 */
|
|
1275
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
|
|
1276
|
+
: __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
|
|
1277
|
+
|
|
1278
|
+
/*! create fp8x2 e5m2 type from __hip_bfloat162 */
|
|
1279
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
|
|
1280
|
+
: __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
|
|
1281
|
+
|
|
1282
|
+
/*! create fp8x2 e5m2 type from __half2 */
|
|
1283
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
|
|
1284
|
+
: __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
|
|
1285
|
+
|
|
1286
|
+
/*! default construct fp8x2 e5m2 */
|
|
1287
|
+
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz() = default;
|
|
1288
|
+
|
|
1289
|
+
/*! convert fp8x2 e5m2 to __half2 */
|
|
1290
|
+
__FP8_HOST_DEVICE__ operator __half2() const {
|
|
1291
|
+
return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
|
|
1292
|
+
}
|
|
1293
|
+
|
|
1294
|
+
/*! convert fp8x2 e5m2 to float2 */
|
|
1295
|
+
__FP8_HOST_DEVICE__ operator float2() const {
|
|
1296
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
1297
|
+
return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
|
|
1298
|
+
#else
|
|
1299
|
+
return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
|
|
1300
|
+
__wm, __we),
|
|
1301
|
+
internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
|
|
1302
|
+
__wm, __we));
|
|
1303
|
+
#endif
|
|
1304
|
+
}
|
|
1305
|
+
};
|
|
1306
|
+
|
|
1307
|
+
/**
|
|
1308
|
+
* \brief struct representing four fp8 numbers with e5m2 interpretation
|
|
1309
|
+
*
|
|
1310
|
+
*/
|
|
1311
|
+
struct __hip_fp8x4_e5m2_fnuz {
|
|
1312
|
+
__hip_fp8x4_storage_t __x; //! raw storage of four fp8 numbers
|
|
1313
|
+
static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
|
|
1314
|
+
static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
|
|
1315
|
+
static constexpr unsigned int __we = 5;
|
|
1316
|
+
static constexpr unsigned int __wm = 2;
|
|
1317
|
+
|
|
1318
|
+
/*! create fp8x4 e5m2 type from double4 */
|
|
1319
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
|
|
1320
|
+
: __x(reinterpret_cast<__hip_fp8x4_storage_t>(
|
|
1321
|
+
static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
|
|
1322
|
+
val.x, __default_saturation, __default_interpret)) |
|
|
1323
|
+
reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
|
|
1324
|
+
val.y, __default_saturation, __default_interpret))
|
|
1325
|
+
<< 8 |
|
|
1326
|
+
reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
|
|
1327
|
+
val.z, __default_saturation, __default_interpret))
|
|
1328
|
+
<< 16 |
|
|
1329
|
+
reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
|
|
1330
|
+
val.w, __default_saturation, __default_interpret))
|
|
1331
|
+
<< 24))) {}
|
|
1332
|
+
|
|
1333
|
+
/*! create fp8x4 e5m2 type from float4 */
|
|
1334
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
|
|
1335
|
+
: __x(reinterpret_cast<__hip_fp8x4_storage_t>(
|
|
1336
|
+
static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
|
|
1337
|
+
val.x, __default_saturation, __default_interpret)) |
|
|
1338
|
+
reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
|
|
1339
|
+
val.y, __default_saturation, __default_interpret))
|
|
1340
|
+
<< 8 |
|
|
1341
|
+
reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
|
|
1342
|
+
val.z, __default_saturation, __default_interpret))
|
|
1343
|
+
<< 16 |
|
|
1344
|
+
reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
|
|
1345
|
+
val.w, __default_saturation, __default_interpret))
|
|
1346
|
+
<< 24))) {}
|
|
1347
|
+
|
|
1348
|
+
/*! create fp8x4 e5m2 type from two __hip_bfloat162 */
|
|
1349
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
|
|
1350
|
+
: __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
|
|
1351
|
+
reinterpret_cast<unsigned short>(
|
|
1352
|
+
__hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
|
|
1353
|
+
reinterpret_cast<unsigned short>(
|
|
1354
|
+
__hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
|
|
1355
|
+
<< 16))) {}
|
|
1356
|
+
|
|
1357
|
+
/*! create fp8x4 e5m2 type from two __half2 */
|
|
1358
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
|
|
1359
|
+
: __x(reinterpret_cast<__hip_fp8x4_storage_t>(
|
|
1360
|
+
static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
|
|
1361
|
+
high, __default_saturation, __default_interpret)) |
|
|
1362
|
+
reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
|
|
1363
|
+
low, __default_saturation, __default_interpret))
|
|
1364
|
+
<< 16))) {}
|
|
1365
|
+
|
|
1366
|
+
/* default construct fp8x4 e5m2 */
|
|
1367
|
+
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz() = default;
|
|
1368
|
+
|
|
1369
|
+
/*! convert fp8x4 e5m2 to float4 */
|
|
1370
|
+
__FP8_HOST_DEVICE__ operator float4() const {
|
|
1371
|
+
auto x = __x; // bypass const
|
|
1372
|
+
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
|
|
1373
|
+
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
|
|
1374
|
+
#if HIP_FP8_CVT_FAST_PATH
|
|
1375
|
+
float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
|
|
1376
|
+
float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
|
|
1377
|
+
#else
|
|
1378
|
+
float2 high = float2(internal::cast_from_f8<float, true>(
|
|
1379
|
+
static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
|
|
1380
|
+
internal::cast_from_f8<float, true>(
|
|
1381
|
+
static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
|
|
1382
|
+
float2 low = float2(internal::cast_from_f8<float, true>(
|
|
1383
|
+
static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
|
|
1384
|
+
internal::cast_from_f8<float, true>(
|
|
1385
|
+
static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
|
|
1386
|
+
#endif
|
|
1387
|
+
return float4(low.x, low.y, high.x, high.y);
|
|
1388
|
+
}
|
|
1389
|
+
};
|
|
1390
|
+
|
|
1391
|
+
#endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
|