quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2474 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+
7
+ # 1. Redistributions of source code must retain the above copyright notice, this
8
+ # list of conditions and the following disclaimer.
9
+
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+
14
+ # 3. Neither the name of the copyright holder nor the names of its
15
+ # contributors may be used to endorse or promote products derived from
16
+ # this software without specific prior written permission.
17
+
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+ import argparse
30
+ import enum
31
+ from typing import Tuple, Type, Callable, Optional
32
+ from functools import partial
33
+ import math
34
+
35
+ import cuda.bindings.driver as cuda
36
+
37
+ import torch
38
+
39
+ import cutlass
40
+ import cutlass.cute as cute
41
+ import cutlass.pipeline as pipeline
42
+ import cutlass.torch as cutlass_torch
43
+ from cutlass.cute.runtime import from_dlpack, make_ptr
44
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
45
+ import cutlass.utils.hopper_helpers as sm90_utils
46
+ from cutlass import Int32, const_expr
47
+
48
+ from quack.tile_scheduler import (
49
+ TileSchedulerArguments,
50
+ TileScheduler,
51
+ VarlenMTileSchedulerArguments,
52
+ VarlenMTileScheduler,
53
+ ParamsBase,
54
+ RasterOrderOption,
55
+ )
56
+ from quack.tensormap_manager import TensorMapManagerSm90
57
+
58
+ # return PipelineStateWAdvance instead of PipelineState
59
+ from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
60
+ import quack.utils as utils
61
+
62
+ """
63
+ A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
64
+ using CUTE DSL.
65
+ - Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
66
+ - Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
67
+ - Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
68
+
69
+ This GEMM kernel supports the following features:
70
+ - Utilizes Tensor Memory Access (TMA) for efficient memory operations
71
+ - Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations
72
+ - Implements TMA multicast with cluster to reduce L2 memory traffic
73
+ - Supports multi-stage pipeline to overlap computation and memory access
74
+
75
+ This GEMM works as follows:
76
+ 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
77
+ 2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction.
78
+ 3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations.
79
+
80
+ Hopper WGMMA instructions operate as follows:
81
+ - Read matrix A from SMEM
82
+ - Read matrix B from SMEM
83
+ - Perform MMA operation and store the result in Accumulator(register)
84
+
85
+ To run this example:
86
+
87
+ .. code-block:: bash
88
+
89
+ python examples/hopper/dense_gemm.py \
90
+ --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
91
+ --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
92
+ --d_dtype Float16 --acc_dtype Float32 \
93
+ --a_major k --b_major k --d_major n
94
+
95
+ The above example command compute batched gemm with M=8192, N=8192, K=8192,
96
+ batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape
97
+ is (1,1). The input, mma accumulator and output data type are set as fp16, fp32
98
+ and fp16, respectively.
99
+
100
+ To collect performance with NCU profiler:
101
+
102
+ .. code-block:: bash
103
+
104
+ ncu python examples/hopper/dense_gemm.py \
105
+ --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
106
+ --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
107
+ --d_dtype Float16 --acc_dtype Float32 \
108
+ --a_major k --b_major k --d_major n
109
+
110
+ Constraints:
111
+ * Supported input data types: fp16, fp8 (e4m3fn, e5m2)
112
+ * For fp16 types, A and B must have the same data type
113
+ * For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
114
+ * Fp8 types only support k-major layout
115
+ * Only fp32 accumulation is supported in this example
116
+ * CTA tile shape M must be 64/128
117
+ * CTA tile shape N must be 64/128/256
118
+ * CTA tile shape K must be 64
119
+ * Cluster shape M/N must be positive and power of 2, total cluster size <= 4
120
+ * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
121
+ i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
122
+ * OOB tiles are not allowed when TMA store is disabled
123
+ """
124
+
125
+
126
+ # /////////////////////////////////////////////////////////////////////////////
127
+ # Helpers to parse args
128
+ # /////////////////////////////////////////////////////////////////////////////
129
+ def parse_comma_separated_ints(s: str):
130
+ try:
131
+ return tuple([int(x.strip()) for x in s.split(",")])
132
+ except ValueError:
133
+ raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
134
+
135
+
136
+ def parse_arguments() -> argparse.Namespace:
137
+ parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.")
138
+
139
+ parser.add_argument(
140
+ "--mnkl",
141
+ type=parse_comma_separated_ints,
142
+ default=(4096, 4096, 4096, 1),
143
+ help="mnkl dimensions (comma-separated)",
144
+ )
145
+ parser.add_argument(
146
+ "--tile_shape_mnk",
147
+ type=parse_comma_separated_ints,
148
+ default=(128, 256, 64),
149
+ help="Cta tile shape (comma-separated)",
150
+ )
151
+ parser.add_argument(
152
+ "--cluster_shape_mn",
153
+ type=parse_comma_separated_ints,
154
+ choices=[(1, 1), (2, 1), (1, 2), (2, 2)],
155
+ default=(1, 1),
156
+ help="Cluster shape (comma-separated)",
157
+ )
158
+ parser.add_argument(
159
+ "--a_dtype",
160
+ type=cutlass.dtype,
161
+ default=cutlass.BFloat16,
162
+ )
163
+ parser.add_argument(
164
+ "--b_dtype",
165
+ type=cutlass.dtype,
166
+ default=cutlass.BFloat16,
167
+ )
168
+ parser.add_argument(
169
+ "--d_dtype",
170
+ type=cutlass.dtype,
171
+ default=cutlass.BFloat16,
172
+ )
173
+ parser.add_argument(
174
+ "--c_dtype",
175
+ type=cutlass.dtype,
176
+ default=None,
177
+ )
178
+ parser.add_argument(
179
+ "--acc_dtype",
180
+ type=cutlass.dtype,
181
+ default=cutlass.Float32,
182
+ )
183
+ parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
184
+ parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
185
+ parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
186
+ parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
187
+ parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
188
+ parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
189
+ parser.add_argument(
190
+ "--iterations",
191
+ type=int,
192
+ default=30,
193
+ help="Number of iterations to run the kernel",
194
+ )
195
+ parser.add_argument("--persistent", action="store_true", help="Persistent kernel")
196
+ parser.add_argument(
197
+ "--dynamic_persistent", action="store_true", help="Dynamic persistent kernel"
198
+ )
199
+ parser.add_argument("--pingpong", action="store_true", help="Pingpong kernel")
200
+ parser.add_argument("--varlen_m", action="store_true", help="Variable length M dimension")
201
+ parser.add_argument("--gather_A", action="store_true", help="Gather A")
202
+ parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum")
203
+ parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
204
+
205
+ args = parser.parse_args()
206
+
207
+ if len(args.mnkl) != 4:
208
+ parser.error("--mnkl must contain exactly 4 values")
209
+ if len(args.tile_shape_mnk) != 3:
210
+ parser.error("--tile_shape_mnk must contain exactly 3 values")
211
+ if len(args.cluster_shape_mn) != 2:
212
+ parser.error("--cluster_shape_mn must contain exactly 2 values")
213
+
214
+ return args
215
+
216
+
217
+ # /////////////////////////////////////////////////////////////////////////////
218
+ # Host setup and device kernel launch
219
+ # /////////////////////////////////////////////////////////////////////////////
220
+
221
+
222
+ class NamedBarrierGemm(enum.IntEnum):
223
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
224
+ # For mainloop load warps to signal that the epilogue load warp can start.
225
+ # This is to avoid loading C too early, interfering with loading A and B.
226
+ EpilogueLoad = enum.auto()
227
+ MmaWG0 = enum.auto()
228
+ MmaWG1 = enum.auto()
229
+ EpiWG0 = enum.auto()
230
+ EpiWG1 = enum.auto()
231
+
232
+
233
+ class HopperWgmmaGemmKernel:
234
+ """
235
+ This class implements batched matrix multiplication (C = A x B) with support for various data types
236
+ and architectural features specific to Hopper GPUs.
237
+
238
+ :param acc_dtype: Data type for accumulation during computation
239
+ :type acc_dtype: type[cutlass.Numeric]
240
+ :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
241
+ :type tile_shape_mnk: Tuple[int, int, int]
242
+ :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
243
+ :type cluster_shape_mnk: Tuple[int, int, int]
244
+
245
+ :note: Data type requirements:
246
+ - For 16-bit types: A and B must have the same data type
247
+ - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit
248
+ - Float8 types only support k-major layout
249
+
250
+ :note: Supported data types:
251
+ - Float16
252
+ - BFloat16
253
+ - Float8E4M3FN/Float8E5M2
254
+
255
+ :note: Supported accumulation types:
256
+ - Float32 (for all floating point inputs)
257
+
258
+ :note: Constraints:
259
+ - Cluster shape M/N must be positive and power of 2, total cluster size <= 4
260
+
261
+ Example:
262
+ >>> gemm = HopperWgmmaGemmKernel(
263
+ ... acc_dtype=cutlass.Float32,
264
+ ... tile_shape_mnk=(128, 256, 64),
265
+ ... cluster_shape_mnk=(1, 1, 1)
266
+ ... )
267
+ >>> gemm(a_tensor, b_tensor, c_tensor, stream)
268
+ """
269
+
270
+ bytes_per_tensormap = 128
271
+ num_tensormaps = 1 # For D only
272
+
273
+ def __init__(
274
+ self,
275
+ acc_dtype: Type[cutlass.Numeric],
276
+ a_dtype: Type[cutlass.Numeric],
277
+ tile_shape_mnk: Tuple[int, int, int],
278
+ cluster_shape_mnk: Tuple[int, int, int],
279
+ pingpong: bool = False,
280
+ is_persistent: bool = True,
281
+ fp8_fast_accum: bool = False,
282
+ gather_A: bool = False,
283
+ ):
284
+ """
285
+ Initializes the configuration for a Hopper dense GEMM kernel.
286
+
287
+ This configuration includes data types for operands, tile shape, cluster configuration,
288
+ and thread layout.
289
+
290
+ :param acc_dtype: Data type for accumulation during computation
291
+ :type acc_dtype: type[cutlass.Numeric]
292
+ :param tile_shape_mnk: Shape of the CTA tile (M,N,K)
293
+ :type tile_shape_mnk: Tuple[int, int, int]
294
+ :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
295
+ :type cluster_shape_mnk: Tuple[int, int, int]
296
+ """
297
+
298
+ self.acc_dtype = acc_dtype
299
+ self.pingpong = pingpong
300
+ self.is_persistent = is_persistent
301
+ if self.pingpong:
302
+ assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
303
+ self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
304
+ self.gather_A = gather_A
305
+ if gather_A:
306
+ assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
307
+ self.tensormap_update_mode = cutlass.utils.TensorMapUpdateMode.SMEM
308
+
309
+ self.cluster_shape_mnk = cluster_shape_mnk
310
+ self.tile_shape_mnk = tuple(tile_shape_mnk)
311
+ tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1]
312
+ # check the cta tile shape
313
+ if not self.pingpong:
314
+ if tile_M not in [64, 128, 192, 256, 320]:
315
+ raise ValueError("CTA tile shape M must be 64/128/192/256/320")
316
+ if tile_M in [192, 320]: # special case
317
+ tile_N_max = 256 if tile_M == 192 else 160
318
+ if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
319
+ raise ValueError(
320
+ f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
321
+ )
322
+ else:
323
+ if not (
324
+ (tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
325
+ ):
326
+ raise ValueError(
327
+ "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
328
+ )
329
+ else:
330
+ if tile_M not in [64, 128, 192]:
331
+ raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
332
+ tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
333
+ if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
334
+ raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
335
+ if not self.tile_shape_mnk[2] % 16 == 0:
336
+ raise ValueError("CTA tile shape K must be divisible by 16")
337
+
338
+ if not self.pingpong:
339
+ if tile_M == 320: # tile_M / 64 is not even so we have to split along N
340
+ atom_layout_m, atom_layout_n = 1, 2
341
+ elif tile_M == 192:
342
+ if tile_N <= 128:
343
+ atom_layout_m, atom_layout_n = 3, 1
344
+ else:
345
+ atom_layout_m, atom_layout_n = 1, 2
346
+ else:
347
+ atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
348
+ atom_layout_n = 1
349
+ assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
350
+ else:
351
+ atom_layout_m, atom_layout_n = 1, 1
352
+ self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
353
+
354
+ self.num_mcast_ctas_a = self.cluster_shape_mnk[1] if not self.gather_A else 1
355
+ self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
356
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
357
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
358
+
359
+ self.occupancy = 1
360
+ self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
361
+ if self.pingpong:
362
+ assert self.mma_warp_groups == 2
363
+ assert self.mma_warp_groups in [1, 2, 3]
364
+ self.num_threads_per_warp_group = 128
365
+ self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
366
+ self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
367
+ self.num_epi_threads = (
368
+ self.mma_warp_groups if not self.pingpong else 1
369
+ ) * self.num_threads_per_warp_group
370
+ self.num_ab_load_warps = 1 if not self.gather_A else 4
371
+ self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
372
+ self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
373
+ self.ab_load_warp_id = self.mma_warp_groups * 4
374
+ self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
375
+
376
+ regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // (
377
+ math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
378
+ )
379
+ if self.fp8_slow_accum:
380
+ regs_per_thread *= 2
381
+ if not self.gather_A:
382
+ if self.mma_warp_groups == 3:
383
+ self.num_regs_load, self.num_regs_mma = 32, 160
384
+ else:
385
+ heavy_register_pressure = regs_per_thread >= 208
386
+ self.num_regs_load, self.num_regs_mma = (
387
+ (40, 232) if not heavy_register_pressure else (24, 240)
388
+ )
389
+ else:
390
+ if self.mma_warp_groups == 3:
391
+ self.num_regs_load, self.num_regs_mma = 56, 152
392
+ else:
393
+ self.num_regs_load, self.num_regs_mma = (56, 224)
394
+
395
+ self.ab_stage = None
396
+ self.epi_stage = None
397
+
398
+ self.a_smem_layout_staged = None
399
+ self.b_smem_layout_staged = None
400
+ self.epi_smem_layout_staged = None
401
+ self.epi_tile = None
402
+
403
+ self.shared_storage = None
404
+ self.buffer_align_bytes = 1024
405
+
406
+ def _setup_attributes(self):
407
+ """Set up configurations that are dependent on GEMM inputs
408
+
409
+ This method configures various attributes based on the input tensor properties
410
+ (data types, leading dimensions) and kernel settings:
411
+ - Configuring tiled MMA
412
+ - Computing MMA/cluster/tile shapes
413
+ - Computing cluster layout
414
+ - Computing multicast CTAs for A/B
415
+ - Computing epilogue subtile
416
+ - Setting up A/B/C stage counts in shared memory
417
+ - Computing A/B/C shared memory layout
418
+ """
419
+
420
+ self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
421
+
422
+ self.epi_tile = self._sm90_compute_tile_shape_or_override(
423
+ self.tile_shape_mnk,
424
+ self.atom_layout_mnk,
425
+ self.d_dtype,
426
+ )
427
+
428
+ # Compute stage before compute smem layout
429
+ self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
430
+ self.tile_shape_mnk,
431
+ self.epi_tile,
432
+ self.a_dtype,
433
+ self.b_dtype,
434
+ self.d_dtype,
435
+ self.c_dtype,
436
+ self.smem_capacity,
437
+ self.occupancy,
438
+ # epi_smem will reuse smem ab if not persistent.
439
+ overlap_sD_sA=not self.is_persistent,
440
+ )
441
+ self.sched_stage = 2 if self.pingpong else 1
442
+
443
+ (
444
+ self.a_smem_layout_staged,
445
+ self.b_smem_layout_staged,
446
+ self.epi_smem_layout_staged,
447
+ self.epi_c_smem_layout_staged,
448
+ ) = self._make_smem_layouts(
449
+ self.tile_shape_mnk,
450
+ self.epi_tile,
451
+ self.a_dtype,
452
+ self.a_layout,
453
+ self.b_dtype,
454
+ self.b_layout,
455
+ self.ab_stage,
456
+ self.d_dtype,
457
+ self.d_layout,
458
+ self.epi_stage,
459
+ self.c_dtype,
460
+ self.c_layout,
461
+ self.epi_c_stage,
462
+ )
463
+
464
+ @cute.jit
465
+ def __call__(
466
+ self,
467
+ mA: cute.Tensor,
468
+ mB: cute.Tensor,
469
+ mD: cute.Tensor,
470
+ mC: Optional[cute.Tensor],
471
+ mAIdx: Optional[cute.Tensor],
472
+ mCuSeqlensM: Optional[cute.Tensor],
473
+ mTensormaps: Optional[cute.Tensor],
474
+ tile_count_semaphore: Optional[cute.Pointer],
475
+ max_active_clusters: Int32,
476
+ stream: cuda.CUstream,
477
+ ):
478
+ """Execute the GEMM operation in steps:
479
+ - Setup static attributes
480
+ - Setup TMA load/store atoms and tensors
481
+ - Compute grid size
482
+ - Define shared storage for kernel
483
+ - Launch the kernel synchronously
484
+
485
+ :param mA: Input tensor A
486
+ :type mA: cute.Tensor
487
+ :param mB: Input tensor B
488
+ :type mB: cute.Tensor
489
+ :param mD: Output tensor D
490
+ :type mD: cute.Tensor
491
+ :param stream: CUDA stream for asynchronous execution
492
+ :type stream: cuda.CUstream
493
+ """
494
+
495
+ # setup static attributes before smem/grid/tma computation
496
+ self.a_dtype = mA.element_type
497
+ self.b_dtype = mB.element_type
498
+ self.d_dtype = mD.element_type
499
+ self.c_dtype = mC.element_type if mC is not None else None
500
+ self.a_layout = cutlass.utils.LayoutEnum.from_tensor(mA)
501
+ self.b_layout = cutlass.utils.LayoutEnum.from_tensor(mB)
502
+ self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
503
+ self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
504
+
505
+ if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
506
+ raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
507
+ if const_expr(self.a_dtype.width != self.b_dtype.width):
508
+ raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
509
+ if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
510
+ raise TypeError("a_dtype should be float16 or float8")
511
+ assert (mAIdx is not None) == self.gather_A
512
+
513
+ # Assume all strides are divisible by 128 bits except the last stride
514
+ new_stride = lambda t: tuple(
515
+ cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
516
+ for s in t.stride
517
+ )
518
+ mA, mD = [
519
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
520
+ for t in (mA, mD)
521
+ ]
522
+
523
+ self._setup_attributes()
524
+
525
+ tiled_mma = sm90_utils.make_trivial_tiled_mma(
526
+ self.a_dtype,
527
+ self.b_dtype,
528
+ self.a_layout.sm90_mma_major_mode(),
529
+ self.b_layout.sm90_mma_major_mode(),
530
+ self.acc_dtype,
531
+ self.atom_layout_mnk,
532
+ tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
533
+ )
534
+ if const_expr(self.atom_layout_mnk[1] > 1):
535
+ # If N dimension is split among 2 WGs, we need to permute the N dimension so
536
+ # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
537
+ # containing accumulators that are next to each other in the N dimension.
538
+ # Without permutation WG0 would write to epi smem of size (64, 16) and
539
+ # WG1 would write to a separate epi smem of size (64, 16) that's far away.
540
+ atom_n = self.atom_layout_mnk[1]
541
+ permutation_n = cute.make_ordered_layout(
542
+ (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
543
+ )
544
+ tiled_mma = cute.make_tiled_mma(
545
+ cute.make_mma_atom(tiled_mma.op),
546
+ self.atom_layout_mnk,
547
+ permutation_mnk=(None, permutation_n, None),
548
+ )
549
+
550
+ if const_expr(not self.gather_A):
551
+ tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
552
+ mA,
553
+ self.a_smem_layout_staged,
554
+ (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
555
+ self.cluster_shape_mnk[1],
556
+ )
557
+ else:
558
+ tma_atom_a, tma_tensor_a = None, None
559
+
560
+ tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
561
+ mB,
562
+ self.b_smem_layout_staged,
563
+ (self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
564
+ self.cluster_shape_mnk[0],
565
+ )
566
+
567
+ tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
568
+ mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
569
+ )
570
+
571
+ if const_expr(mC is not None):
572
+ tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
573
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
574
+ )
575
+ else:
576
+ tma_atom_c, tma_tensor_c = None, None
577
+
578
+ if const_expr(mCuSeqlensM is None):
579
+ problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (
580
+ mD.shape[2],
581
+ )
582
+ TileSchedulerCls = TileScheduler
583
+ tile_sched_args = TileSchedulerArguments(
584
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
585
+ raster_order=RasterOrderOption.Heuristic,
586
+ group_size=8,
587
+ cluster_shape_mnk=self.cluster_shape_mnk,
588
+ tile_count_semaphore=tile_count_semaphore,
589
+ is_persistent=self.is_persistent,
590
+ )
591
+ else:
592
+ assert mTensormaps is not None
593
+ problem_shape_ntile_mnl = (
594
+ None,
595
+ cute.ceil_div(mD.shape[1], self.tile_shape_mnk[1]),
596
+ mCuSeqlensM.shape[0] - 1,
597
+ )
598
+ TileSchedulerCls = VarlenMTileScheduler
599
+ tile_sched_args = VarlenMTileSchedulerArguments(
600
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
601
+ total_m=mD.shape[0],
602
+ cu_seqlens_m=mCuSeqlensM,
603
+ raster_order=RasterOrderOption.Heuristic,
604
+ group_size=8,
605
+ tile_shape_mnk=self.tile_shape_mnk,
606
+ cluster_shape_mnk=self.cluster_shape_mnk,
607
+ tile_count_semaphore=tile_count_semaphore,
608
+ is_persistent=self.is_persistent,
609
+ )
610
+ tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
611
+ grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
612
+
613
+ epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if self.is_persistent else 0
614
+ epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
615
+
616
+ size_tensormap_in_i64 = (
617
+ 0
618
+ if mCuSeqlensM is None
619
+ or self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.GMEM
620
+ else HopperWgmmaGemmKernel.num_tensormaps
621
+ * HopperWgmmaGemmKernel.bytes_per_tensormap
622
+ // 8
623
+ ) * (1 if not self.pingpong else 2)
624
+
625
+ @cute.struct
626
+ class SharedStorage:
627
+ tensormap_buffer: cute.struct.Align[
628
+ cute.struct.MemRange[cutlass.Int64, size_tensormap_in_i64],
629
+ 64,
630
+ ]
631
+ ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
632
+ epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
633
+ sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
634
+ tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
635
+ sD: cute.struct.Align[
636
+ cute.struct.MemRange[self.d_dtype, epi_smem_size],
637
+ self.buffer_align_bytes,
638
+ ]
639
+ sC: cute.struct.Align[
640
+ cute.struct.MemRange[
641
+ self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
642
+ ],
643
+ self.buffer_align_bytes,
644
+ ]
645
+ sA: cute.struct.Align[
646
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
647
+ self.buffer_align_bytes,
648
+ ]
649
+ sB: cute.struct.Align[
650
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
651
+ self.buffer_align_bytes,
652
+ ]
653
+
654
+ self.shared_storage = SharedStorage
655
+
656
+ # Launch the kernel synchronously
657
+ self.kernel(
658
+ tma_atom_a,
659
+ tma_tensor_a if const_expr(not self.gather_A) else mA,
660
+ tma_atom_b,
661
+ tma_tensor_b,
662
+ tma_atom_d,
663
+ tma_tensor_d,
664
+ mD,
665
+ tma_atom_c,
666
+ tma_tensor_c,
667
+ mAIdx,
668
+ mCuSeqlensM,
669
+ mTensormaps,
670
+ tiled_mma,
671
+ self.cluster_layout_mnk,
672
+ self.a_smem_layout_staged,
673
+ self.b_smem_layout_staged,
674
+ self.epi_smem_layout_staged,
675
+ self.epi_c_smem_layout_staged,
676
+ tile_sched_params,
677
+ TileSchedulerCls,
678
+ ).launch(
679
+ grid=grid,
680
+ block=[self.threads_per_cta, 1, 1],
681
+ cluster=self.cluster_shape_mnk,
682
+ smem=self.shared_storage.size_in_bytes(),
683
+ stream=stream,
684
+ min_blocks_per_mp=1,
685
+ )
686
+ return
687
+
688
+ # GPU device kernel
689
+ @cute.kernel
690
+ def kernel(
691
+ self,
692
+ tma_atom_a: Optional[cute.CopyAtom],
693
+ mA_mkl: cute.Tensor,
694
+ tma_atom_b: cute.CopyAtom,
695
+ mB_nkl: cute.Tensor,
696
+ tma_atom_d: cute.CopyAtom,
697
+ mD_mnl_tma: cute.Tensor,
698
+ mD_mnl: cute.Tensor,
699
+ tma_atom_c: Optional[cute.CopyAtom],
700
+ mC_mnl: Optional[cute.Tensor],
701
+ mAIdx: Optional[cute.Tensor],
702
+ cu_seqlens_m: Optional[cute.Tensor],
703
+ tensormaps: Optional[cute.Tensor],
704
+ tiled_mma: cute.TiledMma,
705
+ cluster_layout_mnk: cute.Layout,
706
+ a_smem_layout_staged: cute.ComposedLayout,
707
+ b_smem_layout_staged: cute.ComposedLayout,
708
+ epi_smem_layout_staged: cute.ComposedLayout,
709
+ epi_c_smem_layout_staged: cute.ComposedLayout,
710
+ tile_sched_params: ParamsBase,
711
+ TileSchedulerCls: cutlass.Constexpr[Callable],
712
+ ):
713
+ """
714
+ GPU device kernel performing the batched GEMM computation.
715
+
716
+ :param tma_atom_a: TMA copy atom for A tensor
717
+ :type tma_atom_a: cute.CopyAtom
718
+ :param mA_mkl: Input tensor A
719
+ :type mA_mkl: cute.Tensor
720
+ :param tma_atom_b: TMA copy atom for B tensor
721
+ :type tma_atom_b: cute.CopyAtom
722
+ :param mB_nkl: Input tensor B
723
+ :type mB_nkl: cute.Tensor
724
+ :param tma_atom_d: TMA copy atom for D tensor
725
+ :type tma_atom_d: cute.CopyAtom
726
+ :param mD_mnl_tma: Output tensor D
727
+ :type mD_mnl_tma: cute.Tensor
728
+ :param tiled_mma: Tiled MMA object
729
+ :type tiled_mma: cute.TiledMma
730
+ :param cluster_layout_mnk: CTA layout
731
+ :type cluster_layout_mnk: cute.Layout
732
+ :param a_smem_layout_staged: Shared memory layout for A
733
+ :type a_smem_layout_staged: cute.ComposedLayout
734
+ :param b_smem_layout_staged: Shared memory layout for B
735
+ :type b_smem_layout_staged: cute.ComposedLayout
736
+ :param epi_smem_layout_staged: Shared memory layout for epilogue
737
+ :type epi_smem_layout_staged: cute.ComposedLayout
738
+ """
739
+
740
+ varlen = const_expr(cu_seqlens_m is not None)
741
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
742
+
743
+ # /////////////////////////////////////////////////////////////////////////////
744
+ # Prefetch Tma desc
745
+ # /////////////////////////////////////////////////////////////////////////////
746
+ if warp_idx == self.ab_load_warp_id:
747
+ if const_expr(tma_atom_a is not None):
748
+ cpasync.prefetch_descriptor(tma_atom_a)
749
+ cpasync.prefetch_descriptor(tma_atom_b)
750
+ cpasync.prefetch_descriptor(tma_atom_d)
751
+ if const_expr(tma_atom_c is not None):
752
+ cpasync.prefetch_descriptor(tma_atom_c)
753
+
754
+ a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
755
+ b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
756
+ tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
757
+ if const_expr(not self.gather_A):
758
+ tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
759
+
760
+ # /////////////////////////////////////////////////////////////////////////////
761
+ # Alloc and init AB full/empty + ACC full mbar (pipeline)
762
+ # /////////////////////////////////////////////////////////////////////////////
763
+ smem = cutlass.utils.SmemAllocator()
764
+ storage = smem.allocate(self.shared_storage)
765
+
766
+ # Threads/warps participating in this pipeline
767
+ producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
768
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
769
+ # Each warp will contribute to the arrive count with the number of mcast size
770
+ mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
771
+ consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
772
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(
773
+ pipeline.Agent.Thread, consumer_arrive_cnt
774
+ )
775
+
776
+ cta_layout_vmnk = cute.make_layout((1, *cluster_layout_mnk.shape))
777
+ pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
778
+ ab_pipeline = pipeline_cls.create(
779
+ barrier_storage=storage.ab_pipeline_array_ptr.data_ptr(),
780
+ num_stages=self.ab_stage,
781
+ producer_group=ab_pipeline_producer_group,
782
+ consumer_group=ab_pipeline_consumer_group,
783
+ tx_count=tma_copy_bytes,
784
+ cta_layout_vmnk=cta_layout_vmnk,
785
+ )
786
+
787
+ if const_expr(mC_mnl is not None):
788
+ # Threads/warps participating in this pipeline
789
+ epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
790
+ # Each warp will contribute 1 to the arrive count
791
+ consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
792
+ epi_pipeline_consumer_group = pipeline.CooperativeGroup(
793
+ pipeline.Agent.Thread, consumer_arrive_cnt
794
+ )
795
+ c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
796
+ tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
797
+ epi_pipeline = pipeline.PipelineTmaAsync.create(
798
+ barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
799
+ num_stages=self.epi_c_stage,
800
+ producer_group=epi_pipeline_producer_group,
801
+ consumer_group=epi_pipeline_consumer_group,
802
+ tx_count=tma_copy_c_bytes,
803
+ )
804
+ else:
805
+ epi_pipeline = None
806
+
807
+ if const_expr(tile_sched_params.tile_count_semaphore is not None):
808
+ # Dynamic persistent scheduler
809
+ # Threads/warps participating in this pipeline
810
+ sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
811
+ cluster_size = cute.size(cluster_layout_mnk)
812
+ # Each warp that are not the scheduler warp will contribute 1 to the arrive count
813
+ consumer_arrive_cnt = (
814
+ (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
815
+ ) * cluster_size - 1
816
+ sched_pipeline_consumer_group = pipeline.CooperativeGroup(
817
+ pipeline.Agent.Thread, consumer_arrive_cnt
818
+ )
819
+ sched_pipeline = pipeline.PipelineAsync.create(
820
+ barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
821
+ num_stages=self.sched_stage,
822
+ producer_group=sched_pipeline_producer_group,
823
+ consumer_group=sched_pipeline_consumer_group,
824
+ # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
825
+ consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
826
+ )
827
+ tile_count = storage.tile_count.get_tensor((self.sched_stage,))
828
+ else:
829
+ sched_pipeline = None
830
+ tile_count = None
831
+
832
+ # ///////////////////////////////////////////////////////////////////////////////
833
+ # Generate smem tensor A/B
834
+ # ///////////////////////////////////////////////////////////////////////////////
835
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
836
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
837
+ if const_expr(not self.is_persistent):
838
+ sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
839
+ sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
840
+ else:
841
+ sD = storage.sD.get_tensor(
842
+ epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
843
+ )
844
+ if const_expr(mC_mnl is not None):
845
+ sC = storage.sC.get_tensor(
846
+ epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
847
+ )
848
+ else:
849
+ sC = None
850
+
851
+ # Get tensormap buffer address
852
+ if const_expr(varlen):
853
+ grid_dim = cute.arch.grid_dim()
854
+ bid = cute.arch.block_idx()
855
+ tensormap_workspace_idx = (
856
+ bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0]
857
+ )
858
+ # TODO: this is only for D, not for A/B
859
+ if const_expr(self.pingpong):
860
+ tensormap_workspace_idx = tensormap_workspace_idx * 2 + warp_idx // 4
861
+ tensormap_manager = TensorMapManagerSm90(
862
+ self.tensormap_update_mode, HopperWgmmaGemmKernel.bytes_per_tensormap
863
+ )
864
+ tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
865
+ tensormaps[tensormap_workspace_idx, None].iterator
866
+ )
867
+ if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.SMEM):
868
+ tensormap_smem_ptr = storage.tensormap_buffer.data_ptr()
869
+ tensormap_d_smem_ptr = tensormap_smem_ptr + (warp_idx // 4) * (
870
+ HopperWgmmaGemmKernel.bytes_per_tensormap // 8
871
+ )
872
+ # Need this, otherwise "expected tma descriptor pointer to have alignment at least 64, but got 8"
873
+ tensormap_d_smem_ptr = cute.make_ptr(
874
+ cutlass.Int64,
875
+ tensormap_d_smem_ptr.toint(),
876
+ cute.AddressSpace.smem,
877
+ assumed_align=64,
878
+ )
879
+ tensormap_d_init_ptr = tensormap_d_smem_ptr
880
+ else:
881
+ tensormap_d_smem_ptr = None
882
+ tensormap_d_init_ptr = tensormap_d_ptr
883
+ else:
884
+ tensormap_d_smem_ptr = None
885
+ tensormap_manager, tensormap_d_ptr, tensormap_d_init_ptr = None, None, None
886
+
887
+ TileSchedulerCls = partial(
888
+ TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
889
+ )
890
+
891
+ k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
892
+ c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
893
+
894
+ if warp_idx >= self.ab_load_warp_id:
895
+ cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
896
+ if const_expr(mC_mnl is not None):
897
+ epi_load_barrier = pipeline.NamedBarrier(
898
+ barrier_id=int(NamedBarrierGemm.EpilogueLoad),
899
+ num_threads=self.num_ab_load_threads + self.num_epi_load_threads,
900
+ )
901
+ else:
902
+ epi_load_barrier = None
903
+ if (
904
+ warp_idx >= self.ab_load_warp_id
905
+ and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
906
+ ):
907
+ # ///////////////////////////////////////////////////////////////////////////////
908
+ # Get mcast mask
909
+ # ///////////////////////////////////////////////////////////////////////////////
910
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
911
+ cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
912
+ a_mcast_mask = cute.make_layout_image_mask(
913
+ cluster_layout_mnk, cluster_coord_mnk, mode=1
914
+ )
915
+ b_mcast_mask = cute.make_layout_image_mask(
916
+ cluster_layout_mnk, cluster_coord_mnk, mode=0
917
+ )
918
+ a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
919
+ b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
920
+
921
+ # Persistent tile scheduling loop
922
+ is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
923
+ if const_expr(cute.size(cluster_layout_mnk) > 1):
924
+ is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
925
+ tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
926
+ work_tile = tile_scheduler.initial_work_tile_info()
927
+ ab_producer_state = make_pipeline_state(
928
+ pipeline.PipelineUserType.Producer, self.ab_stage
929
+ )
930
+ do_epi_load_barrier_arrive = cutlass.Boolean(True)
931
+ while work_tile.is_valid_tile:
932
+ tile_coord_mnkl = work_tile.tile_idx
933
+ batch_idx = tile_coord_mnkl[3]
934
+ # ///////////////////////////////////////////////////////////////////////////
935
+ # Local_tile partition global tensors
936
+ # ///////////////////////////////////////////////////////////////////////////
937
+ if const_expr(not self.gather_A):
938
+ if const_expr(cu_seqlens_m is not None):
939
+ mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
940
+ else:
941
+ mA_mk = mA_mkl[None, None, batch_idx]
942
+ # (bM, bK, RestK)
943
+ gA_k = cute.local_tile(
944
+ mA_mk,
945
+ cute.select(self.tile_shape_mnk, [0, 2]),
946
+ (tile_coord_mnkl[0], None),
947
+ )
948
+ else:
949
+ mA_mk = mA_mkl
950
+ if const_expr(cu_seqlens_m is not None):
951
+ mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
952
+ else:
953
+ mAIdx_mk = mAIdx[None, batch_idx]
954
+ gAIdx = cute.local_tile(
955
+ mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
956
+ )
957
+ # (bN, bK, RestK)
958
+ gB_k = cute.local_tile(
959
+ mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
960
+ )
961
+ # //////////////////////////////////////////////////////////////////////////
962
+ # Partition shared tensor for TMA load A/B
963
+ # //////////////////////////////////////////////////////////////////////////
964
+ # TMA load A partition_S/D
965
+ a_cta_layout = cute.make_layout(
966
+ cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
967
+ )
968
+ a_cta_crd = cluster_coord_mnk[1]
969
+ if const_expr(not self.gather_A):
970
+ # ((atom_v, rest_v), STAGE)
971
+ # ((atom_v, rest_v), RestK)
972
+ tAsA, tAgA_k = cpasync.tma_partition(
973
+ tma_atom_a,
974
+ a_cta_crd,
975
+ a_cta_layout,
976
+ cute.group_modes(sA, 0, 2),
977
+ cute.group_modes(gA_k, 0, 2),
978
+ )
979
+ copy_A = partial(cute.copy, tma_atom_a, mcast_mask=a_mcast_mask)
980
+ else:
981
+ tiled_copy_A = self._make_gmem_tiled_copy_A(
982
+ mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
983
+ )
984
+ tidx = (
985
+ cute.arch.thread_idx()[0]
986
+ - self.mma_warp_groups * self.num_threads_per_warp_group
987
+ )
988
+ thr_copy_A = tiled_copy_A.get_slice(tidx)
989
+ # (atom_v, CPY_M, 1, STAGE)
990
+ tAsA = thr_copy_A.partition_D(sA)
991
+ assert tAsA.shape[2] == 1
992
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
993
+ copy_A = partial(cute.copy, tiled_copy_A)
994
+ # TMA load B partition_S/D
995
+ b_cta_layout = cute.make_layout(
996
+ cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
997
+ )
998
+ b_cta_crd = cluster_coord_mnk[0]
999
+ # ((atom_v, rest_v), STAGE)
1000
+ # ((atom_v, rest_v), RestK)
1001
+ tBsB, tBgB_k = cpasync.tma_partition(
1002
+ tma_atom_b,
1003
+ b_cta_crd,
1004
+ b_cta_layout,
1005
+ cute.group_modes(sB, 0, 2),
1006
+ cute.group_modes(gB_k, 0, 2),
1007
+ )
1008
+ copy_B = partial(cute.copy, tma_atom_b, mcast_mask=b_mcast_mask)
1009
+ if const_expr(not self.gather_A):
1010
+ ab_producer_state = self.load_AB(
1011
+ ab_pipeline,
1012
+ ab_producer_state,
1013
+ copy_A,
1014
+ tAgA_k,
1015
+ tAsA,
1016
+ copy_B,
1017
+ tBgB_k,
1018
+ tBsB,
1019
+ )
1020
+ else:
1021
+ limit_m = (
1022
+ mAIdx.shape[0]
1023
+ if const_expr(cu_seqlens_m is None)
1024
+ else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
1025
+ )
1026
+ ab_producer_state = self.load_AB_gather_A(
1027
+ ab_pipeline,
1028
+ ab_producer_state,
1029
+ thr_copy_A,
1030
+ mA_mk,
1031
+ tAsA,
1032
+ gAIdx,
1033
+ copy_B,
1034
+ tBgB_k,
1035
+ tBsB,
1036
+ limit_A=(
1037
+ limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
1038
+ mA_mk.shape[1],
1039
+ ),
1040
+ )
1041
+ if const_expr(epi_load_barrier is not None):
1042
+ # In the first work tile, the epi load warp will wait for the signal
1043
+ # from the mainloop load warp to start loading C, to avoid interfering
1044
+ # with loading A and B.
1045
+ if do_epi_load_barrier_arrive:
1046
+ epi_load_barrier.arrive()
1047
+ do_epi_load_barrier_arrive = cutlass.Boolean(False)
1048
+ tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
1049
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1050
+ work_tile = tile_scheduler.get_current_work()
1051
+ # End of persistent scheduler loop
1052
+ if const_expr(self.pingpong):
1053
+ # Need to write the tile_idx to smem for the next WG in the pingpong mode
1054
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1055
+ ab_pipeline.producer_tail(ab_producer_state)
1056
+ if is_scheduler_warp:
1057
+ tile_scheduler.producer_tail()
1058
+
1059
+ # if const_expr(mC_mnl is not None):
1060
+ # if warp_idx == self.epi_load_warp_id:
1061
+ # epi_producer_state = make_pipeline_state(
1062
+ # pipeline.PipelineUserType.Producer, self.epi_c_stage
1063
+ # )
1064
+ # do_epi_load_barrier_wait = cutlass.Boolean(True)
1065
+ # tile_scheduler = TileSchedulerCls()
1066
+ # work_tile = tile_scheduler.initial_work_tile_info()
1067
+ # while work_tile.is_valid_tile:
1068
+ # tile_coord_mnkl = work_tile.tile_idx
1069
+ # batch_idx = tile_coord_mnkl[3]
1070
+ # if const_expr(cu_seqlens_m is not None):
1071
+ # mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
1072
+ # else:
1073
+ # mC_mn = mC_mnl[None, None, batch_idx]
1074
+ # # (bM, bN)
1075
+ # gC = cute.local_tile(
1076
+ # mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1077
+ # )
1078
+ # tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
1079
+ # bGS_sC, bGS_gC = cpasync.tma_partition(
1080
+ # tma_atom_c,
1081
+ # 0,
1082
+ # cute.make_layout(1),
1083
+ # cute.group_modes(sC, 0, 2),
1084
+ # tCgC_for_tma_partition,
1085
+ # )
1086
+ # if do_epi_load_barrier_wait:
1087
+ # epi_load_barrier.arrive_and_wait()
1088
+ # do_epi_load_barrier_wait = cutlass.Boolean(False)
1089
+ # epi_tile_num = const_expr(cute.size(tCgC_for_tma_partition, mode=[1]))
1090
+ # epi_tile_shape = tCgC_for_tma_partition.shape[1]
1091
+ # for epi_idx in cutlass.range(epi_tile_num, unroll=1):
1092
+ # epi_pipeline.producer_acquire(epi_producer_state)
1093
+ # # Get the global memory coordinate for the current epi tile
1094
+ # epi_tile_layout = cute.make_layout(
1095
+ # epi_tile_shape, stride=(epi_tile_shape[1], 1)
1096
+ # )
1097
+ # gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1098
+ # cute.copy(
1099
+ # tma_atom_c,
1100
+ # bGS_gC[None, gmem_coord],
1101
+ # bGS_sC[None, epi_producer_state.index],
1102
+ # tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1103
+ # )
1104
+ # # Epi pipeline's producer commit is a NOP
1105
+ # epi_pipeline.producer_commit(epi_producer_state)
1106
+ # epi_producer_state.advance()
1107
+ # tile_scheduler.advance_to_next_work()
1108
+ # work_tile = tile_scheduler.get_current_work()
1109
+ # # End of persistent scheduler loop
1110
+ # epi_pipeline.producer_tail(epi_producer_state)
1111
+
1112
+ if warp_idx < self.ab_load_warp_id:
1113
+ cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
1114
+ is_tma_warp = cutlass.Boolean(
1115
+ (not self.pingpong and warp_idx == 0)
1116
+ or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
1117
+ )
1118
+ if const_expr(varlen):
1119
+ # initialize tensormap for D
1120
+ tensormap_manager.init_tensormap_from_atom(
1121
+ tma_atom_d,
1122
+ tensormap_d_init_ptr,
1123
+ is_manager_warp=is_tma_warp,
1124
+ )
1125
+ # //////////////////////////////////////////////////////////////////////////////
1126
+ # Partition global tensor for TiledMMA_A/B/C
1127
+ # //////////////////////////////////////////////////////////////////////////////
1128
+ tidx, _, _ = cute.arch.thread_idx()
1129
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
1130
+ if const_expr(self.pingpong):
1131
+ tidx = tidx % self.num_threads_per_warp_group
1132
+ warp_group_thread_layout = cute.make_layout(
1133
+ self.mma_warp_groups if not self.pingpong else 1,
1134
+ stride=self.num_threads_per_warp_group,
1135
+ )
1136
+ thr_mma = tiled_mma.get_slice(
1137
+ warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
1138
+ )
1139
+
1140
+ # //////////////////////////////////////////////////////////////////////////////
1141
+ # Make fragments
1142
+ # //////////////////////////////////////////////////////////////////////////////
1143
+ tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
1144
+ tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
1145
+
1146
+ acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
1147
+ acc = cute.make_fragment(acc_shape, self.acc_dtype)
1148
+ if const_expr(self.fp8_slow_accum):
1149
+ acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
1150
+ else:
1151
+ acc_slow = None
1152
+
1153
+ if const_expr(self.pingpong):
1154
+ if warp_group_idx == 0:
1155
+ # WG0 needs a start signal at the very beginning
1156
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
1157
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
1158
+
1159
+ ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
1160
+ epi_read_state = make_pipeline_state(
1161
+ pipeline.PipelineUserType.Consumer, self.epi_c_stage
1162
+ )
1163
+ epi_producer_state = make_pipeline_state(
1164
+ pipeline.PipelineUserType.Producer, self.epi_c_stage
1165
+ )
1166
+ tile_scheduler = TileSchedulerCls()
1167
+ if const_expr(self.pingpong):
1168
+ if warp_idx >= 4:
1169
+ # Advance 2nd Math WG to the next work tile for the startup
1170
+ tile_scheduler.advance_to_next_work()
1171
+ # Advance 2nd Math WG pipeline states to the end of 1st Math WG
1172
+ ab_read_state.advance_iters(k_tile_cnt)
1173
+ epi_read_state.advance_iters(c_tile_cnt)
1174
+ epi_producer_state.advance_iters(c_tile_cnt)
1175
+ work_tile = tile_scheduler.initial_work_tile_info()
1176
+ if const_expr(varlen):
1177
+ # wait tensormap initialization complete before update
1178
+ tensormap_manager.fence_tensormap_initialization()
1179
+ # batch index of last tile
1180
+ last_batch_idx = cutlass.Int32(-1)
1181
+ while work_tile.is_valid_tile:
1182
+ tile_coord_mnkl = work_tile.tile_idx
1183
+ batch_idx = tile_coord_mnkl[3]
1184
+ if const_expr(varlen):
1185
+ is_group_changed = batch_idx != last_batch_idx
1186
+ last_batch_idx = batch_idx
1187
+ if is_group_changed:
1188
+ # construct tensor D based on real address, shape and stride information
1189
+ tensormap_manager.update_tensormap_shape(
1190
+ ((tensormap_d_ptr),),
1191
+ is_manager_warp=is_tma_warp,
1192
+ tensormap_smem_ptr=(tensormap_d_smem_ptr,),
1193
+ shapes=(cu_seqlens_m[batch_idx + 1],),
1194
+ orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
1195
+ )
1196
+
1197
+ ab_read_state, tiled_mma = self.mma(
1198
+ ab_pipeline,
1199
+ ab_read_state,
1200
+ tiled_mma,
1201
+ tCrA,
1202
+ tCrB,
1203
+ acc,
1204
+ acc_slow,
1205
+ k_tile_cnt,
1206
+ warp_group_idx,
1207
+ )
1208
+ if const_expr(self.pingpong):
1209
+ # Update starting mainloop pipeline state for the next tile
1210
+ ab_read_state.advance_iters(k_tile_cnt)
1211
+
1212
+ # /////////////////////////////////////////////////////////////////////////////
1213
+ # EPILOGUE
1214
+ # /////////////////////////////////////////////////////////////////////////////
1215
+ if const_expr(self.pingpong):
1216
+ self.pingpong_barrier_sync(warp_group_idx, "epi")
1217
+
1218
+ epilogue_barrier = pipeline.NamedBarrier(
1219
+ barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
1220
+ )
1221
+
1222
+ # Wait for all warp groups in the thread block to finish, because smem for tensor
1223
+ # A in the mainloop is reused in the epilogue if not persistent.
1224
+ if const_expr(not self.is_persistent):
1225
+ epilogue_barrier.arrive_and_wait()
1226
+
1227
+ if const_expr(varlen):
1228
+ # ensure the update to tensormap has completed before using it
1229
+ if is_group_changed:
1230
+ if is_tma_warp:
1231
+ tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1232
+
1233
+ # Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
1234
+ # get st.matrix with num_matrices=4
1235
+ copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
1236
+ self.d_layout, elem_ty_d=self.d_dtype, elem_ty_acc=self.acc_dtype
1237
+ )
1238
+ copy_atom_C = cute.make_copy_atom(
1239
+ warp.StMatrix8x8x16bOp(
1240
+ self.d_layout.is_m_major_c(),
1241
+ num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
1242
+ ),
1243
+ cutlass.Float16, # this is just to get the right source layout
1244
+ )
1245
+ tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
1246
+ tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
1247
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1248
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1249
+ tRS_sD = thr_copy_r2s.partition_D(sD)
1250
+ # (R2S, R2S_M, R2S_N)
1251
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
1252
+
1253
+ # Allocate D registers.
1254
+ tRS_rD_layout = cute.make_layout(thr_copy_r2s.partition_S(sD).shape[:3])
1255
+ tRS_rD = cute.make_fragment(tRS_rD_layout, self.acc_dtype)
1256
+
1257
+ if const_expr(mC_mnl is not None):
1258
+ copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
1259
+ tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1260
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1261
+ tSR_sC = thr_copy_s2r.partition_S(sC)
1262
+ tRS_rC = cute.make_fragment(tRS_rD_layout, self.c_dtype)
1263
+ tSR_rC = thr_copy_s2r.retile(tRS_rC)
1264
+ else:
1265
+ thr_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
1266
+
1267
+ if const_expr(cu_seqlens_m is not None):
1268
+ mD_mn_tma = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl_tma)
1269
+ else:
1270
+ mD_mn_tma = mD_mnl_tma[None, None, batch_idx]
1271
+ # (bM, bN)
1272
+ gD = cute.local_tile(
1273
+ mD_mn_tma, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1274
+ )
1275
+ tDgD_for_tma_partition = cute.zipped_divide(gD, self.epi_tile)
1276
+ bSG_sD, bSG_gD = cpasync.tma_partition(
1277
+ tma_atom_d,
1278
+ 0,
1279
+ cute.make_layout(1),
1280
+ cute.group_modes(sD, 0, 2),
1281
+ tDgD_for_tma_partition,
1282
+ )
1283
+
1284
+ if const_expr(mC_mnl is not None):
1285
+ if const_expr(cu_seqlens_m is not None):
1286
+ mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
1287
+ else:
1288
+ mC_mn = mC_mnl[None, None, batch_idx]
1289
+ # (bM, bN)
1290
+ gC = cute.local_tile(
1291
+ mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1292
+ )
1293
+ tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
1294
+ bGS_sC, bGS_gC = cpasync.tma_partition(
1295
+ tma_atom_c,
1296
+ 0,
1297
+ cute.make_layout(1),
1298
+ cute.group_modes(sC, 0, 2),
1299
+ tCgC_for_tma_partition,
1300
+ )
1301
+
1302
+ epi_tile_num = const_expr(cute.size(tDgD_for_tma_partition, mode=[1]))
1303
+ epi_tile_shape = tDgD_for_tma_partition.shape[1]
1304
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1305
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
1306
+
1307
+ if const_expr(mC_mnl is not None):
1308
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
1309
+ if is_tma_warp:
1310
+ epi_pipeline.producer_acquire(epi_producer_state)
1311
+ # Get the global memory coordinate for the current epi tile
1312
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1313
+ cute.copy(
1314
+ tma_atom_c,
1315
+ bGS_gC[None, gmem_coord],
1316
+ bGS_sC[None, epi_producer_state.index],
1317
+ tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1318
+ )
1319
+ # Epi pipeline's producer commit is a NOP
1320
+ epi_pipeline.producer_commit(epi_producer_state)
1321
+ epi_producer_state.advance()
1322
+
1323
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
1324
+ # Copy from acc to D registers
1325
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1326
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1327
+ if const_expr(mC_mnl is not None):
1328
+ epi_pipeline.consumer_wait(epi_read_state)
1329
+ cute.copy(
1330
+ thr_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
1331
+ )
1332
+ # Fence to make sure shared memory read is visible to TMA load
1333
+ cute.arch.fence_proxy(
1334
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1335
+ )
1336
+ cute.arch.sync_warp()
1337
+ with cute.arch.elect_one():
1338
+ epi_pipeline.consumer_release(epi_read_state)
1339
+ epi_read_state.advance()
1340
+ if const_expr(epi_idx + self.epi_c_stage < epi_tile_num):
1341
+ if is_tma_warp:
1342
+ epi_pipeline.producer_acquire(epi_producer_state)
1343
+ # Get the global memory coordinate for the current epi tile
1344
+ gmem_coord = epi_tile_layout.get_hier_coord(
1345
+ epi_idx + self.epi_c_stage
1346
+ )
1347
+ cute.copy(
1348
+ tma_atom_c,
1349
+ bGS_gC[None, gmem_coord],
1350
+ bGS_sC[None, epi_producer_state.index],
1351
+ tma_bar_ptr=epi_pipeline.producer_get_barrier(
1352
+ epi_producer_state
1353
+ ),
1354
+ )
1355
+ # Epi pipeline's producer commit is a NOP
1356
+ epi_pipeline.producer_commit(epi_producer_state)
1357
+ epi_producer_state.advance()
1358
+ tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(self.acc_dtype))
1359
+ # Type conversion
1360
+ tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
1361
+ tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
1362
+ # Copy from D registers to shared memory
1363
+ epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3])
1364
+ cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
1365
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1366
+ cute.arch.fence_proxy(
1367
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1368
+ )
1369
+ epilogue_barrier.arrive_and_wait()
1370
+ # Get the global memory coordinate for the current epi tile
1371
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1372
+ # Copy from shared memory to global memory
1373
+ if is_tma_warp:
1374
+ if const_expr(varlen):
1375
+ tma_desc_ptr = tensormap_manager.get_tensormap_ptr(
1376
+ tensormap_d_ptr,
1377
+ cute.AddressSpace.generic,
1378
+ )
1379
+ else:
1380
+ tma_desc_ptr = None
1381
+ cute.copy(
1382
+ tma_atom_d,
1383
+ bSG_sD[None, epi_buffer],
1384
+ bSG_gD[None, gmem_coord],
1385
+ tma_desc_ptr=tma_desc_ptr,
1386
+ )
1387
+ cute.arch.cp_async_bulk_commit_group()
1388
+ cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
1389
+ epilogue_barrier.arrive_and_wait()
1390
+
1391
+ if const_expr(self.pingpong):
1392
+ # Update starting load/store pipeline states for the next tile
1393
+ epi_read_state.advance_iters(c_tile_cnt)
1394
+ epi_producer_state.advance_iters(c_tile_cnt)
1395
+ # With pingpong, 2 WGs write two different output tiles to the same smem,
1396
+ # so we have to make sure the smem content is done reading before signaling
1397
+ # the next WG's epilogue.
1398
+ if warp_idx == 0 or warp_idx == 4:
1399
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1400
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
1401
+
1402
+ tile_scheduler.advance_to_next_work(
1403
+ advance_count=1 if not self.pingpong else self.mma_warp_groups
1404
+ )
1405
+ work_tile = tile_scheduler.get_current_work()
1406
+ # End of persistent scheduler loop
1407
+
1408
+ if const_expr(not self.pingpong):
1409
+ if warp_idx == 0:
1410
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1411
+
1412
+ @cute.jit
1413
+ def load_AB(
1414
+ self,
1415
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1416
+ ab_producer_state: cutlass.pipeline.PipelineState,
1417
+ copy_A: Callable,
1418
+ tAgA: cute.Tensor,
1419
+ tAsA: cute.Tensor,
1420
+ copy_B: Callable,
1421
+ tBgB: cute.Tensor,
1422
+ tBsB: cute.Tensor,
1423
+ ) -> cutlass.pipeline.PipelineState:
1424
+ k_tile_cnt = cute.size(tAgA, mode=[1])
1425
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1426
+ peek_ab_empty_status = cutlass.Boolean(True)
1427
+ if 0 < k_tile_cnt:
1428
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1429
+ # /////////////////////////////////////////////////////////////////////////
1430
+ # TMA load
1431
+ # /////////////////////////////////////////////////////////////////////////
1432
+ for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1433
+ # Wait for A/B buffers to be empty before loading into them
1434
+ # Also sets the transaction barrier for the A/B buffers
1435
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1436
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1437
+ copy_A(
1438
+ tAgA[None, k_tile],
1439
+ tAsA[None, ab_producer_state.index],
1440
+ tma_bar_ptr=tma_bar_ptr,
1441
+ )
1442
+ copy_B(
1443
+ tBgB[None, k_tile],
1444
+ tBsB[None, ab_producer_state.index],
1445
+ tma_bar_ptr=tma_bar_ptr,
1446
+ )
1447
+ # Mainloop pipeline's producer commit is a NOP
1448
+ ab_pipeline.producer_commit(ab_producer_state)
1449
+ ab_producer_state.advance()
1450
+ peek_ab_empty_status = cutlass.Boolean(True)
1451
+ if k_tile + 1 < k_tile_cnt:
1452
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1453
+ return ab_producer_state
1454
+
1455
+ @cute.jit
1456
+ def load_AB_gather_A(
1457
+ self,
1458
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1459
+ ab_producer_state: cutlass.pipeline.PipelineState,
1460
+ thr_copy_A: cute.core.ThrCopy,
1461
+ mA: cute.Tensor,
1462
+ tAsA: cute.Tensor,
1463
+ gAIdx: cute.Tensor,
1464
+ copy_B: Callable,
1465
+ tBgB: cute.Tensor,
1466
+ tBsB: cute.Tensor,
1467
+ limit_A: Tuple[Int32, Int32],
1468
+ ) -> cutlass.pipeline.PipelineState:
1469
+ # (atom_v, CPY_M, 1, RestK)
1470
+ limit_m, limit_k = limit_A
1471
+ limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
1472
+ cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
1473
+ tAcA = thr_copy_A.partition_S(cA)
1474
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
1475
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
1476
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
1477
+ # This is so that when we do the comparison, t0AcA is known at compile time.
1478
+ limit_m = limit_m - tAcA[0][0]
1479
+ # Read indices for A
1480
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
1481
+ m_idx = cute.make_fragment(rows_per_thread, Int32)
1482
+ for m in cutlass.range(rows_per_thread):
1483
+ row_idx = tAcA[0, m, 0][0]
1484
+ if t0AcA[0, m, 0][0] < limit_m:
1485
+ m_idx[m] = gAIdx[row_idx]
1486
+ else:
1487
+ m_idx[m] = -1
1488
+ elems_per_load = cute.size(tAsA.shape[0][0])
1489
+ # (m, (bK, RestK))
1490
+ mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
1491
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1492
+ k_tile_cnt = cute.size(tBgB, mode=[1])
1493
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1494
+ peek_ab_empty_status = cutlass.Boolean(True)
1495
+ if 0 < k_tile_cnt:
1496
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1497
+ # /////////////////////////////////////////////////////////////////////////
1498
+ # TMA load on B and cp.async on A
1499
+ # /////////////////////////////////////////////////////////////////////////
1500
+ copy_A = partial(cute.copy, thr_copy_A)
1501
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1502
+ # Wait for A/B buffers to be empty before loading into them
1503
+ # Also sets the transaction barrier for the A/B buffers
1504
+ ab_pipeline.producer_acquire(
1505
+ ab_producer_state,
1506
+ peek_ab_empty_status,
1507
+ # A tiny bit faster to rotate the warp that does TMA
1508
+ is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
1509
+ )
1510
+ # A bit faster to load B first while we calculate the predicate for A
1511
+ if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1512
+ copy_B(
1513
+ tBgB[None, k_tile],
1514
+ tBsB[None, ab_producer_state.index],
1515
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1516
+ )
1517
+ # (m, bK)
1518
+ mA_cur = mA_k[None, (None, k_tile)]
1519
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1520
+ # (elems_per_load, thread_per_row)
1521
+ mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1522
+ if t0AcA[0, m, 0][0] < limit_m:
1523
+ # There's only 1 load per row
1524
+ assert cute.size(tAcA.shape, mode=[2]) == 1
1525
+ ki = tAcA[0, 0, 0][1] // elems_per_load
1526
+ copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1527
+ # This tells mbarrier to track the completion of cp.async
1528
+ ab_pipeline.producer_commit(ab_producer_state)
1529
+ ab_producer_state.advance()
1530
+ peek_ab_empty_status = cutlass.Boolean(True)
1531
+ if k_tile + 1 < k_tile_cnt:
1532
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1533
+ # bound checking in the K dimension on the last k_tile
1534
+ if 0 < k_tile_cnt:
1535
+ k_tile = k_tile_cnt - 1
1536
+ ab_pipeline.producer_acquire(
1537
+ ab_producer_state,
1538
+ peek_ab_empty_status,
1539
+ is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
1540
+ )
1541
+ if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1542
+ copy_B(
1543
+ tBgB[None, k_tile],
1544
+ tBsB[None, ab_producer_state.index],
1545
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1546
+ )
1547
+ assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
1548
+ tApA = cute.make_fragment(1, cutlass.Boolean)
1549
+ tApA[0] = tAcA[0, 0, 0][1] < limit_k
1550
+ # (m, bK)
1551
+ mA_cur = mA_k[None, (None, k_tile)]
1552
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1553
+ # (elems_per_load, thread_per_row)
1554
+ mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1555
+ if t0AcA[0, m, 0][0] < limit_m:
1556
+ # There's only 1 load per row
1557
+ assert cute.size(tAcA.shape, mode=[2]) == 1
1558
+ ki = tAcA[0, 0, 0][1] // elems_per_load
1559
+ # copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
1560
+ # TODO
1561
+ copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1562
+ ab_pipeline.producer_commit(ab_producer_state)
1563
+ ab_producer_state.advance()
1564
+ return ab_producer_state
1565
+
1566
+ @cute.jit
1567
+ def mma(
1568
+ self,
1569
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1570
+ ab_read_state: cutlass.pipeline.PipelineState,
1571
+ tiled_mma: cute.TiledMma,
1572
+ tCrA: cute.Tensor,
1573
+ tCrB: cute.Tensor,
1574
+ acc: cute.Tensor,
1575
+ acc_slow: Optional[cute.Tensor],
1576
+ k_tile_cnt: Int32,
1577
+ warp_group_idx: Int32,
1578
+ ) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
1579
+ # /////////////////////////////////////////////////////////////////////////////
1580
+ # Prologue MMAs
1581
+ # /////////////////////////////////////////////////////////////////////////////
1582
+ k_pipe_mmas = 1
1583
+ ab_release_state = ab_read_state.clone()
1584
+ num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
1585
+ if const_expr(self.pingpong):
1586
+ self.pingpong_barrier_sync(warp_group_idx, stage="mma")
1587
+ peek_ab_full_status = cutlass.Boolean(True)
1588
+ if 0 < k_tile_cnt:
1589
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1590
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1591
+ num_k_blocks = cute.size(tCrA, mode=[2])
1592
+ # TODO: this is probably not correct if k_tile_cnt == 0
1593
+ for k_tile in cutlass.range(num_prologue_mma):
1594
+ # Wait for A/B buffer to be ready
1595
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1596
+ warpgroup.fence()
1597
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1598
+ k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1599
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1600
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1601
+ warpgroup.commit_group()
1602
+ ab_read_state.advance()
1603
+ peek_ab_full_status = cutlass.Boolean(True)
1604
+ if k_tile + 1 < k_tile_cnt:
1605
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1606
+ if const_expr(self.fp8_slow_accum):
1607
+ warpgroup.wait_group(0)
1608
+ acc_slow.store(acc.load())
1609
+
1610
+ # /////////////////////////////////////////////////////////////////////////////
1611
+ # MAINLOOP
1612
+ # /////////////////////////////////////////////////////////////////////////////
1613
+ for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
1614
+ # Wait for TMA copies to complete
1615
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1616
+ # WGMMA
1617
+ warpgroup.fence()
1618
+ if const_expr(self.fp8_slow_accum):
1619
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1620
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1621
+ k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1622
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1623
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1624
+ warpgroup.commit_group()
1625
+ # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
1626
+ if const_expr(not self.fp8_slow_accum):
1627
+ warpgroup.wait_group(k_pipe_mmas)
1628
+ else:
1629
+ warpgroup.wait_group(0)
1630
+ acc_slow.store(acc_slow.load() + acc.load())
1631
+ ab_pipeline.consumer_release(ab_release_state)
1632
+ ab_read_state.advance()
1633
+ ab_release_state.advance()
1634
+ peek_ab_full_status = cutlass.Boolean(True)
1635
+ if k_tile + 1 < k_tile_cnt:
1636
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1637
+ if const_expr(self.pingpong):
1638
+ # Cue for next WG's MMA to start
1639
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
1640
+ if const_expr(not self.fp8_slow_accum):
1641
+ # fp8_slow_accum would already called wait_group(0) inside the loop
1642
+ warpgroup.wait_group(0)
1643
+ for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
1644
+ ab_pipeline.consumer_release(ab_release_state)
1645
+ ab_release_state.advance()
1646
+ if const_expr(self.fp8_slow_accum):
1647
+ acc.store(acc_slow.load())
1648
+ # If we don't return the tiled_mma, we get compiler error
1649
+ # "operand #0 does not dominate this use"
1650
+ return ab_read_state, tiled_mma
1651
+
1652
+ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
1653
+ assert stage in ["mma", "epi"]
1654
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1655
+ cute.arch.barrier(
1656
+ barrier_id=int(barrier) + warp_group_idx,
1657
+ number_of_threads=2 * self.num_threads_per_warp_group,
1658
+ )
1659
+
1660
+ def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
1661
+ assert stage in ["mma", "epi"]
1662
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1663
+ cute.arch.barrier_arrive(
1664
+ barrier_id=int(barrier) + warp_group_idx,
1665
+ number_of_threads=2 * self.num_threads_per_warp_group,
1666
+ )
1667
+
1668
+ @staticmethod
1669
+ def _compute_stages(
1670
+ tile_shape_mnk: Tuple[int, int, int],
1671
+ epi_tile: Optional[Tuple[int, int]],
1672
+ a_dtype: Type[cutlass.Numeric],
1673
+ b_dtype: Type[cutlass.Numeric],
1674
+ d_dtype: Type[cutlass.Numeric],
1675
+ c_dtype: Optional[Type[cutlass.Numeric]],
1676
+ smem_capacity: int,
1677
+ occupancy: int,
1678
+ overlap_sD_sA: bool,
1679
+ ) -> Tuple[int, int]:
1680
+ """Computes the number of stages for A/B/C operands based on heuristics.
1681
+
1682
+ :param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1683
+ :type tile_shape_mnk: Tuple[int, int, int]
1684
+ :param a_dtype: Data type of operand A.
1685
+ :type a_dtype: type[cutlass.Numeric]
1686
+ :param b_dtype: Data type of operand B.
1687
+ :type b_dtype: type[cutlass.Numeric]
1688
+ :param smem_capacity: Total available shared memory capacity in bytes.
1689
+ :type smem_capacity: int
1690
+ :param occupancy: Target number of CTAs per SM (occupancy).
1691
+ :type occupancy: int
1692
+
1693
+ :return: A tuple containing the computed number of stages for:
1694
+ (A/B operand stages, epilogue stages)
1695
+ :rtype: Tuple[int, int]
1696
+ """
1697
+
1698
+ epi_stage = 2
1699
+ if overlap_sD_sA:
1700
+ epi_bytes = 0
1701
+ else:
1702
+ d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8
1703
+ epi_bytes = d_bytes_per_stage * epi_stage
1704
+ epi_c_stage = 0 if c_dtype is None else 2
1705
+ if c_dtype is not None:
1706
+ epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
1707
+
1708
+ a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
1709
+ b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
1710
+ ab_bytes_per_stage = (
1711
+ cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
1712
+ )
1713
+ mbar_helpers_bytes = 1024
1714
+
1715
+ remaining_bytes = (
1716
+ (smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
1717
+ )
1718
+ ab_stage = remaining_bytes // ab_bytes_per_stage
1719
+
1720
+ # Refine epilogue stages:
1721
+ # Calculate remaining smem after allocating for A/B stages and reserved bytes
1722
+ # Add remaining unused smem to epilogue
1723
+ if not overlap_sD_sA:
1724
+ epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // d_bytes_per_stage
1725
+ return ab_stage, epi_stage, epi_c_stage
1726
+
1727
+ @staticmethod
1728
+ def _sm90_compute_tile_shape_or_override(
1729
+ tile_shape_mnk: Tuple[int, int, int],
1730
+ atom_layout_mnk: Tuple[int, int, int],
1731
+ element_type: Type[cutlass.Numeric],
1732
+ epi_tile_override: Tuple[int, int] | None = None,
1733
+ ) -> Tuple[int, int]:
1734
+ """Compute the epilogue tile shape or use override if provided.
1735
+
1736
+ :param tile_shape_mnk: CTA tile shape (M,N,K)
1737
+ :type tile_shape_mnk: Tuple[int, int, int]
1738
+ :param element_type: Data type of elements
1739
+ :type element_type: type[cutlass.Numeric]
1740
+ :param is_cooperative: Whether to use cooperative approach
1741
+ :type is_cooperative: bool
1742
+ :param epi_tile_override: Optional override for epilogue tile shape
1743
+ :type epi_tile_override: Tuple[int, int] or None
1744
+
1745
+ :return: Computed epilogue tile shape
1746
+ :rtype: Tuple[int, int]
1747
+ """
1748
+ if epi_tile_override is not None:
1749
+ return epi_tile_override
1750
+ if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
1751
+ tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
1752
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1753
+ elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
1754
+ tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
1755
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1756
+ else:
1757
+ # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
1758
+ # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
1759
+ # M dimension first, then move to the N dimension. But the accumulator in registers
1760
+ # iterate along the N dimension first, then move to the M dimension.
1761
+ # We could change the epilogue to accommodate this,
1762
+ # but it's easier to just set epi_tile_m = 64.
1763
+ n_perf = 64 if element_type.width == 8 else 32
1764
+ tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
1765
+ tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
1766
+ return (tile_m, tile_n)
1767
+
1768
+ @staticmethod
1769
+ def _make_smem_layouts(
1770
+ tile_shape_mnk: Tuple[int, int, int],
1771
+ epi_tile: Tuple[int, int],
1772
+ a_dtype: Type[cutlass.Numeric],
1773
+ a_layout: cutlass.utils.LayoutEnum,
1774
+ b_dtype: Type[cutlass.Numeric],
1775
+ b_layout: cutlass.utils.LayoutEnum,
1776
+ ab_stage: int,
1777
+ d_dtype: Type[cutlass.Numeric],
1778
+ d_layout: cutlass.utils.LayoutEnum,
1779
+ epi_stage: int,
1780
+ c_dtype: Optional[Type[cutlass.Numeric]],
1781
+ c_layout: Optional[cutlass.utils.LayoutEnum],
1782
+ epi_c_stage: int,
1783
+ ) -> Tuple[
1784
+ cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
1785
+ ]:
1786
+ """Create shared memory layouts for A, B, and C tensors.
1787
+
1788
+ :param tile_shape_mnk: CTA tile shape (M,N,K)
1789
+ :type tile_shape_mnk: Tuple[int, int, int]
1790
+ :param epi_tile: Epilogue tile shape
1791
+ :type epi_tile: Tuple[int, int]
1792
+ :param a_dtype: Data type for matrix A
1793
+ :type a_dtype: type[cutlass.Numeric]
1794
+ :param a_layout: Layout enum for matrix A
1795
+ :type a_layout: cutlass.utils.LayoutEnum
1796
+ :param b_dtype: Data type for matrix B
1797
+ :type b_dtype: type[cutlass.Numeric]
1798
+ :param b_layout: Layout enum for matrix B
1799
+ :type b_layout: cutlass.utils.LayoutEnum
1800
+ :param ab_stage: Number of stages for A/B tensors
1801
+ :type ab_stage: int
1802
+ :param d_dtype: Data type for output matrix C
1803
+ :type d_dtype: type[cutlass.Numeric]
1804
+ :param d_layout: Layout enum for the output matrix C
1805
+ :type d_layout: cutlass.utils.LayoutEnum
1806
+ :param epi_stage: Number of epilogue stages
1807
+ :type epi_stage: int
1808
+
1809
+ :return: Tuple of shared memory layouts for A, B, and C
1810
+ :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
1811
+ """
1812
+ a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
1813
+
1814
+ a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1815
+ b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1816
+ a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
1817
+ a_smem_layout_atom = warpgroup.make_smem_layout_atom(
1818
+ sm90_utils.get_smem_layout_atom(
1819
+ a_layout,
1820
+ a_dtype,
1821
+ a_major_mode_size,
1822
+ ),
1823
+ a_dtype,
1824
+ )
1825
+ a_smem_layout_staged = cute.tile_to_shape(
1826
+ a_smem_layout_atom,
1827
+ cute.append(a_smem_shape, ab_stage),
1828
+ order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
1829
+ )
1830
+
1831
+ b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None))
1832
+
1833
+ b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
1834
+ b_smem_layout_atom = warpgroup.make_smem_layout_atom(
1835
+ sm90_utils.get_smem_layout_atom(
1836
+ b_layout,
1837
+ b_dtype,
1838
+ b_major_mode_size,
1839
+ ),
1840
+ b_dtype,
1841
+ )
1842
+ b_smem_layout_staged = cute.tile_to_shape(
1843
+ b_smem_layout_atom,
1844
+ cute.append(b_smem_shape, ab_stage),
1845
+ order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
1846
+ )
1847
+
1848
+ d_smem_shape = epi_tile
1849
+ d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1850
+ d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1851
+ sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
1852
+ d_dtype,
1853
+ )
1854
+ epi_smem_layout_staged = cute.tile_to_shape(
1855
+ d_smem_layout_atom,
1856
+ cute.append(d_smem_shape, epi_stage),
1857
+ order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1858
+ )
1859
+
1860
+ if c_dtype is not None:
1861
+ assert c_layout is not None
1862
+ c_smem_shape = epi_tile
1863
+ c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
1864
+ c_smem_layout_atom = warpgroup.make_smem_layout_atom(
1865
+ sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
1866
+ c_dtype,
1867
+ )
1868
+ epi_c_smem_layout_staged = cute.tile_to_shape(
1869
+ c_smem_layout_atom,
1870
+ cute.append(c_smem_shape, epi_c_stage),
1871
+ order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
1872
+ )
1873
+ else:
1874
+ epi_c_smem_layout_staged = None
1875
+
1876
+ return (
1877
+ a_smem_layout_staged,
1878
+ b_smem_layout_staged,
1879
+ epi_smem_layout_staged,
1880
+ epi_c_smem_layout_staged,
1881
+ )
1882
+
1883
+ @staticmethod
1884
+ def _make_tma_epi_atoms_and_tensors(
1885
+ tensor_d: cute.Tensor,
1886
+ epi_smem_layout_staged: cute.ComposedLayout,
1887
+ epi_tile: Tuple[int, int],
1888
+ store_or_load: str,
1889
+ ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1890
+ """Create TMA atoms and tensors for storing D or loading C.
1891
+
1892
+ :param tensor_d: Output tensor D
1893
+ :type tensor_d: cute.Tensor
1894
+ :param epi_smem_layout_staged: Shared memory layout for epilogue
1895
+ :type epi_smem_layout_staged: cute.ComposedLayout
1896
+ :param epi_tile: Epilogue tile shape
1897
+ :type epi_tile: Tuple[int, int]
1898
+
1899
+ :return: TMA atom and tensor for C
1900
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1901
+ """
1902
+ assert store_or_load in ["load", "store"]
1903
+ epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
1904
+ d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
1905
+ op = (
1906
+ cpasync.CopyBulkTensorTileG2SOp()
1907
+ if store_or_load == "load"
1908
+ else cpasync.CopyBulkTensorTileS2GOp()
1909
+ )
1910
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
1911
+ op, tensor_d, epi_smem_layout, d_cta_v_layout
1912
+ )
1913
+ return tma_atom_d, tma_tensor_d
1914
+
1915
+ @staticmethod
1916
+ def _make_tma_atoms_and_tensors(
1917
+ tensor: cute.Tensor,
1918
+ smem_layout_staged: cute.ComposedLayout,
1919
+ smem_tile: Tuple[int, int],
1920
+ mcast_dim: int,
1921
+ ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1922
+ """Create TMA atoms and tensors for input tensors.
1923
+
1924
+ :param tensor: Input tensor (A or B)
1925
+ :type tensor: cute.Tensor
1926
+ :param smem_layout_staged: Shared memory layout for the tensor
1927
+ :type smem_layout_staged: cute.ComposedLayout
1928
+ :param smem_tile: Shared memory tile shape
1929
+ :type smem_tile: Tuple[int, int]
1930
+ :param mcast_dim: Multicast dimension
1931
+ :type mcast_dim: int
1932
+
1933
+ :return: TMA atom and tensor
1934
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1935
+ """
1936
+ op = (
1937
+ cpasync.CopyBulkTensorTileG2SOp()
1938
+ if mcast_dim == 1
1939
+ else cpasync.CopyBulkTensorTileG2SMulticastOp()
1940
+ )
1941
+
1942
+ smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
1943
+ tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
1944
+ op,
1945
+ tensor,
1946
+ smem_layout,
1947
+ smem_tile,
1948
+ num_multicast=mcast_dim,
1949
+ )
1950
+ return tma_atom, tma_tensor
1951
+
1952
+ def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128):
1953
+ atom_async_copy = cute.make_copy_atom(
1954
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
1955
+ dtype,
1956
+ num_bits_per_copy=copy_bits,
1957
+ )
1958
+ copy_elems = copy_bits // dtype.width
1959
+ shape_dim_1 = cute.size(self.tile_shape_mnk[2]) // copy_elems
1960
+ # thread layout for copy
1961
+ thread_layout = cute.make_layout(
1962
+ (num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
1963
+ )
1964
+ if major_mode != cutlass.utils.LayoutEnum.ROW_MAJOR:
1965
+ shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
1966
+ thread_layout = cute.make_layout(
1967
+ (shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
1968
+ )
1969
+ # Value layout for copy
1970
+ value_layout = (
1971
+ cute.make_layout((1, copy_elems))
1972
+ if major_mode == cutlass.utils.LayoutEnum.ROW_MAJOR
1973
+ else cute.make_layout((copy_elems, 1))
1974
+ )
1975
+ return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
1976
+
1977
+ @staticmethod
1978
+ def is_valid_dtypes(
1979
+ a_dtype: Type[cutlass.Numeric],
1980
+ b_dtype: Type[cutlass.Numeric],
1981
+ acc_dtype: Type[cutlass.Numeric],
1982
+ d_dtype: Type[cutlass.Numeric],
1983
+ a_major: str,
1984
+ b_major: str,
1985
+ ) -> bool:
1986
+ """
1987
+ Check if the dtypes are valid
1988
+
1989
+ :param a_dtype: The data type of tensor A
1990
+ :type a_dtype: Type[cutlass.Numeric]
1991
+ :param b_dtype: The data type of tensor B
1992
+ :type b_dtype: Type[cutlass.Numeric]
1993
+ :param acc_dtype: The data type of the accumulator
1994
+ :type acc_dtype: Type[cutlass.Numeric]
1995
+ :param d_dtype: The data type of the output tensor
1996
+ :type d_dtype: Type[cutlass.Numeric]
1997
+ :param a_major: major mode of tensor A
1998
+ :type a_major: str
1999
+ :param b_major: major mode of tensor B
2000
+ :type b_major: str
2001
+
2002
+ :return: True if the dtypes are valid, False otherwise
2003
+ :rtype: bool
2004
+ """
2005
+ is_valid = True
2006
+ if a_dtype not in {
2007
+ cutlass.Float16,
2008
+ cutlass.BFloat16,
2009
+ cutlass.Float8E4M3FN,
2010
+ cutlass.Float8E5M2,
2011
+ }:
2012
+ is_valid = False
2013
+ # tested b_dtype
2014
+ if b_dtype not in {
2015
+ cutlass.Float16,
2016
+ cutlass.BFloat16,
2017
+ cutlass.Float8E4M3FN,
2018
+ cutlass.Float8E5M2,
2019
+ }:
2020
+ is_valid = False
2021
+ if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
2022
+ is_valid = False
2023
+ # tested d_dtype
2024
+ if d_dtype not in {
2025
+ cutlass.Float32,
2026
+ cutlass.Float16,
2027
+ cutlass.BFloat16,
2028
+ cutlass.Float8E4M3FN,
2029
+ cutlass.Float8E5M2,
2030
+ }:
2031
+ is_valid = False
2032
+ # make sure a_dtype == b_dtype for Float16
2033
+ if a_dtype.width == 16 and a_dtype != b_dtype:
2034
+ is_valid = False
2035
+ # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2)
2036
+ if a_dtype.width != b_dtype.width:
2037
+ is_valid = False
2038
+
2039
+ # for Float8 types, this implementation only supports k-major layout
2040
+ if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
2041
+ is_valid = False
2042
+
2043
+ return is_valid
2044
+
2045
+
2046
+ def run(
2047
+ mnkl: Tuple[int, int, int, int],
2048
+ a_dtype: Type[cutlass.Numeric],
2049
+ b_dtype: Type[cutlass.Numeric],
2050
+ d_dtype: Type[cutlass.Numeric],
2051
+ c_dtype: Optional[Type[cutlass.Numeric]],
2052
+ acc_dtype: Type[cutlass.Numeric],
2053
+ a_major: str,
2054
+ b_major: str,
2055
+ d_major: str,
2056
+ c_major: str,
2057
+ tile_shape_mnk: Tuple[int, int, int],
2058
+ cluster_shape_mn: Tuple[int, int],
2059
+ tolerance: float,
2060
+ warmup_iterations: int,
2061
+ iterations: int,
2062
+ skip_ref_check: bool,
2063
+ persistent: bool,
2064
+ dynamic_persistent: bool,
2065
+ pingpong: bool,
2066
+ varlen_m: bool,
2067
+ gather_A: bool,
2068
+ fp8_fast_accum: bool,
2069
+ **kwargs,
2070
+ ):
2071
+ """
2072
+ Prepare A/B/D/C tensors, launch GPU kernel, and reference checking.
2073
+
2074
+ :param mnkl: Problem size (M, N, K, L)
2075
+ :type mnkl: Tuple[int, int, int, int]
2076
+ :param a_dtype: Data type for input tensor A
2077
+ :type a_dtype: Type[cutlass.Numeric]
2078
+ :param b_dtype: Data type for input tensor B
2079
+ :type b_dtype: Type[cutlass.Numeric]
2080
+ :param d_dtype: Data type for output tensor C
2081
+ :type d_dtype: Type[cutlass.Numeric]
2082
+ :param acc_dtype: Data type for accumulation during matrix multiplication
2083
+ :type acc_dtype: Type[cutlass.Numeric]
2084
+ :param a_major/b_major/d_major: Memory layout of tensor A/B/C
2085
+ :type a_major/b_major/d_major: str
2086
+ :param tile_shape_mnk: CTA tile shape (M, N, K)
2087
+ :type tile_shape_mnk: Tuple[int, int, int]
2088
+ :param cluster_shape_mn: Cluster shape (M, N)
2089
+ :type cluster_shape_mn: Tuple[int, int]
2090
+ :param tolerance: Tolerance value for reference validation comparison
2091
+ :type tolerance: float
2092
+ :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
2093
+ :type warmup_iterations: int, optional
2094
+ :param iterations: Number of benchmark iterations to run, defaults to 1
2095
+ :type iterations: int, optional
2096
+ :param skip_ref_check: Whether to skip reference result validation, defaults to False
2097
+ :type skip_ref_check: bool, optional
2098
+ """
2099
+
2100
+ if dynamic_persistent:
2101
+ persistent = True
2102
+
2103
+ print("Running Hopper Dense GEMM with:")
2104
+ print(f"mnkl: {mnkl}")
2105
+ print(
2106
+ f"A dtype: {a_dtype}, B dtype: {b_dtype}, D dtype: {d_dtype}, C_dtype: {c_dtype}, Acc dtype: {acc_dtype}"
2107
+ )
2108
+ print(f"Matrix majors - A: {a_major}, B: {b_major}, D: {d_major}")
2109
+ print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
2110
+ print(f"Tolerance: {tolerance}")
2111
+ print(f"Warmup iterations: {warmup_iterations}")
2112
+ print(f"Iterations: {iterations}")
2113
+ print(f"Skip reference checking: {skip_ref_check}")
2114
+
2115
+ # Unpack parameters
2116
+ m, n, k, l = mnkl
2117
+ cluster_shape_mnk = (*cluster_shape_mn, 1)
2118
+
2119
+ # Skip unsupported types
2120
+ if not HopperWgmmaGemmKernel.is_valid_dtypes(
2121
+ a_dtype, b_dtype, acc_dtype, d_dtype, a_major, b_major
2122
+ ):
2123
+ raise TypeError(
2124
+ f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {d_dtype}, {a_major=}, {b_major=}"
2125
+ )
2126
+
2127
+ # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero)
2128
+ if not torch.cuda.is_available():
2129
+ raise RuntimeError("GPU is required to run this example!")
2130
+
2131
+ torch.manual_seed(1111)
2132
+
2133
+ # Create and permute tensor A/B/C
2134
+ def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
2135
+ # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
2136
+ # else : (l, mode0, mode1) -> (mode0, mode1, l)
2137
+ shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
2138
+ permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
2139
+ is_unsigned = dtype in {cutlass.Uint8}
2140
+ # Temporarily use uint8 as torch does not support fp8 type
2141
+ torch_dtype = cutlass_torch.dtype(dtype)
2142
+ gen_dtype = (
2143
+ torch_dtype
2144
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
2145
+ else torch.bfloat16
2146
+ )
2147
+
2148
+ # Create dtype torch tensor (cpu)
2149
+ torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
2150
+ shape,
2151
+ gen_dtype,
2152
+ permute_order=permute_order,
2153
+ # init_type=cutlass.torch.TensorInitType.RANDOM,
2154
+ # init_config=cutlass.torch.RandomInitConfig(
2155
+ # min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
2156
+ # ),
2157
+ init_type=cutlass.torch.TensorInitType.GAUSSIAN,
2158
+ init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
2159
+ ).to(torch_dtype)
2160
+ # Create dtype torch tensor (gpu)
2161
+ torch_tensor = torch_tensor_cpu.cuda()
2162
+
2163
+ # Create f32 torch tensor (cpu)
2164
+ f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
2165
+
2166
+ # Create dtype cute tensor (gpu)
2167
+ torch_tensor_view = (
2168
+ torch_tensor
2169
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
2170
+ else torch_tensor.view(torch.uint8)
2171
+ )
2172
+ cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
2173
+ cute_tensor.element_type = dtype
2174
+ if is_dynamic_layout:
2175
+ cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
2176
+ cute_tensor = cute_tensor.mark_compact_shape_dynamic(
2177
+ mode=(1 if not is_mode0_major else 0),
2178
+ stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
2179
+ divisibility=(128 // dtype.width),
2180
+ )
2181
+ cute_tensor = cutlass.torch.convert_cute_tensor(
2182
+ f32_torch_tensor,
2183
+ cute_tensor,
2184
+ dtype,
2185
+ is_dynamic_layout=is_dynamic_layout,
2186
+ )
2187
+
2188
+ return f32_torch_tensor, cute_tensor, torch_tensor
2189
+
2190
+ a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
2191
+ if gather_A:
2192
+ assert a_major == "k"
2193
+ a_idx = torch.randperm(l * m, dtype=torch.int32, device="cuda")
2194
+ from einops import rearrange
2195
+
2196
+ a = rearrange(rearrange(a, "m k l -> (m l) k")[a_idx.cpu()], "(m l) k -> m k l", m=m)
2197
+ a_torch = rearrange(a_torch, "m k l -> (m l) k")
2198
+ mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2199
+ a_idx_reshaped = rearrange(a_idx, "(m l) -> l m", m=m).contiguous().transpose(0, 1)
2200
+ mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
2201
+ else:
2202
+ mAIdx = None
2203
+ b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
2204
+ _, mD, d_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
2205
+ if c_dtype is not None:
2206
+ c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
2207
+ else:
2208
+ c, mC, c_torch = None, None, None
2209
+ if varlen_m:
2210
+ assert a_major == "k"
2211
+ assert d_major == "n"
2212
+ from einops import rearrange
2213
+
2214
+ a, d_torch = [rearrange(t, "m x l -> (l m) x") for t in (a, d_torch)]
2215
+ if not gather_A:
2216
+ (a_torch,) = [rearrange(t, "m x l -> (l m) x") for t in (a_torch,)]
2217
+ if c_dtype is not None:
2218
+ c, c_torch = [rearrange(t, "m x l -> (l m) x") for t in (c, c_torch)]
2219
+ mC = from_dlpack(c_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2220
+ mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2221
+ mD = from_dlpack(d_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2222
+ # TODO: generate random cu_seqlens_m
2223
+ cu_seqlens_m = torch.arange(0, l + 1, dtype=torch.int32, device="cuda") * m
2224
+ mCuSeqlensM = from_dlpack(cu_seqlens_m, assumed_align=64).mark_layout_dynamic(leading_dim=0)
2225
+ if gather_A:
2226
+ a_idx_reshaped = rearrange(a_idx_reshaped, "m l -> (l m)")
2227
+ mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
2228
+ else:
2229
+ cu_seqlens_m, mCuSeqlensM = None, None
2230
+
2231
+ if varlen_m: # Need to allocate space in gmem to store tensormaps
2232
+ if not persistent:
2233
+ total_m = m * l
2234
+ block_size_m = tile_shape_mnk[0] * cluster_shape_mnk[0]
2235
+ block_size_n = tile_shape_mnk[1] * cluster_shape_mnk[1]
2236
+ total_clusters_m_max = (total_m + l * (block_size_m - 1)) // block_size_m
2237
+ total_clusters_max = total_clusters_m_max * ((n + block_size_n - 1) // block_size_n)
2238
+ total_ctas = total_clusters_max * cluster_shape_mnk[0] * cluster_shape_mnk[1]
2239
+ else:
2240
+ total_ctas = cutlass.utils.HardwareInfo().get_device_multiprocessor_count()
2241
+ if pingpong:
2242
+ total_ctas *= 2
2243
+ # 128 bytes per tensormap
2244
+ tensormaps_torch = torch.empty(total_ctas, 128 // 8, dtype=torch.int64, device="cuda")
2245
+ tensormaps_tensor = from_dlpack(
2246
+ tensormaps_torch, assumed_align=128
2247
+ ).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
2248
+ else:
2249
+ tensormaps_tensor = None
2250
+
2251
+ gemm = HopperWgmmaGemmKernel(
2252
+ acc_dtype,
2253
+ a_dtype,
2254
+ tile_shape_mnk,
2255
+ cluster_shape_mnk,
2256
+ pingpong=pingpong,
2257
+ is_persistent=persistent,
2258
+ fp8_fast_accum=fp8_fast_accum,
2259
+ gather_A=gather_A,
2260
+ )
2261
+
2262
+ # Compute max active clusters on current device
2263
+ if persistent:
2264
+ max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
2265
+ cluster_shape_mn[0] * cluster_shape_mn[1]
2266
+ )
2267
+ if dynamic_persistent:
2268
+ tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda")
2269
+ else:
2270
+ tile_count_semaphore = None
2271
+ # max_active_clusters = 1
2272
+ else:
2273
+ max_active_clusters = 0
2274
+ tile_count_semaphore = None
2275
+
2276
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
2277
+ # compile gemm kernel
2278
+ compiled_gemm = cute.compile(
2279
+ gemm,
2280
+ mA,
2281
+ mB,
2282
+ mD,
2283
+ mC,
2284
+ mAIdx,
2285
+ mCuSeqlensM,
2286
+ tensormaps_tensor,
2287
+ make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2288
+ if tile_count_semaphore is not None
2289
+ else None,
2290
+ max_active_clusters,
2291
+ current_stream,
2292
+ )
2293
+
2294
+ if not skip_ref_check:
2295
+ # execution
2296
+ compiled_gemm(
2297
+ mA,
2298
+ mB,
2299
+ mD,
2300
+ mC,
2301
+ mAIdx,
2302
+ mCuSeqlensM,
2303
+ tensormaps_tensor,
2304
+ tile_count_semaphore,
2305
+ max_active_clusters,
2306
+ current_stream,
2307
+ )
2308
+ if tile_count_semaphore is not None and varlen_m:
2309
+ tile_count_semaphore.zero_()
2310
+
2311
+ torch.cuda.synchronize()
2312
+
2313
+ # Ref check
2314
+ if not varlen_m:
2315
+ ref = torch.einsum("mkl,nkl->mnl", a, b)
2316
+ else:
2317
+ ref = torch.cat(
2318
+ [
2319
+ torch.einsum("mk,nk->mn", a[cu_seqlens_m[i] : cu_seqlens_m[i + 1]], b[:, :, i])
2320
+ for i in range(l)
2321
+ ],
2322
+ dim=0,
2323
+ )
2324
+ if c is not None:
2325
+ ref = ref + c
2326
+ ref = ref.cpu()
2327
+
2328
+ if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
2329
+ # m major: (l, n, m) -> (m, n, l)
2330
+ # n major: (l, m, n) -> (m, n, l)
2331
+ permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
2332
+ shape = (l, m, n) if d_major == "n" else (l, n, m)
2333
+ f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
2334
+ shape,
2335
+ torch.uint8,
2336
+ permute_order=permute_order,
2337
+ init_type=cutlass_torch.TensorInitType.SKIP,
2338
+ ).cuda()
2339
+ # Create dtype cute tensor (gpu)
2340
+ ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
2341
+ leading_dim=(1 if d_major == "n" else 0)
2342
+ )
2343
+ ref_d_tensor.element_type = d_dtype
2344
+ ref_d_tensor = cutlass_torch.convert_cute_tensor(
2345
+ ref,
2346
+ ref_d_tensor,
2347
+ d_dtype,
2348
+ is_dynamic_layout=True,
2349
+ )
2350
+ ref_d = f8_torch_tensor.cpu()
2351
+ else:
2352
+ ref_d = ref.to(cutlass_torch.dtype(d_dtype))
2353
+
2354
+ out = d_torch.cpu().squeeze()
2355
+ out_ref = ref_d.squeeze()
2356
+ # breakpoint()
2357
+ torch.testing.assert_close(d_torch.cpu(), ref_d, atol=tolerance, rtol=1e-03)
2358
+
2359
+ # return
2360
+
2361
+ from triton.testing import do_bench
2362
+
2363
+ flops = 2 * m * n * k * l
2364
+ # Calculate memory bandwidth
2365
+ bytes_A = m * k * l * (a_dtype.width // 8) # A tensor: (m, k, l)
2366
+ bytes_B = n * k * l * (b_dtype.width // 8) # B tensor: (n, k, l)
2367
+ bytes_D = m * n * l * (d_dtype.width // 8) # D tensor: (m, n, l)
2368
+ bytes_C = m * n * l * (c_dtype.width // 8) if c_dtype is not None else 0 # C tensor: (m, n, l)
2369
+ total_bytes = bytes_A + bytes_B + bytes_D + bytes_C # Read A, B, C; Write D
2370
+
2371
+ repeats = iterations
2372
+ warmup = warmup_iterations
2373
+
2374
+ import time
2375
+
2376
+ if not varlen_m and not gather_A:
2377
+ time.sleep(0.5)
2378
+ if a_dtype.width == 8:
2379
+ assert l == 1
2380
+ scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
2381
+ fn_cublas = lambda: torch._scaled_mm(
2382
+ a_torch[:, :, 0],
2383
+ b_torch[:, :, 0].mT,
2384
+ scale_a=scale_ab,
2385
+ scale_b=scale_ab,
2386
+ out_dtype=torch.bfloat16,
2387
+ use_fast_accum=fp8_fast_accum,
2388
+ )
2389
+ else:
2390
+ if c_torch is None:
2391
+ fn_cublas = lambda: torch.matmul(
2392
+ a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT
2393
+ )
2394
+ else:
2395
+ c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
2396
+ fn_cublas = lambda: torch.baddbmm(
2397
+ c_torch_convert.permute(2, 0, 1),
2398
+ a_torch.permute(2, 0, 1),
2399
+ b_torch.permute(2, 0, 1).mT,
2400
+ )
2401
+ timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2402
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2403
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2404
+
2405
+ time.sleep(0.5)
2406
+
2407
+ def fn():
2408
+ compiled_gemm(
2409
+ mA,
2410
+ mB,
2411
+ mD,
2412
+ mC,
2413
+ mAIdx,
2414
+ mCuSeqlensM,
2415
+ tensormaps_tensor,
2416
+ tile_count_semaphore,
2417
+ max_active_clusters,
2418
+ current_stream,
2419
+ )
2420
+ if tile_count_semaphore is not None and varlen_m:
2421
+ tile_count_semaphore.zero_()
2422
+
2423
+ timing = do_bench(fn, warmup=warmup, rep=repeats)
2424
+ # Idk why but for some cases the 1st run is much slower
2425
+ time.sleep(0.5)
2426
+ timing = do_bench(fn, warmup=warmup, rep=repeats)
2427
+ tflops = flops / (timing * 1e9) # Convert to TFlops
2428
+ gbps = total_bytes / (timing * 1e6) # Convert to GB/s (1e9 for ms->s, 1e9 for B->GB)
2429
+ print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}, GB/s: {gbps:.0f}")
2430
+ fn()
2431
+
2432
+ if not varlen_m:
2433
+ time.sleep(0.5)
2434
+ timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2435
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2436
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2437
+
2438
+ from flash_attn.utils.benchmark import pytorch_profiler
2439
+
2440
+ pytorch_profiler(fn_cublas)
2441
+ # pytorch_profiler(torch.sort, d_torch.squeeze(), dim=-1)
2442
+ # pytorch_profiler(torch.compile(torch.sort), d_torch.squeeze(), dim=-1)
2443
+ # pytorch_profiler(torch.topk, d_torch.squeeze(), dim=-1, k=1)
2444
+ # pytorch_profiler(torch.compile(torch.topk), d_torch.squeeze(), dim=-1, k=1)
2445
+ # pytorch_profiler(torch.square, d_torch.squeeze())
2446
+
2447
+
2448
+ if __name__ == "__main__":
2449
+ args = parse_arguments()
2450
+ run(
2451
+ args.mnkl,
2452
+ args.a_dtype,
2453
+ args.b_dtype,
2454
+ args.d_dtype,
2455
+ args.c_dtype,
2456
+ args.acc_dtype,
2457
+ args.a_major,
2458
+ args.b_major,
2459
+ args.d_major,
2460
+ args.c_major,
2461
+ args.tile_shape_mnk,
2462
+ args.cluster_shape_mn,
2463
+ args.tolerance,
2464
+ args.warmup_iterations,
2465
+ args.iterations,
2466
+ args.skip_ref_check,
2467
+ args.persistent,
2468
+ args.dynamic_persistent,
2469
+ args.pingpong,
2470
+ args.varlen_m,
2471
+ args.gather_A,
2472
+ args.fp8_fast_accum,
2473
+ )
2474
+ print("PASS")