mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,536 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ import math
9
+ import random
10
+ from typing import List, Optional, Sequence, Tuple, Type
11
+
12
+ import torch
13
+
14
+ from .. import fmha
15
+ from .attn_bias import AttentionBias
16
+ from .common import AttentionOpBase
17
+
18
+
19
+ def _create_aligned_bias(*shape: int, **kwargs) -> torch.Tensor:
20
+ align_to = 8
21
+ return (
22
+ torch.randn(
23
+ (
24
+ *shape[:-1],
25
+ align_to * ((shape[-1] + align_to - 1) // align_to),
26
+ ),
27
+ **kwargs,
28
+ )
29
+ * 3
30
+ ).narrow(-1, 0, shape[-1])
31
+
32
+
33
+ def create_attn_bias( # noqa: C901
34
+ bias_type,
35
+ batch_size: int,
36
+ num_heads: int,
37
+ num_heads_groups: int,
38
+ q_len: int,
39
+ kv_len: int,
40
+ device,
41
+ dtype,
42
+ requires_grad: bool,
43
+ fmt: str,
44
+ op: Optional[Type[AttentionOpBase]] = None,
45
+ page_size: Optional[int] = None,
46
+ ):
47
+ if bias_type is None or isinstance(None, bias_type):
48
+ return None
49
+ r = random.Random("-".join(map(str, [batch_size, q_len, kv_len, dtype, fmt])))
50
+ window_size = {0: 3, 1: 128, 2: 300}[r.randint(0, 2)]
51
+ if bias_type is torch.Tensor:
52
+ if fmt == "BMK":
53
+ batch_size *= num_heads
54
+ num_heads = 1
55
+ if op is not None and issubclass(op, fmha.triton_splitk.FwOp):
56
+ attn_bias = (
57
+ torch.randn(
58
+ (batch_size, num_heads_groups, num_heads, q_len, kv_len),
59
+ device=device,
60
+ dtype=dtype,
61
+ )
62
+ * 3
63
+ )
64
+ if fmt in ["BMK", "BMHK"]:
65
+ attn_bias = attn_bias[:, 0]
66
+ else:
67
+ attn_bias = _create_aligned_bias(
68
+ batch_size,
69
+ num_heads_groups,
70
+ num_heads,
71
+ q_len,
72
+ kv_len,
73
+ device=device,
74
+ dtype=dtype,
75
+ )
76
+
77
+ # make sure it also works if the first columns/rows are partially masked out
78
+ attn_bias[0, 0, 0, : q_len - 1, : kv_len - 1] = -math.inf
79
+ if fmt in ["BMK", "BMHK"]:
80
+ attn_bias = attn_bias[:, 0]
81
+
82
+ if requires_grad:
83
+ attn_bias.requires_grad_(True)
84
+ if fmt == "BMK":
85
+ attn_bias = attn_bias[:, 0]
86
+ return attn_bias
87
+ if bias_type is fmha.attn_bias.LowerTriangularMask:
88
+ return bias_type()
89
+ if bias_type is fmha.attn_bias.LowerTriangularFromBottomRightMask:
90
+ return bias_type()
91
+ if bias_type is fmha.attn_bias.LowerTriangularFromBottomRightLocalAttentionMask:
92
+ return bias_type(window_size)
93
+ if bias_type is fmha.attn_bias.LowerTriangularMaskWithTensorBias:
94
+ attn_bias = _create_aligned_bias(
95
+ batch_size,
96
+ num_heads_groups,
97
+ num_heads,
98
+ q_len,
99
+ kv_len,
100
+ device=device,
101
+ dtype=dtype,
102
+ )
103
+ if fmt in ["BMK", "BMHK"]:
104
+ attn_bias = attn_bias[:, 0]
105
+ if fmt == "BMK":
106
+ attn_bias = attn_bias[:, 0]
107
+ if requires_grad:
108
+ attn_bias.requires_grad_(True)
109
+ return fmha.attn_bias.LowerTriangularMaskWithTensorBias(attn_bias)
110
+ if bias_type in [
111
+ fmha.attn_bias.BlockDiagonalMask,
112
+ fmha.attn_bias.BlockDiagonalCausalMask,
113
+ fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
114
+ fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask,
115
+ fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
116
+ ]:
117
+ # These bias types are not supported in BMK format
118
+ assert fmt in ["BMGHK", "BMHK"]
119
+ max_q_minus_k = None
120
+ if bias_type in {
121
+ fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
122
+ fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
123
+ }:
124
+ max_q_minus_k = 0
125
+ elif bias_type == fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask:
126
+ assert window_size is not None
127
+ max_q_minus_k = window_size - 1
128
+
129
+ block_diag = fmha.attn_bias.BlockDiagonalMask.from_seqlens(
130
+ *_rand_seqlens(
131
+ r,
132
+ batch_size,
133
+ q_len,
134
+ kv_len,
135
+ max_q_minus_k=max_q_minus_k,
136
+ )
137
+ )
138
+ if bias_type is fmha.attn_bias.BlockDiagonalCausalMask:
139
+ block_diag = block_diag.make_causal()
140
+ if bias_type in {
141
+ fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask,
142
+ fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
143
+ }:
144
+ block_diag = fmha.attn_bias.BlockDiagonalMask(
145
+ q_seqinfo=block_diag.q_seqinfo,
146
+ k_seqinfo=block_diag.k_seqinfo,
147
+ _batch_sizes=block_diag._batch_sizes,
148
+ )
149
+ assert window_size is not None
150
+ if bias_type is fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask:
151
+ block_diag = block_diag.make_local_attention(window_size)
152
+ else:
153
+ block_diag = block_diag.make_local_attention_from_bottomright(
154
+ window_size
155
+ )
156
+ if bias_type is fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask:
157
+ block_diag = block_diag.make_causal_from_bottomright()
158
+ return block_diag
159
+ if bias_type in [
160
+ fmha.attn_bias.BlockDiagonalPaddedKeysMask,
161
+ fmha.attn_bias.BlockDiagonalLocalAttentionPaddedKeysMask,
162
+ fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask,
163
+ fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
164
+ fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
165
+ fmha.attn_bias.PagedBlockDiagonalCausalLocalPaddedKeysMask,
166
+ fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
167
+ ]:
168
+ assert fmt in ["BMHK", "BMGHK"]
169
+ q, k = _rand_seqlens_padded_k(r, batch_size, q_len, kv_len)
170
+ block_diag_type = (
171
+ bias_type._UNPAGED_TYPE
172
+ if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask)
173
+ else bias_type
174
+ )
175
+ if bias_type in [
176
+ fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask,
177
+ fmha.attn_bias.PagedBlockDiagonalCausalLocalPaddedKeysMask,
178
+ ]:
179
+ g_block_diag = block_diag_type.from_seqlens_local( # type: ignore
180
+ q_seqlen=q,
181
+ kv_padding=kv_len,
182
+ kv_seqlen=k,
183
+ window_size=min(window_size, min(k)),
184
+ )
185
+ elif bias_type is fmha.attn_bias.BlockDiagonalLocalAttentionPaddedKeysMask:
186
+ g_block_diag = block_diag_type.from_seqlens_local(
187
+ q_seqlen=q,
188
+ kv_padding=kv_len,
189
+ kv_seqlen=k,
190
+ window_left=max(window_size, max(q)) + 1,
191
+ window_right=max(window_size, max(q)) + 1,
192
+ )
193
+ else:
194
+ g_block_diag = block_diag_type.from_seqlens( # type: ignore
195
+ q_seqlen=q,
196
+ kv_padding=kv_len, # type: ignore
197
+ kv_seqlen=k,
198
+ )
199
+ if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask):
200
+ assert page_size is not None
201
+ pages_per_row = (kv_len + page_size - 1) // page_size
202
+ block_tables = torch.tensor(
203
+ r.sample(range(batch_size * pages_per_row), batch_size * pages_per_row),
204
+ device=device,
205
+ dtype=torch.int32,
206
+ ).reshape(batch_size, pages_per_row)
207
+ return g_block_diag.make_paged(
208
+ block_tables=block_tables, page_size=page_size, paged_type=bias_type
209
+ )
210
+ return g_block_diag
211
+ if bias_type in [
212
+ fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask,
213
+ fmha.attn_bias.BlockDiagonalGappyKeysMask,
214
+ fmha.attn_bias.BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask,
215
+ ]:
216
+ assert fmt in ["BMHK", "BMGHK"]
217
+ max_q_minus_k = (
218
+ None if bias_type is fmha.attn_bias.BlockDiagonalGappyKeysMask else 0
219
+ )
220
+ q, k = _rand_seqlens(r, batch_size, q_len, kv_len, max_q_minus_k)
221
+ total_kv_len = kv_len * batch_size
222
+ starts = [r.randint(0, total_kv_len - ki) for ki in k] + [total_kv_len]
223
+ if (
224
+ bias_type
225
+ is fmha.attn_bias.BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask
226
+ ):
227
+ return bias_type.from_seqlens_local_gappy(
228
+ q_seqlen=q,
229
+ kv_seqstarts=starts,
230
+ kv_seqlen=k,
231
+ window_left=r.randint(0, 5),
232
+ window_right=r.randint(0, 5),
233
+ device=device,
234
+ )
235
+
236
+ return bias_type.from_seqlens(
237
+ q_seqlen=q,
238
+ kv_seqstarts=starts,
239
+ kv_seqlen=k,
240
+ )
241
+ if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalGappyKeysMask):
242
+ assert fmt in ["BMHK", "BMGHK"]
243
+ assert page_size is not None
244
+ pages_per_row = (kv_len + page_size - 1) // page_size
245
+ total_queries = q_len * batch_size
246
+ if issubclass(
247
+ bias_type, fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetGappyKeysMask
248
+ ):
249
+ q, k = _rand_seqlens_padded_k(r, batch_size, q_len, kv_len)
250
+ else:
251
+ q = _rand_maxed_partition(
252
+ r, total_queries, batch_size, total_queries, False
253
+ )
254
+ k = [r.randint(1, kv_len) for _ in range(batch_size)]
255
+ row_size = pages_per_row * page_size
256
+ starts = [row_size * i + r.randint(0, row_size - ki) for i, ki in enumerate(k)]
257
+ starts.append(pages_per_row * batch_size * page_size)
258
+ block_diag_type = bias_type._UNPAGED_TYPE # type: ignore
259
+ g_block_diag = block_diag_type.from_seqlens(
260
+ q_seqlen=q,
261
+ kv_seqstarts=starts,
262
+ kv_seqlen=k,
263
+ )
264
+ block_tables = torch.tensor(
265
+ r.sample(range(batch_size * pages_per_row), batch_size * pages_per_row),
266
+ device=device,
267
+ dtype=torch.int32,
268
+ ).reshape(batch_size, pages_per_row)
269
+ return g_block_diag.make_paged(
270
+ block_tables=block_tables,
271
+ page_size=page_size,
272
+ paged_type=bias_type,
273
+ notional_padding=page_size * pages_per_row,
274
+ )
275
+ if bias_type == fmha.attn_bias.LocalAttentionFromBottomRightMask:
276
+ return bias_type(
277
+ window_left=r.randint(0, 5),
278
+ window_right=r.randint(0, 5),
279
+ )
280
+
281
+ raise AssertionError(f"Unsupported bias type: {bias_type}")
282
+
283
+
284
+ def _rand_seqlens(
285
+ r: random.Random,
286
+ bs: int,
287
+ q_len: int,
288
+ kv_len: int,
289
+ max_q_minus_k: Optional[int],
290
+ ) -> Tuple[Sequence[int], Sequence[int]]:
291
+ """
292
+ Generates lists of lengths of query blocks and corresponding key blocks.
293
+ The total number of queries will be bs * q_len and the
294
+ total number of keys will be bs * kv_len.
295
+ max_q_minus_k: maximum allowed num_queries - num_keys.
296
+ For "bottom-right" masks it's 0, we need to have more keys than
297
+ queries, otherwise some queries have no keys to attend to.
298
+ For BlockDiagonalCausalMask it's None, there is no constraint
299
+ on num_queries - num_keys.
300
+ For BlockDiagonalCausalLocalAttentionMask it's equal
301
+ to the window size.
302
+ """
303
+ if max_q_minus_k == 0:
304
+ # In case max_q_minus_k > 0 the exact condition is
305
+ # kv_len >= q_len - max_q_minus_k * batch_size,
306
+ # but we can't check it without knowing the actual batch size,
307
+ # which is determined in the loop below.
308
+ assert kv_len >= q_len
309
+ q_len *= bs
310
+ kv_len *= bs
311
+ seqlens_q: List[int] = []
312
+ seqlens_k: List[int] = []
313
+
314
+ step_q = [max(1, q_len // 10), max(2, q_len // 2)]
315
+ step_k = [max(1, kv_len // 10), max(2, kv_len // 2)]
316
+ while sum(seqlens_q) < q_len and sum(seqlens_k) < kv_len:
317
+ if max_q_minus_k is None:
318
+ # Simple case - no constraint on the number of queries and keys.
319
+ num_queries = r.randrange(*step_q)
320
+ seqlens_q.append(num_queries)
321
+ seqlens_k.append(r.randrange(*step_k))
322
+ else:
323
+ # In this case we need to make sure num_queries - num_keys < max_q_minus_k holds for every batch element.
324
+ # To do this, when choosing num_queries and num_keys at a given step,
325
+ # we ensure two conditions are satisfied:
326
+ # 1) num_queries <= num_keys + max_q_minus_k for the current batch element
327
+ # 2) Same holds for the remaining keys and queries, i.e.
328
+ # queries_left - num_queries <= keys_left - num_keys + max_q_minus_k
329
+ keys_left = kv_len - sum(seqlens_k, 0)
330
+ queries_left = q_len - sum(seqlens_q, 0)
331
+
332
+ assert keys_left >= queries_left - max_q_minus_k, (
333
+ f"{keys_left=} {queries_left=} {max_q_minus_k=} {kv_len=} {q_len=} {seqlens_k=} {seqlens_q=}"
334
+ )
335
+ # Limit num_queries from above: if num_queries > keys_left + max_q_minus_k,
336
+ # condition num_queries <= num_keys + max_q_minus_k can't be satisfied even if we take
337
+ # all the remaining keys
338
+ max_queries_to_take = min(queries_left, keys_left + max_q_minus_k)
339
+ num_queries = r.randrange(1, max_queries_to_take + 1)
340
+ seqlens_q.append(num_queries)
341
+
342
+ # Now we know num_queries, let's select num_keys.
343
+ # How many keys can we use for the current batch element so that
344
+ # for the remaining keys and values the constraint
345
+ # num_queries - num_keys < max_q_minus_k holds on the next step?
346
+ extra_keys_available = keys_left - queries_left + max_q_minus_k + 1
347
+ assert extra_keys_available >= 0
348
+ if extra_keys_available > 0:
349
+ seqlens_k.append(num_queries + r.randrange(0, extra_keys_available))
350
+ else:
351
+ seqlens_k.append(num_queries)
352
+ seqlens_q[-1] = q_len - sum(seqlens_q[:-1])
353
+ seqlens_k[-1] = kv_len - sum(seqlens_k[:-1])
354
+ return seqlens_q, seqlens_k
355
+
356
+
357
+ def _rand_maxed_partition(
358
+ r: random.Random, total: int, n: int, mx: int, positive: bool = True
359
+ ) -> List[int]:
360
+ # returns list of n nonnegative integers less than mx summing to total
361
+ # NB: This is unfortunately biased towards evenly-split bins.
362
+ # If `positive`, outputs are positive
363
+ if positive:
364
+ total -= n
365
+ mx -= 1
366
+ idxs = r.sample(range(n * mx), total)
367
+ y = torch.zeros(n, mx, dtype=torch.int32)
368
+ y.flatten()[idxs] = 1
369
+ z = y.sum(1)
370
+ if positive:
371
+ z += 1
372
+ return z.tolist()
373
+
374
+
375
+ def _rand_seqlens_padded_k(
376
+ r: random.Random, bs: int, q_len: int, kv_len: int
377
+ ) -> Tuple[Sequence[int], Sequence[int]]:
378
+ # This is for BlockDiagonalCausalWithOffsetPaddedKeysMask.
379
+ # we need q_seqlens and k_seqlens to be of len bsz.
380
+ # For each "batch element" there must be more keys than queries
381
+ # because this bias type is "bottom right" and so any extra queries
382
+ # will attend to nothing and have undefined result.
383
+ # In addition every element of k_seqlens must be <= kv_len
384
+ if q_len > kv_len:
385
+ raise ValueError("need more queries than keys")
386
+ if q_len == kv_len:
387
+ # all key slots are needed so we cannot have padding
388
+ q_seqlens = k_seqlens = [kv_len] * bs
389
+ else:
390
+ q_seqlens = _rand_maxed_partition(r, q_len * bs, bs, kv_len)
391
+ k_seqlens = [r.randint(i, kv_len) for i in q_seqlens]
392
+ return q_seqlens, k_seqlens
393
+
394
+
395
+ def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None):
396
+ if q.ndim == 5:
397
+
398
+ def attn_bias_group(group: int):
399
+ if isinstance(attn_bias, torch.Tensor):
400
+ return attn_bias[:, group]
401
+ if isinstance(attn_bias, fmha.attn_bias.LowerTriangularMaskWithTensorBias):
402
+ return fmha.attn_bias.LowerTriangularMaskWithTensorBias(
403
+ attn_bias._bias[:, group]
404
+ )
405
+ return attn_bias
406
+
407
+ return torch.stack(
408
+ [
409
+ ref_attention_bmhk(
410
+ q[:, :, g],
411
+ k[:, :, g],
412
+ v[:, :, g],
413
+ scale=scale,
414
+ attn_bias=attn_bias_group(g),
415
+ )
416
+ for g in range(q.shape[2])
417
+ ],
418
+ dim=2,
419
+ )
420
+ if q.ndim == 4:
421
+ assert p == 0.0
422
+ return ref_attention_bmhk(q, k, v, scale=scale, attn_bias=attn_bias)
423
+ q = q.float()
424
+ k = k.float()
425
+ v = v.float()
426
+
427
+ scale = scale if scale is not None else (1 / q.shape[-1] ** 0.5)
428
+ q = q * scale
429
+
430
+ attn = q @ k.transpose(-2, -1)
431
+ if attn_bias is not None:
432
+ if isinstance(attn_bias, AttentionBias):
433
+ # Always create in B,H,Mq,Mk format
434
+ attn_bias_tensor = attn_bias.materialize(
435
+ (q.shape[0], 1, q.shape[1], k.shape[1]),
436
+ device=q.device,
437
+ dtype=torch.float32,
438
+ )
439
+ else:
440
+ attn_bias_tensor = attn_bias
441
+ if attn_bias_tensor.ndim == 4:
442
+ assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1]
443
+ attn_bias_tensor = attn_bias_tensor.reshape(
444
+ [-1, *attn_bias_tensor.shape[2:]]
445
+ )
446
+ attn = attn + attn_bias_tensor.float()
447
+ attn = attn.softmax(-1)
448
+ if drop_mask is not None:
449
+ attn = attn * (drop_mask / (1 - p))
450
+ return attn @ v
451
+
452
+
453
+ def ref_attention_bmhk(q, k, v, attn_bias, scale=None) -> torch.Tensor:
454
+ assert q.ndim == 4
455
+
456
+ def T(t):
457
+ return t.permute((0, 2, 1, 3)).reshape(
458
+ [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
459
+ )
460
+
461
+ if isinstance(attn_bias, AttentionBias):
462
+ attn_bias = attn_bias.materialize(
463
+ (q.shape[0], q.shape[2], q.shape[1], k.shape[1]),
464
+ device=q.device,
465
+ dtype=torch.float32,
466
+ ).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]])
467
+ out = ref_attention(T(q), T(k), T(v), attn_bias, scale=scale)
468
+ out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
469
+ return out.permute((0, 2, 1, 3))
470
+
471
+
472
+ def pack_kv_cache(
473
+ cache_k: torch.Tensor,
474
+ cache_v: torch.Tensor,
475
+ kv_seqlens: List[int],
476
+ BLOCK_N: int,
477
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
478
+ """
479
+ Create block tables and pages K/V cache for testing paged attention.
480
+ Args:
481
+ cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D].
482
+ Note that these tensors are unexpanded,
483
+ i.e. for multiquery case cache_k.shape[2] = 1
484
+ kv_seqlens: list of K/V sequence lengths
485
+ BLOCK_N: number of tokens per per paged attention block
486
+ B: batch size
487
+ Returns:
488
+ block_tables: [B, MAX_BLOCKS]
489
+ packed_cache_k: [1, total_len_rounded, H_kv, D]
490
+ packed_cache_v: [1, total_len_rounded, H_kv, D]
491
+ where total_len_rounded is a sum of K/V seqlens, each rounded up
492
+ to a multiple of BLOCK_N.
493
+ """
494
+
495
+ kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens]
496
+
497
+ total_len_rounded = sum(kv_seqlens_rounded)
498
+
499
+ B, MAX_T, H, D = cache_k.shape
500
+
501
+ packed_cache_k = torch.empty(
502
+ total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
503
+ )
504
+ packed_cache_v = torch.empty(
505
+ total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
506
+ )
507
+ seqstart = 0
508
+ for b in range(B):
509
+ packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[
510
+ b, : kv_seqlens[b]
511
+ ].clone()
512
+ packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[
513
+ b, : kv_seqlens[b]
514
+ ].clone()
515
+ seqstart += kv_seqlens_rounded[b]
516
+
517
+ num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N
518
+ block_tables = (
519
+ torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32)
520
+ .unsqueeze(0)
521
+ .expand(B, num_blocks_per_row)
522
+ )
523
+ seqstarts = (
524
+ (
525
+ torch.tensor(kv_seqlens_rounded).cumsum(dim=0)
526
+ - torch.tensor(kv_seqlens_rounded)
527
+ )
528
+ .to(device="cuda")
529
+ .unsqueeze(1)
530
+ ) // BLOCK_N
531
+ block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32)
532
+ return (
533
+ block_tables,
534
+ packed_cache_k.unsqueeze(0),
535
+ packed_cache_v.unsqueeze(0),
536
+ )