triton-windows 3.1.0.post17__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
triton/runtime/jit.py ADDED
@@ -0,0 +1,956 @@
1
+ from __future__ import annotations, division
2
+ import ast
3
+ import hashlib
4
+ import inspect
5
+ import itertools
6
+ import os
7
+ import re
8
+ import textwrap
9
+ from collections import defaultdict
10
+ from functools import cached_property
11
+ from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
12
+ from ..runtime.driver import driver
13
+ from types import ModuleType
14
+
15
+ TRITON_MODULE = __name__[:-len(".runtime.jit")]
16
+
17
+ T = TypeVar("T")
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Dependencies Finder
21
+ # -----------------------------------------------------------------------------
22
+
23
+
24
+ class DependenciesFinder(ast.NodeVisitor):
25
+ """
26
+ This AST visitor is used to find dependencies of a JITFunction. This can
27
+ be used to invalidate a JITFunction's hash when its source code -- or
28
+ that of its dependencies -- changes.
29
+
30
+ This visitor also keeps track of the global variables touched by the
31
+ JITFunction. When we launch the kernel, we check that these have the same
32
+ values as they did when we ran this visitor. If not, we raise an error (or
33
+ otherwise we could recompile).
34
+ """
35
+
36
+ def __init__(self, name, globals, src) -> None:
37
+ super().__init__()
38
+ self.name = name
39
+ self.hasher = hashlib.sha256(src.encode("utf-8"))
40
+
41
+ # This function's __globals__ dict.
42
+ self.globals = globals
43
+
44
+ # Python builtins that can be accessed from Triton kernels.
45
+ self.supported_python_builtins = {
46
+ 'float',
47
+ 'getattr',
48
+ 'int',
49
+ 'isinstance',
50
+ 'len',
51
+ 'list',
52
+ 'max',
53
+ 'min',
54
+ 'print',
55
+ 'range',
56
+ }
57
+
58
+ # used_global_vals tells us which global variables are used by this
59
+ # function and all those it transitively calls, plus the values of those
60
+ # variables when each function was initially run. (That is, if A calls
61
+ # C, and B calls C, then the values for C in used_global_vals will be
62
+ # from the first time C was run, either by A or B.)
63
+ #
64
+ # Each function may have a different __globals__ dict, so the global
65
+ # variable `foo` may actually have a different value in the different
66
+ # functions. Thus this map is actually
67
+ # (var_name, id(__globals__)) -> (var_value, __globals__).
68
+ self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
69
+
70
+ self.visiting_arg_default_value = False
71
+
72
+ @property
73
+ def ret(self):
74
+ return self.hasher.hexdigest()
75
+
76
+ def visit_Name(self, node):
77
+ if type(node.ctx) == ast.Store:
78
+ return node.id
79
+
80
+ if node.id in self.local_names:
81
+ # The global name is hidden by the local name.
82
+ return None
83
+
84
+ val = self.globals.get(node.id, None)
85
+
86
+ # Only keep track of "interesting" global variables, that non-evil users
87
+ # might change. Don't consider functions, modules, builtins, etc. This
88
+ # helps keep the list of vars we have to check small.
89
+ if (val is not None #
90
+ # Python default arguments are resolved only once, when the
91
+ # function is defined. So if you do `foo(a=A)` and the value of
92
+ # A changes, foo will still use the old value of A.
93
+ and not self.visiting_arg_default_value
94
+ # It would be pretty evil if someone did `import x` and then
95
+ # `x = blah`.
96
+ and type(val) != ModuleType
97
+ # It would be pretty evil if we used function `foo` inside of
98
+ # `bar` and then someone did `foo = baz`.
99
+ and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) #
100
+ and node.id not in self.supported_python_builtins #
101
+ ):
102
+ self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals)
103
+
104
+ return val
105
+
106
+ def visit_Tuple(self, node):
107
+ # We need to explicitly return the tuple values so that visit_Assign can
108
+ # access them in the case of `a, b = ...`.
109
+ return [self.visit(elt) for elt in node.elts]
110
+
111
+ def visit_Attribute(self, node):
112
+ lhs = self.visit(node.value)
113
+ while isinstance(lhs, ast.Attribute):
114
+ lhs = self.visit(lhs.value)
115
+ if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE):
116
+ return None
117
+ return getattr(lhs, node.attr)
118
+
119
+ def visit_Call(self, node):
120
+
121
+ def is_triton_builtin(func):
122
+ if inspect.isbuiltin(node.func):
123
+ return True
124
+ module = getattr(func, "__module__", "")
125
+ return module.startswith(TRITON_MODULE)
126
+
127
+ func = self.visit(node.func)
128
+ assert func is None or is_triton_builtin(func) or isinstance(
129
+ func, JITFunction
130
+ ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
131
+
132
+ # Traverse arguments as well as node.func so we can find JITFunctions
133
+ # passed to tl.reduce or tl.associative_scan as the combine_fn
134
+ for obj in itertools.chain(
135
+ (func, ),
136
+ map(self.visit, node.args),
137
+ (self.visit(kw.value) for kw in node.keywords),
138
+ ):
139
+ if not isinstance(obj, JITFunction):
140
+ continue
141
+ if is_triton_builtin(obj):
142
+ continue
143
+
144
+ func_cache_key = obj.cache_key
145
+
146
+ # Merge our used_global_vals with those of the called function,
147
+ # after checking that all overlapping values are consistent.
148
+ for k in self.used_global_vals.keys() & obj.used_global_vals.keys():
149
+ var_name, _ = k
150
+ v1, _ = self.used_global_vals[k]
151
+ v2, _ = obj.used_global_vals[k]
152
+ if v1 != v2:
153
+ raise RuntimeError(
154
+ f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
155
+ )
156
+
157
+ self.used_global_vals.update(obj.used_global_vals)
158
+
159
+ noinline = str(getattr(obj, "noinline", False))
160
+
161
+ key = func_cache_key + noinline
162
+ self.hasher.update(key.encode("utf-8"))
163
+
164
+ def visit_FunctionDef(self, node):
165
+ # Save the local name, which may hide the global name.
166
+ self.local_names = {arg.arg for arg in node.args.args}
167
+ self.generic_visit(node)
168
+
169
+ def visit_arguments(self, node):
170
+ # The purpose of this function is to visit everything in `arguments`
171
+ # just like `generic_visit`, except when we're visiting default values
172
+ # (i.e. the `foo` part of `def fn(x = foo)`), we set
173
+ # self.visiting_arg_default_value = True. This allows visit_Name to be
174
+ # aware that we're inside function default values, which have special
175
+ # semantics.
176
+
177
+ # According to the AST docs, the arguments node has the following structure.
178
+ #
179
+ # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs,
180
+ # expr* kw_defaults, arg? kwarg, expr* defaults)
181
+ def visit_defaults(defaults):
182
+ try:
183
+ assert not self.visiting_arg_default_value
184
+ self.visiting_arg_default_value = True
185
+ for expr in defaults:
186
+ if expr is not None:
187
+ self.visit(expr)
188
+ finally:
189
+ self.visiting_arg_default_value = False
190
+
191
+ for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs):
192
+ self.visit(arg)
193
+
194
+ visit_defaults(node.kw_defaults)
195
+
196
+ if node.kwarg is not None:
197
+ self.visit(node.kwarg)
198
+
199
+ visit_defaults(node.defaults)
200
+
201
+ def visitAssnTarget(self, node):
202
+ # Target is either a single string, or a list of strings (if the assn
203
+ # target is a tuple).
204
+ target = self.visit(node)
205
+ if isinstance(target, list):
206
+ self.local_names |= set(target)
207
+ else:
208
+ self.local_names.add(target)
209
+
210
+ def visit_Assign(self, node):
211
+ if len(node.targets) != 1:
212
+ # TODO(jlebar): I don't actually know how to hit this. You don't
213
+ # get it from `a, b = ...` -- in that case, node.targets is a single
214
+ # Tuple, and in fact we *do* need to handle that case if we want
215
+ # existing code to work.
216
+ raise TypeError("Simultaneous multiple assignment is not supported.")
217
+
218
+ self.visitAssnTarget(node.targets[0])
219
+
220
+ # This will re-visit the target, but that's OK.
221
+ self.generic_visit(node)
222
+
223
+ def visit_AnnAssign(self, node):
224
+ self.visitAssnTarget(node.target)
225
+
226
+ # This will re-visit the target, but that's OK.
227
+ self.generic_visit(node)
228
+
229
+ def visit_For(self, node):
230
+ self.visitAssnTarget(node.target)
231
+
232
+ # This will re-visit the target, but that's fine.
233
+ self.generic_visit(node)
234
+
235
+
236
+ # -----------------------------------------------------------------------------
237
+ # JITFunction
238
+ # -----------------------------------------------------------------------------
239
+
240
+
241
+ def _normalize_ty(ty) -> str:
242
+ if isinstance(ty, type):
243
+ return ty.__name__
244
+ elif isinstance(ty, str):
245
+ return ty
246
+ return repr(ty)
247
+
248
+
249
+ class KernelParam:
250
+ """Represents a parameter (name plus metadata) to a @jit'ed function."""
251
+
252
+ def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
253
+ self.num = num
254
+ self._param = param
255
+ self.do_not_specialize = do_not_specialize
256
+
257
+ @cached_property
258
+ def name(self):
259
+ return self._param.name
260
+
261
+ @cached_property
262
+ def annotation(self):
263
+ if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
264
+ return ""
265
+ return _normalize_ty(self._param.annotation)
266
+
267
+ @cached_property
268
+ def annotation_type(self):
269
+ annotation = self.annotation
270
+ for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
271
+ width = annotation[annotation.find(ty1) + len(ty1):]
272
+ if width and ty1 in annotation:
273
+ return f"{ty2}{width}"
274
+ if annotation == "bool":
275
+ return "u1"
276
+ return ""
277
+
278
+ @cached_property
279
+ def is_constexpr(self):
280
+ return "constexpr" in self.annotation
281
+
282
+ @cached_property
283
+ def is_const(self):
284
+ return "const" in self.annotation and not self.is_constexpr
285
+
286
+ @property
287
+ def default(self):
288
+ return self._param.default
289
+
290
+ @property
291
+ def has_default(self):
292
+ return self._param.default != inspect.Parameter.empty
293
+
294
+
295
+ def compute_spec_key(v):
296
+
297
+ if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
298
+ return "D"
299
+ elif isinstance(v, int):
300
+ # bool is a subclass of int, so we don't check explicitly above.
301
+ if (v % 16 == 0):
302
+ return "D"
303
+ elif v == 1:
304
+ return "1"
305
+ return "N"
306
+
307
+
308
+ dtype2str = {}
309
+
310
+
311
+ def mangle_type(arg, is_const=False):
312
+
313
+ if arg is None:
314
+ return "none"
315
+ elif isinstance(arg, bool):
316
+ return "i1"
317
+ elif isinstance(arg, int):
318
+ if -(2**31) <= arg and arg <= 2**31 - 1:
319
+ return "i32"
320
+ elif 2**63 <= arg and arg <= 2**64 - 1:
321
+ return "u64"
322
+ else:
323
+ return "i64"
324
+ elif isinstance(arg, float):
325
+ return "fp32"
326
+ else:
327
+ # dtypes are hashable so we can memoize this mapping:
328
+ dsk = (arg.dtype, is_const)
329
+ res = dtype2str.get(dsk, None)
330
+ if res is None:
331
+ res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
332
+ dtype2str[dsk] = res
333
+ return res
334
+
335
+
336
+ class KernelInterface(Generic[T]):
337
+ run: T
338
+
339
+ def __getitem__(self, grid) -> T:
340
+ """
341
+ A JIT function is launched with: fn[grid](*args, **kwargs).
342
+ Hence JITFunction.__getitem__ returns a callable proxy that
343
+ memorizes the grid.
344
+ """
345
+ return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
346
+ # return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
347
+
348
+
349
+ def serialize_specialization_data(name, signature, constants, attrs, options, key):
350
+ constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()}
351
+ import json
352
+ obj = {
353
+ 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options':
354
+ options.__dict__, 'key': key
355
+ }
356
+ serialized_obj = json.dumps(obj)
357
+ return serialized_obj
358
+
359
+
360
+ def create_function_from_signature(sig, kparams):
361
+ """
362
+ Equivalent to sig.bind followed by apply_defaults. This generates a
363
+ native Python function (using exec) which can be memoized on a per-kernel
364
+ basis to avoid having to run these expensive functions -- which constitute
365
+ much of the kernel launch overhead -- every time we run the kernel.
366
+ """
367
+
368
+ assert len(sig.parameters) == len(kparams)
369
+
370
+ # Create the function argument list and the dict entries for the return statement
371
+ func_args = []
372
+ dict_entries = []
373
+ constexpr_vals = []
374
+ non_constexpr_vals = []
375
+ signature_types = []
376
+ specialisations = []
377
+
378
+ for ((name, sp), kp) in zip(sig.parameters.items(), kparams):
379
+ if sp.default is inspect.Parameter.empty:
380
+ func_args.append(name)
381
+ dict_entries.append(f"'{name}': {name}")
382
+ else:
383
+ func_args.append(f"{name}=default_{name}")
384
+ dict_entries.append(f"'{name}': {name}")
385
+ if kp.is_constexpr:
386
+ constexpr_vals.append(name)
387
+ else:
388
+ non_constexpr_vals.append(name)
389
+ if not kp.do_not_specialize:
390
+ specialisations.append('compute_spec_key(%s)' % name)
391
+ if kp.annotation_type:
392
+ signature_types.append('"%s"' % kp.annotation_type)
393
+ else:
394
+ signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False'))
395
+
396
+ cache_key = ''.join([x + ', ' for x in signature_types + specialisations])
397
+ constexpr_vals = ''.join([x + ', ' for x in constexpr_vals])
398
+ non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals])
399
+
400
+ func_args.append('**excess_kwargs')
401
+
402
+ # Join all arguments into a function definition string
403
+ args_str = ', '.join(func_args)
404
+ dict_str = ', '.join(dict_entries)
405
+ func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % (
406
+ args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals)
407
+
408
+ # Prepare defaults to be inserted into function namespace
409
+ func_namespace = {
410
+ f"default_{name}": param.default
411
+ for name, param in sig.parameters.items()
412
+ if param.default is not inspect.Parameter.empty
413
+ }
414
+
415
+ func_namespace['mangle_type'] = mangle_type
416
+ func_namespace['compute_spec_key'] = compute_spec_key
417
+
418
+ # Execute the function string in func_namespace to create the function
419
+ exec(func_body, func_namespace)
420
+
421
+ # Extract the newly created function from the namespace
422
+ return func_namespace['dynamic_func']
423
+
424
+
425
+ type_canonicalisation_dict = {
426
+ "bool": "i1",
427
+ "float8e4nv": "fp8e4nv",
428
+ "float8e5": "fp8e5",
429
+ "float8e4b15": "fp8e4b15",
430
+ "float8_e4m3fn": "fp8e4nv",
431
+ "float8e4b8": "fp8e4b8",
432
+ "float8_e4m3fnuz": "fp8e4b8",
433
+ "float8_e5m2": "fp8e5",
434
+ "float8e5b16": "fp8e5b16",
435
+ "float8_e5m2fnuz": "fp8e5b16",
436
+ "float16": "fp16",
437
+ "bfloat16": "bf16",
438
+ "float32": "fp32",
439
+ "float64": "fp64",
440
+ "int8": "i8",
441
+ "int16": "i16",
442
+ "int32": "i32",
443
+ "int64": "i64",
444
+ "uint8": "u8",
445
+ "uint16": "u16",
446
+ "uint32": "u32",
447
+ "uint64": "u64",
448
+ }
449
+
450
+ for v in list(type_canonicalisation_dict.values()):
451
+ type_canonicalisation_dict[v] = v
452
+
453
+
454
+ class JITFunction(KernelInterface[T]):
455
+ # Hook for inspecting compiled functions and modules
456
+ cache_hook = None
457
+ divisibility = 16
458
+
459
+ @staticmethod
460
+ def _key_of(arg):
461
+ if hasattr(arg, "dtype"):
462
+ return arg.dtype
463
+ elif isinstance(arg, bool):
464
+ return "i1"
465
+ elif isinstance(arg, int):
466
+ if -(2**31) <= arg and arg <= 2**31 - 1:
467
+ return "i32"
468
+ elif 2**63 <= arg and arg <= 2**64 - 1:
469
+ return "u64"
470
+ else:
471
+ return "i64"
472
+ elif isinstance(arg, float):
473
+ return "fp32"
474
+ elif arg is None:
475
+ return None
476
+ else:
477
+ raise TypeError(f"Unsupported type {type(arg)} for {arg}")
478
+
479
+ @staticmethod
480
+ def _spec_of(arg):
481
+ if hasattr(arg, "data_ptr"):
482
+ return arg.data_ptr() % JITFunction.divisibility == 0
483
+ elif isinstance(arg, int):
484
+ return (arg % 16 == 0, arg == 1)
485
+ return (arg is None, )
486
+
487
+ def _get_config(self, *args):
488
+ from ..compiler import AttrsDescriptor
489
+
490
+ def is_divisible_by_16(x):
491
+ if hasattr(x, "data_ptr"):
492
+ return x.data_ptr() % JITFunction.divisibility == 0
493
+ elif isinstance(x, int):
494
+ return x % JITFunction.divisibility == 0
495
+ if x is None:
496
+ return True
497
+ return False
498
+
499
+ divisible_by_16 = {
500
+ param.num
501
+ for param, arg in zip(self.params, args)
502
+ if is_divisible_by_16(arg) and not param.do_not_specialize
503
+ }
504
+ equal_to_1 = {
505
+ param.num
506
+ for param, arg in zip(self.params, args)
507
+ if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize
508
+ }
509
+ # folded equal_to_1 and None
510
+ # TODO: method to collect all folded args
511
+ return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1))
512
+ # return _triton.code_gen.instance_descriptor(divisible_by_16,
513
+ # equal_to_1)
514
+
515
+ @staticmethod
516
+ def _type_of(key, is_const=False):
517
+ # `None` is nullptr. Implicitly convert to *i8.
518
+ if key is None:
519
+ return "*i8"
520
+ elif isinstance(key, str):
521
+ return key
522
+
523
+ dtype_str = str(key).split(".")[-1]
524
+ dtype_str = type_canonicalisation_dict[dtype_str]
525
+ const_str = "*k" if is_const else "*"
526
+ return const_str + dtype_str
527
+
528
+ def _make_constants(self, constexpr_key):
529
+ constants = dict(zip(self.constexprs, constexpr_key))
530
+ return constants
531
+
532
+ def _call_hook(
533
+ self,
534
+ key,
535
+ signature,
536
+ device,
537
+ constants,
538
+ options,
539
+ configs,
540
+ ):
541
+ if JITFunction.cache_hook is None:
542
+ return False
543
+
544
+ name = self.fn.__name__
545
+ module = self.fn.__module__
546
+ arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
547
+ repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})"
548
+
549
+ class JitFunctionInfo:
550
+
551
+ def __init__(self, module, name, jit_function):
552
+ self.module = module
553
+ self.name = name
554
+ self.jit_function = jit_function
555
+ pass
556
+
557
+ specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key)
558
+
559
+ kwargs = {
560
+ 'signature': signature,
561
+ 'device': device,
562
+ 'constants': constants,
563
+ 'num_warps': options.num_warps,
564
+ 'num_ctas': options.num_ctas,
565
+ 'num_stages': options.num_stages,
566
+ 'enable_fp_fusion': options.enable_fp_fusion,
567
+ 'extern_libs': options.extern_libs,
568
+ 'configs': configs,
569
+ 'specialization_data': specialization_data,
570
+ }
571
+
572
+ return JITFunction.cache_hook(
573
+ key=key,
574
+ repr=repr,
575
+ fn=JitFunctionInfo(module, name, self),
576
+ compile={"key": key, **kwargs},
577
+ is_manual_warmup=False,
578
+ already_compiled=False,
579
+ )
580
+
581
+ def add_pre_run_hook(self, hook):
582
+ '''
583
+ Add a hook that will be executed prior to the execution of run
584
+ function with args and kwargs passed into the kernel
585
+ '''
586
+ assert callable(hook)
587
+ self.pre_run_hooks.append(hook)
588
+
589
+ def create_binder(self):
590
+ """
591
+ Precompute as much as possible.
592
+ """
593
+ from ..compiler import CompiledKernel, compile, ASTSource, make_backend
594
+ self.CompiledKernel = CompiledKernel
595
+ self.compile = compile
596
+ self.ASTSource = ASTSource
597
+ self.make_backend = make_backend
598
+ self.binder = create_function_from_signature(self.signature, self.params)
599
+ self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr]
600
+ self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr]
601
+ self.specialised_indices = [
602
+ i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr)
603
+ ]
604
+
605
+ def run(self, *args, grid, warmup, **kwargs):
606
+ # parse options
607
+ device = driver.active.get_current_device()
608
+ stream = driver.active.get_current_stream(device)
609
+ kwargs["debug"] = self.debug
610
+
611
+ # Execute pre run hooks with args and kwargs
612
+ for hook in self.pre_run_hooks:
613
+ hook(*args, **kwargs)
614
+
615
+ if self.binder is None:
616
+ self.create_binder()
617
+
618
+ bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
619
+
620
+ # compute cache key
621
+ key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
622
+ kernel = self.cache[device].get(key, None)
623
+
624
+ if kernel is None:
625
+ # Kernel is not cached; we have to compile.
626
+ target = driver.active.get_current_target()
627
+ backend = self.make_backend(target)
628
+ options = backend.parse_options(kwargs)
629
+
630
+ # deprecated arguments
631
+ assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
632
+ assert "device" not in kwargs, "device option is deprecated; current device will be used"
633
+ assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
634
+ for k in excess_kwargs:
635
+ if k not in options.__dict__:
636
+ raise KeyError("Keyword argument %s was specified but unrecognised" % k)
637
+
638
+ bound_vals = tuple(bound_args.values())
639
+
640
+ # `None` is nullptr. Implicitly convert to *i8. This needs to be
641
+ # done here rather than when we build the signature as otherwise
642
+ # the kernel cache key could not distinguish between byte pointers
643
+ # and None arguments, resulting in a downstream mismatch:
644
+ sigkeys = [self.params[i].name for i in self.non_constexpr_indices]
645
+ sigvals = sig_and_spec[:len(sigkeys)]
646
+ signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)}
647
+
648
+ configs = (self._get_config(*bound_vals), )
649
+ constants = {
650
+ p.name: v
651
+ for (v, p) in zip(bound_vals, self.params)
652
+ if p.is_constexpr or p.num in configs[0].equal_to_1 or v is None
653
+ }
654
+ for i, arg in constants.items():
655
+ if callable(arg):
656
+ raise TypeError(f"Callable constexpr at index {i} is not supported")
657
+
658
+ if self._call_hook(key, signature, device, constants, options, configs):
659
+ return None
660
+ # compile the kernel
661
+ src = self.ASTSource(self, signature, constants, configs[0])
662
+ kernel = self.compile(
663
+ src,
664
+ target=target,
665
+ options=options.__dict__,
666
+ )
667
+ self.cache[device][key] = kernel
668
+
669
+ # Check that used global values have not changed.
670
+ not_present = object()
671
+ for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items():
672
+ if (newVal := globals_dict.get(name, not_present)) != val:
673
+ raise RuntimeError(
674
+ f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")
675
+
676
+ if not warmup:
677
+ # canonicalize grid
678
+ assert grid is not None
679
+ if callable(grid):
680
+ # Arguments are passed as a dict to `grid`, by contract.
681
+ # TODO(jlebar): In the new launch API, pass the compiler flags as a
682
+ # second parameter to `grid`.
683
+ grid = grid(bound_args)
684
+ grid_size = len(grid)
685
+ grid_0 = grid[0]
686
+ grid_1 = grid[1] if grid_size > 1 else 1
687
+ grid_2 = grid[2] if grid_size > 2 else 1
688
+
689
+ # launch kernel
690
+ launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
691
+ kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
692
+ self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
693
+ return kernel
694
+
695
+ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None,
696
+ launch_metadata=None):
697
+ do_not_specialize = do_not_specialize if do_not_specialize else []
698
+
699
+ self.fn = fn
700
+ self.module = fn.__module__
701
+ self.version = version
702
+ self.signature = inspect.signature(fn)
703
+ self.do_not_specialize = do_not_specialize
704
+ self.starting_line_number = inspect.getsourcelines(fn)[1]
705
+ self.repr = lambda _: fn.__name__ if repr is None else repr(_)
706
+ self.launch_metadata = launch_metadata
707
+
708
+ self.binder = None
709
+
710
+ self.params = []
711
+ for i, param in enumerate(self.signature.parameters.values()):
712
+ dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
713
+ self.params.append(KernelParam(i, param, dns))
714
+
715
+ # function source code (without decorators)
716
+ self.src = textwrap.dedent(inspect.getsource(fn))
717
+ self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
718
+ # cache of just-in-time compiled kernels
719
+ self.cache = defaultdict(dict)
720
+ self.hash = None
721
+
722
+ # Map of global variables used by the function and any functions it
723
+ # transitively calls, plus their values. The values are collected when
724
+ # the function is first compiled. Then every time we run the function,
725
+ # we check that the values of the globals match what's expected,
726
+ # otherwise we raise an error.
727
+ #
728
+ # Different functions can have different __globals__ maps, so the map
729
+ # key is actually (var name, id(__globals__)), and the map value is
730
+ # (value, __globals__).
731
+ self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
732
+
733
+ # JITFunction can be instantiated as kernel
734
+ # when called with a grid using __getitem__
735
+ self.kernel = None
736
+ self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
737
+ self.noinline = noinline
738
+
739
+ # TODO(jlebar): Remove uses of these fields outside this file, then
740
+ # remove the fields here.
741
+ self.arg_names = [p.name for p in self.params]
742
+ self.constexprs = [p.num for p in self.params if p.is_constexpr]
743
+
744
+ # Hooks that will be called prior to executing "run"
745
+ self.pre_run_hooks = []
746
+
747
+ # reuse docs of wrapped function
748
+ self.__doc__ = fn.__doc__
749
+ self.__name__ = fn.__name__
750
+ self.__globals__ = fn.__globals__
751
+ self.__module__ = fn.__module__
752
+
753
+ @property
754
+ def cache_key(self):
755
+ # TODO : hash should be attribute of `self`
756
+ if self.hash is None:
757
+ dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src)
758
+ dependencies_finder.visit(self.parse())
759
+ self.hash = dependencies_finder.ret + str(self.starting_line_number)
760
+ self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
761
+ return self.hash
762
+
763
+ def warmup(self, *args, grid, **kwargs):
764
+ return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
765
+
766
+ def preload(self, specialization_data):
767
+ from ..compiler import AttrsDescriptor, compile, ASTSource
768
+ import json
769
+ import triton.language as tl
770
+ device = driver.active.get_current_device()
771
+ deserialized_obj = json.loads(specialization_data)
772
+ if deserialized_obj['name'] != self.fn.__name__:
773
+ raise RuntimeError(
774
+ f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
775
+ constants = {
776
+ key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
777
+ for key, value in deserialized_obj['constants'].items()
778
+ }
779
+ signature = dict(deserialized_obj['signature'].items())
780
+ src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs']))
781
+ options = {
782
+ key: tuple(value) if isinstance(value, list) else value
783
+ for key, value in deserialized_obj['options'].items()
784
+ }
785
+ key = deserialized_obj['key']
786
+ kernel = compile(src, None, options)
787
+ self.cache[device][key] = kernel
788
+ return kernel
789
+
790
+ # we do not parse `src` in the constructor because
791
+ # the user might want to monkey-patch self.src dynamically.
792
+ # Our unit tests do this, for example.
793
+ def parse(self):
794
+ tree = ast.parse(self.src)
795
+ assert isinstance(tree, ast.Module)
796
+ assert len(tree.body) == 1
797
+ assert isinstance(tree.body[0], ast.FunctionDef)
798
+ return tree
799
+
800
+ def __call__(self, *args, **kwargs):
801
+ raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
802
+
803
+ def __setattr__(self, name, value):
804
+ super(JITFunction, self).__setattr__(name, value)
805
+ # - when `.src` attribute is set, cache path needs
806
+ # to be reinitialized
807
+ if name == "src":
808
+ self.hash = None
809
+
810
+ def __repr__(self):
811
+ return f"JITFunction({self.module}:{self.fn.__name__})"
812
+
813
+
814
+ # -----------------------------------------------------------------------------
815
+ # `jit` decorator
816
+ # -----------------------------------------------------------------------------
817
+
818
+
819
+ @overload
820
+ def jit(fn: T) -> JITFunction[T]:
821
+ ...
822
+
823
+
824
+ @overload
825
+ def jit(
826
+ *,
827
+ version=None,
828
+ repr: Optional[Callable] = None,
829
+ launch_metadata: Optional[Callable] = None,
830
+ do_not_specialize: Optional[Iterable[int]] = None,
831
+ debug: Optional[bool] = None,
832
+ noinline: Optional[bool] = None,
833
+ ) -> Callable[[T], JITFunction[T]]:
834
+ ...
835
+
836
+
837
+ def jit(
838
+ fn: Optional[T] = None,
839
+ *,
840
+ version=None,
841
+ repr: Optional[Callable] = None,
842
+ launch_metadata: Optional[Callable] = None,
843
+ do_not_specialize: Optional[Iterable[int]] = None,
844
+ debug: Optional[bool] = None,
845
+ noinline: Optional[bool] = None,
846
+ ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
847
+ """
848
+ Decorator for JIT-compiling a function using the Triton compiler.
849
+
850
+ :note: When a jit'd function is called, arguments are
851
+ implicitly converted to pointers if they have a :code:`.data_ptr()` method
852
+ and a `.dtype` attribute.
853
+
854
+ :note: This function will be compiled and run on the GPU. It will only have access to:
855
+
856
+ * python primitives,
857
+ * builtins within the triton package,
858
+ * arguments to this function,
859
+ * other jit'd functions
860
+
861
+ :param fn: the function to be jit-compiled
862
+ :type fn: Callable
863
+ """
864
+
865
+ def decorator(fn: T) -> JITFunction[T]:
866
+ assert callable(fn)
867
+ if os.getenv("TRITON_INTERPRET", "0") == "1":
868
+ from .interpreter import InterpretedFunction
869
+ return InterpretedFunction(fn)
870
+ else:
871
+ return JITFunction(
872
+ fn,
873
+ version=version,
874
+ do_not_specialize=do_not_specialize,
875
+ debug=debug,
876
+ noinline=noinline,
877
+ repr=repr,
878
+ launch_metadata=launch_metadata,
879
+ )
880
+
881
+ if fn is not None:
882
+ return decorator(fn)
883
+
884
+ else:
885
+ return decorator
886
+
887
+
888
+ # -----------------------------------------------------------------------------
889
+ # Utilities for mocking tensors
890
+ # -----------------------------------------------------------------------------
891
+
892
+
893
+ class MockTensor:
894
+ """
895
+ Can be used in place of real tensors when calling:
896
+ kernel.warmup(MockTensor(torch.float32), ...)
897
+ """
898
+
899
+ @staticmethod
900
+ def wrap_dtype(arg):
901
+ if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
902
+ return MockTensor(arg)
903
+ return arg
904
+
905
+ def __init__(self, dtype):
906
+ self.dtype = dtype
907
+
908
+ @staticmethod
909
+ def data_ptr():
910
+ return 0 # optimistically assumes multiple of 16
911
+
912
+
913
+ class TensorWrapper:
914
+
915
+ def __init__(self, base, dtype):
916
+ self.dtype = dtype
917
+ self.base = base
918
+ self.data = base.data
919
+ self.device = base.device
920
+ self.shape = self.base.shape
921
+
922
+ def data_ptr(self):
923
+ return self.base.data_ptr()
924
+
925
+ def stride(self, i):
926
+ return self.base.stride(i)
927
+
928
+ def __str__(self) -> str:
929
+ return f"TensorWrapper[{self.dtype}]({self.base})"
930
+
931
+ def element_size(self):
932
+ return self.base.element_size()
933
+
934
+ def cpu(self):
935
+ return TensorWrapper(self.base.cpu(), self.dtype)
936
+
937
+ def copy_(self, other):
938
+ self.base.copy_(other.base)
939
+
940
+ def to(self, device):
941
+ return TensorWrapper(self.base.to(device), self.dtype)
942
+
943
+
944
+ def reinterpret(tensor, dtype):
945
+ if isinstance(tensor, TensorWrapper):
946
+ if dtype == tensor.base.dtype:
947
+ # Reinterpreting to the original interpretation; return the base.
948
+ return tensor.base
949
+ else:
950
+ # Reinterpreting a wrapped tensor to a different type.
951
+ return TensorWrapper(tensor.base, dtype)
952
+ elif hasattr(tensor, "data_ptr"):
953
+ # A new wrapper is needed around an unwrapped tensor.
954
+ return TensorWrapper(tensor, dtype)
955
+ else:
956
+ raise TypeError(f"Cannot reinterpret a {type(tensor)}.")