quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +16 -25
- quack/autotuner.py +64 -5
- quack/cross_entropy.py +6 -10
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/layernorm.py +1 -1
- quack/reduce.py +6 -7
- quack/rmsnorm.py +126 -158
- quack/softmax.py +1 -1
- quack/tile_scheduler.py +37 -49
- quack/utils.py +61 -71
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +3 -3
- quack_kernels-0.2.2.dist-info/RECORD +37 -0
- quack_kernels-0.2.0.dist-info/RECORD +0 -37
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/dense_gemm_sm90.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
|
|
3
3
|
|
|
4
4
|
import enum
|
|
5
|
-
from typing import Tuple, Type, Callable, Optional, Union
|
|
5
|
+
from typing import Tuple, Type, Callable, Optional, Union, Literal
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from functools import partial
|
|
8
8
|
import math
|
|
@@ -86,12 +86,13 @@ class NamedBarrierGemm(enum.IntEnum):
|
|
|
86
86
|
MmaWG1 = enum.auto()
|
|
87
87
|
EpiWG0 = enum.auto()
|
|
88
88
|
EpiWG1 = enum.auto()
|
|
89
|
+
TmemPtr = enum.auto()
|
|
89
90
|
|
|
90
91
|
|
|
91
92
|
class GemmSm90:
|
|
92
93
|
"""
|
|
93
94
|
This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
94
|
-
and architectural features specific to Hopper GPUs.
|
|
95
|
+
and architectural features specific to Hopper GPUs with persistent tile scheduling and warp specialization.
|
|
95
96
|
|
|
96
97
|
:param acc_dtype: Data type for accumulation during computation
|
|
97
98
|
:type acc_dtype: type[cutlass.Numeric]
|
|
@@ -125,12 +126,15 @@ class GemmSm90:
|
|
|
125
126
|
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
|
126
127
|
"""
|
|
127
128
|
|
|
129
|
+
arch = 90
|
|
128
130
|
bytes_per_tensormap = 128
|
|
131
|
+
num_epi_tensormaps: int = 0
|
|
129
132
|
|
|
130
133
|
@dataclass
|
|
131
134
|
class EpilogueArguments(ArgumentsBase):
|
|
132
135
|
alpha: Optional[Float32 | cute.Tensor] = None
|
|
133
136
|
beta: Optional[Float32 | cute.Tensor] = None
|
|
137
|
+
add_to_output: bool = False
|
|
134
138
|
|
|
135
139
|
@dataclass
|
|
136
140
|
class EpilogueParams(ParamsBase):
|
|
@@ -174,8 +178,8 @@ class GemmSm90:
|
|
|
174
178
|
|
|
175
179
|
self.cluster_shape_mnk = cluster_shape_mnk
|
|
176
180
|
# K dimension is deferred in _setup_attributes
|
|
177
|
-
self.
|
|
178
|
-
tile_M, tile_N = self.
|
|
181
|
+
self.cta_tile_shape_mnk = (*tile_shape_mn, 1)
|
|
182
|
+
tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
|
|
179
183
|
# check the cta tile shape
|
|
180
184
|
if not self.pingpong:
|
|
181
185
|
if tile_M not in [64, 128, 192, 256, 320]:
|
|
@@ -209,7 +213,9 @@ class GemmSm90:
|
|
|
209
213
|
else:
|
|
210
214
|
atom_layout_m, atom_layout_n = 1, 2
|
|
211
215
|
else:
|
|
212
|
-
atom_layout_m =
|
|
216
|
+
atom_layout_m = (
|
|
217
|
+
self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2
|
|
218
|
+
)
|
|
213
219
|
atom_layout_n = 1
|
|
214
220
|
assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
|
|
215
221
|
else:
|
|
@@ -229,16 +235,14 @@ class GemmSm90:
|
|
|
229
235
|
self.num_threads_per_warp_group = 128
|
|
230
236
|
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
|
231
237
|
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
|
232
|
-
self.
|
|
233
|
-
self.mma_warp_groups if not self.pingpong else 1
|
|
234
|
-
) * self.num_threads_per_warp_group
|
|
238
|
+
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
|
|
235
239
|
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
|
236
240
|
self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
|
|
237
241
|
self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
238
242
|
self.ab_load_warp_id = self.mma_warp_groups * 4
|
|
239
243
|
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
240
244
|
|
|
241
|
-
regs_per_thread = math.prod(self.
|
|
245
|
+
regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
|
|
242
246
|
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
243
247
|
)
|
|
244
248
|
if self.fp8_slow_accum:
|
|
@@ -268,7 +272,7 @@ class GemmSm90:
|
|
|
268
272
|
self.shared_storage = None
|
|
269
273
|
self.buffer_align_bytes = 1024
|
|
270
274
|
|
|
271
|
-
def _setup_attributes(self, epilogue_args:
|
|
275
|
+
def _setup_attributes(self, epilogue_args: EpilogueArguments):
|
|
272
276
|
"""Set up configurations that are dependent on GEMM inputs
|
|
273
277
|
|
|
274
278
|
This method configures various attributes based on the input tensor properties
|
|
@@ -289,7 +293,7 @@ class GemmSm90:
|
|
|
289
293
|
self.b_layout.sm90_mma_major_mode(),
|
|
290
294
|
self.acc_dtype,
|
|
291
295
|
self.atom_layout_mnk,
|
|
292
|
-
tiler_mn=(64, self.
|
|
296
|
+
tiler_mn=(64, self.cta_tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
|
293
297
|
)
|
|
294
298
|
if const_expr(self.atom_layout_mnk[1] > 1):
|
|
295
299
|
# If N dimension is split among 2 WGs, we need to permute the N dimension so
|
|
@@ -299,7 +303,7 @@ class GemmSm90:
|
|
|
299
303
|
# WG1 would write to a separate epi smem of size (64, 16) that's far away.
|
|
300
304
|
atom_n = self.atom_layout_mnk[1]
|
|
301
305
|
permutation_n = cute.make_ordered_layout(
|
|
302
|
-
(8, self.
|
|
306
|
+
(8, self.cta_tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
|
|
303
307
|
)
|
|
304
308
|
self.tiled_mma = cute.make_tiled_mma(
|
|
305
309
|
cute.make_mma_atom(self.tiled_mma.op),
|
|
@@ -308,23 +312,23 @@ class GemmSm90:
|
|
|
308
312
|
)
|
|
309
313
|
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
|
|
310
314
|
mma_inst_tile_k = 4
|
|
311
|
-
self.
|
|
312
|
-
self.
|
|
313
|
-
self.
|
|
315
|
+
self.cta_tile_shape_mnk = (
|
|
316
|
+
self.cta_tile_shape_mnk[0],
|
|
317
|
+
self.cta_tile_shape_mnk[1],
|
|
314
318
|
mma_inst_shape_k * mma_inst_tile_k,
|
|
315
319
|
)
|
|
316
320
|
|
|
317
321
|
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
|
318
322
|
|
|
319
323
|
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
|
320
|
-
self.
|
|
324
|
+
self.cta_tile_shape_mnk,
|
|
321
325
|
self.atom_layout_mnk,
|
|
322
326
|
self.d_dtype,
|
|
323
327
|
)
|
|
324
328
|
|
|
325
329
|
# Compute stage before compute smem layout
|
|
326
330
|
self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
|
|
327
|
-
self.
|
|
331
|
+
self.cta_tile_shape_mnk,
|
|
328
332
|
self.epi_tile,
|
|
329
333
|
self.a_dtype,
|
|
330
334
|
self.b_dtype,
|
|
@@ -344,7 +348,7 @@ class GemmSm90:
|
|
|
344
348
|
self.epi_smem_layout_staged,
|
|
345
349
|
self.epi_c_smem_layout_staged,
|
|
346
350
|
) = self._make_smem_layouts(
|
|
347
|
-
self.
|
|
351
|
+
self.cta_tile_shape_mnk,
|
|
348
352
|
self.epi_tile,
|
|
349
353
|
self.a_dtype,
|
|
350
354
|
self.a_layout,
|
|
@@ -366,10 +370,9 @@ class GemmSm90:
|
|
|
366
370
|
mB: cute.Tensor,
|
|
367
371
|
mD: Optional[cute.Tensor],
|
|
368
372
|
mC: Optional[cute.Tensor],
|
|
369
|
-
epilogue_args:
|
|
373
|
+
epilogue_args: ArgumentsBase,
|
|
370
374
|
scheduler_args: TileSchedulerOptions,
|
|
371
375
|
varlen_args: Optional[VarlenArguments],
|
|
372
|
-
mAIdx: Optional[cute.Tensor],
|
|
373
376
|
stream: cuda.CUstream,
|
|
374
377
|
):
|
|
375
378
|
"""Execute the GEMM operation in steps:
|
|
@@ -405,7 +408,10 @@ class GemmSm90:
|
|
|
405
408
|
raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
|
|
406
409
|
if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
|
|
407
410
|
raise TypeError("a_dtype should be float16 or float8")
|
|
408
|
-
|
|
411
|
+
|
|
412
|
+
if const_expr(varlen_args is None):
|
|
413
|
+
varlen_args = VarlenArguments()
|
|
414
|
+
assert (varlen_args.mAIdx is not None) == self.gather_A
|
|
409
415
|
|
|
410
416
|
# Assume all strides are divisible by 128 bits except the last stride
|
|
411
417
|
new_stride = lambda t: tuple(
|
|
@@ -421,77 +427,47 @@ class GemmSm90:
|
|
|
421
427
|
|
|
422
428
|
self._setup_attributes(epilogue_args)
|
|
423
429
|
|
|
430
|
+
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0))
|
|
431
|
+
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0))
|
|
432
|
+
tma_atom_a, tma_tensor_a = None, None
|
|
424
433
|
if const_expr(not self.gather_A):
|
|
425
434
|
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
|
426
435
|
mA,
|
|
427
|
-
|
|
428
|
-
(self.
|
|
436
|
+
a_smem_layout,
|
|
437
|
+
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
|
|
429
438
|
self.cluster_shape_mnk[1],
|
|
430
439
|
)
|
|
431
|
-
else:
|
|
432
|
-
tma_atom_a, tma_tensor_a = None, None
|
|
433
|
-
|
|
434
440
|
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
|
435
441
|
mB,
|
|
436
|
-
|
|
437
|
-
(self.
|
|
442
|
+
b_smem_layout,
|
|
443
|
+
(self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
|
|
438
444
|
self.cluster_shape_mnk[0],
|
|
439
445
|
)
|
|
440
446
|
|
|
447
|
+
self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
448
|
+
if const_expr(not self.gather_A):
|
|
449
|
+
self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
450
|
+
|
|
451
|
+
tma_atom_d, tma_tensor_d = None, None
|
|
441
452
|
if const_expr(mD is not None):
|
|
442
453
|
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
|
443
|
-
mD,
|
|
454
|
+
mD,
|
|
455
|
+
self.epi_smem_layout_staged,
|
|
456
|
+
self.epi_tile,
|
|
457
|
+
op_type="store"
|
|
458
|
+
if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
|
|
459
|
+
else "add",
|
|
444
460
|
)
|
|
445
|
-
|
|
446
|
-
tma_atom_d, tma_tensor_d = None, None
|
|
447
|
-
|
|
461
|
+
tma_atom_c, tma_tensor_c = None, None
|
|
448
462
|
if const_expr(mC is not None):
|
|
449
463
|
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
|
450
|
-
mC, self.epi_c_smem_layout_staged, self.epi_tile,
|
|
464
|
+
mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
|
|
451
465
|
)
|
|
452
|
-
else:
|
|
453
|
-
tma_atom_c, tma_tensor_c = None, None
|
|
454
466
|
|
|
455
467
|
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
|
456
468
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
if const_expr(varlen_args.mCuSeqlensM is None):
|
|
460
|
-
num_problems = (
|
|
461
|
-
mD.shape[2]
|
|
462
|
-
if mD is not None
|
|
463
|
-
else (
|
|
464
|
-
mB.shape[2]
|
|
465
|
-
if varlen_args.mCuSeqlensK is None
|
|
466
|
-
else varlen_args.mCuSeqlensK.shape[0] - 1
|
|
467
|
-
)
|
|
468
|
-
)
|
|
469
|
-
problem_shape_ntile_mnl = (
|
|
470
|
-
cute.ceil_div(mA.shape[0], self.tile_shape_mnk[0]),
|
|
471
|
-
cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
|
|
472
|
-
num_problems,
|
|
473
|
-
)
|
|
474
|
-
TileSchedulerCls = self.get_scheduler_class()
|
|
475
|
-
tile_sched_args = self.get_scheduler_arguments(problem_shape_ntile_mnl, scheduler_args)
|
|
476
|
-
else:
|
|
477
|
-
assert mD is not None or not self.gather_A
|
|
478
|
-
problem_shape_ntile_mnl = (
|
|
479
|
-
None,
|
|
480
|
-
cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
|
|
481
|
-
varlen_args.mCuSeqlensM.shape[0] - 1,
|
|
482
|
-
)
|
|
483
|
-
TileSchedulerCls = VarlenMTileScheduler
|
|
484
|
-
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
485
|
-
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
486
|
-
total_m=mD.shape[0] if mD is not None else mAIdx.shape[0],
|
|
487
|
-
cu_seqlens_m=varlen_args.mCuSeqlensM,
|
|
488
|
-
raster_order=scheduler_args.raster_order,
|
|
489
|
-
group_size=scheduler_args.max_swizzle_size,
|
|
490
|
-
tile_shape_mn=self.tile_shape_mnk[:2],
|
|
491
|
-
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
492
|
-
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
493
|
-
is_persistent=self.is_persistent,
|
|
494
|
-
)
|
|
469
|
+
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
|
|
470
|
+
tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
|
|
495
471
|
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
496
472
|
grid = TileSchedulerCls.get_grid_shape(
|
|
497
473
|
tile_sched_params, scheduler_args.max_active_clusters
|
|
@@ -534,6 +510,7 @@ class GemmSm90:
|
|
|
534
510
|
|
|
535
511
|
# Launch the kernel synchronously
|
|
536
512
|
self.kernel(
|
|
513
|
+
self.tiled_mma,
|
|
537
514
|
tma_atom_a,
|
|
538
515
|
tma_tensor_a if const_expr(not self.gather_A) else mA,
|
|
539
516
|
tma_atom_b,
|
|
@@ -543,11 +520,10 @@ class GemmSm90:
|
|
|
543
520
|
tma_atom_c,
|
|
544
521
|
tma_tensor_c,
|
|
545
522
|
epilogue_params,
|
|
546
|
-
mAIdx,
|
|
547
523
|
varlen_args.mCuSeqlensM,
|
|
548
524
|
varlen_args.mCuSeqlensK,
|
|
549
525
|
varlen_args.mTensormaps,
|
|
550
|
-
|
|
526
|
+
varlen_args.mAIdx,
|
|
551
527
|
self.cluster_layout_mnk,
|
|
552
528
|
self.a_smem_layout_staged,
|
|
553
529
|
self.b_smem_layout_staged,
|
|
@@ -569,6 +545,7 @@ class GemmSm90:
|
|
|
569
545
|
@cute.kernel
|
|
570
546
|
def kernel(
|
|
571
547
|
self,
|
|
548
|
+
tiled_mma: cute.TiledMma,
|
|
572
549
|
tma_atom_a: Optional[cute.CopyAtom],
|
|
573
550
|
mA_mkl: cute.Tensor,
|
|
574
551
|
tma_atom_b: cute.CopyAtom,
|
|
@@ -578,11 +555,10 @@ class GemmSm90:
|
|
|
578
555
|
tma_atom_c: Optional[cute.CopyAtom],
|
|
579
556
|
mC_mnl: Optional[cute.Tensor],
|
|
580
557
|
epilogue_params: ParamsBase,
|
|
581
|
-
mAIdx: Optional[cute.Tensor],
|
|
582
558
|
cu_seqlens_m: Optional[cute.Tensor],
|
|
583
559
|
cu_seqlens_k: Optional[cute.Tensor],
|
|
584
560
|
tensormaps: Optional[cute.Tensor],
|
|
585
|
-
|
|
561
|
+
mAIdx: Optional[cute.Tensor],
|
|
586
562
|
cluster_layout_mnk: cute.Layout,
|
|
587
563
|
a_smem_layout: cute.ComposedLayout,
|
|
588
564
|
b_smem_layout: cute.ComposedLayout,
|
|
@@ -621,6 +597,8 @@ class GemmSm90:
|
|
|
621
597
|
varlen_m = const_expr(cu_seqlens_m is not None)
|
|
622
598
|
varlen_k = const_expr(cu_seqlens_k is not None)
|
|
623
599
|
assert not (varlen_m and varlen_k)
|
|
600
|
+
if const_expr(self.gather_A):
|
|
601
|
+
assert varlen_m or varlen_k
|
|
624
602
|
has_D = const_expr(mD_mnl is not None)
|
|
625
603
|
has_C = const_expr(mC_mnl is not None)
|
|
626
604
|
|
|
@@ -641,8 +619,6 @@ class GemmSm90:
|
|
|
641
619
|
storage = smem.allocate(self.shared_storage)
|
|
642
620
|
|
|
643
621
|
ab_pipeline = self.make_ab_pipeline(
|
|
644
|
-
a_smem_layout=cute.slice_(a_smem_layout, (None, None, 0)),
|
|
645
|
-
b_smem_layout=cute.slice_(b_smem_layout, (None, None, 0)),
|
|
646
622
|
tiled_mma=tiled_mma,
|
|
647
623
|
cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
|
|
648
624
|
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
@@ -682,27 +658,9 @@ class GemmSm90:
|
|
|
682
658
|
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
|
683
659
|
|
|
684
660
|
# Get tensormap buffer address
|
|
685
|
-
tensormap_manager =
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
tensormap_manager = TensorMapManagerSm90(
|
|
689
|
-
cutlass.utils.TensorMapUpdateMode.GMEM, GemmSm90.bytes_per_tensormap
|
|
690
|
-
)
|
|
691
|
-
# equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
|
|
692
|
-
tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
|
|
693
|
-
if const_expr(varlen_m):
|
|
694
|
-
tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
|
|
695
|
-
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
696
|
-
tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
|
|
697
|
-
)
|
|
698
|
-
else:
|
|
699
|
-
assert varlen_k
|
|
700
|
-
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
701
|
-
tensormaps[tensormap_workspace_idx, 0, None].iterator
|
|
702
|
-
)
|
|
703
|
-
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
704
|
-
tensormaps[tensormap_workspace_idx, 1, None].iterator
|
|
705
|
-
)
|
|
661
|
+
tensormap_manager, tensormap_ab_ptrs, tensormap_d_ptr, tensormap_epi_ptrs = (
|
|
662
|
+
self.tensormap_init(tensormaps, varlen_m, varlen_k, has_D, warp_idx)
|
|
663
|
+
)
|
|
706
664
|
|
|
707
665
|
TileSchedulerCls = partial(
|
|
708
666
|
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
|
@@ -717,14 +675,15 @@ class GemmSm90:
|
|
|
717
675
|
is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
718
676
|
if const_expr(varlen_k):
|
|
719
677
|
# initialize tensormap for A & B
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
678
|
+
if const_expr(not self.gather_A):
|
|
679
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
680
|
+
tma_atom_a,
|
|
681
|
+
tensormap_ab_ptrs[0],
|
|
682
|
+
is_tma_warp,
|
|
683
|
+
)
|
|
725
684
|
tensormap_manager.init_tensormap_from_atom(
|
|
726
685
|
tma_atom_b,
|
|
727
|
-
|
|
686
|
+
tensormap_ab_ptrs[1],
|
|
728
687
|
is_tma_warp,
|
|
729
688
|
)
|
|
730
689
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
@@ -762,16 +721,12 @@ class GemmSm90:
|
|
|
762
721
|
is_group_changed = batch_idx != last_batch_idx
|
|
763
722
|
last_batch_idx = batch_idx
|
|
764
723
|
if is_group_changed:
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1,
|
|
772
|
-
0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1,
|
|
773
|
-
),
|
|
774
|
-
tensormap_smem_ptr=None,
|
|
724
|
+
self.tensormap_update_AB(
|
|
725
|
+
tensormap_manager,
|
|
726
|
+
tensormap_ab_ptrs,
|
|
727
|
+
cu_seqlens_k,
|
|
728
|
+
batch_idx,
|
|
729
|
+
is_tma_warp,
|
|
775
730
|
)
|
|
776
731
|
# ///////////////////////////////////////////////////////////////////////////
|
|
777
732
|
# Local_tile partition global tensors
|
|
@@ -786,44 +741,54 @@ class GemmSm90:
|
|
|
786
741
|
# (bM, bK, RestK)
|
|
787
742
|
gA_k = cute.local_tile(
|
|
788
743
|
mA_mk,
|
|
789
|
-
cute.select(self.
|
|
744
|
+
cute.select(self.cta_tile_shape_mnk, [0, 2]),
|
|
790
745
|
(tile_coord_mnkl[0], None),
|
|
791
746
|
)
|
|
792
747
|
else:
|
|
793
|
-
mA_mk = mA_mkl
|
|
794
748
|
if const_expr(varlen_m):
|
|
795
749
|
mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
|
|
796
|
-
|
|
797
|
-
|
|
750
|
+
gAIdx = cute.local_tile(
|
|
751
|
+
mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
|
752
|
+
)
|
|
753
|
+
# (M, K)
|
|
754
|
+
mA_mk = mA_mkl
|
|
798
755
|
else:
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
756
|
+
assert varlen_k
|
|
757
|
+
mAIdx_mk = cute.domain_offset((cu_seqlens_k[batch_idx],), mAIdx)
|
|
758
|
+
# (tile_K, RestK)
|
|
759
|
+
gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
|
|
760
|
+
# (tile_M, K)
|
|
761
|
+
mA_mk = cute.local_tile(
|
|
762
|
+
mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
|
|
763
|
+
)
|
|
803
764
|
if const_expr(varlen_k):
|
|
804
765
|
mB_nk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mB_nkl)
|
|
805
766
|
else:
|
|
806
767
|
mB_nk = mB_nkl[None, None, batch_idx]
|
|
807
768
|
# (bN, bK, RestK)
|
|
808
769
|
gB_k = cute.local_tile(
|
|
809
|
-
mB_nk,
|
|
770
|
+
mB_nk,
|
|
771
|
+
cute.select(self.cta_tile_shape_mnk, [1, 2]),
|
|
772
|
+
(tile_coord_mnkl[1], None),
|
|
810
773
|
)
|
|
811
774
|
# //////////////////////////////////////////////////////////////////////////
|
|
812
775
|
# Partition shared tensor for TMA load A/B
|
|
813
776
|
# //////////////////////////////////////////////////////////////////////////
|
|
777
|
+
tma_desc_a_ptr, tma_desc_b_ptr = None, None
|
|
814
778
|
if const_expr(varlen_k):
|
|
815
779
|
# ensure the update to tensormap has completed before using it
|
|
780
|
+
tensormap_a_ptr, tensormap_b_ptr = tensormap_ab_ptrs
|
|
816
781
|
if is_group_changed and is_tma_warp:
|
|
817
|
-
|
|
782
|
+
if const_expr(not self.gather_A):
|
|
783
|
+
tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
|
|
818
784
|
tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
785
|
+
if const_expr(not self.gather_A):
|
|
786
|
+
tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
787
|
+
tensormap_a_ptr, cute.AddressSpace.generic
|
|
788
|
+
)
|
|
822
789
|
tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
823
790
|
tensormap_b_ptr, cute.AddressSpace.generic
|
|
824
791
|
)
|
|
825
|
-
else:
|
|
826
|
-
tma_desc_a_ptr, tma_desc_b_ptr = None, None
|
|
827
792
|
# TMA load A partition_S/D
|
|
828
793
|
a_cta_layout = cute.make_layout(
|
|
829
794
|
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
@@ -855,8 +820,11 @@ class GemmSm90:
|
|
|
855
820
|
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
856
821
|
# (atom_v, CPY_M, 1, STAGE)
|
|
857
822
|
tAsA = thr_copy_A.partition_D(sA)
|
|
858
|
-
|
|
859
|
-
|
|
823
|
+
if const_expr(varlen_m): # k-major
|
|
824
|
+
assert tAsA.shape[2] == 1
|
|
825
|
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
|
826
|
+
else: # varlen_k, m-major
|
|
827
|
+
tAsA = cute.group_modes(tAsA, 0, 3)
|
|
860
828
|
copy_A = partial(cute.copy, tiled_copy_A)
|
|
861
829
|
# TMA load B partition_S/D
|
|
862
830
|
b_cta_layout = cute.make_layout(
|
|
@@ -877,9 +845,9 @@ class GemmSm90:
|
|
|
877
845
|
k_len = (
|
|
878
846
|
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
879
847
|
if const_expr(varlen_k)
|
|
880
|
-
else mA_mkl.shape[1]
|
|
848
|
+
else Int32(mA_mkl.shape[1])
|
|
881
849
|
)
|
|
882
|
-
k_tile_cnt = cute.ceil_div(k_len, self.
|
|
850
|
+
k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
|
|
883
851
|
if const_expr(not self.gather_A):
|
|
884
852
|
ab_producer_state = self.load_AB(
|
|
885
853
|
ab_pipeline,
|
|
@@ -894,7 +862,7 @@ class GemmSm90:
|
|
|
894
862
|
)
|
|
895
863
|
else:
|
|
896
864
|
limit_m = (
|
|
897
|
-
|
|
865
|
+
Int32(mA_mkl.shape[0])
|
|
898
866
|
if const_expr(cu_seqlens_m is None)
|
|
899
867
|
else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
|
|
900
868
|
)
|
|
@@ -910,19 +878,17 @@ class GemmSm90:
|
|
|
910
878
|
tBsB,
|
|
911
879
|
k_tile_cnt,
|
|
912
880
|
limit_A=(
|
|
913
|
-
limit_m - tile_coord_mnkl[0] * self.
|
|
914
|
-
|
|
881
|
+
limit_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
882
|
+
k_len,
|
|
915
883
|
),
|
|
884
|
+
varlen_m=varlen_m,
|
|
916
885
|
)
|
|
917
886
|
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
918
|
-
tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
919
887
|
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
920
888
|
work_tile = tile_scheduler.get_current_work()
|
|
921
889
|
# End of persistent scheduler loop
|
|
922
890
|
if const_expr(self.pingpong and not varlen_k):
|
|
923
891
|
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
|
924
|
-
# tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
925
|
-
tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
926
892
|
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
927
893
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
928
894
|
if is_scheduler_warp:
|
|
@@ -936,11 +902,20 @@ class GemmSm90:
|
|
|
936
902
|
)
|
|
937
903
|
if const_expr(varlen_m):
|
|
938
904
|
# initialize tensormap for D
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
905
|
+
if const_expr(has_D):
|
|
906
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
907
|
+
tma_atom_d,
|
|
908
|
+
tensormap_d_ptr,
|
|
909
|
+
is_manager_warp=is_tma_warp,
|
|
910
|
+
)
|
|
911
|
+
for tma_atom, tensormap_epi_ptr in zip(
|
|
912
|
+
self.epi_get_tma_atoms(epilogue_params), tensormap_epi_ptrs
|
|
913
|
+
):
|
|
914
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
915
|
+
tma_atom,
|
|
916
|
+
tensormap_epi_ptr,
|
|
917
|
+
is_manager_warp=is_tma_warp,
|
|
918
|
+
)
|
|
944
919
|
# //////////////////////////////////////////////////////////////////////////////
|
|
945
920
|
# Partition global tensor for TiledMMA_A/B/C
|
|
946
921
|
# //////////////////////////////////////////////////////////////////////////////
|
|
@@ -962,7 +937,9 @@ class GemmSm90:
|
|
|
962
937
|
tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
|
|
963
938
|
tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
|
|
964
939
|
|
|
965
|
-
acc_shape = tiled_mma.partition_shape_C(
|
|
940
|
+
acc_shape = tiled_mma.partition_shape_C(
|
|
941
|
+
cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
|
|
942
|
+
)
|
|
966
943
|
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
967
944
|
acc_slow = None
|
|
968
945
|
if const_expr(self.fp8_slow_accum):
|
|
@@ -974,10 +951,11 @@ class GemmSm90:
|
|
|
974
951
|
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
|
975
952
|
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
|
976
953
|
|
|
977
|
-
k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.
|
|
978
|
-
c_tile_cnt = cute.size(cute.ceil_div(self.
|
|
954
|
+
k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2])
|
|
955
|
+
c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
|
|
979
956
|
|
|
980
957
|
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
|
958
|
+
epi_store_pipeline = self.make_epi_store_pipeline()
|
|
981
959
|
epi_read_state = make_pipeline_state(
|
|
982
960
|
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
|
983
961
|
)
|
|
@@ -998,7 +976,7 @@ class GemmSm90:
|
|
|
998
976
|
else:
|
|
999
977
|
batch_idx = work_tile.tile_idx[3]
|
|
1000
978
|
k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1001
|
-
k_tile_cnt = cute.ceil_div(k_len, self.
|
|
979
|
+
k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
|
|
1002
980
|
ab_read_state.advance_iters(k_tile_cnt)
|
|
1003
981
|
tile_scheduler.advance_to_next_work()
|
|
1004
982
|
if const_expr(varlen_k):
|
|
@@ -1019,13 +997,14 @@ class GemmSm90:
|
|
|
1019
997
|
is_group_changed = batch_idx != last_batch_idx
|
|
1020
998
|
last_batch_idx = batch_idx
|
|
1021
999
|
if is_group_changed:
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1000
|
+
self.tensormap_update_D_epi(
|
|
1001
|
+
tensormap_manager,
|
|
1002
|
+
tensormap_d_ptr,
|
|
1003
|
+
tensormap_epi_ptrs,
|
|
1004
|
+
epilogue_params,
|
|
1005
|
+
cu_seqlens_m,
|
|
1006
|
+
batch_idx,
|
|
1025
1007
|
is_manager_warp=is_tma_warp,
|
|
1026
|
-
shapes=(cu_seqlens_m[batch_idx + 1],),
|
|
1027
|
-
orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
|
|
1028
|
-
tensormap_smem_ptr=None,
|
|
1029
1008
|
)
|
|
1030
1009
|
|
|
1031
1010
|
k_len = (
|
|
@@ -1033,7 +1012,7 @@ class GemmSm90:
|
|
|
1033
1012
|
if const_expr(varlen_k)
|
|
1034
1013
|
else mA_mkl.shape[1]
|
|
1035
1014
|
)
|
|
1036
|
-
k_tile_cnt = cute.ceil_div(k_len, self.
|
|
1015
|
+
k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
|
|
1037
1016
|
ab_read_state, tiled_mma = self.mma(
|
|
1038
1017
|
ab_pipeline,
|
|
1039
1018
|
ab_read_state,
|
|
@@ -1056,24 +1035,34 @@ class GemmSm90:
|
|
|
1056
1035
|
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
|
1057
1036
|
|
|
1058
1037
|
epilogue_barrier = pipeline.NamedBarrier(
|
|
1059
|
-
barrier_id=int(NamedBarrierGemm.Epilogue),
|
|
1038
|
+
barrier_id=int(NamedBarrierGemm.Epilogue),
|
|
1039
|
+
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
|
1060
1040
|
)
|
|
1061
1041
|
|
|
1042
|
+
tma_desc_d_ptr, tma_desc_epi_ptrs = None, [None] * self.num_epi_tensormaps
|
|
1062
1043
|
if const_expr(varlen_m):
|
|
1063
1044
|
# ensure the update to tensormap has completed before using it
|
|
1064
1045
|
if is_group_changed and is_tma_warp:
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1046
|
+
if const_expr(has_D):
|
|
1047
|
+
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1048
|
+
for tensormap_epi_ptr in tensormap_epi_ptrs:
|
|
1049
|
+
tensormap_manager.fence_tensormap_update(tensormap_epi_ptr)
|
|
1050
|
+
if const_expr(has_D):
|
|
1051
|
+
tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1052
|
+
tensormap_d_ptr, cute.AddressSpace.generic
|
|
1053
|
+
)
|
|
1054
|
+
tma_desc_epi_ptrs = [
|
|
1055
|
+
tensormap_manager.get_tensormap_ptr(
|
|
1056
|
+
tensormap_epi_ptr, cute.AddressSpace.generic
|
|
1057
|
+
)
|
|
1058
|
+
for tensormap_epi_ptr in tensormap_epi_ptrs
|
|
1059
|
+
]
|
|
1071
1060
|
|
|
1072
1061
|
if const_expr(has_D):
|
|
1073
1062
|
bSG_sD, bSG_gD = self.epilog_gmem_copy_and_partition(
|
|
1074
1063
|
tma_atom_d,
|
|
1075
1064
|
mD_mnl,
|
|
1076
|
-
self.
|
|
1065
|
+
self.cta_tile_shape_mnk[:2],
|
|
1077
1066
|
self.epi_tile,
|
|
1078
1067
|
sD,
|
|
1079
1068
|
tile_coord_mnkl,
|
|
@@ -1086,7 +1075,7 @@ class GemmSm90:
|
|
|
1086
1075
|
bGS_sC, bGS_gC = self.epilog_gmem_copy_and_partition(
|
|
1087
1076
|
tma_atom_c,
|
|
1088
1077
|
mC_mnl,
|
|
1089
|
-
self.
|
|
1078
|
+
self.cta_tile_shape_mnk[:2],
|
|
1090
1079
|
self.epi_tile,
|
|
1091
1080
|
sC,
|
|
1092
1081
|
tile_coord_mnkl,
|
|
@@ -1118,7 +1107,9 @@ class GemmSm90:
|
|
|
1118
1107
|
epi_read_state, epi_producer_state = self.epilogue(
|
|
1119
1108
|
epilogue_params,
|
|
1120
1109
|
epi_smem_tensors,
|
|
1110
|
+
tma_desc_epi_ptrs,
|
|
1121
1111
|
epi_pipeline,
|
|
1112
|
+
epi_store_pipeline,
|
|
1122
1113
|
epi_read_state,
|
|
1123
1114
|
epi_producer_state,
|
|
1124
1115
|
tiled_mma,
|
|
@@ -1147,7 +1138,7 @@ class GemmSm90:
|
|
|
1147
1138
|
# so we have to make sure the smem content is done reading before signaling
|
|
1148
1139
|
# the next WG's epilogue.
|
|
1149
1140
|
if is_tma_warp:
|
|
1150
|
-
|
|
1141
|
+
epi_store_pipeline.producer_tail()
|
|
1151
1142
|
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
|
1152
1143
|
|
|
1153
1144
|
if const_expr(not self.pingpong):
|
|
@@ -1168,15 +1159,16 @@ class GemmSm90:
|
|
|
1168
1159
|
if work_tile.is_valid_tile:
|
|
1169
1160
|
batch_idx = work_tile.tile_idx[3]
|
|
1170
1161
|
k_len = cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1171
|
-
k_tile_cnt = cute.ceil_div(k_len, self.
|
|
1162
|
+
k_tile_cnt = cute.ceil_div(k_len, self.cta_tile_shape_mnk[2])
|
|
1172
1163
|
ab_read_state.advance_iters(k_tile_cnt)
|
|
1173
1164
|
tile_scheduler.advance_to_next_work()
|
|
1174
1165
|
work_tile = tile_scheduler.get_current_work()
|
|
1175
1166
|
# End of persistent scheduler loop
|
|
1176
1167
|
|
|
1168
|
+
# Wait for D store complete
|
|
1177
1169
|
if const_expr(not self.pingpong):
|
|
1178
1170
|
if is_tma_warp:
|
|
1179
|
-
|
|
1171
|
+
epi_store_pipeline.producer_tail()
|
|
1180
1172
|
|
|
1181
1173
|
@cute.jit
|
|
1182
1174
|
def load_AB(
|
|
@@ -1219,37 +1211,57 @@ class GemmSm90:
|
|
|
1219
1211
|
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1220
1212
|
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1221
1213
|
thr_copy_A: cute.core.ThrCopy,
|
|
1222
|
-
mA: cute.Tensor,
|
|
1214
|
+
mA: cute.Tensor, # (M, K) if varlen_m, (tile_M, K) if varlen_k
|
|
1223
1215
|
tAsA: cute.Tensor,
|
|
1224
|
-
gAIdx: cute.Tensor,
|
|
1216
|
+
gAIdx: cute.Tensor, # (tile_M,) if varlen_m, (tile_K, RestK) if varlen_k
|
|
1225
1217
|
copy_B: Callable,
|
|
1226
1218
|
tBgB: cute.Tensor,
|
|
1227
1219
|
tBsB: cute.Tensor,
|
|
1228
1220
|
k_tile_cnt: Int32,
|
|
1229
1221
|
limit_A: Tuple[Int32, Int32],
|
|
1222
|
+
varlen_m: bool,
|
|
1230
1223
|
) -> cutlass.pipeline.PipelineState:
|
|
1231
|
-
# (atom_v, CPY_M, 1, RestK)
|
|
1232
1224
|
limit_m, limit_k = limit_A
|
|
1233
|
-
|
|
1234
|
-
|
|
1225
|
+
# Do we need to check if we overshoot tile_M when we load A?
|
|
1226
|
+
is_even_m_smem = self.cta_tile_shape_mnk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
|
1227
|
+
if const_expr(not is_even_m_smem):
|
|
1228
|
+
limit_m = min(limit_m, self.cta_tile_shape_mnk[0])
|
|
1229
|
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
1230
|
+
cA = cute.make_identity_tensor(cute.select(self.cta_tile_shape_mnk, [0, 2]))
|
|
1235
1231
|
tAcA = thr_copy_A.partition_S(cA)
|
|
1236
1232
|
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
1237
1233
|
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
1238
1234
|
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
1239
1235
|
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
1240
1236
|
limit_m = limit_m - tAcA[0][0]
|
|
1237
|
+
limit_k = limit_k - tAcA[0][1]
|
|
1241
1238
|
# Read indices for A
|
|
1242
1239
|
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1240
|
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
1241
|
+
tApA_m = cute.make_fragment(rows_per_thread, Boolean)
|
|
1242
|
+
for m in cutlass.range_constexpr(rows_per_thread):
|
|
1243
|
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
1244
|
+
m_idx, k_idx, tAmA = None, None, None
|
|
1245
|
+
if const_expr(varlen_m):
|
|
1246
|
+
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
|
1247
|
+
for m in cutlass.range(rows_per_thread):
|
|
1248
|
+
row_idx = tAcA[0, m, 0][0]
|
|
1249
|
+
if tApA_m[m]:
|
|
1250
|
+
m_idx[m] = gAIdx[row_idx]
|
|
1251
|
+
else:
|
|
1252
|
+
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
|
1253
|
+
else:
|
|
1254
|
+
k_idx = cute.make_fragment(cols_per_thread, Int32) # Will be read later
|
|
1255
|
+
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
|
1256
|
+
# This is very convoluted but idk a better way
|
|
1257
|
+
# for tile_M=128, flat_divide gives (8, 16, K),
|
|
1258
|
+
# then logical_divide gives ((8, 1), (8, 2), K).
|
|
1259
|
+
tidx = thr_copy_A.thr_idx
|
|
1260
|
+
tAmA = cute.logical_divide(
|
|
1261
|
+
cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
|
|
1262
|
+
)[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
|
|
1251
1263
|
# (m, (bK, RestK))
|
|
1252
|
-
mA_k = cute.logical_divide(mA, (None, self.
|
|
1264
|
+
mA_k = cute.logical_divide(mA, (None, self.cta_tile_shape_mnk[2]))
|
|
1253
1265
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
1254
1266
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1255
1267
|
peek_ab_empty_status = Boolean(True)
|
|
@@ -1260,31 +1272,55 @@ class GemmSm90:
|
|
|
1260
1272
|
# /////////////////////////////////////////////////////////////////////////
|
|
1261
1273
|
copy_A = partial(cute.copy, thr_copy_A)
|
|
1262
1274
|
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1275
|
+
if const_expr(not varlen_m): # Prefetch mAIdx early, even before smem is free
|
|
1276
|
+
gAIdx_cur = gAIdx[None, k_tile]
|
|
1277
|
+
for k in cutlass.range(cols_per_thread):
|
|
1278
|
+
col_idx = tAcA[0, 0, k][1]
|
|
1279
|
+
k_idx[k] = gAIdx_cur[col_idx]
|
|
1263
1280
|
# Wait for A/B buffers to be empty before loading into them
|
|
1264
1281
|
# Also sets the transaction barrier for the A/B buffers
|
|
1282
|
+
# A tiny bit faster to rotate the warp that does TMA
|
|
1283
|
+
# However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
|
|
1284
|
+
# since that's the warp that does the tensormap update.
|
|
1285
|
+
tma_warp_id = self.ab_load_warp_id + (
|
|
1286
|
+
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
|
1287
|
+
)
|
|
1265
1288
|
ab_pipeline.producer_acquire(
|
|
1266
1289
|
ab_producer_state,
|
|
1267
1290
|
peek_ab_empty_status,
|
|
1268
|
-
|
|
1269
|
-
is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
|
|
1291
|
+
is_tma_warp=warp_idx == tma_warp_id,
|
|
1270
1292
|
)
|
|
1271
1293
|
# A bit faster to load B first while we calculate the predicate for A
|
|
1272
|
-
if warp_idx ==
|
|
1294
|
+
if warp_idx == tma_warp_id:
|
|
1273
1295
|
copy_B(
|
|
1274
1296
|
tBgB[None, k_tile],
|
|
1275
1297
|
tBsB[None, ab_producer_state.index],
|
|
1276
1298
|
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1277
1299
|
)
|
|
1278
1300
|
# (m, bK)
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
#
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1301
|
+
if const_expr(varlen_m):
|
|
1302
|
+
mA_cur = mA_k[None, (None, k_tile)]
|
|
1303
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1304
|
+
# cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
|
|
1305
|
+
# ((elems_per_load), thread_per_row)
|
|
1306
|
+
# But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
|
|
1307
|
+
# So we append 1s to the last dimension and then do tiled_divide, then slice.
|
|
1308
|
+
mA_row = cute.tiled_divide(
|
|
1309
|
+
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
|
1310
|
+
)[None, None, 0]
|
|
1311
|
+
if const_expr(is_even_m_smem) or tApA_m[m]:
|
|
1312
|
+
# There's only 1 load per row
|
|
1313
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1314
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1315
|
+
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1316
|
+
else:
|
|
1317
|
+
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
|
1318
|
+
# copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), ab_producer_state.index], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
|
|
1319
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1320
|
+
if tApA_m[m]:
|
|
1321
|
+
copy_A(
|
|
1322
|
+
tAmA[None, m, k_idx[k]], tAsA[(None, m, k), ab_producer_state.index]
|
|
1323
|
+
)
|
|
1288
1324
|
# This tells mbarrier to track the completion of cp.async
|
|
1289
1325
|
ab_pipeline.producer_commit(ab_producer_state)
|
|
1290
1326
|
ab_producer_state.advance()
|
|
@@ -1294,32 +1330,57 @@ class GemmSm90:
|
|
|
1294
1330
|
# bound checking in the K dimension on the last k_tile
|
|
1295
1331
|
if 0 < k_tile_cnt:
|
|
1296
1332
|
k_tile = k_tile_cnt - 1
|
|
1333
|
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
|
1334
|
+
limit_k -= k_tile * self.cta_tile_shape_mnk[2]
|
|
1335
|
+
for k in cutlass.range_constexpr(cols_per_thread):
|
|
1336
|
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k
|
|
1337
|
+
if const_expr(not varlen_m):
|
|
1338
|
+
gAIdx_cur = gAIdx[None, k_tile]
|
|
1339
|
+
for k in cutlass.range(cols_per_thread):
|
|
1340
|
+
col_idx = tAcA[0, 0, k][1]
|
|
1341
|
+
if tApA_k[k]:
|
|
1342
|
+
k_idx[k] = gAIdx_cur[col_idx]
|
|
1343
|
+
else:
|
|
1344
|
+
k_idx[k] = -1
|
|
1345
|
+
tma_warp_id = self.ab_load_warp_id + (
|
|
1346
|
+
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
|
1347
|
+
)
|
|
1297
1348
|
ab_pipeline.producer_acquire(
|
|
1298
1349
|
ab_producer_state,
|
|
1299
1350
|
peek_ab_empty_status,
|
|
1300
|
-
is_tma_warp=warp_idx ==
|
|
1351
|
+
is_tma_warp=warp_idx == tma_warp_id,
|
|
1301
1352
|
)
|
|
1302
|
-
if warp_idx ==
|
|
1353
|
+
if warp_idx == tma_warp_id:
|
|
1303
1354
|
copy_B(
|
|
1304
1355
|
tBgB[None, k_tile],
|
|
1305
1356
|
tBsB[None, ab_producer_state.index],
|
|
1306
1357
|
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1307
1358
|
)
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1359
|
+
if const_expr(varlen_m):
|
|
1360
|
+
# (m, bK)
|
|
1361
|
+
mA_cur = mA_k[None, (None, k_tile)]
|
|
1362
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1363
|
+
# ((elems_per_load, 1), thread_per_row)
|
|
1364
|
+
mA_row = cute.tiled_divide(
|
|
1365
|
+
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
|
1366
|
+
)[None, None, 0]
|
|
1367
|
+
if const_expr(is_even_m_smem) or tApA_m[k]:
|
|
1368
|
+
# There's only 1 load per row
|
|
1369
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1370
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1371
|
+
copy_A(
|
|
1372
|
+
mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA_k
|
|
1373
|
+
)
|
|
1374
|
+
else:
|
|
1375
|
+
tApA_k = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
|
|
1376
|
+
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
|
1377
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1378
|
+
if tApA_m[m]:
|
|
1379
|
+
copy_A(
|
|
1380
|
+
tAmA[None, m, k_idx[k]],
|
|
1381
|
+
tAsA[(None, m, k), ab_producer_state.index],
|
|
1382
|
+
pred=tApA_k[None, k],
|
|
1383
|
+
)
|
|
1323
1384
|
ab_pipeline.producer_commit(ab_producer_state)
|
|
1324
1385
|
ab_producer_state.advance()
|
|
1325
1386
|
return ab_producer_state
|
|
@@ -1416,7 +1477,9 @@ class GemmSm90:
|
|
|
1416
1477
|
self,
|
|
1417
1478
|
params: EpilogueParams,
|
|
1418
1479
|
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
1480
|
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
|
1419
1481
|
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1482
|
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1420
1483
|
epi_read_state: cutlass.pipeline.PipelineState,
|
|
1421
1484
|
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
1422
1485
|
tiled_mma: cute.TiledMma,
|
|
@@ -1443,7 +1506,7 @@ class GemmSm90:
|
|
|
1443
1506
|
has_D = const_expr(copy_D is not None)
|
|
1444
1507
|
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1445
1508
|
epi_tile_shape = cute.zipped_divide(
|
|
1446
|
-
cute.make_layout(self.
|
|
1509
|
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), self.epi_tile
|
|
1447
1510
|
).shape[1]
|
|
1448
1511
|
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
1449
1512
|
epi_tile_num = cute.size(epi_tile_shape)
|
|
@@ -1491,8 +1554,8 @@ class GemmSm90:
|
|
|
1491
1554
|
if is_tma_warp:
|
|
1492
1555
|
if const_expr(has_D):
|
|
1493
1556
|
copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
|
|
1494
|
-
|
|
1495
|
-
|
|
1557
|
+
epi_store_pipeline.producer_commit()
|
|
1558
|
+
epi_store_pipeline.producer_acquire()
|
|
1496
1559
|
epilogue_barrier.arrive_and_wait()
|
|
1497
1560
|
|
|
1498
1561
|
return epi_read_state, epi_producer_state
|
|
@@ -1544,21 +1607,171 @@ class GemmSm90:
|
|
|
1544
1607
|
tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
|
|
1545
1608
|
return None
|
|
1546
1609
|
|
|
1547
|
-
def
|
|
1610
|
+
def tensormap_init(
|
|
1611
|
+
self,
|
|
1612
|
+
tensormaps: Optional[cute.Tensor],
|
|
1613
|
+
varlen_m: bool,
|
|
1614
|
+
varlen_k: bool,
|
|
1615
|
+
has_D: bool,
|
|
1616
|
+
warp_idx: Int32,
|
|
1617
|
+
):
|
|
1618
|
+
tensormap_manager = None
|
|
1619
|
+
tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None
|
|
1620
|
+
tensormap_epi_ptrs = [None] * self.num_epi_tensormaps
|
|
1621
|
+
if const_expr(varlen_m or varlen_k):
|
|
1622
|
+
tensormap_manager = TensorMapManagerSm90(
|
|
1623
|
+
cutlass.utils.TensorMapUpdateMode.GMEM, self.__class__.bytes_per_tensormap
|
|
1624
|
+
)
|
|
1625
|
+
# equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y
|
|
1626
|
+
tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx())
|
|
1627
|
+
if const_expr(varlen_m):
|
|
1628
|
+
tensormap_d_idx = warp_idx // 4 if const_expr(self.pingpong) else 0
|
|
1629
|
+
tensormap_epi_offset = tensormap_d_idx
|
|
1630
|
+
if const_expr(has_D):
|
|
1631
|
+
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1632
|
+
tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator
|
|
1633
|
+
)
|
|
1634
|
+
tensormap_epi_offset += 1 if not self.pingpong else 2
|
|
1635
|
+
tensormap_epi_ptrs = [
|
|
1636
|
+
tensormap_manager.get_tensormap_ptr(
|
|
1637
|
+
tensormaps[
|
|
1638
|
+
tensormap_workspace_idx,
|
|
1639
|
+
tensormap_epi_offset + i * (1 if not self.pingpong else 2),
|
|
1640
|
+
None,
|
|
1641
|
+
].iterator
|
|
1642
|
+
)
|
|
1643
|
+
for i in range(self.num_epi_tensormaps)
|
|
1644
|
+
]
|
|
1645
|
+
else:
|
|
1646
|
+
assert varlen_k
|
|
1647
|
+
if const_expr(not self.gather_A):
|
|
1648
|
+
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1649
|
+
tensormaps[tensormap_workspace_idx, 0, None].iterator
|
|
1650
|
+
)
|
|
1651
|
+
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1652
|
+
tensormaps[
|
|
1653
|
+
tensormap_workspace_idx, 1 if not self.gather_A else 0, None
|
|
1654
|
+
].iterator
|
|
1655
|
+
)
|
|
1656
|
+
tensormap_ab_ptrs = [tensormap_a_ptr, tensormap_b_ptr]
|
|
1657
|
+
return (
|
|
1658
|
+
tensormap_manager,
|
|
1659
|
+
tensormap_ab_ptrs,
|
|
1660
|
+
tensormap_d_ptr,
|
|
1661
|
+
tensormap_epi_ptrs,
|
|
1662
|
+
)
|
|
1663
|
+
|
|
1664
|
+
def tensormap_update_AB(
|
|
1665
|
+
self,
|
|
1666
|
+
tensormap_manager,
|
|
1667
|
+
tensormap_ab_ptrs,
|
|
1668
|
+
cu_seqlens_k: cute.Tensor,
|
|
1669
|
+
batch_idx: Int32,
|
|
1670
|
+
is_manager_warp: bool | Boolean,
|
|
1671
|
+
) -> None:
|
|
1672
|
+
# construct tensor A/B based on real address, shape and stride information
|
|
1673
|
+
tensormap_a_ptr, tensormap_b_ptr = tensormap_ab_ptrs
|
|
1674
|
+
tensormap_ptrs = [tensormap_b_ptr]
|
|
1675
|
+
shapes = [cu_seqlens_k[batch_idx + 1]]
|
|
1676
|
+
orders = [0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1]
|
|
1677
|
+
if const_expr(not self.gather_A):
|
|
1678
|
+
tensormap_ptrs.insert(0, tensormap_a_ptr)
|
|
1679
|
+
shapes.insert(0, cu_seqlens_k[batch_idx + 1])
|
|
1680
|
+
orders.insert(0, 0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1)
|
|
1681
|
+
tensormap_manager.update_tensormap_shape(
|
|
1682
|
+
tensormap_ptrs,
|
|
1683
|
+
is_manager_warp=is_manager_warp,
|
|
1684
|
+
shapes=shapes,
|
|
1685
|
+
orders=orders,
|
|
1686
|
+
tensormap_smem_ptr=None,
|
|
1687
|
+
)
|
|
1688
|
+
|
|
1689
|
+
def tensormap_update_D_epi(
|
|
1690
|
+
self,
|
|
1691
|
+
tensormap_manager,
|
|
1692
|
+
tensormap_d_ptr,
|
|
1693
|
+
tensormap_epi_ptrs,
|
|
1694
|
+
epilogue_params: EpilogueParams,
|
|
1695
|
+
cu_seqlens_m: cute.Tensor,
|
|
1696
|
+
batch_idx: Int32,
|
|
1697
|
+
is_manager_warp: bool | Boolean,
|
|
1698
|
+
) -> None:
|
|
1699
|
+
# construct tensor D based on real address, shape and stride information
|
|
1700
|
+
tensormap_ptrs, shapes, orders = [], [], []
|
|
1701
|
+
if const_expr(tensormap_d_ptr is not None):
|
|
1702
|
+
tensormap_ptrs.append(tensormap_d_ptr)
|
|
1703
|
+
shapes.append(cu_seqlens_m[batch_idx + 1])
|
|
1704
|
+
orders.append(0 if const_expr(self.d_layout.is_m_major_c()) else 1)
|
|
1705
|
+
epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
|
|
1706
|
+
epilogue_params, cu_seqlens_m, batch_idx
|
|
1707
|
+
)
|
|
1708
|
+
tensormap_ptrs.extend(tensormap_epi_ptrs)
|
|
1709
|
+
shapes.extend(epi_shapes)
|
|
1710
|
+
orders.extend(epi_orders)
|
|
1711
|
+
tensormap_manager.update_tensormap_shape(
|
|
1712
|
+
tensormap_ptrs,
|
|
1713
|
+
is_manager_warp=is_manager_warp,
|
|
1714
|
+
shapes=shapes,
|
|
1715
|
+
orders=orders,
|
|
1716
|
+
tensormap_smem_ptr=None,
|
|
1717
|
+
)
|
|
1718
|
+
|
|
1719
|
+
def get_scheduler_class(self, varlen_m: bool = False):
|
|
1548
1720
|
"""Return the scheduler class to use. Override in subclasses for custom schedulers."""
|
|
1549
|
-
return TileScheduler
|
|
1721
|
+
return TileScheduler if not varlen_m else VarlenMTileScheduler
|
|
1550
1722
|
|
|
1551
|
-
def get_scheduler_arguments(
|
|
1723
|
+
def get_scheduler_arguments(
|
|
1724
|
+
self,
|
|
1725
|
+
mA: cute.Tensor,
|
|
1726
|
+
mB: cute.Tensor,
|
|
1727
|
+
mD: Optional[cute.Tensor],
|
|
1728
|
+
scheduler_args,
|
|
1729
|
+
varlen_args,
|
|
1730
|
+
):
|
|
1552
1731
|
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1732
|
+
if const_expr(varlen_args.mCuSeqlensM is None):
|
|
1733
|
+
num_problems = (
|
|
1734
|
+
mD.shape[2]
|
|
1735
|
+
if mD is not None
|
|
1736
|
+
else (
|
|
1737
|
+
mB.shape[2]
|
|
1738
|
+
if varlen_args.mCuSeqlensK is None
|
|
1739
|
+
else varlen_args.mCuSeqlensK.shape[0] - 1
|
|
1740
|
+
)
|
|
1741
|
+
)
|
|
1742
|
+
problem_shape_ntile_mnl = (
|
|
1743
|
+
cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]),
|
|
1744
|
+
cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
|
|
1745
|
+
num_problems,
|
|
1746
|
+
)
|
|
1747
|
+
tile_sched_args = TileSchedulerArguments(
|
|
1748
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
1749
|
+
raster_order=scheduler_args.raster_order,
|
|
1750
|
+
group_size=scheduler_args.max_swizzle_size,
|
|
1751
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1752
|
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1753
|
+
batch_idx_permute=scheduler_args.batch_idx_permute,
|
|
1754
|
+
is_persistent=self.is_persistent,
|
|
1755
|
+
)
|
|
1756
|
+
else:
|
|
1757
|
+
assert mD is not None or not self.gather_A
|
|
1758
|
+
problem_shape_ntile_mnl = (
|
|
1759
|
+
None,
|
|
1760
|
+
cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
|
|
1761
|
+
varlen_args.mCuSeqlensM.shape[0] - 1,
|
|
1762
|
+
)
|
|
1763
|
+
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
1764
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
1765
|
+
total_m=mD.shape[0] if mD is not None else varlen_args.mAIdx.shape[0],
|
|
1766
|
+
cu_seqlens_m=varlen_args.mCuSeqlensM,
|
|
1767
|
+
raster_order=scheduler_args.raster_order,
|
|
1768
|
+
group_size=scheduler_args.max_swizzle_size,
|
|
1769
|
+
tile_shape_mn=self.cta_tile_shape_mnk[:2],
|
|
1770
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1771
|
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1772
|
+
is_persistent=self.is_persistent,
|
|
1773
|
+
)
|
|
1774
|
+
return tile_sched_args
|
|
1562
1775
|
|
|
1563
1776
|
def epi_visit_acc(
|
|
1564
1777
|
self,
|
|
@@ -1575,10 +1788,28 @@ class GemmSm90:
|
|
|
1575
1788
|
) -> EpilogueParams:
|
|
1576
1789
|
return GemmSm90.EpilogueParams(alpha=args.alpha, beta=args.beta)
|
|
1577
1790
|
|
|
1791
|
+
def epi_get_tma_atoms(
|
|
1792
|
+
self, params: EpilogueParams, *, loc=None, ip=None
|
|
1793
|
+
) -> list[cute.CopyAtom]:
|
|
1794
|
+
"""Subclasses can override this"""
|
|
1795
|
+
return []
|
|
1796
|
+
|
|
1797
|
+
def epi_get_tensormap_update_shapes_orders(
|
|
1798
|
+
self,
|
|
1799
|
+
params: EpilogueParams,
|
|
1800
|
+
cu_seqlens_m: cute.Tensor,
|
|
1801
|
+
batch_idx: Int32,
|
|
1802
|
+
*,
|
|
1803
|
+
loc=None,
|
|
1804
|
+
ip=None,
|
|
1805
|
+
) -> tuple[list[Int32], list[int]]:
|
|
1806
|
+
"""Subclasses can override this"""
|
|
1807
|
+
return [], []
|
|
1808
|
+
|
|
1578
1809
|
@staticmethod
|
|
1579
1810
|
def epi_smem_bytes_per_stage(
|
|
1580
1811
|
args: Optional[EpilogueArguments],
|
|
1581
|
-
|
|
1812
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1582
1813
|
epi_tile: Tuple[int, int],
|
|
1583
1814
|
) -> int:
|
|
1584
1815
|
return 0
|
|
@@ -1589,7 +1820,7 @@ class GemmSm90:
|
|
|
1589
1820
|
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
1590
1821
|
return tuple()
|
|
1591
1822
|
|
|
1592
|
-
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage:
|
|
1823
|
+
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
|
|
1593
1824
|
assert stage in ["mma", "epi"]
|
|
1594
1825
|
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1595
1826
|
cute.arch.barrier(
|
|
@@ -1597,7 +1828,7 @@ class GemmSm90:
|
|
|
1597
1828
|
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1598
1829
|
)
|
|
1599
1830
|
|
|
1600
|
-
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage:
|
|
1831
|
+
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
|
|
1601
1832
|
assert stage in ["mma", "epi"]
|
|
1602
1833
|
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1603
1834
|
cute.arch.barrier_arrive(
|
|
@@ -1691,8 +1922,6 @@ class GemmSm90:
|
|
|
1691
1922
|
|
|
1692
1923
|
def make_ab_pipeline(
|
|
1693
1924
|
self,
|
|
1694
|
-
a_smem_layout: cute.Layout | cute.ComposedLayout,
|
|
1695
|
-
b_smem_layout: cute.Layout | cute.ComposedLayout,
|
|
1696
1925
|
tiled_mma: cute.TiledMma,
|
|
1697
1926
|
cluster_layout_vmnk: cute.Layout,
|
|
1698
1927
|
ab_pipeline_mbar_ptr: cute.Pointer,
|
|
@@ -1702,20 +1931,23 @@ class GemmSm90:
|
|
|
1702
1931
|
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
1703
1932
|
# Each warp will contribute to the arrive count with the number of mcast size
|
|
1704
1933
|
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
1705
|
-
consumer_arrive_cnt = mcast_size
|
|
1934
|
+
consumer_arrive_cnt = mcast_size
|
|
1935
|
+
if const_expr(self.arch != 100):
|
|
1936
|
+
consumer_arrive_cnt *= tiled_mma.size // cute.arch.WARP_SIZE
|
|
1706
1937
|
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1707
1938
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1708
1939
|
)
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1940
|
+
if const_expr(self.arch != 100):
|
|
1941
|
+
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
1942
|
+
else:
|
|
1943
|
+
# TODO: we need a pipeline class for TMACpAsyncUMMA
|
|
1944
|
+
pipeline_cls = pipeline.PipelineTmaUmma if not self.gather_A else PipelineTmaCpAsync
|
|
1713
1945
|
return pipeline_cls.create(
|
|
1714
1946
|
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1715
1947
|
num_stages=self.ab_stage,
|
|
1716
1948
|
producer_group=ab_pipeline_producer_group,
|
|
1717
1949
|
consumer_group=ab_pipeline_consumer_group,
|
|
1718
|
-
tx_count=
|
|
1950
|
+
tx_count=self.num_tma_load_bytes,
|
|
1719
1951
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1720
1952
|
)
|
|
1721
1953
|
|
|
@@ -1725,7 +1957,7 @@ class GemmSm90:
|
|
|
1725
1957
|
# Threads/warps participating in this pipeline
|
|
1726
1958
|
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1727
1959
|
# Each warp will contribute 1 to the arrive count
|
|
1728
|
-
consumer_arrive_cnt = self.
|
|
1960
|
+
consumer_arrive_cnt = self.num_epi_warps
|
|
1729
1961
|
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1730
1962
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1731
1963
|
)
|
|
@@ -1738,6 +1970,16 @@ class GemmSm90:
|
|
|
1738
1970
|
tx_count=tma_copy_c_bytes,
|
|
1739
1971
|
)
|
|
1740
1972
|
|
|
1973
|
+
def make_epi_store_pipeline(self):
|
|
1974
|
+
# Threads/warps participating in tma store pipeline
|
|
1975
|
+
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
|
|
1976
|
+
epi_store_producer_group = pipeline.CooperativeGroup(
|
|
1977
|
+
pipeline.Agent.Thread, num_epi_threads, num_epi_threads
|
|
1978
|
+
)
|
|
1979
|
+
return pipeline.PipelineTmaStore.create(
|
|
1980
|
+
num_stages=self.epi_stage, producer_group=epi_store_producer_group
|
|
1981
|
+
)
|
|
1982
|
+
|
|
1741
1983
|
def make_sched_pipeline(
|
|
1742
1984
|
self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
|
|
1743
1985
|
):
|
|
@@ -1766,21 +2008,21 @@ class GemmSm90:
|
|
|
1766
2008
|
@classmethod
|
|
1767
2009
|
def _compute_stages(
|
|
1768
2010
|
cls,
|
|
1769
|
-
|
|
2011
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1770
2012
|
epi_tile: Tuple[int, int],
|
|
1771
2013
|
a_dtype: Type[cutlass.Numeric],
|
|
1772
2014
|
b_dtype: Type[cutlass.Numeric],
|
|
1773
2015
|
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1774
2016
|
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1775
|
-
epilogue_args:
|
|
2017
|
+
epilogue_args: EpilogueArguments,
|
|
1776
2018
|
smem_capacity: int,
|
|
1777
2019
|
occupancy: int,
|
|
1778
|
-
overlap_sD_sA: bool,
|
|
2020
|
+
overlap_sD_sA: bool = False,
|
|
1779
2021
|
) -> Tuple[int, int]:
|
|
1780
2022
|
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
1781
2023
|
|
|
1782
|
-
:param
|
|
1783
|
-
:type
|
|
2024
|
+
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
|
2025
|
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
|
1784
2026
|
:param a_dtype: Data type of operand A.
|
|
1785
2027
|
:type a_dtype: type[cutlass.Numeric]
|
|
1786
2028
|
:param b_dtype: Data type of operand B.
|
|
@@ -1803,15 +2045,15 @@ class GemmSm90:
|
|
|
1803
2045
|
cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
|
|
1804
2046
|
)
|
|
1805
2047
|
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
|
1806
|
-
epilogue_args,
|
|
2048
|
+
epilogue_args, cta_tile_shape_mnk, epi_tile
|
|
1807
2049
|
)
|
|
1808
2050
|
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
1809
2051
|
epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
|
|
1810
2052
|
if c_dtype is not None:
|
|
1811
2053
|
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
1812
2054
|
|
|
1813
|
-
a_shape = cute.slice_(
|
|
1814
|
-
b_shape = cute.slice_(
|
|
2055
|
+
a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
|
|
2056
|
+
b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
|
|
1815
2057
|
ab_bytes_per_stage = (
|
|
1816
2058
|
cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
|
|
1817
2059
|
)
|
|
@@ -1829,15 +2071,15 @@ class GemmSm90:
|
|
|
1829
2071
|
|
|
1830
2072
|
@staticmethod
|
|
1831
2073
|
def _sm90_compute_tile_shape_or_override(
|
|
1832
|
-
|
|
2074
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1833
2075
|
atom_layout_mnk: Tuple[int, int, int],
|
|
1834
2076
|
element_type: Optional[Type[cutlass.Numeric]] = None,
|
|
1835
2077
|
epi_tile_override: Tuple[int, int] | None = None,
|
|
1836
2078
|
) -> Tuple[int, int]:
|
|
1837
2079
|
"""Compute the epilogue tile shape or use override if provided.
|
|
1838
2080
|
|
|
1839
|
-
:param
|
|
1840
|
-
:type
|
|
2081
|
+
:param cta_tile_shape_mnk: CTA tile shape (M,N,K)
|
|
2082
|
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
|
1841
2083
|
:param element_type: Data type of elements
|
|
1842
2084
|
:type element_type: type[cutlass.Numeric]
|
|
1843
2085
|
:param is_cooperative: Whether to use cooperative approach
|
|
@@ -1850,12 +2092,12 @@ class GemmSm90:
|
|
|
1850
2092
|
"""
|
|
1851
2093
|
if epi_tile_override is not None:
|
|
1852
2094
|
return epi_tile_override
|
|
1853
|
-
if
|
|
1854
|
-
tile_m = math.gcd(128, cute.size(
|
|
1855
|
-
tile_n = math.gcd(32, cute.size(
|
|
1856
|
-
elif
|
|
1857
|
-
tile_m = math.gcd(192, cute.size(
|
|
1858
|
-
tile_n = math.gcd(32, cute.size(
|
|
2095
|
+
if cta_tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
|
|
2096
|
+
tile_m = math.gcd(128, cute.size(cta_tile_shape_mnk, mode=[0]))
|
|
2097
|
+
tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
|
|
2098
|
+
elif cta_tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
|
|
2099
|
+
tile_m = math.gcd(192, cute.size(cta_tile_shape_mnk, mode=[0]))
|
|
2100
|
+
tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
|
|
1859
2101
|
else:
|
|
1860
2102
|
# In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
|
|
1861
2103
|
# epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
|
|
@@ -1864,13 +2106,13 @@ class GemmSm90:
|
|
|
1864
2106
|
# We could change the epilogue to accommodate this,
|
|
1865
2107
|
# but it's easier to just set epi_tile_m = 64.
|
|
1866
2108
|
n_perf = 64 if element_type is not None and element_type.width == 8 else 32
|
|
1867
|
-
tile_m = math.gcd(64, cute.size(
|
|
1868
|
-
tile_n = math.gcd(n_perf, cute.size(
|
|
2109
|
+
tile_m = math.gcd(64, cute.size(cta_tile_shape_mnk, mode=[0]))
|
|
2110
|
+
tile_n = math.gcd(n_perf, cute.size(cta_tile_shape_mnk, mode=[1]))
|
|
1869
2111
|
return (tile_m, tile_n)
|
|
1870
2112
|
|
|
1871
2113
|
@staticmethod
|
|
1872
2114
|
def _make_smem_layouts(
|
|
1873
|
-
|
|
2115
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1874
2116
|
epi_tile: Tuple[int, int],
|
|
1875
2117
|
a_dtype: Type[cutlass.Numeric],
|
|
1876
2118
|
a_layout: LayoutEnum,
|
|
@@ -1888,8 +2130,8 @@ class GemmSm90:
|
|
|
1888
2130
|
]:
|
|
1889
2131
|
"""Create shared memory layouts for A, B, and C tensors.
|
|
1890
2132
|
|
|
1891
|
-
:param
|
|
1892
|
-
:type
|
|
2133
|
+
:param cta_tile_shape_mnk: CTA tile shape (M,N,K)
|
|
2134
|
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
|
1893
2135
|
:param epi_tile: Epilogue tile shape
|
|
1894
2136
|
:type epi_tile: Tuple[int, int]
|
|
1895
2137
|
:param a_dtype: Data type for matrix A
|
|
@@ -1912,11 +2154,11 @@ class GemmSm90:
|
|
|
1912
2154
|
:return: Tuple of shared memory layouts for A, B, and C
|
|
1913
2155
|
:rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
|
|
1914
2156
|
"""
|
|
1915
|
-
a_smem_shape = cute.slice_(
|
|
2157
|
+
a_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
|
|
1916
2158
|
|
|
1917
2159
|
a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
|
1918
2160
|
b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
|
1919
|
-
a_major_mode_size =
|
|
2161
|
+
a_major_mode_size = cta_tile_shape_mnk[2 if a_is_k_major else 0]
|
|
1920
2162
|
a_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1921
2163
|
sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
|
|
1922
2164
|
a_dtype,
|
|
@@ -1927,9 +2169,9 @@ class GemmSm90:
|
|
|
1927
2169
|
order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
|
|
1928
2170
|
)
|
|
1929
2171
|
|
|
1930
|
-
b_smem_shape = cute.slice_(
|
|
2172
|
+
b_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
|
|
1931
2173
|
|
|
1932
|
-
b_major_mode_size =
|
|
2174
|
+
b_major_mode_size = cta_tile_shape_mnk[2 if b_is_k_major else 1]
|
|
1933
2175
|
b_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1934
2176
|
sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
|
|
1935
2177
|
b_dtype,
|
|
@@ -1983,7 +2225,7 @@ class GemmSm90:
|
|
|
1983
2225
|
tensor_d: cute.Tensor,
|
|
1984
2226
|
epi_smem_layout_staged: cute.ComposedLayout,
|
|
1985
2227
|
epi_tile: Tuple[int, int],
|
|
1986
|
-
|
|
2228
|
+
op_type: Literal["store", "load", "add"],
|
|
1987
2229
|
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
|
1988
2230
|
"""Create TMA atoms and tensors for storing D or loading C.
|
|
1989
2231
|
|
|
@@ -1997,13 +2239,15 @@ class GemmSm90:
|
|
|
1997
2239
|
:return: TMA atom and tensor for C
|
|
1998
2240
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
1999
2241
|
"""
|
|
2000
|
-
assert
|
|
2242
|
+
assert op_type in ["load", "store", "add"]
|
|
2001
2243
|
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
|
2002
2244
|
d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
|
|
2003
2245
|
op = (
|
|
2004
2246
|
cpasync.CopyBulkTensorTileG2SOp()
|
|
2005
|
-
if
|
|
2247
|
+
if op_type == "load"
|
|
2006
2248
|
else cpasync.CopyBulkTensorTileS2GOp()
|
|
2249
|
+
if op_type == "store"
|
|
2250
|
+
else cpasync.CopyReduceBulkTensorTileS2GOp(cute.ReductionOp.ADD)
|
|
2007
2251
|
)
|
|
2008
2252
|
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
|
2009
2253
|
op, tensor_d, epi_smem_layout, d_cta_v_layout
|
|
@@ -2013,7 +2257,7 @@ class GemmSm90:
|
|
|
2013
2257
|
@staticmethod
|
|
2014
2258
|
def _make_tma_atoms_and_tensors(
|
|
2015
2259
|
tensor: cute.Tensor,
|
|
2016
|
-
|
|
2260
|
+
smem_layout: cute.ComposedLayout,
|
|
2017
2261
|
smem_tile: Tuple[int, int],
|
|
2018
2262
|
mcast_dim: int,
|
|
2019
2263
|
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
|
@@ -2021,8 +2265,8 @@ class GemmSm90:
|
|
|
2021
2265
|
|
|
2022
2266
|
:param tensor: Input tensor (A or B)
|
|
2023
2267
|
:type tensor: cute.Tensor
|
|
2024
|
-
:param
|
|
2025
|
-
:type
|
|
2268
|
+
:param smem_layout: Shared memory layout for the tensor
|
|
2269
|
+
:type smem_layout: cute.ComposedLayout
|
|
2026
2270
|
:param smem_tile: Shared memory tile shape
|
|
2027
2271
|
:type smem_tile: Tuple[int, int]
|
|
2028
2272
|
:param mcast_dim: Multicast dimension
|
|
@@ -2036,8 +2280,6 @@ class GemmSm90:
|
|
|
2036
2280
|
if mcast_dim == 1
|
|
2037
2281
|
else cpasync.CopyBulkTensorTileG2SMulticastOp()
|
|
2038
2282
|
)
|
|
2039
|
-
|
|
2040
|
-
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
|
|
2041
2283
|
tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
|
|
2042
2284
|
op,
|
|
2043
2285
|
tensor,
|
|
@@ -2054,13 +2296,18 @@ class GemmSm90:
|
|
|
2054
2296
|
num_bits_per_copy=copy_bits,
|
|
2055
2297
|
)
|
|
2056
2298
|
copy_elems = copy_bits // dtype.width
|
|
2057
|
-
|
|
2299
|
+
loads_per_cache_line = 128 * 8 // copy_bits # 128 bytes per cache line
|
|
2300
|
+
shape_dim_1 = cute.size(self.cta_tile_shape_mnk[2]) // copy_elems
|
|
2301
|
+
if shape_dim_1 > loads_per_cache_line:
|
|
2302
|
+
shape_dim_1 = math.gcd(shape_dim_1, loads_per_cache_line)
|
|
2058
2303
|
# thread layout for copy
|
|
2059
2304
|
thread_layout = cute.make_layout(
|
|
2060
2305
|
(num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
|
2061
2306
|
)
|
|
2062
2307
|
if major_mode != LayoutEnum.ROW_MAJOR:
|
|
2063
|
-
shape_dim_0 = cute.size(self.
|
|
2308
|
+
shape_dim_0 = cute.size(self.cta_tile_shape_mnk[0]) // copy_elems
|
|
2309
|
+
if shape_dim_0 > loads_per_cache_line:
|
|
2310
|
+
shape_dim_0 = math.gcd(shape_dim_0, loads_per_cache_line)
|
|
2064
2311
|
thread_layout = cute.make_layout(
|
|
2065
2312
|
(shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
|
2066
2313
|
)
|
|
@@ -2142,10 +2389,11 @@ class GemmSm90:
|
|
|
2142
2389
|
|
|
2143
2390
|
|
|
2144
2391
|
def gemm_sm90(
|
|
2145
|
-
|
|
2146
|
-
|
|
2147
|
-
|
|
2148
|
-
|
|
2392
|
+
# (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
|
|
2393
|
+
A: Tensor,
|
|
2394
|
+
B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
|
|
2395
|
+
D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
2396
|
+
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
2149
2397
|
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
2150
2398
|
tile_M: int,
|
|
2151
2399
|
tile_N: int,
|
|
@@ -2155,9 +2403,37 @@ def gemm_sm90(
|
|
|
2155
2403
|
persistent: bool = True,
|
|
2156
2404
|
alpha: float | Tensor = 1.0,
|
|
2157
2405
|
beta: float | Tensor = 1.0,
|
|
2406
|
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
2407
|
+
cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
|
|
2408
|
+
A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
|
|
2409
|
+
batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
|
|
2410
|
+
add_to_output: bool = False,
|
|
2158
2411
|
) -> None:
|
|
2159
|
-
|
|
2160
|
-
|
|
2412
|
+
varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
|
|
2413
|
+
assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
|
|
2414
|
+
"Only one of cu_seqlens_m and cu_seqlens_k can be specified"
|
|
2415
|
+
)
|
|
2416
|
+
gather_A = A_idx is not None
|
|
2417
|
+
if gather_A:
|
|
2418
|
+
assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
|
|
2419
|
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
2420
|
+
if varlen:
|
|
2421
|
+
assert persistent, "varlen requires persistent=True"
|
|
2422
|
+
if add_to_output:
|
|
2423
|
+
assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
|
|
2424
|
+
if cu_seqlens_m is not None:
|
|
2425
|
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
2426
|
+
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
|
2427
|
+
if cu_seqlens_k is not None:
|
|
2428
|
+
assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
|
|
2429
|
+
assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
|
|
2430
|
+
|
|
2431
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
2432
|
+
A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
|
|
2433
|
+
)
|
|
2434
|
+
GemmWrapperBase.permute_tensors(
|
|
2435
|
+
tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
|
|
2436
|
+
)
|
|
2161
2437
|
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
2162
2438
|
major_configs = {
|
|
2163
2439
|
"A": ("m", "k", "l"),
|
|
@@ -2190,10 +2466,25 @@ def gemm_sm90(
|
|
|
2190
2466
|
assert isinstance(scalar, Tensor)
|
|
2191
2467
|
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2192
2468
|
|
|
2193
|
-
epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta))
|
|
2469
|
+
epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta), add_to_output)
|
|
2194
2470
|
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
2195
|
-
max_active_clusters,
|
|
2471
|
+
max_active_clusters,
|
|
2472
|
+
tile_count_semaphore,
|
|
2473
|
+
batch_idx_permute,
|
|
2474
|
+
)
|
|
2475
|
+
|
|
2476
|
+
# Create varlen arguments if needed (assumes persistent=True when varlen)
|
|
2477
|
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
2478
|
+
cu_seqlens_m,
|
|
2479
|
+
cu_seqlens_k,
|
|
2480
|
+
A_idx,
|
|
2481
|
+
max_active_clusters,
|
|
2482
|
+
cluster_shape_mnk,
|
|
2483
|
+
tensor_infos,
|
|
2484
|
+
GemmSm90.num_epi_tensormaps,
|
|
2485
|
+
pingpong,
|
|
2196
2486
|
)
|
|
2487
|
+
|
|
2197
2488
|
current_stream = cutlass_torch.current_stream()
|
|
2198
2489
|
compile_key = GemmWrapperBase.get_compile_key(
|
|
2199
2490
|
tensor_infos,
|
|
@@ -2205,6 +2496,11 @@ def gemm_sm90(
|
|
|
2205
2496
|
tile_count_semaphore is not None,
|
|
2206
2497
|
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
|
2207
2498
|
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
|
2499
|
+
add_to_output,
|
|
2500
|
+
cu_seqlens_m is not None,
|
|
2501
|
+
cu_seqlens_k is not None,
|
|
2502
|
+
gather_A,
|
|
2503
|
+
batch_idx_permute is not None,
|
|
2208
2504
|
key_tensor_names=("A", "B", "D", "C"),
|
|
2209
2505
|
)
|
|
2210
2506
|
cache = gemm_sm90.compile_cache
|
|
@@ -2216,6 +2512,7 @@ def gemm_sm90(
|
|
|
2216
2512
|
cluster_shape_mnk,
|
|
2217
2513
|
pingpong=pingpong,
|
|
2218
2514
|
is_persistent=persistent,
|
|
2515
|
+
gather_A=gather_A,
|
|
2219
2516
|
)
|
|
2220
2517
|
cache[compile_key] = cute.compile(
|
|
2221
2518
|
gemm,
|
|
@@ -2225,8 +2522,7 @@ def gemm_sm90(
|
|
|
2225
2522
|
tensor_infos["C"].cute_tensor,
|
|
2226
2523
|
epi_args,
|
|
2227
2524
|
scheduler_args,
|
|
2228
|
-
|
|
2229
|
-
None, # mAIdx
|
|
2525
|
+
varlen_args,
|
|
2230
2526
|
current_stream,
|
|
2231
2527
|
)
|
|
2232
2528
|
cache[compile_key](
|
|
@@ -2236,8 +2532,7 @@ def gemm_sm90(
|
|
|
2236
2532
|
tensor_infos["C"].cute_tensor,
|
|
2237
2533
|
epi_args,
|
|
2238
2534
|
scheduler_args,
|
|
2239
|
-
|
|
2240
|
-
None,
|
|
2535
|
+
varlen_args,
|
|
2241
2536
|
current_stream,
|
|
2242
2537
|
)
|
|
2243
2538
|
|