quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/cute_dsl_utils.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
import pathlib
|
|
3
|
-
from functools import partial
|
|
5
|
+
from functools import partial, lru_cache
|
|
6
|
+
from dataclasses import dataclass, fields
|
|
7
|
+
|
|
8
|
+
import torch
|
|
4
9
|
|
|
5
10
|
try:
|
|
6
11
|
from triton.tools.disasm import extract
|
|
@@ -9,12 +14,89 @@ except ImportError:
|
|
|
9
14
|
|
|
10
15
|
import cutlass
|
|
11
16
|
import cutlass.cute as cute
|
|
17
|
+
from cutlass.base_dsl.typing import JitArgument
|
|
18
|
+
from cutlass.cutlass_dsl import NumericMeta
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
|
|
12
22
|
|
|
13
23
|
|
|
14
24
|
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
|
15
25
|
cute_compile_og = cute.compile
|
|
16
26
|
|
|
17
27
|
|
|
28
|
+
torch2cute_dtype_map = {
|
|
29
|
+
torch.float16: cutlass.Float16,
|
|
30
|
+
torch.bfloat16: cutlass.BFloat16,
|
|
31
|
+
torch.float32: cutlass.Float32,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@lru_cache
|
|
36
|
+
def get_max_active_clusters(cluster_size):
|
|
37
|
+
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class ParamsBase:
|
|
42
|
+
def __extract_mlir_values__(self):
|
|
43
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
44
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
45
|
+
values, self._values_pos = [], []
|
|
46
|
+
for obj in non_constexpr_fields:
|
|
47
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
48
|
+
values += obj_values
|
|
49
|
+
self._values_pos.append(len(obj_values))
|
|
50
|
+
return values
|
|
51
|
+
|
|
52
|
+
def __new_from_mlir_values__(self, values):
|
|
53
|
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
54
|
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
|
55
|
+
non_constexpr_fields = {
|
|
56
|
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
|
57
|
+
}
|
|
58
|
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
59
|
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
60
|
+
values = values[n_items:]
|
|
61
|
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ArgumentsBase(JitArgument):
|
|
66
|
+
def __c_pointers__(self):
|
|
67
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
68
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
69
|
+
c_ptrs = []
|
|
70
|
+
for obj in non_constexpr_fields:
|
|
71
|
+
if hasattr(obj, "__c_pointers__"):
|
|
72
|
+
c_ptrs.extend(obj.__c_pointers__())
|
|
73
|
+
return c_ptrs
|
|
74
|
+
|
|
75
|
+
def __get_mlir_types__(self):
|
|
76
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
77
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
78
|
+
types = []
|
|
79
|
+
for obj in non_constexpr_fields:
|
|
80
|
+
if hasattr(obj, "__get_mlir_types__"):
|
|
81
|
+
types.extend(obj.__get_mlir_types__())
|
|
82
|
+
return types
|
|
83
|
+
|
|
84
|
+
def __new_from_mlir_values__(self, values):
|
|
85
|
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
86
|
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
|
87
|
+
non_constexpr_fields = {
|
|
88
|
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
|
89
|
+
}
|
|
90
|
+
# for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
91
|
+
for name, field in non_constexpr_fields.items():
|
|
92
|
+
# non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
93
|
+
# values = values[n_items:]
|
|
94
|
+
n_items = 1
|
|
95
|
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
96
|
+
values = values[n_items:]
|
|
97
|
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
98
|
+
|
|
99
|
+
|
|
18
100
|
def load_cubin_module_data_patched(cubin_data, filepath):
|
|
19
101
|
path = pathlib.Path(filepath)
|
|
20
102
|
path.write_bytes(cubin_data)
|
|
@@ -35,6 +117,3 @@ def cute_compile_patched(*args, **kwargs):
|
|
|
35
117
|
sass = extract(cubin_path, None)
|
|
36
118
|
cubin_path.with_suffix(".annotated.sass").write_text(sass)
|
|
37
119
|
return output
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
cute.compile = cute_compile_patched
|
quack/dense_gemm_sm100.py
CHANGED
|
@@ -43,10 +43,10 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils
|
|
|
43
43
|
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
44
44
|
from cutlass import Int32, const_expr
|
|
45
45
|
|
|
46
|
+
from quack.cute_dsl_utils import ParamsBase
|
|
46
47
|
from quack.tile_scheduler import (
|
|
47
48
|
TileSchedulerArguments,
|
|
48
49
|
TileScheduler,
|
|
49
|
-
ParamsBase,
|
|
50
50
|
RasterOrderOption,
|
|
51
51
|
)
|
|
52
52
|
|