triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,799 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import operator
|
|
3
|
+
import os
|
|
4
|
+
import subprocess
|
|
5
|
+
import triton
|
|
6
|
+
import re
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from triton import knobs
|
|
9
|
+
from triton.runtime.build import compile_module_from_src
|
|
10
|
+
from triton.runtime import _allocation
|
|
11
|
+
from triton.backends.compiler import GPUTarget
|
|
12
|
+
from triton.backends.driver import GPUDriver
|
|
13
|
+
|
|
14
|
+
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
15
|
+
include_dirs = [os.path.join(dirname, "include")]
|
|
16
|
+
if os.name == "nt":
|
|
17
|
+
from triton.windows_utils import find_cuda
|
|
18
|
+
_, cuda_inc_dirs, _ = find_cuda()
|
|
19
|
+
include_dirs += cuda_inc_dirs
|
|
20
|
+
libdevice_dir = os.path.join(dirname, "lib")
|
|
21
|
+
libraries = ['cuda']
|
|
22
|
+
PyCUtensorMap = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@functools.lru_cache()
|
|
26
|
+
def libcuda_dirs():
|
|
27
|
+
if env_libcuda_path := knobs.nvidia.libcuda_path:
|
|
28
|
+
return [env_libcuda_path]
|
|
29
|
+
|
|
30
|
+
if os.name == "nt":
|
|
31
|
+
_, _, cuda_lib_dirs = find_cuda()
|
|
32
|
+
return cuda_lib_dirs
|
|
33
|
+
|
|
34
|
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
|
|
35
|
+
# each line looks like the following:
|
|
36
|
+
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
|
|
37
|
+
locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
|
|
38
|
+
dirs = [os.path.dirname(loc) for loc in locs]
|
|
39
|
+
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
|
|
40
|
+
if env_ld_library_path and not dirs:
|
|
41
|
+
dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))]
|
|
42
|
+
msg = 'libcuda.so cannot found!\n'
|
|
43
|
+
if locs:
|
|
44
|
+
msg += 'Possible files are located at %s.' % str(locs)
|
|
45
|
+
msg += 'Please create a symlink of libcuda.so to any of the files.'
|
|
46
|
+
else:
|
|
47
|
+
msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"'
|
|
48
|
+
msg += ' (requires sudo) to refresh the linker cache.'
|
|
49
|
+
assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg
|
|
50
|
+
return dirs
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@functools.lru_cache()
|
|
54
|
+
def library_dirs():
|
|
55
|
+
return [libdevice_dir, *libcuda_dirs()]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ------------------------
|
|
59
|
+
# Utils
|
|
60
|
+
# ------------------------
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class CudaUtils(object):
|
|
64
|
+
|
|
65
|
+
def __new__(cls):
|
|
66
|
+
if not hasattr(cls, "instance"):
|
|
67
|
+
cls.instance = super(CudaUtils, cls).__new__(cls)
|
|
68
|
+
return cls.instance
|
|
69
|
+
|
|
70
|
+
def __init__(self):
|
|
71
|
+
mod = compile_module_from_src(
|
|
72
|
+
src=Path(os.path.join(dirname, "driver.c")).read_text(),
|
|
73
|
+
name="cuda_utils",
|
|
74
|
+
library_dirs=library_dirs(),
|
|
75
|
+
include_dirs=include_dirs,
|
|
76
|
+
libraries=libraries,
|
|
77
|
+
)
|
|
78
|
+
global PyCUtensorMap
|
|
79
|
+
PyCUtensorMap = mod.PyCUtensorMap
|
|
80
|
+
self.load_binary = mod.load_binary
|
|
81
|
+
self.get_device_properties = mod.get_device_properties
|
|
82
|
+
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
|
|
83
|
+
self.set_printf_fifo_size = mod.set_printf_fifo_size
|
|
84
|
+
self.fill_tma_descriptor = mod.fill_tma_descriptor
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# ------------------------
|
|
88
|
+
# Launcher
|
|
89
|
+
# ------------------------
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def ty_to_cpp(ty):
|
|
93
|
+
if ty[0] == '*':
|
|
94
|
+
return "CUdeviceptr"
|
|
95
|
+
if ty.startswith("tensordesc"):
|
|
96
|
+
return "CUtensorMap"
|
|
97
|
+
return {
|
|
98
|
+
"i1": "int8_t",
|
|
99
|
+
"i8": "int8_t",
|
|
100
|
+
"i16": "int16_t",
|
|
101
|
+
"i32": "int32_t",
|
|
102
|
+
"i64": "int64_t",
|
|
103
|
+
"u1": "uint8_t",
|
|
104
|
+
"u8": "uint8_t",
|
|
105
|
+
"u16": "uint16_t",
|
|
106
|
+
"u32": "uint32_t",
|
|
107
|
+
"u64": "uint64_t",
|
|
108
|
+
"fp16": "double",
|
|
109
|
+
"bf16": "double",
|
|
110
|
+
"fp32": "double",
|
|
111
|
+
"f32": "double",
|
|
112
|
+
"fp64": "double",
|
|
113
|
+
"nvTmaDesc": "CUtensorMap",
|
|
114
|
+
}[ty]
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
FLOAT_STORAGE_TYPE = {
|
|
118
|
+
"fp16": "uint16_t",
|
|
119
|
+
"bf16": "uint16_t",
|
|
120
|
+
"fp32": "uint32_t",
|
|
121
|
+
"f32": "uint32_t",
|
|
122
|
+
"fp64": "uint64_t",
|
|
123
|
+
}
|
|
124
|
+
FLOAT_PACK_FUNCTION = {
|
|
125
|
+
"fp16": "pack_fp16",
|
|
126
|
+
"bf16": "pack_bf16",
|
|
127
|
+
"fp32": "pack_fp32",
|
|
128
|
+
"f32": "pack_fp32",
|
|
129
|
+
"fp64": "pack_fp64",
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
_BASE_ARGS_FORMAT = "iiiKKppOOOOOO"
|
|
133
|
+
_BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def make_launcher(constants, signature, tensordesc_meta):
|
|
137
|
+
|
|
138
|
+
def _expand_signature(signature):
|
|
139
|
+
output = []
|
|
140
|
+
tensordesc_idx = 0
|
|
141
|
+
# Expand tensor descriptor arguments into either nvTmaDesc, shape and
|
|
142
|
+
# strides, or base pointer, shape and strides depending on whether the
|
|
143
|
+
# kernel was lowered to use the nvTmaDesc or not.
|
|
144
|
+
for sig in signature:
|
|
145
|
+
if isinstance(sig, str) and sig.startswith("tensordesc"):
|
|
146
|
+
meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
|
|
147
|
+
tensordesc_idx += 1
|
|
148
|
+
|
|
149
|
+
match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
|
|
150
|
+
dtype = match.group(1)
|
|
151
|
+
shape = match.group(2)
|
|
152
|
+
ndim = shape.count(",") + 1
|
|
153
|
+
|
|
154
|
+
if meta is None:
|
|
155
|
+
output.append("*" + dtype)
|
|
156
|
+
# Currently the host side tensor descriptors get passed in as a
|
|
157
|
+
# tensor desc, shape, and strides. We have no way to use these
|
|
158
|
+
# shape and strides when processing tensor descriptors which is
|
|
159
|
+
# why we provide our own decomposition above. Sadly this means
|
|
160
|
+
# we have to pass the shape and strides twice.
|
|
161
|
+
for _ in range(2 * ndim):
|
|
162
|
+
output.append("i64")
|
|
163
|
+
output.append("i1")
|
|
164
|
+
else:
|
|
165
|
+
output.append("nvTmaDesc")
|
|
166
|
+
|
|
167
|
+
for _ in range(ndim):
|
|
168
|
+
output.append("i32")
|
|
169
|
+
for _ in range(ndim):
|
|
170
|
+
output.append("i64")
|
|
171
|
+
else:
|
|
172
|
+
output.append(sig)
|
|
173
|
+
|
|
174
|
+
assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
|
|
175
|
+
return output
|
|
176
|
+
|
|
177
|
+
def _flatten_signature(sig, output):
|
|
178
|
+
# Flatten tuples
|
|
179
|
+
if isinstance(sig, tuple):
|
|
180
|
+
for x in sig:
|
|
181
|
+
_flatten_signature(x, output)
|
|
182
|
+
else:
|
|
183
|
+
output.append(sig)
|
|
184
|
+
|
|
185
|
+
def _extracted_type(ty):
|
|
186
|
+
if isinstance(ty, tuple):
|
|
187
|
+
val = ','.join(map(_extracted_type, ty))
|
|
188
|
+
return f"[{val}]"
|
|
189
|
+
if ty[0] == '*':
|
|
190
|
+
return "PyObject*"
|
|
191
|
+
if ty in ("constexpr", "nvTmaDesc"):
|
|
192
|
+
return "PyObject*"
|
|
193
|
+
return ty_to_cpp(ty)
|
|
194
|
+
|
|
195
|
+
def format_of(ty):
|
|
196
|
+
if isinstance(ty, tuple):
|
|
197
|
+
val = ''.join(map(format_of, ty))
|
|
198
|
+
return f"({val})"
|
|
199
|
+
if ty[0] == '*':
|
|
200
|
+
return "O"
|
|
201
|
+
if ty in ("constexpr", "nvTmaDesc"):
|
|
202
|
+
return "O"
|
|
203
|
+
if ty.startswith("tensordesc"):
|
|
204
|
+
return "O"
|
|
205
|
+
return {
|
|
206
|
+
"double": "d",
|
|
207
|
+
"long": "l",
|
|
208
|
+
"int8_t": "b",
|
|
209
|
+
"int16_t": "h",
|
|
210
|
+
"int32_t": "i",
|
|
211
|
+
"int64_t": "L",
|
|
212
|
+
"uint8_t": "B",
|
|
213
|
+
"uint16_t": "H",
|
|
214
|
+
"uint32_t": "I",
|
|
215
|
+
"uint64_t": "K",
|
|
216
|
+
}[ty_to_cpp(ty)]
|
|
217
|
+
|
|
218
|
+
expand_signature = _expand_signature(signature.values())
|
|
219
|
+
signature = {i: s for i, s in enumerate(expand_signature)}
|
|
220
|
+
|
|
221
|
+
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
|
222
|
+
format = _BASE_ARGS_FORMAT + args_format
|
|
223
|
+
|
|
224
|
+
flat_signature = []
|
|
225
|
+
for sig in signature.values():
|
|
226
|
+
_flatten_signature(sig, flat_signature)
|
|
227
|
+
signature = {i: s for i, s in enumerate(flat_signature)}
|
|
228
|
+
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
|
229
|
+
# Record the end of regular arguments;
|
|
230
|
+
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
231
|
+
arg_decl_list = []
|
|
232
|
+
for i, ty in signature.items():
|
|
233
|
+
if ty == "constexpr":
|
|
234
|
+
continue
|
|
235
|
+
if ty in FLOAT_STORAGE_TYPE:
|
|
236
|
+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
|
|
237
|
+
else:
|
|
238
|
+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
|
|
239
|
+
arg_decls = ', '.join(arg_decl_list)
|
|
240
|
+
internal_args_list = []
|
|
241
|
+
for i, ty in signature.items():
|
|
242
|
+
if ty[0] == "*":
|
|
243
|
+
internal_args_list.append(f"ptr_info{i}.dev_ptr")
|
|
244
|
+
elif ty in FLOAT_STORAGE_TYPE:
|
|
245
|
+
internal_args_list.append(f"_arg{i}_storage")
|
|
246
|
+
elif ty == "nvTmaDesc":
|
|
247
|
+
# Note: we have to dereference the pointer
|
|
248
|
+
internal_args_list.append(f"*tma_ptr{i}")
|
|
249
|
+
elif ty != "constexpr":
|
|
250
|
+
internal_args_list.append(f"_arg{i}")
|
|
251
|
+
params = range(len(signature))
|
|
252
|
+
|
|
253
|
+
# generate glue code
|
|
254
|
+
newline = '\n '
|
|
255
|
+
ptr_decls = [
|
|
256
|
+
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
|
|
257
|
+
for i, ty in signature.items()
|
|
258
|
+
if ty[0] == "*"
|
|
259
|
+
]
|
|
260
|
+
tma_decls = [
|
|
261
|
+
f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
|
|
262
|
+
if ty == "nvTmaDesc"
|
|
263
|
+
]
|
|
264
|
+
float_storage_decls = [
|
|
265
|
+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
|
|
266
|
+
for i, ty in signature.items()
|
|
267
|
+
if ty in FLOAT_STORAGE_TYPE
|
|
268
|
+
]
|
|
269
|
+
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
|
270
|
+
params.append("&global_scratch")
|
|
271
|
+
params.append("&profile_scratch")
|
|
272
|
+
src = f"""
|
|
273
|
+
#define _CRT_SECURE_NO_WARNINGS
|
|
274
|
+
#include \"cuda.h\"
|
|
275
|
+
|
|
276
|
+
#ifndef _WIN32
|
|
277
|
+
#include <dlfcn.h>
|
|
278
|
+
#else
|
|
279
|
+
#define WIN32_LEAN_AND_MEAN
|
|
280
|
+
#include <windows.h>
|
|
281
|
+
#endif
|
|
282
|
+
|
|
283
|
+
#include <stdbool.h>
|
|
284
|
+
#include <stdlib.h>
|
|
285
|
+
#define PY_SSIZE_T_CLEAN
|
|
286
|
+
#include <Python.h>
|
|
287
|
+
|
|
288
|
+
typedef struct {{
|
|
289
|
+
PyObject_HEAD
|
|
290
|
+
_Alignas(128) CUtensorMap tensorMap;
|
|
291
|
+
}} PyCUtensorMapObject;
|
|
292
|
+
|
|
293
|
+
static inline void gpuAssert(CUresult code, const char *file, int line)
|
|
294
|
+
{{
|
|
295
|
+
if (code != CUDA_SUCCESS)
|
|
296
|
+
{{
|
|
297
|
+
const char* prefix = "Triton Error [CUDA]: ";
|
|
298
|
+
const char* str;
|
|
299
|
+
cuGetErrorString(code, &str);
|
|
300
|
+
char err[1024] = {{0}};
|
|
301
|
+
strcat(err, prefix);
|
|
302
|
+
strcat(err, str);
|
|
303
|
+
PyGILState_STATE gil_state;
|
|
304
|
+
gil_state = PyGILState_Ensure();
|
|
305
|
+
PyErr_SetString(PyExc_RuntimeError, err);
|
|
306
|
+
PyGILState_Release(gil_state);
|
|
307
|
+
}}
|
|
308
|
+
}}
|
|
309
|
+
|
|
310
|
+
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
|
311
|
+
|
|
312
|
+
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
|
|
313
|
+
|
|
314
|
+
#ifndef _WIN32
|
|
315
|
+
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
|
316
|
+
// Open the shared library
|
|
317
|
+
void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
|
|
318
|
+
if (!handle) {{
|
|
319
|
+
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1");
|
|
320
|
+
return NULL;
|
|
321
|
+
}}
|
|
322
|
+
// Clear any existing error
|
|
323
|
+
dlerror();
|
|
324
|
+
cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
|
|
325
|
+
// Check for errors
|
|
326
|
+
const char *dlsym_error = dlerror();
|
|
327
|
+
if (dlsym_error) {{
|
|
328
|
+
PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1");
|
|
329
|
+
return NULL;
|
|
330
|
+
}}
|
|
331
|
+
return cuLaunchKernelExHandle;
|
|
332
|
+
}}
|
|
333
|
+
#else
|
|
334
|
+
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
|
335
|
+
// Open the shared library
|
|
336
|
+
HMODULE handle = LoadLibraryA("nvcuda.dll");
|
|
337
|
+
if (!handle) {{
|
|
338
|
+
PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll");
|
|
339
|
+
return NULL;
|
|
340
|
+
}}
|
|
341
|
+
cuLaunchKernelEx_t cuLaunchKernelExHandle =
|
|
342
|
+
(cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx");
|
|
343
|
+
// Check for errors
|
|
344
|
+
long error = GetLastError();
|
|
345
|
+
if (error) {{
|
|
346
|
+
PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll");
|
|
347
|
+
return NULL;
|
|
348
|
+
}}
|
|
349
|
+
return cuLaunchKernelExHandle;
|
|
350
|
+
}}
|
|
351
|
+
#endif
|
|
352
|
+
|
|
353
|
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
354
|
+
void *params[] = {{ {', '.join(params)} }};
|
|
355
|
+
if (gridX*gridY*gridZ > 0) {{
|
|
356
|
+
// 4 attributes that we can currently pass maximum
|
|
357
|
+
CUlaunchAttribute launchAttr[4];
|
|
358
|
+
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
359
|
+
if (cuLaunchKernelExHandle == NULL) {{
|
|
360
|
+
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
361
|
+
}}
|
|
362
|
+
CUlaunchConfig config;
|
|
363
|
+
config.gridDimX = gridX;
|
|
364
|
+
config.gridDimY = gridY;
|
|
365
|
+
config.gridDimZ = gridZ;
|
|
366
|
+
|
|
367
|
+
if (num_ctas != 1) {{
|
|
368
|
+
config.gridDimX *= clusterDimX;
|
|
369
|
+
config.gridDimY *= clusterDimY;
|
|
370
|
+
config.gridDimZ *= clusterDimZ;
|
|
371
|
+
}}
|
|
372
|
+
|
|
373
|
+
config.blockDimX = 32 * num_warps;
|
|
374
|
+
config.blockDimY = 1;
|
|
375
|
+
config.blockDimZ = 1;
|
|
376
|
+
config.sharedMemBytes = shared_memory;
|
|
377
|
+
config.hStream = stream;
|
|
378
|
+
config.attrs = launchAttr;
|
|
379
|
+
int num_attrs = 0;
|
|
380
|
+
|
|
381
|
+
if (launch_pdl != 0) {{
|
|
382
|
+
CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
|
|
383
|
+
launchAttr[num_attrs] = pdlAttr;
|
|
384
|
+
++num_attrs;
|
|
385
|
+
}}
|
|
386
|
+
|
|
387
|
+
if (launch_cooperative_grid != 0) {{
|
|
388
|
+
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
|
|
389
|
+
launchAttr[num_attrs] = coopAttr;
|
|
390
|
+
++num_attrs;
|
|
391
|
+
}}
|
|
392
|
+
|
|
393
|
+
if (num_ctas != 1) {{
|
|
394
|
+
CUlaunchAttribute clusterAttr = {{}};
|
|
395
|
+
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
396
|
+
clusterAttr.value.clusterDim.x = clusterDimX;
|
|
397
|
+
clusterAttr.value.clusterDim.y = clusterDimY;
|
|
398
|
+
clusterAttr.value.clusterDim.z = clusterDimZ;
|
|
399
|
+
launchAttr[num_attrs] = clusterAttr;
|
|
400
|
+
++num_attrs;
|
|
401
|
+
|
|
402
|
+
CUlaunchAttribute clusterSchedulingAttr = {{}};
|
|
403
|
+
clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
404
|
+
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
|
405
|
+
launchAttr[num_attrs] = clusterSchedulingAttr;
|
|
406
|
+
++num_attrs;
|
|
407
|
+
}}
|
|
408
|
+
|
|
409
|
+
config.numAttrs = num_attrs;
|
|
410
|
+
|
|
411
|
+
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
412
|
+
}}
|
|
413
|
+
}}
|
|
414
|
+
|
|
415
|
+
typedef struct _DevicePtrInfo {{
|
|
416
|
+
CUdeviceptr dev_ptr;
|
|
417
|
+
bool valid;
|
|
418
|
+
}} DevicePtrInfo;
|
|
419
|
+
|
|
420
|
+
static PyObject* data_ptr_str = NULL;
|
|
421
|
+
static PyObject* py_tensor_map_type = NULL;
|
|
422
|
+
|
|
423
|
+
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
424
|
+
DevicePtrInfo ptr_info;
|
|
425
|
+
ptr_info.dev_ptr = 0;
|
|
426
|
+
ptr_info.valid = true;
|
|
427
|
+
if (PyLong_Check(obj)) {{
|
|
428
|
+
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
|
|
429
|
+
return ptr_info;
|
|
430
|
+
}}
|
|
431
|
+
if (obj == Py_None) {{
|
|
432
|
+
// valid nullptr
|
|
433
|
+
return ptr_info;
|
|
434
|
+
}}
|
|
435
|
+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
|
|
436
|
+
if (!ret) {{
|
|
437
|
+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
|
438
|
+
ptr_info.valid = false;
|
|
439
|
+
goto cleanup;
|
|
440
|
+
}}
|
|
441
|
+
if (!PyLong_Check(ret)) {{
|
|
442
|
+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
|
443
|
+
ptr_info.valid = false;
|
|
444
|
+
goto cleanup;
|
|
445
|
+
}}
|
|
446
|
+
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
|
|
447
|
+
if(!ptr_info.dev_ptr)
|
|
448
|
+
return ptr_info;
|
|
449
|
+
uint64_t dev_ptr;
|
|
450
|
+
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
451
|
+
if (status == CUDA_ERROR_INVALID_VALUE) {{
|
|
452
|
+
PyErr_Format(PyExc_ValueError,
|
|
453
|
+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
454
|
+
ptr_info.valid = false;
|
|
455
|
+
}} else if (status != CUDA_SUCCESS) {{
|
|
456
|
+
CUDA_CHECK(status); // Catch any other cuda API errors
|
|
457
|
+
ptr_info.valid = false;
|
|
458
|
+
}}
|
|
459
|
+
ptr_info.dev_ptr = dev_ptr;
|
|
460
|
+
cleanup:
|
|
461
|
+
Py_XDECREF(ret);
|
|
462
|
+
return ptr_info;
|
|
463
|
+
|
|
464
|
+
}}
|
|
465
|
+
|
|
466
|
+
static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
|
|
467
|
+
if (sizeof(CUtensorMap*) != 8) {{
|
|
468
|
+
PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
|
|
469
|
+
return NULL;
|
|
470
|
+
}}
|
|
471
|
+
|
|
472
|
+
if (Py_TYPE(obj) != (PyTypeObject*)py_tensor_map_type) {{
|
|
473
|
+
PyErr_Format(PyExc_TypeError, "object must be of type PyCUtensorMap, got %s", Py_TYPE(obj)->tp_name);
|
|
474
|
+
return NULL;
|
|
475
|
+
}}
|
|
476
|
+
|
|
477
|
+
CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
|
|
478
|
+
uintptr_t align_128 = (uintptr_t)map & (128 - 1);
|
|
479
|
+
if (align_128 != 0) {{
|
|
480
|
+
PyErr_Format(PyExc_ValueError, "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", align_128);
|
|
481
|
+
return NULL;
|
|
482
|
+
}}
|
|
483
|
+
return map;
|
|
484
|
+
}}
|
|
485
|
+
|
|
486
|
+
static void ensureCudaContext() {{
|
|
487
|
+
CUcontext pctx;
|
|
488
|
+
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
|
489
|
+
if (!pctx) {{
|
|
490
|
+
// Ensure device context.
|
|
491
|
+
CUdevice device;
|
|
492
|
+
CUDA_CHECK(cuDeviceGet(&device, 0));
|
|
493
|
+
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
|
494
|
+
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
|
495
|
+
}}
|
|
496
|
+
}}
|
|
497
|
+
|
|
498
|
+
static uint16_t pack_fp16(double f) {{
|
|
499
|
+
uint16_t result;
|
|
500
|
+
// from https://github.com/python/pythoncapi-compat
|
|
501
|
+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
|
502
|
+
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
503
|
+
#else
|
|
504
|
+
PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
505
|
+
#endif
|
|
506
|
+
return result;
|
|
507
|
+
}}
|
|
508
|
+
|
|
509
|
+
static uint16_t pack_bf16(double f) {{
|
|
510
|
+
float f32 = (float)f;
|
|
511
|
+
uint32_t u32 = *(uint32_t*)&f32;
|
|
512
|
+
return (uint16_t)(u32 >> 16);
|
|
513
|
+
}}
|
|
514
|
+
|
|
515
|
+
static uint32_t pack_fp32(double f) {{
|
|
516
|
+
float f32 = (float)f;
|
|
517
|
+
return *(uint32_t*)&f32;
|
|
518
|
+
}}
|
|
519
|
+
|
|
520
|
+
static uint64_t pack_fp64(double f) {{
|
|
521
|
+
return *(uint64_t*)&f;
|
|
522
|
+
}}
|
|
523
|
+
|
|
524
|
+
static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
525
|
+
// ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
|
|
526
|
+
ensureCudaContext();
|
|
527
|
+
|
|
528
|
+
int gridX, gridY, gridZ;
|
|
529
|
+
uint64_t _stream;
|
|
530
|
+
uint64_t _function;
|
|
531
|
+
int launch_cooperative_grid;
|
|
532
|
+
int launch_pdl;
|
|
533
|
+
PyObject *launch_enter_hook = NULL;
|
|
534
|
+
PyObject *launch_exit_hook = NULL;
|
|
535
|
+
PyObject *kernel_metadata = NULL;
|
|
536
|
+
PyObject *launch_metadata = NULL;
|
|
537
|
+
PyObject *global_scratch_obj = NULL;
|
|
538
|
+
PyObject *profile_scratch_obj = NULL;
|
|
539
|
+
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
|
|
540
|
+
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
|
|
541
|
+
&_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj, &profile_scratch_obj,
|
|
542
|
+
&kernel_metadata, &launch_metadata,
|
|
543
|
+
&launch_enter_hook, &launch_exit_hook{args_list})) {{
|
|
544
|
+
return NULL;
|
|
545
|
+
}}
|
|
546
|
+
|
|
547
|
+
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
|
|
548
|
+
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
|
|
549
|
+
PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
|
|
550
|
+
return NULL;
|
|
551
|
+
}}
|
|
552
|
+
|
|
553
|
+
// extract launch metadata
|
|
554
|
+
if (launch_enter_hook != Py_None){{
|
|
555
|
+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
|
|
556
|
+
if (!ret)
|
|
557
|
+
return NULL;
|
|
558
|
+
Py_DECREF(ret);
|
|
559
|
+
}}
|
|
560
|
+
|
|
561
|
+
CUdeviceptr global_scratch = 0;
|
|
562
|
+
if (global_scratch_obj != Py_None) {{
|
|
563
|
+
DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
|
|
564
|
+
if (!global_scratch_info.valid) {{
|
|
565
|
+
return NULL;
|
|
566
|
+
}}
|
|
567
|
+
global_scratch = global_scratch_info.dev_ptr;
|
|
568
|
+
}}
|
|
569
|
+
|
|
570
|
+
CUdeviceptr profile_scratch = 0;
|
|
571
|
+
if (profile_scratch_obj != Py_None) {{
|
|
572
|
+
DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
|
|
573
|
+
if (!profile_scratch_info.valid) {{
|
|
574
|
+
return NULL;
|
|
575
|
+
}}
|
|
576
|
+
profile_scratch = profile_scratch_info.dev_ptr;
|
|
577
|
+
}}
|
|
578
|
+
|
|
579
|
+
// raise exception asap
|
|
580
|
+
{newline.join(ptr_decls)}
|
|
581
|
+
{newline.join(tma_decls)}
|
|
582
|
+
{newline.join(float_storage_decls)}
|
|
583
|
+
Py_BEGIN_ALLOW_THREADS;
|
|
584
|
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
585
|
+
Py_END_ALLOW_THREADS;
|
|
586
|
+
if (PyErr_Occurred()) {{
|
|
587
|
+
return NULL;
|
|
588
|
+
}}
|
|
589
|
+
|
|
590
|
+
if(launch_exit_hook != Py_None){{
|
|
591
|
+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
|
|
592
|
+
if (!ret)
|
|
593
|
+
return NULL;
|
|
594
|
+
Py_DECREF(ret);
|
|
595
|
+
}}
|
|
596
|
+
|
|
597
|
+
Py_RETURN_NONE;
|
|
598
|
+
}}
|
|
599
|
+
|
|
600
|
+
static PyMethodDef ModuleMethods[] = {{
|
|
601
|
+
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
|
602
|
+
{{NULL, NULL, 0, NULL}} // sentinel
|
|
603
|
+
}};
|
|
604
|
+
|
|
605
|
+
static struct PyModuleDef ModuleDef = {{
|
|
606
|
+
PyModuleDef_HEAD_INIT,
|
|
607
|
+
\"__triton_launcher\",
|
|
608
|
+
NULL, //documentation
|
|
609
|
+
-1, //size
|
|
610
|
+
ModuleMethods
|
|
611
|
+
}};
|
|
612
|
+
|
|
613
|
+
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
614
|
+
data_ptr_str = PyUnicode_InternFromString("data_ptr");
|
|
615
|
+
if(data_ptr_str == NULL) {{
|
|
616
|
+
return NULL;
|
|
617
|
+
}}
|
|
618
|
+
PyObject* driver_mod = PyImport_ImportModule("triton.backends.nvidia.driver");
|
|
619
|
+
if (driver_mod == NULL) {{
|
|
620
|
+
return NULL;
|
|
621
|
+
}}
|
|
622
|
+
py_tensor_map_type = PyObject_GetAttrString(driver_mod, "PyCUtensorMap");
|
|
623
|
+
if (py_tensor_map_type == NULL) {{
|
|
624
|
+
return NULL;
|
|
625
|
+
}}
|
|
626
|
+
|
|
627
|
+
PyObject *m = PyModule_Create(&ModuleDef);
|
|
628
|
+
if(m == NULL) {{
|
|
629
|
+
return NULL;
|
|
630
|
+
}}
|
|
631
|
+
PyModule_AddFunctions(m, ModuleMethods);
|
|
632
|
+
return m;
|
|
633
|
+
}}
|
|
634
|
+
"""
|
|
635
|
+
return src
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
# The TMA dtype enum values are slightly different on host vs device...
|
|
639
|
+
TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
|
|
640
|
+
TMA_DTYPE_DEVICE_TO_HOST[8] = 10
|
|
641
|
+
TMA_DTYPE_DEVICE_TO_HOST[9] = 8
|
|
642
|
+
TMA_DTYPE_DEVICE_TO_HOST[10] = 9
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def make_tensordesc_arg(arg, metadata):
|
|
646
|
+
if metadata is None:
|
|
647
|
+
# Currently the host side tensor descriptors get decomposed in
|
|
648
|
+
# the frontend to tensor desc, shape, and strides. We have no
|
|
649
|
+
# way to use these shape and strides when processing tensor
|
|
650
|
+
# descriptors which is why we provide our own decomposition
|
|
651
|
+
# above. Sadly this means we have to pass the shape and strides
|
|
652
|
+
# twice.
|
|
653
|
+
return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
|
|
654
|
+
|
|
655
|
+
swizzle = metadata["swizzle"]
|
|
656
|
+
elem_size = metadata["elem_size"]
|
|
657
|
+
elem_type = metadata["elem_type"]
|
|
658
|
+
block_size = metadata["block_size"]
|
|
659
|
+
fp4_padded = metadata["fp4_padded"]
|
|
660
|
+
|
|
661
|
+
shape = arg.shape
|
|
662
|
+
strides = arg.strides
|
|
663
|
+
assert strides[-1] == 1
|
|
664
|
+
padding = 1 if arg.padding == "nan" else 0
|
|
665
|
+
|
|
666
|
+
if fp4_padded:
|
|
667
|
+
shape = list(shape)
|
|
668
|
+
shape[-1] *= 2
|
|
669
|
+
|
|
670
|
+
cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
|
|
671
|
+
arg.base.data_ptr(),
|
|
672
|
+
swizzle,
|
|
673
|
+
elem_size,
|
|
674
|
+
TMA_DTYPE_DEVICE_TO_HOST[elem_type],
|
|
675
|
+
block_size,
|
|
676
|
+
shape,
|
|
677
|
+
strides,
|
|
678
|
+
padding,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
return [cu_tensor_map, *shape, *strides]
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def wrap_handle_tensordesc(launcher, signature, tensordesc_meta):
|
|
685
|
+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
|
|
686
|
+
if not has_tensor_desc_arg:
|
|
687
|
+
return launcher
|
|
688
|
+
|
|
689
|
+
tensordesc_indices = set(
|
|
690
|
+
[i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
|
|
691
|
+
assert not tensordesc_meta or len(tensordesc_meta) == len(tensordesc_indices)
|
|
692
|
+
if not tensordesc_meta:
|
|
693
|
+
tensordesc_meta = [None] * len(tensordesc_indices)
|
|
694
|
+
|
|
695
|
+
def inner(*args):
|
|
696
|
+
final_args = list(args[:_BASE_ARGS_FORMAT_LEN])
|
|
697
|
+
tensordesc_idx = 0
|
|
698
|
+
for i, arg in enumerate(args[_BASE_ARGS_FORMAT_LEN:]):
|
|
699
|
+
if i in tensordesc_indices:
|
|
700
|
+
final_args.extend(make_tensordesc_arg(arg, tensordesc_meta[tensordesc_idx]))
|
|
701
|
+
tensordesc_idx += 1
|
|
702
|
+
else:
|
|
703
|
+
final_args.append(arg)
|
|
704
|
+
return launcher(*final_args)
|
|
705
|
+
|
|
706
|
+
return inner
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
class CudaLauncher(object):
|
|
710
|
+
|
|
711
|
+
def __init__(self, src, metadata):
|
|
712
|
+
constants = src.constants if hasattr(src, "constants") else dict()
|
|
713
|
+
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
|
|
714
|
+
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
|
715
|
+
signature = {idx: value for idx, value in src.signature.items()}
|
|
716
|
+
tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
|
|
717
|
+
src = make_launcher(constants, signature, tensordesc_meta)
|
|
718
|
+
mod = compile_module_from_src(
|
|
719
|
+
src=src,
|
|
720
|
+
name="__triton_launcher",
|
|
721
|
+
library_dirs=library_dirs(),
|
|
722
|
+
include_dirs=include_dirs,
|
|
723
|
+
libraries=libraries,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
|
|
727
|
+
self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
|
|
728
|
+
self.global_scratch_size = metadata.global_scratch_size
|
|
729
|
+
self.global_scratch_align = metadata.global_scratch_align
|
|
730
|
+
self.profile_scratch_size = metadata.profile_scratch_size
|
|
731
|
+
self.profile_scratch_align = metadata.profile_scratch_align
|
|
732
|
+
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
|
733
|
+
self.launch_pdl = metadata.launch_pdl
|
|
734
|
+
|
|
735
|
+
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
|
736
|
+
|
|
737
|
+
def allocate_scratch(size, align, allocator):
|
|
738
|
+
if size > 0:
|
|
739
|
+
grid_size = gridX * gridY * gridZ
|
|
740
|
+
alloc_size = grid_size * self.num_ctas * size
|
|
741
|
+
alloc_fn = allocator.get()
|
|
742
|
+
return alloc_fn(alloc_size, align, stream)
|
|
743
|
+
return None
|
|
744
|
+
|
|
745
|
+
global_scratch = allocate_scratch(self.global_scratch_size, self.global_scratch_align, _allocation._allocator)
|
|
746
|
+
profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
|
|
747
|
+
_allocation._profile_allocator)
|
|
748
|
+
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
|
|
749
|
+
global_scratch, profile_scratch, *args)
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
class CudaDriver(GPUDriver):
|
|
753
|
+
|
|
754
|
+
def __init__(self):
|
|
755
|
+
self.utils = CudaUtils() # TODO: make static
|
|
756
|
+
self.launcher_cls = CudaLauncher
|
|
757
|
+
super().__init__()
|
|
758
|
+
|
|
759
|
+
def get_current_target(self):
|
|
760
|
+
device = self.get_current_device()
|
|
761
|
+
capability = self.get_device_capability(device)
|
|
762
|
+
capability = capability[0] * 10 + capability[1]
|
|
763
|
+
warp_size = 32
|
|
764
|
+
return GPUTarget("cuda", capability, warp_size)
|
|
765
|
+
|
|
766
|
+
def get_active_torch_device(self):
|
|
767
|
+
import torch
|
|
768
|
+
return torch.device("cuda", self.get_current_device())
|
|
769
|
+
|
|
770
|
+
def get_device_interface(self):
|
|
771
|
+
import torch
|
|
772
|
+
return torch.cuda
|
|
773
|
+
|
|
774
|
+
@staticmethod
|
|
775
|
+
def is_active():
|
|
776
|
+
try:
|
|
777
|
+
import torch
|
|
778
|
+
return torch.cuda.is_available() and (torch.version.hip is None)
|
|
779
|
+
except ImportError:
|
|
780
|
+
return False
|
|
781
|
+
|
|
782
|
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
783
|
+
return ty_to_cpp(ty)
|
|
784
|
+
|
|
785
|
+
def get_benchmarker(self):
|
|
786
|
+
from triton.testing import do_bench
|
|
787
|
+
return do_bench
|
|
788
|
+
|
|
789
|
+
def get_empty_cache_for_benchmark(self):
|
|
790
|
+
import torch
|
|
791
|
+
|
|
792
|
+
# We maintain a buffer of 256 MB that we clear
|
|
793
|
+
# before each kernel call to make sure that the L2 cache
|
|
794
|
+
# doesn't contain any input data before the run
|
|
795
|
+
cache_size = 256 * 1024 * 1024
|
|
796
|
+
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
|
|
797
|
+
|
|
798
|
+
def clear_cache(self, cache):
|
|
799
|
+
cache.zero_()
|