quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,330 @@
1
+ from typing import Tuple, Optional, Callable
2
+ from functools import partial
3
+ from torch import Tensor
4
+ from quack.gemm_act import GemmActMixin, act_fn_map, gemm_act
5
+ from quack.gemm_sm90 import GemmSm90
6
+ from quack.gemm_sm100 import GemmSm100
7
+ from quack.tile_scheduler import TriangularTileScheduler
8
+ from quack.gemm_wrapper_utils import GemmWrapperBase
9
+ from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
10
+ from quack.varlen_utils import VarlenManager
11
+ import quack.copy_utils as copy_utils
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+ import cutlass.torch as cutlass_torch
15
+ from cutlass.cute.runtime import make_ptr
16
+ from cutlass import Int32, Float32, Boolean, const_expr
17
+ import cutlass.utils.hopper_helpers as sm90_utils_og
18
+ import cutlass.utils.blackwell_helpers as sm100_utils
19
+ from cutlass.cutlass_dsl import if_generate
20
+
21
+
22
+ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
23
+ def get_scheduler_class(self, varlen_m: bool = False):
24
+ return TriangularTileScheduler
25
+
26
+ @cute.jit
27
+ def epilogue(
28
+ self,
29
+ params: GemmActMixin.EpilogueParams,
30
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
31
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
32
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
33
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
34
+ epi_read_state: cutlass.pipeline.PipelineState,
35
+ epi_producer_state: cutlass.pipeline.PipelineState,
36
+ epi_tile: cute.Tile,
37
+ load_acc_subtile: Callable,
38
+ tRS_rD: cute.Tensor,
39
+ tRS_rC: Optional[cute.Tensor],
40
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
41
+ tiled_copy_r2s: cute.TiledCopy,
42
+ tRS_sD: cute.Tensor,
43
+ tiled_copy_s2r: Optional[cute.TiledCopy],
44
+ tSR_rC: Optional[cute.Tensor],
45
+ tSR_sC: Optional[cute.Tensor],
46
+ copy_D: Optional[Callable],
47
+ copy_C: Optional[Callable],
48
+ tile_coord_mnkl: cute.Coord,
49
+ varlen_manager: VarlenManager,
50
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
51
+ tile_scheduler,
52
+ tidx: Int32,
53
+ is_tma_warp: Boolean,
54
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
55
+ has_C = const_expr(tRS_rC is not None)
56
+ has_D = const_expr(copy_D is not None)
57
+
58
+ tma_atom_postact = params.tma_atom_postact
59
+ mPostAct_mnl = params.mPostAct_mnl
60
+ sRowVec, sColVec, sPostAct = epi_smem_tensors
61
+ get_smem_store_op = (
62
+ partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
63
+ if self.arch == 100
64
+ else sm90_utils_og.sm90_get_smem_store_op
65
+ )
66
+ copy_atom_postact_r2s = get_smem_store_op(
67
+ self.postact_layout, self.postact_dtype, self.acc_dtype
68
+ )
69
+ # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
70
+ # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
71
+ tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
72
+ tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
73
+ (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
74
+ batch_idx = tile_coord_mnkl[3]
75
+ copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
76
+ tma_atom_postact,
77
+ varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
78
+ self.cta_tile_shape_postact_mn,
79
+ params.epi_tile_postact,
80
+ sPostAct,
81
+ tile_coord_mnkl,
82
+ tma_desc_ptr=tma_desc_postact_ptr,
83
+ )
84
+
85
+ # We iterate over epi tiles in the N dimension first before the M dimension
86
+ epi_tile_shape = cute.zipped_divide(
87
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
88
+ ).shape[1]
89
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
90
+ epi_tile_num = cute.size(epi_tile_shape)
91
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
92
+
93
+ epi_tensors = self.epi_begin(
94
+ params,
95
+ epi_smem_tensors,
96
+ epi_tile,
97
+ tiled_copy_t2r,
98
+ tiled_copy_r2s,
99
+ tile_coord_mnkl,
100
+ varlen_manager,
101
+ epilogue_barrier,
102
+ tidx,
103
+ )
104
+
105
+ if const_expr(copy_C is not None):
106
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
107
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
108
+ if is_tma_warp:
109
+ epi_pipeline.producer_acquire(epi_producer_state)
110
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
111
+ epi_pipeline.producer_commit(epi_producer_state)
112
+ epi_producer_state.advance()
113
+
114
+ def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl):
115
+ pid_m = tile_coord_mnkl[0]
116
+ pid_n = tile_coord_mnkl[1]
117
+ # Fence and barrier to make sure shared memory store is visible to TMA store
118
+ cute.arch.fence_proxy(
119
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
120
+ )
121
+ epilogue_barrier.arrive_and_wait()
122
+ # Copy from shared memory to global memory
123
+ if is_tma_warp:
124
+ square_tile_m = pid_m // self.cluster_shape_mnk[0]
125
+ square_tile_n = pid_n // self.cluster_shape_mnk[1]
126
+ if const_expr(has_D):
127
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
128
+ if square_tile_m != square_tile_n: # don't write twice to the same tile
129
+ copy_postact(src_idx=src_idx, dst_idx=dst_idx)
130
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
131
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
132
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
133
+ epilogue_barrier.arrive_and_wait()
134
+
135
+ delay_tma_store = True
136
+
137
+ src_idx_prev, dst_idx_prev = None, None
138
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
139
+ # The global memory coordinate for the current epi tile
140
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
141
+ # Copy from acc to D registers
142
+ load_acc_subtile(tRS_rD, epi_idx)
143
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
144
+ if const_expr(has_C):
145
+ epi_pipeline.consumer_wait(epi_read_state)
146
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
147
+ # Fence to make sure shared memory read is visible to TMA load
148
+ cute.arch.fence_proxy(
149
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
150
+ )
151
+ cute.arch.sync_warp()
152
+ with cute.arch.elect_one():
153
+ epi_pipeline.consumer_release(epi_read_state)
154
+ epi_read_state.advance()
155
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
156
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
157
+ if is_tma_warp:
158
+ epi_pipeline.producer_acquire(epi_producer_state)
159
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
160
+ epi_pipeline.producer_commit(epi_producer_state)
161
+ epi_producer_state.advance()
162
+ tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
163
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
164
+ if const_expr(delay_tma_store):
165
+ if const_expr(epi_idx > 0):
166
+ tma_store_fn(
167
+ src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
168
+ )
169
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
170
+ # Copy from D registers to shared memory
171
+ if const_expr(has_D):
172
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
173
+ cute.copy(
174
+ tiled_copy_postact_r2s,
175
+ tiled_copy_postact_r2s.retile(tRS_rPostAct),
176
+ tRS_sPostAct[None, None, None, epi_buffer],
177
+ )
178
+ if const_expr(not delay_tma_store):
179
+ tma_store_fn(
180
+ src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl
181
+ )
182
+
183
+ if const_expr(delay_tma_store):
184
+ tma_store_fn(
185
+ src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
186
+ )
187
+
188
+ self.epi_end(
189
+ params,
190
+ epi_tensors,
191
+ epi_tile,
192
+ tiled_copy_t2r,
193
+ tiled_copy_r2s,
194
+ tile_coord_mnkl,
195
+ varlen_manager,
196
+ tidx,
197
+ )
198
+
199
+ return epi_read_state, epi_producer_state
200
+
201
+
202
+ class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90):
203
+ pass
204
+
205
+
206
+ class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
207
+ pass
208
+
209
+
210
+ def gemm_symmetric(
211
+ A: Tensor, # (l, m, k)
212
+ B: Tensor, # (l, m, k)
213
+ D: Optional[Tensor], # (l, m, m)
214
+ C: Optional[Tensor], # (l, m, m)
215
+ tile_count_semaphore: Optional[Tensor], # (1,)
216
+ tile_M: int,
217
+ tile_N: int,
218
+ cluster_M: int,
219
+ cluster_N: int,
220
+ pingpong: bool = False,
221
+ persistent: bool = True,
222
+ max_swizzle_size: int = 8,
223
+ alpha: float | Tensor = 1.0,
224
+ beta: float | Tensor = 1.0,
225
+ ) -> None:
226
+ # Tranpose D so the "activation" is a write to the mirrored tile
227
+ PostAct = D.mT
228
+
229
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
230
+ A, B, D, C, additional_tensors={"PostAct": PostAct}
231
+ )
232
+ assert M == N, "M and N must be the same; symmetric gemm only supports square matrices"
233
+ GemmWrapperBase.permute_tensors(tensor_infos)
234
+ GemmWrapperBase.extract_dtypes(tensor_infos)
235
+ major_configs = {
236
+ "A": ("m", "k", "l"),
237
+ "B": ("n", "k", "l"),
238
+ "D": ("m", "n", "l"),
239
+ "C": ("m", "n", "l"),
240
+ "PostAct": ("m", "n", "l"),
241
+ }
242
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
243
+
244
+ device_capacity = get_device_capacity(A.device)
245
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
246
+ GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100
247
+
248
+ acc_dtype = Float32
249
+ tile_shape_mn = (tile_M, tile_N)
250
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
251
+ if not GemmCls.is_valid_dtypes(
252
+ tensor_infos["A"].dtype,
253
+ tensor_infos["B"].dtype,
254
+ acc_dtype,
255
+ tensor_infos["D"].dtype,
256
+ tensor_infos["A"].major,
257
+ tensor_infos["B"].major,
258
+ ):
259
+ raise TypeError("Skipping due to unsupported combination of types and majors")
260
+
261
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
262
+ GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs)
263
+
264
+ def scalar_arg(scalar: float | Tensor):
265
+ if isinstance(scalar, float):
266
+ return Float32(scalar) if scalar != 1.0 else None
267
+ else:
268
+ assert isinstance(scalar, Tensor)
269
+ return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
270
+
271
+ activation = None # Equivalent to identity
272
+ act_fn = act_fn_map[activation]
273
+ epi_args = GemmCls.EpilogueArguments(
274
+ tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta)
275
+ )
276
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
277
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
278
+ )
279
+ varlen_args = None
280
+
281
+ current_stream = cutlass_torch.current_stream()
282
+ compile_key = GemmWrapperBase.get_compile_key(
283
+ tensor_infos,
284
+ activation,
285
+ tile_shape_mn,
286
+ cluster_shape_mnk,
287
+ pingpong,
288
+ persistent,
289
+ tile_count_semaphore is not None,
290
+ device_capacity,
291
+ max_swizzle_size,
292
+ 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
293
+ 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
294
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
295
+ )
296
+ cache = gemm_act.compile_cache
297
+ if compile_key not in cache:
298
+ if device_capacity[0] == 9:
299
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
300
+ gemm_obj = GemmCls(
301
+ acc_dtype,
302
+ tensor_infos["A"].dtype,
303
+ tile_shape_mn,
304
+ cluster_shape_mnk,
305
+ gather_A=False,
306
+ )
307
+ cache[compile_key] = cute.compile(
308
+ gemm_obj,
309
+ tensor_infos["A"].cute_tensor,
310
+ tensor_infos["B"].cute_tensor,
311
+ tensor_infos["D"].cute_tensor,
312
+ tensor_infos["C"].cute_tensor,
313
+ epi_args,
314
+ scheduler_args,
315
+ varlen_args,
316
+ current_stream,
317
+ )
318
+ cache[compile_key](
319
+ tensor_infos["A"].cute_tensor,
320
+ tensor_infos["B"].cute_tensor,
321
+ tensor_infos["D"].cute_tensor,
322
+ tensor_infos["C"].cute_tensor,
323
+ epi_args,
324
+ scheduler_args,
325
+ varlen_args,
326
+ current_stream,
327
+ )
328
+
329
+
330
+ gemm_act.compile_cache = {}
@@ -2,6 +2,7 @@
2
2
  from typing import Optional, Tuple, Dict, Any
3
3
  from dataclasses import dataclass
4
4
 
5
+ import torch
5
6
  from torch import Tensor
6
7
 
7
8
  import cutlass.cute as cute
@@ -9,7 +10,8 @@ from cutlass import Int32
9
10
  from cutlass.cute.runtime import from_dlpack, make_ptr
10
11
 
11
12
  from quack.cute_dsl_utils import torch2cute_dtype_map
12
- from quack.dense_gemm_sm90 import TileSchedulerOptions
13
+ from quack.varlen_utils import VarlenArguments
14
+ from quack.tile_scheduler import TileSchedulerOptions
13
15
 
14
16
 
15
17
  @dataclass
@@ -22,8 +24,8 @@ class GemmTensorInfo:
22
24
 
23
25
  class GemmWrapperBase:
24
26
  @staticmethod
25
- def validate_tensor_3d(tensor: Tensor, name: str) -> None:
26
- assert tensor.dim() == 3 and tensor.is_cuda, f"{name} must be a 3D CUDA tensor"
27
+ def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None:
28
+ assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor"
27
29
  assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
28
30
 
29
31
  @staticmethod
@@ -47,7 +49,7 @@ class GemmWrapperBase:
47
49
  ) -> Optional[cute.Tensor]:
48
50
  if tensor is None:
49
51
  return None
50
- # Tensor is already permuted to (dims[0], dims[1], dims[2])
52
+ # Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1])
51
53
  # If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
52
54
  leading_dim = 1 if major == dims[1] else 0
53
55
  return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
@@ -61,43 +63,131 @@ class GemmWrapperBase:
61
63
  D: Optional[Tensor] = None,
62
64
  C: Optional[Tensor] = None,
63
65
  additional_tensors: Optional[Dict[str, Tensor]] = None,
66
+ cu_seqlens_m: Optional[Tensor] = None,
67
+ cu_seqlens_k: Optional[Tensor] = None,
68
+ A_idx: Optional[Tensor] = None,
64
69
  ) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
65
- GemmWrapperBase.validate_tensor_3d(A, "A")
66
- L, M, K = A.shape
67
- GemmWrapperBase.validate_tensor_3d(B, "B")
68
- _, N, _ = B.shape
70
+ assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
71
+ "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
72
+ )
69
73
  assert B.dtype == A.dtype, "A and B must have the same dtype"
70
- GemmWrapperBase.validate_shape(B, (L, N, K), "B")
74
+
75
+ # Validate A_idx if provided (for gather_A case)
76
+ gather_A = A_idx is not None
77
+ if gather_A:
78
+ assert cu_seqlens_m is not None or cu_seqlens_k is not None, (
79
+ "gather_A requires either varlen_m or varlen_k"
80
+ )
81
+ assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}"
82
+ assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D"
83
+
84
+ # Determine mode and extract dimensions
85
+ if cu_seqlens_m is not None:
86
+ # varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n)
87
+ assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D"
88
+ assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D"
89
+
90
+ if gather_A:
91
+ # When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M
92
+ total_M = A_idx.shape[0]
93
+ _, K = A.shape
94
+ else:
95
+ total_M, K = A.shape
96
+
97
+ L, N, K_B = B.shape
98
+ assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
99
+ assert cu_seqlens_m.shape == (L + 1,), (
100
+ f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}"
101
+ )
102
+ M = total_M
103
+ dc_shape = (total_M, N)
104
+ dc_ndim = 2
105
+ elif cu_seqlens_k is not None:
106
+ # varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n)
107
+ assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D"
108
+ assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D"
109
+
110
+ if gather_A:
111
+ # When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K
112
+ M, _ = A.shape
113
+ total_K = A_idx.shape[0]
114
+ else:
115
+ M, total_K = A.shape
116
+
117
+ N, K_B = B.shape
118
+ assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}"
119
+ L = cu_seqlens_k.shape[0] - 1
120
+ assert cu_seqlens_k.shape == (L + 1,), (
121
+ f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}"
122
+ )
123
+ K = total_K
124
+ dc_shape = (L, M, N)
125
+ dc_ndim = 3
126
+ else:
127
+ # Normal case - all tensors must be 3D
128
+ GemmWrapperBase.validate_tensor(A, "A", 3)
129
+ GemmWrapperBase.validate_tensor(B, "B", 3)
130
+ L, M, K = A.shape
131
+ _, N, K_B = B.shape
132
+ assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
133
+ GemmWrapperBase.validate_shape(B, (L, N, K), "B")
134
+ dc_shape = (L, M, N)
135
+ dc_ndim = 3
136
+
137
+ # Validate D and C shapes uniformly
138
+ for tensor, name in [(D, "D"), (C, "C")]:
139
+ if tensor is not None:
140
+ assert tensor.dim() == dc_ndim, (
141
+ f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
142
+ )
143
+ assert tensor.shape == dc_shape, (
144
+ f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
145
+ )
146
+
71
147
  tensors = {
72
148
  "A": GemmTensorInfo(A),
73
149
  "B": GemmTensorInfo(B),
74
150
  "D": GemmTensorInfo(D),
75
151
  "C": GemmTensorInfo(C),
76
152
  }
77
- if D is not None:
78
- GemmWrapperBase.validate_tensor_3d(D, "D")
79
- GemmWrapperBase.validate_shape(D, (L, M, N), "D")
80
- if C is not None:
81
- GemmWrapperBase.validate_tensor_3d(C, "C")
82
- GemmWrapperBase.validate_shape(C, (L, M, N), "C")
153
+
83
154
  if additional_tensors:
84
155
  for name, tensor in additional_tensors.items():
85
156
  if tensor is not None:
86
- GemmWrapperBase.validate_tensor_3d(tensor, name)
87
- GemmWrapperBase.validate_shape(tensor, (L, M, N), name)
157
+ assert tensor.dim() == dc_ndim, (
158
+ f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
159
+ )
160
+ assert tensor.shape == dc_shape, (
161
+ f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
162
+ )
88
163
  tensors[name] = GemmTensorInfo(tensor)
89
164
 
90
165
  return L, M, K, N, tensors
91
166
 
92
167
  @staticmethod
93
- def permute_tensors(tensors: Dict[str, GemmTensorInfo]) -> None:
94
- for info in tensors.values():
95
- if info.tensor is not None:
96
- info.tensor = info.tensor.permute(1, 2, 0)
168
+ def permute_tensors(
169
+ tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False
170
+ ) -> None:
171
+ # Determine which tensors need permutation
172
+ if varlen_m:
173
+ # Only B needs permutation (3D tensor)
174
+ tensors_to_permute = ["B"]
175
+ elif varlen_k:
176
+ # Only D and C need permutation (3D tensors)
177
+ tensors_to_permute = ["D", "C"]
178
+ else:
179
+ # All tensors need permutation
180
+ tensors_to_permute = None
181
+
182
+ # Apply permutation from (L, *, *) -> (*, *, L) for selected tensors
183
+ for name, info in tensors.items():
184
+ if info.tensor is not None and info.tensor.ndim == 3:
185
+ if tensors_to_permute is None or name in tensors_to_permute:
186
+ info.tensor = info.tensor.permute(1, 2, 0)
97
187
 
98
188
  @staticmethod
99
189
  def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
100
- for info in tensors.values():
190
+ for name, info in tensors.items():
101
191
  if info.tensor is not None:
102
192
  info.dtype = torch2cute_dtype_map[info.tensor.dtype]
103
193
 
@@ -121,7 +211,10 @@ class GemmWrapperBase:
121
211
 
122
212
  @staticmethod
123
213
  def create_scheduler_args(
124
- max_active_clusters: int, tile_count_semaphore: Optional[Tensor] = None
214
+ max_active_clusters: int,
215
+ tile_count_semaphore: Optional[Tensor] = None,
216
+ batch_idx_permute: Optional[Tensor] = None,
217
+ max_swizzle_size: int = 8,
125
218
  ) -> TileSchedulerOptions:
126
219
  return TileSchedulerOptions(
127
220
  Int32(max_active_clusters),
@@ -130,6 +223,72 @@ class GemmWrapperBase:
130
223
  )
131
224
  if tile_count_semaphore is not None
132
225
  else None,
226
+ batch_idx_permute=(
227
+ from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0)
228
+ )
229
+ if batch_idx_permute is not None
230
+ else None,
231
+ max_swizzle_size=Int32(max_swizzle_size),
232
+ )
233
+
234
+ @staticmethod
235
+ def create_varlen_args(
236
+ cu_seqlens_m: Optional[Tensor],
237
+ cu_seqlens_k: Optional[Tensor],
238
+ A_idx: Optional[Tensor],
239
+ max_active_clusters: int,
240
+ cluster_shape_mnk: Tuple[int, int, int],
241
+ tensors: Dict[str, GemmTensorInfo],
242
+ num_epi_tensormaps: int = 0,
243
+ pingpong: bool = False,
244
+ ) -> Optional[Any]:
245
+ if cu_seqlens_m is None and cu_seqlens_k is None:
246
+ return None
247
+ # When varlen_m, we assume persistent=True
248
+ # Grid size depends on num_active_clusters and cluster size
249
+ cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1]
250
+ num_blocks = max_active_clusters * cluster_size
251
+ # Calculate number of tensormaps needed
252
+ if cu_seqlens_m is not None:
253
+ # For varlen_m: need tensormaps for D and epilogue tensors
254
+ num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2)
255
+ if tensors["D"].tensor is not None:
256
+ num_tensormaps += 1 if not pingpong else 2 # D tensormap
257
+ else:
258
+ # For varlen_k: need tensormaps for A & B
259
+ num_tensormaps = 2 if A_idx is None else 1
260
+ # Create tensormap buffer (each tensormap is 128 bytes = 16 int64s)
261
+ tensormap_size = 128 // 8 # 16 int64s
262
+ if num_tensormaps > 0:
263
+ device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device
264
+ tensormaps = torch.empty(
265
+ (num_blocks, num_tensormaps, tensormap_size),
266
+ dtype=torch.int64,
267
+ device=device,
268
+ )
269
+ tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic(
270
+ mode=0, stride_order=(0, 1, 2)
271
+ )
272
+ else:
273
+ tensormaps_cute = None
274
+
275
+ return VarlenArguments(
276
+ mCuSeqlensM=(
277
+ from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0)
278
+ if cu_seqlens_m is not None
279
+ else None
280
+ ),
281
+ mCuSeqlensK=(
282
+ from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0)
283
+ if cu_seqlens_k is not None
284
+ else None
285
+ ),
286
+ mTensormaps=tensormaps_cute,
287
+ mAIdx=(
288
+ from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0)
289
+ if A_idx is not None
290
+ else None
291
+ ),
133
292
  )
134
293
 
135
294
  @staticmethod