mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,598 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ import math
9
+ from dataclasses import dataclass
10
+ from functools import partial
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Iterable,
15
+ List,
16
+ Mapping,
17
+ Optional,
18
+ Set,
19
+ Tuple,
20
+ Type,
21
+ Union,
22
+ )
23
+
24
+ import torch
25
+
26
+ from .attn_bias import (
27
+ AttentionBias,
28
+ BlockDiagonalGappyKeysMask,
29
+ BlockDiagonalMask,
30
+ BlockDiagonalPaddedKeysMask,
31
+ LowerTriangularMask,
32
+ LowerTriangularMaskWithTensorBias,
33
+ PagedBlockDiagonalGappyKeysMask,
34
+ PagedBlockDiagonalPaddedKeysMask,
35
+ )
36
+ from .utils.cpp_lib import _built_with_cuda
37
+ from .utils.op_common import BaseOperator
38
+
39
+
40
+ def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
41
+ # NoneType
42
+ if isinstance(None, attn_bias_type):
43
+ return True
44
+ if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
45
+ return True
46
+ return False
47
+
48
+
49
+ def _attn_bias_apply(
50
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
51
+ op: Callable[[torch.Tensor], torch.Tensor],
52
+ ) -> Optional[Union[torch.Tensor, AttentionBias]]:
53
+ if isinstance(attn_bias, torch.Tensor):
54
+ return op(attn_bias)
55
+ if isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
56
+ return LowerTriangularMaskWithTensorBias(op(attn_bias._bias))
57
+ return attn_bias
58
+
59
+
60
+ class ScaledTensor(torch.Tensor):
61
+ __slots__ = ["scale", "dequant_func", "original_dtype"]
62
+
63
+ # Disabling custom torch function handling for this class
64
+ __torch_function__ = torch._C._disabled_torch_function_impl
65
+
66
+ @staticmethod
67
+ def __new__(
68
+ cls,
69
+ data: torch.Tensor,
70
+ scale: torch.Tensor,
71
+ dequant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
72
+ original_dtype: torch.dtype,
73
+ require_grad: bool = False,
74
+ ) -> "ScaledTensor":
75
+ """
76
+ Creates a new ScaledTensor subclass instance.
77
+
78
+ Parameters:
79
+ - data: The underlying quantized tensor (e.g., int8, int4).
80
+ - scale: The scale tensor or scalar to be used for dequantization.
81
+ - dequant_func: A callable that applies dequantization, which takes both the data and scale as input.
82
+ - original_dtype: The data type before quantization (e.g., float32, float16).
83
+ - require_grad: Whether or not to track gradients (default: False for inference use).
84
+ """
85
+ # Use _make_subclass to create a new ScaledTensor instance, which is a subclass of torch.Tensor.
86
+ instance = torch.Tensor._make_subclass(cls, data, require_grad)
87
+
88
+ # Store the dequantization scale and function as attributes.
89
+ instance.scale = scale # type: ignore
90
+ instance.dequant_func = dequant_func # type: ignore
91
+
92
+ # Store the original data type of the tensor, so we can cast it back after dequantization.
93
+ instance.original_dtype = original_dtype # type: ignore
94
+
95
+ # Return the new instance of ScaledTensor.
96
+ return instance
97
+
98
+ def dequantize(self) -> torch.Tensor:
99
+ """
100
+ Applies the custom dequantization function provided at the tensor's creation.
101
+ After dequantization, the data is cast back to its original data type.
102
+ """
103
+ # Explicitly create a new torch.Tensor to ensure the return type is torch.Tensor, not ScaledTensor.
104
+ data = torch.Tensor(self.float())
105
+
106
+ # Call the dequantization function, passing in the data and the scale.
107
+ dequantized_data = self.dequant_func(data, self.scale) # type: ignore
108
+
109
+ # Cast the dequantized data back to the original data type.
110
+ return dequantized_data.to(self.original_dtype) # type: ignore
111
+
112
+ def unpack(self) -> Tuple[torch.Tensor, torch.Tensor]:
113
+ """
114
+ Unpacks the ScaledTensor by returning its data and scale as a tuple.
115
+ Returns:
116
+ - A tuple of (data, scale), both of which are torch.Tensor objects.
117
+ """
118
+ return self.data, self.scale # type: ignore
119
+
120
+ def __repr__(self):
121
+ """
122
+ Custom string representation for ScaledTensor.
123
+ """
124
+ return f"ScaledTensor(data={self.data}, scale={self.scale}, original_dtype={self.original_dtype})"
125
+
126
+
127
+ def pack_fp8_tensorwise_per_head(
128
+ x: torch.Tensor, scale: Union[torch.Tensor, float], original_dtype
129
+ ) -> ScaledTensor:
130
+ """
131
+ Pack a tensor into a tensorwise fp8 ScaledTensor.
132
+ """
133
+ if isinstance(scale, float):
134
+ scale = torch.tensor([scale], device=x.device)
135
+
136
+ def dequant_func(x, scale):
137
+ return x * scale[:, None, :, None]
138
+
139
+ return ScaledTensor(
140
+ data=x,
141
+ scale=scale,
142
+ dequant_func=dequant_func,
143
+ original_dtype=original_dtype,
144
+ )
145
+
146
+
147
+ @dataclass
148
+ class Inputs:
149
+ """
150
+ Stores inputs to the `memory_efficient_attention` operators
151
+ """
152
+
153
+ query: torch.Tensor
154
+ key: torch.Tensor
155
+ value: torch.Tensor
156
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
157
+ p: float = 0.0
158
+ scale: Optional[float] = None
159
+ output_dtype: Optional[torch.dtype] = None
160
+ is_partial: bool = False
161
+ quantize_pv_to_fp8: bool = False
162
+ quantize_qk_to_fp8: bool = False
163
+ use_fp32_scales: bool = False
164
+
165
+ @property
166
+ def device(self) -> torch.device:
167
+ return self.query.device
168
+
169
+ @property
170
+ def scale_float(self) -> float:
171
+ return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
172
+
173
+ def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
174
+ if self.query.ndim == 5:
175
+ return self.query, self.key, self.value
176
+ if self.query.ndim == 4:
177
+ return (
178
+ self.query.unsqueeze(2),
179
+ self.key.unsqueeze(2),
180
+ self.value.unsqueeze(2),
181
+ )
182
+ if self.value.ndim == 3:
183
+ return (
184
+ self.query[:, :, None, None],
185
+ self.key[:, :, None, None],
186
+ self.value[:, :, None, None],
187
+ )
188
+ raise AssertionError
189
+
190
+ def normalize_bmhk(self) -> Tuple[int, ...]:
191
+ if self.query.ndim not in [3, 4, 5]:
192
+ raise ValueError(
193
+ f"Invalid shape for query: {self.query.shape}. "
194
+ "Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
195
+ ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
196
+ )
197
+ if self.value.dtype == torch.int32:
198
+ # Quantized K/V case, in which the last dims of Q and K are different.
199
+ # NB we currently don't have any implementations for quantized KV with
200
+ # SUPPORTS_DIFFERENT_VALUE_EMBED.
201
+ output_shape: Tuple[int, ...] = tuple(self.query.shape)
202
+ else:
203
+ output_shape = tuple(self.query.shape)[:-1] + (self.value.shape[-1],)
204
+ # Convert from legacy format
205
+ if self.query.ndim == 3:
206
+ self.query = self.query.unsqueeze(2)
207
+ self.key = self.key.unsqueeze(2)
208
+ self.value = self.value.unsqueeze(2)
209
+ self.attn_bias = _attn_bias_apply(
210
+ self.attn_bias, partial(torch.unsqueeze, dim=1)
211
+ )
212
+ return output_shape
213
+
214
+ def validate_inputs(self) -> None: # noqa: C901
215
+ qkv = (self.query, self.key, self.value)
216
+ if self.query.ndim not in (3, 4, 5) or any(
217
+ x.ndim != self.query.ndim for x in qkv
218
+ ):
219
+ raise ValueError(
220
+ f"Query/Key/Value should all have BMGHK, BMHK or BMK shape.\n"
221
+ f" query.shape: {self.query.shape}\n"
222
+ f" key.shape : {self.key.shape}\n"
223
+ f" value.shape: {self.value.shape}"
224
+ )
225
+ if any(x.device != self.query.device for x in qkv):
226
+ raise ValueError("Query/Key/Value should all be on the same device")
227
+ if isinstance(
228
+ self.attn_bias,
229
+ (
230
+ BlockDiagonalMask,
231
+ BlockDiagonalPaddedKeysMask,
232
+ PagedBlockDiagonalPaddedKeysMask,
233
+ BlockDiagonalGappyKeysMask,
234
+ PagedBlockDiagonalGappyKeysMask,
235
+ ),
236
+ ):
237
+ bias_device = self.attn_bias.q_seqinfo.seqstart.device
238
+ if bias_device != self.query.device:
239
+ raise ValueError(
240
+ f"Attention bias and Query/Key/Value should be on the same device\n"
241
+ f" query.device: {self.query.device}\n"
242
+ f" attn_bias : {bias_device}\n"
243
+ )
244
+
245
+ quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
246
+ non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
247
+ if not (quantized_dtypes or non_quantized_dtypes):
248
+ raise ValueError(
249
+ "Query/Key/Value should either all have the same dtype, or "
250
+ "(in the quantized case) Key/Value should have dtype torch.int32\n"
251
+ f" query.dtype: {self.query.dtype}\n"
252
+ f" key.dtype : {self.key.dtype}\n"
253
+ f" value.dtype: {self.value.dtype}"
254
+ )
255
+ # Biases with tensors attached are meant to be in BMHK format
256
+ # This would require to permute biases/gradients which can be expensive,
257
+ # so let's just forbid it - BMK is a legacy format anyway
258
+ if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
259
+ type(self.attn_bias)
260
+ ):
261
+ raise ValueError(
262
+ f"Please provide inputs in BMHK format rather "
263
+ f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
264
+ )
265
+ attn_bias_t: Optional[torch.Tensor] = None
266
+ if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
267
+ attn_bias_t = self.attn_bias._bias
268
+ elif isinstance(self.attn_bias, torch.Tensor):
269
+ attn_bias_t = self.attn_bias
270
+ if self.query.ndim == 4 and attn_bias_t is not None:
271
+ expected_shape = (
272
+ self.query.shape[0],
273
+ self.query.shape[2],
274
+ self.query.shape[1],
275
+ self.key.shape[1],
276
+ )
277
+ if attn_bias_t.shape != expected_shape:
278
+ raise ValueError(
279
+ f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
280
+ f" query.shape: {self.query.shape}\n"
281
+ f" key.shape : {self.key.shape}\n"
282
+ f" value.shape: {self.value.shape}"
283
+ )
284
+ if isinstance(self.attn_bias, BlockDiagonalMask):
285
+ if any(x.shape[0] != 1 for x in qkv):
286
+ raise ValueError(
287
+ f"Expected batch_size=1 when using block-diagonal bias\n"
288
+ f" query.shape: {self.query.shape}\n"
289
+ f" key.shape : {self.key.shape}\n"
290
+ f" value.shape: {self.value.shape}"
291
+ )
292
+ if self.p < 0.0 or self.p > 1.0:
293
+ raise ValueError(f"Invalid dropout probability: p={self.p}")
294
+ # Check that shapes match between inputs
295
+ B, Mq = self.query.shape[:2]
296
+ K = self.query.shape[-1]
297
+ B, Mkv = self.key.shape[:2]
298
+ Kv = self.value.shape[-1]
299
+ quantized_kv_cache = self.value.dtype == torch.int32
300
+ key_embed_dim = Kv if quantized_kv_cache else K
301
+
302
+ valid_shapes = True
303
+ if self.query.ndim == 3: # BMK
304
+ valid_shapes = (
305
+ self.query.shape == (B, Mq, K)
306
+ and self.key.shape == (B, Mkv, K)
307
+ and self.value.shape == (B, Mkv, Kv)
308
+ )
309
+ H = self.query.shape[-2]
310
+ if self.query.ndim == 4: # BMHK
311
+ valid_shapes = (
312
+ self.query.shape == (B, Mq, H, K)
313
+ and self.key.shape == (B, Mkv, H, key_embed_dim)
314
+ and self.value.shape == (B, Mkv, H, Kv)
315
+ )
316
+ G = self.query.shape[2]
317
+ if self.query.ndim == 5: # BMNHK
318
+ valid_shapes = (
319
+ self.query.shape == (B, Mq, G, H, K)
320
+ and self.key.shape == (B, Mkv, G, H, key_embed_dim)
321
+ and self.value.shape == (B, Mkv, G, H, Kv)
322
+ )
323
+ if not valid_shapes:
324
+ raise ValueError(
325
+ f"Incompatible shapes for attention inputs:\n"
326
+ f" query.shape: {self.query.shape}\n"
327
+ f" key.shape : {self.key.shape}\n"
328
+ f" value.shape: {self.value.shape}\n"
329
+ "HINT: We don't support broadcasting, please use `expand` "
330
+ "yourself before calling `memory_efficient_attention` if you need to"
331
+ )
332
+
333
+ def get_output_dtype(self) -> torch.dtype:
334
+ if self.output_dtype is None:
335
+ if self.is_partial and self.query.dtype is not torch.float64:
336
+ return torch.float32
337
+ return self.query.dtype
338
+ return self.output_dtype
339
+
340
+ @property
341
+ def nbytes(self) -> int:
342
+ """
343
+ Number of bytes in the input, not counting the attention bias.
344
+ """
345
+ return sum(
346
+ x.untyped_storage().nbytes() for x in [self.query, self.key, self.value]
347
+ )
348
+
349
+
350
+ @dataclass
351
+ class Context:
352
+ lse: torch.Tensor
353
+ out: torch.Tensor
354
+ # NOTE: If `rng_state` is set, `op_bw` should be set as well
355
+ # as the randomness is backend-dependant
356
+ op_bw: Optional[Type["AttentionBwOpBase"]] = None
357
+ rng_state: Optional[Any] = None
358
+ qkv_share_storage: bool = False
359
+
360
+ def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
361
+ pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
362
+ lse = self.lse
363
+ if pad_amount > 0:
364
+ if force_pad_inf:
365
+ lse = lse[:, :, : self.out.shape[1]]
366
+ pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
367
+ lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
368
+ elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
369
+ lse[:, :, self.out.shape[1] :].fill_(math.inf)
370
+ return lse
371
+
372
+
373
+ @dataclass
374
+ class Gradients:
375
+ dq: torch.Tensor
376
+ dk: torch.Tensor
377
+ dv: torch.Tensor
378
+ # bias gradient. None if there is no tensor bias or if it doesn't require grad
379
+ db: Optional[torch.Tensor] = None
380
+
381
+
382
+ class AttentionOpBase(BaseOperator):
383
+ """Base class for any attention operator in xFormers
384
+
385
+ See:
386
+
387
+ - :attr:`xformers.ops.fmha.cutlass.FwOp`
388
+ - :attr:`xformers.ops.fmha.cutlass.BwOp`
389
+ - :attr:`xformers.ops.fmha.flash.FwOp`
390
+ - :attr:`xformers.ops.fmha.flash.BwOp`
391
+ - :attr:`xformers.ops.fmha.triton.FwOp`
392
+ - :attr:`xformers.ops.fmha.triton.BwOp`
393
+ """
394
+
395
+ OPERATOR: Any # pyre-ignore[13]
396
+ SUPPORTED_DEVICES: Set[str] # pyre-ignore[13]
397
+ CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
398
+ CUDA_MAXIMUM_COMPUTE_CAPABILITY: Optional[Tuple[int, int]] = None
399
+ SUPPORTED_DTYPES: Set[torch.dtype] # pyre-ignore[13]
400
+ SUPPORTED_MAX_K: float # pyre-ignore[13]
401
+ SUPPORTED_MIN_K: int = 0
402
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (type(None),)
403
+ SUPPORTS_DROPOUT: bool # pyre-ignore[13]
404
+ SUPPORTS_CUSTOM_SCALE: bool = False
405
+ SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
406
+ SUPPORTS_OUTPUT_DTYPE: bool = False
407
+ SUPPORTS_PARTIAL: bool = False
408
+ IS_DETERMINISTIC: bool = True
409
+ SUPPORTS_BMGHK: bool = False
410
+ NAME: str # pyre-ignore[13]
411
+ OPERATOR_CATEGORY = "memory_efficient_attention"
412
+ # Format for the LSE computed in the FW pass, and accepted in the BW pass,
413
+ # for BlockDiagonalMask and children.
414
+ # When using a varlen bias, both the FW and BW operators must have the
415
+ # same value for `VARLEN_LSE_PACKED`
416
+ VARLEN_LSE_PACKED: bool = True
417
+
418
+ _TEST_BATCH_SIZES: List[int] = [1, 300]
419
+ _TEST_K: List[int] = [32, 128]
420
+
421
+ @classmethod
422
+ def supports(cls, d: Inputs) -> bool:
423
+ return not cls.not_supported_reasons(d)
424
+
425
+ @classmethod
426
+ def shape_not_supported_reasons(
427
+ cls, Mq: int, Mkv: int, K: int, Kv: int
428
+ ) -> List[str]:
429
+ reasons = []
430
+ if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
431
+ reasons.append("query.shape[-1] != value.shape[-1]")
432
+ if max(K, Kv) > cls.SUPPORTED_MAX_K:
433
+ reasons.append(
434
+ f"max(query.shape[-1], value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
435
+ )
436
+ if min(K, Kv) < cls.SUPPORTED_MIN_K:
437
+ reasons.append(
438
+ f"min(query.shape[-1], value.shape[-1]) < {cls.SUPPORTED_MIN_K}"
439
+ )
440
+ return reasons
441
+
442
+ @classmethod
443
+ def not_supported_reasons(cls, d: Inputs) -> List[str]: # noqa: C901
444
+ """
445
+ Returns a list of reasons why this is not supported.
446
+ The kernel can run these inputs only if the returned list is empty
447
+ """
448
+ query_shape = d.query.shape
449
+ reasons = cls.shape_not_supported_reasons(
450
+ Mq=query_shape[1],
451
+ Mkv=d.key.shape[1],
452
+ K=query_shape[-1],
453
+ Kv=query_shape[-1] if d.value.dtype == torch.int32 else d.value.shape[-1],
454
+ )
455
+ device_type = d.query.device.type
456
+ dtype = d.query.dtype
457
+ if device_type not in cls.SUPPORTED_DEVICES:
458
+ reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
459
+ if (
460
+ device_type == "cuda"
461
+ and not _built_with_cuda
462
+ and (torch.version.hip is None)
463
+ ):
464
+ reasons.append("xFormers wasn't build with CUDA support")
465
+ if device_type == "cuda" and (torch.version.hip is None):
466
+ device_capability = torch.cuda.get_device_capability(d.device)
467
+ if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
468
+ reasons.append(
469
+ f"requires device with capability >= {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
470
+ f"but your GPU has capability {device_capability} (too old)"
471
+ )
472
+ elif (
473
+ cls.CUDA_MAXIMUM_COMPUTE_CAPABILITY is not None
474
+ and device_capability > cls.CUDA_MAXIMUM_COMPUTE_CAPABILITY
475
+ ):
476
+ reasons.append(
477
+ f"requires device with capability <= {cls.CUDA_MAXIMUM_COMPUTE_CAPABILITY} "
478
+ f"but your GPU has capability {device_capability} (too new)"
479
+ )
480
+ if dtype not in cls.SUPPORTED_DTYPES:
481
+ reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
482
+ if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
483
+ reasons.append(f"attn_bias type is {type(d.attn_bias)}")
484
+ if not cls.SUPPORTS_OUTPUT_DTYPE:
485
+ if d.output_dtype is not None and d.output_dtype is not dtype:
486
+ reasons.append("Custom output dtype not supported")
487
+ if d.is_partial and not cls.SUPPORTS_PARTIAL:
488
+ reasons.append("Partial attention not supported")
489
+ if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
490
+ reasons.append("dropout > 0.0")
491
+ if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
492
+ reasons.append("has custom scale")
493
+ # bfloat16 is only supported on A100+ and MTIA
494
+ # ... although the kernels can still run and give the
495
+ # correct result
496
+ supports_bf16 = (
497
+ device_type.startswith("cuda")
498
+ and torch.cuda.get_device_capability(d.query.device)[0] >= 8
499
+ ) or device_type.startswith("mtia")
500
+ if dtype is torch.bfloat16 and not supports_bf16:
501
+ reasons.append("bf16 is only supported on A100+ GPUs and MTIA")
502
+ if not cls.is_available():
503
+ reasons.append(
504
+ "operator wasn't built - see `python -m xformers.info` for more info"
505
+ )
506
+ if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
507
+ reasons.append(
508
+ "operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
509
+ )
510
+ if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
511
+ reasons.append("operator does not support BMGHK format")
512
+ return reasons
513
+
514
+
515
+ class AttentionFwOpBase(AttentionOpBase):
516
+ ERROR_ATOL: Mapping[torch.dtype, float] = {
517
+ torch.float: 3e-4,
518
+ torch.half: 4e-3,
519
+ torch.bfloat16: 2e-2,
520
+ }
521
+ ERROR_RTOL: Mapping[torch.dtype, float] = {
522
+ torch.float: 2e-5,
523
+ torch.half: 4e-4,
524
+ torch.bfloat16: 5e-3,
525
+ }
526
+
527
+ @classmethod
528
+ def apply(
529
+ cls, inp: Inputs, needs_gradient: bool
530
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
531
+ raise NotImplementedError()
532
+
533
+
534
+ class AttentionBwOpBase(AttentionOpBase):
535
+ # NOTE on tolerances: These are tested for `scales => (1/32)**0.5`
536
+ # In the BW pass, imprecisions accumulate in the Q@K.T recalculation
537
+ # These imprecisions are multiplied by the `scale` and then exponentiated
538
+ # So if the scale is too high, we get a lot of errors
539
+
540
+ ERROR_ATOL: Mapping[torch.dtype, float] = {
541
+ torch.float: 9e-4,
542
+ torch.half: 0.2,
543
+ torch.bfloat16: 0.9,
544
+ }
545
+ ERROR_RTOL: Mapping[torch.dtype, float] = {
546
+ torch.float: 1e-4,
547
+ torch.half: 2e-2,
548
+ torch.bfloat16: 0.1,
549
+ }
550
+ SUPPORTS_ATTN_BIAS_GRAD = False
551
+ SUPPORTS_PARTIAL = True
552
+
553
+ @classmethod
554
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
555
+ reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
556
+ if (
557
+ isinstance(d.attn_bias, torch.Tensor)
558
+ and d.attn_bias.requires_grad
559
+ and not cls.SUPPORTS_ATTN_BIAS_GRAD
560
+ ):
561
+ reasons.append(
562
+ "Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
563
+ )
564
+
565
+ return reasons
566
+
567
+ @classmethod
568
+ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
569
+ raise NotImplementedError()
570
+
571
+
572
+ AttentionOp = Tuple[
573
+ Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
574
+ ]
575
+
576
+
577
+ def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
578
+ if tensor.ndim == 4:
579
+ return tensor
580
+ return tensor.reshape(
581
+ [tensor.shape[0] // num_heads, num_heads, tensor.shape[1], tensor.shape[2]]
582
+ ).permute((0, 2, 1, 3))
583
+
584
+
585
+ def check_lastdim_alignment_stride1(
586
+ reasons: List[str], name: str, x: torch.Tensor, alignment: int
587
+ ) -> None:
588
+ if x.shape[-1] % alignment != 0:
589
+ reasons.append(f"{name}.shape[-1] % {alignment} != 0")
590
+ elif x.stride(-2) % alignment != 0:
591
+ reasons.append(
592
+ f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
593
+ )
594
+ # We can have stride=0 sometimes if dimension=1
595
+ if x.stride(-1) > 1:
596
+ reasons.append(
597
+ f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
598
+ )