quack-kernels 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +288 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +61 -59
- quack/topk.py +14 -8
- quack/utils.py +14 -259
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/dense_gemm_sm90.py
CHANGED
|
@@ -1,63 +1,43 @@
|
|
|
1
|
-
#
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
# Redistribution and use in source and binary forms, with or without
|
|
5
|
-
# modification, are permitted provided that the following conditions are met:
|
|
6
|
-
|
|
7
|
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
-
# list of conditions and the following disclaimer.
|
|
9
|
-
|
|
10
|
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
-
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
-
# and/or other materials provided with the distribution.
|
|
13
|
-
|
|
14
|
-
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
-
# contributors may be used to endorse or promote products derived from
|
|
16
|
-
# this software without specific prior written permission.
|
|
17
|
-
|
|
18
|
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
|
-
|
|
29
|
-
import argparse
|
|
1
|
+
# Based on the cute-dsl example:
|
|
2
|
+
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
|
|
3
|
+
|
|
30
4
|
import enum
|
|
31
|
-
from typing import Tuple, Type, Callable, Optional
|
|
5
|
+
from typing import Tuple, Type, Callable, Optional, Union
|
|
6
|
+
from dataclasses import dataclass
|
|
32
7
|
from functools import partial
|
|
33
8
|
import math
|
|
34
9
|
|
|
35
|
-
|
|
10
|
+
from torch import Tensor
|
|
36
11
|
|
|
37
|
-
import
|
|
12
|
+
import cuda.bindings.driver as cuda
|
|
38
13
|
|
|
39
14
|
import cutlass
|
|
40
15
|
import cutlass.cute as cute
|
|
41
16
|
import cutlass.pipeline as pipeline
|
|
42
|
-
import cutlass.torch as cutlass_torch
|
|
43
|
-
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
44
17
|
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
45
18
|
import cutlass.utils.hopper_helpers as sm90_utils
|
|
46
|
-
from cutlass import Int32, const_expr
|
|
19
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
20
|
+
from cutlass.utils import LayoutEnum
|
|
21
|
+
import cutlass.torch as cutlass_torch
|
|
22
|
+
from cutlass.cute.runtime import make_ptr
|
|
23
|
+
|
|
47
24
|
|
|
25
|
+
from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
|
|
48
26
|
from quack.tile_scheduler import (
|
|
27
|
+
TileSchedulerOptions,
|
|
49
28
|
TileSchedulerArguments,
|
|
50
29
|
TileScheduler,
|
|
51
30
|
VarlenMTileSchedulerArguments,
|
|
52
31
|
VarlenMTileScheduler,
|
|
53
|
-
ParamsBase,
|
|
54
|
-
RasterOrderOption,
|
|
55
32
|
)
|
|
33
|
+
from quack.varlen_utils import VarlenArguments
|
|
56
34
|
from quack.tensormap_manager import TensorMapManagerSm90
|
|
57
35
|
|
|
58
36
|
# return PipelineStateWAdvance instead of PipelineState
|
|
59
37
|
from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
|
|
60
38
|
import quack.utils as utils
|
|
39
|
+
from quack.cute_dsl_utils import get_max_active_clusters
|
|
40
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
61
41
|
|
|
62
42
|
"""
|
|
63
43
|
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
|
@@ -82,31 +62,6 @@ Hopper WGMMA instructions operate as follows:
|
|
|
82
62
|
- Read matrix B from SMEM
|
|
83
63
|
- Perform MMA operation and store the result in Accumulator(register)
|
|
84
64
|
|
|
85
|
-
To run this example:
|
|
86
|
-
|
|
87
|
-
.. code-block:: bash
|
|
88
|
-
|
|
89
|
-
python examples/hopper/dense_gemm.py \
|
|
90
|
-
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
|
91
|
-
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
|
92
|
-
--d_dtype Float16 --acc_dtype Float32 \
|
|
93
|
-
--a_major k --b_major k --d_major n
|
|
94
|
-
|
|
95
|
-
The above example command compute batched gemm with M=8192, N=8192, K=8192,
|
|
96
|
-
batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape
|
|
97
|
-
is (1,1). The input, mma accumulator and output data type are set as fp16, fp32
|
|
98
|
-
and fp16, respectively.
|
|
99
|
-
|
|
100
|
-
To collect performance with NCU profiler:
|
|
101
|
-
|
|
102
|
-
.. code-block:: bash
|
|
103
|
-
|
|
104
|
-
ncu python examples/hopper/dense_gemm.py \
|
|
105
|
-
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
|
106
|
-
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
|
107
|
-
--d_dtype Float16 --acc_dtype Float32 \
|
|
108
|
-
--a_major k --b_major k --d_major n
|
|
109
|
-
|
|
110
65
|
Constraints:
|
|
111
66
|
* Supported input data types: fp16, fp8 (e4m3fn, e5m2)
|
|
112
67
|
* For fp16 types, A and B must have the same data type
|
|
@@ -119,106 +74,9 @@ Constraints:
|
|
|
119
74
|
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
|
120
75
|
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
|
121
76
|
i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
|
|
122
|
-
* OOB tiles are not allowed when TMA store is disabled
|
|
123
77
|
"""
|
|
124
78
|
|
|
125
79
|
|
|
126
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
127
|
-
# Helpers to parse args
|
|
128
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
129
|
-
def parse_comma_separated_ints(s: str):
|
|
130
|
-
try:
|
|
131
|
-
return tuple([int(x.strip()) for x in s.split(",")])
|
|
132
|
-
except ValueError:
|
|
133
|
-
raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def parse_arguments() -> argparse.Namespace:
|
|
137
|
-
parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.")
|
|
138
|
-
|
|
139
|
-
parser.add_argument(
|
|
140
|
-
"--mnkl",
|
|
141
|
-
type=parse_comma_separated_ints,
|
|
142
|
-
default=(4096, 4096, 4096, 1),
|
|
143
|
-
help="mnkl dimensions (comma-separated)",
|
|
144
|
-
)
|
|
145
|
-
parser.add_argument(
|
|
146
|
-
"--tile_shape_mnk",
|
|
147
|
-
type=parse_comma_separated_ints,
|
|
148
|
-
default=(128, 256, 64),
|
|
149
|
-
help="Cta tile shape (comma-separated)",
|
|
150
|
-
)
|
|
151
|
-
parser.add_argument(
|
|
152
|
-
"--cluster_shape_mn",
|
|
153
|
-
type=parse_comma_separated_ints,
|
|
154
|
-
choices=[(1, 1), (2, 1), (1, 2), (2, 2)],
|
|
155
|
-
default=(1, 1),
|
|
156
|
-
help="Cluster shape (comma-separated)",
|
|
157
|
-
)
|
|
158
|
-
parser.add_argument(
|
|
159
|
-
"--a_dtype",
|
|
160
|
-
type=cutlass.dtype,
|
|
161
|
-
default=cutlass.BFloat16,
|
|
162
|
-
)
|
|
163
|
-
parser.add_argument(
|
|
164
|
-
"--b_dtype",
|
|
165
|
-
type=cutlass.dtype,
|
|
166
|
-
default=cutlass.BFloat16,
|
|
167
|
-
)
|
|
168
|
-
parser.add_argument(
|
|
169
|
-
"--d_dtype",
|
|
170
|
-
type=cutlass.dtype,
|
|
171
|
-
default=cutlass.BFloat16,
|
|
172
|
-
)
|
|
173
|
-
parser.add_argument(
|
|
174
|
-
"--c_dtype",
|
|
175
|
-
type=cutlass.dtype,
|
|
176
|
-
default=None,
|
|
177
|
-
)
|
|
178
|
-
parser.add_argument(
|
|
179
|
-
"--acc_dtype",
|
|
180
|
-
type=cutlass.dtype,
|
|
181
|
-
default=cutlass.Float32,
|
|
182
|
-
)
|
|
183
|
-
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
184
|
-
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
185
|
-
parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
|
|
186
|
-
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
|
|
187
|
-
parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
|
|
188
|
-
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
|
|
189
|
-
parser.add_argument(
|
|
190
|
-
"--iterations",
|
|
191
|
-
type=int,
|
|
192
|
-
default=30,
|
|
193
|
-
help="Number of iterations to run the kernel",
|
|
194
|
-
)
|
|
195
|
-
parser.add_argument("--persistent", action="store_true", help="Persistent kernel")
|
|
196
|
-
parser.add_argument(
|
|
197
|
-
"--dynamic_persistent", action="store_true", help="Dynamic persistent kernel"
|
|
198
|
-
)
|
|
199
|
-
parser.add_argument("--pingpong", action="store_true", help="Pingpong kernel")
|
|
200
|
-
parser.add_argument("--varlen_m", action="store_true", help="Variable length M dimension")
|
|
201
|
-
parser.add_argument("--gather_A", action="store_true", help="Gather A")
|
|
202
|
-
parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum")
|
|
203
|
-
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
|
|
204
|
-
|
|
205
|
-
args = parser.parse_args()
|
|
206
|
-
|
|
207
|
-
if len(args.mnkl) != 4:
|
|
208
|
-
parser.error("--mnkl must contain exactly 4 values")
|
|
209
|
-
if len(args.tile_shape_mnk) != 3:
|
|
210
|
-
parser.error("--tile_shape_mnk must contain exactly 3 values")
|
|
211
|
-
if len(args.cluster_shape_mn) != 2:
|
|
212
|
-
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
|
213
|
-
|
|
214
|
-
return args
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
218
|
-
# Host setup and device kernel launch
|
|
219
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
220
|
-
|
|
221
|
-
|
|
222
80
|
class NamedBarrierGemm(enum.IntEnum):
|
|
223
81
|
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
|
224
82
|
# For mainloop load warps to signal that the epilogue load warp can start.
|
|
@@ -230,15 +88,15 @@ class NamedBarrierGemm(enum.IntEnum):
|
|
|
230
88
|
EpiWG1 = enum.auto()
|
|
231
89
|
|
|
232
90
|
|
|
233
|
-
class
|
|
91
|
+
class GemmSm90:
|
|
234
92
|
"""
|
|
235
93
|
This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
236
94
|
and architectural features specific to Hopper GPUs.
|
|
237
95
|
|
|
238
96
|
:param acc_dtype: Data type for accumulation during computation
|
|
239
97
|
:type acc_dtype: type[cutlass.Numeric]
|
|
240
|
-
:param
|
|
241
|
-
:type
|
|
98
|
+
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
|
99
|
+
:type tile_shape_mn: Tuple[int, int, int]
|
|
242
100
|
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
|
243
101
|
:type cluster_shape_mnk: Tuple[int, int, int]
|
|
244
102
|
|
|
@@ -259,22 +117,31 @@ class HopperWgmmaGemmKernel:
|
|
|
259
117
|
- Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
|
260
118
|
|
|
261
119
|
Example:
|
|
262
|
-
>>> gemm =
|
|
120
|
+
>>> gemm = GemmSm90(
|
|
263
121
|
... acc_dtype=cutlass.Float32,
|
|
264
|
-
...
|
|
122
|
+
... tile_shape_mn=(128, 256),
|
|
265
123
|
... cluster_shape_mnk=(1, 1, 1)
|
|
266
124
|
... )
|
|
267
125
|
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
|
268
126
|
"""
|
|
269
127
|
|
|
270
128
|
bytes_per_tensormap = 128
|
|
271
|
-
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class EpilogueArguments(ArgumentsBase):
|
|
132
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
133
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
134
|
+
|
|
135
|
+
@dataclass
|
|
136
|
+
class EpilogueParams(ParamsBase):
|
|
137
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
138
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
272
139
|
|
|
273
140
|
def __init__(
|
|
274
141
|
self,
|
|
275
142
|
acc_dtype: Type[cutlass.Numeric],
|
|
276
143
|
a_dtype: Type[cutlass.Numeric],
|
|
277
|
-
|
|
144
|
+
tile_shape_mn: Tuple[int, int],
|
|
278
145
|
cluster_shape_mnk: Tuple[int, int, int],
|
|
279
146
|
pingpong: bool = False,
|
|
280
147
|
is_persistent: bool = True,
|
|
@@ -289,8 +156,8 @@ class HopperWgmmaGemmKernel:
|
|
|
289
156
|
|
|
290
157
|
:param acc_dtype: Data type for accumulation during computation
|
|
291
158
|
:type acc_dtype: type[cutlass.Numeric]
|
|
292
|
-
:param
|
|
293
|
-
:type
|
|
159
|
+
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
|
160
|
+
:type tile_shape_mn: Tuple[int, int]
|
|
294
161
|
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
|
295
162
|
:type cluster_shape_mnk: Tuple[int, int, int]
|
|
296
163
|
"""
|
|
@@ -304,11 +171,11 @@ class HopperWgmmaGemmKernel:
|
|
|
304
171
|
self.gather_A = gather_A
|
|
305
172
|
if gather_A:
|
|
306
173
|
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
|
307
|
-
self.tensormap_update_mode = cutlass.utils.TensorMapUpdateMode.SMEM
|
|
308
174
|
|
|
309
175
|
self.cluster_shape_mnk = cluster_shape_mnk
|
|
310
|
-
|
|
311
|
-
|
|
176
|
+
# K dimension is deferred in _setup_attributes
|
|
177
|
+
self.tile_shape_mnk = (*tile_shape_mn, 1)
|
|
178
|
+
tile_M, tile_N = self.tile_shape_mnk[0], self.tile_shape_mnk[1]
|
|
312
179
|
# check the cta tile shape
|
|
313
180
|
if not self.pingpong:
|
|
314
181
|
if tile_M not in [64, 128, 192, 256, 320]:
|
|
@@ -332,8 +199,6 @@ class HopperWgmmaGemmKernel:
|
|
|
332
199
|
tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
|
|
333
200
|
if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
|
|
334
201
|
raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
|
|
335
|
-
if not self.tile_shape_mnk[2] % 16 == 0:
|
|
336
|
-
raise ValueError("CTA tile shape K must be divisible by 16")
|
|
337
202
|
|
|
338
203
|
if not self.pingpong:
|
|
339
204
|
if tile_M == 320: # tile_M / 64 is not even so we have to split along N
|
|
@@ -344,7 +209,7 @@ class HopperWgmmaGemmKernel:
|
|
|
344
209
|
else:
|
|
345
210
|
atom_layout_m, atom_layout_n = 1, 2
|
|
346
211
|
else:
|
|
347
|
-
atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
|
|
212
|
+
atom_layout_m = self.tile_shape_mnk[0] // 64 if self.tile_shape_mnk[0] < 256 else 2
|
|
348
213
|
atom_layout_n = 1
|
|
349
214
|
assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
|
|
350
215
|
else:
|
|
@@ -403,7 +268,7 @@ class HopperWgmmaGemmKernel:
|
|
|
403
268
|
self.shared_storage = None
|
|
404
269
|
self.buffer_align_bytes = 1024
|
|
405
270
|
|
|
406
|
-
def _setup_attributes(self):
|
|
271
|
+
def _setup_attributes(self, epilogue_args: Optional[EpilogueArguments]):
|
|
407
272
|
"""Set up configurations that are dependent on GEMM inputs
|
|
408
273
|
|
|
409
274
|
This method configures various attributes based on the input tensor properties
|
|
@@ -417,6 +282,38 @@ class HopperWgmmaGemmKernel:
|
|
|
417
282
|
- Computing A/B/C shared memory layout
|
|
418
283
|
"""
|
|
419
284
|
|
|
285
|
+
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
|
286
|
+
self.a_dtype,
|
|
287
|
+
self.b_dtype,
|
|
288
|
+
self.a_layout.sm90_mma_major_mode(),
|
|
289
|
+
self.b_layout.sm90_mma_major_mode(),
|
|
290
|
+
self.acc_dtype,
|
|
291
|
+
self.atom_layout_mnk,
|
|
292
|
+
tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
|
293
|
+
)
|
|
294
|
+
if const_expr(self.atom_layout_mnk[1] > 1):
|
|
295
|
+
# If N dimension is split among 2 WGs, we need to permute the N dimension so
|
|
296
|
+
# that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
|
|
297
|
+
# containing accumulators that are next to each other in the N dimension.
|
|
298
|
+
# Without permutation WG0 would write to epi smem of size (64, 16) and
|
|
299
|
+
# WG1 would write to a separate epi smem of size (64, 16) that's far away.
|
|
300
|
+
atom_n = self.atom_layout_mnk[1]
|
|
301
|
+
permutation_n = cute.make_ordered_layout(
|
|
302
|
+
(8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
|
|
303
|
+
)
|
|
304
|
+
self.tiled_mma = cute.make_tiled_mma(
|
|
305
|
+
cute.make_mma_atom(self.tiled_mma.op),
|
|
306
|
+
self.atom_layout_mnk,
|
|
307
|
+
permutation_mnk=(None, permutation_n, None),
|
|
308
|
+
)
|
|
309
|
+
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
|
|
310
|
+
mma_inst_tile_k = 4
|
|
311
|
+
self.tile_shape_mnk = (
|
|
312
|
+
self.tile_shape_mnk[0],
|
|
313
|
+
self.tile_shape_mnk[1],
|
|
314
|
+
mma_inst_shape_k * mma_inst_tile_k,
|
|
315
|
+
)
|
|
316
|
+
|
|
420
317
|
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
|
421
318
|
|
|
422
319
|
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
|
@@ -433,6 +330,7 @@ class HopperWgmmaGemmKernel:
|
|
|
433
330
|
self.b_dtype,
|
|
434
331
|
self.d_dtype,
|
|
435
332
|
self.c_dtype,
|
|
333
|
+
epilogue_args,
|
|
436
334
|
self.smem_capacity,
|
|
437
335
|
self.occupancy,
|
|
438
336
|
# epi_smem will reuse smem ab if not persistent.
|
|
@@ -466,13 +364,12 @@ class HopperWgmmaGemmKernel:
|
|
|
466
364
|
self,
|
|
467
365
|
mA: cute.Tensor,
|
|
468
366
|
mB: cute.Tensor,
|
|
469
|
-
mD: cute.Tensor,
|
|
367
|
+
mD: Optional[cute.Tensor],
|
|
470
368
|
mC: Optional[cute.Tensor],
|
|
369
|
+
epilogue_args: Optional[ArgumentsBase],
|
|
370
|
+
scheduler_args: TileSchedulerOptions,
|
|
371
|
+
varlen_args: Optional[VarlenArguments],
|
|
471
372
|
mAIdx: Optional[cute.Tensor],
|
|
472
|
-
mCuSeqlensM: Optional[cute.Tensor],
|
|
473
|
-
mTensormaps: Optional[cute.Tensor],
|
|
474
|
-
tile_count_semaphore: Optional[cute.Pointer],
|
|
475
|
-
max_active_clusters: Int32,
|
|
476
373
|
stream: cuda.CUstream,
|
|
477
374
|
):
|
|
478
375
|
"""Execute the GEMM operation in steps:
|
|
@@ -495,12 +392,12 @@ class HopperWgmmaGemmKernel:
|
|
|
495
392
|
# setup static attributes before smem/grid/tma computation
|
|
496
393
|
self.a_dtype = mA.element_type
|
|
497
394
|
self.b_dtype = mB.element_type
|
|
498
|
-
self.d_dtype = mD.element_type
|
|
395
|
+
self.d_dtype = mD.element_type if mD is not None else None
|
|
499
396
|
self.c_dtype = mC.element_type if mC is not None else None
|
|
500
|
-
self.a_layout =
|
|
501
|
-
self.b_layout =
|
|
502
|
-
self.d_layout =
|
|
503
|
-
self.c_layout =
|
|
397
|
+
self.a_layout = LayoutEnum.from_tensor(mA)
|
|
398
|
+
self.b_layout = LayoutEnum.from_tensor(mB)
|
|
399
|
+
self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
|
|
400
|
+
self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
|
|
504
401
|
|
|
505
402
|
if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
|
|
506
403
|
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
|
|
@@ -517,35 +414,12 @@ class HopperWgmmaGemmKernel:
|
|
|
517
414
|
)
|
|
518
415
|
mA, mD = [
|
|
519
416
|
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
417
|
+
if t is not None
|
|
418
|
+
else None
|
|
520
419
|
for t in (mA, mD)
|
|
521
420
|
]
|
|
522
421
|
|
|
523
|
-
self._setup_attributes()
|
|
524
|
-
|
|
525
|
-
tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
|
526
|
-
self.a_dtype,
|
|
527
|
-
self.b_dtype,
|
|
528
|
-
self.a_layout.sm90_mma_major_mode(),
|
|
529
|
-
self.b_layout.sm90_mma_major_mode(),
|
|
530
|
-
self.acc_dtype,
|
|
531
|
-
self.atom_layout_mnk,
|
|
532
|
-
tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
|
533
|
-
)
|
|
534
|
-
if const_expr(self.atom_layout_mnk[1] > 1):
|
|
535
|
-
# If N dimension is split among 2 WGs, we need to permute the N dimension so
|
|
536
|
-
# that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
|
|
537
|
-
# containing accumulators that are next to each other in the N dimension.
|
|
538
|
-
# Without permutation WG0 would write to epi smem of size (64, 16) and
|
|
539
|
-
# WG1 would write to a separate epi smem of size (64, 16) that's far away.
|
|
540
|
-
atom_n = self.atom_layout_mnk[1]
|
|
541
|
-
permutation_n = cute.make_ordered_layout(
|
|
542
|
-
(8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
|
|
543
|
-
)
|
|
544
|
-
tiled_mma = cute.make_tiled_mma(
|
|
545
|
-
cute.make_mma_atom(tiled_mma.op),
|
|
546
|
-
self.atom_layout_mnk,
|
|
547
|
-
permutation_mnk=(None, permutation_n, None),
|
|
548
|
-
)
|
|
422
|
+
self._setup_attributes(epilogue_args)
|
|
549
423
|
|
|
550
424
|
if const_expr(not self.gather_A):
|
|
551
425
|
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
|
@@ -564,9 +438,12 @@ class HopperWgmmaGemmKernel:
|
|
|
564
438
|
self.cluster_shape_mnk[0],
|
|
565
439
|
)
|
|
566
440
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
441
|
+
if const_expr(mD is not None):
|
|
442
|
+
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
|
443
|
+
mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
|
|
444
|
+
)
|
|
445
|
+
else:
|
|
446
|
+
tma_atom_d, tma_tensor_d = None, None
|
|
570
447
|
|
|
571
448
|
if const_expr(mC is not None):
|
|
572
449
|
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
|
@@ -575,65 +452,66 @@ class HopperWgmmaGemmKernel:
|
|
|
575
452
|
else:
|
|
576
453
|
tma_atom_c, tma_tensor_c = None, None
|
|
577
454
|
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
455
|
+
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
|
456
|
+
|
|
457
|
+
if const_expr(varlen_args is None):
|
|
458
|
+
varlen_args = VarlenArguments()
|
|
459
|
+
if const_expr(varlen_args.mCuSeqlensM is None):
|
|
460
|
+
num_problems = (
|
|
461
|
+
mD.shape[2]
|
|
462
|
+
if mD is not None
|
|
463
|
+
else (
|
|
464
|
+
mB.shape[2]
|
|
465
|
+
if varlen_args.mCuSeqlensK is None
|
|
466
|
+
else varlen_args.mCuSeqlensK.shape[0] - 1
|
|
467
|
+
)
|
|
581
468
|
)
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
group_size=8,
|
|
587
|
-
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
588
|
-
tile_count_semaphore=tile_count_semaphore,
|
|
589
|
-
is_persistent=self.is_persistent,
|
|
469
|
+
problem_shape_ntile_mnl = (
|
|
470
|
+
cute.ceil_div(mA.shape[0], self.tile_shape_mnk[0]),
|
|
471
|
+
cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
|
|
472
|
+
num_problems,
|
|
590
473
|
)
|
|
474
|
+
TileSchedulerCls = self.get_scheduler_class()
|
|
475
|
+
tile_sched_args = self.get_scheduler_arguments(problem_shape_ntile_mnl, scheduler_args)
|
|
591
476
|
else:
|
|
592
|
-
assert
|
|
477
|
+
assert mD is not None or not self.gather_A
|
|
593
478
|
problem_shape_ntile_mnl = (
|
|
594
479
|
None,
|
|
595
|
-
cute.ceil_div(
|
|
596
|
-
mCuSeqlensM.shape[0] - 1,
|
|
480
|
+
cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
|
|
481
|
+
varlen_args.mCuSeqlensM.shape[0] - 1,
|
|
597
482
|
)
|
|
598
483
|
TileSchedulerCls = VarlenMTileScheduler
|
|
599
484
|
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
600
485
|
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
601
|
-
total_m=mD.shape[0],
|
|
602
|
-
cu_seqlens_m=mCuSeqlensM,
|
|
603
|
-
raster_order=
|
|
604
|
-
group_size=
|
|
605
|
-
|
|
486
|
+
total_m=mD.shape[0] if mD is not None else mAIdx.shape[0],
|
|
487
|
+
cu_seqlens_m=varlen_args.mCuSeqlensM,
|
|
488
|
+
raster_order=scheduler_args.raster_order,
|
|
489
|
+
group_size=scheduler_args.max_swizzle_size,
|
|
490
|
+
tile_shape_mn=self.tile_shape_mnk[:2],
|
|
606
491
|
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
607
|
-
tile_count_semaphore=tile_count_semaphore,
|
|
492
|
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
608
493
|
is_persistent=self.is_persistent,
|
|
609
494
|
)
|
|
610
495
|
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
611
|
-
grid = TileSchedulerCls.get_grid_shape(
|
|
496
|
+
grid = TileSchedulerCls.get_grid_shape(
|
|
497
|
+
tile_sched_params, scheduler_args.max_active_clusters
|
|
498
|
+
)
|
|
612
499
|
|
|
613
|
-
epi_smem_size =
|
|
500
|
+
epi_smem_size = (
|
|
501
|
+
cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
|
|
502
|
+
)
|
|
614
503
|
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
|
615
504
|
|
|
616
|
-
size_tensormap_in_i64 = (
|
|
617
|
-
0
|
|
618
|
-
if mCuSeqlensM is None
|
|
619
|
-
or self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.GMEM
|
|
620
|
-
else HopperWgmmaGemmKernel.num_tensormaps
|
|
621
|
-
* HopperWgmmaGemmKernel.bytes_per_tensormap
|
|
622
|
-
// 8
|
|
623
|
-
) * (1 if not self.pingpong else 2)
|
|
624
|
-
|
|
625
505
|
@cute.struct
|
|
626
506
|
class SharedStorage:
|
|
627
|
-
tensormap_buffer: cute.struct.Align[
|
|
628
|
-
cute.struct.MemRange[cutlass.Int64, size_tensormap_in_i64],
|
|
629
|
-
64,
|
|
630
|
-
]
|
|
631
507
|
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
632
508
|
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
633
509
|
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
634
510
|
tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
|
|
635
511
|
sD: cute.struct.Align[
|
|
636
|
-
cute.struct.MemRange[
|
|
512
|
+
cute.struct.MemRange[
|
|
513
|
+
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
514
|
+
],
|
|
637
515
|
self.buffer_align_bytes,
|
|
638
516
|
]
|
|
639
517
|
sC: cute.struct.Align[
|
|
@@ -642,6 +520,7 @@ class HopperWgmmaGemmKernel:
|
|
|
642
520
|
],
|
|
643
521
|
self.buffer_align_bytes,
|
|
644
522
|
]
|
|
523
|
+
epi: self.epi_get_smem_struct(epilogue_params)
|
|
645
524
|
sA: cute.struct.Align[
|
|
646
525
|
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
|
|
647
526
|
self.buffer_align_bytes,
|
|
@@ -661,13 +540,14 @@ class HopperWgmmaGemmKernel:
|
|
|
661
540
|
tma_tensor_b,
|
|
662
541
|
tma_atom_d,
|
|
663
542
|
tma_tensor_d,
|
|
664
|
-
mD,
|
|
665
543
|
tma_atom_c,
|
|
666
544
|
tma_tensor_c,
|
|
545
|
+
epilogue_params,
|
|
667
546
|
mAIdx,
|
|
668
|
-
mCuSeqlensM,
|
|
669
|
-
|
|
670
|
-
|
|
547
|
+
varlen_args.mCuSeqlensM,
|
|
548
|
+
varlen_args.mCuSeqlensK,
|
|
549
|
+
varlen_args.mTensormaps,
|
|
550
|
+
self.tiled_mma,
|
|
671
551
|
self.cluster_layout_mnk,
|
|
672
552
|
self.a_smem_layout_staged,
|
|
673
553
|
self.b_smem_layout_staged,
|
|
@@ -693,20 +573,21 @@ class HopperWgmmaGemmKernel:
|
|
|
693
573
|
mA_mkl: cute.Tensor,
|
|
694
574
|
tma_atom_b: cute.CopyAtom,
|
|
695
575
|
mB_nkl: cute.Tensor,
|
|
696
|
-
tma_atom_d: cute.CopyAtom,
|
|
697
|
-
|
|
698
|
-
mD_mnl: cute.Tensor,
|
|
576
|
+
tma_atom_d: Optional[cute.CopyAtom],
|
|
577
|
+
mD_mnl: Optional[cute.Tensor],
|
|
699
578
|
tma_atom_c: Optional[cute.CopyAtom],
|
|
700
579
|
mC_mnl: Optional[cute.Tensor],
|
|
580
|
+
epilogue_params: ParamsBase,
|
|
701
581
|
mAIdx: Optional[cute.Tensor],
|
|
702
582
|
cu_seqlens_m: Optional[cute.Tensor],
|
|
583
|
+
cu_seqlens_k: Optional[cute.Tensor],
|
|
703
584
|
tensormaps: Optional[cute.Tensor],
|
|
704
585
|
tiled_mma: cute.TiledMma,
|
|
705
586
|
cluster_layout_mnk: cute.Layout,
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
587
|
+
a_smem_layout: cute.ComposedLayout,
|
|
588
|
+
b_smem_layout: cute.ComposedLayout,
|
|
589
|
+
epi_smem_layout: cute.ComposedLayout,
|
|
590
|
+
epi_c_smem_layout: cute.ComposedLayout,
|
|
710
591
|
tile_sched_params: ParamsBase,
|
|
711
592
|
TileSchedulerCls: cutlass.Constexpr[Callable],
|
|
712
593
|
):
|
|
@@ -723,39 +604,35 @@ class HopperWgmmaGemmKernel:
|
|
|
723
604
|
:type mB_nkl: cute.Tensor
|
|
724
605
|
:param tma_atom_d: TMA copy atom for D tensor
|
|
725
606
|
:type tma_atom_d: cute.CopyAtom
|
|
726
|
-
:param
|
|
727
|
-
:type
|
|
607
|
+
:param mD_mnl: Output tensor D
|
|
608
|
+
:type mD_mnl: cute.Tensor
|
|
728
609
|
:param tiled_mma: Tiled MMA object
|
|
729
610
|
:type tiled_mma: cute.TiledMma
|
|
730
611
|
:param cluster_layout_mnk: CTA layout
|
|
731
612
|
:type cluster_layout_mnk: cute.Layout
|
|
732
|
-
:param
|
|
733
|
-
:type
|
|
734
|
-
:param
|
|
735
|
-
:type
|
|
736
|
-
:param
|
|
737
|
-
:type
|
|
613
|
+
:param a_smem_layout: Shared memory layout for A
|
|
614
|
+
:type a_smem_layout: cute.ComposedLayout
|
|
615
|
+
:param b_smem_layout: Shared memory layout for B
|
|
616
|
+
:type b_smem_layout: cute.ComposedLayout
|
|
617
|
+
:param epi_smem_layout: Shared memory layout for epilogue
|
|
618
|
+
:type epi_smem_layout: cute.ComposedLayout
|
|
738
619
|
"""
|
|
739
620
|
|
|
740
|
-
|
|
621
|
+
varlen_m = const_expr(cu_seqlens_m is not None)
|
|
622
|
+
varlen_k = const_expr(cu_seqlens_k is not None)
|
|
623
|
+
assert not (varlen_m and varlen_k)
|
|
624
|
+
has_D = const_expr(mD_mnl is not None)
|
|
625
|
+
has_C = const_expr(mC_mnl is not None)
|
|
626
|
+
|
|
741
627
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
742
628
|
|
|
743
629
|
# /////////////////////////////////////////////////////////////////////////////
|
|
744
630
|
# Prefetch Tma desc
|
|
745
631
|
# /////////////////////////////////////////////////////////////////////////////
|
|
746
632
|
if warp_idx == self.ab_load_warp_id:
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
cpasync.prefetch_descriptor(tma_atom_d)
|
|
751
|
-
if const_expr(tma_atom_c is not None):
|
|
752
|
-
cpasync.prefetch_descriptor(tma_atom_c)
|
|
753
|
-
|
|
754
|
-
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
|
|
755
|
-
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
|
|
756
|
-
tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
757
|
-
if const_expr(not self.gather_A):
|
|
758
|
-
tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
633
|
+
for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
|
|
634
|
+
if const_expr(tma_atom is not None):
|
|
635
|
+
cpasync.prefetch_descriptor(tma_atom)
|
|
759
636
|
|
|
760
637
|
# /////////////////////////////////////////////////////////////////////////////
|
|
761
638
|
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
|
@@ -763,147 +640,93 @@ class HopperWgmmaGemmKernel:
|
|
|
763
640
|
smem = cutlass.utils.SmemAllocator()
|
|
764
641
|
storage = smem.allocate(self.shared_storage)
|
|
765
642
|
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
773
|
-
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
774
|
-
)
|
|
775
|
-
|
|
776
|
-
cta_layout_vmnk = cute.make_layout((1, *cluster_layout_mnk.shape))
|
|
777
|
-
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
778
|
-
ab_pipeline = pipeline_cls.create(
|
|
779
|
-
barrier_storage=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
780
|
-
num_stages=self.ab_stage,
|
|
781
|
-
producer_group=ab_pipeline_producer_group,
|
|
782
|
-
consumer_group=ab_pipeline_consumer_group,
|
|
783
|
-
tx_count=tma_copy_bytes,
|
|
784
|
-
cta_layout_vmnk=cta_layout_vmnk,
|
|
643
|
+
ab_pipeline = self.make_ab_pipeline(
|
|
644
|
+
a_smem_layout=cute.slice_(a_smem_layout, (None, None, 0)),
|
|
645
|
+
b_smem_layout=cute.slice_(b_smem_layout, (None, None, 0)),
|
|
646
|
+
tiled_mma=tiled_mma,
|
|
647
|
+
cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
|
|
648
|
+
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
785
649
|
)
|
|
786
|
-
|
|
787
|
-
if const_expr(
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
|
|
792
|
-
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
793
|
-
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
650
|
+
epi_pipeline = None
|
|
651
|
+
if const_expr(has_C):
|
|
652
|
+
epi_pipeline = self.make_epi_pipeline(
|
|
653
|
+
c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
|
|
654
|
+
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
794
655
|
)
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
epi_pipeline = pipeline.PipelineTmaAsync.create(
|
|
798
|
-
barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
799
|
-
num_stages=self.epi_c_stage,
|
|
800
|
-
producer_group=epi_pipeline_producer_group,
|
|
801
|
-
consumer_group=epi_pipeline_consumer_group,
|
|
802
|
-
tx_count=tma_copy_c_bytes,
|
|
803
|
-
)
|
|
804
|
-
else:
|
|
805
|
-
epi_pipeline = None
|
|
806
|
-
|
|
656
|
+
sched_pipeline = None
|
|
657
|
+
tile_count = None
|
|
807
658
|
if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
808
659
|
# Dynamic persistent scheduler
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
consumer_arrive_cnt = (
|
|
814
|
-
(self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
|
|
815
|
-
) * cluster_size - 1
|
|
816
|
-
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
817
|
-
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
818
|
-
)
|
|
819
|
-
sched_pipeline = pipeline.PipelineAsync.create(
|
|
820
|
-
barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
821
|
-
num_stages=self.sched_stage,
|
|
822
|
-
producer_group=sched_pipeline_producer_group,
|
|
823
|
-
consumer_group=sched_pipeline_consumer_group,
|
|
824
|
-
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
825
|
-
consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
|
|
660
|
+
sched_pipeline = self.make_sched_pipeline(
|
|
661
|
+
cluster_layout_mnk,
|
|
662
|
+
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
663
|
+
varlen_k=varlen_k,
|
|
826
664
|
)
|
|
827
665
|
tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
828
|
-
else:
|
|
829
|
-
sched_pipeline = None
|
|
830
|
-
tile_count = None
|
|
831
666
|
|
|
832
667
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
833
668
|
# Generate smem tensor A/B
|
|
834
669
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
835
|
-
sA = storage.sA.get_tensor(
|
|
836
|
-
sB = storage.sB.get_tensor(
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
else:
|
|
849
|
-
sC = None
|
|
670
|
+
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
|
671
|
+
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
|
672
|
+
sD = None
|
|
673
|
+
if const_expr(has_D):
|
|
674
|
+
if const_expr(not self.is_persistent):
|
|
675
|
+
sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
|
|
676
|
+
sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
|
|
677
|
+
else:
|
|
678
|
+
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
|
679
|
+
sC = None
|
|
680
|
+
if const_expr(has_C):
|
|
681
|
+
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
682
|
+
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
|
850
683
|
|
|
851
684
|
# Get tensormap buffer address
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
tensormap_workspace_idx = (
|
|
856
|
-
bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0]
|
|
857
|
-
)
|
|
858
|
-
# TODO: this is only for D, not for A/B
|
|
859
|
-
if const_expr(self.pingpong):
|
|
860
|
-
tensormap_workspace_idx = tensormap_workspace_idx * 2 + warp_idx // 4
|
|
685
|
+
tensormap_manager = None
|
|
686
|
+
tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
|
|
687
|
+
if const_expr(varlen_m or varlen_k):
|
|
861
688
|
tensormap_manager = TensorMapManagerSm90(
|
|
862
|
-
|
|
863
|
-
)
|
|
864
|
-
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
865
|
-
tensormaps[tensormap_workspace_idx, None].iterator
|
|
689
|
+
cutlass.utils.TensorMapUpdateMode.GMEM, GemmSm90.bytes_per_tensormap
|
|
866
690
|
)
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
691
|
+
# equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
|
|
692
|
+
tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
|
|
693
|
+
if const_expr(varlen_m):
|
|
694
|
+
tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
|
|
695
|
+
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
696
|
+
tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
|
|
871
697
|
)
|
|
872
|
-
# Need this, otherwise "expected tma descriptor pointer to have alignment at least 64, but got 8"
|
|
873
|
-
tensormap_d_smem_ptr = cute.make_ptr(
|
|
874
|
-
cutlass.Int64,
|
|
875
|
-
tensormap_d_smem_ptr.toint(),
|
|
876
|
-
cute.AddressSpace.smem,
|
|
877
|
-
assumed_align=64,
|
|
878
|
-
)
|
|
879
|
-
tensormap_d_init_ptr = tensormap_d_smem_ptr
|
|
880
698
|
else:
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
699
|
+
assert varlen_k
|
|
700
|
+
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
701
|
+
tensormaps[tensormap_workspace_idx, 0, None].iterator
|
|
702
|
+
)
|
|
703
|
+
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
704
|
+
tensormaps[tensormap_workspace_idx, 1, None].iterator
|
|
705
|
+
)
|
|
886
706
|
|
|
887
707
|
TileSchedulerCls = partial(
|
|
888
708
|
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
|
889
709
|
)
|
|
890
710
|
|
|
891
|
-
k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
|
|
892
|
-
c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
|
|
893
|
-
|
|
894
711
|
if warp_idx >= self.ab_load_warp_id:
|
|
895
712
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
|
|
896
|
-
if const_expr(mC_mnl is not None):
|
|
897
|
-
epi_load_barrier = pipeline.NamedBarrier(
|
|
898
|
-
barrier_id=int(NamedBarrierGemm.EpilogueLoad),
|
|
899
|
-
num_threads=self.num_ab_load_threads + self.num_epi_load_threads,
|
|
900
|
-
)
|
|
901
|
-
else:
|
|
902
|
-
epi_load_barrier = None
|
|
903
713
|
if (
|
|
904
714
|
warp_idx >= self.ab_load_warp_id
|
|
905
715
|
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
906
716
|
):
|
|
717
|
+
is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
718
|
+
if const_expr(varlen_k):
|
|
719
|
+
# initialize tensormap for A & B
|
|
720
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
721
|
+
tma_atom_a,
|
|
722
|
+
tensormap_a_ptr,
|
|
723
|
+
is_tma_warp,
|
|
724
|
+
)
|
|
725
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
726
|
+
tma_atom_b,
|
|
727
|
+
tensormap_b_ptr,
|
|
728
|
+
is_tma_warp,
|
|
729
|
+
)
|
|
907
730
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
908
731
|
# Get mcast mask
|
|
909
732
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
@@ -927,16 +750,37 @@ class HopperWgmmaGemmKernel:
|
|
|
927
750
|
ab_producer_state = make_pipeline_state(
|
|
928
751
|
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
929
752
|
)
|
|
930
|
-
|
|
753
|
+
if const_expr(varlen_k):
|
|
754
|
+
# wait tensormap initialization complete before update
|
|
755
|
+
tensormap_manager.fence_tensormap_initialization()
|
|
756
|
+
# batch index of last tile
|
|
757
|
+
last_batch_idx = cutlass.Int32(-1)
|
|
931
758
|
while work_tile.is_valid_tile:
|
|
932
759
|
tile_coord_mnkl = work_tile.tile_idx
|
|
933
760
|
batch_idx = tile_coord_mnkl[3]
|
|
761
|
+
if const_expr(varlen_k):
|
|
762
|
+
is_group_changed = batch_idx != last_batch_idx
|
|
763
|
+
last_batch_idx = batch_idx
|
|
764
|
+
if is_group_changed:
|
|
765
|
+
# construct tensor A/B based on real address, shape and stride information
|
|
766
|
+
tensormap_manager.update_tensormap_shape(
|
|
767
|
+
(tensormap_a_ptr, tensormap_b_ptr),
|
|
768
|
+
is_manager_warp=is_tma_warp,
|
|
769
|
+
shapes=(cu_seqlens_k[batch_idx + 1], cu_seqlens_k[batch_idx + 1]),
|
|
770
|
+
orders=(
|
|
771
|
+
0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1,
|
|
772
|
+
0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1,
|
|
773
|
+
),
|
|
774
|
+
tensormap_smem_ptr=None,
|
|
775
|
+
)
|
|
934
776
|
# ///////////////////////////////////////////////////////////////////////////
|
|
935
777
|
# Local_tile partition global tensors
|
|
936
778
|
# ///////////////////////////////////////////////////////////////////////////
|
|
937
779
|
if const_expr(not self.gather_A):
|
|
938
|
-
if const_expr(
|
|
780
|
+
if const_expr(varlen_m):
|
|
939
781
|
mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
|
|
782
|
+
elif const_expr(varlen_k):
|
|
783
|
+
mA_mk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mA_mkl)
|
|
940
784
|
else:
|
|
941
785
|
mA_mk = mA_mkl[None, None, batch_idx]
|
|
942
786
|
# (bM, bK, RestK)
|
|
@@ -947,28 +791,46 @@ class HopperWgmmaGemmKernel:
|
|
|
947
791
|
)
|
|
948
792
|
else:
|
|
949
793
|
mA_mk = mA_mkl
|
|
950
|
-
if const_expr(
|
|
794
|
+
if const_expr(varlen_m):
|
|
951
795
|
mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
|
|
796
|
+
elif const_expr(varlen_k):
|
|
797
|
+
mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
|
|
952
798
|
else:
|
|
953
799
|
mAIdx_mk = mAIdx[None, batch_idx]
|
|
954
800
|
gAIdx = cute.local_tile(
|
|
955
801
|
mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
|
956
802
|
)
|
|
803
|
+
if const_expr(varlen_k):
|
|
804
|
+
mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
|
|
805
|
+
else:
|
|
806
|
+
mB_nk = mB_nkl[None, None, batch_idx]
|
|
957
807
|
# (bN, bK, RestK)
|
|
958
808
|
gB_k = cute.local_tile(
|
|
959
|
-
|
|
809
|
+
mB_nk, cute.select(self.tile_shape_mnk, [1, 2]), (tile_coord_mnkl[1], None)
|
|
960
810
|
)
|
|
961
811
|
# //////////////////////////////////////////////////////////////////////////
|
|
962
812
|
# Partition shared tensor for TMA load A/B
|
|
963
813
|
# //////////////////////////////////////////////////////////////////////////
|
|
814
|
+
if const_expr(varlen_k):
|
|
815
|
+
# ensure the update to tensormap has completed before using it
|
|
816
|
+
if is_group_changed and is_tma_warp:
|
|
817
|
+
tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
|
|
818
|
+
tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
|
|
819
|
+
tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
820
|
+
tensormap_a_ptr, cute.AddressSpace.generic
|
|
821
|
+
)
|
|
822
|
+
tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
823
|
+
tensormap_b_ptr, cute.AddressSpace.generic
|
|
824
|
+
)
|
|
825
|
+
else:
|
|
826
|
+
tma_desc_a_ptr, tma_desc_b_ptr = None, None
|
|
964
827
|
# TMA load A partition_S/D
|
|
965
828
|
a_cta_layout = cute.make_layout(
|
|
966
829
|
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
967
830
|
)
|
|
968
831
|
a_cta_crd = cluster_coord_mnk[1]
|
|
969
832
|
if const_expr(not self.gather_A):
|
|
970
|
-
# ((atom_v, rest_v), STAGE)
|
|
971
|
-
# ((atom_v, rest_v), RestK)
|
|
833
|
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
972
834
|
tAsA, tAgA_k = cpasync.tma_partition(
|
|
973
835
|
tma_atom_a,
|
|
974
836
|
a_cta_crd,
|
|
@@ -976,7 +838,12 @@ class HopperWgmmaGemmKernel:
|
|
|
976
838
|
cute.group_modes(sA, 0, 2),
|
|
977
839
|
cute.group_modes(gA_k, 0, 2),
|
|
978
840
|
)
|
|
979
|
-
copy_A = partial(
|
|
841
|
+
copy_A = partial(
|
|
842
|
+
cute.copy,
|
|
843
|
+
tma_atom_a,
|
|
844
|
+
mcast_mask=a_mcast_mask,
|
|
845
|
+
tma_desc_ptr=tma_desc_a_ptr,
|
|
846
|
+
)
|
|
980
847
|
else:
|
|
981
848
|
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
|
982
849
|
mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
|
|
@@ -996,8 +863,7 @@ class HopperWgmmaGemmKernel:
|
|
|
996
863
|
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
|
997
864
|
)
|
|
998
865
|
b_cta_crd = cluster_coord_mnk[0]
|
|
999
|
-
# ((atom_v, rest_v), STAGE)
|
|
1000
|
-
# ((atom_v, rest_v), RestK)
|
|
866
|
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
1001
867
|
tBsB, tBgB_k = cpasync.tma_partition(
|
|
1002
868
|
tma_atom_b,
|
|
1003
869
|
b_cta_crd,
|
|
@@ -1005,7 +871,15 @@ class HopperWgmmaGemmKernel:
|
|
|
1005
871
|
cute.group_modes(sB, 0, 2),
|
|
1006
872
|
cute.group_modes(gB_k, 0, 2),
|
|
1007
873
|
)
|
|
1008
|
-
copy_B = partial(
|
|
874
|
+
copy_B = partial(
|
|
875
|
+
cute.copy, tma_atom_b, mcast_mask=b_mcast_mask, tma_desc_ptr=tma_desc_b_ptr
|
|
876
|
+
)
|
|
877
|
+
k_len = (
|
|
878
|
+
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
879
|
+
if const_expr(varlen_k)
|
|
880
|
+
else mA_mkl.shape[1]
|
|
881
|
+
)
|
|
882
|
+
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
1009
883
|
if const_expr(not self.gather_A):
|
|
1010
884
|
ab_producer_state = self.load_AB(
|
|
1011
885
|
ab_pipeline,
|
|
@@ -1016,6 +890,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1016
890
|
copy_B,
|
|
1017
891
|
tBgB_k,
|
|
1018
892
|
tBsB,
|
|
893
|
+
k_tile_cnt,
|
|
1019
894
|
)
|
|
1020
895
|
else:
|
|
1021
896
|
limit_m = (
|
|
@@ -1033,93 +908,37 @@ class HopperWgmmaGemmKernel:
|
|
|
1033
908
|
copy_B,
|
|
1034
909
|
tBgB_k,
|
|
1035
910
|
tBsB,
|
|
911
|
+
k_tile_cnt,
|
|
1036
912
|
limit_A=(
|
|
1037
913
|
limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
|
|
1038
914
|
mA_mk.shape[1],
|
|
1039
915
|
),
|
|
1040
916
|
)
|
|
1041
|
-
if const_expr(epi_load_barrier is not None):
|
|
1042
|
-
# In the first work tile, the epi load warp will wait for the signal
|
|
1043
|
-
# from the mainloop load warp to start loading C, to avoid interfering
|
|
1044
|
-
# with loading A and B.
|
|
1045
|
-
if do_epi_load_barrier_arrive:
|
|
1046
|
-
epi_load_barrier.arrive()
|
|
1047
|
-
do_epi_load_barrier_arrive = cutlass.Boolean(False)
|
|
1048
917
|
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
918
|
+
tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1049
919
|
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1050
920
|
work_tile = tile_scheduler.get_current_work()
|
|
1051
921
|
# End of persistent scheduler loop
|
|
1052
|
-
if const_expr(self.pingpong):
|
|
922
|
+
if const_expr(self.pingpong and not varlen_k):
|
|
1053
923
|
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
|
924
|
+
# tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
925
|
+
tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1054
926
|
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1055
927
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
1056
928
|
if is_scheduler_warp:
|
|
1057
929
|
tile_scheduler.producer_tail()
|
|
1058
930
|
|
|
1059
|
-
# if const_expr(mC_mnl is not None):
|
|
1060
|
-
# if warp_idx == self.epi_load_warp_id:
|
|
1061
|
-
# epi_producer_state = make_pipeline_state(
|
|
1062
|
-
# pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
1063
|
-
# )
|
|
1064
|
-
# do_epi_load_barrier_wait = cutlass.Boolean(True)
|
|
1065
|
-
# tile_scheduler = TileSchedulerCls()
|
|
1066
|
-
# work_tile = tile_scheduler.initial_work_tile_info()
|
|
1067
|
-
# while work_tile.is_valid_tile:
|
|
1068
|
-
# tile_coord_mnkl = work_tile.tile_idx
|
|
1069
|
-
# batch_idx = tile_coord_mnkl[3]
|
|
1070
|
-
# if const_expr(cu_seqlens_m is not None):
|
|
1071
|
-
# mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
|
|
1072
|
-
# else:
|
|
1073
|
-
# mC_mn = mC_mnl[None, None, batch_idx]
|
|
1074
|
-
# # (bM, bN)
|
|
1075
|
-
# gC = cute.local_tile(
|
|
1076
|
-
# mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
|
|
1077
|
-
# )
|
|
1078
|
-
# tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
|
|
1079
|
-
# bGS_sC, bGS_gC = cpasync.tma_partition(
|
|
1080
|
-
# tma_atom_c,
|
|
1081
|
-
# 0,
|
|
1082
|
-
# cute.make_layout(1),
|
|
1083
|
-
# cute.group_modes(sC, 0, 2),
|
|
1084
|
-
# tCgC_for_tma_partition,
|
|
1085
|
-
# )
|
|
1086
|
-
# if do_epi_load_barrier_wait:
|
|
1087
|
-
# epi_load_barrier.arrive_and_wait()
|
|
1088
|
-
# do_epi_load_barrier_wait = cutlass.Boolean(False)
|
|
1089
|
-
# epi_tile_num = const_expr(cute.size(tCgC_for_tma_partition, mode=[1]))
|
|
1090
|
-
# epi_tile_shape = tCgC_for_tma_partition.shape[1]
|
|
1091
|
-
# for epi_idx in cutlass.range(epi_tile_num, unroll=1):
|
|
1092
|
-
# epi_pipeline.producer_acquire(epi_producer_state)
|
|
1093
|
-
# # Get the global memory coordinate for the current epi tile
|
|
1094
|
-
# epi_tile_layout = cute.make_layout(
|
|
1095
|
-
# epi_tile_shape, stride=(epi_tile_shape[1], 1)
|
|
1096
|
-
# )
|
|
1097
|
-
# gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1098
|
-
# cute.copy(
|
|
1099
|
-
# tma_atom_c,
|
|
1100
|
-
# bGS_gC[None, gmem_coord],
|
|
1101
|
-
# bGS_sC[None, epi_producer_state.index],
|
|
1102
|
-
# tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1103
|
-
# )
|
|
1104
|
-
# # Epi pipeline's producer commit is a NOP
|
|
1105
|
-
# epi_pipeline.producer_commit(epi_producer_state)
|
|
1106
|
-
# epi_producer_state.advance()
|
|
1107
|
-
# tile_scheduler.advance_to_next_work()
|
|
1108
|
-
# work_tile = tile_scheduler.get_current_work()
|
|
1109
|
-
# # End of persistent scheduler loop
|
|
1110
|
-
# epi_pipeline.producer_tail(epi_producer_state)
|
|
1111
|
-
|
|
1112
931
|
if warp_idx < self.ab_load_warp_id:
|
|
1113
932
|
cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
|
|
1114
|
-
is_tma_warp =
|
|
933
|
+
is_tma_warp = Boolean(
|
|
1115
934
|
(not self.pingpong and warp_idx == 0)
|
|
1116
935
|
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
|
1117
936
|
)
|
|
1118
|
-
if const_expr(
|
|
937
|
+
if const_expr(varlen_m):
|
|
1119
938
|
# initialize tensormap for D
|
|
1120
939
|
tensormap_manager.init_tensormap_from_atom(
|
|
1121
940
|
tma_atom_d,
|
|
1122
|
-
|
|
941
|
+
tensormap_d_ptr,
|
|
1123
942
|
is_manager_warp=is_tma_warp,
|
|
1124
943
|
)
|
|
1125
944
|
# //////////////////////////////////////////////////////////////////////////////
|
|
@@ -1145,10 +964,9 @@ class HopperWgmmaGemmKernel:
|
|
|
1145
964
|
|
|
1146
965
|
acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
|
|
1147
966
|
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
967
|
+
acc_slow = None
|
|
1148
968
|
if const_expr(self.fp8_slow_accum):
|
|
1149
969
|
acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
1150
|
-
else:
|
|
1151
|
-
acc_slow = None
|
|
1152
970
|
|
|
1153
971
|
if const_expr(self.pingpong):
|
|
1154
972
|
if warp_group_idx == 0:
|
|
@@ -1156,6 +974,9 @@ class HopperWgmmaGemmKernel:
|
|
|
1156
974
|
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
|
1157
975
|
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
|
1158
976
|
|
|
977
|
+
k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.tile_shape_mnk[2])
|
|
978
|
+
c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
|
|
979
|
+
|
|
1159
980
|
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
|
1160
981
|
epi_read_state = make_pipeline_state(
|
|
1161
982
|
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
|
@@ -1164,16 +985,29 @@ class HopperWgmmaGemmKernel:
|
|
|
1164
985
|
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
1165
986
|
)
|
|
1166
987
|
tile_scheduler = TileSchedulerCls()
|
|
988
|
+
work_tile = None
|
|
1167
989
|
if const_expr(self.pingpong):
|
|
990
|
+
if const_expr(varlen_k):
|
|
991
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1168
992
|
if warp_idx >= 4:
|
|
1169
|
-
# Advance 2nd Math WG to the next work tile for the startup
|
|
1170
|
-
tile_scheduler.advance_to_next_work()
|
|
1171
993
|
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
|
1172
|
-
ab_read_state.advance_iters(k_tile_cnt)
|
|
1173
994
|
epi_read_state.advance_iters(c_tile_cnt)
|
|
1174
995
|
epi_producer_state.advance_iters(c_tile_cnt)
|
|
1175
|
-
|
|
1176
|
-
|
|
996
|
+
if const_expr(not varlen_k):
|
|
997
|
+
ab_read_state.advance_iters(k_tile_cnt_static)
|
|
998
|
+
else:
|
|
999
|
+
batch_idx = work_tile.tile_idx[3]
|
|
1000
|
+
k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1001
|
+
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
1002
|
+
ab_read_state.advance_iters(k_tile_cnt)
|
|
1003
|
+
tile_scheduler.advance_to_next_work()
|
|
1004
|
+
if const_expr(varlen_k):
|
|
1005
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1006
|
+
if const_expr(not varlen_k):
|
|
1007
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1008
|
+
else:
|
|
1009
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1010
|
+
if const_expr(varlen_m):
|
|
1177
1011
|
# wait tensormap initialization complete before update
|
|
1178
1012
|
tensormap_manager.fence_tensormap_initialization()
|
|
1179
1013
|
# batch index of last tile
|
|
@@ -1181,19 +1015,25 @@ class HopperWgmmaGemmKernel:
|
|
|
1181
1015
|
while work_tile.is_valid_tile:
|
|
1182
1016
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1183
1017
|
batch_idx = tile_coord_mnkl[3]
|
|
1184
|
-
if const_expr(
|
|
1018
|
+
if const_expr(varlen_m):
|
|
1185
1019
|
is_group_changed = batch_idx != last_batch_idx
|
|
1186
1020
|
last_batch_idx = batch_idx
|
|
1187
1021
|
if is_group_changed:
|
|
1188
1022
|
# construct tensor D based on real address, shape and stride information
|
|
1189
1023
|
tensormap_manager.update_tensormap_shape(
|
|
1190
|
-
(
|
|
1024
|
+
(tensormap_d_ptr,),
|
|
1191
1025
|
is_manager_warp=is_tma_warp,
|
|
1192
|
-
tensormap_smem_ptr=(tensormap_d_smem_ptr,),
|
|
1193
1026
|
shapes=(cu_seqlens_m[batch_idx + 1],),
|
|
1194
1027
|
orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
|
|
1028
|
+
tensormap_smem_ptr=None,
|
|
1195
1029
|
)
|
|
1196
1030
|
|
|
1031
|
+
k_len = (
|
|
1032
|
+
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1033
|
+
if const_expr(varlen_k)
|
|
1034
|
+
else mA_mkl.shape[1]
|
|
1035
|
+
)
|
|
1036
|
+
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
1197
1037
|
ab_read_state, tiled_mma = self.mma(
|
|
1198
1038
|
ab_pipeline,
|
|
1199
1039
|
ab_read_state,
|
|
@@ -1205,9 +1045,9 @@ class HopperWgmmaGemmKernel:
|
|
|
1205
1045
|
k_tile_cnt,
|
|
1206
1046
|
warp_group_idx,
|
|
1207
1047
|
)
|
|
1208
|
-
if const_expr(
|
|
1209
|
-
|
|
1210
|
-
|
|
1048
|
+
if const_expr(varlen_k):
|
|
1049
|
+
if k_tile_cnt == 0:
|
|
1050
|
+
acc.fill(0.0)
|
|
1211
1051
|
|
|
1212
1052
|
# /////////////////////////////////////////////////////////////////////////////
|
|
1213
1053
|
# EPILOGUE
|
|
@@ -1219,194 +1059,123 @@ class HopperWgmmaGemmKernel:
|
|
|
1219
1059
|
barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
|
|
1220
1060
|
)
|
|
1221
1061
|
|
|
1222
|
-
|
|
1223
|
-
# A in the mainloop is reused in the epilogue if not persistent.
|
|
1224
|
-
if const_expr(not self.is_persistent):
|
|
1225
|
-
epilogue_barrier.arrive_and_wait()
|
|
1226
|
-
|
|
1227
|
-
if const_expr(varlen):
|
|
1062
|
+
if const_expr(varlen_m):
|
|
1228
1063
|
# ensure the update to tensormap has completed before using it
|
|
1229
|
-
if is_group_changed:
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
# get st.matrix with num_matrices=4
|
|
1235
|
-
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
|
|
1236
|
-
self.d_layout, elem_ty_d=self.d_dtype, elem_ty_acc=self.acc_dtype
|
|
1237
|
-
)
|
|
1238
|
-
copy_atom_C = cute.make_copy_atom(
|
|
1239
|
-
warp.StMatrix8x8x16bOp(
|
|
1240
|
-
self.d_layout.is_m_major_c(),
|
|
1241
|
-
num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
|
|
1242
|
-
),
|
|
1243
|
-
cutlass.Float16, # this is just to get the right source layout
|
|
1244
|
-
)
|
|
1245
|
-
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
|
1246
|
-
tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
|
|
1247
|
-
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1248
|
-
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1249
|
-
tRS_sD = thr_copy_r2s.partition_D(sD)
|
|
1250
|
-
# (R2S, R2S_M, R2S_N)
|
|
1251
|
-
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
1252
|
-
|
|
1253
|
-
# Allocate D registers.
|
|
1254
|
-
tRS_rD_layout = cute.make_layout(thr_copy_r2s.partition_S(sD).shape[:3])
|
|
1255
|
-
tRS_rD = cute.make_fragment(tRS_rD_layout, self.acc_dtype)
|
|
1256
|
-
|
|
1257
|
-
if const_expr(mC_mnl is not None):
|
|
1258
|
-
copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
|
|
1259
|
-
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
|
1260
|
-
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1261
|
-
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1262
|
-
tRS_rC = cute.make_fragment(tRS_rD_layout, self.c_dtype)
|
|
1263
|
-
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
|
1264
|
-
else:
|
|
1265
|
-
thr_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
|
1266
|
-
|
|
1267
|
-
if const_expr(cu_seqlens_m is not None):
|
|
1268
|
-
mD_mn_tma = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl_tma)
|
|
1064
|
+
if is_group_changed and is_tma_warp:
|
|
1065
|
+
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1066
|
+
tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1067
|
+
tensormap_d_ptr, cute.AddressSpace.generic
|
|
1068
|
+
)
|
|
1269
1069
|
else:
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
tDgD_for_tma_partition,
|
|
1282
|
-
)
|
|
1283
|
-
|
|
1284
|
-
if const_expr(mC_mnl is not None):
|
|
1285
|
-
if const_expr(cu_seqlens_m is not None):
|
|
1286
|
-
mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
|
|
1287
|
-
else:
|
|
1288
|
-
mC_mn = mC_mnl[None, None, batch_idx]
|
|
1289
|
-
# (bM, bN)
|
|
1290
|
-
gC = cute.local_tile(
|
|
1291
|
-
mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
|
|
1070
|
+
tma_desc_d_ptr = None
|
|
1071
|
+
|
|
1072
|
+
if const_expr(has_D):
|
|
1073
|
+
bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(
|
|
1074
|
+
tma_atom_d,
|
|
1075
|
+
mD_mnl,
|
|
1076
|
+
self.tile_shape_mnk[:2],
|
|
1077
|
+
self.epi_tile,
|
|
1078
|
+
sD,
|
|
1079
|
+
tile_coord_mnkl,
|
|
1080
|
+
cu_seqlens_m,
|
|
1292
1081
|
)
|
|
1293
|
-
|
|
1294
|
-
|
|
1082
|
+
copy_D = partial(cute.copy, tma_atom_d, tma_desc_ptr=tma_desc_d_ptr)
|
|
1083
|
+
else:
|
|
1084
|
+
bSG_sD, bSG_gD, copy_D = None, None, None
|
|
1085
|
+
if const_expr(has_C):
|
|
1086
|
+
bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
|
|
1295
1087
|
tma_atom_c,
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1088
|
+
mC_mnl,
|
|
1089
|
+
self.tile_shape_mnk[:2],
|
|
1090
|
+
self.epi_tile,
|
|
1091
|
+
sC,
|
|
1092
|
+
tile_coord_mnkl,
|
|
1093
|
+
cu_seqlens_m,
|
|
1300
1094
|
)
|
|
1095
|
+
copy_C = partial(cute.copy, tma_atom_c)
|
|
1096
|
+
epi_load_g2s = partial(self.epi_load_g2s, epi_pipeline, copy_C, bGS_gC, bGS_sC)
|
|
1097
|
+
else:
|
|
1098
|
+
epi_load_g2s = None
|
|
1301
1099
|
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
if is_tma_warp:
|
|
1310
|
-
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1311
|
-
# Get the global memory coordinate for the current epi tile
|
|
1312
|
-
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1313
|
-
cute.copy(
|
|
1314
|
-
tma_atom_c,
|
|
1315
|
-
bGS_gC[None, gmem_coord],
|
|
1316
|
-
bGS_sC[None, epi_producer_state.index],
|
|
1317
|
-
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1318
|
-
)
|
|
1319
|
-
# Epi pipeline's producer commit is a NOP
|
|
1320
|
-
epi_pipeline.producer_commit(epi_producer_state)
|
|
1321
|
-
epi_producer_state.advance()
|
|
1322
|
-
|
|
1323
|
-
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
1324
|
-
# Copy from acc to D registers
|
|
1325
|
-
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
1326
|
-
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
1327
|
-
if const_expr(mC_mnl is not None):
|
|
1328
|
-
epi_pipeline.consumer_wait(epi_read_state)
|
|
1329
|
-
cute.copy(
|
|
1330
|
-
thr_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
|
|
1331
|
-
)
|
|
1332
|
-
# Fence to make sure shared memory read is visible to TMA load
|
|
1333
|
-
cute.arch.fence_proxy(
|
|
1334
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1335
|
-
)
|
|
1336
|
-
cute.arch.sync_warp()
|
|
1337
|
-
with cute.arch.elect_one():
|
|
1338
|
-
epi_pipeline.consumer_release(epi_read_state)
|
|
1339
|
-
epi_read_state.advance()
|
|
1340
|
-
if const_expr(epi_idx + self.epi_c_stage < epi_tile_num):
|
|
1341
|
-
if is_tma_warp:
|
|
1342
|
-
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1343
|
-
# Get the global memory coordinate for the current epi tile
|
|
1344
|
-
gmem_coord = epi_tile_layout.get_hier_coord(
|
|
1345
|
-
epi_idx + self.epi_c_stage
|
|
1346
|
-
)
|
|
1347
|
-
cute.copy(
|
|
1348
|
-
tma_atom_c,
|
|
1349
|
-
bGS_gC[None, gmem_coord],
|
|
1350
|
-
bGS_sC[None, epi_producer_state.index],
|
|
1351
|
-
tma_bar_ptr=epi_pipeline.producer_get_barrier(
|
|
1352
|
-
epi_producer_state
|
|
1353
|
-
),
|
|
1354
|
-
)
|
|
1355
|
-
# Epi pipeline's producer commit is a NOP
|
|
1356
|
-
epi_pipeline.producer_commit(epi_producer_state)
|
|
1357
|
-
epi_producer_state.advance()
|
|
1358
|
-
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(self.acc_dtype))
|
|
1359
|
-
# Type conversion
|
|
1360
|
-
tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
|
|
1361
|
-
tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
|
|
1362
|
-
# Copy from D registers to shared memory
|
|
1363
|
-
epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3])
|
|
1364
|
-
cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
|
|
1365
|
-
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1366
|
-
cute.arch.fence_proxy(
|
|
1367
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1100
|
+
d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
|
|
1101
|
+
tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
|
1102
|
+
tiled_mma, self.d_layout, d_dtype_for_layout, acc, sD, tidx
|
|
1103
|
+
)
|
|
1104
|
+
if const_expr(has_C):
|
|
1105
|
+
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
|
1106
|
+
tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
|
|
1368
1107
|
)
|
|
1108
|
+
else:
|
|
1109
|
+
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
|
1110
|
+
|
|
1111
|
+
# Wait for all warp groups in the thread block to finish, because smem for tensor
|
|
1112
|
+
# A in the mainloop is reused in the epilogue if not persistent.
|
|
1113
|
+
if const_expr(not self.is_persistent):
|
|
1369
1114
|
epilogue_barrier.arrive_and_wait()
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1115
|
+
|
|
1116
|
+
self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
|
|
1117
|
+
|
|
1118
|
+
epi_read_state, epi_producer_state = self.epilogue(
|
|
1119
|
+
epilogue_params,
|
|
1120
|
+
epi_smem_tensors,
|
|
1121
|
+
epi_pipeline,
|
|
1122
|
+
epi_read_state,
|
|
1123
|
+
epi_producer_state,
|
|
1124
|
+
tiled_mma,
|
|
1125
|
+
tRS_rAcc,
|
|
1126
|
+
tRS_rD,
|
|
1127
|
+
tRS_rC,
|
|
1128
|
+
tiled_copy_r2s,
|
|
1129
|
+
tRS_sD,
|
|
1130
|
+
tiled_copy_s2r,
|
|
1131
|
+
tSR_rC,
|
|
1132
|
+
tSR_sC,
|
|
1133
|
+
copy_D,
|
|
1134
|
+
bSG_sD,
|
|
1135
|
+
bSG_gD,
|
|
1136
|
+
epi_load_g2s,
|
|
1137
|
+
tile_coord_mnkl,
|
|
1138
|
+
cu_seqlens_m,
|
|
1139
|
+
epilogue_barrier,
|
|
1140
|
+
tile_scheduler,
|
|
1141
|
+
tidx,
|
|
1142
|
+
is_tma_warp,
|
|
1143
|
+
)
|
|
1390
1144
|
|
|
1391
1145
|
if const_expr(self.pingpong):
|
|
1392
|
-
# Update starting load/store pipeline states for the next tile
|
|
1393
|
-
epi_read_state.advance_iters(c_tile_cnt)
|
|
1394
|
-
epi_producer_state.advance_iters(c_tile_cnt)
|
|
1395
1146
|
# With pingpong, 2 WGs write two different output tiles to the same smem,
|
|
1396
1147
|
# so we have to make sure the smem content is done reading before signaling
|
|
1397
1148
|
# the next WG's epilogue.
|
|
1398
|
-
if
|
|
1149
|
+
if is_tma_warp:
|
|
1399
1150
|
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
1400
1151
|
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
|
1401
1152
|
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1153
|
+
if const_expr(not self.pingpong):
|
|
1154
|
+
tile_scheduler.advance_to_next_work()
|
|
1155
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1156
|
+
else: # Skip a tile for pingpong
|
|
1157
|
+
# Update starting load/store pipeline states for the next tile
|
|
1158
|
+
epi_read_state.advance_iters(c_tile_cnt)
|
|
1159
|
+
epi_producer_state.advance_iters(c_tile_cnt)
|
|
1160
|
+
# Update starting mainloop pipeline state for the next tile
|
|
1161
|
+
if const_expr(not varlen_k):
|
|
1162
|
+
ab_read_state.advance_iters(k_tile_cnt_static)
|
|
1163
|
+
tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
|
|
1164
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1165
|
+
else:
|
|
1166
|
+
tile_scheduler.advance_to_next_work()
|
|
1167
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1168
|
+
if work_tile.is_valid_tile:
|
|
1169
|
+
batch_idx = work_tile.tile_idx[3]
|
|
1170
|
+
k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1171
|
+
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
1172
|
+
ab_read_state.advance_iters(k_tile_cnt)
|
|
1173
|
+
tile_scheduler.advance_to_next_work()
|
|
1174
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1406
1175
|
# End of persistent scheduler loop
|
|
1407
1176
|
|
|
1408
1177
|
if const_expr(not self.pingpong):
|
|
1409
|
-
if
|
|
1178
|
+
if is_tma_warp:
|
|
1410
1179
|
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
1411
1180
|
|
|
1412
1181
|
@cute.jit
|
|
@@ -1420,10 +1189,10 @@ class HopperWgmmaGemmKernel:
|
|
|
1420
1189
|
copy_B: Callable,
|
|
1421
1190
|
tBgB: cute.Tensor,
|
|
1422
1191
|
tBsB: cute.Tensor,
|
|
1192
|
+
k_tile_cnt: Int32,
|
|
1423
1193
|
) -> cutlass.pipeline.PipelineState:
|
|
1424
|
-
k_tile_cnt = cute.size(tAgA, mode=[1])
|
|
1425
1194
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1426
|
-
peek_ab_empty_status =
|
|
1195
|
+
peek_ab_empty_status = Boolean(True)
|
|
1427
1196
|
if 0 < k_tile_cnt:
|
|
1428
1197
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1429
1198
|
# /////////////////////////////////////////////////////////////////////////
|
|
@@ -1434,20 +1203,12 @@ class HopperWgmmaGemmKernel:
|
|
|
1434
1203
|
# Also sets the transaction barrier for the A/B buffers
|
|
1435
1204
|
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
|
1436
1205
|
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1437
|
-
copy_A(
|
|
1438
|
-
|
|
1439
|
-
tAsA[None, ab_producer_state.index],
|
|
1440
|
-
tma_bar_ptr=tma_bar_ptr,
|
|
1441
|
-
)
|
|
1442
|
-
copy_B(
|
|
1443
|
-
tBgB[None, k_tile],
|
|
1444
|
-
tBsB[None, ab_producer_state.index],
|
|
1445
|
-
tma_bar_ptr=tma_bar_ptr,
|
|
1446
|
-
)
|
|
1206
|
+
copy_A(tAgA[None, k_tile], tAsA[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
|
|
1207
|
+
copy_B(tBgB[None, k_tile], tBsB[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
|
|
1447
1208
|
# Mainloop pipeline's producer commit is a NOP
|
|
1448
1209
|
ab_pipeline.producer_commit(ab_producer_state)
|
|
1449
1210
|
ab_producer_state.advance()
|
|
1450
|
-
peek_ab_empty_status =
|
|
1211
|
+
peek_ab_empty_status = Boolean(True)
|
|
1451
1212
|
if k_tile + 1 < k_tile_cnt:
|
|
1452
1213
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1453
1214
|
return ab_producer_state
|
|
@@ -1464,6 +1225,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1464
1225
|
copy_B: Callable,
|
|
1465
1226
|
tBgB: cute.Tensor,
|
|
1466
1227
|
tBsB: cute.Tensor,
|
|
1228
|
+
k_tile_cnt: Int32,
|
|
1467
1229
|
limit_A: Tuple[Int32, Int32],
|
|
1468
1230
|
) -> cutlass.pipeline.PipelineState:
|
|
1469
1231
|
# (atom_v, CPY_M, 1, RestK)
|
|
@@ -1489,9 +1251,8 @@ class HopperWgmmaGemmKernel:
|
|
|
1489
1251
|
# (m, (bK, RestK))
|
|
1490
1252
|
mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
|
|
1491
1253
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
1492
|
-
k_tile_cnt = cute.size(tBgB, mode=[1])
|
|
1493
1254
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1494
|
-
peek_ab_empty_status =
|
|
1255
|
+
peek_ab_empty_status = Boolean(True)
|
|
1495
1256
|
if 0 < k_tile_cnt:
|
|
1496
1257
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1497
1258
|
# /////////////////////////////////////////////////////////////////////////
|
|
@@ -1527,7 +1288,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1527
1288
|
# This tells mbarrier to track the completion of cp.async
|
|
1528
1289
|
ab_pipeline.producer_commit(ab_producer_state)
|
|
1529
1290
|
ab_producer_state.advance()
|
|
1530
|
-
peek_ab_empty_status =
|
|
1291
|
+
peek_ab_empty_status = Boolean(True)
|
|
1531
1292
|
if k_tile + 1 < k_tile_cnt:
|
|
1532
1293
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1533
1294
|
# bound checking in the K dimension on the last k_tile
|
|
@@ -1545,7 +1306,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1545
1306
|
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1546
1307
|
)
|
|
1547
1308
|
assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
|
|
1548
|
-
tApA = cute.make_fragment(1,
|
|
1309
|
+
tApA = cute.make_fragment(1, Boolean)
|
|
1549
1310
|
tApA[0] = tAcA[0, 0, 0][1] < limit_k
|
|
1550
1311
|
# (m, bK)
|
|
1551
1312
|
mA_cur = mA_k[None, (None, k_tile)]
|
|
@@ -1584,12 +1345,11 @@ class HopperWgmmaGemmKernel:
|
|
|
1584
1345
|
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
|
|
1585
1346
|
if const_expr(self.pingpong):
|
|
1586
1347
|
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
|
|
1587
|
-
peek_ab_full_status =
|
|
1348
|
+
peek_ab_full_status = Boolean(True)
|
|
1588
1349
|
if 0 < k_tile_cnt:
|
|
1589
1350
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1590
1351
|
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
1591
1352
|
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
1592
|
-
# TODO: this is probably not correct if k_tile_cnt == 0
|
|
1593
1353
|
for k_tile in cutlass.range(num_prologue_mma):
|
|
1594
1354
|
# Wait for A/B buffer to be ready
|
|
1595
1355
|
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
@@ -1600,9 +1360,11 @@ class HopperWgmmaGemmKernel:
|
|
|
1600
1360
|
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
|
1601
1361
|
warpgroup.commit_group()
|
|
1602
1362
|
ab_read_state.advance()
|
|
1603
|
-
peek_ab_full_status =
|
|
1363
|
+
peek_ab_full_status = Boolean(True)
|
|
1604
1364
|
if k_tile + 1 < k_tile_cnt:
|
|
1605
1365
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1366
|
+
# If k_tile_cnt == 0, this is not correct. But we will set acc to 0 in the mainloop
|
|
1367
|
+
# in that case.
|
|
1606
1368
|
if const_expr(self.fp8_slow_accum):
|
|
1607
1369
|
warpgroup.wait_group(0)
|
|
1608
1370
|
acc_slow.store(acc.load())
|
|
@@ -1631,7 +1393,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1631
1393
|
ab_pipeline.consumer_release(ab_release_state)
|
|
1632
1394
|
ab_read_state.advance()
|
|
1633
1395
|
ab_release_state.advance()
|
|
1634
|
-
peek_ab_full_status =
|
|
1396
|
+
peek_ab_full_status = Boolean(True)
|
|
1635
1397
|
if k_tile + 1 < k_tile_cnt:
|
|
1636
1398
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1637
1399
|
if const_expr(self.pingpong):
|
|
@@ -1640,7 +1402,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1640
1402
|
if const_expr(not self.fp8_slow_accum):
|
|
1641
1403
|
# fp8_slow_accum would already called wait_group(0) inside the loop
|
|
1642
1404
|
warpgroup.wait_group(0)
|
|
1643
|
-
for k_tile in cutlass.range(
|
|
1405
|
+
for k_tile in cutlass.range(num_prologue_mma, unroll=1):
|
|
1644
1406
|
ab_pipeline.consumer_release(ab_release_state)
|
|
1645
1407
|
ab_release_state.advance()
|
|
1646
1408
|
if const_expr(self.fp8_slow_accum):
|
|
@@ -1649,6 +1411,184 @@ class HopperWgmmaGemmKernel:
|
|
|
1649
1411
|
# "operand #0 does not dominate this use"
|
|
1650
1412
|
return ab_read_state, tiled_mma
|
|
1651
1413
|
|
|
1414
|
+
@cute.jit
|
|
1415
|
+
def epilogue(
|
|
1416
|
+
self,
|
|
1417
|
+
params: EpilogueParams,
|
|
1418
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
1419
|
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1420
|
+
epi_read_state: cutlass.pipeline.PipelineState,
|
|
1421
|
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
1422
|
+
tiled_mma: cute.TiledMma,
|
|
1423
|
+
tRS_rAcc: cute.Tensor,
|
|
1424
|
+
tRS_rD: cute.Tensor,
|
|
1425
|
+
tRS_rC: Optional[cute.Tensor],
|
|
1426
|
+
tiled_copy_r2s: cute.core.ThrCopy,
|
|
1427
|
+
tRS_sD: cute.Tensor,
|
|
1428
|
+
tiled_copy_s2r: Optional[cute.core.ThrCopy],
|
|
1429
|
+
tSR_rC: Optional[cute.Tensor],
|
|
1430
|
+
tSR_sC: Optional[cute.Tensor],
|
|
1431
|
+
copy_D: Optional[Callable],
|
|
1432
|
+
bSG_sD: cute.Tensor,
|
|
1433
|
+
bSG_gD: cute.Tensor,
|
|
1434
|
+
epi_load_g2s: Optional[Callable],
|
|
1435
|
+
tile_coord_mnkl: cute.Coord,
|
|
1436
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
1437
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
1438
|
+
tile_scheduler,
|
|
1439
|
+
tidx: Int32,
|
|
1440
|
+
is_tma_warp: Boolean,
|
|
1441
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
1442
|
+
has_C = const_expr(tRS_rC is not None)
|
|
1443
|
+
has_D = const_expr(copy_D is not None)
|
|
1444
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1445
|
+
epi_tile_shape = cute.zipped_divide(
|
|
1446
|
+
cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
|
|
1447
|
+
).shape[1]
|
|
1448
|
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
1449
|
+
epi_tile_num = cute.size(epi_tile_shape)
|
|
1450
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
1451
|
+
|
|
1452
|
+
if const_expr(epi_load_g2s is not None):
|
|
1453
|
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
1454
|
+
epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
|
|
1455
|
+
|
|
1456
|
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
1457
|
+
# Copy from acc to D registers
|
|
1458
|
+
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
1459
|
+
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
1460
|
+
if const_expr(has_C):
|
|
1461
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
1462
|
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
1463
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
1464
|
+
cute.arch.fence_proxy(
|
|
1465
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1466
|
+
)
|
|
1467
|
+
cute.arch.sync_warp()
|
|
1468
|
+
with cute.arch.elect_one():
|
|
1469
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
1470
|
+
epi_read_state.advance()
|
|
1471
|
+
if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
1472
|
+
epi_producer_state = epi_load_g2s(
|
|
1473
|
+
epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
|
|
1474
|
+
)
|
|
1475
|
+
tRS_rEpi = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
|
|
1476
|
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
1477
|
+
# Copy from D registers to shared memory
|
|
1478
|
+
if const_expr(has_D):
|
|
1479
|
+
# Type conversion
|
|
1480
|
+
tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
|
|
1481
|
+
tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
|
|
1482
|
+
cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
|
|
1483
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1484
|
+
cute.arch.fence_proxy(
|
|
1485
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1486
|
+
)
|
|
1487
|
+
epilogue_barrier.arrive_and_wait()
|
|
1488
|
+
# Get the global memory coordinate for the current epi tile
|
|
1489
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1490
|
+
# Copy from shared memory to global memory
|
|
1491
|
+
if is_tma_warp:
|
|
1492
|
+
if const_expr(has_D):
|
|
1493
|
+
copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
|
|
1494
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
1495
|
+
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
|
|
1496
|
+
epilogue_barrier.arrive_and_wait()
|
|
1497
|
+
|
|
1498
|
+
return epi_read_state, epi_producer_state
|
|
1499
|
+
|
|
1500
|
+
@cute.jit
|
|
1501
|
+
def epi_load_g2s(
|
|
1502
|
+
self,
|
|
1503
|
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1504
|
+
copy_C: Callable,
|
|
1505
|
+
bGS_gC: cute.Tensor,
|
|
1506
|
+
bGS_sC: cute.Tensor,
|
|
1507
|
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
1508
|
+
epi_idx: Int32,
|
|
1509
|
+
should_load: Boolean,
|
|
1510
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1511
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1512
|
+
epi_tile_layout = cute.make_layout(bGS_gC.shape[1], stride=(bGS_gC.shape[1][1], 1))
|
|
1513
|
+
if should_load:
|
|
1514
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1515
|
+
# Get the global memory coordinate for the current epi tile
|
|
1516
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1517
|
+
copy_C(
|
|
1518
|
+
bGS_gC[None, gmem_coord],
|
|
1519
|
+
bGS_sC[None, epi_producer_state.index],
|
|
1520
|
+
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1521
|
+
)
|
|
1522
|
+
# Epi pipeline's producer commit is a NOP
|
|
1523
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1524
|
+
epi_producer_state.advance()
|
|
1525
|
+
return epi_producer_state
|
|
1526
|
+
|
|
1527
|
+
def epi_visit_acc_subtile(
|
|
1528
|
+
self,
|
|
1529
|
+
params: EpilogueParams,
|
|
1530
|
+
tRS_rD: cute.Tensor,
|
|
1531
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
1532
|
+
) -> Optional[cute.Tensor]:
|
|
1533
|
+
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
|
1534
|
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
1535
|
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
|
1536
|
+
tRS_rD.store(tRS_rD.load() * alpha)
|
|
1537
|
+
# Apply C with beta scaling
|
|
1538
|
+
if const_expr(tRS_rC is not None):
|
|
1539
|
+
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
|
1540
|
+
# beta is None, default behavior: add C (beta=1.0)
|
|
1541
|
+
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
|
|
1542
|
+
else:
|
|
1543
|
+
beta = utils.load_scalar_or_pointer(params.beta)
|
|
1544
|
+
tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
|
|
1545
|
+
return None
|
|
1546
|
+
|
|
1547
|
+
def get_scheduler_class(self):
|
|
1548
|
+
"""Return the scheduler class to use. Override in subclasses for custom schedulers."""
|
|
1549
|
+
return TileScheduler
|
|
1550
|
+
|
|
1551
|
+
def get_scheduler_arguments(self, problem_shape_ntile_mnl, scheduler_args):
|
|
1552
|
+
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
|
1553
|
+
return TileSchedulerArguments(
|
|
1554
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
1555
|
+
raster_order=scheduler_args.raster_order,
|
|
1556
|
+
group_size=scheduler_args.max_swizzle_size,
|
|
1557
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1558
|
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1559
|
+
batch_idx_permute=scheduler_args.batch_idx_permute,
|
|
1560
|
+
is_persistent=self.is_persistent,
|
|
1561
|
+
)
|
|
1562
|
+
|
|
1563
|
+
def epi_visit_acc(
|
|
1564
|
+
self,
|
|
1565
|
+
params: EpilogueParams,
|
|
1566
|
+
acc: cute.Tensor,
|
|
1567
|
+
tiled_mma: cute.TiledMma,
|
|
1568
|
+
tile_coord_mnkl: cute.Coord,
|
|
1569
|
+
tidx: Int32,
|
|
1570
|
+
) -> None:
|
|
1571
|
+
pass
|
|
1572
|
+
|
|
1573
|
+
def epi_to_underlying_arguments(
|
|
1574
|
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
1575
|
+
) -> EpilogueParams:
|
|
1576
|
+
return GemmSm90.EpilogueParams(alpha=args.alpha, beta=args.beta)
|
|
1577
|
+
|
|
1578
|
+
@staticmethod
|
|
1579
|
+
def epi_smem_bytes_per_stage(
|
|
1580
|
+
args: Optional[EpilogueArguments],
|
|
1581
|
+
tile_shape_mnk: Tuple[int, int, int],
|
|
1582
|
+
epi_tile: Tuple[int, int],
|
|
1583
|
+
) -> int:
|
|
1584
|
+
return 0
|
|
1585
|
+
|
|
1586
|
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
1587
|
+
return cute.struct.MemRange[cutlass.Int32, 0] # Dummy struct
|
|
1588
|
+
|
|
1589
|
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
1590
|
+
return tuple()
|
|
1591
|
+
|
|
1652
1592
|
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
|
|
1653
1593
|
assert stage in ["mma", "epi"]
|
|
1654
1594
|
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
@@ -1665,14 +1605,174 @@ class HopperWgmmaGemmKernel:
|
|
|
1665
1605
|
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1666
1606
|
)
|
|
1667
1607
|
|
|
1668
|
-
|
|
1608
|
+
def epilog_smem_copy_atom(self, tiled_mma: cute.TiledMma) -> cute.TiledCopy:
|
|
1609
|
+
copy_atom_C = cute.make_copy_atom(
|
|
1610
|
+
warp.StMatrix8x8x16bOp(
|
|
1611
|
+
self.d_layout.is_m_major_c() if self.d_layout is not None else False,
|
|
1612
|
+
num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
|
|
1613
|
+
),
|
|
1614
|
+
cutlass.Float16, # this is just to get the right source layout
|
|
1615
|
+
)
|
|
1616
|
+
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
|
1617
|
+
return tiled_copy_C_atom
|
|
1618
|
+
|
|
1619
|
+
def epilog_smem_store_and_partition(
|
|
1620
|
+
self,
|
|
1621
|
+
tiled_mma: cute.TiledMma,
|
|
1622
|
+
d_layout: Optional[LayoutEnum],
|
|
1623
|
+
dtype: Type[cutlass.Numeric],
|
|
1624
|
+
acc: cute.Tensor,
|
|
1625
|
+
sD: cute.Tensor,
|
|
1626
|
+
tidx: Int32,
|
|
1627
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1628
|
+
if d_layout is None:
|
|
1629
|
+
d_layout = LayoutEnum.ROW_MAJOR
|
|
1630
|
+
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
1631
|
+
# Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
|
|
1632
|
+
# get st.matrix with num_matrices=4
|
|
1633
|
+
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
|
|
1634
|
+
d_layout, elem_ty_d=dtype, elem_ty_acc=self.acc_dtype
|
|
1635
|
+
)
|
|
1636
|
+
tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
|
|
1637
|
+
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1638
|
+
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1639
|
+
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
|
1640
|
+
# (R2S, R2S_M, R2S_N)
|
|
1641
|
+
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
1642
|
+
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
|
|
1643
|
+
tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
|
|
1644
|
+
tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
|
|
1645
|
+
return tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD
|
|
1646
|
+
|
|
1647
|
+
def epilog_smem_load_and_partition(
|
|
1648
|
+
self,
|
|
1649
|
+
tiled_mma: cute.TiledMma,
|
|
1650
|
+
c_layout: LayoutEnum,
|
|
1651
|
+
dtype: Type[cutlass.Numeric],
|
|
1652
|
+
sC: cute.Tensor,
|
|
1653
|
+
tRS_rD_layout: cutlass.Layout,
|
|
1654
|
+
tidx: Int32,
|
|
1655
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1656
|
+
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
1657
|
+
copy_atom_s2r = utils.sm90_get_smem_load_op(c_layout, dtype)
|
|
1658
|
+
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
|
1659
|
+
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1660
|
+
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1661
|
+
tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
|
|
1662
|
+
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
|
1663
|
+
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
|
1664
|
+
|
|
1665
|
+
def epilog_gmem_copy_and_partition(
|
|
1666
|
+
self,
|
|
1667
|
+
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
|
1668
|
+
mD_mnl: cute.Tensor,
|
|
1669
|
+
tile_shape_mn: cute.Tile,
|
|
1670
|
+
epi_tile: cute.Tile,
|
|
1671
|
+
sD: cute.Tensor,
|
|
1672
|
+
tile_coord_mnkl: cute.Coord,
|
|
1673
|
+
cu_seqlens_m: Optional[cute.Tensor] = None,
|
|
1674
|
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
1675
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1676
|
+
if const_expr(cu_seqlens_m is not None):
|
|
1677
|
+
mD_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl)
|
|
1678
|
+
else:
|
|
1679
|
+
mD_mn = mD_mnl[None, None, batch_idx]
|
|
1680
|
+
# (bM, bN)
|
|
1681
|
+
gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
|
|
1682
|
+
tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
|
|
1683
|
+
bSG_sD, bSG_gD = cpasync.tma_partition(
|
|
1684
|
+
atom,
|
|
1685
|
+
0,
|
|
1686
|
+
cute.make_layout(1),
|
|
1687
|
+
cute.group_modes(sD, 0, 2),
|
|
1688
|
+
tDgD_for_tma_partition,
|
|
1689
|
+
)
|
|
1690
|
+
return bSG_sD, bSG_gD
|
|
1691
|
+
|
|
1692
|
+
def make_ab_pipeline(
|
|
1693
|
+
self,
|
|
1694
|
+
a_smem_layout: cute.Layout | cute.ComposedLayout,
|
|
1695
|
+
b_smem_layout: cute.Layout | cute.ComposedLayout,
|
|
1696
|
+
tiled_mma: cute.TiledMma,
|
|
1697
|
+
cluster_layout_vmnk: cute.Layout,
|
|
1698
|
+
ab_pipeline_mbar_ptr: cute.Pointer,
|
|
1699
|
+
):
|
|
1700
|
+
# Threads/warps participating in this pipeline
|
|
1701
|
+
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
|
|
1702
|
+
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
1703
|
+
# Each warp will contribute to the arrive count with the number of mcast size
|
|
1704
|
+
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
1705
|
+
consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
|
|
1706
|
+
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1707
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1708
|
+
)
|
|
1709
|
+
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
1710
|
+
tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
1711
|
+
if const_expr(not self.gather_A):
|
|
1712
|
+
tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
1713
|
+
return pipeline_cls.create(
|
|
1714
|
+
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1715
|
+
num_stages=self.ab_stage,
|
|
1716
|
+
producer_group=ab_pipeline_producer_group,
|
|
1717
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
1718
|
+
tx_count=tma_copy_bytes,
|
|
1719
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1720
|
+
)
|
|
1721
|
+
|
|
1722
|
+
def make_epi_pipeline(
|
|
1723
|
+
self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer
|
|
1724
|
+
):
|
|
1725
|
+
# Threads/warps participating in this pipeline
|
|
1726
|
+
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1727
|
+
# Each warp will contribute 1 to the arrive count
|
|
1728
|
+
consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
|
|
1729
|
+
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1730
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1731
|
+
)
|
|
1732
|
+
tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
|
|
1733
|
+
return pipeline.PipelineTmaAsync.create(
|
|
1734
|
+
barrier_storage=epi_pipeline_mbar_ptr,
|
|
1735
|
+
num_stages=self.epi_c_stage,
|
|
1736
|
+
producer_group=epi_pipeline_producer_group,
|
|
1737
|
+
consumer_group=epi_pipeline_consumer_group,
|
|
1738
|
+
tx_count=tma_copy_c_bytes,
|
|
1739
|
+
)
|
|
1740
|
+
|
|
1741
|
+
def make_sched_pipeline(
|
|
1742
|
+
self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
|
|
1743
|
+
):
|
|
1744
|
+
# Threads/warps participating in this pipeline
|
|
1745
|
+
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1746
|
+
cluster_size = cute.size(cluster_layout_mnk)
|
|
1747
|
+
# Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
|
1748
|
+
# If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
|
|
1749
|
+
# at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
|
|
1750
|
+
consumer_arrive_cnt = (
|
|
1751
|
+
(self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
|
|
1752
|
+
+ self.num_ab_load_warps
|
|
1753
|
+
) * cluster_size - 1
|
|
1754
|
+
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1755
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1756
|
+
)
|
|
1757
|
+
return pipeline.PipelineAsync.create(
|
|
1758
|
+
barrier_storage=sched_pipeline_mbar_ptr,
|
|
1759
|
+
num_stages=self.sched_stage,
|
|
1760
|
+
producer_group=sched_pipeline_producer_group,
|
|
1761
|
+
consumer_group=sched_pipeline_consumer_group,
|
|
1762
|
+
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
1763
|
+
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
|
1764
|
+
)
|
|
1765
|
+
|
|
1766
|
+
@classmethod
|
|
1669
1767
|
def _compute_stages(
|
|
1768
|
+
cls,
|
|
1670
1769
|
tile_shape_mnk: Tuple[int, int, int],
|
|
1671
|
-
epi_tile:
|
|
1770
|
+
epi_tile: Tuple[int, int],
|
|
1672
1771
|
a_dtype: Type[cutlass.Numeric],
|
|
1673
1772
|
b_dtype: Type[cutlass.Numeric],
|
|
1674
|
-
d_dtype: Type[cutlass.Numeric],
|
|
1773
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1675
1774
|
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1775
|
+
epilogue_args: Optional[EpilogueArguments],
|
|
1676
1776
|
smem_capacity: int,
|
|
1677
1777
|
occupancy: int,
|
|
1678
1778
|
overlap_sD_sA: bool,
|
|
@@ -1695,13 +1795,18 @@ class HopperWgmmaGemmKernel:
|
|
|
1695
1795
|
:rtype: Tuple[int, int]
|
|
1696
1796
|
"""
|
|
1697
1797
|
|
|
1698
|
-
epi_stage = 2
|
|
1798
|
+
epi_stage = 4 if epi_tile[1] <= 16 else 2
|
|
1699
1799
|
if overlap_sD_sA:
|
|
1700
1800
|
epi_bytes = 0
|
|
1701
1801
|
else:
|
|
1702
|
-
d_bytes_per_stage =
|
|
1703
|
-
|
|
1704
|
-
|
|
1802
|
+
d_bytes_per_stage = (
|
|
1803
|
+
cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
|
|
1804
|
+
)
|
|
1805
|
+
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
|
1806
|
+
epilogue_args, tile_shape_mnk, epi_tile
|
|
1807
|
+
)
|
|
1808
|
+
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
1809
|
+
epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
|
|
1705
1810
|
if c_dtype is not None:
|
|
1706
1811
|
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
1707
1812
|
|
|
@@ -1712,23 +1817,21 @@ class HopperWgmmaGemmKernel:
|
|
|
1712
1817
|
)
|
|
1713
1818
|
mbar_helpers_bytes = 1024
|
|
1714
1819
|
|
|
1715
|
-
remaining_bytes =
|
|
1716
|
-
(smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
|
|
1717
|
-
)
|
|
1820
|
+
remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
|
|
1718
1821
|
ab_stage = remaining_bytes // ab_bytes_per_stage
|
|
1719
1822
|
|
|
1720
1823
|
# Refine epilogue stages:
|
|
1721
1824
|
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1722
1825
|
# Add remaining unused smem to epilogue
|
|
1723
|
-
if not overlap_sD_sA:
|
|
1724
|
-
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) //
|
|
1826
|
+
if not overlap_sD_sA and epi_bytes_per_stage > 0:
|
|
1827
|
+
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
|
|
1725
1828
|
return ab_stage, epi_stage, epi_c_stage
|
|
1726
1829
|
|
|
1727
1830
|
@staticmethod
|
|
1728
1831
|
def _sm90_compute_tile_shape_or_override(
|
|
1729
1832
|
tile_shape_mnk: Tuple[int, int, int],
|
|
1730
1833
|
atom_layout_mnk: Tuple[int, int, int],
|
|
1731
|
-
element_type: Type[cutlass.Numeric],
|
|
1834
|
+
element_type: Optional[Type[cutlass.Numeric]] = None,
|
|
1732
1835
|
epi_tile_override: Tuple[int, int] | None = None,
|
|
1733
1836
|
) -> Tuple[int, int]:
|
|
1734
1837
|
"""Compute the epilogue tile shape or use override if provided.
|
|
@@ -1760,7 +1863,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1760
1863
|
# iterate along the N dimension first, then move to the M dimension.
|
|
1761
1864
|
# We could change the epilogue to accommodate this,
|
|
1762
1865
|
# but it's easier to just set epi_tile_m = 64.
|
|
1763
|
-
n_perf = 64 if element_type.width == 8 else 32
|
|
1866
|
+
n_perf = 64 if element_type is not None and element_type.width == 8 else 32
|
|
1764
1867
|
tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
|
|
1765
1868
|
tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
|
|
1766
1869
|
return (tile_m, tile_n)
|
|
@@ -1770,15 +1873,15 @@ class HopperWgmmaGemmKernel:
|
|
|
1770
1873
|
tile_shape_mnk: Tuple[int, int, int],
|
|
1771
1874
|
epi_tile: Tuple[int, int],
|
|
1772
1875
|
a_dtype: Type[cutlass.Numeric],
|
|
1773
|
-
a_layout:
|
|
1876
|
+
a_layout: LayoutEnum,
|
|
1774
1877
|
b_dtype: Type[cutlass.Numeric],
|
|
1775
|
-
b_layout:
|
|
1878
|
+
b_layout: LayoutEnum,
|
|
1776
1879
|
ab_stage: int,
|
|
1777
|
-
d_dtype: Type[cutlass.Numeric],
|
|
1778
|
-
d_layout:
|
|
1880
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1881
|
+
d_layout: LayoutEnum,
|
|
1779
1882
|
epi_stage: int,
|
|
1780
1883
|
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1781
|
-
c_layout: Optional[
|
|
1884
|
+
c_layout: Optional[LayoutEnum],
|
|
1782
1885
|
epi_c_stage: int,
|
|
1783
1886
|
) -> Tuple[
|
|
1784
1887
|
cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
|
|
@@ -1792,17 +1895,17 @@ class HopperWgmmaGemmKernel:
|
|
|
1792
1895
|
:param a_dtype: Data type for matrix A
|
|
1793
1896
|
:type a_dtype: type[cutlass.Numeric]
|
|
1794
1897
|
:param a_layout: Layout enum for matrix A
|
|
1795
|
-
:type a_layout:
|
|
1898
|
+
:type a_layout: LayoutEnum
|
|
1796
1899
|
:param b_dtype: Data type for matrix B
|
|
1797
1900
|
:type b_dtype: type[cutlass.Numeric]
|
|
1798
1901
|
:param b_layout: Layout enum for matrix B
|
|
1799
|
-
:type b_layout:
|
|
1902
|
+
:type b_layout: LayoutEnum
|
|
1800
1903
|
:param ab_stage: Number of stages for A/B tensors
|
|
1801
1904
|
:type ab_stage: int
|
|
1802
|
-
:param d_dtype: Data type for output matrix
|
|
1905
|
+
:param d_dtype: Data type for output matrix D
|
|
1803
1906
|
:type d_dtype: type[cutlass.Numeric]
|
|
1804
1907
|
:param d_layout: Layout enum for the output matrix C
|
|
1805
|
-
:type d_layout:
|
|
1908
|
+
:type d_layout: LayoutEnum
|
|
1806
1909
|
:param epi_stage: Number of epilogue stages
|
|
1807
1910
|
:type epi_stage: int
|
|
1808
1911
|
|
|
@@ -1815,11 +1918,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1815
1918
|
b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
|
1816
1919
|
a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
|
|
1817
1920
|
a_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1818
|
-
sm90_utils.get_smem_layout_atom(
|
|
1819
|
-
a_layout,
|
|
1820
|
-
a_dtype,
|
|
1821
|
-
a_major_mode_size,
|
|
1822
|
-
),
|
|
1921
|
+
sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
|
|
1823
1922
|
a_dtype,
|
|
1824
1923
|
)
|
|
1825
1924
|
a_smem_layout_staged = cute.tile_to_shape(
|
|
@@ -1832,11 +1931,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1832
1931
|
|
|
1833
1932
|
b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
|
|
1834
1933
|
b_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1835
|
-
sm90_utils.get_smem_layout_atom(
|
|
1836
|
-
b_layout,
|
|
1837
|
-
b_dtype,
|
|
1838
|
-
b_major_mode_size,
|
|
1839
|
-
),
|
|
1934
|
+
sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
|
|
1840
1935
|
b_dtype,
|
|
1841
1936
|
)
|
|
1842
1937
|
b_smem_layout_staged = cute.tile_to_shape(
|
|
@@ -1845,17 +1940,20 @@ class HopperWgmmaGemmKernel:
|
|
|
1845
1940
|
order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
|
|
1846
1941
|
)
|
|
1847
1942
|
|
|
1848
|
-
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1943
|
+
if d_dtype is not None:
|
|
1944
|
+
d_smem_shape = epi_tile
|
|
1945
|
+
d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
|
|
1946
|
+
d_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1947
|
+
sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
|
|
1948
|
+
d_dtype,
|
|
1949
|
+
)
|
|
1950
|
+
epi_smem_layout_staged = cute.tile_to_shape(
|
|
1951
|
+
d_smem_layout_atom,
|
|
1952
|
+
cute.append(d_smem_shape, epi_stage),
|
|
1953
|
+
order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
|
|
1954
|
+
)
|
|
1955
|
+
else:
|
|
1956
|
+
epi_smem_layout_staged = None
|
|
1859
1957
|
|
|
1860
1958
|
if c_dtype is not None:
|
|
1861
1959
|
assert c_layout is not None
|
|
@@ -1961,7 +2059,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1961
2059
|
thread_layout = cute.make_layout(
|
|
1962
2060
|
(num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
|
1963
2061
|
)
|
|
1964
|
-
if major_mode !=
|
|
2062
|
+
if major_mode != LayoutEnum.ROW_MAJOR:
|
|
1965
2063
|
shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
|
|
1966
2064
|
thread_layout = cute.make_layout(
|
|
1967
2065
|
(shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
|
@@ -1969,7 +2067,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1969
2067
|
# Value layout for copy
|
|
1970
2068
|
value_layout = (
|
|
1971
2069
|
cute.make_layout((1, copy_elems))
|
|
1972
|
-
if major_mode ==
|
|
2070
|
+
if major_mode == LayoutEnum.ROW_MAJOR
|
|
1973
2071
|
else cute.make_layout((copy_elems, 1))
|
|
1974
2072
|
)
|
|
1975
2073
|
return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
|
|
@@ -1979,7 +2077,7 @@ class HopperWgmmaGemmKernel:
|
|
|
1979
2077
|
a_dtype: Type[cutlass.Numeric],
|
|
1980
2078
|
b_dtype: Type[cutlass.Numeric],
|
|
1981
2079
|
acc_dtype: Type[cutlass.Numeric],
|
|
1982
|
-
d_dtype: Type[cutlass.Numeric],
|
|
2080
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1983
2081
|
a_major: str,
|
|
1984
2082
|
b_major: str,
|
|
1985
2083
|
) -> bool:
|
|
@@ -2022,6 +2120,7 @@ class HopperWgmmaGemmKernel:
|
|
|
2022
2120
|
is_valid = False
|
|
2023
2121
|
# tested d_dtype
|
|
2024
2122
|
if d_dtype not in {
|
|
2123
|
+
None,
|
|
2025
2124
|
cutlass.Float32,
|
|
2026
2125
|
cutlass.Float16,
|
|
2027
2126
|
cutlass.BFloat16,
|
|
@@ -2039,436 +2138,108 @@ class HopperWgmmaGemmKernel:
|
|
|
2039
2138
|
# for Float8 types, this implementation only supports k-major layout
|
|
2040
2139
|
if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
|
|
2041
2140
|
is_valid = False
|
|
2042
|
-
|
|
2043
2141
|
return is_valid
|
|
2044
2142
|
|
|
2045
2143
|
|
|
2046
|
-
def
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
|
|
2065
|
-
|
|
2066
|
-
|
|
2067
|
-
|
|
2068
|
-
|
|
2069
|
-
|
|
2070
|
-
)
|
|
2071
|
-
|
|
2072
|
-
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
:param acc_dtype: Data type for accumulation during matrix multiplication
|
|
2083
|
-
:type acc_dtype: Type[cutlass.Numeric]
|
|
2084
|
-
:param a_major/b_major/d_major: Memory layout of tensor A/B/C
|
|
2085
|
-
:type a_major/b_major/d_major: str
|
|
2086
|
-
:param tile_shape_mnk: CTA tile shape (M, N, K)
|
|
2087
|
-
:type tile_shape_mnk: Tuple[int, int, int]
|
|
2088
|
-
:param cluster_shape_mn: Cluster shape (M, N)
|
|
2089
|
-
:type cluster_shape_mn: Tuple[int, int]
|
|
2090
|
-
:param tolerance: Tolerance value for reference validation comparison
|
|
2091
|
-
:type tolerance: float
|
|
2092
|
-
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
|
2093
|
-
:type warmup_iterations: int, optional
|
|
2094
|
-
:param iterations: Number of benchmark iterations to run, defaults to 1
|
|
2095
|
-
:type iterations: int, optional
|
|
2096
|
-
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
|
2097
|
-
:type skip_ref_check: bool, optional
|
|
2098
|
-
"""
|
|
2099
|
-
|
|
2100
|
-
if dynamic_persistent:
|
|
2101
|
-
persistent = True
|
|
2102
|
-
|
|
2103
|
-
print("Running Hopper Dense GEMM with:")
|
|
2104
|
-
print(f"mnkl: {mnkl}")
|
|
2105
|
-
print(
|
|
2106
|
-
f"A dtype: {a_dtype}, B dtype: {b_dtype}, D dtype: {d_dtype}, C_dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
|
2107
|
-
)
|
|
2108
|
-
print(f"Matrix majors - A: {a_major}, B: {b_major}, D: {d_major}")
|
|
2109
|
-
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
|
2110
|
-
print(f"Tolerance: {tolerance}")
|
|
2111
|
-
print(f"Warmup iterations: {warmup_iterations}")
|
|
2112
|
-
print(f"Iterations: {iterations}")
|
|
2113
|
-
print(f"Skip reference checking: {skip_ref_check}")
|
|
2114
|
-
|
|
2115
|
-
# Unpack parameters
|
|
2116
|
-
m, n, k, l = mnkl
|
|
2117
|
-
cluster_shape_mnk = (*cluster_shape_mn, 1)
|
|
2118
|
-
|
|
2119
|
-
# Skip unsupported types
|
|
2120
|
-
if not HopperWgmmaGemmKernel.is_valid_dtypes(
|
|
2121
|
-
a_dtype, b_dtype, acc_dtype, d_dtype, a_major, b_major
|
|
2144
|
+
def gemm_sm90(
|
|
2145
|
+
A: Tensor, # (l, m, k)
|
|
2146
|
+
B: Tensor, # (l, n, k)
|
|
2147
|
+
D: Tensor, # (l, m, n)
|
|
2148
|
+
C: Optional[Tensor], # (l, m, n)
|
|
2149
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
2150
|
+
tile_M: int,
|
|
2151
|
+
tile_N: int,
|
|
2152
|
+
cluster_M: int,
|
|
2153
|
+
cluster_N: int,
|
|
2154
|
+
pingpong: bool = False,
|
|
2155
|
+
persistent: bool = True,
|
|
2156
|
+
alpha: float | Tensor = 1.0,
|
|
2157
|
+
beta: float | Tensor = 1.0,
|
|
2158
|
+
) -> None:
|
|
2159
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(A, B, D, C)
|
|
2160
|
+
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
2161
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
2162
|
+
major_configs = {
|
|
2163
|
+
"A": ("m", "k", "l"),
|
|
2164
|
+
"B": ("n", "k", "l"),
|
|
2165
|
+
"D": ("m", "n", "l"),
|
|
2166
|
+
"C": ("m", "n", "l"),
|
|
2167
|
+
}
|
|
2168
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
2169
|
+
|
|
2170
|
+
acc_dtype = cutlass.Float32
|
|
2171
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
2172
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
2173
|
+
if not GemmSm90.is_valid_dtypes(
|
|
2174
|
+
tensor_infos["A"].dtype,
|
|
2175
|
+
tensor_infos["B"].dtype,
|
|
2176
|
+
acc_dtype,
|
|
2177
|
+
tensor_infos["D"].dtype,
|
|
2178
|
+
tensor_infos["A"].major,
|
|
2179
|
+
tensor_infos["B"].major,
|
|
2122
2180
|
):
|
|
2123
|
-
raise TypeError(
|
|
2124
|
-
f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {d_dtype}, {a_major=}, {b_major=}"
|
|
2125
|
-
)
|
|
2181
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
2126
2182
|
|
|
2127
|
-
|
|
2128
|
-
|
|
2129
|
-
raise RuntimeError("GPU is required to run this example!")
|
|
2130
|
-
|
|
2131
|
-
torch.manual_seed(1111)
|
|
2132
|
-
|
|
2133
|
-
# Create and permute tensor A/B/C
|
|
2134
|
-
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
|
|
2135
|
-
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
|
2136
|
-
# else : (l, mode0, mode1) -> (mode0, mode1, l)
|
|
2137
|
-
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
|
2138
|
-
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
|
2139
|
-
is_unsigned = dtype in {cutlass.Uint8}
|
|
2140
|
-
# Temporarily use uint8 as torch does not support fp8 type
|
|
2141
|
-
torch_dtype = cutlass_torch.dtype(dtype)
|
|
2142
|
-
gen_dtype = (
|
|
2143
|
-
torch_dtype
|
|
2144
|
-
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
2145
|
-
else torch.bfloat16
|
|
2146
|
-
)
|
|
2147
|
-
|
|
2148
|
-
# Create dtype torch tensor (cpu)
|
|
2149
|
-
torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
|
|
2150
|
-
shape,
|
|
2151
|
-
gen_dtype,
|
|
2152
|
-
permute_order=permute_order,
|
|
2153
|
-
# init_type=cutlass.torch.TensorInitType.RANDOM,
|
|
2154
|
-
# init_config=cutlass.torch.RandomInitConfig(
|
|
2155
|
-
# min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
|
|
2156
|
-
# ),
|
|
2157
|
-
init_type=cutlass.torch.TensorInitType.GAUSSIAN,
|
|
2158
|
-
init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
|
|
2159
|
-
).to(torch_dtype)
|
|
2160
|
-
# Create dtype torch tensor (gpu)
|
|
2161
|
-
torch_tensor = torch_tensor_cpu.cuda()
|
|
2162
|
-
|
|
2163
|
-
# Create f32 torch tensor (cpu)
|
|
2164
|
-
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
|
2165
|
-
|
|
2166
|
-
# Create dtype cute tensor (gpu)
|
|
2167
|
-
torch_tensor_view = (
|
|
2168
|
-
torch_tensor
|
|
2169
|
-
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
2170
|
-
else torch_tensor.view(torch.uint8)
|
|
2171
|
-
)
|
|
2172
|
-
cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
|
|
2173
|
-
cute_tensor.element_type = dtype
|
|
2174
|
-
if is_dynamic_layout:
|
|
2175
|
-
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
|
|
2176
|
-
cute_tensor = cute_tensor.mark_compact_shape_dynamic(
|
|
2177
|
-
mode=(1 if not is_mode0_major else 0),
|
|
2178
|
-
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
|
|
2179
|
-
divisibility=(128 // dtype.width),
|
|
2180
|
-
)
|
|
2181
|
-
cute_tensor = cutlass.torch.convert_cute_tensor(
|
|
2182
|
-
f32_torch_tensor,
|
|
2183
|
-
cute_tensor,
|
|
2184
|
-
dtype,
|
|
2185
|
-
is_dynamic_layout=is_dynamic_layout,
|
|
2186
|
-
)
|
|
2183
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
2184
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
2187
2185
|
|
|
2188
|
-
|
|
2189
|
-
|
|
2190
|
-
|
|
2191
|
-
if gather_A:
|
|
2192
|
-
assert a_major == "k"
|
|
2193
|
-
a_idx = torch.randperm(l * m, dtype=torch.int32, device="cuda")
|
|
2194
|
-
from einops import rearrange
|
|
2195
|
-
|
|
2196
|
-
a = rearrange(rearrange(a, "m k l -> (m l) k")[a_idx.cpu()], "(m l) k -> m k l", m=m)
|
|
2197
|
-
a_torch = rearrange(a_torch, "m k l -> (m l) k")
|
|
2198
|
-
mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2199
|
-
a_idx_reshaped = rearrange(a_idx, "(m l) -> l m", m=m).contiguous().transpose(0, 1)
|
|
2200
|
-
mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
2201
|
-
else:
|
|
2202
|
-
mAIdx = None
|
|
2203
|
-
b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
|
2204
|
-
_, mD, d_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
|
|
2205
|
-
if c_dtype is not None:
|
|
2206
|
-
c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
|
2207
|
-
else:
|
|
2208
|
-
c, mC, c_torch = None, None, None
|
|
2209
|
-
if varlen_m:
|
|
2210
|
-
assert a_major == "k"
|
|
2211
|
-
assert d_major == "n"
|
|
2212
|
-
from einops import rearrange
|
|
2213
|
-
|
|
2214
|
-
a, d_torch = [rearrange(t, "m x l -> (l m) x") for t in (a, d_torch)]
|
|
2215
|
-
if not gather_A:
|
|
2216
|
-
(a_torch,) = [rearrange(t, "m x l -> (l m) x") for t in (a_torch,)]
|
|
2217
|
-
if c_dtype is not None:
|
|
2218
|
-
c, c_torch = [rearrange(t, "m x l -> (l m) x") for t in (c, c_torch)]
|
|
2219
|
-
mC = from_dlpack(c_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2220
|
-
mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2221
|
-
mD = from_dlpack(d_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2222
|
-
# TODO: generate random cu_seqlens_m
|
|
2223
|
-
cu_seqlens_m = torch.arange(0, l + 1, dtype=torch.int32, device="cuda") * m
|
|
2224
|
-
mCuSeqlensM = from_dlpack(cu_seqlens_m, assumed_align=64).mark_layout_dynamic(leading_dim=0)
|
|
2225
|
-
if gather_A:
|
|
2226
|
-
a_idx_reshaped = rearrange(a_idx_reshaped, "m l -> (l m)")
|
|
2227
|
-
mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
2228
|
-
else:
|
|
2229
|
-
cu_seqlens_m, mCuSeqlensM = None, None
|
|
2230
|
-
|
|
2231
|
-
if varlen_m: # Need to allocate space in gmem to store tensormaps
|
|
2232
|
-
if not persistent:
|
|
2233
|
-
total_m = m * l
|
|
2234
|
-
block_size_m = tile_shape_mnk[0] * cluster_shape_mnk[0]
|
|
2235
|
-
block_size_n = tile_shape_mnk[1] * cluster_shape_mnk[1]
|
|
2236
|
-
total_clusters_m_max = (total_m + l * (block_size_m - 1)) // block_size_m
|
|
2237
|
-
total_clusters_max = total_clusters_m_max * ((n + block_size_n - 1) // block_size_n)
|
|
2238
|
-
total_ctas = total_clusters_max * cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
|
2186
|
+
def scalar_arg(scalar: float | Tensor):
|
|
2187
|
+
if isinstance(scalar, float):
|
|
2188
|
+
return Float32(scalar) if scalar != 1.0 else None
|
|
2239
2189
|
else:
|
|
2240
|
-
|
|
2241
|
-
|
|
2242
|
-
|
|
2243
|
-
|
|
2244
|
-
|
|
2245
|
-
|
|
2246
|
-
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
|
|
2251
|
-
|
|
2252
|
-
acc_dtype,
|
|
2253
|
-
a_dtype,
|
|
2254
|
-
tile_shape_mnk,
|
|
2190
|
+
assert isinstance(scalar, Tensor)
|
|
2191
|
+
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2192
|
+
|
|
2193
|
+
epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta))
|
|
2194
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
2195
|
+
max_active_clusters, tile_count_semaphore
|
|
2196
|
+
)
|
|
2197
|
+
current_stream = cutlass_torch.current_stream()
|
|
2198
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
2199
|
+
tensor_infos,
|
|
2200
|
+
None,
|
|
2201
|
+
tile_shape_mn,
|
|
2255
2202
|
cluster_shape_mnk,
|
|
2256
|
-
pingpong
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2203
|
+
pingpong,
|
|
2204
|
+
persistent,
|
|
2205
|
+
tile_count_semaphore is not None,
|
|
2206
|
+
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
|
2207
|
+
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
|
2208
|
+
key_tensor_names=("A", "B", "D", "C"),
|
|
2260
2209
|
)
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2210
|
+
cache = gemm_sm90.compile_cache
|
|
2211
|
+
if compile_key not in cache:
|
|
2212
|
+
gemm = GemmSm90(
|
|
2213
|
+
acc_dtype,
|
|
2214
|
+
tensor_infos["A"].dtype,
|
|
2215
|
+
tile_shape_mn,
|
|
2216
|
+
cluster_shape_mnk,
|
|
2217
|
+
pingpong=pingpong,
|
|
2218
|
+
is_persistent=persistent,
|
|
2266
2219
|
)
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
|
|
2277
|
-
# compile gemm kernel
|
|
2278
|
-
compiled_gemm = cute.compile(
|
|
2279
|
-
gemm,
|
|
2280
|
-
mA,
|
|
2281
|
-
mB,
|
|
2282
|
-
mD,
|
|
2283
|
-
mC,
|
|
2284
|
-
mAIdx,
|
|
2285
|
-
mCuSeqlensM,
|
|
2286
|
-
tensormaps_tensor,
|
|
2287
|
-
make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2288
|
-
if tile_count_semaphore is not None
|
|
2289
|
-
else None,
|
|
2290
|
-
max_active_clusters,
|
|
2291
|
-
current_stream,
|
|
2292
|
-
)
|
|
2293
|
-
|
|
2294
|
-
if not skip_ref_check:
|
|
2295
|
-
# execution
|
|
2296
|
-
compiled_gemm(
|
|
2297
|
-
mA,
|
|
2298
|
-
mB,
|
|
2299
|
-
mD,
|
|
2300
|
-
mC,
|
|
2301
|
-
mAIdx,
|
|
2302
|
-
mCuSeqlensM,
|
|
2303
|
-
tensormaps_tensor,
|
|
2304
|
-
tile_count_semaphore,
|
|
2305
|
-
max_active_clusters,
|
|
2220
|
+
cache[compile_key] = cute.compile(
|
|
2221
|
+
gemm,
|
|
2222
|
+
tensor_infos["A"].cute_tensor,
|
|
2223
|
+
tensor_infos["B"].cute_tensor,
|
|
2224
|
+
tensor_infos["D"].cute_tensor,
|
|
2225
|
+
tensor_infos["C"].cute_tensor,
|
|
2226
|
+
epi_args,
|
|
2227
|
+
scheduler_args,
|
|
2228
|
+
None, # varlen_args
|
|
2229
|
+
None, # mAIdx
|
|
2306
2230
|
current_stream,
|
|
2307
2231
|
)
|
|
2308
|
-
|
|
2309
|
-
|
|
2310
|
-
|
|
2311
|
-
|
|
2312
|
-
|
|
2313
|
-
|
|
2314
|
-
|
|
2315
|
-
|
|
2316
|
-
|
|
2317
|
-
|
|
2318
|
-
|
|
2319
|
-
torch.einsum("mk,nk->mn", a[cu_seqlens_m[i] : cu_seqlens_m[i + 1]], b[:, :, i])
|
|
2320
|
-
for i in range(l)
|
|
2321
|
-
],
|
|
2322
|
-
dim=0,
|
|
2323
|
-
)
|
|
2324
|
-
if c is not None:
|
|
2325
|
-
ref = ref + c
|
|
2326
|
-
ref = ref.cpu()
|
|
2327
|
-
|
|
2328
|
-
if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
|
2329
|
-
# m major: (l, n, m) -> (m, n, l)
|
|
2330
|
-
# n major: (l, m, n) -> (m, n, l)
|
|
2331
|
-
permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
|
|
2332
|
-
shape = (l, m, n) if d_major == "n" else (l, n, m)
|
|
2333
|
-
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
|
2334
|
-
shape,
|
|
2335
|
-
torch.uint8,
|
|
2336
|
-
permute_order=permute_order,
|
|
2337
|
-
init_type=cutlass_torch.TensorInitType.SKIP,
|
|
2338
|
-
).cuda()
|
|
2339
|
-
# Create dtype cute tensor (gpu)
|
|
2340
|
-
ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
|
|
2341
|
-
leading_dim=(1 if d_major == "n" else 0)
|
|
2342
|
-
)
|
|
2343
|
-
ref_d_tensor.element_type = d_dtype
|
|
2344
|
-
ref_d_tensor = cutlass_torch.convert_cute_tensor(
|
|
2345
|
-
ref,
|
|
2346
|
-
ref_d_tensor,
|
|
2347
|
-
d_dtype,
|
|
2348
|
-
is_dynamic_layout=True,
|
|
2349
|
-
)
|
|
2350
|
-
ref_d = f8_torch_tensor.cpu()
|
|
2351
|
-
else:
|
|
2352
|
-
ref_d = ref.to(cutlass_torch.dtype(d_dtype))
|
|
2353
|
-
|
|
2354
|
-
out = d_torch.cpu().squeeze()
|
|
2355
|
-
out_ref = ref_d.squeeze()
|
|
2356
|
-
# breakpoint()
|
|
2357
|
-
torch.testing.assert_close(d_torch.cpu(), ref_d, atol=tolerance, rtol=1e-03)
|
|
2358
|
-
|
|
2359
|
-
# return
|
|
2360
|
-
|
|
2361
|
-
from triton.testing import do_bench
|
|
2362
|
-
|
|
2363
|
-
flops = 2 * m * n * k * l
|
|
2364
|
-
# Calculate memory bandwidth
|
|
2365
|
-
bytes_A = m * k * l * (a_dtype.width // 8) # A tensor: (m, k, l)
|
|
2366
|
-
bytes_B = n * k * l * (b_dtype.width // 8) # B tensor: (n, k, l)
|
|
2367
|
-
bytes_D = m * n * l * (d_dtype.width // 8) # D tensor: (m, n, l)
|
|
2368
|
-
bytes_C = m * n * l * (c_dtype.width // 8) if c_dtype is not None else 0 # C tensor: (m, n, l)
|
|
2369
|
-
total_bytes = bytes_A + bytes_B + bytes_D + bytes_C # Read A, B, C; Write D
|
|
2370
|
-
|
|
2371
|
-
repeats = iterations
|
|
2372
|
-
warmup = warmup_iterations
|
|
2373
|
-
|
|
2374
|
-
import time
|
|
2375
|
-
|
|
2376
|
-
if not varlen_m and not gather_A:
|
|
2377
|
-
time.sleep(0.5)
|
|
2378
|
-
if a_dtype.width == 8:
|
|
2379
|
-
assert l == 1
|
|
2380
|
-
scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
|
|
2381
|
-
fn_cublas = lambda: torch._scaled_mm(
|
|
2382
|
-
a_torch[:, :, 0],
|
|
2383
|
-
b_torch[:, :, 0].mT,
|
|
2384
|
-
scale_a=scale_ab,
|
|
2385
|
-
scale_b=scale_ab,
|
|
2386
|
-
out_dtype=torch.bfloat16,
|
|
2387
|
-
use_fast_accum=fp8_fast_accum,
|
|
2388
|
-
)
|
|
2389
|
-
else:
|
|
2390
|
-
if c_torch is None:
|
|
2391
|
-
fn_cublas = lambda: torch.matmul(
|
|
2392
|
-
a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT
|
|
2393
|
-
)
|
|
2394
|
-
else:
|
|
2395
|
-
c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
|
|
2396
|
-
fn_cublas = lambda: torch.baddbmm(
|
|
2397
|
-
c_torch_convert.permute(2, 0, 1),
|
|
2398
|
-
a_torch.permute(2, 0, 1),
|
|
2399
|
-
b_torch.permute(2, 0, 1).mT,
|
|
2400
|
-
)
|
|
2401
|
-
timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2402
|
-
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2403
|
-
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2232
|
+
cache[compile_key](
|
|
2233
|
+
tensor_infos["A"].cute_tensor,
|
|
2234
|
+
tensor_infos["B"].cute_tensor,
|
|
2235
|
+
tensor_infos["D"].cute_tensor,
|
|
2236
|
+
tensor_infos["C"].cute_tensor,
|
|
2237
|
+
epi_args,
|
|
2238
|
+
scheduler_args,
|
|
2239
|
+
None,
|
|
2240
|
+
None,
|
|
2241
|
+
current_stream,
|
|
2242
|
+
)
|
|
2404
2243
|
|
|
2405
|
-
time.sleep(0.5)
|
|
2406
2244
|
|
|
2407
|
-
|
|
2408
|
-
compiled_gemm(
|
|
2409
|
-
mA,
|
|
2410
|
-
mB,
|
|
2411
|
-
mD,
|
|
2412
|
-
mC,
|
|
2413
|
-
mAIdx,
|
|
2414
|
-
mCuSeqlensM,
|
|
2415
|
-
tensormaps_tensor,
|
|
2416
|
-
tile_count_semaphore,
|
|
2417
|
-
max_active_clusters,
|
|
2418
|
-
current_stream,
|
|
2419
|
-
)
|
|
2420
|
-
if tile_count_semaphore is not None and varlen_m:
|
|
2421
|
-
tile_count_semaphore.zero_()
|
|
2422
|
-
|
|
2423
|
-
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
2424
|
-
# Idk why but for some cases the 1st run is much slower
|
|
2425
|
-
time.sleep(0.5)
|
|
2426
|
-
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
2427
|
-
tflops = flops / (timing * 1e9) # Convert to TFlops
|
|
2428
|
-
gbps = total_bytes / (timing * 1e6) # Convert to GB/s (1e9 for ms->s, 1e9 for B->GB)
|
|
2429
|
-
print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}, GB/s: {gbps:.0f}")
|
|
2430
|
-
fn()
|
|
2431
|
-
|
|
2432
|
-
if not varlen_m:
|
|
2433
|
-
time.sleep(0.5)
|
|
2434
|
-
timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2435
|
-
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2436
|
-
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2437
|
-
|
|
2438
|
-
from flash_attn.utils.benchmark import pytorch_profiler
|
|
2439
|
-
|
|
2440
|
-
pytorch_profiler(fn_cublas)
|
|
2441
|
-
# pytorch_profiler(torch.sort, d_torch.squeeze(), dim=-1)
|
|
2442
|
-
# pytorch_profiler(torch.compile(torch.sort), d_torch.squeeze(), dim=-1)
|
|
2443
|
-
# pytorch_profiler(torch.topk, d_torch.squeeze(), dim=-1, k=1)
|
|
2444
|
-
# pytorch_profiler(torch.compile(torch.topk), d_torch.squeeze(), dim=-1, k=1)
|
|
2445
|
-
# pytorch_profiler(torch.square, d_torch.squeeze())
|
|
2446
|
-
|
|
2447
|
-
|
|
2448
|
-
if __name__ == "__main__":
|
|
2449
|
-
args = parse_arguments()
|
|
2450
|
-
run(
|
|
2451
|
-
args.mnkl,
|
|
2452
|
-
args.a_dtype,
|
|
2453
|
-
args.b_dtype,
|
|
2454
|
-
args.d_dtype,
|
|
2455
|
-
args.c_dtype,
|
|
2456
|
-
args.acc_dtype,
|
|
2457
|
-
args.a_major,
|
|
2458
|
-
args.b_major,
|
|
2459
|
-
args.d_major,
|
|
2460
|
-
args.c_major,
|
|
2461
|
-
args.tile_shape_mnk,
|
|
2462
|
-
args.cluster_shape_mn,
|
|
2463
|
-
args.tolerance,
|
|
2464
|
-
args.warmup_iterations,
|
|
2465
|
-
args.iterations,
|
|
2466
|
-
args.skip_ref_check,
|
|
2467
|
-
args.persistent,
|
|
2468
|
-
args.dynamic_persistent,
|
|
2469
|
-
args.pingpong,
|
|
2470
|
-
args.varlen_m,
|
|
2471
|
-
args.gather_A,
|
|
2472
|
-
args.fp8_fast_accum,
|
|
2473
|
-
)
|
|
2474
|
-
print("PASS")
|
|
2245
|
+
gemm_sm90.compile_cache = {}
|