quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +14 -18
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +2 -4
- quack/linear.py +37 -0
- quack/pipeline.py +59 -89
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +5 -3
- quack/sort/bitonic_sort.py +3 -3
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- quack_kernels-0.2.5.dist-info/RECORD +0 -45
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/gemm_dact.py
CHANGED
|
@@ -1,20 +1,35 @@
|
|
|
1
1
|
# Copyright (c) 2025, Tri Dao.
|
|
2
|
-
from typing import Optional, Tuple
|
|
2
|
+
from typing import Optional, Tuple, Callable, Type
|
|
3
3
|
from functools import partial
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
import operator
|
|
4
6
|
|
|
7
|
+
import torch
|
|
5
8
|
from torch import Tensor
|
|
6
9
|
|
|
7
10
|
import cutlass
|
|
8
11
|
import cutlass.cute as cute
|
|
9
|
-
from cutlass import Float32, const_expr
|
|
12
|
+
from cutlass import Int32, Float32, const_expr
|
|
10
13
|
import cutlass.torch as cutlass_torch
|
|
14
|
+
from cutlass.cute.runtime import from_dlpack
|
|
15
|
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
11
16
|
|
|
17
|
+
import quack.sm90_utils as sm90_utils
|
|
18
|
+
from quack.sm90_utils import partition_for_epilogue
|
|
12
19
|
from quack.gemm_sm90 import GemmSm90
|
|
13
20
|
from quack.gemm_sm100 import GemmSm100
|
|
14
21
|
from quack.gemm_default_epi import GemmDefaultEpiMixin
|
|
15
22
|
from quack.gemm_act import GemmActMixin
|
|
16
|
-
from quack.cute_dsl_utils import
|
|
23
|
+
from quack.cute_dsl_utils import (
|
|
24
|
+
ArgumentsBase,
|
|
25
|
+
ParamsBase,
|
|
26
|
+
torch2cute_dtype_map,
|
|
27
|
+
get_device_capacity,
|
|
28
|
+
get_max_active_clusters,
|
|
29
|
+
)
|
|
17
30
|
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
31
|
+
from quack.varlen_utils import VarlenManager
|
|
32
|
+
import quack.layout_utils as layout_utils
|
|
18
33
|
import quack.activation
|
|
19
34
|
|
|
20
35
|
|
|
@@ -39,7 +54,7 @@ class GemmDActMixin(GemmActMixin):
|
|
|
39
54
|
tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
|
|
40
55
|
# If we don't have .shape here, the compiler generates local stores and loads
|
|
41
56
|
if const_expr(params.act_fn is not None):
|
|
42
|
-
tRS_rPostAct = cute.
|
|
57
|
+
tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
|
|
43
58
|
if const_expr(self.arch < 100):
|
|
44
59
|
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
45
60
|
tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
|
|
@@ -213,3 +228,504 @@ def gemm_dact(
|
|
|
213
228
|
|
|
214
229
|
|
|
215
230
|
gemm_dact.compile_cache = {}
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class GemmDGatedMixin(GemmActMixin):
|
|
234
|
+
# Different from GemmActMixin, here act_bwd_fn must take in 3 arguments (x, y, dout)
|
|
235
|
+
# and return 3 arguments (dx, dy, out)
|
|
236
|
+
@dataclass
|
|
237
|
+
class EpilogueArguments(ArgumentsBase):
|
|
238
|
+
mPostAct: cute.Tensor
|
|
239
|
+
act_bwd_fn: cutlass.Constexpr[Callable]
|
|
240
|
+
implicit_dtype: Type[cutlass.Numeric] = cute.BFloat16
|
|
241
|
+
# We don't use alpha, beta, mRowVecBroadcast for now
|
|
242
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
243
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
244
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
245
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
246
|
+
mColVecReduce: Optional[cute.Tensor] = None
|
|
247
|
+
|
|
248
|
+
@dataclass
|
|
249
|
+
class EpilogueParams(ParamsBase):
|
|
250
|
+
tma_atom_postact: cute.CopyAtom
|
|
251
|
+
mPostAct_mnl: cute.Tensor
|
|
252
|
+
epi_postact_smem_layout_staged: cute.ComposedLayout
|
|
253
|
+
epi_tile_postact: cute.Tile
|
|
254
|
+
act_bwd_fn: cutlass.Constexpr[Callable]
|
|
255
|
+
implicit_dtype: Type[cutlass.Numeric]
|
|
256
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
257
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
258
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
259
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
260
|
+
mColVecReduce: Optional[cute.Tensor] = None
|
|
261
|
+
|
|
262
|
+
def epi_to_underlying_arguments(
|
|
263
|
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
264
|
+
) -> EpilogueParams:
|
|
265
|
+
self.postact_dtype = args.mPostAct.element_type
|
|
266
|
+
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
|
267
|
+
# C and D are implicitly 2 16-bit elements packed into 32 bits, simply for the purpose
|
|
268
|
+
# for reusing the existing load/store code.
|
|
269
|
+
assert args.implicit_dtype.width == 16, "GemmDGated only supports 16bit for now"
|
|
270
|
+
assert self.d_dtype.width == 32, "D storage type must be 32 bit"
|
|
271
|
+
assert self.c_dtype.width == 32, "C storage type must be 32 bit"
|
|
272
|
+
|
|
273
|
+
self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
|
|
274
|
+
epi_tile_postact = self.epi_tile
|
|
275
|
+
utils_cls = sm100_utils if self.arch == 100 else sm90_utils
|
|
276
|
+
epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
|
|
277
|
+
self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
|
|
278
|
+
)
|
|
279
|
+
tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
|
|
280
|
+
args.mPostAct,
|
|
281
|
+
epi_postact_smem_layout_staged,
|
|
282
|
+
epi_tile_postact,
|
|
283
|
+
op_type="store",
|
|
284
|
+
)
|
|
285
|
+
# Assume all strides are divisible by 32 bits except the last stride
|
|
286
|
+
new_stride = lambda t: tuple(
|
|
287
|
+
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
|
288
|
+
for s in t.stride
|
|
289
|
+
)
|
|
290
|
+
mRowVecBroadcast, mColVecBroadcast, mColVecReduce = [
|
|
291
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
292
|
+
if t is not None
|
|
293
|
+
else None
|
|
294
|
+
for t in (args.mRowVecBroadcast, args.mColVecBroadcast, args.mColVecReduce)
|
|
295
|
+
]
|
|
296
|
+
return self.EpilogueParams(
|
|
297
|
+
tma_atom_postact,
|
|
298
|
+
tma_tensor_postact,
|
|
299
|
+
epi_postact_smem_layout_staged,
|
|
300
|
+
epi_tile_postact,
|
|
301
|
+
args.act_bwd_fn,
|
|
302
|
+
args.implicit_dtype,
|
|
303
|
+
alpha=args.alpha,
|
|
304
|
+
beta=args.beta,
|
|
305
|
+
mRowVecBroadcast=mRowVecBroadcast,
|
|
306
|
+
mColVecBroadcast=mColVecBroadcast,
|
|
307
|
+
mColVecReduce=mColVecReduce,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
@cute.jit
|
|
311
|
+
def epi_begin(
|
|
312
|
+
self,
|
|
313
|
+
params: EpilogueParams,
|
|
314
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
315
|
+
epi_tile: cute.Tile,
|
|
316
|
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
|
317
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
318
|
+
tile_coord_mnkl: cute.Coord,
|
|
319
|
+
varlen_manager: VarlenManager,
|
|
320
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
321
|
+
tidx: Int32,
|
|
322
|
+
) -> Tuple[cute.Tensor, ...]:
|
|
323
|
+
epi_tensors = GemmDefaultEpiMixin.epi_begin(
|
|
324
|
+
self,
|
|
325
|
+
params,
|
|
326
|
+
epi_smem_tensors,
|
|
327
|
+
epi_tile,
|
|
328
|
+
tiled_copy_t2r,
|
|
329
|
+
tiled_copy_r2s,
|
|
330
|
+
tile_coord_mnkl,
|
|
331
|
+
varlen_manager,
|
|
332
|
+
epilogue_barrier,
|
|
333
|
+
tidx,
|
|
334
|
+
)
|
|
335
|
+
partition_for_epilogue_fn = partial(
|
|
336
|
+
partition_for_epilogue,
|
|
337
|
+
epi_tile=epi_tile,
|
|
338
|
+
tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
|
|
339
|
+
tidx=tidx,
|
|
340
|
+
reference_src=tiled_copy_t2r is None,
|
|
341
|
+
)
|
|
342
|
+
tDrColVecReduce = None
|
|
343
|
+
if const_expr(params.mColVecReduce is not None):
|
|
344
|
+
colvec_mma_layout = cute.make_layout(self.cta_tile_shape_mnk[:2], stride=(1, 0))
|
|
345
|
+
tDrColVec_layout = partition_for_epilogue_fn(
|
|
346
|
+
cute.make_rmem_tensor(colvec_mma_layout, Float32)
|
|
347
|
+
).layout
|
|
348
|
+
tDrColVecReduce = cute.make_rmem_tensor(tDrColVec_layout, Float32)
|
|
349
|
+
cute.filter_zeros(tDrColVecReduce).fill(0.0)
|
|
350
|
+
return (*epi_tensors, tDrColVecReduce)
|
|
351
|
+
|
|
352
|
+
def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
|
|
353
|
+
epi_tensors, tDrColVecReduce = epi_tensors[:-1], epi_tensors[-1]
|
|
354
|
+
epi_loop_tensors = super().epi_begin_loop(params, epi_tensors, epi_coord)
|
|
355
|
+
tDrColVecReduce_cur = None
|
|
356
|
+
if const_expr(tDrColVecReduce is not None):
|
|
357
|
+
tDrColVecReduce_cur = cute.group_modes(tDrColVecReduce, 3, cute.rank(tDrColVecReduce))[
|
|
358
|
+
None, None, None, epi_coord
|
|
359
|
+
]
|
|
360
|
+
return (*epi_loop_tensors, tDrColVecReduce_cur)
|
|
361
|
+
|
|
362
|
+
@cute.jit
|
|
363
|
+
def epi_visit_subtile(
|
|
364
|
+
self,
|
|
365
|
+
params: EpilogueParams,
|
|
366
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
367
|
+
tRS_rD: cute.Tensor,
|
|
368
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
369
|
+
) -> Optional[cute.Tensor]:
|
|
370
|
+
alpha, beta, tDrRowVec, tDrColVec, tDrColVecReduce = epi_loop_tensors
|
|
371
|
+
assert alpha is None and beta is None and tDrRowVec is None # We don't use these for now
|
|
372
|
+
assert tRS_rC is not None
|
|
373
|
+
implicit_dtype = params.implicit_dtype
|
|
374
|
+
assert implicit_dtype.width == 16, "GemmDGatedMixin only supports 16bit for now"
|
|
375
|
+
tRS_rXY_f16x2 = cute.recast_tensor(tRS_rC, implicit_dtype)
|
|
376
|
+
tRS_rXY_f32x2 = cute.make_rmem_tensor(tRS_rXY_f16x2.layout, Float32)
|
|
377
|
+
tRS_rXY_f32x2.store(tRS_rXY_f16x2.load().to(Float32))
|
|
378
|
+
tRS_rdXY_f32x2 = cute.make_rmem_tensor_like(tRS_rXY_f32x2, Float32)
|
|
379
|
+
tRS_rOut = cute.make_rmem_tensor_like(tRS_rD, Float32)
|
|
380
|
+
tRS_rD_scaled = cute.make_rmem_tensor_like(tRS_rD)
|
|
381
|
+
if const_expr(tDrColVec is not None): # Scale D by colvec
|
|
382
|
+
if const_expr(self.arch < 100):
|
|
383
|
+
tRS_rD_scaled.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type))
|
|
384
|
+
else:
|
|
385
|
+
tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
|
|
386
|
+
tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout)
|
|
387
|
+
tRS_rD_scaled_mn = layout_utils.convert_layout_zero_stride(
|
|
388
|
+
tRS_rD_scaled, tDrColVec.layout
|
|
389
|
+
)
|
|
390
|
+
for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
|
|
391
|
+
for n in cutlass.range(
|
|
392
|
+
cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
|
|
393
|
+
):
|
|
394
|
+
(
|
|
395
|
+
tRS_rD_scaled_mn[m, 2 * n],
|
|
396
|
+
tRS_rD_scaled_mn[m, 2 * n + 1],
|
|
397
|
+
) = cute.arch.mul_packed_f32x2(
|
|
398
|
+
(tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]),
|
|
399
|
+
(tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]),
|
|
400
|
+
)
|
|
401
|
+
else:
|
|
402
|
+
tRS_rD_scaled.store(tRS_rD.load())
|
|
403
|
+
if const_expr(self.arch < 100):
|
|
404
|
+
for i in cutlass.range(cute.size(tRS_rD)):
|
|
405
|
+
(
|
|
406
|
+
tRS_rdXY_f32x2[2 * i],
|
|
407
|
+
tRS_rdXY_f32x2[2 * i + 1],
|
|
408
|
+
tRS_rOut[i],
|
|
409
|
+
) = params.act_bwd_fn(
|
|
410
|
+
tRS_rXY_f32x2[2 * i], tRS_rXY_f32x2[2 * i + 1], tRS_rD_scaled[i]
|
|
411
|
+
)
|
|
412
|
+
else:
|
|
413
|
+
for i in cutlass.range(cute.size(tRS_rD) // 2):
|
|
414
|
+
(
|
|
415
|
+
(tRS_rdXY_f32x2[4 * i], tRS_rdXY_f32x2[4 * i + 2]),
|
|
416
|
+
(tRS_rdXY_f32x2[4 * i + 1], tRS_rdXY_f32x2[4 * i + 3]),
|
|
417
|
+
(tRS_rOut[2 * i], tRS_rOut[2 * i + 1]),
|
|
418
|
+
) = params.act_bwd_fn(
|
|
419
|
+
(tRS_rXY_f32x2[4 * i], tRS_rXY_f32x2[4 * i + 2]),
|
|
420
|
+
(tRS_rXY_f32x2[4 * i + 1], tRS_rXY_f32x2[4 * i + 3]),
|
|
421
|
+
(tRS_rD_scaled[2 * i], tRS_rD_scaled[2 * i + 1]),
|
|
422
|
+
)
|
|
423
|
+
if const_expr(tDrColVecReduce is not None):
|
|
424
|
+
# Need to multiply before D is scaled by colvec_scale
|
|
425
|
+
if const_expr(self.arch < 100):
|
|
426
|
+
for i in cutlass.range(cute.size(tDrColVecReduce), unroll_full=True):
|
|
427
|
+
tDrColVecReduce[i] += tRS_rOut[i] * tRS_rD[i]
|
|
428
|
+
else:
|
|
429
|
+
tDrColVecReduce_mn = layout_utils.convert_layout_zero_stride(
|
|
430
|
+
tDrColVecReduce, tDrColVecReduce.layout
|
|
431
|
+
)
|
|
432
|
+
tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVecReduce.layout)
|
|
433
|
+
tRS_rOut_mn = layout_utils.convert_layout_zero_stride(
|
|
434
|
+
tRS_rOut, tDrColVecReduce.layout
|
|
435
|
+
)
|
|
436
|
+
for m in cutlass.range(cute.size(tDrColVecReduce_mn, mode=[0]), unroll_full=True):
|
|
437
|
+
row_sum = cute.arch.mul_packed_f32x2(
|
|
438
|
+
(tRS_rD_mn[m, 0], tRS_rD_mn[m, 1]), (tRS_rOut_mn[m, 0], tRS_rOut_mn[m, 1])
|
|
439
|
+
)
|
|
440
|
+
for n in cutlass.range(
|
|
441
|
+
1, cute.size(tDrColVecReduce_mn, mode=[1]) // 2, unroll_full=True
|
|
442
|
+
):
|
|
443
|
+
row_sum = cute.arch.fma_packed_f32x2(
|
|
444
|
+
(tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]),
|
|
445
|
+
(tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]),
|
|
446
|
+
row_sum,
|
|
447
|
+
)
|
|
448
|
+
tDrColVecReduce_mn[m, 0] += row_sum[0] + row_sum[1]
|
|
449
|
+
|
|
450
|
+
if const_expr(tDrColVec is not None): # Scale Out by colvec
|
|
451
|
+
if const_expr(self.arch < 100):
|
|
452
|
+
tRS_rOut.store(tRS_rOut.load() * tDrColVec.load().to(tRS_rD.element_type))
|
|
453
|
+
else:
|
|
454
|
+
tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
|
|
455
|
+
tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVec.layout)
|
|
456
|
+
for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
|
|
457
|
+
for n in cutlass.range(
|
|
458
|
+
cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
|
|
459
|
+
):
|
|
460
|
+
tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1] = (
|
|
461
|
+
cute.arch.mul_packed_f32x2(
|
|
462
|
+
(tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]),
|
|
463
|
+
(tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]),
|
|
464
|
+
)
|
|
465
|
+
)
|
|
466
|
+
# Type conversion
|
|
467
|
+
tRS_rdXY_f16x2 = cute.make_rmem_tensor(tRS_rdXY_f32x2.layout, implicit_dtype)
|
|
468
|
+
tRS_rdXY_f16x2.store(tRS_rdXY_f32x2.load().to(implicit_dtype))
|
|
469
|
+
tRS_rD.store(cute.recast_tensor(tRS_rdXY_f16x2, Float32).load())
|
|
470
|
+
tRS_rOut_cvt = cute.make_fragment_like(tRS_rOut, self.postact_dtype)
|
|
471
|
+
tRS_rOut_cvt.store(tRS_rOut.load().to(self.postact_dtype))
|
|
472
|
+
return tRS_rOut_cvt
|
|
473
|
+
|
|
474
|
+
@cute.jit
|
|
475
|
+
def epi_end(
|
|
476
|
+
self,
|
|
477
|
+
params: EpilogueParams,
|
|
478
|
+
epi_tensors: Tuple[cute.Tensor, ...],
|
|
479
|
+
epi_tile: cute.Tile,
|
|
480
|
+
tiled_copy_t2r: Optional[cute.TiledCopy],
|
|
481
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
482
|
+
tile_coord_mnkl: cute.Coord,
|
|
483
|
+
varlen_manager: VarlenManager,
|
|
484
|
+
tidx: Int32,
|
|
485
|
+
) -> None:
|
|
486
|
+
partition_for_epilogue_fn = partial(
|
|
487
|
+
partition_for_epilogue,
|
|
488
|
+
epi_tile=epi_tile,
|
|
489
|
+
tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
|
|
490
|
+
tidx=tidx,
|
|
491
|
+
reference_src=tiled_copy_t2r is None,
|
|
492
|
+
)
|
|
493
|
+
tDrColVecReduce = epi_tensors[-1]
|
|
494
|
+
tile_M, tile_N = self.cta_tile_shape_mnk[:2]
|
|
495
|
+
if const_expr(params.mColVecReduce is not None):
|
|
496
|
+
tDrCVR_flt = cute.filter_zeros(tDrColVecReduce)
|
|
497
|
+
if const_expr(self.arch != 100):
|
|
498
|
+
for i in cutlass.range(cute.size(tDrCVR_flt), unroll_full=True):
|
|
499
|
+
tDrCVR_flt[i] = cute.arch.warp_reduction(
|
|
500
|
+
tDrCVR_flt[i], operator.add, threads_in_group=4
|
|
501
|
+
)
|
|
502
|
+
else:
|
|
503
|
+
# Don't need warp_reduce since we load from tmem with one thread per row
|
|
504
|
+
assert self.d_layout.is_n_major_c(), (
|
|
505
|
+
"GemmDGated only supports n-major output for now"
|
|
506
|
+
)
|
|
507
|
+
batch_idx = tile_coord_mnkl[3]
|
|
508
|
+
limit_n = (
|
|
509
|
+
params.mColVecReduce.shape[2]
|
|
510
|
+
if not varlen_manager.varlen_m
|
|
511
|
+
else params.mColVecReduce.shape[1]
|
|
512
|
+
)
|
|
513
|
+
if tile_coord_mnkl[1] < limit_n:
|
|
514
|
+
if const_expr(not varlen_manager.varlen_m):
|
|
515
|
+
mColVec = params.mColVecReduce[batch_idx, None, tile_coord_mnkl[1]]
|
|
516
|
+
else:
|
|
517
|
+
mColVec = cute.domain_offset(
|
|
518
|
+
(varlen_manager.params.cu_seqlens_m[batch_idx],),
|
|
519
|
+
params.mColVecReduce[None, tile_coord_mnkl[1]],
|
|
520
|
+
)
|
|
521
|
+
gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
|
|
522
|
+
limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
|
|
523
|
+
tDcCV = partition_for_epilogue_fn(cute.make_identity_tensor((tile_M, tile_N)))
|
|
524
|
+
tDrColVecReduce_m = layout_utils.convert_layout_zero_stride(
|
|
525
|
+
tDrColVecReduce, tDrColVecReduce.layout
|
|
526
|
+
)[None, 0]
|
|
527
|
+
tDcCV_m = layout_utils.convert_layout_zero_stride(tDcCV, tDrColVecReduce.layout)[
|
|
528
|
+
None, 0
|
|
529
|
+
]
|
|
530
|
+
if tDcCV_m[0][1] == 0:
|
|
531
|
+
for m in cutlass.range(cute.size(tDcCV_m, mode=[0])):
|
|
532
|
+
row_idx = tDcCV_m[m][0]
|
|
533
|
+
if row_idx < limit_m:
|
|
534
|
+
gColVec[row_idx] = tDrColVecReduce_m[m]
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
class GemmDGatedSm90(GemmDGatedMixin, GemmSm90):
|
|
538
|
+
pass
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
class GemmDGatedSm100(GemmDGatedMixin, GemmSm100):
|
|
542
|
+
pass
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
dgate_fn_map = {
|
|
546
|
+
"swiglu": quack.activation.dswiglu,
|
|
547
|
+
"swiglu_oai": quack.activation.dswiglu_oai,
|
|
548
|
+
"reglu": quack.activation.dreglu,
|
|
549
|
+
"geglu": quack.activation.dgeglu,
|
|
550
|
+
"glu": quack.activation.dglu,
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def gemm_dgated(
|
|
555
|
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
|
556
|
+
B: Tensor, # (l, n, k)
|
|
557
|
+
Out: Tensor, # (l, m, 2*n) if n_major or (l, 2*m, n) if m_major, or (total_m, 2*n) if varlen_m
|
|
558
|
+
PreAct: Tensor, # (l, m, 2*n) if n_major or (l, 2*m, n) if m_major, or (total_m, 2*n) if varlen_m
|
|
559
|
+
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
560
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
561
|
+
activation: Optional[str],
|
|
562
|
+
tile_M: int,
|
|
563
|
+
tile_N: int,
|
|
564
|
+
cluster_M: int,
|
|
565
|
+
cluster_N: int,
|
|
566
|
+
pingpong: bool = True,
|
|
567
|
+
persistent: bool = True,
|
|
568
|
+
max_swizzle_size: int = 8,
|
|
569
|
+
colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
|
570
|
+
# (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m
|
|
571
|
+
colvec_reduce: Optional[Tensor] = None,
|
|
572
|
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
573
|
+
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
574
|
+
) -> None:
|
|
575
|
+
"""If tile_count_semaphore is provided, it must already be zero'ed out."""
|
|
576
|
+
if cu_seqlens_m is not None:
|
|
577
|
+
assert persistent, "varlen_m requires persistent=True"
|
|
578
|
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
579
|
+
assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
|
|
580
|
+
assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
|
|
581
|
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
|
582
|
+
gather_A = A_idx is not None
|
|
583
|
+
if gather_A:
|
|
584
|
+
assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
|
|
585
|
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
586
|
+
assert activation in dgate_fn_map, f"Unsupported activation {activation}"
|
|
587
|
+
|
|
588
|
+
# Special handling for Out and PreAct
|
|
589
|
+
AB_swapped = not Out.stride(-1) == 1
|
|
590
|
+
assert Out.dtype == PreAct.dtype
|
|
591
|
+
implicit_dtype = torch2cute_dtype_map[Out.dtype]
|
|
592
|
+
assert Out.element_size() == 2, "Out dtype must be fp16 or bf16"
|
|
593
|
+
assert PreAct.element_size() == 2, "Preact dtype must be fp16 or bf16"
|
|
594
|
+
# We pretend that Out is (M, N, L) of type fp32 instead of (M, 2N, L) of type f16.
|
|
595
|
+
# Similarly we pretend that PreAct is (M, N, L) of type fp32 instead of (M, 2N, L) of type f16
|
|
596
|
+
if cu_seqlens_m is not None or not AB_swapped:
|
|
597
|
+
# varlen_m (always AB_swapped=False) or normal case with AB_swapped=False
|
|
598
|
+
Out = Out.view(torch.float32)
|
|
599
|
+
PreAct = PreAct.view(torch.float32)
|
|
600
|
+
else:
|
|
601
|
+
# Normal case with AB_swapped=True
|
|
602
|
+
Out = Out.mT.view(torch.float32).mT
|
|
603
|
+
PreAct = PreAct.mT.view(torch.float32).mT
|
|
604
|
+
|
|
605
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
606
|
+
A,
|
|
607
|
+
B,
|
|
608
|
+
Out,
|
|
609
|
+
PreAct,
|
|
610
|
+
additional_tensors={"PostAct": PostAct},
|
|
611
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
612
|
+
A_idx=A_idx,
|
|
613
|
+
)
|
|
614
|
+
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
|
615
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
616
|
+
major_configs = {
|
|
617
|
+
"A": ("m", "k", "l"),
|
|
618
|
+
"B": ("n", "k", "l"),
|
|
619
|
+
"D": ("m", "n", "l"),
|
|
620
|
+
"C": ("m", "n", "l"),
|
|
621
|
+
"PostAct": ("m", "n", "l"),
|
|
622
|
+
}
|
|
623
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
624
|
+
|
|
625
|
+
device_capacity = get_device_capacity(A.device)
|
|
626
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
627
|
+
GemmCls = GemmDGatedSm100 if device_capacity[0] > 9 else GemmDGatedSm90
|
|
628
|
+
|
|
629
|
+
acc_dtype = Float32
|
|
630
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
631
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
632
|
+
if not GemmCls.is_valid_dtypes(
|
|
633
|
+
tensor_infos["A"].dtype,
|
|
634
|
+
tensor_infos["B"].dtype,
|
|
635
|
+
acc_dtype,
|
|
636
|
+
tensor_infos["D"].dtype,
|
|
637
|
+
tensor_infos["A"].major,
|
|
638
|
+
tensor_infos["B"].major,
|
|
639
|
+
):
|
|
640
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
641
|
+
|
|
642
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
643
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
644
|
+
act_fn = dgate_fn_map[activation]
|
|
645
|
+
epi_args = GemmCls.EpilogueArguments(
|
|
646
|
+
tensor_infos["PostAct"].cute_tensor,
|
|
647
|
+
act_fn,
|
|
648
|
+
implicit_dtype=implicit_dtype,
|
|
649
|
+
mColVecBroadcast=(
|
|
650
|
+
from_dlpack(colvec_scale.detach(), assumed_align=4).mark_layout_dynamic(
|
|
651
|
+
leading_dim=1 if cu_seqlens_m is None else 0
|
|
652
|
+
)
|
|
653
|
+
if colvec_scale is not None
|
|
654
|
+
else None
|
|
655
|
+
),
|
|
656
|
+
mColVecReduce=(
|
|
657
|
+
from_dlpack(colvec_reduce.detach(), assumed_align=4).mark_layout_dynamic(
|
|
658
|
+
leading_dim=2 if cu_seqlens_m is None else 1
|
|
659
|
+
)
|
|
660
|
+
if colvec_reduce is not None
|
|
661
|
+
else None
|
|
662
|
+
),
|
|
663
|
+
)
|
|
664
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
665
|
+
max_active_clusters, tile_count_semaphore
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
|
669
|
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
670
|
+
cu_seqlens_m,
|
|
671
|
+
None, # cu_seqlens_k
|
|
672
|
+
A_idx,
|
|
673
|
+
max_active_clusters,
|
|
674
|
+
cluster_shape_mnk,
|
|
675
|
+
tensor_infos,
|
|
676
|
+
GemmCls.num_epi_tensormaps,
|
|
677
|
+
pingpong,
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
current_stream = cutlass_torch.current_stream()
|
|
681
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
682
|
+
tensor_infos,
|
|
683
|
+
activation,
|
|
684
|
+
tile_shape_mn,
|
|
685
|
+
cluster_shape_mnk,
|
|
686
|
+
pingpong,
|
|
687
|
+
persistent,
|
|
688
|
+
tile_count_semaphore is not None,
|
|
689
|
+
device_capacity,
|
|
690
|
+
max_swizzle_size,
|
|
691
|
+
colvec_scale.dtype if colvec_scale is not None else None,
|
|
692
|
+
colvec_reduce.dtype if colvec_reduce is not None else None,
|
|
693
|
+
cu_seqlens_m is not None,
|
|
694
|
+
A_idx is not None,
|
|
695
|
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
696
|
+
)
|
|
697
|
+
cache = gemm_dgated.compile_cache
|
|
698
|
+
if compile_key not in cache:
|
|
699
|
+
if device_capacity[0] == 9:
|
|
700
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
701
|
+
gemm_obj = GemmCls(
|
|
702
|
+
acc_dtype,
|
|
703
|
+
tensor_infos["A"].dtype,
|
|
704
|
+
tile_shape_mn,
|
|
705
|
+
cluster_shape_mnk,
|
|
706
|
+
gather_A=gather_A,
|
|
707
|
+
)
|
|
708
|
+
cache[compile_key] = cute.compile(
|
|
709
|
+
gemm_obj,
|
|
710
|
+
tensor_infos["A"].cute_tensor,
|
|
711
|
+
tensor_infos["B"].cute_tensor,
|
|
712
|
+
tensor_infos["D"].cute_tensor, # Out
|
|
713
|
+
tensor_infos["C"].cute_tensor, # PreAct
|
|
714
|
+
epi_args,
|
|
715
|
+
scheduler_args,
|
|
716
|
+
varlen_args,
|
|
717
|
+
current_stream,
|
|
718
|
+
)
|
|
719
|
+
cache[compile_key](
|
|
720
|
+
tensor_infos["A"].cute_tensor,
|
|
721
|
+
tensor_infos["B"].cute_tensor,
|
|
722
|
+
tensor_infos["D"].cute_tensor, # Out
|
|
723
|
+
tensor_infos["C"].cute_tensor, # PreAct
|
|
724
|
+
epi_args,
|
|
725
|
+
scheduler_args,
|
|
726
|
+
varlen_args,
|
|
727
|
+
current_stream,
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
gemm_dgated.compile_cache = {}
|
quack/gemm_default_epi.py
CHANGED
|
@@ -101,7 +101,7 @@ class GemmDefaultEpiMixin:
|
|
|
101
101
|
tRVsRV = thr_copy_RV.partition_D(sRowVec)
|
|
102
102
|
tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
|
|
103
103
|
limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
|
|
104
|
-
tRVpRV = cute.
|
|
104
|
+
tRVpRV = cute.make_rmem_tensor((1, cute.size(tRVsRV.shape[1])), Boolean)
|
|
105
105
|
for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
|
|
106
106
|
tRVpRV[0, m] = tRVcRV[0, m] < limit_n
|
|
107
107
|
cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
|
|
@@ -132,7 +132,7 @@ class GemmDefaultEpiMixin:
|
|
|
132
132
|
tCVsCV = thr_copy_CV.partition_D(sColVec)
|
|
133
133
|
tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
|
|
134
134
|
limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
|
|
135
|
-
tCVpCV = cute.
|
|
135
|
+
tCVpCV = cute.make_rmem_tensor((1, cute.size(tCVsCV.shape[1])), Boolean)
|
|
136
136
|
for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
|
|
137
137
|
tCVpCV[0, m] = tCVcCV[0, m] < limit_m
|
|
138
138
|
cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
|
|
@@ -158,7 +158,7 @@ class GemmDefaultEpiMixin:
|
|
|
158
158
|
None, None, None, epi_coord
|
|
159
159
|
]
|
|
160
160
|
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
|
161
|
-
tDrRowVec = cute.
|
|
161
|
+
tDrRowVec = cute.make_rmem_tensor(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
|
|
162
162
|
cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
|
|
163
163
|
tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
|
|
164
164
|
tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
|
|
@@ -169,7 +169,7 @@ class GemmDefaultEpiMixin:
|
|
|
169
169
|
]
|
|
170
170
|
# This somehow doesn't work, some dim with stride 0 turns to non-zero stride
|
|
171
171
|
# tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
|
|
172
|
-
tDrColVec = cute.
|
|
172
|
+
tDrColVec = cute.make_rmem_tensor(tDsColVec_cur.layout, tDsColVec_cur.element_type)
|
|
173
173
|
cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
|
|
174
174
|
tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
|
|
175
175
|
tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
|