quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +14 -18
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +2 -4
- quack/linear.py +37 -0
- quack/pipeline.py +59 -89
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +5 -3
- quack/sort/bitonic_sort.py +3 -3
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- quack_kernels-0.2.5.dist-info/RECORD +0 -45
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/tile_scheduler.py
CHANGED
|
@@ -25,6 +25,13 @@ class RasterOrder(IntEnum):
|
|
|
25
25
|
AlongN = 1
|
|
26
26
|
|
|
27
27
|
|
|
28
|
+
class PersistenceMode(IntEnum):
|
|
29
|
+
NONE = 0
|
|
30
|
+
STATIC = 1
|
|
31
|
+
DYNAMIC = 2
|
|
32
|
+
CLC = 3
|
|
33
|
+
|
|
34
|
+
|
|
28
35
|
@cute.jit
|
|
29
36
|
def get_raster_order_from_option(
|
|
30
37
|
raster_order_option: RasterOrderOption, problem_shape_ncluster_mn: cute.Shape, group_size: Int32
|
|
@@ -61,7 +68,7 @@ class TileSchedulerArguments(ArgumentsBase):
|
|
|
61
68
|
cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
|
|
62
69
|
tile_count_semaphore: Optional[cute.Pointer] = None
|
|
63
70
|
batch_idx_permute: Optional[cute.Tensor] = None
|
|
64
|
-
|
|
71
|
+
persistence_mode: cutlass.Constexpr[PersistenceMode] = PersistenceMode.NONE
|
|
65
72
|
|
|
66
73
|
|
|
67
74
|
class TileScheduler:
|
|
@@ -69,15 +76,15 @@ class TileScheduler:
|
|
|
69
76
|
class Params(ParamsBase):
|
|
70
77
|
problem_shape_ncluster_mnl: cute.Shape
|
|
71
78
|
raster_order: RasterOrder
|
|
72
|
-
|
|
79
|
+
num_clusters_per_problem_fdd: FastDivmod
|
|
73
80
|
num_groups_regular: Int32
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
81
|
+
group_size_fdd: FastDivmod
|
|
82
|
+
group_size_tail_fdd: FastDivmod
|
|
83
|
+
num_clusters_in_group_fdd: FastDivmod
|
|
77
84
|
tile_count_semaphore: Optional[cute.Pointer]
|
|
78
85
|
batch_idx_permute: Optional[cute.Tensor]
|
|
79
86
|
cluster_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
80
|
-
|
|
87
|
+
persistence_mode: cutlass.Constexpr[PersistenceMode]
|
|
81
88
|
|
|
82
89
|
@staticmethod
|
|
83
90
|
@cute.jit
|
|
@@ -107,26 +114,30 @@ class TileScheduler:
|
|
|
107
114
|
group_size_tail = ncluster_fast % group_size
|
|
108
115
|
num_groups_regular = ncluster_fast // group_size
|
|
109
116
|
num_clusters_in_group = group_size * ncluster_slow
|
|
117
|
+
if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC):
|
|
118
|
+
assert args.tile_count_semaphore is not None
|
|
110
119
|
return TileScheduler.Params(
|
|
111
120
|
problem_shape_ncluster_mnl,
|
|
112
121
|
raster_order,
|
|
113
|
-
FastDivmod
|
|
122
|
+
FastDivmod(num_clusters_per_problem),
|
|
114
123
|
num_groups_regular,
|
|
115
|
-
FastDivmod
|
|
124
|
+
FastDivmod(group_size),
|
|
116
125
|
# Don't divide by 0
|
|
117
|
-
FastDivmod
|
|
118
|
-
FastDivmod
|
|
119
|
-
args.tile_count_semaphore
|
|
126
|
+
FastDivmod(group_size_tail if group_size_tail > 0 else 1),
|
|
127
|
+
FastDivmod(num_clusters_in_group),
|
|
128
|
+
args.tile_count_semaphore
|
|
129
|
+
if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC)
|
|
130
|
+
else None,
|
|
120
131
|
args.batch_idx_permute,
|
|
121
132
|
cluster_shape_mn,
|
|
122
|
-
args.
|
|
133
|
+
args.persistence_mode,
|
|
123
134
|
)
|
|
124
135
|
|
|
125
136
|
def __init__(
|
|
126
137
|
self,
|
|
127
138
|
current_work_linear_idx: Int32,
|
|
128
139
|
num_tiles_executed: Int32,
|
|
129
|
-
|
|
140
|
+
sched_smem: Optional[cute.Tensor],
|
|
130
141
|
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
|
|
131
142
|
pipeline_state: PipelineStateWAdvance,
|
|
132
143
|
params: Params,
|
|
@@ -136,7 +147,7 @@ class TileScheduler:
|
|
|
136
147
|
):
|
|
137
148
|
self._current_work_linear_idx = current_work_linear_idx
|
|
138
149
|
self.num_tiles_executed = num_tiles_executed
|
|
139
|
-
self.
|
|
150
|
+
self._sched_smem = sched_smem
|
|
140
151
|
self._scheduler_pipeline = scheduler_pipeline
|
|
141
152
|
self._pipeline_state = pipeline_state
|
|
142
153
|
self.params = params
|
|
@@ -151,16 +162,14 @@ class TileScheduler:
|
|
|
151
162
|
@cute.jit
|
|
152
163
|
def create(
|
|
153
164
|
params: Params,
|
|
154
|
-
|
|
165
|
+
sched_smem: Optional[cute.Tensor] = None,
|
|
155
166
|
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
|
|
156
|
-
is_scheduler_warp: bool | Boolean = False,
|
|
157
167
|
*,
|
|
158
168
|
loc=None,
|
|
159
169
|
ip=None,
|
|
160
170
|
) -> "TileScheduler":
|
|
161
171
|
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
162
|
-
|
|
163
|
-
if const_expr(not params.is_persistent):
|
|
172
|
+
if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
|
|
164
173
|
cidx, cidy, _ = cute.arch.cluster_idx()
|
|
165
174
|
cdimx, _, _ = cute.arch.cluster_dim()
|
|
166
175
|
cluster_id = cidx + cidy * cdimx
|
|
@@ -168,16 +177,20 @@ class TileScheduler:
|
|
|
168
177
|
else:
|
|
169
178
|
_, _, bidz = cute.arch.block_idx()
|
|
170
179
|
current_work_linear_idx = Int32(bidz)
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
180
|
+
stages = 0
|
|
181
|
+
if const_expr(
|
|
182
|
+
params.persistence_mode
|
|
183
|
+
in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC]
|
|
184
|
+
):
|
|
185
|
+
assert sched_smem is not None
|
|
186
|
+
assert scheduler_pipeline is not None
|
|
187
|
+
stages = const_expr(cute.size(sched_smem, mode=[1]))
|
|
175
188
|
return TileScheduler(
|
|
176
189
|
current_work_linear_idx,
|
|
177
190
|
Int32(0), # num_tiles_executed
|
|
178
|
-
|
|
191
|
+
sched_smem,
|
|
179
192
|
scheduler_pipeline,
|
|
180
|
-
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(
|
|
193
|
+
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)),
|
|
181
194
|
params,
|
|
182
195
|
loc=loc,
|
|
183
196
|
ip=ip,
|
|
@@ -195,7 +208,7 @@ class TileScheduler:
|
|
|
195
208
|
num_ctas_mnl = tuple(
|
|
196
209
|
x * y for x, y in zip(params.problem_shape_ncluster_mnl, params.cluster_shape_mn)
|
|
197
210
|
) + (params.problem_shape_ncluster_mnl[2],)
|
|
198
|
-
if const_expr(
|
|
211
|
+
if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
|
|
199
212
|
return num_ctas_mnl
|
|
200
213
|
else:
|
|
201
214
|
num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
|
|
@@ -212,12 +225,12 @@ class TileScheduler:
|
|
|
212
225
|
) -> Tuple[Int32, Int32]:
|
|
213
226
|
# CTA Swizzle to promote L2 data reuse
|
|
214
227
|
params = self.params
|
|
215
|
-
group_id, id_in_group =
|
|
228
|
+
group_id, id_in_group = divmod(cluster_id_in_problem, params.num_clusters_in_group_fdd)
|
|
216
229
|
cid_fast_in_group, cid_slow = Int32(0), Int32(0)
|
|
217
230
|
if group_id < params.num_groups_regular:
|
|
218
|
-
cid_slow, cid_fast_in_group =
|
|
231
|
+
cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_fdd)
|
|
219
232
|
else: # tail part
|
|
220
|
-
cid_slow, cid_fast_in_group =
|
|
233
|
+
cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_tail_fdd)
|
|
221
234
|
if group_id % 2 == 1: # serpentine order
|
|
222
235
|
ncluster_slow = (
|
|
223
236
|
params.problem_shape_ncluster_mnl[1]
|
|
@@ -225,56 +238,151 @@ class TileScheduler:
|
|
|
225
238
|
else params.problem_shape_ncluster_mnl[0]
|
|
226
239
|
)
|
|
227
240
|
cid_slow = ncluster_slow - 1 - cid_slow
|
|
228
|
-
cid_fast = group_id * params.
|
|
241
|
+
cid_fast = group_id * params.group_size_fdd.divisor + cid_fast_in_group
|
|
229
242
|
cid_m, cid_n = cid_fast, cid_slow
|
|
230
243
|
if params.raster_order == RasterOrder.AlongN:
|
|
231
244
|
cid_m, cid_n = cid_slow, cid_fast
|
|
232
245
|
return cid_m, cid_n
|
|
233
246
|
|
|
234
247
|
@cute.jit
|
|
235
|
-
def
|
|
248
|
+
def _delinearize_work_idx(
|
|
249
|
+
self, *, block_zero_only: bool = False, loc=None, ip=None
|
|
250
|
+
) -> cutlass.utils.WorkTileInfo:
|
|
236
251
|
params = self.params
|
|
237
|
-
if const_expr(
|
|
238
|
-
cluster_id_in_problem = self._current_work_linear_idx
|
|
239
|
-
_, _, bidz = cute.arch.block_idx()
|
|
240
|
-
else:
|
|
241
|
-
bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
|
|
242
|
-
self._current_work_linear_idx
|
|
243
|
-
)
|
|
244
|
-
cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip)
|
|
245
|
-
# Get the pid from cluster id
|
|
246
|
-
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
247
|
-
pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
248
|
-
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
249
|
-
batch_idx = (
|
|
250
|
-
bidz if const_expr(params.batch_idx_permute is None) else params.batch_idx_permute[bidz]
|
|
251
|
-
)
|
|
252
|
-
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
253
|
-
if const_expr(not params.is_persistent):
|
|
252
|
+
if const_expr(params.persistence_mode == PersistenceMode.NONE):
|
|
254
253
|
is_valid = self.num_tiles_executed == 0
|
|
255
254
|
else:
|
|
256
255
|
is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
|
|
256
|
+
pid_m, pid_n, batch_idx = Int32(0), Int32(0), Int32(0)
|
|
257
|
+
if is_valid:
|
|
258
|
+
if const_expr(params.persistence_mode == PersistenceMode.NONE):
|
|
259
|
+
cluster_id_in_problem = self._current_work_linear_idx
|
|
260
|
+
_, _, bidz = cute.arch.block_idx()
|
|
261
|
+
else:
|
|
262
|
+
bidz, cluster_id_in_problem = divmod(
|
|
263
|
+
self._current_work_linear_idx, params.num_clusters_per_problem_fdd
|
|
264
|
+
)
|
|
265
|
+
cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip)
|
|
266
|
+
if const_expr(block_zero_only):
|
|
267
|
+
bidx_in_cluster = (Int32(0), Int32(0))
|
|
268
|
+
else:
|
|
269
|
+
# Get the pid from cluster id
|
|
270
|
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
271
|
+
pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
272
|
+
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
273
|
+
batch_idx = (
|
|
274
|
+
bidz
|
|
275
|
+
if const_expr(params.batch_idx_permute is None)
|
|
276
|
+
else params.batch_idx_permute[bidz]
|
|
277
|
+
)
|
|
278
|
+
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
257
279
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
258
280
|
|
|
281
|
+
@cute.jit
|
|
282
|
+
def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
|
|
283
|
+
params = self.params
|
|
284
|
+
pid_m, pid_n, batch_idx, is_valid = Int32(0), Int32(0), Int32(0), Boolean(False)
|
|
285
|
+
if const_expr(params.persistence_mode == PersistenceMode.NONE):
|
|
286
|
+
pass
|
|
287
|
+
# elif const_expr(params.persistence_mode == PersistenceMode.STATIC):
|
|
288
|
+
# return self._delinearize_work_idx(loc=loc, ip=ip)
|
|
289
|
+
else:
|
|
290
|
+
self._scheduler_pipeline.consumer_wait(self._pipeline_state)
|
|
291
|
+
pid_m, pid_n, batch_idx, is_valid_i32 = [
|
|
292
|
+
self._sched_smem[i, self._pipeline_state.index] for i in range(4)
|
|
293
|
+
]
|
|
294
|
+
# Need this fence since the STAS from the producer is using the async proxy.
|
|
295
|
+
# Without this, we get race condition / deadlock.
|
|
296
|
+
if const_expr(cute.size(params.cluster_shape_mn) > 1):
|
|
297
|
+
cute.arch.fence_view_async_shared()
|
|
298
|
+
cute.arch.sync_warp()
|
|
299
|
+
with cute.arch.elect_one():
|
|
300
|
+
self._scheduler_pipeline.consumer_release(self._pipeline_state)
|
|
301
|
+
self._pipeline_state.advance()
|
|
302
|
+
is_valid = Boolean(is_valid_i32)
|
|
303
|
+
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
304
|
+
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, Boolean(is_valid))
|
|
305
|
+
|
|
306
|
+
# @cute.jit
|
|
259
307
|
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
260
|
-
return self.
|
|
308
|
+
return self._delinearize_work_idx(loc=loc, ip=ip)
|
|
309
|
+
# if is_scheduler_warp:
|
|
310
|
+
# work_tile_info = self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip)
|
|
311
|
+
# self.write_work_tile_to_smem(work_tile_info, loc=loc, ip=ip)
|
|
312
|
+
# self.write_work_tile_to_smem(self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip), loc=loc, ip=ip)
|
|
261
313
|
|
|
262
314
|
@cute.jit
|
|
263
|
-
def
|
|
315
|
+
def _fetch_next_work_idx(self, *, loc=None, ip=None) -> Int32:
|
|
264
316
|
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
265
317
|
params = self.params
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
318
|
+
num_persistent_clusters = Int32(cute.arch.grid_dim()[2])
|
|
319
|
+
if const_expr(params.persistence_mode == PersistenceMode.STATIC):
|
|
320
|
+
return self._current_work_linear_idx + num_persistent_clusters
|
|
321
|
+
elif const_expr(params.persistence_mode == PersistenceMode.DYNAMIC):
|
|
322
|
+
next_work_linear_idx = Int32(0)
|
|
323
|
+
if cute.arch.lane_idx() == 0:
|
|
324
|
+
# If varlen_m, problem_shape_ncluster_mnl[0] is None, so we use atomic_add
|
|
325
|
+
# instead of atomic_inc, and at the end of the kernel must reset the semaphore to 0.
|
|
326
|
+
# # cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
|
|
327
|
+
if const_expr(params.problem_shape_ncluster_mnl[0] is not None):
|
|
328
|
+
next_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32(
|
|
272
329
|
cute.size(params.problem_shape_ncluster_mnl) - 1,
|
|
273
330
|
params.tile_count_semaphore,
|
|
274
331
|
)
|
|
275
|
-
#
|
|
276
|
-
|
|
277
|
-
|
|
332
|
+
else: # varlen_m
|
|
333
|
+
next_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
|
|
334
|
+
1, params.tile_count_semaphore
|
|
335
|
+
)
|
|
336
|
+
# cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
|
|
337
|
+
return cute.arch.shuffle_sync(next_work_linear_idx, 0)
|
|
338
|
+
else:
|
|
339
|
+
return Int32(0)
|
|
340
|
+
|
|
341
|
+
@cute.jit
|
|
342
|
+
def write_work_tile_to_smem(
|
|
343
|
+
self, work_tile_info: cutlass.utils.WorkTileInfo, *, loc=None, ip=None
|
|
344
|
+
):
|
|
345
|
+
params = self.params
|
|
346
|
+
if const_expr(self._sched_smem is not None):
|
|
347
|
+
# producer phase is always consumer_phase ^ 1
|
|
348
|
+
pipeline_state_producer = PipelineStateWAdvance(
|
|
349
|
+
self._pipeline_state.stages,
|
|
350
|
+
self._pipeline_state.count,
|
|
351
|
+
self._pipeline_state.index,
|
|
352
|
+
self._pipeline_state.phase ^ 1,
|
|
353
|
+
)
|
|
354
|
+
self._scheduler_pipeline.producer_acquire(pipeline_state_producer)
|
|
355
|
+
sched_data = [
|
|
356
|
+
work_tile_info.tile_idx[0],
|
|
357
|
+
work_tile_info.tile_idx[1],
|
|
358
|
+
work_tile_info.tile_idx[3],
|
|
359
|
+
Int32(work_tile_info.is_valid_tile),
|
|
360
|
+
]
|
|
361
|
+
lane_idx = cute.arch.lane_idx()
|
|
362
|
+
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
363
|
+
# cute.printf("Producer pid_m = {}, pid_n = {}, batch_idx = {}, is_valid = {}, after empty wait, idx = {}", sched_data[0], sched_data[1], sched_data[2], sched_data[3], self._current_work_linear_idx)
|
|
364
|
+
pipeline_idx = self._pipeline_state.index
|
|
365
|
+
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
366
|
+
for i in cutlass.range_constexpr(4):
|
|
367
|
+
self._sched_smem[i, pipeline_idx] = sched_data[i]
|
|
368
|
+
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
369
|
+
else:
|
|
370
|
+
peer_cta_rank_in_cluster = lane_idx
|
|
371
|
+
# Here we assume that the block idx in cluster is linearized such that
|
|
372
|
+
# x is the fastest moving direction.
|
|
373
|
+
bidx_in_cluster = peer_cta_rank_in_cluster % params.cluster_shape_mn[0]
|
|
374
|
+
bidy_in_cluster = peer_cta_rank_in_cluster // params.cluster_shape_mn[0]
|
|
375
|
+
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(self._pipeline_state)
|
|
376
|
+
cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr, 16, peer_cta_rank_in_cluster)
|
|
377
|
+
utils.store_shared_remote_x4(
|
|
378
|
+
sched_data[0] + bidx_in_cluster,
|
|
379
|
+
sched_data[1] + bidy_in_cluster,
|
|
380
|
+
sched_data[2],
|
|
381
|
+
sched_data[3],
|
|
382
|
+
smem_ptr=self._sched_smem[None, pipeline_idx].iterator,
|
|
383
|
+
mbar_ptr=mbar_ptr,
|
|
384
|
+
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
385
|
+
)
|
|
278
386
|
|
|
279
387
|
@cute.jit
|
|
280
388
|
def advance_to_next_work(
|
|
@@ -285,73 +393,37 @@ class TileScheduler:
|
|
|
285
393
|
loc=None,
|
|
286
394
|
ip=None,
|
|
287
395
|
):
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
396
|
+
"""is_scheduler_warp should only be true for one warp in the whole cluster.
|
|
397
|
+
Moreover, we assume that only block zero in the cluster is calling this function.
|
|
398
|
+
If calling with is_scheduler_warp = True, advance_count must be 1.
|
|
399
|
+
"""
|
|
291
400
|
params = self.params
|
|
292
|
-
if const_expr(params.is_persistent):
|
|
293
|
-
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
294
|
-
if const_expr(params.tile_count_semaphore is None): # Static persistent
|
|
295
|
-
self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
|
|
296
|
-
else: # Dynamic persistent
|
|
297
|
-
if const_expr(advance_count > 1):
|
|
298
|
-
self._pipeline_state.advance_iters(advance_count - 1)
|
|
299
|
-
current_work_linear_idx = self._current_work_linear_idx
|
|
300
|
-
if is_scheduler_warp:
|
|
301
|
-
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
302
|
-
lane_idx = cute.arch.lane_idx()
|
|
303
|
-
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
304
|
-
# cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after empty wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
305
|
-
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
306
|
-
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
307
|
-
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
308
|
-
else:
|
|
309
|
-
peer_cta_rank_in_cluster = lane_idx
|
|
310
|
-
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
311
|
-
self._pipeline_state
|
|
312
|
-
)
|
|
313
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
314
|
-
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
315
|
-
)
|
|
316
|
-
utils.store_shared_remote(
|
|
317
|
-
val=current_work_linear_idx,
|
|
318
|
-
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
319
|
-
mbar_ptr=mbar_ptr,
|
|
320
|
-
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
321
|
-
)
|
|
322
|
-
# cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after full arrive", bidx, bidz, tidx)
|
|
323
|
-
else:
|
|
324
|
-
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
325
|
-
self._scheduler_pipeline.consumer_wait(self._pipeline_state)
|
|
326
|
-
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
327
|
-
current_work_linear_idx = self._tile_count[self._pipeline_state.index]
|
|
328
|
-
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after smem read, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
329
|
-
# Need this fence since the STAS from the producer is using the async proxy.
|
|
330
|
-
# Without this, we get race condition / deadlock.
|
|
331
|
-
if const_expr(cute.size(params.cluster_shape_mn) > 1):
|
|
332
|
-
cute.arch.fence_proxy(
|
|
333
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
334
|
-
)
|
|
335
|
-
cute.arch.sync_warp()
|
|
336
|
-
with cute.arch.elect_one():
|
|
337
|
-
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before empty arrive", bidx, bidz, tidx)
|
|
338
|
-
self._scheduler_pipeline.consumer_release(self._pipeline_state)
|
|
339
|
-
# if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
|
|
340
|
-
# if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
|
|
341
|
-
self._current_work_linear_idx = current_work_linear_idx
|
|
342
|
-
self._pipeline_state.advance()
|
|
343
401
|
self.num_tiles_executed += Int32(advance_count)
|
|
402
|
+
if const_expr(self._pipeline_state is not None and advance_count > 1):
|
|
403
|
+
self._pipeline_state.advance_iters(advance_count - 1)
|
|
404
|
+
if const_expr(params.persistence_mode in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC]):
|
|
405
|
+
# We assume here that advance_count is 1 for scheduler_warp
|
|
406
|
+
if is_scheduler_warp:
|
|
407
|
+
self._current_work_linear_idx = self._fetch_next_work_idx(loc=loc, ip=ip)
|
|
408
|
+
work_tile_info = self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip)
|
|
409
|
+
self.write_work_tile_to_smem(work_tile_info, loc=loc, ip=ip)
|
|
344
410
|
|
|
345
411
|
def producer_tail(self):
|
|
346
|
-
if const_expr(self.
|
|
347
|
-
|
|
412
|
+
if const_expr(self._scheduler_pipeline is not None):
|
|
413
|
+
pipeline_state_producer = PipelineStateWAdvance(
|
|
414
|
+
self._pipeline_state.stages,
|
|
415
|
+
self._pipeline_state.count,
|
|
416
|
+
self._pipeline_state.index,
|
|
417
|
+
self._pipeline_state.phase ^ 1,
|
|
418
|
+
)
|
|
419
|
+
self._scheduler_pipeline.producer_tail(pipeline_state_producer)
|
|
348
420
|
|
|
349
421
|
def __extract_mlir_values__(self):
|
|
350
422
|
values, self._values_pos = [], []
|
|
351
423
|
for obj in [
|
|
352
424
|
self._current_work_linear_idx,
|
|
353
425
|
self.num_tiles_executed,
|
|
354
|
-
self.
|
|
426
|
+
self._sched_smem,
|
|
355
427
|
self._scheduler_pipeline,
|
|
356
428
|
self._pipeline_state,
|
|
357
429
|
self.params,
|
|
@@ -367,7 +439,7 @@ class TileScheduler:
|
|
|
367
439
|
[
|
|
368
440
|
self._current_work_linear_idx,
|
|
369
441
|
self.num_tiles_executed,
|
|
370
|
-
self.
|
|
442
|
+
self._sched_smem,
|
|
371
443
|
self._scheduler_pipeline,
|
|
372
444
|
self._pipeline_state,
|
|
373
445
|
self.params,
|
|
@@ -396,16 +468,16 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
396
468
|
@dataclass
|
|
397
469
|
class Params(ParamsBase):
|
|
398
470
|
problem_shape_ncluster_mnl: cute.Shape
|
|
399
|
-
|
|
471
|
+
num_clusters_per_problem_fdd: FastDivmod
|
|
400
472
|
group_size_inv_f32: Float32
|
|
401
473
|
num_groups_regular: Int32
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
474
|
+
group_size_fdd: FastDivmod
|
|
475
|
+
group_size_tail_fdd: FastDivmod
|
|
476
|
+
group_size_mul_group_size_fdd: FastDivmod
|
|
477
|
+
group_size_tail_mul_group_size_fdd: FastDivmod
|
|
406
478
|
tile_count_semaphore: Optional[cute.Pointer]
|
|
407
479
|
cluster_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
408
|
-
|
|
480
|
+
persistence_mode: cutlass.Constexpr[PersistenceMode]
|
|
409
481
|
|
|
410
482
|
@staticmethod
|
|
411
483
|
@cute.jit
|
|
@@ -425,19 +497,23 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
425
497
|
group_size = min(args.group_size, cluster_m)
|
|
426
498
|
group_size_tail = cluster_m % group_size
|
|
427
499
|
num_groups_regular = cluster_m // group_size
|
|
500
|
+
if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC):
|
|
501
|
+
assert args.tile_count_semaphore is not None
|
|
428
502
|
return TriangularTileScheduler.Params(
|
|
429
503
|
problem_shape_ncluster_mnl,
|
|
430
|
-
FastDivmod
|
|
504
|
+
FastDivmod(num_clusters_per_problem),
|
|
431
505
|
Float32(1.0 / group_size),
|
|
432
506
|
num_groups_regular,
|
|
433
|
-
FastDivmod
|
|
507
|
+
FastDivmod(group_size),
|
|
434
508
|
# Don't divide by 0
|
|
435
|
-
FastDivmod
|
|
436
|
-
FastDivmod
|
|
437
|
-
FastDivmod
|
|
438
|
-
args.tile_count_semaphore
|
|
509
|
+
FastDivmod(group_size_tail if group_size_tail > 0 else 1),
|
|
510
|
+
FastDivmod(group_size * group_size),
|
|
511
|
+
FastDivmod((group_size_tail if group_size_tail > 0 else 1) * group_size),
|
|
512
|
+
args.tile_count_semaphore
|
|
513
|
+
if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC)
|
|
514
|
+
else None,
|
|
439
515
|
cluster_shape_mn,
|
|
440
|
-
args.
|
|
516
|
+
args.persistence_mode,
|
|
441
517
|
)
|
|
442
518
|
|
|
443
519
|
@staticmethod
|
|
@@ -448,30 +524,32 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
448
524
|
@cute.jit
|
|
449
525
|
def create(
|
|
450
526
|
params: Params,
|
|
451
|
-
|
|
527
|
+
sched_smem: Optional[cute.Tensor] = None,
|
|
452
528
|
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
|
|
453
|
-
is_scheduler_warp: bool | Boolean = False,
|
|
454
529
|
*,
|
|
455
530
|
loc=None,
|
|
456
531
|
ip=None,
|
|
457
532
|
) -> "TriangularTileScheduler":
|
|
458
533
|
stages = 0
|
|
459
|
-
if const_expr(
|
|
534
|
+
if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
|
|
460
535
|
cluster_id, _, _ = cute.arch.cluster_idx()
|
|
461
536
|
current_work_linear_idx = Int32(cluster_id)
|
|
462
537
|
else:
|
|
463
538
|
_, _, bidz = cute.arch.block_idx()
|
|
464
539
|
current_work_linear_idx = Int32(bidz)
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
540
|
+
if const_expr(
|
|
541
|
+
params.persistence_mode
|
|
542
|
+
in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC]
|
|
543
|
+
):
|
|
544
|
+
assert sched_smem is not None
|
|
545
|
+
assert scheduler_pipeline is not None
|
|
546
|
+
stages = const_expr(cute.size(sched_smem))
|
|
469
547
|
return TriangularTileScheduler(
|
|
470
548
|
current_work_linear_idx,
|
|
471
549
|
Int32(0), # num_tiles_executed
|
|
472
|
-
|
|
550
|
+
sched_smem,
|
|
473
551
|
scheduler_pipeline,
|
|
474
|
-
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(
|
|
552
|
+
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)),
|
|
475
553
|
params,
|
|
476
554
|
loc=loc,
|
|
477
555
|
ip=ip,
|
|
@@ -486,15 +564,11 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
486
564
|
loc=None,
|
|
487
565
|
ip=None,
|
|
488
566
|
) -> Tuple[Int32, Int32, Int32]:
|
|
489
|
-
clusters = (
|
|
490
|
-
params.num_clusters_per_problem_divmod.divisor,
|
|
491
|
-
1,
|
|
492
|
-
params.problem_shape_ncluster_mnl[2],
|
|
493
|
-
)
|
|
567
|
+
clusters = (params.num_clusters_per_problem_fdd.divisor, 1)
|
|
494
568
|
num_ctas_mnl = tuple(x * y for x, y in zip(clusters, params.cluster_shape_mn)) + (
|
|
495
569
|
params.problem_shape_ncluster_mnl[2],
|
|
496
570
|
)
|
|
497
|
-
if const_expr(
|
|
571
|
+
if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
|
|
498
572
|
return num_ctas_mnl
|
|
499
573
|
else:
|
|
500
574
|
num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
|
|
@@ -506,17 +580,19 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
506
580
|
return (*params.cluster_shape_mn, num_persistent_clusters)
|
|
507
581
|
|
|
508
582
|
@cute.jit
|
|
509
|
-
def
|
|
583
|
+
def _delinearize_work_idx(
|
|
584
|
+
self, *, block_zero_only: bool = False, loc=None, ip=None
|
|
585
|
+
) -> cutlass.utils.WorkTileInfo:
|
|
510
586
|
params = self.params
|
|
511
|
-
if const_expr(
|
|
587
|
+
if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
|
|
512
588
|
cluster_id_in_problem = self._current_work_linear_idx
|
|
513
589
|
_, _, bidz = cute.arch.block_idx()
|
|
514
590
|
else:
|
|
515
|
-
bidz, cluster_id_in_problem =
|
|
516
|
-
self._current_work_linear_idx
|
|
591
|
+
bidz, cluster_id_in_problem = divmod(
|
|
592
|
+
self._current_work_linear_idx, params.num_clusters_per_problem_fdd
|
|
517
593
|
)
|
|
518
594
|
# CTA Swizzle to promote L2 data reuse
|
|
519
|
-
group_size = params.
|
|
595
|
+
group_size = params.group_size_fdd.divisor
|
|
520
596
|
group_id = (
|
|
521
597
|
utils.ceil(
|
|
522
598
|
(utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
|
|
@@ -528,40 +604,40 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
528
604
|
group_size_actual = (
|
|
529
605
|
group_size
|
|
530
606
|
if group_id < params.num_groups_regular
|
|
531
|
-
else params.
|
|
607
|
+
else params.group_size_tail_fdd.divisor
|
|
532
608
|
)
|
|
533
609
|
group_col, group_remainder = Int32(0), Int32(0)
|
|
534
610
|
if group_id < params.num_groups_regular:
|
|
535
|
-
group_col, group_remainder =
|
|
611
|
+
group_col, group_remainder = divmod(id_in_group, params.group_size_mul_group_size_fdd)
|
|
536
612
|
else: # tail part
|
|
537
|
-
group_col, group_remainder =
|
|
538
|
-
id_in_group
|
|
613
|
+
group_col, group_remainder = divmod(
|
|
614
|
+
id_in_group, params.group_size_tail_mul_group_size_fdd
|
|
539
615
|
)
|
|
540
616
|
cid_m_in_group, cid_n_in_group = Int32(0), Int32(0)
|
|
541
617
|
if id_in_group >= group_size_actual * group_size * group_id: # triangular tail
|
|
542
618
|
cid_m_in_group, cid_n_in_group = triangular_idx_to_coord(group_remainder)
|
|
543
619
|
else:
|
|
544
620
|
if group_id < params.num_groups_regular:
|
|
545
|
-
cid_n_in_group, cid_m_in_group =
|
|
621
|
+
cid_n_in_group, cid_m_in_group = divmod(group_remainder, params.group_size_fdd)
|
|
546
622
|
else:
|
|
547
|
-
cid_n_in_group, cid_m_in_group = params.
|
|
548
|
-
group_remainder
|
|
549
|
-
)
|
|
623
|
+
cid_n_in_group, cid_m_in_group = divmod(group_remainder, params.group_size_tail_fdd)
|
|
550
624
|
cid_m = cid_m_start + cid_m_in_group
|
|
551
625
|
cid_n = group_col * group_size + cid_n_in_group
|
|
552
626
|
|
|
553
|
-
|
|
554
|
-
|
|
627
|
+
if const_expr(block_zero_only):
|
|
628
|
+
bidx_in_cluster = (Int32(0), Int32(0))
|
|
629
|
+
else:
|
|
630
|
+
# Get the pid from cluster id
|
|
631
|
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
555
632
|
pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
556
633
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
557
634
|
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
|
558
|
-
if const_expr(
|
|
635
|
+
if const_expr(params.persistence_mode == PersistenceMode.NONE):
|
|
559
636
|
is_valid = self.num_tiles_executed == 0
|
|
560
637
|
else:
|
|
561
638
|
is_valid = (
|
|
562
639
|
self._current_work_linear_idx
|
|
563
|
-
< params.
|
|
564
|
-
* params.problem_shape_ncluster_mnl[2]
|
|
640
|
+
< params.num_clusters_per_problem_fdd.divisor * params.problem_shape_ncluster_mnl[2]
|
|
565
641
|
)
|
|
566
642
|
# bidx, bidy, bidz = cute.arch.block_idx()
|
|
567
643
|
# tidx, _, _ = cute.arch.thread_idx()
|
|
@@ -581,7 +657,7 @@ class VarlenMTileSchedulerArguments(ParamsBase):
|
|
|
581
657
|
tile_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
582
658
|
cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
|
|
583
659
|
tile_count_semaphore: Optional[cute.Pointer] = None
|
|
584
|
-
|
|
660
|
+
persistence_mode: cutlass.Constexpr[PersistenceMode] = PersistenceMode.NONE
|
|
585
661
|
|
|
586
662
|
|
|
587
663
|
class VarlenMTileScheduler(TileScheduler):
|
|
@@ -592,13 +668,13 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
592
668
|
cu_seqlens_m: cute.Tensor
|
|
593
669
|
raster_order: cutlass.Constexpr[RasterOrder]
|
|
594
670
|
group_size: Int32
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
671
|
+
group_size_fdd: Optional[FastDivmod]
|
|
672
|
+
group_size_tail_fdd: Optional[FastDivmod]
|
|
673
|
+
num_clusters_in_group_fdd: FastDivmod
|
|
598
674
|
tile_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
599
675
|
tile_count_semaphore: Optional[cute.Pointer]
|
|
600
676
|
cluster_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
601
|
-
|
|
677
|
+
persistence_mode: cutlass.Constexpr[PersistenceMode]
|
|
602
678
|
|
|
603
679
|
@staticmethod
|
|
604
680
|
@cute.jit
|
|
@@ -621,43 +697,40 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
621
697
|
if args.raster_order == RasterOrderOption.AlongM
|
|
622
698
|
else RasterOrder.AlongN # For Heuristic we also use AlongN
|
|
623
699
|
)
|
|
624
|
-
ncluster_fast =
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
problem_shape_ncluster_mn[1]
|
|
631
|
-
if raster_order == RasterOrder.AlongM
|
|
632
|
-
else problem_shape_ncluster_mn[0]
|
|
633
|
-
)
|
|
700
|
+
ncluster_fast = problem_shape_ncluster_mn[
|
|
701
|
+
0 if raster_order == RasterOrder.AlongM else 1
|
|
702
|
+
]
|
|
703
|
+
ncluster_slow = problem_shape_ncluster_mn[
|
|
704
|
+
1 if raster_order == RasterOrder.AlongM else 0
|
|
705
|
+
]
|
|
634
706
|
if const_expr(ncluster_fast is not None):
|
|
635
707
|
group_size = min(args.group_size, ncluster_fast)
|
|
636
708
|
group_size_tail = ncluster_fast % group_size
|
|
637
709
|
else:
|
|
638
710
|
group_size, group_size_tail = args.group_size, None
|
|
711
|
+
num_clusters_in_group = None
|
|
639
712
|
if const_expr(ncluster_slow is not None):
|
|
640
713
|
num_clusters_in_group = group_size * ncluster_slow
|
|
641
|
-
|
|
642
|
-
|
|
714
|
+
if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC):
|
|
715
|
+
assert args.tile_count_semaphore is not None
|
|
643
716
|
return VarlenMTileScheduler.Params(
|
|
644
717
|
problem_shape_ncluster_mnl,
|
|
645
718
|
args.total_m,
|
|
646
719
|
args.cu_seqlens_m,
|
|
647
720
|
raster_order,
|
|
648
721
|
group_size,
|
|
649
|
-
FastDivmod
|
|
722
|
+
FastDivmod(group_size) if ncluster_fast is not None else None,
|
|
650
723
|
# Don't divide by 0
|
|
651
|
-
FastDivmod
|
|
724
|
+
FastDivmod(group_size_tail if group_size_tail > 0 else 1)
|
|
652
725
|
if group_size_tail is not None
|
|
653
726
|
else None,
|
|
654
|
-
FastDivmod
|
|
655
|
-
if num_clusters_in_group is not None
|
|
656
|
-
else None,
|
|
727
|
+
FastDivmod(num_clusters_in_group) if num_clusters_in_group is not None else None,
|
|
657
728
|
args.tile_shape_mn,
|
|
658
|
-
args.tile_count_semaphore
|
|
729
|
+
args.tile_count_semaphore
|
|
730
|
+
if const_expr(args.persistence_mode == PersistenceMode.DYNAMIC)
|
|
731
|
+
else None,
|
|
659
732
|
cluster_shape_mn,
|
|
660
|
-
args.
|
|
733
|
+
args.persistence_mode,
|
|
661
734
|
)
|
|
662
735
|
|
|
663
736
|
def __init__(
|
|
@@ -666,7 +739,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
666
739
|
num_tiles_executed: Int32,
|
|
667
740
|
current_batch_idx: Int32,
|
|
668
741
|
num_work_idx_before_cur_batch: Int32,
|
|
669
|
-
|
|
742
|
+
sched_smem: Optional[cute.Tensor],
|
|
670
743
|
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
|
|
671
744
|
pipeline_state: PipelineStateWAdvance,
|
|
672
745
|
params: Params,
|
|
@@ -678,7 +751,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
678
751
|
self.num_tiles_executed = num_tiles_executed
|
|
679
752
|
self._current_batch_idx = current_batch_idx
|
|
680
753
|
self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
|
|
681
|
-
self.
|
|
754
|
+
self._sched_smem = sched_smem
|
|
682
755
|
self._scheduler_pipeline = scheduler_pipeline
|
|
683
756
|
self._pipeline_state = pipeline_state
|
|
684
757
|
self.params = params
|
|
@@ -693,9 +766,8 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
693
766
|
@cute.jit
|
|
694
767
|
def create(
|
|
695
768
|
params: Params,
|
|
696
|
-
|
|
769
|
+
sched_smem: Optional[cute.Tensor] = None,
|
|
697
770
|
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
|
|
698
|
-
is_scheduler_warp: bool | Boolean = False,
|
|
699
771
|
*,
|
|
700
772
|
loc=None,
|
|
701
773
|
ip=None,
|
|
@@ -703,18 +775,21 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
703
775
|
stages = 0
|
|
704
776
|
_, _, bidz = cute.arch.block_idx()
|
|
705
777
|
current_work_linear_idx = Int32(bidz)
|
|
706
|
-
if const_expr(
|
|
707
|
-
|
|
778
|
+
if const_expr(
|
|
779
|
+
params.persistence_mode
|
|
780
|
+
in [PersistenceMode.STATIC, PersistenceMode.DYNAMIC, PersistenceMode.CLC]
|
|
781
|
+
):
|
|
782
|
+
assert sched_smem is not None
|
|
708
783
|
assert scheduler_pipeline is not None
|
|
709
|
-
stages = const_expr(cute.size(
|
|
784
|
+
stages = const_expr(cute.size(sched_smem, mode=[1]))
|
|
710
785
|
return VarlenMTileScheduler(
|
|
711
786
|
current_work_linear_idx,
|
|
712
787
|
Int32(0), # num_tiles_executed
|
|
713
788
|
Int32(0), # current_batch_idx
|
|
714
789
|
Int32(0), # num_work_idx_before_cur_batch
|
|
715
|
-
|
|
790
|
+
sched_smem,
|
|
716
791
|
scheduler_pipeline,
|
|
717
|
-
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(
|
|
792
|
+
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0)),
|
|
718
793
|
params,
|
|
719
794
|
loc=loc,
|
|
720
795
|
ip=ip,
|
|
@@ -733,54 +808,33 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
733
808
|
num_batch = params.problem_shape_ncluster_mnl[2]
|
|
734
809
|
total_clusters_m_max = (params.total_m + num_batch * (block_size - 1)) // block_size
|
|
735
810
|
total_clusters_max = total_clusters_m_max * params.problem_shape_ncluster_mnl[1]
|
|
736
|
-
if const_expr(
|
|
811
|
+
if const_expr(params.persistence_mode in [PersistenceMode.NONE, PersistenceMode.CLC]):
|
|
737
812
|
return (*params.cluster_shape_mn, total_clusters_max)
|
|
738
813
|
else:
|
|
739
814
|
num_persistent_clusters = cutlass.min(max_active_clusters, total_clusters_max)
|
|
740
815
|
return (*params.cluster_shape_mn, num_persistent_clusters)
|
|
741
816
|
|
|
742
|
-
@cute.jit
|
|
743
|
-
def _get_num_m_blocks(
|
|
744
|
-
self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int]
|
|
745
|
-
) -> Int32:
|
|
746
|
-
num_batch = self.params.problem_shape_ncluster_mnl[2]
|
|
747
|
-
batch_idx = lane + bidb_start
|
|
748
|
-
cur_cu_seqlen = Int32(0)
|
|
749
|
-
if batch_idx <= num_batch:
|
|
750
|
-
cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx]
|
|
751
|
-
next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
|
|
752
|
-
seqlen = next_cu_seqlen - cur_cu_seqlen
|
|
753
|
-
return (
|
|
754
|
-
cute.ceil_div(seqlen, block_size)
|
|
755
|
-
if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1
|
|
756
|
-
else Int32(0)
|
|
757
|
-
)
|
|
758
|
-
|
|
759
817
|
@cute.jit
|
|
760
818
|
def _swizzle_cta(
|
|
761
819
|
self, cluster_id_in_problem: Int32, num_clusters_m: Int32, *, loc=None, ip=None
|
|
762
820
|
) -> Tuple[Int32, Int32]:
|
|
763
821
|
params = self.params
|
|
764
822
|
# CTA Swizzle to promote L2 data reuse
|
|
765
|
-
if const_expr(params.
|
|
766
|
-
group_id, id_in_group = params.
|
|
767
|
-
|
|
768
|
-
)
|
|
769
|
-
num_clusters_in_group = params.num_clusters_in_group_divmod.divisor
|
|
823
|
+
if const_expr(params.num_clusters_in_group_fdd is not None):
|
|
824
|
+
group_id, id_in_group = divmod(cluster_id_in_problem, params.num_clusters_in_group_fdd)
|
|
825
|
+
num_clusters_in_group = params.num_clusters_in_group_fdd.divisor
|
|
770
826
|
else:
|
|
771
827
|
assert params.raster_order == RasterOrder.AlongN
|
|
772
828
|
num_clusters_in_group = params.group_size * num_clusters_m
|
|
773
829
|
group_id = cluster_id_in_problem // num_clusters_in_group
|
|
774
830
|
id_in_group = cluster_id_in_problem - group_id * num_clusters_in_group
|
|
775
831
|
cid_fast_in_group, cid_slow = Int32(0), Int32(0)
|
|
776
|
-
if const_expr(
|
|
777
|
-
params.group_size_divmod is not None and params.group_size_tail_divmod is not None
|
|
778
|
-
):
|
|
832
|
+
if const_expr(params.group_size_fdd is not None and params.group_size_tail_fdd is not None):
|
|
779
833
|
num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
|
|
780
834
|
if (group_id + 1) * num_clusters_in_group <= num_clusters:
|
|
781
|
-
cid_slow, cid_fast_in_group =
|
|
835
|
+
cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_fdd)
|
|
782
836
|
else: # tail part
|
|
783
|
-
cid_slow, cid_fast_in_group =
|
|
837
|
+
cid_slow, cid_fast_in_group = divmod(id_in_group, params.group_size_tail_fdd)
|
|
784
838
|
else:
|
|
785
839
|
assert params.raster_order == RasterOrder.AlongM
|
|
786
840
|
group_size_actual = cutlass.min(
|
|
@@ -802,7 +856,26 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
802
856
|
return cid_m, cid_n
|
|
803
857
|
|
|
804
858
|
@cute.jit
|
|
805
|
-
def
|
|
859
|
+
def _get_num_m_blocks(
|
|
860
|
+
self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int]
|
|
861
|
+
) -> Int32:
|
|
862
|
+
num_batch = self.params.problem_shape_ncluster_mnl[2]
|
|
863
|
+
batch_idx = lane + bidb_start
|
|
864
|
+
cur_cu_seqlen = Int32(0)
|
|
865
|
+
if batch_idx <= num_batch:
|
|
866
|
+
cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx]
|
|
867
|
+
next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
|
|
868
|
+
seqlen = next_cu_seqlen - cur_cu_seqlen
|
|
869
|
+
return (
|
|
870
|
+
cute.ceil_div(seqlen, block_size)
|
|
871
|
+
if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1
|
|
872
|
+
else Int32(0)
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
@cute.jit
|
|
876
|
+
def _delinearize_work_idx(
|
|
877
|
+
self, *, block_zero_only: bool = False, loc=None, ip=None
|
|
878
|
+
) -> cutlass.utils.WorkTileInfo:
|
|
806
879
|
params = self.params
|
|
807
880
|
lane_idx = cute.arch.lane_idx()
|
|
808
881
|
num_batch = self.params.problem_shape_ncluster_mnl[2]
|
|
@@ -819,7 +892,6 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
819
892
|
)
|
|
820
893
|
problems_end_tile = self._num_work_idx_before_cur_batch + clusters_in_problems
|
|
821
894
|
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, num_clusters_cumulative = %d, problems_end_tile = %d", self._tile_idx, problems_end_tile, num_clusters_m, num_clusters_cumulative, problems_end_tile)
|
|
822
|
-
cid_m, cid_n = Int32(0), Int32(0)
|
|
823
895
|
next_tile_idx = self._current_work_linear_idx
|
|
824
896
|
while problems_end_tile <= next_tile_idx:
|
|
825
897
|
batch_idx += cute.arch.WARP_SIZE - 1
|
|
@@ -836,11 +908,14 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
836
908
|
num_clusters_cumulative, cute.arch.WARP_SIZE - 1
|
|
837
909
|
)
|
|
838
910
|
problems_end_tile += clusters_in_problems
|
|
911
|
+
if const_expr(params.persistence_mode == PersistenceMode.NONE):
|
|
912
|
+
is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
|
|
913
|
+
else:
|
|
914
|
+
is_valid = batch_idx < num_batch
|
|
839
915
|
# Just a placeholer value in case batch_idx >= num_batch
|
|
840
916
|
num_work_idx_before_cur_batch = problems_end_tile - clusters_in_problems
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
else:
|
|
917
|
+
cid_m, cid_n = Int32(0), Int32(0)
|
|
918
|
+
if is_valid:
|
|
844
919
|
problems_start_tile = problems_end_tile - clusters_in_problems
|
|
845
920
|
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, batch_idx = %d", self._tile_idx, problems_end_tile, num_clusters_m, batch_idx)
|
|
846
921
|
# The next problem to process is the first one that does not have ending tile position
|
|
@@ -859,42 +934,21 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
859
934
|
num_clusters_m = cute.arch.shuffle_sync(num_clusters_m, batch_idx_in_problems)
|
|
860
935
|
num_work_idx_before_cur_batch = problems_start_tile + num_clusters_prev_lane
|
|
861
936
|
cluster_id_in_problem = next_tile_idx - num_work_idx_before_cur_batch
|
|
862
|
-
# cid_n = cluster_id_in_problem // num_clusters_m
|
|
863
|
-
# cid_m = cluster_id_in_problem - cid_n * num_clusters_m
|
|
864
937
|
# if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, cid_n=%d, cid_m=%d, is_valid = %d", self._tile_idx, batch_idx, cid_n, cid_m, is_valid)
|
|
865
938
|
cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, num_clusters_m, loc=loc, ip=ip)
|
|
866
939
|
self._current_batch_idx = batch_idx
|
|
867
940
|
self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
|
|
868
941
|
|
|
869
|
-
|
|
870
|
-
|
|
942
|
+
if const_expr(block_zero_only):
|
|
943
|
+
bidx_in_cluster = (Int32(0), Int32(0))
|
|
944
|
+
else:
|
|
945
|
+
# Get the pid from cluster id
|
|
946
|
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
871
947
|
pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
872
948
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
873
949
|
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
874
|
-
if const_expr(not params.is_persistent):
|
|
875
|
-
is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
|
|
876
|
-
else:
|
|
877
|
-
is_valid = batch_idx < num_batch
|
|
878
950
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
879
951
|
|
|
880
|
-
@cute.jit
|
|
881
|
-
def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
882
|
-
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
883
|
-
if const_expr(self.params.tile_count_semaphore is not None):
|
|
884
|
-
params = self.params
|
|
885
|
-
current_work_linear_idx = self._current_work_linear_idx
|
|
886
|
-
if is_scheduler_warp:
|
|
887
|
-
if cute.arch.lane_idx() == 0:
|
|
888
|
-
# cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
|
|
889
|
-
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
890
|
-
current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
|
|
891
|
-
1, params.tile_count_semaphore
|
|
892
|
-
)
|
|
893
|
-
# cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
|
|
894
|
-
# lane 0 already has the right tile_idx, just need to broadcast
|
|
895
|
-
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
896
|
-
self._current_work_linear_idx = current_work_linear_idx
|
|
897
|
-
|
|
898
952
|
def __extract_mlir_values__(self):
|
|
899
953
|
values, self._values_pos = [], []
|
|
900
954
|
for obj in [
|
|
@@ -902,7 +956,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
902
956
|
self.num_tiles_executed,
|
|
903
957
|
self._current_batch_idx,
|
|
904
958
|
self._num_work_idx_before_cur_batch,
|
|
905
|
-
self.
|
|
959
|
+
self._sched_smem,
|
|
906
960
|
self._scheduler_pipeline,
|
|
907
961
|
self._pipeline_state,
|
|
908
962
|
self.params,
|
|
@@ -920,7 +974,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
920
974
|
self.num_tiles_executed,
|
|
921
975
|
self._current_batch_idx,
|
|
922
976
|
self._num_work_idx_before_cur_batch,
|
|
923
|
-
self.
|
|
977
|
+
self._sched_smem,
|
|
924
978
|
self._scheduler_pipeline,
|
|
925
979
|
self._pipeline_state,
|
|
926
980
|
self.params,
|