quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/utils.py CHANGED
@@ -1,25 +1,27 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- from typing import Optional, Tuple, Type, Union
4
+ from functools import partial
5
+ from typing import Optional, Tuple, Union
5
6
 
6
7
  import cutlass
7
8
  import cutlass.cute as cute
8
9
 
9
- from cutlass import Float32, Int32
10
+ from cutlass import Float32, Int32, const_expr
10
11
  from cutlass.cutlass_dsl import T, dsl_user_op
11
12
  from cutlass._mlir.dialects import llvm, nvvm, vector
12
- from cutlass.cute.runtime import from_dlpack
13
13
 
14
14
 
15
- def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
16
- return (
17
- from_dlpack(x, assumed_align=alignment)
18
- .mark_layout_dynamic(leading_dim=leading_dim)
19
- .mark_compact_shape_dynamic(
20
- mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
21
- )
22
- )
15
+ # cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default
16
+ fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
17
+ mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
18
+ add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
19
+ sub_packed_f32x2 = partial(
20
+ cute.arch.calc_packed_f32x2_op,
21
+ src_c=None,
22
+ calc_func=nvvm.sub_packed_f32x2,
23
+ rnd=nvvm.RoundingModeKind.RN,
24
+ )
23
25
 
24
26
 
25
27
  @dsl_user_op
@@ -29,7 +31,7 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
29
31
 
30
32
  @cute.jit
31
33
  def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
32
- if cutlass.const_expr(isinstance(x, cute.Pointer)):
34
+ if const_expr(isinstance(x, cute.Pointer)):
33
35
  return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
34
36
  else:
35
37
  assert isinstance(x, Float32)
@@ -38,11 +40,11 @@ def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
38
40
 
39
41
  @dsl_user_op
40
42
  def set_block_rank(
41
- smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
42
- ) -> cutlass.Int32:
43
+ smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
44
+ ) -> Int32:
43
45
  """Map the given smem pointer to the address at another CTA rank in the cluster."""
44
46
  smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
45
- return cutlass.Int32(
47
+ return Int32(
46
48
  llvm.inline_asm(
47
49
  T.i32(),
48
50
  [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
@@ -57,7 +59,7 @@ def set_block_rank(
57
59
 
58
60
  @dsl_user_op
59
61
  def store_shared_remote(
60
- val: float | Float32 | cutlass.Int64,
62
+ val: float | Float32 | Int32 | cutlass.Int64,
61
63
  smem_ptr: cute.Pointer,
62
64
  mbar_ptr: cute.Pointer,
63
65
  peer_cta_rank_in_cluster: cute.typing.Int,
@@ -71,7 +73,7 @@ def store_shared_remote(
71
73
  remote_mbar_ptr_i32 = set_block_rank(
72
74
  mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
73
75
  ).ir_value()
74
- if cutlass.const_expr(isinstance(val, float)):
76
+ if const_expr(isinstance(val, float)):
75
77
  val = Float32(val)
76
78
  assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
77
79
  suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
@@ -100,6 +102,21 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
100
102
  )
101
103
 
102
104
 
105
+ @dsl_user_op
106
+ def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
107
+ return Float32(
108
+ llvm.inline_asm(
109
+ T.f32(),
110
+ [Float32(a).ir_value(loc=loc, ip=ip)],
111
+ "sqrt.approx.f32 $0, $1;",
112
+ "=f,f",
113
+ has_side_effects=False,
114
+ is_align_stack=False,
115
+ asm_dialect=llvm.AsmDialect.AD_ATT,
116
+ )
117
+ )
118
+
119
+
103
120
  @dsl_user_op
104
121
  def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
105
122
  return Int32(
@@ -134,55 +151,6 @@ def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -
134
151
  )
135
152
 
136
153
 
137
- @cute.jit
138
- def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
139
- assert t.element_type.width == 16
140
- assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
141
- t_u32 = cute.recast_tensor(t, Int32)
142
-
143
- quad_idx = cute.arch.lane_idx() % 4
144
- lane_03 = quad_idx == 0 or quad_idx == 3
145
- selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
146
- selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
147
- # upper_map = [0, 3, 1, 2]
148
- # lower_map = [1, 2, 0, 3]
149
- # upper_idx = upper_map[quad_idx]
150
- # indexing isn't supported so we have to do arithmetic
151
- upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
152
- lower_idx = upper_idx ^ 1
153
-
154
- # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
155
- width = 4
156
- mask = cute.arch.WARP_SIZE - width
157
- clamp = cute.arch.WARP_SIZE - 1
158
- mask_and_clamp = mask << 8 | clamp
159
-
160
- for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
161
- upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
162
- upper0 = upper if lane_03 else lower
163
- lower0 = lower if lane_03 else upper
164
- upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
165
- lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
166
- t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
167
- t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
168
-
169
-
170
- @cute.jit
171
- def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
172
- # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
173
- tApA = cute.make_fragment(
174
- cute.make_layout(
175
- (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
176
- stride=(cute.size(tAcA, mode=[2]), 0, 1),
177
- ),
178
- cutlass.Boolean,
179
- )
180
- for rest_v in cutlass.range_constexpr(tApA.shape[0]):
181
- for rest_k in cutlass.range_constexpr(tApA.shape[2]):
182
- tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
183
- return tApA
184
-
185
-
186
154
  @cute.jit
187
155
  def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
188
156
  """Fill out-of-bounds values in shared memory tensor.
@@ -196,7 +164,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
196
164
  tXrX_fill.fill(fill_value)
197
165
  for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
198
166
  for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
199
- if cutlass.const_expr(tXpX is not None):
167
+ if const_expr(tXpX is not None):
200
168
  if not tXpX[rest_v, 0, rest_k]:
201
169
  cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
202
170
  else:
@@ -228,44 +196,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
228
196
  return res0, res1
229
197
 
230
198
 
231
- @dsl_user_op
232
- def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
233
- flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
234
- flat_stride = cute.flatten_to_tuple(tensor.stride)
235
- assert len(flat_coord_i64) == len(
236
- flat_stride
237
- ), "Coordinate and stride must have the same length"
238
- offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
239
- assert isinstance(tensor.iterator, cute.Pointer)
240
- # HACK: we assume that applying the offset does not change the pointer alignment
241
- new_ptr = cute.make_ptr(
242
- tensor.element_type,
243
- tensor.iterator.toint() + offset * tensor.element_type.width // 8,
244
- tensor.memspace,
245
- assumed_align=tensor.iterator.max_alignment,
246
- )
247
- return cute.make_tensor(new_ptr, tensor.layout)
248
-
249
-
250
- @dsl_user_op
251
- def coord_offset_i64(
252
- idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
253
- ) -> cute.Tensor:
254
- offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
255
- assert isinstance(tensor.iterator, cute.Pointer)
256
- # HACK: we assume that applying the offset does not change the pointer alignment
257
- new_ptr = cute.make_ptr(
258
- tensor.element_type,
259
- tensor.iterator.toint() + offset * tensor.element_type.width // 8,
260
- tensor.memspace,
261
- assumed_align=tensor.iterator.max_alignment,
262
- )
263
- return cute.make_tensor(new_ptr, tensor.layout)
264
-
265
-
266
199
  @cute.jit
267
- def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
268
- if cutlass.const_expr(lane is None):
200
+ def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32:
201
+ if const_expr(lane is None):
269
202
  lane = cute.arch.lane_idx()
270
203
  for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
271
204
  offset = 1 << i
@@ -276,74 +209,6 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) ->
276
209
  return val
277
210
 
278
211
 
279
- def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
280
- """
281
- For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
282
- For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
283
- """
284
- acc_layout_col_major = cute.make_layout(acc_layout.shape)
285
- acc_layout_mn = cute.make_layout(
286
- (
287
- (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
288
- (
289
- acc_layout_col_major.shape[0][0],
290
- *acc_layout_col_major.shape[0][2:],
291
- acc_layout_col_major.shape[2],
292
- ), # MMA_N
293
- *acc_layout_col_major.shape[3:],
294
- ),
295
- stride=(
296
- (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
297
- (
298
- acc_layout_col_major.stride[0][0],
299
- *acc_layout_col_major.stride[0][2:],
300
- acc_layout_col_major.stride[2],
301
- ), # MMA_N
302
- *acc_layout_col_major.stride[3:],
303
- ),
304
- )
305
- return cute.composition(acc_layout, acc_layout_mn)
306
-
307
-
308
- def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
309
- return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
310
-
311
-
312
- @dsl_user_op
313
- def sm90_get_smem_load_op(
314
- layout_c: cutlass.utils.LayoutEnum,
315
- elem_ty_c: Type[cutlass.Numeric],
316
- *,
317
- loc=None,
318
- ip=None,
319
- ) -> cute.CopyAtom:
320
- """
321
- Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
322
-
323
- Parameters:
324
- -----------
325
- layout_c : LayoutEnum
326
- The layout enum of the output tensor D.
327
-
328
- elem_ty_c : Type[Numeric]
329
- The element type for output tensor D.
330
-
331
- Returns:
332
- --------
333
- Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
334
- """
335
-
336
- if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
337
- raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
338
- is_m_major = layout_c.is_m_major_c()
339
- if elem_ty_c.width == 16:
340
- return cute.make_copy_atom(
341
- cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
342
- )
343
- else:
344
- return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
345
-
346
-
347
212
  @dsl_user_op
348
213
  def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
349
214
  return nvvm.atomicrmw(
quack/varlen_utils.py CHANGED
@@ -3,9 +3,13 @@
3
3
  from typing import Optional
4
4
  from dataclasses import dataclass
5
5
 
6
+ import cutlass
6
7
  import cutlass.cute as cute
8
+ from cutlass import Int32, Boolean, const_expr
9
+ from cutlass.utils import LayoutEnum
7
10
 
8
- from quack.cute_dsl_utils import ArgumentsBase
11
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
12
+ from quack.tensormap_manager import TensorMapManagerSm90
9
13
 
10
14
 
11
15
  # Grouping arguments together that should be passed to __call__
@@ -14,9 +18,369 @@ class VarlenArguments(ArgumentsBase):
14
18
  mCuSeqlensM: Optional[cute.Tensor] = None
15
19
  mCuSeqlensK: Optional[cute.Tensor] = None
16
20
  mTensormaps: Optional[cute.Tensor] = None
21
+ mAIdx: Optional[cute.Tensor] = None
17
22
 
18
- def __post_init__(self):
19
- if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
20
- assert (
21
- self.mTensormaps is not None
22
- ), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
23
+
24
+ class VarlenManager:
25
+ bytes_per_tensormap = 128
26
+
27
+ @dataclass
28
+ class Params(ParamsBase):
29
+ cu_seqlens_m: Optional[cute.Tensor] = None
30
+ cu_seqlens_k: Optional[cute.Tensor] = None
31
+ tensormaps: Optional[cute.Tensor] = None
32
+ mAIdx: Optional[cute.Tensor] = None
33
+
34
+ @staticmethod
35
+ @cute.jit
36
+ def create(args: VarlenArguments, *, loc=None, ip=None) -> "VarlenManager.Params":
37
+ return VarlenManager.Params(
38
+ cu_seqlens_m=args.mCuSeqlensM,
39
+ cu_seqlens_k=args.mCuSeqlensK,
40
+ tensormaps=args.mTensormaps,
41
+ mAIdx=args.mAIdx,
42
+ )
43
+
44
+ def __init__(
45
+ self,
46
+ params: Params,
47
+ tensormap_manager: Optional[cutlass.utils.TensorMapManager],
48
+ tensormap_a_ptr: Optional[cute.Pointer],
49
+ tensormap_b_ptr: Optional[cute.Pointer],
50
+ tensormap_d_ptr: Optional[cute.Pointer],
51
+ tensormap_epi_ptrs: list[Optional[cute.Pointer]],
52
+ len_m_static: Int32,
53
+ len_k_static: Int32,
54
+ last_batch_idx: Int32 = Int32(-1),
55
+ is_group_changed: Boolean = Boolean(True),
56
+ *,
57
+ loc=None,
58
+ ip=None,
59
+ ):
60
+ self.params = params
61
+ self.tensormap_manager = tensormap_manager
62
+ self._tensormap_a_ptr = tensormap_a_ptr
63
+ self._tensormap_b_ptr = tensormap_b_ptr
64
+ self._tensormap_d_ptr = tensormap_d_ptr
65
+ self._tensormap_epi_ptrs = tensormap_epi_ptrs
66
+ self._len_m_static = len_m_static
67
+ self._len_k_static = len_k_static
68
+ self._last_batch_idx = last_batch_idx
69
+ self._is_group_changed = is_group_changed
70
+ self.varlen_m = const_expr(params.cu_seqlens_m is not None)
71
+ self.varlen_k = const_expr(params.cu_seqlens_k is not None)
72
+ self.gather_A = const_expr(params.mAIdx is not None)
73
+ self._loc = loc
74
+ self._ip = ip
75
+
76
+ @staticmethod
77
+ def to_underlying_arguments(args: VarlenArguments, *, loc=None, ip=None) -> Params:
78
+ assert not (args.mCuSeqlensM is not None and args.mCuSeqlensK is not None), (
79
+ "Only support either varlen_m or varlen_k"
80
+ )
81
+ return VarlenManager.Params.create(args, loc=loc, ip=ip)
82
+
83
+ @staticmethod
84
+ @cute.jit
85
+ def create(
86
+ params: Params,
87
+ has_D: bool,
88
+ num_epi_tensormaps: int,
89
+ len_m_static: Int32,
90
+ len_k_static: Int32,
91
+ pingpong: bool = False,
92
+ warp_idx: int | Int32 = 0,
93
+ *,
94
+ loc=None,
95
+ ip=None,
96
+ ) -> "VarlenManager":
97
+ tensormap_manager = None
98
+ tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
99
+ tensormap_epi_ptrs = [None] * num_epi_tensormaps
100
+ varlen_m = const_expr(params.cu_seqlens_m is not None)
101
+ varlen_k = const_expr(params.cu_seqlens_k is not None)
102
+ if const_expr(varlen_m or varlen_k):
103
+ tensormap_manager = TensorMapManagerSm90(
104
+ cutlass.utils.TensorMapUpdateMode.GMEM, VarlenManager.bytes_per_tensormap
105
+ )
106
+ # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
107
+ tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
108
+ if const_expr(varlen_m):
109
+ tensormap_d_idx = warp_idx // 4 if const_expr(pingpong) else 0
110
+ tensormap_epi_offset = tensormap_d_idx
111
+ if const_expr(has_D):
112
+ tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
113
+ params.tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
114
+ )
115
+ tensormap_epi_offset += 1 if not pingpong else 2
116
+ tensormap_epi_ptrs = [
117
+ tensormap_manager.get_tensormap_ptr(
118
+ params.tensormaps[
119
+ tensormap_workspace_idx,
120
+ tensormap_epi_offset + i * (1 if not pingpong else 2),
121
+ None,
122
+ ].iterator
123
+ )
124
+ for i in range(num_epi_tensormaps)
125
+ ]
126
+ else:
127
+ assert varlen_k
128
+ gather_A = const_expr(params.mAIdx is not None)
129
+ if const_expr(not gather_A):
130
+ tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
131
+ params.tensormaps[tensormap_workspace_idx, 0, None].iterator
132
+ )
133
+ tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
134
+ params.tensormaps[
135
+ tensormap_workspace_idx, 1 if not gather_A else 0, None
136
+ ].iterator
137
+ )
138
+ return VarlenManager(
139
+ params,
140
+ tensormap_manager,
141
+ tensormap_a_ptr,
142
+ tensormap_b_ptr,
143
+ tensormap_d_ptr,
144
+ tensormap_epi_ptrs,
145
+ len_m_static=len_m_static,
146
+ len_k_static=len_k_static,
147
+ )
148
+
149
+ def len_m(self, batch_idx: Int32) -> Int32:
150
+ if const_expr(self.varlen_m):
151
+ return self.params.cu_seqlens_m[batch_idx + 1] - self.params.cu_seqlens_m[batch_idx]
152
+ else:
153
+ return self._len_m_static
154
+
155
+ def len_k(self, batch_idx: Int32) -> Int32:
156
+ if const_expr(self.varlen_k):
157
+ return self.params.cu_seqlens_k[batch_idx + 1] - self.params.cu_seqlens_k[batch_idx]
158
+ else:
159
+ return self._len_k_static
160
+
161
+ def offset_batch_A(self, mA_mkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
162
+ params = self.params
163
+ if const_expr(self.varlen_m):
164
+ mA_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mA_mkl)
165
+ elif const_expr(self.varlen_k):
166
+ mA_mk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mA_mkl)
167
+ else:
168
+ mA_mk = mA_mkl[None, None, batch_idx]
169
+ return mA_mk
170
+
171
+ def offset_batch_AIdx(self, batch_idx: Int32) -> cute.Tensor:
172
+ params = self.params
173
+ if const_expr(self.varlen_m):
174
+ mAIdx_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx],), params.mAIdx)
175
+ elif const_expr(self.varlen_k):
176
+ mAIdx_mk = cute.domain_offset((params.cu_seqlens_k[batch_idx],), params.mAIdx)
177
+ else:
178
+ mAIdx_mk = params.mAIdx[None, batch_idx]
179
+ return mAIdx_mk
180
+
181
+ def offset_batch_B(self, mB_nkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
182
+ params = self.params
183
+ if const_expr(self.varlen_k):
184
+ mB_nk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mB_nkl)
185
+ else:
186
+ mB_nk = mB_nkl[None, None, batch_idx]
187
+ return mB_nk
188
+
189
+ def offset_batch_epi(self, mD_mnl: cute.Tensor, batch_idx: Int32) -> cute.Tensor:
190
+ params = self.params
191
+ if const_expr(self.varlen_m):
192
+ mD_mn = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mD_mnl)
193
+ else:
194
+ mD_mn = mD_mnl[None, None, batch_idx]
195
+ return mD_mn
196
+
197
+ def init_tensormap_AB(
198
+ self,
199
+ tma_atom_a: Optional[cute.CopyAtom],
200
+ tma_atom_b: cute.CopyAtom,
201
+ is_manager_warp: bool | Boolean = True,
202
+ ) -> None:
203
+ if const_expr(self.varlen_k):
204
+ if const_expr(not self.gather_A):
205
+ self.tensormap_manager.init_tensormap_from_atom(
206
+ tma_atom_a, self._tensormap_a_ptr, is_manager_warp
207
+ )
208
+ self.tensormap_manager.init_tensormap_from_atom(
209
+ tma_atom_b, self._tensormap_b_ptr, is_manager_warp
210
+ )
211
+
212
+ def init_tensormap_epi(
213
+ self,
214
+ tma_atom_d: Optional[cute.CopyAtom],
215
+ tma_atoms_epi: list[cute.CopyAtom],
216
+ is_manager_warp: bool | Boolean = True,
217
+ ) -> None:
218
+ if const_expr(self.varlen_m):
219
+ if const_expr(self._tensormap_d_ptr is not None):
220
+ self.tensormap_manager.init_tensormap_from_atom(
221
+ tma_atom_d, self._tensormap_d_ptr, is_manager_warp
222
+ )
223
+ for tma_atom, tensormap_epi_ptr in zip(tma_atoms_epi, self._tensormap_epi_ptrs):
224
+ self.tensormap_manager.init_tensormap_from_atom(
225
+ tma_atom, tensormap_epi_ptr, is_manager_warp
226
+ )
227
+
228
+ def fence_tensormap_init(self) -> None:
229
+ self.tensormap_manager.fence_tensormap_initialization()
230
+
231
+ @cute.jit
232
+ def update_tensormap_AB(
233
+ self,
234
+ batch_idx: Int32,
235
+ a_layout: LayoutEnum,
236
+ b_layout: LayoutEnum,
237
+ is_manager_warp: bool | Boolean = True,
238
+ ) -> None:
239
+ if const_expr(self.varlen_k):
240
+ self._is_group_changed = Boolean(batch_idx != self._last_batch_idx)
241
+ self._last_batch_idx = batch_idx
242
+ if self._is_group_changed:
243
+ # construct tensor A/B based on real address, shape and stride information
244
+ cu_seqlens_k = self.params.cu_seqlens_k
245
+ tensormap_ptrs = [self._tensormap_b_ptr]
246
+ shapes = [cu_seqlens_k[batch_idx + 1]]
247
+ orders = [0 if const_expr(b_layout == LayoutEnum.ROW_MAJOR) else 1]
248
+ if const_expr(not self.gather_A):
249
+ tensormap_ptrs.insert(0, self._tensormap_a_ptr)
250
+ shapes.insert(0, cu_seqlens_k[batch_idx + 1])
251
+ orders.insert(0, 0 if const_expr(a_layout == LayoutEnum.ROW_MAJOR) else 1)
252
+ self.tensormap_manager.update_tensormap_shape(
253
+ tensormap_ptrs,
254
+ is_manager_warp=is_manager_warp,
255
+ shapes=shapes,
256
+ orders=orders,
257
+ tensormap_smem_ptr=None,
258
+ )
259
+
260
+ @cute.jit
261
+ def update_tensormap_epi(
262
+ self,
263
+ batch_idx: Int32,
264
+ d_layout: LayoutEnum,
265
+ epi_shapes: list[Int32],
266
+ epi_orders: list[int],
267
+ is_manager_warp: bool | Boolean = True,
268
+ ) -> None:
269
+ if const_expr(self.varlen_m):
270
+ self._is_group_changed = Boolean(batch_idx != self._last_batch_idx)
271
+ self._last_batch_idx = batch_idx
272
+ # Cute-DSL doesn't like this under if statement
273
+ order_d = (
274
+ (0 if const_expr(d_layout.is_m_major_c()) else 1) if d_layout is not None else None
275
+ )
276
+ if self._is_group_changed:
277
+ # construct tensor A/B based on real address, shape and stride information
278
+ cu_seqlens_m = self.params.cu_seqlens_m
279
+ # construct tensor D based on real address, shape and stride information
280
+ tensormap_ptrs, shapes, orders = [], [], []
281
+ if const_expr(self._tensormap_d_ptr is not None):
282
+ tensormap_ptrs.append(self._tensormap_d_ptr)
283
+ shapes.append(cu_seqlens_m[batch_idx + 1])
284
+ orders.append(order_d)
285
+ tensormap_ptrs.extend(self._tensormap_epi_ptrs)
286
+ shapes.extend(epi_shapes)
287
+ orders.extend(epi_orders)
288
+ self.tensormap_manager.update_tensormap_shape(
289
+ tensormap_ptrs,
290
+ is_manager_warp=is_manager_warp,
291
+ shapes=shapes,
292
+ orders=orders,
293
+ tensormap_smem_ptr=None,
294
+ )
295
+
296
+ @cute.jit
297
+ def fence_tensormap_update_AB(self, is_manager_warp: bool | Boolean = True) -> None:
298
+ if const_expr(self.varlen_k):
299
+ if self._is_group_changed and is_manager_warp:
300
+ if const_expr(not self.gather_A):
301
+ self.tensormap_manager.fence_tensormap_update(self._tensormap_a_ptr)
302
+ self.tensormap_manager.fence_tensormap_update(self._tensormap_b_ptr)
303
+
304
+ @cute.jit
305
+ def fence_tensormap_update_epi(self, is_manager_warp: bool | Boolean = True) -> None:
306
+ if const_expr(self.varlen_m):
307
+ if self._is_group_changed and is_manager_warp:
308
+ if const_expr(self._tensormap_d_ptr is not None):
309
+ self.tensormap_manager.fence_tensormap_update(self._tensormap_d_ptr)
310
+ for tensormap_epi_ptr in self._tensormap_epi_ptrs:
311
+ if const_expr(tensormap_epi_ptr is not None):
312
+ self.tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
313
+
314
+ def get_tma_desc_a_ptr(self) -> Optional[cute.Pointer]:
315
+ tma_desc_a_ptr = None
316
+ if const_expr(self.varlen_k and self._tensormap_a_ptr is not None):
317
+ tma_desc_a_ptr = self.tensormap_manager.get_tensormap_ptr(
318
+ self._tensormap_a_ptr, cute.AddressSpace.generic
319
+ )
320
+ return tma_desc_a_ptr
321
+
322
+ def get_tma_desc_b_ptr(self) -> Optional[cute.Pointer]:
323
+ tma_desc_b_ptr = None
324
+ if const_expr(self.varlen_k):
325
+ tma_desc_b_ptr = self.tensormap_manager.get_tensormap_ptr(
326
+ self._tensormap_b_ptr, cute.AddressSpace.generic
327
+ )
328
+ return tma_desc_b_ptr
329
+
330
+ def get_tma_desc_d_ptr(self) -> Optional[cute.Pointer]:
331
+ tma_desc_d_ptr = None
332
+ if const_expr(self.varlen_m and self._tensormap_d_ptr is not None):
333
+ tma_desc_d_ptr = self.tensormap_manager.get_tensormap_ptr(
334
+ self._tensormap_d_ptr, cute.AddressSpace.generic
335
+ )
336
+ return tma_desc_d_ptr
337
+
338
+ def get_tma_desc_epi_ptrs(self) -> list[Optional[cute.Pointer]]:
339
+ tma_desc_epi_ptrs = [None] * len(self._tensormap_epi_ptrs)
340
+ if const_expr(self.varlen_m):
341
+ for i, tensormap_epi_ptr in enumerate(self._tensormap_epi_ptrs):
342
+ if const_expr(tensormap_epi_ptr is not None):
343
+ tma_desc_epi_ptrs[i] = self.tensormap_manager.get_tensormap_ptr(
344
+ tensormap_epi_ptr, cute.AddressSpace.generic
345
+ )
346
+ return tma_desc_epi_ptrs
347
+
348
+ def __extract_mlir_values__(self):
349
+ values, self._values_pos = [], []
350
+ for obj in [
351
+ self.params,
352
+ self.tensormap_manager,
353
+ self._tensormap_a_ptr,
354
+ self._tensormap_b_ptr,
355
+ self._tensormap_d_ptr,
356
+ self._tensormap_epi_ptrs,
357
+ self._len_m_static,
358
+ self._len_k_static,
359
+ self._last_batch_idx,
360
+ self._is_group_changed,
361
+ ]:
362
+ obj_values = cutlass.extract_mlir_values(obj)
363
+ values += obj_values
364
+ self._values_pos.append(len(obj_values))
365
+ return values
366
+
367
+ def __new_from_mlir_values__(self, values):
368
+ obj_list = []
369
+ for obj, n_items in zip(
370
+ [
371
+ self.params,
372
+ self.tensormap_manager,
373
+ self._tensormap_a_ptr,
374
+ self._tensormap_b_ptr,
375
+ self._tensormap_d_ptr,
376
+ self._tensormap_epi_ptrs,
377
+ self._len_m_static,
378
+ self._len_k_static,
379
+ self._last_batch_idx,
380
+ self._is_group_changed,
381
+ ],
382
+ self._values_pos,
383
+ ):
384
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
385
+ values = values[n_items:]
386
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)
@@ -1,10 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.2.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.3.3
7
7
  Requires-Dist: torch
8
+ Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
9
+ Requires-Dist: torch-c-dlpack-ext
8
10
  Provides-Extra: dev
9
11
  Requires-Dist: pre-commit; extra == "dev"
10
12
  Requires-Dist: ruff; extra == "dev"