triton-windows 3.4.0.post20__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/_C/libtriton.pyd
CHANGED
|
Binary file
|
triton/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""isort:skip_file"""
|
|
2
|
-
__version__ = '3.
|
|
2
|
+
__version__ = '3.5.0'
|
|
3
3
|
|
|
4
4
|
# ---------------------------------------
|
|
5
5
|
# Note: import order is significant here.
|
|
@@ -17,7 +17,8 @@ from .runtime import (
|
|
|
17
17
|
InterpreterError,
|
|
18
18
|
MockTensor,
|
|
19
19
|
)
|
|
20
|
-
from .runtime.jit import jit
|
|
20
|
+
from .runtime.jit import constexpr_function, jit
|
|
21
|
+
from .runtime._async_compile import AsyncCompileMode, FutureKernel
|
|
21
22
|
from .compiler import compile, CompilationError
|
|
22
23
|
from .errors import TritonError
|
|
23
24
|
from .runtime._allocation import set_allocator
|
|
@@ -29,11 +30,14 @@ from . import tools
|
|
|
29
30
|
must_use_result = language.core.must_use_result
|
|
30
31
|
|
|
31
32
|
__all__ = [
|
|
33
|
+
"AsyncCompileMode",
|
|
32
34
|
"autotune",
|
|
33
35
|
"cdiv",
|
|
34
36
|
"CompilationError",
|
|
35
37
|
"compile",
|
|
36
38
|
"Config",
|
|
39
|
+
"constexpr_function",
|
|
40
|
+
"FutureKernel",
|
|
37
41
|
"heuristics",
|
|
38
42
|
"InterpreterError",
|
|
39
43
|
"jit",
|
|
@@ -59,10 +63,12 @@ __all__ = [
|
|
|
59
63
|
# -------------------------------------
|
|
60
64
|
|
|
61
65
|
|
|
66
|
+
@constexpr_function
|
|
62
67
|
def cdiv(x: int, y: int):
|
|
63
68
|
return (x + y - 1) // y
|
|
64
69
|
|
|
65
70
|
|
|
71
|
+
@constexpr_function
|
|
66
72
|
def next_power_of_2(n: int):
|
|
67
73
|
"""Return the smallest power of 2 greater than or equal to n"""
|
|
68
74
|
n -= 1
|
triton/_filecheck.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import os
|
|
2
3
|
import inspect
|
|
3
4
|
import subprocess
|
|
@@ -7,6 +8,7 @@ import triton
|
|
|
7
8
|
from triton.compiler import ASTSource, make_backend
|
|
8
9
|
from triton.backends.compiler import GPUTarget
|
|
9
10
|
from triton.experimental.gluon._runtime import GluonASTSource
|
|
11
|
+
from triton.runtime.jit import create_function_from_signature
|
|
10
12
|
from triton._C.libtriton import ir
|
|
11
13
|
|
|
12
14
|
# ===-----------------------------------------------------------------------===#
|
|
@@ -15,7 +17,6 @@ from triton._C.libtriton import ir
|
|
|
15
17
|
|
|
16
18
|
# Stub target for testing the frontend.
|
|
17
19
|
stub_target = GPUTarget("cuda", 100, 32)
|
|
18
|
-
stub_backend = make_backend(stub_target)
|
|
19
20
|
|
|
20
21
|
triton_dir = os.path.dirname(__file__)
|
|
21
22
|
filecheck_path = os.path.join(triton_dir, "FileCheck")
|
|
@@ -42,29 +43,37 @@ def run_filecheck(name, module_str, check_template):
|
|
|
42
43
|
temp.write(check_template)
|
|
43
44
|
|
|
44
45
|
try:
|
|
45
|
-
subprocess.check_output(
|
|
46
|
-
|
|
46
|
+
subprocess.check_output(
|
|
47
|
+
[filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
|
|
48
|
+
stderr=subprocess.STDOUT)
|
|
47
49
|
except subprocess.CalledProcessError as error:
|
|
48
50
|
decoded = error.output.decode('unicode_escape')
|
|
49
51
|
raise ValueError(decoded)
|
|
50
52
|
|
|
51
53
|
|
|
52
|
-
def run_parser(kernel_fn):
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
54
|
+
def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):
|
|
55
|
+
if "sanitize_overflow" not in kwargs:
|
|
56
|
+
kwargs = dict(kwargs)
|
|
57
|
+
kwargs["sanitize_overflow"] = False
|
|
58
|
+
backend = make_backend(target)
|
|
59
|
+
binder = create_function_from_signature(
|
|
60
|
+
kernel_fn.signature,
|
|
61
|
+
kernel_fn.params,
|
|
62
|
+
backend,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
bound_args, specialization, options = binder(*args, **kwargs)
|
|
66
|
+
options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options)
|
|
56
67
|
source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
|
|
57
|
-
src = source_cls(
|
|
68
|
+
src = source_cls(kernel_fn, signature, constexprs, attrs)
|
|
58
69
|
|
|
59
70
|
context = ir.context()
|
|
60
71
|
ir.load_dialects(context)
|
|
61
|
-
|
|
72
|
+
backend.load_dialects(context)
|
|
62
73
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
module_map = stub_backend.get_module_map()
|
|
67
|
-
module = src.make_ir(options, codegen_fns, module_map, context)
|
|
74
|
+
codegen_fns = backend.get_codegen_implementation(options)
|
|
75
|
+
module_map = backend.get_module_map()
|
|
76
|
+
module = src.make_ir(target, options, codegen_fns, module_map, context)
|
|
68
77
|
assert module.verify()
|
|
69
78
|
return module
|
|
70
79
|
|
|
@@ -81,6 +90,7 @@ def run_filecheck_test(kernel_fn):
|
|
|
81
90
|
|
|
82
91
|
def filecheck_test(fn):
|
|
83
92
|
|
|
93
|
+
@functools.wraps(fn)
|
|
84
94
|
def test_fn():
|
|
85
95
|
run_filecheck_test(fn)
|
|
86
96
|
|
triton/_internal_testing.py
CHANGED
|
@@ -5,10 +5,10 @@ import torch
|
|
|
5
5
|
import triton
|
|
6
6
|
import triton.language as tl
|
|
7
7
|
from triton import knobs
|
|
8
|
+
from typing import Optional, Set, Union
|
|
8
9
|
import pytest
|
|
9
10
|
|
|
10
11
|
from numpy.random import RandomState
|
|
11
|
-
from typing import Optional, Union
|
|
12
12
|
from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
|
|
13
13
|
|
|
14
14
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
|
@@ -38,10 +38,22 @@ def is_cuda():
|
|
|
38
38
|
return False if target is None else target.backend == "cuda"
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def
|
|
41
|
+
def is_ampere_or_newer():
|
|
42
|
+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 8
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def is_blackwell():
|
|
46
|
+
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def is_hopper_or_newer():
|
|
42
50
|
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
|
|
43
51
|
|
|
44
52
|
|
|
53
|
+
def is_hopper():
|
|
54
|
+
return is_cuda() and torch.cuda.get_device_capability()[0] == 9
|
|
55
|
+
|
|
56
|
+
|
|
45
57
|
def is_hip():
|
|
46
58
|
target = get_current_target()
|
|
47
59
|
return False if target is None else target.backend == "hip"
|
|
@@ -62,9 +74,13 @@ def is_hip_cdna4():
|
|
|
62
74
|
return target is not None and target.backend == 'hip' and target.arch == 'gfx950'
|
|
63
75
|
|
|
64
76
|
|
|
77
|
+
def is_hip_gfx11():
|
|
78
|
+
target = get_current_target()
|
|
79
|
+
return target is not None and target.backend == 'hip' and 'gfx11' in target.arch
|
|
80
|
+
|
|
81
|
+
|
|
65
82
|
def is_hip_gfx12():
|
|
66
83
|
target = get_current_target()
|
|
67
|
-
print(target.arch)
|
|
68
84
|
return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
|
|
69
85
|
|
|
70
86
|
|
|
@@ -72,6 +88,10 @@ def is_hip_cdna():
|
|
|
72
88
|
return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
|
|
73
89
|
|
|
74
90
|
|
|
91
|
+
def get_hip_lds_size():
|
|
92
|
+
return 163840 if is_hip_cdna4() else 65536
|
|
93
|
+
|
|
94
|
+
|
|
75
95
|
def is_xpu():
|
|
76
96
|
target = get_current_target()
|
|
77
97
|
return False if target is None else target.backend == "xpu"
|
|
@@ -132,7 +152,7 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc
|
|
|
132
152
|
|
|
133
153
|
|
|
134
154
|
def str_to_triton_dtype(x: str) -> tl.dtype:
|
|
135
|
-
return tl.str_to_ty(type_canonicalisation_dict[x])
|
|
155
|
+
return tl.str_to_ty(type_canonicalisation_dict[x], None)
|
|
136
156
|
|
|
137
157
|
|
|
138
158
|
def torch_dtype_name(dtype) -> str:
|
|
@@ -187,3 +207,49 @@ def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> t
|
|
|
187
207
|
if isinstance(t, triton.runtime.jit.TensorWrapper):
|
|
188
208
|
return t.base
|
|
189
209
|
return t
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None):
|
|
213
|
+
from triton import knobs
|
|
214
|
+
|
|
215
|
+
if skipped_attr is None:
|
|
216
|
+
skipped_attr = set()
|
|
217
|
+
|
|
218
|
+
monkeypatch = pytest.MonkeyPatch()
|
|
219
|
+
|
|
220
|
+
knobs_map = {
|
|
221
|
+
name: knobset
|
|
222
|
+
for name, knobset in knobs.__dict__.items()
|
|
223
|
+
if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
# We store which variables we need to unset below in finally because
|
|
227
|
+
# monkeypatch doesn't appear to reset variables that were never set
|
|
228
|
+
# before the monkeypatch.delenv call below.
|
|
229
|
+
env_to_unset = []
|
|
230
|
+
prev_propagate_env = knobs.propagate_env
|
|
231
|
+
|
|
232
|
+
def fresh_function():
|
|
233
|
+
nonlocal env_to_unset
|
|
234
|
+
for name, knobset in knobs_map.items():
|
|
235
|
+
setattr(knobs, name, knobset.copy().reset())
|
|
236
|
+
for knob in knobset.knob_descriptors.values():
|
|
237
|
+
if knob.key in os.environ:
|
|
238
|
+
monkeypatch.delenv(knob.key, raising=False)
|
|
239
|
+
else:
|
|
240
|
+
env_to_unset.append(knob.key)
|
|
241
|
+
knobs.propagate_env = True
|
|
242
|
+
return knobs
|
|
243
|
+
|
|
244
|
+
def reset_function():
|
|
245
|
+
for name, knobset in knobs_map.items():
|
|
246
|
+
setattr(knobs, name, knobset)
|
|
247
|
+
# `undo` should be placed before `del os.environ`
|
|
248
|
+
# Otherwise, it may restore environment variables that monkeypatch deleted
|
|
249
|
+
monkeypatch.undo()
|
|
250
|
+
for k in env_to_unset:
|
|
251
|
+
if k in os.environ:
|
|
252
|
+
del os.environ[k]
|
|
253
|
+
knobs.propagate_env = prev_propagate_env
|
|
254
|
+
|
|
255
|
+
return fresh_function, reset_function
|
triton/_utils.py
CHANGED
|
@@ -16,9 +16,11 @@ def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
|
|
19
|
+
from .language import core
|
|
19
20
|
assert len(path) != 0
|
|
20
21
|
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
|
|
21
|
-
prev
|
|
22
|
+
assert isinstance(prev, core.tuple)
|
|
23
|
+
prev._setitem(path[-1], val)
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
|
triton/backends/amd/compiler.py
CHANGED
|
@@ -7,8 +7,8 @@ from types import ModuleType
|
|
|
7
7
|
import hashlib
|
|
8
8
|
import tempfile
|
|
9
9
|
import re
|
|
10
|
-
import subprocess
|
|
11
10
|
import functools
|
|
11
|
+
import warnings
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
|
|
14
14
|
|
|
@@ -18,8 +18,9 @@ def get_min_dot_size(target: GPUTarget):
|
|
|
18
18
|
return lambda lhs_type, rhs_type: (1, 1, 1)
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def is_pingpong_schedule_enabled(arch):
|
|
22
|
-
return (arch == "gfx942"
|
|
21
|
+
def is_pingpong_schedule_enabled(arch, use_async_copy):
|
|
22
|
+
return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
|
|
23
|
+
) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def is_in_thread_transpose_enabled(arch):
|
|
@@ -37,7 +38,11 @@ class HIPOptions:
|
|
|
37
38
|
debug: bool = False
|
|
38
39
|
sanitize_overflow: bool = True
|
|
39
40
|
arch: str = None
|
|
40
|
-
|
|
41
|
+
# We have native support for OCP fp8 variants since CDNA4/RDNA4. For earlier generations,
|
|
42
|
+
# we software emulate the support for them.
|
|
43
|
+
# UZ fp8 variants (fp8e4b8 and fp8e5b16) are natively supported for CDNA3. For other
|
|
44
|
+
# architectures they are software emulated.
|
|
45
|
+
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8")
|
|
41
46
|
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
|
42
47
|
default_dot_input_precision: str = "ieee"
|
|
43
48
|
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
|
|
@@ -48,6 +53,7 @@ class HIPOptions:
|
|
|
48
53
|
allow_flush_denorm: bool = False
|
|
49
54
|
max_num_imprecise_acc_default: int = 0
|
|
50
55
|
backend_name: str = 'hip'
|
|
56
|
+
instrumentation_mode: str = ""
|
|
51
57
|
|
|
52
58
|
# The following option provides hints to the AMDGPU backend regarding instruction scheduling
|
|
53
59
|
# for all `tt.dot` operations in a kernel. The "none" variant preserves the default
|
|
@@ -57,10 +63,6 @@ class HIPOptions:
|
|
|
57
63
|
#
|
|
58
64
|
# Current experimental scheduling variants:
|
|
59
65
|
#
|
|
60
|
-
# local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
|
|
61
|
-
# Kernel library. Note, this variant requires the use of buffer load/store ops
|
|
62
|
-
# and a special software pipelining style - i.e., 1x LDS and 1x register
|
|
63
|
-
# prefetch buffers for each GEMM tile.
|
|
64
66
|
# attention: enables a bunch of optimizations for attention kernels, including:
|
|
65
67
|
# - iglp 2 and sched.barrier around it
|
|
66
68
|
# - sink-insts-to-avoid-spills flag to avoid register spills
|
|
@@ -73,8 +75,11 @@ class HIPOptions:
|
|
|
73
75
|
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
|
74
76
|
"num_warps must be a power of 2"
|
|
75
77
|
|
|
76
|
-
if self.arch == 'gfx950':
|
|
77
|
-
|
|
78
|
+
if (self.arch == 'gfx950') and (self.kpack != 1):
|
|
79
|
+
warnings.warn(
|
|
80
|
+
f"kpack is deprecated starting from gfx950 and will be removed in later releases. So for now kpack = {self.kpack} will be overwritten to 1 to make transitioning easier."
|
|
81
|
+
)
|
|
82
|
+
object.__setattr__(self, 'kpack', 1)
|
|
78
83
|
|
|
79
84
|
default_libdir = Path(__file__).parent / 'lib'
|
|
80
85
|
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
|
|
@@ -88,6 +93,7 @@ class HIPOptions:
|
|
|
88
93
|
|
|
89
94
|
|
|
90
95
|
class HIPBackend(BaseBackend):
|
|
96
|
+
instrumentation = None
|
|
91
97
|
|
|
92
98
|
@staticmethod
|
|
93
99
|
def supports_target(target: GPUTarget):
|
|
@@ -104,6 +110,9 @@ class HIPBackend(BaseBackend):
|
|
|
104
110
|
def parse_options(self, opts) -> Any:
|
|
105
111
|
args = {'arch': knobs.runtime.override_arch or self.target.arch}
|
|
106
112
|
|
|
113
|
+
if opts.get("num_ctas", 1) > 1:
|
|
114
|
+
raise ValueError("num_ctas > 1 not supported for AMD GPUs")
|
|
115
|
+
|
|
107
116
|
# Enable XF32 (TF32) for CDNA3 GPUs
|
|
108
117
|
if self.target.arch == 'gfx942':
|
|
109
118
|
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
|
|
@@ -111,14 +120,12 @@ class HIPBackend(BaseBackend):
|
|
|
111
120
|
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
|
|
112
121
|
|
|
113
122
|
if "supported_fp8_dtypes" not in opts:
|
|
114
|
-
supported_fp8_dtypes =
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
|
|
121
|
-
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
|
|
123
|
+
args["supported_fp8_dtypes"] = tuple(sorted(HIPOptions.supported_fp8_dtypes))
|
|
124
|
+
|
|
125
|
+
if self.target.arch == 'gfx950':
|
|
126
|
+
deprecated_fp8_dot_operand_dtypes = set(HIPOptions.deprecated_fp8_dot_operand_dtypes)
|
|
127
|
+
deprecated_fp8_dot_operand_dtypes.update({"fp8e5b16", "fp8e4b8"})
|
|
128
|
+
args["deprecated_fp8_dot_operand_dtypes"] = tuple(sorted(deprecated_fp8_dot_operand_dtypes))
|
|
122
129
|
|
|
123
130
|
if "enable_fp_fusion" not in opts:
|
|
124
131
|
args["enable_fp_fusion"] = knobs.language.default_fp_fusion
|
|
@@ -146,6 +153,8 @@ class HIPBackend(BaseBackend):
|
|
|
146
153
|
|
|
147
154
|
def load_dialects(self, ctx):
|
|
148
155
|
amd.load_dialects(ctx)
|
|
156
|
+
if HIPBackend.instrumentation:
|
|
157
|
+
HIPBackend.instrumentation.load_dialects(ctx)
|
|
149
158
|
|
|
150
159
|
@staticmethod
|
|
151
160
|
def is_within_2gb(arg):
|
|
@@ -174,26 +183,6 @@ class HIPBackend(BaseBackend):
|
|
|
174
183
|
ret += "S"
|
|
175
184
|
return ret
|
|
176
185
|
|
|
177
|
-
@staticmethod
|
|
178
|
-
def path_to_rocm_lld():
|
|
179
|
-
# Check env path for ld.lld
|
|
180
|
-
lld_env_path = knobs.amd.lld_path
|
|
181
|
-
if lld_env_path is not None:
|
|
182
|
-
lld = Path(lld_env_path)
|
|
183
|
-
if lld.is_file():
|
|
184
|
-
return lld
|
|
185
|
-
# Check backend for ld.lld (used for pytorch wheels)
|
|
186
|
-
lld = Path(__file__).parent / "llvm/bin/ld.lld"
|
|
187
|
-
if lld.is_file():
|
|
188
|
-
return lld
|
|
189
|
-
lld = Path("/opt/rocm/llvm/bin/ld.lld")
|
|
190
|
-
if lld.is_file():
|
|
191
|
-
return lld
|
|
192
|
-
lld = Path("/usr/bin/ld.lld")
|
|
193
|
-
if lld.is_file():
|
|
194
|
-
return lld
|
|
195
|
-
raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found. Set 'TRITON_HIP_LLD_PATH' to its path.")
|
|
196
|
-
|
|
197
186
|
@staticmethod
|
|
198
187
|
def make_ttir(mod, metadata, options):
|
|
199
188
|
pm = ir.pass_manager(mod.context)
|
|
@@ -237,12 +226,10 @@ class HIPBackend(BaseBackend):
|
|
|
237
226
|
global_prefetch = knobs.amd.global_prefetch
|
|
238
227
|
local_prefetch = knobs.amd.local_prefetch
|
|
239
228
|
use_async_copy = knobs.amd.use_async_copy
|
|
229
|
+
use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)
|
|
240
230
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
global_prefetch = local_prefetch = 1
|
|
244
|
-
|
|
245
|
-
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
|
|
231
|
+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
|
|
232
|
+
use_block_pingpong)
|
|
246
233
|
if use_async_copy:
|
|
247
234
|
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
|
|
248
235
|
passes.common.add_canonicalizer(pm)
|
|
@@ -255,14 +242,13 @@ class HIPBackend(BaseBackend):
|
|
|
255
242
|
amd.passes.ttgpuir.add_in_thread_transpose(pm)
|
|
256
243
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
257
244
|
amd.passes.ttgpuir.add_reorder_instructions(pm)
|
|
258
|
-
use_block_pingpong
|
|
259
|
-
if use_block_pingpong and options.num_stages == 2:
|
|
245
|
+
if use_block_pingpong and options.num_stages > 1:
|
|
260
246
|
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
|
|
261
247
|
|
|
262
248
|
if knobs.amd.use_buffer_ops:
|
|
263
249
|
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
|
|
264
250
|
passes.common.add_canonicalizer(pm)
|
|
265
|
-
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
|
|
251
|
+
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch, knobs.amd.use_buffer_atomics)
|
|
266
252
|
|
|
267
253
|
amd.passes.ttgpuir.add_fold_true_cmpi(pm)
|
|
268
254
|
passes.common.add_canonicalizer(pm)
|
|
@@ -274,15 +260,16 @@ class HIPBackend(BaseBackend):
|
|
|
274
260
|
return mod
|
|
275
261
|
|
|
276
262
|
@staticmethod
|
|
277
|
-
def
|
|
263
|
+
def gluon_to_ttgir(src, metadata, options):
|
|
278
264
|
mod = src
|
|
279
265
|
pm = ir.pass_manager(mod.context)
|
|
280
266
|
pm.enable_debug()
|
|
281
267
|
|
|
282
|
-
passes.
|
|
268
|
+
passes.gluon.add_inliner(pm)
|
|
269
|
+
passes.gluon.add_resolve_auto_encodings(pm)
|
|
283
270
|
passes.common.add_sccp(pm)
|
|
284
271
|
passes.ttir.add_loop_aware_cse(pm)
|
|
285
|
-
passes.
|
|
272
|
+
passes.gluon.add_canonicalizer(pm)
|
|
286
273
|
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
287
274
|
|
|
288
275
|
pm.run(mod)
|
|
@@ -304,7 +291,10 @@ class HIPBackend(BaseBackend):
|
|
|
304
291
|
passes.convert.add_scf_to_cf(pm)
|
|
305
292
|
passes.convert.add_index_to_llvmir(pm)
|
|
306
293
|
|
|
307
|
-
passes.ttgpuir.add_allocate_shared_memory(pm)
|
|
294
|
+
amd.passes.ttgpuir.add_allocate_shared_memory(pm)
|
|
295
|
+
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
|
|
296
|
+
if HIPBackend.instrumentation:
|
|
297
|
+
HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
|
|
308
298
|
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
|
|
309
299
|
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
|
|
310
300
|
## of the value of kernel arg `allow_flush_denorm`.
|
|
@@ -322,10 +312,17 @@ class HIPBackend(BaseBackend):
|
|
|
322
312
|
passes.common.add_canonicalizer(pm)
|
|
323
313
|
passes.common.add_cse(pm)
|
|
324
314
|
passes.common.add_symbol_dce(pm)
|
|
315
|
+
|
|
325
316
|
if options.schedule_hint.lower() != "none":
|
|
326
317
|
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
|
|
318
|
+
|
|
319
|
+
# This can not be moved below the di_scope pass
|
|
320
|
+
if HIPBackend.instrumentation:
|
|
321
|
+
HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
|
|
322
|
+
|
|
327
323
|
if not knobs.compilation.disable_line_info:
|
|
328
324
|
passes.llvmir.add_di_scope(pm)
|
|
325
|
+
|
|
329
326
|
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
|
|
330
327
|
pm.run(mod)
|
|
331
328
|
|
|
@@ -382,15 +379,27 @@ class HIPBackend(BaseBackend):
|
|
|
382
379
|
llvm.link_extern_libs(llvm_mod, paths)
|
|
383
380
|
elif options.extern_libs:
|
|
384
381
|
paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
|
|
385
|
-
|
|
382
|
+
if len(paths) > 0:
|
|
383
|
+
llvm.link_extern_libs(llvm_mod, paths)
|
|
386
384
|
|
|
387
385
|
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
|
|
388
386
|
|
|
387
|
+
# Architectures with architected SGPRs store the workgroup id in ttmp9 (X) and ttmp7 (Y[15:0], Z[31:16]).
|
|
388
|
+
# These attributes are used to determine if Z should be masked out when loading Y. They are inferred during
|
|
389
|
+
# optimize_module from calls to @llvm.amdgcn.workgroup.id.x/y/z(). We cannot rely on this because a
|
|
390
|
+
# dispatch dimensions might be used even if there is no program_id() call for it.
|
|
391
|
+
if amd.has_architected_sgprs(options.arch):
|
|
392
|
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-x")
|
|
393
|
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-y")
|
|
394
|
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-z")
|
|
395
|
+
|
|
389
396
|
if knobs.amd.scalarize_packed_fops:
|
|
390
397
|
amd.add_scalarize_packed_fops_llvm_pass(fns[0])
|
|
391
398
|
|
|
392
399
|
# Get some metadata
|
|
393
400
|
metadata["shared"] = src.get_int_attr("ttg.shared")
|
|
401
|
+
metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
|
|
402
|
+
metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
|
|
394
403
|
|
|
395
404
|
amd.cleanup_bitcode_metadata(llvm_mod)
|
|
396
405
|
# Disable inlining of print related functions,
|
|
@@ -414,7 +423,9 @@ class HIPBackend(BaseBackend):
|
|
|
414
423
|
# the regression is not significant. It would be better to have some heuristics.
|
|
415
424
|
if options.schedule_hint == 'attention':
|
|
416
425
|
flags.append('sink-insts-to-avoid-spills')
|
|
417
|
-
|
|
426
|
+
features = '-real-true16' if 'gfx11' in options.arch else ''
|
|
427
|
+
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
|
|
428
|
+
False)
|
|
418
429
|
if knobs.amd.dump_amdgcn:
|
|
419
430
|
print("// -----// AMDGCN Dump //----- //")
|
|
420
431
|
print(amdgcn)
|
|
@@ -426,14 +437,12 @@ class HIPBackend(BaseBackend):
|
|
|
426
437
|
if knobs.compilation.enable_asan:
|
|
427
438
|
target_features = '+xnack'
|
|
428
439
|
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
|
|
429
|
-
|
|
430
|
-
rocm_path = HIPBackend.path_to_rocm_lld()
|
|
431
440
|
with tempfile.NamedTemporaryFile() as tmp_out:
|
|
432
441
|
with tempfile.NamedTemporaryFile() as tmp_in:
|
|
433
|
-
with open(tmp_in.name,
|
|
442
|
+
with open(tmp_in.name, "wb") as fd_in:
|
|
434
443
|
fd_in.write(hsaco)
|
|
435
|
-
|
|
436
|
-
with open(tmp_out.name,
|
|
444
|
+
amd.link_hsaco(tmp_in.name, tmp_out.name)
|
|
445
|
+
with open(tmp_out.name, "rb") as fd_out:
|
|
437
446
|
ret = fd_out.read()
|
|
438
447
|
return ret
|
|
439
448
|
|
|
@@ -442,12 +451,11 @@ class HIPBackend(BaseBackend):
|
|
|
442
451
|
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
|
|
443
452
|
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
|
|
444
453
|
elif language == Language.GLUON:
|
|
445
|
-
stages["ttgir"] = lambda src, metadata: self.
|
|
454
|
+
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
|
|
446
455
|
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
|
|
447
456
|
stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
|
|
448
457
|
stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
|
|
449
458
|
|
|
450
459
|
@functools.lru_cache()
|
|
451
460
|
def hash(self):
|
|
452
|
-
|
|
453
|
-
return f'{version}-{self.target}'
|
|
461
|
+
return f'{self.target}'
|