triton-windows 3.3.1.post19__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,34 +1,26 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import builtins
4
- import os
5
4
  import time
6
5
  import inspect
6
+ import hashlib
7
+ import json
8
+ from functools import cached_property
7
9
  from typing import Dict, Tuple, List, Optional
8
10
 
9
- from .jit import KernelInterface
11
+ from .. import knobs
12
+ from .jit import KernelInterface, JITFunction
10
13
  from .errors import OutOfResources, PTXASError
11
14
  from .driver import driver
15
+ from .cache import get_cache_manager, triton_key
16
+ from triton._C.libtriton import get_cache_invalidating_env_vars
12
17
 
13
18
 
14
19
  class Autotuner(KernelInterface):
15
20
 
16
- def __init__(
17
- self,
18
- fn,
19
- arg_names,
20
- configs,
21
- key,
22
- reset_to_zero,
23
- restore_value,
24
- pre_hook=None,
25
- post_hook=None,
26
- prune_configs_by: Optional[Dict] = None,
27
- warmup=None,
28
- rep=None,
29
- use_cuda_graph=False,
30
- do_bench=None,
31
- ):
21
+ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
22
+ prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None,
23
+ cache_results=False):
32
24
  """
33
25
  :param prune_configs_by: a dict of functions that are used to prune configs, fields:
34
26
  'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -36,15 +28,13 @@ class Autotuner(KernelInterface):
36
28
  'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
37
29
  """
38
30
  if not configs:
39
- self.configs = [
40
- Config({}, num_warps=4, num_stages=3, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
41
- reg_dec_producer=0, reg_inc_consumer=0)
42
- ]
31
+ self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
43
32
  else:
44
33
  self.configs = configs
45
34
  self.keys = key
46
35
  self.cache: Dict[Tuple, Config] = {}
47
36
  self.arg_names = arg_names
37
+ self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
48
38
 
49
39
  # Reset to zero or restore values
50
40
  self.reset_to_zero = []
@@ -97,6 +87,7 @@ class Autotuner(KernelInterface):
97
87
  while not inspect.isfunction(self.base_fn):
98
88
  self.base_fn = self.base_fn.fn
99
89
 
90
+ self._do_bench = do_bench
100
91
  self.num_warmups = warmup
101
92
  self.num_reps = rep
102
93
  self.use_cuda_graph = use_cuda_graph
@@ -110,7 +101,7 @@ class Autotuner(KernelInterface):
110
101
  stacklevel=1)
111
102
  if use_cuda_graph:
112
103
  from ..testing import do_bench_cudagraph
113
- self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
104
+ self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
114
105
  kernel_call,
115
106
  rep=rep if rep is not None else 100,
116
107
  quantiles=quantiles,
@@ -118,7 +109,7 @@ class Autotuner(KernelInterface):
118
109
  return
119
110
 
120
111
  import triton.testing
121
- self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
112
+ self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
122
113
  kernel_call,
123
114
  warmup=warmup if warmup is not None else 25,
124
115
  rep=rep if rep is not None else 100,
@@ -126,15 +117,16 @@ class Autotuner(KernelInterface):
126
117
  )
127
118
  return
128
119
 
129
- if do_bench is None:
130
- self.do_bench = driver.active.get_benchmarker()
131
- else:
132
- self.do_bench = do_bench
120
+ @cached_property
121
+ def do_bench(self):
122
+ if self._do_bench is None:
123
+ return driver.active.get_benchmarker()
124
+ return self._do_bench
133
125
 
134
126
  def _bench(self, *args, config, **meta):
135
127
  from ..compiler.errors import CompileTimeAssertionFailure
136
128
 
137
- verbose = os.environ.get("TRITON_PRINT_AUTOTUNING", None) == "1"
129
+ verbose = knobs.autotuning.print
138
130
  if verbose:
139
131
  print(f"Autotuning kernel {self.base_fn.__name__} with config {config}")
140
132
 
@@ -173,6 +165,48 @@ class Autotuner(KernelInterface):
173
165
  print(f"Autotuning failed with {e}")
174
166
  return [float("inf"), float("inf"), float("inf")]
175
167
 
168
+ def check_disk_cache(self, tuning_key, configs, bench_fn):
169
+ # We can't serialize prehooks, so just give up and run the benchmarks.
170
+ if not tuning_key or any(cfg.pre_hook for cfg in configs):
171
+ bench_fn()
172
+ return False
173
+
174
+ from triton.compiler.compiler import make_backend
175
+
176
+ fn = self.fn
177
+ while not isinstance(fn, JITFunction):
178
+ fn = fn.fn
179
+
180
+ env_vars = get_cache_invalidating_env_vars()
181
+ cache_key = [
182
+ triton_key(),
183
+ make_backend(driver.active.get_current_target()).hash(),
184
+ fn.cache_key,
185
+ str(sorted(env_vars.items())),
186
+ str(tuning_key),
187
+ ] + [str(c) for c in configs]
188
+ cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
189
+ cache = get_cache_manager(cache_key)
190
+ file_name = f"{fn.__name__[:150]}.autotune.json"
191
+ path = cache.get_file(file_name)
192
+ if path:
193
+ with open(path, "r") as cached_configs:
194
+ timings = json.load(cached_configs)["configs_timings"]
195
+ timings = {Config(**config): timing for config, timing in timings}
196
+ self.cache[tuning_key] = builtins.min(timings, key=timings.get)
197
+ self.configs_timings = timings
198
+ return True
199
+
200
+ bench_fn()
201
+ cache.put(
202
+ json.dumps({
203
+ "key":
204
+ tuning_key,
205
+ "configs_timings":
206
+ [(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
207
+ }), file_name, binary=False)
208
+ return False
209
+
176
210
  def run(self, *args, **kwargs):
177
211
  self.nargs = dict(zip(self.arg_names, args))
178
212
  used_cached_result = True
@@ -185,24 +219,31 @@ class Autotuner(KernelInterface):
185
219
  key.append(str(arg.dtype))
186
220
  key = tuple(key)
187
221
  if key not in self.cache:
188
- # prune configs
189
222
  used_cached_result = False
190
223
  pruned_configs = self.prune_configs(kwargs)
191
- bench_start = time.time()
192
- timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
193
- bench_end = time.time()
194
- self.bench_time = bench_end - bench_start
195
- self.cache[key] = builtins.min(timings, key=timings.get)
196
- full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
197
- self.pre_hook(full_nargs, reset_only=True)
198
- self.configs_timings = timings
224
+
225
+ def benchmark():
226
+ bench_start = time.perf_counter()
227
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
228
+ bench_end = time.perf_counter()
229
+ self.bench_time = bench_end - bench_start
230
+ self.cache[key] = builtins.min(timings, key=timings.get)
231
+ full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
232
+ self.pre_hook(full_nargs, reset_only=True)
233
+ self.configs_timings = timings
234
+
235
+ if self.cache_results:
236
+ used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
237
+ else:
238
+ benchmark()
239
+
199
240
  config = self.cache[key]
200
241
  else:
201
242
  config = self.configs[0]
202
243
  self.best_config = config
203
- if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
204
- print(f"Triton autotuning for function {self.base_fn.__name__} finished after "
205
- f"{self.bench_time:.2f}s; best config selected: {self.best_config};")
244
+ if knobs.autotuning.print and not used_cached_result:
245
+ print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
246
+ f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
206
247
  if config.pre_hook is not None:
207
248
  full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
208
249
  config.pre_hook(full_nargs)
@@ -241,11 +282,11 @@ class Autotuner(KernelInterface):
241
282
  def warmup(self, *args, **kwargs):
242
283
  self.nargs = dict(zip(self.arg_names, args))
243
284
  ret = []
244
- for config in self.prune_configs(kwargs):
285
+ for autotune_config in self.prune_configs(kwargs):
245
286
  ret.append(self.fn.warmup(
246
287
  *args,
247
288
  **kwargs,
248
- **config.all_kwargs(),
289
+ **autotune_config.all_kwargs(),
249
290
  ))
250
291
  self.nargs = None
251
292
  return ret
@@ -263,27 +304,34 @@ class Config:
263
304
  :type num_warps: int
264
305
  :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
265
306
  Mostly useful for matrix multiplication workloads on SM80+ GPUs.
266
- :type num_ctas: int
307
+ :type num_stages: int
267
308
  :ivar num_ctas: number of blocks in a block cluster. SM90+ only.
309
+ :type num_ctas: int
268
310
  :type maxnreg: Optional[int]
269
311
  :ivar maxnreg: maximum number of registers one thread can use. Corresponds
270
312
  to ptx .maxnreg directive. Not supported on all platforms.
271
313
  :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
272
314
  function are args.
315
+ :ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}).
273
316
  """
274
317
 
275
- def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
276
- reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None):
318
+ def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None):
277
319
  self.kwargs = kwargs
278
320
  self.num_warps = num_warps
279
321
  self.num_ctas = num_ctas
280
322
  self.num_stages = num_stages
281
- self.num_buffers_warp_spec = num_buffers_warp_spec
282
- self.num_consumer_groups = num_consumer_groups
283
- self.reg_dec_producer = reg_dec_producer
284
- self.reg_inc_consumer = reg_inc_consumer
285
323
  self.maxnreg = maxnreg
286
324
  self.pre_hook = pre_hook
325
+ self.ir_override = ir_override
326
+
327
+ def __setstate__(self, state):
328
+ self.kwargs = state.get("kwargs", {})
329
+ self.num_warps = state.get("num_warps", 4)
330
+ self.num_stages = state.get("num_stages", 3)
331
+ self.num_ctas = state.get("num_ctas", 1)
332
+ self.maxnreg = state.get("maxnreg", None)
333
+ self.pre_hook = state.get("pre_hook", None)
334
+ self.ir_override = state.get("ir_override", None)
287
335
 
288
336
  def all_kwargs(self):
289
337
  return {
@@ -293,11 +341,8 @@ class Config:
293
341
  ("num_warps", self.num_warps),
294
342
  ("num_ctas", self.num_ctas),
295
343
  ("num_stages", self.num_stages),
296
- ("num_buffers_warp_spec", self.num_buffers_warp_spec),
297
- ("num_consumer_groups", self.num_consumer_groups),
298
- ("reg_dec_producer", self.reg_dec_producer),
299
- ("reg_inc_consumer", self.reg_inc_consumer),
300
344
  ("maxnreg", self.maxnreg),
345
+ ("ir_override", self.ir_override),
301
346
  ) if v is not None
302
347
  }
303
348
  }
@@ -309,16 +354,26 @@ class Config:
309
354
  res.append(f"num_warps: {self.num_warps}")
310
355
  res.append(f"num_ctas: {self.num_ctas}")
311
356
  res.append(f"num_stages: {self.num_stages}")
312
- res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}")
313
- res.append(f"num_consumer_groups: {self.num_consumer_groups}")
314
- res.append(f"reg_dec_producer: {self.reg_dec_producer}")
315
- res.append(f"reg_inc_consumer: {self.reg_inc_consumer}")
316
357
  res.append(f"maxnreg: {self.maxnreg}")
317
358
  return ", ".join(res)
318
359
 
360
+ def __hash__(self):
361
+ return hash((*self.all_kwargs().items(), self.pre_hook))
362
+
363
+ def __eq__(self, other):
364
+ self_tuple = tuple((
365
+ *self.all_kwargs().items(),
366
+ self.pre_hook,
367
+ ))
368
+ other_tuple = tuple((
369
+ *other.all_kwargs().items(),
370
+ other.pre_hook,
371
+ ))
372
+ return self_tuple == other_tuple
373
+
319
374
 
320
375
  def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
321
- warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
376
+ warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False):
322
377
  """
323
378
  Decorator for auto-tuning a :code:`triton.jit`'d function.
324
379
 
@@ -372,12 +427,14 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
372
427
  :type rep: int
373
428
  :param do_bench: a benchmark function to measure the time of each run.
374
429
  :type do_bench: lambda fn, quantiles
430
+ :param cache_results: whether to cache autotune timings to disk. Defaults to False.
431
+ "type cache_results: bool
375
432
  """
376
433
 
377
434
  def decorator(fn):
378
435
  return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
379
436
  post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
380
- use_cuda_graph=use_cuda_graph, do_bench=do_bench)
437
+ use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results)
381
438
 
382
439
  return decorator
383
440
 
triton/runtime/build.py CHANGED
@@ -1,14 +1,25 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
- import sysconfig
4
+ import hashlib
5
+ import importlib.util
6
+ import logging
3
7
  import os
4
8
  import shutil
5
9
  import subprocess
10
+ import sysconfig
11
+ import tempfile
12
+
13
+ from types import ModuleType
14
+
15
+ from .cache import get_cache_manager
16
+ from .. import knobs
6
17
 
7
18
  if os.name == "nt":
8
19
  from triton.windows_utils import find_msvc_winsdk, find_python
9
20
 
10
21
 
11
- @functools.cache
22
+ @functools.lru_cache
12
23
  def get_cc():
13
24
  cc = os.environ.get("CC")
14
25
  if cc is None:
@@ -30,6 +41,11 @@ def get_cc():
30
41
  return cc
31
42
 
32
43
 
44
+ def is_tcc(cc):
45
+ cc = os.path.basename(cc).lower()
46
+ return cc == "tcc" or cc == "tcc.exe"
47
+
48
+
33
49
  def is_msvc(cc):
34
50
  cc = os.path.basename(cc).lower()
35
51
  return cc == "cl" or cc == "cl.exe"
@@ -40,10 +56,11 @@ def is_clang(cc):
40
56
  return cc == "clang" or cc == "clang.exe"
41
57
 
42
58
 
43
- def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
59
+ def _cc_cmd(cc: str, src: str, out: str, include_dirs: list[str], library_dirs: list[str], libraries: list[str],
60
+ ccflags: list[str]) -> list[str]:
44
61
  if is_msvc(cc):
45
62
  out_base = os.path.splitext(out)[0]
46
- cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819"]
63
+ cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/std:c11", "/wd4819"]
47
64
  cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None]
48
65
  cc_cmd += [f"/Fo{out_base + '.obj'}"]
49
66
  cc_cmd += ["/link"]
@@ -58,45 +75,94 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
58
75
  if not (os.name == "nt" and is_clang(cc)):
59
76
  # Clang does not support -fPIC on Windows
60
77
  cc_cmd += ["-fPIC"]
78
+ if is_tcc(cc):
79
+ cc_cmd += ["-D_Py_USE_GCC_BUILTIN_ATOMICS"]
61
80
  cc_cmd += [f'-l{lib}' for lib in libraries]
62
81
  cc_cmd += [f"-L{dir}" for dir in library_dirs]
63
82
  cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
83
+ cc_cmd += ccflags
64
84
  return cc_cmd
65
85
 
66
86
 
67
- def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
87
+ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
88
+ ccflags: list[str]) -> str:
89
+ if impl := knobs.build.impl:
90
+ return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
68
91
  suffix = sysconfig.get_config_var('EXT_SUFFIX')
69
92
  so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
70
- # try to avoid setuptools if possible
71
93
  cc = get_cc()
72
94
  # This function was renamed and made public in Python 3.10
73
95
  if hasattr(sysconfig, 'get_default_scheme'):
74
96
  scheme = sysconfig.get_default_scheme()
75
97
  else:
76
- scheme = sysconfig._get_default_scheme()
98
+ scheme = sysconfig._get_default_scheme() # type: ignore
77
99
  # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
78
100
  # path changes to include 'local'. This change is required to use triton with system-wide python.
79
101
  if scheme == 'posix_local':
80
102
  scheme = 'posix_prefix'
81
103
  py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
82
- custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH'))
104
+ custom_backend_dirs = knobs.build.backend_dirs
105
+ # Don't append in place
83
106
  include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
84
107
  if os.name == "nt":
85
- library_dirs += find_python()
86
- # Link against Python stable ABI
87
- # libraries is modified in place
88
- if "python3" not in libraries:
89
- libraries += ["python3"]
108
+ library_dirs = library_dirs + find_python()
109
+ version = sysconfig.get_python_version().replace(".", "")
110
+ if sysconfig.get_config_var("Py_GIL_DISABLED"):
111
+ version += "t"
112
+ libraries = libraries + [f"python{version}"]
90
113
  if is_msvc(cc):
91
114
  _, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
92
- include_dirs += msvc_winsdk_inc_dirs
93
- library_dirs += msvc_winsdk_lib_dirs
94
- cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
115
+ include_dirs = include_dirs + msvc_winsdk_inc_dirs
116
+ library_dirs = library_dirs + msvc_winsdk_lib_dirs
117
+ cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries, ccflags)
95
118
 
96
119
  try:
97
- ret = subprocess.check_call(cc_cmd)
120
+ subprocess.check_call(cc_cmd)
98
121
  except Exception as e:
99
122
  print("Failed to compile. cc_cmd:", cc_cmd)
100
123
  raise e
101
124
 
102
125
  return so
126
+
127
+
128
+ @functools.lru_cache
129
+ def platform_key() -> str:
130
+ from platform import machine, system, architecture
131
+ return ",".join([machine(), system(), *architecture()])
132
+
133
+
134
+ def _load_module_from_path(name: str, path: str) -> ModuleType:
135
+ # Loading module with relative path may cause error
136
+ path = os.path.abspath(path)
137
+ spec = importlib.util.spec_from_file_location(name, path)
138
+ if not spec or not spec.loader:
139
+ raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
140
+ mod = importlib.util.module_from_spec(spec)
141
+ spec.loader.exec_module(mod)
142
+ return mod
143
+
144
+
145
+ def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
146
+ include_dirs: list[str] | None = None, libraries: list[str] | None = None,
147
+ ccflags: list[str] | None = None) -> ModuleType:
148
+ key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
149
+ cache = get_cache_manager(key)
150
+ suffix = sysconfig.get_config_var("EXT_SUFFIX")
151
+ cache_path = cache.get_file(f"{name}{suffix}")
152
+
153
+ if cache_path is not None:
154
+ try:
155
+ return _load_module_from_path(name, cache_path)
156
+ except (RuntimeError, ImportError):
157
+ log = logging.getLogger(__name__)
158
+ log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
159
+
160
+ with tempfile.TemporaryDirectory() as tmpdir:
161
+ src_path = os.path.join(tmpdir, name + ".c")
162
+ with open(src_path, "w") as f:
163
+ f.write(src)
164
+ so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
165
+ with open(so, "rb") as f:
166
+ cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
167
+
168
+ return _load_module_from_path(name, cache_path)
triton/runtime/cache.py CHANGED
@@ -1,33 +1,19 @@
1
- import importlib
2
1
  import json
3
2
  import os
4
3
  import uuid
5
4
  from abc import ABC, abstractmethod
6
- from pathlib import Path
7
5
  from typing import Dict, List, Optional
8
6
  import base64
9
7
  import hashlib
8
+ import functools
9
+ import sysconfig
10
10
 
11
-
12
- def get_home_dir():
13
- return os.getenv("TRITON_HOME", Path.home())
14
-
15
-
16
- def default_cache_dir():
17
- return os.path.join(get_home_dir(), ".triton", "cache")
18
-
19
-
20
- def default_override_dir():
21
- return os.path.join(get_home_dir(), ".triton", "override")
22
-
23
-
24
- def default_dump_dir():
25
- return os.path.join(get_home_dir(), ".triton", "dump")
11
+ from triton import __version__, knobs
26
12
 
27
13
 
28
14
  class CacheManager(ABC):
29
15
 
30
- def __init__(self, key):
16
+ def __init__(self, key, override=False, dump=False):
31
17
  pass
32
18
 
33
19
  @abstractmethod
@@ -53,16 +39,16 @@ class FileCacheManager(CacheManager):
53
39
  self.key = key
54
40
  self.lock_path = None
55
41
  if dump:
56
- self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir()
42
+ self.cache_dir = knobs.cache.dump_dir
57
43
  self.cache_dir = os.path.join(self.cache_dir, self.key)
58
44
  self.lock_path = os.path.join(self.cache_dir, "lock")
59
45
  os.makedirs(self.cache_dir, exist_ok=True)
60
46
  elif override:
61
- self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir()
47
+ self.cache_dir = knobs.cache.override_dir
62
48
  self.cache_dir = os.path.join(self.cache_dir, self.key)
63
49
  else:
64
50
  # create cache directory if it doesn't exist
65
- self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
51
+ self.cache_dir = knobs.cache.dir
66
52
  if self.cache_dir:
67
53
  self.cache_dir = os.path.join(self.cache_dir, self.key)
68
54
  self.lock_path = os.path.join(self.cache_dir, "lock")
@@ -166,10 +152,10 @@ class RedisRemoteCacheBackend(RemoteCacheBackend):
166
152
  def __init__(self, key):
167
153
  import redis
168
154
  self._key = key
169
- self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
155
+ self._key_fmt = knobs.cache.redis.key_format
170
156
  self._redis = redis.Redis(
171
- host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
172
- port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
157
+ host=knobs.cache.redis.host,
158
+ port=knobs.cache.redis.port,
173
159
  )
174
160
 
175
161
  def _get_key(self, filename: str) -> str:
@@ -187,10 +173,10 @@ class RemoteCacheManager(CacheManager):
187
173
 
188
174
  def __init__(self, key, override=False, dump=False):
189
175
  # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
190
- remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
191
- module_path, clz_nme = remote_cache_manager.split(":")
192
- module = importlib.import_module(module_path)
193
- remote_cache_cls = getattr(module, clz_nme)
176
+ remote_cache_cls = knobs.cache.remote_manager_class
177
+ if not remote_cache_cls:
178
+ raise RuntimeError(
179
+ "Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
194
180
  self._backend = remote_cache_cls(key)
195
181
 
196
182
  self._override = override
@@ -260,37 +246,24 @@ class RemoteCacheManager(CacheManager):
260
246
  return self.put(grp_contents, grp_filename)
261
247
 
262
248
 
263
- __cache_cls = FileCacheManager
264
- __cache_cls_nme = "DEFAULT"
265
-
266
-
267
249
  def _base32(key):
268
250
  # Assume key is a hex string.
269
251
  return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
270
252
 
271
253
 
272
254
  def get_cache_manager(key) -> CacheManager:
273
- import os
274
-
275
- user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
276
- global __cache_cls
277
- global __cache_cls_nme
278
-
279
- if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
280
- module_path, clz_nme = user_cache_manager.split(":")
281
- module = importlib.import_module(module_path)
282
- __cache_cls = getattr(module, clz_nme)
283
- __cache_cls_nme = user_cache_manager
284
-
285
- return __cache_cls(_base32(key))
255
+ cls = knobs.cache.manager_class or FileCacheManager
256
+ return cls(_base32(key))
286
257
 
287
258
 
288
259
  def get_override_manager(key) -> CacheManager:
289
- return __cache_cls(_base32(key), override=True)
260
+ cls = knobs.cache.manager_class or FileCacheManager
261
+ return cls(_base32(key), override=True)
290
262
 
291
263
 
292
264
  def get_dump_manager(key) -> CacheManager:
293
- return __cache_cls(_base32(key), dump=True)
265
+ cls = knobs.cache.manager_class or FileCacheManager
266
+ return cls(_base32(key), dump=True)
294
267
 
295
268
 
296
269
  def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
@@ -301,3 +274,44 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
301
274
  key = f"{key}-{kwargs.get(kw)}"
302
275
  key = hashlib.sha256(key.encode("utf-8")).hexdigest()
303
276
  return _base32(key)
277
+
278
+
279
+ @functools.lru_cache()
280
+ def triton_key():
281
+ import pkgutil
282
+ TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
283
+ contents = []
284
+ # frontend
285
+ with open(__file__, "rb") as f:
286
+ contents += [hashlib.sha256(f.read()).hexdigest()]
287
+ # compiler
288
+ path_prefixes = [
289
+ (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
290
+ (os.path.join(TRITON_PATH, "backends"), "triton.backends."),
291
+ ]
292
+ for path, prefix in path_prefixes:
293
+ for lib in pkgutil.walk_packages([path], prefix=prefix):
294
+ with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
295
+ contents += [hashlib.sha256(f.read()).hexdigest()]
296
+
297
+ # backend
298
+ libtriton_hash = hashlib.sha256()
299
+ ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
300
+ with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
301
+ while True:
302
+ chunk = f.read(1024**2)
303
+ if not chunk:
304
+ break
305
+ libtriton_hash.update(chunk)
306
+ contents.append(libtriton_hash.hexdigest())
307
+ # language
308
+ language_path = os.path.join(TRITON_PATH, 'language')
309
+ for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
310
+ with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
311
+ contents += [hashlib.sha256(f.read()).hexdigest()]
312
+ return f'{__version__}' + '-'.join(contents)
313
+
314
+
315
+ def get_cache_key(src, backend, backend_options, env_vars):
316
+ key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
317
+ return key