quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -1,33 +1,8 @@
|
|
|
1
|
-
#
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
# Redistribution and use in source and binary forms, with or without
|
|
5
|
-
# modification, are permitted provided that the following conditions are met:
|
|
6
|
-
|
|
7
|
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
-
# list of conditions and the following disclaimer.
|
|
9
|
-
|
|
10
|
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
-
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
-
# and/or other materials provided with the distribution.
|
|
13
|
-
|
|
14
|
-
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
-
# contributors may be used to endorse or promote products derived from
|
|
16
|
-
# this software without specific prior written permission.
|
|
17
|
-
|
|
18
|
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
1
|
+
# Based on the cute-dsl example:
|
|
2
|
+
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
|
|
28
3
|
|
|
29
4
|
import argparse
|
|
30
|
-
from typing import Optional, Type, Tuple, Union, Callable
|
|
5
|
+
from typing import Optional, Type, Tuple, Union, Callable, Literal
|
|
31
6
|
from functools import partial
|
|
32
7
|
|
|
33
8
|
import cuda.bindings.driver as cuda
|
|
@@ -40,15 +15,25 @@ import cutlass.torch as cutlass_torch
|
|
|
40
15
|
import cutlass.pipeline as pipeline
|
|
41
16
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
42
17
|
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
|
18
|
+
from cutlass.cute.nvgpu.warp import (
|
|
19
|
+
LdMatrix8x8x16bOp,
|
|
20
|
+
LdMatrix16x16x8bOp,
|
|
21
|
+
StMatrix8x8x16bOp,
|
|
22
|
+
StMatrix16x8x8bOp,
|
|
23
|
+
)
|
|
24
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
25
|
+
from cutlass.utils import LayoutEnum
|
|
43
26
|
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
44
|
-
from cutlass import Int32, const_expr
|
|
45
27
|
|
|
46
|
-
from quack.
|
|
47
|
-
from quack.
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
28
|
+
from quack.pipeline import PipelineTmaCpAsyncUmma
|
|
29
|
+
from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
|
|
30
|
+
from quack.tile_scheduler import TileSchedulerOptions
|
|
31
|
+
from quack.varlen_utils import VarlenArguments, VarlenManager
|
|
32
|
+
from quack.gemm_sm90 import GemmSm90, NamedBarrierGemm
|
|
33
|
+
import quack.copy_utils as copy_utils
|
|
34
|
+
import quack.sm100_utils as quack_sm100_utils
|
|
35
|
+
|
|
36
|
+
# return PipelineStateWAdvance instead of PipelineState
|
|
52
37
|
|
|
53
38
|
"""
|
|
54
39
|
A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
|
|
@@ -72,8 +57,6 @@ This GEMM works as follows:
|
|
|
72
57
|
- Type convert C matrix to output type.
|
|
73
58
|
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
|
|
74
59
|
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
|
|
75
|
-
- Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
|
|
76
|
-
e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
|
|
77
60
|
|
|
78
61
|
SM100 tcgen05.mma instructions operate as follows:
|
|
79
62
|
- Read matrix A from SMEM
|
|
@@ -105,7 +88,7 @@ To collect performance with NCU profiler:
|
|
|
105
88
|
|
|
106
89
|
Constraints are same as dense_gemm.py:
|
|
107
90
|
* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
|
|
108
|
-
see detailed valid dtype combinations in below
|
|
91
|
+
see detailed valid dtype combinations in below GemmSm100 class documentation
|
|
109
92
|
* A/B tensor must have the same data type
|
|
110
93
|
* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
|
|
111
94
|
* Mma tiler N must be 32-256, step 32
|
|
@@ -118,14 +101,12 @@ Constraints are same as dense_gemm.py:
|
|
|
118
101
|
"""
|
|
119
102
|
|
|
120
103
|
|
|
121
|
-
class
|
|
104
|
+
class GemmSm100(GemmSm90):
|
|
122
105
|
"""This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
123
106
|
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
|
|
124
107
|
|
|
125
108
|
:param acc_dtype: Data type for accumulation during computation
|
|
126
109
|
:type acc_dtype: type[cutlass.Numeric]
|
|
127
|
-
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
|
|
128
|
-
:type use_2cta_instrs: bool
|
|
129
110
|
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
|
|
130
111
|
:type mma_tiler_mn: Tuple[int, int]
|
|
131
112
|
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
|
@@ -159,22 +140,28 @@ class PersistentDenseGemmKernel:
|
|
|
159
140
|
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
|
160
141
|
|
|
161
142
|
Example:
|
|
162
|
-
>>> gemm =
|
|
163
|
-
... acc_dtype=
|
|
164
|
-
... use_2cta_instrs=True,
|
|
143
|
+
>>> gemm = GemmSm100(
|
|
144
|
+
... acc_dtype=Float32,
|
|
165
145
|
... mma_tiler_mn=(128, 128),
|
|
166
146
|
... cluster_shape_mn=(2, 2)
|
|
167
147
|
... )
|
|
168
148
|
>>> gemm(mA, mB, mD, max_active_clusters, stream)
|
|
169
149
|
"""
|
|
170
150
|
|
|
151
|
+
arch = 100
|
|
152
|
+
num_epi_tensormaps = GemmSm90.num_epi_tensormaps
|
|
153
|
+
|
|
154
|
+
EpilogueArguments = GemmSm90.EpilogueArguments
|
|
155
|
+
EpilogueParams = GemmSm90.EpilogueParams
|
|
156
|
+
|
|
171
157
|
def __init__(
|
|
172
158
|
self,
|
|
173
159
|
acc_dtype: Type[cutlass.Numeric],
|
|
174
|
-
|
|
160
|
+
a_dtype: Type[cutlass.Numeric], # ignored for now
|
|
175
161
|
mma_tiler_mn: Tuple[int, int],
|
|
176
|
-
|
|
162
|
+
cluster_shape_mnk: Tuple[int, int, int],
|
|
177
163
|
sf_vec_size: Optional[int] = None,
|
|
164
|
+
gather_A: bool = False,
|
|
178
165
|
):
|
|
179
166
|
"""Initializes the configuration for a Blackwell dense GEMM kernel.
|
|
180
167
|
|
|
@@ -187,50 +174,54 @@ class PersistentDenseGemmKernel:
|
|
|
187
174
|
with cta_group=2 should be used.
|
|
188
175
|
|
|
189
176
|
2. Cluster Shape:
|
|
190
|
-
-
|
|
177
|
+
- cluster_shape_mnk: The (ClusterM, ClusterN) shape of the CTA cluster.
|
|
191
178
|
|
|
192
179
|
:param acc_dtype: Data type of the accumulator.
|
|
193
180
|
:type acc_dtype: type[cutlass.Numeric]
|
|
194
181
|
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
|
|
195
182
|
:type mma_tiler_mn: Tuple[int, int]
|
|
196
|
-
:param
|
|
197
|
-
:type
|
|
198
|
-
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
|
|
199
|
-
:type cluster_shape_mn: Tuple[int, int]
|
|
183
|
+
:param cluster_shape_mnk: Tuple (ClusterM, ClusterN) shape of the cluster.
|
|
184
|
+
:type cluster_shape_mnk: Tuple[int, int]
|
|
200
185
|
"""
|
|
201
186
|
|
|
202
187
|
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
|
|
203
|
-
self.use_2cta_instrs =
|
|
204
|
-
self.
|
|
188
|
+
self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (256,)
|
|
189
|
+
self.cluster_shape_mnk = cluster_shape_mnk
|
|
190
|
+
assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1"
|
|
205
191
|
# K dimension is deferred in _setup_attributes
|
|
206
192
|
self.mma_tiler = (*mma_tiler_mn, 1)
|
|
207
193
|
self.sf_vec_size = sf_vec_size
|
|
208
194
|
self.blockscaled = sf_vec_size is not None
|
|
195
|
+
self.is_persistent = True
|
|
196
|
+
self.pingpong = False # for compatibility with GemmSm90
|
|
197
|
+
self.gather_A = gather_A
|
|
198
|
+
if gather_A:
|
|
199
|
+
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
|
209
200
|
|
|
210
|
-
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
201
|
+
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
211
202
|
|
|
203
|
+
self.num_ab_load_warps = 1 if not self.gather_A else 5
|
|
212
204
|
self.occupancy = 1
|
|
213
205
|
# Set specialized warp ids
|
|
214
|
-
self.epilog_warp_id = (
|
|
215
|
-
0,
|
|
216
|
-
1,
|
|
217
|
-
2,
|
|
218
|
-
3,
|
|
219
|
-
)
|
|
206
|
+
self.epilog_warp_id = (0, 1, 2, 3)
|
|
220
207
|
self.mma_warp_id = 4
|
|
221
|
-
self.
|
|
222
|
-
self.
|
|
223
|
-
self.
|
|
224
|
-
|
|
208
|
+
self.ab_load_warp_id = 5
|
|
209
|
+
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
210
|
+
self.scheduler_warp_id = self.epi_load_warp_id + 1
|
|
211
|
+
self.num_epi_warps = len(self.epilog_warp_id)
|
|
212
|
+
self.threads_per_cta = cute.arch.WARP_SIZE * (
|
|
213
|
+
self.num_ab_load_warps
|
|
214
|
+
+ len(
|
|
215
|
+
(
|
|
216
|
+
self.mma_warp_id,
|
|
217
|
+
self.epi_load_warp_id,
|
|
218
|
+
self.scheduler_warp_id,
|
|
219
|
+
*self.epilog_warp_id,
|
|
220
|
+
)
|
|
221
|
+
)
|
|
225
222
|
)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
self.epilog_sync_bar_id = 1
|
|
229
|
-
self.tmem_ptr_sync_bar_id = 2
|
|
230
|
-
self.epilog_load_bar_id = 3
|
|
231
|
-
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_100")
|
|
232
|
-
|
|
233
|
-
def _setup_attributes(self):
|
|
223
|
+
|
|
224
|
+
def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments):
|
|
234
225
|
"""Set up configurations that are dependent on GEMM inputs
|
|
235
226
|
|
|
236
227
|
This method configures various attributes based on the input tensor properties
|
|
@@ -261,7 +252,7 @@ class PersistentDenseGemmKernel:
|
|
|
261
252
|
|
|
262
253
|
# Configure tiled mma
|
|
263
254
|
if const_expr(not self.blockscaled):
|
|
264
|
-
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
255
|
+
self.tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
265
256
|
self.a_dtype,
|
|
266
257
|
self.a_major_mode,
|
|
267
258
|
self.b_major_mode,
|
|
@@ -269,9 +260,9 @@ class PersistentDenseGemmKernel:
|
|
|
269
260
|
self.cta_group,
|
|
270
261
|
self.mma_tiler[:2],
|
|
271
262
|
)
|
|
272
|
-
tiled_mma_sfb = None
|
|
263
|
+
self.tiled_mma_sfb = None
|
|
273
264
|
else:
|
|
274
|
-
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
265
|
+
self.tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
275
266
|
self.a_dtype,
|
|
276
267
|
self.a_major_mode,
|
|
277
268
|
self.b_major_mode,
|
|
@@ -280,13 +271,13 @@ class PersistentDenseGemmKernel:
|
|
|
280
271
|
self.cta_group,
|
|
281
272
|
self.mma_inst_shape_mnk[:2],
|
|
282
273
|
)
|
|
283
|
-
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
274
|
+
self.tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
284
275
|
self.a_dtype,
|
|
285
276
|
self.a_major_mode,
|
|
286
277
|
self.b_major_mode,
|
|
287
278
|
self.sf_dtype,
|
|
288
279
|
self.sf_vec_size,
|
|
289
|
-
|
|
280
|
+
tcgen05.CtaGroup.ONE,
|
|
290
281
|
self.mma_inst_shape_mnk_sfb[:2],
|
|
291
282
|
)
|
|
292
283
|
|
|
@@ -306,26 +297,28 @@ class PersistentDenseGemmKernel:
|
|
|
306
297
|
else:
|
|
307
298
|
self.mma_tiler_sfb = None
|
|
308
299
|
self.cta_tile_shape_mnk = (
|
|
309
|
-
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
|
300
|
+
self.mma_tiler[0] // cute.size(self.tiled_mma.thr_id.shape),
|
|
310
301
|
self.mma_tiler[1],
|
|
311
302
|
self.mma_tiler[2],
|
|
312
303
|
)
|
|
313
304
|
|
|
314
305
|
# Compute cluster layout
|
|
315
306
|
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
316
|
-
cute.make_layout(
|
|
317
|
-
(tiled_mma.thr_id.shape,),
|
|
307
|
+
cute.make_layout(self.cluster_shape_mnk),
|
|
308
|
+
(self.tiled_mma.thr_id.shape,),
|
|
318
309
|
)
|
|
319
310
|
if const_expr(self.blockscaled):
|
|
320
311
|
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
|
|
321
|
-
cute.make_layout(
|
|
322
|
-
(tiled_mma_sfb.thr_id.shape,),
|
|
312
|
+
cute.make_layout(self.cluster_shape_mnk),
|
|
313
|
+
(self.tiled_mma_sfb.thr_id.shape,),
|
|
323
314
|
)
|
|
324
315
|
else:
|
|
325
316
|
self.cluster_layout_sfb_vmnk = None
|
|
326
317
|
|
|
327
318
|
# Compute number of multicast CTAs for A/B
|
|
328
319
|
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
|
320
|
+
if self.gather_A:
|
|
321
|
+
assert self.num_mcast_ctas_a == 1
|
|
329
322
|
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
|
330
323
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
331
324
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
@@ -337,60 +330,82 @@ class PersistentDenseGemmKernel:
|
|
|
337
330
|
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
|
|
338
331
|
self.cta_tile_shape_mnk,
|
|
339
332
|
self.use_2cta_instrs,
|
|
340
|
-
self.d_layout,
|
|
341
|
-
self.d_dtype,
|
|
333
|
+
self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR,
|
|
334
|
+
self.d_dtype if self.d_dtype is not None else cutlass.BFloat16,
|
|
335
|
+
layout_c=self.c_layout,
|
|
336
|
+
elem_ty_c=self.c_dtype,
|
|
342
337
|
)
|
|
343
338
|
|
|
344
339
|
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
|
|
340
|
+
prefetch_A_idx = (
|
|
341
|
+
None
|
|
342
|
+
if not self.gather_A
|
|
343
|
+
else ("varlen_m" if varlen_args.mCuSeqlensM is not None else "varlen_k")
|
|
344
|
+
)
|
|
345
345
|
(
|
|
346
346
|
self.num_acc_stage,
|
|
347
|
-
self.
|
|
348
|
-
self.
|
|
349
|
-
self.
|
|
347
|
+
self.ab_stage,
|
|
348
|
+
self.epi_stage,
|
|
349
|
+
self.epi_c_stage,
|
|
350
350
|
) = self._compute_stages(
|
|
351
|
-
tiled_mma,
|
|
351
|
+
self.tiled_mma,
|
|
352
352
|
self.mma_tiler,
|
|
353
|
+
self.cta_tile_shape_mnk,
|
|
354
|
+
self.epi_tile,
|
|
353
355
|
self.a_dtype,
|
|
354
356
|
self.b_dtype,
|
|
355
|
-
self.
|
|
357
|
+
self.sf_dtype,
|
|
358
|
+
self.sf_vec_size,
|
|
356
359
|
self.d_dtype,
|
|
357
360
|
self.c_dtype,
|
|
358
361
|
self.d_layout,
|
|
359
362
|
self.c_layout,
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
self.smem_capacity
|
|
363
|
+
epilogue_args,
|
|
364
|
+
prefetch_A_idx,
|
|
365
|
+
cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
|
|
363
366
|
self.occupancy,
|
|
364
367
|
)
|
|
368
|
+
self.sched_stage = 1
|
|
369
|
+
self.a_prefetch_stage = (
|
|
370
|
+
0
|
|
371
|
+
if not self.gather_A
|
|
372
|
+
else (2 if varlen_args.mCuSeqlensM is not None else self.ab_stage)
|
|
373
|
+
)
|
|
365
374
|
|
|
366
375
|
# Compute A/B/SFA/SFB/C shared memory layout
|
|
367
376
|
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
|
368
|
-
tiled_mma, self.mma_tiler, self.a_dtype, self.
|
|
377
|
+
self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
|
|
369
378
|
)
|
|
379
|
+
self.a_smem_load_layout_staged = self.a_smem_layout_staged
|
|
380
|
+
if const_expr(self.gather_A):
|
|
381
|
+
self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a(
|
|
382
|
+
self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
|
|
383
|
+
)
|
|
370
384
|
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
|
371
|
-
tiled_mma, self.mma_tiler, self.b_dtype, self.
|
|
372
|
-
)
|
|
373
|
-
self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
374
|
-
self.d_dtype, self.d_layout, self.epi_tile, self.num_d_stage
|
|
385
|
+
self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage
|
|
375
386
|
)
|
|
387
|
+
self.epi_smem_layout_staged = None
|
|
388
|
+
if const_expr(self.d_dtype is not None):
|
|
389
|
+
self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
390
|
+
self.d_dtype, self.d_layout, self.epi_tile, self.epi_stage
|
|
391
|
+
)
|
|
392
|
+
self.epi_c_smem_layout_staged = None
|
|
376
393
|
if const_expr(self.c_dtype is not None):
|
|
377
394
|
self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
378
|
-
self.c_dtype, self.c_layout, self.epi_tile, self.
|
|
395
|
+
self.c_dtype, self.c_layout, self.epi_tile, self.epi_c_stage
|
|
379
396
|
)
|
|
380
|
-
else:
|
|
381
|
-
self.epi_c_smem_layout_staged = None
|
|
382
397
|
if const_expr(self.blockscaled):
|
|
383
398
|
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
|
|
384
|
-
tiled_mma,
|
|
399
|
+
self.tiled_mma,
|
|
385
400
|
self.mma_tiler,
|
|
386
401
|
self.sf_vec_size,
|
|
387
|
-
self.
|
|
402
|
+
self.ab_stage,
|
|
388
403
|
)
|
|
389
404
|
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
|
|
390
|
-
tiled_mma,
|
|
405
|
+
self.tiled_mma,
|
|
391
406
|
self.mma_tiler,
|
|
392
407
|
self.sf_vec_size,
|
|
393
|
-
self.
|
|
408
|
+
self.ab_stage,
|
|
394
409
|
)
|
|
395
410
|
else:
|
|
396
411
|
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged = None, None
|
|
@@ -398,7 +413,7 @@ class PersistentDenseGemmKernel:
|
|
|
398
413
|
# Compute the number of tensor memory allocation columns
|
|
399
414
|
if const_expr(not self.blockscaled):
|
|
400
415
|
self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
|
|
401
|
-
tiled_mma, self.mma_tiler, self.num_acc_stage
|
|
416
|
+
self.tiled_mma, self.mma_tiler, self.num_acc_stage
|
|
402
417
|
)
|
|
403
418
|
else:
|
|
404
419
|
SM100_TMEM_CAPACITY_COLUMNS = 512
|
|
@@ -409,14 +424,14 @@ class PersistentDenseGemmKernel:
|
|
|
409
424
|
self,
|
|
410
425
|
mA: cute.Tensor,
|
|
411
426
|
mB: cute.Tensor,
|
|
412
|
-
mD: cute.Tensor,
|
|
427
|
+
mD: Optional[cute.Tensor],
|
|
413
428
|
mC: Optional[cute.Tensor],
|
|
414
|
-
|
|
415
|
-
|
|
429
|
+
epilogue_args: ArgumentsBase,
|
|
430
|
+
scheduler_args: TileSchedulerOptions,
|
|
431
|
+
varlen_args: Optional[VarlenArguments],
|
|
416
432
|
stream: cuda.CUstream,
|
|
417
433
|
mSFA: Optional[cute.Tensor] = None,
|
|
418
434
|
mSFB: Optional[cute.Tensor] = None,
|
|
419
|
-
epilogue_op: cutlass.Constexpr = lambda x: x,
|
|
420
435
|
):
|
|
421
436
|
"""Execute the GEMM operation in steps:
|
|
422
437
|
- Setup static attributes before smem/grid/tma computation
|
|
@@ -435,32 +450,48 @@ class PersistentDenseGemmKernel:
|
|
|
435
450
|
:type max_active_clusters: cutlass.Constexpr
|
|
436
451
|
:param stream: CUDA stream for asynchronous execution
|
|
437
452
|
:type stream: cuda.CUstream
|
|
438
|
-
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
|
|
439
|
-
:type epilogue_op: cutlass.Constexpr
|
|
440
453
|
:raises TypeError: If input data types are incompatible with the MMA instruction.
|
|
441
454
|
:raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
|
|
442
455
|
"""
|
|
443
456
|
if const_expr(self.blockscaled):
|
|
444
457
|
assert mSFA is not None and mSFB is not None
|
|
445
458
|
# Setup static attributes before smem/grid/tma computation
|
|
446
|
-
self.a_dtype
|
|
447
|
-
self.b_dtype
|
|
448
|
-
self.d_dtype
|
|
459
|
+
self.a_dtype = mA.element_type
|
|
460
|
+
self.b_dtype = mB.element_type
|
|
461
|
+
self.d_dtype = mD.element_type if mD is not None else None
|
|
449
462
|
self.c_dtype = mC.element_type if mC is not None else None
|
|
450
463
|
self.sf_dtype: Optional[Type[cutlass.Numeric]] = (
|
|
451
464
|
mSFA.element_type if mSFA is not None else None
|
|
452
465
|
)
|
|
453
|
-
self.
|
|
454
|
-
self.
|
|
455
|
-
self.d_layout =
|
|
456
|
-
self.c_layout =
|
|
466
|
+
self.a_layout = LayoutEnum.from_tensor(mA)
|
|
467
|
+
self.b_layout = LayoutEnum.from_tensor(mB)
|
|
468
|
+
self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
|
|
469
|
+
self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
|
|
470
|
+
self.a_major_mode = LayoutEnum.from_tensor(mA).mma_major_mode()
|
|
471
|
+
self.b_major_mode = LayoutEnum.from_tensor(mB).mma_major_mode()
|
|
457
472
|
|
|
458
473
|
# Check if input data types are compatible with MMA instruction
|
|
459
474
|
if const_expr(self.a_dtype != self.b_dtype):
|
|
460
475
|
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
|
|
461
476
|
|
|
477
|
+
if const_expr(varlen_args is None):
|
|
478
|
+
varlen_args = VarlenArguments()
|
|
479
|
+
assert (varlen_args.mAIdx is not None) == self.gather_A
|
|
480
|
+
|
|
481
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
482
|
+
new_stride = lambda t: tuple(
|
|
483
|
+
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
|
484
|
+
for s in t.stride
|
|
485
|
+
)
|
|
486
|
+
mA, mD = [
|
|
487
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
488
|
+
if t is not None
|
|
489
|
+
else None
|
|
490
|
+
for t in (mA, mD)
|
|
491
|
+
]
|
|
492
|
+
|
|
462
493
|
# Setup attributes that dependent on gemm inputs
|
|
463
|
-
self._setup_attributes()
|
|
494
|
+
self._setup_attributes(epilogue_args, varlen_args)
|
|
464
495
|
|
|
465
496
|
if const_expr(self.blockscaled):
|
|
466
497
|
# Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
|
|
@@ -471,67 +502,44 @@ class PersistentDenseGemmKernel:
|
|
|
471
502
|
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(mB.shape, self.sf_vec_size)
|
|
472
503
|
mSFB = cute.make_tensor(mSFB.iterator, sfb_layout)
|
|
473
504
|
|
|
474
|
-
|
|
475
|
-
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
476
|
-
self.a_dtype,
|
|
477
|
-
self.a_major_mode,
|
|
478
|
-
self.b_major_mode,
|
|
479
|
-
self.acc_dtype,
|
|
480
|
-
self.cta_group,
|
|
481
|
-
self.mma_tiler[:2],
|
|
482
|
-
)
|
|
483
|
-
tiled_mma_sfb = None
|
|
484
|
-
else:
|
|
485
|
-
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
486
|
-
self.a_dtype,
|
|
487
|
-
self.a_major_mode,
|
|
488
|
-
self.b_major_mode,
|
|
489
|
-
self.sf_dtype,
|
|
490
|
-
self.sf_vec_size,
|
|
491
|
-
self.cta_group,
|
|
492
|
-
self.mma_inst_shape_mnk[:2],
|
|
493
|
-
)
|
|
494
|
-
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
495
|
-
self.a_dtype,
|
|
496
|
-
self.a_major_mode,
|
|
497
|
-
self.b_major_mode,
|
|
498
|
-
self.sf_dtype,
|
|
499
|
-
self.sf_vec_size,
|
|
500
|
-
cute.nvgpu.tcgen05.CtaGroup.ONE,
|
|
501
|
-
self.mma_inst_shape_mnk_sfb[:2],
|
|
502
|
-
)
|
|
503
|
-
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
|
505
|
+
atom_thr_size = cute.size(self.tiled_mma.thr_id.shape)
|
|
504
506
|
|
|
505
|
-
# Setup TMA load for A
|
|
506
|
-
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
|
507
|
+
# Setup TMA load for A & B
|
|
507
508
|
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
|
508
|
-
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
|
509
|
-
a_op,
|
|
510
|
-
mA,
|
|
511
|
-
a_smem_layout,
|
|
512
|
-
self.mma_tiler,
|
|
513
|
-
tiled_mma,
|
|
514
|
-
self.cluster_layout_vmnk.shape,
|
|
515
|
-
internal_type=(cutlass.TFloat32 if mA.element_type is cutlass.Float32 else None),
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
# Setup TMA load for B
|
|
519
|
-
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
|
520
509
|
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
|
510
|
+
tma_atom_a, tma_tensor_a = None, None
|
|
511
|
+
if const_expr(not self.gather_A):
|
|
512
|
+
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
|
513
|
+
self.cluster_shape_mnk, self.tiled_mma.thr_id
|
|
514
|
+
)
|
|
515
|
+
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
|
516
|
+
a_op,
|
|
517
|
+
mA,
|
|
518
|
+
a_smem_layout,
|
|
519
|
+
self.mma_tiler,
|
|
520
|
+
self.tiled_mma,
|
|
521
|
+
self.cluster_layout_vmnk.shape,
|
|
522
|
+
internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None),
|
|
523
|
+
)
|
|
524
|
+
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
|
525
|
+
self.cluster_shape_mnk, self.tiled_mma.thr_id
|
|
526
|
+
)
|
|
521
527
|
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
|
522
528
|
b_op,
|
|
523
529
|
mB,
|
|
524
530
|
b_smem_layout,
|
|
525
531
|
self.mma_tiler,
|
|
526
|
-
tiled_mma,
|
|
532
|
+
self.tiled_mma,
|
|
527
533
|
self.cluster_layout_vmnk.shape,
|
|
528
|
-
internal_type=(cutlass.TFloat32 if mB.element_type is
|
|
534
|
+
internal_type=(cutlass.TFloat32 if mB.element_type is Float32 else None),
|
|
529
535
|
)
|
|
530
536
|
|
|
537
|
+
tma_atom_sfa, tma_tensor_sfa = None, None
|
|
538
|
+
tma_atom_sfb, tma_tensor_sfb = None, None
|
|
531
539
|
if const_expr(self.blockscaled):
|
|
532
540
|
# Setup TMA load for SFA
|
|
533
541
|
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
|
534
|
-
self.
|
|
542
|
+
self.cluster_shape_mnk, self.tiled_mma.thr_id
|
|
535
543
|
)
|
|
536
544
|
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
|
537
545
|
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
|
@@ -539,13 +547,13 @@ class PersistentDenseGemmKernel:
|
|
|
539
547
|
mSFA,
|
|
540
548
|
sfa_smem_layout,
|
|
541
549
|
self.mma_tiler,
|
|
542
|
-
tiled_mma,
|
|
550
|
+
self.tiled_mma,
|
|
543
551
|
self.cluster_layout_vmnk.shape,
|
|
544
552
|
internal_type=cutlass.Int16,
|
|
545
553
|
)
|
|
546
554
|
# Setup TMA load for SFB
|
|
547
555
|
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
|
|
548
|
-
self.
|
|
556
|
+
self.cluster_shape_mnk, self.tiled_mma.thr_id
|
|
549
557
|
)
|
|
550
558
|
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
|
551
559
|
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
|
@@ -553,58 +561,50 @@ class PersistentDenseGemmKernel:
|
|
|
553
561
|
mSFB,
|
|
554
562
|
sfb_smem_layout,
|
|
555
563
|
self.mma_tiler_sfb,
|
|
556
|
-
tiled_mma_sfb,
|
|
564
|
+
self.tiled_mma_sfb,
|
|
557
565
|
self.cluster_layout_sfb_vmnk.shape,
|
|
558
566
|
internal_type=cutlass.Int16,
|
|
559
567
|
)
|
|
560
|
-
else:
|
|
561
|
-
tma_atom_sfa, tma_tensor_sfa = None, None
|
|
562
|
-
tma_atom_sfb, tma_tensor_sfb = None, None
|
|
563
568
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
569
|
+
self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
570
|
+
if const_expr(not self.gather_A):
|
|
571
|
+
self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
567
572
|
if const_expr(self.blockscaled):
|
|
568
573
|
sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
|
|
569
574
|
sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
|
|
570
|
-
self.num_tma_load_bytes +=
|
|
575
|
+
self.num_tma_load_bytes += sfa_copy_size + sfb_copy_size
|
|
576
|
+
self.num_tma_load_bytes *= atom_thr_size
|
|
571
577
|
|
|
572
578
|
# Setup TMA store for D
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
self.epi_tile,
|
|
579
|
-
)
|
|
580
|
-
if const_expr(mC is not None):
|
|
581
|
-
epi_c_smem_layout = cute.slice_(self.epi_c_smem_layout_staged, (None, None, 0))
|
|
582
|
-
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
|
583
|
-
cpasync.CopyBulkTensorTileG2SOp(),
|
|
584
|
-
mC,
|
|
585
|
-
epi_c_smem_layout,
|
|
579
|
+
tma_atom_d, tma_tensor_d = None, None
|
|
580
|
+
if const_expr(mD is not None):
|
|
581
|
+
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
|
582
|
+
mD,
|
|
583
|
+
self.epi_smem_layout_staged,
|
|
586
584
|
self.epi_tile,
|
|
585
|
+
op_type="store"
|
|
586
|
+
if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
|
|
587
|
+
else "add",
|
|
588
|
+
)
|
|
589
|
+
tma_atom_c, tma_tensor_c = None, None
|
|
590
|
+
if const_expr(mC is not None):
|
|
591
|
+
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
|
592
|
+
mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
|
|
587
593
|
)
|
|
588
|
-
else:
|
|
589
|
-
tma_atom_c, tma_tensor_c = None, None
|
|
590
594
|
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
TileSchedulerCls =
|
|
595
|
-
tile_sched_args =
|
|
596
|
-
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
597
|
-
raster_order=RasterOrderOption.Heuristic,
|
|
598
|
-
group_size=8,
|
|
599
|
-
cluster_shape_mnk=(*self.cluster_shape_mn, 1),
|
|
600
|
-
tile_count_semaphore=tile_count_semaphore,
|
|
601
|
-
is_persistent=True,
|
|
602
|
-
)
|
|
595
|
+
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
|
596
|
+
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
|
|
597
|
+
|
|
598
|
+
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
|
|
599
|
+
tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
|
|
603
600
|
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
604
|
-
grid = TileSchedulerCls.get_grid_shape(
|
|
601
|
+
grid = TileSchedulerCls.get_grid_shape(
|
|
602
|
+
tile_sched_params, scheduler_args.max_active_clusters
|
|
603
|
+
)
|
|
605
604
|
|
|
606
605
|
self.buffer_align_bytes = 1024
|
|
607
606
|
|
|
607
|
+
epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
|
|
608
608
|
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
|
609
609
|
sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU
|
|
610
610
|
sfa_smem_size = (
|
|
@@ -613,22 +613,33 @@ class PersistentDenseGemmKernel:
|
|
|
613
613
|
sfb_smem_size = (
|
|
614
614
|
cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0
|
|
615
615
|
)
|
|
616
|
+
a_idx_smem_size = 0
|
|
617
|
+
if const_expr(self.gather_A):
|
|
618
|
+
a_idx_smem_size = self.a_prefetch_stage * (
|
|
619
|
+
self.cta_tile_shape_mnk[0]
|
|
620
|
+
if varlen_args.mCuSeqlensM is not None
|
|
621
|
+
else self.cta_tile_shape_mnk[2]
|
|
622
|
+
)
|
|
616
623
|
|
|
617
624
|
# Define shared storage for kernel
|
|
618
625
|
@cute.struct
|
|
619
626
|
class SharedStorage:
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
627
|
+
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
628
|
+
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
629
|
+
acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
|
630
|
+
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
631
|
+
a_prefetch_pipeline_array_ptr: cute.struct.MemRange[
|
|
632
|
+
cutlass.Int64, self.a_prefetch_stage * 2
|
|
633
|
+
]
|
|
634
|
+
tile_count: cute.struct.MemRange[Int32, self.sched_stage]
|
|
625
635
|
tmem_dealloc_mbar_ptr: cutlass.Int64
|
|
626
636
|
tmem_holding_buf: Int32
|
|
627
|
-
|
|
628
|
-
tile_count: cute.struct.MemRange[cutlass.Int32, 1]
|
|
637
|
+
sAIdx: cute.struct.Align[cute.struct.MemRange[Int32, a_idx_smem_size], 16]
|
|
629
638
|
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
630
639
|
sD: cute.struct.Align[
|
|
631
|
-
cute.struct.MemRange[
|
|
640
|
+
cute.struct.MemRange[
|
|
641
|
+
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
642
|
+
],
|
|
632
643
|
self.buffer_align_bytes,
|
|
633
644
|
]
|
|
634
645
|
sC: cute.struct.Align[
|
|
@@ -637,6 +648,7 @@ class PersistentDenseGemmKernel:
|
|
|
637
648
|
],
|
|
638
649
|
self.buffer_align_bytes,
|
|
639
650
|
]
|
|
651
|
+
epi: self.epi_get_smem_struct(epilogue_params)
|
|
640
652
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
641
653
|
sA: cute.struct.Align[
|
|
642
654
|
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
|
|
@@ -662,10 +674,10 @@ class PersistentDenseGemmKernel:
|
|
|
662
674
|
|
|
663
675
|
# Launch the kernel synchronously
|
|
664
676
|
self.kernel(
|
|
665
|
-
tiled_mma,
|
|
666
|
-
tiled_mma_sfb,
|
|
677
|
+
self.tiled_mma,
|
|
678
|
+
self.tiled_mma_sfb,
|
|
667
679
|
tma_atom_a,
|
|
668
|
-
tma_tensor_a,
|
|
680
|
+
tma_tensor_a if const_expr(not self.gather_A) else mA,
|
|
669
681
|
tma_atom_b,
|
|
670
682
|
tma_tensor_b,
|
|
671
683
|
tma_atom_sfa,
|
|
@@ -676,24 +688,26 @@ class PersistentDenseGemmKernel:
|
|
|
676
688
|
tma_tensor_d,
|
|
677
689
|
tma_atom_c,
|
|
678
690
|
tma_tensor_c,
|
|
691
|
+
epilogue_params,
|
|
692
|
+
varlen_params,
|
|
679
693
|
self.cluster_layout_vmnk,
|
|
680
694
|
self.cluster_layout_sfb_vmnk,
|
|
681
695
|
self.a_smem_layout_staged,
|
|
696
|
+
self.a_smem_load_layout_staged,
|
|
682
697
|
self.b_smem_layout_staged,
|
|
683
698
|
self.sfa_smem_layout_staged,
|
|
684
699
|
self.sfb_smem_layout_staged,
|
|
685
|
-
self.
|
|
700
|
+
self.epi_smem_layout_staged,
|
|
686
701
|
self.epi_c_smem_layout_staged,
|
|
687
702
|
self.epi_tile,
|
|
688
703
|
tile_sched_params,
|
|
689
704
|
TileSchedulerCls,
|
|
690
|
-
epilogue_op,
|
|
691
705
|
).launch(
|
|
692
706
|
grid=grid,
|
|
693
707
|
block=[self.threads_per_cta, 1, 1],
|
|
694
|
-
cluster=
|
|
695
|
-
smem=self.shared_storage.size_in_bytes(),
|
|
708
|
+
cluster=self.cluster_shape_mnk,
|
|
696
709
|
stream=stream,
|
|
710
|
+
min_blocks_per_mp=1,
|
|
697
711
|
)
|
|
698
712
|
return
|
|
699
713
|
|
|
@@ -703,7 +717,7 @@ class PersistentDenseGemmKernel:
|
|
|
703
717
|
self,
|
|
704
718
|
tiled_mma: cute.TiledMma,
|
|
705
719
|
tiled_mma_sfb: Optional[cute.TiledMma],
|
|
706
|
-
tma_atom_a: cute.CopyAtom,
|
|
720
|
+
tma_atom_a: Optional[cute.CopyAtom],
|
|
707
721
|
mA_mkl: cute.Tensor,
|
|
708
722
|
tma_atom_b: cute.CopyAtom,
|
|
709
723
|
mB_nkl: cute.Tensor,
|
|
@@ -712,37 +726,52 @@ class PersistentDenseGemmKernel:
|
|
|
712
726
|
tma_atom_sfb: Optional[cute.CopyAtom],
|
|
713
727
|
mSFB_nkl: Optional[cute.Tensor],
|
|
714
728
|
tma_atom_d: Optional[cute.CopyAtom],
|
|
715
|
-
mD_mnl: cute.Tensor,
|
|
729
|
+
mD_mnl: Optional[cute.Tensor],
|
|
716
730
|
tma_atom_c: Optional[cute.CopyAtom],
|
|
717
731
|
mC_mnl: Optional[cute.Tensor],
|
|
732
|
+
epilogue_params: ParamsBase,
|
|
733
|
+
varlen_params: VarlenManager.Params,
|
|
718
734
|
cluster_layout_vmnk: cute.Layout,
|
|
719
735
|
cluster_layout_sfb_vmnk: Optional[cute.Layout],
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
736
|
+
a_smem_layout: cute.ComposedLayout,
|
|
737
|
+
a_smem_load_layout: cute.ComposedLayout,
|
|
738
|
+
b_smem_layout: cute.ComposedLayout,
|
|
739
|
+
sfa_smem_layout: Optional[cute.Layout],
|
|
740
|
+
sfb_smem_layout: Optional[cute.Layout],
|
|
741
|
+
epi_smem_layout: Union[cute.Layout, cute.ComposedLayout, None],
|
|
742
|
+
epi_c_smem_layout: Union[cute.Layout, cute.ComposedLayout, None],
|
|
726
743
|
epi_tile: cute.Tile,
|
|
727
744
|
tile_sched_params: ParamsBase,
|
|
728
745
|
TileSchedulerCls: cutlass.Constexpr[Callable],
|
|
729
|
-
epilogue_op: cutlass.Constexpr[Callable],
|
|
730
746
|
):
|
|
731
747
|
"""
|
|
732
748
|
GPU device kernel performing the Persistent batched GEMM computation.
|
|
733
749
|
"""
|
|
750
|
+
|
|
751
|
+
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
|
|
752
|
+
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
|
|
753
|
+
assert not (varlen_m and varlen_k)
|
|
754
|
+
if const_expr(self.gather_A):
|
|
755
|
+
assert varlen_m or varlen_k
|
|
756
|
+
has_D = const_expr(mD_mnl is not None)
|
|
757
|
+
has_C = const_expr(mC_mnl is not None)
|
|
758
|
+
|
|
734
759
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
735
760
|
|
|
736
|
-
#
|
|
737
|
-
#
|
|
738
|
-
#
|
|
739
|
-
if warp_idx == self.
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
761
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
762
|
+
# Prefetch Tma desc
|
|
763
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
764
|
+
if warp_idx == self.ab_load_warp_id:
|
|
765
|
+
for tma_atom in (
|
|
766
|
+
tma_atom_a,
|
|
767
|
+
tma_atom_b,
|
|
768
|
+
tma_atom_sfa,
|
|
769
|
+
tma_atom_sfb,
|
|
770
|
+
tma_atom_d,
|
|
771
|
+
tma_atom_c,
|
|
772
|
+
):
|
|
773
|
+
if const_expr(tma_atom is not None):
|
|
774
|
+
cpasync.prefetch_descriptor(tma_atom)
|
|
746
775
|
|
|
747
776
|
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
|
|
748
777
|
|
|
@@ -754,13 +783,6 @@ class PersistentDenseGemmKernel:
|
|
|
754
783
|
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
|
755
784
|
is_leader_cta = mma_tile_coord_v == 0
|
|
756
785
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
757
|
-
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
|
758
|
-
if const_expr(self.blockscaled):
|
|
759
|
-
block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
|
|
760
|
-
cta_rank_in_cluster
|
|
761
|
-
)
|
|
762
|
-
else:
|
|
763
|
-
block_in_cluster_coord_sfb_vmnk = None
|
|
764
786
|
# Coord inside cta
|
|
765
787
|
tidx, _, _ = cute.arch.thread_idx()
|
|
766
788
|
|
|
@@ -775,104 +797,68 @@ class PersistentDenseGemmKernel:
|
|
|
775
797
|
|
|
776
798
|
# Tensor memory dealloc barrier init
|
|
777
799
|
if use_2cta_instrs:
|
|
778
|
-
if warp_idx == self.
|
|
800
|
+
if warp_idx == self.ab_load_warp_id:
|
|
779
801
|
num_tmem_dealloc_threads = 32
|
|
780
802
|
cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
|
|
781
803
|
|
|
782
|
-
# Initialize
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
804
|
+
# Initialize pipelines and states
|
|
805
|
+
ab_pipeline = self.make_ab_pipeline(
|
|
806
|
+
tiled_mma=tiled_mma,
|
|
807
|
+
cluster_layout_vmnk=cluster_layout_vmnk,
|
|
808
|
+
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
809
|
+
is_leader_cta=is_leader_cta,
|
|
787
810
|
)
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
811
|
+
epi_pipeline = None
|
|
812
|
+
if const_expr(has_C):
|
|
813
|
+
epi_pipeline = self.make_epi_pipeline(
|
|
814
|
+
c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
|
|
815
|
+
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
816
|
+
)
|
|
817
|
+
acc_pipeline = self.make_acc_pipeline(
|
|
818
|
+
cluster_layout_vmnk=cluster_layout_vmnk,
|
|
819
|
+
acc_pipeline_mbar_ptr=storage.acc_pipeline_array_ptr.data_ptr(),
|
|
795
820
|
)
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
821
|
+
sched_pipeline = None
|
|
822
|
+
tile_count = None
|
|
823
|
+
if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
824
|
+
# Dynamic persistent scheduler
|
|
825
|
+
sched_pipeline = self.make_sched_pipeline(
|
|
826
|
+
self.cluster_shape_mnk,
|
|
827
|
+
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
828
|
+
has_C=has_C,
|
|
804
829
|
)
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
producer_group=epi_pipeline_producer_group,
|
|
811
|
-
consumer_group=epi_pipeline_consumer_group,
|
|
812
|
-
tx_count=tma_copy_c_bytes,
|
|
830
|
+
tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
831
|
+
a_prefetch_pipeline = None
|
|
832
|
+
if const_expr(self.gather_A):
|
|
833
|
+
a_prefetch_pipeline = self.make_a_prefetch_pipeline(
|
|
834
|
+
storage.a_prefetch_pipeline_array_ptr.data_ptr(),
|
|
813
835
|
)
|
|
814
|
-
else:
|
|
815
|
-
epi_pipeline = None
|
|
816
|
-
|
|
817
|
-
# Initialize acc_pipeline (barrier) and states
|
|
818
|
-
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
819
|
-
num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
|
|
820
|
-
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
821
|
-
pipeline.Agent.Thread, num_acc_consumer_threads
|
|
822
|
-
)
|
|
823
|
-
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
|
824
|
-
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
|
825
|
-
num_stages=self.num_acc_stage,
|
|
826
|
-
producer_group=acc_pipeline_producer_group,
|
|
827
|
-
consumer_group=acc_pipeline_consumer_group,
|
|
828
|
-
cta_layout_vmnk=cluster_layout_vmnk,
|
|
829
|
-
)
|
|
830
|
-
|
|
831
|
-
# if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
832
|
-
# # Dynamic persistent scheduler
|
|
833
|
-
# # Threads/warps participating in this pipeline
|
|
834
|
-
# sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
835
|
-
# cluster_size = cute.size(cluster_layout_vmnk)
|
|
836
|
-
# # Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
|
837
|
-
# consumer_arrive_cnt = (
|
|
838
|
-
# (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
|
|
839
|
-
# ) * cluster_size - 1
|
|
840
|
-
# sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
841
|
-
# pipeline.Agent.Thread, consumer_arrive_cnt
|
|
842
|
-
# )
|
|
843
|
-
# sched_pipeline = pipeline.PipelineAsync.create(
|
|
844
|
-
# barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
845
|
-
# num_stages=self.sched_stage,
|
|
846
|
-
# producer_group=sched_pipeline_producer_group,
|
|
847
|
-
# consumer_group=sched_pipeline_consumer_group,
|
|
848
|
-
# # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
849
|
-
# consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
|
|
850
|
-
# )
|
|
851
|
-
# tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
852
|
-
# else:
|
|
853
|
-
# sched_pipeline = None
|
|
854
|
-
# tile_count = None
|
|
855
836
|
|
|
856
837
|
# Setup smem tensor A/B/D
|
|
857
838
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
858
|
-
|
|
839
|
+
sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
|
840
|
+
sA = storage.sA.get_tensor(a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner)
|
|
859
841
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
860
|
-
sB = storage.sB.get_tensor(
|
|
842
|
+
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
|
843
|
+
sAIdx = None
|
|
844
|
+
if const_expr(self.gather_A):
|
|
845
|
+
a_idx_smem_dim = self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2]
|
|
846
|
+
a_idx_smem_layout = cute.make_layout((a_idx_smem_dim, self.a_prefetch_stage))
|
|
847
|
+
sAIdx = storage.sAIdx.get_tensor(a_idx_smem_layout)
|
|
848
|
+
sSFA, sSFB = None, None
|
|
861
849
|
if const_expr(self.blockscaled):
|
|
862
850
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
863
|
-
sSFA = storage.sSFA.get_tensor(
|
|
851
|
+
sSFA = storage.sSFA.get_tensor(sfa_smem_layout)
|
|
864
852
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
865
|
-
sSFB = storage.sSFB.get_tensor(
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
else:
|
|
875
|
-
sC = None
|
|
853
|
+
sSFB = storage.sSFB.get_tensor(sfb_smem_layout)
|
|
854
|
+
sD = None
|
|
855
|
+
if const_expr(has_D):
|
|
856
|
+
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
857
|
+
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
|
858
|
+
sC = None
|
|
859
|
+
if const_expr(has_C):
|
|
860
|
+
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
861
|
+
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
|
876
862
|
|
|
877
863
|
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
|
878
864
|
thr_mma_sfb = (
|
|
@@ -884,26 +870,51 @@ class PersistentDenseGemmKernel:
|
|
|
884
870
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
885
871
|
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
|
886
872
|
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
873
|
+
varlen_manager = VarlenManager.create(
|
|
874
|
+
varlen_params,
|
|
875
|
+
has_D,
|
|
876
|
+
self.num_epi_tensormaps,
|
|
877
|
+
# Only used if not varlen_m
|
|
878
|
+
len_m_static=Int32(
|
|
879
|
+
mA_mkl.shape[0]
|
|
880
|
+
if varlen_k or varlen_params.mAIdx is None
|
|
881
|
+
else varlen_params.mAIdx.shape[0]
|
|
882
|
+
),
|
|
883
|
+
len_k_static=Int32(mA_mkl.shape[1]),
|
|
890
884
|
)
|
|
891
885
|
|
|
892
|
-
TileSchedulerCls = partial(
|
|
893
|
-
|
|
886
|
+
TileSchedulerCls = partial(
|
|
887
|
+
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
|
888
|
+
)
|
|
894
889
|
|
|
895
|
-
|
|
890
|
+
tmem_alloc_barrier = pipeline.NamedBarrier(
|
|
891
|
+
barrier_id=int(NamedBarrierGemm.TmemPtr),
|
|
892
|
+
num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
|
|
893
|
+
)
|
|
894
|
+
epi_load_barrier = None
|
|
895
|
+
if const_expr(has_C):
|
|
896
896
|
epi_load_barrier = pipeline.NamedBarrier(
|
|
897
|
-
barrier_id=int(
|
|
897
|
+
barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE
|
|
898
898
|
)
|
|
899
|
-
else:
|
|
900
|
-
epi_load_barrier = None
|
|
901
899
|
|
|
902
900
|
#
|
|
903
|
-
# Specialized
|
|
901
|
+
# Specialized AB load warps
|
|
904
902
|
#
|
|
905
|
-
if warp_idx == self.
|
|
903
|
+
if warp_idx == self.ab_load_warp_id:
|
|
904
|
+
is_tma_warp = True
|
|
905
|
+
# initialize tensormap for A & B
|
|
906
|
+
varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
|
|
907
|
+
tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
|
|
908
|
+
tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
|
|
906
909
|
# Compute multicast mask for A/B buffer full
|
|
910
|
+
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
|
911
|
+
block_in_cluster_coord_sfb_vmnk = None
|
|
912
|
+
if const_expr(self.blockscaled):
|
|
913
|
+
block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
|
|
914
|
+
cta_rank_in_cluster
|
|
915
|
+
)
|
|
916
|
+
a_mcast_mask, b_mcast_mask = None, None
|
|
917
|
+
sfa_mcast_mask, sfb_mcast_mask = None, None
|
|
907
918
|
if const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
|
908
919
|
a_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
909
920
|
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
|
@@ -918,141 +929,139 @@ class PersistentDenseGemmKernel:
|
|
|
918
929
|
sfb_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
919
930
|
cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1
|
|
920
931
|
)
|
|
921
|
-
else:
|
|
922
|
-
sfa_mcast_mask, sfb_mcast_mask = None, None
|
|
923
|
-
else:
|
|
924
|
-
a_mcast_mask, b_mcast_mask = None, None
|
|
925
|
-
sfa_mcast_mask, sfb_mcast_mask = None, None
|
|
926
932
|
|
|
927
933
|
# Persistent tile scheduling loop
|
|
928
934
|
tile_scheduler = TileSchedulerCls()
|
|
929
935
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
930
936
|
ab_producer_state = pipeline.make_pipeline_state(
|
|
931
|
-
pipeline.PipelineUserType.Producer, self.
|
|
937
|
+
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
932
938
|
)
|
|
933
|
-
|
|
939
|
+
if const_expr(varlen_k):
|
|
940
|
+
# wait tensormap initialization complete before update
|
|
941
|
+
varlen_manager.fence_tensormap_init()
|
|
942
|
+
do_epi_load_barrier_arrive = Boolean(True)
|
|
934
943
|
while work_tile.is_valid_tile:
|
|
935
|
-
# Get tile coord from tile scheduler
|
|
936
944
|
tile_coord_mnkl = work_tile.tile_idx
|
|
945
|
+
batch_idx = tile_coord_mnkl[3]
|
|
946
|
+
varlen_manager.update_tensormap_AB(
|
|
947
|
+
batch_idx,
|
|
948
|
+
self.a_layout,
|
|
949
|
+
self.b_layout,
|
|
950
|
+
is_tma_warp,
|
|
951
|
+
)
|
|
952
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
953
|
+
# Local_tile partition global tensors
|
|
954
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
937
955
|
mma_tile_coord_mnl = (
|
|
938
956
|
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
939
957
|
tile_coord_mnkl[1],
|
|
940
958
|
tile_coord_mnkl[3],
|
|
941
959
|
)
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
cute.
|
|
947
|
-
|
|
948
|
-
|
|
960
|
+
gA_mk = None
|
|
961
|
+
if const_expr(not self.gather_A):
|
|
962
|
+
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
|
963
|
+
# (bM, bK, RestK)
|
|
964
|
+
gA_mk = cute.local_tile(
|
|
965
|
+
mA_mk,
|
|
966
|
+
cute.select(self.mma_tiler, [0, 2]),
|
|
967
|
+
(mma_tile_coord_mnl[0], None),
|
|
968
|
+
)
|
|
949
969
|
# (bN, bK, RestK)
|
|
950
|
-
|
|
951
|
-
mB_nkl,
|
|
952
|
-
cute.
|
|
953
|
-
(mma_tile_coord_mnl[1], None
|
|
970
|
+
gB_nk = cute.local_tile(
|
|
971
|
+
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
|
|
972
|
+
cute.select(self.mma_tiler, [1, 2]),
|
|
973
|
+
(mma_tile_coord_mnl[1], None),
|
|
954
974
|
)
|
|
955
975
|
if const_expr(self.blockscaled):
|
|
956
976
|
# (bM, bK)
|
|
957
977
|
gSFA_mkl = cute.local_tile(
|
|
958
|
-
mSFA_mkl,
|
|
959
|
-
cute.
|
|
960
|
-
(mma_tile_coord_mnl[0], None
|
|
978
|
+
varlen_manager.offset_batch_A(mSFA_mkl, batch_idx),
|
|
979
|
+
cute.select(self.mma_tiler, [0, 2]),
|
|
980
|
+
(mma_tile_coord_mnl[0], None),
|
|
961
981
|
)
|
|
962
982
|
# (bN, bK)
|
|
963
983
|
gSFB_nkl = cute.local_tile(
|
|
964
|
-
mSFB_nkl,
|
|
965
|
-
cute.
|
|
966
|
-
(mma_tile_coord_mnl[1], None
|
|
984
|
+
varlen_manager.offset_batch_B(mSFB_nkl, batch_idx),
|
|
985
|
+
cute.select(self.mma_tiler, [1, 2]),
|
|
986
|
+
(mma_tile_coord_mnl[1], None),
|
|
967
987
|
)
|
|
988
|
+
|
|
968
989
|
# Partition global tensor for TiledMMA_A/B/D
|
|
969
|
-
#
|
|
970
|
-
|
|
990
|
+
# Then partition global/shared tensor for TMA load A/B
|
|
991
|
+
varlen_manager.fence_tensormap_update_AB(is_tma_warp)
|
|
992
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
993
|
+
# TMA load A partition_S/D
|
|
994
|
+
a_cta_layout = cute.make_layout(
|
|
995
|
+
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
|
996
|
+
)
|
|
997
|
+
copy_A = None
|
|
998
|
+
if const_expr(not self.gather_A):
|
|
999
|
+
# (MMA, MMA_M, MMA_K, RestK)
|
|
1000
|
+
tCgA = thr_mma.partition_A(gA_mk)
|
|
1001
|
+
copy_A, _, _ = copy_utils.tma_get_copy_fn(
|
|
1002
|
+
tma_atom_a,
|
|
1003
|
+
cta_coord=block_in_cluster_coord_vmnk[2],
|
|
1004
|
+
cta_layout=a_cta_layout,
|
|
1005
|
+
src_tensor=tCgA,
|
|
1006
|
+
dst_tensor=sA,
|
|
1007
|
+
mcast_mask=a_mcast_mask,
|
|
1008
|
+
tma_desc_ptr=tma_desc_a_ptr,
|
|
1009
|
+
)
|
|
971
1010
|
# (MMA, MMA_N, MMA_K, RestK)
|
|
972
|
-
tCgB = thr_mma.partition_B(
|
|
1011
|
+
tCgB = thr_mma.partition_B(gB_nk)
|
|
973
1012
|
if const_expr(self.blockscaled):
|
|
974
1013
|
# (MMA, MMA_M, MMA_K)
|
|
975
1014
|
tCgSFA = thr_mma.partition_A(gSFA_mkl)
|
|
976
1015
|
# (MMA, MMA_N, MMA_K)
|
|
977
1016
|
tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
|
|
978
|
-
# Partition global/shared tensor for TMA load A/B
|
|
979
|
-
# TMA load A partition_S/D
|
|
980
|
-
a_cta_layout = cute.make_layout(
|
|
981
|
-
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
|
982
|
-
)
|
|
983
|
-
# ((atom_v, rest_v), STAGE)
|
|
984
|
-
# ((atom_v, rest_v), RestK)
|
|
985
|
-
tAsA, tAgA = cpasync.tma_partition(
|
|
986
|
-
tma_atom_a,
|
|
987
|
-
block_in_cluster_coord_vmnk[2],
|
|
988
|
-
a_cta_layout,
|
|
989
|
-
cute.group_modes(sA, 0, 3),
|
|
990
|
-
cute.group_modes(tCgA, 0, 3),
|
|
991
|
-
)
|
|
992
1017
|
# TMA load B partition_S/D
|
|
993
|
-
|
|
994
|
-
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
995
|
-
)
|
|
996
|
-
# ((atom_v, rest_v), STAGE)
|
|
997
|
-
# ((atom_v, rest_v), RestK)
|
|
998
|
-
tBsB, tBgB = cpasync.tma_partition(
|
|
1018
|
+
copy_B, _, _ = copy_utils.tma_get_copy_fn(
|
|
999
1019
|
tma_atom_b,
|
|
1000
|
-
block_in_cluster_coord_vmnk[1],
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1020
|
+
cta_coord=block_in_cluster_coord_vmnk[1],
|
|
1021
|
+
cta_layout=cute.make_layout(
|
|
1022
|
+
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
1023
|
+
),
|
|
1024
|
+
src_tensor=tCgB,
|
|
1025
|
+
dst_tensor=sB,
|
|
1026
|
+
mcast_mask=b_mcast_mask,
|
|
1027
|
+
tma_desc_ptr=tma_desc_b_ptr,
|
|
1004
1028
|
)
|
|
1029
|
+
copy_SFA, copy_SFB = None, None
|
|
1005
1030
|
if const_expr(self.blockscaled):
|
|
1006
1031
|
# TMA load SFA partition_S/D
|
|
1007
|
-
|
|
1008
|
-
# ((atom_v, rest_v), STAGE)
|
|
1009
|
-
# ((atom_v, rest_v), RestK)
|
|
1010
|
-
tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition(
|
|
1032
|
+
copy_SFA, _, _ = copy_utils.tma_get_copy_fn(
|
|
1011
1033
|
tma_atom_sfa,
|
|
1012
|
-
block_in_cluster_coord_vmnk[2],
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1034
|
+
cta_coord=block_in_cluster_coord_vmnk[2],
|
|
1035
|
+
cta_layout=a_cta_layout,
|
|
1036
|
+
src_tensor=tCgSFA,
|
|
1037
|
+
dst_tensor=sSFA,
|
|
1038
|
+
filter_zeros=True,
|
|
1039
|
+
mcast_mask=sfa_mcast_mask,
|
|
1040
|
+
# tma_desc_ptr=tma_desc_sfa_ptr,
|
|
1016
1041
|
)
|
|
1017
|
-
tAsSFA = cute.filter_zeros(tAsSFA)
|
|
1018
|
-
tAgSFA = cute.filter_zeros(tAgSFA)
|
|
1019
1042
|
# TMA load SFB partition_S/D
|
|
1020
1043
|
sfb_cta_layout = cute.make_layout(
|
|
1021
1044
|
cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape
|
|
1022
1045
|
)
|
|
1023
|
-
|
|
1024
|
-
# ((atom_v, rest_v), RestK)
|
|
1025
|
-
tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition(
|
|
1046
|
+
copy_SFB, _, _ = copy_utils.tma_get_copy_fn(
|
|
1026
1047
|
tma_atom_sfb,
|
|
1027
|
-
block_in_cluster_coord_sfb_vmnk[1],
|
|
1028
|
-
sfb_cta_layout,
|
|
1029
|
-
|
|
1030
|
-
|
|
1048
|
+
cta_coord=block_in_cluster_coord_sfb_vmnk[1],
|
|
1049
|
+
cta_layout=sfb_cta_layout,
|
|
1050
|
+
src_tensor=tCgSFB,
|
|
1051
|
+
dst_tensor=sSFB,
|
|
1052
|
+
filter_zeros=True,
|
|
1053
|
+
mcast_mask=sfb_mcast_mask,
|
|
1054
|
+
# tma_desc_ptr=tma_desc_sfa_ptr,
|
|
1031
1055
|
)
|
|
1032
|
-
|
|
1033
|
-
tBgSFB = cute.filter_zeros(tBgSFB)
|
|
1034
|
-
else:
|
|
1035
|
-
tAsSFA, tAgSFA = None, None
|
|
1036
|
-
tBsSFB, tBgSFB = None, None
|
|
1056
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1037
1057
|
ab_producer_state = self.load_AB(
|
|
1038
1058
|
ab_pipeline,
|
|
1039
1059
|
ab_producer_state,
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
tBgB,
|
|
1046
|
-
tBsB,
|
|
1047
|
-
b_mcast_mask,
|
|
1048
|
-
tma_atom_sfa,
|
|
1049
|
-
tAgSFA,
|
|
1050
|
-
tAsSFA,
|
|
1051
|
-
sfa_mcast_mask,
|
|
1052
|
-
tma_atom_sfb,
|
|
1053
|
-
tBgSFB,
|
|
1054
|
-
tBsSFB,
|
|
1055
|
-
sfb_mcast_mask,
|
|
1060
|
+
copy_A,
|
|
1061
|
+
copy_B,
|
|
1062
|
+
k_tile_cnt,
|
|
1063
|
+
copy_SFA,
|
|
1064
|
+
copy_SFB,
|
|
1056
1065
|
)
|
|
1057
1066
|
if const_expr(epi_load_barrier is not None):
|
|
1058
1067
|
# In the first work tile, the epi load warp will wait for the signal
|
|
@@ -1060,58 +1069,209 @@ class PersistentDenseGemmKernel:
|
|
|
1060
1069
|
# with loading A and B.
|
|
1061
1070
|
if do_epi_load_barrier_arrive:
|
|
1062
1071
|
epi_load_barrier.arrive()
|
|
1063
|
-
do_epi_load_barrier_arrive =
|
|
1072
|
+
do_epi_load_barrier_arrive = Boolean(False)
|
|
1064
1073
|
# Advance to next tile
|
|
1065
1074
|
tile_scheduler.advance_to_next_work()
|
|
1066
1075
|
work_tile = tile_scheduler.get_current_work()
|
|
1067
1076
|
# Wait A/B buffer empty
|
|
1068
1077
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
1069
1078
|
|
|
1079
|
+
if const_expr(self.gather_A):
|
|
1080
|
+
if (
|
|
1081
|
+
warp_idx >= self.ab_load_warp_id + 1
|
|
1082
|
+
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
1083
|
+
):
|
|
1084
|
+
# Persistent tile scheduling loop
|
|
1085
|
+
tile_scheduler = TileSchedulerCls()
|
|
1086
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1087
|
+
ab_producer_state = pipeline.make_pipeline_state(
|
|
1088
|
+
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
1089
|
+
)
|
|
1090
|
+
a_prefetch_consumer_state = pipeline.make_pipeline_state(
|
|
1091
|
+
pipeline.PipelineUserType.Consumer, self.a_prefetch_stage
|
|
1092
|
+
)
|
|
1093
|
+
while work_tile.is_valid_tile:
|
|
1094
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1095
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1096
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
1097
|
+
# Local_tile partition global tensors
|
|
1098
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
1099
|
+
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
|
1100
|
+
if const_expr(varlen_m):
|
|
1101
|
+
# (M, K)
|
|
1102
|
+
mA_mk = mA_mkl
|
|
1103
|
+
else:
|
|
1104
|
+
assert varlen_k
|
|
1105
|
+
# (tile_M, K)
|
|
1106
|
+
mA_mk = cute.local_tile(
|
|
1107
|
+
mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
|
|
1108
|
+
)
|
|
1109
|
+
# Partition global tensor for TiledMMA_A/B/D
|
|
1110
|
+
len_m = varlen_manager.len_m(batch_idx)
|
|
1111
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
1112
|
+
# TMA load A partition_S/D
|
|
1113
|
+
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
|
1114
|
+
mA_mkl.element_type, self.a_layout, (self.num_ab_load_warps - 1) * 32
|
|
1115
|
+
)
|
|
1116
|
+
tidx = cute.arch.thread_idx()[0] - (self.ab_load_warp_id + 1) * 32
|
|
1117
|
+
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
1118
|
+
copy_A, prefetch_A = None, None
|
|
1119
|
+
if const_expr(varlen_m):
|
|
1120
|
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
|
1121
|
+
copy_A = copy_utils.gather_m_get_copy_fn(
|
|
1122
|
+
thr_copy_A,
|
|
1123
|
+
mA_mk,
|
|
1124
|
+
sA,
|
|
1125
|
+
sAIdx[None, a_prefetch_consumer_state.index],
|
|
1126
|
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
1127
|
+
limit_k=len_k,
|
|
1128
|
+
)
|
|
1129
|
+
cute.arch.sync_warp()
|
|
1130
|
+
with cute.arch.elect_one():
|
|
1131
|
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
|
1132
|
+
a_prefetch_consumer_state.advance()
|
|
1133
|
+
else:
|
|
1134
|
+
copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
|
|
1135
|
+
thr_copy_A,
|
|
1136
|
+
mA_mk,
|
|
1137
|
+
sA,
|
|
1138
|
+
sAIdx,
|
|
1139
|
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
1140
|
+
limit_k=len_k,
|
|
1141
|
+
)
|
|
1142
|
+
prefetch_A = partial(prefetch_A, a_prefetch_pipeline)
|
|
1143
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1144
|
+
ab_producer_state, a_prefetch_consumer_state = self.load_A_gather_A(
|
|
1145
|
+
ab_pipeline,
|
|
1146
|
+
ab_producer_state,
|
|
1147
|
+
a_prefetch_consumer_state,
|
|
1148
|
+
copy_A,
|
|
1149
|
+
prefetch_A,
|
|
1150
|
+
k_tile_cnt,
|
|
1151
|
+
)
|
|
1152
|
+
# Advance to next tile
|
|
1153
|
+
tile_scheduler.advance_to_next_work()
|
|
1154
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1155
|
+
|
|
1156
|
+
#
|
|
1157
|
+
# Specialized scheduler warp. Will also prefetch A indices if gatherA
|
|
1158
|
+
#
|
|
1159
|
+
if const_expr(tile_sched_params.tile_count_semaphore is not None or self.gather_A):
|
|
1160
|
+
if warp_idx == self.scheduler_warp_id:
|
|
1161
|
+
is_scheduler_warp = True
|
|
1162
|
+
if const_expr(cute.size(cluster_layout_vmnk) > 1):
|
|
1163
|
+
is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0
|
|
1164
|
+
tile_M = self.cta_tile_shape_mnk[0]
|
|
1165
|
+
tile_K = self.cta_tile_shape_mnk[2]
|
|
1166
|
+
thr_copy_AIdx, tAsAIdx, tAcAIdx = None, None, None
|
|
1167
|
+
if const_expr(self.gather_A):
|
|
1168
|
+
tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True)
|
|
1169
|
+
thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx())
|
|
1170
|
+
tAsAIdx = thr_copy_AIdx.partition_D(sAIdx)
|
|
1171
|
+
tAcAIdx = thr_copy_AIdx.partition_S(
|
|
1172
|
+
cute.make_identity_tensor(tile_M if varlen_m else tile_K)
|
|
1173
|
+
)
|
|
1174
|
+
# Persistent tile scheduling loop
|
|
1175
|
+
tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
|
|
1176
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1177
|
+
a_prefetch_producer_state = None
|
|
1178
|
+
if const_expr(self.gather_A):
|
|
1179
|
+
a_prefetch_producer_state = pipeline.make_pipeline_state(
|
|
1180
|
+
pipeline.PipelineUserType.Producer, self.a_prefetch_stage
|
|
1181
|
+
)
|
|
1182
|
+
while work_tile.is_valid_tile:
|
|
1183
|
+
if const_expr(self.gather_A):
|
|
1184
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1185
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1186
|
+
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
|
1187
|
+
if const_expr(varlen_m):
|
|
1188
|
+
# (tile_M,)
|
|
1189
|
+
gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],))
|
|
1190
|
+
tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
|
|
1191
|
+
len_m = varlen_manager.len_m(batch_idx)
|
|
1192
|
+
m_limit = len_m - tile_coord_mnkl[0] * tile_M
|
|
1193
|
+
tApAIdx_m = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
|
|
1194
|
+
for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
|
|
1195
|
+
tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit
|
|
1196
|
+
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
|
|
1197
|
+
cute.copy(
|
|
1198
|
+
thr_copy_AIdx,
|
|
1199
|
+
tAgAIdx,
|
|
1200
|
+
tAsAIdx[None, None, a_prefetch_producer_state.index],
|
|
1201
|
+
pred=tApAIdx_m,
|
|
1202
|
+
)
|
|
1203
|
+
a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
|
|
1204
|
+
a_prefetch_producer_state.advance()
|
|
1205
|
+
else:
|
|
1206
|
+
# (tile_K, RestK)
|
|
1207
|
+
gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,))
|
|
1208
|
+
tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
|
|
1209
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
1210
|
+
k_tile_cnt = cute.ceil_div(len_k, tile_K)
|
|
1211
|
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1212
|
+
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
|
|
1213
|
+
cute.copy(
|
|
1214
|
+
thr_copy_AIdx,
|
|
1215
|
+
tAgAIdx[None, None, k_tile],
|
|
1216
|
+
tAsAIdx[None, None, a_prefetch_producer_state.index],
|
|
1217
|
+
)
|
|
1218
|
+
a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
|
|
1219
|
+
a_prefetch_producer_state.advance()
|
|
1220
|
+
if 0 < k_tile_cnt:
|
|
1221
|
+
k_tile = k_tile_cnt - 1
|
|
1222
|
+
k_limit = len_k - k_tile * tile_K
|
|
1223
|
+
tApAIdx_k = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
|
|
1224
|
+
for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
|
|
1225
|
+
tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit
|
|
1226
|
+
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
|
|
1227
|
+
cute.copy(
|
|
1228
|
+
tiled_copy_AIdx,
|
|
1229
|
+
tAgAIdx[None, None, k_tile],
|
|
1230
|
+
tAsAIdx[None, None, a_prefetch_producer_state.index],
|
|
1231
|
+
pred=tApAIdx_k,
|
|
1232
|
+
)
|
|
1233
|
+
a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
|
|
1234
|
+
a_prefetch_producer_state.advance()
|
|
1235
|
+
# Advance to next tile
|
|
1236
|
+
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1237
|
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1238
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1239
|
+
# End of persistent scheduler loop
|
|
1240
|
+
if is_scheduler_warp:
|
|
1241
|
+
tile_scheduler.producer_tail()
|
|
1242
|
+
|
|
1070
1243
|
#
|
|
1071
1244
|
# Specialized TMA epi load warp
|
|
1072
1245
|
#
|
|
1073
1246
|
if const_expr(mC_mnl is not None):
|
|
1074
|
-
if warp_idx == self.
|
|
1247
|
+
if warp_idx == self.epi_load_warp_id:
|
|
1075
1248
|
epi_producer_state = pipeline.make_pipeline_state(
|
|
1076
|
-
pipeline.PipelineUserType.Producer, self.
|
|
1249
|
+
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
1077
1250
|
)
|
|
1078
|
-
do_epi_load_barrier_wait =
|
|
1251
|
+
do_epi_load_barrier_wait = Boolean(True)
|
|
1079
1252
|
# Persistent tile scheduling loop
|
|
1080
1253
|
tile_scheduler = TileSchedulerCls()
|
|
1081
1254
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1082
1255
|
while work_tile.is_valid_tile:
|
|
1083
1256
|
# Get tile coord from tile scheduler
|
|
1084
1257
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
|
|
1258
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1259
|
+
copy_C_fn, _, bGS_gC = self.epilog_gmem_copy_and_partition(
|
|
1260
|
+
tma_atom_c,
|
|
1261
|
+
varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
|
|
1262
|
+
self.cta_tile_shape_mnk[:2],
|
|
1263
|
+
epi_tile,
|
|
1264
|
+
sC,
|
|
1265
|
+
tile_coord_mnkl,
|
|
1094
1266
|
)
|
|
1095
|
-
|
|
1096
|
-
# (MMA, MMA_M, MMA_N)
|
|
1097
|
-
tCgC = thr_mma.partition_C(gC_mnl)
|
|
1098
|
-
# bGS_gC has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
1099
|
-
bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
|
|
1100
|
-
tma_atom_c, tCgC, epi_tile, sC
|
|
1101
|
-
)
|
|
1102
|
-
bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
|
|
1267
|
+
copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
|
|
1103
1268
|
if do_epi_load_barrier_wait:
|
|
1104
1269
|
epi_load_barrier.arrive_and_wait()
|
|
1105
|
-
do_epi_load_barrier_wait =
|
|
1270
|
+
do_epi_load_barrier_wait = Boolean(False)
|
|
1106
1271
|
epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1]))
|
|
1107
|
-
for
|
|
1272
|
+
for epi_idx in cutlass.range(epi_tile_num, unroll=1):
|
|
1108
1273
|
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1109
|
-
|
|
1110
|
-
tma_atom_c,
|
|
1111
|
-
bGS_gC[None, subtile_idx],
|
|
1112
|
-
bGS_sC[None, epi_producer_state.index],
|
|
1113
|
-
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1114
|
-
)
|
|
1274
|
+
copy_C(src_idx=epi_idx, producer_state=epi_producer_state)
|
|
1115
1275
|
# Epi pipeline's producer commit is a NOP
|
|
1116
1276
|
epi_pipeline.producer_commit(epi_producer_state)
|
|
1117
1277
|
epi_producer_state.advance()
|
|
@@ -1132,7 +1292,7 @@ class PersistentDenseGemmKernel:
|
|
|
1132
1292
|
)
|
|
1133
1293
|
# Partition shared/tensor memory tensor for TiledMMA_A/B/D
|
|
1134
1294
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
1135
|
-
tCrA = tiled_mma.make_fragment_A(
|
|
1295
|
+
tCrA = tiled_mma.make_fragment_A(sA_mma)
|
|
1136
1296
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
1137
1297
|
tCrB = tiled_mma.make_fragment_B(sB)
|
|
1138
1298
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
@@ -1149,10 +1309,9 @@ class PersistentDenseGemmKernel:
|
|
|
1149
1309
|
tiled_mma,
|
|
1150
1310
|
self.mma_tiler,
|
|
1151
1311
|
self.sf_vec_size,
|
|
1152
|
-
cute.slice_(
|
|
1312
|
+
cute.slice_(sfa_smem_layout, (None, None, None, 0)),
|
|
1153
1313
|
)
|
|
1154
1314
|
tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
|
|
1155
|
-
|
|
1156
1315
|
# Make SFB tmem tensor
|
|
1157
1316
|
sfb_tmem_ptr = cute.recast_ptr(
|
|
1158
1317
|
acc_tmem_ptr
|
|
@@ -1165,7 +1324,7 @@ class PersistentDenseGemmKernel:
|
|
|
1165
1324
|
tiled_mma,
|
|
1166
1325
|
self.mma_tiler,
|
|
1167
1326
|
self.sf_vec_size,
|
|
1168
|
-
cute.slice_(
|
|
1327
|
+
cute.slice_(sfb_smem_layout, (None, None, None, 0)),
|
|
1169
1328
|
)
|
|
1170
1329
|
tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
|
|
1171
1330
|
# Partition for S2T copy of SFA/SFB
|
|
@@ -1180,6 +1339,7 @@ class PersistentDenseGemmKernel:
|
|
|
1180
1339
|
tCtSFB_compact_s2t,
|
|
1181
1340
|
) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
|
|
1182
1341
|
else:
|
|
1342
|
+
tCtSFA, tCtSFB = None, None
|
|
1183
1343
|
tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
|
|
1184
1344
|
tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
|
|
1185
1345
|
|
|
@@ -1187,7 +1347,7 @@ class PersistentDenseGemmKernel:
|
|
|
1187
1347
|
tile_scheduler = TileSchedulerCls()
|
|
1188
1348
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1189
1349
|
ab_consumer_state = pipeline.make_pipeline_state(
|
|
1190
|
-
pipeline.PipelineUserType.Consumer, self.
|
|
1350
|
+
pipeline.PipelineUserType.Consumer, self.ab_stage
|
|
1191
1351
|
)
|
|
1192
1352
|
acc_producer_state = pipeline.make_pipeline_state(
|
|
1193
1353
|
pipeline.PipelineUserType.Producer, self.num_acc_stage
|
|
@@ -1195,6 +1355,9 @@ class PersistentDenseGemmKernel:
|
|
|
1195
1355
|
while work_tile.is_valid_tile:
|
|
1196
1356
|
# Get tile coord from tile scheduler
|
|
1197
1357
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1358
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1359
|
+
k_len = varlen_manager.len_k(batch_idx)
|
|
1360
|
+
k_tile_cnt = cute.ceil_div(k_len, self.mma_tiler[2])
|
|
1198
1361
|
# Set tensor memory buffer for current tile
|
|
1199
1362
|
# (MMA, MMA_M, MMA_N)
|
|
1200
1363
|
tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index]
|
|
@@ -1209,6 +1372,9 @@ class PersistentDenseGemmKernel:
|
|
|
1209
1372
|
tCtAcc,
|
|
1210
1373
|
k_tile_cnt,
|
|
1211
1374
|
is_leader_cta,
|
|
1375
|
+
cta_rank_in_cluster,
|
|
1376
|
+
tCtSFA,
|
|
1377
|
+
tCtSFB,
|
|
1212
1378
|
tiled_copy_s2t_sfa,
|
|
1213
1379
|
tiled_copy_s2t_sfb,
|
|
1214
1380
|
tCsSFA_compact_s2t,
|
|
@@ -1234,6 +1400,14 @@ class PersistentDenseGemmKernel:
|
|
|
1234
1400
|
)
|
|
1235
1401
|
# Bar sync for retrieve tensor memory ptr from shared memory
|
|
1236
1402
|
tmem_alloc_barrier.arrive_and_wait()
|
|
1403
|
+
|
|
1404
|
+
is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
|
|
1405
|
+
varlen_manager.init_tensormap_epi(
|
|
1406
|
+
tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
|
|
1407
|
+
)
|
|
1408
|
+
tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
|
|
1409
|
+
tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
|
|
1410
|
+
|
|
1237
1411
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
1238
1412
|
acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
1239
1413
|
self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
|
|
@@ -1241,9 +1415,9 @@ class PersistentDenseGemmKernel:
|
|
|
1241
1415
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
1242
1416
|
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
|
1243
1417
|
|
|
1244
|
-
epilog_threads = cute.arch.WARP_SIZE * len(self.epilog_warp_id)
|
|
1245
1418
|
epilogue_barrier = pipeline.NamedBarrier(
|
|
1246
|
-
barrier_id=
|
|
1419
|
+
barrier_id=int(NamedBarrierGemm.Epilogue),
|
|
1420
|
+
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
|
1247
1421
|
)
|
|
1248
1422
|
|
|
1249
1423
|
# Partition for epilogue
|
|
@@ -1252,19 +1426,16 @@ class PersistentDenseGemmKernel:
|
|
|
1252
1426
|
epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
|
|
1253
1427
|
)
|
|
1254
1428
|
|
|
1255
|
-
tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.
|
|
1256
|
-
tiled_copy_r2s, tRS_rD, tRS_sD = self.
|
|
1257
|
-
tiled_copy_t2r, tTR_rD,
|
|
1429
|
+
tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.acc_dtype)
|
|
1430
|
+
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
|
1431
|
+
tiled_copy_t2r, self.d_layout, self.d_dtype, tTR_rD, sD, epi_tidx
|
|
1258
1432
|
)
|
|
1433
|
+
tRS_rC, tSR_rC, tSR_sC = None, None, None
|
|
1434
|
+
tiled_copy_s2r = None
|
|
1259
1435
|
if const_expr(mC_mnl is not None):
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
tiled_copy_t2r, tTR_rC, epi_tidx, sC
|
|
1436
|
+
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
|
1437
|
+
tiled_copy_t2r, self.c_layout, self.c_dtype, sC, tRS_rD.layout, epi_tidx
|
|
1263
1438
|
)
|
|
1264
|
-
# TODO: for m major, D is being stored w STSM so we'd need LDSM here
|
|
1265
|
-
# tRS_rC = tSR_rC # TODO: retile?
|
|
1266
|
-
tRS_rC = cute.make_fragment(tRS_rD.layout, self.c_dtype)
|
|
1267
|
-
tSR_rC = tiled_copy_s2r.get_slice(epi_tidx).retile(tRS_rC)
|
|
1268
1439
|
|
|
1269
1440
|
# Persistent tile scheduling loop
|
|
1270
1441
|
tile_scheduler = TileSchedulerCls()
|
|
@@ -1272,37 +1443,27 @@ class PersistentDenseGemmKernel:
|
|
|
1272
1443
|
acc_consumer_state = pipeline.make_pipeline_state(
|
|
1273
1444
|
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
|
1274
1445
|
)
|
|
1275
|
-
|
|
1276
|
-
d_producer_group = pipeline.CooperativeGroup(
|
|
1277
|
-
pipeline.Agent.Thread,
|
|
1278
|
-
32 * len(self.epilog_warp_id),
|
|
1279
|
-
32 * len(self.epilog_warp_id),
|
|
1280
|
-
)
|
|
1281
|
-
d_pipeline = pipeline.PipelineTmaStore.create(
|
|
1282
|
-
num_stages=self.num_d_stage, producer_group=d_producer_group
|
|
1283
|
-
)
|
|
1446
|
+
epi_store_pipeline = self.make_epi_store_pipeline()
|
|
1284
1447
|
epi_read_state = pipeline.make_pipeline_state(
|
|
1285
|
-
pipeline.PipelineUserType.Consumer, self.
|
|
1448
|
+
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
|
1286
1449
|
)
|
|
1287
|
-
|
|
1450
|
+
if const_expr(varlen_m):
|
|
1451
|
+
# wait tensormap initialization complete before update
|
|
1452
|
+
varlen_manager.fence_tensormap_init()
|
|
1288
1453
|
while work_tile.is_valid_tile:
|
|
1289
1454
|
# Get tile coord from tile scheduler
|
|
1290
1455
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
tile_coord_mnkl[3],
|
|
1456
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1457
|
+
epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
|
|
1458
|
+
epilogue_params, varlen_params.cu_seqlens_m, batch_idx
|
|
1295
1459
|
)
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1460
|
+
varlen_manager.update_tensormap_epi(
|
|
1461
|
+
batch_idx,
|
|
1462
|
+
self.d_layout,
|
|
1463
|
+
epi_shapes,
|
|
1464
|
+
epi_orders,
|
|
1465
|
+
is_tma_warp,
|
|
1300
1466
|
)
|
|
1301
|
-
# Partition global tensor for TiledMMA_A/B/D
|
|
1302
|
-
# (MMA, MMA_M, MMA_N)
|
|
1303
|
-
tDgD = thr_mma.partition_C(gD_mnl)
|
|
1304
|
-
# bSG_gD has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
1305
|
-
bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(tma_atom_d, tDgD, epi_tile, sD)
|
|
1306
1467
|
|
|
1307
1468
|
# Set tensor memory buffer for current tile
|
|
1308
1469
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
|
@@ -1311,49 +1472,59 @@ class PersistentDenseGemmKernel:
|
|
|
1311
1472
|
# Wait for accumulator buffer full
|
|
1312
1473
|
acc_pipeline.consumer_wait(acc_consumer_state)
|
|
1313
1474
|
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
acc_vec = epilogue_op(acc_vec)
|
|
1327
|
-
if const_expr(mC_mnl is not None):
|
|
1328
|
-
epi_pipeline.consumer_wait(epi_read_state)
|
|
1329
|
-
cute.copy(
|
|
1330
|
-
tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
|
|
1331
|
-
)
|
|
1332
|
-
# Fence to make sure shared memory read is visible to TMA load
|
|
1333
|
-
cute.arch.fence_proxy(
|
|
1334
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1335
|
-
)
|
|
1336
|
-
cute.arch.sync_warp()
|
|
1337
|
-
with cute.arch.elect_one():
|
|
1338
|
-
epi_pipeline.consumer_release(epi_read_state)
|
|
1339
|
-
epi_read_state.advance()
|
|
1340
|
-
acc_vec = acc_vec + tRS_rC.load().to(self.acc_dtype)
|
|
1341
|
-
tRS_rD.store(acc_vec.to(self.d_dtype))
|
|
1342
|
-
# Store D to shared memory
|
|
1343
|
-
d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage
|
|
1344
|
-
cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
|
|
1345
|
-
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1346
|
-
cute.arch.fence_proxy(
|
|
1347
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1475
|
+
varlen_manager.fence_tensormap_update_epi(is_tma_warp)
|
|
1476
|
+
|
|
1477
|
+
copy_D = None
|
|
1478
|
+
if const_expr(has_D):
|
|
1479
|
+
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
|
|
1480
|
+
tma_atom_d,
|
|
1481
|
+
varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
|
|
1482
|
+
self.cta_tile_shape_mnk[:2],
|
|
1483
|
+
epi_tile,
|
|
1484
|
+
sD,
|
|
1485
|
+
tile_coord_mnkl,
|
|
1486
|
+
tma_desc_ptr=tma_desc_d_ptr,
|
|
1348
1487
|
)
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1488
|
+
copy_C = None # We're using a separate warp to load C
|
|
1489
|
+
|
|
1490
|
+
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
|
1491
|
+
k_len = varlen_manager.len_k(batch_idx)
|
|
1492
|
+
load_acc_subtile = partial(
|
|
1493
|
+
self.epi_load_acc_subtile,
|
|
1494
|
+
tiled_copy_t2r,
|
|
1495
|
+
tiled_copy_r2s,
|
|
1496
|
+
tTR_tAcc,
|
|
1497
|
+
tTR_rAcc,
|
|
1498
|
+
clear_acc=varlen_k and k_len == 0,
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
epi_read_state, _ = self.epilogue(
|
|
1502
|
+
epilogue_params,
|
|
1503
|
+
epi_smem_tensors,
|
|
1504
|
+
tma_desc_epi_ptrs,
|
|
1505
|
+
epi_pipeline,
|
|
1506
|
+
epi_store_pipeline,
|
|
1507
|
+
epi_read_state,
|
|
1508
|
+
None, # epi_producer_state
|
|
1509
|
+
epi_tile,
|
|
1510
|
+
load_acc_subtile,
|
|
1511
|
+
tRS_rD,
|
|
1512
|
+
tRS_rC,
|
|
1513
|
+
tiled_copy_t2r,
|
|
1514
|
+
tiled_copy_r2s,
|
|
1515
|
+
tRS_sD,
|
|
1516
|
+
tiled_copy_s2r,
|
|
1517
|
+
tSR_rC,
|
|
1518
|
+
tSR_sC,
|
|
1519
|
+
copy_D,
|
|
1520
|
+
copy_C,
|
|
1521
|
+
tile_coord_mnkl,
|
|
1522
|
+
varlen_manager,
|
|
1523
|
+
epilogue_barrier,
|
|
1524
|
+
tile_scheduler,
|
|
1525
|
+
epi_tidx,
|
|
1526
|
+
is_tma_warp,
|
|
1527
|
+
)
|
|
1357
1528
|
|
|
1358
1529
|
# Async arrive accumulator buffer empty
|
|
1359
1530
|
with cute.arch.elect_one():
|
|
@@ -1369,7 +1540,7 @@ class PersistentDenseGemmKernel:
|
|
|
1369
1540
|
cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
|
|
1370
1541
|
epilogue_barrier.arrive_and_wait()
|
|
1371
1542
|
if warp_idx == self.epilog_warp_id[0]:
|
|
1372
|
-
if use_2cta_instrs:
|
|
1543
|
+
if const_expr(use_2cta_instrs):
|
|
1373
1544
|
cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
|
|
1374
1545
|
cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
|
|
1375
1546
|
cute.arch.dealloc_tmem(
|
|
@@ -1377,82 +1548,54 @@ class PersistentDenseGemmKernel:
|
|
|
1377
1548
|
)
|
|
1378
1549
|
|
|
1379
1550
|
# Wait for D store complete
|
|
1380
|
-
|
|
1551
|
+
if is_tma_warp:
|
|
1552
|
+
epi_store_pipeline.producer_tail()
|
|
1381
1553
|
|
|
1382
1554
|
@cute.jit
|
|
1383
|
-
def
|
|
1555
|
+
def load_A_gather_A(
|
|
1384
1556
|
self,
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
tBgB: cute.Tensor,
|
|
1393
|
-
tBsB: cute.Tensor,
|
|
1394
|
-
b_mcast_mask: cutlass.Int16,
|
|
1395
|
-
tma_atom_sfa: Optional[cute.CopyAtom] = None,
|
|
1396
|
-
tAgSFA: Optional[cute.Tensor] = None,
|
|
1397
|
-
tAsSFA: Optional[cute.Tensor] = None,
|
|
1398
|
-
sfa_mcast_mask: Optional[cutlass.Int16] = None,
|
|
1399
|
-
tma_atom_sfb: Optional[cute.CopyAtom] = None,
|
|
1400
|
-
tBgSFB: Optional[cute.Tensor] = None,
|
|
1401
|
-
tBsSFB: Optional[cute.Tensor] = None,
|
|
1402
|
-
sfb_mcast_mask: Optional[cutlass.Int16] = None,
|
|
1403
|
-
) -> cutlass.pipeline.PipelineState:
|
|
1404
|
-
blockscaled = const_expr(tma_atom_sfa is not None)
|
|
1405
|
-
if const_expr(blockscaled):
|
|
1406
|
-
assert all(x is not None for x in (tma_atom_sfa, tAgSFA, tAsSFA))
|
|
1407
|
-
assert all(x is not None for x in (tma_atom_sfb, tBgSFB, tBsSFB))
|
|
1408
|
-
k_tile_cnt = cute.size(tAgA, mode=[1])
|
|
1557
|
+
a_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1558
|
+
a_producer_state: cutlass.pipeline.PipelineState,
|
|
1559
|
+
a_prefetch_consumer_state: Optional[cutlass.pipeline.PipelineState],
|
|
1560
|
+
copy_A: Callable,
|
|
1561
|
+
prefetch_A: Optional[Callable],
|
|
1562
|
+
k_tile_cnt: Int32,
|
|
1563
|
+
) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]:
|
|
1409
1564
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1410
|
-
|
|
1565
|
+
peek_a_empty_status = Boolean(True)
|
|
1411
1566
|
if 0 < k_tile_cnt:
|
|
1412
|
-
|
|
1567
|
+
peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
|
|
1413
1568
|
# /////////////////////////////////////////////////////////////////////////
|
|
1414
|
-
#
|
|
1569
|
+
# cp.async on A
|
|
1415
1570
|
# /////////////////////////////////////////////////////////////////////////
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
)
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
tBgB[None, k_tile],
|
|
1430
|
-
tBsB[None, ab_producer_state.index],
|
|
1431
|
-
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1432
|
-
mcast_mask=b_mcast_mask,
|
|
1433
|
-
)
|
|
1434
|
-
if const_expr(blockscaled):
|
|
1435
|
-
cute.copy(
|
|
1436
|
-
tma_atom_sfa,
|
|
1437
|
-
tAgSFA[None, ab_producer_state.count],
|
|
1438
|
-
tAsSFA[None, ab_producer_state.index],
|
|
1439
|
-
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1440
|
-
mcast_mask=sfa_mcast_mask,
|
|
1441
|
-
)
|
|
1442
|
-
cute.copy(
|
|
1443
|
-
tma_atom_sfb,
|
|
1444
|
-
tBgSFB[None, ab_producer_state.count],
|
|
1445
|
-
tBsSFB[None, ab_producer_state.index],
|
|
1446
|
-
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1447
|
-
mcast_mask=sfb_mcast_mask,
|
|
1448
|
-
)
|
|
1449
|
-
# Mainloop pipeline's producer commit is a NOP
|
|
1450
|
-
ab_pipeline.producer_commit(ab_producer_state)
|
|
1451
|
-
ab_producer_state.advance()
|
|
1452
|
-
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1571
|
+
is_tma_warp = False
|
|
1572
|
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1573
|
+
smem_idx = a_producer_state.index
|
|
1574
|
+
prefetch_out = ()
|
|
1575
|
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
1576
|
+
prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),)
|
|
1577
|
+
a_prefetch_consumer_state.advance()
|
|
1578
|
+
a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp)
|
|
1579
|
+
copy_A(k_tile, smem_idx, *prefetch_out)
|
|
1580
|
+
# This tells mbarrier to track the completion of cp.async
|
|
1581
|
+
a_pipeline.producer_cpasync_commit(a_producer_state)
|
|
1582
|
+
a_producer_state.advance()
|
|
1583
|
+
peek_a_empty_status = Boolean(True)
|
|
1453
1584
|
if k_tile + 1 < k_tile_cnt:
|
|
1454
|
-
|
|
1455
|
-
|
|
1585
|
+
peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
|
|
1586
|
+
# bound checking in the K dimension on the last k_tile
|
|
1587
|
+
if 0 < k_tile_cnt:
|
|
1588
|
+
k_tile = k_tile_cnt - 1
|
|
1589
|
+
smem_idx = a_producer_state.index
|
|
1590
|
+
prefetch_out = ()
|
|
1591
|
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
1592
|
+
prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),)
|
|
1593
|
+
a_prefetch_consumer_state.advance()
|
|
1594
|
+
a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp)
|
|
1595
|
+
copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
|
|
1596
|
+
a_pipeline.producer_cpasync_commit(a_producer_state)
|
|
1597
|
+
a_producer_state.advance()
|
|
1598
|
+
return a_producer_state, a_prefetch_consumer_state
|
|
1456
1599
|
|
|
1457
1600
|
@cute.jit
|
|
1458
1601
|
def mma(
|
|
@@ -1466,7 +1609,10 @@ class PersistentDenseGemmKernel:
|
|
|
1466
1609
|
tCrB: cute.Tensor,
|
|
1467
1610
|
acc: cute.Tensor,
|
|
1468
1611
|
k_tile_cnt: Int32,
|
|
1469
|
-
is_leader_cta:
|
|
1612
|
+
is_leader_cta: Boolean,
|
|
1613
|
+
cta_rank_in_cluster: Int32,
|
|
1614
|
+
tCtSFA: Optional[cute.Tensor] = None,
|
|
1615
|
+
tCtSFB: Optional[cute.Tensor] = None,
|
|
1470
1616
|
tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None,
|
|
1471
1617
|
tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None,
|
|
1472
1618
|
tCsSFA_compact_s2t: Optional[cute.Tensor] = None,
|
|
@@ -1476,12 +1622,17 @@ class PersistentDenseGemmKernel:
|
|
|
1476
1622
|
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]:
|
|
1477
1623
|
blockscaled = const_expr(tiled_copy_s2t_sfa is not None)
|
|
1478
1624
|
if const_expr(blockscaled):
|
|
1625
|
+
assert all(x is not None for x in (tCtSFA, tCtSFB))
|
|
1479
1626
|
assert all(x is not None for x in (tiled_copy_s2t_sfa, tiled_copy_s2t_sfb))
|
|
1480
1627
|
assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t))
|
|
1481
1628
|
assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t))
|
|
1629
|
+
# If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will
|
|
1630
|
+
# arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader
|
|
1631
|
+
# CTA will wait for that then arrive at the mbarrier on the leader CTA.
|
|
1632
|
+
need_nonleader_cta = const_expr(self.gather_A and self.use_2cta_instrs)
|
|
1482
1633
|
# Peek (try_wait) AB buffer full for k_tile = 0
|
|
1483
|
-
peek_ab_full_status =
|
|
1484
|
-
if 0 < k_tile_cnt and is_leader_cta:
|
|
1634
|
+
peek_ab_full_status = Boolean(True)
|
|
1635
|
+
if 0 < k_tile_cnt and (is_leader_cta or need_nonleader_cta):
|
|
1485
1636
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
|
1486
1637
|
# Wait for accumulator buffer empty
|
|
1487
1638
|
if is_leader_cta:
|
|
@@ -1491,6 +1642,14 @@ class PersistentDenseGemmKernel:
|
|
|
1491
1642
|
# Mma mainloop
|
|
1492
1643
|
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
1493
1644
|
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
1645
|
+
if const_expr(need_nonleader_cta):
|
|
1646
|
+
if not is_leader_cta:
|
|
1647
|
+
ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
|
|
1648
|
+
with cute.arch.elect_one():
|
|
1649
|
+
# The odd CTA signals the even CTA
|
|
1650
|
+
ab_pipeline.sync_object_full.arrive_mbarrier(
|
|
1651
|
+
ab_consumer_state.index, dst_rank=cta_rank_in_cluster & 0xFE
|
|
1652
|
+
)
|
|
1494
1653
|
if is_leader_cta:
|
|
1495
1654
|
# Conditionally wait for AB buffer full
|
|
1496
1655
|
ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
|
|
@@ -1503,14 +1662,19 @@ class PersistentDenseGemmKernel:
|
|
|
1503
1662
|
cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
|
|
1504
1663
|
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1505
1664
|
k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index)
|
|
1665
|
+
if const_expr(blockscaled):
|
|
1666
|
+
# Set SFA/SFB tensor to tiled_mma
|
|
1667
|
+
sf_kblock_coord = (None, None, k_blk_idx)
|
|
1668
|
+
tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
|
|
1669
|
+
tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator)
|
|
1506
1670
|
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1507
1671
|
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
|
1508
1672
|
# Async arrive AB buffer empty
|
|
1509
1673
|
ab_pipeline.consumer_release(ab_consumer_state)
|
|
1510
1674
|
ab_consumer_state.advance()
|
|
1511
1675
|
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
|
1512
|
-
peek_ab_full_status =
|
|
1513
|
-
if k_tile + 1 < k_tile_cnt and is_leader_cta:
|
|
1676
|
+
peek_ab_full_status = Boolean(True)
|
|
1677
|
+
if k_tile + 1 < k_tile_cnt and (is_leader_cta or need_nonleader_cta):
|
|
1514
1678
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
|
1515
1679
|
# Async arrive accumulator buffer full
|
|
1516
1680
|
if is_leader_cta:
|
|
@@ -1520,6 +1684,25 @@ class PersistentDenseGemmKernel:
|
|
|
1520
1684
|
# "operand #0 does not dominate this use"
|
|
1521
1685
|
return ab_consumer_state, acc_producer_state, tiled_mma
|
|
1522
1686
|
|
|
1687
|
+
@cute.jit
|
|
1688
|
+
def epi_load_acc_subtile(
|
|
1689
|
+
self,
|
|
1690
|
+
tiled_copy_t2r: cute.TiledCopy,
|
|
1691
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
1692
|
+
tTR_tAcc: cute.Tensor,
|
|
1693
|
+
tTR_rAcc: cute.Tensor,
|
|
1694
|
+
tRS_rD: cute.Tensor,
|
|
1695
|
+
epi_idx: int,
|
|
1696
|
+
clear_acc: Boolean = False,
|
|
1697
|
+
):
|
|
1698
|
+
if not clear_acc:
|
|
1699
|
+
# Load accumulator from tensor memory buffer to register
|
|
1700
|
+
cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, None, epi_idx], tTR_rAcc)
|
|
1701
|
+
tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc)
|
|
1702
|
+
tRS_rD.store(tRS_rAcc.load())
|
|
1703
|
+
else:
|
|
1704
|
+
tRS_rD.fill(0.0)
|
|
1705
|
+
|
|
1523
1706
|
def mainloop_s2t_copy_and_partition(
|
|
1524
1707
|
self,
|
|
1525
1708
|
sSF: cute.Tensor,
|
|
@@ -1560,7 +1743,7 @@ class PersistentDenseGemmKernel:
|
|
|
1560
1743
|
tidx: Int32,
|
|
1561
1744
|
tAcc: cute.Tensor,
|
|
1562
1745
|
epi_tile: cute.Tile,
|
|
1563
|
-
use_2cta_instrs: Union[
|
|
1746
|
+
use_2cta_instrs: Union[Boolean, bool],
|
|
1564
1747
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1565
1748
|
"""
|
|
1566
1749
|
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
|
|
@@ -1583,8 +1766,8 @@ class PersistentDenseGemmKernel:
|
|
|
1583
1766
|
# Make tiledCopy for tensor memory load
|
|
1584
1767
|
copy_atom_t2r = sm100_utils.get_tmem_load_op(
|
|
1585
1768
|
self.cta_tile_shape_mnk,
|
|
1586
|
-
self.d_layout,
|
|
1587
|
-
self.d_dtype,
|
|
1769
|
+
self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR,
|
|
1770
|
+
self.d_dtype if self.d_dtype is not None else cutlass.BFloat16,
|
|
1588
1771
|
self.acc_dtype,
|
|
1589
1772
|
epi_tile,
|
|
1590
1773
|
use_2cta_instrs,
|
|
@@ -1607,12 +1790,14 @@ class PersistentDenseGemmKernel:
|
|
|
1607
1790
|
tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
|
|
1608
1791
|
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
|
|
1609
1792
|
|
|
1610
|
-
def
|
|
1793
|
+
def epilog_smem_store_and_partition(
|
|
1611
1794
|
self,
|
|
1612
1795
|
tiled_copy_t2r: cute.TiledCopy,
|
|
1796
|
+
d_layout: Optional[LayoutEnum],
|
|
1797
|
+
dtype: Optional[Type[cutlass.Numeric]],
|
|
1613
1798
|
tTR_rD: cute.Tensor,
|
|
1614
|
-
tidx: Int32,
|
|
1615
1799
|
sD: cute.Tensor,
|
|
1800
|
+
tidx: Int32,
|
|
1616
1801
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1617
1802
|
"""
|
|
1618
1803
|
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
|
|
@@ -1634,93 +1819,183 @@ class PersistentDenseGemmKernel:
|
|
|
1634
1819
|
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
1635
1820
|
"""
|
|
1636
1821
|
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
|
1637
|
-
|
|
1822
|
+
d_layout if d_layout is not None else LayoutEnum.ROW_MAJOR,
|
|
1823
|
+
dtype if dtype is not None else cutlass.BFloat16,
|
|
1824
|
+
self.acc_dtype,
|
|
1825
|
+
tiled_copy_t2r,
|
|
1638
1826
|
)
|
|
1639
1827
|
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
|
1640
1828
|
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1641
1829
|
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1642
|
-
tRS_sD = thr_copy_r2s.partition_D(sD)
|
|
1830
|
+
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
|
1643
1831
|
# (R2S, R2S_M, R2S_N)
|
|
1644
1832
|
tRS_rD = tiled_copy_r2s.retile(tTR_rD)
|
|
1645
1833
|
return tiled_copy_r2s, tRS_rD, tRS_sD
|
|
1646
1834
|
|
|
1647
|
-
|
|
1648
|
-
# self,
|
|
1649
|
-
# tiled_copy_t2r: cute.TiledCopy,
|
|
1650
|
-
# tTR_rC: cute.Tensor,
|
|
1651
|
-
# tidx: Int32,
|
|
1652
|
-
# sC: cute.Tensor,
|
|
1653
|
-
# ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1654
|
-
# copy_atom_s2r = cute.make_copy_atom(
|
|
1655
|
-
# warp.LdMatrix8x8x16bOp(self.c_layout.is_m_major_c(), num_matrices=4),
|
|
1656
|
-
# self.c_dtype, # TODO: this probably only works for f16 for now?
|
|
1657
|
-
# )
|
|
1658
|
-
# # copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
|
|
1659
|
-
# tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
|
|
1660
|
-
# # (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1661
|
-
# thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1662
|
-
# # (R2S, R2S_M, R2S_N)
|
|
1663
|
-
# tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1664
|
-
# return tiled_copy_s2r, tSR_sC
|
|
1665
|
-
|
|
1666
|
-
def epilog_gmem_copy_and_partition(
|
|
1835
|
+
def epilog_smem_load_and_partition(
|
|
1667
1836
|
self,
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1837
|
+
tiled_copy_t2r: cute.TiledCopy,
|
|
1838
|
+
c_layout: LayoutEnum,
|
|
1839
|
+
dtype: Type[cutlass.Numeric],
|
|
1840
|
+
# tTR_rC: cute.Tensor,
|
|
1841
|
+
sC: cute.Tensor,
|
|
1842
|
+
tRS_rD_layout: cutlass.Layout,
|
|
1843
|
+
tidx: Int32,
|
|
1844
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1845
|
+
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
|
1846
|
+
c_layout, dtype, self.acc_dtype, tiled_copy_t2r
|
|
1847
|
+
)
|
|
1848
|
+
store_op = copy_atom_r2s.op
|
|
1849
|
+
# m8n8 16-bit path
|
|
1850
|
+
if isinstance(store_op, StMatrix8x8x16bOp):
|
|
1851
|
+
op = LdMatrix8x8x16bOp(num_matrices=store_op.num_matrices, transpose=store_op.transpose)
|
|
1852
|
+
# m16n8 8-bit store -> m16n16 8-bit load
|
|
1853
|
+
elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [2, 4]:
|
|
1854
|
+
# transpose=True is enforced by the class
|
|
1855
|
+
op = LdMatrix16x16x8bOp(num_matrices=store_op.num_matrices // 2)
|
|
1856
|
+
else:
|
|
1857
|
+
op = cute.nvgpu.CopyUniversalOp()
|
|
1858
|
+
copy_atom_s2r = cute.make_copy_atom(op, dtype)
|
|
1859
|
+
tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
|
|
1860
|
+
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1861
|
+
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1862
|
+
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1863
|
+
tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
|
|
1864
|
+
# (R2S, R2S_M, R2S_N)
|
|
1865
|
+
tSR_rC = tiled_copy_s2r.retile(tRS_rC)
|
|
1866
|
+
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
|
1685
1867
|
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
#
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
|
|
1868
|
+
@cute.jit
|
|
1869
|
+
def make_ab_pipeline(
|
|
1870
|
+
self,
|
|
1871
|
+
tiled_mma: cute.TiledMma,
|
|
1872
|
+
cluster_layout_vmnk: cute.Layout,
|
|
1873
|
+
ab_pipeline_mbar_ptr: cute.Pointer,
|
|
1874
|
+
is_leader_cta: Boolean,
|
|
1875
|
+
) -> pipeline.PipelineAsync:
|
|
1876
|
+
# If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will
|
|
1877
|
+
# arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader
|
|
1878
|
+
# CTA will wait for that then arrive at the mbarrier on the leader CTA.
|
|
1879
|
+
# The producer count for the leader CTA is 1 (TMA) + num_cpasync_threads
|
|
1880
|
+
# + 1 (from non-leader CTA).
|
|
1881
|
+
# The producer count for the non-leader CTA is num_cpasync_threads
|
|
1882
|
+
# (TMA doesn't arrive there).
|
|
1883
|
+
if const_expr(not self.gather_A):
|
|
1884
|
+
producer_cnt = 1
|
|
1885
|
+
else:
|
|
1886
|
+
producer_cnt = (self.num_ab_load_warps - 1) * 32 + (
|
|
1887
|
+
1 if const_expr(not self.use_2cta_instrs) else 2
|
|
1888
|
+
)
|
|
1889
|
+
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
1890
|
+
# Each warp will contribute to the arrive count with the number of mcast size
|
|
1891
|
+
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
1892
|
+
consumer_arrive_cnt = mcast_size
|
|
1893
|
+
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1894
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1708
1895
|
)
|
|
1709
|
-
|
|
1896
|
+
if const_expr(not self.gather_A):
|
|
1897
|
+
pipeline_ab = pipeline.PipelineTmaUmma.create(
|
|
1898
|
+
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1899
|
+
num_stages=self.ab_stage,
|
|
1900
|
+
producer_group=ab_pipeline_producer_group,
|
|
1901
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
1902
|
+
tx_count=self.num_tma_load_bytes,
|
|
1903
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1904
|
+
)
|
|
1905
|
+
else:
|
|
1906
|
+
pipeline_ab = PipelineTmaCpAsyncUmma.create(
|
|
1907
|
+
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1908
|
+
num_stages=self.ab_stage,
|
|
1909
|
+
producer_group=ab_pipeline_producer_group,
|
|
1910
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
1911
|
+
tx_count=self.num_tma_load_bytes,
|
|
1912
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1913
|
+
producer_drop_count=None
|
|
1914
|
+
if not self.use_2cta_instrs
|
|
1915
|
+
else (2 if not is_leader_cta else 0),
|
|
1916
|
+
)
|
|
1917
|
+
return pipeline_ab
|
|
1710
1918
|
|
|
1711
|
-
|
|
1919
|
+
def make_acc_pipeline(
|
|
1920
|
+
self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer
|
|
1921
|
+
) -> pipeline.PipelineAsync:
|
|
1922
|
+
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1923
|
+
num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1)
|
|
1924
|
+
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1925
|
+
pipeline.Agent.Thread, num_acc_consumer_threads
|
|
1926
|
+
)
|
|
1927
|
+
return pipeline.PipelineUmmaAsync.create(
|
|
1928
|
+
barrier_storage=acc_pipeline_mbar_ptr,
|
|
1929
|
+
num_stages=self.num_acc_stage,
|
|
1930
|
+
producer_group=acc_pipeline_producer_group,
|
|
1931
|
+
consumer_group=acc_pipeline_consumer_group,
|
|
1932
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1933
|
+
)
|
|
1934
|
+
|
|
1935
|
+
def make_sched_pipeline(
|
|
1936
|
+
self,
|
|
1937
|
+
cluster_layout_mnk: cute.Layout,
|
|
1938
|
+
sched_pipeline_mbar_ptr: cute.Pointer,
|
|
1939
|
+
has_C: bool = False,
|
|
1940
|
+
) -> pipeline.PipelineAsync:
|
|
1941
|
+
# Threads/warps participating in this pipeline
|
|
1942
|
+
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1943
|
+
cluster_size = cute.size(cluster_layout_mnk)
|
|
1944
|
+
# Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
|
1945
|
+
warps_per_cta = self.num_ab_load_warps + len(
|
|
1946
|
+
(self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id)
|
|
1947
|
+
)
|
|
1948
|
+
if has_C:
|
|
1949
|
+
warps_per_cta += 1
|
|
1950
|
+
consumer_arrive_cnt = warps_per_cta * cluster_size - 1
|
|
1951
|
+
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1952
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1953
|
+
)
|
|
1954
|
+
return pipeline.PipelineAsync.create(
|
|
1955
|
+
barrier_storage=sched_pipeline_mbar_ptr,
|
|
1956
|
+
num_stages=self.sched_stage,
|
|
1957
|
+
producer_group=sched_pipeline_producer_group,
|
|
1958
|
+
consumer_group=sched_pipeline_consumer_group,
|
|
1959
|
+
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
1960
|
+
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
|
1961
|
+
)
|
|
1962
|
+
|
|
1963
|
+
@cute.jit
|
|
1964
|
+
def make_a_prefetch_pipeline(
|
|
1965
|
+
self, a_prefetch_pipeline_mbar_ptr: cute.Pointer
|
|
1966
|
+
) -> pipeline.PipelineAsync:
|
|
1967
|
+
producer_cnt = 32
|
|
1968
|
+
a_prefetch_producer_group = pipeline.CooperativeGroup(
|
|
1969
|
+
pipeline.Agent.Thread, producer_cnt, alignment=producer_cnt
|
|
1970
|
+
)
|
|
1971
|
+
consumer_arrive_cnt = self.num_ab_load_warps - 1
|
|
1972
|
+
a_prefetch_consumer_group = pipeline.CooperativeGroup(
|
|
1973
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1974
|
+
)
|
|
1975
|
+
return pipeline.PipelineCpAsync.create(
|
|
1976
|
+
barrier_storage=a_prefetch_pipeline_mbar_ptr,
|
|
1977
|
+
num_stages=self.a_prefetch_stage,
|
|
1978
|
+
producer_group=a_prefetch_producer_group,
|
|
1979
|
+
consumer_group=a_prefetch_consumer_group,
|
|
1980
|
+
)
|
|
1981
|
+
|
|
1982
|
+
@classmethod
|
|
1712
1983
|
def _compute_stages(
|
|
1984
|
+
cls,
|
|
1713
1985
|
tiled_mma: cute.TiledMma,
|
|
1714
1986
|
mma_tiler_mnk: Tuple[int, int, int],
|
|
1987
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1988
|
+
epi_tile: cute.Tile,
|
|
1715
1989
|
a_dtype: Type[cutlass.Numeric],
|
|
1716
1990
|
b_dtype: Type[cutlass.Numeric],
|
|
1717
|
-
epi_tile: cute.Tile,
|
|
1718
|
-
d_dtype: Type[cutlass.Numeric],
|
|
1719
|
-
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1720
|
-
d_layout: cutlass.utils.LayoutEnum,
|
|
1721
|
-
c_layout: Optional[cutlass.utils.LayoutEnum],
|
|
1722
1991
|
sf_dtype: Optional[Type[cutlass.Numeric]],
|
|
1723
1992
|
sf_vec_size: Optional[int],
|
|
1993
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1994
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1995
|
+
d_layout: Optional[LayoutEnum],
|
|
1996
|
+
c_layout: Optional[LayoutEnum],
|
|
1997
|
+
epilogue_args: EpilogueArguments,
|
|
1998
|
+
prefetch_A_idx: Literal[None, "varlen_m", "varlen_k"],
|
|
1724
1999
|
smem_capacity: int,
|
|
1725
2000
|
occupancy: int,
|
|
1726
2001
|
) -> Tuple[int, int, int]:
|
|
@@ -1738,8 +2013,8 @@ class PersistentDenseGemmKernel:
|
|
|
1738
2013
|
:type epi_tile: cute.Tile
|
|
1739
2014
|
:param d_dtype: Data type of operand C (output).
|
|
1740
2015
|
:type d_dtype: type[cutlass.Numeric]
|
|
1741
|
-
:param d_layout: Layout enum of operand
|
|
1742
|
-
:type d_layout:
|
|
2016
|
+
:param d_layout: Layout enum of operand D.
|
|
2017
|
+
:type d_layout: LayoutEnum
|
|
1743
2018
|
:param smem_capacity: Total available shared memory capacity in bytes.
|
|
1744
2019
|
:type smem_capacity: int
|
|
1745
2020
|
:param occupancy: Target number of CTAs per SM (occupancy).
|
|
@@ -1757,8 +2032,8 @@ class PersistentDenseGemmKernel:
|
|
|
1757
2032
|
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
|
|
1758
2033
|
|
|
1759
2034
|
# Default D stages
|
|
1760
|
-
|
|
1761
|
-
|
|
2035
|
+
epi_stage = 4 if cute.size(epi_tile[1]) <= 16 else 2
|
|
2036
|
+
epi_c_stage = 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2)
|
|
1762
2037
|
|
|
1763
2038
|
# Calculate smem layout and size for one stage of A, B, and C
|
|
1764
2039
|
a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
|
|
@@ -1773,7 +2048,11 @@ class PersistentDenseGemmKernel:
|
|
|
1773
2048
|
b_dtype,
|
|
1774
2049
|
1, # a tmp 1 stage is provided
|
|
1775
2050
|
)
|
|
1776
|
-
d_smem_layout_staged_one =
|
|
2051
|
+
d_smem_layout_staged_one = (
|
|
2052
|
+
sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
|
|
2053
|
+
if d_dtype is not None
|
|
2054
|
+
else None
|
|
2055
|
+
)
|
|
1777
2056
|
c_smem_layout_staged_one = (
|
|
1778
2057
|
sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
|
|
1779
2058
|
if c_dtype is not None
|
|
@@ -1796,34 +2075,38 @@ class PersistentDenseGemmKernel:
|
|
|
1796
2075
|
ab_bytes_per_stage = cute.size_in_bytes(
|
|
1797
2076
|
a_dtype, a_smem_layout_staged_one
|
|
1798
2077
|
) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
|
|
2078
|
+
if const_expr(prefetch_A_idx == "varlen_k"): # Need smem to prefetch A indices
|
|
2079
|
+
ab_bytes_per_stage += Int32.width // 8 * cta_tile_shape_mnk[2]
|
|
1799
2080
|
if const_expr(blockscaled):
|
|
1800
2081
|
ab_bytes_per_stage += cute.size_in_bytes(
|
|
1801
2082
|
sf_dtype, sfa_smem_layout_staged_one
|
|
1802
2083
|
) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
|
|
1803
2084
|
mbar_helpers_bytes = 1024
|
|
1804
|
-
|
|
1805
|
-
|
|
2085
|
+
if const_expr(prefetch_A_idx == "varlen_m"):
|
|
2086
|
+
mbar_helpers_bytes += Int32.width // 8 * cta_tile_shape_mnk[0] * 2
|
|
2087
|
+
d_bytes_per_stage = (
|
|
2088
|
+
cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) if d_dtype is not None else 0
|
|
2089
|
+
)
|
|
2090
|
+
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
|
2091
|
+
epilogue_args, cta_tile_shape_mnk, epi_tile
|
|
2092
|
+
)
|
|
2093
|
+
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
1806
2094
|
if const_expr(c_dtype is not None):
|
|
1807
2095
|
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
|
|
1808
|
-
epi_bytes += c_bytes_per_stage *
|
|
2096
|
+
epi_bytes += c_bytes_per_stage * epi_c_stage
|
|
1809
2097
|
|
|
1810
2098
|
# Calculate A/B/SFA/SFB stages:
|
|
1811
2099
|
# Start with total smem per CTA (capacity / occupancy)
|
|
1812
2100
|
# Subtract reserved bytes and initial C stages bytes
|
|
1813
2101
|
# Divide remaining by bytes needed per A/B/SFA/SFB stage
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
) // ab_bytes_per_stage
|
|
2102
|
+
remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
|
|
2103
|
+
ab_stage = remaining_bytes // ab_bytes_per_stage
|
|
1817
2104
|
|
|
1818
2105
|
# Refine epilogue stages:
|
|
1819
2106
|
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1820
2107
|
# Add remaining unused smem to epilogue
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
- occupancy * ab_bytes_per_stage * num_ab_stage
|
|
1824
|
-
- occupancy * (mbar_helpers_bytes + epi_bytes)
|
|
1825
|
-
) // (occupancy * d_bytes_per_stage)
|
|
1826
|
-
return num_acc_stage, num_ab_stage, num_d_stage, num_c_stage
|
|
2108
|
+
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // (epi_bytes_per_stage)
|
|
2109
|
+
return num_acc_stage, ab_stage, epi_stage, epi_c_stage
|
|
1827
2110
|
|
|
1828
2111
|
@staticmethod
|
|
1829
2112
|
def _compute_num_tmem_alloc_cols(
|
|
@@ -1851,9 +2134,12 @@ class PersistentDenseGemmKernel:
|
|
|
1851
2134
|
|
|
1852
2135
|
@staticmethod
|
|
1853
2136
|
def is_valid_dtypes(
|
|
1854
|
-
|
|
2137
|
+
a_dtype: Type[cutlass.Numeric],
|
|
2138
|
+
b_dtype: Type[cutlass.Numeric],
|
|
1855
2139
|
acc_dtype: Type[cutlass.Numeric],
|
|
1856
|
-
d_dtype: Type[cutlass.Numeric],
|
|
2140
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
2141
|
+
a_major: str,
|
|
2142
|
+
b_major: str,
|
|
1857
2143
|
) -> bool:
|
|
1858
2144
|
"""
|
|
1859
2145
|
Check if the dtypes are valid
|
|
@@ -1869,6 +2155,9 @@ class PersistentDenseGemmKernel:
|
|
|
1869
2155
|
:rtype: bool
|
|
1870
2156
|
"""
|
|
1871
2157
|
is_valid = True
|
|
2158
|
+
if b_dtype != a_dtype:
|
|
2159
|
+
is_valid = False
|
|
2160
|
+
ab_dtype = a_dtype
|
|
1872
2161
|
if ab_dtype not in {
|
|
1873
2162
|
cutlass.Float16,
|
|
1874
2163
|
cutlass.BFloat16,
|
|
@@ -1880,18 +2169,18 @@ class PersistentDenseGemmKernel:
|
|
|
1880
2169
|
}:
|
|
1881
2170
|
is_valid = False
|
|
1882
2171
|
if (
|
|
1883
|
-
acc_dtype not in {
|
|
2172
|
+
acc_dtype not in {Float32, cutlass.Float16, Int32}
|
|
1884
2173
|
or acc_dtype == cutlass.Float16
|
|
1885
2174
|
and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
|
|
1886
2175
|
or acc_dtype == Int32
|
|
1887
2176
|
and ab_dtype not in {cutlass.Uint8, cutlass.Int8}
|
|
1888
2177
|
):
|
|
1889
2178
|
is_valid = False
|
|
1890
|
-
if (
|
|
1891
|
-
acc_dtype ==
|
|
2179
|
+
if d_dtype is not None and (
|
|
2180
|
+
acc_dtype == Float32
|
|
1892
2181
|
and d_dtype
|
|
1893
2182
|
not in {
|
|
1894
|
-
|
|
2183
|
+
Float32,
|
|
1895
2184
|
cutlass.Float16,
|
|
1896
2185
|
cutlass.BFloat16,
|
|
1897
2186
|
cutlass.Float8E4M3FN,
|
|
@@ -1911,13 +2200,15 @@ class PersistentDenseGemmKernel:
|
|
|
1911
2200
|
not in {
|
|
1912
2201
|
cutlass.BFloat16,
|
|
1913
2202
|
cutlass.Float16,
|
|
1914
|
-
|
|
2203
|
+
Float32,
|
|
1915
2204
|
Int32,
|
|
1916
2205
|
cutlass.Int8,
|
|
1917
2206
|
cutlass.Uint8,
|
|
1918
2207
|
}
|
|
1919
2208
|
):
|
|
1920
2209
|
is_valid = False
|
|
2210
|
+
if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
|
|
2211
|
+
is_valid = False
|
|
1921
2212
|
return is_valid
|
|
1922
2213
|
|
|
1923
2214
|
@staticmethod
|
|
@@ -1964,7 +2255,7 @@ class PersistentDenseGemmKernel:
|
|
|
1964
2255
|
|
|
1965
2256
|
# Check valid d_dtype
|
|
1966
2257
|
if d_dtype not in {
|
|
1967
|
-
|
|
2258
|
+
Float32,
|
|
1968
2259
|
cutlass.Float16,
|
|
1969
2260
|
cutlass.BFloat16,
|
|
1970
2261
|
cutlass.Float8E5M2,
|
|
@@ -1974,37 +2265,8 @@ class PersistentDenseGemmKernel:
|
|
|
1974
2265
|
|
|
1975
2266
|
return is_valid
|
|
1976
2267
|
|
|
1977
|
-
@staticmethod
|
|
1978
|
-
def is_valid_layouts(
|
|
1979
|
-
ab_dtype: Type[cutlass.Numeric],
|
|
1980
|
-
a_major: str,
|
|
1981
|
-
b_major: str,
|
|
1982
|
-
) -> bool:
|
|
1983
|
-
"""
|
|
1984
|
-
Check if the dtypes and sf_vec_size are valid combinations
|
|
1985
|
-
|
|
1986
|
-
:param ab_dtype: The data type of the A and B operands
|
|
1987
|
-
:type ab_dtype: Type[cutlass.Numeric]
|
|
1988
|
-
:param d_dtype: The data type of the output tensor
|
|
1989
|
-
:type d_dtype: Type[cutlass.Numeric]
|
|
1990
|
-
:param a_major: The major dimension of the A tensor
|
|
1991
|
-
:type a_major: str
|
|
1992
|
-
:param b_major: The major dimension of the B tensor
|
|
1993
|
-
:type b_major: str
|
|
1994
|
-
:param d_major: The major dimension of the C tensor
|
|
1995
|
-
:type d_major: str
|
|
1996
|
-
|
|
1997
|
-
:return: True if the layouts are valid, False otherwise
|
|
1998
|
-
:rtype: bool
|
|
1999
|
-
"""
|
|
2000
|
-
is_valid = True
|
|
2001
|
-
if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
|
|
2002
|
-
is_valid = False
|
|
2003
|
-
return is_valid
|
|
2004
|
-
|
|
2005
2268
|
@staticmethod
|
|
2006
2269
|
def is_valid_mma_tiler_and_cluster_shape(
|
|
2007
|
-
use_2cta_instrs: bool,
|
|
2008
2270
|
mma_tiler_mn: Tuple[int, int],
|
|
2009
2271
|
cluster_shape_mn: Tuple[int, int],
|
|
2010
2272
|
blockscaled: bool,
|
|
@@ -2012,8 +2274,6 @@ class PersistentDenseGemmKernel:
|
|
|
2012
2274
|
"""
|
|
2013
2275
|
Check if the mma tiler and cluster shape are valid
|
|
2014
2276
|
|
|
2015
|
-
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
2016
|
-
:type use_2cta_instrs: bool
|
|
2017
2277
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
2018
2278
|
:type mma_tiler_mn: Tuple[int, int]
|
|
2019
2279
|
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
@@ -2024,10 +2284,7 @@ class PersistentDenseGemmKernel:
|
|
|
2024
2284
|
"""
|
|
2025
2285
|
is_valid = True
|
|
2026
2286
|
# Skip invalid mma tile shape
|
|
2027
|
-
if not
|
|
2028
|
-
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
|
|
2029
|
-
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
|
|
2030
|
-
):
|
|
2287
|
+
if mma_tiler_mn[0] not in [64, 128, 256]:
|
|
2031
2288
|
is_valid = False
|
|
2032
2289
|
if not blockscaled:
|
|
2033
2290
|
if mma_tiler_mn[1] not in range(32, 257, 32):
|
|
@@ -2035,9 +2292,6 @@ class PersistentDenseGemmKernel:
|
|
|
2035
2292
|
else:
|
|
2036
2293
|
if mma_tiler_mn[1] not in [128, 256]:
|
|
2037
2294
|
is_valid = False
|
|
2038
|
-
# Skip illegal cluster shape
|
|
2039
|
-
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
|
|
2040
|
-
is_valid = False
|
|
2041
2295
|
# Skip invalid cluster shape
|
|
2042
2296
|
is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
|
|
2043
2297
|
if (
|
|
@@ -2113,7 +2367,6 @@ class PersistentDenseGemmKernel:
|
|
|
2113
2367
|
ab_dtype: Type[cutlass.Numeric],
|
|
2114
2368
|
acc_dtype: Type[cutlass.Numeric],
|
|
2115
2369
|
d_dtype: Type[cutlass.Numeric],
|
|
2116
|
-
use_2cta_instrs: bool,
|
|
2117
2370
|
mma_tiler_mn: Tuple[int, int],
|
|
2118
2371
|
cluster_shape_mn: Tuple[int, int],
|
|
2119
2372
|
m: int,
|
|
@@ -2133,8 +2386,6 @@ class PersistentDenseGemmKernel:
|
|
|
2133
2386
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
2134
2387
|
:param d_dtype: The data type of the output tensor
|
|
2135
2388
|
:type d_dtype: Type[cutlass.Numeric]
|
|
2136
|
-
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
2137
|
-
:type use_2cta_instrs: bool
|
|
2138
2389
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
2139
2390
|
:type mma_tiler_mn: Tuple[int, int]
|
|
2140
2391
|
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
@@ -2159,15 +2410,15 @@ class PersistentDenseGemmKernel:
|
|
|
2159
2410
|
"""
|
|
2160
2411
|
can_implement = True
|
|
2161
2412
|
# Skip unsupported types
|
|
2162
|
-
if not
|
|
2413
|
+
if not GemmSm100.is_valid_dtypes(ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major):
|
|
2163
2414
|
can_implement = False
|
|
2164
2415
|
# Skip invalid mma tile shape and cluster shape
|
|
2165
|
-
if not
|
|
2166
|
-
|
|
2416
|
+
if not GemmSm100.is_valid_mma_tiler_and_cluster_shape(
|
|
2417
|
+
mma_tiler_mn, cluster_shape_mn, blockscaled=False
|
|
2167
2418
|
):
|
|
2168
2419
|
can_implement = False
|
|
2169
2420
|
# Skip illegal problem shape for load/store alignment
|
|
2170
|
-
if not
|
|
2421
|
+
if not GemmSm100.is_valid_tensor_alignment(
|
|
2171
2422
|
m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major
|
|
2172
2423
|
):
|
|
2173
2424
|
can_implement = False
|
|
@@ -2186,7 +2437,6 @@ def run(
|
|
|
2186
2437
|
c_major: str,
|
|
2187
2438
|
mma_tiler_mn: Tuple[int, int] = (256, 256),
|
|
2188
2439
|
cluster_shape_mn: Tuple[int, int] = (2, 1),
|
|
2189
|
-
use_2cta_instrs: bool = True,
|
|
2190
2440
|
tolerance: float = 1e-01,
|
|
2191
2441
|
warmup_iterations: int = 0,
|
|
2192
2442
|
iterations: int = 1,
|
|
@@ -2215,9 +2465,6 @@ def run(
|
|
|
2215
2465
|
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
|
|
2216
2466
|
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
2217
2467
|
:type cluster_shape_mn: Tuple[int, int], optional
|
|
2218
|
-
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
|
|
2219
|
-
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
2220
|
-
:type use_2cta_instrs: bool, optional
|
|
2221
2468
|
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
|
2222
2469
|
:type tolerance: float, optional
|
|
2223
2470
|
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
|
@@ -2236,7 +2483,6 @@ def run(
|
|
|
2236
2483
|
print(f"AB dtype: {ab_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
|
|
2237
2484
|
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
|
|
2238
2485
|
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
|
2239
|
-
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
|
|
2240
2486
|
print(f"Tolerance: {tolerance}")
|
|
2241
2487
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
2242
2488
|
print(f"Iterations: {iterations}")
|
|
@@ -2248,11 +2494,10 @@ def run(
|
|
|
2248
2494
|
m, n, k, l = mnkl
|
|
2249
2495
|
|
|
2250
2496
|
# Skip unsupported testcase
|
|
2251
|
-
if not
|
|
2497
|
+
if not GemmSm100.can_implement(
|
|
2252
2498
|
ab_dtype,
|
|
2253
2499
|
acc_dtype,
|
|
2254
2500
|
d_dtype,
|
|
2255
|
-
use_2cta_instrs,
|
|
2256
2501
|
mma_tiler_mn,
|
|
2257
2502
|
cluster_shape_mn,
|
|
2258
2503
|
m,
|
|
@@ -2264,7 +2509,7 @@ def run(
|
|
|
2264
2509
|
d_major,
|
|
2265
2510
|
):
|
|
2266
2511
|
raise TypeError(
|
|
2267
|
-
f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {
|
|
2512
|
+
f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}"
|
|
2268
2513
|
)
|
|
2269
2514
|
|
|
2270
2515
|
if not torch.cuda.is_available():
|
|
@@ -2339,12 +2584,8 @@ def run(
|
|
|
2339
2584
|
c, mC, c_torch = None, None, None
|
|
2340
2585
|
|
|
2341
2586
|
# Configure gemm kernel
|
|
2342
|
-
|
|
2343
|
-
|
|
2344
|
-
use_2cta_instrs,
|
|
2345
|
-
mma_tiler_mn,
|
|
2346
|
-
cluster_shape_mn,
|
|
2347
|
-
)
|
|
2587
|
+
cluster_shape_mnk = (*cluster_shape_mn, 1)
|
|
2588
|
+
gemm = GemmSm100(acc_dtype, ab_dtype, mma_tiler_mn, cluster_shape_mnk)
|
|
2348
2589
|
|
|
2349
2590
|
# Compute max active clusters on current device
|
|
2350
2591
|
hardware_info = cutlass.utils.HardwareInfo()
|
|
@@ -2356,6 +2597,17 @@ def run(
|
|
|
2356
2597
|
else:
|
|
2357
2598
|
tile_count_semaphore = None
|
|
2358
2599
|
|
|
2600
|
+
scheduler_args = TileSchedulerOptions(
|
|
2601
|
+
Int32(max_active_clusters),
|
|
2602
|
+
tile_count_semaphore=make_ptr(
|
|
2603
|
+
Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
|
|
2604
|
+
)
|
|
2605
|
+
if tile_count_semaphore is not None
|
|
2606
|
+
else None,
|
|
2607
|
+
)
|
|
2608
|
+
epi_args = gemm.EpilogueArguments()
|
|
2609
|
+
varlen_args = VarlenArguments()
|
|
2610
|
+
|
|
2359
2611
|
# Get current CUDA stream from PyTorch
|
|
2360
2612
|
torch_stream = torch.cuda.current_stream()
|
|
2361
2613
|
# Get the raw stream pointer as a CUstream
|
|
@@ -2367,15 +2619,14 @@ def run(
|
|
|
2367
2619
|
mB,
|
|
2368
2620
|
mD,
|
|
2369
2621
|
mC,
|
|
2370
|
-
|
|
2371
|
-
|
|
2372
|
-
|
|
2373
|
-
max_active_clusters,
|
|
2622
|
+
epi_args,
|
|
2623
|
+
scheduler_args,
|
|
2624
|
+
varlen_args,
|
|
2374
2625
|
current_stream,
|
|
2375
2626
|
)
|
|
2376
2627
|
|
|
2377
2628
|
if not skip_ref_check:
|
|
2378
|
-
compiled_gemm(mA, mB, mD, mC,
|
|
2629
|
+
compiled_gemm(mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream)
|
|
2379
2630
|
if ab_dtype in {
|
|
2380
2631
|
cutlass.Int8,
|
|
2381
2632
|
cutlass.Uint8,
|
|
@@ -2393,7 +2644,7 @@ def run(
|
|
|
2393
2644
|
gpu_d = d_torch.cpu()
|
|
2394
2645
|
|
|
2395
2646
|
# Convert ref to c_type
|
|
2396
|
-
if d_dtype ==
|
|
2647
|
+
if d_dtype == Float32:
|
|
2397
2648
|
ref_d = ref
|
|
2398
2649
|
elif d_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
|
2399
2650
|
# m major: (l, n, m) -> (m, n, l)
|
|
@@ -2463,7 +2714,9 @@ def run(
|
|
|
2463
2714
|
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2464
2715
|
|
|
2465
2716
|
time.sleep(0.5)
|
|
2466
|
-
fn = lambda: compiled_gemm(
|
|
2717
|
+
fn = lambda: compiled_gemm(
|
|
2718
|
+
mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream
|
|
2719
|
+
)
|
|
2467
2720
|
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
2468
2721
|
tflops = flops / (timing * 1e9) # Convert to TFlops
|
|
2469
2722
|
print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
|
|
@@ -2505,12 +2758,7 @@ if __name__ == "__main__":
|
|
|
2505
2758
|
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
2506
2759
|
parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
2507
2760
|
parser.add_argument("--c_dtype", type=cutlass.dtype, default=None)
|
|
2508
|
-
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=
|
|
2509
|
-
parser.add_argument(
|
|
2510
|
-
"--use_2cta_instrs",
|
|
2511
|
-
action="store_true",
|
|
2512
|
-
help="Enable 2CTA MMA instructions feature",
|
|
2513
|
-
)
|
|
2761
|
+
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=Float32)
|
|
2514
2762
|
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
2515
2763
|
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
2516
2764
|
parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
|
|
@@ -2552,7 +2800,6 @@ if __name__ == "__main__":
|
|
|
2552
2800
|
args.c_major,
|
|
2553
2801
|
args.mma_tiler_mn,
|
|
2554
2802
|
args.cluster_shape_mn,
|
|
2555
|
-
args.use_2cta_instrs,
|
|
2556
2803
|
args.tolerance,
|
|
2557
2804
|
args.warmup_iterations,
|
|
2558
2805
|
args.iterations,
|