triton-windows 3.3.1.post19__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/backends/amd/compiler.py
CHANGED
|
@@ -1,25 +1,30 @@
|
|
|
1
|
-
from triton.backends.compiler import BaseBackend, GPUTarget
|
|
1
|
+
from triton.backends.compiler import BaseBackend, GPUTarget, Language
|
|
2
2
|
from triton._C.libtriton import ir, passes, llvm, amd
|
|
3
|
+
from triton import knobs
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import Any, Dict, Tuple
|
|
5
6
|
from types import ModuleType
|
|
6
7
|
import hashlib
|
|
7
8
|
import tempfile
|
|
8
|
-
import os
|
|
9
9
|
import re
|
|
10
|
-
import subprocess
|
|
11
10
|
import functools
|
|
11
|
+
import warnings
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def
|
|
16
|
-
#
|
|
17
|
-
|
|
15
|
+
def get_min_dot_size(target: GPUTarget):
|
|
16
|
+
# We fallback to use FMA and cast arguments if certain configurations is
|
|
17
|
+
# not supported natively by matrix core units.
|
|
18
|
+
return lambda lhs_type, rhs_type: (1, 1, 1)
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
def
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
def is_pingpong_schedule_enabled(arch, use_async_copy):
|
|
22
|
+
return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
|
|
23
|
+
) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_in_thread_transpose_enabled(arch):
|
|
27
|
+
return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose
|
|
23
28
|
|
|
24
29
|
|
|
25
30
|
@dataclass(frozen=True)
|
|
@@ -28,17 +33,17 @@ class HIPOptions:
|
|
|
28
33
|
waves_per_eu: int = 1
|
|
29
34
|
num_stages: int = 2
|
|
30
35
|
num_ctas: int = 1
|
|
31
|
-
num_buffers_warp_spec: int = 0
|
|
32
|
-
num_consumer_groups: int = 0
|
|
33
|
-
reg_dec_producer: int = 0
|
|
34
|
-
reg_inc_consumer: int = 0
|
|
35
36
|
extern_libs: dict = None
|
|
36
37
|
cluster_dims: tuple = (1, 1, 1)
|
|
37
38
|
debug: bool = False
|
|
38
39
|
sanitize_overflow: bool = True
|
|
39
40
|
arch: str = None
|
|
40
|
-
|
|
41
|
-
|
|
41
|
+
# We have native support for OCP fp8 variants since CDNA4/RDNA4. For earlier generations,
|
|
42
|
+
# we software emulate the support for them.
|
|
43
|
+
# UZ fp8 variants (fp8e4b8 and fp8e5b16) are natively supported for CDNA3. For other
|
|
44
|
+
# architectures they are software emulated.
|
|
45
|
+
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8")
|
|
46
|
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
|
42
47
|
default_dot_input_precision: str = "ieee"
|
|
43
48
|
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
|
|
44
49
|
enable_fp_fusion: bool = True
|
|
@@ -48,6 +53,7 @@ class HIPOptions:
|
|
|
48
53
|
allow_flush_denorm: bool = False
|
|
49
54
|
max_num_imprecise_acc_default: int = 0
|
|
50
55
|
backend_name: str = 'hip'
|
|
56
|
+
instrumentation_mode: str = ""
|
|
51
57
|
|
|
52
58
|
# The following option provides hints to the AMDGPU backend regarding instruction scheduling
|
|
53
59
|
# for all `tt.dot` operations in a kernel. The "none" variant preserves the default
|
|
@@ -57,32 +63,29 @@ class HIPOptions:
|
|
|
57
63
|
#
|
|
58
64
|
# Current experimental scheduling variants:
|
|
59
65
|
#
|
|
60
|
-
#
|
|
61
|
-
#
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
# GEMM kernels.".
|
|
65
|
-
# local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
|
|
66
|
-
# Kernel library. Note, this variant requires the use of buffer load/store ops
|
|
67
|
-
# and a special software pipelining style - i.e., 1x LDS and 1x register
|
|
68
|
-
# prefetch buffers for each GEMM tile.
|
|
69
|
-
instruction_sched_variant: str = 'none'
|
|
66
|
+
# attention: enables a bunch of optimizations for attention kernels, including:
|
|
67
|
+
# - iglp 2 and sched.barrier around it
|
|
68
|
+
# - sink-insts-to-avoid-spills flag to avoid register spills
|
|
69
|
+
schedule_hint: str = 'none'
|
|
70
70
|
|
|
71
71
|
def __post_init__(self):
|
|
72
|
+
gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number
|
|
73
|
+
warp_size = 32 if gfx_major >= 10 else 64
|
|
74
|
+
object.__setattr__(self, 'warp_size', warp_size)
|
|
75
|
+
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
|
76
|
+
"num_warps must be a power of 2"
|
|
77
|
+
|
|
78
|
+
if (self.arch == 'gfx950') and (self.kpack != 1):
|
|
79
|
+
warnings.warn(
|
|
80
|
+
f"kpack is deprecated starting from gfx950 and will be removed in later releases. So for now kpack = {self.kpack} will be overwritten to 1 to make transitioning easier."
|
|
81
|
+
)
|
|
82
|
+
object.__setattr__(self, 'kpack', 1)
|
|
83
|
+
|
|
72
84
|
default_libdir = Path(__file__).parent / 'lib'
|
|
73
85
|
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
|
|
74
|
-
|
|
75
|
-
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
|
|
76
|
-
object.__setattr__(self, 'warp_size', warp_size)
|
|
77
|
-
# Only kpack=1 is supported on gfx950
|
|
78
|
-
kpack = 1 if self.arch == 'gfx950' else self.kpack
|
|
79
|
-
object.__setattr__(self, 'kpack', kpack)
|
|
80
|
-
libs = ["ocml", "ockl"]
|
|
81
|
-
for lib in libs:
|
|
86
|
+
for lib in ["ocml", "ockl"]:
|
|
82
87
|
extern_libs[lib] = str(default_libdir / f'{lib}.bc')
|
|
83
88
|
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
|
|
84
|
-
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
|
85
|
-
"num_warps must be a power of 2"
|
|
86
89
|
|
|
87
90
|
def hash(self):
|
|
88
91
|
key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
|
|
@@ -90,6 +93,7 @@ class HIPOptions:
|
|
|
90
93
|
|
|
91
94
|
|
|
92
95
|
class HIPBackend(BaseBackend):
|
|
96
|
+
instrumentation = None
|
|
93
97
|
|
|
94
98
|
@staticmethod
|
|
95
99
|
def supports_target(target: GPUTarget):
|
|
@@ -100,26 +104,33 @@ class HIPBackend(BaseBackend):
|
|
|
100
104
|
assert isinstance(target.arch, str)
|
|
101
105
|
self.binary_ext = "hsaco"
|
|
102
106
|
|
|
107
|
+
def get_target_name(self, options) -> str:
|
|
108
|
+
return f"hip:{options.arch}"
|
|
109
|
+
|
|
103
110
|
def parse_options(self, opts) -> Any:
|
|
104
|
-
args = {'arch':
|
|
111
|
+
args = {'arch': knobs.runtime.override_arch or self.target.arch}
|
|
112
|
+
|
|
113
|
+
if opts.get("num_ctas", 1) > 1:
|
|
114
|
+
raise ValueError("num_ctas > 1 not supported for AMD GPUs")
|
|
105
115
|
|
|
106
116
|
# Enable XF32 (TF32) for CDNA3 GPUs
|
|
107
|
-
if self.target.arch
|
|
117
|
+
if self.target.arch == 'gfx942':
|
|
108
118
|
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
|
|
109
119
|
allowed_dot_input_precisions.update({'tf32'})
|
|
110
120
|
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
|
|
111
121
|
|
|
112
122
|
if "supported_fp8_dtypes" not in opts:
|
|
113
|
-
supported_fp8_dtypes =
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
args["
|
|
123
|
+
args["supported_fp8_dtypes"] = tuple(sorted(HIPOptions.supported_fp8_dtypes))
|
|
124
|
+
|
|
125
|
+
if self.target.arch == 'gfx950':
|
|
126
|
+
deprecated_fp8_dot_operand_dtypes = set(HIPOptions.deprecated_fp8_dot_operand_dtypes)
|
|
127
|
+
deprecated_fp8_dot_operand_dtypes.update({"fp8e5b16", "fp8e4b8"})
|
|
128
|
+
args["deprecated_fp8_dot_operand_dtypes"] = tuple(sorted(deprecated_fp8_dot_operand_dtypes))
|
|
119
129
|
|
|
120
130
|
if "enable_fp_fusion" not in opts:
|
|
121
|
-
args["enable_fp_fusion"] =
|
|
122
|
-
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys()
|
|
131
|
+
args["enable_fp_fusion"] = knobs.language.default_fp_fusion
|
|
132
|
+
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() \
|
|
133
|
+
if k in opts and opts[k] is not None})
|
|
123
134
|
return HIPOptions(**args)
|
|
124
135
|
|
|
125
136
|
def pack_metadata(self, metadata):
|
|
@@ -133,8 +144,7 @@ class HIPBackend(BaseBackend):
|
|
|
133
144
|
)
|
|
134
145
|
|
|
135
146
|
def get_codegen_implementation(self, options):
|
|
136
|
-
|
|
137
|
-
return codegen_fns
|
|
147
|
+
return {"min_dot_size": get_min_dot_size(self.target)}
|
|
138
148
|
|
|
139
149
|
def get_module_map(self) -> Dict[str, ModuleType]:
|
|
140
150
|
from triton.language.extra.hip import libdevice
|
|
@@ -143,11 +153,8 @@ class HIPBackend(BaseBackend):
|
|
|
143
153
|
|
|
144
154
|
def load_dialects(self, ctx):
|
|
145
155
|
amd.load_dialects(ctx)
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
@functools.lru_cache()
|
|
149
|
-
def use_buffer_ops():
|
|
150
|
-
return os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
|
|
156
|
+
if HIPBackend.instrumentation:
|
|
157
|
+
HIPBackend.instrumentation.load_dialects(ctx)
|
|
151
158
|
|
|
152
159
|
@staticmethod
|
|
153
160
|
def is_within_2gb(arg):
|
|
@@ -172,41 +179,22 @@ class HIPBackend(BaseBackend):
|
|
|
172
179
|
ret = BaseBackend.get_arg_specialization(arg, ty, **kwargs)
|
|
173
180
|
# Only attempt to do buffer ops specialization if buffer ops are enabled.
|
|
174
181
|
# Otherwise the is_within_2gb check is unnecessary overhead.
|
|
175
|
-
if
|
|
182
|
+
if knobs.amd.use_buffer_ops and ty == "tensor" and HIPBackend.is_within_2gb(arg):
|
|
176
183
|
ret += "S"
|
|
177
184
|
return ret
|
|
178
185
|
|
|
179
|
-
@staticmethod
|
|
180
|
-
def path_to_rocm_lld():
|
|
181
|
-
# Check env path for ld.lld
|
|
182
|
-
lld_env_path = os.getenv("TRITON_HIP_LLD_PATH")
|
|
183
|
-
if lld_env_path is not None:
|
|
184
|
-
lld = Path(lld_env_path)
|
|
185
|
-
if lld.is_file():
|
|
186
|
-
return lld
|
|
187
|
-
# Check backend for ld.lld (used for pytorch wheels)
|
|
188
|
-
lld = Path(__file__).parent / "llvm/bin/ld.lld"
|
|
189
|
-
if lld.is_file():
|
|
190
|
-
return lld
|
|
191
|
-
lld = Path("/opt/rocm/llvm/bin/ld.lld")
|
|
192
|
-
if lld.is_file():
|
|
193
|
-
return lld
|
|
194
|
-
lld = Path("/usr/bin/ld.lld")
|
|
195
|
-
if lld.is_file():
|
|
196
|
-
return lld
|
|
197
|
-
raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found. Set 'TRITON_HIP_LLD_PATH' to its path.")
|
|
198
|
-
|
|
199
186
|
@staticmethod
|
|
200
187
|
def make_ttir(mod, metadata, options):
|
|
201
188
|
pm = ir.pass_manager(mod.context)
|
|
202
189
|
pm.enable_debug()
|
|
203
190
|
passes.common.add_inliner(pm)
|
|
204
191
|
passes.ttir.add_rewrite_tensor_pointer(pm)
|
|
192
|
+
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
|
|
205
193
|
passes.common.add_canonicalizer(pm)
|
|
206
194
|
passes.ttir.add_combine(pm)
|
|
207
195
|
passes.ttir.add_reorder_broadcast(pm)
|
|
208
196
|
passes.common.add_cse(pm)
|
|
209
|
-
passes.
|
|
197
|
+
passes.ttir.add_triton_licm(pm)
|
|
210
198
|
passes.common.add_symbol_dce(pm)
|
|
211
199
|
passes.ttir.add_loop_unroll(pm)
|
|
212
200
|
pm.run(mod)
|
|
@@ -230,39 +218,60 @@ class HIPBackend(BaseBackend):
|
|
|
230
218
|
passes.ttgpuir.add_optimize_dot_operands(pm, True)
|
|
231
219
|
amd.passes.ttgpuir.add_hoist_layout_conversions(pm)
|
|
232
220
|
|
|
233
|
-
|
|
234
|
-
|
|
221
|
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
|
222
|
+
passes.common.add_canonicalizer(pm)
|
|
223
|
+
passes.ttir.add_triton_licm(pm)
|
|
224
|
+
passes.common.add_canonicalizer(pm)
|
|
235
225
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
226
|
+
global_prefetch = knobs.amd.global_prefetch
|
|
227
|
+
local_prefetch = knobs.amd.local_prefetch
|
|
228
|
+
use_async_copy = knobs.amd.use_async_copy
|
|
229
|
+
use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)
|
|
239
230
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
amd.passes.ttgpuir.
|
|
247
|
-
passes.common.add_canonicalizer(pm)
|
|
248
|
-
if options.instruction_sched_variant.lower() != "none":
|
|
249
|
-
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
|
|
231
|
+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
|
|
232
|
+
use_block_pingpong)
|
|
233
|
+
if use_async_copy:
|
|
234
|
+
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
|
|
235
|
+
passes.common.add_canonicalizer(pm)
|
|
236
|
+
if options.schedule_hint.lower() != "none":
|
|
237
|
+
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.schedule_hint)
|
|
250
238
|
passes.ttgpuir.add_optimize_dot_operands(pm, True)
|
|
251
239
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
252
240
|
passes.ttgpuir.add_reduce_data_duplication(pm)
|
|
253
|
-
if
|
|
254
|
-
amd.passes.ttgpuir.
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
241
|
+
if is_in_thread_transpose_enabled(options.arch):
|
|
242
|
+
amd.passes.ttgpuir.add_in_thread_transpose(pm)
|
|
243
|
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
244
|
+
amd.passes.ttgpuir.add_reorder_instructions(pm)
|
|
245
|
+
if use_block_pingpong and options.num_stages > 1:
|
|
246
|
+
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
|
|
247
|
+
|
|
248
|
+
if knobs.amd.use_buffer_ops:
|
|
260
249
|
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
|
|
261
250
|
passes.common.add_canonicalizer(pm)
|
|
262
|
-
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
|
|
251
|
+
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch, knobs.amd.use_buffer_atomics)
|
|
252
|
+
|
|
253
|
+
amd.passes.ttgpuir.add_fold_true_cmpi(pm)
|
|
263
254
|
passes.common.add_canonicalizer(pm)
|
|
264
255
|
passes.common.add_cse(pm)
|
|
265
256
|
passes.common.add_symbol_dce(pm)
|
|
257
|
+
if use_async_copy:
|
|
258
|
+
amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
|
|
259
|
+
pm.run(mod)
|
|
260
|
+
return mod
|
|
261
|
+
|
|
262
|
+
@staticmethod
|
|
263
|
+
def gluon_to_ttgir(src, metadata, options):
|
|
264
|
+
mod = src
|
|
265
|
+
pm = ir.pass_manager(mod.context)
|
|
266
|
+
pm.enable_debug()
|
|
267
|
+
|
|
268
|
+
passes.gluon.add_inliner(pm)
|
|
269
|
+
passes.gluon.add_resolve_auto_encodings(pm)
|
|
270
|
+
passes.common.add_sccp(pm)
|
|
271
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
272
|
+
passes.gluon.add_canonicalizer(pm)
|
|
273
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
274
|
+
|
|
266
275
|
pm.run(mod)
|
|
267
276
|
return mod
|
|
268
277
|
|
|
@@ -272,7 +281,6 @@ class HIPBackend(BaseBackend):
|
|
|
272
281
|
# TritonGPU -> LLVM-IR (MLIR)
|
|
273
282
|
pm = ir.pass_manager(mod.context)
|
|
274
283
|
pm.enable_debug()
|
|
275
|
-
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch)
|
|
276
284
|
# custom_lds_size is an experimental parameter that defines amount of LDS available
|
|
277
285
|
# for one thread block. Measured in bytes.
|
|
278
286
|
#
|
|
@@ -283,7 +291,10 @@ class HIPBackend(BaseBackend):
|
|
|
283
291
|
passes.convert.add_scf_to_cf(pm)
|
|
284
292
|
passes.convert.add_index_to_llvmir(pm)
|
|
285
293
|
|
|
286
|
-
passes.ttgpuir.add_allocate_shared_memory(pm)
|
|
294
|
+
amd.passes.ttgpuir.add_allocate_shared_memory(pm)
|
|
295
|
+
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
|
|
296
|
+
if HIPBackend.instrumentation:
|
|
297
|
+
HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
|
|
287
298
|
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
|
|
288
299
|
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
|
|
289
300
|
## of the value of kernel arg `allow_flush_denorm`.
|
|
@@ -301,10 +312,17 @@ class HIPBackend(BaseBackend):
|
|
|
301
312
|
passes.common.add_canonicalizer(pm)
|
|
302
313
|
passes.common.add_cse(pm)
|
|
303
314
|
passes.common.add_symbol_dce(pm)
|
|
304
|
-
|
|
315
|
+
|
|
316
|
+
if options.schedule_hint.lower() != "none":
|
|
305
317
|
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
|
|
306
|
-
|
|
318
|
+
|
|
319
|
+
# This can not be moved below the di_scope pass
|
|
320
|
+
if HIPBackend.instrumentation:
|
|
321
|
+
HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
|
|
322
|
+
|
|
323
|
+
if not knobs.compilation.disable_line_info:
|
|
307
324
|
passes.llvmir.add_di_scope(pm)
|
|
325
|
+
|
|
308
326
|
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
|
|
309
327
|
pm.run(mod)
|
|
310
328
|
|
|
@@ -314,7 +332,7 @@ class HIPBackend(BaseBackend):
|
|
|
314
332
|
llvm_mod = llvm.to_module(mod, context)
|
|
315
333
|
amd.attach_target_triple(llvm_mod)
|
|
316
334
|
target_features = ''
|
|
317
|
-
if
|
|
335
|
+
if knobs.compilation.enable_asan:
|
|
318
336
|
target_features = '+xnack'
|
|
319
337
|
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
|
|
320
338
|
|
|
@@ -342,7 +360,7 @@ class HIPBackend(BaseBackend):
|
|
|
342
360
|
fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
|
|
343
361
|
denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
|
|
344
362
|
fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
|
|
345
|
-
if
|
|
363
|
+
if knobs.compilation.enable_asan:
|
|
346
364
|
fns[0].add_fn_target_feature("+xnack")
|
|
347
365
|
fns[0].add_fn_asan_attr()
|
|
348
366
|
|
|
@@ -351,7 +369,7 @@ class HIPBackend(BaseBackend):
|
|
|
351
369
|
# from memory.
|
|
352
370
|
amd.set_all_fn_arg_inreg(fns[0])
|
|
353
371
|
|
|
354
|
-
if
|
|
372
|
+
if knobs.compilation.enable_asan:
|
|
355
373
|
default_libdir = Path(__file__).parent / 'lib'
|
|
356
374
|
paths = [
|
|
357
375
|
str(default_libdir / 'asanrtl.bc'),
|
|
@@ -361,12 +379,27 @@ class HIPBackend(BaseBackend):
|
|
|
361
379
|
llvm.link_extern_libs(llvm_mod, paths)
|
|
362
380
|
elif options.extern_libs:
|
|
363
381
|
paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
|
|
364
|
-
|
|
382
|
+
if len(paths) > 0:
|
|
383
|
+
llvm.link_extern_libs(llvm_mod, paths)
|
|
365
384
|
|
|
366
385
|
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
|
|
367
386
|
|
|
387
|
+
# Architectures with architected SGPRs store the workgroup id in ttmp9 (X) and ttmp7 (Y[15:0], Z[31:16]).
|
|
388
|
+
# These attributes are used to determine if Z should be masked out when loading Y. They are inferred during
|
|
389
|
+
# optimize_module from calls to @llvm.amdgcn.workgroup.id.x/y/z(). We cannot rely on this because a
|
|
390
|
+
# dispatch dimensions might be used even if there is no program_id() call for it.
|
|
391
|
+
if amd.has_architected_sgprs(options.arch):
|
|
392
|
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-x")
|
|
393
|
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-y")
|
|
394
|
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-z")
|
|
395
|
+
|
|
396
|
+
if knobs.amd.scalarize_packed_fops:
|
|
397
|
+
amd.add_scalarize_packed_fops_llvm_pass(fns[0])
|
|
398
|
+
|
|
368
399
|
# Get some metadata
|
|
369
400
|
metadata["shared"] = src.get_int_attr("ttg.shared")
|
|
401
|
+
metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
|
|
402
|
+
metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
|
|
370
403
|
|
|
371
404
|
amd.cleanup_bitcode_metadata(llvm_mod)
|
|
372
405
|
# Disable inlining of print related functions,
|
|
@@ -377,14 +410,23 @@ class HIPBackend(BaseBackend):
|
|
|
377
410
|
@staticmethod
|
|
378
411
|
def make_amdgcn(src, metadata, options):
|
|
379
412
|
# Find kernel names (there should only be one)
|
|
380
|
-
# We get the name at the last possible step to
|
|
413
|
+
# We get the name at the last possible step to accommodate `triton.compile`
|
|
381
414
|
# on user-provided LLVM
|
|
382
415
|
names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src)
|
|
383
416
|
assert len(names) == 1
|
|
384
417
|
metadata["name"] = names[0]
|
|
385
418
|
# llvm -> hsaco
|
|
386
|
-
|
|
387
|
-
|
|
419
|
+
flags = []
|
|
420
|
+
# The sink-insts-to-avoid-spills flag asks LLVM backend to sink instructions
|
|
421
|
+
# into loops to avoid register spills in the MachineSinking pass, while it
|
|
422
|
+
# can also lead to regression in some cases. But from current observation,
|
|
423
|
+
# the regression is not significant. It would be better to have some heuristics.
|
|
424
|
+
if options.schedule_hint == 'attention':
|
|
425
|
+
flags.append('sink-insts-to-avoid-spills')
|
|
426
|
+
features = '-real-true16' if 'gfx11' in options.arch else ''
|
|
427
|
+
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
|
|
428
|
+
False)
|
|
429
|
+
if knobs.amd.dump_amdgcn:
|
|
388
430
|
print("// -----// AMDGCN Dump //----- //")
|
|
389
431
|
print(amdgcn)
|
|
390
432
|
return amdgcn
|
|
@@ -392,28 +434,28 @@ class HIPBackend(BaseBackend):
|
|
|
392
434
|
@staticmethod
|
|
393
435
|
def make_hsaco(src, metadata, options):
|
|
394
436
|
target_features = ''
|
|
395
|
-
if
|
|
437
|
+
if knobs.compilation.enable_asan:
|
|
396
438
|
target_features = '+xnack'
|
|
397
439
|
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
|
|
398
|
-
|
|
399
|
-
rocm_path = HIPBackend.path_to_rocm_lld()
|
|
400
440
|
with tempfile.NamedTemporaryFile() as tmp_out:
|
|
401
441
|
with tempfile.NamedTemporaryFile() as tmp_in:
|
|
402
|
-
with open(tmp_in.name,
|
|
442
|
+
with open(tmp_in.name, "wb") as fd_in:
|
|
403
443
|
fd_in.write(hsaco)
|
|
404
|
-
|
|
405
|
-
with open(tmp_out.name,
|
|
444
|
+
amd.link_hsaco(tmp_in.name, tmp_out.name)
|
|
445
|
+
with open(tmp_out.name, "rb") as fd_out:
|
|
406
446
|
ret = fd_out.read()
|
|
407
447
|
return ret
|
|
408
448
|
|
|
409
|
-
def add_stages(self, stages, options):
|
|
410
|
-
|
|
411
|
-
|
|
449
|
+
def add_stages(self, stages, options, language):
|
|
450
|
+
if language == Language.TRITON:
|
|
451
|
+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
|
|
452
|
+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
|
|
453
|
+
elif language == Language.GLUON:
|
|
454
|
+
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
|
|
412
455
|
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
|
|
413
456
|
stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
|
|
414
457
|
stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
|
|
415
458
|
|
|
416
459
|
@functools.lru_cache()
|
|
417
460
|
def hash(self):
|
|
418
|
-
|
|
419
|
-
return f'{version}-{self.target}'
|
|
461
|
+
return f'{self.target}'
|