triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,533 @@
|
|
|
1
|
+
from triton.backends.compiler import BaseBackend, GPUTarget, Language
|
|
2
|
+
from triton._C.libtriton import ir, passes, llvm, nvidia
|
|
3
|
+
from triton import knobs
|
|
4
|
+
from triton.runtime.errors import PTXASError
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
import functools
|
|
8
|
+
from typing import Any, Dict, Tuple, Optional
|
|
9
|
+
from types import ModuleType
|
|
10
|
+
import hashlib
|
|
11
|
+
import re
|
|
12
|
+
import tempfile
|
|
13
|
+
import signal
|
|
14
|
+
import os
|
|
15
|
+
import subprocess
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def min_dot_size(target: GPUTarget):
|
|
20
|
+
|
|
21
|
+
def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
|
|
22
|
+
lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
|
|
23
|
+
rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
|
|
24
|
+
assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
|
|
25
|
+
# For small M/N the input we can still use tensorcores with padding.
|
|
26
|
+
if lhs_bitwidth == 8:
|
|
27
|
+
return (1, 1, 32)
|
|
28
|
+
else:
|
|
29
|
+
return (1, 1, 16)
|
|
30
|
+
|
|
31
|
+
return check_dot_compatibility
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_ptxas() -> knobs.NvidiaTool:
|
|
35
|
+
return knobs.nvidia.ptxas
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@functools.lru_cache()
|
|
39
|
+
def get_ptxas_version():
|
|
40
|
+
mock_ver = knobs.nvidia.mock_ptx_version
|
|
41
|
+
if mock_ver is not None:
|
|
42
|
+
return mock_ver # This is not really a version of ptxas, but it is good enough for testing
|
|
43
|
+
version = subprocess.check_output([get_ptxas().path, "--version"]).decode("utf-8")
|
|
44
|
+
return version
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@functools.lru_cache()
|
|
48
|
+
def ptx_get_version(cuda_version) -> int:
|
|
49
|
+
'''
|
|
50
|
+
Get the highest PTX version supported by the current CUDA driver.
|
|
51
|
+
'''
|
|
52
|
+
assert isinstance(cuda_version, str)
|
|
53
|
+
major, minor = map(int, cuda_version.split('.'))
|
|
54
|
+
if major == 12:
|
|
55
|
+
if minor < 6:
|
|
56
|
+
return 80 + minor
|
|
57
|
+
else:
|
|
58
|
+
return 80 + minor - 1
|
|
59
|
+
if major == 11:
|
|
60
|
+
return 70 + minor
|
|
61
|
+
if major == 10:
|
|
62
|
+
return 63 + minor
|
|
63
|
+
|
|
64
|
+
if major >= 13:
|
|
65
|
+
base_ptx = 90
|
|
66
|
+
return base_ptx + (major - 13) * 10 + minor
|
|
67
|
+
|
|
68
|
+
raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_ptx_version_from_options(options, arch: int):
|
|
72
|
+
ptx_version = options.ptx_version
|
|
73
|
+
if ptx_version is None:
|
|
74
|
+
cuda_version = get_ptxas().version
|
|
75
|
+
ptx_version = ptx_get_version(cuda_version)
|
|
76
|
+
return ptx_version
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@functools.lru_cache()
|
|
80
|
+
def get_features(options, arch: int):
|
|
81
|
+
ptx_version = get_ptx_version_from_options(options, arch)
|
|
82
|
+
|
|
83
|
+
# PTX 8.6 is the max version supported by llvm c1188642.
|
|
84
|
+
#
|
|
85
|
+
# To check if a newer PTX version is supported, increase this value
|
|
86
|
+
# and run a test. If it's not supported, LLVM will print a warning
|
|
87
|
+
# like "+ptx8.4 is not a recognized feature for this target".
|
|
88
|
+
llvm_ptx_version = min(86, ptx_version)
|
|
89
|
+
features = f'+ptx{llvm_ptx_version}'
|
|
90
|
+
return features
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@functools.lru_cache(None)
|
|
94
|
+
def file_hash(path):
|
|
95
|
+
with open(path, "rb") as f:
|
|
96
|
+
return hashlib.sha256(f.read()).hexdigest()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def sm_arch_from_capability(capability: int):
|
|
100
|
+
# TODO: Handle non-"a" sms
|
|
101
|
+
suffix = "a" if capability >= 90 else ""
|
|
102
|
+
return f"sm_{capability}{suffix}"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# The file may be accessed in parallel
|
|
106
|
+
def try_remove(path):
|
|
107
|
+
if os.path.exists(path):
|
|
108
|
+
try:
|
|
109
|
+
os.remove(path)
|
|
110
|
+
except OSError:
|
|
111
|
+
import traceback
|
|
112
|
+
traceback.print_exc()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(frozen=True)
|
|
116
|
+
class CUDAOptions:
|
|
117
|
+
num_warps: int = 4
|
|
118
|
+
num_ctas: int = 1
|
|
119
|
+
num_stages: int = 3
|
|
120
|
+
warp_size: int = 32
|
|
121
|
+
# maxnreg corresponds to the ptx parameter .maxnreg, which controls the
|
|
122
|
+
# maximum number of 32-bit registers used by one thread.
|
|
123
|
+
maxnreg: Optional[int] = None
|
|
124
|
+
cluster_dims: tuple = (1, 1, 1)
|
|
125
|
+
ptx_version: int = None
|
|
126
|
+
ptx_options: str = None
|
|
127
|
+
ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
|
|
128
|
+
enable_fp_fusion: bool = True
|
|
129
|
+
launch_cooperative_grid: bool = False
|
|
130
|
+
launch_pdl: bool = False
|
|
131
|
+
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
|
|
132
|
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
|
133
|
+
default_dot_input_precision: str = "tf32"
|
|
134
|
+
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
|
|
135
|
+
max_num_imprecise_acc_default: bool = None
|
|
136
|
+
extern_libs: dict = None
|
|
137
|
+
debug: bool = False
|
|
138
|
+
backend_name: str = 'cuda'
|
|
139
|
+
sanitize_overflow: bool = True
|
|
140
|
+
arch: str = None
|
|
141
|
+
instrumentation_mode: str = ""
|
|
142
|
+
|
|
143
|
+
def __post_init__(self):
|
|
144
|
+
default_libdir = Path(__file__).parent / 'lib'
|
|
145
|
+
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
|
|
146
|
+
if not extern_libs.get('libdevice', None):
|
|
147
|
+
extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
|
|
148
|
+
|
|
149
|
+
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
|
|
150
|
+
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
|
151
|
+
"num_warps must be a power of 2"
|
|
152
|
+
|
|
153
|
+
def hash(self):
|
|
154
|
+
hash_dict = dict(self.__dict__)
|
|
155
|
+
hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
|
|
156
|
+
key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
|
|
157
|
+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class CUDABackend(BaseBackend):
|
|
161
|
+
instrumentation = None
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def supports_target(target: GPUTarget):
|
|
165
|
+
return target.backend == 'cuda'
|
|
166
|
+
|
|
167
|
+
def _parse_arch(self, arch):
|
|
168
|
+
pattern = r"^sm(\d+)$"
|
|
169
|
+
match = re.fullmatch(pattern, arch)
|
|
170
|
+
if not match:
|
|
171
|
+
raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
|
|
172
|
+
return int(match.group(1))
|
|
173
|
+
|
|
174
|
+
def get_target_name(self, options) -> str:
|
|
175
|
+
capability = self._parse_arch(options.arch)
|
|
176
|
+
return f"cuda:{capability}"
|
|
177
|
+
|
|
178
|
+
def __init__(self, target: GPUTarget) -> None:
|
|
179
|
+
super().__init__(target)
|
|
180
|
+
self.binary_ext = "cubin"
|
|
181
|
+
|
|
182
|
+
def parse_options(self, opts) -> Any:
|
|
183
|
+
args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
|
|
184
|
+
args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
|
|
185
|
+
capability = int(self._parse_arch(args["arch"]))
|
|
186
|
+
|
|
187
|
+
if args.get("num_ctas", 1) > 1 and capability < 90:
|
|
188
|
+
raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
|
|
189
|
+
f"Current target is sm_{capability}. This configuration will fail. "
|
|
190
|
+
f"Please set num_ctas=1 or target an SM90+ GPU."))
|
|
191
|
+
|
|
192
|
+
if "supported_fp8_dtypes" not in args:
|
|
193
|
+
supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
|
|
194
|
+
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
|
|
195
|
+
|
|
196
|
+
if "deprecated_fp8_dot_operand_dtypes" not in args:
|
|
197
|
+
if capability >= 90:
|
|
198
|
+
args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
|
|
199
|
+
|
|
200
|
+
if "enable_fp_fusion" not in args:
|
|
201
|
+
args["enable_fp_fusion"] = knobs.language.default_fp_fusion
|
|
202
|
+
|
|
203
|
+
args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
|
|
204
|
+
|
|
205
|
+
return CUDAOptions(**args)
|
|
206
|
+
|
|
207
|
+
def pack_metadata(self, metadata):
|
|
208
|
+
return (
|
|
209
|
+
metadata.num_warps,
|
|
210
|
+
metadata.num_ctas,
|
|
211
|
+
metadata.shared,
|
|
212
|
+
metadata.cluster_dims[0],
|
|
213
|
+
metadata.cluster_dims[1],
|
|
214
|
+
metadata.cluster_dims[2],
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def get_codegen_implementation(self, options):
|
|
218
|
+
import triton.language.extra.cuda as cuda
|
|
219
|
+
capability = int(self._parse_arch(options.arch))
|
|
220
|
+
codegen_fns = {
|
|
221
|
+
"convert_custom_types":
|
|
222
|
+
cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
|
|
223
|
+
min_dot_size(self.target)
|
|
224
|
+
}
|
|
225
|
+
return codegen_fns
|
|
226
|
+
|
|
227
|
+
def get_module_map(self) -> Dict[str, ModuleType]:
|
|
228
|
+
from triton.language.extra.cuda import libdevice
|
|
229
|
+
return {"triton.language.extra.libdevice": libdevice}
|
|
230
|
+
|
|
231
|
+
def load_dialects(self, ctx):
|
|
232
|
+
nvidia.load_dialects(ctx)
|
|
233
|
+
if CUDABackend.instrumentation:
|
|
234
|
+
CUDABackend.instrumentation.load_dialects(ctx)
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def make_ttir(mod, metadata, opt, capability):
|
|
238
|
+
pm = ir.pass_manager(mod.context)
|
|
239
|
+
pm.enable_debug()
|
|
240
|
+
passes.common.add_inliner(pm)
|
|
241
|
+
passes.ttir.add_rewrite_tensor_pointer(pm)
|
|
242
|
+
if capability // 10 < 9:
|
|
243
|
+
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
|
|
244
|
+
passes.common.add_canonicalizer(pm)
|
|
245
|
+
passes.ttir.add_combine(pm)
|
|
246
|
+
passes.ttir.add_reorder_broadcast(pm)
|
|
247
|
+
passes.common.add_cse(pm)
|
|
248
|
+
passes.common.add_symbol_dce(pm)
|
|
249
|
+
passes.ttir.add_loop_unroll(pm)
|
|
250
|
+
pm.run(mod)
|
|
251
|
+
return mod
|
|
252
|
+
|
|
253
|
+
@staticmethod
|
|
254
|
+
def make_ttgir(mod, metadata, opt, capability):
|
|
255
|
+
# Set maxnreg on all kernels, if it was provided.
|
|
256
|
+
if opt.maxnreg is not None:
|
|
257
|
+
mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
|
|
258
|
+
|
|
259
|
+
cluster_info = nvidia.ClusterInfo()
|
|
260
|
+
if opt.cluster_dims is not None:
|
|
261
|
+
cluster_info.clusterDimX = opt.cluster_dims[0]
|
|
262
|
+
cluster_info.clusterDimY = opt.cluster_dims[1]
|
|
263
|
+
cluster_info.clusterDimZ = opt.cluster_dims[2]
|
|
264
|
+
pm = ir.pass_manager(mod.context)
|
|
265
|
+
dump_enabled = pm.enable_debug()
|
|
266
|
+
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
|
|
267
|
+
# optimize TTGIR
|
|
268
|
+
passes.ttgpuir.add_coalesce(pm)
|
|
269
|
+
if capability // 10 >= 8:
|
|
270
|
+
passes.ttgpuir.add_f32_dot_tc(pm)
|
|
271
|
+
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
|
|
272
|
+
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
|
|
273
|
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
274
|
+
passes.ttgpuir.add_optimize_thread_locality(pm)
|
|
275
|
+
passes.ttgpuir.add_accelerate_matmul(pm)
|
|
276
|
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
277
|
+
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
|
278
|
+
nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
|
|
279
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
280
|
+
if capability // 10 in [8, 9]:
|
|
281
|
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
|
282
|
+
passes.common.add_canonicalizer(pm)
|
|
283
|
+
passes.ttir.add_triton_licm(pm)
|
|
284
|
+
passes.common.add_canonicalizer(pm)
|
|
285
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
286
|
+
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
|
|
287
|
+
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
|
|
288
|
+
passes.ttgpuir.add_schedule_loops(pm)
|
|
289
|
+
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
|
|
290
|
+
elif capability // 10 >= 10:
|
|
291
|
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
|
292
|
+
passes.common.add_canonicalizer(pm)
|
|
293
|
+
passes.ttir.add_triton_licm(pm)
|
|
294
|
+
passes.ttgpuir.add_optimize_accumulator_init(pm)
|
|
295
|
+
passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
|
|
296
|
+
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
|
|
297
|
+
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
|
|
298
|
+
passes.ttgpuir.add_schedule_loops(pm)
|
|
299
|
+
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
|
|
300
|
+
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
|
|
301
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
302
|
+
# hoist again and allow hoisting out of if statements
|
|
303
|
+
passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
|
|
304
|
+
nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
|
|
305
|
+
else:
|
|
306
|
+
passes.ttir.add_triton_licm(pm)
|
|
307
|
+
passes.common.add_canonicalizer(pm)
|
|
308
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
309
|
+
passes.ttgpuir.add_prefetch(pm)
|
|
310
|
+
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
|
311
|
+
passes.ttgpuir.add_coalesce_async_copy(pm)
|
|
312
|
+
nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
|
|
313
|
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
314
|
+
nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
|
|
315
|
+
passes.ttgpuir.add_reduce_data_duplication(pm)
|
|
316
|
+
passes.ttgpuir.add_reorder_instructions(pm)
|
|
317
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
318
|
+
passes.common.add_symbol_dce(pm)
|
|
319
|
+
if capability // 10 >= 9:
|
|
320
|
+
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
|
|
321
|
+
nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
|
|
322
|
+
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
|
|
323
|
+
passes.common.add_sccp(pm)
|
|
324
|
+
passes.common.add_cse(pm)
|
|
325
|
+
passes.common.add_canonicalizer(pm)
|
|
326
|
+
|
|
327
|
+
pm.run(mod)
|
|
328
|
+
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
|
|
329
|
+
tensordesc_meta = mod.get_tensordesc_metadata()
|
|
330
|
+
metadata["tensordesc_meta"] = tensordesc_meta
|
|
331
|
+
return mod
|
|
332
|
+
|
|
333
|
+
def gluon_to_ttgir(self, src, metadata, options, capability):
|
|
334
|
+
mod = src
|
|
335
|
+
pm = ir.pass_manager(mod.context)
|
|
336
|
+
pm.enable_debug()
|
|
337
|
+
|
|
338
|
+
passes.gluon.add_inliner(pm)
|
|
339
|
+
passes.gluon.add_resolve_auto_encodings(pm)
|
|
340
|
+
passes.common.add_sccp(pm)
|
|
341
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
342
|
+
passes.gluon.add_canonicalizer(pm)
|
|
343
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
344
|
+
|
|
345
|
+
pm.run(mod)
|
|
346
|
+
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
|
|
347
|
+
return mod
|
|
348
|
+
|
|
349
|
+
def make_llir(self, src, metadata, options, capability):
|
|
350
|
+
ptx_version = get_ptx_version_from_options(options, self.target.arch)
|
|
351
|
+
|
|
352
|
+
mod = src
|
|
353
|
+
# TritonGPU -> LLVM-IR (MLIR)
|
|
354
|
+
pm = ir.pass_manager(mod.context)
|
|
355
|
+
pm.enable_debug()
|
|
356
|
+
|
|
357
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
358
|
+
passes.ttgpuir.add_allocate_warp_groups(pm)
|
|
359
|
+
passes.convert.add_scf_to_cf(pm)
|
|
360
|
+
nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
|
|
361
|
+
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
|
|
362
|
+
if knobs.compilation.enable_experimental_consan:
|
|
363
|
+
# Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
|
|
364
|
+
passes.ttgpuir.add_concurrency_sanitizer(pm)
|
|
365
|
+
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
|
|
366
|
+
nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
|
|
367
|
+
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
|
|
368
|
+
if CUDABackend.instrumentation:
|
|
369
|
+
CUDABackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
|
|
370
|
+
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
|
|
371
|
+
passes.common.add_canonicalizer(pm)
|
|
372
|
+
passes.common.add_cse(pm)
|
|
373
|
+
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
|
|
374
|
+
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
|
|
375
|
+
passes.common.add_canonicalizer(pm)
|
|
376
|
+
passes.common.add_cse(pm)
|
|
377
|
+
passes.common.add_symbol_dce(pm)
|
|
378
|
+
passes.convert.add_nvvm_to_llvm(pm)
|
|
379
|
+
if not knobs.compilation.disable_line_info:
|
|
380
|
+
passes.llvmir.add_di_scope(pm)
|
|
381
|
+
if CUDABackend.instrumentation:
|
|
382
|
+
CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
|
|
383
|
+
|
|
384
|
+
pm.run(mod)
|
|
385
|
+
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
|
|
386
|
+
llvm.init_targets()
|
|
387
|
+
context = llvm.context()
|
|
388
|
+
if knobs.compilation.enable_asan:
|
|
389
|
+
raise RuntimeError(
|
|
390
|
+
"Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
|
|
391
|
+
llvm_mod = llvm.to_module(mod, context)
|
|
392
|
+
proc = sm_arch_from_capability(capability)
|
|
393
|
+
features = get_features(options, self.target.arch)
|
|
394
|
+
triple = 'nvptx64-nvidia-cuda'
|
|
395
|
+
nvidia.set_short_ptr()
|
|
396
|
+
llvm.attach_datalayout(llvm_mod, triple, proc, features)
|
|
397
|
+
nvidia.set_nvvm_reflect_ftz(llvm_mod)
|
|
398
|
+
|
|
399
|
+
if options.extern_libs and nvidia.has_extern_deps(llvm_mod):
|
|
400
|
+
paths = [path for (name, path) in options.extern_libs]
|
|
401
|
+
llvm.link_extern_libs(llvm_mod, paths)
|
|
402
|
+
|
|
403
|
+
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
|
|
404
|
+
|
|
405
|
+
# Get some metadata
|
|
406
|
+
# warp-specialization mutates num_warps
|
|
407
|
+
total_num_warps = src.get_int_attr("ttg.total-num-warps")
|
|
408
|
+
if total_num_warps is not None:
|
|
409
|
+
metadata["num_warps"] = total_num_warps
|
|
410
|
+
metadata["shared"] = src.get_int_attr("ttg.shared")
|
|
411
|
+
metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
|
|
412
|
+
metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
|
|
413
|
+
metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
|
|
414
|
+
metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
|
|
415
|
+
metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
|
|
416
|
+
ret = str(llvm_mod)
|
|
417
|
+
del llvm_mod
|
|
418
|
+
del context
|
|
419
|
+
return ret
|
|
420
|
+
|
|
421
|
+
def make_ptx(self, src, metadata, opt, capability):
|
|
422
|
+
ptx_version = get_ptx_version_from_options(opt, self.target.arch)
|
|
423
|
+
|
|
424
|
+
triple = 'nvptx64-nvidia-cuda'
|
|
425
|
+
proc = sm_arch_from_capability(capability)
|
|
426
|
+
features = get_features(opt, self.target.arch)
|
|
427
|
+
ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
|
|
428
|
+
# Find kernel names (there should only be one)
|
|
429
|
+
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
|
|
430
|
+
assert len(names) == 1
|
|
431
|
+
metadata["name"] = names[0]
|
|
432
|
+
# post-process
|
|
433
|
+
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
|
|
434
|
+
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
|
|
435
|
+
ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
|
|
436
|
+
# Remove the debug flag that prevents ptxas from optimizing the code
|
|
437
|
+
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
|
|
438
|
+
if knobs.nvidia.dump_nvptx:
|
|
439
|
+
print("// -----// NVPTX Dump //----- //")
|
|
440
|
+
print(ret)
|
|
441
|
+
return ret
|
|
442
|
+
|
|
443
|
+
def make_cubin(self, src, metadata, opt, capability):
|
|
444
|
+
ptxas = get_ptxas().path
|
|
445
|
+
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
|
|
446
|
+
tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
|
|
447
|
+
fsrc.write(src)
|
|
448
|
+
fsrc.flush()
|
|
449
|
+
fbin = fsrc.name + '.o'
|
|
450
|
+
|
|
451
|
+
debug_info = []
|
|
452
|
+
if knobs.compilation.disable_line_info:
|
|
453
|
+
# This option is ignored if used without -lineinfo
|
|
454
|
+
debug_info += ["-lineinfo", "-suppress-debug-info"]
|
|
455
|
+
elif knobs.nvidia.disable_ptxas_opt:
|
|
456
|
+
# Synthesize complete debug info
|
|
457
|
+
debug_info += ["-g"]
|
|
458
|
+
else:
|
|
459
|
+
# Only emit line info
|
|
460
|
+
debug_info += ["-lineinfo"]
|
|
461
|
+
|
|
462
|
+
fmad = [] if opt.enable_fp_fusion else ["--fmad=false"]
|
|
463
|
+
arch = sm_arch_from_capability(capability)
|
|
464
|
+
|
|
465
|
+
# Disable ptxas optimizations if requested
|
|
466
|
+
disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
|
|
467
|
+
|
|
468
|
+
# Accept more ptxas options if provided
|
|
469
|
+
ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
|
|
470
|
+
|
|
471
|
+
ptxas_cmd = [
|
|
472
|
+
ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name,
|
|
473
|
+
'-o', fbin
|
|
474
|
+
]
|
|
475
|
+
try:
|
|
476
|
+
# close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
|
|
477
|
+
# On Windows, both stdout and stderr need to be redirected to flog
|
|
478
|
+
subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog,
|
|
479
|
+
stderr=flog)
|
|
480
|
+
if knobs.nvidia.dump_ptxas_log:
|
|
481
|
+
with open(flog.name) as log_file:
|
|
482
|
+
print(log_file.read())
|
|
483
|
+
|
|
484
|
+
except subprocess.CalledProcessError as e:
|
|
485
|
+
with open(flog.name) as log_file:
|
|
486
|
+
log = log_file.read()
|
|
487
|
+
|
|
488
|
+
if e.returncode == 255:
|
|
489
|
+
error = 'Internal Triton PTX codegen error'
|
|
490
|
+
elif e.returncode == 128 + signal.SIGSEGV:
|
|
491
|
+
error = '`ptxas` raised SIGSEGV'
|
|
492
|
+
else:
|
|
493
|
+
error = f'`ptxas` failed with error code {e.returncode}'
|
|
494
|
+
|
|
495
|
+
error = (f"{error}\n"
|
|
496
|
+
f"`ptxas` stderr:\n{log}\n"
|
|
497
|
+
f'Repro command: {" ".join(ptxas_cmd)}\n')
|
|
498
|
+
|
|
499
|
+
print(f"""
|
|
500
|
+
|
|
501
|
+
================================================================
|
|
502
|
+
{error}
|
|
503
|
+
|
|
504
|
+
{src}
|
|
505
|
+
================================================================
|
|
506
|
+
please share the reproducer above with Triton project.
|
|
507
|
+
""")
|
|
508
|
+
raise PTXASError(error)
|
|
509
|
+
|
|
510
|
+
with open(fbin, 'rb') as f:
|
|
511
|
+
cubin = f.read()
|
|
512
|
+
try_remove(fbin)
|
|
513
|
+
|
|
514
|
+
# It's better to remove the temp files outside the context managers
|
|
515
|
+
try_remove(fsrc.name)
|
|
516
|
+
try_remove(flog.name)
|
|
517
|
+
return cubin
|
|
518
|
+
|
|
519
|
+
def add_stages(self, stages, options, language):
|
|
520
|
+
capability = self._parse_arch(options.arch)
|
|
521
|
+
if language == Language.TRITON:
|
|
522
|
+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
|
|
523
|
+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
|
|
524
|
+
elif language == Language.GLUON:
|
|
525
|
+
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability)
|
|
526
|
+
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
|
|
527
|
+
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
|
|
528
|
+
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
|
|
529
|
+
|
|
530
|
+
@functools.lru_cache()
|
|
531
|
+
def hash(self):
|
|
532
|
+
version = get_ptxas_version()
|
|
533
|
+
return f'{version}-{self.target.arch}'
|