mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple, Union
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from . import attn_bias
|
|
15
|
+
from .attn_bias import (
|
|
16
|
+
AttentionBias,
|
|
17
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
18
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
19
|
+
BlockDiagonalCausalMask,
|
|
20
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
21
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
22
|
+
BlockDiagonalGappyKeysMask,
|
|
23
|
+
BlockDiagonalMask,
|
|
24
|
+
BlockDiagonalPaddedKeysMask,
|
|
25
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
26
|
+
LowerTriangularFromBottomRightMask,
|
|
27
|
+
LowerTriangularMask,
|
|
28
|
+
LowerTriangularMaskWithTensorBias,
|
|
29
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
30
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
31
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
32
|
+
)
|
|
33
|
+
from .common import (
|
|
34
|
+
AttentionBwOpBase,
|
|
35
|
+
AttentionFwOpBase,
|
|
36
|
+
check_lastdim_alignment_stride1,
|
|
37
|
+
Context,
|
|
38
|
+
Gradients,
|
|
39
|
+
Inputs,
|
|
40
|
+
)
|
|
41
|
+
from .utils.op_common import get_operator, register_operator
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _minimum_gemm_alignment(inp: Inputs) -> int:
|
|
45
|
+
return 1
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _get_seqlen_info(
|
|
49
|
+
inp: Inputs,
|
|
50
|
+
) -> Tuple[
|
|
51
|
+
Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], int, int
|
|
52
|
+
]:
|
|
53
|
+
attn_bias = inp.attn_bias
|
|
54
|
+
if isinstance(
|
|
55
|
+
attn_bias,
|
|
56
|
+
(
|
|
57
|
+
BlockDiagonalMask,
|
|
58
|
+
BlockDiagonalGappyKeysMask,
|
|
59
|
+
BlockDiagonalPaddedKeysMask,
|
|
60
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
61
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
62
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
63
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
64
|
+
),
|
|
65
|
+
):
|
|
66
|
+
attn_bias.k_seqinfo.to(inp.query.device)
|
|
67
|
+
attn_bias.q_seqinfo.to(inp.query.device)
|
|
68
|
+
seqstart_k = attn_bias.k_seqinfo.seqstart
|
|
69
|
+
seqstart_q = attn_bias.q_seqinfo.seqstart
|
|
70
|
+
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
|
|
71
|
+
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
|
|
72
|
+
seqlen = (
|
|
73
|
+
None
|
|
74
|
+
if isinstance(attn_bias, BlockDiagonalMask)
|
|
75
|
+
else attn_bias.k_seqinfo.seqlen
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
seqstart_k = None
|
|
79
|
+
seqstart_q = None
|
|
80
|
+
max_seqlen_q = -1
|
|
81
|
+
max_seqlen_k = -1
|
|
82
|
+
seqlen = None
|
|
83
|
+
|
|
84
|
+
if isinstance(attn_bias, PagedBlockDiagonalGappyKeysMask):
|
|
85
|
+
assert seqstart_k is not None
|
|
86
|
+
assert seqlen is not None
|
|
87
|
+
seqstart_k = seqstart_k[:-1]
|
|
88
|
+
seqlen = seqlen - seqstart_k
|
|
89
|
+
|
|
90
|
+
return seqstart_k, seqstart_q, seqlen, max_seqlen_q, max_seqlen_k
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _get_tensor_bias(
|
|
94
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
|
|
95
|
+
) -> Optional[torch.Tensor]:
|
|
96
|
+
if isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
|
|
97
|
+
return attn_bias._bias
|
|
98
|
+
if isinstance(attn_bias, torch.Tensor):
|
|
99
|
+
return attn_bias
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _check_bias_alignment(
|
|
104
|
+
reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
|
105
|
+
) -> None:
|
|
106
|
+
attn_bias_tensor = _get_tensor_bias(attn_bias)
|
|
107
|
+
if attn_bias_tensor is not None:
|
|
108
|
+
alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits
|
|
109
|
+
show_padding_hint = False
|
|
110
|
+
for d in range(attn_bias_tensor.ndim - 1):
|
|
111
|
+
if attn_bias_tensor.stride(d) % alignment != 0:
|
|
112
|
+
reasons.append(
|
|
113
|
+
f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})"
|
|
114
|
+
)
|
|
115
|
+
show_padding_hint = True
|
|
116
|
+
if show_padding_hint:
|
|
117
|
+
reasons.append(
|
|
118
|
+
"""\
|
|
119
|
+
HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \
|
|
120
|
+
you need to ensure memory is aligned by slicing a bigger tensor. \
|
|
121
|
+
Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`"""
|
|
122
|
+
)
|
|
123
|
+
# We can have stride=0 sometimes if dimension=1
|
|
124
|
+
if attn_bias_tensor.stride(-1) > 1:
|
|
125
|
+
reasons.append(
|
|
126
|
+
f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - "
|
|
127
|
+
"you should call `.contiguous()` on the bias"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class _CustomMaskType(int, Enum):
|
|
132
|
+
"""
|
|
133
|
+
(Matches CustomMaskType in C++.)
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
NoCustomMask = 0
|
|
137
|
+
CausalFromTopLeft = 1
|
|
138
|
+
CausalFromBottomRight = 2
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
|
|
142
|
+
if isinstance(
|
|
143
|
+
bias,
|
|
144
|
+
(
|
|
145
|
+
LowerTriangularMask,
|
|
146
|
+
BlockDiagonalCausalMask,
|
|
147
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
148
|
+
),
|
|
149
|
+
):
|
|
150
|
+
return int(_CustomMaskType.CausalFromTopLeft)
|
|
151
|
+
if isinstance(
|
|
152
|
+
bias,
|
|
153
|
+
(
|
|
154
|
+
LowerTriangularFromBottomRightMask,
|
|
155
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
156
|
+
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
|
157
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
158
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
159
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
160
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
161
|
+
),
|
|
162
|
+
):
|
|
163
|
+
return int(_CustomMaskType.CausalFromBottomRight)
|
|
164
|
+
return int(_CustomMaskType.NoCustomMask)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@register_operator
|
|
168
|
+
class FwOp(AttentionFwOpBase):
|
|
169
|
+
"""xFormers' MHA kernel based on Composable Kernel."""
|
|
170
|
+
|
|
171
|
+
OPERATOR = get_operator("xformers", "efficient_attention_forward_ck")
|
|
172
|
+
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
|
173
|
+
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
|
|
174
|
+
SUPPORTED_MAX_K = 512
|
|
175
|
+
|
|
176
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
177
|
+
type(None),
|
|
178
|
+
torch.Tensor,
|
|
179
|
+
LowerTriangularMask,
|
|
180
|
+
LowerTriangularFromBottomRightMask,
|
|
181
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
182
|
+
LowerTriangularMaskWithTensorBias,
|
|
183
|
+
BlockDiagonalMask,
|
|
184
|
+
BlockDiagonalCausalMask,
|
|
185
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
186
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
187
|
+
BlockDiagonalGappyKeysMask,
|
|
188
|
+
BlockDiagonalPaddedKeysMask,
|
|
189
|
+
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
|
190
|
+
attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
|
191
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
192
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
193
|
+
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
194
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
SUPPORTS_DROPOUT = True
|
|
198
|
+
SUPPORTS_CUSTOM_SCALE = True
|
|
199
|
+
SUPPORTS_DIFFERENT_VALUE_EMBED = True
|
|
200
|
+
SUPPORTS_PARTIAL = True
|
|
201
|
+
SUPPORTS_BMGHK = True
|
|
202
|
+
VARLEN_LSE_PACKED = True
|
|
203
|
+
NAME = "ckF"
|
|
204
|
+
|
|
205
|
+
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
|
206
|
+
torch.float: 3e-4,
|
|
207
|
+
torch.half: 6e-3,
|
|
208
|
+
torch.bfloat16: 2.8e-2,
|
|
209
|
+
}
|
|
210
|
+
ERROR_RTOL: Mapping[torch.dtype, float] = {
|
|
211
|
+
torch.float: 2e-5,
|
|
212
|
+
torch.half: 3e-3,
|
|
213
|
+
torch.bfloat16: 2e-2,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
_TEST_K: List[int] = [
|
|
217
|
+
32, # 64x64 kernel
|
|
218
|
+
96,
|
|
219
|
+
128, # 64x128 kernel
|
|
220
|
+
256, # 64x128 with accumulation in gmem
|
|
221
|
+
512,
|
|
222
|
+
]
|
|
223
|
+
|
|
224
|
+
@classmethod
|
|
225
|
+
def apply(
|
|
226
|
+
cls, inp: Inputs, needs_gradient: bool
|
|
227
|
+
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
228
|
+
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
|
229
|
+
raise NotImplementedError("Unsupported attn_bias type")
|
|
230
|
+
if inp.query.ndim in [1, 2, 3]:
|
|
231
|
+
raise NotImplementedError("Unsupported number of dimensions")
|
|
232
|
+
if inp.query.ndim in [4]:
|
|
233
|
+
return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
|
|
234
|
+
assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
|
|
235
|
+
ctx: Optional[Context] = None
|
|
236
|
+
|
|
237
|
+
# when the input is expanded 5-D, the group dimension has zero stride
|
|
238
|
+
if inp.key.stride()[3] == 0:
|
|
239
|
+
assert inp.value.stride()[3] == 0, (
|
|
240
|
+
"key and value should be expanded in the same way"
|
|
241
|
+
)
|
|
242
|
+
k_shape = inp.key.size()
|
|
243
|
+
k_stride = inp.key.stride()
|
|
244
|
+
key = inp.key.as_strided(
|
|
245
|
+
(k_shape[0], k_shape[1], k_shape[2], k_shape[4]),
|
|
246
|
+
(k_stride[0], k_stride[1], k_stride[2], k_stride[4]),
|
|
247
|
+
)
|
|
248
|
+
v_shape = inp.value.size()
|
|
249
|
+
v_stride = inp.value.stride()
|
|
250
|
+
value = inp.value.as_strided(
|
|
251
|
+
(v_shape[0], v_shape[1], v_shape[2], v_shape[4]),
|
|
252
|
+
(v_stride[0], v_stride[1], v_stride[2], v_stride[4]),
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
key = inp.key.flatten(2, 3)
|
|
256
|
+
value = inp.value.flatten(2, 3)
|
|
257
|
+
|
|
258
|
+
[_, _, G, Hq, _] = inp.query.shape
|
|
259
|
+
attn_bias_replace = inp.attn_bias
|
|
260
|
+
if isinstance(inp.attn_bias, LowerTriangularMaskWithTensorBias):
|
|
261
|
+
bias_tensor = _get_tensor_bias(inp.attn_bias)
|
|
262
|
+
if bias_tensor is not None and bias_tensor.ndim == 5:
|
|
263
|
+
attn_bias_replace = LowerTriangularMaskWithTensorBias(
|
|
264
|
+
bias_tensor.flatten(1, 2)
|
|
265
|
+
)
|
|
266
|
+
elif isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.ndim == 5:
|
|
267
|
+
attn_bias_replace = inp.attn_bias.flatten(1, 2)
|
|
268
|
+
inp = Inputs(
|
|
269
|
+
query=inp.query.flatten(2, 3),
|
|
270
|
+
key=key,
|
|
271
|
+
value=value,
|
|
272
|
+
attn_bias=attn_bias_replace,
|
|
273
|
+
p=inp.p,
|
|
274
|
+
scale=inp.scale,
|
|
275
|
+
output_dtype=inp.output_dtype,
|
|
276
|
+
is_partial=inp.is_partial,
|
|
277
|
+
)
|
|
278
|
+
out, ctx = cls.apply_bmhk(inp, needs_gradient=needs_gradient)
|
|
279
|
+
out = out.unflatten(2, (G, Hq))
|
|
280
|
+
if ctx is not None:
|
|
281
|
+
lse = ctx.lse.unflatten(1, (G, Hq))
|
|
282
|
+
ctx = Context(
|
|
283
|
+
lse=lse,
|
|
284
|
+
out=out,
|
|
285
|
+
op_bw=ctx.op_bw,
|
|
286
|
+
rng_state=ctx.rng_state,
|
|
287
|
+
qkv_share_storage=ctx.qkv_share_storage,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
return out, ctx
|
|
291
|
+
|
|
292
|
+
@classmethod
|
|
293
|
+
def apply_bmhk(
|
|
294
|
+
cls, inp: Inputs, needs_gradient: bool
|
|
295
|
+
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
296
|
+
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
|
297
|
+
raise NotImplementedError("Unsupported attn_bias type")
|
|
298
|
+
seqstart_k, seqstart_q, seqlen_k, max_seqlen_q, _ = _get_seqlen_info(inp)
|
|
299
|
+
out, lse, rng_seed, rng_offset = cls.OPERATOR(
|
|
300
|
+
query=inp.query,
|
|
301
|
+
key=inp.key,
|
|
302
|
+
value=inp.value,
|
|
303
|
+
attn_bias=_get_tensor_bias(inp.attn_bias),
|
|
304
|
+
seqstart_q=seqstart_q,
|
|
305
|
+
seqstart_k=seqstart_k,
|
|
306
|
+
max_seqlen_q=max_seqlen_q,
|
|
307
|
+
dropout_p=inp.p,
|
|
308
|
+
compute_logsumexp=needs_gradient,
|
|
309
|
+
custom_mask_type=_custom_mask_type(inp.attn_bias),
|
|
310
|
+
scale=inp.scale,
|
|
311
|
+
seqlen_k=seqlen_k,
|
|
312
|
+
window_size=(
|
|
313
|
+
inp.attn_bias._window_size
|
|
314
|
+
if isinstance(
|
|
315
|
+
inp.attn_bias,
|
|
316
|
+
(
|
|
317
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
318
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
319
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
320
|
+
),
|
|
321
|
+
)
|
|
322
|
+
else None
|
|
323
|
+
),
|
|
324
|
+
block_tables=(
|
|
325
|
+
inp.attn_bias.block_tables
|
|
326
|
+
if isinstance(
|
|
327
|
+
inp.attn_bias,
|
|
328
|
+
(
|
|
329
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
330
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
331
|
+
),
|
|
332
|
+
)
|
|
333
|
+
else None
|
|
334
|
+
),
|
|
335
|
+
page_size=(
|
|
336
|
+
inp.attn_bias.page_size
|
|
337
|
+
if isinstance(
|
|
338
|
+
inp.attn_bias,
|
|
339
|
+
(
|
|
340
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
341
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
342
|
+
),
|
|
343
|
+
)
|
|
344
|
+
else None
|
|
345
|
+
),
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
ctx: Optional[Context] = None
|
|
349
|
+
if needs_gradient:
|
|
350
|
+
ctx = Context(
|
|
351
|
+
out=out,
|
|
352
|
+
# lse=_post_process_lse(lse, inp, tuple(original_query_shape)),
|
|
353
|
+
lse=lse,
|
|
354
|
+
# cutlass forward is only compatible with cutlass backward if
|
|
355
|
+
# dropout is used (because of the way RNG states are passed and the
|
|
356
|
+
# way random numbers are generated during backward)
|
|
357
|
+
op_bw=BwOp if inp.p != 0 else None,
|
|
358
|
+
)
|
|
359
|
+
if inp.p != 0:
|
|
360
|
+
ctx.rng_state = torch.tensor(
|
|
361
|
+
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
|
|
362
|
+
)
|
|
363
|
+
return out, ctx
|
|
364
|
+
|
|
365
|
+
@classmethod
|
|
366
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
367
|
+
reasons = super(FwOp, cls).not_supported_reasons(d)
|
|
368
|
+
matmul_alignment_mn = _minimum_gemm_alignment(d)
|
|
369
|
+
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
|
|
370
|
+
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
|
|
371
|
+
_check_bias_alignment(reasons, d.attn_bias)
|
|
372
|
+
return reasons
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@register_operator
|
|
376
|
+
class BwOp(AttentionBwOpBase):
|
|
377
|
+
__doc__ = FwOp.__doc__
|
|
378
|
+
|
|
379
|
+
OPERATOR = get_operator("xformers", "efficient_attention_backward_ck")
|
|
380
|
+
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
|
381
|
+
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
|
382
|
+
SUPPORTED_MAX_K = 256
|
|
383
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
384
|
+
type(None),
|
|
385
|
+
torch.Tensor,
|
|
386
|
+
LowerTriangularMask,
|
|
387
|
+
LowerTriangularFromBottomRightMask,
|
|
388
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
389
|
+
# TODO: Fix handling of gradient through the fMHA autograd function
|
|
390
|
+
# LowerTriangularMaskWithTensorBias,
|
|
391
|
+
BlockDiagonalMask,
|
|
392
|
+
BlockDiagonalCausalMask,
|
|
393
|
+
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
|
394
|
+
attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
|
395
|
+
)
|
|
396
|
+
SUPPORTS_ATTN_BIAS_GRAD = True
|
|
397
|
+
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
|
398
|
+
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
|
399
|
+
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
|
400
|
+
SUPPORTS_UNPADDED_LSE = True
|
|
401
|
+
NAME = "ckB"
|
|
402
|
+
|
|
403
|
+
_TEST_K: List[int] = [
|
|
404
|
+
32, # 64x64 kernel
|
|
405
|
+
64,
|
|
406
|
+
96,
|
|
407
|
+
128, # 64x128/128x128 kernel
|
|
408
|
+
256,
|
|
409
|
+
]
|
|
410
|
+
|
|
411
|
+
@classmethod
|
|
412
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
413
|
+
reasons = super(BwOp, cls).not_supported_reasons(d)
|
|
414
|
+
matmul_alignment_mn = _minimum_gemm_alignment(d)
|
|
415
|
+
|
|
416
|
+
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
|
|
417
|
+
check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
|
|
418
|
+
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
|
|
419
|
+
_check_bias_alignment(reasons, d.attn_bias)
|
|
420
|
+
attn_bias_tensor = _get_tensor_bias(d.attn_bias)
|
|
421
|
+
|
|
422
|
+
# Backprop of gradient through broadcasted bias is not supported
|
|
423
|
+
if attn_bias_tensor is not None and attn_bias_tensor.requires_grad:
|
|
424
|
+
# Don't forget that inputs are either in BMK or BMHK!
|
|
425
|
+
if d.query.ndim == 3 and attn_bias_tensor.ndim == 3:
|
|
426
|
+
expected_bias_shape = (*d.query.shape[:2], d.key.shape[1])
|
|
427
|
+
else:
|
|
428
|
+
# bias is B H Mq Mk
|
|
429
|
+
expected_bias_shape = (
|
|
430
|
+
d.query.shape[0],
|
|
431
|
+
d.query.shape[2] if d.query.ndim == 4 else 1,
|
|
432
|
+
d.query.shape[1],
|
|
433
|
+
d.key.shape[1],
|
|
434
|
+
)
|
|
435
|
+
if tuple(attn_bias_tensor.shape) != expected_bias_shape:
|
|
436
|
+
reasons.append(
|
|
437
|
+
"Broadcasting the `attn_bias` tensor is not supported "
|
|
438
|
+
f"(shape: {tuple(attn_bias_tensor.shape)}"
|
|
439
|
+
f"/ expected: {expected_bias_shape})"
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
return reasons
|
|
443
|
+
|
|
444
|
+
@classmethod
|
|
445
|
+
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
|
446
|
+
if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
|
447
|
+
raise NotImplementedError("Unsupported attn_bias type")
|
|
448
|
+
|
|
449
|
+
seqstart_k, seqstart_q, seqlen_k, max_seqlen_q, max_seqlen_k = _get_seqlen_info(
|
|
450
|
+
inp
|
|
451
|
+
)
|
|
452
|
+
dtype = inp.query.dtype
|
|
453
|
+
|
|
454
|
+
rng_seed = rng_offset = 0
|
|
455
|
+
if inp.p != 0.0:
|
|
456
|
+
if (
|
|
457
|
+
ctx.rng_state is None
|
|
458
|
+
or ctx.rng_state.dtype != torch.int64
|
|
459
|
+
or ctx.rng_state.device.type != "cpu"
|
|
460
|
+
or ctx.rng_state.shape != (2,)
|
|
461
|
+
):
|
|
462
|
+
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
|
|
463
|
+
rng_seed, rng_offset = ctx.rng_state.tolist()
|
|
464
|
+
|
|
465
|
+
(grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
|
|
466
|
+
grad.to(dtype),
|
|
467
|
+
inp.query,
|
|
468
|
+
inp.key,
|
|
469
|
+
inp.value,
|
|
470
|
+
attn_bias=_get_tensor_bias(inp.attn_bias),
|
|
471
|
+
seqstart_q=seqstart_q,
|
|
472
|
+
seqstart_k=seqstart_k,
|
|
473
|
+
max_seqlen_q=max_seqlen_q,
|
|
474
|
+
max_seqlen_k=max_seqlen_k,
|
|
475
|
+
seqlen_k=seqlen_k,
|
|
476
|
+
logsumexp=ctx.lse,
|
|
477
|
+
output=ctx.out.to(dtype),
|
|
478
|
+
dropout_p=inp.p,
|
|
479
|
+
# if not using dropout, seed and offset are irrelevant but still expected
|
|
480
|
+
# in function signature so just pass 0
|
|
481
|
+
# seed and offset could be None if a different FW op other than cutlass
|
|
482
|
+
# was used.
|
|
483
|
+
rng_seed=rng_seed,
|
|
484
|
+
rng_offset=rng_offset,
|
|
485
|
+
custom_mask_type=_custom_mask_type(inp.attn_bias),
|
|
486
|
+
scale=inp.scale,
|
|
487
|
+
window_size=(
|
|
488
|
+
inp.attn_bias._window_size
|
|
489
|
+
if isinstance(
|
|
490
|
+
inp.attn_bias,
|
|
491
|
+
(
|
|
492
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
493
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
494
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
495
|
+
),
|
|
496
|
+
)
|
|
497
|
+
else None
|
|
498
|
+
),
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# c++/CUDA implementation returns an uninitialized tensor if bias doesn't
|
|
502
|
+
# require grad
|
|
503
|
+
if not (
|
|
504
|
+
isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad
|
|
505
|
+
):
|
|
506
|
+
grad_bias = None
|
|
507
|
+
|
|
508
|
+
return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
from typing import Any, Iterable, List, Optional, Set, Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
|
13
|
+
from .common import AttentionFwOpBase, Context, Inputs
|
|
14
|
+
from .utils.op_common import get_operator, register_operator
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_operator
|
|
18
|
+
class FwOp(AttentionFwOpBase):
|
|
19
|
+
"""
|
|
20
|
+
An operator optimized for K=256 (so the contiguous dim fits into registers).
|
|
21
|
+
Tested to work on MI250x.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
OPERATOR = get_operator("xformers", "efficient_attention_forward_decoder_ck")
|
|
25
|
+
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
|
26
|
+
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16, torch.float}
|
|
27
|
+
SUPPORTED_MAX_K: int = 256
|
|
28
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
29
|
+
type(None),
|
|
30
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
31
|
+
)
|
|
32
|
+
SUPPORTS_DROPOUT = False
|
|
33
|
+
SUPPORTS_CUSTOM_SCALE = True
|
|
34
|
+
SUPPORTS_BMGHK = True
|
|
35
|
+
NAME = "ck_decoderF"
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]: # noqa: C901
|
|
39
|
+
reasons = super(FwOp, cls).not_supported_reasons(d)
|
|
40
|
+
|
|
41
|
+
attn_bias = d.attn_bias
|
|
42
|
+
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
|
43
|
+
if d.query.shape[0] != 1:
|
|
44
|
+
reasons.append(
|
|
45
|
+
f"One formal batch element expected; got {d.query.shape[0]}"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if d.query.shape[-1] > cls.SUPPORTED_MAX_K:
|
|
49
|
+
reasons.append(
|
|
50
|
+
f"Got head_dim={d.query.shape[-1]}; only head_dim<={cls.SUPPORTED_MAX_K} is supported for now."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
threads_per_warp = 64 # TODO: ideally query the platform here
|
|
54
|
+
required_alignment = 0
|
|
55
|
+
head_dim = d.query.shape[-1]
|
|
56
|
+
for vec_size in (4, 2, 1):
|
|
57
|
+
if head_dim <= vec_size * threads_per_warp:
|
|
58
|
+
required_alignment = vec_size
|
|
59
|
+
|
|
60
|
+
if not required_alignment:
|
|
61
|
+
reasons.append(f"Got head_dim={head_dim} which is too large")
|
|
62
|
+
|
|
63
|
+
if head_dim % required_alignment != 0:
|
|
64
|
+
reasons.append(
|
|
65
|
+
f"Got head_dim={head_dim}; it needs to be divisible by {required_alignment}"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if d.key.stride(-1) != 1:
|
|
69
|
+
reasons.append("expect keys to have last dim contiguous")
|
|
70
|
+
|
|
71
|
+
if d.value.stride(-1) != 1:
|
|
72
|
+
reasons.append("expect values to have last dim contiguous")
|
|
73
|
+
|
|
74
|
+
q_starts = attn_bias.q_seqinfo.seqstart_py
|
|
75
|
+
padding = attn_bias.k_seqinfo.padding
|
|
76
|
+
bsz = d.key.shape[1] // padding
|
|
77
|
+
num_queries = d.query.shape[1] // bsz
|
|
78
|
+
|
|
79
|
+
if q_starts != list(range(0, 1 + bsz, num_queries)):
|
|
80
|
+
reasons.append("expect to have same num_queries in each batch")
|
|
81
|
+
if bsz != len(q_starts) - 1:
|
|
82
|
+
reasons.append("empty lanes not supported yet")
|
|
83
|
+
|
|
84
|
+
if attn_bias.k_seqinfo.padding > 8192:
|
|
85
|
+
reasons.append("key padding exceeds 8192")
|
|
86
|
+
|
|
87
|
+
return reasons
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def apply(
|
|
91
|
+
cls, inp: Inputs, needs_gradient: bool
|
|
92
|
+
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
93
|
+
if needs_gradient:
|
|
94
|
+
raise NotImplementedError("backward pass is not supported")
|
|
95
|
+
attn_bias = inp.attn_bias
|
|
96
|
+
q, k, v = inp.get_qkv_in_bmghk()
|
|
97
|
+
if attn_bias is not None:
|
|
98
|
+
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
|
99
|
+
attn_bias.k_seqinfo.to(k.device)
|
|
100
|
+
attn_bias.q_seqinfo.to(q.device)
|
|
101
|
+
padding = attn_bias.k_seqinfo.padding
|
|
102
|
+
seq_positions_gpu = attn_bias.k_seqinfo.seqlen
|
|
103
|
+
else:
|
|
104
|
+
padding = k.shape[1]
|
|
105
|
+
seq_positions_gpu = None
|
|
106
|
+
|
|
107
|
+
if attn_bias is not None:
|
|
108
|
+
# key: (1, B * padding, G, 1 if multiquery else Hkv, D)
|
|
109
|
+
# value: like key
|
|
110
|
+
# query: (1, B * q_seqlen, G, Hq, D)
|
|
111
|
+
multiquery = k.stride(3) == 0
|
|
112
|
+
if multiquery:
|
|
113
|
+
key = k[0, :, :, :1].unflatten(0, (-1, padding))
|
|
114
|
+
value = v[0, :, :, :1].unflatten(0, (-1, padding))
|
|
115
|
+
else:
|
|
116
|
+
key = k[0].unflatten(0, (-1, padding))
|
|
117
|
+
value = v[0].unflatten(0, (-1, padding))
|
|
118
|
+
query = q[0].unflatten(0, (key.shape[0], -1))
|
|
119
|
+
else:
|
|
120
|
+
# key: (B, padding, G, 1 if multiquery else Hkv, D)
|
|
121
|
+
# value: like key
|
|
122
|
+
# query: (B, q_seqlen, G, Hq, D)
|
|
123
|
+
key = k
|
|
124
|
+
query = q
|
|
125
|
+
value = v
|
|
126
|
+
|
|
127
|
+
if inp.scale is not None:
|
|
128
|
+
qk_scale = inp.scale
|
|
129
|
+
else:
|
|
130
|
+
qk_scale = torch.rsqrt(
|
|
131
|
+
torch.tensor(key.shape[-1], dtype=torch.float32)
|
|
132
|
+
).item()
|
|
133
|
+
|
|
134
|
+
out = cls.OPERATOR(
|
|
135
|
+
query=query,
|
|
136
|
+
key=key,
|
|
137
|
+
value=value,
|
|
138
|
+
seq_positions=seq_positions_gpu,
|
|
139
|
+
scale=qk_scale,
|
|
140
|
+
)
|
|
141
|
+
return out, None
|