triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
triton/tools/disasm.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2020 Da Yan @ HKUST
|
|
4
|
+
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
import functools
|
|
24
|
+
import os
|
|
25
|
+
import re
|
|
26
|
+
import subprocess
|
|
27
|
+
import tempfile
|
|
28
|
+
|
|
29
|
+
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
|
|
30
|
+
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
|
|
31
|
+
FNAME_RE = re.compile(r'\s*Function : (\w+)\s*')
|
|
32
|
+
BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);')
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def parseCtrl(sline):
|
|
36
|
+
enc = int(SLINE_RE.match(sline).group(1), 16)
|
|
37
|
+
stall = (enc >> 41) & 0xf
|
|
38
|
+
yld = (enc >> 45) & 0x1
|
|
39
|
+
wrtdb = (enc >> 46) & 0x7
|
|
40
|
+
readb = (enc >> 49) & 0x7
|
|
41
|
+
watdb = (enc >> 52) & 0x3f
|
|
42
|
+
|
|
43
|
+
yld_str = 'Y' if yld == 0 else '-'
|
|
44
|
+
wrtdb_str = '-' if wrtdb == 7 else str(wrtdb)
|
|
45
|
+
readb_str = '-' if readb == 7 else str(readb)
|
|
46
|
+
watdb_str = '--' if watdb == 0 else f'{watdb:02d}'
|
|
47
|
+
return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}'
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def processSassLines(fline, sline, labels):
|
|
51
|
+
asm = FLINE_RE.match(fline).group(1)
|
|
52
|
+
# Remove tailing space
|
|
53
|
+
if asm.endswith(" ;"):
|
|
54
|
+
asm = asm[:-2] + ";"
|
|
55
|
+
ctrl = parseCtrl(sline)
|
|
56
|
+
# BRA target address
|
|
57
|
+
if BRA_RE.match(asm) is not None:
|
|
58
|
+
target = int(BRA_RE.match(asm).group(2), 16)
|
|
59
|
+
if target in labels:
|
|
60
|
+
pass
|
|
61
|
+
else:
|
|
62
|
+
labels[target] = len(labels)
|
|
63
|
+
return (f'{ctrl}', f'{asm}')
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@functools.lru_cache()
|
|
67
|
+
def get_sass(cubin_asm, fun=None):
|
|
68
|
+
fd, path = tempfile.mkstemp()
|
|
69
|
+
try:
|
|
70
|
+
with open(fd, 'wb') as cubin:
|
|
71
|
+
cubin.write(cubin_asm)
|
|
72
|
+
sass = extract(path, fun)
|
|
73
|
+
finally:
|
|
74
|
+
os.remove(path)
|
|
75
|
+
return sass
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def path_to_cuobjdump():
|
|
79
|
+
from triton import knobs
|
|
80
|
+
return knobs.nvidia.cuobjdump.path
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def extract(file_path, fun):
|
|
84
|
+
cuobjdump = path_to_cuobjdump()
|
|
85
|
+
if fun is None:
|
|
86
|
+
sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
|
|
87
|
+
else:
|
|
88
|
+
sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path])
|
|
89
|
+
sass_lines = sass_str.splitlines()
|
|
90
|
+
line_idx = 0
|
|
91
|
+
while line_idx < len(sass_lines):
|
|
92
|
+
line = sass_lines[line_idx].decode()
|
|
93
|
+
# format:
|
|
94
|
+
# function : <function_name>
|
|
95
|
+
# .headerflags: ...
|
|
96
|
+
# /*0000*/ asmstr /*0x...*/
|
|
97
|
+
# /*0x...*/
|
|
98
|
+
|
|
99
|
+
# Looking for new function header (function: <name>)
|
|
100
|
+
while FNAME_RE.match(line) is None:
|
|
101
|
+
line_idx += 1
|
|
102
|
+
if line_idx < len(sass_lines):
|
|
103
|
+
line = sass_lines[line_idx].decode()
|
|
104
|
+
else:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
fname = FNAME_RE.match(line).group(1)
|
|
108
|
+
ret = ''
|
|
109
|
+
ret += f'Function:{fname}\n'
|
|
110
|
+
line_idx += 2 # bypass .headerflags
|
|
111
|
+
line = sass_lines[line_idx].decode()
|
|
112
|
+
# Remapping address to label
|
|
113
|
+
labels = {} # address -> label_idx
|
|
114
|
+
# store sass asm in buffer and them print them (for labels)
|
|
115
|
+
# (ctrl, asm)
|
|
116
|
+
asm_buffer = []
|
|
117
|
+
while FLINE_RE.match(line) is not None:
|
|
118
|
+
# First line (Offset ASM Encoding)
|
|
119
|
+
fline = sass_lines[line_idx].decode()
|
|
120
|
+
line_idx += 1
|
|
121
|
+
# Second line (Encoding)
|
|
122
|
+
sline = sass_lines[line_idx].decode()
|
|
123
|
+
line_idx += 1
|
|
124
|
+
asm_buffer.append(processSassLines(fline, sline, labels))
|
|
125
|
+
# peek the next line
|
|
126
|
+
line = sass_lines[line_idx].decode()
|
|
127
|
+
# Print sass
|
|
128
|
+
# label naming convention: LBB#i
|
|
129
|
+
for idx, (ctrl, asm) in enumerate(asm_buffer):
|
|
130
|
+
# Print label if this is BRA target
|
|
131
|
+
offset = idx * 16
|
|
132
|
+
if offset in labels:
|
|
133
|
+
label_name = f'LBB{labels[offset]}'
|
|
134
|
+
ret += f'{label_name}:\n'
|
|
135
|
+
ret += ctrl + '\t'
|
|
136
|
+
# if this is BRA, remap offset to label
|
|
137
|
+
if BRA_RE.match(asm):
|
|
138
|
+
target = int(BRA_RE.match(asm).group(2), 16)
|
|
139
|
+
target_name = f'LBB{labels[target]}'
|
|
140
|
+
asm = BRA_RE.sub(rf'\1{target_name};', asm)
|
|
141
|
+
ret += asm + '\n'
|
|
142
|
+
ret += '\n'
|
|
143
|
+
return ret
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
/* clang-format off */
|
|
2
|
+
#include <stdio.h>
|
|
3
|
+
#include <stdint.h>
|
|
4
|
+
#include <inttypes.h>
|
|
5
|
+
#include <string.h>
|
|
6
|
+
#include <cuda.h>
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
// helpers to check for cuda errors
|
|
10
|
+
#define CUDA_CHECK(ans) {{\
|
|
11
|
+
gpuAssert((ans), __FILE__, __LINE__);\
|
|
12
|
+
}}\
|
|
13
|
+
|
|
14
|
+
static inline void gpuAssert(CUresult code, const char *file, int line) {{
|
|
15
|
+
if (code != CUDA_SUCCESS) {{
|
|
16
|
+
const char *prefix = "Triton Error [CUDA]: ";
|
|
17
|
+
const char *str;
|
|
18
|
+
cuGetErrorString(code, &str);
|
|
19
|
+
char err[1024] = {{0}};
|
|
20
|
+
strcat(err, prefix);
|
|
21
|
+
strcat(err, str);
|
|
22
|
+
printf("%s\\n", err);
|
|
23
|
+
exit(code);
|
|
24
|
+
}}
|
|
25
|
+
}}
|
|
26
|
+
|
|
27
|
+
// globals
|
|
28
|
+
#define CUBIN_NAME {kernel_name}_cubin
|
|
29
|
+
CUmodule {kernel_name}_mod = NULL;
|
|
30
|
+
CUfunction {kernel_name}_func = NULL;
|
|
31
|
+
unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }};
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
void unload_{kernel_name}(void) {{
|
|
35
|
+
CUDA_CHECK(cuModuleUnload({kernel_name}_mod));
|
|
36
|
+
}}
|
|
37
|
+
|
|
38
|
+
// TODO: some code duplication with `runtime/backend/cuda.c`
|
|
39
|
+
void load_{kernel_name}() {{
|
|
40
|
+
int dev = 0;
|
|
41
|
+
void *bin = (void *)&CUBIN_NAME;
|
|
42
|
+
int shared = {shared};
|
|
43
|
+
CUDA_CHECK(cuModuleLoadData(&{kernel_name}_mod, bin));
|
|
44
|
+
CUDA_CHECK(cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
|
|
45
|
+
// set dynamic shared memory if necessary
|
|
46
|
+
int shared_optin;
|
|
47
|
+
CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev));
|
|
48
|
+
if (shared > 49152 && shared_optin > 49152) {{
|
|
49
|
+
CUDA_CHECK(cuFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED));
|
|
50
|
+
CUDA_CHECK(cuFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin))
|
|
51
|
+
}}
|
|
52
|
+
}}
|
|
53
|
+
|
|
54
|
+
/*
|
|
55
|
+
{kernel_docstring}
|
|
56
|
+
*/
|
|
57
|
+
CUresult {kernel_name}(CUstream stream, {signature}) {{
|
|
58
|
+
if ({kernel_name}_func == NULL)
|
|
59
|
+
load_{kernel_name}();
|
|
60
|
+
unsigned int gX = {gridX};
|
|
61
|
+
unsigned int gY = {gridY};
|
|
62
|
+
unsigned int gZ = {gridZ};
|
|
63
|
+
CUdeviceptr global_scratch = 0;
|
|
64
|
+
CUdeviceptr profile_scratch = 0;
|
|
65
|
+
void *args[{num_args}] = {{ {arg_pointers} }};
|
|
66
|
+
// TODO: shared memory
|
|
67
|
+
if(gX * gY * gZ > 0)
|
|
68
|
+
return cuLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL);
|
|
69
|
+
return (CUresult)NULL;
|
|
70
|
+
}}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
#ifndef TT_KERNEL_INCLUDES
|
|
2
|
+
#define TT_KERNEL_INCLUDES
|
|
3
|
+
|
|
4
|
+
#include <cuda.h>
|
|
5
|
+
#include <inttypes.h>
|
|
6
|
+
#include <stdint.h>
|
|
7
|
+
#include <stdio.h>
|
|
8
|
+
|
|
9
|
+
#endif
|
|
10
|
+
|
|
11
|
+
void unload_{kernel_name}(void);
|
|
12
|
+
void load_{kernel_name}(void);
|
|
13
|
+
// tt-linker: {kernel_name}:{full_signature}:{algo_info}
|
|
14
|
+
CUresult{_placeholder} {kernel_name}(CUstream stream, {signature});
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
// SPDX-License-Identifier: MIT
|
|
2
|
+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
3
|
+
|
|
4
|
+
/* clang-format off */
|
|
5
|
+
#include <stdio.h>
|
|
6
|
+
#include <stdint.h>
|
|
7
|
+
#include <inttypes.h>
|
|
8
|
+
#include <string.h>
|
|
9
|
+
#include <hip/hip_runtime.h>
|
|
10
|
+
|
|
11
|
+
// helpers to check for hip errors
|
|
12
|
+
#define HIP_CHECK(ans) {{\
|
|
13
|
+
gpuAssert((ans), __FILE__, __LINE__);\
|
|
14
|
+
}}\
|
|
15
|
+
|
|
16
|
+
static inline void gpuAssert(hipError_t code, const char *file, int line) {{
|
|
17
|
+
if (code != hipSuccess) {{
|
|
18
|
+
const char *prefix = "Triton Error [HIP]: ";
|
|
19
|
+
const char *str;
|
|
20
|
+
hipDrvGetErrorString(code, &str);
|
|
21
|
+
char err[1024] = {{0}};
|
|
22
|
+
strcat(err, prefix);
|
|
23
|
+
strcat(err, str);
|
|
24
|
+
printf("%s\\n", err);
|
|
25
|
+
exit(code);
|
|
26
|
+
}}
|
|
27
|
+
}}
|
|
28
|
+
|
|
29
|
+
// globals
|
|
30
|
+
#define HSACO_NAME {kernel_name}_hsaco
|
|
31
|
+
hipModule_t {kernel_name}_mod = nullptr;
|
|
32
|
+
hipFunction_t {kernel_name}_func = nullptr;
|
|
33
|
+
unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
void unload_{kernel_name}(void) {{
|
|
37
|
+
HIP_CHECK(hipModuleUnload({kernel_name}_mod));
|
|
38
|
+
}}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
void load_{kernel_name}() {{
|
|
42
|
+
int dev = 0;
|
|
43
|
+
void *bin = (void *)&HSACO_NAME;
|
|
44
|
+
int shared = {shared};
|
|
45
|
+
HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
|
|
46
|
+
HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
|
|
47
|
+
}}
|
|
48
|
+
|
|
49
|
+
/*
|
|
50
|
+
{kernel_docstring}
|
|
51
|
+
*/
|
|
52
|
+
hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
|
|
53
|
+
if ({kernel_name}_func == nullptr)
|
|
54
|
+
load_{kernel_name}();
|
|
55
|
+
unsigned int gX = {gridX};
|
|
56
|
+
unsigned int gY = {gridY};
|
|
57
|
+
unsigned int gZ = {gridZ};
|
|
58
|
+
hipDeviceptr_t global_scratch = 0;
|
|
59
|
+
hipDeviceptr_t profile_scratch = 0;
|
|
60
|
+
void *args[{num_args}] = {{ {arg_pointers} }};
|
|
61
|
+
// TODO: shared memory
|
|
62
|
+
if(gX * gY * gZ > 0)
|
|
63
|
+
return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr);
|
|
64
|
+
else
|
|
65
|
+
return hipErrorInvalidValue;
|
|
66
|
+
}}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
// SPDX-License-Identifier: MIT
|
|
2
|
+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
3
|
+
|
|
4
|
+
#pragma once
|
|
5
|
+
|
|
6
|
+
#include <hip/hip_runtime.h>
|
|
7
|
+
#include <inttypes.h>
|
|
8
|
+
#include <stdint.h>
|
|
9
|
+
#include <stdio.h>
|
|
10
|
+
|
|
11
|
+
void unload_{kernel_name}(void);
|
|
12
|
+
void load_{kernel_name}(void);
|
|
13
|
+
hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature});
|
triton/tools/link.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Sequence, Union
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _exists(x):
|
|
9
|
+
return x is not None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LinkerError(Exception):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class KernelLinkerMeta:
|
|
18
|
+
orig_kernel_name: str
|
|
19
|
+
arg_names: Sequence[str]
|
|
20
|
+
arg_ctypes: Sequence[str]
|
|
21
|
+
sizes: Sequence[Union[int, None]]
|
|
22
|
+
sig_hash: str
|
|
23
|
+
triton_suffix: str
|
|
24
|
+
suffix: str
|
|
25
|
+
num_specs: int
|
|
26
|
+
""" number of specialized arguments """
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class HeaderParser:
|
|
30
|
+
|
|
31
|
+
def __init__(self) -> None:
|
|
32
|
+
import re
|
|
33
|
+
|
|
34
|
+
# [kernel_name, c signature]
|
|
35
|
+
self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)")
|
|
36
|
+
# [name, hash, suffix]
|
|
37
|
+
self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
|
|
38
|
+
# [(type, name)]
|
|
39
|
+
self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?")
|
|
40
|
+
# [d|c]
|
|
41
|
+
self.arg_suffix = re.compile("[c,d]")
|
|
42
|
+
|
|
43
|
+
self.kernels = defaultdict(list)
|
|
44
|
+
|
|
45
|
+
def extract_linker_meta(self, header: str):
|
|
46
|
+
for ln in header.splitlines():
|
|
47
|
+
if ln.startswith("//"):
|
|
48
|
+
m = self.linker_directives.match(ln)
|
|
49
|
+
if _exists(m):
|
|
50
|
+
ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3)
|
|
51
|
+
name, sig_hash, suffix = self._match_name(ker_name)
|
|
52
|
+
c_types, arg_names = self._match_c_sig(c_sig)
|
|
53
|
+
num_specs, sizes = self._match_suffix(suffix, c_sig)
|
|
54
|
+
self._add_kernel(
|
|
55
|
+
"_".join([name, algo_info]),
|
|
56
|
+
KernelLinkerMeta(
|
|
57
|
+
orig_kernel_name=name,
|
|
58
|
+
arg_names=arg_names,
|
|
59
|
+
arg_ctypes=c_types,
|
|
60
|
+
sizes=sizes,
|
|
61
|
+
sig_hash=sig_hash,
|
|
62
|
+
triton_suffix=suffix,
|
|
63
|
+
suffix=suffix,
|
|
64
|
+
num_specs=num_specs,
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def _match_name(self, ker_name: str):
|
|
69
|
+
m = self.kernel_name.match(ker_name)
|
|
70
|
+
if _exists(m):
|
|
71
|
+
name, sig_hash, suffix = m.group(1), m.group(2), m.group(3)
|
|
72
|
+
return name, sig_hash, suffix
|
|
73
|
+
raise LinkerError(f"{ker_name} is not a valid kernel name")
|
|
74
|
+
|
|
75
|
+
def _match_c_sig(self, c_sig: str):
|
|
76
|
+
m = self.c_sig.findall(c_sig)
|
|
77
|
+
if len(m):
|
|
78
|
+
tys, args = [], []
|
|
79
|
+
for ty, arg_name in m:
|
|
80
|
+
tys.append(ty)
|
|
81
|
+
args.append(arg_name)
|
|
82
|
+
return tys, args
|
|
83
|
+
|
|
84
|
+
raise LinkerError(f"{c_sig} is not a valid argument signature")
|
|
85
|
+
|
|
86
|
+
def _match_suffix(self, suffix: str, c_sig: str):
|
|
87
|
+
args = c_sig.split(",")
|
|
88
|
+
s2i = {"c": 1, "d": 16}
|
|
89
|
+
num_specs = 0
|
|
90
|
+
sizes = []
|
|
91
|
+
# scan through suffix, first find the index,
|
|
92
|
+
# then see if it is followed by d or c
|
|
93
|
+
for i in range(len(args)):
|
|
94
|
+
pos = suffix.find(str(i))
|
|
95
|
+
if pos == -1:
|
|
96
|
+
raise LinkerError(f"{suffix} is not a valid kernel suffix")
|
|
97
|
+
pos += len(str(i))
|
|
98
|
+
if self.arg_suffix.match(suffix, pos):
|
|
99
|
+
num_specs += 1
|
|
100
|
+
sizes.extend([None] * (i - len(sizes)))
|
|
101
|
+
sizes.append(s2i[suffix[pos]])
|
|
102
|
+
pos += 1
|
|
103
|
+
if i < len(args) - 1:
|
|
104
|
+
suffix = suffix[pos:]
|
|
105
|
+
else:
|
|
106
|
+
sizes.extend([None] * (len(args) - len(sizes)))
|
|
107
|
+
return num_specs, sizes
|
|
108
|
+
|
|
109
|
+
def _add_kernel(self, name: str, ker: KernelLinkerMeta):
|
|
110
|
+
if name in self.kernels:
|
|
111
|
+
last: KernelLinkerMeta = self.kernels[name][-1]
|
|
112
|
+
|
|
113
|
+
for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
|
|
114
|
+
if cur != new_:
|
|
115
|
+
raise LinkerError(
|
|
116
|
+
f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
self.kernels[name].append(ker)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def gen_signature_with_full_args(m):
|
|
123
|
+
return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)])
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def gen_signature(m):
|
|
127
|
+
arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1]
|
|
128
|
+
arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1]
|
|
129
|
+
sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)])
|
|
130
|
+
return sig
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# generate declarations of kernels with meta-parameter and constant values
|
|
134
|
+
def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
|
135
|
+
return f"""
|
|
136
|
+
CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])});
|
|
137
|
+
void load_{name}();
|
|
138
|
+
void unload_{name}();
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# generate declarations of kernels with meta-parameter and constant values
|
|
143
|
+
def make_global_decl(meta: KernelLinkerMeta) -> str:
|
|
144
|
+
return f"""
|
|
145
|
+
CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)});
|
|
146
|
+
CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id);
|
|
147
|
+
void load_{meta.orig_kernel_name}();
|
|
148
|
+
void unload_{meta.orig_kernel_name}();
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# generate dispatcher function for kernels with different meta-parameter and constant values
|
|
153
|
+
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
|
|
154
|
+
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
|
|
155
|
+
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
|
|
156
|
+
src += "}\n"
|
|
157
|
+
return src
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# generate dispatcher function for kernels with different integer value hints
|
|
161
|
+
def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
|
|
162
|
+
src = f"// launcher for: {name}\n"
|
|
163
|
+
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
|
164
|
+
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
|
|
165
|
+
src += "\n"
|
|
166
|
+
|
|
167
|
+
src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
|
|
168
|
+
src += "\n"
|
|
169
|
+
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
|
170
|
+
cond_fn = ( #
|
|
171
|
+
lambda val, hint: f"({val} % {hint} == 0)" #
|
|
172
|
+
if hint == 16 #
|
|
173
|
+
else f"({val} == {hint})" #
|
|
174
|
+
if hint == 1 #
|
|
175
|
+
else None)
|
|
176
|
+
conds = " && ".join([ #
|
|
177
|
+
cond_fn(val, hint) #
|
|
178
|
+
for val, hint in zip(meta.arg_names, meta.sizes) #
|
|
179
|
+
if hint is not None
|
|
180
|
+
])
|
|
181
|
+
src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
|
|
182
|
+
) # Edge case where no specializations hence no dispatching required
|
|
183
|
+
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
|
|
184
|
+
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
|
|
185
|
+
src += "\n"
|
|
186
|
+
src += " return CUDA_ERROR_INVALID_VALUE;\n"
|
|
187
|
+
src += "}\n"
|
|
188
|
+
|
|
189
|
+
for mode in ["load", "unload"]:
|
|
190
|
+
src += f"\n// {mode} for: {name}\n"
|
|
191
|
+
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
|
192
|
+
src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
|
|
193
|
+
src += f"void {mode}_{name}() {{"
|
|
194
|
+
src += "\n"
|
|
195
|
+
for meta in sorted(metas, key=lambda m: -m.num_specs):
|
|
196
|
+
src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
|
|
197
|
+
src += "}\n"
|
|
198
|
+
return src
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# generate dispatcher function for kernels with different meta-parameter and constant values
|
|
202
|
+
def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str:
|
|
203
|
+
src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
|
|
204
|
+
src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n"
|
|
205
|
+
src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n"
|
|
206
|
+
src += "}\n"
|
|
207
|
+
return src
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values
|
|
211
|
+
def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str:
|
|
212
|
+
# the table of hint dispatchers
|
|
213
|
+
src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n"
|
|
214
|
+
src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n"
|
|
215
|
+
for name in names:
|
|
216
|
+
src += f" {name},\n"
|
|
217
|
+
src += "};\n"
|
|
218
|
+
return src
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# generate definition for load/unload functions for kernels with different meta-parameter and constant values
|
|
222
|
+
def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str:
|
|
223
|
+
src = ""
|
|
224
|
+
for mode in ["load", "unload"]:
|
|
225
|
+
src += f"void {mode}_{meta.orig_kernel_name}(void){{\n"
|
|
226
|
+
for name in names:
|
|
227
|
+
src += f" {mode}_{name}();\n"
|
|
228
|
+
src += "}\n\n"
|
|
229
|
+
return src
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str:
|
|
233
|
+
src = f"int {meta.orig_kernel_name}_get_num_algos(void);"
|
|
234
|
+
return src
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def make_get_num_algos_def(meta: KernelLinkerMeta) -> str:
|
|
238
|
+
src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n"
|
|
239
|
+
src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n"
|
|
240
|
+
src += "}\n"
|
|
241
|
+
return src
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
desc = """
|
|
245
|
+
Triton ahead-of-time linker:
|
|
246
|
+
|
|
247
|
+
This program takes in header files generated by compile.py, and generates a
|
|
248
|
+
single entry-point responsible for dispatching the user's input to the right
|
|
249
|
+
kernel given the specializations that were compiled.
|
|
250
|
+
|
|
251
|
+
Example usage:
|
|
252
|
+
python link.py /path/to/headers/*.h -o kernel_name
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
if __name__ == "__main__":
|
|
256
|
+
from argparse import ArgumentParser
|
|
257
|
+
|
|
258
|
+
parser = ArgumentParser(description=desc)
|
|
259
|
+
parser.add_argument(
|
|
260
|
+
"headers",
|
|
261
|
+
nargs="+",
|
|
262
|
+
help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
|
|
263
|
+
)
|
|
264
|
+
parser.add_argument("--out", "-o", type=Path, help="Out filename")
|
|
265
|
+
parser.add_argument(
|
|
266
|
+
"--prefix",
|
|
267
|
+
type=str,
|
|
268
|
+
default="",
|
|
269
|
+
help="String to prefix kernel dispatcher names",
|
|
270
|
+
)
|
|
271
|
+
args = parser.parse_args()
|
|
272
|
+
|
|
273
|
+
# metadata
|
|
274
|
+
parser = HeaderParser()
|
|
275
|
+
includes = []
|
|
276
|
+
for header in args.headers:
|
|
277
|
+
h_path = Path(header)
|
|
278
|
+
h_str = h_path.read_text()
|
|
279
|
+
includes.append(h_path.name)
|
|
280
|
+
parser.extract_linker_meta(h_str)
|
|
281
|
+
|
|
282
|
+
# generate headers
|
|
283
|
+
algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()]
|
|
284
|
+
meta_lists = [meta for name, meta in parser.kernels.items()]
|
|
285
|
+
meta = meta_lists[0][0]
|
|
286
|
+
get_num_algos_decl = make_get_num_algos_decl(meta)
|
|
287
|
+
global_decl = make_global_decl(meta)
|
|
288
|
+
with args.out.with_suffix(".h").open("w") as fp:
|
|
289
|
+
out = "#include <cuda.h>\n"
|
|
290
|
+
out += "\n".join(algo_decls)
|
|
291
|
+
out += "\n"
|
|
292
|
+
out += get_num_algos_decl
|
|
293
|
+
out += "\n"
|
|
294
|
+
out += global_decl
|
|
295
|
+
fp.write(out)
|
|
296
|
+
|
|
297
|
+
# generate source
|
|
298
|
+
defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()]
|
|
299
|
+
names = [name for name in parser.kernels.keys()]
|
|
300
|
+
func_pointers_def = make_func_pointers(names, meta)
|
|
301
|
+
meta_const_def = make_kernel_meta_const_dispatcher(meta)
|
|
302
|
+
load_unload_def = make_kernel_load_def(names, meta)
|
|
303
|
+
get_num_algos_def = make_get_num_algos_def(meta)
|
|
304
|
+
default_algo_kernel = make_default_algo_kernel(meta)
|
|
305
|
+
with args.out.with_suffix(".c").open("w") as fp:
|
|
306
|
+
out = ""
|
|
307
|
+
out += "#include <cuda.h>\n"
|
|
308
|
+
out += "#include <stdint.h>\n"
|
|
309
|
+
out += "#include <assert.h>\n"
|
|
310
|
+
out += "\n"
|
|
311
|
+
out += "\n".join(defs)
|
|
312
|
+
out += "\n"
|
|
313
|
+
out += func_pointers_def
|
|
314
|
+
out += "\n"
|
|
315
|
+
out += get_num_algos_def
|
|
316
|
+
out += "\n"
|
|
317
|
+
out += meta_const_def
|
|
318
|
+
out += "\n"
|
|
319
|
+
out += load_unload_def
|
|
320
|
+
out += "\n"
|
|
321
|
+
out += default_algo_kernel
|
|
322
|
+
fp.write(out)
|