quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +14 -18
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +2 -4
- quack/linear.py +37 -0
- quack/pipeline.py +59 -89
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +5 -3
- quack/sort/bitonic_sort.py +3 -3
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- quack_kernels-0.2.5.dist-info/RECORD +0 -45
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/gemm_sm90.py
CHANGED
|
@@ -12,6 +12,7 @@ import cuda.bindings.driver as cuda
|
|
|
12
12
|
import cutlass
|
|
13
13
|
import cutlass.cute as cute
|
|
14
14
|
import cutlass.pipeline as pipeline
|
|
15
|
+
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
|
15
16
|
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
16
17
|
import cutlass.utils.hopper_helpers as sm90_utils
|
|
17
18
|
from cutlass import Int32, Float32, Float16, Boolean, const_expr
|
|
@@ -26,6 +27,7 @@ from quack.tile_scheduler import (
|
|
|
26
27
|
TileScheduler,
|
|
27
28
|
VarlenMTileSchedulerArguments,
|
|
28
29
|
VarlenMTileScheduler,
|
|
30
|
+
PersistenceMode,
|
|
29
31
|
)
|
|
30
32
|
from quack.varlen_utils import VarlenArguments, VarlenManager
|
|
31
33
|
|
|
@@ -226,8 +228,6 @@ class GemmSm90:
|
|
|
226
228
|
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
|
|
227
229
|
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
|
228
230
|
self.ab_load_warp_id = self.mma_warp_groups * 4
|
|
229
|
-
# self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
230
|
-
# self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
231
231
|
|
|
232
232
|
regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
|
|
233
233
|
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
@@ -324,8 +324,6 @@ class GemmSm90:
|
|
|
324
324
|
epilogue_args,
|
|
325
325
|
cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
|
|
326
326
|
self.occupancy,
|
|
327
|
-
# epi_smem will reuse smem ab if not persistent.
|
|
328
|
-
overlap_sD_sA=not self.is_persistent,
|
|
329
327
|
)
|
|
330
328
|
self.sched_stage = 2 if self.pingpong else 1
|
|
331
329
|
|
|
@@ -401,10 +399,12 @@ class GemmSm90:
|
|
|
401
399
|
assert (varlen_args.mAIdx is not None) == self.gather_A
|
|
402
400
|
|
|
403
401
|
# Assume all strides are divisible by 128 bits except the last stride
|
|
404
|
-
new_stride
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
402
|
+
def new_stride(t: cute.Tensor):
|
|
403
|
+
return tuple(
|
|
404
|
+
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
|
405
|
+
for s in t.stride
|
|
406
|
+
)
|
|
407
|
+
|
|
408
408
|
mA, mD = [
|
|
409
409
|
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
410
410
|
if t is not None
|
|
@@ -461,9 +461,7 @@ class GemmSm90:
|
|
|
461
461
|
tile_sched_params, scheduler_args.max_active_clusters
|
|
462
462
|
)
|
|
463
463
|
|
|
464
|
-
epi_smem_size = (
|
|
465
|
-
cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
|
|
466
|
-
)
|
|
464
|
+
epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
|
|
467
465
|
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
|
468
466
|
|
|
469
467
|
@cute.struct
|
|
@@ -471,7 +469,7 @@ class GemmSm90:
|
|
|
471
469
|
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
472
470
|
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
473
471
|
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
474
|
-
|
|
472
|
+
scheduler_data: cute.struct.MemRange[Int32, self.sched_stage * 4]
|
|
475
473
|
sD: cute.struct.Align[
|
|
476
474
|
cute.struct.MemRange[
|
|
477
475
|
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
@@ -585,17 +583,13 @@ class GemmSm90:
|
|
|
585
583
|
|
|
586
584
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
587
585
|
|
|
588
|
-
#
|
|
589
|
-
# Prefetch Tma desc
|
|
590
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
586
|
+
# Prefetch Tma desc
|
|
591
587
|
if warp_idx == self.ab_load_warp_id:
|
|
592
588
|
for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
|
|
593
589
|
if const_expr(tma_atom is not None):
|
|
594
590
|
cpasync.prefetch_descriptor(tma_atom)
|
|
595
591
|
|
|
596
|
-
#
|
|
597
|
-
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
|
598
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
592
|
+
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
|
599
593
|
smem = cutlass.utils.SmemAllocator()
|
|
600
594
|
storage = smem.allocate(self.shared_storage)
|
|
601
595
|
|
|
@@ -611,28 +605,24 @@ class GemmSm90:
|
|
|
611
605
|
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
612
606
|
)
|
|
613
607
|
sched_pipeline = None
|
|
614
|
-
|
|
615
|
-
if const_expr(
|
|
616
|
-
# Dynamic persistent scheduler
|
|
608
|
+
scheduler_data = None
|
|
609
|
+
if const_expr(self.is_persistent):
|
|
617
610
|
sched_pipeline = self.make_sched_pipeline(
|
|
618
611
|
cluster_layout_mnk,
|
|
619
612
|
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
620
613
|
varlen_k=varlen_k,
|
|
621
614
|
)
|
|
622
|
-
|
|
615
|
+
scheduler_data = storage.scheduler_data.get_tensor((4, self.sched_stage))
|
|
616
|
+
|
|
617
|
+
# Cluster arrive after barrier init
|
|
618
|
+
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True)
|
|
623
619
|
|
|
624
|
-
#
|
|
625
|
-
# Generate smem tensor A/B
|
|
626
|
-
# ///////////////////////////////////////////////////////////////////////////////
|
|
620
|
+
# Generate smem tensor A/B
|
|
627
621
|
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
|
628
622
|
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
|
629
623
|
sD = None
|
|
630
624
|
if const_expr(has_D):
|
|
631
|
-
|
|
632
|
-
sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
|
|
633
|
-
sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
|
|
634
|
-
else:
|
|
635
|
-
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
|
625
|
+
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
|
636
626
|
sC = None
|
|
637
627
|
if const_expr(has_C):
|
|
638
628
|
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
@@ -654,11 +644,14 @@ class GemmSm90:
|
|
|
654
644
|
)
|
|
655
645
|
|
|
656
646
|
TileSchedulerCls = partial(
|
|
657
|
-
TileSchedulerCls.create, tile_sched_params,
|
|
647
|
+
TileSchedulerCls.create, tile_sched_params, scheduler_data, sched_pipeline
|
|
658
648
|
)
|
|
659
649
|
|
|
650
|
+
# Cluster wait for barrier init
|
|
651
|
+
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1])
|
|
652
|
+
|
|
660
653
|
if warp_idx >= self.ab_load_warp_id:
|
|
661
|
-
cute.arch.
|
|
654
|
+
cute.arch.setmaxregister_decrease(self.num_regs_load)
|
|
662
655
|
if (
|
|
663
656
|
warp_idx >= self.ab_load_warp_id
|
|
664
657
|
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
@@ -668,9 +661,7 @@ class GemmSm90:
|
|
|
668
661
|
varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
|
|
669
662
|
tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
|
|
670
663
|
tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
|
|
671
|
-
# ///////////////////////////////////////////////////////////////////////////////
|
|
672
664
|
# Get mcast mask
|
|
673
|
-
# ///////////////////////////////////////////////////////////////////////////////
|
|
674
665
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
675
666
|
block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
676
667
|
a_mcast_mask = cute.make_layout_image_mask(
|
|
@@ -686,7 +677,7 @@ class GemmSm90:
|
|
|
686
677
|
is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
687
678
|
if const_expr(cute.size(cluster_layout_mnk) > 1):
|
|
688
679
|
is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
|
|
689
|
-
tile_scheduler = TileSchedulerCls(
|
|
680
|
+
tile_scheduler = TileSchedulerCls()
|
|
690
681
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
691
682
|
ab_producer_state = make_pipeline_state(
|
|
692
683
|
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
@@ -698,14 +689,9 @@ class GemmSm90:
|
|
|
698
689
|
tile_coord_mnkl = work_tile.tile_idx
|
|
699
690
|
batch_idx = tile_coord_mnkl[3]
|
|
700
691
|
varlen_manager.update_tensormap_AB(
|
|
701
|
-
batch_idx,
|
|
702
|
-
self.a_layout,
|
|
703
|
-
self.b_layout,
|
|
704
|
-
is_tma_warp,
|
|
692
|
+
batch_idx, self.a_layout, self.b_layout, is_tma_warp
|
|
705
693
|
)
|
|
706
|
-
#
|
|
707
|
-
# Local_tile partition global tensors
|
|
708
|
-
# ///////////////////////////////////////////////////////////////////////////
|
|
694
|
+
# Local_tile partition global tensors
|
|
709
695
|
if const_expr(not self.gather_A):
|
|
710
696
|
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
|
711
697
|
# (bM, bK, RestK)
|
|
@@ -736,9 +722,7 @@ class GemmSm90:
|
|
|
736
722
|
cute.select(self.cta_tile_shape_mnk, [1, 2]),
|
|
737
723
|
(tile_coord_mnkl[1], None),
|
|
738
724
|
)
|
|
739
|
-
#
|
|
740
|
-
# Partition shared tensor for TMA load A/B
|
|
741
|
-
# //////////////////////////////////////////////////////////////////////////
|
|
725
|
+
# Partition shared tensor for TMA load A/B
|
|
742
726
|
varlen_manager.fence_tensormap_update_AB(is_tma_warp)
|
|
743
727
|
len_m = varlen_manager.len_m(batch_idx)
|
|
744
728
|
len_k = varlen_manager.len_k(batch_idx)
|
|
@@ -810,19 +794,20 @@ class GemmSm90:
|
|
|
810
794
|
k_tile_cnt,
|
|
811
795
|
varlen_m=varlen_m,
|
|
812
796
|
)
|
|
813
|
-
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
814
797
|
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
815
798
|
work_tile = tile_scheduler.get_current_work()
|
|
816
799
|
# End of persistent scheduler loop
|
|
817
800
|
if const_expr(self.pingpong and not varlen_k):
|
|
818
801
|
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
|
819
|
-
|
|
802
|
+
if is_scheduler_warp:
|
|
803
|
+
tile_scheduler.write_work_tile_to_smem(work_tile)
|
|
804
|
+
work_tile = tile_scheduler.get_current_work()
|
|
820
805
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
821
806
|
if is_scheduler_warp:
|
|
822
807
|
tile_scheduler.producer_tail()
|
|
823
808
|
|
|
824
809
|
if warp_idx < self.ab_load_warp_id:
|
|
825
|
-
cute.arch.
|
|
810
|
+
cute.arch.setmaxregister_increase(self.num_regs_mma)
|
|
826
811
|
is_tma_warp = Boolean(
|
|
827
812
|
(not self.pingpong and warp_idx == 0)
|
|
828
813
|
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
|
@@ -832,34 +817,30 @@ class GemmSm90:
|
|
|
832
817
|
)
|
|
833
818
|
tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
|
|
834
819
|
tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
|
|
835
|
-
#
|
|
836
|
-
# Partition global tensor for TiledMMA_A/B/C
|
|
837
|
-
# //////////////////////////////////////////////////////////////////////////////
|
|
820
|
+
# Partition global tensor for TiledMMA_A/B/C
|
|
838
821
|
tidx, _, _ = cute.arch.thread_idx()
|
|
839
822
|
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
|
840
823
|
if const_expr(self.pingpong):
|
|
841
824
|
tidx = tidx % self.num_threads_per_warp_group
|
|
842
825
|
warp_group_thread_layout = cute.make_layout(
|
|
843
|
-
self.mma_warp_groups if not self.pingpong else 1,
|
|
826
|
+
self.mma_warp_groups if const_expr(not self.pingpong) else 1,
|
|
844
827
|
stride=self.num_threads_per_warp_group,
|
|
845
828
|
)
|
|
846
829
|
thr_mma = tiled_mma.get_slice(
|
|
847
830
|
warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
|
|
848
831
|
)
|
|
849
832
|
|
|
850
|
-
#
|
|
851
|
-
# Make fragments
|
|
852
|
-
# //////////////////////////////////////////////////////////////////////////////
|
|
833
|
+
# Make fragments
|
|
853
834
|
tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
|
|
854
835
|
tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
|
|
855
836
|
|
|
856
837
|
acc_shape = tiled_mma.partition_shape_C(
|
|
857
838
|
cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
|
|
858
839
|
)
|
|
859
|
-
acc = cute.
|
|
840
|
+
acc = cute.make_rmem_tensor(acc_shape, self.acc_dtype)
|
|
860
841
|
acc_slow = None
|
|
861
842
|
if const_expr(self.fp8_slow_accum):
|
|
862
|
-
acc_slow = cute.
|
|
843
|
+
acc_slow = cute.make_rmem_tensor(acc_shape, self.acc_dtype)
|
|
863
844
|
|
|
864
845
|
if const_expr(self.pingpong):
|
|
865
846
|
if warp_group_idx == 0:
|
|
@@ -879,10 +860,8 @@ class GemmSm90:
|
|
|
879
860
|
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
880
861
|
)
|
|
881
862
|
tile_scheduler = TileSchedulerCls()
|
|
882
|
-
work_tile =
|
|
863
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
883
864
|
if const_expr(self.pingpong):
|
|
884
|
-
if const_expr(varlen_k):
|
|
885
|
-
work_tile = tile_scheduler.initial_work_tile_info()
|
|
886
865
|
if warp_idx >= 4:
|
|
887
866
|
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
|
888
867
|
epi_read_state.advance_iters(c_tile_cnt)
|
|
@@ -893,13 +872,9 @@ class GemmSm90:
|
|
|
893
872
|
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
|
894
873
|
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
895
874
|
ab_read_state.advance_iters(k_tile_cnt)
|
|
875
|
+
# TODO: do we need to check if work_tile is valid?
|
|
896
876
|
tile_scheduler.advance_to_next_work()
|
|
897
|
-
|
|
898
|
-
work_tile = tile_scheduler.get_current_work()
|
|
899
|
-
if const_expr(not varlen_k):
|
|
900
|
-
work_tile = tile_scheduler.initial_work_tile_info()
|
|
901
|
-
else:
|
|
902
|
-
work_tile = tile_scheduler.initial_work_tile_info()
|
|
877
|
+
work_tile = tile_scheduler.get_current_work()
|
|
903
878
|
if const_expr(varlen_m):
|
|
904
879
|
# wait tensormap initialization complete before update
|
|
905
880
|
varlen_manager.fence_tensormap_init()
|
|
@@ -910,11 +885,7 @@ class GemmSm90:
|
|
|
910
885
|
epilogue_params, varlen_params.cu_seqlens_m, batch_idx
|
|
911
886
|
)
|
|
912
887
|
varlen_manager.update_tensormap_epi(
|
|
913
|
-
batch_idx,
|
|
914
|
-
self.d_layout,
|
|
915
|
-
epi_shapes,
|
|
916
|
-
epi_orders,
|
|
917
|
-
is_tma_warp,
|
|
888
|
+
batch_idx, self.d_layout, epi_shapes, epi_orders, is_tma_warp
|
|
918
889
|
)
|
|
919
890
|
len_k = varlen_manager.len_k(batch_idx)
|
|
920
891
|
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
@@ -933,9 +904,7 @@ class GemmSm90:
|
|
|
933
904
|
if k_tile_cnt == 0:
|
|
934
905
|
acc.fill(0.0)
|
|
935
906
|
|
|
936
|
-
#
|
|
937
|
-
# EPILOGUE
|
|
938
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
907
|
+
# EPILOGUE
|
|
939
908
|
if const_expr(self.pingpong):
|
|
940
909
|
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
|
941
910
|
|
|
@@ -983,11 +952,6 @@ class GemmSm90:
|
|
|
983
952
|
else:
|
|
984
953
|
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
|
985
954
|
|
|
986
|
-
# Wait for all warp groups in the thread block to finish, because smem for tensor
|
|
987
|
-
# A in the mainloop is reused in the epilogue if not persistent.
|
|
988
|
-
if const_expr(not self.is_persistent):
|
|
989
|
-
epilogue_barrier.arrive_and_wait()
|
|
990
|
-
|
|
991
955
|
self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
|
|
992
956
|
|
|
993
957
|
epi_read_state, epi_producer_state = self.epilogue(
|
|
@@ -1073,9 +1037,7 @@ class GemmSm90:
|
|
|
1073
1037
|
peek_ab_empty_status = Boolean(True)
|
|
1074
1038
|
if 0 < k_tile_cnt:
|
|
1075
1039
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1076
|
-
# /////////////////////////////////////////////////////////////////////////
|
|
1077
1040
|
# TMA load
|
|
1078
|
-
# /////////////////////////////////////////////////////////////////////////
|
|
1079
1041
|
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
1080
1042
|
# Wait for A/B buffers to be empty before loading into them
|
|
1081
1043
|
# Also sets the transaction barrier for the A/B buffers
|
|
@@ -1112,9 +1074,7 @@ class GemmSm90:
|
|
|
1112
1074
|
peek_ab_empty_status = Boolean(True)
|
|
1113
1075
|
if 0 < k_tile_cnt:
|
|
1114
1076
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1115
|
-
# /////////////////////////////////////////////////////////////////////////
|
|
1116
1077
|
# TMA load on B and cp.async on A
|
|
1117
|
-
# /////////////////////////////////////////////////////////////////////////
|
|
1118
1078
|
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1119
1079
|
prefetch_out = ()
|
|
1120
1080
|
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
@@ -1172,9 +1132,7 @@ class GemmSm90:
|
|
|
1172
1132
|
k_tile_cnt: Int32,
|
|
1173
1133
|
warp_group_idx: Int32,
|
|
1174
1134
|
) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
|
|
1175
|
-
#
|
|
1176
|
-
# Prologue MMAs
|
|
1177
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
1135
|
+
# Prologue MMAs
|
|
1178
1136
|
k_pipe_mmas = 1
|
|
1179
1137
|
ab_release_state = ab_read_state.clone()
|
|
1180
1138
|
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
|
|
@@ -1204,13 +1162,10 @@ class GemmSm90:
|
|
|
1204
1162
|
warpgroup.wait_group(0)
|
|
1205
1163
|
acc_slow.store(acc.load())
|
|
1206
1164
|
|
|
1207
|
-
#
|
|
1208
|
-
# MAINLOOP
|
|
1209
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
1165
|
+
# MAINLOOP
|
|
1210
1166
|
for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
|
|
1211
1167
|
# Wait for TMA copies to complete
|
|
1212
1168
|
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
1213
|
-
# WGMMA
|
|
1214
1169
|
warpgroup.fence()
|
|
1215
1170
|
if const_expr(self.fp8_slow_accum):
|
|
1216
1171
|
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
@@ -1308,9 +1263,7 @@ class GemmSm90:
|
|
|
1308
1263
|
|
|
1309
1264
|
def tma_store_fn(src_idx, dst_idx):
|
|
1310
1265
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1311
|
-
cute.arch.
|
|
1312
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1313
|
-
)
|
|
1266
|
+
cute.arch.fence_view_async_shared()
|
|
1314
1267
|
epilogue_barrier.arrive_and_wait()
|
|
1315
1268
|
# Copy from shared memory to global memory
|
|
1316
1269
|
if is_tma_warp:
|
|
@@ -1336,9 +1289,7 @@ class GemmSm90:
|
|
|
1336
1289
|
epi_pipeline.consumer_wait(epi_read_state)
|
|
1337
1290
|
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
1338
1291
|
# Fence to make sure shared memory read is visible to TMA load
|
|
1339
|
-
cute.arch.
|
|
1340
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1341
|
-
)
|
|
1292
|
+
cute.arch.fence_view_async_shared()
|
|
1342
1293
|
cute.arch.sync_warp()
|
|
1343
1294
|
with cute.arch.elect_one():
|
|
1344
1295
|
epi_pipeline.consumer_release(epi_read_state)
|
|
@@ -1391,6 +1342,12 @@ class GemmSm90:
|
|
|
1391
1342
|
varlen_args,
|
|
1392
1343
|
):
|
|
1393
1344
|
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
|
1345
|
+
if const_expr(not self.is_persistent):
|
|
1346
|
+
persistence_mode = PersistenceMode.NONE
|
|
1347
|
+
elif const_expr(scheduler_args.tile_count_semaphore is not None):
|
|
1348
|
+
persistence_mode = PersistenceMode.DYNAMIC
|
|
1349
|
+
else:
|
|
1350
|
+
persistence_mode = PersistenceMode.STATIC
|
|
1394
1351
|
if const_expr(varlen_args.mCuSeqlensM is None):
|
|
1395
1352
|
num_problems = (
|
|
1396
1353
|
mD.shape[2]
|
|
@@ -1413,7 +1370,7 @@ class GemmSm90:
|
|
|
1413
1370
|
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1414
1371
|
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1415
1372
|
batch_idx_permute=scheduler_args.batch_idx_permute,
|
|
1416
|
-
|
|
1373
|
+
persistence_mode=persistence_mode,
|
|
1417
1374
|
)
|
|
1418
1375
|
else:
|
|
1419
1376
|
assert mD is not None or not self.gather_A
|
|
@@ -1431,7 +1388,7 @@ class GemmSm90:
|
|
|
1431
1388
|
tile_shape_mn=self.cta_tile_shape_mnk[:2],
|
|
1432
1389
|
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1433
1390
|
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1434
|
-
|
|
1391
|
+
persistence_mode=persistence_mode,
|
|
1435
1392
|
)
|
|
1436
1393
|
return tile_sched_args
|
|
1437
1394
|
|
|
@@ -1579,7 +1536,7 @@ class GemmSm90:
|
|
|
1579
1536
|
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
|
1580
1537
|
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
|
|
1581
1538
|
tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
|
|
1582
|
-
tRS_rD = cute.
|
|
1539
|
+
tRS_rD = cute.make_rmem_tensor(tRS_rD_shape, self.acc_dtype)
|
|
1583
1540
|
return tiled_copy_r2s, tRS_rD, tRS_sD
|
|
1584
1541
|
|
|
1585
1542
|
def epilog_smem_load_and_partition(
|
|
@@ -1596,7 +1553,7 @@ class GemmSm90:
|
|
|
1596
1553
|
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
|
1597
1554
|
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1598
1555
|
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1599
|
-
tRS_rC = cute.
|
|
1556
|
+
tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, dtype)
|
|
1600
1557
|
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
|
1601
1558
|
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
|
1602
1559
|
|
|
@@ -1651,6 +1608,7 @@ class GemmSm90:
|
|
|
1651
1608
|
consumer_group=ab_pipeline_consumer_group,
|
|
1652
1609
|
tx_count=self.num_tma_load_bytes,
|
|
1653
1610
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1611
|
+
defer_sync=True,
|
|
1654
1612
|
)
|
|
1655
1613
|
|
|
1656
1614
|
def make_epi_pipeline(
|
|
@@ -1670,6 +1628,7 @@ class GemmSm90:
|
|
|
1670
1628
|
producer_group=epi_pipeline_producer_group,
|
|
1671
1629
|
consumer_group=epi_pipeline_consumer_group,
|
|
1672
1630
|
tx_count=tma_copy_c_bytes,
|
|
1631
|
+
defer_sync=True,
|
|
1673
1632
|
)
|
|
1674
1633
|
|
|
1675
1634
|
def make_epi_store_pipeline(self):
|
|
@@ -1686,13 +1645,13 @@ class GemmSm90:
|
|
|
1686
1645
|
# Threads/warps participating in this pipeline
|
|
1687
1646
|
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1688
1647
|
cluster_size = cute.size(cluster_layout_mnk)
|
|
1689
|
-
# Each warp
|
|
1648
|
+
# Each warp will contribute 1 to the arrive count
|
|
1690
1649
|
# If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
|
|
1691
1650
|
# at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
|
|
1692
1651
|
consumer_arrive_cnt = (
|
|
1693
1652
|
(self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
|
|
1694
1653
|
+ self.num_ab_load_warps
|
|
1695
|
-
) * cluster_size
|
|
1654
|
+
) * cluster_size
|
|
1696
1655
|
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1697
1656
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1698
1657
|
)
|
|
@@ -1703,6 +1662,7 @@ class GemmSm90:
|
|
|
1703
1662
|
consumer_group=sched_pipeline_consumer_group,
|
|
1704
1663
|
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
1705
1664
|
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
|
1665
|
+
defer_sync=True,
|
|
1706
1666
|
)
|
|
1707
1667
|
|
|
1708
1668
|
@classmethod
|
|
@@ -1717,7 +1677,6 @@ class GemmSm90:
|
|
|
1717
1677
|
epilogue_args: EpilogueArguments,
|
|
1718
1678
|
smem_capacity: int,
|
|
1719
1679
|
occupancy: int,
|
|
1720
|
-
overlap_sD_sA: bool = False,
|
|
1721
1680
|
) -> Tuple[int, int]:
|
|
1722
1681
|
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
1723
1682
|
|
|
@@ -1738,16 +1697,11 @@ class GemmSm90:
|
|
|
1738
1697
|
"""
|
|
1739
1698
|
|
|
1740
1699
|
epi_stage = 4 if epi_tile[1] <= 16 else 2
|
|
1741
|
-
if
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
)
|
|
1747
|
-
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
|
1748
|
-
epilogue_args, cta_tile_shape_mnk, epi_tile
|
|
1749
|
-
)
|
|
1750
|
-
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
1700
|
+
d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
|
|
1701
|
+
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
|
1702
|
+
epilogue_args, cta_tile_shape_mnk, epi_tile
|
|
1703
|
+
)
|
|
1704
|
+
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
1751
1705
|
epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
|
|
1752
1706
|
if c_dtype is not None:
|
|
1753
1707
|
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
@@ -1765,7 +1719,7 @@ class GemmSm90:
|
|
|
1765
1719
|
# Refine epilogue stages:
|
|
1766
1720
|
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1767
1721
|
# Add remaining unused smem to epilogue
|
|
1768
|
-
if
|
|
1722
|
+
if epi_bytes_per_stage > 0:
|
|
1769
1723
|
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
|
|
1770
1724
|
return ab_stage, epi_stage, epi_c_stage
|
|
1771
1725
|
|
quack/gemm_symmetric.py
CHANGED
|
@@ -115,9 +115,7 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
|
115
115
|
pid_m = tile_coord_mnkl[0]
|
|
116
116
|
pid_n = tile_coord_mnkl[1]
|
|
117
117
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
118
|
-
cute.arch.
|
|
119
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
120
|
-
)
|
|
118
|
+
cute.arch.fence_view_async_shared()
|
|
121
119
|
epilogue_barrier.arrive_and_wait()
|
|
122
120
|
# Copy from shared memory to global memory
|
|
123
121
|
if is_tma_warp:
|
|
@@ -145,9 +143,7 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
|
145
143
|
epi_pipeline.consumer_wait(epi_read_state)
|
|
146
144
|
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
147
145
|
# Fence to make sure shared memory read is visible to TMA load
|
|
148
|
-
cute.arch.
|
|
149
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
150
|
-
)
|
|
146
|
+
cute.arch.fence_view_async_shared()
|
|
151
147
|
cute.arch.sync_warp()
|
|
152
148
|
with cute.arch.elect_one():
|
|
153
149
|
epi_pipeline.consumer_release(epi_read_state)
|
quack/layout_utils.py
CHANGED
|
@@ -6,8 +6,6 @@ import cutlass.cute as cute
|
|
|
6
6
|
|
|
7
7
|
from cutlass import Int32, const_expr
|
|
8
8
|
|
|
9
|
-
from quack.utils import prmt
|
|
10
|
-
|
|
11
9
|
|
|
12
10
|
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
|
13
11
|
"""Transpose the first two dimensions of a tensor on smem."""
|
|
@@ -55,8 +53,8 @@ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
|
|
55
53
|
lower0 = lower if lane_03 else upper
|
|
56
54
|
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
|
57
55
|
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
|
58
|
-
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
|
59
|
-
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
|
56
|
+
t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper)
|
|
57
|
+
t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower)
|
|
60
58
|
|
|
61
59
|
|
|
62
60
|
@cute.jit
|
quack/linear.py
CHANGED
|
@@ -9,6 +9,7 @@ from torch.amp import custom_fwd, custom_bwd
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
from quack.gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
|
|
12
|
+
from quack.gemm_interface import gemm_gated, gemm_dgated
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def linear_fwd_convert_type(*tensors):
|
|
@@ -228,6 +229,42 @@ def act_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=
|
|
|
228
229
|
return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
|
|
229
230
|
|
|
230
231
|
|
|
232
|
+
class LinearGatedFunc(LinearActFunc):
|
|
233
|
+
matmul_fwd_fn = gemm_gated
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class LinearGatedUntunedFunc(LinearActFunc):
|
|
237
|
+
# Passing in tuned=False to disable tuning at runtime
|
|
238
|
+
matmul_fwd_fn = partial(gemm_gated, tuned=False)
|
|
239
|
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
240
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
241
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def linear_gated_func(
|
|
245
|
+
x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
|
|
246
|
+
):
|
|
247
|
+
fn_cls = LinearGatedFunc if tuned else LinearGatedUntunedFunc
|
|
248
|
+
return fn_cls.apply(x, weight, activation, bias, store_preact, fuse_grad_accum)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class DGatedLinearFunc(DActLinearFunc):
|
|
252
|
+
matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class DGatedLinearUntunedFunc(DActLinearFunc):
|
|
256
|
+
# Passing in tuned=False to disable tuning at runtime
|
|
257
|
+
matmul_fwd_fn = partial(gemm, tuned=False)
|
|
258
|
+
matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True, tuned=False)
|
|
259
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
260
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def gated_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=True):
|
|
264
|
+
fn_cls = DGatedLinearFunc if tuned else DGatedLinearUntunedFunc
|
|
265
|
+
return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
|
|
266
|
+
|
|
267
|
+
|
|
231
268
|
class Linear(nn.Linear):
|
|
232
269
|
def __init__(
|
|
233
270
|
self,
|