triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
triton/runtime/cache.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Dict, List, Optional
|
|
6
|
+
import base64
|
|
7
|
+
import hashlib
|
|
8
|
+
import functools
|
|
9
|
+
import sysconfig
|
|
10
|
+
|
|
11
|
+
from triton import __version__, knobs
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CacheManager(ABC):
|
|
15
|
+
|
|
16
|
+
def __init__(self, key, override=False, dump=False):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def get_file(self, filename) -> Optional[str]:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def put(self, data, filename, binary=True) -> str:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def put_group(self, filename: str, group: Dict[str, str]):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class FileCacheManager(CacheManager):
|
|
37
|
+
|
|
38
|
+
def __init__(self, key, override=False, dump=False):
|
|
39
|
+
self.key = key
|
|
40
|
+
self.lock_path = None
|
|
41
|
+
if dump:
|
|
42
|
+
self.cache_dir = knobs.cache.dump_dir
|
|
43
|
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
44
|
+
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
45
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
46
|
+
elif override:
|
|
47
|
+
self.cache_dir = knobs.cache.override_dir
|
|
48
|
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
49
|
+
else:
|
|
50
|
+
# create cache directory if it doesn't exist
|
|
51
|
+
self.cache_dir = knobs.cache.dir
|
|
52
|
+
if self.cache_dir:
|
|
53
|
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
54
|
+
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
55
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
56
|
+
else:
|
|
57
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
58
|
+
|
|
59
|
+
def _make_path(self, filename) -> str:
|
|
60
|
+
return os.path.join(self.cache_dir, filename)
|
|
61
|
+
|
|
62
|
+
def has_file(self, filename) -> bool:
|
|
63
|
+
if not self.cache_dir:
|
|
64
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
65
|
+
return os.path.exists(self._make_path(filename))
|
|
66
|
+
|
|
67
|
+
def get_file(self, filename) -> Optional[str]:
|
|
68
|
+
if self.has_file(filename):
|
|
69
|
+
return self._make_path(filename)
|
|
70
|
+
else:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
74
|
+
grp_filename = f"__grp__{filename}"
|
|
75
|
+
if not self.has_file(grp_filename):
|
|
76
|
+
return None
|
|
77
|
+
grp_filepath = self._make_path(grp_filename)
|
|
78
|
+
with open(grp_filepath) as f:
|
|
79
|
+
grp_data = json.load(f)
|
|
80
|
+
child_paths = grp_data.get("child_paths", None)
|
|
81
|
+
# Invalid group data.
|
|
82
|
+
if child_paths is None:
|
|
83
|
+
return None
|
|
84
|
+
result = {}
|
|
85
|
+
for c, p in child_paths.items():
|
|
86
|
+
if os.path.exists(p):
|
|
87
|
+
result[c] = p
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
# Note a group of pushed files as being part of a group
|
|
91
|
+
def put_group(self, filename: str, group: Dict[str, str]) -> str:
|
|
92
|
+
if not self.cache_dir:
|
|
93
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
94
|
+
grp_contents = json.dumps({"child_paths": group})
|
|
95
|
+
grp_filename = f"__grp__{filename}"
|
|
96
|
+
return self.put(grp_contents, grp_filename, binary=False)
|
|
97
|
+
|
|
98
|
+
def put(self, data, filename, binary=True) -> str:
|
|
99
|
+
if not self.cache_dir:
|
|
100
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
101
|
+
binary = isinstance(data, bytes)
|
|
102
|
+
if not binary:
|
|
103
|
+
data = str(data)
|
|
104
|
+
assert self.lock_path is not None
|
|
105
|
+
filepath = self._make_path(filename)
|
|
106
|
+
# Random ID to avoid any collisions
|
|
107
|
+
rnd_id = str(uuid.uuid4())
|
|
108
|
+
# we use the PID in case a bunch of these around so we can see what PID made it
|
|
109
|
+
pid = os.getpid()
|
|
110
|
+
# use temp dir to be robust against program interruptions
|
|
111
|
+
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
|
|
112
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
113
|
+
temp_path = os.path.join(temp_dir, filename)
|
|
114
|
+
|
|
115
|
+
mode = "wb" if binary else "w"
|
|
116
|
+
with open(temp_path, mode) as f:
|
|
117
|
+
f.write(data)
|
|
118
|
+
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
|
|
119
|
+
# so filepath cannot see a partial write
|
|
120
|
+
try:
|
|
121
|
+
os.replace(temp_path, filepath)
|
|
122
|
+
except PermissionError:
|
|
123
|
+
# Ignore PermissionError on Windows because it happens when another process already
|
|
124
|
+
# put a file into the cache and locked it by opening it.
|
|
125
|
+
if os.name == "nt":
|
|
126
|
+
os.remove(temp_path)
|
|
127
|
+
else:
|
|
128
|
+
raise
|
|
129
|
+
os.removedirs(temp_dir)
|
|
130
|
+
return filepath
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class RemoteCacheBackend:
|
|
134
|
+
"""
|
|
135
|
+
A backend implementation for accessing a remote/distributed cache.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(self, key: str):
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
@abstractmethod
|
|
142
|
+
def get(self, filenames: List[str]) -> Dict[str, bytes]:
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
@abstractmethod
|
|
146
|
+
def put(self, filename: str, data: bytes):
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class RedisRemoteCacheBackend(RemoteCacheBackend):
|
|
151
|
+
|
|
152
|
+
def __init__(self, key):
|
|
153
|
+
import redis
|
|
154
|
+
self._key = key
|
|
155
|
+
self._key_fmt = knobs.cache.redis.key_format
|
|
156
|
+
self._redis = redis.Redis(
|
|
157
|
+
host=knobs.cache.redis.host,
|
|
158
|
+
port=knobs.cache.redis.port,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def _get_key(self, filename: str) -> str:
|
|
162
|
+
return self._key_fmt.format(key=self._key, filename=filename)
|
|
163
|
+
|
|
164
|
+
def get(self, filenames: List[str]) -> Dict[str, str]:
|
|
165
|
+
results = self._redis.mget([self._get_key(f) for f in filenames])
|
|
166
|
+
return {filename: result for filename, result in zip(filenames, results) if result is not None}
|
|
167
|
+
|
|
168
|
+
def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
|
|
169
|
+
self._redis.set(self._get_key(filename), data)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class RemoteCacheManager(CacheManager):
|
|
173
|
+
|
|
174
|
+
def __init__(self, key, override=False, dump=False):
|
|
175
|
+
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
|
|
176
|
+
remote_cache_cls = knobs.cache.remote_manager_class
|
|
177
|
+
if not remote_cache_cls:
|
|
178
|
+
raise RuntimeError(
|
|
179
|
+
"Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
|
|
180
|
+
self._backend = remote_cache_cls(key)
|
|
181
|
+
|
|
182
|
+
self._override = override
|
|
183
|
+
self._dump = dump
|
|
184
|
+
|
|
185
|
+
# Use a `FileCacheManager` to materialize remote cache paths locally.
|
|
186
|
+
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
|
|
187
|
+
|
|
188
|
+
def _materialize(self, filename: str, data: bytes):
|
|
189
|
+
# We use a backing `FileCacheManager` to provide the materialized data.
|
|
190
|
+
return self._file_cache_manager.put(data, filename, binary=True)
|
|
191
|
+
|
|
192
|
+
def get_file(self, filename: str) -> Optional[str]:
|
|
193
|
+
# We don't handle the dump/override cases.
|
|
194
|
+
if self._dump or self._override:
|
|
195
|
+
return self._file_cache_manager.get_file(filename)
|
|
196
|
+
|
|
197
|
+
# We always check the remote cache backend -- even if our internal file-
|
|
198
|
+
# based cache has the item -- to make sure LRU accounting works as
|
|
199
|
+
# expected.
|
|
200
|
+
results = self._backend.get([filename])
|
|
201
|
+
if len(results) == 0:
|
|
202
|
+
return None
|
|
203
|
+
(_, data), = results.items()
|
|
204
|
+
return self._materialize(filename, data)
|
|
205
|
+
|
|
206
|
+
def put(self, data, filename: str, binary=True) -> str:
|
|
207
|
+
# We don't handle the dump/override cases.
|
|
208
|
+
if self._dump or self._override:
|
|
209
|
+
return self._file_cache_manager.put(data, filename, binary=binary)
|
|
210
|
+
|
|
211
|
+
if not isinstance(data, bytes):
|
|
212
|
+
data = str(data).encode("utf-8")
|
|
213
|
+
self._backend.put(filename, data)
|
|
214
|
+
return self._materialize(filename, data)
|
|
215
|
+
|
|
216
|
+
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
217
|
+
# We don't handle the dump/override cases.
|
|
218
|
+
if self._dump or self._override:
|
|
219
|
+
return self._file_cache_manager.get_group(filename)
|
|
220
|
+
|
|
221
|
+
grp_filename = f"__grp__{filename}"
|
|
222
|
+
grp_filepath = self.get_file(grp_filename)
|
|
223
|
+
if grp_filepath is None:
|
|
224
|
+
return None
|
|
225
|
+
with open(grp_filepath) as f:
|
|
226
|
+
grp_data = json.load(f)
|
|
227
|
+
child_paths = grp_data.get("child_paths", None)
|
|
228
|
+
|
|
229
|
+
result = None
|
|
230
|
+
|
|
231
|
+
# Found group data.
|
|
232
|
+
if child_paths is not None:
|
|
233
|
+
result = {}
|
|
234
|
+
for child_path, data in self._backend.get(child_paths).items():
|
|
235
|
+
result[child_path] = self._materialize(child_path, data)
|
|
236
|
+
|
|
237
|
+
return result
|
|
238
|
+
|
|
239
|
+
def put_group(self, filename: str, group: Dict[str, str]):
|
|
240
|
+
# We don't handle the dump/override cases.
|
|
241
|
+
if self._dump or self._override:
|
|
242
|
+
return self._file_cache_manager.put_group(filename, group)
|
|
243
|
+
|
|
244
|
+
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
|
|
245
|
+
grp_filename = f"__grp__{filename}"
|
|
246
|
+
return self.put(grp_contents, grp_filename)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _base32(key):
|
|
250
|
+
# Assume key is a hex string.
|
|
251
|
+
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def get_cache_manager(key) -> CacheManager:
|
|
255
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
256
|
+
return cls(_base32(key))
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def get_override_manager(key) -> CacheManager:
|
|
260
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
261
|
+
return cls(_base32(key), override=True)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def get_dump_manager(key) -> CacheManager:
|
|
265
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
266
|
+
return cls(_base32(key), dump=True)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
270
|
+
# Get unique key for the compiled code
|
|
271
|
+
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
|
272
|
+
key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
|
|
273
|
+
for kw in kwargs:
|
|
274
|
+
key = f"{key}-{kwargs.get(kw)}"
|
|
275
|
+
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
276
|
+
return _base32(key)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@functools.lru_cache()
|
|
280
|
+
def triton_key():
|
|
281
|
+
import pkgutil
|
|
282
|
+
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
283
|
+
contents = []
|
|
284
|
+
# frontend
|
|
285
|
+
with open(__file__, "rb") as f:
|
|
286
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
287
|
+
# compiler
|
|
288
|
+
path_prefixes = [
|
|
289
|
+
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
|
|
290
|
+
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
|
|
291
|
+
]
|
|
292
|
+
for path, prefix in path_prefixes:
|
|
293
|
+
for lib in pkgutil.walk_packages([path], prefix=prefix):
|
|
294
|
+
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
295
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
296
|
+
|
|
297
|
+
# backend
|
|
298
|
+
libtriton_hash = hashlib.sha256()
|
|
299
|
+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
300
|
+
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
|
|
301
|
+
while True:
|
|
302
|
+
chunk = f.read(1024**2)
|
|
303
|
+
if not chunk:
|
|
304
|
+
break
|
|
305
|
+
libtriton_hash.update(chunk)
|
|
306
|
+
contents.append(libtriton_hash.hexdigest())
|
|
307
|
+
# language
|
|
308
|
+
language_path = os.path.join(TRITON_PATH, 'language')
|
|
309
|
+
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
|
|
310
|
+
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
311
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
312
|
+
return f'{__version__}' + '-'.join(contents)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def get_cache_key(src, backend, backend_options, env_vars):
|
|
316
|
+
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
|
|
317
|
+
return key
|
triton/runtime/driver.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..backends import backends, DriverBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _create_driver() -> DriverBase:
|
|
7
|
+
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
|
|
8
|
+
if len(active_drivers) != 1:
|
|
9
|
+
raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
|
|
10
|
+
return active_drivers[0]()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DriverConfig:
|
|
14
|
+
|
|
15
|
+
def __init__(self) -> None:
|
|
16
|
+
self._default: DriverBase | None = None
|
|
17
|
+
self._active: DriverBase | None = None
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def default(self) -> DriverBase:
|
|
21
|
+
if self._default is None:
|
|
22
|
+
self._default = _create_driver()
|
|
23
|
+
return self._default
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def active(self) -> DriverBase:
|
|
27
|
+
if self._active is None:
|
|
28
|
+
self._active = self.default
|
|
29
|
+
return self._active
|
|
30
|
+
|
|
31
|
+
def set_active(self, driver: DriverBase) -> None:
|
|
32
|
+
self._active = driver
|
|
33
|
+
|
|
34
|
+
def reset_active(self) -> None:
|
|
35
|
+
self._active = self.default
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
driver = DriverConfig()
|
triton/runtime/errors.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from ..errors import TritonError
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class InterpreterError(TritonError):
|
|
6
|
+
|
|
7
|
+
def __init__(self, error_message: Optional[str] = None):
|
|
8
|
+
self.error_message = error_message
|
|
9
|
+
|
|
10
|
+
def __str__(self) -> str:
|
|
11
|
+
return self.error_message or ""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OutOfResources(TritonError):
|
|
15
|
+
|
|
16
|
+
def __init__(self, required, limit, name):
|
|
17
|
+
self.required = required
|
|
18
|
+
self.limit = limit
|
|
19
|
+
self.name = name
|
|
20
|
+
|
|
21
|
+
def __str__(self) -> str:
|
|
22
|
+
return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help."
|
|
23
|
+
|
|
24
|
+
def __reduce__(self):
|
|
25
|
+
# this is necessary to make CompilationError picklable
|
|
26
|
+
return (type(self), (self.required, self.limit, self.name))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PTXASError(TritonError):
|
|
30
|
+
|
|
31
|
+
def __init__(self, error_message: Optional[str] = None):
|
|
32
|
+
self.error_message = error_message
|
|
33
|
+
|
|
34
|
+
def __str__(self) -> str:
|
|
35
|
+
error_message = self.error_message or ""
|
|
36
|
+
return f"PTXAS error: {error_message}"
|