triton-windows 3.2.0.post11__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +85 -0
- triton/_internal_testing.py +123 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +368 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +512 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +304 -0
- triton/backends/driver.py +48 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +410 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +524 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1303 -0
- triton/compiler/compiler.py +430 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +294 -0
- triton/language/_utils.py +21 -0
- triton/language/core.py +2694 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +13 -0
- triton/language/extra/cuda/_experimental_tma.py +108 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +475 -0
- triton/language/extra/libdevice.py +786 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1796 -0
- triton/language/standard.py +452 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +408 -0
- triton/runtime/build.py +111 -0
- triton/runtime/cache.py +295 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1235 -0
- triton/runtime/jit.py +951 -0
- triton/testing.py +511 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +155 -0
- triton/tools/disasm.py +144 -0
- triton/tools/experimental_descriptor.py +32 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +375 -0
- triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
- triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
- triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
- triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
triton/runtime/cache.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import uuid
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List, Optional
|
|
8
|
+
import base64
|
|
9
|
+
import hashlib
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_home_dir():
|
|
13
|
+
return os.getenv("TRITON_HOME", Path.home())
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def default_cache_dir():
|
|
17
|
+
return os.path.join(get_home_dir(), ".triton", "cache")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def default_override_dir():
|
|
21
|
+
return os.path.join(get_home_dir(), ".triton", "override")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def default_dump_dir():
|
|
25
|
+
return os.path.join(get_home_dir(), ".triton", "dump")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class CacheManager(ABC):
|
|
29
|
+
|
|
30
|
+
def __init__(self, key):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def get_file(self, filename) -> Optional[str]:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def put(self, data, filename, binary=True) -> str:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def put_group(self, filename: str, group: Dict[str, str]):
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class FileCacheManager(CacheManager):
|
|
51
|
+
|
|
52
|
+
def __init__(self, key, override=False, dump=False):
|
|
53
|
+
self.key = key
|
|
54
|
+
self.lock_path = None
|
|
55
|
+
if dump:
|
|
56
|
+
self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir()
|
|
57
|
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
58
|
+
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
59
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
60
|
+
elif override:
|
|
61
|
+
self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir()
|
|
62
|
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
63
|
+
else:
|
|
64
|
+
# create cache directory if it doesn't exist
|
|
65
|
+
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
|
|
66
|
+
if self.cache_dir:
|
|
67
|
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
68
|
+
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
69
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
70
|
+
else:
|
|
71
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
72
|
+
|
|
73
|
+
def _make_path(self, filename) -> str:
|
|
74
|
+
return os.path.join(self.cache_dir, filename)
|
|
75
|
+
|
|
76
|
+
def has_file(self, filename) -> bool:
|
|
77
|
+
if not self.cache_dir:
|
|
78
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
79
|
+
return os.path.exists(self._make_path(filename))
|
|
80
|
+
|
|
81
|
+
def get_file(self, filename) -> Optional[str]:
|
|
82
|
+
if self.has_file(filename):
|
|
83
|
+
return self._make_path(filename)
|
|
84
|
+
else:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
88
|
+
grp_filename = f"__grp__{filename}"
|
|
89
|
+
if not self.has_file(grp_filename):
|
|
90
|
+
return None
|
|
91
|
+
grp_filepath = self._make_path(grp_filename)
|
|
92
|
+
with open(grp_filepath) as f:
|
|
93
|
+
grp_data = json.load(f)
|
|
94
|
+
child_paths = grp_data.get("child_paths", None)
|
|
95
|
+
# Invalid group data.
|
|
96
|
+
if child_paths is None:
|
|
97
|
+
return None
|
|
98
|
+
result = {}
|
|
99
|
+
for c, p in child_paths.items():
|
|
100
|
+
if os.path.exists(p):
|
|
101
|
+
result[c] = p
|
|
102
|
+
return result
|
|
103
|
+
|
|
104
|
+
# Note a group of pushed files as being part of a group
|
|
105
|
+
def put_group(self, filename: str, group: Dict[str, str]) -> str:
|
|
106
|
+
if not self.cache_dir:
|
|
107
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
108
|
+
grp_contents = json.dumps({"child_paths": group})
|
|
109
|
+
grp_filename = f"__grp__{filename}"
|
|
110
|
+
return self.put(grp_contents, grp_filename, binary=False)
|
|
111
|
+
|
|
112
|
+
def put(self, data, filename, binary=True) -> str:
|
|
113
|
+
if not self.cache_dir:
|
|
114
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
115
|
+
binary = isinstance(data, bytes)
|
|
116
|
+
if not binary:
|
|
117
|
+
data = str(data)
|
|
118
|
+
assert self.lock_path is not None
|
|
119
|
+
filepath = self._make_path(filename)
|
|
120
|
+
# Random ID to avoid any collisions
|
|
121
|
+
rnd_id = str(uuid.uuid4())
|
|
122
|
+
# we use the PID in case a bunch of these around so we can see what PID made it
|
|
123
|
+
pid = os.getpid()
|
|
124
|
+
# use temp dir to be robust against program interruptions
|
|
125
|
+
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
|
|
126
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
127
|
+
temp_path = os.path.join(temp_dir, filename)
|
|
128
|
+
|
|
129
|
+
mode = "wb" if binary else "w"
|
|
130
|
+
with open(temp_path, mode) as f:
|
|
131
|
+
f.write(data)
|
|
132
|
+
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
|
|
133
|
+
# so filepath cannot see a partial write
|
|
134
|
+
os.replace(temp_path, filepath)
|
|
135
|
+
os.removedirs(temp_dir)
|
|
136
|
+
return filepath
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class RemoteCacheBackend:
|
|
140
|
+
"""
|
|
141
|
+
A backend implementation for accessing a remote/distributed cache.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(self, key: str):
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
@abstractmethod
|
|
148
|
+
def get(self, filenames: List[str]) -> Dict[str, bytes]:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
@abstractmethod
|
|
152
|
+
def put(self, filename: str, data: bytes):
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class RedisRemoteCacheBackend(RemoteCacheBackend):
|
|
157
|
+
|
|
158
|
+
def __init__(self, key):
|
|
159
|
+
import redis
|
|
160
|
+
self._key = key
|
|
161
|
+
self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
|
|
162
|
+
self._redis = redis.Redis(
|
|
163
|
+
host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
|
|
164
|
+
port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _get_key(self, filename: str) -> str:
|
|
168
|
+
return self._key_fmt.format(key=self._key, filename=filename)
|
|
169
|
+
|
|
170
|
+
def get(self, filenames: List[str]) -> Dict[str, str]:
|
|
171
|
+
results = self._redis.mget([self._get_key(f) for f in filenames])
|
|
172
|
+
return {filename: result for filename, result in zip(filenames, results) if result is not None}
|
|
173
|
+
|
|
174
|
+
def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
|
|
175
|
+
self._redis.set(self._get_key(filename), data)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class RemoteCacheManager(CacheManager):
|
|
179
|
+
|
|
180
|
+
def __init__(self, key, override=False, dump=False):
|
|
181
|
+
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
|
|
182
|
+
remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
|
|
183
|
+
module_path, clz_nme = remote_cache_manager.split(":")
|
|
184
|
+
module = importlib.import_module(module_path)
|
|
185
|
+
remote_cache_cls = getattr(module, clz_nme)
|
|
186
|
+
self._backend = remote_cache_cls(key)
|
|
187
|
+
|
|
188
|
+
self._override = override
|
|
189
|
+
self._dump = dump
|
|
190
|
+
|
|
191
|
+
# Use a `FileCacheManager` to materialize remote cache paths locally.
|
|
192
|
+
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
|
|
193
|
+
|
|
194
|
+
def _materialize(self, filename: str, data: bytes):
|
|
195
|
+
# We use a backing `FileCacheManager` to provide the materialized data.
|
|
196
|
+
return self._file_cache_manager.put(data, filename, binary=True)
|
|
197
|
+
|
|
198
|
+
def get_file(self, filename: str) -> Optional[str]:
|
|
199
|
+
# We don't handle the dump/override cases.
|
|
200
|
+
if self._dump or self._override:
|
|
201
|
+
return self._file_cache_manager.get_file(filename)
|
|
202
|
+
|
|
203
|
+
# We always check the remote cache backend -- even if our internal file-
|
|
204
|
+
# based cache has the item -- to make sure LRU accounting works as
|
|
205
|
+
# expected.
|
|
206
|
+
results = self._backend.get([filename])
|
|
207
|
+
if len(results) == 0:
|
|
208
|
+
return None
|
|
209
|
+
(_, data), = results.items()
|
|
210
|
+
return self._materialize(filename, data)
|
|
211
|
+
|
|
212
|
+
def put(self, data, filename: str, binary=True) -> str:
|
|
213
|
+
# We don't handle the dump/override cases.
|
|
214
|
+
if self._dump or self._override:
|
|
215
|
+
return self._file_cache_manager.put(data, filename, binary=binary)
|
|
216
|
+
|
|
217
|
+
if not isinstance(data, bytes):
|
|
218
|
+
data = str(data).encode("utf-8")
|
|
219
|
+
self._backend.put(filename, data)
|
|
220
|
+
return self._materialize(filename, data)
|
|
221
|
+
|
|
222
|
+
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
223
|
+
# We don't handle the dump/override cases.
|
|
224
|
+
if self._dump or self._override:
|
|
225
|
+
return self._file_cache_manager.get_group(filename)
|
|
226
|
+
|
|
227
|
+
grp_filename = f"__grp__{filename}"
|
|
228
|
+
grp_filepath = self.get_file(grp_filename)
|
|
229
|
+
if grp_filepath is None:
|
|
230
|
+
return None
|
|
231
|
+
with open(grp_filepath) as f:
|
|
232
|
+
grp_data = json.load(f)
|
|
233
|
+
child_paths = grp_data.get("child_paths", None)
|
|
234
|
+
|
|
235
|
+
result = None
|
|
236
|
+
|
|
237
|
+
# Found group data.
|
|
238
|
+
if child_paths is not None:
|
|
239
|
+
result = {}
|
|
240
|
+
for child_path, data in self._backend.get(child_paths).items():
|
|
241
|
+
result[child_path] = self._materialize(child_path, data)
|
|
242
|
+
|
|
243
|
+
return result
|
|
244
|
+
|
|
245
|
+
def put_group(self, filename: str, group: Dict[str, str]):
|
|
246
|
+
# We don't handle the dump/override cases.
|
|
247
|
+
if self._dump or self._override:
|
|
248
|
+
return self._file_cache_manager.put_group(filename, group)
|
|
249
|
+
|
|
250
|
+
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
|
|
251
|
+
grp_filename = f"__grp__{filename}"
|
|
252
|
+
return self.put(grp_contents, grp_filename)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
__cache_cls = FileCacheManager
|
|
256
|
+
__cache_cls_nme = "DEFAULT"
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _base64(key):
|
|
260
|
+
# Assume key is a hex string.
|
|
261
|
+
return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def get_cache_manager(key) -> CacheManager:
|
|
265
|
+
import os
|
|
266
|
+
|
|
267
|
+
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
|
|
268
|
+
global __cache_cls
|
|
269
|
+
global __cache_cls_nme
|
|
270
|
+
|
|
271
|
+
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
|
272
|
+
module_path, clz_nme = user_cache_manager.split(":")
|
|
273
|
+
module = importlib.import_module(module_path)
|
|
274
|
+
__cache_cls = getattr(module, clz_nme)
|
|
275
|
+
__cache_cls_nme = user_cache_manager
|
|
276
|
+
|
|
277
|
+
return __cache_cls(_base64(key))
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def get_override_manager(key) -> CacheManager:
|
|
281
|
+
return __cache_cls(_base64(key), override=True)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def get_dump_manager(key) -> CacheManager:
|
|
285
|
+
return __cache_cls(_base64(key), dump=True)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
289
|
+
# Get unique key for the compiled code
|
|
290
|
+
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
|
291
|
+
key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
|
|
292
|
+
for kw in kwargs:
|
|
293
|
+
key = f"{key}-{kwargs.get(kw)}"
|
|
294
|
+
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
295
|
+
return _base64(key)
|
triton/runtime/driver.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from ..backends import backends
|
|
2
|
+
from ..backends import DriverBase
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _create_driver():
|
|
6
|
+
actives = [x.driver for x in backends.values() if x.driver.is_active()]
|
|
7
|
+
if len(actives) != 1:
|
|
8
|
+
raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.")
|
|
9
|
+
return actives[0]()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LazyProxy:
|
|
13
|
+
|
|
14
|
+
def __init__(self, init_fn):
|
|
15
|
+
self._init_fn = init_fn
|
|
16
|
+
self._obj = None
|
|
17
|
+
|
|
18
|
+
def _initialize_obj(self):
|
|
19
|
+
if self._obj is None:
|
|
20
|
+
self._obj = self._init_fn()
|
|
21
|
+
|
|
22
|
+
def __getattr__(self, name):
|
|
23
|
+
self._initialize_obj()
|
|
24
|
+
return getattr(self._obj, name)
|
|
25
|
+
|
|
26
|
+
def __setattr__(self, name, value):
|
|
27
|
+
if name in ["_init_fn", "_obj"]:
|
|
28
|
+
super().__setattr__(name, value)
|
|
29
|
+
else:
|
|
30
|
+
self._initialize_obj()
|
|
31
|
+
setattr(self._obj, name, value)
|
|
32
|
+
|
|
33
|
+
def __delattr__(self, name):
|
|
34
|
+
self._initialize_obj()
|
|
35
|
+
delattr(self._obj, name)
|
|
36
|
+
|
|
37
|
+
def __repr__(self):
|
|
38
|
+
if self._obj is None:
|
|
39
|
+
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
|
|
40
|
+
return repr(self._obj)
|
|
41
|
+
|
|
42
|
+
def __str__(self):
|
|
43
|
+
self._initialize_obj()
|
|
44
|
+
return str(self._obj)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DriverConfig:
|
|
48
|
+
|
|
49
|
+
def __init__(self):
|
|
50
|
+
self.default = LazyProxy(_create_driver)
|
|
51
|
+
self.active = self.default
|
|
52
|
+
|
|
53
|
+
def set_active(self, driver: DriverBase):
|
|
54
|
+
self.active = driver
|
|
55
|
+
|
|
56
|
+
def reset_active(self):
|
|
57
|
+
self.active = self.default
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
driver = DriverConfig()
|
triton/runtime/errors.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from ..errors import TritonError
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class InterpreterError(TritonError):
|
|
6
|
+
|
|
7
|
+
def __init__(self, error_message: Optional[str] = None):
|
|
8
|
+
self.error_message = error_message
|
|
9
|
+
|
|
10
|
+
def __str__(self) -> str:
|
|
11
|
+
return self.error_message or ""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OutOfResources(TritonError):
|
|
15
|
+
|
|
16
|
+
def __init__(self, required, limit, name):
|
|
17
|
+
self.required = required
|
|
18
|
+
self.limit = limit
|
|
19
|
+
self.name = name
|
|
20
|
+
|
|
21
|
+
def __str__(self) -> str:
|
|
22
|
+
return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help."
|
|
23
|
+
|
|
24
|
+
def __reduce__(self):
|
|
25
|
+
# this is necessary to make CompilationError picklable
|
|
26
|
+
return (type(self), (self.required, self.limit, self.name))
|