quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_dact.py ADDED
@@ -0,0 +1,215 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Optional, Tuple
3
+ from functools import partial
4
+
5
+ from torch import Tensor
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Float32, const_expr
10
+ import cutlass.torch as cutlass_torch
11
+
12
+ from quack.gemm_sm90 import GemmSm90
13
+ from quack.gemm_sm100 import GemmSm100
14
+ from quack.gemm_default_epi import GemmDefaultEpiMixin
15
+ from quack.gemm_act import GemmActMixin
16
+ from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
17
+ from quack.gemm_wrapper_utils import GemmWrapperBase
18
+ import quack.activation
19
+
20
+
21
+ class GemmDActMixin(GemmActMixin):
22
+ # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
23
+ # and return 2 arguments (dx, out)
24
+ EpilogueArguments = GemmActMixin.EpilogueArguments
25
+ EpilogueParams = GemmActMixin.EpilogueParams
26
+
27
+ @cute.jit
28
+ def epi_visit_subtile(
29
+ self,
30
+ params: EpilogueParams,
31
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
32
+ tRS_rD: cute.Tensor,
33
+ tRS_rC: Optional[cute.Tensor] = None,
34
+ ) -> Optional[cute.Tensor]:
35
+ assert tRS_rC is not None
36
+ # We don't add C to the accumulator
37
+ GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
38
+ tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
39
+ tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
40
+ # If we don't have .shape here, the compiler generates local stores and loads
41
+ if const_expr(params.act_fn is not None):
42
+ tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
43
+ if const_expr(self.arch < 100):
44
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
45
+ tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
46
+ else:
47
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
48
+ (
49
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
50
+ (tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]),
51
+ ) = params.act_fn(
52
+ (tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]),
53
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
54
+ )
55
+ else:
56
+ tRS_rPostAct = tRS_rC_acc
57
+ # Type conversion
58
+ tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
59
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
60
+ return tRS_rPostAct_out
61
+
62
+
63
+ class GemmDActSm90(GemmDActMixin, GemmSm90):
64
+ pass
65
+
66
+
67
+ class GemmDActSm100(GemmDActMixin, GemmSm100):
68
+ pass
69
+
70
+
71
+ dact_fn_map = {
72
+ None: None,
73
+ "relu": quack.activation.drelu,
74
+ "relu_sq": quack.activation.drelu_sq,
75
+ "gelu_tanh_approx": quack.activation.dgelu_tanh_approx,
76
+ }
77
+
78
+
79
+ def gemm_dact(
80
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
81
+ B: Tensor, # (l, n, k)
82
+ Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
83
+ PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
84
+ PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
85
+ tile_count_semaphore: Optional[Tensor], # (1,)
86
+ activation: Optional[str],
87
+ tile_M: int,
88
+ tile_N: int,
89
+ cluster_M: int,
90
+ cluster_N: int,
91
+ pingpong: bool = True,
92
+ persistent: bool = True,
93
+ max_swizzle_size: int = 8,
94
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
95
+ A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
96
+ ) -> None:
97
+ if cu_seqlens_m is not None:
98
+ assert persistent, "varlen_m requires persistent=True"
99
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
100
+ assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
101
+ assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
102
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
103
+ gather_A = A_idx is not None
104
+ if gather_A:
105
+ assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
106
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
107
+ assert activation in dact_fn_map, f"Unsupported activation {activation}"
108
+
109
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
110
+ A,
111
+ B,
112
+ Out,
113
+ PreAct,
114
+ additional_tensors={"PostAct": PostAct},
115
+ cu_seqlens_m=cu_seqlens_m,
116
+ A_idx=A_idx,
117
+ )
118
+ GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
119
+ GemmWrapperBase.extract_dtypes(tensor_infos)
120
+ major_configs = {
121
+ "A": ("m", "k", "l"),
122
+ "B": ("n", "k", "l"),
123
+ "D": ("m", "n", "l"),
124
+ "C": ("m", "n", "l"),
125
+ "PostAct": ("m", "n", "l"),
126
+ }
127
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
128
+
129
+ device_capacity = get_device_capacity(A.device)
130
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
131
+ GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90
132
+
133
+ acc_dtype = Float32
134
+ tile_shape_mn = (tile_M, tile_N)
135
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
136
+ if not GemmCls.is_valid_dtypes(
137
+ tensor_infos["A"].dtype,
138
+ tensor_infos["B"].dtype,
139
+ acc_dtype,
140
+ tensor_infos["D"].dtype,
141
+ tensor_infos["A"].major,
142
+ tensor_infos["B"].major,
143
+ ):
144
+ raise TypeError("Skipping due to unsupported combination of types and majors")
145
+
146
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
147
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
148
+ act_fn = dact_fn_map[activation]
149
+ epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
150
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
151
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
152
+ )
153
+
154
+ # Create varlen arguments if needed (assumes persistent=True when varlen_m)
155
+ varlen_args = GemmWrapperBase.create_varlen_args(
156
+ cu_seqlens_m,
157
+ None, # cu_seqlens_k
158
+ A_idx,
159
+ max_active_clusters,
160
+ cluster_shape_mnk,
161
+ tensor_infos,
162
+ GemmCls.num_epi_tensormaps,
163
+ pingpong,
164
+ )
165
+
166
+ current_stream = cutlass_torch.current_stream()
167
+ compile_key = GemmWrapperBase.get_compile_key(
168
+ tensor_infos,
169
+ activation,
170
+ tile_shape_mn,
171
+ cluster_shape_mnk,
172
+ pingpong,
173
+ persistent,
174
+ tile_count_semaphore is not None,
175
+ device_capacity,
176
+ max_swizzle_size,
177
+ cu_seqlens_m is not None,
178
+ A_idx is not None,
179
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
180
+ )
181
+ cache = gemm_dact.compile_cache
182
+ if compile_key not in cache:
183
+ if device_capacity[0] == 9:
184
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
185
+ gemm = GemmCls(
186
+ acc_dtype,
187
+ tensor_infos["A"].dtype,
188
+ tile_shape_mn,
189
+ cluster_shape_mnk,
190
+ gather_A=gather_A,
191
+ )
192
+ cache[compile_key] = cute.compile(
193
+ gemm,
194
+ tensor_infos["A"].cute_tensor,
195
+ tensor_infos["B"].cute_tensor,
196
+ tensor_infos["D"].cute_tensor,
197
+ tensor_infos["C"].cute_tensor,
198
+ epi_args,
199
+ scheduler_args,
200
+ varlen_args,
201
+ current_stream,
202
+ )
203
+ cache[compile_key](
204
+ tensor_infos["A"].cute_tensor,
205
+ tensor_infos["B"].cute_tensor,
206
+ tensor_infos["D"].cute_tensor,
207
+ tensor_infos["C"].cute_tensor,
208
+ epi_args,
209
+ scheduler_args,
210
+ varlen_args,
211
+ current_stream,
212
+ )
213
+
214
+
215
+ gemm_dact.compile_cache = {}
@@ -0,0 +1,259 @@
1
+ # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
+ from typing import Optional, Tuple
3
+ from functools import partial
4
+ from dataclasses import dataclass
5
+
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32, Float32, Boolean, const_expr
10
+
11
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
12
+ from quack.gemm_sm90 import GemmSm90
13
+ from quack.gemm_sm100 import GemmSm100
14
+ from quack.sm90_utils import partition_for_epilogue
15
+ import quack.utils as utils
16
+ import quack.copy_utils as copy_utils
17
+ from quack.varlen_utils import VarlenManager
18
+
19
+
20
+ class GemmDefaultEpiMixin:
21
+ num_epi_tensormaps: int = 0
22
+
23
+ @dataclass
24
+ class EpilogueArguments(ArgumentsBase):
25
+ alpha: Optional[Float32 | cute.Tensor] = None
26
+ beta: Optional[Float32 | cute.Tensor] = None
27
+ mRowVecBroadcast: Optional[cute.Tensor] = None
28
+ mColVecBroadcast: Optional[cute.Tensor] = None
29
+ add_to_output: bool = False
30
+
31
+ @dataclass
32
+ class EpilogueParams(ParamsBase):
33
+ alpha: Optional[Float32 | cute.Tensor] = None
34
+ beta: Optional[Float32 | cute.Tensor] = None
35
+ mRowVecBroadcast: Optional[cute.Tensor] = None
36
+ mColVecBroadcast: Optional[cute.Tensor] = None
37
+
38
+ def epi_to_underlying_arguments(
39
+ self, args: EpilogueArguments, *, loc=None, ip=None
40
+ ) -> EpilogueParams:
41
+ # Assume all strides are divisible by 32 bits except the last stride
42
+ new_stride = lambda t: tuple(
43
+ cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
44
+ for s in t.stride
45
+ )
46
+ mRowVecBroadcast, mColVecBroadcast = [
47
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
48
+ if t is not None
49
+ else None
50
+ for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
51
+ ]
52
+ return self.EpilogueParams(
53
+ alpha=args.alpha,
54
+ beta=args.beta,
55
+ mRowVecBroadcast=mRowVecBroadcast,
56
+ mColVecBroadcast=mColVecBroadcast,
57
+ )
58
+
59
+ @cute.jit
60
+ def epi_begin(
61
+ self,
62
+ params: EpilogueParams,
63
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
64
+ epi_tile: cute.Tile,
65
+ tiled_copy_t2r: Optional[cute.TiledCopy],
66
+ tiled_copy_r2s: cute.TiledCopy,
67
+ tile_coord_mnkl: cute.Coord,
68
+ varlen_manager: VarlenManager,
69
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
70
+ tidx: Int32,
71
+ ):
72
+ alpha, beta = None, None
73
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
74
+ alpha = utils.load_scalar_or_pointer(params.alpha)
75
+ if const_expr(hasattr(params, "beta") and params.beta is not None):
76
+ beta = utils.load_scalar_or_pointer(params.beta)
77
+ sRowVec, sColVec, *rest = epi_smem_tensors
78
+ tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
79
+ batch_idx = tile_coord_mnkl[3]
80
+ num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
81
+ # Don't need sync as we assume the previous epilogue has finished
82
+
83
+ partition_for_epilogue_fn = partial(
84
+ partition_for_epilogue,
85
+ epi_tile=epi_tile,
86
+ tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
87
+ tidx=tidx,
88
+ reference_src=tiled_copy_t2r is None,
89
+ )
90
+
91
+ tDsRowVec = None
92
+ if const_expr(params.mRowVecBroadcast is not None):
93
+ rowvec_dtype = params.mRowVecBroadcast.element_type
94
+ num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width
95
+ thr_copy_RV = copy_utils.tiled_copy_1d(
96
+ params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
97
+ ).get_slice(tidx)
98
+ mRowVec = params.mRowVecBroadcast[batch_idx, None]
99
+ gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],))
100
+ tRVgRV = thr_copy_RV.partition_S(gRowVec)
101
+ tRVsRV = thr_copy_RV.partition_D(sRowVec)
102
+ tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
103
+ limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
104
+ tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
105
+ for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
106
+ tRVpRV[0, m] = tRVcRV[0, m] < limit_n
107
+ cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
108
+ # (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
109
+ tDsRowVec = partition_for_epilogue_fn(
110
+ cute.make_tensor(
111
+ sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1))
112
+ )
113
+ )
114
+ if const_expr(tiled_copy_t2r is not None):
115
+ tDsRowVec = tiled_copy_r2s.retile(tDsRowVec)
116
+
117
+ tDsColVec = None
118
+ if const_expr(params.mColVecBroadcast is not None):
119
+ colvec_dtype = params.mColVecBroadcast.element_type
120
+ num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width
121
+ thr_copy_CV = copy_utils.tiled_copy_1d(
122
+ params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
123
+ ).get_slice(tidx)
124
+ if const_expr(not varlen_manager.varlen_m):
125
+ mColVec = params.mColVecBroadcast[batch_idx, None]
126
+ else:
127
+ mColVec = cute.domain_offset(
128
+ (varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast
129
+ )
130
+ gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
131
+ tCVgCV = thr_copy_CV.partition_S(gColVec)
132
+ tCVsCV = thr_copy_CV.partition_D(sColVec)
133
+ tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
134
+ limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
135
+ tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
136
+ for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
137
+ tCVpCV[0, m] = tCVcCV[0, m] < limit_m
138
+ cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
139
+ tDsColVec = partition_for_epilogue_fn(
140
+ cute.make_tensor(
141
+ sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0))
142
+ )
143
+ )
144
+ if const_expr(tiled_copy_t2r is not None):
145
+ tDsColVec = tiled_copy_r2s.retile(tDsColVec)
146
+
147
+ if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None):
148
+ cute.arch.cp_async_commit_group()
149
+ cute.arch.cp_async_wait_group(0)
150
+ epilogue_barrier.arrive_and_wait()
151
+ return alpha, beta, tDsRowVec, tDsColVec
152
+
153
+ def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
154
+ alpha, beta, tDsRowVec, tDsColVec = epi_tensors
155
+ tDrRowVec_cvt = None
156
+ if const_expr(tDsRowVec is not None):
157
+ tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[
158
+ None, None, None, epi_coord
159
+ ]
160
+ # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
161
+ tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
162
+ cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
163
+ tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
164
+ tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
165
+ tDrColVec_cvt = None
166
+ if const_expr(tDsColVec is not None):
167
+ tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[
168
+ None, None, None, epi_coord
169
+ ]
170
+ # This somehow doesn't work, some dim with stride 0 turns to non-zero stride
171
+ # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
172
+ tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
173
+ cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
174
+ tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
175
+ tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
176
+ return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt
177
+
178
+ @cute.jit
179
+ def epi_visit_subtile(
180
+ self,
181
+ params: EpilogueParams,
182
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
183
+ tRS_rD: cute.Tensor,
184
+ tRS_rC: Optional[cute.Tensor] = None,
185
+ ) -> Optional[cute.Tensor]:
186
+ alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors
187
+ rD = tRS_rD.load()
188
+ # Apply alpha scaling to accumulator if alpha is provided (not None)
189
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
190
+ alpha = utils.load_scalar_or_pointer(params.alpha)
191
+ rD *= alpha
192
+ # Apply C with beta scaling
193
+ if const_expr(tRS_rC is not None):
194
+ if const_expr(not hasattr(params, "beta") or params.beta is None):
195
+ # beta is None, default behavior: add C (beta=1.0)
196
+ rD += tRS_rC.load().to(tRS_rD.element_type)
197
+ else:
198
+ beta = utils.load_scalar_or_pointer(params.beta)
199
+ rD += beta * tRS_rC.load().to(tRS_rD.element_type)
200
+ tRS_rD.store(rD)
201
+ if const_expr(tDrRowVec is not None):
202
+ for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True):
203
+ tRS_rD[i] += tDrRowVec[i]
204
+ if const_expr(tDrColVec is not None):
205
+ for i in cutlass.range(cute.size(tDrColVec), unroll_full=True):
206
+ tRS_rD[i] += tDrColVec[i]
207
+ return None
208
+
209
+ @staticmethod
210
+ def epi_smem_bytes_per_stage(
211
+ args: Optional[EpilogueArguments],
212
+ cta_tile_shape_mnk: Tuple[int, int, int],
213
+ epi_tile: cute.Tile,
214
+ ) -> int:
215
+ row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1]
216
+ col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0]
217
+ row_vec_dtype = (
218
+ args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32
219
+ )
220
+ col_vec_dtype = (
221
+ args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32
222
+ )
223
+ return (
224
+ row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width
225
+ ) // 8
226
+
227
+ def epi_get_smem_struct(self, params: EpilogueParams):
228
+ row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
229
+ col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
230
+ row_vec_dtype = (
231
+ params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
232
+ )
233
+ col_vec_dtype = (
234
+ params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
235
+ )
236
+
237
+ @cute.struct
238
+ class EpiSharedStorage:
239
+ sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
240
+ sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
241
+
242
+ return EpiSharedStorage
243
+
244
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
245
+ sRowVec = None
246
+ if const_expr(params.mRowVecBroadcast is not None):
247
+ sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1]))
248
+ sColVec = None
249
+ if const_expr(params.mColVecBroadcast is not None):
250
+ sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0]))
251
+ return (sRowVec, sColVec)
252
+
253
+
254
+ class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
255
+ pass
256
+
257
+
258
+ class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
259
+ pass