mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,598 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from functools import partial
|
|
11
|
+
from typing import (
|
|
12
|
+
Any,
|
|
13
|
+
Callable,
|
|
14
|
+
Iterable,
|
|
15
|
+
List,
|
|
16
|
+
Mapping,
|
|
17
|
+
Optional,
|
|
18
|
+
Set,
|
|
19
|
+
Tuple,
|
|
20
|
+
Type,
|
|
21
|
+
Union,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
|
|
26
|
+
from .attn_bias import (
|
|
27
|
+
AttentionBias,
|
|
28
|
+
BlockDiagonalGappyKeysMask,
|
|
29
|
+
BlockDiagonalMask,
|
|
30
|
+
BlockDiagonalPaddedKeysMask,
|
|
31
|
+
LowerTriangularMask,
|
|
32
|
+
LowerTriangularMaskWithTensorBias,
|
|
33
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
34
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
35
|
+
)
|
|
36
|
+
from .utils.cpp_lib import _built_with_cuda
|
|
37
|
+
from .utils.op_common import BaseOperator
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
|
|
41
|
+
# NoneType
|
|
42
|
+
if isinstance(None, attn_bias_type):
|
|
43
|
+
return True
|
|
44
|
+
if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
|
|
45
|
+
return True
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _attn_bias_apply(
|
|
50
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
|
|
51
|
+
op: Callable[[torch.Tensor], torch.Tensor],
|
|
52
|
+
) -> Optional[Union[torch.Tensor, AttentionBias]]:
|
|
53
|
+
if isinstance(attn_bias, torch.Tensor):
|
|
54
|
+
return op(attn_bias)
|
|
55
|
+
if isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
|
|
56
|
+
return LowerTriangularMaskWithTensorBias(op(attn_bias._bias))
|
|
57
|
+
return attn_bias
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ScaledTensor(torch.Tensor):
|
|
61
|
+
__slots__ = ["scale", "dequant_func", "original_dtype"]
|
|
62
|
+
|
|
63
|
+
# Disabling custom torch function handling for this class
|
|
64
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def __new__(
|
|
68
|
+
cls,
|
|
69
|
+
data: torch.Tensor,
|
|
70
|
+
scale: torch.Tensor,
|
|
71
|
+
dequant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
72
|
+
original_dtype: torch.dtype,
|
|
73
|
+
require_grad: bool = False,
|
|
74
|
+
) -> "ScaledTensor":
|
|
75
|
+
"""
|
|
76
|
+
Creates a new ScaledTensor subclass instance.
|
|
77
|
+
|
|
78
|
+
Parameters:
|
|
79
|
+
- data: The underlying quantized tensor (e.g., int8, int4).
|
|
80
|
+
- scale: The scale tensor or scalar to be used for dequantization.
|
|
81
|
+
- dequant_func: A callable that applies dequantization, which takes both the data and scale as input.
|
|
82
|
+
- original_dtype: The data type before quantization (e.g., float32, float16).
|
|
83
|
+
- require_grad: Whether or not to track gradients (default: False for inference use).
|
|
84
|
+
"""
|
|
85
|
+
# Use _make_subclass to create a new ScaledTensor instance, which is a subclass of torch.Tensor.
|
|
86
|
+
instance = torch.Tensor._make_subclass(cls, data, require_grad)
|
|
87
|
+
|
|
88
|
+
# Store the dequantization scale and function as attributes.
|
|
89
|
+
instance.scale = scale # type: ignore
|
|
90
|
+
instance.dequant_func = dequant_func # type: ignore
|
|
91
|
+
|
|
92
|
+
# Store the original data type of the tensor, so we can cast it back after dequantization.
|
|
93
|
+
instance.original_dtype = original_dtype # type: ignore
|
|
94
|
+
|
|
95
|
+
# Return the new instance of ScaledTensor.
|
|
96
|
+
return instance
|
|
97
|
+
|
|
98
|
+
def dequantize(self) -> torch.Tensor:
|
|
99
|
+
"""
|
|
100
|
+
Applies the custom dequantization function provided at the tensor's creation.
|
|
101
|
+
After dequantization, the data is cast back to its original data type.
|
|
102
|
+
"""
|
|
103
|
+
# Explicitly create a new torch.Tensor to ensure the return type is torch.Tensor, not ScaledTensor.
|
|
104
|
+
data = torch.Tensor(self.float())
|
|
105
|
+
|
|
106
|
+
# Call the dequantization function, passing in the data and the scale.
|
|
107
|
+
dequantized_data = self.dequant_func(data, self.scale) # type: ignore
|
|
108
|
+
|
|
109
|
+
# Cast the dequantized data back to the original data type.
|
|
110
|
+
return dequantized_data.to(self.original_dtype) # type: ignore
|
|
111
|
+
|
|
112
|
+
def unpack(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
113
|
+
"""
|
|
114
|
+
Unpacks the ScaledTensor by returning its data and scale as a tuple.
|
|
115
|
+
Returns:
|
|
116
|
+
- A tuple of (data, scale), both of which are torch.Tensor objects.
|
|
117
|
+
"""
|
|
118
|
+
return self.data, self.scale # type: ignore
|
|
119
|
+
|
|
120
|
+
def __repr__(self):
|
|
121
|
+
"""
|
|
122
|
+
Custom string representation for ScaledTensor.
|
|
123
|
+
"""
|
|
124
|
+
return f"ScaledTensor(data={self.data}, scale={self.scale}, original_dtype={self.original_dtype})"
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def pack_fp8_tensorwise_per_head(
|
|
128
|
+
x: torch.Tensor, scale: Union[torch.Tensor, float], original_dtype
|
|
129
|
+
) -> ScaledTensor:
|
|
130
|
+
"""
|
|
131
|
+
Pack a tensor into a tensorwise fp8 ScaledTensor.
|
|
132
|
+
"""
|
|
133
|
+
if isinstance(scale, float):
|
|
134
|
+
scale = torch.tensor([scale], device=x.device)
|
|
135
|
+
|
|
136
|
+
def dequant_func(x, scale):
|
|
137
|
+
return x * scale[:, None, :, None]
|
|
138
|
+
|
|
139
|
+
return ScaledTensor(
|
|
140
|
+
data=x,
|
|
141
|
+
scale=scale,
|
|
142
|
+
dequant_func=dequant_func,
|
|
143
|
+
original_dtype=original_dtype,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class Inputs:
|
|
149
|
+
"""
|
|
150
|
+
Stores inputs to the `memory_efficient_attention` operators
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
query: torch.Tensor
|
|
154
|
+
key: torch.Tensor
|
|
155
|
+
value: torch.Tensor
|
|
156
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
|
|
157
|
+
p: float = 0.0
|
|
158
|
+
scale: Optional[float] = None
|
|
159
|
+
output_dtype: Optional[torch.dtype] = None
|
|
160
|
+
is_partial: bool = False
|
|
161
|
+
quantize_pv_to_fp8: bool = False
|
|
162
|
+
quantize_qk_to_fp8: bool = False
|
|
163
|
+
use_fp32_scales: bool = False
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def device(self) -> torch.device:
|
|
167
|
+
return self.query.device
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def scale_float(self) -> float:
|
|
171
|
+
return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
|
|
172
|
+
|
|
173
|
+
def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
174
|
+
if self.query.ndim == 5:
|
|
175
|
+
return self.query, self.key, self.value
|
|
176
|
+
if self.query.ndim == 4:
|
|
177
|
+
return (
|
|
178
|
+
self.query.unsqueeze(2),
|
|
179
|
+
self.key.unsqueeze(2),
|
|
180
|
+
self.value.unsqueeze(2),
|
|
181
|
+
)
|
|
182
|
+
if self.value.ndim == 3:
|
|
183
|
+
return (
|
|
184
|
+
self.query[:, :, None, None],
|
|
185
|
+
self.key[:, :, None, None],
|
|
186
|
+
self.value[:, :, None, None],
|
|
187
|
+
)
|
|
188
|
+
raise AssertionError
|
|
189
|
+
|
|
190
|
+
def normalize_bmhk(self) -> Tuple[int, ...]:
|
|
191
|
+
if self.query.ndim not in [3, 4, 5]:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Invalid shape for query: {self.query.shape}. "
|
|
194
|
+
"Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
|
|
195
|
+
", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
|
|
196
|
+
)
|
|
197
|
+
if self.value.dtype == torch.int32:
|
|
198
|
+
# Quantized K/V case, in which the last dims of Q and K are different.
|
|
199
|
+
# NB we currently don't have any implementations for quantized KV with
|
|
200
|
+
# SUPPORTS_DIFFERENT_VALUE_EMBED.
|
|
201
|
+
output_shape: Tuple[int, ...] = tuple(self.query.shape)
|
|
202
|
+
else:
|
|
203
|
+
output_shape = tuple(self.query.shape)[:-1] + (self.value.shape[-1],)
|
|
204
|
+
# Convert from legacy format
|
|
205
|
+
if self.query.ndim == 3:
|
|
206
|
+
self.query = self.query.unsqueeze(2)
|
|
207
|
+
self.key = self.key.unsqueeze(2)
|
|
208
|
+
self.value = self.value.unsqueeze(2)
|
|
209
|
+
self.attn_bias = _attn_bias_apply(
|
|
210
|
+
self.attn_bias, partial(torch.unsqueeze, dim=1)
|
|
211
|
+
)
|
|
212
|
+
return output_shape
|
|
213
|
+
|
|
214
|
+
def validate_inputs(self) -> None: # noqa: C901
|
|
215
|
+
qkv = (self.query, self.key, self.value)
|
|
216
|
+
if self.query.ndim not in (3, 4, 5) or any(
|
|
217
|
+
x.ndim != self.query.ndim for x in qkv
|
|
218
|
+
):
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"Query/Key/Value should all have BMGHK, BMHK or BMK shape.\n"
|
|
221
|
+
f" query.shape: {self.query.shape}\n"
|
|
222
|
+
f" key.shape : {self.key.shape}\n"
|
|
223
|
+
f" value.shape: {self.value.shape}"
|
|
224
|
+
)
|
|
225
|
+
if any(x.device != self.query.device for x in qkv):
|
|
226
|
+
raise ValueError("Query/Key/Value should all be on the same device")
|
|
227
|
+
if isinstance(
|
|
228
|
+
self.attn_bias,
|
|
229
|
+
(
|
|
230
|
+
BlockDiagonalMask,
|
|
231
|
+
BlockDiagonalPaddedKeysMask,
|
|
232
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
233
|
+
BlockDiagonalGappyKeysMask,
|
|
234
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
235
|
+
),
|
|
236
|
+
):
|
|
237
|
+
bias_device = self.attn_bias.q_seqinfo.seqstart.device
|
|
238
|
+
if bias_device != self.query.device:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Attention bias and Query/Key/Value should be on the same device\n"
|
|
241
|
+
f" query.device: {self.query.device}\n"
|
|
242
|
+
f" attn_bias : {bias_device}\n"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
|
|
246
|
+
non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
|
|
247
|
+
if not (quantized_dtypes or non_quantized_dtypes):
|
|
248
|
+
raise ValueError(
|
|
249
|
+
"Query/Key/Value should either all have the same dtype, or "
|
|
250
|
+
"(in the quantized case) Key/Value should have dtype torch.int32\n"
|
|
251
|
+
f" query.dtype: {self.query.dtype}\n"
|
|
252
|
+
f" key.dtype : {self.key.dtype}\n"
|
|
253
|
+
f" value.dtype: {self.value.dtype}"
|
|
254
|
+
)
|
|
255
|
+
# Biases with tensors attached are meant to be in BMHK format
|
|
256
|
+
# This would require to permute biases/gradients which can be expensive,
|
|
257
|
+
# so let's just forbid it - BMK is a legacy format anyway
|
|
258
|
+
if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
|
|
259
|
+
type(self.attn_bias)
|
|
260
|
+
):
|
|
261
|
+
raise ValueError(
|
|
262
|
+
f"Please provide inputs in BMHK format rather "
|
|
263
|
+
f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
|
|
264
|
+
)
|
|
265
|
+
attn_bias_t: Optional[torch.Tensor] = None
|
|
266
|
+
if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
|
|
267
|
+
attn_bias_t = self.attn_bias._bias
|
|
268
|
+
elif isinstance(self.attn_bias, torch.Tensor):
|
|
269
|
+
attn_bias_t = self.attn_bias
|
|
270
|
+
if self.query.ndim == 4 and attn_bias_t is not None:
|
|
271
|
+
expected_shape = (
|
|
272
|
+
self.query.shape[0],
|
|
273
|
+
self.query.shape[2],
|
|
274
|
+
self.query.shape[1],
|
|
275
|
+
self.key.shape[1],
|
|
276
|
+
)
|
|
277
|
+
if attn_bias_t.shape != expected_shape:
|
|
278
|
+
raise ValueError(
|
|
279
|
+
f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
|
|
280
|
+
f" query.shape: {self.query.shape}\n"
|
|
281
|
+
f" key.shape : {self.key.shape}\n"
|
|
282
|
+
f" value.shape: {self.value.shape}"
|
|
283
|
+
)
|
|
284
|
+
if isinstance(self.attn_bias, BlockDiagonalMask):
|
|
285
|
+
if any(x.shape[0] != 1 for x in qkv):
|
|
286
|
+
raise ValueError(
|
|
287
|
+
f"Expected batch_size=1 when using block-diagonal bias\n"
|
|
288
|
+
f" query.shape: {self.query.shape}\n"
|
|
289
|
+
f" key.shape : {self.key.shape}\n"
|
|
290
|
+
f" value.shape: {self.value.shape}"
|
|
291
|
+
)
|
|
292
|
+
if self.p < 0.0 or self.p > 1.0:
|
|
293
|
+
raise ValueError(f"Invalid dropout probability: p={self.p}")
|
|
294
|
+
# Check that shapes match between inputs
|
|
295
|
+
B, Mq = self.query.shape[:2]
|
|
296
|
+
K = self.query.shape[-1]
|
|
297
|
+
B, Mkv = self.key.shape[:2]
|
|
298
|
+
Kv = self.value.shape[-1]
|
|
299
|
+
quantized_kv_cache = self.value.dtype == torch.int32
|
|
300
|
+
key_embed_dim = Kv if quantized_kv_cache else K
|
|
301
|
+
|
|
302
|
+
valid_shapes = True
|
|
303
|
+
if self.query.ndim == 3: # BMK
|
|
304
|
+
valid_shapes = (
|
|
305
|
+
self.query.shape == (B, Mq, K)
|
|
306
|
+
and self.key.shape == (B, Mkv, K)
|
|
307
|
+
and self.value.shape == (B, Mkv, Kv)
|
|
308
|
+
)
|
|
309
|
+
H = self.query.shape[-2]
|
|
310
|
+
if self.query.ndim == 4: # BMHK
|
|
311
|
+
valid_shapes = (
|
|
312
|
+
self.query.shape == (B, Mq, H, K)
|
|
313
|
+
and self.key.shape == (B, Mkv, H, key_embed_dim)
|
|
314
|
+
and self.value.shape == (B, Mkv, H, Kv)
|
|
315
|
+
)
|
|
316
|
+
G = self.query.shape[2]
|
|
317
|
+
if self.query.ndim == 5: # BMNHK
|
|
318
|
+
valid_shapes = (
|
|
319
|
+
self.query.shape == (B, Mq, G, H, K)
|
|
320
|
+
and self.key.shape == (B, Mkv, G, H, key_embed_dim)
|
|
321
|
+
and self.value.shape == (B, Mkv, G, H, Kv)
|
|
322
|
+
)
|
|
323
|
+
if not valid_shapes:
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f"Incompatible shapes for attention inputs:\n"
|
|
326
|
+
f" query.shape: {self.query.shape}\n"
|
|
327
|
+
f" key.shape : {self.key.shape}\n"
|
|
328
|
+
f" value.shape: {self.value.shape}\n"
|
|
329
|
+
"HINT: We don't support broadcasting, please use `expand` "
|
|
330
|
+
"yourself before calling `memory_efficient_attention` if you need to"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def get_output_dtype(self) -> torch.dtype:
|
|
334
|
+
if self.output_dtype is None:
|
|
335
|
+
if self.is_partial and self.query.dtype is not torch.float64:
|
|
336
|
+
return torch.float32
|
|
337
|
+
return self.query.dtype
|
|
338
|
+
return self.output_dtype
|
|
339
|
+
|
|
340
|
+
@property
|
|
341
|
+
def nbytes(self) -> int:
|
|
342
|
+
"""
|
|
343
|
+
Number of bytes in the input, not counting the attention bias.
|
|
344
|
+
"""
|
|
345
|
+
return sum(
|
|
346
|
+
x.untyped_storage().nbytes() for x in [self.query, self.key, self.value]
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@dataclass
|
|
351
|
+
class Context:
|
|
352
|
+
lse: torch.Tensor
|
|
353
|
+
out: torch.Tensor
|
|
354
|
+
# NOTE: If `rng_state` is set, `op_bw` should be set as well
|
|
355
|
+
# as the randomness is backend-dependant
|
|
356
|
+
op_bw: Optional[Type["AttentionBwOpBase"]] = None
|
|
357
|
+
rng_state: Optional[Any] = None
|
|
358
|
+
qkv_share_storage: bool = False
|
|
359
|
+
|
|
360
|
+
def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
|
|
361
|
+
pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
|
|
362
|
+
lse = self.lse
|
|
363
|
+
if pad_amount > 0:
|
|
364
|
+
if force_pad_inf:
|
|
365
|
+
lse = lse[:, :, : self.out.shape[1]]
|
|
366
|
+
pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
|
|
367
|
+
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
|
|
368
|
+
elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
|
|
369
|
+
lse[:, :, self.out.shape[1] :].fill_(math.inf)
|
|
370
|
+
return lse
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@dataclass
|
|
374
|
+
class Gradients:
|
|
375
|
+
dq: torch.Tensor
|
|
376
|
+
dk: torch.Tensor
|
|
377
|
+
dv: torch.Tensor
|
|
378
|
+
# bias gradient. None if there is no tensor bias or if it doesn't require grad
|
|
379
|
+
db: Optional[torch.Tensor] = None
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class AttentionOpBase(BaseOperator):
|
|
383
|
+
"""Base class for any attention operator in xFormers
|
|
384
|
+
|
|
385
|
+
See:
|
|
386
|
+
|
|
387
|
+
- :attr:`xformers.ops.fmha.cutlass.FwOp`
|
|
388
|
+
- :attr:`xformers.ops.fmha.cutlass.BwOp`
|
|
389
|
+
- :attr:`xformers.ops.fmha.flash.FwOp`
|
|
390
|
+
- :attr:`xformers.ops.fmha.flash.BwOp`
|
|
391
|
+
- :attr:`xformers.ops.fmha.triton.FwOp`
|
|
392
|
+
- :attr:`xformers.ops.fmha.triton.BwOp`
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
OPERATOR: Any # pyre-ignore[13]
|
|
396
|
+
SUPPORTED_DEVICES: Set[str] # pyre-ignore[13]
|
|
397
|
+
CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
|
|
398
|
+
CUDA_MAXIMUM_COMPUTE_CAPABILITY: Optional[Tuple[int, int]] = None
|
|
399
|
+
SUPPORTED_DTYPES: Set[torch.dtype] # pyre-ignore[13]
|
|
400
|
+
SUPPORTED_MAX_K: float # pyre-ignore[13]
|
|
401
|
+
SUPPORTED_MIN_K: int = 0
|
|
402
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (type(None),)
|
|
403
|
+
SUPPORTS_DROPOUT: bool # pyre-ignore[13]
|
|
404
|
+
SUPPORTS_CUSTOM_SCALE: bool = False
|
|
405
|
+
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
|
|
406
|
+
SUPPORTS_OUTPUT_DTYPE: bool = False
|
|
407
|
+
SUPPORTS_PARTIAL: bool = False
|
|
408
|
+
IS_DETERMINISTIC: bool = True
|
|
409
|
+
SUPPORTS_BMGHK: bool = False
|
|
410
|
+
NAME: str # pyre-ignore[13]
|
|
411
|
+
OPERATOR_CATEGORY = "memory_efficient_attention"
|
|
412
|
+
# Format for the LSE computed in the FW pass, and accepted in the BW pass,
|
|
413
|
+
# for BlockDiagonalMask and children.
|
|
414
|
+
# When using a varlen bias, both the FW and BW operators must have the
|
|
415
|
+
# same value for `VARLEN_LSE_PACKED`
|
|
416
|
+
VARLEN_LSE_PACKED: bool = True
|
|
417
|
+
|
|
418
|
+
_TEST_BATCH_SIZES: List[int] = [1, 300]
|
|
419
|
+
_TEST_K: List[int] = [32, 128]
|
|
420
|
+
|
|
421
|
+
@classmethod
|
|
422
|
+
def supports(cls, d: Inputs) -> bool:
|
|
423
|
+
return not cls.not_supported_reasons(d)
|
|
424
|
+
|
|
425
|
+
@classmethod
|
|
426
|
+
def shape_not_supported_reasons(
|
|
427
|
+
cls, Mq: int, Mkv: int, K: int, Kv: int
|
|
428
|
+
) -> List[str]:
|
|
429
|
+
reasons = []
|
|
430
|
+
if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
|
|
431
|
+
reasons.append("query.shape[-1] != value.shape[-1]")
|
|
432
|
+
if max(K, Kv) > cls.SUPPORTED_MAX_K:
|
|
433
|
+
reasons.append(
|
|
434
|
+
f"max(query.shape[-1], value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
|
|
435
|
+
)
|
|
436
|
+
if min(K, Kv) < cls.SUPPORTED_MIN_K:
|
|
437
|
+
reasons.append(
|
|
438
|
+
f"min(query.shape[-1], value.shape[-1]) < {cls.SUPPORTED_MIN_K}"
|
|
439
|
+
)
|
|
440
|
+
return reasons
|
|
441
|
+
|
|
442
|
+
@classmethod
|
|
443
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]: # noqa: C901
|
|
444
|
+
"""
|
|
445
|
+
Returns a list of reasons why this is not supported.
|
|
446
|
+
The kernel can run these inputs only if the returned list is empty
|
|
447
|
+
"""
|
|
448
|
+
query_shape = d.query.shape
|
|
449
|
+
reasons = cls.shape_not_supported_reasons(
|
|
450
|
+
Mq=query_shape[1],
|
|
451
|
+
Mkv=d.key.shape[1],
|
|
452
|
+
K=query_shape[-1],
|
|
453
|
+
Kv=query_shape[-1] if d.value.dtype == torch.int32 else d.value.shape[-1],
|
|
454
|
+
)
|
|
455
|
+
device_type = d.query.device.type
|
|
456
|
+
dtype = d.query.dtype
|
|
457
|
+
if device_type not in cls.SUPPORTED_DEVICES:
|
|
458
|
+
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
|
|
459
|
+
if (
|
|
460
|
+
device_type == "cuda"
|
|
461
|
+
and not _built_with_cuda
|
|
462
|
+
and (torch.version.hip is None)
|
|
463
|
+
):
|
|
464
|
+
reasons.append("xFormers wasn't build with CUDA support")
|
|
465
|
+
if device_type == "cuda" and (torch.version.hip is None):
|
|
466
|
+
device_capability = torch.cuda.get_device_capability(d.device)
|
|
467
|
+
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
|
|
468
|
+
reasons.append(
|
|
469
|
+
f"requires device with capability >= {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
|
|
470
|
+
f"but your GPU has capability {device_capability} (too old)"
|
|
471
|
+
)
|
|
472
|
+
elif (
|
|
473
|
+
cls.CUDA_MAXIMUM_COMPUTE_CAPABILITY is not None
|
|
474
|
+
and device_capability > cls.CUDA_MAXIMUM_COMPUTE_CAPABILITY
|
|
475
|
+
):
|
|
476
|
+
reasons.append(
|
|
477
|
+
f"requires device with capability <= {cls.CUDA_MAXIMUM_COMPUTE_CAPABILITY} "
|
|
478
|
+
f"but your GPU has capability {device_capability} (too new)"
|
|
479
|
+
)
|
|
480
|
+
if dtype not in cls.SUPPORTED_DTYPES:
|
|
481
|
+
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
|
|
482
|
+
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
|
|
483
|
+
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
|
|
484
|
+
if not cls.SUPPORTS_OUTPUT_DTYPE:
|
|
485
|
+
if d.output_dtype is not None and d.output_dtype is not dtype:
|
|
486
|
+
reasons.append("Custom output dtype not supported")
|
|
487
|
+
if d.is_partial and not cls.SUPPORTS_PARTIAL:
|
|
488
|
+
reasons.append("Partial attention not supported")
|
|
489
|
+
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
|
|
490
|
+
reasons.append("dropout > 0.0")
|
|
491
|
+
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
|
|
492
|
+
reasons.append("has custom scale")
|
|
493
|
+
# bfloat16 is only supported on A100+ and MTIA
|
|
494
|
+
# ... although the kernels can still run and give the
|
|
495
|
+
# correct result
|
|
496
|
+
supports_bf16 = (
|
|
497
|
+
device_type.startswith("cuda")
|
|
498
|
+
and torch.cuda.get_device_capability(d.query.device)[0] >= 8
|
|
499
|
+
) or device_type.startswith("mtia")
|
|
500
|
+
if dtype is torch.bfloat16 and not supports_bf16:
|
|
501
|
+
reasons.append("bf16 is only supported on A100+ GPUs and MTIA")
|
|
502
|
+
if not cls.is_available():
|
|
503
|
+
reasons.append(
|
|
504
|
+
"operator wasn't built - see `python -m xformers.info` for more info"
|
|
505
|
+
)
|
|
506
|
+
if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
|
|
507
|
+
reasons.append(
|
|
508
|
+
"operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
|
|
509
|
+
)
|
|
510
|
+
if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
|
|
511
|
+
reasons.append("operator does not support BMGHK format")
|
|
512
|
+
return reasons
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class AttentionFwOpBase(AttentionOpBase):
|
|
516
|
+
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
|
517
|
+
torch.float: 3e-4,
|
|
518
|
+
torch.half: 4e-3,
|
|
519
|
+
torch.bfloat16: 2e-2,
|
|
520
|
+
}
|
|
521
|
+
ERROR_RTOL: Mapping[torch.dtype, float] = {
|
|
522
|
+
torch.float: 2e-5,
|
|
523
|
+
torch.half: 4e-4,
|
|
524
|
+
torch.bfloat16: 5e-3,
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
@classmethod
|
|
528
|
+
def apply(
|
|
529
|
+
cls, inp: Inputs, needs_gradient: bool
|
|
530
|
+
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
531
|
+
raise NotImplementedError()
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class AttentionBwOpBase(AttentionOpBase):
|
|
535
|
+
# NOTE on tolerances: These are tested for `scales => (1/32)**0.5`
|
|
536
|
+
# In the BW pass, imprecisions accumulate in the Q@K.T recalculation
|
|
537
|
+
# These imprecisions are multiplied by the `scale` and then exponentiated
|
|
538
|
+
# So if the scale is too high, we get a lot of errors
|
|
539
|
+
|
|
540
|
+
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
|
541
|
+
torch.float: 9e-4,
|
|
542
|
+
torch.half: 0.2,
|
|
543
|
+
torch.bfloat16: 0.9,
|
|
544
|
+
}
|
|
545
|
+
ERROR_RTOL: Mapping[torch.dtype, float] = {
|
|
546
|
+
torch.float: 1e-4,
|
|
547
|
+
torch.half: 2e-2,
|
|
548
|
+
torch.bfloat16: 0.1,
|
|
549
|
+
}
|
|
550
|
+
SUPPORTS_ATTN_BIAS_GRAD = False
|
|
551
|
+
SUPPORTS_PARTIAL = True
|
|
552
|
+
|
|
553
|
+
@classmethod
|
|
554
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
555
|
+
reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
|
|
556
|
+
if (
|
|
557
|
+
isinstance(d.attn_bias, torch.Tensor)
|
|
558
|
+
and d.attn_bias.requires_grad
|
|
559
|
+
and not cls.SUPPORTS_ATTN_BIAS_GRAD
|
|
560
|
+
):
|
|
561
|
+
reasons.append(
|
|
562
|
+
"Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
return reasons
|
|
566
|
+
|
|
567
|
+
@classmethod
|
|
568
|
+
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
|
569
|
+
raise NotImplementedError()
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
AttentionOp = Tuple[
|
|
573
|
+
Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
|
|
574
|
+
]
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
|
|
578
|
+
if tensor.ndim == 4:
|
|
579
|
+
return tensor
|
|
580
|
+
return tensor.reshape(
|
|
581
|
+
[tensor.shape[0] // num_heads, num_heads, tensor.shape[1], tensor.shape[2]]
|
|
582
|
+
).permute((0, 2, 1, 3))
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def check_lastdim_alignment_stride1(
|
|
586
|
+
reasons: List[str], name: str, x: torch.Tensor, alignment: int
|
|
587
|
+
) -> None:
|
|
588
|
+
if x.shape[-1] % alignment != 0:
|
|
589
|
+
reasons.append(f"{name}.shape[-1] % {alignment} != 0")
|
|
590
|
+
elif x.stride(-2) % alignment != 0:
|
|
591
|
+
reasons.append(
|
|
592
|
+
f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
|
|
593
|
+
)
|
|
594
|
+
# We can have stride=0 sometimes if dimension=1
|
|
595
|
+
if x.stride(-1) > 1:
|
|
596
|
+
reasons.append(
|
|
597
|
+
f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
|
|
598
|
+
)
|