quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +4 -1
- quack/autotuner.py +309 -0
- quack/cross_entropy.py +2 -5
- quack/cute_dsl_utils.py +40 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +2474 -0
- quack/fast_math.py +97 -0
- quack/gemm_config.py +61 -0
- quack/gemm_interface.py +321 -0
- quack/linear.py +176 -0
- quack/lse.py +62 -0
- quack/mlp.py +204 -0
- quack/pipeline.py +166 -0
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2088 -0
- quack/tensormap_manager.py +114 -0
- quack/tile_scheduler.py +935 -0
- quack/topk.py +221 -0
- quack/utils.py +237 -19
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/METADATA +3 -3
- quack_kernels-0.1.11.dist-info/RECORD +31 -0
- quack_kernels-0.1.9.dist-info/RECORD +0 -12
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/top_level.txt +0 -0
quack/dense_gemm_sm90.py
ADDED
|
@@ -0,0 +1,2474 @@
|
|
|
1
|
+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
3
|
+
|
|
4
|
+
# Redistribution and use in source and binary forms, with or without
|
|
5
|
+
# modification, are permitted provided that the following conditions are met:
|
|
6
|
+
|
|
7
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
+
# list of conditions and the following disclaimer.
|
|
9
|
+
|
|
10
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
+
# and/or other materials provided with the distribution.
|
|
13
|
+
|
|
14
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
+
# contributors may be used to endorse or promote products derived from
|
|
16
|
+
# this software without specific prior written permission.
|
|
17
|
+
|
|
18
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
|
+
|
|
29
|
+
import argparse
|
|
30
|
+
import enum
|
|
31
|
+
from typing import Tuple, Type, Callable, Optional
|
|
32
|
+
from functools import partial
|
|
33
|
+
import math
|
|
34
|
+
|
|
35
|
+
import cuda.bindings.driver as cuda
|
|
36
|
+
|
|
37
|
+
import torch
|
|
38
|
+
|
|
39
|
+
import cutlass
|
|
40
|
+
import cutlass.cute as cute
|
|
41
|
+
import cutlass.pipeline as pipeline
|
|
42
|
+
import cutlass.torch as cutlass_torch
|
|
43
|
+
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
44
|
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
45
|
+
import cutlass.utils.hopper_helpers as sm90_utils
|
|
46
|
+
from cutlass import Int32, const_expr
|
|
47
|
+
|
|
48
|
+
from quack.tile_scheduler import (
|
|
49
|
+
TileSchedulerArguments,
|
|
50
|
+
TileScheduler,
|
|
51
|
+
VarlenMTileSchedulerArguments,
|
|
52
|
+
VarlenMTileScheduler,
|
|
53
|
+
ParamsBase,
|
|
54
|
+
RasterOrderOption,
|
|
55
|
+
)
|
|
56
|
+
from quack.tensormap_manager import TensorMapManagerSm90
|
|
57
|
+
|
|
58
|
+
# return PipelineStateWAdvance instead of PipelineState
|
|
59
|
+
from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
|
|
60
|
+
import quack.utils as utils
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
|
64
|
+
using CUTE DSL.
|
|
65
|
+
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
|
66
|
+
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
|
67
|
+
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
|
68
|
+
|
|
69
|
+
This GEMM kernel supports the following features:
|
|
70
|
+
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
|
|
71
|
+
- Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations
|
|
72
|
+
- Implements TMA multicast with cluster to reduce L2 memory traffic
|
|
73
|
+
- Supports multi-stage pipeline to overlap computation and memory access
|
|
74
|
+
|
|
75
|
+
This GEMM works as follows:
|
|
76
|
+
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
|
77
|
+
2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction.
|
|
78
|
+
3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations.
|
|
79
|
+
|
|
80
|
+
Hopper WGMMA instructions operate as follows:
|
|
81
|
+
- Read matrix A from SMEM
|
|
82
|
+
- Read matrix B from SMEM
|
|
83
|
+
- Perform MMA operation and store the result in Accumulator(register)
|
|
84
|
+
|
|
85
|
+
To run this example:
|
|
86
|
+
|
|
87
|
+
.. code-block:: bash
|
|
88
|
+
|
|
89
|
+
python examples/hopper/dense_gemm.py \
|
|
90
|
+
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
|
91
|
+
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
|
92
|
+
--d_dtype Float16 --acc_dtype Float32 \
|
|
93
|
+
--a_major k --b_major k --d_major n
|
|
94
|
+
|
|
95
|
+
The above example command compute batched gemm with M=8192, N=8192, K=8192,
|
|
96
|
+
batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape
|
|
97
|
+
is (1,1). The input, mma accumulator and output data type are set as fp16, fp32
|
|
98
|
+
and fp16, respectively.
|
|
99
|
+
|
|
100
|
+
To collect performance with NCU profiler:
|
|
101
|
+
|
|
102
|
+
.. code-block:: bash
|
|
103
|
+
|
|
104
|
+
ncu python examples/hopper/dense_gemm.py \
|
|
105
|
+
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
|
106
|
+
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
|
107
|
+
--d_dtype Float16 --acc_dtype Float32 \
|
|
108
|
+
--a_major k --b_major k --d_major n
|
|
109
|
+
|
|
110
|
+
Constraints:
|
|
111
|
+
* Supported input data types: fp16, fp8 (e4m3fn, e5m2)
|
|
112
|
+
* For fp16 types, A and B must have the same data type
|
|
113
|
+
* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
|
|
114
|
+
* Fp8 types only support k-major layout
|
|
115
|
+
* Only fp32 accumulation is supported in this example
|
|
116
|
+
* CTA tile shape M must be 64/128
|
|
117
|
+
* CTA tile shape N must be 64/128/256
|
|
118
|
+
* CTA tile shape K must be 64
|
|
119
|
+
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
|
120
|
+
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
|
121
|
+
i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
|
|
122
|
+
* OOB tiles are not allowed when TMA store is disabled
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
127
|
+
# Helpers to parse args
|
|
128
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
129
|
+
def parse_comma_separated_ints(s: str):
|
|
130
|
+
try:
|
|
131
|
+
return tuple([int(x.strip()) for x in s.split(",")])
|
|
132
|
+
except ValueError:
|
|
133
|
+
raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def parse_arguments() -> argparse.Namespace:
|
|
137
|
+
parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.")
|
|
138
|
+
|
|
139
|
+
parser.add_argument(
|
|
140
|
+
"--mnkl",
|
|
141
|
+
type=parse_comma_separated_ints,
|
|
142
|
+
default=(4096, 4096, 4096, 1),
|
|
143
|
+
help="mnkl dimensions (comma-separated)",
|
|
144
|
+
)
|
|
145
|
+
parser.add_argument(
|
|
146
|
+
"--tile_shape_mnk",
|
|
147
|
+
type=parse_comma_separated_ints,
|
|
148
|
+
default=(128, 256, 64),
|
|
149
|
+
help="Cta tile shape (comma-separated)",
|
|
150
|
+
)
|
|
151
|
+
parser.add_argument(
|
|
152
|
+
"--cluster_shape_mn",
|
|
153
|
+
type=parse_comma_separated_ints,
|
|
154
|
+
choices=[(1, 1), (2, 1), (1, 2), (2, 2)],
|
|
155
|
+
default=(1, 1),
|
|
156
|
+
help="Cluster shape (comma-separated)",
|
|
157
|
+
)
|
|
158
|
+
parser.add_argument(
|
|
159
|
+
"--a_dtype",
|
|
160
|
+
type=cutlass.dtype,
|
|
161
|
+
default=cutlass.BFloat16,
|
|
162
|
+
)
|
|
163
|
+
parser.add_argument(
|
|
164
|
+
"--b_dtype",
|
|
165
|
+
type=cutlass.dtype,
|
|
166
|
+
default=cutlass.BFloat16,
|
|
167
|
+
)
|
|
168
|
+
parser.add_argument(
|
|
169
|
+
"--d_dtype",
|
|
170
|
+
type=cutlass.dtype,
|
|
171
|
+
default=cutlass.BFloat16,
|
|
172
|
+
)
|
|
173
|
+
parser.add_argument(
|
|
174
|
+
"--c_dtype",
|
|
175
|
+
type=cutlass.dtype,
|
|
176
|
+
default=None,
|
|
177
|
+
)
|
|
178
|
+
parser.add_argument(
|
|
179
|
+
"--acc_dtype",
|
|
180
|
+
type=cutlass.dtype,
|
|
181
|
+
default=cutlass.Float32,
|
|
182
|
+
)
|
|
183
|
+
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
184
|
+
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
185
|
+
parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
|
|
186
|
+
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
|
|
187
|
+
parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
|
|
188
|
+
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
|
|
189
|
+
parser.add_argument(
|
|
190
|
+
"--iterations",
|
|
191
|
+
type=int,
|
|
192
|
+
default=30,
|
|
193
|
+
help="Number of iterations to run the kernel",
|
|
194
|
+
)
|
|
195
|
+
parser.add_argument("--persistent", action="store_true", help="Persistent kernel")
|
|
196
|
+
parser.add_argument(
|
|
197
|
+
"--dynamic_persistent", action="store_true", help="Dynamic persistent kernel"
|
|
198
|
+
)
|
|
199
|
+
parser.add_argument("--pingpong", action="store_true", help="Pingpong kernel")
|
|
200
|
+
parser.add_argument("--varlen_m", action="store_true", help="Variable length M dimension")
|
|
201
|
+
parser.add_argument("--gather_A", action="store_true", help="Gather A")
|
|
202
|
+
parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum")
|
|
203
|
+
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
|
|
204
|
+
|
|
205
|
+
args = parser.parse_args()
|
|
206
|
+
|
|
207
|
+
if len(args.mnkl) != 4:
|
|
208
|
+
parser.error("--mnkl must contain exactly 4 values")
|
|
209
|
+
if len(args.tile_shape_mnk) != 3:
|
|
210
|
+
parser.error("--tile_shape_mnk must contain exactly 3 values")
|
|
211
|
+
if len(args.cluster_shape_mn) != 2:
|
|
212
|
+
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
|
213
|
+
|
|
214
|
+
return args
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
218
|
+
# Host setup and device kernel launch
|
|
219
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class NamedBarrierGemm(enum.IntEnum):
|
|
223
|
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
|
224
|
+
# For mainloop load warps to signal that the epilogue load warp can start.
|
|
225
|
+
# This is to avoid loading C too early, interfering with loading A and B.
|
|
226
|
+
EpilogueLoad = enum.auto()
|
|
227
|
+
MmaWG0 = enum.auto()
|
|
228
|
+
MmaWG1 = enum.auto()
|
|
229
|
+
EpiWG0 = enum.auto()
|
|
230
|
+
EpiWG1 = enum.auto()
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class HopperWgmmaGemmKernel:
|
|
234
|
+
"""
|
|
235
|
+
This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
236
|
+
and architectural features specific to Hopper GPUs.
|
|
237
|
+
|
|
238
|
+
:param acc_dtype: Data type for accumulation during computation
|
|
239
|
+
:type acc_dtype: type[cutlass.Numeric]
|
|
240
|
+
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
|
|
241
|
+
:type tile_shape_mnk: Tuple[int, int, int]
|
|
242
|
+
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
|
243
|
+
:type cluster_shape_mnk: Tuple[int, int, int]
|
|
244
|
+
|
|
245
|
+
:note: Data type requirements:
|
|
246
|
+
- For 16-bit types: A and B must have the same data type
|
|
247
|
+
- For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit
|
|
248
|
+
- Float8 types only support k-major layout
|
|
249
|
+
|
|
250
|
+
:note: Supported data types:
|
|
251
|
+
- Float16
|
|
252
|
+
- BFloat16
|
|
253
|
+
- Float8E4M3FN/Float8E5M2
|
|
254
|
+
|
|
255
|
+
:note: Supported accumulation types:
|
|
256
|
+
- Float32 (for all floating point inputs)
|
|
257
|
+
|
|
258
|
+
:note: Constraints:
|
|
259
|
+
- Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
|
260
|
+
|
|
261
|
+
Example:
|
|
262
|
+
>>> gemm = HopperWgmmaGemmKernel(
|
|
263
|
+
... acc_dtype=cutlass.Float32,
|
|
264
|
+
... tile_shape_mnk=(128, 256, 64),
|
|
265
|
+
... cluster_shape_mnk=(1, 1, 1)
|
|
266
|
+
... )
|
|
267
|
+
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
bytes_per_tensormap = 128
|
|
271
|
+
num_tensormaps = 1 # For D only
|
|
272
|
+
|
|
273
|
+
def __init__(
|
|
274
|
+
self,
|
|
275
|
+
acc_dtype: Type[cutlass.Numeric],
|
|
276
|
+
a_dtype: Type[cutlass.Numeric],
|
|
277
|
+
tile_shape_mnk: Tuple[int, int, int],
|
|
278
|
+
cluster_shape_mnk: Tuple[int, int, int],
|
|
279
|
+
pingpong: bool = False,
|
|
280
|
+
is_persistent: bool = True,
|
|
281
|
+
fp8_fast_accum: bool = False,
|
|
282
|
+
gather_A: bool = False,
|
|
283
|
+
):
|
|
284
|
+
"""
|
|
285
|
+
Initializes the configuration for a Hopper dense GEMM kernel.
|
|
286
|
+
|
|
287
|
+
This configuration includes data types for operands, tile shape, cluster configuration,
|
|
288
|
+
and thread layout.
|
|
289
|
+
|
|
290
|
+
:param acc_dtype: Data type for accumulation during computation
|
|
291
|
+
:type acc_dtype: type[cutlass.Numeric]
|
|
292
|
+
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
|
|
293
|
+
:type tile_shape_mnk: Tuple[int, int, int]
|
|
294
|
+
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
|
295
|
+
:type cluster_shape_mnk: Tuple[int, int, int]
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
self.acc_dtype = acc_dtype
|
|
299
|
+
self.pingpong = pingpong
|
|
300
|
+
self.is_persistent = is_persistent
|
|
301
|
+
if self.pingpong:
|
|
302
|
+
assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
|
|
303
|
+
self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
|
|
304
|
+
self.gather_A = gather_A
|
|
305
|
+
if gather_A:
|
|
306
|
+
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
|
307
|
+
self.tensormap_update_mode = cutlass.utils.TensorMapUpdateMode.SMEM
|
|
308
|
+
|
|
309
|
+
self.cluster_shape_mnk = cluster_shape_mnk
|
|
310
|
+
self.tile_shape_mnk = tuple(tile_shape_mnk)
|
|
311
|
+
tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1]
|
|
312
|
+
# check the cta tile shape
|
|
313
|
+
if not self.pingpong:
|
|
314
|
+
if tile_M not in [64, 128, 192, 256, 320]:
|
|
315
|
+
raise ValueError("CTA tile shape M must be 64/128/192/256/320")
|
|
316
|
+
if tile_M in [192, 320]: # special case
|
|
317
|
+
tile_N_max = 256 if tile_M == 192 else 160
|
|
318
|
+
if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
|
|
319
|
+
raise ValueError(
|
|
320
|
+
f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
if not (
|
|
324
|
+
(tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
|
|
325
|
+
):
|
|
326
|
+
raise ValueError(
|
|
327
|
+
"CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
if tile_M not in [64, 128, 192]:
|
|
331
|
+
raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
|
|
332
|
+
tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
|
|
333
|
+
if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
|
|
334
|
+
raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
|
|
335
|
+
if not self.tile_shape_mnk[2] % 16 == 0:
|
|
336
|
+
raise ValueError("CTA tile shape K must be divisible by 16")
|
|
337
|
+
|
|
338
|
+
if not self.pingpong:
|
|
339
|
+
if tile_M == 320: # tile_M / 64 is not even so we have to split along N
|
|
340
|
+
atom_layout_m, atom_layout_n = 1, 2
|
|
341
|
+
elif tile_M == 192:
|
|
342
|
+
if tile_N <= 128:
|
|
343
|
+
atom_layout_m, atom_layout_n = 3, 1
|
|
344
|
+
else:
|
|
345
|
+
atom_layout_m, atom_layout_n = 1, 2
|
|
346
|
+
else:
|
|
347
|
+
atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
|
|
348
|
+
atom_layout_n = 1
|
|
349
|
+
assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
|
|
350
|
+
else:
|
|
351
|
+
atom_layout_m, atom_layout_n = 1, 1
|
|
352
|
+
self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
|
|
353
|
+
|
|
354
|
+
self.num_mcast_ctas_a = self.cluster_shape_mnk[1] if not self.gather_A else 1
|
|
355
|
+
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
|
356
|
+
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
357
|
+
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
358
|
+
|
|
359
|
+
self.occupancy = 1
|
|
360
|
+
self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
|
|
361
|
+
if self.pingpong:
|
|
362
|
+
assert self.mma_warp_groups == 2
|
|
363
|
+
assert self.mma_warp_groups in [1, 2, 3]
|
|
364
|
+
self.num_threads_per_warp_group = 128
|
|
365
|
+
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
|
366
|
+
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
|
367
|
+
self.num_epi_threads = (
|
|
368
|
+
self.mma_warp_groups if not self.pingpong else 1
|
|
369
|
+
) * self.num_threads_per_warp_group
|
|
370
|
+
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
|
371
|
+
self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
|
|
372
|
+
self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
373
|
+
self.ab_load_warp_id = self.mma_warp_groups * 4
|
|
374
|
+
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
375
|
+
|
|
376
|
+
regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // (
|
|
377
|
+
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
378
|
+
)
|
|
379
|
+
if self.fp8_slow_accum:
|
|
380
|
+
regs_per_thread *= 2
|
|
381
|
+
if not self.gather_A:
|
|
382
|
+
if self.mma_warp_groups == 3:
|
|
383
|
+
self.num_regs_load, self.num_regs_mma = 32, 160
|
|
384
|
+
else:
|
|
385
|
+
heavy_register_pressure = regs_per_thread >= 208
|
|
386
|
+
self.num_regs_load, self.num_regs_mma = (
|
|
387
|
+
(40, 232) if not heavy_register_pressure else (24, 240)
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
if self.mma_warp_groups == 3:
|
|
391
|
+
self.num_regs_load, self.num_regs_mma = 56, 152
|
|
392
|
+
else:
|
|
393
|
+
self.num_regs_load, self.num_regs_mma = (56, 224)
|
|
394
|
+
|
|
395
|
+
self.ab_stage = None
|
|
396
|
+
self.epi_stage = None
|
|
397
|
+
|
|
398
|
+
self.a_smem_layout_staged = None
|
|
399
|
+
self.b_smem_layout_staged = None
|
|
400
|
+
self.epi_smem_layout_staged = None
|
|
401
|
+
self.epi_tile = None
|
|
402
|
+
|
|
403
|
+
self.shared_storage = None
|
|
404
|
+
self.buffer_align_bytes = 1024
|
|
405
|
+
|
|
406
|
+
def _setup_attributes(self):
|
|
407
|
+
"""Set up configurations that are dependent on GEMM inputs
|
|
408
|
+
|
|
409
|
+
This method configures various attributes based on the input tensor properties
|
|
410
|
+
(data types, leading dimensions) and kernel settings:
|
|
411
|
+
- Configuring tiled MMA
|
|
412
|
+
- Computing MMA/cluster/tile shapes
|
|
413
|
+
- Computing cluster layout
|
|
414
|
+
- Computing multicast CTAs for A/B
|
|
415
|
+
- Computing epilogue subtile
|
|
416
|
+
- Setting up A/B/C stage counts in shared memory
|
|
417
|
+
- Computing A/B/C shared memory layout
|
|
418
|
+
"""
|
|
419
|
+
|
|
420
|
+
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
|
421
|
+
|
|
422
|
+
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
|
423
|
+
self.tile_shape_mnk,
|
|
424
|
+
self.atom_layout_mnk,
|
|
425
|
+
self.d_dtype,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Compute stage before compute smem layout
|
|
429
|
+
self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
|
|
430
|
+
self.tile_shape_mnk,
|
|
431
|
+
self.epi_tile,
|
|
432
|
+
self.a_dtype,
|
|
433
|
+
self.b_dtype,
|
|
434
|
+
self.d_dtype,
|
|
435
|
+
self.c_dtype,
|
|
436
|
+
self.smem_capacity,
|
|
437
|
+
self.occupancy,
|
|
438
|
+
# epi_smem will reuse smem ab if not persistent.
|
|
439
|
+
overlap_sD_sA=not self.is_persistent,
|
|
440
|
+
)
|
|
441
|
+
self.sched_stage = 2 if self.pingpong else 1
|
|
442
|
+
|
|
443
|
+
(
|
|
444
|
+
self.a_smem_layout_staged,
|
|
445
|
+
self.b_smem_layout_staged,
|
|
446
|
+
self.epi_smem_layout_staged,
|
|
447
|
+
self.epi_c_smem_layout_staged,
|
|
448
|
+
) = self._make_smem_layouts(
|
|
449
|
+
self.tile_shape_mnk,
|
|
450
|
+
self.epi_tile,
|
|
451
|
+
self.a_dtype,
|
|
452
|
+
self.a_layout,
|
|
453
|
+
self.b_dtype,
|
|
454
|
+
self.b_layout,
|
|
455
|
+
self.ab_stage,
|
|
456
|
+
self.d_dtype,
|
|
457
|
+
self.d_layout,
|
|
458
|
+
self.epi_stage,
|
|
459
|
+
self.c_dtype,
|
|
460
|
+
self.c_layout,
|
|
461
|
+
self.epi_c_stage,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
@cute.jit
|
|
465
|
+
def __call__(
|
|
466
|
+
self,
|
|
467
|
+
mA: cute.Tensor,
|
|
468
|
+
mB: cute.Tensor,
|
|
469
|
+
mD: cute.Tensor,
|
|
470
|
+
mC: Optional[cute.Tensor],
|
|
471
|
+
mAIdx: Optional[cute.Tensor],
|
|
472
|
+
mCuSeqlensM: Optional[cute.Tensor],
|
|
473
|
+
mTensormaps: Optional[cute.Tensor],
|
|
474
|
+
tile_count_semaphore: Optional[cute.Pointer],
|
|
475
|
+
max_active_clusters: Int32,
|
|
476
|
+
stream: cuda.CUstream,
|
|
477
|
+
):
|
|
478
|
+
"""Execute the GEMM operation in steps:
|
|
479
|
+
- Setup static attributes
|
|
480
|
+
- Setup TMA load/store atoms and tensors
|
|
481
|
+
- Compute grid size
|
|
482
|
+
- Define shared storage for kernel
|
|
483
|
+
- Launch the kernel synchronously
|
|
484
|
+
|
|
485
|
+
:param mA: Input tensor A
|
|
486
|
+
:type mA: cute.Tensor
|
|
487
|
+
:param mB: Input tensor B
|
|
488
|
+
:type mB: cute.Tensor
|
|
489
|
+
:param mD: Output tensor D
|
|
490
|
+
:type mD: cute.Tensor
|
|
491
|
+
:param stream: CUDA stream for asynchronous execution
|
|
492
|
+
:type stream: cuda.CUstream
|
|
493
|
+
"""
|
|
494
|
+
|
|
495
|
+
# setup static attributes before smem/grid/tma computation
|
|
496
|
+
self.a_dtype = mA.element_type
|
|
497
|
+
self.b_dtype = mB.element_type
|
|
498
|
+
self.d_dtype = mD.element_type
|
|
499
|
+
self.c_dtype = mC.element_type if mC is not None else None
|
|
500
|
+
self.a_layout = cutlass.utils.LayoutEnum.from_tensor(mA)
|
|
501
|
+
self.b_layout = cutlass.utils.LayoutEnum.from_tensor(mB)
|
|
502
|
+
self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
|
|
503
|
+
self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
|
|
504
|
+
|
|
505
|
+
if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
|
|
506
|
+
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
|
|
507
|
+
if const_expr(self.a_dtype.width != self.b_dtype.width):
|
|
508
|
+
raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
|
|
509
|
+
if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
|
|
510
|
+
raise TypeError("a_dtype should be float16 or float8")
|
|
511
|
+
assert (mAIdx is not None) == self.gather_A
|
|
512
|
+
|
|
513
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
514
|
+
new_stride = lambda t: tuple(
|
|
515
|
+
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
|
516
|
+
for s in t.stride
|
|
517
|
+
)
|
|
518
|
+
mA, mD = [
|
|
519
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
520
|
+
for t in (mA, mD)
|
|
521
|
+
]
|
|
522
|
+
|
|
523
|
+
self._setup_attributes()
|
|
524
|
+
|
|
525
|
+
tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
|
526
|
+
self.a_dtype,
|
|
527
|
+
self.b_dtype,
|
|
528
|
+
self.a_layout.sm90_mma_major_mode(),
|
|
529
|
+
self.b_layout.sm90_mma_major_mode(),
|
|
530
|
+
self.acc_dtype,
|
|
531
|
+
self.atom_layout_mnk,
|
|
532
|
+
tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
|
533
|
+
)
|
|
534
|
+
if const_expr(self.atom_layout_mnk[1] > 1):
|
|
535
|
+
# If N dimension is split among 2 WGs, we need to permute the N dimension so
|
|
536
|
+
# that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
|
|
537
|
+
# containing accumulators that are next to each other in the N dimension.
|
|
538
|
+
# Without permutation WG0 would write to epi smem of size (64, 16) and
|
|
539
|
+
# WG1 would write to a separate epi smem of size (64, 16) that's far away.
|
|
540
|
+
atom_n = self.atom_layout_mnk[1]
|
|
541
|
+
permutation_n = cute.make_ordered_layout(
|
|
542
|
+
(8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
|
|
543
|
+
)
|
|
544
|
+
tiled_mma = cute.make_tiled_mma(
|
|
545
|
+
cute.make_mma_atom(tiled_mma.op),
|
|
546
|
+
self.atom_layout_mnk,
|
|
547
|
+
permutation_mnk=(None, permutation_n, None),
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
if const_expr(not self.gather_A):
|
|
551
|
+
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
|
552
|
+
mA,
|
|
553
|
+
self.a_smem_layout_staged,
|
|
554
|
+
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
|
|
555
|
+
self.cluster_shape_mnk[1],
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
tma_atom_a, tma_tensor_a = None, None
|
|
559
|
+
|
|
560
|
+
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
|
561
|
+
mB,
|
|
562
|
+
self.b_smem_layout_staged,
|
|
563
|
+
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
|
|
564
|
+
self.cluster_shape_mnk[0],
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
|
568
|
+
mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
if const_expr(mC is not None):
|
|
572
|
+
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
|
573
|
+
mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
|
|
574
|
+
)
|
|
575
|
+
else:
|
|
576
|
+
tma_atom_c, tma_tensor_c = None, None
|
|
577
|
+
|
|
578
|
+
if const_expr(mCuSeqlensM is None):
|
|
579
|
+
problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (
|
|
580
|
+
mD.shape[2],
|
|
581
|
+
)
|
|
582
|
+
TileSchedulerCls = TileScheduler
|
|
583
|
+
tile_sched_args = TileSchedulerArguments(
|
|
584
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
585
|
+
raster_order=RasterOrderOption.Heuristic,
|
|
586
|
+
group_size=8,
|
|
587
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
588
|
+
tile_count_semaphore=tile_count_semaphore,
|
|
589
|
+
is_persistent=self.is_persistent,
|
|
590
|
+
)
|
|
591
|
+
else:
|
|
592
|
+
assert mTensormaps is not None
|
|
593
|
+
problem_shape_ntile_mnl = (
|
|
594
|
+
None,
|
|
595
|
+
cute.ceil_div(mD.shape[1], self.tile_shape_mnk[1]),
|
|
596
|
+
mCuSeqlensM.shape[0] - 1,
|
|
597
|
+
)
|
|
598
|
+
TileSchedulerCls = VarlenMTileScheduler
|
|
599
|
+
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
600
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
601
|
+
total_m=mD.shape[0],
|
|
602
|
+
cu_seqlens_m=mCuSeqlensM,
|
|
603
|
+
raster_order=RasterOrderOption.Heuristic,
|
|
604
|
+
group_size=8,
|
|
605
|
+
tile_shape_mnk=self.tile_shape_mnk,
|
|
606
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
607
|
+
tile_count_semaphore=tile_count_semaphore,
|
|
608
|
+
is_persistent=self.is_persistent,
|
|
609
|
+
)
|
|
610
|
+
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
611
|
+
grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
|
|
612
|
+
|
|
613
|
+
epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if self.is_persistent else 0
|
|
614
|
+
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
|
615
|
+
|
|
616
|
+
size_tensormap_in_i64 = (
|
|
617
|
+
0
|
|
618
|
+
if mCuSeqlensM is None
|
|
619
|
+
or self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.GMEM
|
|
620
|
+
else HopperWgmmaGemmKernel.num_tensormaps
|
|
621
|
+
* HopperWgmmaGemmKernel.bytes_per_tensormap
|
|
622
|
+
// 8
|
|
623
|
+
) * (1 if not self.pingpong else 2)
|
|
624
|
+
|
|
625
|
+
@cute.struct
|
|
626
|
+
class SharedStorage:
|
|
627
|
+
tensormap_buffer: cute.struct.Align[
|
|
628
|
+
cute.struct.MemRange[cutlass.Int64, size_tensormap_in_i64],
|
|
629
|
+
64,
|
|
630
|
+
]
|
|
631
|
+
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
632
|
+
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
633
|
+
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
634
|
+
tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
|
|
635
|
+
sD: cute.struct.Align[
|
|
636
|
+
cute.struct.MemRange[self.d_dtype, epi_smem_size],
|
|
637
|
+
self.buffer_align_bytes,
|
|
638
|
+
]
|
|
639
|
+
sC: cute.struct.Align[
|
|
640
|
+
cute.struct.MemRange[
|
|
641
|
+
self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
|
|
642
|
+
],
|
|
643
|
+
self.buffer_align_bytes,
|
|
644
|
+
]
|
|
645
|
+
sA: cute.struct.Align[
|
|
646
|
+
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
|
|
647
|
+
self.buffer_align_bytes,
|
|
648
|
+
]
|
|
649
|
+
sB: cute.struct.Align[
|
|
650
|
+
cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
|
|
651
|
+
self.buffer_align_bytes,
|
|
652
|
+
]
|
|
653
|
+
|
|
654
|
+
self.shared_storage = SharedStorage
|
|
655
|
+
|
|
656
|
+
# Launch the kernel synchronously
|
|
657
|
+
self.kernel(
|
|
658
|
+
tma_atom_a,
|
|
659
|
+
tma_tensor_a if const_expr(not self.gather_A) else mA,
|
|
660
|
+
tma_atom_b,
|
|
661
|
+
tma_tensor_b,
|
|
662
|
+
tma_atom_d,
|
|
663
|
+
tma_tensor_d,
|
|
664
|
+
mD,
|
|
665
|
+
tma_atom_c,
|
|
666
|
+
tma_tensor_c,
|
|
667
|
+
mAIdx,
|
|
668
|
+
mCuSeqlensM,
|
|
669
|
+
mTensormaps,
|
|
670
|
+
tiled_mma,
|
|
671
|
+
self.cluster_layout_mnk,
|
|
672
|
+
self.a_smem_layout_staged,
|
|
673
|
+
self.b_smem_layout_staged,
|
|
674
|
+
self.epi_smem_layout_staged,
|
|
675
|
+
self.epi_c_smem_layout_staged,
|
|
676
|
+
tile_sched_params,
|
|
677
|
+
TileSchedulerCls,
|
|
678
|
+
).launch(
|
|
679
|
+
grid=grid,
|
|
680
|
+
block=[self.threads_per_cta, 1, 1],
|
|
681
|
+
cluster=self.cluster_shape_mnk,
|
|
682
|
+
smem=self.shared_storage.size_in_bytes(),
|
|
683
|
+
stream=stream,
|
|
684
|
+
min_blocks_per_mp=1,
|
|
685
|
+
)
|
|
686
|
+
return
|
|
687
|
+
|
|
688
|
+
# GPU device kernel
|
|
689
|
+
@cute.kernel
|
|
690
|
+
def kernel(
|
|
691
|
+
self,
|
|
692
|
+
tma_atom_a: Optional[cute.CopyAtom],
|
|
693
|
+
mA_mkl: cute.Tensor,
|
|
694
|
+
tma_atom_b: cute.CopyAtom,
|
|
695
|
+
mB_nkl: cute.Tensor,
|
|
696
|
+
tma_atom_d: cute.CopyAtom,
|
|
697
|
+
mD_mnl_tma: cute.Tensor,
|
|
698
|
+
mD_mnl: cute.Tensor,
|
|
699
|
+
tma_atom_c: Optional[cute.CopyAtom],
|
|
700
|
+
mC_mnl: Optional[cute.Tensor],
|
|
701
|
+
mAIdx: Optional[cute.Tensor],
|
|
702
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
703
|
+
tensormaps: Optional[cute.Tensor],
|
|
704
|
+
tiled_mma: cute.TiledMma,
|
|
705
|
+
cluster_layout_mnk: cute.Layout,
|
|
706
|
+
a_smem_layout_staged: cute.ComposedLayout,
|
|
707
|
+
b_smem_layout_staged: cute.ComposedLayout,
|
|
708
|
+
epi_smem_layout_staged: cute.ComposedLayout,
|
|
709
|
+
epi_c_smem_layout_staged: cute.ComposedLayout,
|
|
710
|
+
tile_sched_params: ParamsBase,
|
|
711
|
+
TileSchedulerCls: cutlass.Constexpr[Callable],
|
|
712
|
+
):
|
|
713
|
+
"""
|
|
714
|
+
GPU device kernel performing the batched GEMM computation.
|
|
715
|
+
|
|
716
|
+
:param tma_atom_a: TMA copy atom for A tensor
|
|
717
|
+
:type tma_atom_a: cute.CopyAtom
|
|
718
|
+
:param mA_mkl: Input tensor A
|
|
719
|
+
:type mA_mkl: cute.Tensor
|
|
720
|
+
:param tma_atom_b: TMA copy atom for B tensor
|
|
721
|
+
:type tma_atom_b: cute.CopyAtom
|
|
722
|
+
:param mB_nkl: Input tensor B
|
|
723
|
+
:type mB_nkl: cute.Tensor
|
|
724
|
+
:param tma_atom_d: TMA copy atom for D tensor
|
|
725
|
+
:type tma_atom_d: cute.CopyAtom
|
|
726
|
+
:param mD_mnl_tma: Output tensor D
|
|
727
|
+
:type mD_mnl_tma: cute.Tensor
|
|
728
|
+
:param tiled_mma: Tiled MMA object
|
|
729
|
+
:type tiled_mma: cute.TiledMma
|
|
730
|
+
:param cluster_layout_mnk: CTA layout
|
|
731
|
+
:type cluster_layout_mnk: cute.Layout
|
|
732
|
+
:param a_smem_layout_staged: Shared memory layout for A
|
|
733
|
+
:type a_smem_layout_staged: cute.ComposedLayout
|
|
734
|
+
:param b_smem_layout_staged: Shared memory layout for B
|
|
735
|
+
:type b_smem_layout_staged: cute.ComposedLayout
|
|
736
|
+
:param epi_smem_layout_staged: Shared memory layout for epilogue
|
|
737
|
+
:type epi_smem_layout_staged: cute.ComposedLayout
|
|
738
|
+
"""
|
|
739
|
+
|
|
740
|
+
varlen = const_expr(cu_seqlens_m is not None)
|
|
741
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
742
|
+
|
|
743
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
744
|
+
# Prefetch Tma desc
|
|
745
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
746
|
+
if warp_idx == self.ab_load_warp_id:
|
|
747
|
+
if const_expr(tma_atom_a is not None):
|
|
748
|
+
cpasync.prefetch_descriptor(tma_atom_a)
|
|
749
|
+
cpasync.prefetch_descriptor(tma_atom_b)
|
|
750
|
+
cpasync.prefetch_descriptor(tma_atom_d)
|
|
751
|
+
if const_expr(tma_atom_c is not None):
|
|
752
|
+
cpasync.prefetch_descriptor(tma_atom_c)
|
|
753
|
+
|
|
754
|
+
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
|
|
755
|
+
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
|
|
756
|
+
tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
757
|
+
if const_expr(not self.gather_A):
|
|
758
|
+
tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
759
|
+
|
|
760
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
761
|
+
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
|
762
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
763
|
+
smem = cutlass.utils.SmemAllocator()
|
|
764
|
+
storage = smem.allocate(self.shared_storage)
|
|
765
|
+
|
|
766
|
+
# Threads/warps participating in this pipeline
|
|
767
|
+
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
|
|
768
|
+
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
769
|
+
# Each warp will contribute to the arrive count with the number of mcast size
|
|
770
|
+
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
771
|
+
consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
|
|
772
|
+
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
773
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
cta_layout_vmnk = cute.make_layout((1, *cluster_layout_mnk.shape))
|
|
777
|
+
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
778
|
+
ab_pipeline = pipeline_cls.create(
|
|
779
|
+
barrier_storage=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
780
|
+
num_stages=self.ab_stage,
|
|
781
|
+
producer_group=ab_pipeline_producer_group,
|
|
782
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
783
|
+
tx_count=tma_copy_bytes,
|
|
784
|
+
cta_layout_vmnk=cta_layout_vmnk,
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
if const_expr(mC_mnl is not None):
|
|
788
|
+
# Threads/warps participating in this pipeline
|
|
789
|
+
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
790
|
+
# Each warp will contribute 1 to the arrive count
|
|
791
|
+
consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
|
|
792
|
+
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
793
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
794
|
+
)
|
|
795
|
+
c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
|
|
796
|
+
tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
|
|
797
|
+
epi_pipeline = pipeline.PipelineTmaAsync.create(
|
|
798
|
+
barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
799
|
+
num_stages=self.epi_c_stage,
|
|
800
|
+
producer_group=epi_pipeline_producer_group,
|
|
801
|
+
consumer_group=epi_pipeline_consumer_group,
|
|
802
|
+
tx_count=tma_copy_c_bytes,
|
|
803
|
+
)
|
|
804
|
+
else:
|
|
805
|
+
epi_pipeline = None
|
|
806
|
+
|
|
807
|
+
if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
808
|
+
# Dynamic persistent scheduler
|
|
809
|
+
# Threads/warps participating in this pipeline
|
|
810
|
+
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
811
|
+
cluster_size = cute.size(cluster_layout_mnk)
|
|
812
|
+
# Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
|
813
|
+
consumer_arrive_cnt = (
|
|
814
|
+
(self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
|
|
815
|
+
) * cluster_size - 1
|
|
816
|
+
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
817
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
818
|
+
)
|
|
819
|
+
sched_pipeline = pipeline.PipelineAsync.create(
|
|
820
|
+
barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
821
|
+
num_stages=self.sched_stage,
|
|
822
|
+
producer_group=sched_pipeline_producer_group,
|
|
823
|
+
consumer_group=sched_pipeline_consumer_group,
|
|
824
|
+
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
825
|
+
consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
|
|
826
|
+
)
|
|
827
|
+
tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
828
|
+
else:
|
|
829
|
+
sched_pipeline = None
|
|
830
|
+
tile_count = None
|
|
831
|
+
|
|
832
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
833
|
+
# Generate smem tensor A/B
|
|
834
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
835
|
+
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
|
|
836
|
+
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
|
|
837
|
+
if const_expr(not self.is_persistent):
|
|
838
|
+
sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
|
|
839
|
+
sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
|
|
840
|
+
else:
|
|
841
|
+
sD = storage.sD.get_tensor(
|
|
842
|
+
epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
|
|
843
|
+
)
|
|
844
|
+
if const_expr(mC_mnl is not None):
|
|
845
|
+
sC = storage.sC.get_tensor(
|
|
846
|
+
epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
|
|
847
|
+
)
|
|
848
|
+
else:
|
|
849
|
+
sC = None
|
|
850
|
+
|
|
851
|
+
# Get tensormap buffer address
|
|
852
|
+
if const_expr(varlen):
|
|
853
|
+
grid_dim = cute.arch.grid_dim()
|
|
854
|
+
bid = cute.arch.block_idx()
|
|
855
|
+
tensormap_workspace_idx = (
|
|
856
|
+
bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0]
|
|
857
|
+
)
|
|
858
|
+
# TODO: this is only for D, not for A/B
|
|
859
|
+
if const_expr(self.pingpong):
|
|
860
|
+
tensormap_workspace_idx = tensormap_workspace_idx * 2 + warp_idx // 4
|
|
861
|
+
tensormap_manager = TensorMapManagerSm90(
|
|
862
|
+
self.tensormap_update_mode, HopperWgmmaGemmKernel.bytes_per_tensormap
|
|
863
|
+
)
|
|
864
|
+
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
865
|
+
tensormaps[tensormap_workspace_idx, None].iterator
|
|
866
|
+
)
|
|
867
|
+
if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.SMEM):
|
|
868
|
+
tensormap_smem_ptr = storage.tensormap_buffer.data_ptr()
|
|
869
|
+
tensormap_d_smem_ptr = tensormap_smem_ptr + (warp_idx // 4) * (
|
|
870
|
+
HopperWgmmaGemmKernel.bytes_per_tensormap // 8
|
|
871
|
+
)
|
|
872
|
+
# Need this, otherwise "expected tma descriptor pointer to have alignment at least 64, but got 8"
|
|
873
|
+
tensormap_d_smem_ptr = cute.make_ptr(
|
|
874
|
+
cutlass.Int64,
|
|
875
|
+
tensormap_d_smem_ptr.toint(),
|
|
876
|
+
cute.AddressSpace.smem,
|
|
877
|
+
assumed_align=64,
|
|
878
|
+
)
|
|
879
|
+
tensormap_d_init_ptr = tensormap_d_smem_ptr
|
|
880
|
+
else:
|
|
881
|
+
tensormap_d_smem_ptr = None
|
|
882
|
+
tensormap_d_init_ptr = tensormap_d_ptr
|
|
883
|
+
else:
|
|
884
|
+
tensormap_d_smem_ptr = None
|
|
885
|
+
tensormap_manager, tensormap_d_ptr, tensormap_d_init_ptr = None, None, None
|
|
886
|
+
|
|
887
|
+
TileSchedulerCls = partial(
|
|
888
|
+
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
|
|
892
|
+
c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
|
|
893
|
+
|
|
894
|
+
if warp_idx >= self.ab_load_warp_id:
|
|
895
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
|
|
896
|
+
if const_expr(mC_mnl is not None):
|
|
897
|
+
epi_load_barrier = pipeline.NamedBarrier(
|
|
898
|
+
barrier_id=int(NamedBarrierGemm.EpilogueLoad),
|
|
899
|
+
num_threads=self.num_ab_load_threads + self.num_epi_load_threads,
|
|
900
|
+
)
|
|
901
|
+
else:
|
|
902
|
+
epi_load_barrier = None
|
|
903
|
+
if (
|
|
904
|
+
warp_idx >= self.ab_load_warp_id
|
|
905
|
+
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
906
|
+
):
|
|
907
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
908
|
+
# Get mcast mask
|
|
909
|
+
# ///////////////////////////////////////////////////////////////////////////////
|
|
910
|
+
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
911
|
+
cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
912
|
+
a_mcast_mask = cute.make_layout_image_mask(
|
|
913
|
+
cluster_layout_mnk, cluster_coord_mnk, mode=1
|
|
914
|
+
)
|
|
915
|
+
b_mcast_mask = cute.make_layout_image_mask(
|
|
916
|
+
cluster_layout_mnk, cluster_coord_mnk, mode=0
|
|
917
|
+
)
|
|
918
|
+
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
|
|
919
|
+
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
|
|
920
|
+
|
|
921
|
+
# Persistent tile scheduling loop
|
|
922
|
+
is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
923
|
+
if const_expr(cute.size(cluster_layout_mnk) > 1):
|
|
924
|
+
is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
|
|
925
|
+
tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
|
|
926
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
927
|
+
ab_producer_state = make_pipeline_state(
|
|
928
|
+
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
929
|
+
)
|
|
930
|
+
do_epi_load_barrier_arrive = cutlass.Boolean(True)
|
|
931
|
+
while work_tile.is_valid_tile:
|
|
932
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
933
|
+
batch_idx = tile_coord_mnkl[3]
|
|
934
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
935
|
+
# Local_tile partition global tensors
|
|
936
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
937
|
+
if const_expr(not self.gather_A):
|
|
938
|
+
if const_expr(cu_seqlens_m is not None):
|
|
939
|
+
mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
|
|
940
|
+
else:
|
|
941
|
+
mA_mk = mA_mkl[None, None, batch_idx]
|
|
942
|
+
# (bM, bK, RestK)
|
|
943
|
+
gA_k = cute.local_tile(
|
|
944
|
+
mA_mk,
|
|
945
|
+
cute.select(self.tile_shape_mnk, [0, 2]),
|
|
946
|
+
(tile_coord_mnkl[0], None),
|
|
947
|
+
)
|
|
948
|
+
else:
|
|
949
|
+
mA_mk = mA_mkl
|
|
950
|
+
if const_expr(cu_seqlens_m is not None):
|
|
951
|
+
mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
|
|
952
|
+
else:
|
|
953
|
+
mAIdx_mk = mAIdx[None, batch_idx]
|
|
954
|
+
gAIdx = cute.local_tile(
|
|
955
|
+
mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
|
956
|
+
)
|
|
957
|
+
# (bN, bK, RestK)
|
|
958
|
+
gB_k = cute.local_tile(
|
|
959
|
+
mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
|
|
960
|
+
)
|
|
961
|
+
# //////////////////////////////////////////////////////////////////////////
|
|
962
|
+
# Partition shared tensor for TMA load A/B
|
|
963
|
+
# //////////////////////////////////////////////////////////////////////////
|
|
964
|
+
# TMA load A partition_S/D
|
|
965
|
+
a_cta_layout = cute.make_layout(
|
|
966
|
+
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
967
|
+
)
|
|
968
|
+
a_cta_crd = cluster_coord_mnk[1]
|
|
969
|
+
if const_expr(not self.gather_A):
|
|
970
|
+
# ((atom_v, rest_v), STAGE)
|
|
971
|
+
# ((atom_v, rest_v), RestK)
|
|
972
|
+
tAsA, tAgA_k = cpasync.tma_partition(
|
|
973
|
+
tma_atom_a,
|
|
974
|
+
a_cta_crd,
|
|
975
|
+
a_cta_layout,
|
|
976
|
+
cute.group_modes(sA, 0, 2),
|
|
977
|
+
cute.group_modes(gA_k, 0, 2),
|
|
978
|
+
)
|
|
979
|
+
copy_A = partial(cute.copy, tma_atom_a, mcast_mask=a_mcast_mask)
|
|
980
|
+
else:
|
|
981
|
+
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
|
982
|
+
mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
|
|
983
|
+
)
|
|
984
|
+
tidx = (
|
|
985
|
+
cute.arch.thread_idx()[0]
|
|
986
|
+
- self.mma_warp_groups * self.num_threads_per_warp_group
|
|
987
|
+
)
|
|
988
|
+
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
989
|
+
# (atom_v, CPY_M, 1, STAGE)
|
|
990
|
+
tAsA = thr_copy_A.partition_D(sA)
|
|
991
|
+
assert tAsA.shape[2] == 1
|
|
992
|
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
|
993
|
+
copy_A = partial(cute.copy, tiled_copy_A)
|
|
994
|
+
# TMA load B partition_S/D
|
|
995
|
+
b_cta_layout = cute.make_layout(
|
|
996
|
+
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
|
997
|
+
)
|
|
998
|
+
b_cta_crd = cluster_coord_mnk[0]
|
|
999
|
+
# ((atom_v, rest_v), STAGE)
|
|
1000
|
+
# ((atom_v, rest_v), RestK)
|
|
1001
|
+
tBsB, tBgB_k = cpasync.tma_partition(
|
|
1002
|
+
tma_atom_b,
|
|
1003
|
+
b_cta_crd,
|
|
1004
|
+
b_cta_layout,
|
|
1005
|
+
cute.group_modes(sB, 0, 2),
|
|
1006
|
+
cute.group_modes(gB_k, 0, 2),
|
|
1007
|
+
)
|
|
1008
|
+
copy_B = partial(cute.copy, tma_atom_b, mcast_mask=b_mcast_mask)
|
|
1009
|
+
if const_expr(not self.gather_A):
|
|
1010
|
+
ab_producer_state = self.load_AB(
|
|
1011
|
+
ab_pipeline,
|
|
1012
|
+
ab_producer_state,
|
|
1013
|
+
copy_A,
|
|
1014
|
+
tAgA_k,
|
|
1015
|
+
tAsA,
|
|
1016
|
+
copy_B,
|
|
1017
|
+
tBgB_k,
|
|
1018
|
+
tBsB,
|
|
1019
|
+
)
|
|
1020
|
+
else:
|
|
1021
|
+
limit_m = (
|
|
1022
|
+
mAIdx.shape[0]
|
|
1023
|
+
if const_expr(cu_seqlens_m is None)
|
|
1024
|
+
else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
|
|
1025
|
+
)
|
|
1026
|
+
ab_producer_state = self.load_AB_gather_A(
|
|
1027
|
+
ab_pipeline,
|
|
1028
|
+
ab_producer_state,
|
|
1029
|
+
thr_copy_A,
|
|
1030
|
+
mA_mk,
|
|
1031
|
+
tAsA,
|
|
1032
|
+
gAIdx,
|
|
1033
|
+
copy_B,
|
|
1034
|
+
tBgB_k,
|
|
1035
|
+
tBsB,
|
|
1036
|
+
limit_A=(
|
|
1037
|
+
limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
|
|
1038
|
+
mA_mk.shape[1],
|
|
1039
|
+
),
|
|
1040
|
+
)
|
|
1041
|
+
if const_expr(epi_load_barrier is not None):
|
|
1042
|
+
# In the first work tile, the epi load warp will wait for the signal
|
|
1043
|
+
# from the mainloop load warp to start loading C, to avoid interfering
|
|
1044
|
+
# with loading A and B.
|
|
1045
|
+
if do_epi_load_barrier_arrive:
|
|
1046
|
+
epi_load_barrier.arrive()
|
|
1047
|
+
do_epi_load_barrier_arrive = cutlass.Boolean(False)
|
|
1048
|
+
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1049
|
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1050
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1051
|
+
# End of persistent scheduler loop
|
|
1052
|
+
if const_expr(self.pingpong):
|
|
1053
|
+
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
|
1054
|
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1055
|
+
ab_pipeline.producer_tail(ab_producer_state)
|
|
1056
|
+
if is_scheduler_warp:
|
|
1057
|
+
tile_scheduler.producer_tail()
|
|
1058
|
+
|
|
1059
|
+
# if const_expr(mC_mnl is not None):
|
|
1060
|
+
# if warp_idx == self.epi_load_warp_id:
|
|
1061
|
+
# epi_producer_state = make_pipeline_state(
|
|
1062
|
+
# pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
1063
|
+
# )
|
|
1064
|
+
# do_epi_load_barrier_wait = cutlass.Boolean(True)
|
|
1065
|
+
# tile_scheduler = TileSchedulerCls()
|
|
1066
|
+
# work_tile = tile_scheduler.initial_work_tile_info()
|
|
1067
|
+
# while work_tile.is_valid_tile:
|
|
1068
|
+
# tile_coord_mnkl = work_tile.tile_idx
|
|
1069
|
+
# batch_idx = tile_coord_mnkl[3]
|
|
1070
|
+
# if const_expr(cu_seqlens_m is not None):
|
|
1071
|
+
# mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
|
|
1072
|
+
# else:
|
|
1073
|
+
# mC_mn = mC_mnl[None, None, batch_idx]
|
|
1074
|
+
# # (bM, bN)
|
|
1075
|
+
# gC = cute.local_tile(
|
|
1076
|
+
# mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
|
|
1077
|
+
# )
|
|
1078
|
+
# tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
|
|
1079
|
+
# bGS_sC, bGS_gC = cpasync.tma_partition(
|
|
1080
|
+
# tma_atom_c,
|
|
1081
|
+
# 0,
|
|
1082
|
+
# cute.make_layout(1),
|
|
1083
|
+
# cute.group_modes(sC, 0, 2),
|
|
1084
|
+
# tCgC_for_tma_partition,
|
|
1085
|
+
# )
|
|
1086
|
+
# if do_epi_load_barrier_wait:
|
|
1087
|
+
# epi_load_barrier.arrive_and_wait()
|
|
1088
|
+
# do_epi_load_barrier_wait = cutlass.Boolean(False)
|
|
1089
|
+
# epi_tile_num = const_expr(cute.size(tCgC_for_tma_partition, mode=[1]))
|
|
1090
|
+
# epi_tile_shape = tCgC_for_tma_partition.shape[1]
|
|
1091
|
+
# for epi_idx in cutlass.range(epi_tile_num, unroll=1):
|
|
1092
|
+
# epi_pipeline.producer_acquire(epi_producer_state)
|
|
1093
|
+
# # Get the global memory coordinate for the current epi tile
|
|
1094
|
+
# epi_tile_layout = cute.make_layout(
|
|
1095
|
+
# epi_tile_shape, stride=(epi_tile_shape[1], 1)
|
|
1096
|
+
# )
|
|
1097
|
+
# gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1098
|
+
# cute.copy(
|
|
1099
|
+
# tma_atom_c,
|
|
1100
|
+
# bGS_gC[None, gmem_coord],
|
|
1101
|
+
# bGS_sC[None, epi_producer_state.index],
|
|
1102
|
+
# tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1103
|
+
# )
|
|
1104
|
+
# # Epi pipeline's producer commit is a NOP
|
|
1105
|
+
# epi_pipeline.producer_commit(epi_producer_state)
|
|
1106
|
+
# epi_producer_state.advance()
|
|
1107
|
+
# tile_scheduler.advance_to_next_work()
|
|
1108
|
+
# work_tile = tile_scheduler.get_current_work()
|
|
1109
|
+
# # End of persistent scheduler loop
|
|
1110
|
+
# epi_pipeline.producer_tail(epi_producer_state)
|
|
1111
|
+
|
|
1112
|
+
if warp_idx < self.ab_load_warp_id:
|
|
1113
|
+
cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
|
|
1114
|
+
is_tma_warp = cutlass.Boolean(
|
|
1115
|
+
(not self.pingpong and warp_idx == 0)
|
|
1116
|
+
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
|
1117
|
+
)
|
|
1118
|
+
if const_expr(varlen):
|
|
1119
|
+
# initialize tensormap for D
|
|
1120
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
1121
|
+
tma_atom_d,
|
|
1122
|
+
tensormap_d_init_ptr,
|
|
1123
|
+
is_manager_warp=is_tma_warp,
|
|
1124
|
+
)
|
|
1125
|
+
# //////////////////////////////////////////////////////////////////////////////
|
|
1126
|
+
# Partition global tensor for TiledMMA_A/B/C
|
|
1127
|
+
# //////////////////////////////////////////////////////////////////////////////
|
|
1128
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
1129
|
+
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
|
1130
|
+
if const_expr(self.pingpong):
|
|
1131
|
+
tidx = tidx % self.num_threads_per_warp_group
|
|
1132
|
+
warp_group_thread_layout = cute.make_layout(
|
|
1133
|
+
self.mma_warp_groups if not self.pingpong else 1,
|
|
1134
|
+
stride=self.num_threads_per_warp_group,
|
|
1135
|
+
)
|
|
1136
|
+
thr_mma = tiled_mma.get_slice(
|
|
1137
|
+
warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
|
|
1138
|
+
)
|
|
1139
|
+
|
|
1140
|
+
# //////////////////////////////////////////////////////////////////////////////
|
|
1141
|
+
# Make fragments
|
|
1142
|
+
# //////////////////////////////////////////////////////////////////////////////
|
|
1143
|
+
tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
|
|
1144
|
+
tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
|
|
1145
|
+
|
|
1146
|
+
acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
|
|
1147
|
+
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
1148
|
+
if const_expr(self.fp8_slow_accum):
|
|
1149
|
+
acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
1150
|
+
else:
|
|
1151
|
+
acc_slow = None
|
|
1152
|
+
|
|
1153
|
+
if const_expr(self.pingpong):
|
|
1154
|
+
if warp_group_idx == 0:
|
|
1155
|
+
# WG0 needs a start signal at the very beginning
|
|
1156
|
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
|
1157
|
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
|
1158
|
+
|
|
1159
|
+
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
|
1160
|
+
epi_read_state = make_pipeline_state(
|
|
1161
|
+
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
|
1162
|
+
)
|
|
1163
|
+
epi_producer_state = make_pipeline_state(
|
|
1164
|
+
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
1165
|
+
)
|
|
1166
|
+
tile_scheduler = TileSchedulerCls()
|
|
1167
|
+
if const_expr(self.pingpong):
|
|
1168
|
+
if warp_idx >= 4:
|
|
1169
|
+
# Advance 2nd Math WG to the next work tile for the startup
|
|
1170
|
+
tile_scheduler.advance_to_next_work()
|
|
1171
|
+
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
|
1172
|
+
ab_read_state.advance_iters(k_tile_cnt)
|
|
1173
|
+
epi_read_state.advance_iters(c_tile_cnt)
|
|
1174
|
+
epi_producer_state.advance_iters(c_tile_cnt)
|
|
1175
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1176
|
+
if const_expr(varlen):
|
|
1177
|
+
# wait tensormap initialization complete before update
|
|
1178
|
+
tensormap_manager.fence_tensormap_initialization()
|
|
1179
|
+
# batch index of last tile
|
|
1180
|
+
last_batch_idx = cutlass.Int32(-1)
|
|
1181
|
+
while work_tile.is_valid_tile:
|
|
1182
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1183
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1184
|
+
if const_expr(varlen):
|
|
1185
|
+
is_group_changed = batch_idx != last_batch_idx
|
|
1186
|
+
last_batch_idx = batch_idx
|
|
1187
|
+
if is_group_changed:
|
|
1188
|
+
# construct tensor D based on real address, shape and stride information
|
|
1189
|
+
tensormap_manager.update_tensormap_shape(
|
|
1190
|
+
((tensormap_d_ptr),),
|
|
1191
|
+
is_manager_warp=is_tma_warp,
|
|
1192
|
+
tensormap_smem_ptr=(tensormap_d_smem_ptr,),
|
|
1193
|
+
shapes=(cu_seqlens_m[batch_idx + 1],),
|
|
1194
|
+
orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
ab_read_state, tiled_mma = self.mma(
|
|
1198
|
+
ab_pipeline,
|
|
1199
|
+
ab_read_state,
|
|
1200
|
+
tiled_mma,
|
|
1201
|
+
tCrA,
|
|
1202
|
+
tCrB,
|
|
1203
|
+
acc,
|
|
1204
|
+
acc_slow,
|
|
1205
|
+
k_tile_cnt,
|
|
1206
|
+
warp_group_idx,
|
|
1207
|
+
)
|
|
1208
|
+
if const_expr(self.pingpong):
|
|
1209
|
+
# Update starting mainloop pipeline state for the next tile
|
|
1210
|
+
ab_read_state.advance_iters(k_tile_cnt)
|
|
1211
|
+
|
|
1212
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1213
|
+
# EPILOGUE
|
|
1214
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1215
|
+
if const_expr(self.pingpong):
|
|
1216
|
+
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
|
1217
|
+
|
|
1218
|
+
epilogue_barrier = pipeline.NamedBarrier(
|
|
1219
|
+
barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
|
|
1220
|
+
)
|
|
1221
|
+
|
|
1222
|
+
# Wait for all warp groups in the thread block to finish, because smem for tensor
|
|
1223
|
+
# A in the mainloop is reused in the epilogue if not persistent.
|
|
1224
|
+
if const_expr(not self.is_persistent):
|
|
1225
|
+
epilogue_barrier.arrive_and_wait()
|
|
1226
|
+
|
|
1227
|
+
if const_expr(varlen):
|
|
1228
|
+
# ensure the update to tensormap has completed before using it
|
|
1229
|
+
if is_group_changed:
|
|
1230
|
+
if is_tma_warp:
|
|
1231
|
+
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1232
|
+
|
|
1233
|
+
# Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
|
|
1234
|
+
# get st.matrix with num_matrices=4
|
|
1235
|
+
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
|
|
1236
|
+
self.d_layout, elem_ty_d=self.d_dtype, elem_ty_acc=self.acc_dtype
|
|
1237
|
+
)
|
|
1238
|
+
copy_atom_C = cute.make_copy_atom(
|
|
1239
|
+
warp.StMatrix8x8x16bOp(
|
|
1240
|
+
self.d_layout.is_m_major_c(),
|
|
1241
|
+
num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
|
|
1242
|
+
),
|
|
1243
|
+
cutlass.Float16, # this is just to get the right source layout
|
|
1244
|
+
)
|
|
1245
|
+
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
|
1246
|
+
tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
|
|
1247
|
+
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1248
|
+
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1249
|
+
tRS_sD = thr_copy_r2s.partition_D(sD)
|
|
1250
|
+
# (R2S, R2S_M, R2S_N)
|
|
1251
|
+
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
1252
|
+
|
|
1253
|
+
# Allocate D registers.
|
|
1254
|
+
tRS_rD_layout = cute.make_layout(thr_copy_r2s.partition_S(sD).shape[:3])
|
|
1255
|
+
tRS_rD = cute.make_fragment(tRS_rD_layout, self.acc_dtype)
|
|
1256
|
+
|
|
1257
|
+
if const_expr(mC_mnl is not None):
|
|
1258
|
+
copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
|
|
1259
|
+
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
|
1260
|
+
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1261
|
+
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1262
|
+
tRS_rC = cute.make_fragment(tRS_rD_layout, self.c_dtype)
|
|
1263
|
+
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
|
1264
|
+
else:
|
|
1265
|
+
thr_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
|
1266
|
+
|
|
1267
|
+
if const_expr(cu_seqlens_m is not None):
|
|
1268
|
+
mD_mn_tma = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl_tma)
|
|
1269
|
+
else:
|
|
1270
|
+
mD_mn_tma = mD_mnl_tma[None, None, batch_idx]
|
|
1271
|
+
# (bM, bN)
|
|
1272
|
+
gD = cute.local_tile(
|
|
1273
|
+
mD_mn_tma, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
|
|
1274
|
+
)
|
|
1275
|
+
tDgD_for_tma_partition = cute.zipped_divide(gD, self.epi_tile)
|
|
1276
|
+
bSG_sD, bSG_gD = cpasync.tma_partition(
|
|
1277
|
+
tma_atom_d,
|
|
1278
|
+
0,
|
|
1279
|
+
cute.make_layout(1),
|
|
1280
|
+
cute.group_modes(sD, 0, 2),
|
|
1281
|
+
tDgD_for_tma_partition,
|
|
1282
|
+
)
|
|
1283
|
+
|
|
1284
|
+
if const_expr(mC_mnl is not None):
|
|
1285
|
+
if const_expr(cu_seqlens_m is not None):
|
|
1286
|
+
mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
|
|
1287
|
+
else:
|
|
1288
|
+
mC_mn = mC_mnl[None, None, batch_idx]
|
|
1289
|
+
# (bM, bN)
|
|
1290
|
+
gC = cute.local_tile(
|
|
1291
|
+
mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
|
|
1292
|
+
)
|
|
1293
|
+
tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
|
|
1294
|
+
bGS_sC, bGS_gC = cpasync.tma_partition(
|
|
1295
|
+
tma_atom_c,
|
|
1296
|
+
0,
|
|
1297
|
+
cute.make_layout(1),
|
|
1298
|
+
cute.group_modes(sC, 0, 2),
|
|
1299
|
+
tCgC_for_tma_partition,
|
|
1300
|
+
)
|
|
1301
|
+
|
|
1302
|
+
epi_tile_num = const_expr(cute.size(tDgD_for_tma_partition, mode=[1]))
|
|
1303
|
+
epi_tile_shape = tDgD_for_tma_partition.shape[1]
|
|
1304
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
1305
|
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
1306
|
+
|
|
1307
|
+
if const_expr(mC_mnl is not None):
|
|
1308
|
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
1309
|
+
if is_tma_warp:
|
|
1310
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1311
|
+
# Get the global memory coordinate for the current epi tile
|
|
1312
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1313
|
+
cute.copy(
|
|
1314
|
+
tma_atom_c,
|
|
1315
|
+
bGS_gC[None, gmem_coord],
|
|
1316
|
+
bGS_sC[None, epi_producer_state.index],
|
|
1317
|
+
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1318
|
+
)
|
|
1319
|
+
# Epi pipeline's producer commit is a NOP
|
|
1320
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1321
|
+
epi_producer_state.advance()
|
|
1322
|
+
|
|
1323
|
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
1324
|
+
# Copy from acc to D registers
|
|
1325
|
+
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
1326
|
+
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
1327
|
+
if const_expr(mC_mnl is not None):
|
|
1328
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
1329
|
+
cute.copy(
|
|
1330
|
+
thr_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
|
|
1331
|
+
)
|
|
1332
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
1333
|
+
cute.arch.fence_proxy(
|
|
1334
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1335
|
+
)
|
|
1336
|
+
cute.arch.sync_warp()
|
|
1337
|
+
with cute.arch.elect_one():
|
|
1338
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
1339
|
+
epi_read_state.advance()
|
|
1340
|
+
if const_expr(epi_idx + self.epi_c_stage < epi_tile_num):
|
|
1341
|
+
if is_tma_warp:
|
|
1342
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1343
|
+
# Get the global memory coordinate for the current epi tile
|
|
1344
|
+
gmem_coord = epi_tile_layout.get_hier_coord(
|
|
1345
|
+
epi_idx + self.epi_c_stage
|
|
1346
|
+
)
|
|
1347
|
+
cute.copy(
|
|
1348
|
+
tma_atom_c,
|
|
1349
|
+
bGS_gC[None, gmem_coord],
|
|
1350
|
+
bGS_sC[None, epi_producer_state.index],
|
|
1351
|
+
tma_bar_ptr=epi_pipeline.producer_get_barrier(
|
|
1352
|
+
epi_producer_state
|
|
1353
|
+
),
|
|
1354
|
+
)
|
|
1355
|
+
# Epi pipeline's producer commit is a NOP
|
|
1356
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1357
|
+
epi_producer_state.advance()
|
|
1358
|
+
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(self.acc_dtype))
|
|
1359
|
+
# Type conversion
|
|
1360
|
+
tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
|
|
1361
|
+
tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
|
|
1362
|
+
# Copy from D registers to shared memory
|
|
1363
|
+
epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3])
|
|
1364
|
+
cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
|
|
1365
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1366
|
+
cute.arch.fence_proxy(
|
|
1367
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1368
|
+
)
|
|
1369
|
+
epilogue_barrier.arrive_and_wait()
|
|
1370
|
+
# Get the global memory coordinate for the current epi tile
|
|
1371
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1372
|
+
# Copy from shared memory to global memory
|
|
1373
|
+
if is_tma_warp:
|
|
1374
|
+
if const_expr(varlen):
|
|
1375
|
+
tma_desc_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1376
|
+
tensormap_d_ptr,
|
|
1377
|
+
cute.AddressSpace.generic,
|
|
1378
|
+
)
|
|
1379
|
+
else:
|
|
1380
|
+
tma_desc_ptr = None
|
|
1381
|
+
cute.copy(
|
|
1382
|
+
tma_atom_d,
|
|
1383
|
+
bSG_sD[None, epi_buffer],
|
|
1384
|
+
bSG_gD[None, gmem_coord],
|
|
1385
|
+
tma_desc_ptr=tma_desc_ptr,
|
|
1386
|
+
)
|
|
1387
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
1388
|
+
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
|
|
1389
|
+
epilogue_barrier.arrive_and_wait()
|
|
1390
|
+
|
|
1391
|
+
if const_expr(self.pingpong):
|
|
1392
|
+
# Update starting load/store pipeline states for the next tile
|
|
1393
|
+
epi_read_state.advance_iters(c_tile_cnt)
|
|
1394
|
+
epi_producer_state.advance_iters(c_tile_cnt)
|
|
1395
|
+
# With pingpong, 2 WGs write two different output tiles to the same smem,
|
|
1396
|
+
# so we have to make sure the smem content is done reading before signaling
|
|
1397
|
+
# the next WG's epilogue.
|
|
1398
|
+
if warp_idx == 0 or warp_idx == 4:
|
|
1399
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
1400
|
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
|
1401
|
+
|
|
1402
|
+
tile_scheduler.advance_to_next_work(
|
|
1403
|
+
advance_count=1 if not self.pingpong else self.mma_warp_groups
|
|
1404
|
+
)
|
|
1405
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1406
|
+
# End of persistent scheduler loop
|
|
1407
|
+
|
|
1408
|
+
if const_expr(not self.pingpong):
|
|
1409
|
+
if warp_idx == 0:
|
|
1410
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
1411
|
+
|
|
1412
|
+
@cute.jit
|
|
1413
|
+
def load_AB(
|
|
1414
|
+
self,
|
|
1415
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1416
|
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1417
|
+
copy_A: Callable,
|
|
1418
|
+
tAgA: cute.Tensor,
|
|
1419
|
+
tAsA: cute.Tensor,
|
|
1420
|
+
copy_B: Callable,
|
|
1421
|
+
tBgB: cute.Tensor,
|
|
1422
|
+
tBsB: cute.Tensor,
|
|
1423
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1424
|
+
k_tile_cnt = cute.size(tAgA, mode=[1])
|
|
1425
|
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1426
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1427
|
+
if 0 < k_tile_cnt:
|
|
1428
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1429
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1430
|
+
# TMA load
|
|
1431
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1432
|
+
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
1433
|
+
# Wait for A/B buffers to be empty before loading into them
|
|
1434
|
+
# Also sets the transaction barrier for the A/B buffers
|
|
1435
|
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
|
1436
|
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1437
|
+
copy_A(
|
|
1438
|
+
tAgA[None, k_tile],
|
|
1439
|
+
tAsA[None, ab_producer_state.index],
|
|
1440
|
+
tma_bar_ptr=tma_bar_ptr,
|
|
1441
|
+
)
|
|
1442
|
+
copy_B(
|
|
1443
|
+
tBgB[None, k_tile],
|
|
1444
|
+
tBsB[None, ab_producer_state.index],
|
|
1445
|
+
tma_bar_ptr=tma_bar_ptr,
|
|
1446
|
+
)
|
|
1447
|
+
# Mainloop pipeline's producer commit is a NOP
|
|
1448
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1449
|
+
ab_producer_state.advance()
|
|
1450
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1451
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1452
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1453
|
+
return ab_producer_state
|
|
1454
|
+
|
|
1455
|
+
@cute.jit
|
|
1456
|
+
def load_AB_gather_A(
|
|
1457
|
+
self,
|
|
1458
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1459
|
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1460
|
+
thr_copy_A: cute.core.ThrCopy,
|
|
1461
|
+
mA: cute.Tensor,
|
|
1462
|
+
tAsA: cute.Tensor,
|
|
1463
|
+
gAIdx: cute.Tensor,
|
|
1464
|
+
copy_B: Callable,
|
|
1465
|
+
tBgB: cute.Tensor,
|
|
1466
|
+
tBsB: cute.Tensor,
|
|
1467
|
+
limit_A: Tuple[Int32, Int32],
|
|
1468
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1469
|
+
# (atom_v, CPY_M, 1, RestK)
|
|
1470
|
+
limit_m, limit_k = limit_A
|
|
1471
|
+
limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
|
|
1472
|
+
cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
|
|
1473
|
+
tAcA = thr_copy_A.partition_S(cA)
|
|
1474
|
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
1475
|
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
1476
|
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
1477
|
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
1478
|
+
limit_m = limit_m - tAcA[0][0]
|
|
1479
|
+
# Read indices for A
|
|
1480
|
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
1481
|
+
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
|
1482
|
+
for m in cutlass.range(rows_per_thread):
|
|
1483
|
+
row_idx = tAcA[0, m, 0][0]
|
|
1484
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1485
|
+
m_idx[m] = gAIdx[row_idx]
|
|
1486
|
+
else:
|
|
1487
|
+
m_idx[m] = -1
|
|
1488
|
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
1489
|
+
# (m, (bK, RestK))
|
|
1490
|
+
mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
|
|
1491
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
1492
|
+
k_tile_cnt = cute.size(tBgB, mode=[1])
|
|
1493
|
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1494
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1495
|
+
if 0 < k_tile_cnt:
|
|
1496
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1497
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1498
|
+
# TMA load on B and cp.async on A
|
|
1499
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1500
|
+
copy_A = partial(cute.copy, thr_copy_A)
|
|
1501
|
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1502
|
+
# Wait for A/B buffers to be empty before loading into them
|
|
1503
|
+
# Also sets the transaction barrier for the A/B buffers
|
|
1504
|
+
ab_pipeline.producer_acquire(
|
|
1505
|
+
ab_producer_state,
|
|
1506
|
+
peek_ab_empty_status,
|
|
1507
|
+
# A tiny bit faster to rotate the warp that does TMA
|
|
1508
|
+
is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
|
|
1509
|
+
)
|
|
1510
|
+
# A bit faster to load B first while we calculate the predicate for A
|
|
1511
|
+
if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
|
|
1512
|
+
copy_B(
|
|
1513
|
+
tBgB[None, k_tile],
|
|
1514
|
+
tBsB[None, ab_producer_state.index],
|
|
1515
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1516
|
+
)
|
|
1517
|
+
# (m, bK)
|
|
1518
|
+
mA_cur = mA_k[None, (None, k_tile)]
|
|
1519
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1520
|
+
# (elems_per_load, thread_per_row)
|
|
1521
|
+
mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
|
|
1522
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1523
|
+
# There's only 1 load per row
|
|
1524
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1525
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1526
|
+
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1527
|
+
# This tells mbarrier to track the completion of cp.async
|
|
1528
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1529
|
+
ab_producer_state.advance()
|
|
1530
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1531
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1532
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1533
|
+
# bound checking in the K dimension on the last k_tile
|
|
1534
|
+
if 0 < k_tile_cnt:
|
|
1535
|
+
k_tile = k_tile_cnt - 1
|
|
1536
|
+
ab_pipeline.producer_acquire(
|
|
1537
|
+
ab_producer_state,
|
|
1538
|
+
peek_ab_empty_status,
|
|
1539
|
+
is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
|
|
1540
|
+
)
|
|
1541
|
+
if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
|
|
1542
|
+
copy_B(
|
|
1543
|
+
tBgB[None, k_tile],
|
|
1544
|
+
tBsB[None, ab_producer_state.index],
|
|
1545
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1546
|
+
)
|
|
1547
|
+
assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
|
|
1548
|
+
tApA = cute.make_fragment(1, cutlass.Boolean)
|
|
1549
|
+
tApA[0] = tAcA[0, 0, 0][1] < limit_k
|
|
1550
|
+
# (m, bK)
|
|
1551
|
+
mA_cur = mA_k[None, (None, k_tile)]
|
|
1552
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1553
|
+
# (elems_per_load, thread_per_row)
|
|
1554
|
+
mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
|
|
1555
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1556
|
+
# There's only 1 load per row
|
|
1557
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1558
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1559
|
+
# copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
|
|
1560
|
+
# TODO
|
|
1561
|
+
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1562
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1563
|
+
ab_producer_state.advance()
|
|
1564
|
+
return ab_producer_state
|
|
1565
|
+
|
|
1566
|
+
@cute.jit
|
|
1567
|
+
def mma(
|
|
1568
|
+
self,
|
|
1569
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1570
|
+
ab_read_state: cutlass.pipeline.PipelineState,
|
|
1571
|
+
tiled_mma: cute.TiledMma,
|
|
1572
|
+
tCrA: cute.Tensor,
|
|
1573
|
+
tCrB: cute.Tensor,
|
|
1574
|
+
acc: cute.Tensor,
|
|
1575
|
+
acc_slow: Optional[cute.Tensor],
|
|
1576
|
+
k_tile_cnt: Int32,
|
|
1577
|
+
warp_group_idx: Int32,
|
|
1578
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
|
|
1579
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1580
|
+
# Prologue MMAs
|
|
1581
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1582
|
+
k_pipe_mmas = 1
|
|
1583
|
+
ab_release_state = ab_read_state.clone()
|
|
1584
|
+
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
|
|
1585
|
+
if const_expr(self.pingpong):
|
|
1586
|
+
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
|
|
1587
|
+
peek_ab_full_status = cutlass.Boolean(True)
|
|
1588
|
+
if 0 < k_tile_cnt:
|
|
1589
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1590
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
1591
|
+
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
1592
|
+
# TODO: this is probably not correct if k_tile_cnt == 0
|
|
1593
|
+
for k_tile in cutlass.range(num_prologue_mma):
|
|
1594
|
+
# Wait for A/B buffer to be ready
|
|
1595
|
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
1596
|
+
warpgroup.fence()
|
|
1597
|
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1598
|
+
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
|
1599
|
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1600
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
|
1601
|
+
warpgroup.commit_group()
|
|
1602
|
+
ab_read_state.advance()
|
|
1603
|
+
peek_ab_full_status = cutlass.Boolean(True)
|
|
1604
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1605
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1606
|
+
if const_expr(self.fp8_slow_accum):
|
|
1607
|
+
warpgroup.wait_group(0)
|
|
1608
|
+
acc_slow.store(acc.load())
|
|
1609
|
+
|
|
1610
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1611
|
+
# MAINLOOP
|
|
1612
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1613
|
+
for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
|
|
1614
|
+
# Wait for TMA copies to complete
|
|
1615
|
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
1616
|
+
# WGMMA
|
|
1617
|
+
warpgroup.fence()
|
|
1618
|
+
if const_expr(self.fp8_slow_accum):
|
|
1619
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
1620
|
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1621
|
+
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
|
1622
|
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1623
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
|
1624
|
+
warpgroup.commit_group()
|
|
1625
|
+
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
|
|
1626
|
+
if const_expr(not self.fp8_slow_accum):
|
|
1627
|
+
warpgroup.wait_group(k_pipe_mmas)
|
|
1628
|
+
else:
|
|
1629
|
+
warpgroup.wait_group(0)
|
|
1630
|
+
acc_slow.store(acc_slow.load() + acc.load())
|
|
1631
|
+
ab_pipeline.consumer_release(ab_release_state)
|
|
1632
|
+
ab_read_state.advance()
|
|
1633
|
+
ab_release_state.advance()
|
|
1634
|
+
peek_ab_full_status = cutlass.Boolean(True)
|
|
1635
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1636
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1637
|
+
if const_expr(self.pingpong):
|
|
1638
|
+
# Cue for next WG's MMA to start
|
|
1639
|
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
|
|
1640
|
+
if const_expr(not self.fp8_slow_accum):
|
|
1641
|
+
# fp8_slow_accum would already called wait_group(0) inside the loop
|
|
1642
|
+
warpgroup.wait_group(0)
|
|
1643
|
+
for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
|
|
1644
|
+
ab_pipeline.consumer_release(ab_release_state)
|
|
1645
|
+
ab_release_state.advance()
|
|
1646
|
+
if const_expr(self.fp8_slow_accum):
|
|
1647
|
+
acc.store(acc_slow.load())
|
|
1648
|
+
# If we don't return the tiled_mma, we get compiler error
|
|
1649
|
+
# "operand #0 does not dominate this use"
|
|
1650
|
+
return ab_read_state, tiled_mma
|
|
1651
|
+
|
|
1652
|
+
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
|
|
1653
|
+
assert stage in ["mma", "epi"]
|
|
1654
|
+
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1655
|
+
cute.arch.barrier(
|
|
1656
|
+
barrier_id=int(barrier) + warp_group_idx,
|
|
1657
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1658
|
+
)
|
|
1659
|
+
|
|
1660
|
+
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
|
|
1661
|
+
assert stage in ["mma", "epi"]
|
|
1662
|
+
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1663
|
+
cute.arch.barrier_arrive(
|
|
1664
|
+
barrier_id=int(barrier) + warp_group_idx,
|
|
1665
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1666
|
+
)
|
|
1667
|
+
|
|
1668
|
+
@staticmethod
|
|
1669
|
+
def _compute_stages(
|
|
1670
|
+
tile_shape_mnk: Tuple[int, int, int],
|
|
1671
|
+
epi_tile: Optional[Tuple[int, int]],
|
|
1672
|
+
a_dtype: Type[cutlass.Numeric],
|
|
1673
|
+
b_dtype: Type[cutlass.Numeric],
|
|
1674
|
+
d_dtype: Type[cutlass.Numeric],
|
|
1675
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1676
|
+
smem_capacity: int,
|
|
1677
|
+
occupancy: int,
|
|
1678
|
+
overlap_sD_sA: bool,
|
|
1679
|
+
) -> Tuple[int, int]:
|
|
1680
|
+
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
1681
|
+
|
|
1682
|
+
:param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
|
1683
|
+
:type tile_shape_mnk: Tuple[int, int, int]
|
|
1684
|
+
:param a_dtype: Data type of operand A.
|
|
1685
|
+
:type a_dtype: type[cutlass.Numeric]
|
|
1686
|
+
:param b_dtype: Data type of operand B.
|
|
1687
|
+
:type b_dtype: type[cutlass.Numeric]
|
|
1688
|
+
:param smem_capacity: Total available shared memory capacity in bytes.
|
|
1689
|
+
:type smem_capacity: int
|
|
1690
|
+
:param occupancy: Target number of CTAs per SM (occupancy).
|
|
1691
|
+
:type occupancy: int
|
|
1692
|
+
|
|
1693
|
+
:return: A tuple containing the computed number of stages for:
|
|
1694
|
+
(A/B operand stages, epilogue stages)
|
|
1695
|
+
:rtype: Tuple[int, int]
|
|
1696
|
+
"""
|
|
1697
|
+
|
|
1698
|
+
epi_stage = 2
|
|
1699
|
+
if overlap_sD_sA:
|
|
1700
|
+
epi_bytes = 0
|
|
1701
|
+
else:
|
|
1702
|
+
d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8
|
|
1703
|
+
epi_bytes = d_bytes_per_stage * epi_stage
|
|
1704
|
+
epi_c_stage = 0 if c_dtype is None else 2
|
|
1705
|
+
if c_dtype is not None:
|
|
1706
|
+
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
1707
|
+
|
|
1708
|
+
a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
|
|
1709
|
+
b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
|
|
1710
|
+
ab_bytes_per_stage = (
|
|
1711
|
+
cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
|
|
1712
|
+
)
|
|
1713
|
+
mbar_helpers_bytes = 1024
|
|
1714
|
+
|
|
1715
|
+
remaining_bytes = (
|
|
1716
|
+
(smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
|
|
1717
|
+
)
|
|
1718
|
+
ab_stage = remaining_bytes // ab_bytes_per_stage
|
|
1719
|
+
|
|
1720
|
+
# Refine epilogue stages:
|
|
1721
|
+
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1722
|
+
# Add remaining unused smem to epilogue
|
|
1723
|
+
if not overlap_sD_sA:
|
|
1724
|
+
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // d_bytes_per_stage
|
|
1725
|
+
return ab_stage, epi_stage, epi_c_stage
|
|
1726
|
+
|
|
1727
|
+
@staticmethod
|
|
1728
|
+
def _sm90_compute_tile_shape_or_override(
|
|
1729
|
+
tile_shape_mnk: Tuple[int, int, int],
|
|
1730
|
+
atom_layout_mnk: Tuple[int, int, int],
|
|
1731
|
+
element_type: Type[cutlass.Numeric],
|
|
1732
|
+
epi_tile_override: Tuple[int, int] | None = None,
|
|
1733
|
+
) -> Tuple[int, int]:
|
|
1734
|
+
"""Compute the epilogue tile shape or use override if provided.
|
|
1735
|
+
|
|
1736
|
+
:param tile_shape_mnk: CTA tile shape (M,N,K)
|
|
1737
|
+
:type tile_shape_mnk: Tuple[int, int, int]
|
|
1738
|
+
:param element_type: Data type of elements
|
|
1739
|
+
:type element_type: type[cutlass.Numeric]
|
|
1740
|
+
:param is_cooperative: Whether to use cooperative approach
|
|
1741
|
+
:type is_cooperative: bool
|
|
1742
|
+
:param epi_tile_override: Optional override for epilogue tile shape
|
|
1743
|
+
:type epi_tile_override: Tuple[int, int] or None
|
|
1744
|
+
|
|
1745
|
+
:return: Computed epilogue tile shape
|
|
1746
|
+
:rtype: Tuple[int, int]
|
|
1747
|
+
"""
|
|
1748
|
+
if epi_tile_override is not None:
|
|
1749
|
+
return epi_tile_override
|
|
1750
|
+
if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
|
|
1751
|
+
tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
|
|
1752
|
+
tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
|
|
1753
|
+
elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
|
|
1754
|
+
tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
|
|
1755
|
+
tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
|
|
1756
|
+
else:
|
|
1757
|
+
# In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
|
|
1758
|
+
# epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
|
|
1759
|
+
# M dimension first, then move to the N dimension. But the accumulator in registers
|
|
1760
|
+
# iterate along the N dimension first, then move to the M dimension.
|
|
1761
|
+
# We could change the epilogue to accommodate this,
|
|
1762
|
+
# but it's easier to just set epi_tile_m = 64.
|
|
1763
|
+
n_perf = 64 if element_type.width == 8 else 32
|
|
1764
|
+
tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
|
|
1765
|
+
tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
|
|
1766
|
+
return (tile_m, tile_n)
|
|
1767
|
+
|
|
1768
|
+
@staticmethod
|
|
1769
|
+
def _make_smem_layouts(
|
|
1770
|
+
tile_shape_mnk: Tuple[int, int, int],
|
|
1771
|
+
epi_tile: Tuple[int, int],
|
|
1772
|
+
a_dtype: Type[cutlass.Numeric],
|
|
1773
|
+
a_layout: cutlass.utils.LayoutEnum,
|
|
1774
|
+
b_dtype: Type[cutlass.Numeric],
|
|
1775
|
+
b_layout: cutlass.utils.LayoutEnum,
|
|
1776
|
+
ab_stage: int,
|
|
1777
|
+
d_dtype: Type[cutlass.Numeric],
|
|
1778
|
+
d_layout: cutlass.utils.LayoutEnum,
|
|
1779
|
+
epi_stage: int,
|
|
1780
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1781
|
+
c_layout: Optional[cutlass.utils.LayoutEnum],
|
|
1782
|
+
epi_c_stage: int,
|
|
1783
|
+
) -> Tuple[
|
|
1784
|
+
cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
|
|
1785
|
+
]:
|
|
1786
|
+
"""Create shared memory layouts for A, B, and C tensors.
|
|
1787
|
+
|
|
1788
|
+
:param tile_shape_mnk: CTA tile shape (M,N,K)
|
|
1789
|
+
:type tile_shape_mnk: Tuple[int, int, int]
|
|
1790
|
+
:param epi_tile: Epilogue tile shape
|
|
1791
|
+
:type epi_tile: Tuple[int, int]
|
|
1792
|
+
:param a_dtype: Data type for matrix A
|
|
1793
|
+
:type a_dtype: type[cutlass.Numeric]
|
|
1794
|
+
:param a_layout: Layout enum for matrix A
|
|
1795
|
+
:type a_layout: cutlass.utils.LayoutEnum
|
|
1796
|
+
:param b_dtype: Data type for matrix B
|
|
1797
|
+
:type b_dtype: type[cutlass.Numeric]
|
|
1798
|
+
:param b_layout: Layout enum for matrix B
|
|
1799
|
+
:type b_layout: cutlass.utils.LayoutEnum
|
|
1800
|
+
:param ab_stage: Number of stages for A/B tensors
|
|
1801
|
+
:type ab_stage: int
|
|
1802
|
+
:param d_dtype: Data type for output matrix C
|
|
1803
|
+
:type d_dtype: type[cutlass.Numeric]
|
|
1804
|
+
:param d_layout: Layout enum for the output matrix C
|
|
1805
|
+
:type d_layout: cutlass.utils.LayoutEnum
|
|
1806
|
+
:param epi_stage: Number of epilogue stages
|
|
1807
|
+
:type epi_stage: int
|
|
1808
|
+
|
|
1809
|
+
:return: Tuple of shared memory layouts for A, B, and C
|
|
1810
|
+
:rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
|
|
1811
|
+
"""
|
|
1812
|
+
a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
|
|
1813
|
+
|
|
1814
|
+
a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
|
1815
|
+
b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
|
1816
|
+
a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
|
|
1817
|
+
a_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1818
|
+
sm90_utils.get_smem_layout_atom(
|
|
1819
|
+
a_layout,
|
|
1820
|
+
a_dtype,
|
|
1821
|
+
a_major_mode_size,
|
|
1822
|
+
),
|
|
1823
|
+
a_dtype,
|
|
1824
|
+
)
|
|
1825
|
+
a_smem_layout_staged = cute.tile_to_shape(
|
|
1826
|
+
a_smem_layout_atom,
|
|
1827
|
+
cute.append(a_smem_shape, ab_stage),
|
|
1828
|
+
order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
|
|
1829
|
+
)
|
|
1830
|
+
|
|
1831
|
+
b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None))
|
|
1832
|
+
|
|
1833
|
+
b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
|
|
1834
|
+
b_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1835
|
+
sm90_utils.get_smem_layout_atom(
|
|
1836
|
+
b_layout,
|
|
1837
|
+
b_dtype,
|
|
1838
|
+
b_major_mode_size,
|
|
1839
|
+
),
|
|
1840
|
+
b_dtype,
|
|
1841
|
+
)
|
|
1842
|
+
b_smem_layout_staged = cute.tile_to_shape(
|
|
1843
|
+
b_smem_layout_atom,
|
|
1844
|
+
cute.append(b_smem_shape, ab_stage),
|
|
1845
|
+
order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
|
|
1846
|
+
)
|
|
1847
|
+
|
|
1848
|
+
d_smem_shape = epi_tile
|
|
1849
|
+
d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
|
|
1850
|
+
d_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1851
|
+
sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
|
|
1852
|
+
d_dtype,
|
|
1853
|
+
)
|
|
1854
|
+
epi_smem_layout_staged = cute.tile_to_shape(
|
|
1855
|
+
d_smem_layout_atom,
|
|
1856
|
+
cute.append(d_smem_shape, epi_stage),
|
|
1857
|
+
order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
if c_dtype is not None:
|
|
1861
|
+
assert c_layout is not None
|
|
1862
|
+
c_smem_shape = epi_tile
|
|
1863
|
+
c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
|
|
1864
|
+
c_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1865
|
+
sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
|
|
1866
|
+
c_dtype,
|
|
1867
|
+
)
|
|
1868
|
+
epi_c_smem_layout_staged = cute.tile_to_shape(
|
|
1869
|
+
c_smem_layout_atom,
|
|
1870
|
+
cute.append(c_smem_shape, epi_c_stage),
|
|
1871
|
+
order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
|
|
1872
|
+
)
|
|
1873
|
+
else:
|
|
1874
|
+
epi_c_smem_layout_staged = None
|
|
1875
|
+
|
|
1876
|
+
return (
|
|
1877
|
+
a_smem_layout_staged,
|
|
1878
|
+
b_smem_layout_staged,
|
|
1879
|
+
epi_smem_layout_staged,
|
|
1880
|
+
epi_c_smem_layout_staged,
|
|
1881
|
+
)
|
|
1882
|
+
|
|
1883
|
+
@staticmethod
|
|
1884
|
+
def _make_tma_epi_atoms_and_tensors(
|
|
1885
|
+
tensor_d: cute.Tensor,
|
|
1886
|
+
epi_smem_layout_staged: cute.ComposedLayout,
|
|
1887
|
+
epi_tile: Tuple[int, int],
|
|
1888
|
+
store_or_load: str,
|
|
1889
|
+
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
|
1890
|
+
"""Create TMA atoms and tensors for storing D or loading C.
|
|
1891
|
+
|
|
1892
|
+
:param tensor_d: Output tensor D
|
|
1893
|
+
:type tensor_d: cute.Tensor
|
|
1894
|
+
:param epi_smem_layout_staged: Shared memory layout for epilogue
|
|
1895
|
+
:type epi_smem_layout_staged: cute.ComposedLayout
|
|
1896
|
+
:param epi_tile: Epilogue tile shape
|
|
1897
|
+
:type epi_tile: Tuple[int, int]
|
|
1898
|
+
|
|
1899
|
+
:return: TMA atom and tensor for C
|
|
1900
|
+
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
1901
|
+
"""
|
|
1902
|
+
assert store_or_load in ["load", "store"]
|
|
1903
|
+
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
|
1904
|
+
d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
|
|
1905
|
+
op = (
|
|
1906
|
+
cpasync.CopyBulkTensorTileG2SOp()
|
|
1907
|
+
if store_or_load == "load"
|
|
1908
|
+
else cpasync.CopyBulkTensorTileS2GOp()
|
|
1909
|
+
)
|
|
1910
|
+
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
|
1911
|
+
op, tensor_d, epi_smem_layout, d_cta_v_layout
|
|
1912
|
+
)
|
|
1913
|
+
return tma_atom_d, tma_tensor_d
|
|
1914
|
+
|
|
1915
|
+
@staticmethod
|
|
1916
|
+
def _make_tma_atoms_and_tensors(
|
|
1917
|
+
tensor: cute.Tensor,
|
|
1918
|
+
smem_layout_staged: cute.ComposedLayout,
|
|
1919
|
+
smem_tile: Tuple[int, int],
|
|
1920
|
+
mcast_dim: int,
|
|
1921
|
+
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
|
1922
|
+
"""Create TMA atoms and tensors for input tensors.
|
|
1923
|
+
|
|
1924
|
+
:param tensor: Input tensor (A or B)
|
|
1925
|
+
:type tensor: cute.Tensor
|
|
1926
|
+
:param smem_layout_staged: Shared memory layout for the tensor
|
|
1927
|
+
:type smem_layout_staged: cute.ComposedLayout
|
|
1928
|
+
:param smem_tile: Shared memory tile shape
|
|
1929
|
+
:type smem_tile: Tuple[int, int]
|
|
1930
|
+
:param mcast_dim: Multicast dimension
|
|
1931
|
+
:type mcast_dim: int
|
|
1932
|
+
|
|
1933
|
+
:return: TMA atom and tensor
|
|
1934
|
+
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
1935
|
+
"""
|
|
1936
|
+
op = (
|
|
1937
|
+
cpasync.CopyBulkTensorTileG2SOp()
|
|
1938
|
+
if mcast_dim == 1
|
|
1939
|
+
else cpasync.CopyBulkTensorTileG2SMulticastOp()
|
|
1940
|
+
)
|
|
1941
|
+
|
|
1942
|
+
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
|
|
1943
|
+
tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
|
|
1944
|
+
op,
|
|
1945
|
+
tensor,
|
|
1946
|
+
smem_layout,
|
|
1947
|
+
smem_tile,
|
|
1948
|
+
num_multicast=mcast_dim,
|
|
1949
|
+
)
|
|
1950
|
+
return tma_atom, tma_tensor
|
|
1951
|
+
|
|
1952
|
+
def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128):
|
|
1953
|
+
atom_async_copy = cute.make_copy_atom(
|
|
1954
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
1955
|
+
dtype,
|
|
1956
|
+
num_bits_per_copy=copy_bits,
|
|
1957
|
+
)
|
|
1958
|
+
copy_elems = copy_bits // dtype.width
|
|
1959
|
+
shape_dim_1 = cute.size(self.tile_shape_mnk[2]) // copy_elems
|
|
1960
|
+
# thread layout for copy
|
|
1961
|
+
thread_layout = cute.make_layout(
|
|
1962
|
+
(num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
|
1963
|
+
)
|
|
1964
|
+
if major_mode != cutlass.utils.LayoutEnum.ROW_MAJOR:
|
|
1965
|
+
shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
|
|
1966
|
+
thread_layout = cute.make_layout(
|
|
1967
|
+
(shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
|
1968
|
+
)
|
|
1969
|
+
# Value layout for copy
|
|
1970
|
+
value_layout = (
|
|
1971
|
+
cute.make_layout((1, copy_elems))
|
|
1972
|
+
if major_mode == cutlass.utils.LayoutEnum.ROW_MAJOR
|
|
1973
|
+
else cute.make_layout((copy_elems, 1))
|
|
1974
|
+
)
|
|
1975
|
+
return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
|
|
1976
|
+
|
|
1977
|
+
@staticmethod
|
|
1978
|
+
def is_valid_dtypes(
|
|
1979
|
+
a_dtype: Type[cutlass.Numeric],
|
|
1980
|
+
b_dtype: Type[cutlass.Numeric],
|
|
1981
|
+
acc_dtype: Type[cutlass.Numeric],
|
|
1982
|
+
d_dtype: Type[cutlass.Numeric],
|
|
1983
|
+
a_major: str,
|
|
1984
|
+
b_major: str,
|
|
1985
|
+
) -> bool:
|
|
1986
|
+
"""
|
|
1987
|
+
Check if the dtypes are valid
|
|
1988
|
+
|
|
1989
|
+
:param a_dtype: The data type of tensor A
|
|
1990
|
+
:type a_dtype: Type[cutlass.Numeric]
|
|
1991
|
+
:param b_dtype: The data type of tensor B
|
|
1992
|
+
:type b_dtype: Type[cutlass.Numeric]
|
|
1993
|
+
:param acc_dtype: The data type of the accumulator
|
|
1994
|
+
:type acc_dtype: Type[cutlass.Numeric]
|
|
1995
|
+
:param d_dtype: The data type of the output tensor
|
|
1996
|
+
:type d_dtype: Type[cutlass.Numeric]
|
|
1997
|
+
:param a_major: major mode of tensor A
|
|
1998
|
+
:type a_major: str
|
|
1999
|
+
:param b_major: major mode of tensor B
|
|
2000
|
+
:type b_major: str
|
|
2001
|
+
|
|
2002
|
+
:return: True if the dtypes are valid, False otherwise
|
|
2003
|
+
:rtype: bool
|
|
2004
|
+
"""
|
|
2005
|
+
is_valid = True
|
|
2006
|
+
if a_dtype not in {
|
|
2007
|
+
cutlass.Float16,
|
|
2008
|
+
cutlass.BFloat16,
|
|
2009
|
+
cutlass.Float8E4M3FN,
|
|
2010
|
+
cutlass.Float8E5M2,
|
|
2011
|
+
}:
|
|
2012
|
+
is_valid = False
|
|
2013
|
+
# tested b_dtype
|
|
2014
|
+
if b_dtype not in {
|
|
2015
|
+
cutlass.Float16,
|
|
2016
|
+
cutlass.BFloat16,
|
|
2017
|
+
cutlass.Float8E4M3FN,
|
|
2018
|
+
cutlass.Float8E5M2,
|
|
2019
|
+
}:
|
|
2020
|
+
is_valid = False
|
|
2021
|
+
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
|
|
2022
|
+
is_valid = False
|
|
2023
|
+
# tested d_dtype
|
|
2024
|
+
if d_dtype not in {
|
|
2025
|
+
cutlass.Float32,
|
|
2026
|
+
cutlass.Float16,
|
|
2027
|
+
cutlass.BFloat16,
|
|
2028
|
+
cutlass.Float8E4M3FN,
|
|
2029
|
+
cutlass.Float8E5M2,
|
|
2030
|
+
}:
|
|
2031
|
+
is_valid = False
|
|
2032
|
+
# make sure a_dtype == b_dtype for Float16
|
|
2033
|
+
if a_dtype.width == 16 and a_dtype != b_dtype:
|
|
2034
|
+
is_valid = False
|
|
2035
|
+
# make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2)
|
|
2036
|
+
if a_dtype.width != b_dtype.width:
|
|
2037
|
+
is_valid = False
|
|
2038
|
+
|
|
2039
|
+
# for Float8 types, this implementation only supports k-major layout
|
|
2040
|
+
if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
|
|
2041
|
+
is_valid = False
|
|
2042
|
+
|
|
2043
|
+
return is_valid
|
|
2044
|
+
|
|
2045
|
+
|
|
2046
|
+
def run(
|
|
2047
|
+
mnkl: Tuple[int, int, int, int],
|
|
2048
|
+
a_dtype: Type[cutlass.Numeric],
|
|
2049
|
+
b_dtype: Type[cutlass.Numeric],
|
|
2050
|
+
d_dtype: Type[cutlass.Numeric],
|
|
2051
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
2052
|
+
acc_dtype: Type[cutlass.Numeric],
|
|
2053
|
+
a_major: str,
|
|
2054
|
+
b_major: str,
|
|
2055
|
+
d_major: str,
|
|
2056
|
+
c_major: str,
|
|
2057
|
+
tile_shape_mnk: Tuple[int, int, int],
|
|
2058
|
+
cluster_shape_mn: Tuple[int, int],
|
|
2059
|
+
tolerance: float,
|
|
2060
|
+
warmup_iterations: int,
|
|
2061
|
+
iterations: int,
|
|
2062
|
+
skip_ref_check: bool,
|
|
2063
|
+
persistent: bool,
|
|
2064
|
+
dynamic_persistent: bool,
|
|
2065
|
+
pingpong: bool,
|
|
2066
|
+
varlen_m: bool,
|
|
2067
|
+
gather_A: bool,
|
|
2068
|
+
fp8_fast_accum: bool,
|
|
2069
|
+
**kwargs,
|
|
2070
|
+
):
|
|
2071
|
+
"""
|
|
2072
|
+
Prepare A/B/D/C tensors, launch GPU kernel, and reference checking.
|
|
2073
|
+
|
|
2074
|
+
:param mnkl: Problem size (M, N, K, L)
|
|
2075
|
+
:type mnkl: Tuple[int, int, int, int]
|
|
2076
|
+
:param a_dtype: Data type for input tensor A
|
|
2077
|
+
:type a_dtype: Type[cutlass.Numeric]
|
|
2078
|
+
:param b_dtype: Data type for input tensor B
|
|
2079
|
+
:type b_dtype: Type[cutlass.Numeric]
|
|
2080
|
+
:param d_dtype: Data type for output tensor C
|
|
2081
|
+
:type d_dtype: Type[cutlass.Numeric]
|
|
2082
|
+
:param acc_dtype: Data type for accumulation during matrix multiplication
|
|
2083
|
+
:type acc_dtype: Type[cutlass.Numeric]
|
|
2084
|
+
:param a_major/b_major/d_major: Memory layout of tensor A/B/C
|
|
2085
|
+
:type a_major/b_major/d_major: str
|
|
2086
|
+
:param tile_shape_mnk: CTA tile shape (M, N, K)
|
|
2087
|
+
:type tile_shape_mnk: Tuple[int, int, int]
|
|
2088
|
+
:param cluster_shape_mn: Cluster shape (M, N)
|
|
2089
|
+
:type cluster_shape_mn: Tuple[int, int]
|
|
2090
|
+
:param tolerance: Tolerance value for reference validation comparison
|
|
2091
|
+
:type tolerance: float
|
|
2092
|
+
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
|
2093
|
+
:type warmup_iterations: int, optional
|
|
2094
|
+
:param iterations: Number of benchmark iterations to run, defaults to 1
|
|
2095
|
+
:type iterations: int, optional
|
|
2096
|
+
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
|
2097
|
+
:type skip_ref_check: bool, optional
|
|
2098
|
+
"""
|
|
2099
|
+
|
|
2100
|
+
if dynamic_persistent:
|
|
2101
|
+
persistent = True
|
|
2102
|
+
|
|
2103
|
+
print("Running Hopper Dense GEMM with:")
|
|
2104
|
+
print(f"mnkl: {mnkl}")
|
|
2105
|
+
print(
|
|
2106
|
+
f"A dtype: {a_dtype}, B dtype: {b_dtype}, D dtype: {d_dtype}, C_dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
|
2107
|
+
)
|
|
2108
|
+
print(f"Matrix majors - A: {a_major}, B: {b_major}, D: {d_major}")
|
|
2109
|
+
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
|
2110
|
+
print(f"Tolerance: {tolerance}")
|
|
2111
|
+
print(f"Warmup iterations: {warmup_iterations}")
|
|
2112
|
+
print(f"Iterations: {iterations}")
|
|
2113
|
+
print(f"Skip reference checking: {skip_ref_check}")
|
|
2114
|
+
|
|
2115
|
+
# Unpack parameters
|
|
2116
|
+
m, n, k, l = mnkl
|
|
2117
|
+
cluster_shape_mnk = (*cluster_shape_mn, 1)
|
|
2118
|
+
|
|
2119
|
+
# Skip unsupported types
|
|
2120
|
+
if not HopperWgmmaGemmKernel.is_valid_dtypes(
|
|
2121
|
+
a_dtype, b_dtype, acc_dtype, d_dtype, a_major, b_major
|
|
2122
|
+
):
|
|
2123
|
+
raise TypeError(
|
|
2124
|
+
f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {d_dtype}, {a_major=}, {b_major=}"
|
|
2125
|
+
)
|
|
2126
|
+
|
|
2127
|
+
# Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero)
|
|
2128
|
+
if not torch.cuda.is_available():
|
|
2129
|
+
raise RuntimeError("GPU is required to run this example!")
|
|
2130
|
+
|
|
2131
|
+
torch.manual_seed(1111)
|
|
2132
|
+
|
|
2133
|
+
# Create and permute tensor A/B/C
|
|
2134
|
+
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
|
|
2135
|
+
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
|
2136
|
+
# else : (l, mode0, mode1) -> (mode0, mode1, l)
|
|
2137
|
+
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
|
2138
|
+
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
|
2139
|
+
is_unsigned = dtype in {cutlass.Uint8}
|
|
2140
|
+
# Temporarily use uint8 as torch does not support fp8 type
|
|
2141
|
+
torch_dtype = cutlass_torch.dtype(dtype)
|
|
2142
|
+
gen_dtype = (
|
|
2143
|
+
torch_dtype
|
|
2144
|
+
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
2145
|
+
else torch.bfloat16
|
|
2146
|
+
)
|
|
2147
|
+
|
|
2148
|
+
# Create dtype torch tensor (cpu)
|
|
2149
|
+
torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
|
|
2150
|
+
shape,
|
|
2151
|
+
gen_dtype,
|
|
2152
|
+
permute_order=permute_order,
|
|
2153
|
+
# init_type=cutlass.torch.TensorInitType.RANDOM,
|
|
2154
|
+
# init_config=cutlass.torch.RandomInitConfig(
|
|
2155
|
+
# min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
|
|
2156
|
+
# ),
|
|
2157
|
+
init_type=cutlass.torch.TensorInitType.GAUSSIAN,
|
|
2158
|
+
init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
|
|
2159
|
+
).to(torch_dtype)
|
|
2160
|
+
# Create dtype torch tensor (gpu)
|
|
2161
|
+
torch_tensor = torch_tensor_cpu.cuda()
|
|
2162
|
+
|
|
2163
|
+
# Create f32 torch tensor (cpu)
|
|
2164
|
+
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
|
2165
|
+
|
|
2166
|
+
# Create dtype cute tensor (gpu)
|
|
2167
|
+
torch_tensor_view = (
|
|
2168
|
+
torch_tensor
|
|
2169
|
+
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
2170
|
+
else torch_tensor.view(torch.uint8)
|
|
2171
|
+
)
|
|
2172
|
+
cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
|
|
2173
|
+
cute_tensor.element_type = dtype
|
|
2174
|
+
if is_dynamic_layout:
|
|
2175
|
+
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
|
|
2176
|
+
cute_tensor = cute_tensor.mark_compact_shape_dynamic(
|
|
2177
|
+
mode=(1 if not is_mode0_major else 0),
|
|
2178
|
+
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
|
|
2179
|
+
divisibility=(128 // dtype.width),
|
|
2180
|
+
)
|
|
2181
|
+
cute_tensor = cutlass.torch.convert_cute_tensor(
|
|
2182
|
+
f32_torch_tensor,
|
|
2183
|
+
cute_tensor,
|
|
2184
|
+
dtype,
|
|
2185
|
+
is_dynamic_layout=is_dynamic_layout,
|
|
2186
|
+
)
|
|
2187
|
+
|
|
2188
|
+
return f32_torch_tensor, cute_tensor, torch_tensor
|
|
2189
|
+
|
|
2190
|
+
a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
|
|
2191
|
+
if gather_A:
|
|
2192
|
+
assert a_major == "k"
|
|
2193
|
+
a_idx = torch.randperm(l * m, dtype=torch.int32, device="cuda")
|
|
2194
|
+
from einops import rearrange
|
|
2195
|
+
|
|
2196
|
+
a = rearrange(rearrange(a, "m k l -> (m l) k")[a_idx.cpu()], "(m l) k -> m k l", m=m)
|
|
2197
|
+
a_torch = rearrange(a_torch, "m k l -> (m l) k")
|
|
2198
|
+
mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2199
|
+
a_idx_reshaped = rearrange(a_idx, "(m l) -> l m", m=m).contiguous().transpose(0, 1)
|
|
2200
|
+
mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
2201
|
+
else:
|
|
2202
|
+
mAIdx = None
|
|
2203
|
+
b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
|
2204
|
+
_, mD, d_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
|
|
2205
|
+
if c_dtype is not None:
|
|
2206
|
+
c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
|
2207
|
+
else:
|
|
2208
|
+
c, mC, c_torch = None, None, None
|
|
2209
|
+
if varlen_m:
|
|
2210
|
+
assert a_major == "k"
|
|
2211
|
+
assert d_major == "n"
|
|
2212
|
+
from einops import rearrange
|
|
2213
|
+
|
|
2214
|
+
a, d_torch = [rearrange(t, "m x l -> (l m) x") for t in (a, d_torch)]
|
|
2215
|
+
if not gather_A:
|
|
2216
|
+
(a_torch,) = [rearrange(t, "m x l -> (l m) x") for t in (a_torch,)]
|
|
2217
|
+
if c_dtype is not None:
|
|
2218
|
+
c, c_torch = [rearrange(t, "m x l -> (l m) x") for t in (c, c_torch)]
|
|
2219
|
+
mC = from_dlpack(c_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2220
|
+
mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2221
|
+
mD = from_dlpack(d_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2222
|
+
# TODO: generate random cu_seqlens_m
|
|
2223
|
+
cu_seqlens_m = torch.arange(0, l + 1, dtype=torch.int32, device="cuda") * m
|
|
2224
|
+
mCuSeqlensM = from_dlpack(cu_seqlens_m, assumed_align=64).mark_layout_dynamic(leading_dim=0)
|
|
2225
|
+
if gather_A:
|
|
2226
|
+
a_idx_reshaped = rearrange(a_idx_reshaped, "m l -> (l m)")
|
|
2227
|
+
mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
2228
|
+
else:
|
|
2229
|
+
cu_seqlens_m, mCuSeqlensM = None, None
|
|
2230
|
+
|
|
2231
|
+
if varlen_m: # Need to allocate space in gmem to store tensormaps
|
|
2232
|
+
if not persistent:
|
|
2233
|
+
total_m = m * l
|
|
2234
|
+
block_size_m = tile_shape_mnk[0] * cluster_shape_mnk[0]
|
|
2235
|
+
block_size_n = tile_shape_mnk[1] * cluster_shape_mnk[1]
|
|
2236
|
+
total_clusters_m_max = (total_m + l * (block_size_m - 1)) // block_size_m
|
|
2237
|
+
total_clusters_max = total_clusters_m_max * ((n + block_size_n - 1) // block_size_n)
|
|
2238
|
+
total_ctas = total_clusters_max * cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
|
2239
|
+
else:
|
|
2240
|
+
total_ctas = cutlass.utils.HardwareInfo().get_device_multiprocessor_count()
|
|
2241
|
+
if pingpong:
|
|
2242
|
+
total_ctas *= 2
|
|
2243
|
+
# 128 bytes per tensormap
|
|
2244
|
+
tensormaps_torch = torch.empty(total_ctas, 128 // 8, dtype=torch.int64, device="cuda")
|
|
2245
|
+
tensormaps_tensor = from_dlpack(
|
|
2246
|
+
tensormaps_torch, assumed_align=128
|
|
2247
|
+
).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
|
|
2248
|
+
else:
|
|
2249
|
+
tensormaps_tensor = None
|
|
2250
|
+
|
|
2251
|
+
gemm = HopperWgmmaGemmKernel(
|
|
2252
|
+
acc_dtype,
|
|
2253
|
+
a_dtype,
|
|
2254
|
+
tile_shape_mnk,
|
|
2255
|
+
cluster_shape_mnk,
|
|
2256
|
+
pingpong=pingpong,
|
|
2257
|
+
is_persistent=persistent,
|
|
2258
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
2259
|
+
gather_A=gather_A,
|
|
2260
|
+
)
|
|
2261
|
+
|
|
2262
|
+
# Compute max active clusters on current device
|
|
2263
|
+
if persistent:
|
|
2264
|
+
max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
|
2265
|
+
cluster_shape_mn[0] * cluster_shape_mn[1]
|
|
2266
|
+
)
|
|
2267
|
+
if dynamic_persistent:
|
|
2268
|
+
tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda")
|
|
2269
|
+
else:
|
|
2270
|
+
tile_count_semaphore = None
|
|
2271
|
+
# max_active_clusters = 1
|
|
2272
|
+
else:
|
|
2273
|
+
max_active_clusters = 0
|
|
2274
|
+
tile_count_semaphore = None
|
|
2275
|
+
|
|
2276
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
2277
|
+
# compile gemm kernel
|
|
2278
|
+
compiled_gemm = cute.compile(
|
|
2279
|
+
gemm,
|
|
2280
|
+
mA,
|
|
2281
|
+
mB,
|
|
2282
|
+
mD,
|
|
2283
|
+
mC,
|
|
2284
|
+
mAIdx,
|
|
2285
|
+
mCuSeqlensM,
|
|
2286
|
+
tensormaps_tensor,
|
|
2287
|
+
make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2288
|
+
if tile_count_semaphore is not None
|
|
2289
|
+
else None,
|
|
2290
|
+
max_active_clusters,
|
|
2291
|
+
current_stream,
|
|
2292
|
+
)
|
|
2293
|
+
|
|
2294
|
+
if not skip_ref_check:
|
|
2295
|
+
# execution
|
|
2296
|
+
compiled_gemm(
|
|
2297
|
+
mA,
|
|
2298
|
+
mB,
|
|
2299
|
+
mD,
|
|
2300
|
+
mC,
|
|
2301
|
+
mAIdx,
|
|
2302
|
+
mCuSeqlensM,
|
|
2303
|
+
tensormaps_tensor,
|
|
2304
|
+
tile_count_semaphore,
|
|
2305
|
+
max_active_clusters,
|
|
2306
|
+
current_stream,
|
|
2307
|
+
)
|
|
2308
|
+
if tile_count_semaphore is not None and varlen_m:
|
|
2309
|
+
tile_count_semaphore.zero_()
|
|
2310
|
+
|
|
2311
|
+
torch.cuda.synchronize()
|
|
2312
|
+
|
|
2313
|
+
# Ref check
|
|
2314
|
+
if not varlen_m:
|
|
2315
|
+
ref = torch.einsum("mkl,nkl->mnl", a, b)
|
|
2316
|
+
else:
|
|
2317
|
+
ref = torch.cat(
|
|
2318
|
+
[
|
|
2319
|
+
torch.einsum("mk,nk->mn", a[cu_seqlens_m[i] : cu_seqlens_m[i + 1]], b[:, :, i])
|
|
2320
|
+
for i in range(l)
|
|
2321
|
+
],
|
|
2322
|
+
dim=0,
|
|
2323
|
+
)
|
|
2324
|
+
if c is not None:
|
|
2325
|
+
ref = ref + c
|
|
2326
|
+
ref = ref.cpu()
|
|
2327
|
+
|
|
2328
|
+
if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
|
2329
|
+
# m major: (l, n, m) -> (m, n, l)
|
|
2330
|
+
# n major: (l, m, n) -> (m, n, l)
|
|
2331
|
+
permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
|
|
2332
|
+
shape = (l, m, n) if d_major == "n" else (l, n, m)
|
|
2333
|
+
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
|
2334
|
+
shape,
|
|
2335
|
+
torch.uint8,
|
|
2336
|
+
permute_order=permute_order,
|
|
2337
|
+
init_type=cutlass_torch.TensorInitType.SKIP,
|
|
2338
|
+
).cuda()
|
|
2339
|
+
# Create dtype cute tensor (gpu)
|
|
2340
|
+
ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
|
|
2341
|
+
leading_dim=(1 if d_major == "n" else 0)
|
|
2342
|
+
)
|
|
2343
|
+
ref_d_tensor.element_type = d_dtype
|
|
2344
|
+
ref_d_tensor = cutlass_torch.convert_cute_tensor(
|
|
2345
|
+
ref,
|
|
2346
|
+
ref_d_tensor,
|
|
2347
|
+
d_dtype,
|
|
2348
|
+
is_dynamic_layout=True,
|
|
2349
|
+
)
|
|
2350
|
+
ref_d = f8_torch_tensor.cpu()
|
|
2351
|
+
else:
|
|
2352
|
+
ref_d = ref.to(cutlass_torch.dtype(d_dtype))
|
|
2353
|
+
|
|
2354
|
+
out = d_torch.cpu().squeeze()
|
|
2355
|
+
out_ref = ref_d.squeeze()
|
|
2356
|
+
# breakpoint()
|
|
2357
|
+
torch.testing.assert_close(d_torch.cpu(), ref_d, atol=tolerance, rtol=1e-03)
|
|
2358
|
+
|
|
2359
|
+
# return
|
|
2360
|
+
|
|
2361
|
+
from triton.testing import do_bench
|
|
2362
|
+
|
|
2363
|
+
flops = 2 * m * n * k * l
|
|
2364
|
+
# Calculate memory bandwidth
|
|
2365
|
+
bytes_A = m * k * l * (a_dtype.width // 8) # A tensor: (m, k, l)
|
|
2366
|
+
bytes_B = n * k * l * (b_dtype.width // 8) # B tensor: (n, k, l)
|
|
2367
|
+
bytes_D = m * n * l * (d_dtype.width // 8) # D tensor: (m, n, l)
|
|
2368
|
+
bytes_C = m * n * l * (c_dtype.width // 8) if c_dtype is not None else 0 # C tensor: (m, n, l)
|
|
2369
|
+
total_bytes = bytes_A + bytes_B + bytes_D + bytes_C # Read A, B, C; Write D
|
|
2370
|
+
|
|
2371
|
+
repeats = iterations
|
|
2372
|
+
warmup = warmup_iterations
|
|
2373
|
+
|
|
2374
|
+
import time
|
|
2375
|
+
|
|
2376
|
+
if not varlen_m and not gather_A:
|
|
2377
|
+
time.sleep(0.5)
|
|
2378
|
+
if a_dtype.width == 8:
|
|
2379
|
+
assert l == 1
|
|
2380
|
+
scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
|
|
2381
|
+
fn_cublas = lambda: torch._scaled_mm(
|
|
2382
|
+
a_torch[:, :, 0],
|
|
2383
|
+
b_torch[:, :, 0].mT,
|
|
2384
|
+
scale_a=scale_ab,
|
|
2385
|
+
scale_b=scale_ab,
|
|
2386
|
+
out_dtype=torch.bfloat16,
|
|
2387
|
+
use_fast_accum=fp8_fast_accum,
|
|
2388
|
+
)
|
|
2389
|
+
else:
|
|
2390
|
+
if c_torch is None:
|
|
2391
|
+
fn_cublas = lambda: torch.matmul(
|
|
2392
|
+
a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT
|
|
2393
|
+
)
|
|
2394
|
+
else:
|
|
2395
|
+
c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
|
|
2396
|
+
fn_cublas = lambda: torch.baddbmm(
|
|
2397
|
+
c_torch_convert.permute(2, 0, 1),
|
|
2398
|
+
a_torch.permute(2, 0, 1),
|
|
2399
|
+
b_torch.permute(2, 0, 1).mT,
|
|
2400
|
+
)
|
|
2401
|
+
timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2402
|
+
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2403
|
+
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2404
|
+
|
|
2405
|
+
time.sleep(0.5)
|
|
2406
|
+
|
|
2407
|
+
def fn():
|
|
2408
|
+
compiled_gemm(
|
|
2409
|
+
mA,
|
|
2410
|
+
mB,
|
|
2411
|
+
mD,
|
|
2412
|
+
mC,
|
|
2413
|
+
mAIdx,
|
|
2414
|
+
mCuSeqlensM,
|
|
2415
|
+
tensormaps_tensor,
|
|
2416
|
+
tile_count_semaphore,
|
|
2417
|
+
max_active_clusters,
|
|
2418
|
+
current_stream,
|
|
2419
|
+
)
|
|
2420
|
+
if tile_count_semaphore is not None and varlen_m:
|
|
2421
|
+
tile_count_semaphore.zero_()
|
|
2422
|
+
|
|
2423
|
+
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
2424
|
+
# Idk why but for some cases the 1st run is much slower
|
|
2425
|
+
time.sleep(0.5)
|
|
2426
|
+
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
2427
|
+
tflops = flops / (timing * 1e9) # Convert to TFlops
|
|
2428
|
+
gbps = total_bytes / (timing * 1e6) # Convert to GB/s (1e9 for ms->s, 1e9 for B->GB)
|
|
2429
|
+
print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}, GB/s: {gbps:.0f}")
|
|
2430
|
+
fn()
|
|
2431
|
+
|
|
2432
|
+
if not varlen_m:
|
|
2433
|
+
time.sleep(0.5)
|
|
2434
|
+
timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2435
|
+
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2436
|
+
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2437
|
+
|
|
2438
|
+
from flash_attn.utils.benchmark import pytorch_profiler
|
|
2439
|
+
|
|
2440
|
+
pytorch_profiler(fn_cublas)
|
|
2441
|
+
# pytorch_profiler(torch.sort, d_torch.squeeze(), dim=-1)
|
|
2442
|
+
# pytorch_profiler(torch.compile(torch.sort), d_torch.squeeze(), dim=-1)
|
|
2443
|
+
# pytorch_profiler(torch.topk, d_torch.squeeze(), dim=-1, k=1)
|
|
2444
|
+
# pytorch_profiler(torch.compile(torch.topk), d_torch.squeeze(), dim=-1, k=1)
|
|
2445
|
+
# pytorch_profiler(torch.square, d_torch.squeeze())
|
|
2446
|
+
|
|
2447
|
+
|
|
2448
|
+
if __name__ == "__main__":
|
|
2449
|
+
args = parse_arguments()
|
|
2450
|
+
run(
|
|
2451
|
+
args.mnkl,
|
|
2452
|
+
args.a_dtype,
|
|
2453
|
+
args.b_dtype,
|
|
2454
|
+
args.d_dtype,
|
|
2455
|
+
args.c_dtype,
|
|
2456
|
+
args.acc_dtype,
|
|
2457
|
+
args.a_major,
|
|
2458
|
+
args.b_major,
|
|
2459
|
+
args.d_major,
|
|
2460
|
+
args.c_major,
|
|
2461
|
+
args.tile_shape_mnk,
|
|
2462
|
+
args.cluster_shape_mn,
|
|
2463
|
+
args.tolerance,
|
|
2464
|
+
args.warmup_iterations,
|
|
2465
|
+
args.iterations,
|
|
2466
|
+
args.skip_ref_check,
|
|
2467
|
+
args.persistent,
|
|
2468
|
+
args.dynamic_persistent,
|
|
2469
|
+
args.pingpong,
|
|
2470
|
+
args.varlen_m,
|
|
2471
|
+
args.gather_A,
|
|
2472
|
+
args.fp8_fast_accum,
|
|
2473
|
+
)
|
|
2474
|
+
print("PASS")
|