triton-windows 3.3.0.post19__cp310-cp310-win_amd64.whl → 3.4.0.post20__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (173) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/runtime/tcc/lib/python310.def +1610 -0
  56. triton/runtime/tcc/lib/python311.def +1633 -0
  57. triton/runtime/tcc/lib/python312.def +1703 -0
  58. triton/runtime/tcc/lib/python313.def +1651 -0
  59. triton/runtime/tcc/lib/python313t.def +1656 -0
  60. triton/runtime/tcc/lib/python39.def +1644 -0
  61. triton/runtime/tcc/lib/python3t.def +905 -0
  62. triton/testing.py +16 -12
  63. triton/tools/disasm.py +3 -4
  64. triton/tools/tensor_descriptor.py +36 -0
  65. triton/windows_utils.py +14 -6
  66. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  67. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  68. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
  69. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  70. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  71. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  72. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  73. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  80. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  81. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  82. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  83. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  84. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  85. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  86. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  87. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  88. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  89. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  90. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  91. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  92. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  93. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  94. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  95. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  96. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  97. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  98. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  99. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  100. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  101. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  102. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  103. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  104. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  105. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  106. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  107. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  108. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  109. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  110. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  111. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  112. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  113. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  114. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  115. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  116. triton/backends/amd/include/hip/device_functions.h +0 -38
  117. triton/backends/amd/include/hip/driver_types.h +0 -468
  118. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  119. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  120. triton/backends/amd/include/hip/hip_common.h +0 -100
  121. triton/backends/amd/include/hip/hip_complex.h +0 -38
  122. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  123. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  124. triton/backends/amd/include/hip/hip_ext.h +0 -161
  125. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  126. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  127. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  128. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  129. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  130. triton/backends/amd/include/hip/hip_profile.h +0 -27
  131. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  132. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  133. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  134. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  135. triton/backends/amd/include/hip/hip_version.h +0 -17
  136. triton/backends/amd/include/hip/hiprtc.h +0 -421
  137. triton/backends/amd/include/hip/library_types.h +0 -78
  138. triton/backends/amd/include/hip/math_functions.h +0 -42
  139. triton/backends/amd/include/hip/surface_types.h +0 -63
  140. triton/backends/amd/include/hip/texture_types.h +0 -194
  141. triton/backends/amd/include/hsa/Brig.h +0 -1131
  142. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  143. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  144. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  145. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  146. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  147. triton/backends/amd/include/hsa/hsa.h +0 -5738
  148. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  149. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  150. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  151. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  152. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  153. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  154. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  155. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  156. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  157. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  158. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  159. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  160. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  161. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  162. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  163. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  164. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  165. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  166. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  167. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  168. triton/backends/amd/include/roctracer/roctx.h +0 -229
  169. triton/language/_utils.py +0 -21
  170. triton/language/extra/cuda/_experimental_tma.py +0 -106
  171. triton/tools/experimental_descriptor.py +0 -32
  172. triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
  173. triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
triton/runtime/jit.py CHANGED
@@ -1,17 +1,21 @@
1
1
  from __future__ import annotations, division
2
2
  import ast
3
+ import copy
3
4
  import hashlib
4
5
  import inspect
5
6
  import itertools
6
- import os
7
7
  import re
8
8
  import textwrap
9
9
  from collections import defaultdict
10
+ from dataclasses import dataclass
10
11
  from functools import cached_property
11
12
  from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
12
- from ..runtime.driver import driver
13
+
14
+ from triton.tools.tensor_descriptor import TensorDescriptor
13
15
  from types import ModuleType
14
- from .._utils import find_paths_if, get_iterable_path
16
+ from .. import knobs
17
+ from ..runtime.driver import driver
18
+ from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, canonicalize_dtype
15
19
 
16
20
  TRITON_MODULE = __name__[:-len(".runtime.jit")]
17
21
 
@@ -34,13 +38,14 @@ class DependenciesFinder(ast.NodeVisitor):
34
38
  otherwise we could recompile).
35
39
  """
36
40
 
37
- def __init__(self, name, globals, src) -> None:
41
+ def __init__(self, name, globals, nonlocals, src) -> None:
38
42
  super().__init__()
39
43
  self.name = name
40
44
  self.hasher = hashlib.sha256(src.encode("utf-8"))
41
45
 
42
46
  # This function's __globals__ dict.
43
47
  self.globals = globals
48
+ self.nonlocals = nonlocals
44
49
 
45
50
  # Python builtins that can be accessed from Triton kernels.
46
51
  self.supported_python_builtins = {
@@ -106,7 +111,16 @@ class DependenciesFinder(ast.NodeVisitor):
106
111
  # The global name is hidden by the local name.
107
112
  return None
108
113
 
109
- val = self.globals.get(node.id, None)
114
+ def name_lookup(name):
115
+ val = self.globals.get(name, None)
116
+ if val is not None:
117
+ return val, self.globals
118
+ val = self.nonlocals.get(name, None)
119
+ if val is not None:
120
+ return val, self.nonlocals
121
+ return None, None
122
+
123
+ val, var_dict = name_lookup(node.id)
110
124
 
111
125
  # Only keep track of "interesting" global variables, that non-evil users
112
126
  # might change. Don't consider functions, modules, builtins, etc. This
@@ -123,7 +137,7 @@ class DependenciesFinder(ast.NodeVisitor):
123
137
  # `bar` and then someone did `foo = baz`.
124
138
  and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) #
125
139
  and node.id not in self.supported_python_builtins):
126
- self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals)
140
+ self.used_global_vals[(node.id, id(var_dict))] = (copy.copy(val), var_dict)
127
141
 
128
142
  self._update_hash(val)
129
143
  return val
@@ -221,11 +235,29 @@ class DependenciesFinder(ast.NodeVisitor):
221
235
 
222
236
 
223
237
  def _normalize_ty(ty) -> str:
224
- if isinstance(ty, type):
225
- return ty.__name__
226
- elif isinstance(ty, str):
227
- return ty
228
- return repr(ty)
238
+ import triton.language.core as core
239
+ if isinstance(ty, str):
240
+ ty = ty.strip()
241
+ if ty.startswith("const "):
242
+ ty = ty.removeprefix("const")
243
+ ty = _normalize_ty(ty)
244
+ assert ty.startswith("*")
245
+ return "*k" + ty[1:]
246
+ if ty.endswith("*"):
247
+ return "*" + _normalize_ty(ty[:-1])
248
+ if ty.startswith("*"):
249
+ return "*" + _normalize_ty(ty[1:])
250
+ if ty.startswith("tl."):
251
+ return _normalize_ty(ty.removeprefix("tl."))
252
+ elif isinstance(ty, core.pointer_type):
253
+ return f"*{_normalize_ty(ty.element_ty)}"
254
+ elif isinstance(ty, core.dtype):
255
+ ty = ty.name
256
+ elif isinstance(ty, type):
257
+ ty = ty.__name__
258
+ else:
259
+ ty = str(ty)
260
+ return type_canonicalisation_dict.get(ty.replace("_t", ""), ty)
229
261
 
230
262
 
231
263
  class KernelParam:
@@ -243,20 +275,20 @@ class KernelParam:
243
275
  return self._param.name
244
276
 
245
277
  @cached_property
246
- def annotation(self):
278
+ def annotation(self) -> str:
247
279
  if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
248
280
  return ""
249
281
  return _normalize_ty(self._param.annotation)
250
282
 
251
283
  @cached_property
252
- def annotation_type(self):
253
- annotation = self.annotation
254
- for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
255
- width = annotation[annotation.find(ty1) + len(ty1):]
256
- if width and ty1 in annotation:
257
- return f"{ty2}{width}"
258
- if annotation == "bool":
259
- return "u1"
284
+ def annotation_type(self) -> str:
285
+ a = self.annotation
286
+ if a.startswith("*k"):
287
+ a = a[2:]
288
+ elif a.startswith("*"):
289
+ a = a[1:]
290
+ if a in set(type_canonicalisation_dict.values()):
291
+ return self.annotation
260
292
  return ""
261
293
 
262
294
  @cached_property
@@ -265,7 +297,9 @@ class KernelParam:
265
297
 
266
298
  @cached_property
267
299
  def is_const(self):
268
- return "const" in self.annotation and not self.is_constexpr
300
+ if self.is_constexpr:
301
+ return False
302
+ return "const" in self.annotation or self.annotation.startswith("*k")
269
303
 
270
304
  @property
271
305
  def default(self):
@@ -280,22 +314,16 @@ dtype2str = {}
280
314
  specialize_impl_cache = []
281
315
 
282
316
 
283
- def create_specialize_impl():
284
- if specialize_impl_cache:
285
- return specialize_impl_cache[-1]
317
+ def create_specialize_impl(specialize_extra):
286
318
 
287
319
  from ..language import constexpr
320
+ from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
288
321
 
289
- def specialize_impl(arg, specialize_extra, is_const=False, specialize_value=True, align=True):
290
-
322
+ def specialize_impl(arg, is_const=False, specialize_value=True, align=True):
291
323
  if arg is None:
292
324
  return ("constexpr", None)
293
- elif isinstance(arg, JITFunction):
294
- return ("constexpr", arg.cache_key)
295
- elif isinstance(arg, constexpr):
296
- return ("constexpr", arg)
297
325
  elif isinstance(arg, bool):
298
- return ("i1", None)
326
+ return ("u1", None)
299
327
  elif isinstance(arg, int):
300
328
  key = specialize_extra(arg, "int", align=align) if specialize_value else None
301
329
  if arg == 1 and specialize_value:
@@ -308,31 +336,46 @@ def create_specialize_impl():
308
336
  return ("i64", key)
309
337
  elif isinstance(arg, float):
310
338
  return ("fp32", None)
339
+ elif hasattr(arg, "data_ptr"):
340
+ # dtypes are hashable so we can memoize this mapping:
341
+ dsk = (arg.dtype, is_const)
342
+ res = dtype2str.get(dsk, None)
343
+ if res is None:
344
+ res = ("*k" if dsk[1] else "*") + canonicalize_dtype(dsk[0])
345
+ dtype2str[dsk] = res
346
+ key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
347
+ return (res, key)
348
+ elif isinstance(arg, JITFunction):
349
+ return ("constexpr", arg.cache_key)
350
+ elif isinstance(arg, constexpr):
351
+ return ("constexpr", arg)
311
352
  elif hasattr(arg, "tma_desc_cpu_ptr"):
312
353
  return ("nvTmaDesc", None)
313
354
  elif isinstance(arg, tuple):
314
- spec = [specialize_impl(x, specialize_extra) for x in arg]
355
+ spec = [specialize_impl(x) for x in arg]
315
356
  make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
316
357
  tys = make_tuple([x[0] for x in spec])
317
358
  keys = make_tuple([x[1] for x in spec])
318
359
  return (tys, keys)
360
+ elif isinstance(arg, TensorDescriptor):
361
+ assert hasattr(arg.base, "data_ptr")
362
+ inner = canonicalize_dtype(arg.base.dtype)
363
+ return (f"tensordesc<{inner}{list(arg.block_shape)}>", None)
364
+ elif isinstance(arg, GluonTensorDescriptor):
365
+ assert hasattr(arg.base, "data_ptr")
366
+ inner = canonicalize_dtype(arg.base.dtype)
367
+ return (f"tensordesc<{inner}{list(arg.block_shape)},{arg.layout!r}>", None)
319
368
  else:
320
- # dtypes are hashable so we can memoize this mapping:
321
- dsk = (arg.dtype, is_const)
322
- res = dtype2str.get(dsk, None)
323
- if res is None:
324
- res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
325
- dtype2str[dsk] = res
326
- key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
327
- return (res, key)
369
+ raise TypeError("Unsupported type: %s" % type(arg))
328
370
 
329
- specialize_impl_cache.append(specialize_impl)
330
371
  return specialize_impl
331
372
 
332
373
 
333
374
  def mangle_type(arg, specialize=False):
334
- specialize_impl = create_specialize_impl()
335
- return specialize_impl(arg, lambda _, **kwargs: None, specialize_value=specialize)[0]
375
+ if len(specialize_impl_cache) == 0:
376
+ specialize_impl_cache.append(create_specialize_impl(lambda _, **kwargs: None))
377
+ specialize_impl = specialize_impl_cache[0]
378
+ return specialize_impl(arg, specialize_value=specialize)[0]
336
379
 
337
380
 
338
381
  class KernelInterface(Generic[T]):
@@ -378,9 +421,17 @@ def create_function_from_signature(sig, kparams, backend):
378
421
  is_const = 'True' if kp.is_const else 'False'
379
422
  specialize = 'False' if kp.do_not_specialize else 'True'
380
423
  align = 'False' if kp.do_not_specialize_on_alignment else 'True'
381
- ret = f"specialize_impl({name}, specialize_extra, {is_const}, {specialize}, {align})"
424
+ ret = f"specialize_impl({name}, {is_const}, {specialize}, {align})"
382
425
  if kp.annotation_type:
383
- specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
426
+ if isinstance(kp.annotation_type, str):
427
+ if kp.annotation_type == "u1" or kp.annotation_type[:2] in ["fp", "bf"]:
428
+ # we do not specialize non-constexpr floats and bools:
429
+ specialize = False
430
+ if specialize:
431
+ specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
432
+ else:
433
+ # skip runtime specialization:
434
+ specialization.append(f'("{kp.annotation_type}", None)')
384
435
  else:
385
436
  specialization.append(f"{ret}")
386
437
 
@@ -401,8 +452,7 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
401
452
  }
402
453
 
403
454
  func_namespace["JITFunction"] = JITFunction
404
- func_namespace["specialize_impl"] = create_specialize_impl()
405
- func_namespace["specialize_extra"] = backend.get_arg_specialization
455
+ func_namespace["specialize_impl"] = create_specialize_impl(backend.get_arg_specialization)
406
456
 
407
457
  # Execute the function string in func_namespace to create the function
408
458
  exec(func_body, func_namespace)
@@ -411,44 +461,25 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
411
461
  return func_namespace['dynamic_func']
412
462
 
413
463
 
414
- type_canonicalisation_dict = {
415
- "bool": "i1",
416
- "float8e4nv": "fp8e4nv",
417
- "float8e5": "fp8e5",
418
- "float8e4b15": "fp8e4b15",
419
- "float8_e4m3fn": "fp8e4nv",
420
- "float8e4b8": "fp8e4b8",
421
- "float8_e4m3fnuz": "fp8e4b8",
422
- "float8_e5m2": "fp8e5",
423
- "float8e5b16": "fp8e5b16",
424
- "float8_e5m2fnuz": "fp8e5b16",
425
- "float16": "fp16",
426
- "bfloat16": "bf16",
427
- "float32": "fp32",
428
- "float64": "fp64",
429
- "int8": "i8",
430
- "int16": "i16",
431
- "int32": "i32",
432
- "int64": "i64",
433
- "uint8": "u8",
434
- "uint16": "u16",
435
- "uint32": "u32",
436
- "uint64": "u64",
437
- }
438
-
439
- for v in list(type_canonicalisation_dict.values()):
440
- type_canonicalisation_dict[v] = v
464
+ def get_full_name(fn):
465
+ return f"{fn.__module__}.{fn.__qualname__}"
466
+
467
+
468
+ @dataclass
469
+ class JitFunctionInfo:
470
+ module: ModuleType
471
+ name: str
472
+ jit_function: JITFunction
441
473
 
442
474
 
443
475
  class JITFunction(KernelInterface[T]):
444
- # Hook for inspecting compiled functions and modules
445
- cache_hook = None
446
- # Hook to signal that a kernel is done compiling and inspect compiled function.
447
- # cache_hook will always be called before compilation and compiled_hook after.
448
- compiled_hook = None
476
+
477
+ def is_gluon(self):
478
+ return False
449
479
 
450
480
  def _call_hook(
451
481
  self,
482
+ hook,
452
483
  key,
453
484
  signature,
454
485
  device,
@@ -456,26 +487,17 @@ class JITFunction(KernelInterface[T]):
456
487
  options,
457
488
  configs,
458
489
  is_warmup,
459
- before,
460
- ):
461
- hook = JITFunction.cache_hook if before else JITFunction.compiled_hook
462
- if hook is None:
463
- return False
490
+ ) -> bool | None:
491
+ if not hook:
492
+ return None
464
493
 
465
- name = self.fn.__name__
494
+ name = self.fn.__qualname__
466
495
  module = self.fn.__module__
467
496
  arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
468
497
  repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
498
+ full_name = get_full_name(self.fn)
469
499
 
470
- class JitFunctionInfo:
471
-
472
- def __init__(self, module, name, jit_function):
473
- self.module = module
474
- self.name = name
475
- self.jit_function = jit_function
476
- pass
477
-
478
- specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key)
500
+ specialization_data = serialize_specialization_data(full_name, signature, constants, configs[0], options, key)
479
501
 
480
502
  kwargs = {
481
503
  'signature': signature,
@@ -523,7 +545,7 @@ class JITFunction(KernelInterface[T]):
523
545
  return {}, target, backend, binder
524
546
 
525
547
  def run(self, *args, grid, warmup, **kwargs):
526
- kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"
548
+ kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
527
549
 
528
550
  # parse options
529
551
  device = driver.active.get_current_device()
@@ -534,6 +556,8 @@ class JITFunction(KernelInterface[T]):
534
556
  hook(*args, **kwargs)
535
557
 
536
558
  kernel_cache, target, backend, binder = self.device_caches[device]
559
+ # specialization is list[tuple[str, Any]], where first element of tuple is
560
+ # the type and the second parameter is the 'specialization' value.
537
561
  bound_args, specialization, options = binder(*args, **kwargs)
538
562
 
539
563
  # compute cache key
@@ -562,13 +586,15 @@ class JITFunction(KernelInterface[T]):
562
586
  attrvals = [x[1] for x in specialization]
563
587
  attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
564
588
  attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
565
- if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True):
589
+ if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs],
590
+ warmup):
566
591
  return None
567
592
  # compile the kernel
568
593
  src = self.ASTSource(self, signature, constexprs, attrs)
569
594
  kernel = self.compile(src, target=target, options=options.__dict__)
570
595
  kernel_cache[key] = kernel
571
- self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False)
596
+ self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
597
+ warmup)
572
598
 
573
599
  # Check that used global values have not changed.
574
600
  not_present = object()
@@ -588,9 +614,8 @@ class JITFunction(KernelInterface[T]):
588
614
  grid_2 = grid[2] if grid_size > 2 else 1
589
615
  # launch kernel
590
616
  launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
591
- kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
592
- launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook,
593
- *bound_args.values())
617
+ kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
618
+ knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())
594
619
  return kernel
595
620
 
596
621
  def repr(self, _):
@@ -609,7 +634,7 @@ class JITFunction(KernelInterface[T]):
609
634
  self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
610
635
  self.starting_line_number = inspect.getsourcelines(fn)[1]
611
636
  self._repr = repr
612
- self._fn_name = fn.__name__
637
+ self._fn_name = get_full_name(fn)
613
638
  self.launch_metadata = launch_metadata
614
639
 
615
640
  self.params = []
@@ -654,19 +679,30 @@ class JITFunction(KernelInterface[T]):
654
679
  # reuse docs of wrapped function
655
680
  self.__doc__ = fn.__doc__
656
681
  self.__name__ = fn.__name__
682
+ self.__qualname__ = fn.__qualname__
657
683
  self.__globals__ = fn.__globals__
658
684
  self.__module__ = fn.__module__
659
685
 
686
+ def get_capture_scope(self):
687
+ return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
688
+
660
689
  @property
661
690
  def cache_key(self):
662
691
  # TODO : hash should be attribute of `self`
663
692
  if self.hash is None:
664
- dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src)
693
+ nonlocals = inspect.getclosurevars(self.fn).nonlocals
694
+ dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
695
+ src=self.src)
665
696
  dependencies_finder.visit(self.parse())
666
697
  self.hash = dependencies_finder.ret + str(self.starting_line_number)
667
698
  self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
668
699
  return self.hash
669
700
 
701
+ @property
702
+ def type(self):
703
+ from triton.language.core import constexpr
704
+ return constexpr
705
+
670
706
  def warmup(self, *args, grid, **kwargs):
671
707
  return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
672
708
 
@@ -676,9 +712,9 @@ class JITFunction(KernelInterface[T]):
676
712
  import triton.language as tl
677
713
  device = driver.active.get_current_device()
678
714
  deserialized_obj = json.loads(specialization_data)
679
- if deserialized_obj['name'] != self.fn.__name__:
715
+ if deserialized_obj['name'] != self._fn_name:
680
716
  raise RuntimeError(
681
- f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
717
+ f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
682
718
  constant_keys = map(tuple, deserialized_obj['constant_keys'])
683
719
  constant_vals = deserialized_obj['constant_vals']
684
720
  constants = {
@@ -729,7 +765,7 @@ class JITFunction(KernelInterface[T]):
729
765
  super().__setattr__('src', new_src)
730
766
 
731
767
  def __repr__(self):
732
- return f"JITFunction({self.module}:{self.fn.__name__})"
768
+ return f"JITFunction({self.module}:{self.fn.__qualname__})"
733
769
 
734
770
 
735
771
  # -----------------------------------------------------------------------------
@@ -748,8 +784,8 @@ def jit(
748
784
  version=None,
749
785
  repr: Optional[Callable] = None,
750
786
  launch_metadata: Optional[Callable] = None,
751
- do_not_specialize: Optional[Iterable[int]] = None,
752
- do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
787
+ do_not_specialize: Optional[Iterable[int | str]] = None,
788
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
753
789
  debug: Optional[bool] = None,
754
790
  noinline: Optional[bool] = None,
755
791
  ) -> Callable[[T], JITFunction[T]]:
@@ -762,8 +798,8 @@ def jit(
762
798
  version=None,
763
799
  repr: Optional[Callable] = None,
764
800
  launch_metadata: Optional[Callable] = None,
765
- do_not_specialize: Optional[Iterable[int]] = None,
766
- do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
801
+ do_not_specialize: Optional[Iterable[int | str]] = None,
802
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
767
803
  debug: Optional[bool] = None,
768
804
  noinline: Optional[bool] = None,
769
805
  ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
@@ -787,7 +823,7 @@ def jit(
787
823
 
788
824
  def decorator(fn: T) -> JITFunction[T]:
789
825
  assert callable(fn)
790
- if os.getenv("TRITON_INTERPRET", "0") == "1":
826
+ if knobs.runtime.interpret:
791
827
  from .interpreter import InterpretedFunction
792
828
  return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize,
793
829
  do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug,