quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,937 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple, Optional
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32, Boolean, const_expr
10
+
11
+ import quack.utils as utils
12
+ from quack.fast_math import FastDivmod
13
+ from quack.pipeline import PipelineStateWAdvance
14
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
15
+
16
+
17
+ class RasterOrderOption(IntEnum):
18
+ AlongM = 0
19
+ AlongN = 1
20
+ Heuristic = 2 # Pick AlongM if tiles_n > tiles_m, else AlongN
21
+
22
+
23
+ class RasterOrder(IntEnum):
24
+ AlongM = 0
25
+ AlongN = 1
26
+
27
+
28
+ @cute.jit
29
+ def get_raster_order_from_option(
30
+ raster_order_option: RasterOrderOption, problem_shape_ncluster_mn: cute.Shape, group_size: Int32
31
+ ) -> RasterOrder:
32
+ raster_order = (
33
+ RasterOrder.AlongM
34
+ if raster_order_option == RasterOrderOption.AlongM
35
+ else RasterOrder.AlongN
36
+ )
37
+ if raster_order_option == RasterOrderOption.Heuristic:
38
+ problem_blocks_m = cute.round_up(problem_shape_ncluster_mn[0], group_size)
39
+ problem_blocks_n = cute.round_up(problem_shape_ncluster_mn[1], group_size)
40
+ raster_order = (
41
+ RasterOrder.AlongM if problem_blocks_n > problem_blocks_m else RasterOrder.AlongN
42
+ )
43
+ return raster_order
44
+
45
+
46
+ # Grouping arguments together that should be passed to __call__
47
+ @dataclass
48
+ class TileSchedulerOptions(ArgumentsBase):
49
+ max_active_clusters: Int32
50
+ raster_order: cutlass.Constexpr[RasterOrderOption] = RasterOrderOption.Heuristic
51
+ max_swizzle_size: Int32 = Int32(8)
52
+ tile_count_semaphore: Optional[cute.Pointer] = None
53
+ batch_idx_permute: Optional[cute.Tensor] = None
54
+
55
+
56
+ @dataclass
57
+ class TileSchedulerArguments(ArgumentsBase):
58
+ problem_shape_ntile_mnl: cute.Shape
59
+ raster_order: cutlass.Constexpr[RasterOrderOption]
60
+ group_size: Int32
61
+ cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
62
+ tile_count_semaphore: Optional[cute.Pointer] = None
63
+ batch_idx_permute: Optional[cute.Tensor] = None
64
+ is_persistent: cutlass.Constexpr[bool] = False
65
+
66
+
67
+ class TileScheduler:
68
+ @dataclass
69
+ class Params(ParamsBase):
70
+ problem_shape_ncluster_mnl: cute.Shape
71
+ raster_order: RasterOrder
72
+ num_clusters_per_problem_divmod: FastDivmod
73
+ num_groups_regular: Int32
74
+ group_size_divmod: FastDivmod
75
+ group_size_tail_divmod: FastDivmod
76
+ num_clusters_in_group_divmod: FastDivmod
77
+ tile_count_semaphore: Optional[cute.Pointer]
78
+ batch_idx_permute: Optional[cute.Tensor]
79
+ cluster_shape_mn: cutlass.Constexpr[cute.Shape]
80
+ is_persistent: cutlass.Constexpr[bool]
81
+
82
+ @staticmethod
83
+ @cute.jit
84
+ def create(args: TileSchedulerArguments, *, loc=None, ip=None) -> "TileScheduler.Params":
85
+ assert args.cluster_shape_mnk[2] == 1
86
+ cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
87
+ problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
88
+ problem_shape_ncluster_mn = cute.ceil_div(problem_shape_ntile_mn, cluster_shape_mn)
89
+ problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
90
+ args.problem_shape_ntile_mnl[2],
91
+ )
92
+ num_clusters_per_problem = cute.size(problem_shape_ncluster_mn)
93
+ raster_order = get_raster_order_from_option(
94
+ args.raster_order, problem_shape_ncluster_mn, args.group_size
95
+ )
96
+ ncluster_fast = (
97
+ problem_shape_ncluster_mn[0]
98
+ if raster_order == RasterOrder.AlongM
99
+ else problem_shape_ncluster_mn[1]
100
+ )
101
+ ncluster_slow = (
102
+ problem_shape_ncluster_mn[1]
103
+ if raster_order == RasterOrder.AlongM
104
+ else problem_shape_ncluster_mn[0]
105
+ )
106
+ group_size = min(args.group_size, ncluster_fast)
107
+ group_size_tail = ncluster_fast % group_size
108
+ num_groups_regular = ncluster_fast // group_size
109
+ num_clusters_in_group = group_size * ncluster_slow
110
+ return TileScheduler.Params(
111
+ problem_shape_ncluster_mnl,
112
+ raster_order,
113
+ FastDivmod.create(num_clusters_per_problem),
114
+ num_groups_regular,
115
+ FastDivmod.create(group_size),
116
+ # Don't divide by 0
117
+ FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
118
+ FastDivmod.create(num_clusters_in_group),
119
+ args.tile_count_semaphore if const_expr(args.is_persistent) else None,
120
+ args.batch_idx_permute,
121
+ cluster_shape_mn,
122
+ args.is_persistent,
123
+ )
124
+
125
+ def __init__(
126
+ self,
127
+ current_work_linear_idx: Int32,
128
+ num_tiles_executed: Int32,
129
+ tile_count: Optional[cute.Tensor],
130
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
131
+ pipeline_state: PipelineStateWAdvance,
132
+ params: Params,
133
+ *,
134
+ loc=None,
135
+ ip=None,
136
+ ):
137
+ self._current_work_linear_idx = current_work_linear_idx
138
+ self._num_tiles_executed = num_tiles_executed
139
+ self._tile_count = tile_count
140
+ self._scheduler_pipeline = scheduler_pipeline
141
+ self._pipeline_state = pipeline_state
142
+ self.params = params
143
+ self._loc = loc
144
+ self._ip = ip
145
+
146
+ @staticmethod
147
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
148
+ return TileScheduler.Params.create(args, loc=loc, ip=ip)
149
+
150
+ @staticmethod
151
+ @cute.jit
152
+ def create(
153
+ params: Params,
154
+ tile_count: Optional[cute.Tensor] = None,
155
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
156
+ is_scheduler_warp: bool | Boolean = False,
157
+ *,
158
+ loc=None,
159
+ ip=None,
160
+ ) -> "TileScheduler":
161
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
162
+ stages = 0
163
+ if const_expr(not params.is_persistent):
164
+ cidx, cidy, _ = cute.arch.cluster_idx()
165
+ cdimx, _, _ = cute.arch.cluster_dim()
166
+ cluster_id = cidx + cidy * cdimx
167
+ current_work_linear_idx = Int32(cluster_id)
168
+ else:
169
+ _, _, bidz = cute.arch.block_idx()
170
+ current_work_linear_idx = Int32(bidz)
171
+ if const_expr(params.tile_count_semaphore is not None):
172
+ assert tile_count is not None
173
+ assert scheduler_pipeline is not None
174
+ stages = const_expr(cute.size(tile_count))
175
+ return TileScheduler(
176
+ current_work_linear_idx,
177
+ Int32(0), # num_tiles_executed
178
+ tile_count,
179
+ scheduler_pipeline,
180
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
181
+ params,
182
+ loc=loc,
183
+ ip=ip,
184
+ )
185
+
186
+ # called by host
187
+ @staticmethod
188
+ def get_grid_shape(
189
+ params: Params,
190
+ max_active_clusters: Int32,
191
+ *,
192
+ loc=None,
193
+ ip=None,
194
+ ) -> Tuple[Int32, Int32, Int32]:
195
+ num_ctas_mnl = tuple(
196
+ x * y for x, y in zip(params.problem_shape_ncluster_mnl, params.cluster_shape_mn)
197
+ ) + (params.problem_shape_ncluster_mnl[2],)
198
+ if const_expr(not params.is_persistent):
199
+ return num_ctas_mnl
200
+ else:
201
+ num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
202
+ num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip)
203
+ # Total ctas that can run in one wave
204
+ num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
205
+ num_persistent_ctas = cutlass.min(num_ctas_in_problem, num_ctas_per_wave)
206
+ num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
207
+ return (*params.cluster_shape_mn, num_persistent_clusters)
208
+
209
+ @cute.jit
210
+ def _swizzle_cta(
211
+ self, cluster_id_in_problem: Int32, *, loc=None, ip=None
212
+ ) -> Tuple[Int32, Int32]:
213
+ # CTA Swizzle to promote L2 data reuse
214
+ params = self.params
215
+ group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(cluster_id_in_problem)
216
+ cid_fast_in_group, cid_slow = Int32(0), Int32(0)
217
+ if group_id < params.num_groups_regular:
218
+ cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group)
219
+ else: # tail part
220
+ cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group)
221
+ if group_id % 2 == 1: # serpentine order
222
+ ncluster_slow = (
223
+ params.problem_shape_ncluster_mnl[1]
224
+ if params.raster_order == RasterOrder.AlongM
225
+ else params.problem_shape_ncluster_mnl[0]
226
+ )
227
+ cid_slow = ncluster_slow - 1 - cid_slow
228
+ cid_fast = group_id * params.group_size_divmod.divisor + cid_fast_in_group
229
+ cid_m, cid_n = cid_fast, cid_slow
230
+ if params.raster_order == RasterOrder.AlongN:
231
+ cid_m, cid_n = cid_slow, cid_fast
232
+ return cid_m, cid_n
233
+
234
+ @cute.jit
235
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
236
+ params = self.params
237
+ if const_expr(not params.is_persistent):
238
+ cluster_id_in_problem = self._current_work_linear_idx
239
+ _, _, bidz = cute.arch.block_idx()
240
+ else:
241
+ bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
242
+ self._current_work_linear_idx
243
+ )
244
+ cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip)
245
+ # Get the pid from cluster id
246
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
247
+ pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
248
+ pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
249
+ batch_idx = (
250
+ bidz if const_expr(params.batch_idx_permute is None) else params.batch_idx_permute[bidz]
251
+ )
252
+ tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
253
+ if const_expr(not params.is_persistent):
254
+ is_valid = self._num_tiles_executed == 0
255
+ else:
256
+ is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
257
+ return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
258
+
259
+ def initial_work_tile_info(self, *, loc=None, ip=None):
260
+ return self.get_current_work(loc=loc, ip=ip)
261
+
262
+ @cute.jit
263
+ def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
264
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
265
+ params = self.params
266
+ if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
267
+ current_work_linear_idx = self._current_work_linear_idx
268
+ if is_scheduler_warp:
269
+ if cute.arch.lane_idx() == 0:
270
+ num_persistent_clusters = cute.arch.grid_dim()[2]
271
+ current_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32(
272
+ cute.size(params.problem_shape_ncluster_mnl) - 1,
273
+ params.tile_count_semaphore,
274
+ )
275
+ # lane 0 already has the right tile_idx, just need to broadcast
276
+ current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
277
+ self._current_work_linear_idx = current_work_linear_idx
278
+
279
+ # We have to split broadcast_next_work and advance_to_next_work into two functions
280
+ # due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
281
+ @cute.jit
282
+ def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
283
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
284
+ params = self.params
285
+ if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
286
+ current_work_linear_idx = self._current_work_linear_idx
287
+ if is_scheduler_warp:
288
+ self._scheduler_pipeline.producer_acquire(self._pipeline_state)
289
+ lane_idx = cute.arch.lane_idx()
290
+ if lane_idx < cute.size(params.cluster_shape_mn):
291
+ # cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
292
+ if const_expr(cute.size(params.cluster_shape_mn) == 1):
293
+ self._tile_count[self._pipeline_state.index] = current_work_linear_idx
294
+ self._scheduler_pipeline.producer_commit(self._pipeline_state)
295
+ else:
296
+ peer_cta_rank_in_cluster = lane_idx
297
+ mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
298
+ self._pipeline_state
299
+ )
300
+ cute.arch.mbarrier_arrive_and_expect_tx(
301
+ mbar_ptr, 4, peer_cta_rank_in_cluster
302
+ )
303
+ utils.store_shared_remote(
304
+ val=current_work_linear_idx,
305
+ smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
306
+ mbar_ptr=mbar_ptr,
307
+ peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
308
+ )
309
+ # cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
310
+
311
+ @cute.jit
312
+ def advance_to_next_work(
313
+ self,
314
+ is_scheduler_warp: bool | Boolean = False,
315
+ *,
316
+ advance_count: int = 1,
317
+ loc=None,
318
+ ip=None,
319
+ ):
320
+ tidx = cute.arch.thread_idx()[0]
321
+ bidx = cute.arch.block_idx()[0]
322
+ params = self.params
323
+ if const_expr(params.is_persistent):
324
+ num_persistent_clusters = cute.arch.grid_dim()[2]
325
+ if const_expr(params.tile_count_semaphore is None): # Static persistent
326
+ self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
327
+ else: # Dynamic persistent
328
+ if const_expr(advance_count > 1):
329
+ self._pipeline_state.advance_iters(advance_count - 1)
330
+ current_work_linear_idx = self._current_work_linear_idx
331
+ if not is_scheduler_warp:
332
+ # if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
333
+ self._scheduler_pipeline.consumer_wait(self._pipeline_state)
334
+ # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
335
+ current_work_linear_idx = self._tile_count[self._pipeline_state.index]
336
+ # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after smem read, idx = {}", bidx, tidx, current_work_linear_idx)
337
+ cute.arch.sync_warp()
338
+ with cute.arch.elect_one():
339
+ # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, before empty arrive", bidx, tidx)
340
+ self._scheduler_pipeline.consumer_release(self._pipeline_state)
341
+ # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after empty arrive", bidx, tidx)
342
+ self._current_work_linear_idx = current_work_linear_idx
343
+ self._pipeline_state.advance()
344
+ self._num_tiles_executed += Int32(advance_count)
345
+
346
+ def producer_tail(self):
347
+ if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
348
+ self._scheduler_pipeline.producer_tail(self._pipeline_state)
349
+
350
+ @property
351
+ def num_tiles_executed(self) -> Int32:
352
+ return self._num_tiles_executed
353
+
354
+ def __extract_mlir_values__(self):
355
+ values, self._values_pos = [], []
356
+ for obj in [
357
+ self._current_work_linear_idx,
358
+ self._num_tiles_executed,
359
+ self._tile_count,
360
+ self._scheduler_pipeline,
361
+ self._pipeline_state,
362
+ self.params,
363
+ ]:
364
+ obj_values = cutlass.extract_mlir_values(obj)
365
+ values += obj_values
366
+ self._values_pos.append(len(obj_values))
367
+ return values
368
+
369
+ def __new_from_mlir_values__(self, values):
370
+ obj_list = []
371
+ for obj, n_items in zip(
372
+ [
373
+ self._current_work_linear_idx,
374
+ self._num_tiles_executed,
375
+ self._tile_count,
376
+ self._scheduler_pipeline,
377
+ self._pipeline_state,
378
+ self.params,
379
+ ],
380
+ self._values_pos,
381
+ ):
382
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
383
+ values = values[n_items:]
384
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)
385
+
386
+
387
+ @cute.jit
388
+ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
389
+ """
390
+ Convert a triangular index to 2D coordinates.
391
+ This is used to convert the linear index to 2D coordinates for triangular matrices.
392
+ """
393
+ row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1
394
+ col = idx - (row * (row + 1)) // 2
395
+ return row, col
396
+
397
+
398
+ class TriangularTileScheduler(TileScheduler):
399
+ """We assume the tile size per cluster is square (e.g., 128 x 256 per CTA, with cluster 2 x 1)"""
400
+
401
+ @dataclass
402
+ class Params(ParamsBase):
403
+ problem_shape_ncluster_mnl: cute.Shape
404
+ num_clusters_per_problem_divmod: FastDivmod
405
+ group_size_inv_f32: cutlass.Float32
406
+ num_groups_regular: Int32
407
+ group_size_divmod: FastDivmod
408
+ group_size_tail_divmod: FastDivmod
409
+ group_size_mul_group_size_divmod: FastDivmod
410
+ group_size_tail_mul_group_size_divmod: FastDivmod
411
+ tile_count_semaphore: Optional[cute.Pointer]
412
+ cluster_shape_mn: cutlass.Constexpr[cute.Shape]
413
+ is_persistent: cutlass.Constexpr[bool]
414
+
415
+ @staticmethod
416
+ @cute.jit
417
+ def create(
418
+ args: TileSchedulerArguments, *, loc=None, ip=None
419
+ ) -> "TriangularTileScheduler.Params":
420
+ assert args.cluster_shape_mnk[2] == 1
421
+ cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
422
+ problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
423
+ problem_shape_ncluster_mn = cute.ceil_div(problem_shape_ntile_mn, cluster_shape_mn)
424
+ problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
425
+ args.problem_shape_ntile_mnl[2],
426
+ )
427
+ cluster_m = problem_shape_ncluster_mn[0]
428
+ # Assume that each cluster is responsible for a square tile
429
+ num_clusters_per_problem = cluster_m * (cluster_m + 1) // 2
430
+ group_size = min(args.group_size, cluster_m)
431
+ group_size_tail = cluster_m % group_size
432
+ num_groups_regular = cluster_m // group_size
433
+ return TriangularTileScheduler.Params(
434
+ problem_shape_ncluster_mnl,
435
+ FastDivmod.create(num_clusters_per_problem),
436
+ cutlass.Float32(1.0 / group_size),
437
+ num_groups_regular,
438
+ FastDivmod.create(group_size),
439
+ # Don't divide by 0
440
+ FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
441
+ FastDivmod.create(group_size * group_size),
442
+ FastDivmod.create((group_size_tail if group_size_tail > 0 else 1) * group_size),
443
+ args.tile_count_semaphore if const_expr(args.is_persistent) else None,
444
+ cluster_shape_mn,
445
+ args.is_persistent,
446
+ )
447
+
448
+ @staticmethod
449
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
450
+ return TriangularTileScheduler.Params.create(args, loc=loc, ip=ip)
451
+
452
+ @staticmethod
453
+ @cute.jit
454
+ def create(
455
+ params: Params,
456
+ tile_count: Optional[cute.Tensor] = None,
457
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
458
+ is_scheduler_warp: bool | Boolean = False,
459
+ *,
460
+ loc=None,
461
+ ip=None,
462
+ ) -> "TriangularTileScheduler":
463
+ stages = 0
464
+ if const_expr(not params.is_persistent):
465
+ cluster_id, _, _ = cute.arch.cluster_idx()
466
+ current_work_linear_idx = Int32(cluster_id)
467
+ else:
468
+ _, _, bidz = cute.arch.block_idx()
469
+ current_work_linear_idx = Int32(bidz)
470
+ if const_expr(params.tile_count_semaphore is not None):
471
+ assert tile_count is not None
472
+ assert scheduler_pipeline is not None
473
+ stages = const_expr(cute.size(tile_count))
474
+ return TriangularTileScheduler(
475
+ current_work_linear_idx,
476
+ Int32(0), # num_tiles_executed
477
+ tile_count,
478
+ scheduler_pipeline,
479
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
480
+ params,
481
+ loc=loc,
482
+ ip=ip,
483
+ )
484
+
485
+ # called by host
486
+ @staticmethod
487
+ def get_grid_shape(
488
+ params: Params,
489
+ max_active_clusters: Int32,
490
+ *,
491
+ loc=None,
492
+ ip=None,
493
+ ) -> Tuple[Int32, Int32, Int32]:
494
+ clusters = (
495
+ params.num_clusters_per_problem_divmod.divisor,
496
+ 1,
497
+ params.problem_shape_ncluster_mnl[2],
498
+ )
499
+ num_ctas_mnl = tuple(x * y for x, y in zip(clusters, params.cluster_shape_mn)) + (
500
+ params.problem_shape_ncluster_mnl[2],
501
+ )
502
+ if const_expr(not params.is_persistent):
503
+ return num_ctas_mnl
504
+ else:
505
+ num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
506
+ num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip)
507
+ # Total ctas that can run in one wave
508
+ num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
509
+ num_persistent_ctas = cutlass.min(num_ctas_in_problem, num_ctas_per_wave)
510
+ num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
511
+ return (*params.cluster_shape_mn, num_persistent_clusters)
512
+
513
+ @cute.jit
514
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
515
+ params = self.params
516
+ if const_expr(not params.is_persistent):
517
+ cluster_id_in_problem = self._current_work_linear_idx
518
+ _, _, bidz = cute.arch.block_idx()
519
+ else:
520
+ bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
521
+ self._current_work_linear_idx
522
+ )
523
+ # CTA Swizzle to promote L2 data reuse
524
+ group_size = params.group_size_divmod.divisor
525
+ group_id = (
526
+ utils.ceil(
527
+ (utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
528
+ )
529
+ - 1
530
+ )
531
+ cid_m_start = group_id * group_size
532
+ id_in_group = cluster_id_in_problem - (cid_m_start * (cid_m_start + 1)) // 2
533
+ group_size_actual = (
534
+ group_size
535
+ if group_id < params.num_groups_regular
536
+ else params.group_size_tail_divmod.divisor
537
+ )
538
+ group_col, group_remainder = Int32(0), Int32(0)
539
+ if group_id < params.num_groups_regular:
540
+ group_col, group_remainder = params.group_size_mul_group_size_divmod.divmod(id_in_group)
541
+ else: # tail part
542
+ group_col, group_remainder = params.group_size_tail_mul_group_size_divmod.divmod(
543
+ id_in_group
544
+ )
545
+ cid_m_in_group, cid_n_in_group = Int32(0), Int32(0)
546
+ if id_in_group >= group_size_actual * group_size * group_id: # triangular tail
547
+ cid_m_in_group, cid_n_in_group = triangular_idx_to_coord(group_remainder)
548
+ else:
549
+ if group_id < params.num_groups_regular:
550
+ cid_n_in_group, cid_m_in_group = params.group_size_divmod.divmod(group_remainder)
551
+ else:
552
+ cid_n_in_group, cid_m_in_group = params.group_size_tail_divmod.divmod(
553
+ group_remainder
554
+ )
555
+ cid_m = cid_m_start + cid_m_in_group
556
+ cid_n = group_col * group_size + cid_n_in_group
557
+
558
+ # Get the pid from cluster id
559
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
560
+ pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
561
+ pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
562
+ tile_coord_mnkl = (pid_m, pid_n, None, bidz)
563
+ if const_expr(not params.is_persistent):
564
+ is_valid = self._num_tiles_executed == 0
565
+ else:
566
+ is_valid = (
567
+ self._current_work_linear_idx
568
+ < params.num_clusters_per_problem_divmod.divisor
569
+ * params.problem_shape_ncluster_mnl[2]
570
+ )
571
+ # bidx, bidy, bidz = cute.arch.block_idx()
572
+ # tidx, _, _ = cute.arch.thread_idx()
573
+ # if tidx == 0:
574
+ # cute.printf("bidx = {}, bidy = {}, group_id = {}, id_in_group = {}, group_size_actual = {}, group_col = {}, group_remainder = {}, cid_n_in_group = {}, cid_m_in_group = {}, cid_m = {}, cid_n = {}, is_valid = {}",
575
+ # bidx, bidy, group_id, id_in_group, group_size_actual, group_col, group_remainder, cid_n_in_group, cid_m_in_group, cid_m, cid_n, is_valid)
576
+ return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
577
+
578
+
579
+ @dataclass
580
+ class VarlenMTileSchedulerArguments(ParamsBase):
581
+ problem_shape_ntile_mnl: cute.Shape
582
+ total_m: Int32
583
+ cu_seqlens_m: cute.Tensor
584
+ raster_order: cutlass.Constexpr[RasterOrderOption]
585
+ group_size: Int32
586
+ tile_shape_mn: cutlass.Constexpr[cute.Shape]
587
+ cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
588
+ tile_count_semaphore: Optional[cute.Pointer] = None
589
+ is_persistent: cutlass.Constexpr[bool] = False
590
+
591
+
592
+ class VarlenMTileScheduler(TileScheduler):
593
+ @dataclass
594
+ class Params(ParamsBase):
595
+ problem_shape_ncluster_mnl: cute.Shape
596
+ total_m: Int32
597
+ cu_seqlens_m: cute.Tensor
598
+ raster_order: cutlass.Constexpr[RasterOrder]
599
+ group_size: Int32
600
+ group_size_divmod: Optional[FastDivmod]
601
+ group_size_tail_divmod: Optional[FastDivmod]
602
+ num_clusters_in_group_divmod: FastDivmod
603
+ tile_shape_mn: cutlass.Constexpr[cute.Shape]
604
+ tile_count_semaphore: Optional[cute.Pointer]
605
+ cluster_shape_mn: cutlass.Constexpr[cute.Shape]
606
+ is_persistent: cutlass.Constexpr[bool]
607
+
608
+ @staticmethod
609
+ @cute.jit
610
+ def create(
611
+ args: TileSchedulerArguments, *, loc=None, ip=None
612
+ ) -> "VarlenMTileScheduler.Params":
613
+ assert args.cluster_shape_mnk[2] == 1
614
+ cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
615
+ # problem_shape_ntile_mnl[0] will be None for VarlenM
616
+ problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
617
+ problem_shape_ncluster_mn = (
618
+ None,
619
+ cute.ceil_div(problem_shape_ntile_mn[1], cluster_shape_mn[1]),
620
+ )
621
+ problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
622
+ args.problem_shape_ntile_mnl[2],
623
+ )
624
+ raster_order = const_expr(
625
+ RasterOrder.AlongM
626
+ if args.raster_order == RasterOrderOption.AlongM
627
+ else RasterOrder.AlongN # For Heuristic we also use AlongN
628
+ )
629
+ ncluster_fast = (
630
+ problem_shape_ncluster_mn[0]
631
+ if raster_order == RasterOrder.AlongM
632
+ else problem_shape_ncluster_mn[1]
633
+ )
634
+ ncluster_slow = (
635
+ problem_shape_ncluster_mn[1]
636
+ if raster_order == RasterOrder.AlongM
637
+ else problem_shape_ncluster_mn[0]
638
+ )
639
+ if const_expr(ncluster_fast is not None):
640
+ group_size = min(args.group_size, ncluster_fast)
641
+ group_size_tail = ncluster_fast % group_size
642
+ else:
643
+ group_size, group_size_tail = args.group_size, None
644
+ if const_expr(ncluster_slow is not None):
645
+ num_clusters_in_group = group_size * ncluster_slow
646
+ else:
647
+ num_clusters_in_group = None
648
+ return VarlenMTileScheduler.Params(
649
+ problem_shape_ncluster_mnl,
650
+ args.total_m,
651
+ args.cu_seqlens_m,
652
+ raster_order,
653
+ group_size,
654
+ FastDivmod.create(group_size) if ncluster_fast is not None else None,
655
+ # Don't divide by 0
656
+ FastDivmod.create(group_size_tail if group_size_tail > 0 else 1)
657
+ if group_size_tail is not None
658
+ else None,
659
+ FastDivmod.create(num_clusters_in_group)
660
+ if num_clusters_in_group is not None
661
+ else None,
662
+ args.tile_shape_mn,
663
+ args.tile_count_semaphore if const_expr(args.is_persistent) else None,
664
+ cluster_shape_mn,
665
+ args.is_persistent,
666
+ )
667
+
668
+ def __init__(
669
+ self,
670
+ current_work_linear_idx: Int32,
671
+ num_tiles_executed: Int32,
672
+ current_batch_idx: Int32,
673
+ num_work_idx_before_cur_batch: Int32,
674
+ tile_count: Optional[cute.Tensor],
675
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
676
+ pipeline_state: PipelineStateWAdvance,
677
+ params: Params,
678
+ *,
679
+ loc=None,
680
+ ip=None,
681
+ ):
682
+ self._current_work_linear_idx = current_work_linear_idx
683
+ self._num_tiles_executed = num_tiles_executed
684
+ self._current_batch_idx = current_batch_idx
685
+ self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
686
+ self._tile_count = tile_count
687
+ self._scheduler_pipeline = scheduler_pipeline
688
+ self._pipeline_state = pipeline_state
689
+ self.params = params
690
+ self._loc = loc
691
+ self._ip = ip
692
+
693
+ @staticmethod
694
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
695
+ return VarlenMTileScheduler.Params.create(args, loc=loc, ip=ip)
696
+
697
+ @staticmethod
698
+ @cute.jit
699
+ def create(
700
+ params: Params,
701
+ tile_count: Optional[cute.Tensor] = None,
702
+ scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
703
+ is_scheduler_warp: bool | Boolean = False,
704
+ *,
705
+ loc=None,
706
+ ip=None,
707
+ ) -> "VarlenMTileScheduler":
708
+ stages = 0
709
+ _, _, bidz = cute.arch.block_idx()
710
+ current_work_linear_idx = Int32(bidz)
711
+ if const_expr(params.tile_count_semaphore is not None):
712
+ assert tile_count is not None
713
+ assert scheduler_pipeline is not None
714
+ stages = const_expr(cute.size(tile_count))
715
+ return VarlenMTileScheduler(
716
+ current_work_linear_idx,
717
+ Int32(0), # num_tiles_executed
718
+ Int32(0), # current_batch_idx
719
+ Int32(0), # num_work_idx_before_cur_batch
720
+ tile_count,
721
+ scheduler_pipeline,
722
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
723
+ params,
724
+ loc=loc,
725
+ ip=ip,
726
+ )
727
+
728
+ # called by host
729
+ @staticmethod
730
+ def get_grid_shape(
731
+ params: Params,
732
+ max_active_clusters: Int32,
733
+ *,
734
+ loc=None,
735
+ ip=None,
736
+ ) -> Tuple[Int32, Int32, Int32]:
737
+ block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0]
738
+ num_batch = params.problem_shape_ncluster_mnl[2]
739
+ total_clusters_m_max = (params.total_m + num_batch * (block_size - 1)) // block_size
740
+ total_clusters_max = total_clusters_m_max * params.problem_shape_ncluster_mnl[1]
741
+ if const_expr(not params.is_persistent):
742
+ return (*params.cluster_shape_mn, total_clusters_max)
743
+ else:
744
+ num_persistent_clusters = cutlass.min(max_active_clusters, total_clusters_max)
745
+ return (*params.cluster_shape_mn, num_persistent_clusters)
746
+
747
+ @cute.jit
748
+ def _get_num_m_blocks(
749
+ self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int]
750
+ ) -> Int32:
751
+ num_batch = self.params.problem_shape_ncluster_mnl[2]
752
+ batch_idx = lane + bidb_start
753
+ cur_cu_seqlen = Int32(0)
754
+ if batch_idx <= num_batch:
755
+ cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx]
756
+ next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
757
+ seqlen = next_cu_seqlen - cur_cu_seqlen
758
+ return (
759
+ cute.ceil_div(seqlen, block_size)
760
+ if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1
761
+ else Int32(0)
762
+ )
763
+
764
+ @cute.jit
765
+ def _swizzle_cta(
766
+ self, cluster_id_in_problem: Int32, num_clusters_m: Int32, *, loc=None, ip=None
767
+ ) -> Tuple[Int32, Int32]:
768
+ params = self.params
769
+ # CTA Swizzle to promote L2 data reuse
770
+ if const_expr(params.num_clusters_in_group_divmod is not None):
771
+ group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(
772
+ cluster_id_in_problem
773
+ )
774
+ num_clusters_in_group = params.num_clusters_in_group_divmod.divisor
775
+ else:
776
+ assert params.raster_order == RasterOrder.AlongN
777
+ num_clusters_in_group = params.group_size * num_clusters_m
778
+ group_id = cluster_id_in_problem // num_clusters_in_group
779
+ id_in_group = cluster_id_in_problem - group_id * num_clusters_in_group
780
+ cid_fast_in_group, cid_slow = Int32(0), Int32(0)
781
+ if const_expr(
782
+ params.group_size_divmod is not None and params.group_size_tail_divmod is not None
783
+ ):
784
+ num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
785
+ if (group_id + 1) * num_clusters_in_group <= num_clusters:
786
+ cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group)
787
+ else: # tail part
788
+ cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group)
789
+ else:
790
+ assert params.raster_order == RasterOrder.AlongM
791
+ group_size_actual = cutlass.min(
792
+ params.group_size, num_clusters_m - group_id * params.group_size
793
+ )
794
+ cid_slow = id_in_group // group_size_actual
795
+ cid_fast_in_group = id_in_group - cid_slow * group_size_actual
796
+ if group_id % 2 == 1: # serpentine order
797
+ ncluster_slow = (
798
+ params.problem_shape_ncluster_mnl[1]
799
+ if params.raster_order == RasterOrder.AlongM
800
+ else num_clusters_m
801
+ )
802
+ cid_slow = ncluster_slow - 1 - cid_slow
803
+ cid_fast = group_id * params.group_size + cid_fast_in_group
804
+ cid_m, cid_n = cid_fast, cid_slow
805
+ if params.raster_order == RasterOrder.AlongN:
806
+ cid_m, cid_n = cid_slow, cid_fast
807
+ return cid_m, cid_n
808
+
809
+ @cute.jit
810
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
811
+ params = self.params
812
+ lane_idx = cute.arch.lane_idx()
813
+ num_batch = self.params.problem_shape_ncluster_mnl[2]
814
+ block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0]
815
+ batch_idx = self._current_batch_idx
816
+ num_clusters_m = self._get_num_m_blocks(
817
+ lane_idx, bidb_start=batch_idx, block_size=block_size
818
+ )
819
+ num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
820
+ num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx)
821
+ # Total number of blocks for the next 31 problems, same for all lanes
822
+ clusters_in_problems = cute.arch.shuffle_sync(
823
+ num_clusters_cumulative, cute.arch.WARP_SIZE - 1
824
+ )
825
+ problems_end_tile = self._num_work_idx_before_cur_batch + clusters_in_problems
826
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, num_clusters_cumulative = %d, problems_end_tile = %d", self._tile_idx, problems_end_tile, num_clusters_m, num_clusters_cumulative, problems_end_tile)
827
+ cid_m, cid_n = Int32(0), Int32(0)
828
+ next_tile_idx = self._current_work_linear_idx
829
+ while problems_end_tile <= next_tile_idx:
830
+ batch_idx += cute.arch.WARP_SIZE - 1
831
+ if batch_idx >= num_batch:
832
+ batch_idx = Int32(num_batch)
833
+ problems_end_tile = next_tile_idx + 1
834
+ else:
835
+ num_clusters_m = self._get_num_m_blocks(
836
+ lane_idx, bidb_start=batch_idx, block_size=block_size
837
+ )
838
+ num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
839
+ num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx)
840
+ clusters_in_problems = cute.arch.shuffle_sync(
841
+ num_clusters_cumulative, cute.arch.WARP_SIZE - 1
842
+ )
843
+ problems_end_tile += clusters_in_problems
844
+ # Just a placeholer value in case batch_idx >= num_batch
845
+ num_work_idx_before_cur_batch = problems_end_tile - clusters_in_problems
846
+ if batch_idx >= num_batch:
847
+ cid_m, cid_n, batch_idx = Int32(0), Int32(0), Int32(num_batch)
848
+ else:
849
+ problems_start_tile = problems_end_tile - clusters_in_problems
850
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, batch_idx = %d", self._tile_idx, problems_end_tile, num_clusters_m, batch_idx)
851
+ # The next problem to process is the first one that does not have ending tile position
852
+ # that is greater than or equal to tile index.
853
+ batch_idx_in_problems = cute.arch.popc(
854
+ cute.arch.vote_ballot_sync(
855
+ problems_start_tile + num_clusters_cumulative <= next_tile_idx
856
+ )
857
+ )
858
+ batch_idx += batch_idx_in_problems
859
+ num_clusters_prev_lane = (
860
+ 0
861
+ if batch_idx_in_problems == 0
862
+ else cute.arch.shuffle_sync(num_clusters_cumulative, batch_idx_in_problems - 1)
863
+ )
864
+ num_clusters_m = cute.arch.shuffle_sync(num_clusters_m, batch_idx_in_problems)
865
+ num_work_idx_before_cur_batch = problems_start_tile + num_clusters_prev_lane
866
+ cluster_id_in_problem = next_tile_idx - num_work_idx_before_cur_batch
867
+ # cid_n = cluster_id_in_problem // num_clusters_m
868
+ # cid_m = cluster_id_in_problem - cid_n * num_clusters_m
869
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, cid_n=%d, cid_m=%d, is_valid = %d", self._tile_idx, batch_idx, cid_n, cid_m, is_valid)
870
+ cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, num_clusters_m, loc=loc, ip=ip)
871
+ self._current_batch_idx = batch_idx
872
+ self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
873
+
874
+ # Get the pid from cluster id
875
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
876
+ pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
877
+ pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
878
+ tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
879
+ if const_expr(not params.is_persistent):
880
+ is_valid = self._num_tiles_executed == 0 and batch_idx < num_batch
881
+ else:
882
+ is_valid = batch_idx < num_batch
883
+ return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
884
+
885
+ @cute.jit
886
+ def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
887
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
888
+ if const_expr(self.params.tile_count_semaphore is not None):
889
+ params = self.params
890
+ current_work_linear_idx = self._current_work_linear_idx
891
+ if is_scheduler_warp:
892
+ if cute.arch.lane_idx() == 0:
893
+ # cute.printf("before atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
894
+ num_persistent_clusters = cute.arch.grid_dim()[2]
895
+ current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
896
+ 1, params.tile_count_semaphore
897
+ )
898
+ # cute.printf("after atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
899
+ # lane 0 already has the right tile_idx, just need to broadcast
900
+ current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
901
+ self._current_work_linear_idx = current_work_linear_idx
902
+
903
+ def __extract_mlir_values__(self):
904
+ values, self._values_pos = [], []
905
+ for obj in [
906
+ self._current_work_linear_idx,
907
+ self._num_tiles_executed,
908
+ self._current_batch_idx,
909
+ self._num_work_idx_before_cur_batch,
910
+ self._tile_count,
911
+ self._scheduler_pipeline,
912
+ self._pipeline_state,
913
+ self.params,
914
+ ]:
915
+ obj_values = cutlass.extract_mlir_values(obj)
916
+ values += obj_values
917
+ self._values_pos.append(len(obj_values))
918
+ return values
919
+
920
+ def __new_from_mlir_values__(self, values):
921
+ obj_list = []
922
+ for obj, n_items in zip(
923
+ [
924
+ self._current_work_linear_idx,
925
+ self._num_tiles_executed,
926
+ self._current_batch_idx,
927
+ self._num_work_idx_before_cur_batch,
928
+ self._tile_count,
929
+ self._scheduler_pipeline,
930
+ self._pipeline_state,
931
+ self.params,
932
+ ],
933
+ self._values_pos,
934
+ ):
935
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
936
+ values = values[n_items:]
937
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)