quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/gemm_sm100.py CHANGED
@@ -2,7 +2,7 @@
2
2
  # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
3
3
 
4
4
  import argparse
5
- from typing import Optional, Type, Tuple, Union, Callable
5
+ from typing import Optional, Type, Tuple, Union, Callable, Literal
6
6
  from functools import partial
7
7
 
8
8
  import cuda.bindings.driver as cuda
@@ -15,14 +15,23 @@ import cutlass.torch as cutlass_torch
15
15
  import cutlass.pipeline as pipeline
16
16
  import cutlass.utils.blackwell_helpers as sm100_utils
17
17
  import cutlass.utils.blockscaled_layout as blockscaled_utils
18
+ from cutlass.cute.nvgpu.warp import (
19
+ LdMatrix8x8x16bOp,
20
+ LdMatrix16x16x8bOp,
21
+ StMatrix8x8x16bOp,
22
+ StMatrix16x8x8bOp,
23
+ )
18
24
  from cutlass import Int32, Float32, Boolean, const_expr
19
25
  from cutlass.utils import LayoutEnum
20
26
  from cutlass.cute.runtime import from_dlpack, make_ptr
21
27
 
28
+ from quack.pipeline import PipelineTmaCpAsyncUmma
22
29
  from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
23
30
  from quack.tile_scheduler import TileSchedulerOptions
24
- from quack.varlen_utils import VarlenArguments
25
- from quack.dense_gemm_sm90 import GemmSm90, NamedBarrierGemm
31
+ from quack.varlen_utils import VarlenArguments, VarlenManager
32
+ from quack.gemm_sm90 import GemmSm90, NamedBarrierGemm
33
+ import quack.copy_utils as copy_utils
34
+ import quack.sm100_utils as quack_sm100_utils
26
35
 
27
36
  # return PipelineStateWAdvance instead of PipelineState
28
37
 
@@ -148,6 +157,7 @@ class GemmSm100(GemmSm90):
148
157
  def __init__(
149
158
  self,
150
159
  acc_dtype: Type[cutlass.Numeric],
160
+ a_dtype: Type[cutlass.Numeric], # ignored for now
151
161
  mma_tiler_mn: Tuple[int, int],
152
162
  cluster_shape_mnk: Tuple[int, int, int],
153
163
  sf_vec_size: Optional[int] = None,
@@ -175,7 +185,7 @@ class GemmSm100(GemmSm90):
175
185
  """
176
186
 
177
187
  self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
178
- self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (128, 256)
188
+ self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (256,)
179
189
  self.cluster_shape_mnk = cluster_shape_mnk
180
190
  assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1"
181
191
  # K dimension is deferred in _setup_attributes
@@ -190,19 +200,28 @@ class GemmSm100(GemmSm90):
190
200
 
191
201
  self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
192
202
 
203
+ self.num_ab_load_warps = 1 if not self.gather_A else 5
193
204
  self.occupancy = 1
194
205
  # Set specialized warp ids
195
206
  self.epilog_warp_id = (0, 1, 2, 3)
196
207
  self.mma_warp_id = 4
197
- self.tma_warp_id = 5
198
- self.tma_epi_warp_id = 6
208
+ self.ab_load_warp_id = 5
209
+ self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
210
+ self.scheduler_warp_id = self.epi_load_warp_id + 1
199
211
  self.num_epi_warps = len(self.epilog_warp_id)
200
- self.threads_per_cta = 32 * len(
201
- (self.mma_warp_id, self.tma_warp_id, self.tma_epi_warp_id, *self.epilog_warp_id)
212
+ self.threads_per_cta = cute.arch.WARP_SIZE * (
213
+ self.num_ab_load_warps
214
+ + len(
215
+ (
216
+ self.mma_warp_id,
217
+ self.epi_load_warp_id,
218
+ self.scheduler_warp_id,
219
+ *self.epilog_warp_id,
220
+ )
221
+ )
202
222
  )
203
- self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_100")
204
223
 
205
- def _setup_attributes(self):
224
+ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments):
206
225
  """Set up configurations that are dependent on GEMM inputs
207
226
 
208
227
  This method configures various attributes based on the input tensor properties
@@ -298,6 +317,8 @@ class GemmSm100(GemmSm90):
298
317
 
299
318
  # Compute number of multicast CTAs for A/B
300
319
  self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
320
+ if self.gather_A:
321
+ assert self.num_mcast_ctas_a == 1
301
322
  self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
302
323
  self.is_a_mcast = self.num_mcast_ctas_a > 1
303
324
  self.is_b_mcast = self.num_mcast_ctas_b > 1
@@ -309,11 +330,18 @@ class GemmSm100(GemmSm90):
309
330
  self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
310
331
  self.cta_tile_shape_mnk,
311
332
  self.use_2cta_instrs,
312
- self.d_layout,
313
- self.d_dtype,
333
+ self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR,
334
+ self.d_dtype if self.d_dtype is not None else cutlass.BFloat16,
335
+ layout_c=self.c_layout,
336
+ elem_ty_c=self.c_dtype,
314
337
  )
315
338
 
316
339
  # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
340
+ prefetch_A_idx = (
341
+ None
342
+ if not self.gather_A
343
+ else ("varlen_m" if varlen_args.mCuSeqlensM is not None else "varlen_k")
344
+ )
317
345
  (
318
346
  self.num_acc_stage,
319
347
  self.ab_stage,
@@ -322,36 +350,50 @@ class GemmSm100(GemmSm90):
322
350
  ) = self._compute_stages(
323
351
  self.tiled_mma,
324
352
  self.mma_tiler,
353
+ self.cta_tile_shape_mnk,
354
+ self.epi_tile,
325
355
  self.a_dtype,
326
356
  self.b_dtype,
327
- self.epi_tile,
357
+ self.sf_dtype,
358
+ self.sf_vec_size,
328
359
  self.d_dtype,
329
360
  self.c_dtype,
330
361
  self.d_layout,
331
362
  self.c_layout,
332
- self.sf_dtype,
333
- self.sf_vec_size,
334
- self.smem_capacity,
363
+ epilogue_args,
364
+ prefetch_A_idx,
365
+ cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
335
366
  self.occupancy,
336
367
  )
337
- self.sched_stage = 1 # For compatibility with GemmSm90
368
+ self.sched_stage = 1
369
+ self.a_prefetch_stage = (
370
+ 0
371
+ if not self.gather_A
372
+ else (2 if varlen_args.mCuSeqlensM is not None else self.ab_stage)
373
+ )
338
374
 
339
375
  # Compute A/B/SFA/SFB/C shared memory layout
340
376
  self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
341
377
  self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
342
378
  )
379
+ self.a_smem_load_layout_staged = self.a_smem_layout_staged
380
+ if const_expr(self.gather_A):
381
+ self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a(
382
+ self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
383
+ )
343
384
  self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
344
385
  self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage
345
386
  )
346
- self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi(
347
- self.d_dtype, self.d_layout, self.epi_tile, self.epi_stage
348
- )
387
+ self.epi_smem_layout_staged = None
388
+ if const_expr(self.d_dtype is not None):
389
+ self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi(
390
+ self.d_dtype, self.d_layout, self.epi_tile, self.epi_stage
391
+ )
392
+ self.epi_c_smem_layout_staged = None
349
393
  if const_expr(self.c_dtype is not None):
350
394
  self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
351
395
  self.c_dtype, self.c_layout, self.epi_tile, self.epi_c_stage
352
396
  )
353
- else:
354
- self.epi_c_smem_layout_staged = None
355
397
  if const_expr(self.blockscaled):
356
398
  self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
357
399
  self.tiled_mma,
@@ -449,7 +491,7 @@ class GemmSm100(GemmSm90):
449
491
  ]
450
492
 
451
493
  # Setup attributes that dependent on gemm inputs
452
- self._setup_attributes()
494
+ self._setup_attributes(epilogue_args, varlen_args)
453
495
 
454
496
  if const_expr(self.blockscaled):
455
497
  # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
@@ -536,24 +578,22 @@ class GemmSm100(GemmSm90):
536
578
  # Setup TMA store for D
537
579
  tma_atom_d, tma_tensor_d = None, None
538
580
  if const_expr(mD is not None):
539
- epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0))
540
- tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
541
- cpasync.CopyBulkTensorTileS2GOp(),
581
+ tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
542
582
  mD,
543
- epi_smem_layout,
583
+ self.epi_smem_layout_staged,
544
584
  self.epi_tile,
585
+ op_type="store"
586
+ if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
587
+ else "add",
545
588
  )
546
589
  tma_atom_c, tma_tensor_c = None, None
547
590
  if const_expr(mC is not None):
548
- epi_c_smem_layout = cute.slice_(self.epi_c_smem_layout_staged, (None, None, 0))
549
- tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
550
- cpasync.CopyBulkTensorTileG2SOp(),
551
- mC,
552
- epi_c_smem_layout,
553
- self.epi_tile,
591
+ tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
592
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
554
593
  )
555
594
 
556
595
  epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
596
+ varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
557
597
 
558
598
  TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
559
599
  tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
@@ -573,6 +613,13 @@ class GemmSm100(GemmSm90):
573
613
  sfb_smem_size = (
574
614
  cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0
575
615
  )
616
+ a_idx_smem_size = 0
617
+ if const_expr(self.gather_A):
618
+ a_idx_smem_size = self.a_prefetch_stage * (
619
+ self.cta_tile_shape_mnk[0]
620
+ if varlen_args.mCuSeqlensM is not None
621
+ else self.cta_tile_shape_mnk[2]
622
+ )
576
623
 
577
624
  # Define shared storage for kernel
578
625
  @cute.struct
@@ -581,9 +628,13 @@ class GemmSm100(GemmSm90):
581
628
  epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
582
629
  acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
583
630
  sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
584
- tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
631
+ a_prefetch_pipeline_array_ptr: cute.struct.MemRange[
632
+ cutlass.Int64, self.a_prefetch_stage * 2
633
+ ]
634
+ tile_count: cute.struct.MemRange[Int32, self.sched_stage]
585
635
  tmem_dealloc_mbar_ptr: cutlass.Int64
586
636
  tmem_holding_buf: Int32
637
+ sAIdx: cute.struct.Align[cute.struct.MemRange[Int32, a_idx_smem_size], 16]
587
638
  # (EPI_TILE_M, EPI_TILE_N, STAGE)
588
639
  sD: cute.struct.Align[
589
640
  cute.struct.MemRange[
@@ -638,13 +689,11 @@ class GemmSm100(GemmSm90):
638
689
  tma_atom_c,
639
690
  tma_tensor_c,
640
691
  epilogue_params,
641
- varlen_args.mCuSeqlensM,
642
- varlen_args.mCuSeqlensK,
643
- varlen_args.mTensormaps,
644
- varlen_args.mAIdx,
692
+ varlen_params,
645
693
  self.cluster_layout_vmnk,
646
694
  self.cluster_layout_sfb_vmnk,
647
695
  self.a_smem_layout_staged,
696
+ self.a_smem_load_layout_staged,
648
697
  self.b_smem_layout_staged,
649
698
  self.sfa_smem_layout_staged,
650
699
  self.sfb_smem_layout_staged,
@@ -657,7 +706,6 @@ class GemmSm100(GemmSm90):
657
706
  grid=grid,
658
707
  block=[self.threads_per_cta, 1, 1],
659
708
  cluster=self.cluster_shape_mnk,
660
- smem=self.shared_storage.size_in_bytes(),
661
709
  stream=stream,
662
710
  min_blocks_per_mp=1,
663
711
  )
@@ -682,13 +730,11 @@ class GemmSm100(GemmSm90):
682
730
  tma_atom_c: Optional[cute.CopyAtom],
683
731
  mC_mnl: Optional[cute.Tensor],
684
732
  epilogue_params: ParamsBase,
685
- cu_seqlens_m: Optional[cute.Tensor],
686
- cu_seqlens_k: Optional[cute.Tensor],
687
- tensormaps: Optional[cute.Tensor],
688
- mAIdx: Optional[cute.Tensor],
733
+ varlen_params: VarlenManager.Params,
689
734
  cluster_layout_vmnk: cute.Layout,
690
735
  cluster_layout_sfb_vmnk: Optional[cute.Layout],
691
736
  a_smem_layout: cute.ComposedLayout,
737
+ a_smem_load_layout: cute.ComposedLayout,
692
738
  b_smem_layout: cute.ComposedLayout,
693
739
  sfa_smem_layout: Optional[cute.Layout],
694
740
  sfb_smem_layout: Optional[cute.Layout],
@@ -702,8 +748,8 @@ class GemmSm100(GemmSm90):
702
748
  GPU device kernel performing the Persistent batched GEMM computation.
703
749
  """
704
750
 
705
- varlen_m = const_expr(cu_seqlens_m is not None)
706
- varlen_k = const_expr(cu_seqlens_k is not None)
751
+ varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
752
+ varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
707
753
  assert not (varlen_m and varlen_k)
708
754
  if const_expr(self.gather_A):
709
755
  assert varlen_m or varlen_k
@@ -715,7 +761,7 @@ class GemmSm100(GemmSm90):
715
761
  # /////////////////////////////////////////////////////////////////////////////
716
762
  # Prefetch Tma desc
717
763
  # /////////////////////////////////////////////////////////////////////////////
718
- if warp_idx == self.tma_warp_id:
764
+ if warp_idx == self.ab_load_warp_id:
719
765
  for tma_atom in (
720
766
  tma_atom_a,
721
767
  tma_atom_b,
@@ -751,7 +797,7 @@ class GemmSm100(GemmSm90):
751
797
 
752
798
  # Tensor memory dealloc barrier init
753
799
  if use_2cta_instrs:
754
- if warp_idx == self.tma_warp_id:
800
+ if warp_idx == self.ab_load_warp_id:
755
801
  num_tmem_dealloc_threads = 32
756
802
  cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
757
803
 
@@ -760,6 +806,7 @@ class GemmSm100(GemmSm90):
760
806
  tiled_mma=tiled_mma,
761
807
  cluster_layout_vmnk=cluster_layout_vmnk,
762
808
  ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
809
+ is_leader_cta=is_leader_cta,
763
810
  )
764
811
  epi_pipeline = None
765
812
  if const_expr(has_C):
@@ -774,20 +821,30 @@ class GemmSm100(GemmSm90):
774
821
  sched_pipeline = None
775
822
  tile_count = None
776
823
  if const_expr(tile_sched_params.tile_count_semaphore is not None):
777
- # TODO: Untested, not sure if this is right for Sm100
778
824
  # Dynamic persistent scheduler
779
825
  sched_pipeline = self.make_sched_pipeline(
780
826
  self.cluster_shape_mnk,
781
827
  sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
782
- varlen_k=varlen_k,
828
+ has_C=has_C,
783
829
  )
784
830
  tile_count = storage.tile_count.get_tensor((self.sched_stage,))
831
+ a_prefetch_pipeline = None
832
+ if const_expr(self.gather_A):
833
+ a_prefetch_pipeline = self.make_a_prefetch_pipeline(
834
+ storage.a_prefetch_pipeline_array_ptr.data_ptr(),
835
+ )
785
836
 
786
837
  # Setup smem tensor A/B/D
787
838
  # (MMA, MMA_M, MMA_K, STAGE)
788
- sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
839
+ sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
840
+ sA = storage.sA.get_tensor(a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner)
789
841
  # (MMA, MMA_N, MMA_K, STAGE)
790
842
  sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
843
+ sAIdx = None
844
+ if const_expr(self.gather_A):
845
+ a_idx_smem_dim = self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2]
846
+ a_idx_smem_layout = cute.make_layout((a_idx_smem_dim, self.a_prefetch_stage))
847
+ sAIdx = storage.sAIdx.get_tensor(a_idx_smem_layout)
791
848
  sSFA, sSFB = None, None
792
849
  if const_expr(self.blockscaled):
793
850
  # (MMA, MMA_M, MMA_K, STAGE)
@@ -813,9 +870,17 @@ class GemmSm100(GemmSm90):
813
870
  # (MMA, MMA_M, MMA_N, STAGE)
814
871
  tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
815
872
 
816
- # Get tensormap buffer address
817
- tensormap_manager, tensormap_ab_ptrs, tensormap_d_ptr, tensormap_epi_ptrs = (
818
- self.tensormap_init(tensormaps, varlen_m, varlen_k, has_D, warp_idx)
873
+ varlen_manager = VarlenManager.create(
874
+ varlen_params,
875
+ has_D,
876
+ self.num_epi_tensormaps,
877
+ # Only used if not varlen_m
878
+ len_m_static=Int32(
879
+ mA_mkl.shape[0]
880
+ if varlen_k or varlen_params.mAIdx is None
881
+ else varlen_params.mAIdx.shape[0]
882
+ ),
883
+ len_k_static=Int32(mA_mkl.shape[1]),
819
884
  )
820
885
 
821
886
  TileSchedulerCls = partial(
@@ -833,22 +898,14 @@ class GemmSm100(GemmSm90):
833
898
  )
834
899
 
835
900
  #
836
- # Specialized TMA load warp
901
+ # Specialized AB load warps
837
902
  #
838
- if warp_idx == self.tma_warp_id:
839
- if const_expr(varlen_k):
840
- # initialize tensormap for A & B
841
- if const_expr(not self.gather_A):
842
- tensormap_manager.init_tensormap_from_atom(
843
- tma_atom_a,
844
- tensormap_ab_ptrs[0],
845
- is_manager_warp=True,
846
- )
847
- tensormap_manager.init_tensormap_from_atom(
848
- tma_atom_b,
849
- tensormap_ab_ptrs[1],
850
- is_manager_warp=True,
851
- )
903
+ if warp_idx == self.ab_load_warp_id:
904
+ is_tma_warp = True
905
+ # initialize tensormap for A & B
906
+ varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
907
+ tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
908
+ tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
852
909
  # Compute multicast mask for A/B buffer full
853
910
  block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
854
911
  block_in_cluster_coord_sfb_vmnk = None
@@ -874,34 +931,24 @@ class GemmSm100(GemmSm90):
874
931
  )
875
932
 
876
933
  # Persistent tile scheduling loop
877
- is_scheduler_warp = True
878
- if const_expr(cute.size(cluster_layout_vmnk) > 1):
879
- is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0
880
- tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
934
+ tile_scheduler = TileSchedulerCls()
881
935
  work_tile = tile_scheduler.initial_work_tile_info()
882
936
  ab_producer_state = pipeline.make_pipeline_state(
883
937
  pipeline.PipelineUserType.Producer, self.ab_stage
884
938
  )
885
939
  if const_expr(varlen_k):
886
940
  # wait tensormap initialization complete before update
887
- tensormap_manager.fence_tensormap_initialization()
888
- # batch index of last tile
889
- last_batch_idx = cutlass.Int32(-1)
941
+ varlen_manager.fence_tensormap_init()
890
942
  do_epi_load_barrier_arrive = Boolean(True)
891
943
  while work_tile.is_valid_tile:
892
944
  tile_coord_mnkl = work_tile.tile_idx
893
945
  batch_idx = tile_coord_mnkl[3]
894
- if const_expr(varlen_k):
895
- is_group_changed = batch_idx != last_batch_idx
896
- last_batch_idx = batch_idx
897
- if is_group_changed:
898
- self.tensormap_update_AB(
899
- tensormap_manager,
900
- tensormap_ab_ptrs,
901
- cu_seqlens_k,
902
- batch_idx,
903
- is_manager_warp=True,
904
- )
946
+ varlen_manager.update_tensormap_AB(
947
+ batch_idx,
948
+ self.a_layout,
949
+ self.b_layout,
950
+ is_tma_warp,
951
+ )
905
952
  # ///////////////////////////////////////////////////////////////////////////
906
953
  # Local_tile partition global tensors
907
954
  # ///////////////////////////////////////////////////////////////////////////
@@ -910,120 +957,111 @@ class GemmSm100(GemmSm90):
910
957
  tile_coord_mnkl[1],
911
958
  tile_coord_mnkl[3],
912
959
  )
913
- # TODO: varlen_m
914
- # (bM, bK, RestK)
915
- gA_mkl = cute.local_tile(
916
- mA_mkl,
917
- cute.slice_(self.mma_tiler, (None, 0, None)),
918
- (mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]),
919
- )
960
+ gA_mk = None
961
+ if const_expr(not self.gather_A):
962
+ mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
963
+ # (bM, bK, RestK)
964
+ gA_mk = cute.local_tile(
965
+ mA_mk,
966
+ cute.select(self.mma_tiler, [0, 2]),
967
+ (mma_tile_coord_mnl[0], None),
968
+ )
920
969
  # (bN, bK, RestK)
921
- gB_nkl = cute.local_tile(
922
- mB_nkl,
923
- cute.slice_(self.mma_tiler, (0, None, None)),
924
- (mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]),
970
+ gB_nk = cute.local_tile(
971
+ varlen_manager.offset_batch_B(mB_nkl, batch_idx),
972
+ cute.select(self.mma_tiler, [1, 2]),
973
+ (mma_tile_coord_mnl[1], None),
925
974
  )
926
975
  if const_expr(self.blockscaled):
927
976
  # (bM, bK)
928
977
  gSFA_mkl = cute.local_tile(
929
- mSFA_mkl,
930
- cute.slice_(self.mma_tiler, (None, 0, None)),
931
- (mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]),
978
+ varlen_manager.offset_batch_A(mSFA_mkl, batch_idx),
979
+ cute.select(self.mma_tiler, [0, 2]),
980
+ (mma_tile_coord_mnl[0], None),
932
981
  )
933
982
  # (bN, bK)
934
983
  gSFB_nkl = cute.local_tile(
935
- mSFB_nkl,
936
- cute.slice_(self.mma_tiler, (0, None, None)),
937
- (mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]),
984
+ varlen_manager.offset_batch_B(mSFB_nkl, batch_idx),
985
+ cute.select(self.mma_tiler, [1, 2]),
986
+ (mma_tile_coord_mnl[1], None),
938
987
  )
988
+
939
989
  # Partition global tensor for TiledMMA_A/B/D
940
- # (MMA, MMA_M, MMA_K, RestK)
941
- tCgA = thr_mma.partition_A(gA_mkl)
990
+ # Then partition global/shared tensor for TMA load A/B
991
+ varlen_manager.fence_tensormap_update_AB(is_tma_warp)
992
+ len_k = varlen_manager.len_k(batch_idx)
993
+ # TMA load A partition_S/D
994
+ a_cta_layout = cute.make_layout(
995
+ cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
996
+ )
997
+ copy_A = None
998
+ if const_expr(not self.gather_A):
999
+ # (MMA, MMA_M, MMA_K, RestK)
1000
+ tCgA = thr_mma.partition_A(gA_mk)
1001
+ copy_A, _, _ = copy_utils.tma_get_copy_fn(
1002
+ tma_atom_a,
1003
+ cta_coord=block_in_cluster_coord_vmnk[2],
1004
+ cta_layout=a_cta_layout,
1005
+ src_tensor=tCgA,
1006
+ dst_tensor=sA,
1007
+ mcast_mask=a_mcast_mask,
1008
+ tma_desc_ptr=tma_desc_a_ptr,
1009
+ )
942
1010
  # (MMA, MMA_N, MMA_K, RestK)
943
- tCgB = thr_mma.partition_B(gB_nkl)
1011
+ tCgB = thr_mma.partition_B(gB_nk)
944
1012
  if const_expr(self.blockscaled):
945
1013
  # (MMA, MMA_M, MMA_K)
946
1014
  tCgSFA = thr_mma.partition_A(gSFA_mkl)
947
1015
  # (MMA, MMA_N, MMA_K)
948
1016
  tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
949
- # Partition global/shared tensor for TMA load A/B
950
- # TMA load A partition_S/D
951
- a_cta_layout = cute.make_layout(
952
- cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
953
- )
954
- # ((atom_v, rest_v), STAGE)
955
- # ((atom_v, rest_v), RestK)
956
- tAsA, tAgA = cpasync.tma_partition(
957
- tma_atom_a,
958
- block_in_cluster_coord_vmnk[2],
959
- a_cta_layout,
960
- cute.group_modes(sA, 0, 3),
961
- cute.group_modes(tCgA, 0, 3),
962
- )
963
1017
  # TMA load B partition_S/D
964
- b_cta_layout = cute.make_layout(
965
- cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
966
- )
967
- # ((atom_v, rest_v), STAGE)
968
- # ((atom_v, rest_v), RestK)
969
- tBsB, tBgB = cpasync.tma_partition(
1018
+ copy_B, _, _ = copy_utils.tma_get_copy_fn(
970
1019
  tma_atom_b,
971
- block_in_cluster_coord_vmnk[1],
972
- b_cta_layout,
973
- cute.group_modes(sB, 0, 3),
974
- cute.group_modes(tCgB, 0, 3),
1020
+ cta_coord=block_in_cluster_coord_vmnk[1],
1021
+ cta_layout=cute.make_layout(
1022
+ cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
1023
+ ),
1024
+ src_tensor=tCgB,
1025
+ dst_tensor=sB,
1026
+ mcast_mask=b_mcast_mask,
1027
+ tma_desc_ptr=tma_desc_b_ptr,
975
1028
  )
1029
+ copy_SFA, copy_SFB = None, None
976
1030
  if const_expr(self.blockscaled):
977
1031
  # TMA load SFA partition_S/D
978
- sfa_cta_layout = a_cta_layout
979
- # ((atom_v, rest_v), STAGE)
980
- # ((atom_v, rest_v), RestK)
981
- tAsSFA, tAgSFA = cpasync.tma_partition(
1032
+ copy_SFA, _, _ = copy_utils.tma_get_copy_fn(
982
1033
  tma_atom_sfa,
983
- block_in_cluster_coord_vmnk[2],
984
- sfa_cta_layout,
985
- cute.group_modes(sSFA, 0, 3),
986
- cute.group_modes(tCgSFA, 0, 3),
1034
+ cta_coord=block_in_cluster_coord_vmnk[2],
1035
+ cta_layout=a_cta_layout,
1036
+ src_tensor=tCgSFA,
1037
+ dst_tensor=sSFA,
1038
+ filter_zeros=True,
1039
+ mcast_mask=sfa_mcast_mask,
1040
+ # tma_desc_ptr=tma_desc_sfa_ptr,
987
1041
  )
988
- tAsSFA = cute.filter_zeros(tAsSFA)
989
- tAgSFA = cute.filter_zeros(tAgSFA)
990
1042
  # TMA load SFB partition_S/D
991
1043
  sfb_cta_layout = cute.make_layout(
992
1044
  cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape
993
1045
  )
994
- # ((atom_v, rest_v), STAGE)
995
- # ((atom_v, rest_v), RestK)
996
- tBsSFB, tBgSFB = cpasync.tma_partition(
1046
+ copy_SFB, _, _ = copy_utils.tma_get_copy_fn(
997
1047
  tma_atom_sfb,
998
- block_in_cluster_coord_sfb_vmnk[1],
999
- sfb_cta_layout,
1000
- cute.group_modes(sSFB, 0, 3),
1001
- cute.group_modes(tCgSFB, 0, 3),
1048
+ cta_coord=block_in_cluster_coord_sfb_vmnk[1],
1049
+ cta_layout=sfb_cta_layout,
1050
+ src_tensor=tCgSFB,
1051
+ dst_tensor=sSFB,
1052
+ filter_zeros=True,
1053
+ mcast_mask=sfb_mcast_mask,
1054
+ # tma_desc_ptr=tma_desc_sfa_ptr,
1002
1055
  )
1003
- tBsSFB = cute.filter_zeros(tBsSFB)
1004
- tBgSFB = cute.filter_zeros(tBgSFB)
1005
- else:
1006
- tAsSFA, tAgSFA = None, None
1007
- tBsSFB, tBgSFB = None, None
1056
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1008
1057
  ab_producer_state = self.load_AB(
1009
1058
  ab_pipeline,
1010
1059
  ab_producer_state,
1011
- tma_atom_a,
1012
- tAgA,
1013
- tAsA,
1014
- a_mcast_mask,
1015
- tma_atom_b,
1016
- tBgB,
1017
- tBsB,
1018
- b_mcast_mask,
1019
- tma_atom_sfa,
1020
- tAgSFA,
1021
- tAsSFA,
1022
- sfa_mcast_mask,
1023
- tma_atom_sfb,
1024
- tBgSFB,
1025
- tBsSFB,
1026
- sfb_mcast_mask,
1060
+ copy_A,
1061
+ copy_B,
1062
+ k_tile_cnt,
1063
+ copy_SFA,
1064
+ copy_SFB,
1027
1065
  )
1028
1066
  if const_expr(epi_load_barrier is not None):
1029
1067
  # In the first work tile, the epi load warp will wait for the signal
@@ -1033,19 +1071,180 @@ class GemmSm100(GemmSm90):
1033
1071
  epi_load_barrier.arrive()
1034
1072
  do_epi_load_barrier_arrive = Boolean(False)
1035
1073
  # Advance to next tile
1036
- tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
1037
1074
  tile_scheduler.advance_to_next_work()
1038
1075
  work_tile = tile_scheduler.get_current_work()
1039
1076
  # Wait A/B buffer empty
1040
1077
  ab_pipeline.producer_tail(ab_producer_state)
1041
- if is_scheduler_warp:
1042
- tile_scheduler.producer_tail()
1078
+
1079
+ if const_expr(self.gather_A):
1080
+ if (
1081
+ warp_idx >= self.ab_load_warp_id + 1
1082
+ and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
1083
+ ):
1084
+ # Persistent tile scheduling loop
1085
+ tile_scheduler = TileSchedulerCls()
1086
+ work_tile = tile_scheduler.initial_work_tile_info()
1087
+ ab_producer_state = pipeline.make_pipeline_state(
1088
+ pipeline.PipelineUserType.Producer, self.ab_stage
1089
+ )
1090
+ a_prefetch_consumer_state = pipeline.make_pipeline_state(
1091
+ pipeline.PipelineUserType.Consumer, self.a_prefetch_stage
1092
+ )
1093
+ while work_tile.is_valid_tile:
1094
+ tile_coord_mnkl = work_tile.tile_idx
1095
+ batch_idx = tile_coord_mnkl[3]
1096
+ # ///////////////////////////////////////////////////////////////////////////
1097
+ # Local_tile partition global tensors
1098
+ # ///////////////////////////////////////////////////////////////////////////
1099
+ mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
1100
+ if const_expr(varlen_m):
1101
+ # (M, K)
1102
+ mA_mk = mA_mkl
1103
+ else:
1104
+ assert varlen_k
1105
+ # (tile_M, K)
1106
+ mA_mk = cute.local_tile(
1107
+ mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
1108
+ )
1109
+ # Partition global tensor for TiledMMA_A/B/D
1110
+ len_m = varlen_manager.len_m(batch_idx)
1111
+ len_k = varlen_manager.len_k(batch_idx)
1112
+ # TMA load A partition_S/D
1113
+ tiled_copy_A = self._make_gmem_tiled_copy_A(
1114
+ mA_mkl.element_type, self.a_layout, (self.num_ab_load_warps - 1) * 32
1115
+ )
1116
+ tidx = cute.arch.thread_idx()[0] - (self.ab_load_warp_id + 1) * 32
1117
+ thr_copy_A = tiled_copy_A.get_slice(tidx)
1118
+ copy_A, prefetch_A = None, None
1119
+ if const_expr(varlen_m):
1120
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
1121
+ copy_A = copy_utils.gather_m_get_copy_fn(
1122
+ thr_copy_A,
1123
+ mA_mk,
1124
+ sA,
1125
+ sAIdx[None, a_prefetch_consumer_state.index],
1126
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
1127
+ limit_k=len_k,
1128
+ )
1129
+ cute.arch.sync_warp()
1130
+ with cute.arch.elect_one():
1131
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
1132
+ a_prefetch_consumer_state.advance()
1133
+ else:
1134
+ copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
1135
+ thr_copy_A,
1136
+ mA_mk,
1137
+ sA,
1138
+ sAIdx,
1139
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
1140
+ limit_k=len_k,
1141
+ )
1142
+ prefetch_A = partial(prefetch_A, a_prefetch_pipeline)
1143
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1144
+ ab_producer_state, a_prefetch_consumer_state = self.load_A_gather_A(
1145
+ ab_pipeline,
1146
+ ab_producer_state,
1147
+ a_prefetch_consumer_state,
1148
+ copy_A,
1149
+ prefetch_A,
1150
+ k_tile_cnt,
1151
+ )
1152
+ # Advance to next tile
1153
+ tile_scheduler.advance_to_next_work()
1154
+ work_tile = tile_scheduler.get_current_work()
1155
+
1156
+ #
1157
+ # Specialized scheduler warp. Will also prefetch A indices if gatherA
1158
+ #
1159
+ if const_expr(tile_sched_params.tile_count_semaphore is not None or self.gather_A):
1160
+ if warp_idx == self.scheduler_warp_id:
1161
+ is_scheduler_warp = True
1162
+ if const_expr(cute.size(cluster_layout_vmnk) > 1):
1163
+ is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0
1164
+ tile_M = self.cta_tile_shape_mnk[0]
1165
+ tile_K = self.cta_tile_shape_mnk[2]
1166
+ thr_copy_AIdx, tAsAIdx, tAcAIdx = None, None, None
1167
+ if const_expr(self.gather_A):
1168
+ tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True)
1169
+ thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx())
1170
+ tAsAIdx = thr_copy_AIdx.partition_D(sAIdx)
1171
+ tAcAIdx = thr_copy_AIdx.partition_S(
1172
+ cute.make_identity_tensor(tile_M if varlen_m else tile_K)
1173
+ )
1174
+ # Persistent tile scheduling loop
1175
+ tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
1176
+ work_tile = tile_scheduler.initial_work_tile_info()
1177
+ a_prefetch_producer_state = None
1178
+ if const_expr(self.gather_A):
1179
+ a_prefetch_producer_state = pipeline.make_pipeline_state(
1180
+ pipeline.PipelineUserType.Producer, self.a_prefetch_stage
1181
+ )
1182
+ while work_tile.is_valid_tile:
1183
+ if const_expr(self.gather_A):
1184
+ tile_coord_mnkl = work_tile.tile_idx
1185
+ batch_idx = tile_coord_mnkl[3]
1186
+ mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
1187
+ if const_expr(varlen_m):
1188
+ # (tile_M,)
1189
+ gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],))
1190
+ tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
1191
+ len_m = varlen_manager.len_m(batch_idx)
1192
+ m_limit = len_m - tile_coord_mnkl[0] * tile_M
1193
+ tApAIdx_m = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
1194
+ for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
1195
+ tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit
1196
+ a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
1197
+ cute.copy(
1198
+ thr_copy_AIdx,
1199
+ tAgAIdx,
1200
+ tAsAIdx[None, None, a_prefetch_producer_state.index],
1201
+ pred=tApAIdx_m,
1202
+ )
1203
+ a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
1204
+ a_prefetch_producer_state.advance()
1205
+ else:
1206
+ # (tile_K, RestK)
1207
+ gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,))
1208
+ tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
1209
+ len_k = varlen_manager.len_k(batch_idx)
1210
+ k_tile_cnt = cute.ceil_div(len_k, tile_K)
1211
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1212
+ a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
1213
+ cute.copy(
1214
+ thr_copy_AIdx,
1215
+ tAgAIdx[None, None, k_tile],
1216
+ tAsAIdx[None, None, a_prefetch_producer_state.index],
1217
+ )
1218
+ a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
1219
+ a_prefetch_producer_state.advance()
1220
+ if 0 < k_tile_cnt:
1221
+ k_tile = k_tile_cnt - 1
1222
+ k_limit = len_k - k_tile * tile_K
1223
+ tApAIdx_k = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
1224
+ for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
1225
+ tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit
1226
+ a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
1227
+ cute.copy(
1228
+ tiled_copy_AIdx,
1229
+ tAgAIdx[None, None, k_tile],
1230
+ tAsAIdx[None, None, a_prefetch_producer_state.index],
1231
+ pred=tApAIdx_k,
1232
+ )
1233
+ a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
1234
+ a_prefetch_producer_state.advance()
1235
+ # Advance to next tile
1236
+ tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
1237
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1238
+ work_tile = tile_scheduler.get_current_work()
1239
+ # End of persistent scheduler loop
1240
+ if is_scheduler_warp:
1241
+ tile_scheduler.producer_tail()
1043
1242
 
1044
1243
  #
1045
1244
  # Specialized TMA epi load warp
1046
1245
  #
1047
1246
  if const_expr(mC_mnl is not None):
1048
- if warp_idx == self.tma_epi_warp_id:
1247
+ if warp_idx == self.epi_load_warp_id:
1049
1248
  epi_producer_state = pipeline.make_pipeline_state(
1050
1249
  pipeline.PipelineUserType.Producer, self.epi_c_stage
1051
1250
  )
@@ -1056,37 +1255,23 @@ class GemmSm100(GemmSm90):
1056
1255
  while work_tile.is_valid_tile:
1057
1256
  # Get tile coord from tile scheduler
1058
1257
  tile_coord_mnkl = work_tile.tile_idx
1059
- # TODO: varlen_m
1060
- mma_tile_coord_mnl = (
1061
- tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
1062
- tile_coord_mnkl[1],
1063
- tile_coord_mnkl[3],
1064
- )
1065
- # Local_tile partition global tensors
1066
- # (bM, bN)
1067
- gC_mnl = cute.local_tile(
1068
- mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
1258
+ batch_idx = tile_coord_mnkl[3]
1259
+ copy_C_fn, _, bGS_gC = self.epilog_gmem_copy_and_partition(
1260
+ tma_atom_c,
1261
+ varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
1262
+ self.cta_tile_shape_mnk[:2],
1263
+ epi_tile,
1264
+ sC,
1265
+ tile_coord_mnkl,
1069
1266
  )
1070
- # Partition global tensor for TiledMMA_A/B/D
1071
- # (MMA, MMA_M, MMA_N)
1072
- tCgC = thr_mma.partition_C(gC_mnl)
1073
- # bGS_gC has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
1074
- bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
1075
- tma_atom_c, tCgC, epi_tile, sC
1076
- )
1077
- bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
1267
+ copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
1078
1268
  if do_epi_load_barrier_wait:
1079
1269
  epi_load_barrier.arrive_and_wait()
1080
1270
  do_epi_load_barrier_wait = Boolean(False)
1081
1271
  epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1]))
1082
- for subtile_idx in cutlass.range(epi_tile_num, unroll=1):
1272
+ for epi_idx in cutlass.range(epi_tile_num, unroll=1):
1083
1273
  epi_pipeline.producer_acquire(epi_producer_state)
1084
- cute.copy(
1085
- tma_atom_c,
1086
- bGS_gC[None, subtile_idx],
1087
- bGS_sC[None, epi_producer_state.index],
1088
- tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1089
- )
1274
+ copy_C(src_idx=epi_idx, producer_state=epi_producer_state)
1090
1275
  # Epi pipeline's producer commit is a NOP
1091
1276
  epi_pipeline.producer_commit(epi_producer_state)
1092
1277
  epi_producer_state.advance()
@@ -1107,7 +1292,7 @@ class GemmSm100(GemmSm90):
1107
1292
  )
1108
1293
  # Partition shared/tensor memory tensor for TiledMMA_A/B/D
1109
1294
  # (MMA, MMA_M, MMA_K, STAGE)
1110
- tCrA = tiled_mma.make_fragment_A(sA)
1295
+ tCrA = tiled_mma.make_fragment_A(sA_mma)
1111
1296
  # (MMA, MMA_N, MMA_K, STAGE)
1112
1297
  tCrB = tiled_mma.make_fragment_B(sB)
1113
1298
  # (MMA, MMA_M, MMA_N, STAGE)
@@ -1154,10 +1339,10 @@ class GemmSm100(GemmSm90):
1154
1339
  tCtSFB_compact_s2t,
1155
1340
  ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
1156
1341
  else:
1342
+ tCtSFA, tCtSFB = None, None
1157
1343
  tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
1158
1344
  tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
1159
1345
 
1160
- k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.mma_tiler[2])
1161
1346
  # Persistent tile scheduling loop
1162
1347
  tile_scheduler = TileSchedulerCls()
1163
1348
  work_tile = tile_scheduler.initial_work_tile_info()
@@ -1170,6 +1355,9 @@ class GemmSm100(GemmSm90):
1170
1355
  while work_tile.is_valid_tile:
1171
1356
  # Get tile coord from tile scheduler
1172
1357
  tile_coord_mnkl = work_tile.tile_idx
1358
+ batch_idx = tile_coord_mnkl[3]
1359
+ k_len = varlen_manager.len_k(batch_idx)
1360
+ k_tile_cnt = cute.ceil_div(k_len, self.mma_tiler[2])
1173
1361
  # Set tensor memory buffer for current tile
1174
1362
  # (MMA, MMA_M, MMA_N)
1175
1363
  tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index]
@@ -1184,6 +1372,9 @@ class GemmSm100(GemmSm90):
1184
1372
  tCtAcc,
1185
1373
  k_tile_cnt,
1186
1374
  is_leader_cta,
1375
+ cta_rank_in_cluster,
1376
+ tCtSFA,
1377
+ tCtSFB,
1187
1378
  tiled_copy_s2t_sfa,
1188
1379
  tiled_copy_s2t_sfb,
1189
1380
  tCsSFA_compact_s2t,
@@ -1209,6 +1400,14 @@ class GemmSm100(GemmSm90):
1209
1400
  )
1210
1401
  # Bar sync for retrieve tensor memory ptr from shared memory
1211
1402
  tmem_alloc_barrier.arrive_and_wait()
1403
+
1404
+ is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
1405
+ varlen_manager.init_tensormap_epi(
1406
+ tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
1407
+ )
1408
+ tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
1409
+ tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
1410
+
1212
1411
  # Retrieving tensor memory ptr and make accumulator tensor
1213
1412
  acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
1214
1413
  self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
@@ -1221,44 +1420,22 @@ class GemmSm100(GemmSm90):
1221
1420
  num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
1222
1421
  )
1223
1422
 
1224
- is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
1225
- if const_expr(varlen_m):
1226
- # initialize tensormap for D
1227
- if const_expr(has_D):
1228
- tensormap_manager.init_tensormap_from_atom(
1229
- tma_atom_d,
1230
- tensormap_d_ptr,
1231
- is_manager_warp=is_tma_warp,
1232
- )
1233
- for tma_atom, tensormap_epi_ptr in zip(
1234
- self.epi_get_tma_atoms(epilogue_params), tensormap_epi_ptrs
1235
- ):
1236
- tensormap_manager.init_tensormap_from_atom(
1237
- tma_atom,
1238
- tensormap_epi_ptr,
1239
- is_manager_warp=is_tma_warp,
1240
- )
1241
-
1242
1423
  # Partition for epilogue
1243
1424
  epi_tidx = tidx
1244
1425
  tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
1245
1426
  epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
1246
1427
  )
1247
1428
 
1248
- tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.d_dtype)
1249
- tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
1250
- tiled_copy_t2r, tTR_rD, epi_tidx, sD
1429
+ tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.acc_dtype)
1430
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
1431
+ tiled_copy_t2r, self.d_layout, self.d_dtype, tTR_rD, sD, epi_tidx
1251
1432
  )
1252
- tRS_rC, tSR_rC = None, None
1433
+ tRS_rC, tSR_rC, tSR_sC = None, None, None
1434
+ tiled_copy_s2r = None
1253
1435
  if const_expr(mC_mnl is not None):
1254
- tTR_rC = cute.make_fragment_like(tTR_rD, self.c_dtype)
1255
- tiled_copy_s2r, tSR_rC, tSR_sC = self.epilog_smem_copy_and_partition(
1256
- tiled_copy_t2r, tTR_rC, epi_tidx, sC
1436
+ tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
1437
+ tiled_copy_t2r, self.c_layout, self.c_dtype, sC, tRS_rD.layout, epi_tidx
1257
1438
  )
1258
- # TODO: for m major, D is being stored w STSM so we'd need LDSM here
1259
- # tRS_rC = tSR_rC # TODO: retile?
1260
- tRS_rC = cute.make_fragment(tRS_rD.layout, self.c_dtype)
1261
- tSR_rC = tiled_copy_s2r.get_slice(epi_tidx).retile(tRS_rC)
1262
1439
 
1263
1440
  # Persistent tile scheduling loop
1264
1441
  tile_scheduler = TileSchedulerCls()
@@ -1272,42 +1449,21 @@ class GemmSm100(GemmSm90):
1272
1449
  )
1273
1450
  if const_expr(varlen_m):
1274
1451
  # wait tensormap initialization complete before update
1275
- tensormap_manager.fence_tensormap_initialization()
1276
- # batch index of last tile
1277
- last_batch_idx = cutlass.Int32(-1)
1452
+ varlen_manager.fence_tensormap_init()
1278
1453
  while work_tile.is_valid_tile:
1279
1454
  # Get tile coord from tile scheduler
1280
1455
  tile_coord_mnkl = work_tile.tile_idx
1281
1456
  batch_idx = tile_coord_mnkl[3]
1282
- if const_expr(varlen_m):
1283
- is_group_changed = batch_idx != last_batch_idx
1284
- last_batch_idx = batch_idx
1285
- if is_group_changed:
1286
- self.tensormap_update_D_epi(
1287
- tensormap_manager,
1288
- tensormap_d_ptr,
1289
- tensormap_epi_ptrs,
1290
- epilogue_params,
1291
- cu_seqlens_m,
1292
- batch_idx,
1293
- is_manager_warp=is_tma_warp,
1294
- )
1295
-
1296
- mma_tile_coord_mnl = (
1297
- tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
1298
- tile_coord_mnkl[1],
1299
- tile_coord_mnkl[3],
1457
+ epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
1458
+ epilogue_params, varlen_params.cu_seqlens_m, batch_idx
1300
1459
  )
1301
- # Local_tile partition global tensors
1302
- # (bM, bN)
1303
- gD_mnl = cute.local_tile(
1304
- mD_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
1460
+ varlen_manager.update_tensormap_epi(
1461
+ batch_idx,
1462
+ self.d_layout,
1463
+ epi_shapes,
1464
+ epi_orders,
1465
+ is_tma_warp,
1305
1466
  )
1306
- # Partition global tensor for TiledMMA_A/B/D
1307
- # (MMA, MMA_M, MMA_N)
1308
- tDgD = thr_mma.partition_C(gD_mnl)
1309
- # bSG_gD has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
1310
- bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(tma_atom_d, tDgD, epi_tile, sD)
1311
1467
 
1312
1468
  # Set tensor memory buffer for current tile
1313
1469
  # (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
@@ -1316,67 +1472,59 @@ class GemmSm100(GemmSm90):
1316
1472
  # Wait for accumulator buffer full
1317
1473
  acc_pipeline.consumer_wait(acc_consumer_state)
1318
1474
 
1319
- tma_desc_d_ptr, tma_desc_epi_ptrs = None, [None] * self.num_epi_tensormaps
1320
- if const_expr(varlen_m):
1321
- # ensure the update to tensormap has completed before using it
1322
- if is_group_changed and is_tma_warp:
1323
- if const_expr(has_D):
1324
- tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1325
- for tensormap_epi_ptr in tensormap_epi_ptrs:
1326
- tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
1327
- if const_expr(has_D):
1328
- tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
1329
- tensormap_d_ptr, cute.AddressSpace.generic
1330
- )
1331
- tma_desc_epi_ptrs = [
1332
- tensormap_manager.get_tensormap_ptr(
1333
- tensormap_epi_ptr, cute.AddressSpace.generic
1334
- )
1335
- for tensormap_epi_ptr in tensormap_epi_ptrs
1336
- ]
1475
+ varlen_manager.fence_tensormap_update_epi(is_tma_warp)
1337
1476
 
1338
- tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
1339
- bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
1340
-
1341
- # Store accumulator to global memory in subtiles
1342
- subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
1343
- num_prev_subtiles = tile_scheduler.num_tiles_executed * subtile_cnt
1344
- for subtile_idx in cutlass.range(subtile_cnt):
1345
- # Load accumulator from tensor memory buffer to register
1346
- tTR_tAcc_mn = tTR_tAcc[None, None, None, subtile_idx]
1347
- cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
1348
- # Convert to D type
1349
- acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
1350
- if const_expr(mC_mnl is not None):
1351
- epi_pipeline.consumer_wait(epi_read_state)
1352
- cute.copy(
1353
- tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
1354
- )
1355
- # Fence to make sure shared memory read is visible to TMA load
1356
- cute.arch.fence_proxy(
1357
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1358
- )
1359
- cute.arch.sync_warp()
1360
- with cute.arch.elect_one():
1361
- epi_pipeline.consumer_release(epi_read_state)
1362
- epi_read_state.advance()
1363
- acc_vec = acc_vec + tRS_rC.load().to(self.acc_dtype)
1364
- tRS_rD.store(acc_vec.to(self.d_dtype))
1365
- # Store D to shared memory
1366
- d_buffer = (num_prev_subtiles + subtile_idx) % self.epi_stage
1367
- cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
1368
- # Fence and barrier to make sure shared memory store is visible to TMA store
1369
- cute.arch.fence_proxy(
1370
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1477
+ copy_D = None
1478
+ if const_expr(has_D):
1479
+ copy_D, _, _ = self.epilog_gmem_copy_and_partition(
1480
+ tma_atom_d,
1481
+ varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
1482
+ self.cta_tile_shape_mnk[:2],
1483
+ epi_tile,
1484
+ sD,
1485
+ tile_coord_mnkl,
1486
+ tma_desc_ptr=tma_desc_d_ptr,
1371
1487
  )
1372
- epilogue_barrier.arrive_and_wait()
1373
- # TMA store D to global memory
1374
- if is_tma_warp:
1375
- cute.copy(tma_atom_d, bSG_sD[None, d_buffer], bSG_gD[None, subtile_idx])
1376
- # Fence and barrier to make sure shared memory store is visible to TMA store
1377
- epi_store_pipeline.producer_commit()
1378
- epi_store_pipeline.producer_acquire()
1379
- epilogue_barrier.arrive_and_wait()
1488
+ copy_C = None # We're using a separate warp to load C
1489
+
1490
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
1491
+ k_len = varlen_manager.len_k(batch_idx)
1492
+ load_acc_subtile = partial(
1493
+ self.epi_load_acc_subtile,
1494
+ tiled_copy_t2r,
1495
+ tiled_copy_r2s,
1496
+ tTR_tAcc,
1497
+ tTR_rAcc,
1498
+ clear_acc=varlen_k and k_len == 0,
1499
+ )
1500
+
1501
+ epi_read_state, _ = self.epilogue(
1502
+ epilogue_params,
1503
+ epi_smem_tensors,
1504
+ tma_desc_epi_ptrs,
1505
+ epi_pipeline,
1506
+ epi_store_pipeline,
1507
+ epi_read_state,
1508
+ None, # epi_producer_state
1509
+ epi_tile,
1510
+ load_acc_subtile,
1511
+ tRS_rD,
1512
+ tRS_rC,
1513
+ tiled_copy_t2r,
1514
+ tiled_copy_r2s,
1515
+ tRS_sD,
1516
+ tiled_copy_s2r,
1517
+ tSR_rC,
1518
+ tSR_sC,
1519
+ copy_D,
1520
+ copy_C,
1521
+ tile_coord_mnkl,
1522
+ varlen_manager,
1523
+ epilogue_barrier,
1524
+ tile_scheduler,
1525
+ epi_tidx,
1526
+ is_tma_warp,
1527
+ )
1380
1528
 
1381
1529
  # Async arrive accumulator buffer empty
1382
1530
  with cute.arch.elect_one():
@@ -1404,79 +1552,50 @@ class GemmSm100(GemmSm90):
1404
1552
  epi_store_pipeline.producer_tail()
1405
1553
 
1406
1554
  @cute.jit
1407
- def load_AB(
1555
+ def load_A_gather_A(
1408
1556
  self,
1409
- ab_pipeline: cutlass.pipeline.PipelineAsync,
1410
- ab_producer_state: cutlass.pipeline.PipelineState,
1411
- tma_atom_a: cute.CopyAtom,
1412
- tAgA: cute.Tensor,
1413
- tAsA: cute.Tensor,
1414
- a_mcast_mask: cutlass.Int16,
1415
- tma_atom_b: cute.CopyAtom,
1416
- tBgB: cute.Tensor,
1417
- tBsB: cute.Tensor,
1418
- b_mcast_mask: cutlass.Int16,
1419
- tma_atom_sfa: Optional[cute.CopyAtom] = None,
1420
- tAgSFA: Optional[cute.Tensor] = None,
1421
- tAsSFA: Optional[cute.Tensor] = None,
1422
- sfa_mcast_mask: Optional[cutlass.Int16] = None,
1423
- tma_atom_sfb: Optional[cute.CopyAtom] = None,
1424
- tBgSFB: Optional[cute.Tensor] = None,
1425
- tBsSFB: Optional[cute.Tensor] = None,
1426
- sfb_mcast_mask: Optional[cutlass.Int16] = None,
1427
- ) -> cutlass.pipeline.PipelineState:
1428
- blockscaled = const_expr(tma_atom_sfa is not None)
1429
- if const_expr(blockscaled):
1430
- assert all(x is not None for x in (tma_atom_sfa, tAgSFA, tAsSFA))
1431
- assert all(x is not None for x in (tma_atom_sfb, tBgSFB, tBsSFB))
1432
- k_tile_cnt = cute.size(tAgA, mode=[1])
1557
+ a_pipeline: cutlass.pipeline.PipelineAsync,
1558
+ a_producer_state: cutlass.pipeline.PipelineState,
1559
+ a_prefetch_consumer_state: Optional[cutlass.pipeline.PipelineState],
1560
+ copy_A: Callable,
1561
+ prefetch_A: Optional[Callable],
1562
+ k_tile_cnt: Int32,
1563
+ ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]:
1433
1564
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1434
- peek_ab_empty_status = Boolean(True)
1565
+ peek_a_empty_status = Boolean(True)
1435
1566
  if 0 < k_tile_cnt:
1436
- peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1567
+ peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
1437
1568
  # /////////////////////////////////////////////////////////////////////////
1438
- # TMA load
1569
+ # cp.async on A
1439
1570
  # /////////////////////////////////////////////////////////////////////////
1440
- for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1441
- # Wait for A/B buffers to be empty before loading into them
1442
- # Also sets the transaction barrier for the A/B buffers
1443
- ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1444
- cute.copy(
1445
- tma_atom_a,
1446
- tAgA[None, k_tile],
1447
- tAsA[None, ab_producer_state.index],
1448
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1449
- mcast_mask=a_mcast_mask,
1450
- )
1451
- cute.copy(
1452
- tma_atom_b,
1453
- tBgB[None, k_tile],
1454
- tBsB[None, ab_producer_state.index],
1455
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1456
- mcast_mask=b_mcast_mask,
1457
- )
1458
- if const_expr(blockscaled):
1459
- cute.copy(
1460
- tma_atom_sfa,
1461
- tAgSFA[None, ab_producer_state.count],
1462
- tAsSFA[None, ab_producer_state.index],
1463
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1464
- mcast_mask=sfa_mcast_mask,
1465
- )
1466
- cute.copy(
1467
- tma_atom_sfb,
1468
- tBgSFB[None, ab_producer_state.count],
1469
- tBsSFB[None, ab_producer_state.index],
1470
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1471
- mcast_mask=sfb_mcast_mask,
1472
- )
1473
- # Mainloop pipeline's producer commit is a NOP
1474
- ab_pipeline.producer_commit(ab_producer_state)
1475
- ab_producer_state.advance()
1476
- peek_ab_empty_status = Boolean(True)
1571
+ is_tma_warp = False
1572
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1573
+ smem_idx = a_producer_state.index
1574
+ prefetch_out = ()
1575
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1576
+ prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),)
1577
+ a_prefetch_consumer_state.advance()
1578
+ a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp)
1579
+ copy_A(k_tile, smem_idx, *prefetch_out)
1580
+ # This tells mbarrier to track the completion of cp.async
1581
+ a_pipeline.producer_cpasync_commit(a_producer_state)
1582
+ a_producer_state.advance()
1583
+ peek_a_empty_status = Boolean(True)
1477
1584
  if k_tile + 1 < k_tile_cnt:
1478
- peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1479
- return ab_producer_state
1585
+ peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
1586
+ # bound checking in the K dimension on the last k_tile
1587
+ if 0 < k_tile_cnt:
1588
+ k_tile = k_tile_cnt - 1
1589
+ smem_idx = a_producer_state.index
1590
+ prefetch_out = ()
1591
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1592
+ prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),)
1593
+ a_prefetch_consumer_state.advance()
1594
+ a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp)
1595
+ copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
1596
+ a_pipeline.producer_cpasync_commit(a_producer_state)
1597
+ a_producer_state.advance()
1598
+ return a_producer_state, a_prefetch_consumer_state
1480
1599
 
1481
1600
  @cute.jit
1482
1601
  def mma(
@@ -1491,6 +1610,9 @@ class GemmSm100(GemmSm90):
1491
1610
  acc: cute.Tensor,
1492
1611
  k_tile_cnt: Int32,
1493
1612
  is_leader_cta: Boolean,
1613
+ cta_rank_in_cluster: Int32,
1614
+ tCtSFA: Optional[cute.Tensor] = None,
1615
+ tCtSFB: Optional[cute.Tensor] = None,
1494
1616
  tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None,
1495
1617
  tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None,
1496
1618
  tCsSFA_compact_s2t: Optional[cute.Tensor] = None,
@@ -1500,12 +1622,17 @@ class GemmSm100(GemmSm90):
1500
1622
  ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]:
1501
1623
  blockscaled = const_expr(tiled_copy_s2t_sfa is not None)
1502
1624
  if const_expr(blockscaled):
1625
+ assert all(x is not None for x in (tCtSFA, tCtSFB))
1503
1626
  assert all(x is not None for x in (tiled_copy_s2t_sfa, tiled_copy_s2t_sfb))
1504
1627
  assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t))
1505
1628
  assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t))
1629
+ # If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will
1630
+ # arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader
1631
+ # CTA will wait for that then arrive at the mbarrier on the leader CTA.
1632
+ need_nonleader_cta = const_expr(self.gather_A and self.use_2cta_instrs)
1506
1633
  # Peek (try_wait) AB buffer full for k_tile = 0
1507
1634
  peek_ab_full_status = Boolean(True)
1508
- if 0 < k_tile_cnt and is_leader_cta:
1635
+ if 0 < k_tile_cnt and (is_leader_cta or need_nonleader_cta):
1509
1636
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
1510
1637
  # Wait for accumulator buffer empty
1511
1638
  if is_leader_cta:
@@ -1515,6 +1642,14 @@ class GemmSm100(GemmSm90):
1515
1642
  # Mma mainloop
1516
1643
  num_k_blocks = cute.size(tCrA, mode=[2])
1517
1644
  for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1645
+ if const_expr(need_nonleader_cta):
1646
+ if not is_leader_cta:
1647
+ ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
1648
+ with cute.arch.elect_one():
1649
+ # The odd CTA signals the even CTA
1650
+ ab_pipeline.sync_object_full.arrive_mbarrier(
1651
+ ab_consumer_state.index, dst_rank=cta_rank_in_cluster & 0xFE
1652
+ )
1518
1653
  if is_leader_cta:
1519
1654
  # Conditionally wait for AB buffer full
1520
1655
  ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
@@ -1527,6 +1662,11 @@ class GemmSm100(GemmSm90):
1527
1662
  cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
1528
1663
  for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1529
1664
  k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index)
1665
+ if const_expr(blockscaled):
1666
+ # Set SFA/SFB tensor to tiled_mma
1667
+ sf_kblock_coord = (None, None, k_blk_idx)
1668
+ tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
1669
+ tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator)
1530
1670
  cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1531
1671
  tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
1532
1672
  # Async arrive AB buffer empty
@@ -1534,7 +1674,7 @@ class GemmSm100(GemmSm90):
1534
1674
  ab_consumer_state.advance()
1535
1675
  # Peek (try_wait) AB buffer full for k_tile = k_tile + 1
1536
1676
  peek_ab_full_status = Boolean(True)
1537
- if k_tile + 1 < k_tile_cnt and is_leader_cta:
1677
+ if k_tile + 1 < k_tile_cnt and (is_leader_cta or need_nonleader_cta):
1538
1678
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
1539
1679
  # Async arrive accumulator buffer full
1540
1680
  if is_leader_cta:
@@ -1544,6 +1684,25 @@ class GemmSm100(GemmSm90):
1544
1684
  # "operand #0 does not dominate this use"
1545
1685
  return ab_consumer_state, acc_producer_state, tiled_mma
1546
1686
 
1687
+ @cute.jit
1688
+ def epi_load_acc_subtile(
1689
+ self,
1690
+ tiled_copy_t2r: cute.TiledCopy,
1691
+ tiled_copy_r2s: cute.TiledCopy,
1692
+ tTR_tAcc: cute.Tensor,
1693
+ tTR_rAcc: cute.Tensor,
1694
+ tRS_rD: cute.Tensor,
1695
+ epi_idx: int,
1696
+ clear_acc: Boolean = False,
1697
+ ):
1698
+ if not clear_acc:
1699
+ # Load accumulator from tensor memory buffer to register
1700
+ cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, None, epi_idx], tTR_rAcc)
1701
+ tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc)
1702
+ tRS_rD.store(tRS_rAcc.load())
1703
+ else:
1704
+ tRS_rD.fill(0.0)
1705
+
1547
1706
  def mainloop_s2t_copy_and_partition(
1548
1707
  self,
1549
1708
  sSF: cute.Tensor,
@@ -1607,8 +1766,8 @@ class GemmSm100(GemmSm90):
1607
1766
  # Make tiledCopy for tensor memory load
1608
1767
  copy_atom_t2r = sm100_utils.get_tmem_load_op(
1609
1768
  self.cta_tile_shape_mnk,
1610
- self.d_layout,
1611
- self.d_dtype,
1769
+ self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR,
1770
+ self.d_dtype if self.d_dtype is not None else cutlass.BFloat16,
1612
1771
  self.acc_dtype,
1613
1772
  epi_tile,
1614
1773
  use_2cta_instrs,
@@ -1631,12 +1790,14 @@ class GemmSm100(GemmSm90):
1631
1790
  tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
1632
1791
  return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
1633
1792
 
1634
- def epilog_smem_copy_and_partition(
1793
+ def epilog_smem_store_and_partition(
1635
1794
  self,
1636
1795
  tiled_copy_t2r: cute.TiledCopy,
1796
+ d_layout: Optional[LayoutEnum],
1797
+ dtype: Optional[Type[cutlass.Numeric]],
1637
1798
  tTR_rD: cute.Tensor,
1638
- tidx: Int32,
1639
1799
  sD: cute.Tensor,
1800
+ tidx: Int32,
1640
1801
  ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1641
1802
  """
1642
1803
  Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
@@ -1658,83 +1819,106 @@ class GemmSm100(GemmSm90):
1658
1819
  :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
1659
1820
  """
1660
1821
  copy_atom_r2s = sm100_utils.get_smem_store_op(
1661
- self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r
1822
+ d_layout if d_layout is not None else LayoutEnum.ROW_MAJOR,
1823
+ dtype if dtype is not None else cutlass.BFloat16,
1824
+ self.acc_dtype,
1825
+ tiled_copy_t2r,
1662
1826
  )
1663
1827
  tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
1664
1828
  # (R2S, R2S_M, R2S_N, PIPE_D)
1665
1829
  thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1666
- tRS_sD = thr_copy_r2s.partition_D(sD)
1830
+ tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1667
1831
  # (R2S, R2S_M, R2S_N)
1668
1832
  tRS_rD = tiled_copy_r2s.retile(tTR_rD)
1669
1833
  return tiled_copy_r2s, tRS_rD, tRS_sD
1670
1834
 
1671
- # def epilog_smem_load_copy_and_partition(
1672
- # self,
1673
- # tiled_copy_t2r: cute.TiledCopy,
1674
- # tTR_rC: cute.Tensor,
1675
- # tidx: Int32,
1676
- # sC: cute.Tensor,
1677
- # ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1678
- # copy_atom_s2r = cute.make_copy_atom(
1679
- # warp.LdMatrix8x8x16bOp(self.c_layout.is_m_major_c(), num_matrices=4),
1680
- # self.c_dtype, # TODO: this probably only works for f16 for now?
1681
- # )
1682
- # # copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
1683
- # tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
1684
- # # (R2S, R2S_M, R2S_N, PIPE_D)
1685
- # thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1686
- # # (R2S, R2S_M, R2S_N)
1687
- # tSR_sC = thr_copy_s2r.partition_S(sC)
1688
- # return tiled_copy_s2r, tSR_sC
1689
-
1690
- def epilog_gmem_copy_and_partition(
1835
+ def epilog_smem_load_and_partition(
1691
1836
  self,
1692
- atom: Union[cute.CopyAtom, cute.TiledCopy],
1693
- gD_mnl: cute.Tensor,
1694
- epi_tile: cute.Tile,
1695
- sD: cute.Tensor,
1696
- ) -> Tuple[cute.Tensor, cute.Tensor]:
1697
- """Make tiledCopy for global memory store, then use it to:
1698
- - partition register array (source) and global memory (destination) for none TMA store version;
1699
- - partition shared memory (source) and global memory (destination) for TMA store version.
1700
-
1701
- :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
1702
- :type atom: cute.CopyAtom or cute.TiledCopy
1703
- :param gD_mnl: The global tensor C
1704
- :type gD_mnl: cute.Tensor
1705
- :param epi_tile: The epilogue tiler
1706
- :type epi_tile: cute.Tile
1707
- :param sD: The shared memory tensor to be copied and partitioned
1708
- :type sD: cute.Tensor
1837
+ tiled_copy_t2r: cute.TiledCopy,
1838
+ c_layout: LayoutEnum,
1839
+ dtype: Type[cutlass.Numeric],
1840
+ # tTR_rC: cute.Tensor,
1841
+ sC: cute.Tensor,
1842
+ tRS_rD_layout: cutlass.Layout,
1843
+ tidx: Int32,
1844
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1845
+ copy_atom_r2s = sm100_utils.get_smem_store_op(
1846
+ c_layout, dtype, self.acc_dtype, tiled_copy_t2r
1847
+ )
1848
+ store_op = copy_atom_r2s.op
1849
+ # m8n8 16-bit path
1850
+ if isinstance(store_op, StMatrix8x8x16bOp):
1851
+ op = LdMatrix8x8x16bOp(num_matrices=store_op.num_matrices, transpose=store_op.transpose)
1852
+ # m16n8 8-bit store -> m16n16 8-bit load
1853
+ elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [2, 4]:
1854
+ # transpose=True is enforced by the class
1855
+ op = LdMatrix16x16x8bOp(num_matrices=store_op.num_matrices // 2)
1856
+ else:
1857
+ op = cute.nvgpu.CopyUniversalOp()
1858
+ copy_atom_s2r = cute.make_copy_atom(op, dtype)
1859
+ tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
1860
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1861
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1862
+ tSR_sC = thr_copy_s2r.partition_S(sC)
1863
+ tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
1864
+ # (R2S, R2S_M, R2S_N)
1865
+ tSR_rC = tiled_copy_s2r.retile(tRS_rC)
1866
+ return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
1709
1867
 
1710
- :return: A tuple containing either:
1711
- - For TMA store: (tma_atom_d, bSG_sD, bSG_gD) where:
1712
- - tma_atom_d: The TMA copy atom
1713
- - bSG_sD: The partitioned shared memory tensor C
1714
- - bSG_gD: The partitioned global tensor C
1715
- - For non-TMA store: (simt_atom, tTR_rD, tTR_gD) where:
1716
- - simt_atom: The SIMT copy atom
1717
- - tTR_rD: The register tensor C
1718
- - tTR_gD: The partitioned global tensor C
1719
- :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
1720
- """
1721
- # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
1722
- gD_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0)], epi_tile)
1723
- sD_for_tma_partition = cute.group_modes(sD, 0, 2)
1724
- gD_for_tma_partition = cute.group_modes(gD_epi, 0, 2)
1725
- # ((ATOM_V, REST_V), EPI_M, EPI_N)
1726
- bSG_sD, bSG_gD = cpasync.tma_partition(
1727
- atom,
1728
- 0,
1729
- cute.make_layout(1),
1730
- sD_for_tma_partition,
1731
- gD_for_tma_partition,
1868
+ @cute.jit
1869
+ def make_ab_pipeline(
1870
+ self,
1871
+ tiled_mma: cute.TiledMma,
1872
+ cluster_layout_vmnk: cute.Layout,
1873
+ ab_pipeline_mbar_ptr: cute.Pointer,
1874
+ is_leader_cta: Boolean,
1875
+ ) -> pipeline.PipelineAsync:
1876
+ # If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will
1877
+ # arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader
1878
+ # CTA will wait for that then arrive at the mbarrier on the leader CTA.
1879
+ # The producer count for the leader CTA is 1 (TMA) + num_cpasync_threads
1880
+ # + 1 (from non-leader CTA).
1881
+ # The producer count for the non-leader CTA is num_cpasync_threads
1882
+ # (TMA doesn't arrive there).
1883
+ if const_expr(not self.gather_A):
1884
+ producer_cnt = 1
1885
+ else:
1886
+ producer_cnt = (self.num_ab_load_warps - 1) * 32 + (
1887
+ 1 if const_expr(not self.use_2cta_instrs) else 2
1888
+ )
1889
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1890
+ # Each warp will contribute to the arrive count with the number of mcast size
1891
+ mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
1892
+ consumer_arrive_cnt = mcast_size
1893
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(
1894
+ pipeline.Agent.Thread, consumer_arrive_cnt
1732
1895
  )
1733
- return bSG_sD, bSG_gD
1896
+ if const_expr(not self.gather_A):
1897
+ pipeline_ab = pipeline.PipelineTmaUmma.create(
1898
+ barrier_storage=ab_pipeline_mbar_ptr,
1899
+ num_stages=self.ab_stage,
1900
+ producer_group=ab_pipeline_producer_group,
1901
+ consumer_group=ab_pipeline_consumer_group,
1902
+ tx_count=self.num_tma_load_bytes,
1903
+ cta_layout_vmnk=cluster_layout_vmnk,
1904
+ )
1905
+ else:
1906
+ pipeline_ab = PipelineTmaCpAsyncUmma.create(
1907
+ barrier_storage=ab_pipeline_mbar_ptr,
1908
+ num_stages=self.ab_stage,
1909
+ producer_group=ab_pipeline_producer_group,
1910
+ consumer_group=ab_pipeline_consumer_group,
1911
+ tx_count=self.num_tma_load_bytes,
1912
+ cta_layout_vmnk=cluster_layout_vmnk,
1913
+ producer_drop_count=None
1914
+ if not self.use_2cta_instrs
1915
+ else (2 if not is_leader_cta else 0),
1916
+ )
1917
+ return pipeline_ab
1734
1918
 
1735
1919
  def make_acc_pipeline(
1736
1920
  self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer
1737
- ):
1921
+ ) -> pipeline.PipelineAsync:
1738
1922
  acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1739
1923
  num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1)
1740
1924
  acc_pipeline_consumer_group = pipeline.CooperativeGroup(
@@ -1748,19 +1932,70 @@ class GemmSm100(GemmSm90):
1748
1932
  cta_layout_vmnk=cluster_layout_vmnk,
1749
1933
  )
1750
1934
 
1751
- @staticmethod
1935
+ def make_sched_pipeline(
1936
+ self,
1937
+ cluster_layout_mnk: cute.Layout,
1938
+ sched_pipeline_mbar_ptr: cute.Pointer,
1939
+ has_C: bool = False,
1940
+ ) -> pipeline.PipelineAsync:
1941
+ # Threads/warps participating in this pipeline
1942
+ sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1943
+ cluster_size = cute.size(cluster_layout_mnk)
1944
+ # Each warp that are not the scheduler warp will contribute 1 to the arrive count
1945
+ warps_per_cta = self.num_ab_load_warps + len(
1946
+ (self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id)
1947
+ )
1948
+ if has_C:
1949
+ warps_per_cta += 1
1950
+ consumer_arrive_cnt = warps_per_cta * cluster_size - 1
1951
+ sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1952
+ pipeline.Agent.Thread, consumer_arrive_cnt
1953
+ )
1954
+ return pipeline.PipelineAsync.create(
1955
+ barrier_storage=sched_pipeline_mbar_ptr,
1956
+ num_stages=self.sched_stage,
1957
+ producer_group=sched_pipeline_producer_group,
1958
+ consumer_group=sched_pipeline_consumer_group,
1959
+ # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1960
+ consumer_mask=None if const_expr(cluster_size == 1) else 0,
1961
+ )
1962
+
1963
+ @cute.jit
1964
+ def make_a_prefetch_pipeline(
1965
+ self, a_prefetch_pipeline_mbar_ptr: cute.Pointer
1966
+ ) -> pipeline.PipelineAsync:
1967
+ producer_cnt = 32
1968
+ a_prefetch_producer_group = pipeline.CooperativeGroup(
1969
+ pipeline.Agent.Thread, producer_cnt, alignment=producer_cnt
1970
+ )
1971
+ consumer_arrive_cnt = self.num_ab_load_warps - 1
1972
+ a_prefetch_consumer_group = pipeline.CooperativeGroup(
1973
+ pipeline.Agent.Thread, consumer_arrive_cnt
1974
+ )
1975
+ return pipeline.PipelineCpAsync.create(
1976
+ barrier_storage=a_prefetch_pipeline_mbar_ptr,
1977
+ num_stages=self.a_prefetch_stage,
1978
+ producer_group=a_prefetch_producer_group,
1979
+ consumer_group=a_prefetch_consumer_group,
1980
+ )
1981
+
1982
+ @classmethod
1752
1983
  def _compute_stages(
1984
+ cls,
1753
1985
  tiled_mma: cute.TiledMma,
1754
1986
  mma_tiler_mnk: Tuple[int, int, int],
1987
+ cta_tile_shape_mnk: Tuple[int, int, int],
1988
+ epi_tile: cute.Tile,
1755
1989
  a_dtype: Type[cutlass.Numeric],
1756
1990
  b_dtype: Type[cutlass.Numeric],
1757
- epi_tile: cute.Tile,
1758
- d_dtype: Type[cutlass.Numeric],
1759
- c_dtype: Optional[Type[cutlass.Numeric]],
1760
- d_layout: LayoutEnum,
1761
- c_layout: Optional[LayoutEnum],
1762
1991
  sf_dtype: Optional[Type[cutlass.Numeric]],
1763
1992
  sf_vec_size: Optional[int],
1993
+ d_dtype: Optional[Type[cutlass.Numeric]],
1994
+ c_dtype: Optional[Type[cutlass.Numeric]],
1995
+ d_layout: Optional[LayoutEnum],
1996
+ c_layout: Optional[LayoutEnum],
1997
+ epilogue_args: EpilogueArguments,
1998
+ prefetch_A_idx: Literal[None, "varlen_m", "varlen_k"],
1764
1999
  smem_capacity: int,
1765
2000
  occupancy: int,
1766
2001
  ) -> Tuple[int, int, int]:
@@ -1778,7 +2013,7 @@ class GemmSm100(GemmSm90):
1778
2013
  :type epi_tile: cute.Tile
1779
2014
  :param d_dtype: Data type of operand C (output).
1780
2015
  :type d_dtype: type[cutlass.Numeric]
1781
- :param d_layout: Layout enum of operand C.
2016
+ :param d_layout: Layout enum of operand D.
1782
2017
  :type d_layout: LayoutEnum
1783
2018
  :param smem_capacity: Total available shared memory capacity in bytes.
1784
2019
  :type smem_capacity: int
@@ -1797,8 +2032,8 @@ class GemmSm100(GemmSm90):
1797
2032
  num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
1798
2033
 
1799
2034
  # Default D stages
1800
- epi_stage = 2
1801
- epi_c_stage = 2 if c_dtype is not None else 0
2035
+ epi_stage = 4 if cute.size(epi_tile[1]) <= 16 else 2
2036
+ epi_c_stage = 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2)
1802
2037
 
1803
2038
  # Calculate smem layout and size for one stage of A, B, and C
1804
2039
  a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
@@ -1813,7 +2048,11 @@ class GemmSm100(GemmSm90):
1813
2048
  b_dtype,
1814
2049
  1, # a tmp 1 stage is provided
1815
2050
  )
1816
- d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
2051
+ d_smem_layout_staged_one = (
2052
+ sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
2053
+ if d_dtype is not None
2054
+ else None
2055
+ )
1817
2056
  c_smem_layout_staged_one = (
1818
2057
  sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
1819
2058
  if c_dtype is not None
@@ -1836,13 +2075,22 @@ class GemmSm100(GemmSm90):
1836
2075
  ab_bytes_per_stage = cute.size_in_bytes(
1837
2076
  a_dtype, a_smem_layout_staged_one
1838
2077
  ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
2078
+ if const_expr(prefetch_A_idx == "varlen_k"): # Need smem to prefetch A indices
2079
+ ab_bytes_per_stage += Int32.width // 8 * cta_tile_shape_mnk[2]
1839
2080
  if const_expr(blockscaled):
1840
2081
  ab_bytes_per_stage += cute.size_in_bytes(
1841
2082
  sf_dtype, sfa_smem_layout_staged_one
1842
2083
  ) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
1843
2084
  mbar_helpers_bytes = 1024
1844
- d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
1845
- epi_bytes = d_bytes_per_stage * epi_stage
2085
+ if const_expr(prefetch_A_idx == "varlen_m"):
2086
+ mbar_helpers_bytes += Int32.width // 8 * cta_tile_shape_mnk[0] * 2
2087
+ d_bytes_per_stage = (
2088
+ cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) if d_dtype is not None else 0
2089
+ )
2090
+ epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
2091
+ epilogue_args, cta_tile_shape_mnk, epi_tile
2092
+ )
2093
+ epi_bytes = epi_bytes_per_stage * epi_stage
1846
2094
  if const_expr(c_dtype is not None):
1847
2095
  c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
1848
2096
  epi_bytes += c_bytes_per_stage * epi_c_stage
@@ -1851,18 +2099,13 @@ class GemmSm100(GemmSm90):
1851
2099
  # Start with total smem per CTA (capacity / occupancy)
1852
2100
  # Subtract reserved bytes and initial C stages bytes
1853
2101
  # Divide remaining by bytes needed per A/B/SFA/SFB stage
1854
- ab_stage = (
1855
- smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)
1856
- ) // ab_bytes_per_stage
2102
+ remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
2103
+ ab_stage = remaining_bytes // ab_bytes_per_stage
1857
2104
 
1858
2105
  # Refine epilogue stages:
1859
2106
  # Calculate remaining smem after allocating for A/B stages and reserved bytes
1860
2107
  # Add remaining unused smem to epilogue
1861
- epi_stage += (
1862
- smem_capacity
1863
- - occupancy * ab_bytes_per_stage * ab_stage
1864
- - occupancy * (mbar_helpers_bytes + epi_bytes)
1865
- ) // (occupancy * d_bytes_per_stage)
2108
+ epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // (epi_bytes_per_stage)
1866
2109
  return num_acc_stage, ab_stage, epi_stage, epi_c_stage
1867
2110
 
1868
2111
  @staticmethod
@@ -1891,9 +2134,12 @@ class GemmSm100(GemmSm90):
1891
2134
 
1892
2135
  @staticmethod
1893
2136
  def is_valid_dtypes(
1894
- ab_dtype: Type[cutlass.Numeric],
2137
+ a_dtype: Type[cutlass.Numeric],
2138
+ b_dtype: Type[cutlass.Numeric],
1895
2139
  acc_dtype: Type[cutlass.Numeric],
1896
- d_dtype: Type[cutlass.Numeric],
2140
+ d_dtype: Optional[Type[cutlass.Numeric]],
2141
+ a_major: str,
2142
+ b_major: str,
1897
2143
  ) -> bool:
1898
2144
  """
1899
2145
  Check if the dtypes are valid
@@ -1909,6 +2155,9 @@ class GemmSm100(GemmSm90):
1909
2155
  :rtype: bool
1910
2156
  """
1911
2157
  is_valid = True
2158
+ if b_dtype != a_dtype:
2159
+ is_valid = False
2160
+ ab_dtype = a_dtype
1912
2161
  if ab_dtype not in {
1913
2162
  cutlass.Float16,
1914
2163
  cutlass.BFloat16,
@@ -1927,7 +2176,7 @@ class GemmSm100(GemmSm90):
1927
2176
  and ab_dtype not in {cutlass.Uint8, cutlass.Int8}
1928
2177
  ):
1929
2178
  is_valid = False
1930
- if (
2179
+ if d_dtype is not None and (
1931
2180
  acc_dtype == Float32
1932
2181
  and d_dtype
1933
2182
  not in {
@@ -1958,6 +2207,8 @@ class GemmSm100(GemmSm90):
1958
2207
  }
1959
2208
  ):
1960
2209
  is_valid = False
2210
+ if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
2211
+ is_valid = False
1961
2212
  return is_valid
1962
2213
 
1963
2214
  @staticmethod
@@ -2014,34 +2265,6 @@ class GemmSm100(GemmSm90):
2014
2265
 
2015
2266
  return is_valid
2016
2267
 
2017
- @staticmethod
2018
- def is_valid_layouts(
2019
- ab_dtype: Type[cutlass.Numeric],
2020
- a_major: str,
2021
- b_major: str,
2022
- ) -> bool:
2023
- """
2024
- Check if the dtypes and sf_vec_size are valid combinations
2025
-
2026
- :param ab_dtype: The data type of the A and B operands
2027
- :type ab_dtype: Type[cutlass.Numeric]
2028
- :param d_dtype: The data type of the output tensor
2029
- :type d_dtype: Type[cutlass.Numeric]
2030
- :param a_major: The major dimension of the A tensor
2031
- :type a_major: str
2032
- :param b_major: The major dimension of the B tensor
2033
- :type b_major: str
2034
- :param d_major: The major dimension of the C tensor
2035
- :type d_major: str
2036
-
2037
- :return: True if the layouts are valid, False otherwise
2038
- :rtype: bool
2039
- """
2040
- is_valid = True
2041
- if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
2042
- is_valid = False
2043
- return is_valid
2044
-
2045
2268
  @staticmethod
2046
2269
  def is_valid_mma_tiler_and_cluster_shape(
2047
2270
  mma_tiler_mn: Tuple[int, int],
@@ -2187,7 +2410,7 @@ class GemmSm100(GemmSm90):
2187
2410
  """
2188
2411
  can_implement = True
2189
2412
  # Skip unsupported types
2190
- if not GemmSm100.is_valid_dtypes(ab_dtype, acc_dtype, d_dtype):
2413
+ if not GemmSm100.is_valid_dtypes(ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major):
2191
2414
  can_implement = False
2192
2415
  # Skip invalid mma tile shape and cluster shape
2193
2416
  if not GemmSm100.is_valid_mma_tiler_and_cluster_shape(
@@ -2362,7 +2585,7 @@ def run(
2362
2585
 
2363
2586
  # Configure gemm kernel
2364
2587
  cluster_shape_mnk = (*cluster_shape_mn, 1)
2365
- gemm = GemmSm100(acc_dtype, mma_tiler_mn, cluster_shape_mnk)
2588
+ gemm = GemmSm100(acc_dtype, ab_dtype, mma_tiler_mn, cluster_shape_mnk)
2366
2589
 
2367
2590
  # Compute max active clusters on current device
2368
2591
  hardware_info = cutlass.utils.HardwareInfo()