quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,54 @@
1
- # Copyright (c) 2025, Tri Dao.
1
+ # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
2
  from typing import Tuple, Optional, Callable
3
+ from functools import partial
3
4
  from dataclasses import dataclass
4
5
 
5
6
  from torch import Tensor
6
7
 
7
8
  import cutlass
8
9
  import cutlass.cute as cute
9
- from cutlass.cute.nvgpu import warpgroup
10
- import cutlass.utils.hopper_helpers as sm90_utils
10
+ import cutlass.utils.hopper_helpers as sm90_utils_og
11
+ import cutlass.utils.blackwell_helpers as sm100_utils
11
12
  from cutlass import Int32, Float32, Boolean, const_expr
13
+ from cutlass.cutlass_dsl import if_generate
12
14
  import cutlass.torch as cutlass_torch
15
+ from cutlass.cute.runtime import from_dlpack
13
16
 
14
17
  from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
15
- from quack.dense_gemm_sm90 import GemmSm90
16
- from quack.cute_dsl_utils import get_max_active_clusters
18
+ from quack.varlen_utils import VarlenManager
19
+ from quack.gemm_sm90 import GemmSm90
20
+ from quack.gemm_sm100 import GemmSm100
21
+ from quack.gemm_default_epi import GemmDefaultEpiMixin
22
+ from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
17
23
  from quack.gemm_wrapper_utils import GemmWrapperBase
24
+ import quack.sm90_utils as sm90_utils
25
+ import quack.copy_utils as copy_utils
18
26
  import quack.activation
19
27
 
20
28
 
21
- class GemmActSm90(GemmSm90):
29
+ class GemmActMixin(GemmDefaultEpiMixin):
22
30
  num_epi_tensormaps: int = 1
23
31
 
24
32
  @dataclass
25
33
  class EpilogueArguments(ArgumentsBase):
26
34
  mPostAct: cute.Tensor
27
35
  act_fn: cutlass.Constexpr[Optional[Callable]] = None
28
- alpha: Optional[Float32] = None
29
- beta: Optional[Float32] = None
36
+ alpha: Optional[Float32 | cute.Tensor] = None
37
+ beta: Optional[Float32 | cute.Tensor] = None
38
+ mRowVecBroadcast: Optional[cute.Tensor] = None
39
+ mColVecBroadcast: Optional[cute.Tensor] = None
30
40
 
31
41
  @dataclass
32
42
  class EpilogueParams(ParamsBase):
33
43
  tma_atom_postact: cute.CopyAtom
34
44
  mPostAct_mnl: cute.Tensor
35
45
  epi_postact_smem_layout_staged: cute.ComposedLayout
46
+ epi_tile_postact: cute.Tile
36
47
  act_fn: cutlass.Constexpr[Optional[Callable]] = None
37
- alpha: Optional[Float32] = None
38
- beta: Optional[Float32] = None
48
+ alpha: Optional[Float32 | cute.Tensor] = None
49
+ beta: Optional[Float32 | cute.Tensor] = None
50
+ mRowVecBroadcast: Optional[cute.Tensor] = None
51
+ mColVecBroadcast: Optional[cute.Tensor] = None
39
52
 
40
53
  def epi_to_underlying_arguments(
41
54
  self, args: EpilogueArguments, *, loc=None, ip=None
@@ -44,36 +57,38 @@ class GemmActSm90(GemmSm90):
44
57
  self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
45
58
 
46
59
  self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
47
- self.epi_tile_postact = self.epi_tile
48
- postact_major_mode_size = (
49
- self.epi_tile_postact[1]
50
- if self.postact_layout.is_n_major_c()
51
- else self.epi_tile_postact[0]
52
- )
53
- postact_smem_layout_atom = warpgroup.make_smem_layout_atom(
54
- sm90_utils.get_smem_layout_atom(
55
- self.postact_layout, self.postact_dtype, postact_major_mode_size
56
- ),
57
- self.postact_dtype,
58
- )
59
- epi_postact_smem_layout_staged = cute.tile_to_shape(
60
- postact_smem_layout_atom,
61
- cute.append(self.epi_tile_postact, self.epi_stage),
62
- order=(0, 1, 2),
60
+ epi_tile_postact = self.epi_tile
61
+ utils_cls = sm100_utils if self.arch == 100 else sm90_utils
62
+ epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
63
+ self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
63
64
  )
64
65
  tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
65
66
  args.mPostAct,
66
67
  epi_postact_smem_layout_staged,
67
- self.epi_tile_postact,
68
+ epi_tile_postact,
68
69
  op_type="store",
69
70
  )
70
- return GemmActSm90.EpilogueParams(
71
+ # Assume all strides are divisible by 32 bits except the last stride
72
+ new_stride = lambda t: tuple(
73
+ cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
74
+ for s in t.stride
75
+ )
76
+ mRowVecBroadcast, mColVecBroadcast = [
77
+ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
78
+ if t is not None
79
+ else None
80
+ for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
81
+ ]
82
+ return self.EpilogueParams(
71
83
  tma_atom_postact,
72
84
  tma_tensor_postact,
73
85
  epi_postact_smem_layout_staged,
86
+ epi_tile_postact,
74
87
  args.act_fn,
75
- args.alpha,
76
- args.beta,
88
+ alpha=args.alpha,
89
+ beta=args.beta,
90
+ mRowVecBroadcast=mRowVecBroadcast,
91
+ mColVecBroadcast=mColVecBroadcast,
77
92
  )
78
93
 
79
94
  def epi_get_tma_atoms(
@@ -84,29 +99,41 @@ class GemmActSm90(GemmSm90):
84
99
  def epi_get_tensormap_update_shapes_orders(
85
100
  self,
86
101
  params: EpilogueParams,
87
- cu_seqlens_m: cute.Tensor,
102
+ cu_seqlens_m: Optional[cute.Tensor],
88
103
  batch_idx: Int32,
89
104
  *,
90
105
  loc=None,
91
106
  ip=None,
92
107
  ) -> tuple[list[Int32], list[int]]:
93
- shapes = [cu_seqlens_m[batch_idx + 1]]
108
+ shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None]
94
109
  orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
95
110
  return shapes, orders
96
111
 
97
112
  @staticmethod
98
113
  def epi_smem_bytes_per_stage(
99
- args: EpilogueArguments,
100
- cta_tile_shape_mnk: Tuple[int, int, int],
101
- epi_tile: Tuple[int, int],
114
+ args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile
102
115
  ) -> int:
103
116
  postact_dtype = args.mPostAct.element_type
104
- postact_bytes_per_stage = cute.size(epi_tile) * (postact_dtype.width // 8)
105
- return postact_bytes_per_stage
117
+ postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8)
118
+ rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
119
+ args, cta_tile_shape_mnk, epi_tile
120
+ )
121
+ return postact_bytes_per_stage + rowvec_colvec_bytes
106
122
 
107
123
  def epi_get_smem_struct(self, params: EpilogueParams):
124
+ row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
125
+ col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
126
+ row_vec_dtype = (
127
+ params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
128
+ )
129
+ col_vec_dtype = (
130
+ params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
131
+ )
132
+
108
133
  @cute.struct
109
134
  class EpiSharedStorage:
135
+ sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
136
+ sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
110
137
  sPostAct: cute.struct.Align[
111
138
  cute.struct.MemRange[
112
139
  self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
@@ -117,11 +144,12 @@ class GemmActSm90(GemmSm90):
117
144
  return EpiSharedStorage
118
145
 
119
146
  def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
147
+ sRowVec, sColVec = super().epi_get_smem_tensors(params, storage)
120
148
  sPostAct = storage.epi.sPostAct.get_tensor(
121
149
  params.epi_postact_smem_layout_staged.outer,
122
150
  swizzle=params.epi_postact_smem_layout_staged.inner,
123
151
  )
124
- return (sPostAct,)
152
+ return (sRowVec, sColVec, sPostAct)
125
153
 
126
154
  @cute.jit
127
155
  def epilogue(
@@ -133,21 +161,20 @@ class GemmActSm90(GemmSm90):
133
161
  epi_store_pipeline: cutlass.pipeline.PipelineAsync,
134
162
  epi_read_state: cutlass.pipeline.PipelineState,
135
163
  epi_producer_state: cutlass.pipeline.PipelineState,
136
- tiled_mma: cute.TiledMma,
137
- tRS_rAcc: cute.Tensor,
164
+ epi_tile: cute.Tile,
165
+ load_acc_subtile: Callable,
138
166
  tRS_rD: cute.Tensor,
139
167
  tRS_rC: Optional[cute.Tensor],
140
- tiled_copy_r2s: cute.core.ThrCopy,
168
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
169
+ tiled_copy_r2s: cute.TiledCopy,
141
170
  tRS_sD: cute.Tensor,
142
- tiled_copy_s2r: Optional[cute.core.ThrCopy],
171
+ tiled_copy_s2r: Optional[cute.TiledCopy],
143
172
  tSR_rC: Optional[cute.Tensor],
144
173
  tSR_sC: Optional[cute.Tensor],
145
174
  copy_D: Optional[Callable],
146
- bSG_sD: cute.Tensor,
147
- bSG_gD: cute.Tensor,
148
- epi_load_g2s: Optional[Callable],
175
+ copy_C: Optional[Callable],
149
176
  tile_coord_mnkl: cute.Coord,
150
- cu_seqlens_m: Optional[cute.Tensor],
177
+ varlen_manager: VarlenManager,
151
178
  epilogue_barrier: cutlass.pipeline.NamedBarrier,
152
179
  tile_scheduler,
153
180
  tidx: Int32,
@@ -158,41 +185,85 @@ class GemmActSm90(GemmSm90):
158
185
 
159
186
  tma_atom_postact = params.tma_atom_postact
160
187
  mPostAct_mnl = params.mPostAct_mnl
161
- (sPostAct,) = epi_smem_tensors
162
- tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
163
- copy_atom_postact_r2s = sm90_utils.sm90_get_smem_store_op(
164
- self.postact_layout, elem_ty_d=self.postact_dtype, elem_ty_acc=self.acc_dtype
188
+ sRowVec, sColVec, sPostAct = epi_smem_tensors
189
+ get_smem_store_op = (
190
+ partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
191
+ if self.arch == 100
192
+ else sm90_utils_og.sm90_get_smem_store_op
193
+ )
194
+ copy_atom_postact_r2s = get_smem_store_op(
195
+ self.postact_layout, self.postact_dtype, self.acc_dtype
165
196
  )
166
- tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
167
- thr_copy_postact_r2s = tiled_copy_postact_r2s.get_slice(tidx)
168
- tRS_sPostAct = thr_copy_postact_r2s.partition_D(sPostAct)
169
- bSG_sPostAct, bSG_gPostAct = self.epilog_gmem_copy_and_partition(
197
+ # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
198
+ # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
199
+ tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
200
+ tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
201
+ (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
202
+ batch_idx = tile_coord_mnkl[3]
203
+ copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
170
204
  tma_atom_postact,
171
- mPostAct_mnl,
205
+ varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
172
206
  self.cta_tile_shape_postact_mn,
173
- self.epi_tile_postact,
207
+ params.epi_tile_postact,
174
208
  sPostAct,
175
209
  tile_coord_mnkl,
176
- cu_seqlens_m,
210
+ tma_desc_ptr=tma_desc_postact_ptr,
177
211
  )
178
- (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
179
212
 
180
213
  # We iterate over epi tiles in the N dimension first before the M dimension
181
214
  epi_tile_shape = cute.zipped_divide(
182
- cute.make_layout(self.cta_tile_shape_mnk[:2]), self.epi_tile
215
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
183
216
  ).shape[1]
184
217
  epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
185
218
  epi_tile_num = cute.size(epi_tile_shape)
186
219
  num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
187
220
 
188
- if const_expr(epi_load_g2s is not None):
221
+ epi_tensors = self.epi_begin(
222
+ params,
223
+ epi_smem_tensors,
224
+ epi_tile,
225
+ tiled_copy_t2r,
226
+ tiled_copy_r2s,
227
+ tile_coord_mnkl,
228
+ varlen_manager,
229
+ epilogue_barrier,
230
+ tidx,
231
+ )
232
+
233
+ if const_expr(copy_C is not None):
189
234
  for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
190
- epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
235
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
236
+ if is_tma_warp:
237
+ epi_pipeline.producer_acquire(epi_producer_state)
238
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
239
+ epi_pipeline.producer_commit(epi_producer_state)
240
+ epi_producer_state.advance()
241
+
242
+ def tma_store_fn(src_idx, dst_idx):
243
+ # Fence and barrier to make sure shared memory store is visible to TMA store
244
+ cute.arch.fence_proxy(
245
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
246
+ )
247
+ epilogue_barrier.arrive_and_wait()
248
+ # Copy from shared memory to global memory
249
+ if is_tma_warp:
250
+ if const_expr(has_D):
251
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
252
+ copy_postact(src_idx=src_idx, dst_idx=dst_idx)
253
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
254
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
255
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
256
+ epilogue_barrier.arrive_and_wait()
257
+
258
+ delay_tma_store = True
191
259
 
260
+ src_idx_prev, dst_idx_prev = None, None
192
261
  for epi_idx in cutlass.range_constexpr(epi_tile_num):
262
+ # The global memory coordinate for the current epi tile
263
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
193
264
  # Copy from acc to D registers
194
- for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
195
- tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
265
+ load_acc_subtile(tRS_rD, epi_idx)
266
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
196
267
  if const_expr(has_C):
197
268
  epi_pipeline.consumer_wait(epi_read_state)
198
269
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
@@ -204,69 +275,67 @@ class GemmActSm90(GemmSm90):
204
275
  with cute.arch.elect_one():
205
276
  epi_pipeline.consumer_release(epi_read_state)
206
277
  epi_read_state.advance()
207
- if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
208
- epi_producer_state = epi_load_g2s(
209
- epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
210
- )
211
- tRS_rPostAct = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
278
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
279
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
280
+ if is_tma_warp:
281
+ epi_pipeline.producer_acquire(epi_producer_state)
282
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
283
+ epi_pipeline.producer_commit(epi_producer_state)
284
+ epi_producer_state.advance()
285
+ tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
212
286
  epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
287
+ if const_expr(delay_tma_store):
288
+ if const_expr(epi_idx > 0):
289
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
290
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
213
291
  # Copy from D registers to shared memory
214
292
  if const_expr(has_D):
215
- # Type conversion
216
- tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
217
- tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
218
- cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
293
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
219
294
  cute.copy(
220
295
  tiled_copy_postact_r2s,
221
296
  tiled_copy_postact_r2s.retile(tRS_rPostAct),
222
297
  tRS_sPostAct[None, None, None, epi_buffer],
223
298
  )
224
- # Fence and barrier to make sure shared memory store is visible to TMA store
225
- cute.arch.fence_proxy(
226
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
227
- )
228
- epilogue_barrier.arrive_and_wait()
229
- # Get the global memory coordinate for the current epi tile
230
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
231
- # Copy from shared memory to global memory
232
- if is_tma_warp:
233
- if const_expr(has_D):
234
- copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
235
- cute.copy(
236
- tma_atom_postact,
237
- bSG_sPostAct[None, epi_buffer],
238
- bSG_gPostAct[None, gmem_coord],
239
- tma_desc_ptr=tma_desc_postact_ptr,
240
- )
241
- epi_store_pipeline.producer_commit()
242
- epi_store_pipeline.producer_acquire()
243
- epilogue_barrier.arrive_and_wait()
299
+ if const_expr(not delay_tma_store):
300
+ tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
301
+
302
+ if const_expr(delay_tma_store):
303
+ tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
304
+
305
+ self.epi_end(
306
+ params,
307
+ epi_tensors,
308
+ epi_tile,
309
+ tiled_copy_t2r,
310
+ tiled_copy_r2s,
311
+ tile_coord_mnkl,
312
+ varlen_manager,
313
+ tidx,
314
+ )
244
315
 
245
316
  return epi_read_state, epi_producer_state
246
317
 
247
318
  @cute.jit
248
- def epi_visit_acc_subtile(
319
+ def epi_visit_subtile(
249
320
  self,
250
321
  params: EpilogueParams,
322
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
251
323
  tRS_rD: cute.Tensor,
252
324
  tRS_rC: Optional[cute.Tensor] = None,
253
325
  ) -> Optional[cute.Tensor]:
254
- # Apply alpha scaling to accumulator if alpha is provided (not None)
255
- if const_expr(params.alpha is not None):
256
- tRS_rD.store(tRS_rD.load() * params.alpha)
257
- # Apply C with beta scaling
258
- if const_expr(tRS_rC is not None):
259
- if const_expr(params.beta is None):
260
- # beta is None, default behavior: add C (beta=1.0)
261
- tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
262
- else:
263
- tRS_rD.store(tRS_rD.load() + params.beta * tRS_rC.load().to(tRS_rD.element_type))
326
+ GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
264
327
  # Apply activation function if provided
265
328
  # If we don't have .shape here, the compiler generates local stores and loads
266
329
  if const_expr(params.act_fn is not None):
267
330
  tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
268
- for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
269
- tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
331
+ if const_expr(self.arch < 100):
332
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
333
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
334
+ else:
335
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
336
+ tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
337
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1])
338
+ )
270
339
  else:
271
340
  tRS_rPostAct = tRS_rD
272
341
  # Type conversion
@@ -275,6 +344,14 @@ class GemmActSm90(GemmSm90):
275
344
  return tRS_rPostAct_out
276
345
 
277
346
 
347
+ class GemmActSm90(GemmActMixin, GemmSm90):
348
+ pass
349
+
350
+
351
+ class GemmActSm100(GemmActMixin, GemmSm100):
352
+ pass
353
+
354
+
278
355
  act_fn_map = {
279
356
  None: None,
280
357
  "relu": quack.activation.relu,
@@ -283,7 +360,7 @@ act_fn_map = {
283
360
  }
284
361
 
285
362
 
286
- def gemm_act_sm90(
363
+ def gemm_act(
287
364
  A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
288
365
  B: Tensor, # (l, n, k)
289
366
  D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
@@ -297,6 +374,9 @@ def gemm_act_sm90(
297
374
  cluster_N: int,
298
375
  pingpong: bool = False,
299
376
  persistent: bool = True,
377
+ max_swizzle_size: int = 8,
378
+ rowvec_bias: Optional[Tensor] = None, # (l, n)
379
+ colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
300
380
  cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
301
381
  A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
302
382
  ) -> None:
@@ -326,10 +406,14 @@ def gemm_act_sm90(
326
406
  }
327
407
  GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
328
408
 
329
- acc_dtype = cutlass.Float32
409
+ device_capacity = get_device_capacity(A.device)
410
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
411
+ GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90
412
+
413
+ acc_dtype = Float32
330
414
  tile_shape_mn = (tile_M, tile_N)
331
415
  cluster_shape_mnk = (cluster_M, cluster_N, 1)
332
- if not GemmActSm90.is_valid_dtypes(
416
+ if not GemmCls.is_valid_dtypes(
333
417
  tensor_infos["A"].dtype,
334
418
  tensor_infos["B"].dtype,
335
419
  acc_dtype,
@@ -342,9 +426,22 @@ def gemm_act_sm90(
342
426
  max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
343
427
  GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
344
428
  act_fn = act_fn_map[activation]
345
- epi_args = GemmActSm90.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
429
+ epi_args = GemmCls.EpilogueArguments(
430
+ tensor_infos["PostAct"].cute_tensor,
431
+ act_fn,
432
+ mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
433
+ leading_dim=1
434
+ )
435
+ if rowvec_bias is not None
436
+ else None,
437
+ mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
438
+ leading_dim=1 if cu_seqlens_m is None else 0
439
+ )
440
+ if colvec_bias is not None
441
+ else None,
442
+ )
346
443
  scheduler_args = GemmWrapperBase.create_scheduler_args(
347
- max_active_clusters, tile_count_semaphore
444
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
348
445
  )
349
446
 
350
447
  # Create varlen arguments if needed (assumes persistent=True when varlen_m)
@@ -355,7 +452,7 @@ def gemm_act_sm90(
355
452
  max_active_clusters,
356
453
  cluster_shape_mnk,
357
454
  tensor_infos,
358
- GemmActSm90.num_epi_tensormaps,
455
+ GemmCls.num_epi_tensormaps,
359
456
  pingpong,
360
457
  )
361
458
 
@@ -368,23 +465,27 @@ def gemm_act_sm90(
368
465
  pingpong,
369
466
  persistent,
370
467
  tile_count_semaphore is not None,
468
+ device_capacity,
469
+ max_swizzle_size,
470
+ rowvec_bias.dtype if rowvec_bias is not None else None,
471
+ colvec_bias.dtype if colvec_bias is not None else None,
371
472
  cu_seqlens_m is not None,
372
473
  A_idx is not None,
373
474
  key_tensor_names=("A", "B", "D", "PostAct", "C"),
374
475
  )
375
- cache = gemm_act_sm90.compile_cache
476
+ cache = gemm_act.compile_cache
376
477
  if compile_key not in cache:
377
- gemm = GemmActSm90(
478
+ if device_capacity[0] == 9:
479
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
480
+ gemm_obj = GemmCls(
378
481
  acc_dtype,
379
482
  tensor_infos["A"].dtype,
380
483
  tile_shape_mn,
381
484
  cluster_shape_mnk,
382
- pingpong=pingpong,
383
- is_persistent=persistent,
384
485
  gather_A=gather_A,
385
486
  )
386
487
  cache[compile_key] = cute.compile(
387
- gemm,
488
+ gemm_obj,
388
489
  tensor_infos["A"].cute_tensor,
389
490
  tensor_infos["B"].cute_tensor,
390
491
  tensor_infos["D"].cute_tensor,
@@ -406,4 +507,4 @@ def gemm_act_sm90(
406
507
  )
407
508
 
408
509
 
409
- gemm_act_sm90.compile_cache = {}
510
+ gemm_act.compile_cache = {}