quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/cute_dsl_utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import pathlib
|
|
5
|
+
from functools import partial, lru_cache
|
|
6
|
+
from dataclasses import dataclass, fields
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from triton.tools.disasm import extract
|
|
12
|
+
except ImportError:
|
|
13
|
+
extract = None
|
|
14
|
+
|
|
15
|
+
import cutlass
|
|
16
|
+
import cutlass.cute as cute
|
|
17
|
+
from cutlass.base_dsl.typing import JitArgument
|
|
18
|
+
from cutlass.cutlass_dsl import NumericMeta
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
|
25
|
+
cute_compile_og = cute.compile
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
torch2cute_dtype_map = {
|
|
29
|
+
torch.float16: cutlass.Float16,
|
|
30
|
+
torch.bfloat16: cutlass.BFloat16,
|
|
31
|
+
torch.float32: cutlass.Float32,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@lru_cache
|
|
36
|
+
def get_max_active_clusters(cluster_size):
|
|
37
|
+
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class ParamsBase:
|
|
42
|
+
def __extract_mlir_values__(self):
|
|
43
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
44
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
45
|
+
values, self._values_pos = [], []
|
|
46
|
+
for obj in non_constexpr_fields:
|
|
47
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
48
|
+
values += obj_values
|
|
49
|
+
self._values_pos.append(len(obj_values))
|
|
50
|
+
return values
|
|
51
|
+
|
|
52
|
+
def __new_from_mlir_values__(self, values):
|
|
53
|
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
54
|
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
|
55
|
+
non_constexpr_fields = {
|
|
56
|
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
|
57
|
+
}
|
|
58
|
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
59
|
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
60
|
+
values = values[n_items:]
|
|
61
|
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ArgumentsBase(JitArgument):
|
|
66
|
+
def __c_pointers__(self):
|
|
67
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
68
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
69
|
+
c_ptrs = []
|
|
70
|
+
for obj in non_constexpr_fields:
|
|
71
|
+
if hasattr(obj, "__c_pointers__"):
|
|
72
|
+
c_ptrs.extend(obj.__c_pointers__())
|
|
73
|
+
return c_ptrs
|
|
74
|
+
|
|
75
|
+
def __get_mlir_types__(self):
|
|
76
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
77
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
78
|
+
types = []
|
|
79
|
+
for obj in non_constexpr_fields:
|
|
80
|
+
if hasattr(obj, "__get_mlir_types__"):
|
|
81
|
+
types.extend(obj.__get_mlir_types__())
|
|
82
|
+
return types
|
|
83
|
+
|
|
84
|
+
def __new_from_mlir_values__(self, values):
|
|
85
|
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
86
|
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
|
87
|
+
non_constexpr_fields = {
|
|
88
|
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
|
89
|
+
}
|
|
90
|
+
# for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
91
|
+
for name, field in non_constexpr_fields.items():
|
|
92
|
+
# non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
93
|
+
# values = values[n_items:]
|
|
94
|
+
n_items = 1
|
|
95
|
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
96
|
+
values = values[n_items:]
|
|
97
|
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def load_cubin_module_data_patched(cubin_data, filepath):
|
|
101
|
+
path = pathlib.Path(filepath)
|
|
102
|
+
path.write_bytes(cubin_data)
|
|
103
|
+
return load_cubin_module_data_og(cubin_data)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def cute_compile_patched(*args, **kwargs):
|
|
107
|
+
"""A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
|
|
108
|
+
if os.getenv("CUTE_CUBIN_PATH") is not None:
|
|
109
|
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
|
|
110
|
+
load_cubin_module_data_patched, filepath=os.getenv("CUTE_CUBIN_PATH")
|
|
111
|
+
)
|
|
112
|
+
output = cute_compile_og(*args, **kwargs)
|
|
113
|
+
if os.getenv("CUTE_CUBIN_PATH") is not None:
|
|
114
|
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
|
|
115
|
+
if extract is not None:
|
|
116
|
+
cubin_path = pathlib.Path(os.getenv("CUTE_CUBIN_PATH"))
|
|
117
|
+
sass = extract(cubin_path, None)
|
|
118
|
+
cubin_path.with_suffix(".annotated.sass").write_text(sass)
|
|
119
|
+
return output
|