triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,1809 +0,0 @@
1
- /*
2
- Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved.
3
-
4
- Permission is hereby granted, free of charge, to any person obtaining a copy
5
- of this software and associated documentation files (the "Software"), to deal
6
- in the Software without restriction, including without limitation the rights
7
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8
- copies of the Software, and to permit persons to whom the Software is
9
- furnished to do so, subject to the following conditions:
10
-
11
- The above copyright notice and this permission notice shall be included in
12
- all copies or substantial portions of the Software.
13
-
14
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20
- THE SOFTWARE.
21
- */
22
-
23
- #pragma once
24
- #ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H
25
- #define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H
26
-
27
- #if defined(__HIPCC_RTC__)
28
- #define __HOST_DEVICE__ __device__
29
- #else
30
- #define __HOST_DEVICE__ __host__ __device__
31
- #include <hip/amd_detail/amd_hip_common.h>
32
- #include "hip/amd_detail/host_defines.h"
33
- #include <assert.h>
34
- #if defined(__cplusplus)
35
- #include <algorithm>
36
- #include <type_traits>
37
- #include <utility>
38
- #endif
39
- #endif // !defined(__HIPCC_RTC__)
40
-
41
- #if defined(__clang__) && defined(__HIP__)
42
- typedef _Float16 _Float16_2 __attribute__((ext_vector_type(2)));
43
-
44
- struct __half_raw {
45
- union {
46
- static_assert(sizeof(_Float16) == sizeof(unsigned short), "");
47
-
48
- _Float16 data;
49
- unsigned short x;
50
- };
51
- };
52
-
53
- struct __half2_raw {
54
- union {
55
- static_assert(sizeof(_Float16_2) == sizeof(unsigned short[2]), "");
56
-
57
- struct {
58
- __half_raw x;
59
- __half_raw y;
60
- };
61
- _Float16_2 data;
62
- };
63
- };
64
-
65
- #if defined(__cplusplus)
66
- #if !defined(__HIPCC_RTC__)
67
- #include "hip_fp16_math_fwd.h"
68
- #include "amd_hip_vector_types.h"
69
- #include "host_defines.h"
70
- #include "amd_device_functions.h"
71
- #include "amd_warp_functions.h"
72
- #endif
73
- namespace std
74
- {
75
- template<> struct is_floating_point<_Float16> : std::true_type {};
76
- }
77
-
78
- template<bool cond, typename T = void>
79
- using Enable_if_t = typename std::enable_if<cond, T>::type;
80
-
81
- // BEGIN STRUCT __HALF
82
- struct __half {
83
- protected:
84
- union {
85
- static_assert(sizeof(_Float16) == sizeof(unsigned short), "");
86
-
87
- _Float16 data;
88
- unsigned short __x;
89
- };
90
- public:
91
- // CREATORS
92
- __HOST_DEVICE__
93
- __half() = default;
94
- __HOST_DEVICE__
95
- __half(const __half_raw& x) : data{x.data} {}
96
- #if !defined(__HIP_NO_HALF_CONVERSIONS__)
97
- __HOST_DEVICE__
98
- __half(decltype(data) x) : data{x} {}
99
- template<
100
- typename T,
101
- Enable_if_t<std::is_floating_point<T>{}>* = nullptr>
102
- __HOST_DEVICE__
103
- __half(T x) : data{static_cast<_Float16>(x)} {}
104
- #endif
105
- __HOST_DEVICE__
106
- __half(const __half&) = default;
107
- __HOST_DEVICE__
108
- __half(__half&&) = default;
109
- __HOST_DEVICE__
110
- ~__half() = default;
111
-
112
- // CREATORS - DEVICE ONLY
113
- #if !defined(__HIP_NO_HALF_CONVERSIONS__)
114
- template<
115
- typename T, Enable_if_t<std::is_integral<T>{}>* = nullptr>
116
- __HOST_DEVICE__
117
- __half(T x) : data{static_cast<_Float16>(x)} {}
118
- #endif
119
-
120
- // MANIPULATORS
121
- __HOST_DEVICE__
122
- __half& operator=(const __half&) = default;
123
- __HOST_DEVICE__
124
- __half& operator=(__half&&) = default;
125
- __HOST_DEVICE__
126
- __half& operator=(const __half_raw& x)
127
- {
128
- data = x.data;
129
- return *this;
130
- }
131
- __HOST_DEVICE__
132
- volatile __half& operator=(const __half_raw& x) volatile
133
- {
134
- data = x.data;
135
- return *this;
136
- }
137
- volatile __half& operator=(const volatile __half_raw& x) volatile
138
- {
139
- data = x.data;
140
- return *this;
141
- }
142
- __half& operator=(__half_raw&& x)
143
- {
144
- data = x.data;
145
- return *this;
146
- }
147
- volatile __half& operator=(__half_raw&& x) volatile
148
- {
149
- data = x.data;
150
- return *this;
151
- }
152
- volatile __half& operator=(volatile __half_raw&& x) volatile
153
- {
154
- data = x.data;
155
- return *this;
156
- }
157
- #if !defined(__HIP_NO_HALF_CONVERSIONS__)
158
- template<
159
- typename T,
160
- Enable_if_t<std::is_floating_point<T>{}>* = nullptr>
161
- __HOST_DEVICE__
162
- __half& operator=(T x)
163
- {
164
- data = static_cast<_Float16>(x);
165
- return *this;
166
- }
167
- #endif
168
-
169
- // MANIPULATORS - DEVICE ONLY
170
- #if !defined(__HIP_NO_HALF_CONVERSIONS__)
171
- template<
172
- typename T, Enable_if_t<std::is_integral<T>{}>* = nullptr>
173
- __device__
174
- __half& operator=(T x)
175
- {
176
- data = static_cast<_Float16>(x);
177
- return *this;
178
- }
179
- #endif
180
-
181
- #if !defined(__HIP_NO_HALF_OPERATORS__)
182
- __device__
183
- __half& operator+=(const __half& x)
184
- {
185
- data += x.data;
186
- return *this;
187
- }
188
- __device__
189
- __half& operator-=(const __half& x)
190
- {
191
- data -= x.data;
192
- return *this;
193
- }
194
- __device__
195
- __half& operator*=(const __half& x)
196
- {
197
- data *= x.data;
198
- return *this;
199
- }
200
- __device__
201
- __half& operator/=(const __half& x)
202
- {
203
- data /= x.data;
204
- return *this;
205
- }
206
- __device__
207
- __half& operator++() { ++data; return *this; }
208
- __device__
209
- __half operator++(int)
210
- {
211
- __half tmp{*this};
212
- ++*this;
213
- return tmp;
214
- }
215
- __device__
216
- __half& operator--() { --data; return *this; }
217
- __device__
218
- __half operator--(int)
219
- {
220
- __half tmp{*this};
221
- --*this;
222
- return tmp;
223
- }
224
- #endif
225
-
226
- // ACCESSORS
227
- #if !defined(__HIP_NO_HALF_CONVERSIONS__)
228
- template<
229
- typename T,
230
- Enable_if_t<std::is_floating_point<T>{}>* = nullptr>
231
- __HOST_DEVICE__
232
- operator T() const { return data; }
233
- #endif
234
- __HOST_DEVICE__
235
- operator __half_raw() const { return __half_raw{data}; }
236
- __HOST_DEVICE__
237
- operator __half_raw() const volatile
238
- {
239
- return __half_raw{data};
240
- }
241
-
242
- #if !defined(__HIP_NO_HALF_CONVERSIONS__)
243
- template<
244
- typename T, Enable_if_t<std::is_integral<T>{}>* = nullptr>
245
- __HOST_DEVICE__
246
- operator T() const { return data; }
247
- #endif
248
-
249
- #if !defined(__HIP_NO_HALF_OPERATORS__)
250
- __device__
251
- __half operator+() const { return *this; }
252
- __device__
253
- __half operator-() const
254
- {
255
- __half tmp{*this};
256
- tmp.data = -tmp.data;
257
- return tmp;
258
- }
259
- #endif
260
-
261
- // FRIENDS
262
- #if !defined(__HIP_NO_HALF_OPERATORS__)
263
- friend
264
- inline
265
- __device__
266
- __half operator+(const __half& x, const __half& y)
267
- {
268
- return __half{x} += y;
269
- }
270
- friend
271
- inline
272
- __device__
273
- __half operator-(const __half& x, const __half& y)
274
- {
275
- return __half{x} -= y;
276
- }
277
- friend
278
- inline
279
- __device__
280
- __half operator*(const __half& x, const __half& y)
281
- {
282
- return __half{x} *= y;
283
- }
284
- friend
285
- inline
286
- __device__
287
- __half operator/(const __half& x, const __half& y)
288
- {
289
- return __half{x} /= y;
290
- }
291
- friend
292
- inline
293
- __device__
294
- bool operator==(const __half& x, const __half& y)
295
- {
296
- return x.data == y.data;
297
- }
298
- friend
299
- inline
300
- __device__
301
- bool operator!=(const __half& x, const __half& y)
302
- {
303
- return !(x == y);
304
- }
305
- friend
306
- inline
307
- __device__
308
- bool operator<(const __half& x, const __half& y)
309
- {
310
- return x.data < y.data;
311
- }
312
- friend
313
- inline
314
- __device__
315
- bool operator>(const __half& x, const __half& y)
316
- {
317
- return y.data < x.data;
318
- }
319
- friend
320
- inline
321
- __device__
322
- bool operator<=(const __half& x, const __half& y)
323
- {
324
- return !(y < x);
325
- }
326
- friend
327
- inline
328
- __device__
329
- bool operator>=(const __half& x, const __half& y)
330
- {
331
- return !(x < y);
332
- }
333
- #endif // !defined(__HIP_NO_HALF_OPERATORS__)
334
- };
335
- // END STRUCT __HALF
336
-
337
- // BEGIN STRUCT __HALF2
338
- struct __half2 {
339
- public:
340
- union {
341
- static_assert(
342
- sizeof(_Float16_2) == sizeof(unsigned short[2]), "");
343
-
344
- struct {
345
- __half x;
346
- __half y;
347
- };
348
- _Float16_2 data;
349
- };
350
-
351
- // CREATORS
352
- __HOST_DEVICE__
353
- __half2() = default;
354
- __HOST_DEVICE__
355
- __half2(const __half2_raw& xx) : data{xx.data} {}
356
- __HOST_DEVICE__
357
- __half2(decltype(data) xx) : data{xx} {}
358
- __HOST_DEVICE__
359
- __half2(const __half& xx, const __half& yy)
360
- :
361
- data{static_cast<__half_raw>(xx).data,
362
- static_cast<__half_raw>(yy).data}
363
- {}
364
- __HOST_DEVICE__
365
- __half2(const __half2&) = default;
366
- __HOST_DEVICE__
367
- __half2(__half2&&) = default;
368
- __HOST_DEVICE__
369
- ~__half2() = default;
370
-
371
- // MANIPULATORS
372
- __HOST_DEVICE__
373
- __half2& operator=(const __half2&) = default;
374
- __HOST_DEVICE__
375
- __half2& operator=(__half2&&) = default;
376
- __HOST_DEVICE__
377
- __half2& operator=(const __half2_raw& xx)
378
- {
379
- data = xx.data;
380
- return *this;
381
- }
382
-
383
- // MANIPULATORS - DEVICE ONLY
384
- #if !defined(__HIP_NO_HALF_OPERATORS__)
385
- __device__
386
- __half2& operator+=(const __half2& xx)
387
- {
388
- data += xx.data;
389
- return *this;
390
- }
391
- __device__
392
- __half2& operator-=(const __half2& xx)
393
- {
394
- data -= xx.data;
395
- return *this;
396
- }
397
- __device__
398
- __half2& operator*=(const __half2& xx)
399
- {
400
- data *= xx.data;
401
- return *this;
402
- }
403
- __device__
404
- __half2& operator/=(const __half2& xx)
405
- {
406
- data /= xx.data;
407
- return *this;
408
- }
409
- __device__
410
- __half2& operator++() { return *this += _Float16_2{1, 1}; }
411
- __device__
412
- __half2 operator++(int)
413
- {
414
- __half2 tmp{*this};
415
- ++*this;
416
- return tmp;
417
- }
418
- __device__
419
- __half2& operator--() { return *this -= _Float16_2{1, 1}; }
420
- __device__
421
- __half2 operator--(int)
422
- {
423
- __half2 tmp{*this};
424
- --*this;
425
- return tmp;
426
- }
427
- #endif
428
-
429
- // ACCESSORS
430
- __HOST_DEVICE__
431
- operator decltype(data)() const { return data; }
432
- __HOST_DEVICE__
433
- operator __half2_raw() const {
434
- __half2_raw r;
435
- r.data = data;
436
- return r;
437
- }
438
-
439
- // ACCESSORS - DEVICE ONLY
440
- #if !defined(__HIP_NO_HALF_OPERATORS__)
441
- __device__
442
- __half2 operator+() const { return *this; }
443
- __device__
444
- __half2 operator-() const
445
- {
446
- __half2 tmp{*this};
447
- tmp.data = -tmp.data;
448
- return tmp;
449
- }
450
- #endif
451
-
452
- // FRIENDS
453
- #if !defined(__HIP_NO_HALF_OPERATORS__)
454
- friend
455
- inline
456
- __device__
457
- __half2 operator+(const __half2& xx, const __half2& yy)
458
- {
459
- return __half2{xx} += yy;
460
- }
461
- friend
462
- inline
463
- __device__
464
- __half2 operator-(const __half2& xx, const __half2& yy)
465
- {
466
- return __half2{xx} -= yy;
467
- }
468
- friend
469
- inline
470
- __device__
471
- __half2 operator*(const __half2& xx, const __half2& yy)
472
- {
473
- return __half2{xx} *= yy;
474
- }
475
- friend
476
- inline
477
- __device__
478
- __half2 operator/(const __half2& xx, const __half2& yy)
479
- {
480
- return __half2{xx} /= yy;
481
- }
482
- friend
483
- inline
484
- __device__
485
- bool operator==(const __half2& xx, const __half2& yy)
486
- {
487
- auto r = xx.data == yy.data;
488
- return r.x != 0 && r.y != 0;
489
- }
490
- friend
491
- inline
492
- __device__
493
- bool operator!=(const __half2& xx, const __half2& yy)
494
- {
495
- return !(xx == yy);
496
- }
497
- friend
498
- inline
499
- __device__
500
- bool operator<(const __half2& xx, const __half2& yy)
501
- {
502
- auto r = xx.data < yy.data;
503
- return r.x != 0 && r.y != 0;
504
- }
505
- friend
506
- inline
507
- __device__
508
- bool operator>(const __half2& xx, const __half2& yy)
509
- {
510
- return yy < xx;
511
- }
512
- friend
513
- inline
514
- __device__
515
- bool operator<=(const __half2& xx, const __half2& yy)
516
- {
517
- return !(yy < xx);
518
- }
519
- friend
520
- inline
521
- __device__
522
- bool operator>=(const __half2& xx, const __half2& yy)
523
- {
524
- return !(xx < yy);
525
- }
526
- #endif // !defined(__HIP_NO_HALF_OPERATORS__)
527
- };
528
- // END STRUCT __HALF2
529
-
530
- namespace
531
- {
532
- inline
533
- __HOST_DEVICE__
534
- __half2 make_half2(__half x, __half y)
535
- {
536
- return __half2{x, y};
537
- }
538
-
539
- inline
540
- __HOST_DEVICE__
541
- __half __low2half(__half2 x)
542
- {
543
- return __half{__half_raw{static_cast<__half2_raw>(x).data.x}};
544
- }
545
-
546
- inline
547
- __HOST_DEVICE__
548
- __half __high2half(__half2 x)
549
- {
550
- return __half{__half_raw{static_cast<__half2_raw>(x).data.y}};
551
- }
552
-
553
- inline
554
- __HOST_DEVICE__
555
- __half2 __half2half2(__half x)
556
- {
557
- return __half2{x, x};
558
- }
559
-
560
- inline
561
- __HOST_DEVICE__
562
- __half2 __halves2half2(__half x, __half y)
563
- {
564
- return __half2{x, y};
565
- }
566
-
567
- inline
568
- __HOST_DEVICE__
569
- __half2 __low2half2(__half2 x)
570
- {
571
- return __half2{
572
- _Float16_2{
573
- static_cast<__half2_raw>(x).data.x,
574
- static_cast<__half2_raw>(x).data.x}};
575
- }
576
-
577
- inline
578
- __HOST_DEVICE__
579
- __half2 __high2half2(__half2 x)
580
- {
581
- return __half2{
582
- _Float16_2{
583
- static_cast<__half2_raw>(x).data.y,
584
- static_cast<__half2_raw>(x).data.y}};
585
- }
586
-
587
- inline
588
- __HOST_DEVICE__
589
- __half2 __lows2half2(__half2 x, __half2 y)
590
- {
591
- return __half2{
592
- _Float16_2{
593
- static_cast<__half2_raw>(x).data.x,
594
- static_cast<__half2_raw>(y).data.x}};
595
- }
596
-
597
- inline
598
- __HOST_DEVICE__
599
- __half2 __highs2half2(__half2 x, __half2 y)
600
- {
601
- return __half2{
602
- _Float16_2{
603
- static_cast<__half2_raw>(x).data.y,
604
- static_cast<__half2_raw>(y).data.y}};
605
- }
606
-
607
- inline
608
- __HOST_DEVICE__
609
- __half2 __lowhigh2highlow(__half2 x)
610
- {
611
- return __half2{
612
- _Float16_2{
613
- static_cast<__half2_raw>(x).data.y,
614
- static_cast<__half2_raw>(x).data.x}};
615
- }
616
-
617
- // Bitcasts
618
- inline
619
- __device__
620
- short __half_as_short(__half x)
621
- {
622
- return static_cast<__half_raw>(x).x;
623
- }
624
-
625
- inline
626
- __device__
627
- unsigned short __half_as_ushort(__half x)
628
- {
629
- return static_cast<__half_raw>(x).x;
630
- }
631
-
632
- inline
633
- __device__
634
- __half __short_as_half(short x)
635
- {
636
- __half_raw r; r.x = x;
637
- return r;
638
- }
639
-
640
- inline
641
- __device__
642
- __half __ushort_as_half(unsigned short x)
643
- {
644
- __half_raw r; r.x = x;
645
- return r;
646
- }
647
-
648
- // float -> half | half2
649
- inline
650
- __HOST_DEVICE__
651
- __half __float2half(float x)
652
- {
653
- return __half_raw{static_cast<_Float16>(x)};
654
- }
655
- inline
656
- __HOST_DEVICE__
657
- __half __float2half_rn(float x)
658
- {
659
- return __half_raw{static_cast<_Float16>(x)};
660
- }
661
- #if !defined(__HIPCC_RTC__)
662
- // TODO: rounding behaviour is not correct for host functions.
663
- inline
664
- __host__
665
- __half __float2half_rz(float x)
666
- {
667
- return __half_raw{static_cast<_Float16>(x)};
668
- }
669
- inline
670
- __host__
671
- __half __float2half_rd(float x)
672
- {
673
- return __half_raw{static_cast<_Float16>(x)};
674
- }
675
- inline
676
- __host__
677
- __half __float2half_ru(float x)
678
- {
679
- return __half_raw{static_cast<_Float16>(x)};
680
- }
681
- #endif
682
- inline
683
- __device__
684
- __half __float2half_rz(float x)
685
- {
686
- return __half_raw{__ocml_cvtrtz_f16_f32(x)};
687
- }
688
- inline
689
- __device__
690
- __half __float2half_rd(float x)
691
- {
692
- return __half_raw{__ocml_cvtrtn_f16_f32(x)};
693
- }
694
- inline
695
- __device__
696
- __half __float2half_ru(float x)
697
- {
698
- return __half_raw{__ocml_cvtrtp_f16_f32(x)};
699
- }
700
- inline
701
- __HOST_DEVICE__
702
- __half2 __float2half2_rn(float x)
703
- {
704
- return __half2{
705
- _Float16_2{
706
- static_cast<_Float16>(x), static_cast<_Float16>(x)}};
707
- }
708
- inline
709
- __HOST_DEVICE__
710
- __half2 __floats2half2_rn(float x, float y)
711
- {
712
- return __half2{_Float16_2{
713
- static_cast<_Float16>(x), static_cast<_Float16>(y)}};
714
- }
715
- inline
716
- __HOST_DEVICE__
717
- __half2 __float22half2_rn(float2 x)
718
- {
719
- return __floats2half2_rn(x.x, x.y);
720
- }
721
-
722
- // half | half2 -> float
723
- inline
724
- __HOST_DEVICE__
725
- float __half2float(__half x)
726
- {
727
- return static_cast<__half_raw>(x).data;
728
- }
729
- inline
730
- __HOST_DEVICE__
731
- float __low2float(__half2 x)
732
- {
733
- return static_cast<__half2_raw>(x).data.x;
734
- }
735
- inline
736
- __HOST_DEVICE__
737
- float __high2float(__half2 x)
738
- {
739
- return static_cast<__half2_raw>(x).data.y;
740
- }
741
- inline
742
- __HOST_DEVICE__
743
- float2 __half22float2(__half2 x)
744
- {
745
- return make_float2(
746
- static_cast<__half2_raw>(x).data.x,
747
- static_cast<__half2_raw>(x).data.y);
748
- }
749
-
750
- // half -> int
751
- inline
752
- __device__
753
- int __half2int_rn(__half x)
754
- {
755
- return static_cast<__half_raw>(x).data;
756
- }
757
- inline
758
- __device__
759
- int __half2int_rz(__half x)
760
- {
761
- return static_cast<__half_raw>(x).data;
762
- }
763
- inline
764
- __device__
765
- int __half2int_rd(__half x)
766
- {
767
- return static_cast<__half_raw>(x).data;
768
- }
769
- inline
770
- __device__
771
- int __half2int_ru(__half x)
772
- {
773
- return static_cast<__half_raw>(x).data;
774
- }
775
-
776
- // int -> half
777
- inline
778
- __device__
779
- __half __int2half_rn(int x)
780
- {
781
- return __half_raw{static_cast<_Float16>(x)};
782
- }
783
- inline
784
- __device__
785
- __half __int2half_rz(int x)
786
- {
787
- return __half_raw{static_cast<_Float16>(x)};
788
- }
789
- inline
790
- __device__
791
- __half __int2half_rd(int x)
792
- {
793
- return __half_raw{static_cast<_Float16>(x)};
794
- }
795
- inline
796
- __device__
797
- __half __int2half_ru(int x)
798
- {
799
- return __half_raw{static_cast<_Float16>(x)};
800
- }
801
-
802
- // half -> short
803
- inline
804
- __device__
805
- short __half2short_rn(__half x)
806
- {
807
- return static_cast<__half_raw>(x).data;
808
- }
809
- inline
810
- __device__
811
- short __half2short_rz(__half x)
812
- {
813
- return static_cast<__half_raw>(x).data;
814
- }
815
- inline
816
- __device__
817
- short __half2short_rd(__half x)
818
- {
819
- return static_cast<__half_raw>(x).data;
820
- }
821
- inline
822
- __device__
823
- short __half2short_ru(__half x)
824
- {
825
- return static_cast<__half_raw>(x).data;
826
- }
827
-
828
- // short -> half
829
- inline
830
- __device__
831
- __half __short2half_rn(short x)
832
- {
833
- return __half_raw{static_cast<_Float16>(x)};
834
- }
835
- inline
836
- __device__
837
- __half __short2half_rz(short x)
838
- {
839
- return __half_raw{static_cast<_Float16>(x)};
840
- }
841
- inline
842
- __device__
843
- __half __short2half_rd(short x)
844
- {
845
- return __half_raw{static_cast<_Float16>(x)};
846
- }
847
- inline
848
- __device__
849
- __half __short2half_ru(short x)
850
- {
851
- return __half_raw{static_cast<_Float16>(x)};
852
- }
853
-
854
- // half -> long long
855
- inline
856
- __device__
857
- long long __half2ll_rn(__half x)
858
- {
859
- return static_cast<__half_raw>(x).data;
860
- }
861
- inline
862
- __device__
863
- long long __half2ll_rz(__half x)
864
- {
865
- return static_cast<__half_raw>(x).data;
866
- }
867
- inline
868
- __device__
869
- long long __half2ll_rd(__half x)
870
- {
871
- return static_cast<__half_raw>(x).data;
872
- }
873
- inline
874
- __device__
875
- long long __half2ll_ru(__half x)
876
- {
877
- return static_cast<__half_raw>(x).data;
878
- }
879
-
880
- // long long -> half
881
- inline
882
- __device__
883
- __half __ll2half_rn(long long x)
884
- {
885
- return __half_raw{static_cast<_Float16>(x)};
886
- }
887
- inline
888
- __device__
889
- __half __ll2half_rz(long long x)
890
- {
891
- return __half_raw{static_cast<_Float16>(x)};
892
- }
893
- inline
894
- __device__
895
- __half __ll2half_rd(long long x)
896
- {
897
- return __half_raw{static_cast<_Float16>(x)};
898
- }
899
- inline
900
- __device__
901
- __half __ll2half_ru(long long x)
902
- {
903
- return __half_raw{static_cast<_Float16>(x)};
904
- }
905
-
906
- // half -> unsigned int
907
- inline
908
- __device__
909
- unsigned int __half2uint_rn(__half x)
910
- {
911
- return static_cast<__half_raw>(x).data;
912
- }
913
- inline
914
- __device__
915
- unsigned int __half2uint_rz(__half x)
916
- {
917
- return static_cast<__half_raw>(x).data;
918
- }
919
- inline
920
- __device__
921
- unsigned int __half2uint_rd(__half x)
922
- {
923
- return static_cast<__half_raw>(x).data;
924
- }
925
- inline
926
- __device__
927
- unsigned int __half2uint_ru(__half x)
928
- {
929
- return static_cast<__half_raw>(x).data;
930
- }
931
-
932
- // unsigned int -> half
933
- inline
934
- __device__
935
- __half __uint2half_rn(unsigned int x)
936
- {
937
- return __half_raw{static_cast<_Float16>(x)};
938
- }
939
- inline
940
- __device__
941
- __half __uint2half_rz(unsigned int x)
942
- {
943
- return __half_raw{static_cast<_Float16>(x)};
944
- }
945
- inline
946
- __device__
947
- __half __uint2half_rd(unsigned int x)
948
- {
949
- return __half_raw{static_cast<_Float16>(x)};
950
- }
951
- inline
952
- __device__
953
- __half __uint2half_ru(unsigned int x)
954
- {
955
- return __half_raw{static_cast<_Float16>(x)};
956
- }
957
-
958
- // half -> unsigned short
959
- inline
960
- __device__
961
- unsigned short __half2ushort_rn(__half x)
962
- {
963
- return static_cast<__half_raw>(x).data;
964
- }
965
- inline
966
- __device__
967
- unsigned short __half2ushort_rz(__half x)
968
- {
969
- return static_cast<__half_raw>(x).data;
970
- }
971
- inline
972
- __device__
973
- unsigned short __half2ushort_rd(__half x)
974
- {
975
- return static_cast<__half_raw>(x).data;
976
- }
977
- inline
978
- __device__
979
- unsigned short __half2ushort_ru(__half x)
980
- {
981
- return static_cast<__half_raw>(x).data;
982
- }
983
-
984
- // unsigned short -> half
985
- inline
986
- __device__
987
- __half __ushort2half_rn(unsigned short x)
988
- {
989
- return __half_raw{static_cast<_Float16>(x)};
990
- }
991
- inline
992
- __device__
993
- __half __ushort2half_rz(unsigned short x)
994
- {
995
- return __half_raw{static_cast<_Float16>(x)};
996
- }
997
- inline
998
- __device__
999
- __half __ushort2half_rd(unsigned short x)
1000
- {
1001
- return __half_raw{static_cast<_Float16>(x)};
1002
- }
1003
- inline
1004
- __device__
1005
- __half __ushort2half_ru(unsigned short x)
1006
- {
1007
- return __half_raw{static_cast<_Float16>(x)};
1008
- }
1009
-
1010
- // half -> unsigned long long
1011
- inline
1012
- __device__
1013
- unsigned long long __half2ull_rn(__half x)
1014
- {
1015
- return static_cast<__half_raw>(x).data;
1016
- }
1017
- inline
1018
- __device__
1019
- unsigned long long __half2ull_rz(__half x)
1020
- {
1021
- return static_cast<__half_raw>(x).data;
1022
- }
1023
- inline
1024
- __device__
1025
- unsigned long long __half2ull_rd(__half x)
1026
- {
1027
- return static_cast<__half_raw>(x).data;
1028
- }
1029
- inline
1030
- __device__
1031
- unsigned long long __half2ull_ru(__half x)
1032
- {
1033
- return static_cast<__half_raw>(x).data;
1034
- }
1035
-
1036
- // unsigned long long -> half
1037
- inline
1038
- __device__
1039
- __half __ull2half_rn(unsigned long long x)
1040
- {
1041
- return __half_raw{static_cast<_Float16>(x)};
1042
- }
1043
- inline
1044
- __device__
1045
- __half __ull2half_rz(unsigned long long x)
1046
- {
1047
- return __half_raw{static_cast<_Float16>(x)};
1048
- }
1049
- inline
1050
- __device__
1051
- __half __ull2half_rd(unsigned long long x)
1052
- {
1053
- return __half_raw{static_cast<_Float16>(x)};
1054
- }
1055
- inline
1056
- __device__
1057
- __half __ull2half_ru(unsigned long long x)
1058
- {
1059
- return __half_raw{static_cast<_Float16>(x)};
1060
- }
1061
-
1062
- // Load primitives
1063
- inline
1064
- __device__
1065
- __half __ldg(const __half* ptr) { return *ptr; }
1066
- inline
1067
- __device__
1068
- __half __ldcg(const __half* ptr) { return *ptr; }
1069
- inline
1070
- __device__
1071
- __half __ldca(const __half* ptr) { return *ptr; }
1072
- inline
1073
- __device__
1074
- __half __ldcs(const __half* ptr) { return *ptr; }
1075
-
1076
- inline
1077
- __HOST_DEVICE__
1078
- __half2 __ldg(const __half2* ptr) { return *ptr; }
1079
- inline
1080
- __HOST_DEVICE__
1081
- __half2 __ldcg(const __half2* ptr) { return *ptr; }
1082
- inline
1083
- __HOST_DEVICE__
1084
- __half2 __ldca(const __half2* ptr) { return *ptr; }
1085
- inline
1086
- __HOST_DEVICE__
1087
- __half2 __ldcs(const __half2* ptr) { return *ptr; }
1088
-
1089
- // Relations
1090
- inline
1091
- __device__
1092
- bool __heq(__half x, __half y)
1093
- {
1094
- return static_cast<__half_raw>(x).data ==
1095
- static_cast<__half_raw>(y).data;
1096
- }
1097
- inline
1098
- __device__
1099
- bool __hne(__half x, __half y)
1100
- {
1101
- return static_cast<__half_raw>(x).data !=
1102
- static_cast<__half_raw>(y).data;
1103
- }
1104
- inline
1105
- __device__
1106
- bool __hle(__half x, __half y)
1107
- {
1108
- return static_cast<__half_raw>(x).data <=
1109
- static_cast<__half_raw>(y).data;
1110
- }
1111
- inline
1112
- __device__
1113
- bool __hge(__half x, __half y)
1114
- {
1115
- return static_cast<__half_raw>(x).data >=
1116
- static_cast<__half_raw>(y).data;
1117
- }
1118
- inline
1119
- __device__
1120
- bool __hlt(__half x, __half y)
1121
- {
1122
- return static_cast<__half_raw>(x).data <
1123
- static_cast<__half_raw>(y).data;
1124
- }
1125
- inline
1126
- __device__
1127
- bool __hgt(__half x, __half y)
1128
- {
1129
- return static_cast<__half_raw>(x).data >
1130
- static_cast<__half_raw>(y).data;
1131
- }
1132
- inline __device__
1133
- bool __hequ(__half x, __half y) {
1134
- return !(static_cast<__half_raw>(x).data < static_cast<__half_raw>(y).data) &&
1135
- !(static_cast<__half_raw>(x).data > static_cast<__half_raw>(y).data);
1136
- }
1137
- inline __device__
1138
- bool __hneu(__half x, __half y) {
1139
- return !(static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data);
1140
- }
1141
- inline __device__
1142
- bool __hleu(__half x, __half y) {
1143
- return !(static_cast<__half_raw>(x).data > static_cast<__half_raw>(y).data);
1144
- }
1145
- inline
1146
- __device__
1147
- bool __hgeu(__half x, __half y) {
1148
- return !(static_cast<__half_raw>(x).data < static_cast<__half_raw>(y).data);
1149
- }
1150
- inline
1151
- __device__
1152
- bool __hltu(__half x, __half y) {
1153
- return !(static_cast<__half_raw>(x).data >= static_cast<__half_raw>(y).data);
1154
- }
1155
- inline
1156
- __device__
1157
- bool __hgtu(__half x, __half y) {
1158
- return !(static_cast<__half_raw>(x).data <= static_cast<__half_raw>(y).data);
1159
- }
1160
-
1161
- inline
1162
- __HOST_DEVICE__
1163
- __half2 __heq2(__half2 x, __half2 y)
1164
- {
1165
- auto r = static_cast<__half2_raw>(x).data ==
1166
- static_cast<__half2_raw>(y).data;
1167
- return __builtin_convertvector(-r, _Float16_2);
1168
- }
1169
- inline
1170
- __HOST_DEVICE__
1171
- __half2 __hne2(__half2 x, __half2 y)
1172
- {
1173
- auto r = static_cast<__half2_raw>(x).data !=
1174
- static_cast<__half2_raw>(y).data;
1175
- return __builtin_convertvector(-r, _Float16_2);
1176
- }
1177
- inline
1178
- __HOST_DEVICE__
1179
- __half2 __hle2(__half2 x, __half2 y)
1180
- {
1181
- auto r = static_cast<__half2_raw>(x).data <=
1182
- static_cast<__half2_raw>(y).data;
1183
- return __builtin_convertvector(-r, _Float16_2);
1184
- }
1185
- inline
1186
- __HOST_DEVICE__
1187
- __half2 __hge2(__half2 x, __half2 y)
1188
- {
1189
- auto r = static_cast<__half2_raw>(x).data >=
1190
- static_cast<__half2_raw>(y).data;
1191
- return __builtin_convertvector(-r, _Float16_2);
1192
- }
1193
- inline
1194
- __HOST_DEVICE__
1195
- __half2 __hlt2(__half2 x, __half2 y)
1196
- {
1197
- auto r = static_cast<__half2_raw>(x).data <
1198
- static_cast<__half2_raw>(y).data;
1199
- return __builtin_convertvector(-r, _Float16_2);
1200
- }
1201
- inline
1202
- __HOST_DEVICE__
1203
- __half2 __hgt2(__half2 x, __half2 y)
1204
- {
1205
- auto r = static_cast<__half2_raw>(x).data >
1206
- static_cast<__half2_raw>(y).data;
1207
- return __builtin_convertvector(-r, _Float16_2);
1208
- }
1209
- inline __HOST_DEVICE__
1210
- __half2 __hequ2(__half2 x, __half2 y) {
1211
- auto r = !(static_cast<__half2_raw>(x).data < static_cast<__half2_raw>(y).data) &&
1212
- !(static_cast<__half2_raw>(x).data > static_cast<__half2_raw>(y).data);
1213
- return __builtin_convertvector(-r, _Float16_2);
1214
- }
1215
- inline
1216
- __HOST_DEVICE__
1217
- __half2 __hneu2(__half2 x, __half2 y) {
1218
- auto r = !(static_cast<__half2_raw>(x).data == static_cast<__half2_raw>(y).data);
1219
- return __builtin_convertvector(-r, _Float16_2);
1220
- }
1221
- inline
1222
- __HOST_DEVICE__
1223
- __half2 __hleu2(__half2 x, __half2 y) {
1224
- auto r = !(static_cast<__half2_raw>(x).data > static_cast<__half2_raw>(y).data);
1225
- return __builtin_convertvector(-r, _Float16_2);
1226
- }
1227
- inline
1228
- __HOST_DEVICE__
1229
- __half2 __hgeu2(__half2 x, __half2 y) {
1230
- auto r = !(static_cast<__half2_raw>(x).data < static_cast<__half2_raw>(y).data);
1231
- return __builtin_convertvector(-r, _Float16_2);
1232
- }
1233
- inline
1234
- __HOST_DEVICE__
1235
- __half2 __hltu2(__half2 x, __half2 y) {
1236
- auto r = !(static_cast<__half2_raw>(x).data >= static_cast<__half2_raw>(y).data);
1237
- return __builtin_convertvector(-r, _Float16_2);
1238
- }
1239
- inline
1240
- __HOST_DEVICE__
1241
- __half2 __hgtu2(__half2 x, __half2 y) {
1242
- auto r = !(static_cast<__half2_raw>(x).data <= static_cast<__half2_raw>(y).data);
1243
- return __builtin_convertvector(-r, _Float16_2);
1244
- }
1245
-
1246
- inline
1247
- __HOST_DEVICE__
1248
- bool __hbeq2(__half2 x, __half2 y)
1249
- {
1250
- auto r = static_cast<__half2_raw>(__heq2(x, y));
1251
- return r.data.x != 0 && r.data.y != 0;
1252
- }
1253
- inline
1254
- __HOST_DEVICE__
1255
- bool __hbne2(__half2 x, __half2 y)
1256
- {
1257
- auto r = static_cast<__half2_raw>(__hne2(x, y));
1258
- return r.data.x != 0 && r.data.y != 0;
1259
- }
1260
- inline
1261
- __HOST_DEVICE__
1262
- bool __hble2(__half2 x, __half2 y)
1263
- {
1264
- auto r = static_cast<__half2_raw>(__hle2(x, y));
1265
- return r.data.x != 0 && r.data.y != 0;
1266
- }
1267
- inline
1268
- __HOST_DEVICE__
1269
- bool __hbge2(__half2 x, __half2 y)
1270
- {
1271
- auto r = static_cast<__half2_raw>(__hge2(x, y));
1272
- return r.data.x != 0 && r.data.y != 0;
1273
- }
1274
- inline
1275
- __HOST_DEVICE__
1276
- bool __hblt2(__half2 x, __half2 y)
1277
- {
1278
- auto r = static_cast<__half2_raw>(__hlt2(x, y));
1279
- return r.data.x != 0 && r.data.y != 0;
1280
- }
1281
- inline
1282
- __HOST_DEVICE__
1283
- bool __hbgt2(__half2 x, __half2 y)
1284
- {
1285
- auto r = static_cast<__half2_raw>(__hgt2(x, y));
1286
- return r.data.x != 0 && r.data.y != 0;
1287
- }
1288
- inline
1289
- __HOST_DEVICE__
1290
- bool __hbequ2(__half2 x, __half2 y) { return __hbeq2(x, y); }
1291
- inline
1292
- __HOST_DEVICE__
1293
- bool __hbneu2(__half2 x, __half2 y) { return __hbne2(x, y); }
1294
- inline
1295
- __HOST_DEVICE__
1296
- bool __hbleu2(__half2 x, __half2 y) { return __hble2(x, y); }
1297
- inline
1298
- __HOST_DEVICE__
1299
- bool __hbgeu2(__half2 x, __half2 y) { return __hbge2(x, y); }
1300
- inline
1301
- __HOST_DEVICE__
1302
- bool __hbltu2(__half2 x, __half2 y) { return __hblt2(x, y); }
1303
- inline
1304
- __HOST_DEVICE__
1305
- bool __hbgtu2(__half2 x, __half2 y) { return __hbgt2(x, y); }
1306
- inline
1307
- __device__
1308
- __half __hmax(const __half x, const __half y) {
1309
- return __half_raw{__ocml_fmax_f16(static_cast<__half_raw>(x).data,
1310
- static_cast<__half_raw>(y).data)};
1311
- }
1312
- inline
1313
- __device__
1314
- __half __hmax_nan(const __half x, const __half y) {
1315
- if(__ocml_isnan_f16(static_cast<__half_raw>(x).data)) {
1316
- return x;
1317
- } else if (__ocml_isnan_f16(static_cast<__half_raw>(y).data)) {
1318
- return y;
1319
- }
1320
- return __hmax(x, y);
1321
- }
1322
- inline
1323
- __device__
1324
- __half __hmin(const __half x, const __half y) {
1325
- return __half_raw{__ocml_fmin_f16(static_cast<__half_raw>(x).data,
1326
- static_cast<__half_raw>(y).data)};
1327
- }
1328
- inline
1329
- __device__
1330
- __half __hmin_nan(const __half x, const __half y) {
1331
- if(__ocml_isnan_f16(static_cast<__half_raw>(x).data)) {
1332
- return x;
1333
- } else if (__ocml_isnan_f16(static_cast<__half_raw>(y).data)) {
1334
- return y;
1335
- }
1336
- return __hmin(x, y);
1337
- }
1338
-
1339
- // Arithmetic
1340
- inline
1341
- __device__
1342
- __half __clamp_01(__half x)
1343
- {
1344
- auto r = static_cast<__half_raw>(x);
1345
-
1346
- if (__hlt(x, __half_raw{0})) return __half_raw{0};
1347
- if (__hlt(__half_raw{1}, x)) return __half_raw{1};
1348
- return r;
1349
- }
1350
-
1351
- inline
1352
- __device__
1353
- __half __hadd(__half x, __half y)
1354
- {
1355
- return __half_raw{
1356
- static_cast<__half_raw>(x).data +
1357
- static_cast<__half_raw>(y).data};
1358
- }
1359
- inline
1360
- __device__
1361
- __half __habs(__half x)
1362
- {
1363
- return __half_raw{
1364
- __ocml_fabs_f16(static_cast<__half_raw>(x).data)};
1365
- }
1366
- inline
1367
- __device__
1368
- __half __hsub(__half x, __half y)
1369
- {
1370
- return __half_raw{
1371
- static_cast<__half_raw>(x).data -
1372
- static_cast<__half_raw>(y).data};
1373
- }
1374
- inline
1375
- __device__
1376
- __half __hmul(__half x, __half y)
1377
- {
1378
- return __half_raw{
1379
- static_cast<__half_raw>(x).data *
1380
- static_cast<__half_raw>(y).data};
1381
- }
1382
- inline
1383
- __device__
1384
- __half __hadd_sat(__half x, __half y)
1385
- {
1386
- return __clamp_01(__hadd(x, y));
1387
- }
1388
- inline
1389
- __device__
1390
- __half __hsub_sat(__half x, __half y)
1391
- {
1392
- return __clamp_01(__hsub(x, y));
1393
- }
1394
- inline
1395
- __device__
1396
- __half __hmul_sat(__half x, __half y)
1397
- {
1398
- return __clamp_01(__hmul(x, y));
1399
- }
1400
- inline
1401
- __device__
1402
- __half __hfma(__half x, __half y, __half z)
1403
- {
1404
- return __half_raw{__ocml_fma_f16(
1405
- static_cast<__half_raw>(x).data,
1406
- static_cast<__half_raw>(y).data,
1407
- static_cast<__half_raw>(z).data)};
1408
- }
1409
- inline
1410
- __device__
1411
- __half __hfma_sat(__half x, __half y, __half z)
1412
- {
1413
- return __clamp_01(__hfma(x, y, z));
1414
- }
1415
- inline
1416
- __device__
1417
- __half __hdiv(__half x, __half y)
1418
- {
1419
- return __half_raw{
1420
- static_cast<__half_raw>(x).data /
1421
- static_cast<__half_raw>(y).data};
1422
- }
1423
-
1424
- inline
1425
- __HOST_DEVICE__
1426
- __half2 __hadd2(__half2 x, __half2 y)
1427
- {
1428
- return __half2{
1429
- static_cast<__half2_raw>(x).data +
1430
- static_cast<__half2_raw>(y).data};
1431
- }
1432
- inline
1433
- __HOST_DEVICE__
1434
- __half2 __habs2(__half2 x)
1435
- {
1436
- return __half2{
1437
- __ocml_fabs_2f16(static_cast<__half2_raw>(x).data)};
1438
- }
1439
- inline
1440
- __HOST_DEVICE__
1441
- __half2 __hsub2(__half2 x, __half2 y)
1442
- {
1443
- return __half2{
1444
- static_cast<__half2_raw>(x).data -
1445
- static_cast<__half2_raw>(y).data};
1446
- }
1447
- inline
1448
- __HOST_DEVICE__
1449
- __half2 __hmul2(__half2 x, __half2 y)
1450
- {
1451
- return __half2{
1452
- static_cast<__half2_raw>(x).data *
1453
- static_cast<__half2_raw>(y).data};
1454
- }
1455
- inline
1456
- __HOST_DEVICE__
1457
- __half2 __hadd2_sat(__half2 x, __half2 y)
1458
- {
1459
- auto r = static_cast<__half2_raw>(__hadd2(x, y));
1460
- return __half2{
1461
- __clamp_01(__half_raw{r.data.x}),
1462
- __clamp_01(__half_raw{r.data.y})};
1463
- }
1464
- inline
1465
- __HOST_DEVICE__
1466
- __half2 __hsub2_sat(__half2 x, __half2 y)
1467
- {
1468
- auto r = static_cast<__half2_raw>(__hsub2(x, y));
1469
- return __half2{
1470
- __clamp_01(__half_raw{r.data.x}),
1471
- __clamp_01(__half_raw{r.data.y})};
1472
- }
1473
- inline
1474
- __HOST_DEVICE__
1475
- __half2 __hmul2_sat(__half2 x, __half2 y)
1476
- {
1477
- auto r = static_cast<__half2_raw>(__hmul2(x, y));
1478
- return __half2{
1479
- __clamp_01(__half_raw{r.data.x}),
1480
- __clamp_01(__half_raw{r.data.y})};
1481
- }
1482
- inline
1483
- __HOST_DEVICE__
1484
- __half2 __hfma2(__half2 x, __half2 y, __half2 z)
1485
- {
1486
- return __half2{__ocml_fma_2f16(x, y, z)};
1487
- }
1488
- inline
1489
- __HOST_DEVICE__
1490
- __half2 __hfma2_sat(__half2 x, __half2 y, __half2 z)
1491
- {
1492
- auto r = static_cast<__half2_raw>(__hfma2(x, y, z));
1493
- return __half2{
1494
- __clamp_01(__half_raw{r.data.x}),
1495
- __clamp_01(__half_raw{r.data.y})};
1496
- }
1497
- inline
1498
- __HOST_DEVICE__
1499
- __half2 __h2div(__half2 x, __half2 y)
1500
- {
1501
- return __half2{
1502
- static_cast<__half2_raw>(x).data /
1503
- static_cast<__half2_raw>(y).data};
1504
- }
1505
-
1506
- // Math functions
1507
- #if defined(__clang__) && defined(__HIP__)
1508
- inline
1509
- __device__
1510
- float amd_mixed_dot(__half2 a, __half2 b, float c, bool saturate) {
1511
- return __ockl_fdot2(static_cast<__half2_raw>(a).data,
1512
- static_cast<__half2_raw>(b).data,
1513
- c, saturate);
1514
- }
1515
- #endif
1516
- inline
1517
- __device__
1518
- __half htrunc(__half x)
1519
- {
1520
- return __half_raw{
1521
- __ocml_trunc_f16(static_cast<__half_raw>(x).data)};
1522
- }
1523
- inline
1524
- __device__
1525
- __half hceil(__half x)
1526
- {
1527
- return __half_raw{
1528
- __ocml_ceil_f16(static_cast<__half_raw>(x).data)};
1529
- }
1530
- inline
1531
- __device__
1532
- __half hfloor(__half x)
1533
- {
1534
- return __half_raw{
1535
- __ocml_floor_f16(static_cast<__half_raw>(x).data)};
1536
- }
1537
- inline
1538
- __device__
1539
- __half hrint(__half x)
1540
- {
1541
- return __half_raw{
1542
- __ocml_rint_f16(static_cast<__half_raw>(x).data)};
1543
- }
1544
- inline
1545
- __device__
1546
- __half hsin(__half x)
1547
- {
1548
- return __half_raw{
1549
- __ocml_sin_f16(static_cast<__half_raw>(x).data)};
1550
- }
1551
- inline
1552
- __device__
1553
- __half hcos(__half x)
1554
- {
1555
- return __half_raw{
1556
- __ocml_cos_f16(static_cast<__half_raw>(x).data)};
1557
- }
1558
- inline
1559
- __device__
1560
- __half hexp(__half x)
1561
- {
1562
- return __half_raw{
1563
- __ocml_exp_f16(static_cast<__half_raw>(x).data)};
1564
- }
1565
- inline
1566
- __device__
1567
- __half hexp2(__half x)
1568
- {
1569
- return __half_raw{
1570
- __ocml_exp2_f16(static_cast<__half_raw>(x).data)};
1571
- }
1572
- inline
1573
- __device__
1574
- __half hexp10(__half x)
1575
- {
1576
- return __half_raw{
1577
- __ocml_exp10_f16(static_cast<__half_raw>(x).data)};
1578
- }
1579
- inline
1580
- __device__
1581
- __half hlog2(__half x)
1582
- {
1583
- return __half_raw{
1584
- __ocml_log2_f16(static_cast<__half_raw>(x).data)};
1585
- }
1586
- inline
1587
- __device__
1588
- __half hlog(__half x)
1589
- {
1590
- return __half_raw{
1591
- __ocml_log_f16(static_cast<__half_raw>(x).data)};
1592
- }
1593
- inline
1594
- __device__
1595
- __half hlog10(__half x)
1596
- {
1597
- return __half_raw{
1598
- __ocml_log10_f16(static_cast<__half_raw>(x).data)};
1599
- }
1600
- inline
1601
- __device__
1602
- __half hrcp(__half x)
1603
- {
1604
- return __half_raw{
1605
- static_cast<_Float16>(1.0f) /static_cast<__half_raw>(x).data};
1606
- }
1607
- inline
1608
- __device__
1609
- __half hrsqrt(__half x)
1610
- {
1611
- return __half_raw{
1612
- __ocml_rsqrt_f16(static_cast<__half_raw>(x).data)};
1613
- }
1614
- inline
1615
- __device__
1616
- __half hsqrt(__half x)
1617
- {
1618
- return __half_raw{
1619
- __ocml_sqrt_f16(static_cast<__half_raw>(x).data)};
1620
- }
1621
- inline
1622
- __device__
1623
- bool __hisinf(__half x)
1624
- {
1625
- return __ocml_isinf_f16(static_cast<__half_raw>(x).data);
1626
- }
1627
- inline
1628
- __device__
1629
- bool __hisnan(__half x)
1630
- {
1631
- return __ocml_isnan_f16(static_cast<__half_raw>(x).data);
1632
- }
1633
- inline
1634
- __device__
1635
- __half __hneg(__half x)
1636
- {
1637
- return __half_raw{-static_cast<__half_raw>(x).data};
1638
- }
1639
-
1640
- inline
1641
- __HOST_DEVICE__
1642
- __half2 h2trunc(__half2 x)
1643
- {
1644
- return __half2{__ocml_trunc_2f16(x)};
1645
- }
1646
- inline
1647
- __HOST_DEVICE__
1648
- __half2 h2ceil(__half2 x)
1649
- {
1650
- return __half2{__ocml_ceil_2f16(x)};
1651
- }
1652
- inline
1653
- __HOST_DEVICE__
1654
- __half2 h2floor(__half2 x)
1655
- {
1656
- return __half2{__ocml_floor_2f16(x)};
1657
- }
1658
- inline
1659
- __HOST_DEVICE__
1660
- __half2 h2rint(__half2 x)
1661
- {
1662
- return __half2{__ocml_rint_2f16(x)};
1663
- }
1664
- inline
1665
- __HOST_DEVICE__
1666
- __half2 h2sin(__half2 x)
1667
- {
1668
- return __half2{__ocml_sin_2f16(x)};
1669
- }
1670
- inline
1671
- __HOST_DEVICE__
1672
- __half2 h2cos(__half2 x)
1673
- {
1674
- return __half2{__ocml_cos_2f16(x)};
1675
- }
1676
- inline
1677
- __HOST_DEVICE__
1678
- __half2 h2exp(__half2 x)
1679
- {
1680
- return __half2{__ocml_exp_2f16(x)};
1681
- }
1682
- inline
1683
- __HOST_DEVICE__
1684
- __half2 h2exp2(__half2 x)
1685
- {
1686
- return __half2{__ocml_exp2_2f16(x)};
1687
- }
1688
- inline
1689
- __HOST_DEVICE__
1690
- __half2 h2exp10(__half2 x)
1691
- {
1692
- return __half2{__ocml_exp10_2f16(x)};
1693
- }
1694
- inline
1695
- __HOST_DEVICE__
1696
- __half2 h2log2(__half2 x)
1697
- {
1698
- return __half2{__ocml_log2_2f16(x)};
1699
- }
1700
- inline
1701
- __HOST_DEVICE__
1702
- __half2 h2log(__half2 x) { return __ocml_log_2f16(x); }
1703
- inline
1704
- __HOST_DEVICE__
1705
- __half2 h2log10(__half2 x) { return __ocml_log10_2f16(x); }
1706
- inline
1707
- __HOST_DEVICE__
1708
- __half2 h2rcp(__half2 x) {
1709
- return _Float16_2{
1710
- _Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / x.data};
1711
- }
1712
- inline
1713
- __HOST_DEVICE__
1714
- __half2 h2rsqrt(__half2 x) { return __ocml_rsqrt_2f16(x); }
1715
- inline
1716
- __HOST_DEVICE__
1717
- __half2 h2sqrt(__half2 x) { return __ocml_sqrt_2f16(x); }
1718
- inline
1719
- __HOST_DEVICE__
1720
- __half2 __hisinf2(__half2 x)
1721
- {
1722
- auto r = __ocml_isinf_2f16(x);
1723
- return __half2{_Float16_2{
1724
- static_cast<_Float16>(r.x), static_cast<_Float16>(r.y)}};
1725
- }
1726
- inline
1727
- __HOST_DEVICE__
1728
- __half2 __hisnan2(__half2 x)
1729
- {
1730
- auto r = __ocml_isnan_2f16(x);
1731
- return __half2{_Float16_2{
1732
- static_cast<_Float16>(r.x), static_cast<_Float16>(r.y)}};
1733
- }
1734
- inline
1735
- __HOST_DEVICE__
1736
- __half2 __hneg2(__half2 x)
1737
- {
1738
- return __half2{-static_cast<__half2_raw>(x).data};
1739
- }
1740
- } // Anonymous namespace.
1741
-
1742
- #if !defined(HIP_NO_HALF)
1743
- using half = __half;
1744
- using half2 = __half2;
1745
- #endif
1746
- __device__
1747
- inline
1748
- __half __shfl(__half var, int src_lane, int width = warpSize) {
1749
- union { int i; __half h; } tmp; tmp.h = var;
1750
- tmp.i = __shfl(tmp.i, src_lane, width);
1751
- return tmp.h;
1752
- }
1753
- __device__
1754
- inline
1755
- __half2 __shfl(__half2 var, int src_lane, int width = warpSize) {
1756
- union { int i; __half2 h; } tmp; tmp.h = var;
1757
- tmp.i = __shfl(tmp.i, src_lane, width);
1758
- return tmp.h;
1759
- }
1760
- __device__
1761
- inline
1762
- __half __shfl_up(__half var, unsigned int lane_delta, int width = warpSize) {
1763
- union { int i; __half h; } tmp; tmp.h = var;
1764
- tmp.i = __shfl_up(tmp.i, lane_delta, width);
1765
- return tmp.h;
1766
- }
1767
- __device__
1768
- inline
1769
- __half2 __shfl_up(__half2 var, unsigned int lane_delta, int width = warpSize) {
1770
- union { int i; __half2 h; } tmp; tmp.h = var;
1771
- tmp.i = __shfl_up(tmp.i, lane_delta, width);
1772
- return tmp.h;
1773
- }
1774
- __device__
1775
- inline
1776
- __half __shfl_down(__half var, unsigned int lane_delta, int width = warpSize) {
1777
- union { int i; __half h; } tmp; tmp.h = var;
1778
- tmp.i = __shfl_down(tmp.i, lane_delta, width);
1779
- return tmp.h;
1780
- }
1781
- __device__
1782
- inline
1783
- __half2 __shfl_down(__half2 var, unsigned int lane_delta, int width = warpSize) {
1784
- union { int i; __half2 h; } tmp; tmp.h = var;
1785
- tmp.i = __shfl_down(tmp.i, lane_delta, width);
1786
- return tmp.h;
1787
- }
1788
- __device__
1789
- inline
1790
- __half __shfl_xor(__half var, int lane_mask, int width = warpSize) {
1791
- union { int i; __half h; } tmp; tmp.h = var;
1792
- tmp.i = __shfl_xor(tmp.i, lane_mask, width);
1793
- return tmp.h;
1794
- }
1795
- __device__
1796
- inline
1797
- __half2 __shfl_xor(__half2 var, int lane_mask, int width = warpSize) {
1798
- union { int i; __half2 h; } tmp; tmp.h = var;
1799
- tmp.i = __shfl_xor(tmp.i, lane_mask, width);
1800
- return tmp.h;
1801
- }
1802
- #endif // defined(__cplusplus)
1803
- #elif defined(__GNUC__)
1804
- #if !defined(__HIPCC_RTC__)
1805
- #include "hip_fp16_gcc.h"
1806
- #endif
1807
- #endif // !defined(__clang__) && defined(__GNUC__)
1808
-
1809
- #endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H