quack-kernels 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2562 @@
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+
7
+ # 1. Redistributions of source code must retain the above copyright notice, this
8
+ # list of conditions and the following disclaimer.
9
+
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+
14
+ # 3. Neither the name of the copyright holder nor the names of its
15
+ # contributors may be used to endorse or promote products derived from
16
+ # this software without specific prior written permission.
17
+
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+ import argparse
30
+ from typing import Optional, Type, Tuple, Union, Callable
31
+ from functools import partial
32
+
33
+ import cuda.bindings.driver as cuda
34
+ import torch
35
+
36
+ import cutlass
37
+ import cutlass.cute as cute
38
+ from cutlass.cute.nvgpu import cpasync, tcgen05
39
+ import cutlass.torch as cutlass_torch
40
+ import cutlass.pipeline as pipeline
41
+ import cutlass.utils.blackwell_helpers as sm100_utils
42
+ import cutlass.utils.blockscaled_layout as blockscaled_utils
43
+ from cutlass.cute.runtime import from_dlpack, make_ptr
44
+ from cutlass import Int32, const_expr
45
+
46
+ from quack.tile_scheduler import (
47
+ TileSchedulerArguments,
48
+ TileScheduler,
49
+ ParamsBase,
50
+ RasterOrderOption,
51
+ )
52
+
53
+ """
54
+ A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
55
+ using CUTE DSL.
56
+ - Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
57
+ - Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
58
+ - Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
59
+
60
+ This GEMM kernel supports the following features:
61
+ - Utilizes Tensor Memory Access (TMA) for efficient memory operations
62
+ - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions)
63
+ - Implements TMA multicast with cluster to reduce L2 memory traffic
64
+ - Support persistent tile scheduling to better overlap memory load/store with mma between tiles
65
+ - Support warp specialization to avoid explicit pipelining between mainloop load and mma
66
+
67
+ This GEMM works as follows:
68
+ 1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
69
+ 2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
70
+ 3. EPILOGUE warp:
71
+ - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
72
+ - Type convert C matrix to output type.
73
+ - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
74
+ or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
75
+ - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
76
+ e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
77
+
78
+ SM100 tcgen05.mma instructions operate as follows:
79
+ - Read matrix A from SMEM
80
+ - Read matrix B from SMEM
81
+ - Write accumulator to TMEM
82
+ The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
83
+
84
+ Input arguments to this example is same as dense_gemm.py.
85
+
86
+ .. code-block:: bash
87
+
88
+ python examples/blackwell/dense_gemm_persistent.py \
89
+ --ab_dtype Float16 --d_dtype Float16 --acc_dtype Float32 \
90
+ --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
91
+ --mnkl 8192,8192,8192,1 \
92
+ --use_2cta_instrs
93
+
94
+ To collect performance with NCU profiler:
95
+
96
+ .. code-block:: bash
97
+
98
+ ncu python examples/blackwell/dense_gemm_persistent.py \
99
+ --ab_dtype Float16 --d_dtype Float16 --acc_dtype Float32 \
100
+ --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
101
+ --mnkl 8192,8192,8192,1 \
102
+ --use_2cta_instrs \
103
+ --warmup_iterations 1 --iterations 10 --skip_ref_check
104
+
105
+
106
+ Constraints are same as dense_gemm.py:
107
+ * Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
108
+ see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation
109
+ * A/B tensor must have the same data type
110
+ * Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
111
+ * Mma tiler N must be 32-256, step 32
112
+ * Cluster shape M/N must be positive and power of 2, total cluster size <= 16
113
+ * Cluster shape M must be multiple of 2 if use_2cta_instrs=True
114
+ * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
115
+ i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32,
116
+ Float16/BFloat16, and Int8/Uint8/Float8, respectively.
117
+ * OOB tiles are not allowed when TMA store is disabled
118
+ """
119
+
120
+
121
+ class PersistentDenseGemmKernel:
122
+ """This class implements batched matrix multiplication (C = A x B) with support for various data types
123
+ and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
124
+
125
+ :param acc_dtype: Data type for accumulation during computation
126
+ :type acc_dtype: type[cutlass.Numeric]
127
+ :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
128
+ :type use_2cta_instrs: bool
129
+ :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
130
+ :type mma_tiler_mn: Tuple[int, int]
131
+ :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
132
+ :type cluster_shape_mn: Tuple[int, int]
133
+
134
+ :note: In current version, A and B tensor must have the same data type
135
+ - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
136
+
137
+ :note: Supported A/B data types:
138
+ - TFloat32
139
+ - Float16/BFloat16
140
+ - Int8/Uint8
141
+ - Float8E4M3FN/Float8E5M2
142
+
143
+ :note: Supported accumulator data types:
144
+ - Float32 (for all floating point A/B data types)
145
+ - Float16 (only for fp16 and fp8 A/B data types)
146
+ - Int32 (only for uint8/int8 A/B data types)
147
+
148
+ :note: Supported C data types:
149
+ - Float32 (for float32 and int32 accumulator data types)
150
+ - Int32 (for float32 and int32 accumulator data types)
151
+ - Float16/BFloat16 (for fp16 and fp8 accumulator data types)
152
+ - Int8/Uint8 (for uint8/int8 accumulator data types)
153
+ - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
154
+
155
+ :note: Constraints:
156
+ - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
157
+ - MMA tiler N must be 32-256, step 32
158
+ - Cluster shape M must be multiple of 2 if use_2cta_instrs=True
159
+ - Cluster shape M/N must be positive and power of 2, total cluster size <= 16
160
+
161
+ Example:
162
+ >>> gemm = PersistentDenseGemmKernel(
163
+ ... acc_dtype=cutlass.Float32,
164
+ ... use_2cta_instrs=True,
165
+ ... mma_tiler_mn=(128, 128),
166
+ ... cluster_shape_mn=(2, 2)
167
+ ... )
168
+ >>> gemm(mA, mB, mD, max_active_clusters, stream)
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ acc_dtype: Type[cutlass.Numeric],
174
+ use_2cta_instrs: bool,
175
+ mma_tiler_mn: Tuple[int, int],
176
+ cluster_shape_mn: Tuple[int, int],
177
+ sf_vec_size: Optional[int] = None,
178
+ ):
179
+ """Initializes the configuration for a Blackwell dense GEMM kernel.
180
+
181
+ This configuration includes several key aspects:
182
+
183
+ 1. MMA Instruction Settings (tcgen05):
184
+ - acc_dtype: Data types for MMA accumulator.
185
+ - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
186
+ - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
187
+ with cta_group=2 should be used.
188
+
189
+ 2. Cluster Shape:
190
+ - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
191
+
192
+ :param acc_dtype: Data type of the accumulator.
193
+ :type acc_dtype: type[cutlass.Numeric]
194
+ :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
195
+ :type mma_tiler_mn: Tuple[int, int]
196
+ :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
197
+ :type use_2cta_instrs: bool
198
+ :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
199
+ :type cluster_shape_mn: Tuple[int, int]
200
+ """
201
+
202
+ self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
203
+ self.use_2cta_instrs = use_2cta_instrs
204
+ self.cluster_shape_mn = cluster_shape_mn
205
+ # K dimension is deferred in _setup_attributes
206
+ self.mma_tiler = (*mma_tiler_mn, 1)
207
+ self.sf_vec_size = sf_vec_size
208
+ self.blockscaled = sf_vec_size is not None
209
+
210
+ self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
211
+
212
+ self.occupancy = 1
213
+ # Set specialized warp ids
214
+ self.epilog_warp_id = (
215
+ 0,
216
+ 1,
217
+ 2,
218
+ 3,
219
+ )
220
+ self.mma_warp_id = 4
221
+ self.tma_warp_id = 5
222
+ self.tma_epi_warp_id = 6
223
+ self.threads_per_cta = 32 * len(
224
+ (self.mma_warp_id, self.tma_warp_id, self.tma_epi_warp_id, *self.epilog_warp_id)
225
+ )
226
+ # Set barrier id for cta sync, epilogue sync and tmem ptr sync
227
+ self.cta_sync_bar_id = 0
228
+ self.epilog_sync_bar_id = 1
229
+ self.tmem_ptr_sync_bar_id = 2
230
+ self.epilog_load_bar_id = 3
231
+ self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_100")
232
+
233
+ def _setup_attributes(self):
234
+ """Set up configurations that are dependent on GEMM inputs
235
+
236
+ This method configures various attributes based on the input tensor properties
237
+ (data types, leading dimensions) and kernel settings:
238
+ - Configuring tiled MMA
239
+ - Computing MMA/cluster/tile shapes
240
+ - Computing cluster layout
241
+ - Computing multicast CTAs for A/B
242
+ - Computing epilogue subtile
243
+ - Setting up A/B/C stage counts in shared memory
244
+ - Computing A/B/C shared memory layout
245
+ - Computing tensor memory allocation columns
246
+ """
247
+ # Compute mma instruction shapes
248
+ mma_inst_bits_k = 256
249
+ # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K)
250
+ self.mma_inst_shape_mnk = (
251
+ self.mma_tiler[0],
252
+ self.mma_tiler[1],
253
+ mma_inst_bits_k // self.a_dtype.width,
254
+ )
255
+ # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K)
256
+ self.mma_inst_shape_mnk_sfb = (
257
+ self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1),
258
+ cute.round_up(self.mma_inst_shape_mnk[1], 128),
259
+ self.mma_inst_shape_mnk[2],
260
+ )
261
+
262
+ # Configure tiled mma
263
+ if const_expr(not self.blockscaled):
264
+ tiled_mma = sm100_utils.make_trivial_tiled_mma(
265
+ self.a_dtype,
266
+ self.a_major_mode,
267
+ self.b_major_mode,
268
+ self.acc_dtype,
269
+ self.cta_group,
270
+ self.mma_tiler[:2],
271
+ )
272
+ tiled_mma_sfb = None
273
+ else:
274
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
275
+ self.a_dtype,
276
+ self.a_major_mode,
277
+ self.b_major_mode,
278
+ self.sf_dtype,
279
+ self.sf_vec_size,
280
+ self.cta_group,
281
+ self.mma_inst_shape_mnk[:2],
282
+ )
283
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
284
+ self.a_dtype,
285
+ self.a_major_mode,
286
+ self.b_major_mode,
287
+ self.sf_dtype,
288
+ self.sf_vec_size,
289
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
290
+ self.mma_inst_shape_mnk_sfb[:2],
291
+ )
292
+
293
+ # Compute mma/cluster/tile shapes
294
+ mma_inst_tile_k = 4
295
+ self.mma_tiler = (
296
+ self.mma_inst_shape_mnk[0],
297
+ self.mma_inst_shape_mnk[1],
298
+ self.mma_inst_shape_mnk[2] * mma_inst_tile_k,
299
+ )
300
+ if const_expr(self.blockscaled):
301
+ self.mma_tiler_sfb = (
302
+ self.mma_inst_shape_mnk_sfb[0],
303
+ self.mma_inst_shape_mnk_sfb[1],
304
+ self.mma_inst_shape_mnk_sfb[2] * mma_inst_tile_k,
305
+ )
306
+ else:
307
+ self.mma_tiler_sfb = None
308
+ self.cta_tile_shape_mnk = (
309
+ self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
310
+ self.mma_tiler[1],
311
+ self.mma_tiler[2],
312
+ )
313
+
314
+ # Compute cluster layout
315
+ self.cluster_layout_vmnk = cute.tiled_divide(
316
+ cute.make_layout((*self.cluster_shape_mn, 1)),
317
+ (tiled_mma.thr_id.shape,),
318
+ )
319
+ if const_expr(self.blockscaled):
320
+ self.cluster_layout_sfb_vmnk = cute.tiled_divide(
321
+ cute.make_layout((*self.cluster_shape_mn, 1)),
322
+ (tiled_mma_sfb.thr_id.shape,),
323
+ )
324
+ else:
325
+ self.cluster_layout_sfb_vmnk = None
326
+
327
+ # Compute number of multicast CTAs for A/B
328
+ self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
329
+ self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
330
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
331
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
332
+ if const_expr(self.blockscaled):
333
+ self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1])
334
+ self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
335
+
336
+ # Compute epilogue subtile
337
+ self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
338
+ self.cta_tile_shape_mnk,
339
+ self.use_2cta_instrs,
340
+ self.d_layout,
341
+ self.d_dtype,
342
+ )
343
+
344
+ # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
345
+ (
346
+ self.num_acc_stage,
347
+ self.num_ab_stage,
348
+ self.num_d_stage,
349
+ self.num_c_stage,
350
+ ) = self._compute_stages(
351
+ tiled_mma,
352
+ self.mma_tiler,
353
+ self.a_dtype,
354
+ self.b_dtype,
355
+ self.epi_tile,
356
+ self.d_dtype,
357
+ self.c_dtype,
358
+ self.d_layout,
359
+ self.c_layout,
360
+ self.sf_dtype,
361
+ self.sf_vec_size,
362
+ self.smem_capacity,
363
+ self.occupancy,
364
+ )
365
+
366
+ # Compute A/B/SFA/SFB/C shared memory layout
367
+ self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
368
+ tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage
369
+ )
370
+ self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
371
+ tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage
372
+ )
373
+ self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
374
+ self.d_dtype, self.d_layout, self.epi_tile, self.num_d_stage
375
+ )
376
+ if const_expr(self.c_dtype is not None):
377
+ self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
378
+ self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage
379
+ )
380
+ else:
381
+ self.epi_c_smem_layout_staged = None
382
+ if const_expr(self.blockscaled):
383
+ self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
384
+ tiled_mma,
385
+ self.mma_tiler,
386
+ self.sf_vec_size,
387
+ self.num_ab_stage,
388
+ )
389
+ self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
390
+ tiled_mma,
391
+ self.mma_tiler,
392
+ self.sf_vec_size,
393
+ self.num_ab_stage,
394
+ )
395
+ else:
396
+ self.sfa_smem_layout_staged, self.sfb_smem_layout_staged = None, None
397
+
398
+ # Compute the number of tensor memory allocation columns
399
+ if const_expr(not self.blockscaled):
400
+ self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
401
+ tiled_mma, self.mma_tiler, self.num_acc_stage
402
+ )
403
+ else:
404
+ SM100_TMEM_CAPACITY_COLUMNS = 512
405
+ self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
406
+
407
+ @cute.jit
408
+ def __call__(
409
+ self,
410
+ mA: cute.Tensor,
411
+ mB: cute.Tensor,
412
+ mD: cute.Tensor,
413
+ mC: Optional[cute.Tensor],
414
+ tile_count_semaphore: Optional[cute.Pointer],
415
+ max_active_clusters: cutlass.Constexpr,
416
+ stream: cuda.CUstream,
417
+ mSFA: Optional[cute.Tensor] = None,
418
+ mSFB: Optional[cute.Tensor] = None,
419
+ epilogue_op: cutlass.Constexpr = lambda x: x,
420
+ ):
421
+ """Execute the GEMM operation in steps:
422
+ - Setup static attributes before smem/grid/tma computation
423
+ - Setup TMA load/store atoms and tensors
424
+ - Compute grid size with regard to hardware constraints
425
+ - Define shared storage for kernel
426
+ - Launch the kernel synchronously
427
+
428
+ :param mA: Input tensor A
429
+ :type mA: cute.Tensor
430
+ :param mB: Input tensor B
431
+ :type mB: cute.Tensor
432
+ :param mD: Output tensor D
433
+ :type mD: cute.Tensor
434
+ :param max_active_clusters: Maximum number of active clusters
435
+ :type max_active_clusters: cutlass.Constexpr
436
+ :param stream: CUDA stream for asynchronous execution
437
+ :type stream: cuda.CUstream
438
+ :param epilogue_op: Optional elementwise lambda function to apply to the output tensor
439
+ :type epilogue_op: cutlass.Constexpr
440
+ :raises TypeError: If input data types are incompatible with the MMA instruction.
441
+ :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
442
+ """
443
+ if const_expr(self.blockscaled):
444
+ assert mSFA is not None and mSFB is not None
445
+ # Setup static attributes before smem/grid/tma computation
446
+ self.a_dtype: Type[cutlass.Numeric] = mA.element_type
447
+ self.b_dtype: Type[cutlass.Numeric] = mB.element_type
448
+ self.d_dtype: Type[cutlass.Numeric] = mD.element_type
449
+ self.c_dtype = mC.element_type if mC is not None else None
450
+ self.sf_dtype: Optional[Type[cutlass.Numeric]] = (
451
+ mSFA.element_type if mSFA is not None else None
452
+ )
453
+ self.a_major_mode = cutlass.utils.LayoutEnum.from_tensor(mA).mma_major_mode()
454
+ self.b_major_mode = cutlass.utils.LayoutEnum.from_tensor(mB).mma_major_mode()
455
+ self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
456
+ self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
457
+
458
+ # Check if input data types are compatible with MMA instruction
459
+ if const_expr(self.a_dtype != self.b_dtype):
460
+ raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
461
+
462
+ # Setup attributes that dependent on gemm inputs
463
+ self._setup_attributes()
464
+
465
+ if const_expr(self.blockscaled):
466
+ # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
467
+ # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL)
468
+ sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(mA.shape, self.sf_vec_size)
469
+ mSFA = cute.make_tensor(mSFA.iterator, sfa_layout)
470
+ # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL)
471
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(mB.shape, self.sf_vec_size)
472
+ mSFB = cute.make_tensor(mSFB.iterator, sfb_layout)
473
+
474
+ if const_expr(not self.blockscaled):
475
+ tiled_mma = sm100_utils.make_trivial_tiled_mma(
476
+ self.a_dtype,
477
+ self.a_major_mode,
478
+ self.b_major_mode,
479
+ self.acc_dtype,
480
+ self.cta_group,
481
+ self.mma_tiler[:2],
482
+ )
483
+ tiled_mma_sfb = None
484
+ else:
485
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
486
+ self.a_dtype,
487
+ self.a_major_mode,
488
+ self.b_major_mode,
489
+ self.sf_dtype,
490
+ self.sf_vec_size,
491
+ self.cta_group,
492
+ self.mma_inst_shape_mnk[:2],
493
+ )
494
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
495
+ self.a_dtype,
496
+ self.a_major_mode,
497
+ self.b_major_mode,
498
+ self.sf_dtype,
499
+ self.sf_vec_size,
500
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
501
+ self.mma_inst_shape_mnk_sfb[:2],
502
+ )
503
+ atom_thr_size = cute.size(tiled_mma.thr_id.shape)
504
+
505
+ # Setup TMA load for A
506
+ a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
507
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
508
+ tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
509
+ a_op,
510
+ mA,
511
+ a_smem_layout,
512
+ self.mma_tiler,
513
+ tiled_mma,
514
+ self.cluster_layout_vmnk.shape,
515
+ internal_type=(cutlass.TFloat32 if mA.element_type is cutlass.Float32 else None),
516
+ )
517
+
518
+ # Setup TMA load for B
519
+ b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
520
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
521
+ tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
522
+ b_op,
523
+ mB,
524
+ b_smem_layout,
525
+ self.mma_tiler,
526
+ tiled_mma,
527
+ self.cluster_layout_vmnk.shape,
528
+ internal_type=(cutlass.TFloat32 if mB.element_type is cutlass.Float32 else None),
529
+ )
530
+
531
+ if const_expr(self.blockscaled):
532
+ # Setup TMA load for SFA
533
+ sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
534
+ self.cluster_shape_mn, tiled_mma.thr_id
535
+ )
536
+ sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
537
+ tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
538
+ sfa_op,
539
+ mSFA,
540
+ sfa_smem_layout,
541
+ self.mma_tiler,
542
+ tiled_mma,
543
+ self.cluster_layout_vmnk.shape,
544
+ internal_type=cutlass.Int16,
545
+ )
546
+ # Setup TMA load for SFB
547
+ sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
548
+ self.cluster_shape_mn, tiled_mma.thr_id
549
+ )
550
+ sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
551
+ tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
552
+ sfb_op,
553
+ mSFB,
554
+ sfb_smem_layout,
555
+ self.mma_tiler_sfb,
556
+ tiled_mma_sfb,
557
+ self.cluster_layout_sfb_vmnk.shape,
558
+ internal_type=cutlass.Int16,
559
+ )
560
+ else:
561
+ tma_atom_sfa, tma_tensor_sfa = None, None
562
+ tma_atom_sfb, tma_tensor_sfb = None, None
563
+
564
+ a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
565
+ b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
566
+ self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
567
+ if const_expr(self.blockscaled):
568
+ sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
569
+ sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
570
+ self.num_tma_load_bytes += (sfa_copy_size + sfb_copy_size) * atom_thr_size
571
+
572
+ # Setup TMA store for D
573
+ epi_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
574
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
575
+ cpasync.CopyBulkTensorTileS2GOp(),
576
+ mD,
577
+ epi_smem_layout,
578
+ self.epi_tile,
579
+ )
580
+ if const_expr(mC is not None):
581
+ epi_c_smem_layout = cute.slice_(self.epi_c_smem_layout_staged, (None, None, 0))
582
+ tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
583
+ cpasync.CopyBulkTensorTileG2SOp(),
584
+ mC,
585
+ epi_c_smem_layout,
586
+ self.epi_tile,
587
+ )
588
+ else:
589
+ tma_atom_c, tma_tensor_c = None, None
590
+
591
+ problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.cta_tile_shape_mnk[:2]) + (
592
+ mD.shape[2],
593
+ )
594
+ TileSchedulerCls = TileScheduler
595
+ tile_sched_args = TileSchedulerArguments(
596
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
597
+ raster_order=RasterOrderOption.Heuristic,
598
+ group_size=8,
599
+ cluster_shape_mnk=(*self.cluster_shape_mn, 1),
600
+ tile_count_semaphore=tile_count_semaphore,
601
+ is_persistent=True,
602
+ )
603
+ tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
604
+ grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
605
+
606
+ self.buffer_align_bytes = 1024
607
+
608
+ epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
609
+ sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU
610
+ sfa_smem_size = (
611
+ cute.cosize(self.sfa_smem_layout_staged) if const_expr(self.blockscaled) else 0
612
+ )
613
+ sfb_smem_size = (
614
+ cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0
615
+ )
616
+
617
+ # Define shared storage for kernel
618
+ @cute.struct
619
+ class SharedStorage:
620
+ ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
621
+ ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
622
+ epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage * 2]
623
+ acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
624
+ acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
625
+ tmem_dealloc_mbar_ptr: cutlass.Int64
626
+ tmem_holding_buf: Int32
627
+ sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, 2]
628
+ tile_count: cute.struct.MemRange[cutlass.Int32, 1]
629
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
630
+ sD: cute.struct.Align[
631
+ cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)],
632
+ self.buffer_align_bytes,
633
+ ]
634
+ sC: cute.struct.Align[
635
+ cute.struct.MemRange[
636
+ self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
637
+ ],
638
+ self.buffer_align_bytes,
639
+ ]
640
+ # (MMA, MMA_M, MMA_K, STAGE)
641
+ sA: cute.struct.Align[
642
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
643
+ self.buffer_align_bytes,
644
+ ]
645
+ # (MMA, MMA_N, MMA_K, STAGE)
646
+ sB: cute.struct.Align[
647
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
648
+ self.buffer_align_bytes,
649
+ ]
650
+ # (MMA, MMA_M, MMA_K, STAGE)
651
+ sSFA: cute.struct.Align[
652
+ cute.struct.MemRange[sf_dtype, sfa_smem_size],
653
+ self.buffer_align_bytes,
654
+ ]
655
+ # (MMA, MMA_N, MMA_K, STAGE)
656
+ sSFB: cute.struct.Align[
657
+ cute.struct.MemRange[sf_dtype, sfb_smem_size],
658
+ self.buffer_align_bytes,
659
+ ]
660
+
661
+ self.shared_storage = SharedStorage
662
+
663
+ # Launch the kernel synchronously
664
+ self.kernel(
665
+ tiled_mma,
666
+ tiled_mma_sfb,
667
+ tma_atom_a,
668
+ tma_tensor_a,
669
+ tma_atom_b,
670
+ tma_tensor_b,
671
+ tma_atom_sfa,
672
+ tma_tensor_sfa,
673
+ tma_atom_sfb,
674
+ tma_tensor_sfb,
675
+ tma_atom_d,
676
+ tma_tensor_d,
677
+ tma_atom_c,
678
+ tma_tensor_c,
679
+ self.cluster_layout_vmnk,
680
+ self.cluster_layout_sfb_vmnk,
681
+ self.a_smem_layout_staged,
682
+ self.b_smem_layout_staged,
683
+ self.sfa_smem_layout_staged,
684
+ self.sfb_smem_layout_staged,
685
+ self.d_smem_layout_staged,
686
+ self.epi_c_smem_layout_staged,
687
+ self.epi_tile,
688
+ tile_sched_params,
689
+ TileSchedulerCls,
690
+ epilogue_op,
691
+ ).launch(
692
+ grid=grid,
693
+ block=[self.threads_per_cta, 1, 1],
694
+ cluster=(*self.cluster_shape_mn, 1),
695
+ smem=self.shared_storage.size_in_bytes(),
696
+ stream=stream,
697
+ )
698
+ return
699
+
700
+ # GPU device kernel
701
+ @cute.kernel
702
+ def kernel(
703
+ self,
704
+ tiled_mma: cute.TiledMma,
705
+ tiled_mma_sfb: Optional[cute.TiledMma],
706
+ tma_atom_a: cute.CopyAtom,
707
+ mA_mkl: cute.Tensor,
708
+ tma_atom_b: cute.CopyAtom,
709
+ mB_nkl: cute.Tensor,
710
+ tma_atom_sfa: Optional[cute.CopyAtom],
711
+ mSFA_mkl: Optional[cute.Tensor],
712
+ tma_atom_sfb: Optional[cute.CopyAtom],
713
+ mSFB_nkl: Optional[cute.Tensor],
714
+ tma_atom_d: Optional[cute.CopyAtom],
715
+ mD_mnl: cute.Tensor,
716
+ tma_atom_c: Optional[cute.CopyAtom],
717
+ mC_mnl: Optional[cute.Tensor],
718
+ cluster_layout_vmnk: cute.Layout,
719
+ cluster_layout_sfb_vmnk: Optional[cute.Layout],
720
+ a_smem_layout_staged: cute.ComposedLayout,
721
+ b_smem_layout_staged: cute.ComposedLayout,
722
+ sfa_smem_layout_staged: Optional[cute.Layout],
723
+ sfb_smem_layout_staged: Optional[cute.Layout],
724
+ d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
725
+ epi_c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
726
+ epi_tile: cute.Tile,
727
+ tile_sched_params: ParamsBase,
728
+ TileSchedulerCls: cutlass.Constexpr[Callable],
729
+ epilogue_op: cutlass.Constexpr[Callable],
730
+ ):
731
+ """
732
+ GPU device kernel performing the Persistent batched GEMM computation.
733
+ """
734
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
735
+
736
+ #
737
+ # Prefetch tma desc
738
+ #
739
+ if warp_idx == self.tma_warp_id:
740
+ cpasync.prefetch_descriptor(tma_atom_a)
741
+ cpasync.prefetch_descriptor(tma_atom_b)
742
+ if const_expr(self.blockscaled):
743
+ cpasync.prefetch_descriptor(tma_atom_sfa)
744
+ cpasync.prefetch_descriptor(tma_atom_sfb)
745
+ cpasync.prefetch_descriptor(tma_atom_d)
746
+
747
+ use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
748
+
749
+ #
750
+ # Setup cta/thread coordinates
751
+ #
752
+ # Coords inside cluster
753
+ bidx, _, _ = cute.arch.block_idx()
754
+ mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
755
+ is_leader_cta = mma_tile_coord_v == 0
756
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
757
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
758
+ if const_expr(self.blockscaled):
759
+ block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
760
+ cta_rank_in_cluster
761
+ )
762
+ else:
763
+ block_in_cluster_coord_sfb_vmnk = None
764
+ # Coord inside cta
765
+ tidx, _, _ = cute.arch.thread_idx()
766
+
767
+ #
768
+ # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
769
+ #
770
+ smem = cutlass.utils.SmemAllocator()
771
+ storage = smem.allocate(self.shared_storage)
772
+
773
+ tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
774
+ tmem_holding_buf = storage.tmem_holding_buf
775
+
776
+ # Tensor memory dealloc barrier init
777
+ if use_2cta_instrs:
778
+ if warp_idx == self.tma_warp_id:
779
+ num_tmem_dealloc_threads = 32
780
+ cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
781
+
782
+ # Initialize mainloop ab_pipeline (barrier) and states
783
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
784
+ num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
785
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(
786
+ pipeline.Agent.Thread, num_tma_producer
787
+ )
788
+ ab_pipeline = pipeline.PipelineTmaUmma.create(
789
+ barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
790
+ num_stages=self.num_ab_stage,
791
+ producer_group=ab_pipeline_producer_group,
792
+ consumer_group=ab_pipeline_consumer_group,
793
+ tx_count=self.num_tma_load_bytes,
794
+ cta_layout_vmnk=cluster_layout_vmnk,
795
+ )
796
+
797
+ if const_expr(mC_mnl is not None):
798
+ # Threads/warps participating in this pipeline
799
+ epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
800
+ # Each warp will contribute 1 to the arrive count
801
+ consumer_arrive_cnt = len(self.epilog_warp_id)
802
+ epi_pipeline_consumer_group = pipeline.CooperativeGroup(
803
+ pipeline.Agent.Thread, consumer_arrive_cnt
804
+ )
805
+ c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
806
+ tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
807
+ epi_pipeline = pipeline.PipelineTmaAsync.create(
808
+ barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
809
+ num_stages=self.num_c_stage,
810
+ producer_group=epi_pipeline_producer_group,
811
+ consumer_group=epi_pipeline_consumer_group,
812
+ tx_count=tma_copy_c_bytes,
813
+ )
814
+ else:
815
+ epi_pipeline = None
816
+
817
+ # Initialize acc_pipeline (barrier) and states
818
+ acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
819
+ num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
820
+ acc_pipeline_consumer_group = pipeline.CooperativeGroup(
821
+ pipeline.Agent.Thread, num_acc_consumer_threads
822
+ )
823
+ acc_pipeline = pipeline.PipelineUmmaAsync.create(
824
+ barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
825
+ num_stages=self.num_acc_stage,
826
+ producer_group=acc_pipeline_producer_group,
827
+ consumer_group=acc_pipeline_consumer_group,
828
+ cta_layout_vmnk=cluster_layout_vmnk,
829
+ )
830
+
831
+ # if const_expr(tile_sched_params.tile_count_semaphore is not None):
832
+ # # Dynamic persistent scheduler
833
+ # # Threads/warps participating in this pipeline
834
+ # sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
835
+ # cluster_size = cute.size(cluster_layout_vmnk)
836
+ # # Each warp that are not the scheduler warp will contribute 1 to the arrive count
837
+ # consumer_arrive_cnt = (
838
+ # (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
839
+ # ) * cluster_size - 1
840
+ # sched_pipeline_consumer_group = pipeline.CooperativeGroup(
841
+ # pipeline.Agent.Thread, consumer_arrive_cnt
842
+ # )
843
+ # sched_pipeline = pipeline.PipelineAsync.create(
844
+ # barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
845
+ # num_stages=self.sched_stage,
846
+ # producer_group=sched_pipeline_producer_group,
847
+ # consumer_group=sched_pipeline_consumer_group,
848
+ # # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
849
+ # consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
850
+ # )
851
+ # tile_count = storage.tile_count.get_tensor((self.sched_stage,))
852
+ # else:
853
+ # sched_pipeline = None
854
+ # tile_count = None
855
+
856
+ # Setup smem tensor A/B/D
857
+ # (MMA, MMA_M, MMA_K, STAGE)
858
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
859
+ # (MMA, MMA_N, MMA_K, STAGE)
860
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
861
+ if const_expr(self.blockscaled):
862
+ # (MMA, MMA_M, MMA_K, STAGE)
863
+ sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
864
+ # (MMA, MMA_N, MMA_K, STAGE)
865
+ sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
866
+ else:
867
+ sSFA, sSFB = None, None
868
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
869
+ sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
870
+ if const_expr(mC_mnl is not None):
871
+ sC = storage.sC.get_tensor(
872
+ epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
873
+ )
874
+ else:
875
+ sC = None
876
+
877
+ thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
878
+ thr_mma_sfb = (
879
+ tiled_mma_sfb.get_slice(mma_tile_coord_v) if const_expr(self.blockscaled) else None
880
+ )
881
+
882
+ # (MMA, MMA_M, MMA_N)
883
+ acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
884
+ # (MMA, MMA_M, MMA_N, STAGE)
885
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
886
+
887
+ tmem_ptr_read_threads = cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id))
888
+ tmem_alloc_barrier = pipeline.NamedBarrier(
889
+ barrier_id=self.tmem_ptr_sync_bar_id, num_threads=tmem_ptr_read_threads
890
+ )
891
+
892
+ TileSchedulerCls = partial(TileSchedulerCls.create, tile_sched_params)
893
+ k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.mma_tiler[2])
894
+
895
+ if const_expr(mC_mnl is not None):
896
+ epi_load_barrier = pipeline.NamedBarrier(
897
+ barrier_id=int(self.epilog_load_bar_id), num_threads=2 * cute.arch.WARP_SIZE
898
+ )
899
+ else:
900
+ epi_load_barrier = None
901
+
902
+ #
903
+ # Specialized TMA load warp
904
+ #
905
+ if warp_idx == self.tma_warp_id:
906
+ # Compute multicast mask for A/B buffer full
907
+ if const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
908
+ a_mcast_mask = cpasync.create_tma_multicast_mask(
909
+ cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
910
+ )
911
+ b_mcast_mask = cpasync.create_tma_multicast_mask(
912
+ cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
913
+ )
914
+ if const_expr(self.blockscaled):
915
+ sfa_mcast_mask = cpasync.create_tma_multicast_mask(
916
+ cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
917
+ )
918
+ sfb_mcast_mask = cpasync.create_tma_multicast_mask(
919
+ cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1
920
+ )
921
+ else:
922
+ sfa_mcast_mask, sfb_mcast_mask = None, None
923
+ else:
924
+ a_mcast_mask, b_mcast_mask = None, None
925
+ sfa_mcast_mask, sfb_mcast_mask = None, None
926
+
927
+ # Persistent tile scheduling loop
928
+ tile_scheduler = TileSchedulerCls()
929
+ work_tile = tile_scheduler.initial_work_tile_info()
930
+ ab_producer_state = pipeline.make_pipeline_state(
931
+ pipeline.PipelineUserType.Producer, self.num_ab_stage
932
+ )
933
+ do_epi_load_barrier_arrive = cutlass.Boolean(True)
934
+ while work_tile.is_valid_tile:
935
+ # Get tile coord from tile scheduler
936
+ tile_coord_mnkl = work_tile.tile_idx
937
+ mma_tile_coord_mnl = (
938
+ tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
939
+ tile_coord_mnkl[1],
940
+ tile_coord_mnkl[3],
941
+ )
942
+ # Local_tile partition global tensors
943
+ # (bM, bK, RestK)
944
+ gA_mkl = cute.local_tile(
945
+ mA_mkl,
946
+ cute.slice_(self.mma_tiler, (None, 0, None)),
947
+ (mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]),
948
+ )
949
+ # (bN, bK, RestK)
950
+ gB_nkl = cute.local_tile(
951
+ mB_nkl,
952
+ cute.slice_(self.mma_tiler, (0, None, None)),
953
+ (mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]),
954
+ )
955
+ if const_expr(self.blockscaled):
956
+ # (bM, bK)
957
+ gSFA_mkl = cute.local_tile(
958
+ mSFA_mkl,
959
+ cute.slice_(self.mma_tiler, (None, 0, None)),
960
+ (mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]),
961
+ )
962
+ # (bN, bK)
963
+ gSFB_nkl = cute.local_tile(
964
+ mSFB_nkl,
965
+ cute.slice_(self.mma_tiler, (0, None, None)),
966
+ (mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]),
967
+ )
968
+ # Partition global tensor for TiledMMA_A/B/D
969
+ # (MMA, MMA_M, MMA_K, RestK)
970
+ tCgA = thr_mma.partition_A(gA_mkl)
971
+ # (MMA, MMA_N, MMA_K, RestK)
972
+ tCgB = thr_mma.partition_B(gB_nkl)
973
+ if const_expr(self.blockscaled):
974
+ # (MMA, MMA_M, MMA_K)
975
+ tCgSFA = thr_mma.partition_A(gSFA_mkl)
976
+ # (MMA, MMA_N, MMA_K)
977
+ tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
978
+ # Partition global/shared tensor for TMA load A/B
979
+ # TMA load A partition_S/D
980
+ a_cta_layout = cute.make_layout(
981
+ cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
982
+ )
983
+ # ((atom_v, rest_v), STAGE)
984
+ # ((atom_v, rest_v), RestK)
985
+ tAsA, tAgA = cpasync.tma_partition(
986
+ tma_atom_a,
987
+ block_in_cluster_coord_vmnk[2],
988
+ a_cta_layout,
989
+ cute.group_modes(sA, 0, 3),
990
+ cute.group_modes(tCgA, 0, 3),
991
+ )
992
+ # TMA load B partition_S/D
993
+ b_cta_layout = cute.make_layout(
994
+ cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
995
+ )
996
+ # ((atom_v, rest_v), STAGE)
997
+ # ((atom_v, rest_v), RestK)
998
+ tBsB, tBgB = cpasync.tma_partition(
999
+ tma_atom_b,
1000
+ block_in_cluster_coord_vmnk[1],
1001
+ b_cta_layout,
1002
+ cute.group_modes(sB, 0, 3),
1003
+ cute.group_modes(tCgB, 0, 3),
1004
+ )
1005
+ if const_expr(self.blockscaled):
1006
+ # TMA load SFA partition_S/D
1007
+ sfa_cta_layout = a_cta_layout
1008
+ # ((atom_v, rest_v), STAGE)
1009
+ # ((atom_v, rest_v), RestK)
1010
+ tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition(
1011
+ tma_atom_sfa,
1012
+ block_in_cluster_coord_vmnk[2],
1013
+ sfa_cta_layout,
1014
+ cute.group_modes(sSFA, 0, 3),
1015
+ cute.group_modes(tCgSFA, 0, 3),
1016
+ )
1017
+ tAsSFA = cute.filter_zeros(tAsSFA)
1018
+ tAgSFA = cute.filter_zeros(tAgSFA)
1019
+ # TMA load SFB partition_S/D
1020
+ sfb_cta_layout = cute.make_layout(
1021
+ cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape
1022
+ )
1023
+ # ((atom_v, rest_v), STAGE)
1024
+ # ((atom_v, rest_v), RestK)
1025
+ tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition(
1026
+ tma_atom_sfb,
1027
+ block_in_cluster_coord_sfb_vmnk[1],
1028
+ sfb_cta_layout,
1029
+ cute.group_modes(sSFB, 0, 3),
1030
+ cute.group_modes(tCgSFB, 0, 3),
1031
+ )
1032
+ tBsSFB = cute.filter_zeros(tBsSFB)
1033
+ tBgSFB = cute.filter_zeros(tBgSFB)
1034
+ else:
1035
+ tAsSFA, tAgSFA = None, None
1036
+ tBsSFB, tBgSFB = None, None
1037
+ ab_producer_state = self.load_AB(
1038
+ ab_pipeline,
1039
+ ab_producer_state,
1040
+ tma_atom_a,
1041
+ tAgA,
1042
+ tAsA,
1043
+ a_mcast_mask,
1044
+ tma_atom_b,
1045
+ tBgB,
1046
+ tBsB,
1047
+ b_mcast_mask,
1048
+ tma_atom_sfa,
1049
+ tAgSFA,
1050
+ tAsSFA,
1051
+ sfa_mcast_mask,
1052
+ tma_atom_sfb,
1053
+ tBgSFB,
1054
+ tBsSFB,
1055
+ sfb_mcast_mask,
1056
+ )
1057
+ if const_expr(epi_load_barrier is not None):
1058
+ # In the first work tile, the epi load warp will wait for the signal
1059
+ # from the mainloop load warp to start loading C, to avoid interfering
1060
+ # with loading A and B.
1061
+ if do_epi_load_barrier_arrive:
1062
+ epi_load_barrier.arrive()
1063
+ do_epi_load_barrier_arrive = cutlass.Boolean(False)
1064
+ # Advance to next tile
1065
+ tile_scheduler.advance_to_next_work()
1066
+ work_tile = tile_scheduler.get_current_work()
1067
+ # Wait A/B buffer empty
1068
+ ab_pipeline.producer_tail(ab_producer_state)
1069
+
1070
+ #
1071
+ # Specialized TMA epi load warp
1072
+ #
1073
+ if const_expr(mC_mnl is not None):
1074
+ if warp_idx == self.tma_epi_warp_id:
1075
+ epi_producer_state = pipeline.make_pipeline_state(
1076
+ pipeline.PipelineUserType.Producer, self.num_c_stage
1077
+ )
1078
+ do_epi_load_barrier_wait = cutlass.Boolean(True)
1079
+ # Persistent tile scheduling loop
1080
+ tile_scheduler = TileSchedulerCls()
1081
+ work_tile = tile_scheduler.initial_work_tile_info()
1082
+ while work_tile.is_valid_tile:
1083
+ # Get tile coord from tile scheduler
1084
+ tile_coord_mnkl = work_tile.tile_idx
1085
+ mma_tile_coord_mnl = (
1086
+ tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
1087
+ tile_coord_mnkl[1],
1088
+ tile_coord_mnkl[3],
1089
+ )
1090
+ # Local_tile partition global tensors
1091
+ # (bM, bN)
1092
+ gC_mnl = cute.local_tile(
1093
+ mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
1094
+ )
1095
+ # Partition global tensor for TiledMMA_A/B/D
1096
+ # (MMA, MMA_M, MMA_N)
1097
+ tCgC = thr_mma.partition_C(gC_mnl)
1098
+ # bGS_gC has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
1099
+ bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
1100
+ tma_atom_c, tCgC, epi_tile, sC
1101
+ )
1102
+ bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
1103
+ if do_epi_load_barrier_wait:
1104
+ epi_load_barrier.arrive_and_wait()
1105
+ do_epi_load_barrier_wait = cutlass.Boolean(False)
1106
+ epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1]))
1107
+ for subtile_idx in cutlass.range(epi_tile_num, unroll=1):
1108
+ epi_pipeline.producer_acquire(epi_producer_state)
1109
+ cute.copy(
1110
+ tma_atom_c,
1111
+ bGS_gC[None, subtile_idx],
1112
+ bGS_sC[None, epi_producer_state.index],
1113
+ tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1114
+ )
1115
+ # Epi pipeline's producer commit is a NOP
1116
+ epi_pipeline.producer_commit(epi_producer_state)
1117
+ epi_producer_state.advance()
1118
+ # Advance to next tile
1119
+ tile_scheduler.advance_to_next_work()
1120
+ work_tile = tile_scheduler.get_current_work()
1121
+ # End of persistent scheduler loop
1122
+ epi_pipeline.producer_tail(epi_producer_state)
1123
+
1124
+ #
1125
+ # Specialized MMA warp
1126
+ #
1127
+ if warp_idx == self.mma_warp_id:
1128
+ tmem_alloc_barrier.arrive_and_wait()
1129
+ # Retrieving tensor memory ptr and make accumulator tensor
1130
+ acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
1131
+ self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
1132
+ )
1133
+ # Partition shared/tensor memory tensor for TiledMMA_A/B/D
1134
+ # (MMA, MMA_M, MMA_K, STAGE)
1135
+ tCrA = tiled_mma.make_fragment_A(sA)
1136
+ # (MMA, MMA_N, MMA_K, STAGE)
1137
+ tCrB = tiled_mma.make_fragment_B(sB)
1138
+ # (MMA, MMA_M, MMA_N, STAGE)
1139
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
1140
+
1141
+ if const_expr(self.blockscaled):
1142
+ # Make SFA tmem tensor
1143
+ sfa_tmem_ptr = cute.recast_ptr(
1144
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base),
1145
+ dtype=self.sf_dtype,
1146
+ )
1147
+ # (MMA, MMA_M, MMA_K)
1148
+ tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
1149
+ tiled_mma,
1150
+ self.mma_tiler,
1151
+ self.sf_vec_size,
1152
+ cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
1153
+ )
1154
+ tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
1155
+
1156
+ # Make SFB tmem tensor
1157
+ sfb_tmem_ptr = cute.recast_ptr(
1158
+ acc_tmem_ptr
1159
+ + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
1160
+ + tcgen05.find_tmem_tensor_col_offset(tCtSFA),
1161
+ dtype=self.sf_dtype,
1162
+ )
1163
+ # (MMA, MMA_N, MMA_K)
1164
+ tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
1165
+ tiled_mma,
1166
+ self.mma_tiler,
1167
+ self.sf_vec_size,
1168
+ cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
1169
+ )
1170
+ tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
1171
+ # Partition for S2T copy of SFA/SFB
1172
+ (
1173
+ tiled_copy_s2t_sfa,
1174
+ tCsSFA_compact_s2t,
1175
+ tCtSFA_compact_s2t,
1176
+ ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA)
1177
+ (
1178
+ tiled_copy_s2t_sfb,
1179
+ tCsSFB_compact_s2t,
1180
+ tCtSFB_compact_s2t,
1181
+ ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
1182
+ else:
1183
+ tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
1184
+ tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
1185
+
1186
+ # Persistent tile scheduling loop
1187
+ tile_scheduler = TileSchedulerCls()
1188
+ work_tile = tile_scheduler.initial_work_tile_info()
1189
+ ab_consumer_state = pipeline.make_pipeline_state(
1190
+ pipeline.PipelineUserType.Consumer, self.num_ab_stage
1191
+ )
1192
+ acc_producer_state = pipeline.make_pipeline_state(
1193
+ pipeline.PipelineUserType.Producer, self.num_acc_stage
1194
+ )
1195
+ while work_tile.is_valid_tile:
1196
+ # Get tile coord from tile scheduler
1197
+ tile_coord_mnkl = work_tile.tile_idx
1198
+ # Set tensor memory buffer for current tile
1199
+ # (MMA, MMA_M, MMA_N)
1200
+ tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index]
1201
+ ab_consumer_state, acc_producer_state, tiled_mma = self.mma(
1202
+ ab_pipeline,
1203
+ acc_pipeline,
1204
+ ab_consumer_state,
1205
+ acc_producer_state,
1206
+ tiled_mma,
1207
+ tCrA,
1208
+ tCrB,
1209
+ tCtAcc,
1210
+ k_tile_cnt,
1211
+ is_leader_cta,
1212
+ tiled_copy_s2t_sfa,
1213
+ tiled_copy_s2t_sfb,
1214
+ tCsSFA_compact_s2t,
1215
+ tCsSFB_compact_s2t,
1216
+ tCtSFA_compact_s2t,
1217
+ tCtSFB_compact_s2t,
1218
+ )
1219
+ # Advance to next tile
1220
+ tile_scheduler.advance_to_next_work()
1221
+ work_tile = tile_scheduler.get_current_work()
1222
+
1223
+ # Wait for accumulator buffer empty
1224
+ acc_pipeline.producer_tail(acc_producer_state)
1225
+
1226
+ #
1227
+ # Specialized epilogue warps
1228
+ #
1229
+ if warp_idx < self.mma_warp_id:
1230
+ # Alloc tensor memory buffer
1231
+ if warp_idx == self.epilog_warp_id[0]:
1232
+ cute.arch.alloc_tmem(
1233
+ self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs
1234
+ )
1235
+ # Bar sync for retrieve tensor memory ptr from shared memory
1236
+ tmem_alloc_barrier.arrive_and_wait()
1237
+ # Retrieving tensor memory ptr and make accumulator tensor
1238
+ acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
1239
+ self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
1240
+ )
1241
+ # (MMA, MMA_M, MMA_N, STAGE)
1242
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
1243
+
1244
+ epilog_threads = cute.arch.WARP_SIZE * len(self.epilog_warp_id)
1245
+ epilogue_barrier = pipeline.NamedBarrier(
1246
+ barrier_id=self.epilog_sync_bar_id, num_threads=epilog_threads
1247
+ )
1248
+
1249
+ # Partition for epilogue
1250
+ epi_tidx = tidx
1251
+ tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
1252
+ epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
1253
+ )
1254
+
1255
+ tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.d_dtype)
1256
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
1257
+ tiled_copy_t2r, tTR_rD, epi_tidx, sD
1258
+ )
1259
+ if const_expr(mC_mnl is not None):
1260
+ tTR_rC = cute.make_fragment_like(tTR_rD, self.c_dtype)
1261
+ tiled_copy_s2r, tSR_rC, tSR_sC = self.epilog_smem_copy_and_partition(
1262
+ tiled_copy_t2r, tTR_rC, epi_tidx, sC
1263
+ )
1264
+ # TODO: for m major, D is being stored w STSM so we'd need LDSM here
1265
+ # tRS_rC = tSR_rC # TODO: retile?
1266
+ tRS_rC = cute.make_fragment(tRS_rD.layout, self.c_dtype)
1267
+ tSR_rC = tiled_copy_s2r.get_slice(epi_tidx).retile(tRS_rC)
1268
+
1269
+ # Persistent tile scheduling loop
1270
+ tile_scheduler = TileSchedulerCls()
1271
+ work_tile = tile_scheduler.initial_work_tile_info()
1272
+ acc_consumer_state = pipeline.make_pipeline_state(
1273
+ pipeline.PipelineUserType.Consumer, self.num_acc_stage
1274
+ )
1275
+ # Threads/warps participating in tma store pipeline
1276
+ d_producer_group = pipeline.CooperativeGroup(
1277
+ pipeline.Agent.Thread,
1278
+ 32 * len(self.epilog_warp_id),
1279
+ 32 * len(self.epilog_warp_id),
1280
+ )
1281
+ d_pipeline = pipeline.PipelineTmaStore.create(
1282
+ num_stages=self.num_d_stage, producer_group=d_producer_group
1283
+ )
1284
+ epi_read_state = pipeline.make_pipeline_state(
1285
+ pipeline.PipelineUserType.Consumer, self.num_c_stage
1286
+ )
1287
+
1288
+ while work_tile.is_valid_tile:
1289
+ # Get tile coord from tile scheduler
1290
+ tile_coord_mnkl = work_tile.tile_idx
1291
+ mma_tile_coord_mnl = (
1292
+ tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
1293
+ tile_coord_mnkl[1],
1294
+ tile_coord_mnkl[3],
1295
+ )
1296
+ # Local_tile partition global tensors
1297
+ # (bM, bN)
1298
+ gD_mnl = cute.local_tile(
1299
+ mD_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), mma_tile_coord_mnl
1300
+ )
1301
+ # Partition global tensor for TiledMMA_A/B/D
1302
+ # (MMA, MMA_M, MMA_N)
1303
+ tDgD = thr_mma.partition_C(gD_mnl)
1304
+ # bSG_gD has shape ((ATOM_V, REST_V), EPI_M, EPI_N)
1305
+ bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(tma_atom_d, tDgD, epi_tile, sD)
1306
+
1307
+ # Set tensor memory buffer for current tile
1308
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
1309
+ tTR_tAcc = tTR_tAcc_base[None, None, None, None, None, acc_consumer_state.index]
1310
+
1311
+ # Wait for accumulator buffer full
1312
+ acc_pipeline.consumer_wait(acc_consumer_state)
1313
+
1314
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
1315
+ bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
1316
+
1317
+ # Store accumulator to global memory in subtiles
1318
+ subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
1319
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * subtile_cnt
1320
+ for subtile_idx in cutlass.range(subtile_cnt):
1321
+ # Load accumulator from tensor memory buffer to register
1322
+ tTR_tAcc_mn = tTR_tAcc[None, None, None, subtile_idx]
1323
+ cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
1324
+ # Convert to D type
1325
+ acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
1326
+ acc_vec = epilogue_op(acc_vec)
1327
+ if const_expr(mC_mnl is not None):
1328
+ epi_pipeline.consumer_wait(epi_read_state)
1329
+ cute.copy(
1330
+ tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
1331
+ )
1332
+ # Fence to make sure shared memory read is visible to TMA load
1333
+ cute.arch.fence_proxy(
1334
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1335
+ )
1336
+ cute.arch.sync_warp()
1337
+ with cute.arch.elect_one():
1338
+ epi_pipeline.consumer_release(epi_read_state)
1339
+ epi_read_state.advance()
1340
+ acc_vec = acc_vec + tRS_rC.load().to(self.acc_dtype)
1341
+ tRS_rD.store(acc_vec.to(self.d_dtype))
1342
+ # Store D to shared memory
1343
+ d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage
1344
+ cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
1345
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1346
+ cute.arch.fence_proxy(
1347
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1348
+ )
1349
+ epilogue_barrier.arrive_and_wait()
1350
+ # TMA store D to global memory
1351
+ if warp_idx == self.epilog_warp_id[0]:
1352
+ cute.copy(tma_atom_d, bSG_sD[None, d_buffer], bSG_gD[None, subtile_idx])
1353
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1354
+ d_pipeline.producer_commit()
1355
+ d_pipeline.producer_acquire()
1356
+ epilogue_barrier.arrive_and_wait()
1357
+
1358
+ # Async arrive accumulator buffer empty
1359
+ with cute.arch.elect_one():
1360
+ acc_pipeline.consumer_release(acc_consumer_state)
1361
+ acc_consumer_state.advance()
1362
+
1363
+ # Advance to next tile
1364
+ tile_scheduler.advance_to_next_work()
1365
+ work_tile = tile_scheduler.get_current_work()
1366
+
1367
+ # Dealloc the tensor memory buffer
1368
+ if warp_idx == self.epilog_warp_id[0]:
1369
+ cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
1370
+ epilogue_barrier.arrive_and_wait()
1371
+ if warp_idx == self.epilog_warp_id[0]:
1372
+ if use_2cta_instrs:
1373
+ cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
1374
+ cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
1375
+ cute.arch.dealloc_tmem(
1376
+ acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs
1377
+ )
1378
+
1379
+ # Wait for D store complete
1380
+ d_pipeline.producer_tail()
1381
+
1382
+ @cute.jit
1383
+ def load_AB(
1384
+ self,
1385
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1386
+ ab_producer_state: cutlass.pipeline.PipelineState,
1387
+ tma_atom_a: cute.CopyAtom,
1388
+ tAgA: cute.Tensor,
1389
+ tAsA: cute.Tensor,
1390
+ a_mcast_mask: cutlass.Int16,
1391
+ tma_atom_b: cute.CopyAtom,
1392
+ tBgB: cute.Tensor,
1393
+ tBsB: cute.Tensor,
1394
+ b_mcast_mask: cutlass.Int16,
1395
+ tma_atom_sfa: Optional[cute.CopyAtom] = None,
1396
+ tAgSFA: Optional[cute.Tensor] = None,
1397
+ tAsSFA: Optional[cute.Tensor] = None,
1398
+ sfa_mcast_mask: Optional[cutlass.Int16] = None,
1399
+ tma_atom_sfb: Optional[cute.CopyAtom] = None,
1400
+ tBgSFB: Optional[cute.Tensor] = None,
1401
+ tBsSFB: Optional[cute.Tensor] = None,
1402
+ sfb_mcast_mask: Optional[cutlass.Int16] = None,
1403
+ ) -> cutlass.pipeline.PipelineState:
1404
+ blockscaled = const_expr(tma_atom_sfa is not None)
1405
+ if const_expr(blockscaled):
1406
+ assert all(x is not None for x in (tma_atom_sfa, tAgSFA, tAsSFA))
1407
+ assert all(x is not None for x in (tma_atom_sfb, tBgSFB, tBsSFB))
1408
+ k_tile_cnt = cute.size(tAgA, mode=[1])
1409
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1410
+ peek_ab_empty_status = cutlass.Boolean(True)
1411
+ if 0 < k_tile_cnt:
1412
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1413
+ # /////////////////////////////////////////////////////////////////////////
1414
+ # TMA load
1415
+ # /////////////////////////////////////////////////////////////////////////
1416
+ for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1417
+ # Wait for A/B buffers to be empty before loading into them
1418
+ # Also sets the transaction barrier for the A/B buffers
1419
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1420
+ cute.copy(
1421
+ tma_atom_a,
1422
+ tAgA[None, k_tile],
1423
+ tAsA[None, ab_producer_state.index],
1424
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1425
+ mcast_mask=a_mcast_mask,
1426
+ )
1427
+ cute.copy(
1428
+ tma_atom_b,
1429
+ tBgB[None, k_tile],
1430
+ tBsB[None, ab_producer_state.index],
1431
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1432
+ mcast_mask=b_mcast_mask,
1433
+ )
1434
+ if const_expr(blockscaled):
1435
+ cute.copy(
1436
+ tma_atom_sfa,
1437
+ tAgSFA[None, ab_producer_state.count],
1438
+ tAsSFA[None, ab_producer_state.index],
1439
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1440
+ mcast_mask=sfa_mcast_mask,
1441
+ )
1442
+ cute.copy(
1443
+ tma_atom_sfb,
1444
+ tBgSFB[None, ab_producer_state.count],
1445
+ tBsSFB[None, ab_producer_state.index],
1446
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1447
+ mcast_mask=sfb_mcast_mask,
1448
+ )
1449
+ # Mainloop pipeline's producer commit is a NOP
1450
+ ab_pipeline.producer_commit(ab_producer_state)
1451
+ ab_producer_state.advance()
1452
+ peek_ab_empty_status = cutlass.Boolean(True)
1453
+ if k_tile + 1 < k_tile_cnt:
1454
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1455
+ return ab_producer_state
1456
+
1457
+ @cute.jit
1458
+ def mma(
1459
+ self,
1460
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1461
+ acc_pipeline: cutlass.pipeline.PipelineAsync,
1462
+ ab_consumer_state: cutlass.pipeline.PipelineState,
1463
+ acc_producer_state: cutlass.pipeline.PipelineState,
1464
+ tiled_mma: cute.TiledMma,
1465
+ tCrA: cute.Tensor,
1466
+ tCrB: cute.Tensor,
1467
+ acc: cute.Tensor,
1468
+ k_tile_cnt: Int32,
1469
+ is_leader_cta: cutlass.Boolean,
1470
+ tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None,
1471
+ tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None,
1472
+ tCsSFA_compact_s2t: Optional[cute.Tensor] = None,
1473
+ tCsSFB_compact_s2t: Optional[cute.Tensor] = None,
1474
+ tCtSFA_compact_s2t: Optional[cute.Tensor] = None,
1475
+ tCtSFB_compact_s2t: Optional[cute.Tensor] = None,
1476
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]:
1477
+ blockscaled = const_expr(tiled_copy_s2t_sfa is not None)
1478
+ if const_expr(blockscaled):
1479
+ assert all(x is not None for x in (tiled_copy_s2t_sfa, tiled_copy_s2t_sfb))
1480
+ assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t))
1481
+ assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t))
1482
+ # Peek (try_wait) AB buffer full for k_tile = 0
1483
+ peek_ab_full_status = cutlass.Boolean(True)
1484
+ if 0 < k_tile_cnt and is_leader_cta:
1485
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
1486
+ # Wait for accumulator buffer empty
1487
+ if is_leader_cta:
1488
+ acc_pipeline.producer_acquire(acc_producer_state)
1489
+ # Reset the ACCUMULATE field for each tile
1490
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
1491
+ # Mma mainloop
1492
+ num_k_blocks = cute.size(tCrA, mode=[2])
1493
+ for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1494
+ if is_leader_cta:
1495
+ # Conditionally wait for AB buffer full
1496
+ ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
1497
+ # Copy SFA/SFB from smem to tmem
1498
+ if const_expr(blockscaled):
1499
+ s2t_stage_coord = (None, None, None, None, ab_consumer_state.index)
1500
+ tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord]
1501
+ tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord]
1502
+ cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t)
1503
+ cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
1504
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1505
+ k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index)
1506
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1507
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
1508
+ # Async arrive AB buffer empty
1509
+ ab_pipeline.consumer_release(ab_consumer_state)
1510
+ ab_consumer_state.advance()
1511
+ # Peek (try_wait) AB buffer full for k_tile = k_tile + 1
1512
+ peek_ab_full_status = cutlass.Boolean(True)
1513
+ if k_tile + 1 < k_tile_cnt and is_leader_cta:
1514
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
1515
+ # Async arrive accumulator buffer full
1516
+ if is_leader_cta:
1517
+ acc_pipeline.producer_commit(acc_producer_state)
1518
+ acc_producer_state.advance()
1519
+ # If we don't return the tiled_mma, we get compiler error
1520
+ # "operand #0 does not dominate this use"
1521
+ return ab_consumer_state, acc_producer_state, tiled_mma
1522
+
1523
+ def mainloop_s2t_copy_and_partition(
1524
+ self,
1525
+ sSF: cute.Tensor,
1526
+ tSF: cute.Tensor,
1527
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1528
+ """
1529
+ Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination).
1530
+
1531
+ :param sSF: The scale factor tensor in smem
1532
+ :type sSF: cute.Tensor
1533
+ :param tSF: The scale factor tensor in tmem
1534
+ :type tSF: cute.Tensor
1535
+
1536
+ :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where:
1537
+ - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t)
1538
+ - tCsSF_compact_s2t: The partitioned scale factor tensor in smem
1539
+ - tSF_compact_s2t: The partitioned scale factor tensor in tmem
1540
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
1541
+ """
1542
+ # (MMA, MMA_MN, MMA_K, STAGE)
1543
+ tCsSF_compact = cute.filter_zeros(sSF)
1544
+ # (MMA, MMA_MN, MMA_K)
1545
+ tCtSF_compact = cute.filter_zeros(tSF)
1546
+ # Make S2T CopyAtom and tiledCopy
1547
+ copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype)
1548
+ tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
1549
+ thr_copy_s2t = tiled_copy_s2t.get_slice(0)
1550
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
1551
+ tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
1552
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
1553
+ tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
1554
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K)
1555
+ tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
1556
+ return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
1557
+
1558
+ def epilog_tmem_copy_and_partition(
1559
+ self,
1560
+ tidx: Int32,
1561
+ tAcc: cute.Tensor,
1562
+ epi_tile: cute.Tile,
1563
+ use_2cta_instrs: Union[cutlass.Boolean, bool],
1564
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1565
+ """
1566
+ Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
1567
+
1568
+ :param tidx: The thread index in epilogue warp groups
1569
+ :type tidx: Int32
1570
+ :param tAcc: The accumulator tensor to be copied and partitioned
1571
+ :type tAcc: cute.Tensor
1572
+ :param epi_tile: The epilogue tiler
1573
+ :type epi_tile: cute.Tile
1574
+ :param use_2cta_instrs: Whether use_2cta_instrs is enabled
1575
+ :type use_2cta_instrs: bool
1576
+
1577
+ :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
1578
+ - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
1579
+ - tTR_tAcc: The partitioned accumulator tensor
1580
+ - tTR_rAcc: The accumulated tensor in register used to hold t2r results
1581
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
1582
+ """
1583
+ # Make tiledCopy for tensor memory load
1584
+ copy_atom_t2r = sm100_utils.get_tmem_load_op(
1585
+ self.cta_tile_shape_mnk,
1586
+ self.d_layout,
1587
+ self.d_dtype,
1588
+ self.acc_dtype,
1589
+ epi_tile,
1590
+ use_2cta_instrs,
1591
+ )
1592
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
1593
+ tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile)
1594
+ # (EPI_TILE_M, EPI_TILE_N)
1595
+ tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
1596
+
1597
+ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
1598
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
1599
+ tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
1600
+
1601
+ cAcc = cute.make_identity_tensor((self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]))
1602
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
1603
+ cAcc_epi = cute.flat_divide(cAcc, epi_tile)
1604
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
1605
+ tTR_cAcc = thr_copy_t2r.partition_D(cAcc_epi)
1606
+ # (T2R, T2R_M, T2R_N)
1607
+ tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
1608
+ return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
1609
+
1610
+ def epilog_smem_copy_and_partition(
1611
+ self,
1612
+ tiled_copy_t2r: cute.TiledCopy,
1613
+ tTR_rD: cute.Tensor,
1614
+ tidx: Int32,
1615
+ sD: cute.Tensor,
1616
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1617
+ """
1618
+ Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
1619
+
1620
+ :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
1621
+ :type tiled_copy_t2r: cute.TiledCopy
1622
+ :param tTR_rD: The partitioned accumulator tensor
1623
+ :type tTR_rD: cute.Tensor
1624
+ :param tidx: The thread index in epilogue warp groups
1625
+ :type tidx: Int32
1626
+ :param sD: The shared memory tensor to be copied and partitioned
1627
+ :type sD: cute.Tensor
1628
+ :type sepi: cute.Tensor
1629
+
1630
+ :return: A tuple containing (tiled_copy_r2s, tRS_rD, tRS_sD) where:
1631
+ - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s)
1632
+ - tRS_rD: The partitioned tensor C (register source)
1633
+ - tRS_sD: The partitioned tensor C (smem destination)
1634
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
1635
+ """
1636
+ copy_atom_r2s = sm100_utils.get_smem_store_op(
1637
+ self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r
1638
+ )
1639
+ tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
1640
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1641
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1642
+ tRS_sD = thr_copy_r2s.partition_D(sD)
1643
+ # (R2S, R2S_M, R2S_N)
1644
+ tRS_rD = tiled_copy_r2s.retile(tTR_rD)
1645
+ return tiled_copy_r2s, tRS_rD, tRS_sD
1646
+
1647
+ # def epilog_smem_load_copy_and_partition(
1648
+ # self,
1649
+ # tiled_copy_t2r: cute.TiledCopy,
1650
+ # tTR_rC: cute.Tensor,
1651
+ # tidx: Int32,
1652
+ # sC: cute.Tensor,
1653
+ # ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1654
+ # copy_atom_s2r = cute.make_copy_atom(
1655
+ # warp.LdMatrix8x8x16bOp(self.c_layout.is_m_major_c(), num_matrices=4),
1656
+ # self.c_dtype, # TODO: this probably only works for f16 for now?
1657
+ # )
1658
+ # # copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
1659
+ # tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
1660
+ # # (R2S, R2S_M, R2S_N, PIPE_D)
1661
+ # thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1662
+ # # (R2S, R2S_M, R2S_N)
1663
+ # tSR_sC = thr_copy_s2r.partition_S(sC)
1664
+ # return tiled_copy_s2r, tSR_sC
1665
+
1666
+ def epilog_gmem_copy_and_partition(
1667
+ self,
1668
+ atom: Union[cute.CopyAtom, cute.TiledCopy],
1669
+ gD_mnl: cute.Tensor,
1670
+ epi_tile: cute.Tile,
1671
+ sD: cute.Tensor,
1672
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
1673
+ """Make tiledCopy for global memory store, then use it to:
1674
+ - partition register array (source) and global memory (destination) for none TMA store version;
1675
+ - partition shared memory (source) and global memory (destination) for TMA store version.
1676
+
1677
+ :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
1678
+ :type atom: cute.CopyAtom or cute.TiledCopy
1679
+ :param gD_mnl: The global tensor C
1680
+ :type gD_mnl: cute.Tensor
1681
+ :param epi_tile: The epilogue tiler
1682
+ :type epi_tile: cute.Tile
1683
+ :param sD: The shared memory tensor to be copied and partitioned
1684
+ :type sD: cute.Tensor
1685
+
1686
+ :return: A tuple containing either:
1687
+ - For TMA store: (tma_atom_d, bSG_sD, bSG_gD) where:
1688
+ - tma_atom_d: The TMA copy atom
1689
+ - bSG_sD: The partitioned shared memory tensor C
1690
+ - bSG_gD: The partitioned global tensor C
1691
+ - For non-TMA store: (simt_atom, tTR_rD, tTR_gD) where:
1692
+ - simt_atom: The SIMT copy atom
1693
+ - tTR_rD: The register tensor C
1694
+ - tTR_gD: The partitioned global tensor C
1695
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
1696
+ """
1697
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
1698
+ gD_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0)], epi_tile)
1699
+ sD_for_tma_partition = cute.group_modes(sD, 0, 2)
1700
+ gD_for_tma_partition = cute.group_modes(gD_epi, 0, 2)
1701
+ # ((ATOM_V, REST_V), EPI_M, EPI_N)
1702
+ bSG_sD, bSG_gD = cpasync.tma_partition(
1703
+ atom,
1704
+ 0,
1705
+ cute.make_layout(1),
1706
+ sD_for_tma_partition,
1707
+ gD_for_tma_partition,
1708
+ )
1709
+ return bSG_sD, bSG_gD
1710
+
1711
+ @staticmethod
1712
+ def _compute_stages(
1713
+ tiled_mma: cute.TiledMma,
1714
+ mma_tiler_mnk: Tuple[int, int, int],
1715
+ a_dtype: Type[cutlass.Numeric],
1716
+ b_dtype: Type[cutlass.Numeric],
1717
+ epi_tile: cute.Tile,
1718
+ d_dtype: Type[cutlass.Numeric],
1719
+ c_dtype: Optional[Type[cutlass.Numeric]],
1720
+ d_layout: cutlass.utils.LayoutEnum,
1721
+ c_layout: Optional[cutlass.utils.LayoutEnum],
1722
+ sf_dtype: Optional[Type[cutlass.Numeric]],
1723
+ sf_vec_size: Optional[int],
1724
+ smem_capacity: int,
1725
+ occupancy: int,
1726
+ ) -> Tuple[int, int, int]:
1727
+ """Computes the number of stages for A/B/C operands based on heuristics.
1728
+
1729
+ :param tiled_mma: The tiled MMA object defining the core computation.
1730
+ :type tiled_mma: cute.TiledMma
1731
+ :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
1732
+ :type mma_tiler_mnk: tuple[int, int, int]
1733
+ :param a_dtype: Data type of operand A.
1734
+ :type a_dtype: type[cutlass.Numeric]
1735
+ :param b_dtype: Data type of operand B.
1736
+ :type b_dtype: type[cutlass.Numeric]
1737
+ :param epi_tile: The epilogue tile shape.
1738
+ :type epi_tile: cute.Tile
1739
+ :param d_dtype: Data type of operand C (output).
1740
+ :type d_dtype: type[cutlass.Numeric]
1741
+ :param d_layout: Layout enum of operand C.
1742
+ :type d_layout: cutlass.utils.LayoutEnum
1743
+ :param smem_capacity: Total available shared memory capacity in bytes.
1744
+ :type smem_capacity: int
1745
+ :param occupancy: Target number of CTAs per SM (occupancy).
1746
+ :type occupancy: int
1747
+
1748
+ :return: A tuple containing the computed number of stages for:
1749
+ (ACC stages, A/B operand stages, C stages)
1750
+ :rtype: tuple[int, int, int]
1751
+ """
1752
+ blockscaled = sf_dtype is not None
1753
+ # Default ACC stages
1754
+ if const_expr(not blockscaled):
1755
+ num_acc_stage = 2
1756
+ else:
1757
+ num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
1758
+
1759
+ # Default D stages
1760
+ num_d_stage = 2
1761
+ num_c_stage = 2 if c_dtype is not None else 0
1762
+
1763
+ # Calculate smem layout and size for one stage of A, B, and C
1764
+ a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
1765
+ tiled_mma,
1766
+ mma_tiler_mnk,
1767
+ a_dtype,
1768
+ 1, # a tmp 1 stage is provided
1769
+ )
1770
+ b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(
1771
+ tiled_mma,
1772
+ mma_tiler_mnk,
1773
+ b_dtype,
1774
+ 1, # a tmp 1 stage is provided
1775
+ )
1776
+ d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
1777
+ c_smem_layout_staged_one = (
1778
+ sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
1779
+ if c_dtype is not None
1780
+ else None
1781
+ )
1782
+ if const_expr(blockscaled):
1783
+ sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(
1784
+ tiled_mma,
1785
+ mma_tiler_mnk,
1786
+ sf_vec_size,
1787
+ 1, # a tmp 1 stage is provided
1788
+ )
1789
+ sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(
1790
+ tiled_mma,
1791
+ mma_tiler_mnk,
1792
+ sf_vec_size,
1793
+ 1, # a tmp 1 stage is provided
1794
+ )
1795
+
1796
+ ab_bytes_per_stage = cute.size_in_bytes(
1797
+ a_dtype, a_smem_layout_staged_one
1798
+ ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
1799
+ if const_expr(blockscaled):
1800
+ ab_bytes_per_stage += cute.size_in_bytes(
1801
+ sf_dtype, sfa_smem_layout_staged_one
1802
+ ) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
1803
+ mbar_helpers_bytes = 1024
1804
+ d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
1805
+ epi_bytes = d_bytes_per_stage * num_d_stage
1806
+ if const_expr(c_dtype is not None):
1807
+ c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
1808
+ epi_bytes += c_bytes_per_stage * num_c_stage
1809
+
1810
+ # Calculate A/B/SFA/SFB stages:
1811
+ # Start with total smem per CTA (capacity / occupancy)
1812
+ # Subtract reserved bytes and initial C stages bytes
1813
+ # Divide remaining by bytes needed per A/B/SFA/SFB stage
1814
+ num_ab_stage = (
1815
+ smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)
1816
+ ) // ab_bytes_per_stage
1817
+
1818
+ # Refine epilogue stages:
1819
+ # Calculate remaining smem after allocating for A/B stages and reserved bytes
1820
+ # Add remaining unused smem to epilogue
1821
+ num_d_stage += (
1822
+ smem_capacity
1823
+ - occupancy * ab_bytes_per_stage * num_ab_stage
1824
+ - occupancy * (mbar_helpers_bytes + epi_bytes)
1825
+ ) // (occupancy * d_bytes_per_stage)
1826
+ return num_acc_stage, num_ab_stage, num_d_stage, num_c_stage
1827
+
1828
+ @staticmethod
1829
+ def _compute_num_tmem_alloc_cols(
1830
+ tiled_mma: cute.TiledMma,
1831
+ mma_tiler: Tuple[int, int, int],
1832
+ num_acc_stage: int,
1833
+ ) -> int:
1834
+ """
1835
+ Compute the number of tensor memory allocation columns.
1836
+
1837
+ :param tiled_mma: The tiled MMA object defining the core computation.
1838
+ :type tiled_mma: cute.TiledMma
1839
+ :param mma_tiler: The shape (M, N, K) of the MMA tile.
1840
+ :type mma_tiler: tuple[int, int, int]
1841
+ :param num_acc_stage: The stage of the accumulator tensor.
1842
+ :type num_acc_stage: int
1843
+
1844
+ :return: The number of tensor memory allocation columns.
1845
+ :rtype: int
1846
+ """
1847
+ acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2])
1848
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage))
1849
+ num_tmem_alloc_cols = cutlass.utils.get_num_tmem_alloc_cols(tCtAcc_fake)
1850
+ return num_tmem_alloc_cols
1851
+
1852
+ @staticmethod
1853
+ def is_valid_dtypes(
1854
+ ab_dtype: Type[cutlass.Numeric],
1855
+ acc_dtype: Type[cutlass.Numeric],
1856
+ d_dtype: Type[cutlass.Numeric],
1857
+ ) -> bool:
1858
+ """
1859
+ Check if the dtypes are valid
1860
+
1861
+ :param ab_dtype: The data type of the A and B operands
1862
+ :type ab_dtype: Type[cutlass.Numeric]
1863
+ :param acc_dtype: The data type of the accumulator
1864
+ :type acc_dtype: Type[cutlass.Numeric]
1865
+ :param d_dtype: The data type of the output tensor
1866
+ :type d_dtype: Type[cutlass.Numeric]
1867
+
1868
+ :return: True if the dtypes are valid, False otherwise
1869
+ :rtype: bool
1870
+ """
1871
+ is_valid = True
1872
+ if ab_dtype not in {
1873
+ cutlass.Float16,
1874
+ cutlass.BFloat16,
1875
+ cutlass.TFloat32,
1876
+ cutlass.Uint8,
1877
+ cutlass.Int8,
1878
+ cutlass.Float8E4M3FN,
1879
+ cutlass.Float8E5M2,
1880
+ }:
1881
+ is_valid = False
1882
+ if (
1883
+ acc_dtype not in {cutlass.Float32, cutlass.Float16, Int32}
1884
+ or acc_dtype == cutlass.Float16
1885
+ and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
1886
+ or acc_dtype == Int32
1887
+ and ab_dtype not in {cutlass.Uint8, cutlass.Int8}
1888
+ ):
1889
+ is_valid = False
1890
+ if (
1891
+ acc_dtype == cutlass.Float32
1892
+ and d_dtype
1893
+ not in {
1894
+ cutlass.Float32,
1895
+ cutlass.Float16,
1896
+ cutlass.BFloat16,
1897
+ cutlass.Float8E4M3FN,
1898
+ cutlass.Float8E5M2,
1899
+ Int32,
1900
+ cutlass.Int8,
1901
+ cutlass.Uint8,
1902
+ }
1903
+ or acc_dtype == cutlass.Float16
1904
+ and d_dtype
1905
+ not in {
1906
+ cutlass.BFloat16,
1907
+ cutlass.Float16,
1908
+ }
1909
+ or acc_dtype == Int32
1910
+ and d_dtype
1911
+ not in {
1912
+ cutlass.BFloat16,
1913
+ cutlass.Float16,
1914
+ cutlass.Float32,
1915
+ Int32,
1916
+ cutlass.Int8,
1917
+ cutlass.Uint8,
1918
+ }
1919
+ ):
1920
+ is_valid = False
1921
+ return is_valid
1922
+
1923
+ @staticmethod
1924
+ def is_valid_dtypes_and_scale_factor_vec_size(
1925
+ ab_dtype: Type[cutlass.Numeric],
1926
+ sf_dtype: Type[cutlass.Numeric],
1927
+ sf_vec_size: int,
1928
+ d_dtype: Type[cutlass.Numeric],
1929
+ ) -> bool:
1930
+ """
1931
+ Check if the dtypes and sf_vec_size are valid combinations
1932
+
1933
+ :param ab_dtype: The data type of the A and B operands
1934
+ :type ab_dtype: Type[cutlass.Numeric]
1935
+ :param sf_dtype: The data type of the scale factor
1936
+ :type sf_dtype: Type[cutlass.Numeric]
1937
+ :param sf_vec_size: The vector size of the scale factor
1938
+ :type sf_vec_size: int
1939
+ :param d_dtype: The data type of the output tensor
1940
+ :type d_dtype: Type[cutlass.Numeric]
1941
+
1942
+ :return: True if the dtypes and sf_vec_size are valid, False otherwise
1943
+ :rtype: bool
1944
+ """
1945
+ is_valid = True
1946
+
1947
+ # Check valid ab_dtype
1948
+ if ab_dtype not in {cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
1949
+ is_valid = False
1950
+
1951
+ # Check valid sf_vec_size
1952
+ if sf_vec_size not in {16, 32}:
1953
+ is_valid = False
1954
+
1955
+ # Check valid sf_dtype
1956
+ if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}:
1957
+ is_valid = False
1958
+
1959
+ # Check valid sf_dtype and sf_vec_size combinations
1960
+ if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32:
1961
+ is_valid = False
1962
+ if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16:
1963
+ is_valid = False
1964
+
1965
+ # Check valid d_dtype
1966
+ if d_dtype not in {
1967
+ cutlass.Float32,
1968
+ cutlass.Float16,
1969
+ cutlass.BFloat16,
1970
+ cutlass.Float8E5M2,
1971
+ cutlass.Float8E4M3FN,
1972
+ }:
1973
+ is_valid = False
1974
+
1975
+ return is_valid
1976
+
1977
+ @staticmethod
1978
+ def is_valid_layouts(
1979
+ ab_dtype: Type[cutlass.Numeric],
1980
+ a_major: str,
1981
+ b_major: str,
1982
+ ) -> bool:
1983
+ """
1984
+ Check if the dtypes and sf_vec_size are valid combinations
1985
+
1986
+ :param ab_dtype: The data type of the A and B operands
1987
+ :type ab_dtype: Type[cutlass.Numeric]
1988
+ :param d_dtype: The data type of the output tensor
1989
+ :type d_dtype: Type[cutlass.Numeric]
1990
+ :param a_major: The major dimension of the A tensor
1991
+ :type a_major: str
1992
+ :param b_major: The major dimension of the B tensor
1993
+ :type b_major: str
1994
+ :param d_major: The major dimension of the C tensor
1995
+ :type d_major: str
1996
+
1997
+ :return: True if the layouts are valid, False otherwise
1998
+ :rtype: bool
1999
+ """
2000
+ is_valid = True
2001
+ if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
2002
+ is_valid = False
2003
+ return is_valid
2004
+
2005
+ @staticmethod
2006
+ def is_valid_mma_tiler_and_cluster_shape(
2007
+ use_2cta_instrs: bool,
2008
+ mma_tiler_mn: Tuple[int, int],
2009
+ cluster_shape_mn: Tuple[int, int],
2010
+ blockscaled: bool,
2011
+ ) -> bool:
2012
+ """
2013
+ Check if the mma tiler and cluster shape are valid
2014
+
2015
+ :param use_2cta_instrs: Whether to use 2 CTA groups
2016
+ :type use_2cta_instrs: bool
2017
+ :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
2018
+ :type mma_tiler_mn: Tuple[int, int]
2019
+ :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
2020
+ :type cluster_shape_mn: Tuple[int, int]
2021
+
2022
+ :return: True if the mma tiler and cluster shape are valid, False otherwise
2023
+ :rtype: bool
2024
+ """
2025
+ is_valid = True
2026
+ # Skip invalid mma tile shape
2027
+ if not (
2028
+ (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
2029
+ or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
2030
+ ):
2031
+ is_valid = False
2032
+ if not blockscaled:
2033
+ if mma_tiler_mn[1] not in range(32, 257, 32):
2034
+ is_valid = False
2035
+ else:
2036
+ if mma_tiler_mn[1] not in [128, 256]:
2037
+ is_valid = False
2038
+ # Skip illegal cluster shape
2039
+ if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
2040
+ is_valid = False
2041
+ # Skip invalid cluster shape
2042
+ is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
2043
+ if (
2044
+ cluster_shape_mn[0] * cluster_shape_mn[1] > 16
2045
+ or cluster_shape_mn[0] <= 0
2046
+ or cluster_shape_mn[1] <= 0
2047
+ or not is_power_of_2(cluster_shape_mn[0])
2048
+ or not is_power_of_2(cluster_shape_mn[1])
2049
+ ):
2050
+ is_valid = False
2051
+ if blockscaled:
2052
+ # Special cluster shape check for scale factor multicasts.
2053
+ # Due to limited size of scale factors, we can't multicast among more than 4 CTAs.
2054
+ if cluster_shape_mn[0] > 4 or cluster_shape_mn[1] > 4:
2055
+ is_valid = False
2056
+ return is_valid
2057
+
2058
+ @staticmethod
2059
+ def is_valid_tensor_alignment(
2060
+ m: int,
2061
+ n: int,
2062
+ k: int,
2063
+ l: int,
2064
+ ab_dtype: Type[cutlass.Numeric],
2065
+ d_dtype: Type[cutlass.Numeric],
2066
+ a_major: str,
2067
+ b_major: str,
2068
+ d_major: str,
2069
+ ) -> bool:
2070
+ """
2071
+ Check if the tensor alignment is valid
2072
+
2073
+ :param m: The number of rows in the A tensor
2074
+ :type m: int
2075
+ :param n: The number of columns in the B tensor
2076
+ :type n: int
2077
+ :param k: The number of columns in the A tensor
2078
+ :type k: int
2079
+ :param l: The number of columns in the C tensor
2080
+ :type l: int
2081
+ :param ab_dtype: The data type of the A and B operands
2082
+ :type ab_dtype: Type[cutlass.Numeric]
2083
+ :param d_dtype: The data type of the output tensor
2084
+ :type d_dtype: Type[cutlass.Numeric]
2085
+ :param a_major: The major axis of the A tensor
2086
+ :type a_major: str
2087
+ :param b_major: The major axis of the B tensor
2088
+ :type b_major: str
2089
+ :param d_major: The major axis of the C tensor
2090
+ :type d_major: str
2091
+
2092
+ :return: True if the problem shape is valid, False otherwise
2093
+ :rtype: bool
2094
+ """
2095
+ is_valid = True
2096
+
2097
+ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
2098
+ major_mode_idx = 0 if is_mode0_major else 1
2099
+ num_major_elements = tensor_shape[major_mode_idx]
2100
+ num_contiguous_elements = 16 * 8 // dtype.width
2101
+ return num_major_elements % num_contiguous_elements == 0
2102
+
2103
+ if (
2104
+ not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
2105
+ or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
2106
+ or not check_contigous_16B_alignment(d_dtype, d_major == "m", (m, n, l))
2107
+ ):
2108
+ is_valid = False
2109
+ return is_valid
2110
+
2111
+ @staticmethod
2112
+ def can_implement(
2113
+ ab_dtype: Type[cutlass.Numeric],
2114
+ acc_dtype: Type[cutlass.Numeric],
2115
+ d_dtype: Type[cutlass.Numeric],
2116
+ use_2cta_instrs: bool,
2117
+ mma_tiler_mn: Tuple[int, int],
2118
+ cluster_shape_mn: Tuple[int, int],
2119
+ m: int,
2120
+ n: int,
2121
+ k: int,
2122
+ l: int,
2123
+ a_major: str,
2124
+ b_major: str,
2125
+ d_major: str,
2126
+ ) -> bool:
2127
+ """
2128
+ Check if the gemm can be implemented
2129
+
2130
+ :param ab_dtype: The data type of the A and B operands
2131
+ :type ab_dtype: Type[cutlass.Numeric]
2132
+ :param acc_dtype: The data type of the accumulator
2133
+ :type acc_dtype: Type[cutlass.Numeric]
2134
+ :param d_dtype: The data type of the output tensor
2135
+ :type d_dtype: Type[cutlass.Numeric]
2136
+ :param use_2cta_instrs: Whether to use 2 CTA groups
2137
+ :type use_2cta_instrs: bool
2138
+ :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
2139
+ :type mma_tiler_mn: Tuple[int, int]
2140
+ :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
2141
+ :type cluster_shape_mn: Tuple[int, int]
2142
+ :param m: The number of rows in the A tensor
2143
+ :type m: int
2144
+ :param n: The number of columns in the B tensor
2145
+ :type n: int
2146
+ :param k: The number of columns in the A tensor
2147
+ :type k: int
2148
+ :param l: The number of columns in the C tensor
2149
+ :type l: int
2150
+ :param a_major: The major axis of the A tensor
2151
+ :type a_major: str
2152
+ :param b_major: The major axis of the B tensor
2153
+ :type b_major: str
2154
+ :param d_major: The major axis of the C tensor
2155
+ :type d_major: str
2156
+
2157
+ :return: True if the gemm can be implemented, False otherwise
2158
+ :rtype: bool
2159
+ """
2160
+ can_implement = True
2161
+ # Skip unsupported types
2162
+ if not PersistentDenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, d_dtype):
2163
+ can_implement = False
2164
+ # Skip invalid mma tile shape and cluster shape
2165
+ if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
2166
+ use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, blockscaled=False
2167
+ ):
2168
+ can_implement = False
2169
+ # Skip illegal problem shape for load/store alignment
2170
+ if not PersistentDenseGemmKernel.is_valid_tensor_alignment(
2171
+ m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major
2172
+ ):
2173
+ can_implement = False
2174
+ return can_implement
2175
+
2176
+
2177
+ def run(
2178
+ mnkl: Tuple[int, int, int, int],
2179
+ ab_dtype: Type[cutlass.Numeric],
2180
+ d_dtype: Type[cutlass.Numeric],
2181
+ c_dtype: Optional[Type[cutlass.Numeric]],
2182
+ acc_dtype: Type[cutlass.Numeric],
2183
+ a_major: str,
2184
+ b_major: str,
2185
+ d_major: str,
2186
+ c_major: str,
2187
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
2188
+ cluster_shape_mn: Tuple[int, int] = (2, 1),
2189
+ use_2cta_instrs: bool = True,
2190
+ tolerance: float = 1e-01,
2191
+ warmup_iterations: int = 0,
2192
+ iterations: int = 1,
2193
+ skip_ref_check: bool = False,
2194
+ dynamic_persistent: bool = False,
2195
+ **kwargs,
2196
+ ):
2197
+ """Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
2198
+
2199
+ This function prepares input tensors, configures and launches the persistent GEMM kernel,
2200
+ optionally performs reference validation, and benchmarks the execution performance.
2201
+
2202
+ :param mnkl: Problem size (M, N, K, L)
2203
+ :type mnkl: Tuple[int, int, int, int]
2204
+ :param ab_dtype: Data type for input tensors A and B
2205
+ :type ab_dtype: Type[cutlass.Numeric]
2206
+ :param d_dtype: Data type for output tensor C
2207
+ :type d_dtype: Type[cutlass.Numeric]
2208
+ :param acc_dtype: Data type for accumulation during matrix multiplication
2209
+ :type acc_dtype: Type[cutlass.Numeric]
2210
+ :param a_major/b_major/d_major: Memory layout of tensor A/B/C
2211
+ :type a_major/b_major/d_major: str
2212
+ :param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
2213
+ default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
2214
+ :type mma_tiler_mn: Tuple[int, int], optional
2215
+ :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
2216
+ default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
2217
+ :type cluster_shape_mn: Tuple[int, int], optional
2218
+ :param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
2219
+ will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
2220
+ :type use_2cta_instrs: bool, optional
2221
+ :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
2222
+ :type tolerance: float, optional
2223
+ :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
2224
+ :type warmup_iterations: int, optional
2225
+ :param iterations: Number of benchmark iterations to run, defaults to 1
2226
+ :type iterations: int, optional
2227
+ :param skip_ref_check: Whether to skip reference result validation, defaults to False
2228
+ :type skip_ref_check: bool, optional
2229
+ :raises RuntimeError: If CUDA GPU is not available
2230
+ :raises ValueError: If the configuration is invalid or unsupported by the kernel
2231
+ :return: Execution time of the GEMM kernel
2232
+ :rtype: float
2233
+ """
2234
+ print("Running Blackwell Persistent Dense GEMM test with:")
2235
+ print(f"mnkl: {mnkl}")
2236
+ print(f"AB dtype: {ab_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
2237
+ print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
2238
+ print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
2239
+ print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
2240
+ print(f"Tolerance: {tolerance}")
2241
+ print(f"Warmup iterations: {warmup_iterations}")
2242
+ print(f"Iterations: {iterations}")
2243
+ print(f"Skip reference checking: {skip_ref_check}")
2244
+
2245
+ assert not dynamic_persistent, "Dynamic persistent mode is not supported yet."
2246
+
2247
+ # Unpack parameters
2248
+ m, n, k, l = mnkl
2249
+
2250
+ # Skip unsupported testcase
2251
+ if not PersistentDenseGemmKernel.can_implement(
2252
+ ab_dtype,
2253
+ acc_dtype,
2254
+ d_dtype,
2255
+ use_2cta_instrs,
2256
+ mma_tiler_mn,
2257
+ cluster_shape_mn,
2258
+ m,
2259
+ n,
2260
+ k,
2261
+ l,
2262
+ a_major,
2263
+ b_major,
2264
+ d_major,
2265
+ ):
2266
+ raise TypeError(
2267
+ f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}"
2268
+ )
2269
+
2270
+ if not torch.cuda.is_available():
2271
+ raise RuntimeError("GPU is required to run this example!")
2272
+
2273
+ torch.manual_seed(1111)
2274
+
2275
+ # Create and permute tensor A/B/C
2276
+ def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True):
2277
+ # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
2278
+ # else: (l, mode0, mode1) -> (mode0, mode1, l)
2279
+ shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
2280
+ permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
2281
+ is_unsigned = dtype in {cutlass.Uint8}
2282
+ # Temporarily use uint8 as torch does not support fp8 type
2283
+ torch_dtype = cutlass_torch.dtype(dtype)
2284
+ gen_dtype = (
2285
+ torch_dtype
2286
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
2287
+ else torch.bfloat16
2288
+ )
2289
+
2290
+ # Create dtype torch tensor (cpu)
2291
+ torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor(
2292
+ shape,
2293
+ gen_dtype,
2294
+ permute_order=permute_order,
2295
+ # init_type=cutlass.torch.TensorInitType.RANDOM,
2296
+ # init_config=cutlass.torch.RandomInitConfig(
2297
+ # min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
2298
+ # ),
2299
+ init_type=cutlass.torch.TensorInitType.GAUSSIAN,
2300
+ init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
2301
+ ).to(torch_dtype)
2302
+ # Create dtype torch tensor (gpu)
2303
+ torch_tensor = torch_tensor_cpu.cuda()
2304
+
2305
+ # Create f32 torch tensor (cpu)
2306
+ f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
2307
+
2308
+ # Create dtype cute tensor (gpu)
2309
+ torch_tensor_view = (
2310
+ torch_tensor
2311
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
2312
+ else torch_tensor.view(torch.uint8)
2313
+ )
2314
+ cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
2315
+ cute_tensor.element_type = dtype
2316
+ if is_dynamic_layout:
2317
+ cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
2318
+ cute_tensor = cutlass_torch.convert_cute_tensor(
2319
+ f32_torch_tensor,
2320
+ cute_tensor,
2321
+ dtype,
2322
+ is_dynamic_layout=is_dynamic_layout,
2323
+ )
2324
+
2325
+ return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
2326
+
2327
+ a_ref, mA, a_torch, a_torch_cpu = create_and_permute_tensor(
2328
+ l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
2329
+ )
2330
+ b_ref, mB, b_torch, b_torch_cpu = create_and_permute_tensor(
2331
+ l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
2332
+ )
2333
+ _, mD, d_torch, d_torch_cpu = create_and_permute_tensor(
2334
+ l, m, n, d_major == "m", d_dtype, is_dynamic_layout=True
2335
+ )
2336
+ if c_dtype is not None:
2337
+ c, mC, c_torch, d_torch_cpu = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
2338
+ else:
2339
+ c, mC, c_torch = None, None, None
2340
+
2341
+ # Configure gemm kernel
2342
+ gemm = PersistentDenseGemmKernel(
2343
+ acc_dtype,
2344
+ use_2cta_instrs,
2345
+ mma_tiler_mn,
2346
+ cluster_shape_mn,
2347
+ )
2348
+
2349
+ # Compute max active clusters on current device
2350
+ hardware_info = cutlass.utils.HardwareInfo()
2351
+ max_active_clusters = hardware_info.get_max_active_clusters(
2352
+ cluster_shape_mn[0] * cluster_shape_mn[1]
2353
+ )
2354
+ if dynamic_persistent:
2355
+ tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda")
2356
+ else:
2357
+ tile_count_semaphore = None
2358
+
2359
+ # Get current CUDA stream from PyTorch
2360
+ torch_stream = torch.cuda.current_stream()
2361
+ # Get the raw stream pointer as a CUstream
2362
+ current_stream = cuda.CUstream(torch_stream.cuda_stream)
2363
+ # Compile gemm kernel
2364
+ compiled_gemm = cute.compile(
2365
+ gemm,
2366
+ mA,
2367
+ mB,
2368
+ mD,
2369
+ mC,
2370
+ make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2371
+ if tile_count_semaphore is not None
2372
+ else None,
2373
+ max_active_clusters,
2374
+ current_stream,
2375
+ )
2376
+
2377
+ if not skip_ref_check:
2378
+ compiled_gemm(mA, mB, mD, mC, tile_count_semaphore, current_stream)
2379
+ if ab_dtype in {
2380
+ cutlass.Int8,
2381
+ cutlass.Uint8,
2382
+ cutlass.Float8E4M3FN,
2383
+ cutlass.Float8E5M2,
2384
+ }:
2385
+ ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu())
2386
+ else:
2387
+ ref = torch.einsum("mkl,nkl->mnl", a_ref, b_ref)
2388
+ if c is not None:
2389
+ ref = ref + c
2390
+ ref = ref.cpu()
2391
+
2392
+ # Copy gpu result back
2393
+ gpu_d = d_torch.cpu()
2394
+
2395
+ # Convert ref to c_type
2396
+ if d_dtype == cutlass.Float32:
2397
+ ref_d = ref
2398
+ elif d_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
2399
+ # m major: (l, n, m) -> (m, n, l)
2400
+ # n major: (l, m, n) -> (m, n, l)
2401
+ permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0)
2402
+ shape = (l, m, n) if d_major == "n" else (l, n, m)
2403
+ f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
2404
+ shape,
2405
+ torch.uint8,
2406
+ permute_order=permute_order,
2407
+ init_type=cutlass_torch.TensorInitType.SKIP,
2408
+ ).cuda()
2409
+ # Create dtype cute tensor (gpu)
2410
+ ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
2411
+ leading_dim=(1 if d_major == "n" else 0)
2412
+ )
2413
+ ref_d_tensor.element_type = d_dtype
2414
+ ref_d_tensor = cutlass_torch.convert_cute_tensor(
2415
+ ref,
2416
+ ref_d_tensor,
2417
+ d_dtype,
2418
+ is_dynamic_layout=True,
2419
+ )
2420
+
2421
+ ref_d = f8_torch_tensor.cpu()
2422
+ else:
2423
+ ref_d = ref.to(cutlass_torch.dtype(d_dtype))
2424
+
2425
+ # Reference checking ref_d and gpu_d
2426
+ torch.testing.assert_close(gpu_d, ref_d, atol=tolerance, rtol=1e-05)
2427
+
2428
+ from triton.testing import do_bench
2429
+
2430
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
2431
+
2432
+ flops = 2 * m * n * k * l
2433
+
2434
+ repeats = iterations
2435
+ warmup = warmup_iterations
2436
+
2437
+ import time
2438
+
2439
+ time.sleep(0.5)
2440
+ if ab_dtype.width == 8:
2441
+ assert l == 1
2442
+ scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
2443
+ fn_cublas = lambda: torch._scaled_mm(
2444
+ a_torch[:, :, 0],
2445
+ b_torch[:, :, 0].mT,
2446
+ scale_a=scale_ab,
2447
+ scale_b=scale_ab,
2448
+ out_dtype=torch.bfloat16,
2449
+ # use_fast_accum=fp8_fast_accum,
2450
+ )
2451
+ else:
2452
+ if c_torch is None:
2453
+ fn_cublas = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
2454
+ else:
2455
+ c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
2456
+ fn_cublas = lambda: torch.baddbmm(
2457
+ c_torch_convert.permute(2, 0, 1),
2458
+ a_torch.permute(2, 0, 1),
2459
+ b_torch.permute(2, 0, 1).mT,
2460
+ )
2461
+ timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2462
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2463
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2464
+
2465
+ time.sleep(0.5)
2466
+ fn = lambda: compiled_gemm(mA, mB, mD, mC, tile_count_semaphore, current_stream)
2467
+ timing = do_bench(fn, warmup=warmup, rep=repeats)
2468
+ tflops = flops / (timing * 1e9) # Convert to TFlops
2469
+ print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
2470
+
2471
+ # time.sleep(0.5)
2472
+ # timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2473
+ # tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2474
+ # print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2475
+
2476
+
2477
+ if __name__ == "__main__":
2478
+
2479
+ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
2480
+ try:
2481
+ return tuple(int(x.strip()) for x in s.split(","))
2482
+ except ValueError:
2483
+ raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")
2484
+
2485
+ parser = argparse.ArgumentParser(description="Example of Dense Persistent GEMM on Blackwell.")
2486
+
2487
+ parser.add_argument(
2488
+ "--mnkl",
2489
+ type=parse_comma_separated_ints,
2490
+ default=(256, 256, 512, 1),
2491
+ help="mnkl dimensions (comma-separated)",
2492
+ )
2493
+ parser.add_argument(
2494
+ "--mma_tiler_mn",
2495
+ type=parse_comma_separated_ints,
2496
+ default=(128, 128),
2497
+ help="Mma tile shape (comma-separated)",
2498
+ )
2499
+ parser.add_argument(
2500
+ "--cluster_shape_mn",
2501
+ type=parse_comma_separated_ints,
2502
+ default=(1, 1),
2503
+ help="Cluster shape (comma-separated)",
2504
+ )
2505
+ parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
2506
+ parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
2507
+ parser.add_argument("--c_dtype", type=cutlass.dtype, default=None)
2508
+ parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
2509
+ parser.add_argument(
2510
+ "--use_2cta_instrs",
2511
+ action="store_true",
2512
+ help="Enable 2CTA MMA instructions feature",
2513
+ )
2514
+ parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
2515
+ parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
2516
+ parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
2517
+ parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
2518
+
2519
+ parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
2520
+ parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
2521
+ parser.add_argument(
2522
+ "--iterations",
2523
+ type=int,
2524
+ default=30,
2525
+ help="Number of iterations to run the kernel",
2526
+ )
2527
+ parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
2528
+ parser.add_argument(
2529
+ "--dynamic_persistent", action="store_true", help="Dynamic persistent kernel"
2530
+ )
2531
+
2532
+ args = parser.parse_args()
2533
+
2534
+ if len(args.mnkl) != 4:
2535
+ parser.error("--mnkl must contain exactly 4 values")
2536
+
2537
+ if len(args.mma_tiler_mn) != 2:
2538
+ parser.error("--mma_tiler_mn must contain exactly 2 values")
2539
+
2540
+ if len(args.cluster_shape_mn) != 2:
2541
+ parser.error("--cluster_shape_mn must contain exactly 2 values")
2542
+
2543
+ run(
2544
+ args.mnkl,
2545
+ args.ab_dtype,
2546
+ args.d_dtype,
2547
+ args.c_dtype,
2548
+ args.acc_dtype,
2549
+ args.a_major,
2550
+ args.b_major,
2551
+ args.d_major,
2552
+ args.c_major,
2553
+ args.mma_tiler_mn,
2554
+ args.cluster_shape_mn,
2555
+ args.use_2cta_instrs,
2556
+ args.tolerance,
2557
+ args.warmup_iterations,
2558
+ args.iterations,
2559
+ args.skip_ref_check,
2560
+ args.dynamic_persistent,
2561
+ )
2562
+ print("PASS")