mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,858 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+
9
+ import os
10
+ from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple, TypeVar
11
+
12
+ import torch
13
+ from torch.utils.flop_counter import (
14
+ _unpack_flash_attention_nested_shapes,
15
+ register_flop_formula,
16
+ )
17
+
18
+ from .attn_bias import (
19
+ BlockDiagonalCausalFromBottomRightMask,
20
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
21
+ BlockDiagonalCausalLocalAttentionMask,
22
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
23
+ BlockDiagonalCausalMask,
24
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
25
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
26
+ BlockDiagonalGappyKeysMask,
27
+ BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask,
28
+ BlockDiagonalLocalAttentionPaddedKeysMask,
29
+ BlockDiagonalMask,
30
+ BlockDiagonalPaddedKeysMask,
31
+ LocalAttentionFromBottomRightMask,
32
+ LowerTriangularFromBottomRightLocalAttentionMask,
33
+ LowerTriangularFromBottomRightMask,
34
+ LowerTriangularMask,
35
+ PagedBlockDiagonalCausalLocalPaddedKeysMask,
36
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
37
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
38
+ PagedBlockDiagonalGappyKeysMask,
39
+ PagedBlockDiagonalPaddedKeysMask,
40
+ VARLEN_BIASES,
41
+ )
42
+ from .common import (
43
+ AttentionBwOpBase,
44
+ AttentionFwOpBase,
45
+ check_lastdim_alignment_stride1,
46
+ Context,
47
+ Gradients,
48
+ Inputs,
49
+ ScaledTensor,
50
+ )
51
+ from .flash import (
52
+ _check_needs_no_topleft,
53
+ _convert_input_format,
54
+ _is_causal,
55
+ _post_process_lse,
56
+ _window_size,
57
+ )
58
+ from .utils.op_common import get_operator, register_operator
59
+
60
+ FLASH_VERSION = "0.0.0"
61
+
62
+ T = TypeVar("T")
63
+
64
+
65
+ def maybe_contiguous(x: T) -> T:
66
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x # type: ignore[attr-defined]
67
+
68
+
69
+ try:
70
+ from xformers import _C_flashattention3 # type: ignore[attr-defined]
71
+
72
+ try:
73
+ from xformers._cpp_lib import _build_metadata # type: ignore[attr-defined]
74
+
75
+ if _build_metadata is not None:
76
+ FLASH_VERSION = _build_metadata.flash_version
77
+ except ImportError:
78
+ FLASH_VERSION = "unknown"
79
+ except ImportError:
80
+ try:
81
+ # type: ignore
82
+ from ai_codesign.gen_ai.flash_attention_v2.hopper.flash_attn_interface import (
83
+ flashattn_hopper_cuda as _C_flashattention3,
84
+ )
85
+ except ImportError:
86
+ # We end up here is arch is not 90a
87
+ _C_flashattention3 = None
88
+
89
+
90
+ def _heuristic_kvsplit(
91
+ inp: Inputs,
92
+ enable_kvsplit_attn: bool,
93
+ ) -> bool:
94
+ atten_bias = inp.attn_bias
95
+
96
+ # make sure Q doesn't have varlen
97
+ # pyre-ignore Undefined attribute [16]
98
+ if atten_bias.q_seqinfo.min_seqlen != atten_bias.q_seqinfo.max_seqlen: # type: ignore[union-attr]
99
+ return False
100
+
101
+ # filter out prefill case
102
+ # pyre-ignore Undefined attribute [16]
103
+ if atten_bias.q_seqinfo.max_seqlen == atten_bias.k_seqinfo.max_seqlen: # type: ignore[union-attr]
104
+ return False
105
+
106
+ return enable_kvsplit_attn
107
+
108
+
109
+ def mask_non_zeros(s_q: int, s_k: int, window_left: int, window_right: int) -> int:
110
+ # Exact formula for easy cases
111
+ if window_left < 0 and window_right < 0: # full
112
+ return s_q * s_k
113
+ if window_left < 0 and window_right == 0: # causal
114
+ # (from bottom right)
115
+ return (s_q * (s_q + 1)) // 2 + s_q * max(0, s_k - s_q)
116
+
117
+ # NOTE: Flops calculations here assume `s_q == s_k`
118
+ # otherwise the local attention computations are too involved
119
+ # See also https://docs.google.com/spreadsheets/d/1u1ItCZcHLArcqXLj7mwR4H1pI3lMKU1zlxCYi8JCYgk/edit?usp=sharing
120
+ if window_left < 0:
121
+ window_left = s_k
122
+ if window_right < 0:
123
+ window_right = s_k
124
+
125
+ # below the diagonal
126
+ # ┌───────┐
127
+ # │ ╲ │
128
+ # │ ╲ │ <- Upper triangle ("ut")
129
+ # │┄┄┄╲ │ <--- `lastq_ut`
130
+ # │╲ ╲ │
131
+ # │ ╲ ╲ │ <- Lower part
132
+ # │ ╲ ╲│
133
+ # └───────┘
134
+ mask_nz = min(s_q, s_k) # diagonal
135
+ # Below diagonal (with `window_left`)
136
+ lastq_ut = min(window_left, s_q)
137
+ mask_nz += ((lastq_ut - 1) * lastq_ut) // 2 # upper triangle
138
+ mask_nz += (s_q - lastq_ut) * window_left # lower part
139
+ # Above diagonal (with `window_right`)
140
+ # (counting rows from the bottom for symmetry)
141
+ firstq_bt = min(window_right + 1, s_q)
142
+ mask_nz += ((firstq_bt - 1) * firstq_bt) // 2 # bottom triangle
143
+ mask_nz += (s_q - firstq_bt) * window_right
144
+
145
+ return mask_nz
146
+
147
+
148
+ # Copied from PyTorch, modified to support MQA/GQA and local attention
149
+ # No need to take care of this for the bwd because we don't "unexpand" the keys
150
+ # and values (in the fwd we expand to help with the seqlen/headdim swap trick).
151
+ def sdpa_flop_count(
152
+ query_shape, key_shape, value_shape, window_left: int, window_right: int
153
+ ):
154
+ """
155
+ Count flops for self-attention.
156
+
157
+ NB: We can assume that value_shape == key_shape
158
+ """
159
+ b, h_q, s_q, d_q = query_shape
160
+ _b2, h_kv, s_k, _d2 = key_shape
161
+ _b3, _h2, _s3, d_v = value_shape
162
+ assert b == _b2 == _b3
163
+ assert h_kv == _h2
164
+ assert d_q == _d2
165
+ assert s_k == _s3
166
+ assert d_q == _d2
167
+ assert h_q % h_kv == 0
168
+ # How many values are computed in the attention?
169
+ mask_nz = mask_non_zeros(s_q, s_k, window_left, window_right)
170
+
171
+ # q@k.T
172
+ total_flops = 2 * b * h_q * d_q * mask_nz
173
+ # attn@v
174
+ total_flops += 2 * b * h_q * d_v * mask_nz
175
+ return total_flops
176
+
177
+
178
+ if _C_flashattention3 is not None: # noqa: C901
179
+ # Compatibility check for FAv3 APIs
180
+ EXPECTED_NUM_OF_ARGS = [
181
+ ("fwd", 33),
182
+ ("bwd", 22),
183
+ ]
184
+
185
+ import re
186
+
187
+ def count_args_from_doc(docstring) -> int:
188
+ # Use a regular expression to find the argument list inside parentheses
189
+ match = re.search(r"\((.*?)\)", docstring)
190
+ if match:
191
+ # Extract the argument list and split by commas
192
+ args_list = match.group(1).split(",")
193
+ # Count the number of arguments
194
+ return len(args_list)
195
+ else:
196
+ raise ValueError("No valid argument list found in the docstring.")
197
+
198
+ for name, num_of_args in EXPECTED_NUM_OF_ARGS:
199
+ num_of_args_from_doc = count_args_from_doc(
200
+ getattr(_C_flashattention3, name).__doc__
201
+ )
202
+ assert num_of_args_from_doc == num_of_args, (
203
+ f"Found func signature mismatch for {name}. Expected {num_of_args},"
204
+ f"actual: {num_of_args_from_doc} Please update the version of Flash Attention3."
205
+ )
206
+
207
+ # returns: out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p
208
+ @torch.library.custom_op(
209
+ "mslk_flash3::flash_fwd", mutates_args=(), device_types=["cuda"]
210
+ )
211
+ def mha_fwd(
212
+ query: torch.Tensor,
213
+ key: torch.Tensor,
214
+ value: torch.Tensor,
215
+ cu_seqlens_q: Optional[torch.Tensor],
216
+ cu_seqlens_k: Optional[torch.Tensor],
217
+ seqused_k: Optional[torch.Tensor],
218
+ leftpad_k: Optional[torch.Tensor],
219
+ max_seqlen_q: int,
220
+ max_seqlen_k: int,
221
+ p: float,
222
+ softmax_scale: float,
223
+ is_causal: bool,
224
+ descale_q: Optional[torch.Tensor] = None,
225
+ descale_k: Optional[torch.Tensor] = None,
226
+ descale_v: Optional[torch.Tensor] = None,
227
+ block_table: Optional[torch.Tensor] = None,
228
+ use_kvsplit: bool = False,
229
+ window_left: int = -1,
230
+ window_right: int = -1,
231
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
232
+ query, key = [maybe_contiguous(x) for x in (query, key)]
233
+ value = (
234
+ value.contiguous()
235
+ if value.stride(-1) != 1 and value.stride(-3) != 1
236
+ else value
237
+ )
238
+ cu_seqlens_q, cu_seqlens_k, seqused_k = [
239
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, seqused_k)
240
+ ]
241
+ block_table = maybe_contiguous(block_table)
242
+
243
+ def _get_batch():
244
+ if cu_seqlens_q is not None:
245
+ return cu_seqlens_q.shape[0] - 1
246
+ return query.shape[0]
247
+
248
+ is_paged = block_table is not None
249
+ bs = _get_batch()
250
+ orig_query_shape = query.shape
251
+
252
+ pack_gqa = None
253
+ if use_kvsplit:
254
+ # For KV split, we need to make sure query in shape [batch, seqlen, num_heads, head_dim_q]
255
+ # to be compatible with `pack_gqa` feature
256
+ query = query.view(bs, -1, query.shape[-2], query.shape[-1])
257
+ cu_seqlens_q = None
258
+
259
+ # Auto-detect if we should use GQA parallel mode
260
+ if query.shape[1] <= 64 and query.shape[2] != key.shape[2]:
261
+ pack_gqa = True
262
+
263
+ assert _C_flashattention3 is not None
264
+ out, softmax_lse, *rest = _C_flashattention3.fwd(
265
+ query,
266
+ key,
267
+ value,
268
+ None,
269
+ None, # k_new, v_new
270
+ None, # qv
271
+ None, # out
272
+ cu_seqlens_q,
273
+ cu_seqlens_k if not is_paged else None,
274
+ None, # cu_seqlens_k_new
275
+ None, # seqused_q
276
+ seqused_k,
277
+ max_seqlen_q,
278
+ max_seqlen_k,
279
+ block_table,
280
+ None, # kv_batch_idx
281
+ leftpad_k,
282
+ None, # rotary_cos
283
+ None, # rotary_sin
284
+ None, # seqlens_rotary
285
+ descale_q,
286
+ descale_k,
287
+ descale_v,
288
+ softmax_scale,
289
+ is_causal,
290
+ window_left,
291
+ window_right,
292
+ 0.0, # softcap
293
+ not use_kvsplit, # rotary_interleaved
294
+ None, # scheduler_metadata
295
+ 1 if not use_kvsplit else 0, # num_splits
296
+ pack_gqa, # pack_gqa
297
+ 0, # sm_margin
298
+ )
299
+
300
+ if query.shape != orig_query_shape:
301
+ # Reshape softmax_lse to match expected output format
302
+ num_heads_q = query.shape[-2]
303
+ orig_lse_shape = softmax_lse.shape
304
+ softmax_lse = softmax_lse.view(
305
+ orig_lse_shape[0], num_heads_q, -1, orig_lse_shape[2]
306
+ )
307
+ softmax_lse = softmax_lse.permute(1, 0, 2, 3).reshape(num_heads_q, -1)
308
+
309
+ return out, softmax_lse
310
+
311
+ @torch.library.register_fake("mslk_flash3::flash_fwd")
312
+ def mha_fwd_fake(
313
+ query: torch.Tensor,
314
+ key: torch.Tensor,
315
+ value: torch.Tensor,
316
+ cu_seqlens_q: Optional[torch.Tensor],
317
+ cu_seqlens_k: Optional[torch.Tensor],
318
+ seqused_k: Optional[torch.Tensor],
319
+ leftpad_k: Optional[torch.Tensor],
320
+ max_seqlen_q: int,
321
+ max_seqlen_k: int,
322
+ p: float,
323
+ softmax_scale: float,
324
+ is_causal: bool,
325
+ descale_q: Optional[torch.Tensor] = None,
326
+ descale_k: Optional[torch.Tensor] = None,
327
+ descale_v: Optional[torch.Tensor] = None,
328
+ block_table: Optional[torch.Tensor] = None,
329
+ use_kvsplit: bool = False,
330
+ window_left: int = -1,
331
+ window_right: int = -1,
332
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
333
+ query_shape = query.shape
334
+ out_shape = (*query_shape[:-1], value.shape[-1])
335
+ if query.dtype == torch.float8_e4m3fn or query.dtype == torch.float8_e5m2:
336
+ out = query.new_empty(out_shape, dtype=torch.bfloat16)
337
+ else:
338
+ out = query.new_empty(out_shape)
339
+ # Query is (B, M, H, K) or (total_M, H, K)
340
+ # LSE is (B, H, M) or (H, total_M)
341
+ lse_shape = (
342
+ (query_shape[0], query_shape[2], query_shape[1])
343
+ if cu_seqlens_q is None
344
+ else (query_shape[1], query_shape[0])
345
+ )
346
+ lse = query.new_empty(lse_shape, dtype=torch.float32)
347
+ return out, lse
348
+
349
+ @register_flop_formula(torch.ops.mslk_flash3.flash_fwd, get_raw=True)
350
+ def mha_fwd_flops(
351
+ query: torch.Tensor,
352
+ key: torch.Tensor,
353
+ value: torch.Tensor,
354
+ cu_seqlens_q: Optional[torch.Tensor],
355
+ cu_seqlens_k: Optional[torch.Tensor],
356
+ seqused_k: Optional[torch.Tensor],
357
+ leftpad_k: Optional[torch.Tensor],
358
+ max_seqlen_q: int,
359
+ max_seqlen_k: int,
360
+ p: float,
361
+ softmax_scale: float,
362
+ is_causal: bool,
363
+ descale_q: Optional[torch.Tensor] = None,
364
+ descale_k: Optional[torch.Tensor] = None,
365
+ descale_v: Optional[torch.Tensor] = None,
366
+ block_table: Optional[torch.Tensor] = None,
367
+ use_kvsplit: bool = False,
368
+ window_left: int = -1,
369
+ window_right: int = -1,
370
+ # The FLOPs counter might pass more args (out_val, out_shape, ...)
371
+ *args,
372
+ **kwargs,
373
+ ):
374
+ assert 3 <= query.ndim <= 4
375
+ assert 3 <= key.ndim <= 4
376
+ assert 3 <= value.ndim <= 4
377
+ # This FLOP formula is used by torch.compile's partitioner "automatic
378
+ # activation checkpointing" (AutoAC) to decide which ops to preserve
379
+ # for backward or to recompute. However, this formula is data-dependent!
380
+ # This makes all invocations reuse the choices made based on the first
381
+ # inputs, which may be sub-optimal but also lead to inconsistent
382
+ # behavior across runs. In the presence of tensor parallelism it might
383
+ # also lead to deadlocks if AutoAC recomputes different collectives
384
+ # on different ranks. For distributed jobs it seems more robust to have
385
+ # all ranks always use the "worst case" FLOP estimate. Ranks are in
386
+ # lockstep anyways and will be going as fast as the slowest one.
387
+ if os.environ.get("XFORMERS_FLOP_FORMULA_WORST_CASE", "0") == "1":
388
+ cu_seqlens_q = cu_seqlens_k = max_seqlen_q = max_seqlen_k = None # type: ignore[assignment]
389
+ query = query.unsqueeze(0) if query.ndim == 3 else query
390
+ key = key.unsqueeze(0) if key.ndim == 3 else key
391
+ value = value.unsqueeze(0) if value.ndim == 3 else value
392
+ sizes = _unpack_flash_attention_nested_shapes(
393
+ query=query.transpose(-2, -3) if query.ndim == 4 else query,
394
+ key=key.transpose(-2, -3) if key.ndim == 4 else key,
395
+ value=value.transpose(-2, -3) if value.ndim == 4 else value,
396
+ cum_seq_q=cu_seqlens_q,
397
+ cum_seq_k=cu_seqlens_k,
398
+ max_q=max_seqlen_q,
399
+ max_k=max_seqlen_k,
400
+ )
401
+ if is_causal:
402
+ window_right = 0
403
+ res = sum(
404
+ sdpa_flop_count(
405
+ query_shape,
406
+ key_shape,
407
+ value_shape,
408
+ window_left=window_left,
409
+ window_right=window_right,
410
+ )
411
+ for query_shape, key_shape, value_shape, _ in sizes
412
+ )
413
+ return res
414
+
415
+ def _create_dq_dk_dv(
416
+ grads_share_storage: bool, query, key, value
417
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
418
+ # Create dq,dk,dv
419
+ # If Q/K/V come from a single QKV tensor, let's put the gradient in the
420
+ # right strides, so we can avoid a `cat`
421
+ if grads_share_storage:
422
+ chunk = torch.empty(
423
+ (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
424
+ dtype=query.dtype,
425
+ device=query.device,
426
+ )
427
+ return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2)
428
+ return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
429
+
430
+ @torch.library.custom_op(
431
+ "mslk_flash3::flash_bwd", mutates_args=(), device_types=["cuda"]
432
+ )
433
+ def mha_bwd(
434
+ grads_share_storage: bool,
435
+ dout: torch.Tensor,
436
+ query: torch.Tensor,
437
+ key: torch.Tensor,
438
+ value: torch.Tensor,
439
+ out: torch.Tensor,
440
+ softmax_lse: torch.Tensor,
441
+ cu_seqlens_q: torch.Tensor,
442
+ cu_seqlens_k: torch.Tensor,
443
+ max_seqlen_q: int,
444
+ max_seqlen_k: int,
445
+ softmax_scale: float,
446
+ is_causal: bool,
447
+ window_left: int,
448
+ window_right: int,
449
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
450
+ dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value)
451
+ is_deterministic = False
452
+ if cu_seqlens_q is None:
453
+ assert cu_seqlens_k is None
454
+
455
+ assert _C_flashattention3 is not None
456
+ dq, dk, dv, softmax_d, *rest = _C_flashattention3.bwd(
457
+ dout,
458
+ query,
459
+ key,
460
+ value,
461
+ out,
462
+ softmax_lse,
463
+ dq,
464
+ dk,
465
+ dv,
466
+ cu_seqlens_q,
467
+ cu_seqlens_k,
468
+ None, # seqused_q
469
+ None, # seqused_k
470
+ max_seqlen_q,
471
+ max_seqlen_k,
472
+ softmax_scale,
473
+ is_causal,
474
+ window_left,
475
+ window_right,
476
+ 0.0, # not used, softcap
477
+ is_deterministic,
478
+ 0, # not used, sm_margin
479
+ )
480
+ return dq, dk, dv
481
+
482
+ @torch.library.register_fake("mslk_flash3::flash_bwd")
483
+ def mha_bwd_fake(
484
+ grads_share_storage: bool,
485
+ dout: torch.Tensor,
486
+ query: torch.Tensor,
487
+ key: torch.Tensor,
488
+ value: torch.Tensor,
489
+ out: torch.Tensor,
490
+ softmax_lse: torch.Tensor,
491
+ cu_seqlens_q: torch.Tensor,
492
+ cu_seqlens_k: torch.Tensor,
493
+ max_seqlen_q: int,
494
+ max_seqlen_k: int,
495
+ softmax_scale: float,
496
+ is_causal: bool,
497
+ window_left: int,
498
+ window_right: int,
499
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
500
+ return _create_dq_dk_dv(grads_share_storage, query, key, value)
501
+
502
+ @register_flop_formula(torch.ops.mslk_flash3.flash_bwd, get_raw=True)
503
+ def mha_bwd_flops(
504
+ grads_share_storage: bool,
505
+ dout: torch.Tensor,
506
+ query: torch.Tensor,
507
+ key: torch.Tensor,
508
+ value: torch.Tensor,
509
+ out: torch.Tensor,
510
+ softmax_lse: torch.Tensor,
511
+ cu_seqlens_q: torch.Tensor,
512
+ cu_seqlens_k: torch.Tensor,
513
+ max_seqlen_q: int,
514
+ max_seqlen_k: int,
515
+ softmax_scale: float,
516
+ is_causal: bool,
517
+ window_left: int,
518
+ window_right: int,
519
+ # The FLOPs counter might pass more args (out_val, out_shape, ...)
520
+ *args,
521
+ **kwargs,
522
+ ):
523
+ return (
524
+ 5
525
+ * mha_fwd_flops(
526
+ query,
527
+ key,
528
+ value,
529
+ cu_seqlens_q=cu_seqlens_q,
530
+ cu_seqlens_k=cu_seqlens_k,
531
+ seqused_k=None,
532
+ leftpad_k=None,
533
+ max_seqlen_q=max_seqlen_q,
534
+ max_seqlen_k=max_seqlen_k,
535
+ p=0.0,
536
+ softmax_scale=1.0,
537
+ is_causal=is_causal,
538
+ descale_q=None,
539
+ descale_k=None,
540
+ descale_v=None,
541
+ block_table=None,
542
+ use_kvsplit=False,
543
+ window_left=window_left,
544
+ window_right=window_right,
545
+ )
546
+ // 2
547
+ )
548
+
549
+
550
+ @register_operator
551
+ class FwOp(AttentionFwOpBase):
552
+ """Operator that computes memory-efficient attention using \
553
+ `Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
554
+ implementation.
555
+ """
556
+
557
+ OPERATOR = get_operator("mslk_flash3", "flash_fwd")
558
+ SUPPORTED_DEVICES: Set[str] = {"cuda"}
559
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = (9, 0)
560
+ CUDA_MAXIMUM_COMPUTE_CAPABILITY = (9, 0)
561
+ SUPPORTED_DTYPES: Set[torch.dtype] = {
562
+ torch.half,
563
+ torch.bfloat16,
564
+ torch.float8_e4m3fn,
565
+ }
566
+ SUPPORTED_MAX_K = 256
567
+ SUPPORTED_MIN_K = 64
568
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
569
+ type(None),
570
+ LowerTriangularMask,
571
+ LowerTriangularFromBottomRightMask,
572
+ LowerTriangularFromBottomRightLocalAttentionMask,
573
+ BlockDiagonalMask,
574
+ BlockDiagonalCausalMask,
575
+ BlockDiagonalCausalLocalAttentionMask,
576
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
577
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
578
+ BlockDiagonalCausalFromBottomRightMask,
579
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
580
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
581
+ BlockDiagonalLocalAttentionPaddedKeysMask,
582
+ BlockDiagonalGappyKeysMask,
583
+ BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask,
584
+ BlockDiagonalPaddedKeysMask,
585
+ LocalAttentionFromBottomRightMask,
586
+ PagedBlockDiagonalCausalLocalPaddedKeysMask,
587
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
588
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
589
+ PagedBlockDiagonalGappyKeysMask,
590
+ PagedBlockDiagonalPaddedKeysMask,
591
+ )
592
+
593
+ SUPPORTS_DROPOUT = False
594
+ SUPPORTS_CUSTOM_SCALE = True
595
+ SUPPORTS_DIFFERENT_VALUE_EMBED = False
596
+ SUPPORTS_BMGHK = True
597
+ SUPPORTS_PARTIAL = True
598
+ UNPADDED_LSE = True
599
+ NAME = f"fa3F@{FLASH_VERSION}"
600
+ VERSION = FLASH_VERSION
601
+
602
+ @classmethod
603
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
604
+ reasons = super(FwOp, cls).not_supported_reasons(d)
605
+ check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
606
+ if d.query.shape[-1] not in [64, 128, 192, 256]:
607
+ reasons.append("only head-dim 64, 128, 192 or 256 is supported")
608
+
609
+ _check_needs_no_topleft(d, reasons)
610
+
611
+ return reasons
612
+
613
+ @classmethod
614
+ def apply(
615
+ cls,
616
+ inp: Inputs,
617
+ needs_gradient: bool,
618
+ use_kvsplit: bool = False,
619
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
620
+ original_query_shape = inp.query.shape
621
+ out_shape = [
622
+ *inp.query.shape[:-1],
623
+ inp.value.shape[-1],
624
+ ]
625
+
626
+ def unpack_func(x) -> Tuple[torch.Tensor, Any]:
627
+ return x.unpack() if isinstance(x, ScaledTensor) else (x, None)
628
+
629
+ inp.query, descale_q = unpack_func(inp.query)
630
+ inp.key, descale_k = unpack_func(inp.key)
631
+ inp.value, descale_v = unpack_func(inp.value)
632
+ (
633
+ inp,
634
+ cu_seqlens_q,
635
+ max_seqlen_q,
636
+ cu_seqlens_k,
637
+ max_seqlen_k,
638
+ seqused_k,
639
+ ) = _convert_input_format(inp, supports_mqa=True, use_kvsplit=use_kvsplit)
640
+
641
+ q = inp.query
642
+ k = inp.key
643
+ v = inp.value
644
+
645
+ if inp.query.numel() > 0 and inp.key.numel() > 0:
646
+ win_left, win_right = _window_size(inp.attn_bias)
647
+ block_tables = (
648
+ inp.attn_bias.block_tables
649
+ if isinstance(
650
+ inp.attn_bias,
651
+ (PagedBlockDiagonalGappyKeysMask, PagedBlockDiagonalPaddedKeysMask),
652
+ )
653
+ else None
654
+ )
655
+ leftpad_k = None
656
+ if isinstance(inp.attn_bias, PagedBlockDiagonalGappyKeysMask):
657
+ assert cu_seqlens_q is not None
658
+ assert cu_seqlens_k is not None
659
+ if len(cu_seqlens_q) == len(cu_seqlens_k):
660
+ # case #1: len(cu_seqlens_k) = batch_size + 1
661
+ leftpad_k = cu_seqlens_k[:-1]
662
+ else:
663
+ # case #2: len(cu_seqlens_k) = batch_size
664
+ assert len(cu_seqlens_q) - len(cu_seqlens_k) == 1, (
665
+ f"{len(cu_seqlens_q)=} {len(cu_seqlens_k)=}"
666
+ )
667
+ leftpad_k = cu_seqlens_k
668
+ out, softmax_lse = cls.OPERATOR(
669
+ q,
670
+ k,
671
+ v,
672
+ cu_seqlens_q=cu_seqlens_q,
673
+ cu_seqlens_k=cu_seqlens_k,
674
+ seqused_k=seqused_k,
675
+ leftpad_k=leftpad_k,
676
+ max_seqlen_q=max_seqlen_q,
677
+ max_seqlen_k=max_seqlen_k,
678
+ p=inp.p,
679
+ softmax_scale=inp.scale_float,
680
+ is_causal=_is_causal(inp.attn_bias),
681
+ descale_q=descale_q,
682
+ descale_k=descale_k,
683
+ descale_v=descale_v,
684
+ block_table=block_tables,
685
+ use_kvsplit=use_kvsplit,
686
+ window_left=win_left,
687
+ window_right=win_right,
688
+ )
689
+ out = out.reshape(out_shape)
690
+ else:
691
+ out = torch.zeros(
692
+ inp.query.shape, device=inp.query.device, dtype=inp.query.dtype
693
+ )
694
+ if inp.is_partial:
695
+ softmax_lse = torch.full(
696
+ [inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]],
697
+ float("-inf"),
698
+ device=inp.query.device,
699
+ dtype=torch.float32,
700
+ )
701
+ else:
702
+ softmax_lse = torch.empty(
703
+ [inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]],
704
+ device=inp.query.device,
705
+ dtype=torch.float32,
706
+ )
707
+
708
+ ctx = Context(
709
+ out=out,
710
+ lse=softmax_lse,
711
+ )
712
+
713
+ if not needs_gradient:
714
+ return out, None
715
+ ctx = Context(
716
+ out=out,
717
+ lse=_post_process_lse(softmax_lse, inp, tuple(original_query_shape)),
718
+ )
719
+ return (out, ctx)
720
+
721
+
722
+ @register_operator
723
+ class BwOp(AttentionBwOpBase):
724
+ __doc__ = FwOp.__doc__
725
+
726
+ OPERATOR = get_operator("mslk_flash3", "flash_bwd")
727
+ SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
728
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
729
+ CUDA_MAXIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MAXIMUM_COMPUTE_CAPABILITY
730
+ SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
731
+ SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
732
+ SUPPORTED_MIN_K = FwOp.SUPPORTED_MIN_K
733
+ SUPPORTED_ATTN_BIAS_TYPES = (
734
+ # Exclude padded or gappy masks, since seqused_k is not supported by the kernel.
735
+ type(None),
736
+ LowerTriangularMask,
737
+ LowerTriangularFromBottomRightMask,
738
+ LowerTriangularFromBottomRightLocalAttentionMask,
739
+ BlockDiagonalMask,
740
+ BlockDiagonalCausalMask,
741
+ BlockDiagonalCausalLocalAttentionMask,
742
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
743
+ BlockDiagonalCausalFromBottomRightMask,
744
+ LocalAttentionFromBottomRightMask,
745
+ )
746
+
747
+ SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
748
+ SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
749
+ SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
750
+ IS_DETERMINISTIC = False
751
+ SUPPORTS_BMGHK = False
752
+ SUPPORTS_LSE_FORMATS: Sequence[str] = ["", "varlen_flat"]
753
+ NAME = f"fa3B@{FLASH_VERSION}"
754
+ VERSION = FLASH_VERSION
755
+
756
+ @classmethod
757
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
758
+ reasons = super(BwOp, cls).not_supported_reasons(d)
759
+ check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
760
+ _check_needs_no_topleft(d, reasons)
761
+ if d.query.shape[-1] not in [64, 128, 192, 256]:
762
+ reasons.append("only head-dim 64, 128, 192 or 256 is supported")
763
+
764
+ _check_needs_no_topleft(d, reasons)
765
+ return reasons
766
+
767
+ @classmethod
768
+ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
769
+ dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
770
+ (
771
+ inp,
772
+ cu_seqlens_q,
773
+ max_seqlen_q,
774
+ cu_seqlens_k,
775
+ max_seqlen_k,
776
+ _, # seqused_k,
777
+ ) = _convert_input_format(inp, supports_mqa=False)
778
+ ctx_lse = ctx.lse
779
+
780
+ if isinstance(inp.attn_bias, VARLEN_BIASES):
781
+ assert ctx_lse.shape[0] == 1
782
+ ctx_lse = ctx_lse[0]
783
+ else:
784
+ # NOTE: cutlass pads the last dimension, we need to slice it
785
+ assert ctx_lse.shape[2] >= max_seqlen_q
786
+ ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
787
+
788
+ kernel_out_shape = [
789
+ *inp.query.shape[:-1],
790
+ inp.value.shape[-1],
791
+ ]
792
+ assert grad.dtype in cls.SUPPORTED_DTYPES
793
+
794
+ if inp.query.numel() and inp.key.numel():
795
+ win_left, win_right = _window_size(inp.attn_bias)
796
+ dq, dk, dv = cls.OPERATOR(
797
+ ctx.qkv_share_storage,
798
+ grad.reshape(kernel_out_shape).contiguous(),
799
+ inp.query,
800
+ inp.key,
801
+ inp.value,
802
+ ctx.out.reshape(kernel_out_shape),
803
+ ctx.lse,
804
+ cu_seqlens_q,
805
+ cu_seqlens_k,
806
+ max_seqlen_q,
807
+ max_seqlen_k,
808
+ window_left=win_left,
809
+ window_right=win_right,
810
+ softmax_scale=inp.scale_float,
811
+ is_causal=_is_causal(inp.attn_bias),
812
+ )
813
+ grads = Gradients(dq, dk, dv)
814
+ else:
815
+ grads = Gradients(
816
+ dq=torch.zeros_like(inp.query),
817
+ dk=torch.zeros_like(inp.key),
818
+ dv=torch.zeros_like(inp.value),
819
+ )
820
+
821
+ grads.dq = grads.dq.reshape(dq_shape)
822
+ grads.dk = grads.dk.reshape(dk_shape)
823
+ grads.dv = grads.dv.reshape(dv_shape)
824
+ return grads
825
+
826
+
827
+ @register_operator
828
+ class FwOp_KVSplit(FwOp):
829
+ """Operator that computes memory-efficient attention using \
830
+ `Flash-Attention3 <https://github.com/Dao-AILab/flash-attention/tree/main/hopper>`_ \
831
+ implementation with heuristic rules to dispatch decoding shapes to KVSplit Attention \
832
+ """
833
+
834
+ NAME = f"fa3F_splitKV@{FLASH_VERSION}"
835
+ enable_kvsplit_attn: bool = True
836
+
837
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
838
+ type(None),
839
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
840
+ BlockDiagonalPaddedKeysMask,
841
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
842
+ BlockDiagonalGappyKeysMask,
843
+ BlockDiagonalLocalAttentionPaddedKeysMask,
844
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
845
+ PagedBlockDiagonalGappyKeysMask,
846
+ PagedBlockDiagonalPaddedKeysMask,
847
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
848
+ )
849
+
850
+ @classmethod
851
+ def apply( # type: ignore[override]
852
+ cls,
853
+ inp: Inputs,
854
+ needs_gradient: bool,
855
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
856
+ use_kvsplit = _heuristic_kvsplit(inp, cls.enable_kvsplit_attn)
857
+
858
+ return super().apply(inp, needs_gradient, use_kvsplit)