triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
triton/runtime/jit.py ADDED
@@ -0,0 +1,1107 @@
1
+ from __future__ import annotations, division
2
+ import ast
3
+ import copy
4
+ import hashlib
5
+ import inspect
6
+ import itertools
7
+ import threading
8
+ import re
9
+ import textwrap
10
+ from collections import defaultdict
11
+ from dataclasses import dataclass
12
+ from functools import cached_property
13
+ from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
14
+
15
+ from triton.tools.tensor_descriptor import TensorDescriptor
16
+ from types import ModuleType
17
+ from .. import knobs
18
+ from .driver import driver
19
+ from . import _async_compile
20
+ from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, canonicalize_dtype
21
+ from .cache import get_cache_key
22
+ from triton._C.libtriton import get_cache_invalidating_env_vars
23
+
24
+ TRITON_MODULE = "triton.language"
25
+ GLUON_MODULE = "triton.experimental.gluon.language"
26
+
27
+ T = TypeVar("T")
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # Dependencies Finder
31
+ # -----------------------------------------------------------------------------
32
+
33
+
34
+ class DependenciesFinder(ast.NodeVisitor):
35
+ """
36
+ This AST visitor is used to find dependencies of a JITFunction. This can
37
+ be used to invalidate a JITFunction's hash when its source code -- or
38
+ that of its dependencies -- changes.
39
+
40
+ This visitor also keeps track of the global variables touched by the
41
+ JITFunction. When we launch the kernel, we check that these have the same
42
+ values as they did when we ran this visitor. If not, we raise an error (or
43
+ otherwise we could recompile).
44
+ """
45
+
46
+ def __init__(self, name, globals, nonlocals, src) -> None:
47
+ super().__init__()
48
+ self.name = name
49
+ self.hasher = hashlib.sha256(src.encode("utf-8"))
50
+
51
+ # This function's __globals__ dict.
52
+ self.globals = globals
53
+ self.nonlocals = nonlocals
54
+
55
+ # Python builtins that can be accessed from Triton kernels.
56
+ self.supported_python_builtins = {
57
+ 'float',
58
+ 'getattr',
59
+ 'int',
60
+ 'isinstance',
61
+ 'len',
62
+ 'list',
63
+ 'max',
64
+ 'min',
65
+ 'print',
66
+ 'range',
67
+ }
68
+ self.supported_modules = {
69
+ GLUON_MODULE,
70
+ TRITON_MODULE,
71
+ "copy",
72
+ "math",
73
+ }
74
+
75
+ # used_global_vals tells us which global variables are used by this
76
+ # function and all those it transitively calls, plus the values of those
77
+ # variables when each function was initially run. (That is, if A calls
78
+ # C, and B calls C, then the values for C in used_global_vals will be
79
+ # from the first time C was run, either by A or B.)
80
+ #
81
+ # Each function may have a different __globals__ dict, so the global
82
+ # variable `foo` may actually have a different value in the different
83
+ # functions. Thus this map is actually
84
+ # (var_name, id(__globals__)) -> (var_value, __globals__).
85
+ self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
86
+
87
+ self.visiting_arg_default_value = False
88
+
89
+ @property
90
+ def ret(self):
91
+ return self.hasher.hexdigest()
92
+
93
+ def _is_triton_builtin(self, node, func):
94
+ if inspect.isbuiltin(node.func):
95
+ return True
96
+ module = getattr(func, "__module__", "")
97
+ return module.startswith(TRITON_MODULE)
98
+
99
+ def _update_hash(self, func):
100
+ assert isinstance(func, JITCallable)
101
+ # Merge our used_global_vals with those of the called function,
102
+ # after checking that all overlapping values are consistent.
103
+ for k in self.used_global_vals.keys() & func.used_global_vals.keys():
104
+ var_name, _ = k
105
+ v1, _ = self.used_global_vals[k]
106
+ v2, _ = func.used_global_vals[k]
107
+ if v1 != v2:
108
+ raise RuntimeError(
109
+ f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
110
+ )
111
+ self.used_global_vals.update(func.used_global_vals)
112
+ # update hash
113
+ func_key = func.cache_key
114
+ func_key += str(getattr(func, "noinline", False))
115
+ self.hasher.update(func_key.encode("utf-8"))
116
+
117
+ def record_reference(self, val, var_dict=None, name=None):
118
+ from ..language.core import constexpr
119
+ # Only keep track of "interesting" global variables, that non-evil users
120
+ # might change. Don't consider functions, modules, builtins, etc. This
121
+ # helps keep the list of vars we have to check small.
122
+ if val is None or type(val) is ModuleType:
123
+ return
124
+
125
+ if getattr(val, "__triton_builtin__", False):
126
+ return
127
+
128
+ # Stubs that aren't real functions
129
+ if getattr(val, "__module__", "") == "triton.language.extra.libdevice":
130
+ return
131
+
132
+ if isinstance(val, JITCallable):
133
+ self._update_hash(val)
134
+ return
135
+
136
+ if callable(val) and not isinstance(val, type) and not isinstance(val, constexpr):
137
+ raise RuntimeError(f"Unsupported function referenced: {val}")
138
+
139
+ # Python default arguments are resolved only once, when the
140
+ # function is defined. So if you do `foo(a=A)` and the value of
141
+ # A changes, foo will still use the old value of A.
142
+ # It would be pretty evil if someone did `import x` and then
143
+ # `x = blah`.
144
+ if self.visiting_arg_default_value:
145
+ return
146
+
147
+ if var_dict is not None:
148
+ self.used_global_vals[(name, id(var_dict))] = (copy.deepcopy(val), var_dict)
149
+ return
150
+
151
+ def visit_Name(self, node):
152
+ if type(node.ctx) is ast.Store:
153
+ return node.id
154
+
155
+ if node.id in self.local_names:
156
+ # The global name is hidden by the local name.
157
+ return None
158
+
159
+ def name_lookup(name):
160
+ val = self.globals.get(name, None)
161
+ if val is not None:
162
+ return val, self.globals
163
+ val = self.nonlocals.get(name, None)
164
+ if val is not None:
165
+ return val, self.nonlocals
166
+ return None, None
167
+
168
+ val, var_dict = name_lookup(node.id)
169
+ if node.id in self.supported_python_builtins:
170
+ return val
171
+
172
+ self.record_reference(val, var_dict, node.id)
173
+ return val
174
+
175
+ def visit_Tuple(self, node):
176
+ # We need to explicitly return the tuple values so that visit_Assign can
177
+ # access them in the case of `a, b = ...`.
178
+ return [self.visit(elt) for elt in node.elts]
179
+
180
+ def visit_Attribute(self, node):
181
+ lhs = self.visit(node.value)
182
+ while isinstance(lhs, ast.Attribute):
183
+ lhs = self.visit(lhs.value)
184
+ lhs_name = getattr(lhs, "__name__", "")
185
+ if lhs is None or lhs_name in self.supported_modules:
186
+ return None
187
+ ret = getattr(lhs, node.attr)
188
+ self.record_reference(ret)
189
+ return ret
190
+
191
+ def visit_FunctionDef(self, node):
192
+ # Save the local name, which may hide the global name.
193
+ self.local_names = {arg.arg for arg in node.args.args}
194
+ self.generic_visit(node)
195
+
196
+ def visit_arguments(self, node):
197
+ # The purpose of this function is to visit everything in `arguments`
198
+ # just like `generic_visit`, except when we're visiting default values
199
+ # (i.e. the `foo` part of `def fn(x = foo)`), we set
200
+ # self.visiting_arg_default_value = True. This allows visit_Name to be
201
+ # aware that we're inside function default values, which have special
202
+ # semantics.
203
+
204
+ # According to the AST docs, the arguments node has the following structure.
205
+ #
206
+ # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs,
207
+ # expr* kw_defaults, arg? kwarg, expr* defaults)
208
+ def visit_defaults(defaults):
209
+ try:
210
+ assert not self.visiting_arg_default_value
211
+ self.visiting_arg_default_value = True
212
+ for expr in defaults:
213
+ if expr is not None:
214
+ self.visit(expr)
215
+ finally:
216
+ self.visiting_arg_default_value = False
217
+
218
+ for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs):
219
+ self.visit(arg)
220
+
221
+ visit_defaults(node.kw_defaults)
222
+
223
+ if node.kwarg is not None:
224
+ self.visit(node.kwarg)
225
+
226
+ visit_defaults(node.defaults)
227
+
228
+ def visitAssnTarget(self, node):
229
+ # Target is either a single string, or a list of strings (if the assn
230
+ # target is a tuple).
231
+ target = self.visit(node)
232
+ if isinstance(target, list):
233
+ self.local_names |= set(target)
234
+ else:
235
+ self.local_names.add(target)
236
+
237
+ def visit_Assign(self, node):
238
+ if len(node.targets) != 1:
239
+ # TODO(jlebar): I don't actually know how to hit this. You don't
240
+ # get it from `a, b = ...` -- in that case, node.targets is a single
241
+ # Tuple, and in fact we *do* need to handle that case if we want
242
+ # existing code to work.
243
+ raise TypeError("Simultaneous multiple assignment is not supported.")
244
+
245
+ self.visitAssnTarget(node.targets[0])
246
+
247
+ # This will re-visit the target, but that's OK.
248
+ self.generic_visit(node)
249
+
250
+ def visit_AnnAssign(self, node):
251
+ self.visitAssnTarget(node.target)
252
+
253
+ # This will re-visit the target, but that's OK.
254
+ self.generic_visit(node)
255
+
256
+ def visit_For(self, node):
257
+ self.visitAssnTarget(node.target)
258
+
259
+ # This will re-visit the target, but that's fine.
260
+ self.generic_visit(node)
261
+
262
+
263
+ # -----------------------------------------------------------------------------
264
+ # JITFunction
265
+ # -----------------------------------------------------------------------------
266
+
267
+
268
+ def _normalize_ty(ty) -> str:
269
+ import triton.language.core as core
270
+ if isinstance(ty, str):
271
+ ty = ty.strip()
272
+ if ty.startswith("const "):
273
+ ty = ty.removeprefix("const")
274
+ ty = _normalize_ty(ty)
275
+ assert ty.startswith("*")
276
+ return "*k" + ty[1:]
277
+ if ty.endswith("*"):
278
+ return "*" + _normalize_ty(ty[:-1])
279
+ if ty.startswith("*"):
280
+ return "*" + _normalize_ty(ty[1:])
281
+ if ty.startswith("tl."):
282
+ return _normalize_ty(ty.removeprefix("tl."))
283
+ elif isinstance(ty, core.pointer_type):
284
+ return f"*{_normalize_ty(ty.element_ty)}"
285
+ elif isinstance(ty, core.dtype):
286
+ ty = ty.name
287
+ elif isinstance(ty, type):
288
+ ty = ty.__name__
289
+ else:
290
+ ty = str(ty)
291
+ return type_canonicalisation_dict.get(ty.replace("_t", ""), ty)
292
+
293
+
294
+ class KernelParam:
295
+ """Represents a parameter (name plus metadata) to a @jit'ed function."""
296
+
297
+ def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool,
298
+ do_not_specialize_on_alignment: bool):
299
+ self.num = num
300
+ self._param = param
301
+ self.do_not_specialize = do_not_specialize
302
+ self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
303
+
304
+ @cached_property
305
+ def name(self):
306
+ return self._param.name
307
+
308
+ @cached_property
309
+ def annotation(self) -> str:
310
+ if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
311
+ return ""
312
+ return _normalize_ty(self._param.annotation)
313
+
314
+ @cached_property
315
+ def annotation_type(self) -> str:
316
+ a = self.annotation
317
+ if a.startswith("*k"):
318
+ a = a[2:]
319
+ elif a.startswith("*"):
320
+ a = a[1:]
321
+ if a in set(type_canonicalisation_dict.values()):
322
+ return self.annotation
323
+ return ""
324
+
325
+ @cached_property
326
+ def is_constexpr(self):
327
+ return "constexpr" in self.annotation
328
+
329
+ @cached_property
330
+ def is_const(self):
331
+ if self.is_constexpr:
332
+ return False
333
+ return "const" in self.annotation or self.annotation.startswith("*k")
334
+
335
+ @property
336
+ def default(self):
337
+ return self._param.default
338
+
339
+ @property
340
+ def has_default(self):
341
+ return self._param.default != inspect.Parameter.empty
342
+
343
+
344
+ dtype2str = {}
345
+ specialize_impl_cache = []
346
+
347
+
348
+ def create_specialize_impl(specialize_extra):
349
+
350
+ from ..language import constexpr
351
+ from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
352
+
353
+ def specialize_impl(arg, is_const=False, specialize_value=True, align=True):
354
+ if arg is None:
355
+ return ("constexpr", None)
356
+ elif isinstance(arg, bool):
357
+ return ("u1", None)
358
+ elif isinstance(arg, int):
359
+ key = specialize_extra(arg, "int", align=align) if specialize_value else None
360
+ if arg == 1 and specialize_value:
361
+ return ("constexpr", 1)
362
+ elif -(2**31) <= arg and arg <= 2**31 - 1:
363
+ return ("i32", key)
364
+ elif 2**63 <= arg and arg <= 2**64 - 1:
365
+ return ("u64", key)
366
+ else:
367
+ return ("i64", key)
368
+ elif isinstance(arg, float):
369
+ return ("fp32", None)
370
+ elif hasattr(arg, "data_ptr"):
371
+ # dtypes are hashable so we can memoize this mapping:
372
+ dsk = (arg.dtype, is_const)
373
+ res = dtype2str.get(dsk, None)
374
+ if res is None:
375
+ res = ("*k" if dsk[1] else "*") + canonicalize_dtype(dsk[0])
376
+ dtype2str[dsk] = res
377
+ key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
378
+ return (res, key)
379
+ elif isinstance(arg, JITCallable):
380
+ return ("constexpr", arg.cache_key)
381
+ elif isinstance(arg, constexpr):
382
+ return ("constexpr", arg)
383
+ elif isinstance(arg, tuple):
384
+ spec = [specialize_impl(x) for x in arg]
385
+ make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
386
+ tys = make_tuple([x[0] for x in spec])
387
+ keys = make_tuple([x[1] for x in spec])
388
+ return (tys, keys)
389
+ elif isinstance(arg, TensorDescriptor):
390
+ assert hasattr(arg.base, "data_ptr")
391
+ inner = canonicalize_dtype(arg.base.dtype)
392
+ return (f"tensordesc<{inner}{list(arg.block_shape)}>", None)
393
+ elif isinstance(arg, GluonTensorDescriptor):
394
+ assert hasattr(arg.base, "data_ptr")
395
+ inner = canonicalize_dtype(arg.base.dtype)
396
+ return (f"tensordesc<{inner}{list(arg.block_shape)},{arg.layout!r}>", None)
397
+ else:
398
+ raise TypeError("Unsupported type: %s" % type(arg))
399
+
400
+ return specialize_impl
401
+
402
+
403
+ def mangle_type(arg, specialize=False):
404
+ if len(specialize_impl_cache) == 0:
405
+ specialize_impl_cache.append(create_specialize_impl(lambda _, **kwargs: None))
406
+ specialize_impl = specialize_impl_cache[0]
407
+ return specialize_impl(arg, specialize_value=specialize)[0]
408
+
409
+
410
+ class KernelInterface(Generic[T]):
411
+ run: T
412
+
413
+ def __getitem__(self, grid) -> T:
414
+ """
415
+ A JIT function is launched with: fn[grid](*args, **kwargs).
416
+ Hence JITFunction.__getitem__ returns a callable proxy that
417
+ memorizes the grid.
418
+ """
419
+ return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
420
+ # return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
421
+
422
+
423
+ def serialize_specialization_data(name, signature, constants, attrs, options, key):
424
+ constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()}
425
+ import json
426
+ obj = {
427
+ 'name': name, 'signature': signature, 'constant_keys': [list(x) for x in constants.keys()], 'constant_vals':
428
+ list(constants.values()), 'attrs_keys': [list(x) for x in attrs.keys()], 'attrs_vals': list(attrs.values()),
429
+ 'options': options.__dict__, 'key': key
430
+ }
431
+ serialized_obj = json.dumps(obj)
432
+ return serialized_obj
433
+
434
+
435
+ def create_function_from_signature(sig, kparams, backend):
436
+ """
437
+ Equivalent to sig.bind followed by apply_defaults. This generates a
438
+ native Python function (using exec) which can be memoized on a per-kernel
439
+ basis to avoid having to run these expensive functions -- which constitute
440
+ much of the kernel launch overhead -- every time we run the kernel.
441
+ """
442
+ assert len(sig.parameters) == len(kparams)
443
+ # Create the function argument list and the dict entries for the return statement
444
+ specialization = []
445
+ # signature
446
+ for name, kp in zip(sig.parameters.keys(), kparams):
447
+ if kp.is_constexpr:
448
+ specialization.append(f'("constexpr", {name})')
449
+ else:
450
+ is_const = 'True' if kp.is_const else 'False'
451
+ specialize = 'False' if kp.do_not_specialize else 'True'
452
+ align = 'False' if kp.do_not_specialize_on_alignment else 'True'
453
+ ret = f"specialize_impl({name}, {is_const}, {specialize}, {align})"
454
+ if kp.annotation_type:
455
+ if isinstance(kp.annotation_type, str):
456
+ if kp.annotation_type == "u1" or kp.annotation_type[:2] in ["fp", "bf"]:
457
+ # we do not specialize non-constexpr floats and bools:
458
+ specialize = False
459
+ if specialize:
460
+ specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
461
+ else:
462
+ # skip runtime specialization:
463
+ specialization.append(f'("{kp.annotation_type}", None)')
464
+ else:
465
+ specialization.append(f"{ret}")
466
+
467
+ # compute argument string for a given parameter
468
+ arg = lambda x: x[0] if x[1].default is inspect.Parameter.empty else f"{x[0]}=default_{x[0]}"
469
+ # Join all arguments into a function definition string
470
+ func_body = f"""
471
+ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options"])}):
472
+ params = {{{', '.join([f"'{name}': {name}" for name in sig.parameters.keys()])}}}
473
+ specialization = [{','.join(specialization)}]
474
+ return params, specialization, options
475
+ """
476
+ # Prepare defaults to be inserted into function namespace
477
+ func_namespace = {
478
+ f"default_{name}": param.default
479
+ for name, param in sig.parameters.items()
480
+ if param.default is not inspect.Parameter.empty
481
+ }
482
+
483
+ func_namespace["JITCallable"] = JITCallable
484
+ func_namespace["specialize_impl"] = create_specialize_impl(backend.get_arg_specialization)
485
+
486
+ # Execute the function string in func_namespace to create the function
487
+ exec(func_body, func_namespace)
488
+
489
+ # Extract the newly created function from the namespace
490
+ return func_namespace['dynamic_func']
491
+
492
+
493
+ def get_full_name(fn):
494
+ return f"{fn.__module__}.{fn.__qualname__}"
495
+
496
+
497
+ class JITCallable:
498
+
499
+ def __init__(self, fn):
500
+ self.fn = fn
501
+ self.signature = inspect.signature(fn)
502
+ try:
503
+ self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
504
+ except OSError as e:
505
+ raise ValueError("@jit functions should be defined in a Python file") from e
506
+ self._fn_name = get_full_name(fn)
507
+ self._hash_lock = threading.RLock()
508
+
509
+ # function source code (without decorators)
510
+ src = textwrap.dedent("".join(self.raw_src))
511
+ src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
512
+ self._src = src
513
+ self.hash = None
514
+
515
+ # Map of global variables used by the function and any functions it
516
+ # transitively calls, plus their values. The values are collected when
517
+ # the function is first compiled. Then every time we run the function,
518
+ # we check that the values of the globals match what's expected,
519
+ # otherwise we raise an error.
520
+ #
521
+ # Different functions can have different __globals__ maps, so the map
522
+ # key is actually (var name, id(__globals__)), and the map value is
523
+ # (value, __globals__).
524
+ self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
525
+
526
+ # reuse docs of wrapped function
527
+ self.__doc__ = fn.__doc__
528
+ self.__name__ = fn.__name__
529
+ self.__qualname__ = fn.__qualname__
530
+ self.__globals__ = fn.__globals__
531
+ self.__module__ = fn.__module__
532
+
533
+ def get_capture_scope(self):
534
+ return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
535
+
536
+ @property
537
+ def cache_key(self):
538
+ # TODO : hash should be attribute of `self`
539
+ with self._hash_lock:
540
+ if self.hash is not None:
541
+ return self.hash
542
+ # Set a placeholder hash to break recursion in case the function
543
+ # transitively calls itself. The full hash is set after.
544
+ self.hash = f"recursion:{self._fn_name}"
545
+ nonlocals = inspect.getclosurevars(self.fn).nonlocals
546
+ dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
547
+ src=self.src)
548
+ dependencies_finder.visit(self.parse())
549
+ self.hash = dependencies_finder.ret + str(self.starting_line_number)
550
+ self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
551
+
552
+ from triton.language.core import constexpr
553
+ self.hash += str([(name, val)
554
+ for (name, _), (val, _) in self.used_global_vals.items()
555
+ if isinstance(val, constexpr)])
556
+ self.hash = hashlib.sha256(self.hash.encode("utf-8")).hexdigest()
557
+ return self.hash
558
+
559
+ # we do not parse `src` in the constructor because
560
+ # the user might want to monkey-patch self.src dynamically.
561
+ # Our unit tests do this, for example.
562
+ def parse(self):
563
+ tree = ast.parse(self._src)
564
+ assert isinstance(tree, ast.Module)
565
+ assert len(tree.body) == 1
566
+ assert isinstance(tree.body[0], ast.FunctionDef)
567
+ return tree
568
+
569
+ @property
570
+ def type(self):
571
+ from triton.language.core import constexpr_type
572
+ return constexpr_type(self)
573
+
574
+ def _unsafe_update_src(self, new_src):
575
+ """
576
+ The only method allowed to modify src.
577
+ Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
578
+
579
+ Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None.
580
+ """
581
+ self.hash = None
582
+ self._src = new_src
583
+
584
+ def _set_src(self):
585
+ raise AttributeError("Cannot set attribute 'src' directly. "
586
+ "Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
587
+ "instead.")
588
+
589
+ def _get_src(self):
590
+ return self._src
591
+
592
+ src = property(fget=_get_src, fset=_set_src)
593
+
594
+
595
+ @dataclass
596
+ class JitFunctionInfo:
597
+ module: ModuleType
598
+ name: str
599
+ jit_function: JITFunction
600
+
601
+
602
+ def compute_cache_key(kernel_key_cache, specialization, options):
603
+ key = (tuple(specialization), str(options))
604
+ cache_key = kernel_key_cache.get(key, None)
605
+ if cache_key is not None:
606
+ return cache_key
607
+
608
+ cache_key = str(specialization) + str(options)
609
+ kernel_key_cache[key] = cache_key
610
+ return cache_key
611
+
612
+
613
+ class JITFunction(JITCallable, KernelInterface[T]):
614
+
615
+ def is_gluon(self):
616
+ return False
617
+
618
+ def _call_hook(
619
+ self,
620
+ hook,
621
+ key,
622
+ signature,
623
+ device,
624
+ constants,
625
+ options,
626
+ configs,
627
+ is_warmup,
628
+ ) -> bool | None:
629
+ if not hook:
630
+ return None
631
+
632
+ name = self.fn.__qualname__
633
+ module = self.fn.__module__
634
+ arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
635
+ repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
636
+ full_name = get_full_name(self.fn)
637
+
638
+ specialization_data = serialize_specialization_data(full_name, signature, constants, configs[0], options, key)
639
+
640
+ kwargs = {
641
+ 'signature': signature,
642
+ 'device': device,
643
+ 'constants': constants,
644
+ 'num_warps': options.num_warps,
645
+ 'num_ctas': options.num_ctas,
646
+ 'num_stages': options.num_stages,
647
+ 'enable_fp_fusion': options.enable_fp_fusion,
648
+ 'launch_cooperative_grid': options.launch_cooperative_grid,
649
+ 'extern_libs': options.extern_libs,
650
+ 'configs': configs,
651
+ 'specialization_data': specialization_data,
652
+ 'is_warmup': is_warmup,
653
+ }
654
+
655
+ return hook(
656
+ key=key,
657
+ repr=repr,
658
+ fn=JitFunctionInfo(module, name, self),
659
+ compile={"key": key, **kwargs},
660
+ is_manual_warmup=is_warmup,
661
+ already_compiled=False,
662
+ )
663
+
664
+ def add_pre_run_hook(self, hook):
665
+ '''
666
+ Add a hook that will be executed prior to the execution of run
667
+ function with args and kwargs passed into the kernel
668
+ '''
669
+ assert callable(hook)
670
+ self.pre_run_hooks.append(hook)
671
+
672
+ def create_binder(self):
673
+ """
674
+ Precompute as much as possible.
675
+ """
676
+ from ..compiler import CompiledKernel, compile, ASTSource, make_backend
677
+ target = driver.active.get_current_target()
678
+ backend = make_backend(target)
679
+ self.CompiledKernel = CompiledKernel
680
+ self.compile = compile
681
+ self.ASTSource = ASTSource
682
+ binder = create_function_from_signature(self.signature, self.params, backend)
683
+ return {}, {}, target, backend, binder
684
+
685
+ def _pack_args(self, backend, kwargs, bound_args, specialization, options):
686
+ # options
687
+ options = backend.parse_options(kwargs)
688
+ # signature
689
+ sigkeys = [x.name for x in self.params]
690
+ sigvals = [x[0] for x in specialization]
691
+ signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
692
+ # check arguments
693
+ assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
694
+ assert "device" not in kwargs, "device option is deprecated; current device will be used"
695
+ assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
696
+ for k in kwargs:
697
+ if k not in options.__dict__ and k not in sigkeys:
698
+ raise KeyError("Keyword argument %s was specified but unrecognised" % k)
699
+ # constexprs
700
+ constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
701
+ constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
702
+ # attributes
703
+ attrvals = [x[1] for x in specialization]
704
+ attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
705
+ attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
706
+
707
+ return options, signature, constexprs, attrs
708
+
709
+ def run(self, *args, grid, warmup, **kwargs):
710
+ kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
711
+
712
+ # parse options
713
+ device = driver.active.get_current_device()
714
+ stream = driver.active.get_current_stream(device)
715
+
716
+ # Execute pre run hooks with args and kwargs
717
+ for hook in self.pre_run_hooks:
718
+ hook(*args, **kwargs)
719
+
720
+ kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
721
+ # specialization is list[tuple[str, Any]], where first element of tuple is
722
+ # the type and the second parameter is the 'specialization' value.
723
+ bound_args, specialization, options = binder(*args, **kwargs)
724
+
725
+ key = compute_cache_key(kernel_key_cache, specialization, options)
726
+ kernel = kernel_cache.get(key, None)
727
+
728
+ # Kernel is not cached; we have to compile.
729
+ if kernel is None:
730
+ options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
731
+ options)
732
+
733
+ kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
734
+ if kernel is None:
735
+ return None
736
+
737
+ # Check that used global values have not changed.
738
+ not_present = object()
739
+ for (name, _), (val, globals_dict) in self.used_global_vals.items():
740
+ if (newVal := globals_dict.get(name, not_present)) != val:
741
+ raise RuntimeError(
742
+ f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")
743
+
744
+ if not warmup:
745
+ # canonicalize grid
746
+ assert grid is not None
747
+ if callable(grid):
748
+ grid = grid(bound_args)
749
+ grid_size = len(grid)
750
+ grid_0 = grid[0]
751
+ grid_1 = grid[1] if grid_size > 1 else 1
752
+ grid_2 = grid[2] if grid_size > 2 else 1
753
+ if hasattr(kernel, "result"):
754
+ kernel = kernel.result()
755
+ # launch kernel
756
+ launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
757
+ kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
758
+ knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())
759
+ return kernel
760
+
761
+ def repr(self, _):
762
+ return self._fn_name if self._repr is None else self._repr(_)
763
+
764
+ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
765
+ noinline=None, repr=None, launch_metadata=None):
766
+ do_not_specialize = do_not_specialize if do_not_specialize else []
767
+ do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []
768
+
769
+ super().__init__(fn)
770
+ self.module = fn.__module__
771
+ self.version = version
772
+ self.do_not_specialize = do_not_specialize
773
+ self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
774
+ self._repr = repr
775
+ self.launch_metadata = launch_metadata
776
+
777
+ self.params = []
778
+ for i, param in enumerate(self.signature.parameters.values()):
779
+ dns = i in do_not_specialize or param.name in do_not_specialize
780
+ dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
781
+ self.params.append(KernelParam(i, param, dns, dns_oa))
782
+
783
+ # cache of just-in-time compiled kernels
784
+ self.device_caches = defaultdict(self.create_binder)
785
+
786
+ # JITFunction can be instantiated as kernel
787
+ # when called with a grid using __getitem__
788
+ self.kernel = None
789
+ self.debug = debug
790
+ self.noinline = noinline
791
+
792
+ # TODO(jlebar): Remove uses of these fields outside this file, then
793
+ # remove the fields here.
794
+ self.arg_names = [p.name for p in self.params]
795
+ self.constexprs = [p.num for p in self.params if p.is_constexpr]
796
+
797
+ # Hooks that will be called prior to executing "run"
798
+ self.pre_run_hooks = []
799
+
800
+ def warmup(self, *args, grid, **kwargs):
801
+ return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
802
+
803
+ def preload(self, specialization_data):
804
+ import json
805
+ import triton.language as tl
806
+ device = driver.active.get_current_device()
807
+ deserialized_obj = json.loads(specialization_data)
808
+ if deserialized_obj['name'] != self._fn_name:
809
+ raise RuntimeError(
810
+ f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
811
+ constant_keys = map(tuple, deserialized_obj['constant_keys'])
812
+ constant_vals = deserialized_obj['constant_vals']
813
+ constexprs = {
814
+ key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
815
+ for key, value in zip(constant_keys, constant_vals)
816
+ }
817
+ attrs_keys = map(tuple, deserialized_obj['attrs_keys'])
818
+ attrs_vals = deserialized_obj['attrs_vals']
819
+ attrs = dict(zip(attrs_keys, attrs_vals))
820
+ signature = dict(deserialized_obj['signature'].items())
821
+ options = {
822
+ key: tuple(value) if isinstance(value, list) else value
823
+ for key, value in deserialized_obj['options'].items()
824
+ }
825
+ key = deserialized_obj['key']
826
+ _, _, _, backend, _ = self.device_caches[device]
827
+ options = backend.parse_options(options)
828
+ return self._do_compile(
829
+ key,
830
+ signature,
831
+ device,
832
+ constexprs,
833
+ options,
834
+ attrs,
835
+ warmup=True,
836
+ )
837
+
838
+ def _do_compile(self, key, signature, device, constexprs, options, attrs, warmup):
839
+ kernel_cache, _, target, backend, _ = self.device_caches[device]
840
+
841
+ if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs], warmup):
842
+ return None
843
+ src = self.ASTSource(self, signature, constexprs, attrs)
844
+
845
+ async_mode = _async_compile.active_mode.get()
846
+ if async_mode is not None:
847
+
848
+ env_vars = get_cache_invalidating_env_vars()
849
+ cache_key = get_cache_key(src, backend, options, env_vars)
850
+
851
+ def async_compile():
852
+ return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars)
853
+
854
+ def finalize_compile(kernel):
855
+ kernel_cache[key] = kernel
856
+ self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options,
857
+ [attrs], warmup)
858
+
859
+ kernel = async_mode.submit(cache_key, async_compile, finalize_compile)
860
+ else:
861
+ kernel = self.compile(src, target=target, options=options.__dict__)
862
+ kernel_cache[key] = kernel
863
+ self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
864
+ warmup)
865
+ return kernel
866
+
867
+ def __call__(self, *args, **kwargs):
868
+ raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
869
+
870
+ def __repr__(self):
871
+ return f"JITFunction({self.module}:{self.fn.__qualname__})"
872
+
873
+
874
+ # -----------------------------------------------------------------------------
875
+ # `jit` decorator
876
+ # -----------------------------------------------------------------------------
877
+
878
+
879
+ @overload
880
+ def jit(fn: T) -> JITFunction[T]:
881
+ ...
882
+
883
+
884
+ @overload
885
+ def jit(
886
+ *,
887
+ version=None,
888
+ repr: Optional[Callable] = None,
889
+ launch_metadata: Optional[Callable] = None,
890
+ do_not_specialize: Optional[Iterable[int | str]] = None,
891
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
892
+ debug: Optional[bool] = None,
893
+ noinline: Optional[bool] = None,
894
+ ) -> Callable[[T], JITFunction[T]]:
895
+ ...
896
+
897
+
898
+ def jit(
899
+ fn: Optional[T] = None,
900
+ *,
901
+ version=None,
902
+ repr: Optional[Callable] = None,
903
+ launch_metadata: Optional[Callable] = None,
904
+ do_not_specialize: Optional[Iterable[int | str]] = None,
905
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
906
+ debug: Optional[bool] = None,
907
+ noinline: Optional[bool] = None,
908
+ ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
909
+ """
910
+ Decorator for JIT-compiling a function using the Triton compiler.
911
+
912
+ :note: When a jit'd function is called, arguments are
913
+ implicitly converted to pointers if they have a :code:`.data_ptr()` method
914
+ and a `.dtype` attribute.
915
+
916
+ :note: This function will be compiled and run on the GPU. It will only have access to:
917
+
918
+ * python primitives,
919
+ * builtins within the triton package,
920
+ * arguments to this function,
921
+ * other jit'd functions
922
+
923
+ :param fn: the function to be jit-compiled
924
+ :type fn: Callable
925
+ """
926
+
927
+ def decorator(fn: T) -> JITFunction[T]:
928
+ assert callable(fn)
929
+ if knobs.runtime.interpret:
930
+ from .interpreter import InterpretedFunction
931
+ return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize,
932
+ do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug,
933
+ noinline=noinline, repr=repr, launch_metadata=launch_metadata)
934
+ else:
935
+ return JITFunction(
936
+ fn,
937
+ version=version,
938
+ do_not_specialize=do_not_specialize,
939
+ do_not_specialize_on_alignment=do_not_specialize_on_alignment,
940
+ debug=debug,
941
+ noinline=noinline,
942
+ repr=repr,
943
+ launch_metadata=launch_metadata,
944
+ )
945
+
946
+ if fn is not None:
947
+ return decorator(fn)
948
+
949
+ else:
950
+ return decorator
951
+
952
+
953
+ # -----------------------------------------------------------------------------
954
+ # Utilities for mocking tensors
955
+ # -----------------------------------------------------------------------------
956
+
957
+
958
+ class MockTensor:
959
+ """
960
+ Can be used in place of real tensors when calling:
961
+ kernel.warmup(MockTensor(torch.float32), ...)
962
+ """
963
+
964
+ @staticmethod
965
+ def wrap_dtype(arg):
966
+ if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
967
+ return MockTensor(arg)
968
+ return arg
969
+
970
+ def __init__(self, dtype, shape=None):
971
+ if shape is None:
972
+ shape = [1]
973
+ self.dtype = dtype
974
+ self.shape = shape
975
+
976
+ def stride(self):
977
+ strides = [1]
978
+ for size in self.shape[1:]:
979
+ strides.append(strides[-1] * size)
980
+ return tuple(reversed(strides))
981
+
982
+ @staticmethod
983
+ def data_ptr():
984
+ return 0 # optimistically assumes multiple of 16
985
+
986
+ @staticmethod
987
+ def ptr_range():
988
+ return 0 # optimistically assumes 32 bit pointer range
989
+
990
+
991
+ class TensorWrapper:
992
+
993
+ def __init__(self, base, dtype):
994
+ self.dtype = dtype
995
+ self.base = base
996
+ self.data = base.data
997
+ self.device = base.device
998
+ self.shape = self.base.shape
999
+
1000
+ def data_ptr(self):
1001
+ return self.base.data_ptr()
1002
+
1003
+ def stride(self, *args):
1004
+ return self.base.stride(*args)
1005
+
1006
+ def __str__(self) -> str:
1007
+ return f"TensorWrapper[{self.dtype}]({self.base})"
1008
+
1009
+ def element_size(self):
1010
+ return self.base.element_size()
1011
+
1012
+ def cpu(self):
1013
+ return TensorWrapper(self.base.cpu(), self.dtype)
1014
+
1015
+ def copy_(self, other):
1016
+ self.base.copy_(other.base)
1017
+
1018
+ def clone(self):
1019
+ return TensorWrapper(self.base.clone(), self.dtype)
1020
+
1021
+ def to(self, device):
1022
+ return TensorWrapper(self.base.to(device), self.dtype)
1023
+
1024
+ def new_empty(self, sizes):
1025
+ return TensorWrapper(self.base.new_empty(sizes), self.dtype)
1026
+
1027
+
1028
+ def reinterpret(tensor, dtype):
1029
+ if isinstance(tensor, TensorWrapper):
1030
+ if dtype == tensor.base.dtype:
1031
+ # Reinterpreting to the original interpretation; return the base.
1032
+ return tensor.base
1033
+ else:
1034
+ # Reinterpreting a wrapped tensor to a different type.
1035
+ return TensorWrapper(tensor.base, dtype)
1036
+ elif hasattr(tensor, "data_ptr"):
1037
+ # A new wrapper is needed around an unwrapped tensor.
1038
+ return TensorWrapper(tensor, dtype)
1039
+ else:
1040
+ raise TypeError(f"Cannot reinterpret a {type(tensor)}.")
1041
+
1042
+
1043
+ def get_jit_fn_file_line(fn):
1044
+ base_fn = fn
1045
+ while not isinstance(base_fn, JITCallable):
1046
+ base_fn = base_fn.fn
1047
+ file_name = base_fn.fn.__code__.co_filename
1048
+ begin_line = base_fn.starting_line_number
1049
+ # Match the following pattern:
1050
+ # @triton.autotune(...) <- foo.__code__.co_firstlineno
1051
+ # @triton.heuristics(...)
1052
+ # @triton.jit
1053
+ # def foo(...): <- this line is the first line
1054
+ for idx, line in enumerate(base_fn.raw_src):
1055
+ if line.strip().startswith("def "):
1056
+ begin_line += idx
1057
+ break
1058
+ return file_name, begin_line
1059
+
1060
+
1061
+ class BoundConstexprFunction(JITCallable):
1062
+
1063
+ def __init__(self, instance, fn):
1064
+ self.__self__ = instance
1065
+ self.__func__ = fn
1066
+
1067
+ def __call__(self, *args, **kwargs):
1068
+ return self.__func__(self.__self__, *args, **kwargs)
1069
+
1070
+
1071
+ class ConstexprFunction(JITCallable):
1072
+
1073
+ def __init__(self, fn):
1074
+ super().__init__(fn)
1075
+
1076
+ def __get__(self, obj, objclass):
1077
+ # Create a bound function to support constexpr_function methods
1078
+ if obj is not None:
1079
+ return BoundConstexprFunction(obj, self)
1080
+ return self
1081
+
1082
+ def __call__(self, *args, _semantic=None, **kwargs):
1083
+ from triton.language.core import _unwrap_if_constexpr, constexpr
1084
+ # de-constexpr arguments and discard the _semantic keyword argument:
1085
+ args = [_unwrap_if_constexpr(x) for x in args]
1086
+ kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
1087
+
1088
+ # call the raw Python function f:
1089
+ res = self.fn(*args, **kwargs)
1090
+
1091
+ if _semantic is None:
1092
+ # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function
1093
+ return res
1094
+
1095
+ # convert result back to a Triton constexpr:
1096
+ if knobs.runtime.interpret:
1097
+ return res # No constexpr in interpreter
1098
+ return constexpr(res)
1099
+
1100
+
1101
+ def constexpr_function(fn):
1102
+ """
1103
+ Wraps an arbitrary Python function so that it can be called at
1104
+ compile-time on constexpr arguments in a Triton function and
1105
+ returns a constexpr result.
1106
+ """
1107
+ return ConstexprFunction(fn)