triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,317 @@
1
+ import json
2
+ import os
3
+ import uuid
4
+ from abc import ABC, abstractmethod
5
+ from typing import Dict, List, Optional
6
+ import base64
7
+ import hashlib
8
+ import functools
9
+ import sysconfig
10
+
11
+ from triton import __version__, knobs
12
+
13
+
14
+ class CacheManager(ABC):
15
+
16
+ def __init__(self, key, override=False, dump=False):
17
+ pass
18
+
19
+ @abstractmethod
20
+ def get_file(self, filename) -> Optional[str]:
21
+ pass
22
+
23
+ @abstractmethod
24
+ def put(self, data, filename, binary=True) -> str:
25
+ pass
26
+
27
+ @abstractmethod
28
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
29
+ pass
30
+
31
+ @abstractmethod
32
+ def put_group(self, filename: str, group: Dict[str, str]):
33
+ pass
34
+
35
+
36
+ class FileCacheManager(CacheManager):
37
+
38
+ def __init__(self, key, override=False, dump=False):
39
+ self.key = key
40
+ self.lock_path = None
41
+ if dump:
42
+ self.cache_dir = knobs.cache.dump_dir
43
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
44
+ self.lock_path = os.path.join(self.cache_dir, "lock")
45
+ os.makedirs(self.cache_dir, exist_ok=True)
46
+ elif override:
47
+ self.cache_dir = knobs.cache.override_dir
48
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
49
+ else:
50
+ # create cache directory if it doesn't exist
51
+ self.cache_dir = knobs.cache.dir
52
+ if self.cache_dir:
53
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
54
+ self.lock_path = os.path.join(self.cache_dir, "lock")
55
+ os.makedirs(self.cache_dir, exist_ok=True)
56
+ else:
57
+ raise RuntimeError("Could not create or locate cache dir")
58
+
59
+ def _make_path(self, filename) -> str:
60
+ return os.path.join(self.cache_dir, filename)
61
+
62
+ def has_file(self, filename) -> bool:
63
+ if not self.cache_dir:
64
+ raise RuntimeError("Could not create or locate cache dir")
65
+ return os.path.exists(self._make_path(filename))
66
+
67
+ def get_file(self, filename) -> Optional[str]:
68
+ if self.has_file(filename):
69
+ return self._make_path(filename)
70
+ else:
71
+ return None
72
+
73
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
74
+ grp_filename = f"__grp__{filename}"
75
+ if not self.has_file(grp_filename):
76
+ return None
77
+ grp_filepath = self._make_path(grp_filename)
78
+ with open(grp_filepath) as f:
79
+ grp_data = json.load(f)
80
+ child_paths = grp_data.get("child_paths", None)
81
+ # Invalid group data.
82
+ if child_paths is None:
83
+ return None
84
+ result = {}
85
+ for c, p in child_paths.items():
86
+ if os.path.exists(p):
87
+ result[c] = p
88
+ return result
89
+
90
+ # Note a group of pushed files as being part of a group
91
+ def put_group(self, filename: str, group: Dict[str, str]) -> str:
92
+ if not self.cache_dir:
93
+ raise RuntimeError("Could not create or locate cache dir")
94
+ grp_contents = json.dumps({"child_paths": group})
95
+ grp_filename = f"__grp__{filename}"
96
+ return self.put(grp_contents, grp_filename, binary=False)
97
+
98
+ def put(self, data, filename, binary=True) -> str:
99
+ if not self.cache_dir:
100
+ raise RuntimeError("Could not create or locate cache dir")
101
+ binary = isinstance(data, bytes)
102
+ if not binary:
103
+ data = str(data)
104
+ assert self.lock_path is not None
105
+ filepath = self._make_path(filename)
106
+ # Random ID to avoid any collisions
107
+ rnd_id = str(uuid.uuid4())
108
+ # we use the PID in case a bunch of these around so we can see what PID made it
109
+ pid = os.getpid()
110
+ # use temp dir to be robust against program interruptions
111
+ temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
112
+ os.makedirs(temp_dir, exist_ok=True)
113
+ temp_path = os.path.join(temp_dir, filename)
114
+
115
+ mode = "wb" if binary else "w"
116
+ with open(temp_path, mode) as f:
117
+ f.write(data)
118
+ # Replace is guaranteed to be atomic on POSIX systems if it succeeds
119
+ # so filepath cannot see a partial write
120
+ try:
121
+ os.replace(temp_path, filepath)
122
+ except PermissionError:
123
+ # Ignore PermissionError on Windows because it happens when another process already
124
+ # put a file into the cache and locked it by opening it.
125
+ if os.name == "nt":
126
+ os.remove(temp_path)
127
+ else:
128
+ raise
129
+ os.removedirs(temp_dir)
130
+ return filepath
131
+
132
+
133
+ class RemoteCacheBackend:
134
+ """
135
+ A backend implementation for accessing a remote/distributed cache.
136
+ """
137
+
138
+ def __init__(self, key: str):
139
+ pass
140
+
141
+ @abstractmethod
142
+ def get(self, filenames: List[str]) -> Dict[str, bytes]:
143
+ pass
144
+
145
+ @abstractmethod
146
+ def put(self, filename: str, data: bytes):
147
+ pass
148
+
149
+
150
+ class RedisRemoteCacheBackend(RemoteCacheBackend):
151
+
152
+ def __init__(self, key):
153
+ import redis
154
+ self._key = key
155
+ self._key_fmt = knobs.cache.redis.key_format
156
+ self._redis = redis.Redis(
157
+ host=knobs.cache.redis.host,
158
+ port=knobs.cache.redis.port,
159
+ )
160
+
161
+ def _get_key(self, filename: str) -> str:
162
+ return self._key_fmt.format(key=self._key, filename=filename)
163
+
164
+ def get(self, filenames: List[str]) -> Dict[str, str]:
165
+ results = self._redis.mget([self._get_key(f) for f in filenames])
166
+ return {filename: result for filename, result in zip(filenames, results) if result is not None}
167
+
168
+ def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
169
+ self._redis.set(self._get_key(filename), data)
170
+
171
+
172
+ class RemoteCacheManager(CacheManager):
173
+
174
+ def __init__(self, key, override=False, dump=False):
175
+ # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
176
+ remote_cache_cls = knobs.cache.remote_manager_class
177
+ if not remote_cache_cls:
178
+ raise RuntimeError(
179
+ "Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
180
+ self._backend = remote_cache_cls(key)
181
+
182
+ self._override = override
183
+ self._dump = dump
184
+
185
+ # Use a `FileCacheManager` to materialize remote cache paths locally.
186
+ self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
187
+
188
+ def _materialize(self, filename: str, data: bytes):
189
+ # We use a backing `FileCacheManager` to provide the materialized data.
190
+ return self._file_cache_manager.put(data, filename, binary=True)
191
+
192
+ def get_file(self, filename: str) -> Optional[str]:
193
+ # We don't handle the dump/override cases.
194
+ if self._dump or self._override:
195
+ return self._file_cache_manager.get_file(filename)
196
+
197
+ # We always check the remote cache backend -- even if our internal file-
198
+ # based cache has the item -- to make sure LRU accounting works as
199
+ # expected.
200
+ results = self._backend.get([filename])
201
+ if len(results) == 0:
202
+ return None
203
+ (_, data), = results.items()
204
+ return self._materialize(filename, data)
205
+
206
+ def put(self, data, filename: str, binary=True) -> str:
207
+ # We don't handle the dump/override cases.
208
+ if self._dump or self._override:
209
+ return self._file_cache_manager.put(data, filename, binary=binary)
210
+
211
+ if not isinstance(data, bytes):
212
+ data = str(data).encode("utf-8")
213
+ self._backend.put(filename, data)
214
+ return self._materialize(filename, data)
215
+
216
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
217
+ # We don't handle the dump/override cases.
218
+ if self._dump or self._override:
219
+ return self._file_cache_manager.get_group(filename)
220
+
221
+ grp_filename = f"__grp__{filename}"
222
+ grp_filepath = self.get_file(grp_filename)
223
+ if grp_filepath is None:
224
+ return None
225
+ with open(grp_filepath) as f:
226
+ grp_data = json.load(f)
227
+ child_paths = grp_data.get("child_paths", None)
228
+
229
+ result = None
230
+
231
+ # Found group data.
232
+ if child_paths is not None:
233
+ result = {}
234
+ for child_path, data in self._backend.get(child_paths).items():
235
+ result[child_path] = self._materialize(child_path, data)
236
+
237
+ return result
238
+
239
+ def put_group(self, filename: str, group: Dict[str, str]):
240
+ # We don't handle the dump/override cases.
241
+ if self._dump or self._override:
242
+ return self._file_cache_manager.put_group(filename, group)
243
+
244
+ grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
245
+ grp_filename = f"__grp__{filename}"
246
+ return self.put(grp_contents, grp_filename)
247
+
248
+
249
+ def _base32(key):
250
+ # Assume key is a hex string.
251
+ return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
252
+
253
+
254
+ def get_cache_manager(key) -> CacheManager:
255
+ cls = knobs.cache.manager_class or FileCacheManager
256
+ return cls(_base32(key))
257
+
258
+
259
+ def get_override_manager(key) -> CacheManager:
260
+ cls = knobs.cache.manager_class or FileCacheManager
261
+ return cls(_base32(key), override=True)
262
+
263
+
264
+ def get_dump_manager(key) -> CacheManager:
265
+ cls = knobs.cache.manager_class or FileCacheManager
266
+ return cls(_base32(key), dump=True)
267
+
268
+
269
+ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
270
+ # Get unique key for the compiled code
271
+ signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
272
+ key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
273
+ for kw in kwargs:
274
+ key = f"{key}-{kwargs.get(kw)}"
275
+ key = hashlib.sha256(key.encode("utf-8")).hexdigest()
276
+ return _base32(key)
277
+
278
+
279
+ @functools.lru_cache()
280
+ def triton_key():
281
+ import pkgutil
282
+ TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
283
+ contents = []
284
+ # frontend
285
+ with open(__file__, "rb") as f:
286
+ contents += [hashlib.sha256(f.read()).hexdigest()]
287
+ # compiler
288
+ path_prefixes = [
289
+ (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
290
+ (os.path.join(TRITON_PATH, "backends"), "triton.backends."),
291
+ ]
292
+ for path, prefix in path_prefixes:
293
+ for lib in pkgutil.walk_packages([path], prefix=prefix):
294
+ with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
295
+ contents += [hashlib.sha256(f.read()).hexdigest()]
296
+
297
+ # backend
298
+ libtriton_hash = hashlib.sha256()
299
+ ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
300
+ with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
301
+ while True:
302
+ chunk = f.read(1024**2)
303
+ if not chunk:
304
+ break
305
+ libtriton_hash.update(chunk)
306
+ contents.append(libtriton_hash.hexdigest())
307
+ # language
308
+ language_path = os.path.join(TRITON_PATH, 'language')
309
+ for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
310
+ with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
311
+ contents += [hashlib.sha256(f.read()).hexdigest()]
312
+ return f'{__version__}' + '-'.join(contents)
313
+
314
+
315
+ def get_cache_key(src, backend, backend_options, env_vars):
316
+ key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
317
+ return key
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ from ..backends import backends, DriverBase
4
+
5
+
6
+ def _create_driver() -> DriverBase:
7
+ active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
8
+ if len(active_drivers) != 1:
9
+ raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
10
+ return active_drivers[0]()
11
+
12
+
13
+ class DriverConfig:
14
+
15
+ def __init__(self) -> None:
16
+ self._default: DriverBase | None = None
17
+ self._active: DriverBase | None = None
18
+
19
+ @property
20
+ def default(self) -> DriverBase:
21
+ if self._default is None:
22
+ self._default = _create_driver()
23
+ return self._default
24
+
25
+ @property
26
+ def active(self) -> DriverBase:
27
+ if self._active is None:
28
+ self._active = self.default
29
+ return self._active
30
+
31
+ def set_active(self, driver: DriverBase) -> None:
32
+ self._active = driver
33
+
34
+ def reset_active(self) -> None:
35
+ self._active = self.default
36
+
37
+
38
+ driver = DriverConfig()
@@ -0,0 +1,36 @@
1
+ from ..errors import TritonError
2
+ from typing import Optional
3
+
4
+
5
+ class InterpreterError(TritonError):
6
+
7
+ def __init__(self, error_message: Optional[str] = None):
8
+ self.error_message = error_message
9
+
10
+ def __str__(self) -> str:
11
+ return self.error_message or ""
12
+
13
+
14
+ class OutOfResources(TritonError):
15
+
16
+ def __init__(self, required, limit, name):
17
+ self.required = required
18
+ self.limit = limit
19
+ self.name = name
20
+
21
+ def __str__(self) -> str:
22
+ return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help."
23
+
24
+ def __reduce__(self):
25
+ # this is necessary to make CompilationError picklable
26
+ return (type(self), (self.required, self.limit, self.name))
27
+
28
+
29
+ class PTXASError(TritonError):
30
+
31
+ def __init__(self, error_message: Optional[str] = None):
32
+ self.error_message = error_message
33
+
34
+ def __str__(self) -> str:
35
+ error_message = self.error_message or ""
36
+ return f"PTXAS error: {error_message}"