mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,720 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+
4
+ from typing import Optional, Tuple
5
+ from dataclasses import dataclass, fields
6
+
7
+ try:
8
+ from typing import override
9
+ except ImportError: # Python < 3.12
10
+ from typing_extensions import override
11
+
12
+ import cutlass
13
+ from cutlass._mlir import ir
14
+ import cutlass.cute as cute
15
+ from cutlass import Int32, const_expr
16
+
17
+ import mslk.attention.flash_attn.utils as utils
18
+ from mslk.attention.flash_attn.fast_math import clz
19
+ from cutlass.cute import FastDivmodDivisor
20
+
21
+
22
+ class WorkTileInfo(cutlass.utils.WorkTileInfo):
23
+ """Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
24
+
25
+ @override
26
+ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
27
+ assert len(values) == 5
28
+ new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1])
29
+ new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]])
30
+ return WorkTileInfo(new_tile_idx, new_is_valid_tile)
31
+
32
+
33
+ @dataclass
34
+ class ParamsBase:
35
+ def __extract_mlir_values__(self):
36
+ all_fields = [getattr(self, field.name) for field in fields(self)]
37
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)]
38
+ values, self._values_pos = [], []
39
+ for obj in non_constexpr_fields:
40
+ obj_values = cutlass.extract_mlir_values(obj)
41
+ values += obj_values
42
+ self._values_pos.append(len(obj_values))
43
+ return values
44
+
45
+ def __new_from_mlir_values__(self, values):
46
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
47
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)}
48
+ non_constexpr_fields = {
49
+ n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr)
50
+ }
51
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
52
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
53
+ values = values[n_items:]
54
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
55
+
56
+
57
+ @dataclass
58
+ class TileSchedulerArguments(ParamsBase):
59
+ num_block: Int32
60
+ num_head: Int32
61
+ num_batch: Int32
62
+ num_splits: Int32
63
+ seqlen_k: Int32
64
+ headdim: Int32
65
+ headdim_v: Int32
66
+ total_q: Int32
67
+ tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
68
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
69
+ mCuSeqlensQ: Optional[cute.Tensor] = None
70
+ mSeqUsedQ: Optional[cute.Tensor] = None
71
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
72
+ element_size: cutlass.Constexpr[int] = 2
73
+ is_persistent: cutlass.Constexpr[bool] = False
74
+ lpt: cutlass.Constexpr[bool] = False
75
+ is_split_kv: cutlass.Constexpr[bool] = False
76
+ head_swizzle: cutlass.Constexpr[bool] = False
77
+
78
+
79
+ class SingleTileScheduler:
80
+ @dataclass
81
+ class Params(ParamsBase):
82
+ num_block: Int32
83
+ num_head: Int32
84
+ num_batch: Int32
85
+ num_splits: Int32
86
+ num_splits_divmod: FastDivmodDivisor
87
+ is_split_kv: cutlass.Constexpr[bool] = False
88
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
89
+
90
+ @staticmethod
91
+ def create(
92
+ args: TileSchedulerArguments, *, loc=None, ip=None
93
+ ) -> "SingleTileScheduler.Params":
94
+ return SingleTileScheduler.Params(
95
+ args.num_block,
96
+ args.num_head,
97
+ args.num_batch,
98
+ args.num_splits,
99
+ FastDivmodDivisor(args.num_splits),
100
+ args.is_split_kv,
101
+ args.cluster_shape_mn,
102
+ )
103
+
104
+ def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
105
+ self.params = params
106
+ self._blk_coord = blk_coord
107
+ self._is_first_block = True
108
+ self._loc = loc
109
+ self._ip = ip
110
+
111
+ @staticmethod
112
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
113
+ return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
114
+
115
+ @staticmethod
116
+ def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler":
117
+ blk_coord = cute.arch.block_idx()
118
+ return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
119
+
120
+ # called by host
121
+ @staticmethod
122
+ def get_grid_shape(
123
+ params: Params,
124
+ *,
125
+ loc=None,
126
+ ip=None,
127
+ ) -> Tuple[Int32, Int32, Int32]:
128
+ # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
129
+ assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
130
+ return (
131
+ cute.round_up(params.num_block, params.cluster_shape_mn[0]),
132
+ params.num_head * params.num_splits,
133
+ params.num_batch,
134
+ )
135
+
136
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
137
+ block_idx, head_idx, batch_idx = self._blk_coord
138
+ if const_expr(self.params.is_split_kv):
139
+ head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod)
140
+ else:
141
+ split_idx = Int32(0)
142
+ return WorkTileInfo(
143
+ (block_idx, head_idx, batch_idx, split_idx),
144
+ self._is_first_block,
145
+ )
146
+
147
+ def initial_work_tile_info(self, *, loc=None, ip=None):
148
+ return self.get_current_work(loc=loc, ip=ip)
149
+
150
+ def prefetch_next_work(self, *, loc=None, ip=None):
151
+ pass
152
+
153
+ def advance_to_next_work(self, *, loc=None, ip=None):
154
+ self._is_first_block = False
155
+
156
+ def __extract_mlir_values__(self):
157
+ values, self._values_pos = [], []
158
+ for obj in [self.params, self._blk_coord]:
159
+ obj_values = cutlass.extract_mlir_values(obj)
160
+ values += obj_values
161
+ self._values_pos.append(len(obj_values))
162
+ return values
163
+
164
+ def __new_from_mlir_values__(self, values):
165
+ obj_list = []
166
+ for obj, n_items in zip([self.params, self._blk_coord], self._values_pos):
167
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
168
+ values = values[n_items:]
169
+ return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc)
170
+
171
+
172
+ class StaticPersistentTileScheduler:
173
+ @dataclass
174
+ class Params(ParamsBase):
175
+ num_block_divmod: FastDivmodDivisor
176
+ num_head_divmod: FastDivmodDivisor
177
+ total_blocks: Int32
178
+
179
+ @staticmethod
180
+ def create(
181
+ args: TileSchedulerArguments, *, loc=None, ip=None
182
+ ) -> "StaticPersistentTileScheduler.Params":
183
+ total_blocks = args.num_block * args.num_head * args.num_batch
184
+ return StaticPersistentTileScheduler.Params(
185
+ FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks
186
+ )
187
+
188
+ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
189
+ self.params = params
190
+ self._tile_idx = tile_idx
191
+ self._loc = loc
192
+ self._ip = ip
193
+
194
+ @staticmethod
195
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
196
+ return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
197
+
198
+ @staticmethod
199
+ def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler":
200
+ tile_idx = cute.arch.block_idx()[0]
201
+ return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
202
+
203
+ # called by host
204
+ @staticmethod
205
+ def get_grid_shape(
206
+ params: Params,
207
+ *,
208
+ loc=None,
209
+ ip=None,
210
+ ) -> Tuple[Int32, Int32, Int32]:
211
+ hardware_info = cutlass.utils.HardwareInfo()
212
+ sm_count = hardware_info.get_device_multiprocessor_count()
213
+ return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1))
214
+
215
+ # @cute.jit
216
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
217
+ hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_divmod)
218
+ batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
219
+ is_valid = self._tile_idx < self.params.total_blocks
220
+ # if cute.arch.thread_idx()[0] == 0:
221
+ # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid)
222
+ return WorkTileInfo(
223
+ (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
224
+ )
225
+
226
+ def initial_work_tile_info(self, *, loc=None, ip=None):
227
+ return self.get_current_work(loc=loc, ip=ip)
228
+
229
+ def prefetch_next_work(self, *, loc=None, ip=None):
230
+ pass
231
+
232
+ def advance_to_next_work(self, *, loc=None, ip=None):
233
+ self._tile_idx += cute.arch.grid_dim()[0]
234
+
235
+ def __extract_mlir_values__(self):
236
+ values, self._values_pos = [], []
237
+ for obj in [self.params, self._tile_idx]:
238
+ obj_values = cutlass.extract_mlir_values(obj)
239
+ values += obj_values
240
+ self._values_pos.append(len(obj_values))
241
+ return values
242
+
243
+ def __new_from_mlir_values__(self, values):
244
+ obj_list = []
245
+ for obj, n_items in zip(
246
+ [self.params, self._tile_idx],
247
+ self._values_pos,
248
+ ):
249
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
250
+ values = values[n_items:]
251
+ return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc)
252
+
253
+
254
+ class SingleTileLPTScheduler:
255
+ @dataclass
256
+ class Params(ParamsBase):
257
+ total_blocks: Int32
258
+ num_splits: Int32
259
+ num_block: Int32
260
+ l2_minor: Int32
261
+ num_block_divmod: FastDivmodDivisor
262
+ num_head_divmod: FastDivmodDivisor
263
+ l2_minor_divmod: FastDivmodDivisor
264
+ l2_major_divmod: FastDivmodDivisor
265
+ l2_minor_residual_divmod: FastDivmodDivisor
266
+ num_hb_quotient: Int32
267
+ is_split_kv: cutlass.Constexpr[bool] = False
268
+
269
+ @staticmethod
270
+ @cute.jit
271
+ def create(
272
+ args: TileSchedulerArguments, *, loc=None, ip=None
273
+ ) -> "SingleTileLPTScheduler.Params":
274
+ # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size)
275
+ size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
276
+ size_one_head = size_one_kv_head
277
+ size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
278
+ # Swizzle is the size of each "section". Round swizzle to a power of 2
279
+ # Need to be careful about the case where only one head will fit
280
+ # swizzle is how many heads can fit in L2
281
+ # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
282
+ # Seems faster if swizzle if a power of 2
283
+ log2_floor = lambda n: 31 - clz(n)
284
+ swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
285
+ # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
286
+ # If we're in the last section (called residual), we don't want to divide by
287
+ # swizzle. Instead we want to divide by the remainder.
288
+ num_hb_quotient = (args.num_head * args.num_batch) // swizzle
289
+ num_hb_remainder = (args.num_head * args.num_batch) % swizzle
290
+ return SingleTileLPTScheduler.Params(
291
+ total_blocks=args.num_block * args.num_head * args.num_batch,
292
+ num_block=args.num_block,
293
+ l2_minor=Int32(swizzle),
294
+ num_block_divmod=FastDivmodDivisor(args.num_block),
295
+ num_head_divmod=FastDivmodDivisor(args.num_head),
296
+ l2_minor_divmod=FastDivmodDivisor(swizzle),
297
+ l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
298
+ l2_minor_residual_divmod=FastDivmodDivisor(
299
+ max(num_hb_remainder, 1)
300
+ ), # don't divide by 0
301
+ num_hb_quotient=Int32(num_hb_quotient),
302
+ num_splits=args.num_splits,
303
+ is_split_kv=args.is_split_kv,
304
+ )
305
+
306
+ def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
307
+ self.params = params
308
+ self._tile_idx = tile_idx
309
+ self._split_idx = split_idx
310
+ self._loc = loc
311
+ self._ip = ip
312
+
313
+ @staticmethod
314
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
315
+ return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip)
316
+
317
+ @staticmethod
318
+ @cute.jit
319
+ def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler":
320
+ tile_idx, split_idx, _ = cute.arch.block_idx()
321
+ return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
322
+
323
+ # called by host
324
+ @staticmethod
325
+ def get_grid_shape(
326
+ params: Params,
327
+ *,
328
+ loc=None,
329
+ ip=None,
330
+ ) -> Tuple[Int32, Int32, Int32]:
331
+ return (params.total_blocks, params.num_splits, Int32(1))
332
+
333
+ @cute.jit
334
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
335
+ params = self.params
336
+ # Implement LPT scheduling coordinate calculation
337
+ bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
338
+ # If we're in the last section (called residual), we don't want to divide by
339
+ # swizzle. Instead we want to divide by the remainder.
340
+ block, bidhb_residual = 0, 0
341
+ if bidhb < params.num_hb_quotient:
342
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
343
+ else:
344
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
345
+ bidhb_actual = bidhb * params.l2_minor + bidhb_residual
346
+ batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
347
+ # Longest-processing-time-first
348
+ block = params.num_block - 1 - block
349
+ is_valid = self._tile_idx < params.total_blocks
350
+ return WorkTileInfo(
351
+ (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
352
+ )
353
+
354
+ def initial_work_tile_info(self, *, loc=None, ip=None):
355
+ return self.get_current_work(loc=loc, ip=ip)
356
+
357
+ def prefetch_next_work(self, *, loc=None, ip=None):
358
+ pass
359
+
360
+ def advance_to_next_work(self, *, loc=None, ip=None):
361
+ # Single tile scheduler - set to invalid tile_idx to indicate no more work
362
+ self._tile_idx = self.params.total_blocks
363
+
364
+ def __extract_mlir_values__(self):
365
+ values, self._values_pos = [], []
366
+ for obj in [self.params, self._tile_idx, self._split_idx]:
367
+ obj_values = cutlass.extract_mlir_values(obj)
368
+ values += obj_values
369
+ self._values_pos.append(len(obj_values))
370
+ return values
371
+
372
+ def __new_from_mlir_values__(self, values):
373
+ obj_list = []
374
+ for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos):
375
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
376
+ values = values[n_items:]
377
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)
378
+
379
+
380
+ class SingleTileLPTBwdScheduler:
381
+ @dataclass
382
+ class Params(ParamsBase):
383
+ total_blocks: Int32
384
+ num_block: Int32
385
+ l2_minor: Int32
386
+ num_head_divmod: FastDivmodDivisor
387
+ l2_minor_divmod: FastDivmodDivisor
388
+ l2_major_divmod: FastDivmodDivisor
389
+ l2_minor_residual_divmod: FastDivmodDivisor
390
+ num_hb_quotient: Int32
391
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
392
+ spt: cutlass.Constexpr[bool] = True
393
+
394
+ @staticmethod
395
+ @cute.jit
396
+ def create(
397
+ args: TileSchedulerArguments, *, loc=None, ip=None
398
+ ) -> "SingleTileLPTBwdScheduler.Params":
399
+ size_l2 = 50 * 1024 * 1024
400
+ size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
401
+ # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
402
+ size_one_dqaccum_head = 0
403
+ size_one_head = size_one_qdo_head + size_one_dqaccum_head
404
+ log2_floor = lambda n: 31 - clz(n)
405
+ swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
406
+ # swizzle = 8
407
+ # If we're in the last section (called residual), we don't want to divide by
408
+ # swizzle. Instead we want to divide by the remainder.
409
+ num_hb_quotient = (args.num_head * args.num_batch) // swizzle
410
+ num_hb_remainder = (args.num_head * args.num_batch) % swizzle
411
+ num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0])
412
+ return SingleTileLPTBwdScheduler.Params(
413
+ total_blocks=(num_block * args.cluster_shape_mn[0])
414
+ * args.num_head
415
+ * args.num_batch,
416
+ num_block=num_block,
417
+ l2_minor=Int32(swizzle),
418
+ num_head_divmod=FastDivmodDivisor(args.num_head),
419
+ l2_minor_divmod=FastDivmodDivisor(swizzle),
420
+ l2_major_divmod=FastDivmodDivisor(swizzle * num_block),
421
+ l2_minor_residual_divmod=FastDivmodDivisor(
422
+ max(num_hb_remainder, 1)
423
+ ), # don't divide by 0
424
+ num_hb_quotient=Int32(num_hb_quotient),
425
+ cluster_shape_mn=args.cluster_shape_mn,
426
+ spt=args.lpt,
427
+ )
428
+
429
+ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
430
+ self.params = params
431
+ self._tile_idx = tile_idx
432
+ self._loc = loc
433
+ self._ip = ip
434
+
435
+ @staticmethod
436
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
437
+ return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)
438
+
439
+ @staticmethod
440
+ @cute.jit
441
+ def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdScheduler":
442
+ tile_idx = cute.arch.block_idx()[0]
443
+ return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip)
444
+
445
+ # called by host
446
+ @staticmethod
447
+ def get_grid_shape(
448
+ params: Params,
449
+ *,
450
+ loc=None,
451
+ ip=None,
452
+ ) -> Tuple[Int32, Int32, Int32]:
453
+ return (params.total_blocks, Int32(1), Int32(1))
454
+
455
+ @cute.jit
456
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
457
+ cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0]
458
+ params = self.params
459
+ # Implement LPT scheduling coordinate calculation
460
+ bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod)
461
+ # If we're in the last section (called residual), we don't want to divide by
462
+ # swizzle. Instead we want to divide by the remainder.
463
+ block, bidhb_residual = 0, 0
464
+ if bidhb < params.num_hb_quotient:
465
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
466
+ else:
467
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
468
+ bidhb_actual = bidhb * params.l2_minor + bidhb_residual
469
+ batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
470
+ is_valid = self._tile_idx < params.total_blocks
471
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
472
+ block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0]
473
+ if cutlass.const_expr(params.spt):
474
+ block = params.num_block - 1 - block
475
+ return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid)
476
+
477
+ def initial_work_tile_info(self, *, loc=None, ip=None):
478
+ return self.get_current_work(loc=loc, ip=ip)
479
+
480
+ def prefetch_next_work(self, *, loc=None, ip=None):
481
+ pass
482
+
483
+ def advance_to_next_work(self, *, loc=None, ip=None):
484
+ # Single tile scheduler - set to invalid tile_idx to indicate no more work
485
+ self._tile_idx = self.params.total_blocks
486
+
487
+ def __extract_mlir_values__(self):
488
+ values, self._values_pos = [], []
489
+ for obj in [self.params, self._tile_idx]:
490
+ obj_values = cutlass.extract_mlir_values(obj)
491
+ values += obj_values
492
+ self._values_pos.append(len(obj_values))
493
+ return values
494
+
495
+ def __new_from_mlir_values__(self, values):
496
+ obj_list = []
497
+ for obj, n_items in zip([self.params, self._tile_idx], self._values_pos):
498
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
499
+ values = values[n_items:]
500
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)
501
+
502
+
503
+ class SingleTileVarlenScheduler:
504
+ @dataclass
505
+ class Params(ParamsBase):
506
+ num_head: Int32
507
+ num_batch: Int32
508
+ total_q: Int32
509
+ num_splits: Int32
510
+ max_kvblock_in_l2: Int32
511
+ tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
512
+ mCuSeqlensQ: Optional[cute.Tensor] = None
513
+ mSeqUsedQ: Optional[cute.Tensor] = None
514
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
515
+ lpt: cutlass.Constexpr[bool] = False
516
+ is_split_kv: cutlass.Constexpr[bool] = False
517
+ head_swizzle: cutlass.Constexpr[bool] = False
518
+
519
+ @staticmethod
520
+ @cute.jit
521
+ def create(
522
+ args: TileSchedulerArguments, *, loc=None, ip=None
523
+ ) -> "SingleTileVarlenScheduler.Params":
524
+ size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
525
+ max_kvblock_in_l2 = size_l2 // (
526
+ (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
527
+ )
528
+ assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
529
+ "At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
530
+ )
531
+ return SingleTileVarlenScheduler.Params(
532
+ num_head=args.num_head,
533
+ num_batch=args.num_batch,
534
+ total_q=args.total_q,
535
+ num_splits=args.num_splits,
536
+ max_kvblock_in_l2=max_kvblock_in_l2,
537
+ tile_shape_mn=args.tile_shape_mn,
538
+ mCuSeqlensQ=args.mCuSeqlensQ,
539
+ mSeqUsedQ=args.mSeqUsedQ,
540
+ qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,
541
+ lpt=args.lpt,
542
+ is_split_kv=args.is_split_kv,
543
+ head_swizzle=args.head_swizzle,
544
+ )
545
+
546
+ def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
547
+ self.params = params
548
+ self._tile_idx = tile_idx
549
+ self._split_idx = split_idx
550
+ self._is_first_block = True
551
+ self._loc = loc
552
+ self._ip = ip
553
+
554
+ @staticmethod
555
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
556
+ return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip)
557
+
558
+ @staticmethod
559
+ def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler":
560
+ tile_idx, split_idx, _ = cute.arch.block_idx()
561
+ return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
562
+
563
+ # called by host
564
+ @staticmethod
565
+ def get_grid_shape(
566
+ params: Params,
567
+ *,
568
+ loc=None,
569
+ ip=None,
570
+ ) -> Tuple[Int32, Int32, Int32]:
571
+ total_blocks_max = (
572
+ params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1)
573
+ ) // params.tile_shape_mn[0]
574
+ return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
575
+
576
+ @cute.jit
577
+ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32:
578
+ params = self.params
579
+ batch_idx = lane + bidb_start
580
+ if cutlass.const_expr(params.mSeqUsedQ is not None):
581
+ seqlen = Int32(0)
582
+ if batch_idx < params.num_batch:
583
+ seqlen = params.mSeqUsedQ[batch_idx]
584
+ else:
585
+ assert params.mCuSeqlensQ is not None
586
+ cur_cu_seqlen = Int32(0)
587
+ if batch_idx <= params.num_batch:
588
+ cur_cu_seqlen = params.mCuSeqlensQ[batch_idx]
589
+ next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
590
+ seqlen = next_cu_seqlen - cur_cu_seqlen
591
+ if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1):
592
+ seqlen *= params.qhead_per_kvhead_packgqa
593
+ return (
594
+ cute.ceil_div(seqlen, params.tile_shape_mn[0])
595
+ if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1
596
+ else Int32(0)
597
+ )
598
+
599
+ @cute.jit
600
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
601
+ params = self.params
602
+ lane_idx = cute.arch.lane_idx()
603
+ num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
604
+ num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
605
+ # Total number of blocks for the next 31 batches
606
+ m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1)
607
+ # Same for all lanes
608
+ group_end_tile = m_blocks_in_group * params.num_head
609
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group)
610
+ block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0)
611
+ next_tile_idx = self._tile_idx
612
+ while group_end_tile <= next_tile_idx:
613
+ batch_idx += cute.arch.WARP_SIZE - 1
614
+ if batch_idx >= params.num_batch:
615
+ batch_idx = Int32(params.num_batch)
616
+ group_end_tile = next_tile_idx + 1
617
+ else:
618
+ num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx)
619
+ num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
620
+ m_blocks_in_group = cute.arch.shuffle_sync(
621
+ num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1
622
+ )
623
+ group_end_tile += m_blocks_in_group * params.num_head
624
+ is_valid = False
625
+ if batch_idx >= params.num_batch:
626
+ block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch)
627
+ else:
628
+ group_start_tile = group_end_tile - m_blocks_in_group * params.num_head
629
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx)
630
+ # The next problem to process is the first one that does not have ending tile position
631
+ # that is greater than or equal to tile index.
632
+ batch_idx_in_group = cute.arch.popc(
633
+ cute.arch.vote_ballot_sync(
634
+ group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx
635
+ )
636
+ )
637
+ batch_idx += batch_idx_in_group
638
+ num_m_blocks_prev_lane = (
639
+ 0
640
+ if batch_idx_in_group == 0
641
+ else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1)
642
+ )
643
+ num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)
644
+ mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head
645
+ if cutlass.const_expr(params.lpt or params.head_swizzle):
646
+ # This is a version of the SingleTileLPTScheduler, complicated by the fact that
647
+ # the seqlen can vary per batch.
648
+ # TODO: is there any case where num_m_blocks is 0?
649
+ # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here
650
+ num_n_blocks = (
651
+ num_m_blocks
652
+ * params.tile_shape_mn[0]
653
+ // params.qhead_per_kvhead_packgqa
654
+ // params.tile_shape_mn[1]
655
+ )
656
+ # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head)
657
+ # Seems faster to have this be a power of 2
658
+ nheads_in_l2 = (
659
+ 16
660
+ if num_n_blocks * 16 <= params.max_kvblock_in_l2
661
+ else (
662
+ 8
663
+ if num_n_blocks * 8 <= params.max_kvblock_in_l2
664
+ else (
665
+ 4
666
+ if num_n_blocks * 4 <= params.max_kvblock_in_l2
667
+ else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)
668
+ )
669
+ )
670
+ )
671
+ nheads_in_l2 = min(nheads_in_l2, params.num_head)
672
+ mh_in_l2 = nheads_in_l2 * num_m_blocks
673
+ section_idx = mh_block // mh_in_l2
674
+ l2_mod = mh_block - section_idx * mh_in_l2
675
+ # Deal with tail section
676
+ nheads_in_this_section = (
677
+ nheads_in_l2
678
+ if nheads_in_l2 * (section_idx + 1) <= params.num_head
679
+ else params.num_head - section_idx * nheads_in_l2
680
+ )
681
+ block = l2_mod // nheads_in_this_section
682
+ head_idx_residual = l2_mod - block * nheads_in_this_section
683
+ head_idx = section_idx * nheads_in_l2 + head_idx_residual
684
+ if cutlass.const_expr(params.lpt):
685
+ block = num_m_blocks - 1 - block
686
+ else:
687
+ head_idx = mh_block // num_m_blocks
688
+ block = mh_block - head_idx * num_m_blocks
689
+ is_valid = self._is_first_block and batch_idx < params.num_batch
690
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid)
691
+ split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
692
+ return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
693
+
694
+ def initial_work_tile_info(self, *, loc=None, ip=None):
695
+ return self.get_current_work(loc=loc, ip=ip)
696
+
697
+ def prefetch_next_work(self, *, loc=None, ip=None):
698
+ pass
699
+
700
+ def advance_to_next_work(self, *, loc=None, ip=None):
701
+ # Single tile scheduler - set to invalid tile_idx to indicate no more work
702
+ self._is_first_block = False
703
+
704
+ def __extract_mlir_values__(self):
705
+ values, self._values_pos = [], []
706
+ for obj in [self.params, self._tile_idx, self._split_idx]:
707
+ obj_values = cutlass.extract_mlir_values(obj)
708
+ values += obj_values
709
+ self._values_pos.append(len(obj_values))
710
+ return values
711
+
712
+ def __new_from_mlir_values__(self, values):
713
+ obj_list = []
714
+ for obj, n_items in zip(
715
+ [self.params, self._tile_idx, self._split_idx],
716
+ self._values_pos,
717
+ ):
718
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
719
+ values = values[n_items:]
720
+ return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc)