quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/fast_math.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # Copyright (c) 2025, Tri Dao.
2
2
 
3
3
  from typing import Tuple
4
+ from dataclasses import dataclass
4
5
 
5
6
  import cutlass
6
7
  import cutlass.cute as cute
@@ -8,6 +9,8 @@ from cutlass import Int32, Uint32
8
9
  from cutlass.cutlass_dsl import T, dsl_user_op
9
10
  from cutlass._mlir.dialects import llvm
10
11
 
12
+ from quack.cute_dsl_utils import ParamsBase
13
+
11
14
 
12
15
  @cute.jit
13
16
  def clz(x: Int32) -> Int32:
@@ -45,18 +48,15 @@ def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
45
48
  )
46
49
 
47
50
 
48
- class FastDivmod:
49
- def __init__(
50
- self, divisor: Int32, multipler: Uint32, shift_right: Uint32, *, loc=None, ip=None
51
- ):
52
- self.divisor = divisor
53
- self.multiplier = multipler
54
- self.shift_right = shift_right
55
- self._loc = loc
51
+ @dataclass
52
+ class FastDivmod(ParamsBase):
53
+ divisor: Int32
54
+ multiplier: Uint32
55
+ shift_right: Uint32
56
56
 
57
57
  # called by host
58
58
  @staticmethod
59
- def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod":
59
+ def create(divisor: Int32) -> "FastDivmod":
60
60
  """Construct the FastDivmod object, in host code.
61
61
  This precomputes some values based on the divisor and is computationally expensive.
62
62
  """
@@ -64,7 +64,7 @@ class FastDivmod:
64
64
  divisor_u32 = Uint32(divisor)
65
65
  multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
66
66
  shift_right = Uint32(p - 32)
67
- return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip)
67
+ return FastDivmod(divisor, multiplier, shift_right)
68
68
 
69
69
  @cute.jit
70
70
  def div(self, dividend: Int32) -> Int32:
@@ -78,20 +78,3 @@ class FastDivmod:
78
78
  quotient = self.div(dividend)
79
79
  remainder = dividend - quotient * self.divisor
80
80
  return quotient, remainder
81
-
82
- def __extract_mlir_values__(self):
83
- values, self._values_pos = [], []
84
- for obj in [self.divisor, self.multiplier, self.shift_right]:
85
- obj_values = cutlass.extract_mlir_values(obj)
86
- values += obj_values
87
- self._values_pos.append(len(obj_values))
88
- return values
89
-
90
- def __new_from_mlir_values__(self, values):
91
- obj_list = []
92
- for obj, n_items in zip(
93
- [self.divisor, self.multiplier, self.shift_right], self._values_pos
94
- ):
95
- obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
96
- values = values[n_items:]
97
- return FastDivmod(*(tuple(obj_list)), loc=self._loc)
quack/gemm_act_sm90.py ADDED
@@ -0,0 +1,368 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Tuple, Optional, Callable
3
+ from dataclasses import dataclass
4
+
5
+ from torch import Tensor
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass.cute.nvgpu import warpgroup
10
+ import cutlass.utils.hopper_helpers as sm90_utils
11
+ from cutlass import Int32, Float32, Boolean, const_expr
12
+ import cutlass.torch as cutlass_torch
13
+
14
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
15
+ from quack.dense_gemm_sm90 import GemmSm90
16
+ from quack.cute_dsl_utils import get_max_active_clusters
17
+ from quack.gemm_wrapper_utils import GemmWrapperBase
18
+ import quack.activation
19
+
20
+
21
+ class GemmActSm90(GemmSm90):
22
+ @dataclass
23
+ class EpilogueArguments(ArgumentsBase):
24
+ mPostAct: cute.Tensor
25
+ act_fn: cutlass.Constexpr[Optional[Callable]] = None
26
+ alpha: Optional[Float32] = None
27
+ beta: Optional[Float32] = None
28
+
29
+ @dataclass
30
+ class EpilogueParams(ParamsBase):
31
+ tma_atom_postact: cute.CopyAtom
32
+ mPostAct_mnl: cute.Tensor
33
+ epi_postact_smem_layout_staged: cute.ComposedLayout
34
+ act_fn: cutlass.Constexpr[Optional[Callable]] = None
35
+ alpha: Optional[Float32] = None
36
+ beta: Optional[Float32] = None
37
+
38
+ def epi_to_underlying_arguments(
39
+ self, args: EpilogueArguments, *, loc=None, ip=None
40
+ ) -> EpilogueParams:
41
+ self.postact_dtype = args.mPostAct.element_type
42
+ self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
43
+
44
+ self.tile_shape_postact_mn = self.tile_shape_mnk[:2]
45
+ self.epi_tile_postact = self.epi_tile
46
+ postact_major_mode_size = (
47
+ self.epi_tile_postact[1]
48
+ if self.postact_layout.is_n_major_c()
49
+ else self.epi_tile_postact[0]
50
+ )
51
+ postact_smem_layout_atom = warpgroup.make_smem_layout_atom(
52
+ sm90_utils.get_smem_layout_atom(
53
+ self.postact_layout, self.postact_dtype, postact_major_mode_size
54
+ ),
55
+ self.postact_dtype,
56
+ )
57
+ epi_postact_smem_layout_staged = cute.tile_to_shape(
58
+ postact_smem_layout_atom,
59
+ cute.append(self.epi_tile_postact, self.epi_stage),
60
+ order=(0, 1, 2),
61
+ )
62
+ tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
63
+ args.mPostAct,
64
+ epi_postact_smem_layout_staged,
65
+ self.epi_tile_postact,
66
+ store_or_load="store",
67
+ )
68
+ return GemmActSm90.EpilogueParams(
69
+ tma_atom_postact,
70
+ tma_tensor_postact,
71
+ epi_postact_smem_layout_staged,
72
+ args.act_fn,
73
+ args.alpha,
74
+ args.beta,
75
+ )
76
+
77
+ @staticmethod
78
+ def epi_smem_bytes_per_stage(
79
+ args: EpilogueArguments,
80
+ tile_shape_mnk: Tuple[int, int, int],
81
+ epi_tile: Tuple[int, int],
82
+ ) -> int:
83
+ postact_dtype = args.mPostAct.element_type
84
+ postact_bytes_per_stage = cute.size(epi_tile) * (postact_dtype.width // 8)
85
+ return postact_bytes_per_stage
86
+
87
+ def epi_get_smem_struct(self, params: EpilogueParams):
88
+ @cute.struct
89
+ class EpiSharedStorage:
90
+ sPostAct: cute.struct.Align[
91
+ cute.struct.MemRange[
92
+ self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
93
+ ],
94
+ self.buffer_align_bytes,
95
+ ]
96
+
97
+ return EpiSharedStorage
98
+
99
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
100
+ sPostAct = storage.epi.sPostAct.get_tensor(
101
+ params.epi_postact_smem_layout_staged.outer,
102
+ swizzle=params.epi_postact_smem_layout_staged.inner,
103
+ )
104
+ return (sPostAct,)
105
+
106
+ @cute.jit
107
+ def epilogue(
108
+ self,
109
+ params: EpilogueParams,
110
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
111
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
112
+ epi_read_state: cutlass.pipeline.PipelineState,
113
+ epi_producer_state: cutlass.pipeline.PipelineState,
114
+ tiled_mma: cute.TiledMma,
115
+ tRS_rAcc: cute.Tensor,
116
+ tRS_rD: cute.Tensor,
117
+ tRS_rC: Optional[cute.Tensor],
118
+ tiled_copy_r2s: cute.core.ThrCopy,
119
+ tRS_sD: cute.Tensor,
120
+ tiled_copy_s2r: Optional[cute.core.ThrCopy],
121
+ tSR_rC: Optional[cute.Tensor],
122
+ tSR_sC: Optional[cute.Tensor],
123
+ copy_D: Optional[Callable],
124
+ bSG_sD: cute.Tensor,
125
+ bSG_gD: cute.Tensor,
126
+ epi_load_g2s: Optional[Callable],
127
+ tile_coord_mnkl: cute.Coord,
128
+ cu_seqlens_m: Optional[cute.Tensor],
129
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
130
+ tile_scheduler,
131
+ tidx: Int32,
132
+ is_tma_warp: Boolean,
133
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
134
+ has_C = const_expr(tRS_rC is not None)
135
+ has_D = const_expr(copy_D is not None)
136
+ assert cu_seqlens_m is None, "GemmActSm90 doesn't support varlen_m for now"
137
+
138
+ tma_atom_postact = params.tma_atom_postact
139
+ mPostAct_mnl = params.mPostAct_mnl
140
+ (sPostAct,) = epi_smem_tensors
141
+ tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
142
+ copy_atom_postact_r2s = sm90_utils.sm90_get_smem_store_op(
143
+ self.postact_layout, elem_ty_d=self.postact_dtype, elem_ty_acc=self.acc_dtype
144
+ )
145
+ tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
146
+ thr_copy_postact_r2s = tiled_copy_postact_r2s.get_slice(tidx)
147
+ tRS_sPostAct = thr_copy_postact_r2s.partition_D(sPostAct)
148
+ bSG_sPostAct, bSG_gPostAct = self.epilog_gmem_copy_and_partition(
149
+ tma_atom_postact,
150
+ mPostAct_mnl,
151
+ self.tile_shape_postact_mn,
152
+ self.epi_tile_postact,
153
+ sPostAct,
154
+ tile_coord_mnkl,
155
+ cu_seqlens_m,
156
+ )
157
+
158
+ # We iterate over epi tiles in the N dimension first before the M dimension
159
+ epi_tile_shape = cute.zipped_divide(
160
+ cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
161
+ ).shape[1]
162
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
163
+ epi_tile_num = cute.size(epi_tile_shape)
164
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
165
+
166
+ if const_expr(epi_load_g2s is not None):
167
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
168
+ epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
169
+
170
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
171
+ # Copy from acc to D registers
172
+ for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
173
+ tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
174
+ if const_expr(has_C):
175
+ epi_pipeline.consumer_wait(epi_read_state)
176
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
177
+ # Fence to make sure shared memory read is visible to TMA load
178
+ cute.arch.fence_proxy(
179
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
180
+ )
181
+ cute.arch.sync_warp()
182
+ with cute.arch.elect_one():
183
+ epi_pipeline.consumer_release(epi_read_state)
184
+ epi_read_state.advance()
185
+ if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
186
+ epi_producer_state = epi_load_g2s(
187
+ epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
188
+ )
189
+ tRS_rPostAct = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
190
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
191
+ # Copy from D registers to shared memory
192
+ if const_expr(has_D):
193
+ # Type conversion
194
+ tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
195
+ tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
196
+ cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
197
+ cute.copy(
198
+ tiled_copy_postact_r2s,
199
+ tiled_copy_postact_r2s.retile(tRS_rPostAct),
200
+ tRS_sPostAct[None, None, None, epi_buffer],
201
+ )
202
+ # Fence and barrier to make sure shared memory store is visible to TMA store
203
+ cute.arch.fence_proxy(
204
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
205
+ )
206
+ epilogue_barrier.arrive_and_wait()
207
+ # Get the global memory coordinate for the current epi tile
208
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
209
+ # Copy from shared memory to global memory
210
+ if is_tma_warp:
211
+ if const_expr(has_D):
212
+ copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
213
+ cute.copy(
214
+ tma_atom_postact,
215
+ bSG_sPostAct[None, epi_buffer],
216
+ bSG_gPostAct[None, gmem_coord],
217
+ )
218
+ cute.arch.cp_async_bulk_commit_group()
219
+ cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
220
+ epilogue_barrier.arrive_and_wait()
221
+
222
+ return epi_read_state, epi_producer_state
223
+
224
+ @cute.jit
225
+ def epi_visit_acc_subtile(
226
+ self,
227
+ params: EpilogueParams,
228
+ tRS_rD: cute.Tensor,
229
+ tRS_rC: Optional[cute.Tensor] = None,
230
+ ) -> Optional[cute.Tensor]:
231
+ # Apply alpha scaling to accumulator if alpha is provided (not None)
232
+ if const_expr(params.alpha is not None):
233
+ tRS_rD.store(tRS_rD.load() * params.alpha)
234
+ # Apply C with beta scaling
235
+ if const_expr(tRS_rC is not None):
236
+ if const_expr(params.beta is None):
237
+ # beta is None, default behavior: add C (beta=1.0)
238
+ tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
239
+ else:
240
+ tRS_rD.store(tRS_rD.load() + params.beta * tRS_rC.load().to(tRS_rD.element_type))
241
+ # Apply activation function if provided
242
+ # If we don't have .shape here, the compiler generates local stores and loads
243
+ if const_expr(params.act_fn is not None):
244
+ tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
245
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
246
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
247
+ else:
248
+ tRS_rPostAct = tRS_rD
249
+ # Type conversion
250
+ tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
251
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
252
+ return tRS_rPostAct_out
253
+
254
+
255
+ act_fn_map = {
256
+ None: None,
257
+ "relu": quack.activation.relu,
258
+ "relu_sq": quack.activation.relu_sq,
259
+ "gelu_tanh_approx": quack.activation.gelu_tanh_approx,
260
+ }
261
+
262
+
263
+ def gemm_act_sm90(
264
+ A: Tensor, # (l, m, k)
265
+ B: Tensor, # (l, n, k)
266
+ D: Optional[Tensor], # (l, m, n)
267
+ C: Optional[Tensor], # (l, m, n)
268
+ PostAct: Tensor, # (l, m, n)
269
+ activation: Optional[str],
270
+ tile_M: int,
271
+ tile_N: int,
272
+ cluster_M: int,
273
+ cluster_N: int,
274
+ pingpong: bool = False,
275
+ persistent: bool = True,
276
+ alpha: float = 1.0,
277
+ beta: float = 1.0,
278
+ ) -> None:
279
+ tile_count_semaphore = None
280
+ assert activation in act_fn_map, f"Unsupported activation {activation}"
281
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
282
+ A, B, D, C, additional_tensors={"PostAct": PostAct}
283
+ )
284
+ GemmWrapperBase.permute_tensors(tensor_infos)
285
+ GemmWrapperBase.extract_dtypes(tensor_infos)
286
+ major_configs = {
287
+ "A": ("m", "k", "l"),
288
+ "B": ("n", "k", "l"),
289
+ "D": ("m", "n", "l"),
290
+ "C": ("m", "n", "l"),
291
+ "PostAct": ("m", "n", "l"),
292
+ }
293
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
294
+
295
+ acc_dtype = cutlass.Float32
296
+ tile_shape_mn = (tile_M, tile_N)
297
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
298
+ if not GemmActSm90.is_valid_dtypes(
299
+ tensor_infos["A"].dtype,
300
+ tensor_infos["B"].dtype,
301
+ acc_dtype,
302
+ tensor_infos["D"].dtype,
303
+ tensor_infos["A"].major,
304
+ tensor_infos["B"].major,
305
+ ):
306
+ raise TypeError("Skipping due to unsupported combination of types and majors")
307
+
308
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
309
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
310
+ act_fn = act_fn_map[activation]
311
+ epi_args = GemmActSm90.EpilogueArguments(
312
+ tensor_infos["PostAct"].cute_tensor,
313
+ act_fn,
314
+ alpha=Float32(alpha) if alpha != 1.0 else None,
315
+ beta=Float32(beta) if beta != 1.0 else None,
316
+ )
317
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
318
+ max_active_clusters, tile_count_semaphore
319
+ )
320
+ current_stream = cutlass_torch.current_stream()
321
+ compile_key = GemmWrapperBase.get_compile_key(
322
+ tensor_infos,
323
+ activation,
324
+ tile_shape_mn,
325
+ cluster_shape_mnk,
326
+ pingpong,
327
+ persistent,
328
+ tile_count_semaphore is not None,
329
+ alpha != 1.0,
330
+ beta != 1.0,
331
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
332
+ )
333
+ cache = gemm_act_sm90.compile_cache
334
+ if compile_key not in cache:
335
+ gemm = GemmActSm90(
336
+ acc_dtype,
337
+ tensor_infos["A"].dtype,
338
+ tile_shape_mn,
339
+ cluster_shape_mnk,
340
+ pingpong=pingpong,
341
+ is_persistent=persistent,
342
+ )
343
+ cache[compile_key] = cute.compile(
344
+ gemm,
345
+ tensor_infos["A"].cute_tensor,
346
+ tensor_infos["B"].cute_tensor,
347
+ tensor_infos["D"].cute_tensor,
348
+ tensor_infos["C"].cute_tensor,
349
+ epi_args,
350
+ scheduler_args,
351
+ None, # varlen_args
352
+ None, # mAIdx
353
+ current_stream,
354
+ )
355
+ cache[compile_key](
356
+ tensor_infos["A"].cute_tensor,
357
+ tensor_infos["B"].cute_tensor,
358
+ tensor_infos["D"].cute_tensor,
359
+ tensor_infos["C"].cute_tensor,
360
+ epi_args,
361
+ scheduler_args,
362
+ None,
363
+ None,
364
+ current_stream,
365
+ )
366
+
367
+
368
+ gemm_act_sm90.compile_cache = {}
quack/gemm_config.py CHANGED
@@ -1,61 +1,69 @@
1
- # Copyright (C) 2025, Tri Dao.
1
+ # Copyright (C) 2025, Fri Dao.
2
2
  import itertools
3
- from typing import Optional
4
- from pydantic import BaseModel
3
+ from typing import Optional, List
4
+ from dataclasses import dataclass
5
5
 
6
6
 
7
- class GemmConfig(BaseModel, frozen=True):
8
- tile_m: int = 256
9
- tile_n: int = 128
7
+ @dataclass(frozen=True)
8
+ class GemmConfig:
9
+ tile_m: int = 128
10
+ tile_n: int = 192
11
+ pingpong: bool = True
10
12
  cluster_m: int = 2
11
13
  cluster_n: int = 1
12
14
  swap_ab: bool = False
13
- pingpong: bool = False
14
- raster_order: int = 2
15
- max_swizzle_size: int = 1
15
+ # raster_order: int = 1
16
+ # max_swizzle_size: int = 8
16
17
 
17
18
 
18
19
  def get_all_configs(
19
- epilogue: Optional[str],
20
- tune_pingpong=True,
21
- tune_raster_order=True,
22
- ) -> list[GemmConfig]:
20
+ epilogue: Optional[str] = None,
21
+ tune_coop: bool = True,
22
+ # tune_raster_order=True,
23
+ ) -> List[GemmConfig]:
23
24
  tile_n_vals = [128, 144, 160, 176, 192, 208]
24
- tile_mn_vals = [(256, tile_n) for tile_n in tile_n_vals]
25
- if epilogue in ["swiglu"]:
26
- tile_mn_vals = [(m, n) for m, n in tile_mn_vals if n % 32 == 0]
27
- cluster = [(1, 1), (1, 2), (2, 1)]
28
- # cluster = [(1, 2), (2, 1)]
25
+ tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
26
+ (128, 224),
27
+ (128, 256),
28
+ # (192, 256), # Getting IOT instruction (core dumped) in the bwd
29
+ ]
30
+ tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
31
+ if epilogue in ["gated"]:
32
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
33
+ tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
34
+ elif epilogue in ["lse"]:
35
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
36
+ tile_mn_vals = []
37
+ if tune_coop:
38
+ tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
39
+ tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
40
+ cluster = [(1, 2), (2, 1)]
41
+ # cluster = [(1, 1), (1, 2), (2, 1)]
29
42
  if epilogue in ["lse"]:
30
43
  cluster = [(1, 2), (2, 1)]
31
44
  swap_ab_vals = [False, True]
32
- if epilogue in ["lse", "swiglu"]:
45
+ if epilogue in ["lse", "gated"]:
33
46
  swap_ab_vals = [False]
34
- pingpong_vals = [False, True] if tune_pingpong else [False]
35
- raster_swizzle = (
36
- [(0, 1)]
37
- if not tune_raster_order
38
- else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
39
- )
47
+ # raster_swizzle = (
48
+ # [(0, 1)]
49
+ # if not tune_raster_order
50
+ # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
51
+ # )
40
52
  return [
41
53
  GemmConfig(
42
- tile_m=tile_m if not pingpong else 128,
54
+ tile_m=tile_m,
43
55
  tile_n=tile_n,
56
+ pingpong=pingpong,
44
57
  cluster_m=cluster_m,
45
58
  cluster_n=cluster_n,
46
59
  swap_ab=swap_ab,
47
- pingpong=pingpong,
48
- raster_order=raster_order,
49
- max_swizzle_size=max_swizzle_size,
60
+ # raster_order=raster_order,
61
+ # max_swizzle_size=max_swizzle_size,
50
62
  )
51
- for (tile_m, tile_n), (cluster_m, cluster_n), swap_ab, pingpong, (
52
- raster_order,
53
- max_swizzle_size,
54
- ) in itertools.product(
63
+ for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
55
64
  tile_mn_vals,
56
65
  cluster,
57
66
  swap_ab_vals,
58
- pingpong_vals,
59
- raster_swizzle,
67
+ # raster_swizzle,
60
68
  )
61
69
  ]
@@ -0,0 +1,150 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Optional
3
+
4
+ from torch import Tensor
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import const_expr
9
+ import cutlass.torch as cutlass_torch
10
+
11
+ from quack.gemm_act_sm90 import GemmActSm90
12
+ from quack.cute_dsl_utils import get_max_active_clusters
13
+ from quack.gemm_wrapper_utils import GemmWrapperBase
14
+ import quack.activation
15
+
16
+
17
+ class GemmDActSm90(GemmActSm90):
18
+ # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
19
+ # and return 2 arguments (dx, out)
20
+ EpilogueArguments = GemmActSm90.EpilogueArguments
21
+ EpilogueParams = GemmActSm90.EpilogueParams
22
+
23
+ @cute.jit
24
+ def epi_visit_acc_subtile(
25
+ self,
26
+ params: EpilogueParams,
27
+ tRS_rD: cute.Tensor,
28
+ tRS_rC: Optional[cute.Tensor] = None,
29
+ ) -> Optional[cute.Tensor]:
30
+ assert tRS_rC is not None
31
+ tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
32
+ tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
33
+ # If we don't have .shape here, the compiler generates local stores and loads
34
+ if const_expr(params.act_fn is not None):
35
+ tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
36
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
37
+ tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
38
+ else:
39
+ tRS_rPostAct = tRS_rC_acc
40
+ # Type conversion
41
+ tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
42
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
43
+ return tRS_rPostAct_out
44
+
45
+
46
+ dact_fn_map = {
47
+ None: None,
48
+ "relu": quack.activation.drelu,
49
+ "relu_sq": quack.activation.drelu_sq,
50
+ "gelu_tanh_approx": quack.activation.dgelu_tanh_approx,
51
+ }
52
+
53
+
54
+ def gemm_dact_sm90(
55
+ A: Tensor, # (l, m, k)
56
+ B: Tensor, # (l, n, k)
57
+ Out: Tensor, # (l, m, n)
58
+ PreAct: Tensor, # (l, m, n)
59
+ PostAct: Tensor, # (l, m, n)
60
+ tile_count_semaphore: Optional[Tensor], # (1,)
61
+ activation: Optional[str],
62
+ tile_M: int,
63
+ tile_N: int,
64
+ cluster_M: int,
65
+ cluster_N: int,
66
+ pingpong: bool = True,
67
+ persistent: bool = True,
68
+ ) -> None:
69
+ assert activation in dact_fn_map, f"Unsupported activation {activation}"
70
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
71
+ A, B, Out, PreAct, additional_tensors={"PostAct": PostAct}
72
+ )
73
+ GemmWrapperBase.permute_tensors(tensor_infos)
74
+ GemmWrapperBase.extract_dtypes(tensor_infos)
75
+ major_configs = {
76
+ "A": ("m", "k", "l"),
77
+ "B": ("n", "k", "l"),
78
+ "D": ("m", "n", "l"),
79
+ "C": ("m", "n", "l"),
80
+ "PostAct": ("m", "n", "l"),
81
+ }
82
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
83
+
84
+ acc_dtype = cutlass.Float32
85
+ tile_shape_mn = (tile_M, tile_N)
86
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
87
+ if not GemmDActSm90.is_valid_dtypes(
88
+ tensor_infos["A"].dtype,
89
+ tensor_infos["B"].dtype,
90
+ acc_dtype,
91
+ tensor_infos["D"].dtype,
92
+ tensor_infos["A"].major,
93
+ tensor_infos["B"].major,
94
+ ):
95
+ raise TypeError("Skipping due to unsupported combination of types and majors")
96
+
97
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
98
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
99
+ act_fn = dact_fn_map[activation]
100
+ epi_args = GemmDActSm90.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
101
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
102
+ max_active_clusters, tile_count_semaphore
103
+ )
104
+ current_stream = cutlass_torch.current_stream()
105
+ compile_key = GemmWrapperBase.get_compile_key(
106
+ tensor_infos,
107
+ activation,
108
+ tile_shape_mn,
109
+ cluster_shape_mnk,
110
+ pingpong,
111
+ persistent,
112
+ tile_count_semaphore is not None,
113
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
114
+ )
115
+ cache = gemm_dact_sm90.compile_cache
116
+ if compile_key not in cache:
117
+ gemm = GemmDActSm90(
118
+ acc_dtype,
119
+ tensor_infos["A"].dtype,
120
+ tile_shape_mn,
121
+ cluster_shape_mnk,
122
+ pingpong=pingpong,
123
+ is_persistent=persistent,
124
+ )
125
+ cache[compile_key] = cute.compile(
126
+ gemm,
127
+ tensor_infos["A"].cute_tensor,
128
+ tensor_infos["B"].cute_tensor,
129
+ tensor_infos["D"].cute_tensor,
130
+ tensor_infos["C"].cute_tensor,
131
+ epi_args,
132
+ scheduler_args,
133
+ None, # varlen_args
134
+ None, # mAIdx
135
+ current_stream,
136
+ )
137
+ cache[compile_key](
138
+ tensor_infos["A"].cute_tensor,
139
+ tensor_infos["B"].cute_tensor,
140
+ tensor_infos["D"].cute_tensor,
141
+ tensor_infos["C"].cute_tensor,
142
+ epi_args,
143
+ scheduler_args,
144
+ None,
145
+ None,
146
+ current_stream,
147
+ )
148
+
149
+
150
+ gemm_dact_sm90.compile_cache = {}