triton-windows 3.3.1.post21__cp312-cp312-win_amd64.whl → 3.4.0.post21__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +143 -46
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +94 -94
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +296 -125
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +73 -9
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +47 -83
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
- triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
triton/backends/nvidia/driver.py
CHANGED
|
@@ -1,36 +1,33 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import operator
|
|
2
3
|
import os
|
|
3
|
-
import sysconfig
|
|
4
|
-
import hashlib
|
|
5
4
|
import subprocess
|
|
6
|
-
import
|
|
5
|
+
import triton
|
|
6
|
+
import re
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from triton
|
|
9
|
-
from triton.runtime.
|
|
8
|
+
from triton import knobs
|
|
9
|
+
from triton.runtime.build import compile_module_from_src
|
|
10
10
|
from triton.runtime import _allocation
|
|
11
11
|
from triton.backends.compiler import GPUTarget
|
|
12
12
|
from triton.backends.driver import GPUDriver
|
|
13
13
|
|
|
14
14
|
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
15
|
-
|
|
15
|
+
include_dirs = [os.path.join(dirname, "include")]
|
|
16
16
|
if os.name == "nt":
|
|
17
17
|
from triton.windows_utils import find_cuda
|
|
18
18
|
_, cuda_inc_dirs, _ = find_cuda()
|
|
19
|
-
|
|
19
|
+
include_dirs += cuda_inc_dirs
|
|
20
20
|
libdevice_dir = os.path.join(dirname, "lib")
|
|
21
21
|
libraries = ['cuda']
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@functools.lru_cache()
|
|
25
25
|
def libcuda_dirs():
|
|
26
|
-
env_libcuda_path
|
|
27
|
-
if env_libcuda_path:
|
|
26
|
+
if env_libcuda_path := knobs.nvidia.libcuda_path:
|
|
28
27
|
return [env_libcuda_path]
|
|
29
|
-
|
|
30
28
|
if os.name == "nt":
|
|
31
29
|
_, _, cuda_lib_dirs = find_cuda()
|
|
32
30
|
return cuda_lib_dirs
|
|
33
|
-
|
|
34
31
|
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
|
|
35
32
|
# each line looks like the following:
|
|
36
33
|
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
|
|
@@ -55,36 +52,6 @@ def library_dirs():
|
|
|
55
52
|
return [libdevice_dir, *libcuda_dirs()]
|
|
56
53
|
|
|
57
54
|
|
|
58
|
-
@functools.lru_cache()
|
|
59
|
-
def platform_key():
|
|
60
|
-
from platform import machine, system, architecture
|
|
61
|
-
return ",".join([machine(), system(), *architecture()])
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def compile_module_from_src(src, name):
|
|
65
|
-
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
|
|
66
|
-
cache = get_cache_manager(key)
|
|
67
|
-
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
68
|
-
cache_path = cache.get_file(f"{name}.{ext}")
|
|
69
|
-
if cache_path is None:
|
|
70
|
-
with tempfile.TemporaryDirectory() as tmpdir:
|
|
71
|
-
src_path = os.path.join(tmpdir, f"{name}.c")
|
|
72
|
-
with open(src_path, "w") as f:
|
|
73
|
-
f.write(src)
|
|
74
|
-
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
|
|
75
|
-
with open(so, "rb") as f:
|
|
76
|
-
cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
|
|
77
|
-
|
|
78
|
-
# Loading module with relative path may cause error
|
|
79
|
-
cache_path = os.path.abspath(cache_path)
|
|
80
|
-
|
|
81
|
-
import importlib.util
|
|
82
|
-
spec = importlib.util.spec_from_file_location(name, cache_path)
|
|
83
|
-
mod = importlib.util.module_from_spec(spec)
|
|
84
|
-
spec.loader.exec_module(mod)
|
|
85
|
-
return mod
|
|
86
|
-
|
|
87
|
-
|
|
88
55
|
# ------------------------
|
|
89
56
|
# Utils
|
|
90
57
|
# ------------------------
|
|
@@ -98,13 +65,18 @@ class CudaUtils(object):
|
|
|
98
65
|
return cls.instance
|
|
99
66
|
|
|
100
67
|
def __init__(self):
|
|
101
|
-
mod = compile_module_from_src(
|
|
68
|
+
mod = compile_module_from_src(
|
|
69
|
+
src=Path(os.path.join(dirname, "driver.c")).read_text(),
|
|
70
|
+
name="cuda_utils",
|
|
71
|
+
library_dirs=library_dirs(),
|
|
72
|
+
include_dirs=include_dirs,
|
|
73
|
+
libraries=libraries,
|
|
74
|
+
)
|
|
102
75
|
self.load_binary = mod.load_binary
|
|
103
76
|
self.get_device_properties = mod.get_device_properties
|
|
104
77
|
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
|
|
105
78
|
self.set_printf_fifo_size = mod.set_printf_fifo_size
|
|
106
|
-
self.
|
|
107
|
-
self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor
|
|
79
|
+
self.fill_tma_descriptor = mod.fill_tma_descriptor
|
|
108
80
|
|
|
109
81
|
|
|
110
82
|
# ------------------------
|
|
@@ -115,6 +87,8 @@ class CudaUtils(object):
|
|
|
115
87
|
def ty_to_cpp(ty):
|
|
116
88
|
if ty[0] == '*':
|
|
117
89
|
return "CUdeviceptr"
|
|
90
|
+
if ty.startswith("tensordesc"):
|
|
91
|
+
return "CUtensorMap"
|
|
118
92
|
return {
|
|
119
93
|
"i1": "int32_t",
|
|
120
94
|
"i8": "int8_t",
|
|
@@ -126,21 +100,80 @@ def ty_to_cpp(ty):
|
|
|
126
100
|
"u16": "uint16_t",
|
|
127
101
|
"u32": "uint32_t",
|
|
128
102
|
"u64": "uint64_t",
|
|
129
|
-
"fp16": "
|
|
130
|
-
"bf16": "
|
|
131
|
-
"fp32": "
|
|
132
|
-
"f32": "
|
|
103
|
+
"fp16": "double",
|
|
104
|
+
"bf16": "double",
|
|
105
|
+
"fp32": "double",
|
|
106
|
+
"f32": "double",
|
|
133
107
|
"fp64": "double",
|
|
134
108
|
"nvTmaDesc": "CUtensorMap",
|
|
135
109
|
}[ty]
|
|
136
110
|
|
|
137
111
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
112
|
+
FLOAT_STORAGE_TYPE = {
|
|
113
|
+
"fp16": "uint16_t",
|
|
114
|
+
"bf16": "uint16_t",
|
|
115
|
+
"fp32": "uint32_t",
|
|
116
|
+
"f32": "uint32_t",
|
|
117
|
+
"fp64": "uint64_t",
|
|
118
|
+
}
|
|
119
|
+
FLOAT_PACK_FUNCTION = {
|
|
120
|
+
"fp16": "pack_fp16",
|
|
121
|
+
"bf16": "pack_bf16",
|
|
122
|
+
"fp32": "pack_fp32",
|
|
123
|
+
"f32": "pack_fp32",
|
|
124
|
+
"fp64": "pack_fp64",
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
_BASE_ARGS_FORMAT = "iiiKKppOOOOO"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def make_launcher(constants, signature, tensordesc_meta):
|
|
131
|
+
|
|
132
|
+
def _expand_signature(signature):
|
|
133
|
+
output = []
|
|
134
|
+
tensordesc_idx = 0
|
|
135
|
+
# Expand tensor descriptor arguments into either nvTmaDesc, shape and
|
|
136
|
+
# strides, or base pointer, shape and strides depending on whether the
|
|
137
|
+
# kernel was lowered to use the nvTmaDesc or not.
|
|
138
|
+
for sig in signature:
|
|
139
|
+
if isinstance(sig, str) and sig.startswith("tensordesc"):
|
|
140
|
+
meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
|
|
141
|
+
tensordesc_idx += 1
|
|
142
|
+
|
|
143
|
+
match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
|
|
144
|
+
dtype = match.group(1)
|
|
145
|
+
shape = match.group(2)
|
|
146
|
+
ndim = shape.count(",") + 1
|
|
147
|
+
|
|
148
|
+
if meta is None:
|
|
149
|
+
output.append("*" + dtype)
|
|
150
|
+
# Currently the host side tensor descriptors get passed in as a
|
|
151
|
+
# tensor desc, shape, and strides. We have no way to use these
|
|
152
|
+
# shape and strides when processing tensor descriptors which is
|
|
153
|
+
# why we provide our own decomposition above. Sadly this means
|
|
154
|
+
# we have to pass the shape and strides twice.
|
|
155
|
+
for _ in range(2 * ndim):
|
|
156
|
+
output.append("i64")
|
|
157
|
+
else:
|
|
158
|
+
output.append("nvTmaDesc")
|
|
159
|
+
|
|
160
|
+
for _ in range(ndim):
|
|
161
|
+
output.append("i32")
|
|
162
|
+
for _ in range(ndim):
|
|
163
|
+
output.append("i64")
|
|
164
|
+
else:
|
|
165
|
+
output.append(sig)
|
|
166
|
+
|
|
167
|
+
assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
|
|
168
|
+
return output
|
|
169
|
+
|
|
170
|
+
def _flatten_signature(sig, output):
|
|
171
|
+
# Flatten tuples
|
|
141
172
|
if isinstance(sig, tuple):
|
|
142
|
-
|
|
143
|
-
|
|
173
|
+
for x in sig:
|
|
174
|
+
_flatten_signature(x, output)
|
|
175
|
+
else:
|
|
176
|
+
output.append(sig)
|
|
144
177
|
|
|
145
178
|
def _extracted_type(ty):
|
|
146
179
|
if isinstance(ty, tuple):
|
|
@@ -160,8 +193,9 @@ def make_launcher(constants, signature):
|
|
|
160
193
|
return "O"
|
|
161
194
|
if ty in ("constexpr", "nvTmaDesc"):
|
|
162
195
|
return "O"
|
|
196
|
+
if ty.startswith("tensordesc"):
|
|
197
|
+
return "O"
|
|
163
198
|
return {
|
|
164
|
-
"float": "f",
|
|
165
199
|
"double": "d",
|
|
166
200
|
"long": "l",
|
|
167
201
|
"int8_t": "b",
|
|
@@ -174,19 +208,34 @@ def make_launcher(constants, signature):
|
|
|
174
208
|
"uint64_t": "K",
|
|
175
209
|
}[ty_to_cpp(ty)]
|
|
176
210
|
|
|
211
|
+
expand_signature = _expand_signature(signature.values())
|
|
212
|
+
signature = {i: s for i, s in enumerate(expand_signature)}
|
|
213
|
+
|
|
177
214
|
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
|
178
|
-
format =
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
215
|
+
format = _BASE_ARGS_FORMAT + args_format
|
|
216
|
+
|
|
217
|
+
flat_signature = []
|
|
218
|
+
for sig in signature.values():
|
|
219
|
+
_flatten_signature(sig, flat_signature)
|
|
220
|
+
signature = {i: s for i, s in enumerate(flat_signature)}
|
|
182
221
|
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
|
183
222
|
# Record the end of regular arguments;
|
|
184
223
|
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
185
|
-
|
|
224
|
+
arg_decl_list = []
|
|
225
|
+
for i, ty in signature.items():
|
|
226
|
+
if ty == "constexpr":
|
|
227
|
+
continue
|
|
228
|
+
if ty in FLOAT_STORAGE_TYPE:
|
|
229
|
+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
|
|
230
|
+
else:
|
|
231
|
+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
|
|
232
|
+
arg_decls = ', '.join(arg_decl_list)
|
|
186
233
|
internal_args_list = []
|
|
187
234
|
for i, ty in signature.items():
|
|
188
235
|
if ty[0] == "*":
|
|
189
236
|
internal_args_list.append(f"ptr_info{i}.dev_ptr")
|
|
237
|
+
elif ty in FLOAT_STORAGE_TYPE:
|
|
238
|
+
internal_args_list.append(f"_arg{i}_storage")
|
|
190
239
|
elif ty == "nvTmaDesc":
|
|
191
240
|
# Note: we have to dereference the pointer
|
|
192
241
|
internal_args_list.append(f"*tma_ptr{i}")
|
|
@@ -205,14 +254,17 @@ def make_launcher(constants, signature):
|
|
|
205
254
|
f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
|
|
206
255
|
if ty == "nvTmaDesc"
|
|
207
256
|
]
|
|
257
|
+
float_storage_decls = [
|
|
258
|
+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
|
|
259
|
+
for i, ty in signature.items()
|
|
260
|
+
if ty in FLOAT_STORAGE_TYPE
|
|
261
|
+
]
|
|
208
262
|
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
|
209
263
|
params.append("&global_scratch")
|
|
210
264
|
src = f"""
|
|
211
265
|
#define _CRT_SECURE_NO_WARNINGS
|
|
212
266
|
#include \"cuda.h\"
|
|
213
267
|
#include <stdbool.h>
|
|
214
|
-
#define PY_SSIZE_T_CLEAN
|
|
215
|
-
#define Py_LIMITED_API 0x03090000
|
|
216
268
|
#include <Python.h>
|
|
217
269
|
|
|
218
270
|
#ifndef _WIN32
|
|
@@ -282,67 +334,65 @@ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
|
|
282
334
|
}}
|
|
283
335
|
#endif
|
|
284
336
|
|
|
285
|
-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
337
|
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
286
338
|
void *params[] = {{ {', '.join(params)} }};
|
|
287
339
|
if (gridX*gridY*gridZ > 0) {{
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
340
|
+
// 4 attributes that we can currently pass maxmimum
|
|
341
|
+
CUlaunchAttribute launchAttr[4];
|
|
342
|
+
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
343
|
+
if (cuLaunchKernelExHandle == NULL) {{
|
|
344
|
+
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
345
|
+
}}
|
|
346
|
+
CUlaunchConfig config;
|
|
347
|
+
config.gridDimX = gridX;
|
|
348
|
+
config.gridDimY = gridY;
|
|
349
|
+
config.gridDimZ = gridZ;
|
|
350
|
+
|
|
351
|
+
if (num_ctas != 1) {{
|
|
352
|
+
config.gridDimX *= clusterDimX;
|
|
353
|
+
config.gridDimY *= clusterDimY;
|
|
354
|
+
config.gridDimZ *= clusterDimZ;
|
|
355
|
+
}}
|
|
356
|
+
|
|
357
|
+
config.blockDimX = 32 * num_warps;
|
|
358
|
+
config.blockDimY = 1;
|
|
359
|
+
config.blockDimZ = 1;
|
|
360
|
+
config.sharedMemBytes = shared_memory;
|
|
361
|
+
config.hStream = stream;
|
|
362
|
+
config.attrs = launchAttr;
|
|
363
|
+
int num_attrs = 0;
|
|
364
|
+
|
|
365
|
+
if (launch_pdl != 0) {{
|
|
366
|
+
CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
|
|
367
|
+
launchAttr[num_attrs] = pdlAttr;
|
|
368
|
+
++num_attrs;
|
|
369
|
+
}}
|
|
370
|
+
|
|
371
|
+
if (launch_cooperative_grid != 0) {{
|
|
292
372
|
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
|
|
293
|
-
launchAttr[
|
|
294
|
-
|
|
295
|
-
CUlaunchConfig config;
|
|
296
|
-
config.gridDimX = gridX;
|
|
297
|
-
config.gridDimY = gridY;
|
|
298
|
-
config.gridDimZ = gridZ;
|
|
299
|
-
config.blockDimX = 32 * num_warps;
|
|
300
|
-
config.blockDimY = 1;
|
|
301
|
-
config.blockDimZ = 1;
|
|
302
|
-
config.sharedMemBytes = shared_memory;
|
|
303
|
-
config.hStream = stream;
|
|
304
|
-
config.attrs = launchAttr;
|
|
305
|
-
config.numAttrs = 1;
|
|
306
|
-
|
|
307
|
-
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
308
|
-
if (cuLaunchKernelExHandle == NULL) {{
|
|
309
|
-
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
310
|
-
}}
|
|
311
|
-
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
312
|
-
|
|
313
|
-
}} else {{
|
|
314
|
-
CUlaunchAttribute launchAttr[3];
|
|
315
|
-
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
316
|
-
launchAttr[0].value.clusterDim.x = clusterDimX;
|
|
317
|
-
launchAttr[0].value.clusterDim.y = clusterDimY;
|
|
318
|
-
launchAttr[0].value.clusterDim.z = clusterDimZ;
|
|
319
|
-
launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
320
|
-
launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
|
321
|
-
|
|
322
|
-
unsigned numAttrs = 2;
|
|
323
|
-
if (0 != launch_cooperative_grid) {{
|
|
324
|
-
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
|
|
325
|
-
launchAttr[2] = coopAttr;
|
|
326
|
-
numAttrs = 3;
|
|
327
|
-
}}
|
|
328
|
-
|
|
329
|
-
CUlaunchConfig config;
|
|
330
|
-
config.gridDimX = gridX * clusterDimX;
|
|
331
|
-
config.gridDimY = gridY * clusterDimY;
|
|
332
|
-
config.gridDimZ = gridZ * clusterDimZ;
|
|
333
|
-
config.blockDimX = 32 * num_warps;
|
|
334
|
-
config.blockDimY = 1;
|
|
335
|
-
config.blockDimZ = 1;
|
|
336
|
-
config.sharedMemBytes = shared_memory;
|
|
337
|
-
config.hStream = stream;
|
|
338
|
-
config.attrs = launchAttr;
|
|
339
|
-
config.numAttrs = numAttrs;
|
|
340
|
-
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
341
|
-
if (cuLaunchKernelExHandle == NULL) {{
|
|
342
|
-
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
343
|
-
}}
|
|
344
|
-
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
373
|
+
launchAttr[num_attrs] = coopAttr;
|
|
374
|
+
++num_attrs;
|
|
345
375
|
}}
|
|
376
|
+
|
|
377
|
+
if (num_ctas != 1) {{
|
|
378
|
+
CUlaunchAttribute clusterAttr = {{}};
|
|
379
|
+
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
380
|
+
clusterAttr.value.clusterDim.x = clusterDimX;
|
|
381
|
+
clusterAttr.value.clusterDim.y = clusterDimY;
|
|
382
|
+
clusterAttr.value.clusterDim.z = clusterDimZ;
|
|
383
|
+
launchAttr[num_attrs] = clusterAttr;
|
|
384
|
+
++num_attrs;
|
|
385
|
+
|
|
386
|
+
CUlaunchAttribute clusterSchedulingAttr = {{}};
|
|
387
|
+
clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
388
|
+
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
|
389
|
+
launchAttr[num_attrs] = clusterSchedulingAttr;
|
|
390
|
+
++num_attrs;
|
|
391
|
+
}}
|
|
392
|
+
|
|
393
|
+
config.numAttrs = num_attrs;
|
|
394
|
+
|
|
395
|
+
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
346
396
|
}}
|
|
347
397
|
}}
|
|
348
398
|
|
|
@@ -457,6 +507,32 @@ static void ensureCudaContext() {{
|
|
|
457
507
|
}}
|
|
458
508
|
}}
|
|
459
509
|
|
|
510
|
+
static uint16_t pack_fp16(double f) {{
|
|
511
|
+
uint16_t result;
|
|
512
|
+
// from https://github.com/python/pythoncapi-compat
|
|
513
|
+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
|
514
|
+
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
515
|
+
#else
|
|
516
|
+
PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
517
|
+
#endif
|
|
518
|
+
return result;
|
|
519
|
+
}}
|
|
520
|
+
|
|
521
|
+
static uint16_t pack_bf16(double f) {{
|
|
522
|
+
float f32 = (float)f;
|
|
523
|
+
uint32_t u32 = *(uint32_t*)&f32;
|
|
524
|
+
return (uint16_t)(u32 >> 16);
|
|
525
|
+
}}
|
|
526
|
+
|
|
527
|
+
static uint32_t pack_fp32(double f) {{
|
|
528
|
+
float f32 = (float)f;
|
|
529
|
+
return *(uint32_t*)&f32;
|
|
530
|
+
}}
|
|
531
|
+
|
|
532
|
+
static uint64_t pack_fp64(double f) {{
|
|
533
|
+
return *(uint64_t*)&f;
|
|
534
|
+
}}
|
|
535
|
+
|
|
460
536
|
static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
461
537
|
// ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
|
|
462
538
|
ensureCudaContext();
|
|
@@ -465,6 +541,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
465
541
|
uint64_t _stream;
|
|
466
542
|
uint64_t _function;
|
|
467
543
|
int launch_cooperative_grid;
|
|
544
|
+
int launch_pdl;
|
|
468
545
|
PyObject *launch_enter_hook = NULL;
|
|
469
546
|
PyObject *launch_exit_hook = NULL;
|
|
470
547
|
PyObject *kernel_metadata = NULL;
|
|
@@ -472,7 +549,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
472
549
|
PyObject *global_scratch_obj = NULL;
|
|
473
550
|
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
|
|
474
551
|
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
|
|
475
|
-
&_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
|
|
552
|
+
&_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj,
|
|
476
553
|
&kernel_metadata, &launch_metadata,
|
|
477
554
|
&launch_enter_hook, &launch_exit_hook{args_list})) {{
|
|
478
555
|
return NULL;
|
|
@@ -506,8 +583,9 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
506
583
|
// raise exception asap
|
|
507
584
|
{newline.join(ptr_decls)}
|
|
508
585
|
{newline.join(tma_decls)}
|
|
586
|
+
{newline.join(float_storage_decls)}
|
|
509
587
|
Py_BEGIN_ALLOW_THREADS;
|
|
510
|
-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
588
|
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
511
589
|
Py_END_ALLOW_THREADS;
|
|
512
590
|
if (PyErr_Occurred()) {{
|
|
513
591
|
return NULL;
|
|
@@ -550,6 +628,87 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
|
550
628
|
return src
|
|
551
629
|
|
|
552
630
|
|
|
631
|
+
class TmaDescKernelParam:
|
|
632
|
+
TMA_DESC_SIZE = 128
|
|
633
|
+
|
|
634
|
+
def __init__(self):
|
|
635
|
+
import torch
|
|
636
|
+
self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu")
|
|
637
|
+
|
|
638
|
+
# Return a CUtensorMap* pointer in host memory
|
|
639
|
+
def tma_desc_cpu_ptr(self):
|
|
640
|
+
return self.desc.data_ptr()
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
# The TMA dtype enum values are slightly different on host vs device...
|
|
644
|
+
TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
|
|
645
|
+
TMA_DTYPE_DEVICE_TO_HOST[8] = 10
|
|
646
|
+
TMA_DTYPE_DEVICE_TO_HOST[9] = 8
|
|
647
|
+
TMA_DTYPE_DEVICE_TO_HOST[10] = 9
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def make_tensordesc_arg(arg, metadata):
|
|
651
|
+
if metadata is None:
|
|
652
|
+
# Currently the host side tensor descriptors get decomposed in
|
|
653
|
+
# the frontend to tensor desc, shape, and strides. We have no
|
|
654
|
+
# way to use these shape and strides when processing tensor
|
|
655
|
+
# descriptors which is why we provide our own decomposition
|
|
656
|
+
# above. Sadly this means we have to pass the shape and strides
|
|
657
|
+
# twice.
|
|
658
|
+
return [arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides]
|
|
659
|
+
|
|
660
|
+
swizzle = metadata["swizzle"]
|
|
661
|
+
elem_size = metadata["elem_size"]
|
|
662
|
+
elem_type = metadata["elem_type"]
|
|
663
|
+
block_size = metadata["block_size"]
|
|
664
|
+
fp4_padded = metadata["fp4_padded"]
|
|
665
|
+
|
|
666
|
+
data_ptr = arg.base.data_ptr()
|
|
667
|
+
shape = arg.shape
|
|
668
|
+
strides = arg.strides
|
|
669
|
+
assert strides[-1] == 1
|
|
670
|
+
|
|
671
|
+
desc = TmaDescKernelParam()
|
|
672
|
+
result = [desc, *shape, *strides]
|
|
673
|
+
|
|
674
|
+
if fp4_padded:
|
|
675
|
+
shape = list(shape)
|
|
676
|
+
shape[-1] *= 2
|
|
677
|
+
triton.runtime.driver.active.utils.fill_tma_descriptor(
|
|
678
|
+
desc.tma_desc_cpu_ptr(),
|
|
679
|
+
data_ptr,
|
|
680
|
+
swizzle,
|
|
681
|
+
elem_size,
|
|
682
|
+
TMA_DTYPE_DEVICE_TO_HOST[elem_type],
|
|
683
|
+
block_size,
|
|
684
|
+
shape,
|
|
685
|
+
strides,
|
|
686
|
+
)
|
|
687
|
+
return result
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def wrap_handle_tensordesc(launcher, tensordesc_meta):
|
|
691
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
692
|
+
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
|
|
693
|
+
|
|
694
|
+
def inner(*args):
|
|
695
|
+
meta_args = args[:len(_BASE_ARGS_FORMAT)]
|
|
696
|
+
raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
|
|
697
|
+
tensordesc_idx = 0
|
|
698
|
+
final_args = []
|
|
699
|
+
for i, arg in enumerate(raw_kernel_args):
|
|
700
|
+
if isinstance(arg, (TensorDescriptor, GluonTensorDescriptor)):
|
|
701
|
+
meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
|
|
702
|
+
tensordesc_idx += 1
|
|
703
|
+
final_args.extend(make_tensordesc_arg(arg, meta))
|
|
704
|
+
else:
|
|
705
|
+
final_args.append(arg)
|
|
706
|
+
assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
|
|
707
|
+
return launcher(*meta_args, *final_args)
|
|
708
|
+
|
|
709
|
+
return inner
|
|
710
|
+
|
|
711
|
+
|
|
553
712
|
class CudaLauncher(object):
|
|
554
713
|
|
|
555
714
|
def __init__(self, src, metadata):
|
|
@@ -557,21 +716,33 @@ class CudaLauncher(object):
|
|
|
557
716
|
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
|
|
558
717
|
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
|
559
718
|
signature = {idx: value for idx, value in src.signature.items()}
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
719
|
+
tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
|
|
720
|
+
src = make_launcher(constants, signature, tensordesc_meta)
|
|
721
|
+
mod = compile_module_from_src(
|
|
722
|
+
src=src,
|
|
723
|
+
name="__triton_launcher",
|
|
724
|
+
library_dirs=library_dirs(),
|
|
725
|
+
include_dirs=include_dirs,
|
|
726
|
+
libraries=libraries,
|
|
727
|
+
)
|
|
728
|
+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
|
|
729
|
+
|
|
730
|
+
self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
|
|
731
|
+
self.launch = wrap_handle_tensordesc(mod.launch, tensordesc_meta) if has_tensor_desc_arg else mod.launch
|
|
563
732
|
self.global_scratch_size = metadata.global_scratch_size
|
|
564
733
|
self.global_scratch_align = metadata.global_scratch_align
|
|
565
734
|
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
|
735
|
+
self.launch_pdl = metadata.launch_pdl
|
|
566
736
|
|
|
567
737
|
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
|
568
738
|
if self.global_scratch_size > 0:
|
|
569
739
|
grid_size = gridX * gridY * gridZ
|
|
570
|
-
alloc_size = grid_size * self.global_scratch_size
|
|
740
|
+
alloc_size = grid_size * self.num_ctas * self.global_scratch_size
|
|
571
741
|
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
|
|
572
742
|
else:
|
|
573
743
|
global_scratch = None
|
|
574
|
-
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid,
|
|
744
|
+
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
|
|
745
|
+
global_scratch, *args)
|
|
575
746
|
|
|
576
747
|
|
|
577
748
|
class CudaDriver(GPUDriver):
|