quack-kernels 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/cute_dsl_utils.py CHANGED
@@ -1,6 +1,11 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
1
3
  import os
2
4
  import pathlib
3
- from functools import partial
5
+ from functools import partial, lru_cache
6
+ from dataclasses import dataclass, fields
7
+
8
+ import torch
4
9
 
5
10
  try:
6
11
  from triton.tools.disasm import extract
@@ -9,12 +14,89 @@ except ImportError:
9
14
 
10
15
  import cutlass
11
16
  import cutlass.cute as cute
17
+ from cutlass.base_dsl.typing import JitArgument
18
+ from cutlass.cutlass_dsl import NumericMeta
19
+
20
+
21
+ StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
12
22
 
13
23
 
14
24
  load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
15
25
  cute_compile_og = cute.compile
16
26
 
17
27
 
28
+ torch2cute_dtype_map = {
29
+ torch.float16: cutlass.Float16,
30
+ torch.bfloat16: cutlass.BFloat16,
31
+ torch.float32: cutlass.Float32,
32
+ }
33
+
34
+
35
+ @lru_cache
36
+ def get_max_active_clusters(cluster_size):
37
+ return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
38
+
39
+
40
+ @dataclass
41
+ class ParamsBase:
42
+ def __extract_mlir_values__(self):
43
+ all_fields = [getattr(self, field.name) for field in fields(self)]
44
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
45
+ values, self._values_pos = [], []
46
+ for obj in non_constexpr_fields:
47
+ obj_values = cutlass.extract_mlir_values(obj)
48
+ values += obj_values
49
+ self._values_pos.append(len(obj_values))
50
+ return values
51
+
52
+ def __new_from_mlir_values__(self, values):
53
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
54
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
55
+ non_constexpr_fields = {
56
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
57
+ }
58
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
59
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
60
+ values = values[n_items:]
61
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
62
+
63
+
64
+ @dataclass
65
+ class ArgumentsBase(JitArgument):
66
+ def __c_pointers__(self):
67
+ all_fields = [getattr(self, field.name) for field in fields(self)]
68
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
69
+ c_ptrs = []
70
+ for obj in non_constexpr_fields:
71
+ if hasattr(obj, "__c_pointers__"):
72
+ c_ptrs.extend(obj.__c_pointers__())
73
+ return c_ptrs
74
+
75
+ def __get_mlir_types__(self):
76
+ all_fields = [getattr(self, field.name) for field in fields(self)]
77
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
78
+ types = []
79
+ for obj in non_constexpr_fields:
80
+ if hasattr(obj, "__get_mlir_types__"):
81
+ types.extend(obj.__get_mlir_types__())
82
+ return types
83
+
84
+ def __new_from_mlir_values__(self, values):
85
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
86
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
87
+ non_constexpr_fields = {
88
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
89
+ }
90
+ # for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
91
+ for name, field in non_constexpr_fields.items():
92
+ # non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
93
+ # values = values[n_items:]
94
+ n_items = 1
95
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
96
+ values = values[n_items:]
97
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
98
+
99
+
18
100
  def load_cubin_module_data_patched(cubin_data, filepath):
19
101
  path = pathlib.Path(filepath)
20
102
  path.write_bytes(cubin_data)
@@ -35,6 +117,3 @@ def cute_compile_patched(*args, **kwargs):
35
117
  sass = extract(cubin_path, None)
36
118
  cubin_path.with_suffix(".annotated.sass").write_text(sass)
37
119
  return output
38
-
39
-
40
- cute.compile = cute_compile_patched
quack/dense_gemm_sm100.py CHANGED
@@ -43,10 +43,10 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils
43
43
  from cutlass.cute.runtime import from_dlpack, make_ptr
44
44
  from cutlass import Int32, const_expr
45
45
 
46
+ from quack.cute_dsl_utils import ParamsBase
46
47
  from quack.tile_scheduler import (
47
48
  TileSchedulerArguments,
48
49
  TileScheduler,
49
- ParamsBase,
50
50
  RasterOrderOption,
51
51
  )
52
52