triton-windows 3.1.0.post17__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,361 @@
1
+ from __future__ import annotations
2
+
3
+ import builtins
4
+ import os
5
+ import time
6
+ import inspect
7
+ from typing import Dict
8
+
9
+ from ..testing import do_bench, do_bench_cudagraph
10
+ from .jit import KernelInterface
11
+ from .errors import OutOfResources
12
+
13
+
14
+ class Autotuner(KernelInterface):
15
+
16
+ def __init__(
17
+ self,
18
+ fn,
19
+ arg_names,
20
+ configs,
21
+ key,
22
+ reset_to_zero,
23
+ restore_value,
24
+ pre_hook=None,
25
+ post_hook=None,
26
+ prune_configs_by: Dict = None,
27
+ warmup=25,
28
+ rep=100,
29
+ use_cuda_graph=False,
30
+ ):
31
+ """
32
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
33
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
34
+ 'top_k': number of configs to bench
35
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
36
+ """
37
+ if not configs:
38
+ self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)]
39
+ else:
40
+ self.configs = configs
41
+ self.key_idx = [arg_names.index(k) for k in key]
42
+ self.cache = {}
43
+ self.arg_names = arg_names
44
+
45
+ # Reset to zero or restore values
46
+ self.reset_idx = []
47
+ if reset_to_zero is not None:
48
+ self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
49
+ self.restore_idx = []
50
+ if restore_value is not None:
51
+ self.restore_idx = [arg_names.index(k) for k in restore_value]
52
+
53
+ # Hook to reset or restore for required tensors
54
+ self.pre_hook = lambda args, reset_only=False: 0
55
+ self.post_hook = lambda args, exception: 0
56
+ if pre_hook:
57
+ self.pre_hook = pre_hook
58
+ elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0):
59
+
60
+ def _pre_hook(args, reset_only=False):
61
+ for i in self.reset_idx:
62
+ args[i].zero_()
63
+ if not reset_only:
64
+ self.restore_copies = [args[i].clone() for i in self.restore_idx]
65
+
66
+ self.pre_hook = _pre_hook
67
+
68
+ if post_hook:
69
+ self.post_hook = post_hook
70
+ elif len(self.restore_idx) > 0:
71
+
72
+ def _post_hook(args, exception):
73
+ for i, j in enumerate(self.restore_idx):
74
+ args[j].copy_(self.restore_copies[i])
75
+ self.restore_copies = []
76
+
77
+ self.post_hook = _post_hook
78
+
79
+ self.perf_model = None
80
+ self.configs_top_k = 1.0
81
+ self.early_config_prune = None
82
+ if prune_configs_by:
83
+ self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
84
+ self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
85
+ self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune)
86
+
87
+ self.fn = fn
88
+ self.base_fn = fn
89
+ while not inspect.isfunction(self.base_fn):
90
+ self.base_fn = self.base_fn.fn
91
+ self.num_warmups = warmup
92
+ self.num_reps = rep
93
+ import torch
94
+ self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
95
+
96
+ def _bench(self, *args, config, **meta):
97
+ from ..compiler.errors import CompileTimeAssertionFailure
98
+
99
+ # check for conflicts, i.e. meta-parameters both provided
100
+ # as kwargs and by the autotuner
101
+ conflicts = meta.keys() & config.kwargs.keys()
102
+ if conflicts:
103
+ raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
104
+ " Make sure that you don't re-define auto-tuned symbols.")
105
+ # augment meta-parameters with tunable ones
106
+ current = dict(meta, **config.all_kwargs())
107
+ full_nargs = {**self.nargs, **current}
108
+
109
+ def kernel_call():
110
+ if config.pre_hook:
111
+ config.pre_hook(full_nargs)
112
+ self.pre_hook(args)
113
+ try:
114
+ self.fn.run(
115
+ *args,
116
+ **current,
117
+ )
118
+ except Exception as e:
119
+ try:
120
+ self.post_hook(args, exception=e)
121
+ finally:
122
+ # Throw exception raised by `self.fn.run`
123
+ raise
124
+
125
+ self.post_hook(args, exception=None)
126
+
127
+ try:
128
+ if self.use_cuda_graph:
129
+ import torch
130
+ with torch.cuda.stream(torch.cuda.Stream()):
131
+ bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median")
132
+ return bench_res
133
+ return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
134
+ except (OutOfResources, CompileTimeAssertionFailure):
135
+ return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")]
136
+
137
+ def run(self, *args, **kwargs):
138
+ self.nargs = dict(zip(self.arg_names, args))
139
+ used_cached_result = True
140
+ if len(self.configs) > 1:
141
+ all_args = {**self.nargs, **kwargs}
142
+ _args = []
143
+ for name in self.arg_names:
144
+ if name in all_args:
145
+ _args.append(all_args[name])
146
+ key = [_args[i] for i in self.key_idx]
147
+ for arg in _args:
148
+ if hasattr(arg, "dtype"):
149
+ key.append(str(arg.dtype))
150
+ key = tuple(key)
151
+ if key not in self.cache:
152
+ # prune configs
153
+ used_cached_result = False
154
+ pruned_configs = self.prune_configs(kwargs)
155
+ bench_start = time.time()
156
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
157
+ bench_end = time.time()
158
+ self.bench_time = bench_end - bench_start
159
+ self.cache[key] = builtins.min(timings, key=timings.get)
160
+ self.pre_hook(args, reset_only=True)
161
+ self.configs_timings = timings
162
+ config = self.cache[key]
163
+ else:
164
+ config = self.configs[0]
165
+ self.best_config = config
166
+ if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
167
+ print(f"Triton autotuning for function {self.base_fn.__name__} finished after "
168
+ f"{self.bench_time:.2f}s; best config selected: {self.best_config};")
169
+ if config.pre_hook is not None:
170
+ config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
171
+ ret = self.fn.run(
172
+ *args,
173
+ **kwargs,
174
+ **config.all_kwargs(),
175
+ )
176
+ self.nargs = None
177
+ return ret
178
+
179
+ def prune_configs(self, kwargs):
180
+ pruned_configs = self.configs
181
+ if self.early_config_prune:
182
+ pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
183
+ if self.perf_model:
184
+ top_k = self.configs_top_k
185
+ if isinstance(top_k, float) and top_k <= 1.0:
186
+ top_k = int(len(self.configs) * top_k)
187
+ if len(pruned_configs) > top_k:
188
+ est_timing = {
189
+ config: self.perf_model(
190
+ **self.nargs,
191
+ **kwargs,
192
+ **config.all_kwargs(),
193
+ )
194
+ for config in pruned_configs
195
+ }
196
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
197
+ return pruned_configs
198
+
199
+ def warmup(self, *args, **kwargs):
200
+ self.nargs = dict(zip(self.arg_names, args))
201
+ ret = []
202
+ for config in self.prune_configs(kwargs):
203
+ ret.append(self.fn.warmup(
204
+ *args,
205
+ **kwargs,
206
+ **config.all_kwargs(),
207
+ ))
208
+ self.nargs = None
209
+ return ret
210
+
211
+
212
+ class Config:
213
+ """
214
+ An object that represents a possible kernel configuration for the auto-tuner to try.
215
+
216
+ :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
217
+ :type kwargs: dict[Str, Any]
218
+ :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
219
+ `num_warps=8`, then each kernel instance will be automatically parallelized to
220
+ cooperatively execute using `8 * 32 = 256` threads.
221
+ :type num_warps: int
222
+ :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
223
+ Mostly useful for matrix multiplication workloads on SM80+ GPUs.
224
+ :type num_ctas: int
225
+ :ivar num_ctas: number of blocks in a block cluster. SM90+ only.
226
+ :type maxnreg: Optional[int]
227
+ :ivar maxnreg: maximum number of registers one thread can use. Corresponds
228
+ to ptx .maxnreg directive. Not supported on all platforms.
229
+ :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
230
+ function are args.
231
+ """
232
+
233
+ def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None):
234
+ self.kwargs = kwargs
235
+ self.num_warps = num_warps
236
+ self.num_ctas = num_ctas
237
+ self.num_stages = num_stages
238
+ self.maxnreg = maxnreg
239
+ self.pre_hook = pre_hook
240
+
241
+ def all_kwargs(self):
242
+ return {
243
+ **self.kwargs, **{
244
+ k: v
245
+ for (k, v) in (
246
+ ("num_warps", self.num_warps),
247
+ ("num_ctas", self.num_ctas),
248
+ ("num_stages", self.num_stages),
249
+ ("maxnreg", self.maxnreg),
250
+ ) if v is not None
251
+ }
252
+ }
253
+
254
+ def __str__(self):
255
+ res = []
256
+ for k, v in self.kwargs.items():
257
+ res.append(f"{k}: {v}")
258
+ res.append(f"num_warps: {self.num_warps}")
259
+ res.append(f"num_ctas: {self.num_ctas}")
260
+ res.append(f"num_stages: {self.num_stages}")
261
+ res.append(f"maxnreg: {self.maxnreg}")
262
+ return ", ".join(res)
263
+
264
+
265
+ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
266
+ warmup=25, rep=100, use_cuda_graph=False):
267
+ """
268
+ Decorator for auto-tuning a :code:`triton.jit`'d function.
269
+
270
+ .. highlight:: python
271
+ .. code-block:: python
272
+
273
+ @triton.autotune(configs=[
274
+ triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4),
275
+ triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8),
276
+ ],
277
+ key=['x_size'] # the two above configs will be evaluated anytime
278
+ # the value of x_size changes
279
+ )
280
+ @triton.jit
281
+ def kernel(x_ptr, x_size, **META):
282
+ BLOCK_SIZE = META['BLOCK_SIZE']
283
+ :note: When all the configurations are evaluated, the kernel will run multiple times.
284
+ This means that whatever value the kernel updates will be updated multiple times.
285
+ To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
286
+ resets the value of the provided tensor to `zero` before running any configuration.
287
+
288
+ If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to
289
+ :code:`"1"`, Triton will print a message to stdout after autotuning each
290
+ kernel, including the time spent autotuning and the best configuration.
291
+
292
+ :param configs: a list of :code:`triton.Config` objects
293
+ :type configs: list[triton.Config]
294
+ :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
295
+ :type key: list[str]
296
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
297
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
298
+ 'top_k': number of configs to bench
299
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
300
+ :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
301
+ :type reset_to_zero: list[str]
302
+ :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
303
+ :type restore_value: list[str]
304
+ :param pre_hook: a function that will be called before the kernel is called.
305
+ This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'.
306
+ 'args': a list of arguments passed to the kernel.
307
+ 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook.
308
+ :type pre_hook: lambda args, reset_only
309
+ :param post_hook: a function that will be called after the kernel is called.
310
+ This overrides the default post_hook used for 'restore_value'.
311
+ 'args': a list of arguments passed to the kernel.
312
+ 'exception': the exception raised by the kernel in case of a compilation or runtime error.
313
+ :type post_hook: lambda args, exception
314
+ :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
315
+ :type warmup: int
316
+ :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
317
+ :type rep: int
318
+ """
319
+
320
+ def decorator(fn):
321
+ return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
322
+ post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
323
+ use_cuda_graph=use_cuda_graph)
324
+
325
+ return decorator
326
+
327
+
328
+ class Heuristics(KernelInterface):
329
+
330
+ def __init__(self, fn, arg_names, values) -> None:
331
+ self.fn = fn
332
+ self.values = values
333
+ self.arg_names = arg_names
334
+
335
+ def run(self, *args, **kwargs):
336
+ for v, heur in self.values.items():
337
+ kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
338
+ return self.fn.run(*args, **kwargs)
339
+
340
+
341
+ def heuristics(values):
342
+ """
343
+ Decorator for specifying how the values of certain meta-parameters may be computed.
344
+ This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
345
+
346
+ .. highlight:: python
347
+ .. code-block:: python
348
+
349
+ @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
350
+ @triton.jit
351
+ def kernel(x_ptr, x_size, **META):
352
+ BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
353
+ :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
354
+ each such function takes a list of positional arguments as input.
355
+ :type values: dict[str, Callable[[list[Any]], Any]]
356
+ """
357
+
358
+ def decorator(fn):
359
+ return Heuristics(fn, fn.arg_names, values)
360
+
361
+ return decorator
@@ -0,0 +1,129 @@
1
+ import contextlib
2
+ import sys
3
+ import io
4
+ import sysconfig
5
+ import os
6
+ import shutil
7
+ import subprocess
8
+ import setuptools
9
+
10
+ if os.name == "nt":
11
+ from triton.windows_utils import find_msvc_winsdk, find_python
12
+
13
+
14
+ @contextlib.contextmanager
15
+ def quiet():
16
+ old_stdout, old_stderr = sys.stdout, sys.stderr
17
+ sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
18
+ try:
19
+ yield
20
+ finally:
21
+ sys.stdout, sys.stderr = old_stdout, old_stderr
22
+
23
+
24
+ def get_cc():
25
+ cc = os.environ.get("CC")
26
+ if cc is None:
27
+ # Bundled TinyCC
28
+ cc = os.path.join(sysconfig.get_paths()["platlib"], "triton", "runtime", "tcc", "tcc.exe")
29
+ if not os.path.exists(cc):
30
+ cc = None
31
+ if cc is None:
32
+ cc = shutil.which("cl")
33
+ if cc is None:
34
+ cc = shutil.which("gcc")
35
+ if cc is None:
36
+ cc = shutil.which("clang")
37
+ if cc is None:
38
+ raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
39
+ return cc
40
+
41
+
42
+ def is_msvc(cc):
43
+ cc = os.path.basename(cc).lower()
44
+ return cc == "cl" or cc == "cl.exe"
45
+
46
+
47
+ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
48
+ if is_msvc(cc):
49
+ out_base = os.path.splitext(out)[0]
50
+ cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819"]
51
+ cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None]
52
+ cc_cmd += [f"/Fo{out_base + '.obj'}"]
53
+ cc_cmd += ["/link"]
54
+ cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs]
55
+ cc_cmd += [f'{lib}.lib' for lib in libraries]
56
+ cc_cmd += [f"/OUT:{out}"]
57
+ cc_cmd += [f"/IMPLIB:{out_base + '.lib'}"]
58
+ cc_cmd += [f"/PDB:{out_base + '.pdb'}"]
59
+ else:
60
+ # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
61
+ cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", out]
62
+ cc_cmd += [f'-l{lib}' for lib in libraries]
63
+ cc_cmd += [f"-L{dir}" for dir in library_dirs]
64
+ cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
65
+ return cc_cmd
66
+
67
+
68
+ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
69
+ suffix = sysconfig.get_config_var('EXT_SUFFIX')
70
+ so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
71
+ # try to avoid setuptools if possible
72
+ cc = get_cc()
73
+ # This function was renamed and made public in Python 3.10
74
+ if hasattr(sysconfig, 'get_default_scheme'):
75
+ scheme = sysconfig.get_default_scheme()
76
+ else:
77
+ scheme = sysconfig._get_default_scheme()
78
+ # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
79
+ # path changes to include 'local'. This change is required to use triton with system-wide python.
80
+ if scheme == 'posix_local':
81
+ scheme = 'posix_prefix'
82
+ py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
83
+ include_dirs = include_dirs + [srcdir, py_include_dir]
84
+ if os.name == "nt":
85
+ library_dirs += find_python()
86
+ # Link against Python stable ABI
87
+ # libraries is modified in place
88
+ if "python3" not in libraries:
89
+ libraries += ["python3"]
90
+ if is_msvc(cc):
91
+ msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
92
+ include_dirs += msvc_winsdk_inc_dirs
93
+ library_dirs += msvc_winsdk_lib_dirs
94
+ cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
95
+ ret = subprocess.check_call(cc_cmd)
96
+ if ret == 0:
97
+ return so
98
+ # fallback on setuptools
99
+ extra_compile_args = []
100
+ if is_msvc(cc):
101
+ extra_compile_args += ["/O2"]
102
+ else:
103
+ extra_compile_args += ["-O3"]
104
+ # extra arguments
105
+ extra_link_args = []
106
+ # create extension module
107
+ ext = setuptools.Extension(
108
+ name=name,
109
+ language='c',
110
+ sources=[src],
111
+ include_dirs=include_dirs,
112
+ extra_compile_args=extra_compile_args,
113
+ extra_link_args=extra_link_args,
114
+ library_dirs=library_dirs,
115
+ libraries=libraries,
116
+ )
117
+ # build extension module
118
+ args = ['build_ext']
119
+ args.append('--build-temp=' + srcdir)
120
+ args.append('--build-lib=' + srcdir)
121
+ args.append('-q')
122
+ args = dict(
123
+ name=name,
124
+ ext_modules=[ext],
125
+ script_args=args,
126
+ )
127
+ with quiet():
128
+ setuptools.setup(**args)
129
+ return so