triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,365 @@
1
+ import argparse
2
+ import subprocess
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Optional
5
+
6
+
7
+ class Symbol:
8
+ _name: str
9
+ _op_name: str
10
+ _ret_type: str
11
+ _arg_names: List[str]
12
+ _arg_types: List[str]
13
+
14
+ def __init__(
15
+ self,
16
+ name: str,
17
+ op_name: str,
18
+ ret_type: str,
19
+ arg_names: List[str],
20
+ arg_types: List[str],
21
+ ) -> None:
22
+ '''
23
+ A symbol is a function declaration.
24
+ :param name: name of the symbol
25
+ :param op_name: name of the operation
26
+ :param ret_type: return type of the operation
27
+ :param arg_names: names of the arguments
28
+ :param arg_types: types of the arguments
29
+ '''
30
+ self._name = name
31
+ self._op_name = op_name
32
+ self._ret_type = ret_type
33
+ self._arg_names = list(arg_names)
34
+ self._arg_types = list(arg_types)
35
+
36
+ @property
37
+ def name(self) -> str:
38
+ return self._name
39
+
40
+ @property
41
+ def op_name(self) -> str:
42
+ return self._op_name
43
+
44
+ @property
45
+ def ret_type(self) -> str:
46
+ return self._ret_type
47
+
48
+ @property
49
+ def arg_names(self) -> List[str]:
50
+ return self._arg_names
51
+
52
+ @property
53
+ def arg_types(self) -> List[str]:
54
+ return self._arg_types
55
+
56
+
57
+ def convert_type(type_str) -> Optional[str]:
58
+ if type_str == "i32":
59
+ return "int32"
60
+ elif type_str == "u32":
61
+ return "uint32"
62
+ elif type_str == "i64":
63
+ return "int64"
64
+ elif type_str == "u64":
65
+ return "uint64"
66
+ elif type_str == "float":
67
+ return "fp32"
68
+ elif type_str == "double":
69
+ return "fp64"
70
+ else:
71
+ # ignore other types, such as pointer types
72
+ return None
73
+
74
+
75
+ def to_unsigned(type_str) -> str:
76
+ if type_str == "int32":
77
+ return "uint32"
78
+ elif type_str == "int64":
79
+ return "uint64"
80
+ else:
81
+ return type_str
82
+
83
+
84
+ class ExternLibrary(ABC):
85
+ _name: str
86
+ _path: str
87
+ _symbols: Dict[str, Symbol]
88
+ _format: bool
89
+ _grouping: bool
90
+
91
+ def __init__(
92
+ self,
93
+ name: str,
94
+ path: str,
95
+ format: bool = True,
96
+ grouping: bool = True,
97
+ ) -> None:
98
+ '''
99
+ Abstract class for extern library.
100
+ :param name: name of the library
101
+ :param path: path of the library
102
+ :param format: whether to format the generated stub file
103
+ '''
104
+ self._name = name
105
+ self._path = path
106
+ self._symbols = {}
107
+ self._format = format
108
+ self._grouping = grouping
109
+
110
+ @property
111
+ def name(self) -> str:
112
+ return self._name
113
+
114
+ @property
115
+ def path(self) -> str:
116
+ return self._path
117
+
118
+ @property
119
+ def symbols(self) -> Dict[str, Symbol]:
120
+ return self._symbols
121
+
122
+ @property
123
+ def grouping(self) -> bool:
124
+ return self._grouping
125
+
126
+ @abstractmethod
127
+ def parse_symbols(self, input_file) -> None:
128
+ pass
129
+
130
+ @abstractmethod
131
+ def _output_stubs(self) -> str:
132
+ pass
133
+
134
+ def generate_stub_file(self, output_dir) -> None:
135
+ file_str = self._output_stubs()
136
+ if file_str is None or len(file_str) == 0:
137
+ raise Exception("file_str is empty")
138
+
139
+ output_file = f"{output_dir}/{self._name}.py"
140
+ with open(output_file, "w") as f:
141
+ f.write(file_str)
142
+ f.close()
143
+ if self._format:
144
+ subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
145
+ subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
146
+
147
+
148
+ class Libdevice(ExternLibrary):
149
+ _symbol_groups: Dict[str, List[Symbol]]
150
+
151
+ def __init__(self, path) -> None:
152
+ '''
153
+ Constructor for Libdevice.
154
+ :param path: path of the libdevice library
155
+ '''
156
+ super().__init__("libdevice", path)
157
+ self._symbol_groups = {}
158
+ self.is_pure = True
159
+
160
+ @staticmethod
161
+ def _extract_symbol(line) -> Optional[Symbol]:
162
+ # Extract symbols from line in the following format:
163
+ # "define [internal] <ret_type> @<name>(<arg_types>,)"
164
+ entries = line.split("@")
165
+ ret_str = entries[0]
166
+ func_str = entries[1]
167
+ # Get ret_type, skip internal symbols
168
+ ret_strs = ret_str.split()
169
+ if ret_strs[1] == "internal":
170
+ return None
171
+ ret_type = convert_type(ret_strs[1])
172
+ if ret_type is None:
173
+ return None
174
+ # Get function name
175
+ func_strs = func_str.split("(")
176
+ func_name = func_strs[0].replace("@", "")
177
+ op_name = func_name.replace("__nv_", "")
178
+ if 'ieee' in op_name:
179
+ return None
180
+ # Get arg_types
181
+ arg_strs = func_strs[1].split(",")
182
+ arg_types = []
183
+ arg_names = []
184
+ for i, arg_str in enumerate(arg_strs):
185
+ arg_type = convert_type(arg_str.split()[0])
186
+ if arg_type is None:
187
+ return None
188
+ arg_name = 'arg' + str(i)
189
+ arg_types.append(arg_type)
190
+ arg_names.append(arg_name)
191
+ if op_name == "sad":
192
+ # Special case for sad, where the last argument is an unsigned int
193
+ arg_types[-1] = to_unsigned(arg_types[-1])
194
+ elif op_name.startswith("u"):
195
+ # LLVM does not differentiate between signed and unsigned integer type.
196
+ # We have to convert the types to unsigned
197
+ ret_type = to_unsigned(ret_type)
198
+ for i, arg_type in enumerate(arg_types):
199
+ arg_types[i] = to_unsigned(arg_type)
200
+ return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
201
+
202
+ def _group_symbols(self) -> None:
203
+ symbol_set = {}
204
+ for symbol in self._symbols.values():
205
+ op_name = symbol.op_name
206
+ symbol_set[op_name] = symbol
207
+
208
+ # Group functions together by renaming.
209
+ renaming = {
210
+ 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
211
+ 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
212
+ 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
213
+ 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
214
+ 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
215
+ 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
216
+ 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
217
+ 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
218
+ 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
219
+ 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
220
+ 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
221
+ 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
222
+ 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
223
+ 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
224
+ 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
225
+ 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
226
+ 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
227
+ 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
228
+ 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
229
+ 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
230
+ 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
231
+ 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
232
+ 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
233
+ 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
234
+ 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
235
+ 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
236
+ 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
237
+ 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
238
+ 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
239
+ 'yn'
240
+ }
241
+
242
+ for symbol in self._symbols.values():
243
+ op_name = symbol.op_name
244
+ if op_name in renaming:
245
+ op_name = renaming[op_name]
246
+ symbol._op_name = op_name
247
+ if op_name in self._symbol_groups:
248
+ self._symbol_groups[op_name].append(symbol)
249
+ else:
250
+ self._symbol_groups[op_name] = [symbol]
251
+
252
+ def parse_symbols(self, input_file) -> None:
253
+ if len(self.symbols) > 0:
254
+ return
255
+ output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
256
+ for line in output:
257
+ symbol = self._extract_symbol(line)
258
+ if symbol is None:
259
+ continue
260
+ self._symbols[symbol.name] = symbol
261
+
262
+ self._group_symbols()
263
+
264
+ def _output_stubs(self) -> str:
265
+ # Generate python functions in the following format:
266
+ # @extern.extern
267
+ # def <op_name>(<args>, _builder=None):
268
+ # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
269
+ # return core.extern_elementwise("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
270
+ import_str = "from . import core\n"
271
+
272
+ header_str = ""
273
+ func_str = ""
274
+ for symbols in self._symbol_groups.values():
275
+ func_str += "@core.extern\n"
276
+ func_name_str = f"def {symbols[0].op_name}("
277
+ for arg_name in symbols[0].arg_names:
278
+ func_name_str += f"{arg_name}, "
279
+ func_name_str += "_builder=None):\n"
280
+
281
+ return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), ["
282
+ for arg_name in symbols[0].arg_names:
283
+ return_str += f"{arg_name}, "
284
+ return_str += "], \n"
285
+
286
+ arg_type_symbol_dict_str = "{"
287
+ for symbol in symbols:
288
+ arg_type_symbol_dict_str += "("
289
+ for arg_type in symbol.arg_types:
290
+ arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),'
291
+ ret_type = f'core.dtype("{symbol.ret_type}")'
292
+ arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n"
293
+ arg_type_symbol_dict_str += "}"
294
+
295
+ return_str += arg_type_symbol_dict_str
296
+ return_str += f", is_pure={self.is_pure}"
297
+ return_str += ", _builder=_builder)\n"
298
+
299
+ func_str += func_name_str + return_str + "\n"
300
+ file_str = import_str + header_str + func_str
301
+
302
+ return file_str
303
+
304
+
305
+ class LLVMDisassembler:
306
+ _path: str
307
+ _ll_file: str
308
+
309
+ def __init__(self, path) -> None:
310
+ '''
311
+ Invoke llvm-dis to disassemble the given file.
312
+ :param path: path to llvm-dis
313
+ '''
314
+ self._path = path
315
+ self._ll_file = "/tmp/extern_lib.ll"
316
+
317
+ def disasm(self, lib_path: str) -> None:
318
+ subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
319
+
320
+ @property
321
+ def ll_file(self) -> str:
322
+ return self._ll_file
323
+
324
+ @property
325
+ def path(self) -> str:
326
+ return self._path
327
+
328
+
329
+ extern_libs = ["libdevice"]
330
+
331
+
332
+ def build(
333
+ llvm_dis_path: str,
334
+ lib_path: str,
335
+ lib_name: str,
336
+ output_dir: str,
337
+ ) -> None:
338
+ '''
339
+ Interface function to build the library file.
340
+ :param llvm_dis_path: path to the llvm-dis binary
341
+ :param lib_path: path to the external library file
342
+ :param lib_name: name of the library
343
+ :param output_dir: path to the output directory
344
+ '''
345
+ if lib_name == "libdevice":
346
+ extern_lib = Libdevice(lib_path)
347
+ else:
348
+ raise Exception(f"Unknown extern library: {lib_name}")
349
+
350
+ llvm_disassembler = LLVMDisassembler(llvm_dis_path)
351
+ llvm_disassembler.disasm(lib_path)
352
+
353
+ extern_lib.parse_symbols(llvm_disassembler.ll_file)
354
+ extern_lib.generate_stub_file(output_dir)
355
+
356
+
357
+ if __name__ == "__main__":
358
+ parser = argparse.ArgumentParser()
359
+ parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis")
360
+ parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library")
361
+ parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library")
362
+ parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/")
363
+ args = parser.parse_args()
364
+
365
+ build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)
@@ -0,0 +1,210 @@
1
+ import binascii
2
+ import hashlib
3
+ import importlib.util
4
+ import sys
5
+ from argparse import ArgumentParser
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import List
9
+
10
+ import triton
11
+ import triton.backends
12
+
13
+
14
+ @dataclass
15
+ class CompileArgs:
16
+ '''
17
+ A class to contain arguments from command-line parser.
18
+ '''
19
+ path: str = ''
20
+ kernel_name: str = ''
21
+ signature: str = ''
22
+ grid: str = ''
23
+ target: str | None = None
24
+ num_warps: int = 1
25
+ num_stages: int = 3
26
+ out_name: str | None = None
27
+ out_path: Path | None = None
28
+
29
+
30
+ desc = """
31
+ Triton ahead-of-time compiler:
32
+
33
+ This program compiles the kernel with name `kernel-name` in the file at the
34
+ provided `path` into self-contained C source-code that embeds the `cubin`
35
+ data along with utilities to load, unload and launch the kernel.
36
+
37
+ signature is provided as a list of (optionally divisibility-hinted) types
38
+ or constexpr values, e.g.
39
+
40
+ `compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py`
41
+
42
+ will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`.
43
+ Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16,
44
+ and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype.
45
+
46
+ The resulting entry point will have signature
47
+
48
+ CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2)
49
+
50
+ Different such specialized entry points can be combined using the `linker.py` script.
51
+
52
+ NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter
53
+ used to run this `compile.py` script
54
+ """
55
+
56
+
57
+ def main():
58
+ # command-line arguments
59
+ parser = ArgumentParser(description=desc)
60
+ parser.add_argument("path",
61
+ help="Path to Python source containing desired kernel in its scope. File will be executed.")
62
+ parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
63
+ required=True)
64
+ parser.add_argument(
65
+ "--target", "-t", type=str, default=None,
66
+ help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
67
+ "e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
68
+ parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
69
+ parser.add_argument("--num-stages", "-ns", type=int, default=3,
70
+ help="Number of stages (meta-parameter of the kernel)")
71
+ parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
72
+ parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
73
+ parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
74
+ parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
75
+ cli_args = parser.parse_args()
76
+ args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
77
+ compile_kernel(args)
78
+
79
+
80
+ def compile_kernel(args: CompileArgs):
81
+ out_name = args.out_name if args.out_name else args.kernel_name
82
+ out_path = args.out_path if args.out_path else Path(out_name)
83
+
84
+ # execute python sources and extract functions wrapped in JITFunction
85
+ arg_path = Path(args.path)
86
+ sys.path.insert(0, str(arg_path.parent))
87
+ spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path)
88
+ mod = importlib.util.module_from_spec(spec)
89
+ spec.loader.exec_module(mod)
90
+ kernel = getattr(mod, args.kernel_name)
91
+ grid = args.grid.split(",")
92
+ assert len(grid) == 3
93
+
94
+ # validate and parse signature
95
+ signature = list(map(lambda s: s.strip(" "), args.signature.split(",")))
96
+
97
+ def hash_signature(signature: List[str]):
98
+ m = hashlib.sha256()
99
+ m.update(" ".join(signature).encode())
100
+ return m.hexdigest()[:8]
101
+
102
+ meta_sig = f"warps{args.num_warps}xstages{args.num_stages}"
103
+ sig_hash = hash_signature(signature + [meta_sig])
104
+
105
+ def constexpr(s):
106
+ try:
107
+ ret = int(s)
108
+ return ret
109
+ except ValueError:
110
+ pass
111
+ try:
112
+ ret = float(s)
113
+ return ret
114
+ except ValueError:
115
+ pass
116
+ return None
117
+
118
+ hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s}
119
+ hints = {k: v for k, v in hints.items() if v is not None}
120
+ constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)}
121
+ constants = {k: v for k, v in constants.items() if v is not None}
122
+ for key, value in hints.items():
123
+ if value == 1:
124
+ constants[kernel.arg_names[key[0]]] = value
125
+ signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)}
126
+ for key in constants:
127
+ signature[key] = 'constexpr'
128
+ const_sig = 'x'.join([str(v) for v in constants.values()])
129
+ doc_string = [f"{k}={v}" for k, v in constants.items()]
130
+ doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"]
131
+ # compile ast into cubin
132
+ for h in hints.values():
133
+ assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
134
+ attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
135
+ src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
136
+
137
+ target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
138
+ if args.target else triton.runtime.driver.active.get_current_target()
139
+ backend = triton.compiler.make_backend(target)
140
+ kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
141
+ options = backend.parse_options(kwargs)
142
+ ccinfo = triton.compile(src, target=target, options=options.__dict__)
143
+
144
+ if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
145
+ raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
146
+ if ccinfo.metadata.profile_scratch_size > 0:
147
+ raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented")
148
+
149
+ arg_names = []
150
+ arg_types = []
151
+ arg_names_not_1 = []
152
+ arg_types_not_1 = []
153
+ for i, arg_name in enumerate(kernel.arg_names):
154
+ if arg_name not in constants:
155
+ arg_names.append(arg_name)
156
+ arg_types.append(signature[arg_name])
157
+ arg_names_not_1.append(arg_name)
158
+ arg_types_not_1.append(signature[arg_name])
159
+ elif hints.get((i, ), None) == 1:
160
+ arg_names.append(arg_name)
161
+ arg_types.append("i32")
162
+
163
+ # dump C stub code
164
+ suffix = ''
165
+ for i, ty in enumerate(signature.values()):
166
+ suffix += str(i)
167
+ if hints.get((i, ), None) == 1:
168
+ suffix += 'c'
169
+ if hints.get((i, ), None) == 16:
170
+ suffix += 'd'
171
+ func_name = '_'.join([out_name, sig_hash, suffix])
172
+ asm = ccinfo.asm[backend.binary_ext] # store binary data once
173
+
174
+ hex_ = str(binascii.hexlify(asm))[2:-1]
175
+
176
+ ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
177
+
178
+ params = {
179
+ "kernel_name": func_name,
180
+ "triton_kernel_name": args.kernel_name,
181
+ "bin_size": len(asm),
182
+ "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
183
+ "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
184
+ "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
185
+ "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]),
186
+ "num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch
187
+ "kernel_docstring": doc_string,
188
+ "shared": ccinfo.metadata.shared,
189
+ "num_warps": args.num_warps,
190
+ "algo_info": "_".join([const_sig, meta_sig]),
191
+ "gridX": grid[0],
192
+ "gridY": grid[1],
193
+ "gridZ": grid[2],
194
+ "_placeholder": "",
195
+ }
196
+ output_files = []
197
+ backend_name = target.backend
198
+ template_dir = Path(__file__).parent / "extra" / backend_name
199
+ for template_path in template_dir.glob('compile.*'):
200
+ ext = template_path.suffix
201
+ output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
202
+ with output_file.open("w") as fp:
203
+ fp.write(template_path.read_text().format(**params))
204
+ output_files.append(output_file)
205
+
206
+ return func_name, output_files
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()