quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/dense_gemm_sm90.py CHANGED
@@ -1,63 +1,43 @@
1
- # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD-3-Clause
3
-
4
- # Redistribution and use in source and binary forms, with or without
5
- # modification, are permitted provided that the following conditions are met:
6
-
7
- # 1. Redistributions of source code must retain the above copyright notice, this
8
- # list of conditions and the following disclaimer.
9
-
10
- # 2. Redistributions in binary form must reproduce the above copyright notice,
11
- # this list of conditions and the following disclaimer in the documentation
12
- # and/or other materials provided with the distribution.
13
-
14
- # 3. Neither the name of the copyright holder nor the names of its
15
- # contributors may be used to endorse or promote products derived from
16
- # this software without specific prior written permission.
17
-
18
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
-
29
- import argparse
1
+ # Based on the cute-dsl example:
2
+ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
3
+
30
4
  import enum
31
- from typing import Tuple, Type, Callable, Optional
5
+ from typing import Tuple, Type, Callable, Optional, Union
6
+ from dataclasses import dataclass
32
7
  from functools import partial
33
8
  import math
34
9
 
35
- import cuda.bindings.driver as cuda
10
+ from torch import Tensor
36
11
 
37
- import torch
12
+ import cuda.bindings.driver as cuda
38
13
 
39
14
  import cutlass
40
15
  import cutlass.cute as cute
41
16
  import cutlass.pipeline as pipeline
42
- import cutlass.torch as cutlass_torch
43
- from cutlass.cute.runtime import from_dlpack, make_ptr
44
17
  from cutlass.cute.nvgpu import cpasync, warp, warpgroup
45
18
  import cutlass.utils.hopper_helpers as sm90_utils
46
- from cutlass import Int32, const_expr
19
+ from cutlass import Int32, Float32, Boolean, const_expr
20
+ from cutlass.utils import LayoutEnum
21
+ import cutlass.torch as cutlass_torch
22
+ from cutlass.cute.runtime import make_ptr
23
+
47
24
 
25
+ from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
48
26
  from quack.tile_scheduler import (
27
+ TileSchedulerOptions,
49
28
  TileSchedulerArguments,
50
29
  TileScheduler,
51
30
  VarlenMTileSchedulerArguments,
52
31
  VarlenMTileScheduler,
53
- ParamsBase,
54
- RasterOrderOption,
55
32
  )
33
+ from quack.varlen_utils import VarlenArguments
56
34
  from quack.tensormap_manager import TensorMapManagerSm90
57
35
 
58
36
  # return PipelineStateWAdvance instead of PipelineState
59
37
  from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
60
38
  import quack.utils as utils
39
+ from quack.cute_dsl_utils import get_max_active_clusters
40
+ from quack.gemm_wrapper_utils import GemmWrapperBase
61
41
 
62
42
  """
63
43
  A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
@@ -82,31 +62,6 @@ Hopper WGMMA instructions operate as follows:
82
62
  - Read matrix B from SMEM
83
63
  - Perform MMA operation and store the result in Accumulator(register)
84
64
 
85
- To run this example:
86
-
87
- .. code-block:: bash
88
-
89
- python examples/hopper/dense_gemm.py \
90
- --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
91
- --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
92
- --d_dtype Float16 --acc_dtype Float32 \
93
- --a_major k --b_major k --d_major n
94
-
95
- The above example command compute batched gemm with M=8192, N=8192, K=8192,
96
- batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape
97
- is (1,1). The input, mma accumulator and output data type are set as fp16, fp32
98
- and fp16, respectively.
99
-
100
- To collect performance with NCU profiler:
101
-
102
- .. code-block:: bash
103
-
104
- ncu python examples/hopper/dense_gemm.py \
105
- --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
106
- --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
107
- --d_dtype Float16 --acc_dtype Float32 \
108
- --a_major k --b_major k --d_major n
109
-
110
65
  Constraints:
111
66
  * Supported input data types: fp16, fp8 (e4m3fn, e5m2)
112
67
  * For fp16 types, A and B must have the same data type
@@ -119,106 +74,9 @@ Constraints:
119
74
  * Cluster shape M/N must be positive and power of 2, total cluster size <= 4
120
75
  * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
121
76
  i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
122
- * OOB tiles are not allowed when TMA store is disabled
123
77
  """
124
78
 
125
79
 
126
- # /////////////////////////////////////////////////////////////////////////////
127
- # Helpers to parse args
128
- # /////////////////////////////////////////////////////////////////////////////
129
- def parse_comma_separated_ints(s: str):
130
- try:
131
- return tuple([int(x.strip()) for x in s.split(",")])
132
- except ValueError:
133
- raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
134
-
135
-
136
- def parse_arguments() -> argparse.Namespace:
137
- parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.")
138
-
139
- parser.add_argument(
140
- "--mnkl",
141
- type=parse_comma_separated_ints,
142
- default=(4096, 4096, 4096, 1),
143
- help="mnkl dimensions (comma-separated)",
144
- )
145
- parser.add_argument(
146
- "--tile_shape_mnk",
147
- type=parse_comma_separated_ints,
148
- default=(128, 256, 64),
149
- help="Cta tile shape (comma-separated)",
150
- )
151
- parser.add_argument(
152
- "--cluster_shape_mn",
153
- type=parse_comma_separated_ints,
154
- choices=[(1, 1), (2, 1), (1, 2), (2, 2)],
155
- default=(1, 1),
156
- help="Cluster shape (comma-separated)",
157
- )
158
- parser.add_argument(
159
- "--a_dtype",
160
- type=cutlass.dtype,
161
- default=cutlass.BFloat16,
162
- )
163
- parser.add_argument(
164
- "--b_dtype",
165
- type=cutlass.dtype,
166
- default=cutlass.BFloat16,
167
- )
168
- parser.add_argument(
169
- "--d_dtype",
170
- type=cutlass.dtype,
171
- default=cutlass.BFloat16,
172
- )
173
- parser.add_argument(
174
- "--c_dtype",
175
- type=cutlass.dtype,
176
- default=None,
177
- )
178
- parser.add_argument(
179
- "--acc_dtype",
180
- type=cutlass.dtype,
181
- default=cutlass.Float32,
182
- )
183
- parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
184
- parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
185
- parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
186
- parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
187
- parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
188
- parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
189
- parser.add_argument(
190
- "--iterations",
191
- type=int,
192
- default=30,
193
- help="Number of iterations to run the kernel",
194
- )
195
- parser.add_argument("--persistent", action="store_true", help="Persistent kernel")
196
- parser.add_argument(
197
- "--dynamic_persistent", action="store_true", help="Dynamic persistent kernel"
198
- )
199
- parser.add_argument("--pingpong", action="store_true", help="Pingpong kernel")
200
- parser.add_argument("--varlen_m", action="store_true", help="Variable length M dimension")
201
- parser.add_argument("--gather_A", action="store_true", help="Gather A")
202
- parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum")
203
- parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
204
-
205
- args = parser.parse_args()
206
-
207
- if len(args.mnkl) != 4:
208
- parser.error("--mnkl must contain exactly 4 values")
209
- if len(args.tile_shape_mnk) != 3:
210
- parser.error("--tile_shape_mnk must contain exactly 3 values")
211
- if len(args.cluster_shape_mn) != 2:
212
- parser.error("--cluster_shape_mn must contain exactly 2 values")
213
-
214
- return args
215
-
216
-
217
- # /////////////////////////////////////////////////////////////////////////////
218
- # Host setup and device kernel launch
219
- # /////////////////////////////////////////////////////////////////////////////
220
-
221
-
222
80
  class NamedBarrierGemm(enum.IntEnum):
223
81
  Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
224
82
  # For mainloop load warps to signal that the epilogue load warp can start.
@@ -230,15 +88,15 @@ class NamedBarrierGemm(enum.IntEnum):
230
88
  EpiWG1 = enum.auto()
231
89
 
232
90
 
233
- class HopperWgmmaGemmKernel:
91
+ class GemmSm90:
234
92
  """
235
93
  This class implements batched matrix multiplication (C = A x B) with support for various data types
236
94
  and architectural features specific to Hopper GPUs.
237
95
 
238
96
  :param acc_dtype: Data type for accumulation during computation
239
97
  :type acc_dtype: type[cutlass.Numeric]
240
- :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
241
- :type tile_shape_mnk: Tuple[int, int, int]
98
+ :param tile_shape_mn: Shape of the CTA tile (M,N)
99
+ :type tile_shape_mn: Tuple[int, int, int]
242
100
  :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
243
101
  :type cluster_shape_mnk: Tuple[int, int, int]
244
102
 
@@ -259,22 +117,31 @@ class HopperWgmmaGemmKernel:
259
117
  - Cluster shape M/N must be positive and power of 2, total cluster size <= 4
260
118
 
261
119
  Example:
262
- >>> gemm = HopperWgmmaGemmKernel(
120
+ >>> gemm = GemmSm90(
263
121
  ... acc_dtype=cutlass.Float32,
264
- ... tile_shape_mnk=(128, 256, 64),
122
+ ... tile_shape_mn=(128, 256),
265
123
  ... cluster_shape_mnk=(1, 1, 1)
266
124
  ... )
267
125
  >>> gemm(a_tensor, b_tensor, c_tensor, stream)
268
126
  """
269
127
 
270
128
  bytes_per_tensormap = 128
271
- num_tensormaps = 1 # For D only
129
+
130
+ @dataclass
131
+ class EpilogueArguments(ArgumentsBase):
132
+ alpha: Optional[Float32 | cute.Tensor] = None
133
+ beta: Optional[Float32 | cute.Tensor] = None
134
+
135
+ @dataclass
136
+ class EpilogueParams(ParamsBase):
137
+ alpha: Optional[Float32 | cute.Tensor] = None
138
+ beta: Optional[Float32 | cute.Tensor] = None
272
139
 
273
140
  def __init__(
274
141
  self,
275
142
  acc_dtype: Type[cutlass.Numeric],
276
143
  a_dtype: Type[cutlass.Numeric],
277
- tile_shape_mnk: Tuple[int, int, int],
144
+ tile_shape_mn: Tuple[int, int],
278
145
  cluster_shape_mnk: Tuple[int, int, int],
279
146
  pingpong: bool = False,
280
147
  is_persistent: bool = True,
@@ -289,8 +156,8 @@ class HopperWgmmaGemmKernel:
289
156
 
290
157
  :param acc_dtype: Data type for accumulation during computation
291
158
  :type acc_dtype: type[cutlass.Numeric]
292
- :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
293
- :type tile_shape_mnk: Tuple[int, int, int]
159
+ :param tile_shape_mn: Shape of the CTA tile (M,N)
160
+ :type tile_shape_mn: Tuple[int, int]
294
161
  :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
295
162
  :type cluster_shape_mnk: Tuple[int, int, int]
296
163
  """
@@ -304,11 +171,11 @@ class HopperWgmmaGemmKernel:
304
171
  self.gather_A = gather_A
305
172
  if gather_A:
306
173
  assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
307
- self.tensormap_update_mode = cutlass.utils.TensorMapUpdateMode.SMEM
308
174
 
309
175
  self.cluster_shape_mnk = cluster_shape_mnk
310
- self.tile_shape_mnk = tuple(tile_shape_mnk)
311
- tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1]
176
+ # K dimension is deferred in _setup_attributes
177
+ self.tile_shape_mnk = (*tile_shape_mn, 1)
178
+ tile_M, tile_N = self.tile_shape_mnk[0], self.tile_shape_mnk[1]
312
179
  # check the cta tile shape
313
180
  if not self.pingpong:
314
181
  if tile_M not in [64, 128, 192, 256, 320]:
@@ -332,8 +199,6 @@ class HopperWgmmaGemmKernel:
332
199
  tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
333
200
  if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
334
201
  raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
335
- if not self.tile_shape_mnk[2] % 16 == 0:
336
- raise ValueError("CTA tile shape K must be divisible by 16")
337
202
 
338
203
  if not self.pingpong:
339
204
  if tile_M == 320: # tile_M / 64 is not even so we have to split along N
@@ -344,7 +209,7 @@ class HopperWgmmaGemmKernel:
344
209
  else:
345
210
  atom_layout_m, atom_layout_n = 1, 2
346
211
  else:
347
- atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
212
+ atom_layout_m = self.tile_shape_mnk[0] // 64 if self.tile_shape_mnk[0] < 256 else 2
348
213
  atom_layout_n = 1
349
214
  assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
350
215
  else:
@@ -403,7 +268,7 @@ class HopperWgmmaGemmKernel:
403
268
  self.shared_storage = None
404
269
  self.buffer_align_bytes = 1024
405
270
 
406
- def _setup_attributes(self):
271
+ def _setup_attributes(self, epilogue_args: Optional[EpilogueArguments]):
407
272
  """Set up configurations that are dependent on GEMM inputs
408
273
 
409
274
  This method configures various attributes based on the input tensor properties
@@ -417,6 +282,38 @@ class HopperWgmmaGemmKernel:
417
282
  - Computing A/B/C shared memory layout
418
283
  """
419
284
 
285
+ self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
286
+ self.a_dtype,
287
+ self.b_dtype,
288
+ self.a_layout.sm90_mma_major_mode(),
289
+ self.b_layout.sm90_mma_major_mode(),
290
+ self.acc_dtype,
291
+ self.atom_layout_mnk,
292
+ tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
293
+ )
294
+ if const_expr(self.atom_layout_mnk[1] > 1):
295
+ # If N dimension is split among 2 WGs, we need to permute the N dimension so
296
+ # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
297
+ # containing accumulators that are next to each other in the N dimension.
298
+ # Without permutation WG0 would write to epi smem of size (64, 16) and
299
+ # WG1 would write to a separate epi smem of size (64, 16) that's far away.
300
+ atom_n = self.atom_layout_mnk[1]
301
+ permutation_n = cute.make_ordered_layout(
302
+ (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
303
+ )
304
+ self.tiled_mma = cute.make_tiled_mma(
305
+ cute.make_mma_atom(self.tiled_mma.op),
306
+ self.atom_layout_mnk,
307
+ permutation_mnk=(None, permutation_n, None),
308
+ )
309
+ mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
310
+ mma_inst_tile_k = 4
311
+ self.tile_shape_mnk = (
312
+ self.tile_shape_mnk[0],
313
+ self.tile_shape_mnk[1],
314
+ mma_inst_shape_k * mma_inst_tile_k,
315
+ )
316
+
420
317
  self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
421
318
 
422
319
  self.epi_tile = self._sm90_compute_tile_shape_or_override(
@@ -433,6 +330,7 @@ class HopperWgmmaGemmKernel:
433
330
  self.b_dtype,
434
331
  self.d_dtype,
435
332
  self.c_dtype,
333
+ epilogue_args,
436
334
  self.smem_capacity,
437
335
  self.occupancy,
438
336
  # epi_smem will reuse smem ab if not persistent.
@@ -466,13 +364,12 @@ class HopperWgmmaGemmKernel:
466
364
  self,
467
365
  mA: cute.Tensor,
468
366
  mB: cute.Tensor,
469
- mD: cute.Tensor,
367
+ mD: Optional[cute.Tensor],
470
368
  mC: Optional[cute.Tensor],
369
+ epilogue_args: Optional[ArgumentsBase],
370
+ scheduler_args: TileSchedulerOptions,
371
+ varlen_args: Optional[VarlenArguments],
471
372
  mAIdx: Optional[cute.Tensor],
472
- mCuSeqlensM: Optional[cute.Tensor],
473
- mTensormaps: Optional[cute.Tensor],
474
- tile_count_semaphore: Optional[cute.Pointer],
475
- max_active_clusters: Int32,
476
373
  stream: cuda.CUstream,
477
374
  ):
478
375
  """Execute the GEMM operation in steps:
@@ -495,12 +392,12 @@ class HopperWgmmaGemmKernel:
495
392
  # setup static attributes before smem/grid/tma computation
496
393
  self.a_dtype = mA.element_type
497
394
  self.b_dtype = mB.element_type
498
- self.d_dtype = mD.element_type
395
+ self.d_dtype = mD.element_type if mD is not None else None
499
396
  self.c_dtype = mC.element_type if mC is not None else None
500
- self.a_layout = cutlass.utils.LayoutEnum.from_tensor(mA)
501
- self.b_layout = cutlass.utils.LayoutEnum.from_tensor(mB)
502
- self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
503
- self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
397
+ self.a_layout = LayoutEnum.from_tensor(mA)
398
+ self.b_layout = LayoutEnum.from_tensor(mB)
399
+ self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None
400
+ self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None
504
401
 
505
402
  if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
506
403
  raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
@@ -517,35 +414,12 @@ class HopperWgmmaGemmKernel:
517
414
  )
518
415
  mA, mD = [
519
416
  cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
417
+ if t is not None
418
+ else None
520
419
  for t in (mA, mD)
521
420
  ]
522
421
 
523
- self._setup_attributes()
524
-
525
- tiled_mma = sm90_utils.make_trivial_tiled_mma(
526
- self.a_dtype,
527
- self.b_dtype,
528
- self.a_layout.sm90_mma_major_mode(),
529
- self.b_layout.sm90_mma_major_mode(),
530
- self.acc_dtype,
531
- self.atom_layout_mnk,
532
- tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
533
- )
534
- if const_expr(self.atom_layout_mnk[1] > 1):
535
- # If N dimension is split among 2 WGs, we need to permute the N dimension so
536
- # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
537
- # containing accumulators that are next to each other in the N dimension.
538
- # Without permutation WG0 would write to epi smem of size (64, 16) and
539
- # WG1 would write to a separate epi smem of size (64, 16) that's far away.
540
- atom_n = self.atom_layout_mnk[1]
541
- permutation_n = cute.make_ordered_layout(
542
- (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
543
- )
544
- tiled_mma = cute.make_tiled_mma(
545
- cute.make_mma_atom(tiled_mma.op),
546
- self.atom_layout_mnk,
547
- permutation_mnk=(None, permutation_n, None),
548
- )
422
+ self._setup_attributes(epilogue_args)
549
423
 
550
424
  if const_expr(not self.gather_A):
551
425
  tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
@@ -564,9 +438,12 @@ class HopperWgmmaGemmKernel:
564
438
  self.cluster_shape_mnk[0],
565
439
  )
566
440
 
567
- tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
568
- mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
569
- )
441
+ if const_expr(mD is not None):
442
+ tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
443
+ mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
444
+ )
445
+ else:
446
+ tma_atom_d, tma_tensor_d = None, None
570
447
 
571
448
  if const_expr(mC is not None):
572
449
  tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
@@ -575,65 +452,66 @@ class HopperWgmmaGemmKernel:
575
452
  else:
576
453
  tma_atom_c, tma_tensor_c = None, None
577
454
 
578
- if const_expr(mCuSeqlensM is None):
579
- problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (
580
- mD.shape[2],
455
+ epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
456
+
457
+ if const_expr(varlen_args is None):
458
+ varlen_args = VarlenArguments()
459
+ if const_expr(varlen_args.mCuSeqlensM is None):
460
+ num_problems = (
461
+ mD.shape[2]
462
+ if mD is not None
463
+ else (
464
+ mB.shape[2]
465
+ if varlen_args.mCuSeqlensK is None
466
+ else varlen_args.mCuSeqlensK.shape[0] - 1
467
+ )
581
468
  )
582
- TileSchedulerCls = TileScheduler
583
- tile_sched_args = TileSchedulerArguments(
584
- problem_shape_ntile_mnl=problem_shape_ntile_mnl,
585
- raster_order=RasterOrderOption.Heuristic,
586
- group_size=8,
587
- cluster_shape_mnk=self.cluster_shape_mnk,
588
- tile_count_semaphore=tile_count_semaphore,
589
- is_persistent=self.is_persistent,
469
+ problem_shape_ntile_mnl = (
470
+ cute.ceil_div(mA.shape[0], self.tile_shape_mnk[0]),
471
+ cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
472
+ num_problems,
590
473
  )
474
+ TileSchedulerCls = self.get_scheduler_class()
475
+ tile_sched_args = self.get_scheduler_arguments(problem_shape_ntile_mnl, scheduler_args)
591
476
  else:
592
- assert mTensormaps is not None
477
+ assert mD is not None or not self.gather_A
593
478
  problem_shape_ntile_mnl = (
594
479
  None,
595
- cute.ceil_div(mD.shape[1], self.tile_shape_mnk[1]),
596
- mCuSeqlensM.shape[0] - 1,
480
+ cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
481
+ varlen_args.mCuSeqlensM.shape[0] - 1,
597
482
  )
598
483
  TileSchedulerCls = VarlenMTileScheduler
599
484
  tile_sched_args = VarlenMTileSchedulerArguments(
600
485
  problem_shape_ntile_mnl=problem_shape_ntile_mnl,
601
- total_m=mD.shape[0],
602
- cu_seqlens_m=mCuSeqlensM,
603
- raster_order=RasterOrderOption.Heuristic,
604
- group_size=8,
605
- tile_shape_mnk=self.tile_shape_mnk,
486
+ total_m=mD.shape[0] if mD is not None else mAIdx.shape[0],
487
+ cu_seqlens_m=varlen_args.mCuSeqlensM,
488
+ raster_order=scheduler_args.raster_order,
489
+ group_size=scheduler_args.max_swizzle_size,
490
+ tile_shape_mn=self.tile_shape_mnk[:2],
606
491
  cluster_shape_mnk=self.cluster_shape_mnk,
607
- tile_count_semaphore=tile_count_semaphore,
492
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
608
493
  is_persistent=self.is_persistent,
609
494
  )
610
495
  tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
611
- grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
496
+ grid = TileSchedulerCls.get_grid_shape(
497
+ tile_sched_params, scheduler_args.max_active_clusters
498
+ )
612
499
 
613
- epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if self.is_persistent else 0
500
+ epi_smem_size = (
501
+ cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
502
+ )
614
503
  epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
615
504
 
616
- size_tensormap_in_i64 = (
617
- 0
618
- if mCuSeqlensM is None
619
- or self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.GMEM
620
- else HopperWgmmaGemmKernel.num_tensormaps
621
- * HopperWgmmaGemmKernel.bytes_per_tensormap
622
- // 8
623
- ) * (1 if not self.pingpong else 2)
624
-
625
505
  @cute.struct
626
506
  class SharedStorage:
627
- tensormap_buffer: cute.struct.Align[
628
- cute.struct.MemRange[cutlass.Int64, size_tensormap_in_i64],
629
- 64,
630
- ]
631
507
  ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
632
508
  epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
633
509
  sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
634
510
  tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
635
511
  sD: cute.struct.Align[
636
- cute.struct.MemRange[self.d_dtype, epi_smem_size],
512
+ cute.struct.MemRange[
513
+ self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
514
+ ],
637
515
  self.buffer_align_bytes,
638
516
  ]
639
517
  sC: cute.struct.Align[
@@ -642,6 +520,7 @@ class HopperWgmmaGemmKernel:
642
520
  ],
643
521
  self.buffer_align_bytes,
644
522
  ]
523
+ epi: self.epi_get_smem_struct(epilogue_params)
645
524
  sA: cute.struct.Align[
646
525
  cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
647
526
  self.buffer_align_bytes,
@@ -661,13 +540,14 @@ class HopperWgmmaGemmKernel:
661
540
  tma_tensor_b,
662
541
  tma_atom_d,
663
542
  tma_tensor_d,
664
- mD,
665
543
  tma_atom_c,
666
544
  tma_tensor_c,
545
+ epilogue_params,
667
546
  mAIdx,
668
- mCuSeqlensM,
669
- mTensormaps,
670
- tiled_mma,
547
+ varlen_args.mCuSeqlensM,
548
+ varlen_args.mCuSeqlensK,
549
+ varlen_args.mTensormaps,
550
+ self.tiled_mma,
671
551
  self.cluster_layout_mnk,
672
552
  self.a_smem_layout_staged,
673
553
  self.b_smem_layout_staged,
@@ -693,20 +573,21 @@ class HopperWgmmaGemmKernel:
693
573
  mA_mkl: cute.Tensor,
694
574
  tma_atom_b: cute.CopyAtom,
695
575
  mB_nkl: cute.Tensor,
696
- tma_atom_d: cute.CopyAtom,
697
- mD_mnl_tma: cute.Tensor,
698
- mD_mnl: cute.Tensor,
576
+ tma_atom_d: Optional[cute.CopyAtom],
577
+ mD_mnl: Optional[cute.Tensor],
699
578
  tma_atom_c: Optional[cute.CopyAtom],
700
579
  mC_mnl: Optional[cute.Tensor],
580
+ epilogue_params: ParamsBase,
701
581
  mAIdx: Optional[cute.Tensor],
702
582
  cu_seqlens_m: Optional[cute.Tensor],
583
+ cu_seqlens_k: Optional[cute.Tensor],
703
584
  tensormaps: Optional[cute.Tensor],
704
585
  tiled_mma: cute.TiledMma,
705
586
  cluster_layout_mnk: cute.Layout,
706
- a_smem_layout_staged: cute.ComposedLayout,
707
- b_smem_layout_staged: cute.ComposedLayout,
708
- epi_smem_layout_staged: cute.ComposedLayout,
709
- epi_c_smem_layout_staged: cute.ComposedLayout,
587
+ a_smem_layout: cute.ComposedLayout,
588
+ b_smem_layout: cute.ComposedLayout,
589
+ epi_smem_layout: cute.ComposedLayout,
590
+ epi_c_smem_layout: cute.ComposedLayout,
710
591
  tile_sched_params: ParamsBase,
711
592
  TileSchedulerCls: cutlass.Constexpr[Callable],
712
593
  ):
@@ -723,39 +604,35 @@ class HopperWgmmaGemmKernel:
723
604
  :type mB_nkl: cute.Tensor
724
605
  :param tma_atom_d: TMA copy atom for D tensor
725
606
  :type tma_atom_d: cute.CopyAtom
726
- :param mD_mnl_tma: Output tensor D
727
- :type mD_mnl_tma: cute.Tensor
607
+ :param mD_mnl: Output tensor D
608
+ :type mD_mnl: cute.Tensor
728
609
  :param tiled_mma: Tiled MMA object
729
610
  :type tiled_mma: cute.TiledMma
730
611
  :param cluster_layout_mnk: CTA layout
731
612
  :type cluster_layout_mnk: cute.Layout
732
- :param a_smem_layout_staged: Shared memory layout for A
733
- :type a_smem_layout_staged: cute.ComposedLayout
734
- :param b_smem_layout_staged: Shared memory layout for B
735
- :type b_smem_layout_staged: cute.ComposedLayout
736
- :param epi_smem_layout_staged: Shared memory layout for epilogue
737
- :type epi_smem_layout_staged: cute.ComposedLayout
613
+ :param a_smem_layout: Shared memory layout for A
614
+ :type a_smem_layout: cute.ComposedLayout
615
+ :param b_smem_layout: Shared memory layout for B
616
+ :type b_smem_layout: cute.ComposedLayout
617
+ :param epi_smem_layout: Shared memory layout for epilogue
618
+ :type epi_smem_layout: cute.ComposedLayout
738
619
  """
739
620
 
740
- varlen = const_expr(cu_seqlens_m is not None)
621
+ varlen_m = const_expr(cu_seqlens_m is not None)
622
+ varlen_k = const_expr(cu_seqlens_k is not None)
623
+ assert not (varlen_m and varlen_k)
624
+ has_D = const_expr(mD_mnl is not None)
625
+ has_C = const_expr(mC_mnl is not None)
626
+
741
627
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
742
628
 
743
629
  # /////////////////////////////////////////////////////////////////////////////
744
630
  # Prefetch Tma desc
745
631
  # /////////////////////////////////////////////////////////////////////////////
746
632
  if warp_idx == self.ab_load_warp_id:
747
- if const_expr(tma_atom_a is not None):
748
- cpasync.prefetch_descriptor(tma_atom_a)
749
- cpasync.prefetch_descriptor(tma_atom_b)
750
- cpasync.prefetch_descriptor(tma_atom_d)
751
- if const_expr(tma_atom_c is not None):
752
- cpasync.prefetch_descriptor(tma_atom_c)
753
-
754
- a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
755
- b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
756
- tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
757
- if const_expr(not self.gather_A):
758
- tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
633
+ for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
634
+ if const_expr(tma_atom is not None):
635
+ cpasync.prefetch_descriptor(tma_atom)
759
636
 
760
637
  # /////////////////////////////////////////////////////////////////////////////
761
638
  # Alloc and init AB full/empty + ACC full mbar (pipeline)
@@ -763,147 +640,93 @@ class HopperWgmmaGemmKernel:
763
640
  smem = cutlass.utils.SmemAllocator()
764
641
  storage = smem.allocate(self.shared_storage)
765
642
 
766
- # Threads/warps participating in this pipeline
767
- producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
768
- ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
769
- # Each warp will contribute to the arrive count with the number of mcast size
770
- mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
771
- consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
772
- ab_pipeline_consumer_group = pipeline.CooperativeGroup(
773
- pipeline.Agent.Thread, consumer_arrive_cnt
774
- )
775
-
776
- cta_layout_vmnk = cute.make_layout((1, *cluster_layout_mnk.shape))
777
- pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
778
- ab_pipeline = pipeline_cls.create(
779
- barrier_storage=storage.ab_pipeline_array_ptr.data_ptr(),
780
- num_stages=self.ab_stage,
781
- producer_group=ab_pipeline_producer_group,
782
- consumer_group=ab_pipeline_consumer_group,
783
- tx_count=tma_copy_bytes,
784
- cta_layout_vmnk=cta_layout_vmnk,
643
+ ab_pipeline = self.make_ab_pipeline(
644
+ a_smem_layout=cute.slice_(a_smem_layout, (None, None, 0)),
645
+ b_smem_layout=cute.slice_(b_smem_layout, (None, None, 0)),
646
+ tiled_mma=tiled_mma,
647
+ cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
648
+ ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
785
649
  )
786
-
787
- if const_expr(mC_mnl is not None):
788
- # Threads/warps participating in this pipeline
789
- epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
790
- # Each warp will contribute 1 to the arrive count
791
- consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
792
- epi_pipeline_consumer_group = pipeline.CooperativeGroup(
793
- pipeline.Agent.Thread, consumer_arrive_cnt
650
+ epi_pipeline = None
651
+ if const_expr(has_C):
652
+ epi_pipeline = self.make_epi_pipeline(
653
+ c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
654
+ epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
794
655
  )
795
- c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
796
- tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
797
- epi_pipeline = pipeline.PipelineTmaAsync.create(
798
- barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
799
- num_stages=self.epi_c_stage,
800
- producer_group=epi_pipeline_producer_group,
801
- consumer_group=epi_pipeline_consumer_group,
802
- tx_count=tma_copy_c_bytes,
803
- )
804
- else:
805
- epi_pipeline = None
806
-
656
+ sched_pipeline = None
657
+ tile_count = None
807
658
  if const_expr(tile_sched_params.tile_count_semaphore is not None):
808
659
  # Dynamic persistent scheduler
809
- # Threads/warps participating in this pipeline
810
- sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
811
- cluster_size = cute.size(cluster_layout_mnk)
812
- # Each warp that are not the scheduler warp will contribute 1 to the arrive count
813
- consumer_arrive_cnt = (
814
- (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
815
- ) * cluster_size - 1
816
- sched_pipeline_consumer_group = pipeline.CooperativeGroup(
817
- pipeline.Agent.Thread, consumer_arrive_cnt
818
- )
819
- sched_pipeline = pipeline.PipelineAsync.create(
820
- barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
821
- num_stages=self.sched_stage,
822
- producer_group=sched_pipeline_producer_group,
823
- consumer_group=sched_pipeline_consumer_group,
824
- # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
825
- consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
660
+ sched_pipeline = self.make_sched_pipeline(
661
+ cluster_layout_mnk,
662
+ sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
663
+ varlen_k=varlen_k,
826
664
  )
827
665
  tile_count = storage.tile_count.get_tensor((self.sched_stage,))
828
- else:
829
- sched_pipeline = None
830
- tile_count = None
831
666
 
832
667
  # ///////////////////////////////////////////////////////////////////////////////
833
668
  # Generate smem tensor A/B
834
669
  # ///////////////////////////////////////////////////////////////////////////////
835
- sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
836
- sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
837
- if const_expr(not self.is_persistent):
838
- sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
839
- sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
840
- else:
841
- sD = storage.sD.get_tensor(
842
- epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
843
- )
844
- if const_expr(mC_mnl is not None):
845
- sC = storage.sC.get_tensor(
846
- epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
847
- )
848
- else:
849
- sC = None
670
+ sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
671
+ sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
672
+ sD = None
673
+ if const_expr(has_D):
674
+ if const_expr(not self.is_persistent):
675
+ sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
676
+ sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
677
+ else:
678
+ sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
679
+ sC = None
680
+ if const_expr(has_C):
681
+ sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
682
+ epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
850
683
 
851
684
  # Get tensormap buffer address
852
- if const_expr(varlen):
853
- grid_dim = cute.arch.grid_dim()
854
- bid = cute.arch.block_idx()
855
- tensormap_workspace_idx = (
856
- bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0]
857
- )
858
- # TODO: this is only for D, not for A/B
859
- if const_expr(self.pingpong):
860
- tensormap_workspace_idx = tensormap_workspace_idx * 2 + warp_idx // 4
685
+ tensormap_manager = None
686
+ tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
687
+ if const_expr(varlen_m or varlen_k):
861
688
  tensormap_manager = TensorMapManagerSm90(
862
- self.tensormap_update_mode, HopperWgmmaGemmKernel.bytes_per_tensormap
863
- )
864
- tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
865
- tensormaps[tensormap_workspace_idx, None].iterator
689
+ cutlass.utils.TensorMapUpdateMode.GMEM, GemmSm90.bytes_per_tensormap
866
690
  )
867
- if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.SMEM):
868
- tensormap_smem_ptr = storage.tensormap_buffer.data_ptr()
869
- tensormap_d_smem_ptr = tensormap_smem_ptr + (warp_idx // 4) * (
870
- HopperWgmmaGemmKernel.bytes_per_tensormap // 8
691
+ # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
692
+ tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
693
+ if const_expr(varlen_m):
694
+ tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
695
+ tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
696
+ tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
871
697
  )
872
- # Need this, otherwise "expected tma descriptor pointer to have alignment at least 64, but got 8"
873
- tensormap_d_smem_ptr = cute.make_ptr(
874
- cutlass.Int64,
875
- tensormap_d_smem_ptr.toint(),
876
- cute.AddressSpace.smem,
877
- assumed_align=64,
878
- )
879
- tensormap_d_init_ptr = tensormap_d_smem_ptr
880
698
  else:
881
- tensormap_d_smem_ptr = None
882
- tensormap_d_init_ptr = tensormap_d_ptr
883
- else:
884
- tensormap_d_smem_ptr = None
885
- tensormap_manager, tensormap_d_ptr, tensormap_d_init_ptr = None, None, None
699
+ assert varlen_k
700
+ tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
701
+ tensormaps[tensormap_workspace_idx, 0, None].iterator
702
+ )
703
+ tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
704
+ tensormaps[tensormap_workspace_idx, 1, None].iterator
705
+ )
886
706
 
887
707
  TileSchedulerCls = partial(
888
708
  TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
889
709
  )
890
710
 
891
- k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
892
- c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
893
-
894
711
  if warp_idx >= self.ab_load_warp_id:
895
712
  cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
896
- if const_expr(mC_mnl is not None):
897
- epi_load_barrier = pipeline.NamedBarrier(
898
- barrier_id=int(NamedBarrierGemm.EpilogueLoad),
899
- num_threads=self.num_ab_load_threads + self.num_epi_load_threads,
900
- )
901
- else:
902
- epi_load_barrier = None
903
713
  if (
904
714
  warp_idx >= self.ab_load_warp_id
905
715
  and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
906
716
  ):
717
+ is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
718
+ if const_expr(varlen_k):
719
+ # initialize tensormap for A & B
720
+ tensormap_manager.init_tensormap_from_atom(
721
+ tma_atom_a,
722
+ tensormap_a_ptr,
723
+ is_tma_warp,
724
+ )
725
+ tensormap_manager.init_tensormap_from_atom(
726
+ tma_atom_b,
727
+ tensormap_b_ptr,
728
+ is_tma_warp,
729
+ )
907
730
  # ///////////////////////////////////////////////////////////////////////////////
908
731
  # Get mcast mask
909
732
  # ///////////////////////////////////////////////////////////////////////////////
@@ -927,16 +750,37 @@ class HopperWgmmaGemmKernel:
927
750
  ab_producer_state = make_pipeline_state(
928
751
  pipeline.PipelineUserType.Producer, self.ab_stage
929
752
  )
930
- do_epi_load_barrier_arrive = cutlass.Boolean(True)
753
+ if const_expr(varlen_k):
754
+ # wait tensormap initialization complete before update
755
+ tensormap_manager.fence_tensormap_initialization()
756
+ # batch index of last tile
757
+ last_batch_idx = cutlass.Int32(-1)
931
758
  while work_tile.is_valid_tile:
932
759
  tile_coord_mnkl = work_tile.tile_idx
933
760
  batch_idx = tile_coord_mnkl[3]
761
+ if const_expr(varlen_k):
762
+ is_group_changed = batch_idx != last_batch_idx
763
+ last_batch_idx = batch_idx
764
+ if is_group_changed:
765
+ # construct tensor A/B based on real address, shape and stride information
766
+ tensormap_manager.update_tensormap_shape(
767
+ (tensormap_a_ptr, tensormap_b_ptr),
768
+ is_manager_warp=is_tma_warp,
769
+ shapes=(cu_seqlens_k[batch_idx + 1], cu_seqlens_k[batch_idx + 1]),
770
+ orders=(
771
+ 0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1,
772
+ 0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1,
773
+ ),
774
+ tensormap_smem_ptr=None,
775
+ )
934
776
  # ///////////////////////////////////////////////////////////////////////////
935
777
  # Local_tile partition global tensors
936
778
  # ///////////////////////////////////////////////////////////////////////////
937
779
  if const_expr(not self.gather_A):
938
- if const_expr(cu_seqlens_m is not None):
780
+ if const_expr(varlen_m):
939
781
  mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
782
+ elif const_expr(varlen_k):
783
+ mA_mk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mA_mkl)
940
784
  else:
941
785
  mA_mk = mA_mkl[None, None, batch_idx]
942
786
  # (bM, bK, RestK)
@@ -947,28 +791,46 @@ class HopperWgmmaGemmKernel:
947
791
  )
948
792
  else:
949
793
  mA_mk = mA_mkl
950
- if const_expr(cu_seqlens_m is not None):
794
+ if const_expr(varlen_m):
951
795
  mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
796
+ elif const_expr(varlen_k):
797
+ mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
952
798
  else:
953
799
  mAIdx_mk = mAIdx[None, batch_idx]
954
800
  gAIdx = cute.local_tile(
955
801
  mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
956
802
  )
803
+ if const_expr(varlen_k):
804
+ mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
805
+ else:
806
+ mB_nk = mB_nkl[None, None, batch_idx]
957
807
  # (bN, bK, RestK)
958
808
  gB_k = cute.local_tile(
959
- mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
809
+ mB_nk, cute.select(self.tile_shape_mnk, [1, 2]), (tile_coord_mnkl[1], None)
960
810
  )
961
811
  # //////////////////////////////////////////////////////////////////////////
962
812
  # Partition shared tensor for TMA load A/B
963
813
  # //////////////////////////////////////////////////////////////////////////
814
+ if const_expr(varlen_k):
815
+ # ensure the update to tensormap has completed before using it
816
+ if is_group_changed and is_tma_warp:
817
+ tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
818
+ tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
819
+ tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
820
+ tensormap_a_ptr, cute.AddressSpace.generic
821
+ )
822
+ tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
823
+ tensormap_b_ptr, cute.AddressSpace.generic
824
+ )
825
+ else:
826
+ tma_desc_a_ptr, tma_desc_b_ptr = None, None
964
827
  # TMA load A partition_S/D
965
828
  a_cta_layout = cute.make_layout(
966
829
  cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
967
830
  )
968
831
  a_cta_crd = cluster_coord_mnk[1]
969
832
  if const_expr(not self.gather_A):
970
- # ((atom_v, rest_v), STAGE)
971
- # ((atom_v, rest_v), RestK)
833
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
972
834
  tAsA, tAgA_k = cpasync.tma_partition(
973
835
  tma_atom_a,
974
836
  a_cta_crd,
@@ -976,7 +838,12 @@ class HopperWgmmaGemmKernel:
976
838
  cute.group_modes(sA, 0, 2),
977
839
  cute.group_modes(gA_k, 0, 2),
978
840
  )
979
- copy_A = partial(cute.copy, tma_atom_a, mcast_mask=a_mcast_mask)
841
+ copy_A = partial(
842
+ cute.copy,
843
+ tma_atom_a,
844
+ mcast_mask=a_mcast_mask,
845
+ tma_desc_ptr=tma_desc_a_ptr,
846
+ )
980
847
  else:
981
848
  tiled_copy_A = self._make_gmem_tiled_copy_A(
982
849
  mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
@@ -996,8 +863,7 @@ class HopperWgmmaGemmKernel:
996
863
  cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
997
864
  )
998
865
  b_cta_crd = cluster_coord_mnk[0]
999
- # ((atom_v, rest_v), STAGE)
1000
- # ((atom_v, rest_v), RestK)
866
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
1001
867
  tBsB, tBgB_k = cpasync.tma_partition(
1002
868
  tma_atom_b,
1003
869
  b_cta_crd,
@@ -1005,7 +871,15 @@ class HopperWgmmaGemmKernel:
1005
871
  cute.group_modes(sB, 0, 2),
1006
872
  cute.group_modes(gB_k, 0, 2),
1007
873
  )
1008
- copy_B = partial(cute.copy, tma_atom_b, mcast_mask=b_mcast_mask)
874
+ copy_B = partial(
875
+ cute.copy, tma_atom_b, mcast_mask=b_mcast_mask, tma_desc_ptr=tma_desc_b_ptr
876
+ )
877
+ k_len = (
878
+ cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
879
+ if const_expr(varlen_k)
880
+ else mA_mkl.shape[1]
881
+ )
882
+ k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1009
883
  if const_expr(not self.gather_A):
1010
884
  ab_producer_state = self.load_AB(
1011
885
  ab_pipeline,
@@ -1016,6 +890,7 @@ class HopperWgmmaGemmKernel:
1016
890
  copy_B,
1017
891
  tBgB_k,
1018
892
  tBsB,
893
+ k_tile_cnt,
1019
894
  )
1020
895
  else:
1021
896
  limit_m = (
@@ -1033,93 +908,37 @@ class HopperWgmmaGemmKernel:
1033
908
  copy_B,
1034
909
  tBgB_k,
1035
910
  tBsB,
911
+ k_tile_cnt,
1036
912
  limit_A=(
1037
913
  limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
1038
914
  mA_mk.shape[1],
1039
915
  ),
1040
916
  )
1041
- if const_expr(epi_load_barrier is not None):
1042
- # In the first work tile, the epi load warp will wait for the signal
1043
- # from the mainloop load warp to start loading C, to avoid interfering
1044
- # with loading A and B.
1045
- if do_epi_load_barrier_arrive:
1046
- epi_load_barrier.arrive()
1047
- do_epi_load_barrier_arrive = cutlass.Boolean(False)
1048
917
  tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
918
+ tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
1049
919
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1050
920
  work_tile = tile_scheduler.get_current_work()
1051
921
  # End of persistent scheduler loop
1052
- if const_expr(self.pingpong):
922
+ if const_expr(self.pingpong and not varlen_k):
1053
923
  # Need to write the tile_idx to smem for the next WG in the pingpong mode
924
+ # tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
925
+ tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
1054
926
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1055
927
  ab_pipeline.producer_tail(ab_producer_state)
1056
928
  if is_scheduler_warp:
1057
929
  tile_scheduler.producer_tail()
1058
930
 
1059
- # if const_expr(mC_mnl is not None):
1060
- # if warp_idx == self.epi_load_warp_id:
1061
- # epi_producer_state = make_pipeline_state(
1062
- # pipeline.PipelineUserType.Producer, self.epi_c_stage
1063
- # )
1064
- # do_epi_load_barrier_wait = cutlass.Boolean(True)
1065
- # tile_scheduler = TileSchedulerCls()
1066
- # work_tile = tile_scheduler.initial_work_tile_info()
1067
- # while work_tile.is_valid_tile:
1068
- # tile_coord_mnkl = work_tile.tile_idx
1069
- # batch_idx = tile_coord_mnkl[3]
1070
- # if const_expr(cu_seqlens_m is not None):
1071
- # mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
1072
- # else:
1073
- # mC_mn = mC_mnl[None, None, batch_idx]
1074
- # # (bM, bN)
1075
- # gC = cute.local_tile(
1076
- # mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1077
- # )
1078
- # tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
1079
- # bGS_sC, bGS_gC = cpasync.tma_partition(
1080
- # tma_atom_c,
1081
- # 0,
1082
- # cute.make_layout(1),
1083
- # cute.group_modes(sC, 0, 2),
1084
- # tCgC_for_tma_partition,
1085
- # )
1086
- # if do_epi_load_barrier_wait:
1087
- # epi_load_barrier.arrive_and_wait()
1088
- # do_epi_load_barrier_wait = cutlass.Boolean(False)
1089
- # epi_tile_num = const_expr(cute.size(tCgC_for_tma_partition, mode=[1]))
1090
- # epi_tile_shape = tCgC_for_tma_partition.shape[1]
1091
- # for epi_idx in cutlass.range(epi_tile_num, unroll=1):
1092
- # epi_pipeline.producer_acquire(epi_producer_state)
1093
- # # Get the global memory coordinate for the current epi tile
1094
- # epi_tile_layout = cute.make_layout(
1095
- # epi_tile_shape, stride=(epi_tile_shape[1], 1)
1096
- # )
1097
- # gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1098
- # cute.copy(
1099
- # tma_atom_c,
1100
- # bGS_gC[None, gmem_coord],
1101
- # bGS_sC[None, epi_producer_state.index],
1102
- # tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1103
- # )
1104
- # # Epi pipeline's producer commit is a NOP
1105
- # epi_pipeline.producer_commit(epi_producer_state)
1106
- # epi_producer_state.advance()
1107
- # tile_scheduler.advance_to_next_work()
1108
- # work_tile = tile_scheduler.get_current_work()
1109
- # # End of persistent scheduler loop
1110
- # epi_pipeline.producer_tail(epi_producer_state)
1111
-
1112
931
  if warp_idx < self.ab_load_warp_id:
1113
932
  cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
1114
- is_tma_warp = cutlass.Boolean(
933
+ is_tma_warp = Boolean(
1115
934
  (not self.pingpong and warp_idx == 0)
1116
935
  or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
1117
936
  )
1118
- if const_expr(varlen):
937
+ if const_expr(varlen_m):
1119
938
  # initialize tensormap for D
1120
939
  tensormap_manager.init_tensormap_from_atom(
1121
940
  tma_atom_d,
1122
- tensormap_d_init_ptr,
941
+ tensormap_d_ptr,
1123
942
  is_manager_warp=is_tma_warp,
1124
943
  )
1125
944
  # //////////////////////////////////////////////////////////////////////////////
@@ -1145,10 +964,9 @@ class HopperWgmmaGemmKernel:
1145
964
 
1146
965
  acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
1147
966
  acc = cute.make_fragment(acc_shape, self.acc_dtype)
967
+ acc_slow = None
1148
968
  if const_expr(self.fp8_slow_accum):
1149
969
  acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
1150
- else:
1151
- acc_slow = None
1152
970
 
1153
971
  if const_expr(self.pingpong):
1154
972
  if warp_group_idx == 0:
@@ -1156,6 +974,9 @@ class HopperWgmmaGemmKernel:
1156
974
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
1157
975
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
1158
976
 
977
+ k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.tile_shape_mnk[2])
978
+ c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
979
+
1159
980
  ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
1160
981
  epi_read_state = make_pipeline_state(
1161
982
  pipeline.PipelineUserType.Consumer, self.epi_c_stage
@@ -1164,16 +985,29 @@ class HopperWgmmaGemmKernel:
1164
985
  pipeline.PipelineUserType.Producer, self.epi_c_stage
1165
986
  )
1166
987
  tile_scheduler = TileSchedulerCls()
988
+ work_tile = None
1167
989
  if const_expr(self.pingpong):
990
+ if const_expr(varlen_k):
991
+ work_tile = tile_scheduler.initial_work_tile_info()
1168
992
  if warp_idx >= 4:
1169
- # Advance 2nd Math WG to the next work tile for the startup
1170
- tile_scheduler.advance_to_next_work()
1171
993
  # Advance 2nd Math WG pipeline states to the end of 1st Math WG
1172
- ab_read_state.advance_iters(k_tile_cnt)
1173
994
  epi_read_state.advance_iters(c_tile_cnt)
1174
995
  epi_producer_state.advance_iters(c_tile_cnt)
1175
- work_tile = tile_scheduler.initial_work_tile_info()
1176
- if const_expr(varlen):
996
+ if const_expr(not varlen_k):
997
+ ab_read_state.advance_iters(k_tile_cnt_static)
998
+ else:
999
+ batch_idx = work_tile.tile_idx[3]
1000
+ k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1001
+ k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1002
+ ab_read_state.advance_iters(k_tile_cnt)
1003
+ tile_scheduler.advance_to_next_work()
1004
+ if const_expr(varlen_k):
1005
+ work_tile = tile_scheduler.get_current_work()
1006
+ if const_expr(not varlen_k):
1007
+ work_tile = tile_scheduler.initial_work_tile_info()
1008
+ else:
1009
+ work_tile = tile_scheduler.initial_work_tile_info()
1010
+ if const_expr(varlen_m):
1177
1011
  # wait tensormap initialization complete before update
1178
1012
  tensormap_manager.fence_tensormap_initialization()
1179
1013
  # batch index of last tile
@@ -1181,19 +1015,25 @@ class HopperWgmmaGemmKernel:
1181
1015
  while work_tile.is_valid_tile:
1182
1016
  tile_coord_mnkl = work_tile.tile_idx
1183
1017
  batch_idx = tile_coord_mnkl[3]
1184
- if const_expr(varlen):
1018
+ if const_expr(varlen_m):
1185
1019
  is_group_changed = batch_idx != last_batch_idx
1186
1020
  last_batch_idx = batch_idx
1187
1021
  if is_group_changed:
1188
1022
  # construct tensor D based on real address, shape and stride information
1189
1023
  tensormap_manager.update_tensormap_shape(
1190
- ((tensormap_d_ptr),),
1024
+ (tensormap_d_ptr,),
1191
1025
  is_manager_warp=is_tma_warp,
1192
- tensormap_smem_ptr=(tensormap_d_smem_ptr,),
1193
1026
  shapes=(cu_seqlens_m[batch_idx + 1],),
1194
1027
  orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
1028
+ tensormap_smem_ptr=None,
1195
1029
  )
1196
1030
 
1031
+ k_len = (
1032
+ cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1033
+ if const_expr(varlen_k)
1034
+ else mA_mkl.shape[1]
1035
+ )
1036
+ k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1197
1037
  ab_read_state, tiled_mma = self.mma(
1198
1038
  ab_pipeline,
1199
1039
  ab_read_state,
@@ -1205,9 +1045,9 @@ class HopperWgmmaGemmKernel:
1205
1045
  k_tile_cnt,
1206
1046
  warp_group_idx,
1207
1047
  )
1208
- if const_expr(self.pingpong):
1209
- # Update starting mainloop pipeline state for the next tile
1210
- ab_read_state.advance_iters(k_tile_cnt)
1048
+ if const_expr(varlen_k):
1049
+ if k_tile_cnt == 0:
1050
+ acc.fill(0.0)
1211
1051
 
1212
1052
  # /////////////////////////////////////////////////////////////////////////////
1213
1053
  # EPILOGUE
@@ -1219,194 +1059,123 @@ class HopperWgmmaGemmKernel:
1219
1059
  barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
1220
1060
  )
1221
1061
 
1222
- # Wait for all warp groups in the thread block to finish, because smem for tensor
1223
- # A in the mainloop is reused in the epilogue if not persistent.
1224
- if const_expr(not self.is_persistent):
1225
- epilogue_barrier.arrive_and_wait()
1226
-
1227
- if const_expr(varlen):
1062
+ if const_expr(varlen_m):
1228
1063
  # ensure the update to tensormap has completed before using it
1229
- if is_group_changed:
1230
- if is_tma_warp:
1231
- tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1232
-
1233
- # Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
1234
- # get st.matrix with num_matrices=4
1235
- copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
1236
- self.d_layout, elem_ty_d=self.d_dtype, elem_ty_acc=self.acc_dtype
1237
- )
1238
- copy_atom_C = cute.make_copy_atom(
1239
- warp.StMatrix8x8x16bOp(
1240
- self.d_layout.is_m_major_c(),
1241
- num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
1242
- ),
1243
- cutlass.Float16, # this is just to get the right source layout
1244
- )
1245
- tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
1246
- tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
1247
- # (R2S, R2S_M, R2S_N, PIPE_D)
1248
- thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1249
- tRS_sD = thr_copy_r2s.partition_D(sD)
1250
- # (R2S, R2S_M, R2S_N)
1251
- tRS_rAcc = tiled_copy_r2s.retile(acc)
1252
-
1253
- # Allocate D registers.
1254
- tRS_rD_layout = cute.make_layout(thr_copy_r2s.partition_S(sD).shape[:3])
1255
- tRS_rD = cute.make_fragment(tRS_rD_layout, self.acc_dtype)
1256
-
1257
- if const_expr(mC_mnl is not None):
1258
- copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
1259
- tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1260
- thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1261
- tSR_sC = thr_copy_s2r.partition_S(sC)
1262
- tRS_rC = cute.make_fragment(tRS_rD_layout, self.c_dtype)
1263
- tSR_rC = thr_copy_s2r.retile(tRS_rC)
1264
- else:
1265
- thr_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
1266
-
1267
- if const_expr(cu_seqlens_m is not None):
1268
- mD_mn_tma = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl_tma)
1064
+ if is_group_changed and is_tma_warp:
1065
+ tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1066
+ tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
1067
+ tensormap_d_ptr, cute.AddressSpace.generic
1068
+ )
1269
1069
  else:
1270
- mD_mn_tma = mD_mnl_tma[None, None, batch_idx]
1271
- # (bM, bN)
1272
- gD = cute.local_tile(
1273
- mD_mn_tma, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1274
- )
1275
- tDgD_for_tma_partition = cute.zipped_divide(gD, self.epi_tile)
1276
- bSG_sD, bSG_gD = cpasync.tma_partition(
1277
- tma_atom_d,
1278
- 0,
1279
- cute.make_layout(1),
1280
- cute.group_modes(sD, 0, 2),
1281
- tDgD_for_tma_partition,
1282
- )
1283
-
1284
- if const_expr(mC_mnl is not None):
1285
- if const_expr(cu_seqlens_m is not None):
1286
- mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
1287
- else:
1288
- mC_mn = mC_mnl[None, None, batch_idx]
1289
- # (bM, bN)
1290
- gC = cute.local_tile(
1291
- mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1070
+ tma_desc_d_ptr = None
1071
+
1072
+ if const_expr(has_D):
1073
+ bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(
1074
+ tma_atom_d,
1075
+ mD_mnl,
1076
+ self.tile_shape_mnk[:2],
1077
+ self.epi_tile,
1078
+ sD,
1079
+ tile_coord_mnkl,
1080
+ cu_seqlens_m,
1292
1081
  )
1293
- tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
1294
- bGS_sC, bGS_gC = cpasync.tma_partition(
1082
+ copy_D = partial(cute.copy, tma_atom_d, tma_desc_ptr=tma_desc_d_ptr)
1083
+ else:
1084
+ bSG_sD, bSG_gD, copy_D = None, None, None
1085
+ if const_expr(has_C):
1086
+ bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
1295
1087
  tma_atom_c,
1296
- 0,
1297
- cute.make_layout(1),
1298
- cute.group_modes(sC, 0, 2),
1299
- tCgC_for_tma_partition,
1088
+ mC_mnl,
1089
+ self.tile_shape_mnk[:2],
1090
+ self.epi_tile,
1091
+ sC,
1092
+ tile_coord_mnkl,
1093
+ cu_seqlens_m,
1300
1094
  )
1095
+ copy_C = partial(cute.copy, tma_atom_c)
1096
+ epi_load_g2s = partial(self.epi_load_g2s, epi_pipeline, copy_C, bGS_gC, bGS_sC)
1097
+ else:
1098
+ epi_load_g2s = None
1301
1099
 
1302
- epi_tile_num = const_expr(cute.size(tDgD_for_tma_partition, mode=[1]))
1303
- epi_tile_shape = tDgD_for_tma_partition.shape[1]
1304
- num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1305
- epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
1306
-
1307
- if const_expr(mC_mnl is not None):
1308
- for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
1309
- if is_tma_warp:
1310
- epi_pipeline.producer_acquire(epi_producer_state)
1311
- # Get the global memory coordinate for the current epi tile
1312
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1313
- cute.copy(
1314
- tma_atom_c,
1315
- bGS_gC[None, gmem_coord],
1316
- bGS_sC[None, epi_producer_state.index],
1317
- tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1318
- )
1319
- # Epi pipeline's producer commit is a NOP
1320
- epi_pipeline.producer_commit(epi_producer_state)
1321
- epi_producer_state.advance()
1322
-
1323
- for epi_idx in cutlass.range_constexpr(epi_tile_num):
1324
- # Copy from acc to D registers
1325
- for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1326
- tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1327
- if const_expr(mC_mnl is not None):
1328
- epi_pipeline.consumer_wait(epi_read_state)
1329
- cute.copy(
1330
- thr_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
1331
- )
1332
- # Fence to make sure shared memory read is visible to TMA load
1333
- cute.arch.fence_proxy(
1334
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1335
- )
1336
- cute.arch.sync_warp()
1337
- with cute.arch.elect_one():
1338
- epi_pipeline.consumer_release(epi_read_state)
1339
- epi_read_state.advance()
1340
- if const_expr(epi_idx + self.epi_c_stage < epi_tile_num):
1341
- if is_tma_warp:
1342
- epi_pipeline.producer_acquire(epi_producer_state)
1343
- # Get the global memory coordinate for the current epi tile
1344
- gmem_coord = epi_tile_layout.get_hier_coord(
1345
- epi_idx + self.epi_c_stage
1346
- )
1347
- cute.copy(
1348
- tma_atom_c,
1349
- bGS_gC[None, gmem_coord],
1350
- bGS_sC[None, epi_producer_state.index],
1351
- tma_bar_ptr=epi_pipeline.producer_get_barrier(
1352
- epi_producer_state
1353
- ),
1354
- )
1355
- # Epi pipeline's producer commit is a NOP
1356
- epi_pipeline.producer_commit(epi_producer_state)
1357
- epi_producer_state.advance()
1358
- tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(self.acc_dtype))
1359
- # Type conversion
1360
- tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
1361
- tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
1362
- # Copy from D registers to shared memory
1363
- epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3])
1364
- cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
1365
- # Fence and barrier to make sure shared memory store is visible to TMA store
1366
- cute.arch.fence_proxy(
1367
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1100
+ d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
1101
+ tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
1102
+ tiled_mma, self.d_layout, d_dtype_for_layout, acc, sD, tidx
1103
+ )
1104
+ if const_expr(has_C):
1105
+ tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
1106
+ tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
1368
1107
  )
1108
+ else:
1109
+ tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
1110
+
1111
+ # Wait for all warp groups in the thread block to finish, because smem for tensor
1112
+ # A in the mainloop is reused in the epilogue if not persistent.
1113
+ if const_expr(not self.is_persistent):
1369
1114
  epilogue_barrier.arrive_and_wait()
1370
- # Get the global memory coordinate for the current epi tile
1371
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1372
- # Copy from shared memory to global memory
1373
- if is_tma_warp:
1374
- if const_expr(varlen):
1375
- tma_desc_ptr = tensormap_manager.get_tensormap_ptr(
1376
- tensormap_d_ptr,
1377
- cute.AddressSpace.generic,
1378
- )
1379
- else:
1380
- tma_desc_ptr = None
1381
- cute.copy(
1382
- tma_atom_d,
1383
- bSG_sD[None, epi_buffer],
1384
- bSG_gD[None, gmem_coord],
1385
- tma_desc_ptr=tma_desc_ptr,
1386
- )
1387
- cute.arch.cp_async_bulk_commit_group()
1388
- cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
1389
- epilogue_barrier.arrive_and_wait()
1115
+
1116
+ self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
1117
+
1118
+ epi_read_state, epi_producer_state = self.epilogue(
1119
+ epilogue_params,
1120
+ epi_smem_tensors,
1121
+ epi_pipeline,
1122
+ epi_read_state,
1123
+ epi_producer_state,
1124
+ tiled_mma,
1125
+ tRS_rAcc,
1126
+ tRS_rD,
1127
+ tRS_rC,
1128
+ tiled_copy_r2s,
1129
+ tRS_sD,
1130
+ tiled_copy_s2r,
1131
+ tSR_rC,
1132
+ tSR_sC,
1133
+ copy_D,
1134
+ bSG_sD,
1135
+ bSG_gD,
1136
+ epi_load_g2s,
1137
+ tile_coord_mnkl,
1138
+ cu_seqlens_m,
1139
+ epilogue_barrier,
1140
+ tile_scheduler,
1141
+ tidx,
1142
+ is_tma_warp,
1143
+ )
1390
1144
 
1391
1145
  if const_expr(self.pingpong):
1392
- # Update starting load/store pipeline states for the next tile
1393
- epi_read_state.advance_iters(c_tile_cnt)
1394
- epi_producer_state.advance_iters(c_tile_cnt)
1395
1146
  # With pingpong, 2 WGs write two different output tiles to the same smem,
1396
1147
  # so we have to make sure the smem content is done reading before signaling
1397
1148
  # the next WG's epilogue.
1398
- if warp_idx == 0 or warp_idx == 4:
1149
+ if is_tma_warp:
1399
1150
  cute.arch.cp_async_bulk_wait_group(0, read=True)
1400
1151
  self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
1401
1152
 
1402
- tile_scheduler.advance_to_next_work(
1403
- advance_count=1 if not self.pingpong else self.mma_warp_groups
1404
- )
1405
- work_tile = tile_scheduler.get_current_work()
1153
+ if const_expr(not self.pingpong):
1154
+ tile_scheduler.advance_to_next_work()
1155
+ work_tile = tile_scheduler.get_current_work()
1156
+ else: # Skip a tile for pingpong
1157
+ # Update starting load/store pipeline states for the next tile
1158
+ epi_read_state.advance_iters(c_tile_cnt)
1159
+ epi_producer_state.advance_iters(c_tile_cnt)
1160
+ # Update starting mainloop pipeline state for the next tile
1161
+ if const_expr(not varlen_k):
1162
+ ab_read_state.advance_iters(k_tile_cnt_static)
1163
+ tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
1164
+ work_tile = tile_scheduler.get_current_work()
1165
+ else:
1166
+ tile_scheduler.advance_to_next_work()
1167
+ work_tile = tile_scheduler.get_current_work()
1168
+ if work_tile.is_valid_tile:
1169
+ batch_idx = work_tile.tile_idx[3]
1170
+ k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1171
+ k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1172
+ ab_read_state.advance_iters(k_tile_cnt)
1173
+ tile_scheduler.advance_to_next_work()
1174
+ work_tile = tile_scheduler.get_current_work()
1406
1175
  # End of persistent scheduler loop
1407
1176
 
1408
1177
  if const_expr(not self.pingpong):
1409
- if warp_idx == 0:
1178
+ if is_tma_warp:
1410
1179
  cute.arch.cp_async_bulk_wait_group(0, read=True)
1411
1180
 
1412
1181
  @cute.jit
@@ -1420,10 +1189,10 @@ class HopperWgmmaGemmKernel:
1420
1189
  copy_B: Callable,
1421
1190
  tBgB: cute.Tensor,
1422
1191
  tBsB: cute.Tensor,
1192
+ k_tile_cnt: Int32,
1423
1193
  ) -> cutlass.pipeline.PipelineState:
1424
- k_tile_cnt = cute.size(tAgA, mode=[1])
1425
1194
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1426
- peek_ab_empty_status = cutlass.Boolean(True)
1195
+ peek_ab_empty_status = Boolean(True)
1427
1196
  if 0 < k_tile_cnt:
1428
1197
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1429
1198
  # /////////////////////////////////////////////////////////////////////////
@@ -1434,20 +1203,12 @@ class HopperWgmmaGemmKernel:
1434
1203
  # Also sets the transaction barrier for the A/B buffers
1435
1204
  ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1436
1205
  tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1437
- copy_A(
1438
- tAgA[None, k_tile],
1439
- tAsA[None, ab_producer_state.index],
1440
- tma_bar_ptr=tma_bar_ptr,
1441
- )
1442
- copy_B(
1443
- tBgB[None, k_tile],
1444
- tBsB[None, ab_producer_state.index],
1445
- tma_bar_ptr=tma_bar_ptr,
1446
- )
1206
+ copy_A(tAgA[None, k_tile], tAsA[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
1207
+ copy_B(tBgB[None, k_tile], tBsB[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
1447
1208
  # Mainloop pipeline's producer commit is a NOP
1448
1209
  ab_pipeline.producer_commit(ab_producer_state)
1449
1210
  ab_producer_state.advance()
1450
- peek_ab_empty_status = cutlass.Boolean(True)
1211
+ peek_ab_empty_status = Boolean(True)
1451
1212
  if k_tile + 1 < k_tile_cnt:
1452
1213
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1453
1214
  return ab_producer_state
@@ -1464,6 +1225,7 @@ class HopperWgmmaGemmKernel:
1464
1225
  copy_B: Callable,
1465
1226
  tBgB: cute.Tensor,
1466
1227
  tBsB: cute.Tensor,
1228
+ k_tile_cnt: Int32,
1467
1229
  limit_A: Tuple[Int32, Int32],
1468
1230
  ) -> cutlass.pipeline.PipelineState:
1469
1231
  # (atom_v, CPY_M, 1, RestK)
@@ -1489,9 +1251,8 @@ class HopperWgmmaGemmKernel:
1489
1251
  # (m, (bK, RestK))
1490
1252
  mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
1491
1253
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1492
- k_tile_cnt = cute.size(tBgB, mode=[1])
1493
1254
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1494
- peek_ab_empty_status = cutlass.Boolean(True)
1255
+ peek_ab_empty_status = Boolean(True)
1495
1256
  if 0 < k_tile_cnt:
1496
1257
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1497
1258
  # /////////////////////////////////////////////////////////////////////////
@@ -1527,7 +1288,7 @@ class HopperWgmmaGemmKernel:
1527
1288
  # This tells mbarrier to track the completion of cp.async
1528
1289
  ab_pipeline.producer_commit(ab_producer_state)
1529
1290
  ab_producer_state.advance()
1530
- peek_ab_empty_status = cutlass.Boolean(True)
1291
+ peek_ab_empty_status = Boolean(True)
1531
1292
  if k_tile + 1 < k_tile_cnt:
1532
1293
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1533
1294
  # bound checking in the K dimension on the last k_tile
@@ -1545,7 +1306,7 @@ class HopperWgmmaGemmKernel:
1545
1306
  tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1546
1307
  )
1547
1308
  assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
1548
- tApA = cute.make_fragment(1, cutlass.Boolean)
1309
+ tApA = cute.make_fragment(1, Boolean)
1549
1310
  tApA[0] = tAcA[0, 0, 0][1] < limit_k
1550
1311
  # (m, bK)
1551
1312
  mA_cur = mA_k[None, (None, k_tile)]
@@ -1584,12 +1345,11 @@ class HopperWgmmaGemmKernel:
1584
1345
  num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
1585
1346
  if const_expr(self.pingpong):
1586
1347
  self.pingpong_barrier_sync(warp_group_idx, stage="mma")
1587
- peek_ab_full_status = cutlass.Boolean(True)
1348
+ peek_ab_full_status = Boolean(True)
1588
1349
  if 0 < k_tile_cnt:
1589
1350
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1590
1351
  tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1591
1352
  num_k_blocks = cute.size(tCrA, mode=[2])
1592
- # TODO: this is probably not correct if k_tile_cnt == 0
1593
1353
  for k_tile in cutlass.range(num_prologue_mma):
1594
1354
  # Wait for A/B buffer to be ready
1595
1355
  ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
@@ -1600,9 +1360,11 @@ class HopperWgmmaGemmKernel:
1600
1360
  tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1601
1361
  warpgroup.commit_group()
1602
1362
  ab_read_state.advance()
1603
- peek_ab_full_status = cutlass.Boolean(True)
1363
+ peek_ab_full_status = Boolean(True)
1604
1364
  if k_tile + 1 < k_tile_cnt:
1605
1365
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1366
+ # If k_tile_cnt == 0, this is not correct. But we will set acc to 0 in the mainloop
1367
+ # in that case.
1606
1368
  if const_expr(self.fp8_slow_accum):
1607
1369
  warpgroup.wait_group(0)
1608
1370
  acc_slow.store(acc.load())
@@ -1631,7 +1393,7 @@ class HopperWgmmaGemmKernel:
1631
1393
  ab_pipeline.consumer_release(ab_release_state)
1632
1394
  ab_read_state.advance()
1633
1395
  ab_release_state.advance()
1634
- peek_ab_full_status = cutlass.Boolean(True)
1396
+ peek_ab_full_status = Boolean(True)
1635
1397
  if k_tile + 1 < k_tile_cnt:
1636
1398
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1637
1399
  if const_expr(self.pingpong):
@@ -1640,7 +1402,7 @@ class HopperWgmmaGemmKernel:
1640
1402
  if const_expr(not self.fp8_slow_accum):
1641
1403
  # fp8_slow_accum would already called wait_group(0) inside the loop
1642
1404
  warpgroup.wait_group(0)
1643
- for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
1405
+ for k_tile in cutlass.range(num_prologue_mma, unroll=1):
1644
1406
  ab_pipeline.consumer_release(ab_release_state)
1645
1407
  ab_release_state.advance()
1646
1408
  if const_expr(self.fp8_slow_accum):
@@ -1649,6 +1411,184 @@ class HopperWgmmaGemmKernel:
1649
1411
  # "operand #0 does not dominate this use"
1650
1412
  return ab_read_state, tiled_mma
1651
1413
 
1414
+ @cute.jit
1415
+ def epilogue(
1416
+ self,
1417
+ params: EpilogueParams,
1418
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
1419
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
1420
+ epi_read_state: cutlass.pipeline.PipelineState,
1421
+ epi_producer_state: cutlass.pipeline.PipelineState,
1422
+ tiled_mma: cute.TiledMma,
1423
+ tRS_rAcc: cute.Tensor,
1424
+ tRS_rD: cute.Tensor,
1425
+ tRS_rC: Optional[cute.Tensor],
1426
+ tiled_copy_r2s: cute.core.ThrCopy,
1427
+ tRS_sD: cute.Tensor,
1428
+ tiled_copy_s2r: Optional[cute.core.ThrCopy],
1429
+ tSR_rC: Optional[cute.Tensor],
1430
+ tSR_sC: Optional[cute.Tensor],
1431
+ copy_D: Optional[Callable],
1432
+ bSG_sD: cute.Tensor,
1433
+ bSG_gD: cute.Tensor,
1434
+ epi_load_g2s: Optional[Callable],
1435
+ tile_coord_mnkl: cute.Coord,
1436
+ cu_seqlens_m: Optional[cute.Tensor],
1437
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
1438
+ tile_scheduler,
1439
+ tidx: Int32,
1440
+ is_tma_warp: Boolean,
1441
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
1442
+ has_C = const_expr(tRS_rC is not None)
1443
+ has_D = const_expr(copy_D is not None)
1444
+ # We iterate over epi tiles in the N dimension first before the M dimension
1445
+ epi_tile_shape = cute.zipped_divide(
1446
+ cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
1447
+ ).shape[1]
1448
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
1449
+ epi_tile_num = cute.size(epi_tile_shape)
1450
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1451
+
1452
+ if const_expr(epi_load_g2s is not None):
1453
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
1454
+ epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
1455
+
1456
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
1457
+ # Copy from acc to D registers
1458
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1459
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1460
+ if const_expr(has_C):
1461
+ epi_pipeline.consumer_wait(epi_read_state)
1462
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
1463
+ # Fence to make sure shared memory read is visible to TMA load
1464
+ cute.arch.fence_proxy(
1465
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1466
+ )
1467
+ cute.arch.sync_warp()
1468
+ with cute.arch.elect_one():
1469
+ epi_pipeline.consumer_release(epi_read_state)
1470
+ epi_read_state.advance()
1471
+ if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
1472
+ epi_producer_state = epi_load_g2s(
1473
+ epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
1474
+ )
1475
+ tRS_rEpi = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
1476
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
1477
+ # Copy from D registers to shared memory
1478
+ if const_expr(has_D):
1479
+ # Type conversion
1480
+ tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
1481
+ tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
1482
+ cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
1483
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1484
+ cute.arch.fence_proxy(
1485
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1486
+ )
1487
+ epilogue_barrier.arrive_and_wait()
1488
+ # Get the global memory coordinate for the current epi tile
1489
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1490
+ # Copy from shared memory to global memory
1491
+ if is_tma_warp:
1492
+ if const_expr(has_D):
1493
+ copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
1494
+ cute.arch.cp_async_bulk_commit_group()
1495
+ cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
1496
+ epilogue_barrier.arrive_and_wait()
1497
+
1498
+ return epi_read_state, epi_producer_state
1499
+
1500
+ @cute.jit
1501
+ def epi_load_g2s(
1502
+ self,
1503
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
1504
+ copy_C: Callable,
1505
+ bGS_gC: cute.Tensor,
1506
+ bGS_sC: cute.Tensor,
1507
+ epi_producer_state: cutlass.pipeline.PipelineState,
1508
+ epi_idx: Int32,
1509
+ should_load: Boolean,
1510
+ ) -> cutlass.pipeline.PipelineState:
1511
+ # We iterate over epi tiles in the N dimension first before the M dimension
1512
+ epi_tile_layout = cute.make_layout(bGS_gC.shape[1], stride=(bGS_gC.shape[1][1], 1))
1513
+ if should_load:
1514
+ epi_pipeline.producer_acquire(epi_producer_state)
1515
+ # Get the global memory coordinate for the current epi tile
1516
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1517
+ copy_C(
1518
+ bGS_gC[None, gmem_coord],
1519
+ bGS_sC[None, epi_producer_state.index],
1520
+ tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1521
+ )
1522
+ # Epi pipeline's producer commit is a NOP
1523
+ epi_pipeline.producer_commit(epi_producer_state)
1524
+ epi_producer_state.advance()
1525
+ return epi_producer_state
1526
+
1527
+ def epi_visit_acc_subtile(
1528
+ self,
1529
+ params: EpilogueParams,
1530
+ tRS_rD: cute.Tensor,
1531
+ tRS_rC: Optional[cute.Tensor] = None,
1532
+ ) -> Optional[cute.Tensor]:
1533
+ # Apply alpha scaling to accumulator if alpha is provided (not None)
1534
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
1535
+ alpha = utils.load_scalar_or_pointer(params.alpha)
1536
+ tRS_rD.store(tRS_rD.load() * alpha)
1537
+ # Apply C with beta scaling
1538
+ if const_expr(tRS_rC is not None):
1539
+ if const_expr(not hasattr(params, "beta") or params.beta is None):
1540
+ # beta is None, default behavior: add C (beta=1.0)
1541
+ tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
1542
+ else:
1543
+ beta = utils.load_scalar_or_pointer(params.beta)
1544
+ tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
1545
+ return None
1546
+
1547
+ def get_scheduler_class(self):
1548
+ """Return the scheduler class to use. Override in subclasses for custom schedulers."""
1549
+ return TileScheduler
1550
+
1551
+ def get_scheduler_arguments(self, problem_shape_ntile_mnl, scheduler_args):
1552
+ """Create scheduler arguments. Override in subclasses for custom schedulers."""
1553
+ return TileSchedulerArguments(
1554
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1555
+ raster_order=scheduler_args.raster_order,
1556
+ group_size=scheduler_args.max_swizzle_size,
1557
+ cluster_shape_mnk=self.cluster_shape_mnk,
1558
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
1559
+ batch_idx_permute=scheduler_args.batch_idx_permute,
1560
+ is_persistent=self.is_persistent,
1561
+ )
1562
+
1563
+ def epi_visit_acc(
1564
+ self,
1565
+ params: EpilogueParams,
1566
+ acc: cute.Tensor,
1567
+ tiled_mma: cute.TiledMma,
1568
+ tile_coord_mnkl: cute.Coord,
1569
+ tidx: Int32,
1570
+ ) -> None:
1571
+ pass
1572
+
1573
+ def epi_to_underlying_arguments(
1574
+ self, args: EpilogueArguments, *, loc=None, ip=None
1575
+ ) -> EpilogueParams:
1576
+ return GemmSm90.EpilogueParams(alpha=args.alpha, beta=args.beta)
1577
+
1578
+ @staticmethod
1579
+ def epi_smem_bytes_per_stage(
1580
+ args: Optional[EpilogueArguments],
1581
+ tile_shape_mnk: Tuple[int, int, int],
1582
+ epi_tile: Tuple[int, int],
1583
+ ) -> int:
1584
+ return 0
1585
+
1586
+ def epi_get_smem_struct(self, params: EpilogueParams):
1587
+ return cute.struct.MemRange[cutlass.Int32, 0] # Dummy struct
1588
+
1589
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
1590
+ return tuple()
1591
+
1652
1592
  def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
1653
1593
  assert stage in ["mma", "epi"]
1654
1594
  barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
@@ -1665,14 +1605,174 @@ class HopperWgmmaGemmKernel:
1665
1605
  number_of_threads=2 * self.num_threads_per_warp_group,
1666
1606
  )
1667
1607
 
1668
- @staticmethod
1608
+ def epilog_smem_copy_atom(self, tiled_mma: cute.TiledMma) -> cute.TiledCopy:
1609
+ copy_atom_C = cute.make_copy_atom(
1610
+ warp.StMatrix8x8x16bOp(
1611
+ self.d_layout.is_m_major_c() if self.d_layout is not None else False,
1612
+ num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
1613
+ ),
1614
+ cutlass.Float16, # this is just to get the right source layout
1615
+ )
1616
+ tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
1617
+ return tiled_copy_C_atom
1618
+
1619
+ def epilog_smem_store_and_partition(
1620
+ self,
1621
+ tiled_mma: cute.TiledMma,
1622
+ d_layout: Optional[LayoutEnum],
1623
+ dtype: Type[cutlass.Numeric],
1624
+ acc: cute.Tensor,
1625
+ sD: cute.Tensor,
1626
+ tidx: Int32,
1627
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1628
+ if d_layout is None:
1629
+ d_layout = LayoutEnum.ROW_MAJOR
1630
+ tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
1631
+ # Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
1632
+ # get st.matrix with num_matrices=4
1633
+ copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
1634
+ d_layout, elem_ty_d=dtype, elem_ty_acc=self.acc_dtype
1635
+ )
1636
+ tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
1637
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1638
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1639
+ tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1640
+ # (R2S, R2S_M, R2S_N)
1641
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
1642
+ sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
1643
+ tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
1644
+ tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
1645
+ return tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD
1646
+
1647
+ def epilog_smem_load_and_partition(
1648
+ self,
1649
+ tiled_mma: cute.TiledMma,
1650
+ c_layout: LayoutEnum,
1651
+ dtype: Type[cutlass.Numeric],
1652
+ sC: cute.Tensor,
1653
+ tRS_rD_layout: cutlass.Layout,
1654
+ tidx: Int32,
1655
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1656
+ tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
1657
+ copy_atom_s2r = utils.sm90_get_smem_load_op(c_layout, dtype)
1658
+ tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1659
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1660
+ tSR_sC = thr_copy_s2r.partition_S(sC)
1661
+ tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
1662
+ tSR_rC = thr_copy_s2r.retile(tRS_rC)
1663
+ return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
1664
+
1665
+ def epilog_gmem_copy_and_partition(
1666
+ self,
1667
+ atom: Union[cute.CopyAtom, cute.TiledCopy],
1668
+ mD_mnl: cute.Tensor,
1669
+ tile_shape_mn: cute.Tile,
1670
+ epi_tile: cute.Tile,
1671
+ sD: cute.Tensor,
1672
+ tile_coord_mnkl: cute.Coord,
1673
+ cu_seqlens_m: Optional[cute.Tensor] = None,
1674
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
1675
+ batch_idx = tile_coord_mnkl[3]
1676
+ if const_expr(cu_seqlens_m is not None):
1677
+ mD_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl)
1678
+ else:
1679
+ mD_mn = mD_mnl[None, None, batch_idx]
1680
+ # (bM, bN)
1681
+ gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
1682
+ tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
1683
+ bSG_sD, bSG_gD = cpasync.tma_partition(
1684
+ atom,
1685
+ 0,
1686
+ cute.make_layout(1),
1687
+ cute.group_modes(sD, 0, 2),
1688
+ tDgD_for_tma_partition,
1689
+ )
1690
+ return bSG_sD, bSG_gD
1691
+
1692
+ def make_ab_pipeline(
1693
+ self,
1694
+ a_smem_layout: cute.Layout | cute.ComposedLayout,
1695
+ b_smem_layout: cute.Layout | cute.ComposedLayout,
1696
+ tiled_mma: cute.TiledMma,
1697
+ cluster_layout_vmnk: cute.Layout,
1698
+ ab_pipeline_mbar_ptr: cute.Pointer,
1699
+ ):
1700
+ # Threads/warps participating in this pipeline
1701
+ producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
1702
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1703
+ # Each warp will contribute to the arrive count with the number of mcast size
1704
+ mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
1705
+ consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
1706
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(
1707
+ pipeline.Agent.Thread, consumer_arrive_cnt
1708
+ )
1709
+ pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
1710
+ tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
1711
+ if const_expr(not self.gather_A):
1712
+ tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
1713
+ return pipeline_cls.create(
1714
+ barrier_storage=ab_pipeline_mbar_ptr,
1715
+ num_stages=self.ab_stage,
1716
+ producer_group=ab_pipeline_producer_group,
1717
+ consumer_group=ab_pipeline_consumer_group,
1718
+ tx_count=tma_copy_bytes,
1719
+ cta_layout_vmnk=cluster_layout_vmnk,
1720
+ )
1721
+
1722
+ def make_epi_pipeline(
1723
+ self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer
1724
+ ):
1725
+ # Threads/warps participating in this pipeline
1726
+ epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1727
+ # Each warp will contribute 1 to the arrive count
1728
+ consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
1729
+ epi_pipeline_consumer_group = pipeline.CooperativeGroup(
1730
+ pipeline.Agent.Thread, consumer_arrive_cnt
1731
+ )
1732
+ tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
1733
+ return pipeline.PipelineTmaAsync.create(
1734
+ barrier_storage=epi_pipeline_mbar_ptr,
1735
+ num_stages=self.epi_c_stage,
1736
+ producer_group=epi_pipeline_producer_group,
1737
+ consumer_group=epi_pipeline_consumer_group,
1738
+ tx_count=tma_copy_c_bytes,
1739
+ )
1740
+
1741
+ def make_sched_pipeline(
1742
+ self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
1743
+ ):
1744
+ # Threads/warps participating in this pipeline
1745
+ sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1746
+ cluster_size = cute.size(cluster_layout_mnk)
1747
+ # Each warp that are not the scheduler warp will contribute 1 to the arrive count
1748
+ # If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
1749
+ # at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
1750
+ consumer_arrive_cnt = (
1751
+ (self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
1752
+ + self.num_ab_load_warps
1753
+ ) * cluster_size - 1
1754
+ sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1755
+ pipeline.Agent.Thread, consumer_arrive_cnt
1756
+ )
1757
+ return pipeline.PipelineAsync.create(
1758
+ barrier_storage=sched_pipeline_mbar_ptr,
1759
+ num_stages=self.sched_stage,
1760
+ producer_group=sched_pipeline_producer_group,
1761
+ consumer_group=sched_pipeline_consumer_group,
1762
+ # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1763
+ consumer_mask=None if const_expr(cluster_size == 1) else 0,
1764
+ )
1765
+
1766
+ @classmethod
1669
1767
  def _compute_stages(
1768
+ cls,
1670
1769
  tile_shape_mnk: Tuple[int, int, int],
1671
- epi_tile: Optional[Tuple[int, int]],
1770
+ epi_tile: Tuple[int, int],
1672
1771
  a_dtype: Type[cutlass.Numeric],
1673
1772
  b_dtype: Type[cutlass.Numeric],
1674
- d_dtype: Type[cutlass.Numeric],
1773
+ d_dtype: Optional[Type[cutlass.Numeric]],
1675
1774
  c_dtype: Optional[Type[cutlass.Numeric]],
1775
+ epilogue_args: Optional[EpilogueArguments],
1676
1776
  smem_capacity: int,
1677
1777
  occupancy: int,
1678
1778
  overlap_sD_sA: bool,
@@ -1695,13 +1795,18 @@ class HopperWgmmaGemmKernel:
1695
1795
  :rtype: Tuple[int, int]
1696
1796
  """
1697
1797
 
1698
- epi_stage = 2
1798
+ epi_stage = 4 if epi_tile[1] <= 16 else 2
1699
1799
  if overlap_sD_sA:
1700
1800
  epi_bytes = 0
1701
1801
  else:
1702
- d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8
1703
- epi_bytes = d_bytes_per_stage * epi_stage
1704
- epi_c_stage = 0 if c_dtype is None else 2
1802
+ d_bytes_per_stage = (
1803
+ cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1804
+ )
1805
+ epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1806
+ epilogue_args, tile_shape_mnk, epi_tile
1807
+ )
1808
+ epi_bytes = epi_bytes_per_stage * epi_stage
1809
+ epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
1705
1810
  if c_dtype is not None:
1706
1811
  epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
1707
1812
 
@@ -1712,23 +1817,21 @@ class HopperWgmmaGemmKernel:
1712
1817
  )
1713
1818
  mbar_helpers_bytes = 1024
1714
1819
 
1715
- remaining_bytes = (
1716
- (smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
1717
- )
1820
+ remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
1718
1821
  ab_stage = remaining_bytes // ab_bytes_per_stage
1719
1822
 
1720
1823
  # Refine epilogue stages:
1721
1824
  # Calculate remaining smem after allocating for A/B stages and reserved bytes
1722
1825
  # Add remaining unused smem to epilogue
1723
- if not overlap_sD_sA:
1724
- epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // d_bytes_per_stage
1826
+ if not overlap_sD_sA and epi_bytes_per_stage > 0:
1827
+ epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
1725
1828
  return ab_stage, epi_stage, epi_c_stage
1726
1829
 
1727
1830
  @staticmethod
1728
1831
  def _sm90_compute_tile_shape_or_override(
1729
1832
  tile_shape_mnk: Tuple[int, int, int],
1730
1833
  atom_layout_mnk: Tuple[int, int, int],
1731
- element_type: Type[cutlass.Numeric],
1834
+ element_type: Optional[Type[cutlass.Numeric]] = None,
1732
1835
  epi_tile_override: Tuple[int, int] | None = None,
1733
1836
  ) -> Tuple[int, int]:
1734
1837
  """Compute the epilogue tile shape or use override if provided.
@@ -1760,7 +1863,7 @@ class HopperWgmmaGemmKernel:
1760
1863
  # iterate along the N dimension first, then move to the M dimension.
1761
1864
  # We could change the epilogue to accommodate this,
1762
1865
  # but it's easier to just set epi_tile_m = 64.
1763
- n_perf = 64 if element_type.width == 8 else 32
1866
+ n_perf = 64 if element_type is not None and element_type.width == 8 else 32
1764
1867
  tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
1765
1868
  tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
1766
1869
  return (tile_m, tile_n)
@@ -1770,15 +1873,15 @@ class HopperWgmmaGemmKernel:
1770
1873
  tile_shape_mnk: Tuple[int, int, int],
1771
1874
  epi_tile: Tuple[int, int],
1772
1875
  a_dtype: Type[cutlass.Numeric],
1773
- a_layout: cutlass.utils.LayoutEnum,
1876
+ a_layout: LayoutEnum,
1774
1877
  b_dtype: Type[cutlass.Numeric],
1775
- b_layout: cutlass.utils.LayoutEnum,
1878
+ b_layout: LayoutEnum,
1776
1879
  ab_stage: int,
1777
- d_dtype: Type[cutlass.Numeric],
1778
- d_layout: cutlass.utils.LayoutEnum,
1880
+ d_dtype: Optional[Type[cutlass.Numeric]],
1881
+ d_layout: LayoutEnum,
1779
1882
  epi_stage: int,
1780
1883
  c_dtype: Optional[Type[cutlass.Numeric]],
1781
- c_layout: Optional[cutlass.utils.LayoutEnum],
1884
+ c_layout: Optional[LayoutEnum],
1782
1885
  epi_c_stage: int,
1783
1886
  ) -> Tuple[
1784
1887
  cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
@@ -1792,17 +1895,17 @@ class HopperWgmmaGemmKernel:
1792
1895
  :param a_dtype: Data type for matrix A
1793
1896
  :type a_dtype: type[cutlass.Numeric]
1794
1897
  :param a_layout: Layout enum for matrix A
1795
- :type a_layout: cutlass.utils.LayoutEnum
1898
+ :type a_layout: LayoutEnum
1796
1899
  :param b_dtype: Data type for matrix B
1797
1900
  :type b_dtype: type[cutlass.Numeric]
1798
1901
  :param b_layout: Layout enum for matrix B
1799
- :type b_layout: cutlass.utils.LayoutEnum
1902
+ :type b_layout: LayoutEnum
1800
1903
  :param ab_stage: Number of stages for A/B tensors
1801
1904
  :type ab_stage: int
1802
- :param d_dtype: Data type for output matrix C
1905
+ :param d_dtype: Data type for output matrix D
1803
1906
  :type d_dtype: type[cutlass.Numeric]
1804
1907
  :param d_layout: Layout enum for the output matrix C
1805
- :type d_layout: cutlass.utils.LayoutEnum
1908
+ :type d_layout: LayoutEnum
1806
1909
  :param epi_stage: Number of epilogue stages
1807
1910
  :type epi_stage: int
1808
1911
 
@@ -1815,11 +1918,7 @@ class HopperWgmmaGemmKernel:
1815
1918
  b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1816
1919
  a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
1817
1920
  a_smem_layout_atom = warpgroup.make_smem_layout_atom(
1818
- sm90_utils.get_smem_layout_atom(
1819
- a_layout,
1820
- a_dtype,
1821
- a_major_mode_size,
1822
- ),
1921
+ sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
1823
1922
  a_dtype,
1824
1923
  )
1825
1924
  a_smem_layout_staged = cute.tile_to_shape(
@@ -1832,11 +1931,7 @@ class HopperWgmmaGemmKernel:
1832
1931
 
1833
1932
  b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
1834
1933
  b_smem_layout_atom = warpgroup.make_smem_layout_atom(
1835
- sm90_utils.get_smem_layout_atom(
1836
- b_layout,
1837
- b_dtype,
1838
- b_major_mode_size,
1839
- ),
1934
+ sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
1840
1935
  b_dtype,
1841
1936
  )
1842
1937
  b_smem_layout_staged = cute.tile_to_shape(
@@ -1845,17 +1940,20 @@ class HopperWgmmaGemmKernel:
1845
1940
  order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
1846
1941
  )
1847
1942
 
1848
- d_smem_shape = epi_tile
1849
- d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1850
- d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1851
- sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
1852
- d_dtype,
1853
- )
1854
- epi_smem_layout_staged = cute.tile_to_shape(
1855
- d_smem_layout_atom,
1856
- cute.append(d_smem_shape, epi_stage),
1857
- order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1858
- )
1943
+ if d_dtype is not None:
1944
+ d_smem_shape = epi_tile
1945
+ d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1946
+ d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1947
+ sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
1948
+ d_dtype,
1949
+ )
1950
+ epi_smem_layout_staged = cute.tile_to_shape(
1951
+ d_smem_layout_atom,
1952
+ cute.append(d_smem_shape, epi_stage),
1953
+ order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1954
+ )
1955
+ else:
1956
+ epi_smem_layout_staged = None
1859
1957
 
1860
1958
  if c_dtype is not None:
1861
1959
  assert c_layout is not None
@@ -1961,7 +2059,7 @@ class HopperWgmmaGemmKernel:
1961
2059
  thread_layout = cute.make_layout(
1962
2060
  (num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
1963
2061
  )
1964
- if major_mode != cutlass.utils.LayoutEnum.ROW_MAJOR:
2062
+ if major_mode != LayoutEnum.ROW_MAJOR:
1965
2063
  shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
1966
2064
  thread_layout = cute.make_layout(
1967
2065
  (shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
@@ -1969,7 +2067,7 @@ class HopperWgmmaGemmKernel:
1969
2067
  # Value layout for copy
1970
2068
  value_layout = (
1971
2069
  cute.make_layout((1, copy_elems))
1972
- if major_mode == cutlass.utils.LayoutEnum.ROW_MAJOR
2070
+ if major_mode == LayoutEnum.ROW_MAJOR
1973
2071
  else cute.make_layout((copy_elems, 1))
1974
2072
  )
1975
2073
  return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
@@ -1979,7 +2077,7 @@ class HopperWgmmaGemmKernel:
1979
2077
  a_dtype: Type[cutlass.Numeric],
1980
2078
  b_dtype: Type[cutlass.Numeric],
1981
2079
  acc_dtype: Type[cutlass.Numeric],
1982
- d_dtype: Type[cutlass.Numeric],
2080
+ d_dtype: Optional[Type[cutlass.Numeric]],
1983
2081
  a_major: str,
1984
2082
  b_major: str,
1985
2083
  ) -> bool:
@@ -2022,6 +2120,7 @@ class HopperWgmmaGemmKernel:
2022
2120
  is_valid = False
2023
2121
  # tested d_dtype
2024
2122
  if d_dtype not in {
2123
+ None,
2025
2124
  cutlass.Float32,
2026
2125
  cutlass.Float16,
2027
2126
  cutlass.BFloat16,
@@ -2039,436 +2138,108 @@ class HopperWgmmaGemmKernel:
2039
2138
  # for Float8 types, this implementation only supports k-major layout
2040
2139
  if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
2041
2140
  is_valid = False
2042
-
2043
2141
  return is_valid
2044
2142
 
2045
2143
 
2046
- def run(
2047
- mnkl: Tuple[int, int, int, int],
2048
- a_dtype: Type[cutlass.Numeric],
2049
- b_dtype: Type[cutlass.Numeric],
2050
- d_dtype: Type[cutlass.Numeric],
2051
- c_dtype: Optional[Type[cutlass.Numeric]],
2052
- acc_dtype: Type[cutlass.Numeric],
2053
- a_major: str,
2054
- b_major: str,
2055
- d_major: str,
2056
- c_major: str,
2057
- tile_shape_mnk: Tuple[int, int, int],
2058
- cluster_shape_mn: Tuple[int, int],
2059
- tolerance: float,
2060
- warmup_iterations: int,
2061
- iterations: int,
2062
- skip_ref_check: bool,
2063
- persistent: bool,
2064
- dynamic_persistent: bool,
2065
- pingpong: bool,
2066
- varlen_m: bool,
2067
- gather_A: bool,
2068
- fp8_fast_accum: bool,
2069
- **kwargs,
2070
- ):
2071
- """
2072
- Prepare A/B/D/C tensors, launch GPU kernel, and reference checking.
2073
-
2074
- :param mnkl: Problem size (M, N, K, L)
2075
- :type mnkl: Tuple[int, int, int, int]
2076
- :param a_dtype: Data type for input tensor A
2077
- :type a_dtype: Type[cutlass.Numeric]
2078
- :param b_dtype: Data type for input tensor B
2079
- :type b_dtype: Type[cutlass.Numeric]
2080
- :param d_dtype: Data type for output tensor C
2081
- :type d_dtype: Type[cutlass.Numeric]
2082
- :param acc_dtype: Data type for accumulation during matrix multiplication
2083
- :type acc_dtype: Type[cutlass.Numeric]
2084
- :param a_major/b_major/d_major: Memory layout of tensor A/B/C
2085
- :type a_major/b_major/d_major: str
2086
- :param tile_shape_mnk: CTA tile shape (M, N, K)
2087
- :type tile_shape_mnk: Tuple[int, int, int]
2088
- :param cluster_shape_mn: Cluster shape (M, N)
2089
- :type cluster_shape_mn: Tuple[int, int]
2090
- :param tolerance: Tolerance value for reference validation comparison
2091
- :type tolerance: float
2092
- :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
2093
- :type warmup_iterations: int, optional
2094
- :param iterations: Number of benchmark iterations to run, defaults to 1
2095
- :type iterations: int, optional
2096
- :param skip_ref_check: Whether to skip reference result validation, defaults to False
2097
- :type skip_ref_check: bool, optional
2098
- """
2099
-
2100
- if dynamic_persistent:
2101
- persistent = True
2102
-
2103
- print("Running Hopper Dense GEMM with:")
2104
- print(f"mnkl: {mnkl}")
2105
- print(
2106
- f"A dtype: {a_dtype}, B dtype: {b_dtype}, D dtype: {d_dtype}, C_dtype: {c_dtype}, Acc dtype: {acc_dtype}"
2107
- )
2108
- print(f"Matrix majors - A: {a_major}, B: {b_major}, D: {d_major}")
2109
- print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
2110
- print(f"Tolerance: {tolerance}")
2111
- print(f"Warmup iterations: {warmup_iterations}")
2112
- print(f"Iterations: {iterations}")
2113
- print(f"Skip reference checking: {skip_ref_check}")
2114
-
2115
- # Unpack parameters
2116
- m, n, k, l = mnkl
2117
- cluster_shape_mnk = (*cluster_shape_mn, 1)
2118
-
2119
- # Skip unsupported types
2120
- if not HopperWgmmaGemmKernel.is_valid_dtypes(
2121
- a_dtype, b_dtype, acc_dtype, d_dtype, a_major, b_major
2144
+ def gemm_sm90(
2145
+ A: Tensor, # (l, m, k)
2146
+ B: Tensor, # (l, n, k)
2147
+ D: Tensor, # (l, m, n)
2148
+ C: Optional[Tensor], # (l, m, n)
2149
+ tile_count_semaphore: Optional[Tensor], # (1,)
2150
+ tile_M: int,
2151
+ tile_N: int,
2152
+ cluster_M: int,
2153
+ cluster_N: int,
2154
+ pingpong: bool = False,
2155
+ persistent: bool = True,
2156
+ alpha: float | Tensor = 1.0,
2157
+ beta: float | Tensor = 1.0,
2158
+ ) -> None:
2159
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(A, B, D, C)
2160
+ GemmWrapperBase.permute_tensors(tensor_infos)
2161
+ GemmWrapperBase.extract_dtypes(tensor_infos)
2162
+ major_configs = {
2163
+ "A": ("m", "k", "l"),
2164
+ "B": ("n", "k", "l"),
2165
+ "D": ("m", "n", "l"),
2166
+ "C": ("m", "n", "l"),
2167
+ }
2168
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
2169
+
2170
+ acc_dtype = cutlass.Float32
2171
+ tile_shape_mn = (tile_M, tile_N)
2172
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
2173
+ if not GemmSm90.is_valid_dtypes(
2174
+ tensor_infos["A"].dtype,
2175
+ tensor_infos["B"].dtype,
2176
+ acc_dtype,
2177
+ tensor_infos["D"].dtype,
2178
+ tensor_infos["A"].major,
2179
+ tensor_infos["B"].major,
2122
2180
  ):
2123
- raise TypeError(
2124
- f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {d_dtype}, {a_major=}, {b_major=}"
2125
- )
2181
+ raise TypeError("Skipping due to unsupported combination of types and majors")
2126
2182
 
2127
- # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero)
2128
- if not torch.cuda.is_available():
2129
- raise RuntimeError("GPU is required to run this example!")
2130
-
2131
- torch.manual_seed(1111)
2132
-
2133
- # Create and permute tensor A/B/C
2134
- def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
2135
- # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
2136
- # else : (l, mode0, mode1) -> (mode0, mode1, l)
2137
- shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
2138
- permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
2139
- is_unsigned = dtype in {cutlass.Uint8}
2140
- # Temporarily use uint8 as torch does not support fp8 type
2141
- torch_dtype = cutlass_torch.dtype(dtype)
2142
- gen_dtype = (
2143
- torch_dtype
2144
- if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
2145
- else torch.bfloat16
2146
- )
2147
-
2148
- # Create dtype torch tensor (cpu)
2149
- torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
2150
- shape,
2151
- gen_dtype,
2152
- permute_order=permute_order,
2153
- # init_type=cutlass.torch.TensorInitType.RANDOM,
2154
- # init_config=cutlass.torch.RandomInitConfig(
2155
- # min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
2156
- # ),
2157
- init_type=cutlass.torch.TensorInitType.GAUSSIAN,
2158
- init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
2159
- ).to(torch_dtype)
2160
- # Create dtype torch tensor (gpu)
2161
- torch_tensor = torch_tensor_cpu.cuda()
2162
-
2163
- # Create f32 torch tensor (cpu)
2164
- f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
2165
-
2166
- # Create dtype cute tensor (gpu)
2167
- torch_tensor_view = (
2168
- torch_tensor
2169
- if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
2170
- else torch_tensor.view(torch.uint8)
2171
- )
2172
- cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
2173
- cute_tensor.element_type = dtype
2174
- if is_dynamic_layout:
2175
- cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
2176
- cute_tensor = cute_tensor.mark_compact_shape_dynamic(
2177
- mode=(1 if not is_mode0_major else 0),
2178
- stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
2179
- divisibility=(128 // dtype.width),
2180
- )
2181
- cute_tensor = cutlass.torch.convert_cute_tensor(
2182
- f32_torch_tensor,
2183
- cute_tensor,
2184
- dtype,
2185
- is_dynamic_layout=is_dynamic_layout,
2186
- )
2183
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
2184
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
2187
2185
 
2188
- return f32_torch_tensor, cute_tensor, torch_tensor
2189
-
2190
- a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
2191
- if gather_A:
2192
- assert a_major == "k"
2193
- a_idx = torch.randperm(l * m, dtype=torch.int32, device="cuda")
2194
- from einops import rearrange
2195
-
2196
- a = rearrange(rearrange(a, "m k l -> (m l) k")[a_idx.cpu()], "(m l) k -> m k l", m=m)
2197
- a_torch = rearrange(a_torch, "m k l -> (m l) k")
2198
- mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2199
- a_idx_reshaped = rearrange(a_idx, "(m l) -> l m", m=m).contiguous().transpose(0, 1)
2200
- mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
2201
- else:
2202
- mAIdx = None
2203
- b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
2204
- _, mD, d_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
2205
- if c_dtype is not None:
2206
- c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
2207
- else:
2208
- c, mC, c_torch = None, None, None
2209
- if varlen_m:
2210
- assert a_major == "k"
2211
- assert d_major == "n"
2212
- from einops import rearrange
2213
-
2214
- a, d_torch = [rearrange(t, "m x l -> (l m) x") for t in (a, d_torch)]
2215
- if not gather_A:
2216
- (a_torch,) = [rearrange(t, "m x l -> (l m) x") for t in (a_torch,)]
2217
- if c_dtype is not None:
2218
- c, c_torch = [rearrange(t, "m x l -> (l m) x") for t in (c, c_torch)]
2219
- mC = from_dlpack(c_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2220
- mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2221
- mD = from_dlpack(d_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2222
- # TODO: generate random cu_seqlens_m
2223
- cu_seqlens_m = torch.arange(0, l + 1, dtype=torch.int32, device="cuda") * m
2224
- mCuSeqlensM = from_dlpack(cu_seqlens_m, assumed_align=64).mark_layout_dynamic(leading_dim=0)
2225
- if gather_A:
2226
- a_idx_reshaped = rearrange(a_idx_reshaped, "m l -> (l m)")
2227
- mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
2228
- else:
2229
- cu_seqlens_m, mCuSeqlensM = None, None
2230
-
2231
- if varlen_m: # Need to allocate space in gmem to store tensormaps
2232
- if not persistent:
2233
- total_m = m * l
2234
- block_size_m = tile_shape_mnk[0] * cluster_shape_mnk[0]
2235
- block_size_n = tile_shape_mnk[1] * cluster_shape_mnk[1]
2236
- total_clusters_m_max = (total_m + l * (block_size_m - 1)) // block_size_m
2237
- total_clusters_max = total_clusters_m_max * ((n + block_size_n - 1) // block_size_n)
2238
- total_ctas = total_clusters_max * cluster_shape_mnk[0] * cluster_shape_mnk[1]
2186
+ def scalar_arg(scalar: float | Tensor):
2187
+ if isinstance(scalar, float):
2188
+ return Float32(scalar) if scalar != 1.0 else None
2239
2189
  else:
2240
- total_ctas = cutlass.utils.HardwareInfo().get_device_multiprocessor_count()
2241
- if pingpong:
2242
- total_ctas *= 2
2243
- # 128 bytes per tensormap
2244
- tensormaps_torch = torch.empty(total_ctas, 128 // 8, dtype=torch.int64, device="cuda")
2245
- tensormaps_tensor = from_dlpack(
2246
- tensormaps_torch, assumed_align=128
2247
- ).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
2248
- else:
2249
- tensormaps_tensor = None
2250
-
2251
- gemm = HopperWgmmaGemmKernel(
2252
- acc_dtype,
2253
- a_dtype,
2254
- tile_shape_mnk,
2190
+ assert isinstance(scalar, Tensor)
2191
+ return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2192
+
2193
+ epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta))
2194
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
2195
+ max_active_clusters, tile_count_semaphore
2196
+ )
2197
+ current_stream = cutlass_torch.current_stream()
2198
+ compile_key = GemmWrapperBase.get_compile_key(
2199
+ tensor_infos,
2200
+ None,
2201
+ tile_shape_mn,
2255
2202
  cluster_shape_mnk,
2256
- pingpong=pingpong,
2257
- is_persistent=persistent,
2258
- fp8_fast_accum=fp8_fast_accum,
2259
- gather_A=gather_A,
2203
+ pingpong,
2204
+ persistent,
2205
+ tile_count_semaphore is not None,
2206
+ 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
2207
+ 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
2208
+ key_tensor_names=("A", "B", "D", "C"),
2260
2209
  )
2261
-
2262
- # Compute max active clusters on current device
2263
- if persistent:
2264
- max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
2265
- cluster_shape_mn[0] * cluster_shape_mn[1]
2210
+ cache = gemm_sm90.compile_cache
2211
+ if compile_key not in cache:
2212
+ gemm = GemmSm90(
2213
+ acc_dtype,
2214
+ tensor_infos["A"].dtype,
2215
+ tile_shape_mn,
2216
+ cluster_shape_mnk,
2217
+ pingpong=pingpong,
2218
+ is_persistent=persistent,
2266
2219
  )
2267
- if dynamic_persistent:
2268
- tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda")
2269
- else:
2270
- tile_count_semaphore = None
2271
- # max_active_clusters = 1
2272
- else:
2273
- max_active_clusters = 0
2274
- tile_count_semaphore = None
2275
-
2276
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
2277
- # compile gemm kernel
2278
- compiled_gemm = cute.compile(
2279
- gemm,
2280
- mA,
2281
- mB,
2282
- mD,
2283
- mC,
2284
- mAIdx,
2285
- mCuSeqlensM,
2286
- tensormaps_tensor,
2287
- make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2288
- if tile_count_semaphore is not None
2289
- else None,
2290
- max_active_clusters,
2291
- current_stream,
2292
- )
2293
-
2294
- if not skip_ref_check:
2295
- # execution
2296
- compiled_gemm(
2297
- mA,
2298
- mB,
2299
- mD,
2300
- mC,
2301
- mAIdx,
2302
- mCuSeqlensM,
2303
- tensormaps_tensor,
2304
- tile_count_semaphore,
2305
- max_active_clusters,
2220
+ cache[compile_key] = cute.compile(
2221
+ gemm,
2222
+ tensor_infos["A"].cute_tensor,
2223
+ tensor_infos["B"].cute_tensor,
2224
+ tensor_infos["D"].cute_tensor,
2225
+ tensor_infos["C"].cute_tensor,
2226
+ epi_args,
2227
+ scheduler_args,
2228
+ None, # varlen_args
2229
+ None, # mAIdx
2306
2230
  current_stream,
2307
2231
  )
2308
- if tile_count_semaphore is not None and varlen_m:
2309
- tile_count_semaphore.zero_()
2310
-
2311
- torch.cuda.synchronize()
2312
-
2313
- # Ref check
2314
- if not varlen_m:
2315
- ref = torch.einsum("mkl,nkl->mnl", a, b)
2316
- else:
2317
- ref = torch.cat(
2318
- [
2319
- torch.einsum("mk,nk->mn", a[cu_seqlens_m[i] : cu_seqlens_m[i + 1]], b[:, :, i])
2320
- for i in range(l)
2321
- ],
2322
- dim=0,
2323
- )
2324
- if c is not None:
2325
- ref = ref + c
2326
- ref = ref.cpu()
2327
-
2328
- if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
2329
- # m major: (l, n, m) -> (m, n, l)
2330
- # n major: (l, m, n) -> (m, n, l)
2331
- permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
2332
- shape = (l, m, n) if d_major == "n" else (l, n, m)
2333
- f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
2334
- shape,
2335
- torch.uint8,
2336
- permute_order=permute_order,
2337
- init_type=cutlass_torch.TensorInitType.SKIP,
2338
- ).cuda()
2339
- # Create dtype cute tensor (gpu)
2340
- ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
2341
- leading_dim=(1 if d_major == "n" else 0)
2342
- )
2343
- ref_d_tensor.element_type = d_dtype
2344
- ref_d_tensor = cutlass_torch.convert_cute_tensor(
2345
- ref,
2346
- ref_d_tensor,
2347
- d_dtype,
2348
- is_dynamic_layout=True,
2349
- )
2350
- ref_d = f8_torch_tensor.cpu()
2351
- else:
2352
- ref_d = ref.to(cutlass_torch.dtype(d_dtype))
2353
-
2354
- out = d_torch.cpu().squeeze()
2355
- out_ref = ref_d.squeeze()
2356
- # breakpoint()
2357
- torch.testing.assert_close(d_torch.cpu(), ref_d, atol=tolerance, rtol=1e-03)
2358
-
2359
- # return
2360
-
2361
- from triton.testing import do_bench
2362
-
2363
- flops = 2 * m * n * k * l
2364
- # Calculate memory bandwidth
2365
- bytes_A = m * k * l * (a_dtype.width // 8) # A tensor: (m, k, l)
2366
- bytes_B = n * k * l * (b_dtype.width // 8) # B tensor: (n, k, l)
2367
- bytes_D = m * n * l * (d_dtype.width // 8) # D tensor: (m, n, l)
2368
- bytes_C = m * n * l * (c_dtype.width // 8) if c_dtype is not None else 0 # C tensor: (m, n, l)
2369
- total_bytes = bytes_A + bytes_B + bytes_D + bytes_C # Read A, B, C; Write D
2370
-
2371
- repeats = iterations
2372
- warmup = warmup_iterations
2373
-
2374
- import time
2375
-
2376
- if not varlen_m and not gather_A:
2377
- time.sleep(0.5)
2378
- if a_dtype.width == 8:
2379
- assert l == 1
2380
- scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
2381
- fn_cublas = lambda: torch._scaled_mm(
2382
- a_torch[:, :, 0],
2383
- b_torch[:, :, 0].mT,
2384
- scale_a=scale_ab,
2385
- scale_b=scale_ab,
2386
- out_dtype=torch.bfloat16,
2387
- use_fast_accum=fp8_fast_accum,
2388
- )
2389
- else:
2390
- if c_torch is None:
2391
- fn_cublas = lambda: torch.matmul(
2392
- a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT
2393
- )
2394
- else:
2395
- c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
2396
- fn_cublas = lambda: torch.baddbmm(
2397
- c_torch_convert.permute(2, 0, 1),
2398
- a_torch.permute(2, 0, 1),
2399
- b_torch.permute(2, 0, 1).mT,
2400
- )
2401
- timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2402
- tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2403
- print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2232
+ cache[compile_key](
2233
+ tensor_infos["A"].cute_tensor,
2234
+ tensor_infos["B"].cute_tensor,
2235
+ tensor_infos["D"].cute_tensor,
2236
+ tensor_infos["C"].cute_tensor,
2237
+ epi_args,
2238
+ scheduler_args,
2239
+ None,
2240
+ None,
2241
+ current_stream,
2242
+ )
2404
2243
 
2405
- time.sleep(0.5)
2406
2244
 
2407
- def fn():
2408
- compiled_gemm(
2409
- mA,
2410
- mB,
2411
- mD,
2412
- mC,
2413
- mAIdx,
2414
- mCuSeqlensM,
2415
- tensormaps_tensor,
2416
- tile_count_semaphore,
2417
- max_active_clusters,
2418
- current_stream,
2419
- )
2420
- if tile_count_semaphore is not None and varlen_m:
2421
- tile_count_semaphore.zero_()
2422
-
2423
- timing = do_bench(fn, warmup=warmup, rep=repeats)
2424
- # Idk why but for some cases the 1st run is much slower
2425
- time.sleep(0.5)
2426
- timing = do_bench(fn, warmup=warmup, rep=repeats)
2427
- tflops = flops / (timing * 1e9) # Convert to TFlops
2428
- gbps = total_bytes / (timing * 1e6) # Convert to GB/s (1e9 for ms->s, 1e9 for B->GB)
2429
- print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}, GB/s: {gbps:.0f}")
2430
- fn()
2431
-
2432
- if not varlen_m:
2433
- time.sleep(0.5)
2434
- timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2435
- tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2436
- print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2437
-
2438
- from flash_attn.utils.benchmark import pytorch_profiler
2439
-
2440
- pytorch_profiler(fn_cublas)
2441
- # pytorch_profiler(torch.sort, d_torch.squeeze(), dim=-1)
2442
- # pytorch_profiler(torch.compile(torch.sort), d_torch.squeeze(), dim=-1)
2443
- # pytorch_profiler(torch.topk, d_torch.squeeze(), dim=-1, k=1)
2444
- # pytorch_profiler(torch.compile(torch.topk), d_torch.squeeze(), dim=-1, k=1)
2445
- # pytorch_profiler(torch.square, d_torch.squeeze())
2446
-
2447
-
2448
- if __name__ == "__main__":
2449
- args = parse_arguments()
2450
- run(
2451
- args.mnkl,
2452
- args.a_dtype,
2453
- args.b_dtype,
2454
- args.d_dtype,
2455
- args.c_dtype,
2456
- args.acc_dtype,
2457
- args.a_major,
2458
- args.b_major,
2459
- args.d_major,
2460
- args.c_major,
2461
- args.tile_shape_mnk,
2462
- args.cluster_shape_mn,
2463
- args.tolerance,
2464
- args.warmup_iterations,
2465
- args.iterations,
2466
- args.skip_ref_check,
2467
- args.persistent,
2468
- args.dynamic_persistent,
2469
- args.pingpong,
2470
- args.varlen_m,
2471
- args.gather_A,
2472
- args.fp8_fast_accum,
2473
- )
2474
- print("PASS")
2245
+ gemm_sm90.compile_cache = {}