triton-windows 3.3.0.post19__cp39-cp39-win_amd64.whl → 3.4.0.post20__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
- triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
triton/backends/amd/compiler.py
CHANGED
|
@@ -1,25 +1,29 @@
|
|
|
1
|
-
from triton.backends.compiler import BaseBackend, GPUTarget
|
|
1
|
+
from triton.backends.compiler import BaseBackend, GPUTarget, Language
|
|
2
2
|
from triton._C.libtriton import ir, passes, llvm, amd
|
|
3
|
+
from triton import knobs
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import Any, Dict, Tuple
|
|
5
6
|
from types import ModuleType
|
|
6
7
|
import hashlib
|
|
7
8
|
import tempfile
|
|
8
|
-
import os
|
|
9
9
|
import re
|
|
10
10
|
import subprocess
|
|
11
11
|
import functools
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def
|
|
16
|
-
#
|
|
17
|
-
|
|
15
|
+
def get_min_dot_size(target: GPUTarget):
|
|
16
|
+
# We fallback to use FMA and cast arguments if certain configurations is
|
|
17
|
+
# not supported natively by matrix core units.
|
|
18
|
+
return lambda lhs_type, rhs_type: (1, 1, 1)
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
def
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
def is_pingpong_schedule_enabled(arch):
|
|
22
|
+
return (arch == "gfx942") if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_in_thread_transpose_enabled(arch):
|
|
26
|
+
return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose
|
|
23
27
|
|
|
24
28
|
|
|
25
29
|
@dataclass(frozen=True)
|
|
@@ -28,17 +32,13 @@ class HIPOptions:
|
|
|
28
32
|
waves_per_eu: int = 1
|
|
29
33
|
num_stages: int = 2
|
|
30
34
|
num_ctas: int = 1
|
|
31
|
-
num_buffers_warp_spec: int = 0
|
|
32
|
-
num_consumer_groups: int = 0
|
|
33
|
-
reg_dec_producer: int = 0
|
|
34
|
-
reg_inc_consumer: int = 0
|
|
35
35
|
extern_libs: dict = None
|
|
36
36
|
cluster_dims: tuple = (1, 1, 1)
|
|
37
37
|
debug: bool = False
|
|
38
38
|
sanitize_overflow: bool = True
|
|
39
39
|
arch: str = None
|
|
40
40
|
supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
|
|
41
|
-
|
|
41
|
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
|
42
42
|
default_dot_input_precision: str = "ieee"
|
|
43
43
|
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
|
|
44
44
|
enable_fp_fusion: bool = True
|
|
@@ -57,32 +57,30 @@ class HIPOptions:
|
|
|
57
57
|
#
|
|
58
58
|
# Current experimental scheduling variants:
|
|
59
59
|
#
|
|
60
|
-
# llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's
|
|
61
|
-
# k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels".
|
|
62
|
-
# llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's
|
|
63
|
-
# k-loop; i.e., "interleave DS and MFMA instructions for single wave small
|
|
64
|
-
# GEMM kernels.".
|
|
65
60
|
# local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
|
|
66
61
|
# Kernel library. Note, this variant requires the use of buffer load/store ops
|
|
67
62
|
# and a special software pipelining style - i.e., 1x LDS and 1x register
|
|
68
63
|
# prefetch buffers for each GEMM tile.
|
|
69
|
-
|
|
64
|
+
# attention: enables a bunch of optimizations for attention kernels, including:
|
|
65
|
+
# - iglp 2 and sched.barrier around it
|
|
66
|
+
# - sink-insts-to-avoid-spills flag to avoid register spills
|
|
67
|
+
schedule_hint: str = 'none'
|
|
70
68
|
|
|
71
69
|
def __post_init__(self):
|
|
70
|
+
gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number
|
|
71
|
+
warp_size = 32 if gfx_major >= 10 else 64
|
|
72
|
+
object.__setattr__(self, 'warp_size', warp_size)
|
|
73
|
+
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
|
74
|
+
"num_warps must be a power of 2"
|
|
75
|
+
|
|
76
|
+
if self.arch == 'gfx950':
|
|
77
|
+
assert self.kpack == 1, "gfx950 only accepts kpack == 1"
|
|
78
|
+
|
|
72
79
|
default_libdir = Path(__file__).parent / 'lib'
|
|
73
80
|
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
|
|
74
|
-
|
|
75
|
-
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
|
|
76
|
-
object.__setattr__(self, 'warp_size', warp_size)
|
|
77
|
-
# Only kpack=1 is supported on gfx950
|
|
78
|
-
kpack = 1 if self.arch == 'gfx950' else self.kpack
|
|
79
|
-
object.__setattr__(self, 'kpack', kpack)
|
|
80
|
-
libs = ["ocml", "ockl"]
|
|
81
|
-
for lib in libs:
|
|
81
|
+
for lib in ["ocml", "ockl"]:
|
|
82
82
|
extern_libs[lib] = str(default_libdir / f'{lib}.bc')
|
|
83
83
|
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
|
|
84
|
-
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
|
85
|
-
"num_warps must be a power of 2"
|
|
86
84
|
|
|
87
85
|
def hash(self):
|
|
88
86
|
key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
|
|
@@ -100,26 +98,32 @@ class HIPBackend(BaseBackend):
|
|
|
100
98
|
assert isinstance(target.arch, str)
|
|
101
99
|
self.binary_ext = "hsaco"
|
|
102
100
|
|
|
101
|
+
def get_target_name(self, options) -> str:
|
|
102
|
+
return f"hip:{options.arch}"
|
|
103
|
+
|
|
103
104
|
def parse_options(self, opts) -> Any:
|
|
104
|
-
args = {'arch':
|
|
105
|
+
args = {'arch': knobs.runtime.override_arch or self.target.arch}
|
|
105
106
|
|
|
106
107
|
# Enable XF32 (TF32) for CDNA3 GPUs
|
|
107
|
-
if self.target.arch
|
|
108
|
+
if self.target.arch == 'gfx942':
|
|
108
109
|
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
|
|
109
110
|
allowed_dot_input_precisions.update({'tf32'})
|
|
110
111
|
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
|
|
111
112
|
|
|
112
113
|
if "supported_fp8_dtypes" not in opts:
|
|
113
114
|
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
|
|
114
|
-
if self.target.arch
|
|
115
|
+
if self.target.arch == 'gfx942':
|
|
115
116
|
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
|
|
116
|
-
elif self.target.arch
|
|
117
|
+
elif self.target.arch == 'gfx950':
|
|
118
|
+
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
|
|
119
|
+
elif 'gfx12' in self.target.arch:
|
|
117
120
|
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
|
|
118
121
|
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
|
|
119
122
|
|
|
120
123
|
if "enable_fp_fusion" not in opts:
|
|
121
|
-
args["enable_fp_fusion"] =
|
|
122
|
-
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys()
|
|
124
|
+
args["enable_fp_fusion"] = knobs.language.default_fp_fusion
|
|
125
|
+
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() \
|
|
126
|
+
if k in opts and opts[k] is not None})
|
|
123
127
|
return HIPOptions(**args)
|
|
124
128
|
|
|
125
129
|
def pack_metadata(self, metadata):
|
|
@@ -133,8 +137,7 @@ class HIPBackend(BaseBackend):
|
|
|
133
137
|
)
|
|
134
138
|
|
|
135
139
|
def get_codegen_implementation(self, options):
|
|
136
|
-
|
|
137
|
-
return codegen_fns
|
|
140
|
+
return {"min_dot_size": get_min_dot_size(self.target)}
|
|
138
141
|
|
|
139
142
|
def get_module_map(self) -> Dict[str, ModuleType]:
|
|
140
143
|
from triton.language.extra.hip import libdevice
|
|
@@ -144,11 +147,6 @@ class HIPBackend(BaseBackend):
|
|
|
144
147
|
def load_dialects(self, ctx):
|
|
145
148
|
amd.load_dialects(ctx)
|
|
146
149
|
|
|
147
|
-
@staticmethod
|
|
148
|
-
@functools.lru_cache()
|
|
149
|
-
def use_buffer_ops():
|
|
150
|
-
return os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
|
|
151
|
-
|
|
152
150
|
@staticmethod
|
|
153
151
|
def is_within_2gb(arg):
|
|
154
152
|
import torch
|
|
@@ -172,14 +170,14 @@ class HIPBackend(BaseBackend):
|
|
|
172
170
|
ret = BaseBackend.get_arg_specialization(arg, ty, **kwargs)
|
|
173
171
|
# Only attempt to do buffer ops specialization if buffer ops are enabled.
|
|
174
172
|
# Otherwise the is_within_2gb check is unnecessary overhead.
|
|
175
|
-
if
|
|
173
|
+
if knobs.amd.use_buffer_ops and ty == "tensor" and HIPBackend.is_within_2gb(arg):
|
|
176
174
|
ret += "S"
|
|
177
175
|
return ret
|
|
178
176
|
|
|
179
177
|
@staticmethod
|
|
180
178
|
def path_to_rocm_lld():
|
|
181
179
|
# Check env path for ld.lld
|
|
182
|
-
lld_env_path =
|
|
180
|
+
lld_env_path = knobs.amd.lld_path
|
|
183
181
|
if lld_env_path is not None:
|
|
184
182
|
lld = Path(lld_env_path)
|
|
185
183
|
if lld.is_file():
|
|
@@ -202,11 +200,12 @@ class HIPBackend(BaseBackend):
|
|
|
202
200
|
pm.enable_debug()
|
|
203
201
|
passes.common.add_inliner(pm)
|
|
204
202
|
passes.ttir.add_rewrite_tensor_pointer(pm)
|
|
203
|
+
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
|
|
205
204
|
passes.common.add_canonicalizer(pm)
|
|
206
205
|
passes.ttir.add_combine(pm)
|
|
207
206
|
passes.ttir.add_reorder_broadcast(pm)
|
|
208
207
|
passes.common.add_cse(pm)
|
|
209
|
-
passes.
|
|
208
|
+
passes.ttir.add_triton_licm(pm)
|
|
210
209
|
passes.common.add_symbol_dce(pm)
|
|
211
210
|
passes.ttir.add_loop_unroll(pm)
|
|
212
211
|
pm.run(mod)
|
|
@@ -230,39 +229,62 @@ class HIPBackend(BaseBackend):
|
|
|
230
229
|
passes.ttgpuir.add_optimize_dot_operands(pm, True)
|
|
231
230
|
amd.passes.ttgpuir.add_hoist_layout_conversions(pm)
|
|
232
231
|
|
|
233
|
-
|
|
234
|
-
|
|
232
|
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
|
233
|
+
passes.common.add_canonicalizer(pm)
|
|
234
|
+
passes.ttir.add_triton_licm(pm)
|
|
235
|
+
passes.common.add_canonicalizer(pm)
|
|
236
|
+
|
|
237
|
+
global_prefetch = knobs.amd.global_prefetch
|
|
238
|
+
local_prefetch = knobs.amd.local_prefetch
|
|
239
|
+
use_async_copy = knobs.amd.use_async_copy
|
|
235
240
|
|
|
236
241
|
# The `local-prefetch` scheduling variant requires turning on buffer ops.
|
|
237
|
-
if options.
|
|
242
|
+
if options.schedule_hint == "local-prefetch":
|
|
238
243
|
global_prefetch = local_prefetch = 1
|
|
239
244
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch)
|
|
247
|
-
passes.common.add_canonicalizer(pm)
|
|
248
|
-
if options.instruction_sched_variant.lower() != "none":
|
|
249
|
-
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
|
|
245
|
+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
|
|
246
|
+
if use_async_copy:
|
|
247
|
+
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
|
|
248
|
+
passes.common.add_canonicalizer(pm)
|
|
249
|
+
if options.schedule_hint.lower() != "none":
|
|
250
|
+
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.schedule_hint)
|
|
250
251
|
passes.ttgpuir.add_optimize_dot_operands(pm, True)
|
|
251
252
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
252
253
|
passes.ttgpuir.add_reduce_data_duplication(pm)
|
|
253
|
-
if
|
|
254
|
-
amd.passes.ttgpuir.
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
254
|
+
if is_in_thread_transpose_enabled(options.arch):
|
|
255
|
+
amd.passes.ttgpuir.add_in_thread_transpose(pm)
|
|
256
|
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
257
|
+
amd.passes.ttgpuir.add_reorder_instructions(pm)
|
|
258
|
+
use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
|
|
259
|
+
if use_block_pingpong and options.num_stages == 2:
|
|
260
|
+
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
|
|
261
|
+
|
|
262
|
+
if knobs.amd.use_buffer_ops:
|
|
260
263
|
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
|
|
261
264
|
passes.common.add_canonicalizer(pm)
|
|
262
265
|
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
|
|
266
|
+
|
|
267
|
+
amd.passes.ttgpuir.add_fold_true_cmpi(pm)
|
|
263
268
|
passes.common.add_canonicalizer(pm)
|
|
264
269
|
passes.common.add_cse(pm)
|
|
265
270
|
passes.common.add_symbol_dce(pm)
|
|
271
|
+
if use_async_copy:
|
|
272
|
+
amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
|
|
273
|
+
pm.run(mod)
|
|
274
|
+
return mod
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def ttgir_opt(src, metadata, options):
|
|
278
|
+
mod = src
|
|
279
|
+
pm = ir.pass_manager(mod.context)
|
|
280
|
+
pm.enable_debug()
|
|
281
|
+
|
|
282
|
+
passes.ttgpuir.add_inliner(pm)
|
|
283
|
+
passes.common.add_sccp(pm)
|
|
284
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
285
|
+
passes.ttgpuir.add_canonicalizer(pm)
|
|
286
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
287
|
+
|
|
266
288
|
pm.run(mod)
|
|
267
289
|
return mod
|
|
268
290
|
|
|
@@ -272,7 +294,6 @@ class HIPBackend(BaseBackend):
|
|
|
272
294
|
# TritonGPU -> LLVM-IR (MLIR)
|
|
273
295
|
pm = ir.pass_manager(mod.context)
|
|
274
296
|
pm.enable_debug()
|
|
275
|
-
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch)
|
|
276
297
|
# custom_lds_size is an experimental parameter that defines amount of LDS available
|
|
277
298
|
# for one thread block. Measured in bytes.
|
|
278
299
|
#
|
|
@@ -301,9 +322,9 @@ class HIPBackend(BaseBackend):
|
|
|
301
322
|
passes.common.add_canonicalizer(pm)
|
|
302
323
|
passes.common.add_cse(pm)
|
|
303
324
|
passes.common.add_symbol_dce(pm)
|
|
304
|
-
if options.
|
|
325
|
+
if options.schedule_hint.lower() != "none":
|
|
305
326
|
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
|
|
306
|
-
if
|
|
327
|
+
if not knobs.compilation.disable_line_info:
|
|
307
328
|
passes.llvmir.add_di_scope(pm)
|
|
308
329
|
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
|
|
309
330
|
pm.run(mod)
|
|
@@ -314,7 +335,7 @@ class HIPBackend(BaseBackend):
|
|
|
314
335
|
llvm_mod = llvm.to_module(mod, context)
|
|
315
336
|
amd.attach_target_triple(llvm_mod)
|
|
316
337
|
target_features = ''
|
|
317
|
-
if
|
|
338
|
+
if knobs.compilation.enable_asan:
|
|
318
339
|
target_features = '+xnack'
|
|
319
340
|
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
|
|
320
341
|
|
|
@@ -342,7 +363,7 @@ class HIPBackend(BaseBackend):
|
|
|
342
363
|
fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
|
|
343
364
|
denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
|
|
344
365
|
fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
|
|
345
|
-
if
|
|
366
|
+
if knobs.compilation.enable_asan:
|
|
346
367
|
fns[0].add_fn_target_feature("+xnack")
|
|
347
368
|
fns[0].add_fn_asan_attr()
|
|
348
369
|
|
|
@@ -351,7 +372,7 @@ class HIPBackend(BaseBackend):
|
|
|
351
372
|
# from memory.
|
|
352
373
|
amd.set_all_fn_arg_inreg(fns[0])
|
|
353
374
|
|
|
354
|
-
if
|
|
375
|
+
if knobs.compilation.enable_asan:
|
|
355
376
|
default_libdir = Path(__file__).parent / 'lib'
|
|
356
377
|
paths = [
|
|
357
378
|
str(default_libdir / 'asanrtl.bc'),
|
|
@@ -365,6 +386,9 @@ class HIPBackend(BaseBackend):
|
|
|
365
386
|
|
|
366
387
|
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
|
|
367
388
|
|
|
389
|
+
if knobs.amd.scalarize_packed_fops:
|
|
390
|
+
amd.add_scalarize_packed_fops_llvm_pass(fns[0])
|
|
391
|
+
|
|
368
392
|
# Get some metadata
|
|
369
393
|
metadata["shared"] = src.get_int_attr("ttg.shared")
|
|
370
394
|
|
|
@@ -377,14 +401,21 @@ class HIPBackend(BaseBackend):
|
|
|
377
401
|
@staticmethod
|
|
378
402
|
def make_amdgcn(src, metadata, options):
|
|
379
403
|
# Find kernel names (there should only be one)
|
|
380
|
-
# We get the name at the last possible step to
|
|
404
|
+
# We get the name at the last possible step to accommodate `triton.compile`
|
|
381
405
|
# on user-provided LLVM
|
|
382
406
|
names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src)
|
|
383
407
|
assert len(names) == 1
|
|
384
408
|
metadata["name"] = names[0]
|
|
385
409
|
# llvm -> hsaco
|
|
386
|
-
|
|
387
|
-
|
|
410
|
+
flags = []
|
|
411
|
+
# The sink-insts-to-avoid-spills flag asks LLVM backend to sink instructions
|
|
412
|
+
# into loops to avoid register spills in the MachineSinking pass, while it
|
|
413
|
+
# can also lead to regression in some cases. But from current observation,
|
|
414
|
+
# the regression is not significant. It would be better to have some heuristics.
|
|
415
|
+
if options.schedule_hint == 'attention':
|
|
416
|
+
flags.append('sink-insts-to-avoid-spills')
|
|
417
|
+
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', flags, options.enable_fp_fusion, False)
|
|
418
|
+
if knobs.amd.dump_amdgcn:
|
|
388
419
|
print("// -----// AMDGCN Dump //----- //")
|
|
389
420
|
print(amdgcn)
|
|
390
421
|
return amdgcn
|
|
@@ -392,7 +423,7 @@ class HIPBackend(BaseBackend):
|
|
|
392
423
|
@staticmethod
|
|
393
424
|
def make_hsaco(src, metadata, options):
|
|
394
425
|
target_features = ''
|
|
395
|
-
if
|
|
426
|
+
if knobs.compilation.enable_asan:
|
|
396
427
|
target_features = '+xnack'
|
|
397
428
|
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
|
|
398
429
|
|
|
@@ -406,9 +437,12 @@ class HIPBackend(BaseBackend):
|
|
|
406
437
|
ret = fd_out.read()
|
|
407
438
|
return ret
|
|
408
439
|
|
|
409
|
-
def add_stages(self, stages, options):
|
|
410
|
-
|
|
411
|
-
|
|
440
|
+
def add_stages(self, stages, options, language):
|
|
441
|
+
if language == Language.TRITON:
|
|
442
|
+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
|
|
443
|
+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
|
|
444
|
+
elif language == Language.GLUON:
|
|
445
|
+
stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options)
|
|
412
446
|
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
|
|
413
447
|
stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
|
|
414
448
|
stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
|
triton/backends/amd/driver.c
CHANGED
|
@@ -172,15 +172,18 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
|
172
172
|
// get allocated registers and spilled registers from the function
|
|
173
173
|
int n_regs = 0;
|
|
174
174
|
int n_spills = 0;
|
|
175
|
+
int32_t n_max_threads = 0;
|
|
175
176
|
hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
|
176
177
|
hipSymbolTable.hipFuncGetAttribute(&n_spills,
|
|
177
178
|
HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
|
179
|
+
hipSymbolTable.hipFuncGetAttribute(
|
|
180
|
+
&n_max_threads, HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun);
|
|
178
181
|
n_spills /= 4;
|
|
179
182
|
if (PyErr_Occurred()) {
|
|
180
183
|
return NULL;
|
|
181
184
|
}
|
|
182
|
-
return Py_BuildValue("(
|
|
183
|
-
n_spills);
|
|
185
|
+
return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
|
186
|
+
n_spills, n_max_threads);
|
|
184
187
|
}
|
|
185
188
|
|
|
186
189
|
static PyMethodDef ModuleMethods[] = {
|