triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/backends/amd/driver.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import os
|
|
3
|
-
import hashlib
|
|
4
3
|
import subprocess
|
|
5
|
-
import
|
|
4
|
+
import re
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
from triton
|
|
8
|
-
from triton.runtime.cache import get_cache_manager
|
|
6
|
+
from triton import knobs
|
|
9
7
|
from triton.backends.compiler import GPUTarget
|
|
10
8
|
from triton.backends.driver import GPUDriver
|
|
9
|
+
from triton.runtime import _allocation
|
|
10
|
+
from triton.runtime.build import compile_module_from_src
|
|
11
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
11
12
|
|
|
12
13
|
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
13
|
-
|
|
14
|
+
include_dirs = [os.path.join(dirname, "include")]
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
def _find_already_mmapped_dylib_on_linux(lib_name):
|
|
@@ -66,8 +67,7 @@ def _get_path_to_hip_runtime_dylib():
|
|
|
66
67
|
lib_name = "libamdhip64.so"
|
|
67
68
|
|
|
68
69
|
# If we are told explicitly what HIP runtime dynamic library to use, obey that.
|
|
69
|
-
env_libhip_path
|
|
70
|
-
if env_libhip_path:
|
|
70
|
+
if env_libhip_path := knobs.amd.libhip_path:
|
|
71
71
|
if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
|
|
72
72
|
return env_libhip_path
|
|
73
73
|
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
|
|
@@ -81,6 +81,12 @@ def _get_path_to_hip_runtime_dylib():
|
|
|
81
81
|
|
|
82
82
|
paths = []
|
|
83
83
|
|
|
84
|
+
# Check backend
|
|
85
|
+
local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
|
|
86
|
+
if os.path.exists(local_lib):
|
|
87
|
+
return local_lib
|
|
88
|
+
paths.append(local_lib)
|
|
89
|
+
|
|
84
90
|
import site
|
|
85
91
|
# First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
|
|
86
92
|
# that we run Triton together with PyTorch. This makes sure we use the same dynamic
|
|
@@ -104,8 +110,36 @@ def _get_path_to_hip_runtime_dylib():
|
|
|
104
110
|
return f
|
|
105
111
|
paths.append(f)
|
|
106
112
|
|
|
113
|
+
# HIP_PATH should point to HIP SDK root if set
|
|
114
|
+
env_hip_path = os.getenv("HIP_PATH")
|
|
115
|
+
if env_hip_path:
|
|
116
|
+
hip_lib_path = os.path.join(env_hip_path, "lib", lib_name)
|
|
117
|
+
if os.path.exists(hip_lib_path):
|
|
118
|
+
return hip_lib_path
|
|
119
|
+
paths.append(hip_lib_path)
|
|
120
|
+
|
|
121
|
+
# if available, `hipconfig --path` prints the HIP SDK root
|
|
122
|
+
try:
|
|
123
|
+
hip_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
|
|
124
|
+
if hip_root:
|
|
125
|
+
hip_lib_path = os.path.join(hip_root, "lib", lib_name)
|
|
126
|
+
if os.path.exists(hip_lib_path):
|
|
127
|
+
return hip_lib_path
|
|
128
|
+
paths.append(hip_lib_path)
|
|
129
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
130
|
+
# hipconfig may not be available
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# ROCm lib dir based on env var
|
|
134
|
+
env_rocm_path = os.getenv("ROCM_PATH")
|
|
135
|
+
if env_rocm_path:
|
|
136
|
+
rocm_lib_path = os.path.join(env_rocm_path, "lib", lib_name)
|
|
137
|
+
if os.path.exists(rocm_lib_path):
|
|
138
|
+
return rocm_lib_path
|
|
139
|
+
paths.append(rocm_lib_path)
|
|
140
|
+
|
|
107
141
|
# Afterwards try to search the loader dynamic library resolution paths.
|
|
108
|
-
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
|
|
142
|
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
|
|
109
143
|
# each line looks like the following:
|
|
110
144
|
# libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
|
|
111
145
|
# libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
|
|
@@ -124,25 +158,6 @@ def _get_path_to_hip_runtime_dylib():
|
|
|
124
158
|
raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
|
|
125
159
|
|
|
126
160
|
|
|
127
|
-
def compile_module_from_src(src, name):
|
|
128
|
-
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
|
|
129
|
-
cache = get_cache_manager(key)
|
|
130
|
-
cache_path = cache.get_file(f"{name}.so")
|
|
131
|
-
if cache_path is None:
|
|
132
|
-
with tempfile.TemporaryDirectory() as tmpdir:
|
|
133
|
-
src_path = os.path.join(tmpdir, f"{name}.c")
|
|
134
|
-
with open(src_path, "w") as f:
|
|
135
|
-
f.write(src)
|
|
136
|
-
so = _build(name, src_path, tmpdir, [], include_dir, [])
|
|
137
|
-
with open(so, "rb") as f:
|
|
138
|
-
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
|
|
139
|
-
import importlib.util
|
|
140
|
-
spec = importlib.util.spec_from_file_location(name, cache_path)
|
|
141
|
-
mod = importlib.util.module_from_spec(spec)
|
|
142
|
-
spec.loader.exec_module(mod)
|
|
143
|
-
return mod
|
|
144
|
-
|
|
145
|
-
|
|
146
161
|
class HIPUtils(object):
|
|
147
162
|
|
|
148
163
|
def __new__(cls):
|
|
@@ -157,7 +172,7 @@ class HIPUtils(object):
|
|
|
157
172
|
# This way we don't need to escape-quote C code curly brackets and we can replace
|
|
158
173
|
# exactly once.
|
|
159
174
|
src = src.replace('/*py_libhip_search_path*/', libhip_path, 1)
|
|
160
|
-
mod = compile_module_from_src(src, "hip_utils")
|
|
175
|
+
mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
|
|
161
176
|
self.load_binary = mod.load_binary
|
|
162
177
|
self.get_device_properties = mod.get_device_properties
|
|
163
178
|
|
|
@@ -167,26 +182,71 @@ def ty_to_cpp(ty):
|
|
|
167
182
|
if ty[0] == '*':
|
|
168
183
|
return "hipDeviceptr_t"
|
|
169
184
|
return {
|
|
170
|
-
"i1": "
|
|
185
|
+
"i1": "int8_t",
|
|
171
186
|
"i8": "int8_t",
|
|
172
187
|
"i16": "int16_t",
|
|
173
188
|
"i32": "int32_t",
|
|
174
189
|
"i64": "int64_t",
|
|
175
|
-
"u1": "
|
|
190
|
+
"u1": "uint8_t",
|
|
176
191
|
"u8": "uint8_t",
|
|
177
192
|
"u16": "uint16_t",
|
|
178
193
|
"u32": "uint32_t",
|
|
179
194
|
"u64": "uint64_t",
|
|
180
|
-
"fp16": "
|
|
181
|
-
"bf16": "
|
|
182
|
-
"fp32": "
|
|
183
|
-
"f32": "
|
|
195
|
+
"fp16": "double",
|
|
196
|
+
"bf16": "double",
|
|
197
|
+
"fp32": "double",
|
|
198
|
+
"f32": "double",
|
|
184
199
|
"fp64": "double",
|
|
185
200
|
}[ty]
|
|
186
201
|
|
|
187
202
|
|
|
203
|
+
FLOAT_STORAGE_TYPE = {
|
|
204
|
+
"fp16": "uint16_t",
|
|
205
|
+
"bf16": "uint16_t",
|
|
206
|
+
"fp32": "uint32_t",
|
|
207
|
+
"f32": "uint32_t",
|
|
208
|
+
"fp64": "uint64_t",
|
|
209
|
+
}
|
|
210
|
+
FLOAT_PACK_FUNCTION = {
|
|
211
|
+
"fp16": "pack_fp16",
|
|
212
|
+
"bf16": "pack_bf16",
|
|
213
|
+
"fp32": "pack_fp32",
|
|
214
|
+
"f32": "pack_fp32",
|
|
215
|
+
"fp64": "pack_fp64",
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
_BASE_ARGS_FORMAT = "piiiKKOOOOO"
|
|
219
|
+
|
|
220
|
+
|
|
188
221
|
def make_launcher(constants, signature, warp_size):
|
|
189
222
|
|
|
223
|
+
def _expand_signature(signature):
|
|
224
|
+
output = []
|
|
225
|
+
# Expand tensor descriptor arguments into base pointer, shape, and
|
|
226
|
+
# strides
|
|
227
|
+
for sig in signature:
|
|
228
|
+
if isinstance(sig, str) and sig.startswith("tensordesc"):
|
|
229
|
+
ndim = sig.count(",") + 1
|
|
230
|
+
dtype = re.match("tensordesc<([^[>]*)", sig).group()
|
|
231
|
+
|
|
232
|
+
output.append("*" + dtype)
|
|
233
|
+
for _ in range(2 * ndim):
|
|
234
|
+
output.append("i64")
|
|
235
|
+
output.append("i1")
|
|
236
|
+
# Currently the host side tensor descriptors get passed in as a
|
|
237
|
+
# tensor desc, shape, and strides. We have no way to use these
|
|
238
|
+
# shape and strides when processing tensor descriptors which is
|
|
239
|
+
# why we provide our own decomposition above. Sadly this means
|
|
240
|
+
# we have to pass the shape and strides twice.
|
|
241
|
+
for _ in range(ndim):
|
|
242
|
+
output.append("i32")
|
|
243
|
+
for _ in range(ndim):
|
|
244
|
+
output.append("i64")
|
|
245
|
+
else:
|
|
246
|
+
output.append(sig)
|
|
247
|
+
|
|
248
|
+
return output
|
|
249
|
+
|
|
190
250
|
def _serialize_signature(sig):
|
|
191
251
|
if isinstance(sig, tuple):
|
|
192
252
|
return ','.join(map(_serialize_signature, sig))
|
|
@@ -198,7 +258,7 @@ def make_launcher(constants, signature, warp_size):
|
|
|
198
258
|
return f"[{val}]"
|
|
199
259
|
if ty[0] == '*':
|
|
200
260
|
return "PyObject*"
|
|
201
|
-
if ty
|
|
261
|
+
if ty == "constexpr":
|
|
202
262
|
return "PyObject*"
|
|
203
263
|
return ty_to_cpp(ty)
|
|
204
264
|
|
|
@@ -208,10 +268,9 @@ def make_launcher(constants, signature, warp_size):
|
|
|
208
268
|
return f"({val})"
|
|
209
269
|
if ty[0] == '*':
|
|
210
270
|
return "O"
|
|
211
|
-
if ty
|
|
271
|
+
if ty == "constexpr":
|
|
212
272
|
return "O"
|
|
213
273
|
return {
|
|
214
|
-
"float": "f",
|
|
215
274
|
"double": "d",
|
|
216
275
|
"long": "l",
|
|
217
276
|
"int8_t": "b",
|
|
@@ -224,30 +283,51 @@ def make_launcher(constants, signature, warp_size):
|
|
|
224
283
|
"uint64_t": "K",
|
|
225
284
|
}[ty_to_cpp(ty)]
|
|
226
285
|
|
|
286
|
+
signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
|
|
287
|
+
|
|
227
288
|
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
|
228
|
-
format =
|
|
289
|
+
format = _BASE_ARGS_FORMAT + args_format
|
|
229
290
|
signature = ','.join(map(_serialize_signature, signature.values()))
|
|
230
291
|
signature = list(filter(bool, signature.split(',')))
|
|
231
292
|
signature = {i: s for i, s in enumerate(signature)}
|
|
232
293
|
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
|
233
294
|
# Record the end of regular arguments;
|
|
234
295
|
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
235
|
-
|
|
296
|
+
arg_decl_list = []
|
|
297
|
+
for i, ty in signature.items():
|
|
298
|
+
if ty == "constexpr":
|
|
299
|
+
continue
|
|
300
|
+
if ty in FLOAT_STORAGE_TYPE:
|
|
301
|
+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
|
|
302
|
+
else:
|
|
303
|
+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
|
|
304
|
+
arg_decls = ', '.join(arg_decl_list)
|
|
236
305
|
internal_args_list = []
|
|
237
306
|
for i, ty in signature.items():
|
|
238
307
|
if ty[0] == "*":
|
|
239
308
|
internal_args_list.append(f"ptr_info{i}.dev_ptr")
|
|
309
|
+
elif ty in FLOAT_STORAGE_TYPE:
|
|
310
|
+
internal_args_list.append(f"_arg{i}_storage")
|
|
240
311
|
elif ty != "constexpr":
|
|
241
312
|
internal_args_list.append(f"_arg{i}")
|
|
313
|
+
|
|
314
|
+
float_storage_decls = [
|
|
315
|
+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
|
|
316
|
+
for i, ty in signature.items()
|
|
317
|
+
if ty in FLOAT_STORAGE_TYPE
|
|
318
|
+
]
|
|
319
|
+
|
|
242
320
|
libhip_path = _get_path_to_hip_runtime_dylib()
|
|
243
321
|
|
|
244
322
|
# generate glue code
|
|
245
323
|
params = list(range(len(signature)))
|
|
246
324
|
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
|
247
325
|
params.append("&global_scratch")
|
|
326
|
+
params.append("&profile_scratch")
|
|
248
327
|
src = f"""
|
|
249
328
|
#define __HIP_PLATFORM_AMD__
|
|
250
329
|
#include <hip/hip_runtime.h>
|
|
330
|
+
#include <hip/hip_runtime_api.h>
|
|
251
331
|
#include <Python.h>
|
|
252
332
|
#include <dlfcn.h>
|
|
253
333
|
#include <stdbool.h>
|
|
@@ -260,6 +340,7 @@ static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
|
|
|
260
340
|
// The list of HIP dynamic library symbols and their signature we are interested
|
|
261
341
|
// in this file.
|
|
262
342
|
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
|
|
343
|
+
FOR_EACH_STR_FN(hipGetLastError) \\
|
|
263
344
|
FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \\
|
|
264
345
|
FOR_EACH_ERR_FN(hipModuleLaunchKernel, hipFunction_t f, \\
|
|
265
346
|
unsigned int gridDimX, unsigned int gridDimY, \\
|
|
@@ -291,9 +372,6 @@ static struct HIPSymbolTable hipSymbolTable;
|
|
|
291
372
|
bool initSymbolTable() {{
|
|
292
373
|
// Use the HIP runtime library loaded into the existing process if it exits.
|
|
293
374
|
void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
|
|
294
|
-
if (lib) {{
|
|
295
|
-
// printf("[triton] chosen loaded libamdhip64.so in the process\\n");
|
|
296
|
-
}}
|
|
297
375
|
|
|
298
376
|
// Otherwise, go through the list of search paths to dlopen the first HIP
|
|
299
377
|
// driver library.
|
|
@@ -303,7 +381,6 @@ bool initSymbolTable() {{
|
|
|
303
381
|
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
|
|
304
382
|
if (handle) {{
|
|
305
383
|
lib = handle;
|
|
306
|
-
// printf("[triton] chosen %s\\n", hipLibSearchPaths[i]);
|
|
307
384
|
}}
|
|
308
385
|
}}
|
|
309
386
|
}}
|
|
@@ -312,17 +389,36 @@ bool initSymbolTable() {{
|
|
|
312
389
|
return false;
|
|
313
390
|
}}
|
|
314
391
|
|
|
315
|
-
|
|
392
|
+
typedef hipError_t (*hipGetProcAddress_fn)(
|
|
393
|
+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
|
|
394
|
+
hipDriverProcAddressQueryResult *symbolStatus);
|
|
395
|
+
hipGetProcAddress_fn hipGetProcAddress;
|
|
316
396
|
dlerror(); // Clear existing errors
|
|
317
397
|
const char *error = NULL;
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
398
|
+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
|
|
399
|
+
error = dlerror();
|
|
400
|
+
if (error) {{
|
|
401
|
+
PyErr_SetString(PyExc_RuntimeError,
|
|
402
|
+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
|
|
403
|
+
dlclose(lib);
|
|
404
|
+
return false;
|
|
405
|
+
}}
|
|
406
|
+
|
|
407
|
+
// Resolve all symbols we are interested in.
|
|
408
|
+
int hipVersion = HIP_VERSION;
|
|
409
|
+
uint64_t hipFlags = 0;
|
|
410
|
+
hipDriverProcAddressQueryResult symbolStatus;
|
|
411
|
+
hipError_t status = hipSuccess;
|
|
412
|
+
#define QUERY_EACH_FN(hipSymbolName, ...) \
|
|
413
|
+
status = hipGetProcAddress(#hipSymbolName, \
|
|
414
|
+
(void **)&hipSymbolTable.hipSymbolName, \
|
|
415
|
+
hipVersion, hipFlags, &symbolStatus); \
|
|
416
|
+
if (status != hipSuccess) {{ \
|
|
417
|
+
PyErr_SetString(PyExc_RuntimeError, \
|
|
418
|
+
"cannot get address for '" #hipSymbolName \
|
|
419
|
+
"' from libamdhip64.so"); \
|
|
420
|
+
dlclose(lib); \
|
|
421
|
+
return false; \
|
|
326
422
|
}}
|
|
327
423
|
|
|
328
424
|
HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
|
|
@@ -344,8 +440,7 @@ static inline void gpuAssert(hipError_t code, const char *file, int line)
|
|
|
344
440
|
|
|
345
441
|
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
|
346
442
|
|
|
347
|
-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
348
|
-
// printf("_launch hip kernel\\n");
|
|
443
|
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
349
444
|
hipDeviceptr_t global_scratch = 0;
|
|
350
445
|
void *params[] = {{ {', '.join(params)} }};
|
|
351
446
|
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
|
|
@@ -362,8 +457,11 @@ typedef struct _DevicePtrInfo {{
|
|
|
362
457
|
bool valid;
|
|
363
458
|
}} DevicePtrInfo;
|
|
364
459
|
|
|
460
|
+
static PyObject* data_ptr_str = NULL;
|
|
461
|
+
|
|
365
462
|
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
366
463
|
DevicePtrInfo ptr_info;
|
|
464
|
+
hipError_t status = hipSuccess;
|
|
367
465
|
ptr_info.dev_ptr = 0;
|
|
368
466
|
ptr_info.valid = true;
|
|
369
467
|
if (PyLong_Check(obj)) {{
|
|
@@ -374,53 +472,81 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
|
374
472
|
// valid nullptr
|
|
375
473
|
return ptr_info;
|
|
376
474
|
}}
|
|
377
|
-
PyObject *
|
|
378
|
-
if(
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
475
|
+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
|
|
476
|
+
if (!ret) {{
|
|
477
|
+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
|
478
|
+
ptr_info.valid = false;
|
|
479
|
+
goto cleanup;
|
|
480
|
+
}}
|
|
481
|
+
if (!PyLong_Check(ret)) {{
|
|
482
|
+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
|
483
|
+
ptr_info.valid = false;
|
|
484
|
+
goto cleanup;
|
|
485
|
+
}}
|
|
486
|
+
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
|
487
|
+
if (!ptr_info.dev_ptr)
|
|
488
|
+
goto cleanup;
|
|
489
|
+
uint64_t dev_ptr;
|
|
490
|
+
status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
491
|
+
if (status == hipErrorInvalidValue) {{
|
|
492
|
+
PyErr_Format(PyExc_ValueError,
|
|
493
|
+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
385
494
|
ptr_info.valid = false;
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
|
389
|
-
if(!ptr_info.dev_ptr)
|
|
390
|
-
return ptr_info;
|
|
391
|
-
uint64_t dev_ptr;
|
|
392
|
-
hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
393
|
-
if (status == hipErrorInvalidValue) {{
|
|
394
|
-
PyErr_Format(PyExc_ValueError,
|
|
395
|
-
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
396
|
-
ptr_info.valid = false;
|
|
397
|
-
}}
|
|
398
|
-
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
|
399
|
-
Py_DECREF(ret);
|
|
400
|
-
return ptr_info;
|
|
495
|
+
// Clear and ignore HIP error
|
|
496
|
+
(void)hipSymbolTable.hipGetLastError();
|
|
401
497
|
}}
|
|
402
|
-
|
|
498
|
+
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
|
499
|
+
cleanup:
|
|
500
|
+
Py_DECREF(ret);
|
|
403
501
|
return ptr_info;
|
|
404
502
|
}}
|
|
405
503
|
|
|
504
|
+
static uint16_t pack_fp16(double f) {{
|
|
505
|
+
uint16_t result;
|
|
506
|
+
// from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
|
|
507
|
+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
|
508
|
+
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
509
|
+
#else
|
|
510
|
+
PyFloat_Pack2(f, (char*)&result, 1);
|
|
511
|
+
#endif
|
|
512
|
+
return result;
|
|
513
|
+
}}
|
|
514
|
+
|
|
515
|
+
static uint16_t pack_bf16(double f) {{
|
|
516
|
+
float f32 = (float)f;
|
|
517
|
+
uint32_t u32 = *(uint32_t*)&f32;
|
|
518
|
+
return (uint16_t)(u32 >> 16);
|
|
519
|
+
}}
|
|
520
|
+
|
|
521
|
+
static uint32_t pack_fp32(double f) {{
|
|
522
|
+
float f32 = (float)f;
|
|
523
|
+
return *(uint32_t*)&f32;
|
|
524
|
+
}}
|
|
525
|
+
|
|
526
|
+
static uint64_t pack_fp64(double f) {{
|
|
527
|
+
return *(uint64_t*)&f;
|
|
528
|
+
}}
|
|
529
|
+
|
|
406
530
|
static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
407
|
-
// printf("launch\\n");
|
|
408
531
|
int gridX, gridY, gridZ;
|
|
409
532
|
uint64_t _stream;
|
|
410
533
|
uint64_t _function;
|
|
411
534
|
int launch_cooperative_grid;
|
|
535
|
+
PyObject *profile_scratch_obj = NULL;
|
|
412
536
|
PyObject *launch_enter_hook = NULL;
|
|
413
537
|
PyObject *launch_exit_hook = NULL;
|
|
414
538
|
PyObject *kernel_metadata = NULL;
|
|
415
539
|
PyObject *launch_metadata = NULL;
|
|
416
540
|
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
|
417
541
|
if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
|
|
418
|
-
&gridX, &gridY, &gridZ, &_stream, &_function,
|
|
542
|
+
&gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
|
|
419
543
|
&kernel_metadata, &launch_metadata,
|
|
420
544
|
&launch_enter_hook, &launch_exit_hook {args_list})) {{
|
|
421
545
|
return NULL;
|
|
422
546
|
}}
|
|
423
547
|
|
|
548
|
+
{' '.join(float_storage_decls)}
|
|
549
|
+
|
|
424
550
|
// extract kernel metadata
|
|
425
551
|
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
|
|
426
552
|
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
|
|
@@ -428,32 +554,36 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
428
554
|
}}
|
|
429
555
|
// extract launch metadata
|
|
430
556
|
if (launch_enter_hook != Py_None){{
|
|
431
|
-
PyObject*
|
|
432
|
-
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
|
|
433
|
-
Py_DECREF(args);
|
|
557
|
+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
|
|
434
558
|
if (!ret)
|
|
435
559
|
return NULL;
|
|
560
|
+
Py_DECREF(ret);
|
|
436
561
|
}}
|
|
437
562
|
|
|
563
|
+
hipDeviceptr_t profile_scratch = 0;
|
|
564
|
+
if (profile_scratch_obj != Py_None) {{
|
|
565
|
+
DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
|
|
566
|
+
if (!profile_scratch_info.valid) {{
|
|
567
|
+
return NULL;
|
|
568
|
+
}}
|
|
569
|
+
profile_scratch = profile_scratch_info.dev_ptr;
|
|
570
|
+
}}
|
|
438
571
|
|
|
439
572
|
// raise exception asap
|
|
440
573
|
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
|
441
|
-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
574
|
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
442
575
|
|
|
443
576
|
if(launch_exit_hook != Py_None){{
|
|
444
|
-
PyObject*
|
|
445
|
-
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
|
|
446
|
-
Py_DECREF(args);
|
|
577
|
+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
|
|
447
578
|
if (!ret)
|
|
448
579
|
return NULL;
|
|
580
|
+
Py_DECREF(ret);
|
|
449
581
|
}}
|
|
450
582
|
|
|
451
583
|
if(PyErr_Occurred()) {{
|
|
452
584
|
return NULL;
|
|
453
585
|
}}
|
|
454
|
-
|
|
455
|
-
Py_INCREF(Py_None);
|
|
456
|
-
return Py_None;
|
|
586
|
+
Py_RETURN_NONE;
|
|
457
587
|
}}
|
|
458
588
|
|
|
459
589
|
static PyMethodDef ModuleMethods[] = {{
|
|
@@ -477,6 +607,10 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
|
477
607
|
if(m == NULL) {{
|
|
478
608
|
return NULL;
|
|
479
609
|
}}
|
|
610
|
+
data_ptr_str = PyUnicode_InternFromString("data_ptr");
|
|
611
|
+
if(data_ptr_str == NULL) {{
|
|
612
|
+
return NULL;
|
|
613
|
+
}}
|
|
480
614
|
PyModule_AddFunctions(m, ModuleMethods);
|
|
481
615
|
return m;
|
|
482
616
|
}}
|
|
@@ -484,6 +618,31 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
|
484
618
|
return src
|
|
485
619
|
|
|
486
620
|
|
|
621
|
+
def wrap_handle_tensor_descriptor(launcher):
|
|
622
|
+
"""
|
|
623
|
+
Replace all tensor descriptors with the base ptr, shape, and strides
|
|
624
|
+
"""
|
|
625
|
+
|
|
626
|
+
def inner(*args):
|
|
627
|
+
meta_args = args[:len(_BASE_ARGS_FORMAT)]
|
|
628
|
+
raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
|
|
629
|
+
final_args = []
|
|
630
|
+
for arg in raw_kernel_args:
|
|
631
|
+
if isinstance(arg, TensorDescriptor):
|
|
632
|
+
# Currently the host side tensor descriptors get decomposed in
|
|
633
|
+
# the frontend to tensor desc, shape, and strides. We have no
|
|
634
|
+
# way to use these shape and strides when processing tensor
|
|
635
|
+
# descriptors which is why we provide our own decomposition
|
|
636
|
+
# above. Sadly this means we have to pass the shape and strides
|
|
637
|
+
# twice.
|
|
638
|
+
final_args.extend([arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides])
|
|
639
|
+
else:
|
|
640
|
+
final_args.append(arg)
|
|
641
|
+
return launcher(*meta_args, *final_args)
|
|
642
|
+
|
|
643
|
+
return inner
|
|
644
|
+
|
|
645
|
+
|
|
487
646
|
class HIPLauncher(object):
|
|
488
647
|
|
|
489
648
|
def __init__(self, src, metadata):
|
|
@@ -492,12 +651,28 @@ class HIPLauncher(object):
|
|
|
492
651
|
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
|
493
652
|
signature = {idx: value for idx, value in src.signature.items()}
|
|
494
653
|
src = make_launcher(constants, signature, metadata.warp_size)
|
|
495
|
-
mod = compile_module_from_src(src, "__triton_launcher")
|
|
496
|
-
|
|
654
|
+
mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
|
|
655
|
+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
|
|
656
|
+
|
|
657
|
+
self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch
|
|
497
658
|
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
|
659
|
+
self.profile_scratch_size = metadata.profile_scratch_size
|
|
660
|
+
self.profile_scratch_align = metadata.profile_scratch_align
|
|
661
|
+
|
|
662
|
+
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
|
663
|
+
|
|
664
|
+
def allocate_scratch(size, align, allocator):
|
|
665
|
+
if size > 0:
|
|
666
|
+
grid_size = gridX * gridY * gridZ
|
|
667
|
+
alloc_size = grid_size * size
|
|
668
|
+
alloc_fn = allocator.get()
|
|
669
|
+
return alloc_fn(alloc_size, align, stream)
|
|
670
|
+
return None
|
|
498
671
|
|
|
499
|
-
|
|
500
|
-
|
|
672
|
+
profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
|
|
673
|
+
_allocation._profile_allocator)
|
|
674
|
+
|
|
675
|
+
self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
|
|
501
676
|
|
|
502
677
|
|
|
503
678
|
class HIPDriver(GPUDriver):
|
|
@@ -515,14 +690,17 @@ class HIPDriver(GPUDriver):
|
|
|
515
690
|
def is_active():
|
|
516
691
|
try:
|
|
517
692
|
import torch
|
|
518
|
-
return torch.version.hip is not None
|
|
693
|
+
return torch.cuda.is_available() and (torch.version.hip is not None)
|
|
519
694
|
except ImportError:
|
|
520
695
|
return False
|
|
521
696
|
|
|
697
|
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
698
|
+
return ty_to_cpp(ty)
|
|
699
|
+
|
|
522
700
|
def get_current_target(self):
|
|
523
701
|
device = self.get_current_device()
|
|
524
702
|
device_properties = self.utils.get_device_properties(device)
|
|
525
|
-
arch = device_properties['arch']
|
|
703
|
+
arch = knobs.runtime.override_arch or device_properties['arch']
|
|
526
704
|
warp_size = device_properties['warpSize']
|
|
527
705
|
return GPUTarget("hip", arch.split(':')[0], warp_size)
|
|
528
706
|
|
triton/backends/compiler.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import re
|
|
3
|
-
import subprocess
|
|
4
|
-
import sysconfig
|
|
5
1
|
from abc import ABCMeta, abstractmethod
|
|
6
2
|
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
7
4
|
from typing import Dict, Union
|
|
8
5
|
from types import ModuleType
|
|
9
6
|
|
|
@@ -17,6 +14,12 @@ class GPUTarget(object):
|
|
|
17
14
|
warp_size: int
|
|
18
15
|
|
|
19
16
|
|
|
17
|
+
class Language(Enum):
|
|
18
|
+
"""The input language being compiled by the backend."""
|
|
19
|
+
TRITON = 0
|
|
20
|
+
GLUON = 1
|
|
21
|
+
|
|
22
|
+
|
|
20
23
|
class BaseBackend(metaclass=ABCMeta):
|
|
21
24
|
|
|
22
25
|
def __init__(self, target: GPUTarget) -> None:
|
|
@@ -24,23 +27,6 @@ class BaseBackend(metaclass=ABCMeta):
|
|
|
24
27
|
assert self.supports_target(target)
|
|
25
28
|
|
|
26
29
|
@staticmethod
|
|
27
|
-
def _path_to_binary(binary: str):
|
|
28
|
-
binary += sysconfig.get_config_var("EXE")
|
|
29
|
-
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
|
30
|
-
paths = [
|
|
31
|
-
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
|
32
|
-
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
|
|
33
|
-
]
|
|
34
|
-
for path in paths:
|
|
35
|
-
if os.path.exists(path) and os.path.isfile(path):
|
|
36
|
-
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
|
|
37
|
-
if result is not None:
|
|
38
|
-
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
39
|
-
if version is not None:
|
|
40
|
-
return path, version.group(1)
|
|
41
|
-
raise RuntimeError(f"Cannot find {binary}")
|
|
42
|
-
|
|
43
|
-
@classmethod
|
|
44
30
|
@abstractmethod
|
|
45
31
|
def supports_target(target: GPUTarget):
|
|
46
32
|
raise NotImplementedError
|
triton/backends/driver.py
CHANGED
|
@@ -15,6 +15,19 @@ class DriverBase(metaclass=ABCMeta):
|
|
|
15
15
|
def is_active(self):
|
|
16
16
|
pass
|
|
17
17
|
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Converts a Triton type string to its corresponding C++ type string for this backend.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
str: The C++ type string.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
18
31
|
@abstractmethod
|
|
19
32
|
def get_current_target(self):
|
|
20
33
|
pass
|
|
Binary file
|