triton-windows 3.1.0.post17__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,76 @@
1
+ import os
2
+ import re
3
+ import subprocess
4
+
5
+ from abc import ABCMeta, abstractmethod, abstractclassmethod
6
+ from dataclasses import dataclass
7
+ from typing import Union
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class GPUTarget(object):
12
+ # Target backend, e.g., cuda, hip
13
+ backend: str
14
+ # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
15
+ arch: Union[int, str]
16
+ warp_size: int
17
+
18
+
19
+ class BaseBackend(metaclass=ABCMeta):
20
+
21
+ def __init__(self, target: GPUTarget) -> None:
22
+ self.target = target
23
+ assert self.supports_target(target)
24
+
25
+ @staticmethod
26
+ def _path_to_binary(binary: str):
27
+ base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
28
+ paths = [
29
+ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
30
+ os.path.join(base_dir, "third_party", "cuda", "bin", binary),
31
+ ]
32
+ for p in paths:
33
+ bin = p.split(" ")[0]
34
+ if os.path.exists(bin) and os.path.isfile(bin):
35
+ result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
36
+ if result is not None:
37
+ version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
38
+ if version is not None:
39
+ return p, version.group(1)
40
+ raise RuntimeError(f"Cannot find {binary}")
41
+
42
+ @abstractclassmethod
43
+ def supports_target(target: GPUTarget):
44
+ raise NotImplementedError
45
+
46
+ @abstractmethod
47
+ def hash(self) -> str:
48
+ """Returns a unique identifier for this backend"""
49
+ raise NotImplementedError
50
+
51
+ @abstractmethod
52
+ def parse_options(self, options: dict) -> object:
53
+ """
54
+ Converts an `options` dictionary into an arbitrary object and returns it.
55
+ This function may contain target-specific heuristics and check the legality of the provided options
56
+ """
57
+ raise NotImplementedError
58
+
59
+ @abstractmethod
60
+ def add_stages(self, stages: dict, options: object) -> None:
61
+ """
62
+ Populates `stages` dictionary with entries of the form:
63
+ ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes]
64
+ The value of each entry may populate a `metadata` dictionary.
65
+ Stages will be run sequentially (in inseriton order) and can communicate using `metadata`.
66
+ All stages are expected to return a `str` object, except for the last stage which returns
67
+ a `bytes` object for execution by the launcher.
68
+ """
69
+ raise NotImplementedError
70
+
71
+ @abstractmethod
72
+ def load_dialects(self, context):
73
+ """
74
+ Load additional MLIR dialects into the provided `context`
75
+ """
76
+ raise NotImplementedError
@@ -0,0 +1,34 @@
1
+ from abc import ABCMeta, abstractmethod, abstractclassmethod
2
+
3
+
4
+ class DriverBase(metaclass=ABCMeta):
5
+
6
+ @abstractclassmethod
7
+ def is_active(self):
8
+ pass
9
+
10
+ @abstractmethod
11
+ def get_current_target(self):
12
+ pass
13
+
14
+ def __init__(self) -> None:
15
+ pass
16
+
17
+
18
+ class GPUDriver(DriverBase):
19
+
20
+ def __init__(self):
21
+ # TODO: support other frameworks than torch
22
+ import torch
23
+ self.get_device_capability = torch.cuda.get_device_capability
24
+ try:
25
+ from torch._C import _cuda_getCurrentRawStream
26
+ self.get_current_stream = _cuda_getCurrentRawStream
27
+ except ImportError:
28
+ self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
29
+ self.get_current_device = torch.cuda.current_device
30
+ self.set_current_device = torch.cuda.set_device
31
+
32
+ # TODO: remove once TMA is cleaned up
33
+ def assemble_tensormap_to_arg(self, tensormaps_info, args):
34
+ return args
File without changes
Binary file
@@ -0,0 +1,347 @@
1
+ from triton.backends.compiler import BaseBackend, GPUTarget
2
+ from triton._C.libtriton import ir, passes, llvm, nvidia
3
+
4
+ from dataclasses import dataclass
5
+ import functools
6
+ from typing import Any, Tuple, Optional
7
+ import hashlib
8
+ import re
9
+ import tempfile
10
+ import signal
11
+ import os
12
+ import subprocess
13
+ from pathlib import Path
14
+
15
+
16
+ @functools.lru_cache()
17
+ def _path_to_binary(binary: str):
18
+ paths = [
19
+ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
20
+ ]
21
+ if os.name == "nt":
22
+ binary += ".exe"
23
+ paths += [
24
+ os.path.join(os.path.dirname(__file__), "bin", binary),
25
+ ]
26
+ if os.name == "nt":
27
+ from triton.windows_utils import find_cuda
28
+ cuda_bin_path, _, _ = find_cuda()
29
+ if cuda_bin_path:
30
+ paths += [os.path.join(cuda_bin_path, binary)]
31
+
32
+ for bin in paths:
33
+ if os.path.exists(bin) and os.path.isfile(bin):
34
+ result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
35
+ if result is not None:
36
+ version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
37
+ if version is not None:
38
+ return bin, version.group(1)
39
+ raise RuntimeError(f"Cannot find {binary}")
40
+
41
+
42
+ @functools.lru_cache()
43
+ def get_ptxas_version():
44
+ version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]).decode("utf-8")
45
+ return version
46
+
47
+
48
+ @functools.lru_cache()
49
+ def ptx_get_version(cuda_version) -> int:
50
+ '''
51
+ Get the highest PTX version supported by the current CUDA driver.
52
+ '''
53
+ assert isinstance(cuda_version, str)
54
+ major, minor = map(int, cuda_version.split('.'))
55
+ if major == 12:
56
+ if minor < 6:
57
+ return 80 + minor
58
+ else:
59
+ return 79 + minor
60
+ if major == 11:
61
+ return 70 + minor
62
+ if major == 10:
63
+ return 63 + minor
64
+ raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
65
+
66
+
67
+ @functools.lru_cache(None)
68
+ def file_hash(path):
69
+ with open(path, "rb") as f:
70
+ return hashlib.sha256(f.read()).hexdigest()
71
+
72
+
73
+ # The file may be accessed in parallel
74
+ def try_remove(path):
75
+ if os.path.exists(path):
76
+ try:
77
+ os.remove(path)
78
+ except OSError:
79
+ import traceback
80
+ traceback.print_exc()
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class CUDAOptions:
85
+ num_warps: int = 4
86
+ num_ctas: int = 1
87
+ num_stages: int = 3
88
+ # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
89
+ # maximum number of 32-bit registers used by one thread.
90
+ maxnreg: Optional[int] = None
91
+ cluster_dims: tuple = (1, 1, 1)
92
+ ptx_version: int = None
93
+ enable_fp_fusion: bool = True
94
+ allow_fp8e4nv: bool = False
95
+ allow_fp8e4b15: bool = False
96
+ default_dot_input_precision: str = "tf32"
97
+ allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
98
+ max_num_imprecise_acc_default: bool = None
99
+ extern_libs: dict = None
100
+ debug: bool = False
101
+ backend_name: str = 'cuda'
102
+
103
+ def __post_init__(self):
104
+ default_libdir = Path(__file__).parent / 'lib'
105
+ extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
106
+ if not extern_libs.get('libdevice', None):
107
+ extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.10.bc'))
108
+ object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
109
+ assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
110
+ "num_warps must be a power of 2"
111
+
112
+ def hash(self):
113
+ hash_dict = dict(self.__dict__)
114
+ hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
115
+ key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
116
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
117
+
118
+
119
+ class CUDABackend(BaseBackend):
120
+
121
+ @staticmethod
122
+ def supports_target(target: GPUTarget):
123
+ return target.backend == 'cuda'
124
+
125
+ def __init__(self, target: GPUTarget) -> None:
126
+ super().__init__(target)
127
+ self.capability = target.arch
128
+ assert isinstance(self.capability, int)
129
+ self.binary_ext = "cubin"
130
+
131
+ def parse_options(self, opts) -> Any:
132
+ args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
133
+ args["allow_fp8e4nv"] = self.capability >= 89
134
+ args["allow_fp8e4b15"] = self.capability < 90
135
+ args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
136
+ return CUDAOptions(**args)
137
+
138
+ def pack_metadata(self, metadata):
139
+ return (
140
+ metadata.num_warps,
141
+ metadata.num_ctas,
142
+ metadata.shared,
143
+ metadata.cluster_dims[0],
144
+ metadata.cluster_dims[1],
145
+ metadata.cluster_dims[2],
146
+ )
147
+
148
+ def get_codegen_implementation(self):
149
+ import triton.language.extra.cuda as cuda
150
+ codegen_fns = {
151
+ "convert_custom_types":
152
+ cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70
153
+ }
154
+ return codegen_fns
155
+
156
+ def load_dialects(self, ctx):
157
+ nvidia.load_dialects(ctx)
158
+
159
+ @staticmethod
160
+ def make_ttir(mod, metadata, opt):
161
+ pm = ir.pass_manager(mod.context)
162
+ pm.enable_debug()
163
+ passes.common.add_inliner(pm)
164
+ passes.ttir.add_rewrite_tensor_pointer(pm)
165
+ passes.ttir.add_combine(pm)
166
+ passes.common.add_canonicalizer(pm)
167
+ passes.ttir.add_reorder_broadcast(pm)
168
+ passes.common.add_cse(pm)
169
+ passes.common.add_licm(pm)
170
+ passes.common.add_symbol_dce(pm)
171
+ pm.run(mod)
172
+ return mod
173
+
174
+ @staticmethod
175
+ def make_ttgir(mod, metadata, opt, capability):
176
+ cluster_info = nvidia.ClusterInfo()
177
+ if opt.cluster_dims is not None:
178
+ cluster_info.clusterDimX = opt.cluster_dims[0]
179
+ cluster_info.clusterDimY = opt.cluster_dims[1]
180
+ cluster_info.clusterDimZ = opt.cluster_dims[2]
181
+ # TTIR -> TTGIR
182
+ pm = ir.pass_manager(mod.context)
183
+ pm.enable_debug()
184
+ passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
185
+ # optimize TTGIR
186
+ passes.ttgpuir.add_coalesce(pm)
187
+ if capability // 10 >= 8:
188
+ passes.ttgpuir.add_f32_dot_tc(pm)
189
+ # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
190
+ nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
191
+ passes.ttgpuir.add_remove_layout_conversions(pm)
192
+ passes.ttgpuir.add_optimize_thread_locality(pm)
193
+ passes.ttgpuir.add_accelerate_matmul(pm)
194
+ passes.ttgpuir.add_remove_layout_conversions(pm)
195
+ passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
196
+ passes.common.add_cse(pm)
197
+ if capability // 10 >= 8:
198
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
199
+ passes.ttgpuir.add_pipeline(pm, opt.num_stages)
200
+ passes.ttgpuir.add_prefetch(pm)
201
+ passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
202
+ passes.ttgpuir.add_remove_layout_conversions(pm)
203
+ passes.ttgpuir.add_reduce_data_duplication(pm)
204
+ passes.ttgpuir.add_reorder_instructions(pm)
205
+ passes.common.add_cse(pm)
206
+ passes.common.add_symbol_dce(pm)
207
+ if capability // 10 >= 9:
208
+ nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
209
+ nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
210
+ passes.common.add_canonicalizer(pm)
211
+ pm.run(mod)
212
+ metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
213
+ return mod
214
+
215
+ @staticmethod
216
+ def make_llir(src, metadata, options, capability):
217
+ # warp-specialization mutates num_warps
218
+ num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
219
+ if num_warp_groups is not None:
220
+ metadata["num_warps"] *= num_warp_groups
221
+ mod = src
222
+ # TritonGPU -> LLVM-IR (MLIR)
223
+ pm = ir.pass_manager(mod.context)
224
+ pm.enable_debug()
225
+ nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
226
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
227
+ passes.convert.add_scf_to_cf(pm)
228
+ passes.convert.add_index_to_llvmir(pm)
229
+ passes.ttgpuir.add_allocate_shared_memory(pm)
230
+ nvidia.passes.ttgpuir.add_to_llvmir(pm, capability)
231
+ nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
232
+ passes.convert.add_arith_to_llvmir(pm)
233
+ passes.common.add_canonicalizer(pm)
234
+ passes.common.add_cse(pm)
235
+ passes.common.add_symbol_dce(pm)
236
+ if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
237
+ passes.llvmir.add_di_scope(pm)
238
+ pm.run(mod)
239
+ # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
240
+ llvm.init_targets()
241
+ context = llvm.context()
242
+ llvm_mod = llvm.to_module(mod, context)
243
+ nvidia.set_nvvm_reflect_ftz(llvm_mod)
244
+
245
+ # Set maxnreg on all kernels, if it was provided.
246
+ if options.maxnreg is not None:
247
+ for k in llvm_mod.get_functions():
248
+ if not k.is_declaration() and k.is_external_linkage():
249
+ k.set_nvvm_maxnreg(options.maxnreg)
250
+
251
+ if options.extern_libs:
252
+ paths = [path for (name, path) in options.extern_libs]
253
+ llvm.link_extern_libs(llvm_mod, paths)
254
+
255
+ llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
256
+
257
+ # Get some metadata
258
+ metadata["shared"] = src.get_int_attr("triton_gpu.shared")
259
+ ret = str(llvm_mod)
260
+ del llvm_mod
261
+ del context
262
+ return ret
263
+
264
+ @staticmethod
265
+ def make_ptx(src, metadata, opt, capability):
266
+ ptx_version = opt.ptx_version
267
+ if ptx_version is None:
268
+ _, cuda_version = _path_to_binary("ptxas")
269
+ ptx_version = ptx_get_version(cuda_version)
270
+
271
+ # PTX 8.3 is the max version supported by llvm 3a83162168.
272
+ #
273
+ # To check if a newer PTX version is supported, increase this value
274
+ # and run a test. If it's not supported, LLVM will print a warning
275
+ # like "+ptx8.4 is not a recognized feature for this target".
276
+ llvm_ptx_version = min(83, ptx_version)
277
+
278
+ triple = 'nvptx64-nvidia-cuda'
279
+ proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
280
+ features = f'+ptx{llvm_ptx_version}'
281
+ ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
282
+ # Find kernel names (there should only be one)
283
+ names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
284
+ assert len(names) == 1
285
+ metadata["name"] = names[0]
286
+ # post-process
287
+ ptx_version = f'{ptx_version//10}.{ptx_version%10}'
288
+ ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
289
+ # Remove the debug flag that prevents ptxas from optimizing the code
290
+ ret = re.sub(r",\s*debug|debug,\s*", "", ret)
291
+ if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
292
+ print("// -----// NVPTX Dump //----- //")
293
+ print(ret)
294
+ return ret
295
+
296
+ @staticmethod
297
+ def make_cubin(src, metadata, opt, capability):
298
+ ptxas, _ = _path_to_binary("ptxas")
299
+ with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
300
+ tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
301
+ fsrc.write(src)
302
+ fsrc.flush()
303
+ fbin = fsrc.name + '.o'
304
+
305
+ line_info = [] if os.environ.get('TRITON_DISABLE_LINE_INFO') else ['-lineinfo']
306
+ fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
307
+ suffix = 'a' if capability == 90 else ''
308
+ opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
309
+ ptxas_cmd = [
310
+ ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin
311
+ ]
312
+ try:
313
+ # close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
314
+ # On Windows, both stdout and stderr need to be redirected to flog
315
+ subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog, stderr=flog)
316
+ except subprocess.CalledProcessError as e:
317
+ with open(flog.name) as log_file:
318
+ log = log_file.read()
319
+
320
+ if e.returncode == 255:
321
+ raise RuntimeError(f'Internal Triton PTX codegen error: \n{log}')
322
+ elif e.returncode == 128 + signal.SIGSEGV:
323
+ raise RuntimeError(
324
+ f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}')
325
+ else:
326
+ raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}')
327
+
328
+ with open(fbin, 'rb') as f:
329
+ cubin = f.read()
330
+ try_remove(fbin)
331
+
332
+ # It's better to remove the temp files outside the context managers
333
+ try_remove(fsrc.name)
334
+ try_remove(flog.name)
335
+ return cubin
336
+
337
+ def add_stages(self, stages, options):
338
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
339
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
340
+ stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
341
+ stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
342
+ stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
343
+
344
+ @functools.lru_cache()
345
+ def hash(self):
346
+ version = get_ptxas_version()
347
+ return f'{version}-{self.capability}'