quack-kernels 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {quack_kernels-0.2.1/quack_kernels.egg-info → quack_kernels-0.2.2}/PKG-INFO +2 -2
  2. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/pyproject.toml +1 -1
  3. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/__init__.py +1 -1
  4. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/autotuner.py +64 -5
  5. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/cute_dsl_utils.py +6 -7
  6. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/dense_gemm_sm90.py +582 -287
  7. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/gemm_act_sm90.py +70 -29
  8. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/gemm_dact_sm90.py +43 -10
  9. quack_kernels-0.2.2/quack/gemm_interface.py +892 -0
  10. quack_kernels-0.2.1/quack/dense_gemm_sm100.py → quack_kernels-0.2.2/quack/gemm_sm100.py +443 -419
  11. quack_kernels-0.2.2/quack/gemm_wrapper_utils.py +315 -0
  12. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/rmsnorm.py +83 -149
  13. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/tile_scheduler.py +34 -47
  14. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/utils.py +61 -8
  15. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/varlen_utils.py +1 -6
  16. {quack_kernels-0.2.1 → quack_kernels-0.2.2/quack_kernels.egg-info}/PKG-INFO +2 -2
  17. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack_kernels.egg-info/SOURCES.txt +3 -1
  18. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack_kernels.egg-info/requires.txt +1 -1
  19. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/tests/test_linear.py +6 -1
  20. quack_kernels-0.2.2/tests/test_linear_varlen_k.py +266 -0
  21. quack_kernels-0.2.2/tests/test_linear_varlen_m.py +376 -0
  22. quack_kernels-0.2.1/quack/gemm_interface.py +0 -569
  23. quack_kernels-0.2.1/quack/gemm_wrapper_utils.py +0 -158
  24. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/LICENSE +0 -0
  25. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/README.md +0 -0
  26. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/activation.py +0 -0
  27. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/cross_entropy.py +0 -0
  28. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/fast_math.py +0 -0
  29. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/gemm_config.py +0 -0
  30. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/layernorm.py +0 -0
  31. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/linear.py +0 -0
  32. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/linear_cross_entropy.py +0 -0
  33. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/mlp.py +0 -0
  34. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/pipeline.py +0 -0
  35. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/reduce.py +0 -0
  36. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/reduction_base.py +0 -0
  37. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/softmax.py +0 -0
  38. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/sort/bitonic_sort.py +0 -0
  39. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/sort/generate_sorting_networks.py +0 -0
  40. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/sort/sorting_networks.py +0 -0
  41. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/sort/utils.py +0 -0
  42. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/symmetric_dense_gemm_sm90.py +0 -0
  43. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/tensormap_manager.py +0 -0
  44. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack/topk.py +0 -0
  45. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack_kernels.egg-info/dependency_links.txt +0 -0
  46. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/quack_kernels.egg-info/top_level.txt +0 -0
  47. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/setup.cfg +0 -0
  48. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/tests/test_cross_entropy.py +0 -0
  49. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/tests/test_layernorm.py +0 -0
  50. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/tests/test_linear_cross_entropy.py +0 -0
  51. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/tests/test_rmsnorm.py +0 -0
  52. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/tests/test_softmax.py +0 -0
  53. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/tests/test_symmetric_dense_gemm_sm90.py +0 -0
  54. {quack_kernels-0.2.1 → quack_kernels-0.2.2}/tests/test_topk.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.2.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.2.1
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -7,7 +7,7 @@ name = "quack-kernels"
7
7
  dynamic = ["version"]
8
8
  requires-python = ">=3.10"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.2.0",
10
+ "nvidia-cutlass-dsl==4.2.1",
11
11
  "torch",
12
12
  ]
13
13
 
@@ -1,4 +1,4 @@
1
- __version__ = "0.2.1"
1
+ __version__ = "0.2.2"
2
2
 
3
3
  import cutlass.cute as cute
4
4
 
@@ -11,7 +11,7 @@ import hashlib
11
11
  import json
12
12
  from pathlib import Path
13
13
  from functools import cached_property, partial
14
- from typing import Dict, Tuple
14
+ from typing import Dict, Tuple, List, Optional, Any
15
15
 
16
16
  import torch
17
17
  from torch import Tensor
@@ -53,7 +53,22 @@ def _base32(key):
53
53
 
54
54
 
55
55
  class Autotuner:
56
- def __init__(self, fn, key, configs, restore_value=None, do_bench=None, cache_results=False):
56
+ def __init__(
57
+ self,
58
+ fn,
59
+ key,
60
+ configs,
61
+ restore_value=None,
62
+ prune_configs_by: Optional[Dict] = None,
63
+ do_bench=None,
64
+ cache_results=False,
65
+ ):
66
+ """
67
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
68
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
69
+ 'top_k': number of configs to bench
70
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
71
+ """
57
72
  if not configs:
58
73
  self.configs = [AutotuneConfig()]
59
74
  else:
@@ -90,6 +105,16 @@ class Autotuner:
90
105
  else:
91
106
  self.post_hook = None
92
107
 
108
+ self.perf_model = None
109
+ self.configs_top_k = 1.0
110
+ self.early_config_prune = None
111
+ if prune_configs_by:
112
+ self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
113
+ self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
114
+ self.early_config_prune = prune_configs_by.get(
115
+ "early_config_prune", self.early_config_prune
116
+ )
117
+
93
118
  self.fn = fn
94
119
  self._do_bench = do_bench
95
120
 
@@ -198,13 +223,14 @@ class Autotuner:
198
223
  key = tuple(key)
199
224
  if key not in self.cache:
200
225
  used_cached_result = False
226
+ pruned_configs = self.prune_configs(kwargs)
201
227
 
202
228
  @torch.compiler.disable # Don't want any tracing here
203
229
  def benchmark():
204
230
  bench_start = time.time()
205
231
  timings = {
206
232
  config: self._bench(*args, config=config, **kwargs)
207
- for config in self.configs
233
+ for config in pruned_configs
208
234
  }
209
235
  bench_end = time.time()
210
236
  if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
@@ -215,7 +241,7 @@ class Autotuner:
215
241
  self.configs_timings = timings
216
242
 
217
243
  if self.cache_results:
218
- self.check_disk_cache(key, self.configs, benchmark)
244
+ self.check_disk_cache(key, pruned_configs, benchmark)
219
245
  else:
220
246
  benchmark()
221
247
 
@@ -239,6 +265,32 @@ class Autotuner:
239
265
  self.nargs = None
240
266
  return ret
241
267
 
268
+ def prune_configs(self, kwargs: Dict) -> List[Any]:
269
+ pruned_configs = self.configs
270
+ if self.early_config_prune:
271
+ pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
272
+ if self.perf_model:
273
+ top_k = self.configs_top_k
274
+ if isinstance(top_k, float) and top_k <= 1.0:
275
+ top_k = int(len(self.configs) * top_k)
276
+ elif not isinstance(top_k, int):
277
+ # Slice index must be an integer
278
+ raise TypeError(
279
+ "Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
280
+ )
281
+
282
+ if len(pruned_configs) > top_k:
283
+ est_timing = {
284
+ config: self.perf_model(
285
+ **self.nargs,
286
+ **kwargs,
287
+ **config.all_kwargs(),
288
+ )
289
+ for config in pruned_configs
290
+ }
291
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
292
+ return pruned_configs
293
+
242
294
 
243
295
  class AutotuneConfig:
244
296
  """
@@ -272,7 +324,9 @@ class AutotuneConfig:
272
324
  return self_tuple == other_tuple
273
325
 
274
326
 
275
- def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results=True):
327
+ def autotune(
328
+ configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
329
+ ):
276
330
  f"""
277
331
  Decorator for auto-tuning a function function.
278
332
 
@@ -286,6 +340,10 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
286
340
  :type configs: list[AutotuneConfig]
287
341
  :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
288
342
  :type key: list[str]
343
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
344
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
345
+ 'top_k': number of configs to bench
346
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
289
347
  :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
290
348
  :type restore_value: list[str]
291
349
  :param do_bench: a benchmark function to measure the time of each run.
@@ -303,6 +361,7 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
303
361
  key,
304
362
  configs,
305
363
  restore_value=restore_value,
364
+ prune_configs_by=prune_configs_by,
306
365
  do_bench=do_bench,
307
366
  cache_results=cache_results,
308
367
  )
@@ -98,22 +98,21 @@ class ArgumentsBase(JitArgument):
98
98
 
99
99
 
100
100
  def load_cubin_module_data_patched(cubin_data, filepath):
101
- path = pathlib.Path(filepath)
102
- path.write_bytes(cubin_data)
101
+ pathlib.Path(filepath).write_bytes(cubin_data)
103
102
  return load_cubin_module_data_og(cubin_data)
104
103
 
105
104
 
106
105
  def cute_compile_patched(*args, **kwargs):
107
106
  """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
108
- if os.getenv("CUTE_CUBIN_PATH") is not None:
107
+ cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
108
+ if cubin_path is not None:
109
109
  cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
110
- load_cubin_module_data_patched, filepath=os.getenv("CUTE_CUBIN_PATH")
110
+ load_cubin_module_data_patched, filepath=cubin_path
111
111
  )
112
112
  output = cute_compile_og(*args, **kwargs)
113
- if os.getenv("CUTE_CUBIN_PATH") is not None:
113
+ if cubin_path is not None:
114
114
  cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
115
115
  if extract is not None:
116
- cubin_path = pathlib.Path(os.getenv("CUTE_CUBIN_PATH"))
117
116
  sass = extract(cubin_path, None)
118
- cubin_path.with_suffix(".annotated.sass").write_text(sass)
117
+ pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
119
118
  return output