quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -2,12 +2,10 @@
|
|
|
2
2
|
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
|
|
3
3
|
|
|
4
4
|
import enum
|
|
5
|
-
from typing import Tuple, Type, Callable, Optional, Union
|
|
6
|
-
from dataclasses import dataclass
|
|
5
|
+
from typing import Tuple, Type, Callable, Optional, Union, Literal
|
|
7
6
|
from functools import partial
|
|
8
7
|
import math
|
|
9
8
|
|
|
10
|
-
from torch import Tensor
|
|
11
9
|
|
|
12
10
|
import cuda.bindings.driver as cuda
|
|
13
11
|
|
|
@@ -16,10 +14,9 @@ import cutlass.cute as cute
|
|
|
16
14
|
import cutlass.pipeline as pipeline
|
|
17
15
|
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
18
16
|
import cutlass.utils.hopper_helpers as sm90_utils
|
|
19
|
-
from cutlass import Int32, Float32, Boolean, const_expr
|
|
17
|
+
from cutlass import Int32, Float32, Float16, Boolean, const_expr
|
|
18
|
+
from cutlass.cutlass_dsl import if_generate
|
|
20
19
|
from cutlass.utils import LayoutEnum
|
|
21
|
-
import cutlass.torch as cutlass_torch
|
|
22
|
-
from cutlass.cute.runtime import make_ptr
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
from quack.cute_dsl_utils import ParamsBase, ArgumentsBase
|
|
@@ -30,14 +27,12 @@ from quack.tile_scheduler import (
|
|
|
30
27
|
VarlenMTileSchedulerArguments,
|
|
31
28
|
VarlenMTileScheduler,
|
|
32
29
|
)
|
|
33
|
-
from quack.varlen_utils import VarlenArguments
|
|
34
|
-
from quack.tensormap_manager import TensorMapManagerSm90
|
|
30
|
+
from quack.varlen_utils import VarlenArguments, VarlenManager
|
|
35
31
|
|
|
36
32
|
# return PipelineStateWAdvance instead of PipelineState
|
|
37
33
|
from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
|
|
38
|
-
import quack.
|
|
39
|
-
|
|
40
|
-
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
34
|
+
import quack.copy_utils as copy_utils
|
|
35
|
+
import quack.sm90_utils as quack_sm90_utils
|
|
41
36
|
|
|
42
37
|
"""
|
|
43
38
|
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
|
@@ -86,12 +81,13 @@ class NamedBarrierGemm(enum.IntEnum):
|
|
|
86
81
|
MmaWG1 = enum.auto()
|
|
87
82
|
EpiWG0 = enum.auto()
|
|
88
83
|
EpiWG1 = enum.auto()
|
|
84
|
+
TmemPtr = enum.auto()
|
|
89
85
|
|
|
90
86
|
|
|
91
87
|
class GemmSm90:
|
|
92
88
|
"""
|
|
93
89
|
This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
94
|
-
and architectural features specific to Hopper GPUs.
|
|
90
|
+
and architectural features specific to Hopper GPUs with persistent tile scheduling and warp specialization.
|
|
95
91
|
|
|
96
92
|
:param acc_dtype: Data type for accumulation during computation
|
|
97
93
|
:type acc_dtype: type[cutlass.Numeric]
|
|
@@ -118,24 +114,18 @@ class GemmSm90:
|
|
|
118
114
|
|
|
119
115
|
Example:
|
|
120
116
|
>>> gemm = GemmSm90(
|
|
121
|
-
... acc_dtype=
|
|
117
|
+
... acc_dtype=Float32,
|
|
122
118
|
... tile_shape_mn=(128, 256),
|
|
123
119
|
... cluster_shape_mnk=(1, 1, 1)
|
|
124
120
|
... )
|
|
125
121
|
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
|
126
122
|
"""
|
|
127
123
|
|
|
128
|
-
|
|
124
|
+
arch = 90
|
|
125
|
+
num_epi_tensormaps: int = 0
|
|
129
126
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
alpha: Optional[Float32 | cute.Tensor] = None
|
|
133
|
-
beta: Optional[Float32 | cute.Tensor] = None
|
|
134
|
-
|
|
135
|
-
@dataclass
|
|
136
|
-
class EpilogueParams(ParamsBase):
|
|
137
|
-
alpha: Optional[Float32 | cute.Tensor] = None
|
|
138
|
-
beta: Optional[Float32 | cute.Tensor] = None
|
|
127
|
+
EpilogueArguments = ArgumentsBase
|
|
128
|
+
EpilogueParams = ParamsBase
|
|
139
129
|
|
|
140
130
|
def __init__(
|
|
141
131
|
self,
|
|
@@ -174,8 +164,8 @@ class GemmSm90:
|
|
|
174
164
|
|
|
175
165
|
self.cluster_shape_mnk = cluster_shape_mnk
|
|
176
166
|
# K dimension is deferred in _setup_attributes
|
|
177
|
-
self.
|
|
178
|
-
tile_M, tile_N = self.
|
|
167
|
+
self.cta_tile_shape_mnk = (*tile_shape_mn, 1)
|
|
168
|
+
tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
|
|
179
169
|
# check the cta tile shape
|
|
180
170
|
if not self.pingpong:
|
|
181
171
|
if tile_M not in [64, 128, 192, 256, 320]:
|
|
@@ -209,14 +199,18 @@ class GemmSm90:
|
|
|
209
199
|
else:
|
|
210
200
|
atom_layout_m, atom_layout_n = 1, 2
|
|
211
201
|
else:
|
|
212
|
-
atom_layout_m =
|
|
202
|
+
atom_layout_m = (
|
|
203
|
+
self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2
|
|
204
|
+
)
|
|
213
205
|
atom_layout_n = 1
|
|
214
206
|
assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
|
|
215
207
|
else:
|
|
216
208
|
atom_layout_m, atom_layout_n = 1, 1
|
|
217
209
|
self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
|
|
218
210
|
|
|
219
|
-
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
|
211
|
+
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
|
212
|
+
if self.gather_A:
|
|
213
|
+
assert self.num_mcast_ctas_a == 1
|
|
220
214
|
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
|
221
215
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
222
216
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
@@ -229,16 +223,13 @@ class GemmSm90:
|
|
|
229
223
|
self.num_threads_per_warp_group = 128
|
|
230
224
|
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
|
231
225
|
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
|
232
|
-
self.
|
|
233
|
-
self.mma_warp_groups if not self.pingpong else 1
|
|
234
|
-
) * self.num_threads_per_warp_group
|
|
226
|
+
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
|
|
235
227
|
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
|
236
|
-
self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
|
|
237
|
-
self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
238
228
|
self.ab_load_warp_id = self.mma_warp_groups * 4
|
|
239
|
-
self.
|
|
229
|
+
# self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
230
|
+
# self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
240
231
|
|
|
241
|
-
regs_per_thread = math.prod(self.
|
|
232
|
+
regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
|
|
242
233
|
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
243
234
|
)
|
|
244
235
|
if self.fp8_slow_accum:
|
|
@@ -268,7 +259,7 @@ class GemmSm90:
|
|
|
268
259
|
self.shared_storage = None
|
|
269
260
|
self.buffer_align_bytes = 1024
|
|
270
261
|
|
|
271
|
-
def _setup_attributes(self, epilogue_args:
|
|
262
|
+
def _setup_attributes(self, epilogue_args: EpilogueArguments):
|
|
272
263
|
"""Set up configurations that are dependent on GEMM inputs
|
|
273
264
|
|
|
274
265
|
This method configures various attributes based on the input tensor properties
|
|
@@ -289,7 +280,7 @@ class GemmSm90:
|
|
|
289
280
|
self.b_layout.sm90_mma_major_mode(),
|
|
290
281
|
self.acc_dtype,
|
|
291
282
|
self.atom_layout_mnk,
|
|
292
|
-
tiler_mn=(64, self.
|
|
283
|
+
tiler_mn=(64, self.cta_tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
|
293
284
|
)
|
|
294
285
|
if const_expr(self.atom_layout_mnk[1] > 1):
|
|
295
286
|
# If N dimension is split among 2 WGs, we need to permute the N dimension so
|
|
@@ -299,7 +290,7 @@ class GemmSm90:
|
|
|
299
290
|
# WG1 would write to a separate epi smem of size (64, 16) that's far away.
|
|
300
291
|
atom_n = self.atom_layout_mnk[1]
|
|
301
292
|
permutation_n = cute.make_ordered_layout(
|
|
302
|
-
(8, self.
|
|
293
|
+
(8, self.cta_tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
|
|
303
294
|
)
|
|
304
295
|
self.tiled_mma = cute.make_tiled_mma(
|
|
305
296
|
cute.make_mma_atom(self.tiled_mma.op),
|
|
@@ -308,30 +299,30 @@ class GemmSm90:
|
|
|
308
299
|
)
|
|
309
300
|
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
|
|
310
301
|
mma_inst_tile_k = 4
|
|
311
|
-
self.
|
|
312
|
-
self.
|
|
313
|
-
self.
|
|
302
|
+
self.cta_tile_shape_mnk = (
|
|
303
|
+
self.cta_tile_shape_mnk[0],
|
|
304
|
+
self.cta_tile_shape_mnk[1],
|
|
314
305
|
mma_inst_shape_k * mma_inst_tile_k,
|
|
315
306
|
)
|
|
316
307
|
|
|
317
308
|
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
|
318
309
|
|
|
319
310
|
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
|
320
|
-
self.
|
|
311
|
+
self.cta_tile_shape_mnk,
|
|
321
312
|
self.atom_layout_mnk,
|
|
322
313
|
self.d_dtype,
|
|
323
314
|
)
|
|
324
315
|
|
|
325
316
|
# Compute stage before compute smem layout
|
|
326
317
|
self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
|
|
327
|
-
self.
|
|
318
|
+
self.cta_tile_shape_mnk,
|
|
328
319
|
self.epi_tile,
|
|
329
320
|
self.a_dtype,
|
|
330
321
|
self.b_dtype,
|
|
331
322
|
self.d_dtype,
|
|
332
323
|
self.c_dtype,
|
|
333
324
|
epilogue_args,
|
|
334
|
-
self.smem_capacity
|
|
325
|
+
cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
|
|
335
326
|
self.occupancy,
|
|
336
327
|
# epi_smem will reuse smem ab if not persistent.
|
|
337
328
|
overlap_sD_sA=not self.is_persistent,
|
|
@@ -344,7 +335,7 @@ class GemmSm90:
|
|
|
344
335
|
self.epi_smem_layout_staged,
|
|
345
336
|
self.epi_c_smem_layout_staged,
|
|
346
337
|
) = self._make_smem_layouts(
|
|
347
|
-
self.
|
|
338
|
+
self.cta_tile_shape_mnk,
|
|
348
339
|
self.epi_tile,
|
|
349
340
|
self.a_dtype,
|
|
350
341
|
self.a_layout,
|
|
@@ -366,10 +357,9 @@ class GemmSm90:
|
|
|
366
357
|
mB: cute.Tensor,
|
|
367
358
|
mD: Optional[cute.Tensor],
|
|
368
359
|
mC: Optional[cute.Tensor],
|
|
369
|
-
epilogue_args:
|
|
360
|
+
epilogue_args: ArgumentsBase,
|
|
370
361
|
scheduler_args: TileSchedulerOptions,
|
|
371
362
|
varlen_args: Optional[VarlenArguments],
|
|
372
|
-
mAIdx: Optional[cute.Tensor],
|
|
373
363
|
stream: cuda.CUstream,
|
|
374
364
|
):
|
|
375
365
|
"""Execute the GEMM operation in steps:
|
|
@@ -405,7 +395,10 @@ class GemmSm90:
|
|
|
405
395
|
raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
|
|
406
396
|
if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
|
|
407
397
|
raise TypeError("a_dtype should be float16 or float8")
|
|
408
|
-
|
|
398
|
+
|
|
399
|
+
if const_expr(varlen_args is None):
|
|
400
|
+
varlen_args = VarlenArguments()
|
|
401
|
+
assert (varlen_args.mAIdx is not None) == self.gather_A
|
|
409
402
|
|
|
410
403
|
# Assume all strides are divisible by 128 bits except the last stride
|
|
411
404
|
new_stride = lambda t: tuple(
|
|
@@ -421,77 +414,48 @@ class GemmSm90:
|
|
|
421
414
|
|
|
422
415
|
self._setup_attributes(epilogue_args)
|
|
423
416
|
|
|
417
|
+
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0))
|
|
418
|
+
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0))
|
|
419
|
+
tma_atom_a, tma_tensor_a = None, None
|
|
424
420
|
if const_expr(not self.gather_A):
|
|
425
421
|
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
|
426
422
|
mA,
|
|
427
|
-
|
|
428
|
-
(self.
|
|
423
|
+
a_smem_layout,
|
|
424
|
+
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
|
|
429
425
|
self.cluster_shape_mnk[1],
|
|
430
426
|
)
|
|
431
|
-
else:
|
|
432
|
-
tma_atom_a, tma_tensor_a = None, None
|
|
433
|
-
|
|
434
427
|
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
|
435
428
|
mB,
|
|
436
|
-
|
|
437
|
-
(self.
|
|
429
|
+
b_smem_layout,
|
|
430
|
+
(self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
|
|
438
431
|
self.cluster_shape_mnk[0],
|
|
439
432
|
)
|
|
440
433
|
|
|
434
|
+
self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
435
|
+
if const_expr(not self.gather_A):
|
|
436
|
+
self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
437
|
+
|
|
438
|
+
tma_atom_d, tma_tensor_d = None, None
|
|
441
439
|
if const_expr(mD is not None):
|
|
442
440
|
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
|
443
|
-
mD,
|
|
441
|
+
mD,
|
|
442
|
+
self.epi_smem_layout_staged,
|
|
443
|
+
self.epi_tile,
|
|
444
|
+
op_type="store"
|
|
445
|
+
if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
|
|
446
|
+
else "add",
|
|
444
447
|
)
|
|
445
|
-
|
|
446
|
-
tma_atom_d, tma_tensor_d = None, None
|
|
447
|
-
|
|
448
|
+
tma_atom_c, tma_tensor_c = None, None
|
|
448
449
|
if const_expr(mC is not None):
|
|
449
450
|
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
|
450
|
-
mC, self.epi_c_smem_layout_staged, self.epi_tile,
|
|
451
|
+
mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load"
|
|
451
452
|
)
|
|
452
|
-
else:
|
|
453
|
-
tma_atom_c, tma_tensor_c = None, None
|
|
454
453
|
|
|
455
454
|
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
|
|
455
|
+
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
|
|
456
456
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
if const_expr(varlen_args.mCuSeqlensM is None):
|
|
460
|
-
num_problems = (
|
|
461
|
-
mD.shape[2]
|
|
462
|
-
if mD is not None
|
|
463
|
-
else (
|
|
464
|
-
mB.shape[2]
|
|
465
|
-
if varlen_args.mCuSeqlensK is None
|
|
466
|
-
else varlen_args.mCuSeqlensK.shape[0] - 1
|
|
467
|
-
)
|
|
468
|
-
)
|
|
469
|
-
problem_shape_ntile_mnl = (
|
|
470
|
-
cute.ceil_div(mA.shape[0], self.tile_shape_mnk[0]),
|
|
471
|
-
cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
|
|
472
|
-
num_problems,
|
|
473
|
-
)
|
|
474
|
-
TileSchedulerCls = self.get_scheduler_class()
|
|
475
|
-
tile_sched_args = self.get_scheduler_arguments(problem_shape_ntile_mnl, scheduler_args)
|
|
476
|
-
else:
|
|
477
|
-
assert mD is not None or not self.gather_A
|
|
478
|
-
problem_shape_ntile_mnl = (
|
|
479
|
-
None,
|
|
480
|
-
cute.ceil_div(mB.shape[0], self.tile_shape_mnk[1]),
|
|
481
|
-
varlen_args.mCuSeqlensM.shape[0] - 1,
|
|
482
|
-
)
|
|
483
|
-
TileSchedulerCls = VarlenMTileScheduler
|
|
484
|
-
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
485
|
-
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
486
|
-
total_m=mD.shape[0] if mD is not None else mAIdx.shape[0],
|
|
487
|
-
cu_seqlens_m=varlen_args.mCuSeqlensM,
|
|
488
|
-
raster_order=scheduler_args.raster_order,
|
|
489
|
-
group_size=scheduler_args.max_swizzle_size,
|
|
490
|
-
tile_shape_mn=self.tile_shape_mnk[:2],
|
|
491
|
-
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
492
|
-
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
493
|
-
is_persistent=self.is_persistent,
|
|
494
|
-
)
|
|
457
|
+
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
|
|
458
|
+
tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
|
|
495
459
|
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
496
460
|
grid = TileSchedulerCls.get_grid_shape(
|
|
497
461
|
tile_sched_params, scheduler_args.max_active_clusters
|
|
@@ -507,7 +471,7 @@ class GemmSm90:
|
|
|
507
471
|
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
508
472
|
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
509
473
|
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
510
|
-
tile_count: cute.struct.MemRange[
|
|
474
|
+
tile_count: cute.struct.MemRange[Int32, self.sched_stage]
|
|
511
475
|
sD: cute.struct.Align[
|
|
512
476
|
cute.struct.MemRange[
|
|
513
477
|
self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
|
|
@@ -534,6 +498,7 @@ class GemmSm90:
|
|
|
534
498
|
|
|
535
499
|
# Launch the kernel synchronously
|
|
536
500
|
self.kernel(
|
|
501
|
+
self.tiled_mma,
|
|
537
502
|
tma_atom_a,
|
|
538
503
|
tma_tensor_a if const_expr(not self.gather_A) else mA,
|
|
539
504
|
tma_atom_b,
|
|
@@ -543,11 +508,7 @@ class GemmSm90:
|
|
|
543
508
|
tma_atom_c,
|
|
544
509
|
tma_tensor_c,
|
|
545
510
|
epilogue_params,
|
|
546
|
-
|
|
547
|
-
varlen_args.mCuSeqlensM,
|
|
548
|
-
varlen_args.mCuSeqlensK,
|
|
549
|
-
varlen_args.mTensormaps,
|
|
550
|
-
self.tiled_mma,
|
|
511
|
+
varlen_params,
|
|
551
512
|
self.cluster_layout_mnk,
|
|
552
513
|
self.a_smem_layout_staged,
|
|
553
514
|
self.b_smem_layout_staged,
|
|
@@ -559,7 +520,6 @@ class GemmSm90:
|
|
|
559
520
|
grid=grid,
|
|
560
521
|
block=[self.threads_per_cta, 1, 1],
|
|
561
522
|
cluster=self.cluster_shape_mnk,
|
|
562
|
-
smem=self.shared_storage.size_in_bytes(),
|
|
563
523
|
stream=stream,
|
|
564
524
|
min_blocks_per_mp=1,
|
|
565
525
|
)
|
|
@@ -569,6 +529,7 @@ class GemmSm90:
|
|
|
569
529
|
@cute.kernel
|
|
570
530
|
def kernel(
|
|
571
531
|
self,
|
|
532
|
+
tiled_mma: cute.TiledMma,
|
|
572
533
|
tma_atom_a: Optional[cute.CopyAtom],
|
|
573
534
|
mA_mkl: cute.Tensor,
|
|
574
535
|
tma_atom_b: cute.CopyAtom,
|
|
@@ -578,11 +539,7 @@ class GemmSm90:
|
|
|
578
539
|
tma_atom_c: Optional[cute.CopyAtom],
|
|
579
540
|
mC_mnl: Optional[cute.Tensor],
|
|
580
541
|
epilogue_params: ParamsBase,
|
|
581
|
-
|
|
582
|
-
cu_seqlens_m: Optional[cute.Tensor],
|
|
583
|
-
cu_seqlens_k: Optional[cute.Tensor],
|
|
584
|
-
tensormaps: Optional[cute.Tensor],
|
|
585
|
-
tiled_mma: cute.TiledMma,
|
|
542
|
+
varlen_params: VarlenManager.Params,
|
|
586
543
|
cluster_layout_mnk: cute.Layout,
|
|
587
544
|
a_smem_layout: cute.ComposedLayout,
|
|
588
545
|
b_smem_layout: cute.ComposedLayout,
|
|
@@ -618,9 +575,11 @@ class GemmSm90:
|
|
|
618
575
|
:type epi_smem_layout: cute.ComposedLayout
|
|
619
576
|
"""
|
|
620
577
|
|
|
621
|
-
varlen_m = const_expr(cu_seqlens_m is not None)
|
|
622
|
-
varlen_k = const_expr(cu_seqlens_k is not None)
|
|
578
|
+
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
|
|
579
|
+
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
|
|
623
580
|
assert not (varlen_m and varlen_k)
|
|
581
|
+
if const_expr(self.gather_A):
|
|
582
|
+
assert varlen_m or varlen_k
|
|
624
583
|
has_D = const_expr(mD_mnl is not None)
|
|
625
584
|
has_C = const_expr(mC_mnl is not None)
|
|
626
585
|
|
|
@@ -641,8 +600,6 @@ class GemmSm90:
|
|
|
641
600
|
storage = smem.allocate(self.shared_storage)
|
|
642
601
|
|
|
643
602
|
ab_pipeline = self.make_ab_pipeline(
|
|
644
|
-
a_smem_layout=cute.slice_(a_smem_layout, (None, None, 0)),
|
|
645
|
-
b_smem_layout=cute.slice_(b_smem_layout, (None, None, 0)),
|
|
646
603
|
tiled_mma=tiled_mma,
|
|
647
604
|
cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
|
|
648
605
|
ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
@@ -681,28 +638,20 @@ class GemmSm90:
|
|
|
681
638
|
sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
|
|
682
639
|
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
|
|
683
640
|
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
else:
|
|
699
|
-
assert varlen_k
|
|
700
|
-
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
701
|
-
tensormaps[tensormap_workspace_idx, 0, None].iterator
|
|
702
|
-
)
|
|
703
|
-
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
704
|
-
tensormaps[tensormap_workspace_idx, 1, None].iterator
|
|
705
|
-
)
|
|
641
|
+
varlen_manager = VarlenManager.create(
|
|
642
|
+
varlen_params,
|
|
643
|
+
has_D,
|
|
644
|
+
self.num_epi_tensormaps,
|
|
645
|
+
# Only used if not varlen_m
|
|
646
|
+
len_m_static=Int32(
|
|
647
|
+
mA_mkl.shape[0]
|
|
648
|
+
if varlen_k or varlen_params.mAIdx is None
|
|
649
|
+
else varlen_params.mAIdx.shape[0]
|
|
650
|
+
),
|
|
651
|
+
len_k_static=Int32(mA_mkl.shape[1]),
|
|
652
|
+
pingpong=self.pingpong,
|
|
653
|
+
warp_idx=warp_idx,
|
|
654
|
+
)
|
|
706
655
|
|
|
707
656
|
TileSchedulerCls = partial(
|
|
708
657
|
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
|
@@ -715,28 +664,20 @@ class GemmSm90:
|
|
|
715
664
|
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
716
665
|
):
|
|
717
666
|
is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
tensormap_a_ptr,
|
|
723
|
-
is_tma_warp,
|
|
724
|
-
)
|
|
725
|
-
tensormap_manager.init_tensormap_from_atom(
|
|
726
|
-
tma_atom_b,
|
|
727
|
-
tensormap_b_ptr,
|
|
728
|
-
is_tma_warp,
|
|
729
|
-
)
|
|
667
|
+
# initialize tensormap for A & B
|
|
668
|
+
varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
|
|
669
|
+
tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
|
|
670
|
+
tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
|
|
730
671
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
731
672
|
# Get mcast mask
|
|
732
673
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
733
674
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
734
|
-
|
|
675
|
+
block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
735
676
|
a_mcast_mask = cute.make_layout_image_mask(
|
|
736
|
-
cluster_layout_mnk,
|
|
677
|
+
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
|
|
737
678
|
)
|
|
738
679
|
b_mcast_mask = cute.make_layout_image_mask(
|
|
739
|
-
cluster_layout_mnk,
|
|
680
|
+
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0
|
|
740
681
|
)
|
|
741
682
|
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
|
|
742
683
|
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
|
|
@@ -752,177 +693,129 @@ class GemmSm90:
|
|
|
752
693
|
)
|
|
753
694
|
if const_expr(varlen_k):
|
|
754
695
|
# wait tensormap initialization complete before update
|
|
755
|
-
|
|
756
|
-
# batch index of last tile
|
|
757
|
-
last_batch_idx = cutlass.Int32(-1)
|
|
696
|
+
varlen_manager.fence_tensormap_init()
|
|
758
697
|
while work_tile.is_valid_tile:
|
|
759
698
|
tile_coord_mnkl = work_tile.tile_idx
|
|
760
699
|
batch_idx = tile_coord_mnkl[3]
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
(tensormap_a_ptr, tensormap_b_ptr),
|
|
768
|
-
is_manager_warp=is_tma_warp,
|
|
769
|
-
shapes=(cu_seqlens_k[batch_idx + 1], cu_seqlens_k[batch_idx + 1]),
|
|
770
|
-
orders=(
|
|
771
|
-
0 if const_expr(self.a_layout == LayoutEnum.ROW_MAJOR) else 1,
|
|
772
|
-
0 if const_expr(self.b_layout == LayoutEnum.ROW_MAJOR) else 1,
|
|
773
|
-
),
|
|
774
|
-
tensormap_smem_ptr=None,
|
|
775
|
-
)
|
|
700
|
+
varlen_manager.update_tensormap_AB(
|
|
701
|
+
batch_idx,
|
|
702
|
+
self.a_layout,
|
|
703
|
+
self.b_layout,
|
|
704
|
+
is_tma_warp,
|
|
705
|
+
)
|
|
776
706
|
# ///////////////////////////////////////////////////////////////////////////
|
|
777
707
|
# Local_tile partition global tensors
|
|
778
708
|
# ///////////////////////////////////////////////////////////////////////////
|
|
779
709
|
if const_expr(not self.gather_A):
|
|
780
|
-
|
|
781
|
-
mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
|
|
782
|
-
elif const_expr(varlen_k):
|
|
783
|
-
mA_mk = cute.domain_offset((0, cu_seqlens_k[batch_idx]), mA_mkl)
|
|
784
|
-
else:
|
|
785
|
-
mA_mk = mA_mkl[None, None, batch_idx]
|
|
710
|
+
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
|
|
786
711
|
# (bM, bK, RestK)
|
|
787
|
-
|
|
712
|
+
gA_mk = cute.local_tile(
|
|
788
713
|
mA_mk,
|
|
789
|
-
cute.select(self.
|
|
714
|
+
cute.select(self.cta_tile_shape_mnk, [0, 2]),
|
|
790
715
|
(tile_coord_mnkl[0], None),
|
|
791
716
|
)
|
|
792
717
|
else:
|
|
793
|
-
|
|
718
|
+
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
|
|
794
719
|
if const_expr(varlen_m):
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
720
|
+
gAIdx = cute.local_tile(
|
|
721
|
+
mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
|
722
|
+
)
|
|
723
|
+
# (M, K)
|
|
724
|
+
mA_mk = mA_mkl
|
|
798
725
|
else:
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
mAIdx_mk, (self.
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
mB_nk = mB_nkl[None, None, batch_idx]
|
|
726
|
+
assert varlen_k
|
|
727
|
+
# (tile_K, RestK)
|
|
728
|
+
gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
|
|
729
|
+
# (tile_M, K)
|
|
730
|
+
mA_mk = cute.local_tile(
|
|
731
|
+
mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
|
|
732
|
+
)
|
|
807
733
|
# (bN, bK, RestK)
|
|
808
|
-
|
|
809
|
-
|
|
734
|
+
gB_nk = cute.local_tile(
|
|
735
|
+
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
|
|
736
|
+
cute.select(self.cta_tile_shape_mnk, [1, 2]),
|
|
737
|
+
(tile_coord_mnkl[1], None),
|
|
810
738
|
)
|
|
811
739
|
# //////////////////////////////////////////////////////////////////////////
|
|
812
740
|
# Partition shared tensor for TMA load A/B
|
|
813
741
|
# //////////////////////////////////////////////////////////////////////////
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
|
|
818
|
-
tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
|
|
819
|
-
tma_desc_a_ptr = tensormap_manager.get_tensormap_ptr(
|
|
820
|
-
tensormap_a_ptr, cute.AddressSpace.generic
|
|
821
|
-
)
|
|
822
|
-
tma_desc_b_ptr = tensormap_manager.get_tensormap_ptr(
|
|
823
|
-
tensormap_b_ptr, cute.AddressSpace.generic
|
|
824
|
-
)
|
|
825
|
-
else:
|
|
826
|
-
tma_desc_a_ptr, tma_desc_b_ptr = None, None
|
|
742
|
+
varlen_manager.fence_tensormap_update_AB(is_tma_warp)
|
|
743
|
+
len_m = varlen_manager.len_m(batch_idx)
|
|
744
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
827
745
|
# TMA load A partition_S/D
|
|
828
|
-
|
|
829
|
-
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
830
|
-
)
|
|
831
|
-
a_cta_crd = cluster_coord_mnk[1]
|
|
746
|
+
copy_A = None
|
|
832
747
|
if const_expr(not self.gather_A):
|
|
833
|
-
|
|
834
|
-
tAsA, tAgA_k = cpasync.tma_partition(
|
|
835
|
-
tma_atom_a,
|
|
836
|
-
a_cta_crd,
|
|
837
|
-
a_cta_layout,
|
|
838
|
-
cute.group_modes(sA, 0, 2),
|
|
839
|
-
cute.group_modes(gA_k, 0, 2),
|
|
840
|
-
)
|
|
841
|
-
copy_A = partial(
|
|
842
|
-
cute.copy,
|
|
748
|
+
copy_A, _, _ = copy_utils.tma_get_copy_fn(
|
|
843
749
|
tma_atom_a,
|
|
750
|
+
cta_coord=block_in_cluster_coord_mnk[1],
|
|
751
|
+
cta_layout=cute.make_layout(
|
|
752
|
+
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
753
|
+
),
|
|
754
|
+
src_tensor=gA_mk,
|
|
755
|
+
dst_tensor=sA,
|
|
844
756
|
mcast_mask=a_mcast_mask,
|
|
845
757
|
tma_desc_ptr=tma_desc_a_ptr,
|
|
846
758
|
)
|
|
847
759
|
else:
|
|
848
760
|
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
|
849
|
-
mA_mkl.element_type, self.a_layout, self.
|
|
761
|
+
mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
|
|
850
762
|
)
|
|
851
763
|
tidx = (
|
|
852
|
-
cute.arch.thread_idx()[0]
|
|
853
|
-
- self.mma_warp_groups * self.num_threads_per_warp_group
|
|
764
|
+
cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
|
|
854
765
|
)
|
|
855
766
|
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
767
|
+
copy_A, prefetch_A = None, None
|
|
768
|
+
if const_expr(varlen_m):
|
|
769
|
+
copy_A = copy_utils.gather_m_get_copy_fn(
|
|
770
|
+
thr_copy_A,
|
|
771
|
+
mA_mk,
|
|
772
|
+
sA,
|
|
773
|
+
gAIdx,
|
|
774
|
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
775
|
+
limit_k=len_k,
|
|
776
|
+
)
|
|
777
|
+
else:
|
|
778
|
+
copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
|
|
779
|
+
thr_copy_A,
|
|
780
|
+
mA_mk,
|
|
781
|
+
sA,
|
|
782
|
+
gAIdx,
|
|
783
|
+
limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
|
|
784
|
+
limit_k=len_k,
|
|
785
|
+
)
|
|
861
786
|
# TMA load B partition_S/D
|
|
862
|
-
|
|
863
|
-
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
|
864
|
-
)
|
|
865
|
-
b_cta_crd = cluster_coord_mnk[0]
|
|
866
|
-
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
867
|
-
tBsB, tBgB_k = cpasync.tma_partition(
|
|
787
|
+
copy_B, _, _ = copy_utils.tma_get_copy_fn(
|
|
868
788
|
tma_atom_b,
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
789
|
+
cta_coord=block_in_cluster_coord_mnk[0],
|
|
790
|
+
cta_layout=cute.make_layout(
|
|
791
|
+
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
|
792
|
+
),
|
|
793
|
+
src_tensor=gB_nk,
|
|
794
|
+
dst_tensor=sB,
|
|
795
|
+
mcast_mask=b_mcast_mask,
|
|
796
|
+
tma_desc_ptr=tma_desc_b_ptr,
|
|
876
797
|
)
|
|
877
|
-
|
|
878
|
-
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
879
|
-
if const_expr(varlen_k)
|
|
880
|
-
else mA_mkl.shape[1]
|
|
881
|
-
)
|
|
882
|
-
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
798
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
883
799
|
if const_expr(not self.gather_A):
|
|
884
800
|
ab_producer_state = self.load_AB(
|
|
885
|
-
ab_pipeline,
|
|
886
|
-
ab_producer_state,
|
|
887
|
-
copy_A,
|
|
888
|
-
tAgA_k,
|
|
889
|
-
tAsA,
|
|
890
|
-
copy_B,
|
|
891
|
-
tBgB_k,
|
|
892
|
-
tBsB,
|
|
893
|
-
k_tile_cnt,
|
|
801
|
+
ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
|
|
894
802
|
)
|
|
895
803
|
else:
|
|
896
|
-
limit_m = (
|
|
897
|
-
mAIdx.shape[0]
|
|
898
|
-
if const_expr(cu_seqlens_m is None)
|
|
899
|
-
else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
|
|
900
|
-
)
|
|
901
804
|
ab_producer_state = self.load_AB_gather_A(
|
|
902
805
|
ab_pipeline,
|
|
903
806
|
ab_producer_state,
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
tAsA,
|
|
907
|
-
gAIdx,
|
|
807
|
+
copy_A,
|
|
808
|
+
prefetch_A,
|
|
908
809
|
copy_B,
|
|
909
|
-
tBgB_k,
|
|
910
|
-
tBsB,
|
|
911
810
|
k_tile_cnt,
|
|
912
|
-
|
|
913
|
-
limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
|
|
914
|
-
mA_mk.shape[1],
|
|
915
|
-
),
|
|
811
|
+
varlen_m=varlen_m,
|
|
916
812
|
)
|
|
917
813
|
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
918
|
-
tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
919
814
|
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
920
815
|
work_tile = tile_scheduler.get_current_work()
|
|
921
816
|
# End of persistent scheduler loop
|
|
922
817
|
if const_expr(self.pingpong and not varlen_k):
|
|
923
818
|
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
|
924
|
-
# tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
925
|
-
tile_scheduler.broadcast_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
926
819
|
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
927
820
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
928
821
|
if is_scheduler_warp:
|
|
@@ -934,13 +827,11 @@ class GemmSm90:
|
|
|
934
827
|
(not self.pingpong and warp_idx == 0)
|
|
935
828
|
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
|
936
829
|
)
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
is_manager_warp=is_tma_warp,
|
|
943
|
-
)
|
|
830
|
+
varlen_manager.init_tensormap_epi(
|
|
831
|
+
tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
|
|
832
|
+
)
|
|
833
|
+
tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
|
|
834
|
+
tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
|
|
944
835
|
# //////////////////////////////////////////////////////////////////////////////
|
|
945
836
|
# Partition global tensor for TiledMMA_A/B/C
|
|
946
837
|
# //////////////////////////////////////////////////////////////////////////////
|
|
@@ -962,7 +853,9 @@ class GemmSm90:
|
|
|
962
853
|
tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
|
|
963
854
|
tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
|
|
964
855
|
|
|
965
|
-
acc_shape = tiled_mma.partition_shape_C(
|
|
856
|
+
acc_shape = tiled_mma.partition_shape_C(
|
|
857
|
+
cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
|
|
858
|
+
)
|
|
966
859
|
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
967
860
|
acc_slow = None
|
|
968
861
|
if const_expr(self.fp8_slow_accum):
|
|
@@ -974,10 +867,11 @@ class GemmSm90:
|
|
|
974
867
|
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
|
975
868
|
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
|
976
869
|
|
|
977
|
-
k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.
|
|
978
|
-
c_tile_cnt = cute.size(cute.ceil_div(self.
|
|
870
|
+
k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2])
|
|
871
|
+
c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
|
|
979
872
|
|
|
980
873
|
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
|
874
|
+
epi_store_pipeline = self.make_epi_store_pipeline()
|
|
981
875
|
epi_read_state = make_pipeline_state(
|
|
982
876
|
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
|
983
877
|
)
|
|
@@ -996,9 +890,8 @@ class GemmSm90:
|
|
|
996
890
|
if const_expr(not varlen_k):
|
|
997
891
|
ab_read_state.advance_iters(k_tile_cnt_static)
|
|
998
892
|
else:
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
893
|
+
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
|
894
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1002
895
|
ab_read_state.advance_iters(k_tile_cnt)
|
|
1003
896
|
tile_scheduler.advance_to_next_work()
|
|
1004
897
|
if const_expr(varlen_k):
|
|
@@ -1009,31 +902,22 @@ class GemmSm90:
|
|
|
1009
902
|
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1010
903
|
if const_expr(varlen_m):
|
|
1011
904
|
# wait tensormap initialization complete before update
|
|
1012
|
-
|
|
1013
|
-
# batch index of last tile
|
|
1014
|
-
last_batch_idx = cutlass.Int32(-1)
|
|
905
|
+
varlen_manager.fence_tensormap_init()
|
|
1015
906
|
while work_tile.is_valid_tile:
|
|
1016
907
|
tile_coord_mnkl = work_tile.tile_idx
|
|
1017
908
|
batch_idx = tile_coord_mnkl[3]
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
|
|
1028
|
-
tensormap_smem_ptr=None,
|
|
1029
|
-
)
|
|
1030
|
-
|
|
1031
|
-
k_len = (
|
|
1032
|
-
cu_seqlens_k[batch_idx + 1] - cu_seqlens_k[batch_idx]
|
|
1033
|
-
if const_expr(varlen_k)
|
|
1034
|
-
else mA_mkl.shape[1]
|
|
909
|
+
epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
|
|
910
|
+
epilogue_params, varlen_params.cu_seqlens_m, batch_idx
|
|
911
|
+
)
|
|
912
|
+
varlen_manager.update_tensormap_epi(
|
|
913
|
+
batch_idx,
|
|
914
|
+
self.d_layout,
|
|
915
|
+
epi_shapes,
|
|
916
|
+
epi_orders,
|
|
917
|
+
is_tma_warp,
|
|
1035
918
|
)
|
|
1036
|
-
|
|
919
|
+
len_k = varlen_manager.len_k(batch_idx)
|
|
920
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1037
921
|
ab_read_state, tiled_mma = self.mma(
|
|
1038
922
|
ab_pipeline,
|
|
1039
923
|
ab_read_state,
|
|
@@ -1056,51 +940,42 @@ class GemmSm90:
|
|
|
1056
940
|
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
|
1057
941
|
|
|
1058
942
|
epilogue_barrier = pipeline.NamedBarrier(
|
|
1059
|
-
barrier_id=int(NamedBarrierGemm.Epilogue),
|
|
943
|
+
barrier_id=int(NamedBarrierGemm.Epilogue),
|
|
944
|
+
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
|
|
1060
945
|
)
|
|
1061
946
|
|
|
1062
|
-
|
|
1063
|
-
# ensure the update to tensormap has completed before using it
|
|
1064
|
-
if is_group_changed and is_tma_warp:
|
|
1065
|
-
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1066
|
-
tma_desc_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1067
|
-
tensormap_d_ptr, cute.AddressSpace.generic
|
|
1068
|
-
)
|
|
1069
|
-
else:
|
|
1070
|
-
tma_desc_d_ptr = None
|
|
947
|
+
varlen_manager.fence_tensormap_update_epi(is_tma_warp)
|
|
1071
948
|
|
|
949
|
+
copy_D = None
|
|
1072
950
|
if const_expr(has_D):
|
|
1073
|
-
|
|
951
|
+
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
|
|
1074
952
|
tma_atom_d,
|
|
1075
|
-
mD_mnl,
|
|
1076
|
-
self.
|
|
953
|
+
varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
|
|
954
|
+
self.cta_tile_shape_mnk[:2],
|
|
1077
955
|
self.epi_tile,
|
|
1078
956
|
sD,
|
|
1079
957
|
tile_coord_mnkl,
|
|
1080
|
-
|
|
958
|
+
tma_desc_ptr=tma_desc_d_ptr,
|
|
1081
959
|
)
|
|
1082
|
-
|
|
1083
|
-
else:
|
|
1084
|
-
bSG_sD, bSG_gD, copy_D = None, None, None
|
|
960
|
+
copy_C = None
|
|
1085
961
|
if const_expr(has_C):
|
|
1086
|
-
|
|
962
|
+
copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
|
|
1087
963
|
tma_atom_c,
|
|
1088
|
-
mC_mnl,
|
|
1089
|
-
self.
|
|
964
|
+
varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
|
|
965
|
+
self.cta_tile_shape_mnk[:2],
|
|
1090
966
|
self.epi_tile,
|
|
1091
967
|
sC,
|
|
1092
968
|
tile_coord_mnkl,
|
|
1093
|
-
cu_seqlens_m,
|
|
1094
969
|
)
|
|
1095
|
-
copy_C =
|
|
1096
|
-
epi_load_g2s = partial(self.epi_load_g2s, epi_pipeline, copy_C, bGS_gC, bGS_sC)
|
|
1097
|
-
else:
|
|
1098
|
-
epi_load_g2s = None
|
|
970
|
+
copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
|
|
1099
971
|
|
|
1100
972
|
d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
|
|
1101
|
-
tiled_copy_r2s,
|
|
1102
|
-
tiled_mma, self.d_layout, d_dtype_for_layout,
|
|
973
|
+
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
|
|
974
|
+
tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
|
|
1103
975
|
)
|
|
976
|
+
# (R2S, R2S_M, R2S_N)
|
|
977
|
+
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
978
|
+
load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
|
|
1104
979
|
if const_expr(has_C):
|
|
1105
980
|
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
|
|
1106
981
|
tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
|
|
@@ -1118,24 +993,25 @@ class GemmSm90:
|
|
|
1118
993
|
epi_read_state, epi_producer_state = self.epilogue(
|
|
1119
994
|
epilogue_params,
|
|
1120
995
|
epi_smem_tensors,
|
|
996
|
+
tma_desc_epi_ptrs,
|
|
1121
997
|
epi_pipeline,
|
|
998
|
+
epi_store_pipeline,
|
|
1122
999
|
epi_read_state,
|
|
1123
1000
|
epi_producer_state,
|
|
1124
|
-
|
|
1125
|
-
|
|
1001
|
+
self.epi_tile,
|
|
1002
|
+
load_acc_subtile,
|
|
1126
1003
|
tRS_rD,
|
|
1127
1004
|
tRS_rC,
|
|
1005
|
+
None, # tiled_copy_t2r, for Sm100 only
|
|
1128
1006
|
tiled_copy_r2s,
|
|
1129
1007
|
tRS_sD,
|
|
1130
1008
|
tiled_copy_s2r,
|
|
1131
1009
|
tSR_rC,
|
|
1132
1010
|
tSR_sC,
|
|
1133
1011
|
copy_D,
|
|
1134
|
-
|
|
1135
|
-
bSG_gD,
|
|
1136
|
-
epi_load_g2s,
|
|
1012
|
+
copy_C,
|
|
1137
1013
|
tile_coord_mnkl,
|
|
1138
|
-
|
|
1014
|
+
varlen_manager,
|
|
1139
1015
|
epilogue_barrier,
|
|
1140
1016
|
tile_scheduler,
|
|
1141
1017
|
tidx,
|
|
@@ -1147,7 +1023,7 @@ class GemmSm90:
|
|
|
1147
1023
|
# so we have to make sure the smem content is done reading before signaling
|
|
1148
1024
|
# the next WG's epilogue.
|
|
1149
1025
|
if is_tma_warp:
|
|
1150
|
-
|
|
1026
|
+
epi_store_pipeline.producer_tail()
|
|
1151
1027
|
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
|
1152
1028
|
|
|
1153
1029
|
if const_expr(not self.pingpong):
|
|
@@ -1166,31 +1042,33 @@ class GemmSm90:
|
|
|
1166
1042
|
tile_scheduler.advance_to_next_work()
|
|
1167
1043
|
work_tile = tile_scheduler.get_current_work()
|
|
1168
1044
|
if work_tile.is_valid_tile:
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
k_tile_cnt = cute.ceil_div(k_len, self.tile_shape_mnk[2])
|
|
1045
|
+
len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
|
|
1046
|
+
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
|
|
1172
1047
|
ab_read_state.advance_iters(k_tile_cnt)
|
|
1173
1048
|
tile_scheduler.advance_to_next_work()
|
|
1174
1049
|
work_tile = tile_scheduler.get_current_work()
|
|
1175
1050
|
# End of persistent scheduler loop
|
|
1176
1051
|
|
|
1052
|
+
# Wait for D store complete
|
|
1177
1053
|
if const_expr(not self.pingpong):
|
|
1178
1054
|
if is_tma_warp:
|
|
1179
|
-
|
|
1055
|
+
epi_store_pipeline.producer_tail()
|
|
1180
1056
|
|
|
1181
1057
|
@cute.jit
|
|
1182
1058
|
def load_AB(
|
|
1183
1059
|
self,
|
|
1184
1060
|
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1185
1061
|
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1186
|
-
copy_A: Callable,
|
|
1187
|
-
tAgA: cute.Tensor,
|
|
1188
|
-
tAsA: cute.Tensor,
|
|
1062
|
+
copy_A: Optional[Callable],
|
|
1189
1063
|
copy_B: Callable,
|
|
1190
|
-
tBgB: cute.Tensor,
|
|
1191
|
-
tBsB: cute.Tensor,
|
|
1192
1064
|
k_tile_cnt: Int32,
|
|
1065
|
+
# These are for Sm100 blockscaled gemm
|
|
1066
|
+
copy_SFA: Optional[Callable] = None,
|
|
1067
|
+
copy_SFB: Optional[Callable] = None,
|
|
1193
1068
|
) -> cutlass.pipeline.PipelineState:
|
|
1069
|
+
blockscaled = const_expr(copy_SFA is not None)
|
|
1070
|
+
if const_expr(blockscaled):
|
|
1071
|
+
assert copy_SFB is not None
|
|
1194
1072
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1195
1073
|
peek_ab_empty_status = Boolean(True)
|
|
1196
1074
|
if 0 < k_tile_cnt:
|
|
@@ -1203,8 +1081,13 @@ class GemmSm90:
|
|
|
1203
1081
|
# Also sets the transaction barrier for the A/B buffers
|
|
1204
1082
|
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
|
1205
1083
|
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1206
|
-
|
|
1207
|
-
|
|
1084
|
+
smem_idx = ab_producer_state.index
|
|
1085
|
+
if const_expr(copy_A is not None):
|
|
1086
|
+
copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1087
|
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1088
|
+
if const_expr(blockscaled):
|
|
1089
|
+
copy_SFA(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1090
|
+
copy_SFB(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1208
1091
|
# Mainloop pipeline's producer commit is a NOP
|
|
1209
1092
|
ab_pipeline.producer_commit(ab_producer_state)
|
|
1210
1093
|
ab_producer_state.advance()
|
|
@@ -1218,38 +1101,12 @@ class GemmSm90:
|
|
|
1218
1101
|
self,
|
|
1219
1102
|
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1220
1103
|
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
tAsA: cute.Tensor,
|
|
1224
|
-
gAIdx: cute.Tensor,
|
|
1104
|
+
copy_A: Callable,
|
|
1105
|
+
prefetch_A: Optional[Callable],
|
|
1225
1106
|
copy_B: Callable,
|
|
1226
|
-
tBgB: cute.Tensor,
|
|
1227
|
-
tBsB: cute.Tensor,
|
|
1228
1107
|
k_tile_cnt: Int32,
|
|
1229
|
-
|
|
1108
|
+
varlen_m: bool = True,
|
|
1230
1109
|
) -> cutlass.pipeline.PipelineState:
|
|
1231
|
-
# (atom_v, CPY_M, 1, RestK)
|
|
1232
|
-
limit_m, limit_k = limit_A
|
|
1233
|
-
limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
|
|
1234
|
-
cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
|
|
1235
|
-
tAcA = thr_copy_A.partition_S(cA)
|
|
1236
|
-
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
1237
|
-
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
1238
|
-
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
1239
|
-
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
1240
|
-
limit_m = limit_m - tAcA[0][0]
|
|
1241
|
-
# Read indices for A
|
|
1242
|
-
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
1243
|
-
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
|
1244
|
-
for m in cutlass.range(rows_per_thread):
|
|
1245
|
-
row_idx = tAcA[0, m, 0][0]
|
|
1246
|
-
if t0AcA[0, m, 0][0] < limit_m:
|
|
1247
|
-
m_idx[m] = gAIdx[row_idx]
|
|
1248
|
-
else:
|
|
1249
|
-
m_idx[m] = -1
|
|
1250
|
-
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
1251
|
-
# (m, (bK, RestK))
|
|
1252
|
-
mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
|
|
1253
1110
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
1254
1111
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1255
1112
|
peek_ab_empty_status = Boolean(True)
|
|
@@ -1258,35 +1115,27 @@ class GemmSm90:
|
|
|
1258
1115
|
# /////////////////////////////////////////////////////////////////////////
|
|
1259
1116
|
# TMA load on B and cp.async on A
|
|
1260
1117
|
# /////////////////////////////////////////////////////////////////////////
|
|
1261
|
-
copy_A = partial(cute.copy, thr_copy_A)
|
|
1262
1118
|
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1119
|
+
prefetch_out = ()
|
|
1120
|
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
1121
|
+
prefetch_out = (prefetch_A(k_tile),)
|
|
1263
1122
|
# Wait for A/B buffers to be empty before loading into them
|
|
1264
1123
|
# Also sets the transaction barrier for the A/B buffers
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1124
|
+
# A tiny bit faster to rotate the warp that does TMA
|
|
1125
|
+
# However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
|
|
1126
|
+
# since that's the warp that does the tensormap update.
|
|
1127
|
+
is_tma_warp = warp_idx == self.ab_load_warp_id + (
|
|
1128
|
+
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
|
1270
1129
|
)
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
# (m, bK)
|
|
1279
|
-
mA_cur = mA_k[None, (None, k_tile)]
|
|
1280
|
-
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1281
|
-
# (elems_per_load, thread_per_row)
|
|
1282
|
-
mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
|
|
1283
|
-
if t0AcA[0, m, 0][0] < limit_m:
|
|
1284
|
-
# There's only 1 load per row
|
|
1285
|
-
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1286
|
-
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1287
|
-
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1130
|
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
|
1131
|
+
smem_idx = ab_producer_state.index
|
|
1132
|
+
# A bit faster to load B first while we calculate the indices for A
|
|
1133
|
+
if is_tma_warp:
|
|
1134
|
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1135
|
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1136
|
+
copy_A(k_tile, smem_idx, *prefetch_out)
|
|
1288
1137
|
# This tells mbarrier to track the completion of cp.async
|
|
1289
|
-
ab_pipeline.
|
|
1138
|
+
ab_pipeline.producer_cpasync_commit(ab_producer_state)
|
|
1290
1139
|
ab_producer_state.advance()
|
|
1291
1140
|
peek_ab_empty_status = Boolean(True)
|
|
1292
1141
|
if k_tile + 1 < k_tile_cnt:
|
|
@@ -1294,33 +1143,19 @@ class GemmSm90:
|
|
|
1294
1143
|
# bound checking in the K dimension on the last k_tile
|
|
1295
1144
|
if 0 < k_tile_cnt:
|
|
1296
1145
|
k_tile = k_tile_cnt - 1
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1146
|
+
prefetch_out = ()
|
|
1147
|
+
if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
|
|
1148
|
+
prefetch_out = (prefetch_A(k_tile, pred=True),)
|
|
1149
|
+
is_tma_warp = warp_idx == self.ab_load_warp_id + (
|
|
1150
|
+
(k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
|
|
1301
1151
|
)
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
tApA = cute.make_fragment(1, Boolean)
|
|
1310
|
-
tApA[0] = tAcA[0, 0, 0][1] < limit_k
|
|
1311
|
-
# (m, bK)
|
|
1312
|
-
mA_cur = mA_k[None, (None, k_tile)]
|
|
1313
|
-
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1314
|
-
# (elems_per_load, thread_per_row)
|
|
1315
|
-
mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
|
|
1316
|
-
if t0AcA[0, m, 0][0] < limit_m:
|
|
1317
|
-
# There's only 1 load per row
|
|
1318
|
-
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1319
|
-
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1320
|
-
# copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
|
|
1321
|
-
# TODO
|
|
1322
|
-
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1323
|
-
ab_pipeline.producer_commit(ab_producer_state)
|
|
1152
|
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
|
|
1153
|
+
smem_idx = ab_producer_state.index
|
|
1154
|
+
if is_tma_warp:
|
|
1155
|
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1156
|
+
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
|
|
1157
|
+
copy_A(k_tile, smem_idx, *prefetch_out, pred=True)
|
|
1158
|
+
ab_pipeline.producer_cpasync_commit(ab_producer_state)
|
|
1324
1159
|
ab_producer_state.advance()
|
|
1325
1160
|
return ab_producer_state
|
|
1326
1161
|
|
|
@@ -1416,24 +1251,25 @@ class GemmSm90:
|
|
|
1416
1251
|
self,
|
|
1417
1252
|
params: EpilogueParams,
|
|
1418
1253
|
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
1254
|
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
|
1419
1255
|
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1256
|
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1420
1257
|
epi_read_state: cutlass.pipeline.PipelineState,
|
|
1421
|
-
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
1422
|
-
|
|
1423
|
-
|
|
1258
|
+
epi_producer_state: Optional[cutlass.pipeline.PipelineState],
|
|
1259
|
+
epi_tile: cute.Tile,
|
|
1260
|
+
load_acc_subtile: Callable,
|
|
1424
1261
|
tRS_rD: cute.Tensor,
|
|
1425
1262
|
tRS_rC: Optional[cute.Tensor],
|
|
1426
|
-
|
|
1263
|
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
|
1264
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
1427
1265
|
tRS_sD: cute.Tensor,
|
|
1428
|
-
tiled_copy_s2r: Optional[cute.
|
|
1266
|
+
tiled_copy_s2r: Optional[cute.ThrCopy],
|
|
1429
1267
|
tSR_rC: Optional[cute.Tensor],
|
|
1430
1268
|
tSR_sC: Optional[cute.Tensor],
|
|
1431
1269
|
copy_D: Optional[Callable],
|
|
1432
|
-
|
|
1433
|
-
bSG_gD: cute.Tensor,
|
|
1434
|
-
epi_load_g2s: Optional[Callable],
|
|
1270
|
+
copy_C: Optional[Callable],
|
|
1435
1271
|
tile_coord_mnkl: cute.Coord,
|
|
1436
|
-
|
|
1272
|
+
varlen_manager: VarlenManager,
|
|
1437
1273
|
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
1438
1274
|
tile_scheduler,
|
|
1439
1275
|
tidx: Int32,
|
|
@@ -1441,22 +1277,61 @@ class GemmSm90:
|
|
|
1441
1277
|
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
1442
1278
|
has_C = const_expr(tRS_rC is not None)
|
|
1443
1279
|
has_D = const_expr(copy_D is not None)
|
|
1444
|
-
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1445
1280
|
epi_tile_shape = cute.zipped_divide(
|
|
1446
|
-
cute.make_layout(self.
|
|
1281
|
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
|
1447
1282
|
).shape[1]
|
|
1448
|
-
|
|
1283
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
1284
|
+
epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0))
|
|
1449
1285
|
epi_tile_num = cute.size(epi_tile_shape)
|
|
1450
1286
|
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
1451
1287
|
|
|
1452
|
-
|
|
1288
|
+
epi_tensors = self.epi_begin(
|
|
1289
|
+
params,
|
|
1290
|
+
epi_smem_tensors,
|
|
1291
|
+
epi_tile,
|
|
1292
|
+
tiled_copy_t2r,
|
|
1293
|
+
tiled_copy_r2s,
|
|
1294
|
+
tile_coord_mnkl,
|
|
1295
|
+
varlen_manager,
|
|
1296
|
+
epilogue_barrier,
|
|
1297
|
+
tidx,
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
if const_expr(copy_C is not None):
|
|
1453
1301
|
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
1454
|
-
|
|
1302
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1303
|
+
if is_tma_warp:
|
|
1304
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1305
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
1306
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1307
|
+
epi_producer_state.advance()
|
|
1308
|
+
|
|
1309
|
+
def tma_store_fn(src_idx, dst_idx):
|
|
1310
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1311
|
+
cute.arch.fence_proxy(
|
|
1312
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1313
|
+
)
|
|
1314
|
+
epilogue_barrier.arrive_and_wait()
|
|
1315
|
+
# Copy from shared memory to global memory
|
|
1316
|
+
if is_tma_warp:
|
|
1317
|
+
if const_expr(has_D):
|
|
1318
|
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
|
1319
|
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
|
1320
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
|
1321
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
|
1322
|
+
epilogue_barrier.arrive_and_wait()
|
|
1323
|
+
|
|
1324
|
+
# We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops
|
|
1325
|
+
# with the TMA store. However, currently this doesn't seem to improve perf.
|
|
1326
|
+
delay_tma_store = False
|
|
1455
1327
|
|
|
1328
|
+
src_idx_prev, dst_idx_prev = None, None
|
|
1456
1329
|
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
1330
|
+
# The global memory coordinate for the current epi tile
|
|
1331
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1457
1332
|
# Copy from acc to D registers
|
|
1458
|
-
|
|
1459
|
-
|
|
1333
|
+
load_acc_subtile(tRS_rD, epi_idx)
|
|
1334
|
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
|
1460
1335
|
if const_expr(has_C):
|
|
1461
1336
|
epi_pipeline.consumer_wait(epi_read_state)
|
|
1462
1337
|
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
@@ -1468,98 +1343,132 @@ class GemmSm90:
|
|
|
1468
1343
|
with cute.arch.elect_one():
|
|
1469
1344
|
epi_pipeline.consumer_release(epi_read_state)
|
|
1470
1345
|
epi_read_state.advance()
|
|
1471
|
-
if const_expr(
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1346
|
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
1347
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
|
1348
|
+
if is_tma_warp:
|
|
1349
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1350
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
1351
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1352
|
+
epi_producer_state.advance()
|
|
1353
|
+
tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
1476
1354
|
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
1355
|
+
if const_expr(delay_tma_store):
|
|
1356
|
+
if const_expr(epi_idx > 0):
|
|
1357
|
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
|
1358
|
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
|
1477
1359
|
# Copy from D registers to shared memory
|
|
1478
1360
|
if const_expr(has_D):
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
epilogue_barrier.arrive_and_wait()
|
|
1361
|
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
|
1362
|
+
if const_expr(not delay_tma_store):
|
|
1363
|
+
tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
|
|
1364
|
+
|
|
1365
|
+
if const_expr(delay_tma_store):
|
|
1366
|
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
|
1367
|
+
|
|
1368
|
+
self.epi_end(
|
|
1369
|
+
params,
|
|
1370
|
+
epi_tensors,
|
|
1371
|
+
epi_tile,
|
|
1372
|
+
tiled_copy_t2r,
|
|
1373
|
+
tiled_copy_r2s,
|
|
1374
|
+
tile_coord_mnkl,
|
|
1375
|
+
varlen_manager,
|
|
1376
|
+
tidx,
|
|
1377
|
+
)
|
|
1497
1378
|
|
|
1498
1379
|
return epi_read_state, epi_producer_state
|
|
1499
1380
|
|
|
1500
|
-
|
|
1501
|
-
|
|
1381
|
+
def get_scheduler_class(self, varlen_m: bool = False):
|
|
1382
|
+
"""Return the scheduler class to use. Override in subclasses for custom schedulers."""
|
|
1383
|
+
return TileScheduler if not varlen_m else VarlenMTileScheduler
|
|
1384
|
+
|
|
1385
|
+
def get_scheduler_arguments(
|
|
1502
1386
|
self,
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1387
|
+
mA: cute.Tensor,
|
|
1388
|
+
mB: cute.Tensor,
|
|
1389
|
+
mD: Optional[cute.Tensor],
|
|
1390
|
+
scheduler_args,
|
|
1391
|
+
varlen_args,
|
|
1392
|
+
):
|
|
1393
|
+
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
|
1394
|
+
if const_expr(varlen_args.mCuSeqlensM is None):
|
|
1395
|
+
num_problems = (
|
|
1396
|
+
mD.shape[2]
|
|
1397
|
+
if mD is not None
|
|
1398
|
+
else (
|
|
1399
|
+
mB.shape[2]
|
|
1400
|
+
if varlen_args.mCuSeqlensK is None
|
|
1401
|
+
else varlen_args.mCuSeqlensK.shape[0] - 1
|
|
1402
|
+
)
|
|
1403
|
+
)
|
|
1404
|
+
problem_shape_ntile_mnl = (
|
|
1405
|
+
cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]),
|
|
1406
|
+
cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
|
|
1407
|
+
num_problems,
|
|
1408
|
+
)
|
|
1409
|
+
tile_sched_args = TileSchedulerArguments(
|
|
1410
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
1411
|
+
raster_order=scheduler_args.raster_order,
|
|
1412
|
+
group_size=scheduler_args.max_swizzle_size,
|
|
1413
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1414
|
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1415
|
+
batch_idx_permute=scheduler_args.batch_idx_permute,
|
|
1416
|
+
is_persistent=self.is_persistent,
|
|
1521
1417
|
)
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1418
|
+
else:
|
|
1419
|
+
assert mD is not None or not self.gather_A
|
|
1420
|
+
problem_shape_ntile_mnl = (
|
|
1421
|
+
None,
|
|
1422
|
+
cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
|
|
1423
|
+
varlen_args.mCuSeqlensM.shape[0] - 1,
|
|
1424
|
+
)
|
|
1425
|
+
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
1426
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
1427
|
+
total_m=mD.shape[0] if mD is not None else varlen_args.mAIdx.shape[0],
|
|
1428
|
+
cu_seqlens_m=varlen_args.mCuSeqlensM,
|
|
1429
|
+
raster_order=scheduler_args.raster_order,
|
|
1430
|
+
group_size=scheduler_args.max_swizzle_size,
|
|
1431
|
+
tile_shape_mn=self.cta_tile_shape_mnk[:2],
|
|
1432
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1433
|
+
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1434
|
+
is_persistent=self.is_persistent,
|
|
1435
|
+
)
|
|
1436
|
+
return tile_sched_args
|
|
1437
|
+
|
|
1438
|
+
@cute.jit
|
|
1439
|
+
def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
|
|
1440
|
+
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
1441
|
+
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
1442
|
+
|
|
1443
|
+
@cute.jit
|
|
1444
|
+
def epi_begin(
|
|
1445
|
+
self,
|
|
1446
|
+
params: EpilogueParams,
|
|
1447
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
1448
|
+
epi_tile: cute.Tile,
|
|
1449
|
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
|
1450
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
1451
|
+
tile_coord_mnkl: cute.Coord,
|
|
1452
|
+
varlen_manager: VarlenManager,
|
|
1453
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
1454
|
+
tidx: Int32,
|
|
1455
|
+
) -> Tuple[cute.Tensor, ...]:
|
|
1456
|
+
return ()
|
|
1526
1457
|
|
|
1527
|
-
def
|
|
1458
|
+
def epi_begin_loop(
|
|
1459
|
+
self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord
|
|
1460
|
+
) -> Tuple[cute.Tensor, ...]:
|
|
1461
|
+
return ()
|
|
1462
|
+
|
|
1463
|
+
def epi_visit_subtile(
|
|
1528
1464
|
self,
|
|
1529
1465
|
params: EpilogueParams,
|
|
1466
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
1530
1467
|
tRS_rD: cute.Tensor,
|
|
1531
1468
|
tRS_rC: Optional[cute.Tensor] = None,
|
|
1532
1469
|
) -> Optional[cute.Tensor]:
|
|
1533
|
-
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
|
1534
|
-
if const_expr(hasattr(params, "alpha") and params.alpha is not None):
|
|
1535
|
-
alpha = utils.load_scalar_or_pointer(params.alpha)
|
|
1536
|
-
tRS_rD.store(tRS_rD.load() * alpha)
|
|
1537
|
-
# Apply C with beta scaling
|
|
1538
|
-
if const_expr(tRS_rC is not None):
|
|
1539
|
-
if const_expr(not hasattr(params, "beta") or params.beta is None):
|
|
1540
|
-
# beta is None, default behavior: add C (beta=1.0)
|
|
1541
|
-
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
|
|
1542
|
-
else:
|
|
1543
|
-
beta = utils.load_scalar_or_pointer(params.beta)
|
|
1544
|
-
tRS_rD.store(tRS_rD.load() + beta * tRS_rC.load().to(tRS_rD.element_type))
|
|
1545
1470
|
return None
|
|
1546
1471
|
|
|
1547
|
-
def get_scheduler_class(self):
|
|
1548
|
-
"""Return the scheduler class to use. Override in subclasses for custom schedulers."""
|
|
1549
|
-
return TileScheduler
|
|
1550
|
-
|
|
1551
|
-
def get_scheduler_arguments(self, problem_shape_ntile_mnl, scheduler_args):
|
|
1552
|
-
"""Create scheduler arguments. Override in subclasses for custom schedulers."""
|
|
1553
|
-
return TileSchedulerArguments(
|
|
1554
|
-
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
1555
|
-
raster_order=scheduler_args.raster_order,
|
|
1556
|
-
group_size=scheduler_args.max_swizzle_size,
|
|
1557
|
-
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
1558
|
-
tile_count_semaphore=scheduler_args.tile_count_semaphore,
|
|
1559
|
-
batch_idx_permute=scheduler_args.batch_idx_permute,
|
|
1560
|
-
is_persistent=self.is_persistent,
|
|
1561
|
-
)
|
|
1562
|
-
|
|
1563
1472
|
def epi_visit_acc(
|
|
1564
1473
|
self,
|
|
1565
1474
|
params: EpilogueParams,
|
|
@@ -1570,26 +1479,58 @@ class GemmSm90:
|
|
|
1570
1479
|
) -> None:
|
|
1571
1480
|
pass
|
|
1572
1481
|
|
|
1482
|
+
@cute.jit
|
|
1483
|
+
def epi_end(
|
|
1484
|
+
self,
|
|
1485
|
+
params: EpilogueParams,
|
|
1486
|
+
epi_tensors: Tuple[cute.Tensor, ...],
|
|
1487
|
+
epi_tile: cute.Tile,
|
|
1488
|
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
|
1489
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
1490
|
+
tile_coord_mnkl: cute.Coord,
|
|
1491
|
+
varlen_manager,
|
|
1492
|
+
tidx,
|
|
1493
|
+
) -> None:
|
|
1494
|
+
pass
|
|
1495
|
+
|
|
1573
1496
|
def epi_to_underlying_arguments(
|
|
1574
1497
|
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
1575
1498
|
) -> EpilogueParams:
|
|
1576
|
-
return
|
|
1499
|
+
return self.EpilogueParams()
|
|
1500
|
+
|
|
1501
|
+
def epi_get_tma_atoms(
|
|
1502
|
+
self, params: EpilogueParams, *, loc=None, ip=None
|
|
1503
|
+
) -> list[cute.CopyAtom]:
|
|
1504
|
+
"""Subclasses can override this"""
|
|
1505
|
+
return []
|
|
1506
|
+
|
|
1507
|
+
def epi_get_tensormap_update_shapes_orders(
|
|
1508
|
+
self,
|
|
1509
|
+
params: EpilogueParams,
|
|
1510
|
+
cu_seqlens_m: cute.Tensor,
|
|
1511
|
+
batch_idx: Int32,
|
|
1512
|
+
*,
|
|
1513
|
+
loc=None,
|
|
1514
|
+
ip=None,
|
|
1515
|
+
) -> tuple[list[Int32], list[int]]:
|
|
1516
|
+
"""Subclasses can override this"""
|
|
1517
|
+
return [], []
|
|
1577
1518
|
|
|
1578
1519
|
@staticmethod
|
|
1579
1520
|
def epi_smem_bytes_per_stage(
|
|
1580
1521
|
args: Optional[EpilogueArguments],
|
|
1581
|
-
|
|
1582
|
-
epi_tile:
|
|
1522
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1523
|
+
epi_tile: cute.Tile,
|
|
1583
1524
|
) -> int:
|
|
1584
1525
|
return 0
|
|
1585
1526
|
|
|
1586
1527
|
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
1587
|
-
return cute.struct.MemRange[
|
|
1528
|
+
return cute.struct.MemRange[Int32, 0] # Dummy struct
|
|
1588
1529
|
|
|
1589
1530
|
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
1590
1531
|
return tuple()
|
|
1591
1532
|
|
|
1592
|
-
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage:
|
|
1533
|
+
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
|
|
1593
1534
|
assert stage in ["mma", "epi"]
|
|
1594
1535
|
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1595
1536
|
cute.arch.barrier(
|
|
@@ -1597,7 +1538,7 @@ class GemmSm90:
|
|
|
1597
1538
|
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1598
1539
|
)
|
|
1599
1540
|
|
|
1600
|
-
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage:
|
|
1541
|
+
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
|
|
1601
1542
|
assert stage in ["mma", "epi"]
|
|
1602
1543
|
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1603
1544
|
cute.arch.barrier_arrive(
|
|
@@ -1611,7 +1552,7 @@ class GemmSm90:
|
|
|
1611
1552
|
self.d_layout.is_m_major_c() if self.d_layout is not None else False,
|
|
1612
1553
|
num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
|
|
1613
1554
|
),
|
|
1614
|
-
|
|
1555
|
+
Float16, # this is just to get the right source layout
|
|
1615
1556
|
)
|
|
1616
1557
|
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
|
1617
1558
|
return tiled_copy_C_atom
|
|
@@ -1621,8 +1562,7 @@ class GemmSm90:
|
|
|
1621
1562
|
tiled_mma: cute.TiledMma,
|
|
1622
1563
|
d_layout: Optional[LayoutEnum],
|
|
1623
1564
|
dtype: Type[cutlass.Numeric],
|
|
1624
|
-
|
|
1625
|
-
sD: cute.Tensor,
|
|
1565
|
+
sD: Optional[cute.Tensor],
|
|
1626
1566
|
tidx: Int32,
|
|
1627
1567
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1628
1568
|
if d_layout is None:
|
|
@@ -1637,12 +1577,10 @@ class GemmSm90:
|
|
|
1637
1577
|
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1638
1578
|
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1639
1579
|
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
|
|
1640
|
-
# (R2S, R2S_M, R2S_N)
|
|
1641
|
-
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
1642
1580
|
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
|
|
1643
1581
|
tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
|
|
1644
1582
|
tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
|
|
1645
|
-
return tiled_copy_r2s,
|
|
1583
|
+
return tiled_copy_r2s, tRS_rD, tRS_sD
|
|
1646
1584
|
|
|
1647
1585
|
def epilog_smem_load_and_partition(
|
|
1648
1586
|
self,
|
|
@@ -1654,7 +1592,7 @@ class GemmSm90:
|
|
|
1654
1592
|
tidx: Int32,
|
|
1655
1593
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
1656
1594
|
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
1657
|
-
copy_atom_s2r =
|
|
1595
|
+
copy_atom_s2r = copy_utils.sm90_get_smem_load_op(c_layout, dtype)
|
|
1658
1596
|
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
|
1659
1597
|
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1660
1598
|
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
@@ -1665,57 +1603,53 @@ class GemmSm90:
|
|
|
1665
1603
|
def epilog_gmem_copy_and_partition(
|
|
1666
1604
|
self,
|
|
1667
1605
|
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
|
1668
|
-
|
|
1606
|
+
mD_mn: cute.Tensor,
|
|
1669
1607
|
tile_shape_mn: cute.Tile,
|
|
1670
1608
|
epi_tile: cute.Tile,
|
|
1671
1609
|
sD: cute.Tensor,
|
|
1672
1610
|
tile_coord_mnkl: cute.Coord,
|
|
1673
|
-
|
|
1611
|
+
tma_desc_ptr: Optional[cute.Pointer] = None,
|
|
1674
1612
|
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
1675
|
-
batch_idx = tile_coord_mnkl[3]
|
|
1676
|
-
if const_expr(cu_seqlens_m is not None):
|
|
1677
|
-
mD_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl)
|
|
1678
|
-
else:
|
|
1679
|
-
mD_mn = mD_mnl[None, None, batch_idx]
|
|
1680
1613
|
# (bM, bN)
|
|
1681
1614
|
gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
|
|
1682
1615
|
tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
|
|
1683
|
-
|
|
1616
|
+
is_s2g = isinstance(
|
|
1617
|
+
atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp)
|
|
1618
|
+
)
|
|
1619
|
+
src_tensor, dst_tensor = (
|
|
1620
|
+
(sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD)
|
|
1621
|
+
)
|
|
1622
|
+
return copy_utils.tma_get_copy_fn(
|
|
1684
1623
|
atom,
|
|
1685
|
-
0,
|
|
1686
|
-
cute.make_layout(1),
|
|
1687
|
-
|
|
1688
|
-
|
|
1624
|
+
cta_coord=0,
|
|
1625
|
+
cta_layout=cute.make_layout(1),
|
|
1626
|
+
src_tensor=src_tensor,
|
|
1627
|
+
dst_tensor=dst_tensor,
|
|
1628
|
+
tma_desc_ptr=tma_desc_ptr,
|
|
1689
1629
|
)
|
|
1690
|
-
return bSG_sD, bSG_gD
|
|
1691
1630
|
|
|
1692
1631
|
def make_ab_pipeline(
|
|
1693
1632
|
self,
|
|
1694
|
-
a_smem_layout: cute.Layout | cute.ComposedLayout,
|
|
1695
|
-
b_smem_layout: cute.Layout | cute.ComposedLayout,
|
|
1696
1633
|
tiled_mma: cute.TiledMma,
|
|
1697
1634
|
cluster_layout_vmnk: cute.Layout,
|
|
1698
1635
|
ab_pipeline_mbar_ptr: cute.Pointer,
|
|
1699
1636
|
):
|
|
1700
1637
|
# Threads/warps participating in this pipeline
|
|
1701
|
-
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.
|
|
1638
|
+
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32
|
|
1702
1639
|
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
1703
1640
|
# Each warp will contribute to the arrive count with the number of mcast size
|
|
1704
1641
|
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
1705
|
-
consumer_arrive_cnt = mcast_size *
|
|
1642
|
+
consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
|
|
1706
1643
|
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1707
1644
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1708
1645
|
)
|
|
1709
1646
|
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
1710
|
-
tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
1711
|
-
if const_expr(not self.gather_A):
|
|
1712
|
-
tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
1713
1647
|
return pipeline_cls.create(
|
|
1714
1648
|
barrier_storage=ab_pipeline_mbar_ptr,
|
|
1715
1649
|
num_stages=self.ab_stage,
|
|
1716
1650
|
producer_group=ab_pipeline_producer_group,
|
|
1717
1651
|
consumer_group=ab_pipeline_consumer_group,
|
|
1718
|
-
tx_count=
|
|
1652
|
+
tx_count=self.num_tma_load_bytes,
|
|
1719
1653
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1720
1654
|
)
|
|
1721
1655
|
|
|
@@ -1725,7 +1659,7 @@ class GemmSm90:
|
|
|
1725
1659
|
# Threads/warps participating in this pipeline
|
|
1726
1660
|
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
1727
1661
|
# Each warp will contribute 1 to the arrive count
|
|
1728
|
-
consumer_arrive_cnt = self.
|
|
1662
|
+
consumer_arrive_cnt = self.num_epi_warps
|
|
1729
1663
|
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
1730
1664
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
1731
1665
|
)
|
|
@@ -1738,6 +1672,14 @@ class GemmSm90:
|
|
|
1738
1672
|
tx_count=tma_copy_c_bytes,
|
|
1739
1673
|
)
|
|
1740
1674
|
|
|
1675
|
+
def make_epi_store_pipeline(self):
|
|
1676
|
+
# Threads/warps participating in tma store pipeline
|
|
1677
|
+
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
|
|
1678
|
+
epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads)
|
|
1679
|
+
return pipeline.PipelineTmaStore.create(
|
|
1680
|
+
num_stages=self.epi_stage, producer_group=epi_store_producer_group
|
|
1681
|
+
)
|
|
1682
|
+
|
|
1741
1683
|
def make_sched_pipeline(
|
|
1742
1684
|
self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
|
|
1743
1685
|
):
|
|
@@ -1766,21 +1708,21 @@ class GemmSm90:
|
|
|
1766
1708
|
@classmethod
|
|
1767
1709
|
def _compute_stages(
|
|
1768
1710
|
cls,
|
|
1769
|
-
|
|
1711
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1770
1712
|
epi_tile: Tuple[int, int],
|
|
1771
1713
|
a_dtype: Type[cutlass.Numeric],
|
|
1772
1714
|
b_dtype: Type[cutlass.Numeric],
|
|
1773
1715
|
d_dtype: Optional[Type[cutlass.Numeric]],
|
|
1774
1716
|
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1775
|
-
epilogue_args:
|
|
1717
|
+
epilogue_args: EpilogueArguments,
|
|
1776
1718
|
smem_capacity: int,
|
|
1777
1719
|
occupancy: int,
|
|
1778
|
-
overlap_sD_sA: bool,
|
|
1720
|
+
overlap_sD_sA: bool = False,
|
|
1779
1721
|
) -> Tuple[int, int]:
|
|
1780
1722
|
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
1781
1723
|
|
|
1782
|
-
:param
|
|
1783
|
-
:type
|
|
1724
|
+
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
|
1725
|
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
|
1784
1726
|
:param a_dtype: Data type of operand A.
|
|
1785
1727
|
:type a_dtype: type[cutlass.Numeric]
|
|
1786
1728
|
:param b_dtype: Data type of operand B.
|
|
@@ -1803,15 +1745,15 @@ class GemmSm90:
|
|
|
1803
1745
|
cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
|
|
1804
1746
|
)
|
|
1805
1747
|
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
|
|
1806
|
-
epilogue_args,
|
|
1748
|
+
epilogue_args, cta_tile_shape_mnk, epi_tile
|
|
1807
1749
|
)
|
|
1808
1750
|
epi_bytes = epi_bytes_per_stage * epi_stage
|
|
1809
1751
|
epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
|
|
1810
1752
|
if c_dtype is not None:
|
|
1811
1753
|
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
1812
1754
|
|
|
1813
|
-
a_shape = cute.slice_(
|
|
1814
|
-
b_shape = cute.slice_(
|
|
1755
|
+
a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
|
|
1756
|
+
b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
|
|
1815
1757
|
ab_bytes_per_stage = (
|
|
1816
1758
|
cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
|
|
1817
1759
|
)
|
|
@@ -1829,15 +1771,15 @@ class GemmSm90:
|
|
|
1829
1771
|
|
|
1830
1772
|
@staticmethod
|
|
1831
1773
|
def _sm90_compute_tile_shape_or_override(
|
|
1832
|
-
|
|
1774
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1833
1775
|
atom_layout_mnk: Tuple[int, int, int],
|
|
1834
1776
|
element_type: Optional[Type[cutlass.Numeric]] = None,
|
|
1835
1777
|
epi_tile_override: Tuple[int, int] | None = None,
|
|
1836
1778
|
) -> Tuple[int, int]:
|
|
1837
1779
|
"""Compute the epilogue tile shape or use override if provided.
|
|
1838
1780
|
|
|
1839
|
-
:param
|
|
1840
|
-
:type
|
|
1781
|
+
:param cta_tile_shape_mnk: CTA tile shape (M,N,K)
|
|
1782
|
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
|
1841
1783
|
:param element_type: Data type of elements
|
|
1842
1784
|
:type element_type: type[cutlass.Numeric]
|
|
1843
1785
|
:param is_cooperative: Whether to use cooperative approach
|
|
@@ -1850,12 +1792,12 @@ class GemmSm90:
|
|
|
1850
1792
|
"""
|
|
1851
1793
|
if epi_tile_override is not None:
|
|
1852
1794
|
return epi_tile_override
|
|
1853
|
-
if
|
|
1854
|
-
tile_m = math.gcd(128, cute.size(
|
|
1855
|
-
tile_n = math.gcd(32, cute.size(
|
|
1856
|
-
elif
|
|
1857
|
-
tile_m = math.gcd(192, cute.size(
|
|
1858
|
-
tile_n = math.gcd(32, cute.size(
|
|
1795
|
+
if cta_tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
|
|
1796
|
+
tile_m = math.gcd(128, cute.size(cta_tile_shape_mnk, mode=[0]))
|
|
1797
|
+
tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
|
|
1798
|
+
elif cta_tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
|
|
1799
|
+
tile_m = math.gcd(192, cute.size(cta_tile_shape_mnk, mode=[0]))
|
|
1800
|
+
tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1]))
|
|
1859
1801
|
else:
|
|
1860
1802
|
# In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
|
|
1861
1803
|
# epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
|
|
@@ -1864,13 +1806,13 @@ class GemmSm90:
|
|
|
1864
1806
|
# We could change the epilogue to accommodate this,
|
|
1865
1807
|
# but it's easier to just set epi_tile_m = 64.
|
|
1866
1808
|
n_perf = 64 if element_type is not None and element_type.width == 8 else 32
|
|
1867
|
-
tile_m = math.gcd(64, cute.size(
|
|
1868
|
-
tile_n = math.gcd(n_perf, cute.size(
|
|
1809
|
+
tile_m = math.gcd(64, cute.size(cta_tile_shape_mnk, mode=[0]))
|
|
1810
|
+
tile_n = math.gcd(n_perf, cute.size(cta_tile_shape_mnk, mode=[1]))
|
|
1869
1811
|
return (tile_m, tile_n)
|
|
1870
1812
|
|
|
1871
1813
|
@staticmethod
|
|
1872
1814
|
def _make_smem_layouts(
|
|
1873
|
-
|
|
1815
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
1874
1816
|
epi_tile: Tuple[int, int],
|
|
1875
1817
|
a_dtype: Type[cutlass.Numeric],
|
|
1876
1818
|
a_layout: LayoutEnum,
|
|
@@ -1888,8 +1830,8 @@ class GemmSm90:
|
|
|
1888
1830
|
]:
|
|
1889
1831
|
"""Create shared memory layouts for A, B, and C tensors.
|
|
1890
1832
|
|
|
1891
|
-
:param
|
|
1892
|
-
:type
|
|
1833
|
+
:param cta_tile_shape_mnk: CTA tile shape (M,N,K)
|
|
1834
|
+
:type cta_tile_shape_mnk: Tuple[int, int, int]
|
|
1893
1835
|
:param epi_tile: Epilogue tile shape
|
|
1894
1836
|
:type epi_tile: Tuple[int, int]
|
|
1895
1837
|
:param a_dtype: Data type for matrix A
|
|
@@ -1912,11 +1854,11 @@ class GemmSm90:
|
|
|
1912
1854
|
:return: Tuple of shared memory layouts for A, B, and C
|
|
1913
1855
|
:rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
|
|
1914
1856
|
"""
|
|
1915
|
-
a_smem_shape = cute.slice_(
|
|
1857
|
+
a_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
|
|
1916
1858
|
|
|
1917
1859
|
a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
|
1918
1860
|
b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
|
|
1919
|
-
a_major_mode_size =
|
|
1861
|
+
a_major_mode_size = cta_tile_shape_mnk[2 if a_is_k_major else 0]
|
|
1920
1862
|
a_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1921
1863
|
sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size),
|
|
1922
1864
|
a_dtype,
|
|
@@ -1927,9 +1869,9 @@ class GemmSm90:
|
|
|
1927
1869
|
order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
|
|
1928
1870
|
)
|
|
1929
1871
|
|
|
1930
|
-
b_smem_shape = cute.slice_(
|
|
1872
|
+
b_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
|
|
1931
1873
|
|
|
1932
|
-
b_major_mode_size =
|
|
1874
|
+
b_major_mode_size = cta_tile_shape_mnk[2 if b_is_k_major else 1]
|
|
1933
1875
|
b_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1934
1876
|
sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size),
|
|
1935
1877
|
b_dtype,
|
|
@@ -1940,36 +1882,18 @@ class GemmSm90:
|
|
|
1940
1882
|
order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
|
|
1941
1883
|
)
|
|
1942
1884
|
|
|
1885
|
+
epi_smem_layout_staged = None
|
|
1943
1886
|
if d_dtype is not None:
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
d_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1947
|
-
sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
|
|
1948
|
-
d_dtype,
|
|
1949
|
-
)
|
|
1950
|
-
epi_smem_layout_staged = cute.tile_to_shape(
|
|
1951
|
-
d_smem_layout_atom,
|
|
1952
|
-
cute.append(d_smem_shape, epi_stage),
|
|
1953
|
-
order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
|
|
1887
|
+
epi_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
|
|
1888
|
+
d_dtype, d_layout, epi_tile, epi_stage
|
|
1954
1889
|
)
|
|
1955
|
-
else:
|
|
1956
|
-
epi_smem_layout_staged = None
|
|
1957
1890
|
|
|
1891
|
+
epi_c_smem_layout_staged = None
|
|
1958
1892
|
if c_dtype is not None:
|
|
1959
1893
|
assert c_layout is not None
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
c_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1963
|
-
sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
|
|
1964
|
-
c_dtype,
|
|
1965
|
-
)
|
|
1966
|
-
epi_c_smem_layout_staged = cute.tile_to_shape(
|
|
1967
|
-
c_smem_layout_atom,
|
|
1968
|
-
cute.append(c_smem_shape, epi_c_stage),
|
|
1969
|
-
order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
|
|
1894
|
+
epi_c_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi(
|
|
1895
|
+
c_dtype, c_layout, epi_tile, epi_c_stage
|
|
1970
1896
|
)
|
|
1971
|
-
else:
|
|
1972
|
-
epi_c_smem_layout_staged = None
|
|
1973
1897
|
|
|
1974
1898
|
return (
|
|
1975
1899
|
a_smem_layout_staged,
|
|
@@ -1983,7 +1907,7 @@ class GemmSm90:
|
|
|
1983
1907
|
tensor_d: cute.Tensor,
|
|
1984
1908
|
epi_smem_layout_staged: cute.ComposedLayout,
|
|
1985
1909
|
epi_tile: Tuple[int, int],
|
|
1986
|
-
|
|
1910
|
+
op_type: Literal["store", "load", "add"],
|
|
1987
1911
|
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
|
1988
1912
|
"""Create TMA atoms and tensors for storing D or loading C.
|
|
1989
1913
|
|
|
@@ -1997,13 +1921,15 @@ class GemmSm90:
|
|
|
1997
1921
|
:return: TMA atom and tensor for C
|
|
1998
1922
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
1999
1923
|
"""
|
|
2000
|
-
assert
|
|
1924
|
+
assert op_type in ["load", "store", "add"]
|
|
2001
1925
|
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
|
2002
1926
|
d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
|
|
2003
1927
|
op = (
|
|
2004
1928
|
cpasync.CopyBulkTensorTileG2SOp()
|
|
2005
|
-
if
|
|
1929
|
+
if op_type == "load"
|
|
2006
1930
|
else cpasync.CopyBulkTensorTileS2GOp()
|
|
1931
|
+
if op_type == "store"
|
|
1932
|
+
else cpasync.CopyReduceBulkTensorTileS2GOp(cute.ReductionOp.ADD)
|
|
2007
1933
|
)
|
|
2008
1934
|
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
|
2009
1935
|
op, tensor_d, epi_smem_layout, d_cta_v_layout
|
|
@@ -2013,7 +1939,7 @@ class GemmSm90:
|
|
|
2013
1939
|
@staticmethod
|
|
2014
1940
|
def _make_tma_atoms_and_tensors(
|
|
2015
1941
|
tensor: cute.Tensor,
|
|
2016
|
-
|
|
1942
|
+
smem_layout: cute.ComposedLayout,
|
|
2017
1943
|
smem_tile: Tuple[int, int],
|
|
2018
1944
|
mcast_dim: int,
|
|
2019
1945
|
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
|
@@ -2021,8 +1947,8 @@ class GemmSm90:
|
|
|
2021
1947
|
|
|
2022
1948
|
:param tensor: Input tensor (A or B)
|
|
2023
1949
|
:type tensor: cute.Tensor
|
|
2024
|
-
:param
|
|
2025
|
-
:type
|
|
1950
|
+
:param smem_layout: Shared memory layout for the tensor
|
|
1951
|
+
:type smem_layout: cute.ComposedLayout
|
|
2026
1952
|
:param smem_tile: Shared memory tile shape
|
|
2027
1953
|
:type smem_tile: Tuple[int, int]
|
|
2028
1954
|
:param mcast_dim: Multicast dimension
|
|
@@ -2036,8 +1962,6 @@ class GemmSm90:
|
|
|
2036
1962
|
if mcast_dim == 1
|
|
2037
1963
|
else cpasync.CopyBulkTensorTileG2SMulticastOp()
|
|
2038
1964
|
)
|
|
2039
|
-
|
|
2040
|
-
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
|
|
2041
1965
|
tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
|
|
2042
1966
|
op,
|
|
2043
1967
|
tensor,
|
|
@@ -2054,13 +1978,18 @@ class GemmSm90:
|
|
|
2054
1978
|
num_bits_per_copy=copy_bits,
|
|
2055
1979
|
)
|
|
2056
1980
|
copy_elems = copy_bits // dtype.width
|
|
2057
|
-
|
|
1981
|
+
loads_per_cache_line = 128 * 8 // copy_bits # 128 bytes per cache line
|
|
1982
|
+
shape_dim_1 = cute.size(self.cta_tile_shape_mnk[2]) // copy_elems
|
|
1983
|
+
if shape_dim_1 > loads_per_cache_line:
|
|
1984
|
+
shape_dim_1 = math.gcd(shape_dim_1, loads_per_cache_line)
|
|
2058
1985
|
# thread layout for copy
|
|
2059
1986
|
thread_layout = cute.make_layout(
|
|
2060
1987
|
(num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
|
2061
1988
|
)
|
|
2062
1989
|
if major_mode != LayoutEnum.ROW_MAJOR:
|
|
2063
|
-
shape_dim_0 = cute.size(self.
|
|
1990
|
+
shape_dim_0 = cute.size(self.cta_tile_shape_mnk[0]) // copy_elems
|
|
1991
|
+
if shape_dim_0 > loads_per_cache_line:
|
|
1992
|
+
shape_dim_0 = math.gcd(shape_dim_0, loads_per_cache_line)
|
|
2064
1993
|
thread_layout = cute.make_layout(
|
|
2065
1994
|
(shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
|
2066
1995
|
)
|
|
@@ -2102,7 +2031,7 @@ class GemmSm90:
|
|
|
2102
2031
|
"""
|
|
2103
2032
|
is_valid = True
|
|
2104
2033
|
if a_dtype not in {
|
|
2105
|
-
|
|
2034
|
+
Float16,
|
|
2106
2035
|
cutlass.BFloat16,
|
|
2107
2036
|
cutlass.Float8E4M3FN,
|
|
2108
2037
|
cutlass.Float8E5M2,
|
|
@@ -2110,19 +2039,19 @@ class GemmSm90:
|
|
|
2110
2039
|
is_valid = False
|
|
2111
2040
|
# tested b_dtype
|
|
2112
2041
|
if b_dtype not in {
|
|
2113
|
-
|
|
2042
|
+
Float16,
|
|
2114
2043
|
cutlass.BFloat16,
|
|
2115
2044
|
cutlass.Float8E4M3FN,
|
|
2116
2045
|
cutlass.Float8E5M2,
|
|
2117
2046
|
}:
|
|
2118
2047
|
is_valid = False
|
|
2119
|
-
if acc_dtype not in {
|
|
2048
|
+
if acc_dtype not in {Float32, Float16}:
|
|
2120
2049
|
is_valid = False
|
|
2121
2050
|
# tested d_dtype
|
|
2122
2051
|
if d_dtype not in {
|
|
2123
2052
|
None,
|
|
2124
|
-
|
|
2125
|
-
|
|
2053
|
+
Float32,
|
|
2054
|
+
Float16,
|
|
2126
2055
|
cutlass.BFloat16,
|
|
2127
2056
|
cutlass.Float8E4M3FN,
|
|
2128
2057
|
cutlass.Float8E5M2,
|
|
@@ -2139,107 +2068,3 @@ class GemmSm90:
|
|
|
2139
2068
|
if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
|
|
2140
2069
|
is_valid = False
|
|
2141
2070
|
return is_valid
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
def gemm_sm90(
|
|
2145
|
-
A: Tensor, # (l, m, k)
|
|
2146
|
-
B: Tensor, # (l, n, k)
|
|
2147
|
-
D: Tensor, # (l, m, n)
|
|
2148
|
-
C: Optional[Tensor], # (l, m, n)
|
|
2149
|
-
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
2150
|
-
tile_M: int,
|
|
2151
|
-
tile_N: int,
|
|
2152
|
-
cluster_M: int,
|
|
2153
|
-
cluster_N: int,
|
|
2154
|
-
pingpong: bool = False,
|
|
2155
|
-
persistent: bool = True,
|
|
2156
|
-
alpha: float | Tensor = 1.0,
|
|
2157
|
-
beta: float | Tensor = 1.0,
|
|
2158
|
-
) -> None:
|
|
2159
|
-
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(A, B, D, C)
|
|
2160
|
-
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
2161
|
-
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
2162
|
-
major_configs = {
|
|
2163
|
-
"A": ("m", "k", "l"),
|
|
2164
|
-
"B": ("n", "k", "l"),
|
|
2165
|
-
"D": ("m", "n", "l"),
|
|
2166
|
-
"C": ("m", "n", "l"),
|
|
2167
|
-
}
|
|
2168
|
-
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
2169
|
-
|
|
2170
|
-
acc_dtype = cutlass.Float32
|
|
2171
|
-
tile_shape_mn = (tile_M, tile_N)
|
|
2172
|
-
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
2173
|
-
if not GemmSm90.is_valid_dtypes(
|
|
2174
|
-
tensor_infos["A"].dtype,
|
|
2175
|
-
tensor_infos["B"].dtype,
|
|
2176
|
-
acc_dtype,
|
|
2177
|
-
tensor_infos["D"].dtype,
|
|
2178
|
-
tensor_infos["A"].major,
|
|
2179
|
-
tensor_infos["B"].major,
|
|
2180
|
-
):
|
|
2181
|
-
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
2182
|
-
|
|
2183
|
-
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
2184
|
-
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
2185
|
-
|
|
2186
|
-
def scalar_arg(scalar: float | Tensor):
|
|
2187
|
-
if isinstance(scalar, float):
|
|
2188
|
-
return Float32(scalar) if scalar != 1.0 else None
|
|
2189
|
-
else:
|
|
2190
|
-
assert isinstance(scalar, Tensor)
|
|
2191
|
-
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2192
|
-
|
|
2193
|
-
epi_args = GemmSm90.EpilogueArguments(scalar_arg(alpha), scalar_arg(beta))
|
|
2194
|
-
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
2195
|
-
max_active_clusters, tile_count_semaphore
|
|
2196
|
-
)
|
|
2197
|
-
current_stream = cutlass_torch.current_stream()
|
|
2198
|
-
compile_key = GemmWrapperBase.get_compile_key(
|
|
2199
|
-
tensor_infos,
|
|
2200
|
-
None,
|
|
2201
|
-
tile_shape_mn,
|
|
2202
|
-
cluster_shape_mnk,
|
|
2203
|
-
pingpong,
|
|
2204
|
-
persistent,
|
|
2205
|
-
tile_count_semaphore is not None,
|
|
2206
|
-
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
|
2207
|
-
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
|
2208
|
-
key_tensor_names=("A", "B", "D", "C"),
|
|
2209
|
-
)
|
|
2210
|
-
cache = gemm_sm90.compile_cache
|
|
2211
|
-
if compile_key not in cache:
|
|
2212
|
-
gemm = GemmSm90(
|
|
2213
|
-
acc_dtype,
|
|
2214
|
-
tensor_infos["A"].dtype,
|
|
2215
|
-
tile_shape_mn,
|
|
2216
|
-
cluster_shape_mnk,
|
|
2217
|
-
pingpong=pingpong,
|
|
2218
|
-
is_persistent=persistent,
|
|
2219
|
-
)
|
|
2220
|
-
cache[compile_key] = cute.compile(
|
|
2221
|
-
gemm,
|
|
2222
|
-
tensor_infos["A"].cute_tensor,
|
|
2223
|
-
tensor_infos["B"].cute_tensor,
|
|
2224
|
-
tensor_infos["D"].cute_tensor,
|
|
2225
|
-
tensor_infos["C"].cute_tensor,
|
|
2226
|
-
epi_args,
|
|
2227
|
-
scheduler_args,
|
|
2228
|
-
None, # varlen_args
|
|
2229
|
-
None, # mAIdx
|
|
2230
|
-
current_stream,
|
|
2231
|
-
)
|
|
2232
|
-
cache[compile_key](
|
|
2233
|
-
tensor_infos["A"].cute_tensor,
|
|
2234
|
-
tensor_infos["B"].cute_tensor,
|
|
2235
|
-
tensor_infos["D"].cute_tensor,
|
|
2236
|
-
tensor_infos["C"].cute_tensor,
|
|
2237
|
-
epi_args,
|
|
2238
|
-
scheduler_args,
|
|
2239
|
-
None,
|
|
2240
|
-
None,
|
|
2241
|
-
current_stream,
|
|
2242
|
-
)
|
|
2243
|
-
|
|
2244
|
-
|
|
2245
|
-
gemm_sm90.compile_cache = {}
|