triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
triton/backends/nvidia/driver.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import os
|
|
3
|
+
import sysconfig
|
|
3
4
|
import hashlib
|
|
4
5
|
import subprocess
|
|
5
6
|
import tempfile
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from triton.runtime.build import _build
|
|
8
9
|
from triton.runtime.cache import get_cache_manager
|
|
10
|
+
from triton.runtime import _allocation
|
|
9
11
|
from triton.backends.compiler import GPUTarget
|
|
10
12
|
from triton.backends.driver import GPUDriver
|
|
11
13
|
|
|
@@ -53,14 +55,17 @@ def library_dirs():
|
|
|
53
55
|
return [libdevice_dir, *libcuda_dirs()]
|
|
54
56
|
|
|
55
57
|
|
|
58
|
+
@functools.lru_cache()
|
|
59
|
+
def platform_key():
|
|
60
|
+
from platform import machine, system, architecture
|
|
61
|
+
return ",".join([machine(), system(), *architecture()])
|
|
62
|
+
|
|
63
|
+
|
|
56
64
|
def compile_module_from_src(src, name):
|
|
57
|
-
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
|
|
65
|
+
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
|
|
58
66
|
cache = get_cache_manager(key)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
else:
|
|
62
|
-
so_name = f"{name}.so"
|
|
63
|
-
cache_path = cache.get_file(so_name)
|
|
67
|
+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
68
|
+
cache_path = cache.get_file(f"{name}.{ext}")
|
|
64
69
|
if cache_path is None:
|
|
65
70
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
66
71
|
src_path = os.path.join(tmpdir, f"{name}.c")
|
|
@@ -68,7 +73,7 @@ def compile_module_from_src(src, name):
|
|
|
68
73
|
f.write(src)
|
|
69
74
|
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
|
|
70
75
|
with open(so, "rb") as f:
|
|
71
|
-
cache_path = cache.put(f.read(),
|
|
76
|
+
cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
|
|
72
77
|
import importlib.util
|
|
73
78
|
spec = importlib.util.spec_from_file_location(name, cache_path)
|
|
74
79
|
mod = importlib.util.module_from_spec(spec)
|
|
@@ -126,22 +131,32 @@ def ty_to_cpp(ty):
|
|
|
126
131
|
}[ty]
|
|
127
132
|
|
|
128
133
|
|
|
129
|
-
def make_launcher(constants, signature
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
134
|
+
def make_launcher(constants, signature):
|
|
135
|
+
|
|
136
|
+
def _serialize_signature(sig):
|
|
137
|
+
if isinstance(sig, tuple):
|
|
138
|
+
return ','.join(map(_serialize_signature, sig))
|
|
139
|
+
return sig
|
|
133
140
|
|
|
134
141
|
def _extracted_type(ty):
|
|
142
|
+
if isinstance(ty, tuple):
|
|
143
|
+
val = ','.join(map(_extracted_type, ty))
|
|
144
|
+
return f"[{val}]"
|
|
135
145
|
if ty[0] == '*':
|
|
136
146
|
return "PyObject*"
|
|
137
|
-
if ty
|
|
147
|
+
if ty in ("constexpr", "nvTmaDesc"):
|
|
138
148
|
return "PyObject*"
|
|
139
|
-
|
|
140
149
|
return ty_to_cpp(ty)
|
|
141
150
|
|
|
142
151
|
def format_of(ty):
|
|
152
|
+
if isinstance(ty, tuple):
|
|
153
|
+
val = ''.join(map(format_of, ty))
|
|
154
|
+
return f"({val})"
|
|
155
|
+
if ty[0] == '*':
|
|
156
|
+
return "O"
|
|
157
|
+
if ty in ("constexpr", "nvTmaDesc"):
|
|
158
|
+
return "O"
|
|
143
159
|
return {
|
|
144
|
-
"PyObject*": "O",
|
|
145
160
|
"float": "f",
|
|
146
161
|
"double": "d",
|
|
147
162
|
"long": "l",
|
|
@@ -153,12 +168,17 @@ def make_launcher(constants, signature, ids):
|
|
|
153
168
|
"uint16_t": "H",
|
|
154
169
|
"uint32_t": "I",
|
|
155
170
|
"uint64_t": "K",
|
|
156
|
-
}[ty]
|
|
171
|
+
}[ty_to_cpp(ty)]
|
|
157
172
|
|
|
158
|
-
args_format = ''.join([format_of(
|
|
159
|
-
format = "
|
|
173
|
+
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
|
174
|
+
format = "iiiKKpOOOOO" + args_format
|
|
175
|
+
signature = ','.join(map(_serialize_signature, signature.values()))
|
|
176
|
+
signature = list(filter(bool, signature.split(',')))
|
|
177
|
+
signature = {i: s for i, s in enumerate(signature)}
|
|
160
178
|
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
|
161
|
-
|
|
179
|
+
# Record the end of regular arguments;
|
|
180
|
+
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
|
181
|
+
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
|
|
162
182
|
internal_args_list = []
|
|
163
183
|
for i, ty in signature.items():
|
|
164
184
|
if ty[0] == "*":
|
|
@@ -166,16 +186,23 @@ def make_launcher(constants, signature, ids):
|
|
|
166
186
|
elif ty == "nvTmaDesc":
|
|
167
187
|
# Note: we have to dereference the pointer
|
|
168
188
|
internal_args_list.append(f"*tma_ptr{i}")
|
|
169
|
-
|
|
189
|
+
elif ty != "constexpr":
|
|
170
190
|
internal_args_list.append(f"_arg{i}")
|
|
191
|
+
params = range(len(signature))
|
|
171
192
|
|
|
172
193
|
# generate glue code
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
194
|
+
newline = '\n '
|
|
195
|
+
ptr_decls = [
|
|
196
|
+
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
|
|
197
|
+
for i, ty in signature.items()
|
|
198
|
+
if ty[0] == "*"
|
|
199
|
+
]
|
|
200
|
+
tma_decls = [
|
|
201
|
+
f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
|
|
202
|
+
if ty == "nvTmaDesc"
|
|
203
|
+
]
|
|
204
|
+
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
|
205
|
+
params.append("&global_scratch")
|
|
179
206
|
src = f"""
|
|
180
207
|
#include \"cuda.h\"
|
|
181
208
|
#include <stdbool.h>
|
|
@@ -248,19 +275,50 @@ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
|
|
248
275
|
}}
|
|
249
276
|
#endif
|
|
250
277
|
|
|
251
|
-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
252
|
-
{
|
|
278
|
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
279
|
+
void *params[] = {{ {', '.join(params)} }};
|
|
253
280
|
if (gridX*gridY*gridZ > 0) {{
|
|
254
|
-
if (num_ctas == 1) {{
|
|
281
|
+
if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
|
|
255
282
|
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
|
283
|
+
}} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{
|
|
284
|
+
CUlaunchAttribute launchAttr[1];
|
|
285
|
+
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
|
|
286
|
+
launchAttr[0] = coopAttr;
|
|
287
|
+
|
|
288
|
+
CUlaunchConfig config;
|
|
289
|
+
config.gridDimX = gridX;
|
|
290
|
+
config.gridDimY = gridY;
|
|
291
|
+
config.gridDimZ = gridZ;
|
|
292
|
+
config.blockDimX = 32 * num_warps;
|
|
293
|
+
config.blockDimY = 1;
|
|
294
|
+
config.blockDimZ = 1;
|
|
295
|
+
config.sharedMemBytes = shared_memory;
|
|
296
|
+
config.hStream = stream;
|
|
297
|
+
config.attrs = launchAttr;
|
|
298
|
+
config.numAttrs = 1;
|
|
299
|
+
|
|
300
|
+
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
301
|
+
if (cuLaunchKernelExHandle == NULL) {{
|
|
302
|
+
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
303
|
+
}}
|
|
304
|
+
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
|
305
|
+
|
|
256
306
|
}} else {{
|
|
257
|
-
CUlaunchAttribute launchAttr[
|
|
307
|
+
CUlaunchAttribute launchAttr[3];
|
|
258
308
|
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
|
259
309
|
launchAttr[0].value.clusterDim.x = clusterDimX;
|
|
260
310
|
launchAttr[0].value.clusterDim.y = clusterDimY;
|
|
261
311
|
launchAttr[0].value.clusterDim.z = clusterDimZ;
|
|
262
312
|
launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
|
263
313
|
launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
|
314
|
+
|
|
315
|
+
unsigned numAttrs = 2;
|
|
316
|
+
if (0 != launch_cooperative_grid) {{
|
|
317
|
+
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
|
|
318
|
+
launchAttr[2] = coopAttr;
|
|
319
|
+
numAttrs = 3;
|
|
320
|
+
}}
|
|
321
|
+
|
|
264
322
|
CUlaunchConfig config;
|
|
265
323
|
config.gridDimX = gridX * clusterDimX;
|
|
266
324
|
config.gridDimY = gridY * clusterDimY;
|
|
@@ -271,7 +329,7 @@ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas
|
|
|
271
329
|
config.sharedMemBytes = shared_memory;
|
|
272
330
|
config.hStream = stream;
|
|
273
331
|
config.attrs = launchAttr;
|
|
274
|
-
config.numAttrs =
|
|
332
|
+
config.numAttrs = numAttrs;
|
|
275
333
|
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
|
276
334
|
if (cuLaunchKernelExHandle == NULL) {{
|
|
277
335
|
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
|
@@ -396,14 +454,17 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
396
454
|
int gridX, gridY, gridZ;
|
|
397
455
|
uint64_t _stream;
|
|
398
456
|
uint64_t _function;
|
|
457
|
+
int launch_cooperative_grid;
|
|
399
458
|
PyObject *launch_enter_hook = NULL;
|
|
400
459
|
PyObject *launch_exit_hook = NULL;
|
|
401
460
|
PyObject *kernel_metadata = NULL;
|
|
402
461
|
PyObject *launch_metadata = NULL;
|
|
403
|
-
|
|
404
|
-
|
|
462
|
+
PyObject *global_scratch_obj = NULL;
|
|
463
|
+
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
|
|
464
|
+
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
|
|
465
|
+
&_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
|
|
405
466
|
&kernel_metadata, &launch_metadata,
|
|
406
|
-
&launch_enter_hook, &launch_exit_hook
|
|
467
|
+
&launch_enter_hook, &launch_exit_hook{args_list})) {{
|
|
407
468
|
return NULL;
|
|
408
469
|
}}
|
|
409
470
|
|
|
@@ -422,11 +483,20 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
422
483
|
return NULL;
|
|
423
484
|
}}
|
|
424
485
|
|
|
486
|
+
CUdeviceptr global_scratch = 0;
|
|
487
|
+
if (global_scratch_obj != Py_None) {{
|
|
488
|
+
DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
|
|
489
|
+
if (!global_scratch_info.valid) {{
|
|
490
|
+
return NULL;
|
|
491
|
+
}}
|
|
492
|
+
global_scratch = global_scratch_info.dev_ptr;
|
|
493
|
+
}}
|
|
494
|
+
|
|
425
495
|
// raise exception asap
|
|
426
|
-
{
|
|
427
|
-
{
|
|
496
|
+
{newline.join(ptr_decls)}
|
|
497
|
+
{newline.join(tma_decls)}
|
|
428
498
|
Py_BEGIN_ALLOW_THREADS;
|
|
429
|
-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
499
|
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
430
500
|
Py_END_ALLOW_THREADS;
|
|
431
501
|
if (PyErr_Occurred()) {{
|
|
432
502
|
return NULL;
|
|
@@ -441,9 +511,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
441
511
|
|
|
442
512
|
}}
|
|
443
513
|
|
|
444
|
-
|
|
445
|
-
Py_INCREF(Py_None);
|
|
446
|
-
return Py_None;
|
|
514
|
+
Py_RETURN_NONE;
|
|
447
515
|
}}
|
|
448
516
|
|
|
449
517
|
static PyMethodDef ModuleMethods[] = {{
|
|
@@ -474,17 +542,25 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
|
474
542
|
class CudaLauncher(object):
|
|
475
543
|
|
|
476
544
|
def __init__(self, src, metadata):
|
|
477
|
-
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
|
|
478
545
|
constants = src.constants if hasattr(src, "constants") else dict()
|
|
479
|
-
|
|
480
|
-
constants = {
|
|
481
|
-
signature = {
|
|
482
|
-
src = make_launcher(constants, signature
|
|
546
|
+
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
|
|
547
|
+
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
|
548
|
+
signature = {idx: value for idx, value in src.signature.items()}
|
|
549
|
+
src = make_launcher(constants, signature)
|
|
483
550
|
mod = compile_module_from_src(src, "__triton_launcher")
|
|
484
551
|
self.launch = mod.launch
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
self.
|
|
552
|
+
self.global_scratch_size = metadata.global_scratch_size
|
|
553
|
+
self.global_scratch_align = metadata.global_scratch_align
|
|
554
|
+
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
|
555
|
+
|
|
556
|
+
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
|
557
|
+
if self.global_scratch_size > 0:
|
|
558
|
+
grid_size = gridX * gridY * gridZ
|
|
559
|
+
alloc_size = grid_size * self.global_scratch_size
|
|
560
|
+
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
|
|
561
|
+
else:
|
|
562
|
+
global_scratch = None
|
|
563
|
+
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
|
|
488
564
|
|
|
489
565
|
|
|
490
566
|
class CudaDriver(GPUDriver):
|
|
@@ -501,14 +577,21 @@ class CudaDriver(GPUDriver):
|
|
|
501
577
|
warp_size = 32
|
|
502
578
|
return GPUTarget("cuda", capability, warp_size)
|
|
503
579
|
|
|
580
|
+
def get_active_torch_device(self):
|
|
581
|
+
import torch
|
|
582
|
+
return torch.device("cuda", self.get_current_device())
|
|
583
|
+
|
|
504
584
|
def get_device_interface(self):
|
|
505
585
|
import torch
|
|
506
586
|
return torch.cuda
|
|
507
587
|
|
|
508
588
|
@staticmethod
|
|
509
589
|
def is_active():
|
|
510
|
-
|
|
511
|
-
|
|
590
|
+
try:
|
|
591
|
+
import torch
|
|
592
|
+
return torch.cuda.is_available() and (torch.version.hip is None)
|
|
593
|
+
except ImportError:
|
|
594
|
+
return False
|
|
512
595
|
|
|
513
596
|
def get_benchmarker(self):
|
|
514
597
|
from triton.testing import do_bench
|
|
@@ -522,3 +605,6 @@ class CudaDriver(GPUDriver):
|
|
|
522
605
|
# doesn't contain any input data before the run
|
|
523
606
|
cache_size = 256 * 1024 * 1024
|
|
524
607
|
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
|
|
608
|
+
|
|
609
|
+
def clear_cache(self, cache):
|
|
610
|
+
cache.zero_()
|