triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * MIT License
3
3
  *
4
- * Copyright (c) 2019 - 2023 Advanced Micro Devices, Inc. All rights reserved.
4
+ * Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved.
5
5
  *
6
6
  * Permission is hereby granted, free of charge, to any person obtaining a copy
7
7
  * of this software and associated documentation files (the "Software"), to deal
@@ -81,6 +81,17 @@
81
81
  * To use these functions, include the header file \p hip_bf16.h in your program.
82
82
  */
83
83
 
84
+ /**
85
+ * \defgroup HIP_INTRINSIC_BFLOAT16_RAW Bfloat16 Raw Struct
86
+ * \ingroup HIP_INTRINSIC_BFLOAT16
87
+ * To use these functions, include the header file \p hip_bf16.h in your program.
88
+ */
89
+
90
+ /**
91
+ * \defgroup HIP_INTRINSIC_BFLOAT162_RAW Bfloat162 Raw Struct
92
+ * \ingroup HIP_INTRINSIC_BFLOAT16
93
+ * To use these functions, include the header file \p hip_bf16.h in your program.
94
+ */
84
95
 
85
96
  #ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_
86
97
  #define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_
@@ -93,13 +104,30 @@
93
104
  #include "device_library_decls.h" // ocml conversion functions
94
105
  #include "math_fwd.h" // ocml device functions
95
106
 
107
+ #define __BF16_DEVICE__ __device__
96
108
  #if defined(__HIPCC_RTC__)
97
- #define __HOST_DEVICE__ __device__ static
109
+ #define __BF16_HOST_DEVICE__ __BF16_DEVICE__
98
110
  #else
99
111
  #include <algorithm>
100
112
  #include <climits>
101
113
  #include <cmath>
102
- #define __HOST_DEVICE__ __host__ __device__ static inline
114
+ #define __BF16_HOST_DEVICE__ __host__ __BF16_DEVICE__
115
+ #endif
116
+ #define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline
117
+ #define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline
118
+
119
+ #if defined(__AVX512VL__) and defined(__AVX512BF16__) and not defined(__HIP_DEVICE_COMPILE__)
120
+ // Enable with -mavx512vl -mavx512bf16
121
+ #if defined(__MINGW64__)
122
+ #include <intrin.h>
123
+ #else
124
+ #include <immintrin.h>
125
+ #endif
126
+ #define HIP_BF16_AVX512_OP 1
127
+ static_assert(sizeof(__bf16) == sizeof(unsigned short),
128
+ "sizeof __bf16 should match sizeof unsigned short");
129
+ #else
130
+ #define HIP_BF16_AVX512_OP 0
103
131
  #endif
104
132
 
105
133
  #define HIPRT_ONE_BF16 __float2bfloat16(1.0f)
@@ -118,72 +146,361 @@ static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
118
146
  #endif
119
147
  static_assert(sizeof(unsigned short) == 2, "size of unsigned short should be 2 bytes");
120
148
 
121
- /*! \brief Struct to represent a 16 bit brain floating point number. */
122
- struct __hip_bfloat16 {
123
- unsigned short data;
149
+ /**
150
+ * \ingroup HIP_INTRINSIC_BFLOAT16_RAW
151
+ * \brief represents raw bfloat16 type
152
+ */
153
+ typedef struct __attribute__((aligned(2))) {
154
+ unsigned short x;
155
+ } __hip_bfloat16_raw;
156
+
157
+ /**
158
+ * \ingroup HIP_INTRINSIC_BFLOAT162_RAW
159
+ * \brief represents raw bfloat16x2 vector type
160
+ */
161
+ typedef struct __attribute__((aligned(4))) {
162
+ unsigned short x;
163
+ unsigned short y;
164
+ } __hip_bfloat162_raw;
165
+
166
+ /**
167
+ * \defgroup HIP_INTRINSIC_BFLOAT16_STRUCT
168
+ * \ingroup HIP_INTRINSIC_BFLOAT16
169
+ * \brief Struct to represent a 16 bit brain floating point number.
170
+ * @{
171
+ */
172
+ struct __attribute__((aligned(2))) __hip_bfloat16 {
173
+ private:
174
+ __BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) {
175
+ #if HIP_BF16_AVX512_OP
176
+ union {
177
+ unsigned short us;
178
+ __bf16 bf16;
179
+ } u = {val};
180
+ return _mm_cvtsbh_ss(u.bf16);
181
+ #else
182
+ unsigned int uval = val << 16;
183
+ union {
184
+ unsigned int u32;
185
+ float fp32;
186
+ } u = {uval};
187
+ return u.fp32;
188
+ #endif
189
+ }
190
+ __BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) {
191
+ #if HIP_BF16_AVX512_OP
192
+ union {
193
+ __bf16 bf16;
194
+ unsigned short us;
195
+ } u = {_mm_cvtness_sbh(f)};
196
+ return u.us;
197
+ #else
198
+ union {
199
+ float fp32;
200
+ unsigned int u32;
201
+ } u = {f};
202
+ if (~u.u32 & 0x7f800000) {
203
+ // When the exponent bits are not all 1s, then the value is zero, normal,
204
+ // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
205
+ // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
206
+ // This causes the bfloat16's mantissa to be incremented by 1 if the 16
207
+ // least significant bits of the float mantissa are greater than 0x8000,
208
+ // or if they are equal to 0x8000 and the least significant bit of the
209
+ // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
210
+ // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
211
+ // has the value 0x7f, then incrementing it causes it to become 0x00 and
212
+ // the exponent is incremented by one, which is the next higher FP value
213
+ // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
214
+ // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
215
+ // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
216
+ // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
217
+ // incrementing it causes it to become an exponent of 0xFF and a mantissa
218
+ // of 0x00, which is Inf, the next higher value to the unrounded value.
219
+ u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even
220
+ } else if (u.u32 & 0xffff) {
221
+ // When all of the exponent bits are 1, the value is Inf or NaN.
222
+ // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
223
+ // mantissa bit. Quiet NaN is indicated by the most significant mantissa
224
+ // bit being 1. Signaling NaN is indicated by the most significant
225
+ // mantissa bit being 0 but some other bit(s) being 1. If any of the
226
+ // lower 16 bits of the mantissa are 1, we set the least significant bit
227
+ // of the bfloat16 mantissa, in order to preserve signaling NaN in case
228
+ // the bloat16's mantissa bits are all 0.
229
+ u.u32 |= 0x10000; // Preserve signaling NaN
230
+ }
231
+ return static_cast<unsigned short>(u.u32 >> 16);
232
+ #endif
233
+ }
234
+
235
+ __BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) {
236
+ union {
237
+ float fp32;
238
+ unsigned int u32;
239
+ } u = {static_cast<float>(d_in)};
240
+ double d = u.fp32;
241
+
242
+ // Round to odd
243
+ if ((d_in > 0.0 && d > d_in) || (d_in < 0.0 && d < d_in)) {
244
+ u.u32--;
245
+ u.u32 |= 1;
246
+ }
247
+
248
+ return float_2_bfloatraw(u.fp32);
249
+ }
250
+
251
+ protected:
252
+ /*! \brief raw representation of bfloat16 */
253
+ unsigned short __x;
254
+
255
+ public:
256
+ // TODO: SWDEV-452411
257
+ // Need to add constructor of __hip_bfloat16 from
258
+ // unsigned long long
259
+ // long long
260
+ // long
261
+ // unsigned long
262
+ // Casting directly to double might lead to double rounding.
263
+
264
+ /*! \brief create __hip_bfloat16 from an unsigned int */
265
+ __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned int val)
266
+ : __x(double_2_bfloatraw(static_cast<double>(val))) {}
267
+
268
+ /*! \brief create __hip_bfloat16 from a int */
269
+ __BF16_HOST_DEVICE__ __hip_bfloat16(int val)
270
+ : __x(double_2_bfloatraw(static_cast<double>(val))) {}
271
+
272
+ /*! \brief create __hip_bfloat16 from an unsigned short */
273
+ __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned short val)
274
+ : __x(float_2_bfloatraw(static_cast<float>(val))) {}
275
+
276
+ /*! \brief create __hip_bfloat16 from a short */
277
+ __BF16_HOST_DEVICE__ __hip_bfloat16(short val)
278
+ : __x(float_2_bfloatraw(static_cast<float>(val))) {}
279
+
280
+ /*! \brief create __hip_bfloat16 from a double */
281
+ __BF16_HOST_DEVICE__ __hip_bfloat16(const double val) : __x(double_2_bfloatraw(val)) {}
282
+
283
+ /*! \brief create __hip_bfloat16 from a float */
284
+ __BF16_HOST_DEVICE__ __hip_bfloat16(const float val) : __x(float_2_bfloatraw(val)) {}
285
+
286
+ /*! \brief create __hip_bfloat16 from a __hip_bfloat16_raw */
287
+ __BF16_HOST_DEVICE__ __hip_bfloat16(const __hip_bfloat16_raw& val) : __x(val.x) {}
288
+
289
+ /*! \brief default constructor */
290
+ __BF16_HOST_DEVICE__ __hip_bfloat16() = default;
291
+
292
+ /*! \brief return a __hip_bfloat16_raw */
293
+ __BF16_HOST_DEVICE__ operator __hip_bfloat16_raw() const { return __hip_bfloat16_raw{__x}; }
294
+
295
+ /*! \brief return a __hip_bfloat16_raw cv qualifier */
296
+ __BF16_HOST_DEVICE__ operator __hip_bfloat16_raw() const volatile {
297
+ return __hip_bfloat16_raw{__x};
298
+ }
299
+
300
+ /*! \brief return false if bfloat value is +0.0 or -0.0, returns true otherwise */
301
+ __BF16_HOST_DEVICE__ operator bool() const {
302
+ auto val = bfloatraw_2_float(__x);
303
+ return val != 0.0f && val != -0.0f;
304
+ }
305
+
306
+ /*! \brief return a casted char from underlying float val */
307
+ __BF16_HOST_DEVICE__ operator char() const { return static_cast<char>(bfloatraw_2_float(__x)); }
308
+
309
+ /*! \brief return a float */
310
+ __BF16_HOST_DEVICE__ operator float() const { return bfloatraw_2_float(__x); }
311
+
312
+ /*! \brief return a casted int casted from float of underlying bfloat16 value */
313
+ __BF16_HOST_DEVICE__ operator int() const { return static_cast<int>(bfloatraw_2_float(__x)); }
314
+
315
+ /*! \brief return a casted long casted from float of underlying bfloat16 value */
316
+ __BF16_HOST_DEVICE__ operator long() const { return static_cast<long>(bfloatraw_2_float(__x)); }
317
+
318
+ /*! \brief return a casted long long casted from float of underlying bfloat16 value */
319
+ __BF16_HOST_DEVICE__ operator long long() const {
320
+ return static_cast<long long>(bfloatraw_2_float(__x));
321
+ }
322
+
323
+ /*! \brief return a casted short casted from float of underlying bfloat16 value */
324
+ __BF16_HOST_DEVICE__ operator short() const { return static_cast<short>(bfloatraw_2_float(__x)); }
325
+
326
+ /*! \brief return a casted signed char from float of underlying bfloat16 value */
327
+ __BF16_HOST_DEVICE__ operator signed char() const {
328
+ return static_cast<signed char>(bfloatraw_2_float(__x));
329
+ }
330
+
331
+ /*! \brief return a casted unsigned char casted from float of underlying bfloat16 value */
332
+ __BF16_HOST_DEVICE__ operator unsigned char() const {
333
+ return static_cast<unsigned char>(bfloatraw_2_float(__x));
334
+ }
335
+
336
+ /*! \brief return a casted unsigned int casted from float of underlying bfloat16 value */
337
+ __BF16_HOST_DEVICE__ operator unsigned int() const {
338
+ return static_cast<unsigned int>(bfloatraw_2_float(__x));
339
+ }
340
+
341
+ /*! \brief return a casted unsigned from float of underlying bfloat16 value */
342
+ __BF16_HOST_DEVICE__ operator unsigned long() const {
343
+ return static_cast<unsigned long>(bfloatraw_2_float(__x));
344
+ }
345
+
346
+ /*! \brief return a casted unsigned long long from float of underlying bfloat16 value */
347
+ __BF16_HOST_DEVICE__ operator unsigned long long() const {
348
+ return static_cast<unsigned long long>(bfloatraw_2_float(__x));
349
+ }
350
+
351
+ /*! \brief return a casted unsigned short from float of underlying bfloat16 value */
352
+ __BF16_HOST_DEVICE__ operator unsigned short() const {
353
+ return static_cast<unsigned short>(bfloatraw_2_float(__x));
354
+ }
355
+
356
+ // TODO: SWDEV-452411 add operator which converts unsigned long long and long long to bfloat
357
+
358
+ /*! \brief assign value from an unsigned int */
359
+ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned int val) {
360
+ __x = float_2_bfloatraw(static_cast<float>(val));
361
+ return *this;
362
+ }
363
+
364
+ /*! \brief assign value from a int */
365
+ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(int val) {
366
+ __x = float_2_bfloatraw(static_cast<float>(val));
367
+ return *this;
368
+ }
369
+
370
+ /*! \brief assign value from an unsigned short */
371
+ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned short val) {
372
+ __x = float_2_bfloatraw(static_cast<float>(val));
373
+ return *this;
374
+ }
375
+
376
+ /*! \brief assign value from a short int */
377
+ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(short val) {
378
+ __x = float_2_bfloatraw(static_cast<float>(val));
379
+ return *this;
380
+ }
381
+
382
+ /*! \brief assign value from a double */
383
+ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const double f) {
384
+ __x = float_2_bfloatraw(static_cast<float>(f));
385
+ return *this;
386
+ }
387
+
388
+ /*! \brief assign value from a float */
389
+ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const float f) {
390
+ __x = float_2_bfloatraw(f);
391
+ return *this;
392
+ }
393
+
394
+ /*! \brief assign value from a __hip_bfloat16_raw */
395
+ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const __hip_bfloat16_raw& hr) {
396
+ __x = hr.x;
397
+ return *this;
398
+ }
399
+
400
+ /*! \brief assign value from a __hip_bfloat16_raw volatile */
401
+ __BF16_HOST_DEVICE__ volatile __hip_bfloat16& operator=(const __hip_bfloat16_raw& hr) volatile {
402
+ __x = hr.x;
403
+ return *this;
404
+ }
405
+
406
+ /*! \brief assign value from a __hip_bfloat16_raw cv qualifier */
407
+ __BF16_HOST_DEVICE__ volatile __hip_bfloat16& operator=(
408
+ const volatile __hip_bfloat16_raw& hr) volatile {
409
+ __x = hr.x;
410
+ return *this;
411
+ }
124
412
  };
413
+ /**@}*/
414
+
415
+ /**
416
+ * \defgroup HIP_INTRINSIC_BFLOAT162_STRUCT
417
+ * \ingroup HIP_INTRINSIC_BFLOAT16
418
+ * \brief Struct to represent a two 16 bit brain floating point number.
419
+ * @{
420
+ */
421
+ struct __attribute__((aligned(4))) __hip_bfloat162 {
422
+ public:
423
+ __hip_bfloat16 x; /*! \brief raw representation of bfloat16 */
424
+ __hip_bfloat16 y; /*! \brief raw representation of bfloat16 */
425
+
426
+
427
+ public:
428
+ /*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */
429
+ __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162_raw& h2r)
430
+ : x(__hip_bfloat16(__hip_bfloat16_raw{h2r.x})),
431
+ y(__hip_bfloat16(__hip_bfloat16_raw{h2r.y})) {}
432
+
433
+ /*! \brief copy constructor of __hip_bfloat162 */
434
+ __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162& val) {
435
+ __hip_bfloat162_raw hr = val;
436
+ x = __hip_bfloat16_raw{hr.x};
437
+ y = __hip_bfloat16_raw{hr.y};
438
+ }
439
+
440
+ /*! \brief create __hip_bfloat162 from two __hip_bfloat16 */
441
+ __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat16& a, const __hip_bfloat16& b)
442
+ : x(a), y(b) {}
443
+
444
+ /*! \brief default constructor of __hip_bfloat162 */
445
+ __BF16_HOST_DEVICE__ __hip_bfloat162() = default;
446
+
447
+ /*! \brief return a __hip_bfloat162_raw */
448
+ __BF16_HOST_DEVICE__ operator __hip_bfloat162_raw() const {
449
+ __hip_bfloat16_raw l = x;
450
+ __hip_bfloat16_raw r = y;
451
+ return __hip_bfloat162_raw{l.x, r.x};
452
+ }
125
453
 
126
- /*! \brief Struct to represent two 16 bit brain floating point numbers. */
127
- struct __hip_bfloat162 {
128
- __hip_bfloat16 x;
129
- __hip_bfloat16 y;
454
+ /*! \brief return a float2 */
455
+ __BF16_HOST_DEVICE__ operator float2() const {
456
+ #if HIP_BF16_AVX512_OP
457
+ union {
458
+ __hip_bfloat162_raw raw2;
459
+ __bf16 bf162[2];
460
+ static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw));
461
+ } u;
462
+ u.raw2 = *this;
463
+ __m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0};
464
+ __m128 pf32 = _mm_cvtpbh_ps(pbf16);
465
+ float2 ret(pf32[0], pf32[1]);
466
+ #else
467
+ float2 ret(x, y);
468
+ #endif
469
+ return ret;
470
+ }
471
+
472
+ /*! \brief assign value from __hip_bfloat162_raw */
473
+ __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162_raw& h2r) {
474
+ x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x});
475
+ y = __hip_bfloat16(__hip_bfloat16_raw{h2r.y});
476
+ return *this;
477
+ }
478
+
479
+ /*! \brief assign value from __hip_bfloat162 */
480
+ __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162& src) {
481
+ __hip_bfloat162_raw hr = src;
482
+ x = __hip_bfloat16(__hip_bfloat16_raw{hr.x});
483
+ y = __hip_bfloat16(__hip_bfloat16_raw{hr.y});
484
+ return *this;
485
+ }
130
486
  };
487
+ /**@}*/
131
488
 
132
489
  /**
133
490
  * \ingroup HIP_INTRINSIC_BFLOAT16_CONV
134
491
  * \brief Converts bfloat16 to float
135
492
  */
136
- __HOST_DEVICE__ inline float __bfloat162float(__hip_bfloat16 a) {
137
- unsigned int uval = 0;
138
- uval = a.data << 16;
139
- union {
140
- unsigned int u32;
141
- float fp32;
142
- } u = {uval};
143
- return u.fp32;
493
+ __BF16_HOST_DEVICE_STATIC__ float __bfloat162float(__hip_bfloat16 a) {
494
+ float ret = a;
495
+ return ret;
144
496
  }
145
497
 
146
498
  /**
147
499
  * \ingroup HIP_INTRINSIC_BFLOAT16_CONV
148
500
  * \brief Converts float to bfloat16
149
501
  */
150
- __HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {
151
- __hip_bfloat16 ret;
152
- union {
153
- float fp32;
154
- unsigned int u32;
155
- } u = {f};
156
- if (~u.u32 & 0x7f800000) {
157
- // When the exponent bits are not all 1s, then the value is zero, normal,
158
- // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
159
- // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
160
- // This causes the bfloat16's mantissa to be incremented by 1 if the 16
161
- // least significant bits of the float mantissa are greater than 0x8000,
162
- // or if they are equal to 0x8000 and the least significant bit of the
163
- // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
164
- // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
165
- // has the value 0x7f, then incrementing it causes it to become 0x00 and
166
- // the exponent is incremented by one, which is the next higher FP value
167
- // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
168
- // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
169
- // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
170
- // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
171
- // incrementing it causes it to become an exponent of 0xFF and a mantissa
172
- // of 0x00, which is Inf, the next higher value to the unrounded value.
173
- u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even
174
- } else if (u.u32 & 0xffff) {
175
- // When all of the exponent bits are 1, the value is Inf or NaN.
176
- // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
177
- // mantissa bit. Quiet NaN is indicated by the most significant mantissa
178
- // bit being 1. Signaling NaN is indicated by the most significant
179
- // mantissa bit being 0 but some other bit(s) being 1. If any of the
180
- // lower 16 bits of the mantissa are 1, we set the least significant bit
181
- // of the bfloat16 mantissa, in order to preserve signaling NaN in case
182
- // the bloat16's mantissa bits are all 0.
183
- u.u32 |= 0x10000; // Preserve signaling NaN
184
- }
185
-
186
- ret.data = (u.u32 >> 16);
502
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __float2bfloat16(float f) {
503
+ __hip_bfloat16 ret{f};
187
504
  return ret;
188
505
  }
189
506
 
@@ -191,43 +508,51 @@ __HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {
191
508
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
192
509
  * \brief Converts and moves bfloat162 to float2
193
510
  */
194
- __HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
195
- return float2{__bfloat162float(a.x), __bfloat162float(a.y)};
511
+ __BF16_HOST_DEVICE_STATIC__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
512
+ float2 ret = a;
513
+ return ret;
196
514
  }
197
515
 
198
516
  /**
199
517
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
200
518
  * \brief Moves bfloat16 value to bfloat162
201
519
  */
202
- __HOST_DEVICE__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) {
203
- return __hip_bfloat162{a, a};
520
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) {
521
+ return __hip_bfloat162(a, a);
204
522
  }
205
523
 
206
524
  /**
207
525
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
208
526
  * \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer
209
527
  */
210
- __HOST_DEVICE__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; }
528
+ __BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h) {
529
+ short ret = h;
530
+ return ret;
531
+ }
211
532
 
212
533
  /**
213
534
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
214
535
  * \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer
215
536
  */
216
- __HOST_DEVICE__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; }
537
+ __BF16_HOST_DEVICE_STATIC__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) {
538
+ unsigned short ret = h;
539
+ return ret;
540
+ }
217
541
 
218
542
  /**
219
543
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
220
544
  * \brief Convert double to __hip_bfloat16
221
545
  */
222
- __HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) {
223
- return __float2bfloat16((float)a);
546
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __double2bfloat16(const double a) {
547
+ __hip_bfloat16 ret{a};
548
+ return ret;
224
549
  }
225
550
 
226
551
  /**
227
552
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
228
553
  * \brief Convert float2 to __hip_bfloat162
229
554
  */
230
- __HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
555
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
231
556
  return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};
232
557
  }
233
558
 
@@ -235,97 +560,117 @@ __HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
235
560
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
236
561
  * \brief Combine two __hip_bfloat16 to __hip_bfloat162
237
562
  */
238
- __HOST_DEVICE__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) {
239
- return __hip_bfloat162{a, b};
563
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a,
564
+ const __hip_bfloat16 b) {
565
+ return __hip_bfloat162(a, b);
240
566
  }
241
567
 
242
568
  /**
243
569
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
244
570
  * \brief Returns high 16 bits of __hip_bfloat162
245
571
  */
246
- __HOST_DEVICE__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; }
572
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) {
573
+ __hip_bfloat162_raw hr = a;
574
+ return __hip_bfloat16(__hip_bfloat16_raw{hr.y});
575
+ }
247
576
 
248
577
  /**
249
578
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
250
579
  * \brief Returns high 16 bits of __hip_bfloat162
251
580
  */
252
- __HOST_DEVICE__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) {
253
- return __hip_bfloat162{a.y, a.y};
581
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) {
582
+ __hip_bfloat162_raw hr = a;
583
+ return __hip_bfloat162(__hip_bfloat16_raw{hr.y}, __hip_bfloat16_raw{hr.y});
254
584
  }
255
585
 
256
586
  /**
257
587
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
258
588
  * \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result
259
589
  */
260
- __HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
590
+ __BF16_HOST_DEVICE_STATIC__ float __high2float(const __hip_bfloat162 a) {
591
+ __hip_bfloat162_raw hr = a;
592
+ return __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.y}));
593
+ }
261
594
 
262
595
  /**
263
596
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
264
597
  * \brief Extracts high 16 bits from each and combines them
265
598
  */
266
- __HOST_DEVICE__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a,
267
- const __hip_bfloat162 b) {
268
- return __hip_bfloat162{a.y, b.y};
599
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a,
600
+ const __hip_bfloat162 b) {
601
+ __hip_bfloat162_raw hr_a = a;
602
+ __hip_bfloat162_raw hr_b = b;
603
+ return __hip_bfloat162(__hip_bfloat162_raw{hr_a.y, hr_b.y});
269
604
  }
270
605
 
271
606
  /**
272
607
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
273
608
  * \brief Returns low 16 bits of __hip_bfloat162
274
609
  */
275
- __HOST_DEVICE__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; }
610
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) {
611
+ __hip_bfloat162_raw hr = a;
612
+ return __hip_bfloat16(hr.x);
613
+ }
276
614
 
277
615
  /**
278
616
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
279
617
  * \brief Returns low 16 bits of __hip_bfloat162
280
618
  */
281
- __HOST_DEVICE__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) {
282
- return __hip_bfloat162{a.x, a.x};
619
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) {
620
+ __hip_bfloat162_raw hr = a;
621
+ return __hip_bfloat162(hr.x, hr.x);
283
622
  }
284
623
 
285
624
  /**
286
625
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
287
626
  * \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result
288
627
  */
289
- __HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
628
+ __BF16_HOST_DEVICE_STATIC__ float __low2float(const __hip_bfloat162 a) {
629
+ __hip_bfloat162_raw hr = a;
630
+ return __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.x}));
631
+ }
290
632
 
291
633
  /**
292
634
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
293
635
  * \brief Swaps both halves
294
636
  */
295
- __HOST_DEVICE__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) {
296
- return __hip_bfloat162{a.y, a.x};
637
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) {
638
+ __hip_bfloat162_raw hr = a;
639
+ return __hip_bfloat162(__hip_bfloat162_raw{hr.y, hr.x});
297
640
  }
298
641
 
299
642
  /**
300
643
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
301
644
  * \brief Extracts low 16 bits from each and combines them
302
645
  */
303
- __HOST_DEVICE__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) {
304
- return __hip_bfloat162{a.x, b.x};
646
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a,
647
+ const __hip_bfloat162 b) {
648
+ __hip_bfloat162_raw hr_a = a;
649
+ __hip_bfloat162_raw hr_b = b;
650
+ return __hip_bfloat162(__hip_bfloat162_raw{hr_a.x, hr_b.x});
305
651
  }
306
652
 
307
653
  /**
308
654
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
309
655
  * \brief Reinterprets short int into a bfloat16
310
656
  */
311
- __HOST_DEVICE__ __hip_bfloat16 __short_as_bfloat16(const short int a) {
312
- return __hip_bfloat16{(unsigned short)a};
657
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a) {
658
+ return __hip_bfloat16(a);
313
659
  }
314
660
 
315
661
  /**
316
662
  * \ingroup HIP_INTRINSIC_BFLOAT162_CONV
317
663
  * \brief Reinterprets unsigned short int into a bfloat16
318
664
  */
319
- __HOST_DEVICE__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
320
- return __hip_bfloat16{a};
665
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
666
+ return __hip_bfloat16(a);
321
667
  }
322
668
 
323
-
324
669
  /**
325
670
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
326
671
  * \brief Adds two bfloat16 values
327
672
  */
328
- __HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) {
673
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) {
329
674
  return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b));
330
675
  }
331
676
 
@@ -333,7 +678,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat
333
678
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
334
679
  * \brief Subtracts two bfloat16 values
335
680
  */
336
- __HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) {
681
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) {
337
682
  return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b));
338
683
  }
339
684
 
@@ -341,7 +686,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat
341
686
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
342
687
  * \brief Divides two bfloat16 values
343
688
  */
344
- __HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) {
689
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) {
345
690
  return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b));
346
691
  }
347
692
 
@@ -349,8 +694,8 @@ __HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat
349
694
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
350
695
  * \brief Performs FMA of given bfloat16 values
351
696
  */
352
- __device__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b,
353
- const __hip_bfloat16 c) {
697
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b,
698
+ const __hip_bfloat16 c) {
354
699
  return __float2bfloat16(
355
700
  __ocml_fma_f32(__bfloat162float(a), __bfloat162float(b), __bfloat162float(c)));
356
701
  }
@@ -359,7 +704,7 @@ __device__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b,
359
704
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
360
705
  * \brief Multiplies two bfloat16 values
361
706
  */
362
- __HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) {
707
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) {
363
708
  return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b));
364
709
  }
365
710
 
@@ -367,85 +712,110 @@ __HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat
367
712
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
368
713
  * \brief Negate a bfloat16 value
369
714
  */
370
- __HOST_DEVICE__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) {
371
- auto ret = a;
372
- ret.data ^= 0x8000;
373
- return ret;
715
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) {
716
+ __hip_bfloat16_raw hr = a;
717
+ hr.x ^= 0x8000;
718
+ return __hip_bfloat16(hr);
374
719
  }
375
720
 
376
721
  /**
377
722
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
378
723
  * \brief Returns absolute of a bfloat16
379
724
  */
380
- __HOST_DEVICE__ __hip_bfloat16 __habs(const __hip_bfloat16 a) {
381
- auto ret = a;
382
- ret.data &= 0x7FFF;
383
- return ret;
725
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __habs(const __hip_bfloat16 a) {
726
+ __hip_bfloat16_raw hr = a;
727
+ hr.x &= 0x7FFF;
728
+ return __hip_bfloat16(hr);
384
729
  }
385
730
 
386
731
  /**
387
732
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
388
733
  * \brief Divides bfloat162 values
389
734
  */
390
- __HOST_DEVICE__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) {
391
- return __hip_bfloat162{__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)),
392
- __float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y))};
735
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __h2div(const __hip_bfloat162 a,
736
+ const __hip_bfloat162 b) {
737
+ __hip_bfloat162_raw hr_a = a;
738
+ __hip_bfloat162_raw hr_b = b;
739
+ return __hip_bfloat162(__float2bfloat16(__bfloat162float(__hip_bfloat16_raw{hr_a.x}) /
740
+ __bfloat162float(__hip_bfloat16_raw{hr_b.x})),
741
+ __float2bfloat16(__bfloat162float(__hip_bfloat16_raw{hr_a.y}) /
742
+ __bfloat162float(__hip_bfloat16_raw{hr_b.y})));
393
743
  }
394
744
 
395
745
  /**
396
746
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
397
747
  * \brief Returns absolute of a bfloat162
398
748
  */
399
- __HOST_DEVICE__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) {
400
- return __hip_bfloat162{__habs(a.x), __habs(a.y)};
749
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) {
750
+ __hip_bfloat162_raw hr_a = a;
751
+ return __hip_bfloat162(__habs(__hip_bfloat16_raw{hr_a.x}), __habs(__hip_bfloat16_raw{hr_a.y}));
401
752
  }
402
753
 
403
754
  /**
404
755
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
405
756
  * \brief Adds two bfloat162 values
406
757
  */
407
- __HOST_DEVICE__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
408
- return __hip_bfloat162{__hadd(a.x, b.x), __hadd(a.y, b.y)};
758
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a,
759
+ const __hip_bfloat162 b) {
760
+ __hip_bfloat162_raw hr_a = a;
761
+ __hip_bfloat162_raw hr_b = b;
762
+ return __hip_bfloat162(__hadd(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
763
+ __hadd(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
409
764
  }
410
765
 
411
766
  /**
412
767
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
413
768
  * \brief Performs FMA of given bfloat162 values
414
769
  */
415
- __device__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat162 b,
416
- const __hip_bfloat162 c) {
417
- return __hip_bfloat162{__hfma(a.x, b.x, c.x), __hfma(a.y, b.y, c.y)};
770
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat162 b,
771
+ const __hip_bfloat162 c) {
772
+ __hip_bfloat162_raw hr_a = a;
773
+ __hip_bfloat162_raw hr_b = b;
774
+ __hip_bfloat162_raw hr_c = c;
775
+ return __hip_bfloat162(
776
+ __hfma(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}, __hip_bfloat16_raw{hr_c.x}),
777
+ __hfma(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}, __hip_bfloat16_raw{hr_c.y}));
418
778
  }
419
779
 
420
780
  /**
421
781
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
422
782
  * \brief Multiplies two bfloat162 values
423
783
  */
424
- __HOST_DEVICE__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
425
- return __hip_bfloat162{__hmul(a.x, b.x), __hmul(a.y, b.y)};
784
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a,
785
+ const __hip_bfloat162 b) {
786
+ __hip_bfloat162_raw hr_a = a;
787
+ __hip_bfloat162_raw hr_b = b;
788
+ return __hip_bfloat162(__hmul(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
789
+ __hmul(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
426
790
  }
427
791
 
428
792
  /**
429
793
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
430
794
  * \brief Converts a bfloat162 into negative
431
795
  */
432
- __HOST_DEVICE__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) {
433
- return __hip_bfloat162{__hneg(a.x), __hneg(a.y)};
796
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) {
797
+ __hip_bfloat162_raw hr_a = a;
798
+ return __hip_bfloat162(__hneg(__hip_bfloat16_raw{hr_a.x}), __hneg(__hip_bfloat16_raw{hr_a.y}));
434
799
  }
435
800
 
436
801
  /**
437
802
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
438
803
  * \brief Subtracts two bfloat162 values
439
804
  */
440
- __HOST_DEVICE__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
441
- return __hip_bfloat162{__hsub(a.x, b.x), __hsub(a.y, b.y)};
805
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a,
806
+ const __hip_bfloat162 b) {
807
+ __hip_bfloat162_raw hr_a = a;
808
+ __hip_bfloat162_raw hr_b = b;
809
+ return __hip_bfloat162(__hsub(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
810
+ __hsub(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
442
811
  }
443
812
 
444
813
  /**
445
814
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
446
815
  * \brief Operator to multiply two __hip_bfloat16 numbers
447
816
  */
448
- __HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bfloat16& r) {
817
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator*(const __hip_bfloat16& l,
818
+ const __hip_bfloat16& r) {
449
819
  return __hmul(l, r);
450
820
  }
451
821
 
@@ -453,7 +823,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bf
453
823
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
454
824
  * \brief Operator to multiply-assign two __hip_bfloat16 numbers
455
825
  */
456
- __HOST_DEVICE__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) {
826
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) {
457
827
  l = __hmul(l, r);
458
828
  return l;
459
829
  }
@@ -462,13 +832,14 @@ __HOST_DEVICE__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat
462
832
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
463
833
  * \brief Operator to unary+ on a __hip_bfloat16 number
464
834
  */
465
- __HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; }
835
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; }
466
836
 
467
837
  /**
468
838
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
469
839
  * \brief Operator to add two __hip_bfloat16 numbers
470
840
  */
471
- __HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bfloat16& r) {
841
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16& l,
842
+ const __hip_bfloat16& r) {
472
843
  return __hadd(l, r);
473
844
  }
474
845
 
@@ -476,13 +847,14 @@ __HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bf
476
847
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
477
848
  * \brief Operator to negate a __hip_bfloat16 number
478
849
  */
479
- __HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); }
850
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); }
480
851
 
481
852
  /**
482
853
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
483
854
  * \brief Operator to subtract two __hip_bfloat16 numbers
484
855
  */
485
- __HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bfloat16& r) {
856
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16& l,
857
+ const __hip_bfloat16& r) {
486
858
  return __hsub(l, r);
487
859
  }
488
860
 
@@ -490,7 +862,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bf
490
862
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
491
863
  * \brief Operator to post increment a __hip_bfloat16 number
492
864
  */
493
- __HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) {
865
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) {
494
866
  auto ret = l;
495
867
  l = __hadd(l, HIPRT_ONE_BF16);
496
868
  return ret;
@@ -500,7 +872,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) {
500
872
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
501
873
  * \brief Operator to pre increment a __hip_bfloat16 number
502
874
  */
503
- __HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) {
875
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator++(__hip_bfloat16& l) {
504
876
  l = __hadd(l, HIPRT_ONE_BF16);
505
877
  return l;
506
878
  }
@@ -509,7 +881,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) {
509
881
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
510
882
  * \brief Operator to post decrement a __hip_bfloat16 number
511
883
  */
512
- __HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) {
884
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) {
513
885
  auto ret = l;
514
886
  l = __hsub(l, HIPRT_ONE_BF16);
515
887
  return ret;
@@ -519,7 +891,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) {
519
891
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
520
892
  * \brief Operator to pre decrement a __hip_bfloat16 number
521
893
  */
522
- __HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) {
894
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator--(__hip_bfloat16& l) {
523
895
  l = __hsub(l, HIPRT_ONE_BF16);
524
896
  return l;
525
897
  }
@@ -528,7 +900,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) {
528
900
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
529
901
  * \brief Operator to add-assign two __hip_bfloat16 numbers
530
902
  */
531
- __HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) {
903
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) {
532
904
  l = __hadd(l, r);
533
905
  return l;
534
906
  }
@@ -537,7 +909,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat
537
909
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
538
910
  * \brief Operator to subtract-assign two __hip_bfloat16 numbers
539
911
  */
540
- __HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) {
912
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) {
541
913
  l = __hsub(l, r);
542
914
  return l;
543
915
  }
@@ -546,7 +918,8 @@ __HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat
546
918
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
547
919
  * \brief Operator to divide two __hip_bfloat16 numbers
548
920
  */
549
- __HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bfloat16& r) {
921
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator/(const __hip_bfloat16& l,
922
+ const __hip_bfloat16& r) {
550
923
  return __hdiv(l, r);
551
924
  }
552
925
 
@@ -554,7 +927,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bf
554
927
  * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
555
928
  * \brief Operator to divide-assign two __hip_bfloat16 numbers
556
929
  */
557
- __HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) {
930
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) {
558
931
  l = __hdiv(l, r);
559
932
  return l;
560
933
  }
@@ -563,7 +936,8 @@ __HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat
563
936
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
564
937
  * \brief Operator to multiply two __hip_bfloat162 numbers
565
938
  */
566
- __HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_bfloat162& r) {
939
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator*(const __hip_bfloat162& l,
940
+ const __hip_bfloat162& r) {
567
941
  return __hmul2(l, r);
568
942
  }
569
943
 
@@ -571,7 +945,8 @@ __HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_
571
945
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
572
946
  * \brief Operator to multiply-assign two __hip_bfloat162 numbers
573
947
  */
574
- __HOST_DEVICE__ __hip_bfloat162& operator*=(__hip_bfloat162& l, const __hip_bfloat162& r) {
948
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator*=(__hip_bfloat162& l,
949
+ const __hip_bfloat162& r) {
575
950
  l = __hmul2(l, r);
576
951
  return l;
577
952
  }
@@ -580,13 +955,14 @@ __HOST_DEVICE__ __hip_bfloat162& operator*=(__hip_bfloat162& l, const __hip_bflo
580
955
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
581
956
  * \brief Operator to unary+ on a __hip_bfloat162 number
582
957
  */
583
- __HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; }
958
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; }
584
959
 
585
960
  /**
586
961
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
587
962
  * \brief Operator to add two __hip_bfloat162 numbers
588
963
  */
589
- __HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_bfloat162& r) {
964
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator+(const __hip_bfloat162& l,
965
+ const __hip_bfloat162& r) {
590
966
  return __hadd2(l, r);
591
967
  }
592
968
 
@@ -594,13 +970,16 @@ __HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_
594
970
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
595
971
  * \brief Operator to negate a __hip_bfloat162 number
596
972
  */
597
- __HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l) { return __hneg2(l); }
973
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator-(const __hip_bfloat162& l) {
974
+ return __hneg2(l);
975
+ }
598
976
 
599
977
  /**
600
978
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
601
979
  * \brief Operator to subtract two __hip_bfloat162 numbers
602
980
  */
603
- __HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_bfloat162& r) {
981
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator-(const __hip_bfloat162& l,
982
+ const __hip_bfloat162& r) {
604
983
  return __hsub2(l, r);
605
984
  }
606
985
 
@@ -608,7 +987,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_
608
987
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
609
988
  * \brief Operator to post increment a __hip_bfloat162 number
610
989
  */
611
- __HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) {
990
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) {
612
991
  auto ret = l;
613
992
  l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
614
993
  return ret;
@@ -618,7 +997,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) {
618
997
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
619
998
  * \brief Operator to pre increment a __hip_bfloat162 number
620
999
  */
621
- __HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) {
1000
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator++(__hip_bfloat162& l) {
622
1001
  l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
623
1002
  return l;
624
1003
  }
@@ -627,7 +1006,7 @@ __HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) {
627
1006
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
628
1007
  * \brief Operator to post decrement a __hip_bfloat162 number
629
1008
  */
630
- __HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) {
1009
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) {
631
1010
  auto ret = l;
632
1011
  l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
633
1012
  return ret;
@@ -637,7 +1016,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) {
637
1016
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
638
1017
  * \brief Operator to pre decrement a __hip_bfloat162 number
639
1018
  */
640
- __HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) {
1019
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator--(__hip_bfloat162& l) {
641
1020
  l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
642
1021
  return l;
643
1022
  }
@@ -646,7 +1025,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) {
646
1025
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
647
1026
  * \brief Operator to add-assign two __hip_bfloat162 numbers
648
1027
  */
649
- __HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bfloat162& r) {
1028
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator+=(__hip_bfloat162& l,
1029
+ const __hip_bfloat162& r) {
650
1030
  l = __hadd2(l, r);
651
1031
  return l;
652
1032
  }
@@ -655,7 +1035,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bflo
655
1035
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
656
1036
  * \brief Operator to subtract-assign two __hip_bfloat162 numbers
657
1037
  */
658
- __HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bfloat162& r) {
1038
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator-=(__hip_bfloat162& l,
1039
+ const __hip_bfloat162& r) {
659
1040
  l = __hsub2(l, r);
660
1041
  return l;
661
1042
  }
@@ -664,7 +1045,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bflo
664
1045
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
665
1046
  * \brief Operator to divide two __hip_bfloat162 numbers
666
1047
  */
667
- __HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1048
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator/(const __hip_bfloat162& l,
1049
+ const __hip_bfloat162& r) {
668
1050
  return __h2div(l, r);
669
1051
  }
670
1052
 
@@ -672,7 +1054,8 @@ __HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_
672
1054
  * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
673
1055
  * \brief Operator to divide-assign two __hip_bfloat162 numbers
674
1056
  */
675
- __HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bfloat162& r) {
1057
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator/=(__hip_bfloat162& l,
1058
+ const __hip_bfloat162& r) {
676
1059
  l = __h2div(l, r);
677
1060
  return l;
678
1061
  }
@@ -681,7 +1064,7 @@ __HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bflo
681
1064
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
682
1065
  * \brief Compare two bfloat162 values
683
1066
  */
684
- __HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1067
+ __BF16_HOST_DEVICE_STATIC__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
685
1068
  return __bfloat162float(a) == __bfloat162float(b);
686
1069
  }
687
1070
 
@@ -689,7 +1072,7 @@ __HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
689
1072
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
690
1073
  * \brief Compare two bfloat162 values - unordered equal
691
1074
  */
692
- __HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1075
+ __BF16_HOST_DEVICE_STATIC__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
693
1076
  return !(__bfloat162float(a) < __bfloat162float(b)) &&
694
1077
  !(__bfloat162float(a) > __bfloat162float(b));
695
1078
  }
@@ -698,7 +1081,7 @@ __HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
698
1081
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
699
1082
  * \brief Compare two bfloat162 values - greater than
700
1083
  */
701
- __HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1084
+ __BF16_HOST_DEVICE_STATIC__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
702
1085
  return __bfloat162float(a) > __bfloat162float(b);
703
1086
  }
704
1087
 
@@ -706,7 +1089,7 @@ __HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
706
1089
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
707
1090
  * \brief Compare two bfloat162 values - unordered greater than
708
1091
  */
709
- __HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1092
+ __BF16_HOST_DEVICE_STATIC__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
710
1093
  return !(__bfloat162float(a) <= __bfloat162float(b));
711
1094
  }
712
1095
 
@@ -714,7 +1097,7 @@ __HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
714
1097
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
715
1098
  * \brief Compare two bfloat162 values - greater than equal
716
1099
  */
717
- __HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1100
+ __BF16_HOST_DEVICE_STATIC__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
718
1101
  return __bfloat162float(a) >= __bfloat162float(b);
719
1102
  }
720
1103
 
@@ -722,7 +1105,7 @@ __HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
722
1105
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
723
1106
  * \brief Compare two bfloat162 values - unordered greater than equal
724
1107
  */
725
- __HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1108
+ __BF16_HOST_DEVICE_STATIC__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
726
1109
  return !(__bfloat162float(a) < __bfloat162float(b));
727
1110
  }
728
1111
 
@@ -730,7 +1113,7 @@ __HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
730
1113
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
731
1114
  * \brief Compare two bfloat162 values - not equal
732
1115
  */
733
- __HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1116
+ __BF16_HOST_DEVICE_STATIC__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
734
1117
  return __bfloat162float(a) != __bfloat162float(b);
735
1118
  }
736
1119
 
@@ -738,7 +1121,7 @@ __HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
738
1121
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
739
1122
  * \brief Compare two bfloat162 values - unordered not equal
740
1123
  */
741
- __HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1124
+ __BF16_HOST_DEVICE_STATIC__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
742
1125
  return !(__bfloat162float(a) == __bfloat162float(b));
743
1126
  }
744
1127
 
@@ -746,7 +1129,7 @@ __HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
746
1129
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
747
1130
  * \brief Compare two bfloat162 values - return max
748
1131
  */
749
- __HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1132
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) {
750
1133
  #if __HIP_DEVICE_COMPILE__
751
1134
  return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b)));
752
1135
  #else
@@ -758,7 +1141,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat
758
1141
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
759
1142
  * \brief Compare two bfloat162 values - return min
760
1143
  */
761
- __HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1144
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) {
762
1145
  #if __HIP_DEVICE_COMPILE__
763
1146
  return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b)));
764
1147
  #else
@@ -770,7 +1153,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat
770
1153
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
771
1154
  * \brief Compare two bfloat162 values - less than operator
772
1155
  */
773
- __HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1156
+ __BF16_HOST_DEVICE_STATIC__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
774
1157
  return __bfloat162float(a) < __bfloat162float(b);
775
1158
  }
776
1159
 
@@ -778,7 +1161,7 @@ __HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
778
1161
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
779
1162
  * \brief Compare two bfloat162 values - unordered less than
780
1163
  */
781
- __HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1164
+ __BF16_HOST_DEVICE_STATIC__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
782
1165
  return !(__bfloat162float(a) >= __bfloat162float(b));
783
1166
  }
784
1167
 
@@ -786,7 +1169,7 @@ __HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
786
1169
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
787
1170
  * \brief Compare two bfloat162 values - less than equal
788
1171
  */
789
- __HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1172
+ __BF16_HOST_DEVICE_STATIC__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
790
1173
  return __bfloat162float(a) <= __bfloat162float(b);
791
1174
  }
792
1175
 
@@ -794,7 +1177,7 @@ __HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
794
1177
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
795
1178
  * \brief Compare two bfloat162 values - unordered less than equal
796
1179
  */
797
- __HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
1180
+ __BF16_HOST_DEVICE_STATIC__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
798
1181
  return !(__bfloat162float(a) > __bfloat162float(b));
799
1182
  }
800
1183
 
@@ -802,208 +1185,282 @@ __HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
802
1185
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
803
1186
  * \brief Checks if number is inf
804
1187
  */
805
- __HOST_DEVICE__ int __hisinf(const __hip_bfloat16 a) {
806
- unsigned short sign = a.data & 0x8000U;
807
- #if __HIP_DEVICE_COMPILE__
808
- int res = __ocml_isinf_f32(__bfloat162float(a));
809
- #else
810
- int res = std::isinf(__bfloat162float(a)) ? 1 : 0;
811
- #endif
812
- return (res == 0) ? res : ((sign != 0U) ? -res : res);
1188
+ __BF16_HOST_DEVICE_STATIC__ int __hisinf(const __hip_bfloat16 a) {
1189
+ __hip_bfloat16_raw hr = a;
1190
+ return !(~hr.x & 0x7f80) && !(hr.x & 0x7f);
813
1191
  }
814
1192
 
815
1193
  /**
816
1194
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
817
1195
  * \brief Checks if number is nan
818
1196
  */
819
- __HOST_DEVICE__ bool __hisnan(const __hip_bfloat16 a) {
820
- #if __HIP_DEVICE_COMPILE__
821
- return __ocml_isnan_f32(__bfloat162float(a));
822
- #else
823
- return std::isnan(__bfloat162float(a));
824
- #endif
1197
+ __BF16_HOST_DEVICE_STATIC__ bool __hisnan(const __hip_bfloat16 a) {
1198
+ __hip_bfloat16_raw hr = a;
1199
+ return !(~hr.x & 0x7f80) && +(hr.x & 0x7f);
825
1200
  }
826
1201
 
827
1202
  /**
828
1203
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
829
1204
  * \brief Checks if two numbers are equal
830
1205
  */
831
- __HOST_DEVICE__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
832
- return __heq(a.x, b.x) && __heq(a.y, b.y);
1206
+ __BF16_HOST_DEVICE_STATIC__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1207
+ __hip_bfloat162_raw hr_a = a;
1208
+ __hip_bfloat162_raw hr_b = b;
1209
+ return __heq(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1210
+ __heq(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
833
1211
  }
834
1212
 
835
1213
  /**
836
1214
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
837
1215
  * \brief Checks if two numbers are equal - unordered
838
1216
  */
839
- __HOST_DEVICE__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
840
- return __hequ(a.x, b.x) && __hequ(a.y, b.y);
1217
+ __BF16_HOST_DEVICE_STATIC__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1218
+ __hip_bfloat162_raw hr_a = a;
1219
+ __hip_bfloat162_raw hr_b = b;
1220
+ return __hequ(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1221
+ __hequ(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
841
1222
  }
842
1223
 
843
1224
  /**
844
1225
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
845
1226
  * \brief Check for a >= b
846
1227
  */
847
- __HOST_DEVICE__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
848
- return __hge(a.x, b.x) && __hge(a.y, b.y);
1228
+ __BF16_HOST_DEVICE_STATIC__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1229
+ __hip_bfloat162_raw hr_a = a;
1230
+ __hip_bfloat162_raw hr_b = b;
1231
+ return __hge(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1232
+ __hge(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
849
1233
  }
850
1234
 
851
1235
  /**
852
1236
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
853
1237
  * \brief Check for a >= b - unordered
854
1238
  */
855
- __HOST_DEVICE__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
856
- return __hgeu(a.x, b.x) && __hgeu(a.y, b.y);
1239
+ __BF16_HOST_DEVICE_STATIC__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1240
+ __hip_bfloat162_raw hr_a = a;
1241
+ __hip_bfloat162_raw hr_b = b;
1242
+ return __hgeu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1243
+ __hgeu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
857
1244
  }
858
1245
 
859
1246
  /**
860
1247
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
861
1248
  * \brief Check for a > b
862
1249
  */
863
- __HOST_DEVICE__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
864
- return __hgt(a.x, b.x) && __hgt(a.y, b.y);
1250
+ __BF16_HOST_DEVICE_STATIC__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1251
+ __hip_bfloat162_raw hr_a = a;
1252
+ __hip_bfloat162_raw hr_b = b;
1253
+ return __hgt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1254
+ __hgt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
865
1255
  }
866
1256
 
867
1257
  /**
868
1258
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
869
1259
  * \brief Check for a > b - unordered
870
1260
  */
871
- __HOST_DEVICE__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
872
- return __hgtu(a.x, b.x) && __hgtu(a.y, b.y);
1261
+ __BF16_HOST_DEVICE_STATIC__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1262
+ __hip_bfloat162_raw hr_a = a;
1263
+ __hip_bfloat162_raw hr_b = b;
1264
+ return __hgtu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1265
+ __hgtu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
873
1266
  }
874
1267
 
875
1268
  /**
876
1269
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
877
1270
  * \brief Check for a <= b
878
1271
  */
879
- __HOST_DEVICE__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
880
- return __hle(a.x, b.x) && __hle(a.y, b.y);
1272
+ __BF16_HOST_DEVICE_STATIC__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1273
+ __hip_bfloat162_raw hr_a = a;
1274
+ __hip_bfloat162_raw hr_b = b;
1275
+ return __hle(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1276
+ __hle(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
881
1277
  }
882
1278
 
883
1279
  /**
884
1280
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
885
1281
  * \brief Check for a <= b - unordered
886
1282
  */
887
- __HOST_DEVICE__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
888
- return __hleu(a.x, b.x) && __hleu(a.y, b.y);
1283
+ __BF16_HOST_DEVICE_STATIC__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1284
+ __hip_bfloat162_raw hr_a = a;
1285
+ __hip_bfloat162_raw hr_b = b;
1286
+ return __hleu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1287
+ __hleu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
889
1288
  }
890
1289
 
891
1290
  /**
892
1291
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
893
1292
  * \brief Check for a < b
894
1293
  */
895
- __HOST_DEVICE__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
896
- return __hlt(a.x, b.x) && __hlt(a.y, b.y);
1294
+ __BF16_HOST_DEVICE_STATIC__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1295
+ __hip_bfloat162_raw hr_a = a;
1296
+ __hip_bfloat162_raw hr_b = b;
1297
+ return __hlt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1298
+ __hlt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
897
1299
  }
898
1300
 
899
1301
  /**
900
1302
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
901
1303
  * \brief Check for a < b - unordered
902
1304
  */
903
- __HOST_DEVICE__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
904
- return __hltu(a.x, b.x) && __hltu(a.y, b.y);
1305
+ __BF16_HOST_DEVICE_STATIC__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1306
+ __hip_bfloat162_raw hr_a = a;
1307
+ __hip_bfloat162_raw hr_b = b;
1308
+ return __hltu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
1309
+ __hltu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
905
1310
  }
906
1311
 
907
1312
  /**
908
1313
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
909
1314
  * \brief Check for a != b
910
1315
  */
911
- __HOST_DEVICE__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
912
- return __hne(a.x, b.x) && __hne(a.y, b.y);
1316
+ __BF16_HOST_DEVICE_STATIC__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1317
+ __hip_bfloat162_raw hr_a = a;
1318
+ __hip_bfloat162_raw hr_b = b;
1319
+ return __hne(__hip_bfloat16(__hip_bfloat16_raw{hr_a.x}),
1320
+ __hip_bfloat16(__hip_bfloat16_raw{hr_b.x})) &&
1321
+ __hne(__hip_bfloat16(__hip_bfloat16_raw{hr_a.y}), __hip_bfloat16(__hip_bfloat16_raw{hr_b.y}));
913
1322
  }
914
1323
 
915
1324
  /**
916
1325
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
917
1326
  * \brief Check for a != b
918
1327
  */
919
- __HOST_DEVICE__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
920
- return __hneu(a.x, b.x) && __hneu(a.y, b.y);
1328
+ __BF16_HOST_DEVICE_STATIC__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
1329
+ __hip_bfloat162_raw hr_a = a;
1330
+ __hip_bfloat162_raw hr_b = b;
1331
+ return __hneu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ||
1332
+ __hneu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
921
1333
  }
922
1334
 
923
1335
  /**
924
1336
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
925
1337
  * \brief Check for a != b, returns 1.0 if equal, otherwise 0.0
926
1338
  */
927
- __HOST_DEVICE__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
928
- return __hip_bfloat162{{__heq(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
929
- {__heq(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
1339
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __heq2(const __hip_bfloat162 a,
1340
+ const __hip_bfloat162 b) {
1341
+ __hip_bfloat162_raw hr_a = a;
1342
+ __hip_bfloat162_raw hr_b = b;
1343
+ return __hip_bfloat162{
1344
+ {__heq(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
1345
+ : HIPRT_ZERO_BF16},
1346
+ {__heq(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
1347
+ : HIPRT_ZERO_BF16}};
930
1348
  }
931
1349
 
932
1350
  /**
933
1351
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
934
1352
  * \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0
935
1353
  */
936
- __HOST_DEVICE__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
937
- return __hip_bfloat162{{__hge(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
938
- {__hge(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
1354
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hge2(const __hip_bfloat162 a,
1355
+ const __hip_bfloat162 b) {
1356
+ __hip_bfloat162_raw hr_a = a;
1357
+ __hip_bfloat162_raw hr_b = b;
1358
+ return __hip_bfloat162{
1359
+ {__hge(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
1360
+ : HIPRT_ZERO_BF16},
1361
+ {__hge(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
1362
+ : HIPRT_ZERO_BF16}};
939
1363
  }
940
1364
 
941
1365
  /**
942
1366
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
943
1367
  * \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0
944
1368
  */
945
- __HOST_DEVICE__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
946
- return __hip_bfloat162{{__hgt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
947
- {__hgt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}};
1369
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a,
1370
+ const __hip_bfloat162 b) {
1371
+ __hip_bfloat162_raw hr_a = a;
1372
+ __hip_bfloat162_raw hr_b = b;
1373
+ return __hip_bfloat162{
1374
+ {__hgt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
1375
+ : HIPRT_ZERO_BF16},
1376
+ {__hgt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
1377
+ : HIPRT_ONE_BF16}};
948
1378
  }
949
1379
 
950
1380
  /**
951
1381
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
952
1382
  * \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0
953
1383
  */
954
- __HOST_DEVICE__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) {
955
- return __hip_bfloat162{{__hisnan(a.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
956
- {__hisnan(a.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}};
1384
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) {
1385
+ __hip_bfloat162_raw hr_a = a;
1386
+ return __hip_bfloat162{{__hisnan(__hip_bfloat16_raw{hr_a.x}) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
1387
+ {__hisnan(__hip_bfloat16_raw{hr_a.y}) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}};
957
1388
  }
958
1389
 
959
1390
  /**
960
1391
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
961
1392
  * \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0
962
1393
  */
963
- __HOST_DEVICE__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
964
- return __hip_bfloat162{{__hle(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
965
- {__hle(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
1394
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hle2(const __hip_bfloat162 a,
1395
+ const __hip_bfloat162 b) {
1396
+ __hip_bfloat162_raw hr_a = a;
1397
+ __hip_bfloat162_raw hr_b = b;
1398
+ return __hip_bfloat162{
1399
+ {__hle(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
1400
+ : HIPRT_ZERO_BF16},
1401
+ {__hle(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
1402
+ : HIPRT_ZERO_BF16}};
966
1403
  }
967
1404
 
968
1405
  /**
969
1406
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
970
1407
  * \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0
971
1408
  */
972
- __HOST_DEVICE__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
973
- return __hip_bfloat162{{__hlt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
974
- {__hlt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
1409
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a,
1410
+ const __hip_bfloat162 b) {
1411
+ __hip_bfloat162_raw hr_a = a;
1412
+ __hip_bfloat162_raw hr_b = b;
1413
+ return __hip_bfloat162{
1414
+ {__hlt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
1415
+ : HIPRT_ZERO_BF16},
1416
+ {__hlt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
1417
+ : HIPRT_ZERO_BF16}};
975
1418
  }
976
1419
 
977
1420
  /**
978
1421
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
979
1422
  * \brief Returns max of two elements
980
1423
  */
981
- __HOST_DEVICE__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
982
- return __hip_bfloat162{__hmax(a.x, b.x), __hmax(a.y, b.y)};
1424
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a,
1425
+ const __hip_bfloat162 b) {
1426
+ __hip_bfloat162_raw hr_a = a;
1427
+ __hip_bfloat162_raw hr_b = b;
1428
+ return __hip_bfloat162(__hmax(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
1429
+ __hmax(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
983
1430
  }
984
1431
 
985
1432
  /**
986
1433
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
987
1434
  * \brief Returns min of two elements
988
1435
  */
989
- __HOST_DEVICE__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
990
- return __hip_bfloat162{__hmin(a.x, b.x), __hmin(a.y, b.y)};
1436
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a,
1437
+ const __hip_bfloat162 b) {
1438
+ __hip_bfloat162_raw hr_a = a;
1439
+ __hip_bfloat162_raw hr_b = b;
1440
+ return __hip_bfloat162(__hmin(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
1441
+ __hmin(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
991
1442
  }
992
1443
 
993
1444
  /**
994
1445
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
995
1446
  * \brief Checks for not equal to
996
1447
  */
997
- __HOST_DEVICE__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
998
- return __hip_bfloat162{{__hne(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
999
- {__hne(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}};
1448
+ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hne2(const __hip_bfloat162 a,
1449
+ const __hip_bfloat162 b) {
1450
+ __hip_bfloat162_raw hr_a = a;
1451
+ __hip_bfloat162_raw hr_b = b;
1452
+ return __hip_bfloat162{
1453
+ {__hne(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
1454
+ : HIPRT_ZERO_BF16},
1455
+ {__hne(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
1456
+ : HIPRT_ZERO_BF16}};
1000
1457
  }
1001
1458
 
1002
1459
  /**
1003
1460
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
1004
1461
  * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers
1005
1462
  */
1006
- __HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1463
+ __BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1007
1464
  return __heq(l, r);
1008
1465
  }
1009
1466
 
@@ -1011,7 +1468,7 @@ __HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r
1011
1468
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
1012
1469
  * \brief Operator to perform a not equal on two __hip_bfloat16 numbers
1013
1470
  */
1014
- __HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1471
+ __BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1015
1472
  return __hne(l, r);
1016
1473
  }
1017
1474
 
@@ -1019,7 +1476,7 @@ __HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r
1019
1476
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
1020
1477
  * \brief Operator to perform a less than on two __hip_bfloat16 numbers
1021
1478
  */
1022
- __HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1479
+ __BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1023
1480
  return __hlt(l, r);
1024
1481
  }
1025
1482
 
@@ -1027,7 +1484,7 @@ __HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r)
1027
1484
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
1028
1485
  * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers
1029
1486
  */
1030
- __HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1487
+ __BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1031
1488
  return __hle(l, r);
1032
1489
  }
1033
1490
 
@@ -1035,7 +1492,7 @@ __HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r
1035
1492
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
1036
1493
  * \brief Operator to perform a greater than on two __hip_bfloat16 numbers
1037
1494
  */
1038
- __HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1495
+ __BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1039
1496
  return __hgt(l, r);
1040
1497
  }
1041
1498
 
@@ -1043,7 +1500,7 @@ __HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r)
1043
1500
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
1044
1501
  * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers
1045
1502
  */
1046
- __HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1503
+ __BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
1047
1504
  return __hge(l, r);
1048
1505
  }
1049
1506
 
@@ -1051,55 +1508,60 @@ __HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r
1051
1508
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
1052
1509
  * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers
1053
1510
  */
1054
- __HOST_DEVICE__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1055
- return __heq(l.x, r.x) && __heq(l.y, r.y);
1511
+ __BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1512
+ float2 ret = __heq2(l, r);
1513
+ return ret.x != 0.0f && ret.y != 0.0f;
1056
1514
  }
1057
1515
 
1058
1516
  /**
1059
1517
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
1060
1518
  * \brief Operator to perform a not equal on two __hip_bfloat16 numbers
1061
1519
  */
1062
- __HOST_DEVICE__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1063
- return __hne(l.x, r.x) || __hne(l.y, r.y);
1520
+ __BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1521
+ return !(l == r);
1064
1522
  }
1065
1523
 
1066
1524
  /**
1067
1525
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
1068
1526
  * \brief Operator to perform a less than on two __hip_bfloat16 numbers
1069
1527
  */
1070
- __HOST_DEVICE__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1071
- return __hlt(l.x, r.x) && __hlt(l.y, r.y);
1528
+ __BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1529
+ float2 fl = l, fr = r;
1530
+ return fl.x < fr.x && fl.x < fr.y;
1072
1531
  }
1073
1532
 
1074
1533
  /**
1075
1534
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
1076
1535
  * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers
1077
1536
  */
1078
- __HOST_DEVICE__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1079
- return __hle(l.x, r.x) && __hle(l.y, r.y);
1537
+ __BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1538
+ float2 fl = l, fr = r;
1539
+ return fl.x <= fr.x && fl.x <= fr.y;
1080
1540
  }
1081
1541
 
1082
1542
  /**
1083
1543
  * \ingroup HIP_INTRINSIC_BFLOAT162_COMP
1084
1544
  * \brief Operator to perform a greater than on two __hip_bfloat16 numbers
1085
1545
  */
1086
- __HOST_DEVICE__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1087
- return __hgt(l.x, r.x) && __hgt(l.y, r.y);
1546
+ __BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1547
+ float2 fl = l, fr = r;
1548
+ return fl.x > fr.x && fl.x > fr.y;
1088
1549
  }
1089
1550
 
1090
1551
  /**
1091
1552
  * \ingroup HIP_INTRINSIC_BFLOAT16_COMP
1092
1553
  * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers
1093
1554
  */
1094
- __HOST_DEVICE__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1095
- return __hge(l.x, r.x) && __hge(l.y, r.y);
1555
+ __BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
1556
+ float2 fl = l, fr = r;
1557
+ return fl.x >= fr.x && fl.x >= fr.y;
1096
1558
  }
1097
1559
 
1098
1560
  /**
1099
1561
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1100
1562
  * \brief Calculate ceil of bfloat16
1101
1563
  */
1102
- __device__ __hip_bfloat16 hceil(const __hip_bfloat16 h) {
1564
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hceil(const __hip_bfloat16 h) {
1103
1565
  return __float2bfloat16(__ocml_ceil_f32(__bfloat162float(h)));
1104
1566
  }
1105
1567
 
@@ -1107,7 +1569,7 @@ __device__ __hip_bfloat16 hceil(const __hip_bfloat16 h) {
1107
1569
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1108
1570
  * \brief Calculate cosine of bfloat16
1109
1571
  */
1110
- __device__ __hip_bfloat16 hcos(const __hip_bfloat16 h) {
1572
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hcos(const __hip_bfloat16 h) {
1111
1573
  return __float2bfloat16(__ocml_cos_f32(__bfloat162float(h)));
1112
1574
  }
1113
1575
 
@@ -1115,7 +1577,7 @@ __device__ __hip_bfloat16 hcos(const __hip_bfloat16 h) {
1115
1577
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1116
1578
  * \brief Calculate exponential of bfloat16
1117
1579
  */
1118
- __device__ __hip_bfloat16 hexp(const __hip_bfloat16 h) {
1580
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hexp(const __hip_bfloat16 h) {
1119
1581
  return __float2bfloat16(__ocml_exp_f32(__bfloat162float(h)));
1120
1582
  }
1121
1583
 
@@ -1123,7 +1585,7 @@ __device__ __hip_bfloat16 hexp(const __hip_bfloat16 h) {
1123
1585
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1124
1586
  * \brief Calculate exponential 10 of bfloat16
1125
1587
  */
1126
- __device__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) {
1588
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) {
1127
1589
  return __float2bfloat16(__ocml_exp10_f32(__bfloat162float(h)));
1128
1590
  }
1129
1591
 
@@ -1131,7 +1593,7 @@ __device__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) {
1131
1593
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1132
1594
  * \brief Calculate exponential 2 of bfloat16
1133
1595
  */
1134
- __device__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) {
1596
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) {
1135
1597
  return __float2bfloat16(__ocml_exp2_f32(__bfloat162float(h)));
1136
1598
  }
1137
1599
 
@@ -1139,7 +1601,7 @@ __device__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) {
1139
1601
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1140
1602
  * \brief Calculate floor of bfloat16
1141
1603
  */
1142
- __device__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) {
1604
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) {
1143
1605
  return __float2bfloat16(__ocml_floor_f32(__bfloat162float(h)));
1144
1606
  }
1145
1607
 
@@ -1147,7 +1609,7 @@ __device__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) {
1147
1609
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1148
1610
  * \brief Calculate natural log of bfloat16
1149
1611
  */
1150
- __device__ __hip_bfloat16 hlog(const __hip_bfloat16 h) {
1612
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hlog(const __hip_bfloat16 h) {
1151
1613
  return __float2bfloat16(__ocml_log_f32(__bfloat162float(h)));
1152
1614
  }
1153
1615
 
@@ -1155,7 +1617,7 @@ __device__ __hip_bfloat16 hlog(const __hip_bfloat16 h) {
1155
1617
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1156
1618
  * \brief Calculate log 10 of bfloat16
1157
1619
  */
1158
- __device__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) {
1620
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) {
1159
1621
  return __float2bfloat16(__ocml_log10_f32(__bfloat162float(h)));
1160
1622
  }
1161
1623
 
@@ -1163,7 +1625,7 @@ __device__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) {
1163
1625
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1164
1626
  * \brief Calculate log 2 of bfloat16
1165
1627
  */
1166
- __device__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) {
1628
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) {
1167
1629
  return __float2bfloat16(__ocml_log2_f32(__bfloat162float(h)));
1168
1630
  }
1169
1631
 
@@ -1171,7 +1633,7 @@ __device__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) {
1171
1633
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1172
1634
  * \brief Calculate reciprocal
1173
1635
  */
1174
- __device__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) {
1636
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) {
1175
1637
  return __float2bfloat16(1.0f / (__bfloat162float(h)));
1176
1638
  }
1177
1639
 
@@ -1179,7 +1641,7 @@ __device__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) {
1179
1641
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1180
1642
  * \brief Round to nearest int
1181
1643
  */
1182
- __device__ __hip_bfloat16 hrint(const __hip_bfloat16 h) {
1644
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hrint(const __hip_bfloat16 h) {
1183
1645
  return __float2bfloat16(__ocml_rint_f32(__bfloat162float(h)));
1184
1646
  }
1185
1647
 
@@ -1187,7 +1649,7 @@ __device__ __hip_bfloat16 hrint(const __hip_bfloat16 h) {
1187
1649
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1188
1650
  * \brief Reciprocal square root
1189
1651
  */
1190
- __device__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) {
1652
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) {
1191
1653
  return __float2bfloat16(__ocml_rsqrt_f32(__bfloat162float(h)));
1192
1654
  }
1193
1655
 
@@ -1195,7 +1657,7 @@ __device__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) {
1195
1657
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1196
1658
  * \brief Calculate sin of bfloat16
1197
1659
  */
1198
- __device__ __hip_bfloat16 hsin(const __hip_bfloat16 h) {
1660
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hsin(const __hip_bfloat16 h) {
1199
1661
  return __float2bfloat16(__ocml_sin_f32(__bfloat162float(h)));
1200
1662
  }
1201
1663
 
@@ -1203,7 +1665,7 @@ __device__ __hip_bfloat16 hsin(const __hip_bfloat16 h) {
1203
1665
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1204
1666
  * \brief Calculate sqrt of bfloat16
1205
1667
  */
1206
- __device__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) {
1668
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) {
1207
1669
  return __float2bfloat16(__ocml_sqrt_f32(__bfloat162float(h)));
1208
1670
  }
1209
1671
 
@@ -1211,7 +1673,7 @@ __device__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) {
1211
1673
  * \ingroup HIP_INTRINSIC_BFLOAT16_MATH
1212
1674
  * \brief Calculate truncate of bfloat16
1213
1675
  */
1214
- __device__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) {
1676
+ __BF16_DEVICE_STATIC__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) {
1215
1677
  return __float2bfloat16(__ocml_trunc_f32(__bfloat162float(h)));
1216
1678
  }
1217
1679
 
@@ -1219,119 +1681,134 @@ __device__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) {
1219
1681
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1220
1682
  * \brief Calculate ceil of bfloat162
1221
1683
  */
1222
- __device__ __hip_bfloat162 h2ceil(const __hip_bfloat162 h) {
1223
- return __hip_bfloat162{hceil(h.x), hceil(h.y)};
1684
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2ceil(const __hip_bfloat162 h) {
1685
+ __hip_bfloat162_raw hr = h;
1686
+ return __hip_bfloat162(hceil(__hip_bfloat16_raw{hr.x}), hceil(__hip_bfloat16_raw{hr.y}));
1224
1687
  }
1225
1688
 
1226
1689
  /**
1227
1690
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1228
1691
  * \brief Calculate cosine of bfloat162
1229
1692
  */
1230
- __device__ __hip_bfloat162 h2cos(const __hip_bfloat162 h) {
1231
- return __hip_bfloat162{hcos(h.x), hcos(h.y)};
1693
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2cos(const __hip_bfloat162 h) {
1694
+ __hip_bfloat162_raw hr = h;
1695
+ return __hip_bfloat162(hcos(__hip_bfloat16_raw{hr.x}), hcos(__hip_bfloat16_raw{hr.y}));
1232
1696
  }
1233
1697
 
1234
1698
  /**
1235
1699
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1236
1700
  * \brief Calculate exponential of bfloat162
1237
1701
  */
1238
- __device__ __hip_bfloat162 h2exp(const __hip_bfloat162 h) {
1239
- return __hip_bfloat162{hexp(h.x), hexp(h.y)};
1702
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp(const __hip_bfloat162 h) {
1703
+ __hip_bfloat162_raw hr = h;
1704
+ return __hip_bfloat162(hexp(__hip_bfloat16_raw{hr.x}), hexp(__hip_bfloat16_raw{hr.y}));
1240
1705
  }
1241
1706
 
1242
1707
  /**
1243
1708
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1244
1709
  * \brief Calculate exponential 10 of bfloat162
1245
1710
  */
1246
- __device__ __hip_bfloat162 h2exp10(const __hip_bfloat162 h) {
1247
- return __hip_bfloat162{hexp10(h.x), hexp10(h.y)};
1711
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp10(const __hip_bfloat162 h) {
1712
+ __hip_bfloat162_raw hr = h;
1713
+ return __hip_bfloat162(hexp10(__hip_bfloat16_raw{hr.x}), hexp10(__hip_bfloat16_raw{hr.y}));
1248
1714
  }
1249
1715
 
1250
1716
  /**
1251
1717
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1252
1718
  * \brief Calculate exponential 2 of bfloat162
1253
1719
  */
1254
- __device__ __hip_bfloat162 h2exp2(const __hip_bfloat162 h) {
1255
- return __hip_bfloat162{hexp2(h.x), hexp2(h.y)};
1720
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp2(const __hip_bfloat162 h) {
1721
+ __hip_bfloat162_raw hr = h;
1722
+ return __hip_bfloat162(hexp2(__hip_bfloat16_raw{hr.x}), hexp2(__hip_bfloat16_raw{hr.y}));
1256
1723
  }
1257
1724
 
1258
1725
  /**
1259
1726
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1260
1727
  * \brief Calculate floor of bfloat162
1261
1728
  */
1262
- __device__ __hip_bfloat162 h2floor(const __hip_bfloat162 h) {
1263
- return __hip_bfloat162{hfloor(h.x), hfloor(h.y)};
1729
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2floor(const __hip_bfloat162 h) {
1730
+ __hip_bfloat162_raw hr = h;
1731
+ return __hip_bfloat162(hfloor(__hip_bfloat16_raw{hr.x}), hfloor(__hip_bfloat16_raw{hr.y}));
1264
1732
  }
1265
1733
 
1266
1734
  /**
1267
1735
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1268
1736
  * \brief Calculate natural log of bfloat162
1269
1737
  */
1270
- __device__ __hip_bfloat162 h2log(const __hip_bfloat162 h) {
1271
- return __hip_bfloat162{hlog(h.x), hlog(h.y)};
1738
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2log(const __hip_bfloat162 h) {
1739
+ __hip_bfloat162_raw hr = h;
1740
+ return __hip_bfloat162(hlog(__hip_bfloat16_raw{hr.x}), hlog(__hip_bfloat16_raw{hr.y}));
1272
1741
  }
1273
1742
 
1274
1743
  /**
1275
1744
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1276
1745
  * \brief Calculate log 10 of bfloat162
1277
1746
  */
1278
- __device__ __hip_bfloat162 h2log10(const __hip_bfloat162 h) {
1279
- return __hip_bfloat162{hlog10(h.x), hlog10(h.y)};
1747
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2log10(const __hip_bfloat162 h) {
1748
+ __hip_bfloat162_raw hr = h;
1749
+ return __hip_bfloat162(hlog10(__hip_bfloat16_raw{hr.x}), hlog10(__hip_bfloat16_raw{hr.y}));
1280
1750
  }
1281
1751
 
1282
1752
  /**
1283
1753
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1284
1754
  * \brief Calculate log 2 of bfloat162
1285
1755
  */
1286
- __device__ __hip_bfloat162 h2log2(const __hip_bfloat162 h) {
1287
- return __hip_bfloat162{hlog2(h.x), hlog2(h.y)};
1756
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2log2(const __hip_bfloat162 h) {
1757
+ __hip_bfloat162_raw hr = h;
1758
+ return __hip_bfloat162(hlog2(__hip_bfloat16_raw{hr.x}), hlog2(__hip_bfloat16_raw{hr.y}));
1288
1759
  }
1289
1760
 
1290
1761
  /**
1291
1762
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1292
1763
  * \brief Calculate vector reciprocal
1293
1764
  */
1294
- __device__ __hip_bfloat162 h2rcp(const __hip_bfloat162 h) {
1295
- return __hip_bfloat162{hrcp(h.x), hrcp(h.y)};
1765
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2rcp(const __hip_bfloat162 h) {
1766
+ __hip_bfloat162_raw hr = h;
1767
+ return __hip_bfloat162(hrcp(__hip_bfloat16_raw{hr.x}), hrcp(__hip_bfloat16_raw{hr.y}));
1296
1768
  }
1297
1769
 
1298
1770
  /**
1299
1771
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1300
1772
  * \brief Calculate vector round to nearest int
1301
1773
  */
1302
- __device__ __hip_bfloat162 h2rint(const __hip_bfloat162 h) {
1303
- return __hip_bfloat162{hrint(h.x), hrint(h.y)};
1774
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2rint(const __hip_bfloat162 h) {
1775
+ __hip_bfloat162_raw hr = h;
1776
+ return __hip_bfloat162(hrint(__hip_bfloat16_raw{hr.x}), hrint(__hip_bfloat16_raw{hr.y}));
1304
1777
  }
1305
1778
 
1306
1779
  /**
1307
1780
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1308
1781
  * \brief Calculate vector reciprocal square root
1309
1782
  */
1310
- __device__ __hip_bfloat162 h2rsqrt(const __hip_bfloat162 h) {
1311
- return __hip_bfloat162{hrsqrt(h.x), hrsqrt(h.y)};
1783
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2rsqrt(const __hip_bfloat162 h) {
1784
+ __hip_bfloat162_raw hr = h;
1785
+ return __hip_bfloat162(hrsqrt(__hip_bfloat16_raw{hr.x}), hrsqrt(__hip_bfloat16_raw{hr.y}));
1312
1786
  }
1313
1787
 
1314
1788
  /**
1315
1789
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1316
1790
  * \brief Calculate sin of bfloat162
1317
1791
  */
1318
- __device__ __hip_bfloat162 h2sin(const __hip_bfloat162 h) {
1319
- return __hip_bfloat162{hsin(h.x), hsin(h.y)};
1792
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2sin(const __hip_bfloat162 h) {
1793
+ __hip_bfloat162_raw hr = h;
1794
+ return __hip_bfloat162(hsin(__hip_bfloat16_raw{hr.x}), hsin(__hip_bfloat16_raw{hr.y}));
1320
1795
  }
1321
1796
 
1322
1797
  /**
1323
1798
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1324
1799
  * \brief Calculate sqrt of bfloat162
1325
1800
  */
1326
- __device__ __hip_bfloat162 h2sqrt(const __hip_bfloat162 h) {
1327
- return __hip_bfloat162{hsqrt(h.x), hsqrt(h.y)};
1801
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2sqrt(const __hip_bfloat162 h) {
1802
+ __hip_bfloat162_raw hr = h;
1803
+ return __hip_bfloat162(hsqrt(__hip_bfloat16_raw{hr.x}), hsqrt(__hip_bfloat16_raw{hr.y}));
1328
1804
  }
1329
1805
 
1330
1806
  /**
1331
1807
  * \ingroup HIP_INTRINSIC_BFLOAT162_MATH
1332
1808
  * \brief Calculate truncate of bfloat162
1333
1809
  */
1334
- __device__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) {
1335
- return __hip_bfloat162{htrunc(h.x), htrunc(h.y)};
1810
+ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) {
1811
+ __hip_bfloat162_raw hr = h;
1812
+ return __hip_bfloat162(htrunc(__hip_bfloat16_raw{hr.x}), htrunc(__hip_bfloat16_raw{hr.y}));
1336
1813
  }
1337
1814
  #endif