triton-windows 3.1.0.post17__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +73 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +262 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +497 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +76 -0
- triton/backends/driver.py +34 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +347 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +430 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1302 -0
- triton/compiler/compiler.py +416 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +284 -0
- triton/language/core.py +2621 -0
- triton/language/extra/__init__.py +4 -0
- triton/language/extra/cuda/__init__.py +8 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +468 -0
- triton/language/extra/libdevice.py +1213 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1621 -0
- triton/language/standard.py +441 -0
- triton/ops/__init__.py +7 -0
- triton/ops/blocksparse/__init__.py +7 -0
- triton/ops/blocksparse/matmul.py +432 -0
- triton/ops/blocksparse/softmax.py +228 -0
- triton/ops/cross_entropy.py +96 -0
- triton/ops/flash_attention.py +466 -0
- triton/ops/matmul.py +219 -0
- triton/ops/matmul_perf_model.py +171 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +361 -0
- triton/runtime/build.py +129 -0
- triton/runtime/cache.py +289 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1127 -0
- triton/runtime/jit.py +956 -0
- triton/runtime/tcc/include/_mingw.h +170 -0
- triton/runtime/tcc/include/assert.h +57 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +57 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/limits.h +111 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +737 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdarg.h +79 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +54 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +580 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +118 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/winbase.h +2951 -0
- triton/runtime/tcc/include/winapi/wincon.h +301 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnt.h +5835 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +496 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +145 -0
- triton/tools/disasm.py +142 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +373 -0
- triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
- triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
- triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
- triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
import hashlib
|
|
4
|
+
import subprocess
|
|
5
|
+
import tempfile
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from triton.runtime.build import _build
|
|
8
|
+
from triton.runtime.cache import get_cache_manager
|
|
9
|
+
from triton.backends.compiler import GPUTarget
|
|
10
|
+
from triton.backends.driver import GPUDriver
|
|
11
|
+
|
|
12
|
+
dirname = os.path.dirname(os.path.realpath(__file__))
|
|
13
|
+
include_dir = [os.path.join(dirname, "include")]
|
|
14
|
+
if os.name == "nt":
|
|
15
|
+
from triton.windows_utils import find_cuda
|
|
16
|
+
_, cuda_inc_dirs, _ = find_cuda()
|
|
17
|
+
include_dir += cuda_inc_dirs
|
|
18
|
+
libdevice_dir = os.path.join(dirname, "lib")
|
|
19
|
+
libraries = ['cuda']
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@functools.lru_cache()
|
|
23
|
+
def libcuda_dirs():
|
|
24
|
+
env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH")
|
|
25
|
+
if env_libcuda_path:
|
|
26
|
+
return [env_libcuda_path]
|
|
27
|
+
|
|
28
|
+
if os.name == "nt":
|
|
29
|
+
_, _, cuda_lib_dirs = find_cuda()
|
|
30
|
+
return cuda_lib_dirs
|
|
31
|
+
|
|
32
|
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
|
|
33
|
+
# each line looks like the following:
|
|
34
|
+
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
|
|
35
|
+
locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
|
|
36
|
+
dirs = [os.path.dirname(loc) for loc in locs]
|
|
37
|
+
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
|
|
38
|
+
if env_ld_library_path and not dirs:
|
|
39
|
+
dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))]
|
|
40
|
+
msg = 'libcuda.so cannot found!\n'
|
|
41
|
+
if locs:
|
|
42
|
+
msg += 'Possible files are located at %s.' % str(locs)
|
|
43
|
+
msg += 'Please create a symlink of libcuda.so to any of the files.'
|
|
44
|
+
else:
|
|
45
|
+
msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"'
|
|
46
|
+
msg += ' (requires sudo) to refresh the linker cache.'
|
|
47
|
+
assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg
|
|
48
|
+
return dirs
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@functools.lru_cache()
|
|
52
|
+
def library_dirs():
|
|
53
|
+
return [libdevice_dir, *libcuda_dirs()]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compile_module_from_src(src, name):
|
|
57
|
+
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
|
|
58
|
+
cache = get_cache_manager(key)
|
|
59
|
+
if os.name == "nt":
|
|
60
|
+
so_name = f"{name}.pyd"
|
|
61
|
+
else:
|
|
62
|
+
so_name = f"{name}.so"
|
|
63
|
+
cache_path = cache.get_file(so_name)
|
|
64
|
+
if cache_path is None:
|
|
65
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
66
|
+
src_path = os.path.join(tmpdir, f"{name}.c")
|
|
67
|
+
with open(src_path, "w") as f:
|
|
68
|
+
f.write(src)
|
|
69
|
+
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
|
|
70
|
+
with open(so, "rb") as f:
|
|
71
|
+
cache_path = cache.put(f.read(), so_name, binary=True)
|
|
72
|
+
import importlib.util
|
|
73
|
+
spec = importlib.util.spec_from_file_location(name, cache_path)
|
|
74
|
+
mod = importlib.util.module_from_spec(spec)
|
|
75
|
+
spec.loader.exec_module(mod)
|
|
76
|
+
return mod
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ------------------------
|
|
80
|
+
# Utils
|
|
81
|
+
# ------------------------
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class CudaUtils(object):
|
|
85
|
+
|
|
86
|
+
def __new__(cls):
|
|
87
|
+
if not hasattr(cls, "instance"):
|
|
88
|
+
cls.instance = super(CudaUtils, cls).__new__(cls)
|
|
89
|
+
return cls.instance
|
|
90
|
+
|
|
91
|
+
def __init__(self):
|
|
92
|
+
mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
|
|
93
|
+
self.load_binary = mod.load_binary
|
|
94
|
+
self.get_device_properties = mod.get_device_properties
|
|
95
|
+
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
|
|
96
|
+
self.set_printf_fifo_size = mod.set_printf_fifo_size
|
|
97
|
+
self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor
|
|
98
|
+
self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ------------------------
|
|
102
|
+
# Launcher
|
|
103
|
+
# ------------------------
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def ty_to_cpp(ty):
|
|
107
|
+
if ty[0] == '*':
|
|
108
|
+
return "CUdeviceptr"
|
|
109
|
+
return {
|
|
110
|
+
"i1": "int32_t",
|
|
111
|
+
"i8": "int8_t",
|
|
112
|
+
"i16": "int16_t",
|
|
113
|
+
"i32": "int32_t",
|
|
114
|
+
"i64": "int64_t",
|
|
115
|
+
"u1": "uint32_t",
|
|
116
|
+
"u8": "uint8_t",
|
|
117
|
+
"u16": "uint16_t",
|
|
118
|
+
"u32": "uint32_t",
|
|
119
|
+
"u64": "uint64_t",
|
|
120
|
+
"fp16": "float",
|
|
121
|
+
"bf16": "float",
|
|
122
|
+
"fp32": "float",
|
|
123
|
+
"f32": "float",
|
|
124
|
+
"fp64": "double",
|
|
125
|
+
}[ty]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def make_launcher(constants, signature, ids):
|
|
129
|
+
# Record the end of regular arguments;
|
|
130
|
+
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
131
|
+
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
|
132
|
+
|
|
133
|
+
def _extracted_type(ty):
|
|
134
|
+
if ty[0] == '*':
|
|
135
|
+
return "PyObject*"
|
|
136
|
+
return ty_to_cpp(ty)
|
|
137
|
+
|
|
138
|
+
def format_of(ty):
|
|
139
|
+
return {
|
|
140
|
+
"PyObject*": "O",
|
|
141
|
+
"float": "f",
|
|
142
|
+
"double": "d",
|
|
143
|
+
"long": "l",
|
|
144
|
+
"int8_t": "b",
|
|
145
|
+
"int16_t": "h",
|
|
146
|
+
"int32_t": "i",
|
|
147
|
+
"int64_t": "L",
|
|
148
|
+
"uint8_t": "B",
|
|
149
|
+
"uint16_t": "H",
|
|
150
|
+
"uint32_t": "I",
|
|
151
|
+
"uint64_t": "K",
|
|
152
|
+
}[ty]
|
|
153
|
+
|
|
154
|
+
args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
|
155
|
+
format = "iiiKKOOOO" + args_format
|
|
156
|
+
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
|
157
|
+
|
|
158
|
+
# generate glue code
|
|
159
|
+
params = [i for i in signature.keys() if i not in constants]
|
|
160
|
+
if params:
|
|
161
|
+
params_decl = ", ".join(f"&arg{i}" for i in params)
|
|
162
|
+
params_decl = f"void *params[] = {{ {params_decl} }};"
|
|
163
|
+
else:
|
|
164
|
+
params_decl = "void **params = NULL;"
|
|
165
|
+
src = f"""
|
|
166
|
+
#include \"cuda.h\"
|
|
167
|
+
#include <stdbool.h>
|
|
168
|
+
#define PY_SSIZE_T_CLEAN
|
|
169
|
+
#define Py_LIMITED_API 0x03090000
|
|
170
|
+
#include <Python.h>
|
|
171
|
+
|
|
172
|
+
#ifndef _WIN32
|
|
173
|
+
#include <dlfcn.h>
|
|
174
|
+
#else
|
|
175
|
+
#define WIN32_LEAN_AND_MEAN
|
|
176
|
+
#include <windows.h>
|
|
177
|
+
#endif
|
|
178
|
+
|
|
179
|
+
static inline void gpuAssert(CUresult code, const char *file, int line)
|
|
180
|
+
{{
|
|
181
|
+
if (code != CUDA_SUCCESS)
|
|
182
|
+
{{
|
|
183
|
+
const char* prefix = "Triton Error [CUDA]: ";
|
|
184
|
+
const char* str;
|
|
185
|
+
cuGetErrorString(code, &str);
|
|
186
|
+
char err[1024] = {{0}};
|
|
187
|
+
strcat(err, prefix);
|
|
188
|
+
strcat(err, str);
|
|
189
|
+
PyGILState_STATE gil_state;
|
|
190
|
+
gil_state = PyGILState_Ensure();
|
|
191
|
+
PyErr_SetString(PyExc_RuntimeError, err);
|
|
192
|
+
PyGILState_Release(gil_state);
|
|
193
|
+
}}
|
|
194
|
+
}}
|
|
195
|
+
|
|
196
|
+
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
|
197
|
+
|
|
198
|
+
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
|
|
199
|
+
|
|
200
|
+
#ifndef _WIN32
|
|
201
|
+
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
|
202
|
+
// Open the shared library
|
|
203
|
+
void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
|
|
204
|
+
if (!handle) {{
|
|
205
|
+
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1");
|
|
206
|
+
return NULL;
|
|
207
|
+
}}
|
|
208
|
+
// Clear any existing error
|
|
209
|
+
dlerror();
|
|
210
|
+
cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
|
|
211
|
+
// Check for errors
|
|
212
|
+
const char *dlsym_error = dlerror();
|
|
213
|
+
if (dlsym_error) {{
|
|
214
|
+
PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1");
|
|
215
|
+
return NULL;
|
|
216
|
+
}}
|
|
217
|
+
return cuLaunchKernelExHandle;
|
|
218
|
+
}}
|
|
219
|
+
#else
|
|
220
|
+
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
|
221
|
+
// Open the shared library
|
|
222
|
+
HMODULE handle = LoadLibraryA("nvcuda.dll");
|
|
223
|
+
if (!handle) {{
|
|
224
|
+
PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll");
|
|
225
|
+
return NULL;
|
|
226
|
+
}}
|
|
227
|
+
cuLaunchKernelEx_t cuLaunchKernelExHandle =
|
|
228
|
+
(cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx");
|
|
229
|
+
// Check for errors
|
|
230
|
+
long error = GetLastError();
|
|
231
|
+
if (error) {{
|
|
232
|
+
PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll");
|
|
233
|
+
return NULL;
|
|
234
|
+
}}
|
|
235
|
+
return cuLaunchKernelExHandle;
|
|
236
|
+
}}
|
|
237
|
+
#endif
|
|
238
|
+
|
|
239
|
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
240
|
+
{params_decl}
|
|
241
|
+
if (gridX*gridY*gridZ > 0) {{
|
|
242
|
+
if (num_ctas == 1) {{
|
|
243
|
+
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
|
244
|
+
}} else {{
|
|
245
|
+
CUlaunchAttribute launchAttr[2];
|
|
246
|
+
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
247
|
+
launchAttr[0].value.clusterDim.x = clusterDimX;
|
|
248
|
+
launchAttr[0].value.clusterDim.y = clusterDimY;
|
|
249
|
+
launchAttr[0].value.clusterDim.z = clusterDimZ;
|
|
250
|
+
launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
251
|
+
launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
|
252
|
+
CUlaunchConfig config;
|
|
253
|
+
config.gridDimX = gridX * clusterDimX;
|
|
254
|
+
config.gridDimY = gridY * clusterDimY;
|
|
255
|
+
config.gridDimZ = gridZ * clusterDimZ;
|
|
256
|
+
config.blockDimX = 32 * num_warps;
|
|
257
|
+
config.blockDimY = 1;
|
|
258
|
+
config.blockDimZ = 1;
|
|
259
|
+
config.sharedMemBytes = shared_memory;
|
|
260
|
+
config.hStream = stream;
|
|
261
|
+
config.attrs = launchAttr;
|
|
262
|
+
config.numAttrs = 2;
|
|
263
|
+
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
264
|
+
if (cuLaunchKernelExHandle == NULL) {{
|
|
265
|
+
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
266
|
+
}}
|
|
267
|
+
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
268
|
+
}}
|
|
269
|
+
}}
|
|
270
|
+
}}
|
|
271
|
+
|
|
272
|
+
typedef struct _DevicePtrInfo {{
|
|
273
|
+
CUdeviceptr dev_ptr;
|
|
274
|
+
bool valid;
|
|
275
|
+
}} DevicePtrInfo;
|
|
276
|
+
|
|
277
|
+
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
278
|
+
DevicePtrInfo ptr_info;
|
|
279
|
+
ptr_info.dev_ptr = 0;
|
|
280
|
+
ptr_info.valid = true;
|
|
281
|
+
if (PyLong_Check(obj)) {{
|
|
282
|
+
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
|
|
283
|
+
return ptr_info;
|
|
284
|
+
}}
|
|
285
|
+
if (obj == Py_None) {{
|
|
286
|
+
// valid nullptr
|
|
287
|
+
return ptr_info;
|
|
288
|
+
}}
|
|
289
|
+
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
|
290
|
+
if(ptr){{
|
|
291
|
+
PyObject *empty_tuple = PyTuple_New(0);
|
|
292
|
+
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
|
293
|
+
Py_DECREF(empty_tuple);
|
|
294
|
+
Py_DECREF(ptr);
|
|
295
|
+
if (!PyLong_Check(ret)) {{
|
|
296
|
+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
|
297
|
+
ptr_info.valid = false;
|
|
298
|
+
return ptr_info;
|
|
299
|
+
}}
|
|
300
|
+
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
|
|
301
|
+
if(!ptr_info.dev_ptr)
|
|
302
|
+
return ptr_info;
|
|
303
|
+
uint64_t dev_ptr;
|
|
304
|
+
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
305
|
+
if (status == CUDA_ERROR_INVALID_VALUE) {{
|
|
306
|
+
PyErr_Format(PyExc_ValueError,
|
|
307
|
+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
308
|
+
ptr_info.valid = false;
|
|
309
|
+
}}
|
|
310
|
+
ptr_info.dev_ptr = dev_ptr;
|
|
311
|
+
Py_DECREF(ret); // Thanks ChatGPT!
|
|
312
|
+
return ptr_info;
|
|
313
|
+
}}
|
|
314
|
+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
|
315
|
+
ptr_info.valid = false;
|
|
316
|
+
return ptr_info;
|
|
317
|
+
}}
|
|
318
|
+
|
|
319
|
+
static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
320
|
+
int gridX, gridY, gridZ;
|
|
321
|
+
uint64_t _stream;
|
|
322
|
+
uint64_t _function;
|
|
323
|
+
PyObject *launch_enter_hook = NULL;
|
|
324
|
+
PyObject *launch_exit_hook = NULL;
|
|
325
|
+
PyObject *kernel_metadata = NULL;
|
|
326
|
+
PyObject *launch_metadata = NULL;
|
|
327
|
+
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
|
328
|
+
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function,
|
|
329
|
+
&kernel_metadata, &launch_metadata,
|
|
330
|
+
&launch_enter_hook, &launch_exit_hook {args_list})) {{
|
|
331
|
+
return NULL;
|
|
332
|
+
}}
|
|
333
|
+
|
|
334
|
+
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
|
|
335
|
+
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
|
|
336
|
+
PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
|
|
337
|
+
return NULL;
|
|
338
|
+
}}
|
|
339
|
+
|
|
340
|
+
// extract launch metadata
|
|
341
|
+
if (launch_enter_hook != Py_None){{
|
|
342
|
+
PyObject* args = Py_BuildValue("(O)", launch_metadata);
|
|
343
|
+
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
|
|
344
|
+
Py_DECREF(args);
|
|
345
|
+
if (!ret)
|
|
346
|
+
return NULL;
|
|
347
|
+
}}
|
|
348
|
+
|
|
349
|
+
// raise exception asap
|
|
350
|
+
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
|
351
|
+
Py_BEGIN_ALLOW_THREADS;
|
|
352
|
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
|
|
353
|
+
Py_END_ALLOW_THREADS;
|
|
354
|
+
if (PyErr_Occurred()) {{
|
|
355
|
+
return NULL;
|
|
356
|
+
}}
|
|
357
|
+
|
|
358
|
+
if(launch_exit_hook != Py_None){{
|
|
359
|
+
PyObject* args = Py_BuildValue("(O)", launch_metadata);
|
|
360
|
+
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
|
|
361
|
+
Py_DECREF(args);
|
|
362
|
+
if (!ret)
|
|
363
|
+
return NULL;
|
|
364
|
+
|
|
365
|
+
}}
|
|
366
|
+
|
|
367
|
+
// return None
|
|
368
|
+
Py_INCREF(Py_None);
|
|
369
|
+
return Py_None;
|
|
370
|
+
}}
|
|
371
|
+
|
|
372
|
+
static PyMethodDef ModuleMethods[] = {{
|
|
373
|
+
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
|
374
|
+
{{NULL, NULL, 0, NULL}} // sentinel
|
|
375
|
+
}};
|
|
376
|
+
|
|
377
|
+
static struct PyModuleDef ModuleDef = {{
|
|
378
|
+
PyModuleDef_HEAD_INIT,
|
|
379
|
+
\"__triton_launcher\",
|
|
380
|
+
NULL, //documentation
|
|
381
|
+
-1, //size
|
|
382
|
+
ModuleMethods
|
|
383
|
+
}};
|
|
384
|
+
|
|
385
|
+
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
386
|
+
PyObject *m = PyModule_Create(&ModuleDef);
|
|
387
|
+
if(m == NULL) {{
|
|
388
|
+
return NULL;
|
|
389
|
+
}}
|
|
390
|
+
PyModule_AddFunctions(m, ModuleMethods);
|
|
391
|
+
return m;
|
|
392
|
+
}}
|
|
393
|
+
"""
|
|
394
|
+
return src
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class CudaLauncher(object):
|
|
398
|
+
|
|
399
|
+
def __init__(self, src, metadata):
|
|
400
|
+
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
|
|
401
|
+
constants = src.constants if hasattr(src, "constants") else dict()
|
|
402
|
+
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
|
|
403
|
+
constants = {cst_key(key): value for key, value in constants.items()}
|
|
404
|
+
signature = {cst_key(key): value for key, value in src.signature.items()}
|
|
405
|
+
src = make_launcher(constants, signature, ids)
|
|
406
|
+
mod = compile_module_from_src(src, "__triton_launcher")
|
|
407
|
+
self.launch = mod.launch
|
|
408
|
+
|
|
409
|
+
def __call__(self, *args, **kwargs):
|
|
410
|
+
self.launch(*args, **kwargs)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class CudaDriver(GPUDriver):
|
|
414
|
+
|
|
415
|
+
def __init__(self):
|
|
416
|
+
self.utils = CudaUtils() # TODO: make static
|
|
417
|
+
self.launcher_cls = CudaLauncher
|
|
418
|
+
super().__init__()
|
|
419
|
+
|
|
420
|
+
def get_current_target(self):
|
|
421
|
+
device = self.get_current_device()
|
|
422
|
+
capability = self.get_device_capability(device)
|
|
423
|
+
capability = capability[0] * 10 + capability[1]
|
|
424
|
+
warp_size = 32
|
|
425
|
+
return GPUTarget("cuda", capability, warp_size)
|
|
426
|
+
|
|
427
|
+
@staticmethod
|
|
428
|
+
def is_active():
|
|
429
|
+
import torch
|
|
430
|
+
return torch.cuda.is_available() and (torch.version.hip is None)
|