quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/fast_math.py ADDED
@@ -0,0 +1,80 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32, Uint32
9
+ from cutlass.cutlass_dsl import T, dsl_user_op
10
+ from cutlass._mlir.dialects import llvm
11
+
12
+ from quack.cute_dsl_utils import ParamsBase
13
+
14
+
15
+ @cute.jit
16
+ def clz(x: Int32) -> Int32:
17
+ # for i in cutlass.range_constexpr(32):
18
+ # if (1 << (31 - i)) & x:
19
+ # return Int32(i)
20
+ # return Int32(32)
21
+ # Early exit is not supported yet
22
+ res = Int32(32)
23
+ done = False
24
+ for i in cutlass.range(32):
25
+ if ((1 << (31 - i)) & x) and not done:
26
+ res = Int32(i)
27
+ done = True
28
+ return res
29
+
30
+
31
+ def find_log2(x: Int32) -> Int32:
32
+ a: Int32 = Int32(31 - clz(x))
33
+ return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.
34
+
35
+
36
+ @dsl_user_op
37
+ def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
38
+ return Uint32(
39
+ llvm.inline_asm(
40
+ T.i32(),
41
+ [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
42
+ "mul.hi.u32 $0, $1, $2;",
43
+ "=r,r,r",
44
+ has_side_effects=False,
45
+ is_align_stack=False,
46
+ asm_dialect=llvm.AsmDialect.AD_ATT,
47
+ )
48
+ )
49
+
50
+
51
+ @dataclass
52
+ class FastDivmod(ParamsBase):
53
+ divisor: Int32
54
+ multiplier: Uint32
55
+ shift_right: Uint32
56
+
57
+ # called by host
58
+ @staticmethod
59
+ def create(divisor: Int32) -> "FastDivmod":
60
+ """Construct the FastDivmod object, in host code.
61
+ This precomputes some values based on the divisor and is computationally expensive.
62
+ """
63
+ p = Uint32(31 + find_log2(divisor))
64
+ divisor_u32 = Uint32(divisor)
65
+ multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
66
+ shift_right = Uint32(p - 32)
67
+ return FastDivmod(divisor, multiplier, shift_right)
68
+
69
+ @cute.jit
70
+ def div(self, dividend: Int32) -> Int32:
71
+ return (
72
+ Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
73
+ if self.divisor != 1
74
+ else dividend
75
+ )
76
+
77
+ def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
78
+ quotient = self.div(dividend)
79
+ remainder = dividend - quotient * self.divisor
80
+ return quotient, remainder
quack/gemm_act_sm90.py ADDED
@@ -0,0 +1,368 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Tuple, Optional, Callable
3
+ from dataclasses import dataclass
4
+
5
+ from torch import Tensor
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass.cute.nvgpu import warpgroup
10
+ import cutlass.utils.hopper_helpers as sm90_utils
11
+ from cutlass import Int32, Float32, Boolean, const_expr
12
+ import cutlass.torch as cutlass_torch
13
+
14
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
15
+ from quack.dense_gemm_sm90 import GemmSm90
16
+ from quack.cute_dsl_utils import get_max_active_clusters
17
+ from quack.gemm_wrapper_utils import GemmWrapperBase
18
+ import quack.activation
19
+
20
+
21
+ class GemmActSm90(GemmSm90):
22
+ @dataclass
23
+ class EpilogueArguments(ArgumentsBase):
24
+ mPostAct: cute.Tensor
25
+ act_fn: cutlass.Constexpr[Optional[Callable]] = None
26
+ alpha: Optional[Float32] = None
27
+ beta: Optional[Float32] = None
28
+
29
+ @dataclass
30
+ class EpilogueParams(ParamsBase):
31
+ tma_atom_postact: cute.CopyAtom
32
+ mPostAct_mnl: cute.Tensor
33
+ epi_postact_smem_layout_staged: cute.ComposedLayout
34
+ act_fn: cutlass.Constexpr[Optional[Callable]] = None
35
+ alpha: Optional[Float32] = None
36
+ beta: Optional[Float32] = None
37
+
38
+ def epi_to_underlying_arguments(
39
+ self, args: EpilogueArguments, *, loc=None, ip=None
40
+ ) -> EpilogueParams:
41
+ self.postact_dtype = args.mPostAct.element_type
42
+ self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
43
+
44
+ self.tile_shape_postact_mn = self.tile_shape_mnk[:2]
45
+ self.epi_tile_postact = self.epi_tile
46
+ postact_major_mode_size = (
47
+ self.epi_tile_postact[1]
48
+ if self.postact_layout.is_n_major_c()
49
+ else self.epi_tile_postact[0]
50
+ )
51
+ postact_smem_layout_atom = warpgroup.make_smem_layout_atom(
52
+ sm90_utils.get_smem_layout_atom(
53
+ self.postact_layout, self.postact_dtype, postact_major_mode_size
54
+ ),
55
+ self.postact_dtype,
56
+ )
57
+ epi_postact_smem_layout_staged = cute.tile_to_shape(
58
+ postact_smem_layout_atom,
59
+ cute.append(self.epi_tile_postact, self.epi_stage),
60
+ order=(0, 1, 2),
61
+ )
62
+ tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
63
+ args.mPostAct,
64
+ epi_postact_smem_layout_staged,
65
+ self.epi_tile_postact,
66
+ store_or_load="store",
67
+ )
68
+ return GemmActSm90.EpilogueParams(
69
+ tma_atom_postact,
70
+ tma_tensor_postact,
71
+ epi_postact_smem_layout_staged,
72
+ args.act_fn,
73
+ args.alpha,
74
+ args.beta,
75
+ )
76
+
77
+ @staticmethod
78
+ def epi_smem_bytes_per_stage(
79
+ args: EpilogueArguments,
80
+ tile_shape_mnk: Tuple[int, int, int],
81
+ epi_tile: Tuple[int, int],
82
+ ) -> int:
83
+ postact_dtype = args.mPostAct.element_type
84
+ postact_bytes_per_stage = cute.size(epi_tile) * (postact_dtype.width // 8)
85
+ return postact_bytes_per_stage
86
+
87
+ def epi_get_smem_struct(self, params: EpilogueParams):
88
+ @cute.struct
89
+ class EpiSharedStorage:
90
+ sPostAct: cute.struct.Align[
91
+ cute.struct.MemRange[
92
+ self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
93
+ ],
94
+ self.buffer_align_bytes,
95
+ ]
96
+
97
+ return EpiSharedStorage
98
+
99
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
100
+ sPostAct = storage.epi.sPostAct.get_tensor(
101
+ params.epi_postact_smem_layout_staged.outer,
102
+ swizzle=params.epi_postact_smem_layout_staged.inner,
103
+ )
104
+ return (sPostAct,)
105
+
106
+ @cute.jit
107
+ def epilogue(
108
+ self,
109
+ params: EpilogueParams,
110
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
111
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
112
+ epi_read_state: cutlass.pipeline.PipelineState,
113
+ epi_producer_state: cutlass.pipeline.PipelineState,
114
+ tiled_mma: cute.TiledMma,
115
+ tRS_rAcc: cute.Tensor,
116
+ tRS_rD: cute.Tensor,
117
+ tRS_rC: Optional[cute.Tensor],
118
+ tiled_copy_r2s: cute.core.ThrCopy,
119
+ tRS_sD: cute.Tensor,
120
+ tiled_copy_s2r: Optional[cute.core.ThrCopy],
121
+ tSR_rC: Optional[cute.Tensor],
122
+ tSR_sC: Optional[cute.Tensor],
123
+ copy_D: Optional[Callable],
124
+ bSG_sD: cute.Tensor,
125
+ bSG_gD: cute.Tensor,
126
+ epi_load_g2s: Optional[Callable],
127
+ tile_coord_mnkl: cute.Coord,
128
+ cu_seqlens_m: Optional[cute.Tensor],
129
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
130
+ tile_scheduler,
131
+ tidx: Int32,
132
+ is_tma_warp: Boolean,
133
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
134
+ has_C = const_expr(tRS_rC is not None)
135
+ has_D = const_expr(copy_D is not None)
136
+ assert cu_seqlens_m is None, "GemmActSm90 doesn't support varlen_m for now"
137
+
138
+ tma_atom_postact = params.tma_atom_postact
139
+ mPostAct_mnl = params.mPostAct_mnl
140
+ (sPostAct,) = epi_smem_tensors
141
+ tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
142
+ copy_atom_postact_r2s = sm90_utils.sm90_get_smem_store_op(
143
+ self.postact_layout, elem_ty_d=self.postact_dtype, elem_ty_acc=self.acc_dtype
144
+ )
145
+ tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
146
+ thr_copy_postact_r2s = tiled_copy_postact_r2s.get_slice(tidx)
147
+ tRS_sPostAct = thr_copy_postact_r2s.partition_D(sPostAct)
148
+ bSG_sPostAct, bSG_gPostAct = self.epilog_gmem_copy_and_partition(
149
+ tma_atom_postact,
150
+ mPostAct_mnl,
151
+ self.tile_shape_postact_mn,
152
+ self.epi_tile_postact,
153
+ sPostAct,
154
+ tile_coord_mnkl,
155
+ cu_seqlens_m,
156
+ )
157
+
158
+ # We iterate over epi tiles in the N dimension first before the M dimension
159
+ epi_tile_shape = cute.zipped_divide(
160
+ cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
161
+ ).shape[1]
162
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
163
+ epi_tile_num = cute.size(epi_tile_shape)
164
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
165
+
166
+ if const_expr(epi_load_g2s is not None):
167
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
168
+ epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
169
+
170
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
171
+ # Copy from acc to D registers
172
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
173
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
174
+ if const_expr(has_C):
175
+ epi_pipeline.consumer_wait(epi_read_state)
176
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
177
+ # Fence to make sure shared memory read is visible to TMA load
178
+ cute.arch.fence_proxy(
179
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
180
+ )
181
+ cute.arch.sync_warp()
182
+ with cute.arch.elect_one():
183
+ epi_pipeline.consumer_release(epi_read_state)
184
+ epi_read_state.advance()
185
+ if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
186
+ epi_producer_state = epi_load_g2s(
187
+ epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
188
+ )
189
+ tRS_rPostAct = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
190
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
191
+ # Copy from D registers to shared memory
192
+ if const_expr(has_D):
193
+ # Type conversion
194
+ tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
195
+ tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
196
+ cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
197
+ cute.copy(
198
+ tiled_copy_postact_r2s,
199
+ tiled_copy_postact_r2s.retile(tRS_rPostAct),
200
+ tRS_sPostAct[None, None, None, epi_buffer],
201
+ )
202
+ # Fence and barrier to make sure shared memory store is visible to TMA store
203
+ cute.arch.fence_proxy(
204
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
205
+ )
206
+ epilogue_barrier.arrive_and_wait()
207
+ # Get the global memory coordinate for the current epi tile
208
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
209
+ # Copy from shared memory to global memory
210
+ if is_tma_warp:
211
+ if const_expr(has_D):
212
+ copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
213
+ cute.copy(
214
+ tma_atom_postact,
215
+ bSG_sPostAct[None, epi_buffer],
216
+ bSG_gPostAct[None, gmem_coord],
217
+ )
218
+ cute.arch.cp_async_bulk_commit_group()
219
+ cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
220
+ epilogue_barrier.arrive_and_wait()
221
+
222
+ return epi_read_state, epi_producer_state
223
+
224
+ @cute.jit
225
+ def epi_visit_acc_subtile(
226
+ self,
227
+ params: EpilogueParams,
228
+ tRS_rD: cute.Tensor,
229
+ tRS_rC: Optional[cute.Tensor] = None,
230
+ ) -> Optional[cute.Tensor]:
231
+ # Apply alpha scaling to accumulator if alpha is provided (not None)
232
+ if const_expr(params.alpha is not None):
233
+ tRS_rD.store(tRS_rD.load() * params.alpha)
234
+ # Apply C with beta scaling
235
+ if const_expr(tRS_rC is not None):
236
+ if const_expr(params.beta is None):
237
+ # beta is None, default behavior: add C (beta=1.0)
238
+ tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
239
+ else:
240
+ tRS_rD.store(tRS_rD.load() + params.beta * tRS_rC.load().to(tRS_rD.element_type))
241
+ # Apply activation function if provided
242
+ # If we don't have .shape here, the compiler generates local stores and loads
243
+ if const_expr(params.act_fn is not None):
244
+ tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
245
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
246
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
247
+ else:
248
+ tRS_rPostAct = tRS_rD
249
+ # Type conversion
250
+ tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
251
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
252
+ return tRS_rPostAct_out
253
+
254
+
255
+ act_fn_map = {
256
+ None: None,
257
+ "relu": quack.activation.relu,
258
+ "relu_sq": quack.activation.relu_sq,
259
+ "gelu_tanh_approx": quack.activation.gelu_tanh_approx,
260
+ }
261
+
262
+
263
+ def gemm_act_sm90(
264
+ A: Tensor, # (l, m, k)
265
+ B: Tensor, # (l, n, k)
266
+ D: Optional[Tensor], # (l, m, n)
267
+ C: Optional[Tensor], # (l, m, n)
268
+ PostAct: Tensor, # (l, m, n)
269
+ activation: Optional[str],
270
+ tile_M: int,
271
+ tile_N: int,
272
+ cluster_M: int,
273
+ cluster_N: int,
274
+ pingpong: bool = False,
275
+ persistent: bool = True,
276
+ alpha: float = 1.0,
277
+ beta: float = 1.0,
278
+ ) -> None:
279
+ tile_count_semaphore = None
280
+ assert activation in act_fn_map, f"Unsupported activation {activation}"
281
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
282
+ A, B, D, C, additional_tensors={"PostAct": PostAct}
283
+ )
284
+ GemmWrapperBase.permute_tensors(tensor_infos)
285
+ GemmWrapperBase.extract_dtypes(tensor_infos)
286
+ major_configs = {
287
+ "A": ("m", "k", "l"),
288
+ "B": ("n", "k", "l"),
289
+ "D": ("m", "n", "l"),
290
+ "C": ("m", "n", "l"),
291
+ "PostAct": ("m", "n", "l"),
292
+ }
293
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
294
+
295
+ acc_dtype = cutlass.Float32
296
+ tile_shape_mn = (tile_M, tile_N)
297
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
298
+ if not GemmActSm90.is_valid_dtypes(
299
+ tensor_infos["A"].dtype,
300
+ tensor_infos["B"].dtype,
301
+ acc_dtype,
302
+ tensor_infos["D"].dtype,
303
+ tensor_infos["A"].major,
304
+ tensor_infos["B"].major,
305
+ ):
306
+ raise TypeError("Skipping due to unsupported combination of types and majors")
307
+
308
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
309
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
310
+ act_fn = act_fn_map[activation]
311
+ epi_args = GemmActSm90.EpilogueArguments(
312
+ tensor_infos["PostAct"].cute_tensor,
313
+ act_fn,
314
+ alpha=Float32(alpha) if alpha != 1.0 else None,
315
+ beta=Float32(beta) if beta != 1.0 else None,
316
+ )
317
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
318
+ max_active_clusters, tile_count_semaphore
319
+ )
320
+ current_stream = cutlass_torch.current_stream()
321
+ compile_key = GemmWrapperBase.get_compile_key(
322
+ tensor_infos,
323
+ activation,
324
+ tile_shape_mn,
325
+ cluster_shape_mnk,
326
+ pingpong,
327
+ persistent,
328
+ tile_count_semaphore is not None,
329
+ alpha != 1.0,
330
+ beta != 1.0,
331
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
332
+ )
333
+ cache = gemm_act_sm90.compile_cache
334
+ if compile_key not in cache:
335
+ gemm = GemmActSm90(
336
+ acc_dtype,
337
+ tensor_infos["A"].dtype,
338
+ tile_shape_mn,
339
+ cluster_shape_mnk,
340
+ pingpong=pingpong,
341
+ is_persistent=persistent,
342
+ )
343
+ cache[compile_key] = cute.compile(
344
+ gemm,
345
+ tensor_infos["A"].cute_tensor,
346
+ tensor_infos["B"].cute_tensor,
347
+ tensor_infos["D"].cute_tensor,
348
+ tensor_infos["C"].cute_tensor,
349
+ epi_args,
350
+ scheduler_args,
351
+ None, # varlen_args
352
+ None, # mAIdx
353
+ current_stream,
354
+ )
355
+ cache[compile_key](
356
+ tensor_infos["A"].cute_tensor,
357
+ tensor_infos["B"].cute_tensor,
358
+ tensor_infos["D"].cute_tensor,
359
+ tensor_infos["C"].cute_tensor,
360
+ epi_args,
361
+ scheduler_args,
362
+ None,
363
+ None,
364
+ current_stream,
365
+ )
366
+
367
+
368
+ gemm_act_sm90.compile_cache = {}
quack/gemm_config.py ADDED
@@ -0,0 +1,69 @@
1
+ # Copyright (C) 2025, Fri Dao.
2
+ import itertools
3
+ from typing import Optional, List
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class GemmConfig:
9
+ tile_m: int = 128
10
+ tile_n: int = 192
11
+ pingpong: bool = True
12
+ cluster_m: int = 2
13
+ cluster_n: int = 1
14
+ swap_ab: bool = False
15
+ # raster_order: int = 1
16
+ # max_swizzle_size: int = 8
17
+
18
+
19
+ def get_all_configs(
20
+ epilogue: Optional[str] = None,
21
+ tune_coop: bool = True,
22
+ # tune_raster_order=True,
23
+ ) -> List[GemmConfig]:
24
+ tile_n_vals = [128, 144, 160, 176, 192, 208]
25
+ tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
26
+ (128, 224),
27
+ (128, 256),
28
+ # (192, 256), # Getting IOT instruction (core dumped) in the bwd
29
+ ]
30
+ tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
31
+ if epilogue in ["gated"]:
32
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
33
+ tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
34
+ elif epilogue in ["lse"]:
35
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
36
+ tile_mn_vals = []
37
+ if tune_coop:
38
+ tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
39
+ tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
40
+ cluster = [(1, 2), (2, 1)]
41
+ # cluster = [(1, 1), (1, 2), (2, 1)]
42
+ if epilogue in ["lse"]:
43
+ cluster = [(1, 2), (2, 1)]
44
+ swap_ab_vals = [False, True]
45
+ if epilogue in ["lse", "gated"]:
46
+ swap_ab_vals = [False]
47
+ # raster_swizzle = (
48
+ # [(0, 1)]
49
+ # if not tune_raster_order
50
+ # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
51
+ # )
52
+ return [
53
+ GemmConfig(
54
+ tile_m=tile_m,
55
+ tile_n=tile_n,
56
+ pingpong=pingpong,
57
+ cluster_m=cluster_m,
58
+ cluster_n=cluster_n,
59
+ swap_ab=swap_ab,
60
+ # raster_order=raster_order,
61
+ # max_swizzle_size=max_swizzle_size,
62
+ )
63
+ for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
64
+ tile_mn_vals,
65
+ cluster,
66
+ swap_ab_vals,
67
+ # raster_swizzle,
68
+ )
69
+ ]
@@ -0,0 +1,150 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Optional
3
+
4
+ from torch import Tensor
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import const_expr
9
+ import cutlass.torch as cutlass_torch
10
+
11
+ from quack.gemm_act_sm90 import GemmActSm90
12
+ from quack.cute_dsl_utils import get_max_active_clusters
13
+ from quack.gemm_wrapper_utils import GemmWrapperBase
14
+ import quack.activation
15
+
16
+
17
+ class GemmDActSm90(GemmActSm90):
18
+ # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
19
+ # and return 2 arguments (dx, out)
20
+ EpilogueArguments = GemmActSm90.EpilogueArguments
21
+ EpilogueParams = GemmActSm90.EpilogueParams
22
+
23
+ @cute.jit
24
+ def epi_visit_acc_subtile(
25
+ self,
26
+ params: EpilogueParams,
27
+ tRS_rD: cute.Tensor,
28
+ tRS_rC: Optional[cute.Tensor] = None,
29
+ ) -> Optional[cute.Tensor]:
30
+ assert tRS_rC is not None
31
+ tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
32
+ tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
33
+ # If we don't have .shape here, the compiler generates local stores and loads
34
+ if const_expr(params.act_fn is not None):
35
+ tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
36
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
37
+ tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
38
+ else:
39
+ tRS_rPostAct = tRS_rC_acc
40
+ # Type conversion
41
+ tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
42
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
43
+ return tRS_rPostAct_out
44
+
45
+
46
+ dact_fn_map = {
47
+ None: None,
48
+ "relu": quack.activation.drelu,
49
+ "relu_sq": quack.activation.drelu_sq,
50
+ "gelu_tanh_approx": quack.activation.dgelu_tanh_approx,
51
+ }
52
+
53
+
54
+ def gemm_dact_sm90(
55
+ A: Tensor, # (l, m, k)
56
+ B: Tensor, # (l, n, k)
57
+ Out: Tensor, # (l, m, n)
58
+ PreAct: Tensor, # (l, m, n)
59
+ PostAct: Tensor, # (l, m, n)
60
+ tile_count_semaphore: Optional[Tensor], # (1,)
61
+ activation: Optional[str],
62
+ tile_M: int,
63
+ tile_N: int,
64
+ cluster_M: int,
65
+ cluster_N: int,
66
+ pingpong: bool = True,
67
+ persistent: bool = True,
68
+ ) -> None:
69
+ assert activation in dact_fn_map, f"Unsupported activation {activation}"
70
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
71
+ A, B, Out, PreAct, additional_tensors={"PostAct": PostAct}
72
+ )
73
+ GemmWrapperBase.permute_tensors(tensor_infos)
74
+ GemmWrapperBase.extract_dtypes(tensor_infos)
75
+ major_configs = {
76
+ "A": ("m", "k", "l"),
77
+ "B": ("n", "k", "l"),
78
+ "D": ("m", "n", "l"),
79
+ "C": ("m", "n", "l"),
80
+ "PostAct": ("m", "n", "l"),
81
+ }
82
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
83
+
84
+ acc_dtype = cutlass.Float32
85
+ tile_shape_mn = (tile_M, tile_N)
86
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
87
+ if not GemmDActSm90.is_valid_dtypes(
88
+ tensor_infos["A"].dtype,
89
+ tensor_infos["B"].dtype,
90
+ acc_dtype,
91
+ tensor_infos["D"].dtype,
92
+ tensor_infos["A"].major,
93
+ tensor_infos["B"].major,
94
+ ):
95
+ raise TypeError("Skipping due to unsupported combination of types and majors")
96
+
97
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
98
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
99
+ act_fn = dact_fn_map[activation]
100
+ epi_args = GemmDActSm90.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
101
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
102
+ max_active_clusters, tile_count_semaphore
103
+ )
104
+ current_stream = cutlass_torch.current_stream()
105
+ compile_key = GemmWrapperBase.get_compile_key(
106
+ tensor_infos,
107
+ activation,
108
+ tile_shape_mn,
109
+ cluster_shape_mnk,
110
+ pingpong,
111
+ persistent,
112
+ tile_count_semaphore is not None,
113
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
114
+ )
115
+ cache = gemm_dact_sm90.compile_cache
116
+ if compile_key not in cache:
117
+ gemm = GemmDActSm90(
118
+ acc_dtype,
119
+ tensor_infos["A"].dtype,
120
+ tile_shape_mn,
121
+ cluster_shape_mnk,
122
+ pingpong=pingpong,
123
+ is_persistent=persistent,
124
+ )
125
+ cache[compile_key] = cute.compile(
126
+ gemm,
127
+ tensor_infos["A"].cute_tensor,
128
+ tensor_infos["B"].cute_tensor,
129
+ tensor_infos["D"].cute_tensor,
130
+ tensor_infos["C"].cute_tensor,
131
+ epi_args,
132
+ scheduler_args,
133
+ None, # varlen_args
134
+ None, # mAIdx
135
+ current_stream,
136
+ )
137
+ cache[compile_key](
138
+ tensor_infos["A"].cute_tensor,
139
+ tensor_infos["B"].cute_tensor,
140
+ tensor_infos["D"].cute_tensor,
141
+ tensor_infos["C"].cute_tensor,
142
+ epi_args,
143
+ scheduler_args,
144
+ None,
145
+ None,
146
+ current_stream,
147
+ )
148
+
149
+
150
+ gemm_dact_sm90.compile_cache = {}