mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,245 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+
9
+ from typing import Any, Iterable, Mapping, Optional, Set, Tuple
10
+
11
+ import torch
12
+
13
+ from .attn_bias import (
14
+ BlockDiagonalCausalFromBottomRightMask,
15
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
16
+ BlockDiagonalCausalLocalAttentionMask,
17
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
18
+ BlockDiagonalCausalMask,
19
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
20
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
21
+ BlockDiagonalGappyKeysMask,
22
+ BlockDiagonalLocalAttentionPaddedKeysMask,
23
+ BlockDiagonalMask,
24
+ BlockDiagonalPaddedKeysMask,
25
+ LocalAttentionFromBottomRightMask,
26
+ LowerTriangularFromBottomRightLocalAttentionMask,
27
+ LowerTriangularFromBottomRightMask,
28
+ LowerTriangularMask,
29
+ )
30
+ from .flash import BwOp as BwOpCUDA, FwOp as FwOpCUDA
31
+ from .utils.op_common import get_operator, register_operator
32
+
33
+ try:
34
+ import mtia.host_runtime.torch_mtia.dynamic_library # noqa # type: ignore
35
+
36
+ @torch.library.custom_op(
37
+ "mslk_flash_mtia::flash_fwd",
38
+ mutates_args=(),
39
+ device_types=["mtia"],
40
+ )
41
+ def _flash_fwd(
42
+ query: torch.Tensor,
43
+ key: torch.Tensor,
44
+ value: torch.Tensor,
45
+ cu_seqlens_q: Optional[torch.Tensor],
46
+ cu_seqlens_k: Optional[torch.Tensor],
47
+ seqused_k: Optional[torch.Tensor],
48
+ max_seqlen_q: int,
49
+ max_seqlen_k: int,
50
+ p: float,
51
+ softmax_scale: float,
52
+ is_causal: bool,
53
+ window_left: int,
54
+ window_right: int,
55
+ return_softmax: bool,
56
+ block_tables: Optional[torch.Tensor],
57
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
58
+ ret = torch.ops.aten._flash_attention_forward(
59
+ query,
60
+ key,
61
+ value,
62
+ cu_seqlens_q, # cum_seq_q
63
+ cu_seqlens_k, # cum_seq_k
64
+ max_seqlen_q, # max_q
65
+ max_seqlen_k, # max_k
66
+ p, # dropout_p
67
+ is_causal,
68
+ return_debug_mask=False,
69
+ scale=softmax_scale,
70
+ window_size_left=window_left,
71
+ window_size_right=window_right,
72
+ seqused_k=seqused_k,
73
+ alibi_slopes=None, # alibi_slopes
74
+ )
75
+ attention, logsumexp, rng_state, _, _ = ret
76
+ return attention, logsumexp, rng_state
77
+
78
+ @torch.library.register_fake("mslk_flash_mtia::flash_fwd")
79
+ def _flash_fwd_abstract(
80
+ query,
81
+ key,
82
+ value,
83
+ cu_seqlens_q,
84
+ cu_seqlens_k,
85
+ seqused_k,
86
+ max_seqlen_q,
87
+ max_seqlen_k,
88
+ p,
89
+ softmax_scale,
90
+ is_causal,
91
+ window_left,
92
+ window_right,
93
+ return_softmax,
94
+ block_tables,
95
+ ):
96
+ out = torch.empty_like(query)
97
+ if cu_seqlens_q is None:
98
+ B, M, H, K = query.shape
99
+ lse_shape = [B, H, M] # XXXX ?
100
+ else:
101
+ M, H, K = query.shape
102
+ B = cu_seqlens_q.shape[0] - 1
103
+ lse_shape = [H, M]
104
+ softmax_lse = torch.empty(lse_shape, device=query.device, dtype=torch.float32)
105
+ rng_state = torch.empty([2], device=query.device, dtype=torch.int64)
106
+ return out, softmax_lse, rng_state
107
+
108
+ @torch.library.custom_op(
109
+ "mslk_flash_mtia::flash_bwd",
110
+ mutates_args=(),
111
+ device_types=["mtia"],
112
+ )
113
+ def _flash_bwd(
114
+ grads_share_storage: bool,
115
+ grad: torch.Tensor,
116
+ query: torch.Tensor,
117
+ key: torch.Tensor,
118
+ value: torch.Tensor,
119
+ out: torch.Tensor,
120
+ lse: torch.Tensor,
121
+ cu_seqlens_q: torch.Tensor,
122
+ cu_seqlens_k: torch.Tensor,
123
+ max_seqlen_q: int,
124
+ max_seqlen_k: int,
125
+ p: float,
126
+ softmax_scale: float,
127
+ is_causal: bool,
128
+ window_left: int,
129
+ window_right: int,
130
+ rng_state: torch.Tensor,
131
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
132
+ rng_state0 = rng_state1 = rng_state
133
+ dq, dk, dv = torch.ops.aten._flash_attention_backward(
134
+ grad,
135
+ query,
136
+ key,
137
+ value,
138
+ out,
139
+ lse,
140
+ cu_seqlens_q,
141
+ cu_seqlens_k,
142
+ max_seqlen_q,
143
+ max_seqlen_k,
144
+ p,
145
+ is_causal,
146
+ rng_state0,
147
+ rng_state1,
148
+ scale=softmax_scale,
149
+ window_size_left=window_left,
150
+ window_size_right=window_right,
151
+ )
152
+ return dq, dk, dv
153
+
154
+ @torch.library.register_fake("mslk_flash_mtia::flash_bwd")
155
+ def _flash_bwd_abstract(
156
+ grads_share_storage,
157
+ grad,
158
+ query,
159
+ key,
160
+ value,
161
+ *args,
162
+ **kwargs,
163
+ ):
164
+ return _create_dq_dk_dv(grads_share_storage, query, key, value)
165
+
166
+ except (ImportError, OSError):
167
+ pass
168
+
169
+
170
+ def _create_dq_dk_dv(
171
+ grads_share_storage: bool, query, key, value
172
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
173
+ # Create dq,dk,dv
174
+ # If Q/K/V come from a single QKV tensor, let's put the gradient in the
175
+ # right strides, so we can avoid a `cat`
176
+ if grads_share_storage:
177
+ chunk = torch.empty(
178
+ (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
179
+ dtype=query.dtype,
180
+ device=query.device,
181
+ )
182
+ return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2)
183
+ return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
184
+
185
+
186
+ FLASH_VERSION = torch.nn.attention._get_flash_version() # noqa # type: ignore
187
+
188
+
189
+ @register_operator
190
+ class FwOp(FwOpCUDA):
191
+ """Operator that computes memory-efficient attention using MTIA devicesa"""
192
+
193
+ OPERATOR = get_operator("mslk_flash_mtia", "flash_fwd")
194
+ SUPPORTED_DEVICES: Set[str] = {"mtia"}
195
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
196
+ type(None),
197
+ LowerTriangularMask,
198
+ LowerTriangularFromBottomRightMask,
199
+ LowerTriangularFromBottomRightLocalAttentionMask,
200
+ BlockDiagonalMask,
201
+ BlockDiagonalCausalMask,
202
+ BlockDiagonalCausalLocalAttentionMask,
203
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
204
+ BlockDiagonalLocalAttentionPaddedKeysMask,
205
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
206
+ BlockDiagonalCausalFromBottomRightMask,
207
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
208
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
209
+ BlockDiagonalGappyKeysMask,
210
+ BlockDiagonalPaddedKeysMask,
211
+ LocalAttentionFromBottomRightMask,
212
+ )
213
+ NAME = f"fa2F@{FLASH_VERSION}-mtia"
214
+
215
+ ERROR_ATOL: Mapping[torch.dtype, float] = {
216
+ torch.float: 3e-4,
217
+ torch.half: 7e-3,
218
+ torch.bfloat16: 2e-2,
219
+ }
220
+ ERROR_RTOL: Mapping[torch.dtype, float] = {
221
+ torch.float: 2e-5,
222
+ torch.half: 4e-4,
223
+ torch.bfloat16: 5e-3,
224
+ }
225
+
226
+
227
+ @register_operator
228
+ class BwOp(BwOpCUDA):
229
+ """Operator that computes memory-efficient attention using MTIA devicesa"""
230
+
231
+ OPERATOR = get_operator("mslk_flash_mtia", "flash_bwd")
232
+ SUPPORTED_DEVICES: Set[str] = {"mtia"}
233
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
234
+ type(None),
235
+ LowerTriangularMask,
236
+ LowerTriangularFromBottomRightMask,
237
+ LowerTriangularFromBottomRightLocalAttentionMask,
238
+ BlockDiagonalMask,
239
+ BlockDiagonalCausalMask,
240
+ BlockDiagonalCausalLocalAttentionMask,
241
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
242
+ BlockDiagonalCausalFromBottomRightMask,
243
+ LocalAttentionFromBottomRightMask,
244
+ )
245
+ NAME = f"fa2B@{FLASH_VERSION}-mtia"
@@ -0,0 +1,192 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+ from typing import Callable, Optional, Tuple, Type, Union
8
+
9
+ import torch
10
+
11
+ from .. import fmha
12
+
13
+
14
+ """
15
+ Friendly wrapper around merge_attentions which works with autograd.
16
+
17
+ Use as follows
18
+
19
+ ```
20
+ partial1 = memory_efficient_attention_partial_autograd(q, k1, v1, ...)
21
+ partial2 = memory_efficient_attention_partial_autograd(q, k2, v2, ...)
22
+ attn_out = merge_attentions_autograd(partial1, partial2)
23
+ ```
24
+
25
+ merge_attentions_autograd() can take any number of inputs. Note that
26
+ partial1 and partial2 are not tensors, but rather objects of type
27
+ `Partial`.
28
+
29
+ If you have partial1 and you changed your mind and don't
30
+ want to merge it with anything, you can do
31
+ ```
32
+ attn_out = merge_attentions_autograd(partial1)
33
+ ```
34
+
35
+ """
36
+
37
+
38
+ class _PartialFunc(torch.autograd.Function):
39
+ @staticmethod
40
+ def forward( # type: ignore[override]
41
+ ctx: torch.autograd.function.FunctionCtx,
42
+ query: torch.Tensor,
43
+ key: torch.Tensor,
44
+ value: torch.Tensor,
45
+ attn_bias: Optional[Union[torch.Tensor, fmha.AttentionBias]],
46
+ p: float = 0.0,
47
+ scale: Optional[float] = None,
48
+ op: Optional[Union[fmha.AttentionOp, Type[fmha.AttentionFwOpBase]]] = None,
49
+ output_dtype: Optional[torch.dtype] = None,
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ ctx.bias = attn_bias # type: ignore
52
+ ctx.save_for_backward(query, key, value)
53
+ ctx.p = p # type: ignore
54
+ ctx.scale = scale # type: ignore
55
+ ctx.op = op[1] if isinstance(op, tuple) else None # type: ignore
56
+ attn, lse = fmha.memory_efficient_attention_partial(
57
+ query, key, value, attn_bias, p, scale, op=op, output_dtype=output_dtype
58
+ )
59
+ placeholder = torch.empty_like(attn)
60
+ return attn, lse, placeholder
61
+
62
+ @staticmethod
63
+ def backward( # type: ignore[override]
64
+ ctx: torch.autograd.function.FunctionCtx,
65
+ grad_attn: torch.Tensor,
66
+ lse: torch.Tensor,
67
+ out: torch.Tensor,
68
+ ) -> Tuple[Optional[torch.Tensor], ...]:
69
+ query, key, value = ctx.saved_tensors # type: ignore
70
+ grad_q, grad_k, grad_v = fmha.memory_efficient_attention_backward(
71
+ grad_attn,
72
+ out,
73
+ lse.contiguous(),
74
+ query,
75
+ key,
76
+ value,
77
+ ctx.bias, # type: ignore
78
+ ctx.p, # type: ignore
79
+ ctx.scale, # type: ignore
80
+ op=ctx.op, # type: ignore
81
+ )
82
+ return grad_q, grad_k, grad_v, None, None, None, None, None
83
+
84
+
85
+ class _MergeFunc(torch.autograd.Function):
86
+ @staticmethod
87
+ def forward( # type: ignore[override]
88
+ ctx: torch.autograd.function.FunctionCtx,
89
+ *inputs: torch.Tensor,
90
+ ) -> torch.Tensor:
91
+ ctx.length = len(inputs) // 3 # type: ignore
92
+ if ctx.length > 1: # type: ignore
93
+ attns = inputs[::3]
94
+ lses = inputs[1::3]
95
+ out, lse = fmha.merge_attentions(attns, lses)
96
+ assert lse is not None
97
+ else:
98
+ out, lse = inputs[:2]
99
+ ctx.save_for_backward(out, lse)
100
+ return out
101
+
102
+ @staticmethod
103
+ def backward( # type: ignore[override]
104
+ ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor
105
+ ) -> Tuple[Optional[torch.Tensor], ...]:
106
+ out, lse = ctx.saved_tensors # type: ignore
107
+ return (grad_out, lse, out) * ctx.length # type: ignore
108
+
109
+
110
+ class Partial:
111
+ """
112
+ This class is used to represent a partial attention output, which is
113
+ returned by `memory_efficient_attention_partial_autograd`.
114
+
115
+ Attributes: (Do not access them directly, use the methods instead.)
116
+
117
+ _attn: torch.Tensor
118
+ _lse: torch.Tensor . (Its grad is the full LSE to be used by
119
+ the individual backward passes.)
120
+ _placeholder: torch.Tensor, whose grad is used for passing the full attention
121
+ output to the individual backward passes.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ attn: torch.Tensor,
127
+ lse: torch.Tensor,
128
+ placeholder: torch.Tensor,
129
+ ) -> None:
130
+ """
131
+ Internal use only
132
+ """
133
+ self._attn = attn
134
+ self._lse = lse
135
+ self._placeholder = placeholder
136
+
137
+ def is_bmghk(self) -> bool:
138
+ return self._attn.ndim == 5
139
+
140
+ def apply(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> "Partial":
141
+ """
142
+ Applies fn to the attention output, as if we were a tensor.
143
+ fn must expect tensors of shape BMGHK or BMHK, but cannot actually
144
+ manipulate the K dimension (because the LSE doesn't have one).
145
+
146
+ For example, to slice on the sequence dimension, you might apply
147
+ `lambda x: x[:, start:end]`.
148
+ """
149
+ attn = fn(self._attn)
150
+ if self.is_bmghk():
151
+ rearranged = self._lse.permute(0, 3, 1, 2).unsqueeze(-1)
152
+ lse = fn(rearranged).squeeze(-1).permute(0, 2, 3, 1)
153
+ else:
154
+ rearranged = self._lse.permute(0, 2, 1).unsqueeze(-1)
155
+ lse = fn(rearranged).squeeze(-1).permute(0, 2, 1)
156
+ placeholder = fn(self._placeholder)
157
+ return self.__class__(attn, lse, placeholder)
158
+
159
+ def _tuple(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
160
+ return self._attn, self._lse, self._placeholder
161
+
162
+
163
+ def memory_efficient_attention_partial_autograd(
164
+ query: torch.Tensor,
165
+ key: torch.Tensor,
166
+ value: torch.Tensor,
167
+ attn_bias: Optional[Union[torch.Tensor, fmha.AttentionBias]] = None,
168
+ p: float = 0.0,
169
+ scale: Optional[float] = None,
170
+ *,
171
+ op: Optional[Union[fmha.AttentionOp, Type[fmha.AttentionFwOpBase]]] = None,
172
+ output_dtype: Optional[torch.dtype] = None,
173
+ ) -> Partial:
174
+ """
175
+ Wrapper around `memory_efficient_attention_partial` which works with autograd.
176
+ Arguments are the same as for `memory_efficient_attention_partial`.
177
+ """
178
+ return Partial(
179
+ *_PartialFunc.apply(query, key, value, attn_bias, p, scale, op, output_dtype)
180
+ )
181
+
182
+
183
+ def merge_attentions_autograd(
184
+ *partials: Partial,
185
+ ) -> torch.Tensor:
186
+ """
187
+ Wrapper around merge_attentions which works with autograd.
188
+ """
189
+ args = [i for part in partials for i in part._tuple()]
190
+ if len(args) == 0:
191
+ raise ValueError("No partials to merge")
192
+ return _MergeFunc.apply(*args)