triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1614 @@
1
+ import ast
2
+ import builtins
3
+ import contextlib
4
+ import copy
5
+ import inspect
6
+ import re
7
+ import warnings
8
+ import textwrap
9
+ import itertools
10
+ from dataclasses import dataclass
11
+ from types import ModuleType
12
+ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
13
+
14
+ from .. import knobs, language
15
+ from .._C.libtriton import ir, gluon_ir
16
+ from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
17
+ from ..language.core import _unwrap_if_constexpr, base_value, base_type
18
+ # ideally we wouldn't need any runtime component
19
+ from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction
20
+ from .._utils import find_paths_if, get_iterable_path, set_iterable_path
21
+
22
+ from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
23
+
24
+
25
+ def check_identifier_legality(name, type):
26
+ pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$'
27
+ if not re.match(pattern, name):
28
+ raise CompilationError(f"invalid {type} identifier: {name}", name)
29
+ return name
30
+
31
+
32
+ def mangle_fn(name, arg_tys, constants, caller_context):
33
+ # doesn't mangle ret type, which must be a function of arg tys
34
+ mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
35
+ mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
36
+ mangled_constants = mangled_constants.replace('.', '_d_')
37
+ mangled_constants = mangled_constants.replace("'", '_sq_')
38
+ # [ and ] are not allowed in LLVM identifiers
39
+ mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
40
+ ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
41
+ if caller_context is not None:
42
+ ret += caller_context.mangle()
43
+ return ret
44
+
45
+
46
+ def _is_triton_value(o: Any) -> bool:
47
+ return isinstance(o, base_value)
48
+
49
+
50
+ def _is_triton_tensor(o: Any) -> bool:
51
+ return isinstance(o, tensor)
52
+
53
+
54
+ def _is_constexpr(o: Any) -> bool:
55
+ return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
56
+
57
+
58
+ def _is_non_scalar_tensor(o: Any) -> bool:
59
+ return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
60
+
61
+
62
+ def _is_list_like(o: Any) -> bool:
63
+ return isinstance(o, (list, tuple))
64
+
65
+
66
+ def _check_fn_args(node, fn, args):
67
+ if fn.noinline:
68
+ for idx, arg in enumerate(args):
69
+ if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
70
+ raise UnsupportedLanguageConstruct(
71
+ fn.src, node,
72
+ f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
73
+ )
74
+
75
+
76
+ def _is_namedtuple(val):
77
+ return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")
78
+
79
+
80
+ def _apply_to_tuple_values(value, fn):
81
+ if _is_namedtuple(type(value)):
82
+ fields = value._fields
83
+ elif isinstance(value, language.tuple):
84
+ fields = value.type.fields
85
+ else:
86
+ assert False, f"Unsupported type {type(value)}"
87
+
88
+ vals = [fn(v) for v in value]
89
+ vals = [constexpr(v) if v is None else v for v in vals]
90
+ types = [v.type for v in vals]
91
+ return language.tuple(vals, language.tuple_type(types, fields))
92
+
93
+
94
+ def flatten_values_to_ir(values: Iterable[base_value]):
95
+ handles = []
96
+ for v in values:
97
+ v._flatten_ir(handles)
98
+ return handles
99
+
100
+
101
+ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
102
+ cursor = 0
103
+ for ty in types:
104
+ value, cursor = ty._unflatten_ir(handles, cursor)
105
+ yield value
106
+ assert cursor == len(handles)
107
+
108
+
109
+ _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
110
+
111
+
112
+ def _clone_triton_value(val):
113
+ handles = []
114
+ val._flatten_ir(handles)
115
+ clone, _ = val.type._unflatten_ir(handles, 0)
116
+ return clone
117
+
118
+
119
+ def _clone_scope(scope):
120
+ return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
121
+
122
+
123
+ class enter_sub_region:
124
+
125
+ def __init__(self, generator):
126
+ self.generator = generator
127
+
128
+ def __enter__(self):
129
+ # record lscope & local_defs in the parent scope
130
+ self.liveins = _clone_scope(self.generator.lscope)
131
+ self.prev_defs = _clone_scope(self.generator.local_defs)
132
+ self.generator.local_defs = {}
133
+ self.insert_block = self.generator.builder.get_insertion_block()
134
+ self.insert_point = self.generator.builder.get_insertion_point()
135
+ return self.liveins, self.insert_block
136
+
137
+ def __exit__(self, *args, **kwargs):
138
+ self.generator.builder.restore_insertion_point(self.insert_point)
139
+ self.generator.lscope = self.liveins
140
+ self.generator.local_defs = self.prev_defs
141
+
142
+
143
+ # Check if the given syntax node has an "early" return
144
+ class ContainsReturnChecker(ast.NodeVisitor):
145
+
146
+ def __init__(self, gscope):
147
+ self.gscope = gscope
148
+
149
+ def _visit_stmts(self, body) -> bool:
150
+ return any(self.visit(s) for s in body)
151
+
152
+ def _visit_function(self, fn) -> bool:
153
+ # No need to check within the function as it won't cause an early return.
154
+ # If the function itself has unstructured control flow we may not be able to inline it causing poor performance,
155
+ # we should check for this and emit a warning.
156
+ return False
157
+
158
+ def generic_visit(self, node) -> bool:
159
+ ret = False
160
+ for _, value in ast.iter_fields(node):
161
+ if isinstance(value, list):
162
+ for item in value:
163
+ if isinstance(item, ast.AST):
164
+ ret = ret or self.visit(item)
165
+ elif isinstance(value, ast.AST):
166
+ ret = ret or self.visit(value)
167
+ return ret
168
+
169
+ def visit_Attribute(self, node: ast.Attribute) -> bool:
170
+ # If the left part is a name, it's possible that
171
+ # we call triton native function or a jit function from another module.
172
+ # If the left part is not a name, it must return a tensor or a constexpr
173
+ # whose methods do not contain return statements
174
+ # e.g., (tl.load(x)).to(y)
175
+ # So we only check if the expressions within value have return or not
176
+ if isinstance(node.value, ast.Name):
177
+ if node.value.id in self.gscope:
178
+ value = self.gscope[node.value.id]
179
+ fn = getattr(value, node.attr)
180
+ return self._visit_function(fn)
181
+ return False
182
+ return self.visit(node.value)
183
+
184
+ def visit_Name(self, node: ast.Name) -> bool:
185
+ if type(node.ctx) is ast.Store:
186
+ return False
187
+ if node.id in self.gscope:
188
+ fn = self.gscope[node.id]
189
+ return self._visit_function(fn)
190
+ return False
191
+
192
+ def visit_Return(self, node: ast.Return) -> bool:
193
+ return True
194
+
195
+ def visit_Assign(self, node: ast.Assign) -> bool:
196
+ # There couldn't be an early return
197
+ # x = ...
198
+ return False
199
+
200
+ def visit_AugAssign(self, node: ast.AugAssign) -> bool:
201
+ # There couldn't be an early return
202
+ # x += ...
203
+ return False
204
+
205
+ def visit_Module(self, node: ast.Module) -> bool:
206
+ return self._visit_stmts(node.body)
207
+
208
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
209
+ return self._visit_stmts(node.body)
210
+
211
+ def visit_If(self, node: ast.If) -> bool:
212
+ # TODO: optimize the following case in which we actually don't have
213
+ # a return when static_cond is false:
214
+ # if dynamic_cond
215
+ # if static_cond
216
+ # func_with_return
217
+ # else
218
+ # func_without_return
219
+ ret = self._visit_stmts(node.body)
220
+ if node.orelse:
221
+ ret = ret or self._visit_stmts(node.orelse)
222
+ return ret
223
+
224
+ def visit_IfExp(self, node: ast.IfExp) -> bool:
225
+ return self.visit(node.body) or self.visit(node.orelse)
226
+
227
+ def visit_Call(self, node: ast.Call) -> bool:
228
+ return self.visit(node.func)
229
+
230
+
231
+ class ASTFunction:
232
+
233
+ def __init__(self, ret_types, arg_types, constants, attrs):
234
+ self.ret_types = ret_types
235
+ self.arg_types = arg_types
236
+ self.constants = constants
237
+ self.attrs = attrs
238
+
239
+ def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
240
+ ir_types = []
241
+ for ty in types:
242
+ if ty is None:
243
+ continue
244
+ ty._flatten_ir_types(builder, ir_types)
245
+ return ir_types
246
+
247
+ def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
248
+ return self.flatten_ir_types(builder, self.ret_types)
249
+
250
+ def serialize(self, builder: ir.builder):
251
+ # fill up IR values in template
252
+ # > build function
253
+ is_val = lambda path, _: path not in self.constants and _ is not None
254
+ val_paths = list(find_paths_if(self.arg_types, is_val))
255
+ arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
256
+ arg_types_ir = self.flatten_ir_types(builder, arg_types)
257
+ ret_types_ir = self.return_types_ir(builder)
258
+ return builder.get_function_ty(arg_types_ir, ret_types_ir)
259
+
260
+ def deserialize(self, fn):
261
+ # create "template"
262
+ def make_template(ty):
263
+ if isinstance(ty, (list, tuple, language.tuple_type)):
264
+ return language.tuple([make_template(x) for x in ty], ty)
265
+ return language.constexpr(None)
266
+
267
+ vals = make_template(self.arg_types)
268
+ is_val = lambda path, _: path not in self.constants and _ is not None
269
+ val_paths = list(find_paths_if(self.arg_types, is_val))
270
+ # > add IR values to the template
271
+ cursor = 0
272
+ handles = [fn.args(i) for i in range(fn.get_num_args())]
273
+ for path in val_paths:
274
+ ty = get_iterable_path(self.arg_types, path)
275
+ # > set attributes
276
+ attr_specs = self.attrs.get(path, [])
277
+ for attr_name, attr_val in attr_specs:
278
+ fn.set_arg_attr(cursor, attr_name, attr_val)
279
+ # > build frontend value
280
+ val, cursor = ty._unflatten_ir(handles, cursor)
281
+ set_iterable_path(vals, path, val)
282
+ # > add constexpr values to the template
283
+ constants = self.constants
284
+ for path, val in constants.items():
285
+ set_iterable_path(vals, path, language.constexpr(val))
286
+ return vals
287
+
288
+
289
+ @dataclass(frozen=True)
290
+ class BoundJITMethod:
291
+ __self__: base_value
292
+ __func__: JITFunction
293
+
294
+
295
+ class CodeGenerator(ast.NodeVisitor):
296
+
297
+ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
298
+ module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
299
+ noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
300
+ self.context = context
301
+ self.is_gluon = is_gluon
302
+ if is_gluon:
303
+ from triton.experimental.gluon.language._semantic import GluonSemantic
304
+ self.builder = gluon_ir.GluonOpBuilder(context)
305
+ self.semantic = GluonSemantic(self.builder)
306
+ else:
307
+ from triton.language.semantic import TritonSemantic
308
+ self.builder = ir.builder(context)
309
+ self.semantic = TritonSemantic(self.builder)
310
+
311
+ self.name_loc_as_prefix = None
312
+ self.file_name = file_name
313
+ # node.lineno starts from 1, so we need to subtract 1
314
+ self.begin_line = begin_line - 1
315
+ self.builder.set_loc(file_name, begin_line, 0)
316
+ self.builder.options = options
317
+ # dict of functions provided by the backend. Below are the list of possible functions:
318
+ # Convert custom types not natively supported on HW.
319
+ # convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None)
320
+ self.builder.codegen_fns = codegen_fns
321
+ self.builder.module_map = {} if module_map is None else module_map
322
+ self.module = self.builder.create_module() if module is None else module
323
+ self.function_ret_types = {} if function_types is None else function_types
324
+ self.prototype = prototype
325
+
326
+ self.gscope = {}
327
+ for k, v in gscope.items():
328
+ if isinstance(v, ModuleType):
329
+ self.gscope[k] = module_map.get(v.__name__, v)
330
+ continue
331
+
332
+ module_name = getattr(v, "__module__", "")
333
+ if module_name in module_map:
334
+ self.gscope[k] = getattr(module_map[module_name], v.__name__)
335
+ else:
336
+ self.gscope[k] = v
337
+
338
+ self.lscope = {}
339
+ self.jit_fn = jit_fn
340
+ # TODO: we currently generate illegal names for non-kernel functions involving constexprs!
341
+ if is_kernel:
342
+ function_name = function_name[function_name.rfind('.') + 1:]
343
+ function_name = check_identifier_legality(function_name, "function")
344
+ self.function_name = function_name
345
+ self.is_kernel = is_kernel
346
+ self.cur_node = None
347
+ self.noinline = noinline
348
+ self.caller_context = caller_context
349
+ self.scf_stack = []
350
+ self.ret_type = None
351
+ # SSA-construction
352
+ # name => language.tensor
353
+ self.local_defs: Dict[str, tensor] = {}
354
+ self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
355
+ self.fn = None
356
+ # Are we currently visiting an ast.arg's default value? These have some
357
+ # special handling.
358
+ self.visiting_arg_default_value = False
359
+
360
+ builtin_namespace: Dict[str, Any] = {
361
+ _.__name__: _
362
+ for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
363
+ }
364
+ builtin_namespace.update((
365
+ ('print', language.core.device_print),
366
+ ('min', language.minimum),
367
+ ('max', language.maximum),
368
+ ))
369
+
370
+ def _unsupported(self, node, message):
371
+ return UnsupportedLanguageConstruct(self.jit_fn.src, node, message)
372
+
373
+ def _is_constexpr_global(self, name):
374
+ absent_marker = object()
375
+ val = self.gscope.get(name, absent_marker)
376
+ if val is absent_marker:
377
+ return False
378
+
379
+ if _is_constexpr(val):
380
+ return True
381
+
382
+ return False
383
+
384
+ def _define_name_lookup(self):
385
+
386
+ def local_lookup(name: str, absent):
387
+ # this needs to be re-fetched from `self` every time, because it gets switched occasionally
388
+ return self.lscope.get(name, absent)
389
+
390
+ def global_lookup(name: str, absent):
391
+ val = self.gscope.get(name, absent)
392
+ # The high-level rule is that only constexpr globals are allowed.
393
+ # But actually a bunch of other things, such as module imports, are
394
+ # technically Python globals. We have to allow these too!
395
+ if any([
396
+ val is absent,
397
+ name in self.builtin_namespace, #
398
+ type(val) is ModuleType, #
399
+ isinstance(val, JITCallable), #
400
+ getattr(val, "__triton_builtin__", False), #
401
+ getattr(val, "__triton_aggregate__", False), #
402
+ getattr(val, "__module__", "").startswith("triton.language"), #
403
+ getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), #
404
+ isinstance(val, language.dtype), #
405
+ _is_namedtuple(val),
406
+ self._is_constexpr_global(name), #
407
+ # Allow accesses to globals while visiting an ast.arg
408
+ # because you should be able to do
409
+ # @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
410
+ self.visiting_arg_default_value, #
411
+ knobs.compilation.allow_non_constexpr_globals,
412
+ ]):
413
+ return val
414
+ raise NameError(
415
+ textwrap.dedent(f"""\
416
+ Cannot access global variable {name} from within @jit'ed
417
+ function. Triton kernels can only access global variables that
418
+ are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from
419
+ annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the
420
+ envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
421
+ promise to support this forever.""").replace("\n", " "))
422
+
423
+ absent_marker = object()
424
+
425
+ def name_lookup(name: str) -> Any:
426
+ absent = absent_marker
427
+ for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get:
428
+ value = lookup_function(name, absent)
429
+ if value is not absent:
430
+ return value
431
+ raise NameError(f'{name} is not defined')
432
+
433
+ return name_lookup
434
+
435
+ @contextlib.contextmanager
436
+ def _name_loc_prefix(self, prefix):
437
+ self.name_loc_as_prefix = prefix
438
+ yield
439
+ self.name_loc_as_prefix = None
440
+
441
+ def _maybe_set_loc_to_name(self, val, name):
442
+ if isinstance(val, (ir.value, ir.block_argument)):
443
+ val.set_loc(self.builder.create_name_loc(name, val.get_loc()))
444
+ elif _is_triton_value(val):
445
+ handles = []
446
+ val._flatten_ir(handles)
447
+ for handle in handles:
448
+ handle.set_loc(self.builder.create_name_loc(name, handle.get_loc()))
449
+
450
+ def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
451
+ ''' This function:
452
+ called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
453
+ 1. record local defined name (FIXME: should consider control flow)
454
+ 2. store tensor in self.lvalue
455
+ '''
456
+ self.lscope[name] = value
457
+ self.local_defs[name] = value
458
+
459
+ def _get_insertion_point_and_loc(self):
460
+ # XXX: this is a hack to get the location of the insertion point.
461
+ # The insertion point's location could be invalid sometimes,
462
+ # so we need to explicitly set the location
463
+ loc = self.builder.get_loc()
464
+ ip = self.builder.get_insertion_point()
465
+ return ip, loc
466
+
467
+ def _set_insertion_point_and_loc(self, ip, loc):
468
+ self.builder.restore_insertion_point(ip)
469
+ self.builder.set_loc(loc)
470
+
471
+ def _find_carries(self, node, liveins):
472
+ # create loop body block
473
+ block = self.builder.create_block()
474
+ self.builder.set_insertion_point_to_start(block)
475
+ # dry visit loop body
476
+ self.scf_stack.append(node)
477
+ self.visit_compound_statement(node.body)
478
+ self.scf_stack.pop()
479
+ block.erase()
480
+
481
+ # If a variable (name) has changed value within the loop, then it's
482
+ # a loop-carried variable. (The new and old value must be of the
483
+ # same type)
484
+ init_tys = []
485
+ init_handles = []
486
+ names = []
487
+
488
+ for name, live_val in liveins.items():
489
+ if _is_triton_value(live_val):
490
+ loop_val = self.lscope[name]
491
+ self._verify_loop_carried_variable(name, loop_val, live_val)
492
+
493
+ live_handles = flatten_values_to_ir([live_val])
494
+ loop_handles = flatten_values_to_ir([loop_val])
495
+ if live_handles != loop_handles:
496
+ names.append(name)
497
+ init_tys.append(live_val.type)
498
+ init_handles.extend(live_handles)
499
+ else:
500
+ assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
501
+
502
+ # reset local scope to not pick up local defs from the dry run.
503
+ self.lscope = liveins.copy()
504
+ self.local_defs = {}
505
+
506
+ return names, init_handles, init_tys
507
+
508
+ #
509
+ # AST visitor
510
+ #
511
+ def visit_compound_statement(self, stmts):
512
+ # Ensure that stmts is iterable
513
+ if not _is_list_like(stmts):
514
+ stmts = [stmts]
515
+ for stmt in stmts:
516
+ self.visit(stmt)
517
+ # Stop parsing as soon as we hit a `return` statement; everything
518
+ # after this is dead code.
519
+ if isinstance(stmt, ast.Return):
520
+ break
521
+
522
+ def visit_Module(self, node):
523
+ ast.NodeVisitor.generic_visit(self, node)
524
+
525
+ def visit_List(self, node):
526
+ ctx = self.visit(node.ctx)
527
+ assert ctx is None
528
+ elts = language.tuple([self.visit(elt) for elt in node.elts])
529
+ return elts
530
+
531
+ def visit_ListComp(self, node: ast.ListComp):
532
+ if len(node.generators) != 1:
533
+ raise ValueError("nested comprehensions are not supported")
534
+
535
+ comp = node.generators[0]
536
+ iter = self.visit(comp.iter)
537
+ if not isinstance(iter, tl_tuple):
538
+ raise NotImplementedError("only tuple comprehensions are supported")
539
+
540
+ results = []
541
+ for item in iter:
542
+ self.set_value(comp.target.id, item)
543
+ results.append(self.visit(node.elt))
544
+ return tl_tuple(results)
545
+
546
+ # By design, only non-kernel functions can return
547
+ def visit_Return(self, node):
548
+ ret_value = self.visit(node.value)
549
+ handles = []
550
+
551
+ def decay(value):
552
+ if isinstance(value, language.tuple):
553
+ return _apply_to_tuple_values(value, decay)
554
+ elif isinstance(value, (language.constexpr, int, float)):
555
+ return self.semantic.to_tensor(value)
556
+ return value
557
+
558
+ ret_value = decay(ret_value)
559
+
560
+ if ret_value is None:
561
+ ret_ty = language.void
562
+ else:
563
+ assert isinstance(ret_value, language.core.base_value)
564
+ ret_value._flatten_ir(handles)
565
+ ret_ty = ret_value.type
566
+ self.builder.ret(handles)
567
+ if self.ret_type is None:
568
+ self.ret_type = ret_ty
569
+ elif self.ret_type != ret_ty:
570
+ raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}')
571
+
572
+ # A return op must always terminate the basic block, so we create a dead
573
+ # basic block in case there are any ops after the return.
574
+ post_ret_block = self.builder.create_block()
575
+ self.builder.set_insertion_point_to_end(post_ret_block)
576
+
577
+ def visit_Starred(self, node) -> Any:
578
+ args = self.visit(node.value)
579
+ assert isinstance(args, language.core.tuple)
580
+ return args.values
581
+
582
+ def visit_FunctionDef(self, node):
583
+ arg_names, kwarg_names = self.visit(node.args)
584
+ if self.fn:
585
+ raise self._unsupported(node, "nested function definition is not supported.")
586
+ # initialize defaults
587
+ for i, default_value in enumerate(node.args.defaults[::-1]):
588
+ arg_node = node.args.args[-i - 1]
589
+ annotation = arg_node.annotation
590
+ name = arg_node.arg
591
+ st_target = ast.Name(id=name, ctx=ast.Store())
592
+ if annotation is None:
593
+ init_node = ast.Assign(targets=[st_target], value=default_value)
594
+ else:
595
+ init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
596
+ try:
597
+ assert not self.visiting_arg_default_value
598
+ self.visiting_arg_default_value = True
599
+ self.visit(init_node)
600
+ finally:
601
+ self.visiting_arg_default_value = False
602
+
603
+ # initialize function
604
+ visibility = "public" if self.is_kernel else "private"
605
+ fn_ty = self.prototype.serialize(self.builder)
606
+ self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline)
607
+ self.module.push_back(self.fn)
608
+ entry = self.fn.add_entry_block()
609
+ arg_values = self.prototype.deserialize(self.fn)
610
+ if self.caller_context is not None:
611
+ self.caller_context.initialize_callee(self.fn, self.builder)
612
+ # bind arguments to symbols
613
+ for arg_name, arg_value in zip(arg_names, arg_values):
614
+ self._maybe_set_loc_to_name(arg_value, arg_name)
615
+ self.set_value(arg_name, arg_value)
616
+ insert_pt = self.builder.get_insertion_block()
617
+ self.builder.set_insertion_point_to_start(entry)
618
+ # visit function body
619
+ self.visit_compound_statement(node.body)
620
+
621
+ # finalize function
622
+ assert not self.builder.get_insertion_block().has_terminator()
623
+ if self.ret_type is None or self.ret_type == language.void:
624
+ self.ret_type = language.void
625
+ self.builder.ret([])
626
+ else:
627
+ if isinstance(self.ret_type, language.tuple_type):
628
+ self.prototype.ret_types = self.ret_type.types
629
+ else:
630
+ self.prototype.ret_types = [self.ret_type]
631
+ self.fn.reset_type(self.prototype.serialize(self.builder))
632
+ self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)])
633
+ self.fn.finalize()
634
+
635
+ if insert_pt:
636
+ self.builder.set_insertion_point_to_end(insert_pt)
637
+
638
+ def visit_arguments(self, node):
639
+ arg_names = []
640
+ for arg in node.args:
641
+ arg_names += [self.visit(arg)]
642
+ kwarg_names = self.visit(node.kwarg)
643
+ return arg_names, kwarg_names
644
+
645
+ def visit_arg(self, node):
646
+ ast.NodeVisitor.generic_visit(self, node)
647
+ return node.arg
648
+
649
+ def visit_AnnAssign(self, node):
650
+ # extract attributes
651
+ annotation = self.visit(node.annotation)
652
+ target = self.visit(node.target)
653
+ value = self.visit(node.value)
654
+ # constexpr
655
+ if annotation == constexpr:
656
+ if target in self.lscope:
657
+ raise ValueError(f'{target} is already defined.'
658
+ f' constexpr cannot be reassigned.')
659
+ value = constexpr(value)
660
+ self.lscope[target] = value
661
+ return self.lscope[target]
662
+ # default: call visit_Assign
663
+ return self.visit_Assign(node)
664
+
665
+ def assignTarget(self, target, value):
666
+ assert isinstance(target.ctx, ast.Store)
667
+ if isinstance(target, ast.Subscript):
668
+ return self.visit_Subscript_Store(target, value)
669
+ if isinstance(target, ast.Tuple):
670
+ for i, target in enumerate(target.elts):
671
+ self.assignTarget(target, value.values[i])
672
+ return
673
+ if isinstance(target, ast.Attribute):
674
+ raise NotImplementedError("Attribute assignment is not supported in triton")
675
+ assert isinstance(target, ast.Name)
676
+ self.set_value(self.visit(target), value)
677
+
678
+ def visit_Assign(self, node):
679
+ # construct values to assign
680
+ def _sanitize_value(value):
681
+ if isinstance(value, language.tuple):
682
+ return _apply_to_tuple_values(value, _sanitize_value)
683
+ native_nontensor_types = (language.dtype, language.tuple)
684
+ value = _unwrap_if_constexpr(value)
685
+ if value is not None and \
686
+ not _is_triton_value(value) and \
687
+ not isinstance(value, native_nontensor_types):
688
+ value = self.semantic.to_tensor(value)
689
+ return value
690
+
691
+ targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
692
+ assert len(targets) == 1
693
+ target = targets[0]
694
+ if isinstance(target, ast.Name):
695
+ with self._name_loc_prefix(target.id):
696
+ values = _sanitize_value(self.visit(node.value))
697
+ else:
698
+ values = _sanitize_value(self.visit(node.value))
699
+ self.assignTarget(target, values)
700
+
701
+ def visit_AugAssign(self, node):
702
+ lhs = copy.deepcopy(node.target)
703
+ lhs.ctx = ast.Load()
704
+ rhs = ast.BinOp(lhs, node.op, node.value)
705
+ assign = ast.Assign(targets=[node.target], value=rhs)
706
+ self.visit(assign)
707
+ return self.visit(lhs)
708
+
709
+ def visit_Name(self, node):
710
+ if type(node.ctx) is ast.Store:
711
+ return node.id
712
+ return self.dereference_name(node.id)
713
+
714
+ def visit_Store(self, node):
715
+ ast.NodeVisitor.generic_visit(self, node)
716
+
717
+ def visit_Load(self, node):
718
+ ast.NodeVisitor.generic_visit(self, node)
719
+
720
+ def visit_Tuple(self, node):
721
+ args = [self.visit(x) for x in node.elts]
722
+ return language.tuple(args)
723
+
724
+ def _apply_binary_method(self, method_name, lhs, rhs):
725
+ # TODO: raise something meaningful if getattr fails below, esp for reverse method
726
+ if _is_triton_tensor(lhs):
727
+ return getattr(lhs, method_name)(rhs, _semantic=self.semantic)
728
+ if _is_triton_tensor(rhs):
729
+ reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
730
+ return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
731
+ if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
732
+ lhs = constexpr(lhs)
733
+ return getattr(lhs, method_name)(rhs)
734
+
735
+ def visit_BinOp(self, node):
736
+ lhs = self.visit(node.left)
737
+ rhs = self.visit(node.right)
738
+ method_name = self._method_name_for_bin_op.get(type(node.op))
739
+ if method_name is None:
740
+ raise self._unsupported(node,
741
+ "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
742
+ return self._apply_binary_method(method_name, lhs, rhs)
743
+
744
+ _method_name_for_bin_op: Dict[Type[ast.operator], str] = {
745
+ ast.Add: '__add__',
746
+ ast.Sub: '__sub__',
747
+ ast.Mult: '__mul__',
748
+ ast.Div: '__truediv__',
749
+ ast.FloorDiv: '__floordiv__',
750
+ ast.Mod: '__mod__',
751
+ ast.Pow: '__pow__',
752
+ ast.LShift: '__lshift__',
753
+ ast.RShift: '__rshift__',
754
+ ast.BitAnd: '__and__',
755
+ ast.BitOr: '__or__',
756
+ ast.BitXor: '__xor__',
757
+ }
758
+
759
+ def visit_then_else_blocks(self, node, liveins, then_block, else_block):
760
+ # then block
761
+ self.builder.set_insertion_point_to_start(then_block)
762
+ self.visit_compound_statement(node.body)
763
+ then_block = self.builder.get_insertion_block()
764
+ then_defs = self.local_defs.copy()
765
+ then_vals = self.lscope.copy()
766
+ # else block
767
+ else_defs = {}
768
+ else_vals = liveins.copy()
769
+ if node.orelse:
770
+ self.builder.set_insertion_point_to_start(else_block)
771
+ self.lscope = liveins.copy()
772
+ self.local_defs = {}
773
+ self.visit_compound_statement(node.orelse)
774
+ else_defs = self.local_defs.copy()
775
+ else_block = self.builder.get_insertion_block()
776
+ else_vals = self.lscope.copy()
777
+
778
+ # update block arguments
779
+ names = []
780
+ # variables in livein whose value is updated in `if`
781
+ for name, value in liveins.items():
782
+ # livein variable changed value in either then or else
783
+ if not _is_triton_value(value):
784
+ continue
785
+ then_handles = flatten_values_to_ir([then_vals[name]])
786
+ else_handles = flatten_values_to_ir([else_vals[name]])
787
+ if then_handles == else_handles:
788
+ continue
789
+ names.append(name)
790
+ then_defs[name] = then_vals[name]
791
+ else_defs[name] = else_vals[name]
792
+ # check type
793
+ for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
794
+ type_equal = type(defs[name]) == type(value) # noqa: E721
795
+ assert type_equal and defs[name].type == value.type, \
796
+ f'initial value for `{name}` is of type {value}, '\
797
+ f'but the {block_name} block redefines it as {defs[name]}'
798
+
799
+ # variables that are both in then and else but not in liveins
800
+ # TODO: could probably be cleaned up
801
+ for name in sorted(then_defs.keys() & else_defs.keys()):
802
+ if name in names:
803
+ continue
804
+ then_val = then_defs[name]
805
+ then_ty = then_val.type
806
+ else_val = else_defs[name]
807
+ else_ty = else_val.type
808
+ type_equal = type(then_val) == type(else_val) # noqa: E721
809
+ assert type_equal and then_ty == else_ty, \
810
+ f'Mismatched type for {name} between then block ({then_ty}) '\
811
+ f'and else block ({else_ty})'
812
+ names.append(name)
813
+
814
+ return then_defs, else_defs, then_block, else_block, names
815
+
816
+ def visit_if_top_level(self, cond, node):
817
+ with enter_sub_region(self) as sr:
818
+ liveins, ip_block = sr
819
+ then_block = self.builder.create_block()
820
+ else_block = self.builder.create_block()
821
+ # create branch
822
+ self.builder.set_insertion_point_to_end(ip_block)
823
+ self.builder.create_cond_branch(cond.handle, then_block, else_block)
824
+ # visit then and else blocks
825
+ then_defs, else_defs, then_block, else_block, names = \
826
+ self.visit_then_else_blocks(node, liveins, then_block, else_block)
827
+ # create basic-block after conditional
828
+ endif_block = self.builder.create_block()
829
+ # then terminator
830
+ self.builder.set_insertion_point_to_end(then_block)
831
+ assert not then_block.has_terminator(), f"{then_block}"
832
+ then_handles = flatten_values_to_ir(then_defs[name] for name in names)
833
+ self.builder.create_branch(endif_block, then_handles)
834
+ # else terminator
835
+ self.builder.set_insertion_point_to_end(else_block)
836
+ assert not else_block.has_terminator(), f"{else_block}"
837
+ else_handles = flatten_values_to_ir(else_defs[name] for name in names)
838
+ self.builder.create_branch(endif_block, else_handles)
839
+ assert len(then_handles) == len(else_handles)
840
+ for then_h, else_h in zip(then_handles, else_handles):
841
+ ty = then_h.get_type()
842
+ assert ty == else_h.get_type()
843
+ endif_block.add_argument(ty)
844
+
845
+ # change block
846
+ self.builder.set_insertion_point_to_start(endif_block)
847
+ # update value
848
+ res_handles = [endif_block.arg(i) for i in range(len(then_handles))]
849
+ types = [then_defs[name].type for name in names]
850
+ new_values = unflatten_ir_values(res_handles, types)
851
+ for name, new_value in zip(names, new_values):
852
+ self.set_value(name, new_value)
853
+
854
+ # TODO: refactor
855
+ def visit_if_scf(self, cond, node):
856
+ with enter_sub_region(self) as sr:
857
+ liveins, _ = sr
858
+ ip, last_loc = self._get_insertion_point_and_loc()
859
+ then_block = self.builder.create_block()
860
+ else_block = self.builder.create_block() if node.orelse else None
861
+ then_defs, else_defs, then_block, else_block, names = \
862
+ self.visit_then_else_blocks(node, liveins, then_block, else_block)
863
+ # create if op
864
+ then_handles = flatten_values_to_ir(then_defs[name] for name in names)
865
+ for name, val in zip(names, then_handles):
866
+ self._maybe_set_loc_to_name(val, name)
867
+ self._set_insertion_point_and_loc(ip, last_loc)
868
+ if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
869
+ then_block.merge_block_before(if_op.get_then_block())
870
+ self.builder.set_insertion_point_to_end(if_op.get_then_block())
871
+ if len(names) > 0:
872
+ self.builder.create_yield_op(then_handles)
873
+ if not node.orelse:
874
+ else_block = if_op.get_else_block()
875
+ else:
876
+ else_block.merge_block_before(if_op.get_else_block())
877
+ self.builder.set_insertion_point_to_end(if_op.get_else_block())
878
+ if len(names) > 0:
879
+ else_handles = flatten_values_to_ir(else_defs[name] for name in names)
880
+ for name, val in zip(names, else_handles):
881
+ self._maybe_set_loc_to_name(val, name)
882
+ self.builder.create_yield_op(else_handles)
883
+ # update values
884
+ res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
885
+ types = [then_defs[name].type for name in names]
886
+ new_values = unflatten_ir_values(res_handles, types)
887
+ for name, new_value in zip(names, new_values):
888
+ self.set_value(name, new_value)
889
+
890
+ def visit_If(self, node):
891
+ cond = self.visit(node.test)
892
+
893
+ if _is_triton_tensor(cond):
894
+ if _is_non_scalar_tensor(cond):
895
+ raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
896
+ if cond.type.is_block():
897
+ warnings.warn(
898
+ "If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
899
+ % ast.unparse(node.test))
900
+ cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
901
+ cond = cond.to(language.int1, _semantic=self.semantic)
902
+ if ContainsReturnChecker(self.gscope).visit(node):
903
+ if self.scf_stack:
904
+ raise self._unsupported(
905
+ node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
906
+ self.visit_if_top_level(cond, node)
907
+ else:
908
+ self.visit_if_scf(cond, node)
909
+ else:
910
+ cond = _unwrap_if_constexpr(cond)
911
+ # not isinstance - we insist the real thing, no subclasses and no ducks
912
+ if type(cond) not in _condition_types:
913
+ raise self._unsupported(
914
+ node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
915
+ ', '.join(_.__name__ for _ in _condition_types),
916
+ type(cond).__name__))
917
+
918
+ active_block = node.body if cond else node.orelse
919
+ self.visit_compound_statement(active_block)
920
+
921
+ def visit_IfExp(self, node):
922
+ cond = self.visit(node.test)
923
+ if _is_triton_tensor(cond):
924
+ cond = cond.to(language.int1, _semantic=self.semantic)
925
+ # TODO: Deal w/ more complicated return types (e.g tuple)
926
+ with enter_sub_region(self):
927
+ ip, last_loc = self._get_insertion_point_and_loc()
928
+
929
+ then_block = self.builder.create_block()
930
+ self.builder.set_insertion_point_to_start(then_block)
931
+ then_val = self.semantic.to_tensor(self.visit(node.body))
932
+ then_block = self.builder.get_insertion_block()
933
+
934
+ else_block = self.builder.create_block()
935
+ self.builder.set_insertion_point_to_start(else_block)
936
+ # do not need to reset lscope since
937
+ # ternary expressions cannot define new variables
938
+ else_val = self.semantic.to_tensor(self.visit(node.orelse))
939
+ else_block = self.builder.get_insertion_block()
940
+
941
+ self._set_insertion_point_and_loc(ip, last_loc)
942
+
943
+ assert then_val.type == else_val.type, \
944
+ f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
945
+ ret_type = then_val.type
946
+
947
+ ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
948
+ if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
949
+ then_block.merge_block_before(if_op.get_then_block())
950
+ if ret_type_ir:
951
+ self.builder.set_insertion_point_to_end(if_op.get_then_block())
952
+ self.builder.create_yield_op([then_val.handle])
953
+
954
+ self.builder.set_insertion_point_to_end(if_op.get_then_block())
955
+ else_block.merge_block_before(if_op.get_else_block())
956
+ if ret_type_ir:
957
+ self.builder.set_insertion_point_to_end(if_op.get_else_block())
958
+ self.builder.create_yield_op([else_val.handle])
959
+ return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
960
+ else:
961
+ cond = _unwrap_if_constexpr(cond)
962
+
963
+ # not isinstance - we insist the real thing, no subclasses and no ducks
964
+ if type(cond) not in _condition_types:
965
+ raise self._unsupported(
966
+ node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
967
+ ', '.join(_.__name__ for _ in _condition_types),
968
+ type(cond).__name__))
969
+ if cond:
970
+ return self.visit(node.body)
971
+ else:
972
+ return self.visit(node.orelse)
973
+
974
+ def visit_With(self, node):
975
+ # Lower `with` statements by constructing context managers and calling their enter/exit hooks
976
+ # Instantiate each context manager with builder injection
977
+ if len(node.items) == 1: # Handle async_task
978
+ context = node.items[0].context_expr
979
+ withitemClass = self.visit(context.func)
980
+ if withitemClass == language.async_task:
981
+ args = [self.visit(arg) for arg in context.args]
982
+ with withitemClass(*args, _builder=self.builder):
983
+ self.visit_compound_statement(node.body)
984
+ return
985
+
986
+ cm_list = []
987
+ for item in node.items:
988
+ call = item.context_expr
989
+ fn = self.visit(call.func)
990
+ args = [self.visit(arg) for arg in call.args]
991
+ kws = dict(self.visit(kw) for kw in call.keywords)
992
+ cm = fn(*args, _semantic=self.semantic, **kws)
993
+ cm_list.append(cm)
994
+ for cm, item in zip(cm_list, node.items):
995
+ res = cm.__enter__()
996
+ if item.optional_vars is not None:
997
+ var_name = self.visit(item.optional_vars)
998
+ self.set_value(var_name, res)
999
+ if ContainsReturnChecker(self.gscope).visit(node):
1000
+ raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ")
1001
+ self.visit_compound_statement(node.body)
1002
+ for cm in reversed(cm_list):
1003
+ cm.__exit__(None, None, None)
1004
+
1005
+ def visit_Pass(self, node):
1006
+ pass
1007
+
1008
+ def visit_Compare(self, node):
1009
+ if not (len(node.comparators) == 1 and len(node.ops) == 1):
1010
+ raise self._unsupported(node, "simultaneous multiple comparison is not supported")
1011
+ lhs = self.visit(node.left)
1012
+ rhs = self.visit(node.comparators[0])
1013
+ lhs_value = _unwrap_if_constexpr(lhs)
1014
+ rhs_value = _unwrap_if_constexpr(rhs)
1015
+ if type(node.ops[0]) is ast.Is:
1016
+ return constexpr(lhs_value is rhs_value)
1017
+ if type(node.ops[0]) is ast.IsNot:
1018
+ return constexpr(lhs_value is not rhs_value)
1019
+ method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
1020
+ if method_name is None:
1021
+ raise self._unsupported(
1022
+ node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
1023
+ return self._apply_binary_method(method_name, lhs, rhs)
1024
+
1025
+ _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
1026
+ ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
1027
+ }
1028
+
1029
+ def visit_UnaryOp(self, node):
1030
+ operand = self.visit(node.operand)
1031
+ fn = self._method_name_for_unary_op.get(type(node.op))
1032
+ if fn is None:
1033
+ raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
1034
+ if _is_triton_tensor(operand):
1035
+ return getattr(operand, fn)(_semantic=self.semantic)
1036
+ try:
1037
+ return getattr(operand, fn)()
1038
+ except AttributeError:
1039
+ if fn == "__not__":
1040
+ return constexpr(not operand)
1041
+ raise self._unsupported(
1042
+ node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
1043
+
1044
+ _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
1045
+ ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
1046
+ }
1047
+
1048
+ def _verify_loop_carried_variable(self, name, loop_val, live_val):
1049
+ assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
1050
+ assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
1051
+ assert type(loop_val) is type(live_val), (
1052
+ f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}')
1053
+ assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
1054
+ f'Loop-carried variable {name} has initial type {live_val.type} '\
1055
+ f'but is re-assigned to {loop_val.type} in loop! '\
1056
+ f'Please make sure that the type stays consistent.'
1057
+
1058
+ def visit_withitem(self, node):
1059
+ return self.visit(node.context_expr)
1060
+
1061
+ def visit_While(self, node):
1062
+ with enter_sub_region(self) as sr:
1063
+ liveins, insert_block = sr
1064
+ ip, last_loc = self._get_insertion_point_and_loc()
1065
+
1066
+ names, init_handles, init_fe_tys = self._find_carries(node, liveins)
1067
+
1068
+ init_tys = [h.get_type() for h in init_handles]
1069
+ self._set_insertion_point_and_loc(ip, last_loc)
1070
+ while_op = self.builder.create_while_op(init_tys, init_handles)
1071
+ # merge the condition region
1072
+ before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys)
1073
+ self.builder.set_insertion_point_to_start(before_block)
1074
+ block_args = [before_block.arg(i) for i in range(len(init_handles))]
1075
+ condition_args = unflatten_ir_values(block_args, init_fe_tys)
1076
+ for name, val in zip(names, condition_args):
1077
+ self.lscope[name] = val
1078
+ self.local_defs[name] = val
1079
+ self._maybe_set_loc_to_name(val, name)
1080
+ cond = self.visit(node.test)
1081
+ if isinstance(cond, language.condition):
1082
+ if cond.disable_licm:
1083
+ while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
1084
+ cond = cond.condition
1085
+ self.builder.set_insertion_point_to_end(before_block)
1086
+ # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
1087
+ self.builder.create_condition_op(cond.handle, block_args)
1088
+ # merge the loop body
1089
+ after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys)
1090
+
1091
+ # generate loop body
1092
+ self.builder.set_insertion_point_to_start(after_block)
1093
+ body_handles = [after_block.arg(i) for i in range(len(init_handles))]
1094
+ body_args = unflatten_ir_values(body_handles, init_fe_tys)
1095
+ for name, val in zip(names, body_args):
1096
+ self.lscope[name] = val
1097
+ self.local_defs[name] = val
1098
+ self._maybe_set_loc_to_name(val, name)
1099
+ self.scf_stack.append(node)
1100
+ self.visit_compound_statement(node.body)
1101
+ self.scf_stack.pop()
1102
+
1103
+ yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
1104
+ self.builder.create_yield_op(yield_handles)
1105
+
1106
+ # WhileOp defines new values, update the symbol table (lscope, local_defs)
1107
+ result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
1108
+ result_vals = unflatten_ir_values(result_handles, init_fe_tys)
1109
+ for name, new_def in zip(names, result_vals):
1110
+ self.lscope[name] = new_def
1111
+ self.local_defs[name] = new_def
1112
+ self._maybe_set_loc_to_name(new_def, name)
1113
+
1114
+ for stmt in node.orelse:
1115
+ assert False, "Not implemented"
1116
+ ast.NodeVisitor.generic_visit(self, stmt)
1117
+
1118
+ def visit_Subscript_Load(self, node):
1119
+ assert isinstance(node.ctx, ast.Load)
1120
+ lhs = self.visit(node.value)
1121
+ slices = self.visit(node.slice)
1122
+ if _is_triton_value(lhs):
1123
+ return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
1124
+ return lhs[slices]
1125
+
1126
+ def visit_Subscript_Store(self, node, value):
1127
+ raise NotImplementedError("__setitem__ is not supported in triton")
1128
+
1129
+ def visit_Subscript(self, node):
1130
+ return self.visit_Subscript_Load(node)
1131
+
1132
+ def visit_ExtSlice(self, node):
1133
+ return [self.visit(dim) for dim in node.dims]
1134
+
1135
+ def visit_For(self, node):
1136
+ IteratorClass = self.visit(node.iter.func)
1137
+ iter_args = [self.visit(arg) for arg in node.iter.args]
1138
+ iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords)
1139
+ if IteratorClass == language.static_range:
1140
+ iterator = IteratorClass(*iter_args, **iter_kwargs)
1141
+ static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
1142
+ for i in static_range:
1143
+ self.lscope[node.target.id] = constexpr(i)
1144
+ self.visit_compound_statement(node.body)
1145
+ for stmt in node.orelse:
1146
+ ast.NodeVisitor.generic_visit(self, stmt)
1147
+ return
1148
+ num_stages = None
1149
+ loop_unroll_factor = None
1150
+ disallow_acc_multi_buffer = False
1151
+ flatten = False
1152
+ warp_specialize = False
1153
+ disable_licm = False
1154
+ if IteratorClass is language.range:
1155
+ iterator = IteratorClass(*iter_args, **iter_kwargs)
1156
+ # visit iterator arguments
1157
+ # note: only `range` iterator is supported now
1158
+ # collect lower bound (lb), upper bound (ub), and step
1159
+ lb = iterator.start
1160
+ ub = iterator.end
1161
+ step = iterator.step
1162
+ num_stages = iterator.num_stages
1163
+ loop_unroll_factor = iterator.loop_unroll_factor
1164
+ disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
1165
+ flatten = iterator.flatten
1166
+ warp_specialize = iterator.warp_specialize
1167
+ disable_licm = iterator.disable_licm
1168
+ elif IteratorClass is range:
1169
+ # visit iterator arguments
1170
+ # note: only `range` iterator is supported now
1171
+ # collect lower bound (lb), upper bound (ub), and step
1172
+ lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
1173
+ ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
1174
+ step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
1175
+ else:
1176
+ raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
1177
+ # handle negative constant step (not supported by scf.for in MLIR)
1178
+ negative_step = False
1179
+ if _is_constexpr(step) and step.value < 0:
1180
+ step = constexpr(-step.value)
1181
+ negative_step = True
1182
+ lb, ub = ub, lb
1183
+ lb = self.semantic.to_tensor(lb)
1184
+ ub = self.semantic.to_tensor(ub)
1185
+ step = self.semantic.to_tensor(step)
1186
+ # induction variable type
1187
+ if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
1188
+ raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
1189
+ iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype)
1190
+ iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype)
1191
+ iv_ir_type = iv_type.to_ir(self.builder)
1192
+ iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
1193
+ # lb/ub/step might be constexpr, we need to cast them to tensor
1194
+ lb = lb.handle
1195
+ ub = ub.handle
1196
+ step = step.handle
1197
+ # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
1198
+ lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed)
1199
+ ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
1200
+ step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
1201
+ # Create placeholder for the loop induction variable
1202
+ iv = self.builder.create_poison(iv_ir_type)
1203
+ self.set_value(node.target.id, language.core.tensor(iv, iv_type))
1204
+
1205
+ with enter_sub_region(self) as sr:
1206
+ liveins, insert_block = sr
1207
+ ip, last_loc = self._get_insertion_point_and_loc()
1208
+
1209
+ names, init_handles, init_tys = self._find_carries(node, liveins)
1210
+
1211
+ # create ForOp
1212
+ self._set_insertion_point_and_loc(ip, last_loc)
1213
+ for_op = self.builder.create_for_op(lb, ub, step, init_handles)
1214
+ if _unwrap_if_constexpr(num_stages) is not None:
1215
+ for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
1216
+ if _unwrap_if_constexpr(loop_unroll_factor) is not None:
1217
+ for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
1218
+ if disallow_acc_multi_buffer:
1219
+ for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
1220
+ if flatten:
1221
+ for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
1222
+ if warp_specialize:
1223
+ for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
1224
+ if disable_licm:
1225
+ for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
1226
+
1227
+ self.scf_stack.append(node)
1228
+ for_op_body = for_op.get_body(0)
1229
+ self.builder.set_insertion_point_to_start(for_op_body)
1230
+ block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
1231
+ block_args = unflatten_ir_values(block_handles, init_tys)
1232
+ for name, val in zip(names, block_args):
1233
+ self._maybe_set_loc_to_name(val, name)
1234
+ self.set_value(name, val)
1235
+ self.visit_compound_statement(node.body)
1236
+ self.scf_stack.pop()
1237
+ yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
1238
+
1239
+ # create YieldOp
1240
+ if len(yield_handles) > 0:
1241
+ self.builder.create_yield_op(yield_handles)
1242
+ for_op_region = for_op_body.get_parent()
1243
+ assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
1244
+
1245
+ # update induction variable with actual value, and replace all uses
1246
+ self.builder.set_insertion_point_to_start(for_op_body)
1247
+ iv = for_op.get_induction_var()
1248
+ if negative_step:
1249
+ iv = self.builder.create_sub(ub, iv)
1250
+ iv = self.builder.create_add(iv, lb)
1251
+ self.lscope[node.target.id].handle.replace_all_uses_with(iv)
1252
+ self.set_value(node.target.id, language.core.tensor(iv, iv_type))
1253
+ self._maybe_set_loc_to_name(iv, node.target.id)
1254
+
1255
+ # update lscope & local_defs (ForOp defines new values)
1256
+ result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
1257
+ result_values = unflatten_ir_values(result_handles, init_tys)
1258
+ for name, val in zip(names, result_values):
1259
+ self.set_value(name, val)
1260
+ self._maybe_set_loc_to_name(val, name)
1261
+
1262
+ for stmt in node.orelse:
1263
+ assert False, "Don't know what to do with else after for"
1264
+ ast.NodeVisitor.generic_visit(self, stmt)
1265
+
1266
+ def visit_Slice(self, node):
1267
+ lower = self.visit(node.lower)
1268
+ upper = self.visit(node.upper)
1269
+ step = self.visit(node.step)
1270
+ return language.slice(lower, upper, step)
1271
+
1272
+ def visit_Index(self, node):
1273
+ return self.visit(node.value)
1274
+
1275
+ def visit_keyword(self, node) -> Tuple[str, Any]:
1276
+ return node.arg, self.visit(node.value)
1277
+
1278
+ def visit_Assert(self, node) -> Any:
1279
+ test = self.visit(node.test)
1280
+ msg = self.visit(node.msg) if node.msg is not None else ""
1281
+ return language.core.device_assert(test, msg, _semantic=self.semantic)
1282
+
1283
+ def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
1284
+ args = inspect.getcallargs(fn.fn, *args, **kwargs)
1285
+ args = [args[name] for name in fn.arg_names]
1286
+ for i, arg in enumerate(args):
1287
+ if isinstance(arg, (language.dtype, float, int, bool, JITFunction)):
1288
+ args[i] = language.core.constexpr(arg)
1289
+ args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
1290
+ args_cst = {path: get_iterable_path(args, path) for path in args_cst}
1291
+ args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
1292
+ args_val = [get_iterable_path(args, path) for path in args_path]
1293
+ # mangle
1294
+ caller_context = caller_context or self.caller_context
1295
+ fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
1296
+ # generate function def if necessary
1297
+ if not self.module.has_function(fn_name):
1298
+ # If the callee is not set, we use the same debug setting as the caller
1299
+ file_name, begin_line = get_jit_fn_file_line(fn)
1300
+ arg_types = [
1301
+ language.core.constexpr if arg is None or isinstance(arg,
1302
+ (bool, int, language.core.dtype)) else arg.type
1303
+ for arg in args
1304
+ ]
1305
+ prototype = ASTFunction([], arg_types, args_cst, dict())
1306
+ generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn,
1307
+ function_name=fn_name, function_types=self.function_ret_types,
1308
+ noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
1309
+ options=self.builder.options, codegen_fns=self.builder.codegen_fns,
1310
+ module_map=self.builder.module_map, caller_context=caller_context,
1311
+ is_gluon=self.is_gluon)
1312
+ try:
1313
+ generator.visit(fn.parse())
1314
+ except Exception as e:
1315
+ # Wrap the error in the callee with the location of the call.
1316
+ if knobs.compilation.front_end_debugging:
1317
+ raise
1318
+ raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
1319
+
1320
+ callee_ret_type = generator.ret_type
1321
+ self.function_ret_types[fn_name] = callee_ret_type
1322
+ else:
1323
+ callee_ret_type = self.function_ret_types[fn_name]
1324
+ symbol = self.module.get_function(fn_name)
1325
+ args_val = flatten_values_to_ir(args_val)
1326
+ call_op = self.builder.call(symbol, args_val)
1327
+ if callee_ret_type == language.void:
1328
+ return None
1329
+ handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
1330
+ return next(unflatten_ir_values(handles, [callee_ret_type]))
1331
+
1332
+ def call_Function(self, node, fn, args, kws):
1333
+ if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)):
1334
+ args.insert(0, fn.__self__)
1335
+ fn = fn.__func__
1336
+ if isinstance(fn, JITFunction):
1337
+ _check_fn_args(node, fn, args)
1338
+ return self.call_JitFunction(fn, args, kws)
1339
+ if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance(
1340
+ fn, ConstexprFunction):
1341
+ extra_kwargs = dict()
1342
+
1343
+ if isinstance(fn, ConstexprFunction):
1344
+ sig = inspect.signature(fn.__call__)
1345
+ else:
1346
+ sig = inspect.signature(fn)
1347
+ if '_semantic' in sig.parameters:
1348
+ extra_kwargs["_semantic"] = self.semantic
1349
+ if '_generator' in sig.parameters:
1350
+ extra_kwargs['_generator'] = self
1351
+ try:
1352
+ ret = fn(*args, **extra_kwargs, **kws)
1353
+ # builtin functions return plain tuples for readability
1354
+ if isinstance(ret, tuple):
1355
+ ret = language.tuple(ret)
1356
+ return ret
1357
+ except Exception as e:
1358
+ if knobs.compilation.front_end_debugging:
1359
+ raise
1360
+ # Normally when we raise a CompilationError, we raise it as
1361
+ # `from None`, because the original fileline from the exception
1362
+ # is not relevant (and often points into code_generator.py
1363
+ # itself). But when calling a function, we raise as `from e` to
1364
+ # preserve the traceback of the original error, which may e.g.
1365
+ # be in core.py.
1366
+ raise CompilationError(self.jit_fn.src, node, str(e)) from e
1367
+
1368
+ if fn in self.builtin_namespace.values():
1369
+ args = map(_unwrap_if_constexpr, args)
1370
+ ret = fn(*args, **kws)
1371
+
1372
+ def wrap_constexpr(x):
1373
+ if _is_triton_value(x):
1374
+ return x
1375
+ return constexpr(x)
1376
+
1377
+ if isinstance(ret, (builtins.tuple, language.tuple)):
1378
+ return _apply_to_tuple_values(ret, wrap_constexpr)
1379
+ return wrap_constexpr(ret)
1380
+
1381
+ def call_Method(self, node, fn, fn_self, args, kws):
1382
+ if isinstance(fn, JITFunction):
1383
+ args.insert(0, fn_self)
1384
+ return self.call_Function(node, fn, args, kws)
1385
+
1386
+ def visit_Call(self, node):
1387
+ fn = _unwrap_if_constexpr(self.visit(node.func))
1388
+ if not isinstance(fn, BoundJITMethod):
1389
+ static_implementation = self.statically_implemented_functions.get(fn)
1390
+ if static_implementation is not None:
1391
+ return static_implementation(self, node)
1392
+
1393
+ mur = getattr(fn, '_must_use_result', False)
1394
+ if mur and getattr(node, '_is_unused', False):
1395
+ error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
1396
+ if isinstance(mur, str):
1397
+ error_message.append(mur)
1398
+ raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
1399
+
1400
+ kws = dict(self.visit(keyword) for keyword in node.keywords)
1401
+ args = [self.visit(arg) for arg in node.args]
1402
+ args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1403
+
1404
+ return self.call_Function(node, fn, args, kws)
1405
+
1406
+ def visit_Constant(self, node):
1407
+ return constexpr(node.value)
1408
+
1409
+ def visit_BoolOp(self, node: ast.BoolOp):
1410
+ method_name = self._method_name_for_bool_op.get(type(node.op))
1411
+ if method_name is None:
1412
+ raise self._unsupported(
1413
+ node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
1414
+
1415
+ nontrivial_values = []
1416
+
1417
+ for subnode in node.values:
1418
+ # we visit the values in order, executing their side-effects
1419
+ # and possibly early-exiting:
1420
+ value = self.visit(subnode)
1421
+ if not _is_triton_tensor(value):
1422
+ # this is a constexpr, so we might be able to short-circuit:
1423
+ bv = bool(value)
1424
+ if (bv is False) and (method_name == "logical_and"):
1425
+ # value is falsey so return that:
1426
+ return value
1427
+ if (bv is True) and (method_name == "logical_or"):
1428
+ # value is truthy so return that:
1429
+ return value
1430
+ # otherwise, our constexpr has no effect on the output of the
1431
+ # expression so we do not append it to nontrivial_values.
1432
+ else:
1433
+ if value.type.is_block():
1434
+ lineno = getattr(node, "lineno", None)
1435
+ if lineno is not None:
1436
+ lineno += self.begin_line
1437
+ warnings.warn_explicit(
1438
+ "Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
1439
+ category=UserWarning,
1440
+ filename=self.file_name,
1441
+ lineno=lineno,
1442
+ source=ast.unparse(node),
1443
+ )
1444
+ # not a constexpr so we must append it:
1445
+ nontrivial_values.append(value)
1446
+
1447
+ if len(nontrivial_values) == 0:
1448
+ # the semantics of a disjunction of falsey values or conjunction
1449
+ # of truthy values is to return the final value:
1450
+ nontrivial_values.append(value)
1451
+
1452
+ while len(nontrivial_values) >= 2:
1453
+ rhs = nontrivial_values.pop()
1454
+ lhs = nontrivial_values.pop()
1455
+ res = self._apply_binary_method(method_name, lhs, rhs)
1456
+ nontrivial_values.append(res)
1457
+
1458
+ assert len(nontrivial_values) == 1
1459
+ return nontrivial_values[0]
1460
+
1461
+ _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
1462
+
1463
+ def visit_Attribute(self, node):
1464
+ lhs = self.visit(node.value)
1465
+ if _is_triton_tensor(lhs) and node.attr == "T":
1466
+ return self.semantic.permute(lhs, (1, 0))
1467
+ # NOTE: special case ".value" for BC
1468
+ if isinstance(lhs, constexpr) and node.attr not in ("value", "type"):
1469
+ lhs = lhs.value
1470
+ attr = getattr(lhs, node.attr)
1471
+ if _is_triton_value(lhs) and isinstance(attr, JITFunction):
1472
+ return BoundJITMethod(lhs, attr)
1473
+ return attr
1474
+
1475
+ def visit_Expr(self, node):
1476
+ node.value._is_unused = True
1477
+ ast.NodeVisitor.generic_visit(self, node)
1478
+
1479
+ def visit_NoneType(self, node):
1480
+ return None
1481
+
1482
+ def visit_JoinedStr(self, node):
1483
+ values = list(node.values)
1484
+ for i, value in enumerate(values):
1485
+ if isinstance(value, ast.Constant):
1486
+ values[i] = str(value.value)
1487
+ elif isinstance(value, ast.FormattedValue):
1488
+ conversion_code = value.conversion
1489
+ evaluated = self.visit(value.value)
1490
+ if not _is_constexpr(evaluated):
1491
+ raise self._unsupported(
1492
+ node,
1493
+ "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
1494
+ + str(type(evaluated)))
1495
+ values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
1496
+ else:
1497
+ raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
1498
+ return ''.join(values)
1499
+
1500
+ def visit(self, node):
1501
+ if node is None:
1502
+ return
1503
+ with warnings.catch_warnings():
1504
+ # The ast library added visit_Constant and deprecated some other
1505
+ # methods but we can't move to that without breaking Python 3.6 and 3.7.
1506
+ warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
1507
+ warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
1508
+ last_node = self.cur_node
1509
+ last_loc = self.builder.get_loc()
1510
+ self.cur_node = node
1511
+ if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
1512
+ here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
1513
+ if self.name_loc_as_prefix is not None:
1514
+ self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc))
1515
+ else:
1516
+ self.builder.set_loc(here_loc)
1517
+ last_loc = self.builder.get_loc()
1518
+ try:
1519
+ ret = super().visit(node)
1520
+ except CompilationError:
1521
+ raise
1522
+ except Exception as e:
1523
+ if knobs.compilation.front_end_debugging:
1524
+ raise
1525
+ # Wrap the error in a CompilationError which contains the source
1526
+ # of the @jit function.
1527
+ raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
1528
+
1529
+ # Reset the location to the last one before the visit
1530
+ if last_loc:
1531
+ self.cur_node = last_node
1532
+ self.builder.set_loc(last_loc)
1533
+ return ret
1534
+
1535
+ def generic_visit(self, node):
1536
+ raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__))
1537
+
1538
+ def execute_static_assert(self, node: ast.Call) -> None:
1539
+ arg_count = len(node.args)
1540
+ if not (0 < arg_count <= 2) or len(node.keywords):
1541
+ raise TypeError("`static_assert` requires one or two positional arguments only")
1542
+
1543
+ passed = _unwrap_if_constexpr(self.visit(node.args[0]))
1544
+ if not isinstance(passed, bool):
1545
+ raise NotImplementedError(
1546
+ "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
1547
+ )
1548
+ if not passed:
1549
+ if arg_count == 1:
1550
+ message = ""
1551
+ else:
1552
+ try:
1553
+ message = self.visit(node.args[1])
1554
+ except Exception as e:
1555
+ message = "<failed to evaluate assertion message: " + repr(e) + ">"
1556
+
1557
+ raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message))
1558
+ return None
1559
+
1560
+ def static_executor(python_fn):
1561
+
1562
+ def ret(self, node: ast.Call):
1563
+ kws = {
1564
+ name: _unwrap_if_constexpr(value)
1565
+ for name, value in (self.visit(keyword) for keyword in node.keywords)
1566
+ }
1567
+ args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args]
1568
+ return constexpr(python_fn(*args, **kws))
1569
+
1570
+ return ret
1571
+
1572
+ from ..experimental.gluon import language as ttgl
1573
+ statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
1574
+ language.core.static_assert: execute_static_assert,
1575
+ language.core.static_print: static_executor(print),
1576
+ ttgl.static_assert: execute_static_assert,
1577
+ ttgl.static_print: static_executor(print),
1578
+ int: static_executor(int),
1579
+ len: static_executor(len),
1580
+ }
1581
+
1582
+
1583
+ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
1584
+ arg_types = [None] * len(fn.arg_names)
1585
+ const_iter = iter(src.constants.items())
1586
+ kc, vc = next(const_iter, (None, None))
1587
+
1588
+ for i, (ks, v) in enumerate(src.signature.items()):
1589
+ idx = fn.arg_names.index(ks)
1590
+ cexpr = None
1591
+ if kc is not None and kc[0] == i:
1592
+ cexpr = vc
1593
+ kc, vc = next(const_iter, (None, None))
1594
+ arg_types[idx] = str_to_ty(v, cexpr)
1595
+ prototype = ASTFunction([], arg_types, src.constants, src.attrs)
1596
+ file_name, begin_line = get_jit_fn_file_line(fn)
1597
+ # query function representation
1598
+ from collections import namedtuple
1599
+ leaves = filter(lambda v: len(v) == 1, src.constants)
1600
+ constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
1601
+ signature = src.signature
1602
+ proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
1603
+ generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
1604
+ jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1605
+ codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
1606
+ generator.visit(fn.parse())
1607
+ module = generator.module
1608
+ # module takes ownership of the context
1609
+ module.context = context
1610
+ if not module.verify_with_diagnostics():
1611
+ if not fn.is_gluon():
1612
+ print(module)
1613
+ raise RuntimeError("error encountered during parsing")
1614
+ return module