mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,720 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
from dataclasses import dataclass, fields
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from typing import override
|
|
9
|
+
except ImportError: # Python < 3.12
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
import cutlass
|
|
13
|
+
from cutlass._mlir import ir
|
|
14
|
+
import cutlass.cute as cute
|
|
15
|
+
from cutlass import Int32, const_expr
|
|
16
|
+
|
|
17
|
+
import mslk.attention.flash_attn.utils as utils
|
|
18
|
+
from mslk.attention.flash_attn.fast_math import clz
|
|
19
|
+
from cutlass.cute import FastDivmodDivisor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class WorkTileInfo(cutlass.utils.WorkTileInfo):
|
|
23
|
+
"""Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
|
|
24
|
+
|
|
25
|
+
@override
|
|
26
|
+
def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
|
|
27
|
+
assert len(values) == 5
|
|
28
|
+
new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1])
|
|
29
|
+
new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]])
|
|
30
|
+
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ParamsBase:
|
|
35
|
+
def __extract_mlir_values__(self):
|
|
36
|
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
37
|
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)]
|
|
38
|
+
values, self._values_pos = [], []
|
|
39
|
+
for obj in non_constexpr_fields:
|
|
40
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
41
|
+
values += obj_values
|
|
42
|
+
self._values_pos.append(len(obj_values))
|
|
43
|
+
return values
|
|
44
|
+
|
|
45
|
+
def __new_from_mlir_values__(self, values):
|
|
46
|
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
47
|
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)}
|
|
48
|
+
non_constexpr_fields = {
|
|
49
|
+
n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr)
|
|
50
|
+
}
|
|
51
|
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
52
|
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
53
|
+
values = values[n_items:]
|
|
54
|
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class TileSchedulerArguments(ParamsBase):
|
|
59
|
+
num_block: Int32
|
|
60
|
+
num_head: Int32
|
|
61
|
+
num_batch: Int32
|
|
62
|
+
num_splits: Int32
|
|
63
|
+
seqlen_k: Int32
|
|
64
|
+
headdim: Int32
|
|
65
|
+
headdim_v: Int32
|
|
66
|
+
total_q: Int32
|
|
67
|
+
tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
|
|
68
|
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
|
69
|
+
mCuSeqlensQ: Optional[cute.Tensor] = None
|
|
70
|
+
mSeqUsedQ: Optional[cute.Tensor] = None
|
|
71
|
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
|
72
|
+
element_size: cutlass.Constexpr[int] = 2
|
|
73
|
+
is_persistent: cutlass.Constexpr[bool] = False
|
|
74
|
+
lpt: cutlass.Constexpr[bool] = False
|
|
75
|
+
is_split_kv: cutlass.Constexpr[bool] = False
|
|
76
|
+
head_swizzle: cutlass.Constexpr[bool] = False
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class SingleTileScheduler:
|
|
80
|
+
@dataclass
|
|
81
|
+
class Params(ParamsBase):
|
|
82
|
+
num_block: Int32
|
|
83
|
+
num_head: Int32
|
|
84
|
+
num_batch: Int32
|
|
85
|
+
num_splits: Int32
|
|
86
|
+
num_splits_divmod: FastDivmodDivisor
|
|
87
|
+
is_split_kv: cutlass.Constexpr[bool] = False
|
|
88
|
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def create(
|
|
92
|
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
|
93
|
+
) -> "SingleTileScheduler.Params":
|
|
94
|
+
return SingleTileScheduler.Params(
|
|
95
|
+
args.num_block,
|
|
96
|
+
args.num_head,
|
|
97
|
+
args.num_batch,
|
|
98
|
+
args.num_splits,
|
|
99
|
+
FastDivmodDivisor(args.num_splits),
|
|
100
|
+
args.is_split_kv,
|
|
101
|
+
args.cluster_shape_mn,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
|
|
105
|
+
self.params = params
|
|
106
|
+
self._blk_coord = blk_coord
|
|
107
|
+
self._is_first_block = True
|
|
108
|
+
self._loc = loc
|
|
109
|
+
self._ip = ip
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
|
113
|
+
return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler":
|
|
117
|
+
blk_coord = cute.arch.block_idx()
|
|
118
|
+
return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
|
|
119
|
+
|
|
120
|
+
# called by host
|
|
121
|
+
@staticmethod
|
|
122
|
+
def get_grid_shape(
|
|
123
|
+
params: Params,
|
|
124
|
+
*,
|
|
125
|
+
loc=None,
|
|
126
|
+
ip=None,
|
|
127
|
+
) -> Tuple[Int32, Int32, Int32]:
|
|
128
|
+
# TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
|
|
129
|
+
assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
|
130
|
+
return (
|
|
131
|
+
cute.round_up(params.num_block, params.cluster_shape_mn[0]),
|
|
132
|
+
params.num_head * params.num_splits,
|
|
133
|
+
params.num_batch,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
|
137
|
+
block_idx, head_idx, batch_idx = self._blk_coord
|
|
138
|
+
if const_expr(self.params.is_split_kv):
|
|
139
|
+
head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod)
|
|
140
|
+
else:
|
|
141
|
+
split_idx = Int32(0)
|
|
142
|
+
return WorkTileInfo(
|
|
143
|
+
(block_idx, head_idx, batch_idx, split_idx),
|
|
144
|
+
self._is_first_block,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
148
|
+
return self.get_current_work(loc=loc, ip=ip)
|
|
149
|
+
|
|
150
|
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
|
154
|
+
self._is_first_block = False
|
|
155
|
+
|
|
156
|
+
def __extract_mlir_values__(self):
|
|
157
|
+
values, self._values_pos = [], []
|
|
158
|
+
for obj in [self.params, self._blk_coord]:
|
|
159
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
160
|
+
values += obj_values
|
|
161
|
+
self._values_pos.append(len(obj_values))
|
|
162
|
+
return values
|
|
163
|
+
|
|
164
|
+
def __new_from_mlir_values__(self, values):
|
|
165
|
+
obj_list = []
|
|
166
|
+
for obj, n_items in zip([self.params, self._blk_coord], self._values_pos):
|
|
167
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
168
|
+
values = values[n_items:]
|
|
169
|
+
return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class StaticPersistentTileScheduler:
|
|
173
|
+
@dataclass
|
|
174
|
+
class Params(ParamsBase):
|
|
175
|
+
num_block_divmod: FastDivmodDivisor
|
|
176
|
+
num_head_divmod: FastDivmodDivisor
|
|
177
|
+
total_blocks: Int32
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
def create(
|
|
181
|
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
|
182
|
+
) -> "StaticPersistentTileScheduler.Params":
|
|
183
|
+
total_blocks = args.num_block * args.num_head * args.num_batch
|
|
184
|
+
return StaticPersistentTileScheduler.Params(
|
|
185
|
+
FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
|
|
189
|
+
self.params = params
|
|
190
|
+
self._tile_idx = tile_idx
|
|
191
|
+
self._loc = loc
|
|
192
|
+
self._ip = ip
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
|
196
|
+
return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler":
|
|
200
|
+
tile_idx = cute.arch.block_idx()[0]
|
|
201
|
+
return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
|
|
202
|
+
|
|
203
|
+
# called by host
|
|
204
|
+
@staticmethod
|
|
205
|
+
def get_grid_shape(
|
|
206
|
+
params: Params,
|
|
207
|
+
*,
|
|
208
|
+
loc=None,
|
|
209
|
+
ip=None,
|
|
210
|
+
) -> Tuple[Int32, Int32, Int32]:
|
|
211
|
+
hardware_info = cutlass.utils.HardwareInfo()
|
|
212
|
+
sm_count = hardware_info.get_device_multiprocessor_count()
|
|
213
|
+
return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1))
|
|
214
|
+
|
|
215
|
+
# @cute.jit
|
|
216
|
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
|
217
|
+
hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_divmod)
|
|
218
|
+
batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
|
|
219
|
+
is_valid = self._tile_idx < self.params.total_blocks
|
|
220
|
+
# if cute.arch.thread_idx()[0] == 0:
|
|
221
|
+
# cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid)
|
|
222
|
+
return WorkTileInfo(
|
|
223
|
+
(Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
227
|
+
return self.get_current_work(loc=loc, ip=ip)
|
|
228
|
+
|
|
229
|
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
|
233
|
+
self._tile_idx += cute.arch.grid_dim()[0]
|
|
234
|
+
|
|
235
|
+
def __extract_mlir_values__(self):
|
|
236
|
+
values, self._values_pos = [], []
|
|
237
|
+
for obj in [self.params, self._tile_idx]:
|
|
238
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
239
|
+
values += obj_values
|
|
240
|
+
self._values_pos.append(len(obj_values))
|
|
241
|
+
return values
|
|
242
|
+
|
|
243
|
+
def __new_from_mlir_values__(self, values):
|
|
244
|
+
obj_list = []
|
|
245
|
+
for obj, n_items in zip(
|
|
246
|
+
[self.params, self._tile_idx],
|
|
247
|
+
self._values_pos,
|
|
248
|
+
):
|
|
249
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
250
|
+
values = values[n_items:]
|
|
251
|
+
return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class SingleTileLPTScheduler:
|
|
255
|
+
@dataclass
|
|
256
|
+
class Params(ParamsBase):
|
|
257
|
+
total_blocks: Int32
|
|
258
|
+
num_splits: Int32
|
|
259
|
+
num_block: Int32
|
|
260
|
+
l2_minor: Int32
|
|
261
|
+
num_block_divmod: FastDivmodDivisor
|
|
262
|
+
num_head_divmod: FastDivmodDivisor
|
|
263
|
+
l2_minor_divmod: FastDivmodDivisor
|
|
264
|
+
l2_major_divmod: FastDivmodDivisor
|
|
265
|
+
l2_minor_residual_divmod: FastDivmodDivisor
|
|
266
|
+
num_hb_quotient: Int32
|
|
267
|
+
is_split_kv: cutlass.Constexpr[bool] = False
|
|
268
|
+
|
|
269
|
+
@staticmethod
|
|
270
|
+
@cute.jit
|
|
271
|
+
def create(
|
|
272
|
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
|
273
|
+
) -> "SingleTileLPTScheduler.Params":
|
|
274
|
+
# cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size)
|
|
275
|
+
size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
|
276
|
+
size_one_head = size_one_kv_head
|
|
277
|
+
size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
|
|
278
|
+
# Swizzle is the size of each "section". Round swizzle to a power of 2
|
|
279
|
+
# Need to be careful about the case where only one head will fit
|
|
280
|
+
# swizzle is how many heads can fit in L2
|
|
281
|
+
# swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
|
|
282
|
+
# Seems faster if swizzle if a power of 2
|
|
283
|
+
log2_floor = lambda n: 31 - clz(n)
|
|
284
|
+
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
|
285
|
+
# swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
|
|
286
|
+
# If we're in the last section (called residual), we don't want to divide by
|
|
287
|
+
# swizzle. Instead we want to divide by the remainder.
|
|
288
|
+
num_hb_quotient = (args.num_head * args.num_batch) // swizzle
|
|
289
|
+
num_hb_remainder = (args.num_head * args.num_batch) % swizzle
|
|
290
|
+
return SingleTileLPTScheduler.Params(
|
|
291
|
+
total_blocks=args.num_block * args.num_head * args.num_batch,
|
|
292
|
+
num_block=args.num_block,
|
|
293
|
+
l2_minor=Int32(swizzle),
|
|
294
|
+
num_block_divmod=FastDivmodDivisor(args.num_block),
|
|
295
|
+
num_head_divmod=FastDivmodDivisor(args.num_head),
|
|
296
|
+
l2_minor_divmod=FastDivmodDivisor(swizzle),
|
|
297
|
+
l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
|
|
298
|
+
l2_minor_residual_divmod=FastDivmodDivisor(
|
|
299
|
+
max(num_hb_remainder, 1)
|
|
300
|
+
), # don't divide by 0
|
|
301
|
+
num_hb_quotient=Int32(num_hb_quotient),
|
|
302
|
+
num_splits=args.num_splits,
|
|
303
|
+
is_split_kv=args.is_split_kv,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
|
|
307
|
+
self.params = params
|
|
308
|
+
self._tile_idx = tile_idx
|
|
309
|
+
self._split_idx = split_idx
|
|
310
|
+
self._loc = loc
|
|
311
|
+
self._ip = ip
|
|
312
|
+
|
|
313
|
+
@staticmethod
|
|
314
|
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
|
315
|
+
return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip)
|
|
316
|
+
|
|
317
|
+
@staticmethod
|
|
318
|
+
@cute.jit
|
|
319
|
+
def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler":
|
|
320
|
+
tile_idx, split_idx, _ = cute.arch.block_idx()
|
|
321
|
+
return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
|
322
|
+
|
|
323
|
+
# called by host
|
|
324
|
+
@staticmethod
|
|
325
|
+
def get_grid_shape(
|
|
326
|
+
params: Params,
|
|
327
|
+
*,
|
|
328
|
+
loc=None,
|
|
329
|
+
ip=None,
|
|
330
|
+
) -> Tuple[Int32, Int32, Int32]:
|
|
331
|
+
return (params.total_blocks, params.num_splits, Int32(1))
|
|
332
|
+
|
|
333
|
+
@cute.jit
|
|
334
|
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
|
335
|
+
params = self.params
|
|
336
|
+
# Implement LPT scheduling coordinate calculation
|
|
337
|
+
bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
|
|
338
|
+
# If we're in the last section (called residual), we don't want to divide by
|
|
339
|
+
# swizzle. Instead we want to divide by the remainder.
|
|
340
|
+
block, bidhb_residual = 0, 0
|
|
341
|
+
if bidhb < params.num_hb_quotient:
|
|
342
|
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
|
|
343
|
+
else:
|
|
344
|
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
|
|
345
|
+
bidhb_actual = bidhb * params.l2_minor + bidhb_residual
|
|
346
|
+
batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
|
|
347
|
+
# Longest-processing-time-first
|
|
348
|
+
block = params.num_block - 1 - block
|
|
349
|
+
is_valid = self._tile_idx < params.total_blocks
|
|
350
|
+
return WorkTileInfo(
|
|
351
|
+
(Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
355
|
+
return self.get_current_work(loc=loc, ip=ip)
|
|
356
|
+
|
|
357
|
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
|
358
|
+
pass
|
|
359
|
+
|
|
360
|
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
|
361
|
+
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
|
362
|
+
self._tile_idx = self.params.total_blocks
|
|
363
|
+
|
|
364
|
+
def __extract_mlir_values__(self):
|
|
365
|
+
values, self._values_pos = [], []
|
|
366
|
+
for obj in [self.params, self._tile_idx, self._split_idx]:
|
|
367
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
368
|
+
values += obj_values
|
|
369
|
+
self._values_pos.append(len(obj_values))
|
|
370
|
+
return values
|
|
371
|
+
|
|
372
|
+
def __new_from_mlir_values__(self, values):
|
|
373
|
+
obj_list = []
|
|
374
|
+
for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos):
|
|
375
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
376
|
+
values = values[n_items:]
|
|
377
|
+
return self.__class__(*(tuple(obj_list)), loc=self._loc)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class SingleTileLPTBwdScheduler:
|
|
381
|
+
@dataclass
|
|
382
|
+
class Params(ParamsBase):
|
|
383
|
+
total_blocks: Int32
|
|
384
|
+
num_block: Int32
|
|
385
|
+
l2_minor: Int32
|
|
386
|
+
num_head_divmod: FastDivmodDivisor
|
|
387
|
+
l2_minor_divmod: FastDivmodDivisor
|
|
388
|
+
l2_major_divmod: FastDivmodDivisor
|
|
389
|
+
l2_minor_residual_divmod: FastDivmodDivisor
|
|
390
|
+
num_hb_quotient: Int32
|
|
391
|
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
|
392
|
+
spt: cutlass.Constexpr[bool] = True
|
|
393
|
+
|
|
394
|
+
@staticmethod
|
|
395
|
+
@cute.jit
|
|
396
|
+
def create(
|
|
397
|
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
|
398
|
+
) -> "SingleTileLPTBwdScheduler.Params":
|
|
399
|
+
size_l2 = 50 * 1024 * 1024
|
|
400
|
+
size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
|
401
|
+
# size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
|
|
402
|
+
size_one_dqaccum_head = 0
|
|
403
|
+
size_one_head = size_one_qdo_head + size_one_dqaccum_head
|
|
404
|
+
log2_floor = lambda n: 31 - clz(n)
|
|
405
|
+
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
|
406
|
+
# swizzle = 8
|
|
407
|
+
# If we're in the last section (called residual), we don't want to divide by
|
|
408
|
+
# swizzle. Instead we want to divide by the remainder.
|
|
409
|
+
num_hb_quotient = (args.num_head * args.num_batch) // swizzle
|
|
410
|
+
num_hb_remainder = (args.num_head * args.num_batch) % swizzle
|
|
411
|
+
num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0])
|
|
412
|
+
return SingleTileLPTBwdScheduler.Params(
|
|
413
|
+
total_blocks=(num_block * args.cluster_shape_mn[0])
|
|
414
|
+
* args.num_head
|
|
415
|
+
* args.num_batch,
|
|
416
|
+
num_block=num_block,
|
|
417
|
+
l2_minor=Int32(swizzle),
|
|
418
|
+
num_head_divmod=FastDivmodDivisor(args.num_head),
|
|
419
|
+
l2_minor_divmod=FastDivmodDivisor(swizzle),
|
|
420
|
+
l2_major_divmod=FastDivmodDivisor(swizzle * num_block),
|
|
421
|
+
l2_minor_residual_divmod=FastDivmodDivisor(
|
|
422
|
+
max(num_hb_remainder, 1)
|
|
423
|
+
), # don't divide by 0
|
|
424
|
+
num_hb_quotient=Int32(num_hb_quotient),
|
|
425
|
+
cluster_shape_mn=args.cluster_shape_mn,
|
|
426
|
+
spt=args.lpt,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
|
|
430
|
+
self.params = params
|
|
431
|
+
self._tile_idx = tile_idx
|
|
432
|
+
self._loc = loc
|
|
433
|
+
self._ip = ip
|
|
434
|
+
|
|
435
|
+
@staticmethod
|
|
436
|
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
|
437
|
+
return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)
|
|
438
|
+
|
|
439
|
+
@staticmethod
|
|
440
|
+
@cute.jit
|
|
441
|
+
def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdScheduler":
|
|
442
|
+
tile_idx = cute.arch.block_idx()[0]
|
|
443
|
+
return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip)
|
|
444
|
+
|
|
445
|
+
# called by host
|
|
446
|
+
@staticmethod
|
|
447
|
+
def get_grid_shape(
|
|
448
|
+
params: Params,
|
|
449
|
+
*,
|
|
450
|
+
loc=None,
|
|
451
|
+
ip=None,
|
|
452
|
+
) -> Tuple[Int32, Int32, Int32]:
|
|
453
|
+
return (params.total_blocks, Int32(1), Int32(1))
|
|
454
|
+
|
|
455
|
+
@cute.jit
|
|
456
|
+
def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
|
|
457
|
+
cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0]
|
|
458
|
+
params = self.params
|
|
459
|
+
# Implement LPT scheduling coordinate calculation
|
|
460
|
+
bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod)
|
|
461
|
+
# If we're in the last section (called residual), we don't want to divide by
|
|
462
|
+
# swizzle. Instead we want to divide by the remainder.
|
|
463
|
+
block, bidhb_residual = 0, 0
|
|
464
|
+
if bidhb < params.num_hb_quotient:
|
|
465
|
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
|
|
466
|
+
else:
|
|
467
|
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
|
|
468
|
+
bidhb_actual = bidhb * params.l2_minor + bidhb_residual
|
|
469
|
+
batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
|
|
470
|
+
is_valid = self._tile_idx < params.total_blocks
|
|
471
|
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
472
|
+
block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
473
|
+
if cutlass.const_expr(params.spt):
|
|
474
|
+
block = params.num_block - 1 - block
|
|
475
|
+
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid)
|
|
476
|
+
|
|
477
|
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
478
|
+
return self.get_current_work(loc=loc, ip=ip)
|
|
479
|
+
|
|
480
|
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
|
481
|
+
pass
|
|
482
|
+
|
|
483
|
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
|
484
|
+
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
|
485
|
+
self._tile_idx = self.params.total_blocks
|
|
486
|
+
|
|
487
|
+
def __extract_mlir_values__(self):
|
|
488
|
+
values, self._values_pos = [], []
|
|
489
|
+
for obj in [self.params, self._tile_idx]:
|
|
490
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
491
|
+
values += obj_values
|
|
492
|
+
self._values_pos.append(len(obj_values))
|
|
493
|
+
return values
|
|
494
|
+
|
|
495
|
+
def __new_from_mlir_values__(self, values):
|
|
496
|
+
obj_list = []
|
|
497
|
+
for obj, n_items in zip([self.params, self._tile_idx], self._values_pos):
|
|
498
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
499
|
+
values = values[n_items:]
|
|
500
|
+
return self.__class__(*(tuple(obj_list)), loc=self._loc)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class SingleTileVarlenScheduler:
|
|
504
|
+
@dataclass
|
|
505
|
+
class Params(ParamsBase):
|
|
506
|
+
num_head: Int32
|
|
507
|
+
num_batch: Int32
|
|
508
|
+
total_q: Int32
|
|
509
|
+
num_splits: Int32
|
|
510
|
+
max_kvblock_in_l2: Int32
|
|
511
|
+
tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
|
|
512
|
+
mCuSeqlensQ: Optional[cute.Tensor] = None
|
|
513
|
+
mSeqUsedQ: Optional[cute.Tensor] = None
|
|
514
|
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
|
515
|
+
lpt: cutlass.Constexpr[bool] = False
|
|
516
|
+
is_split_kv: cutlass.Constexpr[bool] = False
|
|
517
|
+
head_swizzle: cutlass.Constexpr[bool] = False
|
|
518
|
+
|
|
519
|
+
@staticmethod
|
|
520
|
+
@cute.jit
|
|
521
|
+
def create(
|
|
522
|
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
|
523
|
+
) -> "SingleTileVarlenScheduler.Params":
|
|
524
|
+
size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
|
|
525
|
+
max_kvblock_in_l2 = size_l2 // (
|
|
526
|
+
(args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
|
|
527
|
+
)
|
|
528
|
+
assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
|
|
529
|
+
"At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
|
|
530
|
+
)
|
|
531
|
+
return SingleTileVarlenScheduler.Params(
|
|
532
|
+
num_head=args.num_head,
|
|
533
|
+
num_batch=args.num_batch,
|
|
534
|
+
total_q=args.total_q,
|
|
535
|
+
num_splits=args.num_splits,
|
|
536
|
+
max_kvblock_in_l2=max_kvblock_in_l2,
|
|
537
|
+
tile_shape_mn=args.tile_shape_mn,
|
|
538
|
+
mCuSeqlensQ=args.mCuSeqlensQ,
|
|
539
|
+
mSeqUsedQ=args.mSeqUsedQ,
|
|
540
|
+
qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,
|
|
541
|
+
lpt=args.lpt,
|
|
542
|
+
is_split_kv=args.is_split_kv,
|
|
543
|
+
head_swizzle=args.head_swizzle,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
|
|
547
|
+
self.params = params
|
|
548
|
+
self._tile_idx = tile_idx
|
|
549
|
+
self._split_idx = split_idx
|
|
550
|
+
self._is_first_block = True
|
|
551
|
+
self._loc = loc
|
|
552
|
+
self._ip = ip
|
|
553
|
+
|
|
554
|
+
@staticmethod
|
|
555
|
+
def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
|
|
556
|
+
return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip)
|
|
557
|
+
|
|
558
|
+
@staticmethod
|
|
559
|
+
def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler":
|
|
560
|
+
tile_idx, split_idx, _ = cute.arch.block_idx()
|
|
561
|
+
return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
|
562
|
+
|
|
563
|
+
# called by host
|
|
564
|
+
@staticmethod
|
|
565
|
+
def get_grid_shape(
|
|
566
|
+
params: Params,
|
|
567
|
+
*,
|
|
568
|
+
loc=None,
|
|
569
|
+
ip=None,
|
|
570
|
+
) -> Tuple[Int32, Int32, Int32]:
|
|
571
|
+
total_blocks_max = (
|
|
572
|
+
params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1)
|
|
573
|
+
) // params.tile_shape_mn[0]
|
|
574
|
+
return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
|
|
575
|
+
|
|
576
|
+
@cute.jit
|
|
577
|
+
def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32:
|
|
578
|
+
params = self.params
|
|
579
|
+
batch_idx = lane + bidb_start
|
|
580
|
+
if cutlass.const_expr(params.mSeqUsedQ is not None):
|
|
581
|
+
seqlen = Int32(0)
|
|
582
|
+
if batch_idx < params.num_batch:
|
|
583
|
+
seqlen = params.mSeqUsedQ[batch_idx]
|
|
584
|
+
else:
|
|
585
|
+
assert params.mCuSeqlensQ is not None
|
|
586
|
+
cur_cu_seqlen = Int32(0)
|
|
587
|
+
if batch_idx <= params.num_batch:
|
|
588
|
+
cur_cu_seqlen = params.mCuSeqlensQ[batch_idx]
|
|
589
|
+
next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
|
|
590
|
+
seqlen = next_cu_seqlen - cur_cu_seqlen
|
|
591
|
+
if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1):
|
|
592
|
+
seqlen *= params.qhead_per_kvhead_packgqa
|
|
593
|
+
return (
|
|
594
|
+
cute.ceil_div(seqlen, params.tile_shape_mn[0])
|
|
595
|
+
if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1
|
|
596
|
+
else Int32(0)
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
@cute.jit
|
|
600
|
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
|
601
|
+
params = self.params
|
|
602
|
+
lane_idx = cute.arch.lane_idx()
|
|
603
|
+
num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
|
|
604
|
+
num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
|
|
605
|
+
# Total number of blocks for the next 31 batches
|
|
606
|
+
m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1)
|
|
607
|
+
# Same for all lanes
|
|
608
|
+
group_end_tile = m_blocks_in_group * params.num_head
|
|
609
|
+
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group)
|
|
610
|
+
block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0)
|
|
611
|
+
next_tile_idx = self._tile_idx
|
|
612
|
+
while group_end_tile <= next_tile_idx:
|
|
613
|
+
batch_idx += cute.arch.WARP_SIZE - 1
|
|
614
|
+
if batch_idx >= params.num_batch:
|
|
615
|
+
batch_idx = Int32(params.num_batch)
|
|
616
|
+
group_end_tile = next_tile_idx + 1
|
|
617
|
+
else:
|
|
618
|
+
num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx)
|
|
619
|
+
num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
|
|
620
|
+
m_blocks_in_group = cute.arch.shuffle_sync(
|
|
621
|
+
num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1
|
|
622
|
+
)
|
|
623
|
+
group_end_tile += m_blocks_in_group * params.num_head
|
|
624
|
+
is_valid = False
|
|
625
|
+
if batch_idx >= params.num_batch:
|
|
626
|
+
block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch)
|
|
627
|
+
else:
|
|
628
|
+
group_start_tile = group_end_tile - m_blocks_in_group * params.num_head
|
|
629
|
+
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx)
|
|
630
|
+
# The next problem to process is the first one that does not have ending tile position
|
|
631
|
+
# that is greater than or equal to tile index.
|
|
632
|
+
batch_idx_in_group = cute.arch.popc(
|
|
633
|
+
cute.arch.vote_ballot_sync(
|
|
634
|
+
group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx
|
|
635
|
+
)
|
|
636
|
+
)
|
|
637
|
+
batch_idx += batch_idx_in_group
|
|
638
|
+
num_m_blocks_prev_lane = (
|
|
639
|
+
0
|
|
640
|
+
if batch_idx_in_group == 0
|
|
641
|
+
else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1)
|
|
642
|
+
)
|
|
643
|
+
num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)
|
|
644
|
+
mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head
|
|
645
|
+
if cutlass.const_expr(params.lpt or params.head_swizzle):
|
|
646
|
+
# This is a version of the SingleTileLPTScheduler, complicated by the fact that
|
|
647
|
+
# the seqlen can vary per batch.
|
|
648
|
+
# TODO: is there any case where num_m_blocks is 0?
|
|
649
|
+
# TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here
|
|
650
|
+
num_n_blocks = (
|
|
651
|
+
num_m_blocks
|
|
652
|
+
* params.tile_shape_mn[0]
|
|
653
|
+
// params.qhead_per_kvhead_packgqa
|
|
654
|
+
// params.tile_shape_mn[1]
|
|
655
|
+
)
|
|
656
|
+
# nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head)
|
|
657
|
+
# Seems faster to have this be a power of 2
|
|
658
|
+
nheads_in_l2 = (
|
|
659
|
+
16
|
|
660
|
+
if num_n_blocks * 16 <= params.max_kvblock_in_l2
|
|
661
|
+
else (
|
|
662
|
+
8
|
|
663
|
+
if num_n_blocks * 8 <= params.max_kvblock_in_l2
|
|
664
|
+
else (
|
|
665
|
+
4
|
|
666
|
+
if num_n_blocks * 4 <= params.max_kvblock_in_l2
|
|
667
|
+
else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)
|
|
668
|
+
)
|
|
669
|
+
)
|
|
670
|
+
)
|
|
671
|
+
nheads_in_l2 = min(nheads_in_l2, params.num_head)
|
|
672
|
+
mh_in_l2 = nheads_in_l2 * num_m_blocks
|
|
673
|
+
section_idx = mh_block // mh_in_l2
|
|
674
|
+
l2_mod = mh_block - section_idx * mh_in_l2
|
|
675
|
+
# Deal with tail section
|
|
676
|
+
nheads_in_this_section = (
|
|
677
|
+
nheads_in_l2
|
|
678
|
+
if nheads_in_l2 * (section_idx + 1) <= params.num_head
|
|
679
|
+
else params.num_head - section_idx * nheads_in_l2
|
|
680
|
+
)
|
|
681
|
+
block = l2_mod // nheads_in_this_section
|
|
682
|
+
head_idx_residual = l2_mod - block * nheads_in_this_section
|
|
683
|
+
head_idx = section_idx * nheads_in_l2 + head_idx_residual
|
|
684
|
+
if cutlass.const_expr(params.lpt):
|
|
685
|
+
block = num_m_blocks - 1 - block
|
|
686
|
+
else:
|
|
687
|
+
head_idx = mh_block // num_m_blocks
|
|
688
|
+
block = mh_block - head_idx * num_m_blocks
|
|
689
|
+
is_valid = self._is_first_block and batch_idx < params.num_batch
|
|
690
|
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid)
|
|
691
|
+
split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
|
|
692
|
+
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
|
|
693
|
+
|
|
694
|
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
695
|
+
return self.get_current_work(loc=loc, ip=ip)
|
|
696
|
+
|
|
697
|
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
|
698
|
+
pass
|
|
699
|
+
|
|
700
|
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
|
701
|
+
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
|
702
|
+
self._is_first_block = False
|
|
703
|
+
|
|
704
|
+
def __extract_mlir_values__(self):
|
|
705
|
+
values, self._values_pos = [], []
|
|
706
|
+
for obj in [self.params, self._tile_idx, self._split_idx]:
|
|
707
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
708
|
+
values += obj_values
|
|
709
|
+
self._values_pos.append(len(obj_values))
|
|
710
|
+
return values
|
|
711
|
+
|
|
712
|
+
def __new_from_mlir_values__(self, values):
|
|
713
|
+
obj_list = []
|
|
714
|
+
for obj, n_items in zip(
|
|
715
|
+
[self.params, self._tile_idx, self._split_idx],
|
|
716
|
+
self._values_pos,
|
|
717
|
+
):
|
|
718
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
719
|
+
values = values[n_items:]
|
|
720
|
+
return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc)
|