quack-kernels 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/autotuner.py +64 -5
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/rmsnorm.py +83 -149
- quack/tile_scheduler.py +34 -47
- quack/utils.py +61 -8
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +2 -2
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/RECORD +18 -18
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -1,30 +1,5 @@
|
|
|
1
|
-
#
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
# Redistribution and use in source and binary forms, with or without
|
|
5
|
-
# modification, are permitted provided that the following conditions are met:
|
|
6
|
-
|
|
7
|
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
-
# list of conditions and the following disclaimer.
|
|
9
|
-
|
|
10
|
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
-
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
-
# and/or other materials provided with the distribution.
|
|
13
|
-
|
|
14
|
-
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
-
# contributors may be used to endorse or promote products derived from
|
|
16
|
-
# this software without specific prior written permission.
|
|
17
|
-
|
|
18
|
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
1
|
+
# Based on the cute-dsl example:
|
|
2
|
+
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
|
|
28
3
|
|
|
29
4
|
import argparse
|
|
30
5
|
from typing import Optional, Type, Tuple, Union, Callable
|
|
@@ -40,15 +15,16 @@ import cutlass.torch as cutlass_torch
|
|
|
40
15
|
import cutlass.pipeline as pipeline
|
|
41
16
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
42
17
|
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
|
18
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
19
|
+
from cutlass.utils import LayoutEnum
|
|
43
20
|
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
44
|
-
from cutlass import Int32, const_expr
|
|
45
21
|
|
|
46
|
-
from quack.cute_dsl_utils import ParamsBase
|
|
47
|
-
from quack.tile_scheduler import
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
22
|
+
from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
|
|
23
|
+
from quack.tile_scheduler import TileSchedulerOptions
|
|
24
|
+
from quack.varlen_utils import VarlenArguments
|
|
25
|
+
from quack.dense_gemm_sm90 import GemmSm90, NamedBarrierGemm
|
|
26
|
+
|
|
27
|
+
# return PipelineStateWAdvance instead of PipelineState
|
|
52
28
|
|
|
53
29
|
"""
|
|
54
30
|
A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
|
|
@@ -72,8 +48,6 @@ This GEMM works as follows:
|
|
|
72
48
|
- Type convert C matrix to output type.
|
|
73
49
|
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
|
|
74
50
|
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
|
|
75
|
-
- Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
|
|
76
|
-
e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
|
|
77
51
|
|
|
78
52
|
SM100 tcgen05.mma instructions operate as follows:
|
|
79
53
|
- Read matrix A from SMEM
|
|
@@ -105,7 +79,7 @@ To collect performance with NCU profiler:
|
|
|
105
79
|
|
|
106
80
|
Constraints are same as dense_gemm.py:
|
|
107
81
|
* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
|
|
108
|
-
see detailed valid dtype combinations in below
|
|
82
|
+
see detailed valid dtype combinations in below GemmSm100 class documentation
|
|
109
83
|
* A/B tensor must have the same data type
|
|
110
84
|
* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
|
|
111
85
|
* Mma tiler N must be 32-256, step 32
|
|
@@ -118,14 +92,12 @@ Constraints are same as dense_gemm.py:
|
|
|
118
92
|
"""
|
|
119
93
|
|
|
120
94
|
|
|
121
|
-
class
|
|
95
|
+
class GemmSm100(GemmSm90):
|
|
122
96
|
"""This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
123
97
|
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
|
|
124
98
|
|
|
125
99
|
:param acc_dtype: Data type for accumulation during computation
|
|
126
100
|
:type acc_dtype: type[cutlass.Numeric]
|
|
127
|
-
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
|
|
128
|
-
:type use_2cta_instrs: bool
|
|
129
101
|
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
|
|
130
102
|
:type mma_tiler_mn: Tuple[int, int]
|
|
131
103
|
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
|
@@ -159,22 +131,27 @@ class PersistentDenseGemmKernel:
|
|
|
159
131
|
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
|
160
132
|
|
|
161
133
|
Example:
|
|
162
|
-
>>> gemm =
|
|
163
|
-
... acc_dtype=
|
|
164
|
-
... use_2cta_instrs=True,
|
|
134
|
+
>>> gemm = GemmSm100(
|
|
135
|
+
... acc_dtype=Float32,
|
|
165
136
|
... mma_tiler_mn=(128, 128),
|
|
166
137
|
... cluster_shape_mn=(2, 2)
|
|
167
138
|
... )
|
|
168
139
|
>>> gemm(mA, mB, mD, max_active_clusters, stream)
|
|
169
140
|
"""
|
|
170
141
|
|
|
142
|
+
arch = 100
|
|
143
|
+
num_epi_tensormaps = GemmSm90.num_epi_tensormaps
|
|
144
|
+
|
|
145
|
+
EpilogueArguments = GemmSm90.EpilogueArguments
|
|
146
|
+
EpilogueParams = GemmSm90.EpilogueParams
|
|
147
|
+
|
|
171
148
|
def __init__(
|
|
172
149
|
self,
|
|
173
150
|
acc_dtype: Type[cutlass.Numeric],
|
|
174
|
-
use_2cta_instrs: bool,
|
|
175
151
|
mma_tiler_mn: Tuple[int, int],
|
|
176
|
-
|
|
152
|
+
cluster_shape_mnk: Tuple[int, int, int],
|
|
177
153
|
sf_vec_size: Optional[int] = None,
|
|
154
|
+
gather_A: bool = False,
|
|
178
155
|
):
|
|
179
156
|
"""Initializes the configuration for a Blackwell dense GEMM kernel.
|
|
180
157
|
|
|
@@ -187,47 +164,42 @@ class PersistentDenseGemmKernel:
|
|
|
187
164
|
with cta_group=2 should be used.
|
|
188
165
|
|
|
189
166
|
2. Cluster Shape:
|
|
190
|
-
-
|
|
167
|
+
- cluster_shape_mnk: The (ClusterM, ClusterN) shape of the CTA cluster.
|
|
191
168
|
|
|
192
169
|
:param acc_dtype: Data type of the accumulator.
|
|
193
170
|
:type acc_dtype: type[cutlass.Numeric]
|
|
194
171
|
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
|
|
195
172
|
:type mma_tiler_mn: Tuple[int, int]
|
|
196
|
-
:param
|
|
197
|
-
:type
|
|
198
|
-
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
|
|
199
|
-
:type cluster_shape_mn: Tuple[int, int]
|
|
173
|
+
:param cluster_shape_mnk: Tuple (ClusterM, ClusterN) shape of the cluster.
|
|
174
|
+
:type cluster_shape_mnk: Tuple[int, int]
|
|
200
175
|
"""
|
|
201
176
|
|
|
202
177
|
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
|
|
203
|
-
self.use_2cta_instrs =
|
|
204
|
-
self.
|
|
178
|
+
self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (128, 256)
|
|
179
|
+
self.cluster_shape_mnk = cluster_shape_mnk
|
|
180
|
+
assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1"
|
|
205
181
|
# K dimension is deferred in _setup_attributes
|
|
206
182
|
self.mma_tiler = (*mma_tiler_mn, 1)
|
|
207
183
|
self.sf_vec_size = sf_vec_size
|
|
208
184
|
self.blockscaled = sf_vec_size is not None
|
|
185
|
+
self.is_persistent = True
|
|
186
|
+
self.pingpong = False # for compatibility with GemmSm90
|
|
187
|
+
self.gather_A = gather_A
|
|
188
|
+
if gather_A:
|
|
189
|
+
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
|
209
190
|
|
|
210
|
-
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
191
|
+
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
211
192
|
|
|
212
193
|
self.occupancy = 1
|
|
213
194
|
# Set specialized warp ids
|
|
214
|
-
self.epilog_warp_id = (
|
|
215
|
-
0,
|
|
216
|
-
1,
|
|
217
|
-
2,
|
|
218
|
-
3,
|
|
219
|
-
)
|
|
195
|
+
self.epilog_warp_id = (0, 1, 2, 3)
|
|
220
196
|
self.mma_warp_id = 4
|
|
221
197
|
self.tma_warp_id = 5
|
|
222
198
|
self.tma_epi_warp_id = 6
|
|
199
|
+
self.num_epi_warps = len(self.epilog_warp_id)
|
|
223
200
|
self.threads_per_cta = 32 * len(
|
|
224
201
|
(self.mma_warp_id, self.tma_warp_id, self.tma_epi_warp_id, *self.epilog_warp_id)
|
|
225
202
|
)
|
|
226
|
-
# Set barrier id for cta sync, epilogue sync and tmem ptr sync
|
|
227
|
-
self.cta_sync_bar_id = 0
|
|
228
|
-
self.epilog_sync_bar_id = 1
|
|
229
|
-
self.tmem_ptr_sync_bar_id = 2
|
|
230
|
-
self.epilog_load_bar_id = 3
|
|
231
203
|
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_100")
|
|
232
204
|
|
|
233
205
|
def _setup_attributes(self):
|
|
@@ -261,7 +233,7 @@ class PersistentDenseGemmKernel:
|
|
|
261
233
|
|
|
262
234
|
# Configure tiled mma
|
|
263
235
|
if const_expr(not self.blockscaled):
|
|
264
|
-
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
236
|
+
self.tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
265
237
|
self.a_dtype,
|
|
266
238
|
self.a_major_mode,
|
|
267
239
|
self.b_major_mode,
|
|
@@ -269,9 +241,9 @@ class PersistentDenseGemmKernel:
|
|
|
269
241
|
self.cta_group,
|
|
270
242
|
self.mma_tiler[:2],
|
|
271
243
|
)
|
|
272
|
-
tiled_mma_sfb = None
|
|
244
|
+
self.tiled_mma_sfb = None
|
|
273
245
|
else:
|
|
274
|
-
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
246
|
+
self.tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
275
247
|
self.a_dtype,
|
|
276
248
|
self.a_major_mode,
|
|
277
249
|
self.b_major_mode,
|
|
@@ -280,13 +252,13 @@ class PersistentDenseGemmKernel:
|
|
|
280
252
|
self.cta_group,
|
|
281
253
|
self.mma_inst_shape_mnk[:2],
|
|
282
254
|
)
|
|
283
|
-
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
255
|
+
self.tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
284
256
|
self.a_dtype,
|
|
285
257
|
self.a_major_mode,
|
|
286
258
|
self.b_major_mode,
|
|
287
259
|
self.sf_dtype,
|
|
288
260
|
self.sf_vec_size,
|
|
289
|
-
|
|
261
|
+
tcgen05.CtaGroup.ONE,
|
|
290
262
|
self.mma_inst_shape_mnk_sfb[:2],
|
|
291
263
|
)
|
|
292
264
|
|
|
@@ -306,20 +278,20 @@ class PersistentDenseGemmKernel:
|
|
|
306
278
|
else:
|
|
307
279
|
self.mma_tiler_sfb = None
|
|
308
280
|
self.cta_tile_shape_mnk = (
|
|
309
|
-
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
|
281
|
+
self.mma_tiler[0] // cute.size(self.tiled_mma.thr_id.shape),
|
|
310
282
|
self.mma_tiler[1],
|
|
311
283
|
self.mma_tiler[2],
|
|
312
284
|
)
|
|
313
285
|
|
|
314
286
|
# Compute cluster layout
|
|
315
287
|
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
316
|
-
cute.make_layout(
|
|
317
|
-
(tiled_mma.thr_id.shape,),
|
|
288
|
+
cute.make_layout(self.cluster_shape_mnk),
|
|
289
|
+
(self.tiled_mma.thr_id.shape,),
|
|
318
290
|
)
|
|
319
291
|
if const_expr(self.blockscaled):
|
|
320
292
|
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
|
|
321
|
-
cute.make_layout(
|
|
322
|
-
(tiled_mma_sfb.thr_id.shape,),
|
|
293
|
+
cute.make_layout(self.cluster_shape_mnk),
|
|
294
|
+
(self.tiled_mma_sfb.thr_id.shape,),
|
|
323
295
|
)
|
|
324
296
|
else:
|
|
325
297
|
self.cluster_layout_sfb_vmnk = None
|
|
@@ -344,11 +316,11 @@ class PersistentDenseGemmKernel:
|
|
|
344
316
|
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
|
|
345
317
|
(
|
|
346
318
|
self.num_acc_stage,
|
|
347
|
-
self.
|
|
348
|
-
self.
|
|
349
|
-
self.
|
|
319
|
+
self.ab_stage,
|
|
320
|
+
self.epi_stage,
|
|
321
|
+
self.epi_c_stage,
|
|
350
322
|
) = self._compute_stages(
|
|
351
|
-
tiled_mma,
|
|
323
|
+
self.tiled_mma,
|
|
352
324
|
self.mma_tiler,
|
|
353
325
|
self.a_dtype,
|
|
354
326
|
self.b_dtype,
|
|
@@ -362,35 +334,36 @@ class PersistentDenseGemmKernel:
|
|
|
362
334
|
self.smem_capacity,
|
|
363
335
|
self.occupancy,
|
|
364
336
|
)
|
|
337
|
+
self.sched_stage = 1 # For compatibility with GemmSm90
|
|
365
338
|
|
|
366
339
|
# Compute A/B/SFA/SFB/C shared memory layout
|
|
367
340
|
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
|
368
|
-
tiled_mma, self.mma_tiler, self.a_dtype, self.
|
|
341
|
+
self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
|
|
369
342
|
)
|
|
370
343
|
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
|
371
|
-
tiled_mma, self.mma_tiler, self.b_dtype, self.
|
|
344
|
+
self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage
|
|
372
345
|
)
|
|
373
|
-
self.
|
|
374
|
-
self.d_dtype, self.d_layout, self.epi_tile, self.
|
|
346
|
+
self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
347
|
+
self.d_dtype, self.d_layout, self.epi_tile, self.epi_stage
|
|
375
348
|
)
|
|
376
349
|
if const_expr(self.c_dtype is not None):
|
|
377
350
|
self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
378
|
-
self.c_dtype, self.c_layout, self.epi_tile, self.
|
|
351
|
+
self.c_dtype, self.c_layout, self.epi_tile, self.epi_c_stage
|
|
379
352
|
)
|
|
380
353
|
else:
|
|
381
354
|
self.epi_c_smem_layout_staged = None
|
|
382
355
|
if const_expr(self.blockscaled):
|
|
383
356
|
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
|
|
384
|
-
tiled_mma,
|
|
357
|
+
self.tiled_mma,
|
|
385
358
|
self.mma_tiler,
|
|
386
359
|
self.sf_vec_size,
|
|
387
|
-
self.
|
|
360
|
+
self.ab_stage,
|
|
388
361
|
)
|
|
389
362
|
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
|
|
390
|
-
tiled_mma,
|
|
363
|
+
self.tiled_mma,
|
|
391
364
|
self.mma_tiler,
|
|
392
365
|
self.sf_vec_size,
|
|
393
|
-
self.
|
|
366
|
+
self.ab_stage,
|
|
394
367
|
)
|
|
395
368
|
else:
|
|
396
369
|
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged = None, None
|
|
@@ -398,7 +371,7 @@ class PersistentDenseGemmKernel:
|
|
|
398
371
|
# Compute the number of tensor memory allocation columns
|
|
399
372
|
if const_expr(not self.blockscaled):
|
|
400
373
|
self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
|
|
401
|
-
tiled_mma, self.mma_tiler, self.num_acc_stage
|
|
374
|
+
self.tiled_mma, self.mma_tiler, self.num_acc_stage
|
|
402
375
|
)
|
|
403
376
|
else:
|
|
404
377
|
SM100_TMEM_CAPACITY_COLUMNS = 512
|
|
@@ -409,14 +382,14 @@ class PersistentDenseGemmKernel:
|
|
|
409
382
|
self,
|
|
410
383
|
mA: cute.Tensor,
|
|
411
384
|
mB: cute.Tensor,
|
|
412
|
-
mD: cute.Tensor,
|
|
385
|
+
mD: Optional[cute.Tensor],
|
|
413
386
|
mC: Optional[cute.Tensor],
|
|
414
|
-
|
|
415
|
-
|
|
387
|
+
epilogue_args: ArgumentsBase,
|
|
388
|
+
scheduler_args: TileSchedulerOptions,
|
|
389
|
+
varlen_args: Optional[VarlenArguments],
|
|
416
390
|
stream: cuda.CUstream,
|
|
417
391
|
mSFA: Optional[cute.Tensor] = None,
|
|
418
392
|
mSFB: Optional[cute.Tensor] = None,
|
|
419
|
-
epilogue_op: cutlass.Constexpr = lambda x: x,
|
|
420
393
|
):
|
|
421
394
|
"""Execute the GEMM operation in steps:
|
|
422
395
|
- Setup static attributes before smem/grid/tma computation
|
|
@@ -435,30 +408,46 @@ class PersistentDenseGemmKernel:
|
|
|
435
408
|
:type max_active_clusters: cutlass.Constexpr
|
|
436
409
|
:param stream: CUDA stream for asynchronous execution
|
|
437
410
|
:type stream: cuda.CUstream
|
|
438
|
-
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
|
|
439
|
-
:type epilogue_op: cutlass.Constexpr
|
|
440
411
|
:raises TypeError: If input data types are incompatible with the MMA instruction.
|
|
441
412
|
:raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
|
|
442
413
|
"""
|
|
443
414
|
if const_expr(self.blockscaled):
|
|
444
415
|
assert mSFA is not None and mSFB is not None
|
|
445
416
|
# Setup static attributes before smem/grid/tma computation
|
|
446
|
-
self.a_dtype
|
|
447
|
-
self.b_dtype
|
|
448
|
-
self.d_dtype
|
|
417
|
+
self.a_dtype = mA.element_type
|
|
418
|
+
self.b_dtype = mB.element_type
|
|
419
|
+
self.d_dtype = mD.element_type if mD is not None else None
|
|
449
420
|
self.c_dtype = mC.element_type if mC is not None else None
|
|
450
421
|
self.sf_dtype: Optional[Type[cutlass.Numeric]] = (
|
|
451
422
|
mSFA.element_type if mSFA is not None else None
|
|
452
423
|
)
|
|
453
|
-
self.
|
|
454
|
-
self.
|
|
455
|
-
self.d_layout =
|
|
456
|
-
self.c_layout =
|
|
424
|
+
self.a_layout = LayoutEnum.from_tensor(mA)
|
|
425
|
+
self.b_layout = LayoutEnum.from_tensor(mB)
|
|
426
|
+
self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
|
|
427
|
+
self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
|
|
428
|
+
self.a_major_mode = LayoutEnum.from_tensor(mA).mma_major_mode()
|
|
429
|
+
self.b_major_mode = LayoutEnum.from_tensor(mB).mma_major_mode()
|
|
457
430
|
|
|
458
431
|
# Check if input data types are compatible with MMA instruction
|
|
459
432
|
if const_expr(self.a_dtype != self.b_dtype):
|
|
460
433
|
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
|
|
461
434
|
|
|
435
|
+
if const_expr(varlen_args is None):
|
|
436
|
+
varlen_args = VarlenArguments()
|
|
437
|
+
assert (varlen_args.mAIdx is not None) == self.gather_A
|
|
438
|
+
|
|
439
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
440
|
+
new_stride = lambda t: tuple(
|
|
441
|
+
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
|
442
|
+
for s in t.stride
|
|
443
|
+
)
|
|
444
|
+
mA, mD = [
|
|
445
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
446
|
+
if t is not None
|
|
447
|
+
else None
|
|
448
|
+
for t in (mA, mD)
|
|
449
|
+
]
|
|
450
|
+
|
|
462
451
|
# Setup attributes that dependent on gemm inputs
|
|
463
452
|
self._setup_attributes()
|
|
464
453
|
|
|
@@ -471,67 +460,44 @@ class PersistentDenseGemmKernel:
|
|
|
471
460
|
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(mB.shape, self.sf_vec_size)
|
|
472
461
|
mSFB = cute.make_tensor(mSFB.iterator, sfb_layout)
|
|
473
462
|
|
|
474
|
-
|
|
475
|
-
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
476
|
-
self.a_dtype,
|
|
477
|
-
self.a_major_mode,
|
|
478
|
-
self.b_major_mode,
|
|
479
|
-
self.acc_dtype,
|
|
480
|
-
self.cta_group,
|
|
481
|
-
self.mma_tiler[:2],
|
|
482
|
-
)
|
|
483
|
-
tiled_mma_sfb = None
|
|
484
|
-
else:
|
|
485
|
-
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
486
|
-
self.a_dtype,
|
|
487
|
-
self.a_major_mode,
|
|
488
|
-
self.b_major_mode,
|
|
489
|
-
self.sf_dtype,
|
|
490
|
-
self.sf_vec_size,
|
|
491
|
-
self.cta_group,
|
|
492
|
-
self.mma_inst_shape_mnk[:2],
|
|
493
|
-
)
|
|
494
|
-
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
495
|
-
self.a_dtype,
|
|
496
|
-
self.a_major_mode,
|
|
497
|
-
self.b_major_mode,
|
|
498
|
-
self.sf_dtype,
|
|
499
|
-
self.sf_vec_size,
|
|
500
|
-
cute.nvgpu.tcgen05.CtaGroup.ONE,
|
|
501
|
-
self.mma_inst_shape_mnk_sfb[:2],
|
|
502
|
-
)
|
|
503
|
-
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
|
463
|
+
atom_thr_size = cute.size(self.tiled_mma.thr_id.shape)
|
|
504
464
|
|
|
505
|
-
# Setup TMA load for A
|
|
506
|
-
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
|
465
|
+
# Setup TMA load for A & B
|
|
507
466
|
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
|
508
|
-
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
|
509
|
-
a_op,
|
|
510
|
-
mA,
|
|
511
|
-
a_smem_layout,
|
|
512
|
-
self.mma_tiler,
|
|
513
|
-
tiled_mma,
|
|
514
|
-
self.cluster_layout_vmnk.shape,
|
|
515
|
-
internal_type=(cutlass.TFloat32 if mA.element_type is cutlass.Float32 else None),
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
# Setup TMA load for B
|
|
519
|
-
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
|
520
467
|
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
|
468
|
+
tma_atom_a, tma_tensor_a = None, None
|
|
469
|
+
if const_expr(not self.gather_A):
|
|
470
|
+
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
|
471
|
+
self.cluster_shape_mnk, self.tiled_mma.thr_id
|
|
472
|
+
)
|
|
473
|
+
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
|
474
|
+
a_op,
|
|
475
|
+
mA,
|
|
476
|
+
a_smem_layout,
|
|
477
|
+
self.mma_tiler,
|
|
478
|
+
self.tiled_mma,
|
|
479
|
+
self.cluster_layout_vmnk.shape,
|
|
480
|
+
internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None),
|
|
481
|
+
)
|
|
482
|
+
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
|
483
|
+
self.cluster_shape_mnk, self.tiled_mma.thr_id
|
|
484
|
+
)
|
|
521
485
|
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
|
522
486
|
b_op,
|
|
523
487
|
mB,
|
|
524
488
|
b_smem_layout,
|
|
525
489
|
self.mma_tiler,
|
|
526
|
-
tiled_mma,
|
|
490
|
+
self.tiled_mma,
|
|
527
491
|
self.cluster_layout_vmnk.shape,
|
|
528
|
-
internal_type=(cutlass.TFloat32 if mB.element_type is
|
|
492
|
+
internal_type=(cutlass.TFloat32 if mB.element_type is Float32 else None),
|
|
529
493
|
)
|
|
530
494
|
|
|
495
|
+
tma_atom_sfa, tma_tensor_sfa = None, None
|
|
496
|
+
tma_atom_sfb, tma_tensor_sfb = None, None
|
|
531
497
|
if const_expr(self.blockscaled):
|
|
532
498
|
# Setup TMA load for SFA
|
|
533
499
|
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
|
534
|
-
self.
|
|
500
|
+
self.cluster_shape_mnk, self.tiled_mma.thr_id
|
|
535
501
|
)
|
|
536
502
|
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
|
537
503
|
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
|
@@ -539,13 +505,13 @@ class PersistentDenseGemmKernel:
|
|
|
539
505
|
mSFA,
|
|
540
506
|
sfa_smem_layout,
|
|
541
507
|
self.mma_tiler,
|
|
542
|
-
tiled_mma,
|
|
508
|
+
self.tiled_mma,
|
|
543
509
|
self.cluster_layout_vmnk.shape,
|
|
544
510
|
internal_type=cutlass.Int16,
|
|
545
511
|
)
|
|
546
512
|
# Setup TMA load for SFB
|
|
547
513
|
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
|
|
548
|
-
self.
|
|
514
|
+
self.cluster_shape_mnk, self.tiled_mma.thr_id
|
|
549
515
|
)
|
|
550
516
|
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
|
551
517
|
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
|
@@ -553,30 +519,31 @@ class PersistentDenseGemmKernel:
|
|
|
553
519
|
mSFB,
|
|
554
520
|
sfb_smem_layout,
|
|
555
521
|
self.mma_tiler_sfb,
|
|
556
|
-
tiled_mma_sfb,
|
|
522
|
+
self.tiled_mma_sfb,
|
|
557
523
|
self.cluster_layout_sfb_vmnk.shape,
|
|
558
524
|
internal_type=cutlass.Int16,
|
|
559
525
|
)
|
|
560
|
-
else:
|
|
561
|
-
tma_atom_sfa, tma_tensor_sfa = None, None
|
|
562
|
-
tma_atom_sfb, tma_tensor_sfb = None, None
|
|
563
526
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
527
|
+
self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
528
|
+
if const_expr(not self.gather_A):
|
|
529
|
+
self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
567
530
|
if const_expr(self.blockscaled):
|
|
568
531
|
sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
|
|
569
532
|
sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
|
|
570
|
-
self.num_tma_load_bytes +=
|
|
533
|
+
self.num_tma_load_bytes += sfa_copy_size + sfb_copy_size
|
|
534
|
+
self.num_tma_load_bytes *= atom_thr_size
|
|
571
535
|
|
|
572
536
|
# Setup TMA store for D
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
537
|
+
tma_atom_d, tma_tensor_d = None, None
|
|
538
|
+
if const_expr(mD is not None):
|
|
539
|
+
epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0))
|
|
540
|
+
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
|
541
|
+
cpasync.CopyBulkTensorTileS2GOp(),
|
|
542
|
+
mD,
|
|
543
|
+
epi_smem_layout,
|
|
544
|
+
self.epi_tile,
|
|
545
|
+
)
|
|
546
|
+
tma_atom_c, tma_tensor_c = None, None
|
|
580
547
|
if const_expr(mC is not None):
|
|
581
548
|
epi_c_smem_layout = cute.slice_(self.epi_c_smem_layout_staged, (None, None, 0))
|
|
582
549
|
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
|
@@ -585,26 +552,19 @@ class PersistentDenseGemmKernel:
|
|
|
585
552
|
epi_c_smem_layout,
|
|
586
553
|
self.epi_tile,
|
|
587
554
|
)
|
|
588
|
-
else:
|
|
589
|
-
tma_atom_c, tma_tensor_c = None, None
|
|
590
555
|
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
)
|
|
594
|
-
|
|
595
|
-
tile_sched_args = TileSchedulerArguments(
|
|
596
|
-
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
597
|
-
raster_order=RasterOrderOption.Heuristic,
|
|
598
|
-
group_size=8,
|
|
599
|
-
cluster_shape_mnk=(*self.cluster_shape_mn, 1),
|
|
600
|
-
tile_count_semaphore=tile_count_semaphore,
|
|
601
|
-
is_persistent=True,
|
|
602
|
-
)
|
|
556
|
+
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
|
557
|
+
|
|
558
|
+
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
|
|
559
|
+
tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
|
|
603
560
|
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
604
|
-
grid = TileSchedulerCls.get_grid_shape(
|
|
561
|
+
grid = TileSchedulerCls.get_grid_shape(
|
|
562
|
+
tile_sched_params, scheduler_args.max_active_clusters
|
|
563
|
+
)
|
|
605
564
|
|
|
606
565
|
self.buffer_align_bytes = 1024
|
|
607
566
|
|
|
567
|
+
epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
|
|
608
568
|
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
|
609
569
|
sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU
|
|
610
570
|
sfa_smem_size = (
|
|
@@ -617,18 +577,18 @@ class PersistentDenseGemmKernel:
|
|
|
617
577
|
# Define shared storage for kernel
|
|
618
578
|
@cute.struct
|
|
619
579
|
class SharedStorage:
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
580
|
+
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
581
|
+
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
582
|
+
acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
|
583
|
+
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
584
|
+
tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
|
|
625
585
|
tmem_dealloc_mbar_ptr: cutlass.Int64
|
|
626
586
|
tmem_holding_buf: Int32
|
|
627
|
-
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
|
628
|
-
tile_count: cute.struct.MemRange[cutlass.Int32, 1]
|
|
629
587
|
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
630
588
|
sD: cute.struct.Align[
|
|
631
|
-
cute.struct.MemRange[
|
|
589
|
+
cute.struct.MemRange[
|
|
590
|
+
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
591
|
+
],
|
|
632
592
|
self.buffer_align_bytes,
|
|
633
593
|
]
|
|
634
594
|
sC: cute.struct.Align[
|
|
@@ -637,6 +597,7 @@ class PersistentDenseGemmKernel:
|
|
|
637
597
|
],
|
|
638
598
|
self.buffer_align_bytes,
|
|
639
599
|
]
|
|
600
|
+
epi: self.epi_get_smem_struct(epilogue_params)
|
|
640
601
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
641
602
|
sA: cute.struct.Align[
|
|
642
603
|
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
|
|
@@ -662,10 +623,10 @@ class PersistentDenseGemmKernel:
|
|
|
662
623
|
|
|
663
624
|
# Launch the kernel synchronously
|
|
664
625
|
self.kernel(
|
|
665
|
-
tiled_mma,
|
|
666
|
-
tiled_mma_sfb,
|
|
626
|
+
self.tiled_mma,
|
|
627
|
+
self.tiled_mma_sfb,
|
|
667
628
|
tma_atom_a,
|
|
668
|
-
tma_tensor_a,
|
|
629
|
+
tma_tensor_a if const_expr(not self.gather_A) else mA,
|
|
669
630
|
tma_atom_b,
|
|
670
631
|
tma_tensor_b,
|
|
671
632
|
tma_atom_sfa,
|
|
@@ -676,24 +637,29 @@ class PersistentDenseGemmKernel:
|
|
|
676
637
|
tma_tensor_d,
|
|
677
638
|
tma_atom_c,
|
|
678
639
|
tma_tensor_c,
|
|
640
|
+
epilogue_params,
|
|
641
|
+
varlen_args.mCuSeqlensM,
|
|
642
|
+
varlen_args.mCuSeqlensK,
|
|
643
|
+
varlen_args.mTensormaps,
|
|
644
|
+
varlen_args.mAIdx,
|
|
679
645
|
self.cluster_layout_vmnk,
|
|
680
646
|
self.cluster_layout_sfb_vmnk,
|
|
681
647
|
self.a_smem_layout_staged,
|
|
682
648
|
self.b_smem_layout_staged,
|
|
683
649
|
self.sfa_smem_layout_staged,
|
|
684
650
|
self.sfb_smem_layout_staged,
|
|
685
|
-
self.
|
|
651
|
+
self.epi_smem_layout_staged,
|
|
686
652
|
self.epi_c_smem_layout_staged,
|
|
687
653
|
self.epi_tile,
|
|
688
654
|
tile_sched_params,
|
|
689
655
|
TileSchedulerCls,
|
|
690
|
-
epilogue_op,
|
|
691
656
|
).launch(
|
|
692
657
|
grid=grid,
|
|
693
658
|
block=[self.threads_per_cta, 1, 1],
|
|
694
|
-
cluster=
|
|
659
|
+
cluster=self.cluster_shape_mnk,
|
|
695
660
|
smem=self.shared_storage.size_in_bytes(),
|
|
696
661
|
stream=stream,
|
|
662
|
+
min_blocks_per_mp=1,
|
|
697
663
|
)
|
|
698
664
|
return
|
|
699
665
|
|
|
@@ -703,7 +669,7 @@ class PersistentDenseGemmKernel:
|
|
|
703
669
|
self,
|
|
704
670
|
tiled_mma: cute.TiledMma,
|
|
705
671
|
tiled_mma_sfb: Optional[cute.TiledMma],
|
|
706
|
-
tma_atom_a: cute.CopyAtom,
|
|
672
|
+
tma_atom_a: Optional[cute.CopyAtom],
|
|
707
673
|
mA_mkl: cute.Tensor,
|
|
708
674
|
tma_atom_b: cute.CopyAtom,
|
|
709
675
|
mB_nkl: cute.Tensor,
|
|
@@ -712,37 +678,54 @@ class PersistentDenseGemmKernel:
|
|
|
712
678
|
tma_atom_sfb: Optional[cute.CopyAtom],
|
|
713
679
|
mSFB_nkl: Optional[cute.Tensor],
|
|
714
680
|
tma_atom_d: Optional[cute.CopyAtom],
|
|
715
|
-
mD_mnl: cute.Tensor,
|
|
681
|
+
mD_mnl: Optional[cute.Tensor],
|
|
716
682
|
tma_atom_c: Optional[cute.CopyAtom],
|
|
717
683
|
mC_mnl: Optional[cute.Tensor],
|
|
684
|
+
epilogue_params: ParamsBase,
|
|
685
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
686
|
+
cu_seqlens_k: Optional[cute.Tensor],
|
|
687
|
+
tensormaps: Optional[cute.Tensor],
|
|
688
|
+
mAIdx: Optional[cute.Tensor],
|
|
718
689
|
cluster_layout_vmnk: cute.Layout,
|
|
719
690
|
cluster_layout_sfb_vmnk: Optional[cute.Layout],
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
691
|
+
a_smem_layout: cute.ComposedLayout,
|
|
692
|
+
b_smem_layout: cute.ComposedLayout,
|
|
693
|
+
sfa_smem_layout: Optional[cute.Layout],
|
|
694
|
+
sfb_smem_layout: Optional[cute.Layout],
|
|
695
|
+
epi_smem_layout: Union[cute.Layout, cute.ComposedLayout, None],
|
|
696
|
+
epi_c_smem_layout: Union[cute.Layout, cute.ComposedLayout, None],
|
|
726
697
|
epi_tile: cute.Tile,
|
|
727
698
|
tile_sched_params: ParamsBase,
|
|
728
699
|
TileSchedulerCls: cutlass.Constexpr[Callable],
|
|
729
|
-
epilogue_op: cutlass.Constexpr[Callable],
|
|
730
700
|
):
|
|
731
701
|
"""
|
|
732
702
|
GPU device kernel performing the Persistent batched GEMM computation.
|
|
733
703
|
"""
|
|
704
|
+
|
|
705
|
+
varlen_m = const_expr(cu_seqlens_m is not None)
|
|
706
|
+
varlen_k = const_expr(cu_seqlens_k is not None)
|
|
707
|
+
assert not (varlen_m and varlen_k)
|
|
708
|
+
if const_expr(self.gather_A):
|
|
709
|
+
assert varlen_m or varlen_k
|
|
710
|
+
has_D = const_expr(mD_mnl is not None)
|
|
711
|
+
has_C = const_expr(mC_mnl is not None)
|
|
712
|
+
|
|
734
713
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
735
714
|
|
|
736
|
-
#
|
|
737
|
-
#
|
|
738
|
-
#
|
|
715
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
716
|
+
# Prefetch Tma desc
|
|
717
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
739
718
|
if warp_idx == self.tma_warp_id:
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
719
|
+
for tma_atom in (
|
|
720
|
+
tma_atom_a,
|
|
721
|
+
tma_atom_b,
|
|
722
|
+
tma_atom_sfa,
|
|
723
|
+
tma_atom_sfb,
|
|
724
|
+
tma_atom_d,
|
|
725
|
+
tma_atom_c,
|
|
726
|
+
):
|
|
727
|
+
if const_expr(tma_atom is not None):
|
|
728
|
+
cpasync.prefetch_descriptor(tma_atom)
|
|
746
729
|
|
|
747
730
|
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
|
|
748
731
|
|
|
@@ -754,13 +737,6 @@ class PersistentDenseGemmKernel:
|
|
|
754
737
|
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
|
755
738
|
is_leader_cta = mma_tile_coord_v == 0
|
|
756
739
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
757
|
-
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
|
758
|
-
if const_expr(self.blockscaled):
|
|
759
|
-
block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
|
|
760
|
-
cta_rank_in_cluster
|
|
761
|
-
)
|
|
762
|
-
else:
|
|
763
|
-
block_in_cluster_coord_sfb_vmnk = None
|
|
764
740
|
# Coord inside cta
|
|
765
741
|
tidx, _, _ = cute.arch.thread_idx()
|
|
766
742
|
|
|
@@ -779,100 +755,53 @@ class PersistentDenseGemmKernel:
|
|
|
779
755
|
num_tmem_dealloc_threads = 32
|
|
780
756
|
cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
|
|
781
757
|
|
|
782
|
-
# Initialize
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
)
|
|
788
|
-
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
|
789
|
-
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
|
|
790
|
-
num_stages=self.num_ab_stage,
|
|
791
|
-
producer_group=ab_pipeline_producer_group,
|
|
792
|
-
consumer_group=ab_pipeline_consumer_group,
|
|
793
|
-
tx_count=self.num_tma_load_bytes,
|
|
794
|
-
cta_layout_vmnk=cluster_layout_vmnk,
|
|
758
|
+
# Initialize pipelines and states
|
|
759
|
+
ab_pipeline = self.make_ab_pipeline(
|
|
760
|
+
tiled_mma=tiled_mma,
|
|
761
|
+
cluster_layout_vmnk=cluster_layout_vmnk,
|
|
762
|
+
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
795
763
|
)
|
|
796
|
-
|
|
797
|
-
if const_expr(
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
consumer_arrive_cnt = len(self.epilog_warp_id)
|
|
802
|
-
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
803
|
-
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
804
|
-
)
|
|
805
|
-
c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
|
|
806
|
-
tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
|
|
807
|
-
epi_pipeline = pipeline.PipelineTmaAsync.create(
|
|
808
|
-
barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
809
|
-
num_stages=self.num_c_stage,
|
|
810
|
-
producer_group=epi_pipeline_producer_group,
|
|
811
|
-
consumer_group=epi_pipeline_consumer_group,
|
|
812
|
-
tx_count=tma_copy_c_bytes,
|
|
764
|
+
epi_pipeline = None
|
|
765
|
+
if const_expr(has_C):
|
|
766
|
+
epi_pipeline = self.make_epi_pipeline(
|
|
767
|
+
c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
|
|
768
|
+
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
813
769
|
)
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
# Initialize acc_pipeline (barrier) and states
|
|
818
|
-
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
819
|
-
num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
|
|
820
|
-
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
821
|
-
pipeline.Agent.Thread, num_acc_consumer_threads
|
|
822
|
-
)
|
|
823
|
-
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
|
824
|
-
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
|
825
|
-
num_stages=self.num_acc_stage,
|
|
826
|
-
producer_group=acc_pipeline_producer_group,
|
|
827
|
-
consumer_group=acc_pipeline_consumer_group,
|
|
828
|
-
cta_layout_vmnk=cluster_layout_vmnk,
|
|
770
|
+
acc_pipeline = self.make_acc_pipeline(
|
|
771
|
+
cluster_layout_vmnk=cluster_layout_vmnk,
|
|
772
|
+
acc_pipeline_mbar_ptr=storage.acc_pipeline_array_ptr.data_ptr(),
|
|
829
773
|
)
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
# pipeline.Agent.Thread, consumer_arrive_cnt
|
|
842
|
-
# )
|
|
843
|
-
# sched_pipeline = pipeline.PipelineAsync.create(
|
|
844
|
-
# barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
845
|
-
# num_stages=self.sched_stage,
|
|
846
|
-
# producer_group=sched_pipeline_producer_group,
|
|
847
|
-
# consumer_group=sched_pipeline_consumer_group,
|
|
848
|
-
# # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
849
|
-
# consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
|
|
850
|
-
# )
|
|
851
|
-
# tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
852
|
-
# else:
|
|
853
|
-
# sched_pipeline = None
|
|
854
|
-
# tile_count = None
|
|
774
|
+
sched_pipeline = None
|
|
775
|
+
tile_count = None
|
|
776
|
+
if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
777
|
+
# TODO: Untested, not sure if this is right for Sm100
|
|
778
|
+
# Dynamic persistent scheduler
|
|
779
|
+
sched_pipeline = self.make_sched_pipeline(
|
|
780
|
+
self.cluster_shape_mnk,
|
|
781
|
+
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
782
|
+
varlen_k=varlen_k,
|
|
783
|
+
)
|
|
784
|
+
tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
855
785
|
|
|
856
786
|
# Setup smem tensor A/B/D
|
|
857
787
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
858
|
-
sA = storage.sA.get_tensor(
|
|
788
|
+
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
|
859
789
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
860
|
-
sB = storage.sB.get_tensor(
|
|
790
|
+
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
|
791
|
+
sSFA, sSFB = None, None
|
|
861
792
|
if const_expr(self.blockscaled):
|
|
862
793
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
863
|
-
sSFA = storage.sSFA.get_tensor(
|
|
794
|
+
sSFA = storage.sSFA.get_tensor(sfa_smem_layout)
|
|
864
795
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
865
|
-
sSFB = storage.sSFB.get_tensor(
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
else:
|
|
875
|
-
sC = None
|
|
796
|
+
sSFB = storage.sSFB.get_tensor(sfb_smem_layout)
|
|
797
|
+
sD = None
|
|
798
|
+
if const_expr(has_D):
|
|
799
|
+
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
800
|
+
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
|
801
|
+
sC = None
|
|
802
|
+
if const_expr(has_C):
|
|
803
|
+
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
804
|
+
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
|
876
805
|
|
|
877
806
|
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
|
878
807
|
thr_mma_sfb = (
|
|
@@ -884,26 +813,51 @@ class PersistentDenseGemmKernel:
|
|
|
884
813
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
885
814
|
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
|
886
815
|
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
816
|
+
# Get tensormap buffer address
|
|
817
|
+
tensormap_manager, tensormap_ab_ptrs, tensormap_d_ptr, tensormap_epi_ptrs = (
|
|
818
|
+
self.tensormap_init(tensormaps, varlen_m, varlen_k, has_D, warp_idx)
|
|
890
819
|
)
|
|
891
820
|
|
|
892
|
-
TileSchedulerCls = partial(
|
|
893
|
-
|
|
821
|
+
TileSchedulerCls = partial(
|
|
822
|
+
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
|
823
|
+
)
|
|
894
824
|
|
|
895
|
-
|
|
825
|
+
tmem_alloc_barrier = pipeline.NamedBarrier(
|
|
826
|
+
barrier_id=int(NamedBarrierGemm.TmemPtr),
|
|
827
|
+
num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
|
|
828
|
+
)
|
|
829
|
+
epi_load_barrier = None
|
|
830
|
+
if const_expr(has_C):
|
|
896
831
|
epi_load_barrier = pipeline.NamedBarrier(
|
|
897
|
-
barrier_id=int(
|
|
832
|
+
barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE
|
|
898
833
|
)
|
|
899
|
-
else:
|
|
900
|
-
epi_load_barrier = None
|
|
901
834
|
|
|
902
835
|
#
|
|
903
836
|
# Specialized TMA load warp
|
|
904
837
|
#
|
|
905
838
|
if warp_idx == self.tma_warp_id:
|
|
839
|
+
if const_expr(varlen_k):
|
|
840
|
+
# initialize tensormap for A & B
|
|
841
|
+
if const_expr(not self.gather_A):
|
|
842
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
843
|
+
tma_atom_a,
|
|
844
|
+
tensormap_ab_ptrs[0],
|
|
845
|
+
is_manager_warp=True,
|
|
846
|
+
)
|
|
847
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
848
|
+
tma_atom_b,
|
|
849
|
+
tensormap_ab_ptrs[1],
|
|
850
|
+
is_manager_warp=True,
|
|
851
|
+
)
|
|
906
852
|
# Compute multicast mask for A/B buffer full
|
|
853
|
+
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
|
854
|
+
block_in_cluster_coord_sfb_vmnk = None
|
|
855
|
+
if const_expr(self.blockscaled):
|
|
856
|
+
block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
|
|
857
|
+
cta_rank_in_cluster
|
|
858
|
+
)
|
|
859
|
+
a_mcast_mask, b_mcast_mask = None, None
|
|
860
|
+
sfa_mcast_mask, sfb_mcast_mask = None, None
|
|
907
861
|
if const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
|
908
862
|
a_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
909
863
|
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
|
@@ -918,28 +872,45 @@ class PersistentDenseGemmKernel:
|
|
|
918
872
|
sfb_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
919
873
|
cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1
|
|
920
874
|
)
|
|
921
|
-
else:
|
|
922
|
-
sfa_mcast_mask, sfb_mcast_mask = None, None
|
|
923
|
-
else:
|
|
924
|
-
a_mcast_mask, b_mcast_mask = None, None
|
|
925
|
-
sfa_mcast_mask, sfb_mcast_mask = None, None
|
|
926
875
|
|
|
927
876
|
# Persistent tile scheduling loop
|
|
928
|
-
|
|
877
|
+
is_scheduler_warp = True
|
|
878
|
+
if const_expr(cute.size(cluster_layout_vmnk) > 1):
|
|
879
|
+
is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0
|
|
880
|
+
tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
|
|
929
881
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
930
882
|
ab_producer_state = pipeline.make_pipeline_state(
|
|
931
|
-
pipeline.PipelineUserType.Producer, self.
|
|
883
|
+
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
932
884
|
)
|
|
933
|
-
|
|
885
|
+
if const_expr(varlen_k):
|
|
886
|
+
# wait tensormap initialization complete before update
|
|
887
|
+
tensormap_manager.fence_tensormap_initialization()
|
|
888
|
+
# batch index of last tile
|
|
889
|
+
last_batch_idx = cutlass.Int32(-1)
|
|
890
|
+
do_epi_load_barrier_arrive = Boolean(True)
|
|
934
891
|
while work_tile.is_valid_tile:
|
|
935
|
-
# Get tile coord from tile scheduler
|
|
936
892
|
tile_coord_mnkl = work_tile.tile_idx
|
|
893
|
+
batch_idx = tile_coord_mnkl[3]
|
|
894
|
+
if const_expr(varlen_k):
|
|
895
|
+
is_group_changed = batch_idx != last_batch_idx
|
|
896
|
+
last_batch_idx = batch_idx
|
|
897
|
+
if is_group_changed:
|
|
898
|
+
self.tensormap_update_AB(
|
|
899
|
+
tensormap_manager,
|
|
900
|
+
tensormap_ab_ptrs,
|
|
901
|
+
cu_seqlens_k,
|
|
902
|
+
batch_idx,
|
|
903
|
+
is_manager_warp=True,
|
|
904
|
+
)
|
|
905
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
906
|
+
# Local_tile partition global tensors
|
|
907
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
937
908
|
mma_tile_coord_mnl = (
|
|
938
909
|
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
939
910
|
tile_coord_mnkl[1],
|
|
940
911
|
tile_coord_mnkl[3],
|
|
941
912
|
)
|
|
942
|
-
#
|
|
913
|
+
# TODO: varlen_m
|
|
943
914
|
# (bM, bK, RestK)
|
|
944
915
|
gA_mkl = cute.local_tile(
|
|
945
916
|
mA_mkl,
|
|
@@ -1007,7 +978,7 @@ class PersistentDenseGemmKernel:
|
|
|
1007
978
|
sfa_cta_layout = a_cta_layout
|
|
1008
979
|
# ((atom_v, rest_v), STAGE)
|
|
1009
980
|
# ((atom_v, rest_v), RestK)
|
|
1010
|
-
tAsSFA, tAgSFA =
|
|
981
|
+
tAsSFA, tAgSFA = cpasync.tma_partition(
|
|
1011
982
|
tma_atom_sfa,
|
|
1012
983
|
block_in_cluster_coord_vmnk[2],
|
|
1013
984
|
sfa_cta_layout,
|
|
@@ -1022,7 +993,7 @@ class PersistentDenseGemmKernel:
|
|
|
1022
993
|
)
|
|
1023
994
|
# ((atom_v, rest_v), STAGE)
|
|
1024
995
|
# ((atom_v, rest_v), RestK)
|
|
1025
|
-
tBsSFB, tBgSFB =
|
|
996
|
+
tBsSFB, tBgSFB = cpasync.tma_partition(
|
|
1026
997
|
tma_atom_sfb,
|
|
1027
998
|
block_in_cluster_coord_sfb_vmnk[1],
|
|
1028
999
|
sfb_cta_layout,
|
|
@@ -1060,12 +1031,15 @@ class PersistentDenseGemmKernel:
|
|
|
1060
1031
|
# with loading A and B.
|
|
1061
1032
|
if do_epi_load_barrier_arrive:
|
|
1062
1033
|
epi_load_barrier.arrive()
|
|
1063
|
-
do_epi_load_barrier_arrive =
|
|
1034
|
+
do_epi_load_barrier_arrive = Boolean(False)
|
|
1064
1035
|
# Advance to next tile
|
|
1036
|
+
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1065
1037
|
tile_scheduler.advance_to_next_work()
|
|
1066
1038
|
work_tile = tile_scheduler.get_current_work()
|
|
1067
1039
|
# Wait A/B buffer empty
|
|
1068
1040
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
1041
|
+
if is_scheduler_warp:
|
|
1042
|
+
tile_scheduler.producer_tail()
|
|
1069
1043
|
|
|
1070
1044
|
#
|
|
1071
1045
|
# Specialized TMA epi load warp
|
|
@@ -1073,15 +1047,16 @@ class PersistentDenseGemmKernel:
|
|
|
1073
1047
|
if const_expr(mC_mnl is not None):
|
|
1074
1048
|
if warp_idx == self.tma_epi_warp_id:
|
|
1075
1049
|
epi_producer_state = pipeline.make_pipeline_state(
|
|
1076
|
-
pipeline.PipelineUserType.Producer, self.
|
|
1050
|
+
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
1077
1051
|
)
|
|
1078
|
-
do_epi_load_barrier_wait =
|
|
1052
|
+
do_epi_load_barrier_wait = Boolean(True)
|
|
1079
1053
|
# Persistent tile scheduling loop
|
|
1080
1054
|
tile_scheduler = TileSchedulerCls()
|
|
1081
1055
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1082
1056
|
while work_tile.is_valid_tile:
|
|
1083
1057
|
# Get tile coord from tile scheduler
|
|
1084
1058
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1059
|
+
# TODO: varlen_m
|
|
1085
1060
|
mma_tile_coord_mnl = (
|
|
1086
1061
|
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
1087
1062
|
tile_coord_mnkl[1],
|
|
@@ -1102,7 +1077,7 @@ class PersistentDenseGemmKernel:
|
|
|
1102
1077
|
bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
|
|
1103
1078
|
if do_epi_load_barrier_wait:
|
|
1104
1079
|
epi_load_barrier.arrive_and_wait()
|
|
1105
|
-
do_epi_load_barrier_wait =
|
|
1080
|
+
do_epi_load_barrier_wait = Boolean(False)
|
|
1106
1081
|
epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1]))
|
|
1107
1082
|
for subtile_idx in cutlass.range(epi_tile_num, unroll=1):
|
|
1108
1083
|
epi_pipeline.producer_acquire(epi_producer_state)
|
|
@@ -1149,10 +1124,9 @@ class PersistentDenseGemmKernel:
|
|
|
1149
1124
|
tiled_mma,
|
|
1150
1125
|
self.mma_tiler,
|
|
1151
1126
|
self.sf_vec_size,
|
|
1152
|
-
cute.slice_(
|
|
1127
|
+
cute.slice_(sfa_smem_layout, (None, None, None, 0)),
|
|
1153
1128
|
)
|
|
1154
1129
|
tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
|
|
1155
|
-
|
|
1156
1130
|
# Make SFB tmem tensor
|
|
1157
1131
|
sfb_tmem_ptr = cute.recast_ptr(
|
|
1158
1132
|
acc_tmem_ptr
|
|
@@ -1165,7 +1139,7 @@ class PersistentDenseGemmKernel:
|
|
|
1165
1139
|
tiled_mma,
|
|
1166
1140
|
self.mma_tiler,
|
|
1167
1141
|
self.sf_vec_size,
|
|
1168
|
-
cute.slice_(
|
|
1142
|
+
cute.slice_(sfb_smem_layout, (None, None, None, 0)),
|
|
1169
1143
|
)
|
|
1170
1144
|
tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
|
|
1171
1145
|
# Partition for S2T copy of SFA/SFB
|
|
@@ -1183,11 +1157,12 @@ class PersistentDenseGemmKernel:
|
|
|
1183
1157
|
tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
|
|
1184
1158
|
tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
|
|
1185
1159
|
|
|
1160
|
+
k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.mma_tiler[2])
|
|
1186
1161
|
# Persistent tile scheduling loop
|
|
1187
1162
|
tile_scheduler = TileSchedulerCls()
|
|
1188
1163
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1189
1164
|
ab_consumer_state = pipeline.make_pipeline_state(
|
|
1190
|
-
pipeline.PipelineUserType.Consumer, self.
|
|
1165
|
+
pipeline.PipelineUserType.Consumer, self.ab_stage
|
|
1191
1166
|
)
|
|
1192
1167
|
acc_producer_state = pipeline.make_pipeline_state(
|
|
1193
1168
|
pipeline.PipelineUserType.Producer, self.num_acc_stage
|
|
@@ -1241,11 +1216,29 @@ class PersistentDenseGemmKernel:
|
|
|
1241
1216
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
1242
1217
|
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
|
1243
1218
|
|
|
1244
|
-
epilog_threads = cute.arch.WARP_SIZE * len(self.epilog_warp_id)
|
|
1245
1219
|
epilogue_barrier = pipeline.NamedBarrier(
|
|
1246
|
-
barrier_id=
|
|
1220
|
+
barrier_id=int(NamedBarrierGemm.Epilogue),
|
|
1221
|
+
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
|
1247
1222
|
)
|
|
1248
1223
|
|
|
1224
|
+
is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
|
|
1225
|
+
if const_expr(varlen_m):
|
|
1226
|
+
# initialize tensormap for D
|
|
1227
|
+
if const_expr(has_D):
|
|
1228
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
1229
|
+
tma_atom_d,
|
|
1230
|
+
tensormap_d_ptr,
|
|
1231
|
+
is_manager_warp=is_tma_warp,
|
|
1232
|
+
)
|
|
1233
|
+
for tma_atom, tensormap_epi_ptr in zip(
|
|
1234
|
+
self.epi_get_tma_atoms(epilogue_params), tensormap_epi_ptrs
|
|
1235
|
+
):
|
|
1236
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
1237
|
+
tma_atom,
|
|
1238
|
+
tensormap_epi_ptr,
|
|
1239
|
+
is_manager_warp=is_tma_warp,
|
|
1240
|
+
)
|
|
1241
|
+
|
|
1249
1242
|
# Partition for epilogue
|
|
1250
1243
|
epi_tidx = tidx
|
|
1251
1244
|
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
|
|
@@ -1256,6 +1249,7 @@ class PersistentDenseGemmKernel:
|
|
|
1256
1249
|
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
|
|
1257
1250
|
tiled_copy_t2r, tTR_rD, epi_tidx, sD
|
|
1258
1251
|
)
|
|
1252
|
+
tRS_rC, tSR_rC = None, None
|
|
1259
1253
|
if const_expr(mC_mnl is not None):
|
|
1260
1254
|
tTR_rC = cute.make_fragment_like(tTR_rD, self.c_dtype)
|
|
1261
1255
|
tiled_copy_s2r, tSR_rC, tSR_sC = self.epilog_smem_copy_and_partition(
|
|
@@ -1272,22 +1266,33 @@ class PersistentDenseGemmKernel:
|
|
|
1272
1266
|
acc_consumer_state = pipeline.make_pipeline_state(
|
|
1273
1267
|
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
|
1274
1268
|
)
|
|
1275
|
-
|
|
1276
|
-
d_producer_group = pipeline.CooperativeGroup(
|
|
1277
|
-
pipeline.Agent.Thread,
|
|
1278
|
-
32 * len(self.epilog_warp_id),
|
|
1279
|
-
32 * len(self.epilog_warp_id),
|
|
1280
|
-
)
|
|
1281
|
-
d_pipeline = pipeline.PipelineTmaStore.create(
|
|
1282
|
-
num_stages=self.num_d_stage, producer_group=d_producer_group
|
|
1283
|
-
)
|
|
1269
|
+
epi_store_pipeline = self.make_epi_store_pipeline()
|
|
1284
1270
|
epi_read_state = pipeline.make_pipeline_state(
|
|
1285
|
-
pipeline.PipelineUserType.Consumer, self.
|
|
1271
|
+
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
|
1286
1272
|
)
|
|
1287
|
-
|
|
1273
|
+
if const_expr(varlen_m):
|
|
1274
|
+
# wait tensormap initialization complete before update
|
|
1275
|
+
tensormap_manager.fence_tensormap_initialization()
|
|
1276
|
+
# batch index of last tile
|
|
1277
|
+
last_batch_idx = cutlass.Int32(-1)
|
|
1288
1278
|
while work_tile.is_valid_tile:
|
|
1289
1279
|
# Get tile coord from tile scheduler
|
|
1290
1280
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1281
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1282
|
+
if const_expr(varlen_m):
|
|
1283
|
+
is_group_changed = batch_idx != last_batch_idx
|
|
1284
|
+
last_batch_idx = batch_idx
|
|
1285
|
+
if is_group_changed:
|
|
1286
|
+
self.tensormap_update_D_epi(
|
|
1287
|
+
tensormap_manager,
|
|
1288
|
+
tensormap_d_ptr,
|
|
1289
|
+
tensormap_epi_ptrs,
|
|
1290
|
+
epilogue_params,
|
|
1291
|
+
cu_seqlens_m,
|
|
1292
|
+
batch_idx,
|
|
1293
|
+
is_manager_warp=is_tma_warp,
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1291
1296
|
mma_tile_coord_mnl = (
|
|
1292
1297
|
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
1293
1298
|
tile_coord_mnkl[1],
|
|
@@ -1311,6 +1316,25 @@ class PersistentDenseGemmKernel:
|
|
|
1311
1316
|
# Wait for accumulator buffer full
|
|
1312
1317
|
acc_pipeline.consumer_wait(acc_consumer_state)
|
|
1313
1318
|
|
|
1319
|
+
tma_desc_d_ptr, tma_desc_epi_ptrs = None, [None] * self.num_epi_tensormaps
|
|
1320
|
+
if const_expr(varlen_m):
|
|
1321
|
+
# ensure the update to tensormap has completed before using it
|
|
1322
|
+
if is_group_changed and is_tma_warp:
|
|
1323
|
+
if const_expr(has_D):
|
|
1324
|
+
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1325
|
+
for tensormap_epi_ptr in tensormap_epi_ptrs:
|
|
1326
|
+
tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
|
|
1327
|
+
if const_expr(has_D):
|
|
1328
|
+
tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1329
|
+
tensormap_d_ptr, cute.AddressSpace.generic
|
|
1330
|
+
)
|
|
1331
|
+
tma_desc_epi_ptrs = [
|
|
1332
|
+
tensormap_manager.get_tensormap_ptr(
|
|
1333
|
+
tensormap_epi_ptr, cute.AddressSpace.generic
|
|
1334
|
+
)
|
|
1335
|
+
for tensormap_epi_ptr in tensormap_epi_ptrs
|
|
1336
|
+
]
|
|
1337
|
+
|
|
1314
1338
|
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
|
1315
1339
|
bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
|
|
1316
1340
|
|
|
@@ -1323,7 +1347,6 @@ class PersistentDenseGemmKernel:
|
|
|
1323
1347
|
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
|
1324
1348
|
# Convert to D type
|
|
1325
1349
|
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
|
1326
|
-
acc_vec = epilogue_op(acc_vec)
|
|
1327
1350
|
if const_expr(mC_mnl is not None):
|
|
1328
1351
|
epi_pipeline.consumer_wait(epi_read_state)
|
|
1329
1352
|
cute.copy(
|
|
@@ -1340,7 +1363,7 @@ class PersistentDenseGemmKernel:
|
|
|
1340
1363
|
acc_vec = acc_vec + tRS_rC.load().to(self.acc_dtype)
|
|
1341
1364
|
tRS_rD.store(acc_vec.to(self.d_dtype))
|
|
1342
1365
|
# Store D to shared memory
|
|
1343
|
-
d_buffer = (num_prev_subtiles + subtile_idx) % self.
|
|
1366
|
+
d_buffer = (num_prev_subtiles + subtile_idx) % self.epi_stage
|
|
1344
1367
|
cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
|
|
1345
1368
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1346
1369
|
cute.arch.fence_proxy(
|
|
@@ -1348,11 +1371,11 @@ class PersistentDenseGemmKernel:
|
|
|
1348
1371
|
)
|
|
1349
1372
|
epilogue_barrier.arrive_and_wait()
|
|
1350
1373
|
# TMA store D to global memory
|
|
1351
|
-
if
|
|
1374
|
+
if is_tma_warp:
|
|
1352
1375
|
cute.copy(tma_atom_d, bSG_sD[None, d_buffer], bSG_gD[None, subtile_idx])
|
|
1353
1376
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1354
|
-
|
|
1355
|
-
|
|
1377
|
+
epi_store_pipeline.producer_commit()
|
|
1378
|
+
epi_store_pipeline.producer_acquire()
|
|
1356
1379
|
epilogue_barrier.arrive_and_wait()
|
|
1357
1380
|
|
|
1358
1381
|
# Async arrive accumulator buffer empty
|
|
@@ -1369,7 +1392,7 @@ class PersistentDenseGemmKernel:
|
|
|
1369
1392
|
cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
|
|
1370
1393
|
epilogue_barrier.arrive_and_wait()
|
|
1371
1394
|
if warp_idx == self.epilog_warp_id[0]:
|
|
1372
|
-
if use_2cta_instrs:
|
|
1395
|
+
if const_expr(use_2cta_instrs):
|
|
1373
1396
|
cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
|
|
1374
1397
|
cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
|
|
1375
1398
|
cute.arch.dealloc_tmem(
|
|
@@ -1377,7 +1400,8 @@ class PersistentDenseGemmKernel:
|
|
|
1377
1400
|
)
|
|
1378
1401
|
|
|
1379
1402
|
# Wait for D store complete
|
|
1380
|
-
|
|
1403
|
+
if is_tma_warp:
|
|
1404
|
+
epi_store_pipeline.producer_tail()
|
|
1381
1405
|
|
|
1382
1406
|
@cute.jit
|
|
1383
1407
|
def load_AB(
|
|
@@ -1407,7 +1431,7 @@ class PersistentDenseGemmKernel:
|
|
|
1407
1431
|
assert all(x is not None for x in (tma_atom_sfb, tBgSFB, tBsSFB))
|
|
1408
1432
|
k_tile_cnt = cute.size(tAgA, mode=[1])
|
|
1409
1433
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1410
|
-
peek_ab_empty_status =
|
|
1434
|
+
peek_ab_empty_status = Boolean(True)
|
|
1411
1435
|
if 0 < k_tile_cnt:
|
|
1412
1436
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1413
1437
|
# /////////////////////////////////////////////////////////////////////////
|
|
@@ -1449,7 +1473,7 @@ class PersistentDenseGemmKernel:
|
|
|
1449
1473
|
# Mainloop pipeline's producer commit is a NOP
|
|
1450
1474
|
ab_pipeline.producer_commit(ab_producer_state)
|
|
1451
1475
|
ab_producer_state.advance()
|
|
1452
|
-
peek_ab_empty_status =
|
|
1476
|
+
peek_ab_empty_status = Boolean(True)
|
|
1453
1477
|
if k_tile + 1 < k_tile_cnt:
|
|
1454
1478
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1455
1479
|
return ab_producer_state
|
|
@@ -1466,7 +1490,7 @@ class PersistentDenseGemmKernel:
|
|
|
1466
1490
|
tCrB: cute.Tensor,
|
|
1467
1491
|
acc: cute.Tensor,
|
|
1468
1492
|
k_tile_cnt: Int32,
|
|
1469
|
-
is_leader_cta:
|
|
1493
|
+
is_leader_cta: Boolean,
|
|
1470
1494
|
tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None,
|
|
1471
1495
|
tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None,
|
|
1472
1496
|
tCsSFA_compact_s2t: Optional[cute.Tensor] = None,
|
|
@@ -1480,7 +1504,7 @@ class PersistentDenseGemmKernel:
|
|
|
1480
1504
|
assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t))
|
|
1481
1505
|
assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t))
|
|
1482
1506
|
# Peek (try_wait) AB buffer full for k_tile = 0
|
|
1483
|
-
peek_ab_full_status =
|
|
1507
|
+
peek_ab_full_status = Boolean(True)
|
|
1484
1508
|
if 0 < k_tile_cnt and is_leader_cta:
|
|
1485
1509
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
|
1486
1510
|
# Wait for accumulator buffer empty
|
|
@@ -1509,7 +1533,7 @@ class PersistentDenseGemmKernel:
|
|
|
1509
1533
|
ab_pipeline.consumer_release(ab_consumer_state)
|
|
1510
1534
|
ab_consumer_state.advance()
|
|
1511
1535
|
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
|
1512
|
-
peek_ab_full_status =
|
|
1536
|
+
peek_ab_full_status = Boolean(True)
|
|
1513
1537
|
if k_tile + 1 < k_tile_cnt and is_leader_cta:
|
|
1514
1538
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
|
1515
1539
|
# Async arrive accumulator buffer full
|
|
@@ -1560,7 +1584,7 @@ class PersistentDenseGemmKernel:
|
|
|
1560
1584
|
tidx: Int32,
|
|
1561
1585
|
tAcc: cute.Tensor,
|
|
1562
1586
|
epi_tile: cute.Tile,
|
|
1563
|
-
use_2cta_instrs: Union[
|
|
1587
|
+
use_2cta_instrs: Union[Boolean, bool],
|
|
1564
1588
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1565
1589
|
"""
|
|
1566
1590
|
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
|
|
@@ -1708,6 +1732,22 @@ class PersistentDenseGemmKernel:
|
|
|
1708
1732
|
)
|
|
1709
1733
|
return bSG_sD, bSG_gD
|
|
1710
1734
|
|
|
1735
|
+
def make_acc_pipeline(
|
|
1736
|
+
self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer
|
|
1737
|
+
):
|
|
1738
|
+
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1739
|
+
num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1)
|
|
1740
|
+
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1741
|
+
pipeline.Agent.Thread, num_acc_consumer_threads
|
|
1742
|
+
)
|
|
1743
|
+
return pipeline.PipelineUmmaAsync.create(
|
|
1744
|
+
barrier_storage=acc_pipeline_mbar_ptr,
|
|
1745
|
+
num_stages=self.num_acc_stage,
|
|
1746
|
+
producer_group=acc_pipeline_producer_group,
|
|
1747
|
+
consumer_group=acc_pipeline_consumer_group,
|
|
1748
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1749
|
+
)
|
|
1750
|
+
|
|
1711
1751
|
@staticmethod
|
|
1712
1752
|
def _compute_stages(
|
|
1713
1753
|
tiled_mma: cute.TiledMma,
|
|
@@ -1717,8 +1757,8 @@ class PersistentDenseGemmKernel:
|
|
|
1717
1757
|
epi_tile: cute.Tile,
|
|
1718
1758
|
d_dtype: Type[cutlass.Numeric],
|
|
1719
1759
|
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1720
|
-
d_layout:
|
|
1721
|
-
c_layout: Optional[
|
|
1760
|
+
d_layout: LayoutEnum,
|
|
1761
|
+
c_layout: Optional[LayoutEnum],
|
|
1722
1762
|
sf_dtype: Optional[Type[cutlass.Numeric]],
|
|
1723
1763
|
sf_vec_size: Optional[int],
|
|
1724
1764
|
smem_capacity: int,
|
|
@@ -1739,7 +1779,7 @@ class PersistentDenseGemmKernel:
|
|
|
1739
1779
|
:param d_dtype: Data type of operand C (output).
|
|
1740
1780
|
:type d_dtype: type[cutlass.Numeric]
|
|
1741
1781
|
:param d_layout: Layout enum of operand C.
|
|
1742
|
-
:type d_layout:
|
|
1782
|
+
:type d_layout: LayoutEnum
|
|
1743
1783
|
:param smem_capacity: Total available shared memory capacity in bytes.
|
|
1744
1784
|
:type smem_capacity: int
|
|
1745
1785
|
:param occupancy: Target number of CTAs per SM (occupancy).
|
|
@@ -1757,8 +1797,8 @@ class PersistentDenseGemmKernel:
|
|
|
1757
1797
|
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
|
|
1758
1798
|
|
|
1759
1799
|
# Default D stages
|
|
1760
|
-
|
|
1761
|
-
|
|
1800
|
+
epi_stage = 2
|
|
1801
|
+
epi_c_stage = 2 if c_dtype is not None else 0
|
|
1762
1802
|
|
|
1763
1803
|
# Calculate smem layout and size for one stage of A, B, and C
|
|
1764
1804
|
a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
|
|
@@ -1802,28 +1842,28 @@ class PersistentDenseGemmKernel:
|
|
|
1802
1842
|
) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
|
|
1803
1843
|
mbar_helpers_bytes = 1024
|
|
1804
1844
|
d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
|
|
1805
|
-
epi_bytes = d_bytes_per_stage *
|
|
1845
|
+
epi_bytes = d_bytes_per_stage * epi_stage
|
|
1806
1846
|
if const_expr(c_dtype is not None):
|
|
1807
1847
|
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
|
|
1808
|
-
epi_bytes += c_bytes_per_stage *
|
|
1848
|
+
epi_bytes += c_bytes_per_stage * epi_c_stage
|
|
1809
1849
|
|
|
1810
1850
|
# Calculate A/B/SFA/SFB stages:
|
|
1811
1851
|
# Start with total smem per CTA (capacity / occupancy)
|
|
1812
1852
|
# Subtract reserved bytes and initial C stages bytes
|
|
1813
1853
|
# Divide remaining by bytes needed per A/B/SFA/SFB stage
|
|
1814
|
-
|
|
1854
|
+
ab_stage = (
|
|
1815
1855
|
smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)
|
|
1816
1856
|
) // ab_bytes_per_stage
|
|
1817
1857
|
|
|
1818
1858
|
# Refine epilogue stages:
|
|
1819
1859
|
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1820
1860
|
# Add remaining unused smem to epilogue
|
|
1821
|
-
|
|
1861
|
+
epi_stage += (
|
|
1822
1862
|
smem_capacity
|
|
1823
|
-
- occupancy * ab_bytes_per_stage *
|
|
1863
|
+
- occupancy * ab_bytes_per_stage * ab_stage
|
|
1824
1864
|
- occupancy * (mbar_helpers_bytes + epi_bytes)
|
|
1825
1865
|
) // (occupancy * d_bytes_per_stage)
|
|
1826
|
-
return num_acc_stage,
|
|
1866
|
+
return num_acc_stage, ab_stage, epi_stage, epi_c_stage
|
|
1827
1867
|
|
|
1828
1868
|
@staticmethod
|
|
1829
1869
|
def _compute_num_tmem_alloc_cols(
|
|
@@ -1880,7 +1920,7 @@ class PersistentDenseGemmKernel:
|
|
|
1880
1920
|
}:
|
|
1881
1921
|
is_valid = False
|
|
1882
1922
|
if (
|
|
1883
|
-
acc_dtype not in {
|
|
1923
|
+
acc_dtype not in {Float32, cutlass.Float16, Int32}
|
|
1884
1924
|
or acc_dtype == cutlass.Float16
|
|
1885
1925
|
and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
|
|
1886
1926
|
or acc_dtype == Int32
|
|
@@ -1888,10 +1928,10 @@ class PersistentDenseGemmKernel:
|
|
|
1888
1928
|
):
|
|
1889
1929
|
is_valid = False
|
|
1890
1930
|
if (
|
|
1891
|
-
acc_dtype ==
|
|
1931
|
+
acc_dtype == Float32
|
|
1892
1932
|
and d_dtype
|
|
1893
1933
|
not in {
|
|
1894
|
-
|
|
1934
|
+
Float32,
|
|
1895
1935
|
cutlass.Float16,
|
|
1896
1936
|
cutlass.BFloat16,
|
|
1897
1937
|
cutlass.Float8E4M3FN,
|
|
@@ -1911,7 +1951,7 @@ class PersistentDenseGemmKernel:
|
|
|
1911
1951
|
not in {
|
|
1912
1952
|
cutlass.BFloat16,
|
|
1913
1953
|
cutlass.Float16,
|
|
1914
|
-
|
|
1954
|
+
Float32,
|
|
1915
1955
|
Int32,
|
|
1916
1956
|
cutlass.Int8,
|
|
1917
1957
|
cutlass.Uint8,
|
|
@@ -1964,7 +2004,7 @@ class PersistentDenseGemmKernel:
|
|
|
1964
2004
|
|
|
1965
2005
|
# Check valid d_dtype
|
|
1966
2006
|
if d_dtype not in {
|
|
1967
|
-
|
|
2007
|
+
Float32,
|
|
1968
2008
|
cutlass.Float16,
|
|
1969
2009
|
cutlass.BFloat16,
|
|
1970
2010
|
cutlass.Float8E5M2,
|
|
@@ -2004,7 +2044,6 @@ class PersistentDenseGemmKernel:
|
|
|
2004
2044
|
|
|
2005
2045
|
@staticmethod
|
|
2006
2046
|
def is_valid_mma_tiler_and_cluster_shape(
|
|
2007
|
-
use_2cta_instrs: bool,
|
|
2008
2047
|
mma_tiler_mn: Tuple[int, int],
|
|
2009
2048
|
cluster_shape_mn: Tuple[int, int],
|
|
2010
2049
|
blockscaled: bool,
|
|
@@ -2012,8 +2051,6 @@ class PersistentDenseGemmKernel:
|
|
|
2012
2051
|
"""
|
|
2013
2052
|
Check if the mma tiler and cluster shape are valid
|
|
2014
2053
|
|
|
2015
|
-
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
2016
|
-
:type use_2cta_instrs: bool
|
|
2017
2054
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
2018
2055
|
:type mma_tiler_mn: Tuple[int, int]
|
|
2019
2056
|
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
@@ -2024,10 +2061,7 @@ class PersistentDenseGemmKernel:
|
|
|
2024
2061
|
"""
|
|
2025
2062
|
is_valid = True
|
|
2026
2063
|
# Skip invalid mma tile shape
|
|
2027
|
-
if not
|
|
2028
|
-
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
|
|
2029
|
-
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
|
|
2030
|
-
):
|
|
2064
|
+
if mma_tiler_mn[0] not in [64, 128, 256]:
|
|
2031
2065
|
is_valid = False
|
|
2032
2066
|
if not blockscaled:
|
|
2033
2067
|
if mma_tiler_mn[1] not in range(32, 257, 32):
|
|
@@ -2035,9 +2069,6 @@ class PersistentDenseGemmKernel:
|
|
|
2035
2069
|
else:
|
|
2036
2070
|
if mma_tiler_mn[1] not in [128, 256]:
|
|
2037
2071
|
is_valid = False
|
|
2038
|
-
# Skip illegal cluster shape
|
|
2039
|
-
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
|
|
2040
|
-
is_valid = False
|
|
2041
2072
|
# Skip invalid cluster shape
|
|
2042
2073
|
is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
|
|
2043
2074
|
if (
|
|
@@ -2113,7 +2144,6 @@ class PersistentDenseGemmKernel:
|
|
|
2113
2144
|
ab_dtype: Type[cutlass.Numeric],
|
|
2114
2145
|
acc_dtype: Type[cutlass.Numeric],
|
|
2115
2146
|
d_dtype: Type[cutlass.Numeric],
|
|
2116
|
-
use_2cta_instrs: bool,
|
|
2117
2147
|
mma_tiler_mn: Tuple[int, int],
|
|
2118
2148
|
cluster_shape_mn: Tuple[int, int],
|
|
2119
2149
|
m: int,
|
|
@@ -2133,8 +2163,6 @@ class PersistentDenseGemmKernel:
|
|
|
2133
2163
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
2134
2164
|
:param d_dtype: The data type of the output tensor
|
|
2135
2165
|
:type d_dtype: Type[cutlass.Numeric]
|
|
2136
|
-
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
2137
|
-
:type use_2cta_instrs: bool
|
|
2138
2166
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
2139
2167
|
:type mma_tiler_mn: Tuple[int, int]
|
|
2140
2168
|
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
@@ -2159,15 +2187,15 @@ class PersistentDenseGemmKernel:
|
|
|
2159
2187
|
"""
|
|
2160
2188
|
can_implement = True
|
|
2161
2189
|
# Skip unsupported types
|
|
2162
|
-
if not
|
|
2190
|
+
if not GemmSm100.is_valid_dtypes(ab_dtype, acc_dtype, d_dtype):
|
|
2163
2191
|
can_implement = False
|
|
2164
2192
|
# Skip invalid mma tile shape and cluster shape
|
|
2165
|
-
if not
|
|
2166
|
-
|
|
2193
|
+
if not GemmSm100.is_valid_mma_tiler_and_cluster_shape(
|
|
2194
|
+
mma_tiler_mn, cluster_shape_mn, blockscaled=False
|
|
2167
2195
|
):
|
|
2168
2196
|
can_implement = False
|
|
2169
2197
|
# Skip illegal problem shape for load/store alignment
|
|
2170
|
-
if not
|
|
2198
|
+
if not GemmSm100.is_valid_tensor_alignment(
|
|
2171
2199
|
m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major
|
|
2172
2200
|
):
|
|
2173
2201
|
can_implement = False
|
|
@@ -2186,7 +2214,6 @@ def run(
|
|
|
2186
2214
|
c_major: str,
|
|
2187
2215
|
mma_tiler_mn: Tuple[int, int] = (256, 256),
|
|
2188
2216
|
cluster_shape_mn: Tuple[int, int] = (2, 1),
|
|
2189
|
-
use_2cta_instrs: bool = True,
|
|
2190
2217
|
tolerance: float = 1e-01,
|
|
2191
2218
|
warmup_iterations: int = 0,
|
|
2192
2219
|
iterations: int = 1,
|
|
@@ -2215,9 +2242,6 @@ def run(
|
|
|
2215
2242
|
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
|
|
2216
2243
|
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
2217
2244
|
:type cluster_shape_mn: Tuple[int, int], optional
|
|
2218
|
-
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
|
|
2219
|
-
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
2220
|
-
:type use_2cta_instrs: bool, optional
|
|
2221
2245
|
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
|
2222
2246
|
:type tolerance: float, optional
|
|
2223
2247
|
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
|
@@ -2236,7 +2260,6 @@ def run(
|
|
|
2236
2260
|
print(f"AB dtype: {ab_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
|
|
2237
2261
|
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
|
|
2238
2262
|
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
|
2239
|
-
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
|
|
2240
2263
|
print(f"Tolerance: {tolerance}")
|
|
2241
2264
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
2242
2265
|
print(f"Iterations: {iterations}")
|
|
@@ -2248,11 +2271,10 @@ def run(
|
|
|
2248
2271
|
m, n, k, l = mnkl
|
|
2249
2272
|
|
|
2250
2273
|
# Skip unsupported testcase
|
|
2251
|
-
if not
|
|
2274
|
+
if not GemmSm100.can_implement(
|
|
2252
2275
|
ab_dtype,
|
|
2253
2276
|
acc_dtype,
|
|
2254
2277
|
d_dtype,
|
|
2255
|
-
use_2cta_instrs,
|
|
2256
2278
|
mma_tiler_mn,
|
|
2257
2279
|
cluster_shape_mn,
|
|
2258
2280
|
m,
|
|
@@ -2264,7 +2286,7 @@ def run(
|
|
|
2264
2286
|
d_major,
|
|
2265
2287
|
):
|
|
2266
2288
|
raise TypeError(
|
|
2267
|
-
f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {
|
|
2289
|
+
f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}"
|
|
2268
2290
|
)
|
|
2269
2291
|
|
|
2270
2292
|
if not torch.cuda.is_available():
|
|
@@ -2339,12 +2361,8 @@ def run(
|
|
|
2339
2361
|
c, mC, c_torch = None, None, None
|
|
2340
2362
|
|
|
2341
2363
|
# Configure gemm kernel
|
|
2342
|
-
|
|
2343
|
-
|
|
2344
|
-
use_2cta_instrs,
|
|
2345
|
-
mma_tiler_mn,
|
|
2346
|
-
cluster_shape_mn,
|
|
2347
|
-
)
|
|
2364
|
+
cluster_shape_mnk = (*cluster_shape_mn, 1)
|
|
2365
|
+
gemm = GemmSm100(acc_dtype, mma_tiler_mn, cluster_shape_mnk)
|
|
2348
2366
|
|
|
2349
2367
|
# Compute max active clusters on current device
|
|
2350
2368
|
hardware_info = cutlass.utils.HardwareInfo()
|
|
@@ -2356,6 +2374,17 @@ def run(
|
|
|
2356
2374
|
else:
|
|
2357
2375
|
tile_count_semaphore = None
|
|
2358
2376
|
|
|
2377
|
+
scheduler_args = TileSchedulerOptions(
|
|
2378
|
+
Int32(max_active_clusters),
|
|
2379
|
+
tile_count_semaphore=make_ptr(
|
|
2380
|
+
Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
|
|
2381
|
+
)
|
|
2382
|
+
if tile_count_semaphore is not None
|
|
2383
|
+
else None,
|
|
2384
|
+
)
|
|
2385
|
+
epi_args = gemm.EpilogueArguments()
|
|
2386
|
+
varlen_args = VarlenArguments()
|
|
2387
|
+
|
|
2359
2388
|
# Get current CUDA stream from PyTorch
|
|
2360
2389
|
torch_stream = torch.cuda.current_stream()
|
|
2361
2390
|
# Get the raw stream pointer as a CUstream
|
|
@@ -2367,15 +2396,14 @@ def run(
|
|
|
2367
2396
|
mB,
|
|
2368
2397
|
mD,
|
|
2369
2398
|
mC,
|
|
2370
|
-
|
|
2371
|
-
|
|
2372
|
-
|
|
2373
|
-
max_active_clusters,
|
|
2399
|
+
epi_args,
|
|
2400
|
+
scheduler_args,
|
|
2401
|
+
varlen_args,
|
|
2374
2402
|
current_stream,
|
|
2375
2403
|
)
|
|
2376
2404
|
|
|
2377
2405
|
if not skip_ref_check:
|
|
2378
|
-
compiled_gemm(mA, mB, mD, mC,
|
|
2406
|
+
compiled_gemm(mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream)
|
|
2379
2407
|
if ab_dtype in {
|
|
2380
2408
|
cutlass.Int8,
|
|
2381
2409
|
cutlass.Uint8,
|
|
@@ -2393,7 +2421,7 @@ def run(
|
|
|
2393
2421
|
gpu_d = d_torch.cpu()
|
|
2394
2422
|
|
|
2395
2423
|
# Convert ref to c_type
|
|
2396
|
-
if d_dtype ==
|
|
2424
|
+
if d_dtype == Float32:
|
|
2397
2425
|
ref_d = ref
|
|
2398
2426
|
elif d_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
|
2399
2427
|
# m major: (l, n, m) -> (m, n, l)
|
|
@@ -2463,7 +2491,9 @@ def run(
|
|
|
2463
2491
|
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2464
2492
|
|
|
2465
2493
|
time.sleep(0.5)
|
|
2466
|
-
fn = lambda: compiled_gemm(
|
|
2494
|
+
fn = lambda: compiled_gemm(
|
|
2495
|
+
mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream
|
|
2496
|
+
)
|
|
2467
2497
|
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
2468
2498
|
tflops = flops / (timing * 1e9) # Convert to TFlops
|
|
2469
2499
|
print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
|
|
@@ -2505,12 +2535,7 @@ if __name__ == "__main__":
|
|
|
2505
2535
|
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
2506
2536
|
parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
2507
2537
|
parser.add_argument("--c_dtype", type=cutlass.dtype, default=None)
|
|
2508
|
-
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=
|
|
2509
|
-
parser.add_argument(
|
|
2510
|
-
"--use_2cta_instrs",
|
|
2511
|
-
action="store_true",
|
|
2512
|
-
help="Enable 2CTA MMA instructions feature",
|
|
2513
|
-
)
|
|
2538
|
+
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=Float32)
|
|
2514
2539
|
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
2515
2540
|
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
2516
2541
|
parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
|
|
@@ -2552,7 +2577,6 @@ if __name__ == "__main__":
|
|
|
2552
2577
|
args.c_major,
|
|
2553
2578
|
args.mma_tiler_mn,
|
|
2554
2579
|
args.cluster_shape_mn,
|
|
2555
|
-
args.use_2cta_instrs,
|
|
2556
2580
|
args.tolerance,
|
|
2557
2581
|
args.warmup_iterations,
|
|
2558
2582
|
args.iterations,
|