triton-windows 3.1.0.post17__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,109 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def globaltimer(_builder=None):
6
+ return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
7
+ _builder=_builder)
8
+
9
+
10
+ @core.extern
11
+ def smid(_builder=None):
12
+ return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
13
+ _builder=_builder)
14
+
15
+
16
+ @core.builtin
17
+ def num_threads(_builder=None):
18
+ return core.constexpr(_builder.options.num_warps * 32)
19
+
20
+
21
+ @core.builtin
22
+ def num_warps(_builder=None):
23
+ return core.constexpr(_builder.options.num_warps)
24
+
25
+
26
+ # ----- FP8E4M3B15 ------
27
+ # This data-type is a variant of the standard FP8E4M3 format.
28
+ # It was designed for fast software conversion to FP16 on
29
+ # nvidia GPUs that do not support it natively.
30
+ # This is the same format as FP8E4M3Nv, but:
31
+ # - the exponent bias is 15 instead of 7
32
+ # - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
33
+ @core.builtin
34
+ def convert_fp8e4b15_to_float16(arg, _builder=None):
35
+ return core.inline_asm_elementwise(
36
+ "{ \n"
37
+ ".reg .b32 a<2>, b<2>; \n"
38
+ "prmt.b32 a0, 0, $2, 0x5746; \n"
39
+ "and.b32 b0, a0, 0x7f007f00; \n"
40
+ "and.b32 b1, a0, 0x00ff00ff; \n"
41
+ "and.b32 a1, a0, 0x00800080; \n"
42
+ "shr.b32 b0, b0, 1; \n"
43
+ "add.u32 b1, b1, a1; \n"
44
+ "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
45
+ "shl.b32 $1, b1, 7; \n"
46
+ "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
47
+ _builder=_builder)
48
+
49
+
50
+ @core.builtin
51
+ def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None):
52
+ asm = """{
53
+ .reg .pred p<4>;
54
+ .reg .b32 a<2>, b<2>;
55
+ .reg .b16 c<4>;
56
+ .reg .b16 max_val_f16;
57
+ .reg .b32 max_val_f16x2;
58
+ mov.b16 max_val_f16, 0x3F00;
59
+ mov.b32 max_val_f16x2, 0x3F003F00;
60
+ and.b32 a0, $1, 0x7fff7fff;
61
+ and.b32 a1, $2, 0x7fff7fff;"""
62
+ if has_minx2:
63
+ asm += """min.f16x2 a0, a0, max_val_f16x2;
64
+ min.f16x2 a1, a1, max_val_f16x2;"""
65
+ else:
66
+ asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2;
67
+ setp.lt.f16x2 p2|p3, a1, max_val_f16x2;
68
+ mov.b32 {c0, c1}, a0;
69
+ mov.b32 {c2, c3}, a1;
70
+ selp.b16 c0, c0, max_val_f16, p0;
71
+ selp.b16 c1, c1, max_val_f16, p1;
72
+ selp.b16 c2, c2, max_val_f16, p2;
73
+ selp.b16 c3, c3, max_val_f16, p3;
74
+ mov.b32 a0, {c0, c1};
75
+ mov.b32 a1, {c2, c3};"""
76
+ asm += """mad.lo.u32 a0, a0, 2, 0x00800080;
77
+ mad.lo.u32 a1, a1, 2, 0x00800080;
78
+ lop3.b32 b0, $1, 0x80008000, a0, 0xea;
79
+ lop3.b32 b1, $2, 0x80008000, a1, 0xea;
80
+ prmt.b32 $0, b0, b1, 0x7531;
81
+ }"""
82
+ return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
83
+ _builder=_builder)
84
+
85
+
86
+ @core.builtin
87
+ def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None):
88
+ if arg.type.scalar.is_fp8e4b15():
89
+ upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder)
90
+ if dst_ty.scalar.is_fp32():
91
+ upcast_val = upcast_val.to(core.float32, _builder=_builder)
92
+ return upcast_val
93
+
94
+ assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
95
+ downcast_val = arg
96
+ if arg.type.scalar.is_fp32():
97
+ downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder)
98
+ downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder)
99
+ return downcast_val
100
+
101
+
102
+ @core.builtin
103
+ def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None):
104
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder)
105
+
106
+
107
+ @core.builtin
108
+ def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None):
109
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder)
@@ -0,0 +1,3 @@
1
+ from . import libdevice
2
+
3
+ __all__ = ["libdevice"]
@@ -0,0 +1,468 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def abs(arg0, _builder=None):
6
+ return core.extern_elementwise(
7
+ "", "", [arg0], {
8
+ (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")),
9
+ (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")),
10
+ (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")),
11
+ (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")),
12
+ }, is_pure=True, _builder=_builder)
13
+
14
+
15
+ @core.extern
16
+ def floor(arg0, _builder=None):
17
+ return core.extern_elementwise(
18
+ "", "", [arg0], {
19
+ (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")),
20
+ (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")),
21
+ }, is_pure=True, _builder=_builder)
22
+
23
+
24
+ @core.extern
25
+ def rsqrt(arg0, _builder=None):
26
+ return core.extern_elementwise(
27
+ "", "", [arg0], {
28
+ (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")),
29
+ (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")),
30
+ }, is_pure=True, _builder=_builder)
31
+
32
+
33
+ @core.extern
34
+ def ceil(arg0, _builder=None):
35
+ return core.extern_elementwise(
36
+ "", "", [arg0], {
37
+ (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")),
38
+ (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")),
39
+ }, is_pure=True, _builder=_builder)
40
+
41
+
42
+ @core.extern
43
+ def trunc(arg0, _builder=None):
44
+ return core.extern_elementwise(
45
+ "", "", [arg0], {
46
+ (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")),
47
+ (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")),
48
+ }, is_pure=True, _builder=_builder)
49
+
50
+
51
+ @core.extern
52
+ def exp2(arg0, _builder=None):
53
+ return core.extern_elementwise(
54
+ "", "", [arg0], {
55
+ (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")),
56
+ (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")),
57
+ }, is_pure=True, _builder=_builder)
58
+
59
+
60
+ @core.extern
61
+ def exp(arg0, _builder=None):
62
+ return core.extern_elementwise(
63
+ "", "", [arg0], {
64
+ (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")),
65
+ (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")),
66
+ }, is_pure=True, _builder=_builder)
67
+
68
+
69
+ @core.extern
70
+ def fast_dividef(arg0, arg1, _builder=None):
71
+ return core.extern_elementwise("", "", [arg0, arg1], {
72
+ (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")),
73
+ }, is_pure=True, _builder=_builder)
74
+
75
+
76
+ @core.extern
77
+ def sqrt(arg0, _builder=None):
78
+ return core.extern_elementwise(
79
+ "", "", [arg0], {
80
+ (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")),
81
+ (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")),
82
+ }, is_pure=True, _builder=_builder)
83
+
84
+
85
+ @core.extern
86
+ def llrint(arg0, _builder=None):
87
+ return core.extern_elementwise(
88
+ "", "", [arg0], {
89
+ (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")),
90
+ (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")),
91
+ }, is_pure=True, _builder=_builder)
92
+
93
+
94
+ @core.extern
95
+ def nearbyint(arg0, _builder=None):
96
+ return core.extern_elementwise(
97
+ "", "", [
98
+ arg0,
99
+ ], {
100
+ (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")),
101
+ (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")),
102
+ }, is_pure=True, _builder=_builder)
103
+
104
+
105
+ @core.extern
106
+ def isnan(arg0, _builder=None):
107
+ return core.extern_elementwise(
108
+ "", "", [
109
+ arg0,
110
+ ], {
111
+ (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")),
112
+ (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")),
113
+ }, is_pure=True, _builder=_builder)
114
+
115
+
116
+ @core.extern
117
+ def signbit(arg0, _builder=None):
118
+ return core.extern_elementwise(
119
+ "", "", [
120
+ arg0,
121
+ ], {
122
+ (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")),
123
+ (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")),
124
+ }, is_pure=True, _builder=_builder)
125
+
126
+
127
+ @core.extern
128
+ def copysign(arg0, arg1, _builder=None):
129
+ return core.extern_elementwise(
130
+ "", "", [arg0, arg1], {
131
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")),
132
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")),
133
+ }, is_pure=True, _builder=_builder)
134
+
135
+
136
+ @core.extern
137
+ def isinf(arg0, _builder=None):
138
+ return core.extern_elementwise(
139
+ "", "", [arg0], {
140
+ (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")),
141
+ (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")),
142
+ }, is_pure=True, _builder=_builder)
143
+
144
+
145
+ @core.extern
146
+ def nextafter(arg0, arg1, _builder=None):
147
+ return core.extern_elementwise(
148
+ "", "", [arg0, arg1], {
149
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")),
150
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")),
151
+ }, is_pure=True, _builder=_builder)
152
+
153
+
154
+ @core.extern
155
+ def sin(arg0, _builder=None):
156
+ return core.extern_elementwise(
157
+ "", "", [arg0], {
158
+ (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")),
159
+ (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")),
160
+ }, is_pure=True, _builder=_builder)
161
+
162
+
163
+ @core.extern
164
+ def cos(arg0, _builder=None):
165
+ return core.extern_elementwise(
166
+ "", "", [arg0], {
167
+ (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")),
168
+ (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")),
169
+ }, is_pure=True, _builder=_builder)
170
+
171
+
172
+ @core.extern
173
+ def tan(arg0, _builder=None):
174
+ return core.extern_elementwise(
175
+ "", "", [arg0], {
176
+ (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")),
177
+ (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")),
178
+ }, is_pure=True, _builder=_builder)
179
+
180
+
181
+ @core.extern
182
+ def log2(arg0, _builder=None):
183
+ return core.extern_elementwise(
184
+ "", "", [arg0], {
185
+ (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")),
186
+ (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")),
187
+ }, is_pure=True, _builder=_builder)
188
+
189
+
190
+ @core.extern
191
+ def cosh(arg0, _builder=None):
192
+ return core.extern_elementwise(
193
+ "", "", [arg0], {
194
+ (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")),
195
+ (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")),
196
+ }, is_pure=True, _builder=_builder)
197
+
198
+
199
+ @core.extern
200
+ def sinh(arg0, _builder=None):
201
+ return core.extern_elementwise(
202
+ "", "", [arg0], {
203
+ (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")),
204
+ (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")),
205
+ }, is_pure=True, _builder=_builder)
206
+
207
+
208
+ @core.extern
209
+ def tanh(arg0, _builder=None):
210
+ return core.extern_elementwise(
211
+ "", "", [arg0], {
212
+ (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")),
213
+ (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")),
214
+ }, is_pure=True, _builder=_builder)
215
+
216
+
217
+ @core.extern
218
+ def atan2(arg0, arg1, _builder=None):
219
+ return core.extern_elementwise(
220
+ "", "", [arg0, arg1], {
221
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")),
222
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")),
223
+ }, is_pure=True, _builder=_builder)
224
+
225
+
226
+ @core.extern
227
+ def atan(arg0, _builder=None):
228
+ return core.extern_elementwise(
229
+ "", "", [arg0], {
230
+ (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")),
231
+ (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")),
232
+ }, is_pure=True, _builder=_builder)
233
+
234
+
235
+ @core.extern
236
+ def asin(arg0, _builder=None):
237
+ return core.extern_elementwise(
238
+ "", "", [arg0], {
239
+ (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")),
240
+ (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")),
241
+ }, is_pure=True, _builder=_builder)
242
+
243
+
244
+ @core.extern
245
+ def acos(arg0, _builder=None):
246
+ return core.extern_elementwise(
247
+ "", "", [arg0], {
248
+ (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")),
249
+ (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")),
250
+ }, is_pure=True, _builder=_builder)
251
+
252
+
253
+ @core.extern
254
+ def log(arg0, _builder=None):
255
+ return core.extern_elementwise(
256
+ "", "", [arg0], {
257
+ (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")),
258
+ (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")),
259
+ }, is_pure=True, _builder=_builder)
260
+
261
+
262
+ @core.extern
263
+ def log10(arg0, _builder=None):
264
+ return core.extern_elementwise(
265
+ "", "", [arg0], {
266
+ (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")),
267
+ (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")),
268
+ }, is_pure=True, _builder=_builder)
269
+
270
+
271
+ @core.extern
272
+ def log1p(arg0, _builder=None):
273
+ return core.extern_elementwise(
274
+ "", "", [arg0], {
275
+ (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")),
276
+ (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")),
277
+ }, is_pure=True, _builder=_builder)
278
+
279
+
280
+ @core.extern
281
+ def acosh(arg0, _builder=None):
282
+ return core.extern_elementwise(
283
+ "", "", [arg0], {
284
+ (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")),
285
+ (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")),
286
+ }, is_pure=True, _builder=_builder)
287
+
288
+
289
+ @core.extern
290
+ def asinh(arg0, _builder=None):
291
+ return core.extern_elementwise(
292
+ "", "", [arg0], {
293
+ (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")),
294
+ (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")),
295
+ }, is_pure=True, _builder=_builder)
296
+
297
+
298
+ @core.extern
299
+ def atanh(arg0, _builder=None):
300
+ return core.extern_elementwise(
301
+ "", "", [arg0], {
302
+ (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")),
303
+ (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")),
304
+ }, is_pure=True, _builder=_builder)
305
+
306
+
307
+ @core.extern
308
+ def expm1(arg0, _builder=None):
309
+ return core.extern_elementwise(
310
+ "", "", [arg0], {
311
+ (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")),
312
+ (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")),
313
+ }, is_pure=True, _builder=_builder)
314
+
315
+
316
+ @core.extern
317
+ def hypot(arg0, arg1, _builder=None):
318
+ return core.extern_elementwise(
319
+ "", "", [arg0, arg1], {
320
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")),
321
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")),
322
+ }, is_pure=True, _builder=_builder)
323
+
324
+
325
+ @core.extern
326
+ def j0(arg0, _builder=None):
327
+ return core.extern_elementwise(
328
+ "", "", [arg0], {
329
+ (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")),
330
+ (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")),
331
+ }, is_pure=True, _builder=_builder)
332
+
333
+
334
+ @core.extern
335
+ def j1(arg0, _builder=None):
336
+ return core.extern_elementwise(
337
+ "", "", [arg0], {
338
+ (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")),
339
+ (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")),
340
+ }, is_pure=True, _builder=_builder)
341
+
342
+
343
+ @core.extern
344
+ def y0(arg0, _builder=None):
345
+ return core.extern_elementwise(
346
+ "", "", [arg0], {
347
+ (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")),
348
+ (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")),
349
+ }, is_pure=True, _builder=_builder)
350
+
351
+
352
+ @core.extern
353
+ def y1(arg0, _builder=None):
354
+ return core.extern_elementwise(
355
+ "", "", [arg0], {
356
+ (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")),
357
+ (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")),
358
+ }, is_pure=True, _builder=_builder)
359
+
360
+
361
+ @core.extern
362
+ def cyl_bessel_i0(arg0, _builder=None):
363
+ return core.extern_elementwise(
364
+ "", "", [arg0], {
365
+ (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")),
366
+ (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")),
367
+ }, is_pure=True, _builder=_builder)
368
+
369
+
370
+ @core.extern
371
+ def cyl_bessel_i1(arg0, _builder=None):
372
+ return core.extern_elementwise(
373
+ "", "", [arg0], {
374
+ (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")),
375
+ (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")),
376
+ }, is_pure=True, _builder=_builder)
377
+
378
+
379
+ @core.extern
380
+ def erf(arg0, _builder=None):
381
+ return core.extern_elementwise(
382
+ "", "", [arg0], {
383
+ (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")),
384
+ (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")),
385
+ }, is_pure=True, _builder=_builder)
386
+
387
+
388
+ @core.extern
389
+ def erfinv(arg0, _builder=None):
390
+ return core.extern_elementwise(
391
+ "", "", [arg0], {
392
+ (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")),
393
+ (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")),
394
+ }, is_pure=True, _builder=_builder)
395
+
396
+
397
+ @core.extern
398
+ def erfc(arg0, _builder=None):
399
+ return core.extern_elementwise(
400
+ "", "", [arg0], {
401
+ (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")),
402
+ (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")),
403
+ }, is_pure=True, _builder=_builder)
404
+
405
+
406
+ @core.extern
407
+ def erfcx(arg0, _builder=None):
408
+ return core.extern_elementwise(
409
+ "", "", [arg0], {
410
+ (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")),
411
+ (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")),
412
+ }, is_pure=True, _builder=_builder)
413
+
414
+
415
+ @core.extern
416
+ def lgamma(arg0, _builder=None):
417
+ return core.extern_elementwise(
418
+ "", "", [arg0], {
419
+ (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")),
420
+ (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")),
421
+ }, is_pure=True, _builder=_builder)
422
+
423
+
424
+ @core.extern
425
+ def ldexp(arg0, arg1, _builder=None):
426
+ return core.extern_elementwise(
427
+ "", "", [arg0, arg1], {
428
+ (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")),
429
+ (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")),
430
+ }, is_pure=True, _builder=_builder)
431
+
432
+
433
+ @core.extern
434
+ def fmod(arg0, arg1, _builder=None):
435
+ return core.extern_elementwise(
436
+ "", "", [arg0, arg1], {
437
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")),
438
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")),
439
+ }, is_pure=True, _builder=_builder)
440
+
441
+
442
+ @core.extern
443
+ def fma(arg0, arg1, arg2, _builder=None):
444
+ return core.extern_elementwise(
445
+ "", "", [arg0, arg1, arg2], {
446
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")),
447
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")),
448
+ }, is_pure=True, _builder=_builder)
449
+
450
+
451
+ @core.extern
452
+ def pow(arg0, arg1, _builder=None):
453
+ return core.extern_elementwise(
454
+ "", "", [arg0, arg1], {
455
+ (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")),
456
+ (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")),
457
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")),
458
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")),
459
+ }, is_pure=True, _builder=_builder)
460
+
461
+
462
+ @core.extern
463
+ def ilogb(arg0, _builder=None):
464
+ return core.extern_elementwise(
465
+ "", "", [arg0], {
466
+ (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")),
467
+ (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")),
468
+ }, is_pure=True, _builder=_builder)