triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,109 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def globaltimer(_semantic=None):
6
+ return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
7
+ _semantic=_semantic)
8
+
9
+
10
+ @core.extern
11
+ def smid(_semantic=None):
12
+ return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
13
+ _semantic=_semantic)
14
+
15
+
16
+ @core.builtin
17
+ def num_threads(_semantic=None):
18
+ return core.constexpr(_semantic.builder.options.num_warps * 32)
19
+
20
+
21
+ @core.builtin
22
+ def num_warps(_semantic=None):
23
+ return core.constexpr(_semantic.builder.options.num_warps)
24
+
25
+
26
+ # ----- FP8E4M3B15 ------
27
+ # This data-type is a variant of the standard FP8E4M3 format.
28
+ # It was designed for fast software conversion to FP16 on
29
+ # nvidia GPUs that do not support it natively.
30
+ # This is the same format as FP8E4M3Nv, but:
31
+ # - the exponent bias is 15 instead of 7
32
+ # - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
33
+ @core.builtin
34
+ def convert_fp8e4b15_to_float16(arg, _semantic=None):
35
+ return core.inline_asm_elementwise(
36
+ "{ \n"
37
+ ".reg .b32 a<2>, b<2>; \n"
38
+ "prmt.b32 a0, 0, $2, 0x5746; \n"
39
+ "and.b32 b0, a0, 0x7f007f00; \n"
40
+ "and.b32 b1, a0, 0x00ff00ff; \n"
41
+ "and.b32 a1, a0, 0x00800080; \n"
42
+ "shr.b32 b0, b0, 1; \n"
43
+ "add.u32 b1, b1, a1; \n"
44
+ "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
45
+ "shl.b32 $1, b1, 7; \n"
46
+ "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
47
+ _semantic=_semantic)
48
+
49
+
50
+ @core.builtin
51
+ def convert_float16_to_fp8e4b15(arg, has_minx2, _semantic=None):
52
+ asm = """{
53
+ .reg .pred p<4>;
54
+ .reg .b32 a<2>, b<2>;
55
+ .reg .b16 c<4>;
56
+ .reg .b16 max_val_f16;
57
+ .reg .b32 max_val_f16x2;
58
+ mov.b16 max_val_f16, 0x3F00;
59
+ mov.b32 max_val_f16x2, 0x3F003F00;
60
+ and.b32 a0, $1, 0x7fff7fff;
61
+ and.b32 a1, $2, 0x7fff7fff;"""
62
+ if has_minx2:
63
+ asm += """min.f16x2 a0, a0, max_val_f16x2;
64
+ min.f16x2 a1, a1, max_val_f16x2;"""
65
+ else:
66
+ asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2;
67
+ setp.lt.f16x2 p2|p3, a1, max_val_f16x2;
68
+ mov.b32 {c0, c1}, a0;
69
+ mov.b32 {c2, c3}, a1;
70
+ selp.b16 c0, c0, max_val_f16, p0;
71
+ selp.b16 c1, c1, max_val_f16, p1;
72
+ selp.b16 c2, c2, max_val_f16, p2;
73
+ selp.b16 c3, c3, max_val_f16, p3;
74
+ mov.b32 a0, {c0, c1};
75
+ mov.b32 a1, {c2, c3};"""
76
+ asm += """mad.lo.u32 a0, a0, 2, 0x00800080;
77
+ mad.lo.u32 a1, a1, 2, 0x00800080;
78
+ lop3.b32 b0, $1, 0x80008000, a0, 0xea;
79
+ lop3.b32 b1, $2, 0x80008000, a1, 0xea;
80
+ prmt.b32 $0, b0, b1, 0x7531;
81
+ }"""
82
+ return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
83
+ _semantic=_semantic)
84
+
85
+
86
+ @core.builtin
87
+ def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _semantic=None):
88
+ if arg.type.scalar.is_fp8e4b15():
89
+ upcast_val = convert_fp8e4b15_to_float16(arg, _semantic=_semantic)
90
+ if dst_ty.scalar.is_fp32():
91
+ upcast_val = upcast_val.to(core.float32, _semantic=_semantic)
92
+ return upcast_val
93
+
94
+ assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
95
+ downcast_val = arg
96
+ if arg.type.scalar.is_fp32():
97
+ downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _semantic=_semantic)
98
+ downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _semantic=_semantic)
99
+ return downcast_val
100
+
101
+
102
+ @core.builtin
103
+ def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
104
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _semantic=_semantic)
105
+
106
+
107
+ @core.builtin
108
+ def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
109
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _semantic=_semantic)
@@ -0,0 +1,5 @@
1
+ from . import libdevice
2
+
3
+ from .utils import memrealtime
4
+
5
+ __all__ = ["libdevice", "memrealtime"]
@@ -0,0 +1,491 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def abs(arg0, _semantic=None):
6
+ return core.extern_elementwise(
7
+ "", "", [arg0], {
8
+ (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")),
9
+ (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")),
10
+ (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")),
11
+ (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")),
12
+ }, is_pure=True, _semantic=_semantic)
13
+
14
+
15
+ @core.extern
16
+ def floor(arg0, _semantic=None):
17
+ return core.extern_elementwise(
18
+ "", "", [arg0], {
19
+ (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")),
20
+ (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")),
21
+ }, is_pure=True, _semantic=_semantic)
22
+
23
+
24
+ @core.extern
25
+ def rsqrt(arg0, _semantic=None):
26
+ return core.extern_elementwise(
27
+ "", "", [arg0], {
28
+ (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")),
29
+ (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")),
30
+ }, is_pure=True, _semantic=_semantic)
31
+
32
+
33
+ @core.extern
34
+ def ceil(arg0, _semantic=None):
35
+ return core.extern_elementwise(
36
+ "", "", [arg0], {
37
+ (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")),
38
+ (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")),
39
+ }, is_pure=True, _semantic=_semantic)
40
+
41
+
42
+ @core.extern
43
+ def trunc(arg0, _semantic=None):
44
+ return core.extern_elementwise(
45
+ "", "", [arg0], {
46
+ (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")),
47
+ (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")),
48
+ }, is_pure=True, _semantic=_semantic)
49
+
50
+
51
+ @core.extern
52
+ def exp2(arg0, _semantic=None):
53
+ return core.extern_elementwise(
54
+ "", "", [arg0], {
55
+ (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")),
56
+ (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")),
57
+ }, is_pure=True, _semantic=_semantic)
58
+
59
+
60
+ @core.extern
61
+ def exp(arg0, _semantic=None):
62
+ return core.extern_elementwise(
63
+ "", "", [arg0], {
64
+ (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")),
65
+ (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")),
66
+ }, is_pure=True, _semantic=_semantic)
67
+
68
+
69
+ @core.extern
70
+ def fast_expf(arg0, _semantic=None):
71
+ return core.extern_elementwise("", "", [arg0], {
72
+ (core.dtype("fp32"), ): ("__triton_hip_fast_expf", core.dtype("fp32")),
73
+ }, is_pure=True, _semantic=_semantic)
74
+
75
+
76
+ @core.extern
77
+ def fast_tanhf(arg0, _semantic=None):
78
+ return core.extern_elementwise("", "", [arg0], {
79
+ (core.dtype("fp32"), ): ("__triton_hip_fast_tanhf", core.dtype("fp32")),
80
+ }, is_pure=True, _semantic=_semantic)
81
+
82
+
83
+ @core.extern
84
+ def fast_dividef(arg0, arg1, _semantic=None):
85
+ return core.extern_elementwise("", "", [arg0, arg1], {
86
+ (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")),
87
+ }, is_pure=True, _semantic=_semantic)
88
+
89
+
90
+ @core.extern
91
+ def sqrt(arg0, _semantic=None):
92
+ return core.extern_elementwise(
93
+ "", "", [arg0], {
94
+ (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")),
95
+ (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")),
96
+ }, is_pure=True, _semantic=_semantic)
97
+
98
+
99
+ @core.extern
100
+ def llrint(arg0, _semantic=None):
101
+ return core.extern_elementwise(
102
+ "", "", [arg0], {
103
+ (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")),
104
+ (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")),
105
+ }, is_pure=True, _semantic=_semantic)
106
+
107
+
108
+ @core.extern
109
+ def nearbyint(arg0, _semantic=None):
110
+ return core.extern_elementwise(
111
+ "", "", [
112
+ arg0,
113
+ ], {
114
+ (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")),
115
+ (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")),
116
+ }, is_pure=True, _semantic=_semantic)
117
+
118
+
119
+ @core.extern
120
+ def isnan(arg0, _semantic=None):
121
+ return core.extern_elementwise(
122
+ "", "", [
123
+ arg0,
124
+ ], {
125
+ (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")),
126
+ (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")),
127
+ }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
128
+
129
+
130
+ @core.extern
131
+ def signbit(arg0, _semantic=None):
132
+ return core.extern_elementwise(
133
+ "", "", [
134
+ arg0,
135
+ ], {
136
+ (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")),
137
+ (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")),
138
+ }, is_pure=True, _semantic=_semantic)
139
+
140
+
141
+ @core.extern
142
+ def copysign(arg0, arg1, _semantic=None):
143
+ return core.extern_elementwise(
144
+ "", "", [arg0, arg1], {
145
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")),
146
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")),
147
+ }, is_pure=True, _semantic=_semantic)
148
+
149
+
150
+ @core.extern
151
+ def isinf(arg0, _semantic=None):
152
+ return core.extern_elementwise(
153
+ "", "", [arg0], {
154
+ (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")),
155
+ (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")),
156
+ }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
157
+
158
+
159
+ @core.extern
160
+ def nextafter(arg0, arg1, _semantic=None):
161
+ return core.extern_elementwise(
162
+ "", "", [arg0, arg1], {
163
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")),
164
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")),
165
+ }, is_pure=True, _semantic=_semantic)
166
+
167
+
168
+ @core.extern
169
+ def sin(arg0, _semantic=None):
170
+ return core.extern_elementwise(
171
+ "", "", [arg0], {
172
+ (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")),
173
+ (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")),
174
+ }, is_pure=True, _semantic=_semantic)
175
+
176
+
177
+ @core.extern
178
+ def cos(arg0, _semantic=None):
179
+ return core.extern_elementwise(
180
+ "", "", [arg0], {
181
+ (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")),
182
+ (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")),
183
+ }, is_pure=True, _semantic=_semantic)
184
+
185
+
186
+ @core.extern
187
+ def tan(arg0, _semantic=None):
188
+ return core.extern_elementwise(
189
+ "", "", [arg0], {
190
+ (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")),
191
+ (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")),
192
+ }, is_pure=True, _semantic=_semantic)
193
+
194
+
195
+ @core.extern
196
+ def log2(arg0, _semantic=None):
197
+ return core.extern_elementwise(
198
+ "", "", [arg0], {
199
+ (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")),
200
+ (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")),
201
+ }, is_pure=True, _semantic=_semantic)
202
+
203
+
204
+ @core.extern
205
+ def cosh(arg0, _semantic=None):
206
+ return core.extern_elementwise(
207
+ "", "", [arg0], {
208
+ (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")),
209
+ (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")),
210
+ }, is_pure=True, _semantic=_semantic)
211
+
212
+
213
+ @core.extern
214
+ def sinh(arg0, _semantic=None):
215
+ return core.extern_elementwise(
216
+ "", "", [arg0], {
217
+ (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")),
218
+ (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")),
219
+ }, is_pure=True, _semantic=_semantic)
220
+
221
+
222
+ @core.extern
223
+ def tanh(arg0, _semantic=None):
224
+ return core.extern_elementwise(
225
+ "", "", [arg0], {
226
+ (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")),
227
+ (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")),
228
+ }, is_pure=True, _semantic=_semantic)
229
+
230
+
231
+ @core.extern
232
+ def atan2(arg0, arg1, _semantic=None):
233
+ return core.extern_elementwise(
234
+ "", "", [arg0, arg1], {
235
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")),
236
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")),
237
+ }, is_pure=True, _semantic=_semantic)
238
+
239
+
240
+ @core.extern
241
+ def atan(arg0, _semantic=None):
242
+ return core.extern_elementwise(
243
+ "", "", [arg0], {
244
+ (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")),
245
+ (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")),
246
+ }, is_pure=True, _semantic=_semantic)
247
+
248
+
249
+ @core.extern
250
+ def asin(arg0, _semantic=None):
251
+ return core.extern_elementwise(
252
+ "", "", [arg0], {
253
+ (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")),
254
+ (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")),
255
+ }, is_pure=True, _semantic=_semantic)
256
+
257
+
258
+ @core.extern
259
+ def acos(arg0, _semantic=None):
260
+ return core.extern_elementwise(
261
+ "", "", [arg0], {
262
+ (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")),
263
+ (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")),
264
+ }, is_pure=True, _semantic=_semantic)
265
+
266
+
267
+ @core.extern
268
+ def log(arg0, _semantic=None):
269
+ return core.extern_elementwise(
270
+ "", "", [arg0], {
271
+ (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")),
272
+ (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")),
273
+ }, is_pure=True, _semantic=_semantic)
274
+
275
+
276
+ @core.extern
277
+ def log10(arg0, _semantic=None):
278
+ return core.extern_elementwise(
279
+ "", "", [arg0], {
280
+ (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")),
281
+ (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")),
282
+ }, is_pure=True, _semantic=_semantic)
283
+
284
+
285
+ @core.extern
286
+ def log1p(arg0, _semantic=None):
287
+ return core.extern_elementwise(
288
+ "", "", [arg0], {
289
+ (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")),
290
+ (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")),
291
+ }, is_pure=True, _semantic=_semantic)
292
+
293
+
294
+ @core.extern
295
+ def acosh(arg0, _semantic=None):
296
+ return core.extern_elementwise(
297
+ "", "", [arg0], {
298
+ (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")),
299
+ (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")),
300
+ }, is_pure=True, _semantic=_semantic)
301
+
302
+
303
+ @core.extern
304
+ def asinh(arg0, _semantic=None):
305
+ return core.extern_elementwise(
306
+ "", "", [arg0], {
307
+ (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")),
308
+ (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")),
309
+ }, is_pure=True, _semantic=_semantic)
310
+
311
+
312
+ @core.extern
313
+ def atanh(arg0, _semantic=None):
314
+ return core.extern_elementwise(
315
+ "", "", [arg0], {
316
+ (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")),
317
+ (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")),
318
+ }, is_pure=True, _semantic=_semantic)
319
+
320
+
321
+ @core.extern
322
+ def expm1(arg0, _semantic=None):
323
+ return core.extern_elementwise(
324
+ "", "", [arg0], {
325
+ (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")),
326
+ (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")),
327
+ }, is_pure=True, _semantic=_semantic)
328
+
329
+
330
+ @core.extern
331
+ def hypot(arg0, arg1, _semantic=None):
332
+ return core.extern_elementwise(
333
+ "", "", [arg0, arg1], {
334
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")),
335
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")),
336
+ }, is_pure=True, _semantic=_semantic)
337
+
338
+
339
+ @core.extern
340
+ def j0(arg0, _semantic=None):
341
+ return core.extern_elementwise(
342
+ "", "", [arg0], {
343
+ (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")),
344
+ (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")),
345
+ }, is_pure=True, _semantic=_semantic)
346
+
347
+
348
+ @core.extern
349
+ def j1(arg0, _semantic=None):
350
+ return core.extern_elementwise(
351
+ "", "", [arg0], {
352
+ (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")),
353
+ (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")),
354
+ }, is_pure=True, _semantic=_semantic)
355
+
356
+
357
+ @core.extern
358
+ def y0(arg0, _semantic=None):
359
+ return core.extern_elementwise(
360
+ "", "", [arg0], {
361
+ (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")),
362
+ (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")),
363
+ }, is_pure=True, _semantic=_semantic)
364
+
365
+
366
+ @core.extern
367
+ def y1(arg0, _semantic=None):
368
+ return core.extern_elementwise(
369
+ "", "", [arg0], {
370
+ (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")),
371
+ (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")),
372
+ }, is_pure=True, _semantic=_semantic)
373
+
374
+
375
+ @core.extern
376
+ def cyl_bessel_i0(arg0, _semantic=None):
377
+ return core.extern_elementwise(
378
+ "", "", [arg0], {
379
+ (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")),
380
+ (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")),
381
+ }, is_pure=True, _semantic=_semantic)
382
+
383
+
384
+ @core.extern
385
+ def cyl_bessel_i1(arg0, _semantic=None):
386
+ return core.extern_elementwise(
387
+ "", "", [arg0], {
388
+ (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")),
389
+ (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")),
390
+ }, is_pure=True, _semantic=_semantic)
391
+
392
+
393
+ @core.extern
394
+ def erf(arg0, _semantic=None):
395
+ return core.extern_elementwise(
396
+ "", "", [arg0], {
397
+ (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")),
398
+ (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")),
399
+ }, is_pure=True, _semantic=_semantic)
400
+
401
+
402
+ @core.extern
403
+ def erfinv(arg0, _semantic=None):
404
+ return core.extern_elementwise(
405
+ "", "", [arg0], {
406
+ (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")),
407
+ (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")),
408
+ }, is_pure=True, _semantic=_semantic)
409
+
410
+
411
+ @core.extern
412
+ def erfc(arg0, _semantic=None):
413
+ return core.extern_elementwise(
414
+ "", "", [arg0], {
415
+ (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")),
416
+ (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")),
417
+ }, is_pure=True, _semantic=_semantic)
418
+
419
+
420
+ @core.extern
421
+ def erfcx(arg0, _semantic=None):
422
+ return core.extern_elementwise(
423
+ "", "", [arg0], {
424
+ (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")),
425
+ (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")),
426
+ }, is_pure=True, _semantic=_semantic)
427
+
428
+
429
+ @core.extern
430
+ def lgamma(arg0, _semantic=None):
431
+ return core.extern_elementwise(
432
+ "", "", [arg0], {
433
+ (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")),
434
+ (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")),
435
+ }, is_pure=True, _semantic=_semantic)
436
+
437
+
438
+ @core.extern
439
+ def ldexp(arg0, arg1, _semantic=None):
440
+ return core.extern_elementwise(
441
+ "", "", [arg0, arg1], {
442
+ (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")),
443
+ (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")),
444
+ }, is_pure=True, _semantic=_semantic)
445
+
446
+
447
+ @core.extern
448
+ def fmod(arg0, arg1, _semantic=None):
449
+ return core.extern_elementwise(
450
+ "", "", [arg0, arg1], {
451
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")),
452
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")),
453
+ }, is_pure=True, _semantic=_semantic)
454
+
455
+
456
+ @core.extern
457
+ def fma(arg0, arg1, arg2, _semantic=None):
458
+ return core.extern_elementwise(
459
+ "", "", [arg0, arg1, arg2], {
460
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")),
461
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")),
462
+ }, is_pure=True, _semantic=_semantic)
463
+
464
+
465
+ @core.extern
466
+ def pow(arg0, arg1, _semantic=None):
467
+ return core.extern_elementwise(
468
+ "", "", [arg0, arg1], {
469
+ (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")),
470
+ (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")),
471
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")),
472
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")),
473
+ }, is_pure=True, _semantic=_semantic)
474
+
475
+
476
+ @core.extern
477
+ def ilogb(arg0, _semantic=None):
478
+ return core.extern_elementwise(
479
+ "", "", [arg0], {
480
+ (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")),
481
+ (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")),
482
+ }, is_pure=True, _semantic=_semantic)
483
+
484
+
485
+ @core.extern
486
+ def round(arg0, _semantic=None):
487
+ return core.extern_elementwise(
488
+ "", "", [arg0], {
489
+ (core.dtype("fp32"), ): ("__ocml_round_f32", core.dtype("fp32")),
490
+ (core.dtype("fp64"), ): ("__ocml_round_f64", core.dtype("fp64")),
491
+ }, is_pure=True, _semantic=_semantic)
@@ -0,0 +1,35 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def memrealtime(_semantic=None):
6
+ """
7
+ Returns a 64-bit real time-counter value
8
+ """
9
+ target_arch = _semantic.builder.options.arch
10
+ if 'gfx11' in target_arch or 'gfx12' in target_arch:
11
+ return core.inline_asm_elementwise(
12
+ """
13
+ s_sendmsg_rtn_b64 $0, sendmsg(MSG_RTN_GET_REALTIME)
14
+ s_waitcnt lgkmcnt(0)
15
+ """,
16
+ "=r",
17
+ [],
18
+ dtype=core.int64,
19
+ is_pure=False,
20
+ pack=1,
21
+ _semantic=_semantic,
22
+ )
23
+ else:
24
+ return core.inline_asm_elementwise(
25
+ """
26
+ s_memrealtime $0
27
+ s_waitcnt vmcnt(0)
28
+ """,
29
+ "=r",
30
+ [],
31
+ dtype=core.int64,
32
+ is_pure=False,
33
+ pack=1,
34
+ _semantic=_semantic,
35
+ )