quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,8 @@
1
- # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD-3-Clause
3
-
4
- # Redistribution and use in source and binary forms, with or without
5
- # modification, are permitted provided that the following conditions are met:
6
-
7
- # 1. Redistributions of source code must retain the above copyright notice, this
8
- # list of conditions and the following disclaimer.
9
-
10
- # 2. Redistributions in binary form must reproduce the above copyright notice,
11
- # this list of conditions and the following disclaimer in the documentation
12
- # and/or other materials provided with the distribution.
13
-
14
- # 3. Neither the name of the copyright holder nor the names of its
15
- # contributors may be used to endorse or promote products derived from
16
- # this software without specific prior written permission.
17
-
18
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1
+ # Based on the cute-dsl example:
2
+ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
28
3
 
29
4
  import argparse
30
- from typing import Optional, Type, Tuple, Union, Callable
5
+ from typing import Optional, Type, Tuple, Union, Callable, Literal
31
6
  from functools import partial
32
7
 
33
8
  import cuda.bindings.driver as cuda
@@ -40,15 +15,25 @@ import cutlass.torch as cutlass_torch
40
15
  import cutlass.pipeline as pipeline
41
16
  import cutlass.utils.blackwell_helpers as sm100_utils
42
17
  import cutlass.utils.blockscaled_layout as blockscaled_utils
18
+ from cutlass.cute.nvgpu.warp import (
19
+ LdMatrix8x8x16bOp,
20
+ LdMatrix16x16x8bOp,
21
+ StMatrix8x8x16bOp,
22
+ StMatrix16x8x8bOp,
23
+ )
24
+ from cutlass import Int32, Float32, Boolean, const_expr
25
+ from cutlass.utils import LayoutEnum
43
26
  from cutlass.cute.runtime import from_dlpack, make_ptr
44
- from cutlass import Int32, const_expr
45
27
 
46
- from quack.cute_dsl_utils import ParamsBase
47
- from quack.tile_scheduler import (
48
- TileSchedulerArguments,
49
- TileScheduler,
50
- RasterOrderOption,
51
- )
28
+ from quack.pipeline import PipelineTmaCpAsyncUmma
29
+ from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
30
+ from quack.tile_scheduler import TileSchedulerOptions
31
+ from quack.varlen_utils import VarlenArguments, VarlenManager
32
+ from quack.gemm_sm90 import GemmSm90, NamedBarrierGemm
33
+ import quack.copy_utils as copy_utils
34
+ import quack.sm100_utils as quack_sm100_utils
35
+
36
+ # return PipelineStateWAdvance instead of PipelineState
52
37
 
53
38
  """
54
39
  A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
@@ -72,8 +57,6 @@ This GEMM works as follows:
72
57
  - Type convert C matrix to output type.
73
58
  - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
74
59
  or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
75
- - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
76
- e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
77
60
 
78
61
  SM100 tcgen05.mma instructions operate as follows:
79
62
  - Read matrix A from SMEM
@@ -105,7 +88,7 @@ To collect performance with NCU profiler:
105
88
 
106
89
  Constraints are same as dense_gemm.py:
107
90
  * Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
108
- see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation
91
+ see detailed valid dtype combinations in below GemmSm100 class documentation
109
92
  * A/B tensor must have the same data type
110
93
  * Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
111
94
  * Mma tiler N must be 32-256, step 32
@@ -118,14 +101,12 @@ Constraints are same as dense_gemm.py:
118
101
  """
119
102
 
120
103
 
121
- class PersistentDenseGemmKernel:
104
+ class GemmSm100(GemmSm90):
122
105
  """This class implements batched matrix multiplication (C = A x B) with support for various data types
123
106
  and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
124
107
 
125
108
  :param acc_dtype: Data type for accumulation during computation
126
109
  :type acc_dtype: type[cutlass.Numeric]
127
- :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
128
- :type use_2cta_instrs: bool
129
110
  :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
130
111
  :type mma_tiler_mn: Tuple[int, int]
131
112
  :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
@@ -159,22 +140,28 @@ class PersistentDenseGemmKernel:
159
140
  - Cluster shape M/N must be positive and power of 2, total cluster size <= 16
160
141
 
161
142
  Example:
162
- >>> gemm = PersistentDenseGemmKernel(
163
- ... acc_dtype=cutlass.Float32,
164
- ... use_2cta_instrs=True,
143
+ >>> gemm = GemmSm100(
144
+ ... acc_dtype=Float32,
165
145
  ... mma_tiler_mn=(128, 128),
166
146
  ... cluster_shape_mn=(2, 2)
167
147
  ... )
168
148
  >>> gemm(mA, mB, mD, max_active_clusters, stream)
169
149
  """
170
150
 
151
+ arch = 100
152
+ num_epi_tensormaps = GemmSm90.num_epi_tensormaps
153
+
154
+ EpilogueArguments = GemmSm90.EpilogueArguments
155
+ EpilogueParams = GemmSm90.EpilogueParams
156
+
171
157
  def __init__(
172
158
  self,
173
159
  acc_dtype: Type[cutlass.Numeric],
174
- use_2cta_instrs: bool,
160
+ a_dtype: Type[cutlass.Numeric], # ignored for now
175
161
  mma_tiler_mn: Tuple[int, int],
176
- cluster_shape_mn: Tuple[int, int],
162
+ cluster_shape_mnk: Tuple[int, int, int],
177
163
  sf_vec_size: Optional[int] = None,
164
+ gather_A: bool = False,
178
165
  ):
179
166
  """Initializes the configuration for a Blackwell dense GEMM kernel.
180
167
 
@@ -187,50 +174,54 @@ class PersistentDenseGemmKernel:
187
174
  with cta_group=2 should be used.
188
175
 
189
176
  2. Cluster Shape:
190
- - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
177
+ - cluster_shape_mnk: The (ClusterM, ClusterN) shape of the CTA cluster.
191
178
 
192
179
  :param acc_dtype: Data type of the accumulator.
193
180
  :type acc_dtype: type[cutlass.Numeric]
194
181
  :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
195
182
  :type mma_tiler_mn: Tuple[int, int]
196
- :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
197
- :type use_2cta_instrs: bool
198
- :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
199
- :type cluster_shape_mn: Tuple[int, int]
183
+ :param cluster_shape_mnk: Tuple (ClusterM, ClusterN) shape of the cluster.
184
+ :type cluster_shape_mnk: Tuple[int, int]
200
185
  """
201
186
 
202
187
  self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
203
- self.use_2cta_instrs = use_2cta_instrs
204
- self.cluster_shape_mn = cluster_shape_mn
188
+ self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (256,)
189
+ self.cluster_shape_mnk = cluster_shape_mnk
190
+ assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1"
205
191
  # K dimension is deferred in _setup_attributes
206
192
  self.mma_tiler = (*mma_tiler_mn, 1)
207
193
  self.sf_vec_size = sf_vec_size
208
194
  self.blockscaled = sf_vec_size is not None
195
+ self.is_persistent = True
196
+ self.pingpong = False # for compatibility with GemmSm90
197
+ self.gather_A = gather_A
198
+ if gather_A:
199
+ assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
209
200
 
210
- self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
201
+ self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
211
202
 
203
+ self.num_ab_load_warps = 1 if not self.gather_A else 5
212
204
  self.occupancy = 1
213
205
  # Set specialized warp ids
214
- self.epilog_warp_id = (
215
- 0,
216
- 1,
217
- 2,
218
- 3,
219
- )
206
+ self.epilog_warp_id = (0, 1, 2, 3)
220
207
  self.mma_warp_id = 4
221
- self.tma_warp_id = 5
222
- self.tma_epi_warp_id = 6
223
- self.threads_per_cta = 32 * len(
224
- (self.mma_warp_id, self.tma_warp_id, self.tma_epi_warp_id, *self.epilog_warp_id)
208
+ self.ab_load_warp_id = 5
209
+ self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
210
+ self.scheduler_warp_id = self.epi_load_warp_id + 1
211
+ self.num_epi_warps = len(self.epilog_warp_id)
212
+ self.threads_per_cta = cute.arch.WARP_SIZE * (
213
+ self.num_ab_load_warps
214
+ + len(
215
+ (
216
+ self.mma_warp_id,
217
+ self.epi_load_warp_id,
218
+ self.scheduler_warp_id,
219
+ *self.epilog_warp_id,
220
+ )
221
+ )
225
222
  )
226
- # Set barrier id for cta sync, epilogue sync and tmem ptr sync
227
- self.cta_sync_bar_id = 0
228
- self.epilog_sync_bar_id = 1
229
- self.tmem_ptr_sync_bar_id = 2
230
- self.epilog_load_bar_id = 3
231
- self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_100")
232
-
233
- def _setup_attributes(self):
223
+
224
+ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments):
234
225
  """Set up configurations that are dependent on GEMM inputs
235
226
 
236
227
  This method configures various attributes based on the input tensor properties
@@ -261,7 +252,7 @@ class PersistentDenseGemmKernel:
261
252
 
262
253
  # Configure tiled mma
263
254
  if const_expr(not self.blockscaled):
264
- tiled_mma = sm100_utils.make_trivial_tiled_mma(
255
+ self.tiled_mma = sm100_utils.make_trivial_tiled_mma(
265
256
  self.a_dtype,
266
257
  self.a_major_mode,
267
258
  self.b_major_mode,
@@ -269,9 +260,9 @@ class PersistentDenseGemmKernel:
269
260
  self.cta_group,
270
261
  self.mma_tiler[:2],
271
262
  )
272
- tiled_mma_sfb = None
263
+ self.tiled_mma_sfb = None
273
264
  else:
274
- tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
265
+ self.tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
275
266
  self.a_dtype,
276
267
  self.a_major_mode,
277
268
  self.b_major_mode,
@@ -280,13 +271,13 @@ class PersistentDenseGemmKernel:
280
271
  self.cta_group,
281
272
  self.mma_inst_shape_mnk[:2],
282
273
  )
283
- tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
274
+ self.tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
284
275
  self.a_dtype,
285
276
  self.a_major_mode,
286
277
  self.b_major_mode,
287
278
  self.sf_dtype,
288
279
  self.sf_vec_size,
289
- cute.nvgpu.tcgen05.CtaGroup.ONE,
280
+ tcgen05.CtaGroup.ONE,
290
281
  self.mma_inst_shape_mnk_sfb[:2],
291
282
  )
292
283
 
@@ -306,26 +297,28 @@ class PersistentDenseGemmKernel:
306
297
  else:
307
298
  self.mma_tiler_sfb = None
308
299
  self.cta_tile_shape_mnk = (
309
- self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
300
+ self.mma_tiler[0] // cute.size(self.tiled_mma.thr_id.shape),
310
301
  self.mma_tiler[1],
311
302
  self.mma_tiler[2],
312
303
  )
313
304
 
314
305
  # Compute cluster layout
315
306
  self.cluster_layout_vmnk = cute.tiled_divide(
316
- cute.make_layout((*self.cluster_shape_mn, 1)),
317
- (tiled_mma.thr_id.shape,),
307
+ cute.make_layout(self.cluster_shape_mnk),
308
+ (self.tiled_mma.thr_id.shape,),
318
309
  )
319
310
  if const_expr(self.blockscaled):
320
311
  self.cluster_layout_sfb_vmnk = cute.tiled_divide(
321
- cute.make_layout((*self.cluster_shape_mn, 1)),
322
- (tiled_mma_sfb.thr_id.shape,),
312
+ cute.make_layout(self.cluster_shape_mnk),
313
+ (self.tiled_mma_sfb.thr_id.shape,),
323
314
  )
324
315
  else:
325
316
  self.cluster_layout_sfb_vmnk = None
326
317
 
327
318
  # Compute number of multicast CTAs for A/B
328
319
  self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
320
+ if self.gather_A:
321
+ assert self.num_mcast_ctas_a == 1
329
322
  self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
330
323
  self.is_a_mcast = self.num_mcast_ctas_a > 1
331
324
  self.is_b_mcast = self.num_mcast_ctas_b > 1
@@ -337,60 +330,82 @@ class PersistentDenseGemmKernel:
337
330
  self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
338
331
  self.cta_tile_shape_mnk,
339
332
  self.use_2cta_instrs,
340
- self.d_layout,
341
- self.d_dtype,
333
+ self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR,
334
+ self.d_dtype if self.d_dtype is not None else cutlass.BFloat16,
335
+ layout_c=self.c_layout,
336
+ elem_ty_c=self.c_dtype,
342
337
  )
343
338
 
344
339
  # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
340
+ prefetch_A_idx = (
341
+ None
342
+ if not self.gather_A
343
+ else ("varlen_m" if varlen_args.mCuSeqlensM is not None else "varlen_k")
344
+ )
345
345
  (
346
346
  self.num_acc_stage,
347
- self.num_ab_stage,
348
- self.num_d_stage,
349
- self.num_c_stage,
347
+ self.ab_stage,
348
+ self.epi_stage,
349
+ self.epi_c_stage,
350
350
  ) = self._compute_stages(
351
- tiled_mma,
351
+ self.tiled_mma,
352
352
  self.mma_tiler,
353
+ self.cta_tile_shape_mnk,
354
+ self.epi_tile,
353
355
  self.a_dtype,
354
356
  self.b_dtype,
355
- self.epi_tile,
357
+ self.sf_dtype,
358
+ self.sf_vec_size,
356
359
  self.d_dtype,
357
360
  self.c_dtype,
358
361
  self.d_layout,
359
362
  self.c_layout,
360
- self.sf_dtype,
361
- self.sf_vec_size,
362
- self.smem_capacity,
363
+ epilogue_args,
364
+ prefetch_A_idx,
365
+ cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
363
366
  self.occupancy,
364
367
  )
368
+ self.sched_stage = 1
369
+ self.a_prefetch_stage = (
370
+ 0
371
+ if not self.gather_A
372
+ else (2 if varlen_args.mCuSeqlensM is not None else self.ab_stage)
373
+ )
365
374
 
366
375
  # Compute A/B/SFA/SFB/C shared memory layout
367
376
  self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
368
- tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage
377
+ self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
369
378
  )
379
+ self.a_smem_load_layout_staged = self.a_smem_layout_staged
380
+ if const_expr(self.gather_A):
381
+ self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a(
382
+ self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
383
+ )
370
384
  self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
371
- tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage
372
- )
373
- self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
374
- self.d_dtype, self.d_layout, self.epi_tile, self.num_d_stage
385
+ self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage
375
386
  )
387
+ self.epi_smem_layout_staged = None
388
+ if const_expr(self.d_dtype is not None):
389
+ self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi(
390
+ self.d_dtype, self.d_layout, self.epi_tile, self.epi_stage
391
+ )
392
+ self.epi_c_smem_layout_staged = None
376
393
  if const_expr(self.c_dtype is not None):
377
394
  self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
378
- self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage
395
+ self.c_dtype, self.c_layout, self.epi_tile, self.epi_c_stage
379
396
  )
380
- else:
381
- self.epi_c_smem_layout_staged = None
382
397
  if const_expr(self.blockscaled):
383
398
  self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
384
- tiled_mma,
399
+ self.tiled_mma,
385
400
  self.mma_tiler,
386
401
  self.sf_vec_size,
387
- self.num_ab_stage,
402
+ self.ab_stage,
388
403
  )
389
404
  self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
390
- tiled_mma,
405
+ self.tiled_mma,
391
406
  self.mma_tiler,
392
407
  self.sf_vec_size,
393
- self.num_ab_stage,
408
+ self.ab_stage,
394
409
  )
395
410
  else:
396
411
  self.sfa_smem_layout_staged, self.sfb_smem_layout_staged = None, None
@@ -398,7 +413,7 @@ class PersistentDenseGemmKernel:
398
413
  # Compute the number of tensor memory allocation columns
399
414
  if const_expr(not self.blockscaled):
400
415
  self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
401
- tiled_mma, self.mma_tiler, self.num_acc_stage
416
+ self.tiled_mma, self.mma_tiler, self.num_acc_stage
402
417
  )
403
418
  else:
404
419
  SM100_TMEM_CAPACITY_COLUMNS = 512
@@ -409,14 +424,14 @@ class PersistentDenseGemmKernel:
409
424
  self,
410
425
  mA: cute.Tensor,
411
426
  mB: cute.Tensor,
412
- mD: cute.Tensor,
427
+ mD: Optional[cute.Tensor],
413
428
  mC: Optional[cute.Tensor],
414
- tile_count_semaphore: Optional[cute.Pointer],
415
- max_active_clusters: cutlass.Constexpr,
429
+ epilogue_args: ArgumentsBase,
430
+ scheduler_args: TileSchedulerOptions,
431
+ varlen_args: Optional[VarlenArguments],
416
432
  stream: cuda.CUstream,
417
433
  mSFA: Optional[cute.Tensor] = None,
418
434
  mSFB: Optional[cute.Tensor] = None,
419
- epilogue_op: cutlass.Constexpr = lambda x: x,
420
435
  ):
421
436
  """Execute the GEMM operation in steps:
422
437
  - Setup static attributes before smem/grid/tma computation
@@ -435,32 +450,48 @@ class PersistentDenseGemmKernel:
435
450
  :type max_active_clusters: cutlass.Constexpr
436
451
  :param stream: CUDA stream for asynchronous execution
437
452
  :type stream: cuda.CUstream
438
- :param epilogue_op: Optional elementwise lambda function to apply to the output tensor
439
- :type epilogue_op: cutlass.Constexpr
440
453
  :raises TypeError: If input data types are incompatible with the MMA instruction.
441
454
  :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
442
455
  """
443
456
  if const_expr(self.blockscaled):
444
457
  assert mSFA is not None and mSFB is not None
445
458
  # Setup static attributes before smem/grid/tma computation
446
- self.a_dtype: Type[cutlass.Numeric] = mA.element_type
447
- self.b_dtype: Type[cutlass.Numeric] = mB.element_type
448
- self.d_dtype: Type[cutlass.Numeric] = mD.element_type
459
+ self.a_dtype = mA.element_type
460
+ self.b_dtype = mB.element_type
461
+ self.d_dtype = mD.element_type if mD is not None else None
449
462
  self.c_dtype = mC.element_type if mC is not None else None
450
463
  self.sf_dtype: Optional[Type[cutlass.Numeric]] = (
451
464
  mSFA.element_type if mSFA is not None else None
452
465
  )
453
- self.a_major_mode = cutlass.utils.LayoutEnum.from_tensor(mA).mma_major_mode()
454
- self.b_major_mode = cutlass.utils.LayoutEnum.from_tensor(mB).mma_major_mode()
455
- self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
456
- self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
466
+ self.a_layout = LayoutEnum.from_tensor(mA)
467
+ self.b_layout = LayoutEnum.from_tensor(mB)
468
+ self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
469
+ self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
470
+ self.a_major_mode = LayoutEnum.from_tensor(mA).mma_major_mode()
471
+ self.b_major_mode = LayoutEnum.from_tensor(mB).mma_major_mode()
457
472
 
458
473
  # Check if input data types are compatible with MMA instruction
459
474
  if const_expr(self.a_dtype != self.b_dtype):
460
475
  raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
461
476
 
477
+ if const_expr(varlen_args is None):
478
+ varlen_args = VarlenArguments()
479
+ assert (varlen_args.mAIdx is not None) == self.gather_A
480
+
481
+ # Assume all strides are divisible by 128 bits except the last stride
482
+ new_stride = lambda t: tuple(
483
+ cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
484
+ for s in t.stride
485
+ )
486
+ mA, mD = [
487
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
488
+ if t is not None
489
+ else None
490
+ for t in (mA, mD)
491
+ ]
492
+
462
493
  # Setup attributes that dependent on gemm inputs
463
- self._setup_attributes()
494
+ self._setup_attributes(epilogue_args, varlen_args)
464
495
 
465
496
  if const_expr(self.blockscaled):
466
497
  # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
@@ -471,67 +502,44 @@ class PersistentDenseGemmKernel:
471
502
  sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(mB.shape, self.sf_vec_size)
472
503
  mSFB = cute.make_tensor(mSFB.iterator, sfb_layout)
473
504
 
474
- if const_expr(not self.blockscaled):
475
- tiled_mma = sm100_utils.make_trivial_tiled_mma(
476
- self.a_dtype,
477
- self.a_major_mode,
478
- self.b_major_mode,
479
- self.acc_dtype,
480
- self.cta_group,
481
- self.mma_tiler[:2],
482
- )
483
- tiled_mma_sfb = None
484
- else:
485
- tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
486
- self.a_dtype,
487
- self.a_major_mode,
488
- self.b_major_mode,
489
- self.sf_dtype,
490
- self.sf_vec_size,
491
- self.cta_group,
492
- self.mma_inst_shape_mnk[:2],
493
- )
494
- tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
495
- self.a_dtype,
496
- self.a_major_mode,
497
- self.b_major_mode,
498
- self.sf_dtype,
499
- self.sf_vec_size,
500
- cute.nvgpu.tcgen05.CtaGroup.ONE,
501
- self.mma_inst_shape_mnk_sfb[:2],
502
- )
503
- atom_thr_size = cute.size(tiled_mma.thr_id.shape)
505
+ atom_thr_size = cute.size(self.tiled_mma.thr_id.shape)
504
506
 
505
- # Setup TMA load for A
506
- a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
507
+ # Setup TMA load for A & B
507
508
  a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
508
- tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
509
- a_op,
510
- mA,
511
- a_smem_layout,
512
- self.mma_tiler,
513
- tiled_mma,
514
- self.cluster_layout_vmnk.shape,
515
- internal_type=(cutlass.TFloat32 if mA.element_type is cutlass.Float32 else None),
516
- )
517
-
518
- # Setup TMA load for B
519
- b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
520
509
  b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
510
+ tma_atom_a, tma_tensor_a = None, None
511
+ if const_expr(not self.gather_A):
512
+ a_op = sm100_utils.cluster_shape_to_tma_atom_A(
513
+ self.cluster_shape_mnk, self.tiled_mma.thr_id
514
+ )
515
+ tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
516
+ a_op,
517
+ mA,
518
+ a_smem_layout,
519
+ self.mma_tiler,
520
+ self.tiled_mma,
521
+ self.cluster_layout_vmnk.shape,
522
+ internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None),
523
+ )
524
+ b_op = sm100_utils.cluster_shape_to_tma_atom_B(
525
+ self.cluster_shape_mnk, self.tiled_mma.thr_id
526
+ )
521
527
  tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
522
528
  b_op,
523
529
  mB,
524
530
  b_smem_layout,
525
531
  self.mma_tiler,
526
- tiled_mma,
532
+ self.tiled_mma,
527
533
  self.cluster_layout_vmnk.shape,
528
- internal_type=(cutlass.TFloat32 if mB.element_type is cutlass.Float32 else None),
534
+ internal_type=(cutlass.TFloat32 if mB.element_type is Float32 else None),
529
535
  )
530
536
 
537
+ tma_atom_sfa, tma_tensor_sfa = None, None
538
+ tma_atom_sfb, tma_tensor_sfb = None, None
531
539
  if const_expr(self.blockscaled):
532
540
  # Setup TMA load for SFA
533
541
  sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
534
- self.cluster_shape_mn, tiled_mma.thr_id
542
+ self.cluster_shape_mnk, self.tiled_mma.thr_id
535
543
  )
536
544
  sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
537
545
  tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
@@ -539,13 +547,13 @@ class PersistentDenseGemmKernel:
539
547
  mSFA,
540
548
  sfa_smem_layout,
541
549
  self.mma_tiler,
542
- tiled_mma,
550
+ self.tiled_mma,
543
551
  self.cluster_layout_vmnk.shape,
544
552
  internal_type=cutlass.Int16,
545
553
  )
546
554
  # Setup TMA load for SFB
547
555
  sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
548
- self.cluster_shape_mn, tiled_mma.thr_id
556
+ self.cluster_shape_mnk, self.tiled_mma.thr_id
549
557
  )
550
558
  sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
551
559
  tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
@@ -553,58 +561,50 @@ class PersistentDenseGemmKernel:
553
561
  mSFB,
554
562
  sfb_smem_layout,
555
563
  self.mma_tiler_sfb,
556
- tiled_mma_sfb,
564
+ self.tiled_mma_sfb,
557
565
  self.cluster_layout_sfb_vmnk.shape,
558
566
  internal_type=cutlass.Int16,
559
567
  )
560
- else:
561
- tma_atom_sfa, tma_tensor_sfa = None, None
562
- tma_atom_sfb, tma_tensor_sfb = None, None
563
568
 
564
- a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
565
- b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
566
- self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
569
+ self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
570
+ if const_expr(not self.gather_A):
571
+ self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
567
572
  if const_expr(self.blockscaled):
568
573
  sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
569
574
  sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
570
- self.num_tma_load_bytes += (sfa_copy_size + sfb_copy_size) * atom_thr_size
575
+ self.num_tma_load_bytes += sfa_copy_size + sfb_copy_size
576
+ self.num_tma_load_bytes *= atom_thr_size
571
577
 
572
578
  # Setup TMA store for D
573
- epi_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
574
- tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
575
- cpasync.CopyBulkTensorTileS2GOp(),
576
- mD,
577
- epi_smem_layout,
578
- self.epi_tile,
579
- )
580
- if const_expr(mC is not None):
581
- epi_c_smem_layout = cute.slice_(self.epi_c_smem_layout_staged, (None, None, 0))
582
- tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
583
- cpasync.CopyBulkTensorTileG2SOp(),
584
- mC,
585
- epi_c_smem_layout,
579
+ tma_atom_d, tma_tensor_d = None, None
580
+ if const_expr(mD is not None):
581
+ tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
582
+ mD,
583
+ self.epi_smem_layout_staged,
586
584
  self.epi_tile,
585
+ op_type="store"
586
+ if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
587
+ else "add",
588
+ )
589
+ tma_atom_c, tma_tensor_c = None, None
590
+ if const_expr(mC is not None):
591
+ tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
592
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
587
593
  )
588
- else:
589
- tma_atom_c, tma_tensor_c = None, None
590
594
 
591
- problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.cta_tile_shape_mnk[:2]) + (
592
- mD.shape[2],
593
- )
594
- TileSchedulerCls = TileScheduler
595
- tile_sched_args = TileSchedulerArguments(
596
- problem_shape_ntile_mnl=problem_shape_ntile_mnl,
597
- raster_order=RasterOrderOption.Heuristic,
598
- group_size=8,
599
- cluster_shape_mnk=(*self.cluster_shape_mn, 1),
600
- tile_count_semaphore=tile_count_semaphore,
601
- is_persistent=True,
602
- )
595
+ epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
596
+ varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
597
+
598
+ TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
599
+ tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
603
600
  tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
604
- grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
601
+ grid = TileSchedulerCls.get_grid_shape(
602
+ tile_sched_params, scheduler_args.max_active_clusters
603
+ )
605
604
 
606
605
  self.buffer_align_bytes = 1024
607
606
 
607
+ epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
608
608
  epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
609
609
  sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU
610
610
  sfa_smem_size = (
@@ -613,22 +613,33 @@ class PersistentDenseGemmKernel:
613
613
  sfb_smem_size = (
614
614
  cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0
615
615
  )
616
+ a_idx_smem_size = 0
617
+ if const_expr(self.gather_A):
618
+ a_idx_smem_size = self.a_prefetch_stage * (
619
+ self.cta_tile_shape_mnk[0]
620
+ if varlen_args.mCuSeqlensM is not None
621
+ else self.cta_tile_shape_mnk[2]
622
+ )
616
623
 
617
624
  # Define shared storage for kernel
618
625
  @cute.struct
619
626
  class SharedStorage:
620
- ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
621
- ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
622
- epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage * 2]
623
- acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
624
- acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
627
+ ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
628
+ epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
629
+ acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
630
+ sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
631
+ a_prefetch_pipeline_array_ptr: cute.struct.MemRange[
632
+ cutlass.Int64, self.a_prefetch_stage * 2
633
+ ]
634
+ tile_count: cute.struct.MemRange[Int32, self.sched_stage]
625
635
  tmem_dealloc_mbar_ptr: cutlass.Int64
626
636
  tmem_holding_buf: Int32
627
- sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, 2]
628
- tile_count: cute.struct.MemRange[cutlass.Int32, 1]
637
+ sAIdx: cute.struct.Align[cute.struct.MemRange[Int32, a_idx_smem_size], 16]
629
638
  # (EPI_TILE_M, EPI_TILE_N, STAGE)
630
639
  sD: cute.struct.Align[
631
- cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)],
640
+ cute.struct.MemRange[
641
+ self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
642
+ ],
632
643
  self.buffer_align_bytes,
633
644
  ]
634
645
  sC: cute.struct.Align[
@@ -637,6 +648,7 @@ class PersistentDenseGemmKernel:
637
648
  ],
638
649
  self.buffer_align_bytes,
639
650
  ]
651
+ epi: self.epi_get_smem_struct(epilogue_params)
640
652
  # (MMA, MMA_M, MMA_K, STAGE)
641
653
  sA: cute.struct.Align[
642
654
  cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
@@ -662,10 +674,10 @@ class PersistentDenseGemmKernel:
662
674
 
663
675
  # Launch the kernel synchronously
664
676
  self.kernel(
665
- tiled_mma,
666
- tiled_mma_sfb,
677
+ self.tiled_mma,
678
+ self.tiled_mma_sfb,
667
679
  tma_atom_a,
668
- tma_tensor_a,
680
+ tma_tensor_a if const_expr(not self.gather_A) else mA,
669
681
  tma_atom_b,
670
682
  tma_tensor_b,
671
683
  tma_atom_sfa,
@@ -676,24 +688,26 @@ class PersistentDenseGemmKernel:
676
688
  tma_tensor_d,
677
689
  tma_atom_c,
678
690
  tma_tensor_c,
691
+ epilogue_params,
692
+ varlen_params,
679
693
  self.cluster_layout_vmnk,
680
694
  self.cluster_layout_sfb_vmnk,
681
695
  self.a_smem_layout_staged,
696
+ self.a_smem_load_layout_staged,
682
697
  self.b_smem_layout_staged,
683
698
  self.sfa_smem_layout_staged,
684
699
  self.sfb_smem_layout_staged,
685
- self.d_smem_layout_staged,
700
+ self.epi_smem_layout_staged,
686
701
  self.epi_c_smem_layout_staged,
687
702
  self.epi_tile,
688
703
  tile_sched_params,
689
704
  TileSchedulerCls,
690
- epilogue_op,
691
705
  ).launch(
692
706
  grid=grid,
693
707
  block=[self.threads_per_cta, 1, 1],
694
- cluster=(*self.cluster_shape_mn, 1),
695
- smem=self.shared_storage.size_in_bytes(),
708
+ cluster=self.cluster_shape_mnk,
696
709
  stream=stream,
710
+ min_blocks_per_mp=1,
697
711
  )
698
712
  return
699
713
 
@@ -703,7 +717,7 @@ class PersistentDenseGemmKernel:
703
717
  self,
704
718
  tiled_mma: cute.TiledMma,
705
719
  tiled_mma_sfb: Optional[cute.TiledMma],
706
- tma_atom_a: cute.CopyAtom,
720
+ tma_atom_a: Optional[cute.CopyAtom],
707
721
  mA_mkl: cute.Tensor,
708
722
  tma_atom_b: cute.CopyAtom,
709
723
  mB_nkl: cute.Tensor,
@@ -712,37 +726,52 @@ class PersistentDenseGemmKernel:
712
726
  tma_atom_sfb: Optional[cute.CopyAtom],
713
727
  mSFB_nkl: Optional[cute.Tensor],
714
728
  tma_atom_d: Optional[cute.CopyAtom],
715
- mD_mnl: cute.Tensor,
729
+ mD_mnl: Optional[cute.Tensor],
716
730
  tma_atom_c: Optional[cute.CopyAtom],
717
731
  mC_mnl: Optional[cute.Tensor],
732
+ epilogue_params: ParamsBase,
733
+ varlen_params: VarlenManager.Params,
718
734
  cluster_layout_vmnk: cute.Layout,
719
735
  cluster_layout_sfb_vmnk: Optional[cute.Layout],
720
- a_smem_layout_staged: cute.ComposedLayout,
721
- b_smem_layout_staged: cute.ComposedLayout,
722
- sfa_smem_layout_staged: Optional[cute.Layout],
723
- sfb_smem_layout_staged: Optional[cute.Layout],
724
- d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
725
- epi_c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
736
+ a_smem_layout: cute.ComposedLayout,
737
+ a_smem_load_layout: cute.ComposedLayout,
738
+ b_smem_layout: cute.ComposedLayout,
739
+ sfa_smem_layout: Optional[cute.Layout],
740
+ sfb_smem_layout: Optional[cute.Layout],
741
+ epi_smem_layout: Union[cute.Layout, cute.ComposedLayout, None],
742
+ epi_c_smem_layout: Union[cute.Layout, cute.ComposedLayout, None],
726
743
  epi_tile: cute.Tile,
727
744
  tile_sched_params: ParamsBase,
728
745
  TileSchedulerCls: cutlass.Constexpr[Callable],
729
- epilogue_op: cutlass.Constexpr[Callable],
730
746
  ):
731
747
  """
732
748
  GPU device kernel performing the Persistent batched GEMM computation.
733
749
  """
750
+
751
+ varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
752
+ varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
753
+ assert not (varlen_m and varlen_k)
754
+ if const_expr(self.gather_A):
755
+ assert varlen_m or varlen_k
756
+ has_D = const_expr(mD_mnl is not None)
757
+ has_C = const_expr(mC_mnl is not None)
758
+
734
759
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
735
760
 
736
- #
737
- # Prefetch tma desc
738
- #
739
- if warp_idx == self.tma_warp_id:
740
- cpasync.prefetch_descriptor(tma_atom_a)
741
- cpasync.prefetch_descriptor(tma_atom_b)
742
- if const_expr(self.blockscaled):
743
- cpasync.prefetch_descriptor(tma_atom_sfa)
744
- cpasync.prefetch_descriptor(tma_atom_sfb)
745
- cpasync.prefetch_descriptor(tma_atom_d)
761
+ # /////////////////////////////////////////////////////////////////////////////
762
+ # Prefetch Tma desc
763
+ # /////////////////////////////////////////////////////////////////////////////
764
+ if warp_idx == self.ab_load_warp_id:
765
+ for tma_atom in (
766
+ tma_atom_a,
767
+ tma_atom_b,
768
+ tma_atom_sfa,
769
+ tma_atom_sfb,
770
+ tma_atom_d,
771
+ tma_atom_c,
772
+ ):
773
+ if const_expr(tma_atom is not None):
774
+ cpasync.prefetch_descriptor(tma_atom)
746
775
 
747
776
  use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
748
777
 
@@ -754,13 +783,6 @@ class PersistentDenseGemmKernel:
754
783
  mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
755
784
  is_leader_cta = mma_tile_coord_v == 0
756
785
  cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
757
- block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
758
- if const_expr(self.blockscaled):
759
- block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
760
- cta_rank_in_cluster
761
- )
762
- else:
763
- block_in_cluster_coord_sfb_vmnk = None
764
786
  # Coord inside cta
765
787
  tidx, _, _ = cute.arch.thread_idx()
766
788
 
@@ -775,104 +797,68 @@ class PersistentDenseGemmKernel:
775
797
 
776
798
  # Tensor memory dealloc barrier init
777
799
  if use_2cta_instrs:
778
- if warp_idx == self.tma_warp_id:
800
+ if warp_idx == self.ab_load_warp_id:
779
801
  num_tmem_dealloc_threads = 32
780
802
  cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
781
803
 
782
- # Initialize mainloop ab_pipeline (barrier) and states
783
- ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
784
- num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
785
- ab_pipeline_consumer_group = pipeline.CooperativeGroup(
786
- pipeline.Agent.Thread, num_tma_producer
804
+ # Initialize pipelines and states
805
+ ab_pipeline = self.make_ab_pipeline(
806
+ tiled_mma=tiled_mma,
807
+ cluster_layout_vmnk=cluster_layout_vmnk,
808
+ ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
809
+ is_leader_cta=is_leader_cta,
787
810
  )
788
- ab_pipeline = pipeline.PipelineTmaUmma.create(
789
- barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
790
- num_stages=self.num_ab_stage,
791
- producer_group=ab_pipeline_producer_group,
792
- consumer_group=ab_pipeline_consumer_group,
793
- tx_count=self.num_tma_load_bytes,
794
- cta_layout_vmnk=cluster_layout_vmnk,
811
+ epi_pipeline = None
812
+ if const_expr(has_C):
813
+ epi_pipeline = self.make_epi_pipeline(
814
+ c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
815
+ epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
816
+ )
817
+ acc_pipeline = self.make_acc_pipeline(
818
+ cluster_layout_vmnk=cluster_layout_vmnk,
819
+ acc_pipeline_mbar_ptr=storage.acc_pipeline_array_ptr.data_ptr(),
795
820
  )
796
-
797
- if const_expr(mC_mnl is not None):
798
- # Threads/warps participating in this pipeline
799
- epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
800
- # Each warp will contribute 1 to the arrive count
801
- consumer_arrive_cnt = len(self.epilog_warp_id)
802
- epi_pipeline_consumer_group = pipeline.CooperativeGroup(
803
- pipeline.Agent.Thread, consumer_arrive_cnt
821
+ sched_pipeline = None
822
+ tile_count = None
823
+ if const_expr(tile_sched_params.tile_count_semaphore is not None):
824
+ # Dynamic persistent scheduler
825
+ sched_pipeline = self.make_sched_pipeline(
826
+ self.cluster_shape_mnk,
827
+ sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
828
+ has_C=has_C,
804
829
  )
805
- c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
806
- tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
807
- epi_pipeline = pipeline.PipelineTmaAsync.create(
808
- barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
809
- num_stages=self.num_c_stage,
810
- producer_group=epi_pipeline_producer_group,
811
- consumer_group=epi_pipeline_consumer_group,
812
- tx_count=tma_copy_c_bytes,
830
+ tile_count = storage.tile_count.get_tensor((self.sched_stage,))
831
+ a_prefetch_pipeline = None
832
+ if const_expr(self.gather_A):
833
+ a_prefetch_pipeline = self.make_a_prefetch_pipeline(
834
+ storage.a_prefetch_pipeline_array_ptr.data_ptr(),
813
835
  )
814
- else:
815
- epi_pipeline = None
816
-
817
- # Initialize acc_pipeline (barrier) and states
818
- acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
819
- num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
820
- acc_pipeline_consumer_group = pipeline.CooperativeGroup(
821
- pipeline.Agent.Thread, num_acc_consumer_threads
822
- )
823
- acc_pipeline = pipeline.PipelineUmmaAsync.create(
824
- barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
825
- num_stages=self.num_acc_stage,
826
- producer_group=acc_pipeline_producer_group,
827
- consumer_group=acc_pipeline_consumer_group,
828
- cta_layout_vmnk=cluster_layout_vmnk,
829
- )
830
-
831
- # if const_expr(tile_sched_params.tile_count_semaphore is not None):
832
- # # Dynamic persistent scheduler
833
- # # Threads/warps participating in this pipeline
834
- # sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
835
- # cluster_size = cute.size(cluster_layout_vmnk)
836
- # # Each warp that are not the scheduler warp will contribute 1 to the arrive count
837
- # consumer_arrive_cnt = (
838
- # (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
839
- # ) * cluster_size - 1
840
- # sched_pipeline_consumer_group = pipeline.CooperativeGroup(
841
- # pipeline.Agent.Thread, consumer_arrive_cnt
842
- # )
843
- # sched_pipeline = pipeline.PipelineAsync.create(
844
- # barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
845
- # num_stages=self.sched_stage,
846
- # producer_group=sched_pipeline_producer_group,
847
- # consumer_group=sched_pipeline_consumer_group,
848
- # # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
849
- # consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
850
- # )
851
- # tile_count = storage.tile_count.get_tensor((self.sched_stage,))
852
- # else:
853
- # sched_pipeline = None
854
- # tile_count = None
855
836
 
856
837
  # Setup smem tensor A/B/D
857
838
  # (MMA, MMA_M, MMA_K, STAGE)
858
- sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
839
+ sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
840
+ sA = storage.sA.get_tensor(a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner)
859
841
  # (MMA, MMA_N, MMA_K, STAGE)
860
- sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
842
+ sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
843
+ sAIdx = None
844
+ if const_expr(self.gather_A):
845
+ a_idx_smem_dim = self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2]
846
+ a_idx_smem_layout = cute.make_layout((a_idx_smem_dim, self.a_prefetch_stage))
847
+ sAIdx = storage.sAIdx.get_tensor(a_idx_smem_layout)
848
+ sSFA, sSFB = None, None
861
849
  if const_expr(self.blockscaled):
862
850
  # (MMA, MMA_M, MMA_K, STAGE)
863
- sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
851
+ sSFA = storage.sSFA.get_tensor(sfa_smem_layout)
864
852
  # (MMA, MMA_N, MMA_K, STAGE)
865
- sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
866
- else:
867
- sSFA, sSFB = None, None
868
- # (EPI_TILE_M, EPI_TILE_N, STAGE)
869
- sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
870
- if const_expr(mC_mnl is not None):
871
- sC = storage.sC.get_tensor(
872
- epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
873
- )
874
- else:
875
- sC = None
853
+ sSFB = storage.sSFB.get_tensor(sfb_smem_layout)
854
+ sD = None
855
+ if const_expr(has_D):
856
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
857
+ sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
858
+ sC = None
859
+ if const_expr(has_C):
860
+ sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
861
+ epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
876
862
 
877
863
  thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
878
864
  thr_mma_sfb = (
@@ -884,26 +870,51 @@ class PersistentDenseGemmKernel:
884
870
  # (MMA, MMA_M, MMA_N, STAGE)
885
871
  tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
886
872
 
887
- tmem_ptr_read_threads = cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id))
888
- tmem_alloc_barrier = pipeline.NamedBarrier(
889
- barrier_id=self.tmem_ptr_sync_bar_id, num_threads=tmem_ptr_read_threads
873
+ varlen_manager = VarlenManager.create(
874
+ varlen_params,
875
+ has_D,
876
+ self.num_epi_tensormaps,
877
+ # Only used if not varlen_m
878
+ len_m_static=Int32(
879
+ mA_mkl.shape[0]
880
+ if varlen_k or varlen_params.mAIdx is None
881
+ else varlen_params.mAIdx.shape[0]
882
+ ),
883
+ len_k_static=Int32(mA_mkl.shape[1]),
890
884
  )
891
885
 
892
- TileSchedulerCls = partial(TileSchedulerCls.create, tile_sched_params)
893
- k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.mma_tiler[2])
886
+ TileSchedulerCls = partial(
887
+ TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
888
+ )
894
889
 
895
- if const_expr(mC_mnl is not None):
890
+ tmem_alloc_barrier = pipeline.NamedBarrier(
891
+ barrier_id=int(NamedBarrierGemm.TmemPtr),
892
+ num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
893
+ )
894
+ epi_load_barrier = None
895
+ if const_expr(has_C):
896
896
  epi_load_barrier = pipeline.NamedBarrier(
897
- barrier_id=int(self.epilog_load_bar_id), num_threads=2 * cute.arch.WARP_SIZE
897
+ barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE
898
898
  )
899
- else:
900
- epi_load_barrier = None
901
899
 
902
900
  #
903
- # Specialized TMA load warp
901
+ # Specialized AB load warps
904
902
  #
905
- if warp_idx == self.tma_warp_id:
903
+ if warp_idx == self.ab_load_warp_id:
904
+ is_tma_warp = True
905
+ # initialize tensormap for A & B
906
+ varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
907
+ tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
908
+ tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
906
909
  # Compute multicast mask for A/B buffer full
910
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
911
+ block_in_cluster_coord_sfb_vmnk = None
912
+ if const_expr(self.blockscaled):
913
+ block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
914
+ cta_rank_in_cluster
915
+ )
916
+ a_mcast_mask, b_mcast_mask = None, None
917
+ sfa_mcast_mask, sfb_mcast_mask = None, None
907
918
  if const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
908
919
  a_mcast_mask = cpasync.create_tma_multicast_mask(
909
920
  cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
@@ -918,141 +929,139 @@ class PersistentDenseGemmKernel:
918
929
  sfb_mcast_mask = cpasync.create_tma_multicast_mask(
919
930
  cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1
920
931
  )
921
- else:
922
- sfa_mcast_mask, sfb_mcast_mask = None, None
923
- else:
924
- a_mcast_mask, b_mcast_mask = None, None
925
- sfa_mcast_mask, sfb_mcast_mask = None, None
926
932
 
927
933
  # Persistent tile scheduling loop
928
934
  tile_scheduler = TileSchedulerCls()
929
935
  work_tile = tile_scheduler.initial_work_tile_info()
930
936
  ab_producer_state = pipeline.make_pipeline_state(
931
- pipeline.PipelineUserType.Producer, self.num_ab_stage
937
+ pipeline.PipelineUserType.Producer, self.ab_stage
932
938
  )
933
- do_epi_load_barrier_arrive = cutlass.Boolean(True)
939
+ if const_expr(varlen_k):
940
+ # wait tensormap initialization complete before update
941
+ varlen_manager.fence_tensormap_init()
942
+ do_epi_load_barrier_arrive = Boolean(True)
934
943
  while work_tile.is_valid_tile:
935
- # Get tile coord from tile scheduler
936
944
  tile_coord_mnkl = work_tile.tile_idx
945
+ batch_idx = tile_coord_mnkl[3]
946
+ varlen_manager.update_tensormap_AB(
947
+ batch_idx,
948
+ self.a_layout,
949
+ self.b_layout,
950
+ is_tma_warp,
951
+ )
952
+ # ///////////////////////////////////////////////////////////////////////////
953
+ # Local_tile partition global tensors
954
+ # ///////////////////////////////////////////////////////////////////////////
937
955
  mma_tile_coord_mnl = (
938
956
  tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
939
957
  tile_coord_mnkl[1],
940
958
  tile_coord_mnkl[3],
941
959
  )
942
- # Local_tile partition global tensors
943
- # (bM, bK, RestK)
944
- gA_mkl = cute.local_tile(
945
- mA_mkl,
946
- cute.slice_(self.mma_tiler, (None, 0, None)),
947
- (mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]),
948
- )
960
+ gA_mk = None
961
+ if const_expr(not self.gather_A):
962
+ mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
963
+ # (bM, bK, RestK)
964
+ gA_mk = cute.local_tile(
965
+ mA_mk,
966
+ cute.select(self.mma_tiler, [0, 2]),
967
+ (mma_tile_coord_mnl[0], None),
968
+ )
949
969
  # (bN, bK, RestK)
950
- gB_nkl = cute.local_tile(
951
- mB_nkl,
952
- cute.slice_(self.mma_tiler, (0, None, None)),
953
- (mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]),
970
+ gB_nk = cute.local_tile(
971
+ varlen_manager.offset_batch_B(mB_nkl, batch_idx),
972
+ cute.select(self.mma_tiler, [1, 2]),
973
+ (mma_tile_coord_mnl[1], None),
954
974
  )
955
975
  if const_expr(self.blockscaled):
956
976
  # (bM, bK)
957
977
  gSFA_mkl = cute.local_tile(
958
- mSFA_mkl,
959
- cute.slice_(self.mma_tiler, (None, 0, None)),
960
- (mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]),
978
+ varlen_manager.offset_batch_A(mSFA_mkl, batch_idx),
979
+ cute.select(self.mma_tiler, [0, 2]),
980
+ (mma_tile_coord_mnl[0], None),
961
981
  )
962
982
  # (bN, bK)
963
983
  gSFB_nkl = cute.local_tile(
964
- mSFB_nkl,
965
- cute.slice_(self.mma_tiler, (0, None, None)),
966
- (mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]),
984
+ varlen_manager.offset_batch_B(mSFB_nkl, batch_idx),
985
+ cute.select(self.mma_tiler, [1, 2]),
986
+ (mma_tile_coord_mnl[1], None),
967
987
  )
988
+
968
989
  # Partition global tensor for TiledMMA_A/B/D
969
- # (MMA, MMA_M, MMA_K, RestK)
970
- tCgA = thr_mma.partition_A(gA_mkl)
990
+ # Then partition global/shared tensor for TMA load A/B
991
+ varlen_manager.fence_tensormap_update_AB(is_tma_warp)
992
+ len_k = varlen_manager.len_k(batch_idx)
993
+ # TMA load A partition_S/D
994
+ a_cta_layout = cute.make_layout(
995
+ cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
996
+ )
997
+ copy_A = None
998
+ if const_expr(not self.gather_A):
999
+ # (MMA, MMA_M, MMA_K, RestK)
1000
+ tCgA = thr_mma.partition_A(gA_mk)
1001
+ copy_A, _, _ = copy_utils.tma_get_copy_fn(
1002
+ tma_atom_a,
1003
+ cta_coord=block_in_cluster_coord_vmnk[2],
1004
+ cta_layout=a_cta_layout,
1005
+ src_tensor=tCgA,
1006
+ dst_tensor=sA,
1007
+ mcast_mask=a_mcast_mask,
1008
+ tma_desc_ptr=tma_desc_a_ptr,
1009
+ )
971
1010
  # (MMA, MMA_N, MMA_K, RestK)
972
- tCgB = thr_mma.partition_B(gB_nkl)
1011
+ tCgB = thr_mma.partition_B(gB_nk)
973
1012
  if const_expr(self.blockscaled):
974
1013
  # (MMA, MMA_M, MMA_K)
975
1014
  tCgSFA = thr_mma.partition_A(gSFA_mkl)
976
1015
  # (MMA, MMA_N, MMA_K)
977
1016
  tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
978
- # Partition global/shared tensor for TMA load A/B
979
- # TMA load A partition_S/D
980
- a_cta_layout = cute.make_layout(
981
- cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
982
- )
983
- # ((atom_v, rest_v), STAGE)
984
- # ((atom_v, rest_v), RestK)
985
- tAsA, tAgA = cpasync.tma_partition(
986
- tma_atom_a,
987
- block_in_cluster_coord_vmnk[2],
988
- a_cta_layout,
989
- cute.group_modes(sA, 0, 3),
990
- cute.group_modes(tCgA, 0, 3),
991
- )
992
1017
  # TMA load B partition_S/D
993
- b_cta_layout = cute.make_layout(
994
- cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
995
- )
996
- # ((atom_v, rest_v), STAGE)
997
- # ((atom_v, rest_v), RestK)
998
- tBsB, tBgB = cpasync.tma_partition(
1018
+ copy_B, _, _ = copy_utils.tma_get_copy_fn(
999
1019
  tma_atom_b,
1000
- block_in_cluster_coord_vmnk[1],
1001
- b_cta_layout,
1002
- cute.group_modes(sB, 0, 3),
1003
- cute.group_modes(tCgB, 0, 3),
1020
+ cta_coord=block_in_cluster_coord_vmnk[1],
1021
+ cta_layout=cute.make_layout(
1022
+ cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
1023
+ ),
1024
+ src_tensor=tCgB,
1025
+ dst_tensor=sB,
1026
+ mcast_mask=b_mcast_mask,
1027
+ tma_desc_ptr=tma_desc_b_ptr,
1004
1028
  )
1029
+ copy_SFA, copy_SFB = None, None
1005
1030
  if const_expr(self.blockscaled):
1006
1031
  # TMA load SFA partition_S/D
1007
- sfa_cta_layout = a_cta_layout
1008
- # ((atom_v, rest_v), STAGE)
1009
- # ((atom_v, rest_v), RestK)
1010
- tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition(
1032
+ copy_SFA, _, _ = copy_utils.tma_get_copy_fn(
1011
1033
  tma_atom_sfa,
1012
- block_in_cluster_coord_vmnk[2],
1013
- sfa_cta_layout,
1014
- cute.group_modes(sSFA, 0, 3),
1015
- cute.group_modes(tCgSFA, 0, 3),
1034
+ cta_coord=block_in_cluster_coord_vmnk[2],
1035
+ cta_layout=a_cta_layout,
1036
+ src_tensor=tCgSFA,
1037
+ dst_tensor=sSFA,
1038
+ filter_zeros=True,
1039
+ mcast_mask=sfa_mcast_mask,
1040
+ # tma_desc_ptr=tma_desc_sfa_ptr,
1016
1041
  )
1017
- tAsSFA = cute.filter_zeros(tAsSFA)
1018
- tAgSFA = cute.filter_zeros(tAgSFA)
1019
1042
  # TMA load SFB partition_S/D
1020
1043
  sfb_cta_layout = cute.make_layout(
1021
1044
  cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape
1022
1045
  )
1023
- # ((atom_v, rest_v), STAGE)
1024
- # ((atom_v, rest_v), RestK)
1025
- tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition(
1046
+ copy_SFB, _, _ = copy_utils.tma_get_copy_fn(
1026
1047
  tma_atom_sfb,
1027
- block_in_cluster_coord_sfb_vmnk[1],
1028
- sfb_cta_layout,
1029
- cute.group_modes(sSFB, 0, 3),
1030
- cute.group_modes(tCgSFB, 0, 3),
1048
+ cta_coord=block_in_cluster_coord_sfb_vmnk[1],
1049
+ cta_layout=sfb_cta_layout,
1050
+ src_tensor=tCgSFB,
1051
+ dst_tensor=sSFB,
1052
+ filter_zeros=True,
1053
+ mcast_mask=sfb_mcast_mask,
1054
+ # tma_desc_ptr=tma_desc_sfa_ptr,
1031
1055
  )
1032
- tBsSFB = cute.filter_zeros(tBsSFB)
1033
- tBgSFB = cute.filter_zeros(tBgSFB)
1034
- else:
1035
- tAsSFA, tAgSFA = None, None
1036
- tBsSFB, tBgSFB = None, None
1056
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1037
1057
  ab_producer_state = self.load_AB(
1038
1058
  ab_pipeline,
1039
1059
  ab_producer_state,
1040
- tma_atom_a,
1041
- tAgA,
1042
- tAsA,
1043
- a_mcast_mask,
1044
- tma_atom_b,
1045
- tBgB,
1046
- tBsB,
1047
- b_mcast_mask,
1048
- tma_atom_sfa,
1049
- tAgSFA,
1050
- tAsSFA,
1051
- sfa_mcast_mask,
1052
- tma_atom_sfb,
1053
- tBgSFB,
1054
- tBsSFB,
1055
- sfb_mcast_mask,
1060
+ copy_A,
1061
+ copy_B,
1062
+ k_tile_cnt,
1063
+ copy_SFA,
1064
+ copy_SFB,
1056
1065
  )
1057
1066
  if const_expr(epi_load_barrier is not None):
1058
1067
  # In the first work tile, the epi load warp will wait for the signal
@@ -1060,58 +1069,209 @@ class PersistentDenseGemmKernel:
1060
1069
  # with loading A and B.
1061
1070
  if do_epi_load_barrier_arrive:
1062
1071
  epi_load_barrier.arrive()
1063
- do_epi_load_barrier_arrive = cutlass.Boolean(False)
1072
+ do_epi_load_barrier_arrive = Boolean(False)
1064
1073
  # Advance to next tile
1065
1074
  tile_scheduler.advance_to_next_work()
1066
1075
  work_tile = tile_scheduler.get_current_work()
1067
1076
  # Wait A/B buffer empty
1068
1077
  ab_pipeline.producer_tail(ab_producer_state)
1069
1078
 
1079
+ if const_expr(self.gather_A):
1080
+ if (
1081
+ warp_idx >= self.ab_load_warp_id + 1
1082
+ and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
1083
+ ):
1084
+ # Persistent tile scheduling loop
1085
+ tile_scheduler = TileSchedulerCls()
1086
+ work_tile = tile_scheduler.initial_work_tile_info()
1087
+ ab_producer_state = pipeline.make_pipeline_state(
1088
+ pipeline.PipelineUserType.Producer, self.ab_stage
1089
+ )
1090
+ a_prefetch_consumer_state = pipeline.make_pipeline_state(
1091
+ pipeline.PipelineUserType.Consumer, self.a_prefetch_stage
1092
+ )
1093
+ while work_tile.is_valid_tile:
1094
+ tile_coord_mnkl = work_tile.tile_idx
1095
+ batch_idx = tile_coord_mnkl[3]
1096
+ # ///////////////////////////////////////////////////////////////////////////
1097
+ # Local_tile partition global tensors
1098
+ # ///////////////////////////////////////////////////////////////////////////
1099
+ mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
1100
+ if const_expr(varlen_m):
1101
+ # (M, K)
1102
+ mA_mk = mA_mkl
1103
+ else:
1104
+ assert varlen_k
1105
+ # (tile_M, K)
1106
+ mA_mk = cute.local_tile(
1107
+ mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
1108
+ )
1109
+ # Partition global tensor for TiledMMA_A/B/D
1110
+ len_m = varlen_manager.len_m(batch_idx)
1111
+ len_k = varlen_manager.len_k(batch_idx)
1112
+ # TMA load A partition_S/D
1113
+ tiled_copy_A = self._make_gmem_tiled_copy_A(
1114
+ mA_mkl.element_type, self.a_layout, (self.num_ab_load_warps - 1) * 32
1115
+ )
1116
+ tidx = cute.arch.thread_idx()[0] - (self.ab_load_warp_id + 1) * 32
1117
+ thr_copy_A = tiled_copy_A.get_slice(tidx)
1118
+ copy_A, prefetch_A = None, None
1119
+ if const_expr(varlen_m):
1120
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
1121
+ copy_A = copy_utils.gather_m_get_copy_fn(
1122
+ thr_copy_A,
1123
+ mA_mk,
1124
+ sA,
1125
+ sAIdx[None, a_prefetch_consumer_state.index],
1126
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
1127
+ limit_k=len_k,
1128
+ )
1129
+ cute.arch.sync_warp()
1130
+ with cute.arch.elect_one():
1131
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
1132
+ a_prefetch_consumer_state.advance()
1133
+ else:
1134
+ copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
1135
+ thr_copy_A,
1136
+ mA_mk,
1137
+ sA,
1138
+ sAIdx,
1139
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
1140
+ limit_k=len_k,
1141
+ )
1142
+ prefetch_A = partial(prefetch_A, a_prefetch_pipeline)
1143
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1144
+ ab_producer_state, a_prefetch_consumer_state = self.load_A_gather_A(
1145
+ ab_pipeline,
1146
+ ab_producer_state,
1147
+ a_prefetch_consumer_state,
1148
+ copy_A,
1149
+ prefetch_A,
1150
+ k_tile_cnt,
1151
+ )
1152
+ # Advance to next tile
1153
+ tile_scheduler.advance_to_next_work()
1154
+ work_tile = tile_scheduler.get_current_work()
1155
+
1156
+ #
1157
+ # Specialized scheduler warp. Will also prefetch A indices if gatherA
1158
+ #
1159
+ if const_expr(tile_sched_params.tile_count_semaphore is not None or self.gather_A):
1160
+ if warp_idx == self.scheduler_warp_id:
1161
+ is_scheduler_warp = True
1162
+ if const_expr(cute.size(cluster_layout_vmnk) > 1):
1163
+ is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0
1164
+ tile_M = self.cta_tile_shape_mnk[0]
1165
+ tile_K = self.cta_tile_shape_mnk[2]
1166
+ thr_copy_AIdx, tAsAIdx, tAcAIdx = None, None, None
1167
+ if const_expr(self.gather_A):
1168
+ tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True)
1169
+ thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx())
1170
+ tAsAIdx = thr_copy_AIdx.partition_D(sAIdx)
1171
+ tAcAIdx = thr_copy_AIdx.partition_S(
1172
+ cute.make_identity_tensor(tile_M if varlen_m else tile_K)
1173
+ )
1174
+ # Persistent tile scheduling loop
1175
+ tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
1176
+ work_tile = tile_scheduler.initial_work_tile_info()
1177
+ a_prefetch_producer_state = None
1178
+ if const_expr(self.gather_A):
1179
+ a_prefetch_producer_state = pipeline.make_pipeline_state(
1180
+ pipeline.PipelineUserType.Producer, self.a_prefetch_stage
1181
+ )
1182
+ while work_tile.is_valid_tile:
1183
+ if const_expr(self.gather_A):
1184
+ tile_coord_mnkl = work_tile.tile_idx
1185
+ batch_idx = tile_coord_mnkl[3]
1186
+ mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
1187
+ if const_expr(varlen_m):
1188
+ # (tile_M,)
1189
+ gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],))
1190
+ tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
1191
+ len_m = varlen_manager.len_m(batch_idx)
1192
+ m_limit = len_m - tile_coord_mnkl[0] * tile_M
1193
+ tApAIdx_m = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
1194
+ for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
1195
+ tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit
1196
+ a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
1197
+ cute.copy(
1198
+ thr_copy_AIdx,
1199
+ tAgAIdx,
1200
+ tAsAIdx[None, None, a_prefetch_producer_state.index],
1201
+ pred=tApAIdx_m,
1202
+ )
1203
+ a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
1204
+ a_prefetch_producer_state.advance()
1205
+ else:
1206
+ # (tile_K, RestK)
1207
+ gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,))
1208
+ tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
1209
+ len_k = varlen_manager.len_k(batch_idx)
1210
+ k_tile_cnt = cute.ceil_div(len_k, tile_K)
1211
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1212
+ a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
1213
+ cute.copy(
1214
+ thr_copy_AIdx,
1215
+ tAgAIdx[None, None, k_tile],
1216
+ tAsAIdx[None, None, a_prefetch_producer_state.index],
1217
+ )
1218
+ a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
1219
+ a_prefetch_producer_state.advance()
1220
+ if 0 < k_tile_cnt:
1221
+ k_tile = k_tile_cnt - 1
1222
+ k_limit = len_k - k_tile * tile_K
1223
+ tApAIdx_k = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
1224
+ for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
1225
+ tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit
1226
+ a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
1227
+ cute.copy(
1228
+ tiled_copy_AIdx,
1229
+ tAgAIdx[None, None, k_tile],
1230
+ tAsAIdx[None, None, a_prefetch_producer_state.index],
1231
+ pred=tApAIdx_k,
1232
+ )
1233
+ a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
1234
+ a_prefetch_producer_state.advance()
1235
+ # Advance to next tile
1236
+ tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
1237
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1238
+ work_tile = tile_scheduler.get_current_work()
1239
+ # End of persistent scheduler loop
1240
+ if is_scheduler_warp:
1241
+ tile_scheduler.producer_tail()
1242
+
1070
1243
  #
1071
1244
  # Specialized TMA epi load warp
1072
1245
  #
1073
1246
  if const_expr(mC_mnl is not None):
1074
- if warp_idx == self.tma_epi_warp_id:
1247
+ if warp_idx == self.epi_load_warp_id:
1075
1248
  epi_producer_state = pipeline.make_pipeline_state(
1076
- pipeline.PipelineUserType.Producer, self.num_c_stage
1249
+ pipeline.PipelineUserType.Producer, self.epi_c_stage
1077
1250
  )
1078
- do_epi_load_barrier_wait = cutlass.Boolean(True)
1251
+ do_epi_load_barrier_wait = Boolean(True)
1079
1252
  # Persistent tile scheduling loop
1080
1253
  tile_scheduler = TileSchedulerCls()
1081
1254
  work_tile = tile_scheduler.initial_work_tile_info()
1082
1255
  while work_tile.is_valid_tile:
1083
1256
  # Get tile coord from tile scheduler
1084
1257
  tile_coord_mnkl = work_tile.tile_idx
1085
- mma_tile_coord_mnl = (
1086
- tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
1087
- tile_coord_mnkl[1],
1088
- tile_coord_mnkl[3],
1089
- )
1090
- # Local_tile partition global tensors
1091
- # (bM, bN)
1092
- gC_mnl = cute.local_tile(
1093
- mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
1258
+ batch_idx = tile_coord_mnkl[3]
1259
+ copy_C_fn, _, bGS_gC = self.epilog_gmem_copy_and_partition(
1260
+ tma_atom_c,
1261
+ varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
1262
+ self.cta_tile_shape_mnk[:2],
1263
+ epi_tile,
1264
+ sC,
1265
+ tile_coord_mnkl,
1094
1266
  )
1095
- # Partition global tensor for TiledMMA_A/B/D
1096
- # (MMA, MMA_M, MMA_N)
1097
- tCgC = thr_mma.partition_C(gC_mnl)
1098
- # bGS_gC has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
1099
- bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
1100
- tma_atom_c, tCgC, epi_tile, sC
1101
- )
1102
- bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
1267
+ copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
1103
1268
  if do_epi_load_barrier_wait:
1104
1269
  epi_load_barrier.arrive_and_wait()
1105
- do_epi_load_barrier_wait = cutlass.Boolean(False)
1270
+ do_epi_load_barrier_wait = Boolean(False)
1106
1271
  epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1]))
1107
- for subtile_idx in cutlass.range(epi_tile_num, unroll=1):
1272
+ for epi_idx in cutlass.range(epi_tile_num, unroll=1):
1108
1273
  epi_pipeline.producer_acquire(epi_producer_state)
1109
- cute.copy(
1110
- tma_atom_c,
1111
- bGS_gC[None, subtile_idx],
1112
- bGS_sC[None, epi_producer_state.index],
1113
- tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1114
- )
1274
+ copy_C(src_idx=epi_idx, producer_state=epi_producer_state)
1115
1275
  # Epi pipeline's producer commit is a NOP
1116
1276
  epi_pipeline.producer_commit(epi_producer_state)
1117
1277
  epi_producer_state.advance()
@@ -1132,7 +1292,7 @@ class PersistentDenseGemmKernel:
1132
1292
  )
1133
1293
  # Partition shared/tensor memory tensor for TiledMMA_A/B/D
1134
1294
  # (MMA, MMA_M, MMA_K, STAGE)
1135
- tCrA = tiled_mma.make_fragment_A(sA)
1295
+ tCrA = tiled_mma.make_fragment_A(sA_mma)
1136
1296
  # (MMA, MMA_N, MMA_K, STAGE)
1137
1297
  tCrB = tiled_mma.make_fragment_B(sB)
1138
1298
  # (MMA, MMA_M, MMA_N, STAGE)
@@ -1149,10 +1309,9 @@ class PersistentDenseGemmKernel:
1149
1309
  tiled_mma,
1150
1310
  self.mma_tiler,
1151
1311
  self.sf_vec_size,
1152
- cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
1312
+ cute.slice_(sfa_smem_layout, (None, None, None, 0)),
1153
1313
  )
1154
1314
  tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
1155
-
1156
1315
  # Make SFB tmem tensor
1157
1316
  sfb_tmem_ptr = cute.recast_ptr(
1158
1317
  acc_tmem_ptr
@@ -1165,7 +1324,7 @@ class PersistentDenseGemmKernel:
1165
1324
  tiled_mma,
1166
1325
  self.mma_tiler,
1167
1326
  self.sf_vec_size,
1168
- cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
1327
+ cute.slice_(sfb_smem_layout, (None, None, None, 0)),
1169
1328
  )
1170
1329
  tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
1171
1330
  # Partition for S2T copy of SFA/SFB
@@ -1180,6 +1339,7 @@ class PersistentDenseGemmKernel:
1180
1339
  tCtSFB_compact_s2t,
1181
1340
  ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
1182
1341
  else:
1342
+ tCtSFA, tCtSFB = None, None
1183
1343
  tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
1184
1344
  tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
1185
1345
 
@@ -1187,7 +1347,7 @@ class PersistentDenseGemmKernel:
1187
1347
  tile_scheduler = TileSchedulerCls()
1188
1348
  work_tile = tile_scheduler.initial_work_tile_info()
1189
1349
  ab_consumer_state = pipeline.make_pipeline_state(
1190
- pipeline.PipelineUserType.Consumer, self.num_ab_stage
1350
+ pipeline.PipelineUserType.Consumer, self.ab_stage
1191
1351
  )
1192
1352
  acc_producer_state = pipeline.make_pipeline_state(
1193
1353
  pipeline.PipelineUserType.Producer, self.num_acc_stage
@@ -1195,6 +1355,9 @@ class PersistentDenseGemmKernel:
1195
1355
  while work_tile.is_valid_tile:
1196
1356
  # Get tile coord from tile scheduler
1197
1357
  tile_coord_mnkl = work_tile.tile_idx
1358
+ batch_idx = tile_coord_mnkl[3]
1359
+ k_len = varlen_manager.len_k(batch_idx)
1360
+ k_tile_cnt = cute.ceil_div(k_len, self.mma_tiler[2])
1198
1361
  # Set tensor memory buffer for current tile
1199
1362
  # (MMA, MMA_M, MMA_N)
1200
1363
  tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index]
@@ -1209,6 +1372,9 @@ class PersistentDenseGemmKernel:
1209
1372
  tCtAcc,
1210
1373
  k_tile_cnt,
1211
1374
  is_leader_cta,
1375
+ cta_rank_in_cluster,
1376
+ tCtSFA,
1377
+ tCtSFB,
1212
1378
  tiled_copy_s2t_sfa,
1213
1379
  tiled_copy_s2t_sfb,
1214
1380
  tCsSFA_compact_s2t,
@@ -1234,6 +1400,14 @@ class PersistentDenseGemmKernel:
1234
1400
  )
1235
1401
  # Bar sync for retrieve tensor memory ptr from shared memory
1236
1402
  tmem_alloc_barrier.arrive_and_wait()
1403
+
1404
+ is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
1405
+ varlen_manager.init_tensormap_epi(
1406
+ tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
1407
+ )
1408
+ tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
1409
+ tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
1410
+
1237
1411
  # Retrieving tensor memory ptr and make accumulator tensor
1238
1412
  acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
1239
1413
  self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
@@ -1241,9 +1415,9 @@ class PersistentDenseGemmKernel:
1241
1415
  # (MMA, MMA_M, MMA_N, STAGE)
1242
1416
  tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
1243
1417
 
1244
- epilog_threads = cute.arch.WARP_SIZE * len(self.epilog_warp_id)
1245
1418
  epilogue_barrier = pipeline.NamedBarrier(
1246
- barrier_id=self.epilog_sync_bar_id, num_threads=epilog_threads
1419
+ barrier_id=int(NamedBarrierGemm.Epilogue),
1420
+ num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
1247
1421
  )
1248
1422
 
1249
1423
  # Partition for epilogue
@@ -1252,19 +1426,16 @@ class PersistentDenseGemmKernel:
1252
1426
  epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
1253
1427
  )
1254
1428
 
1255
- tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.d_dtype)
1256
- tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
1257
- tiled_copy_t2r, tTR_rD, epi_tidx, sD
1429
+ tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.acc_dtype)
1430
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
1431
+ tiled_copy_t2r, self.d_layout, self.d_dtype, tTR_rD, sD, epi_tidx
1258
1432
  )
1433
+ tRS_rC, tSR_rC, tSR_sC = None, None, None
1434
+ tiled_copy_s2r = None
1259
1435
  if const_expr(mC_mnl is not None):
1260
- tTR_rC = cute.make_fragment_like(tTR_rD, self.c_dtype)
1261
- tiled_copy_s2r, tSR_rC, tSR_sC = self.epilog_smem_copy_and_partition(
1262
- tiled_copy_t2r, tTR_rC, epi_tidx, sC
1436
+ tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
1437
+ tiled_copy_t2r, self.c_layout, self.c_dtype, sC, tRS_rD.layout, epi_tidx
1263
1438
  )
1264
- # TODO: for m major, D is being stored w STSM so we'd need LDSM here
1265
- # tRS_rC = tSR_rC # TODO: retile?
1266
- tRS_rC = cute.make_fragment(tRS_rD.layout, self.c_dtype)
1267
- tSR_rC = tiled_copy_s2r.get_slice(epi_tidx).retile(tRS_rC)
1268
1439
 
1269
1440
  # Persistent tile scheduling loop
1270
1441
  tile_scheduler = TileSchedulerCls()
@@ -1272,37 +1443,27 @@ class PersistentDenseGemmKernel:
1272
1443
  acc_consumer_state = pipeline.make_pipeline_state(
1273
1444
  pipeline.PipelineUserType.Consumer, self.num_acc_stage
1274
1445
  )
1275
- # Threads/warps participating in tma store pipeline
1276
- d_producer_group = pipeline.CooperativeGroup(
1277
- pipeline.Agent.Thread,
1278
- 32 * len(self.epilog_warp_id),
1279
- 32 * len(self.epilog_warp_id),
1280
- )
1281
- d_pipeline = pipeline.PipelineTmaStore.create(
1282
- num_stages=self.num_d_stage, producer_group=d_producer_group
1283
- )
1446
+ epi_store_pipeline = self.make_epi_store_pipeline()
1284
1447
  epi_read_state = pipeline.make_pipeline_state(
1285
- pipeline.PipelineUserType.Consumer, self.num_c_stage
1448
+ pipeline.PipelineUserType.Consumer, self.epi_c_stage
1286
1449
  )
1287
-
1450
+ if const_expr(varlen_m):
1451
+ # wait tensormap initialization complete before update
1452
+ varlen_manager.fence_tensormap_init()
1288
1453
  while work_tile.is_valid_tile:
1289
1454
  # Get tile coord from tile scheduler
1290
1455
  tile_coord_mnkl = work_tile.tile_idx
1291
- mma_tile_coord_mnl = (
1292
- tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
1293
- tile_coord_mnkl[1],
1294
- tile_coord_mnkl[3],
1456
+ batch_idx = tile_coord_mnkl[3]
1457
+ epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
1458
+ epilogue_params, varlen_params.cu_seqlens_m, batch_idx
1295
1459
  )
1296
- # Local_tile partition global tensors
1297
- # (bM, bN)
1298
- gD_mnl = cute.local_tile(
1299
- mD_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
1460
+ varlen_manager.update_tensormap_epi(
1461
+ batch_idx,
1462
+ self.d_layout,
1463
+ epi_shapes,
1464
+ epi_orders,
1465
+ is_tma_warp,
1300
1466
  )
1301
- # Partition global tensor for TiledMMA_A/B/D
1302
- # (MMA, MMA_M, MMA_N)
1303
- tDgD = thr_mma.partition_C(gD_mnl)
1304
- # bSG_gD has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
1305
- bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(tma_atom_d, tDgD, epi_tile, sD)
1306
1467
 
1307
1468
  # Set tensor memory buffer for current tile
1308
1469
  # (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
@@ -1311,49 +1472,59 @@ class PersistentDenseGemmKernel:
1311
1472
  # Wait for accumulator buffer full
1312
1473
  acc_pipeline.consumer_wait(acc_consumer_state)
1313
1474
 
1314
- tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
1315
- bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
1316
-
1317
- # Store accumulator to global memory in subtiles
1318
- subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
1319
- num_prev_subtiles = tile_scheduler.num_tiles_executed * subtile_cnt
1320
- for subtile_idx in cutlass.range(subtile_cnt):
1321
- # Load accumulator from tensor memory buffer to register
1322
- tTR_tAcc_mn = tTR_tAcc[None, None, None, subtile_idx]
1323
- cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
1324
- # Convert to D type
1325
- acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
1326
- acc_vec = epilogue_op(acc_vec)
1327
- if const_expr(mC_mnl is not None):
1328
- epi_pipeline.consumer_wait(epi_read_state)
1329
- cute.copy(
1330
- tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
1331
- )
1332
- # Fence to make sure shared memory read is visible to TMA load
1333
- cute.arch.fence_proxy(
1334
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1335
- )
1336
- cute.arch.sync_warp()
1337
- with cute.arch.elect_one():
1338
- epi_pipeline.consumer_release(epi_read_state)
1339
- epi_read_state.advance()
1340
- acc_vec = acc_vec + tRS_rC.load().to(self.acc_dtype)
1341
- tRS_rD.store(acc_vec.to(self.d_dtype))
1342
- # Store D to shared memory
1343
- d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage
1344
- cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
1345
- # Fence and barrier to make sure shared memory store is visible to TMA store
1346
- cute.arch.fence_proxy(
1347
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1475
+ varlen_manager.fence_tensormap_update_epi(is_tma_warp)
1476
+
1477
+ copy_D = None
1478
+ if const_expr(has_D):
1479
+ copy_D, _, _ = self.epilog_gmem_copy_and_partition(
1480
+ tma_atom_d,
1481
+ varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
1482
+ self.cta_tile_shape_mnk[:2],
1483
+ epi_tile,
1484
+ sD,
1485
+ tile_coord_mnkl,
1486
+ tma_desc_ptr=tma_desc_d_ptr,
1348
1487
  )
1349
- epilogue_barrier.arrive_and_wait()
1350
- # TMA store D to global memory
1351
- if warp_idx == self.epilog_warp_id[0]:
1352
- cute.copy(tma_atom_d, bSG_sD[None, d_buffer], bSG_gD[None, subtile_idx])
1353
- # Fence and barrier to make sure shared memory store is visible to TMA store
1354
- d_pipeline.producer_commit()
1355
- d_pipeline.producer_acquire()
1356
- epilogue_barrier.arrive_and_wait()
1488
+ copy_C = None # We're using a separate warp to load C
1489
+
1490
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
1491
+ k_len = varlen_manager.len_k(batch_idx)
1492
+ load_acc_subtile = partial(
1493
+ self.epi_load_acc_subtile,
1494
+ tiled_copy_t2r,
1495
+ tiled_copy_r2s,
1496
+ tTR_tAcc,
1497
+ tTR_rAcc,
1498
+ clear_acc=varlen_k and k_len == 0,
1499
+ )
1500
+
1501
+ epi_read_state, _ = self.epilogue(
1502
+ epilogue_params,
1503
+ epi_smem_tensors,
1504
+ tma_desc_epi_ptrs,
1505
+ epi_pipeline,
1506
+ epi_store_pipeline,
1507
+ epi_read_state,
1508
+ None, # epi_producer_state
1509
+ epi_tile,
1510
+ load_acc_subtile,
1511
+ tRS_rD,
1512
+ tRS_rC,
1513
+ tiled_copy_t2r,
1514
+ tiled_copy_r2s,
1515
+ tRS_sD,
1516
+ tiled_copy_s2r,
1517
+ tSR_rC,
1518
+ tSR_sC,
1519
+ copy_D,
1520
+ copy_C,
1521
+ tile_coord_mnkl,
1522
+ varlen_manager,
1523
+ epilogue_barrier,
1524
+ tile_scheduler,
1525
+ epi_tidx,
1526
+ is_tma_warp,
1527
+ )
1357
1528
 
1358
1529
  # Async arrive accumulator buffer empty
1359
1530
  with cute.arch.elect_one():
@@ -1369,7 +1540,7 @@ class PersistentDenseGemmKernel:
1369
1540
  cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
1370
1541
  epilogue_barrier.arrive_and_wait()
1371
1542
  if warp_idx == self.epilog_warp_id[0]:
1372
- if use_2cta_instrs:
1543
+ if const_expr(use_2cta_instrs):
1373
1544
  cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
1374
1545
  cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
1375
1546
  cute.arch.dealloc_tmem(
@@ -1377,82 +1548,54 @@ class PersistentDenseGemmKernel:
1377
1548
  )
1378
1549
 
1379
1550
  # Wait for D store complete
1380
- d_pipeline.producer_tail()
1551
+ if is_tma_warp:
1552
+ epi_store_pipeline.producer_tail()
1381
1553
 
1382
1554
  @cute.jit
1383
- def load_AB(
1555
+ def load_A_gather_A(
1384
1556
  self,
1385
- ab_pipeline: cutlass.pipeline.PipelineAsync,
1386
- ab_producer_state: cutlass.pipeline.PipelineState,
1387
- tma_atom_a: cute.CopyAtom,
1388
- tAgA: cute.Tensor,
1389
- tAsA: cute.Tensor,
1390
- a_mcast_mask: cutlass.Int16,
1391
- tma_atom_b: cute.CopyAtom,
1392
- tBgB: cute.Tensor,
1393
- tBsB: cute.Tensor,
1394
- b_mcast_mask: cutlass.Int16,
1395
- tma_atom_sfa: Optional[cute.CopyAtom] = None,
1396
- tAgSFA: Optional[cute.Tensor] = None,
1397
- tAsSFA: Optional[cute.Tensor] = None,
1398
- sfa_mcast_mask: Optional[cutlass.Int16] = None,
1399
- tma_atom_sfb: Optional[cute.CopyAtom] = None,
1400
- tBgSFB: Optional[cute.Tensor] = None,
1401
- tBsSFB: Optional[cute.Tensor] = None,
1402
- sfb_mcast_mask: Optional[cutlass.Int16] = None,
1403
- ) -> cutlass.pipeline.PipelineState:
1404
- blockscaled = const_expr(tma_atom_sfa is not None)
1405
- if const_expr(blockscaled):
1406
- assert all(x is not None for x in (tma_atom_sfa, tAgSFA, tAsSFA))
1407
- assert all(x is not None for x in (tma_atom_sfb, tBgSFB, tBsSFB))
1408
- k_tile_cnt = cute.size(tAgA, mode=[1])
1557
+ a_pipeline: cutlass.pipeline.PipelineAsync,
1558
+ a_producer_state: cutlass.pipeline.PipelineState,
1559
+ a_prefetch_consumer_state: Optional[cutlass.pipeline.PipelineState],
1560
+ copy_A: Callable,
1561
+ prefetch_A: Optional[Callable],
1562
+ k_tile_cnt: Int32,
1563
+ ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]:
1409
1564
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1410
- peek_ab_empty_status = cutlass.Boolean(True)
1565
+ peek_a_empty_status = Boolean(True)
1411
1566
  if 0 < k_tile_cnt:
1412
- peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1567
+ peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
1413
1568
  # /////////////////////////////////////////////////////////////////////////
1414
- # TMA load
1569
+ # cp.async on A
1415
1570
  # /////////////////////////////////////////////////////////////////////////
1416
- for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1417
- # Wait for A/B buffers to be empty before loading into them
1418
- # Also sets the transaction barrier for the A/B buffers
1419
- ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1420
- cute.copy(
1421
- tma_atom_a,
1422
- tAgA[None, k_tile],
1423
- tAsA[None, ab_producer_state.index],
1424
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1425
- mcast_mask=a_mcast_mask,
1426
- )
1427
- cute.copy(
1428
- tma_atom_b,
1429
- tBgB[None, k_tile],
1430
- tBsB[None, ab_producer_state.index],
1431
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1432
- mcast_mask=b_mcast_mask,
1433
- )
1434
- if const_expr(blockscaled):
1435
- cute.copy(
1436
- tma_atom_sfa,
1437
- tAgSFA[None, ab_producer_state.count],
1438
- tAsSFA[None, ab_producer_state.index],
1439
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1440
- mcast_mask=sfa_mcast_mask,
1441
- )
1442
- cute.copy(
1443
- tma_atom_sfb,
1444
- tBgSFB[None, ab_producer_state.count],
1445
- tBsSFB[None, ab_producer_state.index],
1446
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1447
- mcast_mask=sfb_mcast_mask,
1448
- )
1449
- # Mainloop pipeline's producer commit is a NOP
1450
- ab_pipeline.producer_commit(ab_producer_state)
1451
- ab_producer_state.advance()
1452
- peek_ab_empty_status = cutlass.Boolean(True)
1571
+ is_tma_warp = False
1572
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1573
+ smem_idx = a_producer_state.index
1574
+ prefetch_out = ()
1575
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1576
+ prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),)
1577
+ a_prefetch_consumer_state.advance()
1578
+ a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp)
1579
+ copy_A(k_tile, smem_idx, *prefetch_out)
1580
+ # This tells mbarrier to track the completion of cp.async
1581
+ a_pipeline.producer_cpasync_commit(a_producer_state)
1582
+ a_producer_state.advance()
1583
+ peek_a_empty_status = Boolean(True)
1453
1584
  if k_tile + 1 < k_tile_cnt:
1454
- peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1455
- return ab_producer_state
1585
+ peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
1586
+ # bound checking in the K dimension on the last k_tile
1587
+ if 0 < k_tile_cnt:
1588
+ k_tile = k_tile_cnt - 1
1589
+ smem_idx = a_producer_state.index
1590
+ prefetch_out = ()
1591
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1592
+ prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),)
1593
+ a_prefetch_consumer_state.advance()
1594
+ a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp)
1595
+ copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
1596
+ a_pipeline.producer_cpasync_commit(a_producer_state)
1597
+ a_producer_state.advance()
1598
+ return a_producer_state, a_prefetch_consumer_state
1456
1599
 
1457
1600
  @cute.jit
1458
1601
  def mma(
@@ -1466,7 +1609,10 @@ class PersistentDenseGemmKernel:
1466
1609
  tCrB: cute.Tensor,
1467
1610
  acc: cute.Tensor,
1468
1611
  k_tile_cnt: Int32,
1469
- is_leader_cta: cutlass.Boolean,
1612
+ is_leader_cta: Boolean,
1613
+ cta_rank_in_cluster: Int32,
1614
+ tCtSFA: Optional[cute.Tensor] = None,
1615
+ tCtSFB: Optional[cute.Tensor] = None,
1470
1616
  tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None,
1471
1617
  tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None,
1472
1618
  tCsSFA_compact_s2t: Optional[cute.Tensor] = None,
@@ -1476,12 +1622,17 @@ class PersistentDenseGemmKernel:
1476
1622
  ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]:
1477
1623
  blockscaled = const_expr(tiled_copy_s2t_sfa is not None)
1478
1624
  if const_expr(blockscaled):
1625
+ assert all(x is not None for x in (tCtSFA, tCtSFB))
1479
1626
  assert all(x is not None for x in (tiled_copy_s2t_sfa, tiled_copy_s2t_sfb))
1480
1627
  assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t))
1481
1628
  assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t))
1629
+ # If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will
1630
+ # arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader
1631
+ # CTA will wait for that then arrive at the mbarrier on the leader CTA.
1632
+ need_nonleader_cta = const_expr(self.gather_A and self.use_2cta_instrs)
1482
1633
  # Peek (try_wait) AB buffer full for k_tile = 0
1483
- peek_ab_full_status = cutlass.Boolean(True)
1484
- if 0 < k_tile_cnt and is_leader_cta:
1634
+ peek_ab_full_status = Boolean(True)
1635
+ if 0 < k_tile_cnt and (is_leader_cta or need_nonleader_cta):
1485
1636
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
1486
1637
  # Wait for accumulator buffer empty
1487
1638
  if is_leader_cta:
@@ -1491,6 +1642,14 @@ class PersistentDenseGemmKernel:
1491
1642
  # Mma mainloop
1492
1643
  num_k_blocks = cute.size(tCrA, mode=[2])
1493
1644
  for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1645
+ if const_expr(need_nonleader_cta):
1646
+ if not is_leader_cta:
1647
+ ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
1648
+ with cute.arch.elect_one():
1649
+ # The odd CTA signals the even CTA
1650
+ ab_pipeline.sync_object_full.arrive_mbarrier(
1651
+ ab_consumer_state.index, dst_rank=cta_rank_in_cluster & 0xFE
1652
+ )
1494
1653
  if is_leader_cta:
1495
1654
  # Conditionally wait for AB buffer full
1496
1655
  ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
@@ -1503,14 +1662,19 @@ class PersistentDenseGemmKernel:
1503
1662
  cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
1504
1663
  for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1505
1664
  k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index)
1665
+ if const_expr(blockscaled):
1666
+ # Set SFA/SFB tensor to tiled_mma
1667
+ sf_kblock_coord = (None, None, k_blk_idx)
1668
+ tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
1669
+ tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator)
1506
1670
  cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1507
1671
  tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
1508
1672
  # Async arrive AB buffer empty
1509
1673
  ab_pipeline.consumer_release(ab_consumer_state)
1510
1674
  ab_consumer_state.advance()
1511
1675
  # Peek (try_wait) AB buffer full for k_tile = k_tile + 1
1512
- peek_ab_full_status = cutlass.Boolean(True)
1513
- if k_tile + 1 < k_tile_cnt and is_leader_cta:
1676
+ peek_ab_full_status = Boolean(True)
1677
+ if k_tile + 1 < k_tile_cnt and (is_leader_cta or need_nonleader_cta):
1514
1678
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
1515
1679
  # Async arrive accumulator buffer full
1516
1680
  if is_leader_cta:
@@ -1520,6 +1684,25 @@ class PersistentDenseGemmKernel:
1520
1684
  # "operand #0 does not dominate this use"
1521
1685
  return ab_consumer_state, acc_producer_state, tiled_mma
1522
1686
 
1687
+ @cute.jit
1688
+ def epi_load_acc_subtile(
1689
+ self,
1690
+ tiled_copy_t2r: cute.TiledCopy,
1691
+ tiled_copy_r2s: cute.TiledCopy,
1692
+ tTR_tAcc: cute.Tensor,
1693
+ tTR_rAcc: cute.Tensor,
1694
+ tRS_rD: cute.Tensor,
1695
+ epi_idx: int,
1696
+ clear_acc: Boolean = False,
1697
+ ):
1698
+ if not clear_acc:
1699
+ # Load accumulator from tensor memory buffer to register
1700
+ cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, None, epi_idx], tTR_rAcc)
1701
+ tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc)
1702
+ tRS_rD.store(tRS_rAcc.load())
1703
+ else:
1704
+ tRS_rD.fill(0.0)
1705
+
1523
1706
  def mainloop_s2t_copy_and_partition(
1524
1707
  self,
1525
1708
  sSF: cute.Tensor,
@@ -1560,7 +1743,7 @@ class PersistentDenseGemmKernel:
1560
1743
  tidx: Int32,
1561
1744
  tAcc: cute.Tensor,
1562
1745
  epi_tile: cute.Tile,
1563
- use_2cta_instrs: Union[cutlass.Boolean, bool],
1746
+ use_2cta_instrs: Union[Boolean, bool],
1564
1747
  ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1565
1748
  """
1566
1749
  Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
@@ -1583,8 +1766,8 @@ class PersistentDenseGemmKernel:
1583
1766
  # Make tiledCopy for tensor memory load
1584
1767
  copy_atom_t2r = sm100_utils.get_tmem_load_op(
1585
1768
  self.cta_tile_shape_mnk,
1586
- self.d_layout,
1587
- self.d_dtype,
1769
+ self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR,
1770
+ self.d_dtype if self.d_dtype is not None else cutlass.BFloat16,
1588
1771
  self.acc_dtype,
1589
1772
  epi_tile,
1590
1773
  use_2cta_instrs,
@@ -1607,12 +1790,14 @@ class PersistentDenseGemmKernel:
1607
1790
  tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
1608
1791
  return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
1609
1792
 
1610
- def epilog_smem_copy_and_partition(
1793
+ def epilog_smem_store_and_partition(
1611
1794
  self,
1612
1795
  tiled_copy_t2r: cute.TiledCopy,
1796
+ d_layout: Optional[LayoutEnum],
1797
+ dtype: Optional[Type[cutlass.Numeric]],
1613
1798
  tTR_rD: cute.Tensor,
1614
- tidx: Int32,
1615
1799
  sD: cute.Tensor,
1800
+ tidx: Int32,
1616
1801
  ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1617
1802
  """
1618
1803
  Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
@@ -1634,93 +1819,183 @@ class PersistentDenseGemmKernel:
1634
1819
  :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
1635
1820
  """
1636
1821
  copy_atom_r2s = sm100_utils.get_smem_store_op(
1637
- self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r
1822
+ d_layout if d_layout is not None else LayoutEnum.ROW_MAJOR,
1823
+ dtype if dtype is not None else cutlass.BFloat16,
1824
+ self.acc_dtype,
1825
+ tiled_copy_t2r,
1638
1826
  )
1639
1827
  tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
1640
1828
  # (R2S, R2S_M, R2S_N, PIPE_D)
1641
1829
  thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1642
- tRS_sD = thr_copy_r2s.partition_D(sD)
1830
+ tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1643
1831
  # (R2S, R2S_M, R2S_N)
1644
1832
  tRS_rD = tiled_copy_r2s.retile(tTR_rD)
1645
1833
  return tiled_copy_r2s, tRS_rD, tRS_sD
1646
1834
 
1647
- # def epilog_smem_load_copy_and_partition(
1648
- # self,
1649
- # tiled_copy_t2r: cute.TiledCopy,
1650
- # tTR_rC: cute.Tensor,
1651
- # tidx: Int32,
1652
- # sC: cute.Tensor,
1653
- # ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1654
- # copy_atom_s2r = cute.make_copy_atom(
1655
- # warp.LdMatrix8x8x16bOp(self.c_layout.is_m_major_c(), num_matrices=4),
1656
- # self.c_dtype, # TODO: this probably only works for f16 for now?
1657
- # )
1658
- # # copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
1659
- # tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
1660
- # # (R2S, R2S_M, R2S_N, PIPE_D)
1661
- # thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1662
- # # (R2S, R2S_M, R2S_N)
1663
- # tSR_sC = thr_copy_s2r.partition_S(sC)
1664
- # return tiled_copy_s2r, tSR_sC
1665
-
1666
- def epilog_gmem_copy_and_partition(
1835
+ def epilog_smem_load_and_partition(
1667
1836
  self,
1668
- atom: Union[cute.CopyAtom, cute.TiledCopy],
1669
- gD_mnl: cute.Tensor,
1670
- epi_tile: cute.Tile,
1671
- sD: cute.Tensor,
1672
- ) -> Tuple[cute.Tensor, cute.Tensor]:
1673
- """Make tiledCopy for global memory store, then use it to:
1674
- - partition register array (source) and global memory (destination) for none TMA store version;
1675
- - partition shared memory (source) and global memory (destination) for TMA store version.
1676
-
1677
- :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
1678
- :type atom: cute.CopyAtom or cute.TiledCopy
1679
- :param gD_mnl: The global tensor C
1680
- :type gD_mnl: cute.Tensor
1681
- :param epi_tile: The epilogue tiler
1682
- :type epi_tile: cute.Tile
1683
- :param sD: The shared memory tensor to be copied and partitioned
1684
- :type sD: cute.Tensor
1837
+ tiled_copy_t2r: cute.TiledCopy,
1838
+ c_layout: LayoutEnum,
1839
+ dtype: Type[cutlass.Numeric],
1840
+ # tTR_rC: cute.Tensor,
1841
+ sC: cute.Tensor,
1842
+ tRS_rD_layout: cutlass.Layout,
1843
+ tidx: Int32,
1844
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1845
+ copy_atom_r2s = sm100_utils.get_smem_store_op(
1846
+ c_layout, dtype, self.acc_dtype, tiled_copy_t2r
1847
+ )
1848
+ store_op = copy_atom_r2s.op
1849
+ # m8n8 16-bit path
1850
+ if isinstance(store_op, StMatrix8x8x16bOp):
1851
+ op = LdMatrix8x8x16bOp(num_matrices=store_op.num_matrices, transpose=store_op.transpose)
1852
+ # m16n8 8-bit store -> m16n16 8-bit load
1853
+ elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [2, 4]:
1854
+ # transpose=True is enforced by the class
1855
+ op = LdMatrix16x16x8bOp(num_matrices=store_op.num_matrices // 2)
1856
+ else:
1857
+ op = cute.nvgpu.CopyUniversalOp()
1858
+ copy_atom_s2r = cute.make_copy_atom(op, dtype)
1859
+ tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
1860
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1861
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1862
+ tSR_sC = thr_copy_s2r.partition_S(sC)
1863
+ tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
1864
+ # (R2S, R2S_M, R2S_N)
1865
+ tSR_rC = tiled_copy_s2r.retile(tRS_rC)
1866
+ return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
1685
1867
 
1686
- :return: A tuple containing either:
1687
- - For TMA store: (tma_atom_d, bSG_sD, bSG_gD) where:
1688
- - tma_atom_d: The TMA copy atom
1689
- - bSG_sD: The partitioned shared memory tensor C
1690
- - bSG_gD: The partitioned global tensor C
1691
- - For non-TMA store: (simt_atom, tTR_rD, tTR_gD) where:
1692
- - simt_atom: The SIMT copy atom
1693
- - tTR_rD: The register tensor C
1694
- - tTR_gD: The partitioned global tensor C
1695
- :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
1696
- """
1697
- # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
1698
- gD_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0)], epi_tile)
1699
- sD_for_tma_partition = cute.group_modes(sD, 0, 2)
1700
- gD_for_tma_partition = cute.group_modes(gD_epi, 0, 2)
1701
- # ((ATOM_V, REST_V), EPI_M, EPI_N)
1702
- bSG_sD, bSG_gD = cpasync.tma_partition(
1703
- atom,
1704
- 0,
1705
- cute.make_layout(1),
1706
- sD_for_tma_partition,
1707
- gD_for_tma_partition,
1868
+ @cute.jit
1869
+ def make_ab_pipeline(
1870
+ self,
1871
+ tiled_mma: cute.TiledMma,
1872
+ cluster_layout_vmnk: cute.Layout,
1873
+ ab_pipeline_mbar_ptr: cute.Pointer,
1874
+ is_leader_cta: Boolean,
1875
+ ) -> pipeline.PipelineAsync:
1876
+ # If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will
1877
+ # arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader
1878
+ # CTA will wait for that then arrive at the mbarrier on the leader CTA.
1879
+ # The producer count for the leader CTA is 1 (TMA) + num_cpasync_threads
1880
+ # + 1 (from non-leader CTA).
1881
+ # The producer count for the non-leader CTA is num_cpasync_threads
1882
+ # (TMA doesn't arrive there).
1883
+ if const_expr(not self.gather_A):
1884
+ producer_cnt = 1
1885
+ else:
1886
+ producer_cnt = (self.num_ab_load_warps - 1) * 32 + (
1887
+ 1 if const_expr(not self.use_2cta_instrs) else 2
1888
+ )
1889
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1890
+ # Each warp will contribute to the arrive count with the number of mcast size
1891
+ mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
1892
+ consumer_arrive_cnt = mcast_size
1893
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(
1894
+ pipeline.Agent.Thread, consumer_arrive_cnt
1708
1895
  )
1709
- return bSG_sD, bSG_gD
1896
+ if const_expr(not self.gather_A):
1897
+ pipeline_ab = pipeline.PipelineTmaUmma.create(
1898
+ barrier_storage=ab_pipeline_mbar_ptr,
1899
+ num_stages=self.ab_stage,
1900
+ producer_group=ab_pipeline_producer_group,
1901
+ consumer_group=ab_pipeline_consumer_group,
1902
+ tx_count=self.num_tma_load_bytes,
1903
+ cta_layout_vmnk=cluster_layout_vmnk,
1904
+ )
1905
+ else:
1906
+ pipeline_ab = PipelineTmaCpAsyncUmma.create(
1907
+ barrier_storage=ab_pipeline_mbar_ptr,
1908
+ num_stages=self.ab_stage,
1909
+ producer_group=ab_pipeline_producer_group,
1910
+ consumer_group=ab_pipeline_consumer_group,
1911
+ tx_count=self.num_tma_load_bytes,
1912
+ cta_layout_vmnk=cluster_layout_vmnk,
1913
+ producer_drop_count=None
1914
+ if not self.use_2cta_instrs
1915
+ else (2 if not is_leader_cta else 0),
1916
+ )
1917
+ return pipeline_ab
1710
1918
 
1711
- @staticmethod
1919
+ def make_acc_pipeline(
1920
+ self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer
1921
+ ) -> pipeline.PipelineAsync:
1922
+ acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1923
+ num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1)
1924
+ acc_pipeline_consumer_group = pipeline.CooperativeGroup(
1925
+ pipeline.Agent.Thread, num_acc_consumer_threads
1926
+ )
1927
+ return pipeline.PipelineUmmaAsync.create(
1928
+ barrier_storage=acc_pipeline_mbar_ptr,
1929
+ num_stages=self.num_acc_stage,
1930
+ producer_group=acc_pipeline_producer_group,
1931
+ consumer_group=acc_pipeline_consumer_group,
1932
+ cta_layout_vmnk=cluster_layout_vmnk,
1933
+ )
1934
+
1935
+ def make_sched_pipeline(
1936
+ self,
1937
+ cluster_layout_mnk: cute.Layout,
1938
+ sched_pipeline_mbar_ptr: cute.Pointer,
1939
+ has_C: bool = False,
1940
+ ) -> pipeline.PipelineAsync:
1941
+ # Threads/warps participating in this pipeline
1942
+ sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1943
+ cluster_size = cute.size(cluster_layout_mnk)
1944
+ # Each warp that are not the scheduler warp will contribute 1 to the arrive count
1945
+ warps_per_cta = self.num_ab_load_warps + len(
1946
+ (self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id)
1947
+ )
1948
+ if has_C:
1949
+ warps_per_cta += 1
1950
+ consumer_arrive_cnt = warps_per_cta * cluster_size - 1
1951
+ sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1952
+ pipeline.Agent.Thread, consumer_arrive_cnt
1953
+ )
1954
+ return pipeline.PipelineAsync.create(
1955
+ barrier_storage=sched_pipeline_mbar_ptr,
1956
+ num_stages=self.sched_stage,
1957
+ producer_group=sched_pipeline_producer_group,
1958
+ consumer_group=sched_pipeline_consumer_group,
1959
+ # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1960
+ consumer_mask=None if const_expr(cluster_size == 1) else 0,
1961
+ )
1962
+
1963
+ @cute.jit
1964
+ def make_a_prefetch_pipeline(
1965
+ self, a_prefetch_pipeline_mbar_ptr: cute.Pointer
1966
+ ) -> pipeline.PipelineAsync:
1967
+ producer_cnt = 32
1968
+ a_prefetch_producer_group = pipeline.CooperativeGroup(
1969
+ pipeline.Agent.Thread, producer_cnt, alignment=producer_cnt
1970
+ )
1971
+ consumer_arrive_cnt = self.num_ab_load_warps - 1
1972
+ a_prefetch_consumer_group = pipeline.CooperativeGroup(
1973
+ pipeline.Agent.Thread, consumer_arrive_cnt
1974
+ )
1975
+ return pipeline.PipelineCpAsync.create(
1976
+ barrier_storage=a_prefetch_pipeline_mbar_ptr,
1977
+ num_stages=self.a_prefetch_stage,
1978
+ producer_group=a_prefetch_producer_group,
1979
+ consumer_group=a_prefetch_consumer_group,
1980
+ )
1981
+
1982
+ @classmethod
1712
1983
  def _compute_stages(
1984
+ cls,
1713
1985
  tiled_mma: cute.TiledMma,
1714
1986
  mma_tiler_mnk: Tuple[int, int, int],
1987
+ cta_tile_shape_mnk: Tuple[int, int, int],
1988
+ epi_tile: cute.Tile,
1715
1989
  a_dtype: Type[cutlass.Numeric],
1716
1990
  b_dtype: Type[cutlass.Numeric],
1717
- epi_tile: cute.Tile,
1718
- d_dtype: Type[cutlass.Numeric],
1719
- c_dtype: Optional[Type[cutlass.Numeric]],
1720
- d_layout: cutlass.utils.LayoutEnum,
1721
- c_layout: Optional[cutlass.utils.LayoutEnum],
1722
1991
  sf_dtype: Optional[Type[cutlass.Numeric]],
1723
1992
  sf_vec_size: Optional[int],
1993
+ d_dtype: Optional[Type[cutlass.Numeric]],
1994
+ c_dtype: Optional[Type[cutlass.Numeric]],
1995
+ d_layout: Optional[LayoutEnum],
1996
+ c_layout: Optional[LayoutEnum],
1997
+ epilogue_args: EpilogueArguments,
1998
+ prefetch_A_idx: Literal[None, "varlen_m", "varlen_k"],
1724
1999
  smem_capacity: int,
1725
2000
  occupancy: int,
1726
2001
  ) -> Tuple[int, int, int]:
@@ -1738,8 +2013,8 @@ class PersistentDenseGemmKernel:
1738
2013
  :type epi_tile: cute.Tile
1739
2014
  :param d_dtype: Data type of operand C (output).
1740
2015
  :type d_dtype: type[cutlass.Numeric]
1741
- :param d_layout: Layout enum of operand C.
1742
- :type d_layout: cutlass.utils.LayoutEnum
2016
+ :param d_layout: Layout enum of operand D.
2017
+ :type d_layout: LayoutEnum
1743
2018
  :param smem_capacity: Total available shared memory capacity in bytes.
1744
2019
  :type smem_capacity: int
1745
2020
  :param occupancy: Target number of CTAs per SM (occupancy).
@@ -1757,8 +2032,8 @@ class PersistentDenseGemmKernel:
1757
2032
  num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
1758
2033
 
1759
2034
  # Default D stages
1760
- num_d_stage = 2
1761
- num_c_stage = 2 if c_dtype is not None else 0
2035
+ epi_stage = 4 if cute.size(epi_tile[1]) <= 16 else 2
2036
+ epi_c_stage = 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2)
1762
2037
 
1763
2038
  # Calculate smem layout and size for one stage of A, B, and C
1764
2039
  a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
@@ -1773,7 +2048,11 @@ class PersistentDenseGemmKernel:
1773
2048
  b_dtype,
1774
2049
  1, # a tmp 1 stage is provided
1775
2050
  )
1776
- d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
2051
+ d_smem_layout_staged_one = (
2052
+ sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
2053
+ if d_dtype is not None
2054
+ else None
2055
+ )
1777
2056
  c_smem_layout_staged_one = (
1778
2057
  sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
1779
2058
  if c_dtype is not None
@@ -1796,34 +2075,38 @@ class PersistentDenseGemmKernel:
1796
2075
  ab_bytes_per_stage = cute.size_in_bytes(
1797
2076
  a_dtype, a_smem_layout_staged_one
1798
2077
  ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
2078
+ if const_expr(prefetch_A_idx == "varlen_k"): # Need smem to prefetch A indices
2079
+ ab_bytes_per_stage += Int32.width // 8 * cta_tile_shape_mnk[2]
1799
2080
  if const_expr(blockscaled):
1800
2081
  ab_bytes_per_stage += cute.size_in_bytes(
1801
2082
  sf_dtype, sfa_smem_layout_staged_one
1802
2083
  ) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
1803
2084
  mbar_helpers_bytes = 1024
1804
- d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
1805
- epi_bytes = d_bytes_per_stage * num_d_stage
2085
+ if const_expr(prefetch_A_idx == "varlen_m"):
2086
+ mbar_helpers_bytes += Int32.width // 8 * cta_tile_shape_mnk[0] * 2
2087
+ d_bytes_per_stage = (
2088
+ cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) if d_dtype is not None else 0
2089
+ )
2090
+ epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
2091
+ epilogue_args, cta_tile_shape_mnk, epi_tile
2092
+ )
2093
+ epi_bytes = epi_bytes_per_stage * epi_stage
1806
2094
  if const_expr(c_dtype is not None):
1807
2095
  c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
1808
- epi_bytes += c_bytes_per_stage * num_c_stage
2096
+ epi_bytes += c_bytes_per_stage * epi_c_stage
1809
2097
 
1810
2098
  # Calculate A/B/SFA/SFB stages:
1811
2099
  # Start with total smem per CTA (capacity / occupancy)
1812
2100
  # Subtract reserved bytes and initial C stages bytes
1813
2101
  # Divide remaining by bytes needed per A/B/SFA/SFB stage
1814
- num_ab_stage = (
1815
- smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)
1816
- ) // ab_bytes_per_stage
2102
+ remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
2103
+ ab_stage = remaining_bytes // ab_bytes_per_stage
1817
2104
 
1818
2105
  # Refine epilogue stages:
1819
2106
  # Calculate remaining smem after allocating for A/B stages and reserved bytes
1820
2107
  # Add remaining unused smem to epilogue
1821
- num_d_stage += (
1822
- smem_capacity
1823
- - occupancy * ab_bytes_per_stage * num_ab_stage
1824
- - occupancy * (mbar_helpers_bytes + epi_bytes)
1825
- ) // (occupancy * d_bytes_per_stage)
1826
- return num_acc_stage, num_ab_stage, num_d_stage, num_c_stage
2108
+ epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // (epi_bytes_per_stage)
2109
+ return num_acc_stage, ab_stage, epi_stage, epi_c_stage
1827
2110
 
1828
2111
  @staticmethod
1829
2112
  def _compute_num_tmem_alloc_cols(
@@ -1851,9 +2134,12 @@ class PersistentDenseGemmKernel:
1851
2134
 
1852
2135
  @staticmethod
1853
2136
  def is_valid_dtypes(
1854
- ab_dtype: Type[cutlass.Numeric],
2137
+ a_dtype: Type[cutlass.Numeric],
2138
+ b_dtype: Type[cutlass.Numeric],
1855
2139
  acc_dtype: Type[cutlass.Numeric],
1856
- d_dtype: Type[cutlass.Numeric],
2140
+ d_dtype: Optional[Type[cutlass.Numeric]],
2141
+ a_major: str,
2142
+ b_major: str,
1857
2143
  ) -> bool:
1858
2144
  """
1859
2145
  Check if the dtypes are valid
@@ -1869,6 +2155,9 @@ class PersistentDenseGemmKernel:
1869
2155
  :rtype: bool
1870
2156
  """
1871
2157
  is_valid = True
2158
+ if b_dtype != a_dtype:
2159
+ is_valid = False
2160
+ ab_dtype = a_dtype
1872
2161
  if ab_dtype not in {
1873
2162
  cutlass.Float16,
1874
2163
  cutlass.BFloat16,
@@ -1880,18 +2169,18 @@ class PersistentDenseGemmKernel:
1880
2169
  }:
1881
2170
  is_valid = False
1882
2171
  if (
1883
- acc_dtype not in {cutlass.Float32, cutlass.Float16, Int32}
2172
+ acc_dtype not in {Float32, cutlass.Float16, Int32}
1884
2173
  or acc_dtype == cutlass.Float16
1885
2174
  and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
1886
2175
  or acc_dtype == Int32
1887
2176
  and ab_dtype not in {cutlass.Uint8, cutlass.Int8}
1888
2177
  ):
1889
2178
  is_valid = False
1890
- if (
1891
- acc_dtype == cutlass.Float32
2179
+ if d_dtype is not None and (
2180
+ acc_dtype == Float32
1892
2181
  and d_dtype
1893
2182
  not in {
1894
- cutlass.Float32,
2183
+ Float32,
1895
2184
  cutlass.Float16,
1896
2185
  cutlass.BFloat16,
1897
2186
  cutlass.Float8E4M3FN,
@@ -1911,13 +2200,15 @@ class PersistentDenseGemmKernel:
1911
2200
  not in {
1912
2201
  cutlass.BFloat16,
1913
2202
  cutlass.Float16,
1914
- cutlass.Float32,
2203
+ Float32,
1915
2204
  Int32,
1916
2205
  cutlass.Int8,
1917
2206
  cutlass.Uint8,
1918
2207
  }
1919
2208
  ):
1920
2209
  is_valid = False
2210
+ if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
2211
+ is_valid = False
1921
2212
  return is_valid
1922
2213
 
1923
2214
  @staticmethod
@@ -1964,7 +2255,7 @@ class PersistentDenseGemmKernel:
1964
2255
 
1965
2256
  # Check valid d_dtype
1966
2257
  if d_dtype not in {
1967
- cutlass.Float32,
2258
+ Float32,
1968
2259
  cutlass.Float16,
1969
2260
  cutlass.BFloat16,
1970
2261
  cutlass.Float8E5M2,
@@ -1974,37 +2265,8 @@ class PersistentDenseGemmKernel:
1974
2265
 
1975
2266
  return is_valid
1976
2267
 
1977
- @staticmethod
1978
- def is_valid_layouts(
1979
- ab_dtype: Type[cutlass.Numeric],
1980
- a_major: str,
1981
- b_major: str,
1982
- ) -> bool:
1983
- """
1984
- Check if the dtypes and sf_vec_size are valid combinations
1985
-
1986
- :param ab_dtype: The data type of the A and B operands
1987
- :type ab_dtype: Type[cutlass.Numeric]
1988
- :param d_dtype: The data type of the output tensor
1989
- :type d_dtype: Type[cutlass.Numeric]
1990
- :param a_major: The major dimension of the A tensor
1991
- :type a_major: str
1992
- :param b_major: The major dimension of the B tensor
1993
- :type b_major: str
1994
- :param d_major: The major dimension of the C tensor
1995
- :type d_major: str
1996
-
1997
- :return: True if the layouts are valid, False otherwise
1998
- :rtype: bool
1999
- """
2000
- is_valid = True
2001
- if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
2002
- is_valid = False
2003
- return is_valid
2004
-
2005
2268
  @staticmethod
2006
2269
  def is_valid_mma_tiler_and_cluster_shape(
2007
- use_2cta_instrs: bool,
2008
2270
  mma_tiler_mn: Tuple[int, int],
2009
2271
  cluster_shape_mn: Tuple[int, int],
2010
2272
  blockscaled: bool,
@@ -2012,8 +2274,6 @@ class PersistentDenseGemmKernel:
2012
2274
  """
2013
2275
  Check if the mma tiler and cluster shape are valid
2014
2276
 
2015
- :param use_2cta_instrs: Whether to use 2 CTA groups
2016
- :type use_2cta_instrs: bool
2017
2277
  :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
2018
2278
  :type mma_tiler_mn: Tuple[int, int]
2019
2279
  :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
@@ -2024,10 +2284,7 @@ class PersistentDenseGemmKernel:
2024
2284
  """
2025
2285
  is_valid = True
2026
2286
  # Skip invalid mma tile shape
2027
- if not (
2028
- (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
2029
- or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
2030
- ):
2287
+ if mma_tiler_mn[0] not in [64, 128, 256]:
2031
2288
  is_valid = False
2032
2289
  if not blockscaled:
2033
2290
  if mma_tiler_mn[1] not in range(32, 257, 32):
@@ -2035,9 +2292,6 @@ class PersistentDenseGemmKernel:
2035
2292
  else:
2036
2293
  if mma_tiler_mn[1] not in [128, 256]:
2037
2294
  is_valid = False
2038
- # Skip illegal cluster shape
2039
- if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
2040
- is_valid = False
2041
2295
  # Skip invalid cluster shape
2042
2296
  is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
2043
2297
  if (
@@ -2113,7 +2367,6 @@ class PersistentDenseGemmKernel:
2113
2367
  ab_dtype: Type[cutlass.Numeric],
2114
2368
  acc_dtype: Type[cutlass.Numeric],
2115
2369
  d_dtype: Type[cutlass.Numeric],
2116
- use_2cta_instrs: bool,
2117
2370
  mma_tiler_mn: Tuple[int, int],
2118
2371
  cluster_shape_mn: Tuple[int, int],
2119
2372
  m: int,
@@ -2133,8 +2386,6 @@ class PersistentDenseGemmKernel:
2133
2386
  :type acc_dtype: Type[cutlass.Numeric]
2134
2387
  :param d_dtype: The data type of the output tensor
2135
2388
  :type d_dtype: Type[cutlass.Numeric]
2136
- :param use_2cta_instrs: Whether to use 2 CTA groups
2137
- :type use_2cta_instrs: bool
2138
2389
  :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
2139
2390
  :type mma_tiler_mn: Tuple[int, int]
2140
2391
  :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
@@ -2159,15 +2410,15 @@ class PersistentDenseGemmKernel:
2159
2410
  """
2160
2411
  can_implement = True
2161
2412
  # Skip unsupported types
2162
- if not PersistentDenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, d_dtype):
2413
+ if not GemmSm100.is_valid_dtypes(ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major):
2163
2414
  can_implement = False
2164
2415
  # Skip invalid mma tile shape and cluster shape
2165
- if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
2166
- use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, blockscaled=False
2416
+ if not GemmSm100.is_valid_mma_tiler_and_cluster_shape(
2417
+ mma_tiler_mn, cluster_shape_mn, blockscaled=False
2167
2418
  ):
2168
2419
  can_implement = False
2169
2420
  # Skip illegal problem shape for load/store alignment
2170
- if not PersistentDenseGemmKernel.is_valid_tensor_alignment(
2421
+ if not GemmSm100.is_valid_tensor_alignment(
2171
2422
  m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major
2172
2423
  ):
2173
2424
  can_implement = False
@@ -2186,7 +2437,6 @@ def run(
2186
2437
  c_major: str,
2187
2438
  mma_tiler_mn: Tuple[int, int] = (256, 256),
2188
2439
  cluster_shape_mn: Tuple[int, int] = (2, 1),
2189
- use_2cta_instrs: bool = True,
2190
2440
  tolerance: float = 1e-01,
2191
2441
  warmup_iterations: int = 0,
2192
2442
  iterations: int = 1,
@@ -2215,9 +2465,6 @@ def run(
2215
2465
  :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
2216
2466
  default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
2217
2467
  :type cluster_shape_mn: Tuple[int, int], optional
2218
- :param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
2219
- will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
2220
- :type use_2cta_instrs: bool, optional
2221
2468
  :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
2222
2469
  :type tolerance: float, optional
2223
2470
  :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
@@ -2236,7 +2483,6 @@ def run(
2236
2483
  print(f"AB dtype: {ab_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
2237
2484
  print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
2238
2485
  print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
2239
- print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
2240
2486
  print(f"Tolerance: {tolerance}")
2241
2487
  print(f"Warmup iterations: {warmup_iterations}")
2242
2488
  print(f"Iterations: {iterations}")
@@ -2248,11 +2494,10 @@ def run(
2248
2494
  m, n, k, l = mnkl
2249
2495
 
2250
2496
  # Skip unsupported testcase
2251
- if not PersistentDenseGemmKernel.can_implement(
2497
+ if not GemmSm100.can_implement(
2252
2498
  ab_dtype,
2253
2499
  acc_dtype,
2254
2500
  d_dtype,
2255
- use_2cta_instrs,
2256
2501
  mma_tiler_mn,
2257
2502
  cluster_shape_mn,
2258
2503
  m,
@@ -2264,7 +2509,7 @@ def run(
2264
2509
  d_major,
2265
2510
  ):
2266
2511
  raise TypeError(
2267
- f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}"
2512
+ f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}"
2268
2513
  )
2269
2514
 
2270
2515
  if not torch.cuda.is_available():
@@ -2339,12 +2584,8 @@ def run(
2339
2584
  c, mC, c_torch = None, None, None
2340
2585
 
2341
2586
  # Configure gemm kernel
2342
- gemm = PersistentDenseGemmKernel(
2343
- acc_dtype,
2344
- use_2cta_instrs,
2345
- mma_tiler_mn,
2346
- cluster_shape_mn,
2347
- )
2587
+ cluster_shape_mnk = (*cluster_shape_mn, 1)
2588
+ gemm = GemmSm100(acc_dtype, ab_dtype, mma_tiler_mn, cluster_shape_mnk)
2348
2589
 
2349
2590
  # Compute max active clusters on current device
2350
2591
  hardware_info = cutlass.utils.HardwareInfo()
@@ -2356,6 +2597,17 @@ def run(
2356
2597
  else:
2357
2598
  tile_count_semaphore = None
2358
2599
 
2600
+ scheduler_args = TileSchedulerOptions(
2601
+ Int32(max_active_clusters),
2602
+ tile_count_semaphore=make_ptr(
2603
+ Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
2604
+ )
2605
+ if tile_count_semaphore is not None
2606
+ else None,
2607
+ )
2608
+ epi_args = gemm.EpilogueArguments()
2609
+ varlen_args = VarlenArguments()
2610
+
2359
2611
  # Get current CUDA stream from PyTorch
2360
2612
  torch_stream = torch.cuda.current_stream()
2361
2613
  # Get the raw stream pointer as a CUstream
@@ -2367,15 +2619,14 @@ def run(
2367
2619
  mB,
2368
2620
  mD,
2369
2621
  mC,
2370
- make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2371
- if tile_count_semaphore is not None
2372
- else None,
2373
- max_active_clusters,
2622
+ epi_args,
2623
+ scheduler_args,
2624
+ varlen_args,
2374
2625
  current_stream,
2375
2626
  )
2376
2627
 
2377
2628
  if not skip_ref_check:
2378
- compiled_gemm(mA, mB, mD, mC, tile_count_semaphore, current_stream)
2629
+ compiled_gemm(mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream)
2379
2630
  if ab_dtype in {
2380
2631
  cutlass.Int8,
2381
2632
  cutlass.Uint8,
@@ -2393,7 +2644,7 @@ def run(
2393
2644
  gpu_d = d_torch.cpu()
2394
2645
 
2395
2646
  # Convert ref to c_type
2396
- if d_dtype == cutlass.Float32:
2647
+ if d_dtype == Float32:
2397
2648
  ref_d = ref
2398
2649
  elif d_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
2399
2650
  # m major: (l, n, m) -> (m, n, l)
@@ -2463,7 +2714,9 @@ def run(
2463
2714
  print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2464
2715
 
2465
2716
  time.sleep(0.5)
2466
- fn = lambda: compiled_gemm(mA, mB, mD, mC, tile_count_semaphore, current_stream)
2717
+ fn = lambda: compiled_gemm(
2718
+ mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream
2719
+ )
2467
2720
  timing = do_bench(fn, warmup=warmup, rep=repeats)
2468
2721
  tflops = flops / (timing * 1e9) # Convert to TFlops
2469
2722
  print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
@@ -2505,12 +2758,7 @@ if __name__ == "__main__":
2505
2758
  parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
2506
2759
  parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
2507
2760
  parser.add_argument("--c_dtype", type=cutlass.dtype, default=None)
2508
- parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
2509
- parser.add_argument(
2510
- "--use_2cta_instrs",
2511
- action="store_true",
2512
- help="Enable 2CTA MMA instructions feature",
2513
- )
2761
+ parser.add_argument("--acc_dtype", type=cutlass.dtype, default=Float32)
2514
2762
  parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
2515
2763
  parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
2516
2764
  parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
@@ -2552,7 +2800,6 @@ if __name__ == "__main__":
2552
2800
  args.c_major,
2553
2801
  args.mma_tiler_mn,
2554
2802
  args.cluster_shape_mn,
2555
- args.use_2cta_instrs,
2556
2803
  args.tolerance,
2557
2804
  args.warmup_iterations,
2558
2805
  args.iterations,