triton-windows 3.2.0.post12__cp39-cp39-win_amd64.whl → 3.3.0a0.post12__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1391 @@
1
+ /**
2
+ * MIT License
3
+ *
4
+ * Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved.
5
+ *
6
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ * of this software and associated documentation files (the "Software"), to deal
8
+ * in the Software without restriction, including without limitation the rights
9
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ * copies of the Software, and to permit persons to whom the Software is
11
+ * furnished to do so, subject to the following conditions:
12
+ *
13
+ * The above copyright notice and this permission notice shall be included in
14
+ * all copies or substantial portions of the Software.
15
+ *
16
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ * SOFTWARE.
23
+ */
24
+
25
+ /**
26
+ * \file
27
+ * \brief amd_hip_fp8.h header, for AMD fp8 data types
28
+ */
29
+
30
+ #ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
31
+ #define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
32
+
33
+ #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && __HIP_DEVICE_COMPILE__
34
+ #define HIP_FP8_CVT_FAST_PATH 1
35
+ #else
36
+ #define HIP_FP8_CVT_FAST_PATH 0
37
+ #endif
38
+
39
+ #if !defined(__HIPCC_RTC__)
40
+ #include <hip/amd_detail/amd_hip_common.h>
41
+ #include <climits>
42
+
43
+ #include "host_defines.h" // __hip_internal::
44
+ #include "amd_hip_vector_types.h" // float2 etc
45
+ #include "amd_hip_fp16.h" // __half_raw
46
+ #include "amd_hip_bf16.h" // bf16
47
+ #include "math_fwd.h" // ocml device functions
48
+ #endif // !defined(__HIPCC_RTC__)
49
+
50
+ #if defined(__HIPCC_RTC__)
51
+ #define __FP8_HOST_DEVICE__ __device__
52
+ #define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static
53
+ #else
54
+ #define __FP8_HOST_DEVICE__ __host__ __device__
55
+ #define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline
56
+ #endif // __HIPCC_RTC__
57
+
58
+ #if !defined(__HIPCC_RTC__)
59
+ static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
60
+ #endif
61
+ static_assert(sizeof(unsigned char) == 1);
62
+ static_assert(sizeof(unsigned short int) == 2);
63
+ static_assert(sizeof(unsigned int) == 4);
64
+
65
+ /**
66
+ * \brief Describes FP8 interpretation
67
+ */
68
+ enum __hip_fp8_interpretation_t {
69
+ __HIP_E4M3_FNUZ = 0, /**< Standard FP8 */
70
+ __HIP_E5M2_FNUZ = 1, /**< BF8 */
71
+ };
72
+
73
+ /**
74
+ * \brief Describes saturation behavior
75
+ */
76
+ enum __hip_saturation_t {
77
+ __HIP_NOSAT = 0, /**< No saturation */
78
+ __HIP_SATFINITE = 1, /**< Saturate to finite */
79
+ };
80
+
81
+ /** \typedef __hip_fp8_storage_t
82
+ *
83
+ * \brief type to store single fp8 number
84
+ */
85
+ typedef unsigned char __hip_fp8_storage_t;
86
+
87
+
88
+ /** \typedef __hip_fp8x2_storage_t
89
+ *
90
+ * \brief type to store two fp8 numbers
91
+ */
92
+ typedef unsigned short int __hip_fp8x2_storage_t;
93
+
94
+
95
+ /** \typedef __hip_fp8x4_storage_t
96
+ *
97
+ * \brief type to store four fp8 numbers
98
+ */
99
+ typedef unsigned int __hip_fp8x4_storage_t;
100
+
101
+ namespace internal {
102
+ // The conversion function is from rocblas
103
+ // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
104
+ // This has been modified to add double types conversion as well
105
+ template <typename T, bool negative_zero_nan>
106
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false,
107
+ bool stoch = false,
108
+ unsigned int rng = 0) {
109
+ constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
110
+ constexpr bool is_float = __hip_internal::is_same<T, float>::value;
111
+ constexpr bool is_double = __hip_internal::is_same<T, double>::value;
112
+ static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8");
113
+
114
+ const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
115
+ unsigned long long x;
116
+
117
+ if (sizeof(T) == 8)
118
+ x = reinterpret_cast<unsigned long long&>(_x);
119
+ else if (sizeof(T) == 4)
120
+ x = reinterpret_cast<unsigned int&>(_x);
121
+ else
122
+ x = reinterpret_cast<unsigned short int&>(_x);
123
+
124
+
125
+ unsigned long long head, mantissa;
126
+ int exponent, bias;
127
+ unsigned int sign;
128
+
129
+ if (sizeof(T) == 8) {
130
+ head = x & 0xFFF0000000000000ull;
131
+ mantissa = x & 0xFFFFFFFFFFFFFull;
132
+ exponent = (head >> 52) & 0x7FF;
133
+ sign = head >> 63;
134
+ bias = 1023;
135
+ } else if (sizeof(T) == 4) {
136
+ head = x & 0xFF800000;
137
+ mantissa = x & 0x7FFFFF;
138
+ exponent = (head >> 23) & 0xFF;
139
+ sign = head >> 31;
140
+ bias = 127;
141
+ } else {
142
+ head = x & 0xFC00;
143
+ mantissa = x & 0x3FF;
144
+ exponent = (head >> 10) & 0x1F;
145
+ sign = head >> 15;
146
+ bias = 15;
147
+ }
148
+
149
+ unsigned int signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
150
+
151
+ // Deal with inf and NaNs
152
+ if (negative_zero_nan) {
153
+ if (sizeof(T) == 8) {
154
+ if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull) return 0x80;
155
+ } else if (sizeof(T) == 4) {
156
+ if ((x & 0x7F800000) == 0x7F800000) return 0x80;
157
+ } else {
158
+ if ((x & 0x7C00) == 0x7C00) return 0x80;
159
+ }
160
+ } else {
161
+ if (sizeof(T) == 8) {
162
+ if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull)
163
+ return signed_inf + (mantissa != 0 ? 1 : 0);
164
+ } else if (sizeof(T) == 4) {
165
+ if ((x & 0x7F800000) == 0x7F800000) return signed_inf + (mantissa != 0 ? 1 : 0);
166
+ } else {
167
+ if ((x & 0x7C00) == 0x7C00) return signed_inf + (mantissa != 0 ? 1 : 0);
168
+ }
169
+ }
170
+
171
+ if (x == 0) {
172
+ return 0;
173
+ }
174
+
175
+ // First need to check if it is normal or denorm as there is a difference of implict 1
176
+ // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
177
+ // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
178
+ // RNE, no need to add rng. Then probably need to check whether there is carry and adjust
179
+ // exponent and mantissa again
180
+
181
+ // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
182
+ const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
183
+ const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
184
+ // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
185
+ // f8_exponent is the converted f8 exponent with bias encoding
186
+ // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
187
+ // the difference needs to be adjusted and mantissa shifted
188
+ int act_exponent, f8_exponent, exponent_diff;
189
+
190
+ if (exponent == 0) { // fp32/fp16 is in denormal.
191
+ /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
192
+ here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
193
+ exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
194
+ fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
195
+ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In
196
+ this case, the fp16 mantissa should be shift left by 1 */
197
+ act_exponent = exponent - bias + 1;
198
+ exponent_diff = f8_denormal_act_exponent -
199
+ act_exponent; // actual exponent is exponent-bias+1 as it is denormal
200
+ } else { // fp32/fp16 is normal with implicit 1
201
+ act_exponent = exponent - bias;
202
+ if (act_exponent <= f8_denormal_act_exponent) {
203
+ /* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
204
+ For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
205
+ actual exponent is -7, it is actually larger due to the implict 1,
206
+ Therefore it needs to be adjust to -6 and mantissa shift right by 1.
207
+ So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
208
+ exponent_diff = f8_denormal_act_exponent - act_exponent;
209
+ } else { // both fp32/fp16 and f8 are in normal range
210
+ exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case,
211
+ // act_exponent could be larger. Just that it does not need shift mantissa
212
+ }
213
+ mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
214
+ }
215
+
216
+ bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
217
+ (1ull << (mfmt - wm + exponent_diff - 1));
218
+ /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift
219
+ right as shift right could rip off some residual part and make something not midpoint look like
220
+ midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint, but
221
+ after shift right by 4 bits, it would look like midpoint.
222
+ */
223
+
224
+ if (exponent_diff > 0)
225
+ mantissa >>= exponent_diff;
226
+ else if (exponent_diff == -1)
227
+ mantissa <<= -exponent_diff;
228
+ bool implicit_one = mantissa & (1ull << mfmt);
229
+ // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
230
+ f8_exponent =
231
+ (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
232
+
233
+ // Now we have the exponent and mantissa adjusted
234
+ unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
235
+ bool odd =
236
+ mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
237
+ mantissa +=
238
+ (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
239
+
240
+ // Now we deal with overflow
241
+ if (f8_exponent == 0) {
242
+ if ((1ull << mfmt) & mantissa) {
243
+ f8_exponent = 1; // denormal overflow to become normal, promote exponent
244
+ }
245
+ } else {
246
+ if ((1ull << (mfmt + 1)) & mantissa) {
247
+ mantissa >>= 1;
248
+ f8_exponent++;
249
+ }
250
+ }
251
+
252
+ mantissa >>= (mfmt - wm);
253
+
254
+ // above range: quantize to maximum possible float of the same sign
255
+ const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
256
+ if (f8_exponent > max_exp) {
257
+ if (clip) {
258
+ mantissa = (1 << wm) - 1;
259
+ f8_exponent = max_exp;
260
+ } else {
261
+ return signed_inf;
262
+ }
263
+ }
264
+
265
+ if (f8_exponent == 0 && mantissa == 0) return negative_zero_nan ? 0 : (sign << 7);
266
+ mantissa &= (1 << wm) - 1;
267
+ return (sign << 7) | (f8_exponent << wm) | mantissa;
268
+ }
269
+
270
+ // The conversion function is from rocblas
271
+ // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
272
+ // This has been modified to handle double types as well
273
+ template <typename T, bool negative_zero_nan>
274
+ __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we) {
275
+ constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
276
+ constexpr bool is_float = __hip_internal::is_same<T, float>::value;
277
+ constexpr bool is_double = __hip_internal::is_same<T, double>::value;
278
+ static_assert(is_half || is_float || is_double, "only half, float and double are supported");
279
+
280
+ constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
281
+ constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
282
+
283
+ T fInf, fNegInf, fNaN, fNeg0;
284
+ if (is_half) {
285
+ const unsigned short int ihInf = 0x7C00;
286
+ const unsigned short int ihNegInf = 0xFC00;
287
+ const unsigned short int ihNaN = 0x7C01;
288
+ const unsigned short int ihNeg0 = 0x8000;
289
+ fInf = reinterpret_cast<const _Float16&>(ihInf);
290
+ fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
291
+ fNaN = reinterpret_cast<const _Float16&>(ihNaN);
292
+ fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
293
+ } else if (is_float) {
294
+ const unsigned int ifInf = 0x7F800000;
295
+ const unsigned int ifNegInf = 0xFF800000;
296
+ const unsigned int ifNaN = 0x7F800001;
297
+ const unsigned int ifNeg0 = 0x80000000;
298
+ fInf = reinterpret_cast<const float&>(ifInf);
299
+ fNegInf = reinterpret_cast<const float&>(ifNegInf);
300
+ fNaN = reinterpret_cast<const float&>(ifNaN);
301
+ fNeg0 = reinterpret_cast<const float&>(ifNeg0);
302
+ } else if (is_double) {
303
+ const unsigned long long ifInf = 0x7FF0000000000000ull;
304
+ const unsigned long long ifNegInf = 0xFFF0000000000000ull;
305
+ const unsigned long long ifNaN = 0x7FF0000000000001ull;
306
+ const unsigned long long ifNeg0 = 0x8000000000000000ull;
307
+ fInf = reinterpret_cast<const double&>(ifInf);
308
+ fNegInf = reinterpret_cast<const double&>(ifNegInf);
309
+ fNaN = reinterpret_cast<const double&>(ifNaN);
310
+ fNeg0 = reinterpret_cast<const double&>(ifNeg0);
311
+ }
312
+
313
+ if (x == 0) {
314
+ return 0;
315
+ }
316
+
317
+ unsigned long long sign = x >> 7;
318
+ unsigned long long mantissa = x & ((1 << wm) - 1);
319
+ int exponent = (x & 0x7F) >> wm;
320
+ if (negative_zero_nan) {
321
+ if (x == 0x80) return fNaN;
322
+ } else {
323
+ if (x == 0x80) return fNeg0;
324
+ if (exponent == ((1 << we) - 1)) return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
325
+ }
326
+
327
+ typename __hip_internal::conditional<
328
+ sizeof(T) == 2, unsigned short int,
329
+ typename __hip_internal::conditional<sizeof(T) == 4, unsigned int,
330
+ unsigned long long>::type>::type retval;
331
+
332
+ if (we == 5 && is_half && !negative_zero_nan) {
333
+ retval = x << 8;
334
+ return reinterpret_cast<const T&>(retval);
335
+ }
336
+
337
+ const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
338
+
339
+ // subnormal input
340
+ if (exponent == 0) {
341
+ #if __HIP_DEVICE_COMPILE__
342
+ // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
343
+ int sh = 1 + __clz(mantissa) - (32 - wm);
344
+ #else
345
+ int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
346
+ #endif
347
+ mantissa <<= sh;
348
+ exponent += 1 - sh;
349
+ mantissa &= ((1ull << wm) - 1);
350
+ }
351
+ exponent += exp_low_cutoff - 1;
352
+ mantissa <<= wmo - wm;
353
+
354
+ // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
355
+ if (exponent <= 0) {
356
+ mantissa |= 1 << wmo;
357
+ mantissa >>= 1 - exponent;
358
+ exponent = 0;
359
+ }
360
+
361
+ if (sizeof(T) == 2)
362
+ retval = (sign << 15) | (exponent << 10) | mantissa;
363
+ else if (sizeof(T) == 4)
364
+ retval = (sign << 31) | (exponent << 23) | mantissa;
365
+ else
366
+ retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
367
+ return reinterpret_cast<const T&>(retval);
368
+ }
369
+
370
+ #if HIP_FP8_CVT_FAST_PATH
371
+ // The conversion function is from rocblas
372
+ // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
373
+ template <bool stochastic_rounding = false>
374
+ static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate,
375
+ __hip_fp8_interpretation_t interpret,
376
+ unsigned int rng = 0) {
377
+ __hip_fp8_storage_t i8data;
378
+ union {
379
+ float fval;
380
+ unsigned int i32val;
381
+ unsigned char i8val[4]; // NOTE: not endian independent
382
+ } val;
383
+
384
+ unsigned int ival = 0;
385
+ val.fval = v;
386
+
387
+ if (saturate) {
388
+ if (interpret == __HIP_E4M3_FNUZ) {
389
+ if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
390
+ val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
391
+ }
392
+ } else {
393
+ if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
394
+ val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
395
+ }
396
+ }
397
+ }
398
+
399
+ if (stochastic_rounding) {
400
+ ival = interpret == __HIP_E4M3_FNUZ
401
+ ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
402
+ : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
403
+ val.i32val = ival;
404
+ i8data = val.i8val[0]; // little endian
405
+ } else { // RNE CVT
406
+ ival = interpret == __HIP_E4M3_FNUZ
407
+ ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
408
+ : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
409
+ val.i32val = ival;
410
+ i8data = val.i8val[0];
411
+ }
412
+ return i8data;
413
+ }
414
+
415
+ static __device__ __hip_fp8x2_storage_t
416
+ cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) {
417
+ union {
418
+ static_assert(sizeof(float2) == sizeof(unsigned int[2]));
419
+ static_assert(sizeof(float2) == sizeof(unsigned short[4]));
420
+ float2 fval;
421
+ unsigned int i32val[2];
422
+ unsigned short i16val[4];
423
+ } f2val;
424
+
425
+ f2val.fval = v;
426
+
427
+ if (saturate) { /// propagate NAN/INF, no clipping
428
+ if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
429
+ f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
430
+ }
431
+ if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
432
+ f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
433
+ }
434
+ }
435
+
436
+ f2val.i32val[0] = interpret == __HIP_E4M3_FNUZ
437
+ ? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0, false)
438
+ : __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0, false);
439
+
440
+ return static_cast<__hip_fp8x2_storage_t>(f2val.i16val[0]);
441
+ }
442
+
443
+ static __device__ float cast_to_f32_from_f8(__hip_fp8_storage_t v,
444
+ __hip_fp8_interpretation_t interpret) {
445
+ union {
446
+ unsigned int i32val;
447
+ unsigned char i8val[4];
448
+ } val;
449
+ val.i8val[0] = v;
450
+
451
+ float fval = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0)
452
+ : __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
453
+ return fval;
454
+ }
455
+
456
+ static __device__ float2 cast_to_f32x2_from_f8x2(__hip_fp8x2_storage_t v,
457
+ __hip_fp8_interpretation_t interpret) {
458
+ union {
459
+ unsigned int i32val;
460
+ unsigned short i16val[2];
461
+ } val;
462
+ val.i16val[0] = v;
463
+
464
+ auto f2 = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val, false)
465
+ : __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val, false);
466
+ return float2{f2[0], f2[1]};
467
+ }
468
+ #endif // HIP_FP8_CVT_FAST_PATH
469
+
470
+ /* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
471
+ Inf are not supported. This gives us one additional number to represent.
472
+ NaN are represented by 1-0000-000 or 1-00000-00 */
473
+ __FP8_HOST_DEVICE_STATIC__ bool hip_fp8_fnuz_is_nan(__hip_fp8_storage_t a) {
474
+ return static_cast<unsigned char>(a) == 0x80;
475
+ }
476
+ } // namespace internal
477
+
478
+ /**
479
+ * \brief convert float to @p __hip_fp8_storage_t
480
+ *
481
+ * \param f float number
482
+ * \param sat saturation of fp8
483
+ * \param type interpretation of fp8
484
+ * \return __hip_fp8_storage_t
485
+ */
486
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(
487
+ const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
488
+ #if HIP_FP8_CVT_FAST_PATH
489
+ return internal::cast_to_f8_from_f32<false>(f, sat == __HIP_SATFINITE, type);
490
+ #else // HIP_FP8_CVT_FAST_PATH
491
+ int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
492
+ int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
493
+ return internal::cast_to_f8<float, true>(f, wm, we, sat == __HIP_SATFINITE);
494
+ #endif // HIP_FP8_CVT_FAST_PATH
495
+ }
496
+
497
+ /**
498
+ * \brief convert float2 to @p __hip_fp8x2_storage_t
499
+ *
500
+ * \param f2 float2 number
501
+ * \param sat saturation of fp8
502
+ * \param type interpretation of fp8
503
+ * \return __hip_fp8x2_storage_t
504
+ */
505
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2(
506
+ const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
507
+ #if HIP_FP8_CVT_FAST_PATH
508
+ return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, type);
509
+ #else
510
+ return static_cast<__hip_fp8x2_storage_t>(
511
+ static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.y, sat, type)) << 8 |
512
+ static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.x, sat, type)));
513
+ #endif
514
+ }
515
+
516
+ /**
517
+ * \brief convert double to @p __hip_fp8_storage_t
518
+ *
519
+ * \param d double val
520
+ * \param sat saturation of fp8
521
+ * \param type interpretation of fp8
522
+ * \return __hip_fp8_storage_t
523
+ */
524
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(
525
+ const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
526
+ int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
527
+ int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
528
+ return internal::cast_to_f8<double, true>(d, wm, we, sat == __HIP_SATFINITE);
529
+ }
530
+
531
+ /**
532
+ * \brief convert double2 to @p __hip_fp8x2_storage_t
533
+ *
534
+ * \param d2 double2 val
535
+ * \param sat saturation of fp8
536
+ * \param type interpretation of fp8
537
+ * \return __hip_fp8x2_storage_t
538
+ */
539
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2(
540
+ const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
541
+ return static_cast<__hip_fp8x2_storage_t>(
542
+ static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.y, sat, type)) << 8 |
543
+ static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.x, sat, type)));
544
+ }
545
+
546
+ /**
547
+ * \brief convert __hip_bfloat16_raw to @p __hip_fp8_storage_t
548
+ *
549
+ * \param hr __hip_bfloat16_raw val
550
+ * \param sat saturation of fp8
551
+ * \param type interpretation of fp8
552
+ * \return __hip_fp8_storage_t
553
+ */
554
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t
555
+ __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat,
556
+ const __hip_fp8_interpretation_t type) {
557
+ float fval = __hip_bfloat16(hr);
558
+ return __hip_cvt_float_to_fp8(fval, sat, type);
559
+ }
560
+
561
+ /**
562
+ * \brief convert double2 to @p __hip_fp8x2_storage_t
563
+ *
564
+ * \param hr __hip_bfloat162_raw value
565
+ * \param sat saturation of fp8
566
+ * \param type interpretation of fp8
567
+ * \return __hip_fp8x2_storage_t
568
+ */
569
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t
570
+ __hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat,
571
+ const __hip_fp8_interpretation_t type) {
572
+ float2 f2 = __hip_bfloat162(hr);
573
+ return __hip_cvt_float2_to_fp8x2(f2, sat, type);
574
+ }
575
+
576
+ /**
577
+ * \brief convert @p __hip_fp8_storage_t to __half_raw
578
+ *
579
+ * \param x __hip_fp8_storage_t val
580
+ * \param type interpretation of fp8
581
+ * \return __half_raw
582
+ */
583
+ __FP8_HOST_DEVICE_STATIC__ __half_raw
584
+ __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t type) {
585
+ unsigned int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
586
+ unsigned int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
587
+ return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
588
+ }
589
+
590
+ /**
591
+ * \brief convert @p __hip_fp8x2_storage_t to __half2_raw
592
+ *
593
+ * \param x __hip_fp8x2_storage_t val
594
+ * \param type interpretation of fp8
595
+ * \return __half2_raw
596
+ */
597
+ __FP8_HOST_DEVICE_STATIC__ __half2_raw
598
+ __hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t type) {
599
+ __half2 ret(static_cast<__half>(
600
+ __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), type)),
601
+ static_cast<__half>(
602
+ __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), type)));
603
+ return static_cast<__half2_raw>(ret);
604
+ }
605
+
606
+ /**
607
+ * \brief convert __half_raw to @p __hip_fp8_storage_t
608
+ *
609
+ * \param x __half_raw value
610
+ * \param sat saturation of fp8
611
+ * \param type interpretation of fp8
612
+ * \return __hip_fp8_storage_t
613
+ */
614
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8(
615
+ const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
616
+ return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, type);
617
+ }
618
+
619
+ /**
620
+ * \brief convert __half2_raw to @p __hip_fp8x2_storage_t
621
+ *
622
+ * \param x __half2_raw value
623
+ * \param sat saturation of fp8
624
+ * \param type interpretation of fp8
625
+ * \return __hip_fp8x2_storage_t
626
+ */
627
+ __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(
628
+ const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
629
+ return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, type);
630
+ }
631
+
632
+ /**
633
+ * \brief struct representing single fp8 number with e4m3 interpretation
634
+ *
635
+ */
636
+ struct __hip_fp8_e4m3_fnuz {
637
+ __hip_fp8_storage_t __x; //! raw storage of fp8 number
638
+ constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE;
639
+ constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
640
+ constexpr static unsigned int __we = 4;
641
+ constexpr static unsigned int __wm = 3;
642
+
643
+ // TODO: SWDEV-452411
644
+ // Add cast from unsigned long long, long long to fp8
645
+
646
+ /*! create fp8 e4m3 from long */
647
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
648
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
649
+ __default_interpret)) {}
650
+
651
+ /*! create fp8 e4m3 from int */
652
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
653
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
654
+ __default_interpret)) {}
655
+
656
+ /*! create fp8 e4m3 from short int */
657
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
658
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
659
+ __default_interpret)) {}
660
+
661
+ /*! create fp8 e4m3 from unsigned long */
662
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
663
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
664
+ __default_interpret)) {}
665
+
666
+ /*! create fp8 e4m3 from unsigned int */
667
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
668
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
669
+ __default_interpret)) {}
670
+
671
+ /*! create fp8 e4m3 from unsigned short */
672
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
673
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
674
+ __default_interpret)) {}
675
+
676
+ /*! create fp8 e4m3 from double */
677
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
678
+ : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
679
+
680
+ /*! create fp8 e4m3 from float */
681
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
682
+ : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
683
+
684
+ /*! create fp8 e4m3 from __hip_bfloat16 */
685
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
686
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
687
+ __default_interpret)) {}
688
+
689
+ /*! create fp8 e4m3 from __half */
690
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
691
+ : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
692
+ __default_interpret)) {}
693
+
694
+ /*! default construct fp8 e4m3 */
695
+ __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz() = default;
696
+
697
+ /*! convert fp8 e4m3 to __half */
698
+ __FP8_HOST_DEVICE__ operator __half() const {
699
+ return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
700
+ }
701
+
702
+ /*! convert fp8 e4m3 to __hip_bfloat16 */
703
+ __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
704
+ float f = *this;
705
+ return __hip_bfloat16(f);
706
+ }
707
+
708
+ /*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */
709
+ __FP8_HOST_DEVICE__ operator bool() const {
710
+ // it can be 0x00 (+0.0) since 0x80 will be nan
711
+ return !(static_cast<unsigned short>(__x) == 0);
712
+ }
713
+
714
+ /*! convert fp8 e4m3 to char, clamp number to CHAR_MIN/CHAR_MAX if its out of range */
715
+ __FP8_HOST_DEVICE__ operator char() const {
716
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
717
+ return 0;
718
+ }
719
+
720
+ auto fval = internal::cast_from_f8<float, true>(__x, __wm, __we);
721
+ auto llval = static_cast<long long>(fval);
722
+ if (llval <= CHAR_MIN) {
723
+ return CHAR_MIN;
724
+ } else if (llval >= CHAR_MAX) {
725
+ return CHAR_MAX;
726
+ }
727
+ return static_cast<char>(fval);
728
+ }
729
+
730
+ /*! convert fp8 e4m3 to double */
731
+ __FP8_HOST_DEVICE__ operator double() const {
732
+ return internal::cast_from_f8<double, true>(__x, __wm, __we);
733
+ }
734
+
735
+ /*! convert fp8 e4m3 to float */
736
+ __FP8_HOST_DEVICE__ operator float() const {
737
+ #if HIP_FP8_CVT_FAST_PATH
738
+ return internal::cast_to_f32_from_f8(__x, __default_interpret);
739
+ #else
740
+ return internal::cast_from_f8<float, true>(__x, __wm, __we);
741
+ #endif
742
+ }
743
+
744
+ /*! convert fp8 e4m3 to int, return 0 if value is NaN */
745
+ __FP8_HOST_DEVICE__ operator int() const {
746
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
747
+ return 0;
748
+ }
749
+
750
+ float fval = *this;
751
+ return static_cast<int>(fval);
752
+ }
753
+
754
+ /*! convert fp8 e4m3 to long, return 0 if value is NaN */
755
+ __FP8_HOST_DEVICE__ operator long int() const {
756
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
757
+ return 0;
758
+ }
759
+
760
+ float fval = *this;
761
+ return static_cast<long>(fval);
762
+ }
763
+
764
+ /*! convert fp8 e4m3 to long long, return 0 if value is NaN */
765
+ __FP8_HOST_DEVICE__ operator long long int() const {
766
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
767
+ return 0;
768
+ }
769
+
770
+ float fval = *this;
771
+ return static_cast<long long>(fval);
772
+ }
773
+
774
+ /*! convert fp8 e4m3 to short int, clamp out of bound values, return 0 if value is NaN */
775
+ __FP8_HOST_DEVICE__ operator short int() const {
776
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
777
+ return 0;
778
+ }
779
+
780
+ float fval = *this;
781
+ auto llval = static_cast<long long>(fval);
782
+ if (llval <= SHRT_MIN) {
783
+ return SHRT_MIN;
784
+ } else if (llval >= SHRT_MAX) {
785
+ return SHRT_MAX;
786
+ }
787
+ return static_cast<short>(fval);
788
+ }
789
+
790
+ /*! convert fp8 e4m3 to signed char, clamp out of bound values, return 0 if value is NaN */
791
+ __FP8_HOST_DEVICE__ operator signed char() const {
792
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
793
+ return 0;
794
+ }
795
+
796
+ float fval = *this;
797
+ auto llval = static_cast<long long>(fval);
798
+ if (llval <= SCHAR_MIN) {
799
+ return SCHAR_MIN;
800
+ } else if (llval >= SCHAR_MAX) {
801
+ return SCHAR_MAX;
802
+ }
803
+ return static_cast<signed char>(fval);
804
+ }
805
+
806
+ /*! convert fp8 e4m3 to unsigned char, clamp out of bound values, return 0 if value is NaN */
807
+ __FP8_HOST_DEVICE__ operator unsigned char() const {
808
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
809
+ return 0;
810
+ }
811
+
812
+ float fval = *this;
813
+ auto llval = static_cast<long long>(fval);
814
+ if (llval <= 0) {
815
+ return 0;
816
+ } else if (llval >= UCHAR_MAX) {
817
+ return UCHAR_MAX;
818
+ }
819
+ return static_cast<unsigned char>(fval);
820
+ }
821
+
822
+ /*! convert fp8 e4m3 to unsigned int, return 0 if value is NaN */
823
+ __FP8_HOST_DEVICE__ operator unsigned int() const {
824
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
825
+ return 0;
826
+ }
827
+
828
+ float fval = *this;
829
+ auto llval = static_cast<long long>(fval);
830
+ if (llval <= 0) {
831
+ return 0;
832
+ }
833
+ return static_cast<unsigned int>(fval);
834
+ }
835
+
836
+ /*! convert fp8 e4m3 to unsigned long, return 0 if value is NaN */
837
+ __FP8_HOST_DEVICE__ operator unsigned long int() const {
838
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
839
+ return 0;
840
+ }
841
+
842
+ float fval = *this;
843
+ auto llval = static_cast<long long>(fval);
844
+ if (llval <= 0) {
845
+ return 0;
846
+ }
847
+ return static_cast<unsigned long>(fval);
848
+ }
849
+
850
+ /*! convert fp8 e4m3 to long long int, return 0 if value is NaN */
851
+ __FP8_HOST_DEVICE__ operator unsigned long long int() const {
852
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
853
+ return 0;
854
+ }
855
+
856
+ float fval = *this;
857
+ auto llval = static_cast<long long>(fval);
858
+ if (llval <= 0) {
859
+ return 0;
860
+ }
861
+ return static_cast<unsigned long long>(fval);
862
+ }
863
+
864
+ /*! convert fp8 e4m3 to unsigned short, return 0 if value is NaN */
865
+ __FP8_HOST_DEVICE__ operator unsigned short int() const {
866
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
867
+ return 0;
868
+ }
869
+
870
+ float fval = *this;
871
+ auto llval = static_cast<long long>(fval);
872
+ if (llval <= 0) {
873
+ return 0;
874
+ }
875
+ return static_cast<unsigned short>(fval);
876
+ }
877
+ };
878
+
879
+ /**
880
+ * \brief struct representing two fp8 numbers with e4m3 interpretation
881
+ *
882
+ */
883
+ struct __hip_fp8x2_e4m3_fnuz {
884
+ __hip_fp8x2_storage_t __x; //! raw storage of two fp8 numbers
885
+ static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
886
+ static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
887
+ static constexpr unsigned int __we = 4;
888
+ static constexpr unsigned int __wm = 3;
889
+
890
+ /*! create fp8x2 e4m3 type from double2 */
891
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
892
+ : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
893
+
894
+ /*! create fp8x2 e4m3 type from float2 */
895
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
896
+ : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
897
+
898
+ /*! create fp8x2 e4m3 type from __hip_bfloat162 */
899
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
900
+ : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
901
+
902
+ /*! create fp8x2 e4m3 type from __half2 */
903
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
904
+ : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
905
+
906
+ /*! Default construct of fp8x2 e4m3 */
907
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz() = default;
908
+
909
+ /*! convert fp8x2 e4m3 to __half2 */
910
+ __FP8_HOST_DEVICE__ operator __half2() const {
911
+ return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
912
+ }
913
+
914
+ /*! convert fp8x2 e4m3 to float2 */
915
+ __FP8_HOST_DEVICE__ operator float2() const {
916
+ #if HIP_FP8_CVT_FAST_PATH
917
+ return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
918
+ #else
919
+ return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
920
+ __wm, __we),
921
+ internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
922
+ __wm, __we));
923
+ #endif
924
+ }
925
+ };
926
+
927
+ /**
928
+ * \brief struct representing four fp8 numbers with e4m3 interpretation
929
+ *
930
+ */
931
+ struct __hip_fp8x4_e4m3_fnuz {
932
+ __hip_fp8x4_storage_t __x; //! raw storage of four fp8 numbers
933
+ static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
934
+ static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
935
+ static constexpr unsigned int __we = 4;
936
+ static constexpr unsigned int __wm = 3;
937
+
938
+ /*! create fp8x4 e4m3 type from double4 */
939
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
940
+ : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
941
+ static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
942
+ val.x, __default_saturation, __default_interpret)) |
943
+ reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
944
+ val.y, __default_saturation, __default_interpret))
945
+ << 8 |
946
+ reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
947
+ val.z, __default_saturation, __default_interpret))
948
+ << 16 |
949
+ reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
950
+ val.w, __default_saturation, __default_interpret))
951
+ << 24))} {}
952
+
953
+ /*! create fp8x4 e4m3 type from float4 */
954
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
955
+ : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
956
+ static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
957
+ val.x, __default_saturation, __default_interpret)) |
958
+ reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
959
+ val.y, __default_saturation, __default_interpret))
960
+ << 8 |
961
+ reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
962
+ val.z, __default_saturation, __default_interpret))
963
+ << 16 |
964
+ reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
965
+ val.w, __default_saturation, __default_interpret))
966
+ << 24))} {}
967
+
968
+ /*! create fp8x4 e4m3 type from two __hip_bfloat162 */
969
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
970
+ : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
971
+ reinterpret_cast<unsigned short>(
972
+ __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
973
+ reinterpret_cast<unsigned short>(
974
+ __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
975
+ << 16))) {}
976
+
977
+ /*! create fp8x4 e4m3 type from two __half2 */
978
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
979
+ : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
980
+ static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
981
+ high, __default_saturation, __default_interpret)) |
982
+ reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
983
+ low, __default_saturation, __default_interpret))
984
+ << 16))) {}
985
+
986
+ /*! Default construct fp8x4 e4m3 */
987
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz() = default;
988
+
989
+ /*! convert fp8x4 e4m3 to float4 */
990
+ __FP8_HOST_DEVICE__ operator float4() const {
991
+ auto x = __x; // bypass const
992
+ auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
993
+ auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
994
+ #if HIP_FP8_CVT_FAST_PATH
995
+ float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
996
+ float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
997
+ #else
998
+ float2 high = float2(internal::cast_from_f8<float, true>(
999
+ static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1000
+ internal::cast_from_f8<float, true>(
1001
+ static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1002
+ float2 low = float2(internal::cast_from_f8<float, true>(
1003
+ static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1004
+ internal::cast_from_f8<float, true>(
1005
+ static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1006
+ #endif
1007
+ return float4(low.x, low.y, high.x, high.y);
1008
+ }
1009
+ };
1010
+
1011
+ /**
1012
+ * \brief struct representing one fp8 number with e5m2 interpretation
1013
+ *
1014
+ */
1015
+ struct __hip_fp8_e5m2_fnuz {
1016
+ __hip_fp8_storage_t __x; //! raw storage of one fp8 numbers
1017
+ static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1018
+ static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1019
+ static constexpr unsigned int __we = 5;
1020
+ static constexpr unsigned int __wm = 2;
1021
+
1022
+
1023
+ // TODO: SWDEV-452411
1024
+ // Add cast from unsigned long long, long long to fp8
1025
+
1026
+ /*! create fp8 e5m2 type from long */
1027
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
1028
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1029
+ __default_interpret)) {}
1030
+
1031
+ /*! create fp8 e5m2 type from int */
1032
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
1033
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1034
+ __default_interpret)) {}
1035
+
1036
+ /*! create fp8 e5m2 type from short int */
1037
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
1038
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1039
+ __default_interpret)) {}
1040
+
1041
+ /*! create fp8 e5m2 type from unsigned long */
1042
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
1043
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1044
+ __default_interpret)) {}
1045
+
1046
+ /*! create fp8 e5m2 type from unsigned int */
1047
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
1048
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1049
+ __default_interpret)) {}
1050
+
1051
+ /*! create fp8 e5m2 type from unsigned short */
1052
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
1053
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1054
+ __default_interpret)) {}
1055
+
1056
+ /*! create fp8 e5m2 type from double */
1057
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
1058
+ : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
1059
+
1060
+ /*! create fp8 e5m2 type from float */
1061
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
1062
+ : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
1063
+
1064
+ /*! create fp8 e5m2 type from __hip_bfloat16 */
1065
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
1066
+ : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
1067
+ __default_interpret)) {}
1068
+
1069
+ /*! create fp8 e5m2 type from __hip_bfloat16 */
1070
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
1071
+ : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
1072
+ __default_interpret)) {}
1073
+
1074
+ /*! default construct fp8 e5m2 */
1075
+ __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz() = default;
1076
+
1077
+ /*! convert fp8 e5m2 to float */
1078
+ __FP8_HOST_DEVICE__ operator float() const {
1079
+ #if HIP_FP8_CVT_FAST_PATH
1080
+ return internal::cast_to_f32_from_f8(__x, __default_interpret);
1081
+ #else
1082
+ return internal::cast_from_f8<float, true>(__x, __wm, __we);
1083
+ #endif
1084
+ }
1085
+
1086
+ /*! convert fp8 e5m2 to __half */
1087
+ __FP8_HOST_DEVICE__ operator __half() const {
1088
+ return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
1089
+ }
1090
+
1091
+ /*! convert fp8 e5m2 to __hip_bfloat16 */
1092
+ __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
1093
+ float f = *this;
1094
+ return __hip_bfloat16(f);
1095
+ }
1096
+
1097
+ /*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */
1098
+ __FP8_HOST_DEVICE__ operator bool() const {
1099
+ // it can be 0x00 (+0.0) since 0x80 will be nan
1100
+ return !(static_cast<unsigned short>(__x) == 0);
1101
+ }
1102
+
1103
+ /*! convert fp8 e5m2 to char, clamp out of bound values, return 0 if value is NaN */
1104
+ __FP8_HOST_DEVICE__ operator char() const {
1105
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1106
+ return 0;
1107
+ }
1108
+
1109
+ float fval = *this;
1110
+ auto llval = static_cast<long long>(fval);
1111
+ if (llval <= CHAR_MIN) {
1112
+ return CHAR_MIN;
1113
+ } else if (llval >= CHAR_MAX) {
1114
+ return CHAR_MAX;
1115
+ }
1116
+ return static_cast<char>(fval);
1117
+ }
1118
+
1119
+ /*! convert fp8 e5m2 to double */
1120
+ __FP8_HOST_DEVICE__ operator double() const {
1121
+ return internal::cast_from_f8<double, true>(__x, __wm, __we);
1122
+ }
1123
+
1124
+ /*! convert fp8 e5m2 to int, return 0 if value is NaN */
1125
+ __FP8_HOST_DEVICE__ operator int() const {
1126
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1127
+ return 0;
1128
+ }
1129
+
1130
+ float fval = *this;
1131
+ return static_cast<int>(fval);
1132
+ }
1133
+
1134
+ /*! convert fp8 e5m2 to long, return 0 if value is NaN */
1135
+ __FP8_HOST_DEVICE__ operator long int() const {
1136
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1137
+ return 0;
1138
+ }
1139
+
1140
+ float fval = *this;
1141
+ return static_cast<long>(fval);
1142
+ }
1143
+
1144
+ /*! convert fp8 e5m2 to long long, return 0 if value is NaN */
1145
+ __FP8_HOST_DEVICE__ operator long long int() const {
1146
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1147
+ return 0;
1148
+ }
1149
+
1150
+ float fval = *this;
1151
+ return static_cast<long long>(fval);
1152
+ }
1153
+
1154
+ /*! convert fp8 e5m2 to short, clamp out of bound values, return 0 if value is NaN */
1155
+ __FP8_HOST_DEVICE__ operator short int() const {
1156
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1157
+ return 0;
1158
+ }
1159
+
1160
+ float fval = *this;
1161
+ auto llval = static_cast<long long>(fval);
1162
+ if (llval <= SHRT_MIN) {
1163
+ return SHRT_MIN;
1164
+ } else if (llval >= SHRT_MAX) {
1165
+ return SHRT_MAX;
1166
+ }
1167
+ return static_cast<short>(fval);
1168
+ }
1169
+
1170
+ /*! convert fp8 e5m2 to signed char, clamp out of bound values, return 0 if value is NaN */
1171
+ __FP8_HOST_DEVICE__ operator signed char() const {
1172
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1173
+ return 0;
1174
+ }
1175
+
1176
+ float fval = *this;
1177
+ auto llval = static_cast<long long>(fval);
1178
+ if (llval <= SCHAR_MIN) {
1179
+ return SCHAR_MIN;
1180
+ } else if (llval >= SCHAR_MAX) {
1181
+ return SCHAR_MAX;
1182
+ }
1183
+ return static_cast<signed char>(fval);
1184
+ }
1185
+
1186
+ /*! convert fp8 e5m2 to unsigned char, clamp out of bound values, return 0 if value is NaN */
1187
+ __FP8_HOST_DEVICE__ operator unsigned char() const {
1188
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1189
+ return 0;
1190
+ }
1191
+
1192
+ float fval = *this;
1193
+ auto llval = static_cast<long long>(fval);
1194
+ if (llval <= 0) {
1195
+ return 0;
1196
+ } else if (llval >= UCHAR_MAX) {
1197
+ return UCHAR_MAX;
1198
+ }
1199
+ return static_cast<unsigned char>(fval);
1200
+ }
1201
+
1202
+ /*! convert fp8 e5m2 to unsigned int, return 0 if value is NaN */
1203
+ __FP8_HOST_DEVICE__ operator unsigned int() const {
1204
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1205
+ return 0;
1206
+ }
1207
+
1208
+ float fval = *this;
1209
+ auto llval = static_cast<long long>(fval);
1210
+ if (llval <= 0) {
1211
+ return 0;
1212
+ }
1213
+ return static_cast<unsigned int>(fval);
1214
+ }
1215
+
1216
+ /*! convert fp8 e5m2 to unsigned long, return 0 if value is NaN */
1217
+ __FP8_HOST_DEVICE__ operator unsigned long int() const {
1218
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1219
+ return 0;
1220
+ }
1221
+
1222
+ float fval = *this;
1223
+ auto llval = static_cast<long long>(fval);
1224
+ if (llval <= 0) {
1225
+ return 0;
1226
+ }
1227
+ return static_cast<unsigned long>(fval);
1228
+ }
1229
+
1230
+ /*! convert fp8 e5m2 to unsigned long long, return 0 if value is NaN */
1231
+ __FP8_HOST_DEVICE__ operator unsigned long long int() const {
1232
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1233
+ return 0;
1234
+ }
1235
+
1236
+ float fval = *this;
1237
+ auto llval = static_cast<long long>(fval);
1238
+ if (llval <= 0) {
1239
+ return 0;
1240
+ }
1241
+ return static_cast<unsigned long long>(fval);
1242
+ }
1243
+
1244
+ /*! convert fp8 e5m2 to unsigned short, return 0 if value is NaN */
1245
+ __FP8_HOST_DEVICE__ operator unsigned short int() const {
1246
+ if (internal::hip_fp8_fnuz_is_nan(__x)) {
1247
+ return 0;
1248
+ }
1249
+
1250
+ float fval = *this;
1251
+ auto llval = static_cast<long long>(fval);
1252
+ if (llval <= 0) {
1253
+ return 0;
1254
+ }
1255
+ return static_cast<unsigned short>(fval);
1256
+ }
1257
+ };
1258
+
1259
+ /**
1260
+ * \brief struct representing two fp8 numbers with e5m2 interpretation
1261
+ *
1262
+ */
1263
+ struct __hip_fp8x2_e5m2_fnuz {
1264
+ __hip_fp8x2_storage_t __x; //! raw storage of two fp8 numbers
1265
+ static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1266
+ static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1267
+ static constexpr unsigned int __we = 5;
1268
+ static constexpr unsigned int __wm = 2;
1269
+
1270
+ /*! create fp8x2 e5m2 type from double2 */
1271
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
1272
+ : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1273
+
1274
+ /*! create fp8x2 e5m2 type from float2 */
1275
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
1276
+ : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1277
+
1278
+ /*! create fp8x2 e5m2 type from __hip_bfloat162 */
1279
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
1280
+ : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1281
+
1282
+ /*! create fp8x2 e5m2 type from __half2 */
1283
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
1284
+ : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1285
+
1286
+ /*! default construct fp8x2 e5m2 */
1287
+ __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz() = default;
1288
+
1289
+ /*! convert fp8x2 e5m2 to __half2 */
1290
+ __FP8_HOST_DEVICE__ operator __half2() const {
1291
+ return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
1292
+ }
1293
+
1294
+ /*! convert fp8x2 e5m2 to float2 */
1295
+ __FP8_HOST_DEVICE__ operator float2() const {
1296
+ #if HIP_FP8_CVT_FAST_PATH
1297
+ return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1298
+ #else
1299
+ return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
1300
+ __wm, __we),
1301
+ internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
1302
+ __wm, __we));
1303
+ #endif
1304
+ }
1305
+ };
1306
+
1307
+ /**
1308
+ * \brief struct representing four fp8 numbers with e5m2 interpretation
1309
+ *
1310
+ */
1311
+ struct __hip_fp8x4_e5m2_fnuz {
1312
+ __hip_fp8x4_storage_t __x; //! raw storage of four fp8 numbers
1313
+ static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1314
+ static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1315
+ static constexpr unsigned int __we = 5;
1316
+ static constexpr unsigned int __wm = 2;
1317
+
1318
+ /*! create fp8x4 e5m2 type from double4 */
1319
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
1320
+ : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1321
+ static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1322
+ val.x, __default_saturation, __default_interpret)) |
1323
+ reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1324
+ val.y, __default_saturation, __default_interpret))
1325
+ << 8 |
1326
+ reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1327
+ val.z, __default_saturation, __default_interpret))
1328
+ << 16 |
1329
+ reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1330
+ val.w, __default_saturation, __default_interpret))
1331
+ << 24))) {}
1332
+
1333
+ /*! create fp8x4 e5m2 type from float4 */
1334
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
1335
+ : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1336
+ static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1337
+ val.x, __default_saturation, __default_interpret)) |
1338
+ reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1339
+ val.y, __default_saturation, __default_interpret))
1340
+ << 8 |
1341
+ reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1342
+ val.z, __default_saturation, __default_interpret))
1343
+ << 16 |
1344
+ reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1345
+ val.w, __default_saturation, __default_interpret))
1346
+ << 24))) {}
1347
+
1348
+ /*! create fp8x4 e5m2 type from two __hip_bfloat162 */
1349
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1350
+ : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
1351
+ reinterpret_cast<unsigned short>(
1352
+ __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
1353
+ reinterpret_cast<unsigned short>(
1354
+ __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
1355
+ << 16))) {}
1356
+
1357
+ /*! create fp8x4 e5m2 type from two __half2 */
1358
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
1359
+ : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1360
+ static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1361
+ high, __default_saturation, __default_interpret)) |
1362
+ reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1363
+ low, __default_saturation, __default_interpret))
1364
+ << 16))) {}
1365
+
1366
+ /* default construct fp8x4 e5m2 */
1367
+ __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz() = default;
1368
+
1369
+ /*! convert fp8x4 e5m2 to float4 */
1370
+ __FP8_HOST_DEVICE__ operator float4() const {
1371
+ auto x = __x; // bypass const
1372
+ auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
1373
+ auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
1374
+ #if HIP_FP8_CVT_FAST_PATH
1375
+ float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
1376
+ float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
1377
+ #else
1378
+ float2 high = float2(internal::cast_from_f8<float, true>(
1379
+ static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1380
+ internal::cast_from_f8<float, true>(
1381
+ static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1382
+ float2 low = float2(internal::cast_from_f8<float, true>(
1383
+ static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1384
+ internal::cast_from_f8<float, true>(
1385
+ static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1386
+ #endif
1387
+ return float4(low.x, low.y, high.x, high.y);
1388
+ }
1389
+ };
1390
+
1391
+ #endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_