triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/backends/nvidia/driver.py
CHANGED
|
@@ -1,37 +1,37 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import operator
|
|
2
3
|
import os
|
|
3
|
-
import sysconfig
|
|
4
|
-
import hashlib
|
|
5
4
|
import subprocess
|
|
6
|
-
import
|
|
5
|
+
import triton
|
|
6
|
+
import re
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from triton
|
|
9
|
-
from triton.runtime.
|
|
8
|
+
from triton import knobs
|
|
9
|
+
from triton.runtime.build import compile_module_from_src
|
|
10
10
|
from triton.runtime import _allocation
|
|
11
11
|
from triton.backends.compiler import GPUTarget
|
|
12
12
|
from triton.backends.driver import GPUDriver
|
|
13
13
|
|
|
14
14
|
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
15
|
-
|
|
15
|
+
include_dirs = [os.path.join(dirname, "include")]
|
|
16
16
|
if os.name == "nt":
|
|
17
17
|
from triton.windows_utils import find_cuda
|
|
18
18
|
_, cuda_inc_dirs, _ = find_cuda()
|
|
19
|
-
|
|
19
|
+
include_dirs += cuda_inc_dirs
|
|
20
20
|
libdevice_dir = os.path.join(dirname, "lib")
|
|
21
21
|
libraries = ['cuda']
|
|
22
|
+
PyCUtensorMap = None
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
@functools.lru_cache()
|
|
25
26
|
def libcuda_dirs():
|
|
26
|
-
env_libcuda_path
|
|
27
|
-
if env_libcuda_path:
|
|
27
|
+
if env_libcuda_path := knobs.nvidia.libcuda_path:
|
|
28
28
|
return [env_libcuda_path]
|
|
29
29
|
|
|
30
30
|
if os.name == "nt":
|
|
31
31
|
_, _, cuda_lib_dirs = find_cuda()
|
|
32
32
|
return cuda_lib_dirs
|
|
33
33
|
|
|
34
|
-
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
|
|
34
|
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
|
|
35
35
|
# each line looks like the following:
|
|
36
36
|
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
|
|
37
37
|
locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
|
|
@@ -55,36 +55,6 @@ def library_dirs():
|
|
|
55
55
|
return [libdevice_dir, *libcuda_dirs()]
|
|
56
56
|
|
|
57
57
|
|
|
58
|
-
@functools.lru_cache()
|
|
59
|
-
def platform_key():
|
|
60
|
-
from platform import machine, system, architecture
|
|
61
|
-
return ",".join([machine(), system(), *architecture()])
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def compile_module_from_src(src, name):
|
|
65
|
-
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
|
|
66
|
-
cache = get_cache_manager(key)
|
|
67
|
-
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
68
|
-
cache_path = cache.get_file(f"{name}.{ext}")
|
|
69
|
-
if cache_path is None:
|
|
70
|
-
with tempfile.TemporaryDirectory() as tmpdir:
|
|
71
|
-
src_path = os.path.join(tmpdir, f"{name}.c")
|
|
72
|
-
with open(src_path, "w") as f:
|
|
73
|
-
f.write(src)
|
|
74
|
-
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
|
|
75
|
-
with open(so, "rb") as f:
|
|
76
|
-
cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
|
|
77
|
-
|
|
78
|
-
# Loading module with relative path may cause error
|
|
79
|
-
cache_path = os.path.abspath(cache_path)
|
|
80
|
-
|
|
81
|
-
import importlib.util
|
|
82
|
-
spec = importlib.util.spec_from_file_location(name, cache_path)
|
|
83
|
-
mod = importlib.util.module_from_spec(spec)
|
|
84
|
-
spec.loader.exec_module(mod)
|
|
85
|
-
return mod
|
|
86
|
-
|
|
87
|
-
|
|
88
58
|
# ------------------------
|
|
89
59
|
# Utils
|
|
90
60
|
# ------------------------
|
|
@@ -98,13 +68,20 @@ class CudaUtils(object):
|
|
|
98
68
|
return cls.instance
|
|
99
69
|
|
|
100
70
|
def __init__(self):
|
|
101
|
-
mod = compile_module_from_src(
|
|
71
|
+
mod = compile_module_from_src(
|
|
72
|
+
src=Path(os.path.join(dirname, "driver.c")).read_text(),
|
|
73
|
+
name="cuda_utils",
|
|
74
|
+
library_dirs=library_dirs(),
|
|
75
|
+
include_dirs=include_dirs,
|
|
76
|
+
libraries=libraries,
|
|
77
|
+
)
|
|
78
|
+
global PyCUtensorMap
|
|
79
|
+
PyCUtensorMap = mod.PyCUtensorMap
|
|
102
80
|
self.load_binary = mod.load_binary
|
|
103
81
|
self.get_device_properties = mod.get_device_properties
|
|
104
82
|
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
|
|
105
83
|
self.set_printf_fifo_size = mod.set_printf_fifo_size
|
|
106
|
-
self.
|
|
107
|
-
self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor
|
|
84
|
+
self.fill_tma_descriptor = mod.fill_tma_descriptor
|
|
108
85
|
|
|
109
86
|
|
|
110
87
|
# ------------------------
|
|
@@ -115,32 +92,95 @@ class CudaUtils(object):
|
|
|
115
92
|
def ty_to_cpp(ty):
|
|
116
93
|
if ty[0] == '*':
|
|
117
94
|
return "CUdeviceptr"
|
|
95
|
+
if ty.startswith("tensordesc"):
|
|
96
|
+
return "CUtensorMap"
|
|
118
97
|
return {
|
|
119
|
-
"i1": "
|
|
98
|
+
"i1": "int8_t",
|
|
120
99
|
"i8": "int8_t",
|
|
121
100
|
"i16": "int16_t",
|
|
122
101
|
"i32": "int32_t",
|
|
123
102
|
"i64": "int64_t",
|
|
124
|
-
"u1": "
|
|
103
|
+
"u1": "uint8_t",
|
|
125
104
|
"u8": "uint8_t",
|
|
126
105
|
"u16": "uint16_t",
|
|
127
106
|
"u32": "uint32_t",
|
|
128
107
|
"u64": "uint64_t",
|
|
129
|
-
"fp16": "
|
|
130
|
-
"bf16": "
|
|
131
|
-
"fp32": "
|
|
132
|
-
"f32": "
|
|
108
|
+
"fp16": "double",
|
|
109
|
+
"bf16": "double",
|
|
110
|
+
"fp32": "double",
|
|
111
|
+
"f32": "double",
|
|
133
112
|
"fp64": "double",
|
|
134
113
|
"nvTmaDesc": "CUtensorMap",
|
|
135
114
|
}[ty]
|
|
136
115
|
|
|
137
116
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
117
|
+
FLOAT_STORAGE_TYPE = {
|
|
118
|
+
"fp16": "uint16_t",
|
|
119
|
+
"bf16": "uint16_t",
|
|
120
|
+
"fp32": "uint32_t",
|
|
121
|
+
"f32": "uint32_t",
|
|
122
|
+
"fp64": "uint64_t",
|
|
123
|
+
}
|
|
124
|
+
FLOAT_PACK_FUNCTION = {
|
|
125
|
+
"fp16": "pack_fp16",
|
|
126
|
+
"bf16": "pack_bf16",
|
|
127
|
+
"fp32": "pack_fp32",
|
|
128
|
+
"f32": "pack_fp32",
|
|
129
|
+
"fp64": "pack_fp64",
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
_BASE_ARGS_FORMAT = "iiiKKppOOOOOO"
|
|
133
|
+
_BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def make_launcher(constants, signature, tensordesc_meta):
|
|
137
|
+
|
|
138
|
+
def _expand_signature(signature):
|
|
139
|
+
output = []
|
|
140
|
+
tensordesc_idx = 0
|
|
141
|
+
# Expand tensor descriptor arguments into either nvTmaDesc, shape and
|
|
142
|
+
# strides, or base pointer, shape and strides depending on whether the
|
|
143
|
+
# kernel was lowered to use the nvTmaDesc or not.
|
|
144
|
+
for sig in signature:
|
|
145
|
+
if isinstance(sig, str) and sig.startswith("tensordesc"):
|
|
146
|
+
meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
|
|
147
|
+
tensordesc_idx += 1
|
|
148
|
+
|
|
149
|
+
match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
|
|
150
|
+
dtype = match.group(1)
|
|
151
|
+
shape = match.group(2)
|
|
152
|
+
ndim = shape.count(",") + 1
|
|
153
|
+
|
|
154
|
+
if meta is None:
|
|
155
|
+
output.append("*" + dtype)
|
|
156
|
+
# Currently the host side tensor descriptors get passed in as a
|
|
157
|
+
# tensor desc, shape, and strides. We have no way to use these
|
|
158
|
+
# shape and strides when processing tensor descriptors which is
|
|
159
|
+
# why we provide our own decomposition above. Sadly this means
|
|
160
|
+
# we have to pass the shape and strides twice.
|
|
161
|
+
for _ in range(2 * ndim):
|
|
162
|
+
output.append("i64")
|
|
163
|
+
output.append("i1")
|
|
164
|
+
else:
|
|
165
|
+
output.append("nvTmaDesc")
|
|
166
|
+
|
|
167
|
+
for _ in range(ndim):
|
|
168
|
+
output.append("i32")
|
|
169
|
+
for _ in range(ndim):
|
|
170
|
+
output.append("i64")
|
|
171
|
+
else:
|
|
172
|
+
output.append(sig)
|
|
173
|
+
|
|
174
|
+
assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
|
|
175
|
+
return output
|
|
176
|
+
|
|
177
|
+
def _flatten_signature(sig, output):
|
|
178
|
+
# Flatten tuples
|
|
141
179
|
if isinstance(sig, tuple):
|
|
142
|
-
|
|
143
|
-
|
|
180
|
+
for x in sig:
|
|
181
|
+
_flatten_signature(x, output)
|
|
182
|
+
else:
|
|
183
|
+
output.append(sig)
|
|
144
184
|
|
|
145
185
|
def _extracted_type(ty):
|
|
146
186
|
if isinstance(ty, tuple):
|
|
@@ -160,8 +200,9 @@ def make_launcher(constants, signature):
|
|
|
160
200
|
return "O"
|
|
161
201
|
if ty in ("constexpr", "nvTmaDesc"):
|
|
162
202
|
return "O"
|
|
203
|
+
if ty.startswith("tensordesc"):
|
|
204
|
+
return "O"
|
|
163
205
|
return {
|
|
164
|
-
"float": "f",
|
|
165
206
|
"double": "d",
|
|
166
207
|
"long": "l",
|
|
167
208
|
"int8_t": "b",
|
|
@@ -174,19 +215,34 @@ def make_launcher(constants, signature):
|
|
|
174
215
|
"uint64_t": "K",
|
|
175
216
|
}[ty_to_cpp(ty)]
|
|
176
217
|
|
|
218
|
+
expand_signature = _expand_signature(signature.values())
|
|
219
|
+
signature = {i: s for i, s in enumerate(expand_signature)}
|
|
220
|
+
|
|
177
221
|
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
|
178
|
-
format =
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
222
|
+
format = _BASE_ARGS_FORMAT + args_format
|
|
223
|
+
|
|
224
|
+
flat_signature = []
|
|
225
|
+
for sig in signature.values():
|
|
226
|
+
_flatten_signature(sig, flat_signature)
|
|
227
|
+
signature = {i: s for i, s in enumerate(flat_signature)}
|
|
182
228
|
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
|
183
229
|
# Record the end of regular arguments;
|
|
184
230
|
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
185
|
-
|
|
231
|
+
arg_decl_list = []
|
|
232
|
+
for i, ty in signature.items():
|
|
233
|
+
if ty == "constexpr":
|
|
234
|
+
continue
|
|
235
|
+
if ty in FLOAT_STORAGE_TYPE:
|
|
236
|
+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
|
|
237
|
+
else:
|
|
238
|
+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
|
|
239
|
+
arg_decls = ', '.join(arg_decl_list)
|
|
186
240
|
internal_args_list = []
|
|
187
241
|
for i, ty in signature.items():
|
|
188
242
|
if ty[0] == "*":
|
|
189
243
|
internal_args_list.append(f"ptr_info{i}.dev_ptr")
|
|
244
|
+
elif ty in FLOAT_STORAGE_TYPE:
|
|
245
|
+
internal_args_list.append(f"_arg{i}_storage")
|
|
190
246
|
elif ty == "nvTmaDesc":
|
|
191
247
|
# Note: we have to dereference the pointer
|
|
192
248
|
internal_args_list.append(f"*tma_ptr{i}")
|
|
@@ -205,15 +261,17 @@ def make_launcher(constants, signature):
|
|
|
205
261
|
f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
|
|
206
262
|
if ty == "nvTmaDesc"
|
|
207
263
|
]
|
|
264
|
+
float_storage_decls = [
|
|
265
|
+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
|
|
266
|
+
for i, ty in signature.items()
|
|
267
|
+
if ty in FLOAT_STORAGE_TYPE
|
|
268
|
+
]
|
|
208
269
|
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
|
209
270
|
params.append("&global_scratch")
|
|
271
|
+
params.append("&profile_scratch")
|
|
210
272
|
src = f"""
|
|
211
273
|
#define _CRT_SECURE_NO_WARNINGS
|
|
212
274
|
#include \"cuda.h\"
|
|
213
|
-
#include <stdbool.h>
|
|
214
|
-
#define PY_SSIZE_T_CLEAN
|
|
215
|
-
#define Py_LIMITED_API 0x03090000
|
|
216
|
-
#include <Python.h>
|
|
217
275
|
|
|
218
276
|
#ifndef _WIN32
|
|
219
277
|
#include <dlfcn.h>
|
|
@@ -222,6 +280,16 @@ def make_launcher(constants, signature):
|
|
|
222
280
|
#include <windows.h>
|
|
223
281
|
#endif
|
|
224
282
|
|
|
283
|
+
#include <stdbool.h>
|
|
284
|
+
#include <stdlib.h>
|
|
285
|
+
#define PY_SSIZE_T_CLEAN
|
|
286
|
+
#include <Python.h>
|
|
287
|
+
|
|
288
|
+
typedef struct {{
|
|
289
|
+
PyObject_HEAD
|
|
290
|
+
_Alignas(128) CUtensorMap tensorMap;
|
|
291
|
+
}} PyCUtensorMapObject;
|
|
292
|
+
|
|
225
293
|
static inline void gpuAssert(CUresult code, const char *file, int line)
|
|
226
294
|
{{
|
|
227
295
|
if (code != CUDA_SUCCESS)
|
|
@@ -282,67 +350,65 @@ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
|
|
282
350
|
}}
|
|
283
351
|
#endif
|
|
284
352
|
|
|
285
|
-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
353
|
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
286
354
|
void *params[] = {{ {', '.join(params)} }};
|
|
287
355
|
if (gridX*gridY*gridZ > 0) {{
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
356
|
+
// 4 attributes that we can currently pass maximum
|
|
357
|
+
CUlaunchAttribute launchAttr[4];
|
|
358
|
+
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
359
|
+
if (cuLaunchKernelExHandle == NULL) {{
|
|
360
|
+
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
361
|
+
}}
|
|
362
|
+
CUlaunchConfig config;
|
|
363
|
+
config.gridDimX = gridX;
|
|
364
|
+
config.gridDimY = gridY;
|
|
365
|
+
config.gridDimZ = gridZ;
|
|
366
|
+
|
|
367
|
+
if (num_ctas != 1) {{
|
|
368
|
+
config.gridDimX *= clusterDimX;
|
|
369
|
+
config.gridDimY *= clusterDimY;
|
|
370
|
+
config.gridDimZ *= clusterDimZ;
|
|
371
|
+
}}
|
|
372
|
+
|
|
373
|
+
config.blockDimX = 32 * num_warps;
|
|
374
|
+
config.blockDimY = 1;
|
|
375
|
+
config.blockDimZ = 1;
|
|
376
|
+
config.sharedMemBytes = shared_memory;
|
|
377
|
+
config.hStream = stream;
|
|
378
|
+
config.attrs = launchAttr;
|
|
379
|
+
int num_attrs = 0;
|
|
380
|
+
|
|
381
|
+
if (launch_pdl != 0) {{
|
|
382
|
+
CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
|
|
383
|
+
launchAttr[num_attrs] = pdlAttr;
|
|
384
|
+
++num_attrs;
|
|
385
|
+
}}
|
|
386
|
+
|
|
387
|
+
if (launch_cooperative_grid != 0) {{
|
|
292
388
|
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
|
|
293
|
-
launchAttr[
|
|
294
|
-
|
|
295
|
-
CUlaunchConfig config;
|
|
296
|
-
config.gridDimX = gridX;
|
|
297
|
-
config.gridDimY = gridY;
|
|
298
|
-
config.gridDimZ = gridZ;
|
|
299
|
-
config.blockDimX = 32 * num_warps;
|
|
300
|
-
config.blockDimY = 1;
|
|
301
|
-
config.blockDimZ = 1;
|
|
302
|
-
config.sharedMemBytes = shared_memory;
|
|
303
|
-
config.hStream = stream;
|
|
304
|
-
config.attrs = launchAttr;
|
|
305
|
-
config.numAttrs = 1;
|
|
306
|
-
|
|
307
|
-
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
308
|
-
if (cuLaunchKernelExHandle == NULL) {{
|
|
309
|
-
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
310
|
-
}}
|
|
311
|
-
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
312
|
-
|
|
313
|
-
}} else {{
|
|
314
|
-
CUlaunchAttribute launchAttr[3];
|
|
315
|
-
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
316
|
-
launchAttr[0].value.clusterDim.x = clusterDimX;
|
|
317
|
-
launchAttr[0].value.clusterDim.y = clusterDimY;
|
|
318
|
-
launchAttr[0].value.clusterDim.z = clusterDimZ;
|
|
319
|
-
launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
320
|
-
launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
|
321
|
-
|
|
322
|
-
unsigned numAttrs = 2;
|
|
323
|
-
if (0 != launch_cooperative_grid) {{
|
|
324
|
-
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
|
|
325
|
-
launchAttr[2] = coopAttr;
|
|
326
|
-
numAttrs = 3;
|
|
327
|
-
}}
|
|
328
|
-
|
|
329
|
-
CUlaunchConfig config;
|
|
330
|
-
config.gridDimX = gridX * clusterDimX;
|
|
331
|
-
config.gridDimY = gridY * clusterDimY;
|
|
332
|
-
config.gridDimZ = gridZ * clusterDimZ;
|
|
333
|
-
config.blockDimX = 32 * num_warps;
|
|
334
|
-
config.blockDimY = 1;
|
|
335
|
-
config.blockDimZ = 1;
|
|
336
|
-
config.sharedMemBytes = shared_memory;
|
|
337
|
-
config.hStream = stream;
|
|
338
|
-
config.attrs = launchAttr;
|
|
339
|
-
config.numAttrs = numAttrs;
|
|
340
|
-
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
341
|
-
if (cuLaunchKernelExHandle == NULL) {{
|
|
342
|
-
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
343
|
-
}}
|
|
344
|
-
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
389
|
+
launchAttr[num_attrs] = coopAttr;
|
|
390
|
+
++num_attrs;
|
|
345
391
|
}}
|
|
392
|
+
|
|
393
|
+
if (num_ctas != 1) {{
|
|
394
|
+
CUlaunchAttribute clusterAttr = {{}};
|
|
395
|
+
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
396
|
+
clusterAttr.value.clusterDim.x = clusterDimX;
|
|
397
|
+
clusterAttr.value.clusterDim.y = clusterDimY;
|
|
398
|
+
clusterAttr.value.clusterDim.z = clusterDimZ;
|
|
399
|
+
launchAttr[num_attrs] = clusterAttr;
|
|
400
|
+
++num_attrs;
|
|
401
|
+
|
|
402
|
+
CUlaunchAttribute clusterSchedulingAttr = {{}};
|
|
403
|
+
clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
404
|
+
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
|
405
|
+
launchAttr[num_attrs] = clusterSchedulingAttr;
|
|
406
|
+
++num_attrs;
|
|
407
|
+
}}
|
|
408
|
+
|
|
409
|
+
config.numAttrs = num_attrs;
|
|
410
|
+
|
|
411
|
+
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
346
412
|
}}
|
|
347
413
|
}}
|
|
348
414
|
|
|
@@ -351,6 +417,9 @@ typedef struct _DevicePtrInfo {{
|
|
|
351
417
|
bool valid;
|
|
352
418
|
}} DevicePtrInfo;
|
|
353
419
|
|
|
420
|
+
static PyObject* data_ptr_str = NULL;
|
|
421
|
+
static PyObject* py_tensor_map_type = NULL;
|
|
422
|
+
|
|
354
423
|
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
355
424
|
DevicePtrInfo ptr_info;
|
|
356
425
|
ptr_info.dev_ptr = 0;
|
|
@@ -363,37 +432,35 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
|
363
432
|
// valid nullptr
|
|
364
433
|
return ptr_info;
|
|
365
434
|
}}
|
|
366
|
-
PyObject *
|
|
367
|
-
if(
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
return ptr_info;
|
|
380
|
-
uint64_t dev_ptr;
|
|
381
|
-
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
382
|
-
if (status == CUDA_ERROR_INVALID_VALUE) {{
|
|
383
|
-
PyErr_Format(PyExc_ValueError,
|
|
384
|
-
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
385
|
-
ptr_info.valid = false;
|
|
386
|
-
}} else if (status != CUDA_SUCCESS) {{
|
|
387
|
-
CUDA_CHECK(status); // Catch any other cuda API errors
|
|
388
|
-
ptr_info.valid = false;
|
|
389
|
-
}}
|
|
390
|
-
ptr_info.dev_ptr = dev_ptr;
|
|
391
|
-
Py_DECREF(ret); // Thanks ChatGPT!
|
|
435
|
+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
|
|
436
|
+
if (!ret) {{
|
|
437
|
+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
|
438
|
+
ptr_info.valid = false;
|
|
439
|
+
goto cleanup;
|
|
440
|
+
}}
|
|
441
|
+
if (!PyLong_Check(ret)) {{
|
|
442
|
+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
|
443
|
+
ptr_info.valid = false;
|
|
444
|
+
goto cleanup;
|
|
445
|
+
}}
|
|
446
|
+
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
|
|
447
|
+
if(!ptr_info.dev_ptr)
|
|
392
448
|
return ptr_info;
|
|
449
|
+
uint64_t dev_ptr;
|
|
450
|
+
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
451
|
+
if (status == CUDA_ERROR_INVALID_VALUE) {{
|
|
452
|
+
PyErr_Format(PyExc_ValueError,
|
|
453
|
+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
454
|
+
ptr_info.valid = false;
|
|
455
|
+
}} else if (status != CUDA_SUCCESS) {{
|
|
456
|
+
CUDA_CHECK(status); // Catch any other cuda API errors
|
|
457
|
+
ptr_info.valid = false;
|
|
393
458
|
}}
|
|
394
|
-
|
|
395
|
-
|
|
459
|
+
ptr_info.dev_ptr = dev_ptr;
|
|
460
|
+
cleanup:
|
|
461
|
+
Py_XDECREF(ret);
|
|
396
462
|
return ptr_info;
|
|
463
|
+
|
|
397
464
|
}}
|
|
398
465
|
|
|
399
466
|
static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
|
|
@@ -402,44 +469,18 @@ static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
|
|
|
402
469
|
return NULL;
|
|
403
470
|
}}
|
|
404
471
|
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist");
|
|
408
|
-
return NULL;
|
|
409
|
-
}}
|
|
410
|
-
|
|
411
|
-
PyObject *empty_tuple = PyTuple_New(0);
|
|
412
|
-
if (!empty_tuple) {{
|
|
413
|
-
Py_DECREF(method_handle);
|
|
414
|
-
PyErr_SetString(PyExc_SystemError, "Internal Python error!");
|
|
415
|
-
return NULL;
|
|
416
|
-
}}
|
|
417
|
-
PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL);
|
|
418
|
-
Py_DECREF(empty_tuple);
|
|
419
|
-
Py_DECREF(method_handle);
|
|
420
|
-
if (!method_ret) {{
|
|
421
|
-
PyErr_SetString(PyExc_SystemError, "Internal Python error!");
|
|
422
|
-
return NULL;
|
|
423
|
-
}}
|
|
424
|
-
|
|
425
|
-
if (!PyLong_Check(method_ret)) {{
|
|
426
|
-
PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int");
|
|
427
|
-
Py_DECREF(method_ret);
|
|
472
|
+
if (Py_TYPE(obj) != (PyTypeObject*)py_tensor_map_type) {{
|
|
473
|
+
PyErr_Format(PyExc_TypeError, "object must be of type PyCUtensorMap, got %s", Py_TYPE(obj)->tp_name);
|
|
428
474
|
return NULL;
|
|
429
|
-
|
|
475
|
+
}}
|
|
430
476
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
if (
|
|
434
|
-
|
|
435
|
-
return NULL;
|
|
436
|
-
}}
|
|
437
|
-
if (ptr_as_uint % 64 != 0) {{
|
|
438
|
-
PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned");
|
|
477
|
+
CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
|
|
478
|
+
uintptr_t align_128 = (uintptr_t)map & (128 - 1);
|
|
479
|
+
if (align_128 != 0) {{
|
|
480
|
+
PyErr_Format(PyExc_ValueError, "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", align_128);
|
|
439
481
|
return NULL;
|
|
440
482
|
}}
|
|
441
|
-
|
|
442
|
-
return (CUtensorMap*)(ptr_as_uint);
|
|
483
|
+
return map;
|
|
443
484
|
}}
|
|
444
485
|
|
|
445
486
|
static void ensureCudaContext() {{
|
|
@@ -454,6 +495,32 @@ static void ensureCudaContext() {{
|
|
|
454
495
|
}}
|
|
455
496
|
}}
|
|
456
497
|
|
|
498
|
+
static uint16_t pack_fp16(double f) {{
|
|
499
|
+
uint16_t result;
|
|
500
|
+
// from https://github.com/python/pythoncapi-compat
|
|
501
|
+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
|
502
|
+
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
503
|
+
#else
|
|
504
|
+
PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
505
|
+
#endif
|
|
506
|
+
return result;
|
|
507
|
+
}}
|
|
508
|
+
|
|
509
|
+
static uint16_t pack_bf16(double f) {{
|
|
510
|
+
float f32 = (float)f;
|
|
511
|
+
uint32_t u32 = *(uint32_t*)&f32;
|
|
512
|
+
return (uint16_t)(u32 >> 16);
|
|
513
|
+
}}
|
|
514
|
+
|
|
515
|
+
static uint32_t pack_fp32(double f) {{
|
|
516
|
+
float f32 = (float)f;
|
|
517
|
+
return *(uint32_t*)&f32;
|
|
518
|
+
}}
|
|
519
|
+
|
|
520
|
+
static uint64_t pack_fp64(double f) {{
|
|
521
|
+
return *(uint64_t*)&f;
|
|
522
|
+
}}
|
|
523
|
+
|
|
457
524
|
static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
458
525
|
// ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
|
|
459
526
|
ensureCudaContext();
|
|
@@ -462,14 +529,16 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
462
529
|
uint64_t _stream;
|
|
463
530
|
uint64_t _function;
|
|
464
531
|
int launch_cooperative_grid;
|
|
532
|
+
int launch_pdl;
|
|
465
533
|
PyObject *launch_enter_hook = NULL;
|
|
466
534
|
PyObject *launch_exit_hook = NULL;
|
|
467
535
|
PyObject *kernel_metadata = NULL;
|
|
468
536
|
PyObject *launch_metadata = NULL;
|
|
469
537
|
PyObject *global_scratch_obj = NULL;
|
|
538
|
+
PyObject *profile_scratch_obj = NULL;
|
|
470
539
|
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
|
|
471
540
|
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
|
|
472
|
-
&_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
|
|
541
|
+
&_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj, &profile_scratch_obj,
|
|
473
542
|
&kernel_metadata, &launch_metadata,
|
|
474
543
|
&launch_enter_hook, &launch_exit_hook{args_list})) {{
|
|
475
544
|
return NULL;
|
|
@@ -483,11 +552,10 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
483
552
|
|
|
484
553
|
// extract launch metadata
|
|
485
554
|
if (launch_enter_hook != Py_None){{
|
|
486
|
-
PyObject*
|
|
487
|
-
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
|
|
488
|
-
Py_DECREF(args);
|
|
555
|
+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
|
|
489
556
|
if (!ret)
|
|
490
557
|
return NULL;
|
|
558
|
+
Py_DECREF(ret);
|
|
491
559
|
}}
|
|
492
560
|
|
|
493
561
|
CUdeviceptr global_scratch = 0;
|
|
@@ -499,23 +567,31 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
499
567
|
global_scratch = global_scratch_info.dev_ptr;
|
|
500
568
|
}}
|
|
501
569
|
|
|
570
|
+
CUdeviceptr profile_scratch = 0;
|
|
571
|
+
if (profile_scratch_obj != Py_None) {{
|
|
572
|
+
DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
|
|
573
|
+
if (!profile_scratch_info.valid) {{
|
|
574
|
+
return NULL;
|
|
575
|
+
}}
|
|
576
|
+
profile_scratch = profile_scratch_info.dev_ptr;
|
|
577
|
+
}}
|
|
578
|
+
|
|
502
579
|
// raise exception asap
|
|
503
580
|
{newline.join(ptr_decls)}
|
|
504
581
|
{newline.join(tma_decls)}
|
|
582
|
+
{newline.join(float_storage_decls)}
|
|
505
583
|
Py_BEGIN_ALLOW_THREADS;
|
|
506
|
-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
584
|
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
507
585
|
Py_END_ALLOW_THREADS;
|
|
508
586
|
if (PyErr_Occurred()) {{
|
|
509
587
|
return NULL;
|
|
510
588
|
}}
|
|
511
589
|
|
|
512
590
|
if(launch_exit_hook != Py_None){{
|
|
513
|
-
PyObject*
|
|
514
|
-
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
|
|
515
|
-
Py_DECREF(args);
|
|
591
|
+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
|
|
516
592
|
if (!ret)
|
|
517
593
|
return NULL;
|
|
518
|
-
|
|
594
|
+
Py_DECREF(ret);
|
|
519
595
|
}}
|
|
520
596
|
|
|
521
597
|
Py_RETURN_NONE;
|
|
@@ -535,6 +611,19 @@ static struct PyModuleDef ModuleDef = {{
|
|
|
535
611
|
}};
|
|
536
612
|
|
|
537
613
|
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
614
|
+
data_ptr_str = PyUnicode_InternFromString("data_ptr");
|
|
615
|
+
if(data_ptr_str == NULL) {{
|
|
616
|
+
return NULL;
|
|
617
|
+
}}
|
|
618
|
+
PyObject* driver_mod = PyImport_ImportModule("triton.backends.nvidia.driver");
|
|
619
|
+
if (driver_mod == NULL) {{
|
|
620
|
+
return NULL;
|
|
621
|
+
}}
|
|
622
|
+
py_tensor_map_type = PyObject_GetAttrString(driver_mod, "PyCUtensorMap");
|
|
623
|
+
if (py_tensor_map_type == NULL) {{
|
|
624
|
+
return NULL;
|
|
625
|
+
}}
|
|
626
|
+
|
|
538
627
|
PyObject *m = PyModule_Create(&ModuleDef);
|
|
539
628
|
if(m == NULL) {{
|
|
540
629
|
return NULL;
|
|
@@ -546,6 +635,77 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
|
546
635
|
return src
|
|
547
636
|
|
|
548
637
|
|
|
638
|
+
# The TMA dtype enum values are slightly different on host vs device...
|
|
639
|
+
TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
|
|
640
|
+
TMA_DTYPE_DEVICE_TO_HOST[8] = 10
|
|
641
|
+
TMA_DTYPE_DEVICE_TO_HOST[9] = 8
|
|
642
|
+
TMA_DTYPE_DEVICE_TO_HOST[10] = 9
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def make_tensordesc_arg(arg, metadata):
|
|
646
|
+
if metadata is None:
|
|
647
|
+
# Currently the host side tensor descriptors get decomposed in
|
|
648
|
+
# the frontend to tensor desc, shape, and strides. We have no
|
|
649
|
+
# way to use these shape and strides when processing tensor
|
|
650
|
+
# descriptors which is why we provide our own decomposition
|
|
651
|
+
# above. Sadly this means we have to pass the shape and strides
|
|
652
|
+
# twice.
|
|
653
|
+
return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
|
|
654
|
+
|
|
655
|
+
swizzle = metadata["swizzle"]
|
|
656
|
+
elem_size = metadata["elem_size"]
|
|
657
|
+
elem_type = metadata["elem_type"]
|
|
658
|
+
block_size = metadata["block_size"]
|
|
659
|
+
fp4_padded = metadata["fp4_padded"]
|
|
660
|
+
|
|
661
|
+
shape = arg.shape
|
|
662
|
+
strides = arg.strides
|
|
663
|
+
assert strides[-1] == 1
|
|
664
|
+
padding = 1 if arg.padding == "nan" else 0
|
|
665
|
+
|
|
666
|
+
if fp4_padded:
|
|
667
|
+
shape = list(shape)
|
|
668
|
+
shape[-1] *= 2
|
|
669
|
+
|
|
670
|
+
cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
|
|
671
|
+
arg.base.data_ptr(),
|
|
672
|
+
swizzle,
|
|
673
|
+
elem_size,
|
|
674
|
+
TMA_DTYPE_DEVICE_TO_HOST[elem_type],
|
|
675
|
+
block_size,
|
|
676
|
+
shape,
|
|
677
|
+
strides,
|
|
678
|
+
padding,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
return [cu_tensor_map, *shape, *strides]
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def wrap_handle_tensordesc(launcher, signature, tensordesc_meta):
|
|
685
|
+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
|
|
686
|
+
if not has_tensor_desc_arg:
|
|
687
|
+
return launcher
|
|
688
|
+
|
|
689
|
+
tensordesc_indices = set(
|
|
690
|
+
[i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
|
|
691
|
+
assert not tensordesc_meta or len(tensordesc_meta) == len(tensordesc_indices)
|
|
692
|
+
if not tensordesc_meta:
|
|
693
|
+
tensordesc_meta = [None] * len(tensordesc_indices)
|
|
694
|
+
|
|
695
|
+
def inner(*args):
|
|
696
|
+
final_args = list(args[:_BASE_ARGS_FORMAT_LEN])
|
|
697
|
+
tensordesc_idx = 0
|
|
698
|
+
for i, arg in enumerate(args[_BASE_ARGS_FORMAT_LEN:]):
|
|
699
|
+
if i in tensordesc_indices:
|
|
700
|
+
final_args.extend(make_tensordesc_arg(arg, tensordesc_meta[tensordesc_idx]))
|
|
701
|
+
tensordesc_idx += 1
|
|
702
|
+
else:
|
|
703
|
+
final_args.append(arg)
|
|
704
|
+
return launcher(*final_args)
|
|
705
|
+
|
|
706
|
+
return inner
|
|
707
|
+
|
|
708
|
+
|
|
549
709
|
class CudaLauncher(object):
|
|
550
710
|
|
|
551
711
|
def __init__(self, src, metadata):
|
|
@@ -553,21 +713,40 @@ class CudaLauncher(object):
|
|
|
553
713
|
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
|
|
554
714
|
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
|
555
715
|
signature = {idx: value for idx, value in src.signature.items()}
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
716
|
+
tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
|
|
717
|
+
src = make_launcher(constants, signature, tensordesc_meta)
|
|
718
|
+
mod = compile_module_from_src(
|
|
719
|
+
src=src,
|
|
720
|
+
name="__triton_launcher",
|
|
721
|
+
library_dirs=library_dirs(),
|
|
722
|
+
include_dirs=include_dirs,
|
|
723
|
+
libraries=libraries,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
|
|
727
|
+
self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
|
|
559
728
|
self.global_scratch_size = metadata.global_scratch_size
|
|
560
729
|
self.global_scratch_align = metadata.global_scratch_align
|
|
730
|
+
self.profile_scratch_size = metadata.profile_scratch_size
|
|
731
|
+
self.profile_scratch_align = metadata.profile_scratch_align
|
|
561
732
|
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
|
733
|
+
self.launch_pdl = metadata.launch_pdl
|
|
562
734
|
|
|
563
735
|
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
736
|
+
|
|
737
|
+
def allocate_scratch(size, align, allocator):
|
|
738
|
+
if size > 0:
|
|
739
|
+
grid_size = gridX * gridY * gridZ
|
|
740
|
+
alloc_size = grid_size * self.num_ctas * size
|
|
741
|
+
alloc_fn = allocator.get()
|
|
742
|
+
return alloc_fn(alloc_size, align, stream)
|
|
743
|
+
return None
|
|
744
|
+
|
|
745
|
+
global_scratch = allocate_scratch(self.global_scratch_size, self.global_scratch_align, _allocation._allocator)
|
|
746
|
+
profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
|
|
747
|
+
_allocation._profile_allocator)
|
|
748
|
+
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
|
|
749
|
+
global_scratch, profile_scratch, *args)
|
|
571
750
|
|
|
572
751
|
|
|
573
752
|
class CudaDriver(GPUDriver):
|
|
@@ -600,6 +779,9 @@ class CudaDriver(GPUDriver):
|
|
|
600
779
|
except ImportError:
|
|
601
780
|
return False
|
|
602
781
|
|
|
782
|
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
783
|
+
return ty_to_cpp(ty)
|
|
784
|
+
|
|
603
785
|
def get_benchmarker(self):
|
|
604
786
|
from triton.testing import do_bench
|
|
605
787
|
return do_bench
|