quack-kernels 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/tile_scheduler.py CHANGED
@@ -25,6 +25,13 @@ class RasterOrder(IntEnum):
25
25
  AlongN = 1
26
26
 
27
27
 
28
+ class PersistenceMode(IntEnum):
29
+ NONE = 0
30
+ STATIC = 1
31
+ DYNAMIC = 2
32
+ CLC = 3
33
+
34
+
28
35
  @cute.jit
29
36
  def get_raster_order_from_option(
30
37
  raster_order_option: RasterOrderOption, problem_shape_ncluster_mn: cute.Shape, group_size: Int32
@@ -61,7 +68,7 @@ class TileSchedulerArguments(ArgumentsBase):
61
68
  cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
62
69
  tile_count_semaphore: Optional[cute.Pointer] = None
63
70
  batch_idx_permute: Optional[cute.Tensor] = None
64
- is_persistent: cutlass.Constexpr[bool] = False
71
+ persistence_mode: cutlass.Constexpr[PersistenceMode] = PersistenceMode.NONE
65
72
 
66
73
 
67
74
  class TileScheduler:
@@ -69,15 +76,15 @@ class TileScheduler:
69
76
  class Params(ParamsBase):
70
77
  problem_shape_ncluster_mnl: cute.Shape
71
78
  raster_order: RasterOrder
72
- num_clusters_per_problem_divmod: FastDivmod
79
+ num_clusters_per_problem_fdd: FastDivmod
73
80
  num_groups_regular: Int32
74
- group_size_divmod: FastDivmod
75
- group_size_tail_divmod: FastDivmod
76
- num_clusters_in_group_divmod: FastDivmod
81
+ group_size_fdd: FastDivmod
82
+ group_size_tail_fdd: FastDivmod
83
+ num_clusters_in_group_fdd: FastDivmod
77
84
  tile_count_semaphore: Optional[cute.Pointer]
78
85
  batch_idx_permute: Optional[cute.Tensor]
79
86
  cluster_shape_mn: cutlass.Constexpr[cute.Shape]
80
- is_persistent: cutlass.Constexpr[bool]
87
+ persistence_mode: cutlass.Constexpr[PersistenceMode]
81
88
 
82
89
  @staticmethod
83
90
  @cute.jit
@@ -107,26 +114,30 @@ class TileScheduler:
107
114
  group_size_tail = ncluster_fast % group_size
108
115
  num_groups_regular = ncluster_fast // group_size
109
116
  num_clusters_in_group = group_size * ncluster_slow
117
+ if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC):
118
+ assert args.tile_count_semaphore is not None
110
119
  return TileScheduler.Params(
111
120
  problem_shape_ncluster_mnl,
112
121
  raster_order,
113
- FastDivmod.create(num_clusters_per_problem),
122
+ FastDivmod(num_clusters_per_problem),
114
123
  num_groups_regular,
115
- FastDivmod.create(group_size),
124
+ FastDivmod(group_size),
116
125
  # Don't divide by 0
117
- FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
118
- FastDivmod.create(num_clusters_in_group),
119
- args.tile_count_semaphore if const_expr(args.is_persistent) else None,
126
+ FastDivmod(group_size_tail if group_size_tail > 0 else 1),
127
+ FastDivmod(num_clusters_in_group),
128
+ args.tile_count_semaphore
129
+ if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC)
130
+ else None,
120
131
  args.batch_idx_permute,
121
132
  cluster_shape_mn,
122
- args.is_persistent,
133
+ args.persistence_mode,
123
134
  )
124
135
 
125
136
  def __init__(
126
137
  self,
127
138
  current_work_linear_idx: Int32,
128
139
  num_tiles_executed: Int32,
129
- tile_count: Optional[cute.Tensor],
140
+ sched_smem: Optional[cute.Tensor],
130
141
  scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
131
142
  pipeline_state: PipelineStateWAdvance,
132
143
  params: Params,
@@ -136,7 +147,7 @@ class TileScheduler:
136
147
  ):
137
148
  self._current_work_linear_idx = current_work_linear_idx
138
149
  self.num_tiles_executed = num_tiles_executed
139
- self._tile_count = tile_count
150
+ self._sched_smem = sched_smem
140
151
  self._scheduler_pipeline = scheduler_pipeline
141
152
  self._pipeline_state = pipeline_state
142
153
  self.params = params
@@ -151,16 +162,14 @@ class TileScheduler:
151
162
  @cute.jit
152
163
  def create(
153
164
  params: Params,
154
- tile_count: Optional[cute.Tensor] = None,
165
+ sched_smem: Optional[cute.Tensor] = None,
155
166
  scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
156
- is_scheduler_warp: bool | Boolean = False,
157
167
  *,
158
168
  loc=None,
159
169
  ip=None,
160
170
  ) -> "TileScheduler":
161
171
  """is_scheduler_warp should only be true for one warp in the whole cluster"""
162
- stages = 0
163
- if const_expr(not params.is_persistent):
172
+ if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
164
173
  cidx, cidy, _ = cute.arch.cluster_idx()
165
174
  cdimx, _, _ = cute.arch.cluster_dim()
166
175
  cluster_id = cidx + cidy * cdimx
@@ -168,16 +177,20 @@ class TileScheduler:
168
177
  else:
169
178
  _, _, bidz = cute.arch.block_idx()
170
179
  current_work_linear_idx = Int32(bidz)
171
- if const_expr(params.tile_count_semaphore is not None):
172
- assert tile_count is not None
173
- assert scheduler_pipeline is not None
174
- stages = const_expr(cute.size(tile_count))
180
+ stages = 0
181
+ if const_expr(
182
+ params.persistence_mode
183
+ in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC]
184
+ ):
185
+ assert sched_smem is not None
186
+ assert scheduler_pipeline is not None
187
+ stages = const_expr(cute.size(sched_smem, mode=[1]))
175
188
  return TileScheduler(
176
189
  current_work_linear_idx,
177
190
  Int32(0), # num_tiles_executed
178
- tile_count,
191
+ sched_smem,
179
192
  scheduler_pipeline,
180
- PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
193
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)),
181
194
  params,
182
195
  loc=loc,
183
196
  ip=ip,
@@ -195,7 +208,7 @@ class TileScheduler:
195
208
  num_ctas_mnl = tuple(
196
209
  x * y for x, y in zip(params.problem_shape_ncluster_mnl, params.cluster_shape_mn)
197
210
  ) + (params.problem_shape_ncluster_mnl[2],)
198
- if const_expr(not params.is_persistent):
211
+ if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
199
212
  return num_ctas_mnl
200
213
  else:
201
214
  num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
@@ -212,12 +225,12 @@ class TileScheduler:
212
225
  ) -> Tuple[Int32, Int32]:
213
226
  # CTA Swizzle to promote L2 data reuse
214
227
  params = self.params
215
- group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(cluster_id_in_problem)
228
+ group_id, id_in_group = divmod(cluster_id_in_problem, params.num_clusters_in_group_fdd)
216
229
  cid_fast_in_group, cid_slow = Int32(0), Int32(0)
217
230
  if group_id < params.num_groups_regular:
218
- cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group)
231
+ cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_fdd)
219
232
  else: # tail part
220
- cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group)
233
+ cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_tail_fdd)
221
234
  if group_id % 2 == 1: # serpentine order
222
235
  ncluster_slow = (
223
236
  params.problem_shape_ncluster_mnl[1]
@@ -225,56 +238,151 @@ class TileScheduler:
225
238
  else params.problem_shape_ncluster_mnl[0]
226
239
  )
227
240
  cid_slow = ncluster_slow - 1 - cid_slow
228
- cid_fast = group_id * params.group_size_divmod.divisor + cid_fast_in_group
241
+ cid_fast = group_id * params.group_size_fdd.divisor + cid_fast_in_group
229
242
  cid_m, cid_n = cid_fast, cid_slow
230
243
  if params.raster_order == RasterOrder.AlongN:
231
244
  cid_m, cid_n = cid_slow, cid_fast
232
245
  return cid_m, cid_n
233
246
 
234
247
  @cute.jit
235
- def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
248
+ def _delinearize_work_idx(
249
+ self, *, block_zero_only: bool = False, loc=None, ip=None
250
+ ) -> cutlass.utils.WorkTileInfo:
236
251
  params = self.params
237
- if const_expr(not params.is_persistent):
238
- cluster_id_in_problem = self._current_work_linear_idx
239
- _, _, bidz = cute.arch.block_idx()
240
- else:
241
- bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
242
- self._current_work_linear_idx
243
- )
244
- cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip)
245
- # Get the pid from cluster id
246
- bidx_in_cluster = cute.arch.block_in_cluster_idx()
247
- pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
248
- pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
249
- batch_idx = (
250
- bidz if const_expr(params.batch_idx_permute is None) else params.batch_idx_permute[bidz]
251
- )
252
- tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
253
- if const_expr(not params.is_persistent):
252
+ if const_expr(params.persistence_mode == PersistenceMode.NONE):
254
253
  is_valid = self.num_tiles_executed == 0
255
254
  else:
256
255
  is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
256
+ pid_m, pid_n, batch_idx = Int32(0), Int32(0), Int32(0)
257
+ if is_valid:
258
+ if const_expr(params.persistence_mode == PersistenceMode.NONE):
259
+ cluster_id_in_problem = self._current_work_linear_idx
260
+ _, _, bidz = cute.arch.block_idx()
261
+ else:
262
+ bidz, cluster_id_in_problem = divmod(
263
+ self._current_work_linear_idx, params.num_clusters_per_problem_fdd
264
+ )
265
+ cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip)
266
+ if const_expr(block_zero_only):
267
+ bidx_in_cluster = (Int32(0), Int32(0))
268
+ else:
269
+ # Get the pid from cluster id
270
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
271
+ pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
272
+ pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
273
+ batch_idx = (
274
+ bidz
275
+ if const_expr(params.batch_idx_permute is None)
276
+ else params.batch_idx_permute[bidz]
277
+ )
278
+ tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
257
279
  return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
258
280
 
281
+ @cute.jit
282
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
283
+ params = self.params
284
+ pid_m, pid_n, batch_idx, is_valid = Int32(0), Int32(0), Int32(0), Boolean(False)
285
+ if const_expr(params.persistence_mode == PersistenceMode.NONE):
286
+ pass
287
+ # elif const_expr(params.persistence_mode == PersistenceMode.STATIC):
288
+ # return self._delinearize_work_idx(loc=loc, ip=ip)
289
+ else:
290
+ self._scheduler_pipeline.consumer_wait(self._pipeline_state)
291
+ pid_m, pid_n, batch_idx, is_valid_i32 = [
292
+ self._sched_smem[i, self._pipeline_state.index] for i in range(4)
293
+ ]
294
+ # Need this fence since the STAS from the producer is using the async proxy.
295
+ # Without this, we get race condition / deadlock.
296
+ if const_expr(cute.size(params.cluster_shape_mn) > 1):
297
+ cute.arch.fence_view_async_shared()
298
+ cute.arch.sync_warp()
299
+ with cute.arch.elect_one():
300
+ self._scheduler_pipeline.consumer_release(self._pipeline_state)
301
+ self._pipeline_state.advance()
302
+ is_valid = Boolean(is_valid_i32)
303
+ tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
304
+ return cutlass.utils.WorkTileInfo(tile_coord_mnkl, Boolean(is_valid))
305
+
306
+ # @cute.jit
259
307
  def initial_work_tile_info(self, *, loc=None, ip=None):
260
- return self.get_current_work(loc=loc, ip=ip)
308
+ return self._delinearize_work_idx(loc=loc, ip=ip)
309
+ # if is_scheduler_warp:
310
+ # work_tile_info = self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip)
311
+ # self.write_work_tile_to_smem(work_tile_info, loc=loc, ip=ip)
312
+ # self.write_work_tile_to_smem(self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip), loc=loc, ip=ip)
261
313
 
262
314
  @cute.jit
263
- def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
315
+ def _fetch_next_work_idx(self, *, loc=None, ip=None) -> Int32:
264
316
  """is_scheduler_warp should only be true for one warp in the whole cluster"""
265
317
  params = self.params
266
- if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
267
- current_work_linear_idx = self._current_work_linear_idx
268
- if is_scheduler_warp:
269
- if cute.arch.lane_idx() == 0:
270
- num_persistent_clusters = cute.arch.grid_dim()[2]
271
- current_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32(
318
+ num_persistent_clusters = Int32(cute.arch.grid_dim()[2])
319
+ if const_expr(params.persistence_mode == PersistenceMode.STATIC):
320
+ return self._current_work_linear_idx + num_persistent_clusters
321
+ elif const_expr(params.persistence_mode == PersistenceMode.DYNAMIC):
322
+ next_work_linear_idx = Int32(0)
323
+ if cute.arch.lane_idx() == 0:
324
+ # If varlen_m, problem_shape_ncluster_mnl[0] is None, so we use atomic_add
325
+ # instead of atomic_inc, and at the end of the kernel must reset the semaphore to 0.
326
+ # # cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
327
+ if const_expr(params.problem_shape_ncluster_mnl[0] is not None):
328
+ next_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32(
272
329
  cute.size(params.problem_shape_ncluster_mnl) - 1,
273
330
  params.tile_count_semaphore,
274
331
  )
275
- # lane 0 already has the right tile_idx, just need to broadcast
276
- current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
277
- self._current_work_linear_idx = current_work_linear_idx
332
+ else: # varlen_m
333
+ next_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
334
+ 1, params.tile_count_semaphore
335
+ )
336
+ # cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
337
+ return cute.arch.shuffle_sync(next_work_linear_idx, 0)
338
+ else:
339
+ return Int32(0)
340
+
341
+ @cute.jit
342
+ def write_work_tile_to_smem(
343
+ self, work_tile_info: cutlass.utils.WorkTileInfo, *, loc=None, ip=None
344
+ ):
345
+ params = self.params
346
+ if const_expr(self._sched_smem is not None):
347
+ # producer phase is always consumer_phase ^ 1
348
+ pipeline_state_producer = PipelineStateWAdvance(
349
+ self._pipeline_state.stages,
350
+ self._pipeline_state.count,
351
+ self._pipeline_state.index,
352
+ self._pipeline_state.phase ^ 1,
353
+ )
354
+ self._scheduler_pipeline.producer_acquire(pipeline_state_producer)
355
+ sched_data = [
356
+ work_tile_info.tile_idx[0],
357
+ work_tile_info.tile_idx[1],
358
+ work_tile_info.tile_idx[3],
359
+ Int32(work_tile_info.is_valid_tile),
360
+ ]
361
+ lane_idx = cute.arch.lane_idx()
362
+ if lane_idx < cute.size(params.cluster_shape_mn):
363
+ # cute.printf("Producer pid_m = {}, pid_n = {}, batch_idx = {}, is_valid = {}, after empty wait, idx = {}", sched_data[0], sched_data[1], sched_data[2], sched_data[3], self._current_work_linear_idx)
364
+ pipeline_idx = self._pipeline_state.index
365
+ if const_expr(cute.size(params.cluster_shape_mn) == 1):
366
+ for i in cutlass.range_constexpr(4):
367
+ self._sched_smem[i, pipeline_idx] = sched_data[i]
368
+ self._scheduler_pipeline.producer_commit(self._pipeline_state)
369
+ else:
370
+ peer_cta_rank_in_cluster = lane_idx
371
+ # Here we assume that the block idx in cluster is linearized such that
372
+ # x is the fastest moving direction.
373
+ bidx_in_cluster = peer_cta_rank_in_cluster % params.cluster_shape_mn[0]
374
+ bidy_in_cluster = peer_cta_rank_in_cluster // params.cluster_shape_mn[0]
375
+ mbar_ptr = self._scheduler_pipeline.producer_get_barrier(self._pipeline_state)
376
+ cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr, 16, peer_cta_rank_in_cluster)
377
+ utils.store_shared_remote_x4(
378
+ sched_data[0] + bidx_in_cluster,
379
+ sched_data[1] + bidy_in_cluster,
380
+ sched_data[2],
381
+ sched_data[3],
382
+ smem_ptr=self._sched_smem[None, pipeline_idx].iterator,
383
+ mbar_ptr=mbar_ptr,
384
+ peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
385
+ )
278
386
 
279
387
  @cute.jit
280
388
  def advance_to_next_work(
@@ -285,73 +393,37 @@ class TileScheduler:
285
393
  loc=None,
286
394
  ip=None,
287
395
  ):
288
- tidx = cute.arch.thread_idx()[0]
289
- bidx = cute.arch.block_idx()[0]
290
- bidz = cute.arch.block_idx()[2]
396
+ """is_scheduler_warp should only be true for one warp in the whole cluster.
397
+ Moreover, we assume that only block zero in the cluster is calling this function.
398
+ If calling with is_scheduler_warp = True, advance_count must be 1.
399
+ """
291
400
  params = self.params
292
- if const_expr(params.is_persistent):
293
- num_persistent_clusters = cute.arch.grid_dim()[2]
294
- if const_expr(params.tile_count_semaphore is None): # Static persistent
295
- self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
296
- else: # Dynamic persistent
297
- if const_expr(advance_count > 1):
298
- self._pipeline_state.advance_iters(advance_count - 1)
299
- current_work_linear_idx = self._current_work_linear_idx
300
- if is_scheduler_warp:
301
- self._scheduler_pipeline.producer_acquire(self._pipeline_state)
302
- lane_idx = cute.arch.lane_idx()
303
- if lane_idx < cute.size(params.cluster_shape_mn):
304
- # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after empty wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
305
- if const_expr(cute.size(params.cluster_shape_mn) == 1):
306
- self._tile_count[self._pipeline_state.index] = current_work_linear_idx
307
- self._scheduler_pipeline.producer_commit(self._pipeline_state)
308
- else:
309
- peer_cta_rank_in_cluster = lane_idx
310
- mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
311
- self._pipeline_state
312
- )
313
- cute.arch.mbarrier_arrive_and_expect_tx(
314
- mbar_ptr, 4, peer_cta_rank_in_cluster
315
- )
316
- utils.store_shared_remote(
317
- val=current_work_linear_idx,
318
- smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
319
- mbar_ptr=mbar_ptr,
320
- peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
321
- )
322
- # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after full arrive", bidx, bidz, tidx)
323
- else:
324
- # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
325
- self._scheduler_pipeline.consumer_wait(self._pipeline_state)
326
- # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
327
- current_work_linear_idx = self._tile_count[self._pipeline_state.index]
328
- # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after smem read, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
329
- # Need this fence since the STAS from the producer is using the async proxy.
330
- # Without this, we get race condition / deadlock.
331
- if const_expr(cute.size(params.cluster_shape_mn) > 1):
332
- cute.arch.fence_proxy(
333
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
334
- )
335
- cute.arch.sync_warp()
336
- with cute.arch.elect_one():
337
- # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before empty arrive", bidx, bidz, tidx)
338
- self._scheduler_pipeline.consumer_release(self._pipeline_state)
339
- # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
340
- # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
341
- self._current_work_linear_idx = current_work_linear_idx
342
- self._pipeline_state.advance()
343
401
  self.num_tiles_executed += Int32(advance_count)
402
+ if const_expr(self._pipeline_state is not None and advance_count > 1):
403
+ self._pipeline_state.advance_iters(advance_count - 1)
404
+ if const_expr(params.persistence_mode in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC]):
405
+ # We assume here that advance_count is 1 for scheduler_warp
406
+ if is_scheduler_warp:
407
+ self._current_work_linear_idx = self._fetch_next_work_idx(loc=loc, ip=ip)
408
+ work_tile_info = self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip)
409
+ self.write_work_tile_to_smem(work_tile_info, loc=loc, ip=ip)
344
410
 
345
411
  def producer_tail(self):
346
- if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
347
- self._scheduler_pipeline.producer_tail(self._pipeline_state)
412
+ if const_expr(self._scheduler_pipeline is not None):
413
+ pipeline_state_producer = PipelineStateWAdvance(
414
+ self._pipeline_state.stages,
415
+ self._pipeline_state.count,
416
+ self._pipeline_state.index,
417
+ self._pipeline_state.phase ^ 1,
418
+ )
419
+ self._scheduler_pipeline.producer_tail(pipeline_state_producer)
348
420
 
349
421
  def __extract_mlir_values__(self):
350
422
  values, self._values_pos = [], []
351
423
  for obj in [
352
424
  self._current_work_linear_idx,
353
425
  self.num_tiles_executed,
354
- self._tile_count,
426
+ self._sched_smem,
355
427
  self._scheduler_pipeline,
356
428
  self._pipeline_state,
357
429
  self.params,
@@ -367,7 +439,7 @@ class TileScheduler:
367
439
  [
368
440
  self._current_work_linear_idx,
369
441
  self.num_tiles_executed,
370
- self._tile_count,
442
+ self._sched_smem,
371
443
  self._scheduler_pipeline,
372
444
  self._pipeline_state,
373
445
  self.params,
@@ -396,16 +468,16 @@ class TriangularTileScheduler(TileScheduler):
396
468
  @dataclass
397
469
  class Params(ParamsBase):
398
470
  problem_shape_ncluster_mnl: cute.Shape
399
- num_clusters_per_problem_divmod: FastDivmod
471
+ num_clusters_per_problem_fdd: FastDivmod
400
472
  group_size_inv_f32: Float32
401
473
  num_groups_regular: Int32
402
- group_size_divmod: FastDivmod
403
- group_size_tail_divmod: FastDivmod
404
- group_size_mul_group_size_divmod: FastDivmod
405
- group_size_tail_mul_group_size_divmod: FastDivmod
474
+ group_size_fdd: FastDivmod
475
+ group_size_tail_fdd: FastDivmod
476
+ group_size_mul_group_size_fdd: FastDivmod
477
+ group_size_tail_mul_group_size_fdd: FastDivmod
406
478
  tile_count_semaphore: Optional[cute.Pointer]
407
479
  cluster_shape_mn: cutlass.Constexpr[cute.Shape]
408
- is_persistent: cutlass.Constexpr[bool]
480
+ persistence_mode: cutlass.Constexpr[PersistenceMode]
409
481
 
410
482
  @staticmethod
411
483
  @cute.jit
@@ -425,19 +497,23 @@ class TriangularTileScheduler(TileScheduler):
425
497
  group_size = min(args.group_size, cluster_m)
426
498
  group_size_tail = cluster_m % group_size
427
499
  num_groups_regular = cluster_m // group_size
500
+ if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC):
501
+ assert args.tile_count_semaphore is not None
428
502
  return TriangularTileScheduler.Params(
429
503
  problem_shape_ncluster_mnl,
430
- FastDivmod.create(num_clusters_per_problem),
504
+ FastDivmod(num_clusters_per_problem),
431
505
  Float32(1.0 / group_size),
432
506
  num_groups_regular,
433
- FastDivmod.create(group_size),
507
+ FastDivmod(group_size),
434
508
  # Don't divide by 0
435
- FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
436
- FastDivmod.create(group_size * group_size),
437
- FastDivmod.create((group_size_tail if group_size_tail > 0 else 1) * group_size),
438
- args.tile_count_semaphore if const_expr(args.is_persistent) else None,
509
+ FastDivmod(group_size_tail if group_size_tail > 0 else 1),
510
+ FastDivmod(group_size * group_size),
511
+ FastDivmod((group_size_tail if group_size_tail > 0 else 1) * group_size),
512
+ args.tile_count_semaphore
513
+ if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC)
514
+ else None,
439
515
  cluster_shape_mn,
440
- args.is_persistent,
516
+ args.persistence_mode,
441
517
  )
442
518
 
443
519
  @staticmethod
@@ -448,30 +524,32 @@ class TriangularTileScheduler(TileScheduler):
448
524
  @cute.jit
449
525
  def create(
450
526
  params: Params,
451
- tile_count: Optional[cute.Tensor] = None,
527
+ sched_smem: Optional[cute.Tensor] = None,
452
528
  scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
453
- is_scheduler_warp: bool | Boolean = False,
454
529
  *,
455
530
  loc=None,
456
531
  ip=None,
457
532
  ) -> "TriangularTileScheduler":
458
533
  stages = 0
459
- if const_expr(not params.is_persistent):
534
+ if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
460
535
  cluster_id, _, _ = cute.arch.cluster_idx()
461
536
  current_work_linear_idx = Int32(cluster_id)
462
537
  else:
463
538
  _, _, bidz = cute.arch.block_idx()
464
539
  current_work_linear_idx = Int32(bidz)
465
- if const_expr(params.tile_count_semaphore is not None):
466
- assert tile_count is not None
467
- assert scheduler_pipeline is not None
468
- stages = const_expr(cute.size(tile_count))
540
+ if const_expr(
541
+ params.persistence_mode
542
+ in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC]
543
+ ):
544
+ assert sched_smem is not None
545
+ assert scheduler_pipeline is not None
546
+ stages = const_expr(cute.size(sched_smem))
469
547
  return TriangularTileScheduler(
470
548
  current_work_linear_idx,
471
549
  Int32(0), # num_tiles_executed
472
- tile_count,
550
+ sched_smem,
473
551
  scheduler_pipeline,
474
- PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
552
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)),
475
553
  params,
476
554
  loc=loc,
477
555
  ip=ip,
@@ -486,15 +564,11 @@ class TriangularTileScheduler(TileScheduler):
486
564
  loc=None,
487
565
  ip=None,
488
566
  ) -> Tuple[Int32, Int32, Int32]:
489
- clusters = (
490
- params.num_clusters_per_problem_divmod.divisor,
491
- 1,
492
- params.problem_shape_ncluster_mnl[2],
493
- )
567
+ clusters = (params.num_clusters_per_problem_fdd.divisor, 1)
494
568
  num_ctas_mnl = tuple(x * y for x, y in zip(clusters, params.cluster_shape_mn)) + (
495
569
  params.problem_shape_ncluster_mnl[2],
496
570
  )
497
- if const_expr(not params.is_persistent):
571
+ if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
498
572
  return num_ctas_mnl
499
573
  else:
500
574
  num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
@@ -506,17 +580,19 @@ class TriangularTileScheduler(TileScheduler):
506
580
  return (*params.cluster_shape_mn, num_persistent_clusters)
507
581
 
508
582
  @cute.jit
509
- def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
583
+ def _delinearize_work_idx(
584
+ self, *, block_zero_only: bool = False, loc=None, ip=None
585
+ ) -> cutlass.utils.WorkTileInfo:
510
586
  params = self.params
511
- if const_expr(not params.is_persistent):
587
+ if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
512
588
  cluster_id_in_problem = self._current_work_linear_idx
513
589
  _, _, bidz = cute.arch.block_idx()
514
590
  else:
515
- bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
516
- self._current_work_linear_idx
591
+ bidz, cluster_id_in_problem = divmod(
592
+ self._current_work_linear_idx, params.num_clusters_per_problem_fdd
517
593
  )
518
594
  # CTA Swizzle to promote L2 data reuse
519
- group_size = params.group_size_divmod.divisor
595
+ group_size = params.group_size_fdd.divisor
520
596
  group_id = (
521
597
  utils.ceil(
522
598
  (utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
@@ -528,40 +604,40 @@ class TriangularTileScheduler(TileScheduler):
528
604
  group_size_actual = (
529
605
  group_size
530
606
  if group_id < params.num_groups_regular
531
- else params.group_size_tail_divmod.divisor
607
+ else params.group_size_tail_fdd.divisor
532
608
  )
533
609
  group_col, group_remainder = Int32(0), Int32(0)
534
610
  if group_id < params.num_groups_regular:
535
- group_col, group_remainder = params.group_size_mul_group_size_divmod.divmod(id_in_group)
611
+ group_col, group_remainder = divmod(id_in_group, params.group_size_mul_group_size_fdd)
536
612
  else: # tail part
537
- group_col, group_remainder = params.group_size_tail_mul_group_size_divmod.divmod(
538
- id_in_group
613
+ group_col, group_remainder = divmod(
614
+ id_in_group, params.group_size_tail_mul_group_size_fdd
539
615
  )
540
616
  cid_m_in_group, cid_n_in_group = Int32(0), Int32(0)
541
617
  if id_in_group >= group_size_actual * group_size * group_id: # triangular tail
542
618
  cid_m_in_group, cid_n_in_group = triangular_idx_to_coord(group_remainder)
543
619
  else:
544
620
  if group_id < params.num_groups_regular:
545
- cid_n_in_group, cid_m_in_group = params.group_size_divmod.divmod(group_remainder)
621
+ cid_n_in_group, cid_m_in_group = divmod(group_remainder, params.group_size_fdd)
546
622
  else:
547
- cid_n_in_group, cid_m_in_group = params.group_size_tail_divmod.divmod(
548
- group_remainder
549
- )
623
+ cid_n_in_group, cid_m_in_group = divmod(group_remainder, params.group_size_tail_fdd)
550
624
  cid_m = cid_m_start + cid_m_in_group
551
625
  cid_n = group_col * group_size + cid_n_in_group
552
626
 
553
- # Get the pid from cluster id
554
- bidx_in_cluster = cute.arch.block_in_cluster_idx()
627
+ if const_expr(block_zero_only):
628
+ bidx_in_cluster = (Int32(0), Int32(0))
629
+ else:
630
+ # Get the pid from cluster id
631
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
555
632
  pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
556
633
  pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
557
634
  tile_coord_mnkl = (pid_m, pid_n, None, bidz)
558
- if const_expr(not params.is_persistent):
635
+ if const_expr(params.persistence_mode == PersistenceMode.NONE):
559
636
  is_valid = self.num_tiles_executed == 0
560
637
  else:
561
638
  is_valid = (
562
639
  self._current_work_linear_idx
563
- < params.num_clusters_per_problem_divmod.divisor
564
- * params.problem_shape_ncluster_mnl[2]
640
+ < params.num_clusters_per_problem_fdd.divisor * params.problem_shape_ncluster_mnl[2]
565
641
  )
566
642
  # bidx, bidy, bidz = cute.arch.block_idx()
567
643
  # tidx, _, _ = cute.arch.thread_idx()
@@ -581,7 +657,7 @@ class VarlenMTileSchedulerArguments(ParamsBase):
581
657
  tile_shape_mn: cutlass.Constexpr[cute.Shape]
582
658
  cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
583
659
  tile_count_semaphore: Optional[cute.Pointer] = None
584
- is_persistent: cutlass.Constexpr[bool] = False
660
+ persistence_mode: cutlass.Constexpr[PersistenceMode] = PersistenceMode.NONE
585
661
 
586
662
 
587
663
  class VarlenMTileScheduler(TileScheduler):
@@ -592,13 +668,13 @@ class VarlenMTileScheduler(TileScheduler):
592
668
  cu_seqlens_m: cute.Tensor
593
669
  raster_order: cutlass.Constexpr[RasterOrder]
594
670
  group_size: Int32
595
- group_size_divmod: Optional[FastDivmod]
596
- group_size_tail_divmod: Optional[FastDivmod]
597
- num_clusters_in_group_divmod: FastDivmod
671
+ group_size_fdd: Optional[FastDivmod]
672
+ group_size_tail_fdd: Optional[FastDivmod]
673
+ num_clusters_in_group_fdd: FastDivmod
598
674
  tile_shape_mn: cutlass.Constexpr[cute.Shape]
599
675
  tile_count_semaphore: Optional[cute.Pointer]
600
676
  cluster_shape_mn: cutlass.Constexpr[cute.Shape]
601
- is_persistent: cutlass.Constexpr[bool]
677
+ persistence_mode: cutlass.Constexpr[PersistenceMode]
602
678
 
603
679
  @staticmethod
604
680
  @cute.jit
@@ -621,43 +697,40 @@ class VarlenMTileScheduler(TileScheduler):
621
697
  if args.raster_order == RasterOrderOption.AlongM
622
698
  else RasterOrder.AlongN # For Heuristic we also use AlongN
623
699
  )
624
- ncluster_fast = (
625
- problem_shape_ncluster_mn[0]
626
- if raster_order == RasterOrder.AlongM
627
- else problem_shape_ncluster_mn[1]
628
- )
629
- ncluster_slow = (
630
- problem_shape_ncluster_mn[1]
631
- if raster_order == RasterOrder.AlongM
632
- else problem_shape_ncluster_mn[0]
633
- )
700
+ ncluster_fast = problem_shape_ncluster_mn[
701
+ 0 if raster_order == RasterOrder.AlongM else 1
702
+ ]
703
+ ncluster_slow = problem_shape_ncluster_mn[
704
+ 1 if raster_order == RasterOrder.AlongM else 0
705
+ ]
634
706
  if const_expr(ncluster_fast is not None):
635
707
  group_size = min(args.group_size, ncluster_fast)
636
708
  group_size_tail = ncluster_fast % group_size
637
709
  else:
638
710
  group_size, group_size_tail = args.group_size, None
711
+ num_clusters_in_group = None
639
712
  if const_expr(ncluster_slow is not None):
640
713
  num_clusters_in_group = group_size * ncluster_slow
641
- else:
642
- num_clusters_in_group = None
714
+ if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC):
715
+ assert args.tile_count_semaphore is not None
643
716
  return VarlenMTileScheduler.Params(
644
717
  problem_shape_ncluster_mnl,
645
718
  args.total_m,
646
719
  args.cu_seqlens_m,
647
720
  raster_order,
648
721
  group_size,
649
- FastDivmod.create(group_size) if ncluster_fast is not None else None,
722
+ FastDivmod(group_size) if ncluster_fast is not None else None,
650
723
  # Don't divide by 0
651
- FastDivmod.create(group_size_tail if group_size_tail > 0 else 1)
724
+ FastDivmod(group_size_tail if group_size_tail > 0 else 1)
652
725
  if group_size_tail is not None
653
726
  else None,
654
- FastDivmod.create(num_clusters_in_group)
655
- if num_clusters_in_group is not None
656
- else None,
727
+ FastDivmod(num_clusters_in_group) if num_clusters_in_group is not None else None,
657
728
  args.tile_shape_mn,
658
- args.tile_count_semaphore if const_expr(args.is_persistent) else None,
729
+ args.tile_count_semaphore
730
+ if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC)
731
+ else None,
659
732
  cluster_shape_mn,
660
- args.is_persistent,
733
+ args.persistence_mode,
661
734
  )
662
735
 
663
736
  def __init__(
@@ -666,7 +739,7 @@ class VarlenMTileScheduler(TileScheduler):
666
739
  num_tiles_executed: Int32,
667
740
  current_batch_idx: Int32,
668
741
  num_work_idx_before_cur_batch: Int32,
669
- tile_count: Optional[cute.Tensor],
742
+ sched_smem: Optional[cute.Tensor],
670
743
  scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
671
744
  pipeline_state: PipelineStateWAdvance,
672
745
  params: Params,
@@ -678,7 +751,7 @@ class VarlenMTileScheduler(TileScheduler):
678
751
  self.num_tiles_executed = num_tiles_executed
679
752
  self._current_batch_idx = current_batch_idx
680
753
  self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
681
- self._tile_count = tile_count
754
+ self._sched_smem = sched_smem
682
755
  self._scheduler_pipeline = scheduler_pipeline
683
756
  self._pipeline_state = pipeline_state
684
757
  self.params = params
@@ -693,9 +766,8 @@ class VarlenMTileScheduler(TileScheduler):
693
766
  @cute.jit
694
767
  def create(
695
768
  params: Params,
696
- tile_count: Optional[cute.Tensor] = None,
769
+ sched_smem: Optional[cute.Tensor] = None,
697
770
  scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
698
- is_scheduler_warp: bool | Boolean = False,
699
771
  *,
700
772
  loc=None,
701
773
  ip=None,
@@ -703,18 +775,21 @@ class VarlenMTileScheduler(TileScheduler):
703
775
  stages = 0
704
776
  _, _, bidz = cute.arch.block_idx()
705
777
  current_work_linear_idx = Int32(bidz)
706
- if const_expr(params.tile_count_semaphore is not None):
707
- assert tile_count is not None
778
+ if const_expr(
779
+ params.persistence_mode
780
+ in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC]
781
+ ):
782
+ assert sched_smem is not None
708
783
  assert scheduler_pipeline is not None
709
- stages = const_expr(cute.size(tile_count))
784
+ stages = const_expr(cute.size(sched_smem, mode=[1]))
710
785
  return VarlenMTileScheduler(
711
786
  current_work_linear_idx,
712
787
  Int32(0), # num_tiles_executed
713
788
  Int32(0), # current_batch_idx
714
789
  Int32(0), # num_work_idx_before_cur_batch
715
- tile_count,
790
+ sched_smem,
716
791
  scheduler_pipeline,
717
- PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
792
+ PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)),
718
793
  params,
719
794
  loc=loc,
720
795
  ip=ip,
@@ -733,54 +808,33 @@ class VarlenMTileScheduler(TileScheduler):
733
808
  num_batch = params.problem_shape_ncluster_mnl[2]
734
809
  total_clusters_m_max = (params.total_m + num_batch * (block_size - 1)) // block_size
735
810
  total_clusters_max = total_clusters_m_max * params.problem_shape_ncluster_mnl[1]
736
- if const_expr(not params.is_persistent):
811
+ if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
737
812
  return (*params.cluster_shape_mn, total_clusters_max)
738
813
  else:
739
814
  num_persistent_clusters = cutlass.min(max_active_clusters, total_clusters_max)
740
815
  return (*params.cluster_shape_mn, num_persistent_clusters)
741
816
 
742
- @cute.jit
743
- def _get_num_m_blocks(
744
- self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int]
745
- ) -> Int32:
746
- num_batch = self.params.problem_shape_ncluster_mnl[2]
747
- batch_idx = lane + bidb_start
748
- cur_cu_seqlen = Int32(0)
749
- if batch_idx <= num_batch:
750
- cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx]
751
- next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
752
- seqlen = next_cu_seqlen - cur_cu_seqlen
753
- return (
754
- cute.ceil_div(seqlen, block_size)
755
- if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1
756
- else Int32(0)
757
- )
758
-
759
817
  @cute.jit
760
818
  def _swizzle_cta(
761
819
  self, cluster_id_in_problem: Int32, num_clusters_m: Int32, *, loc=None, ip=None
762
820
  ) -> Tuple[Int32, Int32]:
763
821
  params = self.params
764
822
  # CTA Swizzle to promote L2 data reuse
765
- if const_expr(params.num_clusters_in_group_divmod is not None):
766
- group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(
767
- cluster_id_in_problem
768
- )
769
- num_clusters_in_group = params.num_clusters_in_group_divmod.divisor
823
+ if const_expr(params.num_clusters_in_group_fdd is not None):
824
+ group_id, id_in_group = divmod(cluster_id_in_problem, params.num_clusters_in_group_fdd)
825
+ num_clusters_in_group = params.num_clusters_in_group_fdd.divisor
770
826
  else:
771
827
  assert params.raster_order == RasterOrder.AlongN
772
828
  num_clusters_in_group = params.group_size * num_clusters_m
773
829
  group_id = cluster_id_in_problem // num_clusters_in_group
774
830
  id_in_group = cluster_id_in_problem - group_id * num_clusters_in_group
775
831
  cid_fast_in_group, cid_slow = Int32(0), Int32(0)
776
- if const_expr(
777
- params.group_size_divmod is not None and params.group_size_tail_divmod is not None
778
- ):
832
+ if const_expr(params.group_size_fdd is not None and params.group_size_tail_fdd is not None):
779
833
  num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
780
834
  if (group_id + 1) * num_clusters_in_group <= num_clusters:
781
- cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group)
835
+ cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_fdd)
782
836
  else: # tail part
783
- cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group)
837
+ cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_tail_fdd)
784
838
  else:
785
839
  assert params.raster_order == RasterOrder.AlongM
786
840
  group_size_actual = cutlass.min(
@@ -802,7 +856,26 @@ class VarlenMTileScheduler(TileScheduler):
802
856
  return cid_m, cid_n
803
857
 
804
858
  @cute.jit
805
- def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
859
+ def _get_num_m_blocks(
860
+ self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int]
861
+ ) -> Int32:
862
+ num_batch = self.params.problem_shape_ncluster_mnl[2]
863
+ batch_idx = lane + bidb_start
864
+ cur_cu_seqlen = Int32(0)
865
+ if batch_idx <= num_batch:
866
+ cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx]
867
+ next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
868
+ seqlen = next_cu_seqlen - cur_cu_seqlen
869
+ return (
870
+ cute.ceil_div(seqlen, block_size)
871
+ if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1
872
+ else Int32(0)
873
+ )
874
+
875
+ @cute.jit
876
+ def _delinearize_work_idx(
877
+ self, *, block_zero_only: bool = False, loc=None, ip=None
878
+ ) -> cutlass.utils.WorkTileInfo:
806
879
  params = self.params
807
880
  lane_idx = cute.arch.lane_idx()
808
881
  num_batch = self.params.problem_shape_ncluster_mnl[2]
@@ -819,7 +892,6 @@ class VarlenMTileScheduler(TileScheduler):
819
892
  )
820
893
  problems_end_tile = self._num_work_idx_before_cur_batch + clusters_in_problems
821
894
  # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, num_clusters_cumulative = %d, problems_end_tile = %d", self._tile_idx, problems_end_tile, num_clusters_m, num_clusters_cumulative, problems_end_tile)
822
- cid_m, cid_n = Int32(0), Int32(0)
823
895
  next_tile_idx = self._current_work_linear_idx
824
896
  while problems_end_tile <= next_tile_idx:
825
897
  batch_idx += cute.arch.WARP_SIZE - 1
@@ -836,11 +908,14 @@ class VarlenMTileScheduler(TileScheduler):
836
908
  num_clusters_cumulative, cute.arch.WARP_SIZE - 1
837
909
  )
838
910
  problems_end_tile += clusters_in_problems
911
+ if const_expr(params.persistence_mode == PersistenceMode.NONE):
912
+ is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
913
+ else:
914
+ is_valid = batch_idx < num_batch
839
915
  # Just a placeholer value in case batch_idx >= num_batch
840
916
  num_work_idx_before_cur_batch = problems_end_tile - clusters_in_problems
841
- if batch_idx >= num_batch:
842
- cid_m, cid_n, batch_idx = Int32(0), Int32(0), Int32(num_batch)
843
- else:
917
+ cid_m, cid_n = Int32(0), Int32(0)
918
+ if is_valid:
844
919
  problems_start_tile = problems_end_tile - clusters_in_problems
845
920
  # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, batch_idx = %d", self._tile_idx, problems_end_tile, num_clusters_m, batch_idx)
846
921
  # The next problem to process is the first one that does not have ending tile position
@@ -859,42 +934,21 @@ class VarlenMTileScheduler(TileScheduler):
859
934
  num_clusters_m = cute.arch.shuffle_sync(num_clusters_m, batch_idx_in_problems)
860
935
  num_work_idx_before_cur_batch = problems_start_tile + num_clusters_prev_lane
861
936
  cluster_id_in_problem = next_tile_idx - num_work_idx_before_cur_batch
862
- # cid_n = cluster_id_in_problem // num_clusters_m
863
- # cid_m = cluster_id_in_problem - cid_n * num_clusters_m
864
937
  # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, cid_n=%d, cid_m=%d, is_valid = %d", self._tile_idx, batch_idx, cid_n, cid_m, is_valid)
865
938
  cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, num_clusters_m, loc=loc, ip=ip)
866
939
  self._current_batch_idx = batch_idx
867
940
  self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
868
941
 
869
- # Get the pid from cluster id
870
- bidx_in_cluster = cute.arch.block_in_cluster_idx()
942
+ if const_expr(block_zero_only):
943
+ bidx_in_cluster = (Int32(0), Int32(0))
944
+ else:
945
+ # Get the pid from cluster id
946
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
871
947
  pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
872
948
  pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
873
949
  tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
874
- if const_expr(not params.is_persistent):
875
- is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
876
- else:
877
- is_valid = batch_idx < num_batch
878
950
  return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
879
951
 
880
- @cute.jit
881
- def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
882
- """is_scheduler_warp should only be true for one warp in the whole cluster"""
883
- if const_expr(self.params.tile_count_semaphore is not None):
884
- params = self.params
885
- current_work_linear_idx = self._current_work_linear_idx
886
- if is_scheduler_warp:
887
- if cute.arch.lane_idx() == 0:
888
- # cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
889
- num_persistent_clusters = cute.arch.grid_dim()[2]
890
- current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
891
- 1, params.tile_count_semaphore
892
- )
893
- # cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
894
- # lane 0 already has the right tile_idx, just need to broadcast
895
- current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
896
- self._current_work_linear_idx = current_work_linear_idx
897
-
898
952
  def __extract_mlir_values__(self):
899
953
  values, self._values_pos = [], []
900
954
  for obj in [
@@ -902,7 +956,7 @@ class VarlenMTileScheduler(TileScheduler):
902
956
  self.num_tiles_executed,
903
957
  self._current_batch_idx,
904
958
  self._num_work_idx_before_cur_batch,
905
- self._tile_count,
959
+ self._sched_smem,
906
960
  self._scheduler_pipeline,
907
961
  self._pipeline_state,
908
962
  self.params,
@@ -920,7 +974,7 @@ class VarlenMTileScheduler(TileScheduler):
920
974
  self.num_tiles_executed,
921
975
  self._current_batch_idx,
922
976
  self._num_work_idx_before_cur_batch,
923
- self._tile_count,
977
+ self._sched_smem,
924
978
  self._scheduler_pipeline,
925
979
  self._pipeline_state,
926
980
  self.params,