quack-kernels 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/gemm_dact.py CHANGED
@@ -1,20 +1,35 @@
1
1
  # Copyright (c) 2025, Tri Dao.
2
- from typing import Optional, Tuple
2
+ from typing import Optional, Tuple, Callable, Type
3
3
  from functools import partial
4
+ from dataclasses import dataclass
5
+ import operator
4
6
 
7
+ import torch
5
8
  from torch import Tensor
6
9
 
7
10
  import cutlass
8
11
  import cutlass.cute as cute
9
- from cutlass import Float32, const_expr
12
+ from cutlass import Int32, Float32, const_expr
10
13
  import cutlass.torch as cutlass_torch
14
+ from cutlass.cute.runtime import from_dlpack
15
+ import cutlass.utils.blackwell_helpers as sm100_utils
11
16
 
17
+ import quack.sm90_utils as sm90_utils
18
+ from quack.sm90_utils import partition_for_epilogue
12
19
  from quack.gemm_sm90 import GemmSm90
13
20
  from quack.gemm_sm100 import GemmSm100
14
21
  from quack.gemm_default_epi import GemmDefaultEpiMixin
15
22
  from quack.gemm_act import GemmActMixin
16
- from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
23
+ from quack.cute_dsl_utils import (
24
+ ArgumentsBase,
25
+ ParamsBase,
26
+ torch2cute_dtype_map,
27
+ get_device_capacity,
28
+ get_max_active_clusters,
29
+ )
17
30
  from quack.gemm_wrapper_utils import GemmWrapperBase
31
+ from quack.varlen_utils import VarlenManager
32
+ import quack.layout_utils as layout_utils
18
33
  import quack.activation
19
34
 
20
35
 
@@ -39,7 +54,7 @@ class GemmDActMixin(GemmActMixin):
39
54
  tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
40
55
  # If we don't have .shape here, the compiler generates local stores and loads
41
56
  if const_expr(params.act_fn is not None):
42
- tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
57
+ tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
43
58
  if const_expr(self.arch < 100):
44
59
  for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
45
60
  tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
@@ -213,3 +228,504 @@ def gemm_dact(
213
228
 
214
229
 
215
230
  gemm_dact.compile_cache = {}
231
+
232
+
233
+ class GemmDGatedMixin(GemmActMixin):
234
+ # Different from GemmActMixin, here act_bwd_fn must take in 3 arguments (x, y, dout)
235
+ # and return 3 arguments (dx, dy, out)
236
+ @dataclass
237
+ class EpilogueArguments(ArgumentsBase):
238
+ mPostAct: cute.Tensor
239
+ act_bwd_fn: cutlass.Constexpr[Callable]
240
+ implicit_dtype: Type[cutlass.Numeric] = cute.BFloat16
241
+ # We don't use alpha, beta, mRowVecBroadcast for now
242
+ alpha: Optional[Float32 | cute.Tensor] = None
243
+ beta: Optional[Float32 | cute.Tensor] = None
244
+ mRowVecBroadcast: Optional[cute.Tensor] = None
245
+ mColVecBroadcast: Optional[cute.Tensor] = None
246
+ mColVecReduce: Optional[cute.Tensor] = None
247
+
248
+ @dataclass
249
+ class EpilogueParams(ParamsBase):
250
+ tma_atom_postact: cute.CopyAtom
251
+ mPostAct_mnl: cute.Tensor
252
+ epi_postact_smem_layout_staged: cute.ComposedLayout
253
+ epi_tile_postact: cute.Tile
254
+ act_bwd_fn: cutlass.Constexpr[Callable]
255
+ implicit_dtype: Type[cutlass.Numeric]
256
+ alpha: Optional[Float32 | cute.Tensor] = None
257
+ beta: Optional[Float32 | cute.Tensor] = None
258
+ mRowVecBroadcast: Optional[cute.Tensor] = None
259
+ mColVecBroadcast: Optional[cute.Tensor] = None
260
+ mColVecReduce: Optional[cute.Tensor] = None
261
+
262
+ def epi_to_underlying_arguments(
263
+ self, args: EpilogueArguments, *, loc=None, ip=None
264
+ ) -> EpilogueParams:
265
+ self.postact_dtype = args.mPostAct.element_type
266
+ self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
267
+ # C and D are implicitly 2 16-bit elements packed into 32 bits, simply for the purpose
268
+ # for reusing the existing load/store code.
269
+ assert args.implicit_dtype.width == 16, "GemmDGated only supports 16bit for now"
270
+ assert self.d_dtype.width == 32, "D storage type must be 32 bit"
271
+ assert self.c_dtype.width == 32, "C storage type must be 32 bit"
272
+
273
+ self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
274
+ epi_tile_postact = self.epi_tile
275
+ utils_cls = sm100_utils if self.arch == 100 else sm90_utils
276
+ epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
277
+ self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
278
+ )
279
+ tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
280
+ args.mPostAct,
281
+ epi_postact_smem_layout_staged,
282
+ epi_tile_postact,
283
+ op_type="store",
284
+ )
285
+ # Assume all strides are divisible by 32 bits except the last stride
286
+ new_stride = lambda t: tuple(
287
+ cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
288
+ for s in t.stride
289
+ )
290
+ mRowVecBroadcast, mColVecBroadcast, mColVecReduce = [
291
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
292
+ if t is not None
293
+ else None
294
+ for t in (args.mRowVecBroadcast, args.mColVecBroadcast, args.mColVecReduce)
295
+ ]
296
+ return self.EpilogueParams(
297
+ tma_atom_postact,
298
+ tma_tensor_postact,
299
+ epi_postact_smem_layout_staged,
300
+ epi_tile_postact,
301
+ args.act_bwd_fn,
302
+ args.implicit_dtype,
303
+ alpha=args.alpha,
304
+ beta=args.beta,
305
+ mRowVecBroadcast=mRowVecBroadcast,
306
+ mColVecBroadcast=mColVecBroadcast,
307
+ mColVecReduce=mColVecReduce,
308
+ )
309
+
310
+ @cute.jit
311
+ def epi_begin(
312
+ self,
313
+ params: EpilogueParams,
314
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
315
+ epi_tile: cute.Tile,
316
+ tiled_copy_t2r: Optional[cute.TiledCopy],
317
+ tiled_copy_r2s: cute.TiledCopy,
318
+ tile_coord_mnkl: cute.Coord,
319
+ varlen_manager: VarlenManager,
320
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
321
+ tidx: Int32,
322
+ ) -> Tuple[cute.Tensor, ...]:
323
+ epi_tensors = GemmDefaultEpiMixin.epi_begin(
324
+ self,
325
+ params,
326
+ epi_smem_tensors,
327
+ epi_tile,
328
+ tiled_copy_t2r,
329
+ tiled_copy_r2s,
330
+ tile_coord_mnkl,
331
+ varlen_manager,
332
+ epilogue_barrier,
333
+ tidx,
334
+ )
335
+ partition_for_epilogue_fn = partial(
336
+ partition_for_epilogue,
337
+ epi_tile=epi_tile,
338
+ tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
339
+ tidx=tidx,
340
+ reference_src=tiled_copy_t2r is None,
341
+ )
342
+ tDrColVecReduce = None
343
+ if const_expr(params.mColVecReduce is not None):
344
+ colvec_mma_layout = cute.make_layout(self.cta_tile_shape_mnk[:2], stride=(1, 0))
345
+ tDrColVec_layout = partition_for_epilogue_fn(
346
+ cute.make_rmem_tensor(colvec_mma_layout, Float32)
347
+ ).layout
348
+ tDrColVecReduce = cute.make_rmem_tensor(tDrColVec_layout, Float32)
349
+ cute.filter_zeros(tDrColVecReduce).fill(0.0)
350
+ return (*epi_tensors, tDrColVecReduce)
351
+
352
+ def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
353
+ epi_tensors, tDrColVecReduce = epi_tensors[:-1], epi_tensors[-1]
354
+ epi_loop_tensors = super().epi_begin_loop(params, epi_tensors, epi_coord)
355
+ tDrColVecReduce_cur = None
356
+ if const_expr(tDrColVecReduce is not None):
357
+ tDrColVecReduce_cur = cute.group_modes(tDrColVecReduce, 3, cute.rank(tDrColVecReduce))[
358
+ None, None, None, epi_coord
359
+ ]
360
+ return (*epi_loop_tensors, tDrColVecReduce_cur)
361
+
362
+ @cute.jit
363
+ def epi_visit_subtile(
364
+ self,
365
+ params: EpilogueParams,
366
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
367
+ tRS_rD: cute.Tensor,
368
+ tRS_rC: Optional[cute.Tensor] = None,
369
+ ) -> Optional[cute.Tensor]:
370
+ alpha, beta, tDrRowVec, tDrColVec, tDrColVecReduce = epi_loop_tensors
371
+ assert alpha is None and beta is None and tDrRowVec is None # We don't use these for now
372
+ assert tRS_rC is not None
373
+ implicit_dtype = params.implicit_dtype
374
+ assert implicit_dtype.width == 16, "GemmDGatedMixin only supports 16bit for now"
375
+ tRS_rXY_f16x2 = cute.recast_tensor(tRS_rC, implicit_dtype)
376
+ tRS_rXY_f32x2 = cute.make_rmem_tensor(tRS_rXY_f16x2.layout, Float32)
377
+ tRS_rXY_f32x2.store(tRS_rXY_f16x2.load().to(Float32))
378
+ tRS_rdXY_f32x2 = cute.make_rmem_tensor_like(tRS_rXY_f32x2, Float32)
379
+ tRS_rOut = cute.make_rmem_tensor_like(tRS_rD, Float32)
380
+ tRS_rD_scaled = cute.make_rmem_tensor_like(tRS_rD)
381
+ if const_expr(tDrColVec is not None): # Scale D by colvec
382
+ if const_expr(self.arch < 100):
383
+ tRS_rD_scaled.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type))
384
+ else:
385
+ tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
386
+ tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout)
387
+ tRS_rD_scaled_mn = layout_utils.convert_layout_zero_stride(
388
+ tRS_rD_scaled, tDrColVec.layout
389
+ )
390
+ for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
391
+ for n in cutlass.range(
392
+ cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
393
+ ):
394
+ (
395
+ tRS_rD_scaled_mn[m, 2 * n],
396
+ tRS_rD_scaled_mn[m, 2 * n + 1],
397
+ ) = cute.arch.mul_packed_f32x2(
398
+ (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]),
399
+ (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]),
400
+ )
401
+ else:
402
+ tRS_rD_scaled.store(tRS_rD.load())
403
+ if const_expr(self.arch < 100):
404
+ for i in cutlass.range(cute.size(tRS_rD)):
405
+ (
406
+ tRS_rdXY_f32x2[2 * i],
407
+ tRS_rdXY_f32x2[2 * i + 1],
408
+ tRS_rOut[i],
409
+ ) = params.act_bwd_fn(
410
+ tRS_rXY_f32x2[2 * i], tRS_rXY_f32x2[2 * i + 1], tRS_rD_scaled[i]
411
+ )
412
+ else:
413
+ for i in cutlass.range(cute.size(tRS_rD) // 2):
414
+ (
415
+ (tRS_rdXY_f32x2[4 * i], tRS_rdXY_f32x2[4 * i + 2]),
416
+ (tRS_rdXY_f32x2[4 * i + 1], tRS_rdXY_f32x2[4 * i + 3]),
417
+ (tRS_rOut[2 * i], tRS_rOut[2 * i + 1]),
418
+ ) = params.act_bwd_fn(
419
+ (tRS_rXY_f32x2[4 * i], tRS_rXY_f32x2[4 * i + 2]),
420
+ (tRS_rXY_f32x2[4 * i + 1], tRS_rXY_f32x2[4 * i + 3]),
421
+ (tRS_rD_scaled[2 * i], tRS_rD_scaled[2 * i + 1]),
422
+ )
423
+ if const_expr(tDrColVecReduce is not None):
424
+ # Need to multiply before D is scaled by colvec_scale
425
+ if const_expr(self.arch < 100):
426
+ for i in cutlass.range(cute.size(tDrColVecReduce), unroll_full=True):
427
+ tDrColVecReduce[i] += tRS_rOut[i] * tRS_rD[i]
428
+ else:
429
+ tDrColVecReduce_mn = layout_utils.convert_layout_zero_stride(
430
+ tDrColVecReduce, tDrColVecReduce.layout
431
+ )
432
+ tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVecReduce.layout)
433
+ tRS_rOut_mn = layout_utils.convert_layout_zero_stride(
434
+ tRS_rOut, tDrColVecReduce.layout
435
+ )
436
+ for m in cutlass.range(cute.size(tDrColVecReduce_mn, mode=[0]), unroll_full=True):
437
+ row_sum = cute.arch.mul_packed_f32x2(
438
+ (tRS_rD_mn[m, 0], tRS_rD_mn[m, 1]), (tRS_rOut_mn[m, 0], tRS_rOut_mn[m, 1])
439
+ )
440
+ for n in cutlass.range(
441
+ 1, cute.size(tDrColVecReduce_mn, mode=[1]) // 2, unroll_full=True
442
+ ):
443
+ row_sum = cute.arch.fma_packed_f32x2(
444
+ (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]),
445
+ (tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]),
446
+ row_sum,
447
+ )
448
+ tDrColVecReduce_mn[m, 0] += row_sum[0] + row_sum[1]
449
+
450
+ if const_expr(tDrColVec is not None): # Scale Out by colvec
451
+ if const_expr(self.arch < 100):
452
+ tRS_rOut.store(tRS_rOut.load() * tDrColVec.load().to(tRS_rD.element_type))
453
+ else:
454
+ tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
455
+ tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVec.layout)
456
+ for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
457
+ for n in cutlass.range(
458
+ cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
459
+ ):
460
+ tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1] = (
461
+ cute.arch.mul_packed_f32x2(
462
+ (tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]),
463
+ (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]),
464
+ )
465
+ )
466
+ # Type conversion
467
+ tRS_rdXY_f16x2 = cute.make_rmem_tensor(tRS_rdXY_f32x2.layout, implicit_dtype)
468
+ tRS_rdXY_f16x2.store(tRS_rdXY_f32x2.load().to(implicit_dtype))
469
+ tRS_rD.store(cute.recast_tensor(tRS_rdXY_f16x2, Float32).load())
470
+ tRS_rOut_cvt = cute.make_fragment_like(tRS_rOut, self.postact_dtype)
471
+ tRS_rOut_cvt.store(tRS_rOut.load().to(self.postact_dtype))
472
+ return tRS_rOut_cvt
473
+
474
+ @cute.jit
475
+ def epi_end(
476
+ self,
477
+ params: EpilogueParams,
478
+ epi_tensors: Tuple[cute.Tensor, ...],
479
+ epi_tile: cute.Tile,
480
+ tiled_copy_t2r: Optional[cute.TiledCopy],
481
+ tiled_copy_r2s: cute.TiledCopy,
482
+ tile_coord_mnkl: cute.Coord,
483
+ varlen_manager: VarlenManager,
484
+ tidx: Int32,
485
+ ) -> None:
486
+ partition_for_epilogue_fn = partial(
487
+ partition_for_epilogue,
488
+ epi_tile=epi_tile,
489
+ tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
490
+ tidx=tidx,
491
+ reference_src=tiled_copy_t2r is None,
492
+ )
493
+ tDrColVecReduce = epi_tensors[-1]
494
+ tile_M, tile_N = self.cta_tile_shape_mnk[:2]
495
+ if const_expr(params.mColVecReduce is not None):
496
+ tDrCVR_flt = cute.filter_zeros(tDrColVecReduce)
497
+ if const_expr(self.arch != 100):
498
+ for i in cutlass.range(cute.size(tDrCVR_flt), unroll_full=True):
499
+ tDrCVR_flt[i] = cute.arch.warp_reduction(
500
+ tDrCVR_flt[i], operator.add, threads_in_group=4
501
+ )
502
+ else:
503
+ # Don't need warp_reduce since we load from tmem with one thread per row
504
+ assert self.d_layout.is_n_major_c(), (
505
+ "GemmDGated only supports n-major output for now"
506
+ )
507
+ batch_idx = tile_coord_mnkl[3]
508
+ limit_n = (
509
+ params.mColVecReduce.shape[2]
510
+ if not varlen_manager.varlen_m
511
+ else params.mColVecReduce.shape[1]
512
+ )
513
+ if tile_coord_mnkl[1] < limit_n:
514
+ if const_expr(not varlen_manager.varlen_m):
515
+ mColVec = params.mColVecReduce[batch_idx, None, tile_coord_mnkl[1]]
516
+ else:
517
+ mColVec = cute.domain_offset(
518
+ (varlen_manager.params.cu_seqlens_m[batch_idx],),
519
+ params.mColVecReduce[None, tile_coord_mnkl[1]],
520
+ )
521
+ gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
522
+ limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
523
+ tDcCV = partition_for_epilogue_fn(cute.make_identity_tensor((tile_M, tile_N)))
524
+ tDrColVecReduce_m = layout_utils.convert_layout_zero_stride(
525
+ tDrColVecReduce, tDrColVecReduce.layout
526
+ )[None, 0]
527
+ tDcCV_m = layout_utils.convert_layout_zero_stride(tDcCV, tDrColVecReduce.layout)[
528
+ None, 0
529
+ ]
530
+ if tDcCV_m[0][1] == 0:
531
+ for m in cutlass.range(cute.size(tDcCV_m, mode=[0])):
532
+ row_idx = tDcCV_m[m][0]
533
+ if row_idx < limit_m:
534
+ gColVec[row_idx] = tDrColVecReduce_m[m]
535
+
536
+
537
+ class GemmDGatedSm90(GemmDGatedMixin, GemmSm90):
538
+ pass
539
+
540
+
541
+ class GemmDGatedSm100(GemmDGatedMixin, GemmSm100):
542
+ pass
543
+
544
+
545
+ dgate_fn_map = {
546
+ "swiglu": quack.activation.dswiglu,
547
+ "swiglu_oai": quack.activation.dswiglu_oai,
548
+ "reglu": quack.activation.dreglu,
549
+ "geglu": quack.activation.dgeglu,
550
+ "glu": quack.activation.dglu,
551
+ }
552
+
553
+
554
+ def gemm_dgated(
555
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
556
+ B: Tensor, # (l, n, k)
557
+ Out: Tensor, # (l, m, 2*n) if n_major or (l, 2*m, n) if m_major, or (total_m, 2*n) if varlen_m
558
+ PreAct: Tensor, # (l, m, 2*n) if n_major or (l, 2*m, n) if m_major, or (total_m, 2*n) if varlen_m
559
+ PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
560
+ tile_count_semaphore: Optional[Tensor], # (1,)
561
+ activation: Optional[str],
562
+ tile_M: int,
563
+ tile_N: int,
564
+ cluster_M: int,
565
+ cluster_N: int,
566
+ pingpong: bool = True,
567
+ persistent: bool = True,
568
+ max_swizzle_size: int = 8,
569
+ colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
570
+ # (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m
571
+ colvec_reduce: Optional[Tensor] = None,
572
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
573
+ A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
574
+ ) -> None:
575
+ """If tile_count_semaphore is provided, it must already be zero'ed out."""
576
+ if cu_seqlens_m is not None:
577
+ assert persistent, "varlen_m requires persistent=True"
578
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
579
+ assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
580
+ assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
581
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
582
+ gather_A = A_idx is not None
583
+ if gather_A:
584
+ assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
585
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
586
+ assert activation in dgate_fn_map, f"Unsupported activation {activation}"
587
+
588
+ # Special handling for Out and PreAct
589
+ AB_swapped = not Out.stride(-1) == 1
590
+ assert Out.dtype == PreAct.dtype
591
+ implicit_dtype = torch2cute_dtype_map[Out.dtype]
592
+ assert Out.element_size() == 2, "Out dtype must be fp16 or bf16"
593
+ assert PreAct.element_size() == 2, "Preact dtype must be fp16 or bf16"
594
+ # We pretend that Out is (M, N, L) of type fp32 instead of (M, 2N, L) of type f16.
595
+ # Similarly we pretend that PreAct is (M, N, L) of type fp32 instead of (M, 2N, L) of type f16
596
+ if cu_seqlens_m is not None or not AB_swapped:
597
+ # varlen_m (always AB_swapped=False) or normal case with AB_swapped=False
598
+ Out = Out.view(torch.float32)
599
+ PreAct = PreAct.view(torch.float32)
600
+ else:
601
+ # Normal case with AB_swapped=True
602
+ Out = Out.mT.view(torch.float32).mT
603
+ PreAct = PreAct.mT.view(torch.float32).mT
604
+
605
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
606
+ A,
607
+ B,
608
+ Out,
609
+ PreAct,
610
+ additional_tensors={"PostAct": PostAct},
611
+ cu_seqlens_m=cu_seqlens_m,
612
+ A_idx=A_idx,
613
+ )
614
+ GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
615
+ GemmWrapperBase.extract_dtypes(tensor_infos)
616
+ major_configs = {
617
+ "A": ("m", "k", "l"),
618
+ "B": ("n", "k", "l"),
619
+ "D": ("m", "n", "l"),
620
+ "C": ("m", "n", "l"),
621
+ "PostAct": ("m", "n", "l"),
622
+ }
623
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
624
+
625
+ device_capacity = get_device_capacity(A.device)
626
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
627
+ GemmCls = GemmDGatedSm100 if device_capacity[0] > 9 else GemmDGatedSm90
628
+
629
+ acc_dtype = Float32
630
+ tile_shape_mn = (tile_M, tile_N)
631
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
632
+ if not GemmCls.is_valid_dtypes(
633
+ tensor_infos["A"].dtype,
634
+ tensor_infos["B"].dtype,
635
+ acc_dtype,
636
+ tensor_infos["D"].dtype,
637
+ tensor_infos["A"].major,
638
+ tensor_infos["B"].major,
639
+ ):
640
+ raise TypeError("Skipping due to unsupported combination of types and majors")
641
+
642
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
643
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
644
+ act_fn = dgate_fn_map[activation]
645
+ epi_args = GemmCls.EpilogueArguments(
646
+ tensor_infos["PostAct"].cute_tensor,
647
+ act_fn,
648
+ implicit_dtype=implicit_dtype,
649
+ mColVecBroadcast=(
650
+ from_dlpack(colvec_scale.detach(), assumed_align=4).mark_layout_dynamic(
651
+ leading_dim=1 if cu_seqlens_m is None else 0
652
+ )
653
+ if colvec_scale is not None
654
+ else None
655
+ ),
656
+ mColVecReduce=(
657
+ from_dlpack(colvec_reduce.detach(), assumed_align=4).mark_layout_dynamic(
658
+ leading_dim=2 if cu_seqlens_m is None else 1
659
+ )
660
+ if colvec_reduce is not None
661
+ else None
662
+ ),
663
+ )
664
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
665
+ max_active_clusters, tile_count_semaphore
666
+ )
667
+
668
+ # Create varlen arguments if needed (assumes persistent=True when varlen_m)
669
+ varlen_args = GemmWrapperBase.create_varlen_args(
670
+ cu_seqlens_m,
671
+ None, # cu_seqlens_k
672
+ A_idx,
673
+ max_active_clusters,
674
+ cluster_shape_mnk,
675
+ tensor_infos,
676
+ GemmCls.num_epi_tensormaps,
677
+ pingpong,
678
+ )
679
+
680
+ current_stream = cutlass_torch.current_stream()
681
+ compile_key = GemmWrapperBase.get_compile_key(
682
+ tensor_infos,
683
+ activation,
684
+ tile_shape_mn,
685
+ cluster_shape_mnk,
686
+ pingpong,
687
+ persistent,
688
+ tile_count_semaphore is not None,
689
+ device_capacity,
690
+ max_swizzle_size,
691
+ colvec_scale.dtype if colvec_scale is not None else None,
692
+ colvec_reduce.dtype if colvec_reduce is not None else None,
693
+ cu_seqlens_m is not None,
694
+ A_idx is not None,
695
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
696
+ )
697
+ cache = gemm_dgated.compile_cache
698
+ if compile_key not in cache:
699
+ if device_capacity[0] == 9:
700
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
701
+ gemm_obj = GemmCls(
702
+ acc_dtype,
703
+ tensor_infos["A"].dtype,
704
+ tile_shape_mn,
705
+ cluster_shape_mnk,
706
+ gather_A=gather_A,
707
+ )
708
+ cache[compile_key] = cute.compile(
709
+ gemm_obj,
710
+ tensor_infos["A"].cute_tensor,
711
+ tensor_infos["B"].cute_tensor,
712
+ tensor_infos["D"].cute_tensor, # Out
713
+ tensor_infos["C"].cute_tensor, # PreAct
714
+ epi_args,
715
+ scheduler_args,
716
+ varlen_args,
717
+ current_stream,
718
+ )
719
+ cache[compile_key](
720
+ tensor_infos["A"].cute_tensor,
721
+ tensor_infos["B"].cute_tensor,
722
+ tensor_infos["D"].cute_tensor, # Out
723
+ tensor_infos["C"].cute_tensor, # PreAct
724
+ epi_args,
725
+ scheduler_args,
726
+ varlen_args,
727
+ current_stream,
728
+ )
729
+
730
+
731
+ gemm_dgated.compile_cache = {}
quack/gemm_default_epi.py CHANGED
@@ -101,7 +101,7 @@ class GemmDefaultEpiMixin:
101
101
  tRVsRV = thr_copy_RV.partition_D(sRowVec)
102
102
  tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
103
103
  limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
104
- tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
104
+ tRVpRV = cute.make_rmem_tensor((1, cute.size(tRVsRV.shape[1])), Boolean)
105
105
  for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
106
106
  tRVpRV[0, m] = tRVcRV[0, m] < limit_n
107
107
  cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
@@ -132,7 +132,7 @@ class GemmDefaultEpiMixin:
132
132
  tCVsCV = thr_copy_CV.partition_D(sColVec)
133
133
  tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
134
134
  limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
135
- tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
135
+ tCVpCV = cute.make_rmem_tensor((1, cute.size(tCVsCV.shape[1])), Boolean)
136
136
  for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
137
137
  tCVpCV[0, m] = tCVcCV[0, m] < limit_m
138
138
  cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
@@ -158,7 +158,7 @@ class GemmDefaultEpiMixin:
158
158
  None, None, None, epi_coord
159
159
  ]
160
160
  # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
161
- tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
161
+ tDrRowVec = cute.make_rmem_tensor(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
162
162
  cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
163
163
  tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
164
164
  tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
@@ -169,7 +169,7 @@ class GemmDefaultEpiMixin:
169
169
  ]
170
170
  # This somehow doesn't work, some dim with stride 0 turns to non-zero stride
171
171
  # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
172
- tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
172
+ tDrColVec = cute.make_rmem_tensor(tDsColVec_cur.layout, tDsColVec_cur.element_type)
173
173
  cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
174
174
  tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
175
175
  tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))