triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,509 @@
1
+ from __future__ import annotations
2
+ import hashlib
3
+ import json
4
+ from .._C.libtriton import get_cache_invalidating_env_vars, ir
5
+ from ..backends import backends
6
+ from ..backends.compiler import Language
7
+ from ..backends.compiler import BaseBackend, GPUTarget
8
+ from .. import __version__, knobs
9
+ from ..runtime.autotuner import OutOfResources
10
+ from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key
11
+ from ..runtime.driver import driver
12
+ from ..tools.disasm import get_sass
13
+ from pathlib import Path
14
+ import re
15
+ import functools
16
+ import os
17
+ import time
18
+ import copy
19
+
20
+ # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
21
+ # and any following whitespace
22
+ # - (public\s+)? : optionally match the keyword public and any following whitespace
23
+ # - (@\w+) : match an @ symbol followed by one or more word characters
24
+ # (letters, digits, or underscores), and capture it as group 1 (the function name)
25
+ # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
26
+ # zero or more arguments separated by commas, and capture it as group 2 (the argument list)
27
+ # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
28
+ ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
29
+ prototype_pattern = {
30
+ "ptx": ptx_prototype_pattern,
31
+ }
32
+
33
+ ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
34
+ arg_type_pattern = {
35
+ "ptx": ptx_arg_type_pattern,
36
+ }
37
+
38
+
39
+ def convert_type_repr(x):
40
+ # Currently we only capture the pointer type and assume the pointer is on global memory.
41
+ # TODO: Capture and support shared memory space
42
+ match = re.search(r'!tt\.ptr<([^,]+)', x)
43
+ tma = re.search(r'tt.nv_tma_desc = 1', x)
44
+ if tma is not None:
45
+ return 'nvTmaDesc'
46
+ x = re.sub(r' {[^}]+}', '', x)
47
+ if match is not None:
48
+ return '*' + convert_type_repr(match.group(1))
49
+ return x
50
+
51
+
52
+ class ASTSource:
53
+
54
+ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
55
+ self.fn = fn
56
+ self.language = Language.TRITON
57
+ self.ext = "ttir"
58
+ self.name = fn.__name__
59
+ self.signature = signature
60
+ self.constants = dict()
61
+ if constexprs is not None:
62
+ for k, v in constexprs.items():
63
+ k = (fn.arg_names.index(k), ) if isinstance(k, str) else k
64
+ assert isinstance(k, tuple)
65
+ self.constants[k] = v
66
+ self.attrs = attrs or dict()
67
+ for k in self.signature.keys():
68
+ if not isinstance(k, str):
69
+ raise TypeError("Signature keys must be string")
70
+
71
+ def hash(self):
72
+ sorted_sig = [v for k, v in sorted(self.signature.items())]
73
+ get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x)
74
+ constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())])
75
+ key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
76
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
77
+
78
+ def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
79
+ from .code_generator import ast_to_ttir
80
+ return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
81
+ module_map=module_map)
82
+
83
+ def parse_options(self):
84
+ return dict()
85
+
86
+
87
+ class IRSource:
88
+
89
+ def __init__(self, path, context, backend):
90
+ self.path = path
91
+ path = Path(path)
92
+ self.ext = path.suffix[1:]
93
+ self.language = Language.TRITON
94
+ self.src = path.read_text()
95
+ ir.load_dialects(context)
96
+ backend.load_dialects(context)
97
+
98
+ # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
99
+ # TODO - replace with a proper parser
100
+ if self.ext == "ptx":
101
+ match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
102
+ self.name = match.group(1)
103
+ signature = match.group(2)
104
+ types = re.findall(arg_type_pattern[self.ext], signature)
105
+ self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
106
+ else:
107
+ self.module = ir.parse_mlir_module(self.path, context)
108
+ fn_name = self.module.get_entry_func_name()
109
+ self.name = "@" + fn_name
110
+ funcOp = self.module.get_function(fn_name)
111
+ func_ty = self.module.get_function_signature(funcOp)
112
+ self.signature = {k: ty for k, ty in enumerate(func_ty)}
113
+
114
+ def hash(self):
115
+ return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
116
+
117
+ def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
118
+ self.module.context = context
119
+ return self.module
120
+
121
+ def parse_options(self):
122
+ if self.ext == "ttgir":
123
+ num_warps = self.module.get_int_attr("ttg.num-warps")
124
+ assert num_warps is not None, "Unable to parse ttg.num-warps attribute"
125
+ return {'num_warps': num_warps}
126
+ return dict()
127
+
128
+
129
+ @functools.lru_cache()
130
+ def max_shared_mem(device):
131
+ return driver.active.utils.get_device_properties(device)["max_shared_mem"]
132
+
133
+
134
+ def parse(full_name, ext, context):
135
+ if ext == "ttir" or ext == "ttgir":
136
+ module = ir.parse_mlir_module(full_name, context)
137
+ module.context = context
138
+ return module
139
+ if ext == "llir" or ext == "ptx" or ext == "amdgcn":
140
+ return Path(full_name).read_text()
141
+ if ext == "cubin" or ext == "hsaco":
142
+ return Path(full_name).read_bytes()
143
+
144
+
145
+ def filter_traceback(e: BaseException):
146
+ """
147
+ Removes code_generator.py and related files from tracebacks.
148
+
149
+ These are uninteresting to the user -- "just show me *my* code!"
150
+ """
151
+ if knobs.compilation.front_end_debugging:
152
+ return
153
+
154
+ if e.__cause__ is not None:
155
+ filter_traceback(e.__cause__)
156
+ if e.__context__ is not None:
157
+ filter_traceback(e.__context__)
158
+
159
+ # If a user has a file that matches one of these, they're out of luck.
160
+ BAD_FILES = [
161
+ "/triton/compiler/code_generator.py",
162
+ "/ast.py",
163
+ ]
164
+ BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES]
165
+
166
+ tb = e.__traceback__
167
+ frames = []
168
+ while tb is not None:
169
+ if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)):
170
+ frames.append(tb)
171
+ tb = tb.tb_next
172
+
173
+ for (cur_frame, next_frame) in zip(frames, frames[1:]):
174
+ cur_frame.tb_next = next_frame
175
+
176
+ if not frames:
177
+ e.__traceback__ = None
178
+ else:
179
+ frames[-1].tb_next = None
180
+ e.__traceback__ = frames[0]
181
+
182
+
183
+ class CompileTimer:
184
+
185
+ def __init__(self) -> None:
186
+ self.start: float = time.perf_counter()
187
+ self.ir_initialization_end: float | None = None
188
+ self.lowering_stage_ends: list[tuple[str, float]] = []
189
+ self.store_results_end: float | None = None
190
+
191
+ def finished_ir_initialization(self) -> None:
192
+ self.ir_initialization_end = time.perf_counter()
193
+
194
+ def stage_finished(self, stage_name: str) -> None:
195
+ self.lowering_stage_ends.append((stage_name, time.perf_counter()))
196
+
197
+ def end(self) -> knobs.CompileTimes:
198
+ timestamp = time.perf_counter()
199
+ if self.ir_initialization_end is None:
200
+ self.ir_initialization_end = timestamp
201
+ else:
202
+ self.store_results_end = timestamp
203
+
204
+ def delta(start: float, end: float | None) -> int:
205
+ if end is None:
206
+ return 0
207
+ return int((end - start) * 1000000)
208
+
209
+ lowering_stage_durations = []
210
+ stage_start = self.ir_initialization_end
211
+ for stage_name, stage_end in self.lowering_stage_ends:
212
+ lowering_stage_durations.append((stage_name, delta(stage_start, stage_end)))
213
+ stage_start = stage_end
214
+
215
+ return knobs.CompileTimes(
216
+ ir_initialization=delta(self.start, self.ir_initialization_end),
217
+ lowering_stages=lowering_stage_durations,
218
+ store_results=delta(stage_start, self.store_results_end),
219
+ )
220
+
221
+
222
+ def compile(src, target=None, options=None, _env_vars=None):
223
+ compilation_listener = knobs.compilation.listener
224
+ if compilation_listener:
225
+ timer = CompileTimer()
226
+
227
+ if target is None:
228
+ target = driver.active.get_current_target()
229
+ assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
230
+ backend = make_backend(target)
231
+ ir_source = not isinstance(src, ASTSource)
232
+ # create backend
233
+ if ir_source:
234
+ assert isinstance(src, str), "source must be either AST or a filepath"
235
+ context = ir.context()
236
+ src = IRSource(src, context, backend)
237
+
238
+ extra_options = src.parse_options()
239
+ options = backend.parse_options(dict(options or dict(), **extra_options))
240
+ # create cache manager
241
+ env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars
242
+ key = get_cache_key(src, backend, options, env_vars=env_vars)
243
+ hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
244
+ fn_cache_manager = get_cache_manager(hash)
245
+ # For dumping/overriding only hash the source as we want it to be independent of triton
246
+ # core changes to make it easier to track kernels by hash.
247
+ enable_override = knobs.compilation.override
248
+ enable_ir_dump = knobs.compilation.dump_ir
249
+ store_only_binary = knobs.compilation.store_binary_only
250
+ fn_override_manager = get_override_manager(src.hash()) if enable_override else None
251
+ fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
252
+ # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
253
+ # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}".
254
+ # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate
255
+ # the file name to 150 characters to be safe.
256
+ file_name = src.name[:150]
257
+ metadata_filename = f"{file_name}.json"
258
+ metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
259
+ metadata_path = metadata_group.get(metadata_filename)
260
+ always_compile = knobs.compilation.always_compile
261
+ if not always_compile and metadata_path is not None:
262
+ # cache hit!
263
+ res = CompiledKernel(src, metadata_group, hash)
264
+ if compilation_listener:
265
+ compilation_listener(
266
+ src=src,
267
+ metadata=res.metadata._asdict(),
268
+ metadata_group=metadata_group,
269
+ times=timer.end(),
270
+ cache_hit=True,
271
+ )
272
+ return res
273
+
274
+ # initialize metadata
275
+ metadata = {
276
+ "hash": hash,
277
+ "target": target,
278
+ **options.__dict__,
279
+ **env_vars,
280
+ }
281
+ metadata["triton_version"] = __version__
282
+ # run compilation pipeline and populate metadata
283
+ stages = dict()
284
+ backend.add_stages(stages, options, src.language)
285
+ first_stage = list(stages.keys()).index(src.ext)
286
+ # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
287
+ if ir_source:
288
+ first_stage += 1
289
+
290
+ # For IRSource, we have already grabbed the context + called both
291
+ # ir.load_dialects and backend.load_dialects.
292
+ if not isinstance(src, IRSource):
293
+ context = ir.context()
294
+ ir.load_dialects(context)
295
+ backend.load_dialects(context)
296
+
297
+ codegen_fns = backend.get_codegen_implementation(options)
298
+ module_map = backend.get_module_map()
299
+ try:
300
+ module = src.make_ir(target, options, codegen_fns, module_map, context)
301
+ except Exception as e:
302
+ filter_traceback(e)
303
+ raise
304
+
305
+ if ir_source:
306
+ ir_filename = f"{file_name}.{src.ext}"
307
+ metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
308
+ else:
309
+ ir_filename = f"{file_name}.source"
310
+ metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
311
+
312
+ use_ir_loc = knobs.compilation.use_ir_loc
313
+ if ir_source and use_ir_loc:
314
+ module.create_location_snapshot(src.path)
315
+ print(f"Creating new locations for {src.path}")
316
+
317
+ if compilation_listener:
318
+ timer.finished_ir_initialization()
319
+ for ext, compile_ir in list(stages.items())[first_stage:]:
320
+ next_module = compile_ir(module, metadata)
321
+ ir_filename = f"{file_name}.{ext}"
322
+ if fn_override_manager is None:
323
+ # Users can override kernels at scale by setting `ir_override` in autotune config
324
+ # without TRITON_KERNEL_OVERRIDE
325
+ if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"):
326
+ next_module = parse(ir_override, ext, context)
327
+ elif full_name := fn_override_manager.get_file(ir_filename):
328
+ print(f"\nOverriding kernel with file {full_name}")
329
+ next_module = parse(full_name, ext, context)
330
+ # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
331
+ if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")):
332
+ metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
333
+ if fn_dump_manager is not None:
334
+ fn_dump_manager.put(next_module, ir_filename)
335
+ if ext == "cubin":
336
+ sass = get_sass(next_module)
337
+ fn_dump_manager.put(sass, file_name + ".sass")
338
+ # use an env variable to parse ir from file
339
+ if use_ir_loc == ext:
340
+ ir_full_name = fn_cache_manager.get_file(ir_filename)
341
+ next_module.create_location_snapshot(ir_full_name)
342
+ print(f"Creating new locations for {ir_full_name}")
343
+ module = next_module
344
+ if compilation_listener:
345
+ timer.stage_finished(ext)
346
+ # write-back metadata
347
+ metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
348
+ binary=False)
349
+ fn_cache_manager.put_group(metadata_filename, metadata_group)
350
+ # Compilation completed, disabling multithreading in context.
351
+ # This is needed to safely finalize threads pool inside context: if current process forks before
352
+ # python GC deletes context object, thread pool in child process will be invalid, which could
353
+ # lead to child crash or hang.
354
+ #
355
+ # However disabling multithreading causes the code to hang if the ASAN pass is enabled
356
+ # this is likely due to the llvm-symbolizer forking a process
357
+ # TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
358
+ # multithreading in the MLIR context
359
+ if not knobs.compilation.enable_asan:
360
+ context.disable_multithreading()
361
+
362
+ # notify any listener
363
+ if compilation_listener:
364
+ compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(),
365
+ cache_hit=False)
366
+ # return handle to compiled kernel
367
+ return CompiledKernel(src, metadata_group, hash)
368
+
369
+
370
+ def make_backend(target: GPUTarget) -> BaseBackend:
371
+ actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
372
+ if len(actives) != 1:
373
+ raise RuntimeError(
374
+ f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.")
375
+ return actives[0](target)
376
+
377
+
378
+ class LazyDict:
379
+
380
+ def __init__(self, data):
381
+ self.data = data
382
+ self.extras = []
383
+
384
+ def get(self):
385
+ for func, args in self.extras:
386
+ self.data = self.data | func(*args)
387
+ self.extras.clear()
388
+ return self.data
389
+
390
+ def add(self, func, args):
391
+ self.extras.append((func, args))
392
+
393
+
394
+ class AsmDict(dict):
395
+
396
+ def __missing__(self, key):
397
+
398
+ if key == "sass":
399
+ value = get_sass(self["cubin"])
400
+ else:
401
+ raise KeyError("Unknown key: '%s'" % key)
402
+
403
+ self[key] = value
404
+ return value
405
+
406
+
407
+ def _raise_error(err, *args, **kwargs):
408
+ raise copy.deepcopy(err)
409
+
410
+
411
+ class CompiledKernel:
412
+
413
+ def __init__(self, src, metadata_group, hash):
414
+ from collections import namedtuple
415
+ metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
416
+ metadata = json.loads(metadata_path.read_text())
417
+ metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
418
+ # JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
419
+ target = metadata['target']
420
+ metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
421
+ KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
422
+ self.metadata = KernelMetadata(**metadata)
423
+ backend = make_backend(self.metadata.target)
424
+ self.packed_metadata = backend.pack_metadata(self.metadata)
425
+ self.src = src
426
+ self.hash = hash
427
+ self.name = self.metadata.name
428
+ # stores the text of each level of IR that was generated during compilation
429
+ asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
430
+ binary_ext = backend.binary_ext
431
+ self.asm = AsmDict({
432
+ file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
433
+ for file in asm_files
434
+ })
435
+ self.metadata_group = metadata_group
436
+ self.kernel = self.asm[binary_ext]
437
+ # binaries are lazily initialized
438
+ # because it involves doing runtime things
439
+ # (e.g., checking amount of shared memory on current device)
440
+ self.module = None
441
+ self.function = None
442
+ self._run = None
443
+
444
+ def _init_handles(self):
445
+ if self.module is not None:
446
+ return
447
+
448
+ def raise_(err):
449
+ # clone the exception object so that the one saved in the closure
450
+ # of the partial function below doesn't get assigned a stack trace
451
+ # after the subsequent raise. otherwise, the CompiledKernel instance
452
+ # saved in the (global) kernel cache will keep references to all the
453
+ # locals in the traceback via the exception instance in the closure.
454
+ cloned_err = copy.deepcopy(err)
455
+ self._run = functools.partial(_raise_error, cloned_err)
456
+ raise err
457
+
458
+ device = driver.active.get_current_device()
459
+ # create launcher
460
+ self._run = driver.active.launcher_cls(self.src, self.metadata)
461
+ # not enough shared memory to run the kernel
462
+ max_shared = max_shared_mem(device)
463
+ if self.metadata.shared > max_shared:
464
+ raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
465
+ if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
466
+ # Use blackwell max tmem size for now, this should be moved in device properties
467
+ max_tmem_size = 512 # tmem size in number of columns
468
+ if self.metadata.tmem_size > max_tmem_size:
469
+ raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory"))
470
+ if knobs.runtime.kernel_load_start_hook is not None:
471
+ knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
472
+ # TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
473
+ self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
474
+ self.name, self.kernel, self.metadata.shared, device)
475
+ warp_size = driver.active.get_current_target().warp_size
476
+ if self.metadata.num_warps * warp_size > self.n_max_threads:
477
+ raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
478
+ if knobs.runtime.kernel_load_end_hook is not None:
479
+ knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
480
+
481
+ @property
482
+ def run(self):
483
+ if self._run is None:
484
+ self._init_handles()
485
+ return self._run
486
+
487
+ def launch_metadata(self, grid, stream, *args):
488
+ if knobs.runtime.launch_enter_hook is None:
489
+ return None
490
+ self._init_handles()
491
+ ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
492
+ if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
493
+ return ret
494
+ arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)}
495
+ ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
496
+ return ret
497
+
498
+ def __getitem__(self, grid):
499
+ self._init_handles()
500
+
501
+ def runner(*args, stream=None):
502
+ if stream is None:
503
+ device = driver.active.get_current_device()
504
+ stream = driver.active.get_current_stream(device)
505
+ launch_metadata = self.launch_metadata(grid, stream, *args)
506
+ self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
507
+ knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
508
+
509
+ return runner
@@ -0,0 +1,51 @@
1
+ import ast
2
+ from typing import Optional
3
+ from ..errors import TritonError
4
+
5
+
6
+ class CompilationError(TritonError):
7
+ """Base class for all errors raised during compilation"""
8
+ source_line_count_max_in_message = 12
9
+
10
+ def _format_message(self) -> str:
11
+ node = self.node
12
+ if self.src is None:
13
+ source_excerpt = " <source unavailable>"
14
+ else:
15
+ if hasattr(node, 'lineno'):
16
+ source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
17
+ if source_excerpt:
18
+ source_excerpt.append(' ' * node.col_offset + '^')
19
+ source_excerpt = '\n'.join(source_excerpt)
20
+ else:
21
+ source_excerpt = " <source empty>"
22
+ else:
23
+ source_excerpt = self.src
24
+
25
+ message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr(
26
+ node, 'lineno') else source_excerpt
27
+ if self.error_message:
28
+ message += '\n' + self.error_message
29
+ return message
30
+
31
+ def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None):
32
+ self.src = src
33
+ self.node = node
34
+ self.error_message = error_message
35
+ self.message = self._format_message()
36
+
37
+ def __str__(self):
38
+ return self.message
39
+
40
+ def __reduce__(self):
41
+ # this is necessary to make CompilationError picklable
42
+ return type(self), (self.src, self.node, self.error_message)
43
+
44
+
45
+ class CompileTimeAssertionFailure(CompilationError):
46
+ """Specific exception for failed tests in `static_assert` invocations"""
47
+ pass
48
+
49
+
50
+ class UnsupportedLanguageConstruct(CompilationError):
51
+ pass
File without changes
triton/errors.py ADDED
@@ -0,0 +1,5 @@
1
+ """Base class for all errors raised by Triton"""
2
+
3
+
4
+ class TritonError(Exception):
5
+ ...
File without changes
@@ -0,0 +1,5 @@
1
+ from . import nvidia
2
+ from ._runtime import constexpr_function, jit
3
+ from triton.language.core import must_use_result
4
+
5
+ __all__ = ["constexpr_function", "jit", "must_use_result", "nvidia"]
File without changes
@@ -0,0 +1,102 @@
1
+ from __future__ import annotations
2
+ from triton.compiler.compiler import ASTSource
3
+ from triton.backends.compiler import Language
4
+ from triton.runtime.jit import JITFunction, constexpr_function
5
+ from typing import TypeVar, Optional, Callable, Iterable, Union
6
+ from triton._C.libtriton import ir
7
+
8
+ T = TypeVar("T")
9
+
10
+ __all__ = ["constexpr_function", "jit"]
11
+
12
+
13
+ class GluonASTSource(ASTSource):
14
+
15
+ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
16
+ super().__init__(fn, signature, constexprs, attrs)
17
+ self.language = Language.GLUON
18
+ self.ext = "ttgir"
19
+
20
+ def make_ir(self, target, options, codegen_fns, module_map, context):
21
+ from triton.compiler.compiler import make_backend
22
+ from triton.compiler.code_generator import ast_to_ttir
23
+
24
+ builder = ir.builder(context)
25
+ module = builder.create_module()
26
+
27
+ # Assign module attributes eagerly, as they are needed to verify layouts
28
+ backend = make_backend(target)
29
+ target = backend.get_target_name(options)
30
+
31
+ module.set_attr("ttg.target", builder.get_string_attr(target))
32
+ module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
33
+ module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
34
+ module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
35
+
36
+ is_cuda = options.backend_name == "cuda"
37
+ if is_cuda and options.maxnreg is not None:
38
+ module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
39
+
40
+ module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
41
+ module_map=module_map, module=module)
42
+ return module
43
+
44
+
45
+ class GluonJITFunction(JITFunction[T]):
46
+
47
+ def create_binder(self):
48
+ result = super().create_binder()
49
+ self.ASTSource = GluonASTSource
50
+ return result
51
+
52
+ def is_gluon(self):
53
+ return True
54
+
55
+
56
+ def jit(
57
+ fn: Optional[T] = None,
58
+ *,
59
+ version=None,
60
+ repr: Optional[Callable] = None,
61
+ launch_metadata: Optional[Callable] = None,
62
+ do_not_specialize: Optional[Iterable[int | str]] = None,
63
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
64
+ debug: Optional[bool] = None,
65
+ noinline: Optional[bool] = None,
66
+ ) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
67
+ """
68
+ Decorator for JIT-compiling a function using the Triton compiler.
69
+
70
+ :note: When a jit'd function is called, arguments are
71
+ implicitly converted to pointers if they have a :code:`.data_ptr()` method
72
+ and a `.dtype` attribute.
73
+
74
+ :note: This function will be compiled and run on the GPU. It will only have access to:
75
+
76
+ * python primitives,
77
+ * builtins within the triton package,
78
+ * arguments to this function,
79
+ * other jit'd functions
80
+
81
+ :param fn: the function to be jit-compiled
82
+ :type fn: Callable
83
+ """
84
+
85
+ def decorator(fn: T) -> JITFunction[T]:
86
+ assert callable(fn)
87
+ return GluonJITFunction(
88
+ fn,
89
+ version=version,
90
+ do_not_specialize=do_not_specialize,
91
+ do_not_specialize_on_alignment=do_not_specialize_on_alignment,
92
+ debug=debug,
93
+ noinline=noinline,
94
+ repr=repr,
95
+ launch_metadata=launch_metadata,
96
+ )
97
+
98
+ if fn is not None:
99
+ return decorator(fn)
100
+
101
+ else:
102
+ return decorator