quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/utils.py CHANGED
@@ -1,7 +1,8 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- from typing import Optional, Tuple, Type, Union
4
+ from functools import partial
5
+ from typing import Optional, Tuple, Union
5
6
 
6
7
  import cutlass
7
8
  import cutlass.cute as cute
@@ -9,70 +10,18 @@ import cutlass.cute as cute
9
10
  from cutlass import Float32, Int32, const_expr
10
11
  from cutlass.cutlass_dsl import T, dsl_user_op
11
12
  from cutlass._mlir.dialects import llvm, nvvm, vector
12
- from cutlass.cute.runtime import from_dlpack
13
13
 
14
14
 
15
- def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
16
- return (
17
- from_dlpack(x, assumed_align=alignment)
18
- .mark_layout_dynamic(leading_dim=leading_dim)
19
- .mark_compact_shape_dynamic(
20
- mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
21
- )
22
- )
23
-
24
-
25
- def transpose_view(a: cute.Tensor) -> cute.Tensor:
26
- """Transpose the first two dimensions of a tensor on smem."""
27
- shape = (a.shape[1], a.shape[0], *a.shape[2:])
28
- order = (1, 0, *range(2, cute.rank(a)))
29
- return cute.composition(a, cute.make_ordered_layout(shape, order=order))
30
-
31
-
32
- def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
33
- return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
34
-
35
-
36
- @dsl_user_op
37
- def get_copy_atom(
38
- dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
39
- ) -> cute.CopyAtom:
40
- num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
41
- copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
42
- return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
43
-
44
-
45
- @dsl_user_op
46
- def copy(
47
- src: cute.Tensor,
48
- dst: cute.Tensor,
49
- *,
50
- pred: Optional[cute.Tensor] = None,
51
- num_copy_elems: int = 1,
52
- is_async: bool = False,
53
- loc=None,
54
- ip=None,
55
- **kwargs,
56
- ) -> None:
57
- copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
58
- cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
59
-
60
-
61
- def tiled_copy_2d(
62
- dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = True
63
- ) -> cute.TiledCopy:
64
- num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
65
- copy_elems = num_copy_bits // dtype.width
66
- copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
67
- copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
68
- gmem_threads_per_row = major_mode_size // copy_elems
69
- assert num_threads % gmem_threads_per_row == 0
70
- thr_layout = cute.make_ordered_layout(
71
- (num_threads // gmem_threads_per_row, gmem_threads_per_row),
72
- order=(1, 0),
73
- )
74
- val_layout = cute.make_layout((1, copy_elems))
75
- return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
15
+ # cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default
16
+ fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
17
+ mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
18
+ add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
19
+ sub_packed_f32x2 = partial(
20
+ cute.arch.calc_packed_f32x2_op,
21
+ src_c=None,
22
+ calc_func=nvvm.sub_packed_f32x2,
23
+ rnd=nvvm.RoundingModeKind.RN,
24
+ )
76
25
 
77
26
 
78
27
  @dsl_user_op
@@ -91,11 +40,11 @@ def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
91
40
 
92
41
  @dsl_user_op
93
42
  def set_block_rank(
94
- smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
95
- ) -> cutlass.Int32:
43
+ smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
44
+ ) -> Int32:
96
45
  """Map the given smem pointer to the address at another CTA rank in the cluster."""
97
46
  smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
98
- return cutlass.Int32(
47
+ return Int32(
99
48
  llvm.inline_asm(
100
49
  T.i32(),
101
50
  [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
@@ -110,7 +59,7 @@ def set_block_rank(
110
59
 
111
60
  @dsl_user_op
112
61
  def store_shared_remote(
113
- val: float | Float32 | cutlass.Int64,
62
+ val: float | Float32 | Int32 | cutlass.Int64,
114
63
  smem_ptr: cute.Pointer,
115
64
  mbar_ptr: cute.Pointer,
116
65
  peer_cta_rank_in_cluster: cute.typing.Int,
@@ -153,6 +102,21 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
153
102
  )
154
103
 
155
104
 
105
+ @dsl_user_op
106
+ def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
107
+ return Float32(
108
+ llvm.inline_asm(
109
+ T.f32(),
110
+ [Float32(a).ir_value(loc=loc, ip=ip)],
111
+ "sqrt.approx.f32 $0, $1;",
112
+ "=f,f",
113
+ has_side_effects=False,
114
+ is_align_stack=False,
115
+ asm_dialect=llvm.AsmDialect.AD_ATT,
116
+ )
117
+ )
118
+
119
+
156
120
  @dsl_user_op
157
121
  def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
158
122
  return Int32(
@@ -187,55 +151,6 @@ def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -
187
151
  )
188
152
 
189
153
 
190
- @cute.jit
191
- def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
192
- assert t.element_type.width == 16
193
- assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
194
- t_u32 = cute.recast_tensor(t, Int32)
195
-
196
- quad_idx = cute.arch.lane_idx() % 4
197
- lane_03 = quad_idx == 0 or quad_idx == 3
198
- selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
199
- selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
200
- # upper_map = [0, 3, 1, 2]
201
- # lower_map = [1, 2, 0, 3]
202
- # upper_idx = upper_map[quad_idx]
203
- # indexing isn't supported so we have to do arithmetic
204
- upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
205
- lower_idx = upper_idx ^ 1
206
-
207
- # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
208
- width = 4
209
- mask = cute.arch.WARP_SIZE - width
210
- clamp = cute.arch.WARP_SIZE - 1
211
- mask_and_clamp = mask << 8 | clamp
212
-
213
- for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
214
- upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
215
- upper0 = upper if lane_03 else lower
216
- lower0 = lower if lane_03 else upper
217
- upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
218
- lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
219
- t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
220
- t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
221
-
222
-
223
- @cute.jit
224
- def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
225
- # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
226
- tApA = cute.make_fragment(
227
- cute.make_layout(
228
- (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
229
- stride=(cute.size(tAcA, mode=[2]), 0, 1),
230
- ),
231
- cutlass.Boolean,
232
- )
233
- for rest_v in cutlass.range_constexpr(tApA.shape[0]):
234
- for rest_k in cutlass.range_constexpr(tApA.shape[2]):
235
- tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
236
- return tApA
237
-
238
-
239
154
  @cute.jit
240
155
  def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
241
156
  """Fill out-of-bounds values in shared memory tensor.
@@ -281,43 +196,8 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
281
196
  return res0, res1
282
197
 
283
198
 
284
- @dsl_user_op
285
- def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
286
- flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
287
- flat_stride = cute.flatten_to_tuple(tensor.stride)
288
- assert len(flat_coord_i64) == len(flat_stride), (
289
- "Coordinate and stride must have the same length"
290
- )
291
- offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
292
- assert isinstance(tensor.iterator, cute.Pointer)
293
- # HACK: we assume that applying the offset does not change the pointer alignment
294
- new_ptr = cute.make_ptr(
295
- tensor.element_type,
296
- tensor.iterator.toint() + offset * tensor.element_type.width // 8,
297
- tensor.memspace,
298
- assumed_align=tensor.iterator.max_alignment,
299
- )
300
- return cute.make_tensor(new_ptr, tensor.layout)
301
-
302
-
303
- @dsl_user_op
304
- def coord_offset_i64(
305
- idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
306
- ) -> cute.Tensor:
307
- offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
308
- assert isinstance(tensor.iterator, cute.Pointer)
309
- # HACK: we assume that applying the offset does not change the pointer alignment
310
- new_ptr = cute.make_ptr(
311
- tensor.element_type,
312
- tensor.iterator.toint() + offset * tensor.element_type.width // 8,
313
- tensor.memspace,
314
- assumed_align=tensor.iterator.max_alignment,
315
- )
316
- return cute.make_tensor(new_ptr, tensor.layout)
317
-
318
-
319
199
  @cute.jit
320
- def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
200
+ def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32:
321
201
  if const_expr(lane is None):
322
202
  lane = cute.arch.lane_idx()
323
203
  for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
@@ -329,74 +209,6 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) ->
329
209
  return val
330
210
 
331
211
 
332
- def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
333
- """
334
- For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
335
- For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
336
- """
337
- acc_layout_col_major = cute.make_layout(acc_layout.shape)
338
- acc_layout_mn = cute.make_layout(
339
- (
340
- (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
341
- (
342
- acc_layout_col_major.shape[0][0],
343
- *acc_layout_col_major.shape[0][2:],
344
- acc_layout_col_major.shape[2],
345
- ), # MMA_N
346
- *acc_layout_col_major.shape[3:],
347
- ),
348
- stride=(
349
- (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
350
- (
351
- acc_layout_col_major.stride[0][0],
352
- *acc_layout_col_major.stride[0][2:],
353
- acc_layout_col_major.stride[2],
354
- ), # MMA_N
355
- *acc_layout_col_major.stride[3:],
356
- ),
357
- )
358
- return cute.composition(acc_layout, acc_layout_mn)
359
-
360
-
361
- def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
362
- return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
363
-
364
-
365
- @dsl_user_op
366
- def sm90_get_smem_load_op(
367
- layout_c: cutlass.utils.LayoutEnum,
368
- elem_ty_c: Type[cutlass.Numeric],
369
- *,
370
- loc=None,
371
- ip=None,
372
- ) -> cute.CopyAtom:
373
- """
374
- Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
375
-
376
- Parameters:
377
- -----------
378
- layout_c : LayoutEnum
379
- The layout enum of the output tensor D.
380
-
381
- elem_ty_c : Type[Numeric]
382
- The element type for output tensor D.
383
-
384
- Returns:
385
- --------
386
- Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
387
- """
388
-
389
- if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
390
- raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
391
- is_m_major = layout_c.is_m_major_c()
392
- if elem_ty_c.width == 16:
393
- return cute.make_copy_atom(
394
- cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
395
- )
396
- else:
397
- return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
398
-
399
-
400
212
  @dsl_user_op
401
213
  def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
402
214
  return nvvm.atomicrmw(
quack/varlen_utils.py CHANGED
@@ -3,9 +3,13 @@
3
3
  from typing import Optional
4
4
  from dataclasses import dataclass
5
5
 
6
+ import cutlass
6
7
  import cutlass.cute as cute
8
+ from cutlass import Int32, Boolean, const_expr
9
+ from cutlass.utils import LayoutEnum
7
10
 
8
- from quack.cute_dsl_utils import ArgumentsBase
11
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
12
+ from quack.tensormap_manager import TensorMapManagerSm90
9
13
 
10
14
 
11
15
  # Grouping arguments together that should be passed to __call__
@@ -15,3 +19,368 @@ class VarlenArguments(ArgumentsBase):
15
19
  mCuSeqlensK: Optional[cute.Tensor] = None
16
20
  mTensormaps: Optional[cute.Tensor] = None
17
21
  mAIdx: Optional[cute.Tensor] = None
22
+
23
+
24
+ class VarlenManager:
25
+ bytes_per_tensormap = 128
26
+
27
+ @dataclass
28
+ class Params(ParamsBase):
29
+ cu_seqlens_m: Optional[cute.Tensor] = None
30
+ cu_seqlens_k: Optional[cute.Tensor] = None
31
+ tensormaps: Optional[cute.Tensor] = None
32
+ mAIdx: Optional[cute.Tensor] = None
33
+
34
+ @staticmethod
35
+ @cute.jit
36
+ def create(args: VarlenArguments, *, loc=None, ip=None) -> "VarlenManager.Params":
37
+ return VarlenManager.Params(
38
+ cu_seqlens_m=args.mCuSeqlensM,
39
+ cu_seqlens_k=args.mCuSeqlensK,
40
+ tensormaps=args.mTensormaps,
41
+ mAIdx=args.mAIdx,
42
+ )
43
+
44
+ def __init__(
45
+ self,
46
+ params: Params,
47
+ tensormap_manager: Optional[cutlass.utils.TensorMapManager],
48
+ tensormap_a_ptr: Optional[cute.Pointer],
49
+ tensormap_b_ptr: Optional[cute.Pointer],
50
+ tensormap_d_ptr: Optional[cute.Pointer],
51
+ tensormap_epi_ptrs: list[Optional[cute.Pointer]],
52
+ len_m_static: Int32,
53
+ len_k_static: Int32,
54
+ last_batch_idx: Int32 = Int32(-1),
55
+ is_group_changed: Boolean = Boolean(True),
56
+ *,
57
+ loc=None,
58
+ ip=None,
59
+ ):
60
+ self.params = params
61
+ self.tensormap_manager = tensormap_manager
62
+ self._tensormap_a_ptr = tensormap_a_ptr
63
+ self._tensormap_b_ptr = tensormap_b_ptr
64
+ self._tensormap_d_ptr = tensormap_d_ptr
65
+ self._tensormap_epi_ptrs = tensormap_epi_ptrs
66
+ self._len_m_static = len_m_static
67
+ self._len_k_static = len_k_static
68
+ self._last_batch_idx = last_batch_idx
69
+ self._is_group_changed = is_group_changed
70
+ self.varlen_m = const_expr(params.cu_seqlens_m is not None)
71
+ self.varlen_k = const_expr(params.cu_seqlens_k is not None)
72
+ self.gather_A = const_expr(params.mAIdx is not None)
73
+ self._loc = loc
74
+ self._ip = ip
75
+
76
+ @staticmethod
77
+ def to_underlying_arguments(args: VarlenArguments, *, loc=None, ip=None) -> Params:
78
+ assert not (args.mCuSeqlensM is not None and args.mCuSeqlensK is not None), (
79
+ "Only support either varlen_m or varlen_k"
80
+ )
81
+ return VarlenManager.Params.create(args, loc=loc, ip=ip)
82
+
83
+ @staticmethod
84
+ @cute.jit
85
+ def create(
86
+ params: Params,
87
+ has_D: bool,
88
+ num_epi_tensormaps: int,
89
+ len_m_static: Int32,
90
+ len_k_static: Int32,
91
+ pingpong: bool = False,
92
+ warp_idx: int | Int32 = 0,
93
+ *,
94
+ loc=None,
95
+ ip=None,
96
+ ) -> "VarlenManager":
97
+ tensormap_manager = None
98
+ tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
99
+ tensormap_epi_ptrs = [None] * num_epi_tensormaps
100
+ varlen_m = const_expr(params.cu_seqlens_m is not None)
101
+ varlen_k = const_expr(params.cu_seqlens_k is not None)
102
+ if const_expr(varlen_m or varlen_k):
103
+ tensormap_manager = TensorMapManagerSm90(
104
+ cutlass.utils.TensorMapUpdateMode.GMEM, VarlenManager.bytes_per_tensormap
105
+ )
106
+ # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
107
+ tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
108
+ if const_expr(varlen_m):
109
+ tensormap_d_idx = warp_idx // 4 if const_expr(pingpong) else 0
110
+ tensormap_epi_offset = tensormap_d_idx
111
+ if const_expr(has_D):
112
+ tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
113
+ params.tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
114
+ )
115
+ tensormap_epi_offset += 1 if not pingpong else 2
116
+ tensormap_epi_ptrs = [
117
+ tensormap_manager.get_tensormap_ptr(
118
+ params.tensormaps[
119
+ tensormap_workspace_idx,
120
+ tensormap_epi_offset + i * (1 if not pingpong else 2),
121
+ None,
122
+ ].iterator
123
+ )
124
+ for i in range(num_epi_tensormaps)
125
+ ]
126
+ else:
127
+ assert varlen_k
128
+ gather_A = const_expr(params.mAIdx is not None)
129
+ if const_expr(not gather_A):
130
+ tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
131
+ params.tensormaps[tensormap_workspace_idx, 0, None].iterator
132
+ )
133
+ tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
134
+ params.tensormaps[
135
+ tensormap_workspace_idx, 1 if not gather_A else 0, None
136
+ ].iterator
137
+ )
138
+ return VarlenManager(
139
+ params,
140
+ tensormap_manager,
141
+ tensormap_a_ptr,
142
+ tensormap_b_ptr,
143
+ tensormap_d_ptr,
144
+ tensormap_epi_ptrs,
145
+ len_m_static=len_m_static,
146
+ len_k_static=len_k_static,
147
+ )
148
+
149
+ def len_m(self, batch_idx: Int32) -> Int32:
150
+ if const_expr(self.varlen_m):
151
+ return self.params.cu_seqlens_m[batch_idx + 1] - self.params.cu_seqlens_m[batch_idx]
152
+ else:
153
+ return self._len_m_static
154
+
155
+ def len_k(self, batch_idx: Int32) -> Int32:
156
+ if const_expr(self.varlen_k):
157
+ return self.params.cu_seqlens_k[batch_idx + 1] - self.params.cu_seqlens_k[batch_idx]
158
+ else:
159
+ return self._len_k_static
160
+
161
+ def offset_batch_A(self, mA_mkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
162
+ params = self.params
163
+ if const_expr(self.varlen_m):
164
+ mA_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mA_mkl)
165
+ elif const_expr(self.varlen_k):
166
+ mA_mk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mA_mkl)
167
+ else:
168
+ mA_mk = mA_mkl[None, None, batch_idx]
169
+ return mA_mk
170
+
171
+ def offset_batch_AIdx(self, batch_idx: Int32) -> cute.Tensor:
172
+ params = self.params
173
+ if const_expr(self.varlen_m):
174
+ mAIdx_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx],), params.mAIdx)
175
+ elif const_expr(self.varlen_k):
176
+ mAIdx_mk = cute.domain_offset((params.cu_seqlens_k[batch_idx],), params.mAIdx)
177
+ else:
178
+ mAIdx_mk = params.mAIdx[None, batch_idx]
179
+ return mAIdx_mk
180
+
181
+ def offset_batch_B(self, mB_nkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
182
+ params = self.params
183
+ if const_expr(self.varlen_k):
184
+ mB_nk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mB_nkl)
185
+ else:
186
+ mB_nk = mB_nkl[None, None, batch_idx]
187
+ return mB_nk
188
+
189
+ def offset_batch_epi(self, mD_mnl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
190
+ params = self.params
191
+ if const_expr(self.varlen_m):
192
+ mD_mn = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mD_mnl)
193
+ else:
194
+ mD_mn = mD_mnl[None, None, batch_idx]
195
+ return mD_mn
196
+
197
+ def init_tensormap_AB(
198
+ self,
199
+ tma_atom_a: Optional[cute.CopyAtom],
200
+ tma_atom_b: cute.CopyAtom,
201
+ is_manager_warp: bool | Boolean = True,
202
+ ) -> None:
203
+ if const_expr(self.varlen_k):
204
+ if const_expr(not self.gather_A):
205
+ self.tensormap_manager.init_tensormap_from_atom(
206
+ tma_atom_a, self._tensormap_a_ptr, is_manager_warp
207
+ )
208
+ self.tensormap_manager.init_tensormap_from_atom(
209
+ tma_atom_b, self._tensormap_b_ptr, is_manager_warp
210
+ )
211
+
212
+ def init_tensormap_epi(
213
+ self,
214
+ tma_atom_d: Optional[cute.CopyAtom],
215
+ tma_atoms_epi: list[cute.CopyAtom],
216
+ is_manager_warp: bool | Boolean = True,
217
+ ) -> None:
218
+ if const_expr(self.varlen_m):
219
+ if const_expr(self._tensormap_d_ptr is not None):
220
+ self.tensormap_manager.init_tensormap_from_atom(
221
+ tma_atom_d, self._tensormap_d_ptr, is_manager_warp
222
+ )
223
+ for tma_atom, tensormap_epi_ptr in zip(tma_atoms_epi, self._tensormap_epi_ptrs):
224
+ self.tensormap_manager.init_tensormap_from_atom(
225
+ tma_atom, tensormap_epi_ptr, is_manager_warp
226
+ )
227
+
228
+ def fence_tensormap_init(self) -> None:
229
+ self.tensormap_manager.fence_tensormap_initialization()
230
+
231
+ @cute.jit
232
+ def update_tensormap_AB(
233
+ self,
234
+ batch_idx: Int32,
235
+ a_layout: LayoutEnum,
236
+ b_layout: LayoutEnum,
237
+ is_manager_warp: bool | Boolean = True,
238
+ ) -> None:
239
+ if const_expr(self.varlen_k):
240
+ self._is_group_changed = Boolean(batch_idx != self._last_batch_idx)
241
+ self._last_batch_idx = batch_idx
242
+ if self._is_group_changed:
243
+ # construct tensor A/B based on real address, shape and stride information
244
+ cu_seqlens_k = self.params.cu_seqlens_k
245
+ tensormap_ptrs = [self._tensormap_b_ptr]
246
+ shapes = [cu_seqlens_k[batch_idx + 1]]
247
+ orders = [0 if const_expr(b_layout == LayoutEnum.ROW_MAJOR) else 1]
248
+ if const_expr(not self.gather_A):
249
+ tensormap_ptrs.insert(0, self._tensormap_a_ptr)
250
+ shapes.insert(0, cu_seqlens_k[batch_idx + 1])
251
+ orders.insert(0, 0 if const_expr(a_layout == LayoutEnum.ROW_MAJOR) else 1)
252
+ self.tensormap_manager.update_tensormap_shape(
253
+ tensormap_ptrs,
254
+ is_manager_warp=is_manager_warp,
255
+ shapes=shapes,
256
+ orders=orders,
257
+ tensormap_smem_ptr=None,
258
+ )
259
+
260
+ @cute.jit
261
+ def update_tensormap_epi(
262
+ self,
263
+ batch_idx: Int32,
264
+ d_layout: LayoutEnum,
265
+ epi_shapes: list[Int32],
266
+ epi_orders: list[int],
267
+ is_manager_warp: bool | Boolean = True,
268
+ ) -> None:
269
+ if const_expr(self.varlen_m):
270
+ self._is_group_changed = Boolean(batch_idx != self._last_batch_idx)
271
+ self._last_batch_idx = batch_idx
272
+ # Cute-DSL doesn't like this under if statement
273
+ order_d = (
274
+ (0 if const_expr(d_layout.is_m_major_c()) else 1) if d_layout is not None else None
275
+ )
276
+ if self._is_group_changed:
277
+ # construct tensor A/B based on real address, shape and stride information
278
+ cu_seqlens_m = self.params.cu_seqlens_m
279
+ # construct tensor D based on real address, shape and stride information
280
+ tensormap_ptrs, shapes, orders = [], [], []
281
+ if const_expr(self._tensormap_d_ptr is not None):
282
+ tensormap_ptrs.append(self._tensormap_d_ptr)
283
+ shapes.append(cu_seqlens_m[batch_idx + 1])
284
+ orders.append(order_d)
285
+ tensormap_ptrs.extend(self._tensormap_epi_ptrs)
286
+ shapes.extend(epi_shapes)
287
+ orders.extend(epi_orders)
288
+ self.tensormap_manager.update_tensormap_shape(
289
+ tensormap_ptrs,
290
+ is_manager_warp=is_manager_warp,
291
+ shapes=shapes,
292
+ orders=orders,
293
+ tensormap_smem_ptr=None,
294
+ )
295
+
296
+ @cute.jit
297
+ def fence_tensormap_update_AB(self, is_manager_warp: bool | Boolean = True) -> None:
298
+ if const_expr(self.varlen_k):
299
+ if self._is_group_changed and is_manager_warp:
300
+ if const_expr(not self.gather_A):
301
+ self.tensormap_manager.fence_tensormap_update(self._tensormap_a_ptr)
302
+ self.tensormap_manager.fence_tensormap_update(self._tensormap_b_ptr)
303
+
304
+ @cute.jit
305
+ def fence_tensormap_update_epi(self, is_manager_warp: bool | Boolean = True) -> None:
306
+ if const_expr(self.varlen_m):
307
+ if self._is_group_changed and is_manager_warp:
308
+ if const_expr(self._tensormap_d_ptr is not None):
309
+ self.tensormap_manager.fence_tensormap_update(self._tensormap_d_ptr)
310
+ for tensormap_epi_ptr in self._tensormap_epi_ptrs:
311
+ if const_expr(tensormap_epi_ptr is not None):
312
+ self.tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
313
+
314
+ def get_tma_desc_a_ptr(self) -> Optional[cute.Pointer]:
315
+ tma_desc_a_ptr = None
316
+ if const_expr(self.varlen_k and self._tensormap_a_ptr is not None):
317
+ tma_desc_a_ptr = self.tensormap_manager.get_tensormap_ptr(
318
+ self._tensormap_a_ptr, cute.AddressSpace.generic
319
+ )
320
+ return tma_desc_a_ptr
321
+
322
+ def get_tma_desc_b_ptr(self) -> Optional[cute.Pointer]:
323
+ tma_desc_b_ptr = None
324
+ if const_expr(self.varlen_k):
325
+ tma_desc_b_ptr = self.tensormap_manager.get_tensormap_ptr(
326
+ self._tensormap_b_ptr, cute.AddressSpace.generic
327
+ )
328
+ return tma_desc_b_ptr
329
+
330
+ def get_tma_desc_d_ptr(self) -> Optional[cute.Pointer]:
331
+ tma_desc_d_ptr = None
332
+ if const_expr(self.varlen_m and self._tensormap_d_ptr is not None):
333
+ tma_desc_d_ptr = self.tensormap_manager.get_tensormap_ptr(
334
+ self._tensormap_d_ptr, cute.AddressSpace.generic
335
+ )
336
+ return tma_desc_d_ptr
337
+
338
+ def get_tma_desc_epi_ptrs(self) -> list[Optional[cute.Pointer]]:
339
+ tma_desc_epi_ptrs = [None] * len(self._tensormap_epi_ptrs)
340
+ if const_expr(self.varlen_m):
341
+ for i, tensormap_epi_ptr in enumerate(self._tensormap_epi_ptrs):
342
+ if const_expr(tensormap_epi_ptr is not None):
343
+ tma_desc_epi_ptrs[i] = self.tensormap_manager.get_tensormap_ptr(
344
+ tensormap_epi_ptr, cute.AddressSpace.generic
345
+ )
346
+ return tma_desc_epi_ptrs
347
+
348
+ def __extract_mlir_values__(self):
349
+ values, self._values_pos = [], []
350
+ for obj in [
351
+ self.params,
352
+ self.tensormap_manager,
353
+ self._tensormap_a_ptr,
354
+ self._tensormap_b_ptr,
355
+ self._tensormap_d_ptr,
356
+ self._tensormap_epi_ptrs,
357
+ self._len_m_static,
358
+ self._len_k_static,
359
+ self._last_batch_idx,
360
+ self._is_group_changed,
361
+ ]:
362
+ obj_values = cutlass.extract_mlir_values(obj)
363
+ values += obj_values
364
+ self._values_pos.append(len(obj_values))
365
+ return values
366
+
367
+ def __new_from_mlir_values__(self, values):
368
+ obj_list = []
369
+ for obj, n_items in zip(
370
+ [
371
+ self.params,
372
+ self.tensormap_manager,
373
+ self._tensormap_a_ptr,
374
+ self._tensormap_b_ptr,
375
+ self._tensormap_d_ptr,
376
+ self._tensormap_epi_ptrs,
377
+ self._len_m_static,
378
+ self._len_k_static,
379
+ self._last_batch_idx,
380
+ self._is_group_changed,
381
+ ],
382
+ self._values_pos,
383
+ ):
384
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
385
+ values = values[n_items:]
386
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)