quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/cute_dsl_utils.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2025, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
from functools import partial, lru_cache
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from functools import lru_cache
|
|
6
5
|
from dataclasses import dataclass, fields
|
|
7
6
|
|
|
8
7
|
import torch
|
|
@@ -14,6 +13,7 @@ except ImportError:
|
|
|
14
13
|
|
|
15
14
|
import cutlass
|
|
16
15
|
import cutlass.cute as cute
|
|
16
|
+
from cutlass import Int32, Int64, Float16, BFloat16, Float32
|
|
17
17
|
from cutlass.base_dsl.typing import JitArgument
|
|
18
18
|
from cutlass.cutlass_dsl import NumericMeta
|
|
19
19
|
|
|
@@ -26,9 +26,11 @@ cute_compile_og = cute.compile
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
torch2cute_dtype_map = {
|
|
29
|
-
torch.float16:
|
|
30
|
-
torch.bfloat16:
|
|
31
|
-
torch.float32:
|
|
29
|
+
torch.float16: Float16,
|
|
30
|
+
torch.bfloat16: BFloat16,
|
|
31
|
+
torch.float32: Float32,
|
|
32
|
+
torch.int32: Int32,
|
|
33
|
+
torch.int64: Int64,
|
|
32
34
|
}
|
|
33
35
|
|
|
34
36
|
|
|
@@ -37,6 +39,11 @@ def get_max_active_clusters(cluster_size):
|
|
|
37
39
|
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
|
38
40
|
|
|
39
41
|
|
|
42
|
+
@lru_cache
|
|
43
|
+
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
|
44
|
+
return torch.cuda.get_device_capability(device)
|
|
45
|
+
|
|
46
|
+
|
|
40
47
|
@dataclass
|
|
41
48
|
class ParamsBase:
|
|
42
49
|
def __extract_mlir_values__(self):
|
|
@@ -75,10 +82,14 @@ class ArgumentsBase(JitArgument):
|
|
|
75
82
|
def __get_mlir_types__(self):
|
|
76
83
|
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
77
84
|
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
|
78
|
-
types = []
|
|
85
|
+
types, self._values_pos = [], []
|
|
79
86
|
for obj in non_constexpr_fields:
|
|
80
87
|
if hasattr(obj, "__get_mlir_types__"):
|
|
81
|
-
|
|
88
|
+
obj_types = obj.__get_mlir_types__()
|
|
89
|
+
types.extend(obj_types)
|
|
90
|
+
self._values_pos.append(len(obj_types))
|
|
91
|
+
else:
|
|
92
|
+
self._values_pos.append(0)
|
|
82
93
|
return types
|
|
83
94
|
|
|
84
95
|
def __new_from_mlir_values__(self, values):
|
|
@@ -87,32 +98,7 @@ class ArgumentsBase(JitArgument):
|
|
|
87
98
|
non_constexpr_fields = {
|
|
88
99
|
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
|
89
100
|
}
|
|
90
|
-
|
|
91
|
-
for name, field in non_constexpr_fields.items():
|
|
92
|
-
# non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
93
|
-
# values = values[n_items:]
|
|
94
|
-
n_items = 1
|
|
101
|
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
95
102
|
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
96
103
|
values = values[n_items:]
|
|
97
104
|
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def load_cubin_module_data_patched(cubin_data, filepath):
|
|
101
|
-
pathlib.Path(filepath).write_bytes(cubin_data)
|
|
102
|
-
return load_cubin_module_data_og(cubin_data)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def cute_compile_patched(*args, **kwargs):
|
|
106
|
-
"""A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
|
|
107
|
-
cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
|
|
108
|
-
if cubin_path is not None:
|
|
109
|
-
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
|
|
110
|
-
load_cubin_module_data_patched, filepath=cubin_path
|
|
111
|
-
)
|
|
112
|
-
output = cute_compile_og(*args, **kwargs)
|
|
113
|
-
if cubin_path is not None:
|
|
114
|
-
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
|
|
115
|
-
if extract is not None:
|
|
116
|
-
sass = extract(cubin_path, None)
|
|
117
|
-
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
|
|
118
|
-
return output
|
quack/gemm.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
import cutlass.torch as cutlass_torch
|
|
8
|
+
from cutlass import Float32
|
|
9
|
+
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
10
|
+
|
|
11
|
+
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
|
12
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
13
|
+
from quack.gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def gemm(
|
|
17
|
+
# (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
|
|
18
|
+
A: Tensor,
|
|
19
|
+
B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
|
|
20
|
+
D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
21
|
+
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
22
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
23
|
+
tile_M: int,
|
|
24
|
+
tile_N: int,
|
|
25
|
+
cluster_M: int,
|
|
26
|
+
cluster_N: int,
|
|
27
|
+
pingpong: bool = False,
|
|
28
|
+
persistent: bool = True,
|
|
29
|
+
max_swizzle_size: int = 8,
|
|
30
|
+
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
|
31
|
+
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
|
32
|
+
alpha: float | Tensor = 1.0,
|
|
33
|
+
beta: float | Tensor = 1.0,
|
|
34
|
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
35
|
+
cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
|
|
36
|
+
A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
|
|
37
|
+
batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
|
|
38
|
+
add_to_output: bool = False,
|
|
39
|
+
) -> None:
|
|
40
|
+
varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
|
|
41
|
+
assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
|
|
42
|
+
"Only one of cu_seqlens_m and cu_seqlens_k can be specified"
|
|
43
|
+
)
|
|
44
|
+
gather_A = A_idx is not None
|
|
45
|
+
if gather_A:
|
|
46
|
+
assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
|
|
47
|
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
48
|
+
if varlen:
|
|
49
|
+
assert persistent, "varlen requires persistent=True"
|
|
50
|
+
if add_to_output:
|
|
51
|
+
assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
|
|
52
|
+
if cu_seqlens_m is not None:
|
|
53
|
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
54
|
+
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
|
55
|
+
if cu_seqlens_k is not None:
|
|
56
|
+
assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
|
|
57
|
+
assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
|
|
58
|
+
|
|
59
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
60
|
+
A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
|
|
61
|
+
)
|
|
62
|
+
GemmWrapperBase.permute_tensors(
|
|
63
|
+
tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
|
|
64
|
+
)
|
|
65
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
66
|
+
major_configs = {
|
|
67
|
+
"A": ("m", "k", "l"),
|
|
68
|
+
"B": ("n", "k", "l"),
|
|
69
|
+
"D": ("m", "n", "l"),
|
|
70
|
+
"C": ("m", "n", "l"),
|
|
71
|
+
}
|
|
72
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
73
|
+
|
|
74
|
+
device_capacity = get_device_capacity(A.device)
|
|
75
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
76
|
+
GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90
|
|
77
|
+
|
|
78
|
+
acc_dtype = Float32
|
|
79
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
80
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
81
|
+
if not GemmCls.is_valid_dtypes(
|
|
82
|
+
tensor_infos["A"].dtype,
|
|
83
|
+
tensor_infos["B"].dtype,
|
|
84
|
+
acc_dtype,
|
|
85
|
+
tensor_infos["D"].dtype,
|
|
86
|
+
tensor_infos["A"].major,
|
|
87
|
+
tensor_infos["B"].major,
|
|
88
|
+
):
|
|
89
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
90
|
+
|
|
91
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
92
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
93
|
+
|
|
94
|
+
def scalar_arg(scalar: float | Tensor):
|
|
95
|
+
if isinstance(scalar, float):
|
|
96
|
+
return Float32(scalar) if scalar != 1.0 else None
|
|
97
|
+
else:
|
|
98
|
+
assert isinstance(scalar, Tensor)
|
|
99
|
+
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
100
|
+
|
|
101
|
+
epi_args = GemmCls.EpilogueArguments(
|
|
102
|
+
scalar_arg(alpha),
|
|
103
|
+
scalar_arg(beta),
|
|
104
|
+
mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
|
105
|
+
leading_dim=1
|
|
106
|
+
)
|
|
107
|
+
if rowvec_bias is not None
|
|
108
|
+
else None,
|
|
109
|
+
mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
|
110
|
+
leading_dim=1 if cu_seqlens_m is None else 0
|
|
111
|
+
)
|
|
112
|
+
if colvec_bias is not None
|
|
113
|
+
else None,
|
|
114
|
+
add_to_output=add_to_output,
|
|
115
|
+
)
|
|
116
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
117
|
+
max_active_clusters,
|
|
118
|
+
tile_count_semaphore,
|
|
119
|
+
batch_idx_permute,
|
|
120
|
+
max_swizzle_size,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Create varlen arguments if needed (assumes persistent=True when varlen)
|
|
124
|
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
125
|
+
cu_seqlens_m,
|
|
126
|
+
cu_seqlens_k,
|
|
127
|
+
A_idx,
|
|
128
|
+
max_active_clusters,
|
|
129
|
+
cluster_shape_mnk,
|
|
130
|
+
tensor_infos,
|
|
131
|
+
GemmCls.num_epi_tensormaps,
|
|
132
|
+
pingpong,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
current_stream = cutlass_torch.current_stream()
|
|
136
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
137
|
+
tensor_infos,
|
|
138
|
+
None, # activation
|
|
139
|
+
tile_shape_mn,
|
|
140
|
+
cluster_shape_mnk,
|
|
141
|
+
pingpong,
|
|
142
|
+
persistent,
|
|
143
|
+
tile_count_semaphore is not None,
|
|
144
|
+
device_capacity,
|
|
145
|
+
# Technically we don't need to recompile for different max_swizzle_size, but currently
|
|
146
|
+
# not recompiling will skew the autotuning results due to power throttling.
|
|
147
|
+
# Effectively we're recompiling as a way to pause between benchmarks during autotuning.
|
|
148
|
+
max_swizzle_size,
|
|
149
|
+
rowvec_bias.dtype if rowvec_bias is not None else None,
|
|
150
|
+
colvec_bias.dtype if colvec_bias is not None else None,
|
|
151
|
+
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
|
152
|
+
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
|
153
|
+
add_to_output,
|
|
154
|
+
cu_seqlens_m is not None,
|
|
155
|
+
cu_seqlens_k is not None,
|
|
156
|
+
gather_A,
|
|
157
|
+
batch_idx_permute is not None,
|
|
158
|
+
key_tensor_names=("A", "B", "D", "C"),
|
|
159
|
+
)
|
|
160
|
+
cache = gemm.compile_cache
|
|
161
|
+
if compile_key not in cache:
|
|
162
|
+
if device_capacity[0] == 9:
|
|
163
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
164
|
+
gemm_obj = GemmCls(
|
|
165
|
+
acc_dtype,
|
|
166
|
+
tensor_infos["A"].dtype,
|
|
167
|
+
tile_shape_mn,
|
|
168
|
+
cluster_shape_mnk,
|
|
169
|
+
gather_A=gather_A,
|
|
170
|
+
)
|
|
171
|
+
cache[compile_key] = cute.compile(
|
|
172
|
+
gemm_obj,
|
|
173
|
+
tensor_infos["A"].cute_tensor,
|
|
174
|
+
tensor_infos["B"].cute_tensor,
|
|
175
|
+
tensor_infos["D"].cute_tensor,
|
|
176
|
+
tensor_infos["C"].cute_tensor,
|
|
177
|
+
epi_args,
|
|
178
|
+
scheduler_args,
|
|
179
|
+
varlen_args,
|
|
180
|
+
current_stream,
|
|
181
|
+
)
|
|
182
|
+
cache[compile_key](
|
|
183
|
+
tensor_infos["A"].cute_tensor,
|
|
184
|
+
tensor_infos["B"].cute_tensor,
|
|
185
|
+
tensor_infos["D"].cute_tensor,
|
|
186
|
+
tensor_infos["C"].cute_tensor,
|
|
187
|
+
epi_args,
|
|
188
|
+
scheduler_args,
|
|
189
|
+
varlen_args,
|
|
190
|
+
current_stream,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
gemm.compile_cache = {}
|