mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,967 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union
9
+
10
+ import torch
11
+
12
+ from . import (
13
+ attn_bias,
14
+ ck,
15
+ ck_decoder,
16
+ ck_splitk,
17
+ cutlass,
18
+ cutlass_blackwell,
19
+ flash,
20
+ flash3,
21
+ flash_mtia,
22
+ triton_splitk,
23
+ )
24
+ from .attn_bias import (
25
+ AttentionBias,
26
+ BlockDiagonalMask,
27
+ LowerTriangularMask,
28
+ VARLEN_BIASES,
29
+ )
30
+ from .common import (
31
+ AttentionBwOpBase,
32
+ AttentionFwOpBase,
33
+ AttentionOp,
34
+ AttentionOpBase,
35
+ bmk2bmhk,
36
+ Context,
37
+ Gradients,
38
+ Inputs,
39
+ )
40
+ from .dispatch import (
41
+ _dispatch_bw,
42
+ _dispatch_fw,
43
+ _ensure_op_supports_or_raise,
44
+ _get_use_fa3,
45
+ _set_use_fa3,
46
+ )
47
+
48
+ MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
49
+ MemoryEfficientAttentionCutlassBlackwellOp = (
50
+ cutlass_blackwell.FwOp,
51
+ cutlass_blackwell.BwOp,
52
+ )
53
+ MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
54
+ MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
55
+ MemoryEfficientAttentionFlashMtiaAttentionOp = (flash_mtia.FwOp, flash_mtia.BwOp)
56
+ MemoryEfficientAttentionCkOp = (ck.FwOp, ck.BwOp)
57
+ MemoryEfficientAttentionCkDecoderOp = (ck_decoder.FwOp, ck.BwOp)
58
+ MemoryEfficientAttentionSplitKCkOp = (ck_splitk.FwOp, ck.BwOp)
59
+
60
+
61
+ def _deserialize_bias(attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]) -> Any:
62
+ if attn_bias_tensor is None:
63
+ return attn_bias_ctx
64
+ return attn_bias_tensor
65
+
66
+
67
+ # Note: `torch.compile` only allows custom autograd functions
68
+ # to accept a subset of types. Therefore we serialize `op` objects
69
+ # to `str` before entering the function, and unserialize them inside.
70
+ # See also: https://github.com/pytorch/pytorch/issues/118395
71
+ _OPS_LOOKUP = {
72
+ flash.FwOp.NAME: flash.FwOp,
73
+ flash.BwOp.NAME: flash.BwOp,
74
+ flash_mtia.FwOp.NAME: flash_mtia.FwOp,
75
+ flash_mtia.BwOp.NAME: flash_mtia.BwOp,
76
+ }
77
+
78
+
79
+ def _serialize_op(op):
80
+ if op is not None and op.NAME in _OPS_LOOKUP:
81
+ return op.NAME
82
+ return op
83
+
84
+
85
+ def _unserialize_op(op):
86
+ if isinstance(op, str):
87
+ return _OPS_LOOKUP[op]
88
+ return op
89
+
90
+
91
+ class _fMHA(torch.autograd.Function):
92
+ @staticmethod
93
+ # type: ignore
94
+ def forward(ctx, op_fw, op_bw, *args: Any) -> Any:
95
+ inp = Inputs(*args)
96
+
97
+ op_fw = _unserialize_op(op_fw)
98
+ op_bw = _unserialize_op(op_bw)
99
+
100
+ out, op_ctx = _memory_efficient_attention_forward_requires_grad(
101
+ inp=inp, op=op_fw
102
+ )
103
+
104
+ # Saving attn_bias is a bit complicated, as the
105
+ # torch part should go in `save_for_backward`
106
+ if isinstance(inp.attn_bias, torch.Tensor):
107
+ attn_bias_tensor = inp.attn_bias
108
+ attn_bias_ctx = None
109
+ else:
110
+ attn_bias_tensor = None
111
+ attn_bias_ctx = inp.attn_bias
112
+
113
+ ctx.save_for_backward(
114
+ inp.query,
115
+ inp.key,
116
+ inp.value,
117
+ op_ctx.out,
118
+ op_ctx.lse,
119
+ )
120
+ ctx.rng_state = op_ctx.rng_state
121
+ ctx.attn_bias_tensor = attn_bias_tensor
122
+ if op_ctx.op_bw is not None:
123
+ if op_bw is not None and op_bw is not op_ctx.op_bw:
124
+ raise ValueError(
125
+ f"Specified op_bw={op_bw.NAME}, but forward op "
126
+ f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
127
+ )
128
+ op_bw = op_ctx.op_bw
129
+ if (
130
+ op_fw is not None
131
+ and op_bw is not None
132
+ and isinstance(inp.attn_bias, VARLEN_BIASES)
133
+ and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2
134
+ and op_bw.VARLEN_LSE_PACKED != op_fw.VARLEN_LSE_PACKED
135
+ ):
136
+ raise ValueError(
137
+ f"Specified op_bw={op_bw.NAME} is not compatible with the "
138
+ f"op_fw={op_fw.NAME}, because they use different format of logsumexp. "
139
+ f"NOTE: This is new with xFormers 0.0.28"
140
+ )
141
+ if op_bw is None and (
142
+ inp.query.requires_grad or inp.key.requires_grad or inp.value.requires_grad
143
+ ):
144
+ varlen_lse_packed = _detect_lse_packed_or_raise(op_ctx.lse, inp)
145
+ if varlen_lse_packed is not None and op_fw is not None:
146
+ assert op_fw.VARLEN_LSE_PACKED == varlen_lse_packed, (
147
+ f"{op_fw.NAME}: wrong value for `VARLEN_LSE_PACKED` ?"
148
+ )
149
+ # NOTE: We need to check tensor strides to decide which operator we run in the BW pass.
150
+ # Unfortunately, PyTorch only allows to call this function during the FW pass, so
151
+ # we decide the operator to use now.
152
+ op_bw = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
153
+ ctx.op_fw = op_fw
154
+ ctx.op_bw = op_bw
155
+ ctx.p = inp.p
156
+ # This allows to create gradients from a single storage,
157
+ # to avoid a "cat" in the BW pass.
158
+ # The heuristic is approximative, but:
159
+ # (1) It's not a big issue to create a shared storage
160
+ # (2) The heuristic needs to pass `torch.compile`
161
+ # (this is also why we run it in the FW pass, the BW pass is stricter)
162
+ ctx.qkv_share_storage = (
163
+ inp.query.shape[0] == inp.key.shape[0]
164
+ and inp.query.shape[-1] == inp.value.shape[-1]
165
+ and inp.query.stride(-2)
166
+ == (inp.key.shape[-1] + inp.query.shape[-1] + inp.value.shape[-1])
167
+ )
168
+
169
+ ctx.scale = inp.scale
170
+ ctx.attn_bias_ctx = attn_bias_ctx
171
+ ctx.n_args = len(args)
172
+ return out, op_ctx.lse
173
+
174
+ @staticmethod
175
+ @torch.autograd.function.once_differentiable
176
+ def backward(ctx, grad, grad_lse):
177
+ # Re-create context
178
+ query, key, value, out, lse = ctx.saved_tensors
179
+ attn_bias_tensor = ctx.attn_bias_tensor
180
+ rng_state = ctx.rng_state
181
+ inp = Inputs(
182
+ query=query,
183
+ key=key,
184
+ value=value,
185
+ attn_bias=_deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
186
+ p=ctx.p,
187
+ scale=ctx.scale,
188
+ )
189
+ op_ctx = Context(
190
+ lse=lse,
191
+ out=out,
192
+ rng_state=rng_state,
193
+ qkv_share_storage=ctx.qkv_share_storage,
194
+ )
195
+ grads = _memory_efficient_attention_backward(
196
+ ctx=op_ctx,
197
+ inp=inp,
198
+ grad=grad,
199
+ op=ctx.op_bw,
200
+ _skip_op_checks=True,
201
+ )
202
+ return (None, None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
203
+ ctx.n_args - 2
204
+ )
205
+
206
+
207
+ def memory_efficient_attention(
208
+ query: torch.Tensor,
209
+ key: torch.Tensor,
210
+ value: torch.Tensor,
211
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
212
+ p: float = 0.0,
213
+ scale: Optional[float] = None,
214
+ *,
215
+ op: Optional[AttentionOp] = None,
216
+ output_dtype: Optional[torch.dtype] = None,
217
+ ) -> torch.Tensor:
218
+ """Implements the memory-efficient attention mechanism following
219
+ `"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
220
+
221
+ :Inputs shape:
222
+
223
+ - Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
224
+ the sequence length, H the number of heads, and K the embeding size per head
225
+
226
+ - If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
227
+
228
+ - Inputs can also be of dimension 5 with GQA - see note below
229
+
230
+ - Inputs can be non-contiguous - we only require the last dimension's stride to be 1
231
+
232
+
233
+ :Equivalent pytorch code:
234
+
235
+ .. code-block:: python
236
+
237
+ scale = 1.0 / query.shape[-1] ** 0.5
238
+ query = query * scale
239
+ query = query.transpose(1, 2)
240
+ key = key.transpose(1, 2)
241
+ value = value.transpose(1, 2)
242
+ attn = query @ key.transpose(-2, -1)
243
+ if attn_bias is not None:
244
+ attn = attn + attn_bias
245
+ attn = attn.softmax(-1)
246
+ attn = F.dropout(attn, p)
247
+ attn = attn @ value
248
+ return attn.transpose(1, 2).contiguous()
249
+
250
+ :Examples:
251
+
252
+ .. code-block:: python
253
+
254
+ import xformers.ops as xops
255
+
256
+ # Compute regular attention
257
+ y = xops.memory_efficient_attention(q, k, v)
258
+
259
+ # With a dropout of 0.2
260
+ y = xops.memory_efficient_attention(q, k, v, p=0.2)
261
+
262
+ # Causal attention
263
+ y = xops.memory_efficient_attention(
264
+ q, k, v,
265
+ attn_bias=xops.LowerTriangularMask()
266
+ )
267
+
268
+ :Supported hardware:
269
+
270
+ NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
271
+
272
+ :EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
273
+
274
+ MQA/GQA is an experimental feature supported only for the forward pass.
275
+ If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
276
+ in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
277
+ ``H`` is the number of heads per group (8 in the example).
278
+
279
+ Please note that xFormers will not automatically broadcast the inputs, so you will need
280
+ to broadcast it manually before calling `memory_efficient_attention`.
281
+
282
+ :GQA/MQA example:
283
+
284
+ .. code-block:: python
285
+
286
+ import torch
287
+ import xformers.ops as xops
288
+
289
+ B, M, K = 3, 32, 128
290
+ kwargs = dict(device="cuda", dtype=torch.float16)
291
+ q = torch.randn([B, M, 8, K], **kwargs)
292
+ k = torch.randn([B, M, 2, K], **kwargs)
293
+ v = torch.randn([B, M, 2, K], **kwargs)
294
+ out_gqa = xops.memory_efficient_attention(
295
+ q.reshape([B, M, 2, 4, K]),
296
+ k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
297
+ v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
298
+ )
299
+
300
+ Raises:
301
+ NotImplementedError: if there is no operator available to compute the MHA
302
+ ValueError: if inputs are invalid
303
+
304
+ :parameter query: Tensor of shape ``[B, Mq, H, K]``
305
+ :parameter key: Tensor of shape ``[B, Mkv, H, K]``
306
+ :parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
307
+ :parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
308
+ For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
309
+ This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
310
+ :parameter p: Dropout probability. Disabled if set to ``0.0``
311
+ :parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
312
+ scale (q.shape[-1]**-0.5) will be used.
313
+ :parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
314
+ If set to ``None`` (recommended), xFormers \
315
+ will dispatch to the best available operator, depending on the inputs \
316
+ and options.
317
+ :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
318
+ """
319
+ return _memory_efficient_attention(
320
+ Inputs(
321
+ query=query,
322
+ key=key,
323
+ value=value,
324
+ p=p,
325
+ attn_bias=attn_bias,
326
+ scale=scale,
327
+ output_dtype=output_dtype,
328
+ ),
329
+ op=op,
330
+ )
331
+
332
+
333
+ torch.library.define(
334
+ "mslk::memory_efficient_attention_forward",
335
+ "(Tensor q, Tensor k, Tensor v, Tensor? b = None, float? p = 0.0, float? scale = None) -> Tensor",
336
+ )
337
+
338
+
339
+ def _memory_efficient_attention_forward_torch_wrapper_meta(
340
+ query: torch.Tensor,
341
+ key: torch.Tensor,
342
+ value: torch.Tensor,
343
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
344
+ p: float = 0.0,
345
+ scale: Optional[float] = None,
346
+ ):
347
+ return torch.empty_like(query)
348
+
349
+
350
+ torch.library.impl(
351
+ "mslk::memory_efficient_attention_forward",
352
+ "Meta",
353
+ _memory_efficient_attention_forward_torch_wrapper_meta,
354
+ )
355
+
356
+
357
+ # torch.compile has issue when tracing through op dispatch and ensure_op_support
358
+ # so provide a wrapper to register it as a custom torch library op.
359
+ def _memory_efficient_attention_forward_torch_wrapper(
360
+ query: torch.Tensor,
361
+ key: torch.Tensor,
362
+ value: torch.Tensor,
363
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
364
+ p: float = 0.0,
365
+ scale: Optional[float] = None,
366
+ ) -> torch.Tensor:
367
+ """
368
+ This provides a torch-compilable wrapper op to
369
+ memory_efficient_attention_forward in certain special cases.
370
+
371
+ Note that the following are not supported
372
+ - `op` input (?)
373
+ - certain attn_bias types (?)
374
+ - output_dtype
375
+ - K != Kv
376
+ """
377
+ return memory_efficient_attention_forward(
378
+ query,
379
+ key,
380
+ value,
381
+ attn_bias,
382
+ p,
383
+ scale,
384
+ )
385
+
386
+
387
+ torch.library.impl(
388
+ "mslk::memory_efficient_attention_forward",
389
+ "CUDA",
390
+ _memory_efficient_attention_forward_torch_wrapper,
391
+ )
392
+
393
+
394
+ torch.library.define(
395
+ "mslk::memory_efficient_attention_forward_with_bias",
396
+ "(Tensor q, Tensor k, Tensor v, Tensor b, float? p = 0.0, float? scale = None) -> Tensor",
397
+ )
398
+
399
+
400
+ def _memory_efficient_attention_forward_torch_wrapper_with_bias_meta(
401
+ query: torch.Tensor,
402
+ key: torch.Tensor,
403
+ value: torch.Tensor,
404
+ attn_bias: Union[torch.Tensor, AttentionBias],
405
+ p: float = 0.0,
406
+ scale: Optional[float] = None,
407
+ ):
408
+ return torch.empty_like(query)
409
+
410
+
411
+ torch.library.impl(
412
+ "mslk::memory_efficient_attention_forward_with_bias",
413
+ "Meta",
414
+ _memory_efficient_attention_forward_torch_wrapper_with_bias_meta,
415
+ )
416
+
417
+
418
+ # torch.compile has issue when tracing through op dispatch and ensure_op_support
419
+ # so provide a wrapper to register it as a custom torch library op.
420
+
421
+
422
+ def _memory_efficient_attention_forward_torch_wrapper_with_bias(
423
+ query: torch.Tensor,
424
+ key: torch.Tensor,
425
+ value: torch.Tensor,
426
+ attn_bias: Union[torch.Tensor, AttentionBias],
427
+ p: float = 0.0,
428
+ scale: Optional[float] = None,
429
+ ) -> torch.Tensor:
430
+ """
431
+ This provides a torch-compilable wrapper op to
432
+ memory_efficient_attention_forward in certain special cases.
433
+
434
+ Note that the following are not supported
435
+ - `op` input (?)
436
+ - certain attn_bias types (?)
437
+ - output_dtype
438
+ - K != Kv
439
+ """
440
+ return memory_efficient_attention_forward(
441
+ query,
442
+ key,
443
+ value,
444
+ attn_bias,
445
+ p,
446
+ scale,
447
+ )
448
+
449
+
450
+ torch.library.impl(
451
+ "mslk::memory_efficient_attention_forward_with_bias",
452
+ "CUDA",
453
+ _memory_efficient_attention_forward_torch_wrapper_with_bias,
454
+ )
455
+
456
+
457
+ def memory_efficient_attention_forward(
458
+ query: torch.Tensor,
459
+ key: torch.Tensor,
460
+ value: torch.Tensor,
461
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
462
+ p: float = 0.0,
463
+ scale: Optional[float] = None,
464
+ *,
465
+ op: Optional[Type[AttentionFwOpBase]] = None,
466
+ output_dtype: Optional[torch.dtype] = None,
467
+ ) -> torch.Tensor:
468
+ """
469
+ Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
470
+ """
471
+ return _memory_efficient_attention_forward(
472
+ Inputs(
473
+ query=query,
474
+ key=key,
475
+ value=value,
476
+ p=p,
477
+ attn_bias=attn_bias,
478
+ scale=scale,
479
+ output_dtype=output_dtype,
480
+ ),
481
+ op=op,
482
+ )
483
+
484
+
485
+ def memory_efficient_attention_forward_requires_grad(
486
+ query: torch.Tensor,
487
+ key: torch.Tensor,
488
+ value: torch.Tensor,
489
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
490
+ p: float = 0.0,
491
+ scale: Optional[float] = None,
492
+ *,
493
+ op: Optional[Type[AttentionFwOpBase]] = None,
494
+ output_dtype: Optional[torch.dtype] = None,
495
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
496
+ """
497
+ Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
498
+ See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
499
+ See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
500
+ """
501
+ if p != 0.0:
502
+ raise NotImplementedError(
503
+ "dropout is not supported on the non-autograd API."
504
+ " If you want to use dropout, please call `memory_efficient_attention` directly"
505
+ )
506
+ out, ctx = _memory_efficient_attention_forward_requires_grad(
507
+ Inputs(
508
+ query=query,
509
+ key=key,
510
+ value=value,
511
+ p=p,
512
+ attn_bias=attn_bias,
513
+ scale=scale,
514
+ output_dtype=output_dtype,
515
+ ),
516
+ op=op,
517
+ )
518
+ return out, ctx.lse
519
+
520
+
521
+ def memory_efficient_attention_backward(
522
+ grad: torch.Tensor,
523
+ output: torch.Tensor,
524
+ lse: torch.Tensor,
525
+ query: torch.Tensor,
526
+ key: torch.Tensor,
527
+ value: torch.Tensor,
528
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
529
+ p: float = 0.0,
530
+ scale: Optional[float] = None,
531
+ *,
532
+ op: Optional[Type[AttentionBwOpBase]] = None,
533
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
534
+ """
535
+ Computes the gradient of the attention.
536
+ Returns a tuple (dq, dk, dv)
537
+ See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
538
+ `lse` is the tensor returned by
539
+ :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
540
+ """
541
+ if p != 0.0:
542
+ raise NotImplementedError(
543
+ "dropout is not supported on the non-autograd API."
544
+ " If you want to use dropout, please call `memory_efficient_attention` directly"
545
+ )
546
+ gradients = _memory_efficient_attention_backward(
547
+ Context(out=output, lse=lse),
548
+ Inputs(
549
+ query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
550
+ ),
551
+ grad,
552
+ op=op,
553
+ )
554
+ return (gradients.dq, gradients.dk, gradients.dv)
555
+
556
+
557
+ def _memory_efficient_attention(
558
+ inp: Inputs, op: Optional[AttentionOp] = None
559
+ ) -> torch.Tensor:
560
+ # fast-path that doesn't require computing the logsumexp for backward computation
561
+ if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
562
+ return _memory_efficient_attention_forward(
563
+ inp, op=op[0] if op is not None else None
564
+ )
565
+
566
+ output_shape = inp.normalize_bmhk()
567
+
568
+ op_fw = _serialize_op(op[0] if op is not None else None)
569
+ op_bw = _serialize_op(op[1] if op is not None else None)
570
+ return _fMHA.apply(
571
+ op_fw, op_bw, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
572
+ )[0].reshape(output_shape)
573
+
574
+
575
+ def _memory_efficient_attention_forward(
576
+ inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
577
+ ) -> torch.Tensor:
578
+ inp.validate_inputs()
579
+ output_shape = inp.normalize_bmhk()
580
+ if op is None:
581
+ op = _dispatch_fw(inp, False)
582
+ else:
583
+ _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
584
+
585
+ out, *_ = op.apply(inp, needs_gradient=False)
586
+ return out.reshape(output_shape)
587
+
588
+
589
+ def _memory_efficient_attention_forward_requires_grad(
590
+ inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
591
+ ) -> Tuple[torch.Tensor, Context]:
592
+ inp.validate_inputs()
593
+ output_shape = inp.normalize_bmhk()
594
+ if op is None:
595
+ op = _dispatch_fw(inp, True)
596
+ else:
597
+ _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
598
+ out, ctx = op.apply(inp, needs_gradient=True)
599
+ assert ctx is not None
600
+ return (out.reshape(output_shape), ctx)
601
+
602
+
603
+ def _detect_lse_packed_or_raise(lse: torch.Tensor, inp: Inputs) -> Optional[bool]:
604
+ """
605
+ Detects the LSE format if we're in a varlen case.
606
+ Returns `None` if the format is not relevant (eg not varlen)
607
+ Raises an exception if the `lse` has the wrong shape
608
+ """
609
+ shape_mismatch_err = (
610
+ "Input tensors have incompatible shapes.\n"
611
+ f" lse.shape : {lse.shape}\n"
612
+ f" query.shape : {inp.query.shape}\n"
613
+ f" attn_bias : {type(inp.attn_bias)}"
614
+ )
615
+ # 1. Check ndim & head dimensions
616
+ # In any case, LSE should be [*, *GH]
617
+ if lse.ndim != (inp.query.ndim - 1) or lse.shape[1:-1] != inp.query.shape[2:-1]:
618
+ raise ValueError(shape_mismatch_err)
619
+ lse_bm = [lse.shape[0], lse.shape[-1]]
620
+ lse_packed_shape = [inp.query.shape[0], inp.query.shape[1]]
621
+ lse_packed = lse_bm[0] == lse_packed_shape[0] and lse_bm >= lse_packed_shape
622
+ # 2. Check correctness for varlen biases with query.shape = [1, M, *GH, K]
623
+ # Either [1, *GH, M] (packed)
624
+ # Or [num_seq, *GH, Mq] .. with `Mq >= max_q` (padded)
625
+ if isinstance(inp.attn_bias, VARLEN_BIASES):
626
+ si = inp.attn_bias.q_seqinfo
627
+ lse_padded_shape = [si.seqstart.shape[0] - 1, si.max_seqlen]
628
+ lse_padded = lse_bm[0] == lse_padded_shape[0] and lse_bm >= lse_padded_shape
629
+ if lse_packed and lse_padded:
630
+ return None
631
+ elif lse_packed:
632
+ return True
633
+ elif lse_padded:
634
+ return False
635
+ raise ValueError(shape_mismatch_err)
636
+ # 3. For non-varlen, shape must be [B, *GH] with query.shape=[B, M, *GH, K]
637
+ if not lse_packed:
638
+ raise ValueError(shape_mismatch_err)
639
+ return None
640
+
641
+
642
+ def _memory_efficient_attention_backward(
643
+ ctx: Context,
644
+ inp: Inputs,
645
+ grad: torch.Tensor,
646
+ op: Optional[Type[AttentionBwOpBase]],
647
+ *,
648
+ _skip_op_checks: bool = False,
649
+ ) -> Gradients:
650
+ """Warning: grad/ctx.out is potentially in BMK format"""
651
+ inp.validate_inputs()
652
+ if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
653
+ raise ValueError(
654
+ "All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
655
+ f"grad.shape : {grad.shape} \n"
656
+ f"out.shape : {ctx.out.shape} \n"
657
+ f"query.shape: {inp.query.shape}"
658
+ )
659
+ shape_dq, shape_dk, shape_dv = tuple(
660
+ x.shape for x in (inp.query, inp.key, inp.value)
661
+ )
662
+ inp.normalize_bmhk()
663
+ varlen_lse_packed = _detect_lse_packed_or_raise(ctx.lse, inp)
664
+ grad = bmk2bmhk(grad, 1)
665
+ ctx.out = bmk2bmhk(ctx.out, 1)
666
+
667
+ if op is None:
668
+ op = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
669
+ elif not _skip_op_checks:
670
+ _ensure_op_supports_or_raise(
671
+ ValueError, "memory_efficient_attention_backward", op, inp
672
+ )
673
+ if varlen_lse_packed is not None and varlen_lse_packed != op.VARLEN_LSE_PACKED:
674
+ raise ValueError(
675
+ f"Wrong LSE format for {op.NAME} in variable seqlen case. "
676
+ f"Double-check that the BW operator {op.NAME} is compatible "
677
+ f"with the operator used in the FW pass."
678
+ )
679
+
680
+ grads = op.apply(ctx, inp, grad)
681
+ grads.dq = grads.dq.reshape(shape_dq)
682
+ grads.dk = grads.dk.reshape(shape_dk)
683
+ grads.dv = grads.dv.reshape(shape_dv)
684
+ return grads
685
+
686
+
687
+ def memory_efficient_attention_partial(
688
+ query: torch.Tensor,
689
+ key: torch.Tensor,
690
+ value: torch.Tensor,
691
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
692
+ p: float = 0.0,
693
+ scale: Optional[float] = None,
694
+ *,
695
+ op: Optional[Union[AttentionOp, Type[AttentionFwOpBase]]] = None,
696
+ output_dtype: Optional[torch.dtype] = None,
697
+ _allow_backward: bool = False,
698
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
699
+ """
700
+ Returns a tuple (output, lse), where `output` is the attention in the style of
701
+ memory_efficient_attention, and `lse` is extra data, a log-sum-exp.
702
+ The outputs of calls to this with the same query and separate keys and values
703
+ can be merged with merge_attentions to obtain the attention of the queries
704
+ against the disjoint union of the keys and values.
705
+
706
+ This function doesn't have a backward pass.
707
+
708
+ If _allow_backward is set to True, then a backward pass is allowed,
709
+ but it is restricted: only the gradient of the output, not the gradient of
710
+ the LSE, is used.
711
+ Note that this makes it very easy to accidentally get wrong gradients.
712
+ """
713
+ if p != 0.0:
714
+ raise NotImplementedError("dropout is not supported.")
715
+ fwop: Optional[Type[AttentionFwOpBase]] = op[0] if isinstance(op, tuple) else op
716
+ inp = Inputs(
717
+ query=query,
718
+ key=key,
719
+ value=value,
720
+ p=p,
721
+ attn_bias=attn_bias,
722
+ scale=scale,
723
+ output_dtype=output_dtype,
724
+ is_partial=True,
725
+ )
726
+ is_grad = (
727
+ _allow_backward
728
+ and torch.is_grad_enabled()
729
+ and any(x.requires_grad for x in [query, key, value])
730
+ )
731
+
732
+ if not is_grad:
733
+ out, ctx = _memory_efficient_attention_forward_requires_grad(
734
+ inp,
735
+ op=fwop,
736
+ )
737
+ return out, ctx.lse
738
+
739
+ if query.ndim == 5:
740
+ raise ValueError("gradients not supported for 5D tensors")
741
+ if isinstance(op, tuple):
742
+ op_fw = _serialize_op(op[0])
743
+ op_bw = _serialize_op(op[1])
744
+ elif op is None:
745
+ op_fw = op_bw = None
746
+ else:
747
+ op_fw = _serialize_op(op)
748
+ op_bw = None
749
+ return _fMHA.apply(
750
+ op_fw,
751
+ op_bw,
752
+ inp.query,
753
+ inp.key,
754
+ inp.value,
755
+ inp.attn_bias,
756
+ inp.p,
757
+ inp.scale,
758
+ inp.output_dtype,
759
+ inp.is_partial,
760
+ )
761
+
762
+
763
+ def merge_attentions( # noqa: C901
764
+ attn_split: Union[torch.Tensor, Sequence[torch.Tensor]],
765
+ lse_split: Union[torch.Tensor, Sequence[torch.Tensor]],
766
+ write_lse: bool = True,
767
+ output_dtype: Optional[torch.dtype] = None,
768
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
769
+ """
770
+ Combine attention output computed on different parts of K/V for the same
771
+ query to get attention on the whole K/V. See https://arxiv.org/abs/2402.05099
772
+ The result is equal to
773
+ Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...)
774
+ LSE_full = log(exp(LSE1) + exp(LSE2) + ...)
775
+
776
+ Args:
777
+ attn_split: attention outputs for chunks,
778
+ either as a list of tensors of shapes [B, M, G, H, Kq] or [B, M, H, Kq]
779
+ or as a single tensor of shape [num_chunks, B, M, G, H, Kq]
780
+ or [num_chunks, B, M, H, Kq]
781
+ lse_split: LSE for chunks,
782
+ either as a list of tensors of shapes [B, G, H, M] or [B, H, M]
783
+ or as a single tensor of shape [num_chunks, B, G, H, M] or [num_chunks, B, H, M]
784
+ write_lse: whether to output LSE
785
+ output_dtype: dtype of attn_out
786
+
787
+ Returns:
788
+ attn_out: [B, M, G, H, Kq] or [B, M, H, Kq]
789
+ lse_out: [B, G, H, M] or [B, H, M] if write_lse
790
+ or None otherwise
791
+ """
792
+
793
+ attn_is_concat = isinstance(attn_split, torch.Tensor)
794
+ lse_is_concat = isinstance(lse_split, torch.Tensor)
795
+
796
+ attn_requires_grad = (
797
+ attn_split.requires_grad # type: ignore
798
+ if attn_is_concat
799
+ else any(x.requires_grad for x in attn_split)
800
+ )
801
+ lse_requires_grad = (
802
+ lse_split.requires_grad # type: ignore
803
+ if lse_is_concat
804
+ else any(x.requires_grad for x in lse_split)
805
+ )
806
+ requires_grad = torch.is_grad_enabled() and (
807
+ attn_requires_grad or lse_requires_grad
808
+ )
809
+ if requires_grad and not write_lse:
810
+ raise ValueError("write_lse should be true if inputs require gradients.")
811
+
812
+ concat_path = attn_is_concat and lse_is_concat and not requires_grad
813
+ if concat_path:
814
+ attn_split = cast(torch.Tensor, attn_split)
815
+ lse_split = cast(torch.Tensor, lse_split)
816
+ if attn_split.ndim != lse_split.ndim + 1:
817
+ raise ValueError(
818
+ f"Incompatible input shapes: {attn_split.shape=}, {lse_split.shape=}"
819
+ )
820
+
821
+ is_bmhk = attn_split.ndim == 5
822
+ if is_bmhk:
823
+ attn_split = attn_split.unsqueeze(3)
824
+ lse_split = lse_split.unsqueeze(2)
825
+
826
+ num_chunks, B, M, G, H, Kq = attn_split.shape
827
+ num_chunks1, B1, G1, H1, M1 = lse_split.shape
828
+ if B != B1 or G != G1 or H != H1 or num_chunks != num_chunks1 or M != M:
829
+ raise ValueError(
830
+ f"Incompatible input shapes: {attn_split.shape=} {lse_split.shape=} "
831
+ f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {num_chunks}/{num_chunks1}, {M}/{M}"
832
+ )
833
+
834
+ attn_split = attn_split.permute(1, 3, 4, 0, 2, 5)
835
+ lse_split = lse_split.permute(1, 2, 3, 0, 4)
836
+
837
+ device = attn_split.device
838
+ attn_dtype = attn_split.dtype
839
+ lse_dtype = lse_split.dtype
840
+ else:
841
+ if attn_is_concat:
842
+ attn_split = attn_split.unbind(0) # type: ignore
843
+ if lse_is_concat:
844
+ lse_split = lse_split.unbind(0) # type: ignore
845
+ num_chunks = len(attn_split)
846
+ if len(lse_split) != num_chunks:
847
+ raise ValueError(
848
+ f"Incompatible number of LSE and attention chunks: {len(attn_split)=}, {len(lse_split)=}"
849
+ )
850
+
851
+ attn_unsqueezed = []
852
+ lse_unsqueezed = []
853
+ is_bmhk = False
854
+ for i in range(num_chunks):
855
+ if attn_split[i].ndim != lse_split[i].ndim + 1:
856
+ raise ValueError(
857
+ f"Incompatible input shapes for chunk {i}: {attn_split[i].shape=}, {lse_split[i].shape=}"
858
+ )
859
+
860
+ is_bmhk = attn_split[i].ndim == 4
861
+ if is_bmhk:
862
+ attn_unsqueezed.append(attn_split[i].unsqueeze(2))
863
+ lse_unsqueezed.append(lse_split[i].unsqueeze(1))
864
+ else:
865
+ attn_unsqueezed.append(attn_split[i])
866
+ lse_unsqueezed.append(lse_split[i])
867
+ attn_split, lse_split = attn_unsqueezed, lse_unsqueezed
868
+
869
+ B, M, G, H, Kq = attn_split[0].shape
870
+ B1, G1, H1, M1 = lse_split[0].shape
871
+ if B != B1 or G != G1 or H != H1 or M != M:
872
+ raise ValueError(
873
+ f"Incompatible input shapes: {attn_split[0].shape=}, {lse_split[0].shape=} "
874
+ f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {M}/{M}"
875
+ )
876
+
877
+ for i in range(num_chunks):
878
+ if attn_split[i].shape != (B, M, G, H, Kq):
879
+ raise ValueError(
880
+ f"Incompatible input shapes for attention chunk {i}: "
881
+ f"{attn_split[i].shape=}, {(B, M, G, H, Kq)=}"
882
+ )
883
+ if lse_split[i].shape != (B, G, H, M):
884
+ raise ValueError(
885
+ f"Incompatible input shapes for LSE chunk {i}: "
886
+ f"{lse_split[i].shape=}, {(B, G, H, M)=}"
887
+ )
888
+
889
+ attn_split[i] = attn_split[i].permute(0, 2, 3, 1, 4) # to (B, G, H, M, Kq)
890
+
891
+ device = attn_split[0].device
892
+ attn_dtype = attn_split[0].dtype
893
+ lse_dtype = lse_split[0].dtype
894
+
895
+ if concat_path:
896
+ attn_out = torch.empty(
897
+ B,
898
+ M,
899
+ G,
900
+ H,
901
+ Kq,
902
+ device=device,
903
+ dtype=output_dtype or attn_dtype,
904
+ )
905
+ if write_lse:
906
+ lse_out = torch.empty(
907
+ B,
908
+ G,
909
+ H,
910
+ M,
911
+ device=device,
912
+ dtype=lse_dtype,
913
+ )
914
+ else:
915
+ lse_out = None
916
+ triton_splitk.merge_attentions(attn_out, lse_out, attn_split, lse_split) # type: ignore
917
+ else:
918
+ outs = triton_splitk.merge_attentions_varargs(
919
+ attn_split, lse_split, write_lse, output_dtype, B, M, G, H, Kq
920
+ ) # type: ignore
921
+ attn_out = outs[0]
922
+ lse_out = outs[1] if write_lse else None
923
+
924
+ if is_bmhk:
925
+ attn_out = attn_out[:, :, 0]
926
+ if lse_out is not None:
927
+ lse_out = lse_out[:, 0]
928
+
929
+ return attn_out, lse_out
930
+
931
+
932
+ ALL_FW_OPS: List[Type[AttentionFwOpBase]] = [
933
+ cutlass.FwOp if torch.version.cuda else ck.FwOp,
934
+ cutlass_blackwell.FwOp,
935
+ flash.FwOp,
936
+ flash_mtia.FwOp,
937
+ flash3.FwOp,
938
+ triton_splitk.FwOp,
939
+ ]
940
+
941
+ ALL_BW_OPS: List[Type[AttentionBwOpBase]] = [
942
+ cutlass.BwOp if torch.version.cuda else ck.BwOp,
943
+ cutlass_blackwell.BwOp,
944
+ flash.BwOp,
945
+ flash_mtia.BwOp,
946
+ flash3.BwOp,
947
+ ]
948
+
949
+ __all__ = [
950
+ "AttentionBias",
951
+ "AttentionOp",
952
+ "AttentionOpBase",
953
+ "LowerTriangularMask",
954
+ "MemoryEfficientAttentionCutlassFwdFlashBwOp",
955
+ "MemoryEfficientAttentionCutlassOp",
956
+ "MemoryEfficientAttentionFlashAttentionOp",
957
+ "MemoryEfficientAttentionFlashMtiaAttentionOp",
958
+ "memory_efficient_attention",
959
+ "MemoryEfficientAttentionCkOp",
960
+ "MemoryEfficientAttentionCkDecoderOp",
961
+ "ALL_FW_OPS",
962
+ "ALL_BW_OPS",
963
+ "attn_bias",
964
+ "_get_use_fa3",
965
+ "_set_use_fa3",
966
+ "BlockDiagonalMask",
967
+ ]