quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/tile_scheduler.py
ADDED
|
@@ -0,0 +1,937 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Tuple, Optional
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import IntEnum
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Int32, Boolean, const_expr
|
|
10
|
+
|
|
11
|
+
import quack.utils as utils
|
|
12
|
+
from quack.fast_math import FastDivmod
|
|
13
|
+
from quack.pipeline import PipelineStateWAdvance
|
|
14
|
+
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RasterOrderOption(IntEnum):
|
|
18
|
+
AlongM = 0
|
|
19
|
+
AlongN = 1
|
|
20
|
+
Heuristic = 2 # Pick AlongM if tiles_n > tiles_m, else AlongN
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RasterOrder(IntEnum):
|
|
24
|
+
AlongM = 0
|
|
25
|
+
AlongN = 1
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@cute.jit
|
|
29
|
+
def get_raster_order_from_option(
|
|
30
|
+
raster_order_option: RasterOrderOption, problem_shape_ncluster_mn: cute.Shape, group_size: Int32
|
|
31
|
+
) -> RasterOrder:
|
|
32
|
+
raster_order = (
|
|
33
|
+
RasterOrder.AlongM
|
|
34
|
+
if raster_order_option == RasterOrderOption.AlongM
|
|
35
|
+
else RasterOrder.AlongN
|
|
36
|
+
)
|
|
37
|
+
if raster_order_option == RasterOrderOption.Heuristic:
|
|
38
|
+
problem_blocks_m = cute.round_up(problem_shape_ncluster_mn[0], group_size)
|
|
39
|
+
problem_blocks_n = cute.round_up(problem_shape_ncluster_mn[1], group_size)
|
|
40
|
+
raster_order = (
|
|
41
|
+
RasterOrder.AlongM if problem_blocks_n > problem_blocks_m else RasterOrder.AlongN
|
|
42
|
+
)
|
|
43
|
+
return raster_order
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Grouping arguments together that should be passed to __call__
|
|
47
|
+
@dataclass
|
|
48
|
+
class TileSchedulerOptions(ArgumentsBase):
|
|
49
|
+
max_active_clusters: Int32
|
|
50
|
+
raster_order: cutlass.Constexpr[RasterOrderOption] = RasterOrderOption.Heuristic
|
|
51
|
+
max_swizzle_size: Int32 = Int32(8)
|
|
52
|
+
tile_count_semaphore: Optional[cute.Pointer] = None
|
|
53
|
+
batch_idx_permute: Optional[cute.Tensor] = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class TileSchedulerArguments(ArgumentsBase):
|
|
58
|
+
problem_shape_ntile_mnl: cute.Shape
|
|
59
|
+
raster_order: cutlass.Constexpr[RasterOrderOption]
|
|
60
|
+
group_size: Int32
|
|
61
|
+
cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
|
|
62
|
+
tile_count_semaphore: Optional[cute.Pointer] = None
|
|
63
|
+
batch_idx_permute: Optional[cute.Tensor] = None
|
|
64
|
+
is_persistent: cutlass.Constexpr[bool] = False
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TileScheduler:
|
|
68
|
+
@dataclass
|
|
69
|
+
class Params(ParamsBase):
|
|
70
|
+
problem_shape_ncluster_mnl: cute.Shape
|
|
71
|
+
raster_order: RasterOrder
|
|
72
|
+
num_clusters_per_problem_divmod: FastDivmod
|
|
73
|
+
num_groups_regular: Int32
|
|
74
|
+
group_size_divmod: FastDivmod
|
|
75
|
+
group_size_tail_divmod: FastDivmod
|
|
76
|
+
num_clusters_in_group_divmod: FastDivmod
|
|
77
|
+
tile_count_semaphore: Optional[cute.Pointer]
|
|
78
|
+
batch_idx_permute: Optional[cute.Tensor]
|
|
79
|
+
cluster_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
80
|
+
is_persistent: cutlass.Constexpr[bool]
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
@cute.jit
|
|
84
|
+
def create(args: TileSchedulerArguments, *, loc=None, ip=None) -> "TileScheduler.Params":
|
|
85
|
+
assert args.cluster_shape_mnk[2] == 1
|
|
86
|
+
cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
|
|
87
|
+
problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
|
|
88
|
+
problem_shape_ncluster_mn = cute.ceil_div(problem_shape_ntile_mn, cluster_shape_mn)
|
|
89
|
+
problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
|
|
90
|
+
args.problem_shape_ntile_mnl[2],
|
|
91
|
+
)
|
|
92
|
+
num_clusters_per_problem = cute.size(problem_shape_ncluster_mn)
|
|
93
|
+
raster_order = get_raster_order_from_option(
|
|
94
|
+
args.raster_order, problem_shape_ncluster_mn, args.group_size
|
|
95
|
+
)
|
|
96
|
+
ncluster_fast = (
|
|
97
|
+
problem_shape_ncluster_mn[0]
|
|
98
|
+
if raster_order == RasterOrder.AlongM
|
|
99
|
+
else problem_shape_ncluster_mn[1]
|
|
100
|
+
)
|
|
101
|
+
ncluster_slow = (
|
|
102
|
+
problem_shape_ncluster_mn[1]
|
|
103
|
+
if raster_order == RasterOrder.AlongM
|
|
104
|
+
else problem_shape_ncluster_mn[0]
|
|
105
|
+
)
|
|
106
|
+
group_size = min(args.group_size, ncluster_fast)
|
|
107
|
+
group_size_tail = ncluster_fast % group_size
|
|
108
|
+
num_groups_regular = ncluster_fast // group_size
|
|
109
|
+
num_clusters_in_group = group_size * ncluster_slow
|
|
110
|
+
return TileScheduler.Params(
|
|
111
|
+
problem_shape_ncluster_mnl,
|
|
112
|
+
raster_order,
|
|
113
|
+
FastDivmod.create(num_clusters_per_problem),
|
|
114
|
+
num_groups_regular,
|
|
115
|
+
FastDivmod.create(group_size),
|
|
116
|
+
# Don't divide by 0
|
|
117
|
+
FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
|
|
118
|
+
FastDivmod.create(num_clusters_in_group),
|
|
119
|
+
args.tile_count_semaphore if const_expr(args.is_persistent) else None,
|
|
120
|
+
args.batch_idx_permute,
|
|
121
|
+
cluster_shape_mn,
|
|
122
|
+
args.is_persistent,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
current_work_linear_idx: Int32,
|
|
128
|
+
num_tiles_executed: Int32,
|
|
129
|
+
tile_count: Optional[cute.Tensor],
|
|
130
|
+
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
|
|
131
|
+
pipeline_state: PipelineStateWAdvance,
|
|
132
|
+
params: Params,
|
|
133
|
+
*,
|
|
134
|
+
loc=None,
|
|
135
|
+
ip=None,
|
|
136
|
+
):
|
|
137
|
+
self._current_work_linear_idx = current_work_linear_idx
|
|
138
|
+
self._num_tiles_executed = num_tiles_executed
|
|
139
|
+
self._tile_count = tile_count
|
|
140
|
+
self._scheduler_pipeline = scheduler_pipeline
|
|
141
|
+
self._pipeline_state = pipeline_state
|
|
142
|
+
self.params = params
|
|
143
|
+
self._loc = loc
|
|
144
|
+
self._ip = ip
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
|
148
|
+
return TileScheduler.Params.create(args, loc=loc, ip=ip)
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
@cute.jit
|
|
152
|
+
def create(
|
|
153
|
+
params: Params,
|
|
154
|
+
tile_count: Optional[cute.Tensor] = None,
|
|
155
|
+
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
|
|
156
|
+
is_scheduler_warp: bool | Boolean = False,
|
|
157
|
+
*,
|
|
158
|
+
loc=None,
|
|
159
|
+
ip=None,
|
|
160
|
+
) -> "TileScheduler":
|
|
161
|
+
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
162
|
+
stages = 0
|
|
163
|
+
if const_expr(not params.is_persistent):
|
|
164
|
+
cidx, cidy, _ = cute.arch.cluster_idx()
|
|
165
|
+
cdimx, _, _ = cute.arch.cluster_dim()
|
|
166
|
+
cluster_id = cidx + cidy * cdimx
|
|
167
|
+
current_work_linear_idx = Int32(cluster_id)
|
|
168
|
+
else:
|
|
169
|
+
_, _, bidz = cute.arch.block_idx()
|
|
170
|
+
current_work_linear_idx = Int32(bidz)
|
|
171
|
+
if const_expr(params.tile_count_semaphore is not None):
|
|
172
|
+
assert tile_count is not None
|
|
173
|
+
assert scheduler_pipeline is not None
|
|
174
|
+
stages = const_expr(cute.size(tile_count))
|
|
175
|
+
return TileScheduler(
|
|
176
|
+
current_work_linear_idx,
|
|
177
|
+
Int32(0), # num_tiles_executed
|
|
178
|
+
tile_count,
|
|
179
|
+
scheduler_pipeline,
|
|
180
|
+
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
|
|
181
|
+
params,
|
|
182
|
+
loc=loc,
|
|
183
|
+
ip=ip,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# called by host
|
|
187
|
+
@staticmethod
|
|
188
|
+
def get_grid_shape(
|
|
189
|
+
params: Params,
|
|
190
|
+
max_active_clusters: Int32,
|
|
191
|
+
*,
|
|
192
|
+
loc=None,
|
|
193
|
+
ip=None,
|
|
194
|
+
) -> Tuple[Int32, Int32, Int32]:
|
|
195
|
+
num_ctas_mnl = tuple(
|
|
196
|
+
x * y for x, y in zip(params.problem_shape_ncluster_mnl, params.cluster_shape_mn)
|
|
197
|
+
) + (params.problem_shape_ncluster_mnl[2],)
|
|
198
|
+
if const_expr(not params.is_persistent):
|
|
199
|
+
return num_ctas_mnl
|
|
200
|
+
else:
|
|
201
|
+
num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
|
|
202
|
+
num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip)
|
|
203
|
+
# Total ctas that can run in one wave
|
|
204
|
+
num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
|
|
205
|
+
num_persistent_ctas = cutlass.min(num_ctas_in_problem, num_ctas_per_wave)
|
|
206
|
+
num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
|
|
207
|
+
return (*params.cluster_shape_mn, num_persistent_clusters)
|
|
208
|
+
|
|
209
|
+
@cute.jit
|
|
210
|
+
def _swizzle_cta(
|
|
211
|
+
self, cluster_id_in_problem: Int32, *, loc=None, ip=None
|
|
212
|
+
) -> Tuple[Int32, Int32]:
|
|
213
|
+
# CTA Swizzle to promote L2 data reuse
|
|
214
|
+
params = self.params
|
|
215
|
+
group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(cluster_id_in_problem)
|
|
216
|
+
cid_fast_in_group, cid_slow = Int32(0), Int32(0)
|
|
217
|
+
if group_id < params.num_groups_regular:
|
|
218
|
+
cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group)
|
|
219
|
+
else: # tail part
|
|
220
|
+
cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group)
|
|
221
|
+
if group_id % 2 == 1: # serpentine order
|
|
222
|
+
ncluster_slow = (
|
|
223
|
+
params.problem_shape_ncluster_mnl[1]
|
|
224
|
+
if params.raster_order == RasterOrder.AlongM
|
|
225
|
+
else params.problem_shape_ncluster_mnl[0]
|
|
226
|
+
)
|
|
227
|
+
cid_slow = ncluster_slow - 1 - cid_slow
|
|
228
|
+
cid_fast = group_id * params.group_size_divmod.divisor + cid_fast_in_group
|
|
229
|
+
cid_m, cid_n = cid_fast, cid_slow
|
|
230
|
+
if params.raster_order == RasterOrder.AlongN:
|
|
231
|
+
cid_m, cid_n = cid_slow, cid_fast
|
|
232
|
+
return cid_m, cid_n
|
|
233
|
+
|
|
234
|
+
@cute.jit
|
|
235
|
+
def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
|
|
236
|
+
params = self.params
|
|
237
|
+
if const_expr(not params.is_persistent):
|
|
238
|
+
cluster_id_in_problem = self._current_work_linear_idx
|
|
239
|
+
_, _, bidz = cute.arch.block_idx()
|
|
240
|
+
else:
|
|
241
|
+
bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
|
|
242
|
+
self._current_work_linear_idx
|
|
243
|
+
)
|
|
244
|
+
cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip)
|
|
245
|
+
# Get the pid from cluster id
|
|
246
|
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
247
|
+
pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
248
|
+
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
249
|
+
batch_idx = (
|
|
250
|
+
bidz if const_expr(params.batch_idx_permute is None) else params.batch_idx_permute[bidz]
|
|
251
|
+
)
|
|
252
|
+
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
253
|
+
if const_expr(not params.is_persistent):
|
|
254
|
+
is_valid = self._num_tiles_executed == 0
|
|
255
|
+
else:
|
|
256
|
+
is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
|
|
257
|
+
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
258
|
+
|
|
259
|
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
260
|
+
return self.get_current_work(loc=loc, ip=ip)
|
|
261
|
+
|
|
262
|
+
@cute.jit
|
|
263
|
+
def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
264
|
+
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
265
|
+
params = self.params
|
|
266
|
+
if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
|
|
267
|
+
current_work_linear_idx = self._current_work_linear_idx
|
|
268
|
+
if is_scheduler_warp:
|
|
269
|
+
if cute.arch.lane_idx() == 0:
|
|
270
|
+
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
271
|
+
current_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32(
|
|
272
|
+
cute.size(params.problem_shape_ncluster_mnl) - 1,
|
|
273
|
+
params.tile_count_semaphore,
|
|
274
|
+
)
|
|
275
|
+
# lane 0 already has the right tile_idx, just need to broadcast
|
|
276
|
+
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
277
|
+
self._current_work_linear_idx = current_work_linear_idx
|
|
278
|
+
|
|
279
|
+
# We have to split broadcast_next_work and advance_to_next_work into two functions
|
|
280
|
+
# due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
|
|
281
|
+
@cute.jit
|
|
282
|
+
def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
283
|
+
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
284
|
+
params = self.params
|
|
285
|
+
if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
|
|
286
|
+
current_work_linear_idx = self._current_work_linear_idx
|
|
287
|
+
if is_scheduler_warp:
|
|
288
|
+
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
289
|
+
lane_idx = cute.arch.lane_idx()
|
|
290
|
+
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
291
|
+
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
292
|
+
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
293
|
+
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
294
|
+
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
295
|
+
else:
|
|
296
|
+
peer_cta_rank_in_cluster = lane_idx
|
|
297
|
+
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
298
|
+
self._pipeline_state
|
|
299
|
+
)
|
|
300
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
301
|
+
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
302
|
+
)
|
|
303
|
+
utils.store_shared_remote(
|
|
304
|
+
val=current_work_linear_idx,
|
|
305
|
+
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
306
|
+
mbar_ptr=mbar_ptr,
|
|
307
|
+
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
308
|
+
)
|
|
309
|
+
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
310
|
+
|
|
311
|
+
@cute.jit
|
|
312
|
+
def advance_to_next_work(
|
|
313
|
+
self,
|
|
314
|
+
is_scheduler_warp: bool | Boolean = False,
|
|
315
|
+
*,
|
|
316
|
+
advance_count: int = 1,
|
|
317
|
+
loc=None,
|
|
318
|
+
ip=None,
|
|
319
|
+
):
|
|
320
|
+
tidx = cute.arch.thread_idx()[0]
|
|
321
|
+
bidx = cute.arch.block_idx()[0]
|
|
322
|
+
params = self.params
|
|
323
|
+
if const_expr(params.is_persistent):
|
|
324
|
+
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
325
|
+
if const_expr(params.tile_count_semaphore is None): # Static persistent
|
|
326
|
+
self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
|
|
327
|
+
else: # Dynamic persistent
|
|
328
|
+
if const_expr(advance_count > 1):
|
|
329
|
+
self._pipeline_state.advance_iters(advance_count - 1)
|
|
330
|
+
current_work_linear_idx = self._current_work_linear_idx
|
|
331
|
+
if not is_scheduler_warp:
|
|
332
|
+
# if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
333
|
+
self._scheduler_pipeline.consumer_wait(self._pipeline_state)
|
|
334
|
+
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
335
|
+
current_work_linear_idx = self._tile_count[self._pipeline_state.index]
|
|
336
|
+
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after smem read, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
337
|
+
cute.arch.sync_warp()
|
|
338
|
+
with cute.arch.elect_one():
|
|
339
|
+
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, before empty arrive", bidx, tidx)
|
|
340
|
+
self._scheduler_pipeline.consumer_release(self._pipeline_state)
|
|
341
|
+
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after empty arrive", bidx, tidx)
|
|
342
|
+
self._current_work_linear_idx = current_work_linear_idx
|
|
343
|
+
self._pipeline_state.advance()
|
|
344
|
+
self._num_tiles_executed += Int32(advance_count)
|
|
345
|
+
|
|
346
|
+
def producer_tail(self):
|
|
347
|
+
if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
|
|
348
|
+
self._scheduler_pipeline.producer_tail(self._pipeline_state)
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def num_tiles_executed(self) -> Int32:
|
|
352
|
+
return self._num_tiles_executed
|
|
353
|
+
|
|
354
|
+
def __extract_mlir_values__(self):
|
|
355
|
+
values, self._values_pos = [], []
|
|
356
|
+
for obj in [
|
|
357
|
+
self._current_work_linear_idx,
|
|
358
|
+
self._num_tiles_executed,
|
|
359
|
+
self._tile_count,
|
|
360
|
+
self._scheduler_pipeline,
|
|
361
|
+
self._pipeline_state,
|
|
362
|
+
self.params,
|
|
363
|
+
]:
|
|
364
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
365
|
+
values += obj_values
|
|
366
|
+
self._values_pos.append(len(obj_values))
|
|
367
|
+
return values
|
|
368
|
+
|
|
369
|
+
def __new_from_mlir_values__(self, values):
|
|
370
|
+
obj_list = []
|
|
371
|
+
for obj, n_items in zip(
|
|
372
|
+
[
|
|
373
|
+
self._current_work_linear_idx,
|
|
374
|
+
self._num_tiles_executed,
|
|
375
|
+
self._tile_count,
|
|
376
|
+
self._scheduler_pipeline,
|
|
377
|
+
self._pipeline_state,
|
|
378
|
+
self.params,
|
|
379
|
+
],
|
|
380
|
+
self._values_pos,
|
|
381
|
+
):
|
|
382
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
383
|
+
values = values[n_items:]
|
|
384
|
+
return self.__class__(*(tuple(obj_list)), loc=self._loc)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@cute.jit
|
|
388
|
+
def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
|
|
389
|
+
"""
|
|
390
|
+
Convert a triangular index to 2D coordinates.
|
|
391
|
+
This is used to convert the linear index to 2D coordinates for triangular matrices.
|
|
392
|
+
"""
|
|
393
|
+
row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1
|
|
394
|
+
col = idx - (row * (row + 1)) // 2
|
|
395
|
+
return row, col
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class TriangularTileScheduler(TileScheduler):
|
|
399
|
+
"""We assume the tile size per cluster is square (e.g., 128 x 256 per CTA, with cluster 2 x 1)"""
|
|
400
|
+
|
|
401
|
+
@dataclass
|
|
402
|
+
class Params(ParamsBase):
|
|
403
|
+
problem_shape_ncluster_mnl: cute.Shape
|
|
404
|
+
num_clusters_per_problem_divmod: FastDivmod
|
|
405
|
+
group_size_inv_f32: cutlass.Float32
|
|
406
|
+
num_groups_regular: Int32
|
|
407
|
+
group_size_divmod: FastDivmod
|
|
408
|
+
group_size_tail_divmod: FastDivmod
|
|
409
|
+
group_size_mul_group_size_divmod: FastDivmod
|
|
410
|
+
group_size_tail_mul_group_size_divmod: FastDivmod
|
|
411
|
+
tile_count_semaphore: Optional[cute.Pointer]
|
|
412
|
+
cluster_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
413
|
+
is_persistent: cutlass.Constexpr[bool]
|
|
414
|
+
|
|
415
|
+
@staticmethod
|
|
416
|
+
@cute.jit
|
|
417
|
+
def create(
|
|
418
|
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
|
419
|
+
) -> "TriangularTileScheduler.Params":
|
|
420
|
+
assert args.cluster_shape_mnk[2] == 1
|
|
421
|
+
cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
|
|
422
|
+
problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
|
|
423
|
+
problem_shape_ncluster_mn = cute.ceil_div(problem_shape_ntile_mn, cluster_shape_mn)
|
|
424
|
+
problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
|
|
425
|
+
args.problem_shape_ntile_mnl[2],
|
|
426
|
+
)
|
|
427
|
+
cluster_m = problem_shape_ncluster_mn[0]
|
|
428
|
+
# Assume that each cluster is responsible for a square tile
|
|
429
|
+
num_clusters_per_problem = cluster_m * (cluster_m + 1) // 2
|
|
430
|
+
group_size = min(args.group_size, cluster_m)
|
|
431
|
+
group_size_tail = cluster_m % group_size
|
|
432
|
+
num_groups_regular = cluster_m // group_size
|
|
433
|
+
return TriangularTileScheduler.Params(
|
|
434
|
+
problem_shape_ncluster_mnl,
|
|
435
|
+
FastDivmod.create(num_clusters_per_problem),
|
|
436
|
+
cutlass.Float32(1.0 / group_size),
|
|
437
|
+
num_groups_regular,
|
|
438
|
+
FastDivmod.create(group_size),
|
|
439
|
+
# Don't divide by 0
|
|
440
|
+
FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
|
|
441
|
+
FastDivmod.create(group_size * group_size),
|
|
442
|
+
FastDivmod.create((group_size_tail if group_size_tail > 0 else 1) * group_size),
|
|
443
|
+
args.tile_count_semaphore if const_expr(args.is_persistent) else None,
|
|
444
|
+
cluster_shape_mn,
|
|
445
|
+
args.is_persistent,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
@staticmethod
|
|
449
|
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
|
450
|
+
return TriangularTileScheduler.Params.create(args, loc=loc, ip=ip)
|
|
451
|
+
|
|
452
|
+
@staticmethod
|
|
453
|
+
@cute.jit
|
|
454
|
+
def create(
|
|
455
|
+
params: Params,
|
|
456
|
+
tile_count: Optional[cute.Tensor] = None,
|
|
457
|
+
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
|
|
458
|
+
is_scheduler_warp: bool | Boolean = False,
|
|
459
|
+
*,
|
|
460
|
+
loc=None,
|
|
461
|
+
ip=None,
|
|
462
|
+
) -> "TriangularTileScheduler":
|
|
463
|
+
stages = 0
|
|
464
|
+
if const_expr(not params.is_persistent):
|
|
465
|
+
cluster_id, _, _ = cute.arch.cluster_idx()
|
|
466
|
+
current_work_linear_idx = Int32(cluster_id)
|
|
467
|
+
else:
|
|
468
|
+
_, _, bidz = cute.arch.block_idx()
|
|
469
|
+
current_work_linear_idx = Int32(bidz)
|
|
470
|
+
if const_expr(params.tile_count_semaphore is not None):
|
|
471
|
+
assert tile_count is not None
|
|
472
|
+
assert scheduler_pipeline is not None
|
|
473
|
+
stages = const_expr(cute.size(tile_count))
|
|
474
|
+
return TriangularTileScheduler(
|
|
475
|
+
current_work_linear_idx,
|
|
476
|
+
Int32(0), # num_tiles_executed
|
|
477
|
+
tile_count,
|
|
478
|
+
scheduler_pipeline,
|
|
479
|
+
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
|
|
480
|
+
params,
|
|
481
|
+
loc=loc,
|
|
482
|
+
ip=ip,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# called by host
|
|
486
|
+
@staticmethod
|
|
487
|
+
def get_grid_shape(
|
|
488
|
+
params: Params,
|
|
489
|
+
max_active_clusters: Int32,
|
|
490
|
+
*,
|
|
491
|
+
loc=None,
|
|
492
|
+
ip=None,
|
|
493
|
+
) -> Tuple[Int32, Int32, Int32]:
|
|
494
|
+
clusters = (
|
|
495
|
+
params.num_clusters_per_problem_divmod.divisor,
|
|
496
|
+
1,
|
|
497
|
+
params.problem_shape_ncluster_mnl[2],
|
|
498
|
+
)
|
|
499
|
+
num_ctas_mnl = tuple(x * y for x, y in zip(clusters, params.cluster_shape_mn)) + (
|
|
500
|
+
params.problem_shape_ncluster_mnl[2],
|
|
501
|
+
)
|
|
502
|
+
if const_expr(not params.is_persistent):
|
|
503
|
+
return num_ctas_mnl
|
|
504
|
+
else:
|
|
505
|
+
num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
|
|
506
|
+
num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip)
|
|
507
|
+
# Total ctas that can run in one wave
|
|
508
|
+
num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
|
|
509
|
+
num_persistent_ctas = cutlass.min(num_ctas_in_problem, num_ctas_per_wave)
|
|
510
|
+
num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
|
|
511
|
+
return (*params.cluster_shape_mn, num_persistent_clusters)
|
|
512
|
+
|
|
513
|
+
@cute.jit
|
|
514
|
+
def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
|
|
515
|
+
params = self.params
|
|
516
|
+
if const_expr(not params.is_persistent):
|
|
517
|
+
cluster_id_in_problem = self._current_work_linear_idx
|
|
518
|
+
_, _, bidz = cute.arch.block_idx()
|
|
519
|
+
else:
|
|
520
|
+
bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod(
|
|
521
|
+
self._current_work_linear_idx
|
|
522
|
+
)
|
|
523
|
+
# CTA Swizzle to promote L2 data reuse
|
|
524
|
+
group_size = params.group_size_divmod.divisor
|
|
525
|
+
group_id = (
|
|
526
|
+
utils.ceil(
|
|
527
|
+
(utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
|
|
528
|
+
)
|
|
529
|
+
- 1
|
|
530
|
+
)
|
|
531
|
+
cid_m_start = group_id * group_size
|
|
532
|
+
id_in_group = cluster_id_in_problem - (cid_m_start * (cid_m_start + 1)) // 2
|
|
533
|
+
group_size_actual = (
|
|
534
|
+
group_size
|
|
535
|
+
if group_id < params.num_groups_regular
|
|
536
|
+
else params.group_size_tail_divmod.divisor
|
|
537
|
+
)
|
|
538
|
+
group_col, group_remainder = Int32(0), Int32(0)
|
|
539
|
+
if group_id < params.num_groups_regular:
|
|
540
|
+
group_col, group_remainder = params.group_size_mul_group_size_divmod.divmod(id_in_group)
|
|
541
|
+
else: # tail part
|
|
542
|
+
group_col, group_remainder = params.group_size_tail_mul_group_size_divmod.divmod(
|
|
543
|
+
id_in_group
|
|
544
|
+
)
|
|
545
|
+
cid_m_in_group, cid_n_in_group = Int32(0), Int32(0)
|
|
546
|
+
if id_in_group >= group_size_actual * group_size * group_id: # triangular tail
|
|
547
|
+
cid_m_in_group, cid_n_in_group = triangular_idx_to_coord(group_remainder)
|
|
548
|
+
else:
|
|
549
|
+
if group_id < params.num_groups_regular:
|
|
550
|
+
cid_n_in_group, cid_m_in_group = params.group_size_divmod.divmod(group_remainder)
|
|
551
|
+
else:
|
|
552
|
+
cid_n_in_group, cid_m_in_group = params.group_size_tail_divmod.divmod(
|
|
553
|
+
group_remainder
|
|
554
|
+
)
|
|
555
|
+
cid_m = cid_m_start + cid_m_in_group
|
|
556
|
+
cid_n = group_col * group_size + cid_n_in_group
|
|
557
|
+
|
|
558
|
+
# Get the pid from cluster id
|
|
559
|
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
560
|
+
pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
561
|
+
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
562
|
+
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
|
563
|
+
if const_expr(not params.is_persistent):
|
|
564
|
+
is_valid = self._num_tiles_executed == 0
|
|
565
|
+
else:
|
|
566
|
+
is_valid = (
|
|
567
|
+
self._current_work_linear_idx
|
|
568
|
+
< params.num_clusters_per_problem_divmod.divisor
|
|
569
|
+
* params.problem_shape_ncluster_mnl[2]
|
|
570
|
+
)
|
|
571
|
+
# bidx, bidy, bidz = cute.arch.block_idx()
|
|
572
|
+
# tidx, _, _ = cute.arch.thread_idx()
|
|
573
|
+
# if tidx == 0:
|
|
574
|
+
# cute.printf("bidx = {}, bidy = {}, group_id = {}, id_in_group = {}, group_size_actual = {}, group_col = {}, group_remainder = {}, cid_n_in_group = {}, cid_m_in_group = {}, cid_m = {}, cid_n = {}, is_valid = {}",
|
|
575
|
+
# bidx, bidy, group_id, id_in_group, group_size_actual, group_col, group_remainder, cid_n_in_group, cid_m_in_group, cid_m, cid_n, is_valid)
|
|
576
|
+
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
@dataclass
|
|
580
|
+
class VarlenMTileSchedulerArguments(ParamsBase):
|
|
581
|
+
problem_shape_ntile_mnl: cute.Shape
|
|
582
|
+
total_m: Int32
|
|
583
|
+
cu_seqlens_m: cute.Tensor
|
|
584
|
+
raster_order: cutlass.Constexpr[RasterOrderOption]
|
|
585
|
+
group_size: Int32
|
|
586
|
+
tile_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
587
|
+
cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
|
|
588
|
+
tile_count_semaphore: Optional[cute.Pointer] = None
|
|
589
|
+
is_persistent: cutlass.Constexpr[bool] = False
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
class VarlenMTileScheduler(TileScheduler):
|
|
593
|
+
@dataclass
|
|
594
|
+
class Params(ParamsBase):
|
|
595
|
+
problem_shape_ncluster_mnl: cute.Shape
|
|
596
|
+
total_m: Int32
|
|
597
|
+
cu_seqlens_m: cute.Tensor
|
|
598
|
+
raster_order: cutlass.Constexpr[RasterOrder]
|
|
599
|
+
group_size: Int32
|
|
600
|
+
group_size_divmod: Optional[FastDivmod]
|
|
601
|
+
group_size_tail_divmod: Optional[FastDivmod]
|
|
602
|
+
num_clusters_in_group_divmod: FastDivmod
|
|
603
|
+
tile_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
604
|
+
tile_count_semaphore: Optional[cute.Pointer]
|
|
605
|
+
cluster_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
606
|
+
is_persistent: cutlass.Constexpr[bool]
|
|
607
|
+
|
|
608
|
+
@staticmethod
|
|
609
|
+
@cute.jit
|
|
610
|
+
def create(
|
|
611
|
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
|
612
|
+
) -> "VarlenMTileScheduler.Params":
|
|
613
|
+
assert args.cluster_shape_mnk[2] == 1
|
|
614
|
+
cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
|
|
615
|
+
# problem_shape_ntile_mnl[0] will be None for VarlenM
|
|
616
|
+
problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
|
|
617
|
+
problem_shape_ncluster_mn = (
|
|
618
|
+
None,
|
|
619
|
+
cute.ceil_div(problem_shape_ntile_mn[1], cluster_shape_mn[1]),
|
|
620
|
+
)
|
|
621
|
+
problem_shape_ncluster_mnl = problem_shape_ncluster_mn + (
|
|
622
|
+
args.problem_shape_ntile_mnl[2],
|
|
623
|
+
)
|
|
624
|
+
raster_order = const_expr(
|
|
625
|
+
RasterOrder.AlongM
|
|
626
|
+
if args.raster_order == RasterOrderOption.AlongM
|
|
627
|
+
else RasterOrder.AlongN # For Heuristic we also use AlongN
|
|
628
|
+
)
|
|
629
|
+
ncluster_fast = (
|
|
630
|
+
problem_shape_ncluster_mn[0]
|
|
631
|
+
if raster_order == RasterOrder.AlongM
|
|
632
|
+
else problem_shape_ncluster_mn[1]
|
|
633
|
+
)
|
|
634
|
+
ncluster_slow = (
|
|
635
|
+
problem_shape_ncluster_mn[1]
|
|
636
|
+
if raster_order == RasterOrder.AlongM
|
|
637
|
+
else problem_shape_ncluster_mn[0]
|
|
638
|
+
)
|
|
639
|
+
if const_expr(ncluster_fast is not None):
|
|
640
|
+
group_size = min(args.group_size, ncluster_fast)
|
|
641
|
+
group_size_tail = ncluster_fast % group_size
|
|
642
|
+
else:
|
|
643
|
+
group_size, group_size_tail = args.group_size, None
|
|
644
|
+
if const_expr(ncluster_slow is not None):
|
|
645
|
+
num_clusters_in_group = group_size * ncluster_slow
|
|
646
|
+
else:
|
|
647
|
+
num_clusters_in_group = None
|
|
648
|
+
return VarlenMTileScheduler.Params(
|
|
649
|
+
problem_shape_ncluster_mnl,
|
|
650
|
+
args.total_m,
|
|
651
|
+
args.cu_seqlens_m,
|
|
652
|
+
raster_order,
|
|
653
|
+
group_size,
|
|
654
|
+
FastDivmod.create(group_size) if ncluster_fast is not None else None,
|
|
655
|
+
# Don't divide by 0
|
|
656
|
+
FastDivmod.create(group_size_tail if group_size_tail > 0 else 1)
|
|
657
|
+
if group_size_tail is not None
|
|
658
|
+
else None,
|
|
659
|
+
FastDivmod.create(num_clusters_in_group)
|
|
660
|
+
if num_clusters_in_group is not None
|
|
661
|
+
else None,
|
|
662
|
+
args.tile_shape_mn,
|
|
663
|
+
args.tile_count_semaphore if const_expr(args.is_persistent) else None,
|
|
664
|
+
cluster_shape_mn,
|
|
665
|
+
args.is_persistent,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
def __init__(
|
|
669
|
+
self,
|
|
670
|
+
current_work_linear_idx: Int32,
|
|
671
|
+
num_tiles_executed: Int32,
|
|
672
|
+
current_batch_idx: Int32,
|
|
673
|
+
num_work_idx_before_cur_batch: Int32,
|
|
674
|
+
tile_count: Optional[cute.Tensor],
|
|
675
|
+
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync],
|
|
676
|
+
pipeline_state: PipelineStateWAdvance,
|
|
677
|
+
params: Params,
|
|
678
|
+
*,
|
|
679
|
+
loc=None,
|
|
680
|
+
ip=None,
|
|
681
|
+
):
|
|
682
|
+
self._current_work_linear_idx = current_work_linear_idx
|
|
683
|
+
self._num_tiles_executed = num_tiles_executed
|
|
684
|
+
self._current_batch_idx = current_batch_idx
|
|
685
|
+
self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
|
|
686
|
+
self._tile_count = tile_count
|
|
687
|
+
self._scheduler_pipeline = scheduler_pipeline
|
|
688
|
+
self._pipeline_state = pipeline_state
|
|
689
|
+
self.params = params
|
|
690
|
+
self._loc = loc
|
|
691
|
+
self._ip = ip
|
|
692
|
+
|
|
693
|
+
@staticmethod
|
|
694
|
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
|
695
|
+
return VarlenMTileScheduler.Params.create(args, loc=loc, ip=ip)
|
|
696
|
+
|
|
697
|
+
@staticmethod
|
|
698
|
+
@cute.jit
|
|
699
|
+
def create(
|
|
700
|
+
params: Params,
|
|
701
|
+
tile_count: Optional[cute.Tensor] = None,
|
|
702
|
+
scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None,
|
|
703
|
+
is_scheduler_warp: bool | Boolean = False,
|
|
704
|
+
*,
|
|
705
|
+
loc=None,
|
|
706
|
+
ip=None,
|
|
707
|
+
) -> "VarlenMTileScheduler":
|
|
708
|
+
stages = 0
|
|
709
|
+
_, _, bidz = cute.arch.block_idx()
|
|
710
|
+
current_work_linear_idx = Int32(bidz)
|
|
711
|
+
if const_expr(params.tile_count_semaphore is not None):
|
|
712
|
+
assert tile_count is not None
|
|
713
|
+
assert scheduler_pipeline is not None
|
|
714
|
+
stages = const_expr(cute.size(tile_count))
|
|
715
|
+
return VarlenMTileScheduler(
|
|
716
|
+
current_work_linear_idx,
|
|
717
|
+
Int32(0), # num_tiles_executed
|
|
718
|
+
Int32(0), # current_batch_idx
|
|
719
|
+
Int32(0), # num_work_idx_before_cur_batch
|
|
720
|
+
tile_count,
|
|
721
|
+
scheduler_pipeline,
|
|
722
|
+
PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)),
|
|
723
|
+
params,
|
|
724
|
+
loc=loc,
|
|
725
|
+
ip=ip,
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
# called by host
|
|
729
|
+
@staticmethod
|
|
730
|
+
def get_grid_shape(
|
|
731
|
+
params: Params,
|
|
732
|
+
max_active_clusters: Int32,
|
|
733
|
+
*,
|
|
734
|
+
loc=None,
|
|
735
|
+
ip=None,
|
|
736
|
+
) -> Tuple[Int32, Int32, Int32]:
|
|
737
|
+
block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0]
|
|
738
|
+
num_batch = params.problem_shape_ncluster_mnl[2]
|
|
739
|
+
total_clusters_m_max = (params.total_m + num_batch * (block_size - 1)) // block_size
|
|
740
|
+
total_clusters_max = total_clusters_m_max * params.problem_shape_ncluster_mnl[1]
|
|
741
|
+
if const_expr(not params.is_persistent):
|
|
742
|
+
return (*params.cluster_shape_mn, total_clusters_max)
|
|
743
|
+
else:
|
|
744
|
+
num_persistent_clusters = cutlass.min(max_active_clusters, total_clusters_max)
|
|
745
|
+
return (*params.cluster_shape_mn, num_persistent_clusters)
|
|
746
|
+
|
|
747
|
+
@cute.jit
|
|
748
|
+
def _get_num_m_blocks(
|
|
749
|
+
self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int]
|
|
750
|
+
) -> Int32:
|
|
751
|
+
num_batch = self.params.problem_shape_ncluster_mnl[2]
|
|
752
|
+
batch_idx = lane + bidb_start
|
|
753
|
+
cur_cu_seqlen = Int32(0)
|
|
754
|
+
if batch_idx <= num_batch:
|
|
755
|
+
cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx]
|
|
756
|
+
next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
|
|
757
|
+
seqlen = next_cu_seqlen - cur_cu_seqlen
|
|
758
|
+
return (
|
|
759
|
+
cute.ceil_div(seqlen, block_size)
|
|
760
|
+
if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1
|
|
761
|
+
else Int32(0)
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
@cute.jit
|
|
765
|
+
def _swizzle_cta(
|
|
766
|
+
self, cluster_id_in_problem: Int32, num_clusters_m: Int32, *, loc=None, ip=None
|
|
767
|
+
) -> Tuple[Int32, Int32]:
|
|
768
|
+
params = self.params
|
|
769
|
+
# CTA Swizzle to promote L2 data reuse
|
|
770
|
+
if const_expr(params.num_clusters_in_group_divmod is not None):
|
|
771
|
+
group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(
|
|
772
|
+
cluster_id_in_problem
|
|
773
|
+
)
|
|
774
|
+
num_clusters_in_group = params.num_clusters_in_group_divmod.divisor
|
|
775
|
+
else:
|
|
776
|
+
assert params.raster_order == RasterOrder.AlongN
|
|
777
|
+
num_clusters_in_group = params.group_size * num_clusters_m
|
|
778
|
+
group_id = cluster_id_in_problem // num_clusters_in_group
|
|
779
|
+
id_in_group = cluster_id_in_problem - group_id * num_clusters_in_group
|
|
780
|
+
cid_fast_in_group, cid_slow = Int32(0), Int32(0)
|
|
781
|
+
if const_expr(
|
|
782
|
+
params.group_size_divmod is not None and params.group_size_tail_divmod is not None
|
|
783
|
+
):
|
|
784
|
+
num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
|
|
785
|
+
if (group_id + 1) * num_clusters_in_group <= num_clusters:
|
|
786
|
+
cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group)
|
|
787
|
+
else: # tail part
|
|
788
|
+
cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group)
|
|
789
|
+
else:
|
|
790
|
+
assert params.raster_order == RasterOrder.AlongM
|
|
791
|
+
group_size_actual = cutlass.min(
|
|
792
|
+
params.group_size, num_clusters_m - group_id * params.group_size
|
|
793
|
+
)
|
|
794
|
+
cid_slow = id_in_group // group_size_actual
|
|
795
|
+
cid_fast_in_group = id_in_group - cid_slow * group_size_actual
|
|
796
|
+
if group_id % 2 == 1: # serpentine order
|
|
797
|
+
ncluster_slow = (
|
|
798
|
+
params.problem_shape_ncluster_mnl[1]
|
|
799
|
+
if params.raster_order == RasterOrder.AlongM
|
|
800
|
+
else num_clusters_m
|
|
801
|
+
)
|
|
802
|
+
cid_slow = ncluster_slow - 1 - cid_slow
|
|
803
|
+
cid_fast = group_id * params.group_size + cid_fast_in_group
|
|
804
|
+
cid_m, cid_n = cid_fast, cid_slow
|
|
805
|
+
if params.raster_order == RasterOrder.AlongN:
|
|
806
|
+
cid_m, cid_n = cid_slow, cid_fast
|
|
807
|
+
return cid_m, cid_n
|
|
808
|
+
|
|
809
|
+
@cute.jit
|
|
810
|
+
def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
|
|
811
|
+
params = self.params
|
|
812
|
+
lane_idx = cute.arch.lane_idx()
|
|
813
|
+
num_batch = self.params.problem_shape_ncluster_mnl[2]
|
|
814
|
+
block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0]
|
|
815
|
+
batch_idx = self._current_batch_idx
|
|
816
|
+
num_clusters_m = self._get_num_m_blocks(
|
|
817
|
+
lane_idx, bidb_start=batch_idx, block_size=block_size
|
|
818
|
+
)
|
|
819
|
+
num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
|
|
820
|
+
num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx)
|
|
821
|
+
# Total number of blocks for the next 31 problems, same for all lanes
|
|
822
|
+
clusters_in_problems = cute.arch.shuffle_sync(
|
|
823
|
+
num_clusters_cumulative, cute.arch.WARP_SIZE - 1
|
|
824
|
+
)
|
|
825
|
+
problems_end_tile = self._num_work_idx_before_cur_batch + clusters_in_problems
|
|
826
|
+
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, num_clusters_cumulative = %d, problems_end_tile = %d", self._tile_idx, problems_end_tile, num_clusters_m, num_clusters_cumulative, problems_end_tile)
|
|
827
|
+
cid_m, cid_n = Int32(0), Int32(0)
|
|
828
|
+
next_tile_idx = self._current_work_linear_idx
|
|
829
|
+
while problems_end_tile <= next_tile_idx:
|
|
830
|
+
batch_idx += cute.arch.WARP_SIZE - 1
|
|
831
|
+
if batch_idx >= num_batch:
|
|
832
|
+
batch_idx = Int32(num_batch)
|
|
833
|
+
problems_end_tile = next_tile_idx + 1
|
|
834
|
+
else:
|
|
835
|
+
num_clusters_m = self._get_num_m_blocks(
|
|
836
|
+
lane_idx, bidb_start=batch_idx, block_size=block_size
|
|
837
|
+
)
|
|
838
|
+
num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1]
|
|
839
|
+
num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx)
|
|
840
|
+
clusters_in_problems = cute.arch.shuffle_sync(
|
|
841
|
+
num_clusters_cumulative, cute.arch.WARP_SIZE - 1
|
|
842
|
+
)
|
|
843
|
+
problems_end_tile += clusters_in_problems
|
|
844
|
+
# Just a placeholer value in case batch_idx >= num_batch
|
|
845
|
+
num_work_idx_before_cur_batch = problems_end_tile - clusters_in_problems
|
|
846
|
+
if batch_idx >= num_batch:
|
|
847
|
+
cid_m, cid_n, batch_idx = Int32(0), Int32(0), Int32(num_batch)
|
|
848
|
+
else:
|
|
849
|
+
problems_start_tile = problems_end_tile - clusters_in_problems
|
|
850
|
+
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, batch_idx = %d", self._tile_idx, problems_end_tile, num_clusters_m, batch_idx)
|
|
851
|
+
# The next problem to process is the first one that does not have ending tile position
|
|
852
|
+
# that is greater than or equal to tile index.
|
|
853
|
+
batch_idx_in_problems = cute.arch.popc(
|
|
854
|
+
cute.arch.vote_ballot_sync(
|
|
855
|
+
problems_start_tile + num_clusters_cumulative <= next_tile_idx
|
|
856
|
+
)
|
|
857
|
+
)
|
|
858
|
+
batch_idx += batch_idx_in_problems
|
|
859
|
+
num_clusters_prev_lane = (
|
|
860
|
+
0
|
|
861
|
+
if batch_idx_in_problems == 0
|
|
862
|
+
else cute.arch.shuffle_sync(num_clusters_cumulative, batch_idx_in_problems - 1)
|
|
863
|
+
)
|
|
864
|
+
num_clusters_m = cute.arch.shuffle_sync(num_clusters_m, batch_idx_in_problems)
|
|
865
|
+
num_work_idx_before_cur_batch = problems_start_tile + num_clusters_prev_lane
|
|
866
|
+
cluster_id_in_problem = next_tile_idx - num_work_idx_before_cur_batch
|
|
867
|
+
# cid_n = cluster_id_in_problem // num_clusters_m
|
|
868
|
+
# cid_m = cluster_id_in_problem - cid_n * num_clusters_m
|
|
869
|
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, cid_n=%d, cid_m=%d, is_valid = %d", self._tile_idx, batch_idx, cid_n, cid_m, is_valid)
|
|
870
|
+
cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, num_clusters_m, loc=loc, ip=ip)
|
|
871
|
+
self._current_batch_idx = batch_idx
|
|
872
|
+
self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
|
|
873
|
+
|
|
874
|
+
# Get the pid from cluster id
|
|
875
|
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
876
|
+
pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
877
|
+
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
878
|
+
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
879
|
+
if const_expr(not params.is_persistent):
|
|
880
|
+
is_valid = self._num_tiles_executed == 0 and batch_idx < num_batch
|
|
881
|
+
else:
|
|
882
|
+
is_valid = batch_idx < num_batch
|
|
883
|
+
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
884
|
+
|
|
885
|
+
@cute.jit
|
|
886
|
+
def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
|
|
887
|
+
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
888
|
+
if const_expr(self.params.tile_count_semaphore is not None):
|
|
889
|
+
params = self.params
|
|
890
|
+
current_work_linear_idx = self._current_work_linear_idx
|
|
891
|
+
if is_scheduler_warp:
|
|
892
|
+
if cute.arch.lane_idx() == 0:
|
|
893
|
+
# cute.printf("before atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
|
|
894
|
+
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
895
|
+
current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
|
|
896
|
+
1, params.tile_count_semaphore
|
|
897
|
+
)
|
|
898
|
+
# cute.printf("after atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
|
|
899
|
+
# lane 0 already has the right tile_idx, just need to broadcast
|
|
900
|
+
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
901
|
+
self._current_work_linear_idx = current_work_linear_idx
|
|
902
|
+
|
|
903
|
+
def __extract_mlir_values__(self):
|
|
904
|
+
values, self._values_pos = [], []
|
|
905
|
+
for obj in [
|
|
906
|
+
self._current_work_linear_idx,
|
|
907
|
+
self._num_tiles_executed,
|
|
908
|
+
self._current_batch_idx,
|
|
909
|
+
self._num_work_idx_before_cur_batch,
|
|
910
|
+
self._tile_count,
|
|
911
|
+
self._scheduler_pipeline,
|
|
912
|
+
self._pipeline_state,
|
|
913
|
+
self.params,
|
|
914
|
+
]:
|
|
915
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
916
|
+
values += obj_values
|
|
917
|
+
self._values_pos.append(len(obj_values))
|
|
918
|
+
return values
|
|
919
|
+
|
|
920
|
+
def __new_from_mlir_values__(self, values):
|
|
921
|
+
obj_list = []
|
|
922
|
+
for obj, n_items in zip(
|
|
923
|
+
[
|
|
924
|
+
self._current_work_linear_idx,
|
|
925
|
+
self._num_tiles_executed,
|
|
926
|
+
self._current_batch_idx,
|
|
927
|
+
self._num_work_idx_before_cur_batch,
|
|
928
|
+
self._tile_count,
|
|
929
|
+
self._scheduler_pipeline,
|
|
930
|
+
self._pipeline_state,
|
|
931
|
+
self.params,
|
|
932
|
+
],
|
|
933
|
+
self._values_pos,
|
|
934
|
+
):
|
|
935
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
936
|
+
values = values[n_items:]
|
|
937
|
+
return self.__class__(*(tuple(obj_list)), loc=self._loc)
|