quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,935 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple, Optional
4
+ from dataclasses import dataclass, fields
5
+ from enum import IntEnum
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32, Boolean, const_expr
10
+
11
+ import quack.utils as utils
12
+ from quack.fast_math import FastDivmod
13
+ from quack.pipeline import PipelineStateWAdvance
14
+
15
+
16
+ @dataclass
17
+ class ParamsBase:
18
+ def __extract_mlir_values__(self):
19
+ all_fields = [getattr(self, field.name) for field in fields(self)]
20
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)]
21
+ values, self._values_pos = [], []
22
+ for obj in non_constexpr_fields:
23
+ obj_values = cutlass.extract_mlir_values(obj)
24
+ values += obj_values
25
+ self._values_pos.append(len(obj_values))
26
+ return values
27
+
28
+ def __new_from_mlir_values__(self, values):
29
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
30
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)}
31
+ non_constexpr_fields = {
32
+ n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr)
33
+ }
34
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
35
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
36
+ values = values[n_items:]
37
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
38
+
39
+
40
+ class RasterOrderOption(IntEnum):
41
+ AlongM = 0
42
+ AlongN = 1
43
+ Heuristic = 2 # Pick AlongM if tiles_n > tiles_m, else AlongN
44
+
45
+
46
+ class RasterOrder(IntEnum):
47
+ AlongM = 0
48
+ AlongN = 1
49
+
50
+
51
+ @cute.jit
52
+ def get_raster_order_from_option(
53
+ raster_order_option: RasterOrderOption, problem_shape_ncluster_mn: cute.Shape, group_size: Int32
54
+ ) -> RasterOrder:
55
+ raster_order = (
56
+ RasterOrder.AlongM
57
+ if raster_order_option == RasterOrderOption.AlongM
58
+ else RasterOrder.AlongN
59
+ )
60
+ if raster_order_option == RasterOrderOption.Heuristic:
61
+ problem_blocks_m = cute.round_up(problem_shape_ncluster_mn[0], group_size)
62
+ problem_blocks_n = cute.round_up(problem_shape_ncluster_mn[1], group_size)
63
+ raster_order = (
64
+ RasterOrder.AlongM if problem_blocks_n > problem_blocks_m else RasterOrder.AlongN
65
+ )
66
+ return raster_order
67
+
68
+
69
+ @dataclass
70
+ class TileSchedulerArguments(ParamsBase):
71
+ problem_shape_ntile_mnl: cute.Shape
72
+ raster_order: RasterOrderOption
73
+ group_size: Int32
74
+ cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
75
+ tile_count_semaphore: Optional[cute.Pointer] = None
76
+ is_persistent: cutlass.Constexpr[bool] = False
77
+
78
+
79
+ class TileScheduler:
80
+ @dataclass
81
+ class Params(ParamsBase):
82
+ problem_shape_ncluster_mnl: cute.Shape
83
+ raster_order: RasterOrder
84
+ num_clusters_per_problem_divmod: FastDivmod
85
+ num_groups_regular: Int32
86
+ group_size_divmod: FastDivmod
87
+ group_size_tail_divmod: FastDivmod
88
+ num_clusters_in_group_divmod: FastDivmod
89
+ tile_count_semaphore: Optional[cute.Pointer]
90
+ cluster_shape_mn: cutlass.Constexpr[cute.Shape]
91
+ is_persistent: cutlass.Constexpr[bool]
92
+
93
+ @staticmethod
94
+ @cute.jit
95
+ def create(args: TileSchedulerArguments, *, loc=None, ip=None) -> "TileScheduler.Params":
96
+ assert args.cluster_shape_mnk[2] == 1
97
+ cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
98
+ problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
99
+ problem_shape_ncluster_mn = cute.ceil_div(problem_shape_ntile_mn, cluster_shape_mn)
100
+ problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
101
+ args.problem_shape_ntile_mnl[2],
102
+ )
103
+ num_clusters_per_problem = cute.size(problem_shape_ncluster_mn)
104
+ raster_order = get_raster_order_from_option(
105
+ args.raster_order, problem_shape_ncluster_mn, args.group_size
106
+ )
107
+ ncluster_fast = (
108
+ problem_shape_ncluster_mn[0]
109
+ if raster_order == RasterOrder.AlongM
110
+ else problem_shape_ncluster_mn[1]
111
+ )
112
+ ncluster_slow = (
113
+ problem_shape_ncluster_mn[1]
114
+ if raster_order == RasterOrder.AlongM
115
+ else problem_shape_ncluster_mn[0]
116
+ )
117
+ group_size = min(args.group_size, ncluster_fast)
118
+ group_size_tail = ncluster_fast % group_size
119
+ num_groups_regular = ncluster_fast // group_size
120
+ num_clusters_in_group = group_size * ncluster_slow
121
+ return TileScheduler.Params(
122
+ problem_shape_ncluster_mnl,
123
+ raster_order,
124
+ FastDivmod.create(num_clusters_per_problem),
125
+ num_groups_regular,
126
+ FastDivmod.create(group_size),
127
+ # Don't divide by 0
128
+ FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
129
+ FastDivmod.create(num_clusters_in_group),
130
+ args.tile_count_semaphore if const_expr(args.is_persistent) else None,
131
+ cluster_shape_mn,
132
+ args.is_persistent,
133
+ )
134
+
135
+ def __init__(
136
+ self,
137
+ current_work_linear_idx: Int32,
138
+ num_tiles_executed: Int32,
139
+ tile_count: Optional[cute.Tensor],
140
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
141
+ pipeline_state: PipelineStateWAdvance,
142
+ params: Params,
143
+ *,
144
+ loc=None,
145
+ ip=None,
146
+ ):
147
+ self._current_work_linear_idx = current_work_linear_idx
148
+ self._num_tiles_executed = num_tiles_executed
149
+ self._tile_count = tile_count
150
+ self._scheduler_pipeline = scheduler_pipeline
151
+ self._pipeline_state = pipeline_state
152
+ self.params = params
153
+ self._loc = loc
154
+ self._ip = ip
155
+
156
+ @staticmethod
157
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
158
+ return TileScheduler.Params.create(args, loc=loc, ip=ip)
159
+
160
+ @staticmethod
161
+ @cute.jit
162
+ def create(
163
+ params: Params,
164
+ tile_count: Optional[cute.Tensor] = None,
165
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
166
+ is_scheduler_warp: bool | Boolean = False,
167
+ *,
168
+ loc=None,
169
+ ip=None,
170
+ ) -> "TileScheduler":
171
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
172
+ stages = 0
173
+ if const_expr(not params.is_persistent):
174
+ cidx, cidy, _ = cute.arch.cluster_idx()
175
+ cdimx, _, _ = cute.arch.cluster_dim()
176
+ cluster_id = cidx + cidy * cdimx
177
+ current_work_linear_idx = Int32(cluster_id)
178
+ else:
179
+ _, _, bidz = cute.arch.block_idx()
180
+ current_work_linear_idx = Int32(bidz)
181
+ if const_expr(params.tile_count_semaphore is not None):
182
+ assert tile_count is not None
183
+ assert scheduler_pipeline is not None
184
+ stages = const_expr(cute.size(tile_count))
185
+ return TileScheduler(
186
+ current_work_linear_idx,
187
+ Int32(0), # num_tiles_executed
188
+ tile_count,
189
+ scheduler_pipeline,
190
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
191
+ params,
192
+ loc=loc,
193
+ ip=ip,
194
+ )
195
+
196
+ # called by host
197
+ @staticmethod
198
+ def get_grid_shape(
199
+ params: Params,
200
+ max_active_clusters: Int32,
201
+ *,
202
+ loc=None,
203
+ ip=None,
204
+ ) -> Tuple[Int32, Int32, Int32]:
205
+ num_ctas_mnl = tuple(
206
+ x * y for x, y in zip(params.problem_shape_ncluster_mnl, params.cluster_shape_mn)
207
+ ) + (params.problem_shape_ncluster_mnl[2],)
208
+ if const_expr(not params.is_persistent):
209
+ return num_ctas_mnl
210
+ else:
211
+ num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
212
+ num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip)
213
+ # Total ctas that can run in one wave
214
+ num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
215
+ num_persistent_ctas = cutlass.min(num_ctas_in_problem, num_ctas_per_wave)
216
+ num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
217
+ return (*params.cluster_shape_mn, num_persistent_clusters)
218
+
219
+ @cute.jit
220
+ def _swizzle_cta(
221
+ self, cluster_id_in_problem: Int32, *, loc=None, ip=None
222
+ ) -> Tuple[Int32, Int32]:
223
+ # CTA Swizzle to promote L2 data reuse
224
+ params = self.params
225
+ group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(cluster_id_in_problem)
226
+ cid_fast_in_group, cid_slow = Int32(0), Int32(0)
227
+ if group_id < params.num_groups_regular:
228
+ cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group)
229
+ else: # tail part
230
+ cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group)
231
+ if group_id % 2 == 1: # serpentine order
232
+ ncluster_slow = (
233
+ params.problem_shape_ncluster_mnl[1]
234
+ if params.raster_order == RasterOrder.AlongM
235
+ else params.problem_shape_ncluster_mnl[0]
236
+ )
237
+ cid_slow = ncluster_slow - 1 - cid_slow
238
+ cid_fast = group_id * params.group_size_divmod.divisor + cid_fast_in_group
239
+ cid_m, cid_n = cid_fast, cid_slow
240
+ if params.raster_order == RasterOrder.AlongN:
241
+ cid_m, cid_n = cid_slow, cid_fast
242
+ return cid_m, cid_n
243
+
244
+ @cute.jit
245
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
246
+ params = self.params
247
+ if const_expr(not params.is_persistent):
248
+ cluster_id_in_problem = self._current_work_linear_idx
249
+ _, _, bidz = cute.arch.block_idx()
250
+ else:
251
+ bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
252
+ self._current_work_linear_idx
253
+ )
254
+ cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip)
255
+ # Get the pid from cluster id
256
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
257
+ pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
258
+ pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
259
+ tile_coord_mnkl = (pid_m, pid_n, None, bidz)
260
+ if const_expr(not params.is_persistent):
261
+ is_valid = self._num_tiles_executed == 0
262
+ else:
263
+ is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
264
+ return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
265
+
266
+ def initial_work_tile_info(self, *, loc=None, ip=None):
267
+ return self.get_current_work(loc=loc, ip=ip)
268
+
269
+ @cute.jit
270
+ def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
271
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
272
+ if const_expr(self.params.tile_count_semaphore is not None):
273
+ params = self.params
274
+ current_work_linear_idx = self._current_work_linear_idx
275
+ if is_scheduler_warp:
276
+ if cute.arch.lane_idx() == 0:
277
+ num_persistent_clusters = cute.arch.grid_dim()[2]
278
+ current_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32(
279
+ cute.size(params.problem_shape_ncluster_mnl) - 1,
280
+ params.tile_count_semaphore,
281
+ )
282
+ # lane 0 already has the right tile_idx, just need to broadcast
283
+ current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
284
+ self._current_work_linear_idx = current_work_linear_idx
285
+
286
+ @cute.jit
287
+ def advance_to_next_work(
288
+ self,
289
+ is_scheduler_warp: bool | Boolean = False,
290
+ *,
291
+ advance_count: int = 1,
292
+ loc=None,
293
+ ip=None,
294
+ ):
295
+ tidx = cute.arch.thread_idx()[0]
296
+ bidx = cute.arch.block_idx()[0]
297
+ params = self.params
298
+ if const_expr(params.is_persistent):
299
+ num_persistent_clusters = cute.arch.grid_dim()[2]
300
+ if const_expr(params.tile_count_semaphore is None): # Static persistent
301
+ self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
302
+ else: # Dynamic persistent
303
+ self._pipeline_state.advance_iters(advance_count - 1)
304
+ current_work_linear_idx = self._current_work_linear_idx
305
+ if is_scheduler_warp:
306
+ self._scheduler_pipeline.producer_acquire(self._pipeline_state)
307
+ lane_idx = cute.arch.lane_idx()
308
+ if lane_idx < cute.size(params.cluster_shape_mn):
309
+ # cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
310
+ if const_expr(cute.size(params.cluster_shape_mn) == 1):
311
+ self._tile_count[self._pipeline_state.index] = current_work_linear_idx
312
+ self._scheduler_pipeline.producer_commit(self._pipeline_state)
313
+ else:
314
+ peer_cta_rank_in_cluster = lane_idx
315
+ mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
316
+ self._pipeline_state
317
+ )
318
+ cute.arch.mbarrier_arrive_and_expect_tx(
319
+ mbar_ptr, 4, peer_cta_rank_in_cluster
320
+ )
321
+ utils.store_shared_remote(
322
+ val=current_work_linear_idx,
323
+ smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
324
+ mbar_ptr=mbar_ptr,
325
+ peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
326
+ )
327
+ # cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
328
+ else:
329
+ # if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
330
+ self._scheduler_pipeline.consumer_wait(self._pipeline_state)
331
+ # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
332
+ current_work_linear_idx = self._tile_count[self._pipeline_state.index]
333
+ # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after smem read, idx = {}", bidx, tidx, current_work_linear_idx)
334
+ cute.arch.sync_warp()
335
+ with cute.arch.elect_one():
336
+ # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, before empty arrive", bidx, tidx)
337
+ self._scheduler_pipeline.consumer_release(self._pipeline_state)
338
+ # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after empty arrive", bidx, tidx)
339
+ self._current_work_linear_idx = current_work_linear_idx
340
+ self._pipeline_state.advance()
341
+ self._num_tiles_executed += Int32(advance_count)
342
+
343
+ def producer_tail(self):
344
+ if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
345
+ self._scheduler_pipeline.producer_tail(self._pipeline_state)
346
+
347
+ @property
348
+ def num_tiles_executed(self) -> Int32:
349
+ return self._num_tiles_executed
350
+
351
+ def __extract_mlir_values__(self):
352
+ values, self._values_pos = [], []
353
+ for obj in [
354
+ self._current_work_linear_idx,
355
+ self._num_tiles_executed,
356
+ self._tile_count,
357
+ self._scheduler_pipeline,
358
+ self._pipeline_state,
359
+ self.params,
360
+ ]:
361
+ obj_values = cutlass.extract_mlir_values(obj)
362
+ values += obj_values
363
+ self._values_pos.append(len(obj_values))
364
+ return values
365
+
366
+ def __new_from_mlir_values__(self, values):
367
+ obj_list = []
368
+ for obj, n_items in zip(
369
+ [
370
+ self._current_work_linear_idx,
371
+ self._num_tiles_executed,
372
+ self._tile_count,
373
+ self._scheduler_pipeline,
374
+ self._pipeline_state,
375
+ self.params,
376
+ ],
377
+ self._values_pos,
378
+ ):
379
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
380
+ values = values[n_items:]
381
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)
382
+
383
+
384
+ @cute.jit
385
+ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
386
+ """
387
+ Convert a triangular index to 2D coordinates.
388
+ This is used to convert the linear index to 2D coordinates for triangular matrices.
389
+ """
390
+ row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1
391
+ col = idx - (row * (row + 1)) // 2
392
+ return row, col
393
+
394
+
395
+ class TriangularTileScheduler(TileScheduler):
396
+ """We assume the tile size per cluster is square (e.g., 128 x 256 per CTA, with cluster 2 x 1)"""
397
+
398
+ @dataclass
399
+ class Params(ParamsBase):
400
+ problem_shape_ncluster_mnl: cute.Shape
401
+ num_clusters_per_problem_divmod: FastDivmod
402
+ group_size_inv_f32: cutlass.Float32
403
+ num_groups_regular: Int32
404
+ group_size_divmod: FastDivmod
405
+ group_size_tail_divmod: FastDivmod
406
+ group_size_mul_group_size_divmod: FastDivmod
407
+ group_size_tail_mul_group_size_divmod: FastDivmod
408
+ tile_count_semaphore: Optional[cute.Pointer]
409
+ cluster_shape_mn: cutlass.Constexpr[cute.Shape]
410
+ is_persistent: cutlass.Constexpr[bool]
411
+
412
+ @staticmethod
413
+ @cute.jit
414
+ def create(
415
+ args: TileSchedulerArguments, *, loc=None, ip=None
416
+ ) -> "TriangularTileScheduler.Params":
417
+ assert args.cluster_shape_mnk[2] == 1
418
+ cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
419
+ problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
420
+ problem_shape_ncluster_mn = cute.ceil_div(problem_shape_ntile_mn, cluster_shape_mn)
421
+ problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
422
+ args.problem_shape_ntile_mnl[2],
423
+ )
424
+ cluster_m = problem_shape_ncluster_mn[0]
425
+ # Assume that each cluster is responsible for a square tile
426
+ num_clusters_per_problem = cluster_m * (cluster_m + 1) // 2
427
+ group_size = min(args.group_size, cluster_m)
428
+ group_size_tail = cluster_m % group_size
429
+ num_groups_regular = cluster_m // group_size
430
+ return TriangularTileScheduler.Params(
431
+ problem_shape_ncluster_mnl,
432
+ FastDivmod.create(num_clusters_per_problem),
433
+ cutlass.Float32(1.0 / group_size),
434
+ num_groups_regular,
435
+ FastDivmod.create(group_size),
436
+ # Don't divide by 0
437
+ FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
438
+ FastDivmod.create(group_size * group_size),
439
+ FastDivmod.create((group_size_tail if group_size_tail > 0 else 1) * group_size),
440
+ args.tile_count_semaphore if const_expr(args.is_persistent) else None,
441
+ cluster_shape_mn,
442
+ args.is_persistent,
443
+ )
444
+
445
+ @staticmethod
446
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
447
+ return TriangularTileScheduler.Params.create(args, loc=loc, ip=ip)
448
+
449
+ @staticmethod
450
+ @cute.jit
451
+ def create(
452
+ params: Params,
453
+ tile_count: Optional[cute.Tensor] = None,
454
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
455
+ is_scheduler_warp: bool | Boolean = False,
456
+ *,
457
+ loc=None,
458
+ ip=None,
459
+ ) -> "TriangularTileScheduler":
460
+ stages = 0
461
+ if const_expr(not params.is_persistent):
462
+ cluster_id, _, _ = cute.arch.cluster_idx()
463
+ current_work_linear_idx = Int32(cluster_id)
464
+ else:
465
+ _, _, bidz = cute.arch.block_idx()
466
+ current_work_linear_idx = Int32(bidz)
467
+ if const_expr(params.tile_count_semaphore is not None):
468
+ assert tile_count is not None
469
+ assert scheduler_pipeline is not None
470
+ stages = const_expr(cute.size(tile_count))
471
+ return TriangularTileScheduler(
472
+ current_work_linear_idx,
473
+ Int32(0), # num_tiles_executed
474
+ tile_count,
475
+ scheduler_pipeline,
476
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
477
+ params,
478
+ loc=loc,
479
+ ip=ip,
480
+ )
481
+
482
+ # called by host
483
+ @staticmethod
484
+ def get_grid_shape(
485
+ params: Params,
486
+ max_active_clusters: Int32,
487
+ *,
488
+ loc=None,
489
+ ip=None,
490
+ ) -> Tuple[Int32, Int32, Int32]:
491
+ clusters = (
492
+ params.num_clusters_per_problem_divmod.divisor,
493
+ 1,
494
+ params.problem_shape_ncluster_mnl[2],
495
+ )
496
+ num_ctas_mnl = tuple(x * y for x, y in zip(clusters, params.cluster_shape_mn)) + (
497
+ params.problem_shape_ncluster_mnl[2],
498
+ )
499
+ if const_expr(not params.is_persistent):
500
+ return num_ctas_mnl
501
+ else:
502
+ num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
503
+ num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip)
504
+ # Total ctas that can run in one wave
505
+ num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
506
+ num_persistent_ctas = cutlass.min(num_ctas_in_problem, num_ctas_per_wave)
507
+ num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
508
+ return (*params.cluster_shape_mn, num_persistent_clusters)
509
+
510
+ @cute.jit
511
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
512
+ params = self.params
513
+ if const_expr(not params.is_persistent):
514
+ cluster_id_in_problem = self._current_work_linear_idx
515
+ _, _, bidz = cute.arch.block_idx()
516
+ else:
517
+ bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
518
+ self._current_work_linear_idx
519
+ )
520
+ # CTA Swizzle to promote L2 data reuse
521
+ group_size = params.group_size_divmod.divisor
522
+ group_id = (
523
+ utils.ceil(
524
+ (utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
525
+ )
526
+ - 1
527
+ )
528
+ cid_m_start = group_id * group_size
529
+ id_in_group = cluster_id_in_problem - (cid_m_start * (cid_m_start + 1)) // 2
530
+ group_size_actual = (
531
+ group_size
532
+ if group_id < params.num_groups_regular
533
+ else params.group_size_tail_divmod.divisor
534
+ )
535
+ group_col, group_remainder = Int32(0), Int32(0)
536
+ if group_id < params.num_groups_regular:
537
+ group_col, group_remainder = params.group_size_mul_group_size_divmod.divmod(id_in_group)
538
+ else: # tail part
539
+ group_col, group_remainder = params.group_size_tail_mul_group_size_divmod.divmod(
540
+ id_in_group
541
+ )
542
+ cid_m_in_group, cid_n_in_group = Int32(0), Int32(0)
543
+ if id_in_group >= group_size_actual * group_size * group_id: # triangular tail
544
+ cid_m_in_group, cid_n_in_group = triangular_idx_to_coord(group_remainder)
545
+ else:
546
+ if group_id < params.num_groups_regular:
547
+ cid_n_in_group, cid_m_in_group = params.group_size_divmod.divmod(group_remainder)
548
+ else:
549
+ cid_n_in_group, cid_m_in_group = params.group_size_tail_divmod.divmod(
550
+ group_remainder
551
+ )
552
+ cid_m = cid_m_start + cid_m_in_group
553
+ cid_n = group_col * group_size + cid_n_in_group
554
+
555
+ # Get the pid from cluster id
556
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
557
+ pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
558
+ pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
559
+ tile_coord_mnkl = (pid_m, pid_n, None, bidz)
560
+ if const_expr(not params.is_persistent):
561
+ is_valid = self._num_tiles_executed == 0
562
+ else:
563
+ is_valid = (
564
+ self._current_work_linear_idx
565
+ < params.num_clusters_per_problem_divmod.divisor
566
+ * params.problem_shape_ncluster_mnl[2]
567
+ )
568
+ # bidx, bidy, bidz = cute.arch.block_idx()
569
+ # tidx, _, _ = cute.arch.thread_idx()
570
+ # if tidx == 0:
571
+ # cute.printf("bidx = {}, bidy = {}, group_id = {}, id_in_group = {}, group_size_actual = {}, group_col = {}, group_remainder = {}, cid_n_in_group = {}, cid_m_in_group = {}, cid_m = {}, cid_n = {}, is_valid = {}",
572
+ # bidx, bidy, group_id, id_in_group, group_size_actual, group_col, group_remainder, cid_n_in_group, cid_m_in_group, cid_m, cid_n, is_valid)
573
+ return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
574
+
575
+
576
+ @dataclass
577
+ class VarlenMTileSchedulerArguments(ParamsBase):
578
+ problem_shape_ntile_mnl: cute.Shape
579
+ total_m: Int32
580
+ cu_seqlens_m: cute.Tensor
581
+ raster_order: cutlass.Constexpr[RasterOrderOption]
582
+ group_size: Int32
583
+ tile_shape_mnk: cutlass.Constexpr[cute.Shape]
584
+ cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
585
+ tile_count_semaphore: Optional[cute.Pointer] = None
586
+ is_persistent: cutlass.Constexpr[bool] = False
587
+
588
+
589
+ class VarlenMTileScheduler(TileScheduler):
590
+ @dataclass
591
+ class Params(ParamsBase):
592
+ problem_shape_ncluster_mnl: cute.Shape
593
+ total_m: Int32
594
+ cu_seqlens_m: cute.Tensor
595
+ raster_order: cutlass.Constexpr[RasterOrder]
596
+ group_size: Int32
597
+ group_size_divmod: Optional[FastDivmod]
598
+ group_size_tail_divmod: Optional[FastDivmod]
599
+ num_clusters_in_group_divmod: FastDivmod
600
+ tile_shape_mn: cutlass.Constexpr[cute.Shape]
601
+ tile_count_semaphore: Optional[cute.Pointer]
602
+ cluster_shape_mn: cutlass.Constexpr[cute.Shape]
603
+ is_persistent: cutlass.Constexpr[bool]
604
+
605
+ @staticmethod
606
+ @cute.jit
607
+ def create(
608
+ args: TileSchedulerArguments, *, loc=None, ip=None
609
+ ) -> "VarlenMTileScheduler.Params":
610
+ assert args.cluster_shape_mnk[2] == 1
611
+ cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
612
+ tile_shape_mn = const_expr(cute.select(args.tile_shape_mnk, mode=[0, 1]))
613
+ # problem_shape_ntile_mnl[0] will be None for VarlenM
614
+ problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
615
+ problem_shape_ncluster_mn = (
616
+ None,
617
+ cute.ceil_div(problem_shape_ntile_mn[1], cluster_shape_mn[1]),
618
+ )
619
+ problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
620
+ args.problem_shape_ntile_mnl[2],
621
+ )
622
+ raster_order = const_expr(
623
+ RasterOrder.AlongM
624
+ if args.raster_order == RasterOrderOption.AlongM
625
+ else RasterOrder.AlongN # For Heuristic we also use AlongN
626
+ )
627
+ ncluster_fast = (
628
+ problem_shape_ncluster_mn[0]
629
+ if raster_order == RasterOrder.AlongM
630
+ else problem_shape_ncluster_mn[1]
631
+ )
632
+ ncluster_slow = (
633
+ problem_shape_ncluster_mn[1]
634
+ if raster_order == RasterOrder.AlongM
635
+ else problem_shape_ncluster_mn[0]
636
+ )
637
+ if const_expr(ncluster_fast is not None):
638
+ group_size = min(args.group_size, ncluster_fast)
639
+ group_size_tail = ncluster_fast % group_size
640
+ else:
641
+ group_size, group_size_tail = args.group_size, None
642
+ if const_expr(ncluster_slow is not None):
643
+ num_clusters_in_group = group_size * ncluster_slow
644
+ else:
645
+ num_clusters_in_group = None
646
+ return VarlenMTileScheduler.Params(
647
+ problem_shape_ncluster_mnl,
648
+ args.total_m,
649
+ args.cu_seqlens_m,
650
+ raster_order,
651
+ group_size,
652
+ FastDivmod.create(group_size) if ncluster_fast is not None else None,
653
+ # Don't divide by 0
654
+ FastDivmod.create(group_size_tail if group_size_tail > 0 else 1)
655
+ if group_size_tail is not None
656
+ else None,
657
+ FastDivmod.create(num_clusters_in_group)
658
+ if num_clusters_in_group is not None
659
+ else None,
660
+ tile_shape_mn,
661
+ args.tile_count_semaphore if const_expr(args.is_persistent) else None,
662
+ cluster_shape_mn,
663
+ args.is_persistent,
664
+ )
665
+
666
+ def __init__(
667
+ self,
668
+ current_work_linear_idx: Int32,
669
+ num_tiles_executed: Int32,
670
+ current_batch_idx: Int32,
671
+ num_work_idx_before_cur_batch: Int32,
672
+ tile_count: Optional[cute.Tensor],
673
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
674
+ pipeline_state: PipelineStateWAdvance,
675
+ params: Params,
676
+ *,
677
+ loc=None,
678
+ ip=None,
679
+ ):
680
+ self._current_work_linear_idx = current_work_linear_idx
681
+ self._num_tiles_executed = num_tiles_executed
682
+ self._current_batch_idx = current_batch_idx
683
+ self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
684
+ self._tile_count = tile_count
685
+ self._scheduler_pipeline = scheduler_pipeline
686
+ self._pipeline_state = pipeline_state
687
+ self.params = params
688
+ self._loc = loc
689
+ self._ip = ip
690
+
691
+ @staticmethod
692
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
693
+ return VarlenMTileScheduler.Params.create(args, loc=loc, ip=ip)
694
+
695
+ @staticmethod
696
+ @cute.jit
697
+ def create(
698
+ params: Params,
699
+ tile_count: Optional[cute.Tensor] = None,
700
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
701
+ is_scheduler_warp: bool | Boolean = False,
702
+ *,
703
+ loc=None,
704
+ ip=None,
705
+ ) -> "VarlenMTileScheduler":
706
+ stages = 0
707
+ _, _, bidz = cute.arch.block_idx()
708
+ current_work_linear_idx = Int32(bidz)
709
+ if const_expr(params.tile_count_semaphore is not None):
710
+ assert tile_count is not None
711
+ assert scheduler_pipeline is not None
712
+ stages = const_expr(cute.size(tile_count))
713
+ return VarlenMTileScheduler(
714
+ current_work_linear_idx,
715
+ Int32(0), # num_tiles_executed
716
+ Int32(0), # current_batch_idx
717
+ Int32(0), # num_work_idx_before_cur_batch
718
+ tile_count,
719
+ scheduler_pipeline,
720
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
721
+ params,
722
+ loc=loc,
723
+ ip=ip,
724
+ )
725
+
726
+ # called by host
727
+ @staticmethod
728
+ def get_grid_shape(
729
+ params: Params,
730
+ max_active_clusters: Int32,
731
+ *,
732
+ loc=None,
733
+ ip=None,
734
+ ) -> Tuple[Int32, Int32, Int32]:
735
+ block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0]
736
+ num_batch = params.problem_shape_ncluster_mnl[2]
737
+ total_clusters_m_max = (params.total_m + num_batch * (block_size - 1)) // block_size
738
+ total_clusters_max = total_clusters_m_max * params.problem_shape_ncluster_mnl[1]
739
+ if const_expr(not params.is_persistent):
740
+ return (*params.cluster_shape_mn, total_clusters_max)
741
+ else:
742
+ num_persistent_clusters = cutlass.min(max_active_clusters, total_clusters_max)
743
+ return (*params.cluster_shape_mn, num_persistent_clusters)
744
+
745
+ @cute.jit
746
+ def _get_num_m_blocks(
747
+ self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int]
748
+ ) -> Int32:
749
+ num_batch = self.params.problem_shape_ncluster_mnl[2]
750
+ batch_idx = lane + bidb_start
751
+ cur_cu_seqlen = Int32(0)
752
+ if batch_idx <= num_batch:
753
+ cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx]
754
+ next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
755
+ seqlen = next_cu_seqlen - cur_cu_seqlen
756
+ return (
757
+ cute.ceil_div(seqlen, block_size)
758
+ if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1
759
+ else Int32(0)
760
+ )
761
+
762
+ @cute.jit
763
+ def _swizzle_cta(
764
+ self, cluster_id_in_problem: Int32, num_clusters_m: Int32, *, loc=None, ip=None
765
+ ) -> Tuple[Int32, Int32]:
766
+ params = self.params
767
+ # CTA Swizzle to promote L2 data reuse
768
+ if const_expr(params.num_clusters_in_group_divmod is not None):
769
+ group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(
770
+ cluster_id_in_problem
771
+ )
772
+ num_clusters_in_group = params.num_clusters_in_group_divmod.divisor
773
+ else:
774
+ assert params.raster_order == RasterOrder.AlongN
775
+ num_clusters_in_group = params.group_size * num_clusters_m
776
+ group_id = cluster_id_in_problem // num_clusters_in_group
777
+ id_in_group = cluster_id_in_problem - group_id * num_clusters_in_group
778
+ cid_fast_in_group, cid_slow = Int32(0), Int32(0)
779
+ if const_expr(
780
+ params.group_size_divmod is not None and params.group_size_tail_divmod is not None
781
+ ):
782
+ num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
783
+ if (group_id + 1) * num_clusters_in_group <= num_clusters:
784
+ cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group)
785
+ else: # tail part
786
+ cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group)
787
+ else:
788
+ assert params.raster_order == RasterOrder.AlongM
789
+ group_size_actual = cutlass.min(
790
+ params.group_size, num_clusters_m - group_id * params.group_size
791
+ )
792
+ cid_slow = id_in_group // group_size_actual
793
+ cid_fast_in_group = id_in_group - cid_slow * group_size_actual
794
+ if group_id % 2 == 1: # serpentine order
795
+ ncluster_slow = (
796
+ params.problem_shape_ncluster_mnl[1]
797
+ if params.raster_order == RasterOrder.AlongM
798
+ else num_clusters_m
799
+ )
800
+ cid_slow = ncluster_slow - 1 - cid_slow
801
+ cid_fast = group_id * params.group_size + cid_fast_in_group
802
+ cid_m, cid_n = cid_fast, cid_slow
803
+ if params.raster_order == RasterOrder.AlongN:
804
+ cid_m, cid_n = cid_slow, cid_fast
805
+ return cid_m, cid_n
806
+
807
+ @cute.jit
808
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
809
+ params = self.params
810
+ lane_idx = cute.arch.lane_idx()
811
+ num_batch = self.params.problem_shape_ncluster_mnl[2]
812
+ block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0]
813
+ batch_idx = self._current_batch_idx
814
+ num_clusters_m = self._get_num_m_blocks(
815
+ lane_idx, bidb_start=batch_idx, block_size=block_size
816
+ )
817
+ num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
818
+ num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx)
819
+ # Total number of blocks for the next 31 problems, same for all lanes
820
+ clusters_in_problems = cute.arch.shuffle_sync(
821
+ num_clusters_cumulative, cute.arch.WARP_SIZE - 1
822
+ )
823
+ problems_end_tile = self._num_work_idx_before_cur_batch + clusters_in_problems
824
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, num_clusters_cumulative = %d, problems_end_tile = %d", self._tile_idx, problems_end_tile, num_clusters_m, num_clusters_cumulative, problems_end_tile)
825
+ cid_m, cid_n = Int32(0), Int32(0)
826
+ next_tile_idx = self._current_work_linear_idx
827
+ while problems_end_tile <= next_tile_idx:
828
+ batch_idx += cute.arch.WARP_SIZE - 1
829
+ if batch_idx >= num_batch:
830
+ batch_idx = Int32(num_batch)
831
+ problems_end_tile = next_tile_idx + 1
832
+ else:
833
+ num_clusters_m = self._get_num_m_blocks(
834
+ lane_idx, bidb_start=batch_idx, block_size=block_size
835
+ )
836
+ num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
837
+ num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx)
838
+ clusters_in_problems = cute.arch.shuffle_sync(
839
+ num_clusters_cumulative, cute.arch.WARP_SIZE - 1
840
+ )
841
+ problems_end_tile += clusters_in_problems
842
+ # Just a placeholer value in case batch_idx >= num_batch
843
+ num_work_idx_before_cur_batch = problems_end_tile - clusters_in_problems
844
+ if batch_idx >= num_batch:
845
+ cid_m, cid_n, batch_idx = Int32(0), Int32(0), Int32(num_batch)
846
+ else:
847
+ problems_start_tile = problems_end_tile - clusters_in_problems
848
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, batch_idx = %d", self._tile_idx, problems_end_tile, num_clusters_m, batch_idx)
849
+ # The next problem to process is the first one that does not have ending tile position
850
+ # that is greater than or equal to tile index.
851
+ batch_idx_in_problems = cute.arch.popc(
852
+ cute.arch.vote_ballot_sync(
853
+ problems_start_tile + num_clusters_cumulative <= next_tile_idx
854
+ )
855
+ )
856
+ batch_idx += batch_idx_in_problems
857
+ num_clusters_prev_lane = (
858
+ 0
859
+ if batch_idx_in_problems == 0
860
+ else cute.arch.shuffle_sync(num_clusters_cumulative, batch_idx_in_problems - 1)
861
+ )
862
+ num_clusters_m = cute.arch.shuffle_sync(num_clusters_m, batch_idx_in_problems)
863
+ num_work_idx_before_cur_batch = problems_start_tile + num_clusters_prev_lane
864
+ cluster_id_in_problem = next_tile_idx - num_work_idx_before_cur_batch
865
+ # cid_n = cluster_id_in_problem // num_clusters_m
866
+ # cid_m = cluster_id_in_problem - cid_n * num_clusters_m
867
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, cid_n=%d, cid_m=%d, is_valid = %d", self._tile_idx, batch_idx, cid_n, cid_m, is_valid)
868
+ cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, num_clusters_m, loc=loc, ip=ip)
869
+ self._current_batch_idx = batch_idx
870
+ self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
871
+
872
+ # Get the pid from cluster id
873
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
874
+ pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
875
+ pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
876
+ tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
877
+ if const_expr(not params.is_persistent):
878
+ is_valid = self._num_tiles_executed == 0 and batch_idx < num_batch
879
+ else:
880
+ is_valid = batch_idx < num_batch
881
+ return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
882
+
883
+ @cute.jit
884
+ def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
885
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
886
+ if const_expr(self.params.tile_count_semaphore is not None):
887
+ params = self.params
888
+ current_work_linear_idx = self._current_work_linear_idx
889
+ if is_scheduler_warp:
890
+ if cute.arch.lane_idx() == 0:
891
+ # cute.printf("before atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
892
+ num_persistent_clusters = cute.arch.grid_dim()[2]
893
+ current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
894
+ 1, params.tile_count_semaphore
895
+ )
896
+ # cute.printf("after atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
897
+ # lane 0 already has the right tile_idx, just need to broadcast
898
+ current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
899
+ self._current_work_linear_idx = current_work_linear_idx
900
+
901
+ def __extract_mlir_values__(self):
902
+ values, self._values_pos = [], []
903
+ for obj in [
904
+ self._current_work_linear_idx,
905
+ self._num_tiles_executed,
906
+ self._current_batch_idx,
907
+ self._num_work_idx_before_cur_batch,
908
+ self._tile_count,
909
+ self._scheduler_pipeline,
910
+ self._pipeline_state,
911
+ self.params,
912
+ ]:
913
+ obj_values = cutlass.extract_mlir_values(obj)
914
+ values += obj_values
915
+ self._values_pos.append(len(obj_values))
916
+ return values
917
+
918
+ def __new_from_mlir_values__(self, values):
919
+ obj_list = []
920
+ for obj, n_items in zip(
921
+ [
922
+ self._current_work_linear_idx,
923
+ self._num_tiles_executed,
924
+ self._current_batch_idx,
925
+ self._num_work_idx_before_cur_batch,
926
+ self._tile_count,
927
+ self._scheduler_pipeline,
928
+ self._pipeline_state,
929
+ self.params,
930
+ ],
931
+ self._values_pos,
932
+ ):
933
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
934
+ values = values[n_items:]
935
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)