mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
import random
|
|
10
|
+
from typing import List, Optional, Sequence, Tuple, Type
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from .. import fmha
|
|
15
|
+
from .attn_bias import AttentionBias
|
|
16
|
+
from .common import AttentionOpBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _create_aligned_bias(*shape: int, **kwargs) -> torch.Tensor:
|
|
20
|
+
align_to = 8
|
|
21
|
+
return (
|
|
22
|
+
torch.randn(
|
|
23
|
+
(
|
|
24
|
+
*shape[:-1],
|
|
25
|
+
align_to * ((shape[-1] + align_to - 1) // align_to),
|
|
26
|
+
),
|
|
27
|
+
**kwargs,
|
|
28
|
+
)
|
|
29
|
+
* 3
|
|
30
|
+
).narrow(-1, 0, shape[-1])
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def create_attn_bias( # noqa: C901
|
|
34
|
+
bias_type,
|
|
35
|
+
batch_size: int,
|
|
36
|
+
num_heads: int,
|
|
37
|
+
num_heads_groups: int,
|
|
38
|
+
q_len: int,
|
|
39
|
+
kv_len: int,
|
|
40
|
+
device,
|
|
41
|
+
dtype,
|
|
42
|
+
requires_grad: bool,
|
|
43
|
+
fmt: str,
|
|
44
|
+
op: Optional[Type[AttentionOpBase]] = None,
|
|
45
|
+
page_size: Optional[int] = None,
|
|
46
|
+
):
|
|
47
|
+
if bias_type is None or isinstance(None, bias_type):
|
|
48
|
+
return None
|
|
49
|
+
r = random.Random("-".join(map(str, [batch_size, q_len, kv_len, dtype, fmt])))
|
|
50
|
+
window_size = {0: 3, 1: 128, 2: 300}[r.randint(0, 2)]
|
|
51
|
+
if bias_type is torch.Tensor:
|
|
52
|
+
if fmt == "BMK":
|
|
53
|
+
batch_size *= num_heads
|
|
54
|
+
num_heads = 1
|
|
55
|
+
if op is not None and issubclass(op, fmha.triton_splitk.FwOp):
|
|
56
|
+
attn_bias = (
|
|
57
|
+
torch.randn(
|
|
58
|
+
(batch_size, num_heads_groups, num_heads, q_len, kv_len),
|
|
59
|
+
device=device,
|
|
60
|
+
dtype=dtype,
|
|
61
|
+
)
|
|
62
|
+
* 3
|
|
63
|
+
)
|
|
64
|
+
if fmt in ["BMK", "BMHK"]:
|
|
65
|
+
attn_bias = attn_bias[:, 0]
|
|
66
|
+
else:
|
|
67
|
+
attn_bias = _create_aligned_bias(
|
|
68
|
+
batch_size,
|
|
69
|
+
num_heads_groups,
|
|
70
|
+
num_heads,
|
|
71
|
+
q_len,
|
|
72
|
+
kv_len,
|
|
73
|
+
device=device,
|
|
74
|
+
dtype=dtype,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# make sure it also works if the first columns/rows are partially masked out
|
|
78
|
+
attn_bias[0, 0, 0, : q_len - 1, : kv_len - 1] = -math.inf
|
|
79
|
+
if fmt in ["BMK", "BMHK"]:
|
|
80
|
+
attn_bias = attn_bias[:, 0]
|
|
81
|
+
|
|
82
|
+
if requires_grad:
|
|
83
|
+
attn_bias.requires_grad_(True)
|
|
84
|
+
if fmt == "BMK":
|
|
85
|
+
attn_bias = attn_bias[:, 0]
|
|
86
|
+
return attn_bias
|
|
87
|
+
if bias_type is fmha.attn_bias.LowerTriangularMask:
|
|
88
|
+
return bias_type()
|
|
89
|
+
if bias_type is fmha.attn_bias.LowerTriangularFromBottomRightMask:
|
|
90
|
+
return bias_type()
|
|
91
|
+
if bias_type is fmha.attn_bias.LowerTriangularFromBottomRightLocalAttentionMask:
|
|
92
|
+
return bias_type(window_size)
|
|
93
|
+
if bias_type is fmha.attn_bias.LowerTriangularMaskWithTensorBias:
|
|
94
|
+
attn_bias = _create_aligned_bias(
|
|
95
|
+
batch_size,
|
|
96
|
+
num_heads_groups,
|
|
97
|
+
num_heads,
|
|
98
|
+
q_len,
|
|
99
|
+
kv_len,
|
|
100
|
+
device=device,
|
|
101
|
+
dtype=dtype,
|
|
102
|
+
)
|
|
103
|
+
if fmt in ["BMK", "BMHK"]:
|
|
104
|
+
attn_bias = attn_bias[:, 0]
|
|
105
|
+
if fmt == "BMK":
|
|
106
|
+
attn_bias = attn_bias[:, 0]
|
|
107
|
+
if requires_grad:
|
|
108
|
+
attn_bias.requires_grad_(True)
|
|
109
|
+
return fmha.attn_bias.LowerTriangularMaskWithTensorBias(attn_bias)
|
|
110
|
+
if bias_type in [
|
|
111
|
+
fmha.attn_bias.BlockDiagonalMask,
|
|
112
|
+
fmha.attn_bias.BlockDiagonalCausalMask,
|
|
113
|
+
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
|
114
|
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
|
115
|
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
116
|
+
]:
|
|
117
|
+
# These bias types are not supported in BMK format
|
|
118
|
+
assert fmt in ["BMGHK", "BMHK"]
|
|
119
|
+
max_q_minus_k = None
|
|
120
|
+
if bias_type in {
|
|
121
|
+
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
|
122
|
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
123
|
+
}:
|
|
124
|
+
max_q_minus_k = 0
|
|
125
|
+
elif bias_type == fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask:
|
|
126
|
+
assert window_size is not None
|
|
127
|
+
max_q_minus_k = window_size - 1
|
|
128
|
+
|
|
129
|
+
block_diag = fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
130
|
+
*_rand_seqlens(
|
|
131
|
+
r,
|
|
132
|
+
batch_size,
|
|
133
|
+
q_len,
|
|
134
|
+
kv_len,
|
|
135
|
+
max_q_minus_k=max_q_minus_k,
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
if bias_type is fmha.attn_bias.BlockDiagonalCausalMask:
|
|
139
|
+
block_diag = block_diag.make_causal()
|
|
140
|
+
if bias_type in {
|
|
141
|
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
|
142
|
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
143
|
+
}:
|
|
144
|
+
block_diag = fmha.attn_bias.BlockDiagonalMask(
|
|
145
|
+
q_seqinfo=block_diag.q_seqinfo,
|
|
146
|
+
k_seqinfo=block_diag.k_seqinfo,
|
|
147
|
+
_batch_sizes=block_diag._batch_sizes,
|
|
148
|
+
)
|
|
149
|
+
assert window_size is not None
|
|
150
|
+
if bias_type is fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask:
|
|
151
|
+
block_diag = block_diag.make_local_attention(window_size)
|
|
152
|
+
else:
|
|
153
|
+
block_diag = block_diag.make_local_attention_from_bottomright(
|
|
154
|
+
window_size
|
|
155
|
+
)
|
|
156
|
+
if bias_type is fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask:
|
|
157
|
+
block_diag = block_diag.make_causal_from_bottomright()
|
|
158
|
+
return block_diag
|
|
159
|
+
if bias_type in [
|
|
160
|
+
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
|
|
161
|
+
fmha.attn_bias.BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
162
|
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
163
|
+
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
164
|
+
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
|
|
165
|
+
fmha.attn_bias.PagedBlockDiagonalCausalLocalPaddedKeysMask,
|
|
166
|
+
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
167
|
+
]:
|
|
168
|
+
assert fmt in ["BMHK", "BMGHK"]
|
|
169
|
+
q, k = _rand_seqlens_padded_k(r, batch_size, q_len, kv_len)
|
|
170
|
+
block_diag_type = (
|
|
171
|
+
bias_type._UNPAGED_TYPE
|
|
172
|
+
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask)
|
|
173
|
+
else bias_type
|
|
174
|
+
)
|
|
175
|
+
if bias_type in [
|
|
176
|
+
fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
177
|
+
fmha.attn_bias.PagedBlockDiagonalCausalLocalPaddedKeysMask,
|
|
178
|
+
]:
|
|
179
|
+
g_block_diag = block_diag_type.from_seqlens_local( # type: ignore
|
|
180
|
+
q_seqlen=q,
|
|
181
|
+
kv_padding=kv_len,
|
|
182
|
+
kv_seqlen=k,
|
|
183
|
+
window_size=min(window_size, min(k)),
|
|
184
|
+
)
|
|
185
|
+
elif bias_type is fmha.attn_bias.BlockDiagonalLocalAttentionPaddedKeysMask:
|
|
186
|
+
g_block_diag = block_diag_type.from_seqlens_local(
|
|
187
|
+
q_seqlen=q,
|
|
188
|
+
kv_padding=kv_len,
|
|
189
|
+
kv_seqlen=k,
|
|
190
|
+
window_left=max(window_size, max(q)) + 1,
|
|
191
|
+
window_right=max(window_size, max(q)) + 1,
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
g_block_diag = block_diag_type.from_seqlens( # type: ignore
|
|
195
|
+
q_seqlen=q,
|
|
196
|
+
kv_padding=kv_len, # type: ignore
|
|
197
|
+
kv_seqlen=k,
|
|
198
|
+
)
|
|
199
|
+
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask):
|
|
200
|
+
assert page_size is not None
|
|
201
|
+
pages_per_row = (kv_len + page_size - 1) // page_size
|
|
202
|
+
block_tables = torch.tensor(
|
|
203
|
+
r.sample(range(batch_size * pages_per_row), batch_size * pages_per_row),
|
|
204
|
+
device=device,
|
|
205
|
+
dtype=torch.int32,
|
|
206
|
+
).reshape(batch_size, pages_per_row)
|
|
207
|
+
return g_block_diag.make_paged(
|
|
208
|
+
block_tables=block_tables, page_size=page_size, paged_type=bias_type
|
|
209
|
+
)
|
|
210
|
+
return g_block_diag
|
|
211
|
+
if bias_type in [
|
|
212
|
+
fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
213
|
+
fmha.attn_bias.BlockDiagonalGappyKeysMask,
|
|
214
|
+
fmha.attn_bias.BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask,
|
|
215
|
+
]:
|
|
216
|
+
assert fmt in ["BMHK", "BMGHK"]
|
|
217
|
+
max_q_minus_k = (
|
|
218
|
+
None if bias_type is fmha.attn_bias.BlockDiagonalGappyKeysMask else 0
|
|
219
|
+
)
|
|
220
|
+
q, k = _rand_seqlens(r, batch_size, q_len, kv_len, max_q_minus_k)
|
|
221
|
+
total_kv_len = kv_len * batch_size
|
|
222
|
+
starts = [r.randint(0, total_kv_len - ki) for ki in k] + [total_kv_len]
|
|
223
|
+
if (
|
|
224
|
+
bias_type
|
|
225
|
+
is fmha.attn_bias.BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask
|
|
226
|
+
):
|
|
227
|
+
return bias_type.from_seqlens_local_gappy(
|
|
228
|
+
q_seqlen=q,
|
|
229
|
+
kv_seqstarts=starts,
|
|
230
|
+
kv_seqlen=k,
|
|
231
|
+
window_left=r.randint(0, 5),
|
|
232
|
+
window_right=r.randint(0, 5),
|
|
233
|
+
device=device,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return bias_type.from_seqlens(
|
|
237
|
+
q_seqlen=q,
|
|
238
|
+
kv_seqstarts=starts,
|
|
239
|
+
kv_seqlen=k,
|
|
240
|
+
)
|
|
241
|
+
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalGappyKeysMask):
|
|
242
|
+
assert fmt in ["BMHK", "BMGHK"]
|
|
243
|
+
assert page_size is not None
|
|
244
|
+
pages_per_row = (kv_len + page_size - 1) // page_size
|
|
245
|
+
total_queries = q_len * batch_size
|
|
246
|
+
if issubclass(
|
|
247
|
+
bias_type, fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetGappyKeysMask
|
|
248
|
+
):
|
|
249
|
+
q, k = _rand_seqlens_padded_k(r, batch_size, q_len, kv_len)
|
|
250
|
+
else:
|
|
251
|
+
q = _rand_maxed_partition(
|
|
252
|
+
r, total_queries, batch_size, total_queries, False
|
|
253
|
+
)
|
|
254
|
+
k = [r.randint(1, kv_len) for _ in range(batch_size)]
|
|
255
|
+
row_size = pages_per_row * page_size
|
|
256
|
+
starts = [row_size * i + r.randint(0, row_size - ki) for i, ki in enumerate(k)]
|
|
257
|
+
starts.append(pages_per_row * batch_size * page_size)
|
|
258
|
+
block_diag_type = bias_type._UNPAGED_TYPE # type: ignore
|
|
259
|
+
g_block_diag = block_diag_type.from_seqlens(
|
|
260
|
+
q_seqlen=q,
|
|
261
|
+
kv_seqstarts=starts,
|
|
262
|
+
kv_seqlen=k,
|
|
263
|
+
)
|
|
264
|
+
block_tables = torch.tensor(
|
|
265
|
+
r.sample(range(batch_size * pages_per_row), batch_size * pages_per_row),
|
|
266
|
+
device=device,
|
|
267
|
+
dtype=torch.int32,
|
|
268
|
+
).reshape(batch_size, pages_per_row)
|
|
269
|
+
return g_block_diag.make_paged(
|
|
270
|
+
block_tables=block_tables,
|
|
271
|
+
page_size=page_size,
|
|
272
|
+
paged_type=bias_type,
|
|
273
|
+
notional_padding=page_size * pages_per_row,
|
|
274
|
+
)
|
|
275
|
+
if bias_type == fmha.attn_bias.LocalAttentionFromBottomRightMask:
|
|
276
|
+
return bias_type(
|
|
277
|
+
window_left=r.randint(0, 5),
|
|
278
|
+
window_right=r.randint(0, 5),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
raise AssertionError(f"Unsupported bias type: {bias_type}")
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _rand_seqlens(
|
|
285
|
+
r: random.Random,
|
|
286
|
+
bs: int,
|
|
287
|
+
q_len: int,
|
|
288
|
+
kv_len: int,
|
|
289
|
+
max_q_minus_k: Optional[int],
|
|
290
|
+
) -> Tuple[Sequence[int], Sequence[int]]:
|
|
291
|
+
"""
|
|
292
|
+
Generates lists of lengths of query blocks and corresponding key blocks.
|
|
293
|
+
The total number of queries will be bs * q_len and the
|
|
294
|
+
total number of keys will be bs * kv_len.
|
|
295
|
+
max_q_minus_k: maximum allowed num_queries - num_keys.
|
|
296
|
+
For "bottom-right" masks it's 0, we need to have more keys than
|
|
297
|
+
queries, otherwise some queries have no keys to attend to.
|
|
298
|
+
For BlockDiagonalCausalMask it's None, there is no constraint
|
|
299
|
+
on num_queries - num_keys.
|
|
300
|
+
For BlockDiagonalCausalLocalAttentionMask it's equal
|
|
301
|
+
to the window size.
|
|
302
|
+
"""
|
|
303
|
+
if max_q_minus_k == 0:
|
|
304
|
+
# In case max_q_minus_k > 0 the exact condition is
|
|
305
|
+
# kv_len >= q_len - max_q_minus_k * batch_size,
|
|
306
|
+
# but we can't check it without knowing the actual batch size,
|
|
307
|
+
# which is determined in the loop below.
|
|
308
|
+
assert kv_len >= q_len
|
|
309
|
+
q_len *= bs
|
|
310
|
+
kv_len *= bs
|
|
311
|
+
seqlens_q: List[int] = []
|
|
312
|
+
seqlens_k: List[int] = []
|
|
313
|
+
|
|
314
|
+
step_q = [max(1, q_len // 10), max(2, q_len // 2)]
|
|
315
|
+
step_k = [max(1, kv_len // 10), max(2, kv_len // 2)]
|
|
316
|
+
while sum(seqlens_q) < q_len and sum(seqlens_k) < kv_len:
|
|
317
|
+
if max_q_minus_k is None:
|
|
318
|
+
# Simple case - no constraint on the number of queries and keys.
|
|
319
|
+
num_queries = r.randrange(*step_q)
|
|
320
|
+
seqlens_q.append(num_queries)
|
|
321
|
+
seqlens_k.append(r.randrange(*step_k))
|
|
322
|
+
else:
|
|
323
|
+
# In this case we need to make sure num_queries - num_keys < max_q_minus_k holds for every batch element.
|
|
324
|
+
# To do this, when choosing num_queries and num_keys at a given step,
|
|
325
|
+
# we ensure two conditions are satisfied:
|
|
326
|
+
# 1) num_queries <= num_keys + max_q_minus_k for the current batch element
|
|
327
|
+
# 2) Same holds for the remaining keys and queries, i.e.
|
|
328
|
+
# queries_left - num_queries <= keys_left - num_keys + max_q_minus_k
|
|
329
|
+
keys_left = kv_len - sum(seqlens_k, 0)
|
|
330
|
+
queries_left = q_len - sum(seqlens_q, 0)
|
|
331
|
+
|
|
332
|
+
assert keys_left >= queries_left - max_q_minus_k, (
|
|
333
|
+
f"{keys_left=} {queries_left=} {max_q_minus_k=} {kv_len=} {q_len=} {seqlens_k=} {seqlens_q=}"
|
|
334
|
+
)
|
|
335
|
+
# Limit num_queries from above: if num_queries > keys_left + max_q_minus_k,
|
|
336
|
+
# condition num_queries <= num_keys + max_q_minus_k can't be satisfied even if we take
|
|
337
|
+
# all the remaining keys
|
|
338
|
+
max_queries_to_take = min(queries_left, keys_left + max_q_minus_k)
|
|
339
|
+
num_queries = r.randrange(1, max_queries_to_take + 1)
|
|
340
|
+
seqlens_q.append(num_queries)
|
|
341
|
+
|
|
342
|
+
# Now we know num_queries, let's select num_keys.
|
|
343
|
+
# How many keys can we use for the current batch element so that
|
|
344
|
+
# for the remaining keys and values the constraint
|
|
345
|
+
# num_queries - num_keys < max_q_minus_k holds on the next step?
|
|
346
|
+
extra_keys_available = keys_left - queries_left + max_q_minus_k + 1
|
|
347
|
+
assert extra_keys_available >= 0
|
|
348
|
+
if extra_keys_available > 0:
|
|
349
|
+
seqlens_k.append(num_queries + r.randrange(0, extra_keys_available))
|
|
350
|
+
else:
|
|
351
|
+
seqlens_k.append(num_queries)
|
|
352
|
+
seqlens_q[-1] = q_len - sum(seqlens_q[:-1])
|
|
353
|
+
seqlens_k[-1] = kv_len - sum(seqlens_k[:-1])
|
|
354
|
+
return seqlens_q, seqlens_k
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def _rand_maxed_partition(
|
|
358
|
+
r: random.Random, total: int, n: int, mx: int, positive: bool = True
|
|
359
|
+
) -> List[int]:
|
|
360
|
+
# returns list of n nonnegative integers less than mx summing to total
|
|
361
|
+
# NB: This is unfortunately biased towards evenly-split bins.
|
|
362
|
+
# If `positive`, outputs are positive
|
|
363
|
+
if positive:
|
|
364
|
+
total -= n
|
|
365
|
+
mx -= 1
|
|
366
|
+
idxs = r.sample(range(n * mx), total)
|
|
367
|
+
y = torch.zeros(n, mx, dtype=torch.int32)
|
|
368
|
+
y.flatten()[idxs] = 1
|
|
369
|
+
z = y.sum(1)
|
|
370
|
+
if positive:
|
|
371
|
+
z += 1
|
|
372
|
+
return z.tolist()
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def _rand_seqlens_padded_k(
|
|
376
|
+
r: random.Random, bs: int, q_len: int, kv_len: int
|
|
377
|
+
) -> Tuple[Sequence[int], Sequence[int]]:
|
|
378
|
+
# This is for BlockDiagonalCausalWithOffsetPaddedKeysMask.
|
|
379
|
+
# we need q_seqlens and k_seqlens to be of len bsz.
|
|
380
|
+
# For each "batch element" there must be more keys than queries
|
|
381
|
+
# because this bias type is "bottom right" and so any extra queries
|
|
382
|
+
# will attend to nothing and have undefined result.
|
|
383
|
+
# In addition every element of k_seqlens must be <= kv_len
|
|
384
|
+
if q_len > kv_len:
|
|
385
|
+
raise ValueError("need more queries than keys")
|
|
386
|
+
if q_len == kv_len:
|
|
387
|
+
# all key slots are needed so we cannot have padding
|
|
388
|
+
q_seqlens = k_seqlens = [kv_len] * bs
|
|
389
|
+
else:
|
|
390
|
+
q_seqlens = _rand_maxed_partition(r, q_len * bs, bs, kv_len)
|
|
391
|
+
k_seqlens = [r.randint(i, kv_len) for i in q_seqlens]
|
|
392
|
+
return q_seqlens, k_seqlens
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None):
|
|
396
|
+
if q.ndim == 5:
|
|
397
|
+
|
|
398
|
+
def attn_bias_group(group: int):
|
|
399
|
+
if isinstance(attn_bias, torch.Tensor):
|
|
400
|
+
return attn_bias[:, group]
|
|
401
|
+
if isinstance(attn_bias, fmha.attn_bias.LowerTriangularMaskWithTensorBias):
|
|
402
|
+
return fmha.attn_bias.LowerTriangularMaskWithTensorBias(
|
|
403
|
+
attn_bias._bias[:, group]
|
|
404
|
+
)
|
|
405
|
+
return attn_bias
|
|
406
|
+
|
|
407
|
+
return torch.stack(
|
|
408
|
+
[
|
|
409
|
+
ref_attention_bmhk(
|
|
410
|
+
q[:, :, g],
|
|
411
|
+
k[:, :, g],
|
|
412
|
+
v[:, :, g],
|
|
413
|
+
scale=scale,
|
|
414
|
+
attn_bias=attn_bias_group(g),
|
|
415
|
+
)
|
|
416
|
+
for g in range(q.shape[2])
|
|
417
|
+
],
|
|
418
|
+
dim=2,
|
|
419
|
+
)
|
|
420
|
+
if q.ndim == 4:
|
|
421
|
+
assert p == 0.0
|
|
422
|
+
return ref_attention_bmhk(q, k, v, scale=scale, attn_bias=attn_bias)
|
|
423
|
+
q = q.float()
|
|
424
|
+
k = k.float()
|
|
425
|
+
v = v.float()
|
|
426
|
+
|
|
427
|
+
scale = scale if scale is not None else (1 / q.shape[-1] ** 0.5)
|
|
428
|
+
q = q * scale
|
|
429
|
+
|
|
430
|
+
attn = q @ k.transpose(-2, -1)
|
|
431
|
+
if attn_bias is not None:
|
|
432
|
+
if isinstance(attn_bias, AttentionBias):
|
|
433
|
+
# Always create in B,H,Mq,Mk format
|
|
434
|
+
attn_bias_tensor = attn_bias.materialize(
|
|
435
|
+
(q.shape[0], 1, q.shape[1], k.shape[1]),
|
|
436
|
+
device=q.device,
|
|
437
|
+
dtype=torch.float32,
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
attn_bias_tensor = attn_bias
|
|
441
|
+
if attn_bias_tensor.ndim == 4:
|
|
442
|
+
assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1]
|
|
443
|
+
attn_bias_tensor = attn_bias_tensor.reshape(
|
|
444
|
+
[-1, *attn_bias_tensor.shape[2:]]
|
|
445
|
+
)
|
|
446
|
+
attn = attn + attn_bias_tensor.float()
|
|
447
|
+
attn = attn.softmax(-1)
|
|
448
|
+
if drop_mask is not None:
|
|
449
|
+
attn = attn * (drop_mask / (1 - p))
|
|
450
|
+
return attn @ v
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def ref_attention_bmhk(q, k, v, attn_bias, scale=None) -> torch.Tensor:
|
|
454
|
+
assert q.ndim == 4
|
|
455
|
+
|
|
456
|
+
def T(t):
|
|
457
|
+
return t.permute((0, 2, 1, 3)).reshape(
|
|
458
|
+
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if isinstance(attn_bias, AttentionBias):
|
|
462
|
+
attn_bias = attn_bias.materialize(
|
|
463
|
+
(q.shape[0], q.shape[2], q.shape[1], k.shape[1]),
|
|
464
|
+
device=q.device,
|
|
465
|
+
dtype=torch.float32,
|
|
466
|
+
).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]])
|
|
467
|
+
out = ref_attention(T(q), T(k), T(v), attn_bias, scale=scale)
|
|
468
|
+
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
|
|
469
|
+
return out.permute((0, 2, 1, 3))
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def pack_kv_cache(
|
|
473
|
+
cache_k: torch.Tensor,
|
|
474
|
+
cache_v: torch.Tensor,
|
|
475
|
+
kv_seqlens: List[int],
|
|
476
|
+
BLOCK_N: int,
|
|
477
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
478
|
+
"""
|
|
479
|
+
Create block tables and pages K/V cache for testing paged attention.
|
|
480
|
+
Args:
|
|
481
|
+
cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D].
|
|
482
|
+
Note that these tensors are unexpanded,
|
|
483
|
+
i.e. for multiquery case cache_k.shape[2] = 1
|
|
484
|
+
kv_seqlens: list of K/V sequence lengths
|
|
485
|
+
BLOCK_N: number of tokens per per paged attention block
|
|
486
|
+
B: batch size
|
|
487
|
+
Returns:
|
|
488
|
+
block_tables: [B, MAX_BLOCKS]
|
|
489
|
+
packed_cache_k: [1, total_len_rounded, H_kv, D]
|
|
490
|
+
packed_cache_v: [1, total_len_rounded, H_kv, D]
|
|
491
|
+
where total_len_rounded is a sum of K/V seqlens, each rounded up
|
|
492
|
+
to a multiple of BLOCK_N.
|
|
493
|
+
"""
|
|
494
|
+
|
|
495
|
+
kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens]
|
|
496
|
+
|
|
497
|
+
total_len_rounded = sum(kv_seqlens_rounded)
|
|
498
|
+
|
|
499
|
+
B, MAX_T, H, D = cache_k.shape
|
|
500
|
+
|
|
501
|
+
packed_cache_k = torch.empty(
|
|
502
|
+
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
|
|
503
|
+
)
|
|
504
|
+
packed_cache_v = torch.empty(
|
|
505
|
+
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
|
|
506
|
+
)
|
|
507
|
+
seqstart = 0
|
|
508
|
+
for b in range(B):
|
|
509
|
+
packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[
|
|
510
|
+
b, : kv_seqlens[b]
|
|
511
|
+
].clone()
|
|
512
|
+
packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[
|
|
513
|
+
b, : kv_seqlens[b]
|
|
514
|
+
].clone()
|
|
515
|
+
seqstart += kv_seqlens_rounded[b]
|
|
516
|
+
|
|
517
|
+
num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N
|
|
518
|
+
block_tables = (
|
|
519
|
+
torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32)
|
|
520
|
+
.unsqueeze(0)
|
|
521
|
+
.expand(B, num_blocks_per_row)
|
|
522
|
+
)
|
|
523
|
+
seqstarts = (
|
|
524
|
+
(
|
|
525
|
+
torch.tensor(kv_seqlens_rounded).cumsum(dim=0)
|
|
526
|
+
- torch.tensor(kv_seqlens_rounded)
|
|
527
|
+
)
|
|
528
|
+
.to(device="cuda")
|
|
529
|
+
.unsqueeze(1)
|
|
530
|
+
) // BLOCK_N
|
|
531
|
+
block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32)
|
|
532
|
+
return (
|
|
533
|
+
block_tables,
|
|
534
|
+
packed_cache_k.unsqueeze(0),
|
|
535
|
+
packed_cache_v.unsqueeze(0),
|
|
536
|
+
)
|