quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +4 -1
- quack/autotuner.py +309 -0
- quack/cross_entropy.py +2 -5
- quack/cute_dsl_utils.py +40 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +2474 -0
- quack/fast_math.py +97 -0
- quack/gemm_config.py +61 -0
- quack/gemm_interface.py +321 -0
- quack/linear.py +176 -0
- quack/lse.py +62 -0
- quack/mlp.py +204 -0
- quack/pipeline.py +166 -0
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2088 -0
- quack/tensormap_manager.py +114 -0
- quack/tile_scheduler.py +935 -0
- quack/topk.py +221 -0
- quack/utils.py +237 -19
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/METADATA +3 -3
- quack_kernels-0.1.11.dist-info/RECORD +31 -0
- quack_kernels-0.1.9.dist-info/RECORD +0 -12
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.11"
|
|
2
2
|
|
|
3
3
|
from quack.rmsnorm import rmsnorm
|
|
4
4
|
from quack.softmax import softmax
|
|
5
5
|
from quack.cross_entropy import cross_entropy
|
|
6
6
|
|
|
7
|
+
# ruff: noqa
|
|
8
|
+
import quack.cute_dsl_utils # Patch cute.compile to optionally dump SASS
|
|
9
|
+
|
|
7
10
|
__all__ = [
|
|
8
11
|
"rmsnorm",
|
|
9
12
|
"softmax",
|
quack/autotuner.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
# Adapted from https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py
|
|
2
|
+
# Copyright (C) 2025, Tri Dao.
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import builtins
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
import inspect
|
|
9
|
+
import base64
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from functools import cached_property, partial
|
|
14
|
+
from typing import Dict, Tuple
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import Tensor
|
|
18
|
+
|
|
19
|
+
import triton
|
|
20
|
+
|
|
21
|
+
from . import __version__
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
PACKAGE_NAME = "quack"
|
|
25
|
+
VERSION = __version__
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_home_dir():
|
|
29
|
+
return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def default_cache_dir():
|
|
33
|
+
return os.path.join(get_home_dir(), f".{PACKAGE_NAME}", "cache")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class FileCacheManager(triton.runtime.cache.FileCacheManager):
|
|
37
|
+
def __init__(self, key):
|
|
38
|
+
super().__init__(key)
|
|
39
|
+
self.cache_dir = (
|
|
40
|
+
os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_DIR", "").strip() or default_cache_dir()
|
|
41
|
+
)
|
|
42
|
+
if self.cache_dir:
|
|
43
|
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
44
|
+
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
45
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
46
|
+
else:
|
|
47
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _base32(key):
|
|
51
|
+
# Assume key is a hex string.
|
|
52
|
+
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Autotuner:
|
|
56
|
+
def __init__(self, fn, key, configs, restore_value=None, do_bench=None, cache_results=False):
|
|
57
|
+
if not configs:
|
|
58
|
+
self.configs = [AutotuneConfig()]
|
|
59
|
+
else:
|
|
60
|
+
self.configs = configs
|
|
61
|
+
signature = inspect.signature(fn)
|
|
62
|
+
self.keys = key
|
|
63
|
+
self.cache: Dict[Tuple, AutotuneConfig] = {}
|
|
64
|
+
self.arg_names = list(signature.parameters.keys())
|
|
65
|
+
self.cache_results = (
|
|
66
|
+
cache_results or os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_AUTOTUNING", None) == "1"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.restore_value = []
|
|
70
|
+
if restore_value is not None:
|
|
71
|
+
self.restore_value = list(restore_value)
|
|
72
|
+
|
|
73
|
+
if len(self.restore_value) > 0:
|
|
74
|
+
|
|
75
|
+
def _pre_hook(kwargs):
|
|
76
|
+
self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value}
|
|
77
|
+
|
|
78
|
+
self.pre_hook = _pre_hook
|
|
79
|
+
else:
|
|
80
|
+
self.pre_hook = None
|
|
81
|
+
|
|
82
|
+
if len(self.restore_value) > 0:
|
|
83
|
+
|
|
84
|
+
def _post_hook(kwargs, exception):
|
|
85
|
+
for name in self.restore_value:
|
|
86
|
+
kwargs[name].copy_(self.restore_copies[name])
|
|
87
|
+
self.restore_copies = {}
|
|
88
|
+
|
|
89
|
+
self.post_hook = _post_hook
|
|
90
|
+
else:
|
|
91
|
+
self.post_hook = None
|
|
92
|
+
|
|
93
|
+
self.fn = fn
|
|
94
|
+
self._do_bench = do_bench
|
|
95
|
+
|
|
96
|
+
@cached_property
|
|
97
|
+
def do_bench(self):
|
|
98
|
+
if self._do_bench is None:
|
|
99
|
+
return partial(triton.testing.do_bench, warmup=5, rep=25)
|
|
100
|
+
return self._do_bench
|
|
101
|
+
|
|
102
|
+
def _bench(self, *args, config, **meta):
|
|
103
|
+
verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
|
104
|
+
if verbose:
|
|
105
|
+
print(f"Autotuning kernel {self.fn.__name__} with config {config}")
|
|
106
|
+
|
|
107
|
+
# check for conflicts, i.e. meta-parameters both provided
|
|
108
|
+
# as kwargs and by the autotuner
|
|
109
|
+
conflicts = meta.keys() & config.kwargs.keys()
|
|
110
|
+
if conflicts:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
|
113
|
+
" Make sure that you don't re-define auto-tuned symbols."
|
|
114
|
+
)
|
|
115
|
+
# augment meta-parameters with tunable ones
|
|
116
|
+
current = dict(meta, **config.all_kwargs())
|
|
117
|
+
full_nargs = {**self.nargs, **current}
|
|
118
|
+
|
|
119
|
+
def kernel_call():
|
|
120
|
+
if self.pre_hook is not None:
|
|
121
|
+
self.pre_hook(full_nargs)
|
|
122
|
+
try:
|
|
123
|
+
self.fn.__call__(
|
|
124
|
+
*args,
|
|
125
|
+
**current,
|
|
126
|
+
)
|
|
127
|
+
except Exception as e:
|
|
128
|
+
try:
|
|
129
|
+
if self.post_hook is not None:
|
|
130
|
+
self.post_hook(full_nargs, exception=e)
|
|
131
|
+
finally:
|
|
132
|
+
# Throw exception raised by `self.fn.run`
|
|
133
|
+
raise
|
|
134
|
+
|
|
135
|
+
if self.post_hook is not None:
|
|
136
|
+
self.post_hook(full_nargs, exception=None)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
|
|
140
|
+
except Exception as e:
|
|
141
|
+
if verbose:
|
|
142
|
+
print(f"Autotuning failed with {e}")
|
|
143
|
+
return [float("inf"), float("inf"), float("inf")]
|
|
144
|
+
|
|
145
|
+
@torch.compiler.disable
|
|
146
|
+
def check_disk_cache(self, tuning_key, configs, bench_fn):
|
|
147
|
+
if not tuning_key:
|
|
148
|
+
bench_fn()
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
fn = self.fn
|
|
152
|
+
config_str_list = [str(c) for c in configs]
|
|
153
|
+
assert len(config_str_list) == len(set(config_str_list)), "Config strings must be unique"
|
|
154
|
+
cache_key = [VERSION, str(tuning_key)] + config_str_list
|
|
155
|
+
cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
|
|
156
|
+
cache = FileCacheManager(_base32(cache_key))
|
|
157
|
+
file_name = f"{fn.__name__[:150]}.autotune.json"
|
|
158
|
+
path = cache.get_file(file_name)
|
|
159
|
+
# There's an environment variable to force cache update
|
|
160
|
+
if path and not os.environ.get(f"{PACKAGE_NAME.upper()}_FORCE_CACHE_UPDATE", False):
|
|
161
|
+
str2config = {s: c for s, c in zip(config_str_list, configs)}
|
|
162
|
+
with open(path, "r") as cached_configs:
|
|
163
|
+
timings = json.load(cached_configs)["configs_timings"]
|
|
164
|
+
timings = {str2config[config]: timing for config, timing in timings}
|
|
165
|
+
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
|
|
166
|
+
self.configs_timings = timings
|
|
167
|
+
self.bench_time = 0
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
bench_fn()
|
|
171
|
+
cache.put(
|
|
172
|
+
json.dumps(
|
|
173
|
+
{
|
|
174
|
+
"key": tuning_key,
|
|
175
|
+
"configs_timings": [
|
|
176
|
+
(str(config), timings) for config, timings in self.configs_timings.items()
|
|
177
|
+
],
|
|
178
|
+
}
|
|
179
|
+
),
|
|
180
|
+
file_name,
|
|
181
|
+
binary=False,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def __call__(self, *args, **kwargs):
|
|
185
|
+
self.nargs = dict(zip(self.arg_names, args))
|
|
186
|
+
used_cached_result = True
|
|
187
|
+
if len(self.configs) > 1:
|
|
188
|
+
all_args = {**self.nargs, **kwargs}
|
|
189
|
+
_args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
|
|
190
|
+
key = [_args[key] for key in self.keys if key in _args]
|
|
191
|
+
for _, arg in _args.items():
|
|
192
|
+
if isinstance(arg, Tensor):
|
|
193
|
+
key.append(str(arg.shape))
|
|
194
|
+
# If stride != 0, 1, we just cache it as 2
|
|
195
|
+
key.append(str([s if s in {0, 1} else 2 for s in arg.stride()]))
|
|
196
|
+
key.append(str(arg.dtype))
|
|
197
|
+
key = tuple(key)
|
|
198
|
+
if key not in self.cache:
|
|
199
|
+
used_cached_result = False
|
|
200
|
+
|
|
201
|
+
@torch.compiler.disable # Don't want any tracing here
|
|
202
|
+
def benchmark():
|
|
203
|
+
bench_start = time.time()
|
|
204
|
+
timings = {
|
|
205
|
+
config: self._bench(*args, config=config, **kwargs)
|
|
206
|
+
for config in self.configs
|
|
207
|
+
}
|
|
208
|
+
bench_end = time.time()
|
|
209
|
+
if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
|
|
210
|
+
for config, time_ in timings.items():
|
|
211
|
+
print(f"[{config}] -> {time_[0]:.3f}ms")
|
|
212
|
+
self.bench_time = bench_end - bench_start
|
|
213
|
+
self.cache[key] = builtins.min(timings, key=timings.get)
|
|
214
|
+
self.configs_timings = timings
|
|
215
|
+
|
|
216
|
+
if self.cache_results:
|
|
217
|
+
self.check_disk_cache(key, self.configs, benchmark)
|
|
218
|
+
else:
|
|
219
|
+
benchmark()
|
|
220
|
+
|
|
221
|
+
config = self.cache[key]
|
|
222
|
+
else:
|
|
223
|
+
config = self.configs[0]
|
|
224
|
+
self.best_config = config
|
|
225
|
+
if (
|
|
226
|
+
os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
|
227
|
+
and not used_cached_result
|
|
228
|
+
):
|
|
229
|
+
print(
|
|
230
|
+
f"{PACKAGE_NAME} autotuning for function {self.fn.__name__} finished after "
|
|
231
|
+
f"{self.bench_time:.2f}s; best config selected: {self.best_config};"
|
|
232
|
+
)
|
|
233
|
+
ret = self.fn.__call__(
|
|
234
|
+
*args,
|
|
235
|
+
**kwargs,
|
|
236
|
+
**config.all_kwargs(),
|
|
237
|
+
)
|
|
238
|
+
self.nargs = None
|
|
239
|
+
return ret
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class AutotuneConfig:
|
|
243
|
+
"""
|
|
244
|
+
An object that represents a possible kernel configuration for the auto-tuner to try.
|
|
245
|
+
|
|
246
|
+
:ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
|
247
|
+
:type kwargs: dict[Str, Any]
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self, **kwargs):
|
|
251
|
+
self.kwargs = kwargs
|
|
252
|
+
|
|
253
|
+
def __setstate__(self, state):
|
|
254
|
+
self.kwargs = state.get("kwargs", {})
|
|
255
|
+
|
|
256
|
+
def all_kwargs(self):
|
|
257
|
+
return self.kwargs
|
|
258
|
+
|
|
259
|
+
def __str__(self):
|
|
260
|
+
res = []
|
|
261
|
+
for k, v in self.kwargs.items():
|
|
262
|
+
res.append(f"{k}: {v}")
|
|
263
|
+
return ", ".join(res)
|
|
264
|
+
|
|
265
|
+
def __hash__(self):
|
|
266
|
+
return hash(tuple(*self.all_kwargs().items()))
|
|
267
|
+
|
|
268
|
+
def __eq__(self, other):
|
|
269
|
+
self_tuple = tuple(*self.all_kwargs().items())
|
|
270
|
+
other_tuple = tuple(*other.all_kwargs().items())
|
|
271
|
+
return self_tuple == other_tuple
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results=True):
|
|
275
|
+
f"""
|
|
276
|
+
Decorator for auto-tuning a function function.
|
|
277
|
+
|
|
278
|
+
.. highlight:: python
|
|
279
|
+
|
|
280
|
+
If the environment variable :code:`{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING` is set to
|
|
281
|
+
:code:`"1"`, we will print a message to stdout after autotuning each
|
|
282
|
+
kernel, including the time spent autotuning and the best configuration.
|
|
283
|
+
|
|
284
|
+
:param configs: a list of :code:`AutotuneConfig` objects
|
|
285
|
+
:type configs: list[AutotuneConfig]
|
|
286
|
+
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
287
|
+
:type key: list[str]
|
|
288
|
+
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
|
289
|
+
:type restore_value: list[str]
|
|
290
|
+
:param do_bench: a benchmark function to measure the time of each run.
|
|
291
|
+
:type do_bench: lambda fn, quantiles
|
|
292
|
+
:param cache_results: whether to cache autotune timings to disk. Defaults to False.
|
|
293
|
+
"type cache_results: bool
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
if key is None:
|
|
297
|
+
key = []
|
|
298
|
+
|
|
299
|
+
def decorator(fn):
|
|
300
|
+
return Autotuner(
|
|
301
|
+
fn,
|
|
302
|
+
key,
|
|
303
|
+
configs,
|
|
304
|
+
restore_value=restore_value,
|
|
305
|
+
do_bench=do_bench,
|
|
306
|
+
cache_results=cache_results,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return decorator
|
quack/cross_entropy.py
CHANGED
|
@@ -446,13 +446,10 @@ class CrossEntropyBackward:
|
|
|
446
446
|
log2_e = math.log2(math.e)
|
|
447
447
|
probs = utils.exp2f((x - lse) * log2_e)
|
|
448
448
|
prob_shifted = probs - 1.0
|
|
449
|
-
|
|
450
449
|
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
451
|
-
for i in cutlass.
|
|
450
|
+
for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
|
|
452
451
|
mask[i] = tXcFull[i][1] == label
|
|
453
|
-
|
|
454
|
-
mask = mask.load()
|
|
455
|
-
grad = cute.where(mask, prob_shifted, probs)
|
|
452
|
+
grad = cute.where(mask.load(), prob_shifted, probs)
|
|
456
453
|
grad = grad * dloss
|
|
457
454
|
|
|
458
455
|
tXrO.store(grad.to(tXrO.element_type))
|
quack/cute_dsl_utils.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pathlib
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from triton.tools.disasm import extract
|
|
7
|
+
except ImportError:
|
|
8
|
+
extract = None
|
|
9
|
+
|
|
10
|
+
import cutlass
|
|
11
|
+
import cutlass.cute as cute
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
|
15
|
+
cute_compile_og = cute.compile
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def load_cubin_module_data_patched(cubin_data, filepath):
|
|
19
|
+
path = pathlib.Path(filepath)
|
|
20
|
+
path.write_bytes(cubin_data)
|
|
21
|
+
return load_cubin_module_data_og(cubin_data)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def cute_compile_patched(*args, **kwargs):
|
|
25
|
+
"""A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
|
|
26
|
+
if os.getenv("CUTE_CUBIN_PATH") is not None:
|
|
27
|
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
|
|
28
|
+
load_cubin_module_data_patched, filepath=os.getenv("CUTE_CUBIN_PATH")
|
|
29
|
+
)
|
|
30
|
+
output = cute_compile_og(*args, **kwargs)
|
|
31
|
+
if os.getenv("CUTE_CUBIN_PATH") is not None:
|
|
32
|
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
|
|
33
|
+
if extract is not None:
|
|
34
|
+
cubin_path = pathlib.Path(os.getenv("CUTE_CUBIN_PATH"))
|
|
35
|
+
sass = extract(cubin_path, None)
|
|
36
|
+
cubin_path.with_suffix(".annotated.sass").write_text(sass)
|
|
37
|
+
return output
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
cute.compile = cute_compile_patched
|