triton-windows 3.2.0.post11__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,295 @@
1
+ import importlib
2
+ import json
3
+ import os
4
+ import uuid
5
+ from abc import ABC, abstractmethod
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional
8
+ import base64
9
+ import hashlib
10
+
11
+
12
+ def get_home_dir():
13
+ return os.getenv("TRITON_HOME", Path.home())
14
+
15
+
16
+ def default_cache_dir():
17
+ return os.path.join(get_home_dir(), ".triton", "cache")
18
+
19
+
20
+ def default_override_dir():
21
+ return os.path.join(get_home_dir(), ".triton", "override")
22
+
23
+
24
+ def default_dump_dir():
25
+ return os.path.join(get_home_dir(), ".triton", "dump")
26
+
27
+
28
+ class CacheManager(ABC):
29
+
30
+ def __init__(self, key):
31
+ pass
32
+
33
+ @abstractmethod
34
+ def get_file(self, filename) -> Optional[str]:
35
+ pass
36
+
37
+ @abstractmethod
38
+ def put(self, data, filename, binary=True) -> str:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
43
+ pass
44
+
45
+ @abstractmethod
46
+ def put_group(self, filename: str, group: Dict[str, str]):
47
+ pass
48
+
49
+
50
+ class FileCacheManager(CacheManager):
51
+
52
+ def __init__(self, key, override=False, dump=False):
53
+ self.key = key
54
+ self.lock_path = None
55
+ if dump:
56
+ self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir()
57
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
58
+ self.lock_path = os.path.join(self.cache_dir, "lock")
59
+ os.makedirs(self.cache_dir, exist_ok=True)
60
+ elif override:
61
+ self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir()
62
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
63
+ else:
64
+ # create cache directory if it doesn't exist
65
+ self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
66
+ if self.cache_dir:
67
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
68
+ self.lock_path = os.path.join(self.cache_dir, "lock")
69
+ os.makedirs(self.cache_dir, exist_ok=True)
70
+ else:
71
+ raise RuntimeError("Could not create or locate cache dir")
72
+
73
+ def _make_path(self, filename) -> str:
74
+ return os.path.join(self.cache_dir, filename)
75
+
76
+ def has_file(self, filename) -> bool:
77
+ if not self.cache_dir:
78
+ raise RuntimeError("Could not create or locate cache dir")
79
+ return os.path.exists(self._make_path(filename))
80
+
81
+ def get_file(self, filename) -> Optional[str]:
82
+ if self.has_file(filename):
83
+ return self._make_path(filename)
84
+ else:
85
+ return None
86
+
87
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
88
+ grp_filename = f"__grp__{filename}"
89
+ if not self.has_file(grp_filename):
90
+ return None
91
+ grp_filepath = self._make_path(grp_filename)
92
+ with open(grp_filepath) as f:
93
+ grp_data = json.load(f)
94
+ child_paths = grp_data.get("child_paths", None)
95
+ # Invalid group data.
96
+ if child_paths is None:
97
+ return None
98
+ result = {}
99
+ for c, p in child_paths.items():
100
+ if os.path.exists(p):
101
+ result[c] = p
102
+ return result
103
+
104
+ # Note a group of pushed files as being part of a group
105
+ def put_group(self, filename: str, group: Dict[str, str]) -> str:
106
+ if not self.cache_dir:
107
+ raise RuntimeError("Could not create or locate cache dir")
108
+ grp_contents = json.dumps({"child_paths": group})
109
+ grp_filename = f"__grp__{filename}"
110
+ return self.put(grp_contents, grp_filename, binary=False)
111
+
112
+ def put(self, data, filename, binary=True) -> str:
113
+ if not self.cache_dir:
114
+ raise RuntimeError("Could not create or locate cache dir")
115
+ binary = isinstance(data, bytes)
116
+ if not binary:
117
+ data = str(data)
118
+ assert self.lock_path is not None
119
+ filepath = self._make_path(filename)
120
+ # Random ID to avoid any collisions
121
+ rnd_id = str(uuid.uuid4())
122
+ # we use the PID in case a bunch of these around so we can see what PID made it
123
+ pid = os.getpid()
124
+ # use temp dir to be robust against program interruptions
125
+ temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
126
+ os.makedirs(temp_dir, exist_ok=True)
127
+ temp_path = os.path.join(temp_dir, filename)
128
+
129
+ mode = "wb" if binary else "w"
130
+ with open(temp_path, mode) as f:
131
+ f.write(data)
132
+ # Replace is guaranteed to be atomic on POSIX systems if it succeeds
133
+ # so filepath cannot see a partial write
134
+ os.replace(temp_path, filepath)
135
+ os.removedirs(temp_dir)
136
+ return filepath
137
+
138
+
139
+ class RemoteCacheBackend:
140
+ """
141
+ A backend implementation for accessing a remote/distributed cache.
142
+ """
143
+
144
+ def __init__(self, key: str):
145
+ pass
146
+
147
+ @abstractmethod
148
+ def get(self, filenames: List[str]) -> Dict[str, bytes]:
149
+ pass
150
+
151
+ @abstractmethod
152
+ def put(self, filename: str, data: bytes):
153
+ pass
154
+
155
+
156
+ class RedisRemoteCacheBackend(RemoteCacheBackend):
157
+
158
+ def __init__(self, key):
159
+ import redis
160
+ self._key = key
161
+ self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
162
+ self._redis = redis.Redis(
163
+ host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
164
+ port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
165
+ )
166
+
167
+ def _get_key(self, filename: str) -> str:
168
+ return self._key_fmt.format(key=self._key, filename=filename)
169
+
170
+ def get(self, filenames: List[str]) -> Dict[str, str]:
171
+ results = self._redis.mget([self._get_key(f) for f in filenames])
172
+ return {filename: result for filename, result in zip(filenames, results) if result is not None}
173
+
174
+ def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
175
+ self._redis.set(self._get_key(filename), data)
176
+
177
+
178
+ class RemoteCacheManager(CacheManager):
179
+
180
+ def __init__(self, key, override=False, dump=False):
181
+ # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
182
+ remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
183
+ module_path, clz_nme = remote_cache_manager.split(":")
184
+ module = importlib.import_module(module_path)
185
+ remote_cache_cls = getattr(module, clz_nme)
186
+ self._backend = remote_cache_cls(key)
187
+
188
+ self._override = override
189
+ self._dump = dump
190
+
191
+ # Use a `FileCacheManager` to materialize remote cache paths locally.
192
+ self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
193
+
194
+ def _materialize(self, filename: str, data: bytes):
195
+ # We use a backing `FileCacheManager` to provide the materialized data.
196
+ return self._file_cache_manager.put(data, filename, binary=True)
197
+
198
+ def get_file(self, filename: str) -> Optional[str]:
199
+ # We don't handle the dump/override cases.
200
+ if self._dump or self._override:
201
+ return self._file_cache_manager.get_file(filename)
202
+
203
+ # We always check the remote cache backend -- even if our internal file-
204
+ # based cache has the item -- to make sure LRU accounting works as
205
+ # expected.
206
+ results = self._backend.get([filename])
207
+ if len(results) == 0:
208
+ return None
209
+ (_, data), = results.items()
210
+ return self._materialize(filename, data)
211
+
212
+ def put(self, data, filename: str, binary=True) -> str:
213
+ # We don't handle the dump/override cases.
214
+ if self._dump or self._override:
215
+ return self._file_cache_manager.put(data, filename, binary=binary)
216
+
217
+ if not isinstance(data, bytes):
218
+ data = str(data).encode("utf-8")
219
+ self._backend.put(filename, data)
220
+ return self._materialize(filename, data)
221
+
222
+ def get_group(self, filename: str) -> Optional[Dict[str, str]]:
223
+ # We don't handle the dump/override cases.
224
+ if self._dump or self._override:
225
+ return self._file_cache_manager.get_group(filename)
226
+
227
+ grp_filename = f"__grp__{filename}"
228
+ grp_filepath = self.get_file(grp_filename)
229
+ if grp_filepath is None:
230
+ return None
231
+ with open(grp_filepath) as f:
232
+ grp_data = json.load(f)
233
+ child_paths = grp_data.get("child_paths", None)
234
+
235
+ result = None
236
+
237
+ # Found group data.
238
+ if child_paths is not None:
239
+ result = {}
240
+ for child_path, data in self._backend.get(child_paths).items():
241
+ result[child_path] = self._materialize(child_path, data)
242
+
243
+ return result
244
+
245
+ def put_group(self, filename: str, group: Dict[str, str]):
246
+ # We don't handle the dump/override cases.
247
+ if self._dump or self._override:
248
+ return self._file_cache_manager.put_group(filename, group)
249
+
250
+ grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
251
+ grp_filename = f"__grp__{filename}"
252
+ return self.put(grp_contents, grp_filename)
253
+
254
+
255
+ __cache_cls = FileCacheManager
256
+ __cache_cls_nme = "DEFAULT"
257
+
258
+
259
+ def _base64(key):
260
+ # Assume key is a hex string.
261
+ return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
262
+
263
+
264
+ def get_cache_manager(key) -> CacheManager:
265
+ import os
266
+
267
+ user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
268
+ global __cache_cls
269
+ global __cache_cls_nme
270
+
271
+ if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
272
+ module_path, clz_nme = user_cache_manager.split(":")
273
+ module = importlib.import_module(module_path)
274
+ __cache_cls = getattr(module, clz_nme)
275
+ __cache_cls_nme = user_cache_manager
276
+
277
+ return __cache_cls(_base64(key))
278
+
279
+
280
+ def get_override_manager(key) -> CacheManager:
281
+ return __cache_cls(_base64(key), override=True)
282
+
283
+
284
+ def get_dump_manager(key) -> CacheManager:
285
+ return __cache_cls(_base64(key), dump=True)
286
+
287
+
288
+ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
289
+ # Get unique key for the compiled code
290
+ signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
291
+ key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
292
+ for kw in kwargs:
293
+ key = f"{key}-{kwargs.get(kw)}"
294
+ key = hashlib.sha256(key.encode("utf-8")).hexdigest()
295
+ return _base64(key)
@@ -0,0 +1,60 @@
1
+ from ..backends import backends
2
+ from ..backends import DriverBase
3
+
4
+
5
+ def _create_driver():
6
+ actives = [x.driver for x in backends.values() if x.driver.is_active()]
7
+ if len(actives) != 1:
8
+ raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.")
9
+ return actives[0]()
10
+
11
+
12
+ class LazyProxy:
13
+
14
+ def __init__(self, init_fn):
15
+ self._init_fn = init_fn
16
+ self._obj = None
17
+
18
+ def _initialize_obj(self):
19
+ if self._obj is None:
20
+ self._obj = self._init_fn()
21
+
22
+ def __getattr__(self, name):
23
+ self._initialize_obj()
24
+ return getattr(self._obj, name)
25
+
26
+ def __setattr__(self, name, value):
27
+ if name in ["_init_fn", "_obj"]:
28
+ super().__setattr__(name, value)
29
+ else:
30
+ self._initialize_obj()
31
+ setattr(self._obj, name, value)
32
+
33
+ def __delattr__(self, name):
34
+ self._initialize_obj()
35
+ delattr(self._obj, name)
36
+
37
+ def __repr__(self):
38
+ if self._obj is None:
39
+ return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
40
+ return repr(self._obj)
41
+
42
+ def __str__(self):
43
+ self._initialize_obj()
44
+ return str(self._obj)
45
+
46
+
47
+ class DriverConfig:
48
+
49
+ def __init__(self):
50
+ self.default = LazyProxy(_create_driver)
51
+ self.active = self.default
52
+
53
+ def set_active(self, driver: DriverBase):
54
+ self.active = driver
55
+
56
+ def reset_active(self):
57
+ self.active = self.default
58
+
59
+
60
+ driver = DriverConfig()
@@ -0,0 +1,26 @@
1
+ from ..errors import TritonError
2
+ from typing import Optional
3
+
4
+
5
+ class InterpreterError(TritonError):
6
+
7
+ def __init__(self, error_message: Optional[str] = None):
8
+ self.error_message = error_message
9
+
10
+ def __str__(self) -> str:
11
+ return self.error_message or ""
12
+
13
+
14
+ class OutOfResources(TritonError):
15
+
16
+ def __init__(self, required, limit, name):
17
+ self.required = required
18
+ self.limit = limit
19
+ self.name = name
20
+
21
+ def __str__(self) -> str:
22
+ return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help."
23
+
24
+ def __reduce__(self):
25
+ # this is necessary to make CompilationError picklable
26
+ return (type(self), (self.required, self.limit, self.name))