quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/dense_gemm_sm90.py CHANGED
@@ -2,7 +2,7 @@
2
2
  # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
3
3
 
4
4
  import enum
5
- from typing import Tuple, Type, Callable, Optional, Union
5
+ from typing import Tuple, Type, Callable, Optional, Union, Literal
6
6
  from dataclasses import dataclass
7
7
  from functools import partial
8
8
  import math
@@ -86,12 +86,13 @@ class NamedBarrierGemm(enum.IntEnum):
86
86
  MmaWG1 = enum.auto()
87
87
  EpiWG0 = enum.auto()
88
88
  EpiWG1 = enum.auto()
89
+ TmemPtr = enum.auto()
89
90
 
90
91
 
91
92
  class GemmSm90:
92
93
  """
93
94
  This class implements batched matrix multiplication (C = A x B) with support for various data types
94
- and architectural features specific to Hopper GPUs.
95
+ and architectural features specific to Hopper GPUs with persistent tile scheduling and warp specialization.
95
96
 
96
97
  :param acc_dtype: Data type for accumulation during computation
97
98
  :type acc_dtype: type[cutlass.Numeric]
@@ -125,12 +126,15 @@ class GemmSm90:
125
126
  >>> gemm(a_tensor, b_tensor, c_tensor, stream)
126
127
  """
127
128
 
129
+ arch = 90
128
130
  bytes_per_tensormap = 128
131
+ num_epi_tensormaps: int = 0
129
132
 
130
133
  @dataclass
131
134
  class EpilogueArguments(ArgumentsBase):
132
135
  alpha: Optional[Float32 | cute.Tensor] = None
133
136
  beta: Optional[Float32 | cute.Tensor] = None
137
+ add_to_output: bool = False
134
138
 
135
139
  @dataclass
136
140
  class EpilogueParams(ParamsBase):
@@ -174,8 +178,8 @@ class GemmSm90:
174
178
 
175
179
  self.cluster_shape_mnk = cluster_shape_mnk
176
180
  # K dimension is deferred in _setup_attributes
177
- self.tile_shape_mnk = (*tile_shape_mn, 1)
178
- tile_M, tile_N = self.tile_shape_mnk[0], self.tile_shape_mnk[1]
181
+ self.cta_tile_shape_mnk = (*tile_shape_mn, 1)
182
+ tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
179
183
  # check the cta tile shape
180
184
  if not self.pingpong:
181
185
  if tile_M not in [64, 128, 192, 256, 320]:
@@ -209,7 +213,9 @@ class GemmSm90:
209
213
  else:
210
214
  atom_layout_m, atom_layout_n = 1, 2
211
215
  else:
212
- atom_layout_m = self.tile_shape_mnk[0] // 64 if self.tile_shape_mnk[0] < 256 else 2
216
+ atom_layout_m = (
217
+ self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2
218
+ )
213
219
  atom_layout_n = 1
214
220
  assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
215
221
  else:
@@ -229,16 +235,14 @@ class GemmSm90:
229
235
  self.num_threads_per_warp_group = 128
230
236
  self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
231
237
  self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
232
- self.num_epi_threads = (
233
- self.mma_warp_groups if not self.pingpong else 1
234
- ) * self.num_threads_per_warp_group
238
+ self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
235
239
  self.num_ab_load_warps = 1 if not self.gather_A else 4
236
240
  self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
237
241
  self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
238
242
  self.ab_load_warp_id = self.mma_warp_groups * 4
239
243
  self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
240
244
 
241
- regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // (
245
+ regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
242
246
  math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
243
247
  )
244
248
  if self.fp8_slow_accum:
@@ -268,7 +272,7 @@ class GemmSm90:
268
272
  self.shared_storage = None
269
273
  self.buffer_align_bytes = 1024
270
274
 
271
- def _setup_attributes(self, epilogue_args: Optional[EpilogueArguments]):
275
+ def _setup_attributes(self, epilogue_args: EpilogueArguments):
272
276
  """Set up configurations that are dependent on GEMM inputs
273
277
 
274
278
  This method configures various attributes based on the input tensor properties
@@ -289,7 +293,7 @@ class GemmSm90:
289
293
  self.b_layout.sm90_mma_major_mode(),
290
294
  self.acc_dtype,
291
295
  self.atom_layout_mnk,
292
- tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
296
+ tiler_mn=(64, self.cta_tile_shape_mnk[1] // self.atom_layout_mnk[1]),
293
297
  )
294
298
  if const_expr(self.atom_layout_mnk[1] > 1):
295
299
  # If N dimension is split among 2 WGs, we need to permute the N dimension so
@@ -299,7 +303,7 @@ class GemmSm90:
299
303
  # WG1 would write to a separate epi smem of size (64, 16) that's far away.
300
304
  atom_n = self.atom_layout_mnk[1]
301
305
  permutation_n = cute.make_ordered_layout(
302
- (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
306
+ (8, self.cta_tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
303
307
  )
304
308
  self.tiled_mma = cute.make_tiled_mma(
305
309
  cute.make_mma_atom(self.tiled_mma.op),
@@ -308,23 +312,23 @@ class GemmSm90:
308
312
  )
309
313
  mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
310
314
  mma_inst_tile_k = 4
311
- self.tile_shape_mnk = (
312
- self.tile_shape_mnk[0],
313
- self.tile_shape_mnk[1],
315
+ self.cta_tile_shape_mnk = (
316
+ self.cta_tile_shape_mnk[0],
317
+ self.cta_tile_shape_mnk[1],
314
318
  mma_inst_shape_k * mma_inst_tile_k,
315
319
  )
316
320
 
317
321
  self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
318
322
 
319
323
  self.epi_tile = self._sm90_compute_tile_shape_or_override(
320
- self.tile_shape_mnk,
324
+ self.cta_tile_shape_mnk,
321
325
  self.atom_layout_mnk,
322
326
  self.d_dtype,
323
327
  )
324
328
 
325
329
  # Compute stage before compute smem layout
326
330
  self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
327
- self.tile_shape_mnk,
331
+ self.cta_tile_shape_mnk,
328
332
  self.epi_tile,
329
333
  self.a_dtype,
330
334
  self.b_dtype,
@@ -344,7 +348,7 @@ class GemmSm90:
344
348
  self.epi_smem_layout_staged,
345
349
  self.epi_c_smem_layout_staged,
346
350
  ) = self._make_smem_layouts(
347
- self.tile_shape_mnk,
351
+ self.cta_tile_shape_mnk,
348
352
  self.epi_tile,
349
353
  self.a_dtype,
350
354
  self.a_layout,
@@ -366,10 +370,9 @@ class GemmSm90:
366
370
  mB: cute.Tensor,
367
371
  mD: Optional[cute.Tensor],
368
372
  mC: Optional[cute.Tensor],
369
- epilogue_args: Optional[ArgumentsBase],
373
+ epilogue_args: ArgumentsBase,
370
374
  scheduler_args: TileSchedulerOptions,
371
375
  varlen_args: Optional[VarlenArguments],
372
- mAIdx: Optional[cute.Tensor],
373
376
  stream: cuda.CUstream,
374
377
  ):
375
378
  """Execute the GEMM operation in steps:
@@ -405,7 +408,10 @@ class GemmSm90:
405
408
  raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
406
409
  if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
407
410
  raise TypeError("a_dtype should be float16 or float8")
408
- assert (mAIdx is not None) == self.gather_A
411
+
412
+ if const_expr(varlen_args is None):
413
+ varlen_args = VarlenArguments()
414
+ assert (varlen_args.mAIdx is not None) == self.gather_A
409
415
 
410
416
  # Assume all strides are divisible by 128 bits except the last stride
411
417
  new_stride = lambda t: tuple(
@@ -421,77 +427,47 @@ class GemmSm90:
421
427
 
422
428
  self._setup_attributes(epilogue_args)
423
429
 
430
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0))
431
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0))
432
+ tma_atom_a, tma_tensor_a = None, None
424
433
  if const_expr(not self.gather_A):
425
434
  tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
426
435
  mA,
427
- self.a_smem_layout_staged,
428
- (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
436
+ a_smem_layout,
437
+ (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
429
438
  self.cluster_shape_mnk[1],
430
439
  )
431
- else:
432
- tma_atom_a, tma_tensor_a = None, None
433
-
434
440
  tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
435
441
  mB,
436
- self.b_smem_layout_staged,
437
- (self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
442
+ b_smem_layout,
443
+ (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
438
444
  self.cluster_shape_mnk[0],
439
445
  )
440
446
 
447
+ self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
448
+ if const_expr(not self.gather_A):
449
+ self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
450
+
451
+ tma_atom_d, tma_tensor_d = None, None
441
452
  if const_expr(mD is not None):
442
453
  tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
443
- mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
454
+ mD,
455
+ self.epi_smem_layout_staged,
456
+ self.epi_tile,
457
+ op_type="store"
458
+ if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
459
+ else "add",
444
460
  )
445
- else:
446
- tma_atom_d, tma_tensor_d = None, None
447
-
461
+ tma_atom_c, tma_tensor_c = None, None
448
462
  if const_expr(mC is not None):
449
463
  tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
450
- mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
464
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
451
465
  )
452
- else:
453
- tma_atom_c, tma_tensor_c = None, None
454
466
 
455
467
  epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
456
468
 
457
- if const_expr(varlen_args is None):
458
- varlen_args = VarlenArguments()
459
- if const_expr(varlen_args.mCuSeqlensM is None):
460
- num_problems = (
461
- mD.shape[2]
462
- if mD is not None
463
- else (
464
- mB.shape[2]
465
- if varlen_args.mCuSeqlensK is None
466
- else varlen_args.mCuSeqlensK.shape[0] - 1
467
- )
468
- )
469
- problem_shape_ntile_mnl = (
470
- cute.ceil_div(mA.shape[0], self.tile_shape_mnk[0]),
471
- cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
472
- num_problems,
473
- )
474
- TileSchedulerCls = self.get_scheduler_class()
475
- tile_sched_args = self.get_scheduler_arguments(problem_shape_ntile_mnl, scheduler_args)
476
- else:
477
- assert mD is not None or not self.gather_A
478
- problem_shape_ntile_mnl = (
479
- None,
480
- cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
481
- varlen_args.mCuSeqlensM.shape[0] - 1,
482
- )
483
- TileSchedulerCls = VarlenMTileScheduler
484
- tile_sched_args = VarlenMTileSchedulerArguments(
485
- problem_shape_ntile_mnl=problem_shape_ntile_mnl,
486
- total_m=mD.shape[0] if mD is not None else mAIdx.shape[0],
487
- cu_seqlens_m=varlen_args.mCuSeqlensM,
488
- raster_order=scheduler_args.raster_order,
489
- group_size=scheduler_args.max_swizzle_size,
490
- tile_shape_mn=self.tile_shape_mnk[:2],
491
- cluster_shape_mnk=self.cluster_shape_mnk,
492
- tile_count_semaphore=scheduler_args.tile_count_semaphore,
493
- is_persistent=self.is_persistent,
494
- )
469
+ TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
470
+ tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
495
471
  tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
496
472
  grid = TileSchedulerCls.get_grid_shape(
497
473
  tile_sched_params, scheduler_args.max_active_clusters
@@ -534,6 +510,7 @@ class GemmSm90:
534
510
 
535
511
  # Launch the kernel synchronously
536
512
  self.kernel(
513
+ self.tiled_mma,
537
514
  tma_atom_a,
538
515
  tma_tensor_a if const_expr(not self.gather_A) else mA,
539
516
  tma_atom_b,
@@ -543,11 +520,10 @@ class GemmSm90:
543
520
  tma_atom_c,
544
521
  tma_tensor_c,
545
522
  epilogue_params,
546
- mAIdx,
547
523
  varlen_args.mCuSeqlensM,
548
524
  varlen_args.mCuSeqlensK,
549
525
  varlen_args.mTensormaps,
550
- self.tiled_mma,
526
+ varlen_args.mAIdx,
551
527
  self.cluster_layout_mnk,
552
528
  self.a_smem_layout_staged,
553
529
  self.b_smem_layout_staged,
@@ -569,6 +545,7 @@ class GemmSm90:
569
545
  @cute.kernel
570
546
  def kernel(
571
547
  self,
548
+ tiled_mma: cute.TiledMma,
572
549
  tma_atom_a: Optional[cute.CopyAtom],
573
550
  mA_mkl: cute.Tensor,
574
551
  tma_atom_b: cute.CopyAtom,
@@ -578,11 +555,10 @@ class GemmSm90:
578
555
  tma_atom_c: Optional[cute.CopyAtom],
579
556
  mC_mnl: Optional[cute.Tensor],
580
557
  epilogue_params: ParamsBase,
581
- mAIdx: Optional[cute.Tensor],
582
558
  cu_seqlens_m: Optional[cute.Tensor],
583
559
  cu_seqlens_k: Optional[cute.Tensor],
584
560
  tensormaps: Optional[cute.Tensor],
585
- tiled_mma: cute.TiledMma,
561
+ mAIdx: Optional[cute.Tensor],
586
562
  cluster_layout_mnk: cute.Layout,
587
563
  a_smem_layout: cute.ComposedLayout,
588
564
  b_smem_layout: cute.ComposedLayout,
@@ -621,6 +597,8 @@ class GemmSm90:
621
597
  varlen_m = const_expr(cu_seqlens_m is not None)
622
598
  varlen_k = const_expr(cu_seqlens_k is not None)
623
599
  assert not (varlen_m and varlen_k)
600
+ if const_expr(self.gather_A):
601
+ assert varlen_m or varlen_k
624
602
  has_D = const_expr(mD_mnl is not None)
625
603
  has_C = const_expr(mC_mnl is not None)
626
604
 
@@ -641,8 +619,6 @@ class GemmSm90:
641
619
  storage = smem.allocate(self.shared_storage)
642
620
 
643
621
  ab_pipeline = self.make_ab_pipeline(
644
- a_smem_layout=cute.slice_(a_smem_layout, (None, None, 0)),
645
- b_smem_layout=cute.slice_(b_smem_layout, (None, None, 0)),
646
622
  tiled_mma=tiled_mma,
647
623
  cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
648
624
  ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
@@ -682,27 +658,9 @@ class GemmSm90:
682
658
  epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
683
659
 
684
660
  # Get tensormap buffer address
685
- tensormap_manager = None
686
- tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
687
- if const_expr(varlen_m or varlen_k):
688
- tensormap_manager = TensorMapManagerSm90(
689
- cutlass.utils.TensorMapUpdateMode.GMEM, GemmSm90.bytes_per_tensormap
690
- )
691
- # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
692
- tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
693
- if const_expr(varlen_m):
694
- tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
695
- tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
696
- tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
697
- )
698
- else:
699
- assert varlen_k
700
- tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
701
- tensormaps[tensormap_workspace_idx, 0, None].iterator
702
- )
703
- tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
704
- tensormaps[tensormap_workspace_idx, 1, None].iterator
705
- )
661
+ tensormap_manager, tensormap_ab_ptrs, tensormap_d_ptr, tensormap_epi_ptrs = (
662
+ self.tensormap_init(tensormaps, varlen_m, varlen_k, has_D, warp_idx)
663
+ )
706
664
 
707
665
  TileSchedulerCls = partial(
708
666
  TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
@@ -717,14 +675,15 @@ class GemmSm90:
717
675
  is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
718
676
  if const_expr(varlen_k):
719
677
  # initialize tensormap for A & B
720
- tensormap_manager.init_tensormap_from_atom(
721
- tma_atom_a,
722
- tensormap_a_ptr,
723
- is_tma_warp,
724
- )
678
+ if const_expr(not self.gather_A):
679
+ tensormap_manager.init_tensormap_from_atom(
680
+ tma_atom_a,
681
+ tensormap_ab_ptrs[0],
682
+ is_tma_warp,
683
+ )
725
684
  tensormap_manager.init_tensormap_from_atom(
726
685
  tma_atom_b,
727
- tensormap_b_ptr,
686
+ tensormap_ab_ptrs[1],
728
687
  is_tma_warp,
729
688
  )
730
689
  # ///////////////////////////////////////////////////////////////////////////////
@@ -762,16 +721,12 @@ class GemmSm90:
762
721
  is_group_changed = batch_idx != last_batch_idx
763
722
  last_batch_idx = batch_idx
764
723
  if is_group_changed:
765
- # construct tensor A/B based on real address, shape and stride information
766
- tensormap_manager.update_tensormap_shape(
767
- (tensormap_a_ptr, tensormap_b_ptr),
768
- is_manager_warp=is_tma_warp,
769
- shapes=(cu_seqlens_k[batch_idx + 1], cu_seqlens_k[batch_idx + 1]),
770
- orders=(
771
- 0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1,
772
- 0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1,
773
- ),
774
- tensormap_smem_ptr=None,
724
+ self.tensormap_update_AB(
725
+ tensormap_manager,
726
+ tensormap_ab_ptrs,
727
+ cu_seqlens_k,
728
+ batch_idx,
729
+ is_tma_warp,
775
730
  )
776
731
  # ///////////////////////////////////////////////////////////////////////////
777
732
  # Local_tile partition global tensors
@@ -786,44 +741,54 @@ class GemmSm90:
786
741
  # (bM, bK, RestK)
787
742
  gA_k = cute.local_tile(
788
743
  mA_mk,
789
- cute.select(self.tile_shape_mnk, [0, 2]),
744
+ cute.select(self.cta_tile_shape_mnk, [0, 2]),
790
745
  (tile_coord_mnkl[0], None),
791
746
  )
792
747
  else:
793
- mA_mk = mA_mkl
794
748
  if const_expr(varlen_m):
795
749
  mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
796
- elif const_expr(varlen_k):
797
- mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
750
+ gAIdx = cute.local_tile(
751
+ mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
752
+ )
753
+ # (M, K)
754
+ mA_mk = mA_mkl
798
755
  else:
799
- mAIdx_mk = mAIdx[None, batch_idx]
800
- gAIdx = cute.local_tile(
801
- mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
802
- )
756
+ assert varlen_k
757
+ mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
758
+ # (tile_K, RestK)
759
+ gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
760
+ # (tile_M, K)
761
+ mA_mk = cute.local_tile(
762
+ mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
763
+ )
803
764
  if const_expr(varlen_k):
804
765
  mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
805
766
  else:
806
767
  mB_nk = mB_nkl[None, None, batch_idx]
807
768
  # (bN, bK, RestK)
808
769
  gB_k = cute.local_tile(
809
- mB_nk, cute.select(self.tile_shape_mnk, [1, 2]), (tile_coord_mnkl[1], None)
770
+ mB_nk,
771
+ cute.select(self.cta_tile_shape_mnk, [1, 2]),
772
+ (tile_coord_mnkl[1], None),
810
773
  )
811
774
  # //////////////////////////////////////////////////////////////////////////
812
775
  # Partition shared tensor for TMA load A/B
813
776
  # //////////////////////////////////////////////////////////////////////////
777
+ tma_desc_a_ptr, tma_desc_b_ptr = None, None
814
778
  if const_expr(varlen_k):
815
779
  # ensure the update to tensormap has completed before using it
780
+ tensormap_a_ptr, tensormap_b_ptr = tensormap_ab_ptrs
816
781
  if is_group_changed and is_tma_warp:
817
- tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
782
+ if const_expr(not self.gather_A):
783
+ tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
818
784
  tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
819
- tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
820
- tensormap_a_ptr, cute.AddressSpace.generic
821
- )
785
+ if const_expr(not self.gather_A):
786
+ tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
787
+ tensormap_a_ptr, cute.AddressSpace.generic
788
+ )
822
789
  tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
823
790
  tensormap_b_ptr, cute.AddressSpace.generic
824
791
  )
825
- else:
826
- tma_desc_a_ptr, tma_desc_b_ptr = None, None
827
792
  # TMA load A partition_S/D
828
793
  a_cta_layout = cute.make_layout(
829
794
  cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
@@ -855,8 +820,11 @@ class GemmSm90:
855
820
  thr_copy_A = tiled_copy_A.get_slice(tidx)
856
821
  # (atom_v, CPY_M, 1, STAGE)
857
822
  tAsA = thr_copy_A.partition_D(sA)
858
- assert tAsA.shape[2] == 1
859
- tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
823
+ if const_expr(varlen_m): # k-major
824
+ assert tAsA.shape[2] == 1
825
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
826
+ else: # varlen_k, m-major
827
+ tAsA = cute.group_modes(tAsA, 0, 3)
860
828
  copy_A = partial(cute.copy, tiled_copy_A)
861
829
  # TMA load B partition_S/D
862
830
  b_cta_layout = cute.make_layout(
@@ -877,9 +845,9 @@ class GemmSm90:
877
845
  k_len = (
878
846
  cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
879
847
  if const_expr(varlen_k)
880
- else mA_mkl.shape[1]
848
+ else Int32(mA_mkl.shape[1])
881
849
  )
882
- k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
850
+ k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
883
851
  if const_expr(not self.gather_A):
884
852
  ab_producer_state = self.load_AB(
885
853
  ab_pipeline,
@@ -894,7 +862,7 @@ class GemmSm90:
894
862
  )
895
863
  else:
896
864
  limit_m = (
897
- mAIdx.shape[0]
865
+ Int32(mA_mkl.shape[0])
898
866
  if const_expr(cu_seqlens_m is None)
899
867
  else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
900
868
  )
@@ -910,19 +878,17 @@ class GemmSm90:
910
878
  tBsB,
911
879
  k_tile_cnt,
912
880
  limit_A=(
913
- limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
914
- mA_mk.shape[1],
881
+ limit_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
882
+ k_len,
915
883
  ),
884
+ varlen_m=varlen_m,
916
885
  )
917
886
  tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
918
- tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
919
887
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
920
888
  work_tile = tile_scheduler.get_current_work()
921
889
  # End of persistent scheduler loop
922
890
  if const_expr(self.pingpong and not varlen_k):
923
891
  # Need to write the tile_idx to smem for the next WG in the pingpong mode
924
- # tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
925
- tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
926
892
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
927
893
  ab_pipeline.producer_tail(ab_producer_state)
928
894
  if is_scheduler_warp:
@@ -936,11 +902,20 @@ class GemmSm90:
936
902
  )
937
903
  if const_expr(varlen_m):
938
904
  # initialize tensormap for D
939
- tensormap_manager.init_tensormap_from_atom(
940
- tma_atom_d,
941
- tensormap_d_ptr,
942
- is_manager_warp=is_tma_warp,
943
- )
905
+ if const_expr(has_D):
906
+ tensormap_manager.init_tensormap_from_atom(
907
+ tma_atom_d,
908
+ tensormap_d_ptr,
909
+ is_manager_warp=is_tma_warp,
910
+ )
911
+ for tma_atom, tensormap_epi_ptr in zip(
912
+ self.epi_get_tma_atoms(epilogue_params), tensormap_epi_ptrs
913
+ ):
914
+ tensormap_manager.init_tensormap_from_atom(
915
+ tma_atom,
916
+ tensormap_epi_ptr,
917
+ is_manager_warp=is_tma_warp,
918
+ )
944
919
  # //////////////////////////////////////////////////////////////////////////////
945
920
  # Partition global tensor for TiledMMA_A/B/C
946
921
  # //////////////////////////////////////////////////////////////////////////////
@@ -962,7 +937,9 @@ class GemmSm90:
962
937
  tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
963
938
  tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
964
939
 
965
- acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
940
+ acc_shape = tiled_mma.partition_shape_C(
941
+ cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
942
+ )
966
943
  acc = cute.make_fragment(acc_shape, self.acc_dtype)
967
944
  acc_slow = None
968
945
  if const_expr(self.fp8_slow_accum):
@@ -974,10 +951,11 @@ class GemmSm90:
974
951
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
975
952
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
976
953
 
977
- k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.tile_shape_mnk[2])
978
- c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
954
+ k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2])
955
+ c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
979
956
 
980
957
  ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
958
+ epi_store_pipeline = self.make_epi_store_pipeline()
981
959
  epi_read_state = make_pipeline_state(
982
960
  pipeline.PipelineUserType.Consumer, self.epi_c_stage
983
961
  )
@@ -998,7 +976,7 @@ class GemmSm90:
998
976
  else:
999
977
  batch_idx = work_tile.tile_idx[3]
1000
978
  k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1001
- k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
979
+ k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
1002
980
  ab_read_state.advance_iters(k_tile_cnt)
1003
981
  tile_scheduler.advance_to_next_work()
1004
982
  if const_expr(varlen_k):
@@ -1019,13 +997,14 @@ class GemmSm90:
1019
997
  is_group_changed = batch_idx != last_batch_idx
1020
998
  last_batch_idx = batch_idx
1021
999
  if is_group_changed:
1022
- # construct tensor D based on real address, shape and stride information
1023
- tensormap_manager.update_tensormap_shape(
1024
- (tensormap_d_ptr,),
1000
+ self.tensormap_update_D_epi(
1001
+ tensormap_manager,
1002
+ tensormap_d_ptr,
1003
+ tensormap_epi_ptrs,
1004
+ epilogue_params,
1005
+ cu_seqlens_m,
1006
+ batch_idx,
1025
1007
  is_manager_warp=is_tma_warp,
1026
- shapes=(cu_seqlens_m[batch_idx + 1],),
1027
- orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
1028
- tensormap_smem_ptr=None,
1029
1008
  )
1030
1009
 
1031
1010
  k_len = (
@@ -1033,7 +1012,7 @@ class GemmSm90:
1033
1012
  if const_expr(varlen_k)
1034
1013
  else mA_mkl.shape[1]
1035
1014
  )
1036
- k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1015
+ k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
1037
1016
  ab_read_state, tiled_mma = self.mma(
1038
1017
  ab_pipeline,
1039
1018
  ab_read_state,
@@ -1056,24 +1035,34 @@ class GemmSm90:
1056
1035
  self.pingpong_barrier_sync(warp_group_idx, "epi")
1057
1036
 
1058
1037
  epilogue_barrier = pipeline.NamedBarrier(
1059
- barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
1038
+ barrier_id=int(NamedBarrierGemm.Epilogue),
1039
+ num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
1060
1040
  )
1061
1041
 
1042
+ tma_desc_d_ptr, tma_desc_epi_ptrs = None, [None] * self.num_epi_tensormaps
1062
1043
  if const_expr(varlen_m):
1063
1044
  # ensure the update to tensormap has completed before using it
1064
1045
  if is_group_changed and is_tma_warp:
1065
- tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1066
- tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
1067
- tensormap_d_ptr, cute.AddressSpace.generic
1068
- )
1069
- else:
1070
- tma_desc_d_ptr = None
1046
+ if const_expr(has_D):
1047
+ tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1048
+ for tensormap_epi_ptr in tensormap_epi_ptrs:
1049
+ tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
1050
+ if const_expr(has_D):
1051
+ tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
1052
+ tensormap_d_ptr, cute.AddressSpace.generic
1053
+ )
1054
+ tma_desc_epi_ptrs = [
1055
+ tensormap_manager.get_tensormap_ptr(
1056
+ tensormap_epi_ptr, cute.AddressSpace.generic
1057
+ )
1058
+ for tensormap_epi_ptr in tensormap_epi_ptrs
1059
+ ]
1071
1060
 
1072
1061
  if const_expr(has_D):
1073
1062
  bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(
1074
1063
  tma_atom_d,
1075
1064
  mD_mnl,
1076
- self.tile_shape_mnk[:2],
1065
+ self.cta_tile_shape_mnk[:2],
1077
1066
  self.epi_tile,
1078
1067
  sD,
1079
1068
  tile_coord_mnkl,
@@ -1086,7 +1075,7 @@ class GemmSm90:
1086
1075
  bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
1087
1076
  tma_atom_c,
1088
1077
  mC_mnl,
1089
- self.tile_shape_mnk[:2],
1078
+ self.cta_tile_shape_mnk[:2],
1090
1079
  self.epi_tile,
1091
1080
  sC,
1092
1081
  tile_coord_mnkl,
@@ -1118,7 +1107,9 @@ class GemmSm90:
1118
1107
  epi_read_state, epi_producer_state = self.epilogue(
1119
1108
  epilogue_params,
1120
1109
  epi_smem_tensors,
1110
+ tma_desc_epi_ptrs,
1121
1111
  epi_pipeline,
1112
+ epi_store_pipeline,
1122
1113
  epi_read_state,
1123
1114
  epi_producer_state,
1124
1115
  tiled_mma,
@@ -1147,7 +1138,7 @@ class GemmSm90:
1147
1138
  # so we have to make sure the smem content is done reading before signaling
1148
1139
  # the next WG's epilogue.
1149
1140
  if is_tma_warp:
1150
- cute.arch.cp_async_bulk_wait_group(0, read=True)
1141
+ epi_store_pipeline.producer_tail()
1151
1142
  self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
1152
1143
 
1153
1144
  if const_expr(not self.pingpong):
@@ -1168,15 +1159,16 @@ class GemmSm90:
1168
1159
  if work_tile.is_valid_tile:
1169
1160
  batch_idx = work_tile.tile_idx[3]
1170
1161
  k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1171
- k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1162
+ k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
1172
1163
  ab_read_state.advance_iters(k_tile_cnt)
1173
1164
  tile_scheduler.advance_to_next_work()
1174
1165
  work_tile = tile_scheduler.get_current_work()
1175
1166
  # End of persistent scheduler loop
1176
1167
 
1168
+ # Wait for D store complete
1177
1169
  if const_expr(not self.pingpong):
1178
1170
  if is_tma_warp:
1179
- cute.arch.cp_async_bulk_wait_group(0, read=True)
1171
+ epi_store_pipeline.producer_tail()
1180
1172
 
1181
1173
  @cute.jit
1182
1174
  def load_AB(
@@ -1219,37 +1211,57 @@ class GemmSm90:
1219
1211
  ab_pipeline: cutlass.pipeline.PipelineAsync,
1220
1212
  ab_producer_state: cutlass.pipeline.PipelineState,
1221
1213
  thr_copy_A: cute.core.ThrCopy,
1222
- mA: cute.Tensor,
1214
+ mA: cute.Tensor, # (M, K) if varlen_m, (tile_M, K) if varlen_k
1223
1215
  tAsA: cute.Tensor,
1224
- gAIdx: cute.Tensor,
1216
+ gAIdx: cute.Tensor, # (tile_M,) if varlen_m, (tile_K, RestK) if varlen_k
1225
1217
  copy_B: Callable,
1226
1218
  tBgB: cute.Tensor,
1227
1219
  tBsB: cute.Tensor,
1228
1220
  k_tile_cnt: Int32,
1229
1221
  limit_A: Tuple[Int32, Int32],
1222
+ varlen_m: bool,
1230
1223
  ) -> cutlass.pipeline.PipelineState:
1231
- # (atom_v, CPY_M, 1, RestK)
1232
1224
  limit_m, limit_k = limit_A
1233
- limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
1234
- cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
1225
+ # Do we need to check if we overshoot tile_M when we load A?
1226
+ is_even_m_smem = self.cta_tile_shape_mnk[0] % thr_copy_A.tiler_mn[0].shape == 0
1227
+ if const_expr(not is_even_m_smem):
1228
+ limit_m = min(limit_m, self.cta_tile_shape_mnk[0])
1229
+ elems_per_load = cute.size(tAsA.shape[0][0])
1230
+ cA = cute.make_identity_tensor(cute.select(self.cta_tile_shape_mnk, [0, 2]))
1235
1231
  tAcA = thr_copy_A.partition_S(cA)
1236
1232
  t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
1237
1233
  # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
1238
1234
  # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
1239
1235
  # This is so that when we do the comparison, t0AcA is known at compile time.
1240
1236
  limit_m = limit_m - tAcA[0][0]
1237
+ limit_k = limit_k - tAcA[0][1]
1241
1238
  # Read indices for A
1242
1239
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
1243
- m_idx = cute.make_fragment(rows_per_thread, Int32)
1244
- for m in cutlass.range(rows_per_thread):
1245
- row_idx = tAcA[0, m, 0][0]
1246
- if t0AcA[0, m, 0][0] < limit_m:
1247
- m_idx[m] = gAIdx[row_idx]
1248
- else:
1249
- m_idx[m] = -1
1250
- elems_per_load = cute.size(tAsA.shape[0][0])
1240
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
1241
+ tApA_m = cute.make_fragment(rows_per_thread, Boolean)
1242
+ for m in cutlass.range_constexpr(rows_per_thread):
1243
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
1244
+ m_idx, k_idx, tAmA = None, None, None
1245
+ if const_expr(varlen_m):
1246
+ m_idx = cute.make_fragment(rows_per_thread, Int32)
1247
+ for m in cutlass.range(rows_per_thread):
1248
+ row_idx = tAcA[0, m, 0][0]
1249
+ if tApA_m[m]:
1250
+ m_idx[m] = gAIdx[row_idx]
1251
+ else:
1252
+ m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
1253
+ else:
1254
+ k_idx = cute.make_fragment(cols_per_thread, Int32) # Will be read later
1255
+ threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
1256
+ # This is very convoluted but idk a better way
1257
+ # for tile_M=128, flat_divide gives (8, 16, K),
1258
+ # then logical_divide gives ((8, 1), (8, 2), K).
1259
+ tidx = thr_copy_A.thr_idx
1260
+ tAmA = cute.logical_divide(
1261
+ cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
1262
+ )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
1251
1263
  # (m, (bK, RestK))
1252
- mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
1264
+ mA_k = cute.logical_divide(mA, (None, self.cta_tile_shape_mnk[2]))
1253
1265
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1254
1266
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1255
1267
  peek_ab_empty_status = Boolean(True)
@@ -1260,31 +1272,55 @@ class GemmSm90:
1260
1272
  # /////////////////////////////////////////////////////////////////////////
1261
1273
  copy_A = partial(cute.copy, thr_copy_A)
1262
1274
  for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1275
+ if const_expr(not varlen_m): # Prefetch mAIdx early, even before smem is free
1276
+ gAIdx_cur = gAIdx[None, k_tile]
1277
+ for k in cutlass.range(cols_per_thread):
1278
+ col_idx = tAcA[0, 0, k][1]
1279
+ k_idx[k] = gAIdx_cur[col_idx]
1263
1280
  # Wait for A/B buffers to be empty before loading into them
1264
1281
  # Also sets the transaction barrier for the A/B buffers
1282
+ # A tiny bit faster to rotate the warp that does TMA
1283
+ # However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
1284
+ # since that's the warp that does the tensormap update.
1285
+ tma_warp_id = self.ab_load_warp_id + (
1286
+ (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1287
+ )
1265
1288
  ab_pipeline.producer_acquire(
1266
1289
  ab_producer_state,
1267
1290
  peek_ab_empty_status,
1268
- # A tiny bit faster to rotate the warp that does TMA
1269
- is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
1291
+ is_tma_warp=warp_idx == tma_warp_id,
1270
1292
  )
1271
1293
  # A bit faster to load B first while we calculate the predicate for A
1272
- if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1294
+ if warp_idx == tma_warp_id:
1273
1295
  copy_B(
1274
1296
  tBgB[None, k_tile],
1275
1297
  tBsB[None, ab_producer_state.index],
1276
1298
  tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1277
1299
  )
1278
1300
  # (m, bK)
1279
- mA_cur = mA_k[None, (None, k_tile)]
1280
- for m in cutlass.range_constexpr(tAcA.shape[1]):
1281
- # (elems_per_load, thread_per_row)
1282
- mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1283
- if t0AcA[0, m, 0][0] < limit_m:
1284
- # There's only 1 load per row
1285
- assert cute.size(tAcA.shape, mode=[2]) == 1
1286
- ki = tAcA[0, 0, 0][1] // elems_per_load
1287
- copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1301
+ if const_expr(varlen_m):
1302
+ mA_cur = mA_k[None, (None, k_tile)]
1303
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1304
+ # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
1305
+ # ((elems_per_load), thread_per_row)
1306
+ # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
1307
+ # So we append 1s to the last dimension and then do tiled_divide, then slice.
1308
+ mA_row = cute.tiled_divide(
1309
+ cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
1310
+ )[None, None, 0]
1311
+ if const_expr(is_even_m_smem) or tApA_m[m]:
1312
+ # There's only 1 load per row
1313
+ assert cute.size(tAcA.shape, mode=[2]) == 1
1314
+ ki = tAcA[0, 0, 0][1] // elems_per_load
1315
+ copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1316
+ else:
1317
+ for k in cutlass.range_constexpr(tAcA.shape[2]):
1318
+ # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), ab_producer_state.index], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
1319
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1320
+ if tApA_m[m]:
1321
+ copy_A(
1322
+ tAmA[None, m, k_idx[k]], tAsA[(None, m, k), ab_producer_state.index]
1323
+ )
1288
1324
  # This tells mbarrier to track the completion of cp.async
1289
1325
  ab_pipeline.producer_commit(ab_producer_state)
1290
1326
  ab_producer_state.advance()
@@ -1294,32 +1330,57 @@ class GemmSm90:
1294
1330
  # bound checking in the K dimension on the last k_tile
1295
1331
  if 0 < k_tile_cnt:
1296
1332
  k_tile = k_tile_cnt - 1
1333
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
1334
+ limit_k -= k_tile * self.cta_tile_shape_mnk[2]
1335
+ for k in cutlass.range_constexpr(cols_per_thread):
1336
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k
1337
+ if const_expr(not varlen_m):
1338
+ gAIdx_cur = gAIdx[None, k_tile]
1339
+ for k in cutlass.range(cols_per_thread):
1340
+ col_idx = tAcA[0, 0, k][1]
1341
+ if tApA_k[k]:
1342
+ k_idx[k] = gAIdx_cur[col_idx]
1343
+ else:
1344
+ k_idx[k] = -1
1345
+ tma_warp_id = self.ab_load_warp_id + (
1346
+ (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1347
+ )
1297
1348
  ab_pipeline.producer_acquire(
1298
1349
  ab_producer_state,
1299
1350
  peek_ab_empty_status,
1300
- is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
1351
+ is_tma_warp=warp_idx == tma_warp_id,
1301
1352
  )
1302
- if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1353
+ if warp_idx == tma_warp_id:
1303
1354
  copy_B(
1304
1355
  tBgB[None, k_tile],
1305
1356
  tBsB[None, ab_producer_state.index],
1306
1357
  tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1307
1358
  )
1308
- assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
1309
- tApA = cute.make_fragment(1, Boolean)
1310
- tApA[0] = tAcA[0, 0, 0][1] < limit_k
1311
- # (m, bK)
1312
- mA_cur = mA_k[None, (None, k_tile)]
1313
- for m in cutlass.range_constexpr(tAcA.shape[1]):
1314
- # (elems_per_load, thread_per_row)
1315
- mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1316
- if t0AcA[0, m, 0][0] < limit_m:
1317
- # There's only 1 load per row
1318
- assert cute.size(tAcA.shape, mode=[2]) == 1
1319
- ki = tAcA[0, 0, 0][1] // elems_per_load
1320
- # copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
1321
- # TODO
1322
- copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1359
+ if const_expr(varlen_m):
1360
+ # (m, bK)
1361
+ mA_cur = mA_k[None, (None, k_tile)]
1362
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1363
+ # ((elems_per_load, 1), thread_per_row)
1364
+ mA_row = cute.tiled_divide(
1365
+ cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
1366
+ )[None, None, 0]
1367
+ if const_expr(is_even_m_smem) or tApA_m[k]:
1368
+ # There's only 1 load per row
1369
+ assert cute.size(tAcA.shape, mode=[2]) == 1
1370
+ ki = tAcA[0, 0, 0][1] // elems_per_load
1371
+ copy_A(
1372
+ mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA_k
1373
+ )
1374
+ else:
1375
+ tApA_k = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
1376
+ for k in cutlass.range_constexpr(tAcA.shape[2]):
1377
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1378
+ if tApA_m[m]:
1379
+ copy_A(
1380
+ tAmA[None, m, k_idx[k]],
1381
+ tAsA[(None, m, k), ab_producer_state.index],
1382
+ pred=tApA_k[None, k],
1383
+ )
1323
1384
  ab_pipeline.producer_commit(ab_producer_state)
1324
1385
  ab_producer_state.advance()
1325
1386
  return ab_producer_state
@@ -1416,7 +1477,9 @@ class GemmSm90:
1416
1477
  self,
1417
1478
  params: EpilogueParams,
1418
1479
  epi_smem_tensors: Tuple[cute.Tensor, ...],
1480
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
1419
1481
  epi_pipeline: cutlass.pipeline.PipelineAsync,
1482
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
1420
1483
  epi_read_state: cutlass.pipeline.PipelineState,
1421
1484
  epi_producer_state: cutlass.pipeline.PipelineState,
1422
1485
  tiled_mma: cute.TiledMma,
@@ -1443,7 +1506,7 @@ class GemmSm90:
1443
1506
  has_D = const_expr(copy_D is not None)
1444
1507
  # We iterate over epi tiles in the N dimension first before the M dimension
1445
1508
  epi_tile_shape = cute.zipped_divide(
1446
- cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
1509
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), self.epi_tile
1447
1510
  ).shape[1]
1448
1511
  epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
1449
1512
  epi_tile_num = cute.size(epi_tile_shape)
@@ -1491,8 +1554,8 @@ class GemmSm90:
1491
1554
  if is_tma_warp:
1492
1555
  if const_expr(has_D):
1493
1556
  copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
1494
- cute.arch.cp_async_bulk_commit_group()
1495
- cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
1557
+ epi_store_pipeline.producer_commit()
1558
+ epi_store_pipeline.producer_acquire()
1496
1559
  epilogue_barrier.arrive_and_wait()
1497
1560
 
1498
1561
  return epi_read_state, epi_producer_state
@@ -1544,21 +1607,171 @@ class GemmSm90:
1544
1607
  tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
1545
1608
  return None
1546
1609
 
1547
- def get_scheduler_class(self):
1610
+ def tensormap_init(
1611
+ self,
1612
+ tensormaps: Optional[cute.Tensor],
1613
+ varlen_m: bool,
1614
+ varlen_k: bool,
1615
+ has_D: bool,
1616
+ warp_idx: Int32,
1617
+ ):
1618
+ tensormap_manager = None
1619
+ tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
1620
+ tensormap_epi_ptrs = [None] * self.num_epi_tensormaps
1621
+ if const_expr(varlen_m or varlen_k):
1622
+ tensormap_manager = TensorMapManagerSm90(
1623
+ cutlass.utils.TensorMapUpdateMode.GMEM, self.__class__.bytes_per_tensormap
1624
+ )
1625
+ # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
1626
+ tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
1627
+ if const_expr(varlen_m):
1628
+ tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
1629
+ tensormap_epi_offset = tensormap_d_idx
1630
+ if const_expr(has_D):
1631
+ tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
1632
+ tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
1633
+ )
1634
+ tensormap_epi_offset += 1 if not self.pingpong else 2
1635
+ tensormap_epi_ptrs = [
1636
+ tensormap_manager.get_tensormap_ptr(
1637
+ tensormaps[
1638
+ tensormap_workspace_idx,
1639
+ tensormap_epi_offset + i * (1 if not self.pingpong else 2),
1640
+ None,
1641
+ ].iterator
1642
+ )
1643
+ for i in range(self.num_epi_tensormaps)
1644
+ ]
1645
+ else:
1646
+ assert varlen_k
1647
+ if const_expr(not self.gather_A):
1648
+ tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
1649
+ tensormaps[tensormap_workspace_idx, 0, None].iterator
1650
+ )
1651
+ tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
1652
+ tensormaps[
1653
+ tensormap_workspace_idx, 1 if not self.gather_A else 0, None
1654
+ ].iterator
1655
+ )
1656
+ tensormap_ab_ptrs = [tensormap_a_ptr, tensormap_b_ptr]
1657
+ return (
1658
+ tensormap_manager,
1659
+ tensormap_ab_ptrs,
1660
+ tensormap_d_ptr,
1661
+ tensormap_epi_ptrs,
1662
+ )
1663
+
1664
+ def tensormap_update_AB(
1665
+ self,
1666
+ tensormap_manager,
1667
+ tensormap_ab_ptrs,
1668
+ cu_seqlens_k: cute.Tensor,
1669
+ batch_idx: Int32,
1670
+ is_manager_warp: bool | Boolean,
1671
+ ) -> None:
1672
+ # construct tensor A/B based on real address, shape and stride information
1673
+ tensormap_a_ptr, tensormap_b_ptr = tensormap_ab_ptrs
1674
+ tensormap_ptrs = [tensormap_b_ptr]
1675
+ shapes = [cu_seqlens_k[batch_idx + 1]]
1676
+ orders = [0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1]
1677
+ if const_expr(not self.gather_A):
1678
+ tensormap_ptrs.insert(0, tensormap_a_ptr)
1679
+ shapes.insert(0, cu_seqlens_k[batch_idx + 1])
1680
+ orders.insert(0, 0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1)
1681
+ tensormap_manager.update_tensormap_shape(
1682
+ tensormap_ptrs,
1683
+ is_manager_warp=is_manager_warp,
1684
+ shapes=shapes,
1685
+ orders=orders,
1686
+ tensormap_smem_ptr=None,
1687
+ )
1688
+
1689
+ def tensormap_update_D_epi(
1690
+ self,
1691
+ tensormap_manager,
1692
+ tensormap_d_ptr,
1693
+ tensormap_epi_ptrs,
1694
+ epilogue_params: EpilogueParams,
1695
+ cu_seqlens_m: cute.Tensor,
1696
+ batch_idx: Int32,
1697
+ is_manager_warp: bool | Boolean,
1698
+ ) -> None:
1699
+ # construct tensor D based on real address, shape and stride information
1700
+ tensormap_ptrs, shapes, orders = [], [], []
1701
+ if const_expr(tensormap_d_ptr is not None):
1702
+ tensormap_ptrs.append(tensormap_d_ptr)
1703
+ shapes.append(cu_seqlens_m[batch_idx + 1])
1704
+ orders.append(0 if const_expr(self.d_layout.is_m_major_c()) else 1)
1705
+ epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
1706
+ epilogue_params, cu_seqlens_m, batch_idx
1707
+ )
1708
+ tensormap_ptrs.extend(tensormap_epi_ptrs)
1709
+ shapes.extend(epi_shapes)
1710
+ orders.extend(epi_orders)
1711
+ tensormap_manager.update_tensormap_shape(
1712
+ tensormap_ptrs,
1713
+ is_manager_warp=is_manager_warp,
1714
+ shapes=shapes,
1715
+ orders=orders,
1716
+ tensormap_smem_ptr=None,
1717
+ )
1718
+
1719
+ def get_scheduler_class(self, varlen_m: bool = False):
1548
1720
  """Return the scheduler class to use. Override in subclasses for custom schedulers."""
1549
- return TileScheduler
1721
+ return TileScheduler if not varlen_m else VarlenMTileScheduler
1550
1722
 
1551
- def get_scheduler_arguments(self, problem_shape_ntile_mnl, scheduler_args):
1723
+ def get_scheduler_arguments(
1724
+ self,
1725
+ mA: cute.Tensor,
1726
+ mB: cute.Tensor,
1727
+ mD: Optional[cute.Tensor],
1728
+ scheduler_args,
1729
+ varlen_args,
1730
+ ):
1552
1731
  """Create scheduler arguments. Override in subclasses for custom schedulers."""
1553
- return TileSchedulerArguments(
1554
- problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1555
- raster_order=scheduler_args.raster_order,
1556
- group_size=scheduler_args.max_swizzle_size,
1557
- cluster_shape_mnk=self.cluster_shape_mnk,
1558
- tile_count_semaphore=scheduler_args.tile_count_semaphore,
1559
- batch_idx_permute=scheduler_args.batch_idx_permute,
1560
- is_persistent=self.is_persistent,
1561
- )
1732
+ if const_expr(varlen_args.mCuSeqlensM is None):
1733
+ num_problems = (
1734
+ mD.shape[2]
1735
+ if mD is not None
1736
+ else (
1737
+ mB.shape[2]
1738
+ if varlen_args.mCuSeqlensK is None
1739
+ else varlen_args.mCuSeqlensK.shape[0] - 1
1740
+ )
1741
+ )
1742
+ problem_shape_ntile_mnl = (
1743
+ cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]),
1744
+ cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
1745
+ num_problems,
1746
+ )
1747
+ tile_sched_args = TileSchedulerArguments(
1748
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1749
+ raster_order=scheduler_args.raster_order,
1750
+ group_size=scheduler_args.max_swizzle_size,
1751
+ cluster_shape_mnk=self.cluster_shape_mnk,
1752
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
1753
+ batch_idx_permute=scheduler_args.batch_idx_permute,
1754
+ is_persistent=self.is_persistent,
1755
+ )
1756
+ else:
1757
+ assert mD is not None or not self.gather_A
1758
+ problem_shape_ntile_mnl = (
1759
+ None,
1760
+ cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
1761
+ varlen_args.mCuSeqlensM.shape[0] - 1,
1762
+ )
1763
+ tile_sched_args = VarlenMTileSchedulerArguments(
1764
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1765
+ total_m=mD.shape[0] if mD is not None else varlen_args.mAIdx.shape[0],
1766
+ cu_seqlens_m=varlen_args.mCuSeqlensM,
1767
+ raster_order=scheduler_args.raster_order,
1768
+ group_size=scheduler_args.max_swizzle_size,
1769
+ tile_shape_mn=self.cta_tile_shape_mnk[:2],
1770
+ cluster_shape_mnk=self.cluster_shape_mnk,
1771
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
1772
+ is_persistent=self.is_persistent,
1773
+ )
1774
+ return tile_sched_args
1562
1775
 
1563
1776
  def epi_visit_acc(
1564
1777
  self,
@@ -1575,10 +1788,28 @@ class GemmSm90:
1575
1788
  ) -> EpilogueParams:
1576
1789
  return GemmSm90.EpilogueParams(alpha=args.alpha, beta=args.beta)
1577
1790
 
1791
+ def epi_get_tma_atoms(
1792
+ self, params: EpilogueParams, *, loc=None, ip=None
1793
+ ) -> list[cute.CopyAtom]:
1794
+ """Subclasses can override this"""
1795
+ return []
1796
+
1797
+ def epi_get_tensormap_update_shapes_orders(
1798
+ self,
1799
+ params: EpilogueParams,
1800
+ cu_seqlens_m: cute.Tensor,
1801
+ batch_idx: Int32,
1802
+ *,
1803
+ loc=None,
1804
+ ip=None,
1805
+ ) -> tuple[list[Int32], list[int]]:
1806
+ """Subclasses can override this"""
1807
+ return [], []
1808
+
1578
1809
  @staticmethod
1579
1810
  def epi_smem_bytes_per_stage(
1580
1811
  args: Optional[EpilogueArguments],
1581
- tile_shape_mnk: Tuple[int, int, int],
1812
+ cta_tile_shape_mnk: Tuple[int, int, int],
1582
1813
  epi_tile: Tuple[int, int],
1583
1814
  ) -> int:
1584
1815
  return 0
@@ -1589,7 +1820,7 @@ class GemmSm90:
1589
1820
  def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
1590
1821
  return tuple()
1591
1822
 
1592
- def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
1823
+ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
1593
1824
  assert stage in ["mma", "epi"]
1594
1825
  barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1595
1826
  cute.arch.barrier(
@@ -1597,7 +1828,7 @@ class GemmSm90:
1597
1828
  number_of_threads=2 * self.num_threads_per_warp_group,
1598
1829
  )
1599
1830
 
1600
- def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
1831
+ def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
1601
1832
  assert stage in ["mma", "epi"]
1602
1833
  barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1603
1834
  cute.arch.barrier_arrive(
@@ -1691,8 +1922,6 @@ class GemmSm90:
1691
1922
 
1692
1923
  def make_ab_pipeline(
1693
1924
  self,
1694
- a_smem_layout: cute.Layout | cute.ComposedLayout,
1695
- b_smem_layout: cute.Layout | cute.ComposedLayout,
1696
1925
  tiled_mma: cute.TiledMma,
1697
1926
  cluster_layout_vmnk: cute.Layout,
1698
1927
  ab_pipeline_mbar_ptr: cute.Pointer,
@@ -1702,20 +1931,23 @@ class GemmSm90:
1702
1931
  ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1703
1932
  # Each warp will contribute to the arrive count with the number of mcast size
1704
1933
  mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
1705
- consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
1934
+ consumer_arrive_cnt = mcast_size
1935
+ if const_expr(self.arch != 100):
1936
+ consumer_arrive_cnt *= tiled_mma.size // cute.arch.WARP_SIZE
1706
1937
  ab_pipeline_consumer_group = pipeline.CooperativeGroup(
1707
1938
  pipeline.Agent.Thread, consumer_arrive_cnt
1708
1939
  )
1709
- pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
1710
- tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
1711
- if const_expr(not self.gather_A):
1712
- tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
1940
+ if const_expr(self.arch != 100):
1941
+ pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
1942
+ else:
1943
+ # TODO: we need a pipeline class for TMACpAsyncUMMA
1944
+ pipeline_cls = pipeline.PipelineTmaUmma if not self.gather_A else PipelineTmaCpAsync
1713
1945
  return pipeline_cls.create(
1714
1946
  barrier_storage=ab_pipeline_mbar_ptr,
1715
1947
  num_stages=self.ab_stage,
1716
1948
  producer_group=ab_pipeline_producer_group,
1717
1949
  consumer_group=ab_pipeline_consumer_group,
1718
- tx_count=tma_copy_bytes,
1950
+ tx_count=self.num_tma_load_bytes,
1719
1951
  cta_layout_vmnk=cluster_layout_vmnk,
1720
1952
  )
1721
1953
 
@@ -1725,7 +1957,7 @@ class GemmSm90:
1725
1957
  # Threads/warps participating in this pipeline
1726
1958
  epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1727
1959
  # Each warp will contribute 1 to the arrive count
1728
- consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
1960
+ consumer_arrive_cnt = self.num_epi_warps
1729
1961
  epi_pipeline_consumer_group = pipeline.CooperativeGroup(
1730
1962
  pipeline.Agent.Thread, consumer_arrive_cnt
1731
1963
  )
@@ -1738,6 +1970,16 @@ class GemmSm90:
1738
1970
  tx_count=tma_copy_c_bytes,
1739
1971
  )
1740
1972
 
1973
+ def make_epi_store_pipeline(self):
1974
+ # Threads/warps participating in tma store pipeline
1975
+ num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
1976
+ epi_store_producer_group = pipeline.CooperativeGroup(
1977
+ pipeline.Agent.Thread, num_epi_threads, num_epi_threads
1978
+ )
1979
+ return pipeline.PipelineTmaStore.create(
1980
+ num_stages=self.epi_stage, producer_group=epi_store_producer_group
1981
+ )
1982
+
1741
1983
  def make_sched_pipeline(
1742
1984
  self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
1743
1985
  ):
@@ -1766,21 +2008,21 @@ class GemmSm90:
1766
2008
  @classmethod
1767
2009
  def _compute_stages(
1768
2010
  cls,
1769
- tile_shape_mnk: Tuple[int, int, int],
2011
+ cta_tile_shape_mnk: Tuple[int, int, int],
1770
2012
  epi_tile: Tuple[int, int],
1771
2013
  a_dtype: Type[cutlass.Numeric],
1772
2014
  b_dtype: Type[cutlass.Numeric],
1773
2015
  d_dtype: Optional[Type[cutlass.Numeric]],
1774
2016
  c_dtype: Optional[Type[cutlass.Numeric]],
1775
- epilogue_args: Optional[EpilogueArguments],
2017
+ epilogue_args: EpilogueArguments,
1776
2018
  smem_capacity: int,
1777
2019
  occupancy: int,
1778
- overlap_sD_sA: bool,
2020
+ overlap_sD_sA: bool = False,
1779
2021
  ) -> Tuple[int, int]:
1780
2022
  """Computes the number of stages for A/B/C operands based on heuristics.
1781
2023
 
1782
- :param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1783
- :type tile_shape_mnk: Tuple[int, int, int]
2024
+ :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
2025
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1784
2026
  :param a_dtype: Data type of operand A.
1785
2027
  :type a_dtype: type[cutlass.Numeric]
1786
2028
  :param b_dtype: Data type of operand B.
@@ -1803,15 +2045,15 @@ class GemmSm90:
1803
2045
  cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1804
2046
  )
1805
2047
  epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1806
- epilogue_args, tile_shape_mnk, epi_tile
2048
+ epilogue_args, cta_tile_shape_mnk, epi_tile
1807
2049
  )
1808
2050
  epi_bytes = epi_bytes_per_stage * epi_stage
1809
2051
  epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
1810
2052
  if c_dtype is not None:
1811
2053
  epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
1812
2054
 
1813
- a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
1814
- b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
2055
+ a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
2056
+ b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
1815
2057
  ab_bytes_per_stage = (
1816
2058
  cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
1817
2059
  )
@@ -1829,15 +2071,15 @@ class GemmSm90:
1829
2071
 
1830
2072
  @staticmethod
1831
2073
  def _sm90_compute_tile_shape_or_override(
1832
- tile_shape_mnk: Tuple[int, int, int],
2074
+ cta_tile_shape_mnk: Tuple[int, int, int],
1833
2075
  atom_layout_mnk: Tuple[int, int, int],
1834
2076
  element_type: Optional[Type[cutlass.Numeric]] = None,
1835
2077
  epi_tile_override: Tuple[int, int] | None = None,
1836
2078
  ) -> Tuple[int, int]:
1837
2079
  """Compute the epilogue tile shape or use override if provided.
1838
2080
 
1839
- :param tile_shape_mnk: CTA tile shape (M,N,K)
1840
- :type tile_shape_mnk: Tuple[int, int, int]
2081
+ :param cta_tile_shape_mnk: CTA tile shape (M,N,K)
2082
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1841
2083
  :param element_type: Data type of elements
1842
2084
  :type element_type: type[cutlass.Numeric]
1843
2085
  :param is_cooperative: Whether to use cooperative approach
@@ -1850,12 +2092,12 @@ class GemmSm90:
1850
2092
  """
1851
2093
  if epi_tile_override is not None:
1852
2094
  return epi_tile_override
1853
- if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
1854
- tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
1855
- tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1856
- elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
1857
- tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
1858
- tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
2095
+ if cta_tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
2096
+ tile_m = math.gcd(128, cute.size(cta_tile_shape_mnk, mode=[0]))
2097
+ tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
2098
+ elif cta_tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
2099
+ tile_m = math.gcd(192, cute.size(cta_tile_shape_mnk, mode=[0]))
2100
+ tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
1859
2101
  else:
1860
2102
  # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
1861
2103
  # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
@@ -1864,13 +2106,13 @@ class GemmSm90:
1864
2106
  # We could change the epilogue to accommodate this,
1865
2107
  # but it's easier to just set epi_tile_m = 64.
1866
2108
  n_perf = 64 if element_type is not None and element_type.width == 8 else 32
1867
- tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
1868
- tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
2109
+ tile_m = math.gcd(64, cute.size(cta_tile_shape_mnk, mode=[0]))
2110
+ tile_n = math.gcd(n_perf, cute.size(cta_tile_shape_mnk, mode=[1]))
1869
2111
  return (tile_m, tile_n)
1870
2112
 
1871
2113
  @staticmethod
1872
2114
  def _make_smem_layouts(
1873
- tile_shape_mnk: Tuple[int, int, int],
2115
+ cta_tile_shape_mnk: Tuple[int, int, int],
1874
2116
  epi_tile: Tuple[int, int],
1875
2117
  a_dtype: Type[cutlass.Numeric],
1876
2118
  a_layout: LayoutEnum,
@@ -1888,8 +2130,8 @@ class GemmSm90:
1888
2130
  ]:
1889
2131
  """Create shared memory layouts for A, B, and C tensors.
1890
2132
 
1891
- :param tile_shape_mnk: CTA tile shape (M,N,K)
1892
- :type tile_shape_mnk: Tuple[int, int, int]
2133
+ :param cta_tile_shape_mnk: CTA tile shape (M,N,K)
2134
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1893
2135
  :param epi_tile: Epilogue tile shape
1894
2136
  :type epi_tile: Tuple[int, int]
1895
2137
  :param a_dtype: Data type for matrix A
@@ -1912,11 +2154,11 @@ class GemmSm90:
1912
2154
  :return: Tuple of shared memory layouts for A, B, and C
1913
2155
  :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
1914
2156
  """
1915
- a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
2157
+ a_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
1916
2158
 
1917
2159
  a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1918
2160
  b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1919
- a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
2161
+ a_major_mode_size = cta_tile_shape_mnk[2 if a_is_k_major else 0]
1920
2162
  a_smem_layout_atom = warpgroup.make_smem_layout_atom(
1921
2163
  sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
1922
2164
  a_dtype,
@@ -1927,9 +2169,9 @@ class GemmSm90:
1927
2169
  order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
1928
2170
  )
1929
2171
 
1930
- b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None))
2172
+ b_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
1931
2173
 
1932
- b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
2174
+ b_major_mode_size = cta_tile_shape_mnk[2 if b_is_k_major else 1]
1933
2175
  b_smem_layout_atom = warpgroup.make_smem_layout_atom(
1934
2176
  sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
1935
2177
  b_dtype,
@@ -1983,7 +2225,7 @@ class GemmSm90:
1983
2225
  tensor_d: cute.Tensor,
1984
2226
  epi_smem_layout_staged: cute.ComposedLayout,
1985
2227
  epi_tile: Tuple[int, int],
1986
- store_or_load: str,
2228
+ op_type: Literal["store", "load", "add"],
1987
2229
  ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1988
2230
  """Create TMA atoms and tensors for storing D or loading C.
1989
2231
 
@@ -1997,13 +2239,15 @@ class GemmSm90:
1997
2239
  :return: TMA atom and tensor for C
1998
2240
  :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1999
2241
  """
2000
- assert store_or_load in ["load", "store"]
2242
+ assert op_type in ["load", "store", "add"]
2001
2243
  epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
2002
2244
  d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
2003
2245
  op = (
2004
2246
  cpasync.CopyBulkTensorTileG2SOp()
2005
- if store_or_load == "load"
2247
+ if op_type == "load"
2006
2248
  else cpasync.CopyBulkTensorTileS2GOp()
2249
+ if op_type == "store"
2250
+ else cpasync.CopyReduceBulkTensorTileS2GOp(cute.ReductionOp.ADD)
2007
2251
  )
2008
2252
  tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
2009
2253
  op, tensor_d, epi_smem_layout, d_cta_v_layout
@@ -2013,7 +2257,7 @@ class GemmSm90:
2013
2257
  @staticmethod
2014
2258
  def _make_tma_atoms_and_tensors(
2015
2259
  tensor: cute.Tensor,
2016
- smem_layout_staged: cute.ComposedLayout,
2260
+ smem_layout: cute.ComposedLayout,
2017
2261
  smem_tile: Tuple[int, int],
2018
2262
  mcast_dim: int,
2019
2263
  ) -> Tuple[cute.CopyAtom, cute.Tensor]:
@@ -2021,8 +2265,8 @@ class GemmSm90:
2021
2265
 
2022
2266
  :param tensor: Input tensor (A or B)
2023
2267
  :type tensor: cute.Tensor
2024
- :param smem_layout_staged: Shared memory layout for the tensor
2025
- :type smem_layout_staged: cute.ComposedLayout
2268
+ :param smem_layout: Shared memory layout for the tensor
2269
+ :type smem_layout: cute.ComposedLayout
2026
2270
  :param smem_tile: Shared memory tile shape
2027
2271
  :type smem_tile: Tuple[int, int]
2028
2272
  :param mcast_dim: Multicast dimension
@@ -2036,8 +2280,6 @@ class GemmSm90:
2036
2280
  if mcast_dim == 1
2037
2281
  else cpasync.CopyBulkTensorTileG2SMulticastOp()
2038
2282
  )
2039
-
2040
- smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
2041
2283
  tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
2042
2284
  op,
2043
2285
  tensor,
@@ -2054,13 +2296,18 @@ class GemmSm90:
2054
2296
  num_bits_per_copy=copy_bits,
2055
2297
  )
2056
2298
  copy_elems = copy_bits // dtype.width
2057
- shape_dim_1 = cute.size(self.tile_shape_mnk[2]) // copy_elems
2299
+ loads_per_cache_line = 128 * 8 // copy_bits # 128 bytes per cache line
2300
+ shape_dim_1 = cute.size(self.cta_tile_shape_mnk[2]) // copy_elems
2301
+ if shape_dim_1 > loads_per_cache_line:
2302
+ shape_dim_1 = math.gcd(shape_dim_1, loads_per_cache_line)
2058
2303
  # thread layout for copy
2059
2304
  thread_layout = cute.make_layout(
2060
2305
  (num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
2061
2306
  )
2062
2307
  if major_mode != LayoutEnum.ROW_MAJOR:
2063
- shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
2308
+ shape_dim_0 = cute.size(self.cta_tile_shape_mnk[0]) // copy_elems
2309
+ if shape_dim_0 > loads_per_cache_line:
2310
+ shape_dim_0 = math.gcd(shape_dim_0, loads_per_cache_line)
2064
2311
  thread_layout = cute.make_layout(
2065
2312
  (shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
2066
2313
  )
@@ -2142,10 +2389,11 @@ class GemmSm90:
2142
2389
 
2143
2390
 
2144
2391
  def gemm_sm90(
2145
- A: Tensor, # (l, m, k)
2146
- B: Tensor, # (l, n, k)
2147
- D: Tensor, # (l, m, n)
2148
- C: Optional[Tensor], # (l, m, n)
2392
+ # (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
2393
+ A: Tensor,
2394
+ B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
2395
+ D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
2396
+ C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
2149
2397
  tile_count_semaphore: Optional[Tensor], # (1,)
2150
2398
  tile_M: int,
2151
2399
  tile_N: int,
@@ -2155,9 +2403,37 @@ def gemm_sm90(
2155
2403
  persistent: bool = True,
2156
2404
  alpha: float | Tensor = 1.0,
2157
2405
  beta: float | Tensor = 1.0,
2406
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
2407
+ cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
2408
+ A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
2409
+ batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
2410
+ add_to_output: bool = False,
2158
2411
  ) -> None:
2159
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(A, B, D, C)
2160
- GemmWrapperBase.permute_tensors(tensor_infos)
2412
+ varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
2413
+ assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
2414
+ "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
2415
+ )
2416
+ gather_A = A_idx is not None
2417
+ if gather_A:
2418
+ assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
2419
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
2420
+ if varlen:
2421
+ assert persistent, "varlen requires persistent=True"
2422
+ if add_to_output:
2423
+ assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
2424
+ if cu_seqlens_m is not None:
2425
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
2426
+ assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
2427
+ if cu_seqlens_k is not None:
2428
+ assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
2429
+ assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
2430
+
2431
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
2432
+ A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
2433
+ )
2434
+ GemmWrapperBase.permute_tensors(
2435
+ tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
2436
+ )
2161
2437
  GemmWrapperBase.extract_dtypes(tensor_infos)
2162
2438
  major_configs = {
2163
2439
  "A": ("m", "k", "l"),
@@ -2190,10 +2466,25 @@ def gemm_sm90(
2190
2466
  assert isinstance(scalar, Tensor)
2191
2467
  return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2192
2468
 
2193
- epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta))
2469
+ epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta), add_to_output)
2194
2470
  scheduler_args = GemmWrapperBase.create_scheduler_args(
2195
- max_active_clusters, tile_count_semaphore
2471
+ max_active_clusters,
2472
+ tile_count_semaphore,
2473
+ batch_idx_permute,
2474
+ )
2475
+
2476
+ # Create varlen arguments if needed (assumes persistent=True when varlen)
2477
+ varlen_args = GemmWrapperBase.create_varlen_args(
2478
+ cu_seqlens_m,
2479
+ cu_seqlens_k,
2480
+ A_idx,
2481
+ max_active_clusters,
2482
+ cluster_shape_mnk,
2483
+ tensor_infos,
2484
+ GemmSm90.num_epi_tensormaps,
2485
+ pingpong,
2196
2486
  )
2487
+
2197
2488
  current_stream = cutlass_torch.current_stream()
2198
2489
  compile_key = GemmWrapperBase.get_compile_key(
2199
2490
  tensor_infos,
@@ -2205,6 +2496,11 @@ def gemm_sm90(
2205
2496
  tile_count_semaphore is not None,
2206
2497
  2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
2207
2498
  2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
2499
+ add_to_output,
2500
+ cu_seqlens_m is not None,
2501
+ cu_seqlens_k is not None,
2502
+ gather_A,
2503
+ batch_idx_permute is not None,
2208
2504
  key_tensor_names=("A", "B", "D", "C"),
2209
2505
  )
2210
2506
  cache = gemm_sm90.compile_cache
@@ -2216,6 +2512,7 @@ def gemm_sm90(
2216
2512
  cluster_shape_mnk,
2217
2513
  pingpong=pingpong,
2218
2514
  is_persistent=persistent,
2515
+ gather_A=gather_A,
2219
2516
  )
2220
2517
  cache[compile_key] = cute.compile(
2221
2518
  gemm,
@@ -2225,8 +2522,7 @@ def gemm_sm90(
2225
2522
  tensor_infos["C"].cute_tensor,
2226
2523
  epi_args,
2227
2524
  scheduler_args,
2228
- None, # varlen_args
2229
- None, # mAIdx
2525
+ varlen_args,
2230
2526
  current_stream,
2231
2527
  )
2232
2528
  cache[compile_key](
@@ -2236,8 +2532,7 @@ def gemm_sm90(
2236
2532
  tensor_infos["C"].cute_tensor,
2237
2533
  epi_args,
2238
2534
  scheduler_args,
2239
- None,
2240
- None,
2535
+ varlen_args,
2241
2536
  current_stream,
2242
2537
  )
2243
2538