quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/gemm_act.py CHANGED
@@ -20,7 +20,8 @@ from quack.gemm_sm90 import GemmSm90
20
20
  from quack.gemm_sm100 import GemmSm100
21
21
  from quack.gemm_default_epi import GemmDefaultEpiMixin
22
22
  from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
23
- from quack.gemm_wrapper_utils import GemmWrapperBase
23
+ from quack.gemm_wrapper_utils import GemmTensorInfo, GemmWrapperBase
24
+ from quack.layout_utils import permute_gated_Cregs_b16
24
25
  import quack.sm90_utils as sm90_utils
25
26
  import quack.copy_utils as copy_utils
26
27
  import quack.activation
@@ -241,9 +242,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
241
242
 
242
243
  def tma_store_fn(src_idx, dst_idx):
243
244
  # Fence and barrier to make sure shared memory store is visible to TMA store
244
- cute.arch.fence_proxy(
245
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
246
- )
245
+ cute.arch.fence_view_async_shared()
247
246
  epilogue_barrier.arrive_and_wait()
248
247
  # Copy from shared memory to global memory
249
248
  if is_tma_warp:
@@ -268,9 +267,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
268
267
  epi_pipeline.consumer_wait(epi_read_state)
269
268
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
270
269
  # Fence to make sure shared memory read is visible to TMA load
271
- cute.arch.fence_proxy(
272
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
273
- )
270
+ cute.arch.fence_view_async_shared()
274
271
  cute.arch.sync_warp()
275
272
  with cute.arch.elect_one():
276
273
  epi_pipeline.consumer_release(epi_read_state)
@@ -327,7 +324,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
327
324
  # Apply activation function if provided
328
325
  # If we don't have .shape here, the compiler generates local stores and loads
329
326
  if const_expr(params.act_fn is not None):
330
- tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
327
+ tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
331
328
  if const_expr(self.arch < 100):
332
329
  for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
333
330
  tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
@@ -508,3 +505,294 @@ def gemm_act(
508
505
 
509
506
 
510
507
  gemm_act.compile_cache = {}
508
+
509
+
510
+ class GemmGatedMixin(GemmActMixin):
511
+ def epi_to_underlying_arguments(
512
+ self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None
513
+ ) -> GemmActMixin.EpilogueParams:
514
+ self.postact_dtype = args.mPostAct.element_type
515
+ self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
516
+ assert self.postact_dtype.width == 16, "GemmGated only supports 16bit postact for now"
517
+ assert self.d_layout is None or self.d_layout.is_n_major_c()
518
+ assert self.postact_layout.is_n_major_c()
519
+ if self.arch == 90:
520
+ assert self.cta_tile_shape_mnk[1] % 32 == 0, (
521
+ "GemmGatedSm90 requires tileN to be divisible by 32"
522
+ )
523
+
524
+ self.cta_tile_shape_postact_mn = (
525
+ self.cta_tile_shape_mnk[0],
526
+ self.cta_tile_shape_mnk[1] // 2,
527
+ )
528
+ if isinstance(self.epi_tile[1], cute.Layout):
529
+ epi_tile_postact_1 = cute.recast_layout(2, 1, self.epi_tile[1])
530
+ else:
531
+ epi_tile_postact_1 = self.epi_tile[1] // 2
532
+ epi_tile_postact = (self.epi_tile[0], epi_tile_postact_1)
533
+ utils_cls = sm100_utils if self.arch == 100 else sm90_utils
534
+ epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
535
+ self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
536
+ )
537
+ tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
538
+ args.mPostAct,
539
+ epi_postact_smem_layout_staged,
540
+ epi_tile_postact,
541
+ op_type="store",
542
+ )
543
+ # Assume all strides are divisible by 32 bits except the last stride
544
+ new_stride = lambda t: tuple(
545
+ cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
546
+ for s in t.stride
547
+ )
548
+ mRowVecBroadcast, mColVecBroadcast = [
549
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
550
+ if t is not None
551
+ else None
552
+ for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
553
+ ]
554
+ return self.EpilogueParams(
555
+ tma_atom_postact,
556
+ tma_tensor_postact,
557
+ epi_postact_smem_layout_staged,
558
+ epi_tile_postact,
559
+ args.act_fn,
560
+ alpha=args.alpha,
561
+ beta=args.beta,
562
+ mRowVecBroadcast=mRowVecBroadcast,
563
+ mColVecBroadcast=mColVecBroadcast,
564
+ )
565
+
566
+ @staticmethod
567
+ def epi_smem_bytes_per_stage(
568
+ args: GemmActMixin.EpilogueArguments,
569
+ cta_tile_shape_mnk: Tuple[int, int, int],
570
+ epi_tile: cute.Tile,
571
+ ) -> int:
572
+ postact_dtype = args.mPostAct.element_type
573
+ postact_bytes_per_stage = (cute.size(cute.shape(epi_tile)) // 2) * (
574
+ postact_dtype.width // 8
575
+ )
576
+ rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
577
+ args, cta_tile_shape_mnk, epi_tile
578
+ )
579
+ return postact_bytes_per_stage + rowvec_colvec_bytes
580
+
581
+ @cute.jit
582
+ def epi_visit_subtile(
583
+ self,
584
+ params: GemmActMixin.EpilogueParams,
585
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
586
+ tRS_rD: cute.Tensor,
587
+ tRS_rC: Optional[cute.Tensor] = None,
588
+ ) -> Optional[cute.Tensor]:
589
+ GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
590
+ tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout)
591
+ # If we don't have .shape here, the compiler generates local stores and loads
592
+ tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype)
593
+ if const_expr(self.arch < 100):
594
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
595
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1])
596
+ else:
597
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
598
+ tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
599
+ (tRS_rD[4 * i], tRS_rD[4 * i + 2]), (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3])
600
+ )
601
+ # Type conversion
602
+ tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype)
603
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
604
+ if const_expr(self.arch == 90):
605
+ # Only need this if we're using STSM
606
+ permute_gated_Cregs_b16(tRS_rPostAct_out)
607
+ return tRS_rPostAct_out
608
+
609
+
610
+ class GemmGatedSm90(GemmGatedMixin, GemmSm90):
611
+ pass
612
+
613
+
614
+ class GemmGatedSm100(GemmGatedMixin, GemmSm100):
615
+ pass
616
+
617
+
618
+ gate_fn_map = {
619
+ "swiglu": quack.activation.swiglu,
620
+ "swiglu_oai": quack.activation.swiglu_oai,
621
+ "reglu": quack.activation.reglu,
622
+ "geglu": quack.activation.geglu,
623
+ "glu": quack.activation.glu,
624
+ }
625
+
626
+
627
+ def gemm_gated(
628
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
629
+ B: Tensor, # (l, n, k)
630
+ D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
631
+ C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
632
+ PostAct: Tensor, # (l, m, n//2) or (total_m, n//2) if varlen_m
633
+ tile_count_semaphore: Optional[Tensor], # (1,)
634
+ activation: Optional[str],
635
+ tile_M: int,
636
+ tile_N: int,
637
+ cluster_M: int,
638
+ cluster_N: int,
639
+ pingpong: bool = False,
640
+ persistent: bool = True,
641
+ max_swizzle_size: int = 8,
642
+ rowvec_bias: Optional[Tensor] = None, # (l, n)
643
+ colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
644
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
645
+ A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
646
+ ) -> None:
647
+ if cu_seqlens_m is not None:
648
+ assert persistent, "varlen_m requires persistent=True"
649
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
650
+ if D is not None:
651
+ assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
652
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
653
+ gather_A = A_idx is not None
654
+ if gather_A:
655
+ assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
656
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
657
+ assert activation in gate_fn_map, f"Unsupported activation {activation}"
658
+
659
+ # Special validation for PostAct shape
660
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
661
+ A, B, D, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
662
+ )
663
+
664
+ # PostAct shape validation depends on varlen_m
665
+ if cu_seqlens_m is not None:
666
+ # varlen_m case: PostAct is 2D (total_m, n//2)
667
+ assert PostAct.dim() == 2 and PostAct.is_cuda, (
668
+ "PostAct must be a 2D CUDA tensor for varlen_m"
669
+ )
670
+ assert PostAct.shape == (
671
+ M,
672
+ N // 2,
673
+ ), f"PostAct must have shape {(M, N // 2)}, got {PostAct.shape}"
674
+ else:
675
+ # Normal case: PostAct is 3D (l, m, n//2)
676
+ assert PostAct.dim() == 3 and PostAct.is_cuda, "PostAct must be a 3D CUDA tensor"
677
+ assert PostAct.shape == (
678
+ L,
679
+ M,
680
+ N // 2,
681
+ ), f"PostAct must have shape {(L, M, N // 2)}, got {PostAct.shape}"
682
+
683
+ tensor_infos["PostAct"] = GemmTensorInfo(PostAct)
684
+ GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
685
+ GemmWrapperBase.extract_dtypes(tensor_infos)
686
+ major_configs = {
687
+ "A": ("m", "k", "l"),
688
+ "B": ("n", "k", "l"),
689
+ "D": ("m", "n", "l"),
690
+ "C": ("m", "n", "l"),
691
+ "PostAct": ("m", "n", "l"), # PostAct has shape (m, n//2, l) after permute
692
+ }
693
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
694
+
695
+ device_capacity = get_device_capacity(A.device)
696
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
697
+ GemmCls = GemmGatedSm100 if device_capacity[0] > 9 else GemmGatedSm90
698
+
699
+ acc_dtype = cutlass.Float32
700
+ tile_shape_mn = (tile_M, tile_N)
701
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
702
+ if not GemmCls.is_valid_dtypes(
703
+ tensor_infos["A"].dtype,
704
+ tensor_infos["B"].dtype,
705
+ acc_dtype,
706
+ tensor_infos["D"].dtype,
707
+ tensor_infos["A"].major,
708
+ tensor_infos["B"].major,
709
+ ):
710
+ raise TypeError("Skipping due to unsupported combination of types and majors")
711
+
712
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
713
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
714
+ act_fn = gate_fn_map[activation]
715
+ epi_args = GemmCls.EpilogueArguments(
716
+ tensor_infos["PostAct"].cute_tensor,
717
+ act_fn,
718
+ mRowVecBroadcast=(
719
+ from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1)
720
+ if rowvec_bias is not None
721
+ else None
722
+ ),
723
+ mColVecBroadcast=(
724
+ from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
725
+ leading_dim=1 if cu_seqlens_m is None else 0
726
+ )
727
+ if colvec_bias is not None
728
+ else None
729
+ ),
730
+ )
731
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
732
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
733
+ )
734
+
735
+ # Create varlen arguments if needed (assumes persistent=True when varlen_m)
736
+ varlen_args = GemmWrapperBase.create_varlen_args(
737
+ cu_seqlens_m,
738
+ None, # cu_seqlens_k
739
+ A_idx,
740
+ max_active_clusters,
741
+ cluster_shape_mnk,
742
+ tensor_infos,
743
+ GemmCls.num_epi_tensormaps,
744
+ pingpong,
745
+ )
746
+
747
+ current_stream = cutlass_torch.current_stream()
748
+ compile_key = GemmWrapperBase.get_compile_key(
749
+ tensor_infos,
750
+ activation,
751
+ tile_shape_mn,
752
+ cluster_shape_mnk,
753
+ pingpong,
754
+ persistent,
755
+ tile_count_semaphore is not None,
756
+ device_capacity,
757
+ max_swizzle_size,
758
+ rowvec_bias.dtype if rowvec_bias is not None else None,
759
+ colvec_bias.dtype if colvec_bias is not None else None,
760
+ cu_seqlens_m is not None,
761
+ A_idx is not None,
762
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
763
+ )
764
+ cache = gemm_gated.compile_cache
765
+ if compile_key not in cache:
766
+ if device_capacity[0] == 9:
767
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
768
+ gemm_obj = GemmCls(
769
+ acc_dtype,
770
+ tensor_infos["A"].dtype,
771
+ tile_shape_mn,
772
+ cluster_shape_mnk,
773
+ gather_A=gather_A,
774
+ )
775
+ cache[compile_key] = cute.compile(
776
+ gemm_obj,
777
+ tensor_infos["A"].cute_tensor,
778
+ tensor_infos["B"].cute_tensor,
779
+ tensor_infos["D"].cute_tensor,
780
+ tensor_infos["C"].cute_tensor,
781
+ epi_args,
782
+ scheduler_args,
783
+ varlen_args,
784
+ current_stream,
785
+ )
786
+ cache[compile_key](
787
+ tensor_infos["A"].cute_tensor,
788
+ tensor_infos["B"].cute_tensor,
789
+ tensor_infos["D"].cute_tensor,
790
+ tensor_infos["C"].cute_tensor,
791
+ epi_args,
792
+ scheduler_args,
793
+ varlen_args,
794
+ current_stream,
795
+ )
796
+
797
+
798
+ gemm_gated.compile_cache = {}