quack-kernels 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/autotuner.py +64 -5
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/rmsnorm.py +83 -149
- quack/tile_scheduler.py +34 -47
- quack/utils.py +61 -8
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +2 -2
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/RECORD +18 -18
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/autotuner.py
CHANGED
|
@@ -11,7 +11,7 @@ import hashlib
|
|
|
11
11
|
import json
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from functools import cached_property, partial
|
|
14
|
-
from typing import Dict, Tuple
|
|
14
|
+
from typing import Dict, Tuple, List, Optional, Any
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import Tensor
|
|
@@ -53,7 +53,22 @@ def _base32(key):
|
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
class Autotuner:
|
|
56
|
-
def __init__(
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
fn,
|
|
59
|
+
key,
|
|
60
|
+
configs,
|
|
61
|
+
restore_value=None,
|
|
62
|
+
prune_configs_by: Optional[Dict] = None,
|
|
63
|
+
do_bench=None,
|
|
64
|
+
cache_results=False,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
68
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
69
|
+
'top_k': number of configs to bench
|
|
70
|
+
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
|
71
|
+
"""
|
|
57
72
|
if not configs:
|
|
58
73
|
self.configs = [AutotuneConfig()]
|
|
59
74
|
else:
|
|
@@ -90,6 +105,16 @@ class Autotuner:
|
|
|
90
105
|
else:
|
|
91
106
|
self.post_hook = None
|
|
92
107
|
|
|
108
|
+
self.perf_model = None
|
|
109
|
+
self.configs_top_k = 1.0
|
|
110
|
+
self.early_config_prune = None
|
|
111
|
+
if prune_configs_by:
|
|
112
|
+
self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
|
|
113
|
+
self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
|
|
114
|
+
self.early_config_prune = prune_configs_by.get(
|
|
115
|
+
"early_config_prune", self.early_config_prune
|
|
116
|
+
)
|
|
117
|
+
|
|
93
118
|
self.fn = fn
|
|
94
119
|
self._do_bench = do_bench
|
|
95
120
|
|
|
@@ -198,13 +223,14 @@ class Autotuner:
|
|
|
198
223
|
key = tuple(key)
|
|
199
224
|
if key not in self.cache:
|
|
200
225
|
used_cached_result = False
|
|
226
|
+
pruned_configs = self.prune_configs(kwargs)
|
|
201
227
|
|
|
202
228
|
@torch.compiler.disable # Don't want any tracing here
|
|
203
229
|
def benchmark():
|
|
204
230
|
bench_start = time.time()
|
|
205
231
|
timings = {
|
|
206
232
|
config: self._bench(*args, config=config, **kwargs)
|
|
207
|
-
for config in
|
|
233
|
+
for config in pruned_configs
|
|
208
234
|
}
|
|
209
235
|
bench_end = time.time()
|
|
210
236
|
if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
|
|
@@ -215,7 +241,7 @@ class Autotuner:
|
|
|
215
241
|
self.configs_timings = timings
|
|
216
242
|
|
|
217
243
|
if self.cache_results:
|
|
218
|
-
self.check_disk_cache(key,
|
|
244
|
+
self.check_disk_cache(key, pruned_configs, benchmark)
|
|
219
245
|
else:
|
|
220
246
|
benchmark()
|
|
221
247
|
|
|
@@ -239,6 +265,32 @@ class Autotuner:
|
|
|
239
265
|
self.nargs = None
|
|
240
266
|
return ret
|
|
241
267
|
|
|
268
|
+
def prune_configs(self, kwargs: Dict) -> List[Any]:
|
|
269
|
+
pruned_configs = self.configs
|
|
270
|
+
if self.early_config_prune:
|
|
271
|
+
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
|
|
272
|
+
if self.perf_model:
|
|
273
|
+
top_k = self.configs_top_k
|
|
274
|
+
if isinstance(top_k, float) and top_k <= 1.0:
|
|
275
|
+
top_k = int(len(self.configs) * top_k)
|
|
276
|
+
elif not isinstance(top_k, int):
|
|
277
|
+
# Slice index must be an integer
|
|
278
|
+
raise TypeError(
|
|
279
|
+
"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if len(pruned_configs) > top_k:
|
|
283
|
+
est_timing = {
|
|
284
|
+
config: self.perf_model(
|
|
285
|
+
**self.nargs,
|
|
286
|
+
**kwargs,
|
|
287
|
+
**config.all_kwargs(),
|
|
288
|
+
)
|
|
289
|
+
for config in pruned_configs
|
|
290
|
+
}
|
|
291
|
+
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
|
292
|
+
return pruned_configs
|
|
293
|
+
|
|
242
294
|
|
|
243
295
|
class AutotuneConfig:
|
|
244
296
|
"""
|
|
@@ -272,7 +324,9 @@ class AutotuneConfig:
|
|
|
272
324
|
return self_tuple == other_tuple
|
|
273
325
|
|
|
274
326
|
|
|
275
|
-
def autotune(
|
|
327
|
+
def autotune(
|
|
328
|
+
configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
|
|
329
|
+
):
|
|
276
330
|
f"""
|
|
277
331
|
Decorator for auto-tuning a function function.
|
|
278
332
|
|
|
@@ -286,6 +340,10 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
|
|
|
286
340
|
:type configs: list[AutotuneConfig]
|
|
287
341
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
288
342
|
:type key: list[str]
|
|
343
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
344
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
345
|
+
'top_k': number of configs to bench
|
|
346
|
+
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
|
289
347
|
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
|
290
348
|
:type restore_value: list[str]
|
|
291
349
|
:param do_bench: a benchmark function to measure the time of each run.
|
|
@@ -303,6 +361,7 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
|
|
|
303
361
|
key,
|
|
304
362
|
configs,
|
|
305
363
|
restore_value=restore_value,
|
|
364
|
+
prune_configs_by=prune_configs_by,
|
|
306
365
|
do_bench=do_bench,
|
|
307
366
|
cache_results=cache_results,
|
|
308
367
|
)
|
quack/cute_dsl_utils.py
CHANGED
|
@@ -98,22 +98,21 @@ class ArgumentsBase(JitArgument):
|
|
|
98
98
|
|
|
99
99
|
|
|
100
100
|
def load_cubin_module_data_patched(cubin_data, filepath):
|
|
101
|
-
|
|
102
|
-
path.write_bytes(cubin_data)
|
|
101
|
+
pathlib.Path(filepath).write_bytes(cubin_data)
|
|
103
102
|
return load_cubin_module_data_og(cubin_data)
|
|
104
103
|
|
|
105
104
|
|
|
106
105
|
def cute_compile_patched(*args, **kwargs):
|
|
107
106
|
"""A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
|
|
108
|
-
|
|
107
|
+
cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
|
|
108
|
+
if cubin_path is not None:
|
|
109
109
|
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
|
|
110
|
-
load_cubin_module_data_patched, filepath=
|
|
110
|
+
load_cubin_module_data_patched, filepath=cubin_path
|
|
111
111
|
)
|
|
112
112
|
output = cute_compile_og(*args, **kwargs)
|
|
113
|
-
if
|
|
113
|
+
if cubin_path is not None:
|
|
114
114
|
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
|
|
115
115
|
if extract is not None:
|
|
116
|
-
cubin_path = pathlib.Path(os.getenv("CUTE_CUBIN_PATH"))
|
|
117
116
|
sass = extract(cubin_path, None)
|
|
118
|
-
cubin_path.with_suffix(".annotated.sass").write_text(sass)
|
|
117
|
+
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
|
|
119
118
|
return output
|