quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/dense_gemm_sm90.py
CHANGED
|
@@ -1,47 +1,43 @@
|
|
|
1
|
-
#
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
# list of conditions and the following disclaimer.
|
|
9
|
-
|
|
10
|
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
-
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
-
# and/or other materials provided with the distribution.
|
|
13
|
-
|
|
14
|
-
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
-
# contributors may be used to endorse or promote products derived from
|
|
16
|
-
# this software without specific prior written permission.
|
|
17
|
-
|
|
18
|
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
|
-
|
|
29
|
-
import argparse
|
|
30
|
-
from typing import Tuple, Type
|
|
1
|
+
# Based on the cute-dsl example:
|
|
2
|
+
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
|
|
3
|
+
|
|
4
|
+
import enum
|
|
5
|
+
from typing import Tuple, Type, Callable, Optional, Union
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import partial
|
|
31
8
|
import math
|
|
32
|
-
import cuda.bindings.driver as cuda
|
|
33
9
|
|
|
34
|
-
import
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
|
|
12
|
+
import cuda.bindings.driver as cuda
|
|
35
13
|
|
|
36
14
|
import cutlass
|
|
37
15
|
import cutlass.cute as cute
|
|
38
|
-
import cutlass.cute.testing as testing
|
|
39
|
-
import cutlass.utils as utils
|
|
40
16
|
import cutlass.pipeline as pipeline
|
|
41
|
-
import cutlass.torch as cutlass_torch
|
|
42
|
-
from cutlass.cute.runtime import from_dlpack
|
|
43
17
|
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
44
18
|
import cutlass.utils.hopper_helpers as sm90_utils
|
|
19
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
20
|
+
from cutlass.utils import LayoutEnum
|
|
21
|
+
import cutlass.torch as cutlass_torch
|
|
22
|
+
from cutlass.cute.runtime import make_ptr
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
|
|
26
|
+
from quack.tile_scheduler import (
|
|
27
|
+
TileSchedulerOptions,
|
|
28
|
+
TileSchedulerArguments,
|
|
29
|
+
TileScheduler,
|
|
30
|
+
VarlenMTileSchedulerArguments,
|
|
31
|
+
VarlenMTileScheduler,
|
|
32
|
+
)
|
|
33
|
+
from quack.varlen_utils import VarlenArguments
|
|
34
|
+
from quack.tensormap_manager import TensorMapManagerSm90
|
|
35
|
+
|
|
36
|
+
# return PipelineStateWAdvance instead of PipelineState
|
|
37
|
+
from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
|
|
38
|
+
import quack.utils as utils
|
|
39
|
+
from quack.cute_dsl_utils import get_max_active_clusters
|
|
40
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
45
41
|
|
|
46
42
|
"""
|
|
47
43
|
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
|
@@ -66,31 +62,6 @@ Hopper WGMMA instructions operate as follows:
|
|
|
66
62
|
- Read matrix B from SMEM
|
|
67
63
|
- Perform MMA operation and store the result in Accumulator(register)
|
|
68
64
|
|
|
69
|
-
To run this example:
|
|
70
|
-
|
|
71
|
-
.. code-block:: bash
|
|
72
|
-
|
|
73
|
-
python examples/hopper/dense_gemm.py \
|
|
74
|
-
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
|
75
|
-
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
|
76
|
-
--d_dtype Float16 --acc_dtype Float32 \
|
|
77
|
-
--a_major k --b_major k --d_major n
|
|
78
|
-
|
|
79
|
-
The above example command compute batched gemm with M=8192, N=8192, K=8192,
|
|
80
|
-
batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape
|
|
81
|
-
is (1,1). The input, mma accumulator and output data type are set as fp16, fp32
|
|
82
|
-
and fp16, respectively.
|
|
83
|
-
|
|
84
|
-
To collect performance with NCU profiler:
|
|
85
|
-
|
|
86
|
-
.. code-block:: bash
|
|
87
|
-
|
|
88
|
-
ncu python examples/hopper/dense_gemm.py \
|
|
89
|
-
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
|
90
|
-
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
|
91
|
-
--d_dtype Float16 --acc_dtype Float32 \
|
|
92
|
-
--a_major k --b_major k --d_major n
|
|
93
|
-
|
|
94
65
|
Constraints:
|
|
95
66
|
* Supported input data types: fp16, fp8 (e4m3fn, e5m2)
|
|
96
67
|
* For fp16 types, A and B must have the same data type
|
|
@@ -103,107 +74,29 @@ Constraints:
|
|
|
103
74
|
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
|
104
75
|
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
|
105
76
|
i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
|
|
106
|
-
* OOB tiles are not allowed when TMA store is disabled
|
|
107
77
|
"""
|
|
108
78
|
|
|
109
79
|
|
|
110
|
-
|
|
111
|
-
#
|
|
112
|
-
#
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def parse_arguments() -> argparse.Namespace:
|
|
121
|
-
parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.")
|
|
122
|
-
|
|
123
|
-
parser.add_argument(
|
|
124
|
-
"--mnkl",
|
|
125
|
-
type=parse_comma_separated_ints,
|
|
126
|
-
default=(4096, 4096, 4096, 1),
|
|
127
|
-
help="mnkl dimensions (comma-separated)",
|
|
128
|
-
)
|
|
129
|
-
parser.add_argument(
|
|
130
|
-
"--tile_shape_mnk",
|
|
131
|
-
type=parse_comma_separated_ints,
|
|
132
|
-
default=(128, 256, 64),
|
|
133
|
-
help="Cta tile shape (comma-separated)",
|
|
134
|
-
)
|
|
135
|
-
parser.add_argument(
|
|
136
|
-
"--cluster_shape_mn",
|
|
137
|
-
type=parse_comma_separated_ints,
|
|
138
|
-
choices=[(1, 1), (2, 1), (1, 2), (2, 2)],
|
|
139
|
-
default=(1, 1),
|
|
140
|
-
help="Cluster shape (comma-separated)",
|
|
141
|
-
)
|
|
142
|
-
parser.add_argument(
|
|
143
|
-
"--a_dtype",
|
|
144
|
-
type=cutlass.dtype,
|
|
145
|
-
default=cutlass.BFloat16,
|
|
146
|
-
)
|
|
147
|
-
parser.add_argument(
|
|
148
|
-
"--b_dtype",
|
|
149
|
-
type=cutlass.dtype,
|
|
150
|
-
default=cutlass.BFloat16,
|
|
151
|
-
)
|
|
152
|
-
parser.add_argument(
|
|
153
|
-
"--d_dtype",
|
|
154
|
-
type=cutlass.dtype,
|
|
155
|
-
default=cutlass.BFloat16,
|
|
156
|
-
)
|
|
157
|
-
parser.add_argument(
|
|
158
|
-
"--acc_dtype",
|
|
159
|
-
type=cutlass.dtype,
|
|
160
|
-
default=cutlass.Float32,
|
|
161
|
-
)
|
|
162
|
-
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
163
|
-
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
164
|
-
parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
|
|
165
|
-
parser.add_argument("--tolerance", type=float, default=1e-01, help="Tolerance for validation")
|
|
166
|
-
parser.add_argument("--warmup_iterations", type=int, default=0, help="Warmup iterations")
|
|
167
|
-
parser.add_argument(
|
|
168
|
-
"--iterations",
|
|
169
|
-
type=int,
|
|
170
|
-
default=1,
|
|
171
|
-
help="Number of iterations to run the kernel",
|
|
172
|
-
)
|
|
173
|
-
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
|
|
174
|
-
parser.add_argument(
|
|
175
|
-
"--use_cold_l2",
|
|
176
|
-
action="store_true",
|
|
177
|
-
default=False,
|
|
178
|
-
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
args = parser.parse_args()
|
|
182
|
-
|
|
183
|
-
if len(args.mnkl) != 4:
|
|
184
|
-
parser.error("--mnkl must contain exactly 4 values")
|
|
185
|
-
if len(args.tile_shape_mnk) != 3:
|
|
186
|
-
parser.error("--tile_shape_mnk must contain exactly 3 values")
|
|
187
|
-
if len(args.cluster_shape_mn) != 2:
|
|
188
|
-
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
|
189
|
-
|
|
190
|
-
return args
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
194
|
-
# Host setup and device kernel launch
|
|
195
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
80
|
+
class NamedBarrierGemm(enum.IntEnum):
|
|
81
|
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
|
82
|
+
# For mainloop load warps to signal that the epilogue load warp can start.
|
|
83
|
+
# This is to avoid loading C too early, interfering with loading A and B.
|
|
84
|
+
EpilogueLoad = enum.auto()
|
|
85
|
+
MmaWG0 = enum.auto()
|
|
86
|
+
MmaWG1 = enum.auto()
|
|
87
|
+
EpiWG0 = enum.auto()
|
|
88
|
+
EpiWG1 = enum.auto()
|
|
196
89
|
|
|
197
90
|
|
|
198
|
-
class
|
|
91
|
+
class GemmSm90:
|
|
199
92
|
"""
|
|
200
93
|
This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
201
94
|
and architectural features specific to Hopper GPUs.
|
|
202
95
|
|
|
203
96
|
:param acc_dtype: Data type for accumulation during computation
|
|
204
97
|
:type acc_dtype: type[cutlass.Numeric]
|
|
205
|
-
:param
|
|
206
|
-
:type
|
|
98
|
+
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
|
99
|
+
:type tile_shape_mn: Tuple[int, int, int]
|
|
207
100
|
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
|
208
101
|
:type cluster_shape_mnk: Tuple[int, int, int]
|
|
209
102
|
|
|
@@ -221,25 +114,39 @@ class HopperWgmmaGemmKernel:
|
|
|
221
114
|
- Float32 (for all floating point inputs)
|
|
222
115
|
|
|
223
116
|
:note: Constraints:
|
|
224
|
-
- CTA tile M must be 64/128
|
|
225
|
-
- CTA tile N must be 64/128/256
|
|
226
|
-
- CTA tile K must be 64
|
|
227
117
|
- Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
|
228
118
|
|
|
229
119
|
Example:
|
|
230
|
-
>>> gemm =
|
|
120
|
+
>>> gemm = GemmSm90(
|
|
231
121
|
... acc_dtype=cutlass.Float32,
|
|
232
|
-
...
|
|
122
|
+
... tile_shape_mn=(128, 256),
|
|
233
123
|
... cluster_shape_mnk=(1, 1, 1)
|
|
234
124
|
... )
|
|
235
125
|
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
|
236
126
|
"""
|
|
237
127
|
|
|
128
|
+
bytes_per_tensormap = 128
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class EpilogueArguments(ArgumentsBase):
|
|
132
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
133
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
134
|
+
|
|
135
|
+
@dataclass
|
|
136
|
+
class EpilogueParams(ParamsBase):
|
|
137
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
138
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
139
|
+
|
|
238
140
|
def __init__(
|
|
239
141
|
self,
|
|
240
142
|
acc_dtype: Type[cutlass.Numeric],
|
|
241
|
-
|
|
143
|
+
a_dtype: Type[cutlass.Numeric],
|
|
144
|
+
tile_shape_mn: Tuple[int, int],
|
|
242
145
|
cluster_shape_mnk: Tuple[int, int, int],
|
|
146
|
+
pingpong: bool = False,
|
|
147
|
+
is_persistent: bool = True,
|
|
148
|
+
fp8_fast_accum: bool = False,
|
|
149
|
+
gather_A: bool = False,
|
|
243
150
|
):
|
|
244
151
|
"""
|
|
245
152
|
Initializes the configuration for a Hopper dense GEMM kernel.
|
|
@@ -249,59 +156,106 @@ class HopperWgmmaGemmKernel:
|
|
|
249
156
|
|
|
250
157
|
:param acc_dtype: Data type for accumulation during computation
|
|
251
158
|
:type acc_dtype: type[cutlass.Numeric]
|
|
252
|
-
:param
|
|
253
|
-
:type
|
|
159
|
+
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
|
160
|
+
:type tile_shape_mn: Tuple[int, int]
|
|
254
161
|
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
|
255
162
|
:type cluster_shape_mnk: Tuple[int, int, int]
|
|
256
163
|
"""
|
|
257
164
|
|
|
258
165
|
self.acc_dtype = acc_dtype
|
|
166
|
+
self.pingpong = pingpong
|
|
167
|
+
self.is_persistent = is_persistent
|
|
168
|
+
if self.pingpong:
|
|
169
|
+
assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
|
|
170
|
+
self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
|
|
171
|
+
self.gather_A = gather_A
|
|
172
|
+
if gather_A:
|
|
173
|
+
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
|
259
174
|
|
|
260
175
|
self.cluster_shape_mnk = cluster_shape_mnk
|
|
261
|
-
|
|
262
|
-
|
|
176
|
+
# K dimension is deferred in _setup_attributes
|
|
177
|
+
self.tile_shape_mnk = (*tile_shape_mn, 1)
|
|
178
|
+
tile_M, tile_N = self.tile_shape_mnk[0], self.tile_shape_mnk[1]
|
|
263
179
|
# check the cta tile shape
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
180
|
+
if not self.pingpong:
|
|
181
|
+
if tile_M not in [64, 128, 192, 256, 320]:
|
|
182
|
+
raise ValueError("CTA tile shape M must be 64/128/192/256/320")
|
|
183
|
+
if tile_M in [192, 320]: # special case
|
|
184
|
+
tile_N_max = 256 if tile_M == 192 else 160
|
|
185
|
+
if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
if not (
|
|
191
|
+
(tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
|
|
192
|
+
):
|
|
193
|
+
raise ValueError(
|
|
194
|
+
"CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
|
|
195
|
+
)
|
|
273
196
|
else:
|
|
274
|
-
if not
|
|
275
|
-
raise ValueError(
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
197
|
+
if tile_M not in [64, 128, 192]:
|
|
198
|
+
raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
|
|
199
|
+
tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
|
|
200
|
+
if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
|
|
201
|
+
raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
|
|
202
|
+
|
|
203
|
+
if not self.pingpong:
|
|
204
|
+
if tile_M == 320: # tile_M / 64 is not even so we have to split along N
|
|
205
|
+
atom_layout_m, atom_layout_n = 1, 2
|
|
206
|
+
elif tile_M == 192:
|
|
207
|
+
if tile_N <= 128:
|
|
208
|
+
atom_layout_m, atom_layout_n = 3, 1
|
|
209
|
+
else:
|
|
210
|
+
atom_layout_m, atom_layout_n = 1, 2
|
|
211
|
+
else:
|
|
212
|
+
atom_layout_m = self.tile_shape_mnk[0] // 64 if self.tile_shape_mnk[0] < 256 else 2
|
|
213
|
+
atom_layout_n = 1
|
|
214
|
+
assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
|
|
283
215
|
else:
|
|
284
|
-
atom_layout_m =
|
|
285
|
-
atom_layout_n = 1
|
|
286
|
-
assert atom_layout_m in [1, 2] and atom_layout_n in [1, 2]
|
|
216
|
+
atom_layout_m, atom_layout_n = 1, 1
|
|
287
217
|
self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
|
|
288
218
|
|
|
289
|
-
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
|
219
|
+
self.num_mcast_ctas_a = self.cluster_shape_mnk[1] if not self.gather_A else 1
|
|
290
220
|
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
|
291
221
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
292
222
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
293
223
|
|
|
294
224
|
self.occupancy = 1
|
|
295
|
-
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
|
|
225
|
+
self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
|
|
226
|
+
if self.pingpong:
|
|
227
|
+
assert self.mma_warp_groups == 2
|
|
228
|
+
assert self.mma_warp_groups in [1, 2, 3]
|
|
296
229
|
self.num_threads_per_warp_group = 128
|
|
297
230
|
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
|
298
|
-
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
|
|
299
|
-
self.
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
self.
|
|
304
|
-
self.
|
|
231
|
+
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
|
232
|
+
self.num_epi_threads = (
|
|
233
|
+
self.mma_warp_groups if not self.pingpong else 1
|
|
234
|
+
) * self.num_threads_per_warp_group
|
|
235
|
+
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
|
236
|
+
self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
|
|
237
|
+
self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
238
|
+
self.ab_load_warp_id = self.mma_warp_groups * 4
|
|
239
|
+
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
240
|
+
|
|
241
|
+
regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // (
|
|
242
|
+
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
243
|
+
)
|
|
244
|
+
if self.fp8_slow_accum:
|
|
245
|
+
regs_per_thread *= 2
|
|
246
|
+
if not self.gather_A:
|
|
247
|
+
if self.mma_warp_groups == 3:
|
|
248
|
+
self.num_regs_load, self.num_regs_mma = 32, 160
|
|
249
|
+
else:
|
|
250
|
+
heavy_register_pressure = regs_per_thread >= 208
|
|
251
|
+
self.num_regs_load, self.num_regs_mma = (
|
|
252
|
+
(40, 232) if not heavy_register_pressure else (24, 240)
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
if self.mma_warp_groups == 3:
|
|
256
|
+
self.num_regs_load, self.num_regs_mma = 56, 152
|
|
257
|
+
else:
|
|
258
|
+
self.num_regs_load, self.num_regs_mma = (56, 224)
|
|
305
259
|
|
|
306
260
|
self.ab_stage = None
|
|
307
261
|
self.epi_stage = None
|
|
@@ -314,7 +268,7 @@ class HopperWgmmaGemmKernel:
|
|
|
314
268
|
self.shared_storage = None
|
|
315
269
|
self.buffer_align_bytes = 1024
|
|
316
270
|
|
|
317
|
-
def _setup_attributes(self):
|
|
271
|
+
def _setup_attributes(self, epilogue_args: Optional[EpilogueArguments]):
|
|
318
272
|
"""Set up configurations that are dependent on GEMM inputs
|
|
319
273
|
|
|
320
274
|
This method configures various attributes based on the input tensor properties
|
|
@@ -328,26 +282,67 @@ class HopperWgmmaGemmKernel:
|
|
|
328
282
|
- Computing A/B/C shared memory layout
|
|
329
283
|
"""
|
|
330
284
|
|
|
331
|
-
self.
|
|
285
|
+
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
|
286
|
+
self.a_dtype,
|
|
287
|
+
self.b_dtype,
|
|
288
|
+
self.a_layout.sm90_mma_major_mode(),
|
|
289
|
+
self.b_layout.sm90_mma_major_mode(),
|
|
290
|
+
self.acc_dtype,
|
|
291
|
+
self.atom_layout_mnk,
|
|
292
|
+
tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
|
293
|
+
)
|
|
294
|
+
if const_expr(self.atom_layout_mnk[1] > 1):
|
|
295
|
+
# If N dimension is split among 2 WGs, we need to permute the N dimension so
|
|
296
|
+
# that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
|
|
297
|
+
# containing accumulators that are next to each other in the N dimension.
|
|
298
|
+
# Without permutation WG0 would write to epi smem of size (64, 16) and
|
|
299
|
+
# WG1 would write to a separate epi smem of size (64, 16) that's far away.
|
|
300
|
+
atom_n = self.atom_layout_mnk[1]
|
|
301
|
+
permutation_n = cute.make_ordered_layout(
|
|
302
|
+
(8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
|
|
303
|
+
)
|
|
304
|
+
self.tiled_mma = cute.make_tiled_mma(
|
|
305
|
+
cute.make_mma_atom(self.tiled_mma.op),
|
|
306
|
+
self.atom_layout_mnk,
|
|
307
|
+
permutation_mnk=(None, permutation_n, None),
|
|
308
|
+
)
|
|
309
|
+
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
|
|
310
|
+
mma_inst_tile_k = 4
|
|
311
|
+
self.tile_shape_mnk = (
|
|
312
|
+
self.tile_shape_mnk[0],
|
|
313
|
+
self.tile_shape_mnk[1],
|
|
314
|
+
mma_inst_shape_k * mma_inst_tile_k,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
|
332
318
|
|
|
333
|
-
is_cooperative = math.prod(self.atom_layout_mnk) > 1
|
|
334
319
|
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
|
335
|
-
self.tile_shape_mnk,
|
|
320
|
+
self.tile_shape_mnk,
|
|
321
|
+
self.atom_layout_mnk,
|
|
322
|
+
self.d_dtype,
|
|
336
323
|
)
|
|
337
324
|
|
|
338
325
|
# Compute stage before compute smem layout
|
|
339
|
-
self.ab_stage, self.epi_stage = self._compute_stages(
|
|
326
|
+
self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
|
|
340
327
|
self.tile_shape_mnk,
|
|
328
|
+
self.epi_tile,
|
|
341
329
|
self.a_dtype,
|
|
342
330
|
self.b_dtype,
|
|
331
|
+
self.d_dtype,
|
|
332
|
+
self.c_dtype,
|
|
333
|
+
epilogue_args,
|
|
343
334
|
self.smem_capacity,
|
|
344
335
|
self.occupancy,
|
|
336
|
+
# epi_smem will reuse smem ab if not persistent.
|
|
337
|
+
overlap_sD_sA=not self.is_persistent,
|
|
345
338
|
)
|
|
339
|
+
self.sched_stage = 2 if self.pingpong else 1
|
|
346
340
|
|
|
347
341
|
(
|
|
348
342
|
self.a_smem_layout_staged,
|
|
349
343
|
self.b_smem_layout_staged,
|
|
350
344
|
self.epi_smem_layout_staged,
|
|
345
|
+
self.epi_c_smem_layout_staged,
|
|
351
346
|
) = self._make_smem_layouts(
|
|
352
347
|
self.tile_shape_mnk,
|
|
353
348
|
self.epi_tile,
|
|
@@ -359,6 +354,9 @@ class HopperWgmmaGemmKernel:
|
|
|
359
354
|
self.d_dtype,
|
|
360
355
|
self.d_layout,
|
|
361
356
|
self.epi_stage,
|
|
357
|
+
self.c_dtype,
|
|
358
|
+
self.c_layout,
|
|
359
|
+
self.epi_c_stage,
|
|
362
360
|
)
|
|
363
361
|
|
|
364
362
|
@cute.jit
|
|
@@ -366,7 +364,12 @@ class HopperWgmmaGemmKernel:
|
|
|
366
364
|
self,
|
|
367
365
|
mA: cute.Tensor,
|
|
368
366
|
mB: cute.Tensor,
|
|
369
|
-
mD: cute.Tensor,
|
|
367
|
+
mD: Optional[cute.Tensor],
|
|
368
|
+
mC: Optional[cute.Tensor],
|
|
369
|
+
epilogue_args: Optional[ArgumentsBase],
|
|
370
|
+
scheduler_args: TileSchedulerOptions,
|
|
371
|
+
varlen_args: Optional[VarlenArguments],
|
|
372
|
+
mAIdx: Optional[cute.Tensor],
|
|
370
373
|
stream: cuda.CUstream,
|
|
371
374
|
):
|
|
372
375
|
"""Execute the GEMM operation in steps:
|
|
@@ -389,36 +392,44 @@ class HopperWgmmaGemmKernel:
|
|
|
389
392
|
# setup static attributes before smem/grid/tma computation
|
|
390
393
|
self.a_dtype = mA.element_type
|
|
391
394
|
self.b_dtype = mB.element_type
|
|
392
|
-
self.d_dtype = mD.element_type
|
|
393
|
-
self.
|
|
394
|
-
self.
|
|
395
|
-
self.
|
|
396
|
-
|
|
397
|
-
|
|
395
|
+
self.d_dtype = mD.element_type if mD is not None else None
|
|
396
|
+
self.c_dtype = mC.element_type if mC is not None else None
|
|
397
|
+
self.a_layout = LayoutEnum.from_tensor(mA)
|
|
398
|
+
self.b_layout = LayoutEnum.from_tensor(mB)
|
|
399
|
+
self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
|
|
400
|
+
self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
|
|
401
|
+
|
|
402
|
+
if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
|
|
398
403
|
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
|
|
399
|
-
if
|
|
404
|
+
if const_expr(self.a_dtype.width != self.b_dtype.width):
|
|
400
405
|
raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
|
|
401
|
-
if
|
|
406
|
+
if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
|
|
402
407
|
raise TypeError("a_dtype should be float16 or float8")
|
|
408
|
+
assert (mAIdx is not None) == self.gather_A
|
|
403
409
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
self.b_dtype,
|
|
409
|
-
self.a_layout.sm90_mma_major_mode(),
|
|
410
|
-
self.b_layout.sm90_mma_major_mode(),
|
|
411
|
-
self.acc_dtype,
|
|
412
|
-
self.atom_layout_mnk,
|
|
413
|
-
tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
|
414
|
-
)
|
|
415
|
-
|
|
416
|
-
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
|
417
|
-
mA,
|
|
418
|
-
self.a_smem_layout_staged,
|
|
419
|
-
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
|
|
420
|
-
self.cluster_shape_mnk[1],
|
|
410
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
411
|
+
new_stride = lambda t: tuple(
|
|
412
|
+
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
|
413
|
+
for s in t.stride
|
|
421
414
|
)
|
|
415
|
+
mA, mD = [
|
|
416
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
417
|
+
if t is not None
|
|
418
|
+
else None
|
|
419
|
+
for t in (mA, mD)
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
self._setup_attributes(epilogue_args)
|
|
423
|
+
|
|
424
|
+
if const_expr(not self.gather_A):
|
|
425
|
+
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
|
426
|
+
mA,
|
|
427
|
+
self.a_smem_layout_staged,
|
|
428
|
+
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
|
|
429
|
+
self.cluster_shape_mnk[1],
|
|
430
|
+
)
|
|
431
|
+
else:
|
|
432
|
+
tma_atom_a, tma_tensor_a = None, None
|
|
422
433
|
|
|
423
434
|
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
|
424
435
|
mB,
|
|
@@ -427,17 +438,89 @@ class HopperWgmmaGemmKernel:
|
|
|
427
438
|
self.cluster_shape_mnk[0],
|
|
428
439
|
)
|
|
429
440
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
441
|
+
if const_expr(mD is not None):
|
|
442
|
+
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
|
443
|
+
mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
|
|
444
|
+
)
|
|
445
|
+
else:
|
|
446
|
+
tma_atom_d, tma_tensor_d = None, None
|
|
447
|
+
|
|
448
|
+
if const_expr(mC is not None):
|
|
449
|
+
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
|
450
|
+
mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
|
|
451
|
+
)
|
|
452
|
+
else:
|
|
453
|
+
tma_atom_c, tma_tensor_c = None, None
|
|
454
|
+
|
|
455
|
+
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
|
456
|
+
|
|
457
|
+
if const_expr(varlen_args is None):
|
|
458
|
+
varlen_args = VarlenArguments()
|
|
459
|
+
if const_expr(varlen_args.mCuSeqlensM is None):
|
|
460
|
+
num_problems = (
|
|
461
|
+
mD.shape[2]
|
|
462
|
+
if mD is not None
|
|
463
|
+
else (
|
|
464
|
+
mB.shape[2]
|
|
465
|
+
if varlen_args.mCuSeqlensK is None
|
|
466
|
+
else varlen_args.mCuSeqlensK.shape[0] - 1
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
problem_shape_ntile_mnl = (
|
|
470
|
+
cute.ceil_div(mA.shape[0], self.tile_shape_mnk[0]),
|
|
471
|
+
cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
|
|
472
|
+
num_problems,
|
|
473
|
+
)
|
|
474
|
+
TileSchedulerCls = self.get_scheduler_class()
|
|
475
|
+
tile_sched_args = self.get_scheduler_arguments(problem_shape_ntile_mnl, scheduler_args)
|
|
476
|
+
else:
|
|
477
|
+
assert mD is not None or not self.gather_A
|
|
478
|
+
problem_shape_ntile_mnl = (
|
|
479
|
+
None,
|
|
480
|
+
cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
|
|
481
|
+
varlen_args.mCuSeqlensM.shape[0] - 1,
|
|
482
|
+
)
|
|
483
|
+
TileSchedulerCls = VarlenMTileScheduler
|
|
484
|
+
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
485
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
486
|
+
total_m=mD.shape[0] if mD is not None else mAIdx.shape[0],
|
|
487
|
+
cu_seqlens_m=varlen_args.mCuSeqlensM,
|
|
488
|
+
raster_order=scheduler_args.raster_order,
|
|
489
|
+
group_size=scheduler_args.max_swizzle_size,
|
|
490
|
+
tile_shape_mn=self.tile_shape_mnk[:2],
|
|
491
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
492
|
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
493
|
+
is_persistent=self.is_persistent,
|
|
494
|
+
)
|
|
495
|
+
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
496
|
+
grid = TileSchedulerCls.get_grid_shape(
|
|
497
|
+
tile_sched_params, scheduler_args.max_active_clusters
|
|
434
498
|
)
|
|
435
499
|
|
|
436
|
-
|
|
500
|
+
epi_smem_size = (
|
|
501
|
+
cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
|
|
502
|
+
)
|
|
503
|
+
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
|
437
504
|
|
|
438
505
|
@cute.struct
|
|
439
506
|
class SharedStorage:
|
|
440
|
-
|
|
507
|
+
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
508
|
+
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
509
|
+
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
510
|
+
tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
|
|
511
|
+
sD: cute.struct.Align[
|
|
512
|
+
cute.struct.MemRange[
|
|
513
|
+
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
514
|
+
],
|
|
515
|
+
self.buffer_align_bytes,
|
|
516
|
+
]
|
|
517
|
+
sC: cute.struct.Align[
|
|
518
|
+
cute.struct.MemRange[
|
|
519
|
+
self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
|
|
520
|
+
],
|
|
521
|
+
self.buffer_align_bytes,
|
|
522
|
+
]
|
|
523
|
+
epi: self.epi_get_smem_struct(epilogue_params)
|
|
441
524
|
sA: cute.struct.Align[
|
|
442
525
|
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
|
|
443
526
|
self.buffer_align_bytes,
|
|
@@ -452,16 +535,26 @@ class HopperWgmmaGemmKernel:
|
|
|
452
535
|
# Launch the kernel synchronously
|
|
453
536
|
self.kernel(
|
|
454
537
|
tma_atom_a,
|
|
455
|
-
tma_tensor_a,
|
|
538
|
+
tma_tensor_a if const_expr(not self.gather_A) else mA,
|
|
456
539
|
tma_atom_b,
|
|
457
540
|
tma_tensor_b,
|
|
458
541
|
tma_atom_d,
|
|
459
542
|
tma_tensor_d,
|
|
460
|
-
|
|
461
|
-
|
|
543
|
+
tma_atom_c,
|
|
544
|
+
tma_tensor_c,
|
|
545
|
+
epilogue_params,
|
|
546
|
+
mAIdx,
|
|
547
|
+
varlen_args.mCuSeqlensM,
|
|
548
|
+
varlen_args.mCuSeqlensK,
|
|
549
|
+
varlen_args.mTensormaps,
|
|
550
|
+
self.tiled_mma,
|
|
551
|
+
self.cluster_layout_mnk,
|
|
462
552
|
self.a_smem_layout_staged,
|
|
463
553
|
self.b_smem_layout_staged,
|
|
464
554
|
self.epi_smem_layout_staged,
|
|
555
|
+
self.epi_c_smem_layout_staged,
|
|
556
|
+
tile_sched_params,
|
|
557
|
+
TileSchedulerCls,
|
|
465
558
|
).launch(
|
|
466
559
|
grid=grid,
|
|
467
560
|
block=[self.threads_per_cta, 1, 1],
|
|
@@ -476,17 +569,27 @@ class HopperWgmmaGemmKernel:
|
|
|
476
569
|
@cute.kernel
|
|
477
570
|
def kernel(
|
|
478
571
|
self,
|
|
479
|
-
tma_atom_a: cute.CopyAtom,
|
|
572
|
+
tma_atom_a: Optional[cute.CopyAtom],
|
|
480
573
|
mA_mkl: cute.Tensor,
|
|
481
574
|
tma_atom_b: cute.CopyAtom,
|
|
482
575
|
mB_nkl: cute.Tensor,
|
|
483
|
-
tma_atom_d: cute.CopyAtom,
|
|
484
|
-
mD_mnl: cute.Tensor,
|
|
576
|
+
tma_atom_d: Optional[cute.CopyAtom],
|
|
577
|
+
mD_mnl: Optional[cute.Tensor],
|
|
578
|
+
tma_atom_c: Optional[cute.CopyAtom],
|
|
579
|
+
mC_mnl: Optional[cute.Tensor],
|
|
580
|
+
epilogue_params: ParamsBase,
|
|
581
|
+
mAIdx: Optional[cute.Tensor],
|
|
582
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
583
|
+
cu_seqlens_k: Optional[cute.Tensor],
|
|
584
|
+
tensormaps: Optional[cute.Tensor],
|
|
485
585
|
tiled_mma: cute.TiledMma,
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
586
|
+
cluster_layout_mnk: cute.Layout,
|
|
587
|
+
a_smem_layout: cute.ComposedLayout,
|
|
588
|
+
b_smem_layout: cute.ComposedLayout,
|
|
589
|
+
epi_smem_layout: cute.ComposedLayout,
|
|
590
|
+
epi_c_smem_layout: cute.ComposedLayout,
|
|
591
|
+
tile_sched_params: ParamsBase,
|
|
592
|
+
TileSchedulerCls: cutlass.Constexpr[Callable],
|
|
490
593
|
):
|
|
491
594
|
"""
|
|
492
595
|
GPU device kernel performing the batched GEMM computation.
|
|
@@ -505,32 +608,31 @@ class HopperWgmmaGemmKernel:
|
|
|
505
608
|
:type mD_mnl: cute.Tensor
|
|
506
609
|
:param tiled_mma: Tiled MMA object
|
|
507
610
|
:type tiled_mma: cute.TiledMma
|
|
508
|
-
:param
|
|
509
|
-
:type
|
|
510
|
-
:param
|
|
511
|
-
:type
|
|
512
|
-
:param
|
|
513
|
-
:type
|
|
514
|
-
:param
|
|
515
|
-
:type
|
|
611
|
+
:param cluster_layout_mnk: CTA layout
|
|
612
|
+
:type cluster_layout_mnk: cute.Layout
|
|
613
|
+
:param a_smem_layout: Shared memory layout for A
|
|
614
|
+
:type a_smem_layout: cute.ComposedLayout
|
|
615
|
+
:param b_smem_layout: Shared memory layout for B
|
|
616
|
+
:type b_smem_layout: cute.ComposedLayout
|
|
617
|
+
:param epi_smem_layout: Shared memory layout for epilogue
|
|
618
|
+
:type epi_smem_layout: cute.ComposedLayout
|
|
516
619
|
"""
|
|
517
620
|
|
|
621
|
+
varlen_m = const_expr(cu_seqlens_m is not None)
|
|
622
|
+
varlen_k = const_expr(cu_seqlens_k is not None)
|
|
623
|
+
assert not (varlen_m and varlen_k)
|
|
624
|
+
has_D = const_expr(mD_mnl is not None)
|
|
625
|
+
has_C = const_expr(mC_mnl is not None)
|
|
626
|
+
|
|
518
627
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
519
628
|
|
|
520
629
|
# /////////////////////////////////////////////////////////////////////////////
|
|
521
630
|
# Prefetch Tma desc
|
|
522
631
|
# /////////////////////////////////////////////////////////////////////////////
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
cpasync.prefetch_descriptor(tma_atom_d)
|
|
528
|
-
|
|
529
|
-
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
|
|
530
|
-
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
|
|
531
|
-
tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes(
|
|
532
|
-
self.b_dtype, b_smem_layout
|
|
533
|
-
)
|
|
632
|
+
if warp_idx == self.ab_load_warp_id:
|
|
633
|
+
for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
|
|
634
|
+
if const_expr(tma_atom is not None):
|
|
635
|
+
cpasync.prefetch_descriptor(tma_atom)
|
|
534
636
|
|
|
535
637
|
# /////////////////////////////////////////////////////////////////////////////
|
|
536
638
|
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
|
@@ -538,164 +640,321 @@ class HopperWgmmaGemmKernel:
|
|
|
538
640
|
smem = cutlass.utils.SmemAllocator()
|
|
539
641
|
storage = smem.allocate(self.shared_storage)
|
|
540
642
|
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
548
|
-
)
|
|
549
|
-
|
|
550
|
-
cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
|
|
551
|
-
mainloop_pipeline = pipeline.PipelineTmaAsync.create(
|
|
552
|
-
barrier_storage=storage.mainloop_pipeline_array_ptr.data_ptr(),
|
|
553
|
-
num_stages=self.ab_stage,
|
|
554
|
-
producer_group=mainloop_pipeline_producer_group,
|
|
555
|
-
consumer_group=mainloop_pipeline_consumer_group,
|
|
556
|
-
tx_count=tma_copy_bytes,
|
|
557
|
-
cta_layout_vmnk=cta_layout_vmnk,
|
|
643
|
+
ab_pipeline = self.make_ab_pipeline(
|
|
644
|
+
a_smem_layout=cute.slice_(a_smem_layout, (None, None, 0)),
|
|
645
|
+
b_smem_layout=cute.slice_(b_smem_layout, (None, None, 0)),
|
|
646
|
+
tiled_mma=tiled_mma,
|
|
647
|
+
cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
|
|
648
|
+
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
558
649
|
)
|
|
650
|
+
epi_pipeline = None
|
|
651
|
+
if const_expr(has_C):
|
|
652
|
+
epi_pipeline = self.make_epi_pipeline(
|
|
653
|
+
c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
|
|
654
|
+
epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
655
|
+
)
|
|
656
|
+
sched_pipeline = None
|
|
657
|
+
tile_count = None
|
|
658
|
+
if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
659
|
+
# Dynamic persistent scheduler
|
|
660
|
+
sched_pipeline = self.make_sched_pipeline(
|
|
661
|
+
cluster_layout_mnk,
|
|
662
|
+
sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
663
|
+
varlen_k=varlen_k,
|
|
664
|
+
)
|
|
665
|
+
tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
559
666
|
|
|
560
667
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
561
668
|
# Generate smem tensor A/B
|
|
562
669
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
563
|
-
sA = storage.sA.get_tensor(
|
|
564
|
-
sB = storage.sB.get_tensor(
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
670
|
+
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
|
|
671
|
+
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
|
|
672
|
+
sD = None
|
|
673
|
+
if const_expr(has_D):
|
|
674
|
+
if const_expr(not self.is_persistent):
|
|
675
|
+
sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
|
|
676
|
+
sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
|
|
677
|
+
else:
|
|
678
|
+
sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
|
|
679
|
+
sC = None
|
|
680
|
+
if const_expr(has_C):
|
|
681
|
+
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
682
|
+
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
|
683
|
+
|
|
684
|
+
# Get tensormap buffer address
|
|
685
|
+
tensormap_manager = None
|
|
686
|
+
tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
|
|
687
|
+
if const_expr(varlen_m or varlen_k):
|
|
688
|
+
tensormap_manager = TensorMapManagerSm90(
|
|
689
|
+
cutlass.utils.TensorMapUpdateMode.GMEM, GemmSm90.bytes_per_tensormap
|
|
690
|
+
)
|
|
691
|
+
# equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
|
|
692
|
+
tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
|
|
693
|
+
if const_expr(varlen_m):
|
|
694
|
+
tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
|
|
695
|
+
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
696
|
+
tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
|
|
697
|
+
)
|
|
698
|
+
else:
|
|
699
|
+
assert varlen_k
|
|
700
|
+
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
701
|
+
tensormaps[tensormap_workspace_idx, 0, None].iterator
|
|
702
|
+
)
|
|
703
|
+
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
704
|
+
tensormaps[tensormap_workspace_idx, 1, None].iterator
|
|
705
|
+
)
|
|
575
706
|
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
s_shape = (
|
|
579
|
-
(group_size_m, cdimx // group_size_m),
|
|
580
|
-
cdimy,
|
|
707
|
+
TileSchedulerCls = partial(
|
|
708
|
+
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
|
581
709
|
)
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
num_reg_cids = cute.size(s_shape)
|
|
585
|
-
cid_m, cid_n = s_layout.get_flat_coord(cluster_id % num_reg_cids)
|
|
586
|
-
|
|
587
|
-
# Deal with the tail part
|
|
588
|
-
if cluster_id >= num_reg_cids:
|
|
589
|
-
tail_size_m = cdimx % group_size_m
|
|
590
|
-
tail_layout = cute.make_layout((tail_size_m, cdimy), stride=(1, tail_size_m))
|
|
591
|
-
tail_cid = cluster_id - num_reg_cids
|
|
592
|
-
tail_cid_m, tail_cid_n = tail_layout.get_flat_coord(tail_cid)
|
|
593
|
-
cid_m = cute.size(s_shape, mode=[0]) + tail_cid_m
|
|
594
|
-
cid_n = tail_cid_n
|
|
595
|
-
|
|
596
|
-
# Get the pid from cluster id
|
|
597
|
-
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
598
|
-
pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
|
|
599
|
-
pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
|
|
600
|
-
|
|
601
|
-
_, _, bidz = cute.arch.block_idx()
|
|
602
|
-
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
|
603
|
-
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
604
|
-
cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
605
|
-
|
|
606
|
-
k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
|
|
607
|
-
|
|
608
|
-
if warp_idx >= self.mma_warp_groups * 4:
|
|
710
|
+
|
|
711
|
+
if warp_idx >= self.ab_load_warp_id:
|
|
609
712
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
|
|
610
|
-
if
|
|
713
|
+
if (
|
|
714
|
+
warp_idx >= self.ab_load_warp_id
|
|
715
|
+
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
716
|
+
):
|
|
717
|
+
is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
718
|
+
if const_expr(varlen_k):
|
|
719
|
+
# initialize tensormap for A & B
|
|
720
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
721
|
+
tma_atom_a,
|
|
722
|
+
tensormap_a_ptr,
|
|
723
|
+
is_tma_warp,
|
|
724
|
+
)
|
|
725
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
726
|
+
tma_atom_b,
|
|
727
|
+
tensormap_b_ptr,
|
|
728
|
+
is_tma_warp,
|
|
729
|
+
)
|
|
611
730
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
612
731
|
# Get mcast mask
|
|
613
732
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
733
|
+
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
734
|
+
cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
614
735
|
a_mcast_mask = cute.make_layout_image_mask(
|
|
615
|
-
|
|
736
|
+
cluster_layout_mnk, cluster_coord_mnk, mode=1
|
|
616
737
|
)
|
|
617
738
|
b_mcast_mask = cute.make_layout_image_mask(
|
|
618
|
-
|
|
739
|
+
cluster_layout_mnk, cluster_coord_mnk, mode=0
|
|
619
740
|
)
|
|
620
741
|
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
|
|
621
742
|
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
|
|
622
|
-
|
|
743
|
+
|
|
744
|
+
# Persistent tile scheduling loop
|
|
745
|
+
is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
746
|
+
if const_expr(cute.size(cluster_layout_mnk) > 1):
|
|
747
|
+
is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
|
|
748
|
+
tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
|
|
749
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
750
|
+
ab_producer_state = make_pipeline_state(
|
|
623
751
|
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
624
752
|
)
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
#
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
753
|
+
if const_expr(varlen_k):
|
|
754
|
+
# wait tensormap initialization complete before update
|
|
755
|
+
tensormap_manager.fence_tensormap_initialization()
|
|
756
|
+
# batch index of last tile
|
|
757
|
+
last_batch_idx = cutlass.Int32(-1)
|
|
758
|
+
while work_tile.is_valid_tile:
|
|
759
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
760
|
+
batch_idx = tile_coord_mnkl[3]
|
|
761
|
+
if const_expr(varlen_k):
|
|
762
|
+
is_group_changed = batch_idx != last_batch_idx
|
|
763
|
+
last_batch_idx = batch_idx
|
|
764
|
+
if is_group_changed:
|
|
765
|
+
# construct tensor A/B based on real address, shape and stride information
|
|
766
|
+
tensormap_manager.update_tensormap_shape(
|
|
767
|
+
(tensormap_a_ptr, tensormap_b_ptr),
|
|
768
|
+
is_manager_warp=is_tma_warp,
|
|
769
|
+
shapes=(cu_seqlens_k[batch_idx + 1], cu_seqlens_k[batch_idx + 1]),
|
|
770
|
+
orders=(
|
|
771
|
+
0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1,
|
|
772
|
+
0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1,
|
|
773
|
+
),
|
|
774
|
+
tensormap_smem_ptr=None,
|
|
775
|
+
)
|
|
776
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
777
|
+
# Local_tile partition global tensors
|
|
778
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
779
|
+
if const_expr(not self.gather_A):
|
|
780
|
+
if const_expr(varlen_m):
|
|
781
|
+
mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
|
|
782
|
+
elif const_expr(varlen_k):
|
|
783
|
+
mA_mk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mA_mkl)
|
|
784
|
+
else:
|
|
785
|
+
mA_mk = mA_mkl[None, None, batch_idx]
|
|
786
|
+
# (bM, bK, RestK)
|
|
787
|
+
gA_k = cute.local_tile(
|
|
788
|
+
mA_mk,
|
|
789
|
+
cute.select(self.tile_shape_mnk, [0, 2]),
|
|
790
|
+
(tile_coord_mnkl[0], None),
|
|
791
|
+
)
|
|
792
|
+
else:
|
|
793
|
+
mA_mk = mA_mkl
|
|
794
|
+
if const_expr(varlen_m):
|
|
795
|
+
mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
|
|
796
|
+
elif const_expr(varlen_k):
|
|
797
|
+
mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
|
|
798
|
+
else:
|
|
799
|
+
mAIdx_mk = mAIdx[None, batch_idx]
|
|
800
|
+
gAIdx = cute.local_tile(
|
|
801
|
+
mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
|
802
|
+
)
|
|
803
|
+
if const_expr(varlen_k):
|
|
804
|
+
mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
|
|
805
|
+
else:
|
|
806
|
+
mB_nk = mB_nkl[None, None, batch_idx]
|
|
807
|
+
# (bN, bK, RestK)
|
|
808
|
+
gB_k = cute.local_tile(
|
|
809
|
+
mB_nk, cute.select(self.tile_shape_mnk, [1, 2]), (tile_coord_mnkl[1], None)
|
|
675
810
|
)
|
|
676
|
-
|
|
811
|
+
# //////////////////////////////////////////////////////////////////////////
|
|
812
|
+
# Partition shared tensor for TMA load A/B
|
|
813
|
+
# //////////////////////////////////////////////////////////////////////////
|
|
814
|
+
if const_expr(varlen_k):
|
|
815
|
+
# ensure the update to tensormap has completed before using it
|
|
816
|
+
if is_group_changed and is_tma_warp:
|
|
817
|
+
tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
|
|
818
|
+
tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
|
|
819
|
+
tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
820
|
+
tensormap_a_ptr, cute.AddressSpace.generic
|
|
821
|
+
)
|
|
822
|
+
tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
823
|
+
tensormap_b_ptr, cute.AddressSpace.generic
|
|
824
|
+
)
|
|
825
|
+
else:
|
|
826
|
+
tma_desc_a_ptr, tma_desc_b_ptr = None, None
|
|
827
|
+
# TMA load A partition_S/D
|
|
828
|
+
a_cta_layout = cute.make_layout(
|
|
829
|
+
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
830
|
+
)
|
|
831
|
+
a_cta_crd = cluster_coord_mnk[1]
|
|
832
|
+
if const_expr(not self.gather_A):
|
|
833
|
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
834
|
+
tAsA, tAgA_k = cpasync.tma_partition(
|
|
835
|
+
tma_atom_a,
|
|
836
|
+
a_cta_crd,
|
|
837
|
+
a_cta_layout,
|
|
838
|
+
cute.group_modes(sA, 0, 2),
|
|
839
|
+
cute.group_modes(gA_k, 0, 2),
|
|
840
|
+
)
|
|
841
|
+
copy_A = partial(
|
|
842
|
+
cute.copy,
|
|
843
|
+
tma_atom_a,
|
|
844
|
+
mcast_mask=a_mcast_mask,
|
|
845
|
+
tma_desc_ptr=tma_desc_a_ptr,
|
|
846
|
+
)
|
|
847
|
+
else:
|
|
848
|
+
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
|
849
|
+
mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
|
|
850
|
+
)
|
|
851
|
+
tidx = (
|
|
852
|
+
cute.arch.thread_idx()[0]
|
|
853
|
+
- self.mma_warp_groups * self.num_threads_per_warp_group
|
|
854
|
+
)
|
|
855
|
+
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
856
|
+
# (atom_v, CPY_M, 1, STAGE)
|
|
857
|
+
tAsA = thr_copy_A.partition_D(sA)
|
|
858
|
+
assert tAsA.shape[2] == 1
|
|
859
|
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
|
860
|
+
copy_A = partial(cute.copy, tiled_copy_A)
|
|
861
|
+
# TMA load B partition_S/D
|
|
862
|
+
b_cta_layout = cute.make_layout(
|
|
863
|
+
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
|
864
|
+
)
|
|
865
|
+
b_cta_crd = cluster_coord_mnk[0]
|
|
866
|
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
867
|
+
tBsB, tBgB_k = cpasync.tma_partition(
|
|
677
868
|
tma_atom_b,
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
869
|
+
b_cta_crd,
|
|
870
|
+
b_cta_layout,
|
|
871
|
+
cute.group_modes(sB, 0, 2),
|
|
872
|
+
cute.group_modes(gB_k, 0, 2),
|
|
682
873
|
)
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
874
|
+
copy_B = partial(
|
|
875
|
+
cute.copy, tma_atom_b, mcast_mask=b_mcast_mask, tma_desc_ptr=tma_desc_b_ptr
|
|
876
|
+
)
|
|
877
|
+
k_len = (
|
|
878
|
+
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
879
|
+
if const_expr(varlen_k)
|
|
880
|
+
else mA_mkl.shape[1]
|
|
881
|
+
)
|
|
882
|
+
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
883
|
+
if const_expr(not self.gather_A):
|
|
884
|
+
ab_producer_state = self.load_AB(
|
|
885
|
+
ab_pipeline,
|
|
886
|
+
ab_producer_state,
|
|
887
|
+
copy_A,
|
|
888
|
+
tAgA_k,
|
|
889
|
+
tAsA,
|
|
890
|
+
copy_B,
|
|
891
|
+
tBgB_k,
|
|
892
|
+
tBsB,
|
|
893
|
+
k_tile_cnt,
|
|
894
|
+
)
|
|
895
|
+
else:
|
|
896
|
+
limit_m = (
|
|
897
|
+
mAIdx.shape[0]
|
|
898
|
+
if const_expr(cu_seqlens_m is None)
|
|
899
|
+
else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
|
|
900
|
+
)
|
|
901
|
+
ab_producer_state = self.load_AB_gather_A(
|
|
902
|
+
ab_pipeline,
|
|
903
|
+
ab_producer_state,
|
|
904
|
+
thr_copy_A,
|
|
905
|
+
mA_mk,
|
|
906
|
+
tAsA,
|
|
907
|
+
gAIdx,
|
|
908
|
+
copy_B,
|
|
909
|
+
tBgB_k,
|
|
910
|
+
tBsB,
|
|
911
|
+
k_tile_cnt,
|
|
912
|
+
limit_A=(
|
|
913
|
+
limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
|
|
914
|
+
mA_mk.shape[1],
|
|
915
|
+
),
|
|
916
|
+
)
|
|
917
|
+
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
918
|
+
tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
919
|
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
920
|
+
work_tile = tile_scheduler.get_current_work()
|
|
921
|
+
# End of persistent scheduler loop
|
|
922
|
+
if const_expr(self.pingpong and not varlen_k):
|
|
923
|
+
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
|
924
|
+
# tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
925
|
+
tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
926
|
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
927
|
+
ab_pipeline.producer_tail(ab_producer_state)
|
|
928
|
+
if is_scheduler_warp:
|
|
929
|
+
tile_scheduler.producer_tail()
|
|
930
|
+
|
|
931
|
+
if warp_idx < self.ab_load_warp_id:
|
|
689
932
|
cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
|
|
933
|
+
is_tma_warp = Boolean(
|
|
934
|
+
(not self.pingpong and warp_idx == 0)
|
|
935
|
+
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
|
936
|
+
)
|
|
937
|
+
if const_expr(varlen_m):
|
|
938
|
+
# initialize tensormap for D
|
|
939
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
940
|
+
tma_atom_d,
|
|
941
|
+
tensormap_d_ptr,
|
|
942
|
+
is_manager_warp=is_tma_warp,
|
|
943
|
+
)
|
|
690
944
|
# //////////////////////////////////////////////////////////////////////////////
|
|
691
945
|
# Partition global tensor for TiledMMA_A/B/C
|
|
692
946
|
# //////////////////////////////////////////////////////////////////////////////
|
|
693
947
|
tidx, _, _ = cute.arch.thread_idx()
|
|
694
948
|
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
|
949
|
+
if const_expr(self.pingpong):
|
|
950
|
+
tidx = tidx % self.num_threads_per_warp_group
|
|
695
951
|
warp_group_thread_layout = cute.make_layout(
|
|
696
|
-
self.mma_warp_groups
|
|
952
|
+
self.mma_warp_groups if not self.pingpong else 1,
|
|
953
|
+
stride=self.num_threads_per_warp_group,
|
|
954
|
+
)
|
|
955
|
+
thr_mma = tiled_mma.get_slice(
|
|
956
|
+
warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
|
|
697
957
|
)
|
|
698
|
-
thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx))
|
|
699
958
|
|
|
700
959
|
# //////////////////////////////////////////////////////////////////////////////
|
|
701
960
|
# Make fragments
|
|
@@ -705,148 +964,818 @@ class HopperWgmmaGemmKernel:
|
|
|
705
964
|
|
|
706
965
|
acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
|
|
707
966
|
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
967
|
+
acc_slow = None
|
|
968
|
+
if const_expr(self.fp8_slow_accum):
|
|
969
|
+
acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
970
|
+
|
|
971
|
+
if const_expr(self.pingpong):
|
|
972
|
+
if warp_group_idx == 0:
|
|
973
|
+
# WG0 needs a start signal at the very beginning
|
|
974
|
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
|
975
|
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
|
976
|
+
|
|
977
|
+
k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.tile_shape_mnk[2])
|
|
978
|
+
c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
|
|
979
|
+
|
|
980
|
+
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
|
981
|
+
epi_read_state = make_pipeline_state(
|
|
982
|
+
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
|
711
983
|
)
|
|
712
|
-
|
|
713
|
-
pipeline.PipelineUserType.
|
|
984
|
+
epi_producer_state = make_pipeline_state(
|
|
985
|
+
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
714
986
|
)
|
|
987
|
+
tile_scheduler = TileSchedulerCls()
|
|
988
|
+
work_tile = None
|
|
989
|
+
if const_expr(self.pingpong):
|
|
990
|
+
if const_expr(varlen_k):
|
|
991
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
992
|
+
if warp_idx >= 4:
|
|
993
|
+
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
|
994
|
+
epi_read_state.advance_iters(c_tile_cnt)
|
|
995
|
+
epi_producer_state.advance_iters(c_tile_cnt)
|
|
996
|
+
if const_expr(not varlen_k):
|
|
997
|
+
ab_read_state.advance_iters(k_tile_cnt_static)
|
|
998
|
+
else:
|
|
999
|
+
batch_idx = work_tile.tile_idx[3]
|
|
1000
|
+
k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1001
|
+
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
1002
|
+
ab_read_state.advance_iters(k_tile_cnt)
|
|
1003
|
+
tile_scheduler.advance_to_next_work()
|
|
1004
|
+
if const_expr(varlen_k):
|
|
1005
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1006
|
+
if const_expr(not varlen_k):
|
|
1007
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1008
|
+
else:
|
|
1009
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1010
|
+
if const_expr(varlen_m):
|
|
1011
|
+
# wait tensormap initialization complete before update
|
|
1012
|
+
tensormap_manager.fence_tensormap_initialization()
|
|
1013
|
+
# batch index of last tile
|
|
1014
|
+
last_batch_idx = cutlass.Int32(-1)
|
|
1015
|
+
while work_tile.is_valid_tile:
|
|
1016
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1017
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1018
|
+
if const_expr(varlen_m):
|
|
1019
|
+
is_group_changed = batch_idx != last_batch_idx
|
|
1020
|
+
last_batch_idx = batch_idx
|
|
1021
|
+
if is_group_changed:
|
|
1022
|
+
# construct tensor D based on real address, shape and stride information
|
|
1023
|
+
tensormap_manager.update_tensormap_shape(
|
|
1024
|
+
(tensormap_d_ptr,),
|
|
1025
|
+
is_manager_warp=is_tma_warp,
|
|
1026
|
+
shapes=(cu_seqlens_m[batch_idx + 1],),
|
|
1027
|
+
orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
|
|
1028
|
+
tensormap_smem_ptr=None,
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
k_len = (
|
|
1032
|
+
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1033
|
+
if const_expr(varlen_k)
|
|
1034
|
+
else mA_mkl.shape[1]
|
|
1035
|
+
)
|
|
1036
|
+
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
1037
|
+
ab_read_state, tiled_mma = self.mma(
|
|
1038
|
+
ab_pipeline,
|
|
1039
|
+
ab_read_state,
|
|
1040
|
+
tiled_mma,
|
|
1041
|
+
tCrA,
|
|
1042
|
+
tCrB,
|
|
1043
|
+
acc,
|
|
1044
|
+
acc_slow,
|
|
1045
|
+
k_tile_cnt,
|
|
1046
|
+
warp_group_idx,
|
|
1047
|
+
)
|
|
1048
|
+
if const_expr(varlen_k):
|
|
1049
|
+
if k_tile_cnt == 0:
|
|
1050
|
+
acc.fill(0.0)
|
|
1051
|
+
|
|
1052
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1053
|
+
# EPILOGUE
|
|
1054
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1055
|
+
if const_expr(self.pingpong):
|
|
1056
|
+
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
|
715
1057
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
719
|
-
k_pipe_mmas = 1
|
|
720
|
-
peek_ab_full_status = cutlass.Boolean(1)
|
|
721
|
-
if mainloop_consumer_read_state.count < k_tile_cnt:
|
|
722
|
-
peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
|
|
723
|
-
mainloop_consumer_read_state
|
|
1058
|
+
epilogue_barrier = pipeline.NamedBarrier(
|
|
1059
|
+
barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
|
|
724
1060
|
)
|
|
725
|
-
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
726
|
-
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
727
|
-
for k_tile in cutlass.range_constexpr(k_pipe_mmas):
|
|
728
|
-
# Wait for A/B buffer to be ready
|
|
729
|
-
mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status)
|
|
730
|
-
warpgroup.fence()
|
|
731
|
-
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
732
|
-
k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index)
|
|
733
|
-
cute.gemm(tiled_mma, acc, tCrA[k_block_coord], tCrB[k_block_coord], acc)
|
|
734
|
-
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
|
735
|
-
warpgroup.commit_group()
|
|
736
|
-
mainloop_consumer_read_state.advance()
|
|
737
|
-
peek_ab_full_status = cutlass.Boolean(1)
|
|
738
|
-
if mainloop_consumer_read_state.count < k_tile_cnt:
|
|
739
|
-
peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
|
|
740
|
-
mainloop_consumer_read_state
|
|
741
|
-
)
|
|
742
1061
|
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
# WGMMA
|
|
750
|
-
warpgroup.fence()
|
|
751
|
-
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
752
|
-
k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index)
|
|
753
|
-
cute.gemm(tiled_mma, acc, tCrA[k_block_coord], tCrB[k_block_coord], acc)
|
|
754
|
-
warpgroup.commit_group()
|
|
755
|
-
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
|
|
756
|
-
warpgroup.wait_group(k_pipe_mmas)
|
|
757
|
-
mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
|
|
758
|
-
mainloop_consumer_read_state.advance()
|
|
759
|
-
mainloop_consumer_release_state.advance()
|
|
760
|
-
peek_ab_full_status = cutlass.Boolean(1)
|
|
761
|
-
if mainloop_consumer_read_state.count < k_tile_cnt:
|
|
762
|
-
peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
|
|
763
|
-
mainloop_consumer_read_state
|
|
1062
|
+
if const_expr(varlen_m):
|
|
1063
|
+
# ensure the update to tensormap has completed before using it
|
|
1064
|
+
if is_group_changed and is_tma_warp:
|
|
1065
|
+
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1066
|
+
tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1067
|
+
tensormap_d_ptr, cute.AddressSpace.generic
|
|
764
1068
|
)
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
1069
|
+
else:
|
|
1070
|
+
tma_desc_d_ptr = None
|
|
1071
|
+
|
|
1072
|
+
if const_expr(has_D):
|
|
1073
|
+
bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(
|
|
1074
|
+
tma_atom_d,
|
|
1075
|
+
mD_mnl,
|
|
1076
|
+
self.tile_shape_mnk[:2],
|
|
1077
|
+
self.epi_tile,
|
|
1078
|
+
sD,
|
|
1079
|
+
tile_coord_mnkl,
|
|
1080
|
+
cu_seqlens_m,
|
|
1081
|
+
)
|
|
1082
|
+
copy_D = partial(cute.copy, tma_atom_d, tma_desc_ptr=tma_desc_d_ptr)
|
|
1083
|
+
else:
|
|
1084
|
+
bSG_sD, bSG_gD, copy_D = None, None, None
|
|
1085
|
+
if const_expr(has_C):
|
|
1086
|
+
bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
|
|
1087
|
+
tma_atom_c,
|
|
1088
|
+
mC_mnl,
|
|
1089
|
+
self.tile_shape_mnk[:2],
|
|
1090
|
+
self.epi_tile,
|
|
1091
|
+
sC,
|
|
1092
|
+
tile_coord_mnkl,
|
|
1093
|
+
cu_seqlens_m,
|
|
1094
|
+
)
|
|
1095
|
+
copy_C = partial(cute.copy, tma_atom_c)
|
|
1096
|
+
epi_load_g2s = partial(self.epi_load_g2s, epi_pipeline, copy_C, bGS_gC, bGS_sC)
|
|
1097
|
+
else:
|
|
1098
|
+
epi_load_g2s = None
|
|
1099
|
+
|
|
1100
|
+
d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
|
|
1101
|
+
tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
|
1102
|
+
tiled_mma, self.d_layout, d_dtype_for_layout, acc, sD, tidx
|
|
1103
|
+
)
|
|
1104
|
+
if const_expr(has_C):
|
|
1105
|
+
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
|
1106
|
+
tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
|
|
1107
|
+
)
|
|
1108
|
+
else:
|
|
1109
|
+
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
|
1110
|
+
|
|
1111
|
+
# Wait for all warp groups in the thread block to finish, because smem for tensor
|
|
1112
|
+
# A in the mainloop is reused in the epilogue if not persistent.
|
|
1113
|
+
if const_expr(not self.is_persistent):
|
|
1114
|
+
epilogue_barrier.arrive_and_wait()
|
|
1115
|
+
|
|
1116
|
+
self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
|
|
1117
|
+
|
|
1118
|
+
epi_read_state, epi_producer_state = self.epilogue(
|
|
1119
|
+
epilogue_params,
|
|
1120
|
+
epi_smem_tensors,
|
|
1121
|
+
epi_pipeline,
|
|
1122
|
+
epi_read_state,
|
|
1123
|
+
epi_producer_state,
|
|
1124
|
+
tiled_mma,
|
|
1125
|
+
tRS_rAcc,
|
|
1126
|
+
tRS_rD,
|
|
1127
|
+
tRS_rC,
|
|
1128
|
+
tiled_copy_r2s,
|
|
1129
|
+
tRS_sD,
|
|
1130
|
+
tiled_copy_s2r,
|
|
1131
|
+
tSR_rC,
|
|
1132
|
+
tSR_sC,
|
|
1133
|
+
copy_D,
|
|
1134
|
+
bSG_sD,
|
|
1135
|
+
bSG_gD,
|
|
1136
|
+
epi_load_g2s,
|
|
1137
|
+
tile_coord_mnkl,
|
|
1138
|
+
cu_seqlens_m,
|
|
1139
|
+
epilogue_barrier,
|
|
1140
|
+
tile_scheduler,
|
|
1141
|
+
tidx,
|
|
1142
|
+
is_tma_warp,
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
if const_expr(self.pingpong):
|
|
1146
|
+
# With pingpong, 2 WGs write two different output tiles to the same smem,
|
|
1147
|
+
# so we have to make sure the smem content is done reading before signaling
|
|
1148
|
+
# the next WG's epilogue.
|
|
1149
|
+
if is_tma_warp:
|
|
1150
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
1151
|
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
|
1152
|
+
|
|
1153
|
+
if const_expr(not self.pingpong):
|
|
1154
|
+
tile_scheduler.advance_to_next_work()
|
|
1155
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1156
|
+
else: # Skip a tile for pingpong
|
|
1157
|
+
# Update starting load/store pipeline states for the next tile
|
|
1158
|
+
epi_read_state.advance_iters(c_tile_cnt)
|
|
1159
|
+
epi_producer_state.advance_iters(c_tile_cnt)
|
|
1160
|
+
# Update starting mainloop pipeline state for the next tile
|
|
1161
|
+
if const_expr(not varlen_k):
|
|
1162
|
+
ab_read_state.advance_iters(k_tile_cnt_static)
|
|
1163
|
+
tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
|
|
1164
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1165
|
+
else:
|
|
1166
|
+
tile_scheduler.advance_to_next_work()
|
|
1167
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1168
|
+
if work_tile.is_valid_tile:
|
|
1169
|
+
batch_idx = work_tile.tile_idx[3]
|
|
1170
|
+
k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1171
|
+
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
1172
|
+
ab_read_state.advance_iters(k_tile_cnt)
|
|
1173
|
+
tile_scheduler.advance_to_next_work()
|
|
1174
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1175
|
+
# End of persistent scheduler loop
|
|
1176
|
+
|
|
1177
|
+
if const_expr(not self.pingpong):
|
|
1178
|
+
if is_tma_warp:
|
|
1179
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
1180
|
+
|
|
1181
|
+
@cute.jit
|
|
1182
|
+
def load_AB(
|
|
1183
|
+
self,
|
|
1184
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1185
|
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1186
|
+
copy_A: Callable,
|
|
1187
|
+
tAgA: cute.Tensor,
|
|
1188
|
+
tAsA: cute.Tensor,
|
|
1189
|
+
copy_B: Callable,
|
|
1190
|
+
tBgB: cute.Tensor,
|
|
1191
|
+
tBsB: cute.Tensor,
|
|
1192
|
+
k_tile_cnt: Int32,
|
|
1193
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1194
|
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1195
|
+
peek_ab_empty_status = Boolean(True)
|
|
1196
|
+
if 0 < k_tile_cnt:
|
|
1197
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1198
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1199
|
+
# TMA load
|
|
1200
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1201
|
+
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
1202
|
+
# Wait for A/B buffers to be empty before loading into them
|
|
1203
|
+
# Also sets the transaction barrier for the A/B buffers
|
|
1204
|
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
|
1205
|
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1206
|
+
copy_A(tAgA[None, k_tile], tAsA[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
|
|
1207
|
+
copy_B(tBgB[None, k_tile], tBsB[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
|
|
1208
|
+
# Mainloop pipeline's producer commit is a NOP
|
|
1209
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1210
|
+
ab_producer_state.advance()
|
|
1211
|
+
peek_ab_empty_status = Boolean(True)
|
|
1212
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1213
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1214
|
+
return ab_producer_state
|
|
1215
|
+
|
|
1216
|
+
@cute.jit
|
|
1217
|
+
def load_AB_gather_A(
|
|
1218
|
+
self,
|
|
1219
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1220
|
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1221
|
+
thr_copy_A: cute.core.ThrCopy,
|
|
1222
|
+
mA: cute.Tensor,
|
|
1223
|
+
tAsA: cute.Tensor,
|
|
1224
|
+
gAIdx: cute.Tensor,
|
|
1225
|
+
copy_B: Callable,
|
|
1226
|
+
tBgB: cute.Tensor,
|
|
1227
|
+
tBsB: cute.Tensor,
|
|
1228
|
+
k_tile_cnt: Int32,
|
|
1229
|
+
limit_A: Tuple[Int32, Int32],
|
|
1230
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1231
|
+
# (atom_v, CPY_M, 1, RestK)
|
|
1232
|
+
limit_m, limit_k = limit_A
|
|
1233
|
+
limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
|
|
1234
|
+
cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
|
|
1235
|
+
tAcA = thr_copy_A.partition_S(cA)
|
|
1236
|
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
1237
|
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
1238
|
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
1239
|
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
1240
|
+
limit_m = limit_m - tAcA[0][0]
|
|
1241
|
+
# Read indices for A
|
|
1242
|
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
1243
|
+
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
|
1244
|
+
for m in cutlass.range(rows_per_thread):
|
|
1245
|
+
row_idx = tAcA[0, m, 0][0]
|
|
1246
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1247
|
+
m_idx[m] = gAIdx[row_idx]
|
|
1248
|
+
else:
|
|
1249
|
+
m_idx[m] = -1
|
|
1250
|
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
1251
|
+
# (m, (bK, RestK))
|
|
1252
|
+
mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
|
|
1253
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
1254
|
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1255
|
+
peek_ab_empty_status = Boolean(True)
|
|
1256
|
+
if 0 < k_tile_cnt:
|
|
1257
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1258
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1259
|
+
# TMA load on B and cp.async on A
|
|
1260
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1261
|
+
copy_A = partial(cute.copy, thr_copy_A)
|
|
1262
|
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1263
|
+
# Wait for A/B buffers to be empty before loading into them
|
|
1264
|
+
# Also sets the transaction barrier for the A/B buffers
|
|
1265
|
+
ab_pipeline.producer_acquire(
|
|
1266
|
+
ab_producer_state,
|
|
1267
|
+
peek_ab_empty_status,
|
|
1268
|
+
# A tiny bit faster to rotate the warp that does TMA
|
|
1269
|
+
is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
|
|
797
1270
|
)
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
1271
|
+
# A bit faster to load B first while we calculate the predicate for A
|
|
1272
|
+
if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
|
|
1273
|
+
copy_B(
|
|
1274
|
+
tBgB[None, k_tile],
|
|
1275
|
+
tBsB[None, ab_producer_state.index],
|
|
1276
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1277
|
+
)
|
|
1278
|
+
# (m, bK)
|
|
1279
|
+
mA_cur = mA_k[None, (None, k_tile)]
|
|
1280
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1281
|
+
# (elems_per_load, thread_per_row)
|
|
1282
|
+
mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
|
|
1283
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1284
|
+
# There's only 1 load per row
|
|
1285
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1286
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1287
|
+
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1288
|
+
# This tells mbarrier to track the completion of cp.async
|
|
1289
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1290
|
+
ab_producer_state.advance()
|
|
1291
|
+
peek_ab_empty_status = Boolean(True)
|
|
1292
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1293
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1294
|
+
# bound checking in the K dimension on the last k_tile
|
|
1295
|
+
if 0 < k_tile_cnt:
|
|
1296
|
+
k_tile = k_tile_cnt - 1
|
|
1297
|
+
ab_pipeline.producer_acquire(
|
|
1298
|
+
ab_producer_state,
|
|
1299
|
+
peek_ab_empty_status,
|
|
1300
|
+
is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
|
|
805
1301
|
)
|
|
1302
|
+
if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
|
|
1303
|
+
copy_B(
|
|
1304
|
+
tBgB[None, k_tile],
|
|
1305
|
+
tBsB[None, ab_producer_state.index],
|
|
1306
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1307
|
+
)
|
|
1308
|
+
assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
|
|
1309
|
+
tApA = cute.make_fragment(1, Boolean)
|
|
1310
|
+
tApA[0] = tAcA[0, 0, 0][1] < limit_k
|
|
1311
|
+
# (m, bK)
|
|
1312
|
+
mA_cur = mA_k[None, (None, k_tile)]
|
|
1313
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1314
|
+
# (elems_per_load, thread_per_row)
|
|
1315
|
+
mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
|
|
1316
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1317
|
+
# There's only 1 load per row
|
|
1318
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1319
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1320
|
+
# copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
|
|
1321
|
+
# TODO
|
|
1322
|
+
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1323
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1324
|
+
ab_producer_state.advance()
|
|
1325
|
+
return ab_producer_state
|
|
806
1326
|
|
|
807
|
-
|
|
808
|
-
|
|
1327
|
+
@cute.jit
|
|
1328
|
+
def mma(
|
|
1329
|
+
self,
|
|
1330
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1331
|
+
ab_read_state: cutlass.pipeline.PipelineState,
|
|
1332
|
+
tiled_mma: cute.TiledMma,
|
|
1333
|
+
tCrA: cute.Tensor,
|
|
1334
|
+
tCrB: cute.Tensor,
|
|
1335
|
+
acc: cute.Tensor,
|
|
1336
|
+
acc_slow: Optional[cute.Tensor],
|
|
1337
|
+
k_tile_cnt: Int32,
|
|
1338
|
+
warp_group_idx: Int32,
|
|
1339
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
|
|
1340
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1341
|
+
# Prologue MMAs
|
|
1342
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1343
|
+
k_pipe_mmas = 1
|
|
1344
|
+
ab_release_state = ab_read_state.clone()
|
|
1345
|
+
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
|
|
1346
|
+
if const_expr(self.pingpong):
|
|
1347
|
+
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
|
|
1348
|
+
peek_ab_full_status = Boolean(True)
|
|
1349
|
+
if 0 < k_tile_cnt:
|
|
1350
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1351
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
1352
|
+
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
1353
|
+
for k_tile in cutlass.range(num_prologue_mma):
|
|
1354
|
+
# Wait for A/B buffer to be ready
|
|
1355
|
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
1356
|
+
warpgroup.fence()
|
|
1357
|
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1358
|
+
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
|
1359
|
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1360
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
|
1361
|
+
warpgroup.commit_group()
|
|
1362
|
+
ab_read_state.advance()
|
|
1363
|
+
peek_ab_full_status = Boolean(True)
|
|
1364
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1365
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1366
|
+
# If k_tile_cnt == 0, this is not correct. But we will set acc to 0 in the mainloop
|
|
1367
|
+
# in that case.
|
|
1368
|
+
if const_expr(self.fp8_slow_accum):
|
|
1369
|
+
warpgroup.wait_group(0)
|
|
1370
|
+
acc_slow.store(acc.load())
|
|
809
1371
|
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
1372
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1373
|
+
# MAINLOOP
|
|
1374
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1375
|
+
for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
|
|
1376
|
+
# Wait for TMA copies to complete
|
|
1377
|
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
1378
|
+
# WGMMA
|
|
1379
|
+
warpgroup.fence()
|
|
1380
|
+
if const_expr(self.fp8_slow_accum):
|
|
1381
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
1382
|
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1383
|
+
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
|
1384
|
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1385
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
|
1386
|
+
warpgroup.commit_group()
|
|
1387
|
+
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
|
|
1388
|
+
if const_expr(not self.fp8_slow_accum):
|
|
1389
|
+
warpgroup.wait_group(k_pipe_mmas)
|
|
1390
|
+
else:
|
|
1391
|
+
warpgroup.wait_group(0)
|
|
1392
|
+
acc_slow.store(acc_slow.load() + acc.load())
|
|
1393
|
+
ab_pipeline.consumer_release(ab_release_state)
|
|
1394
|
+
ab_read_state.advance()
|
|
1395
|
+
ab_release_state.advance()
|
|
1396
|
+
peek_ab_full_status = Boolean(True)
|
|
1397
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1398
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1399
|
+
if const_expr(self.pingpong):
|
|
1400
|
+
# Cue for next WG's MMA to start
|
|
1401
|
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
|
|
1402
|
+
if const_expr(not self.fp8_slow_accum):
|
|
1403
|
+
# fp8_slow_accum would already called wait_group(0) inside the loop
|
|
1404
|
+
warpgroup.wait_group(0)
|
|
1405
|
+
for k_tile in cutlass.range(num_prologue_mma, unroll=1):
|
|
1406
|
+
ab_pipeline.consumer_release(ab_release_state)
|
|
1407
|
+
ab_release_state.advance()
|
|
1408
|
+
if const_expr(self.fp8_slow_accum):
|
|
1409
|
+
acc.store(acc_slow.load())
|
|
1410
|
+
# If we don't return the tiled_mma, we get compiler error
|
|
1411
|
+
# "operand #0 does not dominate this use"
|
|
1412
|
+
return ab_read_state, tiled_mma
|
|
1413
|
+
|
|
1414
|
+
@cute.jit
|
|
1415
|
+
def epilogue(
|
|
1416
|
+
self,
|
|
1417
|
+
params: EpilogueParams,
|
|
1418
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
1419
|
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1420
|
+
epi_read_state: cutlass.pipeline.PipelineState,
|
|
1421
|
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
1422
|
+
tiled_mma: cute.TiledMma,
|
|
1423
|
+
tRS_rAcc: cute.Tensor,
|
|
1424
|
+
tRS_rD: cute.Tensor,
|
|
1425
|
+
tRS_rC: Optional[cute.Tensor],
|
|
1426
|
+
tiled_copy_r2s: cute.core.ThrCopy,
|
|
1427
|
+
tRS_sD: cute.Tensor,
|
|
1428
|
+
tiled_copy_s2r: Optional[cute.core.ThrCopy],
|
|
1429
|
+
tSR_rC: Optional[cute.Tensor],
|
|
1430
|
+
tSR_sC: Optional[cute.Tensor],
|
|
1431
|
+
copy_D: Optional[Callable],
|
|
1432
|
+
bSG_sD: cute.Tensor,
|
|
1433
|
+
bSG_gD: cute.Tensor,
|
|
1434
|
+
epi_load_g2s: Optional[Callable],
|
|
1435
|
+
tile_coord_mnkl: cute.Coord,
|
|
1436
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
1437
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
1438
|
+
tile_scheduler,
|
|
1439
|
+
tidx: Int32,
|
|
1440
|
+
is_tma_warp: Boolean,
|
|
1441
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
1442
|
+
has_C = const_expr(tRS_rC is not None)
|
|
1443
|
+
has_D = const_expr(copy_D is not None)
|
|
1444
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1445
|
+
epi_tile_shape = cute.zipped_divide(
|
|
1446
|
+
cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
|
|
1447
|
+
).shape[1]
|
|
1448
|
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
1449
|
+
epi_tile_num = cute.size(epi_tile_shape)
|
|
1450
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
1451
|
+
|
|
1452
|
+
if const_expr(epi_load_g2s is not None):
|
|
1453
|
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
1454
|
+
epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
|
|
1455
|
+
|
|
1456
|
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
1457
|
+
# Copy from acc to D registers
|
|
1458
|
+
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
1459
|
+
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
1460
|
+
if const_expr(has_C):
|
|
1461
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
1462
|
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
1463
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
822
1464
|
cute.arch.fence_proxy(
|
|
823
1465
|
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
824
1466
|
)
|
|
825
|
-
|
|
826
|
-
cute.arch.
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
1467
|
+
cute.arch.sync_warp()
|
|
1468
|
+
with cute.arch.elect_one():
|
|
1469
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
1470
|
+
epi_read_state.advance()
|
|
1471
|
+
if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
1472
|
+
epi_producer_state = epi_load_g2s(
|
|
1473
|
+
epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
|
|
1474
|
+
)
|
|
1475
|
+
tRS_rEpi = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
|
|
1476
|
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
1477
|
+
# Copy from D registers to shared memory
|
|
1478
|
+
if const_expr(has_D):
|
|
1479
|
+
# Type conversion
|
|
1480
|
+
tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
|
|
1481
|
+
tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
|
|
1482
|
+
cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
|
|
1483
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1484
|
+
cute.arch.fence_proxy(
|
|
1485
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1486
|
+
)
|
|
1487
|
+
epilogue_barrier.arrive_and_wait()
|
|
1488
|
+
# Get the global memory coordinate for the current epi tile
|
|
1489
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1490
|
+
# Copy from shared memory to global memory
|
|
1491
|
+
if is_tma_warp:
|
|
1492
|
+
if const_expr(has_D):
|
|
1493
|
+
copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
|
|
1494
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
1495
|
+
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
|
|
1496
|
+
epilogue_barrier.arrive_and_wait()
|
|
1497
|
+
|
|
1498
|
+
return epi_read_state, epi_producer_state
|
|
1499
|
+
|
|
1500
|
+
@cute.jit
|
|
1501
|
+
def epi_load_g2s(
|
|
1502
|
+
self,
|
|
1503
|
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1504
|
+
copy_C: Callable,
|
|
1505
|
+
bGS_gC: cute.Tensor,
|
|
1506
|
+
bGS_sC: cute.Tensor,
|
|
1507
|
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
1508
|
+
epi_idx: Int32,
|
|
1509
|
+
should_load: Boolean,
|
|
1510
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1511
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1512
|
+
epi_tile_layout = cute.make_layout(bGS_gC.shape[1], stride=(bGS_gC.shape[1][1], 1))
|
|
1513
|
+
if should_load:
|
|
1514
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1515
|
+
# Get the global memory coordinate for the current epi tile
|
|
1516
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1517
|
+
copy_C(
|
|
1518
|
+
bGS_gC[None, gmem_coord],
|
|
1519
|
+
bGS_sC[None, epi_producer_state.index],
|
|
1520
|
+
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1521
|
+
)
|
|
1522
|
+
# Epi pipeline's producer commit is a NOP
|
|
1523
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1524
|
+
epi_producer_state.advance()
|
|
1525
|
+
return epi_producer_state
|
|
1526
|
+
|
|
1527
|
+
def epi_visit_acc_subtile(
|
|
1528
|
+
self,
|
|
1529
|
+
params: EpilogueParams,
|
|
1530
|
+
tRS_rD: cute.Tensor,
|
|
1531
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
1532
|
+
) -> Optional[cute.Tensor]:
|
|
1533
|
+
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
|
1534
|
+
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
1535
|
+
alpha = utils.load_scalar_or_pointer(params.alpha)
|
|
1536
|
+
tRS_rD.store(tRS_rD.load() * alpha)
|
|
1537
|
+
# Apply C with beta scaling
|
|
1538
|
+
if const_expr(tRS_rC is not None):
|
|
1539
|
+
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
|
1540
|
+
# beta is None, default behavior: add C (beta=1.0)
|
|
1541
|
+
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
|
|
1542
|
+
else:
|
|
1543
|
+
beta = utils.load_scalar_or_pointer(params.beta)
|
|
1544
|
+
tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
|
|
1545
|
+
return None
|
|
1546
|
+
|
|
1547
|
+
def get_scheduler_class(self):
|
|
1548
|
+
"""Return the scheduler class to use. Override in subclasses for custom schedulers."""
|
|
1549
|
+
return TileScheduler
|
|
1550
|
+
|
|
1551
|
+
def get_scheduler_arguments(self, problem_shape_ntile_mnl, scheduler_args):
|
|
1552
|
+
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
|
1553
|
+
return TileSchedulerArguments(
|
|
1554
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
1555
|
+
raster_order=scheduler_args.raster_order,
|
|
1556
|
+
group_size=scheduler_args.max_swizzle_size,
|
|
1557
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1558
|
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1559
|
+
batch_idx_permute=scheduler_args.batch_idx_permute,
|
|
1560
|
+
is_persistent=self.is_persistent,
|
|
1561
|
+
)
|
|
1562
|
+
|
|
1563
|
+
def epi_visit_acc(
|
|
1564
|
+
self,
|
|
1565
|
+
params: EpilogueParams,
|
|
1566
|
+
acc: cute.Tensor,
|
|
1567
|
+
tiled_mma: cute.TiledMma,
|
|
1568
|
+
tile_coord_mnkl: cute.Coord,
|
|
1569
|
+
tidx: Int32,
|
|
1570
|
+
) -> None:
|
|
1571
|
+
pass
|
|
1572
|
+
|
|
1573
|
+
def epi_to_underlying_arguments(
|
|
1574
|
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
1575
|
+
) -> EpilogueParams:
|
|
1576
|
+
return GemmSm90.EpilogueParams(alpha=args.alpha, beta=args.beta)
|
|
842
1577
|
|
|
843
1578
|
@staticmethod
|
|
1579
|
+
def epi_smem_bytes_per_stage(
|
|
1580
|
+
args: Optional[EpilogueArguments],
|
|
1581
|
+
tile_shape_mnk: Tuple[int, int, int],
|
|
1582
|
+
epi_tile: Tuple[int, int],
|
|
1583
|
+
) -> int:
|
|
1584
|
+
return 0
|
|
1585
|
+
|
|
1586
|
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
1587
|
+
return cute.struct.MemRange[cutlass.Int32, 0] # Dummy struct
|
|
1588
|
+
|
|
1589
|
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
1590
|
+
return tuple()
|
|
1591
|
+
|
|
1592
|
+
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
|
|
1593
|
+
assert stage in ["mma", "epi"]
|
|
1594
|
+
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1595
|
+
cute.arch.barrier(
|
|
1596
|
+
barrier_id=int(barrier) + warp_group_idx,
|
|
1597
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
|
|
1601
|
+
assert stage in ["mma", "epi"]
|
|
1602
|
+
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1603
|
+
cute.arch.barrier_arrive(
|
|
1604
|
+
barrier_id=int(barrier) + warp_group_idx,
|
|
1605
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1606
|
+
)
|
|
1607
|
+
|
|
1608
|
+
def epilog_smem_copy_atom(self, tiled_mma: cute.TiledMma) -> cute.TiledCopy:
|
|
1609
|
+
copy_atom_C = cute.make_copy_atom(
|
|
1610
|
+
warp.StMatrix8x8x16bOp(
|
|
1611
|
+
self.d_layout.is_m_major_c() if self.d_layout is not None else False,
|
|
1612
|
+
num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
|
|
1613
|
+
),
|
|
1614
|
+
cutlass.Float16, # this is just to get the right source layout
|
|
1615
|
+
)
|
|
1616
|
+
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
|
1617
|
+
return tiled_copy_C_atom
|
|
1618
|
+
|
|
1619
|
+
def epilog_smem_store_and_partition(
|
|
1620
|
+
self,
|
|
1621
|
+
tiled_mma: cute.TiledMma,
|
|
1622
|
+
d_layout: Optional[LayoutEnum],
|
|
1623
|
+
dtype: Type[cutlass.Numeric],
|
|
1624
|
+
acc: cute.Tensor,
|
|
1625
|
+
sD: cute.Tensor,
|
|
1626
|
+
tidx: Int32,
|
|
1627
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1628
|
+
if d_layout is None:
|
|
1629
|
+
d_layout = LayoutEnum.ROW_MAJOR
|
|
1630
|
+
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
1631
|
+
# Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
|
|
1632
|
+
# get st.matrix with num_matrices=4
|
|
1633
|
+
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
|
|
1634
|
+
d_layout, elem_ty_d=dtype, elem_ty_acc=self.acc_dtype
|
|
1635
|
+
)
|
|
1636
|
+
tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
|
|
1637
|
+
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1638
|
+
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1639
|
+
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
|
1640
|
+
# (R2S, R2S_M, R2S_N)
|
|
1641
|
+
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
1642
|
+
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
|
|
1643
|
+
tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
|
|
1644
|
+
tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
|
|
1645
|
+
return tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD
|
|
1646
|
+
|
|
1647
|
+
def epilog_smem_load_and_partition(
|
|
1648
|
+
self,
|
|
1649
|
+
tiled_mma: cute.TiledMma,
|
|
1650
|
+
c_layout: LayoutEnum,
|
|
1651
|
+
dtype: Type[cutlass.Numeric],
|
|
1652
|
+
sC: cute.Tensor,
|
|
1653
|
+
tRS_rD_layout: cutlass.Layout,
|
|
1654
|
+
tidx: Int32,
|
|
1655
|
+
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1656
|
+
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
1657
|
+
copy_atom_s2r = utils.sm90_get_smem_load_op(c_layout, dtype)
|
|
1658
|
+
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
|
1659
|
+
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1660
|
+
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1661
|
+
tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
|
|
1662
|
+
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
|
1663
|
+
return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
|
|
1664
|
+
|
|
1665
|
+
def epilog_gmem_copy_and_partition(
|
|
1666
|
+
self,
|
|
1667
|
+
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
|
1668
|
+
mD_mnl: cute.Tensor,
|
|
1669
|
+
tile_shape_mn: cute.Tile,
|
|
1670
|
+
epi_tile: cute.Tile,
|
|
1671
|
+
sD: cute.Tensor,
|
|
1672
|
+
tile_coord_mnkl: cute.Coord,
|
|
1673
|
+
cu_seqlens_m: Optional[cute.Tensor] = None,
|
|
1674
|
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
1675
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1676
|
+
if const_expr(cu_seqlens_m is not None):
|
|
1677
|
+
mD_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl)
|
|
1678
|
+
else:
|
|
1679
|
+
mD_mn = mD_mnl[None, None, batch_idx]
|
|
1680
|
+
# (bM, bN)
|
|
1681
|
+
gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
|
|
1682
|
+
tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
|
|
1683
|
+
bSG_sD, bSG_gD = cpasync.tma_partition(
|
|
1684
|
+
atom,
|
|
1685
|
+
0,
|
|
1686
|
+
cute.make_layout(1),
|
|
1687
|
+
cute.group_modes(sD, 0, 2),
|
|
1688
|
+
tDgD_for_tma_partition,
|
|
1689
|
+
)
|
|
1690
|
+
return bSG_sD, bSG_gD
|
|
1691
|
+
|
|
1692
|
+
def make_ab_pipeline(
|
|
1693
|
+
self,
|
|
1694
|
+
a_smem_layout: cute.Layout | cute.ComposedLayout,
|
|
1695
|
+
b_smem_layout: cute.Layout | cute.ComposedLayout,
|
|
1696
|
+
tiled_mma: cute.TiledMma,
|
|
1697
|
+
cluster_layout_vmnk: cute.Layout,
|
|
1698
|
+
ab_pipeline_mbar_ptr: cute.Pointer,
|
|
1699
|
+
):
|
|
1700
|
+
# Threads/warps participating in this pipeline
|
|
1701
|
+
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
|
|
1702
|
+
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
1703
|
+
# Each warp will contribute to the arrive count with the number of mcast size
|
|
1704
|
+
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
1705
|
+
consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
|
|
1706
|
+
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1707
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1708
|
+
)
|
|
1709
|
+
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
1710
|
+
tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
1711
|
+
if const_expr(not self.gather_A):
|
|
1712
|
+
tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
1713
|
+
return pipeline_cls.create(
|
|
1714
|
+
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1715
|
+
num_stages=self.ab_stage,
|
|
1716
|
+
producer_group=ab_pipeline_producer_group,
|
|
1717
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
1718
|
+
tx_count=tma_copy_bytes,
|
|
1719
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1720
|
+
)
|
|
1721
|
+
|
|
1722
|
+
def make_epi_pipeline(
|
|
1723
|
+
self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer
|
|
1724
|
+
):
|
|
1725
|
+
# Threads/warps participating in this pipeline
|
|
1726
|
+
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1727
|
+
# Each warp will contribute 1 to the arrive count
|
|
1728
|
+
consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
|
|
1729
|
+
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1730
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1731
|
+
)
|
|
1732
|
+
tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
|
|
1733
|
+
return pipeline.PipelineTmaAsync.create(
|
|
1734
|
+
barrier_storage=epi_pipeline_mbar_ptr,
|
|
1735
|
+
num_stages=self.epi_c_stage,
|
|
1736
|
+
producer_group=epi_pipeline_producer_group,
|
|
1737
|
+
consumer_group=epi_pipeline_consumer_group,
|
|
1738
|
+
tx_count=tma_copy_c_bytes,
|
|
1739
|
+
)
|
|
1740
|
+
|
|
1741
|
+
def make_sched_pipeline(
|
|
1742
|
+
self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
|
|
1743
|
+
):
|
|
1744
|
+
# Threads/warps participating in this pipeline
|
|
1745
|
+
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1746
|
+
cluster_size = cute.size(cluster_layout_mnk)
|
|
1747
|
+
# Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
|
1748
|
+
# If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
|
|
1749
|
+
# at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
|
|
1750
|
+
consumer_arrive_cnt = (
|
|
1751
|
+
(self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
|
|
1752
|
+
+ self.num_ab_load_warps
|
|
1753
|
+
) * cluster_size - 1
|
|
1754
|
+
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1755
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1756
|
+
)
|
|
1757
|
+
return pipeline.PipelineAsync.create(
|
|
1758
|
+
barrier_storage=sched_pipeline_mbar_ptr,
|
|
1759
|
+
num_stages=self.sched_stage,
|
|
1760
|
+
producer_group=sched_pipeline_producer_group,
|
|
1761
|
+
consumer_group=sched_pipeline_consumer_group,
|
|
1762
|
+
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
1763
|
+
consumer_mask=None if const_expr(cluster_size == 1) else 0,
|
|
1764
|
+
)
|
|
1765
|
+
|
|
1766
|
+
@classmethod
|
|
844
1767
|
def _compute_stages(
|
|
1768
|
+
cls,
|
|
845
1769
|
tile_shape_mnk: Tuple[int, int, int],
|
|
1770
|
+
epi_tile: Tuple[int, int],
|
|
846
1771
|
a_dtype: Type[cutlass.Numeric],
|
|
847
1772
|
b_dtype: Type[cutlass.Numeric],
|
|
1773
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1774
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1775
|
+
epilogue_args: Optional[EpilogueArguments],
|
|
848
1776
|
smem_capacity: int,
|
|
849
1777
|
occupancy: int,
|
|
1778
|
+
overlap_sD_sA: bool,
|
|
850
1779
|
) -> Tuple[int, int]:
|
|
851
1780
|
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
852
1781
|
|
|
@@ -866,10 +1795,20 @@ class HopperWgmmaGemmKernel:
|
|
|
866
1795
|
:rtype: Tuple[int, int]
|
|
867
1796
|
"""
|
|
868
1797
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
1798
|
+
epi_stage = 4 if epi_tile[1] <= 16 else 2
|
|
1799
|
+
if overlap_sD_sA:
|
|
1800
|
+
epi_bytes = 0
|
|
1801
|
+
else:
|
|
1802
|
+
d_bytes_per_stage = (
|
|
1803
|
+
cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
|
|
1804
|
+
)
|
|
1805
|
+
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
|
1806
|
+
epilogue_args, tile_shape_mnk, epi_tile
|
|
1807
|
+
)
|
|
1808
|
+
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
1809
|
+
epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
|
|
1810
|
+
if c_dtype is not None:
|
|
1811
|
+
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
873
1812
|
|
|
874
1813
|
a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
|
|
875
1814
|
b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
|
|
@@ -878,16 +1817,21 @@ class HopperWgmmaGemmKernel:
|
|
|
878
1817
|
)
|
|
879
1818
|
mbar_helpers_bytes = 1024
|
|
880
1819
|
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
1820
|
+
remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
|
|
1821
|
+
ab_stage = remaining_bytes // ab_bytes_per_stage
|
|
1822
|
+
|
|
1823
|
+
# Refine epilogue stages:
|
|
1824
|
+
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1825
|
+
# Add remaining unused smem to epilogue
|
|
1826
|
+
if not overlap_sD_sA and epi_bytes_per_stage > 0:
|
|
1827
|
+
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
|
|
1828
|
+
return ab_stage, epi_stage, epi_c_stage
|
|
885
1829
|
|
|
886
1830
|
@staticmethod
|
|
887
1831
|
def _sm90_compute_tile_shape_or_override(
|
|
888
1832
|
tile_shape_mnk: Tuple[int, int, int],
|
|
889
|
-
|
|
890
|
-
|
|
1833
|
+
atom_layout_mnk: Tuple[int, int, int],
|
|
1834
|
+
element_type: Optional[Type[cutlass.Numeric]] = None,
|
|
891
1835
|
epi_tile_override: Tuple[int, int] | None = None,
|
|
892
1836
|
) -> Tuple[int, int]:
|
|
893
1837
|
"""Compute the epilogue tile shape or use override if provided.
|
|
@@ -906,33 +1850,42 @@ class HopperWgmmaGemmKernel:
|
|
|
906
1850
|
"""
|
|
907
1851
|
if epi_tile_override is not None:
|
|
908
1852
|
return epi_tile_override
|
|
909
|
-
if
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
|
|
916
|
-
return (tile_m, tile_n)
|
|
1853
|
+
if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
|
|
1854
|
+
tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
|
|
1855
|
+
tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
|
|
1856
|
+
elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
|
|
1857
|
+
tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
|
|
1858
|
+
tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
|
|
917
1859
|
else:
|
|
918
|
-
|
|
1860
|
+
# In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
|
|
1861
|
+
# epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
|
|
1862
|
+
# M dimension first, then move to the N dimension. But the accumulator in registers
|
|
1863
|
+
# iterate along the N dimension first, then move to the M dimension.
|
|
1864
|
+
# We could change the epilogue to accommodate this,
|
|
1865
|
+
# but it's easier to just set epi_tile_m = 64.
|
|
1866
|
+
n_perf = 64 if element_type is not None and element_type.width == 8 else 32
|
|
919
1867
|
tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
|
|
920
1868
|
tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
|
|
921
|
-
|
|
1869
|
+
return (tile_m, tile_n)
|
|
922
1870
|
|
|
923
1871
|
@staticmethod
|
|
924
1872
|
def _make_smem_layouts(
|
|
925
1873
|
tile_shape_mnk: Tuple[int, int, int],
|
|
926
1874
|
epi_tile: Tuple[int, int],
|
|
927
1875
|
a_dtype: Type[cutlass.Numeric],
|
|
928
|
-
a_layout:
|
|
1876
|
+
a_layout: LayoutEnum,
|
|
929
1877
|
b_dtype: Type[cutlass.Numeric],
|
|
930
|
-
b_layout:
|
|
1878
|
+
b_layout: LayoutEnum,
|
|
931
1879
|
ab_stage: int,
|
|
932
|
-
d_dtype: Type[cutlass.Numeric],
|
|
933
|
-
d_layout:
|
|
1880
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1881
|
+
d_layout: LayoutEnum,
|
|
934
1882
|
epi_stage: int,
|
|
935
|
-
|
|
1883
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1884
|
+
c_layout: Optional[LayoutEnum],
|
|
1885
|
+
epi_c_stage: int,
|
|
1886
|
+
) -> Tuple[
|
|
1887
|
+
cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
|
|
1888
|
+
]:
|
|
936
1889
|
"""Create shared memory layouts for A, B, and C tensors.
|
|
937
1890
|
|
|
938
1891
|
:param tile_shape_mnk: CTA tile shape (M,N,K)
|
|
@@ -942,17 +1895,17 @@ class HopperWgmmaGemmKernel:
|
|
|
942
1895
|
:param a_dtype: Data type for matrix A
|
|
943
1896
|
:type a_dtype: type[cutlass.Numeric]
|
|
944
1897
|
:param a_layout: Layout enum for matrix A
|
|
945
|
-
:type a_layout:
|
|
1898
|
+
:type a_layout: LayoutEnum
|
|
946
1899
|
:param b_dtype: Data type for matrix B
|
|
947
1900
|
:type b_dtype: type[cutlass.Numeric]
|
|
948
1901
|
:param b_layout: Layout enum for matrix B
|
|
949
|
-
:type b_layout:
|
|
1902
|
+
:type b_layout: LayoutEnum
|
|
950
1903
|
:param ab_stage: Number of stages for A/B tensors
|
|
951
1904
|
:type ab_stage: int
|
|
952
|
-
:param d_dtype: Data type for output matrix
|
|
1905
|
+
:param d_dtype: Data type for output matrix D
|
|
953
1906
|
:type d_dtype: type[cutlass.Numeric]
|
|
954
1907
|
:param d_layout: Layout enum for the output matrix C
|
|
955
|
-
:type d_layout:
|
|
1908
|
+
:type d_layout: LayoutEnum
|
|
956
1909
|
:param epi_stage: Number of epilogue stages
|
|
957
1910
|
:type epi_stage: int
|
|
958
1911
|
|
|
@@ -965,11 +1918,7 @@ class HopperWgmmaGemmKernel:
|
|
|
965
1918
|
b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
|
966
1919
|
a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
|
|
967
1920
|
a_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
968
|
-
sm90_utils.get_smem_layout_atom(
|
|
969
|
-
a_layout,
|
|
970
|
-
a_dtype,
|
|
971
|
-
a_major_mode_size,
|
|
972
|
-
),
|
|
1921
|
+
sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
|
|
973
1922
|
a_dtype,
|
|
974
1923
|
)
|
|
975
1924
|
a_smem_layout_staged = cute.tile_to_shape(
|
|
@@ -982,11 +1931,7 @@ class HopperWgmmaGemmKernel:
|
|
|
982
1931
|
|
|
983
1932
|
b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
|
|
984
1933
|
b_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
985
|
-
sm90_utils.get_smem_layout_atom(
|
|
986
|
-
b_layout,
|
|
987
|
-
b_dtype,
|
|
988
|
-
b_major_mode_size,
|
|
989
|
-
),
|
|
1934
|
+
sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
|
|
990
1935
|
b_dtype,
|
|
991
1936
|
)
|
|
992
1937
|
b_smem_layout_staged = cute.tile_to_shape(
|
|
@@ -995,56 +1940,52 @@ class HopperWgmmaGemmKernel:
|
|
|
995
1940
|
order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
|
|
996
1941
|
)
|
|
997
1942
|
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
d_layout,
|
|
1943
|
+
if d_dtype is not None:
|
|
1944
|
+
d_smem_shape = epi_tile
|
|
1945
|
+
d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
|
|
1946
|
+
d_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1947
|
+
sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
|
|
1003
1948
|
d_dtype,
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
:
|
|
1027
|
-
|
|
1028
|
-
:param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions.
|
|
1029
|
-
:type cluster_shape_mnk: Tuple[int, int, int]
|
|
1030
|
-
|
|
1031
|
-
:return: Grid shape for kernel launch.
|
|
1032
|
-
:rtype: Tuple[int, int, int]
|
|
1033
|
-
"""
|
|
1949
|
+
)
|
|
1950
|
+
epi_smem_layout_staged = cute.tile_to_shape(
|
|
1951
|
+
d_smem_layout_atom,
|
|
1952
|
+
cute.append(d_smem_shape, epi_stage),
|
|
1953
|
+
order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
|
|
1954
|
+
)
|
|
1955
|
+
else:
|
|
1956
|
+
epi_smem_layout_staged = None
|
|
1957
|
+
|
|
1958
|
+
if c_dtype is not None:
|
|
1959
|
+
assert c_layout is not None
|
|
1960
|
+
c_smem_shape = epi_tile
|
|
1961
|
+
c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
|
|
1962
|
+
c_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1963
|
+
sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
|
|
1964
|
+
c_dtype,
|
|
1965
|
+
)
|
|
1966
|
+
epi_c_smem_layout_staged = cute.tile_to_shape(
|
|
1967
|
+
c_smem_layout_atom,
|
|
1968
|
+
cute.append(c_smem_shape, epi_c_stage),
|
|
1969
|
+
order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
|
|
1970
|
+
)
|
|
1971
|
+
else:
|
|
1972
|
+
epi_c_smem_layout_staged = None
|
|
1034
1973
|
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1974
|
+
return (
|
|
1975
|
+
a_smem_layout_staged,
|
|
1976
|
+
b_smem_layout_staged,
|
|
1977
|
+
epi_smem_layout_staged,
|
|
1978
|
+
epi_c_smem_layout_staged,
|
|
1979
|
+
)
|
|
1040
1980
|
|
|
1041
1981
|
@staticmethod
|
|
1042
|
-
def
|
|
1982
|
+
def _make_tma_epi_atoms_and_tensors(
|
|
1043
1983
|
tensor_d: cute.Tensor,
|
|
1044
1984
|
epi_smem_layout_staged: cute.ComposedLayout,
|
|
1045
1985
|
epi_tile: Tuple[int, int],
|
|
1986
|
+
store_or_load: str,
|
|
1046
1987
|
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
|
1047
|
-
"""Create TMA atoms and tensors for
|
|
1988
|
+
"""Create TMA atoms and tensors for storing D or loading C.
|
|
1048
1989
|
|
|
1049
1990
|
:param tensor_d: Output tensor D
|
|
1050
1991
|
:type tensor_d: cute.Tensor
|
|
@@ -1056,15 +1997,17 @@ class HopperWgmmaGemmKernel:
|
|
|
1056
1997
|
:return: TMA atom and tensor for C
|
|
1057
1998
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
1058
1999
|
"""
|
|
2000
|
+
assert store_or_load in ["load", "store"]
|
|
1059
2001
|
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
|
1060
|
-
|
|
2002
|
+
d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
|
|
2003
|
+
op = (
|
|
2004
|
+
cpasync.CopyBulkTensorTileG2SOp()
|
|
2005
|
+
if store_or_load == "load"
|
|
2006
|
+
else cpasync.CopyBulkTensorTileS2GOp()
|
|
2007
|
+
)
|
|
1061
2008
|
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
|
1062
|
-
|
|
1063
|
-
tensor_d,
|
|
1064
|
-
epi_smem_layout,
|
|
1065
|
-
c_cta_v_layout,
|
|
2009
|
+
op, tensor_d, epi_smem_layout, d_cta_v_layout
|
|
1066
2010
|
)
|
|
1067
|
-
|
|
1068
2011
|
return tma_atom_d, tma_tensor_d
|
|
1069
2012
|
|
|
1070
2013
|
@staticmethod
|
|
@@ -1104,12 +2047,37 @@ class HopperWgmmaGemmKernel:
|
|
|
1104
2047
|
)
|
|
1105
2048
|
return tma_atom, tma_tensor
|
|
1106
2049
|
|
|
2050
|
+
def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128):
|
|
2051
|
+
atom_async_copy = cute.make_copy_atom(
|
|
2052
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
2053
|
+
dtype,
|
|
2054
|
+
num_bits_per_copy=copy_bits,
|
|
2055
|
+
)
|
|
2056
|
+
copy_elems = copy_bits // dtype.width
|
|
2057
|
+
shape_dim_1 = cute.size(self.tile_shape_mnk[2]) // copy_elems
|
|
2058
|
+
# thread layout for copy
|
|
2059
|
+
thread_layout = cute.make_layout(
|
|
2060
|
+
(num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
|
2061
|
+
)
|
|
2062
|
+
if major_mode != LayoutEnum.ROW_MAJOR:
|
|
2063
|
+
shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
|
|
2064
|
+
thread_layout = cute.make_layout(
|
|
2065
|
+
(shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
|
2066
|
+
)
|
|
2067
|
+
# Value layout for copy
|
|
2068
|
+
value_layout = (
|
|
2069
|
+
cute.make_layout((1, copy_elems))
|
|
2070
|
+
if major_mode == LayoutEnum.ROW_MAJOR
|
|
2071
|
+
else cute.make_layout((copy_elems, 1))
|
|
2072
|
+
)
|
|
2073
|
+
return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
|
|
2074
|
+
|
|
1107
2075
|
@staticmethod
|
|
1108
2076
|
def is_valid_dtypes(
|
|
1109
2077
|
a_dtype: Type[cutlass.Numeric],
|
|
1110
2078
|
b_dtype: Type[cutlass.Numeric],
|
|
1111
2079
|
acc_dtype: Type[cutlass.Numeric],
|
|
1112
|
-
d_dtype: Type[cutlass.Numeric],
|
|
2080
|
+
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1113
2081
|
a_major: str,
|
|
1114
2082
|
b_major: str,
|
|
1115
2083
|
) -> bool:
|
|
@@ -1133,7 +2101,6 @@ class HopperWgmmaGemmKernel:
|
|
|
1133
2101
|
:rtype: bool
|
|
1134
2102
|
"""
|
|
1135
2103
|
is_valid = True
|
|
1136
|
-
# tested a_dtype
|
|
1137
2104
|
if a_dtype not in {
|
|
1138
2105
|
cutlass.Float16,
|
|
1139
2106
|
cutlass.BFloat16,
|
|
@@ -1149,11 +2116,11 @@ class HopperWgmmaGemmKernel:
|
|
|
1149
2116
|
cutlass.Float8E5M2,
|
|
1150
2117
|
}:
|
|
1151
2118
|
is_valid = False
|
|
1152
|
-
# tested acc_dtype
|
|
1153
2119
|
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
|
|
1154
2120
|
is_valid = False
|
|
1155
2121
|
# tested d_dtype
|
|
1156
2122
|
if d_dtype not in {
|
|
2123
|
+
None,
|
|
1157
2124
|
cutlass.Float32,
|
|
1158
2125
|
cutlass.Float16,
|
|
1159
2126
|
cutlass.BFloat16,
|
|
@@ -1171,260 +2138,108 @@ class HopperWgmmaGemmKernel:
|
|
|
1171
2138
|
# for Float8 types, this implementation only supports k-major layout
|
|
1172
2139
|
if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
|
|
1173
2140
|
is_valid = False
|
|
1174
|
-
|
|
1175
2141
|
return is_valid
|
|
1176
2142
|
|
|
1177
2143
|
|
|
1178
|
-
def
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
)
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
:type cluster_shape_mn: Tuple[int, int]
|
|
1215
|
-
:param tolerance: Tolerance value for reference validation comparison
|
|
1216
|
-
:type tolerance: float
|
|
1217
|
-
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
|
1218
|
-
:type warmup_iterations: int, optional
|
|
1219
|
-
:param iterations: Number of benchmark iterations to run, defaults to 1
|
|
1220
|
-
:type iterations: int, optional
|
|
1221
|
-
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
|
1222
|
-
:type skip_ref_check: bool, optional
|
|
1223
|
-
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
|
1224
|
-
:type use_cold_l2: bool, optional
|
|
1225
|
-
:return: Execution time of the GEMM kernel in microseconds
|
|
1226
|
-
:rtype: float
|
|
1227
|
-
"""
|
|
1228
|
-
|
|
1229
|
-
print("Running Hopper Dense GEMM with:")
|
|
1230
|
-
print(f"mnkl: {mnkl}")
|
|
1231
|
-
print(f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
|
|
1232
|
-
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
|
|
1233
|
-
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
|
1234
|
-
print(f"Tolerance: {tolerance}")
|
|
1235
|
-
print(f"Warmup iterations: {warmup_iterations}")
|
|
1236
|
-
print(f"Iterations: {iterations}")
|
|
1237
|
-
print(f"Skip reference checking: {skip_ref_check}")
|
|
1238
|
-
print(f"Use cold L2: {use_cold_l2}")
|
|
1239
|
-
|
|
1240
|
-
# Unpack parameters
|
|
1241
|
-
m, n, k, l = mnkl
|
|
1242
|
-
cluster_shape_mnk = (*cluster_shape_mn, 1)
|
|
1243
|
-
|
|
1244
|
-
# Skip unsupported types
|
|
1245
|
-
if not HopperWgmmaGemmKernel.is_valid_dtypes(
|
|
1246
|
-
a_dtype, b_dtype, acc_dtype, d_dtype, a_major, b_major
|
|
2144
|
+
def gemm_sm90(
|
|
2145
|
+
A: Tensor, # (l, m, k)
|
|
2146
|
+
B: Tensor, # (l, n, k)
|
|
2147
|
+
D: Tensor, # (l, m, n)
|
|
2148
|
+
C: Optional[Tensor], # (l, m, n)
|
|
2149
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
2150
|
+
tile_M: int,
|
|
2151
|
+
tile_N: int,
|
|
2152
|
+
cluster_M: int,
|
|
2153
|
+
cluster_N: int,
|
|
2154
|
+
pingpong: bool = False,
|
|
2155
|
+
persistent: bool = True,
|
|
2156
|
+
alpha: float | Tensor = 1.0,
|
|
2157
|
+
beta: float | Tensor = 1.0,
|
|
2158
|
+
) -> None:
|
|
2159
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(A, B, D, C)
|
|
2160
|
+
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
2161
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
2162
|
+
major_configs = {
|
|
2163
|
+
"A": ("m", "k", "l"),
|
|
2164
|
+
"B": ("n", "k", "l"),
|
|
2165
|
+
"D": ("m", "n", "l"),
|
|
2166
|
+
"C": ("m", "n", "l"),
|
|
2167
|
+
}
|
|
2168
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
2169
|
+
|
|
2170
|
+
acc_dtype = cutlass.Float32
|
|
2171
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
2172
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
2173
|
+
if not GemmSm90.is_valid_dtypes(
|
|
2174
|
+
tensor_infos["A"].dtype,
|
|
2175
|
+
tensor_infos["B"].dtype,
|
|
2176
|
+
acc_dtype,
|
|
2177
|
+
tensor_infos["D"].dtype,
|
|
2178
|
+
tensor_infos["A"].major,
|
|
2179
|
+
tensor_infos["B"].major,
|
|
1247
2180
|
):
|
|
1248
|
-
raise TypeError(
|
|
1249
|
-
f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {d_dtype}, {a_major=}, {b_major=}"
|
|
1250
|
-
)
|
|
2181
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
1251
2182
|
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
raise RuntimeError("GPU is required to run this example!")
|
|
1255
|
-
|
|
1256
|
-
torch.manual_seed(1111)
|
|
1257
|
-
|
|
1258
|
-
# Create and permute tensor A/B/C
|
|
1259
|
-
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
|
|
1260
|
-
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
|
1261
|
-
# else : (l, mode0, mode1) -> (mode0, mode1, l)
|
|
1262
|
-
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
|
1263
|
-
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
|
1264
|
-
is_unsigned = dtype in {cutlass.Uint8}
|
|
1265
|
-
# Temporarily use uint8 as torch does not support fp8 type
|
|
1266
|
-
torch_dtype = (
|
|
1267
|
-
cutlass_torch.dtype(dtype)
|
|
1268
|
-
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
1269
|
-
else torch.uint8
|
|
1270
|
-
)
|
|
1271
|
-
|
|
1272
|
-
# Create dtype torch tensor (cpu)
|
|
1273
|
-
torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
|
|
1274
|
-
shape,
|
|
1275
|
-
torch_dtype,
|
|
1276
|
-
permute_order=permute_order,
|
|
1277
|
-
# init_type=cutlass.torch.TensorInitType.RANDOM,
|
|
1278
|
-
# init_config=cutlass.torch.RandomInitConfig(
|
|
1279
|
-
# min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
|
|
1280
|
-
# ),
|
|
1281
|
-
init_type=cutlass.torch.TensorInitType.GAUSSIAN,
|
|
1282
|
-
init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
|
|
1283
|
-
)
|
|
1284
|
-
# Create dtype torch tensor (gpu)
|
|
1285
|
-
torch_tensor = torch_tensor_cpu.cuda()
|
|
1286
|
-
|
|
1287
|
-
# Create f32 torch tensor (cpu)
|
|
1288
|
-
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
|
1289
|
-
|
|
1290
|
-
# Create dtype cute tensor (gpu)
|
|
1291
|
-
cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
|
|
1292
|
-
cute_tensor.element_type = dtype
|
|
1293
|
-
if is_dynamic_layout:
|
|
1294
|
-
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
|
|
1295
|
-
cute_tensor = cutlass.torch.convert_cute_tensor(
|
|
1296
|
-
f32_torch_tensor,
|
|
1297
|
-
cute_tensor,
|
|
1298
|
-
dtype,
|
|
1299
|
-
is_dynamic_layout=is_dynamic_layout,
|
|
1300
|
-
)
|
|
2183
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
2184
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
1301
2185
|
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
|
1306
|
-
c, mC, c_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
|
|
1307
|
-
|
|
1308
|
-
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk)
|
|
1309
|
-
|
|
1310
|
-
torch_stream = torch.cuda.Stream()
|
|
1311
|
-
stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
1312
|
-
# compile gemm kernel
|
|
1313
|
-
compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
|
|
1314
|
-
|
|
1315
|
-
if not skip_ref_check:
|
|
1316
|
-
# execution
|
|
1317
|
-
compiled_gemm(mA, mB, mC, stream)
|
|
1318
|
-
|
|
1319
|
-
torch.cuda.synchronize()
|
|
1320
|
-
|
|
1321
|
-
# Ref check
|
|
1322
|
-
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
|
|
1323
|
-
|
|
1324
|
-
if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
|
1325
|
-
# m major: (l, n, m) -> (m, n, l)
|
|
1326
|
-
# n major: (l, m, n) -> (m, n, l)
|
|
1327
|
-
permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
|
|
1328
|
-
shape = (l, m, n) if d_major == "n" else (l, n, m)
|
|
1329
|
-
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
|
1330
|
-
shape,
|
|
1331
|
-
torch.uint8,
|
|
1332
|
-
permute_order=permute_order,
|
|
1333
|
-
init_type=cutlass_torch.TensorInitType.SKIP,
|
|
1334
|
-
).cuda()
|
|
1335
|
-
# Create dtype cute tensor (gpu)
|
|
1336
|
-
ref_c_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
|
|
1337
|
-
leading_dim=(1 if d_major == "n" else 0)
|
|
1338
|
-
)
|
|
1339
|
-
ref_c_tensor.element_type = d_dtype
|
|
1340
|
-
ref_c_tensor = cutlass_torch.convert_cute_tensor(
|
|
1341
|
-
ref,
|
|
1342
|
-
ref_c_tensor,
|
|
1343
|
-
d_dtype,
|
|
1344
|
-
is_dynamic_layout=True,
|
|
1345
|
-
)
|
|
1346
|
-
ref_c = f8_torch_tensor.cpu()
|
|
2186
|
+
def scalar_arg(scalar: float | Tensor):
|
|
2187
|
+
if isinstance(scalar, float):
|
|
2188
|
+
return Float32(scalar) if scalar != 1.0 else None
|
|
1347
2189
|
else:
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
2190
|
+
assert isinstance(scalar, Tensor)
|
|
2191
|
+
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2192
|
+
|
|
2193
|
+
epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta))
|
|
2194
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
2195
|
+
max_active_clusters, tile_count_semaphore
|
|
2196
|
+
)
|
|
2197
|
+
current_stream = cutlass_torch.current_stream()
|
|
2198
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
2199
|
+
tensor_infos,
|
|
2200
|
+
None,
|
|
2201
|
+
tile_shape_mn,
|
|
2202
|
+
cluster_shape_mnk,
|
|
2203
|
+
pingpong,
|
|
2204
|
+
persistent,
|
|
2205
|
+
tile_count_semaphore is not None,
|
|
2206
|
+
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
|
2207
|
+
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
|
2208
|
+
key_tensor_names=("A", "B", "D", "C"),
|
|
2209
|
+
)
|
|
2210
|
+
cache = gemm_sm90.compile_cache
|
|
2211
|
+
if compile_key not in cache:
|
|
2212
|
+
gemm = GemmSm90(
|
|
2213
|
+
acc_dtype,
|
|
2214
|
+
tensor_infos["A"].dtype,
|
|
2215
|
+
tile_shape_mn,
|
|
2216
|
+
cluster_shape_mnk,
|
|
2217
|
+
pingpong=pingpong,
|
|
2218
|
+
is_persistent=persistent,
|
|
1364
2219
|
)
|
|
1365
|
-
|
|
1366
|
-
|
|
2220
|
+
cache[compile_key] = cute.compile(
|
|
2221
|
+
gemm,
|
|
2222
|
+
tensor_infos["A"].cute_tensor,
|
|
2223
|
+
tensor_infos["B"].cute_tensor,
|
|
2224
|
+
tensor_infos["D"].cute_tensor,
|
|
2225
|
+
tensor_infos["C"].cute_tensor,
|
|
2226
|
+
epi_args,
|
|
2227
|
+
scheduler_args,
|
|
2228
|
+
None, # varlen_args
|
|
2229
|
+
None, # mAIdx
|
|
2230
|
+
current_stream,
|
|
1367
2231
|
)
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
2232
|
+
cache[compile_key](
|
|
2233
|
+
tensor_infos["A"].cute_tensor,
|
|
2234
|
+
tensor_infos["B"].cute_tensor,
|
|
2235
|
+
tensor_infos["D"].cute_tensor,
|
|
2236
|
+
tensor_infos["C"].cute_tensor,
|
|
2237
|
+
epi_args,
|
|
2238
|
+
scheduler_args,
|
|
2239
|
+
None,
|
|
2240
|
+
None,
|
|
2241
|
+
current_stream,
|
|
1376
2242
|
)
|
|
1377
2243
|
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
1381
|
-
|
|
1382
|
-
flops = 2 * m * n * k * l
|
|
1383
|
-
|
|
1384
|
-
repeats = 30
|
|
1385
|
-
# repeats = 1
|
|
1386
|
-
warmup = 5
|
|
1387
|
-
|
|
1388
|
-
import time
|
|
1389
|
-
|
|
1390
|
-
time.sleep(0.5)
|
|
1391
|
-
fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
|
|
1392
|
-
timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
|
|
1393
|
-
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
1394
|
-
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
1395
|
-
|
|
1396
|
-
time.sleep(0.5)
|
|
1397
|
-
fn = lambda: compiled_gemm(mA, mB, mC, current_stream)
|
|
1398
|
-
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
1399
|
-
tflops = flops / (timing * 1e9) # Convert to TFlops
|
|
1400
|
-
print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
|
|
1401
|
-
|
|
1402
|
-
time.sleep(0.5)
|
|
1403
|
-
fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
|
|
1404
|
-
timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
|
|
1405
|
-
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
1406
|
-
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
1407
|
-
|
|
1408
|
-
return exec_time # Return execution time in microseconds
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
if __name__ == "__main__":
|
|
1412
|
-
args = parse_arguments()
|
|
1413
|
-
run(
|
|
1414
|
-
args.mnkl,
|
|
1415
|
-
args.a_dtype,
|
|
1416
|
-
args.b_dtype,
|
|
1417
|
-
args.d_dtype,
|
|
1418
|
-
args.acc_dtype,
|
|
1419
|
-
args.a_major,
|
|
1420
|
-
args.b_major,
|
|
1421
|
-
args.d_major,
|
|
1422
|
-
args.tile_shape_mnk,
|
|
1423
|
-
args.cluster_shape_mn,
|
|
1424
|
-
args.tolerance,
|
|
1425
|
-
args.warmup_iterations,
|
|
1426
|
-
args.iterations,
|
|
1427
|
-
args.skip_ref_check,
|
|
1428
|
-
args.use_cold_l2,
|
|
1429
|
-
)
|
|
1430
|
-
print("PASS")
|
|
2244
|
+
|
|
2245
|
+
gemm_sm90.compile_cache = {}
|