quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_config.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # Copyright (C) 2025, Fri Dao.
2
2
  import itertools
3
- from typing import Optional, List
3
+ from typing import Optional, List, Literal
4
+ from functools import partial
4
5
  from dataclasses import dataclass
5
6
 
6
7
 
@@ -13,57 +14,82 @@ class GemmConfig:
13
14
  cluster_n: int = 1
14
15
  swap_ab: bool = False
15
16
  # raster_order: int = 1
16
- # max_swizzle_size: int = 8
17
+ max_swizzle_size: int = 8
17
18
 
18
19
 
19
20
  def get_all_configs(
21
+ device_capacity: Literal[9, 10] = 9,
20
22
  epilogue: Optional[str] = None,
21
23
  tune_coop: bool = True,
22
24
  # tune_raster_order=True,
23
25
  ) -> List[GemmConfig]:
24
- tile_n_vals = [128, 144, 160, 176, 192, 208]
25
- tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
26
- (128, 224),
27
- (128, 256),
28
- # (192, 256), # Getting IOT instruction (core dumped) in the bwd
29
- ]
30
- tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
31
- if epilogue in ["gated"]:
32
- tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
33
- tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
34
- elif epilogue in ["lse"]:
35
- tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
36
- tile_mn_vals = []
37
- if tune_coop:
38
- tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
39
- tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
40
- cluster = [(1, 2), (2, 1)]
41
- # cluster = [(1, 1), (1, 2), (2, 1)]
42
- if epilogue in ["lse"]:
26
+ assert device_capacity in [9, 10]
27
+ if device_capacity == 9:
28
+ tile_n_vals = [128, 144, 160, 176, 192, 208]
29
+ tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
30
+ (128, 224),
31
+ (128, 256),
32
+ # (192, 256), # Getting IOT instruction (core dumped) in the bwd
33
+ ]
34
+ tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
35
+ if epilogue in ["gated"]:
36
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
37
+ tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
38
+ elif epilogue in ["lse"]:
39
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
40
+ tile_mn_vals = []
41
+ if tune_coop:
42
+ tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
43
+ tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
43
44
  cluster = [(1, 2), (2, 1)]
44
- swap_ab_vals = [False, True]
45
- if epilogue in ["lse", "gated"]:
46
- swap_ab_vals = [False]
47
- # raster_swizzle = (
48
- # [(0, 1)]
49
- # if not tune_raster_order
50
- # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
51
- # )
52
- return [
53
- GemmConfig(
54
- tile_m=tile_m,
55
- tile_n=tile_n,
56
- pingpong=pingpong,
57
- cluster_m=cluster_m,
58
- cluster_n=cluster_n,
59
- swap_ab=swap_ab,
60
- # raster_order=raster_order,
61
- # max_swizzle_size=max_swizzle_size,
45
+ # cluster = [(1, 1), (1, 2), (2, 1)]
46
+ if epilogue in ["lse"]:
47
+ cluster = [(1, 2), (2, 1)]
48
+ swap_ab_vals = [False, True]
49
+ if epilogue in ["lse", "gated"]:
50
+ swap_ab_vals = [False]
51
+ # raster_swizzle = (
52
+ # [(0, 1)]
53
+ # if not tune_raster_order
54
+ # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
55
+ # )
56
+ return [
57
+ GemmConfig(
58
+ tile_m=tile_m,
59
+ tile_n=tile_n,
60
+ pingpong=pingpong,
61
+ cluster_m=cluster_m,
62
+ cluster_n=cluster_n,
63
+ swap_ab=swap_ab,
64
+ # raster_order=raster_order,
65
+ # max_swizzle_size=max_swizzle_size,
66
+ )
67
+ for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
68
+ tile_mn_vals,
69
+ cluster,
70
+ swap_ab_vals,
71
+ # raster_swizzle,
72
+ )
73
+ ]
74
+ elif device_capacity == 10:
75
+ tile_n_vals = [128, 160, 192, 224, 256]
76
+ tile_n_64_vals = [128, 192, 256]
77
+ tile_mn_cluster_vals = (
78
+ [(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
79
+ # + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals]
80
+ + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
81
+ + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
62
82
  )
63
- for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
64
- tile_mn_vals,
65
- cluster,
66
- swap_ab_vals,
67
- # raster_swizzle,
68
- )
69
- ]
83
+ swap_ab_vals = [False, True]
84
+ if epilogue in ["lse", "gated"]:
85
+ swap_ab_vals = [False]
86
+ max_swizzle_size_vals = [4, 8, 16]
87
+ GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100
88
+ return [
89
+ GemmConfigCls(
90
+ tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms
91
+ )
92
+ for (m, n, (cm, cn)), sab, ms in itertools.product(
93
+ tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals
94
+ )
95
+ ]
@@ -1,40 +1,57 @@
1
1
  # Copyright (c) 2025, Tri Dao.
2
- from typing import Optional
2
+ from typing import Optional, Tuple
3
+ from functools import partial
3
4
 
4
5
  from torch import Tensor
5
6
 
6
7
  import cutlass
7
8
  import cutlass.cute as cute
8
- from cutlass import const_expr
9
+ from cutlass import Float32, const_expr
9
10
  import cutlass.torch as cutlass_torch
10
11
 
11
- from quack.gemm_act_sm90 import GemmActSm90
12
- from quack.cute_dsl_utils import get_max_active_clusters
12
+ from quack.gemm_sm90 import GemmSm90
13
+ from quack.gemm_sm100 import GemmSm100
14
+ from quack.gemm_default_epi import GemmDefaultEpiMixin
15
+ from quack.gemm_act import GemmActMixin
16
+ from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
13
17
  from quack.gemm_wrapper_utils import GemmWrapperBase
14
18
  import quack.activation
15
19
 
16
20
 
17
- class GemmDActSm90(GemmActSm90):
21
+ class GemmDActMixin(GemmActMixin):
18
22
  # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
19
23
  # and return 2 arguments (dx, out)
20
- EpilogueArguments = GemmActSm90.EpilogueArguments
21
- EpilogueParams = GemmActSm90.EpilogueParams
24
+ EpilogueArguments = GemmActMixin.EpilogueArguments
25
+ EpilogueParams = GemmActMixin.EpilogueParams
22
26
 
23
27
  @cute.jit
24
- def epi_visit_acc_subtile(
28
+ def epi_visit_subtile(
25
29
  self,
26
30
  params: EpilogueParams,
31
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
27
32
  tRS_rD: cute.Tensor,
28
33
  tRS_rC: Optional[cute.Tensor] = None,
29
34
  ) -> Optional[cute.Tensor]:
30
35
  assert tRS_rC is not None
36
+ # We don't add C to the accumulator
37
+ GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
31
38
  tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
32
39
  tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
33
40
  # If we don't have .shape here, the compiler generates local stores and loads
34
41
  if const_expr(params.act_fn is not None):
35
42
  tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
36
- for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
37
- tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
43
+ if const_expr(self.arch < 100):
44
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
45
+ tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
46
+ else:
47
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
48
+ (
49
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
50
+ (tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]),
51
+ ) = params.act_fn(
52
+ (tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]),
53
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
54
+ )
38
55
  else:
39
56
  tRS_rPostAct = tRS_rC_acc
40
57
  # Type conversion
@@ -43,6 +60,14 @@ class GemmDActSm90(GemmActSm90):
43
60
  return tRS_rPostAct_out
44
61
 
45
62
 
63
+ class GemmDActSm90(GemmDActMixin, GemmSm90):
64
+ pass
65
+
66
+
67
+ class GemmDActSm100(GemmDActMixin, GemmSm100):
68
+ pass
69
+
70
+
46
71
  dact_fn_map = {
47
72
  None: None,
48
73
  "relu": quack.activation.drelu,
@@ -51,7 +76,7 @@ dact_fn_map = {
51
76
  }
52
77
 
53
78
 
54
- def gemm_dact_sm90(
79
+ def gemm_dact(
55
80
  A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
56
81
  B: Tensor, # (l, n, k)
57
82
  Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
@@ -65,6 +90,7 @@ def gemm_dact_sm90(
65
90
  cluster_N: int,
66
91
  pingpong: bool = True,
67
92
  persistent: bool = True,
93
+ max_swizzle_size: int = 8,
68
94
  cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
69
95
  A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
70
96
  ) -> None:
@@ -100,10 +126,14 @@ def gemm_dact_sm90(
100
126
  }
101
127
  GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
102
128
 
103
- acc_dtype = cutlass.Float32
129
+ device_capacity = get_device_capacity(A.device)
130
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
131
+ GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90
132
+
133
+ acc_dtype = Float32
104
134
  tile_shape_mn = (tile_M, tile_N)
105
135
  cluster_shape_mnk = (cluster_M, cluster_N, 1)
106
- if not GemmDActSm90.is_valid_dtypes(
136
+ if not GemmCls.is_valid_dtypes(
107
137
  tensor_infos["A"].dtype,
108
138
  tensor_infos["B"].dtype,
109
139
  acc_dtype,
@@ -116,9 +146,9 @@ def gemm_dact_sm90(
116
146
  max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
117
147
  GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
118
148
  act_fn = dact_fn_map[activation]
119
- epi_args = GemmDActSm90.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
149
+ epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
120
150
  scheduler_args = GemmWrapperBase.create_scheduler_args(
121
- max_active_clusters, tile_count_semaphore
151
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
122
152
  )
123
153
 
124
154
  # Create varlen arguments if needed (assumes persistent=True when varlen_m)
@@ -129,7 +159,7 @@ def gemm_dact_sm90(
129
159
  max_active_clusters,
130
160
  cluster_shape_mnk,
131
161
  tensor_infos,
132
- GemmDActSm90.num_epi_tensormaps,
162
+ GemmCls.num_epi_tensormaps,
133
163
  pingpong,
134
164
  )
135
165
 
@@ -142,19 +172,21 @@ def gemm_dact_sm90(
142
172
  pingpong,
143
173
  persistent,
144
174
  tile_count_semaphore is not None,
175
+ device_capacity,
176
+ max_swizzle_size,
145
177
  cu_seqlens_m is not None,
146
178
  A_idx is not None,
147
179
  key_tensor_names=("A", "B", "D", "PostAct", "C"),
148
180
  )
149
- cache = gemm_dact_sm90.compile_cache
181
+ cache = gemm_dact.compile_cache
150
182
  if compile_key not in cache:
151
- gemm = GemmDActSm90(
183
+ if device_capacity[0] == 9:
184
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
185
+ gemm = GemmCls(
152
186
  acc_dtype,
153
187
  tensor_infos["A"].dtype,
154
188
  tile_shape_mn,
155
189
  cluster_shape_mnk,
156
- pingpong=pingpong,
157
- is_persistent=persistent,
158
190
  gather_A=gather_A,
159
191
  )
160
192
  cache[compile_key] = cute.compile(
@@ -180,4 +212,4 @@ def gemm_dact_sm90(
180
212
  )
181
213
 
182
214
 
183
- gemm_dact_sm90.compile_cache = {}
215
+ gemm_dact.compile_cache = {}
@@ -0,0 +1,259 @@
1
+ # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
+ from typing import Optional, Tuple
3
+ from functools import partial
4
+ from dataclasses import dataclass
5
+
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32, Float32, Boolean, const_expr
10
+
11
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
12
+ from quack.gemm_sm90 import GemmSm90
13
+ from quack.gemm_sm100 import GemmSm100
14
+ from quack.sm90_utils import partition_for_epilogue
15
+ import quack.utils as utils
16
+ import quack.copy_utils as copy_utils
17
+ from quack.varlen_utils import VarlenManager
18
+
19
+
20
+ class GemmDefaultEpiMixin:
21
+ num_epi_tensormaps: int = 0
22
+
23
+ @dataclass
24
+ class EpilogueArguments(ArgumentsBase):
25
+ alpha: Optional[Float32 | cute.Tensor] = None
26
+ beta: Optional[Float32 | cute.Tensor] = None
27
+ mRowVecBroadcast: Optional[cute.Tensor] = None
28
+ mColVecBroadcast: Optional[cute.Tensor] = None
29
+ add_to_output: bool = False
30
+
31
+ @dataclass
32
+ class EpilogueParams(ParamsBase):
33
+ alpha: Optional[Float32 | cute.Tensor] = None
34
+ beta: Optional[Float32 | cute.Tensor] = None
35
+ mRowVecBroadcast: Optional[cute.Tensor] = None
36
+ mColVecBroadcast: Optional[cute.Tensor] = None
37
+
38
+ def epi_to_underlying_arguments(
39
+ self, args: EpilogueArguments, *, loc=None, ip=None
40
+ ) -> EpilogueParams:
41
+ # Assume all strides are divisible by 32 bits except the last stride
42
+ new_stride = lambda t: tuple(
43
+ cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
44
+ for s in t.stride
45
+ )
46
+ mRowVecBroadcast, mColVecBroadcast = [
47
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
48
+ if t is not None
49
+ else None
50
+ for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
51
+ ]
52
+ return self.EpilogueParams(
53
+ alpha=args.alpha,
54
+ beta=args.beta,
55
+ mRowVecBroadcast=mRowVecBroadcast,
56
+ mColVecBroadcast=mColVecBroadcast,
57
+ )
58
+
59
+ @cute.jit
60
+ def epi_begin(
61
+ self,
62
+ params: EpilogueParams,
63
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
64
+ epi_tile: cute.Tile,
65
+ tiled_copy_t2r: Optional[cute.TiledCopy],
66
+ tiled_copy_r2s: cute.TiledCopy,
67
+ tile_coord_mnkl: cute.Coord,
68
+ varlen_manager: VarlenManager,
69
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
70
+ tidx: Int32,
71
+ ):
72
+ alpha, beta = None, None
73
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
74
+ alpha = utils.load_scalar_or_pointer(params.alpha)
75
+ if const_expr(hasattr(params, "beta") and params.beta is not None):
76
+ beta = utils.load_scalar_or_pointer(params.beta)
77
+ sRowVec, sColVec, *rest = epi_smem_tensors
78
+ tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
79
+ batch_idx = tile_coord_mnkl[3]
80
+ num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
81
+ # Don't need sync as we assume the previous epilogue has finished
82
+
83
+ partition_for_epilogue_fn = partial(
84
+ partition_for_epilogue,
85
+ epi_tile=epi_tile,
86
+ tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
87
+ tidx=tidx,
88
+ reference_src=tiled_copy_t2r is None,
89
+ )
90
+
91
+ tDsRowVec = None
92
+ if const_expr(params.mRowVecBroadcast is not None):
93
+ rowvec_dtype = params.mRowVecBroadcast.element_type
94
+ num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width
95
+ thr_copy_RV = copy_utils.tiled_copy_1d(
96
+ params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
97
+ ).get_slice(tidx)
98
+ mRowVec = params.mRowVecBroadcast[batch_idx, None]
99
+ gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],))
100
+ tRVgRV = thr_copy_RV.partition_S(gRowVec)
101
+ tRVsRV = thr_copy_RV.partition_D(sRowVec)
102
+ tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
103
+ limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
104
+ tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
105
+ for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
106
+ tRVpRV[0, m] = tRVcRV[0, m] < limit_n
107
+ cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
108
+ # (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
109
+ tDsRowVec = partition_for_epilogue_fn(
110
+ cute.make_tensor(
111
+ sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1))
112
+ )
113
+ )
114
+ if const_expr(tiled_copy_t2r is not None):
115
+ tDsRowVec = tiled_copy_r2s.retile(tDsRowVec)
116
+
117
+ tDsColVec = None
118
+ if const_expr(params.mColVecBroadcast is not None):
119
+ colvec_dtype = params.mColVecBroadcast.element_type
120
+ num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width
121
+ thr_copy_CV = copy_utils.tiled_copy_1d(
122
+ params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
123
+ ).get_slice(tidx)
124
+ if const_expr(not varlen_manager.varlen_m):
125
+ mColVec = params.mColVecBroadcast[batch_idx, None]
126
+ else:
127
+ mColVec = cute.domain_offset(
128
+ (varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast
129
+ )
130
+ gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
131
+ tCVgCV = thr_copy_CV.partition_S(gColVec)
132
+ tCVsCV = thr_copy_CV.partition_D(sColVec)
133
+ tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
134
+ limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
135
+ tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
136
+ for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
137
+ tCVpCV[0, m] = tCVcCV[0, m] < limit_m
138
+ cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
139
+ tDsColVec = partition_for_epilogue_fn(
140
+ cute.make_tensor(
141
+ sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0))
142
+ )
143
+ )
144
+ if const_expr(tiled_copy_t2r is not None):
145
+ tDsColVec = tiled_copy_r2s.retile(tDsColVec)
146
+
147
+ if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None):
148
+ cute.arch.cp_async_commit_group()
149
+ cute.arch.cp_async_wait_group(0)
150
+ epilogue_barrier.arrive_and_wait()
151
+ return alpha, beta, tDsRowVec, tDsColVec
152
+
153
+ def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
154
+ alpha, beta, tDsRowVec, tDsColVec = epi_tensors
155
+ tDrRowVec_cvt = None
156
+ if const_expr(tDsRowVec is not None):
157
+ tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[
158
+ None, None, None, epi_coord
159
+ ]
160
+ # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
161
+ tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
162
+ cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
163
+ tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
164
+ tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
165
+ tDrColVec_cvt = None
166
+ if const_expr(tDsColVec is not None):
167
+ tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[
168
+ None, None, None, epi_coord
169
+ ]
170
+ # This somehow doesn't work, some dim with stride 0 turns to non-zero stride
171
+ # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
172
+ tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
173
+ cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
174
+ tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
175
+ tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
176
+ return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt
177
+
178
+ @cute.jit
179
+ def epi_visit_subtile(
180
+ self,
181
+ params: EpilogueParams,
182
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
183
+ tRS_rD: cute.Tensor,
184
+ tRS_rC: Optional[cute.Tensor] = None,
185
+ ) -> Optional[cute.Tensor]:
186
+ alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors
187
+ rD = tRS_rD.load()
188
+ # Apply alpha scaling to accumulator if alpha is provided (not None)
189
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
190
+ alpha = utils.load_scalar_or_pointer(params.alpha)
191
+ rD *= alpha
192
+ # Apply C with beta scaling
193
+ if const_expr(tRS_rC is not None):
194
+ if const_expr(not hasattr(params, "beta") or params.beta is None):
195
+ # beta is None, default behavior: add C (beta=1.0)
196
+ rD += tRS_rC.load().to(tRS_rD.element_type)
197
+ else:
198
+ beta = utils.load_scalar_or_pointer(params.beta)
199
+ rD += beta * tRS_rC.load().to(tRS_rD.element_type)
200
+ tRS_rD.store(rD)
201
+ if const_expr(tDrRowVec is not None):
202
+ for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True):
203
+ tRS_rD[i] += tDrRowVec[i]
204
+ if const_expr(tDrColVec is not None):
205
+ for i in cutlass.range(cute.size(tDrColVec), unroll_full=True):
206
+ tRS_rD[i] += tDrColVec[i]
207
+ return None
208
+
209
+ @staticmethod
210
+ def epi_smem_bytes_per_stage(
211
+ args: Optional[EpilogueArguments],
212
+ cta_tile_shape_mnk: Tuple[int, int, int],
213
+ epi_tile: cute.Tile,
214
+ ) -> int:
215
+ row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1]
216
+ col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0]
217
+ row_vec_dtype = (
218
+ args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32
219
+ )
220
+ col_vec_dtype = (
221
+ args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32
222
+ )
223
+ return (
224
+ row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width
225
+ ) // 8
226
+
227
+ def epi_get_smem_struct(self, params: EpilogueParams):
228
+ row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
229
+ col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
230
+ row_vec_dtype = (
231
+ params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
232
+ )
233
+ col_vec_dtype = (
234
+ params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
235
+ )
236
+
237
+ @cute.struct
238
+ class EpiSharedStorage:
239
+ sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
240
+ sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
241
+
242
+ return EpiSharedStorage
243
+
244
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
245
+ sRowVec = None
246
+ if const_expr(params.mRowVecBroadcast is not None):
247
+ sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1]))
248
+ sColVec = None
249
+ if const_expr(params.mColVecBroadcast is not None):
250
+ sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0]))
251
+ return (sRowVec, sColVec)
252
+
253
+
254
+ class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
255
+ pass
256
+
257
+
258
+ class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
259
+ pass