mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
from typing import List, Optional, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from .attn_bias import (
|
|
13
|
+
_GappySeqInfo,
|
|
14
|
+
_PaddedSeqLenInfo,
|
|
15
|
+
_SeqLenInfo,
|
|
16
|
+
AttentionBias,
|
|
17
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
18
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
19
|
+
BlockDiagonalGappyKeysMask,
|
|
20
|
+
BlockDiagonalPaddedKeysMask,
|
|
21
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
22
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
23
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
24
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def split_blocks_for_decoding_gpu_part(
|
|
29
|
+
input_bias: Union[
|
|
30
|
+
BlockDiagonalPaddedKeysMask, BlockDiagonalCausalWithOffsetPaddedKeysMask
|
|
31
|
+
],
|
|
32
|
+
batchify_len: Optional[int],
|
|
33
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
34
|
+
page_size: Optional[int] = None,
|
|
35
|
+
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
|
36
|
+
"""
|
|
37
|
+
This is the gpu part of split_blocks_for_decoding,
|
|
38
|
+
which can be called in advance.
|
|
39
|
+
"""
|
|
40
|
+
if batchify_len is None:
|
|
41
|
+
return None
|
|
42
|
+
assert batchify_len > 0
|
|
43
|
+
assert input_bias.q_seqinfo.min_seqlen == input_bias.q_seqinfo.max_seqlen
|
|
44
|
+
|
|
45
|
+
seqstart = input_bias.k_seqinfo.seqstart # (B+1,)
|
|
46
|
+
seqlen = input_bias.k_seqinfo.seqlen # (B,)
|
|
47
|
+
|
|
48
|
+
# compute raw block boundaries
|
|
49
|
+
k_ends = seqstart[:-1] + seqlen # (B,)
|
|
50
|
+
# For non-speculative decoding, we have a causal bias here,
|
|
51
|
+
# which will always be from-bottom-right style.
|
|
52
|
+
# Q and K are aligned so that their last tokens are at the same position.
|
|
53
|
+
# If seqlen == batchify_len, the first token of the query is at position batchify_len - 1,
|
|
54
|
+
# and it can attend to all keys from the previous iRoPE chunk.
|
|
55
|
+
# The diagram shows that when seqlen == batchify_len == N and the bias is causal,
|
|
56
|
+
# Q can still attend to K from the previous chunk.
|
|
57
|
+
# -----------iRoPE chunk 0---------|---------iRoPE chunk 1---------------
|
|
58
|
+
# Q[0] |
|
|
59
|
+
# K[0] K[1] K[2] ... K[N-2] K[N-1] |
|
|
60
|
+
|
|
61
|
+
# For speculative decoding, we use this function for the prefix bias only.
|
|
62
|
+
# We are called with a non-causal bias.
|
|
63
|
+
# The query is positioned after the keys, and so when seqlen == batchify_len,
|
|
64
|
+
# the first token of the query is at position batchify_len.
|
|
65
|
+
# So it can't attend to any key from the previous chunk,
|
|
66
|
+
# so we want k_starts == k_ends => k_lens == 0.
|
|
67
|
+
# The diagram shows that when seqlen == batchify_len == N and the bias is non-causal,
|
|
68
|
+
# Q is located entirely in the next iRoPE chunk and can't attend to K[0] ... K[N-1].
|
|
69
|
+
# ------------iRoPE chunk 0---------------|---------iRoPE chunk 1---------
|
|
70
|
+
# | Q[0] Q[1] Q[2]
|
|
71
|
+
# K[0] K[1] K[2] ... K[N-3] K[N-2] K[N-1] |
|
|
72
|
+
|
|
73
|
+
shift = int(isinstance(input_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask))
|
|
74
|
+
k_starts = (k_ends - shift) // batchify_len * batchify_len
|
|
75
|
+
k_starts = torch.where(seqlen == 0, k_ends, k_starts)
|
|
76
|
+
k_lens = k_ends - k_starts
|
|
77
|
+
|
|
78
|
+
if block_tables is None:
|
|
79
|
+
k_seqstarts = torch.cat([k_starts, seqstart[-1:]])
|
|
80
|
+
else:
|
|
81
|
+
k_seqstarts = (k_starts - seqstart[:-1]).clamp(min=0)
|
|
82
|
+
k_lens = k_lens + k_seqstarts
|
|
83
|
+
|
|
84
|
+
return k_seqstarts, k_lens
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def split_blocks_for_decoding(
|
|
88
|
+
input_bias: Union[
|
|
89
|
+
BlockDiagonalPaddedKeysMask, BlockDiagonalCausalWithOffsetPaddedKeysMask
|
|
90
|
+
],
|
|
91
|
+
batchify_len: Optional[int],
|
|
92
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
93
|
+
page_size: Optional[int] = None,
|
|
94
|
+
gpu_data: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
95
|
+
) -> Optional[Union[BlockDiagonalGappyKeysMask, PagedBlockDiagonalGappyKeysMask]]:
|
|
96
|
+
"""
|
|
97
|
+
For decoding, when query length is 1, we can represent iRoPE-batchified bias as a gappy bias.
|
|
98
|
+
This function can also be applied for speculative decoding, when query length is > 1,
|
|
99
|
+
but same across all batch elements. In this case we assume that query (draft) lies entirely
|
|
100
|
+
in one block/subsequence, not crossing the boundary. Cases when the query crosses the boundary
|
|
101
|
+
need to be handled separately by the caller.
|
|
102
|
+
"""
|
|
103
|
+
if batchify_len is None:
|
|
104
|
+
return None
|
|
105
|
+
assert batchify_len > 0
|
|
106
|
+
assert input_bias.q_seqinfo.min_seqlen == input_bias.q_seqinfo.max_seqlen
|
|
107
|
+
|
|
108
|
+
if gpu_data is None:
|
|
109
|
+
gpu_data = split_blocks_for_decoding_gpu_part(
|
|
110
|
+
input_bias, batchify_len, block_tables, page_size
|
|
111
|
+
)
|
|
112
|
+
assert gpu_data is not None
|
|
113
|
+
k_seqstarts, k_lens = gpu_data
|
|
114
|
+
|
|
115
|
+
k_seqstarts_list = []
|
|
116
|
+
k_seqlens_list = []
|
|
117
|
+
k_seqlens_list_actual = []
|
|
118
|
+
B = len(input_bias.k_seqinfo.seqlen_py)
|
|
119
|
+
# About the shift, see the comment in split_blocks_for_decoding_gpu_part.
|
|
120
|
+
shift = int(isinstance(input_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask))
|
|
121
|
+
for i in range(B):
|
|
122
|
+
input_k_start_ = input_bias.k_seqinfo.seqstart_py[i]
|
|
123
|
+
input_k_len_ = input_bias.k_seqinfo.seqlen_py[i]
|
|
124
|
+
input_k_end_ = input_k_start_ + input_k_len_
|
|
125
|
+
k_seqstart = (input_k_end_ - shift) // batchify_len * batchify_len
|
|
126
|
+
if input_k_len_ == 0:
|
|
127
|
+
k_seqstart = input_k_end_
|
|
128
|
+
k_seqend = min(k_seqstart + batchify_len, input_k_end_)
|
|
129
|
+
k_len = k_seqend - k_seqstart
|
|
130
|
+
# NOTE: With chunked, `k_len` cannot exceed the original length `input_k_len_`, so we clamp it here.
|
|
131
|
+
k_len = min(k_len, input_k_len_)
|
|
132
|
+
|
|
133
|
+
if k_seqstart < 0:
|
|
134
|
+
k_len = k_seqstart = 0
|
|
135
|
+
k_seqstart = (
|
|
136
|
+
k_seqstart if block_tables is None else max(k_seqstart - input_k_start_, 0)
|
|
137
|
+
)
|
|
138
|
+
k_seqstarts_list.append(k_seqstart)
|
|
139
|
+
k_seqlens_list_actual.append(k_len)
|
|
140
|
+
k_seqlens_list.append(k_len if block_tables is None else k_len + k_seqstart)
|
|
141
|
+
|
|
142
|
+
OutBiasType = (
|
|
143
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask
|
|
144
|
+
if isinstance(input_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
|
145
|
+
else BlockDiagonalGappyKeysMask
|
|
146
|
+
)
|
|
147
|
+
PagedOutBiasType = (
|
|
148
|
+
PagedBlockDiagonalCausalWithOffsetGappyKeysMask
|
|
149
|
+
if isinstance(input_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
|
150
|
+
else PagedBlockDiagonalGappyKeysMask
|
|
151
|
+
)
|
|
152
|
+
if block_tables is None:
|
|
153
|
+
k_seqstarts_list.append(input_bias.k_seqinfo.seqstart_py[-1])
|
|
154
|
+
return OutBiasType(
|
|
155
|
+
q_seqinfo=input_bias.q_seqinfo,
|
|
156
|
+
k_seqinfo=_GappySeqInfo(
|
|
157
|
+
seqstart_py=k_seqstarts_list,
|
|
158
|
+
seqstart=k_seqstarts,
|
|
159
|
+
seqlen=k_lens,
|
|
160
|
+
seqlen_py=k_seqlens_list,
|
|
161
|
+
min_seqlen=min(k_seqlens_list),
|
|
162
|
+
max_seqlen=max(k_seqlens_list),
|
|
163
|
+
),
|
|
164
|
+
)
|
|
165
|
+
assert page_size is not None
|
|
166
|
+
return PagedOutBiasType(
|
|
167
|
+
q_seqinfo=input_bias.q_seqinfo,
|
|
168
|
+
k_seqinfo=_GappySeqInfo(
|
|
169
|
+
seqstart_py=k_seqstarts_list,
|
|
170
|
+
seqstart=k_seqstarts,
|
|
171
|
+
seqlen=k_lens,
|
|
172
|
+
seqlen_py=k_seqlens_list,
|
|
173
|
+
min_seqlen=min(k_seqlens_list_actual),
|
|
174
|
+
max_seqlen=max(k_seqlens_list_actual),
|
|
175
|
+
),
|
|
176
|
+
block_tables=block_tables,
|
|
177
|
+
page_size=page_size,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def split_blocks_for_prefill(
|
|
182
|
+
input_bias: BlockDiagonalPaddedKeysMask, batchify_len: Optional[int]
|
|
183
|
+
) -> Optional[BlockDiagonalPaddedKeysMask]:
|
|
184
|
+
"""
|
|
185
|
+
From
|
|
186
|
+
https://github.com/fairinternal/llm_inference/blob/11bbb2/llm_inference/models/disagg_transformer.py#L1955
|
|
187
|
+
"""
|
|
188
|
+
if batchify_len is None:
|
|
189
|
+
return None
|
|
190
|
+
padding = input_bias.k_seqinfo.padding
|
|
191
|
+
assert padding % batchify_len == 0, f"{padding} % {batchify_len} != 0"
|
|
192
|
+
split_factor = padding // batchify_len
|
|
193
|
+
batch_size = len(input_bias.q_seqinfo.seqstart_py) - 1
|
|
194
|
+
new_batch_size = batch_size * split_factor
|
|
195
|
+
k_seqlen = input_bias.k_seqinfo.seqlen
|
|
196
|
+
q_seqlen = input_bias.q_seqinfo.seqstart[1:] - input_bias.q_seqinfo.seqstart[:-1]
|
|
197
|
+
k_seqlen_each = k_seqlen.repeat_interleave(split_factor, output_size=new_batch_size)
|
|
198
|
+
q_seqlen_each = q_seqlen.repeat_interleave(split_factor, output_size=new_batch_size)
|
|
199
|
+
res_seqlen_each = k_seqlen_each - q_seqlen_each
|
|
200
|
+
seqpos = torch.arange(
|
|
201
|
+
0, padding, batchify_len, device=k_seqlen.device, dtype=k_seqlen.dtype
|
|
202
|
+
)
|
|
203
|
+
seqpos_start = seqpos.repeat(batch_size)
|
|
204
|
+
k_lengths = (k_seqlen_each - seqpos_start).clamp(min=0, max=batchify_len)
|
|
205
|
+
res_lengths = (res_seqlen_each - seqpos_start).clamp(min=0, max=batchify_len)
|
|
206
|
+
|
|
207
|
+
k_seqstart = torch.arange(
|
|
208
|
+
0,
|
|
209
|
+
new_batch_size * batchify_len + 1,
|
|
210
|
+
batchify_len,
|
|
211
|
+
device=k_seqlen.device,
|
|
212
|
+
dtype=k_seqlen.dtype,
|
|
213
|
+
)
|
|
214
|
+
k_seqstart_py = list(range(0, new_batch_size * batchify_len + 1, batchify_len))
|
|
215
|
+
q_seqstart = torch.zeros_like(k_seqstart)
|
|
216
|
+
torch.cumsum(k_lengths - res_lengths, 0, out=q_seqstart[1:])
|
|
217
|
+
|
|
218
|
+
# start at 2 to avoid reshaping issues with
|
|
219
|
+
# https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L602
|
|
220
|
+
max_q_len = 2
|
|
221
|
+
min_q_len = 2
|
|
222
|
+
max_k_len = 0
|
|
223
|
+
q_seqstart_list: List[int] = [0]
|
|
224
|
+
k_seqlen_list: List[int] = []
|
|
225
|
+
for i in range(len(input_bias.k_seqinfo.seqlen)):
|
|
226
|
+
q_seqlen_ = (
|
|
227
|
+
input_bias.q_seqinfo.seqstart_py[i + 1]
|
|
228
|
+
- input_bias.q_seqinfo.seqstart_py[i]
|
|
229
|
+
)
|
|
230
|
+
k_seqlen_ = input_bias.k_seqinfo.seqlen_py[i]
|
|
231
|
+
res_seqlen_ = k_seqlen_ - q_seqlen_
|
|
232
|
+
for seqpos_ in range(0, padding, batchify_len):
|
|
233
|
+
k_chunk_size = max(min(k_seqlen_ - seqpos_, batchify_len), 0)
|
|
234
|
+
res_chunk_size = max(min(res_seqlen_ - seqpos_, batchify_len), 0)
|
|
235
|
+
q_chunk_size = k_chunk_size - res_chunk_size
|
|
236
|
+
|
|
237
|
+
q_seqstart_list.append(q_seqstart_list[-1] + q_chunk_size)
|
|
238
|
+
k_seqlen_list.append(k_chunk_size)
|
|
239
|
+
if q_chunk_size > max_q_len:
|
|
240
|
+
max_q_len = q_chunk_size
|
|
241
|
+
if q_chunk_size < min_q_len:
|
|
242
|
+
min_q_len = q_chunk_size
|
|
243
|
+
if k_chunk_size > max_k_len:
|
|
244
|
+
max_k_len = k_chunk_size
|
|
245
|
+
|
|
246
|
+
batchify_attn_bias = input_bias.__class__(
|
|
247
|
+
q_seqinfo=_SeqLenInfo(
|
|
248
|
+
seqstart=q_seqstart,
|
|
249
|
+
max_seqlen=max_q_len,
|
|
250
|
+
min_seqlen=min_q_len,
|
|
251
|
+
seqstart_py=q_seqstart_list,
|
|
252
|
+
),
|
|
253
|
+
k_seqinfo=_PaddedSeqLenInfo(
|
|
254
|
+
seqstart=k_seqstart,
|
|
255
|
+
seqlen_py=k_seqlen_list,
|
|
256
|
+
seqlen=k_lengths,
|
|
257
|
+
padding=batchify_len,
|
|
258
|
+
seqstart_py=k_seqstart_py,
|
|
259
|
+
min_seqlen=0,
|
|
260
|
+
max_seqlen=max_k_len,
|
|
261
|
+
),
|
|
262
|
+
)
|
|
263
|
+
return batchify_attn_bias
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def maybe_make_paged(
|
|
267
|
+
attn_bias: Optional[
|
|
268
|
+
Union[
|
|
269
|
+
BlockDiagonalPaddedKeysMask,
|
|
270
|
+
BlockDiagonalGappyKeysMask,
|
|
271
|
+
]
|
|
272
|
+
],
|
|
273
|
+
block_tables: Optional[torch.Tensor],
|
|
274
|
+
page_size: int,
|
|
275
|
+
notional_padding: Optional[int],
|
|
276
|
+
) -> Optional[AttentionBias]:
|
|
277
|
+
"""
|
|
278
|
+
Convert attention bias into its paged version if block_tables is not None.
|
|
279
|
+
Args:
|
|
280
|
+
attn_bias: input attention bias.
|
|
281
|
+
block_tables: table of shape [batch_size, max_pages_per_lane]
|
|
282
|
+
redirecting from logical to physical pages.
|
|
283
|
+
page_size: number of tokens per page.
|
|
284
|
+
notional_padding: if input attention bias is gappy, it has
|
|
285
|
+
no notion of padding, sequence starts are arbitrary.
|
|
286
|
+
However, we need to know how to divide logical sequence space
|
|
287
|
+
into lanes corresponding to each row of block tables.
|
|
288
|
+
In other words, where is 0th block in i-th row of block table
|
|
289
|
+
located in the logical space?
|
|
290
|
+
This function assumes that it's located at i * notional_padding.
|
|
291
|
+
The value of notional_padding needs to be consisted which
|
|
292
|
+
padding used when block_tables was created.
|
|
293
|
+
For example, if a gappy bias was created from a padded bias
|
|
294
|
+
using split_blocks* functions, notional padding
|
|
295
|
+
should be equal to the padding of the original bias.
|
|
296
|
+
Returns:
|
|
297
|
+
Paged version of the original attention bias.
|
|
298
|
+
"""
|
|
299
|
+
if attn_bias is None:
|
|
300
|
+
return None
|
|
301
|
+
if block_tables is None:
|
|
302
|
+
return attn_bias
|
|
303
|
+
|
|
304
|
+
attn_batch_size = len(attn_bias.k_seqinfo.seqlen)
|
|
305
|
+
if attn_batch_size != block_tables.shape[0]:
|
|
306
|
+
# In case of iRoPE each batch lane has been split into smaller chunks,
|
|
307
|
+
# so we need to reshape the block tables accordingly.
|
|
308
|
+
block_tables = block_tables.view(attn_batch_size, -1)
|
|
309
|
+
if isinstance(attn_bias, BlockDiagonalGappyKeysMask):
|
|
310
|
+
assert notional_padding is not None, (
|
|
311
|
+
"Notional padding must be specified to create gappy paged biases."
|
|
312
|
+
)
|
|
313
|
+
return attn_bias.make_paged(
|
|
314
|
+
block_tables=block_tables,
|
|
315
|
+
page_size=page_size,
|
|
316
|
+
notional_padding=notional_padding,
|
|
317
|
+
paged_type=PagedBlockDiagonalGappyKeysMask,
|
|
318
|
+
)
|
|
319
|
+
if isinstance(attn_bias, PagedBlockDiagonalGappyKeysMask):
|
|
320
|
+
return attn_bias
|
|
321
|
+
paged_type = (
|
|
322
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask
|
|
323
|
+
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
|
324
|
+
else PagedBlockDiagonalPaddedKeysMask
|
|
325
|
+
)
|
|
326
|
+
assert isinstance(attn_bias, BlockDiagonalPaddedKeysMask)
|
|
327
|
+
return attn_bias.make_paged(
|
|
328
|
+
block_tables=block_tables, page_size=page_size, paged_type=paged_type
|
|
329
|
+
)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-strict
|
|
7
|
+
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch._C import parse_schema
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def is_pt_cutlass_compatible(force: bool = False) -> bool:
|
|
15
|
+
if torch.version.hip is not None:
|
|
16
|
+
if force:
|
|
17
|
+
raise ImportError("CUTLASS is not supported on ROCm")
|
|
18
|
+
return False
|
|
19
|
+
compatible = True
|
|
20
|
+
|
|
21
|
+
fwd_schema_str = (
|
|
22
|
+
"aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, "
|
|
23
|
+
"Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, "
|
|
24
|
+
"SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, "
|
|
25
|
+
"float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> "
|
|
26
|
+
"(Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, "
|
|
27
|
+
"SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)"
|
|
28
|
+
)
|
|
29
|
+
expected_fwd_schema = parse_schema(fwd_schema_str)
|
|
30
|
+
|
|
31
|
+
current_schema = torch.ops.aten._efficient_attention_forward.default._schema
|
|
32
|
+
if not current_schema.is_backward_compatible_with(expected_fwd_schema):
|
|
33
|
+
compatible = False
|
|
34
|
+
|
|
35
|
+
if force:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
f"Current Torch CUTLASS doesnt have a compatible aten::_efficient_attention_forward schema\n"
|
|
38
|
+
f"EXPECTED:\n{expected_fwd_schema}\n"
|
|
39
|
+
f"but GOT:\n{current_schema}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
bwd_schema_str = (
|
|
43
|
+
"aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, "
|
|
44
|
+
"Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, "
|
|
45
|
+
"SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, "
|
|
46
|
+
"int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, "
|
|
47
|
+
"int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor)"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
expected_bwd_schema = parse_schema(bwd_schema_str)
|
|
51
|
+
|
|
52
|
+
current_schema = torch.ops.aten._efficient_attention_backward.default._schema
|
|
53
|
+
if not current_schema.is_backward_compatible_with(expected_bwd_schema):
|
|
54
|
+
compatible = False
|
|
55
|
+
|
|
56
|
+
if force:
|
|
57
|
+
raise ImportError(
|
|
58
|
+
f"Current Torch CUTLASS doesnt have a compatible aten::_efficient_attention_backward schema\n"
|
|
59
|
+
f"EXPECTED:\n{expected_bwd_schema}\n"
|
|
60
|
+
f"but GOT:\n{current_schema}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return compatible
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def is_pt_flash_old(force: bool) -> Optional[bool]:
|
|
67
|
+
"""
|
|
68
|
+
Returns True if the current PyTorch version has the old Flash-Attention
|
|
69
|
+
ops instead of the new ones.
|
|
70
|
+
If it has none at all, raises an ImportError or returns None.
|
|
71
|
+
"""
|
|
72
|
+
if not torch.backends.cuda.is_flash_attention_available():
|
|
73
|
+
if force:
|
|
74
|
+
raise ImportError("Flash SDP backend is disabled")
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
if not hasattr(torch.nn, "attention") or not hasattr(
|
|
78
|
+
torch.nn.attention, "_get_flash_version"
|
|
79
|
+
):
|
|
80
|
+
if force:
|
|
81
|
+
raise ImportError(
|
|
82
|
+
f"Current Torch {torch.__version__} doesnt implement "
|
|
83
|
+
"torch.nn.attention._get_flash_version()"
|
|
84
|
+
)
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
FLASH_VERSION = torch.nn.attention._get_flash_version()
|
|
88
|
+
|
|
89
|
+
compatible = True
|
|
90
|
+
|
|
91
|
+
# old = before 25/2/2025
|
|
92
|
+
# https://github.com/pytorch/pytorch/commit/3ecfe6be256c585bcadf4c845d7119545444a222
|
|
93
|
+
old_fwd_schema_str = (
|
|
94
|
+
"aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, "
|
|
95
|
+
"Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, "
|
|
96
|
+
"bool is_causal, bool return_debug_mask, *, float? scale=None, "
|
|
97
|
+
"SymInt? window_size_left=None, SymInt? window_size_right=None, "
|
|
98
|
+
"Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, "
|
|
99
|
+
"Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)"
|
|
100
|
+
)
|
|
101
|
+
fwd_schema_str = (
|
|
102
|
+
"aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, "
|
|
103
|
+
"Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, "
|
|
104
|
+
"bool is_causal, bool return_debug_mask, *, float? scale=None, "
|
|
105
|
+
"SymInt? window_size_left=None, SymInt? window_size_right=None, "
|
|
106
|
+
"Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, "
|
|
107
|
+
"Tensor rng_state, Tensor unused, Tensor debug_attn_mask)"
|
|
108
|
+
)
|
|
109
|
+
expected_fwd_schema = parse_schema(fwd_schema_str)
|
|
110
|
+
expected_old_fwd_schema = parse_schema(old_fwd_schema_str)
|
|
111
|
+
|
|
112
|
+
current_schema = torch.ops.aten._flash_attention_forward.default._schema
|
|
113
|
+
old = current_schema.is_backward_compatible_with(expected_old_fwd_schema)
|
|
114
|
+
if not old and not current_schema.is_backward_compatible_with(expected_fwd_schema):
|
|
115
|
+
compatible = False
|
|
116
|
+
|
|
117
|
+
if force:
|
|
118
|
+
raise ImportError(
|
|
119
|
+
f"Current Torch with Flash-Attention {FLASH_VERSION} doesnt have "
|
|
120
|
+
"a compatible aten::_flash_attention_forward schema\n"
|
|
121
|
+
f"EXPECTED:\n{expected_old_fwd_schema}\n"
|
|
122
|
+
f"or:\n{expected_fwd_schema}\n"
|
|
123
|
+
f"but GOT:\n{current_schema}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
bwd_schema_old_str = (
|
|
127
|
+
"aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, "
|
|
128
|
+
"Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, "
|
|
129
|
+
"float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, "
|
|
130
|
+
"SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)"
|
|
131
|
+
)
|
|
132
|
+
bwd_schema_str = (
|
|
133
|
+
"aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, "
|
|
134
|
+
"Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, "
|
|
135
|
+
"float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, "
|
|
136
|
+
"SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)"
|
|
137
|
+
)
|
|
138
|
+
expected_bwd_schema = parse_schema(bwd_schema_old_str if old else bwd_schema_str)
|
|
139
|
+
|
|
140
|
+
current_schema = torch.ops.aten._flash_attention_backward.default._schema
|
|
141
|
+
if not current_schema.is_backward_compatible_with(expected_bwd_schema):
|
|
142
|
+
compatible = False
|
|
143
|
+
|
|
144
|
+
if force:
|
|
145
|
+
raise ImportError(
|
|
146
|
+
f"Current Torch with Flash-Attention {FLASH_VERSION} doesnt have "
|
|
147
|
+
"a compatible aten::_flash_attention_backward schema\n"
|
|
148
|
+
f"EXPECTED:\n{expected_bwd_schema}\n"
|
|
149
|
+
f"but GOT:\n{current_schema}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if not compatible:
|
|
153
|
+
return None
|
|
154
|
+
return old
|