triton-windows 3.1.0.post17__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,289 @@
1
+ import importlib
2
+ import json
3
+ import os
4
+ import uuid
5
+ from abc import ABC, abstractmethod
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional
8
+ import hashlib
9
+
10
+
11
+ def default_cache_dir():
12
+ return os.path.join(Path.home(), ".triton", "cache")
13
+
14
+
15
+ def default_override_dir():
16
+ return os.path.join(Path.home(), ".triton", "override")
17
+
18
+
19
+ def default_dump_dir():
20
+ return os.path.join(Path.home(), ".triton", "dump")
21
+
22
+
23
+ class CacheManager(ABC):
24
+
25
+ def __init__(self, key):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def get_file(self, filename) -> Optional[str]:
30
+ pass
31
+
32
+ @abstractmethod
33
+ def put(self, data, filename, binary=True) -> str:
34
+ pass
35
+
36
+ @abstractmethod
37
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
38
+ pass
39
+
40
+ @abstractmethod
41
+ def put_group(self, filename: str, group: Dict[str, str]):
42
+ pass
43
+
44
+
45
+ class FileCacheManager(CacheManager):
46
+
47
+ def __init__(self, key, override=False, dump=False):
48
+ self.key = key
49
+ self.lock_path = None
50
+ if dump:
51
+ self.cache_dir = default_dump_dir()
52
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
53
+ self.lock_path = os.path.join(self.cache_dir, "lock")
54
+ os.makedirs(self.cache_dir, exist_ok=True)
55
+ elif override:
56
+ self.cache_dir = default_override_dir()
57
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
58
+ else:
59
+ # create cache directory if it doesn't exist
60
+ self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
61
+ if self.cache_dir:
62
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
63
+ self.lock_path = os.path.join(self.cache_dir, "lock")
64
+ os.makedirs(self.cache_dir, exist_ok=True)
65
+ else:
66
+ raise RuntimeError("Could not create or locate cache dir")
67
+
68
+ def _make_path(self, filename) -> str:
69
+ return os.path.join(self.cache_dir, filename)
70
+
71
+ def has_file(self, filename) -> bool:
72
+ if not self.cache_dir:
73
+ raise RuntimeError("Could not create or locate cache dir")
74
+ return os.path.exists(self._make_path(filename))
75
+
76
+ def get_file(self, filename) -> Optional[str]:
77
+ if self.has_file(filename):
78
+ return self._make_path(filename)
79
+ else:
80
+ return None
81
+
82
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
83
+ grp_filename = f"__grp__{filename}"
84
+ if not self.has_file(grp_filename):
85
+ return None
86
+ grp_filepath = self._make_path(grp_filename)
87
+ with open(grp_filepath) as f:
88
+ grp_data = json.load(f)
89
+ child_paths = grp_data.get("child_paths", None)
90
+ # Invalid group data.
91
+ if child_paths is None:
92
+ return None
93
+ result = {}
94
+ for c, p in child_paths.items():
95
+ if os.path.exists(p):
96
+ result[c] = p
97
+ return result
98
+
99
+ # Note a group of pushed files as being part of a group
100
+ def put_group(self, filename: str, group: Dict[str, str]) -> str:
101
+ if not self.cache_dir:
102
+ raise RuntimeError("Could not create or locate cache dir")
103
+ grp_contents = json.dumps({"child_paths": group})
104
+ grp_filename = f"__grp__{filename}"
105
+ return self.put(grp_contents, grp_filename, binary=False)
106
+
107
+ def put(self, data, filename, binary=True) -> str:
108
+ if not self.cache_dir:
109
+ raise RuntimeError("Could not create or locate cache dir")
110
+ binary = isinstance(data, bytes)
111
+ if not binary:
112
+ data = str(data)
113
+ assert self.lock_path is not None
114
+ filepath = self._make_path(filename)
115
+ # Random ID to avoid any collisions
116
+ rnd_id = str(uuid.uuid4())
117
+ # we use the PID in case a bunch of these around so we can see what PID made it
118
+ pid = os.getpid()
119
+ # use tempfile to be robust against program interruptions
120
+ temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
121
+ mode = "wb" if binary else "w"
122
+ with open(temp_path, mode) as f:
123
+ f.write(data)
124
+ # Replace is guaranteed to be atomic on POSIX systems if it succeeds
125
+ # so filepath cannot see a partial write
126
+ try:
127
+ os.replace(temp_path, filepath)
128
+ except PermissionError:
129
+ # Ignore PermissionError on Windows because it happens when another process already
130
+ # put a file into the cache and locked it by opening it.
131
+ if os.name == "nt":
132
+ os.remove(temp_path)
133
+ else:
134
+ raise
135
+ return filepath
136
+
137
+
138
+ class RemoteCacheBackend:
139
+ """
140
+ A backend implementation for accessing a remote/distributed cache.
141
+ """
142
+
143
+ def __init__(self, key: str):
144
+ pass
145
+
146
+ @abstractmethod
147
+ def get(self, filenames: List[str]) -> Dict[str, bytes]:
148
+ pass
149
+
150
+ @abstractmethod
151
+ def put(self, filename: str, data: bytes):
152
+ pass
153
+
154
+
155
+ class RedisRemoteCacheBackend(RemoteCacheBackend):
156
+
157
+ def __init__(self, key):
158
+ import redis
159
+ self._key = key
160
+ self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
161
+ self._redis = redis.Redis(
162
+ host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
163
+ port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
164
+ )
165
+
166
+ def _get_key(self, filename: str) -> str:
167
+ return self._key_fmt.format(key=self._key, filename=filename)
168
+
169
+ def get(self, filenames: List[str]) -> Dict[str, str]:
170
+ results = self._redis.mget([self._get_key(f) for f in filenames])
171
+ return {filename: result for filename, result in zip(filenames, results) if result is not None}
172
+
173
+ def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
174
+ self._redis.set(self._get_key(filename), data)
175
+
176
+
177
+ class RemoteCacheManager(CacheManager):
178
+
179
+ def __init__(self, key, override=False, dump=False):
180
+ # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
181
+ remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
182
+ module_path, clz_nme = remote_cache_manager.split(":")
183
+ module = importlib.import_module(module_path)
184
+ remote_cache_cls = getattr(module, clz_nme)
185
+ self._backend = remote_cache_cls(key)
186
+
187
+ self._override = override
188
+ self._dump = dump
189
+
190
+ # Use a `FileCacheManager` to materialize remote cache paths locally.
191
+ self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
192
+
193
+ def _materialize(self, filename: str, data: bytes):
194
+ # We use a backing `FileCacheManager` to provide the materialized data.
195
+ return self._file_cache_manager.put(data, filename, binary=True)
196
+
197
+ def get_file(self, filename: str) -> Optional[str]:
198
+ # We don't handle the dump/override cases.
199
+ if self._dump or self._override:
200
+ return self._file_cache_manager.get_file(filename)
201
+
202
+ # We always check the remote cache backend -- even if our internal file-
203
+ # based cache has the item -- to make sure LRU accounting works as
204
+ # expected.
205
+ results = self._backend.get([filename])
206
+ if len(results) == 0:
207
+ return None
208
+ (_, data), = results.items()
209
+ return self._materialize(filename, data)
210
+
211
+ def put(self, data, filename: str, binary=True) -> str:
212
+ # We don't handle the dump/override cases.
213
+ if self._dump or self._override:
214
+ return self._file_cache_manager.put(data, filename, binary=binary)
215
+
216
+ if not isinstance(data, bytes):
217
+ data = str(data).encode("utf-8")
218
+ self._backend.put(filename, data)
219
+ return self._materialize(filename, data)
220
+
221
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
222
+ # We don't handle the dump/override cases.
223
+ if self._dump or self._override:
224
+ return self._file_cache_manager.get_group(filename)
225
+
226
+ grp_filename = f"__grp__{filename}"
227
+ grp_filepath = self.get_file(grp_filename)
228
+ if grp_filepath is None:
229
+ return None
230
+ with open(grp_filepath) as f:
231
+ grp_data = json.load(f)
232
+ child_paths = grp_data.get("child_paths", None)
233
+
234
+ result = None
235
+
236
+ # Found group data.
237
+ if child_paths is not None:
238
+ result = {}
239
+ for child_path, data in self._backend.get(child_paths).items():
240
+ result[child_path] = self._materialize(child_path, data)
241
+
242
+ return result
243
+
244
+ def put_group(self, filename: str, group: Dict[str, str]):
245
+ # We don't handle the dump/override cases.
246
+ if self._dump or self._override:
247
+ return self._file_cache_manager.put_group(filename, group)
248
+
249
+ grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
250
+ grp_filename = f"__grp__{filename}"
251
+ return self.put(grp_contents, grp_filename)
252
+
253
+
254
+ __cache_cls = FileCacheManager
255
+ __cache_cls_nme = "DEFAULT"
256
+
257
+
258
+ def get_cache_manager(key) -> CacheManager:
259
+ import os
260
+
261
+ user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
262
+ global __cache_cls
263
+ global __cache_cls_nme
264
+
265
+ if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
266
+ module_path, clz_nme = user_cache_manager.split(":")
267
+ module = importlib.import_module(module_path)
268
+ __cache_cls = getattr(module, clz_nme)
269
+ __cache_cls_nme = user_cache_manager
270
+
271
+ return __cache_cls(key)
272
+
273
+
274
+ def get_override_manager(key) -> CacheManager:
275
+ return __cache_cls(key, override=True)
276
+
277
+
278
+ def get_dump_manager(key) -> CacheManager:
279
+ return __cache_cls(key, dump=True)
280
+
281
+
282
+ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
283
+ # Get unique key for the compiled code
284
+ signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
285
+ key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
286
+ for kw in kwargs:
287
+ key = f"{key}-{kwargs.get(kw)}"
288
+ key = hashlib.sha256(key.encode("utf-8")).hexdigest()
289
+ return key
@@ -0,0 +1,60 @@
1
+ from ..backends import backends
2
+ from ..backends import DriverBase
3
+
4
+
5
+ def _create_driver():
6
+ actives = [x.driver for x in backends.values() if x.driver.is_active()]
7
+ if len(actives) != 1:
8
+ raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.")
9
+ return actives[0]()
10
+
11
+
12
+ class LazyProxy:
13
+
14
+ def __init__(self, init_fn):
15
+ self._init_fn = init_fn
16
+ self._obj = None
17
+
18
+ def _initialize_obj(self):
19
+ if self._obj is None:
20
+ self._obj = self._init_fn()
21
+
22
+ def __getattr__(self, name):
23
+ self._initialize_obj()
24
+ return getattr(self._obj, name)
25
+
26
+ def __setattr__(self, name, value):
27
+ if name in ["_init_fn", "_obj"]:
28
+ super().__setattr__(name, value)
29
+ else:
30
+ self._initialize_obj()
31
+ setattr(self._obj, name, value)
32
+
33
+ def __delattr__(self, name):
34
+ self._initialize_obj()
35
+ delattr(self._obj, name)
36
+
37
+ def __repr__(self):
38
+ if self._obj is None:
39
+ return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
40
+ return repr(self._obj)
41
+
42
+ def __str__(self):
43
+ self._initialize_obj()
44
+ return str(self._obj)
45
+
46
+
47
+ class DriverConfig:
48
+
49
+ def __init__(self):
50
+ self.default = LazyProxy(_create_driver)
51
+ self.active = self.default
52
+
53
+ def set_active(self, driver: DriverBase):
54
+ self.active = driver
55
+
56
+ def reset_active(self):
57
+ self.active = self.default
58
+
59
+
60
+ driver = DriverConfig()
@@ -0,0 +1,26 @@
1
+ from ..errors import TritonError
2
+ from typing import Optional
3
+
4
+
5
+ class InterpreterError(TritonError):
6
+
7
+ def __init__(self, error_message: Optional[str] = None):
8
+ self.error_message = error_message
9
+
10
+ def __str__(self) -> str:
11
+ return self.error_message or ""
12
+
13
+
14
+ class OutOfResources(TritonError):
15
+
16
+ def __init__(self, required, limit, name):
17
+ self.required = required
18
+ self.limit = limit
19
+ self.name = name
20
+
21
+ def __str__(self) -> str:
22
+ return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help."
23
+
24
+ def __reduce__(self):
25
+ # this is necessary to make CompilationError picklable
26
+ return (type(self), (self.required, self.limit, self.name))