quack-kernels 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +4 -1
- quack/autotuner.py +309 -0
- quack/cute_dsl_utils.py +40 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1494 -450
- quack/fast_math.py +97 -0
- quack/gemm_config.py +61 -0
- quack/gemm_interface.py +321 -0
- quack/linear.py +176 -0
- quack/lse.py +62 -0
- quack/mlp.py +204 -0
- quack/pipeline.py +166 -0
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2088 -0
- quack/tensormap_manager.py +114 -0
- quack/tile_scheduler.py +935 -0
- quack/topk.py +221 -0
- quack/utils.py +236 -18
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.1.11.dist-info}/METADATA +1 -1
- quack_kernels-0.1.11.dist-info/RECORD +31 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.1.11.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.1.11.dist-info}/top_level.txt +0 -0
quack/dense_gemm_sm90.py
CHANGED
|
@@ -27,21 +27,37 @@
|
|
|
27
27
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
28
|
|
|
29
29
|
import argparse
|
|
30
|
-
|
|
30
|
+
import enum
|
|
31
|
+
from typing import Tuple, Type, Callable, Optional
|
|
32
|
+
from functools import partial
|
|
31
33
|
import math
|
|
34
|
+
|
|
32
35
|
import cuda.bindings.driver as cuda
|
|
33
36
|
|
|
34
37
|
import torch
|
|
35
38
|
|
|
36
39
|
import cutlass
|
|
37
40
|
import cutlass.cute as cute
|
|
38
|
-
import cutlass.cute.testing as testing
|
|
39
|
-
import cutlass.utils as utils
|
|
40
41
|
import cutlass.pipeline as pipeline
|
|
41
42
|
import cutlass.torch as cutlass_torch
|
|
42
|
-
from cutlass.cute.runtime import from_dlpack
|
|
43
|
+
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
43
44
|
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
44
45
|
import cutlass.utils.hopper_helpers as sm90_utils
|
|
46
|
+
from cutlass import Int32, const_expr
|
|
47
|
+
|
|
48
|
+
from quack.tile_scheduler import (
|
|
49
|
+
TileSchedulerArguments,
|
|
50
|
+
TileScheduler,
|
|
51
|
+
VarlenMTileSchedulerArguments,
|
|
52
|
+
VarlenMTileScheduler,
|
|
53
|
+
ParamsBase,
|
|
54
|
+
RasterOrderOption,
|
|
55
|
+
)
|
|
56
|
+
from quack.tensormap_manager import TensorMapManagerSm90
|
|
57
|
+
|
|
58
|
+
# return PipelineStateWAdvance instead of PipelineState
|
|
59
|
+
from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
|
|
60
|
+
import quack.utils as utils
|
|
45
61
|
|
|
46
62
|
"""
|
|
47
63
|
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
|
@@ -154,6 +170,11 @@ def parse_arguments() -> argparse.Namespace:
|
|
|
154
170
|
type=cutlass.dtype,
|
|
155
171
|
default=cutlass.BFloat16,
|
|
156
172
|
)
|
|
173
|
+
parser.add_argument(
|
|
174
|
+
"--c_dtype",
|
|
175
|
+
type=cutlass.dtype,
|
|
176
|
+
default=None,
|
|
177
|
+
)
|
|
157
178
|
parser.add_argument(
|
|
158
179
|
"--acc_dtype",
|
|
159
180
|
type=cutlass.dtype,
|
|
@@ -162,21 +183,24 @@ def parse_arguments() -> argparse.Namespace:
|
|
|
162
183
|
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
163
184
|
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
164
185
|
parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
|
|
165
|
-
parser.add_argument("--
|
|
166
|
-
parser.add_argument("--
|
|
186
|
+
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
|
|
187
|
+
parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation")
|
|
188
|
+
parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations")
|
|
167
189
|
parser.add_argument(
|
|
168
190
|
"--iterations",
|
|
169
191
|
type=int,
|
|
170
|
-
default=
|
|
192
|
+
default=30,
|
|
171
193
|
help="Number of iterations to run the kernel",
|
|
172
194
|
)
|
|
173
|
-
parser.add_argument("--
|
|
195
|
+
parser.add_argument("--persistent", action="store_true", help="Persistent kernel")
|
|
174
196
|
parser.add_argument(
|
|
175
|
-
"--
|
|
176
|
-
action="store_true",
|
|
177
|
-
default=False,
|
|
178
|
-
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
|
197
|
+
"--dynamic_persistent", action="store_true", help="Dynamic persistent kernel"
|
|
179
198
|
)
|
|
199
|
+
parser.add_argument("--pingpong", action="store_true", help="Pingpong kernel")
|
|
200
|
+
parser.add_argument("--varlen_m", action="store_true", help="Variable length M dimension")
|
|
201
|
+
parser.add_argument("--gather_A", action="store_true", help="Gather A")
|
|
202
|
+
parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum")
|
|
203
|
+
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
|
|
180
204
|
|
|
181
205
|
args = parser.parse_args()
|
|
182
206
|
|
|
@@ -195,6 +219,17 @@ def parse_arguments() -> argparse.Namespace:
|
|
|
195
219
|
# /////////////////////////////////////////////////////////////////////////////
|
|
196
220
|
|
|
197
221
|
|
|
222
|
+
class NamedBarrierGemm(enum.IntEnum):
|
|
223
|
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
|
224
|
+
# For mainloop load warps to signal that the epilogue load warp can start.
|
|
225
|
+
# This is to avoid loading C too early, interfering with loading A and B.
|
|
226
|
+
EpilogueLoad = enum.auto()
|
|
227
|
+
MmaWG0 = enum.auto()
|
|
228
|
+
MmaWG1 = enum.auto()
|
|
229
|
+
EpiWG0 = enum.auto()
|
|
230
|
+
EpiWG1 = enum.auto()
|
|
231
|
+
|
|
232
|
+
|
|
198
233
|
class HopperWgmmaGemmKernel:
|
|
199
234
|
"""
|
|
200
235
|
This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
@@ -221,9 +256,6 @@ class HopperWgmmaGemmKernel:
|
|
|
221
256
|
- Float32 (for all floating point inputs)
|
|
222
257
|
|
|
223
258
|
:note: Constraints:
|
|
224
|
-
- CTA tile M must be 64/128
|
|
225
|
-
- CTA tile N must be 64/128/256
|
|
226
|
-
- CTA tile K must be 64
|
|
227
259
|
- Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
|
228
260
|
|
|
229
261
|
Example:
|
|
@@ -235,11 +267,19 @@ class HopperWgmmaGemmKernel:
|
|
|
235
267
|
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
|
236
268
|
"""
|
|
237
269
|
|
|
270
|
+
bytes_per_tensormap = 128
|
|
271
|
+
num_tensormaps = 1 # For D only
|
|
272
|
+
|
|
238
273
|
def __init__(
|
|
239
274
|
self,
|
|
240
275
|
acc_dtype: Type[cutlass.Numeric],
|
|
276
|
+
a_dtype: Type[cutlass.Numeric],
|
|
241
277
|
tile_shape_mnk: Tuple[int, int, int],
|
|
242
278
|
cluster_shape_mnk: Tuple[int, int, int],
|
|
279
|
+
pingpong: bool = False,
|
|
280
|
+
is_persistent: bool = True,
|
|
281
|
+
fp8_fast_accum: bool = False,
|
|
282
|
+
gather_A: bool = False,
|
|
243
283
|
):
|
|
244
284
|
"""
|
|
245
285
|
Initializes the configuration for a Hopper dense GEMM kernel.
|
|
@@ -256,52 +296,101 @@ class HopperWgmmaGemmKernel:
|
|
|
256
296
|
"""
|
|
257
297
|
|
|
258
298
|
self.acc_dtype = acc_dtype
|
|
299
|
+
self.pingpong = pingpong
|
|
300
|
+
self.is_persistent = is_persistent
|
|
301
|
+
if self.pingpong:
|
|
302
|
+
assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
|
|
303
|
+
self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
|
|
304
|
+
self.gather_A = gather_A
|
|
305
|
+
if gather_A:
|
|
306
|
+
assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
|
|
307
|
+
self.tensormap_update_mode = cutlass.utils.TensorMapUpdateMode.SMEM
|
|
259
308
|
|
|
260
309
|
self.cluster_shape_mnk = cluster_shape_mnk
|
|
261
310
|
self.tile_shape_mnk = tuple(tile_shape_mnk)
|
|
262
311
|
tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1]
|
|
263
312
|
# check the cta tile shape
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
313
|
+
if not self.pingpong:
|
|
314
|
+
if tile_M not in [64, 128, 192, 256, 320]:
|
|
315
|
+
raise ValueError("CTA tile shape M must be 64/128/192/256/320")
|
|
316
|
+
if tile_M in [192, 320]: # special case
|
|
317
|
+
tile_N_max = 256 if tile_M == 192 else 160
|
|
318
|
+
if not (tile_N % 32 == 0 and tile_N <= tile_N_max):
|
|
319
|
+
raise ValueError(
|
|
320
|
+
f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}"
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
if not (
|
|
324
|
+
(tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
|
|
325
|
+
):
|
|
326
|
+
raise ValueError(
|
|
327
|
+
"CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
|
|
328
|
+
)
|
|
273
329
|
else:
|
|
274
|
-
if not
|
|
275
|
-
raise ValueError(
|
|
276
|
-
|
|
277
|
-
|
|
330
|
+
if tile_M not in [64, 128, 192]:
|
|
331
|
+
raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
|
|
332
|
+
tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
|
|
333
|
+
if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
|
|
334
|
+
raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
|
|
278
335
|
if not self.tile_shape_mnk[2] % 16 == 0:
|
|
279
336
|
raise ValueError("CTA tile shape K must be divisible by 16")
|
|
280
337
|
|
|
281
|
-
if
|
|
282
|
-
|
|
338
|
+
if not self.pingpong:
|
|
339
|
+
if tile_M == 320: # tile_M / 64 is not even so we have to split along N
|
|
340
|
+
atom_layout_m, atom_layout_n = 1, 2
|
|
341
|
+
elif tile_M == 192:
|
|
342
|
+
if tile_N <= 128:
|
|
343
|
+
atom_layout_m, atom_layout_n = 3, 1
|
|
344
|
+
else:
|
|
345
|
+
atom_layout_m, atom_layout_n = 1, 2
|
|
346
|
+
else:
|
|
347
|
+
atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2
|
|
348
|
+
atom_layout_n = 1
|
|
349
|
+
assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
|
|
283
350
|
else:
|
|
284
|
-
atom_layout_m =
|
|
285
|
-
atom_layout_n = 1
|
|
286
|
-
assert atom_layout_m in [1, 2] and atom_layout_n in [1, 2]
|
|
351
|
+
atom_layout_m, atom_layout_n = 1, 1
|
|
287
352
|
self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
|
|
288
353
|
|
|
289
|
-
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
|
354
|
+
self.num_mcast_ctas_a = self.cluster_shape_mnk[1] if not self.gather_A else 1
|
|
290
355
|
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
|
291
356
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
292
357
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
293
358
|
|
|
294
359
|
self.occupancy = 1
|
|
295
|
-
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
|
|
360
|
+
self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
|
|
361
|
+
if self.pingpong:
|
|
362
|
+
assert self.mma_warp_groups == 2
|
|
363
|
+
assert self.mma_warp_groups in [1, 2, 3]
|
|
296
364
|
self.num_threads_per_warp_group = 128
|
|
297
365
|
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
|
|
298
|
-
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
|
|
299
|
-
self.
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
self.
|
|
304
|
-
self.
|
|
366
|
+
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
|
|
367
|
+
self.num_epi_threads = (
|
|
368
|
+
self.mma_warp_groups if not self.pingpong else 1
|
|
369
|
+
) * self.num_threads_per_warp_group
|
|
370
|
+
self.num_ab_load_warps = 1 if not self.gather_A else 4
|
|
371
|
+
self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
|
|
372
|
+
self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
|
|
373
|
+
self.ab_load_warp_id = self.mma_warp_groups * 4
|
|
374
|
+
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
|
|
375
|
+
|
|
376
|
+
regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // (
|
|
377
|
+
math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
|
|
378
|
+
)
|
|
379
|
+
if self.fp8_slow_accum:
|
|
380
|
+
regs_per_thread *= 2
|
|
381
|
+
if not self.gather_A:
|
|
382
|
+
if self.mma_warp_groups == 3:
|
|
383
|
+
self.num_regs_load, self.num_regs_mma = 32, 160
|
|
384
|
+
else:
|
|
385
|
+
heavy_register_pressure = regs_per_thread >= 208
|
|
386
|
+
self.num_regs_load, self.num_regs_mma = (
|
|
387
|
+
(40, 232) if not heavy_register_pressure else (24, 240)
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
if self.mma_warp_groups == 3:
|
|
391
|
+
self.num_regs_load, self.num_regs_mma = 56, 152
|
|
392
|
+
else:
|
|
393
|
+
self.num_regs_load, self.num_regs_mma = (56, 224)
|
|
305
394
|
|
|
306
395
|
self.ab_stage = None
|
|
307
396
|
self.epi_stage = None
|
|
@@ -328,26 +417,34 @@ class HopperWgmmaGemmKernel:
|
|
|
328
417
|
- Computing A/B/C shared memory layout
|
|
329
418
|
"""
|
|
330
419
|
|
|
331
|
-
self.
|
|
420
|
+
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
|
332
421
|
|
|
333
|
-
is_cooperative = math.prod(self.atom_layout_mnk) > 1
|
|
334
422
|
self.epi_tile = self._sm90_compute_tile_shape_or_override(
|
|
335
|
-
self.tile_shape_mnk,
|
|
423
|
+
self.tile_shape_mnk,
|
|
424
|
+
self.atom_layout_mnk,
|
|
425
|
+
self.d_dtype,
|
|
336
426
|
)
|
|
337
427
|
|
|
338
428
|
# Compute stage before compute smem layout
|
|
339
|
-
self.ab_stage, self.epi_stage = self._compute_stages(
|
|
429
|
+
self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
|
|
340
430
|
self.tile_shape_mnk,
|
|
431
|
+
self.epi_tile,
|
|
341
432
|
self.a_dtype,
|
|
342
433
|
self.b_dtype,
|
|
434
|
+
self.d_dtype,
|
|
435
|
+
self.c_dtype,
|
|
343
436
|
self.smem_capacity,
|
|
344
437
|
self.occupancy,
|
|
438
|
+
# epi_smem will reuse smem ab if not persistent.
|
|
439
|
+
overlap_sD_sA=not self.is_persistent,
|
|
345
440
|
)
|
|
441
|
+
self.sched_stage = 2 if self.pingpong else 1
|
|
346
442
|
|
|
347
443
|
(
|
|
348
444
|
self.a_smem_layout_staged,
|
|
349
445
|
self.b_smem_layout_staged,
|
|
350
446
|
self.epi_smem_layout_staged,
|
|
447
|
+
self.epi_c_smem_layout_staged,
|
|
351
448
|
) = self._make_smem_layouts(
|
|
352
449
|
self.tile_shape_mnk,
|
|
353
450
|
self.epi_tile,
|
|
@@ -359,6 +456,9 @@ class HopperWgmmaGemmKernel:
|
|
|
359
456
|
self.d_dtype,
|
|
360
457
|
self.d_layout,
|
|
361
458
|
self.epi_stage,
|
|
459
|
+
self.c_dtype,
|
|
460
|
+
self.c_layout,
|
|
461
|
+
self.epi_c_stage,
|
|
362
462
|
)
|
|
363
463
|
|
|
364
464
|
@cute.jit
|
|
@@ -367,6 +467,12 @@ class HopperWgmmaGemmKernel:
|
|
|
367
467
|
mA: cute.Tensor,
|
|
368
468
|
mB: cute.Tensor,
|
|
369
469
|
mD: cute.Tensor,
|
|
470
|
+
mC: Optional[cute.Tensor],
|
|
471
|
+
mAIdx: Optional[cute.Tensor],
|
|
472
|
+
mCuSeqlensM: Optional[cute.Tensor],
|
|
473
|
+
mTensormaps: Optional[cute.Tensor],
|
|
474
|
+
tile_count_semaphore: Optional[cute.Pointer],
|
|
475
|
+
max_active_clusters: Int32,
|
|
370
476
|
stream: cuda.CUstream,
|
|
371
477
|
):
|
|
372
478
|
"""Execute the GEMM operation in steps:
|
|
@@ -390,16 +496,29 @@ class HopperWgmmaGemmKernel:
|
|
|
390
496
|
self.a_dtype = mA.element_type
|
|
391
497
|
self.b_dtype = mB.element_type
|
|
392
498
|
self.d_dtype = mD.element_type
|
|
393
|
-
self.
|
|
394
|
-
self.
|
|
395
|
-
self.
|
|
499
|
+
self.c_dtype = mC.element_type if mC is not None else None
|
|
500
|
+
self.a_layout = cutlass.utils.LayoutEnum.from_tensor(mA)
|
|
501
|
+
self.b_layout = cutlass.utils.LayoutEnum.from_tensor(mB)
|
|
502
|
+
self.d_layout = cutlass.utils.LayoutEnum.from_tensor(mD)
|
|
503
|
+
self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None
|
|
396
504
|
|
|
397
|
-
if
|
|
505
|
+
if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
|
|
398
506
|
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
|
|
399
|
-
if
|
|
507
|
+
if const_expr(self.a_dtype.width != self.b_dtype.width):
|
|
400
508
|
raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
|
|
401
|
-
if
|
|
509
|
+
if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
|
|
402
510
|
raise TypeError("a_dtype should be float16 or float8")
|
|
511
|
+
assert (mAIdx is not None) == self.gather_A
|
|
512
|
+
|
|
513
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
514
|
+
new_stride = lambda t: tuple(
|
|
515
|
+
cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
|
|
516
|
+
for s in t.stride
|
|
517
|
+
)
|
|
518
|
+
mA, mD = [
|
|
519
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
520
|
+
for t in (mA, mD)
|
|
521
|
+
]
|
|
403
522
|
|
|
404
523
|
self._setup_attributes()
|
|
405
524
|
|
|
@@ -412,13 +531,31 @@ class HopperWgmmaGemmKernel:
|
|
|
412
531
|
self.atom_layout_mnk,
|
|
413
532
|
tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]),
|
|
414
533
|
)
|
|
534
|
+
if const_expr(self.atom_layout_mnk[1] > 1):
|
|
535
|
+
# If N dimension is split among 2 WGs, we need to permute the N dimension so
|
|
536
|
+
# that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32)
|
|
537
|
+
# containing accumulators that are next to each other in the N dimension.
|
|
538
|
+
# Without permutation WG0 would write to epi smem of size (64, 16) and
|
|
539
|
+
# WG1 would write to a separate epi smem of size (64, 16) that's far away.
|
|
540
|
+
atom_n = self.atom_layout_mnk[1]
|
|
541
|
+
permutation_n = cute.make_ordered_layout(
|
|
542
|
+
(8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1)
|
|
543
|
+
)
|
|
544
|
+
tiled_mma = cute.make_tiled_mma(
|
|
545
|
+
cute.make_mma_atom(tiled_mma.op),
|
|
546
|
+
self.atom_layout_mnk,
|
|
547
|
+
permutation_mnk=(None, permutation_n, None),
|
|
548
|
+
)
|
|
415
549
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
550
|
+
if const_expr(not self.gather_A):
|
|
551
|
+
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
|
552
|
+
mA,
|
|
553
|
+
self.a_smem_layout_staged,
|
|
554
|
+
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
|
|
555
|
+
self.cluster_shape_mnk[1],
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
tma_atom_a, tma_tensor_a = None, None
|
|
422
559
|
|
|
423
560
|
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
|
424
561
|
mB,
|
|
@@ -427,17 +564,84 @@ class HopperWgmmaGemmKernel:
|
|
|
427
564
|
self.cluster_shape_mnk[0],
|
|
428
565
|
)
|
|
429
566
|
|
|
430
|
-
tma_atom_d, tma_tensor_d = self.
|
|
431
|
-
mD,
|
|
432
|
-
self.epi_smem_layout_staged,
|
|
433
|
-
self.epi_tile,
|
|
567
|
+
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
|
|
568
|
+
mD, self.epi_smem_layout_staged, self.epi_tile, store_or_load="store"
|
|
434
569
|
)
|
|
435
570
|
|
|
436
|
-
|
|
571
|
+
if const_expr(mC is not None):
|
|
572
|
+
tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors(
|
|
573
|
+
mC, self.epi_c_smem_layout_staged, self.epi_tile, store_or_load="load"
|
|
574
|
+
)
|
|
575
|
+
else:
|
|
576
|
+
tma_atom_c, tma_tensor_c = None, None
|
|
577
|
+
|
|
578
|
+
if const_expr(mCuSeqlensM is None):
|
|
579
|
+
problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (
|
|
580
|
+
mD.shape[2],
|
|
581
|
+
)
|
|
582
|
+
TileSchedulerCls = TileScheduler
|
|
583
|
+
tile_sched_args = TileSchedulerArguments(
|
|
584
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
585
|
+
raster_order=RasterOrderOption.Heuristic,
|
|
586
|
+
group_size=8,
|
|
587
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
588
|
+
tile_count_semaphore=tile_count_semaphore,
|
|
589
|
+
is_persistent=self.is_persistent,
|
|
590
|
+
)
|
|
591
|
+
else:
|
|
592
|
+
assert mTensormaps is not None
|
|
593
|
+
problem_shape_ntile_mnl = (
|
|
594
|
+
None,
|
|
595
|
+
cute.ceil_div(mD.shape[1], self.tile_shape_mnk[1]),
|
|
596
|
+
mCuSeqlensM.shape[0] - 1,
|
|
597
|
+
)
|
|
598
|
+
TileSchedulerCls = VarlenMTileScheduler
|
|
599
|
+
tile_sched_args = VarlenMTileSchedulerArguments(
|
|
600
|
+
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
|
|
601
|
+
total_m=mD.shape[0],
|
|
602
|
+
cu_seqlens_m=mCuSeqlensM,
|
|
603
|
+
raster_order=RasterOrderOption.Heuristic,
|
|
604
|
+
group_size=8,
|
|
605
|
+
tile_shape_mnk=self.tile_shape_mnk,
|
|
606
|
+
cluster_shape_mnk=self.cluster_shape_mnk,
|
|
607
|
+
tile_count_semaphore=tile_count_semaphore,
|
|
608
|
+
is_persistent=self.is_persistent,
|
|
609
|
+
)
|
|
610
|
+
tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
|
|
611
|
+
grid = TileSchedulerCls.get_grid_shape(tile_sched_params, max_active_clusters)
|
|
612
|
+
|
|
613
|
+
epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if self.is_persistent else 0
|
|
614
|
+
epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
|
|
615
|
+
|
|
616
|
+
size_tensormap_in_i64 = (
|
|
617
|
+
0
|
|
618
|
+
if mCuSeqlensM is None
|
|
619
|
+
or self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.GMEM
|
|
620
|
+
else HopperWgmmaGemmKernel.num_tensormaps
|
|
621
|
+
* HopperWgmmaGemmKernel.bytes_per_tensormap
|
|
622
|
+
// 8
|
|
623
|
+
) * (1 if not self.pingpong else 2)
|
|
437
624
|
|
|
438
625
|
@cute.struct
|
|
439
626
|
class SharedStorage:
|
|
440
|
-
|
|
627
|
+
tensormap_buffer: cute.struct.Align[
|
|
628
|
+
cute.struct.MemRange[cutlass.Int64, size_tensormap_in_i64],
|
|
629
|
+
64,
|
|
630
|
+
]
|
|
631
|
+
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
|
|
632
|
+
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
|
|
633
|
+
sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
|
|
634
|
+
tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage]
|
|
635
|
+
sD: cute.struct.Align[
|
|
636
|
+
cute.struct.MemRange[self.d_dtype, epi_smem_size],
|
|
637
|
+
self.buffer_align_bytes,
|
|
638
|
+
]
|
|
639
|
+
sC: cute.struct.Align[
|
|
640
|
+
cute.struct.MemRange[
|
|
641
|
+
self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
|
|
642
|
+
],
|
|
643
|
+
self.buffer_align_bytes,
|
|
644
|
+
]
|
|
441
645
|
sA: cute.struct.Align[
|
|
442
646
|
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
|
|
443
647
|
self.buffer_align_bytes,
|
|
@@ -452,16 +656,25 @@ class HopperWgmmaGemmKernel:
|
|
|
452
656
|
# Launch the kernel synchronously
|
|
453
657
|
self.kernel(
|
|
454
658
|
tma_atom_a,
|
|
455
|
-
tma_tensor_a,
|
|
659
|
+
tma_tensor_a if const_expr(not self.gather_A) else mA,
|
|
456
660
|
tma_atom_b,
|
|
457
661
|
tma_tensor_b,
|
|
458
662
|
tma_atom_d,
|
|
459
663
|
tma_tensor_d,
|
|
664
|
+
mD,
|
|
665
|
+
tma_atom_c,
|
|
666
|
+
tma_tensor_c,
|
|
667
|
+
mAIdx,
|
|
668
|
+
mCuSeqlensM,
|
|
669
|
+
mTensormaps,
|
|
460
670
|
tiled_mma,
|
|
461
|
-
self.
|
|
671
|
+
self.cluster_layout_mnk,
|
|
462
672
|
self.a_smem_layout_staged,
|
|
463
673
|
self.b_smem_layout_staged,
|
|
464
674
|
self.epi_smem_layout_staged,
|
|
675
|
+
self.epi_c_smem_layout_staged,
|
|
676
|
+
tile_sched_params,
|
|
677
|
+
TileSchedulerCls,
|
|
465
678
|
).launch(
|
|
466
679
|
grid=grid,
|
|
467
680
|
block=[self.threads_per_cta, 1, 1],
|
|
@@ -476,17 +689,26 @@ class HopperWgmmaGemmKernel:
|
|
|
476
689
|
@cute.kernel
|
|
477
690
|
def kernel(
|
|
478
691
|
self,
|
|
479
|
-
tma_atom_a: cute.CopyAtom,
|
|
692
|
+
tma_atom_a: Optional[cute.CopyAtom],
|
|
480
693
|
mA_mkl: cute.Tensor,
|
|
481
694
|
tma_atom_b: cute.CopyAtom,
|
|
482
695
|
mB_nkl: cute.Tensor,
|
|
483
696
|
tma_atom_d: cute.CopyAtom,
|
|
697
|
+
mD_mnl_tma: cute.Tensor,
|
|
484
698
|
mD_mnl: cute.Tensor,
|
|
699
|
+
tma_atom_c: Optional[cute.CopyAtom],
|
|
700
|
+
mC_mnl: Optional[cute.Tensor],
|
|
701
|
+
mAIdx: Optional[cute.Tensor],
|
|
702
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
703
|
+
tensormaps: Optional[cute.Tensor],
|
|
485
704
|
tiled_mma: cute.TiledMma,
|
|
486
|
-
|
|
705
|
+
cluster_layout_mnk: cute.Layout,
|
|
487
706
|
a_smem_layout_staged: cute.ComposedLayout,
|
|
488
707
|
b_smem_layout_staged: cute.ComposedLayout,
|
|
489
708
|
epi_smem_layout_staged: cute.ComposedLayout,
|
|
709
|
+
epi_c_smem_layout_staged: cute.ComposedLayout,
|
|
710
|
+
tile_sched_params: ParamsBase,
|
|
711
|
+
TileSchedulerCls: cutlass.Constexpr[Callable],
|
|
490
712
|
):
|
|
491
713
|
"""
|
|
492
714
|
GPU device kernel performing the batched GEMM computation.
|
|
@@ -501,12 +723,12 @@ class HopperWgmmaGemmKernel:
|
|
|
501
723
|
:type mB_nkl: cute.Tensor
|
|
502
724
|
:param tma_atom_d: TMA copy atom for D tensor
|
|
503
725
|
:type tma_atom_d: cute.CopyAtom
|
|
504
|
-
:param
|
|
505
|
-
:type
|
|
726
|
+
:param mD_mnl_tma: Output tensor D
|
|
727
|
+
:type mD_mnl_tma: cute.Tensor
|
|
506
728
|
:param tiled_mma: Tiled MMA object
|
|
507
729
|
:type tiled_mma: cute.TiledMma
|
|
508
|
-
:param
|
|
509
|
-
:type
|
|
730
|
+
:param cluster_layout_mnk: CTA layout
|
|
731
|
+
:type cluster_layout_mnk: cute.Layout
|
|
510
732
|
:param a_smem_layout_staged: Shared memory layout for A
|
|
511
733
|
:type a_smem_layout_staged: cute.ComposedLayout
|
|
512
734
|
:param b_smem_layout_staged: Shared memory layout for B
|
|
@@ -515,22 +737,25 @@ class HopperWgmmaGemmKernel:
|
|
|
515
737
|
:type epi_smem_layout_staged: cute.ComposedLayout
|
|
516
738
|
"""
|
|
517
739
|
|
|
740
|
+
varlen = const_expr(cu_seqlens_m is not None)
|
|
518
741
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
519
742
|
|
|
520
743
|
# /////////////////////////////////////////////////////////////////////////////
|
|
521
744
|
# Prefetch Tma desc
|
|
522
745
|
# /////////////////////////////////////////////////////////////////////////////
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
746
|
+
if warp_idx == self.ab_load_warp_id:
|
|
747
|
+
if const_expr(tma_atom_a is not None):
|
|
748
|
+
cpasync.prefetch_descriptor(tma_atom_a)
|
|
526
749
|
cpasync.prefetch_descriptor(tma_atom_b)
|
|
527
750
|
cpasync.prefetch_descriptor(tma_atom_d)
|
|
751
|
+
if const_expr(tma_atom_c is not None):
|
|
752
|
+
cpasync.prefetch_descriptor(tma_atom_c)
|
|
528
753
|
|
|
529
754
|
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
|
|
530
755
|
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
|
|
531
|
-
tma_copy_bytes = cute.size_in_bytes(self.
|
|
532
|
-
|
|
533
|
-
|
|
756
|
+
tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
757
|
+
if const_expr(not self.gather_A):
|
|
758
|
+
tma_copy_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
534
759
|
|
|
535
760
|
# /////////////////////////////////////////////////////////////////////////////
|
|
536
761
|
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
|
@@ -539,163 +764,378 @@ class HopperWgmmaGemmKernel:
|
|
|
539
764
|
storage = smem.allocate(self.shared_storage)
|
|
540
765
|
|
|
541
766
|
# Threads/warps participating in this pipeline
|
|
542
|
-
|
|
543
|
-
|
|
767
|
+
producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_threads
|
|
768
|
+
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
|
|
769
|
+
# Each warp will contribute to the arrive count with the number of mcast size
|
|
544
770
|
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
545
|
-
consumer_arrive_cnt = mcast_size * (
|
|
546
|
-
|
|
771
|
+
consumer_arrive_cnt = mcast_size * (tiled_mma.size // cute.arch.WARP_SIZE)
|
|
772
|
+
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
547
773
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
548
774
|
)
|
|
549
775
|
|
|
550
|
-
cta_layout_vmnk = cute.make_layout((1, *
|
|
551
|
-
|
|
552
|
-
|
|
776
|
+
cta_layout_vmnk = cute.make_layout((1, *cluster_layout_mnk.shape))
|
|
777
|
+
pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
|
|
778
|
+
ab_pipeline = pipeline_cls.create(
|
|
779
|
+
barrier_storage=storage.ab_pipeline_array_ptr.data_ptr(),
|
|
553
780
|
num_stages=self.ab_stage,
|
|
554
|
-
producer_group=
|
|
555
|
-
consumer_group=
|
|
781
|
+
producer_group=ab_pipeline_producer_group,
|
|
782
|
+
consumer_group=ab_pipeline_consumer_group,
|
|
556
783
|
tx_count=tma_copy_bytes,
|
|
557
784
|
cta_layout_vmnk=cta_layout_vmnk,
|
|
558
785
|
)
|
|
559
786
|
|
|
787
|
+
if const_expr(mC_mnl is not None):
|
|
788
|
+
# Threads/warps participating in this pipeline
|
|
789
|
+
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
790
|
+
# Each warp will contribute 1 to the arrive count
|
|
791
|
+
consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE
|
|
792
|
+
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
793
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
794
|
+
)
|
|
795
|
+
c_smem_layout = cute.slice_(epi_c_smem_layout_staged, (None, None, 0))
|
|
796
|
+
tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
|
|
797
|
+
epi_pipeline = pipeline.PipelineTmaAsync.create(
|
|
798
|
+
barrier_storage=storage.epi_pipeline_array_ptr.data_ptr(),
|
|
799
|
+
num_stages=self.epi_c_stage,
|
|
800
|
+
producer_group=epi_pipeline_producer_group,
|
|
801
|
+
consumer_group=epi_pipeline_consumer_group,
|
|
802
|
+
tx_count=tma_copy_c_bytes,
|
|
803
|
+
)
|
|
804
|
+
else:
|
|
805
|
+
epi_pipeline = None
|
|
806
|
+
|
|
807
|
+
if const_expr(tile_sched_params.tile_count_semaphore is not None):
|
|
808
|
+
# Dynamic persistent scheduler
|
|
809
|
+
# Threads/warps participating in this pipeline
|
|
810
|
+
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
811
|
+
cluster_size = cute.size(cluster_layout_mnk)
|
|
812
|
+
# Each warp that are not the scheduler warp will contribute 1 to the arrive count
|
|
813
|
+
consumer_arrive_cnt = (
|
|
814
|
+
(self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps
|
|
815
|
+
) * cluster_size - 1
|
|
816
|
+
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
817
|
+
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
818
|
+
)
|
|
819
|
+
sched_pipeline = pipeline.PipelineAsync.create(
|
|
820
|
+
barrier_storage=storage.sched_pipeline_array_ptr.data_ptr(),
|
|
821
|
+
num_stages=self.sched_stage,
|
|
822
|
+
producer_group=sched_pipeline_producer_group,
|
|
823
|
+
consumer_group=sched_pipeline_consumer_group,
|
|
824
|
+
# If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
|
|
825
|
+
consumer_mask=None if const_expr(cute.size(cluster_layout_mnk) == 1) else 0,
|
|
826
|
+
)
|
|
827
|
+
tile_count = storage.tile_count.get_tensor((self.sched_stage,))
|
|
828
|
+
else:
|
|
829
|
+
sched_pipeline = None
|
|
830
|
+
tile_count = None
|
|
831
|
+
|
|
560
832
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
561
833
|
# Generate smem tensor A/B
|
|
562
834
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
563
835
|
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
|
|
564
836
|
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
837
|
+
if const_expr(not self.is_persistent):
|
|
838
|
+
sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
|
|
839
|
+
sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
|
|
840
|
+
else:
|
|
841
|
+
sD = storage.sD.get_tensor(
|
|
842
|
+
epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
|
|
843
|
+
)
|
|
844
|
+
if const_expr(mC_mnl is not None):
|
|
845
|
+
sC = storage.sC.get_tensor(
|
|
846
|
+
epi_c_smem_layout_staged.outer, swizzle=epi_c_smem_layout_staged.inner
|
|
847
|
+
)
|
|
848
|
+
else:
|
|
849
|
+
sC = None
|
|
850
|
+
|
|
851
|
+
# Get tensormap buffer address
|
|
852
|
+
if const_expr(varlen):
|
|
853
|
+
grid_dim = cute.arch.grid_dim()
|
|
854
|
+
bid = cute.arch.block_idx()
|
|
855
|
+
tensormap_workspace_idx = (
|
|
856
|
+
bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0]
|
|
857
|
+
)
|
|
858
|
+
# TODO: this is only for D, not for A/B
|
|
859
|
+
if const_expr(self.pingpong):
|
|
860
|
+
tensormap_workspace_idx = tensormap_workspace_idx * 2 + warp_idx // 4
|
|
861
|
+
tensormap_manager = TensorMapManagerSm90(
|
|
862
|
+
self.tensormap_update_mode, HopperWgmmaGemmKernel.bytes_per_tensormap
|
|
863
|
+
)
|
|
864
|
+
tensormap_d_ptr = tensormap_manager.get_tensormap_ptr(
|
|
865
|
+
tensormaps[tensormap_workspace_idx, None].iterator
|
|
866
|
+
)
|
|
867
|
+
if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.SMEM):
|
|
868
|
+
tensormap_smem_ptr = storage.tensormap_buffer.data_ptr()
|
|
869
|
+
tensormap_d_smem_ptr = tensormap_smem_ptr + (warp_idx // 4) * (
|
|
870
|
+
HopperWgmmaGemmKernel.bytes_per_tensormap // 8
|
|
871
|
+
)
|
|
872
|
+
# Need this, otherwise "expected tma descriptor pointer to have alignment at least 64, but got 8"
|
|
873
|
+
tensormap_d_smem_ptr = cute.make_ptr(
|
|
874
|
+
cutlass.Int64,
|
|
875
|
+
tensormap_d_smem_ptr.toint(),
|
|
876
|
+
cute.AddressSpace.smem,
|
|
877
|
+
assumed_align=64,
|
|
878
|
+
)
|
|
879
|
+
tensormap_d_init_ptr = tensormap_d_smem_ptr
|
|
880
|
+
else:
|
|
881
|
+
tensormap_d_smem_ptr = None
|
|
882
|
+
tensormap_d_init_ptr = tensormap_d_ptr
|
|
883
|
+
else:
|
|
884
|
+
tensormap_d_smem_ptr = None
|
|
885
|
+
tensormap_manager, tensormap_d_ptr, tensormap_d_init_ptr = None, None, None
|
|
575
886
|
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
s_shape = (
|
|
579
|
-
(group_size_m, cdimx // group_size_m),
|
|
580
|
-
cdimy,
|
|
887
|
+
TileSchedulerCls = partial(
|
|
888
|
+
TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
|
|
581
889
|
)
|
|
582
|
-
s_stride = ((1, cdimy * group_size_m), group_size_m)
|
|
583
|
-
s_layout = cute.make_layout(s_shape, stride=s_stride)
|
|
584
|
-
num_reg_cids = cute.size(s_shape)
|
|
585
|
-
cid_m, cid_n = s_layout.get_flat_coord(cluster_id % num_reg_cids)
|
|
586
|
-
|
|
587
|
-
# Deal with the tail part
|
|
588
|
-
if cluster_id >= num_reg_cids:
|
|
589
|
-
tail_size_m = cdimx % group_size_m
|
|
590
|
-
tail_layout = cute.make_layout((tail_size_m, cdimy), stride=(1, tail_size_m))
|
|
591
|
-
tail_cid = cluster_id - num_reg_cids
|
|
592
|
-
tail_cid_m, tail_cid_n = tail_layout.get_flat_coord(tail_cid)
|
|
593
|
-
cid_m = cute.size(s_shape, mode=[0]) + tail_cid_m
|
|
594
|
-
cid_n = tail_cid_n
|
|
595
|
-
|
|
596
|
-
# Get the pid from cluster id
|
|
597
|
-
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
598
|
-
pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
|
|
599
|
-
pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
|
|
600
|
-
|
|
601
|
-
_, _, bidz = cute.arch.block_idx()
|
|
602
|
-
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
|
603
|
-
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
604
|
-
cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
605
890
|
|
|
606
891
|
k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2])
|
|
892
|
+
c_tile_cnt = cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.epi_tile))
|
|
607
893
|
|
|
608
|
-
if warp_idx >= self.
|
|
894
|
+
if warp_idx >= self.ab_load_warp_id:
|
|
609
895
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
|
|
610
|
-
if
|
|
896
|
+
if const_expr(mC_mnl is not None):
|
|
897
|
+
epi_load_barrier = pipeline.NamedBarrier(
|
|
898
|
+
barrier_id=int(NamedBarrierGemm.EpilogueLoad),
|
|
899
|
+
num_threads=self.num_ab_load_threads + self.num_epi_load_threads,
|
|
900
|
+
)
|
|
901
|
+
else:
|
|
902
|
+
epi_load_barrier = None
|
|
903
|
+
if (
|
|
904
|
+
warp_idx >= self.ab_load_warp_id
|
|
905
|
+
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
|
|
906
|
+
):
|
|
611
907
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
612
908
|
# Get mcast mask
|
|
613
909
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
910
|
+
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
911
|
+
cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
614
912
|
a_mcast_mask = cute.make_layout_image_mask(
|
|
615
|
-
|
|
913
|
+
cluster_layout_mnk, cluster_coord_mnk, mode=1
|
|
616
914
|
)
|
|
617
915
|
b_mcast_mask = cute.make_layout_image_mask(
|
|
618
|
-
|
|
916
|
+
cluster_layout_mnk, cluster_coord_mnk, mode=0
|
|
619
917
|
)
|
|
620
918
|
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
|
|
621
919
|
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
|
|
622
|
-
|
|
920
|
+
|
|
921
|
+
# Persistent tile scheduling loop
|
|
922
|
+
is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
|
|
923
|
+
if const_expr(cute.size(cluster_layout_mnk) > 1):
|
|
924
|
+
is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
|
|
925
|
+
tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
|
|
926
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
927
|
+
ab_producer_state = make_pipeline_state(
|
|
623
928
|
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
624
929
|
)
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
663
|
-
# Wait for A/B buffers to be empty before loading into them
|
|
664
|
-
# Also sets the transaction barrier for the A/B buffers
|
|
665
|
-
mainloop_pipeline.producer_acquire(mainloop_producer_state)
|
|
666
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
667
|
-
# TMA load A/B
|
|
668
|
-
# /////////////////////////////////////////////////////////////////////////////
|
|
669
|
-
cute.copy(
|
|
670
|
-
tma_atom_a,
|
|
671
|
-
tAgA_mkl[None, k_tile],
|
|
672
|
-
tAsA[None, mainloop_producer_state.index],
|
|
673
|
-
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state),
|
|
674
|
-
mcast_mask=a_mcast_mask,
|
|
930
|
+
do_epi_load_barrier_arrive = cutlass.Boolean(True)
|
|
931
|
+
while work_tile.is_valid_tile:
|
|
932
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
933
|
+
batch_idx = tile_coord_mnkl[3]
|
|
934
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
935
|
+
# Local_tile partition global tensors
|
|
936
|
+
# ///////////////////////////////////////////////////////////////////////////
|
|
937
|
+
if const_expr(not self.gather_A):
|
|
938
|
+
if const_expr(cu_seqlens_m is not None):
|
|
939
|
+
mA_mk = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mA_mkl)
|
|
940
|
+
else:
|
|
941
|
+
mA_mk = mA_mkl[None, None, batch_idx]
|
|
942
|
+
# (bM, bK, RestK)
|
|
943
|
+
gA_k = cute.local_tile(
|
|
944
|
+
mA_mk,
|
|
945
|
+
cute.select(self.tile_shape_mnk, [0, 2]),
|
|
946
|
+
(tile_coord_mnkl[0], None),
|
|
947
|
+
)
|
|
948
|
+
else:
|
|
949
|
+
mA_mk = mA_mkl
|
|
950
|
+
if const_expr(cu_seqlens_m is not None):
|
|
951
|
+
mAIdx_mk = cute.domain_offset((cu_seqlens_m[batch_idx],), mAIdx)
|
|
952
|
+
else:
|
|
953
|
+
mAIdx_mk = mAIdx[None, batch_idx]
|
|
954
|
+
gAIdx = cute.local_tile(
|
|
955
|
+
mAIdx_mk, (self.tile_shape_mnk[0],), (tile_coord_mnkl[0],)
|
|
956
|
+
)
|
|
957
|
+
# (bN, bK, RestK)
|
|
958
|
+
gB_k = cute.local_tile(
|
|
959
|
+
mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
|
|
960
|
+
)
|
|
961
|
+
# //////////////////////////////////////////////////////////////////////////
|
|
962
|
+
# Partition shared tensor for TMA load A/B
|
|
963
|
+
# //////////////////////////////////////////////////////////////////////////
|
|
964
|
+
# TMA load A partition_S/D
|
|
965
|
+
a_cta_layout = cute.make_layout(
|
|
966
|
+
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
|
|
675
967
|
)
|
|
676
|
-
|
|
968
|
+
a_cta_crd = cluster_coord_mnk[1]
|
|
969
|
+
if const_expr(not self.gather_A):
|
|
970
|
+
# ((atom_v, rest_v), STAGE)
|
|
971
|
+
# ((atom_v, rest_v), RestK)
|
|
972
|
+
tAsA, tAgA_k = cpasync.tma_partition(
|
|
973
|
+
tma_atom_a,
|
|
974
|
+
a_cta_crd,
|
|
975
|
+
a_cta_layout,
|
|
976
|
+
cute.group_modes(sA, 0, 2),
|
|
977
|
+
cute.group_modes(gA_k, 0, 2),
|
|
978
|
+
)
|
|
979
|
+
copy_A = partial(cute.copy, tma_atom_a, mcast_mask=a_mcast_mask)
|
|
980
|
+
else:
|
|
981
|
+
tiled_copy_A = self._make_gmem_tiled_copy_A(
|
|
982
|
+
mA_mkl.element_type, self.a_layout, self.num_ab_load_threads
|
|
983
|
+
)
|
|
984
|
+
tidx = (
|
|
985
|
+
cute.arch.thread_idx()[0]
|
|
986
|
+
- self.mma_warp_groups * self.num_threads_per_warp_group
|
|
987
|
+
)
|
|
988
|
+
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
989
|
+
# (atom_v, CPY_M, 1, STAGE)
|
|
990
|
+
tAsA = thr_copy_A.partition_D(sA)
|
|
991
|
+
assert tAsA.shape[2] == 1
|
|
992
|
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
|
993
|
+
copy_A = partial(cute.copy, tiled_copy_A)
|
|
994
|
+
# TMA load B partition_S/D
|
|
995
|
+
b_cta_layout = cute.make_layout(
|
|
996
|
+
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
|
|
997
|
+
)
|
|
998
|
+
b_cta_crd = cluster_coord_mnk[0]
|
|
999
|
+
# ((atom_v, rest_v), STAGE)
|
|
1000
|
+
# ((atom_v, rest_v), RestK)
|
|
1001
|
+
tBsB, tBgB_k = cpasync.tma_partition(
|
|
677
1002
|
tma_atom_b,
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
1003
|
+
b_cta_crd,
|
|
1004
|
+
b_cta_layout,
|
|
1005
|
+
cute.group_modes(sB, 0, 2),
|
|
1006
|
+
cute.group_modes(gB_k, 0, 2),
|
|
682
1007
|
)
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
1008
|
+
copy_B = partial(cute.copy, tma_atom_b, mcast_mask=b_mcast_mask)
|
|
1009
|
+
if const_expr(not self.gather_A):
|
|
1010
|
+
ab_producer_state = self.load_AB(
|
|
1011
|
+
ab_pipeline,
|
|
1012
|
+
ab_producer_state,
|
|
1013
|
+
copy_A,
|
|
1014
|
+
tAgA_k,
|
|
1015
|
+
tAsA,
|
|
1016
|
+
copy_B,
|
|
1017
|
+
tBgB_k,
|
|
1018
|
+
tBsB,
|
|
1019
|
+
)
|
|
1020
|
+
else:
|
|
1021
|
+
limit_m = (
|
|
1022
|
+
mAIdx.shape[0]
|
|
1023
|
+
if const_expr(cu_seqlens_m is None)
|
|
1024
|
+
else cu_seqlens_m[batch_idx + 1] - cu_seqlens_m[batch_idx]
|
|
1025
|
+
)
|
|
1026
|
+
ab_producer_state = self.load_AB_gather_A(
|
|
1027
|
+
ab_pipeline,
|
|
1028
|
+
ab_producer_state,
|
|
1029
|
+
thr_copy_A,
|
|
1030
|
+
mA_mk,
|
|
1031
|
+
tAsA,
|
|
1032
|
+
gAIdx,
|
|
1033
|
+
copy_B,
|
|
1034
|
+
tBgB_k,
|
|
1035
|
+
tBsB,
|
|
1036
|
+
limit_A=(
|
|
1037
|
+
limit_m - tile_coord_mnkl[0] * self.tile_shape_mnk[0],
|
|
1038
|
+
mA_mk.shape[1],
|
|
1039
|
+
),
|
|
1040
|
+
)
|
|
1041
|
+
if const_expr(epi_load_barrier is not None):
|
|
1042
|
+
# In the first work tile, the epi load warp will wait for the signal
|
|
1043
|
+
# from the mainloop load warp to start loading C, to avoid interfering
|
|
1044
|
+
# with loading A and B.
|
|
1045
|
+
if do_epi_load_barrier_arrive:
|
|
1046
|
+
epi_load_barrier.arrive()
|
|
1047
|
+
do_epi_load_barrier_arrive = cutlass.Boolean(False)
|
|
1048
|
+
tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1049
|
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1050
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1051
|
+
# End of persistent scheduler loop
|
|
1052
|
+
if const_expr(self.pingpong):
|
|
1053
|
+
# Need to write the tile_idx to smem for the next WG in the pingpong mode
|
|
1054
|
+
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
|
|
1055
|
+
ab_pipeline.producer_tail(ab_producer_state)
|
|
1056
|
+
if is_scheduler_warp:
|
|
1057
|
+
tile_scheduler.producer_tail()
|
|
1058
|
+
|
|
1059
|
+
# if const_expr(mC_mnl is not None):
|
|
1060
|
+
# if warp_idx == self.epi_load_warp_id:
|
|
1061
|
+
# epi_producer_state = make_pipeline_state(
|
|
1062
|
+
# pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
1063
|
+
# )
|
|
1064
|
+
# do_epi_load_barrier_wait = cutlass.Boolean(True)
|
|
1065
|
+
# tile_scheduler = TileSchedulerCls()
|
|
1066
|
+
# work_tile = tile_scheduler.initial_work_tile_info()
|
|
1067
|
+
# while work_tile.is_valid_tile:
|
|
1068
|
+
# tile_coord_mnkl = work_tile.tile_idx
|
|
1069
|
+
# batch_idx = tile_coord_mnkl[3]
|
|
1070
|
+
# if const_expr(cu_seqlens_m is not None):
|
|
1071
|
+
# mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
|
|
1072
|
+
# else:
|
|
1073
|
+
# mC_mn = mC_mnl[None, None, batch_idx]
|
|
1074
|
+
# # (bM, bN)
|
|
1075
|
+
# gC = cute.local_tile(
|
|
1076
|
+
# mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
|
|
1077
|
+
# )
|
|
1078
|
+
# tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
|
|
1079
|
+
# bGS_sC, bGS_gC = cpasync.tma_partition(
|
|
1080
|
+
# tma_atom_c,
|
|
1081
|
+
# 0,
|
|
1082
|
+
# cute.make_layout(1),
|
|
1083
|
+
# cute.group_modes(sC, 0, 2),
|
|
1084
|
+
# tCgC_for_tma_partition,
|
|
1085
|
+
# )
|
|
1086
|
+
# if do_epi_load_barrier_wait:
|
|
1087
|
+
# epi_load_barrier.arrive_and_wait()
|
|
1088
|
+
# do_epi_load_barrier_wait = cutlass.Boolean(False)
|
|
1089
|
+
# epi_tile_num = const_expr(cute.size(tCgC_for_tma_partition, mode=[1]))
|
|
1090
|
+
# epi_tile_shape = tCgC_for_tma_partition.shape[1]
|
|
1091
|
+
# for epi_idx in cutlass.range(epi_tile_num, unroll=1):
|
|
1092
|
+
# epi_pipeline.producer_acquire(epi_producer_state)
|
|
1093
|
+
# # Get the global memory coordinate for the current epi tile
|
|
1094
|
+
# epi_tile_layout = cute.make_layout(
|
|
1095
|
+
# epi_tile_shape, stride=(epi_tile_shape[1], 1)
|
|
1096
|
+
# )
|
|
1097
|
+
# gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1098
|
+
# cute.copy(
|
|
1099
|
+
# tma_atom_c,
|
|
1100
|
+
# bGS_gC[None, gmem_coord],
|
|
1101
|
+
# bGS_sC[None, epi_producer_state.index],
|
|
1102
|
+
# tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1103
|
+
# )
|
|
1104
|
+
# # Epi pipeline's producer commit is a NOP
|
|
1105
|
+
# epi_pipeline.producer_commit(epi_producer_state)
|
|
1106
|
+
# epi_producer_state.advance()
|
|
1107
|
+
# tile_scheduler.advance_to_next_work()
|
|
1108
|
+
# work_tile = tile_scheduler.get_current_work()
|
|
1109
|
+
# # End of persistent scheduler loop
|
|
1110
|
+
# epi_pipeline.producer_tail(epi_producer_state)
|
|
1111
|
+
|
|
1112
|
+
if warp_idx < self.ab_load_warp_id:
|
|
689
1113
|
cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
|
|
1114
|
+
is_tma_warp = cutlass.Boolean(
|
|
1115
|
+
(not self.pingpong and warp_idx == 0)
|
|
1116
|
+
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
|
|
1117
|
+
)
|
|
1118
|
+
if const_expr(varlen):
|
|
1119
|
+
# initialize tensormap for D
|
|
1120
|
+
tensormap_manager.init_tensormap_from_atom(
|
|
1121
|
+
tma_atom_d,
|
|
1122
|
+
tensormap_d_init_ptr,
|
|
1123
|
+
is_manager_warp=is_tma_warp,
|
|
1124
|
+
)
|
|
690
1125
|
# //////////////////////////////////////////////////////////////////////////////
|
|
691
1126
|
# Partition global tensor for TiledMMA_A/B/C
|
|
692
1127
|
# //////////////////////////////////////////////////////////////////////////////
|
|
693
1128
|
tidx, _, _ = cute.arch.thread_idx()
|
|
694
1129
|
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
|
|
1130
|
+
if const_expr(self.pingpong):
|
|
1131
|
+
tidx = tidx % self.num_threads_per_warp_group
|
|
695
1132
|
warp_group_thread_layout = cute.make_layout(
|
|
696
|
-
self.mma_warp_groups
|
|
1133
|
+
self.mma_warp_groups if not self.pingpong else 1,
|
|
1134
|
+
stride=self.num_threads_per_warp_group,
|
|
1135
|
+
)
|
|
1136
|
+
thr_mma = tiled_mma.get_slice(
|
|
1137
|
+
warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
|
|
697
1138
|
)
|
|
698
|
-
thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx))
|
|
699
1139
|
|
|
700
1140
|
# //////////////////////////////////////////////////////////////////////////////
|
|
701
1141
|
# Make fragments
|
|
@@ -705,148 +1145,537 @@ class HopperWgmmaGemmKernel:
|
|
|
705
1145
|
|
|
706
1146
|
acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
|
|
707
1147
|
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
1148
|
+
if const_expr(self.fp8_slow_accum):
|
|
1149
|
+
acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
1150
|
+
else:
|
|
1151
|
+
acc_slow = None
|
|
1152
|
+
|
|
1153
|
+
if const_expr(self.pingpong):
|
|
1154
|
+
if warp_group_idx == 0:
|
|
1155
|
+
# WG0 needs a start signal at the very beginning
|
|
1156
|
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
|
|
1157
|
+
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
|
|
708
1158
|
|
|
709
|
-
|
|
710
|
-
|
|
1159
|
+
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
|
|
1160
|
+
epi_read_state = make_pipeline_state(
|
|
1161
|
+
pipeline.PipelineUserType.Consumer, self.epi_c_stage
|
|
711
1162
|
)
|
|
712
|
-
|
|
713
|
-
pipeline.PipelineUserType.
|
|
1163
|
+
epi_producer_state = make_pipeline_state(
|
|
1164
|
+
pipeline.PipelineUserType.Producer, self.epi_c_stage
|
|
714
1165
|
)
|
|
1166
|
+
tile_scheduler = TileSchedulerCls()
|
|
1167
|
+
if const_expr(self.pingpong):
|
|
1168
|
+
if warp_idx >= 4:
|
|
1169
|
+
# Advance 2nd Math WG to the next work tile for the startup
|
|
1170
|
+
tile_scheduler.advance_to_next_work()
|
|
1171
|
+
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
|
1172
|
+
ab_read_state.advance_iters(k_tile_cnt)
|
|
1173
|
+
epi_read_state.advance_iters(c_tile_cnt)
|
|
1174
|
+
epi_producer_state.advance_iters(c_tile_cnt)
|
|
1175
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1176
|
+
if const_expr(varlen):
|
|
1177
|
+
# wait tensormap initialization complete before update
|
|
1178
|
+
tensormap_manager.fence_tensormap_initialization()
|
|
1179
|
+
# batch index of last tile
|
|
1180
|
+
last_batch_idx = cutlass.Int32(-1)
|
|
1181
|
+
while work_tile.is_valid_tile:
|
|
1182
|
+
tile_coord_mnkl = work_tile.tile_idx
|
|
1183
|
+
batch_idx = tile_coord_mnkl[3]
|
|
1184
|
+
if const_expr(varlen):
|
|
1185
|
+
is_group_changed = batch_idx != last_batch_idx
|
|
1186
|
+
last_batch_idx = batch_idx
|
|
1187
|
+
if is_group_changed:
|
|
1188
|
+
# construct tensor D based on real address, shape and stride information
|
|
1189
|
+
tensormap_manager.update_tensormap_shape(
|
|
1190
|
+
((tensormap_d_ptr),),
|
|
1191
|
+
is_manager_warp=is_tma_warp,
|
|
1192
|
+
tensormap_smem_ptr=(tensormap_d_smem_ptr,),
|
|
1193
|
+
shapes=(cu_seqlens_m[batch_idx + 1],),
|
|
1194
|
+
orders=(0 if const_expr(self.d_layout.is_m_major_c()) else 1,),
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
ab_read_state, tiled_mma = self.mma(
|
|
1198
|
+
ab_pipeline,
|
|
1199
|
+
ab_read_state,
|
|
1200
|
+
tiled_mma,
|
|
1201
|
+
tCrA,
|
|
1202
|
+
tCrB,
|
|
1203
|
+
acc,
|
|
1204
|
+
acc_slow,
|
|
1205
|
+
k_tile_cnt,
|
|
1206
|
+
warp_group_idx,
|
|
1207
|
+
)
|
|
1208
|
+
if const_expr(self.pingpong):
|
|
1209
|
+
# Update starting mainloop pipeline state for the next tile
|
|
1210
|
+
ab_read_state.advance_iters(k_tile_cnt)
|
|
715
1211
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
1212
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1213
|
+
# EPILOGUE
|
|
1214
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1215
|
+
if const_expr(self.pingpong):
|
|
1216
|
+
self.pingpong_barrier_sync(warp_group_idx, "epi")
|
|
1217
|
+
|
|
1218
|
+
epilogue_barrier = pipeline.NamedBarrier(
|
|
1219
|
+
barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads
|
|
724
1220
|
)
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
1221
|
+
|
|
1222
|
+
# Wait for all warp groups in the thread block to finish, because smem for tensor
|
|
1223
|
+
# A in the mainloop is reused in the epilogue if not persistent.
|
|
1224
|
+
if const_expr(not self.is_persistent):
|
|
1225
|
+
epilogue_barrier.arrive_and_wait()
|
|
1226
|
+
|
|
1227
|
+
if const_expr(varlen):
|
|
1228
|
+
# ensure the update to tensormap has completed before using it
|
|
1229
|
+
if is_group_changed:
|
|
1230
|
+
if is_tma_warp:
|
|
1231
|
+
tensormap_manager.fence_tensormap_update(tensormap_d_ptr)
|
|
1232
|
+
|
|
1233
|
+
# Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always
|
|
1234
|
+
# get st.matrix with num_matrices=4
|
|
1235
|
+
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
|
|
1236
|
+
self.d_layout, elem_ty_d=self.d_dtype, elem_ty_acc=self.acc_dtype
|
|
1237
|
+
)
|
|
1238
|
+
copy_atom_C = cute.make_copy_atom(
|
|
1239
|
+
warp.StMatrix8x8x16bOp(
|
|
1240
|
+
self.d_layout.is_m_major_c(),
|
|
1241
|
+
num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2,
|
|
1242
|
+
),
|
|
1243
|
+
cutlass.Float16, # this is just to get the right source layout
|
|
1244
|
+
)
|
|
1245
|
+
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
|
1246
|
+
tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom)
|
|
1247
|
+
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
1248
|
+
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
1249
|
+
tRS_sD = thr_copy_r2s.partition_D(sD)
|
|
1250
|
+
# (R2S, R2S_M, R2S_N)
|
|
1251
|
+
tRS_rAcc = tiled_copy_r2s.retile(acc)
|
|
1252
|
+
|
|
1253
|
+
# Allocate D registers.
|
|
1254
|
+
tRS_rD_layout = cute.make_layout(thr_copy_r2s.partition_S(sD).shape[:3])
|
|
1255
|
+
tRS_rD = cute.make_fragment(tRS_rD_layout, self.acc_dtype)
|
|
1256
|
+
|
|
1257
|
+
if const_expr(mC_mnl is not None):
|
|
1258
|
+
copy_atom_s2r = utils.sm90_get_smem_load_op(self.c_layout, self.c_dtype)
|
|
1259
|
+
tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
|
|
1260
|
+
thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
|
|
1261
|
+
tSR_sC = thr_copy_s2r.partition_S(sC)
|
|
1262
|
+
tRS_rC = cute.make_fragment(tRS_rD_layout, self.c_dtype)
|
|
1263
|
+
tSR_rC = thr_copy_s2r.retile(tRS_rC)
|
|
1264
|
+
else:
|
|
1265
|
+
thr_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
|
|
1266
|
+
|
|
1267
|
+
if const_expr(cu_seqlens_m is not None):
|
|
1268
|
+
mD_mn_tma = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mD_mnl_tma)
|
|
1269
|
+
else:
|
|
1270
|
+
mD_mn_tma = mD_mnl_tma[None, None, batch_idx]
|
|
1271
|
+
# (bM, bN)
|
|
1272
|
+
gD = cute.local_tile(
|
|
1273
|
+
mD_mn_tma, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
|
|
1274
|
+
)
|
|
1275
|
+
tDgD_for_tma_partition = cute.zipped_divide(gD, self.epi_tile)
|
|
1276
|
+
bSG_sD, bSG_gD = cpasync.tma_partition(
|
|
1277
|
+
tma_atom_d,
|
|
1278
|
+
0,
|
|
1279
|
+
cute.make_layout(1),
|
|
1280
|
+
cute.group_modes(sD, 0, 2),
|
|
1281
|
+
tDgD_for_tma_partition,
|
|
1282
|
+
)
|
|
1283
|
+
|
|
1284
|
+
if const_expr(mC_mnl is not None):
|
|
1285
|
+
if const_expr(cu_seqlens_m is not None):
|
|
1286
|
+
mC_mn = cute.domain_offset((cu_seqlens_m[batch_idx], 0), mC_mnl)
|
|
1287
|
+
else:
|
|
1288
|
+
mC_mn = mC_mnl[None, None, batch_idx]
|
|
1289
|
+
# (bM, bN)
|
|
1290
|
+
gC = cute.local_tile(
|
|
1291
|
+
mC_mn, cute.select(self.tile_shape_mnk, [0, 1]), tile_coord_mnkl[:2]
|
|
1292
|
+
)
|
|
1293
|
+
tCgC_for_tma_partition = cute.zipped_divide(gC, self.epi_tile)
|
|
1294
|
+
bGS_sC, bGS_gC = cpasync.tma_partition(
|
|
1295
|
+
tma_atom_c,
|
|
1296
|
+
0,
|
|
1297
|
+
cute.make_layout(1),
|
|
1298
|
+
cute.group_modes(sC, 0, 2),
|
|
1299
|
+
tCgC_for_tma_partition,
|
|
741
1300
|
)
|
|
742
1301
|
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
1302
|
+
epi_tile_num = const_expr(cute.size(tDgD_for_tma_partition, mode=[1]))
|
|
1303
|
+
epi_tile_shape = tDgD_for_tma_partition.shape[1]
|
|
1304
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
1305
|
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
1306
|
+
|
|
1307
|
+
if const_expr(mC_mnl is not None):
|
|
1308
|
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
1309
|
+
if is_tma_warp:
|
|
1310
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1311
|
+
# Get the global memory coordinate for the current epi tile
|
|
1312
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1313
|
+
cute.copy(
|
|
1314
|
+
tma_atom_c,
|
|
1315
|
+
bGS_gC[None, gmem_coord],
|
|
1316
|
+
bGS_sC[None, epi_producer_state.index],
|
|
1317
|
+
tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state),
|
|
1318
|
+
)
|
|
1319
|
+
# Epi pipeline's producer commit is a NOP
|
|
1320
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1321
|
+
epi_producer_state.advance()
|
|
1322
|
+
|
|
1323
|
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
1324
|
+
# Copy from acc to D registers
|
|
1325
|
+
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
1326
|
+
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
1327
|
+
if const_expr(mC_mnl is not None):
|
|
1328
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
1329
|
+
cute.copy(
|
|
1330
|
+
thr_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC
|
|
1331
|
+
)
|
|
1332
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
1333
|
+
cute.arch.fence_proxy(
|
|
1334
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1335
|
+
)
|
|
1336
|
+
cute.arch.sync_warp()
|
|
1337
|
+
with cute.arch.elect_one():
|
|
1338
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
1339
|
+
epi_read_state.advance()
|
|
1340
|
+
if const_expr(epi_idx + self.epi_c_stage < epi_tile_num):
|
|
1341
|
+
if is_tma_warp:
|
|
1342
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
1343
|
+
# Get the global memory coordinate for the current epi tile
|
|
1344
|
+
gmem_coord = epi_tile_layout.get_hier_coord(
|
|
1345
|
+
epi_idx + self.epi_c_stage
|
|
1346
|
+
)
|
|
1347
|
+
cute.copy(
|
|
1348
|
+
tma_atom_c,
|
|
1349
|
+
bGS_gC[None, gmem_coord],
|
|
1350
|
+
bGS_sC[None, epi_producer_state.index],
|
|
1351
|
+
tma_bar_ptr=epi_pipeline.producer_get_barrier(
|
|
1352
|
+
epi_producer_state
|
|
1353
|
+
),
|
|
1354
|
+
)
|
|
1355
|
+
# Epi pipeline's producer commit is a NOP
|
|
1356
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
1357
|
+
epi_producer_state.advance()
|
|
1358
|
+
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(self.acc_dtype))
|
|
1359
|
+
# Type conversion
|
|
1360
|
+
tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
|
|
1361
|
+
tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
|
|
1362
|
+
# Copy from D registers to shared memory
|
|
1363
|
+
epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3])
|
|
1364
|
+
cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
|
|
1365
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
1366
|
+
cute.arch.fence_proxy(
|
|
1367
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
764
1368
|
)
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
1369
|
+
epilogue_barrier.arrive_and_wait()
|
|
1370
|
+
# Get the global memory coordinate for the current epi tile
|
|
1371
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
1372
|
+
# Copy from shared memory to global memory
|
|
1373
|
+
if is_tma_warp:
|
|
1374
|
+
if const_expr(varlen):
|
|
1375
|
+
tma_desc_ptr = tensormap_manager.get_tensormap_ptr(
|
|
1376
|
+
tensormap_d_ptr,
|
|
1377
|
+
cute.AddressSpace.generic,
|
|
1378
|
+
)
|
|
1379
|
+
else:
|
|
1380
|
+
tma_desc_ptr = None
|
|
1381
|
+
cute.copy(
|
|
1382
|
+
tma_atom_d,
|
|
1383
|
+
bSG_sD[None, epi_buffer],
|
|
1384
|
+
bSG_gD[None, gmem_coord],
|
|
1385
|
+
tma_desc_ptr=tma_desc_ptr,
|
|
1386
|
+
)
|
|
1387
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
1388
|
+
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
|
|
1389
|
+
epilogue_barrier.arrive_and_wait()
|
|
1390
|
+
|
|
1391
|
+
if const_expr(self.pingpong):
|
|
1392
|
+
# Update starting load/store pipeline states for the next tile
|
|
1393
|
+
epi_read_state.advance_iters(c_tile_cnt)
|
|
1394
|
+
epi_producer_state.advance_iters(c_tile_cnt)
|
|
1395
|
+
# With pingpong, 2 WGs write two different output tiles to the same smem,
|
|
1396
|
+
# so we have to make sure the smem content is done reading before signaling
|
|
1397
|
+
# the next WG's epilogue.
|
|
1398
|
+
if warp_idx == 0 or warp_idx == 4:
|
|
1399
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
1400
|
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
|
|
1401
|
+
|
|
1402
|
+
tile_scheduler.advance_to_next_work(
|
|
1403
|
+
advance_count=1 if not self.pingpong else self.mma_warp_groups
|
|
1404
|
+
)
|
|
1405
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1406
|
+
# End of persistent scheduler loop
|
|
1407
|
+
|
|
1408
|
+
if const_expr(not self.pingpong):
|
|
1409
|
+
if warp_idx == 0:
|
|
1410
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
1411
|
+
|
|
1412
|
+
@cute.jit
|
|
1413
|
+
def load_AB(
|
|
1414
|
+
self,
|
|
1415
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1416
|
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1417
|
+
copy_A: Callable,
|
|
1418
|
+
tAgA: cute.Tensor,
|
|
1419
|
+
tAsA: cute.Tensor,
|
|
1420
|
+
copy_B: Callable,
|
|
1421
|
+
tBgB: cute.Tensor,
|
|
1422
|
+
tBsB: cute.Tensor,
|
|
1423
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1424
|
+
k_tile_cnt = cute.size(tAgA, mode=[1])
|
|
1425
|
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1426
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1427
|
+
if 0 < k_tile_cnt:
|
|
1428
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1429
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1430
|
+
# TMA load
|
|
1431
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1432
|
+
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
|
|
1433
|
+
# Wait for A/B buffers to be empty before loading into them
|
|
1434
|
+
# Also sets the transaction barrier for the A/B buffers
|
|
1435
|
+
ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
|
|
1436
|
+
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
1437
|
+
copy_A(
|
|
1438
|
+
tAgA[None, k_tile],
|
|
1439
|
+
tAsA[None, ab_producer_state.index],
|
|
1440
|
+
tma_bar_ptr=tma_bar_ptr,
|
|
782
1441
|
)
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
1442
|
+
copy_B(
|
|
1443
|
+
tBgB[None, k_tile],
|
|
1444
|
+
tBsB[None, ab_producer_state.index],
|
|
1445
|
+
tma_bar_ptr=tma_bar_ptr,
|
|
786
1446
|
)
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
1447
|
+
# Mainloop pipeline's producer commit is a NOP
|
|
1448
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1449
|
+
ab_producer_state.advance()
|
|
1450
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1451
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1452
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1453
|
+
return ab_producer_state
|
|
1454
|
+
|
|
1455
|
+
@cute.jit
|
|
1456
|
+
def load_AB_gather_A(
|
|
1457
|
+
self,
|
|
1458
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1459
|
+
ab_producer_state: cutlass.pipeline.PipelineState,
|
|
1460
|
+
thr_copy_A: cute.core.ThrCopy,
|
|
1461
|
+
mA: cute.Tensor,
|
|
1462
|
+
tAsA: cute.Tensor,
|
|
1463
|
+
gAIdx: cute.Tensor,
|
|
1464
|
+
copy_B: Callable,
|
|
1465
|
+
tBgB: cute.Tensor,
|
|
1466
|
+
tBsB: cute.Tensor,
|
|
1467
|
+
limit_A: Tuple[Int32, Int32],
|
|
1468
|
+
) -> cutlass.pipeline.PipelineState:
|
|
1469
|
+
# (atom_v, CPY_M, 1, RestK)
|
|
1470
|
+
limit_m, limit_k = limit_A
|
|
1471
|
+
limit_m = min(limit_m, self.tile_shape_mnk[0]) # To avoid writing beyond smem limit
|
|
1472
|
+
cA = cute.make_identity_tensor(cute.select(self.tile_shape_mnk, [0, 2]))
|
|
1473
|
+
tAcA = thr_copy_A.partition_S(cA)
|
|
1474
|
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
1475
|
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
1476
|
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
1477
|
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
1478
|
+
limit_m = limit_m - tAcA[0][0]
|
|
1479
|
+
# Read indices for A
|
|
1480
|
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
1481
|
+
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
|
1482
|
+
for m in cutlass.range(rows_per_thread):
|
|
1483
|
+
row_idx = tAcA[0, m, 0][0]
|
|
1484
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1485
|
+
m_idx[m] = gAIdx[row_idx]
|
|
1486
|
+
else:
|
|
1487
|
+
m_idx[m] = -1
|
|
1488
|
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
1489
|
+
# (m, (bK, RestK))
|
|
1490
|
+
mA_k = cute.logical_divide(mA, (None, self.tile_shape_mnk[2]))
|
|
1491
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
1492
|
+
k_tile_cnt = cute.size(tBgB, mode=[1])
|
|
1493
|
+
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
|
|
1494
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1495
|
+
if 0 < k_tile_cnt:
|
|
1496
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1497
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1498
|
+
# TMA load on B and cp.async on A
|
|
1499
|
+
# /////////////////////////////////////////////////////////////////////////
|
|
1500
|
+
copy_A = partial(cute.copy, thr_copy_A)
|
|
1501
|
+
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
|
|
1502
|
+
# Wait for A/B buffers to be empty before loading into them
|
|
1503
|
+
# Also sets the transaction barrier for the A/B buffers
|
|
1504
|
+
ab_pipeline.producer_acquire(
|
|
1505
|
+
ab_producer_state,
|
|
1506
|
+
peek_ab_empty_status,
|
|
1507
|
+
# A tiny bit faster to rotate the warp that does TMA
|
|
1508
|
+
is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
|
|
797
1509
|
)
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
1510
|
+
# A bit faster to load B first while we calculate the predicate for A
|
|
1511
|
+
if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
|
|
1512
|
+
copy_B(
|
|
1513
|
+
tBgB[None, k_tile],
|
|
1514
|
+
tBsB[None, ab_producer_state.index],
|
|
1515
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
1516
|
+
)
|
|
1517
|
+
# (m, bK)
|
|
1518
|
+
mA_cur = mA_k[None, (None, k_tile)]
|
|
1519
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1520
|
+
# (elems_per_load, thread_per_row)
|
|
1521
|
+
mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
|
|
1522
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1523
|
+
# There's only 1 load per row
|
|
1524
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1525
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1526
|
+
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1527
|
+
# This tells mbarrier to track the completion of cp.async
|
|
1528
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1529
|
+
ab_producer_state.advance()
|
|
1530
|
+
peek_ab_empty_status = cutlass.Boolean(True)
|
|
1531
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1532
|
+
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
|
|
1533
|
+
# bound checking in the K dimension on the last k_tile
|
|
1534
|
+
if 0 < k_tile_cnt:
|
|
1535
|
+
k_tile = k_tile_cnt - 1
|
|
1536
|
+
ab_pipeline.producer_acquire(
|
|
1537
|
+
ab_producer_state,
|
|
1538
|
+
peek_ab_empty_status,
|
|
1539
|
+
is_tma_warp=warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps),
|
|
805
1540
|
)
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
# Copy from acc to D registers
|
|
812
|
-
tRS_rD = cute.make_fragment_like(tRS_sD[None, None, None, 0], self.acc_dtype)
|
|
813
|
-
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
814
|
-
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
815
|
-
# Type conversion
|
|
816
|
-
tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
|
|
817
|
-
tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
|
|
818
|
-
# Copy from D registers to shared memory
|
|
819
|
-
epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3])
|
|
820
|
-
# cute.arch.barrier(barrier_id=1, number_of_threads=self.num_mma_threads)
|
|
821
|
-
cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)])
|
|
822
|
-
cute.arch.fence_proxy(
|
|
823
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
1541
|
+
if warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps):
|
|
1542
|
+
copy_B(
|
|
1543
|
+
tBgB[None, k_tile],
|
|
1544
|
+
tBsB[None, ab_producer_state.index],
|
|
1545
|
+
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
824
1546
|
)
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
#
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
1547
|
+
assert tAcA.shape[2] == 1 # there's only 1 load along the K dimension
|
|
1548
|
+
tApA = cute.make_fragment(1, cutlass.Boolean)
|
|
1549
|
+
tApA[0] = tAcA[0, 0, 0][1] < limit_k
|
|
1550
|
+
# (m, bK)
|
|
1551
|
+
mA_cur = mA_k[None, (None, k_tile)]
|
|
1552
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
1553
|
+
# (elems_per_load, thread_per_row)
|
|
1554
|
+
mA_row = cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,))
|
|
1555
|
+
if t0AcA[0, m, 0][0] < limit_m:
|
|
1556
|
+
# There's only 1 load per row
|
|
1557
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
1558
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
1559
|
+
# copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index], pred=tApA)
|
|
1560
|
+
# TODO
|
|
1561
|
+
copy_A(mA_row[None, ki], tAsA[(None, m), ab_producer_state.index])
|
|
1562
|
+
ab_pipeline.producer_commit(ab_producer_state)
|
|
1563
|
+
ab_producer_state.advance()
|
|
1564
|
+
return ab_producer_state
|
|
1565
|
+
|
|
1566
|
+
@cute.jit
|
|
1567
|
+
def mma(
|
|
1568
|
+
self,
|
|
1569
|
+
ab_pipeline: cutlass.pipeline.PipelineAsync,
|
|
1570
|
+
ab_read_state: cutlass.pipeline.PipelineState,
|
|
1571
|
+
tiled_mma: cute.TiledMma,
|
|
1572
|
+
tCrA: cute.Tensor,
|
|
1573
|
+
tCrB: cute.Tensor,
|
|
1574
|
+
acc: cute.Tensor,
|
|
1575
|
+
acc_slow: Optional[cute.Tensor],
|
|
1576
|
+
k_tile_cnt: Int32,
|
|
1577
|
+
warp_group_idx: Int32,
|
|
1578
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
|
|
1579
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1580
|
+
# Prologue MMAs
|
|
1581
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1582
|
+
k_pipe_mmas = 1
|
|
1583
|
+
ab_release_state = ab_read_state.clone()
|
|
1584
|
+
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
|
|
1585
|
+
if const_expr(self.pingpong):
|
|
1586
|
+
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
|
|
1587
|
+
peek_ab_full_status = cutlass.Boolean(True)
|
|
1588
|
+
if 0 < k_tile_cnt:
|
|
1589
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1590
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
1591
|
+
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
1592
|
+
# TODO: this is probably not correct if k_tile_cnt == 0
|
|
1593
|
+
for k_tile in cutlass.range(num_prologue_mma):
|
|
1594
|
+
# Wait for A/B buffer to be ready
|
|
1595
|
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
1596
|
+
warpgroup.fence()
|
|
1597
|
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1598
|
+
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
|
1599
|
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1600
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
|
1601
|
+
warpgroup.commit_group()
|
|
1602
|
+
ab_read_state.advance()
|
|
1603
|
+
peek_ab_full_status = cutlass.Boolean(True)
|
|
1604
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1605
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1606
|
+
if const_expr(self.fp8_slow_accum):
|
|
1607
|
+
warpgroup.wait_group(0)
|
|
1608
|
+
acc_slow.store(acc.load())
|
|
1609
|
+
|
|
1610
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1611
|
+
# MAINLOOP
|
|
1612
|
+
# /////////////////////////////////////////////////////////////////////////////
|
|
1613
|
+
for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
|
|
1614
|
+
# Wait for TMA copies to complete
|
|
1615
|
+
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
|
|
1616
|
+
# WGMMA
|
|
1617
|
+
warpgroup.fence()
|
|
1618
|
+
if const_expr(self.fp8_slow_accum):
|
|
1619
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
|
|
1620
|
+
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
|
1621
|
+
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
|
|
1622
|
+
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
|
|
1623
|
+
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
|
|
1624
|
+
warpgroup.commit_group()
|
|
1625
|
+
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
|
|
1626
|
+
if const_expr(not self.fp8_slow_accum):
|
|
1627
|
+
warpgroup.wait_group(k_pipe_mmas)
|
|
1628
|
+
else:
|
|
1629
|
+
warpgroup.wait_group(0)
|
|
1630
|
+
acc_slow.store(acc_slow.load() + acc.load())
|
|
1631
|
+
ab_pipeline.consumer_release(ab_release_state)
|
|
1632
|
+
ab_read_state.advance()
|
|
1633
|
+
ab_release_state.advance()
|
|
1634
|
+
peek_ab_full_status = cutlass.Boolean(True)
|
|
1635
|
+
if k_tile + 1 < k_tile_cnt:
|
|
1636
|
+
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
|
|
1637
|
+
if const_expr(self.pingpong):
|
|
1638
|
+
# Cue for next WG's MMA to start
|
|
1639
|
+
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
|
|
1640
|
+
if const_expr(not self.fp8_slow_accum):
|
|
1641
|
+
# fp8_slow_accum would already called wait_group(0) inside the loop
|
|
1642
|
+
warpgroup.wait_group(0)
|
|
1643
|
+
for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
|
|
1644
|
+
ab_pipeline.consumer_release(ab_release_state)
|
|
1645
|
+
ab_release_state.advance()
|
|
1646
|
+
if const_expr(self.fp8_slow_accum):
|
|
1647
|
+
acc.store(acc_slow.load())
|
|
1648
|
+
# If we don't return the tiled_mma, we get compiler error
|
|
1649
|
+
# "operand #0 does not dominate this use"
|
|
1650
|
+
return ab_read_state, tiled_mma
|
|
1651
|
+
|
|
1652
|
+
def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str):
|
|
1653
|
+
assert stage in ["mma", "epi"]
|
|
1654
|
+
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1655
|
+
cute.arch.barrier(
|
|
1656
|
+
barrier_id=int(barrier) + warp_group_idx,
|
|
1657
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1658
|
+
)
|
|
839
1659
|
|
|
840
|
-
|
|
841
|
-
|
|
1660
|
+
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
|
|
1661
|
+
assert stage in ["mma", "epi"]
|
|
1662
|
+
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
|
|
1663
|
+
cute.arch.barrier_arrive(
|
|
1664
|
+
barrier_id=int(barrier) + warp_group_idx,
|
|
1665
|
+
number_of_threads=2 * self.num_threads_per_warp_group,
|
|
1666
|
+
)
|
|
842
1667
|
|
|
843
1668
|
@staticmethod
|
|
844
1669
|
def _compute_stages(
|
|
845
1670
|
tile_shape_mnk: Tuple[int, int, int],
|
|
1671
|
+
epi_tile: Optional[Tuple[int, int]],
|
|
846
1672
|
a_dtype: Type[cutlass.Numeric],
|
|
847
1673
|
b_dtype: Type[cutlass.Numeric],
|
|
1674
|
+
d_dtype: Type[cutlass.Numeric],
|
|
1675
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
848
1676
|
smem_capacity: int,
|
|
849
1677
|
occupancy: int,
|
|
1678
|
+
overlap_sD_sA: bool,
|
|
850
1679
|
) -> Tuple[int, int]:
|
|
851
1680
|
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
852
1681
|
|
|
@@ -866,10 +1695,15 @@ class HopperWgmmaGemmKernel:
|
|
|
866
1695
|
:rtype: Tuple[int, int]
|
|
867
1696
|
"""
|
|
868
1697
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
1698
|
+
epi_stage = 2
|
|
1699
|
+
if overlap_sD_sA:
|
|
1700
|
+
epi_bytes = 0
|
|
1701
|
+
else:
|
|
1702
|
+
d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8
|
|
1703
|
+
epi_bytes = d_bytes_per_stage * epi_stage
|
|
1704
|
+
epi_c_stage = 0 if c_dtype is None else 2
|
|
1705
|
+
if c_dtype is not None:
|
|
1706
|
+
epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
|
|
873
1707
|
|
|
874
1708
|
a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
|
|
875
1709
|
b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
|
|
@@ -878,16 +1712,23 @@ class HopperWgmmaGemmKernel:
|
|
|
878
1712
|
)
|
|
879
1713
|
mbar_helpers_bytes = 1024
|
|
880
1714
|
|
|
881
|
-
|
|
1715
|
+
remaining_bytes = (
|
|
882
1716
|
(smem_capacity - occupancy * 1024) // occupancy - mbar_helpers_bytes - epi_bytes
|
|
883
|
-
)
|
|
884
|
-
|
|
1717
|
+
)
|
|
1718
|
+
ab_stage = remaining_bytes // ab_bytes_per_stage
|
|
1719
|
+
|
|
1720
|
+
# Refine epilogue stages:
|
|
1721
|
+
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
1722
|
+
# Add remaining unused smem to epilogue
|
|
1723
|
+
if not overlap_sD_sA:
|
|
1724
|
+
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // d_bytes_per_stage
|
|
1725
|
+
return ab_stage, epi_stage, epi_c_stage
|
|
885
1726
|
|
|
886
1727
|
@staticmethod
|
|
887
1728
|
def _sm90_compute_tile_shape_or_override(
|
|
888
1729
|
tile_shape_mnk: Tuple[int, int, int],
|
|
1730
|
+
atom_layout_mnk: Tuple[int, int, int],
|
|
889
1731
|
element_type: Type[cutlass.Numeric],
|
|
890
|
-
is_cooperative: bool = False,
|
|
891
1732
|
epi_tile_override: Tuple[int, int] | None = None,
|
|
892
1733
|
) -> Tuple[int, int]:
|
|
893
1734
|
"""Compute the epilogue tile shape or use override if provided.
|
|
@@ -906,33 +1747,42 @@ class HopperWgmmaGemmKernel:
|
|
|
906
1747
|
"""
|
|
907
1748
|
if epi_tile_override is not None:
|
|
908
1749
|
return epi_tile_override
|
|
909
|
-
if
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
|
|
916
|
-
return (tile_m, tile_n)
|
|
1750
|
+
if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1:
|
|
1751
|
+
tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0]))
|
|
1752
|
+
tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
|
|
1753
|
+
elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1:
|
|
1754
|
+
tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0]))
|
|
1755
|
+
tile_n = math.gcd(32, cute.size(tile_shape_mnk, mode=[1]))
|
|
917
1756
|
else:
|
|
1757
|
+
# In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set
|
|
1758
|
+
# epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the
|
|
1759
|
+
# M dimension first, then move to the N dimension. But the accumulator in registers
|
|
1760
|
+
# iterate along the N dimension first, then move to the M dimension.
|
|
1761
|
+
# We could change the epilogue to accommodate this,
|
|
1762
|
+
# but it's easier to just set epi_tile_m = 64.
|
|
918
1763
|
n_perf = 64 if element_type.width == 8 else 32
|
|
919
1764
|
tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0]))
|
|
920
1765
|
tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1]))
|
|
921
|
-
|
|
1766
|
+
return (tile_m, tile_n)
|
|
922
1767
|
|
|
923
1768
|
@staticmethod
|
|
924
1769
|
def _make_smem_layouts(
|
|
925
1770
|
tile_shape_mnk: Tuple[int, int, int],
|
|
926
1771
|
epi_tile: Tuple[int, int],
|
|
927
1772
|
a_dtype: Type[cutlass.Numeric],
|
|
928
|
-
a_layout: utils.LayoutEnum,
|
|
1773
|
+
a_layout: cutlass.utils.LayoutEnum,
|
|
929
1774
|
b_dtype: Type[cutlass.Numeric],
|
|
930
|
-
b_layout: utils.LayoutEnum,
|
|
1775
|
+
b_layout: cutlass.utils.LayoutEnum,
|
|
931
1776
|
ab_stage: int,
|
|
932
1777
|
d_dtype: Type[cutlass.Numeric],
|
|
933
|
-
d_layout: utils.LayoutEnum,
|
|
1778
|
+
d_layout: cutlass.utils.LayoutEnum,
|
|
934
1779
|
epi_stage: int,
|
|
935
|
-
|
|
1780
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1781
|
+
c_layout: Optional[cutlass.utils.LayoutEnum],
|
|
1782
|
+
epi_c_stage: int,
|
|
1783
|
+
) -> Tuple[
|
|
1784
|
+
cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
|
|
1785
|
+
]:
|
|
936
1786
|
"""Create shared memory layouts for A, B, and C tensors.
|
|
937
1787
|
|
|
938
1788
|
:param tile_shape_mnk: CTA tile shape (M,N,K)
|
|
@@ -942,17 +1792,17 @@ class HopperWgmmaGemmKernel:
|
|
|
942
1792
|
:param a_dtype: Data type for matrix A
|
|
943
1793
|
:type a_dtype: type[cutlass.Numeric]
|
|
944
1794
|
:param a_layout: Layout enum for matrix A
|
|
945
|
-
:type a_layout: utils.LayoutEnum
|
|
1795
|
+
:type a_layout: cutlass.utils.LayoutEnum
|
|
946
1796
|
:param b_dtype: Data type for matrix B
|
|
947
1797
|
:type b_dtype: type[cutlass.Numeric]
|
|
948
1798
|
:param b_layout: Layout enum for matrix B
|
|
949
|
-
:type b_layout: utils.LayoutEnum
|
|
1799
|
+
:type b_layout: cutlass.utils.LayoutEnum
|
|
950
1800
|
:param ab_stage: Number of stages for A/B tensors
|
|
951
1801
|
:type ab_stage: int
|
|
952
1802
|
:param d_dtype: Data type for output matrix C
|
|
953
1803
|
:type d_dtype: type[cutlass.Numeric]
|
|
954
1804
|
:param d_layout: Layout enum for the output matrix C
|
|
955
|
-
:type d_layout: utils.LayoutEnum
|
|
1805
|
+
:type d_layout: cutlass.utils.LayoutEnum
|
|
956
1806
|
:param epi_stage: Number of epilogue stages
|
|
957
1807
|
:type epi_stage: int
|
|
958
1808
|
|
|
@@ -998,11 +1848,7 @@ class HopperWgmmaGemmKernel:
|
|
|
998
1848
|
d_smem_shape = epi_tile
|
|
999
1849
|
d_major_mode_size = epi_tile[1] if d_layout.is_n_major_c() else epi_tile[0]
|
|
1000
1850
|
d_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1001
|
-
sm90_utils.get_smem_layout_atom(
|
|
1002
|
-
d_layout,
|
|
1003
|
-
d_dtype,
|
|
1004
|
-
d_major_mode_size,
|
|
1005
|
-
),
|
|
1851
|
+
sm90_utils.get_smem_layout_atom(d_layout, d_dtype, d_major_mode_size),
|
|
1006
1852
|
d_dtype,
|
|
1007
1853
|
)
|
|
1008
1854
|
epi_smem_layout_staged = cute.tile_to_shape(
|
|
@@ -1011,40 +1857,37 @@ class HopperWgmmaGemmKernel:
|
|
|
1011
1857
|
order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2),
|
|
1012
1858
|
)
|
|
1013
1859
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
:
|
|
1028
|
-
|
|
1029
|
-
:type cluster_shape_mnk: Tuple[int, int, int]
|
|
1030
|
-
|
|
1031
|
-
:return: Grid shape for kernel launch.
|
|
1032
|
-
:rtype: Tuple[int, int, int]
|
|
1033
|
-
"""
|
|
1860
|
+
if c_dtype is not None:
|
|
1861
|
+
assert c_layout is not None
|
|
1862
|
+
c_smem_shape = epi_tile
|
|
1863
|
+
c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
|
|
1864
|
+
c_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
1865
|
+
sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size),
|
|
1866
|
+
c_dtype,
|
|
1867
|
+
)
|
|
1868
|
+
epi_c_smem_layout_staged = cute.tile_to_shape(
|
|
1869
|
+
c_smem_layout_atom,
|
|
1870
|
+
cute.append(c_smem_shape, epi_c_stage),
|
|
1871
|
+
order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
|
|
1872
|
+
)
|
|
1873
|
+
else:
|
|
1874
|
+
epi_c_smem_layout_staged = None
|
|
1034
1875
|
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1876
|
+
return (
|
|
1877
|
+
a_smem_layout_staged,
|
|
1878
|
+
b_smem_layout_staged,
|
|
1879
|
+
epi_smem_layout_staged,
|
|
1880
|
+
epi_c_smem_layout_staged,
|
|
1881
|
+
)
|
|
1040
1882
|
|
|
1041
1883
|
@staticmethod
|
|
1042
|
-
def
|
|
1884
|
+
def _make_tma_epi_atoms_and_tensors(
|
|
1043
1885
|
tensor_d: cute.Tensor,
|
|
1044
1886
|
epi_smem_layout_staged: cute.ComposedLayout,
|
|
1045
1887
|
epi_tile: Tuple[int, int],
|
|
1888
|
+
store_or_load: str,
|
|
1046
1889
|
) -> Tuple[cute.CopyAtom, cute.Tensor]:
|
|
1047
|
-
"""Create TMA atoms and tensors for
|
|
1890
|
+
"""Create TMA atoms and tensors for storing D or loading C.
|
|
1048
1891
|
|
|
1049
1892
|
:param tensor_d: Output tensor D
|
|
1050
1893
|
:type tensor_d: cute.Tensor
|
|
@@ -1056,15 +1899,17 @@ class HopperWgmmaGemmKernel:
|
|
|
1056
1899
|
:return: TMA atom and tensor for C
|
|
1057
1900
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
1058
1901
|
"""
|
|
1902
|
+
assert store_or_load in ["load", "store"]
|
|
1059
1903
|
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
|
1060
|
-
|
|
1904
|
+
d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
|
|
1905
|
+
op = (
|
|
1906
|
+
cpasync.CopyBulkTensorTileG2SOp()
|
|
1907
|
+
if store_or_load == "load"
|
|
1908
|
+
else cpasync.CopyBulkTensorTileS2GOp()
|
|
1909
|
+
)
|
|
1061
1910
|
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
|
|
1062
|
-
|
|
1063
|
-
tensor_d,
|
|
1064
|
-
epi_smem_layout,
|
|
1065
|
-
c_cta_v_layout,
|
|
1911
|
+
op, tensor_d, epi_smem_layout, d_cta_v_layout
|
|
1066
1912
|
)
|
|
1067
|
-
|
|
1068
1913
|
return tma_atom_d, tma_tensor_d
|
|
1069
1914
|
|
|
1070
1915
|
@staticmethod
|
|
@@ -1104,6 +1949,31 @@ class HopperWgmmaGemmKernel:
|
|
|
1104
1949
|
)
|
|
1105
1950
|
return tma_atom, tma_tensor
|
|
1106
1951
|
|
|
1952
|
+
def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128):
|
|
1953
|
+
atom_async_copy = cute.make_copy_atom(
|
|
1954
|
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
1955
|
+
dtype,
|
|
1956
|
+
num_bits_per_copy=copy_bits,
|
|
1957
|
+
)
|
|
1958
|
+
copy_elems = copy_bits // dtype.width
|
|
1959
|
+
shape_dim_1 = cute.size(self.tile_shape_mnk[2]) // copy_elems
|
|
1960
|
+
# thread layout for copy
|
|
1961
|
+
thread_layout = cute.make_layout(
|
|
1962
|
+
(num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
|
1963
|
+
)
|
|
1964
|
+
if major_mode != cutlass.utils.LayoutEnum.ROW_MAJOR:
|
|
1965
|
+
shape_dim_0 = cute.size(self.tile_shape_mnk[0]) // copy_elems
|
|
1966
|
+
thread_layout = cute.make_layout(
|
|
1967
|
+
(shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
|
1968
|
+
)
|
|
1969
|
+
# Value layout for copy
|
|
1970
|
+
value_layout = (
|
|
1971
|
+
cute.make_layout((1, copy_elems))
|
|
1972
|
+
if major_mode == cutlass.utils.LayoutEnum.ROW_MAJOR
|
|
1973
|
+
else cute.make_layout((copy_elems, 1))
|
|
1974
|
+
)
|
|
1975
|
+
return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout)
|
|
1976
|
+
|
|
1107
1977
|
@staticmethod
|
|
1108
1978
|
def is_valid_dtypes(
|
|
1109
1979
|
a_dtype: Type[cutlass.Numeric],
|
|
@@ -1133,7 +2003,6 @@ class HopperWgmmaGemmKernel:
|
|
|
1133
2003
|
:rtype: bool
|
|
1134
2004
|
"""
|
|
1135
2005
|
is_valid = True
|
|
1136
|
-
# tested a_dtype
|
|
1137
2006
|
if a_dtype not in {
|
|
1138
2007
|
cutlass.Float16,
|
|
1139
2008
|
cutlass.BFloat16,
|
|
@@ -1149,7 +2018,6 @@ class HopperWgmmaGemmKernel:
|
|
|
1149
2018
|
cutlass.Float8E5M2,
|
|
1150
2019
|
}:
|
|
1151
2020
|
is_valid = False
|
|
1152
|
-
# tested acc_dtype
|
|
1153
2021
|
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
|
|
1154
2022
|
is_valid = False
|
|
1155
2023
|
# tested d_dtype
|
|
@@ -1180,21 +2048,28 @@ def run(
|
|
|
1180
2048
|
a_dtype: Type[cutlass.Numeric],
|
|
1181
2049
|
b_dtype: Type[cutlass.Numeric],
|
|
1182
2050
|
d_dtype: Type[cutlass.Numeric],
|
|
2051
|
+
c_dtype: Optional[Type[cutlass.Numeric]],
|
|
1183
2052
|
acc_dtype: Type[cutlass.Numeric],
|
|
1184
2053
|
a_major: str,
|
|
1185
2054
|
b_major: str,
|
|
1186
2055
|
d_major: str,
|
|
2056
|
+
c_major: str,
|
|
1187
2057
|
tile_shape_mnk: Tuple[int, int, int],
|
|
1188
2058
|
cluster_shape_mn: Tuple[int, int],
|
|
1189
2059
|
tolerance: float,
|
|
1190
2060
|
warmup_iterations: int,
|
|
1191
2061
|
iterations: int,
|
|
1192
2062
|
skip_ref_check: bool,
|
|
1193
|
-
|
|
2063
|
+
persistent: bool,
|
|
2064
|
+
dynamic_persistent: bool,
|
|
2065
|
+
pingpong: bool,
|
|
2066
|
+
varlen_m: bool,
|
|
2067
|
+
gather_A: bool,
|
|
2068
|
+
fp8_fast_accum: bool,
|
|
1194
2069
|
**kwargs,
|
|
1195
2070
|
):
|
|
1196
2071
|
"""
|
|
1197
|
-
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
|
2072
|
+
Prepare A/B/D/C tensors, launch GPU kernel, and reference checking.
|
|
1198
2073
|
|
|
1199
2074
|
:param mnkl: Problem size (M, N, K, L)
|
|
1200
2075
|
:type mnkl: Tuple[int, int, int, int]
|
|
@@ -1220,22 +2095,22 @@ def run(
|
|
|
1220
2095
|
:type iterations: int, optional
|
|
1221
2096
|
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
|
1222
2097
|
:type skip_ref_check: bool, optional
|
|
1223
|
-
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
|
1224
|
-
:type use_cold_l2: bool, optional
|
|
1225
|
-
:return: Execution time of the GEMM kernel in microseconds
|
|
1226
|
-
:rtype: float
|
|
1227
2098
|
"""
|
|
1228
2099
|
|
|
2100
|
+
if dynamic_persistent:
|
|
2101
|
+
persistent = True
|
|
2102
|
+
|
|
1229
2103
|
print("Running Hopper Dense GEMM with:")
|
|
1230
2104
|
print(f"mnkl: {mnkl}")
|
|
1231
|
-
print(
|
|
1232
|
-
|
|
2105
|
+
print(
|
|
2106
|
+
f"A dtype: {a_dtype}, B dtype: {b_dtype}, D dtype: {d_dtype}, C_dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
|
2107
|
+
)
|
|
2108
|
+
print(f"Matrix majors - A: {a_major}, B: {b_major}, D: {d_major}")
|
|
1233
2109
|
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
|
1234
2110
|
print(f"Tolerance: {tolerance}")
|
|
1235
2111
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
1236
2112
|
print(f"Iterations: {iterations}")
|
|
1237
2113
|
print(f"Skip reference checking: {skip_ref_check}")
|
|
1238
|
-
print(f"Use cold L2: {use_cold_l2}")
|
|
1239
2114
|
|
|
1240
2115
|
# Unpack parameters
|
|
1241
2116
|
m, n, k, l = mnkl
|
|
@@ -1263,16 +2138,17 @@ def run(
|
|
|
1263
2138
|
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
|
1264
2139
|
is_unsigned = dtype in {cutlass.Uint8}
|
|
1265
2140
|
# Temporarily use uint8 as torch does not support fp8 type
|
|
1266
|
-
torch_dtype = (
|
|
1267
|
-
|
|
2141
|
+
torch_dtype = cutlass_torch.dtype(dtype)
|
|
2142
|
+
gen_dtype = (
|
|
2143
|
+
torch_dtype
|
|
1268
2144
|
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
1269
|
-
else torch.
|
|
2145
|
+
else torch.bfloat16
|
|
1270
2146
|
)
|
|
1271
2147
|
|
|
1272
2148
|
# Create dtype torch tensor (cpu)
|
|
1273
2149
|
torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor(
|
|
1274
2150
|
shape,
|
|
1275
|
-
|
|
2151
|
+
gen_dtype,
|
|
1276
2152
|
permute_order=permute_order,
|
|
1277
2153
|
# init_type=cutlass.torch.TensorInitType.RANDOM,
|
|
1278
2154
|
# init_config=cutlass.torch.RandomInitConfig(
|
|
@@ -1280,7 +2156,7 @@ def run(
|
|
|
1280
2156
|
# ),
|
|
1281
2157
|
init_type=cutlass.torch.TensorInitType.GAUSSIAN,
|
|
1282
2158
|
init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1),
|
|
1283
|
-
)
|
|
2159
|
+
).to(torch_dtype)
|
|
1284
2160
|
# Create dtype torch tensor (gpu)
|
|
1285
2161
|
torch_tensor = torch_tensor_cpu.cuda()
|
|
1286
2162
|
|
|
@@ -1288,10 +2164,20 @@ def run(
|
|
|
1288
2164
|
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
|
1289
2165
|
|
|
1290
2166
|
# Create dtype cute tensor (gpu)
|
|
1291
|
-
|
|
2167
|
+
torch_tensor_view = (
|
|
2168
|
+
torch_tensor
|
|
2169
|
+
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
2170
|
+
else torch_tensor.view(torch.uint8)
|
|
2171
|
+
)
|
|
2172
|
+
cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16)
|
|
1292
2173
|
cute_tensor.element_type = dtype
|
|
1293
2174
|
if is_dynamic_layout:
|
|
1294
2175
|
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1))
|
|
2176
|
+
cute_tensor = cute_tensor.mark_compact_shape_dynamic(
|
|
2177
|
+
mode=(1 if not is_mode0_major else 0),
|
|
2178
|
+
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
|
|
2179
|
+
divisibility=(128 // dtype.width),
|
|
2180
|
+
)
|
|
1295
2181
|
cute_tensor = cutlass.torch.convert_cute_tensor(
|
|
1296
2182
|
f32_torch_tensor,
|
|
1297
2183
|
cute_tensor,
|
|
@@ -1302,24 +2188,142 @@ def run(
|
|
|
1302
2188
|
return f32_torch_tensor, cute_tensor, torch_tensor
|
|
1303
2189
|
|
|
1304
2190
|
a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
|
|
2191
|
+
if gather_A:
|
|
2192
|
+
assert a_major == "k"
|
|
2193
|
+
a_idx = torch.randperm(l * m, dtype=torch.int32, device="cuda")
|
|
2194
|
+
from einops import rearrange
|
|
2195
|
+
|
|
2196
|
+
a = rearrange(rearrange(a, "m k l -> (m l) k")[a_idx.cpu()], "(m l) k -> m k l", m=m)
|
|
2197
|
+
a_torch = rearrange(a_torch, "m k l -> (m l) k")
|
|
2198
|
+
mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2199
|
+
a_idx_reshaped = rearrange(a_idx, "(m l) -> l m", m=m).contiguous().transpose(0, 1)
|
|
2200
|
+
mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
2201
|
+
else:
|
|
2202
|
+
mAIdx = None
|
|
1305
2203
|
b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
|
1306
|
-
|
|
2204
|
+
_, mD, d_torch = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
|
|
2205
|
+
if c_dtype is not None:
|
|
2206
|
+
c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
|
2207
|
+
else:
|
|
2208
|
+
c, mC, c_torch = None, None, None
|
|
2209
|
+
if varlen_m:
|
|
2210
|
+
assert a_major == "k"
|
|
2211
|
+
assert d_major == "n"
|
|
2212
|
+
from einops import rearrange
|
|
2213
|
+
|
|
2214
|
+
a, d_torch = [rearrange(t, "m x l -> (l m) x") for t in (a, d_torch)]
|
|
2215
|
+
if not gather_A:
|
|
2216
|
+
(a_torch,) = [rearrange(t, "m x l -> (l m) x") for t in (a_torch,)]
|
|
2217
|
+
if c_dtype is not None:
|
|
2218
|
+
c, c_torch = [rearrange(t, "m x l -> (l m) x") for t in (c, c_torch)]
|
|
2219
|
+
mC = from_dlpack(c_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2220
|
+
mA = from_dlpack(a_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2221
|
+
mD = from_dlpack(d_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
2222
|
+
# TODO: generate random cu_seqlens_m
|
|
2223
|
+
cu_seqlens_m = torch.arange(0, l + 1, dtype=torch.int32, device="cuda") * m
|
|
2224
|
+
mCuSeqlensM = from_dlpack(cu_seqlens_m, assumed_align=64).mark_layout_dynamic(leading_dim=0)
|
|
2225
|
+
if gather_A:
|
|
2226
|
+
a_idx_reshaped = rearrange(a_idx_reshaped, "m l -> (l m)")
|
|
2227
|
+
mAIdx = from_dlpack(a_idx_reshaped, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
2228
|
+
else:
|
|
2229
|
+
cu_seqlens_m, mCuSeqlensM = None, None
|
|
2230
|
+
|
|
2231
|
+
if varlen_m: # Need to allocate space in gmem to store tensormaps
|
|
2232
|
+
if not persistent:
|
|
2233
|
+
total_m = m * l
|
|
2234
|
+
block_size_m = tile_shape_mnk[0] * cluster_shape_mnk[0]
|
|
2235
|
+
block_size_n = tile_shape_mnk[1] * cluster_shape_mnk[1]
|
|
2236
|
+
total_clusters_m_max = (total_m + l * (block_size_m - 1)) // block_size_m
|
|
2237
|
+
total_clusters_max = total_clusters_m_max * ((n + block_size_n - 1) // block_size_n)
|
|
2238
|
+
total_ctas = total_clusters_max * cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
|
2239
|
+
else:
|
|
2240
|
+
total_ctas = cutlass.utils.HardwareInfo().get_device_multiprocessor_count()
|
|
2241
|
+
if pingpong:
|
|
2242
|
+
total_ctas *= 2
|
|
2243
|
+
# 128 bytes per tensormap
|
|
2244
|
+
tensormaps_torch = torch.empty(total_ctas, 128 // 8, dtype=torch.int64, device="cuda")
|
|
2245
|
+
tensormaps_tensor = from_dlpack(
|
|
2246
|
+
tensormaps_torch, assumed_align=128
|
|
2247
|
+
).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
|
|
2248
|
+
else:
|
|
2249
|
+
tensormaps_tensor = None
|
|
2250
|
+
|
|
2251
|
+
gemm = HopperWgmmaGemmKernel(
|
|
2252
|
+
acc_dtype,
|
|
2253
|
+
a_dtype,
|
|
2254
|
+
tile_shape_mnk,
|
|
2255
|
+
cluster_shape_mnk,
|
|
2256
|
+
pingpong=pingpong,
|
|
2257
|
+
is_persistent=persistent,
|
|
2258
|
+
fp8_fast_accum=fp8_fast_accum,
|
|
2259
|
+
gather_A=gather_A,
|
|
2260
|
+
)
|
|
1307
2261
|
|
|
1308
|
-
|
|
2262
|
+
# Compute max active clusters on current device
|
|
2263
|
+
if persistent:
|
|
2264
|
+
max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
|
|
2265
|
+
cluster_shape_mn[0] * cluster_shape_mn[1]
|
|
2266
|
+
)
|
|
2267
|
+
if dynamic_persistent:
|
|
2268
|
+
tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda")
|
|
2269
|
+
else:
|
|
2270
|
+
tile_count_semaphore = None
|
|
2271
|
+
# max_active_clusters = 1
|
|
2272
|
+
else:
|
|
2273
|
+
max_active_clusters = 0
|
|
2274
|
+
tile_count_semaphore = None
|
|
1309
2275
|
|
|
1310
|
-
|
|
1311
|
-
stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
2276
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
1312
2277
|
# compile gemm kernel
|
|
1313
|
-
compiled_gemm = cute.compile(
|
|
2278
|
+
compiled_gemm = cute.compile(
|
|
2279
|
+
gemm,
|
|
2280
|
+
mA,
|
|
2281
|
+
mB,
|
|
2282
|
+
mD,
|
|
2283
|
+
mC,
|
|
2284
|
+
mAIdx,
|
|
2285
|
+
mCuSeqlensM,
|
|
2286
|
+
tensormaps_tensor,
|
|
2287
|
+
make_ptr(Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
2288
|
+
if tile_count_semaphore is not None
|
|
2289
|
+
else None,
|
|
2290
|
+
max_active_clusters,
|
|
2291
|
+
current_stream,
|
|
2292
|
+
)
|
|
1314
2293
|
|
|
1315
2294
|
if not skip_ref_check:
|
|
1316
2295
|
# execution
|
|
1317
|
-
compiled_gemm(
|
|
2296
|
+
compiled_gemm(
|
|
2297
|
+
mA,
|
|
2298
|
+
mB,
|
|
2299
|
+
mD,
|
|
2300
|
+
mC,
|
|
2301
|
+
mAIdx,
|
|
2302
|
+
mCuSeqlensM,
|
|
2303
|
+
tensormaps_tensor,
|
|
2304
|
+
tile_count_semaphore,
|
|
2305
|
+
max_active_clusters,
|
|
2306
|
+
current_stream,
|
|
2307
|
+
)
|
|
2308
|
+
if tile_count_semaphore is not None and varlen_m:
|
|
2309
|
+
tile_count_semaphore.zero_()
|
|
1318
2310
|
|
|
1319
2311
|
torch.cuda.synchronize()
|
|
1320
2312
|
|
|
1321
2313
|
# Ref check
|
|
1322
|
-
|
|
2314
|
+
if not varlen_m:
|
|
2315
|
+
ref = torch.einsum("mkl,nkl->mnl", a, b)
|
|
2316
|
+
else:
|
|
2317
|
+
ref = torch.cat(
|
|
2318
|
+
[
|
|
2319
|
+
torch.einsum("mk,nk->mn", a[cu_seqlens_m[i] : cu_seqlens_m[i + 1]], b[:, :, i])
|
|
2320
|
+
for i in range(l)
|
|
2321
|
+
],
|
|
2322
|
+
dim=0,
|
|
2323
|
+
)
|
|
2324
|
+
if c is not None:
|
|
2325
|
+
ref = ref + c
|
|
2326
|
+
ref = ref.cpu()
|
|
1323
2327
|
|
|
1324
2328
|
if d_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
|
1325
2329
|
# m major: (l, n, m) -> (m, n, l)
|
|
@@ -1333,79 +2337,112 @@ def run(
|
|
|
1333
2337
|
init_type=cutlass_torch.TensorInitType.SKIP,
|
|
1334
2338
|
).cuda()
|
|
1335
2339
|
# Create dtype cute tensor (gpu)
|
|
1336
|
-
|
|
2340
|
+
ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic(
|
|
1337
2341
|
leading_dim=(1 if d_major == "n" else 0)
|
|
1338
2342
|
)
|
|
1339
|
-
|
|
1340
|
-
|
|
2343
|
+
ref_d_tensor.element_type = d_dtype
|
|
2344
|
+
ref_d_tensor = cutlass_torch.convert_cute_tensor(
|
|
1341
2345
|
ref,
|
|
1342
|
-
|
|
2346
|
+
ref_d_tensor,
|
|
1343
2347
|
d_dtype,
|
|
1344
2348
|
is_dynamic_layout=True,
|
|
1345
2349
|
)
|
|
1346
|
-
|
|
2350
|
+
ref_d = f8_torch_tensor.cpu()
|
|
1347
2351
|
else:
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
|
|
1351
|
-
|
|
1352
|
-
def generate_tensors():
|
|
1353
|
-
_, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
|
|
1354
|
-
_, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
|
1355
|
-
_, mC_workspace, _ = create_and_permute_tensor(l, m, n, d_major == "m", d_dtype)
|
|
1356
|
-
return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
|
|
1357
|
-
|
|
1358
|
-
workspace_count = 1
|
|
1359
|
-
if use_cold_l2:
|
|
1360
|
-
one_workspace_bytes = (
|
|
1361
|
-
a_torch.numel() * a_torch.element_size()
|
|
1362
|
-
+ b_torch.numel() * b_torch.element_size()
|
|
1363
|
-
+ c_torch.numel() * c_torch.element_size()
|
|
1364
|
-
)
|
|
1365
|
-
workspace_count = testing.get_workspace_count(
|
|
1366
|
-
one_workspace_bytes, warmup_iterations, iterations
|
|
1367
|
-
)
|
|
2352
|
+
ref_d = ref.to(cutlass_torch.dtype(d_dtype))
|
|
1368
2353
|
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
stream=stream,
|
|
1374
|
-
warmup_iterations=warmup_iterations,
|
|
1375
|
-
iterations=iterations,
|
|
1376
|
-
)
|
|
2354
|
+
out = d_torch.cpu().squeeze()
|
|
2355
|
+
out_ref = ref_d.squeeze()
|
|
2356
|
+
# breakpoint()
|
|
2357
|
+
torch.testing.assert_close(d_torch.cpu(), ref_d, atol=tolerance, rtol=1e-03)
|
|
1377
2358
|
|
|
1378
|
-
|
|
2359
|
+
# return
|
|
1379
2360
|
|
|
1380
|
-
|
|
2361
|
+
from triton.testing import do_bench
|
|
1381
2362
|
|
|
1382
2363
|
flops = 2 * m * n * k * l
|
|
2364
|
+
# Calculate memory bandwidth
|
|
2365
|
+
bytes_A = m * k * l * (a_dtype.width // 8) # A tensor: (m, k, l)
|
|
2366
|
+
bytes_B = n * k * l * (b_dtype.width // 8) # B tensor: (n, k, l)
|
|
2367
|
+
bytes_D = m * n * l * (d_dtype.width // 8) # D tensor: (m, n, l)
|
|
2368
|
+
bytes_C = m * n * l * (c_dtype.width // 8) if c_dtype is not None else 0 # C tensor: (m, n, l)
|
|
2369
|
+
total_bytes = bytes_A + bytes_B + bytes_D + bytes_C # Read A, B, C; Write D
|
|
1383
2370
|
|
|
1384
|
-
repeats =
|
|
1385
|
-
|
|
1386
|
-
warmup = 5
|
|
2371
|
+
repeats = iterations
|
|
2372
|
+
warmup = warmup_iterations
|
|
1387
2373
|
|
|
1388
2374
|
import time
|
|
1389
2375
|
|
|
2376
|
+
if not varlen_m and not gather_A:
|
|
2377
|
+
time.sleep(0.5)
|
|
2378
|
+
if a_dtype.width == 8:
|
|
2379
|
+
assert l == 1
|
|
2380
|
+
scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda")
|
|
2381
|
+
fn_cublas = lambda: torch._scaled_mm(
|
|
2382
|
+
a_torch[:, :, 0],
|
|
2383
|
+
b_torch[:, :, 0].mT,
|
|
2384
|
+
scale_a=scale_ab,
|
|
2385
|
+
scale_b=scale_ab,
|
|
2386
|
+
out_dtype=torch.bfloat16,
|
|
2387
|
+
use_fast_accum=fp8_fast_accum,
|
|
2388
|
+
)
|
|
2389
|
+
else:
|
|
2390
|
+
if c_torch is None:
|
|
2391
|
+
fn_cublas = lambda: torch.matmul(
|
|
2392
|
+
a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT
|
|
2393
|
+
)
|
|
2394
|
+
else:
|
|
2395
|
+
c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32
|
|
2396
|
+
fn_cublas = lambda: torch.baddbmm(
|
|
2397
|
+
c_torch_convert.permute(2, 0, 1),
|
|
2398
|
+
a_torch.permute(2, 0, 1),
|
|
2399
|
+
b_torch.permute(2, 0, 1).mT,
|
|
2400
|
+
)
|
|
2401
|
+
timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2402
|
+
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2403
|
+
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2404
|
+
|
|
1390
2405
|
time.sleep(0.5)
|
|
1391
|
-
fn = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT)
|
|
1392
|
-
timing_cublas = do_bench(fn, warmup=warmup, rep=repeats)
|
|
1393
|
-
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
1394
|
-
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
1395
2406
|
|
|
2407
|
+
def fn():
|
|
2408
|
+
compiled_gemm(
|
|
2409
|
+
mA,
|
|
2410
|
+
mB,
|
|
2411
|
+
mD,
|
|
2412
|
+
mC,
|
|
2413
|
+
mAIdx,
|
|
2414
|
+
mCuSeqlensM,
|
|
2415
|
+
tensormaps_tensor,
|
|
2416
|
+
tile_count_semaphore,
|
|
2417
|
+
max_active_clusters,
|
|
2418
|
+
current_stream,
|
|
2419
|
+
)
|
|
2420
|
+
if tile_count_semaphore is not None and varlen_m:
|
|
2421
|
+
tile_count_semaphore.zero_()
|
|
2422
|
+
|
|
2423
|
+
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
2424
|
+
# Idk why but for some cases the 1st run is much slower
|
|
1396
2425
|
time.sleep(0.5)
|
|
1397
|
-
fn = lambda: compiled_gemm(mA, mB, mC, current_stream)
|
|
1398
2426
|
timing = do_bench(fn, warmup=warmup, rep=repeats)
|
|
1399
2427
|
tflops = flops / (timing * 1e9) # Convert to TFlops
|
|
1400
|
-
|
|
2428
|
+
gbps = total_bytes / (timing * 1e6) # Convert to GB/s (1e9 for ms->s, 1e9 for B->GB)
|
|
2429
|
+
print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}, GB/s: {gbps:.0f}")
|
|
2430
|
+
fn()
|
|
1401
2431
|
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
2432
|
+
if not varlen_m:
|
|
2433
|
+
time.sleep(0.5)
|
|
2434
|
+
timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
|
|
2435
|
+
tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
|
|
2436
|
+
print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
|
|
2437
|
+
|
|
2438
|
+
from flash_attn.utils.benchmark import pytorch_profiler
|
|
1407
2439
|
|
|
1408
|
-
|
|
2440
|
+
pytorch_profiler(fn_cublas)
|
|
2441
|
+
# pytorch_profiler(torch.sort, d_torch.squeeze(), dim=-1)
|
|
2442
|
+
# pytorch_profiler(torch.compile(torch.sort), d_torch.squeeze(), dim=-1)
|
|
2443
|
+
# pytorch_profiler(torch.topk, d_torch.squeeze(), dim=-1, k=1)
|
|
2444
|
+
# pytorch_profiler(torch.compile(torch.topk), d_torch.squeeze(), dim=-1, k=1)
|
|
2445
|
+
# pytorch_profiler(torch.square, d_torch.squeeze())
|
|
1409
2446
|
|
|
1410
2447
|
|
|
1411
2448
|
if __name__ == "__main__":
|
|
@@ -1415,16 +2452,23 @@ if __name__ == "__main__":
|
|
|
1415
2452
|
args.a_dtype,
|
|
1416
2453
|
args.b_dtype,
|
|
1417
2454
|
args.d_dtype,
|
|
2455
|
+
args.c_dtype,
|
|
1418
2456
|
args.acc_dtype,
|
|
1419
2457
|
args.a_major,
|
|
1420
2458
|
args.b_major,
|
|
1421
2459
|
args.d_major,
|
|
2460
|
+
args.c_major,
|
|
1422
2461
|
args.tile_shape_mnk,
|
|
1423
2462
|
args.cluster_shape_mn,
|
|
1424
2463
|
args.tolerance,
|
|
1425
2464
|
args.warmup_iterations,
|
|
1426
2465
|
args.iterations,
|
|
1427
2466
|
args.skip_ref_check,
|
|
1428
|
-
args.
|
|
2467
|
+
args.persistent,
|
|
2468
|
+
args.dynamic_persistent,
|
|
2469
|
+
args.pingpong,
|
|
2470
|
+
args.varlen_m,
|
|
2471
|
+
args.gather_A,
|
|
2472
|
+
args.fp8_fast_accum,
|
|
1429
2473
|
)
|
|
1430
2474
|
print("PASS")
|