quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +14 -18
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +2 -4
- quack/linear.py +37 -0
- quack/pipeline.py +59 -89
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +5 -3
- quack/sort/bitonic_sort.py +3 -3
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- quack_kernels-0.2.5.dist-info/RECORD +0 -45
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/gemm_act.py
CHANGED
|
@@ -20,7 +20,8 @@ from quack.gemm_sm90 import GemmSm90
|
|
|
20
20
|
from quack.gemm_sm100 import GemmSm100
|
|
21
21
|
from quack.gemm_default_epi import GemmDefaultEpiMixin
|
|
22
22
|
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
|
23
|
-
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
23
|
+
from quack.gemm_wrapper_utils import GemmTensorInfo, GemmWrapperBase
|
|
24
|
+
from quack.layout_utils import permute_gated_Cregs_b16
|
|
24
25
|
import quack.sm90_utils as sm90_utils
|
|
25
26
|
import quack.copy_utils as copy_utils
|
|
26
27
|
import quack.activation
|
|
@@ -241,9 +242,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
|
|
|
241
242
|
|
|
242
243
|
def tma_store_fn(src_idx, dst_idx):
|
|
243
244
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
244
|
-
cute.arch.
|
|
245
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
246
|
-
)
|
|
245
|
+
cute.arch.fence_view_async_shared()
|
|
247
246
|
epilogue_barrier.arrive_and_wait()
|
|
248
247
|
# Copy from shared memory to global memory
|
|
249
248
|
if is_tma_warp:
|
|
@@ -268,9 +267,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
|
|
|
268
267
|
epi_pipeline.consumer_wait(epi_read_state)
|
|
269
268
|
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
270
269
|
# Fence to make sure shared memory read is visible to TMA load
|
|
271
|
-
cute.arch.
|
|
272
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
273
|
-
)
|
|
270
|
+
cute.arch.fence_view_async_shared()
|
|
274
271
|
cute.arch.sync_warp()
|
|
275
272
|
with cute.arch.elect_one():
|
|
276
273
|
epi_pipeline.consumer_release(epi_read_state)
|
|
@@ -327,7 +324,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
|
|
|
327
324
|
# Apply activation function if provided
|
|
328
325
|
# If we don't have .shape here, the compiler generates local stores and loads
|
|
329
326
|
if const_expr(params.act_fn is not None):
|
|
330
|
-
tRS_rPostAct = cute.
|
|
327
|
+
tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
|
|
331
328
|
if const_expr(self.arch < 100):
|
|
332
329
|
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
333
330
|
tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
|
|
@@ -508,3 +505,294 @@ def gemm_act(
|
|
|
508
505
|
|
|
509
506
|
|
|
510
507
|
gemm_act.compile_cache = {}
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class GemmGatedMixin(GemmActMixin):
|
|
511
|
+
def epi_to_underlying_arguments(
|
|
512
|
+
self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None
|
|
513
|
+
) -> GemmActMixin.EpilogueParams:
|
|
514
|
+
self.postact_dtype = args.mPostAct.element_type
|
|
515
|
+
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
|
516
|
+
assert self.postact_dtype.width == 16, "GemmGated only supports 16bit postact for now"
|
|
517
|
+
assert self.d_layout is None or self.d_layout.is_n_major_c()
|
|
518
|
+
assert self.postact_layout.is_n_major_c()
|
|
519
|
+
if self.arch == 90:
|
|
520
|
+
assert self.cta_tile_shape_mnk[1] % 32 == 0, (
|
|
521
|
+
"GemmGatedSm90 requires tileN to be divisible by 32"
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
self.cta_tile_shape_postact_mn = (
|
|
525
|
+
self.cta_tile_shape_mnk[0],
|
|
526
|
+
self.cta_tile_shape_mnk[1] // 2,
|
|
527
|
+
)
|
|
528
|
+
if isinstance(self.epi_tile[1], cute.Layout):
|
|
529
|
+
epi_tile_postact_1 = cute.recast_layout(2, 1, self.epi_tile[1])
|
|
530
|
+
else:
|
|
531
|
+
epi_tile_postact_1 = self.epi_tile[1] // 2
|
|
532
|
+
epi_tile_postact = (self.epi_tile[0], epi_tile_postact_1)
|
|
533
|
+
utils_cls = sm100_utils if self.arch == 100 else sm90_utils
|
|
534
|
+
epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
|
|
535
|
+
self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
|
|
536
|
+
)
|
|
537
|
+
tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
|
|
538
|
+
args.mPostAct,
|
|
539
|
+
epi_postact_smem_layout_staged,
|
|
540
|
+
epi_tile_postact,
|
|
541
|
+
op_type="store",
|
|
542
|
+
)
|
|
543
|
+
# Assume all strides are divisible by 32 bits except the last stride
|
|
544
|
+
new_stride = lambda t: tuple(
|
|
545
|
+
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
|
546
|
+
for s in t.stride
|
|
547
|
+
)
|
|
548
|
+
mRowVecBroadcast, mColVecBroadcast = [
|
|
549
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
550
|
+
if t is not None
|
|
551
|
+
else None
|
|
552
|
+
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
|
553
|
+
]
|
|
554
|
+
return self.EpilogueParams(
|
|
555
|
+
tma_atom_postact,
|
|
556
|
+
tma_tensor_postact,
|
|
557
|
+
epi_postact_smem_layout_staged,
|
|
558
|
+
epi_tile_postact,
|
|
559
|
+
args.act_fn,
|
|
560
|
+
alpha=args.alpha,
|
|
561
|
+
beta=args.beta,
|
|
562
|
+
mRowVecBroadcast=mRowVecBroadcast,
|
|
563
|
+
mColVecBroadcast=mColVecBroadcast,
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
@staticmethod
|
|
567
|
+
def epi_smem_bytes_per_stage(
|
|
568
|
+
args: GemmActMixin.EpilogueArguments,
|
|
569
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
570
|
+
epi_tile: cute.Tile,
|
|
571
|
+
) -> int:
|
|
572
|
+
postact_dtype = args.mPostAct.element_type
|
|
573
|
+
postact_bytes_per_stage = (cute.size(cute.shape(epi_tile)) // 2) * (
|
|
574
|
+
postact_dtype.width // 8
|
|
575
|
+
)
|
|
576
|
+
rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
|
|
577
|
+
args, cta_tile_shape_mnk, epi_tile
|
|
578
|
+
)
|
|
579
|
+
return postact_bytes_per_stage + rowvec_colvec_bytes
|
|
580
|
+
|
|
581
|
+
@cute.jit
|
|
582
|
+
def epi_visit_subtile(
|
|
583
|
+
self,
|
|
584
|
+
params: GemmActMixin.EpilogueParams,
|
|
585
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
586
|
+
tRS_rD: cute.Tensor,
|
|
587
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
588
|
+
) -> Optional[cute.Tensor]:
|
|
589
|
+
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
590
|
+
tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout)
|
|
591
|
+
# If we don't have .shape here, the compiler generates local stores and loads
|
|
592
|
+
tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype)
|
|
593
|
+
if const_expr(self.arch < 100):
|
|
594
|
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
595
|
+
tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1])
|
|
596
|
+
else:
|
|
597
|
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
|
598
|
+
tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
|
|
599
|
+
(tRS_rD[4 * i], tRS_rD[4 * i + 2]), (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3])
|
|
600
|
+
)
|
|
601
|
+
# Type conversion
|
|
602
|
+
tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype)
|
|
603
|
+
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
|
604
|
+
if const_expr(self.arch == 90):
|
|
605
|
+
# Only need this if we're using STSM
|
|
606
|
+
permute_gated_Cregs_b16(tRS_rPostAct_out)
|
|
607
|
+
return tRS_rPostAct_out
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class GemmGatedSm90(GemmGatedMixin, GemmSm90):
|
|
611
|
+
pass
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
class GemmGatedSm100(GemmGatedMixin, GemmSm100):
|
|
615
|
+
pass
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
gate_fn_map = {
|
|
619
|
+
"swiglu": quack.activation.swiglu,
|
|
620
|
+
"swiglu_oai": quack.activation.swiglu_oai,
|
|
621
|
+
"reglu": quack.activation.reglu,
|
|
622
|
+
"geglu": quack.activation.geglu,
|
|
623
|
+
"glu": quack.activation.glu,
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def gemm_gated(
|
|
628
|
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
|
629
|
+
B: Tensor, # (l, n, k)
|
|
630
|
+
D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
631
|
+
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
632
|
+
PostAct: Tensor, # (l, m, n//2) or (total_m, n//2) if varlen_m
|
|
633
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
634
|
+
activation: Optional[str],
|
|
635
|
+
tile_M: int,
|
|
636
|
+
tile_N: int,
|
|
637
|
+
cluster_M: int,
|
|
638
|
+
cluster_N: int,
|
|
639
|
+
pingpong: bool = False,
|
|
640
|
+
persistent: bool = True,
|
|
641
|
+
max_swizzle_size: int = 8,
|
|
642
|
+
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
|
643
|
+
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
|
644
|
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
645
|
+
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
646
|
+
) -> None:
|
|
647
|
+
if cu_seqlens_m is not None:
|
|
648
|
+
assert persistent, "varlen_m requires persistent=True"
|
|
649
|
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
650
|
+
if D is not None:
|
|
651
|
+
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
|
652
|
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
|
653
|
+
gather_A = A_idx is not None
|
|
654
|
+
if gather_A:
|
|
655
|
+
assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
|
|
656
|
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
657
|
+
assert activation in gate_fn_map, f"Unsupported activation {activation}"
|
|
658
|
+
|
|
659
|
+
# Special validation for PostAct shape
|
|
660
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
661
|
+
A, B, D, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# PostAct shape validation depends on varlen_m
|
|
665
|
+
if cu_seqlens_m is not None:
|
|
666
|
+
# varlen_m case: PostAct is 2D (total_m, n//2)
|
|
667
|
+
assert PostAct.dim() == 2 and PostAct.is_cuda, (
|
|
668
|
+
"PostAct must be a 2D CUDA tensor for varlen_m"
|
|
669
|
+
)
|
|
670
|
+
assert PostAct.shape == (
|
|
671
|
+
M,
|
|
672
|
+
N // 2,
|
|
673
|
+
), f"PostAct must have shape {(M, N // 2)}, got {PostAct.shape}"
|
|
674
|
+
else:
|
|
675
|
+
# Normal case: PostAct is 3D (l, m, n//2)
|
|
676
|
+
assert PostAct.dim() == 3 and PostAct.is_cuda, "PostAct must be a 3D CUDA tensor"
|
|
677
|
+
assert PostAct.shape == (
|
|
678
|
+
L,
|
|
679
|
+
M,
|
|
680
|
+
N // 2,
|
|
681
|
+
), f"PostAct must have shape {(L, M, N // 2)}, got {PostAct.shape}"
|
|
682
|
+
|
|
683
|
+
tensor_infos["PostAct"] = GemmTensorInfo(PostAct)
|
|
684
|
+
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
|
685
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
686
|
+
major_configs = {
|
|
687
|
+
"A": ("m", "k", "l"),
|
|
688
|
+
"B": ("n", "k", "l"),
|
|
689
|
+
"D": ("m", "n", "l"),
|
|
690
|
+
"C": ("m", "n", "l"),
|
|
691
|
+
"PostAct": ("m", "n", "l"), # PostAct has shape (m, n//2, l) after permute
|
|
692
|
+
}
|
|
693
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
694
|
+
|
|
695
|
+
device_capacity = get_device_capacity(A.device)
|
|
696
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
697
|
+
GemmCls = GemmGatedSm100 if device_capacity[0] > 9 else GemmGatedSm90
|
|
698
|
+
|
|
699
|
+
acc_dtype = cutlass.Float32
|
|
700
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
701
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
702
|
+
if not GemmCls.is_valid_dtypes(
|
|
703
|
+
tensor_infos["A"].dtype,
|
|
704
|
+
tensor_infos["B"].dtype,
|
|
705
|
+
acc_dtype,
|
|
706
|
+
tensor_infos["D"].dtype,
|
|
707
|
+
tensor_infos["A"].major,
|
|
708
|
+
tensor_infos["B"].major,
|
|
709
|
+
):
|
|
710
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
711
|
+
|
|
712
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
713
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
714
|
+
act_fn = gate_fn_map[activation]
|
|
715
|
+
epi_args = GemmCls.EpilogueArguments(
|
|
716
|
+
tensor_infos["PostAct"].cute_tensor,
|
|
717
|
+
act_fn,
|
|
718
|
+
mRowVecBroadcast=(
|
|
719
|
+
from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1)
|
|
720
|
+
if rowvec_bias is not None
|
|
721
|
+
else None
|
|
722
|
+
),
|
|
723
|
+
mColVecBroadcast=(
|
|
724
|
+
from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
|
725
|
+
leading_dim=1 if cu_seqlens_m is None else 0
|
|
726
|
+
)
|
|
727
|
+
if colvec_bias is not None
|
|
728
|
+
else None
|
|
729
|
+
),
|
|
730
|
+
)
|
|
731
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
732
|
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
|
736
|
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
737
|
+
cu_seqlens_m,
|
|
738
|
+
None, # cu_seqlens_k
|
|
739
|
+
A_idx,
|
|
740
|
+
max_active_clusters,
|
|
741
|
+
cluster_shape_mnk,
|
|
742
|
+
tensor_infos,
|
|
743
|
+
GemmCls.num_epi_tensormaps,
|
|
744
|
+
pingpong,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
current_stream = cutlass_torch.current_stream()
|
|
748
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
749
|
+
tensor_infos,
|
|
750
|
+
activation,
|
|
751
|
+
tile_shape_mn,
|
|
752
|
+
cluster_shape_mnk,
|
|
753
|
+
pingpong,
|
|
754
|
+
persistent,
|
|
755
|
+
tile_count_semaphore is not None,
|
|
756
|
+
device_capacity,
|
|
757
|
+
max_swizzle_size,
|
|
758
|
+
rowvec_bias.dtype if rowvec_bias is not None else None,
|
|
759
|
+
colvec_bias.dtype if colvec_bias is not None else None,
|
|
760
|
+
cu_seqlens_m is not None,
|
|
761
|
+
A_idx is not None,
|
|
762
|
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
763
|
+
)
|
|
764
|
+
cache = gemm_gated.compile_cache
|
|
765
|
+
if compile_key not in cache:
|
|
766
|
+
if device_capacity[0] == 9:
|
|
767
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
768
|
+
gemm_obj = GemmCls(
|
|
769
|
+
acc_dtype,
|
|
770
|
+
tensor_infos["A"].dtype,
|
|
771
|
+
tile_shape_mn,
|
|
772
|
+
cluster_shape_mnk,
|
|
773
|
+
gather_A=gather_A,
|
|
774
|
+
)
|
|
775
|
+
cache[compile_key] = cute.compile(
|
|
776
|
+
gemm_obj,
|
|
777
|
+
tensor_infos["A"].cute_tensor,
|
|
778
|
+
tensor_infos["B"].cute_tensor,
|
|
779
|
+
tensor_infos["D"].cute_tensor,
|
|
780
|
+
tensor_infos["C"].cute_tensor,
|
|
781
|
+
epi_args,
|
|
782
|
+
scheduler_args,
|
|
783
|
+
varlen_args,
|
|
784
|
+
current_stream,
|
|
785
|
+
)
|
|
786
|
+
cache[compile_key](
|
|
787
|
+
tensor_infos["A"].cute_tensor,
|
|
788
|
+
tensor_infos["B"].cute_tensor,
|
|
789
|
+
tensor_infos["D"].cute_tensor,
|
|
790
|
+
tensor_infos["C"].cute_tensor,
|
|
791
|
+
epi_args,
|
|
792
|
+
scheduler_args,
|
|
793
|
+
varlen_args,
|
|
794
|
+
current_stream,
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
gemm_gated.compile_cache = {}
|