quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
from
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Optional, Tuple, Union
|
|
5
6
|
|
|
6
7
|
import cutlass
|
|
7
8
|
import cutlass.cute as cute
|
|
@@ -9,70 +10,18 @@ import cutlass.cute as cute
|
|
|
9
10
|
from cutlass import Float32, Int32, const_expr
|
|
10
11
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
11
12
|
from cutlass._mlir.dialects import llvm, nvvm, vector
|
|
12
|
-
from cutlass.cute.runtime import from_dlpack
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
|
26
|
-
"""Transpose the first two dimensions of a tensor on smem."""
|
|
27
|
-
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
|
28
|
-
order = (1, 0, *range(2, cute.rank(a)))
|
|
29
|
-
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
|
33
|
-
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@dsl_user_op
|
|
37
|
-
def get_copy_atom(
|
|
38
|
-
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
|
39
|
-
) -> cute.CopyAtom:
|
|
40
|
-
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
|
41
|
-
copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
42
|
-
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@dsl_user_op
|
|
46
|
-
def copy(
|
|
47
|
-
src: cute.Tensor,
|
|
48
|
-
dst: cute.Tensor,
|
|
49
|
-
*,
|
|
50
|
-
pred: Optional[cute.Tensor] = None,
|
|
51
|
-
num_copy_elems: int = 1,
|
|
52
|
-
is_async: bool = False,
|
|
53
|
-
loc=None,
|
|
54
|
-
ip=None,
|
|
55
|
-
**kwargs,
|
|
56
|
-
) -> None:
|
|
57
|
-
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
|
58
|
-
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def tiled_copy_2d(
|
|
62
|
-
dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = True
|
|
63
|
-
) -> cute.TiledCopy:
|
|
64
|
-
num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
|
65
|
-
copy_elems = num_copy_bits // dtype.width
|
|
66
|
-
copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
67
|
-
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
68
|
-
gmem_threads_per_row = major_mode_size // copy_elems
|
|
69
|
-
assert num_threads % gmem_threads_per_row == 0
|
|
70
|
-
thr_layout = cute.make_ordered_layout(
|
|
71
|
-
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
|
72
|
-
order=(1, 0),
|
|
73
|
-
)
|
|
74
|
-
val_layout = cute.make_layout((1, copy_elems))
|
|
75
|
-
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
15
|
+
# cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default
|
|
16
|
+
fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
17
|
+
mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
18
|
+
add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
19
|
+
sub_packed_f32x2 = partial(
|
|
20
|
+
cute.arch.calc_packed_f32x2_op,
|
|
21
|
+
src_c=None,
|
|
22
|
+
calc_func=nvvm.sub_packed_f32x2,
|
|
23
|
+
rnd=nvvm.RoundingModeKind.RN,
|
|
24
|
+
)
|
|
76
25
|
|
|
77
26
|
|
|
78
27
|
@dsl_user_op
|
|
@@ -91,11 +40,11 @@ def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
|
|
|
91
40
|
|
|
92
41
|
@dsl_user_op
|
|
93
42
|
def set_block_rank(
|
|
94
|
-
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster:
|
|
95
|
-
) ->
|
|
43
|
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
|
|
44
|
+
) -> Int32:
|
|
96
45
|
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
|
97
46
|
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
98
|
-
return
|
|
47
|
+
return Int32(
|
|
99
48
|
llvm.inline_asm(
|
|
100
49
|
T.i32(),
|
|
101
50
|
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
|
|
@@ -110,7 +59,7 @@ def set_block_rank(
|
|
|
110
59
|
|
|
111
60
|
@dsl_user_op
|
|
112
61
|
def store_shared_remote(
|
|
113
|
-
val: float | Float32 | cutlass.Int64,
|
|
62
|
+
val: float | Float32 | Int32 | cutlass.Int64,
|
|
114
63
|
smem_ptr: cute.Pointer,
|
|
115
64
|
mbar_ptr: cute.Pointer,
|
|
116
65
|
peer_cta_rank_in_cluster: cute.typing.Int,
|
|
@@ -153,6 +102,21 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
|
|
|
153
102
|
)
|
|
154
103
|
|
|
155
104
|
|
|
105
|
+
@dsl_user_op
|
|
106
|
+
def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
107
|
+
return Float32(
|
|
108
|
+
llvm.inline_asm(
|
|
109
|
+
T.f32(),
|
|
110
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
111
|
+
"sqrt.approx.f32 $0, $1;",
|
|
112
|
+
"=f,f",
|
|
113
|
+
has_side_effects=False,
|
|
114
|
+
is_align_stack=False,
|
|
115
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
156
120
|
@dsl_user_op
|
|
157
121
|
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
158
122
|
return Int32(
|
|
@@ -187,55 +151,6 @@ def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -
|
|
|
187
151
|
)
|
|
188
152
|
|
|
189
153
|
|
|
190
|
-
@cute.jit
|
|
191
|
-
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
|
192
|
-
assert t.element_type.width == 16
|
|
193
|
-
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
|
194
|
-
t_u32 = cute.recast_tensor(t, Int32)
|
|
195
|
-
|
|
196
|
-
quad_idx = cute.arch.lane_idx() % 4
|
|
197
|
-
lane_03 = quad_idx == 0 or quad_idx == 3
|
|
198
|
-
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
|
199
|
-
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
|
200
|
-
# upper_map = [0, 3, 1, 2]
|
|
201
|
-
# lower_map = [1, 2, 0, 3]
|
|
202
|
-
# upper_idx = upper_map[quad_idx]
|
|
203
|
-
# indexing isn't supported so we have to do arithmetic
|
|
204
|
-
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
|
205
|
-
lower_idx = upper_idx ^ 1
|
|
206
|
-
|
|
207
|
-
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
208
|
-
width = 4
|
|
209
|
-
mask = cute.arch.WARP_SIZE - width
|
|
210
|
-
clamp = cute.arch.WARP_SIZE - 1
|
|
211
|
-
mask_and_clamp = mask << 8 | clamp
|
|
212
|
-
|
|
213
|
-
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
|
214
|
-
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
|
215
|
-
upper0 = upper if lane_03 else lower
|
|
216
|
-
lower0 = lower if lane_03 else upper
|
|
217
|
-
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
|
218
|
-
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
|
219
|
-
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
|
220
|
-
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
@cute.jit
|
|
224
|
-
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
225
|
-
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
226
|
-
tApA = cute.make_fragment(
|
|
227
|
-
cute.make_layout(
|
|
228
|
-
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
229
|
-
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
230
|
-
),
|
|
231
|
-
cutlass.Boolean,
|
|
232
|
-
)
|
|
233
|
-
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
|
234
|
-
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
|
235
|
-
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
236
|
-
return tApA
|
|
237
|
-
|
|
238
|
-
|
|
239
154
|
@cute.jit
|
|
240
155
|
def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
|
|
241
156
|
"""Fill out-of-bounds values in shared memory tensor.
|
|
@@ -281,43 +196,8 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
281
196
|
return res0, res1
|
|
282
197
|
|
|
283
198
|
|
|
284
|
-
@dsl_user_op
|
|
285
|
-
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
286
|
-
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
287
|
-
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
288
|
-
assert len(flat_coord_i64) == len(flat_stride), (
|
|
289
|
-
"Coordinate and stride must have the same length"
|
|
290
|
-
)
|
|
291
|
-
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
292
|
-
assert isinstance(tensor.iterator, cute.Pointer)
|
|
293
|
-
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
294
|
-
new_ptr = cute.make_ptr(
|
|
295
|
-
tensor.element_type,
|
|
296
|
-
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
297
|
-
tensor.memspace,
|
|
298
|
-
assumed_align=tensor.iterator.max_alignment,
|
|
299
|
-
)
|
|
300
|
-
return cute.make_tensor(new_ptr, tensor.layout)
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
@dsl_user_op
|
|
304
|
-
def coord_offset_i64(
|
|
305
|
-
idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
|
|
306
|
-
) -> cute.Tensor:
|
|
307
|
-
offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
|
|
308
|
-
assert isinstance(tensor.iterator, cute.Pointer)
|
|
309
|
-
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
310
|
-
new_ptr = cute.make_ptr(
|
|
311
|
-
tensor.element_type,
|
|
312
|
-
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
313
|
-
tensor.memspace,
|
|
314
|
-
assumed_align=tensor.iterator.max_alignment,
|
|
315
|
-
)
|
|
316
|
-
return cute.make_tensor(new_ptr, tensor.layout)
|
|
317
|
-
|
|
318
|
-
|
|
319
199
|
@cute.jit
|
|
320
|
-
def warp_prefix_sum(val:
|
|
200
|
+
def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32:
|
|
321
201
|
if const_expr(lane is None):
|
|
322
202
|
lane = cute.arch.lane_idx()
|
|
323
203
|
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
|
@@ -329,74 +209,6 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) ->
|
|
|
329
209
|
return val
|
|
330
210
|
|
|
331
211
|
|
|
332
|
-
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
|
333
|
-
"""
|
|
334
|
-
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
|
335
|
-
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
|
336
|
-
"""
|
|
337
|
-
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
|
338
|
-
acc_layout_mn = cute.make_layout(
|
|
339
|
-
(
|
|
340
|
-
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
|
341
|
-
(
|
|
342
|
-
acc_layout_col_major.shape[0][0],
|
|
343
|
-
*acc_layout_col_major.shape[0][2:],
|
|
344
|
-
acc_layout_col_major.shape[2],
|
|
345
|
-
), # MMA_N
|
|
346
|
-
*acc_layout_col_major.shape[3:],
|
|
347
|
-
),
|
|
348
|
-
stride=(
|
|
349
|
-
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
|
350
|
-
(
|
|
351
|
-
acc_layout_col_major.stride[0][0],
|
|
352
|
-
*acc_layout_col_major.stride[0][2:],
|
|
353
|
-
acc_layout_col_major.stride[2],
|
|
354
|
-
), # MMA_N
|
|
355
|
-
*acc_layout_col_major.stride[3:],
|
|
356
|
-
),
|
|
357
|
-
)
|
|
358
|
-
return cute.composition(acc_layout, acc_layout_mn)
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
|
362
|
-
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
@dsl_user_op
|
|
366
|
-
def sm90_get_smem_load_op(
|
|
367
|
-
layout_c: cutlass.utils.LayoutEnum,
|
|
368
|
-
elem_ty_c: Type[cutlass.Numeric],
|
|
369
|
-
*,
|
|
370
|
-
loc=None,
|
|
371
|
-
ip=None,
|
|
372
|
-
) -> cute.CopyAtom:
|
|
373
|
-
"""
|
|
374
|
-
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
|
375
|
-
|
|
376
|
-
Parameters:
|
|
377
|
-
-----------
|
|
378
|
-
layout_c : LayoutEnum
|
|
379
|
-
The layout enum of the output tensor D.
|
|
380
|
-
|
|
381
|
-
elem_ty_c : Type[Numeric]
|
|
382
|
-
The element type for output tensor D.
|
|
383
|
-
|
|
384
|
-
Returns:
|
|
385
|
-
--------
|
|
386
|
-
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
|
387
|
-
"""
|
|
388
|
-
|
|
389
|
-
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
|
390
|
-
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
|
391
|
-
is_m_major = layout_c.is_m_major_c()
|
|
392
|
-
if elem_ty_c.width == 16:
|
|
393
|
-
return cute.make_copy_atom(
|
|
394
|
-
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
|
395
|
-
)
|
|
396
|
-
else:
|
|
397
|
-
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
|
398
|
-
|
|
399
|
-
|
|
400
212
|
@dsl_user_op
|
|
401
213
|
def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
|
|
402
214
|
return nvvm.atomicrmw(
|
quack/varlen_utils.py
CHANGED
|
@@ -3,9 +3,13 @@
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
|
|
6
|
+
import cutlass
|
|
6
7
|
import cutlass.cute as cute
|
|
8
|
+
from cutlass import Int32, Boolean, const_expr
|
|
9
|
+
from cutlass.utils import LayoutEnum
|
|
7
10
|
|
|
8
|
-
from quack.cute_dsl_utils import ArgumentsBase
|
|
11
|
+
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
12
|
+
from quack.tensormap_manager import TensorMapManagerSm90
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
# Grouping arguments together that should be passed to __call__
|
|
@@ -15,3 +19,368 @@ class VarlenArguments(ArgumentsBase):
|
|
|
15
19
|
mCuSeqlensK: Optional[cute.Tensor] = None
|
|
16
20
|
mTensormaps: Optional[cute.Tensor] = None
|
|
17
21
|
mAIdx: Optional[cute.Tensor] = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VarlenManager:
|
|
25
|
+
bytes_per_tensormap = 128
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class Params(ParamsBase):
|
|
29
|
+
cu_seqlens_m: Optional[cute.Tensor] = None
|
|
30
|
+
cu_seqlens_k: Optional[cute.Tensor] = None
|
|
31
|
+
tensormaps: Optional[cute.Tensor] = None
|
|
32
|
+
mAIdx: Optional[cute.Tensor] = None
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
@cute.jit
|
|
36
|
+
def create(args: VarlenArguments, *, loc=None, ip=None) -> "VarlenManager.Params":
|
|
37
|
+
return VarlenManager.Params(
|
|
38
|
+
cu_seqlens_m=args.mCuSeqlensM,
|
|
39
|
+
cu_seqlens_k=args.mCuSeqlensK,
|
|
40
|
+
tensormaps=args.mTensormaps,
|
|
41
|
+
mAIdx=args.mAIdx,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
params: Params,
|
|
47
|
+
tensormap_manager: Optional[cutlass.utils.TensorMapManager],
|
|
48
|
+
tensormap_a_ptr: Optional[cute.Pointer],
|
|
49
|
+
tensormap_b_ptr: Optional[cute.Pointer],
|
|
50
|
+
tensormap_d_ptr: Optional[cute.Pointer],
|
|
51
|
+
tensormap_epi_ptrs: list[Optional[cute.Pointer]],
|
|
52
|
+
len_m_static: Int32,
|
|
53
|
+
len_k_static: Int32,
|
|
54
|
+
last_batch_idx: Int32 = Int32(-1),
|
|
55
|
+
is_group_changed: Boolean = Boolean(True),
|
|
56
|
+
*,
|
|
57
|
+
loc=None,
|
|
58
|
+
ip=None,
|
|
59
|
+
):
|
|
60
|
+
self.params = params
|
|
61
|
+
self.tensormap_manager = tensormap_manager
|
|
62
|
+
self._tensormap_a_ptr = tensormap_a_ptr
|
|
63
|
+
self._tensormap_b_ptr = tensormap_b_ptr
|
|
64
|
+
self._tensormap_d_ptr = tensormap_d_ptr
|
|
65
|
+
self._tensormap_epi_ptrs = tensormap_epi_ptrs
|
|
66
|
+
self._len_m_static = len_m_static
|
|
67
|
+
self._len_k_static = len_k_static
|
|
68
|
+
self._last_batch_idx = last_batch_idx
|
|
69
|
+
self._is_group_changed = is_group_changed
|
|
70
|
+
self.varlen_m = const_expr(params.cu_seqlens_m is not None)
|
|
71
|
+
self.varlen_k = const_expr(params.cu_seqlens_k is not None)
|
|
72
|
+
self.gather_A = const_expr(params.mAIdx is not None)
|
|
73
|
+
self._loc = loc
|
|
74
|
+
self._ip = ip
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def to_underlying_arguments(args: VarlenArguments, *, loc=None, ip=None) -> Params:
|
|
78
|
+
assert not (args.mCuSeqlensM is not None and args.mCuSeqlensK is not None), (
|
|
79
|
+
"Only support either varlen_m or varlen_k"
|
|
80
|
+
)
|
|
81
|
+
return VarlenManager.Params.create(args, loc=loc, ip=ip)
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
@cute.jit
|
|
85
|
+
def create(
|
|
86
|
+
params: Params,
|
|
87
|
+
has_D: bool,
|
|
88
|
+
num_epi_tensormaps: int,
|
|
89
|
+
len_m_static: Int32,
|
|
90
|
+
len_k_static: Int32,
|
|
91
|
+
pingpong: bool = False,
|
|
92
|
+
warp_idx: int | Int32 = 0,
|
|
93
|
+
*,
|
|
94
|
+
loc=None,
|
|
95
|
+
ip=None,
|
|
96
|
+
) -> "VarlenManager":
|
|
97
|
+
tensormap_manager = None
|
|
98
|
+
tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
|
|
99
|
+
tensormap_epi_ptrs = [None] * num_epi_tensormaps
|
|
100
|
+
varlen_m = const_expr(params.cu_seqlens_m is not None)
|
|
101
|
+
varlen_k = const_expr(params.cu_seqlens_k is not None)
|
|
102
|
+
if const_expr(varlen_m or varlen_k):
|
|
103
|
+
tensormap_manager = TensorMapManagerSm90(
|
|
104
|
+
cutlass.utils.TensorMapUpdateMode.GMEM, VarlenManager.bytes_per_tensormap
|
|
105
|
+
)
|
|
106
|
+
# equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
|
|
107
|
+
tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
|
|
108
|
+
if const_expr(varlen_m):
|
|
109
|
+
tensormap_d_idx = warp_idx // 4 if const_expr(pingpong) else 0
|
|
110
|
+
tensormap_epi_offset = tensormap_d_idx
|
|
111
|
+
if const_expr(has_D):
|
|
112
|
+
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
113
|
+
params.tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
|
|
114
|
+
)
|
|
115
|
+
tensormap_epi_offset += 1 if not pingpong else 2
|
|
116
|
+
tensormap_epi_ptrs = [
|
|
117
|
+
tensormap_manager.get_tensormap_ptr(
|
|
118
|
+
params.tensormaps[
|
|
119
|
+
tensormap_workspace_idx,
|
|
120
|
+
tensormap_epi_offset + i * (1 if not pingpong else 2),
|
|
121
|
+
None,
|
|
122
|
+
].iterator
|
|
123
|
+
)
|
|
124
|
+
for i in range(num_epi_tensormaps)
|
|
125
|
+
]
|
|
126
|
+
else:
|
|
127
|
+
assert varlen_k
|
|
128
|
+
gather_A = const_expr(params.mAIdx is not None)
|
|
129
|
+
if const_expr(not gather_A):
|
|
130
|
+
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
131
|
+
params.tensormaps[tensormap_workspace_idx, 0, None].iterator
|
|
132
|
+
)
|
|
133
|
+
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
134
|
+
params.tensormaps[
|
|
135
|
+
tensormap_workspace_idx, 1 if not gather_A else 0, None
|
|
136
|
+
].iterator
|
|
137
|
+
)
|
|
138
|
+
return VarlenManager(
|
|
139
|
+
params,
|
|
140
|
+
tensormap_manager,
|
|
141
|
+
tensormap_a_ptr,
|
|
142
|
+
tensormap_b_ptr,
|
|
143
|
+
tensormap_d_ptr,
|
|
144
|
+
tensormap_epi_ptrs,
|
|
145
|
+
len_m_static=len_m_static,
|
|
146
|
+
len_k_static=len_k_static,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def len_m(self, batch_idx: Int32) -> Int32:
|
|
150
|
+
if const_expr(self.varlen_m):
|
|
151
|
+
return self.params.cu_seqlens_m[batch_idx + 1] - self.params.cu_seqlens_m[batch_idx]
|
|
152
|
+
else:
|
|
153
|
+
return self._len_m_static
|
|
154
|
+
|
|
155
|
+
def len_k(self, batch_idx: Int32) -> Int32:
|
|
156
|
+
if const_expr(self.varlen_k):
|
|
157
|
+
return self.params.cu_seqlens_k[batch_idx + 1] - self.params.cu_seqlens_k[batch_idx]
|
|
158
|
+
else:
|
|
159
|
+
return self._len_k_static
|
|
160
|
+
|
|
161
|
+
def offset_batch_A(self, mA_mkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
|
|
162
|
+
params = self.params
|
|
163
|
+
if const_expr(self.varlen_m):
|
|
164
|
+
mA_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mA_mkl)
|
|
165
|
+
elif const_expr(self.varlen_k):
|
|
166
|
+
mA_mk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mA_mkl)
|
|
167
|
+
else:
|
|
168
|
+
mA_mk = mA_mkl[None, None, batch_idx]
|
|
169
|
+
return mA_mk
|
|
170
|
+
|
|
171
|
+
def offset_batch_AIdx(self, batch_idx: Int32) -> cute.Tensor:
|
|
172
|
+
params = self.params
|
|
173
|
+
if const_expr(self.varlen_m):
|
|
174
|
+
mAIdx_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx],), params.mAIdx)
|
|
175
|
+
elif const_expr(self.varlen_k):
|
|
176
|
+
mAIdx_mk = cute.domain_offset((params.cu_seqlens_k[batch_idx],), params.mAIdx)
|
|
177
|
+
else:
|
|
178
|
+
mAIdx_mk = params.mAIdx[None, batch_idx]
|
|
179
|
+
return mAIdx_mk
|
|
180
|
+
|
|
181
|
+
def offset_batch_B(self, mB_nkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
|
|
182
|
+
params = self.params
|
|
183
|
+
if const_expr(self.varlen_k):
|
|
184
|
+
mB_nk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mB_nkl)
|
|
185
|
+
else:
|
|
186
|
+
mB_nk = mB_nkl[None, None, batch_idx]
|
|
187
|
+
return mB_nk
|
|
188
|
+
|
|
189
|
+
def offset_batch_epi(self, mD_mnl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
|
|
190
|
+
params = self.params
|
|
191
|
+
if const_expr(self.varlen_m):
|
|
192
|
+
mD_mn = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mD_mnl)
|
|
193
|
+
else:
|
|
194
|
+
mD_mn = mD_mnl[None, None, batch_idx]
|
|
195
|
+
return mD_mn
|
|
196
|
+
|
|
197
|
+
def init_tensormap_AB(
|
|
198
|
+
self,
|
|
199
|
+
tma_atom_a: Optional[cute.CopyAtom],
|
|
200
|
+
tma_atom_b: cute.CopyAtom,
|
|
201
|
+
is_manager_warp: bool | Boolean = True,
|
|
202
|
+
) -> None:
|
|
203
|
+
if const_expr(self.varlen_k):
|
|
204
|
+
if const_expr(not self.gather_A):
|
|
205
|
+
self.tensormap_manager.init_tensormap_from_atom(
|
|
206
|
+
tma_atom_a, self._tensormap_a_ptr, is_manager_warp
|
|
207
|
+
)
|
|
208
|
+
self.tensormap_manager.init_tensormap_from_atom(
|
|
209
|
+
tma_atom_b, self._tensormap_b_ptr, is_manager_warp
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def init_tensormap_epi(
|
|
213
|
+
self,
|
|
214
|
+
tma_atom_d: Optional[cute.CopyAtom],
|
|
215
|
+
tma_atoms_epi: list[cute.CopyAtom],
|
|
216
|
+
is_manager_warp: bool | Boolean = True,
|
|
217
|
+
) -> None:
|
|
218
|
+
if const_expr(self.varlen_m):
|
|
219
|
+
if const_expr(self._tensormap_d_ptr is not None):
|
|
220
|
+
self.tensormap_manager.init_tensormap_from_atom(
|
|
221
|
+
tma_atom_d, self._tensormap_d_ptr, is_manager_warp
|
|
222
|
+
)
|
|
223
|
+
for tma_atom, tensormap_epi_ptr in zip(tma_atoms_epi, self._tensormap_epi_ptrs):
|
|
224
|
+
self.tensormap_manager.init_tensormap_from_atom(
|
|
225
|
+
tma_atom, tensormap_epi_ptr, is_manager_warp
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def fence_tensormap_init(self) -> None:
|
|
229
|
+
self.tensormap_manager.fence_tensormap_initialization()
|
|
230
|
+
|
|
231
|
+
@cute.jit
|
|
232
|
+
def update_tensormap_AB(
|
|
233
|
+
self,
|
|
234
|
+
batch_idx: Int32,
|
|
235
|
+
a_layout: LayoutEnum,
|
|
236
|
+
b_layout: LayoutEnum,
|
|
237
|
+
is_manager_warp: bool | Boolean = True,
|
|
238
|
+
) -> None:
|
|
239
|
+
if const_expr(self.varlen_k):
|
|
240
|
+
self._is_group_changed = Boolean(batch_idx != self._last_batch_idx)
|
|
241
|
+
self._last_batch_idx = batch_idx
|
|
242
|
+
if self._is_group_changed:
|
|
243
|
+
# construct tensor A/B based on real address, shape and stride information
|
|
244
|
+
cu_seqlens_k = self.params.cu_seqlens_k
|
|
245
|
+
tensormap_ptrs = [self._tensormap_b_ptr]
|
|
246
|
+
shapes = [cu_seqlens_k[batch_idx + 1]]
|
|
247
|
+
orders = [0 if const_expr(b_layout == LayoutEnum.ROW_MAJOR) else 1]
|
|
248
|
+
if const_expr(not self.gather_A):
|
|
249
|
+
tensormap_ptrs.insert(0, self._tensormap_a_ptr)
|
|
250
|
+
shapes.insert(0, cu_seqlens_k[batch_idx + 1])
|
|
251
|
+
orders.insert(0, 0 if const_expr(a_layout == LayoutEnum.ROW_MAJOR) else 1)
|
|
252
|
+
self.tensormap_manager.update_tensormap_shape(
|
|
253
|
+
tensormap_ptrs,
|
|
254
|
+
is_manager_warp=is_manager_warp,
|
|
255
|
+
shapes=shapes,
|
|
256
|
+
orders=orders,
|
|
257
|
+
tensormap_smem_ptr=None,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
@cute.jit
|
|
261
|
+
def update_tensormap_epi(
|
|
262
|
+
self,
|
|
263
|
+
batch_idx: Int32,
|
|
264
|
+
d_layout: LayoutEnum,
|
|
265
|
+
epi_shapes: list[Int32],
|
|
266
|
+
epi_orders: list[int],
|
|
267
|
+
is_manager_warp: bool | Boolean = True,
|
|
268
|
+
) -> None:
|
|
269
|
+
if const_expr(self.varlen_m):
|
|
270
|
+
self._is_group_changed = Boolean(batch_idx != self._last_batch_idx)
|
|
271
|
+
self._last_batch_idx = batch_idx
|
|
272
|
+
# Cute-DSL doesn't like this under if statement
|
|
273
|
+
order_d = (
|
|
274
|
+
(0 if const_expr(d_layout.is_m_major_c()) else 1) if d_layout is not None else None
|
|
275
|
+
)
|
|
276
|
+
if self._is_group_changed:
|
|
277
|
+
# construct tensor A/B based on real address, shape and stride information
|
|
278
|
+
cu_seqlens_m = self.params.cu_seqlens_m
|
|
279
|
+
# construct tensor D based on real address, shape and stride information
|
|
280
|
+
tensormap_ptrs, shapes, orders = [], [], []
|
|
281
|
+
if const_expr(self._tensormap_d_ptr is not None):
|
|
282
|
+
tensormap_ptrs.append(self._tensormap_d_ptr)
|
|
283
|
+
shapes.append(cu_seqlens_m[batch_idx + 1])
|
|
284
|
+
orders.append(order_d)
|
|
285
|
+
tensormap_ptrs.extend(self._tensormap_epi_ptrs)
|
|
286
|
+
shapes.extend(epi_shapes)
|
|
287
|
+
orders.extend(epi_orders)
|
|
288
|
+
self.tensormap_manager.update_tensormap_shape(
|
|
289
|
+
tensormap_ptrs,
|
|
290
|
+
is_manager_warp=is_manager_warp,
|
|
291
|
+
shapes=shapes,
|
|
292
|
+
orders=orders,
|
|
293
|
+
tensormap_smem_ptr=None,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
@cute.jit
|
|
297
|
+
def fence_tensormap_update_AB(self, is_manager_warp: bool | Boolean = True) -> None:
|
|
298
|
+
if const_expr(self.varlen_k):
|
|
299
|
+
if self._is_group_changed and is_manager_warp:
|
|
300
|
+
if const_expr(not self.gather_A):
|
|
301
|
+
self.tensormap_manager.fence_tensormap_update(self._tensormap_a_ptr)
|
|
302
|
+
self.tensormap_manager.fence_tensormap_update(self._tensormap_b_ptr)
|
|
303
|
+
|
|
304
|
+
@cute.jit
|
|
305
|
+
def fence_tensormap_update_epi(self, is_manager_warp: bool | Boolean = True) -> None:
|
|
306
|
+
if const_expr(self.varlen_m):
|
|
307
|
+
if self._is_group_changed and is_manager_warp:
|
|
308
|
+
if const_expr(self._tensormap_d_ptr is not None):
|
|
309
|
+
self.tensormap_manager.fence_tensormap_update(self._tensormap_d_ptr)
|
|
310
|
+
for tensormap_epi_ptr in self._tensormap_epi_ptrs:
|
|
311
|
+
if const_expr(tensormap_epi_ptr is not None):
|
|
312
|
+
self.tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
|
|
313
|
+
|
|
314
|
+
def get_tma_desc_a_ptr(self) -> Optional[cute.Pointer]:
|
|
315
|
+
tma_desc_a_ptr = None
|
|
316
|
+
if const_expr(self.varlen_k and self._tensormap_a_ptr is not None):
|
|
317
|
+
tma_desc_a_ptr = self.tensormap_manager.get_tensormap_ptr(
|
|
318
|
+
self._tensormap_a_ptr, cute.AddressSpace.generic
|
|
319
|
+
)
|
|
320
|
+
return tma_desc_a_ptr
|
|
321
|
+
|
|
322
|
+
def get_tma_desc_b_ptr(self) -> Optional[cute.Pointer]:
|
|
323
|
+
tma_desc_b_ptr = None
|
|
324
|
+
if const_expr(self.varlen_k):
|
|
325
|
+
tma_desc_b_ptr = self.tensormap_manager.get_tensormap_ptr(
|
|
326
|
+
self._tensormap_b_ptr, cute.AddressSpace.generic
|
|
327
|
+
)
|
|
328
|
+
return tma_desc_b_ptr
|
|
329
|
+
|
|
330
|
+
def get_tma_desc_d_ptr(self) -> Optional[cute.Pointer]:
|
|
331
|
+
tma_desc_d_ptr = None
|
|
332
|
+
if const_expr(self.varlen_m and self._tensormap_d_ptr is not None):
|
|
333
|
+
tma_desc_d_ptr = self.tensormap_manager.get_tensormap_ptr(
|
|
334
|
+
self._tensormap_d_ptr, cute.AddressSpace.generic
|
|
335
|
+
)
|
|
336
|
+
return tma_desc_d_ptr
|
|
337
|
+
|
|
338
|
+
def get_tma_desc_epi_ptrs(self) -> list[Optional[cute.Pointer]]:
|
|
339
|
+
tma_desc_epi_ptrs = [None] * len(self._tensormap_epi_ptrs)
|
|
340
|
+
if const_expr(self.varlen_m):
|
|
341
|
+
for i, tensormap_epi_ptr in enumerate(self._tensormap_epi_ptrs):
|
|
342
|
+
if const_expr(tensormap_epi_ptr is not None):
|
|
343
|
+
tma_desc_epi_ptrs[i] = self.tensormap_manager.get_tensormap_ptr(
|
|
344
|
+
tensormap_epi_ptr, cute.AddressSpace.generic
|
|
345
|
+
)
|
|
346
|
+
return tma_desc_epi_ptrs
|
|
347
|
+
|
|
348
|
+
def __extract_mlir_values__(self):
|
|
349
|
+
values, self._values_pos = [], []
|
|
350
|
+
for obj in [
|
|
351
|
+
self.params,
|
|
352
|
+
self.tensormap_manager,
|
|
353
|
+
self._tensormap_a_ptr,
|
|
354
|
+
self._tensormap_b_ptr,
|
|
355
|
+
self._tensormap_d_ptr,
|
|
356
|
+
self._tensormap_epi_ptrs,
|
|
357
|
+
self._len_m_static,
|
|
358
|
+
self._len_k_static,
|
|
359
|
+
self._last_batch_idx,
|
|
360
|
+
self._is_group_changed,
|
|
361
|
+
]:
|
|
362
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
363
|
+
values += obj_values
|
|
364
|
+
self._values_pos.append(len(obj_values))
|
|
365
|
+
return values
|
|
366
|
+
|
|
367
|
+
def __new_from_mlir_values__(self, values):
|
|
368
|
+
obj_list = []
|
|
369
|
+
for obj, n_items in zip(
|
|
370
|
+
[
|
|
371
|
+
self.params,
|
|
372
|
+
self.tensormap_manager,
|
|
373
|
+
self._tensormap_a_ptr,
|
|
374
|
+
self._tensormap_b_ptr,
|
|
375
|
+
self._tensormap_d_ptr,
|
|
376
|
+
self._tensormap_epi_ptrs,
|
|
377
|
+
self._len_m_static,
|
|
378
|
+
self._len_k_static,
|
|
379
|
+
self._last_batch_idx,
|
|
380
|
+
self._is_group_changed,
|
|
381
|
+
],
|
|
382
|
+
self._values_pos,
|
|
383
|
+
):
|
|
384
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
385
|
+
values = values[n_items:]
|
|
386
|
+
return self.__class__(*(tuple(obj_list)), loc=self._loc)
|