triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
triton/_C/libtriton.pyd
CHANGED
|
Binary file
|
triton/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""isort:skip_file"""
|
|
2
|
-
__version__ = '3.
|
|
2
|
+
__version__ = '3.3.0'
|
|
3
3
|
|
|
4
4
|
# Users may not know how to add cl and CUDA to PATH. Let's do it before loading anything
|
|
5
5
|
import os
|
|
@@ -32,6 +32,7 @@ from .runtime import (
|
|
|
32
32
|
from .runtime.jit import jit
|
|
33
33
|
from .compiler import compile, CompilationError
|
|
34
34
|
from .errors import TritonError
|
|
35
|
+
from .runtime._allocation import set_allocator
|
|
35
36
|
|
|
36
37
|
from . import language
|
|
37
38
|
from . import testing
|
|
@@ -44,7 +45,6 @@ __all__ = [
|
|
|
44
45
|
"compile",
|
|
45
46
|
"Config",
|
|
46
47
|
"heuristics",
|
|
47
|
-
"impl",
|
|
48
48
|
"InterpreterError",
|
|
49
49
|
"jit",
|
|
50
50
|
"JITFunction",
|
|
@@ -52,10 +52,10 @@ __all__ = [
|
|
|
52
52
|
"language",
|
|
53
53
|
"MockTensor",
|
|
54
54
|
"next_power_of_2",
|
|
55
|
-
"ops",
|
|
56
55
|
"OutOfResources",
|
|
57
56
|
"reinterpret",
|
|
58
57
|
"runtime",
|
|
58
|
+
"set_allocator",
|
|
59
59
|
"TensorWrapper",
|
|
60
60
|
"TritonError",
|
|
61
61
|
"testing",
|
triton/_internal_testing.py
CHANGED
|
@@ -4,16 +4,18 @@ import numpy as np
|
|
|
4
4
|
import torch
|
|
5
5
|
import triton
|
|
6
6
|
import triton.language as tl
|
|
7
|
+
from triton.backends.nvidia.compiler import _path_to_binary
|
|
7
8
|
import pytest
|
|
8
9
|
|
|
9
10
|
from numpy.random import RandomState
|
|
10
11
|
from typing import Optional, Union
|
|
11
|
-
from triton.runtime.jit import TensorWrapper, reinterpret
|
|
12
|
+
from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
|
|
12
13
|
|
|
13
14
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
|
14
15
|
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
|
15
16
|
integral_dtypes = int_dtypes + uint_dtypes
|
|
16
17
|
float_dtypes = ['float16', 'float32', 'float64']
|
|
18
|
+
float_dtypes_with_bfloat16 = float_dtypes + ['bfloat16']
|
|
17
19
|
dtypes = integral_dtypes + float_dtypes
|
|
18
20
|
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
|
19
21
|
torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
|
|
@@ -35,11 +37,45 @@ def is_cuda():
|
|
|
35
37
|
return False if target is None else target.backend == "cuda"
|
|
36
38
|
|
|
37
39
|
|
|
40
|
+
def is_hopper():
|
|
41
|
+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
|
|
42
|
+
|
|
43
|
+
|
|
38
44
|
def is_hip():
|
|
39
45
|
target = get_current_target()
|
|
40
46
|
return False if target is None else target.backend == "hip"
|
|
41
47
|
|
|
42
48
|
|
|
49
|
+
def is_hip_mi200():
|
|
50
|
+
target = get_current_target()
|
|
51
|
+
if target is None or target.backend != 'hip':
|
|
52
|
+
return False
|
|
53
|
+
return target.arch == 'gfx90a'
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def is_hip_mi300():
|
|
57
|
+
target = get_current_target()
|
|
58
|
+
if target is None or target.backend != 'hip':
|
|
59
|
+
return False
|
|
60
|
+
return target.arch in ('gfx940', 'gfx941', 'gfx942')
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def is_hip_mi350():
|
|
64
|
+
target = get_current_target()
|
|
65
|
+
if target is None or target.backend != 'hip':
|
|
66
|
+
return False
|
|
67
|
+
return target.arch in ('gfx950')
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def is_hip_cdna():
|
|
71
|
+
return is_hip_mi200() or is_hip_mi300() or is_hip_mi350()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def is_xpu():
|
|
75
|
+
target = get_current_target()
|
|
76
|
+
return False if target is None else target.backend == "xpu"
|
|
77
|
+
|
|
78
|
+
|
|
43
79
|
def get_arch():
|
|
44
80
|
target = get_current_target()
|
|
45
81
|
return "" if target is None else str(target.arch)
|
|
@@ -94,6 +130,10 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc
|
|
|
94
130
|
return torch.tensor(x, device=device)
|
|
95
131
|
|
|
96
132
|
|
|
133
|
+
def str_to_triton_dtype(x: str) -> tl.dtype:
|
|
134
|
+
return tl.str_to_ty(type_canonicalisation_dict[x])
|
|
135
|
+
|
|
136
|
+
|
|
97
137
|
def torch_dtype_name(dtype) -> str:
|
|
98
138
|
if isinstance(dtype, triton.language.dtype):
|
|
99
139
|
return dtype.name
|
|
@@ -116,8 +156,23 @@ def to_numpy(x):
|
|
|
116
156
|
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
|
117
157
|
|
|
118
158
|
|
|
119
|
-
def supports_tma():
|
|
120
|
-
|
|
159
|
+
def supports_tma(byval_only=False):
|
|
160
|
+
if is_interpreter():
|
|
161
|
+
return True
|
|
162
|
+
if not is_cuda():
|
|
163
|
+
return False
|
|
164
|
+
_, cuda_version = _path_to_binary("ptxas")
|
|
165
|
+
min_cuda_version = (12, 0) if byval_only else (12, 3)
|
|
166
|
+
cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
|
|
167
|
+
assert len(cuda_version_tuple) == 2, cuda_version_tuple
|
|
168
|
+
return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def tma_skip_msg(byval_only=False):
|
|
172
|
+
if byval_only:
|
|
173
|
+
return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)"
|
|
174
|
+
else:
|
|
175
|
+
return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)"
|
|
121
176
|
|
|
122
177
|
|
|
123
|
-
requires_tma = pytest.mark.skipif(not supports_tma(), reason=
|
|
178
|
+
requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())
|
triton/_utils.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from functools import reduce
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_iterable_path(iterable, path):
|
|
5
|
+
return reduce(lambda a, idx: a[idx], path, iterable)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def set_iterable_path(iterable, path, val):
|
|
9
|
+
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
|
|
10
|
+
prev[path[-1]] = val
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def find_paths_if(iterable, pred):
|
|
14
|
+
from .language import core
|
|
15
|
+
is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
|
|
16
|
+
ret = dict()
|
|
17
|
+
|
|
18
|
+
def _impl(current, path):
|
|
19
|
+
path = (path[0], ) if len(path) == 1 else tuple(path)
|
|
20
|
+
if is_iterable(current):
|
|
21
|
+
for idx, item in enumerate(current):
|
|
22
|
+
_impl(item, path + (idx, ))
|
|
23
|
+
elif pred(path, current):
|
|
24
|
+
if len(path) == 1:
|
|
25
|
+
ret[(path[0], )] = None
|
|
26
|
+
else:
|
|
27
|
+
ret[tuple(path)] = None
|
|
28
|
+
|
|
29
|
+
if is_iterable(iterable):
|
|
30
|
+
_impl(iterable, [])
|
|
31
|
+
elif pred(list(), iterable):
|
|
32
|
+
ret = {tuple(): None}
|
|
33
|
+
else:
|
|
34
|
+
ret = dict()
|
|
35
|
+
return list(ret.keys())
|
triton/backends/amd/compiler.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from triton.backends.compiler import BaseBackend, GPUTarget
|
|
1
|
+
from triton.backends.compiler import BaseBackend, GPUTarget
|
|
2
2
|
from triton._C.libtriton import ir, passes, llvm, amd
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Any, Dict, Tuple
|
|
@@ -13,16 +13,13 @@ from pathlib import Path
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def min_dot_size(target: GPUTarget):
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
return lambda lhsType, rhsType: (16, 16, 8)
|
|
24
|
-
# Other architectures will only support 16,16,16
|
|
25
|
-
return lambda lhsType, rhsType: (16, 16, 16)
|
|
16
|
+
# If some given configuration is not supported in hardware we fallback to FMA and cast arguments
|
|
17
|
+
return lambda lhsType, rhsType: (1, 1, 1)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def is_pingpong_enabled(arch):
|
|
21
|
+
default = "1" if arch == "gfx942" else "0"
|
|
22
|
+
return os.getenv("TRITON_HIP_USE_BLOCK_PINGPONG", default) == "1"
|
|
26
23
|
|
|
27
24
|
|
|
28
25
|
@dataclass(frozen=True)
|
|
@@ -31,10 +28,6 @@ class HIPOptions:
|
|
|
31
28
|
waves_per_eu: int = 1
|
|
32
29
|
num_stages: int = 2
|
|
33
30
|
num_ctas: int = 1
|
|
34
|
-
num_buffers_warp_spec: int = 0
|
|
35
|
-
num_consumer_groups: int = 0
|
|
36
|
-
reg_dec_producer: int = 0
|
|
37
|
-
reg_inc_consumer: int = 0
|
|
38
31
|
extern_libs: dict = None
|
|
39
32
|
cluster_dims: tuple = (1, 1, 1)
|
|
40
33
|
debug: bool = False
|
|
@@ -45,6 +38,7 @@ class HIPOptions:
|
|
|
45
38
|
default_dot_input_precision: str = "ieee"
|
|
46
39
|
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
|
|
47
40
|
enable_fp_fusion: bool = True
|
|
41
|
+
launch_cooperative_grid: bool = False
|
|
48
42
|
matrix_instr_nonkdim: int = 0
|
|
49
43
|
kpack: int = 1
|
|
50
44
|
allow_flush_denorm: bool = False
|
|
@@ -52,11 +46,23 @@ class HIPOptions:
|
|
|
52
46
|
backend_name: str = 'hip'
|
|
53
47
|
|
|
54
48
|
# The following option provides hints to the AMDGPU backend regarding instruction scheduling
|
|
55
|
-
# for all `tt.dot` operations in a kernel. The "
|
|
49
|
+
# for all `tt.dot` operations in a kernel. The "none" variant preserves the default
|
|
56
50
|
# instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
|
|
57
51
|
# The option is experimental and may change at any time regarding its semantics and/or may
|
|
58
52
|
# be gone entirely anytime.
|
|
59
|
-
|
|
53
|
+
#
|
|
54
|
+
# Current experimental scheduling variants:
|
|
55
|
+
#
|
|
56
|
+
# llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's
|
|
57
|
+
# k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels".
|
|
58
|
+
# llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's
|
|
59
|
+
# k-loop; i.e., "interleave DS and MFMA instructions for single wave small
|
|
60
|
+
# GEMM kernels.".
|
|
61
|
+
# local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
|
|
62
|
+
# Kernel library. Note, this variant requires the use of buffer load/store ops
|
|
63
|
+
# and a special software pipelining style - i.e., 1x LDS and 1x register
|
|
64
|
+
# prefetch buffers for each GEMM tile.
|
|
65
|
+
instruction_sched_variant: str = 'none'
|
|
60
66
|
|
|
61
67
|
def __post_init__(self):
|
|
62
68
|
default_libdir = Path(__file__).parent / 'lib'
|
|
@@ -64,6 +70,9 @@ class HIPOptions:
|
|
|
64
70
|
# Ignore user-defined warp size for gfx9
|
|
65
71
|
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
|
|
66
72
|
object.__setattr__(self, 'warp_size', warp_size)
|
|
73
|
+
# Only kpack=1 is supported on gfx950
|
|
74
|
+
kpack = 1 if self.arch == 'gfx950' else self.kpack
|
|
75
|
+
object.__setattr__(self, 'kpack', kpack)
|
|
67
76
|
libs = ["ocml", "ockl"]
|
|
68
77
|
for lib in libs:
|
|
69
78
|
extern_libs[lib] = str(default_libdir / f'{lib}.bc')
|
|
@@ -76,44 +85,6 @@ class HIPOptions:
|
|
|
76
85
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
77
86
|
|
|
78
87
|
|
|
79
|
-
@register_descriptor
|
|
80
|
-
class HIPAttrsDescriptor(AttrsDescriptor):
|
|
81
|
-
# This property asserts if the underlying storage area of a given pointer
|
|
82
|
-
# can be resepresented as a 32 bit integer. When this is true, we can be
|
|
83
|
-
# sure that all indices into the tensor behind that pointer can use 32-bit
|
|
84
|
-
# indexing. That opens the door for the AMD backend to use buffer load/store
|
|
85
|
-
# instrinsics, which requires this property. Buffer load/store intrinsics
|
|
86
|
-
# gives direct out-of-bound support and simplifies index calculation for
|
|
87
|
-
# lower register pressure.
|
|
88
|
-
__slots__ = ("pointer_range_32")
|
|
89
|
-
|
|
90
|
-
def _add_backend_properties(self, params=None, values=None):
|
|
91
|
-
self.property_values["tt.pointer_range"] = 32
|
|
92
|
-
if params is None or values is None:
|
|
93
|
-
return
|
|
94
|
-
|
|
95
|
-
self.arg_properties["tt.pointer_range"] = [
|
|
96
|
-
param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg)
|
|
97
|
-
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
|
|
98
|
-
]
|
|
99
|
-
|
|
100
|
-
@staticmethod
|
|
101
|
-
def is_within2gb(arg):
|
|
102
|
-
if hasattr(arg, "ptr_range"):
|
|
103
|
-
return arg.ptr_range() <= 2**31 - 1
|
|
104
|
-
if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"):
|
|
105
|
-
# Please note that 2**31-1 is the max int32 positive limit
|
|
106
|
-
return arg.untyped_storage().size() <= 2**31 - 1
|
|
107
|
-
return False
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
def get_property_key(val, align):
|
|
111
|
-
generic_key = AttrsDescriptor.get_property_key(val, align)
|
|
112
|
-
hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N"
|
|
113
|
-
key = (generic_key + hip_key).replace("N", "")
|
|
114
|
-
return key if key else "N"
|
|
115
|
-
|
|
116
|
-
|
|
117
88
|
class HIPBackend(BaseBackend):
|
|
118
89
|
|
|
119
90
|
@staticmethod
|
|
@@ -126,17 +97,25 @@ class HIPBackend(BaseBackend):
|
|
|
126
97
|
self.binary_ext = "hsaco"
|
|
127
98
|
|
|
128
99
|
def parse_options(self, opts) -> Any:
|
|
129
|
-
args = {'arch': self.target.arch}
|
|
100
|
+
args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)}
|
|
101
|
+
|
|
102
|
+
# Enable XF32 (TF32) for CDNA3 GPUs
|
|
103
|
+
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
|
|
104
|
+
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
|
|
105
|
+
allowed_dot_input_precisions.update({'tf32'})
|
|
106
|
+
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
|
|
130
107
|
|
|
131
108
|
if "supported_fp8_dtypes" not in opts:
|
|
132
109
|
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
|
|
133
110
|
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
|
|
134
|
-
supported_fp8_dtypes.update({'fp8e4b8', 'fp8e5b16'})
|
|
111
|
+
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
|
|
112
|
+
elif self.target.arch in ('gfx950'):
|
|
113
|
+
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
|
|
135
114
|
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
|
|
136
115
|
|
|
137
116
|
if "enable_fp_fusion" not in opts:
|
|
138
117
|
args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
|
|
139
|
-
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts})
|
|
118
|
+
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None})
|
|
140
119
|
return HIPOptions(**args)
|
|
141
120
|
|
|
142
121
|
def pack_metadata(self, metadata):
|
|
@@ -149,23 +128,49 @@ class HIPBackend(BaseBackend):
|
|
|
149
128
|
metadata.cluster_dims[2],
|
|
150
129
|
)
|
|
151
130
|
|
|
152
|
-
def get_codegen_implementation(self):
|
|
131
|
+
def get_codegen_implementation(self, options):
|
|
153
132
|
codegen_fns = {"min_dot_size": min_dot_size(self.target)}
|
|
154
133
|
return codegen_fns
|
|
155
134
|
|
|
156
135
|
def get_module_map(self) -> Dict[str, ModuleType]:
|
|
157
136
|
from triton.language.extra.hip import libdevice
|
|
137
|
+
|
|
158
138
|
return {"triton.language.extra.libdevice": libdevice}
|
|
159
139
|
|
|
160
140
|
def load_dialects(self, ctx):
|
|
161
141
|
amd.load_dialects(ctx)
|
|
162
142
|
|
|
163
|
-
|
|
164
|
-
|
|
143
|
+
@staticmethod
|
|
144
|
+
@functools.lru_cache()
|
|
145
|
+
def use_buffer_ops():
|
|
146
|
+
return os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def is_within_2gb(arg):
|
|
150
|
+
import torch
|
|
151
|
+
|
|
152
|
+
MAX_INT_32 = 2**31 - 1
|
|
153
|
+
if hasattr(arg, "ptr_range"):
|
|
154
|
+
return arg.ptr_range() <= MAX_INT_32
|
|
155
|
+
if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
|
|
156
|
+
return arg.untyped_storage().size() <= MAX_INT_32
|
|
157
|
+
return False
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def parse_attr(desc):
|
|
161
|
+
ret = BaseBackend.parse_attr(desc)
|
|
162
|
+
if "S" in desc:
|
|
163
|
+
ret += [["tt.pointer_range", 32]]
|
|
164
|
+
return ret
|
|
165
165
|
|
|
166
166
|
@staticmethod
|
|
167
|
-
def
|
|
168
|
-
|
|
167
|
+
def get_arg_specialization(arg, ty, **kwargs):
|
|
168
|
+
ret = BaseBackend.get_arg_specialization(arg, ty, **kwargs)
|
|
169
|
+
# Only attempt to do buffer ops specialization if buffer ops are enabled.
|
|
170
|
+
# Otherwise the is_within_2gb check is unnecessary overhead.
|
|
171
|
+
if HIPBackend.use_buffer_ops() and ty == "tensor" and HIPBackend.is_within_2gb(arg):
|
|
172
|
+
ret += "S"
|
|
173
|
+
return ret
|
|
169
174
|
|
|
170
175
|
@staticmethod
|
|
171
176
|
def path_to_rocm_lld():
|
|
@@ -193,8 +198,8 @@ class HIPBackend(BaseBackend):
|
|
|
193
198
|
pm.enable_debug()
|
|
194
199
|
passes.common.add_inliner(pm)
|
|
195
200
|
passes.ttir.add_rewrite_tensor_pointer(pm)
|
|
196
|
-
passes.ttir.add_combine(pm)
|
|
197
201
|
passes.common.add_canonicalizer(pm)
|
|
202
|
+
passes.ttir.add_combine(pm)
|
|
198
203
|
passes.ttir.add_reorder_broadcast(pm)
|
|
199
204
|
passes.common.add_cse(pm)
|
|
200
205
|
passes.common.add_licm(pm)
|
|
@@ -219,24 +224,38 @@ class HIPBackend(BaseBackend):
|
|
|
219
224
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
220
225
|
amd.passes.ttgpuir.add_optimize_epilogue(pm)
|
|
221
226
|
passes.ttgpuir.add_optimize_dot_operands(pm, True)
|
|
227
|
+
amd.passes.ttgpuir.add_hoist_layout_conversions(pm)
|
|
228
|
+
|
|
229
|
+
global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
|
|
230
|
+
local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
|
|
231
|
+
|
|
232
|
+
# The `local-prefetch` scheduling variant requires turning on buffer ops.
|
|
233
|
+
if options.instruction_sched_variant == "local-prefetch":
|
|
234
|
+
global_prefetch = local_prefetch = 1
|
|
235
|
+
|
|
222
236
|
if amd.has_matrix_core_feature(options.arch):
|
|
223
237
|
assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
|
|
224
238
|
"We used to trigger software pipelining with "
|
|
225
239
|
"num_stages == 0. Now it will not happen anymore; "
|
|
226
240
|
"please update to use num_stages == 2 for "
|
|
227
241
|
"equivalent behavior in the past.")
|
|
228
|
-
amd.passes.ttgpuir.
|
|
242
|
+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch)
|
|
229
243
|
passes.common.add_canonicalizer(pm)
|
|
230
|
-
|
|
244
|
+
if options.instruction_sched_variant.lower() != "none":
|
|
245
|
+
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
|
|
231
246
|
passes.ttgpuir.add_optimize_dot_operands(pm, True)
|
|
232
247
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
233
248
|
passes.ttgpuir.add_reduce_data_duplication(pm)
|
|
234
249
|
if amd.has_matrix_core_feature(options.arch):
|
|
235
250
|
amd.passes.ttgpuir.add_reorder_instructions(pm)
|
|
236
|
-
|
|
251
|
+
use_block_pingpong = is_pingpong_enabled(options.arch)
|
|
252
|
+
if use_block_pingpong and options.num_stages == 2:
|
|
253
|
+
amd.passes.ttgpuir.add_block_pingpong(pm)
|
|
254
|
+
|
|
255
|
+
if HIPBackend.use_buffer_ops():
|
|
237
256
|
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
|
|
238
257
|
passes.common.add_canonicalizer(pm)
|
|
239
|
-
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm)
|
|
258
|
+
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
|
|
240
259
|
passes.common.add_canonicalizer(pm)
|
|
241
260
|
passes.common.add_cse(pm)
|
|
242
261
|
passes.common.add_symbol_dce(pm)
|
|
@@ -278,7 +297,8 @@ class HIPBackend(BaseBackend):
|
|
|
278
297
|
passes.common.add_canonicalizer(pm)
|
|
279
298
|
passes.common.add_cse(pm)
|
|
280
299
|
passes.common.add_symbol_dce(pm)
|
|
281
|
-
|
|
300
|
+
if options.instruction_sched_variant.lower() != "none":
|
|
301
|
+
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
|
|
282
302
|
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
|
|
283
303
|
passes.llvmir.add_di_scope(pm)
|
|
284
304
|
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
|
|
@@ -289,12 +309,15 @@ class HIPBackend(BaseBackend):
|
|
|
289
309
|
context = llvm.context()
|
|
290
310
|
llvm_mod = llvm.to_module(mod, context)
|
|
291
311
|
amd.attach_target_triple(llvm_mod)
|
|
292
|
-
|
|
312
|
+
target_features = ''
|
|
313
|
+
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
|
|
314
|
+
target_features = '+xnack'
|
|
315
|
+
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
|
|
293
316
|
|
|
294
317
|
# Set various control constants on the LLVM module so that device
|
|
295
318
|
# libraries can resolve references to them.
|
|
296
319
|
amd.set_isa_version(llvm_mod, options.arch)
|
|
297
|
-
amd.set_abi_version(llvm_mod,
|
|
320
|
+
amd.set_abi_version(llvm_mod, 500)
|
|
298
321
|
amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False)
|
|
299
322
|
amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True)
|
|
300
323
|
amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False)
|
|
@@ -305,25 +328,46 @@ class HIPBackend(BaseBackend):
|
|
|
305
328
|
# The public kernel should be kernel 0.
|
|
306
329
|
fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL)
|
|
307
330
|
fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}")
|
|
331
|
+
# LLVM AMDGPU backend supports the attribute "amdgpu-waves-per-eu"="<min>[, <max>]".
|
|
332
|
+
# This attribute may be attached to a kernel function definition and is an optimization hint.
|
|
333
|
+
# <min> parameter specifies the requested minimum number of waves per EU, and optional <max> parameter
|
|
334
|
+
# specifies the requested maximum number of waves per EU (must be greater than <min> if specified).
|
|
335
|
+
# If <max> is omitted, then there is no restriction on the maximum number of waves per EU other than
|
|
336
|
+
# the one dictated by the hardware for which the kernel is compiled. Passing 0, 0 as <min>, <max>
|
|
337
|
+
# implies the default behavior (no limits).
|
|
308
338
|
fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
|
|
309
339
|
denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
|
|
310
340
|
fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
|
|
341
|
+
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
|
|
342
|
+
fns[0].add_fn_target_feature("+xnack")
|
|
343
|
+
fns[0].add_fn_asan_attr()
|
|
311
344
|
|
|
312
345
|
# Hint the compiler that we'd like the firmware to set the kernel arguments
|
|
313
346
|
# to user SGPRs so that the kernel does not need to s_load its arguments
|
|
314
347
|
# from memory.
|
|
315
348
|
amd.set_all_fn_arg_inreg(fns[0])
|
|
316
349
|
|
|
317
|
-
if
|
|
350
|
+
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
|
|
351
|
+
default_libdir = Path(__file__).parent / 'lib'
|
|
352
|
+
paths = [
|
|
353
|
+
str(default_libdir / 'asanrtl.bc'),
|
|
354
|
+
str(default_libdir / "ocml.bc"),
|
|
355
|
+
str(default_libdir / "ockl.bc")
|
|
356
|
+
]
|
|
357
|
+
llvm.link_extern_libs(llvm_mod, paths)
|
|
358
|
+
elif options.extern_libs:
|
|
318
359
|
paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
|
|
319
360
|
llvm.link_extern_libs(llvm_mod, paths)
|
|
320
361
|
|
|
321
362
|
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
|
|
322
363
|
|
|
323
364
|
# Get some metadata
|
|
324
|
-
metadata["shared"] = src.get_int_attr("
|
|
365
|
+
metadata["shared"] = src.get_int_attr("ttg.shared")
|
|
325
366
|
|
|
326
367
|
amd.cleanup_bitcode_metadata(llvm_mod)
|
|
368
|
+
# Disable inlining of print related functions,
|
|
369
|
+
# because inlining of these function could slow down compilation significantly
|
|
370
|
+
amd.disable_print_inline(llvm_mod)
|
|
327
371
|
return str(llvm_mod)
|
|
328
372
|
|
|
329
373
|
@staticmethod
|
|
@@ -343,7 +387,10 @@ class HIPBackend(BaseBackend):
|
|
|
343
387
|
|
|
344
388
|
@staticmethod
|
|
345
389
|
def make_hsaco(src, metadata, options):
|
|
346
|
-
|
|
390
|
+
target_features = ''
|
|
391
|
+
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
|
|
392
|
+
target_features = '+xnack'
|
|
393
|
+
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
|
|
347
394
|
|
|
348
395
|
rocm_path = HIPBackend.path_to_rocm_lld()
|
|
349
396
|
with tempfile.NamedTemporaryFile() as tmp_out:
|