quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,119 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import os
4
+ import pathlib
5
+ from functools import partial, lru_cache
6
+ from dataclasses import dataclass, fields
7
+
8
+ import torch
9
+
10
+ try:
11
+ from triton.tools.disasm import extract
12
+ except ImportError:
13
+ extract = None
14
+
15
+ import cutlass
16
+ import cutlass.cute as cute
17
+ from cutlass.base_dsl.typing import JitArgument
18
+ from cutlass.cutlass_dsl import NumericMeta
19
+
20
+
21
+ StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
22
+
23
+
24
+ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
25
+ cute_compile_og = cute.compile
26
+
27
+
28
+ torch2cute_dtype_map = {
29
+ torch.float16: cutlass.Float16,
30
+ torch.bfloat16: cutlass.BFloat16,
31
+ torch.float32: cutlass.Float32,
32
+ }
33
+
34
+
35
+ @lru_cache
36
+ def get_max_active_clusters(cluster_size):
37
+ return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
38
+
39
+
40
+ @dataclass
41
+ class ParamsBase:
42
+ def __extract_mlir_values__(self):
43
+ all_fields = [getattr(self, field.name) for field in fields(self)]
44
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
45
+ values, self._values_pos = [], []
46
+ for obj in non_constexpr_fields:
47
+ obj_values = cutlass.extract_mlir_values(obj)
48
+ values += obj_values
49
+ self._values_pos.append(len(obj_values))
50
+ return values
51
+
52
+ def __new_from_mlir_values__(self, values):
53
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
54
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
55
+ non_constexpr_fields = {
56
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
57
+ }
58
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
59
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
60
+ values = values[n_items:]
61
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
62
+
63
+
64
+ @dataclass
65
+ class ArgumentsBase(JitArgument):
66
+ def __c_pointers__(self):
67
+ all_fields = [getattr(self, field.name) for field in fields(self)]
68
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
69
+ c_ptrs = []
70
+ for obj in non_constexpr_fields:
71
+ if hasattr(obj, "__c_pointers__"):
72
+ c_ptrs.extend(obj.__c_pointers__())
73
+ return c_ptrs
74
+
75
+ def __get_mlir_types__(self):
76
+ all_fields = [getattr(self, field.name) for field in fields(self)]
77
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
78
+ types = []
79
+ for obj in non_constexpr_fields:
80
+ if hasattr(obj, "__get_mlir_types__"):
81
+ types.extend(obj.__get_mlir_types__())
82
+ return types
83
+
84
+ def __new_from_mlir_values__(self, values):
85
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
86
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
87
+ non_constexpr_fields = {
88
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
89
+ }
90
+ # for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
91
+ for name, field in non_constexpr_fields.items():
92
+ # non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
93
+ # values = values[n_items:]
94
+ n_items = 1
95
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
96
+ values = values[n_items:]
97
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
98
+
99
+
100
+ def load_cubin_module_data_patched(cubin_data, filepath):
101
+ path = pathlib.Path(filepath)
102
+ path.write_bytes(cubin_data)
103
+ return load_cubin_module_data_og(cubin_data)
104
+
105
+
106
+ def cute_compile_patched(*args, **kwargs):
107
+ """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
108
+ if os.getenv("CUTE_CUBIN_PATH") is not None:
109
+ cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
110
+ load_cubin_module_data_patched, filepath=os.getenv("CUTE_CUBIN_PATH")
111
+ )
112
+ output = cute_compile_og(*args, **kwargs)
113
+ if os.getenv("CUTE_CUBIN_PATH") is not None:
114
+ cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
115
+ if extract is not None:
116
+ cubin_path = pathlib.Path(os.getenv("CUTE_CUBIN_PATH"))
117
+ sass = extract(cubin_path, None)
118
+ cubin_path.with_suffix(".annotated.sass").write_text(sass)
119
+ return output