triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,533 @@
1
+ from triton.backends.compiler import BaseBackend, GPUTarget, Language
2
+ from triton._C.libtriton import ir, passes, llvm, nvidia
3
+ from triton import knobs
4
+ from triton.runtime.errors import PTXASError
5
+
6
+ from dataclasses import dataclass
7
+ import functools
8
+ from typing import Any, Dict, Tuple, Optional
9
+ from types import ModuleType
10
+ import hashlib
11
+ import re
12
+ import tempfile
13
+ import signal
14
+ import os
15
+ import subprocess
16
+ from pathlib import Path
17
+
18
+
19
+ def min_dot_size(target: GPUTarget):
20
+
21
+ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
22
+ lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
23
+ rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
24
+ assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
25
+ # For small M/N the input we can still use tensorcores with padding.
26
+ if lhs_bitwidth == 8:
27
+ return (1, 1, 32)
28
+ else:
29
+ return (1, 1, 16)
30
+
31
+ return check_dot_compatibility
32
+
33
+
34
+ def get_ptxas() -> knobs.NvidiaTool:
35
+ return knobs.nvidia.ptxas
36
+
37
+
38
+ @functools.lru_cache()
39
+ def get_ptxas_version():
40
+ mock_ver = knobs.nvidia.mock_ptx_version
41
+ if mock_ver is not None:
42
+ return mock_ver # This is not really a version of ptxas, but it is good enough for testing
43
+ version = subprocess.check_output([get_ptxas().path, "--version"]).decode("utf-8")
44
+ return version
45
+
46
+
47
+ @functools.lru_cache()
48
+ def ptx_get_version(cuda_version) -> int:
49
+ '''
50
+ Get the highest PTX version supported by the current CUDA driver.
51
+ '''
52
+ assert isinstance(cuda_version, str)
53
+ major, minor = map(int, cuda_version.split('.'))
54
+ if major == 12:
55
+ if minor < 6:
56
+ return 80 + minor
57
+ else:
58
+ return 80 + minor - 1
59
+ if major == 11:
60
+ return 70 + minor
61
+ if major == 10:
62
+ return 63 + minor
63
+
64
+ if major >= 13:
65
+ base_ptx = 90
66
+ return base_ptx + (major - 13) * 10 + minor
67
+
68
+ raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
69
+
70
+
71
+ def get_ptx_version_from_options(options, arch: int):
72
+ ptx_version = options.ptx_version
73
+ if ptx_version is None:
74
+ cuda_version = get_ptxas().version
75
+ ptx_version = ptx_get_version(cuda_version)
76
+ return ptx_version
77
+
78
+
79
+ @functools.lru_cache()
80
+ def get_features(options, arch: int):
81
+ ptx_version = get_ptx_version_from_options(options, arch)
82
+
83
+ # PTX 8.6 is the max version supported by llvm c1188642.
84
+ #
85
+ # To check if a newer PTX version is supported, increase this value
86
+ # and run a test. If it's not supported, LLVM will print a warning
87
+ # like "+ptx8.4 is not a recognized feature for this target".
88
+ llvm_ptx_version = min(86, ptx_version)
89
+ features = f'+ptx{llvm_ptx_version}'
90
+ return features
91
+
92
+
93
+ @functools.lru_cache(None)
94
+ def file_hash(path):
95
+ with open(path, "rb") as f:
96
+ return hashlib.sha256(f.read()).hexdigest()
97
+
98
+
99
+ def sm_arch_from_capability(capability: int):
100
+ # TODO: Handle non-"a" sms
101
+ suffix = "a" if capability >= 90 else ""
102
+ return f"sm_{capability}{suffix}"
103
+
104
+
105
+ # The file may be accessed in parallel
106
+ def try_remove(path):
107
+ if os.path.exists(path):
108
+ try:
109
+ os.remove(path)
110
+ except OSError:
111
+ import traceback
112
+ traceback.print_exc()
113
+
114
+
115
+ @dataclass(frozen=True)
116
+ class CUDAOptions:
117
+ num_warps: int = 4
118
+ num_ctas: int = 1
119
+ num_stages: int = 3
120
+ warp_size: int = 32
121
+ # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
122
+ # maximum number of 32-bit registers used by one thread.
123
+ maxnreg: Optional[int] = None
124
+ cluster_dims: tuple = (1, 1, 1)
125
+ ptx_version: int = None
126
+ ptx_options: str = None
127
+ ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
128
+ enable_fp_fusion: bool = True
129
+ launch_cooperative_grid: bool = False
130
+ launch_pdl: bool = False
131
+ supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
132
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
133
+ default_dot_input_precision: str = "tf32"
134
+ allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
135
+ max_num_imprecise_acc_default: bool = None
136
+ extern_libs: dict = None
137
+ debug: bool = False
138
+ backend_name: str = 'cuda'
139
+ sanitize_overflow: bool = True
140
+ arch: str = None
141
+ instrumentation_mode: str = ""
142
+
143
+ def __post_init__(self):
144
+ default_libdir = Path(__file__).parent / 'lib'
145
+ extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
146
+ if not extern_libs.get('libdevice', None):
147
+ extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
148
+
149
+ object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
150
+ assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
151
+ "num_warps must be a power of 2"
152
+
153
+ def hash(self):
154
+ hash_dict = dict(self.__dict__)
155
+ hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
156
+ key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
157
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
158
+
159
+
160
+ class CUDABackend(BaseBackend):
161
+ instrumentation = None
162
+
163
+ @staticmethod
164
+ def supports_target(target: GPUTarget):
165
+ return target.backend == 'cuda'
166
+
167
+ def _parse_arch(self, arch):
168
+ pattern = r"^sm(\d+)$"
169
+ match = re.fullmatch(pattern, arch)
170
+ if not match:
171
+ raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
172
+ return int(match.group(1))
173
+
174
+ def get_target_name(self, options) -> str:
175
+ capability = self._parse_arch(options.arch)
176
+ return f"cuda:{capability}"
177
+
178
+ def __init__(self, target: GPUTarget) -> None:
179
+ super().__init__(target)
180
+ self.binary_ext = "cubin"
181
+
182
+ def parse_options(self, opts) -> Any:
183
+ args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
184
+ args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
185
+ capability = int(self._parse_arch(args["arch"]))
186
+
187
+ if args.get("num_ctas", 1) > 1 and capability < 90:
188
+ raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
189
+ f"Current target is sm_{capability}. This configuration will fail. "
190
+ f"Please set num_ctas=1 or target an SM90+ GPU."))
191
+
192
+ if "supported_fp8_dtypes" not in args:
193
+ supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
194
+ args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
195
+
196
+ if "deprecated_fp8_dot_operand_dtypes" not in args:
197
+ if capability >= 90:
198
+ args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
199
+
200
+ if "enable_fp_fusion" not in args:
201
+ args["enable_fp_fusion"] = knobs.language.default_fp_fusion
202
+
203
+ args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
204
+
205
+ return CUDAOptions(**args)
206
+
207
+ def pack_metadata(self, metadata):
208
+ return (
209
+ metadata.num_warps,
210
+ metadata.num_ctas,
211
+ metadata.shared,
212
+ metadata.cluster_dims[0],
213
+ metadata.cluster_dims[1],
214
+ metadata.cluster_dims[2],
215
+ )
216
+
217
+ def get_codegen_implementation(self, options):
218
+ import triton.language.extra.cuda as cuda
219
+ capability = int(self._parse_arch(options.arch))
220
+ codegen_fns = {
221
+ "convert_custom_types":
222
+ cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
223
+ min_dot_size(self.target)
224
+ }
225
+ return codegen_fns
226
+
227
+ def get_module_map(self) -> Dict[str, ModuleType]:
228
+ from triton.language.extra.cuda import libdevice
229
+ return {"triton.language.extra.libdevice": libdevice}
230
+
231
+ def load_dialects(self, ctx):
232
+ nvidia.load_dialects(ctx)
233
+ if CUDABackend.instrumentation:
234
+ CUDABackend.instrumentation.load_dialects(ctx)
235
+
236
+ @staticmethod
237
+ def make_ttir(mod, metadata, opt, capability):
238
+ pm = ir.pass_manager(mod.context)
239
+ pm.enable_debug()
240
+ passes.common.add_inliner(pm)
241
+ passes.ttir.add_rewrite_tensor_pointer(pm)
242
+ if capability // 10 < 9:
243
+ passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
244
+ passes.common.add_canonicalizer(pm)
245
+ passes.ttir.add_combine(pm)
246
+ passes.ttir.add_reorder_broadcast(pm)
247
+ passes.common.add_cse(pm)
248
+ passes.common.add_symbol_dce(pm)
249
+ passes.ttir.add_loop_unroll(pm)
250
+ pm.run(mod)
251
+ return mod
252
+
253
+ @staticmethod
254
+ def make_ttgir(mod, metadata, opt, capability):
255
+ # Set maxnreg on all kernels, if it was provided.
256
+ if opt.maxnreg is not None:
257
+ mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
258
+
259
+ cluster_info = nvidia.ClusterInfo()
260
+ if opt.cluster_dims is not None:
261
+ cluster_info.clusterDimX = opt.cluster_dims[0]
262
+ cluster_info.clusterDimY = opt.cluster_dims[1]
263
+ cluster_info.clusterDimZ = opt.cluster_dims[2]
264
+ pm = ir.pass_manager(mod.context)
265
+ dump_enabled = pm.enable_debug()
266
+ passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
267
+ # optimize TTGIR
268
+ passes.ttgpuir.add_coalesce(pm)
269
+ if capability // 10 >= 8:
270
+ passes.ttgpuir.add_f32_dot_tc(pm)
271
+ # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
272
+ nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
273
+ passes.ttgpuir.add_remove_layout_conversions(pm)
274
+ passes.ttgpuir.add_optimize_thread_locality(pm)
275
+ passes.ttgpuir.add_accelerate_matmul(pm)
276
+ passes.ttgpuir.add_remove_layout_conversions(pm)
277
+ passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
278
+ nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
279
+ passes.ttir.add_loop_aware_cse(pm)
280
+ if capability // 10 in [8, 9]:
281
+ passes.ttgpuir.add_fuse_nested_loops(pm)
282
+ passes.common.add_canonicalizer(pm)
283
+ passes.ttir.add_triton_licm(pm)
284
+ passes.common.add_canonicalizer(pm)
285
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
286
+ nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
287
+ passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
288
+ passes.ttgpuir.add_schedule_loops(pm)
289
+ passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
290
+ elif capability // 10 >= 10:
291
+ passes.ttgpuir.add_fuse_nested_loops(pm)
292
+ passes.common.add_canonicalizer(pm)
293
+ passes.ttir.add_triton_licm(pm)
294
+ passes.ttgpuir.add_optimize_accumulator_init(pm)
295
+ passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
296
+ nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
297
+ passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
298
+ passes.ttgpuir.add_schedule_loops(pm)
299
+ passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
300
+ passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
301
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
302
+ # hoist again and allow hoisting out of if statements
303
+ passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
304
+ nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
305
+ else:
306
+ passes.ttir.add_triton_licm(pm)
307
+ passes.common.add_canonicalizer(pm)
308
+ passes.ttir.add_loop_aware_cse(pm)
309
+ passes.ttgpuir.add_prefetch(pm)
310
+ passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
311
+ passes.ttgpuir.add_coalesce_async_copy(pm)
312
+ nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
313
+ passes.ttgpuir.add_remove_layout_conversions(pm)
314
+ nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
315
+ passes.ttgpuir.add_reduce_data_duplication(pm)
316
+ passes.ttgpuir.add_reorder_instructions(pm)
317
+ passes.ttir.add_loop_aware_cse(pm)
318
+ passes.common.add_symbol_dce(pm)
319
+ if capability // 10 >= 9:
320
+ nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
321
+ nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
322
+ nvidia.passes.ttnvgpuir.add_lower_mma(pm)
323
+ passes.common.add_sccp(pm)
324
+ passes.common.add_cse(pm)
325
+ passes.common.add_canonicalizer(pm)
326
+
327
+ pm.run(mod)
328
+ metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
329
+ tensordesc_meta = mod.get_tensordesc_metadata()
330
+ metadata["tensordesc_meta"] = tensordesc_meta
331
+ return mod
332
+
333
+ def gluon_to_ttgir(self, src, metadata, options, capability):
334
+ mod = src
335
+ pm = ir.pass_manager(mod.context)
336
+ pm.enable_debug()
337
+
338
+ passes.gluon.add_inliner(pm)
339
+ passes.gluon.add_resolve_auto_encodings(pm)
340
+ passes.common.add_sccp(pm)
341
+ passes.ttir.add_loop_aware_cse(pm)
342
+ passes.gluon.add_canonicalizer(pm)
343
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
344
+
345
+ pm.run(mod)
346
+ metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
347
+ return mod
348
+
349
+ def make_llir(self, src, metadata, options, capability):
350
+ ptx_version = get_ptx_version_from_options(options, self.target.arch)
351
+
352
+ mod = src
353
+ # TritonGPU -> LLVM-IR (MLIR)
354
+ pm = ir.pass_manager(mod.context)
355
+ pm.enable_debug()
356
+
357
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
358
+ passes.ttgpuir.add_allocate_warp_groups(pm)
359
+ passes.convert.add_scf_to_cf(pm)
360
+ nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
361
+ nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
362
+ if knobs.compilation.enable_experimental_consan:
363
+ # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
364
+ passes.ttgpuir.add_concurrency_sanitizer(pm)
365
+ passes.ttgpuir.add_allocate_global_scratch_memory(pm)
366
+ nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
367
+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
368
+ if CUDABackend.instrumentation:
369
+ CUDABackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
370
+ nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
371
+ passes.common.add_canonicalizer(pm)
372
+ passes.common.add_cse(pm)
373
+ nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
374
+ nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
375
+ passes.common.add_canonicalizer(pm)
376
+ passes.common.add_cse(pm)
377
+ passes.common.add_symbol_dce(pm)
378
+ passes.convert.add_nvvm_to_llvm(pm)
379
+ if not knobs.compilation.disable_line_info:
380
+ passes.llvmir.add_di_scope(pm)
381
+ if CUDABackend.instrumentation:
382
+ CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
383
+
384
+ pm.run(mod)
385
+ # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
386
+ llvm.init_targets()
387
+ context = llvm.context()
388
+ if knobs.compilation.enable_asan:
389
+ raise RuntimeError(
390
+ "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
391
+ llvm_mod = llvm.to_module(mod, context)
392
+ proc = sm_arch_from_capability(capability)
393
+ features = get_features(options, self.target.arch)
394
+ triple = 'nvptx64-nvidia-cuda'
395
+ nvidia.set_short_ptr()
396
+ llvm.attach_datalayout(llvm_mod, triple, proc, features)
397
+ nvidia.set_nvvm_reflect_ftz(llvm_mod)
398
+
399
+ if options.extern_libs and nvidia.has_extern_deps(llvm_mod):
400
+ paths = [path for (name, path) in options.extern_libs]
401
+ llvm.link_extern_libs(llvm_mod, paths)
402
+
403
+ llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
404
+
405
+ # Get some metadata
406
+ # warp-specialization mutates num_warps
407
+ total_num_warps = src.get_int_attr("ttg.total-num-warps")
408
+ if total_num_warps is not None:
409
+ metadata["num_warps"] = total_num_warps
410
+ metadata["shared"] = src.get_int_attr("ttg.shared")
411
+ metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
412
+ metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
413
+ metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
414
+ metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
415
+ metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
416
+ ret = str(llvm_mod)
417
+ del llvm_mod
418
+ del context
419
+ return ret
420
+
421
+ def make_ptx(self, src, metadata, opt, capability):
422
+ ptx_version = get_ptx_version_from_options(opt, self.target.arch)
423
+
424
+ triple = 'nvptx64-nvidia-cuda'
425
+ proc = sm_arch_from_capability(capability)
426
+ features = get_features(opt, self.target.arch)
427
+ ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
428
+ # Find kernel names (there should only be one)
429
+ names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
430
+ assert len(names) == 1
431
+ metadata["name"] = names[0]
432
+ # post-process
433
+ ptx_version = f'{ptx_version//10}.{ptx_version%10}'
434
+ ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
435
+ ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
436
+ # Remove the debug flag that prevents ptxas from optimizing the code
437
+ ret = re.sub(r",\s*debug|debug,\s*", "", ret)
438
+ if knobs.nvidia.dump_nvptx:
439
+ print("// -----// NVPTX Dump //----- //")
440
+ print(ret)
441
+ return ret
442
+
443
+ def make_cubin(self, src, metadata, opt, capability):
444
+ ptxas = get_ptxas().path
445
+ with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
446
+ tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
447
+ fsrc.write(src)
448
+ fsrc.flush()
449
+ fbin = fsrc.name + '.o'
450
+
451
+ debug_info = []
452
+ if knobs.compilation.disable_line_info:
453
+ # This option is ignored if used without -lineinfo
454
+ debug_info += ["-lineinfo", "-suppress-debug-info"]
455
+ elif knobs.nvidia.disable_ptxas_opt:
456
+ # Synthesize complete debug info
457
+ debug_info += ["-g"]
458
+ else:
459
+ # Only emit line info
460
+ debug_info += ["-lineinfo"]
461
+
462
+ fmad = [] if opt.enable_fp_fusion else ["--fmad=false"]
463
+ arch = sm_arch_from_capability(capability)
464
+
465
+ # Disable ptxas optimizations if requested
466
+ disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
467
+
468
+ # Accept more ptxas options if provided
469
+ ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
470
+
471
+ ptxas_cmd = [
472
+ ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name,
473
+ '-o', fbin
474
+ ]
475
+ try:
476
+ # close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
477
+ # On Windows, both stdout and stderr need to be redirected to flog
478
+ subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog,
479
+ stderr=flog)
480
+ if knobs.nvidia.dump_ptxas_log:
481
+ with open(flog.name) as log_file:
482
+ print(log_file.read())
483
+
484
+ except subprocess.CalledProcessError as e:
485
+ with open(flog.name) as log_file:
486
+ log = log_file.read()
487
+
488
+ if e.returncode == 255:
489
+ error = 'Internal Triton PTX codegen error'
490
+ elif e.returncode == 128 + signal.SIGSEGV:
491
+ error = '`ptxas` raised SIGSEGV'
492
+ else:
493
+ error = f'`ptxas` failed with error code {e.returncode}'
494
+
495
+ error = (f"{error}\n"
496
+ f"`ptxas` stderr:\n{log}\n"
497
+ f'Repro command: {" ".join(ptxas_cmd)}\n')
498
+
499
+ print(f"""
500
+
501
+ ================================================================
502
+ {error}
503
+
504
+ {src}
505
+ ================================================================
506
+ please share the reproducer above with Triton project.
507
+ """)
508
+ raise PTXASError(error)
509
+
510
+ with open(fbin, 'rb') as f:
511
+ cubin = f.read()
512
+ try_remove(fbin)
513
+
514
+ # It's better to remove the temp files outside the context managers
515
+ try_remove(fsrc.name)
516
+ try_remove(flog.name)
517
+ return cubin
518
+
519
+ def add_stages(self, stages, options, language):
520
+ capability = self._parse_arch(options.arch)
521
+ if language == Language.TRITON:
522
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
523
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
524
+ elif language == Language.GLUON:
525
+ stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability)
526
+ stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
527
+ stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
528
+ stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
529
+
530
+ @functools.lru_cache()
531
+ def hash(self):
532
+ version = get_ptxas_version()
533
+ return f'{version}-{self.target.arch}'