quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_act.py ADDED
@@ -0,0 +1,510 @@
1
+ # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
+ from typing import Tuple, Optional, Callable
3
+ from functools import partial
4
+ from dataclasses import dataclass
5
+
6
+ from torch import Tensor
7
+
8
+ import cutlass
9
+ import cutlass.cute as cute
10
+ import cutlass.utils.hopper_helpers as sm90_utils_og
11
+ import cutlass.utils.blackwell_helpers as sm100_utils
12
+ from cutlass import Int32, Float32, Boolean, const_expr
13
+ from cutlass.cutlass_dsl import if_generate
14
+ import cutlass.torch as cutlass_torch
15
+ from cutlass.cute.runtime import from_dlpack
16
+
17
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
18
+ from quack.varlen_utils import VarlenManager
19
+ from quack.gemm_sm90 import GemmSm90
20
+ from quack.gemm_sm100 import GemmSm100
21
+ from quack.gemm_default_epi import GemmDefaultEpiMixin
22
+ from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
23
+ from quack.gemm_wrapper_utils import GemmWrapperBase
24
+ import quack.sm90_utils as sm90_utils
25
+ import quack.copy_utils as copy_utils
26
+ import quack.activation
27
+
28
+
29
+ class GemmActMixin(GemmDefaultEpiMixin):
30
+ num_epi_tensormaps: int = 1
31
+
32
+ @dataclass
33
+ class EpilogueArguments(ArgumentsBase):
34
+ mPostAct: cute.Tensor
35
+ act_fn: cutlass.Constexpr[Optional[Callable]] = None
36
+ alpha: Optional[Float32 | cute.Tensor] = None
37
+ beta: Optional[Float32 | cute.Tensor] = None
38
+ mRowVecBroadcast: Optional[cute.Tensor] = None
39
+ mColVecBroadcast: Optional[cute.Tensor] = None
40
+
41
+ @dataclass
42
+ class EpilogueParams(ParamsBase):
43
+ tma_atom_postact: cute.CopyAtom
44
+ mPostAct_mnl: cute.Tensor
45
+ epi_postact_smem_layout_staged: cute.ComposedLayout
46
+ epi_tile_postact: cute.Tile
47
+ act_fn: cutlass.Constexpr[Optional[Callable]] = None
48
+ alpha: Optional[Float32 | cute.Tensor] = None
49
+ beta: Optional[Float32 | cute.Tensor] = None
50
+ mRowVecBroadcast: Optional[cute.Tensor] = None
51
+ mColVecBroadcast: Optional[cute.Tensor] = None
52
+
53
+ def epi_to_underlying_arguments(
54
+ self, args: EpilogueArguments, *, loc=None, ip=None
55
+ ) -> EpilogueParams:
56
+ self.postact_dtype = args.mPostAct.element_type
57
+ self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
58
+
59
+ self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
60
+ epi_tile_postact = self.epi_tile
61
+ utils_cls = sm100_utils if self.arch == 100 else sm90_utils
62
+ epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
63
+ self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
64
+ )
65
+ tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
66
+ args.mPostAct,
67
+ epi_postact_smem_layout_staged,
68
+ epi_tile_postact,
69
+ op_type="store",
70
+ )
71
+ # Assume all strides are divisible by 32 bits except the last stride
72
+ new_stride = lambda t: tuple(
73
+ cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
74
+ for s in t.stride
75
+ )
76
+ mRowVecBroadcast, mColVecBroadcast = [
77
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
78
+ if t is not None
79
+ else None
80
+ for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
81
+ ]
82
+ return self.EpilogueParams(
83
+ tma_atom_postact,
84
+ tma_tensor_postact,
85
+ epi_postact_smem_layout_staged,
86
+ epi_tile_postact,
87
+ args.act_fn,
88
+ alpha=args.alpha,
89
+ beta=args.beta,
90
+ mRowVecBroadcast=mRowVecBroadcast,
91
+ mColVecBroadcast=mColVecBroadcast,
92
+ )
93
+
94
+ def epi_get_tma_atoms(
95
+ self, params: EpilogueParams, *, loc=None, ip=None
96
+ ) -> list[cute.CopyAtom]:
97
+ return [params.tma_atom_postact]
98
+
99
+ def epi_get_tensormap_update_shapes_orders(
100
+ self,
101
+ params: EpilogueParams,
102
+ cu_seqlens_m: Optional[cute.Tensor],
103
+ batch_idx: Int32,
104
+ *,
105
+ loc=None,
106
+ ip=None,
107
+ ) -> tuple[list[Int32], list[int]]:
108
+ shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None]
109
+ orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
110
+ return shapes, orders
111
+
112
+ @staticmethod
113
+ def epi_smem_bytes_per_stage(
114
+ args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile
115
+ ) -> int:
116
+ postact_dtype = args.mPostAct.element_type
117
+ postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8)
118
+ rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
119
+ args, cta_tile_shape_mnk, epi_tile
120
+ )
121
+ return postact_bytes_per_stage + rowvec_colvec_bytes
122
+
123
+ def epi_get_smem_struct(self, params: EpilogueParams):
124
+ row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
125
+ col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
126
+ row_vec_dtype = (
127
+ params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
128
+ )
129
+ col_vec_dtype = (
130
+ params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
131
+ )
132
+
133
+ @cute.struct
134
+ class EpiSharedStorage:
135
+ sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
136
+ sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
137
+ sPostAct: cute.struct.Align[
138
+ cute.struct.MemRange[
139
+ self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
140
+ ],
141
+ self.buffer_align_bytes,
142
+ ]
143
+
144
+ return EpiSharedStorage
145
+
146
+ def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
147
+ sRowVec, sColVec = super().epi_get_smem_tensors(params, storage)
148
+ sPostAct = storage.epi.sPostAct.get_tensor(
149
+ params.epi_postact_smem_layout_staged.outer,
150
+ swizzle=params.epi_postact_smem_layout_staged.inner,
151
+ )
152
+ return (sRowVec, sColVec, sPostAct)
153
+
154
+ @cute.jit
155
+ def epilogue(
156
+ self,
157
+ params: EpilogueParams,
158
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
159
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
160
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
161
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
162
+ epi_read_state: cutlass.pipeline.PipelineState,
163
+ epi_producer_state: cutlass.pipeline.PipelineState,
164
+ epi_tile: cute.Tile,
165
+ load_acc_subtile: Callable,
166
+ tRS_rD: cute.Tensor,
167
+ tRS_rC: Optional[cute.Tensor],
168
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
169
+ tiled_copy_r2s: cute.TiledCopy,
170
+ tRS_sD: cute.Tensor,
171
+ tiled_copy_s2r: Optional[cute.TiledCopy],
172
+ tSR_rC: Optional[cute.Tensor],
173
+ tSR_sC: Optional[cute.Tensor],
174
+ copy_D: Optional[Callable],
175
+ copy_C: Optional[Callable],
176
+ tile_coord_mnkl: cute.Coord,
177
+ varlen_manager: VarlenManager,
178
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
179
+ tile_scheduler,
180
+ tidx: Int32,
181
+ is_tma_warp: Boolean,
182
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
183
+ has_C = const_expr(tRS_rC is not None)
184
+ has_D = const_expr(copy_D is not None)
185
+
186
+ tma_atom_postact = params.tma_atom_postact
187
+ mPostAct_mnl = params.mPostAct_mnl
188
+ sRowVec, sColVec, sPostAct = epi_smem_tensors
189
+ get_smem_store_op = (
190
+ partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
191
+ if self.arch == 100
192
+ else sm90_utils_og.sm90_get_smem_store_op
193
+ )
194
+ copy_atom_postact_r2s = get_smem_store_op(
195
+ self.postact_layout, self.postact_dtype, self.acc_dtype
196
+ )
197
+ # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
198
+ # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
199
+ tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
200
+ tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
201
+ (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
202
+ batch_idx = tile_coord_mnkl[3]
203
+ copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
204
+ tma_atom_postact,
205
+ varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
206
+ self.cta_tile_shape_postact_mn,
207
+ params.epi_tile_postact,
208
+ sPostAct,
209
+ tile_coord_mnkl,
210
+ tma_desc_ptr=tma_desc_postact_ptr,
211
+ )
212
+
213
+ # We iterate over epi tiles in the N dimension first before the M dimension
214
+ epi_tile_shape = cute.zipped_divide(
215
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
216
+ ).shape[1]
217
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
218
+ epi_tile_num = cute.size(epi_tile_shape)
219
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
220
+
221
+ epi_tensors = self.epi_begin(
222
+ params,
223
+ epi_smem_tensors,
224
+ epi_tile,
225
+ tiled_copy_t2r,
226
+ tiled_copy_r2s,
227
+ tile_coord_mnkl,
228
+ varlen_manager,
229
+ epilogue_barrier,
230
+ tidx,
231
+ )
232
+
233
+ if const_expr(copy_C is not None):
234
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
235
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
236
+ if is_tma_warp:
237
+ epi_pipeline.producer_acquire(epi_producer_state)
238
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
239
+ epi_pipeline.producer_commit(epi_producer_state)
240
+ epi_producer_state.advance()
241
+
242
+ def tma_store_fn(src_idx, dst_idx):
243
+ # Fence and barrier to make sure shared memory store is visible to TMA store
244
+ cute.arch.fence_proxy(
245
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
246
+ )
247
+ epilogue_barrier.arrive_and_wait()
248
+ # Copy from shared memory to global memory
249
+ if is_tma_warp:
250
+ if const_expr(has_D):
251
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
252
+ copy_postact(src_idx=src_idx, dst_idx=dst_idx)
253
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
254
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
255
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
256
+ epilogue_barrier.arrive_and_wait()
257
+
258
+ delay_tma_store = True
259
+
260
+ src_idx_prev, dst_idx_prev = None, None
261
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
262
+ # The global memory coordinate for the current epi tile
263
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
264
+ # Copy from acc to D registers
265
+ load_acc_subtile(tRS_rD, epi_idx)
266
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
267
+ if const_expr(has_C):
268
+ epi_pipeline.consumer_wait(epi_read_state)
269
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
270
+ # Fence to make sure shared memory read is visible to TMA load
271
+ cute.arch.fence_proxy(
272
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
273
+ )
274
+ cute.arch.sync_warp()
275
+ with cute.arch.elect_one():
276
+ epi_pipeline.consumer_release(epi_read_state)
277
+ epi_read_state.advance()
278
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
279
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
280
+ if is_tma_warp:
281
+ epi_pipeline.producer_acquire(epi_producer_state)
282
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
283
+ epi_pipeline.producer_commit(epi_producer_state)
284
+ epi_producer_state.advance()
285
+ tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
286
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
287
+ if const_expr(delay_tma_store):
288
+ if const_expr(epi_idx > 0):
289
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
290
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
291
+ # Copy from D registers to shared memory
292
+ if const_expr(has_D):
293
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
294
+ cute.copy(
295
+ tiled_copy_postact_r2s,
296
+ tiled_copy_postact_r2s.retile(tRS_rPostAct),
297
+ tRS_sPostAct[None, None, None, epi_buffer],
298
+ )
299
+ if const_expr(not delay_tma_store):
300
+ tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
301
+
302
+ if const_expr(delay_tma_store):
303
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
304
+
305
+ self.epi_end(
306
+ params,
307
+ epi_tensors,
308
+ epi_tile,
309
+ tiled_copy_t2r,
310
+ tiled_copy_r2s,
311
+ tile_coord_mnkl,
312
+ varlen_manager,
313
+ tidx,
314
+ )
315
+
316
+ return epi_read_state, epi_producer_state
317
+
318
+ @cute.jit
319
+ def epi_visit_subtile(
320
+ self,
321
+ params: EpilogueParams,
322
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
323
+ tRS_rD: cute.Tensor,
324
+ tRS_rC: Optional[cute.Tensor] = None,
325
+ ) -> Optional[cute.Tensor]:
326
+ GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
327
+ # Apply activation function if provided
328
+ # If we don't have .shape here, the compiler generates local stores and loads
329
+ if const_expr(params.act_fn is not None):
330
+ tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
331
+ if const_expr(self.arch < 100):
332
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
333
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
334
+ else:
335
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
336
+ tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
337
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1])
338
+ )
339
+ else:
340
+ tRS_rPostAct = tRS_rD
341
+ # Type conversion
342
+ tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
343
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
344
+ return tRS_rPostAct_out
345
+
346
+
347
+ class GemmActSm90(GemmActMixin, GemmSm90):
348
+ pass
349
+
350
+
351
+ class GemmActSm100(GemmActMixin, GemmSm100):
352
+ pass
353
+
354
+
355
+ act_fn_map = {
356
+ None: None,
357
+ "relu": quack.activation.relu,
358
+ "relu_sq": quack.activation.relu_sq,
359
+ "gelu_tanh_approx": quack.activation.gelu_tanh_approx,
360
+ }
361
+
362
+
363
+ def gemm_act(
364
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
365
+ B: Tensor, # (l, n, k)
366
+ D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
367
+ C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
368
+ PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
369
+ tile_count_semaphore: Optional[Tensor], # (1,)
370
+ activation: Optional[str],
371
+ tile_M: int,
372
+ tile_N: int,
373
+ cluster_M: int,
374
+ cluster_N: int,
375
+ pingpong: bool = False,
376
+ persistent: bool = True,
377
+ max_swizzle_size: int = 8,
378
+ rowvec_bias: Optional[Tensor] = None, # (l, n)
379
+ colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
380
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
381
+ A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
382
+ ) -> None:
383
+ if cu_seqlens_m is not None:
384
+ assert persistent, "varlen_m requires persistent=True"
385
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
386
+ if D is not None:
387
+ assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
388
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
389
+ gather_A = A_idx is not None
390
+ if gather_A:
391
+ assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
392
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
393
+ assert activation in act_fn_map, f"Unsupported activation {activation}"
394
+
395
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
396
+ A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
397
+ )
398
+ GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
399
+ GemmWrapperBase.extract_dtypes(tensor_infos)
400
+ major_configs = {
401
+ "A": ("m", "k", "l"),
402
+ "B": ("n", "k", "l"),
403
+ "D": ("m", "n", "l"),
404
+ "C": ("m", "n", "l"),
405
+ "PostAct": ("m", "n", "l"),
406
+ }
407
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
408
+
409
+ device_capacity = get_device_capacity(A.device)
410
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
411
+ GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90
412
+
413
+ acc_dtype = Float32
414
+ tile_shape_mn = (tile_M, tile_N)
415
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
416
+ if not GemmCls.is_valid_dtypes(
417
+ tensor_infos["A"].dtype,
418
+ tensor_infos["B"].dtype,
419
+ acc_dtype,
420
+ tensor_infos["D"].dtype,
421
+ tensor_infos["A"].major,
422
+ tensor_infos["B"].major,
423
+ ):
424
+ raise TypeError("Skipping due to unsupported combination of types and majors")
425
+
426
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
427
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
428
+ act_fn = act_fn_map[activation]
429
+ epi_args = GemmCls.EpilogueArguments(
430
+ tensor_infos["PostAct"].cute_tensor,
431
+ act_fn,
432
+ mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
433
+ leading_dim=1
434
+ )
435
+ if rowvec_bias is not None
436
+ else None,
437
+ mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
438
+ leading_dim=1 if cu_seqlens_m is None else 0
439
+ )
440
+ if colvec_bias is not None
441
+ else None,
442
+ )
443
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
444
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
445
+ )
446
+
447
+ # Create varlen arguments if needed (assumes persistent=True when varlen_m)
448
+ varlen_args = GemmWrapperBase.create_varlen_args(
449
+ cu_seqlens_m,
450
+ None, # cu_seqlens_k
451
+ A_idx,
452
+ max_active_clusters,
453
+ cluster_shape_mnk,
454
+ tensor_infos,
455
+ GemmCls.num_epi_tensormaps,
456
+ pingpong,
457
+ )
458
+
459
+ current_stream = cutlass_torch.current_stream()
460
+ compile_key = GemmWrapperBase.get_compile_key(
461
+ tensor_infos,
462
+ activation,
463
+ tile_shape_mn,
464
+ cluster_shape_mnk,
465
+ pingpong,
466
+ persistent,
467
+ tile_count_semaphore is not None,
468
+ device_capacity,
469
+ max_swizzle_size,
470
+ rowvec_bias.dtype if rowvec_bias is not None else None,
471
+ colvec_bias.dtype if colvec_bias is not None else None,
472
+ cu_seqlens_m is not None,
473
+ A_idx is not None,
474
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
475
+ )
476
+ cache = gemm_act.compile_cache
477
+ if compile_key not in cache:
478
+ if device_capacity[0] == 9:
479
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
480
+ gemm_obj = GemmCls(
481
+ acc_dtype,
482
+ tensor_infos["A"].dtype,
483
+ tile_shape_mn,
484
+ cluster_shape_mnk,
485
+ gather_A=gather_A,
486
+ )
487
+ cache[compile_key] = cute.compile(
488
+ gemm_obj,
489
+ tensor_infos["A"].cute_tensor,
490
+ tensor_infos["B"].cute_tensor,
491
+ tensor_infos["D"].cute_tensor,
492
+ tensor_infos["C"].cute_tensor,
493
+ epi_args,
494
+ scheduler_args,
495
+ varlen_args,
496
+ current_stream,
497
+ )
498
+ cache[compile_key](
499
+ tensor_infos["A"].cute_tensor,
500
+ tensor_infos["B"].cute_tensor,
501
+ tensor_infos["D"].cute_tensor,
502
+ tensor_infos["C"].cute_tensor,
503
+ epi_args,
504
+ scheduler_args,
505
+ varlen_args,
506
+ current_stream,
507
+ )
508
+
509
+
510
+ gemm_act.compile_cache = {}
quack/gemm_config.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # Copyright (C) 2025, Fri Dao.
2
2
  import itertools
3
- from typing import Optional, List
3
+ from typing import Optional, List, Literal
4
+ from functools import partial
4
5
  from dataclasses import dataclass
5
6
 
6
7
 
@@ -13,57 +14,82 @@ class GemmConfig:
13
14
  cluster_n: int = 1
14
15
  swap_ab: bool = False
15
16
  # raster_order: int = 1
16
- # max_swizzle_size: int = 8
17
+ max_swizzle_size: int = 8
17
18
 
18
19
 
19
20
  def get_all_configs(
21
+ device_capacity: Literal[9, 10] = 9,
20
22
  epilogue: Optional[str] = None,
21
23
  tune_coop: bool = True,
22
24
  # tune_raster_order=True,
23
25
  ) -> List[GemmConfig]:
24
- tile_n_vals = [128, 144, 160, 176, 192, 208]
25
- tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
26
- (128, 224),
27
- (128, 256),
28
- # (192, 256), # Getting IOT instruction (core dumped) in the bwd
29
- ]
30
- tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
31
- if epilogue in ["gated"]:
32
- tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
33
- tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
34
- elif epilogue in ["lse"]:
35
- tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
36
- tile_mn_vals = []
37
- if tune_coop:
38
- tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
39
- tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
40
- cluster = [(1, 2), (2, 1)]
41
- # cluster = [(1, 1), (1, 2), (2, 1)]
42
- if epilogue in ["lse"]:
26
+ assert device_capacity in [9, 10]
27
+ if device_capacity == 9:
28
+ tile_n_vals = [128, 144, 160, 176, 192, 208]
29
+ tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
30
+ (128, 224),
31
+ (128, 256),
32
+ # (192, 256), # Getting IOT instruction (core dumped) in the bwd
33
+ ]
34
+ tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
35
+ if epilogue in ["gated"]:
36
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
37
+ tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
38
+ elif epilogue in ["lse"]:
39
+ tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
40
+ tile_mn_vals = []
41
+ if tune_coop:
42
+ tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
43
+ tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
43
44
  cluster = [(1, 2), (2, 1)]
44
- swap_ab_vals = [False, True]
45
- if epilogue in ["lse", "gated"]:
46
- swap_ab_vals = [False]
47
- # raster_swizzle = (
48
- # [(0, 1)]
49
- # if not tune_raster_order
50
- # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
51
- # )
52
- return [
53
- GemmConfig(
54
- tile_m=tile_m,
55
- tile_n=tile_n,
56
- pingpong=pingpong,
57
- cluster_m=cluster_m,
58
- cluster_n=cluster_n,
59
- swap_ab=swap_ab,
60
- # raster_order=raster_order,
61
- # max_swizzle_size=max_swizzle_size,
45
+ # cluster = [(1, 1), (1, 2), (2, 1)]
46
+ if epilogue in ["lse"]:
47
+ cluster = [(1, 2), (2, 1)]
48
+ swap_ab_vals = [False, True]
49
+ if epilogue in ["lse", "gated"]:
50
+ swap_ab_vals = [False]
51
+ # raster_swizzle = (
52
+ # [(0, 1)]
53
+ # if not tune_raster_order
54
+ # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
55
+ # )
56
+ return [
57
+ GemmConfig(
58
+ tile_m=tile_m,
59
+ tile_n=tile_n,
60
+ pingpong=pingpong,
61
+ cluster_m=cluster_m,
62
+ cluster_n=cluster_n,
63
+ swap_ab=swap_ab,
64
+ # raster_order=raster_order,
65
+ # max_swizzle_size=max_swizzle_size,
66
+ )
67
+ for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
68
+ tile_mn_vals,
69
+ cluster,
70
+ swap_ab_vals,
71
+ # raster_swizzle,
72
+ )
73
+ ]
74
+ elif device_capacity == 10:
75
+ tile_n_vals = [128, 160, 192, 224, 256]
76
+ tile_n_64_vals = [128, 192, 256]
77
+ tile_mn_cluster_vals = (
78
+ [(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
79
+ # + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals]
80
+ + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
81
+ + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
62
82
  )
63
- for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
64
- tile_mn_vals,
65
- cluster,
66
- swap_ab_vals,
67
- # raster_swizzle,
68
- )
69
- ]
83
+ swap_ab_vals = [False, True]
84
+ if epilogue in ["lse", "gated"]:
85
+ swap_ab_vals = [False]
86
+ max_swizzle_size_vals = [4, 8, 16]
87
+ GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100
88
+ return [
89
+ GemmConfigCls(
90
+ tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms
91
+ )
92
+ for (m, n, (cm, cn)), sab, ms in itertools.product(
93
+ tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals
94
+ )
95
+ ]