mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,967 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from . import (
|
|
13
|
+
attn_bias,
|
|
14
|
+
ck,
|
|
15
|
+
ck_decoder,
|
|
16
|
+
ck_splitk,
|
|
17
|
+
cutlass,
|
|
18
|
+
cutlass_blackwell,
|
|
19
|
+
flash,
|
|
20
|
+
flash3,
|
|
21
|
+
flash_mtia,
|
|
22
|
+
triton_splitk,
|
|
23
|
+
)
|
|
24
|
+
from .attn_bias import (
|
|
25
|
+
AttentionBias,
|
|
26
|
+
BlockDiagonalMask,
|
|
27
|
+
LowerTriangularMask,
|
|
28
|
+
VARLEN_BIASES,
|
|
29
|
+
)
|
|
30
|
+
from .common import (
|
|
31
|
+
AttentionBwOpBase,
|
|
32
|
+
AttentionFwOpBase,
|
|
33
|
+
AttentionOp,
|
|
34
|
+
AttentionOpBase,
|
|
35
|
+
bmk2bmhk,
|
|
36
|
+
Context,
|
|
37
|
+
Gradients,
|
|
38
|
+
Inputs,
|
|
39
|
+
)
|
|
40
|
+
from .dispatch import (
|
|
41
|
+
_dispatch_bw,
|
|
42
|
+
_dispatch_fw,
|
|
43
|
+
_ensure_op_supports_or_raise,
|
|
44
|
+
_get_use_fa3,
|
|
45
|
+
_set_use_fa3,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
|
|
49
|
+
MemoryEfficientAttentionCutlassBlackwellOp = (
|
|
50
|
+
cutlass_blackwell.FwOp,
|
|
51
|
+
cutlass_blackwell.BwOp,
|
|
52
|
+
)
|
|
53
|
+
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
|
|
54
|
+
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
|
|
55
|
+
MemoryEfficientAttentionFlashMtiaAttentionOp = (flash_mtia.FwOp, flash_mtia.BwOp)
|
|
56
|
+
MemoryEfficientAttentionCkOp = (ck.FwOp, ck.BwOp)
|
|
57
|
+
MemoryEfficientAttentionCkDecoderOp = (ck_decoder.FwOp, ck.BwOp)
|
|
58
|
+
MemoryEfficientAttentionSplitKCkOp = (ck_splitk.FwOp, ck.BwOp)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _deserialize_bias(attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]) -> Any:
|
|
62
|
+
if attn_bias_tensor is None:
|
|
63
|
+
return attn_bias_ctx
|
|
64
|
+
return attn_bias_tensor
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# Note: `torch.compile` only allows custom autograd functions
|
|
68
|
+
# to accept a subset of types. Therefore we serialize `op` objects
|
|
69
|
+
# to `str` before entering the function, and unserialize them inside.
|
|
70
|
+
# See also: https://github.com/pytorch/pytorch/issues/118395
|
|
71
|
+
_OPS_LOOKUP = {
|
|
72
|
+
flash.FwOp.NAME: flash.FwOp,
|
|
73
|
+
flash.BwOp.NAME: flash.BwOp,
|
|
74
|
+
flash_mtia.FwOp.NAME: flash_mtia.FwOp,
|
|
75
|
+
flash_mtia.BwOp.NAME: flash_mtia.BwOp,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _serialize_op(op):
|
|
80
|
+
if op is not None and op.NAME in _OPS_LOOKUP:
|
|
81
|
+
return op.NAME
|
|
82
|
+
return op
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _unserialize_op(op):
|
|
86
|
+
if isinstance(op, str):
|
|
87
|
+
return _OPS_LOOKUP[op]
|
|
88
|
+
return op
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class _fMHA(torch.autograd.Function):
|
|
92
|
+
@staticmethod
|
|
93
|
+
# type: ignore
|
|
94
|
+
def forward(ctx, op_fw, op_bw, *args: Any) -> Any:
|
|
95
|
+
inp = Inputs(*args)
|
|
96
|
+
|
|
97
|
+
op_fw = _unserialize_op(op_fw)
|
|
98
|
+
op_bw = _unserialize_op(op_bw)
|
|
99
|
+
|
|
100
|
+
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
|
|
101
|
+
inp=inp, op=op_fw
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Saving attn_bias is a bit complicated, as the
|
|
105
|
+
# torch part should go in `save_for_backward`
|
|
106
|
+
if isinstance(inp.attn_bias, torch.Tensor):
|
|
107
|
+
attn_bias_tensor = inp.attn_bias
|
|
108
|
+
attn_bias_ctx = None
|
|
109
|
+
else:
|
|
110
|
+
attn_bias_tensor = None
|
|
111
|
+
attn_bias_ctx = inp.attn_bias
|
|
112
|
+
|
|
113
|
+
ctx.save_for_backward(
|
|
114
|
+
inp.query,
|
|
115
|
+
inp.key,
|
|
116
|
+
inp.value,
|
|
117
|
+
op_ctx.out,
|
|
118
|
+
op_ctx.lse,
|
|
119
|
+
)
|
|
120
|
+
ctx.rng_state = op_ctx.rng_state
|
|
121
|
+
ctx.attn_bias_tensor = attn_bias_tensor
|
|
122
|
+
if op_ctx.op_bw is not None:
|
|
123
|
+
if op_bw is not None and op_bw is not op_ctx.op_bw:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"Specified op_bw={op_bw.NAME}, but forward op "
|
|
126
|
+
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
|
|
127
|
+
)
|
|
128
|
+
op_bw = op_ctx.op_bw
|
|
129
|
+
if (
|
|
130
|
+
op_fw is not None
|
|
131
|
+
and op_bw is not None
|
|
132
|
+
and isinstance(inp.attn_bias, VARLEN_BIASES)
|
|
133
|
+
and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2
|
|
134
|
+
and op_bw.VARLEN_LSE_PACKED != op_fw.VARLEN_LSE_PACKED
|
|
135
|
+
):
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Specified op_bw={op_bw.NAME} is not compatible with the "
|
|
138
|
+
f"op_fw={op_fw.NAME}, because they use different format of logsumexp. "
|
|
139
|
+
f"NOTE: This is new with xFormers 0.0.28"
|
|
140
|
+
)
|
|
141
|
+
if op_bw is None and (
|
|
142
|
+
inp.query.requires_grad or inp.key.requires_grad or inp.value.requires_grad
|
|
143
|
+
):
|
|
144
|
+
varlen_lse_packed = _detect_lse_packed_or_raise(op_ctx.lse, inp)
|
|
145
|
+
if varlen_lse_packed is not None and op_fw is not None:
|
|
146
|
+
assert op_fw.VARLEN_LSE_PACKED == varlen_lse_packed, (
|
|
147
|
+
f"{op_fw.NAME}: wrong value for `VARLEN_LSE_PACKED` ?"
|
|
148
|
+
)
|
|
149
|
+
# NOTE: We need to check tensor strides to decide which operator we run in the BW pass.
|
|
150
|
+
# Unfortunately, PyTorch only allows to call this function during the FW pass, so
|
|
151
|
+
# we decide the operator to use now.
|
|
152
|
+
op_bw = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
|
|
153
|
+
ctx.op_fw = op_fw
|
|
154
|
+
ctx.op_bw = op_bw
|
|
155
|
+
ctx.p = inp.p
|
|
156
|
+
# This allows to create gradients from a single storage,
|
|
157
|
+
# to avoid a "cat" in the BW pass.
|
|
158
|
+
# The heuristic is approximative, but:
|
|
159
|
+
# (1) It's not a big issue to create a shared storage
|
|
160
|
+
# (2) The heuristic needs to pass `torch.compile`
|
|
161
|
+
# (this is also why we run it in the FW pass, the BW pass is stricter)
|
|
162
|
+
ctx.qkv_share_storage = (
|
|
163
|
+
inp.query.shape[0] == inp.key.shape[0]
|
|
164
|
+
and inp.query.shape[-1] == inp.value.shape[-1]
|
|
165
|
+
and inp.query.stride(-2)
|
|
166
|
+
== (inp.key.shape[-1] + inp.query.shape[-1] + inp.value.shape[-1])
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
ctx.scale = inp.scale
|
|
170
|
+
ctx.attn_bias_ctx = attn_bias_ctx
|
|
171
|
+
ctx.n_args = len(args)
|
|
172
|
+
return out, op_ctx.lse
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
@torch.autograd.function.once_differentiable
|
|
176
|
+
def backward(ctx, grad, grad_lse):
|
|
177
|
+
# Re-create context
|
|
178
|
+
query, key, value, out, lse = ctx.saved_tensors
|
|
179
|
+
attn_bias_tensor = ctx.attn_bias_tensor
|
|
180
|
+
rng_state = ctx.rng_state
|
|
181
|
+
inp = Inputs(
|
|
182
|
+
query=query,
|
|
183
|
+
key=key,
|
|
184
|
+
value=value,
|
|
185
|
+
attn_bias=_deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
|
|
186
|
+
p=ctx.p,
|
|
187
|
+
scale=ctx.scale,
|
|
188
|
+
)
|
|
189
|
+
op_ctx = Context(
|
|
190
|
+
lse=lse,
|
|
191
|
+
out=out,
|
|
192
|
+
rng_state=rng_state,
|
|
193
|
+
qkv_share_storage=ctx.qkv_share_storage,
|
|
194
|
+
)
|
|
195
|
+
grads = _memory_efficient_attention_backward(
|
|
196
|
+
ctx=op_ctx,
|
|
197
|
+
inp=inp,
|
|
198
|
+
grad=grad,
|
|
199
|
+
op=ctx.op_bw,
|
|
200
|
+
_skip_op_checks=True,
|
|
201
|
+
)
|
|
202
|
+
return (None, None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
|
|
203
|
+
ctx.n_args - 2
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def memory_efficient_attention(
|
|
208
|
+
query: torch.Tensor,
|
|
209
|
+
key: torch.Tensor,
|
|
210
|
+
value: torch.Tensor,
|
|
211
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
|
212
|
+
p: float = 0.0,
|
|
213
|
+
scale: Optional[float] = None,
|
|
214
|
+
*,
|
|
215
|
+
op: Optional[AttentionOp] = None,
|
|
216
|
+
output_dtype: Optional[torch.dtype] = None,
|
|
217
|
+
) -> torch.Tensor:
|
|
218
|
+
"""Implements the memory-efficient attention mechanism following
|
|
219
|
+
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
|
|
220
|
+
|
|
221
|
+
:Inputs shape:
|
|
222
|
+
|
|
223
|
+
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
|
|
224
|
+
the sequence length, H the number of heads, and K the embeding size per head
|
|
225
|
+
|
|
226
|
+
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
|
|
227
|
+
|
|
228
|
+
- Inputs can also be of dimension 5 with GQA - see note below
|
|
229
|
+
|
|
230
|
+
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
:Equivalent pytorch code:
|
|
234
|
+
|
|
235
|
+
.. code-block:: python
|
|
236
|
+
|
|
237
|
+
scale = 1.0 / query.shape[-1] ** 0.5
|
|
238
|
+
query = query * scale
|
|
239
|
+
query = query.transpose(1, 2)
|
|
240
|
+
key = key.transpose(1, 2)
|
|
241
|
+
value = value.transpose(1, 2)
|
|
242
|
+
attn = query @ key.transpose(-2, -1)
|
|
243
|
+
if attn_bias is not None:
|
|
244
|
+
attn = attn + attn_bias
|
|
245
|
+
attn = attn.softmax(-1)
|
|
246
|
+
attn = F.dropout(attn, p)
|
|
247
|
+
attn = attn @ value
|
|
248
|
+
return attn.transpose(1, 2).contiguous()
|
|
249
|
+
|
|
250
|
+
:Examples:
|
|
251
|
+
|
|
252
|
+
.. code-block:: python
|
|
253
|
+
|
|
254
|
+
import xformers.ops as xops
|
|
255
|
+
|
|
256
|
+
# Compute regular attention
|
|
257
|
+
y = xops.memory_efficient_attention(q, k, v)
|
|
258
|
+
|
|
259
|
+
# With a dropout of 0.2
|
|
260
|
+
y = xops.memory_efficient_attention(q, k, v, p=0.2)
|
|
261
|
+
|
|
262
|
+
# Causal attention
|
|
263
|
+
y = xops.memory_efficient_attention(
|
|
264
|
+
q, k, v,
|
|
265
|
+
attn_bias=xops.LowerTriangularMask()
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
:Supported hardware:
|
|
269
|
+
|
|
270
|
+
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
|
|
271
|
+
|
|
272
|
+
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
|
|
273
|
+
|
|
274
|
+
MQA/GQA is an experimental feature supported only for the forward pass.
|
|
275
|
+
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
|
|
276
|
+
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
|
|
277
|
+
``H`` is the number of heads per group (8 in the example).
|
|
278
|
+
|
|
279
|
+
Please note that xFormers will not automatically broadcast the inputs, so you will need
|
|
280
|
+
to broadcast it manually before calling `memory_efficient_attention`.
|
|
281
|
+
|
|
282
|
+
:GQA/MQA example:
|
|
283
|
+
|
|
284
|
+
.. code-block:: python
|
|
285
|
+
|
|
286
|
+
import torch
|
|
287
|
+
import xformers.ops as xops
|
|
288
|
+
|
|
289
|
+
B, M, K = 3, 32, 128
|
|
290
|
+
kwargs = dict(device="cuda", dtype=torch.float16)
|
|
291
|
+
q = torch.randn([B, M, 8, K], **kwargs)
|
|
292
|
+
k = torch.randn([B, M, 2, K], **kwargs)
|
|
293
|
+
v = torch.randn([B, M, 2, K], **kwargs)
|
|
294
|
+
out_gqa = xops.memory_efficient_attention(
|
|
295
|
+
q.reshape([B, M, 2, 4, K]),
|
|
296
|
+
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
|
297
|
+
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
Raises:
|
|
301
|
+
NotImplementedError: if there is no operator available to compute the MHA
|
|
302
|
+
ValueError: if inputs are invalid
|
|
303
|
+
|
|
304
|
+
:parameter query: Tensor of shape ``[B, Mq, H, K]``
|
|
305
|
+
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
|
|
306
|
+
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
|
|
307
|
+
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
|
|
308
|
+
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
|
|
309
|
+
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
|
|
310
|
+
:parameter p: Dropout probability. Disabled if set to ``0.0``
|
|
311
|
+
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
|
|
312
|
+
scale (q.shape[-1]**-0.5) will be used.
|
|
313
|
+
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
|
|
314
|
+
If set to ``None`` (recommended), xFormers \
|
|
315
|
+
will dispatch to the best available operator, depending on the inputs \
|
|
316
|
+
and options.
|
|
317
|
+
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
|
|
318
|
+
"""
|
|
319
|
+
return _memory_efficient_attention(
|
|
320
|
+
Inputs(
|
|
321
|
+
query=query,
|
|
322
|
+
key=key,
|
|
323
|
+
value=value,
|
|
324
|
+
p=p,
|
|
325
|
+
attn_bias=attn_bias,
|
|
326
|
+
scale=scale,
|
|
327
|
+
output_dtype=output_dtype,
|
|
328
|
+
),
|
|
329
|
+
op=op,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
torch.library.define(
|
|
334
|
+
"mslk::memory_efficient_attention_forward",
|
|
335
|
+
"(Tensor q, Tensor k, Tensor v, Tensor? b = None, float? p = 0.0, float? scale = None) -> Tensor",
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _memory_efficient_attention_forward_torch_wrapper_meta(
|
|
340
|
+
query: torch.Tensor,
|
|
341
|
+
key: torch.Tensor,
|
|
342
|
+
value: torch.Tensor,
|
|
343
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
|
344
|
+
p: float = 0.0,
|
|
345
|
+
scale: Optional[float] = None,
|
|
346
|
+
):
|
|
347
|
+
return torch.empty_like(query)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
torch.library.impl(
|
|
351
|
+
"mslk::memory_efficient_attention_forward",
|
|
352
|
+
"Meta",
|
|
353
|
+
_memory_efficient_attention_forward_torch_wrapper_meta,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
# torch.compile has issue when tracing through op dispatch and ensure_op_support
|
|
358
|
+
# so provide a wrapper to register it as a custom torch library op.
|
|
359
|
+
def _memory_efficient_attention_forward_torch_wrapper(
|
|
360
|
+
query: torch.Tensor,
|
|
361
|
+
key: torch.Tensor,
|
|
362
|
+
value: torch.Tensor,
|
|
363
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
|
364
|
+
p: float = 0.0,
|
|
365
|
+
scale: Optional[float] = None,
|
|
366
|
+
) -> torch.Tensor:
|
|
367
|
+
"""
|
|
368
|
+
This provides a torch-compilable wrapper op to
|
|
369
|
+
memory_efficient_attention_forward in certain special cases.
|
|
370
|
+
|
|
371
|
+
Note that the following are not supported
|
|
372
|
+
- `op` input (?)
|
|
373
|
+
- certain attn_bias types (?)
|
|
374
|
+
- output_dtype
|
|
375
|
+
- K != Kv
|
|
376
|
+
"""
|
|
377
|
+
return memory_efficient_attention_forward(
|
|
378
|
+
query,
|
|
379
|
+
key,
|
|
380
|
+
value,
|
|
381
|
+
attn_bias,
|
|
382
|
+
p,
|
|
383
|
+
scale,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
torch.library.impl(
|
|
388
|
+
"mslk::memory_efficient_attention_forward",
|
|
389
|
+
"CUDA",
|
|
390
|
+
_memory_efficient_attention_forward_torch_wrapper,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
torch.library.define(
|
|
395
|
+
"mslk::memory_efficient_attention_forward_with_bias",
|
|
396
|
+
"(Tensor q, Tensor k, Tensor v, Tensor b, float? p = 0.0, float? scale = None) -> Tensor",
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _memory_efficient_attention_forward_torch_wrapper_with_bias_meta(
|
|
401
|
+
query: torch.Tensor,
|
|
402
|
+
key: torch.Tensor,
|
|
403
|
+
value: torch.Tensor,
|
|
404
|
+
attn_bias: Union[torch.Tensor, AttentionBias],
|
|
405
|
+
p: float = 0.0,
|
|
406
|
+
scale: Optional[float] = None,
|
|
407
|
+
):
|
|
408
|
+
return torch.empty_like(query)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
torch.library.impl(
|
|
412
|
+
"mslk::memory_efficient_attention_forward_with_bias",
|
|
413
|
+
"Meta",
|
|
414
|
+
_memory_efficient_attention_forward_torch_wrapper_with_bias_meta,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# torch.compile has issue when tracing through op dispatch and ensure_op_support
|
|
419
|
+
# so provide a wrapper to register it as a custom torch library op.
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _memory_efficient_attention_forward_torch_wrapper_with_bias(
|
|
423
|
+
query: torch.Tensor,
|
|
424
|
+
key: torch.Tensor,
|
|
425
|
+
value: torch.Tensor,
|
|
426
|
+
attn_bias: Union[torch.Tensor, AttentionBias],
|
|
427
|
+
p: float = 0.0,
|
|
428
|
+
scale: Optional[float] = None,
|
|
429
|
+
) -> torch.Tensor:
|
|
430
|
+
"""
|
|
431
|
+
This provides a torch-compilable wrapper op to
|
|
432
|
+
memory_efficient_attention_forward in certain special cases.
|
|
433
|
+
|
|
434
|
+
Note that the following are not supported
|
|
435
|
+
- `op` input (?)
|
|
436
|
+
- certain attn_bias types (?)
|
|
437
|
+
- output_dtype
|
|
438
|
+
- K != Kv
|
|
439
|
+
"""
|
|
440
|
+
return memory_efficient_attention_forward(
|
|
441
|
+
query,
|
|
442
|
+
key,
|
|
443
|
+
value,
|
|
444
|
+
attn_bias,
|
|
445
|
+
p,
|
|
446
|
+
scale,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
torch.library.impl(
|
|
451
|
+
"mslk::memory_efficient_attention_forward_with_bias",
|
|
452
|
+
"CUDA",
|
|
453
|
+
_memory_efficient_attention_forward_torch_wrapper_with_bias,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def memory_efficient_attention_forward(
|
|
458
|
+
query: torch.Tensor,
|
|
459
|
+
key: torch.Tensor,
|
|
460
|
+
value: torch.Tensor,
|
|
461
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
|
462
|
+
p: float = 0.0,
|
|
463
|
+
scale: Optional[float] = None,
|
|
464
|
+
*,
|
|
465
|
+
op: Optional[Type[AttentionFwOpBase]] = None,
|
|
466
|
+
output_dtype: Optional[torch.dtype] = None,
|
|
467
|
+
) -> torch.Tensor:
|
|
468
|
+
"""
|
|
469
|
+
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
|
|
470
|
+
"""
|
|
471
|
+
return _memory_efficient_attention_forward(
|
|
472
|
+
Inputs(
|
|
473
|
+
query=query,
|
|
474
|
+
key=key,
|
|
475
|
+
value=value,
|
|
476
|
+
p=p,
|
|
477
|
+
attn_bias=attn_bias,
|
|
478
|
+
scale=scale,
|
|
479
|
+
output_dtype=output_dtype,
|
|
480
|
+
),
|
|
481
|
+
op=op,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def memory_efficient_attention_forward_requires_grad(
|
|
486
|
+
query: torch.Tensor,
|
|
487
|
+
key: torch.Tensor,
|
|
488
|
+
value: torch.Tensor,
|
|
489
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
|
490
|
+
p: float = 0.0,
|
|
491
|
+
scale: Optional[float] = None,
|
|
492
|
+
*,
|
|
493
|
+
op: Optional[Type[AttentionFwOpBase]] = None,
|
|
494
|
+
output_dtype: Optional[torch.dtype] = None,
|
|
495
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
496
|
+
"""
|
|
497
|
+
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
|
|
498
|
+
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
|
|
499
|
+
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
|
|
500
|
+
"""
|
|
501
|
+
if p != 0.0:
|
|
502
|
+
raise NotImplementedError(
|
|
503
|
+
"dropout is not supported on the non-autograd API."
|
|
504
|
+
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
|
505
|
+
)
|
|
506
|
+
out, ctx = _memory_efficient_attention_forward_requires_grad(
|
|
507
|
+
Inputs(
|
|
508
|
+
query=query,
|
|
509
|
+
key=key,
|
|
510
|
+
value=value,
|
|
511
|
+
p=p,
|
|
512
|
+
attn_bias=attn_bias,
|
|
513
|
+
scale=scale,
|
|
514
|
+
output_dtype=output_dtype,
|
|
515
|
+
),
|
|
516
|
+
op=op,
|
|
517
|
+
)
|
|
518
|
+
return out, ctx.lse
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def memory_efficient_attention_backward(
|
|
522
|
+
grad: torch.Tensor,
|
|
523
|
+
output: torch.Tensor,
|
|
524
|
+
lse: torch.Tensor,
|
|
525
|
+
query: torch.Tensor,
|
|
526
|
+
key: torch.Tensor,
|
|
527
|
+
value: torch.Tensor,
|
|
528
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
|
529
|
+
p: float = 0.0,
|
|
530
|
+
scale: Optional[float] = None,
|
|
531
|
+
*,
|
|
532
|
+
op: Optional[Type[AttentionBwOpBase]] = None,
|
|
533
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
534
|
+
"""
|
|
535
|
+
Computes the gradient of the attention.
|
|
536
|
+
Returns a tuple (dq, dk, dv)
|
|
537
|
+
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
|
|
538
|
+
`lse` is the tensor returned by
|
|
539
|
+
:attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
|
|
540
|
+
"""
|
|
541
|
+
if p != 0.0:
|
|
542
|
+
raise NotImplementedError(
|
|
543
|
+
"dropout is not supported on the non-autograd API."
|
|
544
|
+
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
|
545
|
+
)
|
|
546
|
+
gradients = _memory_efficient_attention_backward(
|
|
547
|
+
Context(out=output, lse=lse),
|
|
548
|
+
Inputs(
|
|
549
|
+
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
|
|
550
|
+
),
|
|
551
|
+
grad,
|
|
552
|
+
op=op,
|
|
553
|
+
)
|
|
554
|
+
return (gradients.dq, gradients.dk, gradients.dv)
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def _memory_efficient_attention(
|
|
558
|
+
inp: Inputs, op: Optional[AttentionOp] = None
|
|
559
|
+
) -> torch.Tensor:
|
|
560
|
+
# fast-path that doesn't require computing the logsumexp for backward computation
|
|
561
|
+
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
|
|
562
|
+
return _memory_efficient_attention_forward(
|
|
563
|
+
inp, op=op[0] if op is not None else None
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
output_shape = inp.normalize_bmhk()
|
|
567
|
+
|
|
568
|
+
op_fw = _serialize_op(op[0] if op is not None else None)
|
|
569
|
+
op_bw = _serialize_op(op[1] if op is not None else None)
|
|
570
|
+
return _fMHA.apply(
|
|
571
|
+
op_fw, op_bw, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
|
|
572
|
+
)[0].reshape(output_shape)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def _memory_efficient_attention_forward(
|
|
576
|
+
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
|
577
|
+
) -> torch.Tensor:
|
|
578
|
+
inp.validate_inputs()
|
|
579
|
+
output_shape = inp.normalize_bmhk()
|
|
580
|
+
if op is None:
|
|
581
|
+
op = _dispatch_fw(inp, False)
|
|
582
|
+
else:
|
|
583
|
+
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
|
584
|
+
|
|
585
|
+
out, *_ = op.apply(inp, needs_gradient=False)
|
|
586
|
+
return out.reshape(output_shape)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _memory_efficient_attention_forward_requires_grad(
|
|
590
|
+
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
|
591
|
+
) -> Tuple[torch.Tensor, Context]:
|
|
592
|
+
inp.validate_inputs()
|
|
593
|
+
output_shape = inp.normalize_bmhk()
|
|
594
|
+
if op is None:
|
|
595
|
+
op = _dispatch_fw(inp, True)
|
|
596
|
+
else:
|
|
597
|
+
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
|
598
|
+
out, ctx = op.apply(inp, needs_gradient=True)
|
|
599
|
+
assert ctx is not None
|
|
600
|
+
return (out.reshape(output_shape), ctx)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def _detect_lse_packed_or_raise(lse: torch.Tensor, inp: Inputs) -> Optional[bool]:
|
|
604
|
+
"""
|
|
605
|
+
Detects the LSE format if we're in a varlen case.
|
|
606
|
+
Returns `None` if the format is not relevant (eg not varlen)
|
|
607
|
+
Raises an exception if the `lse` has the wrong shape
|
|
608
|
+
"""
|
|
609
|
+
shape_mismatch_err = (
|
|
610
|
+
"Input tensors have incompatible shapes.\n"
|
|
611
|
+
f" lse.shape : {lse.shape}\n"
|
|
612
|
+
f" query.shape : {inp.query.shape}\n"
|
|
613
|
+
f" attn_bias : {type(inp.attn_bias)}"
|
|
614
|
+
)
|
|
615
|
+
# 1. Check ndim & head dimensions
|
|
616
|
+
# In any case, LSE should be [*, *GH]
|
|
617
|
+
if lse.ndim != (inp.query.ndim - 1) or lse.shape[1:-1] != inp.query.shape[2:-1]:
|
|
618
|
+
raise ValueError(shape_mismatch_err)
|
|
619
|
+
lse_bm = [lse.shape[0], lse.shape[-1]]
|
|
620
|
+
lse_packed_shape = [inp.query.shape[0], inp.query.shape[1]]
|
|
621
|
+
lse_packed = lse_bm[0] == lse_packed_shape[0] and lse_bm >= lse_packed_shape
|
|
622
|
+
# 2. Check correctness for varlen biases with query.shape = [1, M, *GH, K]
|
|
623
|
+
# Either [1, *GH, M] (packed)
|
|
624
|
+
# Or [num_seq, *GH, Mq] .. with `Mq >= max_q` (padded)
|
|
625
|
+
if isinstance(inp.attn_bias, VARLEN_BIASES):
|
|
626
|
+
si = inp.attn_bias.q_seqinfo
|
|
627
|
+
lse_padded_shape = [si.seqstart.shape[0] - 1, si.max_seqlen]
|
|
628
|
+
lse_padded = lse_bm[0] == lse_padded_shape[0] and lse_bm >= lse_padded_shape
|
|
629
|
+
if lse_packed and lse_padded:
|
|
630
|
+
return None
|
|
631
|
+
elif lse_packed:
|
|
632
|
+
return True
|
|
633
|
+
elif lse_padded:
|
|
634
|
+
return False
|
|
635
|
+
raise ValueError(shape_mismatch_err)
|
|
636
|
+
# 3. For non-varlen, shape must be [B, *GH] with query.shape=[B, M, *GH, K]
|
|
637
|
+
if not lse_packed:
|
|
638
|
+
raise ValueError(shape_mismatch_err)
|
|
639
|
+
return None
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def _memory_efficient_attention_backward(
|
|
643
|
+
ctx: Context,
|
|
644
|
+
inp: Inputs,
|
|
645
|
+
grad: torch.Tensor,
|
|
646
|
+
op: Optional[Type[AttentionBwOpBase]],
|
|
647
|
+
*,
|
|
648
|
+
_skip_op_checks: bool = False,
|
|
649
|
+
) -> Gradients:
|
|
650
|
+
"""Warning: grad/ctx.out is potentially in BMK format"""
|
|
651
|
+
inp.validate_inputs()
|
|
652
|
+
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
|
|
653
|
+
raise ValueError(
|
|
654
|
+
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
|
|
655
|
+
f"grad.shape : {grad.shape} \n"
|
|
656
|
+
f"out.shape : {ctx.out.shape} \n"
|
|
657
|
+
f"query.shape: {inp.query.shape}"
|
|
658
|
+
)
|
|
659
|
+
shape_dq, shape_dk, shape_dv = tuple(
|
|
660
|
+
x.shape for x in (inp.query, inp.key, inp.value)
|
|
661
|
+
)
|
|
662
|
+
inp.normalize_bmhk()
|
|
663
|
+
varlen_lse_packed = _detect_lse_packed_or_raise(ctx.lse, inp)
|
|
664
|
+
grad = bmk2bmhk(grad, 1)
|
|
665
|
+
ctx.out = bmk2bmhk(ctx.out, 1)
|
|
666
|
+
|
|
667
|
+
if op is None:
|
|
668
|
+
op = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
|
|
669
|
+
elif not _skip_op_checks:
|
|
670
|
+
_ensure_op_supports_or_raise(
|
|
671
|
+
ValueError, "memory_efficient_attention_backward", op, inp
|
|
672
|
+
)
|
|
673
|
+
if varlen_lse_packed is not None and varlen_lse_packed != op.VARLEN_LSE_PACKED:
|
|
674
|
+
raise ValueError(
|
|
675
|
+
f"Wrong LSE format for {op.NAME} in variable seqlen case. "
|
|
676
|
+
f"Double-check that the BW operator {op.NAME} is compatible "
|
|
677
|
+
f"with the operator used in the FW pass."
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
grads = op.apply(ctx, inp, grad)
|
|
681
|
+
grads.dq = grads.dq.reshape(shape_dq)
|
|
682
|
+
grads.dk = grads.dk.reshape(shape_dk)
|
|
683
|
+
grads.dv = grads.dv.reshape(shape_dv)
|
|
684
|
+
return grads
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
def memory_efficient_attention_partial(
|
|
688
|
+
query: torch.Tensor,
|
|
689
|
+
key: torch.Tensor,
|
|
690
|
+
value: torch.Tensor,
|
|
691
|
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
|
692
|
+
p: float = 0.0,
|
|
693
|
+
scale: Optional[float] = None,
|
|
694
|
+
*,
|
|
695
|
+
op: Optional[Union[AttentionOp, Type[AttentionFwOpBase]]] = None,
|
|
696
|
+
output_dtype: Optional[torch.dtype] = None,
|
|
697
|
+
_allow_backward: bool = False,
|
|
698
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
699
|
+
"""
|
|
700
|
+
Returns a tuple (output, lse), where `output` is the attention in the style of
|
|
701
|
+
memory_efficient_attention, and `lse` is extra data, a log-sum-exp.
|
|
702
|
+
The outputs of calls to this with the same query and separate keys and values
|
|
703
|
+
can be merged with merge_attentions to obtain the attention of the queries
|
|
704
|
+
against the disjoint union of the keys and values.
|
|
705
|
+
|
|
706
|
+
This function doesn't have a backward pass.
|
|
707
|
+
|
|
708
|
+
If _allow_backward is set to True, then a backward pass is allowed,
|
|
709
|
+
but it is restricted: only the gradient of the output, not the gradient of
|
|
710
|
+
the LSE, is used.
|
|
711
|
+
Note that this makes it very easy to accidentally get wrong gradients.
|
|
712
|
+
"""
|
|
713
|
+
if p != 0.0:
|
|
714
|
+
raise NotImplementedError("dropout is not supported.")
|
|
715
|
+
fwop: Optional[Type[AttentionFwOpBase]] = op[0] if isinstance(op, tuple) else op
|
|
716
|
+
inp = Inputs(
|
|
717
|
+
query=query,
|
|
718
|
+
key=key,
|
|
719
|
+
value=value,
|
|
720
|
+
p=p,
|
|
721
|
+
attn_bias=attn_bias,
|
|
722
|
+
scale=scale,
|
|
723
|
+
output_dtype=output_dtype,
|
|
724
|
+
is_partial=True,
|
|
725
|
+
)
|
|
726
|
+
is_grad = (
|
|
727
|
+
_allow_backward
|
|
728
|
+
and torch.is_grad_enabled()
|
|
729
|
+
and any(x.requires_grad for x in [query, key, value])
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
if not is_grad:
|
|
733
|
+
out, ctx = _memory_efficient_attention_forward_requires_grad(
|
|
734
|
+
inp,
|
|
735
|
+
op=fwop,
|
|
736
|
+
)
|
|
737
|
+
return out, ctx.lse
|
|
738
|
+
|
|
739
|
+
if query.ndim == 5:
|
|
740
|
+
raise ValueError("gradients not supported for 5D tensors")
|
|
741
|
+
if isinstance(op, tuple):
|
|
742
|
+
op_fw = _serialize_op(op[0])
|
|
743
|
+
op_bw = _serialize_op(op[1])
|
|
744
|
+
elif op is None:
|
|
745
|
+
op_fw = op_bw = None
|
|
746
|
+
else:
|
|
747
|
+
op_fw = _serialize_op(op)
|
|
748
|
+
op_bw = None
|
|
749
|
+
return _fMHA.apply(
|
|
750
|
+
op_fw,
|
|
751
|
+
op_bw,
|
|
752
|
+
inp.query,
|
|
753
|
+
inp.key,
|
|
754
|
+
inp.value,
|
|
755
|
+
inp.attn_bias,
|
|
756
|
+
inp.p,
|
|
757
|
+
inp.scale,
|
|
758
|
+
inp.output_dtype,
|
|
759
|
+
inp.is_partial,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def merge_attentions( # noqa: C901
|
|
764
|
+
attn_split: Union[torch.Tensor, Sequence[torch.Tensor]],
|
|
765
|
+
lse_split: Union[torch.Tensor, Sequence[torch.Tensor]],
|
|
766
|
+
write_lse: bool = True,
|
|
767
|
+
output_dtype: Optional[torch.dtype] = None,
|
|
768
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
769
|
+
"""
|
|
770
|
+
Combine attention output computed on different parts of K/V for the same
|
|
771
|
+
query to get attention on the whole K/V. See https://arxiv.org/abs/2402.05099
|
|
772
|
+
The result is equal to
|
|
773
|
+
Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...)
|
|
774
|
+
LSE_full = log(exp(LSE1) + exp(LSE2) + ...)
|
|
775
|
+
|
|
776
|
+
Args:
|
|
777
|
+
attn_split: attention outputs for chunks,
|
|
778
|
+
either as a list of tensors of shapes [B, M, G, H, Kq] or [B, M, H, Kq]
|
|
779
|
+
or as a single tensor of shape [num_chunks, B, M, G, H, Kq]
|
|
780
|
+
or [num_chunks, B, M, H, Kq]
|
|
781
|
+
lse_split: LSE for chunks,
|
|
782
|
+
either as a list of tensors of shapes [B, G, H, M] or [B, H, M]
|
|
783
|
+
or as a single tensor of shape [num_chunks, B, G, H, M] or [num_chunks, B, H, M]
|
|
784
|
+
write_lse: whether to output LSE
|
|
785
|
+
output_dtype: dtype of attn_out
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
attn_out: [B, M, G, H, Kq] or [B, M, H, Kq]
|
|
789
|
+
lse_out: [B, G, H, M] or [B, H, M] if write_lse
|
|
790
|
+
or None otherwise
|
|
791
|
+
"""
|
|
792
|
+
|
|
793
|
+
attn_is_concat = isinstance(attn_split, torch.Tensor)
|
|
794
|
+
lse_is_concat = isinstance(lse_split, torch.Tensor)
|
|
795
|
+
|
|
796
|
+
attn_requires_grad = (
|
|
797
|
+
attn_split.requires_grad # type: ignore
|
|
798
|
+
if attn_is_concat
|
|
799
|
+
else any(x.requires_grad for x in attn_split)
|
|
800
|
+
)
|
|
801
|
+
lse_requires_grad = (
|
|
802
|
+
lse_split.requires_grad # type: ignore
|
|
803
|
+
if lse_is_concat
|
|
804
|
+
else any(x.requires_grad for x in lse_split)
|
|
805
|
+
)
|
|
806
|
+
requires_grad = torch.is_grad_enabled() and (
|
|
807
|
+
attn_requires_grad or lse_requires_grad
|
|
808
|
+
)
|
|
809
|
+
if requires_grad and not write_lse:
|
|
810
|
+
raise ValueError("write_lse should be true if inputs require gradients.")
|
|
811
|
+
|
|
812
|
+
concat_path = attn_is_concat and lse_is_concat and not requires_grad
|
|
813
|
+
if concat_path:
|
|
814
|
+
attn_split = cast(torch.Tensor, attn_split)
|
|
815
|
+
lse_split = cast(torch.Tensor, lse_split)
|
|
816
|
+
if attn_split.ndim != lse_split.ndim + 1:
|
|
817
|
+
raise ValueError(
|
|
818
|
+
f"Incompatible input shapes: {attn_split.shape=}, {lse_split.shape=}"
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
is_bmhk = attn_split.ndim == 5
|
|
822
|
+
if is_bmhk:
|
|
823
|
+
attn_split = attn_split.unsqueeze(3)
|
|
824
|
+
lse_split = lse_split.unsqueeze(2)
|
|
825
|
+
|
|
826
|
+
num_chunks, B, M, G, H, Kq = attn_split.shape
|
|
827
|
+
num_chunks1, B1, G1, H1, M1 = lse_split.shape
|
|
828
|
+
if B != B1 or G != G1 or H != H1 or num_chunks != num_chunks1 or M != M:
|
|
829
|
+
raise ValueError(
|
|
830
|
+
f"Incompatible input shapes: {attn_split.shape=} {lse_split.shape=} "
|
|
831
|
+
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {num_chunks}/{num_chunks1}, {M}/{M}"
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
attn_split = attn_split.permute(1, 3, 4, 0, 2, 5)
|
|
835
|
+
lse_split = lse_split.permute(1, 2, 3, 0, 4)
|
|
836
|
+
|
|
837
|
+
device = attn_split.device
|
|
838
|
+
attn_dtype = attn_split.dtype
|
|
839
|
+
lse_dtype = lse_split.dtype
|
|
840
|
+
else:
|
|
841
|
+
if attn_is_concat:
|
|
842
|
+
attn_split = attn_split.unbind(0) # type: ignore
|
|
843
|
+
if lse_is_concat:
|
|
844
|
+
lse_split = lse_split.unbind(0) # type: ignore
|
|
845
|
+
num_chunks = len(attn_split)
|
|
846
|
+
if len(lse_split) != num_chunks:
|
|
847
|
+
raise ValueError(
|
|
848
|
+
f"Incompatible number of LSE and attention chunks: {len(attn_split)=}, {len(lse_split)=}"
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
attn_unsqueezed = []
|
|
852
|
+
lse_unsqueezed = []
|
|
853
|
+
is_bmhk = False
|
|
854
|
+
for i in range(num_chunks):
|
|
855
|
+
if attn_split[i].ndim != lse_split[i].ndim + 1:
|
|
856
|
+
raise ValueError(
|
|
857
|
+
f"Incompatible input shapes for chunk {i}: {attn_split[i].shape=}, {lse_split[i].shape=}"
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
is_bmhk = attn_split[i].ndim == 4
|
|
861
|
+
if is_bmhk:
|
|
862
|
+
attn_unsqueezed.append(attn_split[i].unsqueeze(2))
|
|
863
|
+
lse_unsqueezed.append(lse_split[i].unsqueeze(1))
|
|
864
|
+
else:
|
|
865
|
+
attn_unsqueezed.append(attn_split[i])
|
|
866
|
+
lse_unsqueezed.append(lse_split[i])
|
|
867
|
+
attn_split, lse_split = attn_unsqueezed, lse_unsqueezed
|
|
868
|
+
|
|
869
|
+
B, M, G, H, Kq = attn_split[0].shape
|
|
870
|
+
B1, G1, H1, M1 = lse_split[0].shape
|
|
871
|
+
if B != B1 or G != G1 or H != H1 or M != M:
|
|
872
|
+
raise ValueError(
|
|
873
|
+
f"Incompatible input shapes: {attn_split[0].shape=}, {lse_split[0].shape=} "
|
|
874
|
+
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {M}/{M}"
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
for i in range(num_chunks):
|
|
878
|
+
if attn_split[i].shape != (B, M, G, H, Kq):
|
|
879
|
+
raise ValueError(
|
|
880
|
+
f"Incompatible input shapes for attention chunk {i}: "
|
|
881
|
+
f"{attn_split[i].shape=}, {(B, M, G, H, Kq)=}"
|
|
882
|
+
)
|
|
883
|
+
if lse_split[i].shape != (B, G, H, M):
|
|
884
|
+
raise ValueError(
|
|
885
|
+
f"Incompatible input shapes for LSE chunk {i}: "
|
|
886
|
+
f"{lse_split[i].shape=}, {(B, G, H, M)=}"
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
attn_split[i] = attn_split[i].permute(0, 2, 3, 1, 4) # to (B, G, H, M, Kq)
|
|
890
|
+
|
|
891
|
+
device = attn_split[0].device
|
|
892
|
+
attn_dtype = attn_split[0].dtype
|
|
893
|
+
lse_dtype = lse_split[0].dtype
|
|
894
|
+
|
|
895
|
+
if concat_path:
|
|
896
|
+
attn_out = torch.empty(
|
|
897
|
+
B,
|
|
898
|
+
M,
|
|
899
|
+
G,
|
|
900
|
+
H,
|
|
901
|
+
Kq,
|
|
902
|
+
device=device,
|
|
903
|
+
dtype=output_dtype or attn_dtype,
|
|
904
|
+
)
|
|
905
|
+
if write_lse:
|
|
906
|
+
lse_out = torch.empty(
|
|
907
|
+
B,
|
|
908
|
+
G,
|
|
909
|
+
H,
|
|
910
|
+
M,
|
|
911
|
+
device=device,
|
|
912
|
+
dtype=lse_dtype,
|
|
913
|
+
)
|
|
914
|
+
else:
|
|
915
|
+
lse_out = None
|
|
916
|
+
triton_splitk.merge_attentions(attn_out, lse_out, attn_split, lse_split) # type: ignore
|
|
917
|
+
else:
|
|
918
|
+
outs = triton_splitk.merge_attentions_varargs(
|
|
919
|
+
attn_split, lse_split, write_lse, output_dtype, B, M, G, H, Kq
|
|
920
|
+
) # type: ignore
|
|
921
|
+
attn_out = outs[0]
|
|
922
|
+
lse_out = outs[1] if write_lse else None
|
|
923
|
+
|
|
924
|
+
if is_bmhk:
|
|
925
|
+
attn_out = attn_out[:, :, 0]
|
|
926
|
+
if lse_out is not None:
|
|
927
|
+
lse_out = lse_out[:, 0]
|
|
928
|
+
|
|
929
|
+
return attn_out, lse_out
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
ALL_FW_OPS: List[Type[AttentionFwOpBase]] = [
|
|
933
|
+
cutlass.FwOp if torch.version.cuda else ck.FwOp,
|
|
934
|
+
cutlass_blackwell.FwOp,
|
|
935
|
+
flash.FwOp,
|
|
936
|
+
flash_mtia.FwOp,
|
|
937
|
+
flash3.FwOp,
|
|
938
|
+
triton_splitk.FwOp,
|
|
939
|
+
]
|
|
940
|
+
|
|
941
|
+
ALL_BW_OPS: List[Type[AttentionBwOpBase]] = [
|
|
942
|
+
cutlass.BwOp if torch.version.cuda else ck.BwOp,
|
|
943
|
+
cutlass_blackwell.BwOp,
|
|
944
|
+
flash.BwOp,
|
|
945
|
+
flash_mtia.BwOp,
|
|
946
|
+
flash3.BwOp,
|
|
947
|
+
]
|
|
948
|
+
|
|
949
|
+
__all__ = [
|
|
950
|
+
"AttentionBias",
|
|
951
|
+
"AttentionOp",
|
|
952
|
+
"AttentionOpBase",
|
|
953
|
+
"LowerTriangularMask",
|
|
954
|
+
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
|
|
955
|
+
"MemoryEfficientAttentionCutlassOp",
|
|
956
|
+
"MemoryEfficientAttentionFlashAttentionOp",
|
|
957
|
+
"MemoryEfficientAttentionFlashMtiaAttentionOp",
|
|
958
|
+
"memory_efficient_attention",
|
|
959
|
+
"MemoryEfficientAttentionCkOp",
|
|
960
|
+
"MemoryEfficientAttentionCkDecoderOp",
|
|
961
|
+
"ALL_FW_OPS",
|
|
962
|
+
"ALL_BW_OPS",
|
|
963
|
+
"attn_bias",
|
|
964
|
+
"_get_use_fa3",
|
|
965
|
+
"_set_use_fa3",
|
|
966
|
+
"BlockDiagonalMask",
|
|
967
|
+
]
|