quack-kernels 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/gemm_sm90.py CHANGED
@@ -12,6 +12,7 @@ import cuda.bindings.driver as cuda
12
12
  import cutlass
13
13
  import cutlass.cute as cute
14
14
  import cutlass.pipeline as pipeline
15
+ from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
15
16
  from cutlass.cute.nvgpu import cpasync, warp, warpgroup
16
17
  import cutlass.utils.hopper_helpers as sm90_utils
17
18
  from cutlass import Int32, Float32, Float16, Boolean, const_expr
@@ -26,6 +27,7 @@ from quack.tile_scheduler import (
26
27
  TileScheduler,
27
28
  VarlenMTileSchedulerArguments,
28
29
  VarlenMTileScheduler,
30
+ PersistenceMode,
29
31
  )
30
32
  from quack.varlen_utils import VarlenArguments, VarlenManager
31
33
 
@@ -226,8 +228,6 @@ class GemmSm90:
226
228
  self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
227
229
  self.num_ab_load_warps = 1 if not self.gather_A else 4
228
230
  self.ab_load_warp_id = self.mma_warp_groups * 4
229
- # self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
230
- # self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
231
231
 
232
232
  regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
233
233
  math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
@@ -324,8 +324,6 @@ class GemmSm90:
324
324
  epilogue_args,
325
325
  cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
326
326
  self.occupancy,
327
- # epi_smem will reuse smem ab if not persistent.
328
- overlap_sD_sA=not self.is_persistent,
329
327
  )
330
328
  self.sched_stage = 2 if self.pingpong else 1
331
329
 
@@ -401,10 +399,12 @@ class GemmSm90:
401
399
  assert (varlen_args.mAIdx is not None) == self.gather_A
402
400
 
403
401
  # Assume all strides are divisible by 128 bits except the last stride
404
- new_stride = lambda t: tuple(
405
- cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
406
- for s in t.stride
407
- )
402
+ def new_stride(t: cute.Tensor):
403
+ return tuple(
404
+ cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
405
+ for s in t.stride
406
+ )
407
+
408
408
  mA, mD = [
409
409
  cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
410
410
  if t is not None
@@ -461,9 +461,7 @@ class GemmSm90:
461
461
  tile_sched_params, scheduler_args.max_active_clusters
462
462
  )
463
463
 
464
- epi_smem_size = (
465
- cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
466
- )
464
+ epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
467
465
  epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
468
466
 
469
467
  @cute.struct
@@ -471,7 +469,7 @@ class GemmSm90:
471
469
  ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
472
470
  epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
473
471
  sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
474
- tile_count: cute.struct.MemRange[Int32, self.sched_stage]
472
+ scheduler_data: cute.struct.MemRange[Int32, self.sched_stage * 4]
475
473
  sD: cute.struct.Align[
476
474
  cute.struct.MemRange[
477
475
  self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
@@ -585,17 +583,13 @@ class GemmSm90:
585
583
 
586
584
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
587
585
 
588
- # /////////////////////////////////////////////////////////////////////////////
589
- # Prefetch Tma desc
590
- # /////////////////////////////////////////////////////////////////////////////
586
+ # Prefetch Tma desc
591
587
  if warp_idx == self.ab_load_warp_id:
592
588
  for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
593
589
  if const_expr(tma_atom is not None):
594
590
  cpasync.prefetch_descriptor(tma_atom)
595
591
 
596
- # /////////////////////////////////////////////////////////////////////////////
597
- # Alloc and init AB full/empty + ACC full mbar (pipeline)
598
- # /////////////////////////////////////////////////////////////////////////////
592
+ # Alloc and init AB full/empty + ACC full mbar (pipeline)
599
593
  smem = cutlass.utils.SmemAllocator()
600
594
  storage = smem.allocate(self.shared_storage)
601
595
 
@@ -611,28 +605,24 @@ class GemmSm90:
611
605
  epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
612
606
  )
613
607
  sched_pipeline = None
614
- tile_count = None
615
- if const_expr(tile_sched_params.tile_count_semaphore is not None):
616
- # Dynamic persistent scheduler
608
+ scheduler_data = None
609
+ if const_expr(self.is_persistent):
617
610
  sched_pipeline = self.make_sched_pipeline(
618
611
  cluster_layout_mnk,
619
612
  sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
620
613
  varlen_k=varlen_k,
621
614
  )
622
- tile_count = storage.tile_count.get_tensor((self.sched_stage,))
615
+ scheduler_data = storage.scheduler_data.get_tensor((4, self.sched_stage))
616
+
617
+ # Cluster arrive after barrier init
618
+ pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True)
623
619
 
624
- # ///////////////////////////////////////////////////////////////////////////////
625
- # Generate smem tensor A/B
626
- # ///////////////////////////////////////////////////////////////////////////////
620
+ # Generate smem tensor A/B
627
621
  sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
628
622
  sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
629
623
  sD = None
630
624
  if const_expr(has_D):
631
- if const_expr(not self.is_persistent):
632
- sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
633
- sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
634
- else:
635
- sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
625
+ sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
636
626
  sC = None
637
627
  if const_expr(has_C):
638
628
  sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
@@ -654,11 +644,14 @@ class GemmSm90:
654
644
  )
655
645
 
656
646
  TileSchedulerCls = partial(
657
- TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
647
+ TileSchedulerCls.create, tile_sched_params, scheduler_data, sched_pipeline
658
648
  )
659
649
 
650
+ # Cluster wait for barrier init
651
+ pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1])
652
+
660
653
  if warp_idx >= self.ab_load_warp_id:
661
- cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
654
+ cute.arch.setmaxregister_decrease(self.num_regs_load)
662
655
  if (
663
656
  warp_idx >= self.ab_load_warp_id
664
657
  and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
@@ -668,9 +661,7 @@ class GemmSm90:
668
661
  varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
669
662
  tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
670
663
  tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
671
- # ///////////////////////////////////////////////////////////////////////////////
672
664
  # Get mcast mask
673
- # ///////////////////////////////////////////////////////////////////////////////
674
665
  cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
675
666
  block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
676
667
  a_mcast_mask = cute.make_layout_image_mask(
@@ -686,7 +677,7 @@ class GemmSm90:
686
677
  is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
687
678
  if const_expr(cute.size(cluster_layout_mnk) > 1):
688
679
  is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
689
- tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
680
+ tile_scheduler = TileSchedulerCls()
690
681
  work_tile = tile_scheduler.initial_work_tile_info()
691
682
  ab_producer_state = make_pipeline_state(
692
683
  pipeline.PipelineUserType.Producer, self.ab_stage
@@ -698,14 +689,9 @@ class GemmSm90:
698
689
  tile_coord_mnkl = work_tile.tile_idx
699
690
  batch_idx = tile_coord_mnkl[3]
700
691
  varlen_manager.update_tensormap_AB(
701
- batch_idx,
702
- self.a_layout,
703
- self.b_layout,
704
- is_tma_warp,
692
+ batch_idx, self.a_layout, self.b_layout, is_tma_warp
705
693
  )
706
- # ///////////////////////////////////////////////////////////////////////////
707
- # Local_tile partition global tensors
708
- # ///////////////////////////////////////////////////////////////////////////
694
+ # Local_tile partition global tensors
709
695
  if const_expr(not self.gather_A):
710
696
  mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
711
697
  # (bM, bK, RestK)
@@ -736,9 +722,7 @@ class GemmSm90:
736
722
  cute.select(self.cta_tile_shape_mnk, [1, 2]),
737
723
  (tile_coord_mnkl[1], None),
738
724
  )
739
- # //////////////////////////////////////////////////////////////////////////
740
- # Partition shared tensor for TMA load A/B
741
- # //////////////////////////////////////////////////////////////////////////
725
+ # Partition shared tensor for TMA load A/B
742
726
  varlen_manager.fence_tensormap_update_AB(is_tma_warp)
743
727
  len_m = varlen_manager.len_m(batch_idx)
744
728
  len_k = varlen_manager.len_k(batch_idx)
@@ -810,19 +794,20 @@ class GemmSm90:
810
794
  k_tile_cnt,
811
795
  varlen_m=varlen_m,
812
796
  )
813
- tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
814
797
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
815
798
  work_tile = tile_scheduler.get_current_work()
816
799
  # End of persistent scheduler loop
817
800
  if const_expr(self.pingpong and not varlen_k):
818
801
  # Need to write the tile_idx to smem for the next WG in the pingpong mode
819
- tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
802
+ if is_scheduler_warp:
803
+ tile_scheduler.write_work_tile_to_smem(work_tile)
804
+ work_tile = tile_scheduler.get_current_work()
820
805
  ab_pipeline.producer_tail(ab_producer_state)
821
806
  if is_scheduler_warp:
822
807
  tile_scheduler.producer_tail()
823
808
 
824
809
  if warp_idx < self.ab_load_warp_id:
825
- cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
810
+ cute.arch.setmaxregister_increase(self.num_regs_mma)
826
811
  is_tma_warp = Boolean(
827
812
  (not self.pingpong and warp_idx == 0)
828
813
  or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
@@ -832,34 +817,30 @@ class GemmSm90:
832
817
  )
833
818
  tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
834
819
  tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
835
- # //////////////////////////////////////////////////////////////////////////////
836
- # Partition global tensor for TiledMMA_A/B/C
837
- # //////////////////////////////////////////////////////////////////////////////
820
+ # Partition global tensor for TiledMMA_A/B/C
838
821
  tidx, _, _ = cute.arch.thread_idx()
839
822
  warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
840
823
  if const_expr(self.pingpong):
841
824
  tidx = tidx % self.num_threads_per_warp_group
842
825
  warp_group_thread_layout = cute.make_layout(
843
- self.mma_warp_groups if not self.pingpong else 1,
826
+ self.mma_warp_groups if const_expr(not self.pingpong) else 1,
844
827
  stride=self.num_threads_per_warp_group,
845
828
  )
846
829
  thr_mma = tiled_mma.get_slice(
847
830
  warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
848
831
  )
849
832
 
850
- # //////////////////////////////////////////////////////////////////////////////
851
- # Make fragments
852
- # //////////////////////////////////////////////////////////////////////////////
833
+ # Make fragments
853
834
  tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
854
835
  tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
855
836
 
856
837
  acc_shape = tiled_mma.partition_shape_C(
857
838
  cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
858
839
  )
859
- acc = cute.make_fragment(acc_shape, self.acc_dtype)
840
+ acc = cute.make_rmem_tensor(acc_shape, self.acc_dtype)
860
841
  acc_slow = None
861
842
  if const_expr(self.fp8_slow_accum):
862
- acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
843
+ acc_slow = cute.make_rmem_tensor(acc_shape, self.acc_dtype)
863
844
 
864
845
  if const_expr(self.pingpong):
865
846
  if warp_group_idx == 0:
@@ -879,10 +860,8 @@ class GemmSm90:
879
860
  pipeline.PipelineUserType.Producer, self.epi_c_stage
880
861
  )
881
862
  tile_scheduler = TileSchedulerCls()
882
- work_tile = None
863
+ work_tile = tile_scheduler.initial_work_tile_info()
883
864
  if const_expr(self.pingpong):
884
- if const_expr(varlen_k):
885
- work_tile = tile_scheduler.initial_work_tile_info()
886
865
  if warp_idx >= 4:
887
866
  # Advance 2nd Math WG pipeline states to the end of 1st Math WG
888
867
  epi_read_state.advance_iters(c_tile_cnt)
@@ -893,13 +872,9 @@ class GemmSm90:
893
872
  len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
894
873
  k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
895
874
  ab_read_state.advance_iters(k_tile_cnt)
875
+ # TODO: do we need to check if work_tile is valid?
896
876
  tile_scheduler.advance_to_next_work()
897
- if const_expr(varlen_k):
898
- work_tile = tile_scheduler.get_current_work()
899
- if const_expr(not varlen_k):
900
- work_tile = tile_scheduler.initial_work_tile_info()
901
- else:
902
- work_tile = tile_scheduler.initial_work_tile_info()
877
+ work_tile = tile_scheduler.get_current_work()
903
878
  if const_expr(varlen_m):
904
879
  # wait tensormap initialization complete before update
905
880
  varlen_manager.fence_tensormap_init()
@@ -910,11 +885,7 @@ class GemmSm90:
910
885
  epilogue_params, varlen_params.cu_seqlens_m, batch_idx
911
886
  )
912
887
  varlen_manager.update_tensormap_epi(
913
- batch_idx,
914
- self.d_layout,
915
- epi_shapes,
916
- epi_orders,
917
- is_tma_warp,
888
+ batch_idx, self.d_layout, epi_shapes, epi_orders, is_tma_warp
918
889
  )
919
890
  len_k = varlen_manager.len_k(batch_idx)
920
891
  k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
@@ -933,9 +904,7 @@ class GemmSm90:
933
904
  if k_tile_cnt == 0:
934
905
  acc.fill(0.0)
935
906
 
936
- # /////////////////////////////////////////////////////////////////////////////
937
- # EPILOGUE
938
- # /////////////////////////////////////////////////////////////////////////////
907
+ # EPILOGUE
939
908
  if const_expr(self.pingpong):
940
909
  self.pingpong_barrier_sync(warp_group_idx, "epi")
941
910
 
@@ -983,11 +952,6 @@ class GemmSm90:
983
952
  else:
984
953
  tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
985
954
 
986
- # Wait for all warp groups in the thread block to finish, because smem for tensor
987
- # A in the mainloop is reused in the epilogue if not persistent.
988
- if const_expr(not self.is_persistent):
989
- epilogue_barrier.arrive_and_wait()
990
-
991
955
  self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
992
956
 
993
957
  epi_read_state, epi_producer_state = self.epilogue(
@@ -1073,9 +1037,7 @@ class GemmSm90:
1073
1037
  peek_ab_empty_status = Boolean(True)
1074
1038
  if 0 < k_tile_cnt:
1075
1039
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1076
- # /////////////////////////////////////////////////////////////////////////
1077
1040
  # TMA load
1078
- # /////////////////////////////////////////////////////////////////////////
1079
1041
  for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1080
1042
  # Wait for A/B buffers to be empty before loading into them
1081
1043
  # Also sets the transaction barrier for the A/B buffers
@@ -1112,9 +1074,7 @@ class GemmSm90:
1112
1074
  peek_ab_empty_status = Boolean(True)
1113
1075
  if 0 < k_tile_cnt:
1114
1076
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1115
- # /////////////////////////////////////////////////////////////////////////
1116
1077
  # TMA load on B and cp.async on A
1117
- # /////////////////////////////////////////////////////////////////////////
1118
1078
  for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1119
1079
  prefetch_out = ()
1120
1080
  if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
@@ -1172,9 +1132,7 @@ class GemmSm90:
1172
1132
  k_tile_cnt: Int32,
1173
1133
  warp_group_idx: Int32,
1174
1134
  ) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
1175
- # /////////////////////////////////////////////////////////////////////////////
1176
- # Prologue MMAs
1177
- # /////////////////////////////////////////////////////////////////////////////
1135
+ # Prologue MMAs
1178
1136
  k_pipe_mmas = 1
1179
1137
  ab_release_state = ab_read_state.clone()
1180
1138
  num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
@@ -1204,13 +1162,10 @@ class GemmSm90:
1204
1162
  warpgroup.wait_group(0)
1205
1163
  acc_slow.store(acc.load())
1206
1164
 
1207
- # /////////////////////////////////////////////////////////////////////////////
1208
- # MAINLOOP
1209
- # /////////////////////////////////////////////////////////////////////////////
1165
+ # MAINLOOP
1210
1166
  for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
1211
1167
  # Wait for TMA copies to complete
1212
1168
  ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1213
- # WGMMA
1214
1169
  warpgroup.fence()
1215
1170
  if const_expr(self.fp8_slow_accum):
1216
1171
  tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
@@ -1308,9 +1263,7 @@ class GemmSm90:
1308
1263
 
1309
1264
  def tma_store_fn(src_idx, dst_idx):
1310
1265
  # Fence and barrier to make sure shared memory store is visible to TMA store
1311
- cute.arch.fence_proxy(
1312
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1313
- )
1266
+ cute.arch.fence_view_async_shared()
1314
1267
  epilogue_barrier.arrive_and_wait()
1315
1268
  # Copy from shared memory to global memory
1316
1269
  if is_tma_warp:
@@ -1336,9 +1289,7 @@ class GemmSm90:
1336
1289
  epi_pipeline.consumer_wait(epi_read_state)
1337
1290
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
1338
1291
  # Fence to make sure shared memory read is visible to TMA load
1339
- cute.arch.fence_proxy(
1340
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1341
- )
1292
+ cute.arch.fence_view_async_shared()
1342
1293
  cute.arch.sync_warp()
1343
1294
  with cute.arch.elect_one():
1344
1295
  epi_pipeline.consumer_release(epi_read_state)
@@ -1391,6 +1342,12 @@ class GemmSm90:
1391
1342
  varlen_args,
1392
1343
  ):
1393
1344
  """Create scheduler arguments. Override in subclasses for custom schedulers."""
1345
+ if const_expr(not self.is_persistent):
1346
+ persistence_mode = PersistenceMode.NONE
1347
+ elif const_expr(scheduler_args.tile_count_semaphore is not None):
1348
+ persistence_mode = PersistenceMode.DYNAMIC
1349
+ else:
1350
+ persistence_mode = PersistenceMode.STATIC
1394
1351
  if const_expr(varlen_args.mCuSeqlensM is None):
1395
1352
  num_problems = (
1396
1353
  mD.shape[2]
@@ -1413,7 +1370,7 @@ class GemmSm90:
1413
1370
  cluster_shape_mnk=self.cluster_shape_mnk,
1414
1371
  tile_count_semaphore=scheduler_args.tile_count_semaphore,
1415
1372
  batch_idx_permute=scheduler_args.batch_idx_permute,
1416
- is_persistent=self.is_persistent,
1373
+ persistence_mode=persistence_mode,
1417
1374
  )
1418
1375
  else:
1419
1376
  assert mD is not None or not self.gather_A
@@ -1431,7 +1388,7 @@ class GemmSm90:
1431
1388
  tile_shape_mn=self.cta_tile_shape_mnk[:2],
1432
1389
  cluster_shape_mnk=self.cluster_shape_mnk,
1433
1390
  tile_count_semaphore=scheduler_args.tile_count_semaphore,
1434
- is_persistent=self.is_persistent,
1391
+ persistence_mode=persistence_mode,
1435
1392
  )
1436
1393
  return tile_sched_args
1437
1394
 
@@ -1579,7 +1536,7 @@ class GemmSm90:
1579
1536
  tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1580
1537
  sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
1581
1538
  tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
1582
- tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
1539
+ tRS_rD = cute.make_rmem_tensor(tRS_rD_shape, self.acc_dtype)
1583
1540
  return tiled_copy_r2s, tRS_rD, tRS_sD
1584
1541
 
1585
1542
  def epilog_smem_load_and_partition(
@@ -1596,7 +1553,7 @@ class GemmSm90:
1596
1553
  tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1597
1554
  thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1598
1555
  tSR_sC = thr_copy_s2r.partition_S(sC)
1599
- tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
1556
+ tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, dtype)
1600
1557
  tSR_rC = thr_copy_s2r.retile(tRS_rC)
1601
1558
  return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
1602
1559
 
@@ -1651,6 +1608,7 @@ class GemmSm90:
1651
1608
  consumer_group=ab_pipeline_consumer_group,
1652
1609
  tx_count=self.num_tma_load_bytes,
1653
1610
  cta_layout_vmnk=cluster_layout_vmnk,
1611
+ defer_sync=True,
1654
1612
  )
1655
1613
 
1656
1614
  def make_epi_pipeline(
@@ -1670,6 +1628,7 @@ class GemmSm90:
1670
1628
  producer_group=epi_pipeline_producer_group,
1671
1629
  consumer_group=epi_pipeline_consumer_group,
1672
1630
  tx_count=tma_copy_c_bytes,
1631
+ defer_sync=True,
1673
1632
  )
1674
1633
 
1675
1634
  def make_epi_store_pipeline(self):
@@ -1686,13 +1645,13 @@ class GemmSm90:
1686
1645
  # Threads/warps participating in this pipeline
1687
1646
  sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1688
1647
  cluster_size = cute.size(cluster_layout_mnk)
1689
- # Each warp that are not the scheduler warp will contribute 1 to the arrive count
1648
+ # Each warp will contribute 1 to the arrive count
1690
1649
  # If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
1691
1650
  # at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
1692
1651
  consumer_arrive_cnt = (
1693
1652
  (self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
1694
1653
  + self.num_ab_load_warps
1695
- ) * cluster_size - 1
1654
+ ) * cluster_size
1696
1655
  sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1697
1656
  pipeline.Agent.Thread, consumer_arrive_cnt
1698
1657
  )
@@ -1703,6 +1662,7 @@ class GemmSm90:
1703
1662
  consumer_group=sched_pipeline_consumer_group,
1704
1663
  # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1705
1664
  consumer_mask=None if const_expr(cluster_size == 1) else 0,
1665
+ defer_sync=True,
1706
1666
  )
1707
1667
 
1708
1668
  @classmethod
@@ -1717,7 +1677,6 @@ class GemmSm90:
1717
1677
  epilogue_args: EpilogueArguments,
1718
1678
  smem_capacity: int,
1719
1679
  occupancy: int,
1720
- overlap_sD_sA: bool = False,
1721
1680
  ) -> Tuple[int, int]:
1722
1681
  """Computes the number of stages for A/B/C operands based on heuristics.
1723
1682
 
@@ -1738,16 +1697,11 @@ class GemmSm90:
1738
1697
  """
1739
1698
 
1740
1699
  epi_stage = 4 if epi_tile[1] <= 16 else 2
1741
- if overlap_sD_sA:
1742
- epi_bytes = 0
1743
- else:
1744
- d_bytes_per_stage = (
1745
- cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1746
- )
1747
- epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1748
- epilogue_args, cta_tile_shape_mnk, epi_tile
1749
- )
1750
- epi_bytes = epi_bytes_per_stage * epi_stage
1700
+ d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1701
+ epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1702
+ epilogue_args, cta_tile_shape_mnk, epi_tile
1703
+ )
1704
+ epi_bytes = epi_bytes_per_stage * epi_stage
1751
1705
  epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
1752
1706
  if c_dtype is not None:
1753
1707
  epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
@@ -1765,7 +1719,7 @@ class GemmSm90:
1765
1719
  # Refine epilogue stages:
1766
1720
  # Calculate remaining smem after allocating for A/B stages and reserved bytes
1767
1721
  # Add remaining unused smem to epilogue
1768
- if not overlap_sD_sA and epi_bytes_per_stage > 0:
1722
+ if epi_bytes_per_stage > 0:
1769
1723
  epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
1770
1724
  return ab_stage, epi_stage, epi_c_stage
1771
1725
 
quack/gemm_symmetric.py CHANGED
@@ -115,9 +115,7 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
115
115
  pid_m = tile_coord_mnkl[0]
116
116
  pid_n = tile_coord_mnkl[1]
117
117
  # Fence and barrier to make sure shared memory store is visible to TMA store
118
- cute.arch.fence_proxy(
119
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
120
- )
118
+ cute.arch.fence_view_async_shared()
121
119
  epilogue_barrier.arrive_and_wait()
122
120
  # Copy from shared memory to global memory
123
121
  if is_tma_warp:
@@ -145,9 +143,7 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
145
143
  epi_pipeline.consumer_wait(epi_read_state)
146
144
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
147
145
  # Fence to make sure shared memory read is visible to TMA load
148
- cute.arch.fence_proxy(
149
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
150
- )
146
+ cute.arch.fence_view_async_shared()
151
147
  cute.arch.sync_warp()
152
148
  with cute.arch.elect_one():
153
149
  epi_pipeline.consumer_release(epi_read_state)
quack/layout_utils.py CHANGED
@@ -6,8 +6,6 @@ import cutlass.cute as cute
6
6
 
7
7
  from cutlass import Int32, const_expr
8
8
 
9
- from quack.utils import prmt
10
-
11
9
 
12
10
  def transpose_view(a: cute.Tensor) -> cute.Tensor:
13
11
  """Transpose the first two dimensions of a tensor on smem."""
@@ -55,8 +53,8 @@ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
55
53
  lower0 = lower if lane_03 else upper
56
54
  upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
57
55
  lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
58
- t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
59
- t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
56
+ t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper)
57
+ t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower)
60
58
 
61
59
 
62
60
  @cute.jit
@@ -187,6 +185,10 @@ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
187
185
  return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
188
186
 
189
187
 
188
+ def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor:
189
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
190
+
191
+
190
192
  @cute.jit
191
193
  def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
192
194
  # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
@@ -227,6 +229,10 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
227
229
  return rA_mma_view
228
230
 
229
231
 
232
+ def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
233
+ return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
234
+
235
+
230
236
  def convert_layout_zero_stride(
231
237
  input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
232
238
  ) -> cute.Layout:
quack/linear.py CHANGED
@@ -9,6 +9,7 @@ from torch.amp import custom_fwd, custom_bwd
9
9
 
10
10
 
11
11
  from quack.gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
12
+ from quack.gemm_interface import gemm_gated, gemm_dgated
12
13
 
13
14
 
14
15
  def linear_fwd_convert_type(*tensors):
@@ -228,6 +229,42 @@ def act_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=
228
229
  return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
229
230
 
230
231
 
232
+ class LinearGatedFunc(LinearActFunc):
233
+ matmul_fwd_fn = gemm_gated
234
+
235
+
236
+ class LinearGatedUntunedFunc(LinearActFunc):
237
+ # Passing in tuned=False to disable tuning at runtime
238
+ matmul_fwd_fn = partial(gemm_gated, tuned=False)
239
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
240
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
241
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False)
242
+
243
+
244
+ def linear_gated_func(
245
+ x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
246
+ ):
247
+ fn_cls = LinearGatedFunc if tuned else LinearGatedUntunedFunc
248
+ return fn_cls.apply(x, weight, activation, bias, store_preact, fuse_grad_accum)
249
+
250
+
251
+ class DGatedLinearFunc(DActLinearFunc):
252
+ matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True)
253
+
254
+
255
+ class DGatedLinearUntunedFunc(DActLinearFunc):
256
+ # Passing in tuned=False to disable tuning at runtime
257
+ matmul_fwd_fn = partial(gemm, tuned=False)
258
+ matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True, tuned=False)
259
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
260
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False)
261
+
262
+
263
+ def gated_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=True):
264
+ fn_cls = DGatedLinearFunc if tuned else DGatedLinearUntunedFunc
265
+ return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
266
+
267
+
231
268
  class Linear(nn.Linear):
232
269
  def __init__(
233
270
  self,