quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/sm100_utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Type, Union
|
|
4
|
+
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
import cutlass.utils.blackwell_helpers as sm100_utils_og
|
|
7
|
+
from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
|
|
8
|
+
from cutlass.cutlass_dsl import Numeric, dsl_user_op
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dsl_user_op
|
|
12
|
+
def make_smem_layout_cpasync_a(
|
|
13
|
+
tiled_mma: cute.TiledMma,
|
|
14
|
+
mma_tiler_mnk: cute.Tile,
|
|
15
|
+
a_dtype: Type[Numeric],
|
|
16
|
+
num_stages: int,
|
|
17
|
+
*,
|
|
18
|
+
loc=None,
|
|
19
|
+
ip=None,
|
|
20
|
+
) -> Union[cute.Layout, cute.ComposedLayout]:
|
|
21
|
+
"""
|
|
22
|
+
:param tiled_mma: The tiled MMA used to partition tensor A
|
|
23
|
+
:type tiled_mma: cute.TiledMma
|
|
24
|
+
:param mma_tiler_mnk: The MMA tile shape
|
|
25
|
+
:type mma_tiler_mnk: cute.cute.Tile
|
|
26
|
+
:param a_dtype: The element type for tensor A
|
|
27
|
+
:type a_dtype: Type[Numeric]
|
|
28
|
+
:param num_stages: The number of pipeline stages for tensor A
|
|
29
|
+
:type num_stages: int
|
|
30
|
+
|
|
31
|
+
:return: SMEM layout for tensor A
|
|
32
|
+
:rtype: Union[cute.Layout, cute.ComposedLayout]
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
|
|
36
|
+
a_smem_shape = tiled_mma.partition_shape_A(
|
|
37
|
+
cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip)
|
|
38
|
+
)
|
|
39
|
+
a_smem_shape_mn_k = (
|
|
40
|
+
cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1],
|
|
41
|
+
cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
|
|
42
|
+
)
|
|
43
|
+
a_smem_layout_atom = sm100_utils_og.make_smem_layout_atom(
|
|
44
|
+
sm100_utils_og.get_smem_layout_atom_ab(
|
|
45
|
+
tiled_mma.op.a_major_mode,
|
|
46
|
+
a_dtype,
|
|
47
|
+
a_smem_shape_mn_k,
|
|
48
|
+
loc=loc,
|
|
49
|
+
ip=ip,
|
|
50
|
+
),
|
|
51
|
+
a_dtype,
|
|
52
|
+
loc=loc,
|
|
53
|
+
ip=ip,
|
|
54
|
+
)
|
|
55
|
+
a_smem_layout_staged = cute.tile_to_shape(
|
|
56
|
+
a_smem_layout_atom,
|
|
57
|
+
cute.append(a_smem_shape_mn_k, num_stages, loc=loc, ip=ip),
|
|
58
|
+
order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
|
|
59
|
+
loc=loc,
|
|
60
|
+
ip=ip,
|
|
61
|
+
)
|
|
62
|
+
return a_smem_layout_staged
|
quack/sm90_utils.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Type, Union, Optional
|
|
4
|
+
|
|
5
|
+
import cutlass
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
|
8
|
+
from cutlass.cute.nvgpu import warpgroup
|
|
9
|
+
from cutlass.cutlass_dsl import Numeric, dsl_user_op
|
|
10
|
+
from cutlass import Float32, Int32, Boolean, const_expr
|
|
11
|
+
from cutlass.utils import LayoutEnum
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dsl_user_op
|
|
15
|
+
def make_smem_layout(
|
|
16
|
+
dtype: Type[Numeric],
|
|
17
|
+
layout: LayoutEnum,
|
|
18
|
+
tile: cute.Tile,
|
|
19
|
+
stage: Optional[int] = None,
|
|
20
|
+
*,
|
|
21
|
+
loc=None,
|
|
22
|
+
ip=None,
|
|
23
|
+
) -> Union[cute.Layout, cute.ComposedLayout]:
|
|
24
|
+
shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip)
|
|
25
|
+
major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
|
|
26
|
+
smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
27
|
+
sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
|
|
28
|
+
dtype,
|
|
29
|
+
)
|
|
30
|
+
smem_layout_staged = cute.tile_to_shape(
|
|
31
|
+
smem_layout_atom,
|
|
32
|
+
cute.append(shape, stage) if const_expr(stage is not None) else shape,
|
|
33
|
+
order=(1, 0, 2) if layout.is_m_major_c() else (0, 1, 2),
|
|
34
|
+
)
|
|
35
|
+
return smem_layout_staged
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# For compatibility with blackwell_helpers.py
|
|
39
|
+
make_smem_layout_epi = make_smem_layout
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dsl_user_op
|
|
43
|
+
def partition_for_epilogue(
|
|
44
|
+
cT: cute.Tensor,
|
|
45
|
+
epi_tile: cute.Tile,
|
|
46
|
+
tiled_copy: cute.TiledCopy,
|
|
47
|
+
tidx: Int32,
|
|
48
|
+
reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy
|
|
49
|
+
*,
|
|
50
|
+
loc=None,
|
|
51
|
+
ip=None,
|
|
52
|
+
) -> cute.Tensor:
|
|
53
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
54
|
+
cT_epi = cute.flat_divide(cT, epi_tile)
|
|
55
|
+
# (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
|
|
56
|
+
if const_expr(reference_src):
|
|
57
|
+
return thr_copy.partition_S(cT_epi, loc=loc, ip=ip)
|
|
58
|
+
else:
|
|
59
|
+
return thr_copy.partition_D(cT_epi, loc=loc, ip=ip)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@cute.jit
|
|
63
|
+
def gemm(
|
|
64
|
+
tiled_mma: cute.TiledMma,
|
|
65
|
+
acc: cute.Tensor,
|
|
66
|
+
tCrA: cute.Tensor,
|
|
67
|
+
tCrB: cute.Tensor,
|
|
68
|
+
zero_init: cutlass.Constexpr[bool] = False,
|
|
69
|
+
wg_wait: cutlass.Constexpr[int] = 0,
|
|
70
|
+
# A_in_regs: cutlass.Constexpr[bool] = False,
|
|
71
|
+
swap_AB: cutlass.Constexpr[bool] = False,
|
|
72
|
+
) -> None:
|
|
73
|
+
if const_expr(swap_AB):
|
|
74
|
+
gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
|
|
75
|
+
else:
|
|
76
|
+
warpgroup.fence()
|
|
77
|
+
# We make a new mma_atom since we'll be modifying its attribute (accumulate).
|
|
78
|
+
# Otherwise the compiler complains "operand #0 does not dominate this use"
|
|
79
|
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
|
80
|
+
mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
|
|
81
|
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
|
82
|
+
cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
|
83
|
+
mma_atom.set(warpgroup.Field.ACCUMULATE, True)
|
|
84
|
+
warpgroup.commit_group()
|
|
85
|
+
if const_expr(wg_wait >= 0):
|
|
86
|
+
warpgroup.wait_group(wg_wait)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def gemm_zero_init(
|
|
90
|
+
tiled_mma: cute.TiledMma,
|
|
91
|
+
shape: cute.Shape,
|
|
92
|
+
tCrA: cute.Tensor,
|
|
93
|
+
tCrB: cute.Tensor,
|
|
94
|
+
A_idx: Optional[Int32] = None,
|
|
95
|
+
B_idx: Optional[Int32] = None,
|
|
96
|
+
wg_wait: int = -1,
|
|
97
|
+
swap_AB: bool = False,
|
|
98
|
+
) -> cute.Tensor:
|
|
99
|
+
if const_expr(swap_AB):
|
|
100
|
+
return gemm_zero_init(
|
|
101
|
+
tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
|
|
105
|
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
106
|
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
107
|
+
gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
|
|
108
|
+
return acc
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def gemm_w_idx(
|
|
112
|
+
tiled_mma: cute.TiledMma,
|
|
113
|
+
acc: cute.Tensor,
|
|
114
|
+
tCrA: cute.Tensor,
|
|
115
|
+
tCrB: cute.Tensor,
|
|
116
|
+
zero_init: Boolean,
|
|
117
|
+
A_idx: Optional[Int32] = None,
|
|
118
|
+
B_idx: Optional[Int32] = None,
|
|
119
|
+
wg_wait: int = -1,
|
|
120
|
+
swap_AB: bool = False,
|
|
121
|
+
) -> None:
|
|
122
|
+
if const_expr(swap_AB):
|
|
123
|
+
gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
|
|
124
|
+
else:
|
|
125
|
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
126
|
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
127
|
+
gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
|