quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +14 -18
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +2 -4
- quack/linear.py +37 -0
- quack/pipeline.py +59 -89
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +5 -3
- quack/sort/bitonic_sort.py +3 -3
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- quack_kernels-0.2.5.dist-info/RECORD +0 -45
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/gemm_sm100.py
CHANGED
|
@@ -13,6 +13,7 @@ import cutlass.cute as cute
|
|
|
13
13
|
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
14
14
|
import cutlass.torch as cutlass_torch
|
|
15
15
|
import cutlass.pipeline as pipeline
|
|
16
|
+
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
|
16
17
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
17
18
|
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
|
18
19
|
from cutlass.cute.nvgpu.warp import (
|
|
@@ -479,10 +480,12 @@ class GemmSm100(GemmSm90):
|
|
|
479
480
|
assert (varlen_args.mAIdx is not None) == self.gather_A
|
|
480
481
|
|
|
481
482
|
# Assume all strides are divisible by 128 bits except the last stride
|
|
482
|
-
new_stride
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
483
|
+
def new_stride(t: cute.Tensor):
|
|
484
|
+
return tuple(
|
|
485
|
+
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
|
486
|
+
for s in t.stride
|
|
487
|
+
)
|
|
488
|
+
|
|
486
489
|
mA, mD = [
|
|
487
490
|
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
488
491
|
if t is not None
|
|
@@ -631,7 +634,7 @@ class GemmSm100(GemmSm90):
|
|
|
631
634
|
a_prefetch_pipeline_array_ptr: cute.struct.MemRange[
|
|
632
635
|
cutlass.Int64, self.a_prefetch_stage * 2
|
|
633
636
|
]
|
|
634
|
-
|
|
637
|
+
scheduler_data: cute.struct.MemRange[Int32, self.sched_stage * 4]
|
|
635
638
|
tmem_dealloc_mbar_ptr: cutlass.Int64
|
|
636
639
|
tmem_holding_buf: Int32
|
|
637
640
|
sAIdx: cute.struct.Align[cute.struct.MemRange[Int32, a_idx_smem_size], 16]
|
|
@@ -758,9 +761,7 @@ class GemmSm100(GemmSm90):
|
|
|
758
761
|
|
|
759
762
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
760
763
|
|
|
761
|
-
#
|
|
762
|
-
# Prefetch Tma desc
|
|
763
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
764
|
+
# Prefetch Tma desc
|
|
764
765
|
if warp_idx == self.ab_load_warp_id:
|
|
765
766
|
for tma_atom in (
|
|
766
767
|
tma_atom_a,
|
|
@@ -775,9 +776,7 @@ class GemmSm100(GemmSm90):
|
|
|
775
776
|
|
|
776
777
|
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
|
|
777
778
|
|
|
778
|
-
#
|
|
779
779
|
# Setup cta/thread coordinates
|
|
780
|
-
#
|
|
781
780
|
# Coords inside cluster
|
|
782
781
|
bidx, _, _ = cute.arch.block_idx()
|
|
783
782
|
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
|
@@ -786,21 +785,10 @@ class GemmSm100(GemmSm90):
|
|
|
786
785
|
# Coord inside cta
|
|
787
786
|
tidx, _, _ = cute.arch.thread_idx()
|
|
788
787
|
|
|
789
|
-
#
|
|
790
788
|
# Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
|
|
791
|
-
#
|
|
792
789
|
smem = cutlass.utils.SmemAllocator()
|
|
793
790
|
storage = smem.allocate(self.shared_storage)
|
|
794
791
|
|
|
795
|
-
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
|
|
796
|
-
tmem_holding_buf = storage.tmem_holding_buf
|
|
797
|
-
|
|
798
|
-
# Tensor memory dealloc barrier init
|
|
799
|
-
if use_2cta_instrs:
|
|
800
|
-
if warp_idx == self.ab_load_warp_id:
|
|
801
|
-
num_tmem_dealloc_threads = 32
|
|
802
|
-
cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
|
|
803
|
-
|
|
804
792
|
# Initialize pipelines and states
|
|
805
793
|
ab_pipeline = self.make_ab_pipeline(
|
|
806
794
|
tiled_mma=tiled_mma,
|
|
@@ -819,21 +807,36 @@ class GemmSm100(GemmSm90):
|
|
|
819
807
|
acc_pipeline_mbar_ptr=storage.acc_pipeline_array_ptr.data_ptr(),
|
|
820
808
|
)
|
|
821
809
|
sched_pipeline = None
|
|
822
|
-
|
|
823
|
-
if const_expr(
|
|
824
|
-
# Dynamic persistent scheduler
|
|
810
|
+
scheduler_data = None
|
|
811
|
+
if const_expr(self.is_persistent):
|
|
825
812
|
sched_pipeline = self.make_sched_pipeline(
|
|
826
813
|
self.cluster_shape_mnk,
|
|
827
814
|
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
828
815
|
has_C=has_C,
|
|
829
816
|
)
|
|
830
|
-
|
|
817
|
+
scheduler_data = storage.scheduler_data.get_tensor((4, self.sched_stage))
|
|
831
818
|
a_prefetch_pipeline = None
|
|
832
819
|
if const_expr(self.gather_A):
|
|
833
820
|
a_prefetch_pipeline = self.make_a_prefetch_pipeline(
|
|
834
821
|
storage.a_prefetch_pipeline_array_ptr.data_ptr(),
|
|
835
822
|
)
|
|
836
823
|
|
|
824
|
+
tmem_alloc_barrier = pipeline.NamedBarrier(
|
|
825
|
+
barrier_id=int(NamedBarrierGemm.TmemPtr),
|
|
826
|
+
num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
|
|
827
|
+
)
|
|
828
|
+
# Tensor memory dealloc barrier init
|
|
829
|
+
tmem = cutlass.utils.TmemAllocator(
|
|
830
|
+
storage.tmem_holding_buf,
|
|
831
|
+
barrier_for_retrieve=tmem_alloc_barrier,
|
|
832
|
+
allocator_warp_id=self.epilog_warp_id[0],
|
|
833
|
+
is_two_cta=use_2cta_instrs,
|
|
834
|
+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
# Cluster arrive after barrier init
|
|
838
|
+
pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
|
|
839
|
+
|
|
837
840
|
# Setup smem tensor A/B/D
|
|
838
841
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
839
842
|
sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
|
@@ -884,22 +887,19 @@ class GemmSm100(GemmSm90):
|
|
|
884
887
|
)
|
|
885
888
|
|
|
886
889
|
TileSchedulerCls = partial(
|
|
887
|
-
TileSchedulerCls.create, tile_sched_params,
|
|
890
|
+
TileSchedulerCls.create, tile_sched_params, scheduler_data, sched_pipeline
|
|
888
891
|
)
|
|
889
892
|
|
|
890
|
-
tmem_alloc_barrier = pipeline.NamedBarrier(
|
|
891
|
-
barrier_id=int(NamedBarrierGemm.TmemPtr),
|
|
892
|
-
num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
|
|
893
|
-
)
|
|
894
893
|
epi_load_barrier = None
|
|
895
894
|
if const_expr(has_C):
|
|
896
895
|
epi_load_barrier = pipeline.NamedBarrier(
|
|
897
896
|
barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE
|
|
898
897
|
)
|
|
899
898
|
|
|
900
|
-
#
|
|
899
|
+
# Cluster wait before tensor memory alloc
|
|
900
|
+
pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
|
|
901
|
+
|
|
901
902
|
# Specialized AB load warps
|
|
902
|
-
#
|
|
903
903
|
if warp_idx == self.ab_load_warp_id:
|
|
904
904
|
is_tma_warp = True
|
|
905
905
|
# initialize tensormap for A & B
|
|
@@ -949,9 +949,7 @@ class GemmSm100(GemmSm90):
|
|
|
949
949
|
self.b_layout,
|
|
950
950
|
is_tma_warp,
|
|
951
951
|
)
|
|
952
|
-
#
|
|
953
|
-
# Local_tile partition global tensors
|
|
954
|
-
# ///////////////////////////////////////////////////////////////////////////
|
|
952
|
+
# Local_tile partition global tensors
|
|
955
953
|
mma_tile_coord_mnl = (
|
|
956
954
|
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
957
955
|
tile_coord_mnkl[1],
|
|
@@ -1093,9 +1091,7 @@ class GemmSm100(GemmSm90):
|
|
|
1093
1091
|
while work_tile.is_valid_tile:
|
|
1094
1092
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1095
1093
|
batch_idx = tile_coord_mnkl[3]
|
|
1096
|
-
#
|
|
1097
|
-
# Local_tile partition global tensors
|
|
1098
|
-
# ///////////////////////////////////////////////////////////////////////////
|
|
1094
|
+
# Local_tile partition global tensors
|
|
1099
1095
|
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
|
1100
1096
|
if const_expr(varlen_m):
|
|
1101
1097
|
# (M, K)
|
|
@@ -1153,10 +1149,8 @@ class GemmSm100(GemmSm90):
|
|
|
1153
1149
|
tile_scheduler.advance_to_next_work()
|
|
1154
1150
|
work_tile = tile_scheduler.get_current_work()
|
|
1155
1151
|
|
|
1156
|
-
#
|
|
1157
1152
|
# Specialized scheduler warp. Will also prefetch A indices if gatherA
|
|
1158
|
-
|
|
1159
|
-
if const_expr(tile_sched_params.tile_count_semaphore is not None or self.gather_A):
|
|
1153
|
+
if const_expr(self.is_persistent or self.gather_A):
|
|
1160
1154
|
if warp_idx == self.scheduler_warp_id:
|
|
1161
1155
|
is_scheduler_warp = True
|
|
1162
1156
|
if const_expr(cute.size(cluster_layout_vmnk) > 1):
|
|
@@ -1172,7 +1166,7 @@ class GemmSm100(GemmSm90):
|
|
|
1172
1166
|
cute.make_identity_tensor(tile_M if varlen_m else tile_K)
|
|
1173
1167
|
)
|
|
1174
1168
|
# Persistent tile scheduling loop
|
|
1175
|
-
tile_scheduler = TileSchedulerCls(
|
|
1169
|
+
tile_scheduler = TileSchedulerCls()
|
|
1176
1170
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1177
1171
|
a_prefetch_producer_state = None
|
|
1178
1172
|
if const_expr(self.gather_A):
|
|
@@ -1190,7 +1184,7 @@ class GemmSm100(GemmSm90):
|
|
|
1190
1184
|
tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
|
|
1191
1185
|
len_m = varlen_manager.len_m(batch_idx)
|
|
1192
1186
|
m_limit = len_m - tile_coord_mnkl[0] * tile_M
|
|
1193
|
-
tApAIdx_m = cute.
|
|
1187
|
+
tApAIdx_m = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean)
|
|
1194
1188
|
for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
|
|
1195
1189
|
tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit
|
|
1196
1190
|
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
|
|
@@ -1220,7 +1214,7 @@ class GemmSm100(GemmSm90):
|
|
|
1220
1214
|
if 0 < k_tile_cnt:
|
|
1221
1215
|
k_tile = k_tile_cnt - 1
|
|
1222
1216
|
k_limit = len_k - k_tile * tile_K
|
|
1223
|
-
tApAIdx_k = cute.
|
|
1217
|
+
tApAIdx_k = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean)
|
|
1224
1218
|
for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
|
|
1225
1219
|
tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit
|
|
1226
1220
|
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
|
|
@@ -1233,16 +1227,13 @@ class GemmSm100(GemmSm90):
|
|
|
1233
1227
|
a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
|
|
1234
1228
|
a_prefetch_producer_state.advance()
|
|
1235
1229
|
# Advance to next tile
|
|
1236
|
-
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1237
1230
|
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1238
1231
|
work_tile = tile_scheduler.get_current_work()
|
|
1239
1232
|
# End of persistent scheduler loop
|
|
1240
1233
|
if is_scheduler_warp:
|
|
1241
1234
|
tile_scheduler.producer_tail()
|
|
1242
1235
|
|
|
1243
|
-
#
|
|
1244
1236
|
# Specialized TMA epi load warp
|
|
1245
|
-
#
|
|
1246
1237
|
if const_expr(mC_mnl is not None):
|
|
1247
1238
|
if warp_idx == self.epi_load_warp_id:
|
|
1248
1239
|
epi_producer_state = pipeline.make_pipeline_state(
|
|
@@ -1281,15 +1272,11 @@ class GemmSm100(GemmSm90):
|
|
|
1281
1272
|
# End of persistent scheduler loop
|
|
1282
1273
|
epi_pipeline.producer_tail(epi_producer_state)
|
|
1283
1274
|
|
|
1284
|
-
#
|
|
1285
1275
|
# Specialized MMA warp
|
|
1286
|
-
#
|
|
1287
1276
|
if warp_idx == self.mma_warp_id:
|
|
1288
|
-
tmem_alloc_barrier.arrive_and_wait()
|
|
1289
1277
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
)
|
|
1278
|
+
tmem.wait_for_alloc()
|
|
1279
|
+
acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
|
1293
1280
|
# Partition shared/tensor memory tensor for TiledMMA_A/B/D
|
|
1294
1281
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
1295
1282
|
tCrA = tiled_mma.make_fragment_A(sA_mma)
|
|
@@ -1394,12 +1381,8 @@ class GemmSm100(GemmSm90):
|
|
|
1394
1381
|
#
|
|
1395
1382
|
if warp_idx < self.mma_warp_id:
|
|
1396
1383
|
# Alloc tensor memory buffer
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs
|
|
1400
|
-
)
|
|
1401
|
-
# Bar sync for retrieve tensor memory ptr from shared memory
|
|
1402
|
-
tmem_alloc_barrier.arrive_and_wait()
|
|
1384
|
+
tmem.allocate(self.num_tmem_alloc_cols)
|
|
1385
|
+
tmem.wait_for_alloc()
|
|
1403
1386
|
|
|
1404
1387
|
is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
|
|
1405
1388
|
varlen_manager.init_tensormap_epi(
|
|
@@ -1409,9 +1392,7 @@ class GemmSm100(GemmSm90):
|
|
|
1409
1392
|
tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
|
|
1410
1393
|
|
|
1411
1394
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
1412
|
-
acc_tmem_ptr =
|
|
1413
|
-
self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
|
|
1414
|
-
)
|
|
1395
|
+
acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
|
1415
1396
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
1416
1397
|
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
|
1417
1398
|
|
|
@@ -1426,7 +1407,7 @@ class GemmSm100(GemmSm90):
|
|
|
1426
1407
|
epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
|
|
1427
1408
|
)
|
|
1428
1409
|
|
|
1429
|
-
tTR_rD = cute.
|
|
1410
|
+
tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.acc_dtype)
|
|
1430
1411
|
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
|
1431
1412
|
tiled_copy_t2r, self.d_layout, self.d_dtype, tTR_rD, sD, epi_tidx
|
|
1432
1413
|
)
|
|
@@ -1535,22 +1516,14 @@ class GemmSm100(GemmSm90):
|
|
|
1535
1516
|
tile_scheduler.advance_to_next_work()
|
|
1536
1517
|
work_tile = tile_scheduler.get_current_work()
|
|
1537
1518
|
|
|
1538
|
-
# Dealloc the tensor memory buffer
|
|
1539
|
-
if warp_idx == self.epilog_warp_id[0]:
|
|
1540
|
-
cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
|
|
1541
|
-
epilogue_barrier.arrive_and_wait()
|
|
1542
|
-
if warp_idx == self.epilog_warp_id[0]:
|
|
1543
|
-
if const_expr(use_2cta_instrs):
|
|
1544
|
-
cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
|
|
1545
|
-
cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
|
|
1546
|
-
cute.arch.dealloc_tmem(
|
|
1547
|
-
acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs
|
|
1548
|
-
)
|
|
1549
|
-
|
|
1550
1519
|
# Wait for D store complete
|
|
1551
1520
|
if is_tma_warp:
|
|
1552
1521
|
epi_store_pipeline.producer_tail()
|
|
1553
1522
|
|
|
1523
|
+
# Dealloc the tensor memory buffer
|
|
1524
|
+
tmem.relinquish_alloc_permit()
|
|
1525
|
+
tmem.free(acc_tmem_ptr)
|
|
1526
|
+
|
|
1554
1527
|
@cute.jit
|
|
1555
1528
|
def load_A_gather_A(
|
|
1556
1529
|
self,
|
|
@@ -1565,9 +1538,7 @@ class GemmSm100(GemmSm90):
|
|
|
1565
1538
|
peek_a_empty_status = Boolean(True)
|
|
1566
1539
|
if 0 < k_tile_cnt:
|
|
1567
1540
|
peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
|
|
1568
|
-
# /////////////////////////////////////////////////////////////////////////
|
|
1569
1541
|
# cp.async on A
|
|
1570
|
-
# /////////////////////////////////////////////////////////////////////////
|
|
1571
1542
|
is_tma_warp = False
|
|
1572
1543
|
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1573
1544
|
smem_idx = a_producer_state.index
|
|
@@ -1787,7 +1758,7 @@ class GemmSm100(GemmSm90):
|
|
|
1787
1758
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
|
1788
1759
|
tTR_cAcc = thr_copy_t2r.partition_D(cAcc_epi)
|
|
1789
1760
|
# (T2R, T2R_M, T2R_N)
|
|
1790
|
-
tTR_rAcc = cute.
|
|
1761
|
+
tTR_rAcc = cute.make_rmem_tensor(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
|
|
1791
1762
|
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
|
|
1792
1763
|
|
|
1793
1764
|
def epilog_smem_store_and_partition(
|
|
@@ -1860,7 +1831,7 @@ class GemmSm100(GemmSm90):
|
|
|
1860
1831
|
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1861
1832
|
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1862
1833
|
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1863
|
-
tRS_rC = cute.
|
|
1834
|
+
tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, dtype)
|
|
1864
1835
|
# (R2S, R2S_M, R2S_N)
|
|
1865
1836
|
tSR_rC = tiled_copy_s2r.retile(tRS_rC)
|
|
1866
1837
|
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
|
@@ -1901,6 +1872,7 @@ class GemmSm100(GemmSm90):
|
|
|
1901
1872
|
consumer_group=ab_pipeline_consumer_group,
|
|
1902
1873
|
tx_count=self.num_tma_load_bytes,
|
|
1903
1874
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1875
|
+
defer_sync=True,
|
|
1904
1876
|
)
|
|
1905
1877
|
else:
|
|
1906
1878
|
pipeline_ab = PipelineTmaCpAsyncUmma.create(
|
|
@@ -1913,6 +1885,7 @@ class GemmSm100(GemmSm90):
|
|
|
1913
1885
|
producer_drop_count=None
|
|
1914
1886
|
if not self.use_2cta_instrs
|
|
1915
1887
|
else (2 if not is_leader_cta else 0),
|
|
1888
|
+
defer_sync=True,
|
|
1916
1889
|
)
|
|
1917
1890
|
return pipeline_ab
|
|
1918
1891
|
|
|
@@ -1930,6 +1903,7 @@ class GemmSm100(GemmSm90):
|
|
|
1930
1903
|
producer_group=acc_pipeline_producer_group,
|
|
1931
1904
|
consumer_group=acc_pipeline_consumer_group,
|
|
1932
1905
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1906
|
+
defer_sync=True,
|
|
1933
1907
|
)
|
|
1934
1908
|
|
|
1935
1909
|
def make_sched_pipeline(
|
|
@@ -1941,13 +1915,13 @@ class GemmSm100(GemmSm90):
|
|
|
1941
1915
|
# Threads/warps participating in this pipeline
|
|
1942
1916
|
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1943
1917
|
cluster_size = cute.size(cluster_layout_mnk)
|
|
1944
|
-
# Each warp
|
|
1918
|
+
# Each warp will contribute 1 to the arrive count
|
|
1945
1919
|
warps_per_cta = self.num_ab_load_warps + len(
|
|
1946
1920
|
(self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id)
|
|
1947
1921
|
)
|
|
1948
1922
|
if has_C:
|
|
1949
1923
|
warps_per_cta += 1
|
|
1950
|
-
consumer_arrive_cnt = warps_per_cta * cluster_size
|
|
1924
|
+
consumer_arrive_cnt = warps_per_cta * cluster_size
|
|
1951
1925
|
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1952
1926
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1953
1927
|
)
|
|
@@ -1958,6 +1932,7 @@ class GemmSm100(GemmSm90):
|
|
|
1958
1932
|
consumer_group=sched_pipeline_consumer_group,
|
|
1959
1933
|
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
1960
1934
|
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
|
1935
|
+
defer_sync=True,
|
|
1961
1936
|
)
|
|
1962
1937
|
|
|
1963
1938
|
@cute.jit
|
|
@@ -1965,9 +1940,7 @@ class GemmSm100(GemmSm90):
|
|
|
1965
1940
|
self, a_prefetch_pipeline_mbar_ptr: cute.Pointer
|
|
1966
1941
|
) -> pipeline.PipelineAsync:
|
|
1967
1942
|
producer_cnt = 32
|
|
1968
|
-
a_prefetch_producer_group = pipeline.CooperativeGroup(
|
|
1969
|
-
pipeline.Agent.Thread, producer_cnt, alignment=producer_cnt
|
|
1970
|
-
)
|
|
1943
|
+
a_prefetch_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
1971
1944
|
consumer_arrive_cnt = self.num_ab_load_warps - 1
|
|
1972
1945
|
a_prefetch_consumer_group = pipeline.CooperativeGroup(
|
|
1973
1946
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
@@ -1977,6 +1950,7 @@ class GemmSm100(GemmSm90):
|
|
|
1977
1950
|
num_stages=self.a_prefetch_stage,
|
|
1978
1951
|
producer_group=a_prefetch_producer_group,
|
|
1979
1952
|
consumer_group=a_prefetch_consumer_group,
|
|
1953
|
+
defer_sync=True,
|
|
1980
1954
|
)
|
|
1981
1955
|
|
|
1982
1956
|
@classmethod
|
|
@@ -2721,10 +2695,10 @@ def run(
|
|
|
2721
2695
|
tflops = flops / (timing * 1e9) # Convert to TFlops
|
|
2722
2696
|
print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
|
|
2723
2697
|
|
|
2724
|
-
|
|
2725
|
-
|
|
2726
|
-
|
|
2727
|
-
|
|
2698
|
+
time.sleep(0.5)
|
|
2699
|
+
timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2700
|
+
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2701
|
+
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2728
2702
|
|
|
2729
2703
|
|
|
2730
2704
|
if __name__ == "__main__":
|