quack-kernels 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/dense_gemm_sm90.py CHANGED
@@ -27,21 +27,37 @@
27
27
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
28
 
29
29
  import argparse
30
- from typing import Tuple, Type
30
+ import enum
31
+ from typing import Tuple, Type, Callable, Optional
32
+ from functools import partial
31
33
  import math
34
+
32
35
  import cuda.bindings.driver as cuda
33
36
 
34
37
  import torch
35
38
 
36
39
  import cutlass
37
40
  import cutlass.cute as cute
38
- import cutlass.cute.testing as testing
39
- import cutlass.utils as utils
40
41
  import cutlass.pipeline as pipeline
41
42
  import cutlass.torch as cutlass_torch
42
- from cutlass.cute.runtime import from_dlpack
43
+ from cutlass.cute.runtime import from_dlpack, make_ptr
43
44
  from cutlass.cute.nvgpu import cpasync, warp, warpgroup
44
45
  import cutlass.utils.hopper_helpers as sm90_utils
46
+ from cutlass import Int32, const_expr
47
+
48
+ from quack.tile_scheduler import (
49
+ TileSchedulerArguments,
50
+ TileScheduler,
51
+ VarlenMTileSchedulerArguments,
52
+ VarlenMTileScheduler,
53
+ ParamsBase,
54
+ RasterOrderOption,
55
+ )
56
+ from quack.tensormap_manager import TensorMapManagerSm90
57
+
58
+ # return PipelineStateWAdvance instead of PipelineState
59
+ from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
60
+ import quack.utils as utils
45
61
 
46
62
  """
47
63
  A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
@@ -154,6 +170,11 @@ def parse_arguments() -> argparse.Namespace:
154
170
  type=cutlass.dtype,
155
171
  default=cutlass.BFloat16,
156
172
  )
173
+ parser.add_argument(
174
+ "--c_dtype",
175
+ type=cutlass.dtype,
176
+ default=None,
177
+ )
157
178
  parser.add_argument(
158
179
  "--acc_dtype",
159
180
  type=cutlass.dtype,
@@ -162,21 +183,24 @@ def parse_arguments() -> argparse.Namespace:
162
183
  parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
163
184
  parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
164
185
  parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
165
- parser.add_argument("--tolerance", type=float, default=1e-01, help="Tolerance for validation")
166
- parser.add_argument("--warmup_iterations", type=int, default=0, help="Warmup iterations")
186
+ parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
187
+ parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
188
+ parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
167
189
  parser.add_argument(
168
190
  "--iterations",
169
191
  type=int,
170
- default=1,
192
+ default=30,
171
193
  help="Number of iterations to run the kernel",
172
194
  )
173
- parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
195
+ parser.add_argument("--persistent", action="store_true", help="Persistent kernel")
174
196
  parser.add_argument(
175
- "--use_cold_l2",
176
- action="store_true",
177
- default=False,
178
- help="Use circular buffer tensor sets to ensure L2 cold cache",
197
+ "--dynamic_persistent", action="store_true", help="Dynamic persistent kernel"
179
198
  )
199
+ parser.add_argument("--pingpong", action="store_true", help="Pingpong kernel")
200
+ parser.add_argument("--varlen_m", action="store_true", help="Variable length M dimension")
201
+ parser.add_argument("--gather_A", action="store_true", help="Gather A")
202
+ parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum")
203
+ parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
180
204
 
181
205
  args = parser.parse_args()
182
206
 
@@ -195,6 +219,17 @@ def parse_arguments() -> argparse.Namespace:
195
219
  # /////////////////////////////////////////////////////////////////////////////
196
220
 
197
221
 
222
+ class NamedBarrierGemm(enum.IntEnum):
223
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
224
+ # For mainloop load warps to signal that the epilogue load warp can start.
225
+ # This is to avoid loading C too early, interfering with loading A and B.
226
+ EpilogueLoad = enum.auto()
227
+ MmaWG0 = enum.auto()
228
+ MmaWG1 = enum.auto()
229
+ EpiWG0 = enum.auto()
230
+ EpiWG1 = enum.auto()
231
+
232
+
198
233
  class HopperWgmmaGemmKernel:
199
234
  """
200
235
  This class implements batched matrix multiplication (C = A x B) with support for various data types
@@ -221,9 +256,6 @@ class HopperWgmmaGemmKernel:
221
256
  - Float32 (for all floating point inputs)
222
257
 
223
258
  :note: Constraints:
224
- - CTA tile M must be 64/128
225
- - CTA tile N must be 64/128/256
226
- - CTA tile K must be 64
227
259
  - Cluster shape M/N must be positive and power of 2, total cluster size <= 4
228
260
 
229
261
  Example:
@@ -235,11 +267,19 @@ class HopperWgmmaGemmKernel:
235
267
  >>> gemm(a_tensor, b_tensor, c_tensor, stream)
236
268
  """
237
269
 
270
+ bytes_per_tensormap = 128
271
+ num_tensormaps = 1 # For D only
272
+
238
273
  def __init__(
239
274
  self,
240
275
  acc_dtype: Type[cutlass.Numeric],
276
+ a_dtype: Type[cutlass.Numeric],
241
277
  tile_shape_mnk: Tuple[int, int, int],
242
278
  cluster_shape_mnk: Tuple[int, int, int],
279
+ pingpong: bool = False,
280
+ is_persistent: bool = True,
281
+ fp8_fast_accum: bool = False,
282
+ gather_A: bool = False,
243
283
  ):
244
284
  """
245
285
  Initializes the configuration for a Hopper dense GEMM kernel.
@@ -256,52 +296,101 @@ class HopperWgmmaGemmKernel:
256
296
  """
257
297
 
258
298
  self.acc_dtype = acc_dtype
299
+ self.pingpong = pingpong
300
+ self.is_persistent = is_persistent
301
+ if self.pingpong:
302
+ assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
303
+ self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
304
+ self.gather_A = gather_A
305
+ if gather_A:
306
+ assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
307
+ self.tensormap_update_mode = cutlass.utils.TensorMapUpdateMode.SMEM
259
308
 
260
309
  self.cluster_shape_mnk = cluster_shape_mnk
261
310
  self.tile_shape_mnk = tuple(tile_shape_mnk)
262
311
  tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1]
263
312
  # check the cta tile shape
264
- # if tile_M not in [64, 128, 192, 256]:
265
- # TODO: M=192 currently doesn't work
266
- if tile_M not in [64, 128, 256]:
267
- raise ValueError("CTA tile shape M must be 64/128/192/256")
268
- if tile_M == 192: # special case
269
- if not (tile_N % 32 == 0 and tile_N <= 288):
270
- raise ValueError(
271
- "If tile_m == 192, CTA tile shape N must be divisible by 32 and <= 288"
272
- )
313
+ if not self.pingpong:
314
+ if tile_M not in [64, 128, 192, 256, 320]:
315
+ raise ValueError("CTA tile shape M must be 64/128/192/256/320")
316
+ if tile_M in [192, 320]: # special case
317
+ tile_N_max = 256 if tile_M == 192 else 160
318
+ if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
319
+ raise ValueError(
320
+ f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
321
+ )
322
+ else:
323
+ if not (
324
+ (tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
325
+ ):
326
+ raise ValueError(
327
+ "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
328
+ )
273
329
  else:
274
- if not ((tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)):
275
- raise ValueError(
276
- "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
277
- )
330
+ if tile_M not in [64, 128, 192]:
331
+ raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
332
+ tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
333
+ if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
334
+ raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
278
335
  if not self.tile_shape_mnk[2] % 16 == 0:
279
336
  raise ValueError("CTA tile shape K must be divisible by 16")
280
337
 
281
- if tile_M == 192: # Special case
282
- atom_layout_m, atom_layout_n = 1, 2
338
+ if not self.pingpong:
339
+ if tile_M == 320: # tile_M / 64 is not even so we have to split along N
340
+ atom_layout_m, atom_layout_n = 1, 2
341
+ elif tile_M == 192:
342
+ if tile_N <= 128:
343
+ atom_layout_m, atom_layout_n = 3, 1
344
+ else:
345
+ atom_layout_m, atom_layout_n = 1, 2
346
+ else:
347
+ atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
348
+ atom_layout_n = 1
349
+ assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
283
350
  else:
284
- atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
285
- atom_layout_n = 1
286
- assert atom_layout_m in [1, 2] and atom_layout_n in [1, 2]
351
+ atom_layout_m, atom_layout_n = 1, 1
287
352
  self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
288
353
 
289
- self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
354
+ self.num_mcast_ctas_a = self.cluster_shape_mnk[1] if not self.gather_A else 1
290
355
  self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
291
356
  self.is_a_mcast = self.num_mcast_ctas_a > 1
292
357
  self.is_b_mcast = self.num_mcast_ctas_b > 1
293
358
 
294
359
  self.occupancy = 1
295
- self.mma_warp_groups = math.prod(self.atom_layout_mnk)
360
+ self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
361
+ if self.pingpong:
362
+ assert self.mma_warp_groups == 2
363
+ assert self.mma_warp_groups in [1, 2, 3]
296
364
  self.num_threads_per_warp_group = 128
297
365
  self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
298
- self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
299
- self.num_mma_threads = self.mma_warp_groups * self.num_threads_per_warp_group
300
-
301
- regs_per_thread = math.prod(self.tile_shape_mnk) // self.num_mma_threads
302
- heavy_register_pressure = regs_per_thread >= 208
303
- self.num_regs_load = 40 if not heavy_register_pressure else 24
304
- self.num_regs_mma = 232 if not heavy_register_pressure else 240
366
+ self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
367
+ self.num_epi_threads = (
368
+ self.mma_warp_groups if not self.pingpong else 1
369
+ ) * self.num_threads_per_warp_group
370
+ self.num_ab_load_warps = 1 if not self.gather_A else 4
371
+ self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
372
+ self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
373
+ self.ab_load_warp_id = self.mma_warp_groups * 4
374
+ self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
375
+
376
+ regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // (
377
+ math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
378
+ )
379
+ if self.fp8_slow_accum:
380
+ regs_per_thread *= 2
381
+ if not self.gather_A:
382
+ if self.mma_warp_groups == 3:
383
+ self.num_regs_load, self.num_regs_mma = 32, 160
384
+ else:
385
+ heavy_register_pressure = regs_per_thread >= 208
386
+ self.num_regs_load, self.num_regs_mma = (
387
+ (40, 232) if not heavy_register_pressure else (24, 240)
388
+ )
389
+ else:
390
+ if self.mma_warp_groups == 3:
391
+ self.num_regs_load, self.num_regs_mma = 56, 152
392
+ else:
393
+ self.num_regs_load, self.num_regs_mma = (56, 224)
305
394
 
306
395
  self.ab_stage = None
307
396
  self.epi_stage = None
@@ -328,26 +417,34 @@ class HopperWgmmaGemmKernel:
328
417
  - Computing A/B/C shared memory layout
329
418
  """
330
419
 
331
- self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
420
+ self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
332
421
 
333
- is_cooperative = math.prod(self.atom_layout_mnk) > 1
334
422
  self.epi_tile = self._sm90_compute_tile_shape_or_override(
335
- self.tile_shape_mnk, self.d_dtype, is_cooperative=is_cooperative
423
+ self.tile_shape_mnk,
424
+ self.atom_layout_mnk,
425
+ self.d_dtype,
336
426
  )
337
427
 
338
428
  # Compute stage before compute smem layout
339
- self.ab_stage, self.epi_stage = self._compute_stages(
429
+ self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
340
430
  self.tile_shape_mnk,
431
+ self.epi_tile,
341
432
  self.a_dtype,
342
433
  self.b_dtype,
434
+ self.d_dtype,
435
+ self.c_dtype,
343
436
  self.smem_capacity,
344
437
  self.occupancy,
438
+ # epi_smem will reuse smem ab if not persistent.
439
+ overlap_sD_sA=not self.is_persistent,
345
440
  )
441
+ self.sched_stage = 2 if self.pingpong else 1
346
442
 
347
443
  (
348
444
  self.a_smem_layout_staged,
349
445
  self.b_smem_layout_staged,
350
446
  self.epi_smem_layout_staged,
447
+ self.epi_c_smem_layout_staged,
351
448
  ) = self._make_smem_layouts(
352
449
  self.tile_shape_mnk,
353
450
  self.epi_tile,
@@ -359,6 +456,9 @@ class HopperWgmmaGemmKernel:
359
456
  self.d_dtype,
360
457
  self.d_layout,
361
458
  self.epi_stage,
459
+ self.c_dtype,
460
+ self.c_layout,
461
+ self.epi_c_stage,
362
462
  )
363
463
 
364
464
  @cute.jit
@@ -367,6 +467,12 @@ class HopperWgmmaGemmKernel:
367
467
  mA: cute.Tensor,
368
468
  mB: cute.Tensor,
369
469
  mD: cute.Tensor,
470
+ mC: Optional[cute.Tensor],
471
+ mAIdx: Optional[cute.Tensor],
472
+ mCuSeqlensM: Optional[cute.Tensor],
473
+ mTensormaps: Optional[cute.Tensor],
474
+ tile_count_semaphore: Optional[cute.Pointer],
475
+ max_active_clusters: Int32,
370
476
  stream: cuda.CUstream,
371
477
  ):
372
478
  """Execute the GEMM operation in steps:
@@ -390,16 +496,29 @@ class HopperWgmmaGemmKernel:
390
496
  self.a_dtype = mA.element_type
391
497
  self.b_dtype = mB.element_type
392
498
  self.d_dtype = mD.element_type
393
- self.a_layout = utils.LayoutEnum.from_tensor(mA)
394
- self.b_layout = utils.LayoutEnum.from_tensor(mB)
395
- self.d_layout = utils.LayoutEnum.from_tensor(mD)
499
+ self.c_dtype = mC.element_type if mC is not None else None
500
+ self.a_layout = cutlass.utils.LayoutEnum.from_tensor(mA)
501
+ self.b_layout = cutlass.utils.LayoutEnum.from_tensor(mB)
502
+ self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
503
+ self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
396
504
 
397
- if cutlass.const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
505
+ if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
398
506
  raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
399
- if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width):
507
+ if const_expr(self.a_dtype.width != self.b_dtype.width):
400
508
  raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
401
- if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
509
+ if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
402
510
  raise TypeError("a_dtype should be float16 or float8")
511
+ assert (mAIdx is not None) == self.gather_A
512
+
513
+ # Assume all strides are divisible by 128 bits except the last stride
514
+ new_stride = lambda t: tuple(
515
+ cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
516
+ for s in t.stride
517
+ )
518
+ mA, mD = [
519
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
520
+ for t in (mA, mD)
521
+ ]
403
522
 
404
523
  self._setup_attributes()
405
524
 
@@ -412,13 +531,31 @@ class HopperWgmmaGemmKernel:
412
531
  self.atom_layout_mnk,
413
532
  tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
414
533
  )
534
+ if const_expr(self.atom_layout_mnk[1] > 1):
535
+ # If N dimension is split among 2 WGs, we need to permute the N dimension so
536
+ # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
537
+ # containing accumulators that are next to each other in the N dimension.
538
+ # Without permutation WG0 would write to epi smem of size (64, 16) and
539
+ # WG1 would write to a separate epi smem of size (64, 16) that's far away.
540
+ atom_n = self.atom_layout_mnk[1]
541
+ permutation_n = cute.make_ordered_layout(
542
+ (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
543
+ )
544
+ tiled_mma = cute.make_tiled_mma(
545
+ cute.make_mma_atom(tiled_mma.op),
546
+ self.atom_layout_mnk,
547
+ permutation_mnk=(None, permutation_n, None),
548
+ )
415
549
 
416
- tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
417
- mA,
418
- self.a_smem_layout_staged,
419
- (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
420
- self.cluster_shape_mnk[1],
421
- )
550
+ if const_expr(not self.gather_A):
551
+ tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
552
+ mA,
553
+ self.a_smem_layout_staged,
554
+ (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
555
+ self.cluster_shape_mnk[1],
556
+ )
557
+ else:
558
+ tma_atom_a, tma_tensor_a = None, None
422
559
 
423
560
  tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
424
561
  mB,
@@ -427,17 +564,84 @@ class HopperWgmmaGemmKernel:
427
564
  self.cluster_shape_mnk[0],
428
565
  )
429
566
 
430
- tma_atom_d, tma_tensor_d = self._make_tma_store_atoms_and_tensors(
431
- mD,
432
- self.epi_smem_layout_staged,
433
- self.epi_tile,
567
+ tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
568
+ mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
434
569
  )
435
570
 
436
- grid = self._compute_grid(mD, self.tile_shape_mnk, self.cluster_shape_mnk)
571
+ if const_expr(mC is not None):
572
+ tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
573
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
574
+ )
575
+ else:
576
+ tma_atom_c, tma_tensor_c = None, None
577
+
578
+ if const_expr(mCuSeqlensM is None):
579
+ problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (
580
+ mD.shape[2],
581
+ )
582
+ TileSchedulerCls = TileScheduler
583
+ tile_sched_args = TileSchedulerArguments(
584
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
585
+ raster_order=RasterOrderOption.Heuristic,
586
+ group_size=8,
587
+ cluster_shape_mnk=self.cluster_shape_mnk,
588
+ tile_count_semaphore=tile_count_semaphore,
589
+ is_persistent=self.is_persistent,
590
+ )
591
+ else:
592
+ assert mTensormaps is not None
593
+ problem_shape_ntile_mnl = (
594
+ None,
595
+ cute.ceil_div(mD.shape[1], self.tile_shape_mnk[1]),
596
+ mCuSeqlensM.shape[0] - 1,
597
+ )
598
+ TileSchedulerCls = VarlenMTileScheduler
599
+ tile_sched_args = VarlenMTileSchedulerArguments(
600
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
601
+ total_m=mD.shape[0],
602
+ cu_seqlens_m=mCuSeqlensM,
603
+ raster_order=RasterOrderOption.Heuristic,
604
+ group_size=8,
605
+ tile_shape_mnk=self.tile_shape_mnk,
606
+ cluster_shape_mnk=self.cluster_shape_mnk,
607
+ tile_count_semaphore=tile_count_semaphore,
608
+ is_persistent=self.is_persistent,
609
+ )
610
+ tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
611
+ grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
612
+
613
+ epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if self.is_persistent else 0
614
+ epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
615
+
616
+ size_tensormap_in_i64 = (
617
+ 0
618
+ if mCuSeqlensM is None
619
+ or self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.GMEM
620
+ else HopperWgmmaGemmKernel.num_tensormaps
621
+ * HopperWgmmaGemmKernel.bytes_per_tensormap
622
+ // 8
623
+ ) * (1 if not self.pingpong else 2)
437
624
 
438
625
  @cute.struct
439
626
  class SharedStorage:
440
- mainloop_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
627
+ tensormap_buffer: cute.struct.Align[
628
+ cute.struct.MemRange[cutlass.Int64, size_tensormap_in_i64],
629
+ 64,
630
+ ]
631
+ ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
632
+ epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
633
+ sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
634
+ tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
635
+ sD: cute.struct.Align[
636
+ cute.struct.MemRange[self.d_dtype, epi_smem_size],
637
+ self.buffer_align_bytes,
638
+ ]
639
+ sC: cute.struct.Align[
640
+ cute.struct.MemRange[
641
+ self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
642
+ ],
643
+ self.buffer_align_bytes,
644
+ ]
441
645
  sA: cute.struct.Align[
442
646
  cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
443
647
  self.buffer_align_bytes,
@@ -452,16 +656,25 @@ class HopperWgmmaGemmKernel:
452
656
  # Launch the kernel synchronously
453
657
  self.kernel(
454
658
  tma_atom_a,
455
- tma_tensor_a,
659
+ tma_tensor_a if const_expr(not self.gather_A) else mA,
456
660
  tma_atom_b,
457
661
  tma_tensor_b,
458
662
  tma_atom_d,
459
663
  tma_tensor_d,
664
+ mD,
665
+ tma_atom_c,
666
+ tma_tensor_c,
667
+ mAIdx,
668
+ mCuSeqlensM,
669
+ mTensormaps,
460
670
  tiled_mma,
461
- self.cta_layout_mnk,
671
+ self.cluster_layout_mnk,
462
672
  self.a_smem_layout_staged,
463
673
  self.b_smem_layout_staged,
464
674
  self.epi_smem_layout_staged,
675
+ self.epi_c_smem_layout_staged,
676
+ tile_sched_params,
677
+ TileSchedulerCls,
465
678
  ).launch(
466
679
  grid=grid,
467
680
  block=[self.threads_per_cta, 1, 1],
@@ -476,17 +689,26 @@ class HopperWgmmaGemmKernel:
476
689
  @cute.kernel
477
690
  def kernel(
478
691
  self,
479
- tma_atom_a: cute.CopyAtom,
692
+ tma_atom_a: Optional[cute.CopyAtom],
480
693
  mA_mkl: cute.Tensor,
481
694
  tma_atom_b: cute.CopyAtom,
482
695
  mB_nkl: cute.Tensor,
483
696
  tma_atom_d: cute.CopyAtom,
697
+ mD_mnl_tma: cute.Tensor,
484
698
  mD_mnl: cute.Tensor,
699
+ tma_atom_c: Optional[cute.CopyAtom],
700
+ mC_mnl: Optional[cute.Tensor],
701
+ mAIdx: Optional[cute.Tensor],
702
+ cu_seqlens_m: Optional[cute.Tensor],
703
+ tensormaps: Optional[cute.Tensor],
485
704
  tiled_mma: cute.TiledMma,
486
- cta_layout_mnk: cute.Layout,
705
+ cluster_layout_mnk: cute.Layout,
487
706
  a_smem_layout_staged: cute.ComposedLayout,
488
707
  b_smem_layout_staged: cute.ComposedLayout,
489
708
  epi_smem_layout_staged: cute.ComposedLayout,
709
+ epi_c_smem_layout_staged: cute.ComposedLayout,
710
+ tile_sched_params: ParamsBase,
711
+ TileSchedulerCls: cutlass.Constexpr[Callable],
490
712
  ):
491
713
  """
492
714
  GPU device kernel performing the batched GEMM computation.
@@ -501,12 +723,12 @@ class HopperWgmmaGemmKernel:
501
723
  :type mB_nkl: cute.Tensor
502
724
  :param tma_atom_d: TMA copy atom for D tensor
503
725
  :type tma_atom_d: cute.CopyAtom
504
- :param mD_mnl: Output tensor D
505
- :type mD_mnl: cute.Tensor
726
+ :param mD_mnl_tma: Output tensor D
727
+ :type mD_mnl_tma: cute.Tensor
506
728
  :param tiled_mma: Tiled MMA object
507
729
  :type tiled_mma: cute.TiledMma
508
- :param cta_layout_mnk: CTA layout
509
- :type cta_layout_mnk: cute.Layout
730
+ :param cluster_layout_mnk: CTA layout
731
+ :type cluster_layout_mnk: cute.Layout
510
732
  :param a_smem_layout_staged: Shared memory layout for A
511
733
  :type a_smem_layout_staged: cute.ComposedLayout
512
734
  :param b_smem_layout_staged: Shared memory layout for B
@@ -515,22 +737,25 @@ class HopperWgmmaGemmKernel:
515
737
  :type epi_smem_layout_staged: cute.ComposedLayout
516
738
  """
517
739
 
740
+ varlen = const_expr(cu_seqlens_m is not None)
518
741
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
519
742
 
520
743
  # /////////////////////////////////////////////////////////////////////////////
521
744
  # Prefetch Tma desc
522
745
  # /////////////////////////////////////////////////////////////////////////////
523
- # if warp_idx == 0:
524
- if warp_idx == self.mma_warp_groups * 4:
525
- cpasync.prefetch_descriptor(tma_atom_a)
746
+ if warp_idx == self.ab_load_warp_id:
747
+ if const_expr(tma_atom_a is not None):
748
+ cpasync.prefetch_descriptor(tma_atom_a)
526
749
  cpasync.prefetch_descriptor(tma_atom_b)
527
750
  cpasync.prefetch_descriptor(tma_atom_d)
751
+ if const_expr(tma_atom_c is not None):
752
+ cpasync.prefetch_descriptor(tma_atom_c)
528
753
 
529
754
  a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
530
755
  b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
531
- tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes(
532
- self.b_dtype, b_smem_layout
533
- )
756
+ tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
757
+ if const_expr(not self.gather_A):
758
+ tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
534
759
 
535
760
  # /////////////////////////////////////////////////////////////////////////////
536
761
  # Alloc and init AB full/empty + ACC full mbar (pipeline)
@@ -539,163 +764,378 @@ class HopperWgmmaGemmKernel:
539
764
  storage = smem.allocate(self.shared_storage)
540
765
 
541
766
  # Threads/warps participating in this pipeline
542
- mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
543
- # Each warp will constribute to the arrive count with the number of mcast size
767
+ producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
768
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
769
+ # Each warp will contribute to the arrive count with the number of mcast size
544
770
  mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
545
- consumer_arrive_cnt = mcast_size * (self.num_mma_threads // cute.arch.WARP_SIZE)
546
- mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
771
+ consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
772
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(
547
773
  pipeline.Agent.Thread, consumer_arrive_cnt
548
774
  )
549
775
 
550
- cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
551
- mainloop_pipeline = pipeline.PipelineTmaAsync.create(
552
- barrier_storage=storage.mainloop_pipeline_array_ptr.data_ptr(),
776
+ cta_layout_vmnk = cute.make_layout((1, *cluster_layout_mnk.shape))
777
+ pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
778
+ ab_pipeline = pipeline_cls.create(
779
+ barrier_storage=storage.ab_pipeline_array_ptr.data_ptr(),
553
780
  num_stages=self.ab_stage,
554
- producer_group=mainloop_pipeline_producer_group,
555
- consumer_group=mainloop_pipeline_consumer_group,
781
+ producer_group=ab_pipeline_producer_group,
782
+ consumer_group=ab_pipeline_consumer_group,
556
783
  tx_count=tma_copy_bytes,
557
784
  cta_layout_vmnk=cta_layout_vmnk,
558
785
  )
559
786
 
787
+ if const_expr(mC_mnl is not None):
788
+ # Threads/warps participating in this pipeline
789
+ epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
790
+ # Each warp will contribute 1 to the arrive count
791
+ consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
792
+ epi_pipeline_consumer_group = pipeline.CooperativeGroup(
793
+ pipeline.Agent.Thread, consumer_arrive_cnt
794
+ )
795
+ c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
796
+ tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
797
+ epi_pipeline = pipeline.PipelineTmaAsync.create(
798
+ barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
799
+ num_stages=self.epi_c_stage,
800
+ producer_group=epi_pipeline_producer_group,
801
+ consumer_group=epi_pipeline_consumer_group,
802
+ tx_count=tma_copy_c_bytes,
803
+ )
804
+ else:
805
+ epi_pipeline = None
806
+
807
+ if const_expr(tile_sched_params.tile_count_semaphore is not None):
808
+ # Dynamic persistent scheduler
809
+ # Threads/warps participating in this pipeline
810
+ sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
811
+ cluster_size = cute.size(cluster_layout_mnk)
812
+ # Each warp that are not the scheduler warp will contribute 1 to the arrive count
813
+ consumer_arrive_cnt = (
814
+ (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
815
+ ) * cluster_size - 1
816
+ sched_pipeline_consumer_group = pipeline.CooperativeGroup(
817
+ pipeline.Agent.Thread, consumer_arrive_cnt
818
+ )
819
+ sched_pipeline = pipeline.PipelineAsync.create(
820
+ barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
821
+ num_stages=self.sched_stage,
822
+ producer_group=sched_pipeline_producer_group,
823
+ consumer_group=sched_pipeline_consumer_group,
824
+ # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
825
+ consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
826
+ )
827
+ tile_count = storage.tile_count.get_tensor((self.sched_stage,))
828
+ else:
829
+ sched_pipeline = None
830
+ tile_count = None
831
+
560
832
  # ///////////////////////////////////////////////////////////////////////////////
561
833
  # Generate smem tensor A/B
562
834
  # ///////////////////////////////////////////////////////////////////////////////
563
835
  sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
564
836
  sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
565
- sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
566
- sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
567
-
568
- # ///////////////////////////////////////////////////////////////////////////////
569
- # Get cta/warp/thread idx
570
- # ///////////////////////////////////////////////////////////////////////////////
571
-
572
- cidx, cidy, _ = cute.arch.cluster_idx()
573
- cdimx, cdimy, _ = cute.arch.cluster_dim()
574
- cluster_id = cidx + cdimx * cidy
837
+ if const_expr(not self.is_persistent):
838
+ sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
839
+ sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
840
+ else:
841
+ sD = storage.sD.get_tensor(
842
+ epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
843
+ )
844
+ if const_expr(mC_mnl is not None):
845
+ sC = storage.sC.get_tensor(
846
+ epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
847
+ )
848
+ else:
849
+ sC = None
850
+
851
+ # Get tensormap buffer address
852
+ if const_expr(varlen):
853
+ grid_dim = cute.arch.grid_dim()
854
+ bid = cute.arch.block_idx()
855
+ tensormap_workspace_idx = (
856
+ bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0]
857
+ )
858
+ # TODO: this is only for D, not for A/B
859
+ if const_expr(self.pingpong):
860
+ tensormap_workspace_idx = tensormap_workspace_idx * 2 + warp_idx // 4
861
+ tensormap_manager = TensorMapManagerSm90(
862
+ self.tensormap_update_mode, HopperWgmmaGemmKernel.bytes_per_tensormap
863
+ )
864
+ tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
865
+ tensormaps[tensormap_workspace_idx, None].iterator
866
+ )
867
+ if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.SMEM):
868
+ tensormap_smem_ptr = storage.tensormap_buffer.data_ptr()
869
+ tensormap_d_smem_ptr = tensormap_smem_ptr + (warp_idx // 4) * (
870
+ HopperWgmmaGemmKernel.bytes_per_tensormap // 8
871
+ )
872
+ # Need this, otherwise "expected tma descriptor pointer to have alignment at least 64, but got 8"
873
+ tensormap_d_smem_ptr = cute.make_ptr(
874
+ cutlass.Int64,
875
+ tensormap_d_smem_ptr.toint(),
876
+ cute.AddressSpace.smem,
877
+ assumed_align=64,
878
+ )
879
+ tensormap_d_init_ptr = tensormap_d_smem_ptr
880
+ else:
881
+ tensormap_d_smem_ptr = None
882
+ tensormap_d_init_ptr = tensormap_d_ptr
883
+ else:
884
+ tensormap_d_smem_ptr = None
885
+ tensormap_manager, tensormap_d_ptr, tensormap_d_init_ptr = None, None, None
575
886
 
576
- # CTA Swizzle to promote L2 data reuse
577
- group_size_m = 8
578
- s_shape = (
579
- (group_size_m, cdimx // group_size_m),
580
- cdimy,
887
+ TileSchedulerCls = partial(
888
+ TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
581
889
  )
582
- s_stride = ((1, cdimy * group_size_m), group_size_m)
583
- s_layout = cute.make_layout(s_shape, stride=s_stride)
584
- num_reg_cids = cute.size(s_shape)
585
- cid_m, cid_n = s_layout.get_flat_coord(cluster_id % num_reg_cids)
586
-
587
- # Deal with the tail part
588
- if cluster_id >= num_reg_cids:
589
- tail_size_m = cdimx % group_size_m
590
- tail_layout = cute.make_layout((tail_size_m, cdimy), stride=(1, tail_size_m))
591
- tail_cid = cluster_id - num_reg_cids
592
- tail_cid_m, tail_cid_n = tail_layout.get_flat_coord(tail_cid)
593
- cid_m = cute.size(s_shape, mode=[0]) + tail_cid_m
594
- cid_n = tail_cid_n
595
-
596
- # Get the pid from cluster id
597
- bidx_in_cluster = cute.arch.block_in_cluster_idx()
598
- pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
599
- pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
600
-
601
- _, _, bidz = cute.arch.block_idx()
602
- tile_coord_mnkl = (pid_m, pid_n, None, bidz)
603
- cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
604
- cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
605
890
 
606
891
  k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
892
+ c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
607
893
 
608
- if warp_idx >= self.mma_warp_groups * 4:
894
+ if warp_idx >= self.ab_load_warp_id:
609
895
  cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
610
- if warp_idx == self.mma_warp_groups * 4:
896
+ if const_expr(mC_mnl is not None):
897
+ epi_load_barrier = pipeline.NamedBarrier(
898
+ barrier_id=int(NamedBarrierGemm.EpilogueLoad),
899
+ num_threads=self.num_ab_load_threads + self.num_epi_load_threads,
900
+ )
901
+ else:
902
+ epi_load_barrier = None
903
+ if (
904
+ warp_idx >= self.ab_load_warp_id
905
+ and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
906
+ ):
611
907
  # ///////////////////////////////////////////////////////////////////////////////
612
908
  # Get mcast mask
613
909
  # ///////////////////////////////////////////////////////////////////////////////
910
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
911
+ cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
614
912
  a_mcast_mask = cute.make_layout_image_mask(
615
- cta_layout_mnk, cluster_coord_mnk, mode=1
913
+ cluster_layout_mnk, cluster_coord_mnk, mode=1
616
914
  )
617
915
  b_mcast_mask = cute.make_layout_image_mask(
618
- cta_layout_mnk, cluster_coord_mnk, mode=0
916
+ cluster_layout_mnk, cluster_coord_mnk, mode=0
619
917
  )
620
918
  a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
621
919
  b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
622
- mainloop_producer_state = pipeline.make_pipeline_state(
920
+
921
+ # Persistent tile scheduling loop
922
+ is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
923
+ if const_expr(cute.size(cluster_layout_mnk) > 1):
924
+ is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
925
+ tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
926
+ work_tile = tile_scheduler.initial_work_tile_info()
927
+ ab_producer_state = make_pipeline_state(
623
928
  pipeline.PipelineUserType.Producer, self.ab_stage
624
929
  )
625
- # ///////////////////////////////////////////////////////////////////////////////
626
- # Local_tile partition global tensors
627
- # ///////////////////////////////////////////////////////////////////////////////
628
- # (bM, bK, RestK)
629
- gA_mkl = cute.local_tile(
630
- mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
631
- )
632
- # (bN, bK, RestK)
633
- gB_nkl = cute.local_tile(
634
- mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
635
- )
636
- # //////////////////////////////////////////////////////////////////////////////
637
- # Partition shared tensor for TMA load A/B
638
- # //////////////////////////////////////////////////////////////////////////////
639
- # TMA load A partition_S/D
640
- a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
641
- a_cta_crd = cluster_coord_mnk[1]
642
- tAsA, tAgA_mkl = cpasync.tma_partition(
643
- tma_atom_a,
644
- a_cta_crd,
645
- a_cta_layout,
646
- cute.group_modes(sA, 0, 2),
647
- cute.group_modes(gA_mkl, 0, 2),
648
- )
649
- # TMA load B partition_S/D
650
- b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
651
- b_cta_crd = cluster_coord_mnk[0]
652
- tBsB, tBgB_nkl = cpasync.tma_partition(
653
- tma_atom_b,
654
- b_cta_crd,
655
- b_cta_layout,
656
- cute.group_modes(sB, 0, 2),
657
- cute.group_modes(gB_nkl, 0, 2),
658
- )
659
- # /////////////////////////////////////////////////////////////////////////////
660
- # TMA load
661
- # /////////////////////////////////////////////////////////////////////////////
662
- for k_tile in cutlass.range(k_tile_cnt, unroll=1):
663
- # Wait for A/B buffers to be empty before loading into them
664
- # Also sets the transaction barrier for the A/B buffers
665
- mainloop_pipeline.producer_acquire(mainloop_producer_state)
666
- # /////////////////////////////////////////////////////////////////////////////
667
- # TMA load A/B
668
- # /////////////////////////////////////////////////////////////////////////////
669
- cute.copy(
670
- tma_atom_a,
671
- tAgA_mkl[None, k_tile],
672
- tAsA[None, mainloop_producer_state.index],
673
- tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state),
674
- mcast_mask=a_mcast_mask,
930
+ do_epi_load_barrier_arrive = cutlass.Boolean(True)
931
+ while work_tile.is_valid_tile:
932
+ tile_coord_mnkl = work_tile.tile_idx
933
+ batch_idx = tile_coord_mnkl[3]
934
+ # ///////////////////////////////////////////////////////////////////////////
935
+ # Local_tile partition global tensors
936
+ # ///////////////////////////////////////////////////////////////////////////
937
+ if const_expr(not self.gather_A):
938
+ if const_expr(cu_seqlens_m is not None):
939
+ mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
940
+ else:
941
+ mA_mk = mA_mkl[None, None, batch_idx]
942
+ # (bM, bK, RestK)
943
+ gA_k = cute.local_tile(
944
+ mA_mk,
945
+ cute.select(self.tile_shape_mnk, [0, 2]),
946
+ (tile_coord_mnkl[0], None),
947
+ )
948
+ else:
949
+ mA_mk = mA_mkl
950
+ if const_expr(cu_seqlens_m is not None):
951
+ mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
952
+ else:
953
+ mAIdx_mk = mAIdx[None, batch_idx]
954
+ gAIdx = cute.local_tile(
955
+ mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
956
+ )
957
+ # (bN, bK, RestK)
958
+ gB_k = cute.local_tile(
959
+ mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
960
+ )
961
+ # //////////////////////////////////////////////////////////////////////////
962
+ # Partition shared tensor for TMA load A/B
963
+ # //////////////////////////////////////////////////////////////////////////
964
+ # TMA load A partition_S/D
965
+ a_cta_layout = cute.make_layout(
966
+ cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
675
967
  )
676
- cute.copy(
968
+ a_cta_crd = cluster_coord_mnk[1]
969
+ if const_expr(not self.gather_A):
970
+ # ((atom_v, rest_v), STAGE)
971
+ # ((atom_v, rest_v), RestK)
972
+ tAsA, tAgA_k = cpasync.tma_partition(
973
+ tma_atom_a,
974
+ a_cta_crd,
975
+ a_cta_layout,
976
+ cute.group_modes(sA, 0, 2),
977
+ cute.group_modes(gA_k, 0, 2),
978
+ )
979
+ copy_A = partial(cute.copy, tma_atom_a, mcast_mask=a_mcast_mask)
980
+ else:
981
+ tiled_copy_A = self._make_gmem_tiled_copy_A(
982
+ mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
983
+ )
984
+ tidx = (
985
+ cute.arch.thread_idx()[0]
986
+ - self.mma_warp_groups * self.num_threads_per_warp_group
987
+ )
988
+ thr_copy_A = tiled_copy_A.get_slice(tidx)
989
+ # (atom_v, CPY_M, 1, STAGE)
990
+ tAsA = thr_copy_A.partition_D(sA)
991
+ assert tAsA.shape[2] == 1
992
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
993
+ copy_A = partial(cute.copy, tiled_copy_A)
994
+ # TMA load B partition_S/D
995
+ b_cta_layout = cute.make_layout(
996
+ cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
997
+ )
998
+ b_cta_crd = cluster_coord_mnk[0]
999
+ # ((atom_v, rest_v), STAGE)
1000
+ # ((atom_v, rest_v), RestK)
1001
+ tBsB, tBgB_k = cpasync.tma_partition(
677
1002
  tma_atom_b,
678
- tBgB_nkl[None, k_tile],
679
- tBsB[None, mainloop_producer_state.index],
680
- tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state),
681
- mcast_mask=b_mcast_mask,
1003
+ b_cta_crd,
1004
+ b_cta_layout,
1005
+ cute.group_modes(sB, 0, 2),
1006
+ cute.group_modes(gB_k, 0, 2),
682
1007
  )
683
- # Mainloop pipeline's producer commit is a NOP
684
- mainloop_pipeline.producer_commit(mainloop_producer_state)
685
- mainloop_producer_state.advance()
686
- mainloop_pipeline.producer_tail(mainloop_producer_state)
687
-
688
- if warp_idx < self.mma_warp_groups * 4:
1008
+ copy_B = partial(cute.copy, tma_atom_b, mcast_mask=b_mcast_mask)
1009
+ if const_expr(not self.gather_A):
1010
+ ab_producer_state = self.load_AB(
1011
+ ab_pipeline,
1012
+ ab_producer_state,
1013
+ copy_A,
1014
+ tAgA_k,
1015
+ tAsA,
1016
+ copy_B,
1017
+ tBgB_k,
1018
+ tBsB,
1019
+ )
1020
+ else:
1021
+ limit_m = (
1022
+ mAIdx.shape[0]
1023
+ if const_expr(cu_seqlens_m is None)
1024
+ else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
1025
+ )
1026
+ ab_producer_state = self.load_AB_gather_A(
1027
+ ab_pipeline,
1028
+ ab_producer_state,
1029
+ thr_copy_A,
1030
+ mA_mk,
1031
+ tAsA,
1032
+ gAIdx,
1033
+ copy_B,
1034
+ tBgB_k,
1035
+ tBsB,
1036
+ limit_A=(
1037
+ limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
1038
+ mA_mk.shape[1],
1039
+ ),
1040
+ )
1041
+ if const_expr(epi_load_barrier is not None):
1042
+ # In the first work tile, the epi load warp will wait for the signal
1043
+ # from the mainloop load warp to start loading C, to avoid interfering
1044
+ # with loading A and B.
1045
+ if do_epi_load_barrier_arrive:
1046
+ epi_load_barrier.arrive()
1047
+ do_epi_load_barrier_arrive = cutlass.Boolean(False)
1048
+ tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
1049
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1050
+ work_tile = tile_scheduler.get_current_work()
1051
+ # End of persistent scheduler loop
1052
+ if const_expr(self.pingpong):
1053
+ # Need to write the tile_idx to smem for the next WG in the pingpong mode
1054
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1055
+ ab_pipeline.producer_tail(ab_producer_state)
1056
+ if is_scheduler_warp:
1057
+ tile_scheduler.producer_tail()
1058
+
1059
+ # if const_expr(mC_mnl is not None):
1060
+ # if warp_idx == self.epi_load_warp_id:
1061
+ # epi_producer_state = make_pipeline_state(
1062
+ # pipeline.PipelineUserType.Producer, self.epi_c_stage
1063
+ # )
1064
+ # do_epi_load_barrier_wait = cutlass.Boolean(True)
1065
+ # tile_scheduler = TileSchedulerCls()
1066
+ # work_tile = tile_scheduler.initial_work_tile_info()
1067
+ # while work_tile.is_valid_tile:
1068
+ # tile_coord_mnkl = work_tile.tile_idx
1069
+ # batch_idx = tile_coord_mnkl[3]
1070
+ # if const_expr(cu_seqlens_m is not None):
1071
+ # mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
1072
+ # else:
1073
+ # mC_mn = mC_mnl[None, None, batch_idx]
1074
+ # # (bM, bN)
1075
+ # gC = cute.local_tile(
1076
+ # mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1077
+ # )
1078
+ # tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
1079
+ # bGS_sC, bGS_gC = cpasync.tma_partition(
1080
+ # tma_atom_c,
1081
+ # 0,
1082
+ # cute.make_layout(1),
1083
+ # cute.group_modes(sC, 0, 2),
1084
+ # tCgC_for_tma_partition,
1085
+ # )
1086
+ # if do_epi_load_barrier_wait:
1087
+ # epi_load_barrier.arrive_and_wait()
1088
+ # do_epi_load_barrier_wait = cutlass.Boolean(False)
1089
+ # epi_tile_num = const_expr(cute.size(tCgC_for_tma_partition, mode=[1]))
1090
+ # epi_tile_shape = tCgC_for_tma_partition.shape[1]
1091
+ # for epi_idx in cutlass.range(epi_tile_num, unroll=1):
1092
+ # epi_pipeline.producer_acquire(epi_producer_state)
1093
+ # # Get the global memory coordinate for the current epi tile
1094
+ # epi_tile_layout = cute.make_layout(
1095
+ # epi_tile_shape, stride=(epi_tile_shape[1], 1)
1096
+ # )
1097
+ # gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1098
+ # cute.copy(
1099
+ # tma_atom_c,
1100
+ # bGS_gC[None, gmem_coord],
1101
+ # bGS_sC[None, epi_producer_state.index],
1102
+ # tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1103
+ # )
1104
+ # # Epi pipeline's producer commit is a NOP
1105
+ # epi_pipeline.producer_commit(epi_producer_state)
1106
+ # epi_producer_state.advance()
1107
+ # tile_scheduler.advance_to_next_work()
1108
+ # work_tile = tile_scheduler.get_current_work()
1109
+ # # End of persistent scheduler loop
1110
+ # epi_pipeline.producer_tail(epi_producer_state)
1111
+
1112
+ if warp_idx < self.ab_load_warp_id:
689
1113
  cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
1114
+ is_tma_warp = cutlass.Boolean(
1115
+ (not self.pingpong and warp_idx == 0)
1116
+ or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
1117
+ )
1118
+ if const_expr(varlen):
1119
+ # initialize tensormap for D
1120
+ tensormap_manager.init_tensormap_from_atom(
1121
+ tma_atom_d,
1122
+ tensormap_d_init_ptr,
1123
+ is_manager_warp=is_tma_warp,
1124
+ )
690
1125
  # //////////////////////////////////////////////////////////////////////////////
691
1126
  # Partition global tensor for TiledMMA_A/B/C
692
1127
  # //////////////////////////////////////////////////////////////////////////////
693
1128
  tidx, _, _ = cute.arch.thread_idx()
694
1129
  warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
1130
+ if const_expr(self.pingpong):
1131
+ tidx = tidx % self.num_threads_per_warp_group
695
1132
  warp_group_thread_layout = cute.make_layout(
696
- self.mma_warp_groups, stride=self.num_threads_per_warp_group
1133
+ self.mma_warp_groups if not self.pingpong else 1,
1134
+ stride=self.num_threads_per_warp_group,
1135
+ )
1136
+ thr_mma = tiled_mma.get_slice(
1137
+ warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
697
1138
  )
698
- thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx))
699
1139
 
700
1140
  # //////////////////////////////////////////////////////////////////////////////
701
1141
  # Make fragments
@@ -705,148 +1145,537 @@ class HopperWgmmaGemmKernel:
705
1145
 
706
1146
  acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
707
1147
  acc = cute.make_fragment(acc_shape, self.acc_dtype)
1148
+ if const_expr(self.fp8_slow_accum):
1149
+ acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
1150
+ else:
1151
+ acc_slow = None
1152
+
1153
+ if const_expr(self.pingpong):
1154
+ if warp_group_idx == 0:
1155
+ # WG0 needs a start signal at the very beginning
1156
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
1157
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
708
1158
 
709
- mainloop_consumer_read_state = pipeline.make_pipeline_state(
710
- pipeline.PipelineUserType.Consumer, self.ab_stage
1159
+ ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
1160
+ epi_read_state = make_pipeline_state(
1161
+ pipeline.PipelineUserType.Consumer, self.epi_c_stage
711
1162
  )
712
- mainloop_consumer_release_state = pipeline.make_pipeline_state(
713
- pipeline.PipelineUserType.Consumer, self.ab_stage
1163
+ epi_producer_state = make_pipeline_state(
1164
+ pipeline.PipelineUserType.Producer, self.epi_c_stage
714
1165
  )
1166
+ tile_scheduler = TileSchedulerCls()
1167
+ if const_expr(self.pingpong):
1168
+ if warp_idx >= 4:
1169
+ # Advance 2nd Math WG to the next work tile for the startup
1170
+ tile_scheduler.advance_to_next_work()
1171
+ # Advance 2nd Math WG pipeline states to the end of 1st Math WG
1172
+ ab_read_state.advance_iters(k_tile_cnt)
1173
+ epi_read_state.advance_iters(c_tile_cnt)
1174
+ epi_producer_state.advance_iters(c_tile_cnt)
1175
+ work_tile = tile_scheduler.initial_work_tile_info()
1176
+ if const_expr(varlen):
1177
+ # wait tensormap initialization complete before update
1178
+ tensormap_manager.fence_tensormap_initialization()
1179
+ # batch index of last tile
1180
+ last_batch_idx = cutlass.Int32(-1)
1181
+ while work_tile.is_valid_tile:
1182
+ tile_coord_mnkl = work_tile.tile_idx
1183
+ batch_idx = tile_coord_mnkl[3]
1184
+ if const_expr(varlen):
1185
+ is_group_changed = batch_idx != last_batch_idx
1186
+ last_batch_idx = batch_idx
1187
+ if is_group_changed:
1188
+ # construct tensor D based on real address, shape and stride information
1189
+ tensormap_manager.update_tensormap_shape(
1190
+ ((tensormap_d_ptr),),
1191
+ is_manager_warp=is_tma_warp,
1192
+ tensormap_smem_ptr=(tensormap_d_smem_ptr,),
1193
+ shapes=(cu_seqlens_m[batch_idx + 1],),
1194
+ orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
1195
+ )
1196
+
1197
+ ab_read_state, tiled_mma = self.mma(
1198
+ ab_pipeline,
1199
+ ab_read_state,
1200
+ tiled_mma,
1201
+ tCrA,
1202
+ tCrB,
1203
+ acc,
1204
+ acc_slow,
1205
+ k_tile_cnt,
1206
+ warp_group_idx,
1207
+ )
1208
+ if const_expr(self.pingpong):
1209
+ # Update starting mainloop pipeline state for the next tile
1210
+ ab_read_state.advance_iters(k_tile_cnt)
715
1211
 
716
- # /////////////////////////////////////////////////////////////////////////////
717
- # Prologue MMAs
718
- # /////////////////////////////////////////////////////////////////////////////
719
- k_pipe_mmas = 1
720
- peek_ab_full_status = cutlass.Boolean(1)
721
- if mainloop_consumer_read_state.count < k_tile_cnt:
722
- peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
723
- mainloop_consumer_read_state
1212
+ # /////////////////////////////////////////////////////////////////////////////
1213
+ # EPILOGUE
1214
+ # /////////////////////////////////////////////////////////////////////////////
1215
+ if const_expr(self.pingpong):
1216
+ self.pingpong_barrier_sync(warp_group_idx, "epi")
1217
+
1218
+ epilogue_barrier = pipeline.NamedBarrier(
1219
+ barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
724
1220
  )
725
- tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
726
- num_k_blocks = cute.size(tCrA, mode=[2])
727
- for k_tile in cutlass.range_constexpr(k_pipe_mmas):
728
- # Wait for A/B buffer to be ready
729
- mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status)
730
- warpgroup.fence()
731
- for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
732
- k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index)
733
- cute.gemm(tiled_mma, acc, tCrA[k_block_coord], tCrB[k_block_coord], acc)
734
- tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
735
- warpgroup.commit_group()
736
- mainloop_consumer_read_state.advance()
737
- peek_ab_full_status = cutlass.Boolean(1)
738
- if mainloop_consumer_read_state.count < k_tile_cnt:
739
- peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
740
- mainloop_consumer_read_state
1221
+
1222
+ # Wait for all warp groups in the thread block to finish, because smem for tensor
1223
+ # A in the mainloop is reused in the epilogue if not persistent.
1224
+ if const_expr(not self.is_persistent):
1225
+ epilogue_barrier.arrive_and_wait()
1226
+
1227
+ if const_expr(varlen):
1228
+ # ensure the update to tensormap has completed before using it
1229
+ if is_group_changed:
1230
+ if is_tma_warp:
1231
+ tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1232
+
1233
+ # Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
1234
+ # get st.matrix with num_matrices=4
1235
+ copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
1236
+ self.d_layout, elem_ty_d=self.d_dtype, elem_ty_acc=self.acc_dtype
1237
+ )
1238
+ copy_atom_C = cute.make_copy_atom(
1239
+ warp.StMatrix8x8x16bOp(
1240
+ self.d_layout.is_m_major_c(),
1241
+ num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
1242
+ ),
1243
+ cutlass.Float16, # this is just to get the right source layout
1244
+ )
1245
+ tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
1246
+ tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
1247
+ # (R2S, R2S_M, R2S_N, PIPE_D)
1248
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1249
+ tRS_sD = thr_copy_r2s.partition_D(sD)
1250
+ # (R2S, R2S_M, R2S_N)
1251
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
1252
+
1253
+ # Allocate D registers.
1254
+ tRS_rD_layout = cute.make_layout(thr_copy_r2s.partition_S(sD).shape[:3])
1255
+ tRS_rD = cute.make_fragment(tRS_rD_layout, self.acc_dtype)
1256
+
1257
+ if const_expr(mC_mnl is not None):
1258
+ copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
1259
+ tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1260
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1261
+ tSR_sC = thr_copy_s2r.partition_S(sC)
1262
+ tRS_rC = cute.make_fragment(tRS_rD_layout, self.c_dtype)
1263
+ tSR_rC = thr_copy_s2r.retile(tRS_rC)
1264
+ else:
1265
+ thr_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
1266
+
1267
+ if const_expr(cu_seqlens_m is not None):
1268
+ mD_mn_tma = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl_tma)
1269
+ else:
1270
+ mD_mn_tma = mD_mnl_tma[None, None, batch_idx]
1271
+ # (bM, bN)
1272
+ gD = cute.local_tile(
1273
+ mD_mn_tma, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1274
+ )
1275
+ tDgD_for_tma_partition = cute.zipped_divide(gD, self.epi_tile)
1276
+ bSG_sD, bSG_gD = cpasync.tma_partition(
1277
+ tma_atom_d,
1278
+ 0,
1279
+ cute.make_layout(1),
1280
+ cute.group_modes(sD, 0, 2),
1281
+ tDgD_for_tma_partition,
1282
+ )
1283
+
1284
+ if const_expr(mC_mnl is not None):
1285
+ if const_expr(cu_seqlens_m is not None):
1286
+ mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
1287
+ else:
1288
+ mC_mn = mC_mnl[None, None, batch_idx]
1289
+ # (bM, bN)
1290
+ gC = cute.local_tile(
1291
+ mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
1292
+ )
1293
+ tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
1294
+ bGS_sC, bGS_gC = cpasync.tma_partition(
1295
+ tma_atom_c,
1296
+ 0,
1297
+ cute.make_layout(1),
1298
+ cute.group_modes(sC, 0, 2),
1299
+ tCgC_for_tma_partition,
741
1300
  )
742
1301
 
743
- # /////////////////////////////////////////////////////////////////////////////
744
- # MAINLOOP
745
- # /////////////////////////////////////////////////////////////////////////////
746
- for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, unroll=1):
747
- # Wait for TMA copies to complete
748
- mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status)
749
- # WGMMA
750
- warpgroup.fence()
751
- for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
752
- k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index)
753
- cute.gemm(tiled_mma, acc, tCrA[k_block_coord], tCrB[k_block_coord], acc)
754
- warpgroup.commit_group()
755
- # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
756
- warpgroup.wait_group(k_pipe_mmas)
757
- mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
758
- mainloop_consumer_read_state.advance()
759
- mainloop_consumer_release_state.advance()
760
- peek_ab_full_status = cutlass.Boolean(1)
761
- if mainloop_consumer_read_state.count < k_tile_cnt:
762
- peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
763
- mainloop_consumer_read_state
1302
+ epi_tile_num = const_expr(cute.size(tDgD_for_tma_partition, mode=[1]))
1303
+ epi_tile_shape = tDgD_for_tma_partition.shape[1]
1304
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1305
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
1306
+
1307
+ if const_expr(mC_mnl is not None):
1308
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
1309
+ if is_tma_warp:
1310
+ epi_pipeline.producer_acquire(epi_producer_state)
1311
+ # Get the global memory coordinate for the current epi tile
1312
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1313
+ cute.copy(
1314
+ tma_atom_c,
1315
+ bGS_gC[None, gmem_coord],
1316
+ bGS_sC[None, epi_producer_state.index],
1317
+ tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1318
+ )
1319
+ # Epi pipeline's producer commit is a NOP
1320
+ epi_pipeline.producer_commit(epi_producer_state)
1321
+ epi_producer_state.advance()
1322
+
1323
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
1324
+ # Copy from acc to D registers
1325
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1326
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1327
+ if const_expr(mC_mnl is not None):
1328
+ epi_pipeline.consumer_wait(epi_read_state)
1329
+ cute.copy(
1330
+ thr_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
1331
+ )
1332
+ # Fence to make sure shared memory read is visible to TMA load
1333
+ cute.arch.fence_proxy(
1334
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1335
+ )
1336
+ cute.arch.sync_warp()
1337
+ with cute.arch.elect_one():
1338
+ epi_pipeline.consumer_release(epi_read_state)
1339
+ epi_read_state.advance()
1340
+ if const_expr(epi_idx + self.epi_c_stage < epi_tile_num):
1341
+ if is_tma_warp:
1342
+ epi_pipeline.producer_acquire(epi_producer_state)
1343
+ # Get the global memory coordinate for the current epi tile
1344
+ gmem_coord = epi_tile_layout.get_hier_coord(
1345
+ epi_idx + self.epi_c_stage
1346
+ )
1347
+ cute.copy(
1348
+ tma_atom_c,
1349
+ bGS_gC[None, gmem_coord],
1350
+ bGS_sC[None, epi_producer_state.index],
1351
+ tma_bar_ptr=epi_pipeline.producer_get_barrier(
1352
+ epi_producer_state
1353
+ ),
1354
+ )
1355
+ # Epi pipeline's producer commit is a NOP
1356
+ epi_pipeline.producer_commit(epi_producer_state)
1357
+ epi_producer_state.advance()
1358
+ tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(self.acc_dtype))
1359
+ # Type conversion
1360
+ tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
1361
+ tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
1362
+ # Copy from D registers to shared memory
1363
+ epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3])
1364
+ cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
1365
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1366
+ cute.arch.fence_proxy(
1367
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
764
1368
  )
765
- warpgroup.wait_group(0)
766
- for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
767
- mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
768
- mainloop_consumer_release_state.advance()
769
-
770
- # /////////////////////////////////////////////////////////////////////////////
771
- # EPILOGUE
772
- # /////////////////////////////////////////////////////////////////////////////
773
-
774
- # Wait for all warp groups in the thread block to finish, because smem for tensor A in
775
- # the mainloop is reused in the epilogue.
776
- cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
777
-
778
- copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
779
- self.d_layout,
780
- elem_ty_d=self.d_dtype,
781
- elem_ty_acc=self.acc_dtype,
1369
+ epilogue_barrier.arrive_and_wait()
1370
+ # Get the global memory coordinate for the current epi tile
1371
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1372
+ # Copy from shared memory to global memory
1373
+ if is_tma_warp:
1374
+ if const_expr(varlen):
1375
+ tma_desc_ptr = tensormap_manager.get_tensormap_ptr(
1376
+ tensormap_d_ptr,
1377
+ cute.AddressSpace.generic,
1378
+ )
1379
+ else:
1380
+ tma_desc_ptr = None
1381
+ cute.copy(
1382
+ tma_atom_d,
1383
+ bSG_sD[None, epi_buffer],
1384
+ bSG_gD[None, gmem_coord],
1385
+ tma_desc_ptr=tma_desc_ptr,
1386
+ )
1387
+ cute.arch.cp_async_bulk_commit_group()
1388
+ cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
1389
+ epilogue_barrier.arrive_and_wait()
1390
+
1391
+ if const_expr(self.pingpong):
1392
+ # Update starting load/store pipeline states for the next tile
1393
+ epi_read_state.advance_iters(c_tile_cnt)
1394
+ epi_producer_state.advance_iters(c_tile_cnt)
1395
+ # With pingpong, 2 WGs write two different output tiles to the same smem,
1396
+ # so we have to make sure the smem content is done reading before signaling
1397
+ # the next WG's epilogue.
1398
+ if warp_idx == 0 or warp_idx == 4:
1399
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1400
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
1401
+
1402
+ tile_scheduler.advance_to_next_work(
1403
+ advance_count=1 if not self.pingpong else self.mma_warp_groups
1404
+ )
1405
+ work_tile = tile_scheduler.get_current_work()
1406
+ # End of persistent scheduler loop
1407
+
1408
+ if const_expr(not self.pingpong):
1409
+ if warp_idx == 0:
1410
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1411
+
1412
+ @cute.jit
1413
+ def load_AB(
1414
+ self,
1415
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1416
+ ab_producer_state: cutlass.pipeline.PipelineState,
1417
+ copy_A: Callable,
1418
+ tAgA: cute.Tensor,
1419
+ tAsA: cute.Tensor,
1420
+ copy_B: Callable,
1421
+ tBgB: cute.Tensor,
1422
+ tBsB: cute.Tensor,
1423
+ ) -> cutlass.pipeline.PipelineState:
1424
+ k_tile_cnt = cute.size(tAgA, mode=[1])
1425
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1426
+ peek_ab_empty_status = cutlass.Boolean(True)
1427
+ if 0 < k_tile_cnt:
1428
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1429
+ # /////////////////////////////////////////////////////////////////////////
1430
+ # TMA load
1431
+ # /////////////////////////////////////////////////////////////////////////
1432
+ for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1433
+ # Wait for A/B buffers to be empty before loading into them
1434
+ # Also sets the transaction barrier for the A/B buffers
1435
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1436
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1437
+ copy_A(
1438
+ tAgA[None, k_tile],
1439
+ tAsA[None, ab_producer_state.index],
1440
+ tma_bar_ptr=tma_bar_ptr,
782
1441
  )
783
- copy_atom_D = cute.make_copy_atom(
784
- warp.StMatrix8x8x16bOp(self.d_layout.is_m_major_c(), 4),
785
- self.d_dtype,
1442
+ copy_B(
1443
+ tBgB[None, k_tile],
1444
+ tBsB[None, ab_producer_state.index],
1445
+ tma_bar_ptr=tma_bar_ptr,
786
1446
  )
787
- tiled_copy_D_Atom = cute.make_tiled_copy_C_atom(copy_atom_D, tiled_mma)
788
- tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_D_Atom)
789
- # (R2S, R2S_M, R2S_N, PIPE_D)
790
- tRS_sD = tiled_copy_r2s.get_slice(tidx).partition_D(sD)
791
- # (R2S, R2S_M, R2S_N)
792
- tRS_rAcc = tiled_copy_r2s.retile(acc)
793
-
794
- # (bM, bN)
795
- gD_mnl = cute.local_tile(
796
- mD_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
1447
+ # Mainloop pipeline's producer commit is a NOP
1448
+ ab_pipeline.producer_commit(ab_producer_state)
1449
+ ab_producer_state.advance()
1450
+ peek_ab_empty_status = cutlass.Boolean(True)
1451
+ if k_tile + 1 < k_tile_cnt:
1452
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1453
+ return ab_producer_state
1454
+
1455
+ @cute.jit
1456
+ def load_AB_gather_A(
1457
+ self,
1458
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1459
+ ab_producer_state: cutlass.pipeline.PipelineState,
1460
+ thr_copy_A: cute.core.ThrCopy,
1461
+ mA: cute.Tensor,
1462
+ tAsA: cute.Tensor,
1463
+ gAIdx: cute.Tensor,
1464
+ copy_B: Callable,
1465
+ tBgB: cute.Tensor,
1466
+ tBsB: cute.Tensor,
1467
+ limit_A: Tuple[Int32, Int32],
1468
+ ) -> cutlass.pipeline.PipelineState:
1469
+ # (atom_v, CPY_M, 1, RestK)
1470
+ limit_m, limit_k = limit_A
1471
+ limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
1472
+ cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
1473
+ tAcA = thr_copy_A.partition_S(cA)
1474
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
1475
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
1476
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
1477
+ # This is so that when we do the comparison, t0AcA is known at compile time.
1478
+ limit_m = limit_m - tAcA[0][0]
1479
+ # Read indices for A
1480
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
1481
+ m_idx = cute.make_fragment(rows_per_thread, Int32)
1482
+ for m in cutlass.range(rows_per_thread):
1483
+ row_idx = tAcA[0, m, 0][0]
1484
+ if t0AcA[0, m, 0][0] < limit_m:
1485
+ m_idx[m] = gAIdx[row_idx]
1486
+ else:
1487
+ m_idx[m] = -1
1488
+ elems_per_load = cute.size(tAsA.shape[0][0])
1489
+ # (m, (bK, RestK))
1490
+ mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
1491
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1492
+ k_tile_cnt = cute.size(tBgB, mode=[1])
1493
+ # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1494
+ peek_ab_empty_status = cutlass.Boolean(True)
1495
+ if 0 < k_tile_cnt:
1496
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1497
+ # /////////////////////////////////////////////////////////////////////////
1498
+ # TMA load on B and cp.async on A
1499
+ # /////////////////////////////////////////////////////////////////////////
1500
+ copy_A = partial(cute.copy, thr_copy_A)
1501
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1502
+ # Wait for A/B buffers to be empty before loading into them
1503
+ # Also sets the transaction barrier for the A/B buffers
1504
+ ab_pipeline.producer_acquire(
1505
+ ab_producer_state,
1506
+ peek_ab_empty_status,
1507
+ # A tiny bit faster to rotate the warp that does TMA
1508
+ is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
797
1509
  )
798
- tcgc_for_tma_partition = cute.zipped_divide(gD_mnl, self.epi_tile)
799
- bSG_sD, bSG_gD = cpasync.tma_partition(
800
- tma_atom_d,
801
- 0,
802
- cute.make_layout(1),
803
- cute.group_modes(sD, 0, 2),
804
- tcgc_for_tma_partition,
1510
+ # A bit faster to load B first while we calculate the predicate for A
1511
+ if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1512
+ copy_B(
1513
+ tBgB[None, k_tile],
1514
+ tBsB[None, ab_producer_state.index],
1515
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1516
+ )
1517
+ # (m, bK)
1518
+ mA_cur = mA_k[None, (None, k_tile)]
1519
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1520
+ # (elems_per_load, thread_per_row)
1521
+ mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1522
+ if t0AcA[0, m, 0][0] < limit_m:
1523
+ # There's only 1 load per row
1524
+ assert cute.size(tAcA.shape, mode=[2]) == 1
1525
+ ki = tAcA[0, 0, 0][1] // elems_per_load
1526
+ copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1527
+ # This tells mbarrier to track the completion of cp.async
1528
+ ab_pipeline.producer_commit(ab_producer_state)
1529
+ ab_producer_state.advance()
1530
+ peek_ab_empty_status = cutlass.Boolean(True)
1531
+ if k_tile + 1 < k_tile_cnt:
1532
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1533
+ # bound checking in the K dimension on the last k_tile
1534
+ if 0 < k_tile_cnt:
1535
+ k_tile = k_tile_cnt - 1
1536
+ ab_pipeline.producer_acquire(
1537
+ ab_producer_state,
1538
+ peek_ab_empty_status,
1539
+ is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
805
1540
  )
806
-
807
- epi_tile_num = cutlass.const_expr(cute.size(tcgc_for_tma_partition, mode=[1]))
808
- epi_tile_shape = tcgc_for_tma_partition.shape[1]
809
-
810
- for epi_idx in cutlass.range_constexpr(epi_tile_num):
811
- # Copy from acc to D registers
812
- tRS_rD = cute.make_fragment_like(tRS_sD[None, None, None, 0], self.acc_dtype)
813
- for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
814
- tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
815
- # Type conversion
816
- tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
817
- tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
818
- # Copy from D registers to shared memory
819
- epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3])
820
- # cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
821
- cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)])
822
- cute.arch.fence_proxy(
823
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1541
+ if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1542
+ copy_B(
1543
+ tBgB[None, k_tile],
1544
+ tBsB[None, ab_producer_state.index],
1545
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
824
1546
  )
825
- # barrier for sync
826
- cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
827
- # Get the global memory coordinate for the current epi tile.
828
- epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
829
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
830
- # Copy from shared memory to global memory
831
- if warp_idx == 0:
832
- cute.copy(tma_atom_d, bSG_sD[(None, epi_buffer)], bSG_gD[(None, gmem_coord)])
833
- cute.arch.cp_async_bulk_commit_group()
834
- # TODO: when moving to persistent maybe we always need this wait_group
835
- if epi_idx >= self.epi_stage - 1:
836
- cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
837
- if epi_idx >= self.epi_stage - 1:
838
- cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
1547
+ assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
1548
+ tApA = cute.make_fragment(1, cutlass.Boolean)
1549
+ tApA[0] = tAcA[0, 0, 0][1] < limit_k
1550
+ # (m, bK)
1551
+ mA_cur = mA_k[None, (None, k_tile)]
1552
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
1553
+ # (elems_per_load, thread_per_row)
1554
+ mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1555
+ if t0AcA[0, m, 0][0] < limit_m:
1556
+ # There's only 1 load per row
1557
+ assert cute.size(tAcA.shape, mode=[2]) == 1
1558
+ ki = tAcA[0, 0, 0][1] // elems_per_load
1559
+ # copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
1560
+ # TODO
1561
+ copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1562
+ ab_pipeline.producer_commit(ab_producer_state)
1563
+ ab_producer_state.advance()
1564
+ return ab_producer_state
1565
+
1566
+ @cute.jit
1567
+ def mma(
1568
+ self,
1569
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
1570
+ ab_read_state: cutlass.pipeline.PipelineState,
1571
+ tiled_mma: cute.TiledMma,
1572
+ tCrA: cute.Tensor,
1573
+ tCrB: cute.Tensor,
1574
+ acc: cute.Tensor,
1575
+ acc_slow: Optional[cute.Tensor],
1576
+ k_tile_cnt: Int32,
1577
+ warp_group_idx: Int32,
1578
+ ) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
1579
+ # /////////////////////////////////////////////////////////////////////////////
1580
+ # Prologue MMAs
1581
+ # /////////////////////////////////////////////////////////////////////////////
1582
+ k_pipe_mmas = 1
1583
+ ab_release_state = ab_read_state.clone()
1584
+ num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
1585
+ if const_expr(self.pingpong):
1586
+ self.pingpong_barrier_sync(warp_group_idx, stage="mma")
1587
+ peek_ab_full_status = cutlass.Boolean(True)
1588
+ if 0 < k_tile_cnt:
1589
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1590
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1591
+ num_k_blocks = cute.size(tCrA, mode=[2])
1592
+ # TODO: this is probably not correct if k_tile_cnt == 0
1593
+ for k_tile in cutlass.range(num_prologue_mma):
1594
+ # Wait for A/B buffer to be ready
1595
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1596
+ warpgroup.fence()
1597
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1598
+ k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1599
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1600
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1601
+ warpgroup.commit_group()
1602
+ ab_read_state.advance()
1603
+ peek_ab_full_status = cutlass.Boolean(True)
1604
+ if k_tile + 1 < k_tile_cnt:
1605
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1606
+ if const_expr(self.fp8_slow_accum):
1607
+ warpgroup.wait_group(0)
1608
+ acc_slow.store(acc.load())
1609
+
1610
+ # /////////////////////////////////////////////////////////////////////////////
1611
+ # MAINLOOP
1612
+ # /////////////////////////////////////////////////////////////////////////////
1613
+ for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
1614
+ # Wait for TMA copies to complete
1615
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1616
+ # WGMMA
1617
+ warpgroup.fence()
1618
+ if const_expr(self.fp8_slow_accum):
1619
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1620
+ for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1621
+ k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1622
+ cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1623
+ tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1624
+ warpgroup.commit_group()
1625
+ # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
1626
+ if const_expr(not self.fp8_slow_accum):
1627
+ warpgroup.wait_group(k_pipe_mmas)
1628
+ else:
1629
+ warpgroup.wait_group(0)
1630
+ acc_slow.store(acc_slow.load() + acc.load())
1631
+ ab_pipeline.consumer_release(ab_release_state)
1632
+ ab_read_state.advance()
1633
+ ab_release_state.advance()
1634
+ peek_ab_full_status = cutlass.Boolean(True)
1635
+ if k_tile + 1 < k_tile_cnt:
1636
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1637
+ if const_expr(self.pingpong):
1638
+ # Cue for next WG's MMA to start
1639
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
1640
+ if const_expr(not self.fp8_slow_accum):
1641
+ # fp8_slow_accum would already called wait_group(0) inside the loop
1642
+ warpgroup.wait_group(0)
1643
+ for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
1644
+ ab_pipeline.consumer_release(ab_release_state)
1645
+ ab_release_state.advance()
1646
+ if const_expr(self.fp8_slow_accum):
1647
+ acc.store(acc_slow.load())
1648
+ # If we don't return the tiled_mma, we get compiler error
1649
+ # "operand #0 does not dominate this use"
1650
+ return ab_read_state, tiled_mma
1651
+
1652
+ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
1653
+ assert stage in ["mma", "epi"]
1654
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1655
+ cute.arch.barrier(
1656
+ barrier_id=int(barrier) + warp_group_idx,
1657
+ number_of_threads=2 * self.num_threads_per_warp_group,
1658
+ )
839
1659
 
840
- if warp_idx == 0:
841
- cute.arch.cp_async_bulk_wait_group(0, read=True)
1660
+ def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
1661
+ assert stage in ["mma", "epi"]
1662
+ barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1663
+ cute.arch.barrier_arrive(
1664
+ barrier_id=int(barrier) + warp_group_idx,
1665
+ number_of_threads=2 * self.num_threads_per_warp_group,
1666
+ )
842
1667
 
843
1668
  @staticmethod
844
1669
  def _compute_stages(
845
1670
  tile_shape_mnk: Tuple[int, int, int],
1671
+ epi_tile: Optional[Tuple[int, int]],
846
1672
  a_dtype: Type[cutlass.Numeric],
847
1673
  b_dtype: Type[cutlass.Numeric],
1674
+ d_dtype: Type[cutlass.Numeric],
1675
+ c_dtype: Optional[Type[cutlass.Numeric]],
848
1676
  smem_capacity: int,
849
1677
  occupancy: int,
1678
+ overlap_sD_sA: bool,
850
1679
  ) -> Tuple[int, int]:
851
1680
  """Computes the number of stages for A/B/C operands based on heuristics.
852
1681
 
@@ -866,10 +1695,15 @@ class HopperWgmmaGemmKernel:
866
1695
  :rtype: Tuple[int, int]
867
1696
  """
868
1697
 
869
- # epi_stage = 4 if tile_shape_mnk[1] % 32 == 0 else 8
870
- epi_stage = 4
871
- # epi_smem will reuse smem ab.
872
- epi_bytes = 0
1698
+ epi_stage = 2
1699
+ if overlap_sD_sA:
1700
+ epi_bytes = 0
1701
+ else:
1702
+ d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8
1703
+ epi_bytes = d_bytes_per_stage * epi_stage
1704
+ epi_c_stage = 0 if c_dtype is None else 2
1705
+ if c_dtype is not None:
1706
+ epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
873
1707
 
874
1708
  a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
875
1709
  b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
@@ -878,16 +1712,23 @@ class HopperWgmmaGemmKernel:
878
1712
  )
879
1713
  mbar_helpers_bytes = 1024
880
1714
 
881
- ab_stage = (
1715
+ remaining_bytes = (
882
1716
  (smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
883
- ) // ab_bytes_per_stage
884
- return ab_stage, epi_stage
1717
+ )
1718
+ ab_stage = remaining_bytes // ab_bytes_per_stage
1719
+
1720
+ # Refine epilogue stages:
1721
+ # Calculate remaining smem after allocating for A/B stages and reserved bytes
1722
+ # Add remaining unused smem to epilogue
1723
+ if not overlap_sD_sA:
1724
+ epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // d_bytes_per_stage
1725
+ return ab_stage, epi_stage, epi_c_stage
885
1726
 
886
1727
  @staticmethod
887
1728
  def _sm90_compute_tile_shape_or_override(
888
1729
  tile_shape_mnk: Tuple[int, int, int],
1730
+ atom_layout_mnk: Tuple[int, int, int],
889
1731
  element_type: Type[cutlass.Numeric],
890
- is_cooperative: bool = False,
891
1732
  epi_tile_override: Tuple[int, int] | None = None,
892
1733
  ) -> Tuple[int, int]:
893
1734
  """Compute the epilogue tile shape or use override if provided.
@@ -906,33 +1747,42 @@ class HopperWgmmaGemmKernel:
906
1747
  """
907
1748
  if epi_tile_override is not None:
908
1749
  return epi_tile_override
909
- if is_cooperative:
910
- if cute.size(tile_shape_mnk, mode=[0]) == 192:
911
- tile_m = 192
912
- tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]) // 2)
913
- else:
914
- tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
915
- tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
916
- return (tile_m, tile_n)
1750
+ if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
1751
+ tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
1752
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1753
+ elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
1754
+ tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
1755
+ tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
917
1756
  else:
1757
+ # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
1758
+ # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
1759
+ # M dimension first, then move to the N dimension. But the accumulator in registers
1760
+ # iterate along the N dimension first, then move to the M dimension.
1761
+ # We could change the epilogue to accommodate this,
1762
+ # but it's easier to just set epi_tile_m = 64.
918
1763
  n_perf = 64 if element_type.width == 8 else 32
919
1764
  tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
920
1765
  tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
921
- return (tile_m, tile_n)
1766
+ return (tile_m, tile_n)
922
1767
 
923
1768
  @staticmethod
924
1769
  def _make_smem_layouts(
925
1770
  tile_shape_mnk: Tuple[int, int, int],
926
1771
  epi_tile: Tuple[int, int],
927
1772
  a_dtype: Type[cutlass.Numeric],
928
- a_layout: utils.LayoutEnum,
1773
+ a_layout: cutlass.utils.LayoutEnum,
929
1774
  b_dtype: Type[cutlass.Numeric],
930
- b_layout: utils.LayoutEnum,
1775
+ b_layout: cutlass.utils.LayoutEnum,
931
1776
  ab_stage: int,
932
1777
  d_dtype: Type[cutlass.Numeric],
933
- d_layout: utils.LayoutEnum,
1778
+ d_layout: cutlass.utils.LayoutEnum,
934
1779
  epi_stage: int,
935
- ) -> Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]:
1780
+ c_dtype: Optional[Type[cutlass.Numeric]],
1781
+ c_layout: Optional[cutlass.utils.LayoutEnum],
1782
+ epi_c_stage: int,
1783
+ ) -> Tuple[
1784
+ cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
1785
+ ]:
936
1786
  """Create shared memory layouts for A, B, and C tensors.
937
1787
 
938
1788
  :param tile_shape_mnk: CTA tile shape (M,N,K)
@@ -942,17 +1792,17 @@ class HopperWgmmaGemmKernel:
942
1792
  :param a_dtype: Data type for matrix A
943
1793
  :type a_dtype: type[cutlass.Numeric]
944
1794
  :param a_layout: Layout enum for matrix A
945
- :type a_layout: utils.LayoutEnum
1795
+ :type a_layout: cutlass.utils.LayoutEnum
946
1796
  :param b_dtype: Data type for matrix B
947
1797
  :type b_dtype: type[cutlass.Numeric]
948
1798
  :param b_layout: Layout enum for matrix B
949
- :type b_layout: utils.LayoutEnum
1799
+ :type b_layout: cutlass.utils.LayoutEnum
950
1800
  :param ab_stage: Number of stages for A/B tensors
951
1801
  :type ab_stage: int
952
1802
  :param d_dtype: Data type for output matrix C
953
1803
  :type d_dtype: type[cutlass.Numeric]
954
1804
  :param d_layout: Layout enum for the output matrix C
955
- :type d_layout: utils.LayoutEnum
1805
+ :type d_layout: cutlass.utils.LayoutEnum
956
1806
  :param epi_stage: Number of epilogue stages
957
1807
  :type epi_stage: int
958
1808
 
@@ -998,11 +1848,7 @@ class HopperWgmmaGemmKernel:
998
1848
  d_smem_shape = epi_tile
999
1849
  d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1000
1850
  d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1001
- sm90_utils.get_smem_layout_atom(
1002
- d_layout,
1003
- d_dtype,
1004
- d_major_mode_size,
1005
- ),
1851
+ sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
1006
1852
  d_dtype,
1007
1853
  )
1008
1854
  epi_smem_layout_staged = cute.tile_to_shape(
@@ -1011,40 +1857,37 @@ class HopperWgmmaGemmKernel:
1011
1857
  order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1012
1858
  )
1013
1859
 
1014
- return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged
1015
-
1016
- @staticmethod
1017
- def _compute_grid(
1018
- d: cute.Tensor,
1019
- tile_shape_mnk: Tuple[int, int, int],
1020
- cluster_shape_mnk: Tuple[int, int, int],
1021
- ) -> Tuple[int, int, int]:
1022
- """Compute grid shape for the output tensor C.
1023
-
1024
- :param d: The output tensor C
1025
- :type d: cute.Tensor
1026
- :param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1027
- :type tile_shape_mnk: Tuple[int, int, int]
1028
- :param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions.
1029
- :type cluster_shape_mnk: Tuple[int, int, int]
1030
-
1031
- :return: Grid shape for kernel launch.
1032
- :rtype: Tuple[int, int, int]
1033
- """
1860
+ if c_dtype is not None:
1861
+ assert c_layout is not None
1862
+ c_smem_shape = epi_tile
1863
+ c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
1864
+ c_smem_layout_atom = warpgroup.make_smem_layout_atom(
1865
+ sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
1866
+ c_dtype,
1867
+ )
1868
+ epi_c_smem_layout_staged = cute.tile_to_shape(
1869
+ c_smem_layout_atom,
1870
+ cute.append(c_smem_shape, epi_c_stage),
1871
+ order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
1872
+ )
1873
+ else:
1874
+ epi_c_smem_layout_staged = None
1034
1875
 
1035
- c_shape = (tile_shape_mnk[0], tile_shape_mnk[1])
1036
- gc = cute.zipped_divide(d, tiler=c_shape)
1037
- clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk)
1038
- grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk))
1039
- return grid
1876
+ return (
1877
+ a_smem_layout_staged,
1878
+ b_smem_layout_staged,
1879
+ epi_smem_layout_staged,
1880
+ epi_c_smem_layout_staged,
1881
+ )
1040
1882
 
1041
1883
  @staticmethod
1042
- def _make_tma_store_atoms_and_tensors(
1884
+ def _make_tma_epi_atoms_and_tensors(
1043
1885
  tensor_d: cute.Tensor,
1044
1886
  epi_smem_layout_staged: cute.ComposedLayout,
1045
1887
  epi_tile: Tuple[int, int],
1888
+ store_or_load: str,
1046
1889
  ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1047
- """Create TMA atoms and tensors for C tensor storage.
1890
+ """Create TMA atoms and tensors for storing D or loading C.
1048
1891
 
1049
1892
  :param tensor_d: Output tensor D
1050
1893
  :type tensor_d: cute.Tensor
@@ -1056,15 +1899,17 @@ class HopperWgmmaGemmKernel:
1056
1899
  :return: TMA atom and tensor for C
1057
1900
  :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1058
1901
  """
1902
+ assert store_or_load in ["load", "store"]
1059
1903
  epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
1060
- c_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
1904
+ d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
1905
+ op = (
1906
+ cpasync.CopyBulkTensorTileG2SOp()
1907
+ if store_or_load == "load"
1908
+ else cpasync.CopyBulkTensorTileS2GOp()
1909
+ )
1061
1910
  tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
1062
- cpasync.CopyBulkTensorTileS2GOp(),
1063
- tensor_d,
1064
- epi_smem_layout,
1065
- c_cta_v_layout,
1911
+ op, tensor_d, epi_smem_layout, d_cta_v_layout
1066
1912
  )
1067
-
1068
1913
  return tma_atom_d, tma_tensor_d
1069
1914
 
1070
1915
  @staticmethod
@@ -1104,6 +1949,31 @@ class HopperWgmmaGemmKernel:
1104
1949
  )
1105
1950
  return tma_atom, tma_tensor
1106
1951
 
1952
+ def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128):
1953
+ atom_async_copy = cute.make_copy_atom(
1954
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
1955
+ dtype,
1956
+ num_bits_per_copy=copy_bits,
1957
+ )
1958
+ copy_elems = copy_bits // dtype.width
1959
+ shape_dim_1 = cute.size(self.tile_shape_mnk[2]) // copy_elems
1960
+ # thread layout for copy
1961
+ thread_layout = cute.make_layout(
1962
+ (num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
1963
+ )
1964
+ if major_mode != cutlass.utils.LayoutEnum.ROW_MAJOR:
1965
+ shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
1966
+ thread_layout = cute.make_layout(
1967
+ (shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
1968
+ )
1969
+ # Value layout for copy
1970
+ value_layout = (
1971
+ cute.make_layout((1, copy_elems))
1972
+ if major_mode == cutlass.utils.LayoutEnum.ROW_MAJOR
1973
+ else cute.make_layout((copy_elems, 1))
1974
+ )
1975
+ return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
1976
+
1107
1977
  @staticmethod
1108
1978
  def is_valid_dtypes(
1109
1979
  a_dtype: Type[cutlass.Numeric],
@@ -1133,7 +2003,6 @@ class HopperWgmmaGemmKernel:
1133
2003
  :rtype: bool
1134
2004
  """
1135
2005
  is_valid = True
1136
- # tested a_dtype
1137
2006
  if a_dtype not in {
1138
2007
  cutlass.Float16,
1139
2008
  cutlass.BFloat16,
@@ -1149,7 +2018,6 @@ class HopperWgmmaGemmKernel:
1149
2018
  cutlass.Float8E5M2,
1150
2019
  }:
1151
2020
  is_valid = False
1152
- # tested acc_dtype
1153
2021
  if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
1154
2022
  is_valid = False
1155
2023
  # tested d_dtype
@@ -1180,21 +2048,28 @@ def run(
1180
2048
  a_dtype: Type[cutlass.Numeric],
1181
2049
  b_dtype: Type[cutlass.Numeric],
1182
2050
  d_dtype: Type[cutlass.Numeric],
2051
+ c_dtype: Optional[Type[cutlass.Numeric]],
1183
2052
  acc_dtype: Type[cutlass.Numeric],
1184
2053
  a_major: str,
1185
2054
  b_major: str,
1186
2055
  d_major: str,
2056
+ c_major: str,
1187
2057
  tile_shape_mnk: Tuple[int, int, int],
1188
2058
  cluster_shape_mn: Tuple[int, int],
1189
2059
  tolerance: float,
1190
2060
  warmup_iterations: int,
1191
2061
  iterations: int,
1192
2062
  skip_ref_check: bool,
1193
- use_cold_l2: bool = False,
2063
+ persistent: bool,
2064
+ dynamic_persistent: bool,
2065
+ pingpong: bool,
2066
+ varlen_m: bool,
2067
+ gather_A: bool,
2068
+ fp8_fast_accum: bool,
1194
2069
  **kwargs,
1195
2070
  ):
1196
2071
  """
1197
- Prepare A/B/C tensors, launch GPU kernel, and reference checking.
2072
+ Prepare A/B/D/C tensors, launch GPU kernel, and reference checking.
1198
2073
 
1199
2074
  :param mnkl: Problem size (M, N, K, L)
1200
2075
  :type mnkl: Tuple[int, int, int, int]
@@ -1220,22 +2095,22 @@ def run(
1220
2095
  :type iterations: int, optional
1221
2096
  :param skip_ref_check: Whether to skip reference result validation, defaults to False
1222
2097
  :type skip_ref_check: bool, optional
1223
- :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
1224
- :type use_cold_l2: bool, optional
1225
- :return: Execution time of the GEMM kernel in microseconds
1226
- :rtype: float
1227
2098
  """
1228
2099
 
2100
+ if dynamic_persistent:
2101
+ persistent = True
2102
+
1229
2103
  print("Running Hopper Dense GEMM with:")
1230
2104
  print(f"mnkl: {mnkl}")
1231
- print(f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}")
1232
- print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}")
2105
+ print(
2106
+ f"A dtype: {a_dtype}, B dtype: {b_dtype}, D dtype: {d_dtype}, C_dtype: {c_dtype}, Acc dtype: {acc_dtype}"
2107
+ )
2108
+ print(f"Matrix majors - A: {a_major}, B: {b_major}, D: {d_major}")
1233
2109
  print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
1234
2110
  print(f"Tolerance: {tolerance}")
1235
2111
  print(f"Warmup iterations: {warmup_iterations}")
1236
2112
  print(f"Iterations: {iterations}")
1237
2113
  print(f"Skip reference checking: {skip_ref_check}")
1238
- print(f"Use cold L2: {use_cold_l2}")
1239
2114
 
1240
2115
  # Unpack parameters
1241
2116
  m, n, k, l = mnkl
@@ -1263,16 +2138,17 @@ def run(
1263
2138
  permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
1264
2139
  is_unsigned = dtype in {cutlass.Uint8}
1265
2140
  # Temporarily use uint8 as torch does not support fp8 type
1266
- torch_dtype = (
1267
- cutlass_torch.dtype(dtype)
2141
+ torch_dtype = cutlass_torch.dtype(dtype)
2142
+ gen_dtype = (
2143
+ torch_dtype
1268
2144
  if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
1269
- else torch.uint8
2145
+ else torch.bfloat16
1270
2146
  )
1271
2147
 
1272
2148
  # Create dtype torch tensor (cpu)
1273
2149
  torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
1274
2150
  shape,
1275
- torch_dtype,
2151
+ gen_dtype,
1276
2152
  permute_order=permute_order,
1277
2153
  # init_type=cutlass.torch.TensorInitType.RANDOM,
1278
2154
  # init_config=cutlass.torch.RandomInitConfig(
@@ -1280,7 +2156,7 @@ def run(
1280
2156
  # ),
1281
2157
  init_type=cutlass.torch.TensorInitType.GAUSSIAN,
1282
2158
  init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
1283
- )
2159
+ ).to(torch_dtype)
1284
2160
  # Create dtype torch tensor (gpu)
1285
2161
  torch_tensor = torch_tensor_cpu.cuda()
1286
2162
 
@@ -1288,10 +2164,20 @@ def run(
1288
2164
  f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
1289
2165
 
1290
2166
  # Create dtype cute tensor (gpu)
1291
- cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
2167
+ torch_tensor_view = (
2168
+ torch_tensor
2169
+ if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
2170
+ else torch_tensor.view(torch.uint8)
2171
+ )
2172
+ cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
1292
2173
  cute_tensor.element_type = dtype
1293
2174
  if is_dynamic_layout:
1294
2175
  cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
2176
+ cute_tensor = cute_tensor.mark_compact_shape_dynamic(
2177
+ mode=(1 if not is_mode0_major else 0),
2178
+ stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
2179
+ divisibility=(128 // dtype.width),
2180
+ )
1295
2181
  cute_tensor = cutlass.torch.convert_cute_tensor(
1296
2182
  f32_torch_tensor,
1297
2183
  cute_tensor,
@@ -1302,24 +2188,142 @@ def run(
1302
2188
  return f32_torch_tensor, cute_tensor, torch_tensor
1303
2189
 
1304
2190
  a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
2191
+ if gather_A:
2192
+ assert a_major == "k"
2193
+ a_idx = torch.randperm(l * m, dtype=torch.int32, device="cuda")
2194
+ from einops import rearrange
2195
+
2196
+ a = rearrange(rearrange(a, "m k l -> (m l) k")[a_idx.cpu()], "(m l) k -> m k l", m=m)
2197
+ a_torch = rearrange(a_torch, "m k l -> (m l) k")
2198
+ mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2199
+ a_idx_reshaped = rearrange(a_idx, "(m l) -> l m", m=m).contiguous().transpose(0, 1)
2200
+ mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
2201
+ else:
2202
+ mAIdx = None
1305
2203
  b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
1306
- c, mC, c_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
2204
+ _, mD, d_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
2205
+ if c_dtype is not None:
2206
+ c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
2207
+ else:
2208
+ c, mC, c_torch = None, None, None
2209
+ if varlen_m:
2210
+ assert a_major == "k"
2211
+ assert d_major == "n"
2212
+ from einops import rearrange
2213
+
2214
+ a, d_torch = [rearrange(t, "m x l -> (l m) x") for t in (a, d_torch)]
2215
+ if not gather_A:
2216
+ (a_torch,) = [rearrange(t, "m x l -> (l m) x") for t in (a_torch,)]
2217
+ if c_dtype is not None:
2218
+ c, c_torch = [rearrange(t, "m x l -> (l m) x") for t in (c, c_torch)]
2219
+ mC = from_dlpack(c_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2220
+ mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2221
+ mD = from_dlpack(d_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
2222
+ # TODO: generate random cu_seqlens_m
2223
+ cu_seqlens_m = torch.arange(0, l + 1, dtype=torch.int32, device="cuda") * m
2224
+ mCuSeqlensM = from_dlpack(cu_seqlens_m, assumed_align=64).mark_layout_dynamic(leading_dim=0)
2225
+ if gather_A:
2226
+ a_idx_reshaped = rearrange(a_idx_reshaped, "m l -> (l m)")
2227
+ mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
2228
+ else:
2229
+ cu_seqlens_m, mCuSeqlensM = None, None
2230
+
2231
+ if varlen_m: # Need to allocate space in gmem to store tensormaps
2232
+ if not persistent:
2233
+ total_m = m * l
2234
+ block_size_m = tile_shape_mnk[0] * cluster_shape_mnk[0]
2235
+ block_size_n = tile_shape_mnk[1] * cluster_shape_mnk[1]
2236
+ total_clusters_m_max = (total_m + l * (block_size_m - 1)) // block_size_m
2237
+ total_clusters_max = total_clusters_m_max * ((n + block_size_n - 1) // block_size_n)
2238
+ total_ctas = total_clusters_max * cluster_shape_mnk[0] * cluster_shape_mnk[1]
2239
+ else:
2240
+ total_ctas = cutlass.utils.HardwareInfo().get_device_multiprocessor_count()
2241
+ if pingpong:
2242
+ total_ctas *= 2
2243
+ # 128 bytes per tensormap
2244
+ tensormaps_torch = torch.empty(total_ctas, 128 // 8, dtype=torch.int64, device="cuda")
2245
+ tensormaps_tensor = from_dlpack(
2246
+ tensormaps_torch, assumed_align=128
2247
+ ).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
2248
+ else:
2249
+ tensormaps_tensor = None
2250
+
2251
+ gemm = HopperWgmmaGemmKernel(
2252
+ acc_dtype,
2253
+ a_dtype,
2254
+ tile_shape_mnk,
2255
+ cluster_shape_mnk,
2256
+ pingpong=pingpong,
2257
+ is_persistent=persistent,
2258
+ fp8_fast_accum=fp8_fast_accum,
2259
+ gather_A=gather_A,
2260
+ )
1307
2261
 
1308
- gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk)
2262
+ # Compute max active clusters on current device
2263
+ if persistent:
2264
+ max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
2265
+ cluster_shape_mn[0] * cluster_shape_mn[1]
2266
+ )
2267
+ if dynamic_persistent:
2268
+ tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda")
2269
+ else:
2270
+ tile_count_semaphore = None
2271
+ # max_active_clusters = 1
2272
+ else:
2273
+ max_active_clusters = 0
2274
+ tile_count_semaphore = None
1309
2275
 
1310
- torch_stream = torch.cuda.Stream()
1311
- stream = cuda.CUstream(torch_stream.cuda_stream)
2276
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1312
2277
  # compile gemm kernel
1313
- compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
2278
+ compiled_gemm = cute.compile(
2279
+ gemm,
2280
+ mA,
2281
+ mB,
2282
+ mD,
2283
+ mC,
2284
+ mAIdx,
2285
+ mCuSeqlensM,
2286
+ tensormaps_tensor,
2287
+ make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2288
+ if tile_count_semaphore is not None
2289
+ else None,
2290
+ max_active_clusters,
2291
+ current_stream,
2292
+ )
1314
2293
 
1315
2294
  if not skip_ref_check:
1316
2295
  # execution
1317
- compiled_gemm(mA, mB, mC, stream)
2296
+ compiled_gemm(
2297
+ mA,
2298
+ mB,
2299
+ mD,
2300
+ mC,
2301
+ mAIdx,
2302
+ mCuSeqlensM,
2303
+ tensormaps_tensor,
2304
+ tile_count_semaphore,
2305
+ max_active_clusters,
2306
+ current_stream,
2307
+ )
2308
+ if tile_count_semaphore is not None and varlen_m:
2309
+ tile_count_semaphore.zero_()
1318
2310
 
1319
2311
  torch.cuda.synchronize()
1320
2312
 
1321
2313
  # Ref check
1322
- ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
2314
+ if not varlen_m:
2315
+ ref = torch.einsum("mkl,nkl->mnl", a, b)
2316
+ else:
2317
+ ref = torch.cat(
2318
+ [
2319
+ torch.einsum("mk,nk->mn", a[cu_seqlens_m[i] : cu_seqlens_m[i + 1]], b[:, :, i])
2320
+ for i in range(l)
2321
+ ],
2322
+ dim=0,
2323
+ )
2324
+ if c is not None:
2325
+ ref = ref + c
2326
+ ref = ref.cpu()
1323
2327
 
1324
2328
  if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
1325
2329
  # m major: (l, n, m) -> (m, n, l)
@@ -1333,79 +2337,112 @@ def run(
1333
2337
  init_type=cutlass_torch.TensorInitType.SKIP,
1334
2338
  ).cuda()
1335
2339
  # Create dtype cute tensor (gpu)
1336
- ref_c_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
2340
+ ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
1337
2341
  leading_dim=(1 if d_major == "n" else 0)
1338
2342
  )
1339
- ref_c_tensor.element_type = d_dtype
1340
- ref_c_tensor = cutlass_torch.convert_cute_tensor(
2343
+ ref_d_tensor.element_type = d_dtype
2344
+ ref_d_tensor = cutlass_torch.convert_cute_tensor(
1341
2345
  ref,
1342
- ref_c_tensor,
2346
+ ref_d_tensor,
1343
2347
  d_dtype,
1344
2348
  is_dynamic_layout=True,
1345
2349
  )
1346
- ref_c = f8_torch_tensor.cpu()
2350
+ ref_d = f8_torch_tensor.cpu()
1347
2351
  else:
1348
- ref_c = ref.to(cutlass_torch.dtype(d_dtype))
1349
-
1350
- torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
1351
-
1352
- def generate_tensors():
1353
- _, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
1354
- _, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
1355
- _, mC_workspace, _ = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
1356
- return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
1357
-
1358
- workspace_count = 1
1359
- if use_cold_l2:
1360
- one_workspace_bytes = (
1361
- a_torch.numel() * a_torch.element_size()
1362
- + b_torch.numel() * b_torch.element_size()
1363
- + c_torch.numel() * c_torch.element_size()
1364
- )
1365
- workspace_count = testing.get_workspace_count(
1366
- one_workspace_bytes, warmup_iterations, iterations
1367
- )
2352
+ ref_d = ref.to(cutlass_torch.dtype(d_dtype))
1368
2353
 
1369
- exec_time = testing.benchmark(
1370
- compiled_gemm,
1371
- workspace_generator=generate_tensors,
1372
- workspace_count=workspace_count,
1373
- stream=stream,
1374
- warmup_iterations=warmup_iterations,
1375
- iterations=iterations,
1376
- )
2354
+ out = d_torch.cpu().squeeze()
2355
+ out_ref = ref_d.squeeze()
2356
+ # breakpoint()
2357
+ torch.testing.assert_close(d_torch.cpu(), ref_d, atol=tolerance, rtol=1e-03)
1377
2358
 
1378
- from triton.testing import do_bench
2359
+ # return
1379
2360
 
1380
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
2361
+ from triton.testing import do_bench
1381
2362
 
1382
2363
  flops = 2 * m * n * k * l
2364
+ # Calculate memory bandwidth
2365
+ bytes_A = m * k * l * (a_dtype.width // 8) # A tensor: (m, k, l)
2366
+ bytes_B = n * k * l * (b_dtype.width // 8) # B tensor: (n, k, l)
2367
+ bytes_D = m * n * l * (d_dtype.width // 8) # D tensor: (m, n, l)
2368
+ bytes_C = m * n * l * (c_dtype.width // 8) if c_dtype is not None else 0 # C tensor: (m, n, l)
2369
+ total_bytes = bytes_A + bytes_B + bytes_D + bytes_C # Read A, B, C; Write D
1383
2370
 
1384
- repeats = 30
1385
- # repeats = 1
1386
- warmup = 5
2371
+ repeats = iterations
2372
+ warmup = warmup_iterations
1387
2373
 
1388
2374
  import time
1389
2375
 
2376
+ if not varlen_m and not gather_A:
2377
+ time.sleep(0.5)
2378
+ if a_dtype.width == 8:
2379
+ assert l == 1
2380
+ scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
2381
+ fn_cublas = lambda: torch._scaled_mm(
2382
+ a_torch[:, :, 0],
2383
+ b_torch[:, :, 0].mT,
2384
+ scale_a=scale_ab,
2385
+ scale_b=scale_ab,
2386
+ out_dtype=torch.bfloat16,
2387
+ use_fast_accum=fp8_fast_accum,
2388
+ )
2389
+ else:
2390
+ if c_torch is None:
2391
+ fn_cublas = lambda: torch.matmul(
2392
+ a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT
2393
+ )
2394
+ else:
2395
+ c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
2396
+ fn_cublas = lambda: torch.baddbmm(
2397
+ c_torch_convert.permute(2, 0, 1),
2398
+ a_torch.permute(2, 0, 1),
2399
+ b_torch.permute(2, 0, 1).mT,
2400
+ )
2401
+ timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2402
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2403
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2404
+
1390
2405
  time.sleep(0.5)
1391
- fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
1392
- timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
1393
- tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
1394
- print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
1395
2406
 
2407
+ def fn():
2408
+ compiled_gemm(
2409
+ mA,
2410
+ mB,
2411
+ mD,
2412
+ mC,
2413
+ mAIdx,
2414
+ mCuSeqlensM,
2415
+ tensormaps_tensor,
2416
+ tile_count_semaphore,
2417
+ max_active_clusters,
2418
+ current_stream,
2419
+ )
2420
+ if tile_count_semaphore is not None and varlen_m:
2421
+ tile_count_semaphore.zero_()
2422
+
2423
+ timing = do_bench(fn, warmup=warmup, rep=repeats)
2424
+ # Idk why but for some cases the 1st run is much slower
1396
2425
  time.sleep(0.5)
1397
- fn = lambda: compiled_gemm(mA, mB, mC, current_stream)
1398
2426
  timing = do_bench(fn, warmup=warmup, rep=repeats)
1399
2427
  tflops = flops / (timing * 1e9) # Convert to TFlops
1400
- print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
2428
+ gbps = total_bytes / (timing * 1e6) # Convert to GB/s (1e9 for ms->s, 1e9 for B->GB)
2429
+ print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}, GB/s: {gbps:.0f}")
2430
+ fn()
1401
2431
 
1402
- time.sleep(0.5)
1403
- fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
1404
- timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
1405
- tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
1406
- print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2432
+ if not varlen_m:
2433
+ time.sleep(0.5)
2434
+ timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2435
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2436
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2437
+
2438
+ from flash_attn.utils.benchmark import pytorch_profiler
1407
2439
 
1408
- return exec_time # Return execution time in microseconds
2440
+ pytorch_profiler(fn_cublas)
2441
+ # pytorch_profiler(torch.sort, d_torch.squeeze(), dim=-1)
2442
+ # pytorch_profiler(torch.compile(torch.sort), d_torch.squeeze(), dim=-1)
2443
+ # pytorch_profiler(torch.topk, d_torch.squeeze(), dim=-1, k=1)
2444
+ # pytorch_profiler(torch.compile(torch.topk), d_torch.squeeze(), dim=-1, k=1)
2445
+ # pytorch_profiler(torch.square, d_torch.squeeze())
1409
2446
 
1410
2447
 
1411
2448
  if __name__ == "__main__":
@@ -1415,16 +2452,23 @@ if __name__ == "__main__":
1415
2452
  args.a_dtype,
1416
2453
  args.b_dtype,
1417
2454
  args.d_dtype,
2455
+ args.c_dtype,
1418
2456
  args.acc_dtype,
1419
2457
  args.a_major,
1420
2458
  args.b_major,
1421
2459
  args.d_major,
2460
+ args.c_major,
1422
2461
  args.tile_shape_mnk,
1423
2462
  args.cluster_shape_mn,
1424
2463
  args.tolerance,
1425
2464
  args.warmup_iterations,
1426
2465
  args.iterations,
1427
2466
  args.skip_ref_check,
1428
- args.use_cold_l2,
2467
+ args.persistent,
2468
+ args.dynamic_persistent,
2469
+ args.pingpong,
2470
+ args.varlen_m,
2471
+ args.gather_A,
2472
+ args.fp8_fast_accum,
1429
2473
  )
1430
2474
  print("PASS")