quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,44 @@
1
+ quack/__init__.py,sha256=iM_lvTpHS-Yxfxm8YP4MMfuP9esJpxI8karP2Dw7sFg,203
2
+ quack/activation.py,sha256=-lZgojraqdyLjOzgOXBehoVeRBhBq30UX7kOkXsCpGI,20855
3
+ quack/autotuner.py,sha256=atw0ntedi22RPwSdjWOoge4S56S8VFvRocJQcYhpAlo,13454
4
+ quack/broadcast_utils.py,sha256=X5vWg2RtIIWU9Z7nEUW6m0EP0Cfd9XtCKxp4tSyp4Mg,1283
5
+ quack/compile_utils.py,sha256=qJ3oTsDlbAiddrJHtEO7LPYVqn_s-neNfiw-_KvfXZU,591
6
+ quack/copy_utils.py,sha256=J1Hcw18iNHHpOP2wNFhF8Lz16NEmXtoQMu59mmLrRCs,18761
7
+ quack/cross_entropy.py,sha256=w6fjHC_vXt5ji2KfoLrSOdAvpLrQszrYU9rmRij2yY8,24899
8
+ quack/cute_dsl_utils.py,sha256=4uQx5aYDG9UvVzbWwJTjjJLrnoympz70_CD8b37FQWo,3854
9
+ quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
10
+ quack/gemm.py,sha256=8V23MPq49QbV3csv-_AxjfE9qf8R3NIqFK9Q9db6t2c,7417
11
+ quack/gemm_act.py,sha256=Y8HJKfw3tCoFKecwhwhd5xpXd9jCQCGZT_V2xXf-CnU,20823
12
+ quack/gemm_config.py,sha256=94o3g9x7H0wi7aBbsb7H67H8nSzTurwL2zgvKDtQUas,3575
13
+ quack/gemm_dact.py,sha256=l__UhCrFbPjD9a1TAVgP7_C7p5lLfX5DkRcM6z0ofOw,7789
14
+ quack/gemm_default_epi.py,sha256=6qO8Ovtcw8sQQ_kXTBTTQ5IHh1lS6RBCGZG0lgLHNrs,11916
15
+ quack/gemm_interface.py,sha256=AF5PYTNgEHjb3MNXcNvvEpOcShAHtak0Xu12l1zrOAw,44804
16
+ quack/gemm_sm100.py,sha256=U9jmzpST_d1W6CBFf1ZHhTtr0K8hENCsUz7dXvHaMZc,122344
17
+ quack/gemm_sm90.py,sha256=u-Q3fN6DPm1fEdz0LcMecMbGTBcRunUCWopufwO8cHU,92015
18
+ quack/gemm_symmetric.py,sha256=mqx7wgOCY6Dh9hjL6gR9PBstMD476GhpA_NkGeaEtik,13349
19
+ quack/gemm_wrapper_utils.py,sha256=EaPyR3Lq19z_RkdB2_xxRj0IPSJMgyfpkrTXyvY3B6M,12775
20
+ quack/layout_utils.py,sha256=QjFFlvDcLiyGGfA2FKWKI75twHIkOJ2AotE0cIpBAlI,11923
21
+ quack/linear.py,sha256=mhN2A98w7H7X4MS63XCCK3gpOm1eS8H7a4WO9ovkt5U,9791
22
+ quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
23
+ quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
24
+ quack/pipeline.py,sha256=mMdIlpUaHdRDOkvQzgKdCdJydJq6C2eYrny5Bui4KFs,11311
25
+ quack/reduce.py,sha256=ySKT2xh1_pIlbJX29BPmwH6yJ7MxIrRZyxHIPPYVpm0,12698
26
+ quack/reduction_base.py,sha256=QqlPs5L2VCxwDrO4CHPq-KY6f_BAYRbvsR6k81LPzTU,3180
27
+ quack/rmsnorm.py,sha256=esy18s5JtT7KBPRPhWf_anLRTrtromwqeJmg2yzOm60,44678
28
+ quack/sm100_utils.py,sha256=-p5qj3Wi9n4WDLy2sl-fApYpGp5rH3JvZQb712OTxPs,1901
29
+ quack/sm90_utils.py,sha256=hg8qq7S8NODZlUSaxNpdZcsnxcR0jM921rMn1VmBo7o,4278
30
+ quack/softmax.py,sha256=ZqeVbnGfzwkro1LfWBHagbS7B7ug7b9SLZWuGx_Y3Kc,14367
31
+ quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
32
+ quack/tile_scheduler.py,sha256=vbKq0xp94eII0uJ63yY_3sgvJkQI7Irc8y1OttO6cRA,42514
33
+ quack/topk.py,sha256=43xHpRGbwZCSRsulmfrG4WA_r2eLHc3sniaUFU7wn-o,22522
34
+ quack/utils.py,sha256=WIttE1iiwyPIwR1NpaeO26Pn9YkZb361TDxFTUDH-IE,7354
35
+ quack/varlen_utils.py,sha256=SOYkomxX2FoqjYlybg99CqNhS9IARM6F9ba2AkIVvT4,15811
36
+ quack/sort/bitonic_sort.py,sha256=VJPVjPulW_jEr3myBE7AiBYGtsc5T9FEy3sjXFukF7s,4831
37
+ quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
38
+ quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
39
+ quack/sort/utils.py,sha256=RbubEY1GcEpsjiz_6o5o2WB47IeMOzaajW6Jis0s444,1059
40
+ quack_kernels-0.2.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
41
+ quack_kernels-0.2.3.dist-info/METADATA,sha256=-WFp4n_2_bB8KMrDsO2AStm5bx4Av8gZE2wWeEEfcwQ,361
42
+ quack_kernels-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
43
+ quack_kernels-0.2.3.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
44
+ quack_kernels-0.2.3.dist-info/RECORD,,
quack/gemm_act_sm90.py DELETED
@@ -1,368 +0,0 @@
1
- # Copyright (c) 2025, Tri Dao.
2
- from typing import Tuple, Optional, Callable
3
- from dataclasses import dataclass
4
-
5
- from torch import Tensor
6
-
7
- import cutlass
8
- import cutlass.cute as cute
9
- from cutlass.cute.nvgpu import warpgroup
10
- import cutlass.utils.hopper_helpers as sm90_utils
11
- from cutlass import Int32, Float32, Boolean, const_expr
12
- import cutlass.torch as cutlass_torch
13
-
14
- from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
15
- from quack.dense_gemm_sm90 import GemmSm90
16
- from quack.cute_dsl_utils import get_max_active_clusters
17
- from quack.gemm_wrapper_utils import GemmWrapperBase
18
- import quack.activation
19
-
20
-
21
- class GemmActSm90(GemmSm90):
22
- @dataclass
23
- class EpilogueArguments(ArgumentsBase):
24
- mPostAct: cute.Tensor
25
- act_fn: cutlass.Constexpr[Optional[Callable]] = None
26
- alpha: Optional[Float32] = None
27
- beta: Optional[Float32] = None
28
-
29
- @dataclass
30
- class EpilogueParams(ParamsBase):
31
- tma_atom_postact: cute.CopyAtom
32
- mPostAct_mnl: cute.Tensor
33
- epi_postact_smem_layout_staged: cute.ComposedLayout
34
- act_fn: cutlass.Constexpr[Optional[Callable]] = None
35
- alpha: Optional[Float32] = None
36
- beta: Optional[Float32] = None
37
-
38
- def epi_to_underlying_arguments(
39
- self, args: EpilogueArguments, *, loc=None, ip=None
40
- ) -> EpilogueParams:
41
- self.postact_dtype = args.mPostAct.element_type
42
- self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
43
-
44
- self.tile_shape_postact_mn = self.tile_shape_mnk[:2]
45
- self.epi_tile_postact = self.epi_tile
46
- postact_major_mode_size = (
47
- self.epi_tile_postact[1]
48
- if self.postact_layout.is_n_major_c()
49
- else self.epi_tile_postact[0]
50
- )
51
- postact_smem_layout_atom = warpgroup.make_smem_layout_atom(
52
- sm90_utils.get_smem_layout_atom(
53
- self.postact_layout, self.postact_dtype, postact_major_mode_size
54
- ),
55
- self.postact_dtype,
56
- )
57
- epi_postact_smem_layout_staged = cute.tile_to_shape(
58
- postact_smem_layout_atom,
59
- cute.append(self.epi_tile_postact, self.epi_stage),
60
- order=(0, 1, 2),
61
- )
62
- tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
63
- args.mPostAct,
64
- epi_postact_smem_layout_staged,
65
- self.epi_tile_postact,
66
- store_or_load="store",
67
- )
68
- return GemmActSm90.EpilogueParams(
69
- tma_atom_postact,
70
- tma_tensor_postact,
71
- epi_postact_smem_layout_staged,
72
- args.act_fn,
73
- args.alpha,
74
- args.beta,
75
- )
76
-
77
- @staticmethod
78
- def epi_smem_bytes_per_stage(
79
- args: EpilogueArguments,
80
- tile_shape_mnk: Tuple[int, int, int],
81
- epi_tile: Tuple[int, int],
82
- ) -> int:
83
- postact_dtype = args.mPostAct.element_type
84
- postact_bytes_per_stage = cute.size(epi_tile) * (postact_dtype.width // 8)
85
- return postact_bytes_per_stage
86
-
87
- def epi_get_smem_struct(self, params: EpilogueParams):
88
- @cute.struct
89
- class EpiSharedStorage:
90
- sPostAct: cute.struct.Align[
91
- cute.struct.MemRange[
92
- self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
93
- ],
94
- self.buffer_align_bytes,
95
- ]
96
-
97
- return EpiSharedStorage
98
-
99
- def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
100
- sPostAct = storage.epi.sPostAct.get_tensor(
101
- params.epi_postact_smem_layout_staged.outer,
102
- swizzle=params.epi_postact_smem_layout_staged.inner,
103
- )
104
- return (sPostAct,)
105
-
106
- @cute.jit
107
- def epilogue(
108
- self,
109
- params: EpilogueParams,
110
- epi_smem_tensors: Tuple[cute.Tensor, ...],
111
- epi_pipeline: cutlass.pipeline.PipelineAsync,
112
- epi_read_state: cutlass.pipeline.PipelineState,
113
- epi_producer_state: cutlass.pipeline.PipelineState,
114
- tiled_mma: cute.TiledMma,
115
- tRS_rAcc: cute.Tensor,
116
- tRS_rD: cute.Tensor,
117
- tRS_rC: Optional[cute.Tensor],
118
- tiled_copy_r2s: cute.core.ThrCopy,
119
- tRS_sD: cute.Tensor,
120
- tiled_copy_s2r: Optional[cute.core.ThrCopy],
121
- tSR_rC: Optional[cute.Tensor],
122
- tSR_sC: Optional[cute.Tensor],
123
- copy_D: Optional[Callable],
124
- bSG_sD: cute.Tensor,
125
- bSG_gD: cute.Tensor,
126
- epi_load_g2s: Optional[Callable],
127
- tile_coord_mnkl: cute.Coord,
128
- cu_seqlens_m: Optional[cute.Tensor],
129
- epilogue_barrier: cutlass.pipeline.NamedBarrier,
130
- tile_scheduler,
131
- tidx: Int32,
132
- is_tma_warp: Boolean,
133
- ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
134
- has_C = const_expr(tRS_rC is not None)
135
- has_D = const_expr(copy_D is not None)
136
- assert cu_seqlens_m is None, "GemmActSm90 doesn't support varlen_m for now"
137
-
138
- tma_atom_postact = params.tma_atom_postact
139
- mPostAct_mnl = params.mPostAct_mnl
140
- (sPostAct,) = epi_smem_tensors
141
- tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
142
- copy_atom_postact_r2s = sm90_utils.sm90_get_smem_store_op(
143
- self.postact_layout, elem_ty_d=self.postact_dtype, elem_ty_acc=self.acc_dtype
144
- )
145
- tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
146
- thr_copy_postact_r2s = tiled_copy_postact_r2s.get_slice(tidx)
147
- tRS_sPostAct = thr_copy_postact_r2s.partition_D(sPostAct)
148
- bSG_sPostAct, bSG_gPostAct = self.epilog_gmem_copy_and_partition(
149
- tma_atom_postact,
150
- mPostAct_mnl,
151
- self.tile_shape_postact_mn,
152
- self.epi_tile_postact,
153
- sPostAct,
154
- tile_coord_mnkl,
155
- cu_seqlens_m,
156
- )
157
-
158
- # We iterate over epi tiles in the N dimension first before the M dimension
159
- epi_tile_shape = cute.zipped_divide(
160
- cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
161
- ).shape[1]
162
- epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
163
- epi_tile_num = cute.size(epi_tile_shape)
164
- num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
165
-
166
- if const_expr(epi_load_g2s is not None):
167
- for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
168
- epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
169
-
170
- for epi_idx in cutlass.range_constexpr(epi_tile_num):
171
- # Copy from acc to D registers
172
- for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
173
- tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
174
- if const_expr(has_C):
175
- epi_pipeline.consumer_wait(epi_read_state)
176
- cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
177
- # Fence to make sure shared memory read is visible to TMA load
178
- cute.arch.fence_proxy(
179
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
180
- )
181
- cute.arch.sync_warp()
182
- with cute.arch.elect_one():
183
- epi_pipeline.consumer_release(epi_read_state)
184
- epi_read_state.advance()
185
- if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
186
- epi_producer_state = epi_load_g2s(
187
- epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
188
- )
189
- tRS_rPostAct = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
190
- epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
191
- # Copy from D registers to shared memory
192
- if const_expr(has_D):
193
- # Type conversion
194
- tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
195
- tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
196
- cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
197
- cute.copy(
198
- tiled_copy_postact_r2s,
199
- tiled_copy_postact_r2s.retile(tRS_rPostAct),
200
- tRS_sPostAct[None, None, None, epi_buffer],
201
- )
202
- # Fence and barrier to make sure shared memory store is visible to TMA store
203
- cute.arch.fence_proxy(
204
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
205
- )
206
- epilogue_barrier.arrive_and_wait()
207
- # Get the global memory coordinate for the current epi tile
208
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
209
- # Copy from shared memory to global memory
210
- if is_tma_warp:
211
- if const_expr(has_D):
212
- copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
213
- cute.copy(
214
- tma_atom_postact,
215
- bSG_sPostAct[None, epi_buffer],
216
- bSG_gPostAct[None, gmem_coord],
217
- )
218
- cute.arch.cp_async_bulk_commit_group()
219
- cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
220
- epilogue_barrier.arrive_and_wait()
221
-
222
- return epi_read_state, epi_producer_state
223
-
224
- @cute.jit
225
- def epi_visit_acc_subtile(
226
- self,
227
- params: EpilogueParams,
228
- tRS_rD: cute.Tensor,
229
- tRS_rC: Optional[cute.Tensor] = None,
230
- ) -> Optional[cute.Tensor]:
231
- # Apply alpha scaling to accumulator if alpha is provided (not None)
232
- if const_expr(params.alpha is not None):
233
- tRS_rD.store(tRS_rD.load() * params.alpha)
234
- # Apply C with beta scaling
235
- if const_expr(tRS_rC is not None):
236
- if const_expr(params.beta is None):
237
- # beta is None, default behavior: add C (beta=1.0)
238
- tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
239
- else:
240
- tRS_rD.store(tRS_rD.load() + params.beta * tRS_rC.load().to(tRS_rD.element_type))
241
- # Apply activation function if provided
242
- # If we don't have .shape here, the compiler generates local stores and loads
243
- if const_expr(params.act_fn is not None):
244
- tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
245
- for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
246
- tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
247
- else:
248
- tRS_rPostAct = tRS_rD
249
- # Type conversion
250
- tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
251
- tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
252
- return tRS_rPostAct_out
253
-
254
-
255
- act_fn_map = {
256
- None: None,
257
- "relu": quack.activation.relu,
258
- "relu_sq": quack.activation.relu_sq,
259
- "gelu_tanh_approx": quack.activation.gelu_tanh_approx,
260
- }
261
-
262
-
263
- def gemm_act_sm90(
264
- A: Tensor, # (l, m, k)
265
- B: Tensor, # (l, n, k)
266
- D: Optional[Tensor], # (l, m, n)
267
- C: Optional[Tensor], # (l, m, n)
268
- PostAct: Tensor, # (l, m, n)
269
- activation: Optional[str],
270
- tile_M: int,
271
- tile_N: int,
272
- cluster_M: int,
273
- cluster_N: int,
274
- pingpong: bool = False,
275
- persistent: bool = True,
276
- alpha: float = 1.0,
277
- beta: float = 1.0,
278
- ) -> None:
279
- tile_count_semaphore = None
280
- assert activation in act_fn_map, f"Unsupported activation {activation}"
281
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
282
- A, B, D, C, additional_tensors={"PostAct": PostAct}
283
- )
284
- GemmWrapperBase.permute_tensors(tensor_infos)
285
- GemmWrapperBase.extract_dtypes(tensor_infos)
286
- major_configs = {
287
- "A": ("m", "k", "l"),
288
- "B": ("n", "k", "l"),
289
- "D": ("m", "n", "l"),
290
- "C": ("m", "n", "l"),
291
- "PostAct": ("m", "n", "l"),
292
- }
293
- GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
294
-
295
- acc_dtype = cutlass.Float32
296
- tile_shape_mn = (tile_M, tile_N)
297
- cluster_shape_mnk = (cluster_M, cluster_N, 1)
298
- if not GemmActSm90.is_valid_dtypes(
299
- tensor_infos["A"].dtype,
300
- tensor_infos["B"].dtype,
301
- acc_dtype,
302
- tensor_infos["D"].dtype,
303
- tensor_infos["A"].major,
304
- tensor_infos["B"].major,
305
- ):
306
- raise TypeError("Skipping due to unsupported combination of types and majors")
307
-
308
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
309
- GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
310
- act_fn = act_fn_map[activation]
311
- epi_args = GemmActSm90.EpilogueArguments(
312
- tensor_infos["PostAct"].cute_tensor,
313
- act_fn,
314
- alpha=Float32(alpha) if alpha != 1.0 else None,
315
- beta=Float32(beta) if beta != 1.0 else None,
316
- )
317
- scheduler_args = GemmWrapperBase.create_scheduler_args(
318
- max_active_clusters, tile_count_semaphore
319
- )
320
- current_stream = cutlass_torch.current_stream()
321
- compile_key = GemmWrapperBase.get_compile_key(
322
- tensor_infos,
323
- activation,
324
- tile_shape_mn,
325
- cluster_shape_mnk,
326
- pingpong,
327
- persistent,
328
- tile_count_semaphore is not None,
329
- alpha != 1.0,
330
- beta != 1.0,
331
- key_tensor_names=("A", "B", "D", "PostAct", "C"),
332
- )
333
- cache = gemm_act_sm90.compile_cache
334
- if compile_key not in cache:
335
- gemm = GemmActSm90(
336
- acc_dtype,
337
- tensor_infos["A"].dtype,
338
- tile_shape_mn,
339
- cluster_shape_mnk,
340
- pingpong=pingpong,
341
- is_persistent=persistent,
342
- )
343
- cache[compile_key] = cute.compile(
344
- gemm,
345
- tensor_infos["A"].cute_tensor,
346
- tensor_infos["B"].cute_tensor,
347
- tensor_infos["D"].cute_tensor,
348
- tensor_infos["C"].cute_tensor,
349
- epi_args,
350
- scheduler_args,
351
- None, # varlen_args
352
- None, # mAIdx
353
- current_stream,
354
- )
355
- cache[compile_key](
356
- tensor_infos["A"].cute_tensor,
357
- tensor_infos["B"].cute_tensor,
358
- tensor_infos["D"].cute_tensor,
359
- tensor_infos["C"].cute_tensor,
360
- epi_args,
361
- scheduler_args,
362
- None,
363
- None,
364
- current_stream,
365
- )
366
-
367
-
368
- gemm_act_sm90.compile_cache = {}
quack/gemm_dact_sm90.py DELETED
@@ -1,150 +0,0 @@
1
- # Copyright (c) 2025, Tri Dao.
2
- from typing import Optional
3
-
4
- from torch import Tensor
5
-
6
- import cutlass
7
- import cutlass.cute as cute
8
- from cutlass import const_expr
9
- import cutlass.torch as cutlass_torch
10
-
11
- from quack.gemm_act_sm90 import GemmActSm90
12
- from quack.cute_dsl_utils import get_max_active_clusters
13
- from quack.gemm_wrapper_utils import GemmWrapperBase
14
- import quack.activation
15
-
16
-
17
- class GemmDActSm90(GemmActSm90):
18
- # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
19
- # and return 2 arguments (dx, out)
20
- EpilogueArguments = GemmActSm90.EpilogueArguments
21
- EpilogueParams = GemmActSm90.EpilogueParams
22
-
23
- @cute.jit
24
- def epi_visit_acc_subtile(
25
- self,
26
- params: EpilogueParams,
27
- tRS_rD: cute.Tensor,
28
- tRS_rC: Optional[cute.Tensor] = None,
29
- ) -> Optional[cute.Tensor]:
30
- assert tRS_rC is not None
31
- tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
32
- tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
33
- # If we don't have .shape here, the compiler generates local stores and loads
34
- if const_expr(params.act_fn is not None):
35
- tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
36
- for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
37
- tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
38
- else:
39
- tRS_rPostAct = tRS_rC_acc
40
- # Type conversion
41
- tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
42
- tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
43
- return tRS_rPostAct_out
44
-
45
-
46
- dact_fn_map = {
47
- None: None,
48
- "relu": quack.activation.drelu,
49
- "relu_sq": quack.activation.drelu_sq,
50
- "gelu_tanh_approx": quack.activation.dgelu_tanh_approx,
51
- }
52
-
53
-
54
- def gemm_dact_sm90(
55
- A: Tensor, # (l, m, k)
56
- B: Tensor, # (l, n, k)
57
- Out: Tensor, # (l, m, n)
58
- PreAct: Tensor, # (l, m, n)
59
- PostAct: Tensor, # (l, m, n)
60
- tile_count_semaphore: Optional[Tensor], # (1,)
61
- activation: Optional[str],
62
- tile_M: int,
63
- tile_N: int,
64
- cluster_M: int,
65
- cluster_N: int,
66
- pingpong: bool = True,
67
- persistent: bool = True,
68
- ) -> None:
69
- assert activation in dact_fn_map, f"Unsupported activation {activation}"
70
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
71
- A, B, Out, PreAct, additional_tensors={"PostAct": PostAct}
72
- )
73
- GemmWrapperBase.permute_tensors(tensor_infos)
74
- GemmWrapperBase.extract_dtypes(tensor_infos)
75
- major_configs = {
76
- "A": ("m", "k", "l"),
77
- "B": ("n", "k", "l"),
78
- "D": ("m", "n", "l"),
79
- "C": ("m", "n", "l"),
80
- "PostAct": ("m", "n", "l"),
81
- }
82
- GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
83
-
84
- acc_dtype = cutlass.Float32
85
- tile_shape_mn = (tile_M, tile_N)
86
- cluster_shape_mnk = (cluster_M, cluster_N, 1)
87
- if not GemmDActSm90.is_valid_dtypes(
88
- tensor_infos["A"].dtype,
89
- tensor_infos["B"].dtype,
90
- acc_dtype,
91
- tensor_infos["D"].dtype,
92
- tensor_infos["A"].major,
93
- tensor_infos["B"].major,
94
- ):
95
- raise TypeError("Skipping due to unsupported combination of types and majors")
96
-
97
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
98
- GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
99
- act_fn = dact_fn_map[activation]
100
- epi_args = GemmDActSm90.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
101
- scheduler_args = GemmWrapperBase.create_scheduler_args(
102
- max_active_clusters, tile_count_semaphore
103
- )
104
- current_stream = cutlass_torch.current_stream()
105
- compile_key = GemmWrapperBase.get_compile_key(
106
- tensor_infos,
107
- activation,
108
- tile_shape_mn,
109
- cluster_shape_mnk,
110
- pingpong,
111
- persistent,
112
- tile_count_semaphore is not None,
113
- key_tensor_names=("A", "B", "D", "PostAct", "C"),
114
- )
115
- cache = gemm_dact_sm90.compile_cache
116
- if compile_key not in cache:
117
- gemm = GemmDActSm90(
118
- acc_dtype,
119
- tensor_infos["A"].dtype,
120
- tile_shape_mn,
121
- cluster_shape_mnk,
122
- pingpong=pingpong,
123
- is_persistent=persistent,
124
- )
125
- cache[compile_key] = cute.compile(
126
- gemm,
127
- tensor_infos["A"].cute_tensor,
128
- tensor_infos["B"].cute_tensor,
129
- tensor_infos["D"].cute_tensor,
130
- tensor_infos["C"].cute_tensor,
131
- epi_args,
132
- scheduler_args,
133
- None, # varlen_args
134
- None, # mAIdx
135
- current_stream,
136
- )
137
- cache[compile_key](
138
- tensor_infos["A"].cute_tensor,
139
- tensor_infos["B"].cute_tensor,
140
- tensor_infos["D"].cute_tensor,
141
- tensor_infos["C"].cute_tensor,
142
- epi_args,
143
- scheduler_args,
144
- None,
145
- None,
146
- current_stream,
147
- )
148
-
149
-
150
- gemm_dact_sm90.compile_cache = {}