quack-kernels 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2088 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+
7
+ # 1. Redistributions of source code must retain the above copyright notice, this
8
+ # list of conditions and the following disclaimer.
9
+
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+
14
+ # 3. Neither the name of the copyright holder nor the names of its
15
+ # contributors may be used to endorse or promote products derived from
16
+ # this software without specific prior written permission.
17
+
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+ import argparse
30
+ import enum
31
+ from typing import Tuple, Type, Callable, Optional
32
+ from functools import partial
33
+ import math
34
+
35
+ import cuda.bindings.driver as cuda
36
+
37
+ import torch
38
+
39
+ import cutlass
40
+ import cutlass.cute as cute
41
+ import cutlass.pipeline as pipeline
42
+ import cutlass.torch as cutlass_torch
43
+ from cutlass.cute.runtime import from_dlpack
44
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
45
+ import cutlass.utils.hopper_helpers as sm90_utils
46
+ from cutlass import Int32, const_expr
47
+
48
+ from quack.tile_scheduler import (
49
+ TileSchedulerArguments,
50
+ ParamsBase,
51
+ RasterOrderOption,
52
+ TriangularTileScheduler,
53
+ )
54
+ from quack.reduction_base import torch2cute_dtype_map
55
+
56
+ # return PipelineStateWAdvance instead of PipelineState
57
+ from quack.pipeline import make_pipeline_state
58
+ import quack.utils as utils
59
+
60
+ from functools import lru_cache
61
+
62
+
63
+ # /////////////////////////////////////////////////////////////////////////////
64
+ # Helpers to parse args
65
+ # /////////////////////////////////////////////////////////////////////////////
66
+ def parse_comma_separated_ints(s: str):
67
+ try:
68
+ return tuple([int(x.strip()) for x in s.split(",")])
69
+ except ValueError:
70
+ raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
71
+
72
+
73
+ def parse_arguments() -> argparse.Namespace:
74
+ parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.")
75
+
76
+ parser.add_argument(
77
+ "--mnkl",
78
+ type=parse_comma_separated_ints,
79
+ default=(4096, 4096, 4096, 1),
80
+ help="mnkl dimensions (comma-separated)",
81
+ )
82
+ parser.add_argument("--alpha", type=float, default=1.0, help="Scalar multiplier for A @ B")
83
+ parser.add_argument("--beta", type=float, default=1.0, help="Scalar multiplier for C")
84
+ parser.add_argument(
85
+ "--alpha_dtype",
86
+ type=cutlass.dtype,
87
+ default=cutlass.Float32,
88
+ help="Data type for alpha scalar",
89
+ )
90
+ parser.add_argument(
91
+ "--beta_dtype",
92
+ type=cutlass.dtype,
93
+ default=cutlass.Float32,
94
+ help="Data type for beta scalar",
95
+ )
96
+ parser.add_argument(
97
+ "--tile_shape_mnk",
98
+ type=parse_comma_separated_ints,
99
+ default=(128, 256, 64),
100
+ help="Cta tile shape (comma-separated)",
101
+ )
102
+ parser.add_argument(
103
+ "--cluster_shape_mn",
104
+ type=parse_comma_separated_ints,
105
+ choices=[(2, 1)],
106
+ default=(2, 1),
107
+ help="Cluster shape (comma-separated)",
108
+ )
109
+ parser.add_argument(
110
+ "--a_dtype",
111
+ type=cutlass.dtype,
112
+ default=cutlass.BFloat16,
113
+ )
114
+ parser.add_argument(
115
+ "--b_dtype",
116
+ type=cutlass.dtype,
117
+ default=cutlass.BFloat16,
118
+ )
119
+ parser.add_argument(
120
+ "--d_dtype",
121
+ type=cutlass.dtype,
122
+ default=cutlass.BFloat16,
123
+ )
124
+ parser.add_argument(
125
+ "--c_dtype",
126
+ type=cutlass.dtype,
127
+ default=None,
128
+ )
129
+ parser.add_argument(
130
+ "--acc_dtype",
131
+ type=cutlass.dtype,
132
+ default=cutlass.Float32,
133
+ )
134
+ parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
135
+ parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
136
+ parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
137
+ parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
138
+ parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
139
+ parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
140
+ parser.add_argument(
141
+ "--iterations",
142
+ type=int,
143
+ default=30,
144
+ help="Number of iterations to run the kernel",
145
+ )
146
+ parser.add_argument("--persistent", action="store_true", help="Persistent kernel")
147
+ parser.add_argument("--pingpong", action="store_true", help="Pingpong kernel")
148
+ parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum")
149
+ parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
150
+
151
+ args = parser.parse_args()
152
+
153
+ if len(args.mnkl) != 4:
154
+ parser.error("--mnkl must contain exactly 4 values")
155
+ if len(args.tile_shape_mnk) != 3:
156
+ parser.error("--tile_shape_mnk must contain exactly 3 values")
157
+ if len(args.cluster_shape_mn) != 2:
158
+ parser.error("--cluster_shape_mn must contain exactly 2 values")
159
+
160
+ return args
161
+
162
+
163
+ # /////////////////////////////////////////////////////////////////////////////
164
+ # Host setup and device kernel launch
165
+ # /////////////////////////////////////////////////////////////////////////////
166
+
167
+
168
+ class NamedBarrierGemm(enum.IntEnum):
169
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
170
+ # For mainloop load warps to signal that the epilogue load warp can start.
171
+ # This is to avoid loading C too early, interfering with loading A and B.
172
+ EpilogueLoad = enum.auto()
173
+ MmaWG0 = enum.auto()
174
+ MmaWG1 = enum.auto()
175
+ EpiWG0 = enum.auto()
176
+ EpiWG1 = enum.auto()
177
+
178
+
179
+ class HopperSymmetricGemmKernel:
180
+ """
181
+ This class implements batched matrix multiplication for matrix outputs that are guaranteed
182
+ to be symmetric, with C addition. (D = alpha * A x B + beta * C, where B = A^T).
183
+
184
+ :param acc_dtype: Data type for accumulation during computation
185
+ :type acc_dtype: type[cutlass.Numeric]
186
+ :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
187
+ :type tile_shape_mnk: Tuple[int, int, int]
188
+ :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
189
+ :type cluster_shape_mnk: Tuple[int, int, int]
190
+
191
+ :note: Data type requirements:
192
+ - For 16-bit types: A and B must have the same data type
193
+ - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit
194
+ - Float8 types only support k-major layout
195
+
196
+ :note: Supported data types:
197
+ - Float16
198
+ - BFloat16
199
+ - Float8E4M3FN/Float8E5M2
200
+
201
+ :note: Supported accumulation types:
202
+ - Float32 (for all floating point inputs)
203
+
204
+ :note: Constraints:
205
+ - Cluster shape M/N must be positive and power of 2, total cluster size <= 4
206
+
207
+ Example:
208
+ >>> gemm = HopperSymmetricGemmKernel(
209
+ ... acc_dtype=cutlass.Float32,
210
+ ... tile_shape_mnk=(128, 256, 64),
211
+ ... cluster_shape_mnk=(2, 1, 1)
212
+ ... )
213
+ >>> gemm(a_tensor, b_tensor, c_tensor, stream)
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ acc_dtype: Type[cutlass.Numeric],
219
+ a_dtype: Type[cutlass.Numeric],
220
+ tile_shape_mnk: Tuple[int, int, int],
221
+ cluster_shape_mnk: Tuple[int, int, int],
222
+ pingpong: bool = False,
223
+ is_persistent: bool = True,
224
+ fp8_fast_accum: bool = False,
225
+ ):
226
+ """
227
+ Initializes the configuration for the Hopper symmetric dense GEMM kernel.
228
+
229
+ This configuration includes data types for operands, tile shape, cluster configuration,
230
+ and thread layout.
231
+
232
+ :param acc_dtype: Data type for accumulation during computation
233
+ :type acc_dtype: type[cutlass.Numeric]
234
+ :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
235
+ :type tile_shape_mnk: Tuple[int, int, int]
236
+ :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
237
+ :type cluster_shape_mnk: Tuple[int, int, int]
238
+ """
239
+
240
+ self.acc_dtype = acc_dtype
241
+ self.pingpong = pingpong
242
+ self.is_persistent = is_persistent
243
+ if self.pingpong:
244
+ assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
245
+ self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
246
+
247
+ self.cluster_shape_mnk = cluster_shape_mnk
248
+ self.tile_shape_mnk = tuple(tile_shape_mnk)
249
+ tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1]
250
+ # check the cta tile shape
251
+ if not self.pingpong:
252
+ if tile_M not in [64, 128, 192, 256, 320]:
253
+ raise ValueError("CTA tile shape M must be 64/128/192/256/320")
254
+ if tile_M in [192, 320]: # special case
255
+ tile_N_max = 256 if tile_M == 192 else 160
256
+ if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
257
+ raise ValueError(
258
+ f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
259
+ )
260
+ else:
261
+ if not (
262
+ (tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
263
+ ):
264
+ raise ValueError(
265
+ "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
266
+ )
267
+ else:
268
+ if tile_M not in [64, 128, 192]:
269
+ raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
270
+ tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
271
+ if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
272
+ raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
273
+ if not self.tile_shape_mnk[2] % 16 == 0:
274
+ raise ValueError("CTA tile shape K must be divisible by 16")
275
+
276
+ if not self.pingpong:
277
+ if tile_M == 320: # tile_M / 64 is not even so we have to split along N
278
+ atom_layout_m, atom_layout_n = 1, 2
279
+ elif tile_M == 192:
280
+ if tile_N <= 128:
281
+ atom_layout_m, atom_layout_n = 3, 1
282
+ else:
283
+ atom_layout_m, atom_layout_n = 1, 2
284
+ else:
285
+ atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
286
+ atom_layout_n = 1
287
+ assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
288
+ else:
289
+ atom_layout_m, atom_layout_n = 1, 1
290
+ self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
291
+
292
+ self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
293
+ self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
294
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
295
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
296
+
297
+ self.occupancy = 1
298
+ self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
299
+ if self.pingpong:
300
+ assert self.mma_warp_groups == 2
301
+ assert self.mma_warp_groups in [1, 2, 3]
302
+ self.num_threads_per_warp_group = 128
303
+ self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
304
+ self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
305
+ self.num_mma_threads = (
306
+ self.mma_warp_groups if not self.pingpong else 1
307
+ ) * self.num_threads_per_warp_group
308
+ self.num_epi_threads = (
309
+ self.mma_warp_groups if not self.pingpong else 1
310
+ ) * self.num_threads_per_warp_group
311
+ self.num_mainloop_load_threads = cute.arch.WARP_SIZE * 1
312
+ self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
313
+ self.mainloop_load_warp_id = self.mma_warp_groups * 4
314
+ self.epi_load_warp_id = self.mainloop_load_warp_id + 1
315
+
316
+ regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // self.num_mma_threads
317
+ if self.fp8_slow_accum:
318
+ regs_per_thread *= 2
319
+ if self.mma_warp_groups == 3:
320
+ self.num_regs_load, self.num_regs_mma = 32, 160
321
+ else:
322
+ heavy_register_pressure = regs_per_thread >= 208
323
+ self.num_regs_load, self.num_regs_mma = (
324
+ (40, 232) if not heavy_register_pressure else (24, 240)
325
+ )
326
+
327
+ self.ab_stage = None
328
+ self.epi_stage = None
329
+
330
+ self.a_smem_layout_staged = None
331
+ self.b_smem_layout_staged = None
332
+ self.epi_smem_layout_staged = None
333
+ self.epi_tile = None
334
+
335
+ self.shared_storage = None
336
+ self.buffer_align_bytes = 1024
337
+
338
+ def _setup_attributes(self):
339
+ """Set up configurations that are dependent on GEMM inputs
340
+
341
+ This method configures various attributes based on the input tensor properties
342
+ (data types, leading dimensions) and kernel settings:
343
+ - Configuring tiled MMA
344
+ - Computing MMA/cluster/tile shapes
345
+ - Computing cluster layout
346
+ - Computing multicast CTAs for A/B
347
+ - Computing epilogue subtile
348
+ - Setting up A/B/C stage counts in shared memory
349
+ - Computing A/B/C shared memory layout
350
+ """
351
+
352
+ self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
353
+
354
+ self.epi_tile = self._sm90_compute_tile_shape_or_override(
355
+ self.tile_shape_mnk,
356
+ self.atom_layout_mnk,
357
+ self.d_dtype,
358
+ )
359
+
360
+ # Compute stage before compute smem layout
361
+ self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
362
+ self.tile_shape_mnk,
363
+ self.epi_tile,
364
+ self.a_dtype,
365
+ self.b_dtype,
366
+ self.d_dtype,
367
+ self.c_dtype,
368
+ self.smem_capacity,
369
+ self.occupancy,
370
+ # epi_smem will reuse smem ab if not persistent.
371
+ overlap_sD_sA=not self.is_persistent,
372
+ )
373
+
374
+ (
375
+ self.a_smem_layout_staged,
376
+ self.b_smem_layout_staged,
377
+ self.epi_smem_layout_staged,
378
+ self.epi_t_smem_layout_staged,
379
+ self.epi_c_smem_layout_staged,
380
+ ) = self._make_smem_layouts(
381
+ self.tile_shape_mnk,
382
+ self.epi_tile,
383
+ self.a_dtype,
384
+ self.a_layout,
385
+ self.b_dtype,
386
+ self.b_layout,
387
+ self.ab_stage,
388
+ self.d_dtype,
389
+ self.d_layout,
390
+ self.epi_stage,
391
+ self.c_dtype,
392
+ self.c_layout,
393
+ self.epi_c_stage,
394
+ )
395
+
396
+ @cute.jit
397
+ def __call__(
398
+ self,
399
+ mA: cute.Tensor,
400
+ mB: cute.Tensor,
401
+ mD: cute.Tensor,
402
+ mC: Optional[cute.Tensor],
403
+ alpha: cutlass.Numeric,
404
+ beta: cutlass.Numeric,
405
+ max_active_clusters: Int32,
406
+ stream: cuda.CUstream,
407
+ ):
408
+ """Execute the GEMM operation in steps:
409
+ - Setup static attributes
410
+ - Setup TMA load/store atoms and tensors
411
+ - Compute grid size
412
+ - Define shared storage for kernel
413
+ - Launch the kernel synchronously
414
+
415
+ :param mA: Input tensor A
416
+ :type mA: cute.Tensor
417
+ :param mB: Input tensor B
418
+ :type mB: cute.Tensor
419
+ :param mD: Output tensor D
420
+ :type mD: cute.Tensor
421
+ :param stream: CUDA stream for asynchronous execution
422
+ :type stream: cuda.CUstream
423
+ """
424
+ mDt = cute.make_tensor(
425
+ mD.iterator, cute.make_layout(mD.shape, stride=cute.select(mD.stride, mode=[1, 0, 2]))
426
+ )
427
+
428
+ # setup static attributes before smem/grid/tma computation
429
+ self.a_dtype = mA.element_type
430
+ self.b_dtype = mB.element_type
431
+ self.d_dtype = mD.element_type
432
+ self.c_dtype = mC.element_type if mC is not None else None
433
+ self.a_layout = cutlass.utils.LayoutEnum.from_tensor(mA)
434
+ self.b_layout = cutlass.utils.LayoutEnum.from_tensor(mB)
435
+ self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
436
+ self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
437
+
438
+ if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
439
+ raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
440
+ if const_expr(self.a_dtype.width != self.b_dtype.width):
441
+ raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
442
+ if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
443
+ raise TypeError("a_dtype should be float16 or float8")
444
+
445
+ self._setup_attributes()
446
+
447
+ tiled_mma = sm90_utils.make_trivial_tiled_mma(
448
+ self.a_dtype,
449
+ self.b_dtype,
450
+ self.a_layout.sm90_mma_major_mode(),
451
+ self.b_layout.sm90_mma_major_mode(),
452
+ self.acc_dtype,
453
+ self.atom_layout_mnk,
454
+ tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
455
+ )
456
+ if const_expr(self.atom_layout_mnk[1] > 1):
457
+ # If N dimension is split among 2 WGs, we need to permute the N dimension so
458
+ # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
459
+ # containing accumulators that are next to each other in the N dimension.
460
+ # Without permutation WG0 would write to epi smem of size (64, 16) and
461
+ # WG1 would write to a separate epi smem of size (64, 16) that's far away.
462
+ atom_n = self.atom_layout_mnk[1]
463
+ permutation_n = cute.make_ordered_layout(
464
+ (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
465
+ )
466
+ tiled_mma = cute.make_tiled_mma(
467
+ cute.make_mma_atom(tiled_mma.op),
468
+ self.atom_layout_mnk,
469
+ permutation_mnk=(None, permutation_n, None),
470
+ )
471
+
472
+ tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
473
+ mA,
474
+ self.a_smem_layout_staged,
475
+ (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
476
+ self.cluster_shape_mnk[1],
477
+ )
478
+
479
+ tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
480
+ mB,
481
+ self.b_smem_layout_staged,
482
+ (self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
483
+ self.cluster_shape_mnk[0],
484
+ )
485
+
486
+ tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
487
+ mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
488
+ )
489
+
490
+ tma_atom_dt, tma_tensor_dt = self._make_tma_epi_atoms_and_tensors(
491
+ mDt, self.epi_t_smem_layout_staged, self.epi_tile, store_or_load="store"
492
+ )
493
+
494
+ if const_expr(mC is not None):
495
+ tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
496
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
497
+ )
498
+ else:
499
+ tma_atom_c, tma_tensor_c = None, None
500
+
501
+ problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (
502
+ mD.shape[2],
503
+ )
504
+ TileScheduler = TriangularTileScheduler
505
+ tile_sched_args = TileSchedulerArguments(
506
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
507
+ raster_order=RasterOrderOption.Heuristic,
508
+ group_size=8,
509
+ cluster_shape_mnk=self.cluster_shape_mnk,
510
+ is_persistent=self.is_persistent,
511
+ )
512
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
513
+ grid = TileScheduler.get_grid_shape(tile_sched_params, max_active_clusters)
514
+
515
+ epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if self.is_persistent else 0
516
+ epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
517
+
518
+ @cute.struct
519
+ class SharedStorage:
520
+ mainloop_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
521
+ epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
522
+ sD: cute.struct.Align[
523
+ cute.struct.MemRange[self.d_dtype, epi_smem_size],
524
+ self.buffer_align_bytes,
525
+ ]
526
+ sDt: cute.struct.Align[
527
+ cute.struct.MemRange[self.d_dtype, epi_smem_size],
528
+ self.buffer_align_bytes,
529
+ ]
530
+ sC: cute.struct.Align[
531
+ cute.struct.MemRange[
532
+ self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
533
+ ],
534
+ self.buffer_align_bytes,
535
+ ]
536
+ sA: cute.struct.Align[
537
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
538
+ self.buffer_align_bytes,
539
+ ]
540
+ sB: cute.struct.Align[
541
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
542
+ self.buffer_align_bytes,
543
+ ]
544
+
545
+ self.shared_storage = SharedStorage
546
+
547
+ # Launch the kernel synchronously
548
+ self.kernel(
549
+ tma_atom_a,
550
+ tma_tensor_a,
551
+ tma_atom_b,
552
+ tma_tensor_b,
553
+ tma_atom_d,
554
+ tma_tensor_d,
555
+ tma_atom_dt,
556
+ tma_tensor_dt,
557
+ tma_atom_c,
558
+ tma_tensor_c,
559
+ tiled_mma,
560
+ self.cta_layout_mnk,
561
+ self.a_smem_layout_staged,
562
+ self.b_smem_layout_staged,
563
+ self.epi_smem_layout_staged,
564
+ self.epi_t_smem_layout_staged,
565
+ self.epi_c_smem_layout_staged,
566
+ tile_sched_params,
567
+ TileScheduler,
568
+ alpha,
569
+ beta,
570
+ ).launch(
571
+ grid=grid,
572
+ block=[self.threads_per_cta, 1, 1],
573
+ cluster=self.cluster_shape_mnk,
574
+ smem=self.shared_storage.size_in_bytes(),
575
+ stream=stream,
576
+ min_blocks_per_mp=1,
577
+ )
578
+ return
579
+
580
+ # GPU device kernel
581
+ @cute.kernel
582
+ def kernel(
583
+ self,
584
+ tma_atom_a: cute.CopyAtom,
585
+ mA_mkl: cute.Tensor,
586
+ tma_atom_b: cute.CopyAtom,
587
+ mB_nkl: cute.Tensor,
588
+ tma_atom_d: cute.CopyAtom,
589
+ mD_mnl: cute.Tensor,
590
+ tma_atom_dt: cute.CopyAtom,
591
+ mDt_mnl: cute.Tensor,
592
+ tma_atom_c: Optional[cute.CopyAtom],
593
+ mC_mnl: Optional[cute.Tensor],
594
+ tiled_mma: cute.TiledMma,
595
+ cta_layout_mnk: cute.Layout,
596
+ a_smem_layout_staged: cute.ComposedLayout,
597
+ b_smem_layout_staged: cute.ComposedLayout,
598
+ epi_smem_layout_staged: cute.ComposedLayout,
599
+ epi_t_smem_layout_staged: cute.ComposedLayout,
600
+ epi_c_smem_layout_staged: cute.ComposedLayout,
601
+ tile_sched_params: ParamsBase,
602
+ TileScheduler: cutlass.Constexpr[Callable],
603
+ alpha: cutlass.Numeric,
604
+ beta: cutlass.Numeric,
605
+ ):
606
+ """
607
+ GPU device kernel performing the batched GEMM computation.
608
+
609
+ :param tma_atom_a: TMA copy atom for A tensor
610
+ :type tma_atom_a: cute.CopyAtom
611
+ :param mA_mkl: Input tensor A
612
+ :type mA_mkl: cute.Tensor
613
+ :param tma_atom_b: TMA copy atom for B tensor
614
+ :type tma_atom_b: cute.CopyAtom
615
+ :param mB_nkl: Input tensor B
616
+ :type mB_nkl: cute.Tensor
617
+ :param tma_atom_d: TMA copy atom for D tensor
618
+ :type tma_atom_d: cute.CopyAtom
619
+ :param mD_mnl: Output tensor D
620
+ :type mD_mnl: cute.Tensor
621
+ :param tiled_mma: Tiled MMA object
622
+ :type tiled_mma: cute.TiledMma
623
+ :param cta_layout_mnk: CTA layout
624
+ :type cta_layout_mnk: cute.Layout
625
+ :param a_smem_layout_staged: Shared memory layout for A
626
+ :type a_smem_layout_staged: cute.ComposedLayout
627
+ :param b_smem_layout_staged: Shared memory layout for B
628
+ :type b_smem_layout_staged: cute.ComposedLayout
629
+ :param epi_smem_layout_staged: Shared memory layout for epilogue
630
+ :type epi_smem_layout_staged: cute.ComposedLayout
631
+ """
632
+
633
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
634
+
635
+ # /////////////////////////////////////////////////////////////////////////////
636
+ # Prefetch Tma desc
637
+ # /////////////////////////////////////////////////////////////////////////////
638
+ if warp_idx == self.mainloop_load_warp_id:
639
+ cpasync.prefetch_descriptor(tma_atom_a)
640
+ cpasync.prefetch_descriptor(tma_atom_b)
641
+ cpasync.prefetch_descriptor(tma_atom_d)
642
+ cpasync.prefetch_descriptor(tma_atom_dt)
643
+ if const_expr(tma_atom_c is not None):
644
+ cpasync.prefetch_descriptor(tma_atom_c)
645
+
646
+ a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
647
+ b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
648
+ tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes(
649
+ self.b_dtype, b_smem_layout
650
+ )
651
+
652
+ # /////////////////////////////////////////////////////////////////////////////
653
+ # Alloc and init AB full/empty + ACC full mbar (pipeline)
654
+ # /////////////////////////////////////////////////////////////////////////////
655
+ smem = cutlass.utils.SmemAllocator()
656
+ storage = smem.allocate(self.shared_storage)
657
+
658
+ # Threads/warps participating in this pipeline
659
+ mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
660
+ # Each warp will contribute to the arrive count with the number of mcast size
661
+ mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
662
+ consumer_arrive_cnt = mcast_size * (self.num_mma_threads // cute.arch.WARP_SIZE)
663
+ mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
664
+ pipeline.Agent.Thread, consumer_arrive_cnt
665
+ )
666
+
667
+ cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
668
+ mainloop_pipeline = pipeline.PipelineTmaAsync.create(
669
+ barrier_storage=storage.mainloop_pipeline_array_ptr.data_ptr(),
670
+ num_stages=self.ab_stage,
671
+ producer_group=mainloop_pipeline_producer_group,
672
+ consumer_group=mainloop_pipeline_consumer_group,
673
+ tx_count=tma_copy_bytes,
674
+ cta_layout_vmnk=cta_layout_vmnk,
675
+ )
676
+
677
+ if const_expr(mC_mnl is not None):
678
+ # Threads/warps participating in this pipeline
679
+ epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
680
+ # Each warp will contribute 1 to the arrive count
681
+ consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
682
+ epi_pipeline_consumer_group = pipeline.CooperativeGroup(
683
+ pipeline.Agent.Thread, consumer_arrive_cnt
684
+ )
685
+ c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
686
+ tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
687
+ epi_pipeline = pipeline.PipelineTmaAsync.create(
688
+ barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
689
+ num_stages=self.epi_c_stage,
690
+ producer_group=epi_pipeline_producer_group,
691
+ consumer_group=epi_pipeline_consumer_group,
692
+ tx_count=tma_copy_c_bytes,
693
+ )
694
+ else:
695
+ epi_pipeline = None
696
+
697
+ # ///////////////////////////////////////////////////////////////////////////////
698
+ # Generate smem tensor A/B
699
+ # ///////////////////////////////////////////////////////////////////////////////
700
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
701
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
702
+ if const_expr(not self.is_persistent):
703
+ sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
704
+ sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
705
+ else:
706
+ sD = storage.sD.get_tensor(
707
+ epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
708
+ )
709
+ if cutlass.const_expr(not self.is_persistent):
710
+ sD_size_bytes = cute.size_in_bytes(self.d_dtype, epi_smem_layout_staged.outer)
711
+ sDt_ptr = cute.recast_ptr(
712
+ sA.iterator + sD_size_bytes, epi_t_smem_layout_staged.inner, dtype=self.d_dtype
713
+ )
714
+ sDt = cute.make_tensor(sDt_ptr, epi_t_smem_layout_staged.outer)
715
+ else:
716
+ sDt = storage.sDt.get_tensor(
717
+ epi_t_smem_layout_staged.outer, swizzle=epi_t_smem_layout_staged.inner
718
+ )
719
+ if const_expr(mC_mnl is not None):
720
+ sC = storage.sC.get_tensor(
721
+ epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
722
+ )
723
+ else:
724
+ sC = None
725
+
726
+ TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
727
+
728
+ k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
729
+ c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
730
+
731
+ if warp_idx >= self.mainloop_load_warp_id:
732
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
733
+ if const_expr(mC_mnl is not None):
734
+ epi_load_barrier = pipeline.NamedBarrier(
735
+ barrier_id=int(NamedBarrierGemm.EpilogueLoad),
736
+ num_threads=self.num_mainloop_load_threads + self.num_epi_load_threads,
737
+ )
738
+ else:
739
+ epi_load_barrier = None
740
+ if warp_idx == self.mainloop_load_warp_id:
741
+ # ///////////////////////////////////////////////////////////////////////////////
742
+ # Get mcast mask
743
+ # ///////////////////////////////////////////////////////////////////////////////
744
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
745
+ cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
746
+ a_mcast_mask = cute.make_layout_image_mask(
747
+ cta_layout_mnk, cluster_coord_mnk, mode=1
748
+ )
749
+ b_mcast_mask = cute.make_layout_image_mask(
750
+ cta_layout_mnk, cluster_coord_mnk, mode=0
751
+ )
752
+ a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
753
+ b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
754
+ mainloop_producer_state = make_pipeline_state(
755
+ pipeline.PipelineUserType.Producer, self.ab_stage
756
+ )
757
+ do_epi_load_barrier_arrive = cutlass.Boolean(True)
758
+ tile_scheduler = TileSchedulerCls()
759
+ work_tile = tile_scheduler.initial_work_tile_info()
760
+ while work_tile.is_valid_tile:
761
+ tile_coord_mnkl = work_tile.tile_idx
762
+ # ///////////////////////////////////////////////////////////////////////////
763
+ # Local_tile partition global tensors
764
+ # ///////////////////////////////////////////////////////////////////////////
765
+ # (bM, bK, RestK)
766
+ gA_mkl = cute.local_tile(
767
+ mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
768
+ )
769
+ # (bN, bK, RestK)
770
+ gB_nkl = cute.local_tile(
771
+ mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
772
+ )
773
+ # //////////////////////////////////////////////////////////////////////////
774
+ # Partition shared tensor for TMA load A/B
775
+ # //////////////////////////////////////////////////////////////////////////
776
+ # TMA load A partition_S/D
777
+ a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
778
+ a_cta_crd = cluster_coord_mnk[1]
779
+ tAsA, tAgA_mkl = cpasync.tma_partition(
780
+ tma_atom_a,
781
+ a_cta_crd,
782
+ a_cta_layout,
783
+ cute.group_modes(sA, 0, 2),
784
+ cute.group_modes(gA_mkl, 0, 2),
785
+ )
786
+ # TMA load B partition_S/D
787
+ b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
788
+ b_cta_crd = cluster_coord_mnk[0]
789
+ tBsB, tBgB_nkl = cpasync.tma_partition(
790
+ tma_atom_b,
791
+ b_cta_crd,
792
+ b_cta_layout,
793
+ cute.group_modes(sB, 0, 2),
794
+ cute.group_modes(gB_nkl, 0, 2),
795
+ )
796
+ # /////////////////////////////////////////////////////////////////////////
797
+ # TMA load
798
+ # /////////////////////////////////////////////////////////////////////////
799
+ for k_tile in cutlass.range(k_tile_cnt, unroll=1):
800
+ # Wait for A/B buffers to be empty before loading into them
801
+ # Also sets the transaction barrier for the A/B buffers
802
+ mainloop_pipeline.producer_acquire(mainloop_producer_state)
803
+ cute.copy(
804
+ tma_atom_a,
805
+ tAgA_mkl[None, k_tile],
806
+ tAsA[None, mainloop_producer_state.index],
807
+ tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
808
+ mainloop_producer_state
809
+ ),
810
+ mcast_mask=a_mcast_mask,
811
+ )
812
+ cute.copy(
813
+ tma_atom_b,
814
+ tBgB_nkl[None, k_tile],
815
+ tBsB[None, mainloop_producer_state.index],
816
+ tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
817
+ mainloop_producer_state
818
+ ),
819
+ mcast_mask=b_mcast_mask,
820
+ )
821
+ # Mainloop pipeline's producer commit is a NOP
822
+ mainloop_pipeline.producer_commit(mainloop_producer_state)
823
+ mainloop_producer_state.advance()
824
+ if const_expr(epi_load_barrier is not None):
825
+ # In the first work tile, the epi load warp will wait for the signal
826
+ # from the mainloop load warp to start loading C, to avoid interfering
827
+ # with loading A and B.
828
+ if do_epi_load_barrier_arrive:
829
+ epi_load_barrier.arrive()
830
+ do_epi_load_barrier_arrive = cutlass.Boolean(False)
831
+ tile_scheduler.fetch_next_work()
832
+ tile_scheduler.advance_to_next_work()
833
+ work_tile = tile_scheduler.get_current_work()
834
+ # End of persistent scheduler loop
835
+ mainloop_pipeline.producer_tail(mainloop_producer_state)
836
+
837
+ if const_expr(mC_mnl is not None):
838
+ if warp_idx == self.epi_load_warp_id:
839
+ epi_producer_state = make_pipeline_state(
840
+ pipeline.PipelineUserType.Producer, self.epi_c_stage
841
+ )
842
+ do_epi_load_barrier_wait = cutlass.Boolean(True)
843
+ tile_scheduler = TileSchedulerCls()
844
+ work_tile = tile_scheduler.initial_work_tile_info()
845
+ while work_tile.is_valid_tile:
846
+ tile_coord_mnkl = work_tile.tile_idx
847
+ # (bM, bN)
848
+ gC_mnl = cute.local_tile(
849
+ mC_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
850
+ )
851
+ tCgC_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
852
+ bGS_sC, bGS_gC = cpasync.tma_partition(
853
+ tma_atom_c,
854
+ 0,
855
+ cute.make_layout(1),
856
+ cute.group_modes(sC, 0, 2),
857
+ tCgC_for_tma_partition,
858
+ )
859
+ if do_epi_load_barrier_wait:
860
+ epi_load_barrier.arrive_and_wait()
861
+ do_epi_load_barrier_wait = cutlass.Boolean(False)
862
+ epi_tile_num = const_expr(cute.size(tCgC_for_tma_partition, mode=[1]))
863
+ epi_tile_shape = tCgC_for_tma_partition.shape[1]
864
+ for epi_idx in cutlass.range(epi_tile_num, unroll=1):
865
+ epi_pipeline.producer_acquire(epi_producer_state)
866
+ # Get the global memory coordinate for the current epi tile
867
+ epi_tile_layout = cute.make_layout(
868
+ epi_tile_shape, stride=(epi_tile_shape[1], 1)
869
+ )
870
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
871
+ cute.copy(
872
+ tma_atom_c,
873
+ bGS_gC[None, gmem_coord],
874
+ bGS_sC[None, epi_producer_state.index],
875
+ tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
876
+ )
877
+ # Epi pipeline's producer commit is a NOP
878
+ epi_pipeline.producer_commit(epi_producer_state)
879
+ epi_producer_state.advance()
880
+ tile_scheduler.advance_to_next_work()
881
+ work_tile = tile_scheduler.get_current_work()
882
+ # End of persistent scheduler loop
883
+ epi_pipeline.producer_tail(epi_producer_state)
884
+
885
+ if warp_idx < self.mainloop_load_warp_id:
886
+ cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
887
+ # //////////////////////////////////////////////////////////////////////////////
888
+ # Partition global tensor for TiledMMA_A/B/C
889
+ # //////////////////////////////////////////////////////////////////////////////
890
+ tidx, _, _ = cute.arch.thread_idx()
891
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
892
+ if const_expr(self.pingpong):
893
+ tidx = tidx % self.num_threads_per_warp_group
894
+ warp_group_thread_layout = cute.make_layout(
895
+ self.mma_warp_groups if not self.pingpong else 1,
896
+ stride=self.num_threads_per_warp_group,
897
+ )
898
+ thr_mma = tiled_mma.get_slice(
899
+ warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
900
+ )
901
+
902
+ # //////////////////////////////////////////////////////////////////////////////
903
+ # Make fragments
904
+ # //////////////////////////////////////////////////////////////////////////////
905
+ tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
906
+ tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
907
+
908
+ acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
909
+ acc = cute.make_fragment(acc_shape, self.acc_dtype)
910
+ if const_expr(self.fp8_slow_accum):
911
+ acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
912
+
913
+ if const_expr(self.pingpong):
914
+ if warp_group_idx == 0:
915
+ # WG0 needs a start signal at the very beginning
916
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
917
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
918
+
919
+ mainloop_read_state = make_pipeline_state(
920
+ pipeline.PipelineUserType.Consumer, self.ab_stage
921
+ )
922
+ epi_read_state = make_pipeline_state(
923
+ pipeline.PipelineUserType.Consumer, self.epi_c_stage
924
+ )
925
+ tile_scheduler = TileSchedulerCls()
926
+ if const_expr(self.pingpong):
927
+ if warp_idx >= 4:
928
+ # Advance 2nd Math WG to the next work tile for the startup
929
+ tile_scheduler.advance_to_next_work()
930
+ # Advance 2nd Math WG pipeline states to the end of 1st Math WG
931
+ mainloop_read_state.advance_iters(k_tile_cnt)
932
+ epi_read_state.advance_iters(c_tile_cnt)
933
+ work_tile = tile_scheduler.initial_work_tile_info()
934
+ while work_tile.is_valid_tile:
935
+ tile_coord_mnkl = work_tile.tile_idx
936
+ # /////////////////////////////////////////////////////////////////////////////
937
+ # Prologue MMAs
938
+ # /////////////////////////////////////////////////////////////////////////////
939
+ k_pipe_mmas = 1
940
+ mainloop_release_state = mainloop_read_state.clone()
941
+ num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
942
+ if const_expr(self.pingpong):
943
+ self.pingpong_barrier_sync(warp_group_idx, stage="mma")
944
+ peek_ab_full_status = cutlass.Boolean(1)
945
+ if 0 < k_tile_cnt:
946
+ peek_ab_full_status = mainloop_pipeline.consumer_try_wait(mainloop_read_state)
947
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
948
+ num_k_blocks = cute.size(tCrA, mode=[2])
949
+ # TODO: this is probably not correct if k_tile_cnt == 0
950
+ for k_tile in cutlass.range(num_prologue_mma):
951
+ # Wait for A/B buffer to be ready
952
+ mainloop_pipeline.consumer_wait(mainloop_read_state, peek_ab_full_status)
953
+ warpgroup.fence()
954
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
955
+ k_blk_coord = (None, None, k_blk_idx, mainloop_read_state.index)
956
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
957
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
958
+ warpgroup.commit_group()
959
+ mainloop_read_state.advance()
960
+ peek_ab_full_status = cutlass.Boolean(1)
961
+ if k_tile + 1 < k_tile_cnt:
962
+ peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
963
+ mainloop_read_state
964
+ )
965
+ if const_expr(self.fp8_slow_accum):
966
+ warpgroup.wait_group(0)
967
+ acc_slow.store(acc.load())
968
+
969
+ # /////////////////////////////////////////////////////////////////////////////
970
+ # MAINLOOP
971
+ # /////////////////////////////////////////////////////////////////////////////
972
+ for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
973
+ # Wait for TMA copies to complete
974
+ mainloop_pipeline.consumer_wait(mainloop_read_state, peek_ab_full_status)
975
+ # WGMMA
976
+ warpgroup.fence()
977
+ if const_expr(self.fp8_slow_accum):
978
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
979
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
980
+ k_blk_coord = (None, None, k_blk_idx, mainloop_read_state.index)
981
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
982
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
983
+ warpgroup.commit_group()
984
+ # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
985
+ if const_expr(not self.fp8_slow_accum):
986
+ warpgroup.wait_group(k_pipe_mmas)
987
+ else:
988
+ warpgroup.wait_group(0)
989
+ acc_slow.store(acc_slow.load() + acc.load())
990
+ mainloop_pipeline.consumer_release(mainloop_release_state)
991
+ mainloop_read_state.advance()
992
+ mainloop_release_state.advance()
993
+ peek_ab_full_status = cutlass.Boolean(1)
994
+ if k_tile + 1 < k_tile_cnt:
995
+ peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
996
+ mainloop_read_state
997
+ )
998
+ if const_expr(self.pingpong):
999
+ # Cue for next WG's MMA to start
1000
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
1001
+ if const_expr(not self.fp8_slow_accum):
1002
+ # fp8_slow_accum would already called wait_group(0) inside the loop
1003
+ warpgroup.wait_group(0)
1004
+ for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
1005
+ mainloop_pipeline.consumer_release(mainloop_release_state)
1006
+ mainloop_release_state.advance()
1007
+ if const_expr(self.fp8_slow_accum):
1008
+ acc.store(acc_slow.load())
1009
+ if const_expr(self.pingpong):
1010
+ # Update starting mainloop pipeline state for the next tile
1011
+ mainloop_read_state.advance_iters(k_tile_cnt)
1012
+
1013
+ # /////////////////////////////////////////////////////////////////////////////
1014
+ # EPILOGUE
1015
+ # /////////////////////////////////////////////////////////////////////////////
1016
+ if const_expr(self.pingpong):
1017
+ self.pingpong_barrier_sync(warp_group_idx, "epi")
1018
+
1019
+ epilogue_barrier = pipeline.NamedBarrier(
1020
+ barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
1021
+ )
1022
+
1023
+ # Wait for all warp groups in the thread block to finish, because smem for tensor
1024
+ # A in the mainloop is reused in the epilogue if not persistent.
1025
+ if const_expr(not self.is_persistent):
1026
+ epilogue_barrier.arrive_and_wait()
1027
+
1028
+ copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
1029
+ self.d_layout,
1030
+ elem_ty_d=self.d_dtype,
1031
+ elem_ty_acc=self.acc_dtype,
1032
+ )
1033
+ copy_atom_D = cute.make_copy_atom(
1034
+ warp.StMatrix8x8x16bOp(self.d_layout.is_m_major_c(), 4),
1035
+ self.d_dtype,
1036
+ )
1037
+ dt_layout = (
1038
+ cutlass.utils.LayoutEnum.COL_MAJOR
1039
+ if self.d_layout == cutlass.utils.LayoutEnum.ROW_MAJOR
1040
+ else cutlass.utils.LayoutEnum.ROW_MAJOR
1041
+ )
1042
+ copy_atom_r2s_t = sm90_utils.sm90_get_smem_store_op(
1043
+ dt_layout,
1044
+ elem_ty_d=self.d_dtype,
1045
+ elem_ty_acc=self.acc_dtype,
1046
+ )
1047
+ copy_atom_Dt = cute.make_copy_atom(
1048
+ cute.nvgpu.warp.StMatrix8x8x16bOp(
1049
+ not self.d_layout.is_m_major_c(),
1050
+ 4,
1051
+ ),
1052
+ self.d_dtype,
1053
+ )
1054
+ tiled_copy_D_atom = cute.make_tiled_copy_C_atom(copy_atom_D, tiled_mma)
1055
+ tiled_copy_Dt_atom = cute.make_tiled_copy_C_atom(copy_atom_Dt, tiled_mma)
1056
+
1057
+ tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_D_atom)
1058
+ tiled_copy_r2s_t = cute.make_tiled_copy_S(
1059
+ copy_atom_r2s_t,
1060
+ tiled_copy_Dt_atom,
1061
+ )
1062
+
1063
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1064
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1065
+ tRS_sD = thr_copy_r2s.partition_D(sD)
1066
+ thr_copy_r2s_t = tiled_copy_r2s_t.get_slice(tidx)
1067
+ tRS_sDt = thr_copy_r2s_t.partition_D(sDt)
1068
+ # (R2S, R2S_M, R2S_N)
1069
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
1070
+
1071
+ if const_expr(mC_mnl is not None):
1072
+ copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
1073
+ tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_D_atom)
1074
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1075
+ tRS_sC = thr_copy_s2r.partition_S(sC)
1076
+ else:
1077
+ thr_copy_s2r, tRS_sC = None, None
1078
+
1079
+ # (bM, bN)
1080
+ gD_mnl = cute.local_tile(
1081
+ mD_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
1082
+ )
1083
+ tDgD_for_tma_partition = cute.zipped_divide(gD_mnl, self.epi_tile)
1084
+ bSG_sD, bSG_gD = cpasync.tma_partition(
1085
+ tma_atom_d,
1086
+ 0,
1087
+ cute.make_layout(1),
1088
+ cute.group_modes(sD, 0, 2),
1089
+ tDgD_for_tma_partition,
1090
+ )
1091
+
1092
+ gDt_mnl = cute.local_tile(
1093
+ mDt_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
1094
+ )
1095
+ tDgDt_for_tma_partition = cute.zipped_divide(gDt_mnl, self.epi_tile)
1096
+ bSG_sDt, bSG_gDt = cute.nvgpu.cpasync.tma_partition(
1097
+ tma_atom_dt,
1098
+ 0,
1099
+ cute.make_layout(1),
1100
+ cute.group_modes(sDt, 0, 2),
1101
+ tDgDt_for_tma_partition,
1102
+ )
1103
+
1104
+ epi_tile_num = const_expr(cute.size(tDgD_for_tma_partition, mode=[1]))
1105
+ epi_tile_shape = tDgD_for_tma_partition.shape[1]
1106
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1107
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
1108
+ # Copy from acc to D registers
1109
+ tRS_rD = cute.make_fragment_like(tRS_sD[None, None, None, 0], self.acc_dtype)
1110
+ for epi_v in cutlass.range(cute.size(tRS_rD), unroll_full=True):
1111
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1112
+ if const_expr(mC_mnl is not None):
1113
+ epi_pipeline.consumer_wait(epi_read_state)
1114
+ tRS_rC = cute.make_fragment_like(tRS_sC[None, None, None, 0], self.c_dtype)
1115
+ cute.copy(
1116
+ thr_copy_s2r, tRS_sC[None, None, None, epi_read_state.index], tRS_rC
1117
+ )
1118
+ # Fence to make sure shared memory read is visible to TMA load
1119
+ cute.arch.fence_proxy(
1120
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1121
+ )
1122
+ cute.arch.sync_warp()
1123
+ with cute.arch.elect_one():
1124
+ epi_pipeline.consumer_release(epi_read_state)
1125
+ epi_read_state.advance()
1126
+ result_vec = alpha.to(self.acc_dtype) * tRS_rD.load() + beta.to(
1127
+ self.acc_dtype
1128
+ ) * tRS_rC.load().to(self.acc_dtype)
1129
+ tRS_rD.store(result_vec)
1130
+ else:
1131
+ result_vec = alpha.to(self.acc_dtype) * tRS_rD.load()
1132
+ tRS_rD.store(result_vec)
1133
+ # Type conversion
1134
+ tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
1135
+ tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
1136
+ # Copy from D registers to shared memory
1137
+ epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3])
1138
+ cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)])
1139
+ tRS_rDt_out = tiled_copy_r2s_t.retile(tRS_rD_out)
1140
+ cute.copy(
1141
+ tiled_copy_r2s_t, tRS_rDt_out, tRS_sDt[(None, None, None, epi_buffer)]
1142
+ )
1143
+
1144
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1145
+ cute.arch.fence_proxy(
1146
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1147
+ )
1148
+ epilogue_barrier.arrive_and_wait()
1149
+ # Get the global memory coordinate for the current epi tile
1150
+ epi_tile_layout = cute.make_layout(
1151
+ epi_tile_shape, stride=(epi_tile_shape[1], 1)
1152
+ )
1153
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1154
+ # Copy from shared memory to global memory
1155
+ if (not self.pingpong and warp_idx == 0) or (
1156
+ self.pingpong and (warp_idx == 0 or warp_idx == 4)
1157
+ ):
1158
+ cute.copy(tma_atom_d, bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
1159
+ cute.copy(
1160
+ tma_atom_dt, bSG_sDt[(None, epi_buffer)], bSG_gDt[(None, gmem_coord)]
1161
+ )
1162
+ cute.arch.cp_async_bulk_commit_group()
1163
+ cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
1164
+ epilogue_barrier.arrive_and_wait()
1165
+
1166
+ if const_expr(self.pingpong):
1167
+ # Update starting load/store pipeline states for the next tile
1168
+ epi_read_state.advance_iters(c_tile_cnt)
1169
+ # With pingpong, 2 WGs write two different output tiles to the same smem,
1170
+ # so we have to make sure the smem content is done reading before signalling
1171
+ # the next WG's epilogue.
1172
+ if warp_idx == 0 or warp_idx == 4:
1173
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1174
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
1175
+
1176
+ tile_scheduler.advance_to_next_work(
1177
+ advance_count=1 if not self.pingpong else self.mma_warp_groups
1178
+ )
1179
+ work_tile = tile_scheduler.get_current_work()
1180
+ # End of persistent scheduler loop
1181
+
1182
+ if const_expr(not self.pingpong):
1183
+ if warp_idx == 0:
1184
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1185
+
1186
+ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
1187
+ assert stage in ["mma", "epi"]
1188
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1189
+ cute.arch.barrier(
1190
+ barrier_id=int(barrier) + warp_group_idx,
1191
+ number_of_threads=2 * self.num_threads_per_warp_group,
1192
+ )
1193
+
1194
+ def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
1195
+ assert stage in ["mma", "epi"]
1196
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1197
+ cute.arch.barrier_arrive(
1198
+ barrier_id=int(barrier) + warp_group_idx,
1199
+ number_of_threads=2 * self.num_threads_per_warp_group,
1200
+ )
1201
+
1202
+ @staticmethod
1203
+ def _compute_stages(
1204
+ tile_shape_mnk: Tuple[int, int, int],
1205
+ epi_tile: Optional[Tuple[int, int]],
1206
+ a_dtype: Type[cutlass.Numeric],
1207
+ b_dtype: Type[cutlass.Numeric],
1208
+ d_dtype: Type[cutlass.Numeric],
1209
+ c_dtype: Optional[Type[cutlass.Numeric]],
1210
+ smem_capacity: int,
1211
+ occupancy: int,
1212
+ overlap_sD_sA: bool,
1213
+ ) -> Tuple[int, int]:
1214
+ """Computes the number of stages for A/B/C operands based on heuristics.
1215
+
1216
+ :param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1217
+ :type tile_shape_mnk: Tuple[int, int, int]
1218
+ :param a_dtype: Data type of operand A.
1219
+ :type a_dtype: type[cutlass.Numeric]
1220
+ :param b_dtype: Data type of operand B.
1221
+ :type b_dtype: type[cutlass.Numeric]
1222
+ :param smem_capacity: Total available shared memory capacity in bytes.
1223
+ :type smem_capacity: int
1224
+ :param occupancy: Target number of CTAs per SM (occupancy).
1225
+ :type occupancy: int
1226
+
1227
+ :return: A tuple containing the computed number of stages for:
1228
+ (A/B operand stages, epilogue stages)
1229
+ :rtype: Tuple[int, int]
1230
+ """
1231
+
1232
+ epi_stage = 2
1233
+ if overlap_sD_sA:
1234
+ epi_bytes = 0
1235
+ else:
1236
+ d_bytes_per_stage = 2 * cute.size(epi_tile) * d_dtype.width // 8 # added * 2 here
1237
+ epi_bytes = d_bytes_per_stage * epi_stage
1238
+ epi_c_stage = 0 if c_dtype is None else 2
1239
+ if c_dtype is not None:
1240
+ epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
1241
+
1242
+ a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
1243
+ b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
1244
+ ab_bytes_per_stage = (
1245
+ cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
1246
+ )
1247
+ mbar_helpers_bytes = 1024
1248
+
1249
+ ab_stage = (
1250
+ (smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
1251
+ ) // ab_bytes_per_stage
1252
+
1253
+ # Refine epilogue stages:
1254
+ # Calculate remaining smem after allocating for A/B stages and reserved bytes
1255
+ # Add remaining unused smem to epilogue
1256
+ if not overlap_sD_sA:
1257
+ epi_stage += (
1258
+ (smem_capacity - occupancy * 1024) // occupancy
1259
+ - mbar_helpers_bytes
1260
+ - epi_bytes
1261
+ - ab_bytes_per_stage * ab_stage
1262
+ ) // (d_bytes_per_stage)
1263
+ return ab_stage, epi_stage, epi_c_stage
1264
+
1265
+ @staticmethod
1266
+ def _sm90_compute_tile_shape_or_override(
1267
+ tile_shape_mnk: Tuple[int, int, int],
1268
+ atom_layout_mnk: Tuple[int, int, int],
1269
+ element_type: Type[cutlass.Numeric],
1270
+ epi_tile_override: Tuple[int, int] | None = None,
1271
+ ) -> Tuple[int, int]:
1272
+ """Compute the epilogue tile shape or use override if provided.
1273
+
1274
+ :param tile_shape_mnk: CTA tile shape (M,N,K)
1275
+ :type tile_shape_mnk: Tuple[int, int, int]
1276
+ :param element_type: Data type of elements
1277
+ :type element_type: type[cutlass.Numeric]
1278
+ :param is_cooperative: Whether to use cooperative approach
1279
+ :type is_cooperative: bool
1280
+ :param epi_tile_override: Optional override for epilogue tile shape
1281
+ :type epi_tile_override: Tuple[int, int] or None
1282
+
1283
+ :return: Computed epilogue tile shape
1284
+ :rtype: Tuple[int, int]
1285
+ """
1286
+ if epi_tile_override is not None:
1287
+ return epi_tile_override
1288
+ if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
1289
+ tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
1290
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1291
+ elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
1292
+ tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
1293
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1294
+ else:
1295
+ # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
1296
+ # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
1297
+ # M dimension first, then move to the N dimension. But the accumulator in registers
1298
+ # iterate along the N dimension first, then move to the M dimension.
1299
+ # We could change the epilogue to accommodate this,
1300
+ # but it's easier to just set epi_tile_m = 64.
1301
+ n_perf = 64 if element_type.width == 8 else 32
1302
+ tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
1303
+ tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
1304
+ return (tile_m, tile_n)
1305
+
1306
+ @staticmethod
1307
+ def _make_smem_layouts(
1308
+ tile_shape_mnk: Tuple[int, int, int],
1309
+ epi_tile: Tuple[int, int],
1310
+ a_dtype: Type[cutlass.Numeric],
1311
+ a_layout: cutlass.utils.LayoutEnum,
1312
+ b_dtype: Type[cutlass.Numeric],
1313
+ b_layout: cutlass.utils.LayoutEnum,
1314
+ ab_stage: int,
1315
+ d_dtype: Type[cutlass.Numeric],
1316
+ d_layout: cutlass.utils.LayoutEnum,
1317
+ epi_stage: int,
1318
+ c_dtype: Optional[Type[cutlass.Numeric]],
1319
+ c_layout: Optional[cutlass.utils.LayoutEnum],
1320
+ epi_c_stage: int,
1321
+ ) -> Tuple[
1322
+ cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
1323
+ ]:
1324
+ """Create shared memory layouts for A, B, and C tensors.
1325
+
1326
+ :param tile_shape_mnk: CTA tile shape (M,N,K)
1327
+ :type tile_shape_mnk: Tuple[int, int, int]
1328
+ :param epi_tile: Epilogue tile shape
1329
+ :type epi_tile: Tuple[int, int]
1330
+ :param a_dtype: Data type for matrix A
1331
+ :type a_dtype: type[cutlass.Numeric]
1332
+ :param a_layout: Layout enum for matrix A
1333
+ :type a_layout: cutlass.utils.LayoutEnum
1334
+ :param b_dtype: Data type for matrix B
1335
+ :type b_dtype: type[cutlass.Numeric]
1336
+ :param b_layout: Layout enum for matrix B
1337
+ :type b_layout: cutlass.utils.LayoutEnum
1338
+ :param ab_stage: Number of stages for A/B tensors
1339
+ :type ab_stage: int
1340
+ :param d_dtype: Data type for output matrix C
1341
+ :type d_dtype: type[cutlass.Numeric]
1342
+ :param d_layout: Layout enum for the output matrix C
1343
+ :type d_layout: cutlass.utils.LayoutEnum
1344
+ :param epi_stage: Number of epilogue stages
1345
+ :type epi_stage: int
1346
+
1347
+ :return: Tuple of shared memory layouts for A, B, and C
1348
+ :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
1349
+ """
1350
+ a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
1351
+
1352
+ a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1353
+ b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1354
+ a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
1355
+ a_smem_layout_atom = warpgroup.make_smem_layout_atom(
1356
+ sm90_utils.get_smem_layout_atom(
1357
+ a_layout,
1358
+ a_dtype,
1359
+ a_major_mode_size,
1360
+ ),
1361
+ a_dtype,
1362
+ )
1363
+ a_smem_layout_staged = cute.tile_to_shape(
1364
+ a_smem_layout_atom,
1365
+ cute.append(a_smem_shape, ab_stage),
1366
+ order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
1367
+ )
1368
+
1369
+ b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None))
1370
+
1371
+ b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
1372
+ b_smem_layout_atom = warpgroup.make_smem_layout_atom(
1373
+ sm90_utils.get_smem_layout_atom(
1374
+ b_layout,
1375
+ b_dtype,
1376
+ b_major_mode_size,
1377
+ ),
1378
+ b_dtype,
1379
+ )
1380
+ b_smem_layout_staged = cute.tile_to_shape(
1381
+ b_smem_layout_atom,
1382
+ cute.append(b_smem_shape, ab_stage),
1383
+ order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
1384
+ )
1385
+
1386
+ d_smem_shape = epi_tile
1387
+ d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1388
+ d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1389
+ sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
1390
+ d_dtype,
1391
+ )
1392
+ epi_smem_layout_staged = cute.tile_to_shape(
1393
+ d_smem_layout_atom,
1394
+ cute.append(d_smem_shape, epi_stage),
1395
+ order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1396
+ )
1397
+
1398
+ dt_layout = (
1399
+ cutlass.utils.LayoutEnum.COL_MAJOR
1400
+ if d_layout == cutlass.utils.LayoutEnum.ROW_MAJOR
1401
+ else cutlass.utils.LayoutEnum.ROW_MAJOR
1402
+ )
1403
+ dt_major_mode_size = epi_tile[1] if dt_layout.is_n_major_c() else epi_tile[0]
1404
+ dt_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
1405
+ sm90_utils.get_smem_layout_atom(
1406
+ dt_layout,
1407
+ d_dtype,
1408
+ dt_major_mode_size,
1409
+ ),
1410
+ d_dtype,
1411
+ )
1412
+ epi_t_smem_layout_staged = cute.tile_to_shape(
1413
+ dt_smem_layout_atom,
1414
+ cute.append(d_smem_shape, epi_stage),
1415
+ order=(1, 0, 2) if dt_layout.is_m_major_c() else (0, 1, 2),
1416
+ )
1417
+
1418
+ if c_dtype is not None:
1419
+ assert c_layout is not None
1420
+ c_smem_shape = epi_tile
1421
+ c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
1422
+ c_smem_layout_atom = warpgroup.make_smem_layout_atom(
1423
+ sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
1424
+ d_dtype,
1425
+ )
1426
+ epi_c_smem_layout_staged = cute.tile_to_shape(
1427
+ c_smem_layout_atom,
1428
+ cute.append(c_smem_shape, epi_c_stage),
1429
+ order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
1430
+ )
1431
+ else:
1432
+ epi_c_smem_layout_staged = None
1433
+
1434
+ return (
1435
+ a_smem_layout_staged,
1436
+ b_smem_layout_staged,
1437
+ epi_smem_layout_staged,
1438
+ epi_t_smem_layout_staged,
1439
+ epi_c_smem_layout_staged,
1440
+ )
1441
+
1442
+ @staticmethod
1443
+ def _make_tma_epi_atoms_and_tensors(
1444
+ tensor_d: cute.Tensor,
1445
+ epi_smem_layout_staged: cute.ComposedLayout,
1446
+ epi_tile: Tuple[int, int],
1447
+ store_or_load: str,
1448
+ ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1449
+ """Create TMA atoms and tensors for storing D or loading C.
1450
+
1451
+ :param tensor_d: Output tensor D
1452
+ :type tensor_d: cute.Tensor
1453
+ :param epi_smem_layout_staged: Shared memory layout for epilogue
1454
+ :type epi_smem_layout_staged: cute.ComposedLayout
1455
+ :param epi_tile: Epilogue tile shape
1456
+ :type epi_tile: Tuple[int, int]
1457
+
1458
+ :return: TMA atom and tensor for C
1459
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1460
+ """
1461
+ assert store_or_load in ["load", "store"]
1462
+ epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
1463
+ d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
1464
+ op = (
1465
+ cpasync.CopyBulkTensorTileG2SOp()
1466
+ if store_or_load == "load"
1467
+ else cpasync.CopyBulkTensorTileS2GOp()
1468
+ )
1469
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
1470
+ op, tensor_d, epi_smem_layout, d_cta_v_layout
1471
+ )
1472
+ return tma_atom_d, tma_tensor_d
1473
+
1474
+ @staticmethod
1475
+ def _make_tma_atoms_and_tensors(
1476
+ tensor: cute.Tensor,
1477
+ smem_layout_staged: cute.ComposedLayout,
1478
+ smem_tile: Tuple[int, int],
1479
+ mcast_dim: int,
1480
+ ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1481
+ """Create TMA atoms and tensors for input tensors.
1482
+
1483
+ :param tensor: Input tensor (A or B)
1484
+ :type tensor: cute.Tensor
1485
+ :param smem_layout_staged: Shared memory layout for the tensor
1486
+ :type smem_layout_staged: cute.ComposedLayout
1487
+ :param smem_tile: Shared memory tile shape
1488
+ :type smem_tile: Tuple[int, int]
1489
+ :param mcast_dim: Multicast dimension
1490
+ :type mcast_dim: int
1491
+
1492
+ :return: TMA atom and tensor
1493
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1494
+ """
1495
+ op = (
1496
+ cpasync.CopyBulkTensorTileG2SOp()
1497
+ if mcast_dim == 1
1498
+ else cpasync.CopyBulkTensorTileG2SMulticastOp()
1499
+ )
1500
+
1501
+ smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
1502
+ tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
1503
+ op,
1504
+ tensor,
1505
+ smem_layout,
1506
+ smem_tile,
1507
+ num_multicast=mcast_dim,
1508
+ )
1509
+ return tma_atom, tma_tensor
1510
+
1511
+ @staticmethod
1512
+ def is_valid_dtypes(
1513
+ a_dtype: Type[cutlass.Numeric],
1514
+ b_dtype: Type[cutlass.Numeric],
1515
+ acc_dtype: Type[cutlass.Numeric],
1516
+ d_dtype: Type[cutlass.Numeric],
1517
+ a_major: str,
1518
+ b_major: str,
1519
+ ) -> bool:
1520
+ """
1521
+ Check if the dtypes are valid
1522
+
1523
+ :param a_dtype: The data type of tensor A
1524
+ :type a_dtype: Type[cutlass.Numeric]
1525
+ :param b_dtype: The data type of tensor B
1526
+ :type b_dtype: Type[cutlass.Numeric]
1527
+ :param acc_dtype: The data type of the accumulator
1528
+ :type acc_dtype: Type[cutlass.Numeric]
1529
+ :param d_dtype: The data type of the output tensor
1530
+ :type d_dtype: Type[cutlass.Numeric]
1531
+ :param a_major: major mode of tensor A
1532
+ :type a_major: str
1533
+ :param b_major: major mode of tensor B
1534
+ :type b_major: str
1535
+
1536
+ :return: True if the dtypes are valid, False otherwise
1537
+ :rtype: bool
1538
+ """
1539
+ is_valid = True
1540
+ if a_dtype not in {
1541
+ cutlass.Float16,
1542
+ cutlass.BFloat16,
1543
+ cutlass.Float8E4M3FN,
1544
+ cutlass.Float8E5M2,
1545
+ }:
1546
+ is_valid = False
1547
+ # tested b_dtype
1548
+ if b_dtype not in {
1549
+ cutlass.Float16,
1550
+ cutlass.BFloat16,
1551
+ cutlass.Float8E4M3FN,
1552
+ cutlass.Float8E5M2,
1553
+ }:
1554
+ is_valid = False
1555
+ if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
1556
+ is_valid = False
1557
+ # tested d_dtype
1558
+ if d_dtype not in {
1559
+ cutlass.Float32,
1560
+ cutlass.Float16,
1561
+ cutlass.BFloat16,
1562
+ cutlass.Float8E4M3FN,
1563
+ cutlass.Float8E5M2,
1564
+ }:
1565
+ is_valid = False
1566
+ # make sure a_dtype == b_dtype for Float16
1567
+ if a_dtype.width == 16 and a_dtype != b_dtype:
1568
+ is_valid = False
1569
+ # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2)
1570
+ if a_dtype.width != b_dtype.width:
1571
+ is_valid = False
1572
+
1573
+ # for Float8 types, this implementation only supports k-major layout
1574
+ if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
1575
+ is_valid = False
1576
+
1577
+ return is_valid
1578
+
1579
+
1580
+ def run(
1581
+ mnkl: Tuple[int, int, int, int],
1582
+ a_dtype: Type[cutlass.Numeric],
1583
+ b_dtype: Type[cutlass.Numeric],
1584
+ d_dtype: Type[cutlass.Numeric],
1585
+ c_dtype: Optional[Type[cutlass.Numeric]],
1586
+ acc_dtype: Type[cutlass.Numeric],
1587
+ a_major: str,
1588
+ b_major: str,
1589
+ d_major: str,
1590
+ c_major: str,
1591
+ tile_shape_mnk: Tuple[int, int, int],
1592
+ cluster_shape_mn: Tuple[int, int],
1593
+ tolerance: float,
1594
+ warmup_iterations: int,
1595
+ iterations: int,
1596
+ skip_ref_check: bool,
1597
+ persistent: bool,
1598
+ pingpong: bool,
1599
+ fp8_fast_accum: bool,
1600
+ alpha: float = 1.0,
1601
+ beta: float = 1.0,
1602
+ alpha_dtype: Type[cutlass.Numeric] = cutlass.Float32,
1603
+ beta_dtype: Type[cutlass.Numeric] = cutlass.Float32,
1604
+ **kwargs,
1605
+ ):
1606
+ """
1607
+ Prepare A/B/D/C tensors, launch GPU kernel, and reference checking.
1608
+
1609
+ :param mnkl: Problem size (M, N, K, L)
1610
+ :type mnkl: Tuple[int, int, int, int]
1611
+ :param a_dtype: Data type for input tensor A
1612
+ :type a_dtype: Type[cutlass.Numeric]
1613
+ :param b_dtype: Data type for input tensor B
1614
+ :type b_dtype: Type[cutlass.Numeric]
1615
+ :param d_dtype: Data type for output tensor C
1616
+ :type d_dtype: Type[cutlass.Numeric]
1617
+ :param acc_dtype: Data type for accumulation during matrix multiplication
1618
+ :type acc_dtype: Type[cutlass.Numeric]
1619
+ :param a_major/b_major/d_major: Memory layout of tensor A/B/C
1620
+ :type a_major/b_major/d_major: str
1621
+ :param tile_shape_mnk: CTA tile shape (M, N, K)
1622
+ :type tile_shape_mnk: Tuple[int, int, int]
1623
+ :param cluster_shape_mn: Cluster shape (M, N)
1624
+ :type cluster_shape_mn: Tuple[int, int]
1625
+ :param tolerance: Tolerance value for reference validation comparison
1626
+ :type tolerance: float
1627
+ :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
1628
+ :type warmup_iterations: int, optional
1629
+ :param iterations: Number of benchmark iterations to run, defaults to 1
1630
+ :type iterations: int, optional
1631
+ :param skip_ref_check: Whether to skip reference result validation, defaults to False
1632
+ :type skip_ref_check: bool, optional
1633
+ """
1634
+
1635
+ print("Running Hopper Dense GEMM with:")
1636
+ print(f"mnkl: {mnkl}")
1637
+ print(
1638
+ f"A dtype: {a_dtype}, B dtype: {b_dtype}, D dtype: {d_dtype}, C_dtype: {c_dtype}, Acc dtype: {acc_dtype}"
1639
+ )
1640
+ print(f"Matrix majors - A: {a_major}, B: {b_major}, D: {d_major}")
1641
+ print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
1642
+ print(f"Tolerance: {tolerance}")
1643
+ print(f"Alpha: {alpha}, Beta: {beta}")
1644
+ print(f"Alpha dtype: {alpha_dtype}, Beta dtype: {beta_dtype}")
1645
+ print(f"Warmup iterations: {warmup_iterations}")
1646
+ print(f"Iterations: {iterations}")
1647
+ print(f"Skip reference checking: {skip_ref_check}")
1648
+ # TODO: relax this
1649
+ assert c_dtype is None or c_dtype == d_dtype, "C dtype must match output dtype"
1650
+
1651
+ # Unpack parameters
1652
+ m, n, k, l = mnkl
1653
+ cluster_shape_mnk = (*cluster_shape_mn, 1)
1654
+
1655
+ # Skip unsupported types
1656
+ if not HopperSymmetricGemmKernel.is_valid_dtypes(
1657
+ a_dtype, b_dtype, acc_dtype, d_dtype, a_major, b_major
1658
+ ):
1659
+ raise TypeError(
1660
+ f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {d_dtype}, {a_major=}, {b_major=}"
1661
+ )
1662
+
1663
+ # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero)
1664
+ if not torch.cuda.is_available():
1665
+ raise RuntimeError("GPU is required to run this example!")
1666
+
1667
+ torch.manual_seed(1111)
1668
+
1669
+ # Create and permute tensor A/B/C
1670
+ def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
1671
+ # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
1672
+ # else : (l, mode0, mode1) -> (mode0, mode1, l)
1673
+ shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
1674
+ permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
1675
+ is_unsigned = dtype in {cutlass.Uint8}
1676
+ # Temporarily use uint8 as torch does not support fp8 type
1677
+ torch_dtype = cutlass_torch.dtype(dtype)
1678
+ gen_dtype = (
1679
+ torch_dtype
1680
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
1681
+ else torch.bfloat16
1682
+ )
1683
+
1684
+ # Create dtype torch tensor (cpu)
1685
+ torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
1686
+ shape,
1687
+ gen_dtype,
1688
+ permute_order=permute_order,
1689
+ init_type=cutlass.torch.TensorInitType.RANDOM,
1690
+ init_config=cutlass.torch.RandomInitConfig(
1691
+ min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
1692
+ ),
1693
+ ).to(torch_dtype)
1694
+ # Create dtype torch tensor (gpu)
1695
+ torch_tensor = torch_tensor_cpu.cuda()
1696
+
1697
+ # Create f32 torch tensor (cpu)
1698
+ f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
1699
+
1700
+ # Create dtype cute tensor (gpu)
1701
+ torch_tensor_view = (
1702
+ torch_tensor
1703
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
1704
+ else torch_tensor.view(torch.uint8)
1705
+ )
1706
+ cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
1707
+ cute_tensor.element_type = dtype
1708
+ if is_dynamic_layout:
1709
+ cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
1710
+ cute_tensor = cutlass.torch.convert_cute_tensor(
1711
+ f32_torch_tensor,
1712
+ cute_tensor,
1713
+ dtype,
1714
+ is_dynamic_layout=is_dynamic_layout,
1715
+ )
1716
+
1717
+ return f32_torch_tensor, cute_tensor, torch_tensor
1718
+
1719
+ # Create symmetric C matrix
1720
+ def create_and_permute_tensor_C(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
1721
+ # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
1722
+ # else : (l, mode0, mode1) -> (mode0, mode1, l)
1723
+ shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
1724
+ permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
1725
+ is_unsigned = dtype in {cutlass.Uint8}
1726
+ # Temporarily use uint8 as torch does not support fp8 type
1727
+ torch_dtype = cutlass_torch.dtype(dtype)
1728
+ gen_dtype = (
1729
+ torch_dtype
1730
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
1731
+ else torch.bfloat16
1732
+ )
1733
+
1734
+ # Create dtype torch tensor (cpu) - generate random matrix first
1735
+ base_tensor = cutlass.torch.create_and_permute_torch_tensor(
1736
+ shape,
1737
+ gen_dtype,
1738
+ permute_order=permute_order,
1739
+ init_type=cutlass.torch.TensorInitType.RANDOM,
1740
+ init_config=cutlass.torch.RandomInitConfig(
1741
+ min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
1742
+ ),
1743
+ ).to(torch_dtype)
1744
+
1745
+ # Create symmetric matrix
1746
+ assert mode0 == mode1, f"For symmetric C, mode0 ({mode0}) must equal mode1 ({mode1})"
1747
+ torch_tensor_cpu = base_tensor + base_tensor.transpose(0, 1)
1748
+
1749
+ # Create dtype torch tensor (gpu)
1750
+ torch_tensor = torch_tensor_cpu.cuda()
1751
+
1752
+ # Create f32 torch tensor (cpu)
1753
+ f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
1754
+
1755
+ # Create dtype cute tensor (gpu)
1756
+ torch_tensor_view = (
1757
+ torch_tensor
1758
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
1759
+ else torch_tensor.view(torch.uint8)
1760
+ )
1761
+ cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
1762
+ cute_tensor.element_type = dtype
1763
+ if is_dynamic_layout:
1764
+ cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
1765
+ cute_tensor = cutlass.torch.convert_cute_tensor(
1766
+ f32_torch_tensor,
1767
+ cute_tensor,
1768
+ dtype,
1769
+ is_dynamic_layout=is_dynamic_layout,
1770
+ )
1771
+
1772
+ return f32_torch_tensor, cute_tensor, torch_tensor
1773
+
1774
+ a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
1775
+ b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
1776
+ d, mD, d_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
1777
+ if c_dtype is not None:
1778
+ c, mC, c_torch = create_and_permute_tensor_C(l, m, n, c_major == "m", c_dtype)
1779
+ else:
1780
+ c, mC, c_torch = None, None, None
1781
+ b_torch.copy_(a_torch)
1782
+ b.copy_(a)
1783
+
1784
+ gemm = HopperSymmetricGemmKernel(
1785
+ acc_dtype,
1786
+ a_dtype,
1787
+ tile_shape_mnk,
1788
+ cluster_shape_mnk,
1789
+ pingpong=pingpong,
1790
+ is_persistent=persistent,
1791
+ fp8_fast_accum=fp8_fast_accum,
1792
+ )
1793
+
1794
+ # Compute max active clusters on current device
1795
+ if persistent:
1796
+ max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
1797
+ cluster_shape_mn[0] * cluster_shape_mn[1]
1798
+ )
1799
+ else:
1800
+ max_active_clusters = 0
1801
+
1802
+ torch_stream = torch.cuda.Stream()
1803
+ stream = cuda.CUstream(torch_stream.cuda_stream)
1804
+
1805
+ # Create alpha and beta as scalars with specified dtypes
1806
+ alpha_scalar = alpha_dtype(alpha)
1807
+ beta_scalar = beta_dtype(beta)
1808
+
1809
+ # compile gemm kernel
1810
+ compiled_gemm = cute.compile(
1811
+ gemm, mA, mB, mD, mC, alpha_scalar, beta_scalar, max_active_clusters, stream
1812
+ )
1813
+
1814
+ if not skip_ref_check:
1815
+ # execution
1816
+ compiled_gemm(mA, mB, mD, mC, alpha_scalar, beta_scalar, max_active_clusters, stream)
1817
+
1818
+ torch.cuda.synchronize()
1819
+
1820
+ # Ref check
1821
+ ref = torch.einsum("mkl,nkl->mnl", a, b)
1822
+ ref = alpha * ref
1823
+ if c is not None:
1824
+ ref = ref + beta * c
1825
+ ref = ref.cpu()
1826
+
1827
+ if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
1828
+ # m major: (l, n, m) -> (m, n, l)
1829
+ # n major: (l, m, n) -> (m, n, l)
1830
+ permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
1831
+ shape = (l, m, n) if d_major == "n" else (l, n, m)
1832
+ f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
1833
+ shape,
1834
+ torch.uint8,
1835
+ permute_order=permute_order,
1836
+ init_type=cutlass_torch.TensorInitType.SKIP,
1837
+ ).cuda()
1838
+ # Create dtype cute tensor (gpu)
1839
+ ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
1840
+ leading_dim=(1 if d_major == "n" else 0)
1841
+ )
1842
+ ref_d_tensor.element_type = d_dtype
1843
+ ref_d_tensor = cutlass_torch.convert_cute_tensor(
1844
+ ref,
1845
+ ref_d_tensor,
1846
+ d_dtype,
1847
+ is_dynamic_layout=True,
1848
+ )
1849
+ ref_d = f8_torch_tensor.cpu()
1850
+ else:
1851
+ ref_d = ref.to(cutlass_torch.dtype(d_dtype))
1852
+
1853
+ torch.testing.assert_close(d_torch.cpu(), ref_d, atol=tolerance, rtol=1e-03)
1854
+
1855
+ from triton.testing import do_bench
1856
+
1857
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1858
+
1859
+ flops = 2 * m * n * k * l
1860
+
1861
+ repeats = iterations
1862
+ warmup = warmup_iterations
1863
+
1864
+ import time
1865
+
1866
+ time.sleep(0.5)
1867
+ if a_dtype.width == 8:
1868
+ assert l == 1
1869
+ scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
1870
+ fn_cublas = lambda: torch._scaled_mm(
1871
+ a_torch[:, :, 0],
1872
+ b_torch[:, :, 0].mT,
1873
+ scale_a=scale_ab,
1874
+ scale_b=scale_ab,
1875
+ out_dtype=torch.bfloat16,
1876
+ use_fast_accum=fp8_fast_accum,
1877
+ )
1878
+ else:
1879
+ if c_torch is None:
1880
+ fn_cublas = lambda: alpha * torch.matmul(
1881
+ a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT
1882
+ )
1883
+ else:
1884
+ c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
1885
+ fn_cublas = lambda: alpha * torch.matmul(
1886
+ a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT
1887
+ ) + beta * c_torch_convert.permute(2, 0, 1)
1888
+ timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
1889
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
1890
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
1891
+
1892
+ time.sleep(0.5)
1893
+ fn = lambda: compiled_gemm(
1894
+ mA, mB, mD, mC, alpha_scalar, beta_scalar, max_active_clusters, current_stream
1895
+ )
1896
+ timing = do_bench(fn, warmup=warmup, rep=repeats)
1897
+ tflops = flops / (timing * 1e9) # Convert to TFlops
1898
+ print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
1899
+
1900
+ time.sleep(0.5)
1901
+ timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
1902
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
1903
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
1904
+
1905
+
1906
+ @lru_cache(maxsize=32)
1907
+ def get_max_active_clusters_cached(cluster_shape_mn_tuple):
1908
+ cluster_shape_mn = tuple(cluster_shape_mn_tuple)
1909
+ return cutlass.utils.HardwareInfo().get_max_active_clusters(
1910
+ cluster_shape_mn[0] * cluster_shape_mn[1]
1911
+ )
1912
+
1913
+
1914
+ def _symmetric_dense_gemm(
1915
+ a: torch.Tensor,
1916
+ b: torch.Tensor,
1917
+ c: torch.Tensor = None,
1918
+ alpha: float = 1.0,
1919
+ beta: float = 1.0,
1920
+ ) -> torch.Tensor:
1921
+ assert a.dim() == 3 and a.is_cuda, "A must be a 3D CUDA tensor"
1922
+ M, K, L = a.shape
1923
+ assert a.dtype in torch2cute_dtype_map, "Unsupported dtype for A"
1924
+ assert b.dim() == 3 and b.is_cuda, "B must be a 3D CUDA tensor"
1925
+ M, K, L = b.shape
1926
+ assert b.dtype in torch2cute_dtype_map, "Unsupported dtype for B"
1927
+ if c is not None:
1928
+ assert c.shape == (M, M, L) and c.is_cuda, "C must be (M,M,L) CUDA"
1929
+
1930
+ dtype = a.dtype
1931
+ cutlass_dtype = torch2cute_dtype_map[dtype]
1932
+
1933
+ def make_cute_tensor(x: torch.Tensor):
1934
+ x_fp32 = x.to(torch.float32)
1935
+ t = from_dlpack(x, assumed_align=16)
1936
+ t.element_type = cutlass_dtype
1937
+ if x.stride()[0] == 1:
1938
+ leading_dim = 0
1939
+ elif x.stride()[1] == 1:
1940
+ leading_dim = 1
1941
+ else:
1942
+ raise ValueError(
1943
+ f"Input tesnor should have stride 1 along either dim 0 or 1. Strides: {x.stride()}"
1944
+ )
1945
+ t = t.mark_layout_dynamic(leading_dim=leading_dim)
1946
+ return cutlass_torch.convert_cute_tensor(x_fp32, t, cutlass_dtype, is_dynamic_layout=True)
1947
+
1948
+ mA = make_cute_tensor(a)
1949
+ mB = make_cute_tensor(b)
1950
+
1951
+ if c is not None:
1952
+ mC = make_cute_tensor(c)
1953
+ else:
1954
+ mC = None
1955
+
1956
+ # Kernel requires output tensor with stride 1 along dim 0 or 1 as opposed to dim 2
1957
+ d = torch.empty_strided((M, M, L), (M, 1, M * M), dtype=a.dtype, device=a.device)
1958
+ mD = make_cute_tensor(d)
1959
+
1960
+ tile_shape_mnk = (128, 256, 64)
1961
+ cluster_shape_mn = (2, 1)
1962
+ persistent = True
1963
+ cluster_shape_mnk = (*cluster_shape_mn, 1)
1964
+
1965
+ compile_key = (
1966
+ cutlass_dtype,
1967
+ tile_shape_mnk,
1968
+ cluster_shape_mnk,
1969
+ c is not None,
1970
+ persistent,
1971
+ (M, K, L),
1972
+ (
1973
+ a.stride(1) == 1,
1974
+ b.stride(1) == 1,
1975
+ c.stride(1) == 1 if c is not None else None,
1976
+ d.stride(1) == 1,
1977
+ ),
1978
+ )
1979
+
1980
+ if persistent:
1981
+ max_active = get_max_active_clusters_cached(cluster_shape_mn)
1982
+ else:
1983
+ max_active = 0
1984
+
1985
+ alpha_s = cutlass.Float32(alpha)
1986
+ beta_s = cutlass.Float32(beta)
1987
+
1988
+ cache = _symmetric_dense_gemm.compile_cache
1989
+ if compile_key not in cache:
1990
+ gemm = HopperSymmetricGemmKernel(
1991
+ acc_dtype=cutlass.Float32,
1992
+ a_dtype=cutlass_dtype,
1993
+ tile_shape_mnk=tile_shape_mnk,
1994
+ cluster_shape_mnk=cluster_shape_mnk,
1995
+ pingpong=False,
1996
+ is_persistent=persistent,
1997
+ fp8_fast_accum=False,
1998
+ )
1999
+ cache[compile_key] = cute.compile(
2000
+ gemm,
2001
+ mA,
2002
+ mB,
2003
+ mD,
2004
+ mC,
2005
+ alpha_s,
2006
+ beta_s,
2007
+ max_active,
2008
+ cuda.CUstream(torch.cuda.current_stream().cuda_stream),
2009
+ )
2010
+ cache[compile_key](
2011
+ mA,
2012
+ mB,
2013
+ mD,
2014
+ mC,
2015
+ alpha_s,
2016
+ beta_s,
2017
+ max_active,
2018
+ cuda.CUstream(torch.cuda.current_stream().cuda_stream),
2019
+ )
2020
+
2021
+ return d
2022
+
2023
+
2024
+ _symmetric_dense_gemm.compile_cache = {}
2025
+
2026
+
2027
+ def symmetric_dense_gemm(
2028
+ a: torch.Tensor,
2029
+ b: torch.Tensor,
2030
+ c: torch.Tensor = None,
2031
+ alpha: float = 1.0,
2032
+ beta: float = 1.0,
2033
+ ) -> torch.Tensor:
2034
+ """High-performance batched symmetric dense GEMM.
2035
+
2036
+ Computes D = alpha * A @ B + beta * C using the symmetric dense GEMM kernel, with the assumption that
2037
+ A @ B is symmetric and C is symmetric.
2038
+
2039
+ Args:
2040
+ a: Input tensor A of shape (L, M, K) where L is batch dimension
2041
+ b: Input tensor B of shape (L, M, K) where L is batch dimension
2042
+ c: Optional tensor C of shape (L, M, M), defaults to None - MUST BE SYMMETRIC
2043
+ alpha: Scaling factor for A @ B, defaults to 1.0
2044
+ beta: Scaling factor for C (ignored if c is None), defaults to 1.0
2045
+
2046
+ Returns:
2047
+ Symmetric output tensor D of shape (M, M, L)
2048
+ """
2049
+ a_permuted = a.permute(1, 2, 0)
2050
+ b_permuted = b.permute(1, 2, 0)
2051
+
2052
+ c_permuted = None
2053
+ if c is not None:
2054
+ c_permuted = c.permute(1, 2, 0)
2055
+
2056
+ d = _symmetric_dense_gemm(a_permuted, b_permuted, c_permuted, alpha, beta)
2057
+
2058
+ return d.permute(2, 0, 1)
2059
+
2060
+
2061
+ if __name__ == "__main__":
2062
+ args = parse_arguments()
2063
+ run(
2064
+ args.mnkl,
2065
+ args.a_dtype,
2066
+ args.b_dtype,
2067
+ args.d_dtype,
2068
+ args.c_dtype,
2069
+ args.acc_dtype,
2070
+ args.a_major,
2071
+ args.b_major,
2072
+ args.d_major,
2073
+ args.c_major,
2074
+ args.tile_shape_mnk,
2075
+ args.cluster_shape_mn,
2076
+ args.tolerance,
2077
+ args.warmup_iterations,
2078
+ args.iterations,
2079
+ args.skip_ref_check,
2080
+ args.persistent,
2081
+ args.pingpong,
2082
+ args.fp8_fast_accum,
2083
+ alpha=args.alpha,
2084
+ beta=args.beta,
2085
+ alpha_dtype=args.alpha_dtype,
2086
+ beta_dtype=args.beta_dtype,
2087
+ )
2088
+ print("PASS")