quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/sm100_utils.py ADDED
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Type, Union
4
+
5
+ import cutlass.cute as cute
6
+ import cutlass.utils.blackwell_helpers as sm100_utils_og
7
+ from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
8
+ from cutlass.cutlass_dsl import Numeric, dsl_user_op
9
+
10
+
11
+ @dsl_user_op
12
+ def make_smem_layout_cpasync_a(
13
+ tiled_mma: cute.TiledMma,
14
+ mma_tiler_mnk: cute.Tile,
15
+ a_dtype: Type[Numeric],
16
+ num_stages: int,
17
+ *,
18
+ loc=None,
19
+ ip=None,
20
+ ) -> Union[cute.Layout, cute.ComposedLayout]:
21
+ """
22
+ :param tiled_mma: The tiled MMA used to partition tensor A
23
+ :type tiled_mma: cute.TiledMma
24
+ :param mma_tiler_mnk: The MMA tile shape
25
+ :type mma_tiler_mnk: cute.cute.Tile
26
+ :param a_dtype: The element type for tensor A
27
+ :type a_dtype: Type[Numeric]
28
+ :param num_stages: The number of pipeline stages for tensor A
29
+ :type num_stages: int
30
+
31
+ :return: SMEM layout for tensor A
32
+ :rtype: Union[cute.Layout, cute.ComposedLayout]
33
+ """
34
+
35
+ is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
36
+ a_smem_shape = tiled_mma.partition_shape_A(
37
+ cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip)
38
+ )
39
+ a_smem_shape_mn_k = (
40
+ cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1],
41
+ cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
42
+ )
43
+ a_smem_layout_atom = sm100_utils_og.make_smem_layout_atom(
44
+ sm100_utils_og.get_smem_layout_atom_ab(
45
+ tiled_mma.op.a_major_mode,
46
+ a_dtype,
47
+ a_smem_shape_mn_k,
48
+ loc=loc,
49
+ ip=ip,
50
+ ),
51
+ a_dtype,
52
+ loc=loc,
53
+ ip=ip,
54
+ )
55
+ a_smem_layout_staged = cute.tile_to_shape(
56
+ a_smem_layout_atom,
57
+ cute.append(a_smem_shape_mn_k, num_stages, loc=loc, ip=ip),
58
+ order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
59
+ loc=loc,
60
+ ip=ip,
61
+ )
62
+ return a_smem_layout_staged
quack/sm90_utils.py ADDED
@@ -0,0 +1,127 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Type, Union, Optional
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ import cutlass.utils.hopper_helpers as sm90_utils_og
8
+ from cutlass.cute.nvgpu import warpgroup
9
+ from cutlass.cutlass_dsl import Numeric, dsl_user_op
10
+ from cutlass import Float32, Int32, Boolean, const_expr
11
+ from cutlass.utils import LayoutEnum
12
+
13
+
14
+ @dsl_user_op
15
+ def make_smem_layout(
16
+ dtype: Type[Numeric],
17
+ layout: LayoutEnum,
18
+ tile: cute.Tile,
19
+ stage: Optional[int] = None,
20
+ *,
21
+ loc=None,
22
+ ip=None,
23
+ ) -> Union[cute.Layout, cute.ComposedLayout]:
24
+ shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip)
25
+ major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
26
+ smem_layout_atom = warpgroup.make_smem_layout_atom(
27
+ sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
28
+ dtype,
29
+ )
30
+ smem_layout_staged = cute.tile_to_shape(
31
+ smem_layout_atom,
32
+ cute.append(shape, stage) if const_expr(stage is not None) else shape,
33
+ order=(1, 0, 2) if layout.is_m_major_c() else (0, 1, 2),
34
+ )
35
+ return smem_layout_staged
36
+
37
+
38
+ # For compatibility with blackwell_helpers.py
39
+ make_smem_layout_epi = make_smem_layout
40
+
41
+
42
+ @dsl_user_op
43
+ def partition_for_epilogue(
44
+ cT: cute.Tensor,
45
+ epi_tile: cute.Tile,
46
+ tiled_copy: cute.TiledCopy,
47
+ tidx: Int32,
48
+ reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy
49
+ *,
50
+ loc=None,
51
+ ip=None,
52
+ ) -> cute.Tensor:
53
+ thr_copy = tiled_copy.get_slice(tidx)
54
+ cT_epi = cute.flat_divide(cT, epi_tile)
55
+ # (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
56
+ if const_expr(reference_src):
57
+ return thr_copy.partition_S(cT_epi, loc=loc, ip=ip)
58
+ else:
59
+ return thr_copy.partition_D(cT_epi, loc=loc, ip=ip)
60
+
61
+
62
+ @cute.jit
63
+ def gemm(
64
+ tiled_mma: cute.TiledMma,
65
+ acc: cute.Tensor,
66
+ tCrA: cute.Tensor,
67
+ tCrB: cute.Tensor,
68
+ zero_init: cutlass.Constexpr[bool] = False,
69
+ wg_wait: cutlass.Constexpr[int] = 0,
70
+ # A_in_regs: cutlass.Constexpr[bool] = False,
71
+ swap_AB: cutlass.Constexpr[bool] = False,
72
+ ) -> None:
73
+ if const_expr(swap_AB):
74
+ gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
75
+ else:
76
+ warpgroup.fence()
77
+ # We make a new mma_atom since we'll be modifying its attribute (accumulate).
78
+ # Otherwise the compiler complains "operand #0 does not dominate this use"
79
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
80
+ mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
81
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
82
+ cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
83
+ mma_atom.set(warpgroup.Field.ACCUMULATE, True)
84
+ warpgroup.commit_group()
85
+ if const_expr(wg_wait >= 0):
86
+ warpgroup.wait_group(wg_wait)
87
+
88
+
89
+ def gemm_zero_init(
90
+ tiled_mma: cute.TiledMma,
91
+ shape: cute.Shape,
92
+ tCrA: cute.Tensor,
93
+ tCrB: cute.Tensor,
94
+ A_idx: Optional[Int32] = None,
95
+ B_idx: Optional[Int32] = None,
96
+ wg_wait: int = -1,
97
+ swap_AB: bool = False,
98
+ ) -> cute.Tensor:
99
+ if const_expr(swap_AB):
100
+ return gemm_zero_init(
101
+ tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
102
+ )
103
+ else:
104
+ acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
105
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
106
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
107
+ gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
108
+ return acc
109
+
110
+
111
+ def gemm_w_idx(
112
+ tiled_mma: cute.TiledMma,
113
+ acc: cute.Tensor,
114
+ tCrA: cute.Tensor,
115
+ tCrB: cute.Tensor,
116
+ zero_init: Boolean,
117
+ A_idx: Optional[Int32] = None,
118
+ B_idx: Optional[Int32] = None,
119
+ wg_wait: int = -1,
120
+ swap_AB: bool = False,
121
+ ) -> None:
122
+ if const_expr(swap_AB):
123
+ gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
124
+ else:
125
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
126
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
127
+ gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)