numba-cuda 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +246 -114
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
  13. numba_cuda/numba/cuda/cuda_paths.py +293 -99
  14. numba_cuda/numba/cuda/cudadecl.py +93 -79
  15. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  16. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  17. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  18. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  19. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  20. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  21. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  22. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  23. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  24. numba_cuda/numba/cuda/cudadrv/linkable_code.py +27 -3
  25. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  26. numba_cuda/numba/cuda/cudadrv/nvrtc.py +146 -30
  27. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  28. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  29. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  30. numba_cuda/numba/cuda/cudaimpl.py +296 -275
  31. numba_cuda/numba/cuda/cudamath.py +1 -1
  32. numba_cuda/numba/cuda/debuginfo.py +99 -7
  33. numba_cuda/numba/cuda/decorators.py +87 -45
  34. numba_cuda/numba/cuda/descriptor.py +1 -1
  35. numba_cuda/numba/cuda/device_init.py +68 -18
  36. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  37. numba_cuda/numba/cuda/dispatcher.py +300 -213
  38. numba_cuda/numba/cuda/errors.py +13 -10
  39. numba_cuda/numba/cuda/extending.py +55 -1
  40. numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
  41. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
  42. numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +1090 -927
  43. numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +468 -319
  44. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  45. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  46. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  47. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  48. numba_cuda/numba/cuda/initialize.py +5 -3
  49. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
  50. numba_cuda/numba/cuda/intrinsics.py +203 -28
  51. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  52. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  53. numba_cuda/numba/cuda/libdevice.py +317 -317
  54. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  55. numba_cuda/numba/cuda/locks.py +16 -0
  56. numba_cuda/numba/cuda/lowering.py +43 -0
  57. numba_cuda/numba/cuda/mathimpl.py +62 -57
  58. numba_cuda/numba/cuda/models.py +1 -5
  59. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  60. numba_cuda/numba/cuda/printimpl.py +9 -5
  61. numba_cuda/numba/cuda/random.py +46 -36
  62. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  63. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  64. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  65. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  66. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  67. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  68. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  69. numba_cuda/numba/cuda/simulator/api.py +38 -22
  70. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  71. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  72. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  73. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  74. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  75. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  76. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  77. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  78. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  79. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  80. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  81. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  82. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  83. numba_cuda/numba/cuda/simulator_init.py +2 -4
  84. numba_cuda/numba/cuda/stubs.py +134 -108
  85. numba_cuda/numba/cuda/target.py +92 -47
  86. numba_cuda/numba/cuda/testing.py +24 -19
  87. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  88. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  89. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  90. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  91. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  92. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  93. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  94. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  95. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  96. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  97. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  98. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  99. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  100. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  102. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  103. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  104. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  105. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  107. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  108. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  109. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  110. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  111. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  112. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  113. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  114. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  115. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  117. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  118. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  119. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +10 -7
  120. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  121. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  123. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  124. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  125. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  127. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +59 -23
  129. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  130. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  131. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  132. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  133. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  134. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  135. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  136. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  137. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  138. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  139. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  140. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  141. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  142. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  143. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +77 -28
  144. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  145. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  146. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +24 -7
  147. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  148. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  149. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +21 -12
  150. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  151. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  152. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  153. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  154. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  155. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  156. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  157. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  158. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  159. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  161. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  162. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  163. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  164. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  165. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +7 -7
  166. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  167. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  168. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  169. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  170. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  171. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  172. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  173. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  174. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  175. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  176. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  178. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  179. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  180. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  181. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  182. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  183. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  184. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  185. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  186. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  187. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  188. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  189. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  190. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  191. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  192. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  193. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  194. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  195. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  196. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  197. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  198. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  199. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  200. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  201. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  202. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  203. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  204. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  205. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +81 -30
  206. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  207. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  208. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  209. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  210. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  211. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  212. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  213. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  214. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  216. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  217. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  218. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  219. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  220. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  221. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  222. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  223. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  224. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  225. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  226. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  227. numba_cuda/numba/cuda/types.py +5 -2
  228. numba_cuda/numba/cuda/ufuncs.py +382 -362
  229. numba_cuda/numba/cuda/utils.py +2 -2
  230. numba_cuda/numba/cuda/vector_types.py +5 -3
  231. numba_cuda/numba/cuda/vectorizers.py +38 -33
  232. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
  233. numba_cuda-0.10.0.dist-info/RECORD +263 -0
  234. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
  235. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  236. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
  237. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2683 @@
1
+ /*
2
+ * Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #if !defined(__CUDA_BF16_HPP__)
51
+ #define __CUDA_BF16_HPP__
52
+
53
+ #if !defined(__CUDA_BF16_H__)
54
+ #error "Do not include this file directly. Instead, include cuda_bf16.h."
55
+ #endif
56
+
57
+ #if !defined(_MSC_VER) && __cplusplus >= 201103L
58
+ # define __CPP_VERSION_AT_LEAST_11_BF16
59
+ #elif _MSC_FULL_VER >= 190024210 && _MSVC_LANG >= 201103L
60
+ # define __CPP_VERSION_AT_LEAST_11_BF16
61
+ #endif
62
+
63
+ /* C++11 header for std::move.
64
+ * In RTC mode, std::move is provided implicitly; don't include the header
65
+ */
66
+ #if defined(__CPP_VERSION_AT_LEAST_11_BF16) && !defined(__CUDACC_RTC__)
67
+ #include <utility>
68
+ #endif /* defined(__CPP_VERSION_AT_LEAST_11_BF16) && !defined(__CUDACC_RTC__) */
69
+
70
+ /* C++ header for std::memcpy (used for type punning in host-side implementations).
71
+ * When compiling as a CUDA source file memcpy is provided implicitly.
72
+ * !defined(__CUDACC__) implies !defined(__CUDACC_RTC__).
73
+ */
74
+ #if defined(__cplusplus) && !defined(__CUDACC__)
75
+ #include <cstring>
76
+ #endif /* defined(__cplusplus) && !defined(__CUDACC__) */
77
+
78
+
79
+ /* Set up function decorations */
80
+ #if defined(__CUDACC__)
81
+ #define __CUDA_BF16_DECL__ static __device__ __inline__
82
+ #define __CUDA_HOSTDEVICE_BF16_DECL__ static __host__ __device__ __inline__
83
+ #define __VECTOR_FUNCTIONS_DECL__ static __inline__ __host__ __device__
84
+ #define __CUDA_HOSTDEVICE__ __host__ __device__
85
+ #else /* !defined(__CUDACC__) */
86
+ #if defined(__GNUC__)
87
+ #define __CUDA_HOSTDEVICE_BF16_DECL__ static __attribute__ ((unused))
88
+ #else
89
+ #define __CUDA_HOSTDEVICE_BF16_DECL__ static
90
+ #endif /* defined(__GNUC__) */
91
+ #define __CUDA_HOSTDEVICE__
92
+ #endif /* defined(__CUDACC_) */
93
+
94
+ /* Set up structure-alignment attribute */
95
+ #if defined(__CUDACC__)
96
+ #define __CUDA_ALIGN__(align) __align__(align)
97
+ #else
98
+ /* Define alignment macro based on compiler type (cannot assume C11 "_Alignas" is available) */
99
+ #if defined(__CPP_VERSION_AT_LEAST_11_BF16)
100
+ #define __CUDA_ALIGN__(n) alignas(n) /* C++11 kindly gives us a keyword for this */
101
+ #else /* defined(__CPP_VERSION_AT_LEAST_11_BF16)*/
102
+ #if defined(__GNUC__)
103
+ #define __CUDA_ALIGN__(n) __attribute__ ((aligned(n)))
104
+ #elif defined(_MSC_VER)
105
+ #define __CUDA_ALIGN__(n) __declspec(align(n))
106
+ #else
107
+ #define __CUDA_ALIGN__(n)
108
+ #endif /* defined(__GNUC__) */
109
+ #endif /* defined(__CPP_VERSION_AT_LEAST_11_BF16) */
110
+ #endif /* defined(__CUDACC__) */
111
+
112
+ /* Macros to allow nv_bfloat16 & nv_bfloat162 to be used by inline assembly */
113
+ #define __BFLOAT16_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
114
+ #define __BFLOAT16_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
115
+ #define __BFLOAT162_TO_UI(var) *(reinterpret_cast<unsigned int *>(&(var)))
116
+ #define __BFLOAT162_TO_CUI(var) *(reinterpret_cast<const unsigned int *>(&(var)))
117
+
118
+ /**
119
+ * Types which allow static initialization of "nv_bfloat16" and "nv_bfloat162" until
120
+ * these become an actual builtin. Note this initialization is as a
121
+ * bitfield representation of "nv_bfloat16", and not a conversion from short->nv_bfloat16.
122
+ * Such a representation will be deprecated in a future version of CUDA.
123
+ * (Note these are visible to non-nvcc compilers, including C-only compilation)
124
+ */
125
+ typedef struct __CUDA_ALIGN__(2) {
126
+ unsigned short x;
127
+ } __nv_bfloat16_raw;
128
+
129
+ typedef struct __CUDA_ALIGN__(4) {
130
+ unsigned short x;
131
+ unsigned short y;
132
+ } __nv_bfloat162_raw;
133
+
134
+ /* All other definitions in this file are only visible to C++ compilers */
135
+ #if defined(__cplusplus)
136
+
137
+ /* Hide GCC member initialization list warnings because of host/device in-function init requirement */
138
+ #if defined(__GNUC__)
139
+ #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
140
+ #pragma GCC diagnostic push
141
+ #pragma GCC diagnostic ignored "-Wstrict-aliasing"
142
+ #pragma GCC diagnostic ignored "-Weffc++"
143
+ #endif /* __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6) */
144
+ #endif /* defined(__GNUC__) */
145
+
146
+ /* class' : multiple assignment operators specified
147
+ The class has multiple assignment operators of a single type. This warning is informational */
148
+ #if defined(_MSC_VER) && _MSC_VER >= 1500
149
+ #pragma warning( push )
150
+ #pragma warning( disable:4522 )
151
+ #endif /* defined(__GNUC__) */
152
+
153
+ struct __CUDA_ALIGN__(2) __nv_bfloat16 {
154
+ protected:
155
+ unsigned short __x;
156
+
157
+ public:
158
+ #if defined(__CPP_VERSION_AT_LEAST_11_BF16)
159
+ __nv_bfloat16() = default;
160
+ #else
161
+ __CUDA_HOSTDEVICE__ __nv_bfloat16() { }
162
+ #endif /* defined(__CPP_VERSION_AT_LEAST_11_BF16) */
163
+
164
+ /* Convert to/from __nv_bfloat16_raw */
165
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(const __nv_bfloat16_raw &hr) : __x(hr.x) { }
166
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(const __nv_bfloat16_raw &hr) { __x = hr.x; return *this; }
167
+ __CUDA_HOSTDEVICE__ volatile __nv_bfloat16 &operator=(const __nv_bfloat16_raw &hr) volatile { __x = hr.x; return *this; }
168
+ __CUDA_HOSTDEVICE__ volatile __nv_bfloat16 &operator=(const volatile __nv_bfloat16_raw &hr) volatile { __x = hr.x; return *this; }
169
+ __CUDA_HOSTDEVICE__ operator __nv_bfloat16_raw() const { __nv_bfloat16_raw ret; ret.x = __x; return ret; }
170
+ __CUDA_HOSTDEVICE__ operator __nv_bfloat16_raw() const volatile { __nv_bfloat16_raw ret; ret.x = __x; return ret; }
171
+
172
+ #if !defined(__CUDA_NO_BFLOAT16_CONVERSIONS__)
173
+ /* Construct from float/double */
174
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(const float f) { __x = __float2bfloat16(f).__x; }
175
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(const double f) { __x = __double2bfloat16(f).__x; }
176
+
177
+ __CUDA_HOSTDEVICE__ operator float() const { return __bfloat162float(*this); }
178
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(const float f) { __x = __float2bfloat16(f).__x; return *this; }
179
+
180
+ /* We omit "cast to double" operator, so as to not be ambiguous about up-cast */
181
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(const double f) { __x = __double2bfloat16(f).__x; return *this; }
182
+
183
+ /* Member functions only available to nvcc compilation so far */
184
+ #if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
185
+ /* Allow automatic construction from types supported natively in hardware */
186
+ /* Note we do avoid constructor init-list because of special host/device compilation rules */
187
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(short val) { __x = __short2bfloat16_rn(val).__x; }
188
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(unsigned short val) { __x = __ushort2bfloat16_rn(val).__x; }
189
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(int val) { __x = __int2bfloat16_rn(val).__x; }
190
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(unsigned int val) { __x = __uint2bfloat16_rn(val).__x; }
191
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(long long val) { __x = __ll2bfloat16_rn(val).__x; }
192
+ __CUDA_HOSTDEVICE__ __nv_bfloat16(unsigned long long val) { __x = __ull2bfloat16_rn(val).__x; }
193
+
194
+ /* Allow automatic casts to supported builtin types, matching all that are permitted with float */
195
+ __CUDA_HOSTDEVICE__ operator short() const { return __bfloat162short_rz(*this); }
196
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(short val) { __x = __short2bfloat16_rn(val).__x; return *this; }
197
+
198
+ __CUDA_HOSTDEVICE__ operator unsigned short() const { return __bfloat162ushort_rz(*this); }
199
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(unsigned short val) { __x = __ushort2bfloat16_rn(val).__x; return *this; }
200
+
201
+ __CUDA_HOSTDEVICE__ operator int() const { return __bfloat162int_rz(*this); }
202
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(int val) { __x = __int2bfloat16_rn(val).__x; return *this; }
203
+
204
+ __CUDA_HOSTDEVICE__ operator unsigned int() const { return __bfloat162uint_rz(*this); }
205
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(unsigned int val) { __x = __uint2bfloat16_rn(val).__x; return *this; }
206
+
207
+ __CUDA_HOSTDEVICE__ operator long long() const { return __bfloat162ll_rz(*this); }
208
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(long long val) { __x = __ll2bfloat16_rn(val).__x; return *this; }
209
+
210
+ __CUDA_HOSTDEVICE__ operator unsigned long long() const { return __bfloat162ull_rz(*this); }
211
+ __CUDA_HOSTDEVICE__ __nv_bfloat16 &operator=(unsigned long long val) { __x = __ull2bfloat16_rn(val).__x; return *this; }
212
+
213
+ /* Boolean conversion - note both 0 and -0 must return false */
214
+ __CUDA_HOSTDEVICE__ operator bool() const { return (__x & 0x7FFF) != 0; }
215
+ #endif /* defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) */
216
+ #endif /* !defined(__CUDA_NO_BFLOAT16_CONVERSIONS__) */
217
+ };
218
+
219
+ /* Global-space operator functions are only available to nvcc compilation */
220
+ #if defined(__CUDACC__)
221
+
222
+ #if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
223
+ #if !defined(__CUDA_NO_BFLOAT16_OPERATORS__)
224
+ /* Some basic arithmetic operations expected of a builtin */
225
+ __device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hadd(lh, rh); }
226
+ __device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hsub(lh, rh); }
227
+ __device__ __forceinline__ __nv_bfloat16 operator*(const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hmul(lh, rh); }
228
+ __device__ __forceinline__ __nv_bfloat16 operator/(const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hdiv(lh, rh); }
229
+
230
+ __device__ __forceinline__ __nv_bfloat16 &operator+=(__nv_bfloat16 &lh, const __nv_bfloat16 &rh) { lh = __hadd(lh, rh); return lh; }
231
+ __device__ __forceinline__ __nv_bfloat16 &operator-=(__nv_bfloat16 &lh, const __nv_bfloat16 &rh) { lh = __hsub(lh, rh); return lh; }
232
+ __device__ __forceinline__ __nv_bfloat16 &operator*=(__nv_bfloat16 &lh, const __nv_bfloat16 &rh) { lh = __hmul(lh, rh); return lh; }
233
+ __device__ __forceinline__ __nv_bfloat16 &operator/=(__nv_bfloat16 &lh, const __nv_bfloat16 &rh) { lh = __hdiv(lh, rh); return lh; }
234
+
235
+ /* Note for increment and decrement we use the raw value 0x3F80 equating to nv_bfloat16(1.0f), to avoid the extra conversion */
236
+ __device__ __forceinline__ __nv_bfloat16 &operator++(__nv_bfloat16 &h) { __nv_bfloat16_raw one; one.x = 0x3F80; h += one; return h; }
237
+ __device__ __forceinline__ __nv_bfloat16 &operator--(__nv_bfloat16 &h) { __nv_bfloat16_raw one; one.x = 0x3F80; h -= one; return h; }
238
+ __device__ __forceinline__ __nv_bfloat16 operator++(__nv_bfloat16 &h, const int ignored)
239
+ {
240
+ // ignored on purpose. Parameter only needed to distinguish the function declaration from other types of operators.
241
+ static_cast<void>(ignored);
242
+
243
+ const __nv_bfloat16 ret = h;
244
+ __nv_bfloat16_raw one;
245
+ one.x = 0x3F80;
246
+ h += one;
247
+ return ret;
248
+ }
249
+ __device__ __forceinline__ __nv_bfloat16 operator--(__nv_bfloat16 &h, const int ignored)
250
+ {
251
+ // ignored on purpose. Parameter only needed to distinguish the function declaration from other types of operators.
252
+ static_cast<void>(ignored);
253
+
254
+ const __nv_bfloat16 ret = h;
255
+ __nv_bfloat16_raw one;
256
+ one.x = 0x3F80;
257
+ h -= one;
258
+ return ret;
259
+ }
260
+ /* Unary plus and inverse operators */
261
+ __device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16 &h) { return h; }
262
+ __device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16 &h) { return __hneg(h); }
263
+
264
+ /* Some basic comparison operations to make it look like a builtin */
265
+ __device__ __forceinline__ bool operator==(const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __heq(lh, rh); }
266
+ __device__ __forceinline__ bool operator!=(const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hneu(lh, rh); }
267
+ __device__ __forceinline__ bool operator> (const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hgt(lh, rh); }
268
+ __device__ __forceinline__ bool operator< (const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hlt(lh, rh); }
269
+ __device__ __forceinline__ bool operator>=(const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hge(lh, rh); }
270
+ __device__ __forceinline__ bool operator<=(const __nv_bfloat16 &lh, const __nv_bfloat16 &rh) { return __hle(lh, rh); }
271
+ #endif /* !defined(__CUDA_NO_BFLOAT16_OPERATORS__) */
272
+ #endif /* __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) */
273
+ #endif /* defined(__CUDACC__) */
274
+
275
+ /* __nv_bfloat162 is visible to non-nvcc host compilers */
276
+ struct __CUDA_ALIGN__(4) __nv_bfloat162 {
277
+ __nv_bfloat16 x;
278
+ __nv_bfloat16 y;
279
+
280
+ // All construct/copy/assign/move
281
+ public:
282
+ #if defined(__CPP_VERSION_AT_LEAST_11_BF16)
283
+ __nv_bfloat162() = default;
284
+ __CUDA_HOSTDEVICE__ __nv_bfloat162(__nv_bfloat162 &&src) { __BFLOAT162_TO_UI(*this) = std::move(__BFLOAT162_TO_CUI(src)); }
285
+ __CUDA_HOSTDEVICE__ __nv_bfloat162 &operator=(__nv_bfloat162 &&src) { __BFLOAT162_TO_UI(*this) = std::move(__BFLOAT162_TO_CUI(src)); return *this; }
286
+ #else
287
+ __CUDA_HOSTDEVICE__ __nv_bfloat162() { }
288
+ #endif /* defined(__CPP_VERSION_AT_LEAST_11_BF16) */
289
+ __CUDA_HOSTDEVICE__ __nv_bfloat162(const __nv_bfloat16 &a, const __nv_bfloat16 &b) : x(a), y(b) { }
290
+ __CUDA_HOSTDEVICE__ __nv_bfloat162(const __nv_bfloat162 &src) { __BFLOAT162_TO_UI(*this) = __BFLOAT162_TO_CUI(src); }
291
+ __CUDA_HOSTDEVICE__ __nv_bfloat162 &operator=(const __nv_bfloat162 &src) { __BFLOAT162_TO_UI(*this) = __BFLOAT162_TO_CUI(src); return *this; }
292
+
293
+ /* Convert to/from __nv_bfloat162_raw */
294
+ __CUDA_HOSTDEVICE__ __nv_bfloat162(const __nv_bfloat162_raw &h2r ) { __BFLOAT162_TO_UI(*this) = __BFLOAT162_TO_CUI(h2r); }
295
+ __CUDA_HOSTDEVICE__ __nv_bfloat162 &operator=(const __nv_bfloat162_raw &h2r) { __BFLOAT162_TO_UI(*this) = __BFLOAT162_TO_CUI(h2r); return *this; }
296
+ __CUDA_HOSTDEVICE__ operator __nv_bfloat162_raw() const { __nv_bfloat162_raw ret; ret.x = 0U; ret.y = 0U; __BFLOAT162_TO_UI(ret) = __BFLOAT162_TO_CUI(*this); return ret; }
297
+ };
298
+
299
+ /* Global-space operator functions are only available to nvcc compilation */
300
+ #if defined(__CUDACC__)
301
+
302
+ #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) && !defined(__CUDA_NO_BFLOAT162_OPERATORS__)
303
+
304
+ __device__ __forceinline__ __nv_bfloat162 operator+(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hadd2(lh, rh); }
305
+ __device__ __forceinline__ __nv_bfloat162 operator-(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hsub2(lh, rh); }
306
+ __device__ __forceinline__ __nv_bfloat162 operator*(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hmul2(lh, rh); }
307
+ __device__ __forceinline__ __nv_bfloat162 operator/(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __h2div(lh, rh); }
308
+
309
+ __device__ __forceinline__ __nv_bfloat162& operator+=(__nv_bfloat162 &lh, const __nv_bfloat162 &rh) { lh = __hadd2(lh, rh); return lh; }
310
+ __device__ __forceinline__ __nv_bfloat162& operator-=(__nv_bfloat162 &lh, const __nv_bfloat162 &rh) { lh = __hsub2(lh, rh); return lh; }
311
+ __device__ __forceinline__ __nv_bfloat162& operator*=(__nv_bfloat162 &lh, const __nv_bfloat162 &rh) { lh = __hmul2(lh, rh); return lh; }
312
+ __device__ __forceinline__ __nv_bfloat162& operator/=(__nv_bfloat162 &lh, const __nv_bfloat162 &rh) { lh = __h2div(lh, rh); return lh; }
313
+
314
+ __device__ __forceinline__ __nv_bfloat162 &operator++(__nv_bfloat162 &h) { __nv_bfloat162_raw one; one.x = 0x3F80; one.y = 0x3F80; h = __hadd2(h, one); return h; }
315
+ __device__ __forceinline__ __nv_bfloat162 &operator--(__nv_bfloat162 &h) { __nv_bfloat162_raw one; one.x = 0x3F80; one.y = 0x3F80; h = __hsub2(h, one); return h; }
316
+ __device__ __forceinline__ __nv_bfloat162 operator++(__nv_bfloat162 &h, const int ignored)
317
+ {
318
+ // ignored on purpose. Parameter only needed to distinguish the function declaration from other types of operators.
319
+ static_cast<void>(ignored);
320
+
321
+ const __nv_bfloat162 ret = h;
322
+ __nv_bfloat162_raw one;
323
+ one.x = 0x3F80;
324
+ one.y = 0x3F80;
325
+ h = __hadd2(h, one);
326
+ return ret;
327
+ }
328
+ __device__ __forceinline__ __nv_bfloat162 operator--(__nv_bfloat162 &h, const int ignored)
329
+ {
330
+ // ignored on purpose. Parameter only needed to distinguish the function declaration from other types of operators.
331
+ static_cast<void>(ignored);
332
+
333
+ const __nv_bfloat162 ret = h;
334
+ __nv_bfloat162_raw one;
335
+ one.x = 0x3F80;
336
+ one.y = 0x3F80;
337
+ h = __hsub2(h, one);
338
+ return ret;
339
+ }
340
+ __device__ __forceinline__ __nv_bfloat162 operator+(const __nv_bfloat162 &h) { return h; }
341
+ __device__ __forceinline__ __nv_bfloat162 operator-(const __nv_bfloat162 &h) { return __hneg2(h); }
342
+
343
+ __device__ __forceinline__ bool operator==(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hbeq2(lh, rh); }
344
+ __device__ __forceinline__ bool operator!=(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hbneu2(lh, rh); }
345
+ __device__ __forceinline__ bool operator>(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hbgt2(lh, rh); }
346
+ __device__ __forceinline__ bool operator<(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hblt2(lh, rh); }
347
+ __device__ __forceinline__ bool operator>=(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hbge2(lh, rh); }
348
+ __device__ __forceinline__ bool operator<=(const __nv_bfloat162 &lh, const __nv_bfloat162 &rh) { return __hble2(lh, rh); }
349
+
350
+ #endif /* __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) */
351
+ #endif /* defined(__CUDACC__) */
352
+
353
+ /* Restore warning for multiple assignment operators */
354
+ #if defined(_MSC_VER) && _MSC_VER >= 1500
355
+ #pragma warning( pop )
356
+ #endif /* defined(_MSC_VER) && _MSC_VER >= 1500 */
357
+
358
+ /* Restore -Weffc++ warnings from here on */
359
+ #if defined(__GNUC__)
360
+ #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
361
+ #pragma GCC diagnostic pop
362
+ #endif /* __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6) */
363
+ #endif /* defined(__GNUC__) */
364
+
365
+ #undef __CUDA_HOSTDEVICE__
366
+ #undef __CUDA_ALIGN__
367
+
368
+ __CUDA_HOSTDEVICE_BF16_DECL__ unsigned short __internal_float2bfloat16(const float f, unsigned int &sign, unsigned int &remainder)
369
+ {
370
+ unsigned int x;
371
+
372
+ #if defined(__CUDA_ARCH__)
373
+ x = __float_as_uint(f);
374
+ #elif defined(__CUDACC__)
375
+ (void)memcpy(&x, &f, sizeof(f));
376
+ #else
377
+ (void)std::memcpy(&x, &f, sizeof(f));
378
+ #endif
379
+
380
+ if ((x & 0x7fffffffU) > 0x7f800000U) {
381
+ sign = 0U;
382
+ remainder = 0U;
383
+ return static_cast<unsigned short>(0x7fffU);
384
+ }
385
+ sign = x >> 31U;
386
+ remainder = x << 16U;
387
+ return static_cast<unsigned short>(x >> 16U);
388
+ }
389
+
390
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __double2bfloat16(const double x)
391
+ {
392
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
393
+ __nv_bfloat16 val;
394
+ asm("{ cvt.rn.bf16.f64 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "d"(x));
395
+ return val;
396
+ #else
397
+
398
+ float f = static_cast<float>(x);
399
+ const double d = static_cast<double>(f);
400
+ unsigned int u;
401
+
402
+ #if defined(__CUDA_ARCH__)
403
+ u = __float_as_uint(f);
404
+ #elif defined(__CUDACC__)
405
+ (void)memcpy(&u, &f, sizeof(f));
406
+ #else
407
+ (void)std::memcpy(&u, &f, sizeof(f));
408
+ #endif
409
+ bool x_is_not_nan = ((u << (unsigned)1U) <= (unsigned)0xFF000000U);
410
+
411
+
412
+ if ((x > 0.0) && (d > x)) {
413
+ u--;
414
+ }
415
+ if ((x < 0.0) && (d < x)) {
416
+ u--;
417
+ }
418
+ if ((d != x) && x_is_not_nan) {
419
+ u |= 1U;
420
+ }
421
+
422
+ #if defined(__CUDA_ARCH__)
423
+ f = __int_as_float(static_cast<int>(u));
424
+ #elif defined(__CUDACC__)
425
+ (void)memcpy(&f, &u, sizeof(f));
426
+ #else
427
+ (void)std::memcpy(&f, &u, sizeof(f));
428
+ #endif
429
+
430
+ return __float2bfloat16(f);
431
+
432
+ #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
433
+ }
434
+
435
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __float2bfloat16(const float a)
436
+ {
437
+ __nv_bfloat16 val;
438
+ #if __CUDA_ARCH__ >= 800
439
+ asm("{ cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "f"(a));
440
+ #else
441
+ __nv_bfloat16_raw r;
442
+ unsigned int sign = 0U;
443
+ unsigned int remainder = 0U;
444
+ r.x = __internal_float2bfloat16(a, sign, remainder);
445
+ if ((remainder > 0x80000000U) || ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
446
+ r.x++;
447
+ }
448
+ val = r;
449
+ #endif
450
+ return val;
451
+ }
452
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __float2bfloat16_rn(const float a)
453
+ {
454
+ __nv_bfloat16 val;
455
+ #if __CUDA_ARCH__ >= 800
456
+ asm("{ cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "f"(a));
457
+ #else
458
+ __nv_bfloat16_raw r;
459
+ unsigned int sign = 0U;
460
+ unsigned int remainder = 0U;
461
+ r.x = __internal_float2bfloat16(a, sign, remainder);
462
+ if ((remainder > 0x80000000U) || ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
463
+ r.x++;
464
+ }
465
+ val = r;
466
+ #endif
467
+ return val;
468
+ }
469
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __float2bfloat16_rz(const float a)
470
+ {
471
+ __nv_bfloat16 val;
472
+ #if __CUDA_ARCH__ >= 800
473
+ asm("{ cvt.rz.bf16.f32 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "f"(a));
474
+ #else
475
+ __nv_bfloat16_raw r;
476
+ unsigned int sign = 0U;
477
+ unsigned int remainder = 0U;
478
+ r.x = __internal_float2bfloat16(a, sign, remainder);
479
+ val = r;
480
+ #endif
481
+ return val;
482
+ }
483
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __float2bfloat16_rd(const float a)
484
+ {
485
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
486
+ __nv_bfloat16 val;
487
+ asm("{ cvt.rm.bf16.f32 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "f"(a));
488
+ return val;
489
+ #else
490
+ __nv_bfloat16 val;
491
+ __nv_bfloat16_raw r;
492
+ unsigned int sign = 0U;
493
+ unsigned int remainder = 0U;
494
+ r.x = __internal_float2bfloat16(a, sign, remainder);
495
+ if ((remainder != 0U) && (sign != 0U)) {
496
+ r.x++;
497
+ }
498
+ val = r;
499
+ return val;
500
+ #endif
501
+ }
502
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __float2bfloat16_ru(const float a)
503
+ {
504
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
505
+ __nv_bfloat16 val;
506
+ asm("{ cvt.rp.bf16.f32 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "f"(a));
507
+ return val;
508
+ #else
509
+ __nv_bfloat16 val;
510
+ __nv_bfloat16_raw r;
511
+ unsigned int sign = 0U;
512
+ unsigned int remainder = 0U;
513
+ r.x = __internal_float2bfloat16(a, sign, remainder);
514
+ if ((remainder != 0U) && (sign == 0U)) {
515
+ r.x++;
516
+ }
517
+ val = r;
518
+ return val;
519
+ #endif
520
+ }
521
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat162 __float2bfloat162_rn(const float a)
522
+ {
523
+ __nv_bfloat162 val;
524
+ #if __CUDA_ARCH__ >= 800
525
+ asm("{.reg .b16 low;\n"
526
+ " cvt.rn.bf16.f32 low, %1;\n"
527
+ " mov.b32 %0, {low,low};}\n" : "=r"(__BFLOAT162_TO_UI(val)) : "f"(a));
528
+ #else
529
+ val = __nv_bfloat162(__float2bfloat16_rn(a), __float2bfloat16_rn(a));
530
+ #endif
531
+ return val;
532
+ }
533
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat162 __floats2bfloat162_rn(const float a, const float b)
534
+ {
535
+ __nv_bfloat162 val;
536
+ #if __CUDA_ARCH__ >= 800
537
+ asm("{ cvt.rn.bf16x2.f32 %0, %2, %1;}\n"
538
+ : "=r"(__BFLOAT162_TO_UI(val)) : "f"(a), "f"(b));
539
+ #else
540
+ val = __nv_bfloat162(__float2bfloat16_rn(a), __float2bfloat16_rn(b));
541
+ #endif
542
+ return val;
543
+ }
544
+
545
+ __CUDA_HOSTDEVICE_BF16_DECL__ float __internal_bfloat162float(const unsigned short h)
546
+ {
547
+ float f;
548
+ #if defined(__CUDA_ARCH__)
549
+ #if (__CUDA_ARCH__ >= 900)
550
+ asm("{ cvt.f32.bf16 %0, %1;}\n" : "=f"(f) : "h"(h));
551
+ #else
552
+ asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(f) : "h"(h));
553
+ #endif
554
+ #else
555
+ unsigned int u = static_cast<unsigned int>(h) << 16;
556
+ #if defined(__CUDACC__)
557
+ (void)memcpy(&f, &u, sizeof(f));
558
+ #else
559
+ (void)std::memcpy(&f, &u, sizeof(f));
560
+ #endif
561
+ #endif
562
+ return f;
563
+ }
564
+
565
+ __CUDA_HOSTDEVICE_BF16_DECL__ float __bfloat162float(const __nv_bfloat16 a)
566
+ {
567
+ return __internal_bfloat162float(static_cast<__nv_bfloat16_raw>(a).x);
568
+ }
569
+ __CUDA_HOSTDEVICE_BF16_DECL__ float __low2float(const __nv_bfloat162 a)
570
+ {
571
+ return __internal_bfloat162float(static_cast<__nv_bfloat162_raw>(a).x);
572
+ }
573
+
574
+ __CUDA_HOSTDEVICE_BF16_DECL__ float __high2float(const __nv_bfloat162 a)
575
+ {
576
+ return __internal_bfloat162float(static_cast<__nv_bfloat162_raw>(a).y);
577
+ }
578
+
579
+ #if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
580
+
581
+ /* CUDA vector-types compatible vector creation function (note returns __nv_bfloat162, not nv_bfloat162) */
582
+ __VECTOR_FUNCTIONS_DECL__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
583
+ {
584
+ __nv_bfloat162 t; t.x = x; t.y = y; return t;
585
+ }
586
+ #undef __VECTOR_FUNCTIONS_DECL__
587
+
588
+
589
+ /* Definitions of intrinsics */
590
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat162 __float22bfloat162_rn(const float2 a)
591
+ {
592
+ __nv_bfloat162 val = __floats2bfloat162_rn(a.x, a.y);
593
+ return val;
594
+ }
595
+ __CUDA_HOSTDEVICE_BF16_DECL__ float2 __bfloat1622float2(const __nv_bfloat162 a)
596
+ {
597
+ float hi_float;
598
+ float lo_float;
599
+ lo_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).x);
600
+ hi_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).y);
601
+ return make_float2(lo_float, hi_float);
602
+ }
603
+ __CUDA_BF16_DECL__ int __bfloat162int_rn(const __nv_bfloat16 h)
604
+ {
605
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
606
+ int val;
607
+ asm("{ cvt.rni.s32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__BFLOAT16_TO_CUS(h)));
608
+ return val;
609
+ #else
610
+ return __float2int_rn(__bfloat162float(h));
611
+ #endif
612
+ }
613
+ __CUDA_HOSTDEVICE_BF16_DECL__ int __bfloat162int_rz(const __nv_bfloat16 h)
614
+ {
615
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
616
+ int val;
617
+ asm("{ cvt.rzi.s32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__BFLOAT16_TO_CUS(h)));
618
+ return val;
619
+ #else
620
+ const float f = __bfloat162float(h);
621
+ int i;
622
+ i = static_cast<int>(f);
623
+ #if !(defined __CUDA_ARCH__)
624
+ const int max_val = (int)0x7fffffffU;
625
+ const int min_val = (int)0x80000000U;
626
+ const unsigned short bits = static_cast<unsigned short>(static_cast<__nv_bfloat16_raw>(h).x << 1U);
627
+ // saturation fixup
628
+ if (bits > (unsigned short)0xFF00U) {
629
+ // NaN
630
+ i = 0;
631
+ } else if (f >= static_cast<float>(max_val)) {
632
+ // saturate maximum
633
+ i = max_val;
634
+ } else if (f < static_cast<float>(min_val)) {
635
+ // saturate minimum
636
+ i = min_val;
637
+ }
638
+ #endif
639
+ return i;
640
+ #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
641
+ }
642
+ __CUDA_BF16_DECL__ int __bfloat162int_rd(const __nv_bfloat16 h)
643
+ {
644
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
645
+ int val;
646
+ asm("{ cvt.rmi.s32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__BFLOAT16_TO_CUS(h)));
647
+ return val;
648
+ #else
649
+ return __float2int_rd(__bfloat162float(h));
650
+ #endif
651
+ }
652
+ __CUDA_BF16_DECL__ int __bfloat162int_ru(const __nv_bfloat16 h)
653
+ {
654
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
655
+ int val;
656
+ asm("{ cvt.rpi.s32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__BFLOAT16_TO_CUS(h)));
657
+ return val;
658
+ #else
659
+ return __float2int_ru(__bfloat162float(h));
660
+ #endif
661
+ }
662
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __int2bfloat16_rn(const int i)
663
+ {
664
+ #if (defined __CUDA_ARCH__)
665
+ #if (__CUDA_ARCH__ >= 900)
666
+ __nv_bfloat16 val;
667
+ asm("cvt.rn.bf16.s32 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "r"(i));
668
+ return val;
669
+ #else
670
+ const float ru = __int2float_ru(i);
671
+ const float rd = __int2float_rd(i);
672
+ float rz = __int2float_rz(i);
673
+ if (ru != rd) {
674
+ rz = __uint_as_float(__float_as_uint(rz) | 1U);
675
+ }
676
+ return __float2bfloat16_rn(rz);
677
+ #endif
678
+ #else
679
+ const double d = static_cast<double>(i);
680
+ return __double2bfloat16(d);
681
+ #endif
682
+ }
683
+ __CUDA_BF16_DECL__ __nv_bfloat16 __int2bfloat16_rz(const int i)
684
+ {
685
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
686
+ __nv_bfloat16 val;
687
+ asm("cvt.rz.bf16.s32 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "r"(i));
688
+ return val;
689
+ #else
690
+ return __float2bfloat16_rz(__int2float_rz(i));
691
+ #endif
692
+ }
693
+ __CUDA_BF16_DECL__ __nv_bfloat16 __int2bfloat16_rd(const int i)
694
+ {
695
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
696
+ __nv_bfloat16 val;
697
+ asm("cvt.rm.bf16.s32 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "r"(i));
698
+ return val;
699
+ #else
700
+ return __float2bfloat16_rd(__int2float_rd(i));
701
+ #endif
702
+ }
703
+
704
+ __CUDA_BF16_DECL__ __nv_bfloat16 __int2bfloat16_ru(const int i)
705
+ {
706
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
707
+ __nv_bfloat16 val;
708
+ asm("cvt.rp.bf16.s32 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "r"(i));
709
+ return val;
710
+ #else
711
+ return __float2bfloat16_ru(__int2float_ru(i));
712
+ #endif
713
+ }
714
+
715
+ __CUDA_BF16_DECL__ short int __bfloat162short_rn(const __nv_bfloat16 h)
716
+ {
717
+ short int val;
718
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
719
+ asm("cvt.rni.s16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
720
+ #else
721
+ asm("{ .reg.f32 f;\n"
722
+ " mov.b32 f, {0,%1};\n"
723
+ " cvt.rni.s16.f32 %0,f;\n}"
724
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
725
+ #endif
726
+ return val;
727
+ }
728
+
729
+ __CUDA_HOSTDEVICE_BF16_DECL__ short int __bfloat162short_rz(const __nv_bfloat16 h)
730
+ {
731
+ short int val;
732
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
733
+ asm("cvt.rzi.s16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
734
+ #elif (defined __CUDA_ARCH__)
735
+ asm("{ .reg.f32 f;\n"
736
+ " mov.b32 f, {0,%1};\n"
737
+ " cvt.rzi.s16.f32 %0,f;\n}"
738
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
739
+ #else
740
+ const float f = __bfloat162float(h);
741
+ val = static_cast<short int>(f);
742
+ const short int max_val = (short int)0x7fffU;
743
+ const short int min_val = (short int)0x8000U;
744
+ const unsigned short bits = static_cast<unsigned short>(static_cast<__nv_bfloat16_raw>(h).x << 1U);
745
+ // saturation fixup
746
+ if (bits > (unsigned short)0xFF00U) {
747
+ // NaN
748
+ val = 0;
749
+ } else if (f > static_cast<float>(max_val)) {
750
+ // saturate maximum
751
+ val = max_val;
752
+ } else if (f < static_cast<float>(min_val)) {
753
+ // saturate minimum
754
+ val = min_val;
755
+ }
756
+ #endif
757
+ return val;
758
+ }
759
+ __CUDA_BF16_DECL__ short int __bfloat162short_rd(const __nv_bfloat16 h)
760
+ {
761
+ short int val;
762
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
763
+ asm("cvt.rmi.s16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
764
+ #else
765
+ asm("{ .reg.f32 f;\n"
766
+ " mov.b32 f, {0,%1};\n"
767
+ " cvt.rmi.s16.f32 %0,f;\n}"
768
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
769
+ #endif
770
+ return val;
771
+ }
772
+ __CUDA_BF16_DECL__ short int __bfloat162short_ru(const __nv_bfloat16 h)
773
+ {
774
+ short int val;
775
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
776
+ asm("cvt.rpi.s16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
777
+ #else
778
+ asm("{ .reg.f32 f;\n"
779
+ " mov.b32 f, {0,%1};\n"
780
+ " cvt.rpi.s16.f32 %0,f;\n}"
781
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
782
+ #endif
783
+ return val;
784
+ }
785
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __short2bfloat16_rn(const short int i)
786
+ {
787
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
788
+ __nv_bfloat16 val;
789
+ asm("cvt.rn.bf16.s16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(i));
790
+ return val;
791
+ #else
792
+ const float f = static_cast<float>(i);
793
+ return __float2bfloat16_rn(f);
794
+ #endif
795
+ }
796
+ __CUDA_BF16_DECL__ __nv_bfloat16 __short2bfloat16_rz(const short int i)
797
+ {
798
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
799
+ __nv_bfloat16 val;
800
+ asm("cvt.rz.bf16.s16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(i));
801
+ return val;
802
+ #else
803
+ return __float2bfloat16_rz(__int2float_rz(static_cast<int>(i)));
804
+ #endif
805
+ }
806
+ __CUDA_BF16_DECL__ __nv_bfloat16 __short2bfloat16_rd(const short int i)
807
+ {
808
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
809
+ __nv_bfloat16 val;
810
+ asm("cvt.rm.bf16.s16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(i));
811
+ return val;
812
+ #else
813
+ return __float2bfloat16_rd(__int2float_rd(static_cast<int>(i)));
814
+ #endif
815
+ }
816
+ __CUDA_BF16_DECL__ __nv_bfloat16 __short2bfloat16_ru(const short int i)
817
+ {
818
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
819
+ __nv_bfloat16 val;
820
+ asm("cvt.rp.bf16.s16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(i));
821
+ return val;
822
+ #else
823
+ return __float2bfloat16_ru(__int2float_ru(static_cast<int>(i)));
824
+ #endif
825
+ }
826
+
827
+ __CUDA_BF16_DECL__ unsigned int __bfloat162uint_rn(const __nv_bfloat16 h)
828
+ {
829
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
830
+ unsigned int val;
831
+ asm("{ cvt.rni.u32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__BFLOAT16_TO_CUS(h)));
832
+ return val;
833
+ #else
834
+ return __float2uint_rn(__bfloat162float(h));
835
+ #endif
836
+ }
837
+ __CUDA_HOSTDEVICE_BF16_DECL__ unsigned int __bfloat162uint_rz(const __nv_bfloat16 h)
838
+ {
839
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
840
+ unsigned int val;
841
+ asm("{ cvt.rzi.u32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__BFLOAT16_TO_CUS(h)));
842
+ return val;
843
+ #else
844
+
845
+ const float f = __bfloat162float(h);
846
+ unsigned int i;
847
+ i = static_cast<unsigned int>(f);
848
+ #if !(defined __CUDA_ARCH__)
849
+ const unsigned int max_val = 0xffffffffU;
850
+ const unsigned int min_val = 0U;
851
+ const unsigned short bits = static_cast<unsigned short>(static_cast<__nv_bfloat16_raw>(h).x << 1U);
852
+ // saturation fixup
853
+ if (bits > (unsigned short)0xFF00U) {
854
+ // NaN
855
+ i = 0U;
856
+ } else if (f >= static_cast<float>(max_val)) {
857
+ // saturate maximum
858
+ i = max_val;
859
+ } else if (f < static_cast<float>(min_val)) {
860
+ // saturate minimum
861
+ i = min_val;
862
+ }
863
+ #endif
864
+ return i;
865
+
866
+ #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
867
+ }
868
+ __CUDA_BF16_DECL__ unsigned int __bfloat162uint_rd(const __nv_bfloat16 h)
869
+ {
870
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
871
+ unsigned int val;
872
+ asm("{ cvt.rmi.u32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__BFLOAT16_TO_CUS(h)));
873
+ return val;
874
+ #else
875
+ return __float2uint_rd(__bfloat162float(h));
876
+ #endif
877
+ }
878
+ __CUDA_BF16_DECL__ unsigned int __bfloat162uint_ru(const __nv_bfloat16 h)
879
+ {
880
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
881
+ unsigned int val;
882
+ asm("{ cvt.rpi.u32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__BFLOAT16_TO_CUS(h)));
883
+ return val;
884
+ #else
885
+ return __float2uint_ru(__bfloat162float(h));
886
+ #endif
887
+ }
888
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __uint2bfloat16_rn(const unsigned int i)
889
+ {
890
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
891
+ __nv_bfloat16 val;
892
+ asm("cvt.rn.bf16.u32 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "r"(i));
893
+ return val;
894
+ #elif (defined __CUDA_ARCH__)
895
+ const float ru = __uint2float_ru(i);
896
+ const float rd = __uint2float_rd(i);
897
+ float rz = __uint2float_rz(i);
898
+ if (ru != rd) {
899
+ rz = __uint_as_float(__float_as_uint(rz) | 1U);
900
+ }
901
+ return __float2bfloat16_rn(rz);
902
+ #else
903
+ const double d = static_cast<double>(i);
904
+ return __double2bfloat16(d);
905
+ #endif
906
+ }
907
+ __CUDA_BF16_DECL__ __nv_bfloat16 __uint2bfloat16_rz(const unsigned int i)
908
+ {
909
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
910
+ __nv_bfloat16 val;
911
+ asm("cvt.rz.bf16.u32 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "r"(i));
912
+ return val;
913
+ #else
914
+ return __float2bfloat16_rz(__uint2float_rz(i));
915
+ #endif
916
+ }
917
+ __CUDA_BF16_DECL__ __nv_bfloat16 __uint2bfloat16_rd(const unsigned int i)
918
+ {
919
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
920
+ __nv_bfloat16 val;
921
+ asm("cvt.rm.bf16.u32 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "r"(i));
922
+ return val;
923
+ #else
924
+ return __float2bfloat16_rd(__uint2float_rd(i));
925
+ #endif
926
+ }
927
+ __CUDA_BF16_DECL__ __nv_bfloat16 __uint2bfloat16_ru(const unsigned int i)
928
+ {
929
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
930
+ __nv_bfloat16 val;
931
+ asm("cvt.rp.bf16.u32 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "r"(i));
932
+ return val;
933
+ #else
934
+ return __float2bfloat16_ru(__uint2float_ru(i));
935
+ #endif
936
+ }
937
+
938
+ __CUDA_BF16_DECL__ unsigned short int __bfloat162ushort_rn(const __nv_bfloat16 h)
939
+ {
940
+ unsigned short int val;
941
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
942
+ asm("cvt.rni.u16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
943
+ #else
944
+ asm("{ .reg.f32 f;\n"
945
+ " mov.b32 f, {0,%1};\n"
946
+ " cvt.rni.u16.f32 %0,f;\n}"
947
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
948
+ #endif
949
+ return val;
950
+ }
951
+ __CUDA_HOSTDEVICE_BF16_DECL__ unsigned short int __bfloat162ushort_rz(const __nv_bfloat16 h)
952
+ {
953
+ unsigned short int val;
954
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
955
+ asm("cvt.rzi.u16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
956
+ #elif (defined __CUDA_ARCH__)
957
+ asm("{ .reg.f32 f;\n"
958
+ " mov.b32 f, {0,%1};\n"
959
+ " cvt.rzi.u16.f32 %0,f;\n}"
960
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
961
+ #else
962
+ const float f = __bfloat162float(h);
963
+ val = static_cast<unsigned short int>(f);
964
+ const unsigned short int max_val = 0xffffU;
965
+ const unsigned short int min_val = 0U;
966
+ const unsigned short bits = static_cast<unsigned short>(static_cast<__nv_bfloat16_raw>(h).x << 1U);
967
+ // saturation fixup
968
+ if (bits > (unsigned short)0xFF00U) {
969
+ // NaN
970
+ val = 0U;
971
+ } else if (f > static_cast<float>(max_val)) {
972
+ // saturate maximum
973
+ val = max_val;
974
+ } else if (f < static_cast<float>(min_val)) {
975
+ // saturate minimum
976
+ val = min_val;
977
+ }
978
+ #endif
979
+ return val;
980
+ }
981
+ __CUDA_BF16_DECL__ unsigned short int __bfloat162ushort_rd(const __nv_bfloat16 h)
982
+ {
983
+ unsigned short int val;
984
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
985
+ asm("cvt.rmi.u16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
986
+ #else
987
+ asm("{ .reg.f32 f;\n"
988
+ " mov.b32 f, {0,%1};\n"
989
+ " cvt.rmi.u16.f32 %0,f;\n}"
990
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
991
+ #endif
992
+ return val;
993
+ }
994
+ __CUDA_BF16_DECL__ unsigned short int __bfloat162ushort_ru(const __nv_bfloat16 h)
995
+ {
996
+ unsigned short int val;
997
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
998
+ asm("cvt.rpi.u16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
999
+ #else
1000
+ asm("{ .reg.f32 f;\n"
1001
+ " mov.b32 f, {0,%1};\n"
1002
+ " cvt.rpi.u16.f32 %0,f;\n}"
1003
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(h)));
1004
+ #endif
1005
+ return val;
1006
+ }
1007
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __ushort2bfloat16_rn(const unsigned short int i)
1008
+ {
1009
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1010
+ __nv_bfloat16 val;
1011
+ asm("cvt.rn.bf16.u16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(i));
1012
+ return val;
1013
+ #else
1014
+ const float f = static_cast<float>(i);
1015
+ return __float2bfloat16_rn(f);
1016
+ #endif
1017
+ }
1018
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ushort2bfloat16_rz(const unsigned short int i)
1019
+ {
1020
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1021
+ __nv_bfloat16 val;
1022
+ asm("cvt.rz.bf16.u16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(i));
1023
+ return val;
1024
+ #else
1025
+ return __float2bfloat16_rz(__uint2float_rz(static_cast<unsigned int>(i)));
1026
+ #endif
1027
+ }
1028
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ushort2bfloat16_rd(const unsigned short int i)
1029
+ {
1030
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1031
+ __nv_bfloat16 val;
1032
+ asm("cvt.rm.bf16.u16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(i));
1033
+ return val;
1034
+ #else
1035
+ return __float2bfloat16_rd(__uint2float_rd(static_cast<unsigned int>(i)));
1036
+ #endif
1037
+ }
1038
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ushort2bfloat16_ru(const unsigned short int i)
1039
+ {
1040
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1041
+ __nv_bfloat16 val;
1042
+ asm("cvt.rp.bf16.u16 %0, %1;" : "=h"(__BFLOAT16_TO_US(val)) : "h"(i));
1043
+ return val;
1044
+ #else
1045
+ return __float2bfloat16_ru(__uint2float_ru(static_cast<unsigned int>(i)));
1046
+ #endif
1047
+ }
1048
+
1049
+ __CUDA_BF16_DECL__ unsigned long long int __bfloat162ull_rn(const __nv_bfloat16 h)
1050
+ {
1051
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1052
+ unsigned long long int i;
1053
+ asm("cvt.rni.u64.bf16 %0, %1;" : "=l"(i) : "h"(__BFLOAT16_TO_CUS(h)));
1054
+ return i;
1055
+ #else
1056
+ return __float2ull_rn(__bfloat162float(h));
1057
+ #endif
1058
+ }
1059
+ __CUDA_HOSTDEVICE_BF16_DECL__ unsigned long long int __bfloat162ull_rz(const __nv_bfloat16 h)
1060
+ {
1061
+ unsigned long long int i;
1062
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1063
+ asm("cvt.rzi.u64.bf16 %0, %1;" : "=l"(i) : "h"(__BFLOAT16_TO_CUS(h)));
1064
+ return i;
1065
+ #else
1066
+ const float f = __bfloat162float(h);
1067
+ i = static_cast<unsigned long long int>(f);
1068
+ #if !(defined __CUDA_ARCH__)
1069
+ const unsigned long long int max_val = 0xffffffffffffffffULL;
1070
+ const unsigned long long int min_val = 0ULL;
1071
+ const unsigned short bits = static_cast<unsigned short>(static_cast<__nv_bfloat16_raw>(h).x << 1U);
1072
+ // saturation fixup
1073
+ if (bits > (unsigned short)0xFF00U) {
1074
+ // NaN
1075
+ i = 0x8000000000000000ULL;
1076
+ } else if (f >= static_cast<float>(max_val)) {
1077
+ // saturate maximum
1078
+ i = max_val;
1079
+ } else if (f < static_cast<float>(min_val)) {
1080
+ // saturate minimum
1081
+ i = min_val;
1082
+ }
1083
+ #endif
1084
+ #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1085
+ return i;
1086
+ }
1087
+ __CUDA_BF16_DECL__ unsigned long long int __bfloat162ull_rd(const __nv_bfloat16 h)
1088
+ {
1089
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1090
+ unsigned long long int i;
1091
+ asm("cvt.rmi.u64.bf16 %0, %1;" : "=l"(i) : "h"(__BFLOAT16_TO_CUS(h)));
1092
+ return i;
1093
+ #else
1094
+ return __float2ull_rd(__bfloat162float(h));
1095
+ #endif
1096
+ }
1097
+ __CUDA_BF16_DECL__ unsigned long long int __bfloat162ull_ru(const __nv_bfloat16 h)
1098
+ {
1099
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1100
+ unsigned long long int i;
1101
+ asm("cvt.rpi.u64.bf16 %0, %1;" : "=l"(i) : "h"(__BFLOAT16_TO_CUS(h)));
1102
+ return i;
1103
+ #else
1104
+ return __float2ull_ru(__bfloat162float(h));
1105
+ #endif
1106
+ }
1107
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __ull2bfloat16_rn(const unsigned long long int i)
1108
+ {
1109
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1110
+ __nv_bfloat16 h;
1111
+ asm("cvt.rn.bf16.u64 %0, %1;" : "=h"(__BFLOAT16_TO_US(h)) : "l"(i));
1112
+ return h;
1113
+ #elif (defined __CUDA_ARCH__)
1114
+ const float ru = __ull2float_ru(i);
1115
+ const float rd = __ull2float_rd(i);
1116
+ float rz = __ull2float_rz(i);
1117
+ if (ru != rd) {
1118
+ rz = __uint_as_float(__float_as_uint(rz) | 1U);
1119
+ }
1120
+ return __float2bfloat16_rn(rz);
1121
+ #else
1122
+ float f = static_cast<float>(i);
1123
+ const unsigned long long int uf = static_cast<unsigned long long int>(f);
1124
+ unsigned int u;
1125
+
1126
+ #if defined(__CUDA_ARCH__)
1127
+ u = __float_as_uint(f);
1128
+ #elif defined(__CUDACC__)
1129
+ (void)memcpy(&u, &f, sizeof(f));
1130
+ #else
1131
+ (void)std::memcpy(&u, &f, sizeof(f));
1132
+ #endif
1133
+
1134
+ // round up happened here
1135
+ // note: no need to handle round up to f == 0x1.p64 specially
1136
+ if (uf > i) {
1137
+ u--;
1138
+ }
1139
+ if (uf != i) {
1140
+ u |= 1U;
1141
+ }
1142
+
1143
+ #if defined(__CUDA_ARCH__)
1144
+ f = __int_as_float(static_cast<int>(u));
1145
+ #elif defined(__CUDACC__)
1146
+ (void)memcpy(&f, &u, sizeof(f));
1147
+ #else
1148
+ (void)std::memcpy(&f, &u, sizeof(f));
1149
+ #endif
1150
+
1151
+ return __float2bfloat16_rn(f);
1152
+ #endif
1153
+ }
1154
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ull2bfloat16_rz(const unsigned long long int i)
1155
+ {
1156
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1157
+ __nv_bfloat16 h;
1158
+ asm("cvt.rz.bf16.u64 %0, %1;" : "=h"(__BFLOAT16_TO_US(h)) : "l"(i));
1159
+ return h;
1160
+ #else
1161
+ return __float2bfloat16_rz(__ull2float_rz(i));
1162
+ #endif
1163
+ }
1164
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ull2bfloat16_rd(const unsigned long long int i)
1165
+ {
1166
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1167
+ __nv_bfloat16 h;
1168
+ asm("cvt.rm.bf16.u64 %0, %1;" : "=h"(__BFLOAT16_TO_US(h)) : "l"(i));
1169
+ return h;
1170
+ #else
1171
+ return __float2bfloat16_rd(__ull2float_rd(i));
1172
+ #endif
1173
+ }
1174
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ull2bfloat16_ru(const unsigned long long int i)
1175
+ {
1176
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1177
+ __nv_bfloat16 h;
1178
+ asm("cvt.rp.bf16.u64 %0, %1;" : "=h"(__BFLOAT16_TO_US(h)) : "l"(i));
1179
+ return h;
1180
+ #else
1181
+ return __float2bfloat16_ru(__ull2float_ru(i));
1182
+ #endif
1183
+ }
1184
+ __CUDA_BF16_DECL__ long long int __bfloat162ll_rn(const __nv_bfloat16 h)
1185
+ {
1186
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1187
+ long long int i;
1188
+ asm("cvt.rni.s64.bf16 %0, %1;" : "=l"(i) : "h"(__BFLOAT16_TO_CUS(h)));
1189
+ return i;
1190
+ #else
1191
+ return __float2ll_rn(__bfloat162float(h));
1192
+ #endif
1193
+ }
1194
+ __CUDA_HOSTDEVICE_BF16_DECL__ long long int __bfloat162ll_rz(const __nv_bfloat16 h)
1195
+ {
1196
+ long long int i;
1197
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1198
+ asm("cvt.rzi.s64.bf16 %0, %1;" : "=l"(i) : "h"(__BFLOAT16_TO_CUS(h)));
1199
+ #else
1200
+ const float f = __bfloat162float(h);
1201
+ i = static_cast<long long int>(f);
1202
+ #if !(defined __CUDA_ARCH__)
1203
+ const long long int max_val = (long long int)0x7fffffffffffffffULL;
1204
+ const long long int min_val = (long long int)0x8000000000000000ULL;
1205
+ const unsigned short bits = static_cast<unsigned short>(static_cast<__nv_bfloat16_raw>(h).x << 1U);
1206
+ // saturation fixup
1207
+ if (bits > (unsigned short)0xFF00U) {
1208
+ // NaN
1209
+ i = min_val;
1210
+ } else if (f >= static_cast<float>(max_val)) {
1211
+ // saturate maximum
1212
+ i = max_val;
1213
+ } else if (f < static_cast<float>(min_val)) {
1214
+ // saturate minimum
1215
+ i = min_val;
1216
+ }
1217
+ #endif
1218
+ #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1219
+ return i;
1220
+ }
1221
+ __CUDA_BF16_DECL__ long long int __bfloat162ll_rd(const __nv_bfloat16 h)
1222
+ {
1223
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1224
+ long long int i;
1225
+ asm("cvt.rmi.s64.bf16 %0, %1;" : "=l"(i) : "h"(__BFLOAT16_TO_CUS(h)));
1226
+ return i;
1227
+ #else
1228
+ return __float2ll_rd(__bfloat162float(h));
1229
+ #endif
1230
+ }
1231
+ __CUDA_BF16_DECL__ long long int __bfloat162ll_ru(const __nv_bfloat16 h)
1232
+ {
1233
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1234
+ long long int i;
1235
+ asm("cvt.rpi.s64.bf16 %0, %1;" : "=l"(i) : "h"(__BFLOAT16_TO_CUS(h)));
1236
+ return i;
1237
+ #else
1238
+ return __float2ll_ru(__bfloat162float(h));
1239
+ #endif
1240
+ }
1241
+ __CUDA_HOSTDEVICE_BF16_DECL__ __nv_bfloat16 __ll2bfloat16_rn(const long long int i)
1242
+ {
1243
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1244
+ __nv_bfloat16 h;
1245
+ asm("cvt.rn.bf16.s64 %0, %1;" : "=h"(__BFLOAT16_TO_US(h)) : "l"(i));
1246
+ return h;
1247
+ #elif (defined __CUDA_ARCH__)
1248
+ const float ru = __ll2float_ru(i);
1249
+ const float rd = __ll2float_rd(i);
1250
+ float rz = __ll2float_rz(i);
1251
+ if (ru != rd) {
1252
+ rz = __uint_as_float(__float_as_uint(rz) | 1U);
1253
+ }
1254
+ return __float2bfloat16_rn(rz);
1255
+ #else
1256
+ float f = static_cast<float>(i);
1257
+ const long long int lf = static_cast<long long int>(f);
1258
+ unsigned int u;
1259
+
1260
+ #if defined(__CUDA_ARCH__)
1261
+ u = __float_as_uint(f);
1262
+ #elif defined(__CUDACC__)
1263
+ (void)memcpy(&u, &f, sizeof(f));
1264
+ #else
1265
+ (void)std::memcpy(&u, &f, sizeof(f));
1266
+ #endif
1267
+
1268
+ if ((f > 0.0f) && (lf > i)) {
1269
+ u--;
1270
+ }
1271
+ if ((f < 0.0f) && (lf < i)) {
1272
+ u--;
1273
+ }
1274
+ if (lf != i) {
1275
+ u |= 1U;
1276
+ }
1277
+
1278
+ #if defined(__CUDA_ARCH__)
1279
+ f = __int_as_float(static_cast<int>(u));
1280
+ #elif defined(__CUDACC__)
1281
+ (void)memcpy(&f, &u, sizeof(f));
1282
+ #else
1283
+ (void)std::memcpy(&f, &u, sizeof(f));
1284
+ #endif
1285
+
1286
+ return __float2bfloat16_rn(f);
1287
+ #endif
1288
+ }
1289
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ll2bfloat16_rz(const long long int i)
1290
+ {
1291
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1292
+ __nv_bfloat16 h;
1293
+ asm("cvt.rz.bf16.s64 %0, %1;" : "=h"(__BFLOAT16_TO_US(h)) : "l"(i));
1294
+ return h;
1295
+ #else
1296
+ return __float2bfloat16_rz(__ll2float_rz(i));
1297
+ #endif
1298
+ }
1299
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ll2bfloat16_rd(const long long int i)
1300
+ {
1301
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1302
+ __nv_bfloat16 h;
1303
+ asm("cvt.rm.bf16.s64 %0, %1;" : "=h"(__BFLOAT16_TO_US(h)) : "l"(i));
1304
+ return h;
1305
+ #else
1306
+ return __float2bfloat16_rd(__ll2float_rd(i));
1307
+ #endif
1308
+ }
1309
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ll2bfloat16_ru(const long long int i)
1310
+ {
1311
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1312
+ __nv_bfloat16 h;
1313
+ asm("cvt.rp.bf16.s64 %0, %1;" : "=h"(__BFLOAT16_TO_US(h)) : "l"(i));
1314
+ return h;
1315
+ #else
1316
+ return __float2bfloat16_ru(__ll2float_ru(i));
1317
+ #endif
1318
+ }
1319
+
1320
+ __CUDA_BF16_DECL__ __nv_bfloat16 htrunc(const __nv_bfloat16 h)
1321
+ {
1322
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1323
+ __nv_bfloat16 r;
1324
+ asm("cvt.rzi.bf16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(r)) : "h"(__BFLOAT16_TO_CUS(h)));
1325
+ return r;
1326
+ #else
1327
+ return __float2bfloat16_rz(truncf(__bfloat162float(h)));
1328
+ #endif
1329
+ }
1330
+ __CUDA_BF16_DECL__ __nv_bfloat16 hceil(const __nv_bfloat16 h)
1331
+ {
1332
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1333
+ __nv_bfloat16 r;
1334
+ asm("cvt.rpi.bf16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(r)) : "h"(__BFLOAT16_TO_CUS(h)));
1335
+ return r;
1336
+ #else
1337
+ return __float2bfloat16_ru(ceilf(__bfloat162float(h)));
1338
+ #endif
1339
+ }
1340
+ __CUDA_BF16_DECL__ __nv_bfloat16 hfloor(const __nv_bfloat16 h)
1341
+ {
1342
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1343
+ __nv_bfloat16 r;
1344
+ asm("cvt.rmi.bf16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(r)) : "h"(__BFLOAT16_TO_CUS(h)));
1345
+ return r;
1346
+ #else
1347
+ return __float2bfloat16_rd(floorf(__bfloat162float(h)));
1348
+ #endif
1349
+ }
1350
+ __CUDA_BF16_DECL__ __nv_bfloat16 hrint(const __nv_bfloat16 h)
1351
+ {
1352
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1353
+ __nv_bfloat16 r;
1354
+ asm("cvt.rni.bf16.bf16 %0, %1;" : "=h"(__BFLOAT16_TO_US(r)) : "h"(__BFLOAT16_TO_CUS(h)));
1355
+ return r;
1356
+ #else
1357
+ return __float2bfloat16_rn(rintf(__bfloat162float(h)));
1358
+ #endif
1359
+ }
1360
+
1361
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2trunc(const __nv_bfloat162 h)
1362
+ {
1363
+ const __nv_bfloat16 low = __float2bfloat16_rz(truncf(__low2float(h)));
1364
+ const __nv_bfloat16 high = __float2bfloat16_rz(truncf(__high2float(h)));
1365
+ return __nv_bfloat162(low, high);
1366
+ }
1367
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2ceil(const __nv_bfloat162 h)
1368
+ {
1369
+ const __nv_bfloat16 low = __float2bfloat16_ru(ceilf(__low2float(h)));
1370
+ const __nv_bfloat16 high = __float2bfloat16_ru(ceilf(__high2float(h)));
1371
+ return __nv_bfloat162(low, high);
1372
+ }
1373
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2floor(const __nv_bfloat162 h)
1374
+ {
1375
+ const __nv_bfloat16 low = __float2bfloat16_rd(floorf(__low2float(h)));
1376
+ const __nv_bfloat16 high = __float2bfloat16_rd(floorf(__high2float(h)));
1377
+ return __nv_bfloat162(low, high);
1378
+ }
1379
+
1380
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2rint(const __nv_bfloat162 h)
1381
+ {
1382
+ return __halves2bfloat162(hrint(__low2bfloat16(h)), hrint(__high2bfloat16(h)));
1383
+ }
1384
+ __CUDA_BF16_DECL__ __nv_bfloat162 __lows2bfloat162(const __nv_bfloat162 a, const __nv_bfloat162 b)
1385
+ {
1386
+ __nv_bfloat162 val;
1387
+ asm("{.reg .b16 alow,ahigh,blow,bhigh;\n"
1388
+ " mov.b32 {alow,ahigh}, %1;\n"
1389
+ " mov.b32 {blow,bhigh}, %2;\n"
1390
+ " mov.b32 %0, {alow,blow};}\n" : "=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)), "r"(__BFLOAT162_TO_CUI(b)));
1391
+ return val;
1392
+ }
1393
+ __CUDA_BF16_DECL__ __nv_bfloat162 __highs2bfloat162(const __nv_bfloat162 a, const __nv_bfloat162 b)
1394
+ {
1395
+ __nv_bfloat162 val;
1396
+ asm("{.reg .b16 alow,ahigh,blow,bhigh;\n"
1397
+ " mov.b32 {alow,ahigh}, %1;\n"
1398
+ " mov.b32 {blow,bhigh}, %2;\n"
1399
+ " mov.b32 %0, {ahigh,bhigh};}\n" : "=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)), "r"(__BFLOAT162_TO_CUI(b)));
1400
+ return val;
1401
+ }
1402
+ __CUDA_BF16_DECL__ __nv_bfloat16 __low2bfloat16(const __nv_bfloat162 a)
1403
+ {
1404
+ __nv_bfloat16 ret;
1405
+ asm("{.reg .b16 low,high;\n"
1406
+ " mov.b32 {low,high}, %1;\n"
1407
+ " mov.b16 %0, low;}" : "=h"(__BFLOAT16_TO_US(ret)) : "r"(__BFLOAT162_TO_CUI(a)));
1408
+ return ret;
1409
+ }
1410
+ __CUDA_BF16_DECL__ int __hisinf(const __nv_bfloat16 a)
1411
+ {
1412
+ int retval;
1413
+ if (__BFLOAT16_TO_CUS(a) == 0xFF80U) {
1414
+ retval = -1;
1415
+ } else if (__BFLOAT16_TO_CUS(a) == 0x7F80U) {
1416
+ retval = 1;
1417
+ } else {
1418
+ retval = 0;
1419
+ }
1420
+ return retval;
1421
+ }
1422
+ __CUDA_BF16_DECL__ __nv_bfloat162 __low2bfloat162(const __nv_bfloat162 a)
1423
+ {
1424
+ __nv_bfloat162 val;
1425
+ asm("{.reg .b16 low,high;\n"
1426
+ " mov.b32 {low,high}, %1;\n"
1427
+ " mov.b32 %0, {low,low};}\n" : "=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)));
1428
+ return val;
1429
+ }
1430
+ __CUDA_BF16_DECL__ __nv_bfloat162 __high2bfloat162(const __nv_bfloat162 a)
1431
+ {
1432
+ __nv_bfloat162 val;
1433
+ asm("{.reg .b16 low,high;\n"
1434
+ " mov.b32 {low,high}, %1;\n"
1435
+ " mov.b32 %0, {high,high};}\n" : "=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)));
1436
+ return val;
1437
+ }
1438
+ __CUDA_BF16_DECL__ __nv_bfloat16 __high2bfloat16(const __nv_bfloat162 a)
1439
+ {
1440
+ __nv_bfloat16 ret;
1441
+ asm("{.reg .b16 low,high;\n"
1442
+ " mov.b32 {low,high}, %1;\n"
1443
+ " mov.b16 %0, high;}" : "=h"(__BFLOAT16_TO_US(ret)) : "r"(__BFLOAT162_TO_CUI(a)));
1444
+ return ret;
1445
+ }
1446
+ __CUDA_BF16_DECL__ __nv_bfloat162 __halves2bfloat162(const __nv_bfloat16 a, const __nv_bfloat16 b)
1447
+ {
1448
+ __nv_bfloat162 val;
1449
+ asm("{ mov.b32 %0, {%1,%2};}\n"
1450
+ : "=r"(__BFLOAT162_TO_UI(val)) : "h"(__BFLOAT16_TO_CUS(a)), "h"(__BFLOAT16_TO_CUS(b)));
1451
+ return val;
1452
+ }
1453
+ __CUDA_BF16_DECL__ __nv_bfloat162 __bfloat162bfloat162(const __nv_bfloat16 a)
1454
+ {
1455
+ __nv_bfloat162 val;
1456
+ asm("{ mov.b32 %0, {%1,%1};}\n"
1457
+ : "=r"(__BFLOAT162_TO_UI(val)) : "h"(__BFLOAT16_TO_CUS(a)));
1458
+ return val;
1459
+ }
1460
+ __CUDA_BF16_DECL__ __nv_bfloat162 __lowhigh2highlow(const __nv_bfloat162 a)
1461
+ {
1462
+ __nv_bfloat162 val;
1463
+ asm("{.reg .b16 low,high;\n"
1464
+ " mov.b32 {low,high}, %1;\n"
1465
+ " mov.b32 %0, {high,low};}\n" : "=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)));
1466
+ return val;
1467
+ }
1468
+ __CUDA_BF16_DECL__ short int __bfloat16_as_short(const __nv_bfloat16 h)
1469
+ {
1470
+ return static_cast<short int>(__BFLOAT16_TO_CUS(h));
1471
+ }
1472
+ __CUDA_BF16_DECL__ unsigned short int __bfloat16_as_ushort(const __nv_bfloat16 h)
1473
+ {
1474
+ return __BFLOAT16_TO_CUS(h);
1475
+ }
1476
+ __CUDA_BF16_DECL__ __nv_bfloat16 __short_as_bfloat16(const short int i)
1477
+ {
1478
+ __nv_bfloat16 h;
1479
+ __BFLOAT16_TO_US(h) = static_cast<unsigned short int>(i);
1480
+ return h;
1481
+ }
1482
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ushort_as_bfloat16(const unsigned short int i)
1483
+ {
1484
+ __nv_bfloat16 h;
1485
+ __BFLOAT16_TO_US(h) = i;
1486
+ return h;
1487
+ }
1488
+
1489
+ /******************************************************************************
1490
+ * __nv_bfloat16, __nv_bfloat162 warp shuffle *
1491
+ ******************************************************************************/
1492
+ #define __SHUFFLE_SYNC_BFLOAT162_MACRO(name) /* do */ {\
1493
+ __nv_bfloat162 r; \
1494
+ asm volatile ("{" __CUDA_BF16_STRINGIFY(name) " %0,%1,%2,%3,%4;\n}" \
1495
+ :"=r"(__BFLOAT162_TO_UI(r)): "r"(__BFLOAT162_TO_CUI(var)), "r"(delta), "r"(c), "r"(mask)); \
1496
+ return r; \
1497
+ } /* while(0) */
1498
+
1499
+ __CUDA_BF16_DECL__ __nv_bfloat162 __shfl_sync(const unsigned mask, const __nv_bfloat162 var, const int delta, const int width)
1500
+ {
1501
+ unsigned int warp_size;
1502
+ asm("{mov.u32 %0, WARP_SZ;\n}" : "=r"(warp_size));
1503
+ const unsigned int c = ((warp_size - static_cast<unsigned>(width)) << 8U) | 0x1fU;
1504
+ __SHUFFLE_SYNC_BFLOAT162_MACRO(shfl.sync.idx.b32)
1505
+ }
1506
+ __CUDA_BF16_DECL__ __nv_bfloat162 __shfl_up_sync(const unsigned mask, const __nv_bfloat162 var, const unsigned int delta, const int width)
1507
+ {
1508
+ unsigned int warp_size;
1509
+ asm("{mov.u32 %0, WARP_SZ;\n}" : "=r"(warp_size));
1510
+ const unsigned int c = (warp_size - static_cast<unsigned>(width)) << 8U;
1511
+ __SHUFFLE_SYNC_BFLOAT162_MACRO(shfl.sync.up.b32)
1512
+ }
1513
+ __CUDA_BF16_DECL__ __nv_bfloat162 __shfl_down_sync(const unsigned mask, const __nv_bfloat162 var, const unsigned int delta, const int width)
1514
+ {
1515
+ unsigned int warp_size;
1516
+ asm("{mov.u32 %0, WARP_SZ;\n}" : "=r"(warp_size));
1517
+ const unsigned int c = ((warp_size - static_cast<unsigned>(width)) << 8U) | 0x1fU;
1518
+ __SHUFFLE_SYNC_BFLOAT162_MACRO(shfl.sync.down.b32)
1519
+ }
1520
+ __CUDA_BF16_DECL__ __nv_bfloat162 __shfl_xor_sync(const unsigned mask, const __nv_bfloat162 var, const int delta, const int width)
1521
+ {
1522
+ unsigned int warp_size;
1523
+ asm("{mov.u32 %0, WARP_SZ;\n}" : "=r"(warp_size));
1524
+ const unsigned int c = ((warp_size - static_cast<unsigned>(width)) << 8U) | 0x1fU;
1525
+ __SHUFFLE_SYNC_BFLOAT162_MACRO(shfl.sync.bfly.b32)
1526
+ }
1527
+
1528
+ #undef __SHUFFLE_SYNC_BFLOAT162_MACRO
1529
+
1530
+ __CUDA_BF16_DECL__ __nv_bfloat16 __shfl_sync(const unsigned mask, const __nv_bfloat16 var, const int delta, const int width)
1531
+ {
1532
+ const __nv_bfloat162 temp1 = __halves2bfloat162(var, var);
1533
+ const __nv_bfloat162 temp2 = __shfl_sync(mask, temp1, delta, width);
1534
+ return __low2bfloat16(temp2);
1535
+ }
1536
+ __CUDA_BF16_DECL__ __nv_bfloat16 __shfl_up_sync(const unsigned mask, const __nv_bfloat16 var, const unsigned int delta, const int width)
1537
+ {
1538
+ const __nv_bfloat162 temp1 = __halves2bfloat162(var, var);
1539
+ const __nv_bfloat162 temp2 = __shfl_up_sync(mask, temp1, delta, width);
1540
+ return __low2bfloat16(temp2);
1541
+ }
1542
+ __CUDA_BF16_DECL__ __nv_bfloat16 __shfl_down_sync(const unsigned mask, const __nv_bfloat16 var, const unsigned int delta, const int width)
1543
+ {
1544
+ const __nv_bfloat162 temp1 = __halves2bfloat162(var, var);
1545
+ const __nv_bfloat162 temp2 = __shfl_down_sync(mask, temp1, delta, width);
1546
+ return __low2bfloat16(temp2);
1547
+ }
1548
+ __CUDA_BF16_DECL__ __nv_bfloat16 __shfl_xor_sync(const unsigned mask, const __nv_bfloat16 var, const int delta, const int width)
1549
+ {
1550
+ const __nv_bfloat162 temp1 = __halves2bfloat162(var, var);
1551
+ const __nv_bfloat162 temp2 = __shfl_xor_sync(mask, temp1, delta, width);
1552
+ return __low2bfloat16(temp2);
1553
+ }
1554
+
1555
+ /******************************************************************************
1556
+ * __nv_bfloat16 and __nv_bfloat162 __ldg,__ldcg,__ldca,__ldcs *
1557
+ ******************************************************************************/
1558
+
1559
+ #if defined(__cplusplus)
1560
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1561
+ #define __LDG_PTR "l"
1562
+ #else
1563
+ #define __LDG_PTR "r"
1564
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1565
+ __CUDA_BF16_DECL__ __nv_bfloat162 __ldg(const __nv_bfloat162 *const ptr)
1566
+ {
1567
+ __nv_bfloat162 ret;
1568
+ asm ("ld.global.nc.b32 %0, [%1];" : "=r"(__BFLOAT162_TO_UI(ret)) : __LDG_PTR(ptr));
1569
+ return ret;
1570
+ }
1571
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ldg(const __nv_bfloat16 *const ptr)
1572
+ {
1573
+ __nv_bfloat16 ret;
1574
+ asm ("ld.global.nc.b16 %0, [%1];" : "=h"(__BFLOAT16_TO_US(ret)) : __LDG_PTR(ptr));
1575
+ return ret;
1576
+ }
1577
+ __CUDA_BF16_DECL__ __nv_bfloat162 __ldcg(const __nv_bfloat162 *const ptr)
1578
+ {
1579
+ __nv_bfloat162 ret;
1580
+ asm ("ld.global.cg.b32 %0, [%1];" : "=r"(__BFLOAT162_TO_UI(ret)) : __LDG_PTR(ptr));
1581
+ return ret;
1582
+ }
1583
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ldcg(const __nv_bfloat16 *const ptr)
1584
+ {
1585
+ __nv_bfloat16 ret;
1586
+ asm ("ld.global.cg.b16 %0, [%1];" : "=h"(__BFLOAT16_TO_US(ret)) : __LDG_PTR(ptr));
1587
+ return ret;
1588
+ }
1589
+ __CUDA_BF16_DECL__ __nv_bfloat162 __ldca(const __nv_bfloat162 *const ptr)
1590
+ {
1591
+ __nv_bfloat162 ret;
1592
+ asm ("ld.global.ca.b32 %0, [%1];" : "=r"(__BFLOAT162_TO_UI(ret)) : __LDG_PTR(ptr));
1593
+ return ret;
1594
+ }
1595
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ldca(const __nv_bfloat16 *const ptr)
1596
+ {
1597
+ __nv_bfloat16 ret;
1598
+ asm ("ld.global.ca.b16 %0, [%1];" : "=h"(__BFLOAT16_TO_US(ret)) : __LDG_PTR(ptr));
1599
+ return ret;
1600
+ }
1601
+ __CUDA_BF16_DECL__ __nv_bfloat162 __ldcs(const __nv_bfloat162 *const ptr)
1602
+ {
1603
+ __nv_bfloat162 ret;
1604
+ asm ("ld.global.cs.b32 %0, [%1];" : "=r"(__BFLOAT162_TO_UI(ret)) : __LDG_PTR(ptr));
1605
+ return ret;
1606
+ }
1607
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ldcs(const __nv_bfloat16 *const ptr)
1608
+ {
1609
+ __nv_bfloat16 ret;
1610
+ asm ("ld.global.cs.b16 %0, [%1];" : "=h"(__BFLOAT16_TO_US(ret)) : __LDG_PTR(ptr));
1611
+ return ret;
1612
+ }
1613
+ __CUDA_BF16_DECL__ __nv_bfloat162 __ldlu(const __nv_bfloat162 *const ptr)
1614
+ {
1615
+ __nv_bfloat162 ret;
1616
+ asm ("ld.global.lu.b32 %0, [%1];" : "=r"(__BFLOAT162_TO_UI(ret)) : __LDG_PTR(ptr) : "memory");
1617
+ return ret;
1618
+ }
1619
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ldlu(const __nv_bfloat16 *const ptr)
1620
+ {
1621
+ __nv_bfloat16 ret;
1622
+ asm ("ld.global.lu.b16 %0, [%1];" : "=h"(__BFLOAT16_TO_US(ret)) : __LDG_PTR(ptr) : "memory");
1623
+ return ret;
1624
+ }
1625
+ __CUDA_BF16_DECL__ __nv_bfloat162 __ldcv(const __nv_bfloat162 *const ptr)
1626
+ {
1627
+ __nv_bfloat162 ret;
1628
+ asm ("ld.global.cv.b32 %0, [%1];" : "=r"(__BFLOAT162_TO_UI(ret)) : __LDG_PTR(ptr) : "memory");
1629
+ return ret;
1630
+ }
1631
+ __CUDA_BF16_DECL__ __nv_bfloat16 __ldcv(const __nv_bfloat16 *const ptr)
1632
+ {
1633
+ __nv_bfloat16 ret;
1634
+ asm ("ld.global.cv.b16 %0, [%1];" : "=h"(__BFLOAT16_TO_US(ret)) : __LDG_PTR(ptr) : "memory");
1635
+ return ret;
1636
+ }
1637
+
1638
+ __CUDA_BF16_DECL__ void __stwb(__nv_bfloat162 *const ptr, const __nv_bfloat162 value)
1639
+ {
1640
+ asm ("st.global.wb.b32 [%0], %1;" :: __LDG_PTR(ptr), "r"(__BFLOAT162_TO_CUI(value)) : "memory");
1641
+ }
1642
+ __CUDA_BF16_DECL__ void __stwb(__nv_bfloat16 *const ptr, const __nv_bfloat16 value)
1643
+ {
1644
+ asm ("st.global.wb.b16 [%0], %1;" :: __LDG_PTR(ptr), "h"(__BFLOAT16_TO_CUS(value)) : "memory");
1645
+ }
1646
+ __CUDA_BF16_DECL__ void __stcg(__nv_bfloat162 *const ptr, const __nv_bfloat162 value)
1647
+ {
1648
+ asm ("st.global.cg.b32 [%0], %1;" :: __LDG_PTR(ptr), "r"(__BFLOAT162_TO_CUI(value)) : "memory");
1649
+ }
1650
+ __CUDA_BF16_DECL__ void __stcg(__nv_bfloat16 *const ptr, const __nv_bfloat16 value)
1651
+ {
1652
+ asm ("st.global.cg.b16 [%0], %1;" :: __LDG_PTR(ptr), "h"(__BFLOAT16_TO_CUS(value)) : "memory");
1653
+ }
1654
+ __CUDA_BF16_DECL__ void __stcs(__nv_bfloat162 *const ptr, const __nv_bfloat162 value)
1655
+ {
1656
+ asm ("st.global.cs.b32 [%0], %1;" :: __LDG_PTR(ptr), "r"(__BFLOAT162_TO_CUI(value)) : "memory");
1657
+ }
1658
+ __CUDA_BF16_DECL__ void __stcs(__nv_bfloat16 *const ptr, const __nv_bfloat16 value)
1659
+ {
1660
+ asm ("st.global.cs.b16 [%0], %1;" :: __LDG_PTR(ptr), "h"(__BFLOAT16_TO_CUS(value)) : "memory");
1661
+ }
1662
+ __CUDA_BF16_DECL__ void __stwt(__nv_bfloat162 *const ptr, const __nv_bfloat162 value)
1663
+ {
1664
+ asm ("st.global.wt.b32 [%0], %1;" :: __LDG_PTR(ptr), "r"(__BFLOAT162_TO_CUI(value)) : "memory");
1665
+ }
1666
+ __CUDA_BF16_DECL__ void __stwt(__nv_bfloat16 *const ptr, const __nv_bfloat16 value)
1667
+ {
1668
+ asm ("st.global.wt.b16 [%0], %1;" :: __LDG_PTR(ptr), "h"(__BFLOAT16_TO_CUS(value)) : "memory");
1669
+ }
1670
+
1671
+ #undef __LDG_PTR
1672
+ #endif /*defined(__cplusplus) */
1673
+ /******************************************************************************
1674
+ * __nv_bfloat162 comparison *
1675
+ ******************************************************************************/
1676
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1677
+ #define __COMPARISON_OP_BFLOAT162_MACRO(name) {\
1678
+ __nv_bfloat162 val; \
1679
+ asm( "{ " __CUDA_BF16_STRINGIFY(name) ".bf16x2.bf16x2 %0,%1,%2;\n}" \
1680
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1681
+ return val; \
1682
+ }
1683
+ #else
1684
+ #define __COMPARISON_OP_BFLOAT162_MACRO(name) {\
1685
+ __nv_bfloat162 val; \
1686
+ asm( "{.reg .b32 low_a,low_b,high_a,high_b,high_res,low_res;\n"\
1687
+ " and.b32 high_a, %1, 0xffff0000U;\n"\
1688
+ " and.b32 high_b, %2, 0xffff0000U;\n"\
1689
+ " shl.b32 low_a, %1, 16;\n"\
1690
+ " shl.b32 low_b, %2, 16;\n"\
1691
+ " " __CUDA_BF16_STRINGIFY(name) ".f32.f32 low_res, low_a, low_b;\n"\
1692
+ " " __CUDA_BF16_STRINGIFY(name) ".f32.f32 high_res, high_a, high_b;\n"\
1693
+ " shr.u32 low_res, low_res, 16;\n"\
1694
+ " or.b32 %0, high_res, low_res;}\n"\
1695
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1696
+ return val; \
1697
+ }
1698
+ #endif
1699
+
1700
+ __CUDA_BF16_DECL__ __nv_bfloat162 __heq2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1701
+ {
1702
+ __COMPARISON_OP_BFLOAT162_MACRO(set.eq)
1703
+ }
1704
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hne2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1705
+ {
1706
+ __COMPARISON_OP_BFLOAT162_MACRO(set.ne)
1707
+ }
1708
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hle2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1709
+ {
1710
+ __COMPARISON_OP_BFLOAT162_MACRO(set.le)
1711
+ }
1712
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hge2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1713
+ {
1714
+ __COMPARISON_OP_BFLOAT162_MACRO(set.ge)
1715
+ }
1716
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hlt2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1717
+ {
1718
+ __COMPARISON_OP_BFLOAT162_MACRO(set.lt)
1719
+ }
1720
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hgt2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1721
+ {
1722
+ __COMPARISON_OP_BFLOAT162_MACRO(set.gt)
1723
+ }
1724
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hequ2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1725
+ {
1726
+ __COMPARISON_OP_BFLOAT162_MACRO(set.equ)
1727
+ }
1728
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hneu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1729
+ {
1730
+ __COMPARISON_OP_BFLOAT162_MACRO(set.neu)
1731
+ }
1732
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hleu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1733
+ {
1734
+ __COMPARISON_OP_BFLOAT162_MACRO(set.leu)
1735
+ }
1736
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hgeu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1737
+ {
1738
+ __COMPARISON_OP_BFLOAT162_MACRO(set.geu)
1739
+ }
1740
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hltu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1741
+ {
1742
+ __COMPARISON_OP_BFLOAT162_MACRO(set.ltu)
1743
+ }
1744
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hgtu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1745
+ {
1746
+ __COMPARISON_OP_BFLOAT162_MACRO(set.gtu)
1747
+ }
1748
+ #undef __COMPARISON_OP_BFLOAT162_MACRO
1749
+
1750
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1751
+ #define __BOOL_COMPARISON_OP_BFLOAT162_MACRO(name) {\
1752
+ __nv_bfloat162 val; \
1753
+ bool retval; \
1754
+ asm( "{ " __CUDA_BF16_STRINGIFY(name) ".bf16x2.bf16x2 %0,%1,%2;\n}" \
1755
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1756
+ if (__BFLOAT162_TO_CUI(val) == 0x3F803F80U) {\
1757
+ retval = true; \
1758
+ } else { \
1759
+ retval = false; \
1760
+ }\
1761
+ return retval;\
1762
+ }
1763
+ #else
1764
+
1765
+ #define __BOOL_COMPARISON_OP_BFLOAT162_MACRO(name) {\
1766
+ unsigned int val; \
1767
+ asm( "{.reg .b32 low_a,low_b,high_a,high_b,high_res,low_res;\n"\
1768
+ " and.b32 high_a, %1, 0xffff0000U;\n"\
1769
+ " and.b32 high_b, %2, 0xffff0000U;\n"\
1770
+ " shl.b32 low_a, %1, 16;\n"\
1771
+ " shl.b32 low_b, %2, 16;\n"\
1772
+ " " __CUDA_BF16_STRINGIFY(name) ".f32.f32 low_res, low_a, low_b;\n"\
1773
+ " " __CUDA_BF16_STRINGIFY(name) ".f32.f32 high_res, high_a, high_b;\n"\
1774
+ " and.b32 %0, high_res, low_res;}\n"\
1775
+ :"=r"(val) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1776
+ return (val != 0U) ? true : false; \
1777
+ }
1778
+ #endif
1779
+
1780
+ __CUDA_BF16_DECL__ bool __hbeq2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1781
+ {
1782
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.eq)
1783
+ }
1784
+ __CUDA_BF16_DECL__ bool __hbne2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1785
+ {
1786
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.ne)
1787
+ }
1788
+ __CUDA_BF16_DECL__ bool __hble2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1789
+ {
1790
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.le)
1791
+ }
1792
+ __CUDA_BF16_DECL__ bool __hbge2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1793
+ {
1794
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.ge)
1795
+ }
1796
+ __CUDA_BF16_DECL__ bool __hblt2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1797
+ {
1798
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.lt)
1799
+ }
1800
+ __CUDA_BF16_DECL__ bool __hbgt2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1801
+ {
1802
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.gt)
1803
+ }
1804
+ __CUDA_BF16_DECL__ bool __hbequ2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1805
+ {
1806
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.equ)
1807
+ }
1808
+ __CUDA_BF16_DECL__ bool __hbneu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1809
+ {
1810
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.neu)
1811
+ }
1812
+ __CUDA_BF16_DECL__ bool __hbleu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1813
+ {
1814
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.leu)
1815
+ }
1816
+ __CUDA_BF16_DECL__ bool __hbgeu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1817
+ {
1818
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.geu)
1819
+ }
1820
+ __CUDA_BF16_DECL__ bool __hbltu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1821
+ {
1822
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.ltu)
1823
+ }
1824
+ __CUDA_BF16_DECL__ bool __hbgtu2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1825
+ {
1826
+ __BOOL_COMPARISON_OP_BFLOAT162_MACRO(set.gtu)
1827
+ }
1828
+ #undef __BOOL_COMPARISON_OP_BFLOAT162_MACRO
1829
+ /******************************************************************************
1830
+ * __nv_bfloat16 comparison *
1831
+ ******************************************************************************/
1832
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1833
+ #define __COMPARISON_OP_BFLOAT16_MACRO(name) {\
1834
+ unsigned short val; \
1835
+ asm( "{ .reg .pred __$temp3;\n" \
1836
+ " setp." __CUDA_BF16_STRINGIFY(name) ".bf16 __$temp3, %1, %2;\n" \
1837
+ " selp.u16 %0, 1, 0, __$temp3;}" \
1838
+ : "=h"(val) : "h"(__BFLOAT16_TO_CUS(a)), "h"(__BFLOAT16_TO_CUS(b))); \
1839
+ return (val != 0U) ? true : false; \
1840
+ }
1841
+ #else
1842
+ #define __COMPARISON_OP_BFLOAT16_MACRO(name) {\
1843
+ unsigned int val; \
1844
+ asm( "{.reg .b32 a,b;\n"\
1845
+ " mov.b32 a, {0, %1};\n"\
1846
+ " mov.b32 b, {0, %2};\n"\
1847
+ " set." __CUDA_BF16_STRINGIFY(name) ".f32.f32 %0, a, b;}\n"\
1848
+ :"=r"(val) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b))); \
1849
+ return (val != 0U) ? true : false; \
1850
+ }
1851
+ #endif
1852
+ __CUDA_BF16_DECL__ bool __heq(const __nv_bfloat16 a, const __nv_bfloat16 b)
1853
+ {
1854
+ __COMPARISON_OP_BFLOAT16_MACRO(eq)
1855
+ }
1856
+ __CUDA_BF16_DECL__ bool __hne(const __nv_bfloat16 a, const __nv_bfloat16 b)
1857
+ {
1858
+ __COMPARISON_OP_BFLOAT16_MACRO(ne)
1859
+ }
1860
+ __CUDA_BF16_DECL__ bool __hle(const __nv_bfloat16 a, const __nv_bfloat16 b)
1861
+ {
1862
+ __COMPARISON_OP_BFLOAT16_MACRO(le)
1863
+ }
1864
+ __CUDA_BF16_DECL__ bool __hge(const __nv_bfloat16 a, const __nv_bfloat16 b)
1865
+ {
1866
+ __COMPARISON_OP_BFLOAT16_MACRO(ge)
1867
+ }
1868
+ __CUDA_BF16_DECL__ bool __hlt(const __nv_bfloat16 a, const __nv_bfloat16 b)
1869
+ {
1870
+ __COMPARISON_OP_BFLOAT16_MACRO(lt)
1871
+ }
1872
+ __CUDA_BF16_DECL__ bool __hgt(const __nv_bfloat16 a, const __nv_bfloat16 b)
1873
+ {
1874
+ __COMPARISON_OP_BFLOAT16_MACRO(gt)
1875
+ }
1876
+ __CUDA_BF16_DECL__ bool __hequ(const __nv_bfloat16 a, const __nv_bfloat16 b)
1877
+ {
1878
+ __COMPARISON_OP_BFLOAT16_MACRO(equ)
1879
+ }
1880
+ __CUDA_BF16_DECL__ bool __hneu(const __nv_bfloat16 a, const __nv_bfloat16 b)
1881
+ {
1882
+ __COMPARISON_OP_BFLOAT16_MACRO(neu)
1883
+ }
1884
+ __CUDA_BF16_DECL__ bool __hleu(const __nv_bfloat16 a, const __nv_bfloat16 b)
1885
+ {
1886
+ __COMPARISON_OP_BFLOAT16_MACRO(leu)
1887
+ }
1888
+ __CUDA_BF16_DECL__ bool __hgeu(const __nv_bfloat16 a, const __nv_bfloat16 b)
1889
+ {
1890
+ __COMPARISON_OP_BFLOAT16_MACRO(geu)
1891
+ }
1892
+ __CUDA_BF16_DECL__ bool __hltu(const __nv_bfloat16 a, const __nv_bfloat16 b)
1893
+ {
1894
+ __COMPARISON_OP_BFLOAT16_MACRO(ltu)
1895
+ }
1896
+ __CUDA_BF16_DECL__ bool __hgtu(const __nv_bfloat16 a, const __nv_bfloat16 b)
1897
+ {
1898
+ __COMPARISON_OP_BFLOAT16_MACRO(gtu)
1899
+ }
1900
+ #undef __COMPARISON_OP_BFLOAT16_MACRO
1901
+ /******************************************************************************
1902
+ * __nv_bfloat162 arithmetic *
1903
+ ******************************************************************************/
1904
+ #define __BINARY_OP_BFLOAT162_MACRO(name) /* do */ {\
1905
+ __nv_bfloat162 val; \
1906
+ asm( "{.reg .b32 low_a,low_b,high_a,high_b,high_res,low_res;\n"\
1907
+ " .reg .b16 low,high;\n"\
1908
+ " and.b32 high_a, %1, 0xffff0000U;\n"\
1909
+ " and.b32 high_b, %2, 0xffff0000U;\n"\
1910
+ " shl.b32 low_a, %1, 16;\n"\
1911
+ " shl.b32 low_b, %2, 16;\n"\
1912
+ " " __CUDA_BF16_STRINGIFY(name) ".f32 low_res, low_a, low_b;\n"\
1913
+ " " __CUDA_BF16_STRINGIFY(name) ".f32 high_res, high_a, high_b;\n"\
1914
+ " cvt.rn.bf16.f32 low, low_res;\n"\
1915
+ " cvt.rn.bf16.f32 high, high_res;\n"\
1916
+ " mov.b32 %0, {low,high};}\n"\
1917
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1918
+ return val; \
1919
+ } /* while(0) */
1920
+
1921
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hadd2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1922
+ {
1923
+ __nv_bfloat162 val;
1924
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1925
+ asm( "{ add.bf16x2 %0,%1,%2; }\n"
1926
+ #else
1927
+ asm( "{.reg .b32 c;\n"
1928
+ " mov.b32 c, 0x3f803f80U;\n"
1929
+ " fma.rn.bf16x2 %0,%1,c,%2;}\n"
1930
+ #endif
1931
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1932
+ return val;
1933
+ }
1934
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hsub2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1935
+ {
1936
+ __nv_bfloat162 val;
1937
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1938
+ asm( "{ sub.bf16x2 %0,%1,%2; }\n"
1939
+ #else
1940
+ asm( "{.reg .b32 c;\n"
1941
+ " mov.b32 c, 0xbf80bf80U;\n"
1942
+ " fma.rn.bf16x2 %0,%2,c,%1;}\n"
1943
+ #endif
1944
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1945
+ return val;
1946
+ }
1947
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hmul2(const __nv_bfloat162 a, const __nv_bfloat162 b)
1948
+ {
1949
+ __nv_bfloat162 val;
1950
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1951
+ asm( "{ mul.bf16x2 %0,%1,%2; }\n"
1952
+ #else
1953
+ asm( "{.reg .b32 c;\n"
1954
+ " mov.b32 c, 0x80008000U;\n"
1955
+ " fma.rn.bf16x2 %0,%1,%2,c;}\n"
1956
+ #endif
1957
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1958
+ return val;
1959
+ }
1960
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hadd2_rn(const __nv_bfloat162 a, const __nv_bfloat162 b)
1961
+ {
1962
+ __nv_bfloat162 val;
1963
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1964
+ asm( "{ add.rn.bf16x2 %0,%1,%2; }\n"
1965
+ #else
1966
+ asm( "{.reg .b32 c;\n"
1967
+ " mov.b32 c, 0x3f803f80U;\n"
1968
+ " fma.rn.bf16x2 %0,%1,c,%2;}\n"
1969
+ #endif
1970
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1971
+ return val;
1972
+ }
1973
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hsub2_rn(const __nv_bfloat162 a, const __nv_bfloat162 b)
1974
+ {
1975
+ __nv_bfloat162 val;
1976
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1977
+ asm( "{ sub.rn.bf16x2 %0,%1,%2; }\n"
1978
+ #else
1979
+ asm( "{.reg .b32 c;\n"
1980
+ " mov.b32 c, 0xbf80bf80U;\n"
1981
+ " fma.rn.bf16x2 %0,%2,c,%1;}\n"
1982
+ #endif
1983
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1984
+ return val;
1985
+ }
1986
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hmul2_rn(const __nv_bfloat162 a, const __nv_bfloat162 b)
1987
+ {
1988
+ __nv_bfloat162 val;
1989
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
1990
+ asm( "{ mul.rn.bf16x2 %0,%1,%2; }\n"
1991
+ #else
1992
+ asm( "{.reg .b32 c;\n"
1993
+ " mov.b32 c, 0x80008000U;\n"
1994
+ " fma.rn.bf16x2 %0,%1,%2,c;}\n"
1995
+ #endif
1996
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
1997
+ return val;
1998
+ }
1999
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hadd2_sat(const __nv_bfloat162 a, const __nv_bfloat162 b)
2000
+ {
2001
+ __nv_bfloat162 val;
2002
+ asm( "{.reg .b32 f, one, zero;\n"
2003
+ " mov.b32 one, 0x3f803f80U;\n"
2004
+ " mov.b32 zero, 0;\n"
2005
+ " fma.rn.bf16x2 f,%1,one,%2;\n"
2006
+ " max.bf16x2 f, f, zero;\n"
2007
+ " min.bf16x2 %0, f, one;\n}"
2008
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
2009
+ return val;
2010
+ }
2011
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hsub2_sat(const __nv_bfloat162 a, const __nv_bfloat162 b)
2012
+ {
2013
+ __nv_bfloat162 val;
2014
+ asm( "{.reg .b32 f, one, zero, mone;\n"
2015
+ " mov.b32 one, 0x3f803f80U;\n"
2016
+ " mov.b32 zero, 0;\n"
2017
+ " mov.b32 mone, 0xbf80bf80U;\n"
2018
+ " fma.rn.bf16x2 f,%2,mone,%1;\n"
2019
+ " max.bf16x2 f, f, zero;\n"
2020
+ " min.bf16x2 %0, f, one;\n}"
2021
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
2022
+ return val;
2023
+ }
2024
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hmul2_sat(const __nv_bfloat162 a, const __nv_bfloat162 b)
2025
+ {
2026
+ __nv_bfloat162 val;
2027
+ asm( "{.reg .b32 f, one, zero, mzero;\n"
2028
+ " mov.b32 one, 0x3f803f80U;\n"
2029
+ " mov.b32 zero, 0;\n"
2030
+ " mov.b32 mzero, 0x80008000U;\n"
2031
+ " fma.rn.bf16x2 f,%1,%2,mzero;\n"
2032
+ " max.bf16x2 f, f, zero;\n"
2033
+ " min.bf16x2 %0, f, one;\n}"
2034
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b))); \
2035
+ return val;
2036
+ }
2037
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hfma2(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c)
2038
+ {
2039
+ __nv_bfloat162 val;
2040
+ asm( "{fma.rn.bf16x2 %0,%1,%2,%3;\n}"
2041
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b)),"r"(__BFLOAT162_TO_CUI(c)));
2042
+ return val;
2043
+ }
2044
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hfma2_sat(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c)
2045
+ {
2046
+ __nv_bfloat162 val;
2047
+ asm( "{ .reg .b32 f, one, zero;\n"
2048
+ " mov.b32 one, 0x3f803f80U;\n"
2049
+ " mov.b32 zero, 0;\n"
2050
+ " fma.rn.bf16x2 f, %1, %2, %3;\n"
2051
+ " max.bf16x2 f, f, zero;\n"
2052
+ " min.bf16x2 %0, f, one;\n}"
2053
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b)),"r"(__BFLOAT162_TO_CUI(c)));
2054
+ return val;
2055
+ }
2056
+ __CUDA_BF16_DECL__ __nv_bfloat162 __h2div(const __nv_bfloat162 a, const __nv_bfloat162 b) {
2057
+ __nv_bfloat16 ha, hb;
2058
+
2059
+ ha = __low2bfloat16(a);
2060
+ hb = __low2bfloat16(b);
2061
+
2062
+ const __nv_bfloat16 v1 = __hdiv(ha, hb);
2063
+
2064
+ ha = __high2bfloat16(a);
2065
+ hb = __high2bfloat16(b);
2066
+
2067
+ const __nv_bfloat16 v2 = __hdiv(ha, hb);
2068
+
2069
+ return __halves2bfloat162(v1, v2);
2070
+ }
2071
+ /******************************************************************************
2072
+ * __nv_bfloat16 arithmetic *
2073
+ ******************************************************************************/
2074
+ #define __BINARY_OP_BFLOAT16_MACRO(name) /* do */ {\
2075
+ __nv_bfloat16 val; \
2076
+ asm( "{.reg .b32 a,b,res;\n"\
2077
+ " mov.b32 a, {0,%1};\n"\
2078
+ " mov.b32 b, {0,%2};\n"\
2079
+ " " __CUDA_BF16_STRINGIFY(name) ".f32 res, a, b;\n"\
2080
+ " cvt.rn.bf16.f32 %0, res;}\n"\
2081
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b))); \
2082
+ return val; \
2083
+ } /* while(0) */
2084
+
2085
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hadd(const __nv_bfloat16 a, const __nv_bfloat16 b)
2086
+ {
2087
+ __nv_bfloat16 val;
2088
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2089
+ asm( "{ add.bf16 %0,%1,%2; }\n"
2090
+ #else
2091
+ asm( "{.reg .b16 c;\n"
2092
+ " mov.b16 c, 0x3f80U;\n"
2093
+ " fma.rn.bf16 %0,%1,c,%2;}\n"
2094
+ #endif
2095
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b))); \
2096
+ return val;
2097
+ }
2098
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hsub(const __nv_bfloat16 a, const __nv_bfloat16 b)
2099
+ {
2100
+ __nv_bfloat16 val;
2101
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2102
+ asm( "{ sub.bf16 %0,%1,%2; }\n"
2103
+ #else
2104
+ asm( "{.reg .b16 c;\n"
2105
+ " mov.b16 c, 0xbf80U;\n"
2106
+ " fma.rn.bf16 %0,%2,c,%1;}\n"
2107
+ #endif
2108
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b))); \
2109
+ return val;
2110
+ }
2111
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hmul(const __nv_bfloat16 a, const __nv_bfloat16 b)
2112
+ {
2113
+ __nv_bfloat16 val;
2114
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2115
+ asm( "{ mul.bf16 %0,%1,%2; }\n"
2116
+ #else
2117
+ asm( "{.reg .b16 c;\n"
2118
+ " mov.b16 c, 0x8000U;\n"
2119
+ " fma.rn.bf16 %0,%1,%2,c;}\n"
2120
+ #endif
2121
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b))); \
2122
+ return val;
2123
+ }
2124
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hadd_rn(const __nv_bfloat16 a, const __nv_bfloat16 b)
2125
+ {
2126
+ __nv_bfloat16 val;
2127
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2128
+ asm( "{ add.rn.bf16 %0,%1,%2; }\n"
2129
+ #else
2130
+ asm( "{.reg .b16 c;\n"
2131
+ " mov.b16 c, 0x3f80U;\n"
2132
+ " fma.rn.bf16 %0,%1,c,%2;}\n"
2133
+ #endif
2134
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b))); \
2135
+ return val;
2136
+ }
2137
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hsub_rn(const __nv_bfloat16 a, const __nv_bfloat16 b)
2138
+ {
2139
+ __nv_bfloat16 val;
2140
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2141
+ asm( "{ sub.rn.bf16 %0,%1,%2; }\n"
2142
+ #else
2143
+ asm( "{.reg .b16 c;\n"
2144
+ " mov.b16 c, 0xbf80U;\n"
2145
+ " fma.rn.bf16 %0,%2,c,%1;}\n"
2146
+ #endif
2147
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b))); \
2148
+ return val;
2149
+ }
2150
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hmul_rn(const __nv_bfloat16 a, const __nv_bfloat16 b)
2151
+ {
2152
+ __nv_bfloat16 val;
2153
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2154
+ asm( "{ mul.rn.bf16 %0,%1,%2; }\n"
2155
+ #else
2156
+ asm( "{.reg .b16 c;\n"
2157
+ " mov.b16 c, 0x8000U;\n"
2158
+ " fma.rn.bf16 %0,%1,%2,c;}\n"
2159
+ #endif
2160
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b))); \
2161
+ return val;
2162
+ }
2163
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hadd_sat(const __nv_bfloat16 a, const __nv_bfloat16 b)
2164
+ {
2165
+ __nv_bfloat16 val;
2166
+ asm( "{ .reg .b16 f, one, zero;\n"
2167
+ " mov.b16 one, 0x3f80U;\n"
2168
+ " mov.b16 zero, 0;\n"
2169
+ " fma.rn.bf16 f, %1, one, %2;\n"
2170
+ " max.bf16 f, f, zero;\n"
2171
+ " min.bf16 %0, f, one;\n}"
2172
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)));
2173
+ return val;
2174
+ }
2175
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hsub_sat(const __nv_bfloat16 a, const __nv_bfloat16 b)
2176
+ {
2177
+ __nv_bfloat16 val;
2178
+ asm( "{ .reg .b16 f, one, zero, mone;\n"
2179
+ " mov.b16 one, 0x3f80U;\n"
2180
+ " mov.b16 zero, 0;\n"
2181
+ " mov.b16 mone, 0xbf80U;\n"
2182
+ " fma.rn.bf16 f, %2, mone, %1;\n"
2183
+ " max.bf16 f, f, zero;\n"
2184
+ " min.bf16 %0, f, one;\n}"
2185
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)));
2186
+ return val;
2187
+ }
2188
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hmul_sat(const __nv_bfloat16 a, const __nv_bfloat16 b)
2189
+ {
2190
+ __nv_bfloat16 val;
2191
+ asm( "{ .reg .b16 f, one, zero, mzero;\n"
2192
+ " mov.b16 one, 0x3f80U;\n"
2193
+ " mov.b16 zero, 0;\n"
2194
+ " mov.b16 mzero, 0x8000U;\n"
2195
+ " fma.rn.bf16 f, %1, %2, mzero;\n"
2196
+ " max.bf16 f, f, zero;\n"
2197
+ " min.bf16 %0, f, one;\n}"
2198
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)));
2199
+ return val;
2200
+ }
2201
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hfma(const __nv_bfloat16 a, const __nv_bfloat16 b, const __nv_bfloat16 c)
2202
+ {
2203
+ __nv_bfloat16 val;
2204
+ asm( "{fma.rn.bf16 %0,%1,%2,%3;\n}"
2205
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)),"h"(__BFLOAT16_TO_CUS(c)));
2206
+ return val;
2207
+ }
2208
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hfma_sat(const __nv_bfloat16 a, const __nv_bfloat16 b, const __nv_bfloat16 c)
2209
+ {
2210
+ __nv_bfloat16 val;
2211
+ asm( "{ .reg .b16 f, one, zero;\n"
2212
+ " mov.b16 one, 0x3f80U;\n"
2213
+ " mov.b16 zero, 0;\n"
2214
+ " fma.rn.bf16 f, %1, %2, %3;\n"
2215
+ " max.bf16 f, f, zero;\n"
2216
+ " min.bf16 %0, f, one;\n}"
2217
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)),"h"(__BFLOAT16_TO_CUS(c)));
2218
+ return val;
2219
+ }
2220
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hdiv(const __nv_bfloat16 a, const __nv_bfloat16 b) {
2221
+ __BINARY_OP_BFLOAT16_MACRO(div.rn)
2222
+ }
2223
+
2224
+ /******************************************************************************
2225
+ * __nv_bfloat162 functions *
2226
+ ******************************************************************************/
2227
+ #define __APPROX_FCAST(fun) /* do */ {\
2228
+ __nv_bfloat16 val;\
2229
+ asm("{.reg.b32 f; \n"\
2230
+ " .reg.b16 r; \n"\
2231
+ " mov.b16 r,%1; \n"\
2232
+ " mov.b32 f,{0,r}; \n"\
2233
+ " " __CUDA_BF16_STRINGIFY(fun) ".approx.f32 f,f; \n"\
2234
+ " cvt.rn.bf16.f32 r,f; \n"\
2235
+ " mov.b16 %0,r; \n"\
2236
+ "}": "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)));\
2237
+ return val;\
2238
+ } /* while(0) */
2239
+ #define __APPROX_FCAST2(fun) /* do */ {\
2240
+ __nv_bfloat162 val;\
2241
+ asm("{.reg.b16 hl, hu; \n"\
2242
+ " .reg.b32 fl, fu; \n"\
2243
+ " mov.b32 {hl, hu}, %1; \n"\
2244
+ " mov.b32 fl, {0,hl}; \n"\
2245
+ " mov.b32 fu, {0,hu}; \n"\
2246
+ " " __CUDA_BF16_STRINGIFY(fun) ".approx.f32 fl, fl; \n"\
2247
+ " " __CUDA_BF16_STRINGIFY(fun) ".approx.f32 fu, fu; \n"\
2248
+ " cvt.rn.bf16.f32 hl, fl; \n"\
2249
+ " cvt.rn.bf16.f32 hu, fu; \n"\
2250
+ " mov.b32 %0, {hl, hu}; \n"\
2251
+ "}":"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a))); \
2252
+ return val;\
2253
+ } /* while(0) */
2254
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hsin_internal(const __nv_bfloat16 a) {
2255
+ float f = __bfloat162float(a);
2256
+ f = sinf(f);
2257
+ return __float2bfloat16_rn(f);
2258
+ }
2259
+ __CUDA_BF16_DECL__ __nv_bfloat16 hsin(const __nv_bfloat16 a) {
2260
+ return __hsin_internal(a);
2261
+ }
2262
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2sin(const __nv_bfloat162 a) {
2263
+ const __nv_bfloat16 l = __low2bfloat16(a);
2264
+ const __nv_bfloat16 h = __high2bfloat16(a);
2265
+ return __halves2bfloat162(__hsin_internal(l), __hsin_internal(h));
2266
+ }
2267
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hcos_internal(const __nv_bfloat16 a) {
2268
+ float f = __bfloat162float(a);
2269
+ f = cosf(f);
2270
+ return __float2bfloat16_rn(f);
2271
+ }
2272
+ __CUDA_BF16_DECL__ __nv_bfloat16 hcos(const __nv_bfloat16 a) {
2273
+ return __hcos_internal(a);
2274
+ }
2275
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2cos(const __nv_bfloat162 a) {
2276
+ const __nv_bfloat16 l = __low2bfloat16(a);
2277
+ const __nv_bfloat16 h = __high2bfloat16(a);
2278
+ return __halves2bfloat162(__hcos_internal(l), __hcos_internal(h));
2279
+ }
2280
+
2281
+ #define __BF16_SPEC_CASE2(i,r, spc, ulp) \
2282
+ "{.reg.b32 spc, ulp, p;\n"\
2283
+ " mov.b32 spc," __CUDA_BF16_STRINGIFY(spc) ";\n"\
2284
+ " mov.b32 ulp," __CUDA_BF16_STRINGIFY(ulp) ";\n"\
2285
+ " set.eq.f16x2.f16x2 p," __CUDA_BF16_STRINGIFY(i) ", spc;\n"\
2286
+ " fma.rn.bf16x2 " __CUDA_BF16_STRINGIFY(r) ",p,ulp," __CUDA_BF16_STRINGIFY(r) ";\n}\n"
2287
+ #define __BF16_SPEC_CASE(i,r, spc, ulp) \
2288
+ "{.reg.b16 spc, ulp, p;\n"\
2289
+ " mov.b16 spc," __CUDA_BF16_STRINGIFY(spc) ";\n"\
2290
+ " mov.b16 ulp," __CUDA_BF16_STRINGIFY(ulp) ";\n"\
2291
+ " set.eq.f16.f16 p," __CUDA_BF16_STRINGIFY(i) ", spc;\n"\
2292
+ " fma.rn.bf16 " __CUDA_BF16_STRINGIFY(r) ",p,ulp," __CUDA_BF16_STRINGIFY(r) ";\n}\n"
2293
+
2294
+ __CUDA_BF16_DECL__ __nv_bfloat16 hexp(const __nv_bfloat16 a) {
2295
+ __nv_bfloat16 val;
2296
+ asm("{.reg.b32 f, C; \n"
2297
+ " .reg.b16 h,r; \n"
2298
+ " mov.b16 h,%1; \n"
2299
+ " mov.b32 f,{0,h}; \n"
2300
+ " mov.b32 C, 0x3FB8AA3CU; \n"
2301
+ " mul.f32 f,f,C; \n"
2302
+ " ex2.approx.f32 f,f; \n"
2303
+ " cvt.rn.bf16.f32 r,f; \n"
2304
+ " mov.b16 %0,r; \n"
2305
+ "}": "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)));
2306
+ return val;
2307
+ }
2308
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2exp(const __nv_bfloat162 a) {
2309
+ __nv_bfloat162 val;
2310
+ asm("{.reg.b16 hl, hu; \n"
2311
+ " .reg.b32 h,r,fl,fu, C; \n"
2312
+ " mov.b32 {hl, hu}, %1; \n"
2313
+ " mov.b32 h, %1; \n"
2314
+ " mov.b32 fl, {0,hl}; \n"
2315
+ " mov.b32 fu, {0,hu}; \n"
2316
+ " mov.b32 C, 0x3FB8AA3CU; \n"
2317
+ " mul.f32 fl,fl,C; \n"
2318
+ " mul.f32 fu,fu,C; \n"
2319
+ " ex2.approx.f32 fl, fl; \n"
2320
+ " ex2.approx.f32 fu, fu; \n"
2321
+ " cvt.rn.bf16.f32 hl, fl; \n"
2322
+ " cvt.rn.bf16.f32 hu, fu; \n"
2323
+ " mov.b32 r, {hl, hu}; \n"
2324
+ " mov.b32 %0, r; \n"
2325
+ "}":"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)));
2326
+ return val;
2327
+ }
2328
+ __CUDA_BF16_DECL__ __nv_bfloat16 hexp2(const __nv_bfloat16 a) {
2329
+ __APPROX_FCAST(ex2)
2330
+ }
2331
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2exp2(const __nv_bfloat162 a) {
2332
+ __APPROX_FCAST2(ex2)
2333
+ }
2334
+ __CUDA_BF16_DECL__ __nv_bfloat16 hexp10(const __nv_bfloat16 a) {
2335
+ __nv_bfloat16 val;
2336
+ asm("{.reg.b16 h, r; \n"
2337
+ " .reg.b32 f, C; \n"
2338
+ " mov.b16 h, %1; \n"
2339
+ " mov.b32 f, {0,h}; \n"
2340
+ " mov.b32 C, 0x40549A78U; \n"
2341
+ " mul.f32 f,f,C; \n"
2342
+ " ex2.approx.f32 f, f; \n"
2343
+ " cvt.rn.bf16.f32 r, f; \n"
2344
+ __BF16_SPEC_CASE(%1, r, 0xBC95U,0xBF00U)
2345
+ " mov.b16 %0, r; \n"
2346
+ "}":"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)));
2347
+ return val;
2348
+ }
2349
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2exp10(const __nv_bfloat162 a) {
2350
+ __nv_bfloat162 val;
2351
+ asm("{.reg.b16 hl, hu; \n"
2352
+ " .reg.b32 h,r,fl,fu, C; \n"
2353
+ " mov.b32 {hl, hu}, %1; \n"
2354
+ " mov.b32 fl, {0,hl}; \n"
2355
+ " mov.b32 fu, {0,hu}; \n"
2356
+ " mov.b32 C, 0x40549A78U; \n"
2357
+ " mul.f32 fl,fl,C; \n"
2358
+ " mul.f32 fu,fu,C; \n"
2359
+ " ex2.approx.f32 fl, fl; \n"
2360
+ " ex2.approx.f32 fu, fu; \n"
2361
+ " cvt.rn.bf16.f32 hl, fl; \n"
2362
+ " cvt.rn.bf16.f32 hu, fu; \n"
2363
+ " mov.b32 r, {hl, hu}; \n"
2364
+ __BF16_SPEC_CASE2(%1, r, 0xBC95BC95U,0xBF00BF00U)
2365
+ " mov.b32 %0, r; \n"
2366
+ "}":"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)));
2367
+ return val;
2368
+ }
2369
+ __CUDA_BF16_DECL__ __nv_bfloat16 hlog2(const __nv_bfloat16 a) {
2370
+ __APPROX_FCAST(lg2)
2371
+ }
2372
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2log2(const __nv_bfloat162 a) {
2373
+ __APPROX_FCAST2(lg2)
2374
+ }
2375
+ __CUDA_BF16_DECL__ __nv_bfloat16 hlog(const __nv_bfloat16 a) {
2376
+ __nv_bfloat16 val;
2377
+ asm("{.reg.b32 f, C; \n"
2378
+ " .reg.b16 r,h; \n"
2379
+ " mov.b16 h,%1; \n"
2380
+ " mov.b32 f,{0,h}; \n"
2381
+ " lg2.approx.f32 f,f; \n"
2382
+ " mov.b32 C, 0x3f317218U; \n"
2383
+ " mul.f32 f,f,C; \n"
2384
+ " cvt.rn.bf16.f32 r,f; \n"
2385
+ " mov.b16 %0,r; \n"
2386
+ "}": "=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)));
2387
+ return val;
2388
+ }
2389
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2log(const __nv_bfloat162 a) {
2390
+ __nv_bfloat162 val;
2391
+ asm("{.reg.b16 hl, hu; \n"
2392
+ " .reg.b32 r, fl, fu, C, h; \n"
2393
+ " mov.b32 {hl, hu}, %1; \n"
2394
+ " mov.b32 h, %1; \n"
2395
+ " mov.b32 fl, {0,hl}; \n"
2396
+ " mov.b32 fu, {0,hu}; \n"
2397
+ " lg2.approx.f32 fl, fl; \n"
2398
+ " lg2.approx.f32 fu, fu; \n"
2399
+ " mov.b32 C, 0x3f317218U; \n"
2400
+ " mul.f32 fl,fl,C; \n"
2401
+ " mul.f32 fu,fu,C; \n"
2402
+ " cvt.rn.bf16.f32 hl, fl; \n"
2403
+ " cvt.rn.bf16.f32 hu, fu; \n"
2404
+ " mov.b32 r, {hl, hu}; \n"
2405
+ " mov.b32 %0, r; \n"
2406
+ "}":"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)));
2407
+ return val;
2408
+ }
2409
+ __CUDA_BF16_DECL__ __nv_bfloat16 hlog10(const __nv_bfloat16 a) {
2410
+ __nv_bfloat16 val;
2411
+ asm("{.reg.b16 h, r; \n"
2412
+ " .reg.b32 f, C; \n"
2413
+ " mov.b16 h, %1; \n"
2414
+ " mov.b32 f, {0,h}; \n"
2415
+ " lg2.approx.f32 f, f; \n"
2416
+ " mov.b32 C, 0x3E9A209BU; \n"
2417
+ " mul.f32 f,f,C; \n"
2418
+ " cvt.rn.bf16.f32 r, f; \n"
2419
+ " mov.b16 %0, r; \n"
2420
+ "}":"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)));
2421
+ return val;
2422
+ }
2423
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2log10(const __nv_bfloat162 a) {
2424
+ __nv_bfloat162 val;
2425
+ asm("{.reg.b16 hl, hu; \n"
2426
+ " .reg.b32 r, fl, fu, C, h; \n"
2427
+ " mov.b32 {hl, hu}, %1; \n"
2428
+ " mov.b32 h, %1; \n"
2429
+ " mov.b32 fl, {0,hl}; \n"
2430
+ " mov.b32 fu, {0,hu}; \n"
2431
+ " lg2.approx.f32 fl, fl; \n"
2432
+ " lg2.approx.f32 fu, fu; \n"
2433
+ " mov.b32 C, 0x3E9A209BU; \n"
2434
+ " mul.f32 fl,fl,C; \n"
2435
+ " mul.f32 fu,fu,C; \n"
2436
+ " cvt.rn.bf16.f32 hl, fl; \n"
2437
+ " cvt.rn.bf16.f32 hu, fu; \n"
2438
+ " mov.b32 r, {hl, hu}; \n"
2439
+ " mov.b32 %0, r; \n"
2440
+ "}":"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)));
2441
+ return val;
2442
+ }
2443
+ #undef __BF16_SPEC_CASE2
2444
+ #undef __BF16_SPEC_CASE
2445
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2rcp(const __nv_bfloat162 a) {
2446
+ __APPROX_FCAST2(rcp)
2447
+ }
2448
+ __CUDA_BF16_DECL__ __nv_bfloat16 hrcp(const __nv_bfloat16 a) {
2449
+ __APPROX_FCAST(rcp)
2450
+ }
2451
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2rsqrt(const __nv_bfloat162 a) {
2452
+ __APPROX_FCAST2(rsqrt)
2453
+ }
2454
+ __CUDA_BF16_DECL__ __nv_bfloat16 hrsqrt(const __nv_bfloat16 a) {
2455
+ __APPROX_FCAST(rsqrt)
2456
+ }
2457
+ __CUDA_BF16_DECL__ __nv_bfloat162 h2sqrt(const __nv_bfloat162 a) {
2458
+ __APPROX_FCAST2(sqrt)
2459
+ }
2460
+ __CUDA_BF16_DECL__ __nv_bfloat16 hsqrt(const __nv_bfloat16 a) {
2461
+ __APPROX_FCAST(sqrt)
2462
+ }
2463
+ #undef __APPROX_FCAST
2464
+ #undef __APPROX_FCAST2
2465
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hisnan2(const __nv_bfloat162 a)
2466
+ {
2467
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2468
+ __nv_bfloat162 r;
2469
+ asm("{set.nan.bf16x2.bf16x2 %0,%1,%1;\n}"
2470
+ :"=r"(__BFLOAT162_TO_UI(r)) : "r"(__BFLOAT162_TO_CUI(a)));
2471
+ return r;
2472
+ #else
2473
+ const __nv_bfloat162 b = a;
2474
+ __BINARY_OP_BFLOAT162_MACRO(set.nan.f32)
2475
+ #endif
2476
+ }
2477
+ __CUDA_BF16_DECL__ bool __hisnan(const __nv_bfloat16 a)
2478
+ {
2479
+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2480
+ __nv_bfloat16 r;
2481
+ asm("{set.nan.bf16.bf16 %0,%1,%1;\n}"
2482
+ :"=h"(__BFLOAT16_TO_US(r)) : "h"(__BFLOAT16_TO_CUS(a)));
2483
+ return __BFLOAT16_TO_CUS(r) != 0U;
2484
+ #else
2485
+ unsigned int r;
2486
+ asm( "{.reg .b32 a;\n"
2487
+ " mov.b32 a, {0,%1};\n"
2488
+ " set.nan.f32.f32 %0, a, a;}\n"
2489
+ :"=r"(r) : "h"(__BFLOAT16_TO_CUS(a)));
2490
+ return r != 0U;
2491
+ #endif
2492
+ }
2493
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hneg2(const __nv_bfloat162 a)
2494
+ {
2495
+ __nv_bfloat162 r;
2496
+ asm("{neg.bf16x2 %0,%1;\n}"
2497
+ :"=r"(__BFLOAT162_TO_UI(r)) : "r"(__BFLOAT162_TO_CUI(a)));
2498
+ return r;
2499
+ }
2500
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hneg(const __nv_bfloat16 a)
2501
+ {
2502
+ __nv_bfloat16 r;
2503
+ asm("{neg.bf16 %0,%1;\n}"
2504
+ :"=h"(__BFLOAT16_TO_US(r)) : "h"(__BFLOAT16_TO_CUS(a)));
2505
+ return r;
2506
+ }
2507
+ __CUDA_BF16_DECL__ __nv_bfloat162 __habs2(const __nv_bfloat162 a)
2508
+ {
2509
+ __nv_bfloat162 r;
2510
+ asm("{abs.bf16x2 %0,%1;\n}"
2511
+ :"=r"(__BFLOAT162_TO_UI(r)) : "r"(__BFLOAT162_TO_CUI(a)));
2512
+ return r;
2513
+ }
2514
+ __CUDA_BF16_DECL__ __nv_bfloat16 __habs(const __nv_bfloat16 a)
2515
+ {
2516
+ __nv_bfloat16 r;
2517
+ asm("{abs.bf16 %0,%1;\n}"
2518
+ :"=h"(__BFLOAT16_TO_US(r)) : "h"(__BFLOAT16_TO_CUS(a)));
2519
+ return r;
2520
+ }
2521
+ /******************************************************************************
2522
+ * __nv_bfloat16 arithmetic *
2523
+ ******************************************************************************/
2524
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hmax(const __nv_bfloat16 a, const __nv_bfloat16 b)
2525
+ {
2526
+ __nv_bfloat16 val;
2527
+ asm( "{ max.bf16 %0,%1,%2;\n}"
2528
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)));
2529
+ return val;
2530
+ }
2531
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hmin(const __nv_bfloat16 a, const __nv_bfloat16 b)
2532
+ {
2533
+ __nv_bfloat16 val;
2534
+ asm( "{ min.bf16 %0,%1,%2;\n}"
2535
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)));
2536
+ return val;
2537
+ }
2538
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hmax_nan(const __nv_bfloat16 a, const __nv_bfloat16 b)
2539
+ {
2540
+ __nv_bfloat16 val;
2541
+ asm( "{ max.NaN.bf16 %0,%1,%2;\n}"
2542
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)));
2543
+ return val;
2544
+ }
2545
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hmin_nan(const __nv_bfloat16 a, const __nv_bfloat16 b)
2546
+ {
2547
+ __nv_bfloat16 val;
2548
+ asm( "{ min.NaN.bf16 %0,%1,%2;\n}"
2549
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)));
2550
+ return val;
2551
+ }
2552
+ __CUDA_BF16_DECL__ __nv_bfloat16 __hfma_relu(const __nv_bfloat16 a, const __nv_bfloat16 b, const __nv_bfloat16 c)
2553
+ {
2554
+ __nv_bfloat16 val;
2555
+ asm( "{ fma.rn.relu.bf16 %0,%1,%2,%3;\n}"
2556
+ :"=h"(__BFLOAT16_TO_US(val)) : "h"(__BFLOAT16_TO_CUS(a)),"h"(__BFLOAT16_TO_CUS(b)),"h"(__BFLOAT16_TO_CUS(c)));
2557
+ return val;
2558
+ }
2559
+ /******************************************************************************
2560
+ * __nv_bfloat162 arithmetic *
2561
+ ******************************************************************************/
2562
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hmax2(const __nv_bfloat162 a, const __nv_bfloat162 b)
2563
+ {
2564
+ __nv_bfloat162 val;
2565
+ asm( "{ max.bf16x2 %0,%1,%2;\n}"
2566
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b)));
2567
+ return val;
2568
+ }
2569
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hmin2(const __nv_bfloat162 a, const __nv_bfloat162 b)
2570
+ {
2571
+ __nv_bfloat162 val;
2572
+ asm( "{ min.bf16x2 %0,%1,%2;\n}"
2573
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b)));
2574
+ return val;
2575
+ }
2576
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hmax2_nan(const __nv_bfloat162 a, const __nv_bfloat162 b)
2577
+ {
2578
+ __nv_bfloat162 val;
2579
+ asm( "{ max.NaN.bf16x2 %0,%1,%2;\n}"
2580
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b)));
2581
+ return val;
2582
+ }
2583
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hmin2_nan(const __nv_bfloat162 a, const __nv_bfloat162 b)
2584
+ {
2585
+ __nv_bfloat162 val;
2586
+ asm( "{ min.NaN.bf16x2 %0,%1,%2;\n}"
2587
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b)));
2588
+ return val;
2589
+ }
2590
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hfma2_relu(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c)
2591
+ {
2592
+ __nv_bfloat162 val;
2593
+ asm( "{ fma.rn.relu.bf16x2 %0,%1,%2,%3;\n}"
2594
+ :"=r"(__BFLOAT162_TO_UI(val)) : "r"(__BFLOAT162_TO_CUI(a)),"r"(__BFLOAT162_TO_CUI(b)),"r"(__BFLOAT162_TO_CUI(c)));
2595
+ return val;
2596
+ }
2597
+
2598
+ __CUDA_BF16_DECL__ __nv_bfloat162 __hcmadd(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c)
2599
+ {
2600
+ // fast version of complex multiply-accumulate
2601
+ // (a.re, a.im) * (b.re, b.im) + (c.re, c.im)
2602
+ // acc.re = (c.re + a.re*b.re) - a.im*b.im
2603
+ // acc.im = (c.im + a.re*b.im) + a.im*b.re
2604
+ __nv_bfloat16 real_tmp = __hfma(a.x, b.x, c.x);
2605
+ __nv_bfloat16 img_tmp = __hfma(a.x, b.y, c.y);
2606
+ real_tmp = __hfma(__hneg(a.y), b.y, real_tmp);
2607
+ img_tmp = __hfma(a.y, b.x, img_tmp);
2608
+ return make_bfloat162(real_tmp, img_tmp);
2609
+ }
2610
+
2611
+
2612
+ /* Define __PTR for atomicAdd prototypes below, undef after done */
2613
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
2614
+ #define __PTR "l"
2615
+ #else
2616
+ #define __PTR "r"
2617
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
2618
+
2619
+ __CUDA_BF16_DECL__ __nv_bfloat162 atomicAdd(__nv_bfloat162 *const address, const __nv_bfloat162 val)
2620
+ {
2621
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2622
+ __nv_bfloat162 r;
2623
+ asm volatile ("{ atom.add.noftz.bf16x2 %0,[%1],%2; }\n"
2624
+ : "=r"(__BFLOAT162_TO_UI(r)) : __PTR(address), "r"(__BFLOAT162_TO_CUI(val))
2625
+ : "memory");
2626
+ return r;
2627
+ #else
2628
+ unsigned int* address_as_uint = (unsigned int*)address;
2629
+ unsigned int old = *address_as_uint, assumed;
2630
+ do {
2631
+ assumed = old;
2632
+ __nv_bfloat162 new_val = __hadd2(val, *(__nv_bfloat162*)&assumed);
2633
+ old = atomicCAS(address_as_uint, assumed, *(unsigned int*)&new_val);
2634
+ } while (assumed != old);
2635
+ return *(__nv_bfloat162*)&old;
2636
+ #endif
2637
+ }
2638
+
2639
+ __CUDA_BF16_DECL__ __nv_bfloat16 atomicAdd(__nv_bfloat16 *const address, const __nv_bfloat16 val)
2640
+ {
2641
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
2642
+ __nv_bfloat16 r;
2643
+ asm volatile ("{ atom.add.noftz.bf16 %0,[%1],%2; }\n"
2644
+ : "=h"(__BFLOAT16_TO_US(r))
2645
+ : __PTR(address), "h"(__BFLOAT16_TO_CUS(val))
2646
+ : "memory");
2647
+ return r;
2648
+ #else
2649
+ unsigned short int* address_as_us = (unsigned short int*)address;
2650
+ unsigned short int old = *address_as_us, assumed;
2651
+ do {
2652
+ assumed = old;
2653
+ old = atomicCAS(address_as_us, assumed,
2654
+ __bfloat16_as_ushort(__hadd(val, __ushort_as_bfloat16(assumed))));
2655
+ } while (assumed != old);
2656
+ return __ushort_as_bfloat16(old);
2657
+ #endif
2658
+ }
2659
+
2660
+ #undef __PTR
2661
+ #undef __CUDA_BF16_DECL__
2662
+ #endif /* defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) */
2663
+ #endif /* defined(__cplusplus) */
2664
+
2665
+ #undef __BINARY_OP_BFLOAT162_MACRO
2666
+ #undef __BINARY_OP_BFLOAT16_MACRO
2667
+
2668
+ #undef __CUDA_HOSTDEVICE_BF16_DECL__
2669
+ #undef __CUDA_BF16_DECL__
2670
+
2671
+ /* Define first-class types "nv_bfloat16" and "nv_bfloat162", unless user specifies otherwise via "#define CUDA_NO_BFLOAT16" */
2672
+ /* C cannot ever have these types defined here, because __nv_bfloat16 and __nv_bfloat162 are C++ classes */
2673
+ #if defined(__cplusplus) && !defined(CUDA_NO_BFLOAT16)
2674
+ typedef __nv_bfloat16 nv_bfloat16;
2675
+ typedef __nv_bfloat162 nv_bfloat162;
2676
+
2677
+ #endif /* defined(__cplusplus) && !defined(CUDA_NO_BFLOAT16) */
2678
+
2679
+ #if defined(__CPP_VERSION_AT_LEAST_11_BF16)
2680
+ #undef __CPP_VERSION_AT_LEAST_11_BF16
2681
+ #endif /* defined(__CPP_VERSION_AT_LEAST_11_BF16) */
2682
+
2683
+ #endif /* end of include guard: __CUDA_BF16_HPP__ */