triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,724 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
import subprocess
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from triton import knobs
|
|
7
|
+
from triton.backends.compiler import GPUTarget
|
|
8
|
+
from triton.backends.driver import GPUDriver
|
|
9
|
+
from triton.runtime import _allocation
|
|
10
|
+
from triton.runtime.build import compile_module_from_src
|
|
11
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
12
|
+
|
|
13
|
+
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
14
|
+
include_dirs = [os.path.join(dirname, "include")]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _find_already_mmapped_dylib_on_linux(lib_name):
|
|
18
|
+
import platform
|
|
19
|
+
if platform.system() != 'Linux':
|
|
20
|
+
return None
|
|
21
|
+
|
|
22
|
+
# Use dl_iterate_phdr to walk through the list of shared libraries at runtime.
|
|
23
|
+
# See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details.
|
|
24
|
+
|
|
25
|
+
import ctypes
|
|
26
|
+
from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER
|
|
27
|
+
|
|
28
|
+
class DlPhdrInfo(ctypes.Structure):
|
|
29
|
+
_fields_ = [
|
|
30
|
+
('dlpi_addr', c_void_p),
|
|
31
|
+
('dlpi_name', c_char_p),
|
|
32
|
+
# We don't care about the remaining fields.
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
# callback_t must use POINTER(c_char) to avoid copying.
|
|
36
|
+
callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char))
|
|
37
|
+
|
|
38
|
+
# Load libc and get the dl_iterate_phdr symbol.
|
|
39
|
+
try:
|
|
40
|
+
dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr
|
|
41
|
+
except Exception:
|
|
42
|
+
return None
|
|
43
|
+
# argtypes must use c_char_p to accept create_string_buffer.
|
|
44
|
+
dl_iterate_phdr.argtypes = [callback_t, c_char_p]
|
|
45
|
+
dl_iterate_phdr.restype = c_int
|
|
46
|
+
|
|
47
|
+
max_path_length = 4096
|
|
48
|
+
path = ctypes.create_string_buffer(max_path_length + 1)
|
|
49
|
+
|
|
50
|
+
# Define callback to get the loaded dylib path.
|
|
51
|
+
def callback(info, size, data):
|
|
52
|
+
dlpi_name = info.contents.dlpi_name
|
|
53
|
+
p = Path(os.fsdecode(dlpi_name))
|
|
54
|
+
if lib_name in p.name:
|
|
55
|
+
# Found the dylib; get its path.
|
|
56
|
+
ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name)))
|
|
57
|
+
return 1
|
|
58
|
+
return 0
|
|
59
|
+
|
|
60
|
+
if dl_iterate_phdr(callback_t(callback), path):
|
|
61
|
+
return os.fsdecode(ctypes.string_at(path))
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@functools.lru_cache()
|
|
66
|
+
def _get_path_to_hip_runtime_dylib():
|
|
67
|
+
lib_name = "libamdhip64.so"
|
|
68
|
+
|
|
69
|
+
# If we are told explicitly what HIP runtime dynamic library to use, obey that.
|
|
70
|
+
if env_libhip_path := knobs.amd.libhip_path:
|
|
71
|
+
if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
|
|
72
|
+
return env_libhip_path
|
|
73
|
+
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
|
|
74
|
+
|
|
75
|
+
# If the shared object is already mmapped to address space, use it.
|
|
76
|
+
mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name)
|
|
77
|
+
if mmapped_path:
|
|
78
|
+
if os.path.exists(mmapped_path):
|
|
79
|
+
return mmapped_path
|
|
80
|
+
raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
|
|
81
|
+
|
|
82
|
+
paths = []
|
|
83
|
+
|
|
84
|
+
# Check backend
|
|
85
|
+
local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
|
|
86
|
+
if os.path.exists(local_lib):
|
|
87
|
+
return local_lib
|
|
88
|
+
paths.append(local_lib)
|
|
89
|
+
|
|
90
|
+
import site
|
|
91
|
+
# First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
|
|
92
|
+
# that we run Triton together with PyTorch. This makes sure we use the same dynamic
|
|
93
|
+
# library to avoid version mismatch.
|
|
94
|
+
site_packages = site.getsitepackages()
|
|
95
|
+
user_site = site.getusersitepackages()
|
|
96
|
+
if site.ENABLE_USER_SITE: # ENABLE_USER_SITE is initialized in getusersitepackages()
|
|
97
|
+
site_packages = [user_site] + site_packages
|
|
98
|
+
for path in site_packages:
|
|
99
|
+
path = os.path.join(path, "torch", "lib", lib_name)
|
|
100
|
+
if os.path.exists(path):
|
|
101
|
+
return path
|
|
102
|
+
paths.append(path)
|
|
103
|
+
|
|
104
|
+
# Then try to see if developer provides a HIP runtime dynamic library using LD_LIBARAY_PATH.
|
|
105
|
+
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
|
|
106
|
+
if env_ld_library_path:
|
|
107
|
+
for d in env_ld_library_path.split(":"):
|
|
108
|
+
f = os.path.join(d, lib_name)
|
|
109
|
+
if os.path.exists(f):
|
|
110
|
+
return f
|
|
111
|
+
paths.append(f)
|
|
112
|
+
|
|
113
|
+
# HIP_PATH should point to HIP SDK root if set
|
|
114
|
+
env_hip_path = os.getenv("HIP_PATH")
|
|
115
|
+
if env_hip_path:
|
|
116
|
+
hip_lib_path = os.path.join(env_hip_path, "lib", lib_name)
|
|
117
|
+
if os.path.exists(hip_lib_path):
|
|
118
|
+
return hip_lib_path
|
|
119
|
+
paths.append(hip_lib_path)
|
|
120
|
+
|
|
121
|
+
# if available, `hipconfig --path` prints the HIP SDK root
|
|
122
|
+
try:
|
|
123
|
+
hip_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
|
|
124
|
+
if hip_root:
|
|
125
|
+
hip_lib_path = os.path.join(hip_root, "lib", lib_name)
|
|
126
|
+
if os.path.exists(hip_lib_path):
|
|
127
|
+
return hip_lib_path
|
|
128
|
+
paths.append(hip_lib_path)
|
|
129
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
130
|
+
# hipconfig may not be available
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# ROCm lib dir based on env var
|
|
134
|
+
env_rocm_path = os.getenv("ROCM_PATH")
|
|
135
|
+
if env_rocm_path:
|
|
136
|
+
rocm_lib_path = os.path.join(env_rocm_path, "lib", lib_name)
|
|
137
|
+
if os.path.exists(rocm_lib_path):
|
|
138
|
+
return rocm_lib_path
|
|
139
|
+
paths.append(rocm_lib_path)
|
|
140
|
+
|
|
141
|
+
# Afterwards try to search the loader dynamic library resolution paths.
|
|
142
|
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
|
|
143
|
+
# each line looks like the following:
|
|
144
|
+
# libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
|
|
145
|
+
# libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
|
|
146
|
+
locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)]
|
|
147
|
+
for loc in locs:
|
|
148
|
+
if os.path.exists(loc):
|
|
149
|
+
return loc
|
|
150
|
+
paths.append(loc)
|
|
151
|
+
|
|
152
|
+
# As a last resort, guess if we have it in some common installation path.
|
|
153
|
+
common_install_path = os.path.join('/opt/rocm/lib/', lib_name)
|
|
154
|
+
if os.path.exists(common_install_path):
|
|
155
|
+
return common_install_path
|
|
156
|
+
paths.append(common_install_path)
|
|
157
|
+
|
|
158
|
+
raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class HIPUtils(object):
|
|
162
|
+
|
|
163
|
+
def __new__(cls):
|
|
164
|
+
if not hasattr(cls, "instance"):
|
|
165
|
+
cls.instance = super(HIPUtils, cls).__new__(cls)
|
|
166
|
+
return cls.instance
|
|
167
|
+
|
|
168
|
+
def __init__(self):
|
|
169
|
+
libhip_path = _get_path_to_hip_runtime_dylib()
|
|
170
|
+
src = Path(os.path.join(dirname, "driver.c")).read_text()
|
|
171
|
+
# Just do a simple search and replace here instead of templates or format strings.
|
|
172
|
+
# This way we don't need to escape-quote C code curly brackets and we can replace
|
|
173
|
+
# exactly once.
|
|
174
|
+
src = src.replace('/*py_libhip_search_path*/', libhip_path, 1)
|
|
175
|
+
mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
|
|
176
|
+
self.load_binary = mod.load_binary
|
|
177
|
+
self.get_device_properties = mod.get_device_properties
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# -------------------- Launcher ----------------------------
|
|
181
|
+
def ty_to_cpp(ty):
|
|
182
|
+
if ty[0] == '*':
|
|
183
|
+
return "hipDeviceptr_t"
|
|
184
|
+
return {
|
|
185
|
+
"i1": "int8_t",
|
|
186
|
+
"i8": "int8_t",
|
|
187
|
+
"i16": "int16_t",
|
|
188
|
+
"i32": "int32_t",
|
|
189
|
+
"i64": "int64_t",
|
|
190
|
+
"u1": "uint8_t",
|
|
191
|
+
"u8": "uint8_t",
|
|
192
|
+
"u16": "uint16_t",
|
|
193
|
+
"u32": "uint32_t",
|
|
194
|
+
"u64": "uint64_t",
|
|
195
|
+
"fp16": "double",
|
|
196
|
+
"bf16": "double",
|
|
197
|
+
"fp32": "double",
|
|
198
|
+
"f32": "double",
|
|
199
|
+
"fp64": "double",
|
|
200
|
+
}[ty]
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
FLOAT_STORAGE_TYPE = {
|
|
204
|
+
"fp16": "uint16_t",
|
|
205
|
+
"bf16": "uint16_t",
|
|
206
|
+
"fp32": "uint32_t",
|
|
207
|
+
"f32": "uint32_t",
|
|
208
|
+
"fp64": "uint64_t",
|
|
209
|
+
}
|
|
210
|
+
FLOAT_PACK_FUNCTION = {
|
|
211
|
+
"fp16": "pack_fp16",
|
|
212
|
+
"bf16": "pack_bf16",
|
|
213
|
+
"fp32": "pack_fp32",
|
|
214
|
+
"f32": "pack_fp32",
|
|
215
|
+
"fp64": "pack_fp64",
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
_BASE_ARGS_FORMAT = "piiiKKOOOOO"
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def make_launcher(constants, signature, warp_size):
|
|
222
|
+
|
|
223
|
+
def _expand_signature(signature):
|
|
224
|
+
output = []
|
|
225
|
+
# Expand tensor descriptor arguments into base pointer, shape, and
|
|
226
|
+
# strides
|
|
227
|
+
for sig in signature:
|
|
228
|
+
if isinstance(sig, str) and sig.startswith("tensordesc"):
|
|
229
|
+
ndim = sig.count(",") + 1
|
|
230
|
+
dtype = re.match("tensordesc<([^[>]*)", sig).group()
|
|
231
|
+
|
|
232
|
+
output.append("*" + dtype)
|
|
233
|
+
for _ in range(2 * ndim):
|
|
234
|
+
output.append("i64")
|
|
235
|
+
output.append("i1")
|
|
236
|
+
# Currently the host side tensor descriptors get passed in as a
|
|
237
|
+
# tensor desc, shape, and strides. We have no way to use these
|
|
238
|
+
# shape and strides when processing tensor descriptors which is
|
|
239
|
+
# why we provide our own decomposition above. Sadly this means
|
|
240
|
+
# we have to pass the shape and strides twice.
|
|
241
|
+
for _ in range(ndim):
|
|
242
|
+
output.append("i32")
|
|
243
|
+
for _ in range(ndim):
|
|
244
|
+
output.append("i64")
|
|
245
|
+
else:
|
|
246
|
+
output.append(sig)
|
|
247
|
+
|
|
248
|
+
return output
|
|
249
|
+
|
|
250
|
+
def _serialize_signature(sig):
|
|
251
|
+
if isinstance(sig, tuple):
|
|
252
|
+
return ','.join(map(_serialize_signature, sig))
|
|
253
|
+
return sig
|
|
254
|
+
|
|
255
|
+
def _extracted_type(ty):
|
|
256
|
+
if isinstance(ty, tuple):
|
|
257
|
+
val = ','.join(map(_extracted_type, ty))
|
|
258
|
+
return f"[{val}]"
|
|
259
|
+
if ty[0] == '*':
|
|
260
|
+
return "PyObject*"
|
|
261
|
+
if ty == "constexpr":
|
|
262
|
+
return "PyObject*"
|
|
263
|
+
return ty_to_cpp(ty)
|
|
264
|
+
|
|
265
|
+
def format_of(ty):
|
|
266
|
+
if isinstance(ty, tuple):
|
|
267
|
+
val = ''.join(map(format_of, ty))
|
|
268
|
+
return f"({val})"
|
|
269
|
+
if ty[0] == '*':
|
|
270
|
+
return "O"
|
|
271
|
+
if ty == "constexpr":
|
|
272
|
+
return "O"
|
|
273
|
+
return {
|
|
274
|
+
"double": "d",
|
|
275
|
+
"long": "l",
|
|
276
|
+
"int8_t": "b",
|
|
277
|
+
"int16_t": "h",
|
|
278
|
+
"int32_t": "i",
|
|
279
|
+
"int64_t": "L",
|
|
280
|
+
"uint8_t": "B",
|
|
281
|
+
"uint16_t": "H",
|
|
282
|
+
"uint32_t": "I",
|
|
283
|
+
"uint64_t": "K",
|
|
284
|
+
}[ty_to_cpp(ty)]
|
|
285
|
+
|
|
286
|
+
signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
|
|
287
|
+
|
|
288
|
+
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
|
289
|
+
format = _BASE_ARGS_FORMAT + args_format
|
|
290
|
+
signature = ','.join(map(_serialize_signature, signature.values()))
|
|
291
|
+
signature = list(filter(bool, signature.split(',')))
|
|
292
|
+
signature = {i: s for i, s in enumerate(signature)}
|
|
293
|
+
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
|
294
|
+
# Record the end of regular arguments;
|
|
295
|
+
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
296
|
+
arg_decl_list = []
|
|
297
|
+
for i, ty in signature.items():
|
|
298
|
+
if ty == "constexpr":
|
|
299
|
+
continue
|
|
300
|
+
if ty in FLOAT_STORAGE_TYPE:
|
|
301
|
+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
|
|
302
|
+
else:
|
|
303
|
+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
|
|
304
|
+
arg_decls = ', '.join(arg_decl_list)
|
|
305
|
+
internal_args_list = []
|
|
306
|
+
for i, ty in signature.items():
|
|
307
|
+
if ty[0] == "*":
|
|
308
|
+
internal_args_list.append(f"ptr_info{i}.dev_ptr")
|
|
309
|
+
elif ty in FLOAT_STORAGE_TYPE:
|
|
310
|
+
internal_args_list.append(f"_arg{i}_storage")
|
|
311
|
+
elif ty != "constexpr":
|
|
312
|
+
internal_args_list.append(f"_arg{i}")
|
|
313
|
+
|
|
314
|
+
float_storage_decls = [
|
|
315
|
+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
|
|
316
|
+
for i, ty in signature.items()
|
|
317
|
+
if ty in FLOAT_STORAGE_TYPE
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
libhip_path = _get_path_to_hip_runtime_dylib()
|
|
321
|
+
|
|
322
|
+
# generate glue code
|
|
323
|
+
params = list(range(len(signature)))
|
|
324
|
+
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
|
325
|
+
params.append("&global_scratch")
|
|
326
|
+
params.append("&profile_scratch")
|
|
327
|
+
src = f"""
|
|
328
|
+
#define __HIP_PLATFORM_AMD__
|
|
329
|
+
#include <hip/hip_runtime.h>
|
|
330
|
+
#include <hip/hip_runtime_api.h>
|
|
331
|
+
#include <Python.h>
|
|
332
|
+
#include <dlfcn.h>
|
|
333
|
+
#include <stdbool.h>
|
|
334
|
+
#include <dlfcn.h>
|
|
335
|
+
|
|
336
|
+
// The list of paths to search for the HIP runtime library. The caller Python
|
|
337
|
+
// code should substitute the search path placeholder.
|
|
338
|
+
static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
|
|
339
|
+
|
|
340
|
+
// The list of HIP dynamic library symbols and their signature we are interested
|
|
341
|
+
// in this file.
|
|
342
|
+
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
|
|
343
|
+
FOR_EACH_STR_FN(hipGetLastError) \\
|
|
344
|
+
FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \\
|
|
345
|
+
FOR_EACH_ERR_FN(hipModuleLaunchKernel, hipFunction_t f, \\
|
|
346
|
+
unsigned int gridDimX, unsigned int gridDimY, \\
|
|
347
|
+
unsigned int gridDimZ, unsigned int blockDimX, \\
|
|
348
|
+
unsigned int blockDimY, unsigned int blockDimZ, \\
|
|
349
|
+
unsigned int sharedMemBytes, hipStream_t stream, \\
|
|
350
|
+
void **kernelParams, void **extra) \\
|
|
351
|
+
FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, hipFunction_t f, \\
|
|
352
|
+
unsigned int gridDimX, unsigned int gridDimY, \\
|
|
353
|
+
unsigned int gridDimZ, unsigned int blockDimX, \\
|
|
354
|
+
unsigned int blockDimY, unsigned int blockDimZ, \\
|
|
355
|
+
unsigned int sharedMemBytes, hipStream_t stream, \\
|
|
356
|
+
void **kernelParams, void **extra) \\
|
|
357
|
+
FOR_EACH_ERR_FN(hipPointerGetAttribute, void *data, \\
|
|
358
|
+
hipPointer_attribute attribute, hipDeviceptr_t ptr)
|
|
359
|
+
|
|
360
|
+
// The HIP symbol table for holding resolved dynamic library symbols.
|
|
361
|
+
struct HIPSymbolTable {{
|
|
362
|
+
#define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \\
|
|
363
|
+
hipError_t (*hipSymbolName)(__VA_ARGS__);
|
|
364
|
+
#define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \\
|
|
365
|
+
const char *(*hipSymbolName)(__VA_ARGS__);
|
|
366
|
+
|
|
367
|
+
HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
|
|
368
|
+
}};
|
|
369
|
+
|
|
370
|
+
static struct HIPSymbolTable hipSymbolTable;
|
|
371
|
+
|
|
372
|
+
bool initSymbolTable() {{
|
|
373
|
+
// Use the HIP runtime library loaded into the existing process if it exits.
|
|
374
|
+
void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
|
|
375
|
+
|
|
376
|
+
// Otherwise, go through the list of search paths to dlopen the first HIP
|
|
377
|
+
// driver library.
|
|
378
|
+
if (!lib) {{
|
|
379
|
+
int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
|
|
380
|
+
for (int i = 0; i < n; ++i) {{
|
|
381
|
+
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
|
|
382
|
+
if (handle) {{
|
|
383
|
+
lib = handle;
|
|
384
|
+
}}
|
|
385
|
+
}}
|
|
386
|
+
}}
|
|
387
|
+
if (!lib) {{
|
|
388
|
+
PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
|
|
389
|
+
return false;
|
|
390
|
+
}}
|
|
391
|
+
|
|
392
|
+
typedef hipError_t (*hipGetProcAddress_fn)(
|
|
393
|
+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
|
|
394
|
+
hipDriverProcAddressQueryResult *symbolStatus);
|
|
395
|
+
hipGetProcAddress_fn hipGetProcAddress;
|
|
396
|
+
dlerror(); // Clear existing errors
|
|
397
|
+
const char *error = NULL;
|
|
398
|
+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
|
|
399
|
+
error = dlerror();
|
|
400
|
+
if (error) {{
|
|
401
|
+
PyErr_SetString(PyExc_RuntimeError,
|
|
402
|
+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
|
|
403
|
+
dlclose(lib);
|
|
404
|
+
return false;
|
|
405
|
+
}}
|
|
406
|
+
|
|
407
|
+
// Resolve all symbols we are interested in.
|
|
408
|
+
int hipVersion = HIP_VERSION;
|
|
409
|
+
uint64_t hipFlags = 0;
|
|
410
|
+
hipDriverProcAddressQueryResult symbolStatus;
|
|
411
|
+
hipError_t status = hipSuccess;
|
|
412
|
+
#define QUERY_EACH_FN(hipSymbolName, ...) \
|
|
413
|
+
status = hipGetProcAddress(#hipSymbolName, \
|
|
414
|
+
(void **)&hipSymbolTable.hipSymbolName, \
|
|
415
|
+
hipVersion, hipFlags, &symbolStatus); \
|
|
416
|
+
if (status != hipSuccess) {{ \
|
|
417
|
+
PyErr_SetString(PyExc_RuntimeError, \
|
|
418
|
+
"cannot get address for '" #hipSymbolName \
|
|
419
|
+
"' from libamdhip64.so"); \
|
|
420
|
+
dlclose(lib); \
|
|
421
|
+
return false; \
|
|
422
|
+
}}
|
|
423
|
+
|
|
424
|
+
HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
|
|
425
|
+
|
|
426
|
+
return true;
|
|
427
|
+
}}
|
|
428
|
+
|
|
429
|
+
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
|
430
|
+
{{
|
|
431
|
+
if (code != HIP_SUCCESS)
|
|
432
|
+
{{
|
|
433
|
+
const char* prefix = "Triton Error [HIP]: ";
|
|
434
|
+
const char* str = hipSymbolTable.hipGetErrorString(code);
|
|
435
|
+
char err[1024] = {{0}};
|
|
436
|
+
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
|
437
|
+
PyErr_SetString(PyExc_RuntimeError, err);
|
|
438
|
+
}}
|
|
439
|
+
}}
|
|
440
|
+
|
|
441
|
+
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
|
442
|
+
|
|
443
|
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
444
|
+
hipDeviceptr_t global_scratch = 0;
|
|
445
|
+
void *params[] = {{ {', '.join(params)} }};
|
|
446
|
+
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
|
|
447
|
+
HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
|
|
448
|
+
return;
|
|
449
|
+
}}
|
|
450
|
+
if (gridX*gridY*gridZ > 0) {{
|
|
451
|
+
HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
|
|
452
|
+
}}
|
|
453
|
+
}}
|
|
454
|
+
|
|
455
|
+
typedef struct _DevicePtrInfo {{
|
|
456
|
+
hipDeviceptr_t dev_ptr;
|
|
457
|
+
bool valid;
|
|
458
|
+
}} DevicePtrInfo;
|
|
459
|
+
|
|
460
|
+
static PyObject* data_ptr_str = NULL;
|
|
461
|
+
|
|
462
|
+
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
463
|
+
DevicePtrInfo ptr_info;
|
|
464
|
+
hipError_t status = hipSuccess;
|
|
465
|
+
ptr_info.dev_ptr = 0;
|
|
466
|
+
ptr_info.valid = true;
|
|
467
|
+
if (PyLong_Check(obj)) {{
|
|
468
|
+
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
|
469
|
+
return ptr_info;
|
|
470
|
+
}}
|
|
471
|
+
if (obj == Py_None) {{
|
|
472
|
+
// valid nullptr
|
|
473
|
+
return ptr_info;
|
|
474
|
+
}}
|
|
475
|
+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
|
|
476
|
+
if (!ret) {{
|
|
477
|
+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
|
478
|
+
ptr_info.valid = false;
|
|
479
|
+
goto cleanup;
|
|
480
|
+
}}
|
|
481
|
+
if (!PyLong_Check(ret)) {{
|
|
482
|
+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
|
483
|
+
ptr_info.valid = false;
|
|
484
|
+
goto cleanup;
|
|
485
|
+
}}
|
|
486
|
+
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
|
487
|
+
if (!ptr_info.dev_ptr)
|
|
488
|
+
goto cleanup;
|
|
489
|
+
uint64_t dev_ptr;
|
|
490
|
+
status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
491
|
+
if (status == hipErrorInvalidValue) {{
|
|
492
|
+
PyErr_Format(PyExc_ValueError,
|
|
493
|
+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
494
|
+
ptr_info.valid = false;
|
|
495
|
+
// Clear and ignore HIP error
|
|
496
|
+
(void)hipSymbolTable.hipGetLastError();
|
|
497
|
+
}}
|
|
498
|
+
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
|
499
|
+
cleanup:
|
|
500
|
+
Py_DECREF(ret);
|
|
501
|
+
return ptr_info;
|
|
502
|
+
}}
|
|
503
|
+
|
|
504
|
+
static uint16_t pack_fp16(double f) {{
|
|
505
|
+
uint16_t result;
|
|
506
|
+
// from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
|
|
507
|
+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
|
508
|
+
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
509
|
+
#else
|
|
510
|
+
PyFloat_Pack2(f, (char*)&result, 1);
|
|
511
|
+
#endif
|
|
512
|
+
return result;
|
|
513
|
+
}}
|
|
514
|
+
|
|
515
|
+
static uint16_t pack_bf16(double f) {{
|
|
516
|
+
float f32 = (float)f;
|
|
517
|
+
uint32_t u32 = *(uint32_t*)&f32;
|
|
518
|
+
return (uint16_t)(u32 >> 16);
|
|
519
|
+
}}
|
|
520
|
+
|
|
521
|
+
static uint32_t pack_fp32(double f) {{
|
|
522
|
+
float f32 = (float)f;
|
|
523
|
+
return *(uint32_t*)&f32;
|
|
524
|
+
}}
|
|
525
|
+
|
|
526
|
+
static uint64_t pack_fp64(double f) {{
|
|
527
|
+
return *(uint64_t*)&f;
|
|
528
|
+
}}
|
|
529
|
+
|
|
530
|
+
static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
531
|
+
int gridX, gridY, gridZ;
|
|
532
|
+
uint64_t _stream;
|
|
533
|
+
uint64_t _function;
|
|
534
|
+
int launch_cooperative_grid;
|
|
535
|
+
PyObject *profile_scratch_obj = NULL;
|
|
536
|
+
PyObject *launch_enter_hook = NULL;
|
|
537
|
+
PyObject *launch_exit_hook = NULL;
|
|
538
|
+
PyObject *kernel_metadata = NULL;
|
|
539
|
+
PyObject *launch_metadata = NULL;
|
|
540
|
+
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
|
541
|
+
if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
|
|
542
|
+
&gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
|
|
543
|
+
&kernel_metadata, &launch_metadata,
|
|
544
|
+
&launch_enter_hook, &launch_exit_hook {args_list})) {{
|
|
545
|
+
return NULL;
|
|
546
|
+
}}
|
|
547
|
+
|
|
548
|
+
{' '.join(float_storage_decls)}
|
|
549
|
+
|
|
550
|
+
// extract kernel metadata
|
|
551
|
+
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
|
|
552
|
+
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
|
|
553
|
+
return NULL;
|
|
554
|
+
}}
|
|
555
|
+
// extract launch metadata
|
|
556
|
+
if (launch_enter_hook != Py_None){{
|
|
557
|
+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
|
|
558
|
+
if (!ret)
|
|
559
|
+
return NULL;
|
|
560
|
+
Py_DECREF(ret);
|
|
561
|
+
}}
|
|
562
|
+
|
|
563
|
+
hipDeviceptr_t profile_scratch = 0;
|
|
564
|
+
if (profile_scratch_obj != Py_None) {{
|
|
565
|
+
DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
|
|
566
|
+
if (!profile_scratch_info.valid) {{
|
|
567
|
+
return NULL;
|
|
568
|
+
}}
|
|
569
|
+
profile_scratch = profile_scratch_info.dev_ptr;
|
|
570
|
+
}}
|
|
571
|
+
|
|
572
|
+
// raise exception asap
|
|
573
|
+
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
|
574
|
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
575
|
+
|
|
576
|
+
if(launch_exit_hook != Py_None){{
|
|
577
|
+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
|
|
578
|
+
if (!ret)
|
|
579
|
+
return NULL;
|
|
580
|
+
Py_DECREF(ret);
|
|
581
|
+
}}
|
|
582
|
+
|
|
583
|
+
if(PyErr_Occurred()) {{
|
|
584
|
+
return NULL;
|
|
585
|
+
}}
|
|
586
|
+
Py_RETURN_NONE;
|
|
587
|
+
}}
|
|
588
|
+
|
|
589
|
+
static PyMethodDef ModuleMethods[] = {{
|
|
590
|
+
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
|
591
|
+
{{NULL, NULL, 0, NULL}} // sentinel
|
|
592
|
+
}};
|
|
593
|
+
|
|
594
|
+
static struct PyModuleDef ModuleDef = {{
|
|
595
|
+
PyModuleDef_HEAD_INIT,
|
|
596
|
+
\"__triton_launcher\",
|
|
597
|
+
NULL, //documentation
|
|
598
|
+
-1, //size
|
|
599
|
+
ModuleMethods
|
|
600
|
+
}};
|
|
601
|
+
|
|
602
|
+
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
603
|
+
if (!initSymbolTable()) {{
|
|
604
|
+
return NULL;
|
|
605
|
+
}}
|
|
606
|
+
PyObject *m = PyModule_Create(&ModuleDef);
|
|
607
|
+
if(m == NULL) {{
|
|
608
|
+
return NULL;
|
|
609
|
+
}}
|
|
610
|
+
data_ptr_str = PyUnicode_InternFromString("data_ptr");
|
|
611
|
+
if(data_ptr_str == NULL) {{
|
|
612
|
+
return NULL;
|
|
613
|
+
}}
|
|
614
|
+
PyModule_AddFunctions(m, ModuleMethods);
|
|
615
|
+
return m;
|
|
616
|
+
}}
|
|
617
|
+
"""
|
|
618
|
+
return src
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def wrap_handle_tensor_descriptor(launcher):
|
|
622
|
+
"""
|
|
623
|
+
Replace all tensor descriptors with the base ptr, shape, and strides
|
|
624
|
+
"""
|
|
625
|
+
|
|
626
|
+
def inner(*args):
|
|
627
|
+
meta_args = args[:len(_BASE_ARGS_FORMAT)]
|
|
628
|
+
raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
|
|
629
|
+
final_args = []
|
|
630
|
+
for arg in raw_kernel_args:
|
|
631
|
+
if isinstance(arg, TensorDescriptor):
|
|
632
|
+
# Currently the host side tensor descriptors get decomposed in
|
|
633
|
+
# the frontend to tensor desc, shape, and strides. We have no
|
|
634
|
+
# way to use these shape and strides when processing tensor
|
|
635
|
+
# descriptors which is why we provide our own decomposition
|
|
636
|
+
# above. Sadly this means we have to pass the shape and strides
|
|
637
|
+
# twice.
|
|
638
|
+
final_args.extend([arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides])
|
|
639
|
+
else:
|
|
640
|
+
final_args.append(arg)
|
|
641
|
+
return launcher(*meta_args, *final_args)
|
|
642
|
+
|
|
643
|
+
return inner
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class HIPLauncher(object):
|
|
647
|
+
|
|
648
|
+
def __init__(self, src, metadata):
|
|
649
|
+
constants = src.constants if hasattr(src, "constants") else dict()
|
|
650
|
+
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
|
|
651
|
+
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
|
652
|
+
signature = {idx: value for idx, value in src.signature.items()}
|
|
653
|
+
src = make_launcher(constants, signature, metadata.warp_size)
|
|
654
|
+
mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
|
|
655
|
+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
|
|
656
|
+
|
|
657
|
+
self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch
|
|
658
|
+
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
|
659
|
+
self.profile_scratch_size = metadata.profile_scratch_size
|
|
660
|
+
self.profile_scratch_align = metadata.profile_scratch_align
|
|
661
|
+
|
|
662
|
+
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
|
663
|
+
|
|
664
|
+
def allocate_scratch(size, align, allocator):
|
|
665
|
+
if size > 0:
|
|
666
|
+
grid_size = gridX * gridY * gridZ
|
|
667
|
+
alloc_size = grid_size * size
|
|
668
|
+
alloc_fn = allocator.get()
|
|
669
|
+
return alloc_fn(alloc_size, align, stream)
|
|
670
|
+
return None
|
|
671
|
+
|
|
672
|
+
profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
|
|
673
|
+
_allocation._profile_allocator)
|
|
674
|
+
|
|
675
|
+
self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
class HIPDriver(GPUDriver):
|
|
679
|
+
|
|
680
|
+
def __init__(self):
|
|
681
|
+
super().__init__()
|
|
682
|
+
self.utils = HIPUtils()
|
|
683
|
+
self.launcher_cls = HIPLauncher
|
|
684
|
+
|
|
685
|
+
def get_device_interface(self):
|
|
686
|
+
import torch
|
|
687
|
+
return torch.cuda
|
|
688
|
+
|
|
689
|
+
@staticmethod
|
|
690
|
+
def is_active():
|
|
691
|
+
try:
|
|
692
|
+
import torch
|
|
693
|
+
return torch.cuda.is_available() and (torch.version.hip is not None)
|
|
694
|
+
except ImportError:
|
|
695
|
+
return False
|
|
696
|
+
|
|
697
|
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
698
|
+
return ty_to_cpp(ty)
|
|
699
|
+
|
|
700
|
+
def get_current_target(self):
|
|
701
|
+
device = self.get_current_device()
|
|
702
|
+
device_properties = self.utils.get_device_properties(device)
|
|
703
|
+
arch = knobs.runtime.override_arch or device_properties['arch']
|
|
704
|
+
warp_size = device_properties['warpSize']
|
|
705
|
+
return GPUTarget("hip", arch.split(':')[0], warp_size)
|
|
706
|
+
|
|
707
|
+
def get_active_torch_device(self):
|
|
708
|
+
import torch
|
|
709
|
+
# when using hip devices, the device string in pytorch is "cuda"
|
|
710
|
+
return torch.device("cuda", self.get_current_device())
|
|
711
|
+
|
|
712
|
+
def get_benchmarker(self):
|
|
713
|
+
from triton.testing import do_bench
|
|
714
|
+
return do_bench
|
|
715
|
+
|
|
716
|
+
def get_empty_cache_for_benchmark(self):
|
|
717
|
+
import torch
|
|
718
|
+
|
|
719
|
+
# It's the same as the Nvidia backend.
|
|
720
|
+
cache_size = 256 * 1024 * 1024
|
|
721
|
+
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
|
|
722
|
+
|
|
723
|
+
def clear_cache(self, cache):
|
|
724
|
+
cache.zero_()
|