quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_sm100.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
|
|
3
3
|
|
|
4
4
|
import argparse
|
|
5
|
-
from typing import Optional, Type, Tuple, Union, Callable
|
|
5
|
+
from typing import Optional, Type, Tuple, Union, Callable, Literal
|
|
6
6
|
from functools import partial
|
|
7
7
|
|
|
8
8
|
import cuda.bindings.driver as cuda
|
|
@@ -15,14 +15,23 @@ import cutlass.torch as cutlass_torch
|
|
|
15
15
|
import cutlass.pipeline as pipeline
|
|
16
16
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
17
17
|
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
|
18
|
+
from cutlass.cute.nvgpu.warp import (
|
|
19
|
+
LdMatrix8x8x16bOp,
|
|
20
|
+
LdMatrix16x16x8bOp,
|
|
21
|
+
StMatrix8x8x16bOp,
|
|
22
|
+
StMatrix16x8x8bOp,
|
|
23
|
+
)
|
|
18
24
|
from cutlass import Int32, Float32, Boolean, const_expr
|
|
19
25
|
from cutlass.utils import LayoutEnum
|
|
20
26
|
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
21
27
|
|
|
28
|
+
from quack.pipeline import PipelineTmaCpAsyncUmma
|
|
22
29
|
from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
|
|
23
30
|
from quack.tile_scheduler import TileSchedulerOptions
|
|
24
|
-
from quack.varlen_utils import VarlenArguments
|
|
25
|
-
from quack.
|
|
31
|
+
from quack.varlen_utils import VarlenArguments, VarlenManager
|
|
32
|
+
from quack.gemm_sm90 import GemmSm90, NamedBarrierGemm
|
|
33
|
+
import quack.copy_utils as copy_utils
|
|
34
|
+
import quack.sm100_utils as quack_sm100_utils
|
|
26
35
|
|
|
27
36
|
# return PipelineStateWAdvance instead of PipelineState
|
|
28
37
|
|
|
@@ -148,6 +157,7 @@ class GemmSm100(GemmSm90):
|
|
|
148
157
|
def __init__(
|
|
149
158
|
self,
|
|
150
159
|
acc_dtype: Type[cutlass.Numeric],
|
|
160
|
+
a_dtype: Type[cutlass.Numeric], # ignored for now
|
|
151
161
|
mma_tiler_mn: Tuple[int, int],
|
|
152
162
|
cluster_shape_mnk: Tuple[int, int, int],
|
|
153
163
|
sf_vec_size: Optional[int] = None,
|
|
@@ -175,7 +185,7 @@ class GemmSm100(GemmSm90):
|
|
|
175
185
|
"""
|
|
176
186
|
|
|
177
187
|
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
|
|
178
|
-
self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (
|
|
188
|
+
self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (256,)
|
|
179
189
|
self.cluster_shape_mnk = cluster_shape_mnk
|
|
180
190
|
assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1"
|
|
181
191
|
# K dimension is deferred in _setup_attributes
|
|
@@ -190,19 +200,28 @@ class GemmSm100(GemmSm90):
|
|
|
190
200
|
|
|
191
201
|
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
192
202
|
|
|
203
|
+
self.num_ab_load_warps = 1 if not self.gather_A else 5
|
|
193
204
|
self.occupancy = 1
|
|
194
205
|
# Set specialized warp ids
|
|
195
206
|
self.epilog_warp_id = (0, 1, 2, 3)
|
|
196
207
|
self.mma_warp_id = 4
|
|
197
|
-
self.
|
|
198
|
-
self.
|
|
208
|
+
self.ab_load_warp_id = 5
|
|
209
|
+
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
210
|
+
self.scheduler_warp_id = self.epi_load_warp_id + 1
|
|
199
211
|
self.num_epi_warps = len(self.epilog_warp_id)
|
|
200
|
-
self.threads_per_cta =
|
|
201
|
-
|
|
212
|
+
self.threads_per_cta = cute.arch.WARP_SIZE * (
|
|
213
|
+
self.num_ab_load_warps
|
|
214
|
+
+ len(
|
|
215
|
+
(
|
|
216
|
+
self.mma_warp_id,
|
|
217
|
+
self.epi_load_warp_id,
|
|
218
|
+
self.scheduler_warp_id,
|
|
219
|
+
*self.epilog_warp_id,
|
|
220
|
+
)
|
|
221
|
+
)
|
|
202
222
|
)
|
|
203
|
-
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_100")
|
|
204
223
|
|
|
205
|
-
def _setup_attributes(self):
|
|
224
|
+
def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments):
|
|
206
225
|
"""Set up configurations that are dependent on GEMM inputs
|
|
207
226
|
|
|
208
227
|
This method configures various attributes based on the input tensor properties
|
|
@@ -298,6 +317,8 @@ class GemmSm100(GemmSm90):
|
|
|
298
317
|
|
|
299
318
|
# Compute number of multicast CTAs for A/B
|
|
300
319
|
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
|
320
|
+
if self.gather_A:
|
|
321
|
+
assert self.num_mcast_ctas_a == 1
|
|
301
322
|
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
|
302
323
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
303
324
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
@@ -309,11 +330,18 @@ class GemmSm100(GemmSm90):
|
|
|
309
330
|
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
|
|
310
331
|
self.cta_tile_shape_mnk,
|
|
311
332
|
self.use_2cta_instrs,
|
|
312
|
-
self.d_layout,
|
|
313
|
-
self.d_dtype,
|
|
333
|
+
self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR,
|
|
334
|
+
self.d_dtype if self.d_dtype is not None else cutlass.BFloat16,
|
|
335
|
+
layout_c=self.c_layout,
|
|
336
|
+
elem_ty_c=self.c_dtype,
|
|
314
337
|
)
|
|
315
338
|
|
|
316
339
|
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
|
|
340
|
+
prefetch_A_idx = (
|
|
341
|
+
None
|
|
342
|
+
if not self.gather_A
|
|
343
|
+
else ("varlen_m" if varlen_args.mCuSeqlensM is not None else "varlen_k")
|
|
344
|
+
)
|
|
317
345
|
(
|
|
318
346
|
self.num_acc_stage,
|
|
319
347
|
self.ab_stage,
|
|
@@ -322,36 +350,50 @@ class GemmSm100(GemmSm90):
|
|
|
322
350
|
) = self._compute_stages(
|
|
323
351
|
self.tiled_mma,
|
|
324
352
|
self.mma_tiler,
|
|
353
|
+
self.cta_tile_shape_mnk,
|
|
354
|
+
self.epi_tile,
|
|
325
355
|
self.a_dtype,
|
|
326
356
|
self.b_dtype,
|
|
327
|
-
self.
|
|
357
|
+
self.sf_dtype,
|
|
358
|
+
self.sf_vec_size,
|
|
328
359
|
self.d_dtype,
|
|
329
360
|
self.c_dtype,
|
|
330
361
|
self.d_layout,
|
|
331
362
|
self.c_layout,
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
self.smem_capacity
|
|
363
|
+
epilogue_args,
|
|
364
|
+
prefetch_A_idx,
|
|
365
|
+
cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
|
|
335
366
|
self.occupancy,
|
|
336
367
|
)
|
|
337
|
-
self.sched_stage = 1
|
|
368
|
+
self.sched_stage = 1
|
|
369
|
+
self.a_prefetch_stage = (
|
|
370
|
+
0
|
|
371
|
+
if not self.gather_A
|
|
372
|
+
else (2 if varlen_args.mCuSeqlensM is not None else self.ab_stage)
|
|
373
|
+
)
|
|
338
374
|
|
|
339
375
|
# Compute A/B/SFA/SFB/C shared memory layout
|
|
340
376
|
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
|
341
377
|
self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
|
|
342
378
|
)
|
|
379
|
+
self.a_smem_load_layout_staged = self.a_smem_layout_staged
|
|
380
|
+
if const_expr(self.gather_A):
|
|
381
|
+
self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a(
|
|
382
|
+
self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
|
|
383
|
+
)
|
|
343
384
|
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
|
344
385
|
self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage
|
|
345
386
|
)
|
|
346
|
-
self.epi_smem_layout_staged =
|
|
347
|
-
|
|
348
|
-
|
|
387
|
+
self.epi_smem_layout_staged = None
|
|
388
|
+
if const_expr(self.d_dtype is not None):
|
|
389
|
+
self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
390
|
+
self.d_dtype, self.d_layout, self.epi_tile, self.epi_stage
|
|
391
|
+
)
|
|
392
|
+
self.epi_c_smem_layout_staged = None
|
|
349
393
|
if const_expr(self.c_dtype is not None):
|
|
350
394
|
self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
351
395
|
self.c_dtype, self.c_layout, self.epi_tile, self.epi_c_stage
|
|
352
396
|
)
|
|
353
|
-
else:
|
|
354
|
-
self.epi_c_smem_layout_staged = None
|
|
355
397
|
if const_expr(self.blockscaled):
|
|
356
398
|
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
|
|
357
399
|
self.tiled_mma,
|
|
@@ -449,7 +491,7 @@ class GemmSm100(GemmSm90):
|
|
|
449
491
|
]
|
|
450
492
|
|
|
451
493
|
# Setup attributes that dependent on gemm inputs
|
|
452
|
-
self._setup_attributes()
|
|
494
|
+
self._setup_attributes(epilogue_args, varlen_args)
|
|
453
495
|
|
|
454
496
|
if const_expr(self.blockscaled):
|
|
455
497
|
# Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
|
|
@@ -536,24 +578,22 @@ class GemmSm100(GemmSm90):
|
|
|
536
578
|
# Setup TMA store for D
|
|
537
579
|
tma_atom_d, tma_tensor_d = None, None
|
|
538
580
|
if const_expr(mD is not None):
|
|
539
|
-
|
|
540
|
-
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
|
541
|
-
cpasync.CopyBulkTensorTileS2GOp(),
|
|
581
|
+
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
|
542
582
|
mD,
|
|
543
|
-
|
|
583
|
+
self.epi_smem_layout_staged,
|
|
544
584
|
self.epi_tile,
|
|
585
|
+
op_type="store"
|
|
586
|
+
if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
|
|
587
|
+
else "add",
|
|
545
588
|
)
|
|
546
589
|
tma_atom_c, tma_tensor_c = None, None
|
|
547
590
|
if const_expr(mC is not None):
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
cpasync.CopyBulkTensorTileG2SOp(),
|
|
551
|
-
mC,
|
|
552
|
-
epi_c_smem_layout,
|
|
553
|
-
self.epi_tile,
|
|
591
|
+
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
|
592
|
+
mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
|
|
554
593
|
)
|
|
555
594
|
|
|
556
595
|
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
|
596
|
+
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
|
|
557
597
|
|
|
558
598
|
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
|
|
559
599
|
tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
|
|
@@ -573,6 +613,13 @@ class GemmSm100(GemmSm90):
|
|
|
573
613
|
sfb_smem_size = (
|
|
574
614
|
cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0
|
|
575
615
|
)
|
|
616
|
+
a_idx_smem_size = 0
|
|
617
|
+
if const_expr(self.gather_A):
|
|
618
|
+
a_idx_smem_size = self.a_prefetch_stage * (
|
|
619
|
+
self.cta_tile_shape_mnk[0]
|
|
620
|
+
if varlen_args.mCuSeqlensM is not None
|
|
621
|
+
else self.cta_tile_shape_mnk[2]
|
|
622
|
+
)
|
|
576
623
|
|
|
577
624
|
# Define shared storage for kernel
|
|
578
625
|
@cute.struct
|
|
@@ -581,9 +628,13 @@ class GemmSm100(GemmSm90):
|
|
|
581
628
|
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
582
629
|
acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
|
583
630
|
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
584
|
-
|
|
631
|
+
a_prefetch_pipeline_array_ptr: cute.struct.MemRange[
|
|
632
|
+
cutlass.Int64, self.a_prefetch_stage * 2
|
|
633
|
+
]
|
|
634
|
+
tile_count: cute.struct.MemRange[Int32, self.sched_stage]
|
|
585
635
|
tmem_dealloc_mbar_ptr: cutlass.Int64
|
|
586
636
|
tmem_holding_buf: Int32
|
|
637
|
+
sAIdx: cute.struct.Align[cute.struct.MemRange[Int32, a_idx_smem_size], 16]
|
|
587
638
|
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
588
639
|
sD: cute.struct.Align[
|
|
589
640
|
cute.struct.MemRange[
|
|
@@ -638,13 +689,11 @@ class GemmSm100(GemmSm90):
|
|
|
638
689
|
tma_atom_c,
|
|
639
690
|
tma_tensor_c,
|
|
640
691
|
epilogue_params,
|
|
641
|
-
|
|
642
|
-
varlen_args.mCuSeqlensK,
|
|
643
|
-
varlen_args.mTensormaps,
|
|
644
|
-
varlen_args.mAIdx,
|
|
692
|
+
varlen_params,
|
|
645
693
|
self.cluster_layout_vmnk,
|
|
646
694
|
self.cluster_layout_sfb_vmnk,
|
|
647
695
|
self.a_smem_layout_staged,
|
|
696
|
+
self.a_smem_load_layout_staged,
|
|
648
697
|
self.b_smem_layout_staged,
|
|
649
698
|
self.sfa_smem_layout_staged,
|
|
650
699
|
self.sfb_smem_layout_staged,
|
|
@@ -657,7 +706,6 @@ class GemmSm100(GemmSm90):
|
|
|
657
706
|
grid=grid,
|
|
658
707
|
block=[self.threads_per_cta, 1, 1],
|
|
659
708
|
cluster=self.cluster_shape_mnk,
|
|
660
|
-
smem=self.shared_storage.size_in_bytes(),
|
|
661
709
|
stream=stream,
|
|
662
710
|
min_blocks_per_mp=1,
|
|
663
711
|
)
|
|
@@ -682,13 +730,11 @@ class GemmSm100(GemmSm90):
|
|
|
682
730
|
tma_atom_c: Optional[cute.CopyAtom],
|
|
683
731
|
mC_mnl: Optional[cute.Tensor],
|
|
684
732
|
epilogue_params: ParamsBase,
|
|
685
|
-
|
|
686
|
-
cu_seqlens_k: Optional[cute.Tensor],
|
|
687
|
-
tensormaps: Optional[cute.Tensor],
|
|
688
|
-
mAIdx: Optional[cute.Tensor],
|
|
733
|
+
varlen_params: VarlenManager.Params,
|
|
689
734
|
cluster_layout_vmnk: cute.Layout,
|
|
690
735
|
cluster_layout_sfb_vmnk: Optional[cute.Layout],
|
|
691
736
|
a_smem_layout: cute.ComposedLayout,
|
|
737
|
+
a_smem_load_layout: cute.ComposedLayout,
|
|
692
738
|
b_smem_layout: cute.ComposedLayout,
|
|
693
739
|
sfa_smem_layout: Optional[cute.Layout],
|
|
694
740
|
sfb_smem_layout: Optional[cute.Layout],
|
|
@@ -702,8 +748,8 @@ class GemmSm100(GemmSm90):
|
|
|
702
748
|
GPU device kernel performing the Persistent batched GEMM computation.
|
|
703
749
|
"""
|
|
704
750
|
|
|
705
|
-
varlen_m = const_expr(cu_seqlens_m is not None)
|
|
706
|
-
varlen_k = const_expr(cu_seqlens_k is not None)
|
|
751
|
+
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
|
|
752
|
+
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
|
|
707
753
|
assert not (varlen_m and varlen_k)
|
|
708
754
|
if const_expr(self.gather_A):
|
|
709
755
|
assert varlen_m or varlen_k
|
|
@@ -715,7 +761,7 @@ class GemmSm100(GemmSm90):
|
|
|
715
761
|
# /////////////////////////////////////////////////////////////////////////////
|
|
716
762
|
# Prefetch Tma desc
|
|
717
763
|
# /////////////////////////////////////////////////////////////////////////////
|
|
718
|
-
if warp_idx == self.
|
|
764
|
+
if warp_idx == self.ab_load_warp_id:
|
|
719
765
|
for tma_atom in (
|
|
720
766
|
tma_atom_a,
|
|
721
767
|
tma_atom_b,
|
|
@@ -751,7 +797,7 @@ class GemmSm100(GemmSm90):
|
|
|
751
797
|
|
|
752
798
|
# Tensor memory dealloc barrier init
|
|
753
799
|
if use_2cta_instrs:
|
|
754
|
-
if warp_idx == self.
|
|
800
|
+
if warp_idx == self.ab_load_warp_id:
|
|
755
801
|
num_tmem_dealloc_threads = 32
|
|
756
802
|
cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
|
|
757
803
|
|
|
@@ -760,6 +806,7 @@ class GemmSm100(GemmSm90):
|
|
|
760
806
|
tiled_mma=tiled_mma,
|
|
761
807
|
cluster_layout_vmnk=cluster_layout_vmnk,
|
|
762
808
|
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
809
|
+
is_leader_cta=is_leader_cta,
|
|
763
810
|
)
|
|
764
811
|
epi_pipeline = None
|
|
765
812
|
if const_expr(has_C):
|
|
@@ -774,20 +821,30 @@ class GemmSm100(GemmSm90):
|
|
|
774
821
|
sched_pipeline = None
|
|
775
822
|
tile_count = None
|
|
776
823
|
if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
777
|
-
# TODO: Untested, not sure if this is right for Sm100
|
|
778
824
|
# Dynamic persistent scheduler
|
|
779
825
|
sched_pipeline = self.make_sched_pipeline(
|
|
780
826
|
self.cluster_shape_mnk,
|
|
781
827
|
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
782
|
-
|
|
828
|
+
has_C=has_C,
|
|
783
829
|
)
|
|
784
830
|
tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
831
|
+
a_prefetch_pipeline = None
|
|
832
|
+
if const_expr(self.gather_A):
|
|
833
|
+
a_prefetch_pipeline = self.make_a_prefetch_pipeline(
|
|
834
|
+
storage.a_prefetch_pipeline_array_ptr.data_ptr(),
|
|
835
|
+
)
|
|
785
836
|
|
|
786
837
|
# Setup smem tensor A/B/D
|
|
787
838
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
788
|
-
|
|
839
|
+
sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
|
840
|
+
sA = storage.sA.get_tensor(a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner)
|
|
789
841
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
790
842
|
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
|
843
|
+
sAIdx = None
|
|
844
|
+
if const_expr(self.gather_A):
|
|
845
|
+
a_idx_smem_dim = self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2]
|
|
846
|
+
a_idx_smem_layout = cute.make_layout((a_idx_smem_dim, self.a_prefetch_stage))
|
|
847
|
+
sAIdx = storage.sAIdx.get_tensor(a_idx_smem_layout)
|
|
791
848
|
sSFA, sSFB = None, None
|
|
792
849
|
if const_expr(self.blockscaled):
|
|
793
850
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
@@ -813,9 +870,17 @@ class GemmSm100(GemmSm90):
|
|
|
813
870
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
814
871
|
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
|
815
872
|
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
873
|
+
varlen_manager = VarlenManager.create(
|
|
874
|
+
varlen_params,
|
|
875
|
+
has_D,
|
|
876
|
+
self.num_epi_tensormaps,
|
|
877
|
+
# Only used if not varlen_m
|
|
878
|
+
len_m_static=Int32(
|
|
879
|
+
mA_mkl.shape[0]
|
|
880
|
+
if varlen_k or varlen_params.mAIdx is None
|
|
881
|
+
else varlen_params.mAIdx.shape[0]
|
|
882
|
+
),
|
|
883
|
+
len_k_static=Int32(mA_mkl.shape[1]),
|
|
819
884
|
)
|
|
820
885
|
|
|
821
886
|
TileSchedulerCls = partial(
|
|
@@ -833,22 +898,14 @@ class GemmSm100(GemmSm90):
|
|
|
833
898
|
)
|
|
834
899
|
|
|
835
900
|
#
|
|
836
|
-
# Specialized
|
|
901
|
+
# Specialized AB load warps
|
|
837
902
|
#
|
|
838
|
-
if warp_idx == self.
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
tensormap_ab_ptrs[0],
|
|
845
|
-
is_manager_warp=True,
|
|
846
|
-
)
|
|
847
|
-
tensormap_manager.init_tensormap_from_atom(
|
|
848
|
-
tma_atom_b,
|
|
849
|
-
tensormap_ab_ptrs[1],
|
|
850
|
-
is_manager_warp=True,
|
|
851
|
-
)
|
|
903
|
+
if warp_idx == self.ab_load_warp_id:
|
|
904
|
+
is_tma_warp = True
|
|
905
|
+
# initialize tensormap for A & B
|
|
906
|
+
varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
|
|
907
|
+
tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
|
|
908
|
+
tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
|
|
852
909
|
# Compute multicast mask for A/B buffer full
|
|
853
910
|
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
|
854
911
|
block_in_cluster_coord_sfb_vmnk = None
|
|
@@ -874,34 +931,24 @@ class GemmSm100(GemmSm90):
|
|
|
874
931
|
)
|
|
875
932
|
|
|
876
933
|
# Persistent tile scheduling loop
|
|
877
|
-
|
|
878
|
-
if const_expr(cute.size(cluster_layout_vmnk) > 1):
|
|
879
|
-
is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0
|
|
880
|
-
tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
|
|
934
|
+
tile_scheduler = TileSchedulerCls()
|
|
881
935
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
882
936
|
ab_producer_state = pipeline.make_pipeline_state(
|
|
883
937
|
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
884
938
|
)
|
|
885
939
|
if const_expr(varlen_k):
|
|
886
940
|
# wait tensormap initialization complete before update
|
|
887
|
-
|
|
888
|
-
# batch index of last tile
|
|
889
|
-
last_batch_idx = cutlass.Int32(-1)
|
|
941
|
+
varlen_manager.fence_tensormap_init()
|
|
890
942
|
do_epi_load_barrier_arrive = Boolean(True)
|
|
891
943
|
while work_tile.is_valid_tile:
|
|
892
944
|
tile_coord_mnkl = work_tile.tile_idx
|
|
893
945
|
batch_idx = tile_coord_mnkl[3]
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
tensormap_ab_ptrs,
|
|
901
|
-
cu_seqlens_k,
|
|
902
|
-
batch_idx,
|
|
903
|
-
is_manager_warp=True,
|
|
904
|
-
)
|
|
946
|
+
varlen_manager.update_tensormap_AB(
|
|
947
|
+
batch_idx,
|
|
948
|
+
self.a_layout,
|
|
949
|
+
self.b_layout,
|
|
950
|
+
is_tma_warp,
|
|
951
|
+
)
|
|
905
952
|
# ///////////////////////////////////////////////////////////////////////////
|
|
906
953
|
# Local_tile partition global tensors
|
|
907
954
|
# ///////////////////////////////////////////////////////////////////////////
|
|
@@ -910,120 +957,111 @@ class GemmSm100(GemmSm90):
|
|
|
910
957
|
tile_coord_mnkl[1],
|
|
911
958
|
tile_coord_mnkl[3],
|
|
912
959
|
)
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
cute.
|
|
918
|
-
|
|
919
|
-
|
|
960
|
+
gA_mk = None
|
|
961
|
+
if const_expr(not self.gather_A):
|
|
962
|
+
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
|
963
|
+
# (bM, bK, RestK)
|
|
964
|
+
gA_mk = cute.local_tile(
|
|
965
|
+
mA_mk,
|
|
966
|
+
cute.select(self.mma_tiler, [0, 2]),
|
|
967
|
+
(mma_tile_coord_mnl[0], None),
|
|
968
|
+
)
|
|
920
969
|
# (bN, bK, RestK)
|
|
921
|
-
|
|
922
|
-
mB_nkl,
|
|
923
|
-
cute.
|
|
924
|
-
(mma_tile_coord_mnl[1], None
|
|
970
|
+
gB_nk = cute.local_tile(
|
|
971
|
+
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
|
|
972
|
+
cute.select(self.mma_tiler, [1, 2]),
|
|
973
|
+
(mma_tile_coord_mnl[1], None),
|
|
925
974
|
)
|
|
926
975
|
if const_expr(self.blockscaled):
|
|
927
976
|
# (bM, bK)
|
|
928
977
|
gSFA_mkl = cute.local_tile(
|
|
929
|
-
mSFA_mkl,
|
|
930
|
-
cute.
|
|
931
|
-
(mma_tile_coord_mnl[0], None
|
|
978
|
+
varlen_manager.offset_batch_A(mSFA_mkl, batch_idx),
|
|
979
|
+
cute.select(self.mma_tiler, [0, 2]),
|
|
980
|
+
(mma_tile_coord_mnl[0], None),
|
|
932
981
|
)
|
|
933
982
|
# (bN, bK)
|
|
934
983
|
gSFB_nkl = cute.local_tile(
|
|
935
|
-
mSFB_nkl,
|
|
936
|
-
cute.
|
|
937
|
-
(mma_tile_coord_mnl[1], None
|
|
984
|
+
varlen_manager.offset_batch_B(mSFB_nkl, batch_idx),
|
|
985
|
+
cute.select(self.mma_tiler, [1, 2]),
|
|
986
|
+
(mma_tile_coord_mnl[1], None),
|
|
938
987
|
)
|
|
988
|
+
|
|
939
989
|
# Partition global tensor for TiledMMA_A/B/D
|
|
940
|
-
#
|
|
941
|
-
|
|
990
|
+
# Then partition global/shared tensor for TMA load A/B
|
|
991
|
+
varlen_manager.fence_tensormap_update_AB(is_tma_warp)
|
|
992
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
993
|
+
# TMA load A partition_S/D
|
|
994
|
+
a_cta_layout = cute.make_layout(
|
|
995
|
+
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
|
996
|
+
)
|
|
997
|
+
copy_A = None
|
|
998
|
+
if const_expr(not self.gather_A):
|
|
999
|
+
# (MMA, MMA_M, MMA_K, RestK)
|
|
1000
|
+
tCgA = thr_mma.partition_A(gA_mk)
|
|
1001
|
+
copy_A, _, _ = copy_utils.tma_get_copy_fn(
|
|
1002
|
+
tma_atom_a,
|
|
1003
|
+
cta_coord=block_in_cluster_coord_vmnk[2],
|
|
1004
|
+
cta_layout=a_cta_layout,
|
|
1005
|
+
src_tensor=tCgA,
|
|
1006
|
+
dst_tensor=sA,
|
|
1007
|
+
mcast_mask=a_mcast_mask,
|
|
1008
|
+
tma_desc_ptr=tma_desc_a_ptr,
|
|
1009
|
+
)
|
|
942
1010
|
# (MMA, MMA_N, MMA_K, RestK)
|
|
943
|
-
tCgB = thr_mma.partition_B(
|
|
1011
|
+
tCgB = thr_mma.partition_B(gB_nk)
|
|
944
1012
|
if const_expr(self.blockscaled):
|
|
945
1013
|
# (MMA, MMA_M, MMA_K)
|
|
946
1014
|
tCgSFA = thr_mma.partition_A(gSFA_mkl)
|
|
947
1015
|
# (MMA, MMA_N, MMA_K)
|
|
948
1016
|
tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
|
|
949
|
-
# Partition global/shared tensor for TMA load A/B
|
|
950
|
-
# TMA load A partition_S/D
|
|
951
|
-
a_cta_layout = cute.make_layout(
|
|
952
|
-
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
|
953
|
-
)
|
|
954
|
-
# ((atom_v, rest_v), STAGE)
|
|
955
|
-
# ((atom_v, rest_v), RestK)
|
|
956
|
-
tAsA, tAgA = cpasync.tma_partition(
|
|
957
|
-
tma_atom_a,
|
|
958
|
-
block_in_cluster_coord_vmnk[2],
|
|
959
|
-
a_cta_layout,
|
|
960
|
-
cute.group_modes(sA, 0, 3),
|
|
961
|
-
cute.group_modes(tCgA, 0, 3),
|
|
962
|
-
)
|
|
963
1017
|
# TMA load B partition_S/D
|
|
964
|
-
|
|
965
|
-
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
966
|
-
)
|
|
967
|
-
# ((atom_v, rest_v), STAGE)
|
|
968
|
-
# ((atom_v, rest_v), RestK)
|
|
969
|
-
tBsB, tBgB = cpasync.tma_partition(
|
|
1018
|
+
copy_B, _, _ = copy_utils.tma_get_copy_fn(
|
|
970
1019
|
tma_atom_b,
|
|
971
|
-
block_in_cluster_coord_vmnk[1],
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
1020
|
+
cta_coord=block_in_cluster_coord_vmnk[1],
|
|
1021
|
+
cta_layout=cute.make_layout(
|
|
1022
|
+
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
1023
|
+
),
|
|
1024
|
+
src_tensor=tCgB,
|
|
1025
|
+
dst_tensor=sB,
|
|
1026
|
+
mcast_mask=b_mcast_mask,
|
|
1027
|
+
tma_desc_ptr=tma_desc_b_ptr,
|
|
975
1028
|
)
|
|
1029
|
+
copy_SFA, copy_SFB = None, None
|
|
976
1030
|
if const_expr(self.blockscaled):
|
|
977
1031
|
# TMA load SFA partition_S/D
|
|
978
|
-
|
|
979
|
-
# ((atom_v, rest_v), STAGE)
|
|
980
|
-
# ((atom_v, rest_v), RestK)
|
|
981
|
-
tAsSFA, tAgSFA = cpasync.tma_partition(
|
|
1032
|
+
copy_SFA, _, _ = copy_utils.tma_get_copy_fn(
|
|
982
1033
|
tma_atom_sfa,
|
|
983
|
-
block_in_cluster_coord_vmnk[2],
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
1034
|
+
cta_coord=block_in_cluster_coord_vmnk[2],
|
|
1035
|
+
cta_layout=a_cta_layout,
|
|
1036
|
+
src_tensor=tCgSFA,
|
|
1037
|
+
dst_tensor=sSFA,
|
|
1038
|
+
filter_zeros=True,
|
|
1039
|
+
mcast_mask=sfa_mcast_mask,
|
|
1040
|
+
# tma_desc_ptr=tma_desc_sfa_ptr,
|
|
987
1041
|
)
|
|
988
|
-
tAsSFA = cute.filter_zeros(tAsSFA)
|
|
989
|
-
tAgSFA = cute.filter_zeros(tAgSFA)
|
|
990
1042
|
# TMA load SFB partition_S/D
|
|
991
1043
|
sfb_cta_layout = cute.make_layout(
|
|
992
1044
|
cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape
|
|
993
1045
|
)
|
|
994
|
-
|
|
995
|
-
# ((atom_v, rest_v), RestK)
|
|
996
|
-
tBsSFB, tBgSFB = cpasync.tma_partition(
|
|
1046
|
+
copy_SFB, _, _ = copy_utils.tma_get_copy_fn(
|
|
997
1047
|
tma_atom_sfb,
|
|
998
|
-
block_in_cluster_coord_sfb_vmnk[1],
|
|
999
|
-
sfb_cta_layout,
|
|
1000
|
-
|
|
1001
|
-
|
|
1048
|
+
cta_coord=block_in_cluster_coord_sfb_vmnk[1],
|
|
1049
|
+
cta_layout=sfb_cta_layout,
|
|
1050
|
+
src_tensor=tCgSFB,
|
|
1051
|
+
dst_tensor=sSFB,
|
|
1052
|
+
filter_zeros=True,
|
|
1053
|
+
mcast_mask=sfb_mcast_mask,
|
|
1054
|
+
# tma_desc_ptr=tma_desc_sfa_ptr,
|
|
1002
1055
|
)
|
|
1003
|
-
|
|
1004
|
-
tBgSFB = cute.filter_zeros(tBgSFB)
|
|
1005
|
-
else:
|
|
1006
|
-
tAsSFA, tAgSFA = None, None
|
|
1007
|
-
tBsSFB, tBgSFB = None, None
|
|
1056
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1008
1057
|
ab_producer_state = self.load_AB(
|
|
1009
1058
|
ab_pipeline,
|
|
1010
1059
|
ab_producer_state,
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
tBgB,
|
|
1017
|
-
tBsB,
|
|
1018
|
-
b_mcast_mask,
|
|
1019
|
-
tma_atom_sfa,
|
|
1020
|
-
tAgSFA,
|
|
1021
|
-
tAsSFA,
|
|
1022
|
-
sfa_mcast_mask,
|
|
1023
|
-
tma_atom_sfb,
|
|
1024
|
-
tBgSFB,
|
|
1025
|
-
tBsSFB,
|
|
1026
|
-
sfb_mcast_mask,
|
|
1060
|
+
copy_A,
|
|
1061
|
+
copy_B,
|
|
1062
|
+
k_tile_cnt,
|
|
1063
|
+
copy_SFA,
|
|
1064
|
+
copy_SFB,
|
|
1027
1065
|
)
|
|
1028
1066
|
if const_expr(epi_load_barrier is not None):
|
|
1029
1067
|
# In the first work tile, the epi load warp will wait for the signal
|
|
@@ -1033,19 +1071,180 @@ class GemmSm100(GemmSm90):
|
|
|
1033
1071
|
epi_load_barrier.arrive()
|
|
1034
1072
|
do_epi_load_barrier_arrive = Boolean(False)
|
|
1035
1073
|
# Advance to next tile
|
|
1036
|
-
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1037
1074
|
tile_scheduler.advance_to_next_work()
|
|
1038
1075
|
work_tile = tile_scheduler.get_current_work()
|
|
1039
1076
|
# Wait A/B buffer empty
|
|
1040
1077
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
1041
|
-
|
|
1042
|
-
|
|
1078
|
+
|
|
1079
|
+
if const_expr(self.gather_A):
|
|
1080
|
+
if (
|
|
1081
|
+
warp_idx >= self.ab_load_warp_id + 1
|
|
1082
|
+
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
1083
|
+
):
|
|
1084
|
+
# Persistent tile scheduling loop
|
|
1085
|
+
tile_scheduler = TileSchedulerCls()
|
|
1086
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1087
|
+
ab_producer_state = pipeline.make_pipeline_state(
|
|
1088
|
+
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
1089
|
+
)
|
|
1090
|
+
a_prefetch_consumer_state = pipeline.make_pipeline_state(
|
|
1091
|
+
pipeline.PipelineUserType.Consumer, self.a_prefetch_stage
|
|
1092
|
+
)
|
|
1093
|
+
while work_tile.is_valid_tile:
|
|
1094
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1095
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1096
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
1097
|
+
# Local_tile partition global tensors
|
|
1098
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
1099
|
+
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
|
1100
|
+
if const_expr(varlen_m):
|
|
1101
|
+
# (M, K)
|
|
1102
|
+
mA_mk = mA_mkl
|
|
1103
|
+
else:
|
|
1104
|
+
assert varlen_k
|
|
1105
|
+
# (tile_M, K)
|
|
1106
|
+
mA_mk = cute.local_tile(
|
|
1107
|
+
mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
|
|
1108
|
+
)
|
|
1109
|
+
# Partition global tensor for TiledMMA_A/B/D
|
|
1110
|
+
len_m = varlen_manager.len_m(batch_idx)
|
|
1111
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
1112
|
+
# TMA load A partition_S/D
|
|
1113
|
+
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
|
1114
|
+
mA_mkl.element_type, self.a_layout, (self.num_ab_load_warps - 1) * 32
|
|
1115
|
+
)
|
|
1116
|
+
tidx = cute.arch.thread_idx()[0] - (self.ab_load_warp_id + 1) * 32
|
|
1117
|
+
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
1118
|
+
copy_A, prefetch_A = None, None
|
|
1119
|
+
if const_expr(varlen_m):
|
|
1120
|
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
|
1121
|
+
copy_A = copy_utils.gather_m_get_copy_fn(
|
|
1122
|
+
thr_copy_A,
|
|
1123
|
+
mA_mk,
|
|
1124
|
+
sA,
|
|
1125
|
+
sAIdx[None, a_prefetch_consumer_state.index],
|
|
1126
|
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
1127
|
+
limit_k=len_k,
|
|
1128
|
+
)
|
|
1129
|
+
cute.arch.sync_warp()
|
|
1130
|
+
with cute.arch.elect_one():
|
|
1131
|
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
|
1132
|
+
a_prefetch_consumer_state.advance()
|
|
1133
|
+
else:
|
|
1134
|
+
copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
|
|
1135
|
+
thr_copy_A,
|
|
1136
|
+
mA_mk,
|
|
1137
|
+
sA,
|
|
1138
|
+
sAIdx,
|
|
1139
|
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
1140
|
+
limit_k=len_k,
|
|
1141
|
+
)
|
|
1142
|
+
prefetch_A = partial(prefetch_A, a_prefetch_pipeline)
|
|
1143
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1144
|
+
ab_producer_state, a_prefetch_consumer_state = self.load_A_gather_A(
|
|
1145
|
+
ab_pipeline,
|
|
1146
|
+
ab_producer_state,
|
|
1147
|
+
a_prefetch_consumer_state,
|
|
1148
|
+
copy_A,
|
|
1149
|
+
prefetch_A,
|
|
1150
|
+
k_tile_cnt,
|
|
1151
|
+
)
|
|
1152
|
+
# Advance to next tile
|
|
1153
|
+
tile_scheduler.advance_to_next_work()
|
|
1154
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1155
|
+
|
|
1156
|
+
#
|
|
1157
|
+
# Specialized scheduler warp. Will also prefetch A indices if gatherA
|
|
1158
|
+
#
|
|
1159
|
+
if const_expr(tile_sched_params.tile_count_semaphore is not None or self.gather_A):
|
|
1160
|
+
if warp_idx == self.scheduler_warp_id:
|
|
1161
|
+
is_scheduler_warp = True
|
|
1162
|
+
if const_expr(cute.size(cluster_layout_vmnk) > 1):
|
|
1163
|
+
is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0
|
|
1164
|
+
tile_M = self.cta_tile_shape_mnk[0]
|
|
1165
|
+
tile_K = self.cta_tile_shape_mnk[2]
|
|
1166
|
+
thr_copy_AIdx, tAsAIdx, tAcAIdx = None, None, None
|
|
1167
|
+
if const_expr(self.gather_A):
|
|
1168
|
+
tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True)
|
|
1169
|
+
thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx())
|
|
1170
|
+
tAsAIdx = thr_copy_AIdx.partition_D(sAIdx)
|
|
1171
|
+
tAcAIdx = thr_copy_AIdx.partition_S(
|
|
1172
|
+
cute.make_identity_tensor(tile_M if varlen_m else tile_K)
|
|
1173
|
+
)
|
|
1174
|
+
# Persistent tile scheduling loop
|
|
1175
|
+
tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
|
|
1176
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1177
|
+
a_prefetch_producer_state = None
|
|
1178
|
+
if const_expr(self.gather_A):
|
|
1179
|
+
a_prefetch_producer_state = pipeline.make_pipeline_state(
|
|
1180
|
+
pipeline.PipelineUserType.Producer, self.a_prefetch_stage
|
|
1181
|
+
)
|
|
1182
|
+
while work_tile.is_valid_tile:
|
|
1183
|
+
if const_expr(self.gather_A):
|
|
1184
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1185
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1186
|
+
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
|
1187
|
+
if const_expr(varlen_m):
|
|
1188
|
+
# (tile_M,)
|
|
1189
|
+
gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],))
|
|
1190
|
+
tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
|
|
1191
|
+
len_m = varlen_manager.len_m(batch_idx)
|
|
1192
|
+
m_limit = len_m - tile_coord_mnkl[0] * tile_M
|
|
1193
|
+
tApAIdx_m = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
|
|
1194
|
+
for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
|
|
1195
|
+
tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit
|
|
1196
|
+
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
|
|
1197
|
+
cute.copy(
|
|
1198
|
+
thr_copy_AIdx,
|
|
1199
|
+
tAgAIdx,
|
|
1200
|
+
tAsAIdx[None, None, a_prefetch_producer_state.index],
|
|
1201
|
+
pred=tApAIdx_m,
|
|
1202
|
+
)
|
|
1203
|
+
a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
|
|
1204
|
+
a_prefetch_producer_state.advance()
|
|
1205
|
+
else:
|
|
1206
|
+
# (tile_K, RestK)
|
|
1207
|
+
gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,))
|
|
1208
|
+
tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
|
|
1209
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
1210
|
+
k_tile_cnt = cute.ceil_div(len_k, tile_K)
|
|
1211
|
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1212
|
+
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
|
|
1213
|
+
cute.copy(
|
|
1214
|
+
thr_copy_AIdx,
|
|
1215
|
+
tAgAIdx[None, None, k_tile],
|
|
1216
|
+
tAsAIdx[None, None, a_prefetch_producer_state.index],
|
|
1217
|
+
)
|
|
1218
|
+
a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
|
|
1219
|
+
a_prefetch_producer_state.advance()
|
|
1220
|
+
if 0 < k_tile_cnt:
|
|
1221
|
+
k_tile = k_tile_cnt - 1
|
|
1222
|
+
k_limit = len_k - k_tile * tile_K
|
|
1223
|
+
tApAIdx_k = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
|
|
1224
|
+
for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
|
|
1225
|
+
tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit
|
|
1226
|
+
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
|
|
1227
|
+
cute.copy(
|
|
1228
|
+
tiled_copy_AIdx,
|
|
1229
|
+
tAgAIdx[None, None, k_tile],
|
|
1230
|
+
tAsAIdx[None, None, a_prefetch_producer_state.index],
|
|
1231
|
+
pred=tApAIdx_k,
|
|
1232
|
+
)
|
|
1233
|
+
a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
|
|
1234
|
+
a_prefetch_producer_state.advance()
|
|
1235
|
+
# Advance to next tile
|
|
1236
|
+
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1237
|
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1238
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1239
|
+
# End of persistent scheduler loop
|
|
1240
|
+
if is_scheduler_warp:
|
|
1241
|
+
tile_scheduler.producer_tail()
|
|
1043
1242
|
|
|
1044
1243
|
#
|
|
1045
1244
|
# Specialized TMA epi load warp
|
|
1046
1245
|
#
|
|
1047
1246
|
if const_expr(mC_mnl is not None):
|
|
1048
|
-
if warp_idx == self.
|
|
1247
|
+
if warp_idx == self.epi_load_warp_id:
|
|
1049
1248
|
epi_producer_state = pipeline.make_pipeline_state(
|
|
1050
1249
|
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
1051
1250
|
)
|
|
@@ -1056,37 +1255,23 @@ class GemmSm100(GemmSm90):
|
|
|
1056
1255
|
while work_tile.is_valid_tile:
|
|
1057
1256
|
# Get tile coord from tile scheduler
|
|
1058
1257
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
gC_mnl = cute.local_tile(
|
|
1068
|
-
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
|
|
1258
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1259
|
+
copy_C_fn, _, bGS_gC = self.epilog_gmem_copy_and_partition(
|
|
1260
|
+
tma_atom_c,
|
|
1261
|
+
varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
|
|
1262
|
+
self.cta_tile_shape_mnk[:2],
|
|
1263
|
+
epi_tile,
|
|
1264
|
+
sC,
|
|
1265
|
+
tile_coord_mnkl,
|
|
1069
1266
|
)
|
|
1070
|
-
|
|
1071
|
-
# (MMA, MMA_M, MMA_N)
|
|
1072
|
-
tCgC = thr_mma.partition_C(gC_mnl)
|
|
1073
|
-
# bGS_gC has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
1074
|
-
bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
|
|
1075
|
-
tma_atom_c, tCgC, epi_tile, sC
|
|
1076
|
-
)
|
|
1077
|
-
bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
|
|
1267
|
+
copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
|
|
1078
1268
|
if do_epi_load_barrier_wait:
|
|
1079
1269
|
epi_load_barrier.arrive_and_wait()
|
|
1080
1270
|
do_epi_load_barrier_wait = Boolean(False)
|
|
1081
1271
|
epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1]))
|
|
1082
|
-
for
|
|
1272
|
+
for epi_idx in cutlass.range(epi_tile_num, unroll=1):
|
|
1083
1273
|
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1084
|
-
|
|
1085
|
-
tma_atom_c,
|
|
1086
|
-
bGS_gC[None, subtile_idx],
|
|
1087
|
-
bGS_sC[None, epi_producer_state.index],
|
|
1088
|
-
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1089
|
-
)
|
|
1274
|
+
copy_C(src_idx=epi_idx, producer_state=epi_producer_state)
|
|
1090
1275
|
# Epi pipeline's producer commit is a NOP
|
|
1091
1276
|
epi_pipeline.producer_commit(epi_producer_state)
|
|
1092
1277
|
epi_producer_state.advance()
|
|
@@ -1107,7 +1292,7 @@ class GemmSm100(GemmSm90):
|
|
|
1107
1292
|
)
|
|
1108
1293
|
# Partition shared/tensor memory tensor for TiledMMA_A/B/D
|
|
1109
1294
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
1110
|
-
tCrA = tiled_mma.make_fragment_A(
|
|
1295
|
+
tCrA = tiled_mma.make_fragment_A(sA_mma)
|
|
1111
1296
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
1112
1297
|
tCrB = tiled_mma.make_fragment_B(sB)
|
|
1113
1298
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
@@ -1154,10 +1339,10 @@ class GemmSm100(GemmSm90):
|
|
|
1154
1339
|
tCtSFB_compact_s2t,
|
|
1155
1340
|
) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
|
|
1156
1341
|
else:
|
|
1342
|
+
tCtSFA, tCtSFB = None, None
|
|
1157
1343
|
tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
|
|
1158
1344
|
tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
|
|
1159
1345
|
|
|
1160
|
-
k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.mma_tiler[2])
|
|
1161
1346
|
# Persistent tile scheduling loop
|
|
1162
1347
|
tile_scheduler = TileSchedulerCls()
|
|
1163
1348
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
@@ -1170,6 +1355,9 @@ class GemmSm100(GemmSm90):
|
|
|
1170
1355
|
while work_tile.is_valid_tile:
|
|
1171
1356
|
# Get tile coord from tile scheduler
|
|
1172
1357
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1358
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1359
|
+
k_len = varlen_manager.len_k(batch_idx)
|
|
1360
|
+
k_tile_cnt = cute.ceil_div(k_len, self.mma_tiler[2])
|
|
1173
1361
|
# Set tensor memory buffer for current tile
|
|
1174
1362
|
# (MMA, MMA_M, MMA_N)
|
|
1175
1363
|
tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index]
|
|
@@ -1184,6 +1372,9 @@ class GemmSm100(GemmSm90):
|
|
|
1184
1372
|
tCtAcc,
|
|
1185
1373
|
k_tile_cnt,
|
|
1186
1374
|
is_leader_cta,
|
|
1375
|
+
cta_rank_in_cluster,
|
|
1376
|
+
tCtSFA,
|
|
1377
|
+
tCtSFB,
|
|
1187
1378
|
tiled_copy_s2t_sfa,
|
|
1188
1379
|
tiled_copy_s2t_sfb,
|
|
1189
1380
|
tCsSFA_compact_s2t,
|
|
@@ -1209,6 +1400,14 @@ class GemmSm100(GemmSm90):
|
|
|
1209
1400
|
)
|
|
1210
1401
|
# Bar sync for retrieve tensor memory ptr from shared memory
|
|
1211
1402
|
tmem_alloc_barrier.arrive_and_wait()
|
|
1403
|
+
|
|
1404
|
+
is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
|
|
1405
|
+
varlen_manager.init_tensormap_epi(
|
|
1406
|
+
tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
|
|
1407
|
+
)
|
|
1408
|
+
tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
|
|
1409
|
+
tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
|
|
1410
|
+
|
|
1212
1411
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
1213
1412
|
acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
1214
1413
|
self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
|
|
@@ -1221,44 +1420,22 @@ class GemmSm100(GemmSm90):
|
|
|
1221
1420
|
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
|
1222
1421
|
)
|
|
1223
1422
|
|
|
1224
|
-
is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
|
|
1225
|
-
if const_expr(varlen_m):
|
|
1226
|
-
# initialize tensormap for D
|
|
1227
|
-
if const_expr(has_D):
|
|
1228
|
-
tensormap_manager.init_tensormap_from_atom(
|
|
1229
|
-
tma_atom_d,
|
|
1230
|
-
tensormap_d_ptr,
|
|
1231
|
-
is_manager_warp=is_tma_warp,
|
|
1232
|
-
)
|
|
1233
|
-
for tma_atom, tensormap_epi_ptr in zip(
|
|
1234
|
-
self.epi_get_tma_atoms(epilogue_params), tensormap_epi_ptrs
|
|
1235
|
-
):
|
|
1236
|
-
tensormap_manager.init_tensormap_from_atom(
|
|
1237
|
-
tma_atom,
|
|
1238
|
-
tensormap_epi_ptr,
|
|
1239
|
-
is_manager_warp=is_tma_warp,
|
|
1240
|
-
)
|
|
1241
|
-
|
|
1242
1423
|
# Partition for epilogue
|
|
1243
1424
|
epi_tidx = tidx
|
|
1244
1425
|
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
|
|
1245
1426
|
epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
|
|
1246
1427
|
)
|
|
1247
1428
|
|
|
1248
|
-
tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.
|
|
1249
|
-
tiled_copy_r2s, tRS_rD, tRS_sD = self.
|
|
1250
|
-
tiled_copy_t2r, tTR_rD,
|
|
1429
|
+
tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.acc_dtype)
|
|
1430
|
+
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
|
1431
|
+
tiled_copy_t2r, self.d_layout, self.d_dtype, tTR_rD, sD, epi_tidx
|
|
1251
1432
|
)
|
|
1252
|
-
tRS_rC, tSR_rC = None, None
|
|
1433
|
+
tRS_rC, tSR_rC, tSR_sC = None, None, None
|
|
1434
|
+
tiled_copy_s2r = None
|
|
1253
1435
|
if const_expr(mC_mnl is not None):
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
tiled_copy_t2r, tTR_rC, epi_tidx, sC
|
|
1436
|
+
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
|
1437
|
+
tiled_copy_t2r, self.c_layout, self.c_dtype, sC, tRS_rD.layout, epi_tidx
|
|
1257
1438
|
)
|
|
1258
|
-
# TODO: for m major, D is being stored w STSM so we'd need LDSM here
|
|
1259
|
-
# tRS_rC = tSR_rC # TODO: retile?
|
|
1260
|
-
tRS_rC = cute.make_fragment(tRS_rD.layout, self.c_dtype)
|
|
1261
|
-
tSR_rC = tiled_copy_s2r.get_slice(epi_tidx).retile(tRS_rC)
|
|
1262
1439
|
|
|
1263
1440
|
# Persistent tile scheduling loop
|
|
1264
1441
|
tile_scheduler = TileSchedulerCls()
|
|
@@ -1272,42 +1449,21 @@ class GemmSm100(GemmSm90):
|
|
|
1272
1449
|
)
|
|
1273
1450
|
if const_expr(varlen_m):
|
|
1274
1451
|
# wait tensormap initialization complete before update
|
|
1275
|
-
|
|
1276
|
-
# batch index of last tile
|
|
1277
|
-
last_batch_idx = cutlass.Int32(-1)
|
|
1452
|
+
varlen_manager.fence_tensormap_init()
|
|
1278
1453
|
while work_tile.is_valid_tile:
|
|
1279
1454
|
# Get tile coord from tile scheduler
|
|
1280
1455
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1281
1456
|
batch_idx = tile_coord_mnkl[3]
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
last_batch_idx = batch_idx
|
|
1285
|
-
if is_group_changed:
|
|
1286
|
-
self.tensormap_update_D_epi(
|
|
1287
|
-
tensormap_manager,
|
|
1288
|
-
tensormap_d_ptr,
|
|
1289
|
-
tensormap_epi_ptrs,
|
|
1290
|
-
epilogue_params,
|
|
1291
|
-
cu_seqlens_m,
|
|
1292
|
-
batch_idx,
|
|
1293
|
-
is_manager_warp=is_tma_warp,
|
|
1294
|
-
)
|
|
1295
|
-
|
|
1296
|
-
mma_tile_coord_mnl = (
|
|
1297
|
-
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
1298
|
-
tile_coord_mnkl[1],
|
|
1299
|
-
tile_coord_mnkl[3],
|
|
1457
|
+
epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
|
|
1458
|
+
epilogue_params, varlen_params.cu_seqlens_m, batch_idx
|
|
1300
1459
|
)
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1460
|
+
varlen_manager.update_tensormap_epi(
|
|
1461
|
+
batch_idx,
|
|
1462
|
+
self.d_layout,
|
|
1463
|
+
epi_shapes,
|
|
1464
|
+
epi_orders,
|
|
1465
|
+
is_tma_warp,
|
|
1305
1466
|
)
|
|
1306
|
-
# Partition global tensor for TiledMMA_A/B/D
|
|
1307
|
-
# (MMA, MMA_M, MMA_N)
|
|
1308
|
-
tDgD = thr_mma.partition_C(gD_mnl)
|
|
1309
|
-
# bSG_gD has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
1310
|
-
bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(tma_atom_d, tDgD, epi_tile, sD)
|
|
1311
1467
|
|
|
1312
1468
|
# Set tensor memory buffer for current tile
|
|
1313
1469
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
|
@@ -1316,67 +1472,59 @@ class GemmSm100(GemmSm90):
|
|
|
1316
1472
|
# Wait for accumulator buffer full
|
|
1317
1473
|
acc_pipeline.consumer_wait(acc_consumer_state)
|
|
1318
1474
|
|
|
1319
|
-
|
|
1320
|
-
if const_expr(varlen_m):
|
|
1321
|
-
# ensure the update to tensormap has completed before using it
|
|
1322
|
-
if is_group_changed and is_tma_warp:
|
|
1323
|
-
if const_expr(has_D):
|
|
1324
|
-
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1325
|
-
for tensormap_epi_ptr in tensormap_epi_ptrs:
|
|
1326
|
-
tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
|
|
1327
|
-
if const_expr(has_D):
|
|
1328
|
-
tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1329
|
-
tensormap_d_ptr, cute.AddressSpace.generic
|
|
1330
|
-
)
|
|
1331
|
-
tma_desc_epi_ptrs = [
|
|
1332
|
-
tensormap_manager.get_tensormap_ptr(
|
|
1333
|
-
tensormap_epi_ptr, cute.AddressSpace.generic
|
|
1334
|
-
)
|
|
1335
|
-
for tensormap_epi_ptr in tensormap_epi_ptrs
|
|
1336
|
-
]
|
|
1475
|
+
varlen_manager.fence_tensormap_update_epi(is_tma_warp)
|
|
1337
1476
|
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
# Convert to D type
|
|
1349
|
-
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
|
1350
|
-
if const_expr(mC_mnl is not None):
|
|
1351
|
-
epi_pipeline.consumer_wait(epi_read_state)
|
|
1352
|
-
cute.copy(
|
|
1353
|
-
tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
|
|
1354
|
-
)
|
|
1355
|
-
# Fence to make sure shared memory read is visible to TMA load
|
|
1356
|
-
cute.arch.fence_proxy(
|
|
1357
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1358
|
-
)
|
|
1359
|
-
cute.arch.sync_warp()
|
|
1360
|
-
with cute.arch.elect_one():
|
|
1361
|
-
epi_pipeline.consumer_release(epi_read_state)
|
|
1362
|
-
epi_read_state.advance()
|
|
1363
|
-
acc_vec = acc_vec + tRS_rC.load().to(self.acc_dtype)
|
|
1364
|
-
tRS_rD.store(acc_vec.to(self.d_dtype))
|
|
1365
|
-
# Store D to shared memory
|
|
1366
|
-
d_buffer = (num_prev_subtiles + subtile_idx) % self.epi_stage
|
|
1367
|
-
cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
|
|
1368
|
-
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1369
|
-
cute.arch.fence_proxy(
|
|
1370
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1477
|
+
copy_D = None
|
|
1478
|
+
if const_expr(has_D):
|
|
1479
|
+
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
|
|
1480
|
+
tma_atom_d,
|
|
1481
|
+
varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
|
|
1482
|
+
self.cta_tile_shape_mnk[:2],
|
|
1483
|
+
epi_tile,
|
|
1484
|
+
sD,
|
|
1485
|
+
tile_coord_mnkl,
|
|
1486
|
+
tma_desc_ptr=tma_desc_d_ptr,
|
|
1371
1487
|
)
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1488
|
+
copy_C = None # We're using a separate warp to load C
|
|
1489
|
+
|
|
1490
|
+
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
|
1491
|
+
k_len = varlen_manager.len_k(batch_idx)
|
|
1492
|
+
load_acc_subtile = partial(
|
|
1493
|
+
self.epi_load_acc_subtile,
|
|
1494
|
+
tiled_copy_t2r,
|
|
1495
|
+
tiled_copy_r2s,
|
|
1496
|
+
tTR_tAcc,
|
|
1497
|
+
tTR_rAcc,
|
|
1498
|
+
clear_acc=varlen_k and k_len == 0,
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
epi_read_state, _ = self.epilogue(
|
|
1502
|
+
epilogue_params,
|
|
1503
|
+
epi_smem_tensors,
|
|
1504
|
+
tma_desc_epi_ptrs,
|
|
1505
|
+
epi_pipeline,
|
|
1506
|
+
epi_store_pipeline,
|
|
1507
|
+
epi_read_state,
|
|
1508
|
+
None, # epi_producer_state
|
|
1509
|
+
epi_tile,
|
|
1510
|
+
load_acc_subtile,
|
|
1511
|
+
tRS_rD,
|
|
1512
|
+
tRS_rC,
|
|
1513
|
+
tiled_copy_t2r,
|
|
1514
|
+
tiled_copy_r2s,
|
|
1515
|
+
tRS_sD,
|
|
1516
|
+
tiled_copy_s2r,
|
|
1517
|
+
tSR_rC,
|
|
1518
|
+
tSR_sC,
|
|
1519
|
+
copy_D,
|
|
1520
|
+
copy_C,
|
|
1521
|
+
tile_coord_mnkl,
|
|
1522
|
+
varlen_manager,
|
|
1523
|
+
epilogue_barrier,
|
|
1524
|
+
tile_scheduler,
|
|
1525
|
+
epi_tidx,
|
|
1526
|
+
is_tma_warp,
|
|
1527
|
+
)
|
|
1380
1528
|
|
|
1381
1529
|
# Async arrive accumulator buffer empty
|
|
1382
1530
|
with cute.arch.elect_one():
|
|
@@ -1404,79 +1552,50 @@ class GemmSm100(GemmSm90):
|
|
|
1404
1552
|
epi_store_pipeline.producer_tail()
|
|
1405
1553
|
|
|
1406
1554
|
@cute.jit
|
|
1407
|
-
def
|
|
1555
|
+
def load_A_gather_A(
|
|
1408
1556
|
self,
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
tBgB: cute.Tensor,
|
|
1417
|
-
tBsB: cute.Tensor,
|
|
1418
|
-
b_mcast_mask: cutlass.Int16,
|
|
1419
|
-
tma_atom_sfa: Optional[cute.CopyAtom] = None,
|
|
1420
|
-
tAgSFA: Optional[cute.Tensor] = None,
|
|
1421
|
-
tAsSFA: Optional[cute.Tensor] = None,
|
|
1422
|
-
sfa_mcast_mask: Optional[cutlass.Int16] = None,
|
|
1423
|
-
tma_atom_sfb: Optional[cute.CopyAtom] = None,
|
|
1424
|
-
tBgSFB: Optional[cute.Tensor] = None,
|
|
1425
|
-
tBsSFB: Optional[cute.Tensor] = None,
|
|
1426
|
-
sfb_mcast_mask: Optional[cutlass.Int16] = None,
|
|
1427
|
-
) -> cutlass.pipeline.PipelineState:
|
|
1428
|
-
blockscaled = const_expr(tma_atom_sfa is not None)
|
|
1429
|
-
if const_expr(blockscaled):
|
|
1430
|
-
assert all(x is not None for x in (tma_atom_sfa, tAgSFA, tAsSFA))
|
|
1431
|
-
assert all(x is not None for x in (tma_atom_sfb, tBgSFB, tBsSFB))
|
|
1432
|
-
k_tile_cnt = cute.size(tAgA, mode=[1])
|
|
1557
|
+
a_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1558
|
+
a_producer_state: cutlass.pipeline.PipelineState,
|
|
1559
|
+
a_prefetch_consumer_state: Optional[cutlass.pipeline.PipelineState],
|
|
1560
|
+
copy_A: Callable,
|
|
1561
|
+
prefetch_A: Optional[Callable],
|
|
1562
|
+
k_tile_cnt: Int32,
|
|
1563
|
+
) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]:
|
|
1433
1564
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1434
|
-
|
|
1565
|
+
peek_a_empty_status = Boolean(True)
|
|
1435
1566
|
if 0 < k_tile_cnt:
|
|
1436
|
-
|
|
1567
|
+
peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
|
|
1437
1568
|
# /////////////////////////////////////////////////////////////////////////
|
|
1438
|
-
#
|
|
1569
|
+
# cp.async on A
|
|
1439
1570
|
# /////////////////////////////////////////////////////////////////////////
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
)
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
tBgB[None, k_tile],
|
|
1454
|
-
tBsB[None, ab_producer_state.index],
|
|
1455
|
-
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1456
|
-
mcast_mask=b_mcast_mask,
|
|
1457
|
-
)
|
|
1458
|
-
if const_expr(blockscaled):
|
|
1459
|
-
cute.copy(
|
|
1460
|
-
tma_atom_sfa,
|
|
1461
|
-
tAgSFA[None, ab_producer_state.count],
|
|
1462
|
-
tAsSFA[None, ab_producer_state.index],
|
|
1463
|
-
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1464
|
-
mcast_mask=sfa_mcast_mask,
|
|
1465
|
-
)
|
|
1466
|
-
cute.copy(
|
|
1467
|
-
tma_atom_sfb,
|
|
1468
|
-
tBgSFB[None, ab_producer_state.count],
|
|
1469
|
-
tBsSFB[None, ab_producer_state.index],
|
|
1470
|
-
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1471
|
-
mcast_mask=sfb_mcast_mask,
|
|
1472
|
-
)
|
|
1473
|
-
# Mainloop pipeline's producer commit is a NOP
|
|
1474
|
-
ab_pipeline.producer_commit(ab_producer_state)
|
|
1475
|
-
ab_producer_state.advance()
|
|
1476
|
-
peek_ab_empty_status = Boolean(True)
|
|
1571
|
+
is_tma_warp = False
|
|
1572
|
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1573
|
+
smem_idx = a_producer_state.index
|
|
1574
|
+
prefetch_out = ()
|
|
1575
|
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
1576
|
+
prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),)
|
|
1577
|
+
a_prefetch_consumer_state.advance()
|
|
1578
|
+
a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp)
|
|
1579
|
+
copy_A(k_tile, smem_idx, *prefetch_out)
|
|
1580
|
+
# This tells mbarrier to track the completion of cp.async
|
|
1581
|
+
a_pipeline.producer_cpasync_commit(a_producer_state)
|
|
1582
|
+
a_producer_state.advance()
|
|
1583
|
+
peek_a_empty_status = Boolean(True)
|
|
1477
1584
|
if k_tile + 1 < k_tile_cnt:
|
|
1478
|
-
|
|
1479
|
-
|
|
1585
|
+
peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
|
|
1586
|
+
# bound checking in the K dimension on the last k_tile
|
|
1587
|
+
if 0 < k_tile_cnt:
|
|
1588
|
+
k_tile = k_tile_cnt - 1
|
|
1589
|
+
smem_idx = a_producer_state.index
|
|
1590
|
+
prefetch_out = ()
|
|
1591
|
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
1592
|
+
prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),)
|
|
1593
|
+
a_prefetch_consumer_state.advance()
|
|
1594
|
+
a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp)
|
|
1595
|
+
copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
|
|
1596
|
+
a_pipeline.producer_cpasync_commit(a_producer_state)
|
|
1597
|
+
a_producer_state.advance()
|
|
1598
|
+
return a_producer_state, a_prefetch_consumer_state
|
|
1480
1599
|
|
|
1481
1600
|
@cute.jit
|
|
1482
1601
|
def mma(
|
|
@@ -1491,6 +1610,9 @@ class GemmSm100(GemmSm90):
|
|
|
1491
1610
|
acc: cute.Tensor,
|
|
1492
1611
|
k_tile_cnt: Int32,
|
|
1493
1612
|
is_leader_cta: Boolean,
|
|
1613
|
+
cta_rank_in_cluster: Int32,
|
|
1614
|
+
tCtSFA: Optional[cute.Tensor] = None,
|
|
1615
|
+
tCtSFB: Optional[cute.Tensor] = None,
|
|
1494
1616
|
tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None,
|
|
1495
1617
|
tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None,
|
|
1496
1618
|
tCsSFA_compact_s2t: Optional[cute.Tensor] = None,
|
|
@@ -1500,12 +1622,17 @@ class GemmSm100(GemmSm90):
|
|
|
1500
1622
|
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]:
|
|
1501
1623
|
blockscaled = const_expr(tiled_copy_s2t_sfa is not None)
|
|
1502
1624
|
if const_expr(blockscaled):
|
|
1625
|
+
assert all(x is not None for x in (tCtSFA, tCtSFB))
|
|
1503
1626
|
assert all(x is not None for x in (tiled_copy_s2t_sfa, tiled_copy_s2t_sfb))
|
|
1504
1627
|
assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t))
|
|
1505
1628
|
assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t))
|
|
1629
|
+
# If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will
|
|
1630
|
+
# arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader
|
|
1631
|
+
# CTA will wait for that then arrive at the mbarrier on the leader CTA.
|
|
1632
|
+
need_nonleader_cta = const_expr(self.gather_A and self.use_2cta_instrs)
|
|
1506
1633
|
# Peek (try_wait) AB buffer full for k_tile = 0
|
|
1507
1634
|
peek_ab_full_status = Boolean(True)
|
|
1508
|
-
if 0 < k_tile_cnt and is_leader_cta:
|
|
1635
|
+
if 0 < k_tile_cnt and (is_leader_cta or need_nonleader_cta):
|
|
1509
1636
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
|
1510
1637
|
# Wait for accumulator buffer empty
|
|
1511
1638
|
if is_leader_cta:
|
|
@@ -1515,6 +1642,14 @@ class GemmSm100(GemmSm90):
|
|
|
1515
1642
|
# Mma mainloop
|
|
1516
1643
|
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
1517
1644
|
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
1645
|
+
if const_expr(need_nonleader_cta):
|
|
1646
|
+
if not is_leader_cta:
|
|
1647
|
+
ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
|
|
1648
|
+
with cute.arch.elect_one():
|
|
1649
|
+
# The odd CTA signals the even CTA
|
|
1650
|
+
ab_pipeline.sync_object_full.arrive_mbarrier(
|
|
1651
|
+
ab_consumer_state.index, dst_rank=cta_rank_in_cluster & 0xFE
|
|
1652
|
+
)
|
|
1518
1653
|
if is_leader_cta:
|
|
1519
1654
|
# Conditionally wait for AB buffer full
|
|
1520
1655
|
ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
|
|
@@ -1527,6 +1662,11 @@ class GemmSm100(GemmSm90):
|
|
|
1527
1662
|
cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
|
|
1528
1663
|
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1529
1664
|
k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index)
|
|
1665
|
+
if const_expr(blockscaled):
|
|
1666
|
+
# Set SFA/SFB tensor to tiled_mma
|
|
1667
|
+
sf_kblock_coord = (None, None, k_blk_idx)
|
|
1668
|
+
tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
|
|
1669
|
+
tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator)
|
|
1530
1670
|
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1531
1671
|
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
|
1532
1672
|
# Async arrive AB buffer empty
|
|
@@ -1534,7 +1674,7 @@ class GemmSm100(GemmSm90):
|
|
|
1534
1674
|
ab_consumer_state.advance()
|
|
1535
1675
|
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
|
1536
1676
|
peek_ab_full_status = Boolean(True)
|
|
1537
|
-
if k_tile + 1 < k_tile_cnt and is_leader_cta:
|
|
1677
|
+
if k_tile + 1 < k_tile_cnt and (is_leader_cta or need_nonleader_cta):
|
|
1538
1678
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
|
1539
1679
|
# Async arrive accumulator buffer full
|
|
1540
1680
|
if is_leader_cta:
|
|
@@ -1544,6 +1684,25 @@ class GemmSm100(GemmSm90):
|
|
|
1544
1684
|
# "operand #0 does not dominate this use"
|
|
1545
1685
|
return ab_consumer_state, acc_producer_state, tiled_mma
|
|
1546
1686
|
|
|
1687
|
+
@cute.jit
|
|
1688
|
+
def epi_load_acc_subtile(
|
|
1689
|
+
self,
|
|
1690
|
+
tiled_copy_t2r: cute.TiledCopy,
|
|
1691
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
1692
|
+
tTR_tAcc: cute.Tensor,
|
|
1693
|
+
tTR_rAcc: cute.Tensor,
|
|
1694
|
+
tRS_rD: cute.Tensor,
|
|
1695
|
+
epi_idx: int,
|
|
1696
|
+
clear_acc: Boolean = False,
|
|
1697
|
+
):
|
|
1698
|
+
if not clear_acc:
|
|
1699
|
+
# Load accumulator from tensor memory buffer to register
|
|
1700
|
+
cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, None, epi_idx], tTR_rAcc)
|
|
1701
|
+
tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc)
|
|
1702
|
+
tRS_rD.store(tRS_rAcc.load())
|
|
1703
|
+
else:
|
|
1704
|
+
tRS_rD.fill(0.0)
|
|
1705
|
+
|
|
1547
1706
|
def mainloop_s2t_copy_and_partition(
|
|
1548
1707
|
self,
|
|
1549
1708
|
sSF: cute.Tensor,
|
|
@@ -1607,8 +1766,8 @@ class GemmSm100(GemmSm90):
|
|
|
1607
1766
|
# Make tiledCopy for tensor memory load
|
|
1608
1767
|
copy_atom_t2r = sm100_utils.get_tmem_load_op(
|
|
1609
1768
|
self.cta_tile_shape_mnk,
|
|
1610
|
-
self.d_layout,
|
|
1611
|
-
self.d_dtype,
|
|
1769
|
+
self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR,
|
|
1770
|
+
self.d_dtype if self.d_dtype is not None else cutlass.BFloat16,
|
|
1612
1771
|
self.acc_dtype,
|
|
1613
1772
|
epi_tile,
|
|
1614
1773
|
use_2cta_instrs,
|
|
@@ -1631,12 +1790,14 @@ class GemmSm100(GemmSm90):
|
|
|
1631
1790
|
tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
|
|
1632
1791
|
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
|
|
1633
1792
|
|
|
1634
|
-
def
|
|
1793
|
+
def epilog_smem_store_and_partition(
|
|
1635
1794
|
self,
|
|
1636
1795
|
tiled_copy_t2r: cute.TiledCopy,
|
|
1796
|
+
d_layout: Optional[LayoutEnum],
|
|
1797
|
+
dtype: Optional[Type[cutlass.Numeric]],
|
|
1637
1798
|
tTR_rD: cute.Tensor,
|
|
1638
|
-
tidx: Int32,
|
|
1639
1799
|
sD: cute.Tensor,
|
|
1800
|
+
tidx: Int32,
|
|
1640
1801
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1641
1802
|
"""
|
|
1642
1803
|
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
|
|
@@ -1658,83 +1819,106 @@ class GemmSm100(GemmSm90):
|
|
|
1658
1819
|
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
1659
1820
|
"""
|
|
1660
1821
|
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
|
1661
|
-
|
|
1822
|
+
d_layout if d_layout is not None else LayoutEnum.ROW_MAJOR,
|
|
1823
|
+
dtype if dtype is not None else cutlass.BFloat16,
|
|
1824
|
+
self.acc_dtype,
|
|
1825
|
+
tiled_copy_t2r,
|
|
1662
1826
|
)
|
|
1663
1827
|
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
|
1664
1828
|
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1665
1829
|
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1666
|
-
tRS_sD = thr_copy_r2s.partition_D(sD)
|
|
1830
|
+
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
|
1667
1831
|
# (R2S, R2S_M, R2S_N)
|
|
1668
1832
|
tRS_rD = tiled_copy_r2s.retile(tTR_rD)
|
|
1669
1833
|
return tiled_copy_r2s, tRS_rD, tRS_sD
|
|
1670
1834
|
|
|
1671
|
-
|
|
1672
|
-
# self,
|
|
1673
|
-
# tiled_copy_t2r: cute.TiledCopy,
|
|
1674
|
-
# tTR_rC: cute.Tensor,
|
|
1675
|
-
# tidx: Int32,
|
|
1676
|
-
# sC: cute.Tensor,
|
|
1677
|
-
# ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1678
|
-
# copy_atom_s2r = cute.make_copy_atom(
|
|
1679
|
-
# warp.LdMatrix8x8x16bOp(self.c_layout.is_m_major_c(), num_matrices=4),
|
|
1680
|
-
# self.c_dtype, # TODO: this probably only works for f16 for now?
|
|
1681
|
-
# )
|
|
1682
|
-
# # copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
|
|
1683
|
-
# tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
|
|
1684
|
-
# # (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1685
|
-
# thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1686
|
-
# # (R2S, R2S_M, R2S_N)
|
|
1687
|
-
# tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1688
|
-
# return tiled_copy_s2r, tSR_sC
|
|
1689
|
-
|
|
1690
|
-
def epilog_gmem_copy_and_partition(
|
|
1835
|
+
def epilog_smem_load_and_partition(
|
|
1691
1836
|
self,
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
|
|
1837
|
+
tiled_copy_t2r: cute.TiledCopy,
|
|
1838
|
+
c_layout: LayoutEnum,
|
|
1839
|
+
dtype: Type[cutlass.Numeric],
|
|
1840
|
+
# tTR_rC: cute.Tensor,
|
|
1841
|
+
sC: cute.Tensor,
|
|
1842
|
+
tRS_rD_layout: cutlass.Layout,
|
|
1843
|
+
tidx: Int32,
|
|
1844
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1845
|
+
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
|
1846
|
+
c_layout, dtype, self.acc_dtype, tiled_copy_t2r
|
|
1847
|
+
)
|
|
1848
|
+
store_op = copy_atom_r2s.op
|
|
1849
|
+
# m8n8 16-bit path
|
|
1850
|
+
if isinstance(store_op, StMatrix8x8x16bOp):
|
|
1851
|
+
op = LdMatrix8x8x16bOp(num_matrices=store_op.num_matrices, transpose=store_op.transpose)
|
|
1852
|
+
# m16n8 8-bit store -> m16n16 8-bit load
|
|
1853
|
+
elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [2, 4]:
|
|
1854
|
+
# transpose=True is enforced by the class
|
|
1855
|
+
op = LdMatrix16x16x8bOp(num_matrices=store_op.num_matrices // 2)
|
|
1856
|
+
else:
|
|
1857
|
+
op = cute.nvgpu.CopyUniversalOp()
|
|
1858
|
+
copy_atom_s2r = cute.make_copy_atom(op, dtype)
|
|
1859
|
+
tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
|
|
1860
|
+
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1861
|
+
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1862
|
+
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1863
|
+
tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
|
|
1864
|
+
# (R2S, R2S_M, R2S_N)
|
|
1865
|
+
tSR_rC = tiled_copy_s2r.retile(tRS_rC)
|
|
1866
|
+
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
|
1709
1867
|
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
#
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1868
|
+
@cute.jit
|
|
1869
|
+
def make_ab_pipeline(
|
|
1870
|
+
self,
|
|
1871
|
+
tiled_mma: cute.TiledMma,
|
|
1872
|
+
cluster_layout_vmnk: cute.Layout,
|
|
1873
|
+
ab_pipeline_mbar_ptr: cute.Pointer,
|
|
1874
|
+
is_leader_cta: Boolean,
|
|
1875
|
+
) -> pipeline.PipelineAsync:
|
|
1876
|
+
# If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will
|
|
1877
|
+
# arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader
|
|
1878
|
+
# CTA will wait for that then arrive at the mbarrier on the leader CTA.
|
|
1879
|
+
# The producer count for the leader CTA is 1 (TMA) + num_cpasync_threads
|
|
1880
|
+
# + 1 (from non-leader CTA).
|
|
1881
|
+
# The producer count for the non-leader CTA is num_cpasync_threads
|
|
1882
|
+
# (TMA doesn't arrive there).
|
|
1883
|
+
if const_expr(not self.gather_A):
|
|
1884
|
+
producer_cnt = 1
|
|
1885
|
+
else:
|
|
1886
|
+
producer_cnt = (self.num_ab_load_warps - 1) * 32 + (
|
|
1887
|
+
1 if const_expr(not self.use_2cta_instrs) else 2
|
|
1888
|
+
)
|
|
1889
|
+
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
1890
|
+
# Each warp will contribute to the arrive count with the number of mcast size
|
|
1891
|
+
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
1892
|
+
consumer_arrive_cnt = mcast_size
|
|
1893
|
+
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1894
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1732
1895
|
)
|
|
1733
|
-
|
|
1896
|
+
if const_expr(not self.gather_A):
|
|
1897
|
+
pipeline_ab = pipeline.PipelineTmaUmma.create(
|
|
1898
|
+
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1899
|
+
num_stages=self.ab_stage,
|
|
1900
|
+
producer_group=ab_pipeline_producer_group,
|
|
1901
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
1902
|
+
tx_count=self.num_tma_load_bytes,
|
|
1903
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1904
|
+
)
|
|
1905
|
+
else:
|
|
1906
|
+
pipeline_ab = PipelineTmaCpAsyncUmma.create(
|
|
1907
|
+
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1908
|
+
num_stages=self.ab_stage,
|
|
1909
|
+
producer_group=ab_pipeline_producer_group,
|
|
1910
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
1911
|
+
tx_count=self.num_tma_load_bytes,
|
|
1912
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1913
|
+
producer_drop_count=None
|
|
1914
|
+
if not self.use_2cta_instrs
|
|
1915
|
+
else (2 if not is_leader_cta else 0),
|
|
1916
|
+
)
|
|
1917
|
+
return pipeline_ab
|
|
1734
1918
|
|
|
1735
1919
|
def make_acc_pipeline(
|
|
1736
1920
|
self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer
|
|
1737
|
-
):
|
|
1921
|
+
) -> pipeline.PipelineAsync:
|
|
1738
1922
|
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1739
1923
|
num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1)
|
|
1740
1924
|
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
@@ -1748,19 +1932,70 @@ class GemmSm100(GemmSm90):
|
|
|
1748
1932
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1749
1933
|
)
|
|
1750
1934
|
|
|
1751
|
-
|
|
1935
|
+
def make_sched_pipeline(
|
|
1936
|
+
self,
|
|
1937
|
+
cluster_layout_mnk: cute.Layout,
|
|
1938
|
+
sched_pipeline_mbar_ptr: cute.Pointer,
|
|
1939
|
+
has_C: bool = False,
|
|
1940
|
+
) -> pipeline.PipelineAsync:
|
|
1941
|
+
# Threads/warps participating in this pipeline
|
|
1942
|
+
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1943
|
+
cluster_size = cute.size(cluster_layout_mnk)
|
|
1944
|
+
# Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
|
1945
|
+
warps_per_cta = self.num_ab_load_warps + len(
|
|
1946
|
+
(self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id)
|
|
1947
|
+
)
|
|
1948
|
+
if has_C:
|
|
1949
|
+
warps_per_cta += 1
|
|
1950
|
+
consumer_arrive_cnt = warps_per_cta * cluster_size - 1
|
|
1951
|
+
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1952
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1953
|
+
)
|
|
1954
|
+
return pipeline.PipelineAsync.create(
|
|
1955
|
+
barrier_storage=sched_pipeline_mbar_ptr,
|
|
1956
|
+
num_stages=self.sched_stage,
|
|
1957
|
+
producer_group=sched_pipeline_producer_group,
|
|
1958
|
+
consumer_group=sched_pipeline_consumer_group,
|
|
1959
|
+
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
1960
|
+
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
|
1961
|
+
)
|
|
1962
|
+
|
|
1963
|
+
@cute.jit
|
|
1964
|
+
def make_a_prefetch_pipeline(
|
|
1965
|
+
self, a_prefetch_pipeline_mbar_ptr: cute.Pointer
|
|
1966
|
+
) -> pipeline.PipelineAsync:
|
|
1967
|
+
producer_cnt = 32
|
|
1968
|
+
a_prefetch_producer_group = pipeline.CooperativeGroup(
|
|
1969
|
+
pipeline.Agent.Thread, producer_cnt, alignment=producer_cnt
|
|
1970
|
+
)
|
|
1971
|
+
consumer_arrive_cnt = self.num_ab_load_warps - 1
|
|
1972
|
+
a_prefetch_consumer_group = pipeline.CooperativeGroup(
|
|
1973
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1974
|
+
)
|
|
1975
|
+
return pipeline.PipelineCpAsync.create(
|
|
1976
|
+
barrier_storage=a_prefetch_pipeline_mbar_ptr,
|
|
1977
|
+
num_stages=self.a_prefetch_stage,
|
|
1978
|
+
producer_group=a_prefetch_producer_group,
|
|
1979
|
+
consumer_group=a_prefetch_consumer_group,
|
|
1980
|
+
)
|
|
1981
|
+
|
|
1982
|
+
@classmethod
|
|
1752
1983
|
def _compute_stages(
|
|
1984
|
+
cls,
|
|
1753
1985
|
tiled_mma: cute.TiledMma,
|
|
1754
1986
|
mma_tiler_mnk: Tuple[int, int, int],
|
|
1987
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1988
|
+
epi_tile: cute.Tile,
|
|
1755
1989
|
a_dtype: Type[cutlass.Numeric],
|
|
1756
1990
|
b_dtype: Type[cutlass.Numeric],
|
|
1757
|
-
epi_tile: cute.Tile,
|
|
1758
|
-
d_dtype: Type[cutlass.Numeric],
|
|
1759
|
-
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1760
|
-
d_layout: LayoutEnum,
|
|
1761
|
-
c_layout: Optional[LayoutEnum],
|
|
1762
1991
|
sf_dtype: Optional[Type[cutlass.Numeric]],
|
|
1763
1992
|
sf_vec_size: Optional[int],
|
|
1993
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1994
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1995
|
+
d_layout: Optional[LayoutEnum],
|
|
1996
|
+
c_layout: Optional[LayoutEnum],
|
|
1997
|
+
epilogue_args: EpilogueArguments,
|
|
1998
|
+
prefetch_A_idx: Literal[None, "varlen_m", "varlen_k"],
|
|
1764
1999
|
smem_capacity: int,
|
|
1765
2000
|
occupancy: int,
|
|
1766
2001
|
) -> Tuple[int, int, int]:
|
|
@@ -1778,7 +2013,7 @@ class GemmSm100(GemmSm90):
|
|
|
1778
2013
|
:type epi_tile: cute.Tile
|
|
1779
2014
|
:param d_dtype: Data type of operand C (output).
|
|
1780
2015
|
:type d_dtype: type[cutlass.Numeric]
|
|
1781
|
-
:param d_layout: Layout enum of operand
|
|
2016
|
+
:param d_layout: Layout enum of operand D.
|
|
1782
2017
|
:type d_layout: LayoutEnum
|
|
1783
2018
|
:param smem_capacity: Total available shared memory capacity in bytes.
|
|
1784
2019
|
:type smem_capacity: int
|
|
@@ -1797,8 +2032,8 @@ class GemmSm100(GemmSm90):
|
|
|
1797
2032
|
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
|
|
1798
2033
|
|
|
1799
2034
|
# Default D stages
|
|
1800
|
-
epi_stage = 2
|
|
1801
|
-
epi_c_stage =
|
|
2035
|
+
epi_stage = 4 if cute.size(epi_tile[1]) <= 16 else 2
|
|
2036
|
+
epi_c_stage = 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2)
|
|
1802
2037
|
|
|
1803
2038
|
# Calculate smem layout and size for one stage of A, B, and C
|
|
1804
2039
|
a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
|
|
@@ -1813,7 +2048,11 @@ class GemmSm100(GemmSm90):
|
|
|
1813
2048
|
b_dtype,
|
|
1814
2049
|
1, # a tmp 1 stage is provided
|
|
1815
2050
|
)
|
|
1816
|
-
d_smem_layout_staged_one =
|
|
2051
|
+
d_smem_layout_staged_one = (
|
|
2052
|
+
sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
|
|
2053
|
+
if d_dtype is not None
|
|
2054
|
+
else None
|
|
2055
|
+
)
|
|
1817
2056
|
c_smem_layout_staged_one = (
|
|
1818
2057
|
sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
|
|
1819
2058
|
if c_dtype is not None
|
|
@@ -1836,13 +2075,22 @@ class GemmSm100(GemmSm90):
|
|
|
1836
2075
|
ab_bytes_per_stage = cute.size_in_bytes(
|
|
1837
2076
|
a_dtype, a_smem_layout_staged_one
|
|
1838
2077
|
) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
|
|
2078
|
+
if const_expr(prefetch_A_idx == "varlen_k"): # Need smem to prefetch A indices
|
|
2079
|
+
ab_bytes_per_stage += Int32.width // 8 * cta_tile_shape_mnk[2]
|
|
1839
2080
|
if const_expr(blockscaled):
|
|
1840
2081
|
ab_bytes_per_stage += cute.size_in_bytes(
|
|
1841
2082
|
sf_dtype, sfa_smem_layout_staged_one
|
|
1842
2083
|
) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
|
|
1843
2084
|
mbar_helpers_bytes = 1024
|
|
1844
|
-
|
|
1845
|
-
|
|
2085
|
+
if const_expr(prefetch_A_idx == "varlen_m"):
|
|
2086
|
+
mbar_helpers_bytes += Int32.width // 8 * cta_tile_shape_mnk[0] * 2
|
|
2087
|
+
d_bytes_per_stage = (
|
|
2088
|
+
cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) if d_dtype is not None else 0
|
|
2089
|
+
)
|
|
2090
|
+
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
|
2091
|
+
epilogue_args, cta_tile_shape_mnk, epi_tile
|
|
2092
|
+
)
|
|
2093
|
+
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
1846
2094
|
if const_expr(c_dtype is not None):
|
|
1847
2095
|
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
|
|
1848
2096
|
epi_bytes += c_bytes_per_stage * epi_c_stage
|
|
@@ -1851,18 +2099,13 @@ class GemmSm100(GemmSm90):
|
|
|
1851
2099
|
# Start with total smem per CTA (capacity / occupancy)
|
|
1852
2100
|
# Subtract reserved bytes and initial C stages bytes
|
|
1853
2101
|
# Divide remaining by bytes needed per A/B/SFA/SFB stage
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
) // ab_bytes_per_stage
|
|
2102
|
+
remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
|
|
2103
|
+
ab_stage = remaining_bytes // ab_bytes_per_stage
|
|
1857
2104
|
|
|
1858
2105
|
# Refine epilogue stages:
|
|
1859
2106
|
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1860
2107
|
# Add remaining unused smem to epilogue
|
|
1861
|
-
epi_stage += (
|
|
1862
|
-
smem_capacity
|
|
1863
|
-
- occupancy * ab_bytes_per_stage * ab_stage
|
|
1864
|
-
- occupancy * (mbar_helpers_bytes + epi_bytes)
|
|
1865
|
-
) // (occupancy * d_bytes_per_stage)
|
|
2108
|
+
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // (epi_bytes_per_stage)
|
|
1866
2109
|
return num_acc_stage, ab_stage, epi_stage, epi_c_stage
|
|
1867
2110
|
|
|
1868
2111
|
@staticmethod
|
|
@@ -1891,9 +2134,12 @@ class GemmSm100(GemmSm90):
|
|
|
1891
2134
|
|
|
1892
2135
|
@staticmethod
|
|
1893
2136
|
def is_valid_dtypes(
|
|
1894
|
-
|
|
2137
|
+
a_dtype: Type[cutlass.Numeric],
|
|
2138
|
+
b_dtype: Type[cutlass.Numeric],
|
|
1895
2139
|
acc_dtype: Type[cutlass.Numeric],
|
|
1896
|
-
d_dtype: Type[cutlass.Numeric],
|
|
2140
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
2141
|
+
a_major: str,
|
|
2142
|
+
b_major: str,
|
|
1897
2143
|
) -> bool:
|
|
1898
2144
|
"""
|
|
1899
2145
|
Check if the dtypes are valid
|
|
@@ -1909,6 +2155,9 @@ class GemmSm100(GemmSm90):
|
|
|
1909
2155
|
:rtype: bool
|
|
1910
2156
|
"""
|
|
1911
2157
|
is_valid = True
|
|
2158
|
+
if b_dtype != a_dtype:
|
|
2159
|
+
is_valid = False
|
|
2160
|
+
ab_dtype = a_dtype
|
|
1912
2161
|
if ab_dtype not in {
|
|
1913
2162
|
cutlass.Float16,
|
|
1914
2163
|
cutlass.BFloat16,
|
|
@@ -1927,7 +2176,7 @@ class GemmSm100(GemmSm90):
|
|
|
1927
2176
|
and ab_dtype not in {cutlass.Uint8, cutlass.Int8}
|
|
1928
2177
|
):
|
|
1929
2178
|
is_valid = False
|
|
1930
|
-
if (
|
|
2179
|
+
if d_dtype is not None and (
|
|
1931
2180
|
acc_dtype == Float32
|
|
1932
2181
|
and d_dtype
|
|
1933
2182
|
not in {
|
|
@@ -1958,6 +2207,8 @@ class GemmSm100(GemmSm90):
|
|
|
1958
2207
|
}
|
|
1959
2208
|
):
|
|
1960
2209
|
is_valid = False
|
|
2210
|
+
if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
|
|
2211
|
+
is_valid = False
|
|
1961
2212
|
return is_valid
|
|
1962
2213
|
|
|
1963
2214
|
@staticmethod
|
|
@@ -2014,34 +2265,6 @@ class GemmSm100(GemmSm90):
|
|
|
2014
2265
|
|
|
2015
2266
|
return is_valid
|
|
2016
2267
|
|
|
2017
|
-
@staticmethod
|
|
2018
|
-
def is_valid_layouts(
|
|
2019
|
-
ab_dtype: Type[cutlass.Numeric],
|
|
2020
|
-
a_major: str,
|
|
2021
|
-
b_major: str,
|
|
2022
|
-
) -> bool:
|
|
2023
|
-
"""
|
|
2024
|
-
Check if the dtypes and sf_vec_size are valid combinations
|
|
2025
|
-
|
|
2026
|
-
:param ab_dtype: The data type of the A and B operands
|
|
2027
|
-
:type ab_dtype: Type[cutlass.Numeric]
|
|
2028
|
-
:param d_dtype: The data type of the output tensor
|
|
2029
|
-
:type d_dtype: Type[cutlass.Numeric]
|
|
2030
|
-
:param a_major: The major dimension of the A tensor
|
|
2031
|
-
:type a_major: str
|
|
2032
|
-
:param b_major: The major dimension of the B tensor
|
|
2033
|
-
:type b_major: str
|
|
2034
|
-
:param d_major: The major dimension of the C tensor
|
|
2035
|
-
:type d_major: str
|
|
2036
|
-
|
|
2037
|
-
:return: True if the layouts are valid, False otherwise
|
|
2038
|
-
:rtype: bool
|
|
2039
|
-
"""
|
|
2040
|
-
is_valid = True
|
|
2041
|
-
if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
|
|
2042
|
-
is_valid = False
|
|
2043
|
-
return is_valid
|
|
2044
|
-
|
|
2045
2268
|
@staticmethod
|
|
2046
2269
|
def is_valid_mma_tiler_and_cluster_shape(
|
|
2047
2270
|
mma_tiler_mn: Tuple[int, int],
|
|
@@ -2187,7 +2410,7 @@ class GemmSm100(GemmSm90):
|
|
|
2187
2410
|
"""
|
|
2188
2411
|
can_implement = True
|
|
2189
2412
|
# Skip unsupported types
|
|
2190
|
-
if not GemmSm100.is_valid_dtypes(ab_dtype, acc_dtype, d_dtype):
|
|
2413
|
+
if not GemmSm100.is_valid_dtypes(ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major):
|
|
2191
2414
|
can_implement = False
|
|
2192
2415
|
# Skip invalid mma tile shape and cluster shape
|
|
2193
2416
|
if not GemmSm100.is_valid_mma_tiler_and_cluster_shape(
|
|
@@ -2362,7 +2585,7 @@ def run(
|
|
|
2362
2585
|
|
|
2363
2586
|
# Configure gemm kernel
|
|
2364
2587
|
cluster_shape_mnk = (*cluster_shape_mn, 1)
|
|
2365
|
-
gemm = GemmSm100(acc_dtype, mma_tiler_mn, cluster_shape_mnk)
|
|
2588
|
+
gemm = GemmSm100(acc_dtype, ab_dtype, mma_tiler_mn, cluster_shape_mnk)
|
|
2366
2589
|
|
|
2367
2590
|
# Compute max active clusters on current device
|
|
2368
2591
|
hardware_info = cutlass.utils.HardwareInfo()
|