triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/runtime/autotuner.py
CHANGED
|
@@ -1,34 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import builtins
|
|
4
|
-
import os
|
|
5
4
|
import time
|
|
6
5
|
import inspect
|
|
6
|
+
import hashlib
|
|
7
|
+
import json
|
|
8
|
+
from functools import cached_property
|
|
7
9
|
from typing import Dict, Tuple, List, Optional
|
|
8
10
|
|
|
9
|
-
from
|
|
11
|
+
from .. import knobs
|
|
12
|
+
from .jit import KernelInterface, JITFunction
|
|
10
13
|
from .errors import OutOfResources, PTXASError
|
|
11
14
|
from .driver import driver
|
|
15
|
+
from .cache import get_cache_manager, triton_key
|
|
16
|
+
from triton._C.libtriton import get_cache_invalidating_env_vars
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
class Autotuner(KernelInterface):
|
|
15
20
|
|
|
16
|
-
def __init__(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
arg_names,
|
|
20
|
-
configs,
|
|
21
|
-
key,
|
|
22
|
-
reset_to_zero,
|
|
23
|
-
restore_value,
|
|
24
|
-
pre_hook=None,
|
|
25
|
-
post_hook=None,
|
|
26
|
-
prune_configs_by: Optional[Dict] = None,
|
|
27
|
-
warmup=None,
|
|
28
|
-
rep=None,
|
|
29
|
-
use_cuda_graph=False,
|
|
30
|
-
do_bench=None,
|
|
31
|
-
):
|
|
21
|
+
def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
|
|
22
|
+
prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None,
|
|
23
|
+
cache_results=False):
|
|
32
24
|
"""
|
|
33
25
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
34
26
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
@@ -36,15 +28,13 @@ class Autotuner(KernelInterface):
|
|
|
36
28
|
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
|
37
29
|
"""
|
|
38
30
|
if not configs:
|
|
39
|
-
self.configs = [
|
|
40
|
-
Config({}, num_warps=4, num_stages=3, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
|
|
41
|
-
reg_dec_producer=0, reg_inc_consumer=0)
|
|
42
|
-
]
|
|
31
|
+
self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
|
|
43
32
|
else:
|
|
44
33
|
self.configs = configs
|
|
45
34
|
self.keys = key
|
|
46
35
|
self.cache: Dict[Tuple, Config] = {}
|
|
47
36
|
self.arg_names = arg_names
|
|
37
|
+
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
|
|
48
38
|
|
|
49
39
|
# Reset to zero or restore values
|
|
50
40
|
self.reset_to_zero = []
|
|
@@ -97,6 +87,7 @@ class Autotuner(KernelInterface):
|
|
|
97
87
|
while not inspect.isfunction(self.base_fn):
|
|
98
88
|
self.base_fn = self.base_fn.fn
|
|
99
89
|
|
|
90
|
+
self._do_bench = do_bench
|
|
100
91
|
self.num_warmups = warmup
|
|
101
92
|
self.num_reps = rep
|
|
102
93
|
self.use_cuda_graph = use_cuda_graph
|
|
@@ -110,7 +101,7 @@ class Autotuner(KernelInterface):
|
|
|
110
101
|
stacklevel=1)
|
|
111
102
|
if use_cuda_graph:
|
|
112
103
|
from ..testing import do_bench_cudagraph
|
|
113
|
-
self.
|
|
104
|
+
self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
|
|
114
105
|
kernel_call,
|
|
115
106
|
rep=rep if rep is not None else 100,
|
|
116
107
|
quantiles=quantiles,
|
|
@@ -118,7 +109,7 @@ class Autotuner(KernelInterface):
|
|
|
118
109
|
return
|
|
119
110
|
|
|
120
111
|
import triton.testing
|
|
121
|
-
self.
|
|
112
|
+
self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
|
|
122
113
|
kernel_call,
|
|
123
114
|
warmup=warmup if warmup is not None else 25,
|
|
124
115
|
rep=rep if rep is not None else 100,
|
|
@@ -126,15 +117,16 @@ class Autotuner(KernelInterface):
|
|
|
126
117
|
)
|
|
127
118
|
return
|
|
128
119
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
120
|
+
@cached_property
|
|
121
|
+
def do_bench(self):
|
|
122
|
+
if self._do_bench is None:
|
|
123
|
+
return driver.active.get_benchmarker()
|
|
124
|
+
return self._do_bench
|
|
133
125
|
|
|
134
126
|
def _bench(self, *args, config, **meta):
|
|
135
127
|
from ..compiler.errors import CompileTimeAssertionFailure
|
|
136
128
|
|
|
137
|
-
verbose =
|
|
129
|
+
verbose = knobs.autotuning.print
|
|
138
130
|
if verbose:
|
|
139
131
|
print(f"Autotuning kernel {self.base_fn.__name__} with config {config}")
|
|
140
132
|
|
|
@@ -173,6 +165,48 @@ class Autotuner(KernelInterface):
|
|
|
173
165
|
print(f"Autotuning failed with {e}")
|
|
174
166
|
return [float("inf"), float("inf"), float("inf")]
|
|
175
167
|
|
|
168
|
+
def check_disk_cache(self, tuning_key, configs, bench_fn):
|
|
169
|
+
# We can't serialize prehooks, so just give up and run the benchmarks.
|
|
170
|
+
if not tuning_key or any(cfg.pre_hook for cfg in configs):
|
|
171
|
+
bench_fn()
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
from triton.compiler.compiler import make_backend
|
|
175
|
+
|
|
176
|
+
fn = self.fn
|
|
177
|
+
while not isinstance(fn, JITFunction):
|
|
178
|
+
fn = fn.fn
|
|
179
|
+
|
|
180
|
+
env_vars = get_cache_invalidating_env_vars()
|
|
181
|
+
cache_key = [
|
|
182
|
+
triton_key(),
|
|
183
|
+
make_backend(driver.active.get_current_target()).hash(),
|
|
184
|
+
fn.cache_key,
|
|
185
|
+
str(sorted(env_vars.items())),
|
|
186
|
+
str(tuning_key),
|
|
187
|
+
] + [str(c) for c in configs]
|
|
188
|
+
cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
|
|
189
|
+
cache = get_cache_manager(cache_key)
|
|
190
|
+
file_name = f"{fn.__name__[:150]}.autotune.json"
|
|
191
|
+
path = cache.get_file(file_name)
|
|
192
|
+
if path:
|
|
193
|
+
with open(path, "r") as cached_configs:
|
|
194
|
+
timings = json.load(cached_configs)["configs_timings"]
|
|
195
|
+
timings = {Config(**config): timing for config, timing in timings}
|
|
196
|
+
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
|
|
197
|
+
self.configs_timings = timings
|
|
198
|
+
return True
|
|
199
|
+
|
|
200
|
+
bench_fn()
|
|
201
|
+
cache.put(
|
|
202
|
+
json.dumps({
|
|
203
|
+
"key":
|
|
204
|
+
tuning_key,
|
|
205
|
+
"configs_timings":
|
|
206
|
+
[(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
|
|
207
|
+
}), file_name, binary=False)
|
|
208
|
+
return False
|
|
209
|
+
|
|
176
210
|
def run(self, *args, **kwargs):
|
|
177
211
|
self.nargs = dict(zip(self.arg_names, args))
|
|
178
212
|
used_cached_result = True
|
|
@@ -185,24 +219,31 @@ class Autotuner(KernelInterface):
|
|
|
185
219
|
key.append(str(arg.dtype))
|
|
186
220
|
key = tuple(key)
|
|
187
221
|
if key not in self.cache:
|
|
188
|
-
# prune configs
|
|
189
222
|
used_cached_result = False
|
|
190
223
|
pruned_configs = self.prune_configs(kwargs)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
224
|
+
|
|
225
|
+
def benchmark():
|
|
226
|
+
bench_start = time.perf_counter()
|
|
227
|
+
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
|
228
|
+
bench_end = time.perf_counter()
|
|
229
|
+
self.bench_time = bench_end - bench_start
|
|
230
|
+
self.cache[key] = builtins.min(timings, key=timings.get)
|
|
231
|
+
full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
|
|
232
|
+
self.pre_hook(full_nargs, reset_only=True)
|
|
233
|
+
self.configs_timings = timings
|
|
234
|
+
|
|
235
|
+
if self.cache_results:
|
|
236
|
+
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
|
|
237
|
+
else:
|
|
238
|
+
benchmark()
|
|
239
|
+
|
|
199
240
|
config = self.cache[key]
|
|
200
241
|
else:
|
|
201
242
|
config = self.configs[0]
|
|
202
243
|
self.best_config = config
|
|
203
|
-
if
|
|
204
|
-
print(f"Triton autotuning for function {self.base_fn.__name__}
|
|
205
|
-
f"{self.bench_time:.2f}s
|
|
244
|
+
if knobs.autotuning.print and not used_cached_result:
|
|
245
|
+
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
|
|
246
|
+
f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
|
|
206
247
|
if config.pre_hook is not None:
|
|
207
248
|
full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
|
|
208
249
|
config.pre_hook(full_nargs)
|
|
@@ -241,11 +282,11 @@ class Autotuner(KernelInterface):
|
|
|
241
282
|
def warmup(self, *args, **kwargs):
|
|
242
283
|
self.nargs = dict(zip(self.arg_names, args))
|
|
243
284
|
ret = []
|
|
244
|
-
for
|
|
285
|
+
for autotune_config in self.prune_configs(kwargs):
|
|
245
286
|
ret.append(self.fn.warmup(
|
|
246
287
|
*args,
|
|
247
288
|
**kwargs,
|
|
248
|
-
**
|
|
289
|
+
**autotune_config.all_kwargs(),
|
|
249
290
|
))
|
|
250
291
|
self.nargs = None
|
|
251
292
|
return ret
|
|
@@ -263,27 +304,34 @@ class Config:
|
|
|
263
304
|
:type num_warps: int
|
|
264
305
|
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
|
265
306
|
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
|
266
|
-
:type
|
|
307
|
+
:type num_stages: int
|
|
267
308
|
:ivar num_ctas: number of blocks in a block cluster. SM90+ only.
|
|
309
|
+
:type num_ctas: int
|
|
268
310
|
:type maxnreg: Optional[int]
|
|
269
311
|
:ivar maxnreg: maximum number of registers one thread can use. Corresponds
|
|
270
312
|
to ptx .maxnreg directive. Not supported on all platforms.
|
|
271
313
|
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
|
272
314
|
function are args.
|
|
315
|
+
:ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}).
|
|
273
316
|
"""
|
|
274
317
|
|
|
275
|
-
def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1,
|
|
276
|
-
reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None):
|
|
318
|
+
def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None):
|
|
277
319
|
self.kwargs = kwargs
|
|
278
320
|
self.num_warps = num_warps
|
|
279
321
|
self.num_ctas = num_ctas
|
|
280
322
|
self.num_stages = num_stages
|
|
281
|
-
self.num_buffers_warp_spec = num_buffers_warp_spec
|
|
282
|
-
self.num_consumer_groups = num_consumer_groups
|
|
283
|
-
self.reg_dec_producer = reg_dec_producer
|
|
284
|
-
self.reg_inc_consumer = reg_inc_consumer
|
|
285
323
|
self.maxnreg = maxnreg
|
|
286
324
|
self.pre_hook = pre_hook
|
|
325
|
+
self.ir_override = ir_override
|
|
326
|
+
|
|
327
|
+
def __setstate__(self, state):
|
|
328
|
+
self.kwargs = state.get("kwargs", {})
|
|
329
|
+
self.num_warps = state.get("num_warps", 4)
|
|
330
|
+
self.num_stages = state.get("num_stages", 3)
|
|
331
|
+
self.num_ctas = state.get("num_ctas", 1)
|
|
332
|
+
self.maxnreg = state.get("maxnreg", None)
|
|
333
|
+
self.pre_hook = state.get("pre_hook", None)
|
|
334
|
+
self.ir_override = state.get("ir_override", None)
|
|
287
335
|
|
|
288
336
|
def all_kwargs(self):
|
|
289
337
|
return {
|
|
@@ -293,11 +341,8 @@ class Config:
|
|
|
293
341
|
("num_warps", self.num_warps),
|
|
294
342
|
("num_ctas", self.num_ctas),
|
|
295
343
|
("num_stages", self.num_stages),
|
|
296
|
-
("num_buffers_warp_spec", self.num_buffers_warp_spec),
|
|
297
|
-
("num_consumer_groups", self.num_consumer_groups),
|
|
298
|
-
("reg_dec_producer", self.reg_dec_producer),
|
|
299
|
-
("reg_inc_consumer", self.reg_inc_consumer),
|
|
300
344
|
("maxnreg", self.maxnreg),
|
|
345
|
+
("ir_override", self.ir_override),
|
|
301
346
|
) if v is not None
|
|
302
347
|
}
|
|
303
348
|
}
|
|
@@ -309,16 +354,26 @@ class Config:
|
|
|
309
354
|
res.append(f"num_warps: {self.num_warps}")
|
|
310
355
|
res.append(f"num_ctas: {self.num_ctas}")
|
|
311
356
|
res.append(f"num_stages: {self.num_stages}")
|
|
312
|
-
res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}")
|
|
313
|
-
res.append(f"num_consumer_groups: {self.num_consumer_groups}")
|
|
314
|
-
res.append(f"reg_dec_producer: {self.reg_dec_producer}")
|
|
315
|
-
res.append(f"reg_inc_consumer: {self.reg_inc_consumer}")
|
|
316
357
|
res.append(f"maxnreg: {self.maxnreg}")
|
|
317
358
|
return ", ".join(res)
|
|
318
359
|
|
|
360
|
+
def __hash__(self):
|
|
361
|
+
return hash((*self.all_kwargs().items(), self.pre_hook))
|
|
362
|
+
|
|
363
|
+
def __eq__(self, other):
|
|
364
|
+
self_tuple = tuple((
|
|
365
|
+
*self.all_kwargs().items(),
|
|
366
|
+
self.pre_hook,
|
|
367
|
+
))
|
|
368
|
+
other_tuple = tuple((
|
|
369
|
+
*other.all_kwargs().items(),
|
|
370
|
+
other.pre_hook,
|
|
371
|
+
))
|
|
372
|
+
return self_tuple == other_tuple
|
|
373
|
+
|
|
319
374
|
|
|
320
375
|
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
|
|
321
|
-
warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
|
|
376
|
+
warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False):
|
|
322
377
|
"""
|
|
323
378
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
|
324
379
|
|
|
@@ -372,12 +427,14 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
|
|
|
372
427
|
:type rep: int
|
|
373
428
|
:param do_bench: a benchmark function to measure the time of each run.
|
|
374
429
|
:type do_bench: lambda fn, quantiles
|
|
430
|
+
:param cache_results: whether to cache autotune timings to disk. Defaults to False.
|
|
431
|
+
"type cache_results: bool
|
|
375
432
|
"""
|
|
376
433
|
|
|
377
434
|
def decorator(fn):
|
|
378
435
|
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
|
|
379
436
|
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
|
|
380
|
-
use_cuda_graph=use_cuda_graph, do_bench=do_bench)
|
|
437
|
+
use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results)
|
|
381
438
|
|
|
382
439
|
return decorator
|
|
383
440
|
|
triton/runtime/build.py
CHANGED
|
@@ -1,14 +1,25 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import functools
|
|
2
|
-
import
|
|
4
|
+
import hashlib
|
|
5
|
+
import importlib.util
|
|
6
|
+
import logging
|
|
3
7
|
import os
|
|
4
8
|
import shutil
|
|
5
9
|
import subprocess
|
|
10
|
+
import sysconfig
|
|
11
|
+
import tempfile
|
|
12
|
+
|
|
13
|
+
from types import ModuleType
|
|
14
|
+
|
|
15
|
+
from .cache import get_cache_manager
|
|
16
|
+
from .. import knobs
|
|
6
17
|
|
|
7
18
|
if os.name == "nt":
|
|
8
19
|
from triton.windows_utils import find_msvc_winsdk, find_python
|
|
9
20
|
|
|
10
21
|
|
|
11
|
-
@functools.
|
|
22
|
+
@functools.lru_cache
|
|
12
23
|
def get_cc():
|
|
13
24
|
cc = os.environ.get("CC")
|
|
14
25
|
if cc is None:
|
|
@@ -30,6 +41,11 @@ def get_cc():
|
|
|
30
41
|
return cc
|
|
31
42
|
|
|
32
43
|
|
|
44
|
+
def is_tcc(cc):
|
|
45
|
+
cc = os.path.basename(cc).lower()
|
|
46
|
+
return cc == "tcc" or cc == "tcc.exe"
|
|
47
|
+
|
|
48
|
+
|
|
33
49
|
def is_msvc(cc):
|
|
34
50
|
cc = os.path.basename(cc).lower()
|
|
35
51
|
return cc == "cl" or cc == "cl.exe"
|
|
@@ -40,10 +56,11 @@ def is_clang(cc):
|
|
|
40
56
|
return cc == "clang" or cc == "clang.exe"
|
|
41
57
|
|
|
42
58
|
|
|
43
|
-
def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries
|
|
59
|
+
def _cc_cmd(cc: str, src: str, out: str, include_dirs: list[str], library_dirs: list[str], libraries: list[str],
|
|
60
|
+
ccflags: list[str]) -> list[str]:
|
|
44
61
|
if is_msvc(cc):
|
|
45
62
|
out_base = os.path.splitext(out)[0]
|
|
46
|
-
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819"]
|
|
63
|
+
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/std:c11", "/wd4819"]
|
|
47
64
|
cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None]
|
|
48
65
|
cc_cmd += [f"/Fo{out_base + '.obj'}"]
|
|
49
66
|
cc_cmd += ["/link"]
|
|
@@ -58,45 +75,94 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
|
|
|
58
75
|
if not (os.name == "nt" and is_clang(cc)):
|
|
59
76
|
# Clang does not support -fPIC on Windows
|
|
60
77
|
cc_cmd += ["-fPIC"]
|
|
78
|
+
if is_tcc(cc):
|
|
79
|
+
cc_cmd += ["-D_Py_USE_GCC_BUILTIN_ATOMICS"]
|
|
61
80
|
cc_cmd += [f'-l{lib}' for lib in libraries]
|
|
62
81
|
cc_cmd += [f"-L{dir}" for dir in library_dirs]
|
|
63
82
|
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
|
|
83
|
+
cc_cmd += ccflags
|
|
64
84
|
return cc_cmd
|
|
65
85
|
|
|
66
86
|
|
|
67
|
-
def _build(name, src, srcdir, library_dirs, include_dirs, libraries
|
|
87
|
+
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
|
|
88
|
+
ccflags: list[str]) -> str:
|
|
89
|
+
if impl := knobs.build.impl:
|
|
90
|
+
return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
|
|
68
91
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
|
69
92
|
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
|
70
|
-
# try to avoid setuptools if possible
|
|
71
93
|
cc = get_cc()
|
|
72
94
|
# This function was renamed and made public in Python 3.10
|
|
73
95
|
if hasattr(sysconfig, 'get_default_scheme'):
|
|
74
96
|
scheme = sysconfig.get_default_scheme()
|
|
75
97
|
else:
|
|
76
|
-
scheme = sysconfig._get_default_scheme()
|
|
98
|
+
scheme = sysconfig._get_default_scheme() # type: ignore
|
|
77
99
|
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
|
|
78
100
|
# path changes to include 'local'. This change is required to use triton with system-wide python.
|
|
79
101
|
if scheme == 'posix_local':
|
|
80
102
|
scheme = 'posix_prefix'
|
|
81
103
|
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
|
82
|
-
custom_backend_dirs =
|
|
104
|
+
custom_backend_dirs = knobs.build.backend_dirs
|
|
105
|
+
# Don't append in place
|
|
83
106
|
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
|
|
84
107
|
if os.name == "nt":
|
|
85
|
-
library_dirs
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
libraries
|
|
108
|
+
library_dirs = library_dirs + find_python()
|
|
109
|
+
version = sysconfig.get_python_version().replace(".", "")
|
|
110
|
+
if sysconfig.get_config_var("Py_GIL_DISABLED"):
|
|
111
|
+
version += "t"
|
|
112
|
+
libraries = libraries + [f"python{version}"]
|
|
90
113
|
if is_msvc(cc):
|
|
91
114
|
_, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
|
|
92
|
-
include_dirs
|
|
93
|
-
library_dirs
|
|
94
|
-
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
|
|
115
|
+
include_dirs = include_dirs + msvc_winsdk_inc_dirs
|
|
116
|
+
library_dirs = library_dirs + msvc_winsdk_lib_dirs
|
|
117
|
+
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries, ccflags)
|
|
95
118
|
|
|
96
119
|
try:
|
|
97
|
-
|
|
120
|
+
subprocess.check_call(cc_cmd)
|
|
98
121
|
except Exception as e:
|
|
99
122
|
print("Failed to compile. cc_cmd:", cc_cmd)
|
|
100
123
|
raise e
|
|
101
124
|
|
|
102
125
|
return so
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@functools.lru_cache
|
|
129
|
+
def platform_key() -> str:
|
|
130
|
+
from platform import machine, system, architecture
|
|
131
|
+
return ",".join([machine(), system(), *architecture()])
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _load_module_from_path(name: str, path: str) -> ModuleType:
|
|
135
|
+
# Loading module with relative path may cause error
|
|
136
|
+
path = os.path.abspath(path)
|
|
137
|
+
spec = importlib.util.spec_from_file_location(name, path)
|
|
138
|
+
if not spec or not spec.loader:
|
|
139
|
+
raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
|
|
140
|
+
mod = importlib.util.module_from_spec(spec)
|
|
141
|
+
spec.loader.exec_module(mod)
|
|
142
|
+
return mod
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
|
|
146
|
+
include_dirs: list[str] | None = None, libraries: list[str] | None = None,
|
|
147
|
+
ccflags: list[str] | None = None) -> ModuleType:
|
|
148
|
+
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
|
|
149
|
+
cache = get_cache_manager(key)
|
|
150
|
+
suffix = sysconfig.get_config_var("EXT_SUFFIX")
|
|
151
|
+
cache_path = cache.get_file(f"{name}{suffix}")
|
|
152
|
+
|
|
153
|
+
if cache_path is not None:
|
|
154
|
+
try:
|
|
155
|
+
return _load_module_from_path(name, cache_path)
|
|
156
|
+
except (RuntimeError, ImportError):
|
|
157
|
+
log = logging.getLogger(__name__)
|
|
158
|
+
log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
|
|
159
|
+
|
|
160
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
161
|
+
src_path = os.path.join(tmpdir, name + ".c")
|
|
162
|
+
with open(src_path, "w") as f:
|
|
163
|
+
f.write(src)
|
|
164
|
+
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
|
|
165
|
+
with open(so, "rb") as f:
|
|
166
|
+
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
|
|
167
|
+
|
|
168
|
+
return _load_module_from_path(name, cache_path)
|
triton/runtime/cache.py
CHANGED
|
@@ -1,33 +1,19 @@
|
|
|
1
|
-
import importlib
|
|
2
1
|
import json
|
|
3
2
|
import os
|
|
4
3
|
import uuid
|
|
5
4
|
from abc import ABC, abstractmethod
|
|
6
|
-
from pathlib import Path
|
|
7
5
|
from typing import Dict, List, Optional
|
|
8
6
|
import base64
|
|
9
7
|
import hashlib
|
|
8
|
+
import functools
|
|
9
|
+
import sysconfig
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
def get_home_dir():
|
|
13
|
-
return os.getenv("TRITON_HOME", Path.home())
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def default_cache_dir():
|
|
17
|
-
return os.path.join(get_home_dir(), ".triton", "cache")
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def default_override_dir():
|
|
21
|
-
return os.path.join(get_home_dir(), ".triton", "override")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def default_dump_dir():
|
|
25
|
-
return os.path.join(get_home_dir(), ".triton", "dump")
|
|
11
|
+
from triton import __version__, knobs
|
|
26
12
|
|
|
27
13
|
|
|
28
14
|
class CacheManager(ABC):
|
|
29
15
|
|
|
30
|
-
def __init__(self, key):
|
|
16
|
+
def __init__(self, key, override=False, dump=False):
|
|
31
17
|
pass
|
|
32
18
|
|
|
33
19
|
@abstractmethod
|
|
@@ -53,16 +39,16 @@ class FileCacheManager(CacheManager):
|
|
|
53
39
|
self.key = key
|
|
54
40
|
self.lock_path = None
|
|
55
41
|
if dump:
|
|
56
|
-
self.cache_dir =
|
|
42
|
+
self.cache_dir = knobs.cache.dump_dir
|
|
57
43
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
58
44
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
59
45
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
60
46
|
elif override:
|
|
61
|
-
self.cache_dir =
|
|
47
|
+
self.cache_dir = knobs.cache.override_dir
|
|
62
48
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
63
49
|
else:
|
|
64
50
|
# create cache directory if it doesn't exist
|
|
65
|
-
self.cache_dir =
|
|
51
|
+
self.cache_dir = knobs.cache.dir
|
|
66
52
|
if self.cache_dir:
|
|
67
53
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
68
54
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
@@ -166,10 +152,10 @@ class RedisRemoteCacheBackend(RemoteCacheBackend):
|
|
|
166
152
|
def __init__(self, key):
|
|
167
153
|
import redis
|
|
168
154
|
self._key = key
|
|
169
|
-
self._key_fmt =
|
|
155
|
+
self._key_fmt = knobs.cache.redis.key_format
|
|
170
156
|
self._redis = redis.Redis(
|
|
171
|
-
host=
|
|
172
|
-
port=
|
|
157
|
+
host=knobs.cache.redis.host,
|
|
158
|
+
port=knobs.cache.redis.port,
|
|
173
159
|
)
|
|
174
160
|
|
|
175
161
|
def _get_key(self, filename: str) -> str:
|
|
@@ -187,10 +173,10 @@ class RemoteCacheManager(CacheManager):
|
|
|
187
173
|
|
|
188
174
|
def __init__(self, key, override=False, dump=False):
|
|
189
175
|
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
176
|
+
remote_cache_cls = knobs.cache.remote_manager_class
|
|
177
|
+
if not remote_cache_cls:
|
|
178
|
+
raise RuntimeError(
|
|
179
|
+
"Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
|
|
194
180
|
self._backend = remote_cache_cls(key)
|
|
195
181
|
|
|
196
182
|
self._override = override
|
|
@@ -260,37 +246,24 @@ class RemoteCacheManager(CacheManager):
|
|
|
260
246
|
return self.put(grp_contents, grp_filename)
|
|
261
247
|
|
|
262
248
|
|
|
263
|
-
__cache_cls = FileCacheManager
|
|
264
|
-
__cache_cls_nme = "DEFAULT"
|
|
265
|
-
|
|
266
|
-
|
|
267
249
|
def _base32(key):
|
|
268
250
|
# Assume key is a hex string.
|
|
269
251
|
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
|
270
252
|
|
|
271
253
|
|
|
272
254
|
def get_cache_manager(key) -> CacheManager:
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
|
|
276
|
-
global __cache_cls
|
|
277
|
-
global __cache_cls_nme
|
|
278
|
-
|
|
279
|
-
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
|
|
280
|
-
module_path, clz_nme = user_cache_manager.split(":")
|
|
281
|
-
module = importlib.import_module(module_path)
|
|
282
|
-
__cache_cls = getattr(module, clz_nme)
|
|
283
|
-
__cache_cls_nme = user_cache_manager
|
|
284
|
-
|
|
285
|
-
return __cache_cls(_base32(key))
|
|
255
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
256
|
+
return cls(_base32(key))
|
|
286
257
|
|
|
287
258
|
|
|
288
259
|
def get_override_manager(key) -> CacheManager:
|
|
289
|
-
|
|
260
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
261
|
+
return cls(_base32(key), override=True)
|
|
290
262
|
|
|
291
263
|
|
|
292
264
|
def get_dump_manager(key) -> CacheManager:
|
|
293
|
-
|
|
265
|
+
cls = knobs.cache.manager_class or FileCacheManager
|
|
266
|
+
return cls(_base32(key), dump=True)
|
|
294
267
|
|
|
295
268
|
|
|
296
269
|
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
@@ -301,3 +274,44 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
|
301
274
|
key = f"{key}-{kwargs.get(kw)}"
|
|
302
275
|
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
303
276
|
return _base32(key)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@functools.lru_cache()
|
|
280
|
+
def triton_key():
|
|
281
|
+
import pkgutil
|
|
282
|
+
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
283
|
+
contents = []
|
|
284
|
+
# frontend
|
|
285
|
+
with open(__file__, "rb") as f:
|
|
286
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
287
|
+
# compiler
|
|
288
|
+
path_prefixes = [
|
|
289
|
+
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
|
|
290
|
+
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
|
|
291
|
+
]
|
|
292
|
+
for path, prefix in path_prefixes:
|
|
293
|
+
for lib in pkgutil.walk_packages([path], prefix=prefix):
|
|
294
|
+
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
295
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
296
|
+
|
|
297
|
+
# backend
|
|
298
|
+
libtriton_hash = hashlib.sha256()
|
|
299
|
+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
300
|
+
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
|
|
301
|
+
while True:
|
|
302
|
+
chunk = f.read(1024**2)
|
|
303
|
+
if not chunk:
|
|
304
|
+
break
|
|
305
|
+
libtriton_hash.update(chunk)
|
|
306
|
+
contents.append(libtriton_hash.hexdigest())
|
|
307
|
+
# language
|
|
308
|
+
language_path = os.path.join(TRITON_PATH, 'language')
|
|
309
|
+
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
|
|
310
|
+
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
311
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
312
|
+
return f'{__version__}' + '-'.join(contents)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def get_cache_key(src, backend, backend_options, env_vars):
|
|
316
|
+
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
|
|
317
|
+
return key
|