quack-kernels 0.1.9__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1430 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+
7
+ # 1. Redistributions of source code must retain the above copyright notice, this
8
+ # list of conditions and the following disclaimer.
9
+
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+
14
+ # 3. Neither the name of the copyright holder nor the names of its
15
+ # contributors may be used to endorse or promote products derived from
16
+ # this software without specific prior written permission.
17
+
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+ import argparse
30
+ from typing import Tuple, Type
31
+ import math
32
+ import cuda.bindings.driver as cuda
33
+
34
+ import torch
35
+
36
+ import cutlass
37
+ import cutlass.cute as cute
38
+ import cutlass.cute.testing as testing
39
+ import cutlass.utils as utils
40
+ import cutlass.pipeline as pipeline
41
+ import cutlass.torch as cutlass_torch
42
+ from cutlass.cute.runtime import from_dlpack
43
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
44
+ import cutlass.utils.hopper_helpers as sm90_utils
45
+
46
+ """
47
+ A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
48
+ using CUTE DSL.
49
+ - Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
50
+ - Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
51
+ - Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
52
+
53
+ This GEMM kernel supports the following features:
54
+ - Utilizes Tensor Memory Access (TMA) for efficient memory operations
55
+ - Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations
56
+ - Implements TMA multicast with cluster to reduce L2 memory traffic
57
+ - Supports multi-stage pipeline to overlap computation and memory access
58
+
59
+ This GEMM works as follows:
60
+ 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
61
+ 2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction.
62
+ 3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations.
63
+
64
+ Hopper WGMMA instructions operate as follows:
65
+ - Read matrix A from SMEM
66
+ - Read matrix B from SMEM
67
+ - Perform MMA operation and store the result in Accumulator(register)
68
+
69
+ To run this example:
70
+
71
+ .. code-block:: bash
72
+
73
+ python examples/hopper/dense_gemm.py \
74
+ --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
75
+ --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
76
+ --d_dtype Float16 --acc_dtype Float32 \
77
+ --a_major k --b_major k --d_major n
78
+
79
+ The above example command compute batched gemm with M=8192, N=8192, K=8192,
80
+ batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape
81
+ is (1,1). The input, mma accumulator and output data type are set as fp16, fp32
82
+ and fp16, respectively.
83
+
84
+ To collect performance with NCU profiler:
85
+
86
+ .. code-block:: bash
87
+
88
+ ncu python examples/hopper/dense_gemm.py \
89
+ --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
90
+ --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
91
+ --d_dtype Float16 --acc_dtype Float32 \
92
+ --a_major k --b_major k --d_major n
93
+
94
+ Constraints:
95
+ * Supported input data types: fp16, fp8 (e4m3fn, e5m2)
96
+ * For fp16 types, A and B must have the same data type
97
+ * For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
98
+ * Fp8 types only support k-major layout
99
+ * Only fp32 accumulation is supported in this example
100
+ * CTA tile shape M must be 64/128
101
+ * CTA tile shape N must be 64/128/256
102
+ * CTA tile shape K must be 64
103
+ * Cluster shape M/N must be positive and power of 2, total cluster size <= 4
104
+ * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
105
+ i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
106
+ * OOB tiles are not allowed when TMA store is disabled
107
+ """
108
+
109
+
110
+ # /////////////////////////////////////////////////////////////////////////////
111
+ # Helpers to parse args
112
+ # /////////////////////////////////////////////////////////////////////////////
113
+ def parse_comma_separated_ints(s: str):
114
+ try:
115
+ return tuple([int(x.strip()) for x in s.split(",")])
116
+ except ValueError:
117
+ raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
118
+
119
+
120
+ def parse_arguments() -> argparse.Namespace:
121
+ parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.")
122
+
123
+ parser.add_argument(
124
+ "--mnkl",
125
+ type=parse_comma_separated_ints,
126
+ default=(4096, 4096, 4096, 1),
127
+ help="mnkl dimensions (comma-separated)",
128
+ )
129
+ parser.add_argument(
130
+ "--tile_shape_mnk",
131
+ type=parse_comma_separated_ints,
132
+ default=(128, 256, 64),
133
+ help="Cta tile shape (comma-separated)",
134
+ )
135
+ parser.add_argument(
136
+ "--cluster_shape_mn",
137
+ type=parse_comma_separated_ints,
138
+ choices=[(1, 1), (2, 1), (1, 2), (2, 2)],
139
+ default=(1, 1),
140
+ help="Cluster shape (comma-separated)",
141
+ )
142
+ parser.add_argument(
143
+ "--a_dtype",
144
+ type=cutlass.dtype,
145
+ default=cutlass.BFloat16,
146
+ )
147
+ parser.add_argument(
148
+ "--b_dtype",
149
+ type=cutlass.dtype,
150
+ default=cutlass.BFloat16,
151
+ )
152
+ parser.add_argument(
153
+ "--d_dtype",
154
+ type=cutlass.dtype,
155
+ default=cutlass.BFloat16,
156
+ )
157
+ parser.add_argument(
158
+ "--acc_dtype",
159
+ type=cutlass.dtype,
160
+ default=cutlass.Float32,
161
+ )
162
+ parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
163
+ parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
164
+ parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
165
+ parser.add_argument("--tolerance", type=float, default=1e-01, help="Tolerance for validation")
166
+ parser.add_argument("--warmup_iterations", type=int, default=0, help="Warmup iterations")
167
+ parser.add_argument(
168
+ "--iterations",
169
+ type=int,
170
+ default=1,
171
+ help="Number of iterations to run the kernel",
172
+ )
173
+ parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
174
+ parser.add_argument(
175
+ "--use_cold_l2",
176
+ action="store_true",
177
+ default=False,
178
+ help="Use circular buffer tensor sets to ensure L2 cold cache",
179
+ )
180
+
181
+ args = parser.parse_args()
182
+
183
+ if len(args.mnkl) != 4:
184
+ parser.error("--mnkl must contain exactly 4 values")
185
+ if len(args.tile_shape_mnk) != 3:
186
+ parser.error("--tile_shape_mnk must contain exactly 3 values")
187
+ if len(args.cluster_shape_mn) != 2:
188
+ parser.error("--cluster_shape_mn must contain exactly 2 values")
189
+
190
+ return args
191
+
192
+
193
+ # /////////////////////////////////////////////////////////////////////////////
194
+ # Host setup and device kernel launch
195
+ # /////////////////////////////////////////////////////////////////////////////
196
+
197
+
198
+ class HopperWgmmaGemmKernel:
199
+ """
200
+ This class implements batched matrix multiplication (C = A x B) with support for various data types
201
+ and architectural features specific to Hopper GPUs.
202
+
203
+ :param acc_dtype: Data type for accumulation during computation
204
+ :type acc_dtype: type[cutlass.Numeric]
205
+ :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
206
+ :type tile_shape_mnk: Tuple[int, int, int]
207
+ :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
208
+ :type cluster_shape_mnk: Tuple[int, int, int]
209
+
210
+ :note: Data type requirements:
211
+ - For 16-bit types: A and B must have the same data type
212
+ - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit
213
+ - Float8 types only support k-major layout
214
+
215
+ :note: Supported data types:
216
+ - Float16
217
+ - BFloat16
218
+ - Float8E4M3FN/Float8E5M2
219
+
220
+ :note: Supported accumulation types:
221
+ - Float32 (for all floating point inputs)
222
+
223
+ :note: Constraints:
224
+ - CTA tile M must be 64/128
225
+ - CTA tile N must be 64/128/256
226
+ - CTA tile K must be 64
227
+ - Cluster shape M/N must be positive and power of 2, total cluster size <= 4
228
+
229
+ Example:
230
+ >>> gemm = HopperWgmmaGemmKernel(
231
+ ... acc_dtype=cutlass.Float32,
232
+ ... tile_shape_mnk=(128, 256, 64),
233
+ ... cluster_shape_mnk=(1, 1, 1)
234
+ ... )
235
+ >>> gemm(a_tensor, b_tensor, c_tensor, stream)
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ acc_dtype: Type[cutlass.Numeric],
241
+ tile_shape_mnk: Tuple[int, int, int],
242
+ cluster_shape_mnk: Tuple[int, int, int],
243
+ ):
244
+ """
245
+ Initializes the configuration for a Hopper dense GEMM kernel.
246
+
247
+ This configuration includes data types for operands, tile shape, cluster configuration,
248
+ and thread layout.
249
+
250
+ :param acc_dtype: Data type for accumulation during computation
251
+ :type acc_dtype: type[cutlass.Numeric]
252
+ :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
253
+ :type tile_shape_mnk: Tuple[int, int, int]
254
+ :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
255
+ :type cluster_shape_mnk: Tuple[int, int, int]
256
+ """
257
+
258
+ self.acc_dtype = acc_dtype
259
+
260
+ self.cluster_shape_mnk = cluster_shape_mnk
261
+ self.tile_shape_mnk = tuple(tile_shape_mnk)
262
+ tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1]
263
+ # check the cta tile shape
264
+ # if tile_M not in [64, 128, 192, 256]:
265
+ # TODO: M=192 currently doesn't work
266
+ if tile_M not in [64, 128, 256]:
267
+ raise ValueError("CTA tile shape M must be 64/128/192/256")
268
+ if tile_M == 192: # special case
269
+ if not (tile_N % 32 == 0 and tile_N <= 288):
270
+ raise ValueError(
271
+ "If tile_m == 192, CTA tile shape N must be divisible by 32 and <= 288"
272
+ )
273
+ else:
274
+ if not ((tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)):
275
+ raise ValueError(
276
+ "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
277
+ )
278
+ if not self.tile_shape_mnk[2] % 16 == 0:
279
+ raise ValueError("CTA tile shape K must be divisible by 16")
280
+
281
+ if tile_M == 192: # Special case
282
+ atom_layout_m, atom_layout_n = 1, 2
283
+ else:
284
+ atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
285
+ atom_layout_n = 1
286
+ assert atom_layout_m in [1, 2] and atom_layout_n in [1, 2]
287
+ self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
288
+
289
+ self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
290
+ self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
291
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
292
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
293
+
294
+ self.occupancy = 1
295
+ self.mma_warp_groups = math.prod(self.atom_layout_mnk)
296
+ self.num_threads_per_warp_group = 128
297
+ self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
298
+ self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
299
+ self.num_mma_threads = self.mma_warp_groups * self.num_threads_per_warp_group
300
+
301
+ regs_per_thread = math.prod(self.tile_shape_mnk) // self.num_mma_threads
302
+ heavy_register_pressure = regs_per_thread >= 208
303
+ self.num_regs_load = 40 if not heavy_register_pressure else 24
304
+ self.num_regs_mma = 232 if not heavy_register_pressure else 240
305
+
306
+ self.ab_stage = None
307
+ self.epi_stage = None
308
+
309
+ self.a_smem_layout_staged = None
310
+ self.b_smem_layout_staged = None
311
+ self.epi_smem_layout_staged = None
312
+ self.epi_tile = None
313
+
314
+ self.shared_storage = None
315
+ self.buffer_align_bytes = 1024
316
+
317
+ def _setup_attributes(self):
318
+ """Set up configurations that are dependent on GEMM inputs
319
+
320
+ This method configures various attributes based on the input tensor properties
321
+ (data types, leading dimensions) and kernel settings:
322
+ - Configuring tiled MMA
323
+ - Computing MMA/cluster/tile shapes
324
+ - Computing cluster layout
325
+ - Computing multicast CTAs for A/B
326
+ - Computing epilogue subtile
327
+ - Setting up A/B/C stage counts in shared memory
328
+ - Computing A/B/C shared memory layout
329
+ """
330
+
331
+ self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
332
+
333
+ is_cooperative = math.prod(self.atom_layout_mnk) > 1
334
+ self.epi_tile = self._sm90_compute_tile_shape_or_override(
335
+ self.tile_shape_mnk, self.d_dtype, is_cooperative=is_cooperative
336
+ )
337
+
338
+ # Compute stage before compute smem layout
339
+ self.ab_stage, self.epi_stage = self._compute_stages(
340
+ self.tile_shape_mnk,
341
+ self.a_dtype,
342
+ self.b_dtype,
343
+ self.smem_capacity,
344
+ self.occupancy,
345
+ )
346
+
347
+ (
348
+ self.a_smem_layout_staged,
349
+ self.b_smem_layout_staged,
350
+ self.epi_smem_layout_staged,
351
+ ) = self._make_smem_layouts(
352
+ self.tile_shape_mnk,
353
+ self.epi_tile,
354
+ self.a_dtype,
355
+ self.a_layout,
356
+ self.b_dtype,
357
+ self.b_layout,
358
+ self.ab_stage,
359
+ self.d_dtype,
360
+ self.d_layout,
361
+ self.epi_stage,
362
+ )
363
+
364
+ @cute.jit
365
+ def __call__(
366
+ self,
367
+ mA: cute.Tensor,
368
+ mB: cute.Tensor,
369
+ mD: cute.Tensor,
370
+ stream: cuda.CUstream,
371
+ ):
372
+ """Execute the GEMM operation in steps:
373
+ - Setup static attributes
374
+ - Setup TMA load/store atoms and tensors
375
+ - Compute grid size
376
+ - Define shared storage for kernel
377
+ - Launch the kernel synchronously
378
+
379
+ :param mA: Input tensor A
380
+ :type mA: cute.Tensor
381
+ :param mB: Input tensor B
382
+ :type mB: cute.Tensor
383
+ :param mD: Output tensor D
384
+ :type mD: cute.Tensor
385
+ :param stream: CUDA stream for asynchronous execution
386
+ :type stream: cuda.CUstream
387
+ """
388
+
389
+ # setup static attributes before smem/grid/tma computation
390
+ self.a_dtype = mA.element_type
391
+ self.b_dtype = mB.element_type
392
+ self.d_dtype = mD.element_type
393
+ self.a_layout = utils.LayoutEnum.from_tensor(mA)
394
+ self.b_layout = utils.LayoutEnum.from_tensor(mB)
395
+ self.d_layout = utils.LayoutEnum.from_tensor(mD)
396
+
397
+ if cutlass.const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
398
+ raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
399
+ if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width):
400
+ raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
401
+ if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
402
+ raise TypeError("a_dtype should be float16 or float8")
403
+
404
+ self._setup_attributes()
405
+
406
+ tiled_mma = sm90_utils.make_trivial_tiled_mma(
407
+ self.a_dtype,
408
+ self.b_dtype,
409
+ self.a_layout.sm90_mma_major_mode(),
410
+ self.b_layout.sm90_mma_major_mode(),
411
+ self.acc_dtype,
412
+ self.atom_layout_mnk,
413
+ tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
414
+ )
415
+
416
+ tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
417
+ mA,
418
+ self.a_smem_layout_staged,
419
+ (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
420
+ self.cluster_shape_mnk[1],
421
+ )
422
+
423
+ tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
424
+ mB,
425
+ self.b_smem_layout_staged,
426
+ (self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
427
+ self.cluster_shape_mnk[0],
428
+ )
429
+
430
+ tma_atom_d, tma_tensor_d = self._make_tma_store_atoms_and_tensors(
431
+ mD,
432
+ self.epi_smem_layout_staged,
433
+ self.epi_tile,
434
+ )
435
+
436
+ grid = self._compute_grid(mD, self.tile_shape_mnk, self.cluster_shape_mnk)
437
+
438
+ @cute.struct
439
+ class SharedStorage:
440
+ mainloop_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
441
+ sA: cute.struct.Align[
442
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
443
+ self.buffer_align_bytes,
444
+ ]
445
+ sB: cute.struct.Align[
446
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
447
+ self.buffer_align_bytes,
448
+ ]
449
+
450
+ self.shared_storage = SharedStorage
451
+
452
+ # Launch the kernel synchronously
453
+ self.kernel(
454
+ tma_atom_a,
455
+ tma_tensor_a,
456
+ tma_atom_b,
457
+ tma_tensor_b,
458
+ tma_atom_d,
459
+ tma_tensor_d,
460
+ tiled_mma,
461
+ self.cta_layout_mnk,
462
+ self.a_smem_layout_staged,
463
+ self.b_smem_layout_staged,
464
+ self.epi_smem_layout_staged,
465
+ ).launch(
466
+ grid=grid,
467
+ block=[self.threads_per_cta, 1, 1],
468
+ cluster=self.cluster_shape_mnk,
469
+ smem=self.shared_storage.size_in_bytes(),
470
+ stream=stream,
471
+ min_blocks_per_mp=1,
472
+ )
473
+ return
474
+
475
+ # GPU device kernel
476
+ @cute.kernel
477
+ def kernel(
478
+ self,
479
+ tma_atom_a: cute.CopyAtom,
480
+ mA_mkl: cute.Tensor,
481
+ tma_atom_b: cute.CopyAtom,
482
+ mB_nkl: cute.Tensor,
483
+ tma_atom_d: cute.CopyAtom,
484
+ mD_mnl: cute.Tensor,
485
+ tiled_mma: cute.TiledMma,
486
+ cta_layout_mnk: cute.Layout,
487
+ a_smem_layout_staged: cute.ComposedLayout,
488
+ b_smem_layout_staged: cute.ComposedLayout,
489
+ epi_smem_layout_staged: cute.ComposedLayout,
490
+ ):
491
+ """
492
+ GPU device kernel performing the batched GEMM computation.
493
+
494
+ :param tma_atom_a: TMA copy atom for A tensor
495
+ :type tma_atom_a: cute.CopyAtom
496
+ :param mA_mkl: Input tensor A
497
+ :type mA_mkl: cute.Tensor
498
+ :param tma_atom_b: TMA copy atom for B tensor
499
+ :type tma_atom_b: cute.CopyAtom
500
+ :param mB_nkl: Input tensor B
501
+ :type mB_nkl: cute.Tensor
502
+ :param tma_atom_d: TMA copy atom for D tensor
503
+ :type tma_atom_d: cute.CopyAtom
504
+ :param mD_mnl: Output tensor D
505
+ :type mD_mnl: cute.Tensor
506
+ :param tiled_mma: Tiled MMA object
507
+ :type tiled_mma: cute.TiledMma
508
+ :param cta_layout_mnk: CTA layout
509
+ :type cta_layout_mnk: cute.Layout
510
+ :param a_smem_layout_staged: Shared memory layout for A
511
+ :type a_smem_layout_staged: cute.ComposedLayout
512
+ :param b_smem_layout_staged: Shared memory layout for B
513
+ :type b_smem_layout_staged: cute.ComposedLayout
514
+ :param epi_smem_layout_staged: Shared memory layout for epilogue
515
+ :type epi_smem_layout_staged: cute.ComposedLayout
516
+ """
517
+
518
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
519
+
520
+ # /////////////////////////////////////////////////////////////////////////////
521
+ # Prefetch Tma desc
522
+ # /////////////////////////////////////////////////////////////////////////////
523
+ # if warp_idx == 0:
524
+ if warp_idx == self.mma_warp_groups * 4:
525
+ cpasync.prefetch_descriptor(tma_atom_a)
526
+ cpasync.prefetch_descriptor(tma_atom_b)
527
+ cpasync.prefetch_descriptor(tma_atom_d)
528
+
529
+ a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
530
+ b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
531
+ tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes(
532
+ self.b_dtype, b_smem_layout
533
+ )
534
+
535
+ # /////////////////////////////////////////////////////////////////////////////
536
+ # Alloc and init AB full/empty + ACC full mbar (pipeline)
537
+ # /////////////////////////////////////////////////////////////////////////////
538
+ smem = cutlass.utils.SmemAllocator()
539
+ storage = smem.allocate(self.shared_storage)
540
+
541
+ # Threads/warps participating in this pipeline
542
+ mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
543
+ # Each warp will constribute to the arrive count with the number of mcast size
544
+ mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
545
+ consumer_arrive_cnt = mcast_size * (self.num_mma_threads // cute.arch.WARP_SIZE)
546
+ mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
547
+ pipeline.Agent.Thread, consumer_arrive_cnt
548
+ )
549
+
550
+ cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
551
+ mainloop_pipeline = pipeline.PipelineTmaAsync.create(
552
+ barrier_storage=storage.mainloop_pipeline_array_ptr.data_ptr(),
553
+ num_stages=self.ab_stage,
554
+ producer_group=mainloop_pipeline_producer_group,
555
+ consumer_group=mainloop_pipeline_consumer_group,
556
+ tx_count=tma_copy_bytes,
557
+ cta_layout_vmnk=cta_layout_vmnk,
558
+ )
559
+
560
+ # ///////////////////////////////////////////////////////////////////////////////
561
+ # Generate smem tensor A/B
562
+ # ///////////////////////////////////////////////////////////////////////////////
563
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
564
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
565
+ sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
566
+ sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
567
+
568
+ # ///////////////////////////////////////////////////////////////////////////////
569
+ # Get cta/warp/thread idx
570
+ # ///////////////////////////////////////////////////////////////////////////////
571
+
572
+ cidx, cidy, _ = cute.arch.cluster_idx()
573
+ cdimx, cdimy, _ = cute.arch.cluster_dim()
574
+ cluster_id = cidx + cdimx * cidy
575
+
576
+ # CTA Swizzle to promote L2 data reuse
577
+ group_size_m = 8
578
+ s_shape = (
579
+ (group_size_m, cdimx // group_size_m),
580
+ cdimy,
581
+ )
582
+ s_stride = ((1, cdimy * group_size_m), group_size_m)
583
+ s_layout = cute.make_layout(s_shape, stride=s_stride)
584
+ num_reg_cids = cute.size(s_shape)
585
+ cid_m, cid_n = s_layout.get_flat_coord(cluster_id % num_reg_cids)
586
+
587
+ # Deal with the tail part
588
+ if cluster_id >= num_reg_cids:
589
+ tail_size_m = cdimx % group_size_m
590
+ tail_layout = cute.make_layout((tail_size_m, cdimy), stride=(1, tail_size_m))
591
+ tail_cid = cluster_id - num_reg_cids
592
+ tail_cid_m, tail_cid_n = tail_layout.get_flat_coord(tail_cid)
593
+ cid_m = cute.size(s_shape, mode=[0]) + tail_cid_m
594
+ cid_n = tail_cid_n
595
+
596
+ # Get the pid from cluster id
597
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
598
+ pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
599
+ pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
600
+
601
+ _, _, bidz = cute.arch.block_idx()
602
+ tile_coord_mnkl = (pid_m, pid_n, None, bidz)
603
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
604
+ cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
605
+
606
+ k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
607
+
608
+ if warp_idx >= self.mma_warp_groups * 4:
609
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
610
+ if warp_idx == self.mma_warp_groups * 4:
611
+ # ///////////////////////////////////////////////////////////////////////////////
612
+ # Get mcast mask
613
+ # ///////////////////////////////////////////////////////////////////////////////
614
+ a_mcast_mask = cute.make_layout_image_mask(
615
+ cta_layout_mnk, cluster_coord_mnk, mode=1
616
+ )
617
+ b_mcast_mask = cute.make_layout_image_mask(
618
+ cta_layout_mnk, cluster_coord_mnk, mode=0
619
+ )
620
+ a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
621
+ b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
622
+ mainloop_producer_state = pipeline.make_pipeline_state(
623
+ pipeline.PipelineUserType.Producer, self.ab_stage
624
+ )
625
+ # ///////////////////////////////////////////////////////////////////////////////
626
+ # Local_tile partition global tensors
627
+ # ///////////////////////////////////////////////////////////////////////////////
628
+ # (bM, bK, RestK)
629
+ gA_mkl = cute.local_tile(
630
+ mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
631
+ )
632
+ # (bN, bK, RestK)
633
+ gB_nkl = cute.local_tile(
634
+ mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
635
+ )
636
+ # //////////////////////////////////////////////////////////////////////////////
637
+ # Partition shared tensor for TMA load A/B
638
+ # //////////////////////////////////////////////////////////////////////////////
639
+ # TMA load A partition_S/D
640
+ a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
641
+ a_cta_crd = cluster_coord_mnk[1]
642
+ tAsA, tAgA_mkl = cpasync.tma_partition(
643
+ tma_atom_a,
644
+ a_cta_crd,
645
+ a_cta_layout,
646
+ cute.group_modes(sA, 0, 2),
647
+ cute.group_modes(gA_mkl, 0, 2),
648
+ )
649
+ # TMA load B partition_S/D
650
+ b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
651
+ b_cta_crd = cluster_coord_mnk[0]
652
+ tBsB, tBgB_nkl = cpasync.tma_partition(
653
+ tma_atom_b,
654
+ b_cta_crd,
655
+ b_cta_layout,
656
+ cute.group_modes(sB, 0, 2),
657
+ cute.group_modes(gB_nkl, 0, 2),
658
+ )
659
+ # /////////////////////////////////////////////////////////////////////////////
660
+ # TMA load
661
+ # /////////////////////////////////////////////////////////////////////////////
662
+ for k_tile in cutlass.range(k_tile_cnt, unroll=1):
663
+ # Wait for A/B buffers to be empty before loading into them
664
+ # Also sets the transaction barrier for the A/B buffers
665
+ mainloop_pipeline.producer_acquire(mainloop_producer_state)
666
+ # /////////////////////////////////////////////////////////////////////////////
667
+ # TMA load A/B
668
+ # /////////////////////////////////////////////////////////////////////////////
669
+ cute.copy(
670
+ tma_atom_a,
671
+ tAgA_mkl[None, k_tile],
672
+ tAsA[None, mainloop_producer_state.index],
673
+ tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state),
674
+ mcast_mask=a_mcast_mask,
675
+ )
676
+ cute.copy(
677
+ tma_atom_b,
678
+ tBgB_nkl[None, k_tile],
679
+ tBsB[None, mainloop_producer_state.index],
680
+ tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state),
681
+ mcast_mask=b_mcast_mask,
682
+ )
683
+ # Mainloop pipeline's producer commit is a NOP
684
+ mainloop_pipeline.producer_commit(mainloop_producer_state)
685
+ mainloop_producer_state.advance()
686
+ mainloop_pipeline.producer_tail(mainloop_producer_state)
687
+
688
+ if warp_idx < self.mma_warp_groups * 4:
689
+ cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
690
+ # //////////////////////////////////////////////////////////////////////////////
691
+ # Partition global tensor for TiledMMA_A/B/C
692
+ # //////////////////////////////////////////////////////////////////////////////
693
+ tidx, _, _ = cute.arch.thread_idx()
694
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
695
+ warp_group_thread_layout = cute.make_layout(
696
+ self.mma_warp_groups, stride=self.num_threads_per_warp_group
697
+ )
698
+ thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx))
699
+
700
+ # //////////////////////////////////////////////////////////////////////////////
701
+ # Make fragments
702
+ # //////////////////////////////////////////////////////////////////////////////
703
+ tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
704
+ tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
705
+
706
+ acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
707
+ acc = cute.make_fragment(acc_shape, self.acc_dtype)
708
+
709
+ mainloop_consumer_read_state = pipeline.make_pipeline_state(
710
+ pipeline.PipelineUserType.Consumer, self.ab_stage
711
+ )
712
+ mainloop_consumer_release_state = pipeline.make_pipeline_state(
713
+ pipeline.PipelineUserType.Consumer, self.ab_stage
714
+ )
715
+
716
+ # /////////////////////////////////////////////////////////////////////////////
717
+ # Prologue MMAs
718
+ # /////////////////////////////////////////////////////////////////////////////
719
+ k_pipe_mmas = 1
720
+ peek_ab_full_status = cutlass.Boolean(1)
721
+ if mainloop_consumer_read_state.count < k_tile_cnt:
722
+ peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
723
+ mainloop_consumer_read_state
724
+ )
725
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
726
+ num_k_blocks = cute.size(tCrA, mode=[2])
727
+ for k_tile in cutlass.range_constexpr(k_pipe_mmas):
728
+ # Wait for A/B buffer to be ready
729
+ mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status)
730
+ warpgroup.fence()
731
+ for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
732
+ k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index)
733
+ cute.gemm(tiled_mma, acc, tCrA[k_block_coord], tCrB[k_block_coord], acc)
734
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
735
+ warpgroup.commit_group()
736
+ mainloop_consumer_read_state.advance()
737
+ peek_ab_full_status = cutlass.Boolean(1)
738
+ if mainloop_consumer_read_state.count < k_tile_cnt:
739
+ peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
740
+ mainloop_consumer_read_state
741
+ )
742
+
743
+ # /////////////////////////////////////////////////////////////////////////////
744
+ # MAINLOOP
745
+ # /////////////////////////////////////////////////////////////////////////////
746
+ for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, unroll=1):
747
+ # Wait for TMA copies to complete
748
+ mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status)
749
+ # WGMMA
750
+ warpgroup.fence()
751
+ for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
752
+ k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index)
753
+ cute.gemm(tiled_mma, acc, tCrA[k_block_coord], tCrB[k_block_coord], acc)
754
+ warpgroup.commit_group()
755
+ # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
756
+ warpgroup.wait_group(k_pipe_mmas)
757
+ mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
758
+ mainloop_consumer_read_state.advance()
759
+ mainloop_consumer_release_state.advance()
760
+ peek_ab_full_status = cutlass.Boolean(1)
761
+ if mainloop_consumer_read_state.count < k_tile_cnt:
762
+ peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
763
+ mainloop_consumer_read_state
764
+ )
765
+ warpgroup.wait_group(0)
766
+ for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
767
+ mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
768
+ mainloop_consumer_release_state.advance()
769
+
770
+ # /////////////////////////////////////////////////////////////////////////////
771
+ # EPILOGUE
772
+ # /////////////////////////////////////////////////////////////////////////////
773
+
774
+ # Wait for all warp groups in the thread block to finish, because smem for tensor A in
775
+ # the mainloop is reused in the epilogue.
776
+ cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
777
+
778
+ copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
779
+ self.d_layout,
780
+ elem_ty_d=self.d_dtype,
781
+ elem_ty_acc=self.acc_dtype,
782
+ )
783
+ copy_atom_D = cute.make_copy_atom(
784
+ warp.StMatrix8x8x16bOp(self.d_layout.is_m_major_c(), 4),
785
+ self.d_dtype,
786
+ )
787
+ tiled_copy_D_Atom = cute.make_tiled_copy_C_atom(copy_atom_D, tiled_mma)
788
+ tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_D_Atom)
789
+ # (R2S, R2S_M, R2S_N, PIPE_D)
790
+ tRS_sD = tiled_copy_r2s.get_slice(tidx).partition_D(sD)
791
+ # (R2S, R2S_M, R2S_N)
792
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
793
+
794
+ # (bM, bN)
795
+ gD_mnl = cute.local_tile(
796
+ mD_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
797
+ )
798
+ tcgc_for_tma_partition = cute.zipped_divide(gD_mnl, self.epi_tile)
799
+ bSG_sD, bSG_gD = cpasync.tma_partition(
800
+ tma_atom_d,
801
+ 0,
802
+ cute.make_layout(1),
803
+ cute.group_modes(sD, 0, 2),
804
+ tcgc_for_tma_partition,
805
+ )
806
+
807
+ epi_tile_num = cutlass.const_expr(cute.size(tcgc_for_tma_partition, mode=[1]))
808
+ epi_tile_shape = tcgc_for_tma_partition.shape[1]
809
+
810
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
811
+ # Copy from acc to D registers
812
+ tRS_rD = cute.make_fragment_like(tRS_sD[None, None, None, 0], self.acc_dtype)
813
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
814
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
815
+ # Type conversion
816
+ tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
817
+ tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
818
+ # Copy from D registers to shared memory
819
+ epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3])
820
+ # cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
821
+ cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)])
822
+ cute.arch.fence_proxy(
823
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
824
+ )
825
+ # barrier for sync
826
+ cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
827
+ # Get the global memory coordinate for the current epi tile.
828
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
829
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
830
+ # Copy from shared memory to global memory
831
+ if warp_idx == 0:
832
+ cute.copy(tma_atom_d, bSG_sD[(None, epi_buffer)], bSG_gD[(None, gmem_coord)])
833
+ cute.arch.cp_async_bulk_commit_group()
834
+ # TODO: when moving to persistent maybe we always need this wait_group
835
+ if epi_idx >= self.epi_stage - 1:
836
+ cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
837
+ if epi_idx >= self.epi_stage - 1:
838
+ cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
839
+
840
+ if warp_idx == 0:
841
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
842
+
843
+ @staticmethod
844
+ def _compute_stages(
845
+ tile_shape_mnk: Tuple[int, int, int],
846
+ a_dtype: Type[cutlass.Numeric],
847
+ b_dtype: Type[cutlass.Numeric],
848
+ smem_capacity: int,
849
+ occupancy: int,
850
+ ) -> Tuple[int, int]:
851
+ """Computes the number of stages for A/B/C operands based on heuristics.
852
+
853
+ :param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
854
+ :type tile_shape_mnk: Tuple[int, int, int]
855
+ :param a_dtype: Data type of operand A.
856
+ :type a_dtype: type[cutlass.Numeric]
857
+ :param b_dtype: Data type of operand B.
858
+ :type b_dtype: type[cutlass.Numeric]
859
+ :param smem_capacity: Total available shared memory capacity in bytes.
860
+ :type smem_capacity: int
861
+ :param occupancy: Target number of CTAs per SM (occupancy).
862
+ :type occupancy: int
863
+
864
+ :return: A tuple containing the computed number of stages for:
865
+ (A/B operand stages, epilogue stages)
866
+ :rtype: Tuple[int, int]
867
+ """
868
+
869
+ # epi_stage = 4 if tile_shape_mnk[1] % 32 == 0 else 8
870
+ epi_stage = 4
871
+ # epi_smem will reuse smem ab.
872
+ epi_bytes = 0
873
+
874
+ a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
875
+ b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
876
+ ab_bytes_per_stage = (
877
+ cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
878
+ )
879
+ mbar_helpers_bytes = 1024
880
+
881
+ ab_stage = (
882
+ (smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
883
+ ) // ab_bytes_per_stage
884
+ return ab_stage, epi_stage
885
+
886
+ @staticmethod
887
+ def _sm90_compute_tile_shape_or_override(
888
+ tile_shape_mnk: Tuple[int, int, int],
889
+ element_type: Type[cutlass.Numeric],
890
+ is_cooperative: bool = False,
891
+ epi_tile_override: Tuple[int, int] | None = None,
892
+ ) -> Tuple[int, int]:
893
+ """Compute the epilogue tile shape or use override if provided.
894
+
895
+ :param tile_shape_mnk: CTA tile shape (M,N,K)
896
+ :type tile_shape_mnk: Tuple[int, int, int]
897
+ :param element_type: Data type of elements
898
+ :type element_type: type[cutlass.Numeric]
899
+ :param is_cooperative: Whether to use cooperative approach
900
+ :type is_cooperative: bool
901
+ :param epi_tile_override: Optional override for epilogue tile shape
902
+ :type epi_tile_override: Tuple[int, int] or None
903
+
904
+ :return: Computed epilogue tile shape
905
+ :rtype: Tuple[int, int]
906
+ """
907
+ if epi_tile_override is not None:
908
+ return epi_tile_override
909
+ if is_cooperative:
910
+ if cute.size(tile_shape_mnk, mode=[0]) == 192:
911
+ tile_m = 192
912
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]) // 2)
913
+ else:
914
+ tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
915
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
916
+ return (tile_m, tile_n)
917
+ else:
918
+ n_perf = 64 if element_type.width == 8 else 32
919
+ tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
920
+ tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
921
+ return (tile_m, tile_n)
922
+
923
+ @staticmethod
924
+ def _make_smem_layouts(
925
+ tile_shape_mnk: Tuple[int, int, int],
926
+ epi_tile: Tuple[int, int],
927
+ a_dtype: Type[cutlass.Numeric],
928
+ a_layout: utils.LayoutEnum,
929
+ b_dtype: Type[cutlass.Numeric],
930
+ b_layout: utils.LayoutEnum,
931
+ ab_stage: int,
932
+ d_dtype: Type[cutlass.Numeric],
933
+ d_layout: utils.LayoutEnum,
934
+ epi_stage: int,
935
+ ) -> Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]:
936
+ """Create shared memory layouts for A, B, and C tensors.
937
+
938
+ :param tile_shape_mnk: CTA tile shape (M,N,K)
939
+ :type tile_shape_mnk: Tuple[int, int, int]
940
+ :param epi_tile: Epilogue tile shape
941
+ :type epi_tile: Tuple[int, int]
942
+ :param a_dtype: Data type for matrix A
943
+ :type a_dtype: type[cutlass.Numeric]
944
+ :param a_layout: Layout enum for matrix A
945
+ :type a_layout: utils.LayoutEnum
946
+ :param b_dtype: Data type for matrix B
947
+ :type b_dtype: type[cutlass.Numeric]
948
+ :param b_layout: Layout enum for matrix B
949
+ :type b_layout: utils.LayoutEnum
950
+ :param ab_stage: Number of stages for A/B tensors
951
+ :type ab_stage: int
952
+ :param d_dtype: Data type for output matrix C
953
+ :type d_dtype: type[cutlass.Numeric]
954
+ :param d_layout: Layout enum for the output matrix C
955
+ :type d_layout: utils.LayoutEnum
956
+ :param epi_stage: Number of epilogue stages
957
+ :type epi_stage: int
958
+
959
+ :return: Tuple of shared memory layouts for A, B, and C
960
+ :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
961
+ """
962
+ a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
963
+
964
+ a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
965
+ b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
966
+ a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
967
+ a_smem_layout_atom = warpgroup.make_smem_layout_atom(
968
+ sm90_utils.get_smem_layout_atom(
969
+ a_layout,
970
+ a_dtype,
971
+ a_major_mode_size,
972
+ ),
973
+ a_dtype,
974
+ )
975
+ a_smem_layout_staged = cute.tile_to_shape(
976
+ a_smem_layout_atom,
977
+ cute.append(a_smem_shape, ab_stage),
978
+ order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
979
+ )
980
+
981
+ b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None))
982
+
983
+ b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
984
+ b_smem_layout_atom = warpgroup.make_smem_layout_atom(
985
+ sm90_utils.get_smem_layout_atom(
986
+ b_layout,
987
+ b_dtype,
988
+ b_major_mode_size,
989
+ ),
990
+ b_dtype,
991
+ )
992
+ b_smem_layout_staged = cute.tile_to_shape(
993
+ b_smem_layout_atom,
994
+ cute.append(b_smem_shape, ab_stage),
995
+ order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
996
+ )
997
+
998
+ d_smem_shape = epi_tile
999
+ d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1000
+ d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1001
+ sm90_utils.get_smem_layout_atom(
1002
+ d_layout,
1003
+ d_dtype,
1004
+ d_major_mode_size,
1005
+ ),
1006
+ d_dtype,
1007
+ )
1008
+ epi_smem_layout_staged = cute.tile_to_shape(
1009
+ d_smem_layout_atom,
1010
+ cute.append(d_smem_shape, epi_stage),
1011
+ order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1012
+ )
1013
+
1014
+ return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged
1015
+
1016
+ @staticmethod
1017
+ def _compute_grid(
1018
+ d: cute.Tensor,
1019
+ tile_shape_mnk: Tuple[int, int, int],
1020
+ cluster_shape_mnk: Tuple[int, int, int],
1021
+ ) -> Tuple[int, int, int]:
1022
+ """Compute grid shape for the output tensor C.
1023
+
1024
+ :param d: The output tensor C
1025
+ :type d: cute.Tensor
1026
+ :param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1027
+ :type tile_shape_mnk: Tuple[int, int, int]
1028
+ :param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions.
1029
+ :type cluster_shape_mnk: Tuple[int, int, int]
1030
+
1031
+ :return: Grid shape for kernel launch.
1032
+ :rtype: Tuple[int, int, int]
1033
+ """
1034
+
1035
+ c_shape = (tile_shape_mnk[0], tile_shape_mnk[1])
1036
+ gc = cute.zipped_divide(d, tiler=c_shape)
1037
+ clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk)
1038
+ grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk))
1039
+ return grid
1040
+
1041
+ @staticmethod
1042
+ def _make_tma_store_atoms_and_tensors(
1043
+ tensor_d: cute.Tensor,
1044
+ epi_smem_layout_staged: cute.ComposedLayout,
1045
+ epi_tile: Tuple[int, int],
1046
+ ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1047
+ """Create TMA atoms and tensors for C tensor storage.
1048
+
1049
+ :param tensor_d: Output tensor D
1050
+ :type tensor_d: cute.Tensor
1051
+ :param epi_smem_layout_staged: Shared memory layout for epilogue
1052
+ :type epi_smem_layout_staged: cute.ComposedLayout
1053
+ :param epi_tile: Epilogue tile shape
1054
+ :type epi_tile: Tuple[int, int]
1055
+
1056
+ :return: TMA atom and tensor for C
1057
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1058
+ """
1059
+ epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
1060
+ c_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
1061
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
1062
+ cpasync.CopyBulkTensorTileS2GOp(),
1063
+ tensor_d,
1064
+ epi_smem_layout,
1065
+ c_cta_v_layout,
1066
+ )
1067
+
1068
+ return tma_atom_d, tma_tensor_d
1069
+
1070
+ @staticmethod
1071
+ def _make_tma_atoms_and_tensors(
1072
+ tensor: cute.Tensor,
1073
+ smem_layout_staged: cute.ComposedLayout,
1074
+ smem_tile: Tuple[int, int],
1075
+ mcast_dim: int,
1076
+ ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1077
+ """Create TMA atoms and tensors for input tensors.
1078
+
1079
+ :param tensor: Input tensor (A or B)
1080
+ :type tensor: cute.Tensor
1081
+ :param smem_layout_staged: Shared memory layout for the tensor
1082
+ :type smem_layout_staged: cute.ComposedLayout
1083
+ :param smem_tile: Shared memory tile shape
1084
+ :type smem_tile: Tuple[int, int]
1085
+ :param mcast_dim: Multicast dimension
1086
+ :type mcast_dim: int
1087
+
1088
+ :return: TMA atom and tensor
1089
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1090
+ """
1091
+ op = (
1092
+ cpasync.CopyBulkTensorTileG2SOp()
1093
+ if mcast_dim == 1
1094
+ else cpasync.CopyBulkTensorTileG2SMulticastOp()
1095
+ )
1096
+
1097
+ smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
1098
+ tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
1099
+ op,
1100
+ tensor,
1101
+ smem_layout,
1102
+ smem_tile,
1103
+ num_multicast=mcast_dim,
1104
+ )
1105
+ return tma_atom, tma_tensor
1106
+
1107
+ @staticmethod
1108
+ def is_valid_dtypes(
1109
+ a_dtype: Type[cutlass.Numeric],
1110
+ b_dtype: Type[cutlass.Numeric],
1111
+ acc_dtype: Type[cutlass.Numeric],
1112
+ d_dtype: Type[cutlass.Numeric],
1113
+ a_major: str,
1114
+ b_major: str,
1115
+ ) -> bool:
1116
+ """
1117
+ Check if the dtypes are valid
1118
+
1119
+ :param a_dtype: The data type of tensor A
1120
+ :type a_dtype: Type[cutlass.Numeric]
1121
+ :param b_dtype: The data type of tensor B
1122
+ :type b_dtype: Type[cutlass.Numeric]
1123
+ :param acc_dtype: The data type of the accumulator
1124
+ :type acc_dtype: Type[cutlass.Numeric]
1125
+ :param d_dtype: The data type of the output tensor
1126
+ :type d_dtype: Type[cutlass.Numeric]
1127
+ :param a_major: major mode of tensor A
1128
+ :type a_major: str
1129
+ :param b_major: major mode of tensor B
1130
+ :type b_major: str
1131
+
1132
+ :return: True if the dtypes are valid, False otherwise
1133
+ :rtype: bool
1134
+ """
1135
+ is_valid = True
1136
+ # tested a_dtype
1137
+ if a_dtype not in {
1138
+ cutlass.Float16,
1139
+ cutlass.BFloat16,
1140
+ cutlass.Float8E4M3FN,
1141
+ cutlass.Float8E5M2,
1142
+ }:
1143
+ is_valid = False
1144
+ # tested b_dtype
1145
+ if b_dtype not in {
1146
+ cutlass.Float16,
1147
+ cutlass.BFloat16,
1148
+ cutlass.Float8E4M3FN,
1149
+ cutlass.Float8E5M2,
1150
+ }:
1151
+ is_valid = False
1152
+ # tested acc_dtype
1153
+ if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
1154
+ is_valid = False
1155
+ # tested d_dtype
1156
+ if d_dtype not in {
1157
+ cutlass.Float32,
1158
+ cutlass.Float16,
1159
+ cutlass.BFloat16,
1160
+ cutlass.Float8E4M3FN,
1161
+ cutlass.Float8E5M2,
1162
+ }:
1163
+ is_valid = False
1164
+ # make sure a_dtype == b_dtype for Float16
1165
+ if a_dtype.width == 16 and a_dtype != b_dtype:
1166
+ is_valid = False
1167
+ # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2)
1168
+ if a_dtype.width != b_dtype.width:
1169
+ is_valid = False
1170
+
1171
+ # for Float8 types, this implementation only supports k-major layout
1172
+ if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
1173
+ is_valid = False
1174
+
1175
+ return is_valid
1176
+
1177
+
1178
+ def run(
1179
+ mnkl: Tuple[int, int, int, int],
1180
+ a_dtype: Type[cutlass.Numeric],
1181
+ b_dtype: Type[cutlass.Numeric],
1182
+ d_dtype: Type[cutlass.Numeric],
1183
+ acc_dtype: Type[cutlass.Numeric],
1184
+ a_major: str,
1185
+ b_major: str,
1186
+ d_major: str,
1187
+ tile_shape_mnk: Tuple[int, int, int],
1188
+ cluster_shape_mn: Tuple[int, int],
1189
+ tolerance: float,
1190
+ warmup_iterations: int,
1191
+ iterations: int,
1192
+ skip_ref_check: bool,
1193
+ use_cold_l2: bool = False,
1194
+ **kwargs,
1195
+ ):
1196
+ """
1197
+ Prepare A/B/C tensors, launch GPU kernel, and reference checking.
1198
+
1199
+ :param mnkl: Problem size (M, N, K, L)
1200
+ :type mnkl: Tuple[int, int, int, int]
1201
+ :param a_dtype: Data type for input tensor A
1202
+ :type a_dtype: Type[cutlass.Numeric]
1203
+ :param b_dtype: Data type for input tensor B
1204
+ :type b_dtype: Type[cutlass.Numeric]
1205
+ :param d_dtype: Data type for output tensor C
1206
+ :type d_dtype: Type[cutlass.Numeric]
1207
+ :param acc_dtype: Data type for accumulation during matrix multiplication
1208
+ :type acc_dtype: Type[cutlass.Numeric]
1209
+ :param a_major/b_major/d_major: Memory layout of tensor A/B/C
1210
+ :type a_major/b_major/d_major: str
1211
+ :param tile_shape_mnk: CTA tile shape (M, N, K)
1212
+ :type tile_shape_mnk: Tuple[int, int, int]
1213
+ :param cluster_shape_mn: Cluster shape (M, N)
1214
+ :type cluster_shape_mn: Tuple[int, int]
1215
+ :param tolerance: Tolerance value for reference validation comparison
1216
+ :type tolerance: float
1217
+ :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
1218
+ :type warmup_iterations: int, optional
1219
+ :param iterations: Number of benchmark iterations to run, defaults to 1
1220
+ :type iterations: int, optional
1221
+ :param skip_ref_check: Whether to skip reference result validation, defaults to False
1222
+ :type skip_ref_check: bool, optional
1223
+ :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
1224
+ :type use_cold_l2: bool, optional
1225
+ :return: Execution time of the GEMM kernel in microseconds
1226
+ :rtype: float
1227
+ """
1228
+
1229
+ print("Running Hopper Dense GEMM with:")
1230
+ print(f"mnkl: {mnkl}")
1231
+ print(f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
1232
+ print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
1233
+ print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
1234
+ print(f"Tolerance: {tolerance}")
1235
+ print(f"Warmup iterations: {warmup_iterations}")
1236
+ print(f"Iterations: {iterations}")
1237
+ print(f"Skip reference checking: {skip_ref_check}")
1238
+ print(f"Use cold L2: {use_cold_l2}")
1239
+
1240
+ # Unpack parameters
1241
+ m, n, k, l = mnkl
1242
+ cluster_shape_mnk = (*cluster_shape_mn, 1)
1243
+
1244
+ # Skip unsupported types
1245
+ if not HopperWgmmaGemmKernel.is_valid_dtypes(
1246
+ a_dtype, b_dtype, acc_dtype, d_dtype, a_major, b_major
1247
+ ):
1248
+ raise TypeError(
1249
+ f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {d_dtype}, {a_major=}, {b_major=}"
1250
+ )
1251
+
1252
+ # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero)
1253
+ if not torch.cuda.is_available():
1254
+ raise RuntimeError("GPU is required to run this example!")
1255
+
1256
+ torch.manual_seed(1111)
1257
+
1258
+ # Create and permute tensor A/B/C
1259
+ def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
1260
+ # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
1261
+ # else : (l, mode0, mode1) -> (mode0, mode1, l)
1262
+ shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
1263
+ permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
1264
+ is_unsigned = dtype in {cutlass.Uint8}
1265
+ # Temporarily use uint8 as torch does not support fp8 type
1266
+ torch_dtype = (
1267
+ cutlass_torch.dtype(dtype)
1268
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
1269
+ else torch.uint8
1270
+ )
1271
+
1272
+ # Create dtype torch tensor (cpu)
1273
+ torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
1274
+ shape,
1275
+ torch_dtype,
1276
+ permute_order=permute_order,
1277
+ # init_type=cutlass.torch.TensorInitType.RANDOM,
1278
+ # init_config=cutlass.torch.RandomInitConfig(
1279
+ # min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
1280
+ # ),
1281
+ init_type=cutlass.torch.TensorInitType.GAUSSIAN,
1282
+ init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
1283
+ )
1284
+ # Create dtype torch tensor (gpu)
1285
+ torch_tensor = torch_tensor_cpu.cuda()
1286
+
1287
+ # Create f32 torch tensor (cpu)
1288
+ f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
1289
+
1290
+ # Create dtype cute tensor (gpu)
1291
+ cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
1292
+ cute_tensor.element_type = dtype
1293
+ if is_dynamic_layout:
1294
+ cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
1295
+ cute_tensor = cutlass.torch.convert_cute_tensor(
1296
+ f32_torch_tensor,
1297
+ cute_tensor,
1298
+ dtype,
1299
+ is_dynamic_layout=is_dynamic_layout,
1300
+ )
1301
+
1302
+ return f32_torch_tensor, cute_tensor, torch_tensor
1303
+
1304
+ a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
1305
+ b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
1306
+ c, mC, c_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
1307
+
1308
+ gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk)
1309
+
1310
+ torch_stream = torch.cuda.Stream()
1311
+ stream = cuda.CUstream(torch_stream.cuda_stream)
1312
+ # compile gemm kernel
1313
+ compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
1314
+
1315
+ if not skip_ref_check:
1316
+ # execution
1317
+ compiled_gemm(mA, mB, mC, stream)
1318
+
1319
+ torch.cuda.synchronize()
1320
+
1321
+ # Ref check
1322
+ ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
1323
+
1324
+ if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
1325
+ # m major: (l, n, m) -> (m, n, l)
1326
+ # n major: (l, m, n) -> (m, n, l)
1327
+ permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
1328
+ shape = (l, m, n) if d_major == "n" else (l, n, m)
1329
+ f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
1330
+ shape,
1331
+ torch.uint8,
1332
+ permute_order=permute_order,
1333
+ init_type=cutlass_torch.TensorInitType.SKIP,
1334
+ ).cuda()
1335
+ # Create dtype cute tensor (gpu)
1336
+ ref_c_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
1337
+ leading_dim=(1 if d_major == "n" else 0)
1338
+ )
1339
+ ref_c_tensor.element_type = d_dtype
1340
+ ref_c_tensor = cutlass_torch.convert_cute_tensor(
1341
+ ref,
1342
+ ref_c_tensor,
1343
+ d_dtype,
1344
+ is_dynamic_layout=True,
1345
+ )
1346
+ ref_c = f8_torch_tensor.cpu()
1347
+ else:
1348
+ ref_c = ref.to(cutlass_torch.dtype(d_dtype))
1349
+
1350
+ torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
1351
+
1352
+ def generate_tensors():
1353
+ _, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
1354
+ _, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
1355
+ _, mC_workspace, _ = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
1356
+ return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
1357
+
1358
+ workspace_count = 1
1359
+ if use_cold_l2:
1360
+ one_workspace_bytes = (
1361
+ a_torch.numel() * a_torch.element_size()
1362
+ + b_torch.numel() * b_torch.element_size()
1363
+ + c_torch.numel() * c_torch.element_size()
1364
+ )
1365
+ workspace_count = testing.get_workspace_count(
1366
+ one_workspace_bytes, warmup_iterations, iterations
1367
+ )
1368
+
1369
+ exec_time = testing.benchmark(
1370
+ compiled_gemm,
1371
+ workspace_generator=generate_tensors,
1372
+ workspace_count=workspace_count,
1373
+ stream=stream,
1374
+ warmup_iterations=warmup_iterations,
1375
+ iterations=iterations,
1376
+ )
1377
+
1378
+ from triton.testing import do_bench
1379
+
1380
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1381
+
1382
+ flops = 2 * m * n * k * l
1383
+
1384
+ repeats = 30
1385
+ # repeats = 1
1386
+ warmup = 5
1387
+
1388
+ import time
1389
+
1390
+ time.sleep(0.5)
1391
+ fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
1392
+ timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
1393
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
1394
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
1395
+
1396
+ time.sleep(0.5)
1397
+ fn = lambda: compiled_gemm(mA, mB, mC, current_stream)
1398
+ timing = do_bench(fn, warmup=warmup, rep=repeats)
1399
+ tflops = flops / (timing * 1e9) # Convert to TFlops
1400
+ print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
1401
+
1402
+ time.sleep(0.5)
1403
+ fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
1404
+ timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
1405
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
1406
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
1407
+
1408
+ return exec_time # Return execution time in microseconds
1409
+
1410
+
1411
+ if __name__ == "__main__":
1412
+ args = parse_arguments()
1413
+ run(
1414
+ args.mnkl,
1415
+ args.a_dtype,
1416
+ args.b_dtype,
1417
+ args.d_dtype,
1418
+ args.acc_dtype,
1419
+ args.a_major,
1420
+ args.b_major,
1421
+ args.d_major,
1422
+ args.tile_shape_mnk,
1423
+ args.cluster_shape_mn,
1424
+ args.tolerance,
1425
+ args.warmup_iterations,
1426
+ args.iterations,
1427
+ args.skip_ref_check,
1428
+ args.use_cold_l2,
1429
+ )
1430
+ print("PASS")