quack-kernels 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,9 +1,12 @@
1
- __version__ = "0.1.10"
1
+ __version__ = "0.1.11"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
5
5
  from quack.cross_entropy import cross_entropy
6
6
 
7
+ # ruff: noqa
8
+ import quack.cute_dsl_utils # Patch cute.compile to optionally dump SASS
9
+
7
10
  __all__ = [
8
11
  "rmsnorm",
9
12
  "softmax",
quack/autotuner.py ADDED
@@ -0,0 +1,309 @@
1
+ # Adapted from https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py
2
+ # Copyright (C) 2025, Tri Dao.
3
+ from __future__ import annotations
4
+
5
+ import builtins
6
+ import os
7
+ import time
8
+ import inspect
9
+ import base64
10
+ import hashlib
11
+ import json
12
+ from pathlib import Path
13
+ from functools import cached_property, partial
14
+ from typing import Dict, Tuple
15
+
16
+ import torch
17
+ from torch import Tensor
18
+
19
+ import triton
20
+
21
+ from . import __version__
22
+
23
+
24
+ PACKAGE_NAME = "quack"
25
+ VERSION = __version__
26
+
27
+
28
+ def get_home_dir():
29
+ return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
30
+
31
+
32
+ def default_cache_dir():
33
+ return os.path.join(get_home_dir(), f".{PACKAGE_NAME}", "cache")
34
+
35
+
36
+ class FileCacheManager(triton.runtime.cache.FileCacheManager):
37
+ def __init__(self, key):
38
+ super().__init__(key)
39
+ self.cache_dir = (
40
+ os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_DIR", "").strip() or default_cache_dir()
41
+ )
42
+ if self.cache_dir:
43
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
44
+ self.lock_path = os.path.join(self.cache_dir, "lock")
45
+ os.makedirs(self.cache_dir, exist_ok=True)
46
+ else:
47
+ raise RuntimeError("Could not create or locate cache dir")
48
+
49
+
50
+ def _base32(key):
51
+ # Assume key is a hex string.
52
+ return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
53
+
54
+
55
+ class Autotuner:
56
+ def __init__(self, fn, key, configs, restore_value=None, do_bench=None, cache_results=False):
57
+ if not configs:
58
+ self.configs = [AutotuneConfig()]
59
+ else:
60
+ self.configs = configs
61
+ signature = inspect.signature(fn)
62
+ self.keys = key
63
+ self.cache: Dict[Tuple, AutotuneConfig] = {}
64
+ self.arg_names = list(signature.parameters.keys())
65
+ self.cache_results = (
66
+ cache_results or os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_AUTOTUNING", None) == "1"
67
+ )
68
+
69
+ self.restore_value = []
70
+ if restore_value is not None:
71
+ self.restore_value = list(restore_value)
72
+
73
+ if len(self.restore_value) > 0:
74
+
75
+ def _pre_hook(kwargs):
76
+ self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value}
77
+
78
+ self.pre_hook = _pre_hook
79
+ else:
80
+ self.pre_hook = None
81
+
82
+ if len(self.restore_value) > 0:
83
+
84
+ def _post_hook(kwargs, exception):
85
+ for name in self.restore_value:
86
+ kwargs[name].copy_(self.restore_copies[name])
87
+ self.restore_copies = {}
88
+
89
+ self.post_hook = _post_hook
90
+ else:
91
+ self.post_hook = None
92
+
93
+ self.fn = fn
94
+ self._do_bench = do_bench
95
+
96
+ @cached_property
97
+ def do_bench(self):
98
+ if self._do_bench is None:
99
+ return partial(triton.testing.do_bench, warmup=5, rep=25)
100
+ return self._do_bench
101
+
102
+ def _bench(self, *args, config, **meta):
103
+ verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
104
+ if verbose:
105
+ print(f"Autotuning kernel {self.fn.__name__} with config {config}")
106
+
107
+ # check for conflicts, i.e. meta-parameters both provided
108
+ # as kwargs and by the autotuner
109
+ conflicts = meta.keys() & config.kwargs.keys()
110
+ if conflicts:
111
+ raise ValueError(
112
+ f"Conflicting meta-parameters: {', '.join(conflicts)}."
113
+ " Make sure that you don't re-define auto-tuned symbols."
114
+ )
115
+ # augment meta-parameters with tunable ones
116
+ current = dict(meta, **config.all_kwargs())
117
+ full_nargs = {**self.nargs, **current}
118
+
119
+ def kernel_call():
120
+ if self.pre_hook is not None:
121
+ self.pre_hook(full_nargs)
122
+ try:
123
+ self.fn.__call__(
124
+ *args,
125
+ **current,
126
+ )
127
+ except Exception as e:
128
+ try:
129
+ if self.post_hook is not None:
130
+ self.post_hook(full_nargs, exception=e)
131
+ finally:
132
+ # Throw exception raised by `self.fn.run`
133
+ raise
134
+
135
+ if self.post_hook is not None:
136
+ self.post_hook(full_nargs, exception=None)
137
+
138
+ try:
139
+ return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
140
+ except Exception as e:
141
+ if verbose:
142
+ print(f"Autotuning failed with {e}")
143
+ return [float("inf"), float("inf"), float("inf")]
144
+
145
+ @torch.compiler.disable
146
+ def check_disk_cache(self, tuning_key, configs, bench_fn):
147
+ if not tuning_key:
148
+ bench_fn()
149
+ return
150
+
151
+ fn = self.fn
152
+ config_str_list = [str(c) for c in configs]
153
+ assert len(config_str_list) == len(set(config_str_list)), "Config strings must be unique"
154
+ cache_key = [VERSION, str(tuning_key)] + config_str_list
155
+ cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
156
+ cache = FileCacheManager(_base32(cache_key))
157
+ file_name = f"{fn.__name__[:150]}.autotune.json"
158
+ path = cache.get_file(file_name)
159
+ # There's an environment variable to force cache update
160
+ if path and not os.environ.get(f"{PACKAGE_NAME.upper()}_FORCE_CACHE_UPDATE", False):
161
+ str2config = {s: c for s, c in zip(config_str_list, configs)}
162
+ with open(path, "r") as cached_configs:
163
+ timings = json.load(cached_configs)["configs_timings"]
164
+ timings = {str2config[config]: timing for config, timing in timings}
165
+ self.cache[tuning_key] = builtins.min(timings, key=timings.get)
166
+ self.configs_timings = timings
167
+ self.bench_time = 0
168
+ return
169
+
170
+ bench_fn()
171
+ cache.put(
172
+ json.dumps(
173
+ {
174
+ "key": tuning_key,
175
+ "configs_timings": [
176
+ (str(config), timings) for config, timings in self.configs_timings.items()
177
+ ],
178
+ }
179
+ ),
180
+ file_name,
181
+ binary=False,
182
+ )
183
+
184
+ def __call__(self, *args, **kwargs):
185
+ self.nargs = dict(zip(self.arg_names, args))
186
+ used_cached_result = True
187
+ if len(self.configs) > 1:
188
+ all_args = {**self.nargs, **kwargs}
189
+ _args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
190
+ key = [_args[key] for key in self.keys if key in _args]
191
+ for _, arg in _args.items():
192
+ if isinstance(arg, Tensor):
193
+ key.append(str(arg.shape))
194
+ # If stride != 0, 1, we just cache it as 2
195
+ key.append(str([s if s in {0, 1} else 2 for s in arg.stride()]))
196
+ key.append(str(arg.dtype))
197
+ key = tuple(key)
198
+ if key not in self.cache:
199
+ used_cached_result = False
200
+
201
+ @torch.compiler.disable # Don't want any tracing here
202
+ def benchmark():
203
+ bench_start = time.time()
204
+ timings = {
205
+ config: self._bench(*args, config=config, **kwargs)
206
+ for config in self.configs
207
+ }
208
+ bench_end = time.time()
209
+ if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
210
+ for config, time_ in timings.items():
211
+ print(f"[{config}] -> {time_[0]:.3f}ms")
212
+ self.bench_time = bench_end - bench_start
213
+ self.cache[key] = builtins.min(timings, key=timings.get)
214
+ self.configs_timings = timings
215
+
216
+ if self.cache_results:
217
+ self.check_disk_cache(key, self.configs, benchmark)
218
+ else:
219
+ benchmark()
220
+
221
+ config = self.cache[key]
222
+ else:
223
+ config = self.configs[0]
224
+ self.best_config = config
225
+ if (
226
+ os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
227
+ and not used_cached_result
228
+ ):
229
+ print(
230
+ f"{PACKAGE_NAME} autotuning for function {self.fn.__name__} finished after "
231
+ f"{self.bench_time:.2f}s; best config selected: {self.best_config};"
232
+ )
233
+ ret = self.fn.__call__(
234
+ *args,
235
+ **kwargs,
236
+ **config.all_kwargs(),
237
+ )
238
+ self.nargs = None
239
+ return ret
240
+
241
+
242
+ class AutotuneConfig:
243
+ """
244
+ An object that represents a possible kernel configuration for the auto-tuner to try.
245
+
246
+ :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
247
+ :type kwargs: dict[Str, Any]
248
+ """
249
+
250
+ def __init__(self, **kwargs):
251
+ self.kwargs = kwargs
252
+
253
+ def __setstate__(self, state):
254
+ self.kwargs = state.get("kwargs", {})
255
+
256
+ def all_kwargs(self):
257
+ return self.kwargs
258
+
259
+ def __str__(self):
260
+ res = []
261
+ for k, v in self.kwargs.items():
262
+ res.append(f"{k}: {v}")
263
+ return ", ".join(res)
264
+
265
+ def __hash__(self):
266
+ return hash(tuple(*self.all_kwargs().items()))
267
+
268
+ def __eq__(self, other):
269
+ self_tuple = tuple(*self.all_kwargs().items())
270
+ other_tuple = tuple(*other.all_kwargs().items())
271
+ return self_tuple == other_tuple
272
+
273
+
274
+ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results=True):
275
+ f"""
276
+ Decorator for auto-tuning a function function.
277
+
278
+ .. highlight:: python
279
+
280
+ If the environment variable :code:`{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING` is set to
281
+ :code:`"1"`, we will print a message to stdout after autotuning each
282
+ kernel, including the time spent autotuning and the best configuration.
283
+
284
+ :param configs: a list of :code:`AutotuneConfig` objects
285
+ :type configs: list[AutotuneConfig]
286
+ :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
287
+ :type key: list[str]
288
+ :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
289
+ :type restore_value: list[str]
290
+ :param do_bench: a benchmark function to measure the time of each run.
291
+ :type do_bench: lambda fn, quantiles
292
+ :param cache_results: whether to cache autotune timings to disk. Defaults to False.
293
+ "type cache_results: bool
294
+ """
295
+
296
+ if key is None:
297
+ key = []
298
+
299
+ def decorator(fn):
300
+ return Autotuner(
301
+ fn,
302
+ key,
303
+ configs,
304
+ restore_value=restore_value,
305
+ do_bench=do_bench,
306
+ cache_results=cache_results,
307
+ )
308
+
309
+ return decorator
@@ -0,0 +1,40 @@
1
+ import os
2
+ import pathlib
3
+ from functools import partial
4
+
5
+ try:
6
+ from triton.tools.disasm import extract
7
+ except ImportError:
8
+ extract = None
9
+
10
+ import cutlass
11
+ import cutlass.cute as cute
12
+
13
+
14
+ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
15
+ cute_compile_og = cute.compile
16
+
17
+
18
+ def load_cubin_module_data_patched(cubin_data, filepath):
19
+ path = pathlib.Path(filepath)
20
+ path.write_bytes(cubin_data)
21
+ return load_cubin_module_data_og(cubin_data)
22
+
23
+
24
+ def cute_compile_patched(*args, **kwargs):
25
+ """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
26
+ if os.getenv("CUTE_CUBIN_PATH") is not None:
27
+ cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
28
+ load_cubin_module_data_patched, filepath=os.getenv("CUTE_CUBIN_PATH")
29
+ )
30
+ output = cute_compile_og(*args, **kwargs)
31
+ if os.getenv("CUTE_CUBIN_PATH") is not None:
32
+ cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
33
+ if extract is not None:
34
+ cubin_path = pathlib.Path(os.getenv("CUTE_CUBIN_PATH"))
35
+ sass = extract(cubin_path, None)
36
+ cubin_path.with_suffix(".annotated.sass").write_text(sass)
37
+ return output
38
+
39
+
40
+ cute.compile = cute_compile_patched