quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +16 -25
- quack/autotuner.py +64 -5
- quack/cross_entropy.py +6 -10
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/layernorm.py +1 -1
- quack/reduce.py +6 -7
- quack/rmsnorm.py +126 -158
- quack/softmax.py +1 -1
- quack/tile_scheduler.py +37 -49
- quack/utils.py +61 -71
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +3 -3
- quack_kernels-0.2.2.dist-info/RECORD +37 -0
- quack_kernels-0.2.0.dist-info/RECORD +0 -37
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/gemm_wrapper_utils.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
from typing import Optional, Tuple, Dict, Any
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
+
import torch
|
|
5
6
|
from torch import Tensor
|
|
6
7
|
|
|
7
8
|
import cutlass.cute as cute
|
|
@@ -9,6 +10,7 @@ from cutlass import Int32
|
|
|
9
10
|
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
10
11
|
|
|
11
12
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
13
|
+
from quack.varlen_utils import VarlenArguments
|
|
12
14
|
from quack.dense_gemm_sm90 import TileSchedulerOptions
|
|
13
15
|
|
|
14
16
|
|
|
@@ -22,8 +24,8 @@ class GemmTensorInfo:
|
|
|
22
24
|
|
|
23
25
|
class GemmWrapperBase:
|
|
24
26
|
@staticmethod
|
|
25
|
-
def
|
|
26
|
-
assert tensor.dim() ==
|
|
27
|
+
def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None:
|
|
28
|
+
assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor"
|
|
27
29
|
assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
|
|
28
30
|
|
|
29
31
|
@staticmethod
|
|
@@ -47,7 +49,7 @@ class GemmWrapperBase:
|
|
|
47
49
|
) -> Optional[cute.Tensor]:
|
|
48
50
|
if tensor is None:
|
|
49
51
|
return None
|
|
50
|
-
# Tensor is already permuted to (dims[0], dims[1], dims[2])
|
|
52
|
+
# Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1])
|
|
51
53
|
# If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
|
|
52
54
|
leading_dim = 1 if major == dims[1] else 0
|
|
53
55
|
return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
|
|
@@ -61,43 +63,131 @@ class GemmWrapperBase:
|
|
|
61
63
|
D: Optional[Tensor] = None,
|
|
62
64
|
C: Optional[Tensor] = None,
|
|
63
65
|
additional_tensors: Optional[Dict[str, Tensor]] = None,
|
|
66
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
67
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
68
|
+
A_idx: Optional[Tensor] = None,
|
|
64
69
|
) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
_, N, _ = B.shape
|
|
70
|
+
assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
|
|
71
|
+
"Only one of cu_seqlens_m and cu_seqlens_k can be specified"
|
|
72
|
+
)
|
|
69
73
|
assert B.dtype == A.dtype, "A and B must have the same dtype"
|
|
70
|
-
|
|
74
|
+
|
|
75
|
+
# Validate A_idx if provided (for gather_A case)
|
|
76
|
+
gather_A = A_idx is not None
|
|
77
|
+
if gather_A:
|
|
78
|
+
assert cu_seqlens_m is not None or cu_seqlens_k is not None, (
|
|
79
|
+
"gather_A requires either varlen_m or varlen_k"
|
|
80
|
+
)
|
|
81
|
+
assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}"
|
|
82
|
+
assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D"
|
|
83
|
+
|
|
84
|
+
# Determine mode and extract dimensions
|
|
85
|
+
if cu_seqlens_m is not None:
|
|
86
|
+
# varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n)
|
|
87
|
+
assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D"
|
|
88
|
+
assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D"
|
|
89
|
+
|
|
90
|
+
if gather_A:
|
|
91
|
+
# When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M
|
|
92
|
+
total_M = A_idx.shape[0]
|
|
93
|
+
_, K = A.shape
|
|
94
|
+
else:
|
|
95
|
+
total_M, K = A.shape
|
|
96
|
+
|
|
97
|
+
L, N, K_B = B.shape
|
|
98
|
+
assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
|
|
99
|
+
assert cu_seqlens_m.shape == (L + 1,), (
|
|
100
|
+
f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}"
|
|
101
|
+
)
|
|
102
|
+
M = total_M
|
|
103
|
+
dc_shape = (total_M, N)
|
|
104
|
+
dc_ndim = 2
|
|
105
|
+
elif cu_seqlens_k is not None:
|
|
106
|
+
# varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n)
|
|
107
|
+
assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D"
|
|
108
|
+
assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D"
|
|
109
|
+
|
|
110
|
+
if gather_A:
|
|
111
|
+
# When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K
|
|
112
|
+
M, _ = A.shape
|
|
113
|
+
total_K = A_idx.shape[0]
|
|
114
|
+
else:
|
|
115
|
+
M, total_K = A.shape
|
|
116
|
+
|
|
117
|
+
N, K_B = B.shape
|
|
118
|
+
assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}"
|
|
119
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
120
|
+
assert cu_seqlens_k.shape == (L + 1,), (
|
|
121
|
+
f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}"
|
|
122
|
+
)
|
|
123
|
+
K = total_K
|
|
124
|
+
dc_shape = (L, M, N)
|
|
125
|
+
dc_ndim = 3
|
|
126
|
+
else:
|
|
127
|
+
# Normal case - all tensors must be 3D
|
|
128
|
+
GemmWrapperBase.validate_tensor(A, "A", 3)
|
|
129
|
+
GemmWrapperBase.validate_tensor(B, "B", 3)
|
|
130
|
+
L, M, K = A.shape
|
|
131
|
+
_, N, K_B = B.shape
|
|
132
|
+
assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
|
|
133
|
+
GemmWrapperBase.validate_shape(B, (L, N, K), "B")
|
|
134
|
+
dc_shape = (L, M, N)
|
|
135
|
+
dc_ndim = 3
|
|
136
|
+
|
|
137
|
+
# Validate D and C shapes uniformly
|
|
138
|
+
for tensor, name in [(D, "D"), (C, "C")]:
|
|
139
|
+
if tensor is not None:
|
|
140
|
+
assert tensor.dim() == dc_ndim, (
|
|
141
|
+
f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
|
|
142
|
+
)
|
|
143
|
+
assert tensor.shape == dc_shape, (
|
|
144
|
+
f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
|
|
145
|
+
)
|
|
146
|
+
|
|
71
147
|
tensors = {
|
|
72
148
|
"A": GemmTensorInfo(A),
|
|
73
149
|
"B": GemmTensorInfo(B),
|
|
74
150
|
"D": GemmTensorInfo(D),
|
|
75
151
|
"C": GemmTensorInfo(C),
|
|
76
152
|
}
|
|
77
|
-
|
|
78
|
-
GemmWrapperBase.validate_tensor_3d(D, "D")
|
|
79
|
-
GemmWrapperBase.validate_shape(D, (L, M, N), "D")
|
|
80
|
-
if C is not None:
|
|
81
|
-
GemmWrapperBase.validate_tensor_3d(C, "C")
|
|
82
|
-
GemmWrapperBase.validate_shape(C, (L, M, N), "C")
|
|
153
|
+
|
|
83
154
|
if additional_tensors:
|
|
84
155
|
for name, tensor in additional_tensors.items():
|
|
85
156
|
if tensor is not None:
|
|
86
|
-
|
|
87
|
-
|
|
157
|
+
assert tensor.dim() == dc_ndim, (
|
|
158
|
+
f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
|
|
159
|
+
)
|
|
160
|
+
assert tensor.shape == dc_shape, (
|
|
161
|
+
f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
|
|
162
|
+
)
|
|
88
163
|
tensors[name] = GemmTensorInfo(tensor)
|
|
89
164
|
|
|
90
165
|
return L, M, K, N, tensors
|
|
91
166
|
|
|
92
167
|
@staticmethod
|
|
93
|
-
def permute_tensors(
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
168
|
+
def permute_tensors(
|
|
169
|
+
tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False
|
|
170
|
+
) -> None:
|
|
171
|
+
# Determine which tensors need permutation
|
|
172
|
+
if varlen_m:
|
|
173
|
+
# Only B needs permutation (3D tensor)
|
|
174
|
+
tensors_to_permute = ["B"]
|
|
175
|
+
elif varlen_k:
|
|
176
|
+
# Only D and C need permutation (3D tensors)
|
|
177
|
+
tensors_to_permute = ["D", "C"]
|
|
178
|
+
else:
|
|
179
|
+
# All tensors need permutation
|
|
180
|
+
tensors_to_permute = None
|
|
181
|
+
|
|
182
|
+
# Apply permutation from (L, *, *) -> (*, *, L) for selected tensors
|
|
183
|
+
for name, info in tensors.items():
|
|
184
|
+
if info.tensor is not None and info.tensor.ndim == 3:
|
|
185
|
+
if tensors_to_permute is None or name in tensors_to_permute:
|
|
186
|
+
info.tensor = info.tensor.permute(1, 2, 0)
|
|
97
187
|
|
|
98
188
|
@staticmethod
|
|
99
189
|
def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
|
|
100
|
-
for info in tensors.
|
|
190
|
+
for name, info in tensors.items():
|
|
101
191
|
if info.tensor is not None:
|
|
102
192
|
info.dtype = torch2cute_dtype_map[info.tensor.dtype]
|
|
103
193
|
|
|
@@ -121,7 +211,9 @@ class GemmWrapperBase:
|
|
|
121
211
|
|
|
122
212
|
@staticmethod
|
|
123
213
|
def create_scheduler_args(
|
|
124
|
-
max_active_clusters: int,
|
|
214
|
+
max_active_clusters: int,
|
|
215
|
+
tile_count_semaphore: Optional[Tensor] = None,
|
|
216
|
+
batch_idx_permute: Optional[Tensor] = None,
|
|
125
217
|
) -> TileSchedulerOptions:
|
|
126
218
|
return TileSchedulerOptions(
|
|
127
219
|
Int32(max_active_clusters),
|
|
@@ -130,6 +222,71 @@ class GemmWrapperBase:
|
|
|
130
222
|
)
|
|
131
223
|
if tile_count_semaphore is not None
|
|
132
224
|
else None,
|
|
225
|
+
batch_idx_permute=(
|
|
226
|
+
from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
227
|
+
)
|
|
228
|
+
if batch_idx_permute is not None
|
|
229
|
+
else None,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
def create_varlen_args(
|
|
234
|
+
cu_seqlens_m: Optional[Tensor],
|
|
235
|
+
cu_seqlens_k: Optional[Tensor],
|
|
236
|
+
A_idx: Optional[Tensor],
|
|
237
|
+
max_active_clusters: int,
|
|
238
|
+
cluster_shape_mnk: Tuple[int, int, int],
|
|
239
|
+
tensors: Dict[str, GemmTensorInfo],
|
|
240
|
+
num_epi_tensormaps: int = 0,
|
|
241
|
+
pingpong: bool = False,
|
|
242
|
+
) -> Optional[Any]:
|
|
243
|
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
|
244
|
+
return None
|
|
245
|
+
# When varlen_m, we assume persistent=True
|
|
246
|
+
# Grid size depends on num_active_clusters and cluster size
|
|
247
|
+
cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
|
248
|
+
num_blocks = max_active_clusters * cluster_size
|
|
249
|
+
# Calculate number of tensormaps needed
|
|
250
|
+
if cu_seqlens_m is not None:
|
|
251
|
+
# For varlen_m: need tensormaps for D and epilogue tensors
|
|
252
|
+
num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2)
|
|
253
|
+
if tensors["D"].tensor is not None:
|
|
254
|
+
num_tensormaps += 1 if not pingpong else 2 # D tensormap
|
|
255
|
+
else:
|
|
256
|
+
# For varlen_k: need tensormaps for A & B
|
|
257
|
+
num_tensormaps = 2 if A_idx is None else 1
|
|
258
|
+
# Create tensormap buffer (each tensormap is 128 bytes = 16 int64s)
|
|
259
|
+
tensormap_size = 128 // 8 # 16 int64s
|
|
260
|
+
if num_tensormaps > 0:
|
|
261
|
+
device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device
|
|
262
|
+
tensormaps = torch.empty(
|
|
263
|
+
(num_blocks, num_tensormaps, tensormap_size),
|
|
264
|
+
dtype=torch.int64,
|
|
265
|
+
device=device,
|
|
266
|
+
)
|
|
267
|
+
tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic(
|
|
268
|
+
mode=0, stride_order=(0, 1, 2)
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
tensormaps_cute = None
|
|
272
|
+
|
|
273
|
+
return VarlenArguments(
|
|
274
|
+
mCuSeqlensM=(
|
|
275
|
+
from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
276
|
+
if cu_seqlens_m is not None
|
|
277
|
+
else None
|
|
278
|
+
),
|
|
279
|
+
mCuSeqlensK=(
|
|
280
|
+
from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
281
|
+
if cu_seqlens_k is not None
|
|
282
|
+
else None
|
|
283
|
+
),
|
|
284
|
+
mTensormaps=tensormaps_cute,
|
|
285
|
+
mAIdx=(
|
|
286
|
+
from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
287
|
+
if A_idx is not None
|
|
288
|
+
else None
|
|
289
|
+
),
|
|
133
290
|
)
|
|
134
291
|
|
|
135
292
|
@staticmethod
|
quack/layernorm.py
CHANGED
|
@@ -217,7 +217,7 @@ class LayerNorm(ReductionBase):
|
|
|
217
217
|
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
218
218
|
init_val=0.0,
|
|
219
219
|
)
|
|
220
|
-
rstd =
|
|
220
|
+
rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
|
|
221
221
|
if cutlass.const_expr(mRstd is not None):
|
|
222
222
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
223
223
|
if (
|
quack/reduce.py
CHANGED
|
@@ -159,8 +159,7 @@ def online_softmax_reduce(
|
|
|
159
159
|
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
160
160
|
)
|
|
161
161
|
log2_e = math.log2(math.e)
|
|
162
|
-
exp_x =
|
|
163
|
-
# exp_x = exp2f((x - max_x) * log2_e)
|
|
162
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
|
164
163
|
sum_exp_x = warp_reduce(
|
|
165
164
|
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
166
165
|
operator.add,
|
|
@@ -190,10 +189,10 @@ def online_softmax_reduce(
|
|
|
190
189
|
reduction_buffer[row_idx, lane_idx]
|
|
191
190
|
)
|
|
192
191
|
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
193
|
-
sum_exp_x *=
|
|
192
|
+
sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
|
|
194
193
|
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
195
194
|
if cutlass.const_expr(return_exp_x):
|
|
196
|
-
exp_x *=
|
|
195
|
+
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
|
197
196
|
max_x = max_x_final
|
|
198
197
|
else:
|
|
199
198
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
@@ -231,11 +230,11 @@ def online_softmax_reduce(
|
|
|
231
230
|
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
232
231
|
sum_exp_x = 0.0
|
|
233
232
|
for i in cutlass.range_constexpr(num_iter):
|
|
234
|
-
sum_exp_x += sum_exp_x_single_warp[i] *
|
|
235
|
-
|
|
233
|
+
sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
|
|
234
|
+
max_x_single_warp[i] - max_x_final, fastmath=True
|
|
236
235
|
)
|
|
237
236
|
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
238
237
|
if cutlass.const_expr(return_exp_x):
|
|
239
|
-
exp_x *=
|
|
238
|
+
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
|
240
239
|
max_x = max_x_final
|
|
241
240
|
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|