quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_dact.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Float32, const_expr
|
|
10
|
+
import cutlass.torch as cutlass_torch
|
|
11
|
+
|
|
12
|
+
from quack.gemm_sm90 import GemmSm90
|
|
13
|
+
from quack.gemm_sm100 import GemmSm100
|
|
14
|
+
from quack.gemm_default_epi import GemmDefaultEpiMixin
|
|
15
|
+
from quack.gemm_act import GemmActMixin
|
|
16
|
+
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
|
17
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
18
|
+
import quack.activation
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GemmDActMixin(GemmActMixin):
|
|
22
|
+
# Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
|
|
23
|
+
# and return 2 arguments (dx, out)
|
|
24
|
+
EpilogueArguments = GemmActMixin.EpilogueArguments
|
|
25
|
+
EpilogueParams = GemmActMixin.EpilogueParams
|
|
26
|
+
|
|
27
|
+
@cute.jit
|
|
28
|
+
def epi_visit_subtile(
|
|
29
|
+
self,
|
|
30
|
+
params: EpilogueParams,
|
|
31
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
32
|
+
tRS_rD: cute.Tensor,
|
|
33
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
34
|
+
) -> Optional[cute.Tensor]:
|
|
35
|
+
assert tRS_rC is not None
|
|
36
|
+
# We don't add C to the accumulator
|
|
37
|
+
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
|
|
38
|
+
tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
|
|
39
|
+
tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
|
|
40
|
+
# If we don't have .shape here, the compiler generates local stores and loads
|
|
41
|
+
if const_expr(params.act_fn is not None):
|
|
42
|
+
tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
|
|
43
|
+
if const_expr(self.arch < 100):
|
|
44
|
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
45
|
+
tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
|
|
46
|
+
else:
|
|
47
|
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
|
48
|
+
(
|
|
49
|
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1]),
|
|
50
|
+
(tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]),
|
|
51
|
+
) = params.act_fn(
|
|
52
|
+
(tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]),
|
|
53
|
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1]),
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
tRS_rPostAct = tRS_rC_acc
|
|
57
|
+
# Type conversion
|
|
58
|
+
tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
|
|
59
|
+
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
|
60
|
+
return tRS_rPostAct_out
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class GemmDActSm90(GemmDActMixin, GemmSm90):
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GemmDActSm100(GemmDActMixin, GemmSm100):
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
dact_fn_map = {
|
|
72
|
+
None: None,
|
|
73
|
+
"relu": quack.activation.drelu,
|
|
74
|
+
"relu_sq": quack.activation.drelu_sq,
|
|
75
|
+
"gelu_tanh_approx": quack.activation.dgelu_tanh_approx,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def gemm_dact(
|
|
80
|
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
|
81
|
+
B: Tensor, # (l, n, k)
|
|
82
|
+
Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
83
|
+
PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
84
|
+
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
85
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
86
|
+
activation: Optional[str],
|
|
87
|
+
tile_M: int,
|
|
88
|
+
tile_N: int,
|
|
89
|
+
cluster_M: int,
|
|
90
|
+
cluster_N: int,
|
|
91
|
+
pingpong: bool = True,
|
|
92
|
+
persistent: bool = True,
|
|
93
|
+
max_swizzle_size: int = 8,
|
|
94
|
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
95
|
+
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
96
|
+
) -> None:
|
|
97
|
+
if cu_seqlens_m is not None:
|
|
98
|
+
assert persistent, "varlen_m requires persistent=True"
|
|
99
|
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
100
|
+
assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
|
|
101
|
+
assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
|
|
102
|
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
|
103
|
+
gather_A = A_idx is not None
|
|
104
|
+
if gather_A:
|
|
105
|
+
assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
|
|
106
|
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
107
|
+
assert activation in dact_fn_map, f"Unsupported activation {activation}"
|
|
108
|
+
|
|
109
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
110
|
+
A,
|
|
111
|
+
B,
|
|
112
|
+
Out,
|
|
113
|
+
PreAct,
|
|
114
|
+
additional_tensors={"PostAct": PostAct},
|
|
115
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
116
|
+
A_idx=A_idx,
|
|
117
|
+
)
|
|
118
|
+
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
|
119
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
120
|
+
major_configs = {
|
|
121
|
+
"A": ("m", "k", "l"),
|
|
122
|
+
"B": ("n", "k", "l"),
|
|
123
|
+
"D": ("m", "n", "l"),
|
|
124
|
+
"C": ("m", "n", "l"),
|
|
125
|
+
"PostAct": ("m", "n", "l"),
|
|
126
|
+
}
|
|
127
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
128
|
+
|
|
129
|
+
device_capacity = get_device_capacity(A.device)
|
|
130
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
131
|
+
GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90
|
|
132
|
+
|
|
133
|
+
acc_dtype = Float32
|
|
134
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
135
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
136
|
+
if not GemmCls.is_valid_dtypes(
|
|
137
|
+
tensor_infos["A"].dtype,
|
|
138
|
+
tensor_infos["B"].dtype,
|
|
139
|
+
acc_dtype,
|
|
140
|
+
tensor_infos["D"].dtype,
|
|
141
|
+
tensor_infos["A"].major,
|
|
142
|
+
tensor_infos["B"].major,
|
|
143
|
+
):
|
|
144
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
145
|
+
|
|
146
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
147
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
148
|
+
act_fn = dact_fn_map[activation]
|
|
149
|
+
epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
|
|
150
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
151
|
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
|
155
|
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
156
|
+
cu_seqlens_m,
|
|
157
|
+
None, # cu_seqlens_k
|
|
158
|
+
A_idx,
|
|
159
|
+
max_active_clusters,
|
|
160
|
+
cluster_shape_mnk,
|
|
161
|
+
tensor_infos,
|
|
162
|
+
GemmCls.num_epi_tensormaps,
|
|
163
|
+
pingpong,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
current_stream = cutlass_torch.current_stream()
|
|
167
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
168
|
+
tensor_infos,
|
|
169
|
+
activation,
|
|
170
|
+
tile_shape_mn,
|
|
171
|
+
cluster_shape_mnk,
|
|
172
|
+
pingpong,
|
|
173
|
+
persistent,
|
|
174
|
+
tile_count_semaphore is not None,
|
|
175
|
+
device_capacity,
|
|
176
|
+
max_swizzle_size,
|
|
177
|
+
cu_seqlens_m is not None,
|
|
178
|
+
A_idx is not None,
|
|
179
|
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
180
|
+
)
|
|
181
|
+
cache = gemm_dact.compile_cache
|
|
182
|
+
if compile_key not in cache:
|
|
183
|
+
if device_capacity[0] == 9:
|
|
184
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
185
|
+
gemm = GemmCls(
|
|
186
|
+
acc_dtype,
|
|
187
|
+
tensor_infos["A"].dtype,
|
|
188
|
+
tile_shape_mn,
|
|
189
|
+
cluster_shape_mnk,
|
|
190
|
+
gather_A=gather_A,
|
|
191
|
+
)
|
|
192
|
+
cache[compile_key] = cute.compile(
|
|
193
|
+
gemm,
|
|
194
|
+
tensor_infos["A"].cute_tensor,
|
|
195
|
+
tensor_infos["B"].cute_tensor,
|
|
196
|
+
tensor_infos["D"].cute_tensor,
|
|
197
|
+
tensor_infos["C"].cute_tensor,
|
|
198
|
+
epi_args,
|
|
199
|
+
scheduler_args,
|
|
200
|
+
varlen_args,
|
|
201
|
+
current_stream,
|
|
202
|
+
)
|
|
203
|
+
cache[compile_key](
|
|
204
|
+
tensor_infos["A"].cute_tensor,
|
|
205
|
+
tensor_infos["B"].cute_tensor,
|
|
206
|
+
tensor_infos["D"].cute_tensor,
|
|
207
|
+
tensor_infos["C"].cute_tensor,
|
|
208
|
+
epi_args,
|
|
209
|
+
scheduler_args,
|
|
210
|
+
varlen_args,
|
|
211
|
+
current_stream,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
gemm_dact.compile_cache = {}
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
from functools import partial
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
10
|
+
|
|
11
|
+
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
12
|
+
from quack.gemm_sm90 import GemmSm90
|
|
13
|
+
from quack.gemm_sm100 import GemmSm100
|
|
14
|
+
from quack.sm90_utils import partition_for_epilogue
|
|
15
|
+
import quack.utils as utils
|
|
16
|
+
import quack.copy_utils as copy_utils
|
|
17
|
+
from quack.varlen_utils import VarlenManager
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GemmDefaultEpiMixin:
|
|
21
|
+
num_epi_tensormaps: int = 0
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class EpilogueArguments(ArgumentsBase):
|
|
25
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
26
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
27
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
28
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
29
|
+
add_to_output: bool = False
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class EpilogueParams(ParamsBase):
|
|
33
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
34
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
35
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
36
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
37
|
+
|
|
38
|
+
def epi_to_underlying_arguments(
|
|
39
|
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
40
|
+
) -> EpilogueParams:
|
|
41
|
+
# Assume all strides are divisible by 32 bits except the last stride
|
|
42
|
+
new_stride = lambda t: tuple(
|
|
43
|
+
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
|
44
|
+
for s in t.stride
|
|
45
|
+
)
|
|
46
|
+
mRowVecBroadcast, mColVecBroadcast = [
|
|
47
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
48
|
+
if t is not None
|
|
49
|
+
else None
|
|
50
|
+
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
|
51
|
+
]
|
|
52
|
+
return self.EpilogueParams(
|
|
53
|
+
alpha=args.alpha,
|
|
54
|
+
beta=args.beta,
|
|
55
|
+
mRowVecBroadcast=mRowVecBroadcast,
|
|
56
|
+
mColVecBroadcast=mColVecBroadcast,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
@cute.jit
|
|
60
|
+
def epi_begin(
|
|
61
|
+
self,
|
|
62
|
+
params: EpilogueParams,
|
|
63
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
64
|
+
epi_tile: cute.Tile,
|
|
65
|
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
|
66
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
67
|
+
tile_coord_mnkl: cute.Coord,
|
|
68
|
+
varlen_manager: VarlenManager,
|
|
69
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
70
|
+
tidx: Int32,
|
|
71
|
+
):
|
|
72
|
+
alpha, beta = None, None
|
|
73
|
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
74
|
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
|
75
|
+
if const_expr(hasattr(params, "beta") and params.beta is not None):
|
|
76
|
+
beta = utils.load_scalar_or_pointer(params.beta)
|
|
77
|
+
sRowVec, sColVec, *rest = epi_smem_tensors
|
|
78
|
+
tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
|
|
79
|
+
batch_idx = tile_coord_mnkl[3]
|
|
80
|
+
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
|
|
81
|
+
# Don't need sync as we assume the previous epilogue has finished
|
|
82
|
+
|
|
83
|
+
partition_for_epilogue_fn = partial(
|
|
84
|
+
partition_for_epilogue,
|
|
85
|
+
epi_tile=epi_tile,
|
|
86
|
+
tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
|
|
87
|
+
tidx=tidx,
|
|
88
|
+
reference_src=tiled_copy_t2r is None,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
tDsRowVec = None
|
|
92
|
+
if const_expr(params.mRowVecBroadcast is not None):
|
|
93
|
+
rowvec_dtype = params.mRowVecBroadcast.element_type
|
|
94
|
+
num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width
|
|
95
|
+
thr_copy_RV = copy_utils.tiled_copy_1d(
|
|
96
|
+
params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
|
|
97
|
+
).get_slice(tidx)
|
|
98
|
+
mRowVec = params.mRowVecBroadcast[batch_idx, None]
|
|
99
|
+
gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],))
|
|
100
|
+
tRVgRV = thr_copy_RV.partition_S(gRowVec)
|
|
101
|
+
tRVsRV = thr_copy_RV.partition_D(sRowVec)
|
|
102
|
+
tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
|
|
103
|
+
limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
|
|
104
|
+
tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
|
|
105
|
+
for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
|
|
106
|
+
tRVpRV[0, m] = tRVcRV[0, m] < limit_n
|
|
107
|
+
cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
|
|
108
|
+
# (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
|
|
109
|
+
tDsRowVec = partition_for_epilogue_fn(
|
|
110
|
+
cute.make_tensor(
|
|
111
|
+
sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1))
|
|
112
|
+
)
|
|
113
|
+
)
|
|
114
|
+
if const_expr(tiled_copy_t2r is not None):
|
|
115
|
+
tDsRowVec = tiled_copy_r2s.retile(tDsRowVec)
|
|
116
|
+
|
|
117
|
+
tDsColVec = None
|
|
118
|
+
if const_expr(params.mColVecBroadcast is not None):
|
|
119
|
+
colvec_dtype = params.mColVecBroadcast.element_type
|
|
120
|
+
num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width
|
|
121
|
+
thr_copy_CV = copy_utils.tiled_copy_1d(
|
|
122
|
+
params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
|
|
123
|
+
).get_slice(tidx)
|
|
124
|
+
if const_expr(not varlen_manager.varlen_m):
|
|
125
|
+
mColVec = params.mColVecBroadcast[batch_idx, None]
|
|
126
|
+
else:
|
|
127
|
+
mColVec = cute.domain_offset(
|
|
128
|
+
(varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast
|
|
129
|
+
)
|
|
130
|
+
gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
|
|
131
|
+
tCVgCV = thr_copy_CV.partition_S(gColVec)
|
|
132
|
+
tCVsCV = thr_copy_CV.partition_D(sColVec)
|
|
133
|
+
tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
|
|
134
|
+
limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
|
|
135
|
+
tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
|
|
136
|
+
for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
|
|
137
|
+
tCVpCV[0, m] = tCVcCV[0, m] < limit_m
|
|
138
|
+
cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
|
|
139
|
+
tDsColVec = partition_for_epilogue_fn(
|
|
140
|
+
cute.make_tensor(
|
|
141
|
+
sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0))
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
if const_expr(tiled_copy_t2r is not None):
|
|
145
|
+
tDsColVec = tiled_copy_r2s.retile(tDsColVec)
|
|
146
|
+
|
|
147
|
+
if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None):
|
|
148
|
+
cute.arch.cp_async_commit_group()
|
|
149
|
+
cute.arch.cp_async_wait_group(0)
|
|
150
|
+
epilogue_barrier.arrive_and_wait()
|
|
151
|
+
return alpha, beta, tDsRowVec, tDsColVec
|
|
152
|
+
|
|
153
|
+
def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
|
|
154
|
+
alpha, beta, tDsRowVec, tDsColVec = epi_tensors
|
|
155
|
+
tDrRowVec_cvt = None
|
|
156
|
+
if const_expr(tDsRowVec is not None):
|
|
157
|
+
tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[
|
|
158
|
+
None, None, None, epi_coord
|
|
159
|
+
]
|
|
160
|
+
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
|
161
|
+
tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
|
|
162
|
+
cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
|
|
163
|
+
tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
|
|
164
|
+
tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
|
|
165
|
+
tDrColVec_cvt = None
|
|
166
|
+
if const_expr(tDsColVec is not None):
|
|
167
|
+
tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[
|
|
168
|
+
None, None, None, epi_coord
|
|
169
|
+
]
|
|
170
|
+
# This somehow doesn't work, some dim with stride 0 turns to non-zero stride
|
|
171
|
+
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
|
172
|
+
tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
|
|
173
|
+
cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
|
|
174
|
+
tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
|
|
175
|
+
tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
|
|
176
|
+
return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt
|
|
177
|
+
|
|
178
|
+
@cute.jit
|
|
179
|
+
def epi_visit_subtile(
|
|
180
|
+
self,
|
|
181
|
+
params: EpilogueParams,
|
|
182
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
183
|
+
tRS_rD: cute.Tensor,
|
|
184
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
185
|
+
) -> Optional[cute.Tensor]:
|
|
186
|
+
alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors
|
|
187
|
+
rD = tRS_rD.load()
|
|
188
|
+
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
|
189
|
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
190
|
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
|
191
|
+
rD *= alpha
|
|
192
|
+
# Apply C with beta scaling
|
|
193
|
+
if const_expr(tRS_rC is not None):
|
|
194
|
+
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
|
195
|
+
# beta is None, default behavior: add C (beta=1.0)
|
|
196
|
+
rD += tRS_rC.load().to(tRS_rD.element_type)
|
|
197
|
+
else:
|
|
198
|
+
beta = utils.load_scalar_or_pointer(params.beta)
|
|
199
|
+
rD += beta * tRS_rC.load().to(tRS_rD.element_type)
|
|
200
|
+
tRS_rD.store(rD)
|
|
201
|
+
if const_expr(tDrRowVec is not None):
|
|
202
|
+
for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True):
|
|
203
|
+
tRS_rD[i] += tDrRowVec[i]
|
|
204
|
+
if const_expr(tDrColVec is not None):
|
|
205
|
+
for i in cutlass.range(cute.size(tDrColVec), unroll_full=True):
|
|
206
|
+
tRS_rD[i] += tDrColVec[i]
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def epi_smem_bytes_per_stage(
|
|
211
|
+
args: Optional[EpilogueArguments],
|
|
212
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
213
|
+
epi_tile: cute.Tile,
|
|
214
|
+
) -> int:
|
|
215
|
+
row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1]
|
|
216
|
+
col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0]
|
|
217
|
+
row_vec_dtype = (
|
|
218
|
+
args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32
|
|
219
|
+
)
|
|
220
|
+
col_vec_dtype = (
|
|
221
|
+
args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32
|
|
222
|
+
)
|
|
223
|
+
return (
|
|
224
|
+
row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width
|
|
225
|
+
) // 8
|
|
226
|
+
|
|
227
|
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
228
|
+
row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
|
|
229
|
+
col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
|
|
230
|
+
row_vec_dtype = (
|
|
231
|
+
params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
|
|
232
|
+
)
|
|
233
|
+
col_vec_dtype = (
|
|
234
|
+
params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
@cute.struct
|
|
238
|
+
class EpiSharedStorage:
|
|
239
|
+
sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
|
|
240
|
+
sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
|
|
241
|
+
|
|
242
|
+
return EpiSharedStorage
|
|
243
|
+
|
|
244
|
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
245
|
+
sRowVec = None
|
|
246
|
+
if const_expr(params.mRowVecBroadcast is not None):
|
|
247
|
+
sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1]))
|
|
248
|
+
sColVec = None
|
|
249
|
+
if const_expr(params.mColVecBroadcast is not None):
|
|
250
|
+
sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0]))
|
|
251
|
+
return (sRowVec, sColVec)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
|
|
255
|
+
pass
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
|
|
259
|
+
pass
|