triton-windows 3.2.0.post12__cp312-cp312-win_amd64.whl → 3.3.0a0.post12__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* MIT License
|
|
3
3
|
*
|
|
4
|
-
* Copyright (c) 2019 -
|
|
4
|
+
* Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved.
|
|
5
5
|
*
|
|
6
6
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
7
|
* of this software and associated documentation files (the "Software"), to deal
|
|
@@ -81,6 +81,17 @@
|
|
|
81
81
|
* To use these functions, include the header file \p hip_bf16.h in your program.
|
|
82
82
|
*/
|
|
83
83
|
|
|
84
|
+
/**
|
|
85
|
+
* \defgroup HIP_INTRINSIC_BFLOAT16_RAW Bfloat16 Raw Struct
|
|
86
|
+
* \ingroup HIP_INTRINSIC_BFLOAT16
|
|
87
|
+
* To use these functions, include the header file \p hip_bf16.h in your program.
|
|
88
|
+
*/
|
|
89
|
+
|
|
90
|
+
/**
|
|
91
|
+
* \defgroup HIP_INTRINSIC_BFLOAT162_RAW Bfloat162 Raw Struct
|
|
92
|
+
* \ingroup HIP_INTRINSIC_BFLOAT16
|
|
93
|
+
* To use these functions, include the header file \p hip_bf16.h in your program.
|
|
94
|
+
*/
|
|
84
95
|
|
|
85
96
|
#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_
|
|
86
97
|
#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_
|
|
@@ -93,13 +104,30 @@
|
|
|
93
104
|
#include "device_library_decls.h" // ocml conversion functions
|
|
94
105
|
#include "math_fwd.h" // ocml device functions
|
|
95
106
|
|
|
107
|
+
#define __BF16_DEVICE__ __device__
|
|
96
108
|
#if defined(__HIPCC_RTC__)
|
|
97
|
-
#define
|
|
109
|
+
#define __BF16_HOST_DEVICE__ __BF16_DEVICE__
|
|
98
110
|
#else
|
|
99
111
|
#include <algorithm>
|
|
100
112
|
#include <climits>
|
|
101
113
|
#include <cmath>
|
|
102
|
-
#define
|
|
114
|
+
#define __BF16_HOST_DEVICE__ __host__ __BF16_DEVICE__
|
|
115
|
+
#endif
|
|
116
|
+
#define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline
|
|
117
|
+
#define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline
|
|
118
|
+
|
|
119
|
+
#if defined(__AVX512VL__) and defined(__AVX512BF16__) and not defined(__HIP_DEVICE_COMPILE__)
|
|
120
|
+
// Enable with -mavx512vl -mavx512bf16
|
|
121
|
+
#if defined(__MINGW64__)
|
|
122
|
+
#include <intrin.h>
|
|
123
|
+
#else
|
|
124
|
+
#include <immintrin.h>
|
|
125
|
+
#endif
|
|
126
|
+
#define HIP_BF16_AVX512_OP 1
|
|
127
|
+
static_assert(sizeof(__bf16) == sizeof(unsigned short),
|
|
128
|
+
"sizeof __bf16 should match sizeof unsigned short");
|
|
129
|
+
#else
|
|
130
|
+
#define HIP_BF16_AVX512_OP 0
|
|
103
131
|
#endif
|
|
104
132
|
|
|
105
133
|
#define HIPRT_ONE_BF16 __float2bfloat16(1.0f)
|
|
@@ -118,72 +146,361 @@ static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
|
|
|
118
146
|
#endif
|
|
119
147
|
static_assert(sizeof(unsigned short) == 2, "size of unsigned short should be 2 bytes");
|
|
120
148
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
149
|
+
/**
|
|
150
|
+
* \ingroup HIP_INTRINSIC_BFLOAT16_RAW
|
|
151
|
+
* \brief represents raw bfloat16 type
|
|
152
|
+
*/
|
|
153
|
+
typedef struct __attribute__((aligned(2))) {
|
|
154
|
+
unsigned short x;
|
|
155
|
+
} __hip_bfloat16_raw;
|
|
156
|
+
|
|
157
|
+
/**
|
|
158
|
+
* \ingroup HIP_INTRINSIC_BFLOAT162_RAW
|
|
159
|
+
* \brief represents raw bfloat16x2 vector type
|
|
160
|
+
*/
|
|
161
|
+
typedef struct __attribute__((aligned(4))) {
|
|
162
|
+
unsigned short x;
|
|
163
|
+
unsigned short y;
|
|
164
|
+
} __hip_bfloat162_raw;
|
|
165
|
+
|
|
166
|
+
/**
|
|
167
|
+
* \defgroup HIP_INTRINSIC_BFLOAT16_STRUCT
|
|
168
|
+
* \ingroup HIP_INTRINSIC_BFLOAT16
|
|
169
|
+
* \brief Struct to represent a 16 bit brain floating point number.
|
|
170
|
+
* @{
|
|
171
|
+
*/
|
|
172
|
+
struct __attribute__((aligned(2))) __hip_bfloat16 {
|
|
173
|
+
private:
|
|
174
|
+
__BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) {
|
|
175
|
+
#if HIP_BF16_AVX512_OP
|
|
176
|
+
union {
|
|
177
|
+
unsigned short us;
|
|
178
|
+
__bf16 bf16;
|
|
179
|
+
} u = {val};
|
|
180
|
+
return _mm_cvtsbh_ss(u.bf16);
|
|
181
|
+
#else
|
|
182
|
+
unsigned int uval = val << 16;
|
|
183
|
+
union {
|
|
184
|
+
unsigned int u32;
|
|
185
|
+
float fp32;
|
|
186
|
+
} u = {uval};
|
|
187
|
+
return u.fp32;
|
|
188
|
+
#endif
|
|
189
|
+
}
|
|
190
|
+
__BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) {
|
|
191
|
+
#if HIP_BF16_AVX512_OP
|
|
192
|
+
union {
|
|
193
|
+
__bf16 bf16;
|
|
194
|
+
unsigned short us;
|
|
195
|
+
} u = {_mm_cvtness_sbh(f)};
|
|
196
|
+
return u.us;
|
|
197
|
+
#else
|
|
198
|
+
union {
|
|
199
|
+
float fp32;
|
|
200
|
+
unsigned int u32;
|
|
201
|
+
} u = {f};
|
|
202
|
+
if (~u.u32 & 0x7f800000) {
|
|
203
|
+
// When the exponent bits are not all 1s, then the value is zero, normal,
|
|
204
|
+
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
|
205
|
+
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
|
206
|
+
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
|
207
|
+
// least significant bits of the float mantissa are greater than 0x8000,
|
|
208
|
+
// or if they are equal to 0x8000 and the least significant bit of the
|
|
209
|
+
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
|
210
|
+
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
|
211
|
+
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
|
212
|
+
// the exponent is incremented by one, which is the next higher FP value
|
|
213
|
+
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
|
214
|
+
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
|
215
|
+
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
|
216
|
+
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
|
217
|
+
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
|
218
|
+
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
|
219
|
+
u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even
|
|
220
|
+
} else if (u.u32 & 0xffff) {
|
|
221
|
+
// When all of the exponent bits are 1, the value is Inf or NaN.
|
|
222
|
+
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
|
223
|
+
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
|
224
|
+
// bit being 1. Signaling NaN is indicated by the most significant
|
|
225
|
+
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
|
226
|
+
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
|
227
|
+
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
|
228
|
+
// the bloat16's mantissa bits are all 0.
|
|
229
|
+
u.u32 |= 0x10000; // Preserve signaling NaN
|
|
230
|
+
}
|
|
231
|
+
return static_cast<unsigned short>(u.u32 >> 16);
|
|
232
|
+
#endif
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
__BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) {
|
|
236
|
+
union {
|
|
237
|
+
float fp32;
|
|
238
|
+
unsigned int u32;
|
|
239
|
+
} u = {static_cast<float>(d_in)};
|
|
240
|
+
double d = u.fp32;
|
|
241
|
+
|
|
242
|
+
// Round to odd
|
|
243
|
+
if ((d_in > 0.0 && d > d_in) || (d_in < 0.0 && d < d_in)) {
|
|
244
|
+
u.u32--;
|
|
245
|
+
u.u32 |= 1;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
return float_2_bfloatraw(u.fp32);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
protected:
|
|
252
|
+
/*! \brief raw representation of bfloat16 */
|
|
253
|
+
unsigned short __x;
|
|
254
|
+
|
|
255
|
+
public:
|
|
256
|
+
// TODO: SWDEV-452411
|
|
257
|
+
// Need to add constructor of __hip_bfloat16 from
|
|
258
|
+
// unsigned long long
|
|
259
|
+
// long long
|
|
260
|
+
// long
|
|
261
|
+
// unsigned long
|
|
262
|
+
// Casting directly to double might lead to double rounding.
|
|
263
|
+
|
|
264
|
+
/*! \brief create __hip_bfloat16 from an unsigned int */
|
|
265
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16(unsigned int val)
|
|
266
|
+
: __x(double_2_bfloatraw(static_cast<double>(val))) {}
|
|
267
|
+
|
|
268
|
+
/*! \brief create __hip_bfloat16 from a int */
|
|
269
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16(int val)
|
|
270
|
+
: __x(double_2_bfloatraw(static_cast<double>(val))) {}
|
|
271
|
+
|
|
272
|
+
/*! \brief create __hip_bfloat16 from an unsigned short */
|
|
273
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16(unsigned short val)
|
|
274
|
+
: __x(float_2_bfloatraw(static_cast<float>(val))) {}
|
|
275
|
+
|
|
276
|
+
/*! \brief create __hip_bfloat16 from a short */
|
|
277
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16(short val)
|
|
278
|
+
: __x(float_2_bfloatraw(static_cast<float>(val))) {}
|
|
279
|
+
|
|
280
|
+
/*! \brief create __hip_bfloat16 from a double */
|
|
281
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16(const double val) : __x(double_2_bfloatraw(val)) {}
|
|
282
|
+
|
|
283
|
+
/*! \brief create __hip_bfloat16 from a float */
|
|
284
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16(const float val) : __x(float_2_bfloatraw(val)) {}
|
|
285
|
+
|
|
286
|
+
/*! \brief create __hip_bfloat16 from a __hip_bfloat16_raw */
|
|
287
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16(const __hip_bfloat16_raw& val) : __x(val.x) {}
|
|
288
|
+
|
|
289
|
+
/*! \brief default constructor */
|
|
290
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16() = default;
|
|
291
|
+
|
|
292
|
+
/*! \brief return a __hip_bfloat16_raw */
|
|
293
|
+
__BF16_HOST_DEVICE__ operator __hip_bfloat16_raw() const { return __hip_bfloat16_raw{__x}; }
|
|
294
|
+
|
|
295
|
+
/*! \brief return a __hip_bfloat16_raw cv qualifier */
|
|
296
|
+
__BF16_HOST_DEVICE__ operator __hip_bfloat16_raw() const volatile {
|
|
297
|
+
return __hip_bfloat16_raw{__x};
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
/*! \brief return false if bfloat value is +0.0 or -0.0, returns true otherwise */
|
|
301
|
+
__BF16_HOST_DEVICE__ operator bool() const {
|
|
302
|
+
auto val = bfloatraw_2_float(__x);
|
|
303
|
+
return val != 0.0f && val != -0.0f;
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
/*! \brief return a casted char from underlying float val */
|
|
307
|
+
__BF16_HOST_DEVICE__ operator char() const { return static_cast<char>(bfloatraw_2_float(__x)); }
|
|
308
|
+
|
|
309
|
+
/*! \brief return a float */
|
|
310
|
+
__BF16_HOST_DEVICE__ operator float() const { return bfloatraw_2_float(__x); }
|
|
311
|
+
|
|
312
|
+
/*! \brief return a casted int casted from float of underlying bfloat16 value */
|
|
313
|
+
__BF16_HOST_DEVICE__ operator int() const { return static_cast<int>(bfloatraw_2_float(__x)); }
|
|
314
|
+
|
|
315
|
+
/*! \brief return a casted long casted from float of underlying bfloat16 value */
|
|
316
|
+
__BF16_HOST_DEVICE__ operator long() const { return static_cast<long>(bfloatraw_2_float(__x)); }
|
|
317
|
+
|
|
318
|
+
/*! \brief return a casted long long casted from float of underlying bfloat16 value */
|
|
319
|
+
__BF16_HOST_DEVICE__ operator long long() const {
|
|
320
|
+
return static_cast<long long>(bfloatraw_2_float(__x));
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
/*! \brief return a casted short casted from float of underlying bfloat16 value */
|
|
324
|
+
__BF16_HOST_DEVICE__ operator short() const { return static_cast<short>(bfloatraw_2_float(__x)); }
|
|
325
|
+
|
|
326
|
+
/*! \brief return a casted signed char from float of underlying bfloat16 value */
|
|
327
|
+
__BF16_HOST_DEVICE__ operator signed char() const {
|
|
328
|
+
return static_cast<signed char>(bfloatraw_2_float(__x));
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
/*! \brief return a casted unsigned char casted from float of underlying bfloat16 value */
|
|
332
|
+
__BF16_HOST_DEVICE__ operator unsigned char() const {
|
|
333
|
+
return static_cast<unsigned char>(bfloatraw_2_float(__x));
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
/*! \brief return a casted unsigned int casted from float of underlying bfloat16 value */
|
|
337
|
+
__BF16_HOST_DEVICE__ operator unsigned int() const {
|
|
338
|
+
return static_cast<unsigned int>(bfloatraw_2_float(__x));
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
/*! \brief return a casted unsigned from float of underlying bfloat16 value */
|
|
342
|
+
__BF16_HOST_DEVICE__ operator unsigned long() const {
|
|
343
|
+
return static_cast<unsigned long>(bfloatraw_2_float(__x));
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
/*! \brief return a casted unsigned long long from float of underlying bfloat16 value */
|
|
347
|
+
__BF16_HOST_DEVICE__ operator unsigned long long() const {
|
|
348
|
+
return static_cast<unsigned long long>(bfloatraw_2_float(__x));
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
/*! \brief return a casted unsigned short from float of underlying bfloat16 value */
|
|
352
|
+
__BF16_HOST_DEVICE__ operator unsigned short() const {
|
|
353
|
+
return static_cast<unsigned short>(bfloatraw_2_float(__x));
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
// TODO: SWDEV-452411 add operator which converts unsigned long long and long long to bfloat
|
|
357
|
+
|
|
358
|
+
/*! \brief assign value from an unsigned int */
|
|
359
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned int val) {
|
|
360
|
+
__x = float_2_bfloatraw(static_cast<float>(val));
|
|
361
|
+
return *this;
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
/*! \brief assign value from a int */
|
|
365
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(int val) {
|
|
366
|
+
__x = float_2_bfloatraw(static_cast<float>(val));
|
|
367
|
+
return *this;
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
/*! \brief assign value from an unsigned short */
|
|
371
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned short val) {
|
|
372
|
+
__x = float_2_bfloatraw(static_cast<float>(val));
|
|
373
|
+
return *this;
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
/*! \brief assign value from a short int */
|
|
377
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(short val) {
|
|
378
|
+
__x = float_2_bfloatraw(static_cast<float>(val));
|
|
379
|
+
return *this;
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
/*! \brief assign value from a double */
|
|
383
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const double f) {
|
|
384
|
+
__x = float_2_bfloatraw(static_cast<float>(f));
|
|
385
|
+
return *this;
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
/*! \brief assign value from a float */
|
|
389
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const float f) {
|
|
390
|
+
__x = float_2_bfloatraw(f);
|
|
391
|
+
return *this;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
/*! \brief assign value from a __hip_bfloat16_raw */
|
|
395
|
+
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const __hip_bfloat16_raw& hr) {
|
|
396
|
+
__x = hr.x;
|
|
397
|
+
return *this;
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
/*! \brief assign value from a __hip_bfloat16_raw volatile */
|
|
401
|
+
__BF16_HOST_DEVICE__ volatile __hip_bfloat16& operator=(const __hip_bfloat16_raw& hr) volatile {
|
|
402
|
+
__x = hr.x;
|
|
403
|
+
return *this;
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
/*! \brief assign value from a __hip_bfloat16_raw cv qualifier */
|
|
407
|
+
__BF16_HOST_DEVICE__ volatile __hip_bfloat16& operator=(
|
|
408
|
+
const volatile __hip_bfloat16_raw& hr) volatile {
|
|
409
|
+
__x = hr.x;
|
|
410
|
+
return *this;
|
|
411
|
+
}
|
|
124
412
|
};
|
|
413
|
+
/**@}*/
|
|
414
|
+
|
|
415
|
+
/**
|
|
416
|
+
* \defgroup HIP_INTRINSIC_BFLOAT162_STRUCT
|
|
417
|
+
* \ingroup HIP_INTRINSIC_BFLOAT16
|
|
418
|
+
* \brief Struct to represent a two 16 bit brain floating point number.
|
|
419
|
+
* @{
|
|
420
|
+
*/
|
|
421
|
+
struct __attribute__((aligned(4))) __hip_bfloat162 {
|
|
422
|
+
public:
|
|
423
|
+
__hip_bfloat16 x; /*! \brief raw representation of bfloat16 */
|
|
424
|
+
__hip_bfloat16 y; /*! \brief raw representation of bfloat16 */
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
public:
|
|
428
|
+
/*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */
|
|
429
|
+
__BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162_raw& h2r)
|
|
430
|
+
: x(__hip_bfloat16(__hip_bfloat16_raw{h2r.x})),
|
|
431
|
+
y(__hip_bfloat16(__hip_bfloat16_raw{h2r.y})) {}
|
|
432
|
+
|
|
433
|
+
/*! \brief copy constructor of __hip_bfloat162 */
|
|
434
|
+
__BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162& val) {
|
|
435
|
+
__hip_bfloat162_raw hr = val;
|
|
436
|
+
x = __hip_bfloat16_raw{hr.x};
|
|
437
|
+
y = __hip_bfloat16_raw{hr.y};
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
/*! \brief create __hip_bfloat162 from two __hip_bfloat16 */
|
|
441
|
+
__BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat16& a, const __hip_bfloat16& b)
|
|
442
|
+
: x(a), y(b) {}
|
|
443
|
+
|
|
444
|
+
/*! \brief default constructor of __hip_bfloat162 */
|
|
445
|
+
__BF16_HOST_DEVICE__ __hip_bfloat162() = default;
|
|
446
|
+
|
|
447
|
+
/*! \brief return a __hip_bfloat162_raw */
|
|
448
|
+
__BF16_HOST_DEVICE__ operator __hip_bfloat162_raw() const {
|
|
449
|
+
__hip_bfloat16_raw l = x;
|
|
450
|
+
__hip_bfloat16_raw r = y;
|
|
451
|
+
return __hip_bfloat162_raw{l.x, r.x};
|
|
452
|
+
}
|
|
125
453
|
|
|
126
|
-
/*! \brief
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
454
|
+
/*! \brief return a float2 */
|
|
455
|
+
__BF16_HOST_DEVICE__ operator float2() const {
|
|
456
|
+
#if HIP_BF16_AVX512_OP
|
|
457
|
+
union {
|
|
458
|
+
__hip_bfloat162_raw raw2;
|
|
459
|
+
__bf16 bf162[2];
|
|
460
|
+
static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw));
|
|
461
|
+
} u;
|
|
462
|
+
u.raw2 = *this;
|
|
463
|
+
__m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0};
|
|
464
|
+
__m128 pf32 = _mm_cvtpbh_ps(pbf16);
|
|
465
|
+
float2 ret(pf32[0], pf32[1]);
|
|
466
|
+
#else
|
|
467
|
+
float2 ret(x, y);
|
|
468
|
+
#endif
|
|
469
|
+
return ret;
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
/*! \brief assign value from __hip_bfloat162_raw */
|
|
473
|
+
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162_raw& h2r) {
|
|
474
|
+
x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x});
|
|
475
|
+
y = __hip_bfloat16(__hip_bfloat16_raw{h2r.y});
|
|
476
|
+
return *this;
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
/*! \brief assign value from __hip_bfloat162 */
|
|
480
|
+
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162& src) {
|
|
481
|
+
__hip_bfloat162_raw hr = src;
|
|
482
|
+
x = __hip_bfloat16(__hip_bfloat16_raw{hr.x});
|
|
483
|
+
y = __hip_bfloat16(__hip_bfloat16_raw{hr.y});
|
|
484
|
+
return *this;
|
|
485
|
+
}
|
|
130
486
|
};
|
|
487
|
+
/**@}*/
|
|
131
488
|
|
|
132
489
|
/**
|
|
133
490
|
* \ingroup HIP_INTRINSIC_BFLOAT16_CONV
|
|
134
491
|
* \brief Converts bfloat16 to float
|
|
135
492
|
*/
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
union {
|
|
140
|
-
unsigned int u32;
|
|
141
|
-
float fp32;
|
|
142
|
-
} u = {uval};
|
|
143
|
-
return u.fp32;
|
|
493
|
+
__BF16_HOST_DEVICE_STATIC__ float __bfloat162float(__hip_bfloat16 a) {
|
|
494
|
+
float ret = a;
|
|
495
|
+
return ret;
|
|
144
496
|
}
|
|
145
497
|
|
|
146
498
|
/**
|
|
147
499
|
* \ingroup HIP_INTRINSIC_BFLOAT16_CONV
|
|
148
500
|
* \brief Converts float to bfloat16
|
|
149
501
|
*/
|
|
150
|
-
|
|
151
|
-
__hip_bfloat16 ret;
|
|
152
|
-
union {
|
|
153
|
-
float fp32;
|
|
154
|
-
unsigned int u32;
|
|
155
|
-
} u = {f};
|
|
156
|
-
if (~u.u32 & 0x7f800000) {
|
|
157
|
-
// When the exponent bits are not all 1s, then the value is zero, normal,
|
|
158
|
-
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
|
159
|
-
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
|
160
|
-
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
|
161
|
-
// least significant bits of the float mantissa are greater than 0x8000,
|
|
162
|
-
// or if they are equal to 0x8000 and the least significant bit of the
|
|
163
|
-
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
|
164
|
-
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
|
165
|
-
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
|
166
|
-
// the exponent is incremented by one, which is the next higher FP value
|
|
167
|
-
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
|
168
|
-
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
|
169
|
-
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
|
170
|
-
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
|
171
|
-
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
|
172
|
-
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
|
173
|
-
u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even
|
|
174
|
-
} else if (u.u32 & 0xffff) {
|
|
175
|
-
// When all of the exponent bits are 1, the value is Inf or NaN.
|
|
176
|
-
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
|
177
|
-
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
|
178
|
-
// bit being 1. Signaling NaN is indicated by the most significant
|
|
179
|
-
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
|
180
|
-
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
|
181
|
-
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
|
182
|
-
// the bloat16's mantissa bits are all 0.
|
|
183
|
-
u.u32 |= 0x10000; // Preserve signaling NaN
|
|
184
|
-
}
|
|
185
|
-
|
|
186
|
-
ret.data = (u.u32 >> 16);
|
|
502
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __float2bfloat16(float f) {
|
|
503
|
+
__hip_bfloat16 ret{f};
|
|
187
504
|
return ret;
|
|
188
505
|
}
|
|
189
506
|
|
|
@@ -191,43 +508,51 @@ __HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {
|
|
|
191
508
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
192
509
|
* \brief Converts and moves bfloat162 to float2
|
|
193
510
|
*/
|
|
194
|
-
|
|
195
|
-
|
|
511
|
+
__BF16_HOST_DEVICE_STATIC__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
|
|
512
|
+
float2 ret = a;
|
|
513
|
+
return ret;
|
|
196
514
|
}
|
|
197
515
|
|
|
198
516
|
/**
|
|
199
517
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
200
518
|
* \brief Moves bfloat16 value to bfloat162
|
|
201
519
|
*/
|
|
202
|
-
|
|
203
|
-
return __hip_bfloat162
|
|
520
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) {
|
|
521
|
+
return __hip_bfloat162(a, a);
|
|
204
522
|
}
|
|
205
523
|
|
|
206
524
|
/**
|
|
207
525
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
208
526
|
* \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer
|
|
209
527
|
*/
|
|
210
|
-
|
|
528
|
+
__BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h) {
|
|
529
|
+
short ret = h;
|
|
530
|
+
return ret;
|
|
531
|
+
}
|
|
211
532
|
|
|
212
533
|
/**
|
|
213
534
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
214
535
|
* \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer
|
|
215
536
|
*/
|
|
216
|
-
|
|
537
|
+
__BF16_HOST_DEVICE_STATIC__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) {
|
|
538
|
+
unsigned short ret = h;
|
|
539
|
+
return ret;
|
|
540
|
+
}
|
|
217
541
|
|
|
218
542
|
/**
|
|
219
543
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
220
544
|
* \brief Convert double to __hip_bfloat16
|
|
221
545
|
*/
|
|
222
|
-
|
|
223
|
-
|
|
546
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __double2bfloat16(const double a) {
|
|
547
|
+
__hip_bfloat16 ret{a};
|
|
548
|
+
return ret;
|
|
224
549
|
}
|
|
225
550
|
|
|
226
551
|
/**
|
|
227
552
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
228
553
|
* \brief Convert float2 to __hip_bfloat162
|
|
229
554
|
*/
|
|
230
|
-
|
|
555
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
|
|
231
556
|
return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};
|
|
232
557
|
}
|
|
233
558
|
|
|
@@ -235,97 +560,117 @@ __HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
|
|
|
235
560
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
236
561
|
* \brief Combine two __hip_bfloat16 to __hip_bfloat162
|
|
237
562
|
*/
|
|
238
|
-
|
|
239
|
-
|
|
563
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a,
|
|
564
|
+
const __hip_bfloat16 b) {
|
|
565
|
+
return __hip_bfloat162(a, b);
|
|
240
566
|
}
|
|
241
567
|
|
|
242
568
|
/**
|
|
243
569
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
244
570
|
* \brief Returns high 16 bits of __hip_bfloat162
|
|
245
571
|
*/
|
|
246
|
-
|
|
572
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) {
|
|
573
|
+
__hip_bfloat162_raw hr = a;
|
|
574
|
+
return __hip_bfloat16(__hip_bfloat16_raw{hr.y});
|
|
575
|
+
}
|
|
247
576
|
|
|
248
577
|
/**
|
|
249
578
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
250
579
|
* \brief Returns high 16 bits of __hip_bfloat162
|
|
251
580
|
*/
|
|
252
|
-
|
|
253
|
-
|
|
581
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) {
|
|
582
|
+
__hip_bfloat162_raw hr = a;
|
|
583
|
+
return __hip_bfloat162(__hip_bfloat16_raw{hr.y}, __hip_bfloat16_raw{hr.y});
|
|
254
584
|
}
|
|
255
585
|
|
|
256
586
|
/**
|
|
257
587
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
258
588
|
* \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result
|
|
259
589
|
*/
|
|
260
|
-
|
|
590
|
+
__BF16_HOST_DEVICE_STATIC__ float __high2float(const __hip_bfloat162 a) {
|
|
591
|
+
__hip_bfloat162_raw hr = a;
|
|
592
|
+
return __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.y}));
|
|
593
|
+
}
|
|
261
594
|
|
|
262
595
|
/**
|
|
263
596
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
264
597
|
* \brief Extracts high 16 bits from each and combines them
|
|
265
598
|
*/
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
599
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a,
|
|
600
|
+
const __hip_bfloat162 b) {
|
|
601
|
+
__hip_bfloat162_raw hr_a = a;
|
|
602
|
+
__hip_bfloat162_raw hr_b = b;
|
|
603
|
+
return __hip_bfloat162(__hip_bfloat162_raw{hr_a.y, hr_b.y});
|
|
269
604
|
}
|
|
270
605
|
|
|
271
606
|
/**
|
|
272
607
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
273
608
|
* \brief Returns low 16 bits of __hip_bfloat162
|
|
274
609
|
*/
|
|
275
|
-
|
|
610
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) {
|
|
611
|
+
__hip_bfloat162_raw hr = a;
|
|
612
|
+
return __hip_bfloat16(hr.x);
|
|
613
|
+
}
|
|
276
614
|
|
|
277
615
|
/**
|
|
278
616
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
279
617
|
* \brief Returns low 16 bits of __hip_bfloat162
|
|
280
618
|
*/
|
|
281
|
-
|
|
282
|
-
|
|
619
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) {
|
|
620
|
+
__hip_bfloat162_raw hr = a;
|
|
621
|
+
return __hip_bfloat162(hr.x, hr.x);
|
|
283
622
|
}
|
|
284
623
|
|
|
285
624
|
/**
|
|
286
625
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
287
626
|
* \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result
|
|
288
627
|
*/
|
|
289
|
-
|
|
628
|
+
__BF16_HOST_DEVICE_STATIC__ float __low2float(const __hip_bfloat162 a) {
|
|
629
|
+
__hip_bfloat162_raw hr = a;
|
|
630
|
+
return __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.x}));
|
|
631
|
+
}
|
|
290
632
|
|
|
291
633
|
/**
|
|
292
634
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
293
635
|
* \brief Swaps both halves
|
|
294
636
|
*/
|
|
295
|
-
|
|
296
|
-
|
|
637
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) {
|
|
638
|
+
__hip_bfloat162_raw hr = a;
|
|
639
|
+
return __hip_bfloat162(__hip_bfloat162_raw{hr.y, hr.x});
|
|
297
640
|
}
|
|
298
641
|
|
|
299
642
|
/**
|
|
300
643
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
301
644
|
* \brief Extracts low 16 bits from each and combines them
|
|
302
645
|
*/
|
|
303
|
-
|
|
304
|
-
|
|
646
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a,
|
|
647
|
+
const __hip_bfloat162 b) {
|
|
648
|
+
__hip_bfloat162_raw hr_a = a;
|
|
649
|
+
__hip_bfloat162_raw hr_b = b;
|
|
650
|
+
return __hip_bfloat162(__hip_bfloat162_raw{hr_a.x, hr_b.x});
|
|
305
651
|
}
|
|
306
652
|
|
|
307
653
|
/**
|
|
308
654
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
309
655
|
* \brief Reinterprets short int into a bfloat16
|
|
310
656
|
*/
|
|
311
|
-
|
|
312
|
-
return __hip_bfloat16
|
|
657
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a) {
|
|
658
|
+
return __hip_bfloat16(a);
|
|
313
659
|
}
|
|
314
660
|
|
|
315
661
|
/**
|
|
316
662
|
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
|
317
663
|
* \brief Reinterprets unsigned short int into a bfloat16
|
|
318
664
|
*/
|
|
319
|
-
|
|
320
|
-
return __hip_bfloat16
|
|
665
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
|
|
666
|
+
return __hip_bfloat16(a);
|
|
321
667
|
}
|
|
322
668
|
|
|
323
|
-
|
|
324
669
|
/**
|
|
325
670
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
326
671
|
* \brief Adds two bfloat16 values
|
|
327
672
|
*/
|
|
328
|
-
|
|
673
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
329
674
|
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b));
|
|
330
675
|
}
|
|
331
676
|
|
|
@@ -333,7 +678,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat
|
|
|
333
678
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
334
679
|
* \brief Subtracts two bfloat16 values
|
|
335
680
|
*/
|
|
336
|
-
|
|
681
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
337
682
|
return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b));
|
|
338
683
|
}
|
|
339
684
|
|
|
@@ -341,7 +686,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat
|
|
|
341
686
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
342
687
|
* \brief Divides two bfloat16 values
|
|
343
688
|
*/
|
|
344
|
-
|
|
689
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
345
690
|
return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b));
|
|
346
691
|
}
|
|
347
692
|
|
|
@@ -349,8 +694,8 @@ __HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat
|
|
|
349
694
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
350
695
|
* \brief Performs FMA of given bfloat16 values
|
|
351
696
|
*/
|
|
352
|
-
|
|
353
|
-
|
|
697
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b,
|
|
698
|
+
const __hip_bfloat16 c) {
|
|
354
699
|
return __float2bfloat16(
|
|
355
700
|
__ocml_fma_f32(__bfloat162float(a), __bfloat162float(b), __bfloat162float(c)));
|
|
356
701
|
}
|
|
@@ -359,7 +704,7 @@ __device__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b,
|
|
|
359
704
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
360
705
|
* \brief Multiplies two bfloat16 values
|
|
361
706
|
*/
|
|
362
|
-
|
|
707
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
363
708
|
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b));
|
|
364
709
|
}
|
|
365
710
|
|
|
@@ -367,85 +712,110 @@ __HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat
|
|
|
367
712
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
368
713
|
* \brief Negate a bfloat16 value
|
|
369
714
|
*/
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
return
|
|
715
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) {
|
|
716
|
+
__hip_bfloat16_raw hr = a;
|
|
717
|
+
hr.x ^= 0x8000;
|
|
718
|
+
return __hip_bfloat16(hr);
|
|
374
719
|
}
|
|
375
720
|
|
|
376
721
|
/**
|
|
377
722
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
378
723
|
* \brief Returns absolute of a bfloat16
|
|
379
724
|
*/
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
return
|
|
725
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __habs(const __hip_bfloat16 a) {
|
|
726
|
+
__hip_bfloat16_raw hr = a;
|
|
727
|
+
hr.x &= 0x7FFF;
|
|
728
|
+
return __hip_bfloat16(hr);
|
|
384
729
|
}
|
|
385
730
|
|
|
386
731
|
/**
|
|
387
732
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
388
733
|
* \brief Divides bfloat162 values
|
|
389
734
|
*/
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
735
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __h2div(const __hip_bfloat162 a,
|
|
736
|
+
const __hip_bfloat162 b) {
|
|
737
|
+
__hip_bfloat162_raw hr_a = a;
|
|
738
|
+
__hip_bfloat162_raw hr_b = b;
|
|
739
|
+
return __hip_bfloat162(__float2bfloat16(__bfloat162float(__hip_bfloat16_raw{hr_a.x}) /
|
|
740
|
+
__bfloat162float(__hip_bfloat16_raw{hr_b.x})),
|
|
741
|
+
__float2bfloat16(__bfloat162float(__hip_bfloat16_raw{hr_a.y}) /
|
|
742
|
+
__bfloat162float(__hip_bfloat16_raw{hr_b.y})));
|
|
393
743
|
}
|
|
394
744
|
|
|
395
745
|
/**
|
|
396
746
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
397
747
|
* \brief Returns absolute of a bfloat162
|
|
398
748
|
*/
|
|
399
|
-
|
|
400
|
-
|
|
749
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) {
|
|
750
|
+
__hip_bfloat162_raw hr_a = a;
|
|
751
|
+
return __hip_bfloat162(__habs(__hip_bfloat16_raw{hr_a.x}), __habs(__hip_bfloat16_raw{hr_a.y}));
|
|
401
752
|
}
|
|
402
753
|
|
|
403
754
|
/**
|
|
404
755
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
405
756
|
* \brief Adds two bfloat162 values
|
|
406
757
|
*/
|
|
407
|
-
|
|
408
|
-
|
|
758
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a,
|
|
759
|
+
const __hip_bfloat162 b) {
|
|
760
|
+
__hip_bfloat162_raw hr_a = a;
|
|
761
|
+
__hip_bfloat162_raw hr_b = b;
|
|
762
|
+
return __hip_bfloat162(__hadd(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
|
|
763
|
+
__hadd(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
|
|
409
764
|
}
|
|
410
765
|
|
|
411
766
|
/**
|
|
412
767
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
413
768
|
* \brief Performs FMA of given bfloat162 values
|
|
414
769
|
*/
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
770
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat162 b,
|
|
771
|
+
const __hip_bfloat162 c) {
|
|
772
|
+
__hip_bfloat162_raw hr_a = a;
|
|
773
|
+
__hip_bfloat162_raw hr_b = b;
|
|
774
|
+
__hip_bfloat162_raw hr_c = c;
|
|
775
|
+
return __hip_bfloat162(
|
|
776
|
+
__hfma(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}, __hip_bfloat16_raw{hr_c.x}),
|
|
777
|
+
__hfma(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}, __hip_bfloat16_raw{hr_c.y}));
|
|
418
778
|
}
|
|
419
779
|
|
|
420
780
|
/**
|
|
421
781
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
422
782
|
* \brief Multiplies two bfloat162 values
|
|
423
783
|
*/
|
|
424
|
-
|
|
425
|
-
|
|
784
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a,
|
|
785
|
+
const __hip_bfloat162 b) {
|
|
786
|
+
__hip_bfloat162_raw hr_a = a;
|
|
787
|
+
__hip_bfloat162_raw hr_b = b;
|
|
788
|
+
return __hip_bfloat162(__hmul(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
|
|
789
|
+
__hmul(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
|
|
426
790
|
}
|
|
427
791
|
|
|
428
792
|
/**
|
|
429
793
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
430
794
|
* \brief Converts a bfloat162 into negative
|
|
431
795
|
*/
|
|
432
|
-
|
|
433
|
-
|
|
796
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) {
|
|
797
|
+
__hip_bfloat162_raw hr_a = a;
|
|
798
|
+
return __hip_bfloat162(__hneg(__hip_bfloat16_raw{hr_a.x}), __hneg(__hip_bfloat16_raw{hr_a.y}));
|
|
434
799
|
}
|
|
435
800
|
|
|
436
801
|
/**
|
|
437
802
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
438
803
|
* \brief Subtracts two bfloat162 values
|
|
439
804
|
*/
|
|
440
|
-
|
|
441
|
-
|
|
805
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a,
|
|
806
|
+
const __hip_bfloat162 b) {
|
|
807
|
+
__hip_bfloat162_raw hr_a = a;
|
|
808
|
+
__hip_bfloat162_raw hr_b = b;
|
|
809
|
+
return __hip_bfloat162(__hsub(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
|
|
810
|
+
__hsub(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
|
|
442
811
|
}
|
|
443
812
|
|
|
444
813
|
/**
|
|
445
814
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
446
815
|
* \brief Operator to multiply two __hip_bfloat16 numbers
|
|
447
816
|
*/
|
|
448
|
-
|
|
817
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator*(const __hip_bfloat16& l,
|
|
818
|
+
const __hip_bfloat16& r) {
|
|
449
819
|
return __hmul(l, r);
|
|
450
820
|
}
|
|
451
821
|
|
|
@@ -453,7 +823,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bf
|
|
|
453
823
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
454
824
|
* \brief Operator to multiply-assign two __hip_bfloat16 numbers
|
|
455
825
|
*/
|
|
456
|
-
|
|
826
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
457
827
|
l = __hmul(l, r);
|
|
458
828
|
return l;
|
|
459
829
|
}
|
|
@@ -462,13 +832,14 @@ __HOST_DEVICE__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat
|
|
|
462
832
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
463
833
|
* \brief Operator to unary+ on a __hip_bfloat16 number
|
|
464
834
|
*/
|
|
465
|
-
|
|
835
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; }
|
|
466
836
|
|
|
467
837
|
/**
|
|
468
838
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
469
839
|
* \brief Operator to add two __hip_bfloat16 numbers
|
|
470
840
|
*/
|
|
471
|
-
|
|
841
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16& l,
|
|
842
|
+
const __hip_bfloat16& r) {
|
|
472
843
|
return __hadd(l, r);
|
|
473
844
|
}
|
|
474
845
|
|
|
@@ -476,13 +847,14 @@ __HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bf
|
|
|
476
847
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
477
848
|
* \brief Operator to negate a __hip_bfloat16 number
|
|
478
849
|
*/
|
|
479
|
-
|
|
850
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); }
|
|
480
851
|
|
|
481
852
|
/**
|
|
482
853
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
483
854
|
* \brief Operator to subtract two __hip_bfloat16 numbers
|
|
484
855
|
*/
|
|
485
|
-
|
|
856
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16& l,
|
|
857
|
+
const __hip_bfloat16& r) {
|
|
486
858
|
return __hsub(l, r);
|
|
487
859
|
}
|
|
488
860
|
|
|
@@ -490,7 +862,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bf
|
|
|
490
862
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
491
863
|
* \brief Operator to post increment a __hip_bfloat16 number
|
|
492
864
|
*/
|
|
493
|
-
|
|
865
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) {
|
|
494
866
|
auto ret = l;
|
|
495
867
|
l = __hadd(l, HIPRT_ONE_BF16);
|
|
496
868
|
return ret;
|
|
@@ -500,7 +872,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) {
|
|
|
500
872
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
501
873
|
* \brief Operator to pre increment a __hip_bfloat16 number
|
|
502
874
|
*/
|
|
503
|
-
|
|
875
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator++(__hip_bfloat16& l) {
|
|
504
876
|
l = __hadd(l, HIPRT_ONE_BF16);
|
|
505
877
|
return l;
|
|
506
878
|
}
|
|
@@ -509,7 +881,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) {
|
|
|
509
881
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
510
882
|
* \brief Operator to post decrement a __hip_bfloat16 number
|
|
511
883
|
*/
|
|
512
|
-
|
|
884
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) {
|
|
513
885
|
auto ret = l;
|
|
514
886
|
l = __hsub(l, HIPRT_ONE_BF16);
|
|
515
887
|
return ret;
|
|
@@ -519,7 +891,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) {
|
|
|
519
891
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
520
892
|
* \brief Operator to pre decrement a __hip_bfloat16 number
|
|
521
893
|
*/
|
|
522
|
-
|
|
894
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator--(__hip_bfloat16& l) {
|
|
523
895
|
l = __hsub(l, HIPRT_ONE_BF16);
|
|
524
896
|
return l;
|
|
525
897
|
}
|
|
@@ -528,7 +900,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) {
|
|
|
528
900
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
529
901
|
* \brief Operator to add-assign two __hip_bfloat16 numbers
|
|
530
902
|
*/
|
|
531
|
-
|
|
903
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
532
904
|
l = __hadd(l, r);
|
|
533
905
|
return l;
|
|
534
906
|
}
|
|
@@ -537,7 +909,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat
|
|
|
537
909
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
538
910
|
* \brief Operator to subtract-assign two __hip_bfloat16 numbers
|
|
539
911
|
*/
|
|
540
|
-
|
|
912
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
541
913
|
l = __hsub(l, r);
|
|
542
914
|
return l;
|
|
543
915
|
}
|
|
@@ -546,7 +918,8 @@ __HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat
|
|
|
546
918
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
547
919
|
* \brief Operator to divide two __hip_bfloat16 numbers
|
|
548
920
|
*/
|
|
549
|
-
|
|
921
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator/(const __hip_bfloat16& l,
|
|
922
|
+
const __hip_bfloat16& r) {
|
|
550
923
|
return __hdiv(l, r);
|
|
551
924
|
}
|
|
552
925
|
|
|
@@ -554,7 +927,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bf
|
|
|
554
927
|
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
|
555
928
|
* \brief Operator to divide-assign two __hip_bfloat16 numbers
|
|
556
929
|
*/
|
|
557
|
-
|
|
930
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
558
931
|
l = __hdiv(l, r);
|
|
559
932
|
return l;
|
|
560
933
|
}
|
|
@@ -563,7 +936,8 @@ __HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat
|
|
|
563
936
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
564
937
|
* \brief Operator to multiply two __hip_bfloat162 numbers
|
|
565
938
|
*/
|
|
566
|
-
|
|
939
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator*(const __hip_bfloat162& l,
|
|
940
|
+
const __hip_bfloat162& r) {
|
|
567
941
|
return __hmul2(l, r);
|
|
568
942
|
}
|
|
569
943
|
|
|
@@ -571,7 +945,8 @@ __HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_
|
|
|
571
945
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
572
946
|
* \brief Operator to multiply-assign two __hip_bfloat162 numbers
|
|
573
947
|
*/
|
|
574
|
-
|
|
948
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator*=(__hip_bfloat162& l,
|
|
949
|
+
const __hip_bfloat162& r) {
|
|
575
950
|
l = __hmul2(l, r);
|
|
576
951
|
return l;
|
|
577
952
|
}
|
|
@@ -580,13 +955,14 @@ __HOST_DEVICE__ __hip_bfloat162& operator*=(__hip_bfloat162& l, const __hip_bflo
|
|
|
580
955
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
581
956
|
* \brief Operator to unary+ on a __hip_bfloat162 number
|
|
582
957
|
*/
|
|
583
|
-
|
|
958
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; }
|
|
584
959
|
|
|
585
960
|
/**
|
|
586
961
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
587
962
|
* \brief Operator to add two __hip_bfloat162 numbers
|
|
588
963
|
*/
|
|
589
|
-
|
|
964
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator+(const __hip_bfloat162& l,
|
|
965
|
+
const __hip_bfloat162& r) {
|
|
590
966
|
return __hadd2(l, r);
|
|
591
967
|
}
|
|
592
968
|
|
|
@@ -594,13 +970,16 @@ __HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_
|
|
|
594
970
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
595
971
|
* \brief Operator to negate a __hip_bfloat162 number
|
|
596
972
|
*/
|
|
597
|
-
|
|
973
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator-(const __hip_bfloat162& l) {
|
|
974
|
+
return __hneg2(l);
|
|
975
|
+
}
|
|
598
976
|
|
|
599
977
|
/**
|
|
600
978
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
601
979
|
* \brief Operator to subtract two __hip_bfloat162 numbers
|
|
602
980
|
*/
|
|
603
|
-
|
|
981
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator-(const __hip_bfloat162& l,
|
|
982
|
+
const __hip_bfloat162& r) {
|
|
604
983
|
return __hsub2(l, r);
|
|
605
984
|
}
|
|
606
985
|
|
|
@@ -608,7 +987,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_
|
|
|
608
987
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
609
988
|
* \brief Operator to post increment a __hip_bfloat162 number
|
|
610
989
|
*/
|
|
611
|
-
|
|
990
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) {
|
|
612
991
|
auto ret = l;
|
|
613
992
|
l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
|
|
614
993
|
return ret;
|
|
@@ -618,7 +997,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) {
|
|
|
618
997
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
619
998
|
* \brief Operator to pre increment a __hip_bfloat162 number
|
|
620
999
|
*/
|
|
621
|
-
|
|
1000
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator++(__hip_bfloat162& l) {
|
|
622
1001
|
l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
|
|
623
1002
|
return l;
|
|
624
1003
|
}
|
|
@@ -627,7 +1006,7 @@ __HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) {
|
|
|
627
1006
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
628
1007
|
* \brief Operator to post decrement a __hip_bfloat162 number
|
|
629
1008
|
*/
|
|
630
|
-
|
|
1009
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) {
|
|
631
1010
|
auto ret = l;
|
|
632
1011
|
l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
|
|
633
1012
|
return ret;
|
|
@@ -637,7 +1016,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) {
|
|
|
637
1016
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
638
1017
|
* \brief Operator to pre decrement a __hip_bfloat162 number
|
|
639
1018
|
*/
|
|
640
|
-
|
|
1019
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator--(__hip_bfloat162& l) {
|
|
641
1020
|
l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16});
|
|
642
1021
|
return l;
|
|
643
1022
|
}
|
|
@@ -646,7 +1025,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) {
|
|
|
646
1025
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
647
1026
|
* \brief Operator to add-assign two __hip_bfloat162 numbers
|
|
648
1027
|
*/
|
|
649
|
-
|
|
1028
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator+=(__hip_bfloat162& l,
|
|
1029
|
+
const __hip_bfloat162& r) {
|
|
650
1030
|
l = __hadd2(l, r);
|
|
651
1031
|
return l;
|
|
652
1032
|
}
|
|
@@ -655,7 +1035,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bflo
|
|
|
655
1035
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
656
1036
|
* \brief Operator to subtract-assign two __hip_bfloat162 numbers
|
|
657
1037
|
*/
|
|
658
|
-
|
|
1038
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator-=(__hip_bfloat162& l,
|
|
1039
|
+
const __hip_bfloat162& r) {
|
|
659
1040
|
l = __hsub2(l, r);
|
|
660
1041
|
return l;
|
|
661
1042
|
}
|
|
@@ -664,7 +1045,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bflo
|
|
|
664
1045
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
665
1046
|
* \brief Operator to divide two __hip_bfloat162 numbers
|
|
666
1047
|
*/
|
|
667
|
-
|
|
1048
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator/(const __hip_bfloat162& l,
|
|
1049
|
+
const __hip_bfloat162& r) {
|
|
668
1050
|
return __h2div(l, r);
|
|
669
1051
|
}
|
|
670
1052
|
|
|
@@ -672,7 +1054,8 @@ __HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_
|
|
|
672
1054
|
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
|
673
1055
|
* \brief Operator to divide-assign two __hip_bfloat162 numbers
|
|
674
1056
|
*/
|
|
675
|
-
|
|
1057
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator/=(__hip_bfloat162& l,
|
|
1058
|
+
const __hip_bfloat162& r) {
|
|
676
1059
|
l = __h2div(l, r);
|
|
677
1060
|
return l;
|
|
678
1061
|
}
|
|
@@ -681,7 +1064,7 @@ __HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bflo
|
|
|
681
1064
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
682
1065
|
* \brief Compare two bfloat162 values
|
|
683
1066
|
*/
|
|
684
|
-
|
|
1067
|
+
__BF16_HOST_DEVICE_STATIC__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
685
1068
|
return __bfloat162float(a) == __bfloat162float(b);
|
|
686
1069
|
}
|
|
687
1070
|
|
|
@@ -689,7 +1072,7 @@ __HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
689
1072
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
690
1073
|
* \brief Compare two bfloat162 values - unordered equal
|
|
691
1074
|
*/
|
|
692
|
-
|
|
1075
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
693
1076
|
return !(__bfloat162float(a) < __bfloat162float(b)) &&
|
|
694
1077
|
!(__bfloat162float(a) > __bfloat162float(b));
|
|
695
1078
|
}
|
|
@@ -698,7 +1081,7 @@ __HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
698
1081
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
699
1082
|
* \brief Compare two bfloat162 values - greater than
|
|
700
1083
|
*/
|
|
701
|
-
|
|
1084
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
702
1085
|
return __bfloat162float(a) > __bfloat162float(b);
|
|
703
1086
|
}
|
|
704
1087
|
|
|
@@ -706,7 +1089,7 @@ __HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
706
1089
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
707
1090
|
* \brief Compare two bfloat162 values - unordered greater than
|
|
708
1091
|
*/
|
|
709
|
-
|
|
1092
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
710
1093
|
return !(__bfloat162float(a) <= __bfloat162float(b));
|
|
711
1094
|
}
|
|
712
1095
|
|
|
@@ -714,7 +1097,7 @@ __HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
714
1097
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
715
1098
|
* \brief Compare two bfloat162 values - greater than equal
|
|
716
1099
|
*/
|
|
717
|
-
|
|
1100
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
718
1101
|
return __bfloat162float(a) >= __bfloat162float(b);
|
|
719
1102
|
}
|
|
720
1103
|
|
|
@@ -722,7 +1105,7 @@ __HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
722
1105
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
723
1106
|
* \brief Compare two bfloat162 values - unordered greater than equal
|
|
724
1107
|
*/
|
|
725
|
-
|
|
1108
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
726
1109
|
return !(__bfloat162float(a) < __bfloat162float(b));
|
|
727
1110
|
}
|
|
728
1111
|
|
|
@@ -730,7 +1113,7 @@ __HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
730
1113
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
731
1114
|
* \brief Compare two bfloat162 values - not equal
|
|
732
1115
|
*/
|
|
733
|
-
|
|
1116
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
734
1117
|
return __bfloat162float(a) != __bfloat162float(b);
|
|
735
1118
|
}
|
|
736
1119
|
|
|
@@ -738,7 +1121,7 @@ __HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
738
1121
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
739
1122
|
* \brief Compare two bfloat162 values - unordered not equal
|
|
740
1123
|
*/
|
|
741
|
-
|
|
1124
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
742
1125
|
return !(__bfloat162float(a) == __bfloat162float(b));
|
|
743
1126
|
}
|
|
744
1127
|
|
|
@@ -746,7 +1129,7 @@ __HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
746
1129
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
747
1130
|
* \brief Compare two bfloat162 values - return max
|
|
748
1131
|
*/
|
|
749
|
-
|
|
1132
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
750
1133
|
#if __HIP_DEVICE_COMPILE__
|
|
751
1134
|
return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b)));
|
|
752
1135
|
#else
|
|
@@ -758,7 +1141,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat
|
|
|
758
1141
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
759
1142
|
* \brief Compare two bfloat162 values - return min
|
|
760
1143
|
*/
|
|
761
|
-
|
|
1144
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
762
1145
|
#if __HIP_DEVICE_COMPILE__
|
|
763
1146
|
return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b)));
|
|
764
1147
|
#else
|
|
@@ -770,7 +1153,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat
|
|
|
770
1153
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
771
1154
|
* \brief Compare two bfloat162 values - less than operator
|
|
772
1155
|
*/
|
|
773
|
-
|
|
1156
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
774
1157
|
return __bfloat162float(a) < __bfloat162float(b);
|
|
775
1158
|
}
|
|
776
1159
|
|
|
@@ -778,7 +1161,7 @@ __HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
778
1161
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
779
1162
|
* \brief Compare two bfloat162 values - unordered less than
|
|
780
1163
|
*/
|
|
781
|
-
|
|
1164
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
782
1165
|
return !(__bfloat162float(a) >= __bfloat162float(b));
|
|
783
1166
|
}
|
|
784
1167
|
|
|
@@ -786,7 +1169,7 @@ __HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
786
1169
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
787
1170
|
* \brief Compare two bfloat162 values - less than equal
|
|
788
1171
|
*/
|
|
789
|
-
|
|
1172
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
790
1173
|
return __bfloat162float(a) <= __bfloat162float(b);
|
|
791
1174
|
}
|
|
792
1175
|
|
|
@@ -794,7 +1177,7 @@ __HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
794
1177
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
795
1178
|
* \brief Compare two bfloat162 values - unordered less than equal
|
|
796
1179
|
*/
|
|
797
|
-
|
|
1180
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
798
1181
|
return !(__bfloat162float(a) > __bfloat162float(b));
|
|
799
1182
|
}
|
|
800
1183
|
|
|
@@ -802,208 +1185,282 @@ __HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
|
|
|
802
1185
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
803
1186
|
* \brief Checks if number is inf
|
|
804
1187
|
*/
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
int res = __ocml_isinf_f32(__bfloat162float(a));
|
|
809
|
-
#else
|
|
810
|
-
int res = std::isinf(__bfloat162float(a)) ? 1 : 0;
|
|
811
|
-
#endif
|
|
812
|
-
return (res == 0) ? res : ((sign != 0U) ? -res : res);
|
|
1188
|
+
__BF16_HOST_DEVICE_STATIC__ int __hisinf(const __hip_bfloat16 a) {
|
|
1189
|
+
__hip_bfloat16_raw hr = a;
|
|
1190
|
+
return !(~hr.x & 0x7f80) && !(hr.x & 0x7f);
|
|
813
1191
|
}
|
|
814
1192
|
|
|
815
1193
|
/**
|
|
816
1194
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
817
1195
|
* \brief Checks if number is nan
|
|
818
1196
|
*/
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
return
|
|
822
|
-
#else
|
|
823
|
-
return std::isnan(__bfloat162float(a));
|
|
824
|
-
#endif
|
|
1197
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hisnan(const __hip_bfloat16 a) {
|
|
1198
|
+
__hip_bfloat16_raw hr = a;
|
|
1199
|
+
return !(~hr.x & 0x7f80) && +(hr.x & 0x7f);
|
|
825
1200
|
}
|
|
826
1201
|
|
|
827
1202
|
/**
|
|
828
1203
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
829
1204
|
* \brief Checks if two numbers are equal
|
|
830
1205
|
*/
|
|
831
|
-
|
|
832
|
-
|
|
1206
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1207
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1208
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1209
|
+
return __heq(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1210
|
+
__heq(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
833
1211
|
}
|
|
834
1212
|
|
|
835
1213
|
/**
|
|
836
1214
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
837
1215
|
* \brief Checks if two numbers are equal - unordered
|
|
838
1216
|
*/
|
|
839
|
-
|
|
840
|
-
|
|
1217
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1218
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1219
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1220
|
+
return __hequ(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1221
|
+
__hequ(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
841
1222
|
}
|
|
842
1223
|
|
|
843
1224
|
/**
|
|
844
1225
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
845
1226
|
* \brief Check for a >= b
|
|
846
1227
|
*/
|
|
847
|
-
|
|
848
|
-
|
|
1228
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1229
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1230
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1231
|
+
return __hge(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1232
|
+
__hge(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
849
1233
|
}
|
|
850
1234
|
|
|
851
1235
|
/**
|
|
852
1236
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
853
1237
|
* \brief Check for a >= b - unordered
|
|
854
1238
|
*/
|
|
855
|
-
|
|
856
|
-
|
|
1239
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1240
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1241
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1242
|
+
return __hgeu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1243
|
+
__hgeu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
857
1244
|
}
|
|
858
1245
|
|
|
859
1246
|
/**
|
|
860
1247
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
861
1248
|
* \brief Check for a > b
|
|
862
1249
|
*/
|
|
863
|
-
|
|
864
|
-
|
|
1250
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1251
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1252
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1253
|
+
return __hgt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1254
|
+
__hgt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
865
1255
|
}
|
|
866
1256
|
|
|
867
1257
|
/**
|
|
868
1258
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
869
1259
|
* \brief Check for a > b - unordered
|
|
870
1260
|
*/
|
|
871
|
-
|
|
872
|
-
|
|
1261
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1262
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1263
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1264
|
+
return __hgtu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1265
|
+
__hgtu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
873
1266
|
}
|
|
874
1267
|
|
|
875
1268
|
/**
|
|
876
1269
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
877
1270
|
* \brief Check for a <= b
|
|
878
1271
|
*/
|
|
879
|
-
|
|
880
|
-
|
|
1272
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1273
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1274
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1275
|
+
return __hle(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1276
|
+
__hle(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
881
1277
|
}
|
|
882
1278
|
|
|
883
1279
|
/**
|
|
884
1280
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
885
1281
|
* \brief Check for a <= b - unordered
|
|
886
1282
|
*/
|
|
887
|
-
|
|
888
|
-
|
|
1283
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1284
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1285
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1286
|
+
return __hleu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1287
|
+
__hleu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
889
1288
|
}
|
|
890
1289
|
|
|
891
1290
|
/**
|
|
892
1291
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
893
1292
|
* \brief Check for a < b
|
|
894
1293
|
*/
|
|
895
|
-
|
|
896
|
-
|
|
1294
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1295
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1296
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1297
|
+
return __hlt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1298
|
+
__hlt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
897
1299
|
}
|
|
898
1300
|
|
|
899
1301
|
/**
|
|
900
1302
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
901
1303
|
* \brief Check for a < b - unordered
|
|
902
1304
|
*/
|
|
903
|
-
|
|
904
|
-
|
|
1305
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1306
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1307
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1308
|
+
return __hltu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) &&
|
|
1309
|
+
__hltu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
905
1310
|
}
|
|
906
1311
|
|
|
907
1312
|
/**
|
|
908
1313
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
909
1314
|
* \brief Check for a != b
|
|
910
1315
|
*/
|
|
911
|
-
|
|
912
|
-
|
|
1316
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1317
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1318
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1319
|
+
return __hne(__hip_bfloat16(__hip_bfloat16_raw{hr_a.x}),
|
|
1320
|
+
__hip_bfloat16(__hip_bfloat16_raw{hr_b.x})) &&
|
|
1321
|
+
__hne(__hip_bfloat16(__hip_bfloat16_raw{hr_a.y}), __hip_bfloat16(__hip_bfloat16_raw{hr_b.y}));
|
|
913
1322
|
}
|
|
914
1323
|
|
|
915
1324
|
/**
|
|
916
1325
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
917
1326
|
* \brief Check for a != b
|
|
918
1327
|
*/
|
|
919
|
-
|
|
920
|
-
|
|
1328
|
+
__BF16_HOST_DEVICE_STATIC__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) {
|
|
1329
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1330
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1331
|
+
return __hneu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ||
|
|
1332
|
+
__hneu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y});
|
|
921
1333
|
}
|
|
922
1334
|
|
|
923
1335
|
/**
|
|
924
1336
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
925
1337
|
* \brief Check for a != b, returns 1.0 if equal, otherwise 0.0
|
|
926
1338
|
*/
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
1339
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __heq2(const __hip_bfloat162 a,
|
|
1340
|
+
const __hip_bfloat162 b) {
|
|
1341
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1342
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1343
|
+
return __hip_bfloat162{
|
|
1344
|
+
{__heq(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
|
|
1345
|
+
: HIPRT_ZERO_BF16},
|
|
1346
|
+
{__heq(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
|
|
1347
|
+
: HIPRT_ZERO_BF16}};
|
|
930
1348
|
}
|
|
931
1349
|
|
|
932
1350
|
/**
|
|
933
1351
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
934
1352
|
* \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0
|
|
935
1353
|
*/
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
1354
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hge2(const __hip_bfloat162 a,
|
|
1355
|
+
const __hip_bfloat162 b) {
|
|
1356
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1357
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1358
|
+
return __hip_bfloat162{
|
|
1359
|
+
{__hge(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
|
|
1360
|
+
: HIPRT_ZERO_BF16},
|
|
1361
|
+
{__hge(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
|
|
1362
|
+
: HIPRT_ZERO_BF16}};
|
|
939
1363
|
}
|
|
940
1364
|
|
|
941
1365
|
/**
|
|
942
1366
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
943
1367
|
* \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0
|
|
944
1368
|
*/
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
1369
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a,
|
|
1370
|
+
const __hip_bfloat162 b) {
|
|
1371
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1372
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1373
|
+
return __hip_bfloat162{
|
|
1374
|
+
{__hgt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
|
|
1375
|
+
: HIPRT_ZERO_BF16},
|
|
1376
|
+
{__hgt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
|
|
1377
|
+
: HIPRT_ONE_BF16}};
|
|
948
1378
|
}
|
|
949
1379
|
|
|
950
1380
|
/**
|
|
951
1381
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
952
1382
|
* \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0
|
|
953
1383
|
*/
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
1384
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) {
|
|
1385
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1386
|
+
return __hip_bfloat162{{__hisnan(__hip_bfloat16_raw{hr_a.x}) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16},
|
|
1387
|
+
{__hisnan(__hip_bfloat16_raw{hr_a.y}) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}};
|
|
957
1388
|
}
|
|
958
1389
|
|
|
959
1390
|
/**
|
|
960
1391
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
961
1392
|
* \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0
|
|
962
1393
|
*/
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
1394
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hle2(const __hip_bfloat162 a,
|
|
1395
|
+
const __hip_bfloat162 b) {
|
|
1396
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1397
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1398
|
+
return __hip_bfloat162{
|
|
1399
|
+
{__hle(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
|
|
1400
|
+
: HIPRT_ZERO_BF16},
|
|
1401
|
+
{__hle(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
|
|
1402
|
+
: HIPRT_ZERO_BF16}};
|
|
966
1403
|
}
|
|
967
1404
|
|
|
968
1405
|
/**
|
|
969
1406
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
970
1407
|
* \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0
|
|
971
1408
|
*/
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
1409
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a,
|
|
1410
|
+
const __hip_bfloat162 b) {
|
|
1411
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1412
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1413
|
+
return __hip_bfloat162{
|
|
1414
|
+
{__hlt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
|
|
1415
|
+
: HIPRT_ZERO_BF16},
|
|
1416
|
+
{__hlt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
|
|
1417
|
+
: HIPRT_ZERO_BF16}};
|
|
975
1418
|
}
|
|
976
1419
|
|
|
977
1420
|
/**
|
|
978
1421
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
979
1422
|
* \brief Returns max of two elements
|
|
980
1423
|
*/
|
|
981
|
-
|
|
982
|
-
|
|
1424
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a,
|
|
1425
|
+
const __hip_bfloat162 b) {
|
|
1426
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1427
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1428
|
+
return __hip_bfloat162(__hmax(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
|
|
1429
|
+
__hmax(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
|
|
983
1430
|
}
|
|
984
1431
|
|
|
985
1432
|
/**
|
|
986
1433
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
987
1434
|
* \brief Returns min of two elements
|
|
988
1435
|
*/
|
|
989
|
-
|
|
990
|
-
|
|
1436
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a,
|
|
1437
|
+
const __hip_bfloat162 b) {
|
|
1438
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1439
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1440
|
+
return __hip_bfloat162(__hmin(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}),
|
|
1441
|
+
__hmin(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}));
|
|
991
1442
|
}
|
|
992
1443
|
|
|
993
1444
|
/**
|
|
994
1445
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
995
1446
|
* \brief Checks for not equal to
|
|
996
1447
|
*/
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1448
|
+
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hne2(const __hip_bfloat162 a,
|
|
1449
|
+
const __hip_bfloat162 b) {
|
|
1450
|
+
__hip_bfloat162_raw hr_a = a;
|
|
1451
|
+
__hip_bfloat162_raw hr_b = b;
|
|
1452
|
+
return __hip_bfloat162{
|
|
1453
|
+
{__hne(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16
|
|
1454
|
+
: HIPRT_ZERO_BF16},
|
|
1455
|
+
{__hne(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16
|
|
1456
|
+
: HIPRT_ZERO_BF16}};
|
|
1000
1457
|
}
|
|
1001
1458
|
|
|
1002
1459
|
/**
|
|
1003
1460
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
1004
1461
|
* \brief Operator to perform an equal compare on two __hip_bfloat16 numbers
|
|
1005
1462
|
*/
|
|
1006
|
-
|
|
1463
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
1007
1464
|
return __heq(l, r);
|
|
1008
1465
|
}
|
|
1009
1466
|
|
|
@@ -1011,7 +1468,7 @@ __HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r
|
|
|
1011
1468
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
1012
1469
|
* \brief Operator to perform a not equal on two __hip_bfloat16 numbers
|
|
1013
1470
|
*/
|
|
1014
|
-
|
|
1471
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
1015
1472
|
return __hne(l, r);
|
|
1016
1473
|
}
|
|
1017
1474
|
|
|
@@ -1019,7 +1476,7 @@ __HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r
|
|
|
1019
1476
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
1020
1477
|
* \brief Operator to perform a less than on two __hip_bfloat16 numbers
|
|
1021
1478
|
*/
|
|
1022
|
-
|
|
1479
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
1023
1480
|
return __hlt(l, r);
|
|
1024
1481
|
}
|
|
1025
1482
|
|
|
@@ -1027,7 +1484,7 @@ __HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r)
|
|
|
1027
1484
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
1028
1485
|
* \brief Operator to perform a less than equal on two __hip_bfloat16 numbers
|
|
1029
1486
|
*/
|
|
1030
|
-
|
|
1487
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
1031
1488
|
return __hle(l, r);
|
|
1032
1489
|
}
|
|
1033
1490
|
|
|
@@ -1035,7 +1492,7 @@ __HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r
|
|
|
1035
1492
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
1036
1493
|
* \brief Operator to perform a greater than on two __hip_bfloat16 numbers
|
|
1037
1494
|
*/
|
|
1038
|
-
|
|
1495
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
1039
1496
|
return __hgt(l, r);
|
|
1040
1497
|
}
|
|
1041
1498
|
|
|
@@ -1043,7 +1500,7 @@ __HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r)
|
|
|
1043
1500
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
1044
1501
|
* \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers
|
|
1045
1502
|
*/
|
|
1046
|
-
|
|
1503
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) {
|
|
1047
1504
|
return __hge(l, r);
|
|
1048
1505
|
}
|
|
1049
1506
|
|
|
@@ -1051,55 +1508,60 @@ __HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r
|
|
|
1051
1508
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
1052
1509
|
* \brief Operator to perform an equal compare on two __hip_bfloat16 numbers
|
|
1053
1510
|
*/
|
|
1054
|
-
|
|
1055
|
-
|
|
1511
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) {
|
|
1512
|
+
float2 ret = __heq2(l, r);
|
|
1513
|
+
return ret.x != 0.0f && ret.y != 0.0f;
|
|
1056
1514
|
}
|
|
1057
1515
|
|
|
1058
1516
|
/**
|
|
1059
1517
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
1060
1518
|
* \brief Operator to perform a not equal on two __hip_bfloat16 numbers
|
|
1061
1519
|
*/
|
|
1062
|
-
|
|
1063
|
-
return
|
|
1520
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
|
|
1521
|
+
return !(l == r);
|
|
1064
1522
|
}
|
|
1065
1523
|
|
|
1066
1524
|
/**
|
|
1067
1525
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
1068
1526
|
* \brief Operator to perform a less than on two __hip_bfloat16 numbers
|
|
1069
1527
|
*/
|
|
1070
|
-
|
|
1071
|
-
|
|
1528
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) {
|
|
1529
|
+
float2 fl = l, fr = r;
|
|
1530
|
+
return fl.x < fr.x && fl.x < fr.y;
|
|
1072
1531
|
}
|
|
1073
1532
|
|
|
1074
1533
|
/**
|
|
1075
1534
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
1076
1535
|
* \brief Operator to perform a less than equal on two __hip_bfloat16 numbers
|
|
1077
1536
|
*/
|
|
1078
|
-
|
|
1079
|
-
|
|
1537
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
|
|
1538
|
+
float2 fl = l, fr = r;
|
|
1539
|
+
return fl.x <= fr.x && fl.x <= fr.y;
|
|
1080
1540
|
}
|
|
1081
1541
|
|
|
1082
1542
|
/**
|
|
1083
1543
|
* \ingroup HIP_INTRINSIC_BFLOAT162_COMP
|
|
1084
1544
|
* \brief Operator to perform a greater than on two __hip_bfloat16 numbers
|
|
1085
1545
|
*/
|
|
1086
|
-
|
|
1087
|
-
|
|
1546
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) {
|
|
1547
|
+
float2 fl = l, fr = r;
|
|
1548
|
+
return fl.x > fr.x && fl.x > fr.y;
|
|
1088
1549
|
}
|
|
1089
1550
|
|
|
1090
1551
|
/**
|
|
1091
1552
|
* \ingroup HIP_INTRINSIC_BFLOAT16_COMP
|
|
1092
1553
|
* \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers
|
|
1093
1554
|
*/
|
|
1094
|
-
|
|
1095
|
-
|
|
1555
|
+
__BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) {
|
|
1556
|
+
float2 fl = l, fr = r;
|
|
1557
|
+
return fl.x >= fr.x && fl.x >= fr.y;
|
|
1096
1558
|
}
|
|
1097
1559
|
|
|
1098
1560
|
/**
|
|
1099
1561
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1100
1562
|
* \brief Calculate ceil of bfloat16
|
|
1101
1563
|
*/
|
|
1102
|
-
|
|
1564
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hceil(const __hip_bfloat16 h) {
|
|
1103
1565
|
return __float2bfloat16(__ocml_ceil_f32(__bfloat162float(h)));
|
|
1104
1566
|
}
|
|
1105
1567
|
|
|
@@ -1107,7 +1569,7 @@ __device__ __hip_bfloat16 hceil(const __hip_bfloat16 h) {
|
|
|
1107
1569
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1108
1570
|
* \brief Calculate cosine of bfloat16
|
|
1109
1571
|
*/
|
|
1110
|
-
|
|
1572
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hcos(const __hip_bfloat16 h) {
|
|
1111
1573
|
return __float2bfloat16(__ocml_cos_f32(__bfloat162float(h)));
|
|
1112
1574
|
}
|
|
1113
1575
|
|
|
@@ -1115,7 +1577,7 @@ __device__ __hip_bfloat16 hcos(const __hip_bfloat16 h) {
|
|
|
1115
1577
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1116
1578
|
* \brief Calculate exponential of bfloat16
|
|
1117
1579
|
*/
|
|
1118
|
-
|
|
1580
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp(const __hip_bfloat16 h) {
|
|
1119
1581
|
return __float2bfloat16(__ocml_exp_f32(__bfloat162float(h)));
|
|
1120
1582
|
}
|
|
1121
1583
|
|
|
@@ -1123,7 +1585,7 @@ __device__ __hip_bfloat16 hexp(const __hip_bfloat16 h) {
|
|
|
1123
1585
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1124
1586
|
* \brief Calculate exponential 10 of bfloat16
|
|
1125
1587
|
*/
|
|
1126
|
-
|
|
1588
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) {
|
|
1127
1589
|
return __float2bfloat16(__ocml_exp10_f32(__bfloat162float(h)));
|
|
1128
1590
|
}
|
|
1129
1591
|
|
|
@@ -1131,7 +1593,7 @@ __device__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) {
|
|
|
1131
1593
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1132
1594
|
* \brief Calculate exponential 2 of bfloat16
|
|
1133
1595
|
*/
|
|
1134
|
-
|
|
1596
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) {
|
|
1135
1597
|
return __float2bfloat16(__ocml_exp2_f32(__bfloat162float(h)));
|
|
1136
1598
|
}
|
|
1137
1599
|
|
|
@@ -1139,7 +1601,7 @@ __device__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) {
|
|
|
1139
1601
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1140
1602
|
* \brief Calculate floor of bfloat16
|
|
1141
1603
|
*/
|
|
1142
|
-
|
|
1604
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) {
|
|
1143
1605
|
return __float2bfloat16(__ocml_floor_f32(__bfloat162float(h)));
|
|
1144
1606
|
}
|
|
1145
1607
|
|
|
@@ -1147,7 +1609,7 @@ __device__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) {
|
|
|
1147
1609
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1148
1610
|
* \brief Calculate natural log of bfloat16
|
|
1149
1611
|
*/
|
|
1150
|
-
|
|
1612
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog(const __hip_bfloat16 h) {
|
|
1151
1613
|
return __float2bfloat16(__ocml_log_f32(__bfloat162float(h)));
|
|
1152
1614
|
}
|
|
1153
1615
|
|
|
@@ -1155,7 +1617,7 @@ __device__ __hip_bfloat16 hlog(const __hip_bfloat16 h) {
|
|
|
1155
1617
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1156
1618
|
* \brief Calculate log 10 of bfloat16
|
|
1157
1619
|
*/
|
|
1158
|
-
|
|
1620
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) {
|
|
1159
1621
|
return __float2bfloat16(__ocml_log10_f32(__bfloat162float(h)));
|
|
1160
1622
|
}
|
|
1161
1623
|
|
|
@@ -1163,7 +1625,7 @@ __device__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) {
|
|
|
1163
1625
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1164
1626
|
* \brief Calculate log 2 of bfloat16
|
|
1165
1627
|
*/
|
|
1166
|
-
|
|
1628
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) {
|
|
1167
1629
|
return __float2bfloat16(__ocml_log2_f32(__bfloat162float(h)));
|
|
1168
1630
|
}
|
|
1169
1631
|
|
|
@@ -1171,7 +1633,7 @@ __device__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) {
|
|
|
1171
1633
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1172
1634
|
* \brief Calculate reciprocal
|
|
1173
1635
|
*/
|
|
1174
|
-
|
|
1636
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) {
|
|
1175
1637
|
return __float2bfloat16(1.0f / (__bfloat162float(h)));
|
|
1176
1638
|
}
|
|
1177
1639
|
|
|
@@ -1179,7 +1641,7 @@ __device__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) {
|
|
|
1179
1641
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1180
1642
|
* \brief Round to nearest int
|
|
1181
1643
|
*/
|
|
1182
|
-
|
|
1644
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hrint(const __hip_bfloat16 h) {
|
|
1183
1645
|
return __float2bfloat16(__ocml_rint_f32(__bfloat162float(h)));
|
|
1184
1646
|
}
|
|
1185
1647
|
|
|
@@ -1187,7 +1649,7 @@ __device__ __hip_bfloat16 hrint(const __hip_bfloat16 h) {
|
|
|
1187
1649
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1188
1650
|
* \brief Reciprocal square root
|
|
1189
1651
|
*/
|
|
1190
|
-
|
|
1652
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) {
|
|
1191
1653
|
return __float2bfloat16(__ocml_rsqrt_f32(__bfloat162float(h)));
|
|
1192
1654
|
}
|
|
1193
1655
|
|
|
@@ -1195,7 +1657,7 @@ __device__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) {
|
|
|
1195
1657
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1196
1658
|
* \brief Calculate sin of bfloat16
|
|
1197
1659
|
*/
|
|
1198
|
-
|
|
1660
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hsin(const __hip_bfloat16 h) {
|
|
1199
1661
|
return __float2bfloat16(__ocml_sin_f32(__bfloat162float(h)));
|
|
1200
1662
|
}
|
|
1201
1663
|
|
|
@@ -1203,7 +1665,7 @@ __device__ __hip_bfloat16 hsin(const __hip_bfloat16 h) {
|
|
|
1203
1665
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1204
1666
|
* \brief Calculate sqrt of bfloat16
|
|
1205
1667
|
*/
|
|
1206
|
-
|
|
1668
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) {
|
|
1207
1669
|
return __float2bfloat16(__ocml_sqrt_f32(__bfloat162float(h)));
|
|
1208
1670
|
}
|
|
1209
1671
|
|
|
@@ -1211,7 +1673,7 @@ __device__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) {
|
|
|
1211
1673
|
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
|
|
1212
1674
|
* \brief Calculate truncate of bfloat16
|
|
1213
1675
|
*/
|
|
1214
|
-
|
|
1676
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) {
|
|
1215
1677
|
return __float2bfloat16(__ocml_trunc_f32(__bfloat162float(h)));
|
|
1216
1678
|
}
|
|
1217
1679
|
|
|
@@ -1219,119 +1681,134 @@ __device__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) {
|
|
|
1219
1681
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1220
1682
|
* \brief Calculate ceil of bfloat162
|
|
1221
1683
|
*/
|
|
1222
|
-
|
|
1223
|
-
|
|
1684
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2ceil(const __hip_bfloat162 h) {
|
|
1685
|
+
__hip_bfloat162_raw hr = h;
|
|
1686
|
+
return __hip_bfloat162(hceil(__hip_bfloat16_raw{hr.x}), hceil(__hip_bfloat16_raw{hr.y}));
|
|
1224
1687
|
}
|
|
1225
1688
|
|
|
1226
1689
|
/**
|
|
1227
1690
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1228
1691
|
* \brief Calculate cosine of bfloat162
|
|
1229
1692
|
*/
|
|
1230
|
-
|
|
1231
|
-
|
|
1693
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2cos(const __hip_bfloat162 h) {
|
|
1694
|
+
__hip_bfloat162_raw hr = h;
|
|
1695
|
+
return __hip_bfloat162(hcos(__hip_bfloat16_raw{hr.x}), hcos(__hip_bfloat16_raw{hr.y}));
|
|
1232
1696
|
}
|
|
1233
1697
|
|
|
1234
1698
|
/**
|
|
1235
1699
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1236
1700
|
* \brief Calculate exponential of bfloat162
|
|
1237
1701
|
*/
|
|
1238
|
-
|
|
1239
|
-
|
|
1702
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp(const __hip_bfloat162 h) {
|
|
1703
|
+
__hip_bfloat162_raw hr = h;
|
|
1704
|
+
return __hip_bfloat162(hexp(__hip_bfloat16_raw{hr.x}), hexp(__hip_bfloat16_raw{hr.y}));
|
|
1240
1705
|
}
|
|
1241
1706
|
|
|
1242
1707
|
/**
|
|
1243
1708
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1244
1709
|
* \brief Calculate exponential 10 of bfloat162
|
|
1245
1710
|
*/
|
|
1246
|
-
|
|
1247
|
-
|
|
1711
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp10(const __hip_bfloat162 h) {
|
|
1712
|
+
__hip_bfloat162_raw hr = h;
|
|
1713
|
+
return __hip_bfloat162(hexp10(__hip_bfloat16_raw{hr.x}), hexp10(__hip_bfloat16_raw{hr.y}));
|
|
1248
1714
|
}
|
|
1249
1715
|
|
|
1250
1716
|
/**
|
|
1251
1717
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1252
1718
|
* \brief Calculate exponential 2 of bfloat162
|
|
1253
1719
|
*/
|
|
1254
|
-
|
|
1255
|
-
|
|
1720
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp2(const __hip_bfloat162 h) {
|
|
1721
|
+
__hip_bfloat162_raw hr = h;
|
|
1722
|
+
return __hip_bfloat162(hexp2(__hip_bfloat16_raw{hr.x}), hexp2(__hip_bfloat16_raw{hr.y}));
|
|
1256
1723
|
}
|
|
1257
1724
|
|
|
1258
1725
|
/**
|
|
1259
1726
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1260
1727
|
* \brief Calculate floor of bfloat162
|
|
1261
1728
|
*/
|
|
1262
|
-
|
|
1263
|
-
|
|
1729
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2floor(const __hip_bfloat162 h) {
|
|
1730
|
+
__hip_bfloat162_raw hr = h;
|
|
1731
|
+
return __hip_bfloat162(hfloor(__hip_bfloat16_raw{hr.x}), hfloor(__hip_bfloat16_raw{hr.y}));
|
|
1264
1732
|
}
|
|
1265
1733
|
|
|
1266
1734
|
/**
|
|
1267
1735
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1268
1736
|
* \brief Calculate natural log of bfloat162
|
|
1269
1737
|
*/
|
|
1270
|
-
|
|
1271
|
-
|
|
1738
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log(const __hip_bfloat162 h) {
|
|
1739
|
+
__hip_bfloat162_raw hr = h;
|
|
1740
|
+
return __hip_bfloat162(hlog(__hip_bfloat16_raw{hr.x}), hlog(__hip_bfloat16_raw{hr.y}));
|
|
1272
1741
|
}
|
|
1273
1742
|
|
|
1274
1743
|
/**
|
|
1275
1744
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1276
1745
|
* \brief Calculate log 10 of bfloat162
|
|
1277
1746
|
*/
|
|
1278
|
-
|
|
1279
|
-
|
|
1747
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log10(const __hip_bfloat162 h) {
|
|
1748
|
+
__hip_bfloat162_raw hr = h;
|
|
1749
|
+
return __hip_bfloat162(hlog10(__hip_bfloat16_raw{hr.x}), hlog10(__hip_bfloat16_raw{hr.y}));
|
|
1280
1750
|
}
|
|
1281
1751
|
|
|
1282
1752
|
/**
|
|
1283
1753
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1284
1754
|
* \brief Calculate log 2 of bfloat162
|
|
1285
1755
|
*/
|
|
1286
|
-
|
|
1287
|
-
|
|
1756
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log2(const __hip_bfloat162 h) {
|
|
1757
|
+
__hip_bfloat162_raw hr = h;
|
|
1758
|
+
return __hip_bfloat162(hlog2(__hip_bfloat16_raw{hr.x}), hlog2(__hip_bfloat16_raw{hr.y}));
|
|
1288
1759
|
}
|
|
1289
1760
|
|
|
1290
1761
|
/**
|
|
1291
1762
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1292
1763
|
* \brief Calculate vector reciprocal
|
|
1293
1764
|
*/
|
|
1294
|
-
|
|
1295
|
-
|
|
1765
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rcp(const __hip_bfloat162 h) {
|
|
1766
|
+
__hip_bfloat162_raw hr = h;
|
|
1767
|
+
return __hip_bfloat162(hrcp(__hip_bfloat16_raw{hr.x}), hrcp(__hip_bfloat16_raw{hr.y}));
|
|
1296
1768
|
}
|
|
1297
1769
|
|
|
1298
1770
|
/**
|
|
1299
1771
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1300
1772
|
* \brief Calculate vector round to nearest int
|
|
1301
1773
|
*/
|
|
1302
|
-
|
|
1303
|
-
|
|
1774
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rint(const __hip_bfloat162 h) {
|
|
1775
|
+
__hip_bfloat162_raw hr = h;
|
|
1776
|
+
return __hip_bfloat162(hrint(__hip_bfloat16_raw{hr.x}), hrint(__hip_bfloat16_raw{hr.y}));
|
|
1304
1777
|
}
|
|
1305
1778
|
|
|
1306
1779
|
/**
|
|
1307
1780
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1308
1781
|
* \brief Calculate vector reciprocal square root
|
|
1309
1782
|
*/
|
|
1310
|
-
|
|
1311
|
-
|
|
1783
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rsqrt(const __hip_bfloat162 h) {
|
|
1784
|
+
__hip_bfloat162_raw hr = h;
|
|
1785
|
+
return __hip_bfloat162(hrsqrt(__hip_bfloat16_raw{hr.x}), hrsqrt(__hip_bfloat16_raw{hr.y}));
|
|
1312
1786
|
}
|
|
1313
1787
|
|
|
1314
1788
|
/**
|
|
1315
1789
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1316
1790
|
* \brief Calculate sin of bfloat162
|
|
1317
1791
|
*/
|
|
1318
|
-
|
|
1319
|
-
|
|
1792
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2sin(const __hip_bfloat162 h) {
|
|
1793
|
+
__hip_bfloat162_raw hr = h;
|
|
1794
|
+
return __hip_bfloat162(hsin(__hip_bfloat16_raw{hr.x}), hsin(__hip_bfloat16_raw{hr.y}));
|
|
1320
1795
|
}
|
|
1321
1796
|
|
|
1322
1797
|
/**
|
|
1323
1798
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1324
1799
|
* \brief Calculate sqrt of bfloat162
|
|
1325
1800
|
*/
|
|
1326
|
-
|
|
1327
|
-
|
|
1801
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2sqrt(const __hip_bfloat162 h) {
|
|
1802
|
+
__hip_bfloat162_raw hr = h;
|
|
1803
|
+
return __hip_bfloat162(hsqrt(__hip_bfloat16_raw{hr.x}), hsqrt(__hip_bfloat16_raw{hr.y}));
|
|
1328
1804
|
}
|
|
1329
1805
|
|
|
1330
1806
|
/**
|
|
1331
1807
|
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
|
|
1332
1808
|
* \brief Calculate truncate of bfloat162
|
|
1333
1809
|
*/
|
|
1334
|
-
|
|
1335
|
-
|
|
1810
|
+
__BF16_DEVICE_STATIC__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) {
|
|
1811
|
+
__hip_bfloat162_raw hr = h;
|
|
1812
|
+
return __hip_bfloat162(htrunc(__hip_bfloat16_raw{hr.x}), htrunc(__hip_bfloat16_raw{hr.y}));
|
|
1336
1813
|
}
|
|
1337
1814
|
#endif
|