quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
@@ -2,12 +2,10 @@
2
2
  # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
3
3
 
4
4
  import enum
5
- from typing import Tuple, Type, Callable, Optional, Union
6
- from dataclasses import dataclass
5
+ from typing import Tuple, Type, Callable, Optional, Union, Literal
7
6
  from functools import partial
8
7
  import math
9
8
 
10
- from torch import Tensor
11
9
 
12
10
  import cuda.bindings.driver as cuda
13
11
 
@@ -16,10 +14,9 @@ import cutlass.cute as cute
16
14
  import cutlass.pipeline as pipeline
17
15
  from cutlass.cute.nvgpu import cpasync, warp, warpgroup
18
16
  import cutlass.utils.hopper_helpers as sm90_utils
19
- from cutlass import Int32, Float32, Boolean, const_expr
17
+ from cutlass import Int32, Float32, Float16, Boolean, const_expr
18
+ from cutlass.cutlass_dsl import if_generate
20
19
  from cutlass.utils import LayoutEnum
21
- import cutlass.torch as cutlass_torch
22
- from cutlass.cute.runtime import make_ptr
23
20
 
24
21
 
25
22
  from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
@@ -30,14 +27,12 @@ from quack.tile_scheduler import (
30
27
  VarlenMTileSchedulerArguments,
31
28
  VarlenMTileScheduler,
32
29
  )
33
- from quack.varlen_utils import VarlenArguments
34
- from quack.tensormap_manager import TensorMapManagerSm90
30
+ from quack.varlen_utils import VarlenArguments, VarlenManager
35
31
 
36
32
  # return PipelineStateWAdvance instead of PipelineState
37
33
  from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
38
- import quack.utils as utils
39
- from quack.cute_dsl_utils import get_max_active_clusters
40
- from quack.gemm_wrapper_utils import GemmWrapperBase
34
+ import quack.copy_utils as copy_utils
35
+ import quack.sm90_utils as quack_sm90_utils
41
36
 
42
37
  """
43
38
  A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
@@ -86,12 +81,13 @@ class NamedBarrierGemm(enum.IntEnum):
86
81
  MmaWG1 = enum.auto()
87
82
  EpiWG0 = enum.auto()
88
83
  EpiWG1 = enum.auto()
84
+ TmemPtr = enum.auto()
89
85
 
90
86
 
91
87
  class GemmSm90:
92
88
  """
93
89
  This class implements batched matrix multiplication (C = A x B) with support for various data types
94
- and architectural features specific to Hopper GPUs.
90
+ and architectural features specific to Hopper GPUs with persistent tile scheduling and warp specialization.
95
91
 
96
92
  :param acc_dtype: Data type for accumulation during computation
97
93
  :type acc_dtype: type[cutlass.Numeric]
@@ -118,24 +114,18 @@ class GemmSm90:
118
114
 
119
115
  Example:
120
116
  >>> gemm = GemmSm90(
121
- ... acc_dtype=cutlass.Float32,
117
+ ... acc_dtype=Float32,
122
118
  ... tile_shape_mn=(128, 256),
123
119
  ... cluster_shape_mnk=(1, 1, 1)
124
120
  ... )
125
121
  >>> gemm(a_tensor, b_tensor, c_tensor, stream)
126
122
  """
127
123
 
128
- bytes_per_tensormap = 128
124
+ arch = 90
125
+ num_epi_tensormaps: int = 0
129
126
 
130
- @dataclass
131
- class EpilogueArguments(ArgumentsBase):
132
- alpha: Optional[Float32 | cute.Tensor] = None
133
- beta: Optional[Float32 | cute.Tensor] = None
134
-
135
- @dataclass
136
- class EpilogueParams(ParamsBase):
137
- alpha: Optional[Float32 | cute.Tensor] = None
138
- beta: Optional[Float32 | cute.Tensor] = None
127
+ EpilogueArguments = ArgumentsBase
128
+ EpilogueParams = ParamsBase
139
129
 
140
130
  def __init__(
141
131
  self,
@@ -174,8 +164,8 @@ class GemmSm90:
174
164
 
175
165
  self.cluster_shape_mnk = cluster_shape_mnk
176
166
  # K dimension is deferred in _setup_attributes
177
- self.tile_shape_mnk = (*tile_shape_mn, 1)
178
- tile_M, tile_N = self.tile_shape_mnk[0], self.tile_shape_mnk[1]
167
+ self.cta_tile_shape_mnk = (*tile_shape_mn, 1)
168
+ tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
179
169
  # check the cta tile shape
180
170
  if not self.pingpong:
181
171
  if tile_M not in [64, 128, 192, 256, 320]:
@@ -209,14 +199,18 @@ class GemmSm90:
209
199
  else:
210
200
  atom_layout_m, atom_layout_n = 1, 2
211
201
  else:
212
- atom_layout_m = self.tile_shape_mnk[0] // 64 if self.tile_shape_mnk[0] < 256 else 2
202
+ atom_layout_m = (
203
+ self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2
204
+ )
213
205
  atom_layout_n = 1
214
206
  assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
215
207
  else:
216
208
  atom_layout_m, atom_layout_n = 1, 1
217
209
  self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
218
210
 
219
- self.num_mcast_ctas_a = self.cluster_shape_mnk[1] if not self.gather_A else 1
211
+ self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
212
+ if self.gather_A:
213
+ assert self.num_mcast_ctas_a == 1
220
214
  self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
221
215
  self.is_a_mcast = self.num_mcast_ctas_a > 1
222
216
  self.is_b_mcast = self.num_mcast_ctas_b > 1
@@ -229,16 +223,13 @@ class GemmSm90:
229
223
  self.num_threads_per_warp_group = 128
230
224
  self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
231
225
  self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
232
- self.num_epi_threads = (
233
- self.mma_warp_groups if not self.pingpong else 1
234
- ) * self.num_threads_per_warp_group
226
+ self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
235
227
  self.num_ab_load_warps = 1 if not self.gather_A else 4
236
- self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
237
- self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
238
228
  self.ab_load_warp_id = self.mma_warp_groups * 4
239
- self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
229
+ # self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
230
+ # self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
240
231
 
241
- regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // (
232
+ regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
242
233
  math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
243
234
  )
244
235
  if self.fp8_slow_accum:
@@ -268,7 +259,7 @@ class GemmSm90:
268
259
  self.shared_storage = None
269
260
  self.buffer_align_bytes = 1024
270
261
 
271
- def _setup_attributes(self, epilogue_args: Optional[EpilogueArguments]):
262
+ def _setup_attributes(self, epilogue_args: EpilogueArguments):
272
263
  """Set up configurations that are dependent on GEMM inputs
273
264
 
274
265
  This method configures various attributes based on the input tensor properties
@@ -289,7 +280,7 @@ class GemmSm90:
289
280
  self.b_layout.sm90_mma_major_mode(),
290
281
  self.acc_dtype,
291
282
  self.atom_layout_mnk,
292
- tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
283
+ tiler_mn=(64, self.cta_tile_shape_mnk[1] // self.atom_layout_mnk[1]),
293
284
  )
294
285
  if const_expr(self.atom_layout_mnk[1] > 1):
295
286
  # If N dimension is split among 2 WGs, we need to permute the N dimension so
@@ -299,7 +290,7 @@ class GemmSm90:
299
290
  # WG1 would write to a separate epi smem of size (64, 16) that's far away.
300
291
  atom_n = self.atom_layout_mnk[1]
301
292
  permutation_n = cute.make_ordered_layout(
302
- (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
293
+ (8, self.cta_tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
303
294
  )
304
295
  self.tiled_mma = cute.make_tiled_mma(
305
296
  cute.make_mma_atom(self.tiled_mma.op),
@@ -308,30 +299,30 @@ class GemmSm90:
308
299
  )
309
300
  mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
310
301
  mma_inst_tile_k = 4
311
- self.tile_shape_mnk = (
312
- self.tile_shape_mnk[0],
313
- self.tile_shape_mnk[1],
302
+ self.cta_tile_shape_mnk = (
303
+ self.cta_tile_shape_mnk[0],
304
+ self.cta_tile_shape_mnk[1],
314
305
  mma_inst_shape_k * mma_inst_tile_k,
315
306
  )
316
307
 
317
308
  self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
318
309
 
319
310
  self.epi_tile = self._sm90_compute_tile_shape_or_override(
320
- self.tile_shape_mnk,
311
+ self.cta_tile_shape_mnk,
321
312
  self.atom_layout_mnk,
322
313
  self.d_dtype,
323
314
  )
324
315
 
325
316
  # Compute stage before compute smem layout
326
317
  self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
327
- self.tile_shape_mnk,
318
+ self.cta_tile_shape_mnk,
328
319
  self.epi_tile,
329
320
  self.a_dtype,
330
321
  self.b_dtype,
331
322
  self.d_dtype,
332
323
  self.c_dtype,
333
324
  epilogue_args,
334
- self.smem_capacity,
325
+ cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
335
326
  self.occupancy,
336
327
  # epi_smem will reuse smem ab if not persistent.
337
328
  overlap_sD_sA=not self.is_persistent,
@@ -344,7 +335,7 @@ class GemmSm90:
344
335
  self.epi_smem_layout_staged,
345
336
  self.epi_c_smem_layout_staged,
346
337
  ) = self._make_smem_layouts(
347
- self.tile_shape_mnk,
338
+ self.cta_tile_shape_mnk,
348
339
  self.epi_tile,
349
340
  self.a_dtype,
350
341
  self.a_layout,
@@ -366,10 +357,9 @@ class GemmSm90:
366
357
  mB: cute.Tensor,
367
358
  mD: Optional[cute.Tensor],
368
359
  mC: Optional[cute.Tensor],
369
- epilogue_args: Optional[ArgumentsBase],
360
+ epilogue_args: ArgumentsBase,
370
361
  scheduler_args: TileSchedulerOptions,
371
362
  varlen_args: Optional[VarlenArguments],
372
- mAIdx: Optional[cute.Tensor],
373
363
  stream: cuda.CUstream,
374
364
  ):
375
365
  """Execute the GEMM operation in steps:
@@ -405,7 +395,10 @@ class GemmSm90:
405
395
  raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
406
396
  if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
407
397
  raise TypeError("a_dtype should be float16 or float8")
408
- assert (mAIdx is not None) == self.gather_A
398
+
399
+ if const_expr(varlen_args is None):
400
+ varlen_args = VarlenArguments()
401
+ assert (varlen_args.mAIdx is not None) == self.gather_A
409
402
 
410
403
  # Assume all strides are divisible by 128 bits except the last stride
411
404
  new_stride = lambda t: tuple(
@@ -421,77 +414,48 @@ class GemmSm90:
421
414
 
422
415
  self._setup_attributes(epilogue_args)
423
416
 
417
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0))
418
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0))
419
+ tma_atom_a, tma_tensor_a = None, None
424
420
  if const_expr(not self.gather_A):
425
421
  tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
426
422
  mA,
427
- self.a_smem_layout_staged,
428
- (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
423
+ a_smem_layout,
424
+ (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
429
425
  self.cluster_shape_mnk[1],
430
426
  )
431
- else:
432
- tma_atom_a, tma_tensor_a = None, None
433
-
434
427
  tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
435
428
  mB,
436
- self.b_smem_layout_staged,
437
- (self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
429
+ b_smem_layout,
430
+ (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
438
431
  self.cluster_shape_mnk[0],
439
432
  )
440
433
 
434
+ self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
435
+ if const_expr(not self.gather_A):
436
+ self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
437
+
438
+ tma_atom_d, tma_tensor_d = None, None
441
439
  if const_expr(mD is not None):
442
440
  tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
443
- mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
441
+ mD,
442
+ self.epi_smem_layout_staged,
443
+ self.epi_tile,
444
+ op_type="store"
445
+ if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
446
+ else "add",
444
447
  )
445
- else:
446
- tma_atom_d, tma_tensor_d = None, None
447
-
448
+ tma_atom_c, tma_tensor_c = None, None
448
449
  if const_expr(mC is not None):
449
450
  tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
450
- mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
451
+ mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
451
452
  )
452
- else:
453
- tma_atom_c, tma_tensor_c = None, None
454
453
 
455
454
  epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
455
+ varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
456
456
 
457
- if const_expr(varlen_args is None):
458
- varlen_args = VarlenArguments()
459
- if const_expr(varlen_args.mCuSeqlensM is None):
460
- num_problems = (
461
- mD.shape[2]
462
- if mD is not None
463
- else (
464
- mB.shape[2]
465
- if varlen_args.mCuSeqlensK is None
466
- else varlen_args.mCuSeqlensK.shape[0] - 1
467
- )
468
- )
469
- problem_shape_ntile_mnl = (
470
- cute.ceil_div(mA.shape[0], self.tile_shape_mnk[0]),
471
- cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
472
- num_problems,
473
- )
474
- TileSchedulerCls = self.get_scheduler_class()
475
- tile_sched_args = self.get_scheduler_arguments(problem_shape_ntile_mnl, scheduler_args)
476
- else:
477
- assert mD is not None or not self.gather_A
478
- problem_shape_ntile_mnl = (
479
- None,
480
- cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
481
- varlen_args.mCuSeqlensM.shape[0] - 1,
482
- )
483
- TileSchedulerCls = VarlenMTileScheduler
484
- tile_sched_args = VarlenMTileSchedulerArguments(
485
- problem_shape_ntile_mnl=problem_shape_ntile_mnl,
486
- total_m=mD.shape[0] if mD is not None else mAIdx.shape[0],
487
- cu_seqlens_m=varlen_args.mCuSeqlensM,
488
- raster_order=scheduler_args.raster_order,
489
- group_size=scheduler_args.max_swizzle_size,
490
- tile_shape_mn=self.tile_shape_mnk[:2],
491
- cluster_shape_mnk=self.cluster_shape_mnk,
492
- tile_count_semaphore=scheduler_args.tile_count_semaphore,
493
- is_persistent=self.is_persistent,
494
- )
457
+ TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
458
+ tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
495
459
  tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
496
460
  grid = TileSchedulerCls.get_grid_shape(
497
461
  tile_sched_params, scheduler_args.max_active_clusters
@@ -507,7 +471,7 @@ class GemmSm90:
507
471
  ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
508
472
  epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
509
473
  sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
510
- tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
474
+ tile_count: cute.struct.MemRange[Int32, self.sched_stage]
511
475
  sD: cute.struct.Align[
512
476
  cute.struct.MemRange[
513
477
  self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
@@ -534,6 +498,7 @@ class GemmSm90:
534
498
 
535
499
  # Launch the kernel synchronously
536
500
  self.kernel(
501
+ self.tiled_mma,
537
502
  tma_atom_a,
538
503
  tma_tensor_a if const_expr(not self.gather_A) else mA,
539
504
  tma_atom_b,
@@ -543,11 +508,7 @@ class GemmSm90:
543
508
  tma_atom_c,
544
509
  tma_tensor_c,
545
510
  epilogue_params,
546
- mAIdx,
547
- varlen_args.mCuSeqlensM,
548
- varlen_args.mCuSeqlensK,
549
- varlen_args.mTensormaps,
550
- self.tiled_mma,
511
+ varlen_params,
551
512
  self.cluster_layout_mnk,
552
513
  self.a_smem_layout_staged,
553
514
  self.b_smem_layout_staged,
@@ -559,7 +520,6 @@ class GemmSm90:
559
520
  grid=grid,
560
521
  block=[self.threads_per_cta, 1, 1],
561
522
  cluster=self.cluster_shape_mnk,
562
- smem=self.shared_storage.size_in_bytes(),
563
523
  stream=stream,
564
524
  min_blocks_per_mp=1,
565
525
  )
@@ -569,6 +529,7 @@ class GemmSm90:
569
529
  @cute.kernel
570
530
  def kernel(
571
531
  self,
532
+ tiled_mma: cute.TiledMma,
572
533
  tma_atom_a: Optional[cute.CopyAtom],
573
534
  mA_mkl: cute.Tensor,
574
535
  tma_atom_b: cute.CopyAtom,
@@ -578,11 +539,7 @@ class GemmSm90:
578
539
  tma_atom_c: Optional[cute.CopyAtom],
579
540
  mC_mnl: Optional[cute.Tensor],
580
541
  epilogue_params: ParamsBase,
581
- mAIdx: Optional[cute.Tensor],
582
- cu_seqlens_m: Optional[cute.Tensor],
583
- cu_seqlens_k: Optional[cute.Tensor],
584
- tensormaps: Optional[cute.Tensor],
585
- tiled_mma: cute.TiledMma,
542
+ varlen_params: VarlenManager.Params,
586
543
  cluster_layout_mnk: cute.Layout,
587
544
  a_smem_layout: cute.ComposedLayout,
588
545
  b_smem_layout: cute.ComposedLayout,
@@ -618,9 +575,11 @@ class GemmSm90:
618
575
  :type epi_smem_layout: cute.ComposedLayout
619
576
  """
620
577
 
621
- varlen_m = const_expr(cu_seqlens_m is not None)
622
- varlen_k = const_expr(cu_seqlens_k is not None)
578
+ varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
579
+ varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
623
580
  assert not (varlen_m and varlen_k)
581
+ if const_expr(self.gather_A):
582
+ assert varlen_m or varlen_k
624
583
  has_D = const_expr(mD_mnl is not None)
625
584
  has_C = const_expr(mC_mnl is not None)
626
585
 
@@ -641,8 +600,6 @@ class GemmSm90:
641
600
  storage = smem.allocate(self.shared_storage)
642
601
 
643
602
  ab_pipeline = self.make_ab_pipeline(
644
- a_smem_layout=cute.slice_(a_smem_layout, (None, None, 0)),
645
- b_smem_layout=cute.slice_(b_smem_layout, (None, None, 0)),
646
603
  tiled_mma=tiled_mma,
647
604
  cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
648
605
  ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
@@ -681,28 +638,20 @@ class GemmSm90:
681
638
  sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
682
639
  epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
683
640
 
684
- # Get tensormap buffer address
685
- tensormap_manager = None
686
- tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
687
- if const_expr(varlen_m or varlen_k):
688
- tensormap_manager = TensorMapManagerSm90(
689
- cutlass.utils.TensorMapUpdateMode.GMEM, GemmSm90.bytes_per_tensormap
690
- )
691
- # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
692
- tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
693
- if const_expr(varlen_m):
694
- tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
695
- tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
696
- tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
697
- )
698
- else:
699
- assert varlen_k
700
- tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
701
- tensormaps[tensormap_workspace_idx, 0, None].iterator
702
- )
703
- tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
704
- tensormaps[tensormap_workspace_idx, 1, None].iterator
705
- )
641
+ varlen_manager = VarlenManager.create(
642
+ varlen_params,
643
+ has_D,
644
+ self.num_epi_tensormaps,
645
+ # Only used if not varlen_m
646
+ len_m_static=Int32(
647
+ mA_mkl.shape[0]
648
+ if varlen_k or varlen_params.mAIdx is None
649
+ else varlen_params.mAIdx.shape[0]
650
+ ),
651
+ len_k_static=Int32(mA_mkl.shape[1]),
652
+ pingpong=self.pingpong,
653
+ warp_idx=warp_idx,
654
+ )
706
655
 
707
656
  TileSchedulerCls = partial(
708
657
  TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
@@ -715,28 +664,20 @@ class GemmSm90:
715
664
  and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
716
665
  ):
717
666
  is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
718
- if const_expr(varlen_k):
719
- # initialize tensormap for A & B
720
- tensormap_manager.init_tensormap_from_atom(
721
- tma_atom_a,
722
- tensormap_a_ptr,
723
- is_tma_warp,
724
- )
725
- tensormap_manager.init_tensormap_from_atom(
726
- tma_atom_b,
727
- tensormap_b_ptr,
728
- is_tma_warp,
729
- )
667
+ # initialize tensormap for A & B
668
+ varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
669
+ tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
670
+ tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
730
671
  # ///////////////////////////////////////////////////////////////////////////////
731
672
  # Get mcast mask
732
673
  # ///////////////////////////////////////////////////////////////////////////////
733
674
  cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
734
- cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
675
+ block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
735
676
  a_mcast_mask = cute.make_layout_image_mask(
736
- cluster_layout_mnk, cluster_coord_mnk, mode=1
677
+ cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
737
678
  )
738
679
  b_mcast_mask = cute.make_layout_image_mask(
739
- cluster_layout_mnk, cluster_coord_mnk, mode=0
680
+ cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0
740
681
  )
741
682
  a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
742
683
  b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
@@ -752,177 +693,129 @@ class GemmSm90:
752
693
  )
753
694
  if const_expr(varlen_k):
754
695
  # wait tensormap initialization complete before update
755
- tensormap_manager.fence_tensormap_initialization()
756
- # batch index of last tile
757
- last_batch_idx = cutlass.Int32(-1)
696
+ varlen_manager.fence_tensormap_init()
758
697
  while work_tile.is_valid_tile:
759
698
  tile_coord_mnkl = work_tile.tile_idx
760
699
  batch_idx = tile_coord_mnkl[3]
761
- if const_expr(varlen_k):
762
- is_group_changed = batch_idx != last_batch_idx
763
- last_batch_idx = batch_idx
764
- if is_group_changed:
765
- # construct tensor A/B based on real address, shape and stride information
766
- tensormap_manager.update_tensormap_shape(
767
- (tensormap_a_ptr, tensormap_b_ptr),
768
- is_manager_warp=is_tma_warp,
769
- shapes=(cu_seqlens_k[batch_idx + 1], cu_seqlens_k[batch_idx + 1]),
770
- orders=(
771
- 0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1,
772
- 0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1,
773
- ),
774
- tensormap_smem_ptr=None,
775
- )
700
+ varlen_manager.update_tensormap_AB(
701
+ batch_idx,
702
+ self.a_layout,
703
+ self.b_layout,
704
+ is_tma_warp,
705
+ )
776
706
  # ///////////////////////////////////////////////////////////////////////////
777
707
  # Local_tile partition global tensors
778
708
  # ///////////////////////////////////////////////////////////////////////////
779
709
  if const_expr(not self.gather_A):
780
- if const_expr(varlen_m):
781
- mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
782
- elif const_expr(varlen_k):
783
- mA_mk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mA_mkl)
784
- else:
785
- mA_mk = mA_mkl[None, None, batch_idx]
710
+ mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
786
711
  # (bM, bK, RestK)
787
- gA_k = cute.local_tile(
712
+ gA_mk = cute.local_tile(
788
713
  mA_mk,
789
- cute.select(self.tile_shape_mnk, [0, 2]),
714
+ cute.select(self.cta_tile_shape_mnk, [0, 2]),
790
715
  (tile_coord_mnkl[0], None),
791
716
  )
792
717
  else:
793
- mA_mk = mA_mkl
718
+ mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
794
719
  if const_expr(varlen_m):
795
- mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
796
- elif const_expr(varlen_k):
797
- mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
720
+ gAIdx = cute.local_tile(
721
+ mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
722
+ )
723
+ # (M, K)
724
+ mA_mk = mA_mkl
798
725
  else:
799
- mAIdx_mk = mAIdx[None, batch_idx]
800
- gAIdx = cute.local_tile(
801
- mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
802
- )
803
- if const_expr(varlen_k):
804
- mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
805
- else:
806
- mB_nk = mB_nkl[None, None, batch_idx]
726
+ assert varlen_k
727
+ # (tile_K, RestK)
728
+ gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
729
+ # (tile_M, K)
730
+ mA_mk = cute.local_tile(
731
+ mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
732
+ )
807
733
  # (bN, bK, RestK)
808
- gB_k = cute.local_tile(
809
- mB_nk, cute.select(self.tile_shape_mnk, [1, 2]), (tile_coord_mnkl[1], None)
734
+ gB_nk = cute.local_tile(
735
+ varlen_manager.offset_batch_B(mB_nkl, batch_idx),
736
+ cute.select(self.cta_tile_shape_mnk, [1, 2]),
737
+ (tile_coord_mnkl[1], None),
810
738
  )
811
739
  # //////////////////////////////////////////////////////////////////////////
812
740
  # Partition shared tensor for TMA load A/B
813
741
  # //////////////////////////////////////////////////////////////////////////
814
- if const_expr(varlen_k):
815
- # ensure the update to tensormap has completed before using it
816
- if is_group_changed and is_tma_warp:
817
- tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
818
- tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
819
- tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
820
- tensormap_a_ptr, cute.AddressSpace.generic
821
- )
822
- tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
823
- tensormap_b_ptr, cute.AddressSpace.generic
824
- )
825
- else:
826
- tma_desc_a_ptr, tma_desc_b_ptr = None, None
742
+ varlen_manager.fence_tensormap_update_AB(is_tma_warp)
743
+ len_m = varlen_manager.len_m(batch_idx)
744
+ len_k = varlen_manager.len_k(batch_idx)
827
745
  # TMA load A partition_S/D
828
- a_cta_layout = cute.make_layout(
829
- cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
830
- )
831
- a_cta_crd = cluster_coord_mnk[1]
746
+ copy_A = None
832
747
  if const_expr(not self.gather_A):
833
- # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
834
- tAsA, tAgA_k = cpasync.tma_partition(
835
- tma_atom_a,
836
- a_cta_crd,
837
- a_cta_layout,
838
- cute.group_modes(sA, 0, 2),
839
- cute.group_modes(gA_k, 0, 2),
840
- )
841
- copy_A = partial(
842
- cute.copy,
748
+ copy_A, _, _ = copy_utils.tma_get_copy_fn(
843
749
  tma_atom_a,
750
+ cta_coord=block_in_cluster_coord_mnk[1],
751
+ cta_layout=cute.make_layout(
752
+ cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
753
+ ),
754
+ src_tensor=gA_mk,
755
+ dst_tensor=sA,
844
756
  mcast_mask=a_mcast_mask,
845
757
  tma_desc_ptr=tma_desc_a_ptr,
846
758
  )
847
759
  else:
848
760
  tiled_copy_A = self._make_gmem_tiled_copy_A(
849
- mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
761
+ mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
850
762
  )
851
763
  tidx = (
852
- cute.arch.thread_idx()[0]
853
- - self.mma_warp_groups * self.num_threads_per_warp_group
764
+ cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
854
765
  )
855
766
  thr_copy_A = tiled_copy_A.get_slice(tidx)
856
- # (atom_v, CPY_M, 1, STAGE)
857
- tAsA = thr_copy_A.partition_D(sA)
858
- assert tAsA.shape[2] == 1
859
- tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
860
- copy_A = partial(cute.copy, tiled_copy_A)
767
+ copy_A, prefetch_A = None, None
768
+ if const_expr(varlen_m):
769
+ copy_A = copy_utils.gather_m_get_copy_fn(
770
+ thr_copy_A,
771
+ mA_mk,
772
+ sA,
773
+ gAIdx,
774
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
775
+ limit_k=len_k,
776
+ )
777
+ else:
778
+ copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
779
+ thr_copy_A,
780
+ mA_mk,
781
+ sA,
782
+ gAIdx,
783
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
784
+ limit_k=len_k,
785
+ )
861
786
  # TMA load B partition_S/D
862
- b_cta_layout = cute.make_layout(
863
- cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
864
- )
865
- b_cta_crd = cluster_coord_mnk[0]
866
- # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
867
- tBsB, tBgB_k = cpasync.tma_partition(
787
+ copy_B, _, _ = copy_utils.tma_get_copy_fn(
868
788
  tma_atom_b,
869
- b_cta_crd,
870
- b_cta_layout,
871
- cute.group_modes(sB, 0, 2),
872
- cute.group_modes(gB_k, 0, 2),
873
- )
874
- copy_B = partial(
875
- cute.copy, tma_atom_b, mcast_mask=b_mcast_mask, tma_desc_ptr=tma_desc_b_ptr
789
+ cta_coord=block_in_cluster_coord_mnk[0],
790
+ cta_layout=cute.make_layout(
791
+ cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
792
+ ),
793
+ src_tensor=gB_nk,
794
+ dst_tensor=sB,
795
+ mcast_mask=b_mcast_mask,
796
+ tma_desc_ptr=tma_desc_b_ptr,
876
797
  )
877
- k_len = (
878
- cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
879
- if const_expr(varlen_k)
880
- else mA_mkl.shape[1]
881
- )
882
- k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
798
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
883
799
  if const_expr(not self.gather_A):
884
800
  ab_producer_state = self.load_AB(
885
- ab_pipeline,
886
- ab_producer_state,
887
- copy_A,
888
- tAgA_k,
889
- tAsA,
890
- copy_B,
891
- tBgB_k,
892
- tBsB,
893
- k_tile_cnt,
801
+ ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
894
802
  )
895
803
  else:
896
- limit_m = (
897
- mAIdx.shape[0]
898
- if const_expr(cu_seqlens_m is None)
899
- else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
900
- )
901
804
  ab_producer_state = self.load_AB_gather_A(
902
805
  ab_pipeline,
903
806
  ab_producer_state,
904
- thr_copy_A,
905
- mA_mk,
906
- tAsA,
907
- gAIdx,
807
+ copy_A,
808
+ prefetch_A,
908
809
  copy_B,
909
- tBgB_k,
910
- tBsB,
911
810
  k_tile_cnt,
912
- limit_A=(
913
- limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
914
- mA_mk.shape[1],
915
- ),
811
+ varlen_m=varlen_m,
916
812
  )
917
813
  tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
918
- tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
919
814
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
920
815
  work_tile = tile_scheduler.get_current_work()
921
816
  # End of persistent scheduler loop
922
817
  if const_expr(self.pingpong and not varlen_k):
923
818
  # Need to write the tile_idx to smem for the next WG in the pingpong mode
924
- # tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
925
- tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
926
819
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
927
820
  ab_pipeline.producer_tail(ab_producer_state)
928
821
  if is_scheduler_warp:
@@ -934,13 +827,11 @@ class GemmSm90:
934
827
  (not self.pingpong and warp_idx == 0)
935
828
  or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
936
829
  )
937
- if const_expr(varlen_m):
938
- # initialize tensormap for D
939
- tensormap_manager.init_tensormap_from_atom(
940
- tma_atom_d,
941
- tensormap_d_ptr,
942
- is_manager_warp=is_tma_warp,
943
- )
830
+ varlen_manager.init_tensormap_epi(
831
+ tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
832
+ )
833
+ tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
834
+ tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
944
835
  # //////////////////////////////////////////////////////////////////////////////
945
836
  # Partition global tensor for TiledMMA_A/B/C
946
837
  # //////////////////////////////////////////////////////////////////////////////
@@ -962,7 +853,9 @@ class GemmSm90:
962
853
  tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
963
854
  tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
964
855
 
965
- acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
856
+ acc_shape = tiled_mma.partition_shape_C(
857
+ cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
858
+ )
966
859
  acc = cute.make_fragment(acc_shape, self.acc_dtype)
967
860
  acc_slow = None
968
861
  if const_expr(self.fp8_slow_accum):
@@ -974,10 +867,11 @@ class GemmSm90:
974
867
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
975
868
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
976
869
 
977
- k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.tile_shape_mnk[2])
978
- c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
870
+ k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2])
871
+ c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
979
872
 
980
873
  ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
874
+ epi_store_pipeline = self.make_epi_store_pipeline()
981
875
  epi_read_state = make_pipeline_state(
982
876
  pipeline.PipelineUserType.Consumer, self.epi_c_stage
983
877
  )
@@ -996,9 +890,8 @@ class GemmSm90:
996
890
  if const_expr(not varlen_k):
997
891
  ab_read_state.advance_iters(k_tile_cnt_static)
998
892
  else:
999
- batch_idx = work_tile.tile_idx[3]
1000
- k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1001
- k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
893
+ len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
894
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1002
895
  ab_read_state.advance_iters(k_tile_cnt)
1003
896
  tile_scheduler.advance_to_next_work()
1004
897
  if const_expr(varlen_k):
@@ -1009,31 +902,22 @@ class GemmSm90:
1009
902
  work_tile = tile_scheduler.initial_work_tile_info()
1010
903
  if const_expr(varlen_m):
1011
904
  # wait tensormap initialization complete before update
1012
- tensormap_manager.fence_tensormap_initialization()
1013
- # batch index of last tile
1014
- last_batch_idx = cutlass.Int32(-1)
905
+ varlen_manager.fence_tensormap_init()
1015
906
  while work_tile.is_valid_tile:
1016
907
  tile_coord_mnkl = work_tile.tile_idx
1017
908
  batch_idx = tile_coord_mnkl[3]
1018
- if const_expr(varlen_m):
1019
- is_group_changed = batch_idx != last_batch_idx
1020
- last_batch_idx = batch_idx
1021
- if is_group_changed:
1022
- # construct tensor D based on real address, shape and stride information
1023
- tensormap_manager.update_tensormap_shape(
1024
- (tensormap_d_ptr,),
1025
- is_manager_warp=is_tma_warp,
1026
- shapes=(cu_seqlens_m[batch_idx + 1],),
1027
- orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
1028
- tensormap_smem_ptr=None,
1029
- )
1030
-
1031
- k_len = (
1032
- cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1033
- if const_expr(varlen_k)
1034
- else mA_mkl.shape[1]
909
+ epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
910
+ epilogue_params, varlen_params.cu_seqlens_m, batch_idx
911
+ )
912
+ varlen_manager.update_tensormap_epi(
913
+ batch_idx,
914
+ self.d_layout,
915
+ epi_shapes,
916
+ epi_orders,
917
+ is_tma_warp,
1035
918
  )
1036
- k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
919
+ len_k = varlen_manager.len_k(batch_idx)
920
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1037
921
  ab_read_state, tiled_mma = self.mma(
1038
922
  ab_pipeline,
1039
923
  ab_read_state,
@@ -1056,51 +940,42 @@ class GemmSm90:
1056
940
  self.pingpong_barrier_sync(warp_group_idx, "epi")
1057
941
 
1058
942
  epilogue_barrier = pipeline.NamedBarrier(
1059
- barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
943
+ barrier_id=int(NamedBarrierGemm.Epilogue),
944
+ num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
1060
945
  )
1061
946
 
1062
- if const_expr(varlen_m):
1063
- # ensure the update to tensormap has completed before using it
1064
- if is_group_changed and is_tma_warp:
1065
- tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
1066
- tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
1067
- tensormap_d_ptr, cute.AddressSpace.generic
1068
- )
1069
- else:
1070
- tma_desc_d_ptr = None
947
+ varlen_manager.fence_tensormap_update_epi(is_tma_warp)
1071
948
 
949
+ copy_D = None
1072
950
  if const_expr(has_D):
1073
- bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(
951
+ copy_D, _, _ = self.epilog_gmem_copy_and_partition(
1074
952
  tma_atom_d,
1075
- mD_mnl,
1076
- self.tile_shape_mnk[:2],
953
+ varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
954
+ self.cta_tile_shape_mnk[:2],
1077
955
  self.epi_tile,
1078
956
  sD,
1079
957
  tile_coord_mnkl,
1080
- cu_seqlens_m,
958
+ tma_desc_ptr=tma_desc_d_ptr,
1081
959
  )
1082
- copy_D = partial(cute.copy, tma_atom_d, tma_desc_ptr=tma_desc_d_ptr)
1083
- else:
1084
- bSG_sD, bSG_gD, copy_D = None, None, None
960
+ copy_C = None
1085
961
  if const_expr(has_C):
1086
- bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
962
+ copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
1087
963
  tma_atom_c,
1088
- mC_mnl,
1089
- self.tile_shape_mnk[:2],
964
+ varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
965
+ self.cta_tile_shape_mnk[:2],
1090
966
  self.epi_tile,
1091
967
  sC,
1092
968
  tile_coord_mnkl,
1093
- cu_seqlens_m,
1094
969
  )
1095
- copy_C = partial(cute.copy, tma_atom_c)
1096
- epi_load_g2s = partial(self.epi_load_g2s, epi_pipeline, copy_C, bGS_gC, bGS_sC)
1097
- else:
1098
- epi_load_g2s = None
970
+ copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
1099
971
 
1100
972
  d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
1101
- tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
1102
- tiled_mma, self.d_layout, d_dtype_for_layout, acc, sD, tidx
973
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
974
+ tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
1103
975
  )
976
+ # (R2S, R2S_M, R2S_N)
977
+ tRS_rAcc = tiled_copy_r2s.retile(acc)
978
+ load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
1104
979
  if const_expr(has_C):
1105
980
  tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
1106
981
  tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
@@ -1118,24 +993,25 @@ class GemmSm90:
1118
993
  epi_read_state, epi_producer_state = self.epilogue(
1119
994
  epilogue_params,
1120
995
  epi_smem_tensors,
996
+ tma_desc_epi_ptrs,
1121
997
  epi_pipeline,
998
+ epi_store_pipeline,
1122
999
  epi_read_state,
1123
1000
  epi_producer_state,
1124
- tiled_mma,
1125
- tRS_rAcc,
1001
+ self.epi_tile,
1002
+ load_acc_subtile,
1126
1003
  tRS_rD,
1127
1004
  tRS_rC,
1005
+ None, # tiled_copy_t2r, for Sm100 only
1128
1006
  tiled_copy_r2s,
1129
1007
  tRS_sD,
1130
1008
  tiled_copy_s2r,
1131
1009
  tSR_rC,
1132
1010
  tSR_sC,
1133
1011
  copy_D,
1134
- bSG_sD,
1135
- bSG_gD,
1136
- epi_load_g2s,
1012
+ copy_C,
1137
1013
  tile_coord_mnkl,
1138
- cu_seqlens_m,
1014
+ varlen_manager,
1139
1015
  epilogue_barrier,
1140
1016
  tile_scheduler,
1141
1017
  tidx,
@@ -1147,7 +1023,7 @@ class GemmSm90:
1147
1023
  # so we have to make sure the smem content is done reading before signaling
1148
1024
  # the next WG's epilogue.
1149
1025
  if is_tma_warp:
1150
- cute.arch.cp_async_bulk_wait_group(0, read=True)
1026
+ epi_store_pipeline.producer_tail()
1151
1027
  self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
1152
1028
 
1153
1029
  if const_expr(not self.pingpong):
@@ -1166,31 +1042,33 @@ class GemmSm90:
1166
1042
  tile_scheduler.advance_to_next_work()
1167
1043
  work_tile = tile_scheduler.get_current_work()
1168
1044
  if work_tile.is_valid_tile:
1169
- batch_idx = work_tile.tile_idx[3]
1170
- k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
1171
- k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
1045
+ len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
1046
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
1172
1047
  ab_read_state.advance_iters(k_tile_cnt)
1173
1048
  tile_scheduler.advance_to_next_work()
1174
1049
  work_tile = tile_scheduler.get_current_work()
1175
1050
  # End of persistent scheduler loop
1176
1051
 
1052
+ # Wait for D store complete
1177
1053
  if const_expr(not self.pingpong):
1178
1054
  if is_tma_warp:
1179
- cute.arch.cp_async_bulk_wait_group(0, read=True)
1055
+ epi_store_pipeline.producer_tail()
1180
1056
 
1181
1057
  @cute.jit
1182
1058
  def load_AB(
1183
1059
  self,
1184
1060
  ab_pipeline: cutlass.pipeline.PipelineAsync,
1185
1061
  ab_producer_state: cutlass.pipeline.PipelineState,
1186
- copy_A: Callable,
1187
- tAgA: cute.Tensor,
1188
- tAsA: cute.Tensor,
1062
+ copy_A: Optional[Callable],
1189
1063
  copy_B: Callable,
1190
- tBgB: cute.Tensor,
1191
- tBsB: cute.Tensor,
1192
1064
  k_tile_cnt: Int32,
1065
+ # These are for Sm100 blockscaled gemm
1066
+ copy_SFA: Optional[Callable] = None,
1067
+ copy_SFB: Optional[Callable] = None,
1193
1068
  ) -> cutlass.pipeline.PipelineState:
1069
+ blockscaled = const_expr(copy_SFA is not None)
1070
+ if const_expr(blockscaled):
1071
+ assert copy_SFB is not None
1194
1072
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1195
1073
  peek_ab_empty_status = Boolean(True)
1196
1074
  if 0 < k_tile_cnt:
@@ -1203,8 +1081,13 @@ class GemmSm90:
1203
1081
  # Also sets the transaction barrier for the A/B buffers
1204
1082
  ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
1205
1083
  tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1206
- copy_A(tAgA[None, k_tile], tAsA[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
1207
- copy_B(tBgB[None, k_tile], tBsB[None, ab_producer_state.index], tma_bar_ptr=tma_bar_ptr)
1084
+ smem_idx = ab_producer_state.index
1085
+ if const_expr(copy_A is not None):
1086
+ copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1087
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1088
+ if const_expr(blockscaled):
1089
+ copy_SFA(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1090
+ copy_SFB(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1208
1091
  # Mainloop pipeline's producer commit is a NOP
1209
1092
  ab_pipeline.producer_commit(ab_producer_state)
1210
1093
  ab_producer_state.advance()
@@ -1218,38 +1101,12 @@ class GemmSm90:
1218
1101
  self,
1219
1102
  ab_pipeline: cutlass.pipeline.PipelineAsync,
1220
1103
  ab_producer_state: cutlass.pipeline.PipelineState,
1221
- thr_copy_A: cute.core.ThrCopy,
1222
- mA: cute.Tensor,
1223
- tAsA: cute.Tensor,
1224
- gAIdx: cute.Tensor,
1104
+ copy_A: Callable,
1105
+ prefetch_A: Optional[Callable],
1225
1106
  copy_B: Callable,
1226
- tBgB: cute.Tensor,
1227
- tBsB: cute.Tensor,
1228
1107
  k_tile_cnt: Int32,
1229
- limit_A: Tuple[Int32, Int32],
1108
+ varlen_m: bool = True,
1230
1109
  ) -> cutlass.pipeline.PipelineState:
1231
- # (atom_v, CPY_M, 1, RestK)
1232
- limit_m, limit_k = limit_A
1233
- limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
1234
- cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
1235
- tAcA = thr_copy_A.partition_S(cA)
1236
- t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
1237
- # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
1238
- # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
1239
- # This is so that when we do the comparison, t0AcA is known at compile time.
1240
- limit_m = limit_m - tAcA[0][0]
1241
- # Read indices for A
1242
- rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
1243
- m_idx = cute.make_fragment(rows_per_thread, Int32)
1244
- for m in cutlass.range(rows_per_thread):
1245
- row_idx = tAcA[0, m, 0][0]
1246
- if t0AcA[0, m, 0][0] < limit_m:
1247
- m_idx[m] = gAIdx[row_idx]
1248
- else:
1249
- m_idx[m] = -1
1250
- elems_per_load = cute.size(tAsA.shape[0][0])
1251
- # (m, (bK, RestK))
1252
- mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
1253
1110
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1254
1111
  # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
1255
1112
  peek_ab_empty_status = Boolean(True)
@@ -1258,35 +1115,27 @@ class GemmSm90:
1258
1115
  # /////////////////////////////////////////////////////////////////////////
1259
1116
  # TMA load on B and cp.async on A
1260
1117
  # /////////////////////////////////////////////////////////////////////////
1261
- copy_A = partial(cute.copy, thr_copy_A)
1262
1118
  for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1119
+ prefetch_out = ()
1120
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1121
+ prefetch_out = (prefetch_A(k_tile),)
1263
1122
  # Wait for A/B buffers to be empty before loading into them
1264
1123
  # Also sets the transaction barrier for the A/B buffers
1265
- ab_pipeline.producer_acquire(
1266
- ab_producer_state,
1267
- peek_ab_empty_status,
1268
- # A tiny bit faster to rotate the warp that does TMA
1269
- is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
1124
+ # A tiny bit faster to rotate the warp that does TMA
1125
+ # However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
1126
+ # since that's the warp that does the tensormap update.
1127
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
1128
+ (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1270
1129
  )
1271
- # A bit faster to load B first while we calculate the predicate for A
1272
- if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1273
- copy_B(
1274
- tBgB[None, k_tile],
1275
- tBsB[None, ab_producer_state.index],
1276
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1277
- )
1278
- # (m, bK)
1279
- mA_cur = mA_k[None, (None, k_tile)]
1280
- for m in cutlass.range_constexpr(tAcA.shape[1]):
1281
- # (elems_per_load, thread_per_row)
1282
- mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1283
- if t0AcA[0, m, 0][0] < limit_m:
1284
- # There's only 1 load per row
1285
- assert cute.size(tAcA.shape, mode=[2]) == 1
1286
- ki = tAcA[0, 0, 0][1] // elems_per_load
1287
- copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1130
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1131
+ smem_idx = ab_producer_state.index
1132
+ # A bit faster to load B first while we calculate the indices for A
1133
+ if is_tma_warp:
1134
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1135
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1136
+ copy_A(k_tile, smem_idx, *prefetch_out)
1288
1137
  # This tells mbarrier to track the completion of cp.async
1289
- ab_pipeline.producer_commit(ab_producer_state)
1138
+ ab_pipeline.producer_cpasync_commit(ab_producer_state)
1290
1139
  ab_producer_state.advance()
1291
1140
  peek_ab_empty_status = Boolean(True)
1292
1141
  if k_tile + 1 < k_tile_cnt:
@@ -1294,33 +1143,19 @@ class GemmSm90:
1294
1143
  # bound checking in the K dimension on the last k_tile
1295
1144
  if 0 < k_tile_cnt:
1296
1145
  k_tile = k_tile_cnt - 1
1297
- ab_pipeline.producer_acquire(
1298
- ab_producer_state,
1299
- peek_ab_empty_status,
1300
- is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
1146
+ prefetch_out = ()
1147
+ if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1148
+ prefetch_out = (prefetch_A(k_tile, pred=True),)
1149
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
1150
+ (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1301
1151
  )
1302
- if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
1303
- copy_B(
1304
- tBgB[None, k_tile],
1305
- tBsB[None, ab_producer_state.index],
1306
- tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
1307
- )
1308
- assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
1309
- tApA = cute.make_fragment(1, Boolean)
1310
- tApA[0] = tAcA[0, 0, 0][1] < limit_k
1311
- # (m, bK)
1312
- mA_cur = mA_k[None, (None, k_tile)]
1313
- for m in cutlass.range_constexpr(tAcA.shape[1]):
1314
- # (elems_per_load, thread_per_row)
1315
- mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
1316
- if t0AcA[0, m, 0][0] < limit_m:
1317
- # There's only 1 load per row
1318
- assert cute.size(tAcA.shape, mode=[2]) == 1
1319
- ki = tAcA[0, 0, 0][1] // elems_per_load
1320
- # copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
1321
- # TODO
1322
- copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
1323
- ab_pipeline.producer_commit(ab_producer_state)
1152
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1153
+ smem_idx = ab_producer_state.index
1154
+ if is_tma_warp:
1155
+ tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
1156
+ copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1157
+ copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
1158
+ ab_pipeline.producer_cpasync_commit(ab_producer_state)
1324
1159
  ab_producer_state.advance()
1325
1160
  return ab_producer_state
1326
1161
 
@@ -1416,24 +1251,25 @@ class GemmSm90:
1416
1251
  self,
1417
1252
  params: EpilogueParams,
1418
1253
  epi_smem_tensors: Tuple[cute.Tensor, ...],
1254
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
1419
1255
  epi_pipeline: cutlass.pipeline.PipelineAsync,
1256
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
1420
1257
  epi_read_state: cutlass.pipeline.PipelineState,
1421
- epi_producer_state: cutlass.pipeline.PipelineState,
1422
- tiled_mma: cute.TiledMma,
1423
- tRS_rAcc: cute.Tensor,
1258
+ epi_producer_state: Optional[cutlass.pipeline.PipelineState],
1259
+ epi_tile: cute.Tile,
1260
+ load_acc_subtile: Callable,
1424
1261
  tRS_rD: cute.Tensor,
1425
1262
  tRS_rC: Optional[cute.Tensor],
1426
- tiled_copy_r2s: cute.core.ThrCopy,
1263
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
1264
+ tiled_copy_r2s: cute.TiledCopy,
1427
1265
  tRS_sD: cute.Tensor,
1428
- tiled_copy_s2r: Optional[cute.core.ThrCopy],
1266
+ tiled_copy_s2r: Optional[cute.ThrCopy],
1429
1267
  tSR_rC: Optional[cute.Tensor],
1430
1268
  tSR_sC: Optional[cute.Tensor],
1431
1269
  copy_D: Optional[Callable],
1432
- bSG_sD: cute.Tensor,
1433
- bSG_gD: cute.Tensor,
1434
- epi_load_g2s: Optional[Callable],
1270
+ copy_C: Optional[Callable],
1435
1271
  tile_coord_mnkl: cute.Coord,
1436
- cu_seqlens_m: Optional[cute.Tensor],
1272
+ varlen_manager: VarlenManager,
1437
1273
  epilogue_barrier: cutlass.pipeline.NamedBarrier,
1438
1274
  tile_scheduler,
1439
1275
  tidx: Int32,
@@ -1441,22 +1277,61 @@ class GemmSm90:
1441
1277
  ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
1442
1278
  has_C = const_expr(tRS_rC is not None)
1443
1279
  has_D = const_expr(copy_D is not None)
1444
- # We iterate over epi tiles in the N dimension first before the M dimension
1445
1280
  epi_tile_shape = cute.zipped_divide(
1446
- cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
1281
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
1447
1282
  ).shape[1]
1448
- epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
1283
+ # We iterate over epi tiles in the N dimension first before the M dimension
1284
+ epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0))
1449
1285
  epi_tile_num = cute.size(epi_tile_shape)
1450
1286
  num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
1451
1287
 
1452
- if const_expr(epi_load_g2s is not None):
1288
+ epi_tensors = self.epi_begin(
1289
+ params,
1290
+ epi_smem_tensors,
1291
+ epi_tile,
1292
+ tiled_copy_t2r,
1293
+ tiled_copy_r2s,
1294
+ tile_coord_mnkl,
1295
+ varlen_manager,
1296
+ epilogue_barrier,
1297
+ tidx,
1298
+ )
1299
+
1300
+ if const_expr(copy_C is not None):
1453
1301
  for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
1454
- epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
1302
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
1303
+ if is_tma_warp:
1304
+ epi_pipeline.producer_acquire(epi_producer_state)
1305
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
1306
+ epi_pipeline.producer_commit(epi_producer_state)
1307
+ epi_producer_state.advance()
1308
+
1309
+ def tma_store_fn(src_idx, dst_idx):
1310
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1311
+ cute.arch.fence_proxy(
1312
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1313
+ )
1314
+ epilogue_barrier.arrive_and_wait()
1315
+ # Copy from shared memory to global memory
1316
+ if is_tma_warp:
1317
+ if const_expr(has_D):
1318
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
1319
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
1320
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
1321
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
1322
+ epilogue_barrier.arrive_and_wait()
1323
+
1324
+ # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops
1325
+ # with the TMA store. However, currently this doesn't seem to improve perf.
1326
+ delay_tma_store = False
1455
1327
 
1328
+ src_idx_prev, dst_idx_prev = None, None
1456
1329
  for epi_idx in cutlass.range_constexpr(epi_tile_num):
1330
+ # The global memory coordinate for the current epi tile
1331
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1457
1332
  # Copy from acc to D registers
1458
- for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1459
- tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1333
+ load_acc_subtile(tRS_rD, epi_idx)
1334
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
1460
1335
  if const_expr(has_C):
1461
1336
  epi_pipeline.consumer_wait(epi_read_state)
1462
1337
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
@@ -1468,98 +1343,132 @@ class GemmSm90:
1468
1343
  with cute.arch.elect_one():
1469
1344
  epi_pipeline.consumer_release(epi_read_state)
1470
1345
  epi_read_state.advance()
1471
- if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
1472
- epi_producer_state = epi_load_g2s(
1473
- epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
1474
- )
1475
- tRS_rEpi = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
1346
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
1347
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
1348
+ if is_tma_warp:
1349
+ epi_pipeline.producer_acquire(epi_producer_state)
1350
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
1351
+ epi_pipeline.producer_commit(epi_producer_state)
1352
+ epi_producer_state.advance()
1353
+ tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
1476
1354
  epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
1355
+ if const_expr(delay_tma_store):
1356
+ if const_expr(epi_idx > 0):
1357
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
1358
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
1477
1359
  # Copy from D registers to shared memory
1478
1360
  if const_expr(has_D):
1479
- # Type conversion
1480
- tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
1481
- tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
1482
- cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
1483
- # Fence and barrier to make sure shared memory store is visible to TMA store
1484
- cute.arch.fence_proxy(
1485
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1486
- )
1487
- epilogue_barrier.arrive_and_wait()
1488
- # Get the global memory coordinate for the current epi tile
1489
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1490
- # Copy from shared memory to global memory
1491
- if is_tma_warp:
1492
- if const_expr(has_D):
1493
- copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
1494
- cute.arch.cp_async_bulk_commit_group()
1495
- cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
1496
- epilogue_barrier.arrive_and_wait()
1361
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
1362
+ if const_expr(not delay_tma_store):
1363
+ tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
1364
+
1365
+ if const_expr(delay_tma_store):
1366
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
1367
+
1368
+ self.epi_end(
1369
+ params,
1370
+ epi_tensors,
1371
+ epi_tile,
1372
+ tiled_copy_t2r,
1373
+ tiled_copy_r2s,
1374
+ tile_coord_mnkl,
1375
+ varlen_manager,
1376
+ tidx,
1377
+ )
1497
1378
 
1498
1379
  return epi_read_state, epi_producer_state
1499
1380
 
1500
- @cute.jit
1501
- def epi_load_g2s(
1381
+ def get_scheduler_class(self, varlen_m: bool = False):
1382
+ """Return the scheduler class to use. Override in subclasses for custom schedulers."""
1383
+ return TileScheduler if not varlen_m else VarlenMTileScheduler
1384
+
1385
+ def get_scheduler_arguments(
1502
1386
  self,
1503
- epi_pipeline: cutlass.pipeline.PipelineAsync,
1504
- copy_C: Callable,
1505
- bGS_gC: cute.Tensor,
1506
- bGS_sC: cute.Tensor,
1507
- epi_producer_state: cutlass.pipeline.PipelineState,
1508
- epi_idx: Int32,
1509
- should_load: Boolean,
1510
- ) -> cutlass.pipeline.PipelineState:
1511
- # We iterate over epi tiles in the N dimension first before the M dimension
1512
- epi_tile_layout = cute.make_layout(bGS_gC.shape[1], stride=(bGS_gC.shape[1][1], 1))
1513
- if should_load:
1514
- epi_pipeline.producer_acquire(epi_producer_state)
1515
- # Get the global memory coordinate for the current epi tile
1516
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
1517
- copy_C(
1518
- bGS_gC[None, gmem_coord],
1519
- bGS_sC[None, epi_producer_state.index],
1520
- tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
1387
+ mA: cute.Tensor,
1388
+ mB: cute.Tensor,
1389
+ mD: Optional[cute.Tensor],
1390
+ scheduler_args,
1391
+ varlen_args,
1392
+ ):
1393
+ """Create scheduler arguments. Override in subclasses for custom schedulers."""
1394
+ if const_expr(varlen_args.mCuSeqlensM is None):
1395
+ num_problems = (
1396
+ mD.shape[2]
1397
+ if mD is not None
1398
+ else (
1399
+ mB.shape[2]
1400
+ if varlen_args.mCuSeqlensK is None
1401
+ else varlen_args.mCuSeqlensK.shape[0] - 1
1402
+ )
1403
+ )
1404
+ problem_shape_ntile_mnl = (
1405
+ cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]),
1406
+ cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
1407
+ num_problems,
1408
+ )
1409
+ tile_sched_args = TileSchedulerArguments(
1410
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1411
+ raster_order=scheduler_args.raster_order,
1412
+ group_size=scheduler_args.max_swizzle_size,
1413
+ cluster_shape_mnk=self.cluster_shape_mnk,
1414
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
1415
+ batch_idx_permute=scheduler_args.batch_idx_permute,
1416
+ is_persistent=self.is_persistent,
1521
1417
  )
1522
- # Epi pipeline's producer commit is a NOP
1523
- epi_pipeline.producer_commit(epi_producer_state)
1524
- epi_producer_state.advance()
1525
- return epi_producer_state
1418
+ else:
1419
+ assert mD is not None or not self.gather_A
1420
+ problem_shape_ntile_mnl = (
1421
+ None,
1422
+ cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
1423
+ varlen_args.mCuSeqlensM.shape[0] - 1,
1424
+ )
1425
+ tile_sched_args = VarlenMTileSchedulerArguments(
1426
+ problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1427
+ total_m=mD.shape[0] if mD is not None else varlen_args.mAIdx.shape[0],
1428
+ cu_seqlens_m=varlen_args.mCuSeqlensM,
1429
+ raster_order=scheduler_args.raster_order,
1430
+ group_size=scheduler_args.max_swizzle_size,
1431
+ tile_shape_mn=self.cta_tile_shape_mnk[:2],
1432
+ cluster_shape_mnk=self.cluster_shape_mnk,
1433
+ tile_count_semaphore=scheduler_args.tile_count_semaphore,
1434
+ is_persistent=self.is_persistent,
1435
+ )
1436
+ return tile_sched_args
1437
+
1438
+ @cute.jit
1439
+ def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
1440
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1441
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1442
+
1443
+ @cute.jit
1444
+ def epi_begin(
1445
+ self,
1446
+ params: EpilogueParams,
1447
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
1448
+ epi_tile: cute.Tile,
1449
+ tiled_copy_t2r: Optional[cute.TiledCopy],
1450
+ tiled_copy_r2s: cute.TiledCopy,
1451
+ tile_coord_mnkl: cute.Coord,
1452
+ varlen_manager: VarlenManager,
1453
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
1454
+ tidx: Int32,
1455
+ ) -> Tuple[cute.Tensor, ...]:
1456
+ return ()
1526
1457
 
1527
- def epi_visit_acc_subtile(
1458
+ def epi_begin_loop(
1459
+ self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord
1460
+ ) -> Tuple[cute.Tensor, ...]:
1461
+ return ()
1462
+
1463
+ def epi_visit_subtile(
1528
1464
  self,
1529
1465
  params: EpilogueParams,
1466
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
1530
1467
  tRS_rD: cute.Tensor,
1531
1468
  tRS_rC: Optional[cute.Tensor] = None,
1532
1469
  ) -> Optional[cute.Tensor]:
1533
- # Apply alpha scaling to accumulator if alpha is provided (not None)
1534
- if const_expr(hasattr(params, "alpha") and params.alpha is not None):
1535
- alpha = utils.load_scalar_or_pointer(params.alpha)
1536
- tRS_rD.store(tRS_rD.load() * alpha)
1537
- # Apply C with beta scaling
1538
- if const_expr(tRS_rC is not None):
1539
- if const_expr(not hasattr(params, "beta") or params.beta is None):
1540
- # beta is None, default behavior: add C (beta=1.0)
1541
- tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
1542
- else:
1543
- beta = utils.load_scalar_or_pointer(params.beta)
1544
- tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
1545
1470
  return None
1546
1471
 
1547
- def get_scheduler_class(self):
1548
- """Return the scheduler class to use. Override in subclasses for custom schedulers."""
1549
- return TileScheduler
1550
-
1551
- def get_scheduler_arguments(self, problem_shape_ntile_mnl, scheduler_args):
1552
- """Create scheduler arguments. Override in subclasses for custom schedulers."""
1553
- return TileSchedulerArguments(
1554
- problem_shape_ntile_mnl=problem_shape_ntile_mnl,
1555
- raster_order=scheduler_args.raster_order,
1556
- group_size=scheduler_args.max_swizzle_size,
1557
- cluster_shape_mnk=self.cluster_shape_mnk,
1558
- tile_count_semaphore=scheduler_args.tile_count_semaphore,
1559
- batch_idx_permute=scheduler_args.batch_idx_permute,
1560
- is_persistent=self.is_persistent,
1561
- )
1562
-
1563
1472
  def epi_visit_acc(
1564
1473
  self,
1565
1474
  params: EpilogueParams,
@@ -1570,26 +1479,58 @@ class GemmSm90:
1570
1479
  ) -> None:
1571
1480
  pass
1572
1481
 
1482
+ @cute.jit
1483
+ def epi_end(
1484
+ self,
1485
+ params: EpilogueParams,
1486
+ epi_tensors: Tuple[cute.Tensor, ...],
1487
+ epi_tile: cute.Tile,
1488
+ tiled_copy_t2r: Optional[cute.TiledCopy],
1489
+ tiled_copy_r2s: cute.TiledCopy,
1490
+ tile_coord_mnkl: cute.Coord,
1491
+ varlen_manager,
1492
+ tidx,
1493
+ ) -> None:
1494
+ pass
1495
+
1573
1496
  def epi_to_underlying_arguments(
1574
1497
  self, args: EpilogueArguments, *, loc=None, ip=None
1575
1498
  ) -> EpilogueParams:
1576
- return GemmSm90.EpilogueParams(alpha=args.alpha, beta=args.beta)
1499
+ return self.EpilogueParams()
1500
+
1501
+ def epi_get_tma_atoms(
1502
+ self, params: EpilogueParams, *, loc=None, ip=None
1503
+ ) -> list[cute.CopyAtom]:
1504
+ """Subclasses can override this"""
1505
+ return []
1506
+
1507
+ def epi_get_tensormap_update_shapes_orders(
1508
+ self,
1509
+ params: EpilogueParams,
1510
+ cu_seqlens_m: cute.Tensor,
1511
+ batch_idx: Int32,
1512
+ *,
1513
+ loc=None,
1514
+ ip=None,
1515
+ ) -> tuple[list[Int32], list[int]]:
1516
+ """Subclasses can override this"""
1517
+ return [], []
1577
1518
 
1578
1519
  @staticmethod
1579
1520
  def epi_smem_bytes_per_stage(
1580
1521
  args: Optional[EpilogueArguments],
1581
- tile_shape_mnk: Tuple[int, int, int],
1582
- epi_tile: Tuple[int, int],
1522
+ cta_tile_shape_mnk: Tuple[int, int, int],
1523
+ epi_tile: cute.Tile,
1583
1524
  ) -> int:
1584
1525
  return 0
1585
1526
 
1586
1527
  def epi_get_smem_struct(self, params: EpilogueParams):
1587
- return cute.struct.MemRange[cutlass.Int32, 0] # Dummy struct
1528
+ return cute.struct.MemRange[Int32, 0] # Dummy struct
1588
1529
 
1589
1530
  def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
1590
1531
  return tuple()
1591
1532
 
1592
- def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
1533
+ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
1593
1534
  assert stage in ["mma", "epi"]
1594
1535
  barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1595
1536
  cute.arch.barrier(
@@ -1597,7 +1538,7 @@ class GemmSm90:
1597
1538
  number_of_threads=2 * self.num_threads_per_warp_group,
1598
1539
  )
1599
1540
 
1600
- def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
1541
+ def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
1601
1542
  assert stage in ["mma", "epi"]
1602
1543
  barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
1603
1544
  cute.arch.barrier_arrive(
@@ -1611,7 +1552,7 @@ class GemmSm90:
1611
1552
  self.d_layout.is_m_major_c() if self.d_layout is not None else False,
1612
1553
  num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
1613
1554
  ),
1614
- cutlass.Float16, # this is just to get the right source layout
1555
+ Float16, # this is just to get the right source layout
1615
1556
  )
1616
1557
  tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
1617
1558
  return tiled_copy_C_atom
@@ -1621,8 +1562,7 @@ class GemmSm90:
1621
1562
  tiled_mma: cute.TiledMma,
1622
1563
  d_layout: Optional[LayoutEnum],
1623
1564
  dtype: Type[cutlass.Numeric],
1624
- acc: cute.Tensor,
1625
- sD: cute.Tensor,
1565
+ sD: Optional[cute.Tensor],
1626
1566
  tidx: Int32,
1627
1567
  ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1628
1568
  if d_layout is None:
@@ -1637,12 +1577,10 @@ class GemmSm90:
1637
1577
  # (R2S, R2S_M, R2S_N, PIPE_D)
1638
1578
  thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
1639
1579
  tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1640
- # (R2S, R2S_M, R2S_N)
1641
- tRS_rAcc = tiled_copy_r2s.retile(acc)
1642
1580
  sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
1643
1581
  tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
1644
1582
  tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
1645
- return tiled_copy_r2s, tRS_rAcc, tRS_rD, tRS_sD
1583
+ return tiled_copy_r2s, tRS_rD, tRS_sD
1646
1584
 
1647
1585
  def epilog_smem_load_and_partition(
1648
1586
  self,
@@ -1654,7 +1592,7 @@ class GemmSm90:
1654
1592
  tidx: Int32,
1655
1593
  ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
1656
1594
  tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
1657
- copy_atom_s2r = utils.sm90_get_smem_load_op(c_layout, dtype)
1595
+ copy_atom_s2r = copy_utils.sm90_get_smem_load_op(c_layout, dtype)
1658
1596
  tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1659
1597
  thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1660
1598
  tSR_sC = thr_copy_s2r.partition_S(sC)
@@ -1665,57 +1603,53 @@ class GemmSm90:
1665
1603
  def epilog_gmem_copy_and_partition(
1666
1604
  self,
1667
1605
  atom: Union[cute.CopyAtom, cute.TiledCopy],
1668
- mD_mnl: cute.Tensor,
1606
+ mD_mn: cute.Tensor,
1669
1607
  tile_shape_mn: cute.Tile,
1670
1608
  epi_tile: cute.Tile,
1671
1609
  sD: cute.Tensor,
1672
1610
  tile_coord_mnkl: cute.Coord,
1673
- cu_seqlens_m: Optional[cute.Tensor] = None,
1611
+ tma_desc_ptr: Optional[cute.Pointer] = None,
1674
1612
  ) -> Tuple[cute.Tensor, cute.Tensor]:
1675
- batch_idx = tile_coord_mnkl[3]
1676
- if const_expr(cu_seqlens_m is not None):
1677
- mD_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl)
1678
- else:
1679
- mD_mn = mD_mnl[None, None, batch_idx]
1680
1613
  # (bM, bN)
1681
1614
  gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
1682
1615
  tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
1683
- bSG_sD, bSG_gD = cpasync.tma_partition(
1616
+ is_s2g = isinstance(
1617
+ atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp)
1618
+ )
1619
+ src_tensor, dst_tensor = (
1620
+ (sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD)
1621
+ )
1622
+ return copy_utils.tma_get_copy_fn(
1684
1623
  atom,
1685
- 0,
1686
- cute.make_layout(1),
1687
- cute.group_modes(sD, 0, 2),
1688
- tDgD_for_tma_partition,
1624
+ cta_coord=0,
1625
+ cta_layout=cute.make_layout(1),
1626
+ src_tensor=src_tensor,
1627
+ dst_tensor=dst_tensor,
1628
+ tma_desc_ptr=tma_desc_ptr,
1689
1629
  )
1690
- return bSG_sD, bSG_gD
1691
1630
 
1692
1631
  def make_ab_pipeline(
1693
1632
  self,
1694
- a_smem_layout: cute.Layout | cute.ComposedLayout,
1695
- b_smem_layout: cute.Layout | cute.ComposedLayout,
1696
1633
  tiled_mma: cute.TiledMma,
1697
1634
  cluster_layout_vmnk: cute.Layout,
1698
1635
  ab_pipeline_mbar_ptr: cute.Pointer,
1699
1636
  ):
1700
1637
  # Threads/warps participating in this pipeline
1701
- producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
1638
+ producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32
1702
1639
  ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1703
1640
  # Each warp will contribute to the arrive count with the number of mcast size
1704
1641
  mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
1705
- consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
1642
+ consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
1706
1643
  ab_pipeline_consumer_group = pipeline.CooperativeGroup(
1707
1644
  pipeline.Agent.Thread, consumer_arrive_cnt
1708
1645
  )
1709
1646
  pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
1710
- tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
1711
- if const_expr(not self.gather_A):
1712
- tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
1713
1647
  return pipeline_cls.create(
1714
1648
  barrier_storage=ab_pipeline_mbar_ptr,
1715
1649
  num_stages=self.ab_stage,
1716
1650
  producer_group=ab_pipeline_producer_group,
1717
1651
  consumer_group=ab_pipeline_consumer_group,
1718
- tx_count=tma_copy_bytes,
1652
+ tx_count=self.num_tma_load_bytes,
1719
1653
  cta_layout_vmnk=cluster_layout_vmnk,
1720
1654
  )
1721
1655
 
@@ -1725,7 +1659,7 @@ class GemmSm90:
1725
1659
  # Threads/warps participating in this pipeline
1726
1660
  epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1727
1661
  # Each warp will contribute 1 to the arrive count
1728
- consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
1662
+ consumer_arrive_cnt = self.num_epi_warps
1729
1663
  epi_pipeline_consumer_group = pipeline.CooperativeGroup(
1730
1664
  pipeline.Agent.Thread, consumer_arrive_cnt
1731
1665
  )
@@ -1738,6 +1672,14 @@ class GemmSm90:
1738
1672
  tx_count=tma_copy_c_bytes,
1739
1673
  )
1740
1674
 
1675
+ def make_epi_store_pipeline(self):
1676
+ # Threads/warps participating in tma store pipeline
1677
+ num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
1678
+ epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads)
1679
+ return pipeline.PipelineTmaStore.create(
1680
+ num_stages=self.epi_stage, producer_group=epi_store_producer_group
1681
+ )
1682
+
1741
1683
  def make_sched_pipeline(
1742
1684
  self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
1743
1685
  ):
@@ -1766,21 +1708,21 @@ class GemmSm90:
1766
1708
  @classmethod
1767
1709
  def _compute_stages(
1768
1710
  cls,
1769
- tile_shape_mnk: Tuple[int, int, int],
1711
+ cta_tile_shape_mnk: Tuple[int, int, int],
1770
1712
  epi_tile: Tuple[int, int],
1771
1713
  a_dtype: Type[cutlass.Numeric],
1772
1714
  b_dtype: Type[cutlass.Numeric],
1773
1715
  d_dtype: Optional[Type[cutlass.Numeric]],
1774
1716
  c_dtype: Optional[Type[cutlass.Numeric]],
1775
- epilogue_args: Optional[EpilogueArguments],
1717
+ epilogue_args: EpilogueArguments,
1776
1718
  smem_capacity: int,
1777
1719
  occupancy: int,
1778
- overlap_sD_sA: bool,
1720
+ overlap_sD_sA: bool = False,
1779
1721
  ) -> Tuple[int, int]:
1780
1722
  """Computes the number of stages for A/B/C operands based on heuristics.
1781
1723
 
1782
- :param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1783
- :type tile_shape_mnk: Tuple[int, int, int]
1724
+ :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
1725
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1784
1726
  :param a_dtype: Data type of operand A.
1785
1727
  :type a_dtype: type[cutlass.Numeric]
1786
1728
  :param b_dtype: Data type of operand B.
@@ -1803,15 +1745,15 @@ class GemmSm90:
1803
1745
  cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1804
1746
  )
1805
1747
  epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1806
- epilogue_args, tile_shape_mnk, epi_tile
1748
+ epilogue_args, cta_tile_shape_mnk, epi_tile
1807
1749
  )
1808
1750
  epi_bytes = epi_bytes_per_stage * epi_stage
1809
1751
  epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
1810
1752
  if c_dtype is not None:
1811
1753
  epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
1812
1754
 
1813
- a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
1814
- b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
1755
+ a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
1756
+ b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
1815
1757
  ab_bytes_per_stage = (
1816
1758
  cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
1817
1759
  )
@@ -1829,15 +1771,15 @@ class GemmSm90:
1829
1771
 
1830
1772
  @staticmethod
1831
1773
  def _sm90_compute_tile_shape_or_override(
1832
- tile_shape_mnk: Tuple[int, int, int],
1774
+ cta_tile_shape_mnk: Tuple[int, int, int],
1833
1775
  atom_layout_mnk: Tuple[int, int, int],
1834
1776
  element_type: Optional[Type[cutlass.Numeric]] = None,
1835
1777
  epi_tile_override: Tuple[int, int] | None = None,
1836
1778
  ) -> Tuple[int, int]:
1837
1779
  """Compute the epilogue tile shape or use override if provided.
1838
1780
 
1839
- :param tile_shape_mnk: CTA tile shape (M,N,K)
1840
- :type tile_shape_mnk: Tuple[int, int, int]
1781
+ :param cta_tile_shape_mnk: CTA tile shape (M,N,K)
1782
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1841
1783
  :param element_type: Data type of elements
1842
1784
  :type element_type: type[cutlass.Numeric]
1843
1785
  :param is_cooperative: Whether to use cooperative approach
@@ -1850,12 +1792,12 @@ class GemmSm90:
1850
1792
  """
1851
1793
  if epi_tile_override is not None:
1852
1794
  return epi_tile_override
1853
- if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
1854
- tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
1855
- tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1856
- elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
1857
- tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
1858
- tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
1795
+ if cta_tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
1796
+ tile_m = math.gcd(128, cute.size(cta_tile_shape_mnk, mode=[0]))
1797
+ tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
1798
+ elif cta_tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
1799
+ tile_m = math.gcd(192, cute.size(cta_tile_shape_mnk, mode=[0]))
1800
+ tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
1859
1801
  else:
1860
1802
  # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
1861
1803
  # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
@@ -1864,13 +1806,13 @@ class GemmSm90:
1864
1806
  # We could change the epilogue to accommodate this,
1865
1807
  # but it's easier to just set epi_tile_m = 64.
1866
1808
  n_perf = 64 if element_type is not None and element_type.width == 8 else 32
1867
- tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
1868
- tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
1809
+ tile_m = math.gcd(64, cute.size(cta_tile_shape_mnk, mode=[0]))
1810
+ tile_n = math.gcd(n_perf, cute.size(cta_tile_shape_mnk, mode=[1]))
1869
1811
  return (tile_m, tile_n)
1870
1812
 
1871
1813
  @staticmethod
1872
1814
  def _make_smem_layouts(
1873
- tile_shape_mnk: Tuple[int, int, int],
1815
+ cta_tile_shape_mnk: Tuple[int, int, int],
1874
1816
  epi_tile: Tuple[int, int],
1875
1817
  a_dtype: Type[cutlass.Numeric],
1876
1818
  a_layout: LayoutEnum,
@@ -1888,8 +1830,8 @@ class GemmSm90:
1888
1830
  ]:
1889
1831
  """Create shared memory layouts for A, B, and C tensors.
1890
1832
 
1891
- :param tile_shape_mnk: CTA tile shape (M,N,K)
1892
- :type tile_shape_mnk: Tuple[int, int, int]
1833
+ :param cta_tile_shape_mnk: CTA tile shape (M,N,K)
1834
+ :type cta_tile_shape_mnk: Tuple[int, int, int]
1893
1835
  :param epi_tile: Epilogue tile shape
1894
1836
  :type epi_tile: Tuple[int, int]
1895
1837
  :param a_dtype: Data type for matrix A
@@ -1912,11 +1854,11 @@ class GemmSm90:
1912
1854
  :return: Tuple of shared memory layouts for A, B, and C
1913
1855
  :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
1914
1856
  """
1915
- a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
1857
+ a_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
1916
1858
 
1917
1859
  a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1918
1860
  b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
1919
- a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
1861
+ a_major_mode_size = cta_tile_shape_mnk[2 if a_is_k_major else 0]
1920
1862
  a_smem_layout_atom = warpgroup.make_smem_layout_atom(
1921
1863
  sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
1922
1864
  a_dtype,
@@ -1927,9 +1869,9 @@ class GemmSm90:
1927
1869
  order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
1928
1870
  )
1929
1871
 
1930
- b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None))
1872
+ b_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
1931
1873
 
1932
- b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
1874
+ b_major_mode_size = cta_tile_shape_mnk[2 if b_is_k_major else 1]
1933
1875
  b_smem_layout_atom = warpgroup.make_smem_layout_atom(
1934
1876
  sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
1935
1877
  b_dtype,
@@ -1940,36 +1882,18 @@ class GemmSm90:
1940
1882
  order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
1941
1883
  )
1942
1884
 
1885
+ epi_smem_layout_staged = None
1943
1886
  if d_dtype is not None:
1944
- d_smem_shape = epi_tile
1945
- d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
1946
- d_smem_layout_atom = warpgroup.make_smem_layout_atom(
1947
- sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
1948
- d_dtype,
1949
- )
1950
- epi_smem_layout_staged = cute.tile_to_shape(
1951
- d_smem_layout_atom,
1952
- cute.append(d_smem_shape, epi_stage),
1953
- order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
1887
+ epi_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
1888
+ d_dtype, d_layout, epi_tile, epi_stage
1954
1889
  )
1955
- else:
1956
- epi_smem_layout_staged = None
1957
1890
 
1891
+ epi_c_smem_layout_staged = None
1958
1892
  if c_dtype is not None:
1959
1893
  assert c_layout is not None
1960
- c_smem_shape = epi_tile
1961
- c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
1962
- c_smem_layout_atom = warpgroup.make_smem_layout_atom(
1963
- sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
1964
- c_dtype,
1965
- )
1966
- epi_c_smem_layout_staged = cute.tile_to_shape(
1967
- c_smem_layout_atom,
1968
- cute.append(c_smem_shape, epi_c_stage),
1969
- order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
1894
+ epi_c_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
1895
+ c_dtype, c_layout, epi_tile, epi_c_stage
1970
1896
  )
1971
- else:
1972
- epi_c_smem_layout_staged = None
1973
1897
 
1974
1898
  return (
1975
1899
  a_smem_layout_staged,
@@ -1983,7 +1907,7 @@ class GemmSm90:
1983
1907
  tensor_d: cute.Tensor,
1984
1908
  epi_smem_layout_staged: cute.ComposedLayout,
1985
1909
  epi_tile: Tuple[int, int],
1986
- store_or_load: str,
1910
+ op_type: Literal["store", "load", "add"],
1987
1911
  ) -> Tuple[cute.CopyAtom, cute.Tensor]:
1988
1912
  """Create TMA atoms and tensors for storing D or loading C.
1989
1913
 
@@ -1997,13 +1921,15 @@ class GemmSm90:
1997
1921
  :return: TMA atom and tensor for C
1998
1922
  :rtype: Tuple[cute.CopyAtom, cute.Tensor]
1999
1923
  """
2000
- assert store_or_load in ["load", "store"]
1924
+ assert op_type in ["load", "store", "add"]
2001
1925
  epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
2002
1926
  d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
2003
1927
  op = (
2004
1928
  cpasync.CopyBulkTensorTileG2SOp()
2005
- if store_or_load == "load"
1929
+ if op_type == "load"
2006
1930
  else cpasync.CopyBulkTensorTileS2GOp()
1931
+ if op_type == "store"
1932
+ else cpasync.CopyReduceBulkTensorTileS2GOp(cute.ReductionOp.ADD)
2007
1933
  )
2008
1934
  tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
2009
1935
  op, tensor_d, epi_smem_layout, d_cta_v_layout
@@ -2013,7 +1939,7 @@ class GemmSm90:
2013
1939
  @staticmethod
2014
1940
  def _make_tma_atoms_and_tensors(
2015
1941
  tensor: cute.Tensor,
2016
- smem_layout_staged: cute.ComposedLayout,
1942
+ smem_layout: cute.ComposedLayout,
2017
1943
  smem_tile: Tuple[int, int],
2018
1944
  mcast_dim: int,
2019
1945
  ) -> Tuple[cute.CopyAtom, cute.Tensor]:
@@ -2021,8 +1947,8 @@ class GemmSm90:
2021
1947
 
2022
1948
  :param tensor: Input tensor (A or B)
2023
1949
  :type tensor: cute.Tensor
2024
- :param smem_layout_staged: Shared memory layout for the tensor
2025
- :type smem_layout_staged: cute.ComposedLayout
1950
+ :param smem_layout: Shared memory layout for the tensor
1951
+ :type smem_layout: cute.ComposedLayout
2026
1952
  :param smem_tile: Shared memory tile shape
2027
1953
  :type smem_tile: Tuple[int, int]
2028
1954
  :param mcast_dim: Multicast dimension
@@ -2036,8 +1962,6 @@ class GemmSm90:
2036
1962
  if mcast_dim == 1
2037
1963
  else cpasync.CopyBulkTensorTileG2SMulticastOp()
2038
1964
  )
2039
-
2040
- smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
2041
1965
  tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
2042
1966
  op,
2043
1967
  tensor,
@@ -2054,13 +1978,18 @@ class GemmSm90:
2054
1978
  num_bits_per_copy=copy_bits,
2055
1979
  )
2056
1980
  copy_elems = copy_bits // dtype.width
2057
- shape_dim_1 = cute.size(self.tile_shape_mnk[2]) // copy_elems
1981
+ loads_per_cache_line = 128 * 8 // copy_bits # 128 bytes per cache line
1982
+ shape_dim_1 = cute.size(self.cta_tile_shape_mnk[2]) // copy_elems
1983
+ if shape_dim_1 > loads_per_cache_line:
1984
+ shape_dim_1 = math.gcd(shape_dim_1, loads_per_cache_line)
2058
1985
  # thread layout for copy
2059
1986
  thread_layout = cute.make_layout(
2060
1987
  (num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
2061
1988
  )
2062
1989
  if major_mode != LayoutEnum.ROW_MAJOR:
2063
- shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
1990
+ shape_dim_0 = cute.size(self.cta_tile_shape_mnk[0]) // copy_elems
1991
+ if shape_dim_0 > loads_per_cache_line:
1992
+ shape_dim_0 = math.gcd(shape_dim_0, loads_per_cache_line)
2064
1993
  thread_layout = cute.make_layout(
2065
1994
  (shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
2066
1995
  )
@@ -2102,7 +2031,7 @@ class GemmSm90:
2102
2031
  """
2103
2032
  is_valid = True
2104
2033
  if a_dtype not in {
2105
- cutlass.Float16,
2034
+ Float16,
2106
2035
  cutlass.BFloat16,
2107
2036
  cutlass.Float8E4M3FN,
2108
2037
  cutlass.Float8E5M2,
@@ -2110,19 +2039,19 @@ class GemmSm90:
2110
2039
  is_valid = False
2111
2040
  # tested b_dtype
2112
2041
  if b_dtype not in {
2113
- cutlass.Float16,
2042
+ Float16,
2114
2043
  cutlass.BFloat16,
2115
2044
  cutlass.Float8E4M3FN,
2116
2045
  cutlass.Float8E5M2,
2117
2046
  }:
2118
2047
  is_valid = False
2119
- if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
2048
+ if acc_dtype not in {Float32, Float16}:
2120
2049
  is_valid = False
2121
2050
  # tested d_dtype
2122
2051
  if d_dtype not in {
2123
2052
  None,
2124
- cutlass.Float32,
2125
- cutlass.Float16,
2053
+ Float32,
2054
+ Float16,
2126
2055
  cutlass.BFloat16,
2127
2056
  cutlass.Float8E4M3FN,
2128
2057
  cutlass.Float8E5M2,
@@ -2139,107 +2068,3 @@ class GemmSm90:
2139
2068
  if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
2140
2069
  is_valid = False
2141
2070
  return is_valid
2142
-
2143
-
2144
- def gemm_sm90(
2145
- A: Tensor, # (l, m, k)
2146
- B: Tensor, # (l, n, k)
2147
- D: Tensor, # (l, m, n)
2148
- C: Optional[Tensor], # (l, m, n)
2149
- tile_count_semaphore: Optional[Tensor], # (1,)
2150
- tile_M: int,
2151
- tile_N: int,
2152
- cluster_M: int,
2153
- cluster_N: int,
2154
- pingpong: bool = False,
2155
- persistent: bool = True,
2156
- alpha: float | Tensor = 1.0,
2157
- beta: float | Tensor = 1.0,
2158
- ) -> None:
2159
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(A, B, D, C)
2160
- GemmWrapperBase.permute_tensors(tensor_infos)
2161
- GemmWrapperBase.extract_dtypes(tensor_infos)
2162
- major_configs = {
2163
- "A": ("m", "k", "l"),
2164
- "B": ("n", "k", "l"),
2165
- "D": ("m", "n", "l"),
2166
- "C": ("m", "n", "l"),
2167
- }
2168
- GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
2169
-
2170
- acc_dtype = cutlass.Float32
2171
- tile_shape_mn = (tile_M, tile_N)
2172
- cluster_shape_mnk = (cluster_M, cluster_N, 1)
2173
- if not GemmSm90.is_valid_dtypes(
2174
- tensor_infos["A"].dtype,
2175
- tensor_infos["B"].dtype,
2176
- acc_dtype,
2177
- tensor_infos["D"].dtype,
2178
- tensor_infos["A"].major,
2179
- tensor_infos["B"].major,
2180
- ):
2181
- raise TypeError("Skipping due to unsupported combination of types and majors")
2182
-
2183
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
2184
- GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
2185
-
2186
- def scalar_arg(scalar: float | Tensor):
2187
- if isinstance(scalar, float):
2188
- return Float32(scalar) if scalar != 1.0 else None
2189
- else:
2190
- assert isinstance(scalar, Tensor)
2191
- return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
2192
-
2193
- epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta))
2194
- scheduler_args = GemmWrapperBase.create_scheduler_args(
2195
- max_active_clusters, tile_count_semaphore
2196
- )
2197
- current_stream = cutlass_torch.current_stream()
2198
- compile_key = GemmWrapperBase.get_compile_key(
2199
- tensor_infos,
2200
- None,
2201
- tile_shape_mn,
2202
- cluster_shape_mnk,
2203
- pingpong,
2204
- persistent,
2205
- tile_count_semaphore is not None,
2206
- 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
2207
- 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
2208
- key_tensor_names=("A", "B", "D", "C"),
2209
- )
2210
- cache = gemm_sm90.compile_cache
2211
- if compile_key not in cache:
2212
- gemm = GemmSm90(
2213
- acc_dtype,
2214
- tensor_infos["A"].dtype,
2215
- tile_shape_mn,
2216
- cluster_shape_mnk,
2217
- pingpong=pingpong,
2218
- is_persistent=persistent,
2219
- )
2220
- cache[compile_key] = cute.compile(
2221
- gemm,
2222
- tensor_infos["A"].cute_tensor,
2223
- tensor_infos["B"].cute_tensor,
2224
- tensor_infos["D"].cute_tensor,
2225
- tensor_infos["C"].cute_tensor,
2226
- epi_args,
2227
- scheduler_args,
2228
- None, # varlen_args
2229
- None, # mAIdx
2230
- current_stream,
2231
- )
2232
- cache[compile_key](
2233
- tensor_infos["A"].cute_tensor,
2234
- tensor_infos["B"].cute_tensor,
2235
- tensor_infos["D"].cute_tensor,
2236
- tensor_infos["C"].cute_tensor,
2237
- epi_args,
2238
- scheduler_args,
2239
- None,
2240
- None,
2241
- current_stream,
2242
- )
2243
-
2244
-
2245
- gemm_sm90.compile_cache = {}