triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -3,19 +3,19 @@ import hashlib
3
3
  import json
4
4
  from .._C.libtriton import get_cache_invalidating_env_vars, ir
5
5
  from ..backends import backends
6
- from ..backends.compiler import GPUTarget
7
- from .. import __version__
6
+ from ..backends.compiler import Language
7
+ from ..backends.compiler import BaseBackend, GPUTarget
8
+ from .. import __version__, knobs
8
9
  from ..runtime.autotuner import OutOfResources
9
- from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
10
+ from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key
10
11
  from ..runtime.driver import driver
11
12
  from ..tools.disasm import get_sass
12
- # TODO: this shouldn't be here
13
- from .code_generator import ast_to_ttir
14
13
  from pathlib import Path
15
14
  import re
16
15
  import functools
17
16
  import os
18
- import sysconfig
17
+ import time
18
+ import copy
19
19
 
20
20
  # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
21
21
  # and any following whitespace
@@ -53,6 +53,7 @@ class ASTSource:
53
53
 
54
54
  def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
55
55
  self.fn = fn
56
+ self.language = Language.TRITON
56
57
  self.ext = "ttir"
57
58
  self.name = fn.__name__
58
59
  self.signature = signature
@@ -63,12 +64,9 @@ class ASTSource:
63
64
  assert isinstance(k, tuple)
64
65
  self.constants[k] = v
65
66
  self.attrs = attrs or dict()
66
- if isinstance(self.signature, str):
67
- self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))}
68
- else:
69
- for k in self.signature.keys():
70
- if not isinstance(k, str):
71
- raise TypeError("Signature keys must be string")
67
+ for k in self.signature.keys():
68
+ if not isinstance(k, str):
69
+ raise TypeError("Signature keys must be string")
72
70
 
73
71
  def hash(self):
74
72
  sorted_sig = [v for k, v in sorted(self.signature.items())]
@@ -77,7 +75,8 @@ class ASTSource:
77
75
  key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
78
76
  return hashlib.sha256(key.encode("utf-8")).hexdigest()
79
77
 
80
- def make_ir(self, options, codegen_fns, module_map, context):
78
+ def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
79
+ from .code_generator import ast_to_ttir
81
80
  return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
82
81
  module_map=module_map)
83
82
 
@@ -91,6 +90,7 @@ class IRSource:
91
90
  self.path = path
92
91
  path = Path(path)
93
92
  self.ext = path.suffix[1:]
93
+ self.language = Language.TRITON
94
94
  self.src = path.read_text()
95
95
  ir.load_dialects(context)
96
96
  backend.load_dialects(context)
@@ -114,7 +114,7 @@ class IRSource:
114
114
  def hash(self):
115
115
  return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
116
116
 
117
- def make_ir(self, options, codegen_fns, module_map, context):
117
+ def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
118
118
  self.module.context = context
119
119
  return self.module
120
120
 
@@ -127,39 +127,8 @@ class IRSource:
127
127
 
128
128
 
129
129
  @functools.lru_cache()
130
- def triton_key():
131
- import pkgutil
132
- TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
133
- contents = []
134
- # frontend
135
- with open(__file__, "rb") as f:
136
- contents += [hashlib.sha256(f.read()).hexdigest()]
137
- # compiler
138
- path_prefixes = [
139
- (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
140
- (os.path.join(TRITON_PATH, "backends"), "triton.backends."),
141
- ]
142
- for path, prefix in path_prefixes:
143
- for lib in pkgutil.walk_packages([path], prefix=prefix):
144
- with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
145
- contents += [hashlib.sha256(f.read()).hexdigest()]
146
-
147
- # backend
148
- libtriton_hash = hashlib.sha256()
149
- ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
150
- with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
151
- while True:
152
- chunk = f.read(1024**2)
153
- if not chunk:
154
- break
155
- libtriton_hash.update(chunk)
156
- contents.append(libtriton_hash.hexdigest())
157
- # language
158
- language_path = os.path.join(TRITON_PATH, 'language')
159
- for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
160
- with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
161
- contents += [hashlib.sha256(f.read()).hexdigest()]
162
- return f'{__version__}' + '-'.join(contents)
130
+ def max_shared_mem(device):
131
+ return driver.active.utils.get_device_properties(device)["max_shared_mem"]
163
132
 
164
133
 
165
134
  def parse(full_name, ext, context):
@@ -179,7 +148,7 @@ def filter_traceback(e: BaseException):
179
148
 
180
149
  These are uninteresting to the user -- "just show me *my* code!"
181
150
  """
182
- if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1":
151
+ if knobs.compilation.front_end_debugging:
183
152
  return
184
153
 
185
154
  if e.__cause__ is not None:
@@ -211,7 +180,50 @@ def filter_traceback(e: BaseException):
211
180
  e.__traceback__ = frames[0]
212
181
 
213
182
 
214
- def compile(src, target=None, options=None):
183
+ class CompileTimer:
184
+
185
+ def __init__(self) -> None:
186
+ self.start: float = time.perf_counter()
187
+ self.ir_initialization_end: float | None = None
188
+ self.lowering_stage_ends: list[tuple[str, float]] = []
189
+ self.store_results_end: float | None = None
190
+
191
+ def finished_ir_initialization(self) -> None:
192
+ self.ir_initialization_end = time.perf_counter()
193
+
194
+ def stage_finished(self, stage_name: str) -> None:
195
+ self.lowering_stage_ends.append((stage_name, time.perf_counter()))
196
+
197
+ def end(self) -> knobs.CompileTimes:
198
+ timestamp = time.perf_counter()
199
+ if self.ir_initialization_end is None:
200
+ self.ir_initialization_end = timestamp
201
+ else:
202
+ self.store_results_end = timestamp
203
+
204
+ def delta(start: float, end: float | None) -> int:
205
+ if end is None:
206
+ return 0
207
+ return int((end - start) * 1000000)
208
+
209
+ lowering_stage_durations = []
210
+ stage_start = self.ir_initialization_end
211
+ for stage_name, stage_end in self.lowering_stage_ends:
212
+ lowering_stage_durations.append((stage_name, delta(stage_start, stage_end)))
213
+ stage_start = stage_end
214
+
215
+ return knobs.CompileTimes(
216
+ ir_initialization=delta(self.start, self.ir_initialization_end),
217
+ lowering_stages=lowering_stage_durations,
218
+ store_results=delta(stage_start, self.store_results_end),
219
+ )
220
+
221
+
222
+ def compile(src, target=None, options=None, _env_vars=None):
223
+ compilation_listener = knobs.compilation.listener
224
+ if compilation_listener:
225
+ timer = CompileTimer()
226
+
215
227
  if target is None:
216
228
  target = driver.active.get_current_target()
217
229
  assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
@@ -226,15 +238,15 @@ def compile(src, target=None, options=None):
226
238
  extra_options = src.parse_options()
227
239
  options = backend.parse_options(dict(options or dict(), **extra_options))
228
240
  # create cache manager
229
- env_vars = get_cache_invalidating_env_vars()
230
- key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}"
241
+ env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars
242
+ key = get_cache_key(src, backend, options, env_vars=env_vars)
231
243
  hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
232
244
  fn_cache_manager = get_cache_manager(hash)
233
245
  # For dumping/overriding only hash the source as we want it to be independent of triton
234
246
  # core changes to make it easier to track kernels by hash.
235
- enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
236
- enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
237
- store_only_binary = os.environ.get("TRITON_STORE_BINARY_ONLY", "0") == "1"
247
+ enable_override = knobs.compilation.override
248
+ enable_ir_dump = knobs.compilation.dump_ir
249
+ store_only_binary = knobs.compilation.store_binary_only
238
250
  fn_override_manager = get_override_manager(src.hash()) if enable_override else None
239
251
  fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
240
252
  # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
@@ -245,10 +257,20 @@ def compile(src, target=None, options=None):
245
257
  metadata_filename = f"{file_name}.json"
246
258
  metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
247
259
  metadata_path = metadata_group.get(metadata_filename)
248
- always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1"
260
+ always_compile = knobs.compilation.always_compile
249
261
  if not always_compile and metadata_path is not None:
250
262
  # cache hit!
251
- return CompiledKernel(src, metadata_group, hash)
263
+ res = CompiledKernel(src, metadata_group, hash)
264
+ if compilation_listener:
265
+ compilation_listener(
266
+ src=src,
267
+ metadata=res.metadata._asdict(),
268
+ metadata_group=metadata_group,
269
+ times=timer.end(),
270
+ cache_hit=True,
271
+ )
272
+ return res
273
+
252
274
  # initialize metadata
253
275
  metadata = {
254
276
  "hash": hash,
@@ -259,7 +281,7 @@ def compile(src, target=None, options=None):
259
281
  metadata["triton_version"] = __version__
260
282
  # run compilation pipeline and populate metadata
261
283
  stages = dict()
262
- backend.add_stages(stages, options)
284
+ backend.add_stages(stages, options, src.language)
263
285
  first_stage = list(stages.keys()).index(src.ext)
264
286
  # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
265
287
  if ir_source:
@@ -275,15 +297,34 @@ def compile(src, target=None, options=None):
275
297
  codegen_fns = backend.get_codegen_implementation(options)
276
298
  module_map = backend.get_module_map()
277
299
  try:
278
- module = src.make_ir(options, codegen_fns, module_map, context)
300
+ module = src.make_ir(target, options, codegen_fns, module_map, context)
279
301
  except Exception as e:
280
302
  filter_traceback(e)
281
303
  raise
282
- use_ir_loc = os.environ.get("USE_IR_LOC", None)
304
+
305
+ if ir_source:
306
+ ir_filename = f"{file_name}.{src.ext}"
307
+ metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
308
+ else:
309
+ ir_filename = f"{file_name}.source"
310
+ metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
311
+
312
+ use_ir_loc = knobs.compilation.use_ir_loc
313
+ if ir_source and use_ir_loc:
314
+ module.create_location_snapshot(src.path)
315
+ print(f"Creating new locations for {src.path}")
316
+
317
+ if compilation_listener:
318
+ timer.finished_ir_initialization()
283
319
  for ext, compile_ir in list(stages.items())[first_stage:]:
284
320
  next_module = compile_ir(module, metadata)
285
321
  ir_filename = f"{file_name}.{ext}"
286
- if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
322
+ if fn_override_manager is None:
323
+ # Users can override kernels at scale by setting `ir_override` in autotune config
324
+ # without TRITON_KERNEL_OVERRIDE
325
+ if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"):
326
+ next_module = parse(ir_override, ext, context)
327
+ elif full_name := fn_override_manager.get_file(ir_filename):
287
328
  print(f"\nOverriding kernel with file {full_name}")
288
329
  next_module = parse(full_name, ext, context)
289
330
  # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
@@ -291,12 +332,17 @@ def compile(src, target=None, options=None):
291
332
  metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
292
333
  if fn_dump_manager is not None:
293
334
  fn_dump_manager.put(next_module, ir_filename)
335
+ if ext == "cubin":
336
+ sass = get_sass(next_module)
337
+ fn_dump_manager.put(sass, file_name + ".sass")
294
338
  # use an env variable to parse ir from file
295
339
  if use_ir_loc == ext:
296
340
  ir_full_name = fn_cache_manager.get_file(ir_filename)
297
341
  next_module.create_location_snapshot(ir_full_name)
298
342
  print(f"Creating new locations for {ir_full_name}")
299
343
  module = next_module
344
+ if compilation_listener:
345
+ timer.stage_finished(ext)
300
346
  # write-back metadata
301
347
  metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
302
348
  binary=False)
@@ -310,13 +356,18 @@ def compile(src, target=None, options=None):
310
356
  # this is likely due to the llvm-symbolizer forking a process
311
357
  # TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
312
358
  # multithreading in the MLIR context
313
- if not os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
359
+ if not knobs.compilation.enable_asan:
314
360
  context.disable_multithreading()
361
+
362
+ # notify any listener
363
+ if compilation_listener:
364
+ compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(),
365
+ cache_hit=False)
315
366
  # return handle to compiled kernel
316
367
  return CompiledKernel(src, metadata_group, hash)
317
368
 
318
369
 
319
- def make_backend(target):
370
+ def make_backend(target: GPUTarget) -> BaseBackend:
320
371
  actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
321
372
  if len(actives) != 1:
322
373
  raise RuntimeError(
@@ -330,7 +381,7 @@ class LazyDict:
330
381
  self.data = data
331
382
  self.extras = []
332
383
 
333
- def get(self) -> None:
384
+ def get(self):
334
385
  for func, args in self.extras:
335
386
  self.data = self.data | func(*args)
336
387
  self.extras.clear()
@@ -353,12 +404,11 @@ class AsmDict(dict):
353
404
  return value
354
405
 
355
406
 
356
- class CompiledKernel:
407
+ def _raise_error(err, *args, **kwargs):
408
+ raise copy.deepcopy(err)
357
409
 
358
- # Hooks for external tools to monitor the execution of triton kernels
359
- # TODO: move out of this namespace since it's a runtime thing
360
- launch_enter_hook = None
361
- launch_exit_hook = None
410
+
411
+ class CompiledKernel:
362
412
 
363
413
  def __init__(self, src, metadata_group, hash):
364
414
  from collections import namedtuple
@@ -382,48 +432,66 @@ class CompiledKernel:
382
432
  file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
383
433
  for file in asm_files
384
434
  })
435
+ self.metadata_group = metadata_group
385
436
  self.kernel = self.asm[binary_ext]
386
437
  # binaries are lazily initialized
387
438
  # because it involves doing runtime things
388
439
  # (e.g., checking amount of shared memory on current device)
389
440
  self.module = None
390
441
  self.function = None
442
+ self._run = None
391
443
 
392
444
  def _init_handles(self):
393
445
  if self.module is not None:
394
446
  return
447
+
448
+ def raise_(err):
449
+ # clone the exception object so that the one saved in the closure
450
+ # of the partial function below doesn't get assigned a stack trace
451
+ # after the subsequent raise. otherwise, the CompiledKernel instance
452
+ # saved in the (global) kernel cache will keep references to all the
453
+ # locals in the traceback via the exception instance in the closure.
454
+ cloned_err = copy.deepcopy(err)
455
+ self._run = functools.partial(_raise_error, cloned_err)
456
+ raise err
457
+
395
458
  device = driver.active.get_current_device()
396
459
  # create launcher
397
- self.run = driver.active.launcher_cls(self.src, self.metadata)
460
+ self._run = driver.active.launcher_cls(self.src, self.metadata)
398
461
  # not enough shared memory to run the kernel
399
- max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
462
+ max_shared = max_shared_mem(device)
400
463
  if self.metadata.shared > max_shared:
401
- raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
464
+ raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
402
465
  if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
403
466
  # Use blackwell max tmem size for now, this should be moved in device properties
404
467
  max_tmem_size = 512 # tmem size in number of columns
405
468
  if self.metadata.tmem_size > max_tmem_size:
406
- raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")
469
+ raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory"))
470
+ if knobs.runtime.kernel_load_start_hook is not None:
471
+ knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
407
472
  # TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
408
- self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
473
+ self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
409
474
  self.name, self.kernel, self.metadata.shared, device)
410
-
411
- def __getattribute__(self, name):
412
- if name == 'run':
475
+ warp_size = driver.active.get_current_target().warp_size
476
+ if self.metadata.num_warps * warp_size > self.n_max_threads:
477
+ raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
478
+ if knobs.runtime.kernel_load_end_hook is not None:
479
+ knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
480
+
481
+ @property
482
+ def run(self):
483
+ if self._run is None:
413
484
  self._init_handles()
414
- return super().__getattribute__(name)
485
+ return self._run
415
486
 
416
487
  def launch_metadata(self, grid, stream, *args):
417
- if CompiledKernel.launch_enter_hook is None:
488
+ if knobs.runtime.launch_enter_hook is None:
418
489
  return None
490
+ self._init_handles()
419
491
  ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
420
492
  if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
421
493
  return ret
422
- arg_dict = {}
423
- arg_idx = 0
424
- for i, arg_name in enumerate(self.src.fn.arg_names):
425
- arg_dict[arg_name] = args[arg_idx]
426
- arg_idx += 1
494
+ arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)}
427
495
  ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
428
496
  return ret
429
497
 
@@ -436,6 +504,6 @@ class CompiledKernel:
436
504
  stream = driver.active.get_current_stream(device)
437
505
  launch_metadata = self.launch_metadata(grid, stream, *args)
438
506
  self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
439
- CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)
507
+ knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
440
508
 
441
509
  return runner
File without changes
@@ -0,0 +1,5 @@
1
+ from . import nvidia
2
+ from ._runtime import constexpr_function, jit
3
+ from triton.language.core import must_use_result
4
+
5
+ __all__ = ["constexpr_function", "jit", "must_use_result", "nvidia"]
File without changes
@@ -0,0 +1,102 @@
1
+ from __future__ import annotations
2
+ from triton.compiler.compiler import ASTSource
3
+ from triton.backends.compiler import Language
4
+ from triton.runtime.jit import JITFunction, constexpr_function
5
+ from typing import TypeVar, Optional, Callable, Iterable, Union
6
+ from triton._C.libtriton import ir
7
+
8
+ T = TypeVar("T")
9
+
10
+ __all__ = ["constexpr_function", "jit"]
11
+
12
+
13
+ class GluonASTSource(ASTSource):
14
+
15
+ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
16
+ super().__init__(fn, signature, constexprs, attrs)
17
+ self.language = Language.GLUON
18
+ self.ext = "ttgir"
19
+
20
+ def make_ir(self, target, options, codegen_fns, module_map, context):
21
+ from triton.compiler.compiler import make_backend
22
+ from triton.compiler.code_generator import ast_to_ttir
23
+
24
+ builder = ir.builder(context)
25
+ module = builder.create_module()
26
+
27
+ # Assign module attributes eagerly, as they are needed to verify layouts
28
+ backend = make_backend(target)
29
+ target = backend.get_target_name(options)
30
+
31
+ module.set_attr("ttg.target", builder.get_string_attr(target))
32
+ module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
33
+ module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
34
+ module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
35
+
36
+ is_cuda = options.backend_name == "cuda"
37
+ if is_cuda and options.maxnreg is not None:
38
+ module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
39
+
40
+ module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
41
+ module_map=module_map, module=module)
42
+ return module
43
+
44
+
45
+ class GluonJITFunction(JITFunction[T]):
46
+
47
+ def create_binder(self):
48
+ result = super().create_binder()
49
+ self.ASTSource = GluonASTSource
50
+ return result
51
+
52
+ def is_gluon(self):
53
+ return True
54
+
55
+
56
+ def jit(
57
+ fn: Optional[T] = None,
58
+ *,
59
+ version=None,
60
+ repr: Optional[Callable] = None,
61
+ launch_metadata: Optional[Callable] = None,
62
+ do_not_specialize: Optional[Iterable[int | str]] = None,
63
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
64
+ debug: Optional[bool] = None,
65
+ noinline: Optional[bool] = None,
66
+ ) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
67
+ """
68
+ Decorator for JIT-compiling a function using the Triton compiler.
69
+
70
+ :note: When a jit'd function is called, arguments are
71
+ implicitly converted to pointers if they have a :code:`.data_ptr()` method
72
+ and a `.dtype` attribute.
73
+
74
+ :note: This function will be compiled and run on the GPU. It will only have access to:
75
+
76
+ * python primitives,
77
+ * builtins within the triton package,
78
+ * arguments to this function,
79
+ * other jit'd functions
80
+
81
+ :param fn: the function to be jit-compiled
82
+ :type fn: Callable
83
+ """
84
+
85
+ def decorator(fn: T) -> JITFunction[T]:
86
+ assert callable(fn)
87
+ return GluonJITFunction(
88
+ fn,
89
+ version=version,
90
+ do_not_specialize=do_not_specialize,
91
+ do_not_specialize_on_alignment=do_not_specialize_on_alignment,
92
+ debug=debug,
93
+ noinline=noinline,
94
+ repr=repr,
95
+ launch_metadata=launch_metadata,
96
+ )
97
+
98
+ if fn is not None:
99
+ return decorator(fn)
100
+
101
+ else:
102
+ return decorator
@@ -0,0 +1,119 @@
1
+ from ._core import (
2
+ base_value,
3
+ base_type,
4
+ block_type,
5
+ broadcast,
6
+ constexpr,
7
+ dtype,
8
+ void,
9
+ int1,
10
+ int8,
11
+ int16,
12
+ int32,
13
+ int64,
14
+ uint8,
15
+ uint16,
16
+ uint32,
17
+ uint64,
18
+ float8e5,
19
+ float8e5b16,
20
+ float8e4nv,
21
+ float8e4b8,
22
+ float8e4b15,
23
+ float16,
24
+ bfloat16,
25
+ float32,
26
+ float64,
27
+ pointer_type,
28
+ shared_memory_descriptor,
29
+ tensor,
30
+ tuple,
31
+ tuple_type,
32
+ _unwrap_if_constexpr,
33
+ # API Functions
34
+ allocate_shared_memory,
35
+ arange,
36
+ associative_scan,
37
+ atomic_add,
38
+ atomic_and,
39
+ atomic_cas,
40
+ atomic_max,
41
+ atomic_min,
42
+ atomic_or,
43
+ atomic_xchg,
44
+ atomic_xor,
45
+ convert_layout,
46
+ device_assert,
47
+ expand_dims,
48
+ full,
49
+ histogram,
50
+ inline_asm_elementwise,
51
+ join,
52
+ load,
53
+ map_elementwise,
54
+ max_constancy,
55
+ max_contiguous,
56
+ maximum,
57
+ minimum,
58
+ multiple_of,
59
+ num_programs,
60
+ permute,
61
+ program_id,
62
+ reduce,
63
+ reshape,
64
+ set_auto_layout,
65
+ split,
66
+ static_assert,
67
+ static_print,
68
+ static_range,
69
+ store,
70
+ thread_barrier,
71
+ to_tensor,
72
+ warp_specialize,
73
+ where,
74
+ )
75
+ from ._layouts import (
76
+ AutoLayout,
77
+ BlockedLayout,
78
+ SliceLayout,
79
+ DistributedLinearLayout,
80
+ DotOperandLayout,
81
+ NVMMADistributedLayout,
82
+ NVMMASharedLayout,
83
+ SwizzledSharedLayout,
84
+ PaddedSharedLayout,
85
+ )
86
+ from ._math import (
87
+ umulhi,
88
+ exp,
89
+ exp2,
90
+ fma,
91
+ log,
92
+ log2,
93
+ cos,
94
+ rsqrt,
95
+ sin,
96
+ sqrt,
97
+ sqrt_rn,
98
+ abs,
99
+ fdiv,
100
+ div_rn,
101
+ erf,
102
+ floor,
103
+ ceil,
104
+ )
105
+ from ._standard import (
106
+ cdiv,
107
+ full_like,
108
+ max,
109
+ min,
110
+ reduce_or,
111
+ sum,
112
+ xor_sum,
113
+ zeros,
114
+ zeros_like,
115
+ )
116
+
117
+ from . import nvidia
118
+ from . import amd
119
+ from . import extra