quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,9 @@
3
3
 
4
4
  import enum
5
5
  from typing import Tuple, Type, Callable, Optional, Union, Literal
6
- from dataclasses import dataclass
7
6
  from functools import partial
8
7
  import math
9
8
 
10
- from torch import Tensor
11
9
 
12
10
  import cuda.bindings.driver as cuda
13
11
 
@@ -16,10 +14,9 @@ import cutlass.cute as cute
16
14
  import cutlass.pipeline as pipeline
17
15
  from cutlass.cute.nvgpu import cpasync, warp, warpgroup
18
16
  import cutlass.utils.hopper_helpers as sm90_utils
19
- from cutlass import Int32, Float32, Boolean, const_expr
17
+ from cutlass import Int32, Float32, Float16, Boolean, const_expr
18
+ from cutlass.cutlass_dsl import if_generate
20
19
  from cutlass.utils import LayoutEnum
21
- import cutlass.torch as cutlass_torch
22
- from cutlass.cute.runtime import make_ptr
23
20
 
24
21
 
25
22
  from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
@@ -30,14 +27,12 @@ from quack.tile_scheduler import (
30
27
  VarlenMTileSchedulerArguments,
31
28
  VarlenMTileScheduler,
32
29
  )
33
- from quack.varlen_utils import VarlenArguments
34
- from quack.tensormap_manager import TensorMapManagerSm90
30
+ from quack.varlen_utils import VarlenArguments, VarlenManager
35
31
 
36
32
  # return PipelineStateWAdvance instead of PipelineState
37
33
  from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
38
- import quack.utils as utils
39
- from quack.cute_dsl_utils import get_max_active_clusters
40
- from quack.gemm_wrapper_utils import GemmWrapperBase
34
+ import quack.copy_utils as copy_utils
35
+ import quack.sm90_utils as quack_sm90_utils
41
36
 
42
37
  """
43
38
  A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
@@ -119,7 +114,7 @@ class GemmSm90:
119
114
 
120
115
  Example:
121
116
  >>> gemm = GemmSm90(
122
- ... acc_dtype=cutlass.Float32,
117
+ ... acc_dtype=Float32,
123
118
  ... tile_shape_mn=(128, 256),
124
119
  ... cluster_shape_mnk=(1, 1, 1)
125
120
  ... )
@@ -127,19 +122,10 @@ class GemmSm90:
127
122
  """
128
123
 
129
124
  arch = 90
130
- bytes_per_tensormap = 128
131
125
  num_epi_tensormaps: int = 0
132
126
 
133
- @dataclass
134
- class EpilogueArguments(ArgumentsBase):
135
- alpha: Optional[Float32 | cute.Tensor] = None
136
- beta: Optional[Float32 | cute.Tensor] = None
137
- add_to_output: bool = False
138
-
139
- @dataclass
140
- class EpilogueParams(ParamsBase):
141
- alpha: Optional[Float32 | cute.Tensor] = None
142
- beta: Optional[Float32 | cute.Tensor] = None
127
+ EpilogueArguments = ArgumentsBase
128
+ EpilogueParams = ParamsBase
143
129
 
144
130
  def __init__(
145
131
  self,
@@ -222,7 +208,9 @@ class GemmSm90:
222
208
  atom_layout_m, atom_layout_n = 1, 1
223
209
  self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
224
210
 
225
- self.num_mcast_ctas_a = self.cluster_shape_mnk[1] if not self.gather_A else 1
211
+ self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
212
+ if self.gather_A:
213
+ assert self.num_mcast_ctas_a == 1
226
214
  self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
227
215
  self.is_a_mcast = self.num_mcast_ctas_a > 1
228
216
  self.is_b_mcast = self.num_mcast_ctas_b > 1
@@ -237,10 +225,9 @@ class GemmSm90:
237
225
  self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
238
226
  self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
239
227
  self.num_ab_load_warps = 1 if not self.gather_A else 4
240
- self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
241
- self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
242
228
  self.ab_load_warp_id = self.mma_warp_groups * 4
243
- self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
229
+ # self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
230
+ # self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
244
231
 
245
232
  regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
246
233
  math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
@@ -335,7 +322,7 @@ class GemmSm90:
335
322
  self.d_dtype,
336
323
  self.c_dtype,
337
324
  epilogue_args,
338
- self.smem_capacity,
325
+ cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
339
326
  self.occupancy,
340
327
  # epi_smem will reuse smem ab if not persistent.
341
328
  overlap_sD_sA=not self.is_persistent,
@@ -465,6 +452,7 @@ class GemmSm90:
465
452
  )
466
453
 
467
454
  epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
455
+ varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
468
456
 
469
457
  TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
470
458
  tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
@@ -483,7 +471,7 @@ class GemmSm90:
483
471
  ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
484
472
  epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
485
473
  sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
486
- tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
474
+ tile_count: cute.struct.MemRange[Int32, self.sched_stage]
487
475
  sD: cute.struct.Align[
488
476
  cute.struct.MemRange[
489
477
  self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
@@ -520,10 +508,7 @@ class GemmSm90:
520
508
  tma_atom_c,
521
509
  tma_tensor_c,
522
510
  epilogue_params,
523
- varlen_args.mCuSeqlensM,
524
- varlen_args.mCuSeqlensK,
525
- varlen_args.mTensormaps,
526
- varlen_args.mAIdx,
511
+ varlen_params,
527
512
  self.cluster_layout_mnk,
528
513
  self.a_smem_layout_staged,
529
514
  self.b_smem_layout_staged,
@@ -535,7 +520,6 @@ class GemmSm90:
535
520
  grid=grid,
536
521
  block=[self.threads_per_cta, 1, 1],
537
522
  cluster=self.cluster_shape_mnk,
538
- smem=self.shared_storage.size_in_bytes(),
539
523
  stream=stream,
540
524
  min_blocks_per_mp=1,
541
525
  )
@@ -555,10 +539,7 @@ class GemmSm90:
555
539
  tma_atom_c: Optional[cute.CopyAtom],
556
540
  mC_mnl: Optional[cute.Tensor],
557
541
  epilogue_params: ParamsBase,
558
- cu_seqlens_m: Optional[cute.Tensor],
559
- cu_seqlens_k: Optional[cute.Tensor],
560
- tensormaps: Optional[cute.Tensor],
561
- mAIdx: Optional[cute.Tensor],
542
+ varlen_params: VarlenManager.Params,
562
543
  cluster_layout_mnk: cute.Layout,
563
544
  a_smem_layout: cute.ComposedLayout,
564
545
  b_smem_layout: cute.ComposedLayout,
@@ -594,8 +575,8 @@ class GemmSm90:
594
575
  :type epi_smem_layout: cute.ComposedLayout
595
576
  """
596
577
 
597
- varlen_m = const_expr(cu_seqlens_m is not None)
598
- varlen_k = const_expr(cu_seqlens_k is not None)
578
+ varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
579
+ varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
599
580
  assert not (varlen_m and varlen_k)
600
581
  if const_expr(self.gather_A):
601
582
  assert varlen_m or varlen_k
@@ -657,9 +638,19 @@ class GemmSm90:
657
638
  sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
658
639
  epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
659
640
 
660
- # Get tensormap buffer address
661
- tensormap_manager, tensormap_ab_ptrs, tensormap_d_ptr, tensormap_epi_ptrs = (
662
- self.tensormap_init(tensormaps, varlen_m, varlen_k, has_D, warp_idx)
641
+ varlen_manager = VarlenManager.create(
642
+ varlen_params,
643
+ has_D,
644
+ self.num_epi_tensormaps,
645
+ # Only used if not varlen_m
646
+ len_m_static=Int32(
647
+ mA_mkl.shape[0]
648
+ if varlen_k or varlen_params.mAIdx is None
649
+ else varlen_params.mAIdx.shape[0]
650
+ ),
651
+ len_k_static=Int32(mA_mkl.shape[1]),
652
+ pingpong=self.pingpong,
653
+ warp_idx=warp_idx,
663
654
  )
664
655
 
665
656
  TileSchedulerCls = partial(
@@ -673,29 +664,20 @@ class GemmSm90:
673
664
  and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
674
665
  ):
675
666
  is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
676
- if const_expr(varlen_k):
677
- # initialize tensormap for A & B
678
- if const_expr(not self.gather_A):
679
- tensormap_manager.init_tensormap_from_atom(
680
- tma_atom_a,
681
- tensormap_ab_ptrs[0],
682
- is_tma_warp,
683
- )
684
- tensormap_manager.init_tensormap_from_atom(
685
- tma_atom_b,
686
- tensormap_ab_ptrs[1],
687
- is_tma_warp,
688
- )
667
+ # initialize tensormap for A & B
668
+ varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
669
+ tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
670
+ tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
689
671
  # ///////////////////////////////////////////////////////////////////////////////
690
672
  # Get mcast mask
691
673
  # ///////////////////////////////////////////////////////////////////////////////
692
674
  cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
693
- cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
675
+ block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
694
676
  a_mcast_mask = cute.make_layout_image_mask(
695
- cluster_layout_mnk, cluster_coord_mnk, mode=1
677
+ cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
696
678
  )
697
679
  b_mcast_mask = cute.make_layout_image_mask(
698
- cluster_layout_mnk, cluster_coord_mnk, mode=0
680
+ cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0
699
681
  )
700
682
  a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
701
683
  b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
@@ -711,42 +693,30 @@ class GemmSm90:
711
693
  )
712
694
  if const_expr(varlen_k):
713
695
  # wait tensormap initialization complete before update
714
- tensormap_manager.fence_tensormap_initialization()
715
- # batch index of last tile
716
- last_batch_idx = cutlass.Int32(-1)
696
+ varlen_manager.fence_tensormap_init()
717
697
  while work_tile.is_valid_tile:
718
698
  tile_coord_mnkl = work_tile.tile_idx
719
699
  batch_idx = tile_coord_mnkl[3]
720
- if const_expr(varlen_k):
721
- is_group_changed = batch_idx != last_batch_idx
722
- last_batch_idx = batch_idx
723
- if is_group_changed:
724
- self.tensormap_update_AB(
725
- tensormap_manager,
726
- tensormap_ab_ptrs,
727
- cu_seqlens_k,
728
- batch_idx,
729
- is_tma_warp,
730
- )
700
+ varlen_manager.update_tensormap_AB(
701
+ batch_idx,
702
+ self.a_layout,
703
+ self.b_layout,
704
+ is_tma_warp,
705
+ )
731
706
  # ///////////////////////////////////////////////////////////////////////////
732
707
  # Local_tile partition global tensors
733
708
  # ///////////////////////////////////////////////////////////////////////////
734
709
  if const_expr(not self.gather_A):
735
- if const_expr(varlen_m):
736
- mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
737
- elif const_expr(varlen_k):
738
- mA_mk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mA_mkl)
739
- else:
740
- mA_mk = mA_mkl[None, None, batch_idx]
710
+ mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
741
711
  # (bM, bK, RestK)
742
- gA_k = cute.local_tile(
712
+ gA_mk = cute.local_tile(
743
713
  mA_mk,
744
714
  cute.select(self.cta_tile_shape_mnk, [0, 2]),
745
715
  (tile_coord_mnkl[0], None),
746
716
  )
747
717
  else:
718
+ mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
748
719
  if const_expr(varlen_m):
749
- mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
750
720
  gAIdx = cute.local_tile(
751
721
  mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
752
722
  )
@@ -754,133 +724,90 @@ class GemmSm90:
754
724
  mA_mk = mA_mkl
755
725
  else:
756
726
  assert varlen_k
757
- mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
758
727
  # (tile_K, RestK)
759
728
  gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
760
729
  # (tile_M, K)
761
730
  mA_mk = cute.local_tile(
762
731
  mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
763
732
  )
764
- if const_expr(varlen_k):
765
- mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
766
- else:
767
- mB_nk = mB_nkl[None, None, batch_idx]
768
733
  # (bN, bK, RestK)
769
- gB_k = cute.local_tile(
770
- mB_nk,
734
+ gB_nk = cute.local_tile(
735
+ varlen_manager.offset_batch_B(mB_nkl, batch_idx),
771
736
  cute.select(self.cta_tile_shape_mnk, [1, 2]),
772
737
  (tile_coord_mnkl[1], None),
773
738
  )
774
739
  # //////////////////////////////////////////////////////////////////////////
775
740
  # Partition shared tensor for TMA load A/B
776
741
  # //////////////////////////////////////////////////////////////////////////
777
- tma_desc_a_ptr, tma_desc_b_ptr = None, None
778
- if const_expr(varlen_k):
779
- # ensure the update to tensormap has completed before using it
780
- tensormap_a_ptr, tensormap_b_ptr = tensormap_ab_ptrs
781
- if is_group_changed and is_tma_warp:
782
- if const_expr(not self.gather_A):
783
- tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
784
- tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
785
- if const_expr(not self.gather_A):
786
- tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
787
- tensormap_a_ptr, cute.AddressSpace.generic
788
- )
789
- tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
790
- tensormap_b_ptr, cute.AddressSpace.generic
791
- )
742
+ varlen_manager.fence_tensormap_update_AB(is_tma_warp)
743
+ len_m = varlen_manager.len_m(batch_idx)
744
+ len_k = varlen_manager.len_k(batch_idx)
792
745
  # TMA load A partition_S/D
793
- a_cta_layout = cute.make_layout(
794
- cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
795
- )
796
- a_cta_crd = cluster_coord_mnk[1]
746
+ copy_A = None
797
747
  if const_expr(not self.gather_A):
798
- # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
799
- tAsA, tAgA_k = cpasync.tma_partition(
800
- tma_atom_a,
801
- a_cta_crd,
802
- a_cta_layout,
803
- cute.group_modes(sA, 0, 2),
804
- cute.group_modes(gA_k, 0, 2),
805
- )
806
- copy_A = partial(
807
- cute.copy,
748
+ copy_A, _, _ = copy_utils.tma_get_copy_fn(
808
749
  tma_atom_a,
750
+ cta_coord=block_in_cluster_coord_mnk[1],
751
+ cta_layout=cute.make_layout(
752
+ cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
753
+ ),
754
+ src_tensor=gA_mk,
755
+ dst_tensor=sA,
809
756
  mcast_mask=a_mcast_mask,
810
757
  tma_desc_ptr=tma_desc_a_ptr,
811
758
  )
812
759
  else:
813
760
  tiled_copy_A = self._make_gmem_tiled_copy_A(
814
- mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
761
+ mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
815
762
  )
816
763
  tidx = (
817
- cute.arch.thread_idx()[0]
818
- - self.mma_warp_groups * self.num_threads_per_warp_group
764
+ cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
819
765
  )
820
766
  thr_copy_A = tiled_copy_A.get_slice(tidx)
821
- # (atom_v, CPY_M, 1, STAGE)
822
- tAsA = thr_copy_A.partition_D(sA)
823
- if const_expr(varlen_m): # k-major
824
- assert tAsA.shape[2] == 1
825
- tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
826
- else: # varlen_k, m-major
827
- tAsA = cute.group_modes(tAsA, 0, 3)
828
- copy_A = partial(cute.copy, tiled_copy_A)
767
+ copy_A, prefetch_A = None, None
768
+ if const_expr(varlen_m):
769
+ copy_A = copy_utils.gather_m_get_copy_fn(
770
+ thr_copy_A,
771
+ mA_mk,
772
+ sA,
773
+ gAIdx,
774
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
775
+ limit_k=len_k,
776
+ )
777
+ else:
778
+ copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
779
+ thr_copy_A,
780
+ mA_mk,
781
+ sA,
782
+ gAIdx,
783
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
784
+ limit_k=len_k,
785
+ )
829
786
  # TMA load B partition_S/D
830
- b_cta_layout = cute.make_layout(
831
- cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
832
- )
833
- b_cta_crd = cluster_coord_mnk[0]
834
- # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
835
- tBsB, tBgB_k = cpasync.tma_partition(
787
+ copy_B, _, _ = copy_utils.tma_get_copy_fn(
836
788
  tma_atom_b,
837
- b_cta_crd,
838
- b_cta_layout,
839
- cute.group_modes(sB, 0, 2),
840
- cute.group_modes(gB_k, 0, 2),
841
- )
842
- copy_B = partial(
843
- cute.copy, tma_atom_b, mcast_mask=b_mcast_mask, tma_desc_ptr=tma_desc_b_ptr
789
+ cta_coord=block_in_cluster_coord_mnk[0],
790
+ cta_layout=cute.make_layout(
791
+ cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
792
+ ),
793
+ src_tensor=gB_nk,
794
+ dst_tensor=sB,
795
+ mcast_mask=b_mcast_mask,
796
+ tma_desc_ptr=tma_desc_b_ptr,
844
797
  )
845
- k_len = (
846
- cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
847
- if const_expr(varlen_k)
848
- else Int32(mA_mkl.shape[1])
849
- )
850
- k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
798
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
851
799
  if const_expr(not self.gather_A):
852
800
  ab_producer_state = self.load_AB(
853
- ab_pipeline,
854
- ab_producer_state,
855
- copy_A,
856
- tAgA_k,
857
- tAsA,
858
- copy_B,
859
- tBgB_k,
860
- tBsB,
861
- k_tile_cnt,
801
+ ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
862
802
  )
863
803
  else:
864
- limit_m = (
865
- Int32(mA_mkl.shape[0])
866
- if const_expr(cu_seqlens_m is None)
867
- else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
868
- )
869
804
  ab_producer_state = self.load_AB_gather_A(
870
805
  ab_pipeline,
871
806
  ab_producer_state,
872
- thr_copy_A,
873
- mA_mk,
874
- tAsA,
875
- gAIdx,
807
+ copy_A,
808
+ prefetch_A,
876
809
  copy_B,
877
- tBgB_k,
878
- tBsB,
879
810
  k_tile_cnt,
880
- limit_A=(
881
- limit_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
882
- k_len,
883
- ),
884
811
  varlen_m=varlen_m,
885
812
  )
886
813
  tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
@@ -900,22 +827,11 @@ class GemmSm90:
900
827
  (not self.pingpong and warp_idx == 0)
901
828
  or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
902
829
  )
903
- if const_expr(varlen_m):
904
- # initialize tensormap for D
905
- if const_expr(has_D):
906
- tensormap_manager.init_tensormap_from_atom(
907
- tma_atom_d,
908
- tensormap_d_ptr,
909
- is_manager_warp=is_tma_warp,
910
- )
911
- for tma_atom, tensormap_epi_ptr in zip(
912
- self.epi_get_tma_atoms(epilogue_params), tensormap_epi_ptrs
913
- ):
914
- tensormap_manager.init_tensormap_from_atom(
915
- tma_atom,
916
- tensormap_epi_ptr,
917
- is_manager_warp=is_tma_warp,
918
- )
830
+ varlen_manager.init_tensormap_epi(
831
+ tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
832
+ )
833
+ tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
834
+ tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
919
835
  # //////////////////////////////////////////////////////////////////////////////
920
836
  # Partition global tensor for TiledMMA_A/B/C
921
837
  # //////////////////////////////////////////////////////////////////////////////
@@ -974,9 +890,8 @@ class GemmSm90:
974
890
  if const_expr(not varlen_k):
975
891
  ab_read_state.advance_iters(k_tile_cnt_static)
976
892
  else:
977
- batch_idx = work_tile.tile_idx[3]
978
- k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
979
- k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
893
+ len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
894
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
980
895
  ab_read_state.advance_iters(k_tile_cnt)
981
896
  tile_scheduler.advance_to_next_work()
982
897
  if const_expr(varlen_k):
@@ -987,32 +902,22 @@ class GemmSm90:
987
902
  work_tile = tile_scheduler.initial_work_tile_info()
988
903
  if const_expr(varlen_m):
989
904
  # wait tensormap initialization complete before update
990
- tensormap_manager.fence_tensormap_initialization()
991
- # batch index of last tile
992
- last_batch_idx = cutlass.Int32(-1)
905
+ varlen_manager.fence_tensormap_init()
993
906
  while work_tile.is_valid_tile:
994
907
  tile_coord_mnkl = work_tile.tile_idx
995
908
  batch_idx = tile_coord_mnkl[3]
996
- if const_expr(varlen_m):
997
- is_group_changed = batch_idx != last_batch_idx
998
- last_batch_idx = batch_idx
999
- if is_group_changed:
1000
- self.tensormap_update_D_epi(
1001
- tensormap_manager,
1002
- tensormap_d_ptr,
1003
- tensormap_epi_ptrs,
1004
- epilogue_params,
1005
- cu_seqlens_m,
1006
- batch_idx,
1007
- is_manager_warp=is_tma_warp,
1008
- )
1009
-
1010
- k_len = (
1011
- cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1012
- if const_expr(varlen_k)
1013
- else mA_mkl.shape[1]
909
+ epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
910
+ epilogue_params, varlen_params.cu_seqlens_m, batch_idx
911
+ )
912
+ varlen_manager.update_tensormap_epi(
913
+ batch_idx,
914
+ self.d_layout,
915
+ epi_shapes,
916
+ epi_orders,
917
+ is_tma_warp,
1014
918
  )
1015
- k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
919
+ len_k = varlen_manager.len_k(batch_idx)
920
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1016
921
  ab_read_state, tiled_mma = self.mma(
1017
922
  ab_pipeline,
1018
923
  ab_read_state,
@@ -1039,57 +944,38 @@ class GemmSm90:
1039
944
  num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
1040
945
  )
1041
946
 
1042
- tma_desc_d_ptr, tma_desc_epi_ptrs = None, [None] * self.num_epi_tensormaps
1043
- if const_expr(varlen_m):
1044
- # ensure the update to tensormap has completed before using it
1045
- if is_group_changed and is_tma_warp:
1046
- if const_expr(has_D):
1047
- tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1048
- for tensormap_epi_ptr in tensormap_epi_ptrs:
1049
- tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
1050
- if const_expr(has_D):
1051
- tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
1052
- tensormap_d_ptr, cute.AddressSpace.generic
1053
- )
1054
- tma_desc_epi_ptrs = [
1055
- tensormap_manager.get_tensormap_ptr(
1056
- tensormap_epi_ptr, cute.AddressSpace.generic
1057
- )
1058
- for tensormap_epi_ptr in tensormap_epi_ptrs
1059
- ]
947
+ varlen_manager.fence_tensormap_update_epi(is_tma_warp)
1060
948
 
949
+ copy_D = None
1061
950
  if const_expr(has_D):
1062
- bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(
951
+ copy_D, _, _ = self.epilog_gmem_copy_and_partition(
1063
952
  tma_atom_d,
1064
- mD_mnl,
953
+ varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
1065
954
  self.cta_tile_shape_mnk[:2],
1066
955
  self.epi_tile,
1067
956
  sD,
1068
957
  tile_coord_mnkl,
1069
- cu_seqlens_m,
958
+ tma_desc_ptr=tma_desc_d_ptr,
1070
959
  )
1071
- copy_D = partial(cute.copy, tma_atom_d, tma_desc_ptr=tma_desc_d_ptr)
1072
- else:
1073
- bSG_sD, bSG_gD, copy_D = None, None, None
960
+ copy_C = None
1074
961
  if const_expr(has_C):
1075
- bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
962
+ copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
1076
963
  tma_atom_c,
1077
- mC_mnl,
964
+ varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
1078
965
  self.cta_tile_shape_mnk[:2],
1079
966
  self.epi_tile,
1080
967
  sC,
1081
968
  tile_coord_mnkl,
1082
- cu_seqlens_m,
1083
969
  )
1084
- copy_C = partial(cute.copy, tma_atom_c)
1085
- epi_load_g2s = partial(self.epi_load_g2s, epi_pipeline, copy_C, bGS_gC, bGS_sC)
1086
- else:
1087
- epi_load_g2s = None
970
+ copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
1088
971
 
1089
972
  d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
1090
- tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
1091
- tiled_mma, self.d_layout, d_dtype_for_layout, acc, sD, tidx
973
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
974
+ tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
1092
975
  )
976
+ # (R2S, R2S_M, R2S_N)
977
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
978
+ load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
1093
979
  if const_expr(has_C):
1094
980
  tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
1095
981
  tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
@@ -1112,21 +998,20 @@ class GemmSm90:
1112
998
  epi_store_pipeline,
1113
999
  epi_read_state,
1114
1000
  epi_producer_state,
1115
- tiled_mma,
1116
- tRS_rAcc,
1001
+ self.epi_tile,
1002
+ load_acc_subtile,
1117
1003
  tRS_rD,
1118
1004
  tRS_rC,
1005
+ None, # tiled_copy_t2r, for Sm100 only
1119
1006
  tiled_copy_r2s,
1120
1007
  tRS_sD,
1121
1008
  tiled_copy_s2r,
1122
1009
  tSR_rC,
1123
1010
  tSR_sC,
1124
1011
  copy_D,
1125
- bSG_sD,
1126
- bSG_gD,
1127
- epi_load_g2s,
1012
+ copy_C,
1128
1013
  tile_coord_mnkl,
1129
- cu_seqlens_m,
1014
+ varlen_manager,
1130
1015
  epilogue_barrier,
1131
1016
  tile_scheduler,
1132
1017
  tidx,
@@ -1157,9 +1042,8 @@ class GemmSm90:
1157
1042
  tile_scheduler.advance_to_next_work()
1158
1043
  work_tile = tile_scheduler.get_current_work()
1159
1044
  if work_tile.is_valid_tile:
1160
- batch_idx = work_tile.tile_idx[3]
1161
- k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1162
- k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
1045
+ len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
1046
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1163
1047
  ab_read_state.advance_iters(k_tile_cnt)
1164
1048
  tile_scheduler.advance_to_next_work()
1165
1049
  work_tile = tile_scheduler.get_current_work()
@@ -1175,14 +1059,16 @@ class GemmSm90:
1175
1059
  self,
1176
1060
  ab_pipeline: cutlass.pipeline.PipelineAsync,
1177
1061
  ab_producer_state: cutlass.pipeline.PipelineState,
1178
- copy_A: Callable,
1179
- tAgA: cute.Tensor,
1180
- tAsA: cute.Tensor,
1062
+ copy_A: Optional[Callable],
1181
1063
  copy_B: Callable,
1182
- tBgB: cute.Tensor,
1183
- tBsB: cute.Tensor,
1184
1064
  k_tile_cnt: Int32,
1065
+ # These are for Sm100 blockscaled gemm
1066
+ copy_SFA: Optional[Callable] = None,
1067
+ copy_SFB: Optional[Callable] = None,
1185
1068
  ) -> cutlass.pipeline.PipelineState:
1069
+ blockscaled = const_expr(copy_SFA is not None)
1070
+ if const_expr(blockscaled):
1071
+ assert copy_SFB is not None
1186
1072
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1187
1073
  peek_ab_empty_status = Boolean(True)
1188
1074
  if 0 < k_tile_cnt:
@@ -1195,8 +1081,13 @@ class GemmSm90:
1195
1081
  # Also sets the transaction barrier for the A/B buffers
1196
1082
  ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1197
1083
  tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1198
- copy_A(tAgA[None, k_tile], tAsA[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
1199
- copy_B(tBgB[None, k_tile], tBsB[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
1084
+ smem_idx = ab_producer_state.index
1085
+ if const_expr(copy_A is not None):
1086
+ copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1087
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1088
+ if const_expr(blockscaled):
1089
+ copy_SFA(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1090
+ copy_SFB(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1200
1091
  # Mainloop pipeline's producer commit is a NOP
1201
1092
  ab_pipeline.producer_commit(ab_producer_state)
1202
1093
  ab_producer_state.advance()
@@ -1210,58 +1101,12 @@ class GemmSm90:
1210
1101
  self,
1211
1102
  ab_pipeline: cutlass.pipeline.PipelineAsync,
1212
1103
  ab_producer_state: cutlass.pipeline.PipelineState,
1213
- thr_copy_A: cute.core.ThrCopy,
1214
- mA: cute.Tensor, # (M, K) if varlen_m, (tile_M, K) if varlen_k
1215
- tAsA: cute.Tensor,
1216
- gAIdx: cute.Tensor, # (tile_M,) if varlen_m, (tile_K, RestK) if varlen_k
1104
+ copy_A: Callable,
1105
+ prefetch_A: Optional[Callable],
1217
1106
  copy_B: Callable,
1218
- tBgB: cute.Tensor,
1219
- tBsB: cute.Tensor,
1220
1107
  k_tile_cnt: Int32,
1221
- limit_A: Tuple[Int32, Int32],
1222
- varlen_m: bool,
1108
+ varlen_m: bool = True,
1223
1109
  ) -> cutlass.pipeline.PipelineState:
1224
- limit_m, limit_k = limit_A
1225
- # Do we need to check if we overshoot tile_M when we load A?
1226
- is_even_m_smem = self.cta_tile_shape_mnk[0] % thr_copy_A.tiler_mn[0].shape == 0
1227
- if const_expr(not is_even_m_smem):
1228
- limit_m = min(limit_m, self.cta_tile_shape_mnk[0])
1229
- elems_per_load = cute.size(tAsA.shape[0][0])
1230
- cA = cute.make_identity_tensor(cute.select(self.cta_tile_shape_mnk, [0, 2]))
1231
- tAcA = thr_copy_A.partition_S(cA)
1232
- t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
1233
- # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
1234
- # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
1235
- # This is so that when we do the comparison, t0AcA is known at compile time.
1236
- limit_m = limit_m - tAcA[0][0]
1237
- limit_k = limit_k - tAcA[0][1]
1238
- # Read indices for A
1239
- rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
1240
- cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
1241
- tApA_m = cute.make_fragment(rows_per_thread, Boolean)
1242
- for m in cutlass.range_constexpr(rows_per_thread):
1243
- tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
1244
- m_idx, k_idx, tAmA = None, None, None
1245
- if const_expr(varlen_m):
1246
- m_idx = cute.make_fragment(rows_per_thread, Int32)
1247
- for m in cutlass.range(rows_per_thread):
1248
- row_idx = tAcA[0, m, 0][0]
1249
- if tApA_m[m]:
1250
- m_idx[m] = gAIdx[row_idx]
1251
- else:
1252
- m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
1253
- else:
1254
- k_idx = cute.make_fragment(cols_per_thread, Int32) # Will be read later
1255
- threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
1256
- # This is very convoluted but idk a better way
1257
- # for tile_M=128, flat_divide gives (8, 16, K),
1258
- # then logical_divide gives ((8, 1), (8, 2), K).
1259
- tidx = thr_copy_A.thr_idx
1260
- tAmA = cute.logical_divide(
1261
- cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
1262
- )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
1263
- # (m, (bK, RestK))
1264
- mA_k = cute.logical_divide(mA, (None, self.cta_tile_shape_mnk[2]))
1265
1110
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1266
1111
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1267
1112
  peek_ab_empty_status = Boolean(True)
@@ -1270,59 +1115,27 @@ class GemmSm90:
1270
1115
  # /////////////////////////////////////////////////////////////////////////
1271
1116
  # TMA load on B and cp.async on A
1272
1117
  # /////////////////////////////////////////////////////////////////////////
1273
- copy_A = partial(cute.copy, thr_copy_A)
1274
1118
  for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1275
- if const_expr(not varlen_m): # Prefetch mAIdx early, even before smem is free
1276
- gAIdx_cur = gAIdx[None, k_tile]
1277
- for k in cutlass.range(cols_per_thread):
1278
- col_idx = tAcA[0, 0, k][1]
1279
- k_idx[k] = gAIdx_cur[col_idx]
1119
+ prefetch_out = ()
1120
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1121
+ prefetch_out = (prefetch_A(k_tile),)
1280
1122
  # Wait for A/B buffers to be empty before loading into them
1281
1123
  # Also sets the transaction barrier for the A/B buffers
1282
1124
  # A tiny bit faster to rotate the warp that does TMA
1283
1125
  # However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
1284
1126
  # since that's the warp that does the tensormap update.
1285
- tma_warp_id = self.ab_load_warp_id + (
1127
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
1286
1128
  (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1287
1129
  )
1288
- ab_pipeline.producer_acquire(
1289
- ab_producer_state,
1290
- peek_ab_empty_status,
1291
- is_tma_warp=warp_idx == tma_warp_id,
1292
- )
1293
- # A bit faster to load B first while we calculate the predicate for A
1294
- if warp_idx == tma_warp_id:
1295
- copy_B(
1296
- tBgB[None, k_tile],
1297
- tBsB[None, ab_producer_state.index],
1298
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1299
- )
1300
- # (m, bK)
1301
- if const_expr(varlen_m):
1302
- mA_cur = mA_k[None, (None, k_tile)]
1303
- for m in cutlass.range_constexpr(tAcA.shape[1]):
1304
- # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
1305
- # ((elems_per_load), thread_per_row)
1306
- # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
1307
- # So we append 1s to the last dimension and then do tiled_divide, then slice.
1308
- mA_row = cute.tiled_divide(
1309
- cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
1310
- )[None, None, 0]
1311
- if const_expr(is_even_m_smem) or tApA_m[m]:
1312
- # There's only 1 load per row
1313
- assert cute.size(tAcA.shape, mode=[2]) == 1
1314
- ki = tAcA[0, 0, 0][1] // elems_per_load
1315
- copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1316
- else:
1317
- for k in cutlass.range_constexpr(tAcA.shape[2]):
1318
- # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), ab_producer_state.index], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
1319
- for m in cutlass.range_constexpr(tAcA.shape[1]):
1320
- if tApA_m[m]:
1321
- copy_A(
1322
- tAmA[None, m, k_idx[k]], tAsA[(None, m, k), ab_producer_state.index]
1323
- )
1130
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1131
+ smem_idx = ab_producer_state.index
1132
+ # A bit faster to load B first while we calculate the indices for A
1133
+ if is_tma_warp:
1134
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1135
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1136
+ copy_A(k_tile, smem_idx, *prefetch_out)
1324
1137
  # This tells mbarrier to track the completion of cp.async
1325
- ab_pipeline.producer_commit(ab_producer_state)
1138
+ ab_pipeline.producer_cpasync_commit(ab_producer_state)
1326
1139
  ab_producer_state.advance()
1327
1140
  peek_ab_empty_status = Boolean(True)
1328
1141
  if k_tile + 1 < k_tile_cnt:
@@ -1330,58 +1143,19 @@ class GemmSm90:
1330
1143
  # bound checking in the K dimension on the last k_tile
1331
1144
  if 0 < k_tile_cnt:
1332
1145
  k_tile = k_tile_cnt - 1
1333
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
1334
- limit_k -= k_tile * self.cta_tile_shape_mnk[2]
1335
- for k in cutlass.range_constexpr(cols_per_thread):
1336
- tApA_k[k] = t0AcA[0, 0, k][1] < limit_k
1337
- if const_expr(not varlen_m):
1338
- gAIdx_cur = gAIdx[None, k_tile]
1339
- for k in cutlass.range(cols_per_thread):
1340
- col_idx = tAcA[0, 0, k][1]
1341
- if tApA_k[k]:
1342
- k_idx[k] = gAIdx_cur[col_idx]
1343
- else:
1344
- k_idx[k] = -1
1345
- tma_warp_id = self.ab_load_warp_id + (
1146
+ prefetch_out = ()
1147
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1148
+ prefetch_out = (prefetch_A(k_tile, pred=True),)
1149
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
1346
1150
  (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1347
1151
  )
1348
- ab_pipeline.producer_acquire(
1349
- ab_producer_state,
1350
- peek_ab_empty_status,
1351
- is_tma_warp=warp_idx == tma_warp_id,
1352
- )
1353
- if warp_idx == tma_warp_id:
1354
- copy_B(
1355
- tBgB[None, k_tile],
1356
- tBsB[None, ab_producer_state.index],
1357
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1358
- )
1359
- if const_expr(varlen_m):
1360
- # (m, bK)
1361
- mA_cur = mA_k[None, (None, k_tile)]
1362
- for m in cutlass.range_constexpr(tAcA.shape[1]):
1363
- # ((elems_per_load, 1), thread_per_row)
1364
- mA_row = cute.tiled_divide(
1365
- cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
1366
- )[None, None, 0]
1367
- if const_expr(is_even_m_smem) or tApA_m[k]:
1368
- # There's only 1 load per row
1369
- assert cute.size(tAcA.shape, mode=[2]) == 1
1370
- ki = tAcA[0, 0, 0][1] // elems_per_load
1371
- copy_A(
1372
- mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA_k
1373
- )
1374
- else:
1375
- tApA_k = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
1376
- for k in cutlass.range_constexpr(tAcA.shape[2]):
1377
- for m in cutlass.range_constexpr(tAcA.shape[1]):
1378
- if tApA_m[m]:
1379
- copy_A(
1380
- tAmA[None, m, k_idx[k]],
1381
- tAsA[(None, m, k), ab_producer_state.index],
1382
- pred=tApA_k[None, k],
1383
- )
1384
- ab_pipeline.producer_commit(ab_producer_state)
1152
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1153
+ smem_idx = ab_producer_state.index
1154
+ if is_tma_warp:
1155
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1156
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1157
+ copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
1158
+ ab_pipeline.producer_cpasync_commit(ab_producer_state)
1385
1159
  ab_producer_state.advance()
1386
1160
  return ab_producer_state
1387
1161
 
@@ -1481,22 +1255,21 @@ class GemmSm90:
1481
1255
  epi_pipeline: cutlass.pipeline.PipelineAsync,
1482
1256
  epi_store_pipeline: cutlass.pipeline.PipelineAsync,
1483
1257
  epi_read_state: cutlass.pipeline.PipelineState,
1484
- epi_producer_state: cutlass.pipeline.PipelineState,
1485
- tiled_mma: cute.TiledMma,
1486
- tRS_rAcc: cute.Tensor,
1258
+ epi_producer_state: Optional[cutlass.pipeline.PipelineState],
1259
+ epi_tile: cute.Tile,
1260
+ load_acc_subtile: Callable,
1487
1261
  tRS_rD: cute.Tensor,
1488
1262
  tRS_rC: Optional[cute.Tensor],
1489
- tiled_copy_r2s: cute.core.ThrCopy,
1263
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
1264
+ tiled_copy_r2s: cute.TiledCopy,
1490
1265
  tRS_sD: cute.Tensor,
1491
- tiled_copy_s2r: Optional[cute.core.ThrCopy],
1266
+ tiled_copy_s2r: Optional[cute.ThrCopy],
1492
1267
  tSR_rC: Optional[cute.Tensor],
1493
1268
  tSR_sC: Optional[cute.Tensor],
1494
1269
  copy_D: Optional[Callable],
1495
- bSG_sD: cute.Tensor,
1496
- bSG_gD: cute.Tensor,
1497
- epi_load_g2s: Optional[Callable],
1270
+ copy_C: Optional[Callable],
1498
1271
  tile_coord_mnkl: cute.Coord,
1499
- cu_seqlens_m: Optional[cute.Tensor],
1272
+ varlen_manager: VarlenManager,
1500
1273
  epilogue_barrier: cutlass.pipeline.NamedBarrier,
1501
1274
  tile_scheduler,
1502
1275
  tidx: Int32,
@@ -1504,22 +1277,61 @@ class GemmSm90:
1504
1277
  ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
1505
1278
  has_C = const_expr(tRS_rC is not None)
1506
1279
  has_D = const_expr(copy_D is not None)
1507
- # We iterate over epi tiles in the N dimension first before the M dimension
1508
1280
  epi_tile_shape = cute.zipped_divide(
1509
- cute.make_layout(self.cta_tile_shape_mnk[:2]), self.epi_tile
1281
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
1510
1282
  ).shape[1]
1511
- epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
1283
+ # We iterate over epi tiles in the N dimension first before the M dimension
1284
+ epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0))
1512
1285
  epi_tile_num = cute.size(epi_tile_shape)
1513
1286
  num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1514
1287
 
1515
- if const_expr(epi_load_g2s is not None):
1288
+ epi_tensors = self.epi_begin(
1289
+ params,
1290
+ epi_smem_tensors,
1291
+ epi_tile,
1292
+ tiled_copy_t2r,
1293
+ tiled_copy_r2s,
1294
+ tile_coord_mnkl,
1295
+ varlen_manager,
1296
+ epilogue_barrier,
1297
+ tidx,
1298
+ )
1299
+
1300
+ if const_expr(copy_C is not None):
1516
1301
  for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
1517
- epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
1302
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
1303
+ if is_tma_warp:
1304
+ epi_pipeline.producer_acquire(epi_producer_state)
1305
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
1306
+ epi_pipeline.producer_commit(epi_producer_state)
1307
+ epi_producer_state.advance()
1518
1308
 
1309
+ def tma_store_fn(src_idx, dst_idx):
1310
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1311
+ cute.arch.fence_proxy(
1312
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1313
+ )
1314
+ epilogue_barrier.arrive_and_wait()
1315
+ # Copy from shared memory to global memory
1316
+ if is_tma_warp:
1317
+ if const_expr(has_D):
1318
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
1319
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
1320
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
1321
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
1322
+ epilogue_barrier.arrive_and_wait()
1323
+
1324
+ # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops
1325
+ # with the TMA store. However, currently this doesn't seem to improve perf.
1326
+ delay_tma_store = False
1327
+
1328
+ src_idx_prev, dst_idx_prev = None, None
1519
1329
  for epi_idx in cutlass.range_constexpr(epi_tile_num):
1330
+ # The global memory coordinate for the current epi tile
1331
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1520
1332
  # Copy from acc to D registers
1521
- for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1522
- tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1333
+ load_acc_subtile(tRS_rD, epi_idx)
1334
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
1523
1335
  if const_expr(has_C):
1524
1336
  epi_pipeline.consumer_wait(epi_read_state)
1525
1337
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
@@ -1531,190 +1343,40 @@ class GemmSm90:
1531
1343
  with cute.arch.elect_one():
1532
1344
  epi_pipeline.consumer_release(epi_read_state)
1533
1345
  epi_read_state.advance()
1534
- if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
1535
- epi_producer_state = epi_load_g2s(
1536
- epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
1537
- )
1538
- tRS_rEpi = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
1346
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
1347
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
1348
+ if is_tma_warp:
1349
+ epi_pipeline.producer_acquire(epi_producer_state)
1350
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
1351
+ epi_pipeline.producer_commit(epi_producer_state)
1352
+ epi_producer_state.advance()
1353
+ tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
1539
1354
  epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
1355
+ if const_expr(delay_tma_store):
1356
+ if const_expr(epi_idx > 0):
1357
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
1358
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
1540
1359
  # Copy from D registers to shared memory
1541
1360
  if const_expr(has_D):
1542
- # Type conversion
1543
- tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
1544
- tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
1545
- cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
1546
- # Fence and barrier to make sure shared memory store is visible to TMA store
1547
- cute.arch.fence_proxy(
1548
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1549
- )
1550
- epilogue_barrier.arrive_and_wait()
1551
- # Get the global memory coordinate for the current epi tile
1552
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1553
- # Copy from shared memory to global memory
1554
- if is_tma_warp:
1555
- if const_expr(has_D):
1556
- copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
1557
- epi_store_pipeline.producer_commit()
1558
- epi_store_pipeline.producer_acquire()
1559
- epilogue_barrier.arrive_and_wait()
1560
-
1561
- return epi_read_state, epi_producer_state
1562
-
1563
- @cute.jit
1564
- def epi_load_g2s(
1565
- self,
1566
- epi_pipeline: cutlass.pipeline.PipelineAsync,
1567
- copy_C: Callable,
1568
- bGS_gC: cute.Tensor,
1569
- bGS_sC: cute.Tensor,
1570
- epi_producer_state: cutlass.pipeline.PipelineState,
1571
- epi_idx: Int32,
1572
- should_load: Boolean,
1573
- ) -> cutlass.pipeline.PipelineState:
1574
- # We iterate over epi tiles in the N dimension first before the M dimension
1575
- epi_tile_layout = cute.make_layout(bGS_gC.shape[1], stride=(bGS_gC.shape[1][1], 1))
1576
- if should_load:
1577
- epi_pipeline.producer_acquire(epi_producer_state)
1578
- # Get the global memory coordinate for the current epi tile
1579
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1580
- copy_C(
1581
- bGS_gC[None, gmem_coord],
1582
- bGS_sC[None, epi_producer_state.index],
1583
- tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1584
- )
1585
- # Epi pipeline's producer commit is a NOP
1586
- epi_pipeline.producer_commit(epi_producer_state)
1587
- epi_producer_state.advance()
1588
- return epi_producer_state
1589
-
1590
- def epi_visit_acc_subtile(
1591
- self,
1592
- params: EpilogueParams,
1593
- tRS_rD: cute.Tensor,
1594
- tRS_rC: Optional[cute.Tensor] = None,
1595
- ) -> Optional[cute.Tensor]:
1596
- # Apply alpha scaling to accumulator if alpha is provided (not None)
1597
- if const_expr(hasattr(params, "alpha") and params.alpha is not None):
1598
- alpha = utils.load_scalar_or_pointer(params.alpha)
1599
- tRS_rD.store(tRS_rD.load() * alpha)
1600
- # Apply C with beta scaling
1601
- if const_expr(tRS_rC is not None):
1602
- if const_expr(not hasattr(params, "beta") or params.beta is None):
1603
- # beta is None, default behavior: add C (beta=1.0)
1604
- tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
1605
- else:
1606
- beta = utils.load_scalar_or_pointer(params.beta)
1607
- tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
1608
- return None
1609
-
1610
- def tensormap_init(
1611
- self,
1612
- tensormaps: Optional[cute.Tensor],
1613
- varlen_m: bool,
1614
- varlen_k: bool,
1615
- has_D: bool,
1616
- warp_idx: Int32,
1617
- ):
1618
- tensormap_manager = None
1619
- tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
1620
- tensormap_epi_ptrs = [None] * self.num_epi_tensormaps
1621
- if const_expr(varlen_m or varlen_k):
1622
- tensormap_manager = TensorMapManagerSm90(
1623
- cutlass.utils.TensorMapUpdateMode.GMEM, self.__class__.bytes_per_tensormap
1624
- )
1625
- # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
1626
- tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
1627
- if const_expr(varlen_m):
1628
- tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
1629
- tensormap_epi_offset = tensormap_d_idx
1630
- if const_expr(has_D):
1631
- tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
1632
- tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
1633
- )
1634
- tensormap_epi_offset += 1 if not self.pingpong else 2
1635
- tensormap_epi_ptrs = [
1636
- tensormap_manager.get_tensormap_ptr(
1637
- tensormaps[
1638
- tensormap_workspace_idx,
1639
- tensormap_epi_offset + i * (1 if not self.pingpong else 2),
1640
- None,
1641
- ].iterator
1642
- )
1643
- for i in range(self.num_epi_tensormaps)
1644
- ]
1645
- else:
1646
- assert varlen_k
1647
- if const_expr(not self.gather_A):
1648
- tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
1649
- tensormaps[tensormap_workspace_idx, 0, None].iterator
1650
- )
1651
- tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
1652
- tensormaps[
1653
- tensormap_workspace_idx, 1 if not self.gather_A else 0, None
1654
- ].iterator
1655
- )
1656
- tensormap_ab_ptrs = [tensormap_a_ptr, tensormap_b_ptr]
1657
- return (
1658
- tensormap_manager,
1659
- tensormap_ab_ptrs,
1660
- tensormap_d_ptr,
1661
- tensormap_epi_ptrs,
1361
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
1362
+ if const_expr(not delay_tma_store):
1363
+ tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
1364
+
1365
+ if const_expr(delay_tma_store):
1366
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
1367
+
1368
+ self.epi_end(
1369
+ params,
1370
+ epi_tensors,
1371
+ epi_tile,
1372
+ tiled_copy_t2r,
1373
+ tiled_copy_r2s,
1374
+ tile_coord_mnkl,
1375
+ varlen_manager,
1376
+ tidx,
1662
1377
  )
1663
1378
 
1664
- def tensormap_update_AB(
1665
- self,
1666
- tensormap_manager,
1667
- tensormap_ab_ptrs,
1668
- cu_seqlens_k: cute.Tensor,
1669
- batch_idx: Int32,
1670
- is_manager_warp: bool | Boolean,
1671
- ) -> None:
1672
- # construct tensor A/B based on real address, shape and stride information
1673
- tensormap_a_ptr, tensormap_b_ptr = tensormap_ab_ptrs
1674
- tensormap_ptrs = [tensormap_b_ptr]
1675
- shapes = [cu_seqlens_k[batch_idx + 1]]
1676
- orders = [0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1]
1677
- if const_expr(not self.gather_A):
1678
- tensormap_ptrs.insert(0, tensormap_a_ptr)
1679
- shapes.insert(0, cu_seqlens_k[batch_idx + 1])
1680
- orders.insert(0, 0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1)
1681
- tensormap_manager.update_tensormap_shape(
1682
- tensormap_ptrs,
1683
- is_manager_warp=is_manager_warp,
1684
- shapes=shapes,
1685
- orders=orders,
1686
- tensormap_smem_ptr=None,
1687
- )
1688
-
1689
- def tensormap_update_D_epi(
1690
- self,
1691
- tensormap_manager,
1692
- tensormap_d_ptr,
1693
- tensormap_epi_ptrs,
1694
- epilogue_params: EpilogueParams,
1695
- cu_seqlens_m: cute.Tensor,
1696
- batch_idx: Int32,
1697
- is_manager_warp: bool | Boolean,
1698
- ) -> None:
1699
- # construct tensor D based on real address, shape and stride information
1700
- tensormap_ptrs, shapes, orders = [], [], []
1701
- if const_expr(tensormap_d_ptr is not None):
1702
- tensormap_ptrs.append(tensormap_d_ptr)
1703
- shapes.append(cu_seqlens_m[batch_idx + 1])
1704
- orders.append(0 if const_expr(self.d_layout.is_m_major_c()) else 1)
1705
- epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
1706
- epilogue_params, cu_seqlens_m, batch_idx
1707
- )
1708
- tensormap_ptrs.extend(tensormap_epi_ptrs)
1709
- shapes.extend(epi_shapes)
1710
- orders.extend(epi_orders)
1711
- tensormap_manager.update_tensormap_shape(
1712
- tensormap_ptrs,
1713
- is_manager_warp=is_manager_warp,
1714
- shapes=shapes,
1715
- orders=orders,
1716
- tensormap_smem_ptr=None,
1717
- )
1379
+ return epi_read_state, epi_producer_state
1718
1380
 
1719
1381
  def get_scheduler_class(self, varlen_m: bool = False):
1720
1382
  """Return the scheduler class to use. Override in subclasses for custom schedulers."""
@@ -1773,6 +1435,40 @@ class GemmSm90:
1773
1435
  )
1774
1436
  return tile_sched_args
1775
1437
 
1438
+ @cute.jit
1439
+ def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
1440
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1441
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1442
+
1443
+ @cute.jit
1444
+ def epi_begin(
1445
+ self,
1446
+ params: EpilogueParams,
1447
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
1448
+ epi_tile: cute.Tile,
1449
+ tiled_copy_t2r: Optional[cute.TiledCopy],
1450
+ tiled_copy_r2s: cute.TiledCopy,
1451
+ tile_coord_mnkl: cute.Coord,
1452
+ varlen_manager: VarlenManager,
1453
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
1454
+ tidx: Int32,
1455
+ ) -> Tuple[cute.Tensor, ...]:
1456
+ return ()
1457
+
1458
+ def epi_begin_loop(
1459
+ self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord
1460
+ ) -> Tuple[cute.Tensor, ...]:
1461
+ return ()
1462
+
1463
+ def epi_visit_subtile(
1464
+ self,
1465
+ params: EpilogueParams,
1466
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
1467
+ tRS_rD: cute.Tensor,
1468
+ tRS_rC: Optional[cute.Tensor] = None,
1469
+ ) -> Optional[cute.Tensor]:
1470
+ return None
1471
+
1776
1472
  def epi_visit_acc(
1777
1473
  self,
1778
1474
  params: EpilogueParams,
@@ -1783,10 +1479,24 @@ class GemmSm90:
1783
1479
  ) -> None:
1784
1480
  pass
1785
1481
 
1482
+ @cute.jit
1483
+ def epi_end(
1484
+ self,
1485
+ params: EpilogueParams,
1486
+ epi_tensors: Tuple[cute.Tensor, ...],
1487
+ epi_tile: cute.Tile,
1488
+ tiled_copy_t2r: Optional[cute.TiledCopy],
1489
+ tiled_copy_r2s: cute.TiledCopy,
1490
+ tile_coord_mnkl: cute.Coord,
1491
+ varlen_manager,
1492
+ tidx,
1493
+ ) -> None:
1494
+ pass
1495
+
1786
1496
  def epi_to_underlying_arguments(
1787
1497
  self, args: EpilogueArguments, *, loc=None, ip=None
1788
1498
  ) -> EpilogueParams:
1789
- return GemmSm90.EpilogueParams(alpha=args.alpha, beta=args.beta)
1499
+ return self.EpilogueParams()
1790
1500
 
1791
1501
  def epi_get_tma_atoms(
1792
1502
  self, params: EpilogueParams, *, loc=None, ip=None
@@ -1810,12 +1520,12 @@ class GemmSm90:
1810
1520
  def epi_smem_bytes_per_stage(
1811
1521
  args: Optional[EpilogueArguments],
1812
1522
  cta_tile_shape_mnk: Tuple[int, int, int],
1813
- epi_tile: Tuple[int, int],
1523
+ epi_tile: cute.Tile,
1814
1524
  ) -> int:
1815
1525
  return 0
1816
1526
 
1817
1527
  def epi_get_smem_struct(self, params: EpilogueParams):
1818
- return cute.struct.MemRange[cutlass.Int32, 0] # Dummy struct
1528
+ return cute.struct.MemRange[Int32, 0] # Dummy struct
1819
1529
 
1820
1530
  def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
1821
1531
  return tuple()
@@ -1842,7 +1552,7 @@ class GemmSm90:
1842
1552
  self.d_layout.is_m_major_c() if self.d_layout is not None else False,
1843
1553
  num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
1844
1554
  ),
1845
- cutlass.Float16, # this is just to get the right source layout
1555
+ Float16, # this is just to get the right source layout
1846
1556
  )
1847
1557
  tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
1848
1558
  return tiled_copy_C_atom
@@ -1852,8 +1562,7 @@ class GemmSm90:
1852
1562
  tiled_mma: cute.TiledMma,
1853
1563
  d_layout: Optional[LayoutEnum],
1854
1564
  dtype: Type[cutlass.Numeric],
1855
- acc: cute.Tensor,
1856
- sD: cute.Tensor,
1565
+ sD: Optional[cute.Tensor],
1857
1566
  tidx: Int32,
1858
1567
  ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1859
1568
  if d_layout is None:
@@ -1868,12 +1577,10 @@ class GemmSm90:
1868
1577
  # (R2S, R2S_M, R2S_N, PIPE_D)
1869
1578
  thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1870
1579
  tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1871
- # (R2S, R2S_M, R2S_N)
1872
- tRS_rAcc = tiled_copy_r2s.retile(acc)
1873
1580
  sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
1874
1581
  tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
1875
1582
  tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
1876
- return tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD
1583
+ return tiled_copy_r2s, tRS_rD, tRS_sD
1877
1584
 
1878
1585
  def epilog_smem_load_and_partition(
1879
1586
  self,
@@ -1885,7 +1592,7 @@ class GemmSm90:
1885
1592
  tidx: Int32,
1886
1593
  ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1887
1594
  tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
1888
- copy_atom_s2r = utils.sm90_get_smem_load_op(c_layout, dtype)
1595
+ copy_atom_s2r = copy_utils.sm90_get_smem_load_op(c_layout, dtype)
1889
1596
  tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1890
1597
  thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1891
1598
  tSR_sC = thr_copy_s2r.partition_S(sC)
@@ -1896,29 +1603,30 @@ class GemmSm90:
1896
1603
  def epilog_gmem_copy_and_partition(
1897
1604
  self,
1898
1605
  atom: Union[cute.CopyAtom, cute.TiledCopy],
1899
- mD_mnl: cute.Tensor,
1606
+ mD_mn: cute.Tensor,
1900
1607
  tile_shape_mn: cute.Tile,
1901
1608
  epi_tile: cute.Tile,
1902
1609
  sD: cute.Tensor,
1903
1610
  tile_coord_mnkl: cute.Coord,
1904
- cu_seqlens_m: Optional[cute.Tensor] = None,
1611
+ tma_desc_ptr: Optional[cute.Pointer] = None,
1905
1612
  ) -> Tuple[cute.Tensor, cute.Tensor]:
1906
- batch_idx = tile_coord_mnkl[3]
1907
- if const_expr(cu_seqlens_m is not None):
1908
- mD_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl)
1909
- else:
1910
- mD_mn = mD_mnl[None, None, batch_idx]
1911
1613
  # (bM, bN)
1912
1614
  gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
1913
1615
  tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
1914
- bSG_sD, bSG_gD = cpasync.tma_partition(
1616
+ is_s2g = isinstance(
1617
+ atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp)
1618
+ )
1619
+ src_tensor, dst_tensor = (
1620
+ (sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD)
1621
+ )
1622
+ return copy_utils.tma_get_copy_fn(
1915
1623
  atom,
1916
- 0,
1917
- cute.make_layout(1),
1918
- cute.group_modes(sD, 0, 2),
1919
- tDgD_for_tma_partition,
1624
+ cta_coord=0,
1625
+ cta_layout=cute.make_layout(1),
1626
+ src_tensor=src_tensor,
1627
+ dst_tensor=dst_tensor,
1628
+ tma_desc_ptr=tma_desc_ptr,
1920
1629
  )
1921
- return bSG_sD, bSG_gD
1922
1630
 
1923
1631
  def make_ab_pipeline(
1924
1632
  self,
@@ -1927,21 +1635,15 @@ class GemmSm90:
1927
1635
  ab_pipeline_mbar_ptr: cute.Pointer,
1928
1636
  ):
1929
1637
  # Threads/warps participating in this pipeline
1930
- producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
1638
+ producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32
1931
1639
  ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1932
1640
  # Each warp will contribute to the arrive count with the number of mcast size
1933
1641
  mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
1934
- consumer_arrive_cnt = mcast_size
1935
- if const_expr(self.arch != 100):
1936
- consumer_arrive_cnt *= tiled_mma.size // cute.arch.WARP_SIZE
1642
+ consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
1937
1643
  ab_pipeline_consumer_group = pipeline.CooperativeGroup(
1938
1644
  pipeline.Agent.Thread, consumer_arrive_cnt
1939
1645
  )
1940
- if const_expr(self.arch != 100):
1941
- pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
1942
- else:
1943
- # TODO: we need a pipeline class for TMACpAsyncUMMA
1944
- pipeline_cls = pipeline.PipelineTmaUmma if not self.gather_A else PipelineTmaCpAsync
1646
+ pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
1945
1647
  return pipeline_cls.create(
1946
1648
  barrier_storage=ab_pipeline_mbar_ptr,
1947
1649
  num_stages=self.ab_stage,
@@ -1973,9 +1675,7 @@ class GemmSm90:
1973
1675
  def make_epi_store_pipeline(self):
1974
1676
  # Threads/warps participating in tma store pipeline
1975
1677
  num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
1976
- epi_store_producer_group = pipeline.CooperativeGroup(
1977
- pipeline.Agent.Thread, num_epi_threads, num_epi_threads
1978
- )
1678
+ epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads)
1979
1679
  return pipeline.PipelineTmaStore.create(
1980
1680
  num_stages=self.epi_stage, producer_group=epi_store_producer_group
1981
1681
  )
@@ -2182,36 +1882,18 @@ class GemmSm90:
2182
1882
  order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
2183
1883
  )
2184
1884
 
1885
+ epi_smem_layout_staged = None
2185
1886
  if d_dtype is not None:
2186
- d_smem_shape = epi_tile
2187
- d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
2188
- d_smem_layout_atom = warpgroup.make_smem_layout_atom(
2189
- sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
2190
- d_dtype,
2191
- )
2192
- epi_smem_layout_staged = cute.tile_to_shape(
2193
- d_smem_layout_atom,
2194
- cute.append(d_smem_shape, epi_stage),
2195
- order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1887
+ epi_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
1888
+ d_dtype, d_layout, epi_tile, epi_stage
2196
1889
  )
2197
- else:
2198
- epi_smem_layout_staged = None
2199
1890
 
1891
+ epi_c_smem_layout_staged = None
2200
1892
  if c_dtype is not None:
2201
1893
  assert c_layout is not None
2202
- c_smem_shape = epi_tile
2203
- c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
2204
- c_smem_layout_atom = warpgroup.make_smem_layout_atom(
2205
- sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
2206
- c_dtype,
2207
- )
2208
- epi_c_smem_layout_staged = cute.tile_to_shape(
2209
- c_smem_layout_atom,
2210
- cute.append(c_smem_shape, epi_c_stage),
2211
- order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
1894
+ epi_c_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
1895
+ c_dtype, c_layout, epi_tile, epi_c_stage
2212
1896
  )
2213
- else:
2214
- epi_c_smem_layout_staged = None
2215
1897
 
2216
1898
  return (
2217
1899
  a_smem_layout_staged,
@@ -2349,7 +2031,7 @@ class GemmSm90:
2349
2031
  """
2350
2032
  is_valid = True
2351
2033
  if a_dtype not in {
2352
- cutlass.Float16,
2034
+ Float16,
2353
2035
  cutlass.BFloat16,
2354
2036
  cutlass.Float8E4M3FN,
2355
2037
  cutlass.Float8E5M2,
@@ -2357,19 +2039,19 @@ class GemmSm90:
2357
2039
  is_valid = False
2358
2040
  # tested b_dtype
2359
2041
  if b_dtype not in {
2360
- cutlass.Float16,
2042
+ Float16,
2361
2043
  cutlass.BFloat16,
2362
2044
  cutlass.Float8E4M3FN,
2363
2045
  cutlass.Float8E5M2,
2364
2046
  }:
2365
2047
  is_valid = False
2366
- if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
2048
+ if acc_dtype not in {Float32, Float16}:
2367
2049
  is_valid = False
2368
2050
  # tested d_dtype
2369
2051
  if d_dtype not in {
2370
2052
  None,
2371
- cutlass.Float32,
2372
- cutlass.Float16,
2053
+ Float32,
2054
+ Float16,
2373
2055
  cutlass.BFloat16,
2374
2056
  cutlass.Float8E4M3FN,
2375
2057
  cutlass.Float8E5M2,
@@ -2386,155 +2068,3 @@ class GemmSm90:
2386
2068
  if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
2387
2069
  is_valid = False
2388
2070
  return is_valid
2389
-
2390
-
2391
- def gemm_sm90(
2392
- # (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
2393
- A: Tensor,
2394
- B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
2395
- D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
2396
- C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
2397
- tile_count_semaphore: Optional[Tensor], # (1,)
2398
- tile_M: int,
2399
- tile_N: int,
2400
- cluster_M: int,
2401
- cluster_N: int,
2402
- pingpong: bool = False,
2403
- persistent: bool = True,
2404
- alpha: float | Tensor = 1.0,
2405
- beta: float | Tensor = 1.0,
2406
- cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
2407
- cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
2408
- A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
2409
- batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
2410
- add_to_output: bool = False,
2411
- ) -> None:
2412
- varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
2413
- assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
2414
- "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
2415
- )
2416
- gather_A = A_idx is not None
2417
- if gather_A:
2418
- assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
2419
- assert cluster_N == 1, "gather_A requires cluster_N=1"
2420
- if varlen:
2421
- assert persistent, "varlen requires persistent=True"
2422
- if add_to_output:
2423
- assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
2424
- if cu_seqlens_m is not None:
2425
- assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
2426
- assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
2427
- if cu_seqlens_k is not None:
2428
- assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
2429
- assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
2430
-
2431
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
2432
- A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
2433
- )
2434
- GemmWrapperBase.permute_tensors(
2435
- tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
2436
- )
2437
- GemmWrapperBase.extract_dtypes(tensor_infos)
2438
- major_configs = {
2439
- "A": ("m", "k", "l"),
2440
- "B": ("n", "k", "l"),
2441
- "D": ("m", "n", "l"),
2442
- "C": ("m", "n", "l"),
2443
- }
2444
- GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
2445
-
2446
- acc_dtype = cutlass.Float32
2447
- tile_shape_mn = (tile_M, tile_N)
2448
- cluster_shape_mnk = (cluster_M, cluster_N, 1)
2449
- if not GemmSm90.is_valid_dtypes(
2450
- tensor_infos["A"].dtype,
2451
- tensor_infos["B"].dtype,
2452
- acc_dtype,
2453
- tensor_infos["D"].dtype,
2454
- tensor_infos["A"].major,
2455
- tensor_infos["B"].major,
2456
- ):
2457
- raise TypeError("Skipping due to unsupported combination of types and majors")
2458
-
2459
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
2460
- GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
2461
-
2462
- def scalar_arg(scalar: float | Tensor):
2463
- if isinstance(scalar, float):
2464
- return Float32(scalar) if scalar != 1.0 else None
2465
- else:
2466
- assert isinstance(scalar, Tensor)
2467
- return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2468
-
2469
- epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta), add_to_output)
2470
- scheduler_args = GemmWrapperBase.create_scheduler_args(
2471
- max_active_clusters,
2472
- tile_count_semaphore,
2473
- batch_idx_permute,
2474
- )
2475
-
2476
- # Create varlen arguments if needed (assumes persistent=True when varlen)
2477
- varlen_args = GemmWrapperBase.create_varlen_args(
2478
- cu_seqlens_m,
2479
- cu_seqlens_k,
2480
- A_idx,
2481
- max_active_clusters,
2482
- cluster_shape_mnk,
2483
- tensor_infos,
2484
- GemmSm90.num_epi_tensormaps,
2485
- pingpong,
2486
- )
2487
-
2488
- current_stream = cutlass_torch.current_stream()
2489
- compile_key = GemmWrapperBase.get_compile_key(
2490
- tensor_infos,
2491
- None,
2492
- tile_shape_mn,
2493
- cluster_shape_mnk,
2494
- pingpong,
2495
- persistent,
2496
- tile_count_semaphore is not None,
2497
- 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
2498
- 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
2499
- add_to_output,
2500
- cu_seqlens_m is not None,
2501
- cu_seqlens_k is not None,
2502
- gather_A,
2503
- batch_idx_permute is not None,
2504
- key_tensor_names=("A", "B", "D", "C"),
2505
- )
2506
- cache = gemm_sm90.compile_cache
2507
- if compile_key not in cache:
2508
- gemm = GemmSm90(
2509
- acc_dtype,
2510
- tensor_infos["A"].dtype,
2511
- tile_shape_mn,
2512
- cluster_shape_mnk,
2513
- pingpong=pingpong,
2514
- is_persistent=persistent,
2515
- gather_A=gather_A,
2516
- )
2517
- cache[compile_key] = cute.compile(
2518
- gemm,
2519
- tensor_infos["A"].cute_tensor,
2520
- tensor_infos["B"].cute_tensor,
2521
- tensor_infos["D"].cute_tensor,
2522
- tensor_infos["C"].cute_tensor,
2523
- epi_args,
2524
- scheduler_args,
2525
- varlen_args,
2526
- current_stream,
2527
- )
2528
- cache[compile_key](
2529
- tensor_infos["A"].cute_tensor,
2530
- tensor_infos["B"].cute_tensor,
2531
- tensor_infos["D"].cute_tensor,
2532
- tensor_infos["C"].cute_tensor,
2533
- epi_args,
2534
- scheduler_args,
2535
- varlen_args,
2536
- current_stream,
2537
- )
2538
-
2539
-
2540
- gemm_sm90.compile_cache = {}