quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/dense_gemm_sm90.py CHANGED
@@ -1,47 +1,43 @@
1
- # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD-3-Clause
3
-
4
- # Redistribution and use in source and binary forms, with or without
5
- # modification, are permitted provided that the following conditions are met:
6
-
7
- # 1. Redistributions of source code must retain the above copyright notice, this
8
- # list of conditions and the following disclaimer.
9
-
10
- # 2. Redistributions in binary form must reproduce the above copyright notice,
11
- # this list of conditions and the following disclaimer in the documentation
12
- # and/or other materials provided with the distribution.
13
-
14
- # 3. Neither the name of the copyright holder nor the names of its
15
- # contributors may be used to endorse or promote products derived from
16
- # this software without specific prior written permission.
17
-
18
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
-
29
- import argparse
30
- from typing import Tuple, Type
1
+ # Based on the cute-dsl example:
2
+ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
3
+
4
+ import enum
5
+ from typing import Tuple, Type, Callable, Optional, Union
6
+ from dataclasses import dataclass
7
+ from functools import partial
31
8
  import math
32
- import cuda.bindings.driver as cuda
33
9
 
34
- import torch
10
+ from torch import Tensor
11
+
12
+ import cuda.bindings.driver as cuda
35
13
 
36
14
  import cutlass
37
15
  import cutlass.cute as cute
38
- import cutlass.cute.testing as testing
39
- import cutlass.utils as utils
40
16
  import cutlass.pipeline as pipeline
41
- import cutlass.torch as cutlass_torch
42
- from cutlass.cute.runtime import from_dlpack
43
17
  from cutlass.cute.nvgpu import cpasync, warp, warpgroup
44
18
  import cutlass.utils.hopper_helpers as sm90_utils
19
+ from cutlass import Int32, Float32, Boolean, const_expr
20
+ from cutlass.utils import LayoutEnum
21
+ import cutlass.torch as cutlass_torch
22
+ from cutlass.cute.runtime import make_ptr
23
+
24
+
25
+ from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
26
+ from quack.tile_scheduler import (
27
+ TileSchedulerOptions,
28
+ TileSchedulerArguments,
29
+ TileScheduler,
30
+ VarlenMTileSchedulerArguments,
31
+ VarlenMTileScheduler,
32
+ )
33
+ from quack.varlen_utils import VarlenArguments
34
+ from quack.tensormap_manager import TensorMapManagerSm90
35
+
36
+ # return PipelineStateWAdvance instead of PipelineState
37
+ from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
38
+ import quack.utils as utils
39
+ from quack.cute_dsl_utils import get_max_active_clusters
40
+ from quack.gemm_wrapper_utils import GemmWrapperBase
45
41
 
46
42
  """
47
43
  A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
@@ -66,31 +62,6 @@ Hopper WGMMA instructions operate as follows:
66
62
  - Read matrix B from SMEM
67
63
  - Perform MMA operation and store the result in Accumulator(register)
68
64
 
69
- To run this example:
70
-
71
- .. code-block:: bash
72
-
73
- python examples/hopper/dense_gemm.py \
74
- --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
75
- --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
76
- --d_dtype Float16 --acc_dtype Float32 \
77
- --a_major k --b_major k --d_major n
78
-
79
- The above example command compute batched gemm with M=8192, N=8192, K=8192,
80
- batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape
81
- is (1,1). The input, mma accumulator and output data type are set as fp16, fp32
82
- and fp16, respectively.
83
-
84
- To collect performance with NCU profiler:
85
-
86
- .. code-block:: bash
87
-
88
- ncu python examples/hopper/dense_gemm.py \
89
- --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
90
- --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
91
- --d_dtype Float16 --acc_dtype Float32 \
92
- --a_major k --b_major k --d_major n
93
-
94
65
  Constraints:
95
66
  * Supported input data types: fp16, fp8 (e4m3fn, e5m2)
96
67
  * For fp16 types, A and B must have the same data type
@@ -103,107 +74,29 @@ Constraints:
103
74
  * Cluster shape M/N must be positive and power of 2, total cluster size <= 4
104
75
  * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
105
76
  i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
106
- * OOB tiles are not allowed when TMA store is disabled
107
77
  """
108
78
 
109
79
 
110
- # /////////////////////////////////////////////////////////////////////////////
111
- # Helpers to parse args
112
- # /////////////////////////////////////////////////////////////////////////////
113
- def parse_comma_separated_ints(s: str):
114
- try:
115
- return tuple([int(x.strip()) for x in s.split(",")])
116
- except ValueError:
117
- raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
118
-
119
-
120
- def parse_arguments() -> argparse.Namespace:
121
- parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.")
122
-
123
- parser.add_argument(
124
- "--mnkl",
125
- type=parse_comma_separated_ints,
126
- default=(4096, 4096, 4096, 1),
127
- help="mnkl dimensions (comma-separated)",
128
- )
129
- parser.add_argument(
130
- "--tile_shape_mnk",
131
- type=parse_comma_separated_ints,
132
- default=(128, 256, 64),
133
- help="Cta tile shape (comma-separated)",
134
- )
135
- parser.add_argument(
136
- "--cluster_shape_mn",
137
- type=parse_comma_separated_ints,
138
- choices=[(1, 1), (2, 1), (1, 2), (2, 2)],
139
- default=(1, 1),
140
- help="Cluster shape (comma-separated)",
141
- )
142
- parser.add_argument(
143
- "--a_dtype",
144
- type=cutlass.dtype,
145
- default=cutlass.BFloat16,
146
- )
147
- parser.add_argument(
148
- "--b_dtype",
149
- type=cutlass.dtype,
150
- default=cutlass.BFloat16,
151
- )
152
- parser.add_argument(
153
- "--d_dtype",
154
- type=cutlass.dtype,
155
- default=cutlass.BFloat16,
156
- )
157
- parser.add_argument(
158
- "--acc_dtype",
159
- type=cutlass.dtype,
160
- default=cutlass.Float32,
161
- )
162
- parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
163
- parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
164
- parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
165
- parser.add_argument("--tolerance", type=float, default=1e-01, help="Tolerance for validation")
166
- parser.add_argument("--warmup_iterations", type=int, default=0, help="Warmup iterations")
167
- parser.add_argument(
168
- "--iterations",
169
- type=int,
170
- default=1,
171
- help="Number of iterations to run the kernel",
172
- )
173
- parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
174
- parser.add_argument(
175
- "--use_cold_l2",
176
- action="store_true",
177
- default=False,
178
- help="Use circular buffer tensor sets to ensure L2 cold cache",
179
- )
180
-
181
- args = parser.parse_args()
182
-
183
- if len(args.mnkl) != 4:
184
- parser.error("--mnkl must contain exactly 4 values")
185
- if len(args.tile_shape_mnk) != 3:
186
- parser.error("--tile_shape_mnk must contain exactly 3 values")
187
- if len(args.cluster_shape_mn) != 2:
188
- parser.error("--cluster_shape_mn must contain exactly 2 values")
189
-
190
- return args
191
-
192
-
193
- # /////////////////////////////////////////////////////////////////////////////
194
- # Host setup and device kernel launch
195
- # /////////////////////////////////////////////////////////////////////////////
80
+ class NamedBarrierGemm(enum.IntEnum):
81
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
82
+ # For mainloop load warps to signal that the epilogue load warp can start.
83
+ # This is to avoid loading C too early, interfering with loading A and B.
84
+ EpilogueLoad = enum.auto()
85
+ MmaWG0 = enum.auto()
86
+ MmaWG1 = enum.auto()
87
+ EpiWG0 = enum.auto()
88
+ EpiWG1 = enum.auto()
196
89
 
197
90
 
198
- class HopperWgmmaGemmKernel:
91
+ class GemmSm90:
199
92
  """
200
93
  This class implements batched matrix multiplication (C = A x B) with support for various data types
201
94
  and architectural features specific to Hopper GPUs.
202
95
 
203
96
  :param acc_dtype: Data type for accumulation during computation
204
97
  :type acc_dtype: type[cutlass.Numeric]
205
- :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
206
- :type tile_shape_mnk: Tuple[int, int, int]
98
+ :param tile_shape_mn: Shape of the CTA tile (M,N)
99
+ :type tile_shape_mn: Tuple[int, int, int]
207
100
  :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
208
101
  :type cluster_shape_mnk: Tuple[int, int, int]
209
102
 
@@ -221,25 +114,39 @@ class HopperWgmmaGemmKernel:
221
114
  - Float32 (for all floating point inputs)
222
115
 
223
116
  :note: Constraints:
224
- - CTA tile M must be 64/128
225
- - CTA tile N must be 64/128/256
226
- - CTA tile K must be 64
227
117
  - Cluster shape M/N must be positive and power of 2, total cluster size <= 4
228
118
 
229
119
  Example:
230
- >>> gemm = HopperWgmmaGemmKernel(
120
+ >>> gemm = GemmSm90(
231
121
  ... acc_dtype=cutlass.Float32,
232
- ... tile_shape_mnk=(128, 256, 64),
122
+ ... tile_shape_mn=(128, 256),
233
123
  ... cluster_shape_mnk=(1, 1, 1)
234
124
  ... )
235
125
  >>> gemm(a_tensor, b_tensor, c_tensor, stream)
236
126
  """
237
127
 
128
+ bytes_per_tensormap = 128
129
+
130
+ @dataclass
131
+ class EpilogueArguments(ArgumentsBase):
132
+ alpha: Optional[Float32 | cute.Tensor] = None
133
+ beta: Optional[Float32 | cute.Tensor] = None
134
+
135
+ @dataclass
136
+ class EpilogueParams(ParamsBase):
137
+ alpha: Optional[Float32 | cute.Tensor] = None
138
+ beta: Optional[Float32 | cute.Tensor] = None
139
+
238
140
  def __init__(
239
141
  self,
240
142
  acc_dtype: Type[cutlass.Numeric],
241
- tile_shape_mnk: Tuple[int, int, int],
143
+ a_dtype: Type[cutlass.Numeric],
144
+ tile_shape_mn: Tuple[int, int],
242
145
  cluster_shape_mnk: Tuple[int, int, int],
146
+ pingpong: bool = False,
147
+ is_persistent: bool = True,
148
+ fp8_fast_accum: bool = False,
149
+ gather_A: bool = False,
243
150
  ):
244
151
  """
245
152
  Initializes the configuration for a Hopper dense GEMM kernel.
@@ -249,59 +156,106 @@ class HopperWgmmaGemmKernel:
249
156
 
250
157
  :param acc_dtype: Data type for accumulation during computation
251
158
  :type acc_dtype: type[cutlass.Numeric]
252
- :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
253
- :type tile_shape_mnk: Tuple[int, int, int]
159
+ :param tile_shape_mn: Shape of the CTA tile (M,N)
160
+ :type tile_shape_mn: Tuple[int, int]
254
161
  :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
255
162
  :type cluster_shape_mnk: Tuple[int, int, int]
256
163
  """
257
164
 
258
165
  self.acc_dtype = acc_dtype
166
+ self.pingpong = pingpong
167
+ self.is_persistent = is_persistent
168
+ if self.pingpong:
169
+ assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
170
+ self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
171
+ self.gather_A = gather_A
172
+ if gather_A:
173
+ assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
259
174
 
260
175
  self.cluster_shape_mnk = cluster_shape_mnk
261
- self.tile_shape_mnk = tuple(tile_shape_mnk)
262
- tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1]
176
+ # K dimension is deferred in _setup_attributes
177
+ self.tile_shape_mnk = (*tile_shape_mn, 1)
178
+ tile_M, tile_N = self.tile_shape_mnk[0], self.tile_shape_mnk[1]
263
179
  # check the cta tile shape
264
- # if tile_M not in [64, 128, 192, 256]:
265
- # TODO: M=192 currently doesn't work
266
- if tile_M not in [64, 128, 256]:
267
- raise ValueError("CTA tile shape M must be 64/128/192/256")
268
- if tile_M == 192: # special case
269
- if not (tile_N % 32 == 0 and tile_N <= 288):
270
- raise ValueError(
271
- "If tile_m == 192, CTA tile shape N must be divisible by 32 and <= 288"
272
- )
180
+ if not self.pingpong:
181
+ if tile_M not in [64, 128, 192, 256, 320]:
182
+ raise ValueError("CTA tile shape M must be 64/128/192/256/320")
183
+ if tile_M in [192, 320]: # special case
184
+ tile_N_max = 256 if tile_M == 192 else 160
185
+ if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
186
+ raise ValueError(
187
+ f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
188
+ )
189
+ else:
190
+ if not (
191
+ (tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
192
+ ):
193
+ raise ValueError(
194
+ "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
195
+ )
273
196
  else:
274
- if not ((tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)):
275
- raise ValueError(
276
- "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
277
- )
278
- if not self.tile_shape_mnk[2] % 16 == 0:
279
- raise ValueError("CTA tile shape K must be divisible by 16")
280
-
281
- if tile_M == 192: # Special case
282
- atom_layout_m, atom_layout_n = 1, 2
197
+ if tile_M not in [64, 128, 192]:
198
+ raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
199
+ tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
200
+ if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
201
+ raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
202
+
203
+ if not self.pingpong:
204
+ if tile_M == 320: # tile_M / 64 is not even so we have to split along N
205
+ atom_layout_m, atom_layout_n = 1, 2
206
+ elif tile_M == 192:
207
+ if tile_N <= 128:
208
+ atom_layout_m, atom_layout_n = 3, 1
209
+ else:
210
+ atom_layout_m, atom_layout_n = 1, 2
211
+ else:
212
+ atom_layout_m = self.tile_shape_mnk[0] // 64 if self.tile_shape_mnk[0] < 256 else 2
213
+ atom_layout_n = 1
214
+ assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
283
215
  else:
284
- atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
285
- atom_layout_n = 1
286
- assert atom_layout_m in [1, 2] and atom_layout_n in [1, 2]
216
+ atom_layout_m, atom_layout_n = 1, 1
287
217
  self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
288
218
 
289
- self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
219
+ self.num_mcast_ctas_a = self.cluster_shape_mnk[1] if not self.gather_A else 1
290
220
  self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
291
221
  self.is_a_mcast = self.num_mcast_ctas_a > 1
292
222
  self.is_b_mcast = self.num_mcast_ctas_b > 1
293
223
 
294
224
  self.occupancy = 1
295
- self.mma_warp_groups = math.prod(self.atom_layout_mnk)
225
+ self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
226
+ if self.pingpong:
227
+ assert self.mma_warp_groups == 2
228
+ assert self.mma_warp_groups in [1, 2, 3]
296
229
  self.num_threads_per_warp_group = 128
297
230
  self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
298
- self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
299
- self.num_mma_threads = self.mma_warp_groups * self.num_threads_per_warp_group
300
-
301
- regs_per_thread = math.prod(self.tile_shape_mnk) // self.num_mma_threads
302
- heavy_register_pressure = regs_per_thread >= 208
303
- self.num_regs_load = 40 if not heavy_register_pressure else 24
304
- self.num_regs_mma = 232 if not heavy_register_pressure else 240
231
+ self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
232
+ self.num_epi_threads = (
233
+ self.mma_warp_groups if not self.pingpong else 1
234
+ ) * self.num_threads_per_warp_group
235
+ self.num_ab_load_warps = 1 if not self.gather_A else 4
236
+ self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
237
+ self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
238
+ self.ab_load_warp_id = self.mma_warp_groups * 4
239
+ self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
240
+
241
+ regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // (
242
+ math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
243
+ )
244
+ if self.fp8_slow_accum:
245
+ regs_per_thread *= 2
246
+ if not self.gather_A:
247
+ if self.mma_warp_groups == 3:
248
+ self.num_regs_load, self.num_regs_mma = 32, 160
249
+ else:
250
+ heavy_register_pressure = regs_per_thread >= 208
251
+ self.num_regs_load, self.num_regs_mma = (
252
+ (40, 232) if not heavy_register_pressure else (24, 240)
253
+ )
254
+ else:
255
+ if self.mma_warp_groups == 3:
256
+ self.num_regs_load, self.num_regs_mma = 56, 152
257
+ else:
258
+ self.num_regs_load, self.num_regs_mma = (56, 224)
305
259
 
306
260
  self.ab_stage = None
307
261
  self.epi_stage = None
@@ -314,7 +268,7 @@ class HopperWgmmaGemmKernel:
314
268
  self.shared_storage = None
315
269
  self.buffer_align_bytes = 1024
316
270
 
317
- def _setup_attributes(self):
271
+ def _setup_attributes(self, epilogue_args: Optional[EpilogueArguments]):
318
272
  """Set up configurations that are dependent on GEMM inputs
319
273
 
320
274
  This method configures various attributes based on the input tensor properties
@@ -328,26 +282,67 @@ class HopperWgmmaGemmKernel:
328
282
  - Computing A/B/C shared memory layout
329
283
  """
330
284
 
331
- self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
285
+ self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
286
+ self.a_dtype,
287
+ self.b_dtype,
288
+ self.a_layout.sm90_mma_major_mode(),
289
+ self.b_layout.sm90_mma_major_mode(),
290
+ self.acc_dtype,
291
+ self.atom_layout_mnk,
292
+ tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
293
+ )
294
+ if const_expr(self.atom_layout_mnk[1] > 1):
295
+ # If N dimension is split among 2 WGs, we need to permute the N dimension so
296
+ # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
297
+ # containing accumulators that are next to each other in the N dimension.
298
+ # Without permutation WG0 would write to epi smem of size (64, 16) and
299
+ # WG1 would write to a separate epi smem of size (64, 16) that's far away.
300
+ atom_n = self.atom_layout_mnk[1]
301
+ permutation_n = cute.make_ordered_layout(
302
+ (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
303
+ )
304
+ self.tiled_mma = cute.make_tiled_mma(
305
+ cute.make_mma_atom(self.tiled_mma.op),
306
+ self.atom_layout_mnk,
307
+ permutation_mnk=(None, permutation_n, None),
308
+ )
309
+ mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
310
+ mma_inst_tile_k = 4
311
+ self.tile_shape_mnk = (
312
+ self.tile_shape_mnk[0],
313
+ self.tile_shape_mnk[1],
314
+ mma_inst_shape_k * mma_inst_tile_k,
315
+ )
316
+
317
+ self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
332
318
 
333
- is_cooperative = math.prod(self.atom_layout_mnk) > 1
334
319
  self.epi_tile = self._sm90_compute_tile_shape_or_override(
335
- self.tile_shape_mnk, self.d_dtype, is_cooperative=is_cooperative
320
+ self.tile_shape_mnk,
321
+ self.atom_layout_mnk,
322
+ self.d_dtype,
336
323
  )
337
324
 
338
325
  # Compute stage before compute smem layout
339
- self.ab_stage, self.epi_stage = self._compute_stages(
326
+ self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
340
327
  self.tile_shape_mnk,
328
+ self.epi_tile,
341
329
  self.a_dtype,
342
330
  self.b_dtype,
331
+ self.d_dtype,
332
+ self.c_dtype,
333
+ epilogue_args,
343
334
  self.smem_capacity,
344
335
  self.occupancy,
336
+ # epi_smem will reuse smem ab if not persistent.
337
+ overlap_sD_sA=not self.is_persistent,
345
338
  )
339
+ self.sched_stage = 2 if self.pingpong else 1
346
340
 
347
341
  (
348
342
  self.a_smem_layout_staged,
349
343
  self.b_smem_layout_staged,
350
344
  self.epi_smem_layout_staged,
345
+ self.epi_c_smem_layout_staged,
351
346
  ) = self._make_smem_layouts(
352
347
  self.tile_shape_mnk,
353
348
  self.epi_tile,
@@ -359,6 +354,9 @@ class HopperWgmmaGemmKernel:
359
354
  self.d_dtype,
360
355
  self.d_layout,
361
356
  self.epi_stage,
357
+ self.c_dtype,
358
+ self.c_layout,
359
+ self.epi_c_stage,
362
360
  )
363
361
 
364
362
  @cute.jit
@@ -366,7 +364,12 @@ class HopperWgmmaGemmKernel:
366
364
  self,
367
365
  mA: cute.Tensor,
368
366
  mB: cute.Tensor,
369
- mD: cute.Tensor,
367
+ mD: Optional[cute.Tensor],
368
+ mC: Optional[cute.Tensor],
369
+ epilogue_args: Optional[ArgumentsBase],
370
+ scheduler_args: TileSchedulerOptions,
371
+ varlen_args: Optional[VarlenArguments],
372
+ mAIdx: Optional[cute.Tensor],
370
373
  stream: cuda.CUstream,
371
374
  ):
372
375
  """Execute the GEMM operation in steps:
@@ -389,36 +392,44 @@ class HopperWgmmaGemmKernel:
389
392
  # setup static attributes before smem/grid/tma computation
390
393
  self.a_dtype = mA.element_type
391
394
  self.b_dtype = mB.element_type
392
- self.d_dtype = mD.element_type
393
- self.a_layout = utils.LayoutEnum.from_tensor(mA)
394
- self.b_layout = utils.LayoutEnum.from_tensor(mB)
395
- self.d_layout = utils.LayoutEnum.from_tensor(mD)
396
-
397
- if cutlass.const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
395
+ self.d_dtype = mD.element_type if mD is not None else None
396
+ self.c_dtype = mC.element_type if mC is not None else None
397
+ self.a_layout = LayoutEnum.from_tensor(mA)
398
+ self.b_layout = LayoutEnum.from_tensor(mB)
399
+ self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
400
+ self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
401
+
402
+ if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
398
403
  raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
399
- if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width):
404
+ if const_expr(self.a_dtype.width != self.b_dtype.width):
400
405
  raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
401
- if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
406
+ if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
402
407
  raise TypeError("a_dtype should be float16 or float8")
408
+ assert (mAIdx is not None) == self.gather_A
403
409
 
404
- self._setup_attributes()
405
-
406
- tiled_mma = sm90_utils.make_trivial_tiled_mma(
407
- self.a_dtype,
408
- self.b_dtype,
409
- self.a_layout.sm90_mma_major_mode(),
410
- self.b_layout.sm90_mma_major_mode(),
411
- self.acc_dtype,
412
- self.atom_layout_mnk,
413
- tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
414
- )
415
-
416
- tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
417
- mA,
418
- self.a_smem_layout_staged,
419
- (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
420
- self.cluster_shape_mnk[1],
410
+ # Assume all strides are divisible by 128 bits except the last stride
411
+ new_stride = lambda t: tuple(
412
+ cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
413
+ for s in t.stride
421
414
  )
415
+ mA, mD = [
416
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
417
+ if t is not None
418
+ else None
419
+ for t in (mA, mD)
420
+ ]
421
+
422
+ self._setup_attributes(epilogue_args)
423
+
424
+ if const_expr(not self.gather_A):
425
+ tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
426
+ mA,
427
+ self.a_smem_layout_staged,
428
+ (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
429
+ self.cluster_shape_mnk[1],
430
+ )
431
+ else:
432
+ tma_atom_a, tma_tensor_a = None, None
422
433
 
423
434
  tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
424
435
  mB,
@@ -427,17 +438,89 @@ class HopperWgmmaGemmKernel:
427
438
  self.cluster_shape_mnk[0],
428
439
  )
429
440
 
430
- tma_atom_d, tma_tensor_d = self._make_tma_store_atoms_and_tensors(
431
- mD,
432
- self.epi_smem_layout_staged,
433
- self.epi_tile,
441
+ if const_expr(mD is not None):
442
+ tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
443
+ mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
444
+ )
445
+ else:
446
+ tma_atom_d, tma_tensor_d = None, None
447
+
448
+ if const_expr(mC is not None):
449
+ tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
450
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
451
+ )
452
+ else:
453
+ tma_atom_c, tma_tensor_c = None, None
454
+
455
+ epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
456
+
457
+ if const_expr(varlen_args is None):
458
+ varlen_args = VarlenArguments()
459
+ if const_expr(varlen_args.mCuSeqlensM is None):
460
+ num_problems = (
461
+ mD.shape[2]
462
+ if mD is not None
463
+ else (
464
+ mB.shape[2]
465
+ if varlen_args.mCuSeqlensK is None
466
+ else varlen_args.mCuSeqlensK.shape[0] - 1
467
+ )
468
+ )
469
+ problem_shape_ntile_mnl = (
470
+ cute.ceil_div(mA.shape[0], self.tile_shape_mnk[0]),
471
+ cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
472
+ num_problems,
473
+ )
474
+ TileSchedulerCls = self.get_scheduler_class()
475
+ tile_sched_args = self.get_scheduler_arguments(problem_shape_ntile_mnl, scheduler_args)
476
+ else:
477
+ assert mD is not None or not self.gather_A
478
+ problem_shape_ntile_mnl = (
479
+ None,
480
+ cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
481
+ varlen_args.mCuSeqlensM.shape[0] - 1,
482
+ )
483
+ TileSchedulerCls = VarlenMTileScheduler
484
+ tile_sched_args = VarlenMTileSchedulerArguments(
485
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
486
+ total_m=mD.shape[0] if mD is not None else mAIdx.shape[0],
487
+ cu_seqlens_m=varlen_args.mCuSeqlensM,
488
+ raster_order=scheduler_args.raster_order,
489
+ group_size=scheduler_args.max_swizzle_size,
490
+ tile_shape_mn=self.tile_shape_mnk[:2],
491
+ cluster_shape_mnk=self.cluster_shape_mnk,
492
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
493
+ is_persistent=self.is_persistent,
494
+ )
495
+ tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
496
+ grid = TileSchedulerCls.get_grid_shape(
497
+ tile_sched_params, scheduler_args.max_active_clusters
434
498
  )
435
499
 
436
- grid = self._compute_grid(mD, self.tile_shape_mnk, self.cluster_shape_mnk)
500
+ epi_smem_size = (
501
+ cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
502
+ )
503
+ epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
437
504
 
438
505
  @cute.struct
439
506
  class SharedStorage:
440
- mainloop_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
507
+ ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
508
+ epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
509
+ sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
510
+ tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
511
+ sD: cute.struct.Align[
512
+ cute.struct.MemRange[
513
+ self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
514
+ ],
515
+ self.buffer_align_bytes,
516
+ ]
517
+ sC: cute.struct.Align[
518
+ cute.struct.MemRange[
519
+ self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
520
+ ],
521
+ self.buffer_align_bytes,
522
+ ]
523
+ epi: self.epi_get_smem_struct(epilogue_params)
441
524
  sA: cute.struct.Align[
442
525
  cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
443
526
  self.buffer_align_bytes,
@@ -452,16 +535,26 @@ class HopperWgmmaGemmKernel:
452
535
  # Launch the kernel synchronously
453
536
  self.kernel(
454
537
  tma_atom_a,
455
- tma_tensor_a,
538
+ tma_tensor_a if const_expr(not self.gather_A) else mA,
456
539
  tma_atom_b,
457
540
  tma_tensor_b,
458
541
  tma_atom_d,
459
542
  tma_tensor_d,
460
- tiled_mma,
461
- self.cta_layout_mnk,
543
+ tma_atom_c,
544
+ tma_tensor_c,
545
+ epilogue_params,
546
+ mAIdx,
547
+ varlen_args.mCuSeqlensM,
548
+ varlen_args.mCuSeqlensK,
549
+ varlen_args.mTensormaps,
550
+ self.tiled_mma,
551
+ self.cluster_layout_mnk,
462
552
  self.a_smem_layout_staged,
463
553
  self.b_smem_layout_staged,
464
554
  self.epi_smem_layout_staged,
555
+ self.epi_c_smem_layout_staged,
556
+ tile_sched_params,
557
+ TileSchedulerCls,
465
558
  ).launch(
466
559
  grid=grid,
467
560
  block=[self.threads_per_cta, 1, 1],
@@ -476,17 +569,27 @@ class HopperWgmmaGemmKernel:
476
569
  @cute.kernel
477
570
  def kernel(
478
571
  self,
479
- tma_atom_a: cute.CopyAtom,
572
+ tma_atom_a: Optional[cute.CopyAtom],
480
573
  mA_mkl: cute.Tensor,
481
574
  tma_atom_b: cute.CopyAtom,
482
575
  mB_nkl: cute.Tensor,
483
- tma_atom_d: cute.CopyAtom,
484
- mD_mnl: cute.Tensor,
576
+ tma_atom_d: Optional[cute.CopyAtom],
577
+ mD_mnl: Optional[cute.Tensor],
578
+ tma_atom_c: Optional[cute.CopyAtom],
579
+ mC_mnl: Optional[cute.Tensor],
580
+ epilogue_params: ParamsBase,
581
+ mAIdx: Optional[cute.Tensor],
582
+ cu_seqlens_m: Optional[cute.Tensor],
583
+ cu_seqlens_k: Optional[cute.Tensor],
584
+ tensormaps: Optional[cute.Tensor],
485
585
  tiled_mma: cute.TiledMma,
486
- cta_layout_mnk: cute.Layout,
487
- a_smem_layout_staged: cute.ComposedLayout,
488
- b_smem_layout_staged: cute.ComposedLayout,
489
- epi_smem_layout_staged: cute.ComposedLayout,
586
+ cluster_layout_mnk: cute.Layout,
587
+ a_smem_layout: cute.ComposedLayout,
588
+ b_smem_layout: cute.ComposedLayout,
589
+ epi_smem_layout: cute.ComposedLayout,
590
+ epi_c_smem_layout: cute.ComposedLayout,
591
+ tile_sched_params: ParamsBase,
592
+ TileSchedulerCls: cutlass.Constexpr[Callable],
490
593
  ):
491
594
  """
492
595
  GPU device kernel performing the batched GEMM computation.
@@ -505,32 +608,31 @@ class HopperWgmmaGemmKernel:
505
608
  :type mD_mnl: cute.Tensor
506
609
  :param tiled_mma: Tiled MMA object
507
610
  :type tiled_mma: cute.TiledMma
508
- :param cta_layout_mnk: CTA layout
509
- :type cta_layout_mnk: cute.Layout
510
- :param a_smem_layout_staged: Shared memory layout for A
511
- :type a_smem_layout_staged: cute.ComposedLayout
512
- :param b_smem_layout_staged: Shared memory layout for B
513
- :type b_smem_layout_staged: cute.ComposedLayout
514
- :param epi_smem_layout_staged: Shared memory layout for epilogue
515
- :type epi_smem_layout_staged: cute.ComposedLayout
611
+ :param cluster_layout_mnk: CTA layout
612
+ :type cluster_layout_mnk: cute.Layout
613
+ :param a_smem_layout: Shared memory layout for A
614
+ :type a_smem_layout: cute.ComposedLayout
615
+ :param b_smem_layout: Shared memory layout for B
616
+ :type b_smem_layout: cute.ComposedLayout
617
+ :param epi_smem_layout: Shared memory layout for epilogue
618
+ :type epi_smem_layout: cute.ComposedLayout
516
619
  """
517
620
 
621
+ varlen_m = const_expr(cu_seqlens_m is not None)
622
+ varlen_k = const_expr(cu_seqlens_k is not None)
623
+ assert not (varlen_m and varlen_k)
624
+ has_D = const_expr(mD_mnl is not None)
625
+ has_C = const_expr(mC_mnl is not None)
626
+
518
627
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
519
628
 
520
629
  # /////////////////////////////////////////////////////////////////////////////
521
630
  # Prefetch Tma desc
522
631
  # /////////////////////////////////////////////////////////////////////////////
523
- # if warp_idx == 0:
524
- if warp_idx == self.mma_warp_groups * 4:
525
- cpasync.prefetch_descriptor(tma_atom_a)
526
- cpasync.prefetch_descriptor(tma_atom_b)
527
- cpasync.prefetch_descriptor(tma_atom_d)
528
-
529
- a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
530
- b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
531
- tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes(
532
- self.b_dtype, b_smem_layout
533
- )
632
+ if warp_idx == self.ab_load_warp_id:
633
+ for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
634
+ if const_expr(tma_atom is not None):
635
+ cpasync.prefetch_descriptor(tma_atom)
534
636
 
535
637
  # /////////////////////////////////////////////////////////////////////////////
536
638
  # Alloc and init AB full/empty + ACC full mbar (pipeline)
@@ -538,164 +640,321 @@ class HopperWgmmaGemmKernel:
538
640
  smem = cutlass.utils.SmemAllocator()
539
641
  storage = smem.allocate(self.shared_storage)
540
642
 
541
- # Threads/warps participating in this pipeline
542
- mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
543
- # Each warp will constribute to the arrive count with the number of mcast size
544
- mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
545
- consumer_arrive_cnt = mcast_size * (self.num_mma_threads // cute.arch.WARP_SIZE)
546
- mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
547
- pipeline.Agent.Thread, consumer_arrive_cnt
548
- )
549
-
550
- cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
551
- mainloop_pipeline = pipeline.PipelineTmaAsync.create(
552
- barrier_storage=storage.mainloop_pipeline_array_ptr.data_ptr(),
553
- num_stages=self.ab_stage,
554
- producer_group=mainloop_pipeline_producer_group,
555
- consumer_group=mainloop_pipeline_consumer_group,
556
- tx_count=tma_copy_bytes,
557
- cta_layout_vmnk=cta_layout_vmnk,
643
+ ab_pipeline = self.make_ab_pipeline(
644
+ a_smem_layout=cute.slice_(a_smem_layout, (None, None, 0)),
645
+ b_smem_layout=cute.slice_(b_smem_layout, (None, None, 0)),
646
+ tiled_mma=tiled_mma,
647
+ cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
648
+ ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
558
649
  )
650
+ epi_pipeline = None
651
+ if const_expr(has_C):
652
+ epi_pipeline = self.make_epi_pipeline(
653
+ c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
654
+ epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
655
+ )
656
+ sched_pipeline = None
657
+ tile_count = None
658
+ if const_expr(tile_sched_params.tile_count_semaphore is not None):
659
+ # Dynamic persistent scheduler
660
+ sched_pipeline = self.make_sched_pipeline(
661
+ cluster_layout_mnk,
662
+ sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
663
+ varlen_k=varlen_k,
664
+ )
665
+ tile_count = storage.tile_count.get_tensor((self.sched_stage,))
559
666
 
560
667
  # ///////////////////////////////////////////////////////////////////////////////
561
668
  # Generate smem tensor A/B
562
669
  # ///////////////////////////////////////////////////////////////////////////////
563
- sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
564
- sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
565
- sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
566
- sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
567
-
568
- # ///////////////////////////////////////////////////////////////////////////////
569
- # Get cta/warp/thread idx
570
- # ///////////////////////////////////////////////////////////////////////////////
571
-
572
- cidx, cidy, _ = cute.arch.cluster_idx()
573
- cdimx, cdimy, _ = cute.arch.cluster_dim()
574
- cluster_id = cidx + cdimx * cidy
670
+ sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
671
+ sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
672
+ sD = None
673
+ if const_expr(has_D):
674
+ if const_expr(not self.is_persistent):
675
+ sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
676
+ sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
677
+ else:
678
+ sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
679
+ sC = None
680
+ if const_expr(has_C):
681
+ sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
682
+ epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
683
+
684
+ # Get tensormap buffer address
685
+ tensormap_manager = None
686
+ tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
687
+ if const_expr(varlen_m or varlen_k):
688
+ tensormap_manager = TensorMapManagerSm90(
689
+ cutlass.utils.TensorMapUpdateMode.GMEM, GemmSm90.bytes_per_tensormap
690
+ )
691
+ # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
692
+ tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
693
+ if const_expr(varlen_m):
694
+ tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
695
+ tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
696
+ tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
697
+ )
698
+ else:
699
+ assert varlen_k
700
+ tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
701
+ tensormaps[tensormap_workspace_idx, 0, None].iterator
702
+ )
703
+ tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
704
+ tensormaps[tensormap_workspace_idx, 1, None].iterator
705
+ )
575
706
 
576
- # CTA Swizzle to promote L2 data reuse
577
- group_size_m = 8
578
- s_shape = (
579
- (group_size_m, cdimx // group_size_m),
580
- cdimy,
707
+ TileSchedulerCls = partial(
708
+ TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
581
709
  )
582
- s_stride = ((1, cdimy * group_size_m), group_size_m)
583
- s_layout = cute.make_layout(s_shape, stride=s_stride)
584
- num_reg_cids = cute.size(s_shape)
585
- cid_m, cid_n = s_layout.get_flat_coord(cluster_id % num_reg_cids)
586
-
587
- # Deal with the tail part
588
- if cluster_id >= num_reg_cids:
589
- tail_size_m = cdimx % group_size_m
590
- tail_layout = cute.make_layout((tail_size_m, cdimy), stride=(1, tail_size_m))
591
- tail_cid = cluster_id - num_reg_cids
592
- tail_cid_m, tail_cid_n = tail_layout.get_flat_coord(tail_cid)
593
- cid_m = cute.size(s_shape, mode=[0]) + tail_cid_m
594
- cid_n = tail_cid_n
595
-
596
- # Get the pid from cluster id
597
- bidx_in_cluster = cute.arch.block_in_cluster_idx()
598
- pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
599
- pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
600
-
601
- _, _, bidz = cute.arch.block_idx()
602
- tile_coord_mnkl = (pid_m, pid_n, None, bidz)
603
- cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
604
- cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
605
-
606
- k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
607
-
608
- if warp_idx >= self.mma_warp_groups * 4:
710
+
711
+ if warp_idx >= self.ab_load_warp_id:
609
712
  cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
610
- if warp_idx == self.mma_warp_groups * 4:
713
+ if (
714
+ warp_idx >= self.ab_load_warp_id
715
+ and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
716
+ ):
717
+ is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
718
+ if const_expr(varlen_k):
719
+ # initialize tensormap for A & B
720
+ tensormap_manager.init_tensormap_from_atom(
721
+ tma_atom_a,
722
+ tensormap_a_ptr,
723
+ is_tma_warp,
724
+ )
725
+ tensormap_manager.init_tensormap_from_atom(
726
+ tma_atom_b,
727
+ tensormap_b_ptr,
728
+ is_tma_warp,
729
+ )
611
730
  # ///////////////////////////////////////////////////////////////////////////////
612
731
  # Get mcast mask
613
732
  # ///////////////////////////////////////////////////////////////////////////////
733
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
734
+ cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
614
735
  a_mcast_mask = cute.make_layout_image_mask(
615
- cta_layout_mnk, cluster_coord_mnk, mode=1
736
+ cluster_layout_mnk, cluster_coord_mnk, mode=1
616
737
  )
617
738
  b_mcast_mask = cute.make_layout_image_mask(
618
- cta_layout_mnk, cluster_coord_mnk, mode=0
739
+ cluster_layout_mnk, cluster_coord_mnk, mode=0
619
740
  )
620
741
  a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
621
742
  b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
622
- mainloop_producer_state = pipeline.make_pipeline_state(
743
+
744
+ # Persistent tile scheduling loop
745
+ is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
746
+ if const_expr(cute.size(cluster_layout_mnk) > 1):
747
+ is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
748
+ tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
749
+ work_tile = tile_scheduler.initial_work_tile_info()
750
+ ab_producer_state = make_pipeline_state(
623
751
  pipeline.PipelineUserType.Producer, self.ab_stage
624
752
  )
625
- # ///////////////////////////////////////////////////////////////////////////////
626
- # Local_tile partition global tensors
627
- # ///////////////////////////////////////////////////////////////////////////////
628
- # (bM, bK, RestK)
629
- gA_mkl = cute.local_tile(
630
- mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
631
- )
632
- # (bN, bK, RestK)
633
- gB_nkl = cute.local_tile(
634
- mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
635
- )
636
- # //////////////////////////////////////////////////////////////////////////////
637
- # Partition shared tensor for TMA load A/B
638
- # //////////////////////////////////////////////////////////////////////////////
639
- # TMA load A partition_S/D
640
- a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
641
- a_cta_crd = cluster_coord_mnk[1]
642
- tAsA, tAgA_mkl = cpasync.tma_partition(
643
- tma_atom_a,
644
- a_cta_crd,
645
- a_cta_layout,
646
- cute.group_modes(sA, 0, 2),
647
- cute.group_modes(gA_mkl, 0, 2),
648
- )
649
- # TMA load B partition_S/D
650
- b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
651
- b_cta_crd = cluster_coord_mnk[0]
652
- tBsB, tBgB_nkl = cpasync.tma_partition(
653
- tma_atom_b,
654
- b_cta_crd,
655
- b_cta_layout,
656
- cute.group_modes(sB, 0, 2),
657
- cute.group_modes(gB_nkl, 0, 2),
658
- )
659
- # /////////////////////////////////////////////////////////////////////////////
660
- # TMA load
661
- # /////////////////////////////////////////////////////////////////////////////
662
- for k_tile in cutlass.range(k_tile_cnt, unroll=1):
663
- # Wait for A/B buffers to be empty before loading into them
664
- # Also sets the transaction barrier for the A/B buffers
665
- mainloop_pipeline.producer_acquire(mainloop_producer_state)
666
- # /////////////////////////////////////////////////////////////////////////////
667
- # TMA load A/B
668
- # /////////////////////////////////////////////////////////////////////////////
669
- cute.copy(
670
- tma_atom_a,
671
- tAgA_mkl[None, k_tile],
672
- tAsA[None, mainloop_producer_state.index],
673
- tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state),
674
- mcast_mask=a_mcast_mask,
753
+ if const_expr(varlen_k):
754
+ # wait tensormap initialization complete before update
755
+ tensormap_manager.fence_tensormap_initialization()
756
+ # batch index of last tile
757
+ last_batch_idx = cutlass.Int32(-1)
758
+ while work_tile.is_valid_tile:
759
+ tile_coord_mnkl = work_tile.tile_idx
760
+ batch_idx = tile_coord_mnkl[3]
761
+ if const_expr(varlen_k):
762
+ is_group_changed = batch_idx != last_batch_idx
763
+ last_batch_idx = batch_idx
764
+ if is_group_changed:
765
+ # construct tensor A/B based on real address, shape and stride information
766
+ tensormap_manager.update_tensormap_shape(
767
+ (tensormap_a_ptr, tensormap_b_ptr),
768
+ is_manager_warp=is_tma_warp,
769
+ shapes=(cu_seqlens_k[batch_idx + 1], cu_seqlens_k[batch_idx + 1]),
770
+ orders=(
771
+ 0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1,
772
+ 0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1,
773
+ ),
774
+ tensormap_smem_ptr=None,
775
+ )
776
+ # ///////////////////////////////////////////////////////////////////////////
777
+ # Local_tile partition global tensors
778
+ # ///////////////////////////////////////////////////////////////////////////
779
+ if const_expr(not self.gather_A):
780
+ if const_expr(varlen_m):
781
+ mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
782
+ elif const_expr(varlen_k):
783
+ mA_mk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mA_mkl)
784
+ else:
785
+ mA_mk = mA_mkl[None, None, batch_idx]
786
+ # (bM, bK, RestK)
787
+ gA_k = cute.local_tile(
788
+ mA_mk,
789
+ cute.select(self.tile_shape_mnk, [0, 2]),
790
+ (tile_coord_mnkl[0], None),
791
+ )
792
+ else:
793
+ mA_mk = mA_mkl
794
+ if const_expr(varlen_m):
795
+ mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
796
+ elif const_expr(varlen_k):
797
+ mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
798
+ else:
799
+ mAIdx_mk = mAIdx[None, batch_idx]
800
+ gAIdx = cute.local_tile(
801
+ mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
802
+ )
803
+ if const_expr(varlen_k):
804
+ mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
805
+ else:
806
+ mB_nk = mB_nkl[None, None, batch_idx]
807
+ # (bN, bK, RestK)
808
+ gB_k = cute.local_tile(
809
+ mB_nk, cute.select(self.tile_shape_mnk, [1, 2]), (tile_coord_mnkl[1], None)
675
810
  )
676
- cute.copy(
811
+ # //////////////////////////////////////////////////////////////////////////
812
+ # Partition shared tensor for TMA load A/B
813
+ # //////////////////////////////////////////////////////////////////////////
814
+ if const_expr(varlen_k):
815
+ # ensure the update to tensormap has completed before using it
816
+ if is_group_changed and is_tma_warp:
817
+ tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
818
+ tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
819
+ tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
820
+ tensormap_a_ptr, cute.AddressSpace.generic
821
+ )
822
+ tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
823
+ tensormap_b_ptr, cute.AddressSpace.generic
824
+ )
825
+ else:
826
+ tma_desc_a_ptr, tma_desc_b_ptr = None, None
827
+ # TMA load A partition_S/D
828
+ a_cta_layout = cute.make_layout(
829
+ cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
830
+ )
831
+ a_cta_crd = cluster_coord_mnk[1]
832
+ if const_expr(not self.gather_A):
833
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
834
+ tAsA, tAgA_k = cpasync.tma_partition(
835
+ tma_atom_a,
836
+ a_cta_crd,
837
+ a_cta_layout,
838
+ cute.group_modes(sA, 0, 2),
839
+ cute.group_modes(gA_k, 0, 2),
840
+ )
841
+ copy_A = partial(
842
+ cute.copy,
843
+ tma_atom_a,
844
+ mcast_mask=a_mcast_mask,
845
+ tma_desc_ptr=tma_desc_a_ptr,
846
+ )
847
+ else:
848
+ tiled_copy_A = self._make_gmem_tiled_copy_A(
849
+ mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
850
+ )
851
+ tidx = (
852
+ cute.arch.thread_idx()[0]
853
+ - self.mma_warp_groups * self.num_threads_per_warp_group
854
+ )
855
+ thr_copy_A = tiled_copy_A.get_slice(tidx)
856
+ # (atom_v, CPY_M, 1, STAGE)
857
+ tAsA = thr_copy_A.partition_D(sA)
858
+ assert tAsA.shape[2] == 1
859
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
860
+ copy_A = partial(cute.copy, tiled_copy_A)
861
+ # TMA load B partition_S/D
862
+ b_cta_layout = cute.make_layout(
863
+ cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
864
+ )
865
+ b_cta_crd = cluster_coord_mnk[0]
866
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
867
+ tBsB, tBgB_k = cpasync.tma_partition(
677
868
  tma_atom_b,
678
- tBgB_nkl[None, k_tile],
679
- tBsB[None, mainloop_producer_state.index],
680
- tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state),
681
- mcast_mask=b_mcast_mask,
869
+ b_cta_crd,
870
+ b_cta_layout,
871
+ cute.group_modes(sB, 0, 2),
872
+ cute.group_modes(gB_k, 0, 2),
682
873
  )
683
- # Mainloop pipeline's producer commit is a NOP
684
- mainloop_pipeline.producer_commit(mainloop_producer_state)
685
- mainloop_producer_state.advance()
686
- mainloop_pipeline.producer_tail(mainloop_producer_state)
687
-
688
- if warp_idx < self.mma_warp_groups * 4:
874
+ copy_B = partial(
875
+ cute.copy, tma_atom_b, mcast_mask=b_mcast_mask, tma_desc_ptr=tma_desc_b_ptr
876
+ )
877
+ k_len = (
878
+ cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
879
+ if const_expr(varlen_k)
880
+ else mA_mkl.shape[1]
881
+ )
882
+ k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
883
+ if const_expr(not self.gather_A):
884
+ ab_producer_state = self.load_AB(
885
+ ab_pipeline,
886
+ ab_producer_state,
887
+ copy_A,
888
+ tAgA_k,
889
+ tAsA,
890
+ copy_B,
891
+ tBgB_k,
892
+ tBsB,
893
+ k_tile_cnt,
894
+ )
895
+ else:
896
+ limit_m = (
897
+ mAIdx.shape[0]
898
+ if const_expr(cu_seqlens_m is None)
899
+ else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
900
+ )
901
+ ab_producer_state = self.load_AB_gather_A(
902
+ ab_pipeline,
903
+ ab_producer_state,
904
+ thr_copy_A,
905
+ mA_mk,
906
+ tAsA,
907
+ gAIdx,
908
+ copy_B,
909
+ tBgB_k,
910
+ tBsB,
911
+ k_tile_cnt,
912
+ limit_A=(
913
+ limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
914
+ mA_mk.shape[1],
915
+ ),
916
+ )
917
+ tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
918
+ tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
919
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
920
+ work_tile = tile_scheduler.get_current_work()
921
+ # End of persistent scheduler loop
922
+ if const_expr(self.pingpong and not varlen_k):
923
+ # Need to write the tile_idx to smem for the next WG in the pingpong mode
924
+ # tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
925
+ tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
926
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
927
+ ab_pipeline.producer_tail(ab_producer_state)
928
+ if is_scheduler_warp:
929
+ tile_scheduler.producer_tail()
930
+
931
+ if warp_idx < self.ab_load_warp_id:
689
932
  cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
933
+ is_tma_warp = Boolean(
934
+ (not self.pingpong and warp_idx == 0)
935
+ or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
936
+ )
937
+ if const_expr(varlen_m):
938
+ # initialize tensormap for D
939
+ tensormap_manager.init_tensormap_from_atom(
940
+ tma_atom_d,
941
+ tensormap_d_ptr,
942
+ is_manager_warp=is_tma_warp,
943
+ )
690
944
  # //////////////////////////////////////////////////////////////////////////////
691
945
  # Partition global tensor for TiledMMA_A/B/C
692
946
  # //////////////////////////////////////////////////////////////////////////////
693
947
  tidx, _, _ = cute.arch.thread_idx()
694
948
  warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
949
+ if const_expr(self.pingpong):
950
+ tidx = tidx % self.num_threads_per_warp_group
695
951
  warp_group_thread_layout = cute.make_layout(
696
- self.mma_warp_groups, stride=self.num_threads_per_warp_group
952
+ self.mma_warp_groups if not self.pingpong else 1,
953
+ stride=self.num_threads_per_warp_group,
954
+ )
955
+ thr_mma = tiled_mma.get_slice(
956
+ warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
697
957
  )
698
- thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx))
699
958
 
700
959
  # //////////////////////////////////////////////////////////////////////////////
701
960
  # Make fragments
@@ -705,148 +964,818 @@ class HopperWgmmaGemmKernel:
705
964
 
706
965
  acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
707
966
  acc = cute.make_fragment(acc_shape, self.acc_dtype)
708
-
709
- mainloop_consumer_read_state = pipeline.make_pipeline_state(
710
- pipeline.PipelineUserType.Consumer, self.ab_stage
967
+ acc_slow = None
968
+ if const_expr(self.fp8_slow_accum):
969
+ acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
970
+
971
+ if const_expr(self.pingpong):
972
+ if warp_group_idx == 0:
973
+ # WG0 needs a start signal at the very beginning
974
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
975
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
976
+
977
+ k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.tile_shape_mnk[2])
978
+ c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
979
+
980
+ ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
981
+ epi_read_state = make_pipeline_state(
982
+ pipeline.PipelineUserType.Consumer, self.epi_c_stage
711
983
  )
712
- mainloop_consumer_release_state = pipeline.make_pipeline_state(
713
- pipeline.PipelineUserType.Consumer, self.ab_stage
984
+ epi_producer_state = make_pipeline_state(
985
+ pipeline.PipelineUserType.Producer, self.epi_c_stage
714
986
  )
987
+ tile_scheduler = TileSchedulerCls()
988
+ work_tile = None
989
+ if const_expr(self.pingpong):
990
+ if const_expr(varlen_k):
991
+ work_tile = tile_scheduler.initial_work_tile_info()
992
+ if warp_idx >= 4:
993
+ # Advance 2nd Math WG pipeline states to the end of 1st Math WG
994
+ epi_read_state.advance_iters(c_tile_cnt)
995
+ epi_producer_state.advance_iters(c_tile_cnt)
996
+ if const_expr(not varlen_k):
997
+ ab_read_state.advance_iters(k_tile_cnt_static)
998
+ else:
999
+ batch_idx = work_tile.tile_idx[3]
1000
+ k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1001
+ k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1002
+ ab_read_state.advance_iters(k_tile_cnt)
1003
+ tile_scheduler.advance_to_next_work()
1004
+ if const_expr(varlen_k):
1005
+ work_tile = tile_scheduler.get_current_work()
1006
+ if const_expr(not varlen_k):
1007
+ work_tile = tile_scheduler.initial_work_tile_info()
1008
+ else:
1009
+ work_tile = tile_scheduler.initial_work_tile_info()
1010
+ if const_expr(varlen_m):
1011
+ # wait tensormap initialization complete before update
1012
+ tensormap_manager.fence_tensormap_initialization()
1013
+ # batch index of last tile
1014
+ last_batch_idx = cutlass.Int32(-1)
1015
+ while work_tile.is_valid_tile:
1016
+ tile_coord_mnkl = work_tile.tile_idx
1017
+ batch_idx = tile_coord_mnkl[3]
1018
+ if const_expr(varlen_m):
1019
+ is_group_changed = batch_idx != last_batch_idx
1020
+ last_batch_idx = batch_idx
1021
+ if is_group_changed:
1022
+ # construct tensor D based on real address, shape and stride information
1023
+ tensormap_manager.update_tensormap_shape(
1024
+ (tensormap_d_ptr,),
1025
+ is_manager_warp=is_tma_warp,
1026
+ shapes=(cu_seqlens_m[batch_idx + 1],),
1027
+ orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
1028
+ tensormap_smem_ptr=None,
1029
+ )
1030
+
1031
+ k_len = (
1032
+ cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1033
+ if const_expr(varlen_k)
1034
+ else mA_mkl.shape[1]
1035
+ )
1036
+ k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1037
+ ab_read_state, tiled_mma = self.mma(
1038
+ ab_pipeline,
1039
+ ab_read_state,
1040
+ tiled_mma,
1041
+ tCrA,
1042
+ tCrB,
1043
+ acc,
1044
+ acc_slow,
1045
+ k_tile_cnt,
1046
+ warp_group_idx,
1047
+ )
1048
+ if const_expr(varlen_k):
1049
+ if k_tile_cnt == 0:
1050
+ acc.fill(0.0)
1051
+
1052
+ # /////////////////////////////////////////////////////////////////////////////
1053
+ # EPILOGUE
1054
+ # /////////////////////////////////////////////////////////////////////////////
1055
+ if const_expr(self.pingpong):
1056
+ self.pingpong_barrier_sync(warp_group_idx, "epi")
715
1057
 
716
- # /////////////////////////////////////////////////////////////////////////////
717
- # Prologue MMAs
718
- # /////////////////////////////////////////////////////////////////////////////
719
- k_pipe_mmas = 1
720
- peek_ab_full_status = cutlass.Boolean(1)
721
- if mainloop_consumer_read_state.count < k_tile_cnt:
722
- peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
723
- mainloop_consumer_read_state
1058
+ epilogue_barrier = pipeline.NamedBarrier(
1059
+ barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
724
1060
  )
725
- tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
726
- num_k_blocks = cute.size(tCrA, mode=[2])
727
- for k_tile in cutlass.range_constexpr(k_pipe_mmas):
728
- # Wait for A/B buffer to be ready
729
- mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status)
730
- warpgroup.fence()
731
- for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
732
- k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index)
733
- cute.gemm(tiled_mma, acc, tCrA[k_block_coord], tCrB[k_block_coord], acc)
734
- tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
735
- warpgroup.commit_group()
736
- mainloop_consumer_read_state.advance()
737
- peek_ab_full_status = cutlass.Boolean(1)
738
- if mainloop_consumer_read_state.count < k_tile_cnt:
739
- peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
740
- mainloop_consumer_read_state
741
- )
742
1061
 
743
- # /////////////////////////////////////////////////////////////////////////////
744
- # MAINLOOP
745
- # /////////////////////////////////////////////////////////////////////////////
746
- for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, unroll=1):
747
- # Wait for TMA copies to complete
748
- mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status)
749
- # WGMMA
750
- warpgroup.fence()
751
- for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
752
- k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index)
753
- cute.gemm(tiled_mma, acc, tCrA[k_block_coord], tCrB[k_block_coord], acc)
754
- warpgroup.commit_group()
755
- # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
756
- warpgroup.wait_group(k_pipe_mmas)
757
- mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
758
- mainloop_consumer_read_state.advance()
759
- mainloop_consumer_release_state.advance()
760
- peek_ab_full_status = cutlass.Boolean(1)
761
- if mainloop_consumer_read_state.count < k_tile_cnt:
762
- peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
763
- mainloop_consumer_read_state
1062
+ if const_expr(varlen_m):
1063
+ # ensure the update to tensormap has completed before using it
1064
+ if is_group_changed and is_tma_warp:
1065
+ tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1066
+ tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
1067
+ tensormap_d_ptr, cute.AddressSpace.generic
764
1068
  )
765
- warpgroup.wait_group(0)
766
- for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
767
- mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
768
- mainloop_consumer_release_state.advance()
769
-
770
- # /////////////////////////////////////////////////////////////////////////////
771
- # EPILOGUE
772
- # /////////////////////////////////////////////////////////////////////////////
773
-
774
- # Wait for all warp groups in the thread block to finish, because smem for tensor A in
775
- # the mainloop is reused in the epilogue.
776
- cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
777
-
778
- copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
779
- self.d_layout,
780
- elem_ty_d=self.d_dtype,
781
- elem_ty_acc=self.acc_dtype,
782
- )
783
- copy_atom_D = cute.make_copy_atom(
784
- warp.StMatrix8x8x16bOp(self.d_layout.is_m_major_c(), 4),
785
- self.d_dtype,
786
- )
787
- tiled_copy_D_Atom = cute.make_tiled_copy_C_atom(copy_atom_D, tiled_mma)
788
- tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_D_Atom)
789
- # (R2S, R2S_M, R2S_N, PIPE_D)
790
- tRS_sD = tiled_copy_r2s.get_slice(tidx).partition_D(sD)
791
- # (R2S, R2S_M, R2S_N)
792
- tRS_rAcc = tiled_copy_r2s.retile(acc)
793
-
794
- # (bM, bN)
795
- gD_mnl = cute.local_tile(
796
- mD_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
1069
+ else:
1070
+ tma_desc_d_ptr = None
1071
+
1072
+ if const_expr(has_D):
1073
+ bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(
1074
+ tma_atom_d,
1075
+ mD_mnl,
1076
+ self.tile_shape_mnk[:2],
1077
+ self.epi_tile,
1078
+ sD,
1079
+ tile_coord_mnkl,
1080
+ cu_seqlens_m,
1081
+ )
1082
+ copy_D = partial(cute.copy, tma_atom_d, tma_desc_ptr=tma_desc_d_ptr)
1083
+ else:
1084
+ bSG_sD, bSG_gD, copy_D = None, None, None
1085
+ if const_expr(has_C):
1086
+ bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
1087
+ tma_atom_c,
1088
+ mC_mnl,
1089
+ self.tile_shape_mnk[:2],
1090
+ self.epi_tile,
1091
+ sC,
1092
+ tile_coord_mnkl,
1093
+ cu_seqlens_m,
1094
+ )
1095
+ copy_C = partial(cute.copy, tma_atom_c)
1096
+ epi_load_g2s = partial(self.epi_load_g2s, epi_pipeline, copy_C, bGS_gC, bGS_sC)
1097
+ else:
1098
+ epi_load_g2s = None
1099
+
1100
+ d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
1101
+ tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
1102
+ tiled_mma, self.d_layout, d_dtype_for_layout, acc, sD, tidx
1103
+ )
1104
+ if const_expr(has_C):
1105
+ tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
1106
+ tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
1107
+ )
1108
+ else:
1109
+ tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
1110
+
1111
+ # Wait for all warp groups in the thread block to finish, because smem for tensor
1112
+ # A in the mainloop is reused in the epilogue if not persistent.
1113
+ if const_expr(not self.is_persistent):
1114
+ epilogue_barrier.arrive_and_wait()
1115
+
1116
+ self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
1117
+
1118
+ epi_read_state, epi_producer_state = self.epilogue(
1119
+ epilogue_params,
1120
+ epi_smem_tensors,
1121
+ epi_pipeline,
1122
+ epi_read_state,
1123
+ epi_producer_state,
1124
+ tiled_mma,
1125
+ tRS_rAcc,
1126
+ tRS_rD,
1127
+ tRS_rC,
1128
+ tiled_copy_r2s,
1129
+ tRS_sD,
1130
+ tiled_copy_s2r,
1131
+ tSR_rC,
1132
+ tSR_sC,
1133
+ copy_D,
1134
+ bSG_sD,
1135
+ bSG_gD,
1136
+ epi_load_g2s,
1137
+ tile_coord_mnkl,
1138
+ cu_seqlens_m,
1139
+ epilogue_barrier,
1140
+ tile_scheduler,
1141
+ tidx,
1142
+ is_tma_warp,
1143
+ )
1144
+
1145
+ if const_expr(self.pingpong):
1146
+ # With pingpong, 2 WGs write two different output tiles to the same smem,
1147
+ # so we have to make sure the smem content is done reading before signaling
1148
+ # the next WG's epilogue.
1149
+ if is_tma_warp:
1150
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1151
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
1152
+
1153
+ if const_expr(not self.pingpong):
1154
+ tile_scheduler.advance_to_next_work()
1155
+ work_tile = tile_scheduler.get_current_work()
1156
+ else: # Skip a tile for pingpong
1157
+ # Update starting load/store pipeline states for the next tile
1158
+ epi_read_state.advance_iters(c_tile_cnt)
1159
+ epi_producer_state.advance_iters(c_tile_cnt)
1160
+ # Update starting mainloop pipeline state for the next tile
1161
+ if const_expr(not varlen_k):
1162
+ ab_read_state.advance_iters(k_tile_cnt_static)
1163
+ tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
1164
+ work_tile = tile_scheduler.get_current_work()
1165
+ else:
1166
+ tile_scheduler.advance_to_next_work()
1167
+ work_tile = tile_scheduler.get_current_work()
1168
+ if work_tile.is_valid_tile:
1169
+ batch_idx = work_tile.tile_idx[3]
1170
+ k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1171
+ k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1172
+ ab_read_state.advance_iters(k_tile_cnt)
1173
+ tile_scheduler.advance_to_next_work()
1174
+ work_tile = tile_scheduler.get_current_work()
1175
+ # End of persistent scheduler loop
1176
+
1177
+ if const_expr(not self.pingpong):
1178
+ if is_tma_warp:
1179
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1180
+
1181
+ @cute.jit
1182
+ def load_AB(
1183
+ self,
1184
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1185
+ ab_producer_state: cutlass.pipeline.PipelineState,
1186
+ copy_A: Callable,
1187
+ tAgA: cute.Tensor,
1188
+ tAsA: cute.Tensor,
1189
+ copy_B: Callable,
1190
+ tBgB: cute.Tensor,
1191
+ tBsB: cute.Tensor,
1192
+ k_tile_cnt: Int32,
1193
+ ) -> cutlass.pipeline.PipelineState:
1194
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1195
+ peek_ab_empty_status = Boolean(True)
1196
+ if 0 < k_tile_cnt:
1197
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1198
+ # /////////////////////////////////////////////////////////////////////////
1199
+ # TMA load
1200
+ # /////////////////////////////////////////////////////////////////////////
1201
+ for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1202
+ # Wait for A/B buffers to be empty before loading into them
1203
+ # Also sets the transaction barrier for the A/B buffers
1204
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1205
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1206
+ copy_A(tAgA[None, k_tile], tAsA[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
1207
+ copy_B(tBgB[None, k_tile], tBsB[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
1208
+ # Mainloop pipeline's producer commit is a NOP
1209
+ ab_pipeline.producer_commit(ab_producer_state)
1210
+ ab_producer_state.advance()
1211
+ peek_ab_empty_status = Boolean(True)
1212
+ if k_tile + 1 < k_tile_cnt:
1213
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1214
+ return ab_producer_state
1215
+
1216
+ @cute.jit
1217
+ def load_AB_gather_A(
1218
+ self,
1219
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1220
+ ab_producer_state: cutlass.pipeline.PipelineState,
1221
+ thr_copy_A: cute.core.ThrCopy,
1222
+ mA: cute.Tensor,
1223
+ tAsA: cute.Tensor,
1224
+ gAIdx: cute.Tensor,
1225
+ copy_B: Callable,
1226
+ tBgB: cute.Tensor,
1227
+ tBsB: cute.Tensor,
1228
+ k_tile_cnt: Int32,
1229
+ limit_A: Tuple[Int32, Int32],
1230
+ ) -> cutlass.pipeline.PipelineState:
1231
+ # (atom_v, CPY_M, 1, RestK)
1232
+ limit_m, limit_k = limit_A
1233
+ limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
1234
+ cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
1235
+ tAcA = thr_copy_A.partition_S(cA)
1236
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
1237
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
1238
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
1239
+ # This is so that when we do the comparison, t0AcA is known at compile time.
1240
+ limit_m = limit_m - tAcA[0][0]
1241
+ # Read indices for A
1242
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
1243
+ m_idx = cute.make_fragment(rows_per_thread, Int32)
1244
+ for m in cutlass.range(rows_per_thread):
1245
+ row_idx = tAcA[0, m, 0][0]
1246
+ if t0AcA[0, m, 0][0] < limit_m:
1247
+ m_idx[m] = gAIdx[row_idx]
1248
+ else:
1249
+ m_idx[m] = -1
1250
+ elems_per_load = cute.size(tAsA.shape[0][0])
1251
+ # (m, (bK, RestK))
1252
+ mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
1253
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1254
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1255
+ peek_ab_empty_status = Boolean(True)
1256
+ if 0 < k_tile_cnt:
1257
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1258
+ # /////////////////////////////////////////////////////////////////////////
1259
+ # TMA load on B and cp.async on A
1260
+ # /////////////////////////////////////////////////////////////////////////
1261
+ copy_A = partial(cute.copy, thr_copy_A)
1262
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1263
+ # Wait for A/B buffers to be empty before loading into them
1264
+ # Also sets the transaction barrier for the A/B buffers
1265
+ ab_pipeline.producer_acquire(
1266
+ ab_producer_state,
1267
+ peek_ab_empty_status,
1268
+ # A tiny bit faster to rotate the warp that does TMA
1269
+ is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
797
1270
  )
798
- tcgc_for_tma_partition = cute.zipped_divide(gD_mnl, self.epi_tile)
799
- bSG_sD, bSG_gD = cpasync.tma_partition(
800
- tma_atom_d,
801
- 0,
802
- cute.make_layout(1),
803
- cute.group_modes(sD, 0, 2),
804
- tcgc_for_tma_partition,
1271
+ # A bit faster to load B first while we calculate the predicate for A
1272
+ if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1273
+ copy_B(
1274
+ tBgB[None, k_tile],
1275
+ tBsB[None, ab_producer_state.index],
1276
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1277
+ )
1278
+ # (m, bK)
1279
+ mA_cur = mA_k[None, (None, k_tile)]
1280
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1281
+ # (elems_per_load, thread_per_row)
1282
+ mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1283
+ if t0AcA[0, m, 0][0] < limit_m:
1284
+ # There's only 1 load per row
1285
+ assert cute.size(tAcA.shape, mode=[2]) == 1
1286
+ ki = tAcA[0, 0, 0][1] // elems_per_load
1287
+ copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1288
+ # This tells mbarrier to track the completion of cp.async
1289
+ ab_pipeline.producer_commit(ab_producer_state)
1290
+ ab_producer_state.advance()
1291
+ peek_ab_empty_status = Boolean(True)
1292
+ if k_tile + 1 < k_tile_cnt:
1293
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1294
+ # bound checking in the K dimension on the last k_tile
1295
+ if 0 < k_tile_cnt:
1296
+ k_tile = k_tile_cnt - 1
1297
+ ab_pipeline.producer_acquire(
1298
+ ab_producer_state,
1299
+ peek_ab_empty_status,
1300
+ is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
805
1301
  )
1302
+ if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1303
+ copy_B(
1304
+ tBgB[None, k_tile],
1305
+ tBsB[None, ab_producer_state.index],
1306
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1307
+ )
1308
+ assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
1309
+ tApA = cute.make_fragment(1, Boolean)
1310
+ tApA[0] = tAcA[0, 0, 0][1] < limit_k
1311
+ # (m, bK)
1312
+ mA_cur = mA_k[None, (None, k_tile)]
1313
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1314
+ # (elems_per_load, thread_per_row)
1315
+ mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1316
+ if t0AcA[0, m, 0][0] < limit_m:
1317
+ # There's only 1 load per row
1318
+ assert cute.size(tAcA.shape, mode=[2]) == 1
1319
+ ki = tAcA[0, 0, 0][1] // elems_per_load
1320
+ # copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
1321
+ # TODO
1322
+ copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1323
+ ab_pipeline.producer_commit(ab_producer_state)
1324
+ ab_producer_state.advance()
1325
+ return ab_producer_state
806
1326
 
807
- epi_tile_num = cutlass.const_expr(cute.size(tcgc_for_tma_partition, mode=[1]))
808
- epi_tile_shape = tcgc_for_tma_partition.shape[1]
1327
+ @cute.jit
1328
+ def mma(
1329
+ self,
1330
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1331
+ ab_read_state: cutlass.pipeline.PipelineState,
1332
+ tiled_mma: cute.TiledMma,
1333
+ tCrA: cute.Tensor,
1334
+ tCrB: cute.Tensor,
1335
+ acc: cute.Tensor,
1336
+ acc_slow: Optional[cute.Tensor],
1337
+ k_tile_cnt: Int32,
1338
+ warp_group_idx: Int32,
1339
+ ) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
1340
+ # /////////////////////////////////////////////////////////////////////////////
1341
+ # Prologue MMAs
1342
+ # /////////////////////////////////////////////////////////////////////////////
1343
+ k_pipe_mmas = 1
1344
+ ab_release_state = ab_read_state.clone()
1345
+ num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
1346
+ if const_expr(self.pingpong):
1347
+ self.pingpong_barrier_sync(warp_group_idx, stage="mma")
1348
+ peek_ab_full_status = Boolean(True)
1349
+ if 0 < k_tile_cnt:
1350
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1351
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1352
+ num_k_blocks = cute.size(tCrA, mode=[2])
1353
+ for k_tile in cutlass.range(num_prologue_mma):
1354
+ # Wait for A/B buffer to be ready
1355
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1356
+ warpgroup.fence()
1357
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1358
+ k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1359
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1360
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1361
+ warpgroup.commit_group()
1362
+ ab_read_state.advance()
1363
+ peek_ab_full_status = Boolean(True)
1364
+ if k_tile + 1 < k_tile_cnt:
1365
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1366
+ # If k_tile_cnt == 0, this is not correct. But we will set acc to 0 in the mainloop
1367
+ # in that case.
1368
+ if const_expr(self.fp8_slow_accum):
1369
+ warpgroup.wait_group(0)
1370
+ acc_slow.store(acc.load())
809
1371
 
810
- for epi_idx in cutlass.range_constexpr(epi_tile_num):
811
- # Copy from acc to D registers
812
- tRS_rD = cute.make_fragment_like(tRS_sD[None, None, None, 0], self.acc_dtype)
813
- for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
814
- tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
815
- # Type conversion
816
- tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
817
- tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
818
- # Copy from D registers to shared memory
819
- epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3])
820
- # cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
821
- cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)])
1372
+ # /////////////////////////////////////////////////////////////////////////////
1373
+ # MAINLOOP
1374
+ # /////////////////////////////////////////////////////////////////////////////
1375
+ for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
1376
+ # Wait for TMA copies to complete
1377
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1378
+ # WGMMA
1379
+ warpgroup.fence()
1380
+ if const_expr(self.fp8_slow_accum):
1381
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1382
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1383
+ k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1384
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1385
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1386
+ warpgroup.commit_group()
1387
+ # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
1388
+ if const_expr(not self.fp8_slow_accum):
1389
+ warpgroup.wait_group(k_pipe_mmas)
1390
+ else:
1391
+ warpgroup.wait_group(0)
1392
+ acc_slow.store(acc_slow.load() + acc.load())
1393
+ ab_pipeline.consumer_release(ab_release_state)
1394
+ ab_read_state.advance()
1395
+ ab_release_state.advance()
1396
+ peek_ab_full_status = Boolean(True)
1397
+ if k_tile + 1 < k_tile_cnt:
1398
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1399
+ if const_expr(self.pingpong):
1400
+ # Cue for next WG's MMA to start
1401
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
1402
+ if const_expr(not self.fp8_slow_accum):
1403
+ # fp8_slow_accum would already called wait_group(0) inside the loop
1404
+ warpgroup.wait_group(0)
1405
+ for k_tile in cutlass.range(num_prologue_mma, unroll=1):
1406
+ ab_pipeline.consumer_release(ab_release_state)
1407
+ ab_release_state.advance()
1408
+ if const_expr(self.fp8_slow_accum):
1409
+ acc.store(acc_slow.load())
1410
+ # If we don't return the tiled_mma, we get compiler error
1411
+ # "operand #0 does not dominate this use"
1412
+ return ab_read_state, tiled_mma
1413
+
1414
+ @cute.jit
1415
+ def epilogue(
1416
+ self,
1417
+ params: EpilogueParams,
1418
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
1419
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
1420
+ epi_read_state: cutlass.pipeline.PipelineState,
1421
+ epi_producer_state: cutlass.pipeline.PipelineState,
1422
+ tiled_mma: cute.TiledMma,
1423
+ tRS_rAcc: cute.Tensor,
1424
+ tRS_rD: cute.Tensor,
1425
+ tRS_rC: Optional[cute.Tensor],
1426
+ tiled_copy_r2s: cute.core.ThrCopy,
1427
+ tRS_sD: cute.Tensor,
1428
+ tiled_copy_s2r: Optional[cute.core.ThrCopy],
1429
+ tSR_rC: Optional[cute.Tensor],
1430
+ tSR_sC: Optional[cute.Tensor],
1431
+ copy_D: Optional[Callable],
1432
+ bSG_sD: cute.Tensor,
1433
+ bSG_gD: cute.Tensor,
1434
+ epi_load_g2s: Optional[Callable],
1435
+ tile_coord_mnkl: cute.Coord,
1436
+ cu_seqlens_m: Optional[cute.Tensor],
1437
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
1438
+ tile_scheduler,
1439
+ tidx: Int32,
1440
+ is_tma_warp: Boolean,
1441
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
1442
+ has_C = const_expr(tRS_rC is not None)
1443
+ has_D = const_expr(copy_D is not None)
1444
+ # We iterate over epi tiles in the N dimension first before the M dimension
1445
+ epi_tile_shape = cute.zipped_divide(
1446
+ cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
1447
+ ).shape[1]
1448
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
1449
+ epi_tile_num = cute.size(epi_tile_shape)
1450
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1451
+
1452
+ if const_expr(epi_load_g2s is not None):
1453
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
1454
+ epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
1455
+
1456
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
1457
+ # Copy from acc to D registers
1458
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1459
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1460
+ if const_expr(has_C):
1461
+ epi_pipeline.consumer_wait(epi_read_state)
1462
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
1463
+ # Fence to make sure shared memory read is visible to TMA load
822
1464
  cute.arch.fence_proxy(
823
1465
  cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
824
1466
  )
825
- # barrier for sync
826
- cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
827
- # Get the global memory coordinate for the current epi tile.
828
- epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
829
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
830
- # Copy from shared memory to global memory
831
- if warp_idx == 0:
832
- cute.copy(tma_atom_d, bSG_sD[(None, epi_buffer)], bSG_gD[(None, gmem_coord)])
833
- cute.arch.cp_async_bulk_commit_group()
834
- # TODO: when moving to persistent maybe we always need this wait_group
835
- if epi_idx >= self.epi_stage - 1:
836
- cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
837
- if epi_idx >= self.epi_stage - 1:
838
- cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
839
-
840
- if warp_idx == 0:
841
- cute.arch.cp_async_bulk_wait_group(0, read=True)
1467
+ cute.arch.sync_warp()
1468
+ with cute.arch.elect_one():
1469
+ epi_pipeline.consumer_release(epi_read_state)
1470
+ epi_read_state.advance()
1471
+ if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
1472
+ epi_producer_state = epi_load_g2s(
1473
+ epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
1474
+ )
1475
+ tRS_rEpi = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
1476
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
1477
+ # Copy from D registers to shared memory
1478
+ if const_expr(has_D):
1479
+ # Type conversion
1480
+ tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
1481
+ tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
1482
+ cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
1483
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1484
+ cute.arch.fence_proxy(
1485
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1486
+ )
1487
+ epilogue_barrier.arrive_and_wait()
1488
+ # Get the global memory coordinate for the current epi tile
1489
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1490
+ # Copy from shared memory to global memory
1491
+ if is_tma_warp:
1492
+ if const_expr(has_D):
1493
+ copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
1494
+ cute.arch.cp_async_bulk_commit_group()
1495
+ cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
1496
+ epilogue_barrier.arrive_and_wait()
1497
+
1498
+ return epi_read_state, epi_producer_state
1499
+
1500
+ @cute.jit
1501
+ def epi_load_g2s(
1502
+ self,
1503
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
1504
+ copy_C: Callable,
1505
+ bGS_gC: cute.Tensor,
1506
+ bGS_sC: cute.Tensor,
1507
+ epi_producer_state: cutlass.pipeline.PipelineState,
1508
+ epi_idx: Int32,
1509
+ should_load: Boolean,
1510
+ ) -> cutlass.pipeline.PipelineState:
1511
+ # We iterate over epi tiles in the N dimension first before the M dimension
1512
+ epi_tile_layout = cute.make_layout(bGS_gC.shape[1], stride=(bGS_gC.shape[1][1], 1))
1513
+ if should_load:
1514
+ epi_pipeline.producer_acquire(epi_producer_state)
1515
+ # Get the global memory coordinate for the current epi tile
1516
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1517
+ copy_C(
1518
+ bGS_gC[None, gmem_coord],
1519
+ bGS_sC[None, epi_producer_state.index],
1520
+ tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1521
+ )
1522
+ # Epi pipeline's producer commit is a NOP
1523
+ epi_pipeline.producer_commit(epi_producer_state)
1524
+ epi_producer_state.advance()
1525
+ return epi_producer_state
1526
+
1527
+ def epi_visit_acc_subtile(
1528
+ self,
1529
+ params: EpilogueParams,
1530
+ tRS_rD: cute.Tensor,
1531
+ tRS_rC: Optional[cute.Tensor] = None,
1532
+ ) -> Optional[cute.Tensor]:
1533
+ # Apply alpha scaling to accumulator if alpha is provided (not None)
1534
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
1535
+ alpha = utils.load_scalar_or_pointer(params.alpha)
1536
+ tRS_rD.store(tRS_rD.load() * alpha)
1537
+ # Apply C with beta scaling
1538
+ if const_expr(tRS_rC is not None):
1539
+ if const_expr(not hasattr(params, "beta") or params.beta is None):
1540
+ # beta is None, default behavior: add C (beta=1.0)
1541
+ tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
1542
+ else:
1543
+ beta = utils.load_scalar_or_pointer(params.beta)
1544
+ tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
1545
+ return None
1546
+
1547
+ def get_scheduler_class(self):
1548
+ """Return the scheduler class to use. Override in subclasses for custom schedulers."""
1549
+ return TileScheduler
1550
+
1551
+ def get_scheduler_arguments(self, problem_shape_ntile_mnl, scheduler_args):
1552
+ """Create scheduler arguments. Override in subclasses for custom schedulers."""
1553
+ return TileSchedulerArguments(
1554
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1555
+ raster_order=scheduler_args.raster_order,
1556
+ group_size=scheduler_args.max_swizzle_size,
1557
+ cluster_shape_mnk=self.cluster_shape_mnk,
1558
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
1559
+ batch_idx_permute=scheduler_args.batch_idx_permute,
1560
+ is_persistent=self.is_persistent,
1561
+ )
1562
+
1563
+ def epi_visit_acc(
1564
+ self,
1565
+ params: EpilogueParams,
1566
+ acc: cute.Tensor,
1567
+ tiled_mma: cute.TiledMma,
1568
+ tile_coord_mnkl: cute.Coord,
1569
+ tidx: Int32,
1570
+ ) -> None:
1571
+ pass
1572
+
1573
+ def epi_to_underlying_arguments(
1574
+ self, args: EpilogueArguments, *, loc=None, ip=None
1575
+ ) -> EpilogueParams:
1576
+ return GemmSm90.EpilogueParams(alpha=args.alpha, beta=args.beta)
842
1577
 
843
1578
  @staticmethod
1579
+ def epi_smem_bytes_per_stage(
1580
+ args: Optional[EpilogueArguments],
1581
+ tile_shape_mnk: Tuple[int, int, int],
1582
+ epi_tile: Tuple[int, int],
1583
+ ) -> int:
1584
+ return 0
1585
+
1586
+ def epi_get_smem_struct(self, params: EpilogueParams):
1587
+ return cute.struct.MemRange[cutlass.Int32, 0] # Dummy struct
1588
+
1589
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
1590
+ return tuple()
1591
+
1592
+ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
1593
+ assert stage in ["mma", "epi"]
1594
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1595
+ cute.arch.barrier(
1596
+ barrier_id=int(barrier) + warp_group_idx,
1597
+ number_of_threads=2 * self.num_threads_per_warp_group,
1598
+ )
1599
+
1600
+ def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
1601
+ assert stage in ["mma", "epi"]
1602
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1603
+ cute.arch.barrier_arrive(
1604
+ barrier_id=int(barrier) + warp_group_idx,
1605
+ number_of_threads=2 * self.num_threads_per_warp_group,
1606
+ )
1607
+
1608
+ def epilog_smem_copy_atom(self, tiled_mma: cute.TiledMma) -> cute.TiledCopy:
1609
+ copy_atom_C = cute.make_copy_atom(
1610
+ warp.StMatrix8x8x16bOp(
1611
+ self.d_layout.is_m_major_c() if self.d_layout is not None else False,
1612
+ num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
1613
+ ),
1614
+ cutlass.Float16, # this is just to get the right source layout
1615
+ )
1616
+ tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
1617
+ return tiled_copy_C_atom
1618
+
1619
+ def epilog_smem_store_and_partition(
1620
+ self,
1621
+ tiled_mma: cute.TiledMma,
1622
+ d_layout: Optional[LayoutEnum],
1623
+ dtype: Type[cutlass.Numeric],
1624
+ acc: cute.Tensor,
1625
+ sD: cute.Tensor,
1626
+ tidx: Int32,
1627
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1628
+ if d_layout is None:
1629
+ d_layout = LayoutEnum.ROW_MAJOR
1630
+ tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
1631
+ # Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
1632
+ # get st.matrix with num_matrices=4
1633
+ copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
1634
+ d_layout, elem_ty_d=dtype, elem_ty_acc=self.acc_dtype
1635
+ )
1636
+ tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
1637
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1638
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1639
+ tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1640
+ # (R2S, R2S_M, R2S_N)
1641
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
1642
+ sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
1643
+ tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
1644
+ tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
1645
+ return tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD
1646
+
1647
+ def epilog_smem_load_and_partition(
1648
+ self,
1649
+ tiled_mma: cute.TiledMma,
1650
+ c_layout: LayoutEnum,
1651
+ dtype: Type[cutlass.Numeric],
1652
+ sC: cute.Tensor,
1653
+ tRS_rD_layout: cutlass.Layout,
1654
+ tidx: Int32,
1655
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1656
+ tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
1657
+ copy_atom_s2r = utils.sm90_get_smem_load_op(c_layout, dtype)
1658
+ tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1659
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1660
+ tSR_sC = thr_copy_s2r.partition_S(sC)
1661
+ tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
1662
+ tSR_rC = thr_copy_s2r.retile(tRS_rC)
1663
+ return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
1664
+
1665
+ def epilog_gmem_copy_and_partition(
1666
+ self,
1667
+ atom: Union[cute.CopyAtom, cute.TiledCopy],
1668
+ mD_mnl: cute.Tensor,
1669
+ tile_shape_mn: cute.Tile,
1670
+ epi_tile: cute.Tile,
1671
+ sD: cute.Tensor,
1672
+ tile_coord_mnkl: cute.Coord,
1673
+ cu_seqlens_m: Optional[cute.Tensor] = None,
1674
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
1675
+ batch_idx = tile_coord_mnkl[3]
1676
+ if const_expr(cu_seqlens_m is not None):
1677
+ mD_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl)
1678
+ else:
1679
+ mD_mn = mD_mnl[None, None, batch_idx]
1680
+ # (bM, bN)
1681
+ gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
1682
+ tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
1683
+ bSG_sD, bSG_gD = cpasync.tma_partition(
1684
+ atom,
1685
+ 0,
1686
+ cute.make_layout(1),
1687
+ cute.group_modes(sD, 0, 2),
1688
+ tDgD_for_tma_partition,
1689
+ )
1690
+ return bSG_sD, bSG_gD
1691
+
1692
+ def make_ab_pipeline(
1693
+ self,
1694
+ a_smem_layout: cute.Layout | cute.ComposedLayout,
1695
+ b_smem_layout: cute.Layout | cute.ComposedLayout,
1696
+ tiled_mma: cute.TiledMma,
1697
+ cluster_layout_vmnk: cute.Layout,
1698
+ ab_pipeline_mbar_ptr: cute.Pointer,
1699
+ ):
1700
+ # Threads/warps participating in this pipeline
1701
+ producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
1702
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1703
+ # Each warp will contribute to the arrive count with the number of mcast size
1704
+ mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
1705
+ consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
1706
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(
1707
+ pipeline.Agent.Thread, consumer_arrive_cnt
1708
+ )
1709
+ pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
1710
+ tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
1711
+ if const_expr(not self.gather_A):
1712
+ tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
1713
+ return pipeline_cls.create(
1714
+ barrier_storage=ab_pipeline_mbar_ptr,
1715
+ num_stages=self.ab_stage,
1716
+ producer_group=ab_pipeline_producer_group,
1717
+ consumer_group=ab_pipeline_consumer_group,
1718
+ tx_count=tma_copy_bytes,
1719
+ cta_layout_vmnk=cluster_layout_vmnk,
1720
+ )
1721
+
1722
+ def make_epi_pipeline(
1723
+ self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer
1724
+ ):
1725
+ # Threads/warps participating in this pipeline
1726
+ epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1727
+ # Each warp will contribute 1 to the arrive count
1728
+ consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
1729
+ epi_pipeline_consumer_group = pipeline.CooperativeGroup(
1730
+ pipeline.Agent.Thread, consumer_arrive_cnt
1731
+ )
1732
+ tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
1733
+ return pipeline.PipelineTmaAsync.create(
1734
+ barrier_storage=epi_pipeline_mbar_ptr,
1735
+ num_stages=self.epi_c_stage,
1736
+ producer_group=epi_pipeline_producer_group,
1737
+ consumer_group=epi_pipeline_consumer_group,
1738
+ tx_count=tma_copy_c_bytes,
1739
+ )
1740
+
1741
+ def make_sched_pipeline(
1742
+ self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
1743
+ ):
1744
+ # Threads/warps participating in this pipeline
1745
+ sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1746
+ cluster_size = cute.size(cluster_layout_mnk)
1747
+ # Each warp that are not the scheduler warp will contribute 1 to the arrive count
1748
+ # If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
1749
+ # at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
1750
+ consumer_arrive_cnt = (
1751
+ (self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
1752
+ + self.num_ab_load_warps
1753
+ ) * cluster_size - 1
1754
+ sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1755
+ pipeline.Agent.Thread, consumer_arrive_cnt
1756
+ )
1757
+ return pipeline.PipelineAsync.create(
1758
+ barrier_storage=sched_pipeline_mbar_ptr,
1759
+ num_stages=self.sched_stage,
1760
+ producer_group=sched_pipeline_producer_group,
1761
+ consumer_group=sched_pipeline_consumer_group,
1762
+ # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1763
+ consumer_mask=None if const_expr(cluster_size == 1) else 0,
1764
+ )
1765
+
1766
+ @classmethod
844
1767
  def _compute_stages(
1768
+ cls,
845
1769
  tile_shape_mnk: Tuple[int, int, int],
1770
+ epi_tile: Tuple[int, int],
846
1771
  a_dtype: Type[cutlass.Numeric],
847
1772
  b_dtype: Type[cutlass.Numeric],
1773
+ d_dtype: Optional[Type[cutlass.Numeric]],
1774
+ c_dtype: Optional[Type[cutlass.Numeric]],
1775
+ epilogue_args: Optional[EpilogueArguments],
848
1776
  smem_capacity: int,
849
1777
  occupancy: int,
1778
+ overlap_sD_sA: bool,
850
1779
  ) -> Tuple[int, int]:
851
1780
  """Computes the number of stages for A/B/C operands based on heuristics.
852
1781
 
@@ -866,10 +1795,20 @@ class HopperWgmmaGemmKernel:
866
1795
  :rtype: Tuple[int, int]
867
1796
  """
868
1797
 
869
- # epi_stage = 4 if tile_shape_mnk[1] % 32 == 0 else 8
870
- epi_stage = 4
871
- # epi_smem will reuse smem ab.
872
- epi_bytes = 0
1798
+ epi_stage = 4 if epi_tile[1] <= 16 else 2
1799
+ if overlap_sD_sA:
1800
+ epi_bytes = 0
1801
+ else:
1802
+ d_bytes_per_stage = (
1803
+ cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1804
+ )
1805
+ epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1806
+ epilogue_args, tile_shape_mnk, epi_tile
1807
+ )
1808
+ epi_bytes = epi_bytes_per_stage * epi_stage
1809
+ epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
1810
+ if c_dtype is not None:
1811
+ epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
873
1812
 
874
1813
  a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
875
1814
  b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
@@ -878,16 +1817,21 @@ class HopperWgmmaGemmKernel:
878
1817
  )
879
1818
  mbar_helpers_bytes = 1024
880
1819
 
881
- ab_stage = (
882
- (smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
883
- ) // ab_bytes_per_stage
884
- return ab_stage, epi_stage
1820
+ remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
1821
+ ab_stage = remaining_bytes // ab_bytes_per_stage
1822
+
1823
+ # Refine epilogue stages:
1824
+ # Calculate remaining smem after allocating for A/B stages and reserved bytes
1825
+ # Add remaining unused smem to epilogue
1826
+ if not overlap_sD_sA and epi_bytes_per_stage > 0:
1827
+ epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
1828
+ return ab_stage, epi_stage, epi_c_stage
885
1829
 
886
1830
  @staticmethod
887
1831
  def _sm90_compute_tile_shape_or_override(
888
1832
  tile_shape_mnk: Tuple[int, int, int],
889
- element_type: Type[cutlass.Numeric],
890
- is_cooperative: bool = False,
1833
+ atom_layout_mnk: Tuple[int, int, int],
1834
+ element_type: Optional[Type[cutlass.Numeric]] = None,
891
1835
  epi_tile_override: Tuple[int, int] | None = None,
892
1836
  ) -> Tuple[int, int]:
893
1837
  """Compute the epilogue tile shape or use override if provided.
@@ -906,33 +1850,42 @@ class HopperWgmmaGemmKernel:
906
1850
  """
907
1851
  if epi_tile_override is not None:
908
1852
  return epi_tile_override
909
- if is_cooperative:
910
- if cute.size(tile_shape_mnk, mode=[0]) == 192:
911
- tile_m = 192
912
- tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]) // 2)
913
- else:
914
- tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
915
- tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
916
- return (tile_m, tile_n)
1853
+ if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
1854
+ tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
1855
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1856
+ elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
1857
+ tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
1858
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
917
1859
  else:
918
- n_perf = 64 if element_type.width == 8 else 32
1860
+ # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
1861
+ # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
1862
+ # M dimension first, then move to the N dimension. But the accumulator in registers
1863
+ # iterate along the N dimension first, then move to the M dimension.
1864
+ # We could change the epilogue to accommodate this,
1865
+ # but it's easier to just set epi_tile_m = 64.
1866
+ n_perf = 64 if element_type is not None and element_type.width == 8 else 32
919
1867
  tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
920
1868
  tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
921
- return (tile_m, tile_n)
1869
+ return (tile_m, tile_n)
922
1870
 
923
1871
  @staticmethod
924
1872
  def _make_smem_layouts(
925
1873
  tile_shape_mnk: Tuple[int, int, int],
926
1874
  epi_tile: Tuple[int, int],
927
1875
  a_dtype: Type[cutlass.Numeric],
928
- a_layout: utils.LayoutEnum,
1876
+ a_layout: LayoutEnum,
929
1877
  b_dtype: Type[cutlass.Numeric],
930
- b_layout: utils.LayoutEnum,
1878
+ b_layout: LayoutEnum,
931
1879
  ab_stage: int,
932
- d_dtype: Type[cutlass.Numeric],
933
- d_layout: utils.LayoutEnum,
1880
+ d_dtype: Optional[Type[cutlass.Numeric]],
1881
+ d_layout: LayoutEnum,
934
1882
  epi_stage: int,
935
- ) -> Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]:
1883
+ c_dtype: Optional[Type[cutlass.Numeric]],
1884
+ c_layout: Optional[LayoutEnum],
1885
+ epi_c_stage: int,
1886
+ ) -> Tuple[
1887
+ cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
1888
+ ]:
936
1889
  """Create shared memory layouts for A, B, and C tensors.
937
1890
 
938
1891
  :param tile_shape_mnk: CTA tile shape (M,N,K)
@@ -942,17 +1895,17 @@ class HopperWgmmaGemmKernel:
942
1895
  :param a_dtype: Data type for matrix A
943
1896
  :type a_dtype: type[cutlass.Numeric]
944
1897
  :param a_layout: Layout enum for matrix A
945
- :type a_layout: utils.LayoutEnum
1898
+ :type a_layout: LayoutEnum
946
1899
  :param b_dtype: Data type for matrix B
947
1900
  :type b_dtype: type[cutlass.Numeric]
948
1901
  :param b_layout: Layout enum for matrix B
949
- :type b_layout: utils.LayoutEnum
1902
+ :type b_layout: LayoutEnum
950
1903
  :param ab_stage: Number of stages for A/B tensors
951
1904
  :type ab_stage: int
952
- :param d_dtype: Data type for output matrix C
1905
+ :param d_dtype: Data type for output matrix D
953
1906
  :type d_dtype: type[cutlass.Numeric]
954
1907
  :param d_layout: Layout enum for the output matrix C
955
- :type d_layout: utils.LayoutEnum
1908
+ :type d_layout: LayoutEnum
956
1909
  :param epi_stage: Number of epilogue stages
957
1910
  :type epi_stage: int
958
1911
 
@@ -965,11 +1918,7 @@ class HopperWgmmaGemmKernel:
965
1918
  b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
966
1919
  a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
967
1920
  a_smem_layout_atom = warpgroup.make_smem_layout_atom(
968
- sm90_utils.get_smem_layout_atom(
969
- a_layout,
970
- a_dtype,
971
- a_major_mode_size,
972
- ),
1921
+ sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
973
1922
  a_dtype,
974
1923
  )
975
1924
  a_smem_layout_staged = cute.tile_to_shape(
@@ -982,11 +1931,7 @@ class HopperWgmmaGemmKernel:
982
1931
 
983
1932
  b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
984
1933
  b_smem_layout_atom = warpgroup.make_smem_layout_atom(
985
- sm90_utils.get_smem_layout_atom(
986
- b_layout,
987
- b_dtype,
988
- b_major_mode_size,
989
- ),
1934
+ sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
990
1935
  b_dtype,
991
1936
  )
992
1937
  b_smem_layout_staged = cute.tile_to_shape(
@@ -995,56 +1940,52 @@ class HopperWgmmaGemmKernel:
995
1940
  order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
996
1941
  )
997
1942
 
998
- d_smem_shape = epi_tile
999
- d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1000
- d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1001
- sm90_utils.get_smem_layout_atom(
1002
- d_layout,
1943
+ if d_dtype is not None:
1944
+ d_smem_shape = epi_tile
1945
+ d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1946
+ d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1947
+ sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
1003
1948
  d_dtype,
1004
- d_major_mode_size,
1005
- ),
1006
- d_dtype,
1007
- )
1008
- epi_smem_layout_staged = cute.tile_to_shape(
1009
- d_smem_layout_atom,
1010
- cute.append(d_smem_shape, epi_stage),
1011
- order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1012
- )
1013
-
1014
- return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged
1015
-
1016
- @staticmethod
1017
- def _compute_grid(
1018
- d: cute.Tensor,
1019
- tile_shape_mnk: Tuple[int, int, int],
1020
- cluster_shape_mnk: Tuple[int, int, int],
1021
- ) -> Tuple[int, int, int]:
1022
- """Compute grid shape for the output tensor C.
1023
-
1024
- :param d: The output tensor C
1025
- :type d: cute.Tensor
1026
- :param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1027
- :type tile_shape_mnk: Tuple[int, int, int]
1028
- :param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions.
1029
- :type cluster_shape_mnk: Tuple[int, int, int]
1030
-
1031
- :return: Grid shape for kernel launch.
1032
- :rtype: Tuple[int, int, int]
1033
- """
1949
+ )
1950
+ epi_smem_layout_staged = cute.tile_to_shape(
1951
+ d_smem_layout_atom,
1952
+ cute.append(d_smem_shape, epi_stage),
1953
+ order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1954
+ )
1955
+ else:
1956
+ epi_smem_layout_staged = None
1957
+
1958
+ if c_dtype is not None:
1959
+ assert c_layout is not None
1960
+ c_smem_shape = epi_tile
1961
+ c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
1962
+ c_smem_layout_atom = warpgroup.make_smem_layout_atom(
1963
+ sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
1964
+ c_dtype,
1965
+ )
1966
+ epi_c_smem_layout_staged = cute.tile_to_shape(
1967
+ c_smem_layout_atom,
1968
+ cute.append(c_smem_shape, epi_c_stage),
1969
+ order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
1970
+ )
1971
+ else:
1972
+ epi_c_smem_layout_staged = None
1034
1973
 
1035
- c_shape = (tile_shape_mnk[0], tile_shape_mnk[1])
1036
- gc = cute.zipped_divide(d, tiler=c_shape)
1037
- clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk)
1038
- grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk))
1039
- return grid
1974
+ return (
1975
+ a_smem_layout_staged,
1976
+ b_smem_layout_staged,
1977
+ epi_smem_layout_staged,
1978
+ epi_c_smem_layout_staged,
1979
+ )
1040
1980
 
1041
1981
  @staticmethod
1042
- def _make_tma_store_atoms_and_tensors(
1982
+ def _make_tma_epi_atoms_and_tensors(
1043
1983
  tensor_d: cute.Tensor,
1044
1984
  epi_smem_layout_staged: cute.ComposedLayout,
1045
1985
  epi_tile: Tuple[int, int],
1986
+ store_or_load: str,
1046
1987
  ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1047
- """Create TMA atoms and tensors for C tensor storage.
1988
+ """Create TMA atoms and tensors for storing D or loading C.
1048
1989
 
1049
1990
  :param tensor_d: Output tensor D
1050
1991
  :type tensor_d: cute.Tensor
@@ -1056,15 +1997,17 @@ class HopperWgmmaGemmKernel:
1056
1997
  :return: TMA atom and tensor for C
1057
1998
  :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1058
1999
  """
2000
+ assert store_or_load in ["load", "store"]
1059
2001
  epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
1060
- c_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
2002
+ d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
2003
+ op = (
2004
+ cpasync.CopyBulkTensorTileG2SOp()
2005
+ if store_or_load == "load"
2006
+ else cpasync.CopyBulkTensorTileS2GOp()
2007
+ )
1061
2008
  tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
1062
- cpasync.CopyBulkTensorTileS2GOp(),
1063
- tensor_d,
1064
- epi_smem_layout,
1065
- c_cta_v_layout,
2009
+ op, tensor_d, epi_smem_layout, d_cta_v_layout
1066
2010
  )
1067
-
1068
2011
  return tma_atom_d, tma_tensor_d
1069
2012
 
1070
2013
  @staticmethod
@@ -1104,12 +2047,37 @@ class HopperWgmmaGemmKernel:
1104
2047
  )
1105
2048
  return tma_atom, tma_tensor
1106
2049
 
2050
+ def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128):
2051
+ atom_async_copy = cute.make_copy_atom(
2052
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
2053
+ dtype,
2054
+ num_bits_per_copy=copy_bits,
2055
+ )
2056
+ copy_elems = copy_bits // dtype.width
2057
+ shape_dim_1 = cute.size(self.tile_shape_mnk[2]) // copy_elems
2058
+ # thread layout for copy
2059
+ thread_layout = cute.make_layout(
2060
+ (num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
2061
+ )
2062
+ if major_mode != LayoutEnum.ROW_MAJOR:
2063
+ shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
2064
+ thread_layout = cute.make_layout(
2065
+ (shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
2066
+ )
2067
+ # Value layout for copy
2068
+ value_layout = (
2069
+ cute.make_layout((1, copy_elems))
2070
+ if major_mode == LayoutEnum.ROW_MAJOR
2071
+ else cute.make_layout((copy_elems, 1))
2072
+ )
2073
+ return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
2074
+
1107
2075
  @staticmethod
1108
2076
  def is_valid_dtypes(
1109
2077
  a_dtype: Type[cutlass.Numeric],
1110
2078
  b_dtype: Type[cutlass.Numeric],
1111
2079
  acc_dtype: Type[cutlass.Numeric],
1112
- d_dtype: Type[cutlass.Numeric],
2080
+ d_dtype: Optional[Type[cutlass.Numeric]],
1113
2081
  a_major: str,
1114
2082
  b_major: str,
1115
2083
  ) -> bool:
@@ -1133,7 +2101,6 @@ class HopperWgmmaGemmKernel:
1133
2101
  :rtype: bool
1134
2102
  """
1135
2103
  is_valid = True
1136
- # tested a_dtype
1137
2104
  if a_dtype not in {
1138
2105
  cutlass.Float16,
1139
2106
  cutlass.BFloat16,
@@ -1149,11 +2116,11 @@ class HopperWgmmaGemmKernel:
1149
2116
  cutlass.Float8E5M2,
1150
2117
  }:
1151
2118
  is_valid = False
1152
- # tested acc_dtype
1153
2119
  if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
1154
2120
  is_valid = False
1155
2121
  # tested d_dtype
1156
2122
  if d_dtype not in {
2123
+ None,
1157
2124
  cutlass.Float32,
1158
2125
  cutlass.Float16,
1159
2126
  cutlass.BFloat16,
@@ -1171,260 +2138,108 @@ class HopperWgmmaGemmKernel:
1171
2138
  # for Float8 types, this implementation only supports k-major layout
1172
2139
  if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
1173
2140
  is_valid = False
1174
-
1175
2141
  return is_valid
1176
2142
 
1177
2143
 
1178
- def run(
1179
- mnkl: Tuple[int, int, int, int],
1180
- a_dtype: Type[cutlass.Numeric],
1181
- b_dtype: Type[cutlass.Numeric],
1182
- d_dtype: Type[cutlass.Numeric],
1183
- acc_dtype: Type[cutlass.Numeric],
1184
- a_major: str,
1185
- b_major: str,
1186
- d_major: str,
1187
- tile_shape_mnk: Tuple[int, int, int],
1188
- cluster_shape_mn: Tuple[int, int],
1189
- tolerance: float,
1190
- warmup_iterations: int,
1191
- iterations: int,
1192
- skip_ref_check: bool,
1193
- use_cold_l2: bool = False,
1194
- **kwargs,
1195
- ):
1196
- """
1197
- Prepare A/B/C tensors, launch GPU kernel, and reference checking.
1198
-
1199
- :param mnkl: Problem size (M, N, K, L)
1200
- :type mnkl: Tuple[int, int, int, int]
1201
- :param a_dtype: Data type for input tensor A
1202
- :type a_dtype: Type[cutlass.Numeric]
1203
- :param b_dtype: Data type for input tensor B
1204
- :type b_dtype: Type[cutlass.Numeric]
1205
- :param d_dtype: Data type for output tensor C
1206
- :type d_dtype: Type[cutlass.Numeric]
1207
- :param acc_dtype: Data type for accumulation during matrix multiplication
1208
- :type acc_dtype: Type[cutlass.Numeric]
1209
- :param a_major/b_major/d_major: Memory layout of tensor A/B/C
1210
- :type a_major/b_major/d_major: str
1211
- :param tile_shape_mnk: CTA tile shape (M, N, K)
1212
- :type tile_shape_mnk: Tuple[int, int, int]
1213
- :param cluster_shape_mn: Cluster shape (M, N)
1214
- :type cluster_shape_mn: Tuple[int, int]
1215
- :param tolerance: Tolerance value for reference validation comparison
1216
- :type tolerance: float
1217
- :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
1218
- :type warmup_iterations: int, optional
1219
- :param iterations: Number of benchmark iterations to run, defaults to 1
1220
- :type iterations: int, optional
1221
- :param skip_ref_check: Whether to skip reference result validation, defaults to False
1222
- :type skip_ref_check: bool, optional
1223
- :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
1224
- :type use_cold_l2: bool, optional
1225
- :return: Execution time of the GEMM kernel in microseconds
1226
- :rtype: float
1227
- """
1228
-
1229
- print("Running Hopper Dense GEMM with:")
1230
- print(f"mnkl: {mnkl}")
1231
- print(f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
1232
- print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
1233
- print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
1234
- print(f"Tolerance: {tolerance}")
1235
- print(f"Warmup iterations: {warmup_iterations}")
1236
- print(f"Iterations: {iterations}")
1237
- print(f"Skip reference checking: {skip_ref_check}")
1238
- print(f"Use cold L2: {use_cold_l2}")
1239
-
1240
- # Unpack parameters
1241
- m, n, k, l = mnkl
1242
- cluster_shape_mnk = (*cluster_shape_mn, 1)
1243
-
1244
- # Skip unsupported types
1245
- if not HopperWgmmaGemmKernel.is_valid_dtypes(
1246
- a_dtype, b_dtype, acc_dtype, d_dtype, a_major, b_major
2144
+ def gemm_sm90(
2145
+ A: Tensor, # (l, m, k)
2146
+ B: Tensor, # (l, n, k)
2147
+ D: Tensor, # (l, m, n)
2148
+ C: Optional[Tensor], # (l, m, n)
2149
+ tile_count_semaphore: Optional[Tensor], # (1,)
2150
+ tile_M: int,
2151
+ tile_N: int,
2152
+ cluster_M: int,
2153
+ cluster_N: int,
2154
+ pingpong: bool = False,
2155
+ persistent: bool = True,
2156
+ alpha: float | Tensor = 1.0,
2157
+ beta: float | Tensor = 1.0,
2158
+ ) -> None:
2159
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(A, B, D, C)
2160
+ GemmWrapperBase.permute_tensors(tensor_infos)
2161
+ GemmWrapperBase.extract_dtypes(tensor_infos)
2162
+ major_configs = {
2163
+ "A": ("m", "k", "l"),
2164
+ "B": ("n", "k", "l"),
2165
+ "D": ("m", "n", "l"),
2166
+ "C": ("m", "n", "l"),
2167
+ }
2168
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
2169
+
2170
+ acc_dtype = cutlass.Float32
2171
+ tile_shape_mn = (tile_M, tile_N)
2172
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
2173
+ if not GemmSm90.is_valid_dtypes(
2174
+ tensor_infos["A"].dtype,
2175
+ tensor_infos["B"].dtype,
2176
+ acc_dtype,
2177
+ tensor_infos["D"].dtype,
2178
+ tensor_infos["A"].major,
2179
+ tensor_infos["B"].major,
1247
2180
  ):
1248
- raise TypeError(
1249
- f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {d_dtype}, {a_major=}, {b_major=}"
1250
- )
2181
+ raise TypeError("Skipping due to unsupported combination of types and majors")
1251
2182
 
1252
- # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero)
1253
- if not torch.cuda.is_available():
1254
- raise RuntimeError("GPU is required to run this example!")
1255
-
1256
- torch.manual_seed(1111)
1257
-
1258
- # Create and permute tensor A/B/C
1259
- def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
1260
- # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
1261
- # else : (l, mode0, mode1) -> (mode0, mode1, l)
1262
- shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
1263
- permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
1264
- is_unsigned = dtype in {cutlass.Uint8}
1265
- # Temporarily use uint8 as torch does not support fp8 type
1266
- torch_dtype = (
1267
- cutlass_torch.dtype(dtype)
1268
- if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
1269
- else torch.uint8
1270
- )
1271
-
1272
- # Create dtype torch tensor (cpu)
1273
- torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
1274
- shape,
1275
- torch_dtype,
1276
- permute_order=permute_order,
1277
- # init_type=cutlass.torch.TensorInitType.RANDOM,
1278
- # init_config=cutlass.torch.RandomInitConfig(
1279
- # min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
1280
- # ),
1281
- init_type=cutlass.torch.TensorInitType.GAUSSIAN,
1282
- init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
1283
- )
1284
- # Create dtype torch tensor (gpu)
1285
- torch_tensor = torch_tensor_cpu.cuda()
1286
-
1287
- # Create f32 torch tensor (cpu)
1288
- f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
1289
-
1290
- # Create dtype cute tensor (gpu)
1291
- cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
1292
- cute_tensor.element_type = dtype
1293
- if is_dynamic_layout:
1294
- cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
1295
- cute_tensor = cutlass.torch.convert_cute_tensor(
1296
- f32_torch_tensor,
1297
- cute_tensor,
1298
- dtype,
1299
- is_dynamic_layout=is_dynamic_layout,
1300
- )
2183
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
2184
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
1301
2185
 
1302
- return f32_torch_tensor, cute_tensor, torch_tensor
1303
-
1304
- a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
1305
- b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
1306
- c, mC, c_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
1307
-
1308
- gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk)
1309
-
1310
- torch_stream = torch.cuda.Stream()
1311
- stream = cuda.CUstream(torch_stream.cuda_stream)
1312
- # compile gemm kernel
1313
- compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
1314
-
1315
- if not skip_ref_check:
1316
- # execution
1317
- compiled_gemm(mA, mB, mC, stream)
1318
-
1319
- torch.cuda.synchronize()
1320
-
1321
- # Ref check
1322
- ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
1323
-
1324
- if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
1325
- # m major: (l, n, m) -> (m, n, l)
1326
- # n major: (l, m, n) -> (m, n, l)
1327
- permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
1328
- shape = (l, m, n) if d_major == "n" else (l, n, m)
1329
- f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
1330
- shape,
1331
- torch.uint8,
1332
- permute_order=permute_order,
1333
- init_type=cutlass_torch.TensorInitType.SKIP,
1334
- ).cuda()
1335
- # Create dtype cute tensor (gpu)
1336
- ref_c_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
1337
- leading_dim=(1 if d_major == "n" else 0)
1338
- )
1339
- ref_c_tensor.element_type = d_dtype
1340
- ref_c_tensor = cutlass_torch.convert_cute_tensor(
1341
- ref,
1342
- ref_c_tensor,
1343
- d_dtype,
1344
- is_dynamic_layout=True,
1345
- )
1346
- ref_c = f8_torch_tensor.cpu()
2186
+ def scalar_arg(scalar: float | Tensor):
2187
+ if isinstance(scalar, float):
2188
+ return Float32(scalar) if scalar != 1.0 else None
1347
2189
  else:
1348
- ref_c = ref.to(cutlass_torch.dtype(d_dtype))
1349
-
1350
- torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
1351
-
1352
- def generate_tensors():
1353
- _, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
1354
- _, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
1355
- _, mC_workspace, _ = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
1356
- return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
1357
-
1358
- workspace_count = 1
1359
- if use_cold_l2:
1360
- one_workspace_bytes = (
1361
- a_torch.numel() * a_torch.element_size()
1362
- + b_torch.numel() * b_torch.element_size()
1363
- + c_torch.numel() * c_torch.element_size()
2190
+ assert isinstance(scalar, Tensor)
2191
+ return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2192
+
2193
+ epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta))
2194
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
2195
+ max_active_clusters, tile_count_semaphore
2196
+ )
2197
+ current_stream = cutlass_torch.current_stream()
2198
+ compile_key = GemmWrapperBase.get_compile_key(
2199
+ tensor_infos,
2200
+ None,
2201
+ tile_shape_mn,
2202
+ cluster_shape_mnk,
2203
+ pingpong,
2204
+ persistent,
2205
+ tile_count_semaphore is not None,
2206
+ 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
2207
+ 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
2208
+ key_tensor_names=("A", "B", "D", "C"),
2209
+ )
2210
+ cache = gemm_sm90.compile_cache
2211
+ if compile_key not in cache:
2212
+ gemm = GemmSm90(
2213
+ acc_dtype,
2214
+ tensor_infos["A"].dtype,
2215
+ tile_shape_mn,
2216
+ cluster_shape_mnk,
2217
+ pingpong=pingpong,
2218
+ is_persistent=persistent,
1364
2219
  )
1365
- workspace_count = testing.get_workspace_count(
1366
- one_workspace_bytes, warmup_iterations, iterations
2220
+ cache[compile_key] = cute.compile(
2221
+ gemm,
2222
+ tensor_infos["A"].cute_tensor,
2223
+ tensor_infos["B"].cute_tensor,
2224
+ tensor_infos["D"].cute_tensor,
2225
+ tensor_infos["C"].cute_tensor,
2226
+ epi_args,
2227
+ scheduler_args,
2228
+ None, # varlen_args
2229
+ None, # mAIdx
2230
+ current_stream,
1367
2231
  )
1368
-
1369
- exec_time = testing.benchmark(
1370
- compiled_gemm,
1371
- workspace_generator=generate_tensors,
1372
- workspace_count=workspace_count,
1373
- stream=stream,
1374
- warmup_iterations=warmup_iterations,
1375
- iterations=iterations,
2232
+ cache[compile_key](
2233
+ tensor_infos["A"].cute_tensor,
2234
+ tensor_infos["B"].cute_tensor,
2235
+ tensor_infos["D"].cute_tensor,
2236
+ tensor_infos["C"].cute_tensor,
2237
+ epi_args,
2238
+ scheduler_args,
2239
+ None,
2240
+ None,
2241
+ current_stream,
1376
2242
  )
1377
2243
 
1378
- from triton.testing import do_bench
1379
-
1380
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1381
-
1382
- flops = 2 * m * n * k * l
1383
-
1384
- repeats = 30
1385
- # repeats = 1
1386
- warmup = 5
1387
-
1388
- import time
1389
-
1390
- time.sleep(0.5)
1391
- fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
1392
- timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
1393
- tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
1394
- print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
1395
-
1396
- time.sleep(0.5)
1397
- fn = lambda: compiled_gemm(mA, mB, mC, current_stream)
1398
- timing = do_bench(fn, warmup=warmup, rep=repeats)
1399
- tflops = flops / (timing * 1e9) # Convert to TFlops
1400
- print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
1401
-
1402
- time.sleep(0.5)
1403
- fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
1404
- timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
1405
- tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
1406
- print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
1407
-
1408
- return exec_time # Return execution time in microseconds
1409
-
1410
-
1411
- if __name__ == "__main__":
1412
- args = parse_arguments()
1413
- run(
1414
- args.mnkl,
1415
- args.a_dtype,
1416
- args.b_dtype,
1417
- args.d_dtype,
1418
- args.acc_dtype,
1419
- args.a_major,
1420
- args.b_major,
1421
- args.d_major,
1422
- args.tile_shape_mnk,
1423
- args.cluster_shape_mn,
1424
- args.tolerance,
1425
- args.warmup_iterations,
1426
- args.iterations,
1427
- args.skip_ref_check,
1428
- args.use_cold_l2,
1429
- )
1430
- print("PASS")
2244
+
2245
+ gemm_sm90.compile_cache = {}