quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -1,25 +1,27 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
from
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Optional, Tuple, Union
|
|
5
6
|
|
|
6
7
|
import cutlass
|
|
7
8
|
import cutlass.cute as cute
|
|
8
9
|
|
|
9
|
-
from cutlass import Float32, Int32
|
|
10
|
+
from cutlass import Float32, Int32, const_expr
|
|
10
11
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
11
12
|
from cutlass._mlir.dialects import llvm, nvvm, vector
|
|
12
|
-
from cutlass.cute.runtime import from_dlpack
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
15
|
+
# cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default
|
|
16
|
+
fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
17
|
+
mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
18
|
+
add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
19
|
+
sub_packed_f32x2 = partial(
|
|
20
|
+
cute.arch.calc_packed_f32x2_op,
|
|
21
|
+
src_c=None,
|
|
22
|
+
calc_func=nvvm.sub_packed_f32x2,
|
|
23
|
+
rnd=nvvm.RoundingModeKind.RN,
|
|
24
|
+
)
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
@dsl_user_op
|
|
@@ -29,7 +31,7 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
|
|
|
29
31
|
|
|
30
32
|
@cute.jit
|
|
31
33
|
def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
|
|
32
|
-
if
|
|
34
|
+
if const_expr(isinstance(x, cute.Pointer)):
|
|
33
35
|
return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
|
|
34
36
|
else:
|
|
35
37
|
assert isinstance(x, Float32)
|
|
@@ -38,11 +40,11 @@ def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
|
|
|
38
40
|
|
|
39
41
|
@dsl_user_op
|
|
40
42
|
def set_block_rank(
|
|
41
|
-
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster:
|
|
42
|
-
) ->
|
|
43
|
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
|
|
44
|
+
) -> Int32:
|
|
43
45
|
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
|
44
46
|
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
45
|
-
return
|
|
47
|
+
return Int32(
|
|
46
48
|
llvm.inline_asm(
|
|
47
49
|
T.i32(),
|
|
48
50
|
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
|
|
@@ -57,7 +59,7 @@ def set_block_rank(
|
|
|
57
59
|
|
|
58
60
|
@dsl_user_op
|
|
59
61
|
def store_shared_remote(
|
|
60
|
-
val: float | Float32 | cutlass.Int64,
|
|
62
|
+
val: float | Float32 | Int32 | cutlass.Int64,
|
|
61
63
|
smem_ptr: cute.Pointer,
|
|
62
64
|
mbar_ptr: cute.Pointer,
|
|
63
65
|
peer_cta_rank_in_cluster: cute.typing.Int,
|
|
@@ -71,7 +73,7 @@ def store_shared_remote(
|
|
|
71
73
|
remote_mbar_ptr_i32 = set_block_rank(
|
|
72
74
|
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
73
75
|
).ir_value()
|
|
74
|
-
if
|
|
76
|
+
if const_expr(isinstance(val, float)):
|
|
75
77
|
val = Float32(val)
|
|
76
78
|
assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
|
|
77
79
|
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
|
|
@@ -100,6 +102,21 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
|
|
|
100
102
|
)
|
|
101
103
|
|
|
102
104
|
|
|
105
|
+
@dsl_user_op
|
|
106
|
+
def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
107
|
+
return Float32(
|
|
108
|
+
llvm.inline_asm(
|
|
109
|
+
T.f32(),
|
|
110
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
111
|
+
"sqrt.approx.f32 $0, $1;",
|
|
112
|
+
"=f,f",
|
|
113
|
+
has_side_effects=False,
|
|
114
|
+
is_align_stack=False,
|
|
115
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
103
120
|
@dsl_user_op
|
|
104
121
|
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
105
122
|
return Int32(
|
|
@@ -134,55 +151,6 @@ def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -
|
|
|
134
151
|
)
|
|
135
152
|
|
|
136
153
|
|
|
137
|
-
@cute.jit
|
|
138
|
-
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
|
139
|
-
assert t.element_type.width == 16
|
|
140
|
-
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
|
141
|
-
t_u32 = cute.recast_tensor(t, Int32)
|
|
142
|
-
|
|
143
|
-
quad_idx = cute.arch.lane_idx() % 4
|
|
144
|
-
lane_03 = quad_idx == 0 or quad_idx == 3
|
|
145
|
-
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
|
146
|
-
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
|
147
|
-
# upper_map = [0, 3, 1, 2]
|
|
148
|
-
# lower_map = [1, 2, 0, 3]
|
|
149
|
-
# upper_idx = upper_map[quad_idx]
|
|
150
|
-
# indexing isn't supported so we have to do arithmetic
|
|
151
|
-
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
|
152
|
-
lower_idx = upper_idx ^ 1
|
|
153
|
-
|
|
154
|
-
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
155
|
-
width = 4
|
|
156
|
-
mask = cute.arch.WARP_SIZE - width
|
|
157
|
-
clamp = cute.arch.WARP_SIZE - 1
|
|
158
|
-
mask_and_clamp = mask << 8 | clamp
|
|
159
|
-
|
|
160
|
-
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
|
161
|
-
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
|
162
|
-
upper0 = upper if lane_03 else lower
|
|
163
|
-
lower0 = lower if lane_03 else upper
|
|
164
|
-
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
|
165
|
-
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
|
166
|
-
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
|
167
|
-
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
@cute.jit
|
|
171
|
-
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
172
|
-
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
173
|
-
tApA = cute.make_fragment(
|
|
174
|
-
cute.make_layout(
|
|
175
|
-
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
176
|
-
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
177
|
-
),
|
|
178
|
-
cutlass.Boolean,
|
|
179
|
-
)
|
|
180
|
-
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
|
181
|
-
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
|
182
|
-
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
183
|
-
return tApA
|
|
184
|
-
|
|
185
|
-
|
|
186
154
|
@cute.jit
|
|
187
155
|
def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
|
|
188
156
|
"""Fill out-of-bounds values in shared memory tensor.
|
|
@@ -196,7 +164,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
|
|
|
196
164
|
tXrX_fill.fill(fill_value)
|
|
197
165
|
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
198
166
|
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
199
|
-
if
|
|
167
|
+
if const_expr(tXpX is not None):
|
|
200
168
|
if not tXpX[rest_v, 0, rest_k]:
|
|
201
169
|
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
202
170
|
else:
|
|
@@ -228,44 +196,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
228
196
|
return res0, res1
|
|
229
197
|
|
|
230
198
|
|
|
231
|
-
@dsl_user_op
|
|
232
|
-
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
233
|
-
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
234
|
-
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
235
|
-
assert len(flat_coord_i64) == len(
|
|
236
|
-
flat_stride
|
|
237
|
-
), "Coordinate and stride must have the same length"
|
|
238
|
-
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
239
|
-
assert isinstance(tensor.iterator, cute.Pointer)
|
|
240
|
-
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
241
|
-
new_ptr = cute.make_ptr(
|
|
242
|
-
tensor.element_type,
|
|
243
|
-
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
244
|
-
tensor.memspace,
|
|
245
|
-
assumed_align=tensor.iterator.max_alignment,
|
|
246
|
-
)
|
|
247
|
-
return cute.make_tensor(new_ptr, tensor.layout)
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
@dsl_user_op
|
|
251
|
-
def coord_offset_i64(
|
|
252
|
-
idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
|
|
253
|
-
) -> cute.Tensor:
|
|
254
|
-
offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
|
|
255
|
-
assert isinstance(tensor.iterator, cute.Pointer)
|
|
256
|
-
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
257
|
-
new_ptr = cute.make_ptr(
|
|
258
|
-
tensor.element_type,
|
|
259
|
-
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
260
|
-
tensor.memspace,
|
|
261
|
-
assumed_align=tensor.iterator.max_alignment,
|
|
262
|
-
)
|
|
263
|
-
return cute.make_tensor(new_ptr, tensor.layout)
|
|
264
|
-
|
|
265
|
-
|
|
266
199
|
@cute.jit
|
|
267
|
-
def warp_prefix_sum(val:
|
|
268
|
-
if
|
|
200
|
+
def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32:
|
|
201
|
+
if const_expr(lane is None):
|
|
269
202
|
lane = cute.arch.lane_idx()
|
|
270
203
|
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
|
271
204
|
offset = 1 << i
|
|
@@ -276,74 +209,6 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) ->
|
|
|
276
209
|
return val
|
|
277
210
|
|
|
278
211
|
|
|
279
|
-
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
|
280
|
-
"""
|
|
281
|
-
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
|
282
|
-
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
|
283
|
-
"""
|
|
284
|
-
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
|
285
|
-
acc_layout_mn = cute.make_layout(
|
|
286
|
-
(
|
|
287
|
-
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
|
288
|
-
(
|
|
289
|
-
acc_layout_col_major.shape[0][0],
|
|
290
|
-
*acc_layout_col_major.shape[0][2:],
|
|
291
|
-
acc_layout_col_major.shape[2],
|
|
292
|
-
), # MMA_N
|
|
293
|
-
*acc_layout_col_major.shape[3:],
|
|
294
|
-
),
|
|
295
|
-
stride=(
|
|
296
|
-
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
|
297
|
-
(
|
|
298
|
-
acc_layout_col_major.stride[0][0],
|
|
299
|
-
*acc_layout_col_major.stride[0][2:],
|
|
300
|
-
acc_layout_col_major.stride[2],
|
|
301
|
-
), # MMA_N
|
|
302
|
-
*acc_layout_col_major.stride[3:],
|
|
303
|
-
),
|
|
304
|
-
)
|
|
305
|
-
return cute.composition(acc_layout, acc_layout_mn)
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
|
309
|
-
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
@dsl_user_op
|
|
313
|
-
def sm90_get_smem_load_op(
|
|
314
|
-
layout_c: cutlass.utils.LayoutEnum,
|
|
315
|
-
elem_ty_c: Type[cutlass.Numeric],
|
|
316
|
-
*,
|
|
317
|
-
loc=None,
|
|
318
|
-
ip=None,
|
|
319
|
-
) -> cute.CopyAtom:
|
|
320
|
-
"""
|
|
321
|
-
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
|
322
|
-
|
|
323
|
-
Parameters:
|
|
324
|
-
-----------
|
|
325
|
-
layout_c : LayoutEnum
|
|
326
|
-
The layout enum of the output tensor D.
|
|
327
|
-
|
|
328
|
-
elem_ty_c : Type[Numeric]
|
|
329
|
-
The element type for output tensor D.
|
|
330
|
-
|
|
331
|
-
Returns:
|
|
332
|
-
--------
|
|
333
|
-
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
|
334
|
-
"""
|
|
335
|
-
|
|
336
|
-
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
|
337
|
-
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
|
338
|
-
is_m_major = layout_c.is_m_major_c()
|
|
339
|
-
if elem_ty_c.width == 16:
|
|
340
|
-
return cute.make_copy_atom(
|
|
341
|
-
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
|
342
|
-
)
|
|
343
|
-
else:
|
|
344
|
-
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
|
345
|
-
|
|
346
|
-
|
|
347
212
|
@dsl_user_op
|
|
348
213
|
def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
|
|
349
214
|
return nvvm.atomicrmw(
|
quack/varlen_utils.py
CHANGED
|
@@ -3,9 +3,13 @@
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
|
|
6
|
+
import cutlass
|
|
6
7
|
import cutlass.cute as cute
|
|
8
|
+
from cutlass import Int32, Boolean, const_expr
|
|
9
|
+
from cutlass.utils import LayoutEnum
|
|
7
10
|
|
|
8
|
-
from quack.cute_dsl_utils import ArgumentsBase
|
|
11
|
+
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
12
|
+
from quack.tensormap_manager import TensorMapManagerSm90
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
# Grouping arguments together that should be passed to __call__
|
|
@@ -14,9 +18,369 @@ class VarlenArguments(ArgumentsBase):
|
|
|
14
18
|
mCuSeqlensM: Optional[cute.Tensor] = None
|
|
15
19
|
mCuSeqlensK: Optional[cute.Tensor] = None
|
|
16
20
|
mTensormaps: Optional[cute.Tensor] = None
|
|
21
|
+
mAIdx: Optional[cute.Tensor] = None
|
|
17
22
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
+
|
|
24
|
+
class VarlenManager:
|
|
25
|
+
bytes_per_tensormap = 128
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class Params(ParamsBase):
|
|
29
|
+
cu_seqlens_m: Optional[cute.Tensor] = None
|
|
30
|
+
cu_seqlens_k: Optional[cute.Tensor] = None
|
|
31
|
+
tensormaps: Optional[cute.Tensor] = None
|
|
32
|
+
mAIdx: Optional[cute.Tensor] = None
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
@cute.jit
|
|
36
|
+
def create(args: VarlenArguments, *, loc=None, ip=None) -> "VarlenManager.Params":
|
|
37
|
+
return VarlenManager.Params(
|
|
38
|
+
cu_seqlens_m=args.mCuSeqlensM,
|
|
39
|
+
cu_seqlens_k=args.mCuSeqlensK,
|
|
40
|
+
tensormaps=args.mTensormaps,
|
|
41
|
+
mAIdx=args.mAIdx,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
params: Params,
|
|
47
|
+
tensormap_manager: Optional[cutlass.utils.TensorMapManager],
|
|
48
|
+
tensormap_a_ptr: Optional[cute.Pointer],
|
|
49
|
+
tensormap_b_ptr: Optional[cute.Pointer],
|
|
50
|
+
tensormap_d_ptr: Optional[cute.Pointer],
|
|
51
|
+
tensormap_epi_ptrs: list[Optional[cute.Pointer]],
|
|
52
|
+
len_m_static: Int32,
|
|
53
|
+
len_k_static: Int32,
|
|
54
|
+
last_batch_idx: Int32 = Int32(-1),
|
|
55
|
+
is_group_changed: Boolean = Boolean(True),
|
|
56
|
+
*,
|
|
57
|
+
loc=None,
|
|
58
|
+
ip=None,
|
|
59
|
+
):
|
|
60
|
+
self.params = params
|
|
61
|
+
self.tensormap_manager = tensormap_manager
|
|
62
|
+
self._tensormap_a_ptr = tensormap_a_ptr
|
|
63
|
+
self._tensormap_b_ptr = tensormap_b_ptr
|
|
64
|
+
self._tensormap_d_ptr = tensormap_d_ptr
|
|
65
|
+
self._tensormap_epi_ptrs = tensormap_epi_ptrs
|
|
66
|
+
self._len_m_static = len_m_static
|
|
67
|
+
self._len_k_static = len_k_static
|
|
68
|
+
self._last_batch_idx = last_batch_idx
|
|
69
|
+
self._is_group_changed = is_group_changed
|
|
70
|
+
self.varlen_m = const_expr(params.cu_seqlens_m is not None)
|
|
71
|
+
self.varlen_k = const_expr(params.cu_seqlens_k is not None)
|
|
72
|
+
self.gather_A = const_expr(params.mAIdx is not None)
|
|
73
|
+
self._loc = loc
|
|
74
|
+
self._ip = ip
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def to_underlying_arguments(args: VarlenArguments, *, loc=None, ip=None) -> Params:
|
|
78
|
+
assert not (args.mCuSeqlensM is not None and args.mCuSeqlensK is not None), (
|
|
79
|
+
"Only support either varlen_m or varlen_k"
|
|
80
|
+
)
|
|
81
|
+
return VarlenManager.Params.create(args, loc=loc, ip=ip)
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
@cute.jit
|
|
85
|
+
def create(
|
|
86
|
+
params: Params,
|
|
87
|
+
has_D: bool,
|
|
88
|
+
num_epi_tensormaps: int,
|
|
89
|
+
len_m_static: Int32,
|
|
90
|
+
len_k_static: Int32,
|
|
91
|
+
pingpong: bool = False,
|
|
92
|
+
warp_idx: int | Int32 = 0,
|
|
93
|
+
*,
|
|
94
|
+
loc=None,
|
|
95
|
+
ip=None,
|
|
96
|
+
) -> "VarlenManager":
|
|
97
|
+
tensormap_manager = None
|
|
98
|
+
tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
|
|
99
|
+
tensormap_epi_ptrs = [None] * num_epi_tensormaps
|
|
100
|
+
varlen_m = const_expr(params.cu_seqlens_m is not None)
|
|
101
|
+
varlen_k = const_expr(params.cu_seqlens_k is not None)
|
|
102
|
+
if const_expr(varlen_m or varlen_k):
|
|
103
|
+
tensormap_manager = TensorMapManagerSm90(
|
|
104
|
+
cutlass.utils.TensorMapUpdateMode.GMEM, VarlenManager.bytes_per_tensormap
|
|
105
|
+
)
|
|
106
|
+
# equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
|
|
107
|
+
tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
|
|
108
|
+
if const_expr(varlen_m):
|
|
109
|
+
tensormap_d_idx = warp_idx // 4 if const_expr(pingpong) else 0
|
|
110
|
+
tensormap_epi_offset = tensormap_d_idx
|
|
111
|
+
if const_expr(has_D):
|
|
112
|
+
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
113
|
+
params.tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
|
|
114
|
+
)
|
|
115
|
+
tensormap_epi_offset += 1 if not pingpong else 2
|
|
116
|
+
tensormap_epi_ptrs = [
|
|
117
|
+
tensormap_manager.get_tensormap_ptr(
|
|
118
|
+
params.tensormaps[
|
|
119
|
+
tensormap_workspace_idx,
|
|
120
|
+
tensormap_epi_offset + i * (1 if not pingpong else 2),
|
|
121
|
+
None,
|
|
122
|
+
].iterator
|
|
123
|
+
)
|
|
124
|
+
for i in range(num_epi_tensormaps)
|
|
125
|
+
]
|
|
126
|
+
else:
|
|
127
|
+
assert varlen_k
|
|
128
|
+
gather_A = const_expr(params.mAIdx is not None)
|
|
129
|
+
if const_expr(not gather_A):
|
|
130
|
+
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
131
|
+
params.tensormaps[tensormap_workspace_idx, 0, None].iterator
|
|
132
|
+
)
|
|
133
|
+
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
134
|
+
params.tensormaps[
|
|
135
|
+
tensormap_workspace_idx, 1 if not gather_A else 0, None
|
|
136
|
+
].iterator
|
|
137
|
+
)
|
|
138
|
+
return VarlenManager(
|
|
139
|
+
params,
|
|
140
|
+
tensormap_manager,
|
|
141
|
+
tensormap_a_ptr,
|
|
142
|
+
tensormap_b_ptr,
|
|
143
|
+
tensormap_d_ptr,
|
|
144
|
+
tensormap_epi_ptrs,
|
|
145
|
+
len_m_static=len_m_static,
|
|
146
|
+
len_k_static=len_k_static,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def len_m(self, batch_idx: Int32) -> Int32:
|
|
150
|
+
if const_expr(self.varlen_m):
|
|
151
|
+
return self.params.cu_seqlens_m[batch_idx + 1] - self.params.cu_seqlens_m[batch_idx]
|
|
152
|
+
else:
|
|
153
|
+
return self._len_m_static
|
|
154
|
+
|
|
155
|
+
def len_k(self, batch_idx: Int32) -> Int32:
|
|
156
|
+
if const_expr(self.varlen_k):
|
|
157
|
+
return self.params.cu_seqlens_k[batch_idx + 1] - self.params.cu_seqlens_k[batch_idx]
|
|
158
|
+
else:
|
|
159
|
+
return self._len_k_static
|
|
160
|
+
|
|
161
|
+
def offset_batch_A(self, mA_mkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
|
|
162
|
+
params = self.params
|
|
163
|
+
if const_expr(self.varlen_m):
|
|
164
|
+
mA_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mA_mkl)
|
|
165
|
+
elif const_expr(self.varlen_k):
|
|
166
|
+
mA_mk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mA_mkl)
|
|
167
|
+
else:
|
|
168
|
+
mA_mk = mA_mkl[None, None, batch_idx]
|
|
169
|
+
return mA_mk
|
|
170
|
+
|
|
171
|
+
def offset_batch_AIdx(self, batch_idx: Int32) -> cute.Tensor:
|
|
172
|
+
params = self.params
|
|
173
|
+
if const_expr(self.varlen_m):
|
|
174
|
+
mAIdx_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx],), params.mAIdx)
|
|
175
|
+
elif const_expr(self.varlen_k):
|
|
176
|
+
mAIdx_mk = cute.domain_offset((params.cu_seqlens_k[batch_idx],), params.mAIdx)
|
|
177
|
+
else:
|
|
178
|
+
mAIdx_mk = params.mAIdx[None, batch_idx]
|
|
179
|
+
return mAIdx_mk
|
|
180
|
+
|
|
181
|
+
def offset_batch_B(self, mB_nkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
|
|
182
|
+
params = self.params
|
|
183
|
+
if const_expr(self.varlen_k):
|
|
184
|
+
mB_nk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mB_nkl)
|
|
185
|
+
else:
|
|
186
|
+
mB_nk = mB_nkl[None, None, batch_idx]
|
|
187
|
+
return mB_nk
|
|
188
|
+
|
|
189
|
+
def offset_batch_epi(self, mD_mnl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
|
|
190
|
+
params = self.params
|
|
191
|
+
if const_expr(self.varlen_m):
|
|
192
|
+
mD_mn = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mD_mnl)
|
|
193
|
+
else:
|
|
194
|
+
mD_mn = mD_mnl[None, None, batch_idx]
|
|
195
|
+
return mD_mn
|
|
196
|
+
|
|
197
|
+
def init_tensormap_AB(
|
|
198
|
+
self,
|
|
199
|
+
tma_atom_a: Optional[cute.CopyAtom],
|
|
200
|
+
tma_atom_b: cute.CopyAtom,
|
|
201
|
+
is_manager_warp: bool | Boolean = True,
|
|
202
|
+
) -> None:
|
|
203
|
+
if const_expr(self.varlen_k):
|
|
204
|
+
if const_expr(not self.gather_A):
|
|
205
|
+
self.tensormap_manager.init_tensormap_from_atom(
|
|
206
|
+
tma_atom_a, self._tensormap_a_ptr, is_manager_warp
|
|
207
|
+
)
|
|
208
|
+
self.tensormap_manager.init_tensormap_from_atom(
|
|
209
|
+
tma_atom_b, self._tensormap_b_ptr, is_manager_warp
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def init_tensormap_epi(
|
|
213
|
+
self,
|
|
214
|
+
tma_atom_d: Optional[cute.CopyAtom],
|
|
215
|
+
tma_atoms_epi: list[cute.CopyAtom],
|
|
216
|
+
is_manager_warp: bool | Boolean = True,
|
|
217
|
+
) -> None:
|
|
218
|
+
if const_expr(self.varlen_m):
|
|
219
|
+
if const_expr(self._tensormap_d_ptr is not None):
|
|
220
|
+
self.tensormap_manager.init_tensormap_from_atom(
|
|
221
|
+
tma_atom_d, self._tensormap_d_ptr, is_manager_warp
|
|
222
|
+
)
|
|
223
|
+
for tma_atom, tensormap_epi_ptr in zip(tma_atoms_epi, self._tensormap_epi_ptrs):
|
|
224
|
+
self.tensormap_manager.init_tensormap_from_atom(
|
|
225
|
+
tma_atom, tensormap_epi_ptr, is_manager_warp
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def fence_tensormap_init(self) -> None:
|
|
229
|
+
self.tensormap_manager.fence_tensormap_initialization()
|
|
230
|
+
|
|
231
|
+
@cute.jit
|
|
232
|
+
def update_tensormap_AB(
|
|
233
|
+
self,
|
|
234
|
+
batch_idx: Int32,
|
|
235
|
+
a_layout: LayoutEnum,
|
|
236
|
+
b_layout: LayoutEnum,
|
|
237
|
+
is_manager_warp: bool | Boolean = True,
|
|
238
|
+
) -> None:
|
|
239
|
+
if const_expr(self.varlen_k):
|
|
240
|
+
self._is_group_changed = Boolean(batch_idx != self._last_batch_idx)
|
|
241
|
+
self._last_batch_idx = batch_idx
|
|
242
|
+
if self._is_group_changed:
|
|
243
|
+
# construct tensor A/B based on real address, shape and stride information
|
|
244
|
+
cu_seqlens_k = self.params.cu_seqlens_k
|
|
245
|
+
tensormap_ptrs = [self._tensormap_b_ptr]
|
|
246
|
+
shapes = [cu_seqlens_k[batch_idx + 1]]
|
|
247
|
+
orders = [0 if const_expr(b_layout == LayoutEnum.ROW_MAJOR) else 1]
|
|
248
|
+
if const_expr(not self.gather_A):
|
|
249
|
+
tensormap_ptrs.insert(0, self._tensormap_a_ptr)
|
|
250
|
+
shapes.insert(0, cu_seqlens_k[batch_idx + 1])
|
|
251
|
+
orders.insert(0, 0 if const_expr(a_layout == LayoutEnum.ROW_MAJOR) else 1)
|
|
252
|
+
self.tensormap_manager.update_tensormap_shape(
|
|
253
|
+
tensormap_ptrs,
|
|
254
|
+
is_manager_warp=is_manager_warp,
|
|
255
|
+
shapes=shapes,
|
|
256
|
+
orders=orders,
|
|
257
|
+
tensormap_smem_ptr=None,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
@cute.jit
|
|
261
|
+
def update_tensormap_epi(
|
|
262
|
+
self,
|
|
263
|
+
batch_idx: Int32,
|
|
264
|
+
d_layout: LayoutEnum,
|
|
265
|
+
epi_shapes: list[Int32],
|
|
266
|
+
epi_orders: list[int],
|
|
267
|
+
is_manager_warp: bool | Boolean = True,
|
|
268
|
+
) -> None:
|
|
269
|
+
if const_expr(self.varlen_m):
|
|
270
|
+
self._is_group_changed = Boolean(batch_idx != self._last_batch_idx)
|
|
271
|
+
self._last_batch_idx = batch_idx
|
|
272
|
+
# Cute-DSL doesn't like this under if statement
|
|
273
|
+
order_d = (
|
|
274
|
+
(0 if const_expr(d_layout.is_m_major_c()) else 1) if d_layout is not None else None
|
|
275
|
+
)
|
|
276
|
+
if self._is_group_changed:
|
|
277
|
+
# construct tensor A/B based on real address, shape and stride information
|
|
278
|
+
cu_seqlens_m = self.params.cu_seqlens_m
|
|
279
|
+
# construct tensor D based on real address, shape and stride information
|
|
280
|
+
tensormap_ptrs, shapes, orders = [], [], []
|
|
281
|
+
if const_expr(self._tensormap_d_ptr is not None):
|
|
282
|
+
tensormap_ptrs.append(self._tensormap_d_ptr)
|
|
283
|
+
shapes.append(cu_seqlens_m[batch_idx + 1])
|
|
284
|
+
orders.append(order_d)
|
|
285
|
+
tensormap_ptrs.extend(self._tensormap_epi_ptrs)
|
|
286
|
+
shapes.extend(epi_shapes)
|
|
287
|
+
orders.extend(epi_orders)
|
|
288
|
+
self.tensormap_manager.update_tensormap_shape(
|
|
289
|
+
tensormap_ptrs,
|
|
290
|
+
is_manager_warp=is_manager_warp,
|
|
291
|
+
shapes=shapes,
|
|
292
|
+
orders=orders,
|
|
293
|
+
tensormap_smem_ptr=None,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
@cute.jit
|
|
297
|
+
def fence_tensormap_update_AB(self, is_manager_warp: bool | Boolean = True) -> None:
|
|
298
|
+
if const_expr(self.varlen_k):
|
|
299
|
+
if self._is_group_changed and is_manager_warp:
|
|
300
|
+
if const_expr(not self.gather_A):
|
|
301
|
+
self.tensormap_manager.fence_tensormap_update(self._tensormap_a_ptr)
|
|
302
|
+
self.tensormap_manager.fence_tensormap_update(self._tensormap_b_ptr)
|
|
303
|
+
|
|
304
|
+
@cute.jit
|
|
305
|
+
def fence_tensormap_update_epi(self, is_manager_warp: bool | Boolean = True) -> None:
|
|
306
|
+
if const_expr(self.varlen_m):
|
|
307
|
+
if self._is_group_changed and is_manager_warp:
|
|
308
|
+
if const_expr(self._tensormap_d_ptr is not None):
|
|
309
|
+
self.tensormap_manager.fence_tensormap_update(self._tensormap_d_ptr)
|
|
310
|
+
for tensormap_epi_ptr in self._tensormap_epi_ptrs:
|
|
311
|
+
if const_expr(tensormap_epi_ptr is not None):
|
|
312
|
+
self.tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
|
|
313
|
+
|
|
314
|
+
def get_tma_desc_a_ptr(self) -> Optional[cute.Pointer]:
|
|
315
|
+
tma_desc_a_ptr = None
|
|
316
|
+
if const_expr(self.varlen_k and self._tensormap_a_ptr is not None):
|
|
317
|
+
tma_desc_a_ptr = self.tensormap_manager.get_tensormap_ptr(
|
|
318
|
+
self._tensormap_a_ptr, cute.AddressSpace.generic
|
|
319
|
+
)
|
|
320
|
+
return tma_desc_a_ptr
|
|
321
|
+
|
|
322
|
+
def get_tma_desc_b_ptr(self) -> Optional[cute.Pointer]:
|
|
323
|
+
tma_desc_b_ptr = None
|
|
324
|
+
if const_expr(self.varlen_k):
|
|
325
|
+
tma_desc_b_ptr = self.tensormap_manager.get_tensormap_ptr(
|
|
326
|
+
self._tensormap_b_ptr, cute.AddressSpace.generic
|
|
327
|
+
)
|
|
328
|
+
return tma_desc_b_ptr
|
|
329
|
+
|
|
330
|
+
def get_tma_desc_d_ptr(self) -> Optional[cute.Pointer]:
|
|
331
|
+
tma_desc_d_ptr = None
|
|
332
|
+
if const_expr(self.varlen_m and self._tensormap_d_ptr is not None):
|
|
333
|
+
tma_desc_d_ptr = self.tensormap_manager.get_tensormap_ptr(
|
|
334
|
+
self._tensormap_d_ptr, cute.AddressSpace.generic
|
|
335
|
+
)
|
|
336
|
+
return tma_desc_d_ptr
|
|
337
|
+
|
|
338
|
+
def get_tma_desc_epi_ptrs(self) -> list[Optional[cute.Pointer]]:
|
|
339
|
+
tma_desc_epi_ptrs = [None] * len(self._tensormap_epi_ptrs)
|
|
340
|
+
if const_expr(self.varlen_m):
|
|
341
|
+
for i, tensormap_epi_ptr in enumerate(self._tensormap_epi_ptrs):
|
|
342
|
+
if const_expr(tensormap_epi_ptr is not None):
|
|
343
|
+
tma_desc_epi_ptrs[i] = self.tensormap_manager.get_tensormap_ptr(
|
|
344
|
+
tensormap_epi_ptr, cute.AddressSpace.generic
|
|
345
|
+
)
|
|
346
|
+
return tma_desc_epi_ptrs
|
|
347
|
+
|
|
348
|
+
def __extract_mlir_values__(self):
|
|
349
|
+
values, self._values_pos = [], []
|
|
350
|
+
for obj in [
|
|
351
|
+
self.params,
|
|
352
|
+
self.tensormap_manager,
|
|
353
|
+
self._tensormap_a_ptr,
|
|
354
|
+
self._tensormap_b_ptr,
|
|
355
|
+
self._tensormap_d_ptr,
|
|
356
|
+
self._tensormap_epi_ptrs,
|
|
357
|
+
self._len_m_static,
|
|
358
|
+
self._len_k_static,
|
|
359
|
+
self._last_batch_idx,
|
|
360
|
+
self._is_group_changed,
|
|
361
|
+
]:
|
|
362
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
363
|
+
values += obj_values
|
|
364
|
+
self._values_pos.append(len(obj_values))
|
|
365
|
+
return values
|
|
366
|
+
|
|
367
|
+
def __new_from_mlir_values__(self, values):
|
|
368
|
+
obj_list = []
|
|
369
|
+
for obj, n_items in zip(
|
|
370
|
+
[
|
|
371
|
+
self.params,
|
|
372
|
+
self.tensormap_manager,
|
|
373
|
+
self._tensormap_a_ptr,
|
|
374
|
+
self._tensormap_b_ptr,
|
|
375
|
+
self._tensormap_d_ptr,
|
|
376
|
+
self._tensormap_epi_ptrs,
|
|
377
|
+
self._len_m_static,
|
|
378
|
+
self._len_k_static,
|
|
379
|
+
self._last_batch_idx,
|
|
380
|
+
self._is_group_changed,
|
|
381
|
+
],
|
|
382
|
+
self._values_pos,
|
|
383
|
+
):
|
|
384
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
385
|
+
values = values[n_items:]
|
|
386
|
+
return self.__class__(*(tuple(obj_list)), loc=self._loc)
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.3.3
|
|
7
7
|
Requires-Dist: torch
|
|
8
|
+
Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
|
|
9
|
+
Requires-Dist: torch-c-dlpack-ext
|
|
8
10
|
Provides-Extra: dev
|
|
9
11
|
Requires-Dist: pre-commit; extra == "dev"
|
|
10
12
|
Requires-Dist: ruff; extra == "dev"
|