quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/gemm_act_sm90.py CHANGED
@@ -19,6 +19,8 @@ import quack.activation
19
19
 
20
20
 
21
21
  class GemmActSm90(GemmSm90):
22
+ num_epi_tensormaps: int = 1
23
+
22
24
  @dataclass
23
25
  class EpilogueArguments(ArgumentsBase):
24
26
  mPostAct: cute.Tensor
@@ -41,7 +43,7 @@ class GemmActSm90(GemmSm90):
41
43
  self.postact_dtype = args.mPostAct.element_type
42
44
  self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
43
45
 
44
- self.tile_shape_postact_mn = self.tile_shape_mnk[:2]
46
+ self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
45
47
  self.epi_tile_postact = self.epi_tile
46
48
  postact_major_mode_size = (
47
49
  self.epi_tile_postact[1]
@@ -63,7 +65,7 @@ class GemmActSm90(GemmSm90):
63
65
  args.mPostAct,
64
66
  epi_postact_smem_layout_staged,
65
67
  self.epi_tile_postact,
66
- store_or_load="store",
68
+ op_type="store",
67
69
  )
68
70
  return GemmActSm90.EpilogueParams(
69
71
  tma_atom_postact,
@@ -74,10 +76,28 @@ class GemmActSm90(GemmSm90):
74
76
  args.beta,
75
77
  )
76
78
 
79
+ def epi_get_tma_atoms(
80
+ self, params: EpilogueParams, *, loc=None, ip=None
81
+ ) -> list[cute.CopyAtom]:
82
+ return [params.tma_atom_postact]
83
+
84
+ def epi_get_tensormap_update_shapes_orders(
85
+ self,
86
+ params: EpilogueParams,
87
+ cu_seqlens_m: cute.Tensor,
88
+ batch_idx: Int32,
89
+ *,
90
+ loc=None,
91
+ ip=None,
92
+ ) -> tuple[list[Int32], list[int]]:
93
+ shapes = [cu_seqlens_m[batch_idx + 1]]
94
+ orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
95
+ return shapes, orders
96
+
77
97
  @staticmethod
78
98
  def epi_smem_bytes_per_stage(
79
99
  args: EpilogueArguments,
80
- tile_shape_mnk: Tuple[int, int, int],
100
+ cta_tile_shape_mnk: Tuple[int, int, int],
81
101
  epi_tile: Tuple[int, int],
82
102
  ) -> int:
83
103
  postact_dtype = args.mPostAct.element_type
@@ -108,7 +128,9 @@ class GemmActSm90(GemmSm90):
108
128
  self,
109
129
  params: EpilogueParams,
110
130
  epi_smem_tensors: Tuple[cute.Tensor, ...],
131
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
111
132
  epi_pipeline: cutlass.pipeline.PipelineAsync,
133
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
112
134
  epi_read_state: cutlass.pipeline.PipelineState,
113
135
  epi_producer_state: cutlass.pipeline.PipelineState,
114
136
  tiled_mma: cute.TiledMma,
@@ -133,7 +155,6 @@ class GemmActSm90(GemmSm90):
133
155
  ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
134
156
  has_C = const_expr(tRS_rC is not None)
135
157
  has_D = const_expr(copy_D is not None)
136
- assert cu_seqlens_m is None, "GemmActSm90 doesn't support varlen_m for now"
137
158
 
138
159
  tma_atom_postact = params.tma_atom_postact
139
160
  mPostAct_mnl = params.mPostAct_mnl
@@ -148,16 +169,17 @@ class GemmActSm90(GemmSm90):
148
169
  bSG_sPostAct, bSG_gPostAct = self.epilog_gmem_copy_and_partition(
149
170
  tma_atom_postact,
150
171
  mPostAct_mnl,
151
- self.tile_shape_postact_mn,
172
+ self.cta_tile_shape_postact_mn,
152
173
  self.epi_tile_postact,
153
174
  sPostAct,
154
175
  tile_coord_mnkl,
155
176
  cu_seqlens_m,
156
177
  )
178
+ (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
157
179
 
158
180
  # We iterate over epi tiles in the N dimension first before the M dimension
159
181
  epi_tile_shape = cute.zipped_divide(
160
- cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
182
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), self.epi_tile
161
183
  ).shape[1]
162
184
  epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
163
185
  epi_tile_num = cute.size(epi_tile_shape)
@@ -214,9 +236,10 @@ class GemmActSm90(GemmSm90):
214
236
  tma_atom_postact,
215
237
  bSG_sPostAct[None, epi_buffer],
216
238
  bSG_gPostAct[None, gmem_coord],
239
+ tma_desc_ptr=tma_desc_postact_ptr,
217
240
  )
218
- cute.arch.cp_async_bulk_commit_group()
219
- cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
241
+ epi_store_pipeline.producer_commit()
242
+ epi_store_pipeline.producer_acquire()
220
243
  epilogue_barrier.arrive_and_wait()
221
244
 
222
245
  return epi_read_state, epi_producer_state
@@ -261,11 +284,12 @@ act_fn_map = {
261
284
 
262
285
 
263
286
  def gemm_act_sm90(
264
- A: Tensor, # (l, m, k)
287
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
265
288
  B: Tensor, # (l, n, k)
266
- D: Optional[Tensor], # (l, m, n)
267
- C: Optional[Tensor], # (l, m, n)
268
- PostAct: Tensor, # (l, m, n)
289
+ D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
290
+ C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
291
+ PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
292
+ tile_count_semaphore: Optional[Tensor], # (1,)
269
293
  activation: Optional[str],
270
294
  tile_M: int,
271
295
  tile_N: int,
@@ -273,15 +297,25 @@ def gemm_act_sm90(
273
297
  cluster_N: int,
274
298
  pingpong: bool = False,
275
299
  persistent: bool = True,
276
- alpha: float = 1.0,
277
- beta: float = 1.0,
300
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
301
+ A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
278
302
  ) -> None:
279
- tile_count_semaphore = None
303
+ if cu_seqlens_m is not None:
304
+ assert persistent, "varlen_m requires persistent=True"
305
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
306
+ if D is not None:
307
+ assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
308
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
309
+ gather_A = A_idx is not None
310
+ if gather_A:
311
+ assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
312
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
280
313
  assert activation in act_fn_map, f"Unsupported activation {activation}"
314
+
281
315
  L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
282
- A, B, D, C, additional_tensors={"PostAct": PostAct}
316
+ A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
283
317
  )
284
- GemmWrapperBase.permute_tensors(tensor_infos)
318
+ GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
285
319
  GemmWrapperBase.extract_dtypes(tensor_infos)
286
320
  major_configs = {
287
321
  "A": ("m", "k", "l"),
@@ -308,15 +342,23 @@ def gemm_act_sm90(
308
342
  max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
309
343
  GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
310
344
  act_fn = act_fn_map[activation]
311
- epi_args = GemmActSm90.EpilogueArguments(
312
- tensor_infos["PostAct"].cute_tensor,
313
- act_fn,
314
- alpha=Float32(alpha) if alpha != 1.0 else None,
315
- beta=Float32(beta) if beta != 1.0 else None,
316
- )
345
+ epi_args = GemmActSm90.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
317
346
  scheduler_args = GemmWrapperBase.create_scheduler_args(
318
347
  max_active_clusters, tile_count_semaphore
319
348
  )
349
+
350
+ # Create varlen arguments if needed (assumes persistent=True when varlen_m)
351
+ varlen_args = GemmWrapperBase.create_varlen_args(
352
+ cu_seqlens_m,
353
+ None, # cu_seqlens_k
354
+ A_idx,
355
+ max_active_clusters,
356
+ cluster_shape_mnk,
357
+ tensor_infos,
358
+ GemmActSm90.num_epi_tensormaps,
359
+ pingpong,
360
+ )
361
+
320
362
  current_stream = cutlass_torch.current_stream()
321
363
  compile_key = GemmWrapperBase.get_compile_key(
322
364
  tensor_infos,
@@ -326,8 +368,8 @@ def gemm_act_sm90(
326
368
  pingpong,
327
369
  persistent,
328
370
  tile_count_semaphore is not None,
329
- alpha != 1.0,
330
- beta != 1.0,
371
+ cu_seqlens_m is not None,
372
+ A_idx is not None,
331
373
  key_tensor_names=("A", "B", "D", "PostAct", "C"),
332
374
  )
333
375
  cache = gemm_act_sm90.compile_cache
@@ -339,6 +381,7 @@ def gemm_act_sm90(
339
381
  cluster_shape_mnk,
340
382
  pingpong=pingpong,
341
383
  is_persistent=persistent,
384
+ gather_A=gather_A,
342
385
  )
343
386
  cache[compile_key] = cute.compile(
344
387
  gemm,
@@ -348,8 +391,7 @@ def gemm_act_sm90(
348
391
  tensor_infos["C"].cute_tensor,
349
392
  epi_args,
350
393
  scheduler_args,
351
- None, # varlen_args
352
- None, # mAIdx
394
+ varlen_args,
353
395
  current_stream,
354
396
  )
355
397
  cache[compile_key](
@@ -359,8 +401,7 @@ def gemm_act_sm90(
359
401
  tensor_infos["C"].cute_tensor,
360
402
  epi_args,
361
403
  scheduler_args,
362
- None,
363
- None,
404
+ varlen_args,
364
405
  current_stream,
365
406
  )
366
407
 
quack/gemm_dact_sm90.py CHANGED
@@ -52,11 +52,11 @@ dact_fn_map = {
52
52
 
53
53
 
54
54
  def gemm_dact_sm90(
55
- A: Tensor, # (l, m, k)
55
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
56
56
  B: Tensor, # (l, n, k)
57
- Out: Tensor, # (l, m, n)
58
- PreAct: Tensor, # (l, m, n)
59
- PostAct: Tensor, # (l, m, n)
57
+ Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
58
+ PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
59
+ PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
60
60
  tile_count_semaphore: Optional[Tensor], # (1,)
61
61
  activation: Optional[str],
62
62
  tile_M: int,
@@ -65,12 +65,31 @@ def gemm_dact_sm90(
65
65
  cluster_N: int,
66
66
  pingpong: bool = True,
67
67
  persistent: bool = True,
68
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
69
+ A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
68
70
  ) -> None:
71
+ if cu_seqlens_m is not None:
72
+ assert persistent, "varlen_m requires persistent=True"
73
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
74
+ assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
75
+ assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
76
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
77
+ gather_A = A_idx is not None
78
+ if gather_A:
79
+ assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
80
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
69
81
  assert activation in dact_fn_map, f"Unsupported activation {activation}"
82
+
70
83
  L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
71
- A, B, Out, PreAct, additional_tensors={"PostAct": PostAct}
84
+ A,
85
+ B,
86
+ Out,
87
+ PreAct,
88
+ additional_tensors={"PostAct": PostAct},
89
+ cu_seqlens_m=cu_seqlens_m,
90
+ A_idx=A_idx,
72
91
  )
73
- GemmWrapperBase.permute_tensors(tensor_infos)
92
+ GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
74
93
  GemmWrapperBase.extract_dtypes(tensor_infos)
75
94
  major_configs = {
76
95
  "A": ("m", "k", "l"),
@@ -101,6 +120,19 @@ def gemm_dact_sm90(
101
120
  scheduler_args = GemmWrapperBase.create_scheduler_args(
102
121
  max_active_clusters, tile_count_semaphore
103
122
  )
123
+
124
+ # Create varlen arguments if needed (assumes persistent=True when varlen_m)
125
+ varlen_args = GemmWrapperBase.create_varlen_args(
126
+ cu_seqlens_m,
127
+ None, # cu_seqlens_k
128
+ A_idx,
129
+ max_active_clusters,
130
+ cluster_shape_mnk,
131
+ tensor_infos,
132
+ GemmDActSm90.num_epi_tensormaps,
133
+ pingpong,
134
+ )
135
+
104
136
  current_stream = cutlass_torch.current_stream()
105
137
  compile_key = GemmWrapperBase.get_compile_key(
106
138
  tensor_infos,
@@ -110,6 +142,8 @@ def gemm_dact_sm90(
110
142
  pingpong,
111
143
  persistent,
112
144
  tile_count_semaphore is not None,
145
+ cu_seqlens_m is not None,
146
+ A_idx is not None,
113
147
  key_tensor_names=("A", "B", "D", "PostAct", "C"),
114
148
  )
115
149
  cache = gemm_dact_sm90.compile_cache
@@ -121,6 +155,7 @@ def gemm_dact_sm90(
121
155
  cluster_shape_mnk,
122
156
  pingpong=pingpong,
123
157
  is_persistent=persistent,
158
+ gather_A=gather_A,
124
159
  )
125
160
  cache[compile_key] = cute.compile(
126
161
  gemm,
@@ -130,8 +165,7 @@ def gemm_dact_sm90(
130
165
  tensor_infos["C"].cute_tensor,
131
166
  epi_args,
132
167
  scheduler_args,
133
- None, # varlen_args
134
- None, # mAIdx
168
+ varlen_args,
135
169
  current_stream,
136
170
  )
137
171
  cache[compile_key](
@@ -141,8 +175,7 @@ def gemm_dact_sm90(
141
175
  tensor_infos["C"].cute_tensor,
142
176
  epi_args,
143
177
  scheduler_args,
144
- None,
145
- None,
178
+ varlen_args,
146
179
  current_stream,
147
180
  )
148
181