quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_config.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright (C) 2025, Fri Dao.
|
|
2
2
|
import itertools
|
|
3
|
-
from typing import Optional, List
|
|
3
|
+
from typing import Optional, List, Literal
|
|
4
|
+
from functools import partial
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
|
|
6
7
|
|
|
@@ -13,57 +14,82 @@ class GemmConfig:
|
|
|
13
14
|
cluster_n: int = 1
|
|
14
15
|
swap_ab: bool = False
|
|
15
16
|
# raster_order: int = 1
|
|
16
|
-
|
|
17
|
+
max_swizzle_size: int = 8
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def get_all_configs(
|
|
21
|
+
device_capacity: Literal[9, 10] = 9,
|
|
20
22
|
epilogue: Optional[str] = None,
|
|
21
23
|
tune_coop: bool = True,
|
|
22
24
|
# tune_raster_order=True,
|
|
23
25
|
) -> List[GemmConfig]:
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
(
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
tile_mn_vals
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
if epilogue in ["lse"]:
|
|
26
|
+
assert device_capacity in [9, 10]
|
|
27
|
+
if device_capacity == 9:
|
|
28
|
+
tile_n_vals = [128, 144, 160, 176, 192, 208]
|
|
29
|
+
tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
|
|
30
|
+
(128, 224),
|
|
31
|
+
(128, 256),
|
|
32
|
+
# (192, 256), # Getting IOT instruction (core dumped) in the bwd
|
|
33
|
+
]
|
|
34
|
+
tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
|
|
35
|
+
if epilogue in ["gated"]:
|
|
36
|
+
tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
|
|
37
|
+
tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
|
|
38
|
+
elif epilogue in ["lse"]:
|
|
39
|
+
tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
|
|
40
|
+
tile_mn_vals = []
|
|
41
|
+
if tune_coop:
|
|
42
|
+
tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
|
|
43
|
+
tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
|
|
43
44
|
cluster = [(1, 2), (2, 1)]
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
45
|
+
# cluster = [(1, 1), (1, 2), (2, 1)]
|
|
46
|
+
if epilogue in ["lse"]:
|
|
47
|
+
cluster = [(1, 2), (2, 1)]
|
|
48
|
+
swap_ab_vals = [False, True]
|
|
49
|
+
if epilogue in ["lse", "gated"]:
|
|
50
|
+
swap_ab_vals = [False]
|
|
51
|
+
# raster_swizzle = (
|
|
52
|
+
# [(0, 1)]
|
|
53
|
+
# if not tune_raster_order
|
|
54
|
+
# else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
|
|
55
|
+
# )
|
|
56
|
+
return [
|
|
57
|
+
GemmConfig(
|
|
58
|
+
tile_m=tile_m,
|
|
59
|
+
tile_n=tile_n,
|
|
60
|
+
pingpong=pingpong,
|
|
61
|
+
cluster_m=cluster_m,
|
|
62
|
+
cluster_n=cluster_n,
|
|
63
|
+
swap_ab=swap_ab,
|
|
64
|
+
# raster_order=raster_order,
|
|
65
|
+
# max_swizzle_size=max_swizzle_size,
|
|
66
|
+
)
|
|
67
|
+
for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
|
|
68
|
+
tile_mn_vals,
|
|
69
|
+
cluster,
|
|
70
|
+
swap_ab_vals,
|
|
71
|
+
# raster_swizzle,
|
|
72
|
+
)
|
|
73
|
+
]
|
|
74
|
+
elif device_capacity == 10:
|
|
75
|
+
tile_n_vals = [128, 160, 192, 224, 256]
|
|
76
|
+
tile_n_64_vals = [128, 192, 256]
|
|
77
|
+
tile_mn_cluster_vals = (
|
|
78
|
+
[(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
|
|
79
|
+
# + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals]
|
|
80
|
+
+ [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
|
|
81
|
+
+ [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
|
|
62
82
|
)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
83
|
+
swap_ab_vals = [False, True]
|
|
84
|
+
if epilogue in ["lse", "gated"]:
|
|
85
|
+
swap_ab_vals = [False]
|
|
86
|
+
max_swizzle_size_vals = [4, 8, 16]
|
|
87
|
+
GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100
|
|
88
|
+
return [
|
|
89
|
+
GemmConfigCls(
|
|
90
|
+
tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms
|
|
91
|
+
)
|
|
92
|
+
for (m, n, (cm, cn)), sab, ms in itertools.product(
|
|
93
|
+
tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals
|
|
94
|
+
)
|
|
95
|
+
]
|
|
@@ -1,40 +1,57 @@
|
|
|
1
1
|
# Copyright (c) 2025, Tri Dao.
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
from functools import partial
|
|
3
4
|
|
|
4
5
|
from torch import Tensor
|
|
5
6
|
|
|
6
7
|
import cutlass
|
|
7
8
|
import cutlass.cute as cute
|
|
8
|
-
from cutlass import const_expr
|
|
9
|
+
from cutlass import Float32, const_expr
|
|
9
10
|
import cutlass.torch as cutlass_torch
|
|
10
11
|
|
|
11
|
-
from quack.
|
|
12
|
-
from quack.
|
|
12
|
+
from quack.gemm_sm90 import GemmSm90
|
|
13
|
+
from quack.gemm_sm100 import GemmSm100
|
|
14
|
+
from quack.gemm_default_epi import GemmDefaultEpiMixin
|
|
15
|
+
from quack.gemm_act import GemmActMixin
|
|
16
|
+
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
|
13
17
|
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
14
18
|
import quack.activation
|
|
15
19
|
|
|
16
20
|
|
|
17
|
-
class
|
|
21
|
+
class GemmDActMixin(GemmActMixin):
|
|
18
22
|
# Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
|
|
19
23
|
# and return 2 arguments (dx, out)
|
|
20
|
-
EpilogueArguments =
|
|
21
|
-
EpilogueParams =
|
|
24
|
+
EpilogueArguments = GemmActMixin.EpilogueArguments
|
|
25
|
+
EpilogueParams = GemmActMixin.EpilogueParams
|
|
22
26
|
|
|
23
27
|
@cute.jit
|
|
24
|
-
def
|
|
28
|
+
def epi_visit_subtile(
|
|
25
29
|
self,
|
|
26
30
|
params: EpilogueParams,
|
|
31
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
27
32
|
tRS_rD: cute.Tensor,
|
|
28
33
|
tRS_rC: Optional[cute.Tensor] = None,
|
|
29
34
|
) -> Optional[cute.Tensor]:
|
|
30
35
|
assert tRS_rC is not None
|
|
36
|
+
# We don't add C to the accumulator
|
|
37
|
+
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
|
|
31
38
|
tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
|
|
32
39
|
tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
|
|
33
40
|
# If we don't have .shape here, the compiler generates local stores and loads
|
|
34
41
|
if const_expr(params.act_fn is not None):
|
|
35
42
|
tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
|
|
36
|
-
|
|
37
|
-
|
|
43
|
+
if const_expr(self.arch < 100):
|
|
44
|
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
45
|
+
tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
|
|
46
|
+
else:
|
|
47
|
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
|
48
|
+
(
|
|
49
|
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1]),
|
|
50
|
+
(tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]),
|
|
51
|
+
) = params.act_fn(
|
|
52
|
+
(tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]),
|
|
53
|
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1]),
|
|
54
|
+
)
|
|
38
55
|
else:
|
|
39
56
|
tRS_rPostAct = tRS_rC_acc
|
|
40
57
|
# Type conversion
|
|
@@ -43,6 +60,14 @@ class GemmDActSm90(GemmActSm90):
|
|
|
43
60
|
return tRS_rPostAct_out
|
|
44
61
|
|
|
45
62
|
|
|
63
|
+
class GemmDActSm90(GemmDActMixin, GemmSm90):
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GemmDActSm100(GemmDActMixin, GemmSm100):
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
|
|
46
71
|
dact_fn_map = {
|
|
47
72
|
None: None,
|
|
48
73
|
"relu": quack.activation.drelu,
|
|
@@ -51,7 +76,7 @@ dact_fn_map = {
|
|
|
51
76
|
}
|
|
52
77
|
|
|
53
78
|
|
|
54
|
-
def
|
|
79
|
+
def gemm_dact(
|
|
55
80
|
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
|
56
81
|
B: Tensor, # (l, n, k)
|
|
57
82
|
Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
@@ -65,6 +90,7 @@ def gemm_dact_sm90(
|
|
|
65
90
|
cluster_N: int,
|
|
66
91
|
pingpong: bool = True,
|
|
67
92
|
persistent: bool = True,
|
|
93
|
+
max_swizzle_size: int = 8,
|
|
68
94
|
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
69
95
|
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
70
96
|
) -> None:
|
|
@@ -100,10 +126,14 @@ def gemm_dact_sm90(
|
|
|
100
126
|
}
|
|
101
127
|
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
102
128
|
|
|
103
|
-
|
|
129
|
+
device_capacity = get_device_capacity(A.device)
|
|
130
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
131
|
+
GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90
|
|
132
|
+
|
|
133
|
+
acc_dtype = Float32
|
|
104
134
|
tile_shape_mn = (tile_M, tile_N)
|
|
105
135
|
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
106
|
-
if not
|
|
136
|
+
if not GemmCls.is_valid_dtypes(
|
|
107
137
|
tensor_infos["A"].dtype,
|
|
108
138
|
tensor_infos["B"].dtype,
|
|
109
139
|
acc_dtype,
|
|
@@ -116,9 +146,9 @@ def gemm_dact_sm90(
|
|
|
116
146
|
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
117
147
|
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
118
148
|
act_fn = dact_fn_map[activation]
|
|
119
|
-
epi_args =
|
|
149
|
+
epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
|
|
120
150
|
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
121
|
-
max_active_clusters, tile_count_semaphore
|
|
151
|
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
|
122
152
|
)
|
|
123
153
|
|
|
124
154
|
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
|
@@ -129,7 +159,7 @@ def gemm_dact_sm90(
|
|
|
129
159
|
max_active_clusters,
|
|
130
160
|
cluster_shape_mnk,
|
|
131
161
|
tensor_infos,
|
|
132
|
-
|
|
162
|
+
GemmCls.num_epi_tensormaps,
|
|
133
163
|
pingpong,
|
|
134
164
|
)
|
|
135
165
|
|
|
@@ -142,19 +172,21 @@ def gemm_dact_sm90(
|
|
|
142
172
|
pingpong,
|
|
143
173
|
persistent,
|
|
144
174
|
tile_count_semaphore is not None,
|
|
175
|
+
device_capacity,
|
|
176
|
+
max_swizzle_size,
|
|
145
177
|
cu_seqlens_m is not None,
|
|
146
178
|
A_idx is not None,
|
|
147
179
|
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
148
180
|
)
|
|
149
|
-
cache =
|
|
181
|
+
cache = gemm_dact.compile_cache
|
|
150
182
|
if compile_key not in cache:
|
|
151
|
-
|
|
183
|
+
if device_capacity[0] == 9:
|
|
184
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
185
|
+
gemm = GemmCls(
|
|
152
186
|
acc_dtype,
|
|
153
187
|
tensor_infos["A"].dtype,
|
|
154
188
|
tile_shape_mn,
|
|
155
189
|
cluster_shape_mnk,
|
|
156
|
-
pingpong=pingpong,
|
|
157
|
-
is_persistent=persistent,
|
|
158
190
|
gather_A=gather_A,
|
|
159
191
|
)
|
|
160
192
|
cache[compile_key] = cute.compile(
|
|
@@ -180,4 +212,4 @@ def gemm_dact_sm90(
|
|
|
180
212
|
)
|
|
181
213
|
|
|
182
214
|
|
|
183
|
-
|
|
215
|
+
gemm_dact.compile_cache = {}
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
from functools import partial
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
10
|
+
|
|
11
|
+
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
12
|
+
from quack.gemm_sm90 import GemmSm90
|
|
13
|
+
from quack.gemm_sm100 import GemmSm100
|
|
14
|
+
from quack.sm90_utils import partition_for_epilogue
|
|
15
|
+
import quack.utils as utils
|
|
16
|
+
import quack.copy_utils as copy_utils
|
|
17
|
+
from quack.varlen_utils import VarlenManager
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GemmDefaultEpiMixin:
|
|
21
|
+
num_epi_tensormaps: int = 0
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class EpilogueArguments(ArgumentsBase):
|
|
25
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
26
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
27
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
28
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
29
|
+
add_to_output: bool = False
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class EpilogueParams(ParamsBase):
|
|
33
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
34
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
35
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
36
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
37
|
+
|
|
38
|
+
def epi_to_underlying_arguments(
|
|
39
|
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
40
|
+
) -> EpilogueParams:
|
|
41
|
+
# Assume all strides are divisible by 32 bits except the last stride
|
|
42
|
+
new_stride = lambda t: tuple(
|
|
43
|
+
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
|
44
|
+
for s in t.stride
|
|
45
|
+
)
|
|
46
|
+
mRowVecBroadcast, mColVecBroadcast = [
|
|
47
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
48
|
+
if t is not None
|
|
49
|
+
else None
|
|
50
|
+
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
|
51
|
+
]
|
|
52
|
+
return self.EpilogueParams(
|
|
53
|
+
alpha=args.alpha,
|
|
54
|
+
beta=args.beta,
|
|
55
|
+
mRowVecBroadcast=mRowVecBroadcast,
|
|
56
|
+
mColVecBroadcast=mColVecBroadcast,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
@cute.jit
|
|
60
|
+
def epi_begin(
|
|
61
|
+
self,
|
|
62
|
+
params: EpilogueParams,
|
|
63
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
64
|
+
epi_tile: cute.Tile,
|
|
65
|
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
|
66
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
67
|
+
tile_coord_mnkl: cute.Coord,
|
|
68
|
+
varlen_manager: VarlenManager,
|
|
69
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
70
|
+
tidx: Int32,
|
|
71
|
+
):
|
|
72
|
+
alpha, beta = None, None
|
|
73
|
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
74
|
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
|
75
|
+
if const_expr(hasattr(params, "beta") and params.beta is not None):
|
|
76
|
+
beta = utils.load_scalar_or_pointer(params.beta)
|
|
77
|
+
sRowVec, sColVec, *rest = epi_smem_tensors
|
|
78
|
+
tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
|
|
79
|
+
batch_idx = tile_coord_mnkl[3]
|
|
80
|
+
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
|
|
81
|
+
# Don't need sync as we assume the previous epilogue has finished
|
|
82
|
+
|
|
83
|
+
partition_for_epilogue_fn = partial(
|
|
84
|
+
partition_for_epilogue,
|
|
85
|
+
epi_tile=epi_tile,
|
|
86
|
+
tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
|
|
87
|
+
tidx=tidx,
|
|
88
|
+
reference_src=tiled_copy_t2r is None,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
tDsRowVec = None
|
|
92
|
+
if const_expr(params.mRowVecBroadcast is not None):
|
|
93
|
+
rowvec_dtype = params.mRowVecBroadcast.element_type
|
|
94
|
+
num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width
|
|
95
|
+
thr_copy_RV = copy_utils.tiled_copy_1d(
|
|
96
|
+
params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
|
|
97
|
+
).get_slice(tidx)
|
|
98
|
+
mRowVec = params.mRowVecBroadcast[batch_idx, None]
|
|
99
|
+
gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],))
|
|
100
|
+
tRVgRV = thr_copy_RV.partition_S(gRowVec)
|
|
101
|
+
tRVsRV = thr_copy_RV.partition_D(sRowVec)
|
|
102
|
+
tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
|
|
103
|
+
limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
|
|
104
|
+
tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
|
|
105
|
+
for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
|
|
106
|
+
tRVpRV[0, m] = tRVcRV[0, m] < limit_n
|
|
107
|
+
cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
|
|
108
|
+
# (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
|
|
109
|
+
tDsRowVec = partition_for_epilogue_fn(
|
|
110
|
+
cute.make_tensor(
|
|
111
|
+
sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1))
|
|
112
|
+
)
|
|
113
|
+
)
|
|
114
|
+
if const_expr(tiled_copy_t2r is not None):
|
|
115
|
+
tDsRowVec = tiled_copy_r2s.retile(tDsRowVec)
|
|
116
|
+
|
|
117
|
+
tDsColVec = None
|
|
118
|
+
if const_expr(params.mColVecBroadcast is not None):
|
|
119
|
+
colvec_dtype = params.mColVecBroadcast.element_type
|
|
120
|
+
num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width
|
|
121
|
+
thr_copy_CV = copy_utils.tiled_copy_1d(
|
|
122
|
+
params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
|
|
123
|
+
).get_slice(tidx)
|
|
124
|
+
if const_expr(not varlen_manager.varlen_m):
|
|
125
|
+
mColVec = params.mColVecBroadcast[batch_idx, None]
|
|
126
|
+
else:
|
|
127
|
+
mColVec = cute.domain_offset(
|
|
128
|
+
(varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast
|
|
129
|
+
)
|
|
130
|
+
gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
|
|
131
|
+
tCVgCV = thr_copy_CV.partition_S(gColVec)
|
|
132
|
+
tCVsCV = thr_copy_CV.partition_D(sColVec)
|
|
133
|
+
tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
|
|
134
|
+
limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
|
|
135
|
+
tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
|
|
136
|
+
for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
|
|
137
|
+
tCVpCV[0, m] = tCVcCV[0, m] < limit_m
|
|
138
|
+
cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
|
|
139
|
+
tDsColVec = partition_for_epilogue_fn(
|
|
140
|
+
cute.make_tensor(
|
|
141
|
+
sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0))
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
if const_expr(tiled_copy_t2r is not None):
|
|
145
|
+
tDsColVec = tiled_copy_r2s.retile(tDsColVec)
|
|
146
|
+
|
|
147
|
+
if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None):
|
|
148
|
+
cute.arch.cp_async_commit_group()
|
|
149
|
+
cute.arch.cp_async_wait_group(0)
|
|
150
|
+
epilogue_barrier.arrive_and_wait()
|
|
151
|
+
return alpha, beta, tDsRowVec, tDsColVec
|
|
152
|
+
|
|
153
|
+
def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
|
|
154
|
+
alpha, beta, tDsRowVec, tDsColVec = epi_tensors
|
|
155
|
+
tDrRowVec_cvt = None
|
|
156
|
+
if const_expr(tDsRowVec is not None):
|
|
157
|
+
tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[
|
|
158
|
+
None, None, None, epi_coord
|
|
159
|
+
]
|
|
160
|
+
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
|
161
|
+
tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
|
|
162
|
+
cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
|
|
163
|
+
tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
|
|
164
|
+
tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
|
|
165
|
+
tDrColVec_cvt = None
|
|
166
|
+
if const_expr(tDsColVec is not None):
|
|
167
|
+
tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[
|
|
168
|
+
None, None, None, epi_coord
|
|
169
|
+
]
|
|
170
|
+
# This somehow doesn't work, some dim with stride 0 turns to non-zero stride
|
|
171
|
+
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
|
172
|
+
tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
|
|
173
|
+
cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
|
|
174
|
+
tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
|
|
175
|
+
tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
|
|
176
|
+
return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt
|
|
177
|
+
|
|
178
|
+
@cute.jit
|
|
179
|
+
def epi_visit_subtile(
|
|
180
|
+
self,
|
|
181
|
+
params: EpilogueParams,
|
|
182
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
183
|
+
tRS_rD: cute.Tensor,
|
|
184
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
185
|
+
) -> Optional[cute.Tensor]:
|
|
186
|
+
alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors
|
|
187
|
+
rD = tRS_rD.load()
|
|
188
|
+
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
|
189
|
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
190
|
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
|
191
|
+
rD *= alpha
|
|
192
|
+
# Apply C with beta scaling
|
|
193
|
+
if const_expr(tRS_rC is not None):
|
|
194
|
+
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
|
195
|
+
# beta is None, default behavior: add C (beta=1.0)
|
|
196
|
+
rD += tRS_rC.load().to(tRS_rD.element_type)
|
|
197
|
+
else:
|
|
198
|
+
beta = utils.load_scalar_or_pointer(params.beta)
|
|
199
|
+
rD += beta * tRS_rC.load().to(tRS_rD.element_type)
|
|
200
|
+
tRS_rD.store(rD)
|
|
201
|
+
if const_expr(tDrRowVec is not None):
|
|
202
|
+
for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True):
|
|
203
|
+
tRS_rD[i] += tDrRowVec[i]
|
|
204
|
+
if const_expr(tDrColVec is not None):
|
|
205
|
+
for i in cutlass.range(cute.size(tDrColVec), unroll_full=True):
|
|
206
|
+
tRS_rD[i] += tDrColVec[i]
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def epi_smem_bytes_per_stage(
|
|
211
|
+
args: Optional[EpilogueArguments],
|
|
212
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
213
|
+
epi_tile: cute.Tile,
|
|
214
|
+
) -> int:
|
|
215
|
+
row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1]
|
|
216
|
+
col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0]
|
|
217
|
+
row_vec_dtype = (
|
|
218
|
+
args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32
|
|
219
|
+
)
|
|
220
|
+
col_vec_dtype = (
|
|
221
|
+
args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32
|
|
222
|
+
)
|
|
223
|
+
return (
|
|
224
|
+
row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width
|
|
225
|
+
) // 8
|
|
226
|
+
|
|
227
|
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
228
|
+
row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
|
|
229
|
+
col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
|
|
230
|
+
row_vec_dtype = (
|
|
231
|
+
params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
|
|
232
|
+
)
|
|
233
|
+
col_vec_dtype = (
|
|
234
|
+
params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
@cute.struct
|
|
238
|
+
class EpiSharedStorage:
|
|
239
|
+
sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
|
|
240
|
+
sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
|
|
241
|
+
|
|
242
|
+
return EpiSharedStorage
|
|
243
|
+
|
|
244
|
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
245
|
+
sRowVec = None
|
|
246
|
+
if const_expr(params.mRowVecBroadcast is not None):
|
|
247
|
+
sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1]))
|
|
248
|
+
sColVec = None
|
|
249
|
+
if const_expr(params.mColVecBroadcast is not None):
|
|
250
|
+
sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0]))
|
|
251
|
+
return (sRowVec, sColVec)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
|
|
255
|
+
pass
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
|
|
259
|
+
pass
|