quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,30 +1,5 @@
1
- # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD-3-Clause
3
-
4
- # Redistribution and use in source and binary forms, with or without
5
- # modification, are permitted provided that the following conditions are met:
6
-
7
- # 1. Redistributions of source code must retain the above copyright notice, this
8
- # list of conditions and the following disclaimer.
9
-
10
- # 2. Redistributions in binary form must reproduce the above copyright notice,
11
- # this list of conditions and the following disclaimer in the documentation
12
- # and/or other materials provided with the distribution.
13
-
14
- # 3. Neither the name of the copyright holder nor the names of its
15
- # contributors may be used to endorse or promote products derived from
16
- # this software without specific prior written permission.
17
-
18
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1
+ # Based on the cute-dsl example:
2
+ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
28
3
 
29
4
  import argparse
30
5
  from typing import Optional, Type, Tuple, Union, Callable
@@ -40,15 +15,16 @@ import cutlass.torch as cutlass_torch
40
15
  import cutlass.pipeline as pipeline
41
16
  import cutlass.utils.blackwell_helpers as sm100_utils
42
17
  import cutlass.utils.blockscaled_layout as blockscaled_utils
18
+ from cutlass import Int32, Float32, Boolean, const_expr
19
+ from cutlass.utils import LayoutEnum
43
20
  from cutlass.cute.runtime import from_dlpack, make_ptr
44
- from cutlass import Int32, const_expr
45
21
 
46
- from quack.cute_dsl_utils import ParamsBase
47
- from quack.tile_scheduler import (
48
- TileSchedulerArguments,
49
- TileScheduler,
50
- RasterOrderOption,
51
- )
22
+ from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
23
+ from quack.tile_scheduler import TileSchedulerOptions
24
+ from quack.varlen_utils import VarlenArguments
25
+ from quack.dense_gemm_sm90 import GemmSm90, NamedBarrierGemm
26
+
27
+ # return PipelineStateWAdvance instead of PipelineState
52
28
 
53
29
  """
54
30
  A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
@@ -72,8 +48,6 @@ This GEMM works as follows:
72
48
  - Type convert C matrix to output type.
73
49
  - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
74
50
  or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
75
- - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
76
- e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
77
51
 
78
52
  SM100 tcgen05.mma instructions operate as follows:
79
53
  - Read matrix A from SMEM
@@ -105,7 +79,7 @@ To collect performance with NCU profiler:
105
79
 
106
80
  Constraints are same as dense_gemm.py:
107
81
  * Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
108
- see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation
82
+ see detailed valid dtype combinations in below GemmSm100 class documentation
109
83
  * A/B tensor must have the same data type
110
84
  * Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
111
85
  * Mma tiler N must be 32-256, step 32
@@ -118,14 +92,12 @@ Constraints are same as dense_gemm.py:
118
92
  """
119
93
 
120
94
 
121
- class PersistentDenseGemmKernel:
95
+ class GemmSm100(GemmSm90):
122
96
  """This class implements batched matrix multiplication (C = A x B) with support for various data types
123
97
  and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
124
98
 
125
99
  :param acc_dtype: Data type for accumulation during computation
126
100
  :type acc_dtype: type[cutlass.Numeric]
127
- :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
128
- :type use_2cta_instrs: bool
129
101
  :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
130
102
  :type mma_tiler_mn: Tuple[int, int]
131
103
  :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
@@ -159,22 +131,27 @@ class PersistentDenseGemmKernel:
159
131
  - Cluster shape M/N must be positive and power of 2, total cluster size <= 16
160
132
 
161
133
  Example:
162
- >>> gemm = PersistentDenseGemmKernel(
163
- ... acc_dtype=cutlass.Float32,
164
- ... use_2cta_instrs=True,
134
+ >>> gemm = GemmSm100(
135
+ ... acc_dtype=Float32,
165
136
  ... mma_tiler_mn=(128, 128),
166
137
  ... cluster_shape_mn=(2, 2)
167
138
  ... )
168
139
  >>> gemm(mA, mB, mD, max_active_clusters, stream)
169
140
  """
170
141
 
142
+ arch = 100
143
+ num_epi_tensormaps = GemmSm90.num_epi_tensormaps
144
+
145
+ EpilogueArguments = GemmSm90.EpilogueArguments
146
+ EpilogueParams = GemmSm90.EpilogueParams
147
+
171
148
  def __init__(
172
149
  self,
173
150
  acc_dtype: Type[cutlass.Numeric],
174
- use_2cta_instrs: bool,
175
151
  mma_tiler_mn: Tuple[int, int],
176
- cluster_shape_mn: Tuple[int, int],
152
+ cluster_shape_mnk: Tuple[int, int, int],
177
153
  sf_vec_size: Optional[int] = None,
154
+ gather_A: bool = False,
178
155
  ):
179
156
  """Initializes the configuration for a Blackwell dense GEMM kernel.
180
157
 
@@ -187,47 +164,42 @@ class PersistentDenseGemmKernel:
187
164
  with cta_group=2 should be used.
188
165
 
189
166
  2. Cluster Shape:
190
- - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
167
+ - cluster_shape_mnk: The (ClusterM, ClusterN) shape of the CTA cluster.
191
168
 
192
169
  :param acc_dtype: Data type of the accumulator.
193
170
  :type acc_dtype: type[cutlass.Numeric]
194
171
  :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
195
172
  :type mma_tiler_mn: Tuple[int, int]
196
- :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
197
- :type use_2cta_instrs: bool
198
- :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
199
- :type cluster_shape_mn: Tuple[int, int]
173
+ :param cluster_shape_mnk: Tuple (ClusterM, ClusterN) shape of the cluster.
174
+ :type cluster_shape_mnk: Tuple[int, int]
200
175
  """
201
176
 
202
177
  self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
203
- self.use_2cta_instrs = use_2cta_instrs
204
- self.cluster_shape_mn = cluster_shape_mn
178
+ self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (128, 256)
179
+ self.cluster_shape_mnk = cluster_shape_mnk
180
+ assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1"
205
181
  # K dimension is deferred in _setup_attributes
206
182
  self.mma_tiler = (*mma_tiler_mn, 1)
207
183
  self.sf_vec_size = sf_vec_size
208
184
  self.blockscaled = sf_vec_size is not None
185
+ self.is_persistent = True
186
+ self.pingpong = False # for compatibility with GemmSm90
187
+ self.gather_A = gather_A
188
+ if gather_A:
189
+ assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
209
190
 
210
- self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
191
+ self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
211
192
 
212
193
  self.occupancy = 1
213
194
  # Set specialized warp ids
214
- self.epilog_warp_id = (
215
- 0,
216
- 1,
217
- 2,
218
- 3,
219
- )
195
+ self.epilog_warp_id = (0, 1, 2, 3)
220
196
  self.mma_warp_id = 4
221
197
  self.tma_warp_id = 5
222
198
  self.tma_epi_warp_id = 6
199
+ self.num_epi_warps = len(self.epilog_warp_id)
223
200
  self.threads_per_cta = 32 * len(
224
201
  (self.mma_warp_id, self.tma_warp_id, self.tma_epi_warp_id, *self.epilog_warp_id)
225
202
  )
226
- # Set barrier id for cta sync, epilogue sync and tmem ptr sync
227
- self.cta_sync_bar_id = 0
228
- self.epilog_sync_bar_id = 1
229
- self.tmem_ptr_sync_bar_id = 2
230
- self.epilog_load_bar_id = 3
231
203
  self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_100")
232
204
 
233
205
  def _setup_attributes(self):
@@ -261,7 +233,7 @@ class PersistentDenseGemmKernel:
261
233
 
262
234
  # Configure tiled mma
263
235
  if const_expr(not self.blockscaled):
264
- tiled_mma = sm100_utils.make_trivial_tiled_mma(
236
+ self.tiled_mma = sm100_utils.make_trivial_tiled_mma(
265
237
  self.a_dtype,
266
238
  self.a_major_mode,
267
239
  self.b_major_mode,
@@ -269,9 +241,9 @@ class PersistentDenseGemmKernel:
269
241
  self.cta_group,
270
242
  self.mma_tiler[:2],
271
243
  )
272
- tiled_mma_sfb = None
244
+ self.tiled_mma_sfb = None
273
245
  else:
274
- tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
246
+ self.tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
275
247
  self.a_dtype,
276
248
  self.a_major_mode,
277
249
  self.b_major_mode,
@@ -280,13 +252,13 @@ class PersistentDenseGemmKernel:
280
252
  self.cta_group,
281
253
  self.mma_inst_shape_mnk[:2],
282
254
  )
283
- tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
255
+ self.tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
284
256
  self.a_dtype,
285
257
  self.a_major_mode,
286
258
  self.b_major_mode,
287
259
  self.sf_dtype,
288
260
  self.sf_vec_size,
289
- cute.nvgpu.tcgen05.CtaGroup.ONE,
261
+ tcgen05.CtaGroup.ONE,
290
262
  self.mma_inst_shape_mnk_sfb[:2],
291
263
  )
292
264
 
@@ -306,20 +278,20 @@ class PersistentDenseGemmKernel:
306
278
  else:
307
279
  self.mma_tiler_sfb = None
308
280
  self.cta_tile_shape_mnk = (
309
- self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
281
+ self.mma_tiler[0] // cute.size(self.tiled_mma.thr_id.shape),
310
282
  self.mma_tiler[1],
311
283
  self.mma_tiler[2],
312
284
  )
313
285
 
314
286
  # Compute cluster layout
315
287
  self.cluster_layout_vmnk = cute.tiled_divide(
316
- cute.make_layout((*self.cluster_shape_mn, 1)),
317
- (tiled_mma.thr_id.shape,),
288
+ cute.make_layout(self.cluster_shape_mnk),
289
+ (self.tiled_mma.thr_id.shape,),
318
290
  )
319
291
  if const_expr(self.blockscaled):
320
292
  self.cluster_layout_sfb_vmnk = cute.tiled_divide(
321
- cute.make_layout((*self.cluster_shape_mn, 1)),
322
- (tiled_mma_sfb.thr_id.shape,),
293
+ cute.make_layout(self.cluster_shape_mnk),
294
+ (self.tiled_mma_sfb.thr_id.shape,),
323
295
  )
324
296
  else:
325
297
  self.cluster_layout_sfb_vmnk = None
@@ -344,11 +316,11 @@ class PersistentDenseGemmKernel:
344
316
  # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
345
317
  (
346
318
  self.num_acc_stage,
347
- self.num_ab_stage,
348
- self.num_d_stage,
349
- self.num_c_stage,
319
+ self.ab_stage,
320
+ self.epi_stage,
321
+ self.epi_c_stage,
350
322
  ) = self._compute_stages(
351
- tiled_mma,
323
+ self.tiled_mma,
352
324
  self.mma_tiler,
353
325
  self.a_dtype,
354
326
  self.b_dtype,
@@ -362,35 +334,36 @@ class PersistentDenseGemmKernel:
362
334
  self.smem_capacity,
363
335
  self.occupancy,
364
336
  )
337
+ self.sched_stage = 1 # For compatibility with GemmSm90
365
338
 
366
339
  # Compute A/B/SFA/SFB/C shared memory layout
367
340
  self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
368
- tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage
341
+ self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
369
342
  )
370
343
  self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
371
- tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage
344
+ self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage
372
345
  )
373
- self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
374
- self.d_dtype, self.d_layout, self.epi_tile, self.num_d_stage
346
+ self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi(
347
+ self.d_dtype, self.d_layout, self.epi_tile, self.epi_stage
375
348
  )
376
349
  if const_expr(self.c_dtype is not None):
377
350
  self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
378
- self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage
351
+ self.c_dtype, self.c_layout, self.epi_tile, self.epi_c_stage
379
352
  )
380
353
  else:
381
354
  self.epi_c_smem_layout_staged = None
382
355
  if const_expr(self.blockscaled):
383
356
  self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
384
- tiled_mma,
357
+ self.tiled_mma,
385
358
  self.mma_tiler,
386
359
  self.sf_vec_size,
387
- self.num_ab_stage,
360
+ self.ab_stage,
388
361
  )
389
362
  self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
390
- tiled_mma,
363
+ self.tiled_mma,
391
364
  self.mma_tiler,
392
365
  self.sf_vec_size,
393
- self.num_ab_stage,
366
+ self.ab_stage,
394
367
  )
395
368
  else:
396
369
  self.sfa_smem_layout_staged, self.sfb_smem_layout_staged = None, None
@@ -398,7 +371,7 @@ class PersistentDenseGemmKernel:
398
371
  # Compute the number of tensor memory allocation columns
399
372
  if const_expr(not self.blockscaled):
400
373
  self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
401
- tiled_mma, self.mma_tiler, self.num_acc_stage
374
+ self.tiled_mma, self.mma_tiler, self.num_acc_stage
402
375
  )
403
376
  else:
404
377
  SM100_TMEM_CAPACITY_COLUMNS = 512
@@ -409,14 +382,14 @@ class PersistentDenseGemmKernel:
409
382
  self,
410
383
  mA: cute.Tensor,
411
384
  mB: cute.Tensor,
412
- mD: cute.Tensor,
385
+ mD: Optional[cute.Tensor],
413
386
  mC: Optional[cute.Tensor],
414
- tile_count_semaphore: Optional[cute.Pointer],
415
- max_active_clusters: cutlass.Constexpr,
387
+ epilogue_args: ArgumentsBase,
388
+ scheduler_args: TileSchedulerOptions,
389
+ varlen_args: Optional[VarlenArguments],
416
390
  stream: cuda.CUstream,
417
391
  mSFA: Optional[cute.Tensor] = None,
418
392
  mSFB: Optional[cute.Tensor] = None,
419
- epilogue_op: cutlass.Constexpr = lambda x: x,
420
393
  ):
421
394
  """Execute the GEMM operation in steps:
422
395
  - Setup static attributes before smem/grid/tma computation
@@ -435,30 +408,46 @@ class PersistentDenseGemmKernel:
435
408
  :type max_active_clusters: cutlass.Constexpr
436
409
  :param stream: CUDA stream for asynchronous execution
437
410
  :type stream: cuda.CUstream
438
- :param epilogue_op: Optional elementwise lambda function to apply to the output tensor
439
- :type epilogue_op: cutlass.Constexpr
440
411
  :raises TypeError: If input data types are incompatible with the MMA instruction.
441
412
  :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
442
413
  """
443
414
  if const_expr(self.blockscaled):
444
415
  assert mSFA is not None and mSFB is not None
445
416
  # Setup static attributes before smem/grid/tma computation
446
- self.a_dtype: Type[cutlass.Numeric] = mA.element_type
447
- self.b_dtype: Type[cutlass.Numeric] = mB.element_type
448
- self.d_dtype: Type[cutlass.Numeric] = mD.element_type
417
+ self.a_dtype = mA.element_type
418
+ self.b_dtype = mB.element_type
419
+ self.d_dtype = mD.element_type if mD is not None else None
449
420
  self.c_dtype = mC.element_type if mC is not None else None
450
421
  self.sf_dtype: Optional[Type[cutlass.Numeric]] = (
451
422
  mSFA.element_type if mSFA is not None else None
452
423
  )
453
- self.a_major_mode = cutlass.utils.LayoutEnum.from_tensor(mA).mma_major_mode()
454
- self.b_major_mode = cutlass.utils.LayoutEnum.from_tensor(mB).mma_major_mode()
455
- self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
456
- self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
424
+ self.a_layout = LayoutEnum.from_tensor(mA)
425
+ self.b_layout = LayoutEnum.from_tensor(mB)
426
+ self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
427
+ self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
428
+ self.a_major_mode = LayoutEnum.from_tensor(mA).mma_major_mode()
429
+ self.b_major_mode = LayoutEnum.from_tensor(mB).mma_major_mode()
457
430
 
458
431
  # Check if input data types are compatible with MMA instruction
459
432
  if const_expr(self.a_dtype != self.b_dtype):
460
433
  raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
461
434
 
435
+ if const_expr(varlen_args is None):
436
+ varlen_args = VarlenArguments()
437
+ assert (varlen_args.mAIdx is not None) == self.gather_A
438
+
439
+ # Assume all strides are divisible by 128 bits except the last stride
440
+ new_stride = lambda t: tuple(
441
+ cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
442
+ for s in t.stride
443
+ )
444
+ mA, mD = [
445
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
446
+ if t is not None
447
+ else None
448
+ for t in (mA, mD)
449
+ ]
450
+
462
451
  # Setup attributes that dependent on gemm inputs
463
452
  self._setup_attributes()
464
453
 
@@ -471,67 +460,44 @@ class PersistentDenseGemmKernel:
471
460
  sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(mB.shape, self.sf_vec_size)
472
461
  mSFB = cute.make_tensor(mSFB.iterator, sfb_layout)
473
462
 
474
- if const_expr(not self.blockscaled):
475
- tiled_mma = sm100_utils.make_trivial_tiled_mma(
476
- self.a_dtype,
477
- self.a_major_mode,
478
- self.b_major_mode,
479
- self.acc_dtype,
480
- self.cta_group,
481
- self.mma_tiler[:2],
482
- )
483
- tiled_mma_sfb = None
484
- else:
485
- tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
486
- self.a_dtype,
487
- self.a_major_mode,
488
- self.b_major_mode,
489
- self.sf_dtype,
490
- self.sf_vec_size,
491
- self.cta_group,
492
- self.mma_inst_shape_mnk[:2],
493
- )
494
- tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
495
- self.a_dtype,
496
- self.a_major_mode,
497
- self.b_major_mode,
498
- self.sf_dtype,
499
- self.sf_vec_size,
500
- cute.nvgpu.tcgen05.CtaGroup.ONE,
501
- self.mma_inst_shape_mnk_sfb[:2],
502
- )
503
- atom_thr_size = cute.size(tiled_mma.thr_id.shape)
463
+ atom_thr_size = cute.size(self.tiled_mma.thr_id.shape)
504
464
 
505
- # Setup TMA load for A
506
- a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
465
+ # Setup TMA load for A & B
507
466
  a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
508
- tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
509
- a_op,
510
- mA,
511
- a_smem_layout,
512
- self.mma_tiler,
513
- tiled_mma,
514
- self.cluster_layout_vmnk.shape,
515
- internal_type=(cutlass.TFloat32 if mA.element_type is cutlass.Float32 else None),
516
- )
517
-
518
- # Setup TMA load for B
519
- b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
520
467
  b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
468
+ tma_atom_a, tma_tensor_a = None, None
469
+ if const_expr(not self.gather_A):
470
+ a_op = sm100_utils.cluster_shape_to_tma_atom_A(
471
+ self.cluster_shape_mnk, self.tiled_mma.thr_id
472
+ )
473
+ tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
474
+ a_op,
475
+ mA,
476
+ a_smem_layout,
477
+ self.mma_tiler,
478
+ self.tiled_mma,
479
+ self.cluster_layout_vmnk.shape,
480
+ internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None),
481
+ )
482
+ b_op = sm100_utils.cluster_shape_to_tma_atom_B(
483
+ self.cluster_shape_mnk, self.tiled_mma.thr_id
484
+ )
521
485
  tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
522
486
  b_op,
523
487
  mB,
524
488
  b_smem_layout,
525
489
  self.mma_tiler,
526
- tiled_mma,
490
+ self.tiled_mma,
527
491
  self.cluster_layout_vmnk.shape,
528
- internal_type=(cutlass.TFloat32 if mB.element_type is cutlass.Float32 else None),
492
+ internal_type=(cutlass.TFloat32 if mB.element_type is Float32 else None),
529
493
  )
530
494
 
495
+ tma_atom_sfa, tma_tensor_sfa = None, None
496
+ tma_atom_sfb, tma_tensor_sfb = None, None
531
497
  if const_expr(self.blockscaled):
532
498
  # Setup TMA load for SFA
533
499
  sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
534
- self.cluster_shape_mn, tiled_mma.thr_id
500
+ self.cluster_shape_mnk, self.tiled_mma.thr_id
535
501
  )
536
502
  sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
537
503
  tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
@@ -539,13 +505,13 @@ class PersistentDenseGemmKernel:
539
505
  mSFA,
540
506
  sfa_smem_layout,
541
507
  self.mma_tiler,
542
- tiled_mma,
508
+ self.tiled_mma,
543
509
  self.cluster_layout_vmnk.shape,
544
510
  internal_type=cutlass.Int16,
545
511
  )
546
512
  # Setup TMA load for SFB
547
513
  sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
548
- self.cluster_shape_mn, tiled_mma.thr_id
514
+ self.cluster_shape_mnk, self.tiled_mma.thr_id
549
515
  )
550
516
  sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
551
517
  tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
@@ -553,30 +519,31 @@ class PersistentDenseGemmKernel:
553
519
  mSFB,
554
520
  sfb_smem_layout,
555
521
  self.mma_tiler_sfb,
556
- tiled_mma_sfb,
522
+ self.tiled_mma_sfb,
557
523
  self.cluster_layout_sfb_vmnk.shape,
558
524
  internal_type=cutlass.Int16,
559
525
  )
560
- else:
561
- tma_atom_sfa, tma_tensor_sfa = None, None
562
- tma_atom_sfb, tma_tensor_sfb = None, None
563
526
 
564
- a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
565
- b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
566
- self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
527
+ self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
528
+ if const_expr(not self.gather_A):
529
+ self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
567
530
  if const_expr(self.blockscaled):
568
531
  sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
569
532
  sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
570
- self.num_tma_load_bytes += (sfa_copy_size + sfb_copy_size) * atom_thr_size
533
+ self.num_tma_load_bytes += sfa_copy_size + sfb_copy_size
534
+ self.num_tma_load_bytes *= atom_thr_size
571
535
 
572
536
  # Setup TMA store for D
573
- epi_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
574
- tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
575
- cpasync.CopyBulkTensorTileS2GOp(),
576
- mD,
577
- epi_smem_layout,
578
- self.epi_tile,
579
- )
537
+ tma_atom_d, tma_tensor_d = None, None
538
+ if const_expr(mD is not None):
539
+ epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0))
540
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
541
+ cpasync.CopyBulkTensorTileS2GOp(),
542
+ mD,
543
+ epi_smem_layout,
544
+ self.epi_tile,
545
+ )
546
+ tma_atom_c, tma_tensor_c = None, None
580
547
  if const_expr(mC is not None):
581
548
  epi_c_smem_layout = cute.slice_(self.epi_c_smem_layout_staged, (None, None, 0))
582
549
  tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
@@ -585,26 +552,19 @@ class PersistentDenseGemmKernel:
585
552
  epi_c_smem_layout,
586
553
  self.epi_tile,
587
554
  )
588
- else:
589
- tma_atom_c, tma_tensor_c = None, None
590
555
 
591
- problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.cta_tile_shape_mnk[:2]) + (
592
- mD.shape[2],
593
- )
594
- TileSchedulerCls = TileScheduler
595
- tile_sched_args = TileSchedulerArguments(
596
- problem_shape_ntile_mnl=problem_shape_ntile_mnl,
597
- raster_order=RasterOrderOption.Heuristic,
598
- group_size=8,
599
- cluster_shape_mnk=(*self.cluster_shape_mn, 1),
600
- tile_count_semaphore=tile_count_semaphore,
601
- is_persistent=True,
602
- )
556
+ epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
557
+
558
+ TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
559
+ tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
603
560
  tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
604
- grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
561
+ grid = TileSchedulerCls.get_grid_shape(
562
+ tile_sched_params, scheduler_args.max_active_clusters
563
+ )
605
564
 
606
565
  self.buffer_align_bytes = 1024
607
566
 
567
+ epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
608
568
  epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
609
569
  sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU
610
570
  sfa_smem_size = (
@@ -617,18 +577,18 @@ class PersistentDenseGemmKernel:
617
577
  # Define shared storage for kernel
618
578
  @cute.struct
619
579
  class SharedStorage:
620
- ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
621
- ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
622
- epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage * 2]
623
- acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
624
- acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
580
+ ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
581
+ epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
582
+ acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
583
+ sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
584
+ tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
625
585
  tmem_dealloc_mbar_ptr: cutlass.Int64
626
586
  tmem_holding_buf: Int32
627
- sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, 2]
628
- tile_count: cute.struct.MemRange[cutlass.Int32, 1]
629
587
  # (EPI_TILE_M, EPI_TILE_N, STAGE)
630
588
  sD: cute.struct.Align[
631
- cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)],
589
+ cute.struct.MemRange[
590
+ self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
591
+ ],
632
592
  self.buffer_align_bytes,
633
593
  ]
634
594
  sC: cute.struct.Align[
@@ -637,6 +597,7 @@ class PersistentDenseGemmKernel:
637
597
  ],
638
598
  self.buffer_align_bytes,
639
599
  ]
600
+ epi: self.epi_get_smem_struct(epilogue_params)
640
601
  # (MMA, MMA_M, MMA_K, STAGE)
641
602
  sA: cute.struct.Align[
642
603
  cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
@@ -662,10 +623,10 @@ class PersistentDenseGemmKernel:
662
623
 
663
624
  # Launch the kernel synchronously
664
625
  self.kernel(
665
- tiled_mma,
666
- tiled_mma_sfb,
626
+ self.tiled_mma,
627
+ self.tiled_mma_sfb,
667
628
  tma_atom_a,
668
- tma_tensor_a,
629
+ tma_tensor_a if const_expr(not self.gather_A) else mA,
669
630
  tma_atom_b,
670
631
  tma_tensor_b,
671
632
  tma_atom_sfa,
@@ -676,24 +637,29 @@ class PersistentDenseGemmKernel:
676
637
  tma_tensor_d,
677
638
  tma_atom_c,
678
639
  tma_tensor_c,
640
+ epilogue_params,
641
+ varlen_args.mCuSeqlensM,
642
+ varlen_args.mCuSeqlensK,
643
+ varlen_args.mTensormaps,
644
+ varlen_args.mAIdx,
679
645
  self.cluster_layout_vmnk,
680
646
  self.cluster_layout_sfb_vmnk,
681
647
  self.a_smem_layout_staged,
682
648
  self.b_smem_layout_staged,
683
649
  self.sfa_smem_layout_staged,
684
650
  self.sfb_smem_layout_staged,
685
- self.d_smem_layout_staged,
651
+ self.epi_smem_layout_staged,
686
652
  self.epi_c_smem_layout_staged,
687
653
  self.epi_tile,
688
654
  tile_sched_params,
689
655
  TileSchedulerCls,
690
- epilogue_op,
691
656
  ).launch(
692
657
  grid=grid,
693
658
  block=[self.threads_per_cta, 1, 1],
694
- cluster=(*self.cluster_shape_mn, 1),
659
+ cluster=self.cluster_shape_mnk,
695
660
  smem=self.shared_storage.size_in_bytes(),
696
661
  stream=stream,
662
+ min_blocks_per_mp=1,
697
663
  )
698
664
  return
699
665
 
@@ -703,7 +669,7 @@ class PersistentDenseGemmKernel:
703
669
  self,
704
670
  tiled_mma: cute.TiledMma,
705
671
  tiled_mma_sfb: Optional[cute.TiledMma],
706
- tma_atom_a: cute.CopyAtom,
672
+ tma_atom_a: Optional[cute.CopyAtom],
707
673
  mA_mkl: cute.Tensor,
708
674
  tma_atom_b: cute.CopyAtom,
709
675
  mB_nkl: cute.Tensor,
@@ -712,37 +678,54 @@ class PersistentDenseGemmKernel:
712
678
  tma_atom_sfb: Optional[cute.CopyAtom],
713
679
  mSFB_nkl: Optional[cute.Tensor],
714
680
  tma_atom_d: Optional[cute.CopyAtom],
715
- mD_mnl: cute.Tensor,
681
+ mD_mnl: Optional[cute.Tensor],
716
682
  tma_atom_c: Optional[cute.CopyAtom],
717
683
  mC_mnl: Optional[cute.Tensor],
684
+ epilogue_params: ParamsBase,
685
+ cu_seqlens_m: Optional[cute.Tensor],
686
+ cu_seqlens_k: Optional[cute.Tensor],
687
+ tensormaps: Optional[cute.Tensor],
688
+ mAIdx: Optional[cute.Tensor],
718
689
  cluster_layout_vmnk: cute.Layout,
719
690
  cluster_layout_sfb_vmnk: Optional[cute.Layout],
720
- a_smem_layout_staged: cute.ComposedLayout,
721
- b_smem_layout_staged: cute.ComposedLayout,
722
- sfa_smem_layout_staged: Optional[cute.Layout],
723
- sfb_smem_layout_staged: Optional[cute.Layout],
724
- d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
725
- epi_c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
691
+ a_smem_layout: cute.ComposedLayout,
692
+ b_smem_layout: cute.ComposedLayout,
693
+ sfa_smem_layout: Optional[cute.Layout],
694
+ sfb_smem_layout: Optional[cute.Layout],
695
+ epi_smem_layout: Union[cute.Layout, cute.ComposedLayout, None],
696
+ epi_c_smem_layout: Union[cute.Layout, cute.ComposedLayout, None],
726
697
  epi_tile: cute.Tile,
727
698
  tile_sched_params: ParamsBase,
728
699
  TileSchedulerCls: cutlass.Constexpr[Callable],
729
- epilogue_op: cutlass.Constexpr[Callable],
730
700
  ):
731
701
  """
732
702
  GPU device kernel performing the Persistent batched GEMM computation.
733
703
  """
704
+
705
+ varlen_m = const_expr(cu_seqlens_m is not None)
706
+ varlen_k = const_expr(cu_seqlens_k is not None)
707
+ assert not (varlen_m and varlen_k)
708
+ if const_expr(self.gather_A):
709
+ assert varlen_m or varlen_k
710
+ has_D = const_expr(mD_mnl is not None)
711
+ has_C = const_expr(mC_mnl is not None)
712
+
734
713
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
735
714
 
736
- #
737
- # Prefetch tma desc
738
- #
715
+ # /////////////////////////////////////////////////////////////////////////////
716
+ # Prefetch Tma desc
717
+ # /////////////////////////////////////////////////////////////////////////////
739
718
  if warp_idx == self.tma_warp_id:
740
- cpasync.prefetch_descriptor(tma_atom_a)
741
- cpasync.prefetch_descriptor(tma_atom_b)
742
- if const_expr(self.blockscaled):
743
- cpasync.prefetch_descriptor(tma_atom_sfa)
744
- cpasync.prefetch_descriptor(tma_atom_sfb)
745
- cpasync.prefetch_descriptor(tma_atom_d)
719
+ for tma_atom in (
720
+ tma_atom_a,
721
+ tma_atom_b,
722
+ tma_atom_sfa,
723
+ tma_atom_sfb,
724
+ tma_atom_d,
725
+ tma_atom_c,
726
+ ):
727
+ if const_expr(tma_atom is not None):
728
+ cpasync.prefetch_descriptor(tma_atom)
746
729
 
747
730
  use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
748
731
 
@@ -754,13 +737,6 @@ class PersistentDenseGemmKernel:
754
737
  mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
755
738
  is_leader_cta = mma_tile_coord_v == 0
756
739
  cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
757
- block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
758
- if const_expr(self.blockscaled):
759
- block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
760
- cta_rank_in_cluster
761
- )
762
- else:
763
- block_in_cluster_coord_sfb_vmnk = None
764
740
  # Coord inside cta
765
741
  tidx, _, _ = cute.arch.thread_idx()
766
742
 
@@ -779,100 +755,53 @@ class PersistentDenseGemmKernel:
779
755
  num_tmem_dealloc_threads = 32
780
756
  cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
781
757
 
782
- # Initialize mainloop ab_pipeline (barrier) and states
783
- ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
784
- num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
785
- ab_pipeline_consumer_group = pipeline.CooperativeGroup(
786
- pipeline.Agent.Thread, num_tma_producer
787
- )
788
- ab_pipeline = pipeline.PipelineTmaUmma.create(
789
- barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
790
- num_stages=self.num_ab_stage,
791
- producer_group=ab_pipeline_producer_group,
792
- consumer_group=ab_pipeline_consumer_group,
793
- tx_count=self.num_tma_load_bytes,
794
- cta_layout_vmnk=cluster_layout_vmnk,
758
+ # Initialize pipelines and states
759
+ ab_pipeline = self.make_ab_pipeline(
760
+ tiled_mma=tiled_mma,
761
+ cluster_layout_vmnk=cluster_layout_vmnk,
762
+ ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
795
763
  )
796
-
797
- if const_expr(mC_mnl is not None):
798
- # Threads/warps participating in this pipeline
799
- epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
800
- # Each warp will contribute 1 to the arrive count
801
- consumer_arrive_cnt = len(self.epilog_warp_id)
802
- epi_pipeline_consumer_group = pipeline.CooperativeGroup(
803
- pipeline.Agent.Thread, consumer_arrive_cnt
804
- )
805
- c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
806
- tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
807
- epi_pipeline = pipeline.PipelineTmaAsync.create(
808
- barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
809
- num_stages=self.num_c_stage,
810
- producer_group=epi_pipeline_producer_group,
811
- consumer_group=epi_pipeline_consumer_group,
812
- tx_count=tma_copy_c_bytes,
764
+ epi_pipeline = None
765
+ if const_expr(has_C):
766
+ epi_pipeline = self.make_epi_pipeline(
767
+ c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
768
+ epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
813
769
  )
814
- else:
815
- epi_pipeline = None
816
-
817
- # Initialize acc_pipeline (barrier) and states
818
- acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
819
- num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
820
- acc_pipeline_consumer_group = pipeline.CooperativeGroup(
821
- pipeline.Agent.Thread, num_acc_consumer_threads
822
- )
823
- acc_pipeline = pipeline.PipelineUmmaAsync.create(
824
- barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
825
- num_stages=self.num_acc_stage,
826
- producer_group=acc_pipeline_producer_group,
827
- consumer_group=acc_pipeline_consumer_group,
828
- cta_layout_vmnk=cluster_layout_vmnk,
770
+ acc_pipeline = self.make_acc_pipeline(
771
+ cluster_layout_vmnk=cluster_layout_vmnk,
772
+ acc_pipeline_mbar_ptr=storage.acc_pipeline_array_ptr.data_ptr(),
829
773
  )
830
-
831
- # if const_expr(tile_sched_params.tile_count_semaphore is not None):
832
- # # Dynamic persistent scheduler
833
- # # Threads/warps participating in this pipeline
834
- # sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
835
- # cluster_size = cute.size(cluster_layout_vmnk)
836
- # # Each warp that are not the scheduler warp will contribute 1 to the arrive count
837
- # consumer_arrive_cnt = (
838
- # (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
839
- # ) * cluster_size - 1
840
- # sched_pipeline_consumer_group = pipeline.CooperativeGroup(
841
- # pipeline.Agent.Thread, consumer_arrive_cnt
842
- # )
843
- # sched_pipeline = pipeline.PipelineAsync.create(
844
- # barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
845
- # num_stages=self.sched_stage,
846
- # producer_group=sched_pipeline_producer_group,
847
- # consumer_group=sched_pipeline_consumer_group,
848
- # # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
849
- # consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
850
- # )
851
- # tile_count = storage.tile_count.get_tensor((self.sched_stage,))
852
- # else:
853
- # sched_pipeline = None
854
- # tile_count = None
774
+ sched_pipeline = None
775
+ tile_count = None
776
+ if const_expr(tile_sched_params.tile_count_semaphore is not None):
777
+ # TODO: Untested, not sure if this is right for Sm100
778
+ # Dynamic persistent scheduler
779
+ sched_pipeline = self.make_sched_pipeline(
780
+ self.cluster_shape_mnk,
781
+ sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
782
+ varlen_k=varlen_k,
783
+ )
784
+ tile_count = storage.tile_count.get_tensor((self.sched_stage,))
855
785
 
856
786
  # Setup smem tensor A/B/D
857
787
  # (MMA, MMA_M, MMA_K, STAGE)
858
- sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
788
+ sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
859
789
  # (MMA, MMA_N, MMA_K, STAGE)
860
- sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
790
+ sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
791
+ sSFA, sSFB = None, None
861
792
  if const_expr(self.blockscaled):
862
793
  # (MMA, MMA_M, MMA_K, STAGE)
863
- sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
794
+ sSFA = storage.sSFA.get_tensor(sfa_smem_layout)
864
795
  # (MMA, MMA_N, MMA_K, STAGE)
865
- sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
866
- else:
867
- sSFA, sSFB = None, None
868
- # (EPI_TILE_M, EPI_TILE_N, STAGE)
869
- sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
870
- if const_expr(mC_mnl is not None):
871
- sC = storage.sC.get_tensor(
872
- epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
873
- )
874
- else:
875
- sC = None
796
+ sSFB = storage.sSFB.get_tensor(sfb_smem_layout)
797
+ sD = None
798
+ if const_expr(has_D):
799
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
800
+ sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
801
+ sC = None
802
+ if const_expr(has_C):
803
+ sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
804
+ epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
876
805
 
877
806
  thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
878
807
  thr_mma_sfb = (
@@ -884,26 +813,51 @@ class PersistentDenseGemmKernel:
884
813
  # (MMA, MMA_M, MMA_N, STAGE)
885
814
  tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
886
815
 
887
- tmem_ptr_read_threads = cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id))
888
- tmem_alloc_barrier = pipeline.NamedBarrier(
889
- barrier_id=self.tmem_ptr_sync_bar_id, num_threads=tmem_ptr_read_threads
816
+ # Get tensormap buffer address
817
+ tensormap_manager, tensormap_ab_ptrs, tensormap_d_ptr, tensormap_epi_ptrs = (
818
+ self.tensormap_init(tensormaps, varlen_m, varlen_k, has_D, warp_idx)
890
819
  )
891
820
 
892
- TileSchedulerCls = partial(TileSchedulerCls.create, tile_sched_params)
893
- k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.mma_tiler[2])
821
+ TileSchedulerCls = partial(
822
+ TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
823
+ )
894
824
 
895
- if const_expr(mC_mnl is not None):
825
+ tmem_alloc_barrier = pipeline.NamedBarrier(
826
+ barrier_id=int(NamedBarrierGemm.TmemPtr),
827
+ num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
828
+ )
829
+ epi_load_barrier = None
830
+ if const_expr(has_C):
896
831
  epi_load_barrier = pipeline.NamedBarrier(
897
- barrier_id=int(self.epilog_load_bar_id), num_threads=2 * cute.arch.WARP_SIZE
832
+ barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE
898
833
  )
899
- else:
900
- epi_load_barrier = None
901
834
 
902
835
  #
903
836
  # Specialized TMA load warp
904
837
  #
905
838
  if warp_idx == self.tma_warp_id:
839
+ if const_expr(varlen_k):
840
+ # initialize tensormap for A & B
841
+ if const_expr(not self.gather_A):
842
+ tensormap_manager.init_tensormap_from_atom(
843
+ tma_atom_a,
844
+ tensormap_ab_ptrs[0],
845
+ is_manager_warp=True,
846
+ )
847
+ tensormap_manager.init_tensormap_from_atom(
848
+ tma_atom_b,
849
+ tensormap_ab_ptrs[1],
850
+ is_manager_warp=True,
851
+ )
906
852
  # Compute multicast mask for A/B buffer full
853
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
854
+ block_in_cluster_coord_sfb_vmnk = None
855
+ if const_expr(self.blockscaled):
856
+ block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
857
+ cta_rank_in_cluster
858
+ )
859
+ a_mcast_mask, b_mcast_mask = None, None
860
+ sfa_mcast_mask, sfb_mcast_mask = None, None
907
861
  if const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
908
862
  a_mcast_mask = cpasync.create_tma_multicast_mask(
909
863
  cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
@@ -918,28 +872,45 @@ class PersistentDenseGemmKernel:
918
872
  sfb_mcast_mask = cpasync.create_tma_multicast_mask(
919
873
  cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1
920
874
  )
921
- else:
922
- sfa_mcast_mask, sfb_mcast_mask = None, None
923
- else:
924
- a_mcast_mask, b_mcast_mask = None, None
925
- sfa_mcast_mask, sfb_mcast_mask = None, None
926
875
 
927
876
  # Persistent tile scheduling loop
928
- tile_scheduler = TileSchedulerCls()
877
+ is_scheduler_warp = True
878
+ if const_expr(cute.size(cluster_layout_vmnk) > 1):
879
+ is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0
880
+ tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
929
881
  work_tile = tile_scheduler.initial_work_tile_info()
930
882
  ab_producer_state = pipeline.make_pipeline_state(
931
- pipeline.PipelineUserType.Producer, self.num_ab_stage
883
+ pipeline.PipelineUserType.Producer, self.ab_stage
932
884
  )
933
- do_epi_load_barrier_arrive = cutlass.Boolean(True)
885
+ if const_expr(varlen_k):
886
+ # wait tensormap initialization complete before update
887
+ tensormap_manager.fence_tensormap_initialization()
888
+ # batch index of last tile
889
+ last_batch_idx = cutlass.Int32(-1)
890
+ do_epi_load_barrier_arrive = Boolean(True)
934
891
  while work_tile.is_valid_tile:
935
- # Get tile coord from tile scheduler
936
892
  tile_coord_mnkl = work_tile.tile_idx
893
+ batch_idx = tile_coord_mnkl[3]
894
+ if const_expr(varlen_k):
895
+ is_group_changed = batch_idx != last_batch_idx
896
+ last_batch_idx = batch_idx
897
+ if is_group_changed:
898
+ self.tensormap_update_AB(
899
+ tensormap_manager,
900
+ tensormap_ab_ptrs,
901
+ cu_seqlens_k,
902
+ batch_idx,
903
+ is_manager_warp=True,
904
+ )
905
+ # ///////////////////////////////////////////////////////////////////////////
906
+ # Local_tile partition global tensors
907
+ # ///////////////////////////////////////////////////////////////////////////
937
908
  mma_tile_coord_mnl = (
938
909
  tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
939
910
  tile_coord_mnkl[1],
940
911
  tile_coord_mnkl[3],
941
912
  )
942
- # Local_tile partition global tensors
913
+ # TODO: varlen_m
943
914
  # (bM, bK, RestK)
944
915
  gA_mkl = cute.local_tile(
945
916
  mA_mkl,
@@ -1007,7 +978,7 @@ class PersistentDenseGemmKernel:
1007
978
  sfa_cta_layout = a_cta_layout
1008
979
  # ((atom_v, rest_v), STAGE)
1009
980
  # ((atom_v, rest_v), RestK)
1010
- tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition(
981
+ tAsSFA, tAgSFA = cpasync.tma_partition(
1011
982
  tma_atom_sfa,
1012
983
  block_in_cluster_coord_vmnk[2],
1013
984
  sfa_cta_layout,
@@ -1022,7 +993,7 @@ class PersistentDenseGemmKernel:
1022
993
  )
1023
994
  # ((atom_v, rest_v), STAGE)
1024
995
  # ((atom_v, rest_v), RestK)
1025
- tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition(
996
+ tBsSFB, tBgSFB = cpasync.tma_partition(
1026
997
  tma_atom_sfb,
1027
998
  block_in_cluster_coord_sfb_vmnk[1],
1028
999
  sfb_cta_layout,
@@ -1060,12 +1031,15 @@ class PersistentDenseGemmKernel:
1060
1031
  # with loading A and B.
1061
1032
  if do_epi_load_barrier_arrive:
1062
1033
  epi_load_barrier.arrive()
1063
- do_epi_load_barrier_arrive = cutlass.Boolean(False)
1034
+ do_epi_load_barrier_arrive = Boolean(False)
1064
1035
  # Advance to next tile
1036
+ tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
1065
1037
  tile_scheduler.advance_to_next_work()
1066
1038
  work_tile = tile_scheduler.get_current_work()
1067
1039
  # Wait A/B buffer empty
1068
1040
  ab_pipeline.producer_tail(ab_producer_state)
1041
+ if is_scheduler_warp:
1042
+ tile_scheduler.producer_tail()
1069
1043
 
1070
1044
  #
1071
1045
  # Specialized TMA epi load warp
@@ -1073,15 +1047,16 @@ class PersistentDenseGemmKernel:
1073
1047
  if const_expr(mC_mnl is not None):
1074
1048
  if warp_idx == self.tma_epi_warp_id:
1075
1049
  epi_producer_state = pipeline.make_pipeline_state(
1076
- pipeline.PipelineUserType.Producer, self.num_c_stage
1050
+ pipeline.PipelineUserType.Producer, self.epi_c_stage
1077
1051
  )
1078
- do_epi_load_barrier_wait = cutlass.Boolean(True)
1052
+ do_epi_load_barrier_wait = Boolean(True)
1079
1053
  # Persistent tile scheduling loop
1080
1054
  tile_scheduler = TileSchedulerCls()
1081
1055
  work_tile = tile_scheduler.initial_work_tile_info()
1082
1056
  while work_tile.is_valid_tile:
1083
1057
  # Get tile coord from tile scheduler
1084
1058
  tile_coord_mnkl = work_tile.tile_idx
1059
+ # TODO: varlen_m
1085
1060
  mma_tile_coord_mnl = (
1086
1061
  tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
1087
1062
  tile_coord_mnkl[1],
@@ -1102,7 +1077,7 @@ class PersistentDenseGemmKernel:
1102
1077
  bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
1103
1078
  if do_epi_load_barrier_wait:
1104
1079
  epi_load_barrier.arrive_and_wait()
1105
- do_epi_load_barrier_wait = cutlass.Boolean(False)
1080
+ do_epi_load_barrier_wait = Boolean(False)
1106
1081
  epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1]))
1107
1082
  for subtile_idx in cutlass.range(epi_tile_num, unroll=1):
1108
1083
  epi_pipeline.producer_acquire(epi_producer_state)
@@ -1149,10 +1124,9 @@ class PersistentDenseGemmKernel:
1149
1124
  tiled_mma,
1150
1125
  self.mma_tiler,
1151
1126
  self.sf_vec_size,
1152
- cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
1127
+ cute.slice_(sfa_smem_layout, (None, None, None, 0)),
1153
1128
  )
1154
1129
  tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
1155
-
1156
1130
  # Make SFB tmem tensor
1157
1131
  sfb_tmem_ptr = cute.recast_ptr(
1158
1132
  acc_tmem_ptr
@@ -1165,7 +1139,7 @@ class PersistentDenseGemmKernel:
1165
1139
  tiled_mma,
1166
1140
  self.mma_tiler,
1167
1141
  self.sf_vec_size,
1168
- cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
1142
+ cute.slice_(sfb_smem_layout, (None, None, None, 0)),
1169
1143
  )
1170
1144
  tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
1171
1145
  # Partition for S2T copy of SFA/SFB
@@ -1183,11 +1157,12 @@ class PersistentDenseGemmKernel:
1183
1157
  tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
1184
1158
  tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
1185
1159
 
1160
+ k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.mma_tiler[2])
1186
1161
  # Persistent tile scheduling loop
1187
1162
  tile_scheduler = TileSchedulerCls()
1188
1163
  work_tile = tile_scheduler.initial_work_tile_info()
1189
1164
  ab_consumer_state = pipeline.make_pipeline_state(
1190
- pipeline.PipelineUserType.Consumer, self.num_ab_stage
1165
+ pipeline.PipelineUserType.Consumer, self.ab_stage
1191
1166
  )
1192
1167
  acc_producer_state = pipeline.make_pipeline_state(
1193
1168
  pipeline.PipelineUserType.Producer, self.num_acc_stage
@@ -1241,11 +1216,29 @@ class PersistentDenseGemmKernel:
1241
1216
  # (MMA, MMA_M, MMA_N, STAGE)
1242
1217
  tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
1243
1218
 
1244
- epilog_threads = cute.arch.WARP_SIZE * len(self.epilog_warp_id)
1245
1219
  epilogue_barrier = pipeline.NamedBarrier(
1246
- barrier_id=self.epilog_sync_bar_id, num_threads=epilog_threads
1220
+ barrier_id=int(NamedBarrierGemm.Epilogue),
1221
+ num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
1247
1222
  )
1248
1223
 
1224
+ is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
1225
+ if const_expr(varlen_m):
1226
+ # initialize tensormap for D
1227
+ if const_expr(has_D):
1228
+ tensormap_manager.init_tensormap_from_atom(
1229
+ tma_atom_d,
1230
+ tensormap_d_ptr,
1231
+ is_manager_warp=is_tma_warp,
1232
+ )
1233
+ for tma_atom, tensormap_epi_ptr in zip(
1234
+ self.epi_get_tma_atoms(epilogue_params), tensormap_epi_ptrs
1235
+ ):
1236
+ tensormap_manager.init_tensormap_from_atom(
1237
+ tma_atom,
1238
+ tensormap_epi_ptr,
1239
+ is_manager_warp=is_tma_warp,
1240
+ )
1241
+
1249
1242
  # Partition for epilogue
1250
1243
  epi_tidx = tidx
1251
1244
  tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
@@ -1256,6 +1249,7 @@ class PersistentDenseGemmKernel:
1256
1249
  tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
1257
1250
  tiled_copy_t2r, tTR_rD, epi_tidx, sD
1258
1251
  )
1252
+ tRS_rC, tSR_rC = None, None
1259
1253
  if const_expr(mC_mnl is not None):
1260
1254
  tTR_rC = cute.make_fragment_like(tTR_rD, self.c_dtype)
1261
1255
  tiled_copy_s2r, tSR_rC, tSR_sC = self.epilog_smem_copy_and_partition(
@@ -1272,22 +1266,33 @@ class PersistentDenseGemmKernel:
1272
1266
  acc_consumer_state = pipeline.make_pipeline_state(
1273
1267
  pipeline.PipelineUserType.Consumer, self.num_acc_stage
1274
1268
  )
1275
- # Threads/warps participating in tma store pipeline
1276
- d_producer_group = pipeline.CooperativeGroup(
1277
- pipeline.Agent.Thread,
1278
- 32 * len(self.epilog_warp_id),
1279
- 32 * len(self.epilog_warp_id),
1280
- )
1281
- d_pipeline = pipeline.PipelineTmaStore.create(
1282
- num_stages=self.num_d_stage, producer_group=d_producer_group
1283
- )
1269
+ epi_store_pipeline = self.make_epi_store_pipeline()
1284
1270
  epi_read_state = pipeline.make_pipeline_state(
1285
- pipeline.PipelineUserType.Consumer, self.num_c_stage
1271
+ pipeline.PipelineUserType.Consumer, self.epi_c_stage
1286
1272
  )
1287
-
1273
+ if const_expr(varlen_m):
1274
+ # wait tensormap initialization complete before update
1275
+ tensormap_manager.fence_tensormap_initialization()
1276
+ # batch index of last tile
1277
+ last_batch_idx = cutlass.Int32(-1)
1288
1278
  while work_tile.is_valid_tile:
1289
1279
  # Get tile coord from tile scheduler
1290
1280
  tile_coord_mnkl = work_tile.tile_idx
1281
+ batch_idx = tile_coord_mnkl[3]
1282
+ if const_expr(varlen_m):
1283
+ is_group_changed = batch_idx != last_batch_idx
1284
+ last_batch_idx = batch_idx
1285
+ if is_group_changed:
1286
+ self.tensormap_update_D_epi(
1287
+ tensormap_manager,
1288
+ tensormap_d_ptr,
1289
+ tensormap_epi_ptrs,
1290
+ epilogue_params,
1291
+ cu_seqlens_m,
1292
+ batch_idx,
1293
+ is_manager_warp=is_tma_warp,
1294
+ )
1295
+
1291
1296
  mma_tile_coord_mnl = (
1292
1297
  tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
1293
1298
  tile_coord_mnkl[1],
@@ -1311,6 +1316,25 @@ class PersistentDenseGemmKernel:
1311
1316
  # Wait for accumulator buffer full
1312
1317
  acc_pipeline.consumer_wait(acc_consumer_state)
1313
1318
 
1319
+ tma_desc_d_ptr, tma_desc_epi_ptrs = None, [None] * self.num_epi_tensormaps
1320
+ if const_expr(varlen_m):
1321
+ # ensure the update to tensormap has completed before using it
1322
+ if is_group_changed and is_tma_warp:
1323
+ if const_expr(has_D):
1324
+ tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1325
+ for tensormap_epi_ptr in tensormap_epi_ptrs:
1326
+ tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
1327
+ if const_expr(has_D):
1328
+ tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
1329
+ tensormap_d_ptr, cute.AddressSpace.generic
1330
+ )
1331
+ tma_desc_epi_ptrs = [
1332
+ tensormap_manager.get_tensormap_ptr(
1333
+ tensormap_epi_ptr, cute.AddressSpace.generic
1334
+ )
1335
+ for tensormap_epi_ptr in tensormap_epi_ptrs
1336
+ ]
1337
+
1314
1338
  tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
1315
1339
  bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
1316
1340
 
@@ -1323,7 +1347,6 @@ class PersistentDenseGemmKernel:
1323
1347
  cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
1324
1348
  # Convert to D type
1325
1349
  acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
1326
- acc_vec = epilogue_op(acc_vec)
1327
1350
  if const_expr(mC_mnl is not None):
1328
1351
  epi_pipeline.consumer_wait(epi_read_state)
1329
1352
  cute.copy(
@@ -1340,7 +1363,7 @@ class PersistentDenseGemmKernel:
1340
1363
  acc_vec = acc_vec + tRS_rC.load().to(self.acc_dtype)
1341
1364
  tRS_rD.store(acc_vec.to(self.d_dtype))
1342
1365
  # Store D to shared memory
1343
- d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage
1366
+ d_buffer = (num_prev_subtiles + subtile_idx) % self.epi_stage
1344
1367
  cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
1345
1368
  # Fence and barrier to make sure shared memory store is visible to TMA store
1346
1369
  cute.arch.fence_proxy(
@@ -1348,11 +1371,11 @@ class PersistentDenseGemmKernel:
1348
1371
  )
1349
1372
  epilogue_barrier.arrive_and_wait()
1350
1373
  # TMA store D to global memory
1351
- if warp_idx == self.epilog_warp_id[0]:
1374
+ if is_tma_warp:
1352
1375
  cute.copy(tma_atom_d, bSG_sD[None, d_buffer], bSG_gD[None, subtile_idx])
1353
1376
  # Fence and barrier to make sure shared memory store is visible to TMA store
1354
- d_pipeline.producer_commit()
1355
- d_pipeline.producer_acquire()
1377
+ epi_store_pipeline.producer_commit()
1378
+ epi_store_pipeline.producer_acquire()
1356
1379
  epilogue_barrier.arrive_and_wait()
1357
1380
 
1358
1381
  # Async arrive accumulator buffer empty
@@ -1369,7 +1392,7 @@ class PersistentDenseGemmKernel:
1369
1392
  cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
1370
1393
  epilogue_barrier.arrive_and_wait()
1371
1394
  if warp_idx == self.epilog_warp_id[0]:
1372
- if use_2cta_instrs:
1395
+ if const_expr(use_2cta_instrs):
1373
1396
  cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
1374
1397
  cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
1375
1398
  cute.arch.dealloc_tmem(
@@ -1377,7 +1400,8 @@ class PersistentDenseGemmKernel:
1377
1400
  )
1378
1401
 
1379
1402
  # Wait for D store complete
1380
- d_pipeline.producer_tail()
1403
+ if is_tma_warp:
1404
+ epi_store_pipeline.producer_tail()
1381
1405
 
1382
1406
  @cute.jit
1383
1407
  def load_AB(
@@ -1407,7 +1431,7 @@ class PersistentDenseGemmKernel:
1407
1431
  assert all(x is not None for x in (tma_atom_sfb, tBgSFB, tBsSFB))
1408
1432
  k_tile_cnt = cute.size(tAgA, mode=[1])
1409
1433
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1410
- peek_ab_empty_status = cutlass.Boolean(True)
1434
+ peek_ab_empty_status = Boolean(True)
1411
1435
  if 0 < k_tile_cnt:
1412
1436
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1413
1437
  # /////////////////////////////////////////////////////////////////////////
@@ -1449,7 +1473,7 @@ class PersistentDenseGemmKernel:
1449
1473
  # Mainloop pipeline's producer commit is a NOP
1450
1474
  ab_pipeline.producer_commit(ab_producer_state)
1451
1475
  ab_producer_state.advance()
1452
- peek_ab_empty_status = cutlass.Boolean(True)
1476
+ peek_ab_empty_status = Boolean(True)
1453
1477
  if k_tile + 1 < k_tile_cnt:
1454
1478
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1455
1479
  return ab_producer_state
@@ -1466,7 +1490,7 @@ class PersistentDenseGemmKernel:
1466
1490
  tCrB: cute.Tensor,
1467
1491
  acc: cute.Tensor,
1468
1492
  k_tile_cnt: Int32,
1469
- is_leader_cta: cutlass.Boolean,
1493
+ is_leader_cta: Boolean,
1470
1494
  tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None,
1471
1495
  tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None,
1472
1496
  tCsSFA_compact_s2t: Optional[cute.Tensor] = None,
@@ -1480,7 +1504,7 @@ class PersistentDenseGemmKernel:
1480
1504
  assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t))
1481
1505
  assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t))
1482
1506
  # Peek (try_wait) AB buffer full for k_tile = 0
1483
- peek_ab_full_status = cutlass.Boolean(True)
1507
+ peek_ab_full_status = Boolean(True)
1484
1508
  if 0 < k_tile_cnt and is_leader_cta:
1485
1509
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
1486
1510
  # Wait for accumulator buffer empty
@@ -1509,7 +1533,7 @@ class PersistentDenseGemmKernel:
1509
1533
  ab_pipeline.consumer_release(ab_consumer_state)
1510
1534
  ab_consumer_state.advance()
1511
1535
  # Peek (try_wait) AB buffer full for k_tile = k_tile + 1
1512
- peek_ab_full_status = cutlass.Boolean(True)
1536
+ peek_ab_full_status = Boolean(True)
1513
1537
  if k_tile + 1 < k_tile_cnt and is_leader_cta:
1514
1538
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
1515
1539
  # Async arrive accumulator buffer full
@@ -1560,7 +1584,7 @@ class PersistentDenseGemmKernel:
1560
1584
  tidx: Int32,
1561
1585
  tAcc: cute.Tensor,
1562
1586
  epi_tile: cute.Tile,
1563
- use_2cta_instrs: Union[cutlass.Boolean, bool],
1587
+ use_2cta_instrs: Union[Boolean, bool],
1564
1588
  ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1565
1589
  """
1566
1590
  Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
@@ -1708,6 +1732,22 @@ class PersistentDenseGemmKernel:
1708
1732
  )
1709
1733
  return bSG_sD, bSG_gD
1710
1734
 
1735
+ def make_acc_pipeline(
1736
+ self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer
1737
+ ):
1738
+ acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1739
+ num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1)
1740
+ acc_pipeline_consumer_group = pipeline.CooperativeGroup(
1741
+ pipeline.Agent.Thread, num_acc_consumer_threads
1742
+ )
1743
+ return pipeline.PipelineUmmaAsync.create(
1744
+ barrier_storage=acc_pipeline_mbar_ptr,
1745
+ num_stages=self.num_acc_stage,
1746
+ producer_group=acc_pipeline_producer_group,
1747
+ consumer_group=acc_pipeline_consumer_group,
1748
+ cta_layout_vmnk=cluster_layout_vmnk,
1749
+ )
1750
+
1711
1751
  @staticmethod
1712
1752
  def _compute_stages(
1713
1753
  tiled_mma: cute.TiledMma,
@@ -1717,8 +1757,8 @@ class PersistentDenseGemmKernel:
1717
1757
  epi_tile: cute.Tile,
1718
1758
  d_dtype: Type[cutlass.Numeric],
1719
1759
  c_dtype: Optional[Type[cutlass.Numeric]],
1720
- d_layout: cutlass.utils.LayoutEnum,
1721
- c_layout: Optional[cutlass.utils.LayoutEnum],
1760
+ d_layout: LayoutEnum,
1761
+ c_layout: Optional[LayoutEnum],
1722
1762
  sf_dtype: Optional[Type[cutlass.Numeric]],
1723
1763
  sf_vec_size: Optional[int],
1724
1764
  smem_capacity: int,
@@ -1739,7 +1779,7 @@ class PersistentDenseGemmKernel:
1739
1779
  :param d_dtype: Data type of operand C (output).
1740
1780
  :type d_dtype: type[cutlass.Numeric]
1741
1781
  :param d_layout: Layout enum of operand C.
1742
- :type d_layout: cutlass.utils.LayoutEnum
1782
+ :type d_layout: LayoutEnum
1743
1783
  :param smem_capacity: Total available shared memory capacity in bytes.
1744
1784
  :type smem_capacity: int
1745
1785
  :param occupancy: Target number of CTAs per SM (occupancy).
@@ -1757,8 +1797,8 @@ class PersistentDenseGemmKernel:
1757
1797
  num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
1758
1798
 
1759
1799
  # Default D stages
1760
- num_d_stage = 2
1761
- num_c_stage = 2 if c_dtype is not None else 0
1800
+ epi_stage = 2
1801
+ epi_c_stage = 2 if c_dtype is not None else 0
1762
1802
 
1763
1803
  # Calculate smem layout and size for one stage of A, B, and C
1764
1804
  a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
@@ -1802,28 +1842,28 @@ class PersistentDenseGemmKernel:
1802
1842
  ) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
1803
1843
  mbar_helpers_bytes = 1024
1804
1844
  d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
1805
- epi_bytes = d_bytes_per_stage * num_d_stage
1845
+ epi_bytes = d_bytes_per_stage * epi_stage
1806
1846
  if const_expr(c_dtype is not None):
1807
1847
  c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
1808
- epi_bytes += c_bytes_per_stage * num_c_stage
1848
+ epi_bytes += c_bytes_per_stage * epi_c_stage
1809
1849
 
1810
1850
  # Calculate A/B/SFA/SFB stages:
1811
1851
  # Start with total smem per CTA (capacity / occupancy)
1812
1852
  # Subtract reserved bytes and initial C stages bytes
1813
1853
  # Divide remaining by bytes needed per A/B/SFA/SFB stage
1814
- num_ab_stage = (
1854
+ ab_stage = (
1815
1855
  smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)
1816
1856
  ) // ab_bytes_per_stage
1817
1857
 
1818
1858
  # Refine epilogue stages:
1819
1859
  # Calculate remaining smem after allocating for A/B stages and reserved bytes
1820
1860
  # Add remaining unused smem to epilogue
1821
- num_d_stage += (
1861
+ epi_stage += (
1822
1862
  smem_capacity
1823
- - occupancy * ab_bytes_per_stage * num_ab_stage
1863
+ - occupancy * ab_bytes_per_stage * ab_stage
1824
1864
  - occupancy * (mbar_helpers_bytes + epi_bytes)
1825
1865
  ) // (occupancy * d_bytes_per_stage)
1826
- return num_acc_stage, num_ab_stage, num_d_stage, num_c_stage
1866
+ return num_acc_stage, ab_stage, epi_stage, epi_c_stage
1827
1867
 
1828
1868
  @staticmethod
1829
1869
  def _compute_num_tmem_alloc_cols(
@@ -1880,7 +1920,7 @@ class PersistentDenseGemmKernel:
1880
1920
  }:
1881
1921
  is_valid = False
1882
1922
  if (
1883
- acc_dtype not in {cutlass.Float32, cutlass.Float16, Int32}
1923
+ acc_dtype not in {Float32, cutlass.Float16, Int32}
1884
1924
  or acc_dtype == cutlass.Float16
1885
1925
  and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
1886
1926
  or acc_dtype == Int32
@@ -1888,10 +1928,10 @@ class PersistentDenseGemmKernel:
1888
1928
  ):
1889
1929
  is_valid = False
1890
1930
  if (
1891
- acc_dtype == cutlass.Float32
1931
+ acc_dtype == Float32
1892
1932
  and d_dtype
1893
1933
  not in {
1894
- cutlass.Float32,
1934
+ Float32,
1895
1935
  cutlass.Float16,
1896
1936
  cutlass.BFloat16,
1897
1937
  cutlass.Float8E4M3FN,
@@ -1911,7 +1951,7 @@ class PersistentDenseGemmKernel:
1911
1951
  not in {
1912
1952
  cutlass.BFloat16,
1913
1953
  cutlass.Float16,
1914
- cutlass.Float32,
1954
+ Float32,
1915
1955
  Int32,
1916
1956
  cutlass.Int8,
1917
1957
  cutlass.Uint8,
@@ -1964,7 +2004,7 @@ class PersistentDenseGemmKernel:
1964
2004
 
1965
2005
  # Check valid d_dtype
1966
2006
  if d_dtype not in {
1967
- cutlass.Float32,
2007
+ Float32,
1968
2008
  cutlass.Float16,
1969
2009
  cutlass.BFloat16,
1970
2010
  cutlass.Float8E5M2,
@@ -2004,7 +2044,6 @@ class PersistentDenseGemmKernel:
2004
2044
 
2005
2045
  @staticmethod
2006
2046
  def is_valid_mma_tiler_and_cluster_shape(
2007
- use_2cta_instrs: bool,
2008
2047
  mma_tiler_mn: Tuple[int, int],
2009
2048
  cluster_shape_mn: Tuple[int, int],
2010
2049
  blockscaled: bool,
@@ -2012,8 +2051,6 @@ class PersistentDenseGemmKernel:
2012
2051
  """
2013
2052
  Check if the mma tiler and cluster shape are valid
2014
2053
 
2015
- :param use_2cta_instrs: Whether to use 2 CTA groups
2016
- :type use_2cta_instrs: bool
2017
2054
  :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
2018
2055
  :type mma_tiler_mn: Tuple[int, int]
2019
2056
  :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
@@ -2024,10 +2061,7 @@ class PersistentDenseGemmKernel:
2024
2061
  """
2025
2062
  is_valid = True
2026
2063
  # Skip invalid mma tile shape
2027
- if not (
2028
- (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
2029
- or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
2030
- ):
2064
+ if mma_tiler_mn[0] not in [64, 128, 256]:
2031
2065
  is_valid = False
2032
2066
  if not blockscaled:
2033
2067
  if mma_tiler_mn[1] not in range(32, 257, 32):
@@ -2035,9 +2069,6 @@ class PersistentDenseGemmKernel:
2035
2069
  else:
2036
2070
  if mma_tiler_mn[1] not in [128, 256]:
2037
2071
  is_valid = False
2038
- # Skip illegal cluster shape
2039
- if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
2040
- is_valid = False
2041
2072
  # Skip invalid cluster shape
2042
2073
  is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
2043
2074
  if (
@@ -2113,7 +2144,6 @@ class PersistentDenseGemmKernel:
2113
2144
  ab_dtype: Type[cutlass.Numeric],
2114
2145
  acc_dtype: Type[cutlass.Numeric],
2115
2146
  d_dtype: Type[cutlass.Numeric],
2116
- use_2cta_instrs: bool,
2117
2147
  mma_tiler_mn: Tuple[int, int],
2118
2148
  cluster_shape_mn: Tuple[int, int],
2119
2149
  m: int,
@@ -2133,8 +2163,6 @@ class PersistentDenseGemmKernel:
2133
2163
  :type acc_dtype: Type[cutlass.Numeric]
2134
2164
  :param d_dtype: The data type of the output tensor
2135
2165
  :type d_dtype: Type[cutlass.Numeric]
2136
- :param use_2cta_instrs: Whether to use 2 CTA groups
2137
- :type use_2cta_instrs: bool
2138
2166
  :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
2139
2167
  :type mma_tiler_mn: Tuple[int, int]
2140
2168
  :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
@@ -2159,15 +2187,15 @@ class PersistentDenseGemmKernel:
2159
2187
  """
2160
2188
  can_implement = True
2161
2189
  # Skip unsupported types
2162
- if not PersistentDenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, d_dtype):
2190
+ if not GemmSm100.is_valid_dtypes(ab_dtype, acc_dtype, d_dtype):
2163
2191
  can_implement = False
2164
2192
  # Skip invalid mma tile shape and cluster shape
2165
- if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
2166
- use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, blockscaled=False
2193
+ if not GemmSm100.is_valid_mma_tiler_and_cluster_shape(
2194
+ mma_tiler_mn, cluster_shape_mn, blockscaled=False
2167
2195
  ):
2168
2196
  can_implement = False
2169
2197
  # Skip illegal problem shape for load/store alignment
2170
- if not PersistentDenseGemmKernel.is_valid_tensor_alignment(
2198
+ if not GemmSm100.is_valid_tensor_alignment(
2171
2199
  m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major
2172
2200
  ):
2173
2201
  can_implement = False
@@ -2186,7 +2214,6 @@ def run(
2186
2214
  c_major: str,
2187
2215
  mma_tiler_mn: Tuple[int, int] = (256, 256),
2188
2216
  cluster_shape_mn: Tuple[int, int] = (2, 1),
2189
- use_2cta_instrs: bool = True,
2190
2217
  tolerance: float = 1e-01,
2191
2218
  warmup_iterations: int = 0,
2192
2219
  iterations: int = 1,
@@ -2215,9 +2242,6 @@ def run(
2215
2242
  :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
2216
2243
  default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
2217
2244
  :type cluster_shape_mn: Tuple[int, int], optional
2218
- :param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
2219
- will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
2220
- :type use_2cta_instrs: bool, optional
2221
2245
  :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
2222
2246
  :type tolerance: float, optional
2223
2247
  :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
@@ -2236,7 +2260,6 @@ def run(
2236
2260
  print(f"AB dtype: {ab_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
2237
2261
  print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
2238
2262
  print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
2239
- print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
2240
2263
  print(f"Tolerance: {tolerance}")
2241
2264
  print(f"Warmup iterations: {warmup_iterations}")
2242
2265
  print(f"Iterations: {iterations}")
@@ -2248,11 +2271,10 @@ def run(
2248
2271
  m, n, k, l = mnkl
2249
2272
 
2250
2273
  # Skip unsupported testcase
2251
- if not PersistentDenseGemmKernel.can_implement(
2274
+ if not GemmSm100.can_implement(
2252
2275
  ab_dtype,
2253
2276
  acc_dtype,
2254
2277
  d_dtype,
2255
- use_2cta_instrs,
2256
2278
  mma_tiler_mn,
2257
2279
  cluster_shape_mn,
2258
2280
  m,
@@ -2264,7 +2286,7 @@ def run(
2264
2286
  d_major,
2265
2287
  ):
2266
2288
  raise TypeError(
2267
- f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}"
2289
+ f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}"
2268
2290
  )
2269
2291
 
2270
2292
  if not torch.cuda.is_available():
@@ -2339,12 +2361,8 @@ def run(
2339
2361
  c, mC, c_torch = None, None, None
2340
2362
 
2341
2363
  # Configure gemm kernel
2342
- gemm = PersistentDenseGemmKernel(
2343
- acc_dtype,
2344
- use_2cta_instrs,
2345
- mma_tiler_mn,
2346
- cluster_shape_mn,
2347
- )
2364
+ cluster_shape_mnk = (*cluster_shape_mn, 1)
2365
+ gemm = GemmSm100(acc_dtype, mma_tiler_mn, cluster_shape_mnk)
2348
2366
 
2349
2367
  # Compute max active clusters on current device
2350
2368
  hardware_info = cutlass.utils.HardwareInfo()
@@ -2356,6 +2374,17 @@ def run(
2356
2374
  else:
2357
2375
  tile_count_semaphore = None
2358
2376
 
2377
+ scheduler_args = TileSchedulerOptions(
2378
+ Int32(max_active_clusters),
2379
+ tile_count_semaphore=make_ptr(
2380
+ Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
2381
+ )
2382
+ if tile_count_semaphore is not None
2383
+ else None,
2384
+ )
2385
+ epi_args = gemm.EpilogueArguments()
2386
+ varlen_args = VarlenArguments()
2387
+
2359
2388
  # Get current CUDA stream from PyTorch
2360
2389
  torch_stream = torch.cuda.current_stream()
2361
2390
  # Get the raw stream pointer as a CUstream
@@ -2367,15 +2396,14 @@ def run(
2367
2396
  mB,
2368
2397
  mD,
2369
2398
  mC,
2370
- make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2371
- if tile_count_semaphore is not None
2372
- else None,
2373
- max_active_clusters,
2399
+ epi_args,
2400
+ scheduler_args,
2401
+ varlen_args,
2374
2402
  current_stream,
2375
2403
  )
2376
2404
 
2377
2405
  if not skip_ref_check:
2378
- compiled_gemm(mA, mB, mD, mC, tile_count_semaphore, current_stream)
2406
+ compiled_gemm(mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream)
2379
2407
  if ab_dtype in {
2380
2408
  cutlass.Int8,
2381
2409
  cutlass.Uint8,
@@ -2393,7 +2421,7 @@ def run(
2393
2421
  gpu_d = d_torch.cpu()
2394
2422
 
2395
2423
  # Convert ref to c_type
2396
- if d_dtype == cutlass.Float32:
2424
+ if d_dtype == Float32:
2397
2425
  ref_d = ref
2398
2426
  elif d_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
2399
2427
  # m major: (l, n, m) -> (m, n, l)
@@ -2463,7 +2491,9 @@ def run(
2463
2491
  print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2464
2492
 
2465
2493
  time.sleep(0.5)
2466
- fn = lambda: compiled_gemm(mA, mB, mD, mC, tile_count_semaphore, current_stream)
2494
+ fn = lambda: compiled_gemm(
2495
+ mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream
2496
+ )
2467
2497
  timing = do_bench(fn, warmup=warmup, rep=repeats)
2468
2498
  tflops = flops / (timing * 1e9) # Convert to TFlops
2469
2499
  print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
@@ -2505,12 +2535,7 @@ if __name__ == "__main__":
2505
2535
  parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
2506
2536
  parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
2507
2537
  parser.add_argument("--c_dtype", type=cutlass.dtype, default=None)
2508
- parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
2509
- parser.add_argument(
2510
- "--use_2cta_instrs",
2511
- action="store_true",
2512
- help="Enable 2CTA MMA instructions feature",
2513
- )
2538
+ parser.add_argument("--acc_dtype", type=cutlass.dtype, default=Float32)
2514
2539
  parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
2515
2540
  parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
2516
2541
  parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
@@ -2552,7 +2577,6 @@ if __name__ == "__main__":
2552
2577
  args.c_major,
2553
2578
  args.mma_tiler_mn,
2554
2579
  args.cluster_shape_mn,
2555
- args.use_2cta_instrs,
2556
2580
  args.tolerance,
2557
2581
  args.warmup_iterations,
2558
2582
  args.iterations,