triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/runtime/jit.py CHANGED
@@ -1,19 +1,28 @@
1
1
  from __future__ import annotations, division
2
2
  import ast
3
+ import copy
3
4
  import hashlib
4
5
  import inspect
5
6
  import itertools
6
- import os
7
+ import threading
7
8
  import re
8
9
  import textwrap
9
10
  from collections import defaultdict
11
+ from dataclasses import dataclass
10
12
  from functools import cached_property
11
13
  from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
12
- from ..runtime.driver import driver
14
+
15
+ from triton.tools.tensor_descriptor import TensorDescriptor
13
16
  from types import ModuleType
14
- from .._utils import find_paths_if, get_iterable_path
17
+ from .. import knobs
18
+ from .driver import driver
19
+ from . import _async_compile
20
+ from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, canonicalize_dtype
21
+ from .cache import get_cache_key
22
+ from triton._C.libtriton import get_cache_invalidating_env_vars
15
23
 
16
- TRITON_MODULE = __name__[:-len(".runtime.jit")]
24
+ TRITON_MODULE = "triton.language"
25
+ GLUON_MODULE = "triton.experimental.gluon.language"
17
26
 
18
27
  T = TypeVar("T")
19
28
 
@@ -34,13 +43,14 @@ class DependenciesFinder(ast.NodeVisitor):
34
43
  otherwise we could recompile).
35
44
  """
36
45
 
37
- def __init__(self, name, globals, src) -> None:
46
+ def __init__(self, name, globals, nonlocals, src) -> None:
38
47
  super().__init__()
39
48
  self.name = name
40
49
  self.hasher = hashlib.sha256(src.encode("utf-8"))
41
50
 
42
51
  # This function's __globals__ dict.
43
52
  self.globals = globals
53
+ self.nonlocals = nonlocals
44
54
 
45
55
  # Python builtins that can be accessed from Triton kernels.
46
56
  self.supported_python_builtins = {
@@ -55,6 +65,12 @@ class DependenciesFinder(ast.NodeVisitor):
55
65
  'print',
56
66
  'range',
57
67
  }
68
+ self.supported_modules = {
69
+ GLUON_MODULE,
70
+ TRITON_MODULE,
71
+ "copy",
72
+ "math",
73
+ }
58
74
 
59
75
  # used_global_vals tells us which global variables are used by this
60
76
  # function and all those it transitively calls, plus the values of those
@@ -81,22 +97,56 @@ class DependenciesFinder(ast.NodeVisitor):
81
97
  return module.startswith(TRITON_MODULE)
82
98
 
83
99
  def _update_hash(self, func):
84
- if isinstance(func, JITFunction):
85
- # Merge our used_global_vals with those of the called function,
86
- # after checking that all overlapping values are consistent.
87
- for k in self.used_global_vals.keys() & func.used_global_vals.keys():
88
- var_name, _ = k
89
- v1, _ = self.used_global_vals[k]
90
- v2, _ = func.used_global_vals[k]
91
- if v1 != v2:
92
- raise RuntimeError(
93
- f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
94
- )
95
- self.used_global_vals.update(func.used_global_vals)
96
- # update hash
97
- func_key = func.cache_key
98
- func_key += str(getattr(func, "noinline", False))
99
- self.hasher.update(func_key.encode("utf-8"))
100
+ assert isinstance(func, JITCallable)
101
+ # Merge our used_global_vals with those of the called function,
102
+ # after checking that all overlapping values are consistent.
103
+ for k in self.used_global_vals.keys() & func.used_global_vals.keys():
104
+ var_name, _ = k
105
+ v1, _ = self.used_global_vals[k]
106
+ v2, _ = func.used_global_vals[k]
107
+ if v1 != v2:
108
+ raise RuntimeError(
109
+ f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
110
+ )
111
+ self.used_global_vals.update(func.used_global_vals)
112
+ # update hash
113
+ func_key = func.cache_key
114
+ func_key += str(getattr(func, "noinline", False))
115
+ self.hasher.update(func_key.encode("utf-8"))
116
+
117
+ def record_reference(self, val, var_dict=None, name=None):
118
+ from ..language.core import constexpr
119
+ # Only keep track of "interesting" global variables, that non-evil users
120
+ # might change. Don't consider functions, modules, builtins, etc. This
121
+ # helps keep the list of vars we have to check small.
122
+ if val is None or type(val) is ModuleType:
123
+ return
124
+
125
+ if getattr(val, "__triton_builtin__", False):
126
+ return
127
+
128
+ # Stubs that aren't real functions
129
+ if getattr(val, "__module__", "") == "triton.language.extra.libdevice":
130
+ return
131
+
132
+ if isinstance(val, JITCallable):
133
+ self._update_hash(val)
134
+ return
135
+
136
+ if callable(val) and not isinstance(val, type) and not isinstance(val, constexpr):
137
+ raise RuntimeError(f"Unsupported function referenced: {val}")
138
+
139
+ # Python default arguments are resolved only once, when the
140
+ # function is defined. So if you do `foo(a=A)` and the value of
141
+ # A changes, foo will still use the old value of A.
142
+ # It would be pretty evil if someone did `import x` and then
143
+ # `x = blah`.
144
+ if self.visiting_arg_default_value:
145
+ return
146
+
147
+ if var_dict is not None:
148
+ self.used_global_vals[(name, id(var_dict))] = (copy.deepcopy(val), var_dict)
149
+ return
100
150
 
101
151
  def visit_Name(self, node):
102
152
  if type(node.ctx) is ast.Store:
@@ -106,26 +156,20 @@ class DependenciesFinder(ast.NodeVisitor):
106
156
  # The global name is hidden by the local name.
107
157
  return None
108
158
 
109
- val = self.globals.get(node.id, None)
159
+ def name_lookup(name):
160
+ val = self.globals.get(name, None)
161
+ if val is not None:
162
+ return val, self.globals
163
+ val = self.nonlocals.get(name, None)
164
+ if val is not None:
165
+ return val, self.nonlocals
166
+ return None, None
110
167
 
111
- # Only keep track of "interesting" global variables, that non-evil users
112
- # might change. Don't consider functions, modules, builtins, etc. This
113
- # helps keep the list of vars we have to check small.
114
- if (val is not None #
115
- # Python default arguments are resolved only once, when the
116
- # function is defined. So if you do `foo(a=A)` and the value of
117
- # A changes, foo will still use the old value of A.
118
- and not self.visiting_arg_default_value
119
- # It would be pretty evil if someone did `import x` and then
120
- # `x = blah`.
121
- and type(val) is not ModuleType
122
- # It would be pretty evil if we used function `foo` inside of
123
- # `bar` and then someone did `foo = baz`.
124
- and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) #
125
- and node.id not in self.supported_python_builtins):
126
- self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals)
127
-
128
- self._update_hash(val)
168
+ val, var_dict = name_lookup(node.id)
169
+ if node.id in self.supported_python_builtins:
170
+ return val
171
+
172
+ self.record_reference(val, var_dict, node.id)
129
173
  return val
130
174
 
131
175
  def visit_Tuple(self, node):
@@ -137,10 +181,11 @@ class DependenciesFinder(ast.NodeVisitor):
137
181
  lhs = self.visit(node.value)
138
182
  while isinstance(lhs, ast.Attribute):
139
183
  lhs = self.visit(lhs.value)
140
- if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE):
184
+ lhs_name = getattr(lhs, "__name__", "")
185
+ if lhs is None or lhs_name in self.supported_modules:
141
186
  return None
142
187
  ret = getattr(lhs, node.attr)
143
- self._update_hash(ret)
188
+ self.record_reference(ret)
144
189
  return ret
145
190
 
146
191
  def visit_FunctionDef(self, node):
@@ -221,11 +266,29 @@ class DependenciesFinder(ast.NodeVisitor):
221
266
 
222
267
 
223
268
  def _normalize_ty(ty) -> str:
224
- if isinstance(ty, type):
225
- return ty.__name__
226
- elif isinstance(ty, str):
227
- return ty
228
- return repr(ty)
269
+ import triton.language.core as core
270
+ if isinstance(ty, str):
271
+ ty = ty.strip()
272
+ if ty.startswith("const "):
273
+ ty = ty.removeprefix("const")
274
+ ty = _normalize_ty(ty)
275
+ assert ty.startswith("*")
276
+ return "*k" + ty[1:]
277
+ if ty.endswith("*"):
278
+ return "*" + _normalize_ty(ty[:-1])
279
+ if ty.startswith("*"):
280
+ return "*" + _normalize_ty(ty[1:])
281
+ if ty.startswith("tl."):
282
+ return _normalize_ty(ty.removeprefix("tl."))
283
+ elif isinstance(ty, core.pointer_type):
284
+ return f"*{_normalize_ty(ty.element_ty)}"
285
+ elif isinstance(ty, core.dtype):
286
+ ty = ty.name
287
+ elif isinstance(ty, type):
288
+ ty = ty.__name__
289
+ else:
290
+ ty = str(ty)
291
+ return type_canonicalisation_dict.get(ty.replace("_t", ""), ty)
229
292
 
230
293
 
231
294
  class KernelParam:
@@ -243,20 +306,20 @@ class KernelParam:
243
306
  return self._param.name
244
307
 
245
308
  @cached_property
246
- def annotation(self):
309
+ def annotation(self) -> str:
247
310
  if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
248
311
  return ""
249
312
  return _normalize_ty(self._param.annotation)
250
313
 
251
314
  @cached_property
252
- def annotation_type(self):
253
- annotation = self.annotation
254
- for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
255
- width = annotation[annotation.find(ty1) + len(ty1):]
256
- if width and ty1 in annotation:
257
- return f"{ty2}{width}"
258
- if annotation == "bool":
259
- return "u1"
315
+ def annotation_type(self) -> str:
316
+ a = self.annotation
317
+ if a.startswith("*k"):
318
+ a = a[2:]
319
+ elif a.startswith("*"):
320
+ a = a[1:]
321
+ if a in set(type_canonicalisation_dict.values()):
322
+ return self.annotation
260
323
  return ""
261
324
 
262
325
  @cached_property
@@ -265,7 +328,9 @@ class KernelParam:
265
328
 
266
329
  @cached_property
267
330
  def is_const(self):
268
- return "const" in self.annotation and not self.is_constexpr
331
+ if self.is_constexpr:
332
+ return False
333
+ return "const" in self.annotation or self.annotation.startswith("*k")
269
334
 
270
335
  @property
271
336
  def default(self):
@@ -280,22 +345,16 @@ dtype2str = {}
280
345
  specialize_impl_cache = []
281
346
 
282
347
 
283
- def create_specialize_impl():
284
- if specialize_impl_cache:
285
- return specialize_impl_cache[-1]
348
+ def create_specialize_impl(specialize_extra):
286
349
 
287
350
  from ..language import constexpr
351
+ from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
288
352
 
289
- def specialize_impl(arg, specialize_extra, is_const=False, specialize_value=True, align=True):
290
-
353
+ def specialize_impl(arg, is_const=False, specialize_value=True, align=True):
291
354
  if arg is None:
292
355
  return ("constexpr", None)
293
- elif isinstance(arg, JITFunction):
294
- return ("constexpr", arg.cache_key)
295
- elif isinstance(arg, constexpr):
296
- return ("constexpr", arg)
297
356
  elif isinstance(arg, bool):
298
- return ("i1", None)
357
+ return ("u1", None)
299
358
  elif isinstance(arg, int):
300
359
  key = specialize_extra(arg, "int", align=align) if specialize_value else None
301
360
  if arg == 1 and specialize_value:
@@ -308,31 +367,44 @@ def create_specialize_impl():
308
367
  return ("i64", key)
309
368
  elif isinstance(arg, float):
310
369
  return ("fp32", None)
311
- elif hasattr(arg, "tma_desc_cpu_ptr"):
312
- return ("nvTmaDesc", None)
313
- elif isinstance(arg, tuple):
314
- spec = [specialize_impl(x, specialize_extra) for x in arg]
315
- make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
316
- tys = make_tuple([x[0] for x in spec])
317
- keys = make_tuple([x[1] for x in spec])
318
- return (tys, keys)
319
- else:
370
+ elif hasattr(arg, "data_ptr"):
320
371
  # dtypes are hashable so we can memoize this mapping:
321
372
  dsk = (arg.dtype, is_const)
322
373
  res = dtype2str.get(dsk, None)
323
374
  if res is None:
324
- res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
375
+ res = ("*k" if dsk[1] else "*") + canonicalize_dtype(dsk[0])
325
376
  dtype2str[dsk] = res
326
377
  key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
327
378
  return (res, key)
379
+ elif isinstance(arg, JITCallable):
380
+ return ("constexpr", arg.cache_key)
381
+ elif isinstance(arg, constexpr):
382
+ return ("constexpr", arg)
383
+ elif isinstance(arg, tuple):
384
+ spec = [specialize_impl(x) for x in arg]
385
+ make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
386
+ tys = make_tuple([x[0] for x in spec])
387
+ keys = make_tuple([x[1] for x in spec])
388
+ return (tys, keys)
389
+ elif isinstance(arg, TensorDescriptor):
390
+ assert hasattr(arg.base, "data_ptr")
391
+ inner = canonicalize_dtype(arg.base.dtype)
392
+ return (f"tensordesc<{inner}{list(arg.block_shape)}>", None)
393
+ elif isinstance(arg, GluonTensorDescriptor):
394
+ assert hasattr(arg.base, "data_ptr")
395
+ inner = canonicalize_dtype(arg.base.dtype)
396
+ return (f"tensordesc<{inner}{list(arg.block_shape)},{arg.layout!r}>", None)
397
+ else:
398
+ raise TypeError("Unsupported type: %s" % type(arg))
328
399
 
329
- specialize_impl_cache.append(specialize_impl)
330
400
  return specialize_impl
331
401
 
332
402
 
333
403
  def mangle_type(arg, specialize=False):
334
- specialize_impl = create_specialize_impl()
335
- return specialize_impl(arg, lambda _, **kwargs: None, specialize_value=specialize)[0]
404
+ if len(specialize_impl_cache) == 0:
405
+ specialize_impl_cache.append(create_specialize_impl(lambda _, **kwargs: None))
406
+ specialize_impl = specialize_impl_cache[0]
407
+ return specialize_impl(arg, specialize_value=specialize)[0]
336
408
 
337
409
 
338
410
  class KernelInterface(Generic[T]):
@@ -378,9 +450,17 @@ def create_function_from_signature(sig, kparams, backend):
378
450
  is_const = 'True' if kp.is_const else 'False'
379
451
  specialize = 'False' if kp.do_not_specialize else 'True'
380
452
  align = 'False' if kp.do_not_specialize_on_alignment else 'True'
381
- ret = f"specialize_impl({name}, specialize_extra, {is_const}, {specialize}, {align})"
453
+ ret = f"specialize_impl({name}, {is_const}, {specialize}, {align})"
382
454
  if kp.annotation_type:
383
- specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
455
+ if isinstance(kp.annotation_type, str):
456
+ if kp.annotation_type == "u1" or kp.annotation_type[:2] in ["fp", "bf"]:
457
+ # we do not specialize non-constexpr floats and bools:
458
+ specialize = False
459
+ if specialize:
460
+ specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
461
+ else:
462
+ # skip runtime specialization:
463
+ specialization.append(f'("{kp.annotation_type}", None)')
384
464
  else:
385
465
  specialization.append(f"{ret}")
386
466
 
@@ -400,9 +480,8 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
400
480
  if param.default is not inspect.Parameter.empty
401
481
  }
402
482
 
403
- func_namespace["JITFunction"] = JITFunction
404
- func_namespace["specialize_impl"] = create_specialize_impl()
405
- func_namespace["specialize_extra"] = backend.get_arg_specialization
483
+ func_namespace["JITCallable"] = JITCallable
484
+ func_namespace["specialize_impl"] = create_specialize_impl(backend.get_arg_specialization)
406
485
 
407
486
  # Execute the function string in func_namespace to create the function
408
487
  exec(func_body, func_namespace)
@@ -411,44 +490,134 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
411
490
  return func_namespace['dynamic_func']
412
491
 
413
492
 
414
- type_canonicalisation_dict = {
415
- "bool": "i1",
416
- "float8e4nv": "fp8e4nv",
417
- "float8e5": "fp8e5",
418
- "float8e4b15": "fp8e4b15",
419
- "float8_e4m3fn": "fp8e4nv",
420
- "float8e4b8": "fp8e4b8",
421
- "float8_e4m3fnuz": "fp8e4b8",
422
- "float8_e5m2": "fp8e5",
423
- "float8e5b16": "fp8e5b16",
424
- "float8_e5m2fnuz": "fp8e5b16",
425
- "float16": "fp16",
426
- "bfloat16": "bf16",
427
- "float32": "fp32",
428
- "float64": "fp64",
429
- "int8": "i8",
430
- "int16": "i16",
431
- "int32": "i32",
432
- "int64": "i64",
433
- "uint8": "u8",
434
- "uint16": "u16",
435
- "uint32": "u32",
436
- "uint64": "u64",
437
- }
438
-
439
- for v in list(type_canonicalisation_dict.values()):
440
- type_canonicalisation_dict[v] = v
441
-
442
-
443
- class JITFunction(KernelInterface[T]):
444
- # Hook for inspecting compiled functions and modules
445
- cache_hook = None
446
- # Hook to signal that a kernel is done compiling and inspect compiled function.
447
- # cache_hook will always be called before compilation and compiled_hook after.
448
- compiled_hook = None
493
+ def get_full_name(fn):
494
+ return f"{fn.__module__}.{fn.__qualname__}"
495
+
496
+
497
+ class JITCallable:
498
+
499
+ def __init__(self, fn):
500
+ self.fn = fn
501
+ self.signature = inspect.signature(fn)
502
+ try:
503
+ self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
504
+ except OSError as e:
505
+ raise ValueError("@jit functions should be defined in a Python file") from e
506
+ self._fn_name = get_full_name(fn)
507
+ self._hash_lock = threading.RLock()
508
+
509
+ # function source code (without decorators)
510
+ src = textwrap.dedent("".join(self.raw_src))
511
+ src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
512
+ self._src = src
513
+ self.hash = None
514
+
515
+ # Map of global variables used by the function and any functions it
516
+ # transitively calls, plus their values. The values are collected when
517
+ # the function is first compiled. Then every time we run the function,
518
+ # we check that the values of the globals match what's expected,
519
+ # otherwise we raise an error.
520
+ #
521
+ # Different functions can have different __globals__ maps, so the map
522
+ # key is actually (var name, id(__globals__)), and the map value is
523
+ # (value, __globals__).
524
+ self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
525
+
526
+ # reuse docs of wrapped function
527
+ self.__doc__ = fn.__doc__
528
+ self.__name__ = fn.__name__
529
+ self.__qualname__ = fn.__qualname__
530
+ self.__globals__ = fn.__globals__
531
+ self.__module__ = fn.__module__
532
+
533
+ def get_capture_scope(self):
534
+ return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
535
+
536
+ @property
537
+ def cache_key(self):
538
+ # TODO : hash should be attribute of `self`
539
+ with self._hash_lock:
540
+ if self.hash is not None:
541
+ return self.hash
542
+ # Set a placeholder hash to break recursion in case the function
543
+ # transitively calls itself. The full hash is set after.
544
+ self.hash = f"recursion:{self._fn_name}"
545
+ nonlocals = inspect.getclosurevars(self.fn).nonlocals
546
+ dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
547
+ src=self.src)
548
+ dependencies_finder.visit(self.parse())
549
+ self.hash = dependencies_finder.ret + str(self.starting_line_number)
550
+ self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
551
+
552
+ from triton.language.core import constexpr
553
+ self.hash += str([(name, val)
554
+ for (name, _), (val, _) in self.used_global_vals.items()
555
+ if isinstance(val, constexpr)])
556
+ self.hash = hashlib.sha256(self.hash.encode("utf-8")).hexdigest()
557
+ return self.hash
558
+
559
+ # we do not parse `src` in the constructor because
560
+ # the user might want to monkey-patch self.src dynamically.
561
+ # Our unit tests do this, for example.
562
+ def parse(self):
563
+ tree = ast.parse(self._src)
564
+ assert isinstance(tree, ast.Module)
565
+ assert len(tree.body) == 1
566
+ assert isinstance(tree.body[0], ast.FunctionDef)
567
+ return tree
568
+
569
+ @property
570
+ def type(self):
571
+ from triton.language.core import constexpr_type
572
+ return constexpr_type(self)
573
+
574
+ def _unsafe_update_src(self, new_src):
575
+ """
576
+ The only method allowed to modify src.
577
+ Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
578
+
579
+ Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None.
580
+ """
581
+ self.hash = None
582
+ self._src = new_src
583
+
584
+ def _set_src(self):
585
+ raise AttributeError("Cannot set attribute 'src' directly. "
586
+ "Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
587
+ "instead.")
588
+
589
+ def _get_src(self):
590
+ return self._src
591
+
592
+ src = property(fget=_get_src, fset=_set_src)
593
+
594
+
595
+ @dataclass
596
+ class JitFunctionInfo:
597
+ module: ModuleType
598
+ name: str
599
+ jit_function: JITFunction
600
+
601
+
602
+ def compute_cache_key(kernel_key_cache, specialization, options):
603
+ key = (tuple(specialization), str(options))
604
+ cache_key = kernel_key_cache.get(key, None)
605
+ if cache_key is not None:
606
+ return cache_key
607
+
608
+ cache_key = str(specialization) + str(options)
609
+ kernel_key_cache[key] = cache_key
610
+ return cache_key
611
+
612
+
613
+ class JITFunction(JITCallable, KernelInterface[T]):
614
+
615
+ def is_gluon(self):
616
+ return False
449
617
 
450
618
  def _call_hook(
451
619
  self,
620
+ hook,
452
621
  key,
453
622
  signature,
454
623
  device,
@@ -456,26 +625,17 @@ class JITFunction(KernelInterface[T]):
456
625
  options,
457
626
  configs,
458
627
  is_warmup,
459
- before,
460
- ):
461
- hook = JITFunction.cache_hook if before else JITFunction.compiled_hook
462
- if hook is None:
463
- return False
628
+ ) -> bool | None:
629
+ if not hook:
630
+ return None
464
631
 
465
- name = self.fn.__name__
632
+ name = self.fn.__qualname__
466
633
  module = self.fn.__module__
467
634
  arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
468
635
  repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
636
+ full_name = get_full_name(self.fn)
469
637
 
470
- class JitFunctionInfo:
471
-
472
- def __init__(self, module, name, jit_function):
473
- self.module = module
474
- self.name = name
475
- self.jit_function = jit_function
476
- pass
477
-
478
- specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key)
638
+ specialization_data = serialize_specialization_data(full_name, signature, constants, configs[0], options, key)
479
639
 
480
640
  kwargs = {
481
641
  'signature': signature,
@@ -520,10 +680,34 @@ class JITFunction(KernelInterface[T]):
520
680
  self.compile = compile
521
681
  self.ASTSource = ASTSource
522
682
  binder = create_function_from_signature(self.signature, self.params, backend)
523
- return {}, target, backend, binder
683
+ return {}, {}, target, backend, binder
684
+
685
+ def _pack_args(self, backend, kwargs, bound_args, specialization, options):
686
+ # options
687
+ options = backend.parse_options(kwargs)
688
+ # signature
689
+ sigkeys = [x.name for x in self.params]
690
+ sigvals = [x[0] for x in specialization]
691
+ signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
692
+ # check arguments
693
+ assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
694
+ assert "device" not in kwargs, "device option is deprecated; current device will be used"
695
+ assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
696
+ for k in kwargs:
697
+ if k not in options.__dict__ and k not in sigkeys:
698
+ raise KeyError("Keyword argument %s was specified but unrecognised" % k)
699
+ # constexprs
700
+ constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
701
+ constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
702
+ # attributes
703
+ attrvals = [x[1] for x in specialization]
704
+ attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
705
+ attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
706
+
707
+ return options, signature, constexprs, attrs
524
708
 
525
709
  def run(self, *args, grid, warmup, **kwargs):
526
- kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"
710
+ kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
527
711
 
528
712
  # parse options
529
713
  device = driver.active.get_current_device()
@@ -533,42 +717,22 @@ class JITFunction(KernelInterface[T]):
533
717
  for hook in self.pre_run_hooks:
534
718
  hook(*args, **kwargs)
535
719
 
536
- kernel_cache, target, backend, binder = self.device_caches[device]
720
+ kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
721
+ # specialization is list[tuple[str, Any]], where first element of tuple is
722
+ # the type and the second parameter is the 'specialization' value.
537
723
  bound_args, specialization, options = binder(*args, **kwargs)
538
724
 
539
- # compute cache key
540
- key = str(specialization) + str(options)
725
+ key = compute_cache_key(kernel_key_cache, specialization, options)
541
726
  kernel = kernel_cache.get(key, None)
542
727
 
543
728
  # Kernel is not cached; we have to compile.
544
729
  if kernel is None:
545
- # options
546
- options = backend.parse_options(kwargs)
547
- # signature
548
- sigkeys = [x.name for x in self.params]
549
- sigvals = [x[0] for x in specialization]
550
- signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
551
- # check arguments
552
- assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
553
- assert "device" not in kwargs, "device option is deprecated; current device will be used"
554
- assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
555
- for k in kwargs:
556
- if k not in options.__dict__ and k not in sigkeys:
557
- raise KeyError("Keyword argument %s was specified but unrecognised" % k)
558
- # constexprs
559
- constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
560
- constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
561
- # attributes
562
- attrvals = [x[1] for x in specialization]
563
- attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
564
- attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
565
- if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True):
730
+ options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
731
+ options)
732
+
733
+ kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
734
+ if kernel is None:
566
735
  return None
567
- # compile the kernel
568
- src = self.ASTSource(self, signature, constexprs, attrs)
569
- kernel = self.compile(src, target=target, options=options.__dict__)
570
- kernel_cache[key] = kernel
571
- self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False)
572
736
 
573
737
  # Check that used global values have not changed.
574
738
  not_present = object()
@@ -586,11 +750,12 @@ class JITFunction(KernelInterface[T]):
586
750
  grid_0 = grid[0]
587
751
  grid_1 = grid[1] if grid_size > 1 else 1
588
752
  grid_2 = grid[2] if grid_size > 2 else 1
753
+ if hasattr(kernel, "result"):
754
+ kernel = kernel.result()
589
755
  # launch kernel
590
756
  launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
591
- kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
592
- launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook,
593
- *bound_args.values())
757
+ kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
758
+ knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())
594
759
  return kernel
595
760
 
596
761
  def repr(self, _):
@@ -601,15 +766,12 @@ class JITFunction(KernelInterface[T]):
601
766
  do_not_specialize = do_not_specialize if do_not_specialize else []
602
767
  do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []
603
768
 
604
- self.fn = fn
769
+ super().__init__(fn)
605
770
  self.module = fn.__module__
606
771
  self.version = version
607
- self.signature = inspect.signature(fn)
608
772
  self.do_not_specialize = do_not_specialize
609
773
  self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
610
- self.starting_line_number = inspect.getsourcelines(fn)[1]
611
774
  self._repr = repr
612
- self._fn_name = fn.__name__
613
775
  self.launch_metadata = launch_metadata
614
776
 
615
777
  self.params = []
@@ -618,24 +780,8 @@ class JITFunction(KernelInterface[T]):
618
780
  dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
619
781
  self.params.append(KernelParam(i, param, dns, dns_oa))
620
782
 
621
- # function source code (without decorators)
622
- src = textwrap.dedent(inspect.getsource(fn))
623
- src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
624
- self._unsafe_update_src(src)
625
783
  # cache of just-in-time compiled kernels
626
784
  self.device_caches = defaultdict(self.create_binder)
627
- self.hash = None
628
-
629
- # Map of global variables used by the function and any functions it
630
- # transitively calls, plus their values. The values are collected when
631
- # the function is first compiled. Then every time we run the function,
632
- # we check that the values of the globals match what's expected,
633
- # otherwise we raise an error.
634
- #
635
- # Different functions can have different __globals__ maps, so the map
636
- # key is actually (var name, id(__globals__)), and the map value is
637
- # (value, __globals__).
638
- self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
639
785
 
640
786
  # JITFunction can be instantiated as kernel
641
787
  # when called with a grid using __getitem__
@@ -651,37 +797,20 @@ class JITFunction(KernelInterface[T]):
651
797
  # Hooks that will be called prior to executing "run"
652
798
  self.pre_run_hooks = []
653
799
 
654
- # reuse docs of wrapped function
655
- self.__doc__ = fn.__doc__
656
- self.__name__ = fn.__name__
657
- self.__globals__ = fn.__globals__
658
- self.__module__ = fn.__module__
659
-
660
- @property
661
- def cache_key(self):
662
- # TODO : hash should be attribute of `self`
663
- if self.hash is None:
664
- dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src)
665
- dependencies_finder.visit(self.parse())
666
- self.hash = dependencies_finder.ret + str(self.starting_line_number)
667
- self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
668
- return self.hash
669
-
670
800
  def warmup(self, *args, grid, **kwargs):
671
801
  return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
672
802
 
673
803
  def preload(self, specialization_data):
674
- from ..compiler import compile, ASTSource
675
804
  import json
676
805
  import triton.language as tl
677
806
  device = driver.active.get_current_device()
678
807
  deserialized_obj = json.loads(specialization_data)
679
- if deserialized_obj['name'] != self.fn.__name__:
808
+ if deserialized_obj['name'] != self._fn_name:
680
809
  raise RuntimeError(
681
- f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
810
+ f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
682
811
  constant_keys = map(tuple, deserialized_obj['constant_keys'])
683
812
  constant_vals = deserialized_obj['constant_vals']
684
- constants = {
813
+ constexprs = {
685
814
  key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
686
815
  for key, value in zip(constant_keys, constant_vals)
687
816
  }
@@ -689,47 +818,57 @@ class JITFunction(KernelInterface[T]):
689
818
  attrs_vals = deserialized_obj['attrs_vals']
690
819
  attrs = dict(zip(attrs_keys, attrs_vals))
691
820
  signature = dict(deserialized_obj['signature'].items())
692
- src = ASTSource(self, signature, constants, attrs)
693
821
  options = {
694
822
  key: tuple(value) if isinstance(value, list) else value
695
823
  for key, value in deserialized_obj['options'].items()
696
824
  }
697
825
  key = deserialized_obj['key']
698
- kernel = compile(src, None, options)
699
- self.device_caches[device][0][key] = kernel
700
- return kernel
826
+ _, _, _, backend, _ = self.device_caches[device]
827
+ options = backend.parse_options(options)
828
+ return self._do_compile(
829
+ key,
830
+ signature,
831
+ device,
832
+ constexprs,
833
+ options,
834
+ attrs,
835
+ warmup=True,
836
+ )
701
837
 
702
- # we do not parse `src` in the constructor because
703
- # the user might want to monkey-patch self.src dynamically.
704
- # Our unit tests do this, for example.
705
- def parse(self):
706
- tree = ast.parse(self.src)
707
- assert isinstance(tree, ast.Module)
708
- assert len(tree.body) == 1
709
- assert isinstance(tree.body[0], ast.FunctionDef)
710
- return tree
838
+ def _do_compile(self, key, signature, device, constexprs, options, attrs, warmup):
839
+ kernel_cache, _, target, backend, _ = self.device_caches[device]
711
840
 
712
- def __call__(self, *args, **kwargs):
713
- raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
841
+ if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs], warmup):
842
+ return None
843
+ src = self.ASTSource(self, signature, constexprs, attrs)
714
844
 
715
- def __setattr__(self, name, value):
716
- # - when `.src` attribute is set, cache key of all callers need to be re-computed
717
- if name == "src":
718
- raise AttributeError(f"Cannot set attribute '{name}' directly. "
719
- f"Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
720
- f"instead.")
721
- super(JITFunction, self).__setattr__(name, value)
845
+ async_mode = _async_compile.active_mode.get()
846
+ if async_mode is not None:
722
847
 
723
- def _unsafe_update_src(self, new_src):
724
- """
725
- The only method allowed to modify src.
726
- Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
727
- """
728
- self.hash = None
729
- super().__setattr__('src', new_src)
848
+ env_vars = get_cache_invalidating_env_vars()
849
+ cache_key = get_cache_key(src, backend, options, env_vars)
850
+
851
+ def async_compile():
852
+ return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars)
853
+
854
+ def finalize_compile(kernel):
855
+ kernel_cache[key] = kernel
856
+ self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options,
857
+ [attrs], warmup)
858
+
859
+ kernel = async_mode.submit(cache_key, async_compile, finalize_compile)
860
+ else:
861
+ kernel = self.compile(src, target=target, options=options.__dict__)
862
+ kernel_cache[key] = kernel
863
+ self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
864
+ warmup)
865
+ return kernel
866
+
867
+ def __call__(self, *args, **kwargs):
868
+ raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
730
869
 
731
870
  def __repr__(self):
732
- return f"JITFunction({self.module}:{self.fn.__name__})"
871
+ return f"JITFunction({self.module}:{self.fn.__qualname__})"
733
872
 
734
873
 
735
874
  # -----------------------------------------------------------------------------
@@ -748,8 +887,8 @@ def jit(
748
887
  version=None,
749
888
  repr: Optional[Callable] = None,
750
889
  launch_metadata: Optional[Callable] = None,
751
- do_not_specialize: Optional[Iterable[int]] = None,
752
- do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
890
+ do_not_specialize: Optional[Iterable[int | str]] = None,
891
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
753
892
  debug: Optional[bool] = None,
754
893
  noinline: Optional[bool] = None,
755
894
  ) -> Callable[[T], JITFunction[T]]:
@@ -762,8 +901,8 @@ def jit(
762
901
  version=None,
763
902
  repr: Optional[Callable] = None,
764
903
  launch_metadata: Optional[Callable] = None,
765
- do_not_specialize: Optional[Iterable[int]] = None,
766
- do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
904
+ do_not_specialize: Optional[Iterable[int | str]] = None,
905
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
767
906
  debug: Optional[bool] = None,
768
907
  noinline: Optional[bool] = None,
769
908
  ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
@@ -787,7 +926,7 @@ def jit(
787
926
 
788
927
  def decorator(fn: T) -> JITFunction[T]:
789
928
  assert callable(fn)
790
- if os.getenv("TRITON_INTERPRET", "0") == "1":
929
+ if knobs.runtime.interpret:
791
930
  from .interpreter import InterpretedFunction
792
931
  return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize,
793
932
  do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug,
@@ -828,8 +967,17 @@ class MockTensor:
828
967
  return MockTensor(arg)
829
968
  return arg
830
969
 
831
- def __init__(self, dtype):
970
+ def __init__(self, dtype, shape=None):
971
+ if shape is None:
972
+ shape = [1]
832
973
  self.dtype = dtype
974
+ self.shape = shape
975
+
976
+ def stride(self):
977
+ strides = [1]
978
+ for size in self.shape[1:]:
979
+ strides.append(strides[-1] * size)
980
+ return tuple(reversed(strides))
833
981
 
834
982
  @staticmethod
835
983
  def data_ptr():
@@ -894,17 +1042,66 @@ def reinterpret(tensor, dtype):
894
1042
 
895
1043
  def get_jit_fn_file_line(fn):
896
1044
  base_fn = fn
897
- while not isinstance(base_fn, JITFunction):
1045
+ while not isinstance(base_fn, JITCallable):
898
1046
  base_fn = base_fn.fn
899
1047
  file_name = base_fn.fn.__code__.co_filename
900
- lines, begin_line = inspect.getsourcelines(base_fn.fn)
1048
+ begin_line = base_fn.starting_line_number
901
1049
  # Match the following pattern:
902
1050
  # @triton.autotune(...) <- foo.__code__.co_firstlineno
903
1051
  # @triton.heuristics(...)
904
1052
  # @triton.jit
905
1053
  # def foo(...): <- this line is the first line
906
- for idx, line in enumerate(lines):
1054
+ for idx, line in enumerate(base_fn.raw_src):
907
1055
  if line.strip().startswith("def "):
908
1056
  begin_line += idx
909
1057
  break
910
1058
  return file_name, begin_line
1059
+
1060
+
1061
+ class BoundConstexprFunction(JITCallable):
1062
+
1063
+ def __init__(self, instance, fn):
1064
+ self.__self__ = instance
1065
+ self.__func__ = fn
1066
+
1067
+ def __call__(self, *args, **kwargs):
1068
+ return self.__func__(self.__self__, *args, **kwargs)
1069
+
1070
+
1071
+ class ConstexprFunction(JITCallable):
1072
+
1073
+ def __init__(self, fn):
1074
+ super().__init__(fn)
1075
+
1076
+ def __get__(self, obj, objclass):
1077
+ # Create a bound function to support constexpr_function methods
1078
+ if obj is not None:
1079
+ return BoundConstexprFunction(obj, self)
1080
+ return self
1081
+
1082
+ def __call__(self, *args, _semantic=None, **kwargs):
1083
+ from triton.language.core import _unwrap_if_constexpr, constexpr
1084
+ # de-constexpr arguments and discard the _semantic keyword argument:
1085
+ args = [_unwrap_if_constexpr(x) for x in args]
1086
+ kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
1087
+
1088
+ # call the raw Python function f:
1089
+ res = self.fn(*args, **kwargs)
1090
+
1091
+ if _semantic is None:
1092
+ # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function
1093
+ return res
1094
+
1095
+ # convert result back to a Triton constexpr:
1096
+ if knobs.runtime.interpret:
1097
+ return res # No constexpr in interpreter
1098
+ return constexpr(res)
1099
+
1100
+
1101
+ def constexpr_function(fn):
1102
+ """
1103
+ Wraps an arbitrary Python function so that it can be called at
1104
+ compile-time on constexpr arguments in a Triton function and
1105
+ returns a constexpr result.
1106
+ """
1107
+ return ConstexprFunction(fn)