quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -3,11 +3,9 @@
|
|
|
3
3
|
|
|
4
4
|
import enum
|
|
5
5
|
from typing import Tuple, Type, Callable, Optional, Union, Literal
|
|
6
|
-
from dataclasses import dataclass
|
|
7
6
|
from functools import partial
|
|
8
7
|
import math
|
|
9
8
|
|
|
10
|
-
from torch import Tensor
|
|
11
9
|
|
|
12
10
|
import cuda.bindings.driver as cuda
|
|
13
11
|
|
|
@@ -16,10 +14,9 @@ import cutlass.cute as cute
|
|
|
16
14
|
import cutlass.pipeline as pipeline
|
|
17
15
|
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
18
16
|
import cutlass.utils.hopper_helpers as sm90_utils
|
|
19
|
-
from cutlass import Int32, Float32, Boolean, const_expr
|
|
17
|
+
from cutlass import Int32, Float32, Float16, Boolean, const_expr
|
|
18
|
+
from cutlass.cutlass_dsl import if_generate
|
|
20
19
|
from cutlass.utils import LayoutEnum
|
|
21
|
-
import cutlass.torch as cutlass_torch
|
|
22
|
-
from cutlass.cute.runtime import make_ptr
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
|
|
@@ -30,14 +27,12 @@ from quack.tile_scheduler import (
|
|
|
30
27
|
VarlenMTileSchedulerArguments,
|
|
31
28
|
VarlenMTileScheduler,
|
|
32
29
|
)
|
|
33
|
-
from quack.varlen_utils import VarlenArguments
|
|
34
|
-
from quack.tensormap_manager import TensorMapManagerSm90
|
|
30
|
+
from quack.varlen_utils import VarlenArguments, VarlenManager
|
|
35
31
|
|
|
36
32
|
# return PipelineStateWAdvance instead of PipelineState
|
|
37
33
|
from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
|
|
38
|
-
import quack.
|
|
39
|
-
|
|
40
|
-
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
34
|
+
import quack.copy_utils as copy_utils
|
|
35
|
+
import quack.sm90_utils as quack_sm90_utils
|
|
41
36
|
|
|
42
37
|
"""
|
|
43
38
|
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
|
@@ -119,7 +114,7 @@ class GemmSm90:
|
|
|
119
114
|
|
|
120
115
|
Example:
|
|
121
116
|
>>> gemm = GemmSm90(
|
|
122
|
-
... acc_dtype=
|
|
117
|
+
... acc_dtype=Float32,
|
|
123
118
|
... tile_shape_mn=(128, 256),
|
|
124
119
|
... cluster_shape_mnk=(1, 1, 1)
|
|
125
120
|
... )
|
|
@@ -127,19 +122,10 @@ class GemmSm90:
|
|
|
127
122
|
"""
|
|
128
123
|
|
|
129
124
|
arch = 90
|
|
130
|
-
bytes_per_tensormap = 128
|
|
131
125
|
num_epi_tensormaps: int = 0
|
|
132
126
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
alpha: Optional[Float32 | cute.Tensor] = None
|
|
136
|
-
beta: Optional[Float32 | cute.Tensor] = None
|
|
137
|
-
add_to_output: bool = False
|
|
138
|
-
|
|
139
|
-
@dataclass
|
|
140
|
-
class EpilogueParams(ParamsBase):
|
|
141
|
-
alpha: Optional[Float32 | cute.Tensor] = None
|
|
142
|
-
beta: Optional[Float32 | cute.Tensor] = None
|
|
127
|
+
EpilogueArguments = ArgumentsBase
|
|
128
|
+
EpilogueParams = ParamsBase
|
|
143
129
|
|
|
144
130
|
def __init__(
|
|
145
131
|
self,
|
|
@@ -222,7 +208,9 @@ class GemmSm90:
|
|
|
222
208
|
atom_layout_m, atom_layout_n = 1, 1
|
|
223
209
|
self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
|
|
224
210
|
|
|
225
|
-
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
|
211
|
+
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
|
212
|
+
if self.gather_A:
|
|
213
|
+
assert self.num_mcast_ctas_a == 1
|
|
226
214
|
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
|
227
215
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
228
216
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
@@ -237,10 +225,9 @@ class GemmSm90:
|
|
|
237
225
|
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
|
238
226
|
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
|
|
239
227
|
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
|
240
|
-
self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
|
|
241
|
-
self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
242
228
|
self.ab_load_warp_id = self.mma_warp_groups * 4
|
|
243
|
-
self.
|
|
229
|
+
# self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
230
|
+
# self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
244
231
|
|
|
245
232
|
regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
|
|
246
233
|
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
@@ -335,7 +322,7 @@ class GemmSm90:
|
|
|
335
322
|
self.d_dtype,
|
|
336
323
|
self.c_dtype,
|
|
337
324
|
epilogue_args,
|
|
338
|
-
self.smem_capacity
|
|
325
|
+
cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
|
|
339
326
|
self.occupancy,
|
|
340
327
|
# epi_smem will reuse smem ab if not persistent.
|
|
341
328
|
overlap_sD_sA=not self.is_persistent,
|
|
@@ -465,6 +452,7 @@ class GemmSm90:
|
|
|
465
452
|
)
|
|
466
453
|
|
|
467
454
|
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
|
455
|
+
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
|
|
468
456
|
|
|
469
457
|
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
|
|
470
458
|
tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
|
|
@@ -483,7 +471,7 @@ class GemmSm90:
|
|
|
483
471
|
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
484
472
|
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
485
473
|
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
486
|
-
tile_count: cute.struct.MemRange[
|
|
474
|
+
tile_count: cute.struct.MemRange[Int32, self.sched_stage]
|
|
487
475
|
sD: cute.struct.Align[
|
|
488
476
|
cute.struct.MemRange[
|
|
489
477
|
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
@@ -520,10 +508,7 @@ class GemmSm90:
|
|
|
520
508
|
tma_atom_c,
|
|
521
509
|
tma_tensor_c,
|
|
522
510
|
epilogue_params,
|
|
523
|
-
|
|
524
|
-
varlen_args.mCuSeqlensK,
|
|
525
|
-
varlen_args.mTensormaps,
|
|
526
|
-
varlen_args.mAIdx,
|
|
511
|
+
varlen_params,
|
|
527
512
|
self.cluster_layout_mnk,
|
|
528
513
|
self.a_smem_layout_staged,
|
|
529
514
|
self.b_smem_layout_staged,
|
|
@@ -535,7 +520,6 @@ class GemmSm90:
|
|
|
535
520
|
grid=grid,
|
|
536
521
|
block=[self.threads_per_cta, 1, 1],
|
|
537
522
|
cluster=self.cluster_shape_mnk,
|
|
538
|
-
smem=self.shared_storage.size_in_bytes(),
|
|
539
523
|
stream=stream,
|
|
540
524
|
min_blocks_per_mp=1,
|
|
541
525
|
)
|
|
@@ -555,10 +539,7 @@ class GemmSm90:
|
|
|
555
539
|
tma_atom_c: Optional[cute.CopyAtom],
|
|
556
540
|
mC_mnl: Optional[cute.Tensor],
|
|
557
541
|
epilogue_params: ParamsBase,
|
|
558
|
-
|
|
559
|
-
cu_seqlens_k: Optional[cute.Tensor],
|
|
560
|
-
tensormaps: Optional[cute.Tensor],
|
|
561
|
-
mAIdx: Optional[cute.Tensor],
|
|
542
|
+
varlen_params: VarlenManager.Params,
|
|
562
543
|
cluster_layout_mnk: cute.Layout,
|
|
563
544
|
a_smem_layout: cute.ComposedLayout,
|
|
564
545
|
b_smem_layout: cute.ComposedLayout,
|
|
@@ -594,8 +575,8 @@ class GemmSm90:
|
|
|
594
575
|
:type epi_smem_layout: cute.ComposedLayout
|
|
595
576
|
"""
|
|
596
577
|
|
|
597
|
-
varlen_m = const_expr(cu_seqlens_m is not None)
|
|
598
|
-
varlen_k = const_expr(cu_seqlens_k is not None)
|
|
578
|
+
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
|
|
579
|
+
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
|
|
599
580
|
assert not (varlen_m and varlen_k)
|
|
600
581
|
if const_expr(self.gather_A):
|
|
601
582
|
assert varlen_m or varlen_k
|
|
@@ -657,9 +638,19 @@ class GemmSm90:
|
|
|
657
638
|
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
658
639
|
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
|
659
640
|
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
641
|
+
varlen_manager = VarlenManager.create(
|
|
642
|
+
varlen_params,
|
|
643
|
+
has_D,
|
|
644
|
+
self.num_epi_tensormaps,
|
|
645
|
+
# Only used if not varlen_m
|
|
646
|
+
len_m_static=Int32(
|
|
647
|
+
mA_mkl.shape[0]
|
|
648
|
+
if varlen_k or varlen_params.mAIdx is None
|
|
649
|
+
else varlen_params.mAIdx.shape[0]
|
|
650
|
+
),
|
|
651
|
+
len_k_static=Int32(mA_mkl.shape[1]),
|
|
652
|
+
pingpong=self.pingpong,
|
|
653
|
+
warp_idx=warp_idx,
|
|
663
654
|
)
|
|
664
655
|
|
|
665
656
|
TileSchedulerCls = partial(
|
|
@@ -673,29 +664,20 @@ class GemmSm90:
|
|
|
673
664
|
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
674
665
|
):
|
|
675
666
|
is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
tma_atom_a,
|
|
681
|
-
tensormap_ab_ptrs[0],
|
|
682
|
-
is_tma_warp,
|
|
683
|
-
)
|
|
684
|
-
tensormap_manager.init_tensormap_from_atom(
|
|
685
|
-
tma_atom_b,
|
|
686
|
-
tensormap_ab_ptrs[1],
|
|
687
|
-
is_tma_warp,
|
|
688
|
-
)
|
|
667
|
+
# initialize tensormap for A & B
|
|
668
|
+
varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
|
|
669
|
+
tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
|
|
670
|
+
tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
|
|
689
671
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
690
672
|
# Get mcast mask
|
|
691
673
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
692
674
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
693
|
-
|
|
675
|
+
block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
694
676
|
a_mcast_mask = cute.make_layout_image_mask(
|
|
695
|
-
cluster_layout_mnk,
|
|
677
|
+
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
|
|
696
678
|
)
|
|
697
679
|
b_mcast_mask = cute.make_layout_image_mask(
|
|
698
|
-
cluster_layout_mnk,
|
|
680
|
+
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0
|
|
699
681
|
)
|
|
700
682
|
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
|
|
701
683
|
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
|
|
@@ -711,42 +693,30 @@ class GemmSm90:
|
|
|
711
693
|
)
|
|
712
694
|
if const_expr(varlen_k):
|
|
713
695
|
# wait tensormap initialization complete before update
|
|
714
|
-
|
|
715
|
-
# batch index of last tile
|
|
716
|
-
last_batch_idx = cutlass.Int32(-1)
|
|
696
|
+
varlen_manager.fence_tensormap_init()
|
|
717
697
|
while work_tile.is_valid_tile:
|
|
718
698
|
tile_coord_mnkl = work_tile.tile_idx
|
|
719
699
|
batch_idx = tile_coord_mnkl[3]
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
tensormap_ab_ptrs,
|
|
727
|
-
cu_seqlens_k,
|
|
728
|
-
batch_idx,
|
|
729
|
-
is_tma_warp,
|
|
730
|
-
)
|
|
700
|
+
varlen_manager.update_tensormap_AB(
|
|
701
|
+
batch_idx,
|
|
702
|
+
self.a_layout,
|
|
703
|
+
self.b_layout,
|
|
704
|
+
is_tma_warp,
|
|
705
|
+
)
|
|
731
706
|
# ///////////////////////////////////////////////////////////////////////////
|
|
732
707
|
# Local_tile partition global tensors
|
|
733
708
|
# ///////////////////////////////////////////////////////////////////////////
|
|
734
709
|
if const_expr(not self.gather_A):
|
|
735
|
-
|
|
736
|
-
mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
|
|
737
|
-
elif const_expr(varlen_k):
|
|
738
|
-
mA_mk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mA_mkl)
|
|
739
|
-
else:
|
|
740
|
-
mA_mk = mA_mkl[None, None, batch_idx]
|
|
710
|
+
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
|
741
711
|
# (bM, bK, RestK)
|
|
742
|
-
|
|
712
|
+
gA_mk = cute.local_tile(
|
|
743
713
|
mA_mk,
|
|
744
714
|
cute.select(self.cta_tile_shape_mnk, [0, 2]),
|
|
745
715
|
(tile_coord_mnkl[0], None),
|
|
746
716
|
)
|
|
747
717
|
else:
|
|
718
|
+
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
|
748
719
|
if const_expr(varlen_m):
|
|
749
|
-
mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
|
|
750
720
|
gAIdx = cute.local_tile(
|
|
751
721
|
mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
|
752
722
|
)
|
|
@@ -754,133 +724,90 @@ class GemmSm90:
|
|
|
754
724
|
mA_mk = mA_mkl
|
|
755
725
|
else:
|
|
756
726
|
assert varlen_k
|
|
757
|
-
mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
|
|
758
727
|
# (tile_K, RestK)
|
|
759
728
|
gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
|
|
760
729
|
# (tile_M, K)
|
|
761
730
|
mA_mk = cute.local_tile(
|
|
762
731
|
mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
|
|
763
732
|
)
|
|
764
|
-
if const_expr(varlen_k):
|
|
765
|
-
mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
|
|
766
|
-
else:
|
|
767
|
-
mB_nk = mB_nkl[None, None, batch_idx]
|
|
768
733
|
# (bN, bK, RestK)
|
|
769
|
-
|
|
770
|
-
|
|
734
|
+
gB_nk = cute.local_tile(
|
|
735
|
+
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
|
|
771
736
|
cute.select(self.cta_tile_shape_mnk, [1, 2]),
|
|
772
737
|
(tile_coord_mnkl[1], None),
|
|
773
738
|
)
|
|
774
739
|
# //////////////////////////////////////////////////////////////////////////
|
|
775
740
|
# Partition shared tensor for TMA load A/B
|
|
776
741
|
# //////////////////////////////////////////////////////////////////////////
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
tensormap_a_ptr, tensormap_b_ptr = tensormap_ab_ptrs
|
|
781
|
-
if is_group_changed and is_tma_warp:
|
|
782
|
-
if const_expr(not self.gather_A):
|
|
783
|
-
tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
|
|
784
|
-
tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
|
|
785
|
-
if const_expr(not self.gather_A):
|
|
786
|
-
tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
787
|
-
tensormap_a_ptr, cute.AddressSpace.generic
|
|
788
|
-
)
|
|
789
|
-
tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
790
|
-
tensormap_b_ptr, cute.AddressSpace.generic
|
|
791
|
-
)
|
|
742
|
+
varlen_manager.fence_tensormap_update_AB(is_tma_warp)
|
|
743
|
+
len_m = varlen_manager.len_m(batch_idx)
|
|
744
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
792
745
|
# TMA load A partition_S/D
|
|
793
|
-
|
|
794
|
-
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
795
|
-
)
|
|
796
|
-
a_cta_crd = cluster_coord_mnk[1]
|
|
746
|
+
copy_A = None
|
|
797
747
|
if const_expr(not self.gather_A):
|
|
798
|
-
|
|
799
|
-
tAsA, tAgA_k = cpasync.tma_partition(
|
|
800
|
-
tma_atom_a,
|
|
801
|
-
a_cta_crd,
|
|
802
|
-
a_cta_layout,
|
|
803
|
-
cute.group_modes(sA, 0, 2),
|
|
804
|
-
cute.group_modes(gA_k, 0, 2),
|
|
805
|
-
)
|
|
806
|
-
copy_A = partial(
|
|
807
|
-
cute.copy,
|
|
748
|
+
copy_A, _, _ = copy_utils.tma_get_copy_fn(
|
|
808
749
|
tma_atom_a,
|
|
750
|
+
cta_coord=block_in_cluster_coord_mnk[1],
|
|
751
|
+
cta_layout=cute.make_layout(
|
|
752
|
+
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
753
|
+
),
|
|
754
|
+
src_tensor=gA_mk,
|
|
755
|
+
dst_tensor=sA,
|
|
809
756
|
mcast_mask=a_mcast_mask,
|
|
810
757
|
tma_desc_ptr=tma_desc_a_ptr,
|
|
811
758
|
)
|
|
812
759
|
else:
|
|
813
760
|
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
|
814
|
-
mA_mkl.element_type, self.a_layout, self.
|
|
761
|
+
mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
|
|
815
762
|
)
|
|
816
763
|
tidx = (
|
|
817
|
-
cute.arch.thread_idx()[0]
|
|
818
|
-
- self.mma_warp_groups * self.num_threads_per_warp_group
|
|
764
|
+
cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
|
|
819
765
|
)
|
|
820
766
|
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
767
|
+
copy_A, prefetch_A = None, None
|
|
768
|
+
if const_expr(varlen_m):
|
|
769
|
+
copy_A = copy_utils.gather_m_get_copy_fn(
|
|
770
|
+
thr_copy_A,
|
|
771
|
+
mA_mk,
|
|
772
|
+
sA,
|
|
773
|
+
gAIdx,
|
|
774
|
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
775
|
+
limit_k=len_k,
|
|
776
|
+
)
|
|
777
|
+
else:
|
|
778
|
+
copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
|
|
779
|
+
thr_copy_A,
|
|
780
|
+
mA_mk,
|
|
781
|
+
sA,
|
|
782
|
+
gAIdx,
|
|
783
|
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
784
|
+
limit_k=len_k,
|
|
785
|
+
)
|
|
829
786
|
# TMA load B partition_S/D
|
|
830
|
-
|
|
831
|
-
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
|
832
|
-
)
|
|
833
|
-
b_cta_crd = cluster_coord_mnk[0]
|
|
834
|
-
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
835
|
-
tBsB, tBgB_k = cpasync.tma_partition(
|
|
787
|
+
copy_B, _, _ = copy_utils.tma_get_copy_fn(
|
|
836
788
|
tma_atom_b,
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
789
|
+
cta_coord=block_in_cluster_coord_mnk[0],
|
|
790
|
+
cta_layout=cute.make_layout(
|
|
791
|
+
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
|
792
|
+
),
|
|
793
|
+
src_tensor=gB_nk,
|
|
794
|
+
dst_tensor=sB,
|
|
795
|
+
mcast_mask=b_mcast_mask,
|
|
796
|
+
tma_desc_ptr=tma_desc_b_ptr,
|
|
844
797
|
)
|
|
845
|
-
|
|
846
|
-
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
847
|
-
if const_expr(varlen_k)
|
|
848
|
-
else Int32(mA_mkl.shape[1])
|
|
849
|
-
)
|
|
850
|
-
k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
|
|
798
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
851
799
|
if const_expr(not self.gather_A):
|
|
852
800
|
ab_producer_state = self.load_AB(
|
|
853
|
-
ab_pipeline,
|
|
854
|
-
ab_producer_state,
|
|
855
|
-
copy_A,
|
|
856
|
-
tAgA_k,
|
|
857
|
-
tAsA,
|
|
858
|
-
copy_B,
|
|
859
|
-
tBgB_k,
|
|
860
|
-
tBsB,
|
|
861
|
-
k_tile_cnt,
|
|
801
|
+
ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
|
|
862
802
|
)
|
|
863
803
|
else:
|
|
864
|
-
limit_m = (
|
|
865
|
-
Int32(mA_mkl.shape[0])
|
|
866
|
-
if const_expr(cu_seqlens_m is None)
|
|
867
|
-
else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
|
|
868
|
-
)
|
|
869
804
|
ab_producer_state = self.load_AB_gather_A(
|
|
870
805
|
ab_pipeline,
|
|
871
806
|
ab_producer_state,
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
tAsA,
|
|
875
|
-
gAIdx,
|
|
807
|
+
copy_A,
|
|
808
|
+
prefetch_A,
|
|
876
809
|
copy_B,
|
|
877
|
-
tBgB_k,
|
|
878
|
-
tBsB,
|
|
879
810
|
k_tile_cnt,
|
|
880
|
-
limit_A=(
|
|
881
|
-
limit_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
882
|
-
k_len,
|
|
883
|
-
),
|
|
884
811
|
varlen_m=varlen_m,
|
|
885
812
|
)
|
|
886
813
|
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
@@ -900,22 +827,11 @@ class GemmSm90:
|
|
|
900
827
|
(not self.pingpong and warp_idx == 0)
|
|
901
828
|
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
|
902
829
|
)
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
tensormap_d_ptr,
|
|
909
|
-
is_manager_warp=is_tma_warp,
|
|
910
|
-
)
|
|
911
|
-
for tma_atom, tensormap_epi_ptr in zip(
|
|
912
|
-
self.epi_get_tma_atoms(epilogue_params), tensormap_epi_ptrs
|
|
913
|
-
):
|
|
914
|
-
tensormap_manager.init_tensormap_from_atom(
|
|
915
|
-
tma_atom,
|
|
916
|
-
tensormap_epi_ptr,
|
|
917
|
-
is_manager_warp=is_tma_warp,
|
|
918
|
-
)
|
|
830
|
+
varlen_manager.init_tensormap_epi(
|
|
831
|
+
tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
|
|
832
|
+
)
|
|
833
|
+
tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
|
|
834
|
+
tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
|
|
919
835
|
# //////////////////////////////////////////////////////////////////////////////
|
|
920
836
|
# Partition global tensor for TiledMMA_A/B/C
|
|
921
837
|
# //////////////////////////////////////////////////////////////////////////////
|
|
@@ -974,9 +890,8 @@ class GemmSm90:
|
|
|
974
890
|
if const_expr(not varlen_k):
|
|
975
891
|
ab_read_state.advance_iters(k_tile_cnt_static)
|
|
976
892
|
else:
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
|
|
893
|
+
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
|
894
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
980
895
|
ab_read_state.advance_iters(k_tile_cnt)
|
|
981
896
|
tile_scheduler.advance_to_next_work()
|
|
982
897
|
if const_expr(varlen_k):
|
|
@@ -987,32 +902,22 @@ class GemmSm90:
|
|
|
987
902
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
988
903
|
if const_expr(varlen_m):
|
|
989
904
|
# wait tensormap initialization complete before update
|
|
990
|
-
|
|
991
|
-
# batch index of last tile
|
|
992
|
-
last_batch_idx = cutlass.Int32(-1)
|
|
905
|
+
varlen_manager.fence_tensormap_init()
|
|
993
906
|
while work_tile.is_valid_tile:
|
|
994
907
|
tile_coord_mnkl = work_tile.tile_idx
|
|
995
908
|
batch_idx = tile_coord_mnkl[3]
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
cu_seqlens_m,
|
|
1006
|
-
batch_idx,
|
|
1007
|
-
is_manager_warp=is_tma_warp,
|
|
1008
|
-
)
|
|
1009
|
-
|
|
1010
|
-
k_len = (
|
|
1011
|
-
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1012
|
-
if const_expr(varlen_k)
|
|
1013
|
-
else mA_mkl.shape[1]
|
|
909
|
+
epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
|
|
910
|
+
epilogue_params, varlen_params.cu_seqlens_m, batch_idx
|
|
911
|
+
)
|
|
912
|
+
varlen_manager.update_tensormap_epi(
|
|
913
|
+
batch_idx,
|
|
914
|
+
self.d_layout,
|
|
915
|
+
epi_shapes,
|
|
916
|
+
epi_orders,
|
|
917
|
+
is_tma_warp,
|
|
1014
918
|
)
|
|
1015
|
-
|
|
919
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
920
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1016
921
|
ab_read_state, tiled_mma = self.mma(
|
|
1017
922
|
ab_pipeline,
|
|
1018
923
|
ab_read_state,
|
|
@@ -1039,57 +944,38 @@ class GemmSm90:
|
|
|
1039
944
|
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
|
1040
945
|
)
|
|
1041
946
|
|
|
1042
|
-
|
|
1043
|
-
if const_expr(varlen_m):
|
|
1044
|
-
# ensure the update to tensormap has completed before using it
|
|
1045
|
-
if is_group_changed and is_tma_warp:
|
|
1046
|
-
if const_expr(has_D):
|
|
1047
|
-
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1048
|
-
for tensormap_epi_ptr in tensormap_epi_ptrs:
|
|
1049
|
-
tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
|
|
1050
|
-
if const_expr(has_D):
|
|
1051
|
-
tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1052
|
-
tensormap_d_ptr, cute.AddressSpace.generic
|
|
1053
|
-
)
|
|
1054
|
-
tma_desc_epi_ptrs = [
|
|
1055
|
-
tensormap_manager.get_tensormap_ptr(
|
|
1056
|
-
tensormap_epi_ptr, cute.AddressSpace.generic
|
|
1057
|
-
)
|
|
1058
|
-
for tensormap_epi_ptr in tensormap_epi_ptrs
|
|
1059
|
-
]
|
|
947
|
+
varlen_manager.fence_tensormap_update_epi(is_tma_warp)
|
|
1060
948
|
|
|
949
|
+
copy_D = None
|
|
1061
950
|
if const_expr(has_D):
|
|
1062
|
-
|
|
951
|
+
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
|
|
1063
952
|
tma_atom_d,
|
|
1064
|
-
mD_mnl,
|
|
953
|
+
varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
|
|
1065
954
|
self.cta_tile_shape_mnk[:2],
|
|
1066
955
|
self.epi_tile,
|
|
1067
956
|
sD,
|
|
1068
957
|
tile_coord_mnkl,
|
|
1069
|
-
|
|
958
|
+
tma_desc_ptr=tma_desc_d_ptr,
|
|
1070
959
|
)
|
|
1071
|
-
|
|
1072
|
-
else:
|
|
1073
|
-
bSG_sD, bSG_gD, copy_D = None, None, None
|
|
960
|
+
copy_C = None
|
|
1074
961
|
if const_expr(has_C):
|
|
1075
|
-
|
|
962
|
+
copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
|
|
1076
963
|
tma_atom_c,
|
|
1077
|
-
mC_mnl,
|
|
964
|
+
varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
|
|
1078
965
|
self.cta_tile_shape_mnk[:2],
|
|
1079
966
|
self.epi_tile,
|
|
1080
967
|
sC,
|
|
1081
968
|
tile_coord_mnkl,
|
|
1082
|
-
cu_seqlens_m,
|
|
1083
969
|
)
|
|
1084
|
-
copy_C =
|
|
1085
|
-
epi_load_g2s = partial(self.epi_load_g2s, epi_pipeline, copy_C, bGS_gC, bGS_sC)
|
|
1086
|
-
else:
|
|
1087
|
-
epi_load_g2s = None
|
|
970
|
+
copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
|
|
1088
971
|
|
|
1089
972
|
d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
|
|
1090
|
-
tiled_copy_r2s,
|
|
1091
|
-
tiled_mma, self.d_layout, d_dtype_for_layout,
|
|
973
|
+
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
|
974
|
+
tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
|
|
1092
975
|
)
|
|
976
|
+
# (R2S, R2S_M, R2S_N)
|
|
977
|
+
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
978
|
+
load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
|
|
1093
979
|
if const_expr(has_C):
|
|
1094
980
|
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
|
1095
981
|
tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
|
|
@@ -1112,21 +998,20 @@ class GemmSm90:
|
|
|
1112
998
|
epi_store_pipeline,
|
|
1113
999
|
epi_read_state,
|
|
1114
1000
|
epi_producer_state,
|
|
1115
|
-
|
|
1116
|
-
|
|
1001
|
+
self.epi_tile,
|
|
1002
|
+
load_acc_subtile,
|
|
1117
1003
|
tRS_rD,
|
|
1118
1004
|
tRS_rC,
|
|
1005
|
+
None, # tiled_copy_t2r, for Sm100 only
|
|
1119
1006
|
tiled_copy_r2s,
|
|
1120
1007
|
tRS_sD,
|
|
1121
1008
|
tiled_copy_s2r,
|
|
1122
1009
|
tSR_rC,
|
|
1123
1010
|
tSR_sC,
|
|
1124
1011
|
copy_D,
|
|
1125
|
-
|
|
1126
|
-
bSG_gD,
|
|
1127
|
-
epi_load_g2s,
|
|
1012
|
+
copy_C,
|
|
1128
1013
|
tile_coord_mnkl,
|
|
1129
|
-
|
|
1014
|
+
varlen_manager,
|
|
1130
1015
|
epilogue_barrier,
|
|
1131
1016
|
tile_scheduler,
|
|
1132
1017
|
tidx,
|
|
@@ -1157,9 +1042,8 @@ class GemmSm90:
|
|
|
1157
1042
|
tile_scheduler.advance_to_next_work()
|
|
1158
1043
|
work_tile = tile_scheduler.get_current_work()
|
|
1159
1044
|
if work_tile.is_valid_tile:
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
|
|
1045
|
+
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
|
1046
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1163
1047
|
ab_read_state.advance_iters(k_tile_cnt)
|
|
1164
1048
|
tile_scheduler.advance_to_next_work()
|
|
1165
1049
|
work_tile = tile_scheduler.get_current_work()
|
|
@@ -1175,14 +1059,16 @@ class GemmSm90:
|
|
|
1175
1059
|
self,
|
|
1176
1060
|
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1177
1061
|
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1178
|
-
copy_A: Callable,
|
|
1179
|
-
tAgA: cute.Tensor,
|
|
1180
|
-
tAsA: cute.Tensor,
|
|
1062
|
+
copy_A: Optional[Callable],
|
|
1181
1063
|
copy_B: Callable,
|
|
1182
|
-
tBgB: cute.Tensor,
|
|
1183
|
-
tBsB: cute.Tensor,
|
|
1184
1064
|
k_tile_cnt: Int32,
|
|
1065
|
+
# These are for Sm100 blockscaled gemm
|
|
1066
|
+
copy_SFA: Optional[Callable] = None,
|
|
1067
|
+
copy_SFB: Optional[Callable] = None,
|
|
1185
1068
|
) -> cutlass.pipeline.PipelineState:
|
|
1069
|
+
blockscaled = const_expr(copy_SFA is not None)
|
|
1070
|
+
if const_expr(blockscaled):
|
|
1071
|
+
assert copy_SFB is not None
|
|
1186
1072
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1187
1073
|
peek_ab_empty_status = Boolean(True)
|
|
1188
1074
|
if 0 < k_tile_cnt:
|
|
@@ -1195,8 +1081,13 @@ class GemmSm90:
|
|
|
1195
1081
|
# Also sets the transaction barrier for the A/B buffers
|
|
1196
1082
|
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
|
1197
1083
|
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1198
|
-
|
|
1199
|
-
|
|
1084
|
+
smem_idx = ab_producer_state.index
|
|
1085
|
+
if const_expr(copy_A is not None):
|
|
1086
|
+
copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1087
|
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1088
|
+
if const_expr(blockscaled):
|
|
1089
|
+
copy_SFA(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1090
|
+
copy_SFB(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1200
1091
|
# Mainloop pipeline's producer commit is a NOP
|
|
1201
1092
|
ab_pipeline.producer_commit(ab_producer_state)
|
|
1202
1093
|
ab_producer_state.advance()
|
|
@@ -1210,58 +1101,12 @@ class GemmSm90:
|
|
|
1210
1101
|
self,
|
|
1211
1102
|
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1212
1103
|
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
tAsA: cute.Tensor,
|
|
1216
|
-
gAIdx: cute.Tensor, # (tile_M,) if varlen_m, (tile_K, RestK) if varlen_k
|
|
1104
|
+
copy_A: Callable,
|
|
1105
|
+
prefetch_A: Optional[Callable],
|
|
1217
1106
|
copy_B: Callable,
|
|
1218
|
-
tBgB: cute.Tensor,
|
|
1219
|
-
tBsB: cute.Tensor,
|
|
1220
1107
|
k_tile_cnt: Int32,
|
|
1221
|
-
|
|
1222
|
-
varlen_m: bool,
|
|
1108
|
+
varlen_m: bool = True,
|
|
1223
1109
|
) -> cutlass.pipeline.PipelineState:
|
|
1224
|
-
limit_m, limit_k = limit_A
|
|
1225
|
-
# Do we need to check if we overshoot tile_M when we load A?
|
|
1226
|
-
is_even_m_smem = self.cta_tile_shape_mnk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
|
1227
|
-
if const_expr(not is_even_m_smem):
|
|
1228
|
-
limit_m = min(limit_m, self.cta_tile_shape_mnk[0])
|
|
1229
|
-
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
1230
|
-
cA = cute.make_identity_tensor(cute.select(self.cta_tile_shape_mnk, [0, 2]))
|
|
1231
|
-
tAcA = thr_copy_A.partition_S(cA)
|
|
1232
|
-
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
1233
|
-
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
1234
|
-
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
1235
|
-
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
1236
|
-
limit_m = limit_m - tAcA[0][0]
|
|
1237
|
-
limit_k = limit_k - tAcA[0][1]
|
|
1238
|
-
# Read indices for A
|
|
1239
|
-
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
1240
|
-
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
1241
|
-
tApA_m = cute.make_fragment(rows_per_thread, Boolean)
|
|
1242
|
-
for m in cutlass.range_constexpr(rows_per_thread):
|
|
1243
|
-
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
1244
|
-
m_idx, k_idx, tAmA = None, None, None
|
|
1245
|
-
if const_expr(varlen_m):
|
|
1246
|
-
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
|
1247
|
-
for m in cutlass.range(rows_per_thread):
|
|
1248
|
-
row_idx = tAcA[0, m, 0][0]
|
|
1249
|
-
if tApA_m[m]:
|
|
1250
|
-
m_idx[m] = gAIdx[row_idx]
|
|
1251
|
-
else:
|
|
1252
|
-
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
|
1253
|
-
else:
|
|
1254
|
-
k_idx = cute.make_fragment(cols_per_thread, Int32) # Will be read later
|
|
1255
|
-
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
|
1256
|
-
# This is very convoluted but idk a better way
|
|
1257
|
-
# for tile_M=128, flat_divide gives (8, 16, K),
|
|
1258
|
-
# then logical_divide gives ((8, 1), (8, 2), K).
|
|
1259
|
-
tidx = thr_copy_A.thr_idx
|
|
1260
|
-
tAmA = cute.logical_divide(
|
|
1261
|
-
cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
|
|
1262
|
-
)[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
|
|
1263
|
-
# (m, (bK, RestK))
|
|
1264
|
-
mA_k = cute.logical_divide(mA, (None, self.cta_tile_shape_mnk[2]))
|
|
1265
1110
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
1266
1111
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1267
1112
|
peek_ab_empty_status = Boolean(True)
|
|
@@ -1270,59 +1115,27 @@ class GemmSm90:
|
|
|
1270
1115
|
# /////////////////////////////////////////////////////////////////////////
|
|
1271
1116
|
# TMA load on B and cp.async on A
|
|
1272
1117
|
# /////////////////////////////////////////////////////////////////////////
|
|
1273
|
-
copy_A = partial(cute.copy, thr_copy_A)
|
|
1274
1118
|
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
col_idx = tAcA[0, 0, k][1]
|
|
1279
|
-
k_idx[k] = gAIdx_cur[col_idx]
|
|
1119
|
+
prefetch_out = ()
|
|
1120
|
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
1121
|
+
prefetch_out = (prefetch_A(k_tile),)
|
|
1280
1122
|
# Wait for A/B buffers to be empty before loading into them
|
|
1281
1123
|
# Also sets the transaction barrier for the A/B buffers
|
|
1282
1124
|
# A tiny bit faster to rotate the warp that does TMA
|
|
1283
1125
|
# However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
|
|
1284
1126
|
# since that's the warp that does the tensormap update.
|
|
1285
|
-
|
|
1127
|
+
is_tma_warp = warp_idx == self.ab_load_warp_id + (
|
|
1286
1128
|
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
|
1287
1129
|
)
|
|
1288
|
-
ab_pipeline.producer_acquire(
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
copy_B(
|
|
1296
|
-
tBgB[None, k_tile],
|
|
1297
|
-
tBsB[None, ab_producer_state.index],
|
|
1298
|
-
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1299
|
-
)
|
|
1300
|
-
# (m, bK)
|
|
1301
|
-
if const_expr(varlen_m):
|
|
1302
|
-
mA_cur = mA_k[None, (None, k_tile)]
|
|
1303
|
-
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1304
|
-
# cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
|
|
1305
|
-
# ((elems_per_load), thread_per_row)
|
|
1306
|
-
# But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
|
|
1307
|
-
# So we append 1s to the last dimension and then do tiled_divide, then slice.
|
|
1308
|
-
mA_row = cute.tiled_divide(
|
|
1309
|
-
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
|
1310
|
-
)[None, None, 0]
|
|
1311
|
-
if const_expr(is_even_m_smem) or tApA_m[m]:
|
|
1312
|
-
# There's only 1 load per row
|
|
1313
|
-
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1314
|
-
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1315
|
-
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1316
|
-
else:
|
|
1317
|
-
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
|
1318
|
-
# copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), ab_producer_state.index], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
|
|
1319
|
-
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1320
|
-
if tApA_m[m]:
|
|
1321
|
-
copy_A(
|
|
1322
|
-
tAmA[None, m, k_idx[k]], tAsA[(None, m, k), ab_producer_state.index]
|
|
1323
|
-
)
|
|
1130
|
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
|
1131
|
+
smem_idx = ab_producer_state.index
|
|
1132
|
+
# A bit faster to load B first while we calculate the indices for A
|
|
1133
|
+
if is_tma_warp:
|
|
1134
|
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1135
|
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1136
|
+
copy_A(k_tile, smem_idx, *prefetch_out)
|
|
1324
1137
|
# This tells mbarrier to track the completion of cp.async
|
|
1325
|
-
ab_pipeline.
|
|
1138
|
+
ab_pipeline.producer_cpasync_commit(ab_producer_state)
|
|
1326
1139
|
ab_producer_state.advance()
|
|
1327
1140
|
peek_ab_empty_status = Boolean(True)
|
|
1328
1141
|
if k_tile + 1 < k_tile_cnt:
|
|
@@ -1330,58 +1143,19 @@ class GemmSm90:
|
|
|
1330
1143
|
# bound checking in the K dimension on the last k_tile
|
|
1331
1144
|
if 0 < k_tile_cnt:
|
|
1332
1145
|
k_tile = k_tile_cnt - 1
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
if const_expr(not varlen_m):
|
|
1338
|
-
gAIdx_cur = gAIdx[None, k_tile]
|
|
1339
|
-
for k in cutlass.range(cols_per_thread):
|
|
1340
|
-
col_idx = tAcA[0, 0, k][1]
|
|
1341
|
-
if tApA_k[k]:
|
|
1342
|
-
k_idx[k] = gAIdx_cur[col_idx]
|
|
1343
|
-
else:
|
|
1344
|
-
k_idx[k] = -1
|
|
1345
|
-
tma_warp_id = self.ab_load_warp_id + (
|
|
1146
|
+
prefetch_out = ()
|
|
1147
|
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
1148
|
+
prefetch_out = (prefetch_A(k_tile, pred=True),)
|
|
1149
|
+
is_tma_warp = warp_idx == self.ab_load_warp_id + (
|
|
1346
1150
|
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
|
1347
1151
|
)
|
|
1348
|
-
ab_pipeline.producer_acquire(
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
tBgB[None, k_tile],
|
|
1356
|
-
tBsB[None, ab_producer_state.index],
|
|
1357
|
-
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1358
|
-
)
|
|
1359
|
-
if const_expr(varlen_m):
|
|
1360
|
-
# (m, bK)
|
|
1361
|
-
mA_cur = mA_k[None, (None, k_tile)]
|
|
1362
|
-
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1363
|
-
# ((elems_per_load, 1), thread_per_row)
|
|
1364
|
-
mA_row = cute.tiled_divide(
|
|
1365
|
-
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
|
1366
|
-
)[None, None, 0]
|
|
1367
|
-
if const_expr(is_even_m_smem) or tApA_m[k]:
|
|
1368
|
-
# There's only 1 load per row
|
|
1369
|
-
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1370
|
-
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1371
|
-
copy_A(
|
|
1372
|
-
mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA_k
|
|
1373
|
-
)
|
|
1374
|
-
else:
|
|
1375
|
-
tApA_k = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
|
|
1376
|
-
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
|
1377
|
-
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1378
|
-
if tApA_m[m]:
|
|
1379
|
-
copy_A(
|
|
1380
|
-
tAmA[None, m, k_idx[k]],
|
|
1381
|
-
tAsA[(None, m, k), ab_producer_state.index],
|
|
1382
|
-
pred=tApA_k[None, k],
|
|
1383
|
-
)
|
|
1384
|
-
ab_pipeline.producer_commit(ab_producer_state)
|
|
1152
|
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
|
1153
|
+
smem_idx = ab_producer_state.index
|
|
1154
|
+
if is_tma_warp:
|
|
1155
|
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1156
|
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1157
|
+
copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
|
|
1158
|
+
ab_pipeline.producer_cpasync_commit(ab_producer_state)
|
|
1385
1159
|
ab_producer_state.advance()
|
|
1386
1160
|
return ab_producer_state
|
|
1387
1161
|
|
|
@@ -1481,22 +1255,21 @@ class GemmSm90:
|
|
|
1481
1255
|
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1482
1256
|
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1483
1257
|
epi_read_state: cutlass.pipeline.PipelineState,
|
|
1484
|
-
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
1485
|
-
|
|
1486
|
-
|
|
1258
|
+
epi_producer_state: Optional[cutlass.pipeline.PipelineState],
|
|
1259
|
+
epi_tile: cute.Tile,
|
|
1260
|
+
load_acc_subtile: Callable,
|
|
1487
1261
|
tRS_rD: cute.Tensor,
|
|
1488
1262
|
tRS_rC: Optional[cute.Tensor],
|
|
1489
|
-
|
|
1263
|
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
|
1264
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
1490
1265
|
tRS_sD: cute.Tensor,
|
|
1491
|
-
tiled_copy_s2r: Optional[cute.
|
|
1266
|
+
tiled_copy_s2r: Optional[cute.ThrCopy],
|
|
1492
1267
|
tSR_rC: Optional[cute.Tensor],
|
|
1493
1268
|
tSR_sC: Optional[cute.Tensor],
|
|
1494
1269
|
copy_D: Optional[Callable],
|
|
1495
|
-
|
|
1496
|
-
bSG_gD: cute.Tensor,
|
|
1497
|
-
epi_load_g2s: Optional[Callable],
|
|
1270
|
+
copy_C: Optional[Callable],
|
|
1498
1271
|
tile_coord_mnkl: cute.Coord,
|
|
1499
|
-
|
|
1272
|
+
varlen_manager: VarlenManager,
|
|
1500
1273
|
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
1501
1274
|
tile_scheduler,
|
|
1502
1275
|
tidx: Int32,
|
|
@@ -1504,22 +1277,61 @@ class GemmSm90:
|
|
|
1504
1277
|
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
1505
1278
|
has_C = const_expr(tRS_rC is not None)
|
|
1506
1279
|
has_D = const_expr(copy_D is not None)
|
|
1507
|
-
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1508
1280
|
epi_tile_shape = cute.zipped_divide(
|
|
1509
|
-
cute.make_layout(self.cta_tile_shape_mnk[:2]),
|
|
1281
|
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
|
1510
1282
|
).shape[1]
|
|
1511
|
-
|
|
1283
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1284
|
+
epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0))
|
|
1512
1285
|
epi_tile_num = cute.size(epi_tile_shape)
|
|
1513
1286
|
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
1514
1287
|
|
|
1515
|
-
|
|
1288
|
+
epi_tensors = self.epi_begin(
|
|
1289
|
+
params,
|
|
1290
|
+
epi_smem_tensors,
|
|
1291
|
+
epi_tile,
|
|
1292
|
+
tiled_copy_t2r,
|
|
1293
|
+
tiled_copy_r2s,
|
|
1294
|
+
tile_coord_mnkl,
|
|
1295
|
+
varlen_manager,
|
|
1296
|
+
epilogue_barrier,
|
|
1297
|
+
tidx,
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
if const_expr(copy_C is not None):
|
|
1516
1301
|
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
1517
|
-
|
|
1302
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1303
|
+
if is_tma_warp:
|
|
1304
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1305
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
1306
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1307
|
+
epi_producer_state.advance()
|
|
1518
1308
|
|
|
1309
|
+
def tma_store_fn(src_idx, dst_idx):
|
|
1310
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1311
|
+
cute.arch.fence_proxy(
|
|
1312
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1313
|
+
)
|
|
1314
|
+
epilogue_barrier.arrive_and_wait()
|
|
1315
|
+
# Copy from shared memory to global memory
|
|
1316
|
+
if is_tma_warp:
|
|
1317
|
+
if const_expr(has_D):
|
|
1318
|
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
|
1319
|
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
|
1320
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
|
1321
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
|
1322
|
+
epilogue_barrier.arrive_and_wait()
|
|
1323
|
+
|
|
1324
|
+
# We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops
|
|
1325
|
+
# with the TMA store. However, currently this doesn't seem to improve perf.
|
|
1326
|
+
delay_tma_store = False
|
|
1327
|
+
|
|
1328
|
+
src_idx_prev, dst_idx_prev = None, None
|
|
1519
1329
|
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
1330
|
+
# The global memory coordinate for the current epi tile
|
|
1331
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1520
1332
|
# Copy from acc to D registers
|
|
1521
|
-
|
|
1522
|
-
|
|
1333
|
+
load_acc_subtile(tRS_rD, epi_idx)
|
|
1334
|
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
|
1523
1335
|
if const_expr(has_C):
|
|
1524
1336
|
epi_pipeline.consumer_wait(epi_read_state)
|
|
1525
1337
|
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
@@ -1531,190 +1343,40 @@ class GemmSm90:
|
|
|
1531
1343
|
with cute.arch.elect_one():
|
|
1532
1344
|
epi_pipeline.consumer_release(epi_read_state)
|
|
1533
1345
|
epi_read_state.advance()
|
|
1534
|
-
if const_expr(
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1346
|
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
1347
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
|
1348
|
+
if is_tma_warp:
|
|
1349
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1350
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
1351
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1352
|
+
epi_producer_state.advance()
|
|
1353
|
+
tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
1539
1354
|
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
1355
|
+
if const_expr(delay_tma_store):
|
|
1356
|
+
if const_expr(epi_idx > 0):
|
|
1357
|
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
|
1358
|
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
|
1540
1359
|
# Copy from D registers to shared memory
|
|
1541
1360
|
if const_expr(has_D):
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
epi_store_pipeline.producer_acquire()
|
|
1559
|
-
epilogue_barrier.arrive_and_wait()
|
|
1560
|
-
|
|
1561
|
-
return epi_read_state, epi_producer_state
|
|
1562
|
-
|
|
1563
|
-
@cute.jit
|
|
1564
|
-
def epi_load_g2s(
|
|
1565
|
-
self,
|
|
1566
|
-
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1567
|
-
copy_C: Callable,
|
|
1568
|
-
bGS_gC: cute.Tensor,
|
|
1569
|
-
bGS_sC: cute.Tensor,
|
|
1570
|
-
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
1571
|
-
epi_idx: Int32,
|
|
1572
|
-
should_load: Boolean,
|
|
1573
|
-
) -> cutlass.pipeline.PipelineState:
|
|
1574
|
-
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1575
|
-
epi_tile_layout = cute.make_layout(bGS_gC.shape[1], stride=(bGS_gC.shape[1][1], 1))
|
|
1576
|
-
if should_load:
|
|
1577
|
-
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1578
|
-
# Get the global memory coordinate for the current epi tile
|
|
1579
|
-
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1580
|
-
copy_C(
|
|
1581
|
-
bGS_gC[None, gmem_coord],
|
|
1582
|
-
bGS_sC[None, epi_producer_state.index],
|
|
1583
|
-
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1584
|
-
)
|
|
1585
|
-
# Epi pipeline's producer commit is a NOP
|
|
1586
|
-
epi_pipeline.producer_commit(epi_producer_state)
|
|
1587
|
-
epi_producer_state.advance()
|
|
1588
|
-
return epi_producer_state
|
|
1589
|
-
|
|
1590
|
-
def epi_visit_acc_subtile(
|
|
1591
|
-
self,
|
|
1592
|
-
params: EpilogueParams,
|
|
1593
|
-
tRS_rD: cute.Tensor,
|
|
1594
|
-
tRS_rC: Optional[cute.Tensor] = None,
|
|
1595
|
-
) -> Optional[cute.Tensor]:
|
|
1596
|
-
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
|
1597
|
-
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
1598
|
-
alpha = utils.load_scalar_or_pointer(params.alpha)
|
|
1599
|
-
tRS_rD.store(tRS_rD.load() * alpha)
|
|
1600
|
-
# Apply C with beta scaling
|
|
1601
|
-
if const_expr(tRS_rC is not None):
|
|
1602
|
-
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
|
1603
|
-
# beta is None, default behavior: add C (beta=1.0)
|
|
1604
|
-
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
|
|
1605
|
-
else:
|
|
1606
|
-
beta = utils.load_scalar_or_pointer(params.beta)
|
|
1607
|
-
tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
|
|
1608
|
-
return None
|
|
1609
|
-
|
|
1610
|
-
def tensormap_init(
|
|
1611
|
-
self,
|
|
1612
|
-
tensormaps: Optional[cute.Tensor],
|
|
1613
|
-
varlen_m: bool,
|
|
1614
|
-
varlen_k: bool,
|
|
1615
|
-
has_D: bool,
|
|
1616
|
-
warp_idx: Int32,
|
|
1617
|
-
):
|
|
1618
|
-
tensormap_manager = None
|
|
1619
|
-
tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
|
|
1620
|
-
tensormap_epi_ptrs = [None] * self.num_epi_tensormaps
|
|
1621
|
-
if const_expr(varlen_m or varlen_k):
|
|
1622
|
-
tensormap_manager = TensorMapManagerSm90(
|
|
1623
|
-
cutlass.utils.TensorMapUpdateMode.GMEM, self.__class__.bytes_per_tensormap
|
|
1624
|
-
)
|
|
1625
|
-
# equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
|
|
1626
|
-
tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
|
|
1627
|
-
if const_expr(varlen_m):
|
|
1628
|
-
tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
|
|
1629
|
-
tensormap_epi_offset = tensormap_d_idx
|
|
1630
|
-
if const_expr(has_D):
|
|
1631
|
-
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1632
|
-
tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
|
|
1633
|
-
)
|
|
1634
|
-
tensormap_epi_offset += 1 if not self.pingpong else 2
|
|
1635
|
-
tensormap_epi_ptrs = [
|
|
1636
|
-
tensormap_manager.get_tensormap_ptr(
|
|
1637
|
-
tensormaps[
|
|
1638
|
-
tensormap_workspace_idx,
|
|
1639
|
-
tensormap_epi_offset + i * (1 if not self.pingpong else 2),
|
|
1640
|
-
None,
|
|
1641
|
-
].iterator
|
|
1642
|
-
)
|
|
1643
|
-
for i in range(self.num_epi_tensormaps)
|
|
1644
|
-
]
|
|
1645
|
-
else:
|
|
1646
|
-
assert varlen_k
|
|
1647
|
-
if const_expr(not self.gather_A):
|
|
1648
|
-
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1649
|
-
tensormaps[tensormap_workspace_idx, 0, None].iterator
|
|
1650
|
-
)
|
|
1651
|
-
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1652
|
-
tensormaps[
|
|
1653
|
-
tensormap_workspace_idx, 1 if not self.gather_A else 0, None
|
|
1654
|
-
].iterator
|
|
1655
|
-
)
|
|
1656
|
-
tensormap_ab_ptrs = [tensormap_a_ptr, tensormap_b_ptr]
|
|
1657
|
-
return (
|
|
1658
|
-
tensormap_manager,
|
|
1659
|
-
tensormap_ab_ptrs,
|
|
1660
|
-
tensormap_d_ptr,
|
|
1661
|
-
tensormap_epi_ptrs,
|
|
1361
|
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
|
1362
|
+
if const_expr(not delay_tma_store):
|
|
1363
|
+
tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
|
|
1364
|
+
|
|
1365
|
+
if const_expr(delay_tma_store):
|
|
1366
|
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
|
1367
|
+
|
|
1368
|
+
self.epi_end(
|
|
1369
|
+
params,
|
|
1370
|
+
epi_tensors,
|
|
1371
|
+
epi_tile,
|
|
1372
|
+
tiled_copy_t2r,
|
|
1373
|
+
tiled_copy_r2s,
|
|
1374
|
+
tile_coord_mnkl,
|
|
1375
|
+
varlen_manager,
|
|
1376
|
+
tidx,
|
|
1662
1377
|
)
|
|
1663
1378
|
|
|
1664
|
-
|
|
1665
|
-
self,
|
|
1666
|
-
tensormap_manager,
|
|
1667
|
-
tensormap_ab_ptrs,
|
|
1668
|
-
cu_seqlens_k: cute.Tensor,
|
|
1669
|
-
batch_idx: Int32,
|
|
1670
|
-
is_manager_warp: bool | Boolean,
|
|
1671
|
-
) -> None:
|
|
1672
|
-
# construct tensor A/B based on real address, shape and stride information
|
|
1673
|
-
tensormap_a_ptr, tensormap_b_ptr = tensormap_ab_ptrs
|
|
1674
|
-
tensormap_ptrs = [tensormap_b_ptr]
|
|
1675
|
-
shapes = [cu_seqlens_k[batch_idx + 1]]
|
|
1676
|
-
orders = [0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1]
|
|
1677
|
-
if const_expr(not self.gather_A):
|
|
1678
|
-
tensormap_ptrs.insert(0, tensormap_a_ptr)
|
|
1679
|
-
shapes.insert(0, cu_seqlens_k[batch_idx + 1])
|
|
1680
|
-
orders.insert(0, 0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1)
|
|
1681
|
-
tensormap_manager.update_tensormap_shape(
|
|
1682
|
-
tensormap_ptrs,
|
|
1683
|
-
is_manager_warp=is_manager_warp,
|
|
1684
|
-
shapes=shapes,
|
|
1685
|
-
orders=orders,
|
|
1686
|
-
tensormap_smem_ptr=None,
|
|
1687
|
-
)
|
|
1688
|
-
|
|
1689
|
-
def tensormap_update_D_epi(
|
|
1690
|
-
self,
|
|
1691
|
-
tensormap_manager,
|
|
1692
|
-
tensormap_d_ptr,
|
|
1693
|
-
tensormap_epi_ptrs,
|
|
1694
|
-
epilogue_params: EpilogueParams,
|
|
1695
|
-
cu_seqlens_m: cute.Tensor,
|
|
1696
|
-
batch_idx: Int32,
|
|
1697
|
-
is_manager_warp: bool | Boolean,
|
|
1698
|
-
) -> None:
|
|
1699
|
-
# construct tensor D based on real address, shape and stride information
|
|
1700
|
-
tensormap_ptrs, shapes, orders = [], [], []
|
|
1701
|
-
if const_expr(tensormap_d_ptr is not None):
|
|
1702
|
-
tensormap_ptrs.append(tensormap_d_ptr)
|
|
1703
|
-
shapes.append(cu_seqlens_m[batch_idx + 1])
|
|
1704
|
-
orders.append(0 if const_expr(self.d_layout.is_m_major_c()) else 1)
|
|
1705
|
-
epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
|
|
1706
|
-
epilogue_params, cu_seqlens_m, batch_idx
|
|
1707
|
-
)
|
|
1708
|
-
tensormap_ptrs.extend(tensormap_epi_ptrs)
|
|
1709
|
-
shapes.extend(epi_shapes)
|
|
1710
|
-
orders.extend(epi_orders)
|
|
1711
|
-
tensormap_manager.update_tensormap_shape(
|
|
1712
|
-
tensormap_ptrs,
|
|
1713
|
-
is_manager_warp=is_manager_warp,
|
|
1714
|
-
shapes=shapes,
|
|
1715
|
-
orders=orders,
|
|
1716
|
-
tensormap_smem_ptr=None,
|
|
1717
|
-
)
|
|
1379
|
+
return epi_read_state, epi_producer_state
|
|
1718
1380
|
|
|
1719
1381
|
def get_scheduler_class(self, varlen_m: bool = False):
|
|
1720
1382
|
"""Return the scheduler class to use. Override in subclasses for custom schedulers."""
|
|
@@ -1773,6 +1435,40 @@ class GemmSm90:
|
|
|
1773
1435
|
)
|
|
1774
1436
|
return tile_sched_args
|
|
1775
1437
|
|
|
1438
|
+
@cute.jit
|
|
1439
|
+
def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
|
|
1440
|
+
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
1441
|
+
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
1442
|
+
|
|
1443
|
+
@cute.jit
|
|
1444
|
+
def epi_begin(
|
|
1445
|
+
self,
|
|
1446
|
+
params: EpilogueParams,
|
|
1447
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
1448
|
+
epi_tile: cute.Tile,
|
|
1449
|
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
|
1450
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
1451
|
+
tile_coord_mnkl: cute.Coord,
|
|
1452
|
+
varlen_manager: VarlenManager,
|
|
1453
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
1454
|
+
tidx: Int32,
|
|
1455
|
+
) -> Tuple[cute.Tensor, ...]:
|
|
1456
|
+
return ()
|
|
1457
|
+
|
|
1458
|
+
def epi_begin_loop(
|
|
1459
|
+
self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord
|
|
1460
|
+
) -> Tuple[cute.Tensor, ...]:
|
|
1461
|
+
return ()
|
|
1462
|
+
|
|
1463
|
+
def epi_visit_subtile(
|
|
1464
|
+
self,
|
|
1465
|
+
params: EpilogueParams,
|
|
1466
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
1467
|
+
tRS_rD: cute.Tensor,
|
|
1468
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
1469
|
+
) -> Optional[cute.Tensor]:
|
|
1470
|
+
return None
|
|
1471
|
+
|
|
1776
1472
|
def epi_visit_acc(
|
|
1777
1473
|
self,
|
|
1778
1474
|
params: EpilogueParams,
|
|
@@ -1783,10 +1479,24 @@ class GemmSm90:
|
|
|
1783
1479
|
) -> None:
|
|
1784
1480
|
pass
|
|
1785
1481
|
|
|
1482
|
+
@cute.jit
|
|
1483
|
+
def epi_end(
|
|
1484
|
+
self,
|
|
1485
|
+
params: EpilogueParams,
|
|
1486
|
+
epi_tensors: Tuple[cute.Tensor, ...],
|
|
1487
|
+
epi_tile: cute.Tile,
|
|
1488
|
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
|
1489
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
1490
|
+
tile_coord_mnkl: cute.Coord,
|
|
1491
|
+
varlen_manager,
|
|
1492
|
+
tidx,
|
|
1493
|
+
) -> None:
|
|
1494
|
+
pass
|
|
1495
|
+
|
|
1786
1496
|
def epi_to_underlying_arguments(
|
|
1787
1497
|
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
1788
1498
|
) -> EpilogueParams:
|
|
1789
|
-
return
|
|
1499
|
+
return self.EpilogueParams()
|
|
1790
1500
|
|
|
1791
1501
|
def epi_get_tma_atoms(
|
|
1792
1502
|
self, params: EpilogueParams, *, loc=None, ip=None
|
|
@@ -1810,12 +1520,12 @@ class GemmSm90:
|
|
|
1810
1520
|
def epi_smem_bytes_per_stage(
|
|
1811
1521
|
args: Optional[EpilogueArguments],
|
|
1812
1522
|
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1813
|
-
epi_tile:
|
|
1523
|
+
epi_tile: cute.Tile,
|
|
1814
1524
|
) -> int:
|
|
1815
1525
|
return 0
|
|
1816
1526
|
|
|
1817
1527
|
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
1818
|
-
return cute.struct.MemRange[
|
|
1528
|
+
return cute.struct.MemRange[Int32, 0] # Dummy struct
|
|
1819
1529
|
|
|
1820
1530
|
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
1821
1531
|
return tuple()
|
|
@@ -1842,7 +1552,7 @@ class GemmSm90:
|
|
|
1842
1552
|
self.d_layout.is_m_major_c() if self.d_layout is not None else False,
|
|
1843
1553
|
num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
|
|
1844
1554
|
),
|
|
1845
|
-
|
|
1555
|
+
Float16, # this is just to get the right source layout
|
|
1846
1556
|
)
|
|
1847
1557
|
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
|
1848
1558
|
return tiled_copy_C_atom
|
|
@@ -1852,8 +1562,7 @@ class GemmSm90:
|
|
|
1852
1562
|
tiled_mma: cute.TiledMma,
|
|
1853
1563
|
d_layout: Optional[LayoutEnum],
|
|
1854
1564
|
dtype: Type[cutlass.Numeric],
|
|
1855
|
-
|
|
1856
|
-
sD: cute.Tensor,
|
|
1565
|
+
sD: Optional[cute.Tensor],
|
|
1857
1566
|
tidx: Int32,
|
|
1858
1567
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1859
1568
|
if d_layout is None:
|
|
@@ -1868,12 +1577,10 @@ class GemmSm90:
|
|
|
1868
1577
|
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1869
1578
|
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1870
1579
|
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
|
1871
|
-
# (R2S, R2S_M, R2S_N)
|
|
1872
|
-
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
1873
1580
|
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
|
|
1874
1581
|
tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
|
|
1875
1582
|
tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
|
|
1876
|
-
return tiled_copy_r2s,
|
|
1583
|
+
return tiled_copy_r2s, tRS_rD, tRS_sD
|
|
1877
1584
|
|
|
1878
1585
|
def epilog_smem_load_and_partition(
|
|
1879
1586
|
self,
|
|
@@ -1885,7 +1592,7 @@ class GemmSm90:
|
|
|
1885
1592
|
tidx: Int32,
|
|
1886
1593
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1887
1594
|
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
1888
|
-
copy_atom_s2r =
|
|
1595
|
+
copy_atom_s2r = copy_utils.sm90_get_smem_load_op(c_layout, dtype)
|
|
1889
1596
|
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
|
1890
1597
|
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1891
1598
|
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
@@ -1896,29 +1603,30 @@ class GemmSm90:
|
|
|
1896
1603
|
def epilog_gmem_copy_and_partition(
|
|
1897
1604
|
self,
|
|
1898
1605
|
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
|
1899
|
-
|
|
1606
|
+
mD_mn: cute.Tensor,
|
|
1900
1607
|
tile_shape_mn: cute.Tile,
|
|
1901
1608
|
epi_tile: cute.Tile,
|
|
1902
1609
|
sD: cute.Tensor,
|
|
1903
1610
|
tile_coord_mnkl: cute.Coord,
|
|
1904
|
-
|
|
1611
|
+
tma_desc_ptr: Optional[cute.Pointer] = None,
|
|
1905
1612
|
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
1906
|
-
batch_idx = tile_coord_mnkl[3]
|
|
1907
|
-
if const_expr(cu_seqlens_m is not None):
|
|
1908
|
-
mD_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl)
|
|
1909
|
-
else:
|
|
1910
|
-
mD_mn = mD_mnl[None, None, batch_idx]
|
|
1911
1613
|
# (bM, bN)
|
|
1912
1614
|
gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
|
|
1913
1615
|
tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
|
|
1914
|
-
|
|
1616
|
+
is_s2g = isinstance(
|
|
1617
|
+
atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp)
|
|
1618
|
+
)
|
|
1619
|
+
src_tensor, dst_tensor = (
|
|
1620
|
+
(sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD)
|
|
1621
|
+
)
|
|
1622
|
+
return copy_utils.tma_get_copy_fn(
|
|
1915
1623
|
atom,
|
|
1916
|
-
0,
|
|
1917
|
-
cute.make_layout(1),
|
|
1918
|
-
|
|
1919
|
-
|
|
1624
|
+
cta_coord=0,
|
|
1625
|
+
cta_layout=cute.make_layout(1),
|
|
1626
|
+
src_tensor=src_tensor,
|
|
1627
|
+
dst_tensor=dst_tensor,
|
|
1628
|
+
tma_desc_ptr=tma_desc_ptr,
|
|
1920
1629
|
)
|
|
1921
|
-
return bSG_sD, bSG_gD
|
|
1922
1630
|
|
|
1923
1631
|
def make_ab_pipeline(
|
|
1924
1632
|
self,
|
|
@@ -1927,21 +1635,15 @@ class GemmSm90:
|
|
|
1927
1635
|
ab_pipeline_mbar_ptr: cute.Pointer,
|
|
1928
1636
|
):
|
|
1929
1637
|
# Threads/warps participating in this pipeline
|
|
1930
|
-
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.
|
|
1638
|
+
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32
|
|
1931
1639
|
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
1932
1640
|
# Each warp will contribute to the arrive count with the number of mcast size
|
|
1933
1641
|
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
1934
|
-
consumer_arrive_cnt = mcast_size
|
|
1935
|
-
if const_expr(self.arch != 100):
|
|
1936
|
-
consumer_arrive_cnt *= tiled_mma.size // cute.arch.WARP_SIZE
|
|
1642
|
+
consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
|
|
1937
1643
|
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1938
1644
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1939
1645
|
)
|
|
1940
|
-
if
|
|
1941
|
-
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
1942
|
-
else:
|
|
1943
|
-
# TODO: we need a pipeline class for TMACpAsyncUMMA
|
|
1944
|
-
pipeline_cls = pipeline.PipelineTmaUmma if not self.gather_A else PipelineTmaCpAsync
|
|
1646
|
+
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
1945
1647
|
return pipeline_cls.create(
|
|
1946
1648
|
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1947
1649
|
num_stages=self.ab_stage,
|
|
@@ -1973,9 +1675,7 @@ class GemmSm90:
|
|
|
1973
1675
|
def make_epi_store_pipeline(self):
|
|
1974
1676
|
# Threads/warps participating in tma store pipeline
|
|
1975
1677
|
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
|
|
1976
|
-
epi_store_producer_group = pipeline.CooperativeGroup(
|
|
1977
|
-
pipeline.Agent.Thread, num_epi_threads, num_epi_threads
|
|
1978
|
-
)
|
|
1678
|
+
epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads)
|
|
1979
1679
|
return pipeline.PipelineTmaStore.create(
|
|
1980
1680
|
num_stages=self.epi_stage, producer_group=epi_store_producer_group
|
|
1981
1681
|
)
|
|
@@ -2182,36 +1882,18 @@ class GemmSm90:
|
|
|
2182
1882
|
order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
|
|
2183
1883
|
)
|
|
2184
1884
|
|
|
1885
|
+
epi_smem_layout_staged = None
|
|
2185
1886
|
if d_dtype is not None:
|
|
2186
|
-
|
|
2187
|
-
|
|
2188
|
-
d_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
2189
|
-
sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
|
|
2190
|
-
d_dtype,
|
|
2191
|
-
)
|
|
2192
|
-
epi_smem_layout_staged = cute.tile_to_shape(
|
|
2193
|
-
d_smem_layout_atom,
|
|
2194
|
-
cute.append(d_smem_shape, epi_stage),
|
|
2195
|
-
order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
|
|
1887
|
+
epi_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
|
|
1888
|
+
d_dtype, d_layout, epi_tile, epi_stage
|
|
2196
1889
|
)
|
|
2197
|
-
else:
|
|
2198
|
-
epi_smem_layout_staged = None
|
|
2199
1890
|
|
|
1891
|
+
epi_c_smem_layout_staged = None
|
|
2200
1892
|
if c_dtype is not None:
|
|
2201
1893
|
assert c_layout is not None
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
c_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
2205
|
-
sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
|
|
2206
|
-
c_dtype,
|
|
2207
|
-
)
|
|
2208
|
-
epi_c_smem_layout_staged = cute.tile_to_shape(
|
|
2209
|
-
c_smem_layout_atom,
|
|
2210
|
-
cute.append(c_smem_shape, epi_c_stage),
|
|
2211
|
-
order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
|
|
1894
|
+
epi_c_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
|
|
1895
|
+
c_dtype, c_layout, epi_tile, epi_c_stage
|
|
2212
1896
|
)
|
|
2213
|
-
else:
|
|
2214
|
-
epi_c_smem_layout_staged = None
|
|
2215
1897
|
|
|
2216
1898
|
return (
|
|
2217
1899
|
a_smem_layout_staged,
|
|
@@ -2349,7 +2031,7 @@ class GemmSm90:
|
|
|
2349
2031
|
"""
|
|
2350
2032
|
is_valid = True
|
|
2351
2033
|
if a_dtype not in {
|
|
2352
|
-
|
|
2034
|
+
Float16,
|
|
2353
2035
|
cutlass.BFloat16,
|
|
2354
2036
|
cutlass.Float8E4M3FN,
|
|
2355
2037
|
cutlass.Float8E5M2,
|
|
@@ -2357,19 +2039,19 @@ class GemmSm90:
|
|
|
2357
2039
|
is_valid = False
|
|
2358
2040
|
# tested b_dtype
|
|
2359
2041
|
if b_dtype not in {
|
|
2360
|
-
|
|
2042
|
+
Float16,
|
|
2361
2043
|
cutlass.BFloat16,
|
|
2362
2044
|
cutlass.Float8E4M3FN,
|
|
2363
2045
|
cutlass.Float8E5M2,
|
|
2364
2046
|
}:
|
|
2365
2047
|
is_valid = False
|
|
2366
|
-
if acc_dtype not in {
|
|
2048
|
+
if acc_dtype not in {Float32, Float16}:
|
|
2367
2049
|
is_valid = False
|
|
2368
2050
|
# tested d_dtype
|
|
2369
2051
|
if d_dtype not in {
|
|
2370
2052
|
None,
|
|
2371
|
-
|
|
2372
|
-
|
|
2053
|
+
Float32,
|
|
2054
|
+
Float16,
|
|
2373
2055
|
cutlass.BFloat16,
|
|
2374
2056
|
cutlass.Float8E4M3FN,
|
|
2375
2057
|
cutlass.Float8E5M2,
|
|
@@ -2386,155 +2068,3 @@ class GemmSm90:
|
|
|
2386
2068
|
if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
|
|
2387
2069
|
is_valid = False
|
|
2388
2070
|
return is_valid
|
|
2389
|
-
|
|
2390
|
-
|
|
2391
|
-
def gemm_sm90(
|
|
2392
|
-
# (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
|
|
2393
|
-
A: Tensor,
|
|
2394
|
-
B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
|
|
2395
|
-
D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
2396
|
-
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
2397
|
-
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
2398
|
-
tile_M: int,
|
|
2399
|
-
tile_N: int,
|
|
2400
|
-
cluster_M: int,
|
|
2401
|
-
cluster_N: int,
|
|
2402
|
-
pingpong: bool = False,
|
|
2403
|
-
persistent: bool = True,
|
|
2404
|
-
alpha: float | Tensor = 1.0,
|
|
2405
|
-
beta: float | Tensor = 1.0,
|
|
2406
|
-
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
2407
|
-
cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
|
|
2408
|
-
A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
|
|
2409
|
-
batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
|
|
2410
|
-
add_to_output: bool = False,
|
|
2411
|
-
) -> None:
|
|
2412
|
-
varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
|
|
2413
|
-
assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
|
|
2414
|
-
"Only one of cu_seqlens_m and cu_seqlens_k can be specified"
|
|
2415
|
-
)
|
|
2416
|
-
gather_A = A_idx is not None
|
|
2417
|
-
if gather_A:
|
|
2418
|
-
assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
|
|
2419
|
-
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
2420
|
-
if varlen:
|
|
2421
|
-
assert persistent, "varlen requires persistent=True"
|
|
2422
|
-
if add_to_output:
|
|
2423
|
-
assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
|
|
2424
|
-
if cu_seqlens_m is not None:
|
|
2425
|
-
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
2426
|
-
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
|
2427
|
-
if cu_seqlens_k is not None:
|
|
2428
|
-
assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
|
|
2429
|
-
assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
|
|
2430
|
-
|
|
2431
|
-
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
2432
|
-
A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
|
|
2433
|
-
)
|
|
2434
|
-
GemmWrapperBase.permute_tensors(
|
|
2435
|
-
tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
|
|
2436
|
-
)
|
|
2437
|
-
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
2438
|
-
major_configs = {
|
|
2439
|
-
"A": ("m", "k", "l"),
|
|
2440
|
-
"B": ("n", "k", "l"),
|
|
2441
|
-
"D": ("m", "n", "l"),
|
|
2442
|
-
"C": ("m", "n", "l"),
|
|
2443
|
-
}
|
|
2444
|
-
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
2445
|
-
|
|
2446
|
-
acc_dtype = cutlass.Float32
|
|
2447
|
-
tile_shape_mn = (tile_M, tile_N)
|
|
2448
|
-
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
2449
|
-
if not GemmSm90.is_valid_dtypes(
|
|
2450
|
-
tensor_infos["A"].dtype,
|
|
2451
|
-
tensor_infos["B"].dtype,
|
|
2452
|
-
acc_dtype,
|
|
2453
|
-
tensor_infos["D"].dtype,
|
|
2454
|
-
tensor_infos["A"].major,
|
|
2455
|
-
tensor_infos["B"].major,
|
|
2456
|
-
):
|
|
2457
|
-
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
2458
|
-
|
|
2459
|
-
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
2460
|
-
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
2461
|
-
|
|
2462
|
-
def scalar_arg(scalar: float | Tensor):
|
|
2463
|
-
if isinstance(scalar, float):
|
|
2464
|
-
return Float32(scalar) if scalar != 1.0 else None
|
|
2465
|
-
else:
|
|
2466
|
-
assert isinstance(scalar, Tensor)
|
|
2467
|
-
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2468
|
-
|
|
2469
|
-
epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta), add_to_output)
|
|
2470
|
-
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
2471
|
-
max_active_clusters,
|
|
2472
|
-
tile_count_semaphore,
|
|
2473
|
-
batch_idx_permute,
|
|
2474
|
-
)
|
|
2475
|
-
|
|
2476
|
-
# Create varlen arguments if needed (assumes persistent=True when varlen)
|
|
2477
|
-
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
2478
|
-
cu_seqlens_m,
|
|
2479
|
-
cu_seqlens_k,
|
|
2480
|
-
A_idx,
|
|
2481
|
-
max_active_clusters,
|
|
2482
|
-
cluster_shape_mnk,
|
|
2483
|
-
tensor_infos,
|
|
2484
|
-
GemmSm90.num_epi_tensormaps,
|
|
2485
|
-
pingpong,
|
|
2486
|
-
)
|
|
2487
|
-
|
|
2488
|
-
current_stream = cutlass_torch.current_stream()
|
|
2489
|
-
compile_key = GemmWrapperBase.get_compile_key(
|
|
2490
|
-
tensor_infos,
|
|
2491
|
-
None,
|
|
2492
|
-
tile_shape_mn,
|
|
2493
|
-
cluster_shape_mnk,
|
|
2494
|
-
pingpong,
|
|
2495
|
-
persistent,
|
|
2496
|
-
tile_count_semaphore is not None,
|
|
2497
|
-
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
|
2498
|
-
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
|
2499
|
-
add_to_output,
|
|
2500
|
-
cu_seqlens_m is not None,
|
|
2501
|
-
cu_seqlens_k is not None,
|
|
2502
|
-
gather_A,
|
|
2503
|
-
batch_idx_permute is not None,
|
|
2504
|
-
key_tensor_names=("A", "B", "D", "C"),
|
|
2505
|
-
)
|
|
2506
|
-
cache = gemm_sm90.compile_cache
|
|
2507
|
-
if compile_key not in cache:
|
|
2508
|
-
gemm = GemmSm90(
|
|
2509
|
-
acc_dtype,
|
|
2510
|
-
tensor_infos["A"].dtype,
|
|
2511
|
-
tile_shape_mn,
|
|
2512
|
-
cluster_shape_mnk,
|
|
2513
|
-
pingpong=pingpong,
|
|
2514
|
-
is_persistent=persistent,
|
|
2515
|
-
gather_A=gather_A,
|
|
2516
|
-
)
|
|
2517
|
-
cache[compile_key] = cute.compile(
|
|
2518
|
-
gemm,
|
|
2519
|
-
tensor_infos["A"].cute_tensor,
|
|
2520
|
-
tensor_infos["B"].cute_tensor,
|
|
2521
|
-
tensor_infos["D"].cute_tensor,
|
|
2522
|
-
tensor_infos["C"].cute_tensor,
|
|
2523
|
-
epi_args,
|
|
2524
|
-
scheduler_args,
|
|
2525
|
-
varlen_args,
|
|
2526
|
-
current_stream,
|
|
2527
|
-
)
|
|
2528
|
-
cache[compile_key](
|
|
2529
|
-
tensor_infos["A"].cute_tensor,
|
|
2530
|
-
tensor_infos["B"].cute_tensor,
|
|
2531
|
-
tensor_infos["D"].cute_tensor,
|
|
2532
|
-
tensor_infos["C"].cute_tensor,
|
|
2533
|
-
epi_args,
|
|
2534
|
-
scheduler_args,
|
|
2535
|
-
varlen_args,
|
|
2536
|
-
current_stream,
|
|
2537
|
-
)
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
gemm_sm90.compile_cache = {}
|