mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from typing import Any, Iterable, Mapping, Optional, Set, Tuple
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from .attn_bias import (
|
|
14
|
+
BlockDiagonalCausalFromBottomRightMask,
|
|
15
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
16
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
17
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
18
|
+
BlockDiagonalCausalMask,
|
|
19
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
20
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
21
|
+
BlockDiagonalGappyKeysMask,
|
|
22
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
23
|
+
BlockDiagonalMask,
|
|
24
|
+
BlockDiagonalPaddedKeysMask,
|
|
25
|
+
LocalAttentionFromBottomRightMask,
|
|
26
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
27
|
+
LowerTriangularFromBottomRightMask,
|
|
28
|
+
LowerTriangularMask,
|
|
29
|
+
)
|
|
30
|
+
from .flash import BwOp as BwOpCUDA, FwOp as FwOpCUDA
|
|
31
|
+
from .utils.op_common import get_operator, register_operator
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
import mtia.host_runtime.torch_mtia.dynamic_library # noqa # type: ignore
|
|
35
|
+
|
|
36
|
+
@torch.library.custom_op(
|
|
37
|
+
"mslk_flash_mtia::flash_fwd",
|
|
38
|
+
mutates_args=(),
|
|
39
|
+
device_types=["mtia"],
|
|
40
|
+
)
|
|
41
|
+
def _flash_fwd(
|
|
42
|
+
query: torch.Tensor,
|
|
43
|
+
key: torch.Tensor,
|
|
44
|
+
value: torch.Tensor,
|
|
45
|
+
cu_seqlens_q: Optional[torch.Tensor],
|
|
46
|
+
cu_seqlens_k: Optional[torch.Tensor],
|
|
47
|
+
seqused_k: Optional[torch.Tensor],
|
|
48
|
+
max_seqlen_q: int,
|
|
49
|
+
max_seqlen_k: int,
|
|
50
|
+
p: float,
|
|
51
|
+
softmax_scale: float,
|
|
52
|
+
is_causal: bool,
|
|
53
|
+
window_left: int,
|
|
54
|
+
window_right: int,
|
|
55
|
+
return_softmax: bool,
|
|
56
|
+
block_tables: Optional[torch.Tensor],
|
|
57
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
58
|
+
ret = torch.ops.aten._flash_attention_forward(
|
|
59
|
+
query,
|
|
60
|
+
key,
|
|
61
|
+
value,
|
|
62
|
+
cu_seqlens_q, # cum_seq_q
|
|
63
|
+
cu_seqlens_k, # cum_seq_k
|
|
64
|
+
max_seqlen_q, # max_q
|
|
65
|
+
max_seqlen_k, # max_k
|
|
66
|
+
p, # dropout_p
|
|
67
|
+
is_causal,
|
|
68
|
+
return_debug_mask=False,
|
|
69
|
+
scale=softmax_scale,
|
|
70
|
+
window_size_left=window_left,
|
|
71
|
+
window_size_right=window_right,
|
|
72
|
+
seqused_k=seqused_k,
|
|
73
|
+
alibi_slopes=None, # alibi_slopes
|
|
74
|
+
)
|
|
75
|
+
attention, logsumexp, rng_state, _, _ = ret
|
|
76
|
+
return attention, logsumexp, rng_state
|
|
77
|
+
|
|
78
|
+
@torch.library.register_fake("mslk_flash_mtia::flash_fwd")
|
|
79
|
+
def _flash_fwd_abstract(
|
|
80
|
+
query,
|
|
81
|
+
key,
|
|
82
|
+
value,
|
|
83
|
+
cu_seqlens_q,
|
|
84
|
+
cu_seqlens_k,
|
|
85
|
+
seqused_k,
|
|
86
|
+
max_seqlen_q,
|
|
87
|
+
max_seqlen_k,
|
|
88
|
+
p,
|
|
89
|
+
softmax_scale,
|
|
90
|
+
is_causal,
|
|
91
|
+
window_left,
|
|
92
|
+
window_right,
|
|
93
|
+
return_softmax,
|
|
94
|
+
block_tables,
|
|
95
|
+
):
|
|
96
|
+
out = torch.empty_like(query)
|
|
97
|
+
if cu_seqlens_q is None:
|
|
98
|
+
B, M, H, K = query.shape
|
|
99
|
+
lse_shape = [B, H, M] # XXXX ?
|
|
100
|
+
else:
|
|
101
|
+
M, H, K = query.shape
|
|
102
|
+
B = cu_seqlens_q.shape[0] - 1
|
|
103
|
+
lse_shape = [H, M]
|
|
104
|
+
softmax_lse = torch.empty(lse_shape, device=query.device, dtype=torch.float32)
|
|
105
|
+
rng_state = torch.empty([2], device=query.device, dtype=torch.int64)
|
|
106
|
+
return out, softmax_lse, rng_state
|
|
107
|
+
|
|
108
|
+
@torch.library.custom_op(
|
|
109
|
+
"mslk_flash_mtia::flash_bwd",
|
|
110
|
+
mutates_args=(),
|
|
111
|
+
device_types=["mtia"],
|
|
112
|
+
)
|
|
113
|
+
def _flash_bwd(
|
|
114
|
+
grads_share_storage: bool,
|
|
115
|
+
grad: torch.Tensor,
|
|
116
|
+
query: torch.Tensor,
|
|
117
|
+
key: torch.Tensor,
|
|
118
|
+
value: torch.Tensor,
|
|
119
|
+
out: torch.Tensor,
|
|
120
|
+
lse: torch.Tensor,
|
|
121
|
+
cu_seqlens_q: torch.Tensor,
|
|
122
|
+
cu_seqlens_k: torch.Tensor,
|
|
123
|
+
max_seqlen_q: int,
|
|
124
|
+
max_seqlen_k: int,
|
|
125
|
+
p: float,
|
|
126
|
+
softmax_scale: float,
|
|
127
|
+
is_causal: bool,
|
|
128
|
+
window_left: int,
|
|
129
|
+
window_right: int,
|
|
130
|
+
rng_state: torch.Tensor,
|
|
131
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
132
|
+
rng_state0 = rng_state1 = rng_state
|
|
133
|
+
dq, dk, dv = torch.ops.aten._flash_attention_backward(
|
|
134
|
+
grad,
|
|
135
|
+
query,
|
|
136
|
+
key,
|
|
137
|
+
value,
|
|
138
|
+
out,
|
|
139
|
+
lse,
|
|
140
|
+
cu_seqlens_q,
|
|
141
|
+
cu_seqlens_k,
|
|
142
|
+
max_seqlen_q,
|
|
143
|
+
max_seqlen_k,
|
|
144
|
+
p,
|
|
145
|
+
is_causal,
|
|
146
|
+
rng_state0,
|
|
147
|
+
rng_state1,
|
|
148
|
+
scale=softmax_scale,
|
|
149
|
+
window_size_left=window_left,
|
|
150
|
+
window_size_right=window_right,
|
|
151
|
+
)
|
|
152
|
+
return dq, dk, dv
|
|
153
|
+
|
|
154
|
+
@torch.library.register_fake("mslk_flash_mtia::flash_bwd")
|
|
155
|
+
def _flash_bwd_abstract(
|
|
156
|
+
grads_share_storage,
|
|
157
|
+
grad,
|
|
158
|
+
query,
|
|
159
|
+
key,
|
|
160
|
+
value,
|
|
161
|
+
*args,
|
|
162
|
+
**kwargs,
|
|
163
|
+
):
|
|
164
|
+
return _create_dq_dk_dv(grads_share_storage, query, key, value)
|
|
165
|
+
|
|
166
|
+
except (ImportError, OSError):
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _create_dq_dk_dv(
|
|
171
|
+
grads_share_storage: bool, query, key, value
|
|
172
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
173
|
+
# Create dq,dk,dv
|
|
174
|
+
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
|
|
175
|
+
# right strides, so we can avoid a `cat`
|
|
176
|
+
if grads_share_storage:
|
|
177
|
+
chunk = torch.empty(
|
|
178
|
+
(*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
|
|
179
|
+
dtype=query.dtype,
|
|
180
|
+
device=query.device,
|
|
181
|
+
)
|
|
182
|
+
return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2)
|
|
183
|
+
return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
FLASH_VERSION = torch.nn.attention._get_flash_version() # noqa # type: ignore
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@register_operator
|
|
190
|
+
class FwOp(FwOpCUDA):
|
|
191
|
+
"""Operator that computes memory-efficient attention using MTIA devicesa"""
|
|
192
|
+
|
|
193
|
+
OPERATOR = get_operator("mslk_flash_mtia", "flash_fwd")
|
|
194
|
+
SUPPORTED_DEVICES: Set[str] = {"mtia"}
|
|
195
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
196
|
+
type(None),
|
|
197
|
+
LowerTriangularMask,
|
|
198
|
+
LowerTriangularFromBottomRightMask,
|
|
199
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
200
|
+
BlockDiagonalMask,
|
|
201
|
+
BlockDiagonalCausalMask,
|
|
202
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
203
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
204
|
+
BlockDiagonalLocalAttentionPaddedKeysMask,
|
|
205
|
+
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
|
|
206
|
+
BlockDiagonalCausalFromBottomRightMask,
|
|
207
|
+
BlockDiagonalCausalWithOffsetGappyKeysMask,
|
|
208
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
209
|
+
BlockDiagonalGappyKeysMask,
|
|
210
|
+
BlockDiagonalPaddedKeysMask,
|
|
211
|
+
LocalAttentionFromBottomRightMask,
|
|
212
|
+
)
|
|
213
|
+
NAME = f"fa2F@{FLASH_VERSION}-mtia"
|
|
214
|
+
|
|
215
|
+
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
|
216
|
+
torch.float: 3e-4,
|
|
217
|
+
torch.half: 7e-3,
|
|
218
|
+
torch.bfloat16: 2e-2,
|
|
219
|
+
}
|
|
220
|
+
ERROR_RTOL: Mapping[torch.dtype, float] = {
|
|
221
|
+
torch.float: 2e-5,
|
|
222
|
+
torch.half: 4e-4,
|
|
223
|
+
torch.bfloat16: 5e-3,
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@register_operator
|
|
228
|
+
class BwOp(BwOpCUDA):
|
|
229
|
+
"""Operator that computes memory-efficient attention using MTIA devicesa"""
|
|
230
|
+
|
|
231
|
+
OPERATOR = get_operator("mslk_flash_mtia", "flash_bwd")
|
|
232
|
+
SUPPORTED_DEVICES: Set[str] = {"mtia"}
|
|
233
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
234
|
+
type(None),
|
|
235
|
+
LowerTriangularMask,
|
|
236
|
+
LowerTriangularFromBottomRightMask,
|
|
237
|
+
LowerTriangularFromBottomRightLocalAttentionMask,
|
|
238
|
+
BlockDiagonalMask,
|
|
239
|
+
BlockDiagonalCausalMask,
|
|
240
|
+
BlockDiagonalCausalLocalAttentionMask,
|
|
241
|
+
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
242
|
+
BlockDiagonalCausalFromBottomRightMask,
|
|
243
|
+
LocalAttentionFromBottomRightMask,
|
|
244
|
+
)
|
|
245
|
+
NAME = f"fa2B@{FLASH_VERSION}-mtia"
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
from typing import Callable, Optional, Tuple, Type, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from .. import fmha
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
"""
|
|
15
|
+
Friendly wrapper around merge_attentions which works with autograd.
|
|
16
|
+
|
|
17
|
+
Use as follows
|
|
18
|
+
|
|
19
|
+
```
|
|
20
|
+
partial1 = memory_efficient_attention_partial_autograd(q, k1, v1, ...)
|
|
21
|
+
partial2 = memory_efficient_attention_partial_autograd(q, k2, v2, ...)
|
|
22
|
+
attn_out = merge_attentions_autograd(partial1, partial2)
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
merge_attentions_autograd() can take any number of inputs. Note that
|
|
26
|
+
partial1 and partial2 are not tensors, but rather objects of type
|
|
27
|
+
`Partial`.
|
|
28
|
+
|
|
29
|
+
If you have partial1 and you changed your mind and don't
|
|
30
|
+
want to merge it with anything, you can do
|
|
31
|
+
```
|
|
32
|
+
attn_out = merge_attentions_autograd(partial1)
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class _PartialFunc(torch.autograd.Function):
|
|
39
|
+
@staticmethod
|
|
40
|
+
def forward( # type: ignore[override]
|
|
41
|
+
ctx: torch.autograd.function.FunctionCtx,
|
|
42
|
+
query: torch.Tensor,
|
|
43
|
+
key: torch.Tensor,
|
|
44
|
+
value: torch.Tensor,
|
|
45
|
+
attn_bias: Optional[Union[torch.Tensor, fmha.AttentionBias]],
|
|
46
|
+
p: float = 0.0,
|
|
47
|
+
scale: Optional[float] = None,
|
|
48
|
+
op: Optional[Union[fmha.AttentionOp, Type[fmha.AttentionFwOpBase]]] = None,
|
|
49
|
+
output_dtype: Optional[torch.dtype] = None,
|
|
50
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
51
|
+
ctx.bias = attn_bias # type: ignore
|
|
52
|
+
ctx.save_for_backward(query, key, value)
|
|
53
|
+
ctx.p = p # type: ignore
|
|
54
|
+
ctx.scale = scale # type: ignore
|
|
55
|
+
ctx.op = op[1] if isinstance(op, tuple) else None # type: ignore
|
|
56
|
+
attn, lse = fmha.memory_efficient_attention_partial(
|
|
57
|
+
query, key, value, attn_bias, p, scale, op=op, output_dtype=output_dtype
|
|
58
|
+
)
|
|
59
|
+
placeholder = torch.empty_like(attn)
|
|
60
|
+
return attn, lse, placeholder
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def backward( # type: ignore[override]
|
|
64
|
+
ctx: torch.autograd.function.FunctionCtx,
|
|
65
|
+
grad_attn: torch.Tensor,
|
|
66
|
+
lse: torch.Tensor,
|
|
67
|
+
out: torch.Tensor,
|
|
68
|
+
) -> Tuple[Optional[torch.Tensor], ...]:
|
|
69
|
+
query, key, value = ctx.saved_tensors # type: ignore
|
|
70
|
+
grad_q, grad_k, grad_v = fmha.memory_efficient_attention_backward(
|
|
71
|
+
grad_attn,
|
|
72
|
+
out,
|
|
73
|
+
lse.contiguous(),
|
|
74
|
+
query,
|
|
75
|
+
key,
|
|
76
|
+
value,
|
|
77
|
+
ctx.bias, # type: ignore
|
|
78
|
+
ctx.p, # type: ignore
|
|
79
|
+
ctx.scale, # type: ignore
|
|
80
|
+
op=ctx.op, # type: ignore
|
|
81
|
+
)
|
|
82
|
+
return grad_q, grad_k, grad_v, None, None, None, None, None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class _MergeFunc(torch.autograd.Function):
|
|
86
|
+
@staticmethod
|
|
87
|
+
def forward( # type: ignore[override]
|
|
88
|
+
ctx: torch.autograd.function.FunctionCtx,
|
|
89
|
+
*inputs: torch.Tensor,
|
|
90
|
+
) -> torch.Tensor:
|
|
91
|
+
ctx.length = len(inputs) // 3 # type: ignore
|
|
92
|
+
if ctx.length > 1: # type: ignore
|
|
93
|
+
attns = inputs[::3]
|
|
94
|
+
lses = inputs[1::3]
|
|
95
|
+
out, lse = fmha.merge_attentions(attns, lses)
|
|
96
|
+
assert lse is not None
|
|
97
|
+
else:
|
|
98
|
+
out, lse = inputs[:2]
|
|
99
|
+
ctx.save_for_backward(out, lse)
|
|
100
|
+
return out
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def backward( # type: ignore[override]
|
|
104
|
+
ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor
|
|
105
|
+
) -> Tuple[Optional[torch.Tensor], ...]:
|
|
106
|
+
out, lse = ctx.saved_tensors # type: ignore
|
|
107
|
+
return (grad_out, lse, out) * ctx.length # type: ignore
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class Partial:
|
|
111
|
+
"""
|
|
112
|
+
This class is used to represent a partial attention output, which is
|
|
113
|
+
returned by `memory_efficient_attention_partial_autograd`.
|
|
114
|
+
|
|
115
|
+
Attributes: (Do not access them directly, use the methods instead.)
|
|
116
|
+
|
|
117
|
+
_attn: torch.Tensor
|
|
118
|
+
_lse: torch.Tensor . (Its grad is the full LSE to be used by
|
|
119
|
+
the individual backward passes.)
|
|
120
|
+
_placeholder: torch.Tensor, whose grad is used for passing the full attention
|
|
121
|
+
output to the individual backward passes.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
attn: torch.Tensor,
|
|
127
|
+
lse: torch.Tensor,
|
|
128
|
+
placeholder: torch.Tensor,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""
|
|
131
|
+
Internal use only
|
|
132
|
+
"""
|
|
133
|
+
self._attn = attn
|
|
134
|
+
self._lse = lse
|
|
135
|
+
self._placeholder = placeholder
|
|
136
|
+
|
|
137
|
+
def is_bmghk(self) -> bool:
|
|
138
|
+
return self._attn.ndim == 5
|
|
139
|
+
|
|
140
|
+
def apply(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> "Partial":
|
|
141
|
+
"""
|
|
142
|
+
Applies fn to the attention output, as if we were a tensor.
|
|
143
|
+
fn must expect tensors of shape BMGHK or BMHK, but cannot actually
|
|
144
|
+
manipulate the K dimension (because the LSE doesn't have one).
|
|
145
|
+
|
|
146
|
+
For example, to slice on the sequence dimension, you might apply
|
|
147
|
+
`lambda x: x[:, start:end]`.
|
|
148
|
+
"""
|
|
149
|
+
attn = fn(self._attn)
|
|
150
|
+
if self.is_bmghk():
|
|
151
|
+
rearranged = self._lse.permute(0, 3, 1, 2).unsqueeze(-1)
|
|
152
|
+
lse = fn(rearranged).squeeze(-1).permute(0, 2, 3, 1)
|
|
153
|
+
else:
|
|
154
|
+
rearranged = self._lse.permute(0, 2, 1).unsqueeze(-1)
|
|
155
|
+
lse = fn(rearranged).squeeze(-1).permute(0, 2, 1)
|
|
156
|
+
placeholder = fn(self._placeholder)
|
|
157
|
+
return self.__class__(attn, lse, placeholder)
|
|
158
|
+
|
|
159
|
+
def _tuple(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
160
|
+
return self._attn, self._lse, self._placeholder
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def memory_efficient_attention_partial_autograd(
|
|
164
|
+
query: torch.Tensor,
|
|
165
|
+
key: torch.Tensor,
|
|
166
|
+
value: torch.Tensor,
|
|
167
|
+
attn_bias: Optional[Union[torch.Tensor, fmha.AttentionBias]] = None,
|
|
168
|
+
p: float = 0.0,
|
|
169
|
+
scale: Optional[float] = None,
|
|
170
|
+
*,
|
|
171
|
+
op: Optional[Union[fmha.AttentionOp, Type[fmha.AttentionFwOpBase]]] = None,
|
|
172
|
+
output_dtype: Optional[torch.dtype] = None,
|
|
173
|
+
) -> Partial:
|
|
174
|
+
"""
|
|
175
|
+
Wrapper around `memory_efficient_attention_partial` which works with autograd.
|
|
176
|
+
Arguments are the same as for `memory_efficient_attention_partial`.
|
|
177
|
+
"""
|
|
178
|
+
return Partial(
|
|
179
|
+
*_PartialFunc.apply(query, key, value, attn_bias, p, scale, op, output_dtype)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def merge_attentions_autograd(
|
|
184
|
+
*partials: Partial,
|
|
185
|
+
) -> torch.Tensor:
|
|
186
|
+
"""
|
|
187
|
+
Wrapper around merge_attentions which works with autograd.
|
|
188
|
+
"""
|
|
189
|
+
args = [i for part in partials for i in part._tuple()]
|
|
190
|
+
if len(args) == 0:
|
|
191
|
+
raise ValueError("No partials to merge")
|
|
192
|
+
return _MergeFunc.apply(*args)
|