triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
triton/tools/disasm.py ADDED
@@ -0,0 +1,143 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2020 Da Yan @ HKUST
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import functools
24
+ import os
25
+ import re
26
+ import subprocess
27
+ import tempfile
28
+
29
+ FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
30
+ SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
31
+ FNAME_RE = re.compile(r'\s*Function : (\w+)\s*')
32
+ BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);')
33
+
34
+
35
+ def parseCtrl(sline):
36
+ enc = int(SLINE_RE.match(sline).group(1), 16)
37
+ stall = (enc >> 41) & 0xf
38
+ yld = (enc >> 45) & 0x1
39
+ wrtdb = (enc >> 46) & 0x7
40
+ readb = (enc >> 49) & 0x7
41
+ watdb = (enc >> 52) & 0x3f
42
+
43
+ yld_str = 'Y' if yld == 0 else '-'
44
+ wrtdb_str = '-' if wrtdb == 7 else str(wrtdb)
45
+ readb_str = '-' if readb == 7 else str(readb)
46
+ watdb_str = '--' if watdb == 0 else f'{watdb:02d}'
47
+ return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}'
48
+
49
+
50
+ def processSassLines(fline, sline, labels):
51
+ asm = FLINE_RE.match(fline).group(1)
52
+ # Remove tailing space
53
+ if asm.endswith(" ;"):
54
+ asm = asm[:-2] + ";"
55
+ ctrl = parseCtrl(sline)
56
+ # BRA target address
57
+ if BRA_RE.match(asm) is not None:
58
+ target = int(BRA_RE.match(asm).group(2), 16)
59
+ if target in labels:
60
+ pass
61
+ else:
62
+ labels[target] = len(labels)
63
+ return (f'{ctrl}', f'{asm}')
64
+
65
+
66
+ @functools.lru_cache()
67
+ def get_sass(cubin_asm, fun=None):
68
+ fd, path = tempfile.mkstemp()
69
+ try:
70
+ with open(fd, 'wb') as cubin:
71
+ cubin.write(cubin_asm)
72
+ sass = extract(path, fun)
73
+ finally:
74
+ os.remove(path)
75
+ return sass
76
+
77
+
78
+ def path_to_cuobjdump():
79
+ from triton import knobs
80
+ return knobs.nvidia.cuobjdump.path
81
+
82
+
83
+ def extract(file_path, fun):
84
+ cuobjdump = path_to_cuobjdump()
85
+ if fun is None:
86
+ sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
87
+ else:
88
+ sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path])
89
+ sass_lines = sass_str.splitlines()
90
+ line_idx = 0
91
+ while line_idx < len(sass_lines):
92
+ line = sass_lines[line_idx].decode()
93
+ # format:
94
+ # function : <function_name>
95
+ # .headerflags: ...
96
+ # /*0000*/ asmstr /*0x...*/
97
+ # /*0x...*/
98
+
99
+ # Looking for new function header (function: <name>)
100
+ while FNAME_RE.match(line) is None:
101
+ line_idx += 1
102
+ if line_idx < len(sass_lines):
103
+ line = sass_lines[line_idx].decode()
104
+ else:
105
+ return
106
+
107
+ fname = FNAME_RE.match(line).group(1)
108
+ ret = ''
109
+ ret += f'Function:{fname}\n'
110
+ line_idx += 2 # bypass .headerflags
111
+ line = sass_lines[line_idx].decode()
112
+ # Remapping address to label
113
+ labels = {} # address -> label_idx
114
+ # store sass asm in buffer and them print them (for labels)
115
+ # (ctrl, asm)
116
+ asm_buffer = []
117
+ while FLINE_RE.match(line) is not None:
118
+ # First line (Offset ASM Encoding)
119
+ fline = sass_lines[line_idx].decode()
120
+ line_idx += 1
121
+ # Second line (Encoding)
122
+ sline = sass_lines[line_idx].decode()
123
+ line_idx += 1
124
+ asm_buffer.append(processSassLines(fline, sline, labels))
125
+ # peek the next line
126
+ line = sass_lines[line_idx].decode()
127
+ # Print sass
128
+ # label naming convention: LBB#i
129
+ for idx, (ctrl, asm) in enumerate(asm_buffer):
130
+ # Print label if this is BRA target
131
+ offset = idx * 16
132
+ if offset in labels:
133
+ label_name = f'LBB{labels[offset]}'
134
+ ret += f'{label_name}:\n'
135
+ ret += ctrl + '\t'
136
+ # if this is BRA, remap offset to label
137
+ if BRA_RE.match(asm):
138
+ target = int(BRA_RE.match(asm).group(2), 16)
139
+ target_name = f'LBB{labels[target]}'
140
+ asm = BRA_RE.sub(rf'\1{target_name};', asm)
141
+ ret += asm + '\n'
142
+ ret += '\n'
143
+ return ret
@@ -0,0 +1,70 @@
1
+ /* clang-format off */
2
+ #include <stdio.h>
3
+ #include <stdint.h>
4
+ #include <inttypes.h>
5
+ #include <string.h>
6
+ #include <cuda.h>
7
+
8
+
9
+ // helpers to check for cuda errors
10
+ #define CUDA_CHECK(ans) {{\
11
+ gpuAssert((ans), __FILE__, __LINE__);\
12
+ }}\
13
+
14
+ static inline void gpuAssert(CUresult code, const char *file, int line) {{
15
+ if (code != CUDA_SUCCESS) {{
16
+ const char *prefix = "Triton Error [CUDA]: ";
17
+ const char *str;
18
+ cuGetErrorString(code, &str);
19
+ char err[1024] = {{0}};
20
+ strcat(err, prefix);
21
+ strcat(err, str);
22
+ printf("%s\\n", err);
23
+ exit(code);
24
+ }}
25
+ }}
26
+
27
+ // globals
28
+ #define CUBIN_NAME {kernel_name}_cubin
29
+ CUmodule {kernel_name}_mod = NULL;
30
+ CUfunction {kernel_name}_func = NULL;
31
+ unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }};
32
+
33
+
34
+ void unload_{kernel_name}(void) {{
35
+ CUDA_CHECK(cuModuleUnload({kernel_name}_mod));
36
+ }}
37
+
38
+ // TODO: some code duplication with `runtime/backend/cuda.c`
39
+ void load_{kernel_name}() {{
40
+ int dev = 0;
41
+ void *bin = (void *)&CUBIN_NAME;
42
+ int shared = {shared};
43
+ CUDA_CHECK(cuModuleLoadData(&{kernel_name}_mod, bin));
44
+ CUDA_CHECK(cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
45
+ // set dynamic shared memory if necessary
46
+ int shared_optin;
47
+ CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev));
48
+ if (shared > 49152 && shared_optin > 49152) {{
49
+ CUDA_CHECK(cuFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED));
50
+ CUDA_CHECK(cuFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin))
51
+ }}
52
+ }}
53
+
54
+ /*
55
+ {kernel_docstring}
56
+ */
57
+ CUresult {kernel_name}(CUstream stream, {signature}) {{
58
+ if ({kernel_name}_func == NULL)
59
+ load_{kernel_name}();
60
+ unsigned int gX = {gridX};
61
+ unsigned int gY = {gridY};
62
+ unsigned int gZ = {gridZ};
63
+ CUdeviceptr global_scratch = 0;
64
+ CUdeviceptr profile_scratch = 0;
65
+ void *args[{num_args}] = {{ {arg_pointers} }};
66
+ // TODO: shared memory
67
+ if(gX * gY * gZ > 0)
68
+ return cuLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL);
69
+ return (CUresult)NULL;
70
+ }}
@@ -0,0 +1,14 @@
1
+ #ifndef TT_KERNEL_INCLUDES
2
+ #define TT_KERNEL_INCLUDES
3
+
4
+ #include <cuda.h>
5
+ #include <inttypes.h>
6
+ #include <stdint.h>
7
+ #include <stdio.h>
8
+
9
+ #endif
10
+
11
+ void unload_{kernel_name}(void);
12
+ void load_{kernel_name}(void);
13
+ // tt-linker: {kernel_name}:{full_signature}:{algo_info}
14
+ CUresult{_placeholder} {kernel_name}(CUstream stream, {signature});
@@ -0,0 +1,66 @@
1
+ // SPDX-License-Identifier: MIT
2
+ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ /* clang-format off */
5
+ #include <stdio.h>
6
+ #include <stdint.h>
7
+ #include <inttypes.h>
8
+ #include <string.h>
9
+ #include <hip/hip_runtime.h>
10
+
11
+ // helpers to check for hip errors
12
+ #define HIP_CHECK(ans) {{\
13
+ gpuAssert((ans), __FILE__, __LINE__);\
14
+ }}\
15
+
16
+ static inline void gpuAssert(hipError_t code, const char *file, int line) {{
17
+ if (code != hipSuccess) {{
18
+ const char *prefix = "Triton Error [HIP]: ";
19
+ const char *str;
20
+ hipDrvGetErrorString(code, &str);
21
+ char err[1024] = {{0}};
22
+ strcat(err, prefix);
23
+ strcat(err, str);
24
+ printf("%s\\n", err);
25
+ exit(code);
26
+ }}
27
+ }}
28
+
29
+ // globals
30
+ #define HSACO_NAME {kernel_name}_hsaco
31
+ hipModule_t {kernel_name}_mod = nullptr;
32
+ hipFunction_t {kernel_name}_func = nullptr;
33
+ unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
34
+
35
+
36
+ void unload_{kernel_name}(void) {{
37
+ HIP_CHECK(hipModuleUnload({kernel_name}_mod));
38
+ }}
39
+
40
+
41
+ void load_{kernel_name}() {{
42
+ int dev = 0;
43
+ void *bin = (void *)&HSACO_NAME;
44
+ int shared = {shared};
45
+ HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
46
+ HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
47
+ }}
48
+
49
+ /*
50
+ {kernel_docstring}
51
+ */
52
+ hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
53
+ if ({kernel_name}_func == nullptr)
54
+ load_{kernel_name}();
55
+ unsigned int gX = {gridX};
56
+ unsigned int gY = {gridY};
57
+ unsigned int gZ = {gridZ};
58
+ hipDeviceptr_t global_scratch = 0;
59
+ hipDeviceptr_t profile_scratch = 0;
60
+ void *args[{num_args}] = {{ {arg_pointers} }};
61
+ // TODO: shared memory
62
+ if(gX * gY * gZ > 0)
63
+ return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr);
64
+ else
65
+ return hipErrorInvalidValue;
66
+ }}
@@ -0,0 +1,13 @@
1
+ // SPDX-License-Identifier: MIT
2
+ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ #pragma once
5
+
6
+ #include <hip/hip_runtime.h>
7
+ #include <inttypes.h>
8
+ #include <stdint.h>
9
+ #include <stdio.h>
10
+
11
+ void unload_{kernel_name}(void);
12
+ void load_{kernel_name}(void);
13
+ hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature});
triton/tools/link.py ADDED
@@ -0,0 +1,322 @@
1
+ from collections import defaultdict
2
+ from pathlib import Path
3
+ from typing import Sequence, Union
4
+
5
+ from dataclasses import dataclass
6
+
7
+
8
+ def _exists(x):
9
+ return x is not None
10
+
11
+
12
+ class LinkerError(Exception):
13
+ pass
14
+
15
+
16
+ @dataclass
17
+ class KernelLinkerMeta:
18
+ orig_kernel_name: str
19
+ arg_names: Sequence[str]
20
+ arg_ctypes: Sequence[str]
21
+ sizes: Sequence[Union[int, None]]
22
+ sig_hash: str
23
+ triton_suffix: str
24
+ suffix: str
25
+ num_specs: int
26
+ """ number of specialized arguments """
27
+
28
+
29
+ class HeaderParser:
30
+
31
+ def __init__(self) -> None:
32
+ import re
33
+
34
+ # [kernel_name, c signature]
35
+ self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)")
36
+ # [name, hash, suffix]
37
+ self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
38
+ # [(type, name)]
39
+ self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?")
40
+ # [d|c]
41
+ self.arg_suffix = re.compile("[c,d]")
42
+
43
+ self.kernels = defaultdict(list)
44
+
45
+ def extract_linker_meta(self, header: str):
46
+ for ln in header.splitlines():
47
+ if ln.startswith("//"):
48
+ m = self.linker_directives.match(ln)
49
+ if _exists(m):
50
+ ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3)
51
+ name, sig_hash, suffix = self._match_name(ker_name)
52
+ c_types, arg_names = self._match_c_sig(c_sig)
53
+ num_specs, sizes = self._match_suffix(suffix, c_sig)
54
+ self._add_kernel(
55
+ "_".join([name, algo_info]),
56
+ KernelLinkerMeta(
57
+ orig_kernel_name=name,
58
+ arg_names=arg_names,
59
+ arg_ctypes=c_types,
60
+ sizes=sizes,
61
+ sig_hash=sig_hash,
62
+ triton_suffix=suffix,
63
+ suffix=suffix,
64
+ num_specs=num_specs,
65
+ ),
66
+ )
67
+
68
+ def _match_name(self, ker_name: str):
69
+ m = self.kernel_name.match(ker_name)
70
+ if _exists(m):
71
+ name, sig_hash, suffix = m.group(1), m.group(2), m.group(3)
72
+ return name, sig_hash, suffix
73
+ raise LinkerError(f"{ker_name} is not a valid kernel name")
74
+
75
+ def _match_c_sig(self, c_sig: str):
76
+ m = self.c_sig.findall(c_sig)
77
+ if len(m):
78
+ tys, args = [], []
79
+ for ty, arg_name in m:
80
+ tys.append(ty)
81
+ args.append(arg_name)
82
+ return tys, args
83
+
84
+ raise LinkerError(f"{c_sig} is not a valid argument signature")
85
+
86
+ def _match_suffix(self, suffix: str, c_sig: str):
87
+ args = c_sig.split(",")
88
+ s2i = {"c": 1, "d": 16}
89
+ num_specs = 0
90
+ sizes = []
91
+ # scan through suffix, first find the index,
92
+ # then see if it is followed by d or c
93
+ for i in range(len(args)):
94
+ pos = suffix.find(str(i))
95
+ if pos == -1:
96
+ raise LinkerError(f"{suffix} is not a valid kernel suffix")
97
+ pos += len(str(i))
98
+ if self.arg_suffix.match(suffix, pos):
99
+ num_specs += 1
100
+ sizes.extend([None] * (i - len(sizes)))
101
+ sizes.append(s2i[suffix[pos]])
102
+ pos += 1
103
+ if i < len(args) - 1:
104
+ suffix = suffix[pos:]
105
+ else:
106
+ sizes.extend([None] * (len(args) - len(sizes)))
107
+ return num_specs, sizes
108
+
109
+ def _add_kernel(self, name: str, ker: KernelLinkerMeta):
110
+ if name in self.kernels:
111
+ last: KernelLinkerMeta = self.kernels[name][-1]
112
+
113
+ for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
114
+ if cur != new_:
115
+ raise LinkerError(
116
+ f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
117
+ )
118
+
119
+ self.kernels[name].append(ker)
120
+
121
+
122
+ def gen_signature_with_full_args(m):
123
+ return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)])
124
+
125
+
126
+ def gen_signature(m):
127
+ arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1]
128
+ arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1]
129
+ sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)])
130
+ return sig
131
+
132
+
133
+ # generate declarations of kernels with meta-parameter and constant values
134
+ def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
135
+ return f"""
136
+ CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])});
137
+ void load_{name}();
138
+ void unload_{name}();
139
+ """
140
+
141
+
142
+ # generate declarations of kernels with meta-parameter and constant values
143
+ def make_global_decl(meta: KernelLinkerMeta) -> str:
144
+ return f"""
145
+ CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)});
146
+ CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id);
147
+ void load_{meta.orig_kernel_name}();
148
+ void unload_{meta.orig_kernel_name}();
149
+ """
150
+
151
+
152
+ # generate dispatcher function for kernels with different meta-parameter and constant values
153
+ def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
154
+ src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
155
+ src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
156
+ src += "}\n"
157
+ return src
158
+
159
+
160
+ # generate dispatcher function for kernels with different integer value hints
161
+ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
162
+ src = f"// launcher for: {name}\n"
163
+ for meta in sorted(metas, key=lambda m: -m.num_specs):
164
+ src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
165
+ src += "\n"
166
+
167
+ src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
168
+ src += "\n"
169
+ for meta in sorted(metas, key=lambda m: -m.num_specs):
170
+ cond_fn = ( #
171
+ lambda val, hint: f"({val} % {hint} == 0)" #
172
+ if hint == 16 #
173
+ else f"({val} == {hint})" #
174
+ if hint == 1 #
175
+ else None)
176
+ conds = " && ".join([ #
177
+ cond_fn(val, hint) #
178
+ for val, hint in zip(meta.arg_names, meta.sizes) #
179
+ if hint is not None
180
+ ])
181
+ src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
182
+ ) # Edge case where no specializations hence no dispatching required
183
+ arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
184
+ src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
185
+ src += "\n"
186
+ src += " return CUDA_ERROR_INVALID_VALUE;\n"
187
+ src += "}\n"
188
+
189
+ for mode in ["load", "unload"]:
190
+ src += f"\n// {mode} for: {name}\n"
191
+ for meta in sorted(metas, key=lambda m: -m.num_specs):
192
+ src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
193
+ src += f"void {mode}_{name}() {{"
194
+ src += "\n"
195
+ for meta in sorted(metas, key=lambda m: -m.num_specs):
196
+ src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
197
+ src += "}\n"
198
+ return src
199
+
200
+
201
+ # generate dispatcher function for kernels with different meta-parameter and constant values
202
+ def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str:
203
+ src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
204
+ src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n"
205
+ src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n"
206
+ src += "}\n"
207
+ return src
208
+
209
+
210
+ # generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values
211
+ def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str:
212
+ # the table of hint dispatchers
213
+ src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n"
214
+ src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n"
215
+ for name in names:
216
+ src += f" {name},\n"
217
+ src += "};\n"
218
+ return src
219
+
220
+
221
+ # generate definition for load/unload functions for kernels with different meta-parameter and constant values
222
+ def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str:
223
+ src = ""
224
+ for mode in ["load", "unload"]:
225
+ src += f"void {mode}_{meta.orig_kernel_name}(void){{\n"
226
+ for name in names:
227
+ src += f" {mode}_{name}();\n"
228
+ src += "}\n\n"
229
+ return src
230
+
231
+
232
+ def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str:
233
+ src = f"int {meta.orig_kernel_name}_get_num_algos(void);"
234
+ return src
235
+
236
+
237
+ def make_get_num_algos_def(meta: KernelLinkerMeta) -> str:
238
+ src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n"
239
+ src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n"
240
+ src += "}\n"
241
+ return src
242
+
243
+
244
+ desc = """
245
+ Triton ahead-of-time linker:
246
+
247
+ This program takes in header files generated by compile.py, and generates a
248
+ single entry-point responsible for dispatching the user's input to the right
249
+ kernel given the specializations that were compiled.
250
+
251
+ Example usage:
252
+ python link.py /path/to/headers/*.h -o kernel_name
253
+ """
254
+
255
+ if __name__ == "__main__":
256
+ from argparse import ArgumentParser
257
+
258
+ parser = ArgumentParser(description=desc)
259
+ parser.add_argument(
260
+ "headers",
261
+ nargs="+",
262
+ help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
263
+ )
264
+ parser.add_argument("--out", "-o", type=Path, help="Out filename")
265
+ parser.add_argument(
266
+ "--prefix",
267
+ type=str,
268
+ default="",
269
+ help="String to prefix kernel dispatcher names",
270
+ )
271
+ args = parser.parse_args()
272
+
273
+ # metadata
274
+ parser = HeaderParser()
275
+ includes = []
276
+ for header in args.headers:
277
+ h_path = Path(header)
278
+ h_str = h_path.read_text()
279
+ includes.append(h_path.name)
280
+ parser.extract_linker_meta(h_str)
281
+
282
+ # generate headers
283
+ algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()]
284
+ meta_lists = [meta for name, meta in parser.kernels.items()]
285
+ meta = meta_lists[0][0]
286
+ get_num_algos_decl = make_get_num_algos_decl(meta)
287
+ global_decl = make_global_decl(meta)
288
+ with args.out.with_suffix(".h").open("w") as fp:
289
+ out = "#include <cuda.h>\n"
290
+ out += "\n".join(algo_decls)
291
+ out += "\n"
292
+ out += get_num_algos_decl
293
+ out += "\n"
294
+ out += global_decl
295
+ fp.write(out)
296
+
297
+ # generate source
298
+ defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()]
299
+ names = [name for name in parser.kernels.keys()]
300
+ func_pointers_def = make_func_pointers(names, meta)
301
+ meta_const_def = make_kernel_meta_const_dispatcher(meta)
302
+ load_unload_def = make_kernel_load_def(names, meta)
303
+ get_num_algos_def = make_get_num_algos_def(meta)
304
+ default_algo_kernel = make_default_algo_kernel(meta)
305
+ with args.out.with_suffix(".c").open("w") as fp:
306
+ out = ""
307
+ out += "#include <cuda.h>\n"
308
+ out += "#include <stdint.h>\n"
309
+ out += "#include <assert.h>\n"
310
+ out += "\n"
311
+ out += "\n".join(defs)
312
+ out += "\n"
313
+ out += func_pointers_def
314
+ out += "\n"
315
+ out += get_num_algos_def
316
+ out += "\n"
317
+ out += meta_const_def
318
+ out += "\n"
319
+ out += load_unload_def
320
+ out += "\n"
321
+ out += default_algo_kernel
322
+ fp.write(out)