triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,109 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def globaltimer(_semantic=None):
6
+ return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
7
+ _semantic=_semantic)
8
+
9
+
10
+ @core.extern
11
+ def smid(_semantic=None):
12
+ return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
13
+ _semantic=_semantic)
14
+
15
+
16
+ @core.builtin
17
+ def num_threads(_semantic=None):
18
+ return core.constexpr(_semantic.builder.options.num_warps * 32)
19
+
20
+
21
+ @core.builtin
22
+ def num_warps(_semantic=None):
23
+ return core.constexpr(_semantic.builder.options.num_warps)
24
+
25
+
26
+ # ----- FP8E4M3B15 ------
27
+ # This data-type is a variant of the standard FP8E4M3 format.
28
+ # It was designed for fast software conversion to FP16 on
29
+ # nvidia GPUs that do not support it natively.
30
+ # This is the same format as FP8E4M3Nv, but:
31
+ # - the exponent bias is 15 instead of 7
32
+ # - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
33
+ @core.builtin
34
+ def convert_fp8e4b15_to_float16(arg, _semantic=None):
35
+ return core.inline_asm_elementwise(
36
+ "{ \n"
37
+ ".reg .b32 a<2>, b<2>; \n"
38
+ "prmt.b32 a0, 0, $2, 0x5746; \n"
39
+ "and.b32 b0, a0, 0x7f007f00; \n"
40
+ "and.b32 b1, a0, 0x00ff00ff; \n"
41
+ "and.b32 a1, a0, 0x00800080; \n"
42
+ "shr.b32 b0, b0, 1; \n"
43
+ "add.u32 b1, b1, a1; \n"
44
+ "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
45
+ "shl.b32 $1, b1, 7; \n"
46
+ "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
47
+ _semantic=_semantic)
48
+
49
+
50
+ @core.builtin
51
+ def convert_float16_to_fp8e4b15(arg, has_minx2, _semantic=None):
52
+ asm = """{
53
+ .reg .pred p<4>;
54
+ .reg .b32 a<2>, b<2>;
55
+ .reg .b16 c<4>;
56
+ .reg .b16 max_val_f16;
57
+ .reg .b32 max_val_f16x2;
58
+ mov.b16 max_val_f16, 0x3F00;
59
+ mov.b32 max_val_f16x2, 0x3F003F00;
60
+ and.b32 a0, $1, 0x7fff7fff;
61
+ and.b32 a1, $2, 0x7fff7fff;"""
62
+ if has_minx2:
63
+ asm += """min.f16x2 a0, a0, max_val_f16x2;
64
+ min.f16x2 a1, a1, max_val_f16x2;"""
65
+ else:
66
+ asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2;
67
+ setp.lt.f16x2 p2|p3, a1, max_val_f16x2;
68
+ mov.b32 {c0, c1}, a0;
69
+ mov.b32 {c2, c3}, a1;
70
+ selp.b16 c0, c0, max_val_f16, p0;
71
+ selp.b16 c1, c1, max_val_f16, p1;
72
+ selp.b16 c2, c2, max_val_f16, p2;
73
+ selp.b16 c3, c3, max_val_f16, p3;
74
+ mov.b32 a0, {c0, c1};
75
+ mov.b32 a1, {c2, c3};"""
76
+ asm += """mad.lo.u32 a0, a0, 2, 0x00800080;
77
+ mad.lo.u32 a1, a1, 2, 0x00800080;
78
+ lop3.b32 b0, $1, 0x80008000, a0, 0xea;
79
+ lop3.b32 b1, $2, 0x80008000, a1, 0xea;
80
+ prmt.b32 $0, b0, b1, 0x7531;
81
+ }"""
82
+ return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
83
+ _semantic=_semantic)
84
+
85
+
86
+ @core.builtin
87
+ def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _semantic=None):
88
+ if arg.type.scalar.is_fp8e4b15():
89
+ upcast_val = convert_fp8e4b15_to_float16(arg, _semantic=_semantic)
90
+ if dst_ty.scalar.is_fp32():
91
+ upcast_val = upcast_val.to(core.float32, _semantic=_semantic)
92
+ return upcast_val
93
+
94
+ assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
95
+ downcast_val = arg
96
+ if arg.type.scalar.is_fp32():
97
+ downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _semantic=_semantic)
98
+ downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _semantic=_semantic)
99
+ return downcast_val
100
+
101
+
102
+ @core.builtin
103
+ def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
104
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _semantic=_semantic)
105
+
106
+
107
+ @core.builtin
108
+ def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
109
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _semantic=_semantic)
@@ -0,0 +1,5 @@
1
+ from . import libdevice
2
+
3
+ from .utils import memrealtime
4
+
5
+ __all__ = ["libdevice", "memrealtime"]
@@ -0,0 +1,491 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def abs(arg0, _semantic=None):
6
+ return core.extern_elementwise(
7
+ "", "", [arg0], {
8
+ (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")),
9
+ (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")),
10
+ (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")),
11
+ (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")),
12
+ }, is_pure=True, _semantic=_semantic)
13
+
14
+
15
+ @core.extern
16
+ def floor(arg0, _semantic=None):
17
+ return core.extern_elementwise(
18
+ "", "", [arg0], {
19
+ (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")),
20
+ (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")),
21
+ }, is_pure=True, _semantic=_semantic)
22
+
23
+
24
+ @core.extern
25
+ def rsqrt(arg0, _semantic=None):
26
+ return core.extern_elementwise(
27
+ "", "", [arg0], {
28
+ (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")),
29
+ (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")),
30
+ }, is_pure=True, _semantic=_semantic)
31
+
32
+
33
+ @core.extern
34
+ def ceil(arg0, _semantic=None):
35
+ return core.extern_elementwise(
36
+ "", "", [arg0], {
37
+ (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")),
38
+ (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")),
39
+ }, is_pure=True, _semantic=_semantic)
40
+
41
+
42
+ @core.extern
43
+ def trunc(arg0, _semantic=None):
44
+ return core.extern_elementwise(
45
+ "", "", [arg0], {
46
+ (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")),
47
+ (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")),
48
+ }, is_pure=True, _semantic=_semantic)
49
+
50
+
51
+ @core.extern
52
+ def exp2(arg0, _semantic=None):
53
+ return core.extern_elementwise(
54
+ "", "", [arg0], {
55
+ (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")),
56
+ (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")),
57
+ }, is_pure=True, _semantic=_semantic)
58
+
59
+
60
+ @core.extern
61
+ def exp(arg0, _semantic=None):
62
+ return core.extern_elementwise(
63
+ "", "", [arg0], {
64
+ (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")),
65
+ (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")),
66
+ }, is_pure=True, _semantic=_semantic)
67
+
68
+
69
+ @core.extern
70
+ def fast_expf(arg0, _semantic=None):
71
+ return core.extern_elementwise("", "", [arg0], {
72
+ (core.dtype("fp32"), ): ("__triton_hip_fast_expf", core.dtype("fp32")),
73
+ }, is_pure=True, _semantic=_semantic)
74
+
75
+
76
+ @core.extern
77
+ def fast_tanhf(arg0, _semantic=None):
78
+ return core.extern_elementwise("", "", [arg0], {
79
+ (core.dtype("fp32"), ): ("__triton_hip_fast_tanhf", core.dtype("fp32")),
80
+ }, is_pure=True, _semantic=_semantic)
81
+
82
+
83
+ @core.extern
84
+ def fast_dividef(arg0, arg1, _semantic=None):
85
+ return core.extern_elementwise("", "", [arg0, arg1], {
86
+ (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")),
87
+ }, is_pure=True, _semantic=_semantic)
88
+
89
+
90
+ @core.extern
91
+ def sqrt(arg0, _semantic=None):
92
+ return core.extern_elementwise(
93
+ "", "", [arg0], {
94
+ (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")),
95
+ (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")),
96
+ }, is_pure=True, _semantic=_semantic)
97
+
98
+
99
+ @core.extern
100
+ def llrint(arg0, _semantic=None):
101
+ return core.extern_elementwise(
102
+ "", "", [arg0], {
103
+ (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")),
104
+ (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")),
105
+ }, is_pure=True, _semantic=_semantic)
106
+
107
+
108
+ @core.extern
109
+ def nearbyint(arg0, _semantic=None):
110
+ return core.extern_elementwise(
111
+ "", "", [
112
+ arg0,
113
+ ], {
114
+ (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")),
115
+ (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")),
116
+ }, is_pure=True, _semantic=_semantic)
117
+
118
+
119
+ @core.extern
120
+ def isnan(arg0, _semantic=None):
121
+ return core.extern_elementwise(
122
+ "", "", [
123
+ arg0,
124
+ ], {
125
+ (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")),
126
+ (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")),
127
+ }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
128
+
129
+
130
+ @core.extern
131
+ def signbit(arg0, _semantic=None):
132
+ return core.extern_elementwise(
133
+ "", "", [
134
+ arg0,
135
+ ], {
136
+ (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")),
137
+ (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")),
138
+ }, is_pure=True, _semantic=_semantic)
139
+
140
+
141
+ @core.extern
142
+ def copysign(arg0, arg1, _semantic=None):
143
+ return core.extern_elementwise(
144
+ "", "", [arg0, arg1], {
145
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")),
146
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")),
147
+ }, is_pure=True, _semantic=_semantic)
148
+
149
+
150
+ @core.extern
151
+ def isinf(arg0, _semantic=None):
152
+ return core.extern_elementwise(
153
+ "", "", [arg0], {
154
+ (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")),
155
+ (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")),
156
+ }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
157
+
158
+
159
+ @core.extern
160
+ def nextafter(arg0, arg1, _semantic=None):
161
+ return core.extern_elementwise(
162
+ "", "", [arg0, arg1], {
163
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")),
164
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")),
165
+ }, is_pure=True, _semantic=_semantic)
166
+
167
+
168
+ @core.extern
169
+ def sin(arg0, _semantic=None):
170
+ return core.extern_elementwise(
171
+ "", "", [arg0], {
172
+ (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")),
173
+ (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")),
174
+ }, is_pure=True, _semantic=_semantic)
175
+
176
+
177
+ @core.extern
178
+ def cos(arg0, _semantic=None):
179
+ return core.extern_elementwise(
180
+ "", "", [arg0], {
181
+ (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")),
182
+ (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")),
183
+ }, is_pure=True, _semantic=_semantic)
184
+
185
+
186
+ @core.extern
187
+ def tan(arg0, _semantic=None):
188
+ return core.extern_elementwise(
189
+ "", "", [arg0], {
190
+ (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")),
191
+ (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")),
192
+ }, is_pure=True, _semantic=_semantic)
193
+
194
+
195
+ @core.extern
196
+ def log2(arg0, _semantic=None):
197
+ return core.extern_elementwise(
198
+ "", "", [arg0], {
199
+ (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")),
200
+ (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")),
201
+ }, is_pure=True, _semantic=_semantic)
202
+
203
+
204
+ @core.extern
205
+ def cosh(arg0, _semantic=None):
206
+ return core.extern_elementwise(
207
+ "", "", [arg0], {
208
+ (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")),
209
+ (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")),
210
+ }, is_pure=True, _semantic=_semantic)
211
+
212
+
213
+ @core.extern
214
+ def sinh(arg0, _semantic=None):
215
+ return core.extern_elementwise(
216
+ "", "", [arg0], {
217
+ (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")),
218
+ (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")),
219
+ }, is_pure=True, _semantic=_semantic)
220
+
221
+
222
+ @core.extern
223
+ def tanh(arg0, _semantic=None):
224
+ return core.extern_elementwise(
225
+ "", "", [arg0], {
226
+ (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")),
227
+ (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")),
228
+ }, is_pure=True, _semantic=_semantic)
229
+
230
+
231
+ @core.extern
232
+ def atan2(arg0, arg1, _semantic=None):
233
+ return core.extern_elementwise(
234
+ "", "", [arg0, arg1], {
235
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")),
236
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")),
237
+ }, is_pure=True, _semantic=_semantic)
238
+
239
+
240
+ @core.extern
241
+ def atan(arg0, _semantic=None):
242
+ return core.extern_elementwise(
243
+ "", "", [arg0], {
244
+ (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")),
245
+ (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")),
246
+ }, is_pure=True, _semantic=_semantic)
247
+
248
+
249
+ @core.extern
250
+ def asin(arg0, _semantic=None):
251
+ return core.extern_elementwise(
252
+ "", "", [arg0], {
253
+ (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")),
254
+ (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")),
255
+ }, is_pure=True, _semantic=_semantic)
256
+
257
+
258
+ @core.extern
259
+ def acos(arg0, _semantic=None):
260
+ return core.extern_elementwise(
261
+ "", "", [arg0], {
262
+ (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")),
263
+ (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")),
264
+ }, is_pure=True, _semantic=_semantic)
265
+
266
+
267
+ @core.extern
268
+ def log(arg0, _semantic=None):
269
+ return core.extern_elementwise(
270
+ "", "", [arg0], {
271
+ (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")),
272
+ (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")),
273
+ }, is_pure=True, _semantic=_semantic)
274
+
275
+
276
+ @core.extern
277
+ def log10(arg0, _semantic=None):
278
+ return core.extern_elementwise(
279
+ "", "", [arg0], {
280
+ (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")),
281
+ (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")),
282
+ }, is_pure=True, _semantic=_semantic)
283
+
284
+
285
+ @core.extern
286
+ def log1p(arg0, _semantic=None):
287
+ return core.extern_elementwise(
288
+ "", "", [arg0], {
289
+ (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")),
290
+ (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")),
291
+ }, is_pure=True, _semantic=_semantic)
292
+
293
+
294
+ @core.extern
295
+ def acosh(arg0, _semantic=None):
296
+ return core.extern_elementwise(
297
+ "", "", [arg0], {
298
+ (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")),
299
+ (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")),
300
+ }, is_pure=True, _semantic=_semantic)
301
+
302
+
303
+ @core.extern
304
+ def asinh(arg0, _semantic=None):
305
+ return core.extern_elementwise(
306
+ "", "", [arg0], {
307
+ (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")),
308
+ (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")),
309
+ }, is_pure=True, _semantic=_semantic)
310
+
311
+
312
+ @core.extern
313
+ def atanh(arg0, _semantic=None):
314
+ return core.extern_elementwise(
315
+ "", "", [arg0], {
316
+ (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")),
317
+ (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")),
318
+ }, is_pure=True, _semantic=_semantic)
319
+
320
+
321
+ @core.extern
322
+ def expm1(arg0, _semantic=None):
323
+ return core.extern_elementwise(
324
+ "", "", [arg0], {
325
+ (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")),
326
+ (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")),
327
+ }, is_pure=True, _semantic=_semantic)
328
+
329
+
330
+ @core.extern
331
+ def hypot(arg0, arg1, _semantic=None):
332
+ return core.extern_elementwise(
333
+ "", "", [arg0, arg1], {
334
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")),
335
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")),
336
+ }, is_pure=True, _semantic=_semantic)
337
+
338
+
339
+ @core.extern
340
+ def j0(arg0, _semantic=None):
341
+ return core.extern_elementwise(
342
+ "", "", [arg0], {
343
+ (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")),
344
+ (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")),
345
+ }, is_pure=True, _semantic=_semantic)
346
+
347
+
348
+ @core.extern
349
+ def j1(arg0, _semantic=None):
350
+ return core.extern_elementwise(
351
+ "", "", [arg0], {
352
+ (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")),
353
+ (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")),
354
+ }, is_pure=True, _semantic=_semantic)
355
+
356
+
357
+ @core.extern
358
+ def y0(arg0, _semantic=None):
359
+ return core.extern_elementwise(
360
+ "", "", [arg0], {
361
+ (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")),
362
+ (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")),
363
+ }, is_pure=True, _semantic=_semantic)
364
+
365
+
366
+ @core.extern
367
+ def y1(arg0, _semantic=None):
368
+ return core.extern_elementwise(
369
+ "", "", [arg0], {
370
+ (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")),
371
+ (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")),
372
+ }, is_pure=True, _semantic=_semantic)
373
+
374
+
375
+ @core.extern
376
+ def cyl_bessel_i0(arg0, _semantic=None):
377
+ return core.extern_elementwise(
378
+ "", "", [arg0], {
379
+ (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")),
380
+ (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")),
381
+ }, is_pure=True, _semantic=_semantic)
382
+
383
+
384
+ @core.extern
385
+ def cyl_bessel_i1(arg0, _semantic=None):
386
+ return core.extern_elementwise(
387
+ "", "", [arg0], {
388
+ (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")),
389
+ (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")),
390
+ }, is_pure=True, _semantic=_semantic)
391
+
392
+
393
+ @core.extern
394
+ def erf(arg0, _semantic=None):
395
+ return core.extern_elementwise(
396
+ "", "", [arg0], {
397
+ (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")),
398
+ (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")),
399
+ }, is_pure=True, _semantic=_semantic)
400
+
401
+
402
+ @core.extern
403
+ def erfinv(arg0, _semantic=None):
404
+ return core.extern_elementwise(
405
+ "", "", [arg0], {
406
+ (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")),
407
+ (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")),
408
+ }, is_pure=True, _semantic=_semantic)
409
+
410
+
411
+ @core.extern
412
+ def erfc(arg0, _semantic=None):
413
+ return core.extern_elementwise(
414
+ "", "", [arg0], {
415
+ (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")),
416
+ (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")),
417
+ }, is_pure=True, _semantic=_semantic)
418
+
419
+
420
+ @core.extern
421
+ def erfcx(arg0, _semantic=None):
422
+ return core.extern_elementwise(
423
+ "", "", [arg0], {
424
+ (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")),
425
+ (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")),
426
+ }, is_pure=True, _semantic=_semantic)
427
+
428
+
429
+ @core.extern
430
+ def lgamma(arg0, _semantic=None):
431
+ return core.extern_elementwise(
432
+ "", "", [arg0], {
433
+ (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")),
434
+ (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")),
435
+ }, is_pure=True, _semantic=_semantic)
436
+
437
+
438
+ @core.extern
439
+ def ldexp(arg0, arg1, _semantic=None):
440
+ return core.extern_elementwise(
441
+ "", "", [arg0, arg1], {
442
+ (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")),
443
+ (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")),
444
+ }, is_pure=True, _semantic=_semantic)
445
+
446
+
447
+ @core.extern
448
+ def fmod(arg0, arg1, _semantic=None):
449
+ return core.extern_elementwise(
450
+ "", "", [arg0, arg1], {
451
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")),
452
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")),
453
+ }, is_pure=True, _semantic=_semantic)
454
+
455
+
456
+ @core.extern
457
+ def fma(arg0, arg1, arg2, _semantic=None):
458
+ return core.extern_elementwise(
459
+ "", "", [arg0, arg1, arg2], {
460
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")),
461
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")),
462
+ }, is_pure=True, _semantic=_semantic)
463
+
464
+
465
+ @core.extern
466
+ def pow(arg0, arg1, _semantic=None):
467
+ return core.extern_elementwise(
468
+ "", "", [arg0, arg1], {
469
+ (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")),
470
+ (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")),
471
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")),
472
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")),
473
+ }, is_pure=True, _semantic=_semantic)
474
+
475
+
476
+ @core.extern
477
+ def ilogb(arg0, _semantic=None):
478
+ return core.extern_elementwise(
479
+ "", "", [arg0], {
480
+ (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")),
481
+ (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")),
482
+ }, is_pure=True, _semantic=_semantic)
483
+
484
+
485
+ @core.extern
486
+ def round(arg0, _semantic=None):
487
+ return core.extern_elementwise(
488
+ "", "", [arg0], {
489
+ (core.dtype("fp32"), ): ("__ocml_round_f32", core.dtype("fp32")),
490
+ (core.dtype("fp64"), ): ("__ocml_round_f64", core.dtype("fp64")),
491
+ }, is_pure=True, _semantic=_semantic)
@@ -0,0 +1,35 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def memrealtime(_semantic=None):
6
+ """
7
+ Returns a 64-bit real time-counter value
8
+ """
9
+ target_arch = _semantic.builder.options.arch
10
+ if 'gfx11' in target_arch or 'gfx12' in target_arch:
11
+ return core.inline_asm_elementwise(
12
+ """
13
+ s_sendmsg_rtn_b64 $0, sendmsg(MSG_RTN_GET_REALTIME)
14
+ s_waitcnt lgkmcnt(0)
15
+ """,
16
+ "=r",
17
+ [],
18
+ dtype=core.int64,
19
+ is_pure=False,
20
+ pack=1,
21
+ _semantic=_semantic,
22
+ )
23
+ else:
24
+ return core.inline_asm_elementwise(
25
+ """
26
+ s_memrealtime $0
27
+ s_waitcnt vmcnt(0)
28
+ """,
29
+ "=r",
30
+ [],
31
+ dtype=core.int64,
32
+ is_pure=False,
33
+ pack=1,
34
+ _semantic=_semantic,
35
+ )