quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,2562 @@
|
|
|
1
|
+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
3
|
+
|
|
4
|
+
# Redistribution and use in source and binary forms, with or without
|
|
5
|
+
# modification, are permitted provided that the following conditions are met:
|
|
6
|
+
|
|
7
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
+
# list of conditions and the following disclaimer.
|
|
9
|
+
|
|
10
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
+
# and/or other materials provided with the distribution.
|
|
13
|
+
|
|
14
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
+
# contributors may be used to endorse or promote products derived from
|
|
16
|
+
# this software without specific prior written permission.
|
|
17
|
+
|
|
18
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
|
+
|
|
29
|
+
import argparse
|
|
30
|
+
from typing import Optional, Type, Tuple, Union, Callable
|
|
31
|
+
from functools import partial
|
|
32
|
+
|
|
33
|
+
import cuda.bindings.driver as cuda
|
|
34
|
+
import torch
|
|
35
|
+
|
|
36
|
+
import cutlass
|
|
37
|
+
import cutlass.cute as cute
|
|
38
|
+
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
39
|
+
import cutlass.torch as cutlass_torch
|
|
40
|
+
import cutlass.pipeline as pipeline
|
|
41
|
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
42
|
+
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
|
43
|
+
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
44
|
+
from cutlass import Int32, const_expr
|
|
45
|
+
|
|
46
|
+
from quack.cute_dsl_utils import ParamsBase
|
|
47
|
+
from quack.tile_scheduler import (
|
|
48
|
+
TileSchedulerArguments,
|
|
49
|
+
TileScheduler,
|
|
50
|
+
RasterOrderOption,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
|
|
55
|
+
using CUTE DSL.
|
|
56
|
+
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
|
57
|
+
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
|
58
|
+
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
|
59
|
+
|
|
60
|
+
This GEMM kernel supports the following features:
|
|
61
|
+
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
|
|
62
|
+
- Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions)
|
|
63
|
+
- Implements TMA multicast with cluster to reduce L2 memory traffic
|
|
64
|
+
- Support persistent tile scheduling to better overlap memory load/store with mma between tiles
|
|
65
|
+
- Support warp specialization to avoid explicit pipelining between mainloop load and mma
|
|
66
|
+
|
|
67
|
+
This GEMM works as follows:
|
|
68
|
+
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
|
69
|
+
2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
|
|
70
|
+
3. EPILOGUE warp:
|
|
71
|
+
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
|
72
|
+
- Type convert C matrix to output type.
|
|
73
|
+
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
|
|
74
|
+
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
|
|
75
|
+
- Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
|
|
76
|
+
e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
|
|
77
|
+
|
|
78
|
+
SM100 tcgen05.mma instructions operate as follows:
|
|
79
|
+
- Read matrix A from SMEM
|
|
80
|
+
- Read matrix B from SMEM
|
|
81
|
+
- Write accumulator to TMEM
|
|
82
|
+
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
|
83
|
+
|
|
84
|
+
Input arguments to this example is same as dense_gemm.py.
|
|
85
|
+
|
|
86
|
+
.. code-block:: bash
|
|
87
|
+
|
|
88
|
+
python examples/blackwell/dense_gemm_persistent.py \
|
|
89
|
+
--ab_dtype Float16 --d_dtype Float16 --acc_dtype Float32 \
|
|
90
|
+
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
|
|
91
|
+
--mnkl 8192,8192,8192,1 \
|
|
92
|
+
--use_2cta_instrs
|
|
93
|
+
|
|
94
|
+
To collect performance with NCU profiler:
|
|
95
|
+
|
|
96
|
+
.. code-block:: bash
|
|
97
|
+
|
|
98
|
+
ncu python examples/blackwell/dense_gemm_persistent.py \
|
|
99
|
+
--ab_dtype Float16 --d_dtype Float16 --acc_dtype Float32 \
|
|
100
|
+
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
|
|
101
|
+
--mnkl 8192,8192,8192,1 \
|
|
102
|
+
--use_2cta_instrs \
|
|
103
|
+
--warmup_iterations 1 --iterations 10 --skip_ref_check
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
Constraints are same as dense_gemm.py:
|
|
107
|
+
* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
|
|
108
|
+
see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation
|
|
109
|
+
* A/B tensor must have the same data type
|
|
110
|
+
* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
|
|
111
|
+
* Mma tiler N must be 32-256, step 32
|
|
112
|
+
* Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
|
113
|
+
* Cluster shape M must be multiple of 2 if use_2cta_instrs=True
|
|
114
|
+
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
|
115
|
+
i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32,
|
|
116
|
+
Float16/BFloat16, and Int8/Uint8/Float8, respectively.
|
|
117
|
+
* OOB tiles are not allowed when TMA store is disabled
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class PersistentDenseGemmKernel:
|
|
122
|
+
"""This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
123
|
+
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
|
|
124
|
+
|
|
125
|
+
:param acc_dtype: Data type for accumulation during computation
|
|
126
|
+
:type acc_dtype: type[cutlass.Numeric]
|
|
127
|
+
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
|
|
128
|
+
:type use_2cta_instrs: bool
|
|
129
|
+
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
|
|
130
|
+
:type mma_tiler_mn: Tuple[int, int]
|
|
131
|
+
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
|
132
|
+
:type cluster_shape_mn: Tuple[int, int]
|
|
133
|
+
|
|
134
|
+
:note: In current version, A and B tensor must have the same data type
|
|
135
|
+
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
|
|
136
|
+
|
|
137
|
+
:note: Supported A/B data types:
|
|
138
|
+
- TFloat32
|
|
139
|
+
- Float16/BFloat16
|
|
140
|
+
- Int8/Uint8
|
|
141
|
+
- Float8E4M3FN/Float8E5M2
|
|
142
|
+
|
|
143
|
+
:note: Supported accumulator data types:
|
|
144
|
+
- Float32 (for all floating point A/B data types)
|
|
145
|
+
- Float16 (only for fp16 and fp8 A/B data types)
|
|
146
|
+
- Int32 (only for uint8/int8 A/B data types)
|
|
147
|
+
|
|
148
|
+
:note: Supported C data types:
|
|
149
|
+
- Float32 (for float32 and int32 accumulator data types)
|
|
150
|
+
- Int32 (for float32 and int32 accumulator data types)
|
|
151
|
+
- Float16/BFloat16 (for fp16 and fp8 accumulator data types)
|
|
152
|
+
- Int8/Uint8 (for uint8/int8 accumulator data types)
|
|
153
|
+
- Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
|
|
154
|
+
|
|
155
|
+
:note: Constraints:
|
|
156
|
+
- MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
|
|
157
|
+
- MMA tiler N must be 32-256, step 32
|
|
158
|
+
- Cluster shape M must be multiple of 2 if use_2cta_instrs=True
|
|
159
|
+
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
|
160
|
+
|
|
161
|
+
Example:
|
|
162
|
+
>>> gemm = PersistentDenseGemmKernel(
|
|
163
|
+
... acc_dtype=cutlass.Float32,
|
|
164
|
+
... use_2cta_instrs=True,
|
|
165
|
+
... mma_tiler_mn=(128, 128),
|
|
166
|
+
... cluster_shape_mn=(2, 2)
|
|
167
|
+
... )
|
|
168
|
+
>>> gemm(mA, mB, mD, max_active_clusters, stream)
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
acc_dtype: Type[cutlass.Numeric],
|
|
174
|
+
use_2cta_instrs: bool,
|
|
175
|
+
mma_tiler_mn: Tuple[int, int],
|
|
176
|
+
cluster_shape_mn: Tuple[int, int],
|
|
177
|
+
sf_vec_size: Optional[int] = None,
|
|
178
|
+
):
|
|
179
|
+
"""Initializes the configuration for a Blackwell dense GEMM kernel.
|
|
180
|
+
|
|
181
|
+
This configuration includes several key aspects:
|
|
182
|
+
|
|
183
|
+
1. MMA Instruction Settings (tcgen05):
|
|
184
|
+
- acc_dtype: Data types for MMA accumulator.
|
|
185
|
+
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
|
|
186
|
+
- use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
|
|
187
|
+
with cta_group=2 should be used.
|
|
188
|
+
|
|
189
|
+
2. Cluster Shape:
|
|
190
|
+
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
|
|
191
|
+
|
|
192
|
+
:param acc_dtype: Data type of the accumulator.
|
|
193
|
+
:type acc_dtype: type[cutlass.Numeric]
|
|
194
|
+
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
|
|
195
|
+
:type mma_tiler_mn: Tuple[int, int]
|
|
196
|
+
:param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
|
|
197
|
+
:type use_2cta_instrs: bool
|
|
198
|
+
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
|
|
199
|
+
:type cluster_shape_mn: Tuple[int, int]
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
|
|
203
|
+
self.use_2cta_instrs = use_2cta_instrs
|
|
204
|
+
self.cluster_shape_mn = cluster_shape_mn
|
|
205
|
+
# K dimension is deferred in _setup_attributes
|
|
206
|
+
self.mma_tiler = (*mma_tiler_mn, 1)
|
|
207
|
+
self.sf_vec_size = sf_vec_size
|
|
208
|
+
self.blockscaled = sf_vec_size is not None
|
|
209
|
+
|
|
210
|
+
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
211
|
+
|
|
212
|
+
self.occupancy = 1
|
|
213
|
+
# Set specialized warp ids
|
|
214
|
+
self.epilog_warp_id = (
|
|
215
|
+
0,
|
|
216
|
+
1,
|
|
217
|
+
2,
|
|
218
|
+
3,
|
|
219
|
+
)
|
|
220
|
+
self.mma_warp_id = 4
|
|
221
|
+
self.tma_warp_id = 5
|
|
222
|
+
self.tma_epi_warp_id = 6
|
|
223
|
+
self.threads_per_cta = 32 * len(
|
|
224
|
+
(self.mma_warp_id, self.tma_warp_id, self.tma_epi_warp_id, *self.epilog_warp_id)
|
|
225
|
+
)
|
|
226
|
+
# Set barrier id for cta sync, epilogue sync and tmem ptr sync
|
|
227
|
+
self.cta_sync_bar_id = 0
|
|
228
|
+
self.epilog_sync_bar_id = 1
|
|
229
|
+
self.tmem_ptr_sync_bar_id = 2
|
|
230
|
+
self.epilog_load_bar_id = 3
|
|
231
|
+
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_100")
|
|
232
|
+
|
|
233
|
+
def _setup_attributes(self):
|
|
234
|
+
"""Set up configurations that are dependent on GEMM inputs
|
|
235
|
+
|
|
236
|
+
This method configures various attributes based on the input tensor properties
|
|
237
|
+
(data types, leading dimensions) and kernel settings:
|
|
238
|
+
- Configuring tiled MMA
|
|
239
|
+
- Computing MMA/cluster/tile shapes
|
|
240
|
+
- Computing cluster layout
|
|
241
|
+
- Computing multicast CTAs for A/B
|
|
242
|
+
- Computing epilogue subtile
|
|
243
|
+
- Setting up A/B/C stage counts in shared memory
|
|
244
|
+
- Computing A/B/C shared memory layout
|
|
245
|
+
- Computing tensor memory allocation columns
|
|
246
|
+
"""
|
|
247
|
+
# Compute mma instruction shapes
|
|
248
|
+
mma_inst_bits_k = 256
|
|
249
|
+
# (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K)
|
|
250
|
+
self.mma_inst_shape_mnk = (
|
|
251
|
+
self.mma_tiler[0],
|
|
252
|
+
self.mma_tiler[1],
|
|
253
|
+
mma_inst_bits_k // self.a_dtype.width,
|
|
254
|
+
)
|
|
255
|
+
# (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K)
|
|
256
|
+
self.mma_inst_shape_mnk_sfb = (
|
|
257
|
+
self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1),
|
|
258
|
+
cute.round_up(self.mma_inst_shape_mnk[1], 128),
|
|
259
|
+
self.mma_inst_shape_mnk[2],
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Configure tiled mma
|
|
263
|
+
if const_expr(not self.blockscaled):
|
|
264
|
+
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
265
|
+
self.a_dtype,
|
|
266
|
+
self.a_major_mode,
|
|
267
|
+
self.b_major_mode,
|
|
268
|
+
self.acc_dtype,
|
|
269
|
+
self.cta_group,
|
|
270
|
+
self.mma_tiler[:2],
|
|
271
|
+
)
|
|
272
|
+
tiled_mma_sfb = None
|
|
273
|
+
else:
|
|
274
|
+
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
275
|
+
self.a_dtype,
|
|
276
|
+
self.a_major_mode,
|
|
277
|
+
self.b_major_mode,
|
|
278
|
+
self.sf_dtype,
|
|
279
|
+
self.sf_vec_size,
|
|
280
|
+
self.cta_group,
|
|
281
|
+
self.mma_inst_shape_mnk[:2],
|
|
282
|
+
)
|
|
283
|
+
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
284
|
+
self.a_dtype,
|
|
285
|
+
self.a_major_mode,
|
|
286
|
+
self.b_major_mode,
|
|
287
|
+
self.sf_dtype,
|
|
288
|
+
self.sf_vec_size,
|
|
289
|
+
cute.nvgpu.tcgen05.CtaGroup.ONE,
|
|
290
|
+
self.mma_inst_shape_mnk_sfb[:2],
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Compute mma/cluster/tile shapes
|
|
294
|
+
mma_inst_tile_k = 4
|
|
295
|
+
self.mma_tiler = (
|
|
296
|
+
self.mma_inst_shape_mnk[0],
|
|
297
|
+
self.mma_inst_shape_mnk[1],
|
|
298
|
+
self.mma_inst_shape_mnk[2] * mma_inst_tile_k,
|
|
299
|
+
)
|
|
300
|
+
if const_expr(self.blockscaled):
|
|
301
|
+
self.mma_tiler_sfb = (
|
|
302
|
+
self.mma_inst_shape_mnk_sfb[0],
|
|
303
|
+
self.mma_inst_shape_mnk_sfb[1],
|
|
304
|
+
self.mma_inst_shape_mnk_sfb[2] * mma_inst_tile_k,
|
|
305
|
+
)
|
|
306
|
+
else:
|
|
307
|
+
self.mma_tiler_sfb = None
|
|
308
|
+
self.cta_tile_shape_mnk = (
|
|
309
|
+
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
|
310
|
+
self.mma_tiler[1],
|
|
311
|
+
self.mma_tiler[2],
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Compute cluster layout
|
|
315
|
+
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
316
|
+
cute.make_layout((*self.cluster_shape_mn, 1)),
|
|
317
|
+
(tiled_mma.thr_id.shape,),
|
|
318
|
+
)
|
|
319
|
+
if const_expr(self.blockscaled):
|
|
320
|
+
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
|
|
321
|
+
cute.make_layout((*self.cluster_shape_mn, 1)),
|
|
322
|
+
(tiled_mma_sfb.thr_id.shape,),
|
|
323
|
+
)
|
|
324
|
+
else:
|
|
325
|
+
self.cluster_layout_sfb_vmnk = None
|
|
326
|
+
|
|
327
|
+
# Compute number of multicast CTAs for A/B
|
|
328
|
+
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
|
329
|
+
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
|
330
|
+
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
331
|
+
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
332
|
+
if const_expr(self.blockscaled):
|
|
333
|
+
self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1])
|
|
334
|
+
self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
|
|
335
|
+
|
|
336
|
+
# Compute epilogue subtile
|
|
337
|
+
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
|
|
338
|
+
self.cta_tile_shape_mnk,
|
|
339
|
+
self.use_2cta_instrs,
|
|
340
|
+
self.d_layout,
|
|
341
|
+
self.d_dtype,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
|
|
345
|
+
(
|
|
346
|
+
self.num_acc_stage,
|
|
347
|
+
self.num_ab_stage,
|
|
348
|
+
self.num_d_stage,
|
|
349
|
+
self.num_c_stage,
|
|
350
|
+
) = self._compute_stages(
|
|
351
|
+
tiled_mma,
|
|
352
|
+
self.mma_tiler,
|
|
353
|
+
self.a_dtype,
|
|
354
|
+
self.b_dtype,
|
|
355
|
+
self.epi_tile,
|
|
356
|
+
self.d_dtype,
|
|
357
|
+
self.c_dtype,
|
|
358
|
+
self.d_layout,
|
|
359
|
+
self.c_layout,
|
|
360
|
+
self.sf_dtype,
|
|
361
|
+
self.sf_vec_size,
|
|
362
|
+
self.smem_capacity,
|
|
363
|
+
self.occupancy,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Compute A/B/SFA/SFB/C shared memory layout
|
|
367
|
+
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
|
368
|
+
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage
|
|
369
|
+
)
|
|
370
|
+
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
|
371
|
+
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage
|
|
372
|
+
)
|
|
373
|
+
self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
374
|
+
self.d_dtype, self.d_layout, self.epi_tile, self.num_d_stage
|
|
375
|
+
)
|
|
376
|
+
if const_expr(self.c_dtype is not None):
|
|
377
|
+
self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
378
|
+
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage
|
|
379
|
+
)
|
|
380
|
+
else:
|
|
381
|
+
self.epi_c_smem_layout_staged = None
|
|
382
|
+
if const_expr(self.blockscaled):
|
|
383
|
+
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
|
|
384
|
+
tiled_mma,
|
|
385
|
+
self.mma_tiler,
|
|
386
|
+
self.sf_vec_size,
|
|
387
|
+
self.num_ab_stage,
|
|
388
|
+
)
|
|
389
|
+
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
|
|
390
|
+
tiled_mma,
|
|
391
|
+
self.mma_tiler,
|
|
392
|
+
self.sf_vec_size,
|
|
393
|
+
self.num_ab_stage,
|
|
394
|
+
)
|
|
395
|
+
else:
|
|
396
|
+
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged = None, None
|
|
397
|
+
|
|
398
|
+
# Compute the number of tensor memory allocation columns
|
|
399
|
+
if const_expr(not self.blockscaled):
|
|
400
|
+
self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
|
|
401
|
+
tiled_mma, self.mma_tiler, self.num_acc_stage
|
|
402
|
+
)
|
|
403
|
+
else:
|
|
404
|
+
SM100_TMEM_CAPACITY_COLUMNS = 512
|
|
405
|
+
self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
|
|
406
|
+
|
|
407
|
+
@cute.jit
|
|
408
|
+
def __call__(
|
|
409
|
+
self,
|
|
410
|
+
mA: cute.Tensor,
|
|
411
|
+
mB: cute.Tensor,
|
|
412
|
+
mD: cute.Tensor,
|
|
413
|
+
mC: Optional[cute.Tensor],
|
|
414
|
+
tile_count_semaphore: Optional[cute.Pointer],
|
|
415
|
+
max_active_clusters: cutlass.Constexpr,
|
|
416
|
+
stream: cuda.CUstream,
|
|
417
|
+
mSFA: Optional[cute.Tensor] = None,
|
|
418
|
+
mSFB: Optional[cute.Tensor] = None,
|
|
419
|
+
epilogue_op: cutlass.Constexpr = lambda x: x,
|
|
420
|
+
):
|
|
421
|
+
"""Execute the GEMM operation in steps:
|
|
422
|
+
- Setup static attributes before smem/grid/tma computation
|
|
423
|
+
- Setup TMA load/store atoms and tensors
|
|
424
|
+
- Compute grid size with regard to hardware constraints
|
|
425
|
+
- Define shared storage for kernel
|
|
426
|
+
- Launch the kernel synchronously
|
|
427
|
+
|
|
428
|
+
:param mA: Input tensor A
|
|
429
|
+
:type mA: cute.Tensor
|
|
430
|
+
:param mB: Input tensor B
|
|
431
|
+
:type mB: cute.Tensor
|
|
432
|
+
:param mD: Output tensor D
|
|
433
|
+
:type mD: cute.Tensor
|
|
434
|
+
:param max_active_clusters: Maximum number of active clusters
|
|
435
|
+
:type max_active_clusters: cutlass.Constexpr
|
|
436
|
+
:param stream: CUDA stream for asynchronous execution
|
|
437
|
+
:type stream: cuda.CUstream
|
|
438
|
+
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
|
|
439
|
+
:type epilogue_op: cutlass.Constexpr
|
|
440
|
+
:raises TypeError: If input data types are incompatible with the MMA instruction.
|
|
441
|
+
:raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
|
|
442
|
+
"""
|
|
443
|
+
if const_expr(self.blockscaled):
|
|
444
|
+
assert mSFA is not None and mSFB is not None
|
|
445
|
+
# Setup static attributes before smem/grid/tma computation
|
|
446
|
+
self.a_dtype: Type[cutlass.Numeric] = mA.element_type
|
|
447
|
+
self.b_dtype: Type[cutlass.Numeric] = mB.element_type
|
|
448
|
+
self.d_dtype: Type[cutlass.Numeric] = mD.element_type
|
|
449
|
+
self.c_dtype = mC.element_type if mC is not None else None
|
|
450
|
+
self.sf_dtype: Optional[Type[cutlass.Numeric]] = (
|
|
451
|
+
mSFA.element_type if mSFA is not None else None
|
|
452
|
+
)
|
|
453
|
+
self.a_major_mode = cutlass.utils.LayoutEnum.from_tensor(mA).mma_major_mode()
|
|
454
|
+
self.b_major_mode = cutlass.utils.LayoutEnum.from_tensor(mB).mma_major_mode()
|
|
455
|
+
self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
|
|
456
|
+
self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
|
|
457
|
+
|
|
458
|
+
# Check if input data types are compatible with MMA instruction
|
|
459
|
+
if const_expr(self.a_dtype != self.b_dtype):
|
|
460
|
+
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
|
|
461
|
+
|
|
462
|
+
# Setup attributes that dependent on gemm inputs
|
|
463
|
+
self._setup_attributes()
|
|
464
|
+
|
|
465
|
+
if const_expr(self.blockscaled):
|
|
466
|
+
# Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
|
|
467
|
+
# ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL)
|
|
468
|
+
sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(mA.shape, self.sf_vec_size)
|
|
469
|
+
mSFA = cute.make_tensor(mSFA.iterator, sfa_layout)
|
|
470
|
+
# ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL)
|
|
471
|
+
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(mB.shape, self.sf_vec_size)
|
|
472
|
+
mSFB = cute.make_tensor(mSFB.iterator, sfb_layout)
|
|
473
|
+
|
|
474
|
+
if const_expr(not self.blockscaled):
|
|
475
|
+
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
476
|
+
self.a_dtype,
|
|
477
|
+
self.a_major_mode,
|
|
478
|
+
self.b_major_mode,
|
|
479
|
+
self.acc_dtype,
|
|
480
|
+
self.cta_group,
|
|
481
|
+
self.mma_tiler[:2],
|
|
482
|
+
)
|
|
483
|
+
tiled_mma_sfb = None
|
|
484
|
+
else:
|
|
485
|
+
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
486
|
+
self.a_dtype,
|
|
487
|
+
self.a_major_mode,
|
|
488
|
+
self.b_major_mode,
|
|
489
|
+
self.sf_dtype,
|
|
490
|
+
self.sf_vec_size,
|
|
491
|
+
self.cta_group,
|
|
492
|
+
self.mma_inst_shape_mnk[:2],
|
|
493
|
+
)
|
|
494
|
+
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
|
495
|
+
self.a_dtype,
|
|
496
|
+
self.a_major_mode,
|
|
497
|
+
self.b_major_mode,
|
|
498
|
+
self.sf_dtype,
|
|
499
|
+
self.sf_vec_size,
|
|
500
|
+
cute.nvgpu.tcgen05.CtaGroup.ONE,
|
|
501
|
+
self.mma_inst_shape_mnk_sfb[:2],
|
|
502
|
+
)
|
|
503
|
+
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
|
504
|
+
|
|
505
|
+
# Setup TMA load for A
|
|
506
|
+
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
|
507
|
+
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
|
508
|
+
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
|
509
|
+
a_op,
|
|
510
|
+
mA,
|
|
511
|
+
a_smem_layout,
|
|
512
|
+
self.mma_tiler,
|
|
513
|
+
tiled_mma,
|
|
514
|
+
self.cluster_layout_vmnk.shape,
|
|
515
|
+
internal_type=(cutlass.TFloat32 if mA.element_type is cutlass.Float32 else None),
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Setup TMA load for B
|
|
519
|
+
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
|
520
|
+
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
|
521
|
+
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
|
522
|
+
b_op,
|
|
523
|
+
mB,
|
|
524
|
+
b_smem_layout,
|
|
525
|
+
self.mma_tiler,
|
|
526
|
+
tiled_mma,
|
|
527
|
+
self.cluster_layout_vmnk.shape,
|
|
528
|
+
internal_type=(cutlass.TFloat32 if mB.element_type is cutlass.Float32 else None),
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
if const_expr(self.blockscaled):
|
|
532
|
+
# Setup TMA load for SFA
|
|
533
|
+
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
|
534
|
+
self.cluster_shape_mn, tiled_mma.thr_id
|
|
535
|
+
)
|
|
536
|
+
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
|
537
|
+
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
|
538
|
+
sfa_op,
|
|
539
|
+
mSFA,
|
|
540
|
+
sfa_smem_layout,
|
|
541
|
+
self.mma_tiler,
|
|
542
|
+
tiled_mma,
|
|
543
|
+
self.cluster_layout_vmnk.shape,
|
|
544
|
+
internal_type=cutlass.Int16,
|
|
545
|
+
)
|
|
546
|
+
# Setup TMA load for SFB
|
|
547
|
+
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
|
|
548
|
+
self.cluster_shape_mn, tiled_mma.thr_id
|
|
549
|
+
)
|
|
550
|
+
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
|
551
|
+
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
|
552
|
+
sfb_op,
|
|
553
|
+
mSFB,
|
|
554
|
+
sfb_smem_layout,
|
|
555
|
+
self.mma_tiler_sfb,
|
|
556
|
+
tiled_mma_sfb,
|
|
557
|
+
self.cluster_layout_sfb_vmnk.shape,
|
|
558
|
+
internal_type=cutlass.Int16,
|
|
559
|
+
)
|
|
560
|
+
else:
|
|
561
|
+
tma_atom_sfa, tma_tensor_sfa = None, None
|
|
562
|
+
tma_atom_sfb, tma_tensor_sfb = None, None
|
|
563
|
+
|
|
564
|
+
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
565
|
+
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
566
|
+
self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
|
|
567
|
+
if const_expr(self.blockscaled):
|
|
568
|
+
sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
|
|
569
|
+
sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
|
|
570
|
+
self.num_tma_load_bytes += (sfa_copy_size + sfb_copy_size) * atom_thr_size
|
|
571
|
+
|
|
572
|
+
# Setup TMA store for D
|
|
573
|
+
epi_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
|
|
574
|
+
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
|
575
|
+
cpasync.CopyBulkTensorTileS2GOp(),
|
|
576
|
+
mD,
|
|
577
|
+
epi_smem_layout,
|
|
578
|
+
self.epi_tile,
|
|
579
|
+
)
|
|
580
|
+
if const_expr(mC is not None):
|
|
581
|
+
epi_c_smem_layout = cute.slice_(self.epi_c_smem_layout_staged, (None, None, 0))
|
|
582
|
+
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
|
583
|
+
cpasync.CopyBulkTensorTileG2SOp(),
|
|
584
|
+
mC,
|
|
585
|
+
epi_c_smem_layout,
|
|
586
|
+
self.epi_tile,
|
|
587
|
+
)
|
|
588
|
+
else:
|
|
589
|
+
tma_atom_c, tma_tensor_c = None, None
|
|
590
|
+
|
|
591
|
+
problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.cta_tile_shape_mnk[:2]) + (
|
|
592
|
+
mD.shape[2],
|
|
593
|
+
)
|
|
594
|
+
TileSchedulerCls = TileScheduler
|
|
595
|
+
tile_sched_args = TileSchedulerArguments(
|
|
596
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
597
|
+
raster_order=RasterOrderOption.Heuristic,
|
|
598
|
+
group_size=8,
|
|
599
|
+
cluster_shape_mnk=(*self.cluster_shape_mn, 1),
|
|
600
|
+
tile_count_semaphore=tile_count_semaphore,
|
|
601
|
+
is_persistent=True,
|
|
602
|
+
)
|
|
603
|
+
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
604
|
+
grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
|
|
605
|
+
|
|
606
|
+
self.buffer_align_bytes = 1024
|
|
607
|
+
|
|
608
|
+
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
|
609
|
+
sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU
|
|
610
|
+
sfa_smem_size = (
|
|
611
|
+
cute.cosize(self.sfa_smem_layout_staged) if const_expr(self.blockscaled) else 0
|
|
612
|
+
)
|
|
613
|
+
sfb_smem_size = (
|
|
614
|
+
cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
# Define shared storage for kernel
|
|
618
|
+
@cute.struct
|
|
619
|
+
class SharedStorage:
|
|
620
|
+
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
|
621
|
+
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
|
622
|
+
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage * 2]
|
|
623
|
+
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
|
624
|
+
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
|
625
|
+
tmem_dealloc_mbar_ptr: cutlass.Int64
|
|
626
|
+
tmem_holding_buf: Int32
|
|
627
|
+
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
|
628
|
+
tile_count: cute.struct.MemRange[cutlass.Int32, 1]
|
|
629
|
+
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
630
|
+
sD: cute.struct.Align[
|
|
631
|
+
cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)],
|
|
632
|
+
self.buffer_align_bytes,
|
|
633
|
+
]
|
|
634
|
+
sC: cute.struct.Align[
|
|
635
|
+
cute.struct.MemRange[
|
|
636
|
+
self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
|
|
637
|
+
],
|
|
638
|
+
self.buffer_align_bytes,
|
|
639
|
+
]
|
|
640
|
+
# (MMA, MMA_M, MMA_K, STAGE)
|
|
641
|
+
sA: cute.struct.Align[
|
|
642
|
+
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
|
|
643
|
+
self.buffer_align_bytes,
|
|
644
|
+
]
|
|
645
|
+
# (MMA, MMA_N, MMA_K, STAGE)
|
|
646
|
+
sB: cute.struct.Align[
|
|
647
|
+
cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
|
|
648
|
+
self.buffer_align_bytes,
|
|
649
|
+
]
|
|
650
|
+
# (MMA, MMA_M, MMA_K, STAGE)
|
|
651
|
+
sSFA: cute.struct.Align[
|
|
652
|
+
cute.struct.MemRange[sf_dtype, sfa_smem_size],
|
|
653
|
+
self.buffer_align_bytes,
|
|
654
|
+
]
|
|
655
|
+
# (MMA, MMA_N, MMA_K, STAGE)
|
|
656
|
+
sSFB: cute.struct.Align[
|
|
657
|
+
cute.struct.MemRange[sf_dtype, sfb_smem_size],
|
|
658
|
+
self.buffer_align_bytes,
|
|
659
|
+
]
|
|
660
|
+
|
|
661
|
+
self.shared_storage = SharedStorage
|
|
662
|
+
|
|
663
|
+
# Launch the kernel synchronously
|
|
664
|
+
self.kernel(
|
|
665
|
+
tiled_mma,
|
|
666
|
+
tiled_mma_sfb,
|
|
667
|
+
tma_atom_a,
|
|
668
|
+
tma_tensor_a,
|
|
669
|
+
tma_atom_b,
|
|
670
|
+
tma_tensor_b,
|
|
671
|
+
tma_atom_sfa,
|
|
672
|
+
tma_tensor_sfa,
|
|
673
|
+
tma_atom_sfb,
|
|
674
|
+
tma_tensor_sfb,
|
|
675
|
+
tma_atom_d,
|
|
676
|
+
tma_tensor_d,
|
|
677
|
+
tma_atom_c,
|
|
678
|
+
tma_tensor_c,
|
|
679
|
+
self.cluster_layout_vmnk,
|
|
680
|
+
self.cluster_layout_sfb_vmnk,
|
|
681
|
+
self.a_smem_layout_staged,
|
|
682
|
+
self.b_smem_layout_staged,
|
|
683
|
+
self.sfa_smem_layout_staged,
|
|
684
|
+
self.sfb_smem_layout_staged,
|
|
685
|
+
self.d_smem_layout_staged,
|
|
686
|
+
self.epi_c_smem_layout_staged,
|
|
687
|
+
self.epi_tile,
|
|
688
|
+
tile_sched_params,
|
|
689
|
+
TileSchedulerCls,
|
|
690
|
+
epilogue_op,
|
|
691
|
+
).launch(
|
|
692
|
+
grid=grid,
|
|
693
|
+
block=[self.threads_per_cta, 1, 1],
|
|
694
|
+
cluster=(*self.cluster_shape_mn, 1),
|
|
695
|
+
smem=self.shared_storage.size_in_bytes(),
|
|
696
|
+
stream=stream,
|
|
697
|
+
)
|
|
698
|
+
return
|
|
699
|
+
|
|
700
|
+
# GPU device kernel
|
|
701
|
+
@cute.kernel
|
|
702
|
+
def kernel(
|
|
703
|
+
self,
|
|
704
|
+
tiled_mma: cute.TiledMma,
|
|
705
|
+
tiled_mma_sfb: Optional[cute.TiledMma],
|
|
706
|
+
tma_atom_a: cute.CopyAtom,
|
|
707
|
+
mA_mkl: cute.Tensor,
|
|
708
|
+
tma_atom_b: cute.CopyAtom,
|
|
709
|
+
mB_nkl: cute.Tensor,
|
|
710
|
+
tma_atom_sfa: Optional[cute.CopyAtom],
|
|
711
|
+
mSFA_mkl: Optional[cute.Tensor],
|
|
712
|
+
tma_atom_sfb: Optional[cute.CopyAtom],
|
|
713
|
+
mSFB_nkl: Optional[cute.Tensor],
|
|
714
|
+
tma_atom_d: Optional[cute.CopyAtom],
|
|
715
|
+
mD_mnl: cute.Tensor,
|
|
716
|
+
tma_atom_c: Optional[cute.CopyAtom],
|
|
717
|
+
mC_mnl: Optional[cute.Tensor],
|
|
718
|
+
cluster_layout_vmnk: cute.Layout,
|
|
719
|
+
cluster_layout_sfb_vmnk: Optional[cute.Layout],
|
|
720
|
+
a_smem_layout_staged: cute.ComposedLayout,
|
|
721
|
+
b_smem_layout_staged: cute.ComposedLayout,
|
|
722
|
+
sfa_smem_layout_staged: Optional[cute.Layout],
|
|
723
|
+
sfb_smem_layout_staged: Optional[cute.Layout],
|
|
724
|
+
d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
|
|
725
|
+
epi_c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
|
|
726
|
+
epi_tile: cute.Tile,
|
|
727
|
+
tile_sched_params: ParamsBase,
|
|
728
|
+
TileSchedulerCls: cutlass.Constexpr[Callable],
|
|
729
|
+
epilogue_op: cutlass.Constexpr[Callable],
|
|
730
|
+
):
|
|
731
|
+
"""
|
|
732
|
+
GPU device kernel performing the Persistent batched GEMM computation.
|
|
733
|
+
"""
|
|
734
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
735
|
+
|
|
736
|
+
#
|
|
737
|
+
# Prefetch tma desc
|
|
738
|
+
#
|
|
739
|
+
if warp_idx == self.tma_warp_id:
|
|
740
|
+
cpasync.prefetch_descriptor(tma_atom_a)
|
|
741
|
+
cpasync.prefetch_descriptor(tma_atom_b)
|
|
742
|
+
if const_expr(self.blockscaled):
|
|
743
|
+
cpasync.prefetch_descriptor(tma_atom_sfa)
|
|
744
|
+
cpasync.prefetch_descriptor(tma_atom_sfb)
|
|
745
|
+
cpasync.prefetch_descriptor(tma_atom_d)
|
|
746
|
+
|
|
747
|
+
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
|
|
748
|
+
|
|
749
|
+
#
|
|
750
|
+
# Setup cta/thread coordinates
|
|
751
|
+
#
|
|
752
|
+
# Coords inside cluster
|
|
753
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
754
|
+
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
|
755
|
+
is_leader_cta = mma_tile_coord_v == 0
|
|
756
|
+
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
757
|
+
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
|
758
|
+
if const_expr(self.blockscaled):
|
|
759
|
+
block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
|
|
760
|
+
cta_rank_in_cluster
|
|
761
|
+
)
|
|
762
|
+
else:
|
|
763
|
+
block_in_cluster_coord_sfb_vmnk = None
|
|
764
|
+
# Coord inside cta
|
|
765
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
766
|
+
|
|
767
|
+
#
|
|
768
|
+
# Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
|
|
769
|
+
#
|
|
770
|
+
smem = cutlass.utils.SmemAllocator()
|
|
771
|
+
storage = smem.allocate(self.shared_storage)
|
|
772
|
+
|
|
773
|
+
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
|
|
774
|
+
tmem_holding_buf = storage.tmem_holding_buf
|
|
775
|
+
|
|
776
|
+
# Tensor memory dealloc barrier init
|
|
777
|
+
if use_2cta_instrs:
|
|
778
|
+
if warp_idx == self.tma_warp_id:
|
|
779
|
+
num_tmem_dealloc_threads = 32
|
|
780
|
+
cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
|
|
781
|
+
|
|
782
|
+
# Initialize mainloop ab_pipeline (barrier) and states
|
|
783
|
+
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
784
|
+
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
785
|
+
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
786
|
+
pipeline.Agent.Thread, num_tma_producer
|
|
787
|
+
)
|
|
788
|
+
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
|
789
|
+
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
|
|
790
|
+
num_stages=self.num_ab_stage,
|
|
791
|
+
producer_group=ab_pipeline_producer_group,
|
|
792
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
793
|
+
tx_count=self.num_tma_load_bytes,
|
|
794
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
if const_expr(mC_mnl is not None):
|
|
798
|
+
# Threads/warps participating in this pipeline
|
|
799
|
+
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
800
|
+
# Each warp will contribute 1 to the arrive count
|
|
801
|
+
consumer_arrive_cnt = len(self.epilog_warp_id)
|
|
802
|
+
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
803
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
804
|
+
)
|
|
805
|
+
c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
|
|
806
|
+
tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
|
|
807
|
+
epi_pipeline = pipeline.PipelineTmaAsync.create(
|
|
808
|
+
barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
809
|
+
num_stages=self.num_c_stage,
|
|
810
|
+
producer_group=epi_pipeline_producer_group,
|
|
811
|
+
consumer_group=epi_pipeline_consumer_group,
|
|
812
|
+
tx_count=tma_copy_c_bytes,
|
|
813
|
+
)
|
|
814
|
+
else:
|
|
815
|
+
epi_pipeline = None
|
|
816
|
+
|
|
817
|
+
# Initialize acc_pipeline (barrier) and states
|
|
818
|
+
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
819
|
+
num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
|
|
820
|
+
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
821
|
+
pipeline.Agent.Thread, num_acc_consumer_threads
|
|
822
|
+
)
|
|
823
|
+
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
|
824
|
+
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
|
825
|
+
num_stages=self.num_acc_stage,
|
|
826
|
+
producer_group=acc_pipeline_producer_group,
|
|
827
|
+
consumer_group=acc_pipeline_consumer_group,
|
|
828
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
832
|
+
# # Dynamic persistent scheduler
|
|
833
|
+
# # Threads/warps participating in this pipeline
|
|
834
|
+
# sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
835
|
+
# cluster_size = cute.size(cluster_layout_vmnk)
|
|
836
|
+
# # Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
|
837
|
+
# consumer_arrive_cnt = (
|
|
838
|
+
# (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
|
|
839
|
+
# ) * cluster_size - 1
|
|
840
|
+
# sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
841
|
+
# pipeline.Agent.Thread, consumer_arrive_cnt
|
|
842
|
+
# )
|
|
843
|
+
# sched_pipeline = pipeline.PipelineAsync.create(
|
|
844
|
+
# barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
845
|
+
# num_stages=self.sched_stage,
|
|
846
|
+
# producer_group=sched_pipeline_producer_group,
|
|
847
|
+
# consumer_group=sched_pipeline_consumer_group,
|
|
848
|
+
# # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
849
|
+
# consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
|
|
850
|
+
# )
|
|
851
|
+
# tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
852
|
+
# else:
|
|
853
|
+
# sched_pipeline = None
|
|
854
|
+
# tile_count = None
|
|
855
|
+
|
|
856
|
+
# Setup smem tensor A/B/D
|
|
857
|
+
# (MMA, MMA_M, MMA_K, STAGE)
|
|
858
|
+
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
|
|
859
|
+
# (MMA, MMA_N, MMA_K, STAGE)
|
|
860
|
+
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
|
|
861
|
+
if const_expr(self.blockscaled):
|
|
862
|
+
# (MMA, MMA_M, MMA_K, STAGE)
|
|
863
|
+
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
|
|
864
|
+
# (MMA, MMA_N, MMA_K, STAGE)
|
|
865
|
+
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
|
|
866
|
+
else:
|
|
867
|
+
sSFA, sSFB = None, None
|
|
868
|
+
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
869
|
+
sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
|
|
870
|
+
if const_expr(mC_mnl is not None):
|
|
871
|
+
sC = storage.sC.get_tensor(
|
|
872
|
+
epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
|
|
873
|
+
)
|
|
874
|
+
else:
|
|
875
|
+
sC = None
|
|
876
|
+
|
|
877
|
+
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
|
878
|
+
thr_mma_sfb = (
|
|
879
|
+
tiled_mma_sfb.get_slice(mma_tile_coord_v) if const_expr(self.blockscaled) else None
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
# (MMA, MMA_M, MMA_N)
|
|
883
|
+
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
|
884
|
+
# (MMA, MMA_M, MMA_N, STAGE)
|
|
885
|
+
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
|
886
|
+
|
|
887
|
+
tmem_ptr_read_threads = cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id))
|
|
888
|
+
tmem_alloc_barrier = pipeline.NamedBarrier(
|
|
889
|
+
barrier_id=self.tmem_ptr_sync_bar_id, num_threads=tmem_ptr_read_threads
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
TileSchedulerCls = partial(TileSchedulerCls.create, tile_sched_params)
|
|
893
|
+
k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.mma_tiler[2])
|
|
894
|
+
|
|
895
|
+
if const_expr(mC_mnl is not None):
|
|
896
|
+
epi_load_barrier = pipeline.NamedBarrier(
|
|
897
|
+
barrier_id=int(self.epilog_load_bar_id), num_threads=2 * cute.arch.WARP_SIZE
|
|
898
|
+
)
|
|
899
|
+
else:
|
|
900
|
+
epi_load_barrier = None
|
|
901
|
+
|
|
902
|
+
#
|
|
903
|
+
# Specialized TMA load warp
|
|
904
|
+
#
|
|
905
|
+
if warp_idx == self.tma_warp_id:
|
|
906
|
+
# Compute multicast mask for A/B buffer full
|
|
907
|
+
if const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
|
908
|
+
a_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
909
|
+
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
|
910
|
+
)
|
|
911
|
+
b_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
912
|
+
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
|
|
913
|
+
)
|
|
914
|
+
if const_expr(self.blockscaled):
|
|
915
|
+
sfa_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
916
|
+
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
|
917
|
+
)
|
|
918
|
+
sfb_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
919
|
+
cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1
|
|
920
|
+
)
|
|
921
|
+
else:
|
|
922
|
+
sfa_mcast_mask, sfb_mcast_mask = None, None
|
|
923
|
+
else:
|
|
924
|
+
a_mcast_mask, b_mcast_mask = None, None
|
|
925
|
+
sfa_mcast_mask, sfb_mcast_mask = None, None
|
|
926
|
+
|
|
927
|
+
# Persistent tile scheduling loop
|
|
928
|
+
tile_scheduler = TileSchedulerCls()
|
|
929
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
930
|
+
ab_producer_state = pipeline.make_pipeline_state(
|
|
931
|
+
pipeline.PipelineUserType.Producer, self.num_ab_stage
|
|
932
|
+
)
|
|
933
|
+
do_epi_load_barrier_arrive = cutlass.Boolean(True)
|
|
934
|
+
while work_tile.is_valid_tile:
|
|
935
|
+
# Get tile coord from tile scheduler
|
|
936
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
937
|
+
mma_tile_coord_mnl = (
|
|
938
|
+
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
939
|
+
tile_coord_mnkl[1],
|
|
940
|
+
tile_coord_mnkl[3],
|
|
941
|
+
)
|
|
942
|
+
# Local_tile partition global tensors
|
|
943
|
+
# (bM, bK, RestK)
|
|
944
|
+
gA_mkl = cute.local_tile(
|
|
945
|
+
mA_mkl,
|
|
946
|
+
cute.slice_(self.mma_tiler, (None, 0, None)),
|
|
947
|
+
(mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]),
|
|
948
|
+
)
|
|
949
|
+
# (bN, bK, RestK)
|
|
950
|
+
gB_nkl = cute.local_tile(
|
|
951
|
+
mB_nkl,
|
|
952
|
+
cute.slice_(self.mma_tiler, (0, None, None)),
|
|
953
|
+
(mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]),
|
|
954
|
+
)
|
|
955
|
+
if const_expr(self.blockscaled):
|
|
956
|
+
# (bM, bK)
|
|
957
|
+
gSFA_mkl = cute.local_tile(
|
|
958
|
+
mSFA_mkl,
|
|
959
|
+
cute.slice_(self.mma_tiler, (None, 0, None)),
|
|
960
|
+
(mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]),
|
|
961
|
+
)
|
|
962
|
+
# (bN, bK)
|
|
963
|
+
gSFB_nkl = cute.local_tile(
|
|
964
|
+
mSFB_nkl,
|
|
965
|
+
cute.slice_(self.mma_tiler, (0, None, None)),
|
|
966
|
+
(mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]),
|
|
967
|
+
)
|
|
968
|
+
# Partition global tensor for TiledMMA_A/B/D
|
|
969
|
+
# (MMA, MMA_M, MMA_K, RestK)
|
|
970
|
+
tCgA = thr_mma.partition_A(gA_mkl)
|
|
971
|
+
# (MMA, MMA_N, MMA_K, RestK)
|
|
972
|
+
tCgB = thr_mma.partition_B(gB_nkl)
|
|
973
|
+
if const_expr(self.blockscaled):
|
|
974
|
+
# (MMA, MMA_M, MMA_K)
|
|
975
|
+
tCgSFA = thr_mma.partition_A(gSFA_mkl)
|
|
976
|
+
# (MMA, MMA_N, MMA_K)
|
|
977
|
+
tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
|
|
978
|
+
# Partition global/shared tensor for TMA load A/B
|
|
979
|
+
# TMA load A partition_S/D
|
|
980
|
+
a_cta_layout = cute.make_layout(
|
|
981
|
+
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
|
982
|
+
)
|
|
983
|
+
# ((atom_v, rest_v), STAGE)
|
|
984
|
+
# ((atom_v, rest_v), RestK)
|
|
985
|
+
tAsA, tAgA = cpasync.tma_partition(
|
|
986
|
+
tma_atom_a,
|
|
987
|
+
block_in_cluster_coord_vmnk[2],
|
|
988
|
+
a_cta_layout,
|
|
989
|
+
cute.group_modes(sA, 0, 3),
|
|
990
|
+
cute.group_modes(tCgA, 0, 3),
|
|
991
|
+
)
|
|
992
|
+
# TMA load B partition_S/D
|
|
993
|
+
b_cta_layout = cute.make_layout(
|
|
994
|
+
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
995
|
+
)
|
|
996
|
+
# ((atom_v, rest_v), STAGE)
|
|
997
|
+
# ((atom_v, rest_v), RestK)
|
|
998
|
+
tBsB, tBgB = cpasync.tma_partition(
|
|
999
|
+
tma_atom_b,
|
|
1000
|
+
block_in_cluster_coord_vmnk[1],
|
|
1001
|
+
b_cta_layout,
|
|
1002
|
+
cute.group_modes(sB, 0, 3),
|
|
1003
|
+
cute.group_modes(tCgB, 0, 3),
|
|
1004
|
+
)
|
|
1005
|
+
if const_expr(self.blockscaled):
|
|
1006
|
+
# TMA load SFA partition_S/D
|
|
1007
|
+
sfa_cta_layout = a_cta_layout
|
|
1008
|
+
# ((atom_v, rest_v), STAGE)
|
|
1009
|
+
# ((atom_v, rest_v), RestK)
|
|
1010
|
+
tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition(
|
|
1011
|
+
tma_atom_sfa,
|
|
1012
|
+
block_in_cluster_coord_vmnk[2],
|
|
1013
|
+
sfa_cta_layout,
|
|
1014
|
+
cute.group_modes(sSFA, 0, 3),
|
|
1015
|
+
cute.group_modes(tCgSFA, 0, 3),
|
|
1016
|
+
)
|
|
1017
|
+
tAsSFA = cute.filter_zeros(tAsSFA)
|
|
1018
|
+
tAgSFA = cute.filter_zeros(tAgSFA)
|
|
1019
|
+
# TMA load SFB partition_S/D
|
|
1020
|
+
sfb_cta_layout = cute.make_layout(
|
|
1021
|
+
cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape
|
|
1022
|
+
)
|
|
1023
|
+
# ((atom_v, rest_v), STAGE)
|
|
1024
|
+
# ((atom_v, rest_v), RestK)
|
|
1025
|
+
tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition(
|
|
1026
|
+
tma_atom_sfb,
|
|
1027
|
+
block_in_cluster_coord_sfb_vmnk[1],
|
|
1028
|
+
sfb_cta_layout,
|
|
1029
|
+
cute.group_modes(sSFB, 0, 3),
|
|
1030
|
+
cute.group_modes(tCgSFB, 0, 3),
|
|
1031
|
+
)
|
|
1032
|
+
tBsSFB = cute.filter_zeros(tBsSFB)
|
|
1033
|
+
tBgSFB = cute.filter_zeros(tBgSFB)
|
|
1034
|
+
else:
|
|
1035
|
+
tAsSFA, tAgSFA = None, None
|
|
1036
|
+
tBsSFB, tBgSFB = None, None
|
|
1037
|
+
ab_producer_state = self.load_AB(
|
|
1038
|
+
ab_pipeline,
|
|
1039
|
+
ab_producer_state,
|
|
1040
|
+
tma_atom_a,
|
|
1041
|
+
tAgA,
|
|
1042
|
+
tAsA,
|
|
1043
|
+
a_mcast_mask,
|
|
1044
|
+
tma_atom_b,
|
|
1045
|
+
tBgB,
|
|
1046
|
+
tBsB,
|
|
1047
|
+
b_mcast_mask,
|
|
1048
|
+
tma_atom_sfa,
|
|
1049
|
+
tAgSFA,
|
|
1050
|
+
tAsSFA,
|
|
1051
|
+
sfa_mcast_mask,
|
|
1052
|
+
tma_atom_sfb,
|
|
1053
|
+
tBgSFB,
|
|
1054
|
+
tBsSFB,
|
|
1055
|
+
sfb_mcast_mask,
|
|
1056
|
+
)
|
|
1057
|
+
if const_expr(epi_load_barrier is not None):
|
|
1058
|
+
# In the first work tile, the epi load warp will wait for the signal
|
|
1059
|
+
# from the mainloop load warp to start loading C, to avoid interfering
|
|
1060
|
+
# with loading A and B.
|
|
1061
|
+
if do_epi_load_barrier_arrive:
|
|
1062
|
+
epi_load_barrier.arrive()
|
|
1063
|
+
do_epi_load_barrier_arrive = cutlass.Boolean(False)
|
|
1064
|
+
# Advance to next tile
|
|
1065
|
+
tile_scheduler.advance_to_next_work()
|
|
1066
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1067
|
+
# Wait A/B buffer empty
|
|
1068
|
+
ab_pipeline.producer_tail(ab_producer_state)
|
|
1069
|
+
|
|
1070
|
+
#
|
|
1071
|
+
# Specialized TMA epi load warp
|
|
1072
|
+
#
|
|
1073
|
+
if const_expr(mC_mnl is not None):
|
|
1074
|
+
if warp_idx == self.tma_epi_warp_id:
|
|
1075
|
+
epi_producer_state = pipeline.make_pipeline_state(
|
|
1076
|
+
pipeline.PipelineUserType.Producer, self.num_c_stage
|
|
1077
|
+
)
|
|
1078
|
+
do_epi_load_barrier_wait = cutlass.Boolean(True)
|
|
1079
|
+
# Persistent tile scheduling loop
|
|
1080
|
+
tile_scheduler = TileSchedulerCls()
|
|
1081
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1082
|
+
while work_tile.is_valid_tile:
|
|
1083
|
+
# Get tile coord from tile scheduler
|
|
1084
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1085
|
+
mma_tile_coord_mnl = (
|
|
1086
|
+
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
1087
|
+
tile_coord_mnkl[1],
|
|
1088
|
+
tile_coord_mnkl[3],
|
|
1089
|
+
)
|
|
1090
|
+
# Local_tile partition global tensors
|
|
1091
|
+
# (bM, bN)
|
|
1092
|
+
gC_mnl = cute.local_tile(
|
|
1093
|
+
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
|
|
1094
|
+
)
|
|
1095
|
+
# Partition global tensor for TiledMMA_A/B/D
|
|
1096
|
+
# (MMA, MMA_M, MMA_N)
|
|
1097
|
+
tCgC = thr_mma.partition_C(gC_mnl)
|
|
1098
|
+
# bGS_gC has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
1099
|
+
bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
|
|
1100
|
+
tma_atom_c, tCgC, epi_tile, sC
|
|
1101
|
+
)
|
|
1102
|
+
bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
|
|
1103
|
+
if do_epi_load_barrier_wait:
|
|
1104
|
+
epi_load_barrier.arrive_and_wait()
|
|
1105
|
+
do_epi_load_barrier_wait = cutlass.Boolean(False)
|
|
1106
|
+
epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1]))
|
|
1107
|
+
for subtile_idx in cutlass.range(epi_tile_num, unroll=1):
|
|
1108
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1109
|
+
cute.copy(
|
|
1110
|
+
tma_atom_c,
|
|
1111
|
+
bGS_gC[None, subtile_idx],
|
|
1112
|
+
bGS_sC[None, epi_producer_state.index],
|
|
1113
|
+
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1114
|
+
)
|
|
1115
|
+
# Epi pipeline's producer commit is a NOP
|
|
1116
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1117
|
+
epi_producer_state.advance()
|
|
1118
|
+
# Advance to next tile
|
|
1119
|
+
tile_scheduler.advance_to_next_work()
|
|
1120
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1121
|
+
# End of persistent scheduler loop
|
|
1122
|
+
epi_pipeline.producer_tail(epi_producer_state)
|
|
1123
|
+
|
|
1124
|
+
#
|
|
1125
|
+
# Specialized MMA warp
|
|
1126
|
+
#
|
|
1127
|
+
if warp_idx == self.mma_warp_id:
|
|
1128
|
+
tmem_alloc_barrier.arrive_and_wait()
|
|
1129
|
+
# Retrieving tensor memory ptr and make accumulator tensor
|
|
1130
|
+
acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
1131
|
+
self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
|
|
1132
|
+
)
|
|
1133
|
+
# Partition shared/tensor memory tensor for TiledMMA_A/B/D
|
|
1134
|
+
# (MMA, MMA_M, MMA_K, STAGE)
|
|
1135
|
+
tCrA = tiled_mma.make_fragment_A(sA)
|
|
1136
|
+
# (MMA, MMA_N, MMA_K, STAGE)
|
|
1137
|
+
tCrB = tiled_mma.make_fragment_B(sB)
|
|
1138
|
+
# (MMA, MMA_M, MMA_N, STAGE)
|
|
1139
|
+
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
|
1140
|
+
|
|
1141
|
+
if const_expr(self.blockscaled):
|
|
1142
|
+
# Make SFA tmem tensor
|
|
1143
|
+
sfa_tmem_ptr = cute.recast_ptr(
|
|
1144
|
+
acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base),
|
|
1145
|
+
dtype=self.sf_dtype,
|
|
1146
|
+
)
|
|
1147
|
+
# (MMA, MMA_M, MMA_K)
|
|
1148
|
+
tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
|
|
1149
|
+
tiled_mma,
|
|
1150
|
+
self.mma_tiler,
|
|
1151
|
+
self.sf_vec_size,
|
|
1152
|
+
cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
|
|
1153
|
+
)
|
|
1154
|
+
tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
|
|
1155
|
+
|
|
1156
|
+
# Make SFB tmem tensor
|
|
1157
|
+
sfb_tmem_ptr = cute.recast_ptr(
|
|
1158
|
+
acc_tmem_ptr
|
|
1159
|
+
+ tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
|
|
1160
|
+
+ tcgen05.find_tmem_tensor_col_offset(tCtSFA),
|
|
1161
|
+
dtype=self.sf_dtype,
|
|
1162
|
+
)
|
|
1163
|
+
# (MMA, MMA_N, MMA_K)
|
|
1164
|
+
tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
|
|
1165
|
+
tiled_mma,
|
|
1166
|
+
self.mma_tiler,
|
|
1167
|
+
self.sf_vec_size,
|
|
1168
|
+
cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
|
|
1169
|
+
)
|
|
1170
|
+
tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
|
|
1171
|
+
# Partition for S2T copy of SFA/SFB
|
|
1172
|
+
(
|
|
1173
|
+
tiled_copy_s2t_sfa,
|
|
1174
|
+
tCsSFA_compact_s2t,
|
|
1175
|
+
tCtSFA_compact_s2t,
|
|
1176
|
+
) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA)
|
|
1177
|
+
(
|
|
1178
|
+
tiled_copy_s2t_sfb,
|
|
1179
|
+
tCsSFB_compact_s2t,
|
|
1180
|
+
tCtSFB_compact_s2t,
|
|
1181
|
+
) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
|
|
1182
|
+
else:
|
|
1183
|
+
tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
|
|
1184
|
+
tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
|
|
1185
|
+
|
|
1186
|
+
# Persistent tile scheduling loop
|
|
1187
|
+
tile_scheduler = TileSchedulerCls()
|
|
1188
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1189
|
+
ab_consumer_state = pipeline.make_pipeline_state(
|
|
1190
|
+
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
|
1191
|
+
)
|
|
1192
|
+
acc_producer_state = pipeline.make_pipeline_state(
|
|
1193
|
+
pipeline.PipelineUserType.Producer, self.num_acc_stage
|
|
1194
|
+
)
|
|
1195
|
+
while work_tile.is_valid_tile:
|
|
1196
|
+
# Get tile coord from tile scheduler
|
|
1197
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1198
|
+
# Set tensor memory buffer for current tile
|
|
1199
|
+
# (MMA, MMA_M, MMA_N)
|
|
1200
|
+
tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index]
|
|
1201
|
+
ab_consumer_state, acc_producer_state, tiled_mma = self.mma(
|
|
1202
|
+
ab_pipeline,
|
|
1203
|
+
acc_pipeline,
|
|
1204
|
+
ab_consumer_state,
|
|
1205
|
+
acc_producer_state,
|
|
1206
|
+
tiled_mma,
|
|
1207
|
+
tCrA,
|
|
1208
|
+
tCrB,
|
|
1209
|
+
tCtAcc,
|
|
1210
|
+
k_tile_cnt,
|
|
1211
|
+
is_leader_cta,
|
|
1212
|
+
tiled_copy_s2t_sfa,
|
|
1213
|
+
tiled_copy_s2t_sfb,
|
|
1214
|
+
tCsSFA_compact_s2t,
|
|
1215
|
+
tCsSFB_compact_s2t,
|
|
1216
|
+
tCtSFA_compact_s2t,
|
|
1217
|
+
tCtSFB_compact_s2t,
|
|
1218
|
+
)
|
|
1219
|
+
# Advance to next tile
|
|
1220
|
+
tile_scheduler.advance_to_next_work()
|
|
1221
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1222
|
+
|
|
1223
|
+
# Wait for accumulator buffer empty
|
|
1224
|
+
acc_pipeline.producer_tail(acc_producer_state)
|
|
1225
|
+
|
|
1226
|
+
#
|
|
1227
|
+
# Specialized epilogue warps
|
|
1228
|
+
#
|
|
1229
|
+
if warp_idx < self.mma_warp_id:
|
|
1230
|
+
# Alloc tensor memory buffer
|
|
1231
|
+
if warp_idx == self.epilog_warp_id[0]:
|
|
1232
|
+
cute.arch.alloc_tmem(
|
|
1233
|
+
self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs
|
|
1234
|
+
)
|
|
1235
|
+
# Bar sync for retrieve tensor memory ptr from shared memory
|
|
1236
|
+
tmem_alloc_barrier.arrive_and_wait()
|
|
1237
|
+
# Retrieving tensor memory ptr and make accumulator tensor
|
|
1238
|
+
acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
1239
|
+
self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
|
|
1240
|
+
)
|
|
1241
|
+
# (MMA, MMA_M, MMA_N, STAGE)
|
|
1242
|
+
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
|
1243
|
+
|
|
1244
|
+
epilog_threads = cute.arch.WARP_SIZE * len(self.epilog_warp_id)
|
|
1245
|
+
epilogue_barrier = pipeline.NamedBarrier(
|
|
1246
|
+
barrier_id=self.epilog_sync_bar_id, num_threads=epilog_threads
|
|
1247
|
+
)
|
|
1248
|
+
|
|
1249
|
+
# Partition for epilogue
|
|
1250
|
+
epi_tidx = tidx
|
|
1251
|
+
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
|
|
1252
|
+
epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.d_dtype)
|
|
1256
|
+
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
|
|
1257
|
+
tiled_copy_t2r, tTR_rD, epi_tidx, sD
|
|
1258
|
+
)
|
|
1259
|
+
if const_expr(mC_mnl is not None):
|
|
1260
|
+
tTR_rC = cute.make_fragment_like(tTR_rD, self.c_dtype)
|
|
1261
|
+
tiled_copy_s2r, tSR_rC, tSR_sC = self.epilog_smem_copy_and_partition(
|
|
1262
|
+
tiled_copy_t2r, tTR_rC, epi_tidx, sC
|
|
1263
|
+
)
|
|
1264
|
+
# TODO: for m major, D is being stored w STSM so we'd need LDSM here
|
|
1265
|
+
# tRS_rC = tSR_rC # TODO: retile?
|
|
1266
|
+
tRS_rC = cute.make_fragment(tRS_rD.layout, self.c_dtype)
|
|
1267
|
+
tSR_rC = tiled_copy_s2r.get_slice(epi_tidx).retile(tRS_rC)
|
|
1268
|
+
|
|
1269
|
+
# Persistent tile scheduling loop
|
|
1270
|
+
tile_scheduler = TileSchedulerCls()
|
|
1271
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1272
|
+
acc_consumer_state = pipeline.make_pipeline_state(
|
|
1273
|
+
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
|
1274
|
+
)
|
|
1275
|
+
# Threads/warps participating in tma store pipeline
|
|
1276
|
+
d_producer_group = pipeline.CooperativeGroup(
|
|
1277
|
+
pipeline.Agent.Thread,
|
|
1278
|
+
32 * len(self.epilog_warp_id),
|
|
1279
|
+
32 * len(self.epilog_warp_id),
|
|
1280
|
+
)
|
|
1281
|
+
d_pipeline = pipeline.PipelineTmaStore.create(
|
|
1282
|
+
num_stages=self.num_d_stage, producer_group=d_producer_group
|
|
1283
|
+
)
|
|
1284
|
+
epi_read_state = pipeline.make_pipeline_state(
|
|
1285
|
+
pipeline.PipelineUserType.Consumer, self.num_c_stage
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
while work_tile.is_valid_tile:
|
|
1289
|
+
# Get tile coord from tile scheduler
|
|
1290
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1291
|
+
mma_tile_coord_mnl = (
|
|
1292
|
+
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
|
|
1293
|
+
tile_coord_mnkl[1],
|
|
1294
|
+
tile_coord_mnkl[3],
|
|
1295
|
+
)
|
|
1296
|
+
# Local_tile partition global tensors
|
|
1297
|
+
# (bM, bN)
|
|
1298
|
+
gD_mnl = cute.local_tile(
|
|
1299
|
+
mD_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
|
|
1300
|
+
)
|
|
1301
|
+
# Partition global tensor for TiledMMA_A/B/D
|
|
1302
|
+
# (MMA, MMA_M, MMA_N)
|
|
1303
|
+
tDgD = thr_mma.partition_C(gD_mnl)
|
|
1304
|
+
# bSG_gD has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
1305
|
+
bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(tma_atom_d, tDgD, epi_tile, sD)
|
|
1306
|
+
|
|
1307
|
+
# Set tensor memory buffer for current tile
|
|
1308
|
+
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
|
1309
|
+
tTR_tAcc = tTR_tAcc_base[None, None, None, None, None, acc_consumer_state.index]
|
|
1310
|
+
|
|
1311
|
+
# Wait for accumulator buffer full
|
|
1312
|
+
acc_pipeline.consumer_wait(acc_consumer_state)
|
|
1313
|
+
|
|
1314
|
+
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
|
1315
|
+
bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
|
|
1316
|
+
|
|
1317
|
+
# Store accumulator to global memory in subtiles
|
|
1318
|
+
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
|
1319
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * subtile_cnt
|
|
1320
|
+
for subtile_idx in cutlass.range(subtile_cnt):
|
|
1321
|
+
# Load accumulator from tensor memory buffer to register
|
|
1322
|
+
tTR_tAcc_mn = tTR_tAcc[None, None, None, subtile_idx]
|
|
1323
|
+
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
|
1324
|
+
# Convert to D type
|
|
1325
|
+
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
|
1326
|
+
acc_vec = epilogue_op(acc_vec)
|
|
1327
|
+
if const_expr(mC_mnl is not None):
|
|
1328
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
1329
|
+
cute.copy(
|
|
1330
|
+
tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
|
|
1331
|
+
)
|
|
1332
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
1333
|
+
cute.arch.fence_proxy(
|
|
1334
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1335
|
+
)
|
|
1336
|
+
cute.arch.sync_warp()
|
|
1337
|
+
with cute.arch.elect_one():
|
|
1338
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
1339
|
+
epi_read_state.advance()
|
|
1340
|
+
acc_vec = acc_vec + tRS_rC.load().to(self.acc_dtype)
|
|
1341
|
+
tRS_rD.store(acc_vec.to(self.d_dtype))
|
|
1342
|
+
# Store D to shared memory
|
|
1343
|
+
d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage
|
|
1344
|
+
cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
|
|
1345
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1346
|
+
cute.arch.fence_proxy(
|
|
1347
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1348
|
+
)
|
|
1349
|
+
epilogue_barrier.arrive_and_wait()
|
|
1350
|
+
# TMA store D to global memory
|
|
1351
|
+
if warp_idx == self.epilog_warp_id[0]:
|
|
1352
|
+
cute.copy(tma_atom_d, bSG_sD[None, d_buffer], bSG_gD[None, subtile_idx])
|
|
1353
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1354
|
+
d_pipeline.producer_commit()
|
|
1355
|
+
d_pipeline.producer_acquire()
|
|
1356
|
+
epilogue_barrier.arrive_and_wait()
|
|
1357
|
+
|
|
1358
|
+
# Async arrive accumulator buffer empty
|
|
1359
|
+
with cute.arch.elect_one():
|
|
1360
|
+
acc_pipeline.consumer_release(acc_consumer_state)
|
|
1361
|
+
acc_consumer_state.advance()
|
|
1362
|
+
|
|
1363
|
+
# Advance to next tile
|
|
1364
|
+
tile_scheduler.advance_to_next_work()
|
|
1365
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1366
|
+
|
|
1367
|
+
# Dealloc the tensor memory buffer
|
|
1368
|
+
if warp_idx == self.epilog_warp_id[0]:
|
|
1369
|
+
cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
|
|
1370
|
+
epilogue_barrier.arrive_and_wait()
|
|
1371
|
+
if warp_idx == self.epilog_warp_id[0]:
|
|
1372
|
+
if use_2cta_instrs:
|
|
1373
|
+
cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
|
|
1374
|
+
cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
|
|
1375
|
+
cute.arch.dealloc_tmem(
|
|
1376
|
+
acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs
|
|
1377
|
+
)
|
|
1378
|
+
|
|
1379
|
+
# Wait for D store complete
|
|
1380
|
+
d_pipeline.producer_tail()
|
|
1381
|
+
|
|
1382
|
+
@cute.jit
|
|
1383
|
+
def load_AB(
|
|
1384
|
+
self,
|
|
1385
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1386
|
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1387
|
+
tma_atom_a: cute.CopyAtom,
|
|
1388
|
+
tAgA: cute.Tensor,
|
|
1389
|
+
tAsA: cute.Tensor,
|
|
1390
|
+
a_mcast_mask: cutlass.Int16,
|
|
1391
|
+
tma_atom_b: cute.CopyAtom,
|
|
1392
|
+
tBgB: cute.Tensor,
|
|
1393
|
+
tBsB: cute.Tensor,
|
|
1394
|
+
b_mcast_mask: cutlass.Int16,
|
|
1395
|
+
tma_atom_sfa: Optional[cute.CopyAtom] = None,
|
|
1396
|
+
tAgSFA: Optional[cute.Tensor] = None,
|
|
1397
|
+
tAsSFA: Optional[cute.Tensor] = None,
|
|
1398
|
+
sfa_mcast_mask: Optional[cutlass.Int16] = None,
|
|
1399
|
+
tma_atom_sfb: Optional[cute.CopyAtom] = None,
|
|
1400
|
+
tBgSFB: Optional[cute.Tensor] = None,
|
|
1401
|
+
tBsSFB: Optional[cute.Tensor] = None,
|
|
1402
|
+
sfb_mcast_mask: Optional[cutlass.Int16] = None,
|
|
1403
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1404
|
+
blockscaled = const_expr(tma_atom_sfa is not None)
|
|
1405
|
+
if const_expr(blockscaled):
|
|
1406
|
+
assert all(x is not None for x in (tma_atom_sfa, tAgSFA, tAsSFA))
|
|
1407
|
+
assert all(x is not None for x in (tma_atom_sfb, tBgSFB, tBsSFB))
|
|
1408
|
+
k_tile_cnt = cute.size(tAgA, mode=[1])
|
|
1409
|
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1410
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1411
|
+
if 0 < k_tile_cnt:
|
|
1412
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1413
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1414
|
+
# TMA load
|
|
1415
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1416
|
+
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
1417
|
+
# Wait for A/B buffers to be empty before loading into them
|
|
1418
|
+
# Also sets the transaction barrier for the A/B buffers
|
|
1419
|
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
|
1420
|
+
cute.copy(
|
|
1421
|
+
tma_atom_a,
|
|
1422
|
+
tAgA[None, k_tile],
|
|
1423
|
+
tAsA[None, ab_producer_state.index],
|
|
1424
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1425
|
+
mcast_mask=a_mcast_mask,
|
|
1426
|
+
)
|
|
1427
|
+
cute.copy(
|
|
1428
|
+
tma_atom_b,
|
|
1429
|
+
tBgB[None, k_tile],
|
|
1430
|
+
tBsB[None, ab_producer_state.index],
|
|
1431
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1432
|
+
mcast_mask=b_mcast_mask,
|
|
1433
|
+
)
|
|
1434
|
+
if const_expr(blockscaled):
|
|
1435
|
+
cute.copy(
|
|
1436
|
+
tma_atom_sfa,
|
|
1437
|
+
tAgSFA[None, ab_producer_state.count],
|
|
1438
|
+
tAsSFA[None, ab_producer_state.index],
|
|
1439
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1440
|
+
mcast_mask=sfa_mcast_mask,
|
|
1441
|
+
)
|
|
1442
|
+
cute.copy(
|
|
1443
|
+
tma_atom_sfb,
|
|
1444
|
+
tBgSFB[None, ab_producer_state.count],
|
|
1445
|
+
tBsSFB[None, ab_producer_state.index],
|
|
1446
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1447
|
+
mcast_mask=sfb_mcast_mask,
|
|
1448
|
+
)
|
|
1449
|
+
# Mainloop pipeline's producer commit is a NOP
|
|
1450
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1451
|
+
ab_producer_state.advance()
|
|
1452
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1453
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1454
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1455
|
+
return ab_producer_state
|
|
1456
|
+
|
|
1457
|
+
@cute.jit
|
|
1458
|
+
def mma(
|
|
1459
|
+
self,
|
|
1460
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1461
|
+
acc_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1462
|
+
ab_consumer_state: cutlass.pipeline.PipelineState,
|
|
1463
|
+
acc_producer_state: cutlass.pipeline.PipelineState,
|
|
1464
|
+
tiled_mma: cute.TiledMma,
|
|
1465
|
+
tCrA: cute.Tensor,
|
|
1466
|
+
tCrB: cute.Tensor,
|
|
1467
|
+
acc: cute.Tensor,
|
|
1468
|
+
k_tile_cnt: Int32,
|
|
1469
|
+
is_leader_cta: cutlass.Boolean,
|
|
1470
|
+
tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None,
|
|
1471
|
+
tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None,
|
|
1472
|
+
tCsSFA_compact_s2t: Optional[cute.Tensor] = None,
|
|
1473
|
+
tCsSFB_compact_s2t: Optional[cute.Tensor] = None,
|
|
1474
|
+
tCtSFA_compact_s2t: Optional[cute.Tensor] = None,
|
|
1475
|
+
tCtSFB_compact_s2t: Optional[cute.Tensor] = None,
|
|
1476
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]:
|
|
1477
|
+
blockscaled = const_expr(tiled_copy_s2t_sfa is not None)
|
|
1478
|
+
if const_expr(blockscaled):
|
|
1479
|
+
assert all(x is not None for x in (tiled_copy_s2t_sfa, tiled_copy_s2t_sfb))
|
|
1480
|
+
assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t))
|
|
1481
|
+
assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t))
|
|
1482
|
+
# Peek (try_wait) AB buffer full for k_tile = 0
|
|
1483
|
+
peek_ab_full_status = cutlass.Boolean(True)
|
|
1484
|
+
if 0 < k_tile_cnt and is_leader_cta:
|
|
1485
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
|
1486
|
+
# Wait for accumulator buffer empty
|
|
1487
|
+
if is_leader_cta:
|
|
1488
|
+
acc_pipeline.producer_acquire(acc_producer_state)
|
|
1489
|
+
# Reset the ACCUMULATE field for each tile
|
|
1490
|
+
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
|
1491
|
+
# Mma mainloop
|
|
1492
|
+
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
1493
|
+
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
1494
|
+
if is_leader_cta:
|
|
1495
|
+
# Conditionally wait for AB buffer full
|
|
1496
|
+
ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
|
|
1497
|
+
# Copy SFA/SFB from smem to tmem
|
|
1498
|
+
if const_expr(blockscaled):
|
|
1499
|
+
s2t_stage_coord = (None, None, None, None, ab_consumer_state.index)
|
|
1500
|
+
tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord]
|
|
1501
|
+
tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord]
|
|
1502
|
+
cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t)
|
|
1503
|
+
cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
|
|
1504
|
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1505
|
+
k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index)
|
|
1506
|
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1507
|
+
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
|
1508
|
+
# Async arrive AB buffer empty
|
|
1509
|
+
ab_pipeline.consumer_release(ab_consumer_state)
|
|
1510
|
+
ab_consumer_state.advance()
|
|
1511
|
+
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
|
1512
|
+
peek_ab_full_status = cutlass.Boolean(True)
|
|
1513
|
+
if k_tile + 1 < k_tile_cnt and is_leader_cta:
|
|
1514
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
|
|
1515
|
+
# Async arrive accumulator buffer full
|
|
1516
|
+
if is_leader_cta:
|
|
1517
|
+
acc_pipeline.producer_commit(acc_producer_state)
|
|
1518
|
+
acc_producer_state.advance()
|
|
1519
|
+
# If we don't return the tiled_mma, we get compiler error
|
|
1520
|
+
# "operand #0 does not dominate this use"
|
|
1521
|
+
return ab_consumer_state, acc_producer_state, tiled_mma
|
|
1522
|
+
|
|
1523
|
+
def mainloop_s2t_copy_and_partition(
|
|
1524
|
+
self,
|
|
1525
|
+
sSF: cute.Tensor,
|
|
1526
|
+
tSF: cute.Tensor,
|
|
1527
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1528
|
+
"""
|
|
1529
|
+
Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination).
|
|
1530
|
+
|
|
1531
|
+
:param sSF: The scale factor tensor in smem
|
|
1532
|
+
:type sSF: cute.Tensor
|
|
1533
|
+
:param tSF: The scale factor tensor in tmem
|
|
1534
|
+
:type tSF: cute.Tensor
|
|
1535
|
+
|
|
1536
|
+
:return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where:
|
|
1537
|
+
- tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t)
|
|
1538
|
+
- tCsSF_compact_s2t: The partitioned scale factor tensor in smem
|
|
1539
|
+
- tSF_compact_s2t: The partitioned scale factor tensor in tmem
|
|
1540
|
+
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
1541
|
+
"""
|
|
1542
|
+
# (MMA, MMA_MN, MMA_K, STAGE)
|
|
1543
|
+
tCsSF_compact = cute.filter_zeros(sSF)
|
|
1544
|
+
# (MMA, MMA_MN, MMA_K)
|
|
1545
|
+
tCtSF_compact = cute.filter_zeros(tSF)
|
|
1546
|
+
# Make S2T CopyAtom and tiledCopy
|
|
1547
|
+
copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype)
|
|
1548
|
+
tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
|
|
1549
|
+
thr_copy_s2t = tiled_copy_s2t.get_slice(0)
|
|
1550
|
+
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
|
|
1551
|
+
tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
|
|
1552
|
+
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
|
|
1553
|
+
tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
|
|
1554
|
+
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K)
|
|
1555
|
+
tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
|
|
1556
|
+
return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
|
|
1557
|
+
|
|
1558
|
+
def epilog_tmem_copy_and_partition(
|
|
1559
|
+
self,
|
|
1560
|
+
tidx: Int32,
|
|
1561
|
+
tAcc: cute.Tensor,
|
|
1562
|
+
epi_tile: cute.Tile,
|
|
1563
|
+
use_2cta_instrs: Union[cutlass.Boolean, bool],
|
|
1564
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1565
|
+
"""
|
|
1566
|
+
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
|
|
1567
|
+
|
|
1568
|
+
:param tidx: The thread index in epilogue warp groups
|
|
1569
|
+
:type tidx: Int32
|
|
1570
|
+
:param tAcc: The accumulator tensor to be copied and partitioned
|
|
1571
|
+
:type tAcc: cute.Tensor
|
|
1572
|
+
:param epi_tile: The epilogue tiler
|
|
1573
|
+
:type epi_tile: cute.Tile
|
|
1574
|
+
:param use_2cta_instrs: Whether use_2cta_instrs is enabled
|
|
1575
|
+
:type use_2cta_instrs: bool
|
|
1576
|
+
|
|
1577
|
+
:return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
|
|
1578
|
+
- tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
|
1579
|
+
- tTR_tAcc: The partitioned accumulator tensor
|
|
1580
|
+
- tTR_rAcc: The accumulated tensor in register used to hold t2r results
|
|
1581
|
+
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
1582
|
+
"""
|
|
1583
|
+
# Make tiledCopy for tensor memory load
|
|
1584
|
+
copy_atom_t2r = sm100_utils.get_tmem_load_op(
|
|
1585
|
+
self.cta_tile_shape_mnk,
|
|
1586
|
+
self.d_layout,
|
|
1587
|
+
self.d_dtype,
|
|
1588
|
+
self.acc_dtype,
|
|
1589
|
+
epi_tile,
|
|
1590
|
+
use_2cta_instrs,
|
|
1591
|
+
)
|
|
1592
|
+
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
|
|
1593
|
+
tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile)
|
|
1594
|
+
# (EPI_TILE_M, EPI_TILE_N)
|
|
1595
|
+
tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
|
|
1596
|
+
|
|
1597
|
+
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
|
1598
|
+
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
|
|
1599
|
+
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
|
|
1600
|
+
|
|
1601
|
+
cAcc = cute.make_identity_tensor((self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]))
|
|
1602
|
+
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
|
|
1603
|
+
cAcc_epi = cute.flat_divide(cAcc, epi_tile)
|
|
1604
|
+
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
|
1605
|
+
tTR_cAcc = thr_copy_t2r.partition_D(cAcc_epi)
|
|
1606
|
+
# (T2R, T2R_M, T2R_N)
|
|
1607
|
+
tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
|
|
1608
|
+
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
|
|
1609
|
+
|
|
1610
|
+
def epilog_smem_copy_and_partition(
|
|
1611
|
+
self,
|
|
1612
|
+
tiled_copy_t2r: cute.TiledCopy,
|
|
1613
|
+
tTR_rD: cute.Tensor,
|
|
1614
|
+
tidx: Int32,
|
|
1615
|
+
sD: cute.Tensor,
|
|
1616
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1617
|
+
"""
|
|
1618
|
+
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
|
|
1619
|
+
|
|
1620
|
+
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
|
1621
|
+
:type tiled_copy_t2r: cute.TiledCopy
|
|
1622
|
+
:param tTR_rD: The partitioned accumulator tensor
|
|
1623
|
+
:type tTR_rD: cute.Tensor
|
|
1624
|
+
:param tidx: The thread index in epilogue warp groups
|
|
1625
|
+
:type tidx: Int32
|
|
1626
|
+
:param sD: The shared memory tensor to be copied and partitioned
|
|
1627
|
+
:type sD: cute.Tensor
|
|
1628
|
+
:type sepi: cute.Tensor
|
|
1629
|
+
|
|
1630
|
+
:return: A tuple containing (tiled_copy_r2s, tRS_rD, tRS_sD) where:
|
|
1631
|
+
- tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s)
|
|
1632
|
+
- tRS_rD: The partitioned tensor C (register source)
|
|
1633
|
+
- tRS_sD: The partitioned tensor C (smem destination)
|
|
1634
|
+
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
1635
|
+
"""
|
|
1636
|
+
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
|
1637
|
+
self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r
|
|
1638
|
+
)
|
|
1639
|
+
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
|
1640
|
+
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1641
|
+
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1642
|
+
tRS_sD = thr_copy_r2s.partition_D(sD)
|
|
1643
|
+
# (R2S, R2S_M, R2S_N)
|
|
1644
|
+
tRS_rD = tiled_copy_r2s.retile(tTR_rD)
|
|
1645
|
+
return tiled_copy_r2s, tRS_rD, tRS_sD
|
|
1646
|
+
|
|
1647
|
+
# def epilog_smem_load_copy_and_partition(
|
|
1648
|
+
# self,
|
|
1649
|
+
# tiled_copy_t2r: cute.TiledCopy,
|
|
1650
|
+
# tTR_rC: cute.Tensor,
|
|
1651
|
+
# tidx: Int32,
|
|
1652
|
+
# sC: cute.Tensor,
|
|
1653
|
+
# ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1654
|
+
# copy_atom_s2r = cute.make_copy_atom(
|
|
1655
|
+
# warp.LdMatrix8x8x16bOp(self.c_layout.is_m_major_c(), num_matrices=4),
|
|
1656
|
+
# self.c_dtype, # TODO: this probably only works for f16 for now?
|
|
1657
|
+
# )
|
|
1658
|
+
# # copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
|
|
1659
|
+
# tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
|
|
1660
|
+
# # (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1661
|
+
# thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1662
|
+
# # (R2S, R2S_M, R2S_N)
|
|
1663
|
+
# tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1664
|
+
# return tiled_copy_s2r, tSR_sC
|
|
1665
|
+
|
|
1666
|
+
def epilog_gmem_copy_and_partition(
|
|
1667
|
+
self,
|
|
1668
|
+
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
|
1669
|
+
gD_mnl: cute.Tensor,
|
|
1670
|
+
epi_tile: cute.Tile,
|
|
1671
|
+
sD: cute.Tensor,
|
|
1672
|
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
1673
|
+
"""Make tiledCopy for global memory store, then use it to:
|
|
1674
|
+
- partition register array (source) and global memory (destination) for none TMA store version;
|
|
1675
|
+
- partition shared memory (source) and global memory (destination) for TMA store version.
|
|
1676
|
+
|
|
1677
|
+
:param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
|
|
1678
|
+
:type atom: cute.CopyAtom or cute.TiledCopy
|
|
1679
|
+
:param gD_mnl: The global tensor C
|
|
1680
|
+
:type gD_mnl: cute.Tensor
|
|
1681
|
+
:param epi_tile: The epilogue tiler
|
|
1682
|
+
:type epi_tile: cute.Tile
|
|
1683
|
+
:param sD: The shared memory tensor to be copied and partitioned
|
|
1684
|
+
:type sD: cute.Tensor
|
|
1685
|
+
|
|
1686
|
+
:return: A tuple containing either:
|
|
1687
|
+
- For TMA store: (tma_atom_d, bSG_sD, bSG_gD) where:
|
|
1688
|
+
- tma_atom_d: The TMA copy atom
|
|
1689
|
+
- bSG_sD: The partitioned shared memory tensor C
|
|
1690
|
+
- bSG_gD: The partitioned global tensor C
|
|
1691
|
+
- For non-TMA store: (simt_atom, tTR_rD, tTR_gD) where:
|
|
1692
|
+
- simt_atom: The SIMT copy atom
|
|
1693
|
+
- tTR_rD: The register tensor C
|
|
1694
|
+
- tTR_gD: The partitioned global tensor C
|
|
1695
|
+
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
|
|
1696
|
+
"""
|
|
1697
|
+
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
|
|
1698
|
+
gD_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0)], epi_tile)
|
|
1699
|
+
sD_for_tma_partition = cute.group_modes(sD, 0, 2)
|
|
1700
|
+
gD_for_tma_partition = cute.group_modes(gD_epi, 0, 2)
|
|
1701
|
+
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
1702
|
+
bSG_sD, bSG_gD = cpasync.tma_partition(
|
|
1703
|
+
atom,
|
|
1704
|
+
0,
|
|
1705
|
+
cute.make_layout(1),
|
|
1706
|
+
sD_for_tma_partition,
|
|
1707
|
+
gD_for_tma_partition,
|
|
1708
|
+
)
|
|
1709
|
+
return bSG_sD, bSG_gD
|
|
1710
|
+
|
|
1711
|
+
@staticmethod
|
|
1712
|
+
def _compute_stages(
|
|
1713
|
+
tiled_mma: cute.TiledMma,
|
|
1714
|
+
mma_tiler_mnk: Tuple[int, int, int],
|
|
1715
|
+
a_dtype: Type[cutlass.Numeric],
|
|
1716
|
+
b_dtype: Type[cutlass.Numeric],
|
|
1717
|
+
epi_tile: cute.Tile,
|
|
1718
|
+
d_dtype: Type[cutlass.Numeric],
|
|
1719
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1720
|
+
d_layout: cutlass.utils.LayoutEnum,
|
|
1721
|
+
c_layout: Optional[cutlass.utils.LayoutEnum],
|
|
1722
|
+
sf_dtype: Optional[Type[cutlass.Numeric]],
|
|
1723
|
+
sf_vec_size: Optional[int],
|
|
1724
|
+
smem_capacity: int,
|
|
1725
|
+
occupancy: int,
|
|
1726
|
+
) -> Tuple[int, int, int]:
|
|
1727
|
+
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
1728
|
+
|
|
1729
|
+
:param tiled_mma: The tiled MMA object defining the core computation.
|
|
1730
|
+
:type tiled_mma: cute.TiledMma
|
|
1731
|
+
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
|
|
1732
|
+
:type mma_tiler_mnk: tuple[int, int, int]
|
|
1733
|
+
:param a_dtype: Data type of operand A.
|
|
1734
|
+
:type a_dtype: type[cutlass.Numeric]
|
|
1735
|
+
:param b_dtype: Data type of operand B.
|
|
1736
|
+
:type b_dtype: type[cutlass.Numeric]
|
|
1737
|
+
:param epi_tile: The epilogue tile shape.
|
|
1738
|
+
:type epi_tile: cute.Tile
|
|
1739
|
+
:param d_dtype: Data type of operand C (output).
|
|
1740
|
+
:type d_dtype: type[cutlass.Numeric]
|
|
1741
|
+
:param d_layout: Layout enum of operand C.
|
|
1742
|
+
:type d_layout: cutlass.utils.LayoutEnum
|
|
1743
|
+
:param smem_capacity: Total available shared memory capacity in bytes.
|
|
1744
|
+
:type smem_capacity: int
|
|
1745
|
+
:param occupancy: Target number of CTAs per SM (occupancy).
|
|
1746
|
+
:type occupancy: int
|
|
1747
|
+
|
|
1748
|
+
:return: A tuple containing the computed number of stages for:
|
|
1749
|
+
(ACC stages, A/B operand stages, C stages)
|
|
1750
|
+
:rtype: tuple[int, int, int]
|
|
1751
|
+
"""
|
|
1752
|
+
blockscaled = sf_dtype is not None
|
|
1753
|
+
# Default ACC stages
|
|
1754
|
+
if const_expr(not blockscaled):
|
|
1755
|
+
num_acc_stage = 2
|
|
1756
|
+
else:
|
|
1757
|
+
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
|
|
1758
|
+
|
|
1759
|
+
# Default D stages
|
|
1760
|
+
num_d_stage = 2
|
|
1761
|
+
num_c_stage = 2 if c_dtype is not None else 0
|
|
1762
|
+
|
|
1763
|
+
# Calculate smem layout and size for one stage of A, B, and C
|
|
1764
|
+
a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
|
|
1765
|
+
tiled_mma,
|
|
1766
|
+
mma_tiler_mnk,
|
|
1767
|
+
a_dtype,
|
|
1768
|
+
1, # a tmp 1 stage is provided
|
|
1769
|
+
)
|
|
1770
|
+
b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(
|
|
1771
|
+
tiled_mma,
|
|
1772
|
+
mma_tiler_mnk,
|
|
1773
|
+
b_dtype,
|
|
1774
|
+
1, # a tmp 1 stage is provided
|
|
1775
|
+
)
|
|
1776
|
+
d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
|
|
1777
|
+
c_smem_layout_staged_one = (
|
|
1778
|
+
sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
|
|
1779
|
+
if c_dtype is not None
|
|
1780
|
+
else None
|
|
1781
|
+
)
|
|
1782
|
+
if const_expr(blockscaled):
|
|
1783
|
+
sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(
|
|
1784
|
+
tiled_mma,
|
|
1785
|
+
mma_tiler_mnk,
|
|
1786
|
+
sf_vec_size,
|
|
1787
|
+
1, # a tmp 1 stage is provided
|
|
1788
|
+
)
|
|
1789
|
+
sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(
|
|
1790
|
+
tiled_mma,
|
|
1791
|
+
mma_tiler_mnk,
|
|
1792
|
+
sf_vec_size,
|
|
1793
|
+
1, # a tmp 1 stage is provided
|
|
1794
|
+
)
|
|
1795
|
+
|
|
1796
|
+
ab_bytes_per_stage = cute.size_in_bytes(
|
|
1797
|
+
a_dtype, a_smem_layout_staged_one
|
|
1798
|
+
) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
|
|
1799
|
+
if const_expr(blockscaled):
|
|
1800
|
+
ab_bytes_per_stage += cute.size_in_bytes(
|
|
1801
|
+
sf_dtype, sfa_smem_layout_staged_one
|
|
1802
|
+
) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
|
|
1803
|
+
mbar_helpers_bytes = 1024
|
|
1804
|
+
d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
|
|
1805
|
+
epi_bytes = d_bytes_per_stage * num_d_stage
|
|
1806
|
+
if const_expr(c_dtype is not None):
|
|
1807
|
+
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
|
|
1808
|
+
epi_bytes += c_bytes_per_stage * num_c_stage
|
|
1809
|
+
|
|
1810
|
+
# Calculate A/B/SFA/SFB stages:
|
|
1811
|
+
# Start with total smem per CTA (capacity / occupancy)
|
|
1812
|
+
# Subtract reserved bytes and initial C stages bytes
|
|
1813
|
+
# Divide remaining by bytes needed per A/B/SFA/SFB stage
|
|
1814
|
+
num_ab_stage = (
|
|
1815
|
+
smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)
|
|
1816
|
+
) // ab_bytes_per_stage
|
|
1817
|
+
|
|
1818
|
+
# Refine epilogue stages:
|
|
1819
|
+
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1820
|
+
# Add remaining unused smem to epilogue
|
|
1821
|
+
num_d_stage += (
|
|
1822
|
+
smem_capacity
|
|
1823
|
+
- occupancy * ab_bytes_per_stage * num_ab_stage
|
|
1824
|
+
- occupancy * (mbar_helpers_bytes + epi_bytes)
|
|
1825
|
+
) // (occupancy * d_bytes_per_stage)
|
|
1826
|
+
return num_acc_stage, num_ab_stage, num_d_stage, num_c_stage
|
|
1827
|
+
|
|
1828
|
+
@staticmethod
|
|
1829
|
+
def _compute_num_tmem_alloc_cols(
|
|
1830
|
+
tiled_mma: cute.TiledMma,
|
|
1831
|
+
mma_tiler: Tuple[int, int, int],
|
|
1832
|
+
num_acc_stage: int,
|
|
1833
|
+
) -> int:
|
|
1834
|
+
"""
|
|
1835
|
+
Compute the number of tensor memory allocation columns.
|
|
1836
|
+
|
|
1837
|
+
:param tiled_mma: The tiled MMA object defining the core computation.
|
|
1838
|
+
:type tiled_mma: cute.TiledMma
|
|
1839
|
+
:param mma_tiler: The shape (M, N, K) of the MMA tile.
|
|
1840
|
+
:type mma_tiler: tuple[int, int, int]
|
|
1841
|
+
:param num_acc_stage: The stage of the accumulator tensor.
|
|
1842
|
+
:type num_acc_stage: int
|
|
1843
|
+
|
|
1844
|
+
:return: The number of tensor memory allocation columns.
|
|
1845
|
+
:rtype: int
|
|
1846
|
+
"""
|
|
1847
|
+
acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2])
|
|
1848
|
+
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage))
|
|
1849
|
+
num_tmem_alloc_cols = cutlass.utils.get_num_tmem_alloc_cols(tCtAcc_fake)
|
|
1850
|
+
return num_tmem_alloc_cols
|
|
1851
|
+
|
|
1852
|
+
@staticmethod
|
|
1853
|
+
def is_valid_dtypes(
|
|
1854
|
+
ab_dtype: Type[cutlass.Numeric],
|
|
1855
|
+
acc_dtype: Type[cutlass.Numeric],
|
|
1856
|
+
d_dtype: Type[cutlass.Numeric],
|
|
1857
|
+
) -> bool:
|
|
1858
|
+
"""
|
|
1859
|
+
Check if the dtypes are valid
|
|
1860
|
+
|
|
1861
|
+
:param ab_dtype: The data type of the A and B operands
|
|
1862
|
+
:type ab_dtype: Type[cutlass.Numeric]
|
|
1863
|
+
:param acc_dtype: The data type of the accumulator
|
|
1864
|
+
:type acc_dtype: Type[cutlass.Numeric]
|
|
1865
|
+
:param d_dtype: The data type of the output tensor
|
|
1866
|
+
:type d_dtype: Type[cutlass.Numeric]
|
|
1867
|
+
|
|
1868
|
+
:return: True if the dtypes are valid, False otherwise
|
|
1869
|
+
:rtype: bool
|
|
1870
|
+
"""
|
|
1871
|
+
is_valid = True
|
|
1872
|
+
if ab_dtype not in {
|
|
1873
|
+
cutlass.Float16,
|
|
1874
|
+
cutlass.BFloat16,
|
|
1875
|
+
cutlass.TFloat32,
|
|
1876
|
+
cutlass.Uint8,
|
|
1877
|
+
cutlass.Int8,
|
|
1878
|
+
cutlass.Float8E4M3FN,
|
|
1879
|
+
cutlass.Float8E5M2,
|
|
1880
|
+
}:
|
|
1881
|
+
is_valid = False
|
|
1882
|
+
if (
|
|
1883
|
+
acc_dtype not in {cutlass.Float32, cutlass.Float16, Int32}
|
|
1884
|
+
or acc_dtype == cutlass.Float16
|
|
1885
|
+
and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
|
|
1886
|
+
or acc_dtype == Int32
|
|
1887
|
+
and ab_dtype not in {cutlass.Uint8, cutlass.Int8}
|
|
1888
|
+
):
|
|
1889
|
+
is_valid = False
|
|
1890
|
+
if (
|
|
1891
|
+
acc_dtype == cutlass.Float32
|
|
1892
|
+
and d_dtype
|
|
1893
|
+
not in {
|
|
1894
|
+
cutlass.Float32,
|
|
1895
|
+
cutlass.Float16,
|
|
1896
|
+
cutlass.BFloat16,
|
|
1897
|
+
cutlass.Float8E4M3FN,
|
|
1898
|
+
cutlass.Float8E5M2,
|
|
1899
|
+
Int32,
|
|
1900
|
+
cutlass.Int8,
|
|
1901
|
+
cutlass.Uint8,
|
|
1902
|
+
}
|
|
1903
|
+
or acc_dtype == cutlass.Float16
|
|
1904
|
+
and d_dtype
|
|
1905
|
+
not in {
|
|
1906
|
+
cutlass.BFloat16,
|
|
1907
|
+
cutlass.Float16,
|
|
1908
|
+
}
|
|
1909
|
+
or acc_dtype == Int32
|
|
1910
|
+
and d_dtype
|
|
1911
|
+
not in {
|
|
1912
|
+
cutlass.BFloat16,
|
|
1913
|
+
cutlass.Float16,
|
|
1914
|
+
cutlass.Float32,
|
|
1915
|
+
Int32,
|
|
1916
|
+
cutlass.Int8,
|
|
1917
|
+
cutlass.Uint8,
|
|
1918
|
+
}
|
|
1919
|
+
):
|
|
1920
|
+
is_valid = False
|
|
1921
|
+
return is_valid
|
|
1922
|
+
|
|
1923
|
+
@staticmethod
|
|
1924
|
+
def is_valid_dtypes_and_scale_factor_vec_size(
|
|
1925
|
+
ab_dtype: Type[cutlass.Numeric],
|
|
1926
|
+
sf_dtype: Type[cutlass.Numeric],
|
|
1927
|
+
sf_vec_size: int,
|
|
1928
|
+
d_dtype: Type[cutlass.Numeric],
|
|
1929
|
+
) -> bool:
|
|
1930
|
+
"""
|
|
1931
|
+
Check if the dtypes and sf_vec_size are valid combinations
|
|
1932
|
+
|
|
1933
|
+
:param ab_dtype: The data type of the A and B operands
|
|
1934
|
+
:type ab_dtype: Type[cutlass.Numeric]
|
|
1935
|
+
:param sf_dtype: The data type of the scale factor
|
|
1936
|
+
:type sf_dtype: Type[cutlass.Numeric]
|
|
1937
|
+
:param sf_vec_size: The vector size of the scale factor
|
|
1938
|
+
:type sf_vec_size: int
|
|
1939
|
+
:param d_dtype: The data type of the output tensor
|
|
1940
|
+
:type d_dtype: Type[cutlass.Numeric]
|
|
1941
|
+
|
|
1942
|
+
:return: True if the dtypes and sf_vec_size are valid, False otherwise
|
|
1943
|
+
:rtype: bool
|
|
1944
|
+
"""
|
|
1945
|
+
is_valid = True
|
|
1946
|
+
|
|
1947
|
+
# Check valid ab_dtype
|
|
1948
|
+
if ab_dtype not in {cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
|
1949
|
+
is_valid = False
|
|
1950
|
+
|
|
1951
|
+
# Check valid sf_vec_size
|
|
1952
|
+
if sf_vec_size not in {16, 32}:
|
|
1953
|
+
is_valid = False
|
|
1954
|
+
|
|
1955
|
+
# Check valid sf_dtype
|
|
1956
|
+
if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}:
|
|
1957
|
+
is_valid = False
|
|
1958
|
+
|
|
1959
|
+
# Check valid sf_dtype and sf_vec_size combinations
|
|
1960
|
+
if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32:
|
|
1961
|
+
is_valid = False
|
|
1962
|
+
if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16:
|
|
1963
|
+
is_valid = False
|
|
1964
|
+
|
|
1965
|
+
# Check valid d_dtype
|
|
1966
|
+
if d_dtype not in {
|
|
1967
|
+
cutlass.Float32,
|
|
1968
|
+
cutlass.Float16,
|
|
1969
|
+
cutlass.BFloat16,
|
|
1970
|
+
cutlass.Float8E5M2,
|
|
1971
|
+
cutlass.Float8E4M3FN,
|
|
1972
|
+
}:
|
|
1973
|
+
is_valid = False
|
|
1974
|
+
|
|
1975
|
+
return is_valid
|
|
1976
|
+
|
|
1977
|
+
@staticmethod
|
|
1978
|
+
def is_valid_layouts(
|
|
1979
|
+
ab_dtype: Type[cutlass.Numeric],
|
|
1980
|
+
a_major: str,
|
|
1981
|
+
b_major: str,
|
|
1982
|
+
) -> bool:
|
|
1983
|
+
"""
|
|
1984
|
+
Check if the dtypes and sf_vec_size are valid combinations
|
|
1985
|
+
|
|
1986
|
+
:param ab_dtype: The data type of the A and B operands
|
|
1987
|
+
:type ab_dtype: Type[cutlass.Numeric]
|
|
1988
|
+
:param d_dtype: The data type of the output tensor
|
|
1989
|
+
:type d_dtype: Type[cutlass.Numeric]
|
|
1990
|
+
:param a_major: The major dimension of the A tensor
|
|
1991
|
+
:type a_major: str
|
|
1992
|
+
:param b_major: The major dimension of the B tensor
|
|
1993
|
+
:type b_major: str
|
|
1994
|
+
:param d_major: The major dimension of the C tensor
|
|
1995
|
+
:type d_major: str
|
|
1996
|
+
|
|
1997
|
+
:return: True if the layouts are valid, False otherwise
|
|
1998
|
+
:rtype: bool
|
|
1999
|
+
"""
|
|
2000
|
+
is_valid = True
|
|
2001
|
+
if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
|
|
2002
|
+
is_valid = False
|
|
2003
|
+
return is_valid
|
|
2004
|
+
|
|
2005
|
+
@staticmethod
|
|
2006
|
+
def is_valid_mma_tiler_and_cluster_shape(
|
|
2007
|
+
use_2cta_instrs: bool,
|
|
2008
|
+
mma_tiler_mn: Tuple[int, int],
|
|
2009
|
+
cluster_shape_mn: Tuple[int, int],
|
|
2010
|
+
blockscaled: bool,
|
|
2011
|
+
) -> bool:
|
|
2012
|
+
"""
|
|
2013
|
+
Check if the mma tiler and cluster shape are valid
|
|
2014
|
+
|
|
2015
|
+
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
2016
|
+
:type use_2cta_instrs: bool
|
|
2017
|
+
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
2018
|
+
:type mma_tiler_mn: Tuple[int, int]
|
|
2019
|
+
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
2020
|
+
:type cluster_shape_mn: Tuple[int, int]
|
|
2021
|
+
|
|
2022
|
+
:return: True if the mma tiler and cluster shape are valid, False otherwise
|
|
2023
|
+
:rtype: bool
|
|
2024
|
+
"""
|
|
2025
|
+
is_valid = True
|
|
2026
|
+
# Skip invalid mma tile shape
|
|
2027
|
+
if not (
|
|
2028
|
+
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
|
|
2029
|
+
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
|
|
2030
|
+
):
|
|
2031
|
+
is_valid = False
|
|
2032
|
+
if not blockscaled:
|
|
2033
|
+
if mma_tiler_mn[1] not in range(32, 257, 32):
|
|
2034
|
+
is_valid = False
|
|
2035
|
+
else:
|
|
2036
|
+
if mma_tiler_mn[1] not in [128, 256]:
|
|
2037
|
+
is_valid = False
|
|
2038
|
+
# Skip illegal cluster shape
|
|
2039
|
+
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
|
|
2040
|
+
is_valid = False
|
|
2041
|
+
# Skip invalid cluster shape
|
|
2042
|
+
is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
|
|
2043
|
+
if (
|
|
2044
|
+
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
|
|
2045
|
+
or cluster_shape_mn[0] <= 0
|
|
2046
|
+
or cluster_shape_mn[1] <= 0
|
|
2047
|
+
or not is_power_of_2(cluster_shape_mn[0])
|
|
2048
|
+
or not is_power_of_2(cluster_shape_mn[1])
|
|
2049
|
+
):
|
|
2050
|
+
is_valid = False
|
|
2051
|
+
if blockscaled:
|
|
2052
|
+
# Special cluster shape check for scale factor multicasts.
|
|
2053
|
+
# Due to limited size of scale factors, we can't multicast among more than 4 CTAs.
|
|
2054
|
+
if cluster_shape_mn[0] > 4 or cluster_shape_mn[1] > 4:
|
|
2055
|
+
is_valid = False
|
|
2056
|
+
return is_valid
|
|
2057
|
+
|
|
2058
|
+
@staticmethod
|
|
2059
|
+
def is_valid_tensor_alignment(
|
|
2060
|
+
m: int,
|
|
2061
|
+
n: int,
|
|
2062
|
+
k: int,
|
|
2063
|
+
l: int,
|
|
2064
|
+
ab_dtype: Type[cutlass.Numeric],
|
|
2065
|
+
d_dtype: Type[cutlass.Numeric],
|
|
2066
|
+
a_major: str,
|
|
2067
|
+
b_major: str,
|
|
2068
|
+
d_major: str,
|
|
2069
|
+
) -> bool:
|
|
2070
|
+
"""
|
|
2071
|
+
Check if the tensor alignment is valid
|
|
2072
|
+
|
|
2073
|
+
:param m: The number of rows in the A tensor
|
|
2074
|
+
:type m: int
|
|
2075
|
+
:param n: The number of columns in the B tensor
|
|
2076
|
+
:type n: int
|
|
2077
|
+
:param k: The number of columns in the A tensor
|
|
2078
|
+
:type k: int
|
|
2079
|
+
:param l: The number of columns in the C tensor
|
|
2080
|
+
:type l: int
|
|
2081
|
+
:param ab_dtype: The data type of the A and B operands
|
|
2082
|
+
:type ab_dtype: Type[cutlass.Numeric]
|
|
2083
|
+
:param d_dtype: The data type of the output tensor
|
|
2084
|
+
:type d_dtype: Type[cutlass.Numeric]
|
|
2085
|
+
:param a_major: The major axis of the A tensor
|
|
2086
|
+
:type a_major: str
|
|
2087
|
+
:param b_major: The major axis of the B tensor
|
|
2088
|
+
:type b_major: str
|
|
2089
|
+
:param d_major: The major axis of the C tensor
|
|
2090
|
+
:type d_major: str
|
|
2091
|
+
|
|
2092
|
+
:return: True if the problem shape is valid, False otherwise
|
|
2093
|
+
:rtype: bool
|
|
2094
|
+
"""
|
|
2095
|
+
is_valid = True
|
|
2096
|
+
|
|
2097
|
+
def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
|
|
2098
|
+
major_mode_idx = 0 if is_mode0_major else 1
|
|
2099
|
+
num_major_elements = tensor_shape[major_mode_idx]
|
|
2100
|
+
num_contiguous_elements = 16 * 8 // dtype.width
|
|
2101
|
+
return num_major_elements % num_contiguous_elements == 0
|
|
2102
|
+
|
|
2103
|
+
if (
|
|
2104
|
+
not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
|
|
2105
|
+
or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
|
|
2106
|
+
or not check_contigous_16B_alignment(d_dtype, d_major == "m", (m, n, l))
|
|
2107
|
+
):
|
|
2108
|
+
is_valid = False
|
|
2109
|
+
return is_valid
|
|
2110
|
+
|
|
2111
|
+
@staticmethod
|
|
2112
|
+
def can_implement(
|
|
2113
|
+
ab_dtype: Type[cutlass.Numeric],
|
|
2114
|
+
acc_dtype: Type[cutlass.Numeric],
|
|
2115
|
+
d_dtype: Type[cutlass.Numeric],
|
|
2116
|
+
use_2cta_instrs: bool,
|
|
2117
|
+
mma_tiler_mn: Tuple[int, int],
|
|
2118
|
+
cluster_shape_mn: Tuple[int, int],
|
|
2119
|
+
m: int,
|
|
2120
|
+
n: int,
|
|
2121
|
+
k: int,
|
|
2122
|
+
l: int,
|
|
2123
|
+
a_major: str,
|
|
2124
|
+
b_major: str,
|
|
2125
|
+
d_major: str,
|
|
2126
|
+
) -> bool:
|
|
2127
|
+
"""
|
|
2128
|
+
Check if the gemm can be implemented
|
|
2129
|
+
|
|
2130
|
+
:param ab_dtype: The data type of the A and B operands
|
|
2131
|
+
:type ab_dtype: Type[cutlass.Numeric]
|
|
2132
|
+
:param acc_dtype: The data type of the accumulator
|
|
2133
|
+
:type acc_dtype: Type[cutlass.Numeric]
|
|
2134
|
+
:param d_dtype: The data type of the output tensor
|
|
2135
|
+
:type d_dtype: Type[cutlass.Numeric]
|
|
2136
|
+
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
2137
|
+
:type use_2cta_instrs: bool
|
|
2138
|
+
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
2139
|
+
:type mma_tiler_mn: Tuple[int, int]
|
|
2140
|
+
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
2141
|
+
:type cluster_shape_mn: Tuple[int, int]
|
|
2142
|
+
:param m: The number of rows in the A tensor
|
|
2143
|
+
:type m: int
|
|
2144
|
+
:param n: The number of columns in the B tensor
|
|
2145
|
+
:type n: int
|
|
2146
|
+
:param k: The number of columns in the A tensor
|
|
2147
|
+
:type k: int
|
|
2148
|
+
:param l: The number of columns in the C tensor
|
|
2149
|
+
:type l: int
|
|
2150
|
+
:param a_major: The major axis of the A tensor
|
|
2151
|
+
:type a_major: str
|
|
2152
|
+
:param b_major: The major axis of the B tensor
|
|
2153
|
+
:type b_major: str
|
|
2154
|
+
:param d_major: The major axis of the C tensor
|
|
2155
|
+
:type d_major: str
|
|
2156
|
+
|
|
2157
|
+
:return: True if the gemm can be implemented, False otherwise
|
|
2158
|
+
:rtype: bool
|
|
2159
|
+
"""
|
|
2160
|
+
can_implement = True
|
|
2161
|
+
# Skip unsupported types
|
|
2162
|
+
if not PersistentDenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, d_dtype):
|
|
2163
|
+
can_implement = False
|
|
2164
|
+
# Skip invalid mma tile shape and cluster shape
|
|
2165
|
+
if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
|
|
2166
|
+
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, blockscaled=False
|
|
2167
|
+
):
|
|
2168
|
+
can_implement = False
|
|
2169
|
+
# Skip illegal problem shape for load/store alignment
|
|
2170
|
+
if not PersistentDenseGemmKernel.is_valid_tensor_alignment(
|
|
2171
|
+
m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major
|
|
2172
|
+
):
|
|
2173
|
+
can_implement = False
|
|
2174
|
+
return can_implement
|
|
2175
|
+
|
|
2176
|
+
|
|
2177
|
+
def run(
|
|
2178
|
+
mnkl: Tuple[int, int, int, int],
|
|
2179
|
+
ab_dtype: Type[cutlass.Numeric],
|
|
2180
|
+
d_dtype: Type[cutlass.Numeric],
|
|
2181
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
2182
|
+
acc_dtype: Type[cutlass.Numeric],
|
|
2183
|
+
a_major: str,
|
|
2184
|
+
b_major: str,
|
|
2185
|
+
d_major: str,
|
|
2186
|
+
c_major: str,
|
|
2187
|
+
mma_tiler_mn: Tuple[int, int] = (256, 256),
|
|
2188
|
+
cluster_shape_mn: Tuple[int, int] = (2, 1),
|
|
2189
|
+
use_2cta_instrs: bool = True,
|
|
2190
|
+
tolerance: float = 1e-01,
|
|
2191
|
+
warmup_iterations: int = 0,
|
|
2192
|
+
iterations: int = 1,
|
|
2193
|
+
skip_ref_check: bool = False,
|
|
2194
|
+
dynamic_persistent: bool = False,
|
|
2195
|
+
**kwargs,
|
|
2196
|
+
):
|
|
2197
|
+
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
|
|
2198
|
+
|
|
2199
|
+
This function prepares input tensors, configures and launches the persistent GEMM kernel,
|
|
2200
|
+
optionally performs reference validation, and benchmarks the execution performance.
|
|
2201
|
+
|
|
2202
|
+
:param mnkl: Problem size (M, N, K, L)
|
|
2203
|
+
:type mnkl: Tuple[int, int, int, int]
|
|
2204
|
+
:param ab_dtype: Data type for input tensors A and B
|
|
2205
|
+
:type ab_dtype: Type[cutlass.Numeric]
|
|
2206
|
+
:param d_dtype: Data type for output tensor C
|
|
2207
|
+
:type d_dtype: Type[cutlass.Numeric]
|
|
2208
|
+
:param acc_dtype: Data type for accumulation during matrix multiplication
|
|
2209
|
+
:type acc_dtype: Type[cutlass.Numeric]
|
|
2210
|
+
:param a_major/b_major/d_major: Memory layout of tensor A/B/C
|
|
2211
|
+
:type a_major/b_major/d_major: str
|
|
2212
|
+
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
|
|
2213
|
+
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
2214
|
+
:type mma_tiler_mn: Tuple[int, int], optional
|
|
2215
|
+
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
|
|
2216
|
+
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
2217
|
+
:type cluster_shape_mn: Tuple[int, int], optional
|
|
2218
|
+
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
|
|
2219
|
+
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
2220
|
+
:type use_2cta_instrs: bool, optional
|
|
2221
|
+
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
|
2222
|
+
:type tolerance: float, optional
|
|
2223
|
+
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
|
2224
|
+
:type warmup_iterations: int, optional
|
|
2225
|
+
:param iterations: Number of benchmark iterations to run, defaults to 1
|
|
2226
|
+
:type iterations: int, optional
|
|
2227
|
+
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
|
2228
|
+
:type skip_ref_check: bool, optional
|
|
2229
|
+
:raises RuntimeError: If CUDA GPU is not available
|
|
2230
|
+
:raises ValueError: If the configuration is invalid or unsupported by the kernel
|
|
2231
|
+
:return: Execution time of the GEMM kernel
|
|
2232
|
+
:rtype: float
|
|
2233
|
+
"""
|
|
2234
|
+
print("Running Blackwell Persistent Dense GEMM test with:")
|
|
2235
|
+
print(f"mnkl: {mnkl}")
|
|
2236
|
+
print(f"AB dtype: {ab_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
|
|
2237
|
+
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
|
|
2238
|
+
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
|
2239
|
+
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
|
|
2240
|
+
print(f"Tolerance: {tolerance}")
|
|
2241
|
+
print(f"Warmup iterations: {warmup_iterations}")
|
|
2242
|
+
print(f"Iterations: {iterations}")
|
|
2243
|
+
print(f"Skip reference checking: {skip_ref_check}")
|
|
2244
|
+
|
|
2245
|
+
assert not dynamic_persistent, "Dynamic persistent mode is not supported yet."
|
|
2246
|
+
|
|
2247
|
+
# Unpack parameters
|
|
2248
|
+
m, n, k, l = mnkl
|
|
2249
|
+
|
|
2250
|
+
# Skip unsupported testcase
|
|
2251
|
+
if not PersistentDenseGemmKernel.can_implement(
|
|
2252
|
+
ab_dtype,
|
|
2253
|
+
acc_dtype,
|
|
2254
|
+
d_dtype,
|
|
2255
|
+
use_2cta_instrs,
|
|
2256
|
+
mma_tiler_mn,
|
|
2257
|
+
cluster_shape_mn,
|
|
2258
|
+
m,
|
|
2259
|
+
n,
|
|
2260
|
+
k,
|
|
2261
|
+
l,
|
|
2262
|
+
a_major,
|
|
2263
|
+
b_major,
|
|
2264
|
+
d_major,
|
|
2265
|
+
):
|
|
2266
|
+
raise TypeError(
|
|
2267
|
+
f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}"
|
|
2268
|
+
)
|
|
2269
|
+
|
|
2270
|
+
if not torch.cuda.is_available():
|
|
2271
|
+
raise RuntimeError("GPU is required to run this example!")
|
|
2272
|
+
|
|
2273
|
+
torch.manual_seed(1111)
|
|
2274
|
+
|
|
2275
|
+
# Create and permute tensor A/B/C
|
|
2276
|
+
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
|
|
2277
|
+
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
|
2278
|
+
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
|
2279
|
+
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
|
2280
|
+
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
|
2281
|
+
is_unsigned = dtype in {cutlass.Uint8}
|
|
2282
|
+
# Temporarily use uint8 as torch does not support fp8 type
|
|
2283
|
+
torch_dtype = cutlass_torch.dtype(dtype)
|
|
2284
|
+
gen_dtype = (
|
|
2285
|
+
torch_dtype
|
|
2286
|
+
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
2287
|
+
else torch.bfloat16
|
|
2288
|
+
)
|
|
2289
|
+
|
|
2290
|
+
# Create dtype torch tensor (cpu)
|
|
2291
|
+
torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor(
|
|
2292
|
+
shape,
|
|
2293
|
+
gen_dtype,
|
|
2294
|
+
permute_order=permute_order,
|
|
2295
|
+
# init_type=cutlass.torch.TensorInitType.RANDOM,
|
|
2296
|
+
# init_config=cutlass.torch.RandomInitConfig(
|
|
2297
|
+
# min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
|
|
2298
|
+
# ),
|
|
2299
|
+
init_type=cutlass.torch.TensorInitType.GAUSSIAN,
|
|
2300
|
+
init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
|
|
2301
|
+
).to(torch_dtype)
|
|
2302
|
+
# Create dtype torch tensor (gpu)
|
|
2303
|
+
torch_tensor = torch_tensor_cpu.cuda()
|
|
2304
|
+
|
|
2305
|
+
# Create f32 torch tensor (cpu)
|
|
2306
|
+
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
|
2307
|
+
|
|
2308
|
+
# Create dtype cute tensor (gpu)
|
|
2309
|
+
torch_tensor_view = (
|
|
2310
|
+
torch_tensor
|
|
2311
|
+
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
2312
|
+
else torch_tensor.view(torch.uint8)
|
|
2313
|
+
)
|
|
2314
|
+
cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
|
|
2315
|
+
cute_tensor.element_type = dtype
|
|
2316
|
+
if is_dynamic_layout:
|
|
2317
|
+
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
|
|
2318
|
+
cute_tensor = cutlass_torch.convert_cute_tensor(
|
|
2319
|
+
f32_torch_tensor,
|
|
2320
|
+
cute_tensor,
|
|
2321
|
+
dtype,
|
|
2322
|
+
is_dynamic_layout=is_dynamic_layout,
|
|
2323
|
+
)
|
|
2324
|
+
|
|
2325
|
+
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
|
|
2326
|
+
|
|
2327
|
+
a_ref, mA, a_torch, a_torch_cpu = create_and_permute_tensor(
|
|
2328
|
+
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
|
|
2329
|
+
)
|
|
2330
|
+
b_ref, mB, b_torch, b_torch_cpu = create_and_permute_tensor(
|
|
2331
|
+
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
|
|
2332
|
+
)
|
|
2333
|
+
_, mD, d_torch, d_torch_cpu = create_and_permute_tensor(
|
|
2334
|
+
l, m, n, d_major == "m", d_dtype, is_dynamic_layout=True
|
|
2335
|
+
)
|
|
2336
|
+
if c_dtype is not None:
|
|
2337
|
+
c, mC, c_torch, d_torch_cpu = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
|
2338
|
+
else:
|
|
2339
|
+
c, mC, c_torch = None, None, None
|
|
2340
|
+
|
|
2341
|
+
# Configure gemm kernel
|
|
2342
|
+
gemm = PersistentDenseGemmKernel(
|
|
2343
|
+
acc_dtype,
|
|
2344
|
+
use_2cta_instrs,
|
|
2345
|
+
mma_tiler_mn,
|
|
2346
|
+
cluster_shape_mn,
|
|
2347
|
+
)
|
|
2348
|
+
|
|
2349
|
+
# Compute max active clusters on current device
|
|
2350
|
+
hardware_info = cutlass.utils.HardwareInfo()
|
|
2351
|
+
max_active_clusters = hardware_info.get_max_active_clusters(
|
|
2352
|
+
cluster_shape_mn[0] * cluster_shape_mn[1]
|
|
2353
|
+
)
|
|
2354
|
+
if dynamic_persistent:
|
|
2355
|
+
tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda")
|
|
2356
|
+
else:
|
|
2357
|
+
tile_count_semaphore = None
|
|
2358
|
+
|
|
2359
|
+
# Get current CUDA stream from PyTorch
|
|
2360
|
+
torch_stream = torch.cuda.current_stream()
|
|
2361
|
+
# Get the raw stream pointer as a CUstream
|
|
2362
|
+
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
2363
|
+
# Compile gemm kernel
|
|
2364
|
+
compiled_gemm = cute.compile(
|
|
2365
|
+
gemm,
|
|
2366
|
+
mA,
|
|
2367
|
+
mB,
|
|
2368
|
+
mD,
|
|
2369
|
+
mC,
|
|
2370
|
+
make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2371
|
+
if tile_count_semaphore is not None
|
|
2372
|
+
else None,
|
|
2373
|
+
max_active_clusters,
|
|
2374
|
+
current_stream,
|
|
2375
|
+
)
|
|
2376
|
+
|
|
2377
|
+
if not skip_ref_check:
|
|
2378
|
+
compiled_gemm(mA, mB, mD, mC, tile_count_semaphore, current_stream)
|
|
2379
|
+
if ab_dtype in {
|
|
2380
|
+
cutlass.Int8,
|
|
2381
|
+
cutlass.Uint8,
|
|
2382
|
+
cutlass.Float8E4M3FN,
|
|
2383
|
+
cutlass.Float8E5M2,
|
|
2384
|
+
}:
|
|
2385
|
+
ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu())
|
|
2386
|
+
else:
|
|
2387
|
+
ref = torch.einsum("mkl,nkl->mnl", a_ref, b_ref)
|
|
2388
|
+
if c is not None:
|
|
2389
|
+
ref = ref + c
|
|
2390
|
+
ref = ref.cpu()
|
|
2391
|
+
|
|
2392
|
+
# Copy gpu result back
|
|
2393
|
+
gpu_d = d_torch.cpu()
|
|
2394
|
+
|
|
2395
|
+
# Convert ref to c_type
|
|
2396
|
+
if d_dtype == cutlass.Float32:
|
|
2397
|
+
ref_d = ref
|
|
2398
|
+
elif d_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
|
2399
|
+
# m major: (l, n, m) -> (m, n, l)
|
|
2400
|
+
# n major: (l, m, n) -> (m, n, l)
|
|
2401
|
+
permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
|
|
2402
|
+
shape = (l, m, n) if d_major == "n" else (l, n, m)
|
|
2403
|
+
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
|
2404
|
+
shape,
|
|
2405
|
+
torch.uint8,
|
|
2406
|
+
permute_order=permute_order,
|
|
2407
|
+
init_type=cutlass_torch.TensorInitType.SKIP,
|
|
2408
|
+
).cuda()
|
|
2409
|
+
# Create dtype cute tensor (gpu)
|
|
2410
|
+
ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
|
|
2411
|
+
leading_dim=(1 if d_major == "n" else 0)
|
|
2412
|
+
)
|
|
2413
|
+
ref_d_tensor.element_type = d_dtype
|
|
2414
|
+
ref_d_tensor = cutlass_torch.convert_cute_tensor(
|
|
2415
|
+
ref,
|
|
2416
|
+
ref_d_tensor,
|
|
2417
|
+
d_dtype,
|
|
2418
|
+
is_dynamic_layout=True,
|
|
2419
|
+
)
|
|
2420
|
+
|
|
2421
|
+
ref_d = f8_torch_tensor.cpu()
|
|
2422
|
+
else:
|
|
2423
|
+
ref_d = ref.to(cutlass_torch.dtype(d_dtype))
|
|
2424
|
+
|
|
2425
|
+
# Reference checking ref_d and gpu_d
|
|
2426
|
+
torch.testing.assert_close(gpu_d, ref_d, atol=tolerance, rtol=1e-05)
|
|
2427
|
+
|
|
2428
|
+
from triton.testing import do_bench
|
|
2429
|
+
|
|
2430
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
2431
|
+
|
|
2432
|
+
flops = 2 * m * n * k * l
|
|
2433
|
+
|
|
2434
|
+
repeats = iterations
|
|
2435
|
+
warmup = warmup_iterations
|
|
2436
|
+
|
|
2437
|
+
import time
|
|
2438
|
+
|
|
2439
|
+
time.sleep(0.5)
|
|
2440
|
+
if ab_dtype.width == 8:
|
|
2441
|
+
assert l == 1
|
|
2442
|
+
scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
|
|
2443
|
+
fn_cublas = lambda: torch._scaled_mm(
|
|
2444
|
+
a_torch[:, :, 0],
|
|
2445
|
+
b_torch[:, :, 0].mT,
|
|
2446
|
+
scale_a=scale_ab,
|
|
2447
|
+
scale_b=scale_ab,
|
|
2448
|
+
out_dtype=torch.bfloat16,
|
|
2449
|
+
# use_fast_accum=fp8_fast_accum,
|
|
2450
|
+
)
|
|
2451
|
+
else:
|
|
2452
|
+
if c_torch is None:
|
|
2453
|
+
fn_cublas = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
|
|
2454
|
+
else:
|
|
2455
|
+
c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
|
|
2456
|
+
fn_cublas = lambda: torch.baddbmm(
|
|
2457
|
+
c_torch_convert.permute(2, 0, 1),
|
|
2458
|
+
a_torch.permute(2, 0, 1),
|
|
2459
|
+
b_torch.permute(2, 0, 1).mT,
|
|
2460
|
+
)
|
|
2461
|
+
timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2462
|
+
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2463
|
+
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2464
|
+
|
|
2465
|
+
time.sleep(0.5)
|
|
2466
|
+
fn = lambda: compiled_gemm(mA, mB, mD, mC, tile_count_semaphore, current_stream)
|
|
2467
|
+
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
2468
|
+
tflops = flops / (timing * 1e9) # Convert to TFlops
|
|
2469
|
+
print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
|
|
2470
|
+
|
|
2471
|
+
# time.sleep(0.5)
|
|
2472
|
+
# timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2473
|
+
# tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2474
|
+
# print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2475
|
+
|
|
2476
|
+
|
|
2477
|
+
if __name__ == "__main__":
|
|
2478
|
+
|
|
2479
|
+
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
|
2480
|
+
try:
|
|
2481
|
+
return tuple(int(x.strip()) for x in s.split(","))
|
|
2482
|
+
except ValueError:
|
|
2483
|
+
raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
|
|
2484
|
+
|
|
2485
|
+
parser = argparse.ArgumentParser(description="Example of Dense Persistent GEMM on Blackwell.")
|
|
2486
|
+
|
|
2487
|
+
parser.add_argument(
|
|
2488
|
+
"--mnkl",
|
|
2489
|
+
type=parse_comma_separated_ints,
|
|
2490
|
+
default=(256, 256, 512, 1),
|
|
2491
|
+
help="mnkl dimensions (comma-separated)",
|
|
2492
|
+
)
|
|
2493
|
+
parser.add_argument(
|
|
2494
|
+
"--mma_tiler_mn",
|
|
2495
|
+
type=parse_comma_separated_ints,
|
|
2496
|
+
default=(128, 128),
|
|
2497
|
+
help="Mma tile shape (comma-separated)",
|
|
2498
|
+
)
|
|
2499
|
+
parser.add_argument(
|
|
2500
|
+
"--cluster_shape_mn",
|
|
2501
|
+
type=parse_comma_separated_ints,
|
|
2502
|
+
default=(1, 1),
|
|
2503
|
+
help="Cluster shape (comma-separated)",
|
|
2504
|
+
)
|
|
2505
|
+
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
2506
|
+
parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
2507
|
+
parser.add_argument("--c_dtype", type=cutlass.dtype, default=None)
|
|
2508
|
+
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
|
|
2509
|
+
parser.add_argument(
|
|
2510
|
+
"--use_2cta_instrs",
|
|
2511
|
+
action="store_true",
|
|
2512
|
+
help="Enable 2CTA MMA instructions feature",
|
|
2513
|
+
)
|
|
2514
|
+
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
2515
|
+
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
2516
|
+
parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
|
|
2517
|
+
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
|
|
2518
|
+
|
|
2519
|
+
parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
|
|
2520
|
+
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
|
|
2521
|
+
parser.add_argument(
|
|
2522
|
+
"--iterations",
|
|
2523
|
+
type=int,
|
|
2524
|
+
default=30,
|
|
2525
|
+
help="Number of iterations to run the kernel",
|
|
2526
|
+
)
|
|
2527
|
+
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
|
|
2528
|
+
parser.add_argument(
|
|
2529
|
+
"--dynamic_persistent", action="store_true", help="Dynamic persistent kernel"
|
|
2530
|
+
)
|
|
2531
|
+
|
|
2532
|
+
args = parser.parse_args()
|
|
2533
|
+
|
|
2534
|
+
if len(args.mnkl) != 4:
|
|
2535
|
+
parser.error("--mnkl must contain exactly 4 values")
|
|
2536
|
+
|
|
2537
|
+
if len(args.mma_tiler_mn) != 2:
|
|
2538
|
+
parser.error("--mma_tiler_mn must contain exactly 2 values")
|
|
2539
|
+
|
|
2540
|
+
if len(args.cluster_shape_mn) != 2:
|
|
2541
|
+
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
|
2542
|
+
|
|
2543
|
+
run(
|
|
2544
|
+
args.mnkl,
|
|
2545
|
+
args.ab_dtype,
|
|
2546
|
+
args.d_dtype,
|
|
2547
|
+
args.c_dtype,
|
|
2548
|
+
args.acc_dtype,
|
|
2549
|
+
args.a_major,
|
|
2550
|
+
args.b_major,
|
|
2551
|
+
args.d_major,
|
|
2552
|
+
args.c_major,
|
|
2553
|
+
args.mma_tiler_mn,
|
|
2554
|
+
args.cluster_shape_mn,
|
|
2555
|
+
args.use_2cta_instrs,
|
|
2556
|
+
args.tolerance,
|
|
2557
|
+
args.warmup_iterations,
|
|
2558
|
+
args.iterations,
|
|
2559
|
+
args.skip_ref_check,
|
|
2560
|
+
args.dynamic_persistent,
|
|
2561
|
+
)
|
|
2562
|
+
print("PASS")
|