mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,862 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+
9
+ import os
10
+ from itertools import zip_longest
11
+ from typing import Any, Iterable, List, Optional, Set, Tuple, Union
12
+
13
+ import torch
14
+
15
+ from .attn_bias import (
16
+ AttentionBias,
17
+ BlockDiagonalCausalFromBottomRightMask,
18
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
19
+ BlockDiagonalCausalLocalAttentionMask,
20
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
21
+ BlockDiagonalCausalMask,
22
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
23
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
24
+ BlockDiagonalGappyKeysMask,
25
+ BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask,
26
+ BlockDiagonalLocalAttentionPaddedKeysMask,
27
+ BlockDiagonalMask,
28
+ BlockDiagonalPaddedKeysMask,
29
+ LocalAttentionFromBottomRightMask,
30
+ LowerTriangularFromBottomRightLocalAttentionMask,
31
+ LowerTriangularFromBottomRightMask,
32
+ LowerTriangularMask,
33
+ PagedBlockDiagonalCausalLocalPaddedKeysMask,
34
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
35
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
36
+ PagedBlockDiagonalGappyKeysMask,
37
+ PagedBlockDiagonalPaddedKeysMask,
38
+ VARLEN_BIASES,
39
+ )
40
+ from .common import (
41
+ AttentionBwOpBase,
42
+ AttentionFwOpBase,
43
+ check_lastdim_alignment_stride1,
44
+ Context,
45
+ Gradients,
46
+ Inputs,
47
+ )
48
+ from .torch_attention_compat import is_pt_flash_old
49
+ from .utils.op_common import get_operator, register_operator
50
+
51
+ FLASH_VERSION = "0.0.0"
52
+ VARLEN_LSE_PACKED = False
53
+ pt_flash_is_old = False
54
+ _TRY_PT_FLASH_ATTN = torch.version.hip is None
55
+ _USE_PT_FLASH_ATTN = False
56
+
57
+ try: # noqa: C901
58
+ try:
59
+ from xformers import _C_flashattention # type: ignore[attr-defined]
60
+
61
+ try:
62
+ from xformers._cpp_lib import _build_metadata # type: ignore[attr-defined]
63
+
64
+ if _build_metadata is not None:
65
+ FLASH_VERSION = _build_metadata.flash_version
66
+ except ImportError:
67
+ FLASH_VERSION = "unknown"
68
+
69
+ VARLEN_LSE_PACKED = True
70
+ except ImportError:
71
+ try:
72
+ import flash_attn
73
+ import flash_attn.flash_attn_interface
74
+
75
+ if hasattr(flash_attn.flash_attn_interface, "flash_attn_cuda"):
76
+ _C_flashattention = flash_attn.flash_attn_interface.flash_attn_cuda # type: ignore[attr-defined]
77
+ else:
78
+ _C_flashattention = flash_attn.flash_attn_interface.flash_attn_gpu # type: ignore[attr-defined]
79
+
80
+ FLASH_VERSION = flash_attn.__version__
81
+ FLASH_VER_MIN = (2, 6, 3)
82
+ FLASH_VER_LAST = (2, 8, 3) # last supported, inclusive
83
+ flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
84
+ if (
85
+ flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST
86
+ ) and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1":
87
+ raise ImportError(
88
+ f"Requires Flash-Attention version >={'.'.join([str(i) for i in FLASH_VER_MIN])},"
89
+ f"<={'.'.join([str(i) for i in FLASH_VER_LAST])} "
90
+ f"but got {FLASH_VERSION}."
91
+ )
92
+ VARLEN_LSE_PACKED = True
93
+ except ImportError as e:
94
+ if not _TRY_PT_FLASH_ATTN:
95
+ raise e
96
+ pt_flash_is_old = is_pt_flash_old(force=True) is True
97
+ FLASH_VERSION = torch.nn.attention._get_flash_version() # type: ignore
98
+ VARLEN_LSE_PACKED = not pt_flash_is_old
99
+ _USE_PT_FLASH_ATTN = True
100
+
101
+ @torch.library.custom_op(
102
+ "mslk_flash::flash_fwd",
103
+ mutates_args=(),
104
+ device_types=["cuda"],
105
+ )
106
+ def _flash_fwd(
107
+ query: torch.Tensor,
108
+ key: torch.Tensor,
109
+ value: torch.Tensor,
110
+ cu_seqlens_q: Optional[torch.Tensor],
111
+ cu_seqlens_k: Optional[torch.Tensor],
112
+ seqused_k: Optional[torch.Tensor],
113
+ max_seqlen_q: int,
114
+ max_seqlen_k: int,
115
+ p: float,
116
+ softmax_scale: float,
117
+ is_causal: bool,
118
+ window_left: int,
119
+ window_right: int,
120
+ return_softmax: bool,
121
+ block_tables: Optional[torch.Tensor],
122
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
123
+ softcap = 0.0
124
+ if _USE_PT_FLASH_ATTN:
125
+ ret = torch.ops.aten._flash_attention_forward(
126
+ query,
127
+ key,
128
+ value,
129
+ cu_seqlens_q, # cum_seq_q
130
+ cu_seqlens_k, # cum_seq_k
131
+ max_seqlen_q, # max_q
132
+ max_seqlen_k, # max_k
133
+ p, # dropout_p
134
+ is_causal,
135
+ return_debug_mask=False,
136
+ scale=softmax_scale,
137
+ window_size_left=window_left,
138
+ window_size_right=window_right,
139
+ seqused_k=seqused_k,
140
+ alibi_slopes=None, # alibi_slopes
141
+ )
142
+ if pt_flash_is_old:
143
+ (
144
+ attention,
145
+ logsumexp,
146
+ philox_seed,
147
+ philox_offset,
148
+ _,
149
+ ) = ret
150
+ rng_state = torch.stack([philox_seed, philox_offset])
151
+ else:
152
+ attention, logsumexp, rng_state, _, _ = ret
153
+ return attention, logsumexp, rng_state
154
+ else:
155
+ if cu_seqlens_q is None:
156
+ assert cu_seqlens_k is None
157
+ assert seqused_k is None
158
+ out, softmax_lse, p, rng_state = _C_flashattention.fwd(
159
+ query,
160
+ key,
161
+ value,
162
+ None, # out
163
+ None, # alibi_slopes
164
+ p,
165
+ softmax_scale,
166
+ is_causal,
167
+ window_left, # window_size_left
168
+ window_right, # window_size_right
169
+ softcap,
170
+ return_softmax,
171
+ None, # rng
172
+ )
173
+ else:
174
+ out, softmax_lse, p, rng_state = _C_flashattention.varlen_fwd(
175
+ query,
176
+ key,
177
+ value,
178
+ None, # out
179
+ cu_seqlens_q,
180
+ cu_seqlens_k,
181
+ seqused_k,
182
+ None, # leftpad_k_
183
+ block_tables,
184
+ None, # alibi_slopes
185
+ max_seqlen_q,
186
+ max_seqlen_k,
187
+ p,
188
+ softmax_scale,
189
+ False,
190
+ is_causal,
191
+ window_left,
192
+ window_right,
193
+ softcap,
194
+ return_softmax,
195
+ None, # gen
196
+ )
197
+ return out, softmax_lse, rng_state
198
+
199
+ @torch.library.register_fake("mslk_flash::flash_fwd")
200
+ def _flash_fwd_abstract(
201
+ query,
202
+ key,
203
+ value,
204
+ cu_seqlens_q,
205
+ cu_seqlens_k,
206
+ seqused_k,
207
+ max_seqlen_q,
208
+ max_seqlen_k,
209
+ p,
210
+ softmax_scale,
211
+ is_causal,
212
+ window_left,
213
+ window_right,
214
+ return_softmax,
215
+ block_tables,
216
+ ):
217
+ out = torch.empty_like(query)
218
+ if cu_seqlens_q is None:
219
+ B, M, H, K = query.shape
220
+ lse_shape = [B, H, M] # XXXX ?
221
+ else:
222
+ M, H, K = query.shape
223
+ B = cu_seqlens_q.shape[0] - 1
224
+ if VARLEN_LSE_PACKED:
225
+ lse_shape = [H, M]
226
+ else:
227
+ lse_shape = [B, H, max_seqlen_q]
228
+ softmax_lse = torch.empty(lse_shape, device=query.device, dtype=torch.float32)
229
+ rng_state = torch.empty([2], device=query.device, dtype=torch.int64)
230
+ return out, softmax_lse, rng_state
231
+
232
+ @torch.library.custom_op(
233
+ "mslk_flash::flash_bwd",
234
+ mutates_args=(),
235
+ device_types=["cuda"],
236
+ )
237
+ def _flash_bwd(
238
+ grads_share_storage: bool,
239
+ grad: torch.Tensor,
240
+ query: torch.Tensor,
241
+ key: torch.Tensor,
242
+ value: torch.Tensor,
243
+ out: torch.Tensor,
244
+ lse: torch.Tensor,
245
+ cu_seqlens_q: torch.Tensor,
246
+ cu_seqlens_k: torch.Tensor,
247
+ max_seqlen_q: int,
248
+ max_seqlen_k: int,
249
+ p: float,
250
+ softmax_scale: float,
251
+ is_causal: bool,
252
+ window_left: int,
253
+ window_right: int,
254
+ rng_state: torch.Tensor,
255
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
256
+ softcap = 0.0
257
+ if _USE_PT_FLASH_ATTN:
258
+ assert softcap == 0.0
259
+ if rng_state is not None and pt_flash_is_old:
260
+ rng_state0 = rng_state[0]
261
+ rng_state1 = rng_state[1]
262
+ else:
263
+ rng_state0 = rng_state1 = rng_state
264
+ dq, dk, dv = torch.ops.aten._flash_attention_backward(
265
+ grad,
266
+ query,
267
+ key,
268
+ value,
269
+ out,
270
+ lse,
271
+ cu_seqlens_q,
272
+ cu_seqlens_k,
273
+ max_seqlen_q,
274
+ max_seqlen_k,
275
+ p,
276
+ is_causal,
277
+ rng_state0,
278
+ rng_state1,
279
+ scale=softmax_scale,
280
+ window_size_left=window_left,
281
+ window_size_right=window_right,
282
+ )
283
+ else:
284
+ dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value)
285
+ if cu_seqlens_k is None:
286
+ assert cu_seqlens_q is None
287
+ _C_flashattention.bwd(
288
+ grad,
289
+ query,
290
+ key,
291
+ value,
292
+ out,
293
+ lse,
294
+ dq,
295
+ dk,
296
+ dv,
297
+ None, # alibi_slopes
298
+ p,
299
+ softmax_scale,
300
+ is_causal,
301
+ window_left,
302
+ window_right,
303
+ softcap,
304
+ False, # deterministic
305
+ None,
306
+ rng_state,
307
+ )
308
+ else:
309
+ _C_flashattention.varlen_bwd(
310
+ grad,
311
+ query,
312
+ key,
313
+ value,
314
+ out,
315
+ lse,
316
+ dq,
317
+ dk,
318
+ dv,
319
+ cu_seqlens_q,
320
+ cu_seqlens_k,
321
+ None, # alibi_slopes
322
+ max_seqlen_q,
323
+ max_seqlen_k,
324
+ p,
325
+ softmax_scale,
326
+ False, # zero_tensors
327
+ is_causal,
328
+ window_left,
329
+ window_right,
330
+ softcap,
331
+ False, # deterministic
332
+ None,
333
+ rng_state,
334
+ )
335
+ return dq, dk, dv
336
+
337
+ @torch.library.register_fake("mslk_flash::flash_bwd")
338
+ def _flash_bwd_abstract(
339
+ grads_share_storage,
340
+ grad,
341
+ query,
342
+ key,
343
+ value,
344
+ *args,
345
+ **kwargs,
346
+ ):
347
+ return _create_dq_dk_dv(grads_share_storage, query, key, value)
348
+
349
+ def _create_dq_dk_dv(
350
+ grads_share_storage: bool, query, key, value
351
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
352
+ # Create dq,dk,dv
353
+ # If Q/K/V come from a single QKV tensor, let's put the gradient in the
354
+ # right strides, so we can avoid a `cat`
355
+ if grads_share_storage:
356
+ chunk = torch.empty(
357
+ (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
358
+ dtype=query.dtype,
359
+ device=query.device,
360
+ )
361
+ return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2)
362
+ return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
363
+
364
+ except ImportError:
365
+ pass
366
+
367
+
368
+ def _convert_input_format(
369
+ inp: Inputs,
370
+ supports_mqa: bool,
371
+ use_kvsplit: bool = False,
372
+ ) -> Tuple[
373
+ Inputs,
374
+ Optional[torch.Tensor],
375
+ int,
376
+ Optional[torch.Tensor],
377
+ int,
378
+ Optional[torch.Tensor],
379
+ ]:
380
+ assert inp.query.ndim in [4, 5]
381
+ query, key, value = inp.query, inp.key, inp.value
382
+ batch = query.shape[0]
383
+ seqlen_q = query.shape[1]
384
+ seqlen_kv = key.shape[1]
385
+ head_dim_q = query.shape[-1]
386
+ head_dim_v = value.shape[-1]
387
+
388
+ attn_bias = inp.attn_bias
389
+ if isinstance(attn_bias, BlockDiagonalMask):
390
+ assert attn_bias.k_seqinfo.seqstart.device == inp.query.device
391
+ cu_seqlen_k = attn_bias.k_seqinfo.seqstart
392
+ cu_seqlen_q = attn_bias.q_seqinfo.seqstart
393
+ max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
394
+ max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
395
+ seqused_k = None
396
+ elif isinstance(
397
+ attn_bias,
398
+ (
399
+ BlockDiagonalGappyKeysMask,
400
+ BlockDiagonalPaddedKeysMask,
401
+ PagedBlockDiagonalGappyKeysMask,
402
+ PagedBlockDiagonalPaddedKeysMask,
403
+ ),
404
+ ):
405
+ assert attn_bias.k_seqinfo.seqstart.device == inp.query.device
406
+ cu_seqlen_k = attn_bias.k_seqinfo.seqstart
407
+ cu_seqlen_q = attn_bias.q_seqinfo.seqstart
408
+ max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
409
+ max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
410
+ seqused_k = attn_bias.k_seqinfo.seqlen
411
+ else:
412
+ cu_seqlen_k = None
413
+ cu_seqlen_q = None
414
+ seqused_k = None
415
+ max_seqlen_q = inp.query.shape[1]
416
+ max_seqlen_k = inp.key.shape[1]
417
+
418
+ if query.ndim == 5: # GQA
419
+ assert supports_mqa
420
+
421
+ # Fold the group/head_in_group dimensions together
422
+ def fold(x):
423
+ # Either the head is replicated
424
+ if x.stride(3) == 0:
425
+ return x[:, :, :, 0]
426
+ # Or we reshape
427
+ return x.reshape(
428
+ [
429
+ x.shape[0],
430
+ x.shape[1],
431
+ -1,
432
+ x.shape[4],
433
+ ]
434
+ )
435
+
436
+ query = fold(query)
437
+ key = fold(key)
438
+ value = fold(value)
439
+ # Optimize for MHA
440
+ if supports_mqa and key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0:
441
+ key = key[:, :, :1]
442
+ value = value[:, :, :1]
443
+ # Initially we have `query.shape = [batch, seqlen, num_heads, head_dim_q]`
444
+ # We want format `[batch * seqlen, num_heads, head_dim_q]`
445
+ if cu_seqlen_k is not None:
446
+ query = query.reshape([batch * seqlen_q, -1, head_dim_q])
447
+ key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
448
+ value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
449
+ if isinstance(
450
+ attn_bias,
451
+ (PagedBlockDiagonalGappyKeysMask, PagedBlockDiagonalPaddedKeysMask),
452
+ ):
453
+ num_pages = value.shape[0] // attn_bias.page_size
454
+ key = key.view(num_pages, attn_bias.page_size, *key.shape[1:])
455
+ value = value.view(num_pages, attn_bias.page_size, *value.shape[1:])
456
+
457
+ new_inp = Inputs(
458
+ query=query,
459
+ key=key,
460
+ value=value,
461
+ attn_bias=attn_bias,
462
+ p=inp.p,
463
+ scale=inp.scale,
464
+ output_dtype=inp.output_dtype,
465
+ is_partial=inp.is_partial,
466
+ )
467
+ return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k, seqused_k
468
+
469
+
470
+ def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
471
+ return isinstance(
472
+ attn_bias,
473
+ (
474
+ LowerTriangularMask,
475
+ LowerTriangularFromBottomRightMask,
476
+ LowerTriangularFromBottomRightLocalAttentionMask,
477
+ BlockDiagonalCausalMask,
478
+ BlockDiagonalCausalLocalAttentionMask,
479
+ PagedBlockDiagonalCausalLocalPaddedKeysMask,
480
+ BlockDiagonalCausalFromBottomRightMask,
481
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
482
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
483
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
484
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
485
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
486
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
487
+ ),
488
+ )
489
+
490
+
491
+ def _window_size(
492
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
493
+ ) -> Tuple[int, int]:
494
+ win_left = -1
495
+ win_right = -1
496
+ if isinstance(
497
+ attn_bias,
498
+ (
499
+ BlockDiagonalCausalLocalAttentionMask,
500
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
501
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
502
+ LowerTriangularFromBottomRightLocalAttentionMask,
503
+ PagedBlockDiagonalCausalLocalPaddedKeysMask,
504
+ ),
505
+ ):
506
+ win_left = attn_bias._window_size - 1
507
+ if isinstance(
508
+ attn_bias,
509
+ (
510
+ BlockDiagonalLocalAttentionPaddedKeysMask,
511
+ LocalAttentionFromBottomRightMask,
512
+ BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask,
513
+ ),
514
+ ):
515
+ win_left = attn_bias.window_left
516
+ win_right = attn_bias.window_right
517
+ return (win_left, win_right)
518
+
519
+
520
+ def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
521
+ # Flash does not support TopLeft, so only allow causal masks with TopLeft
522
+ # if each batch element has equal number of queries and keys.
523
+ attn_bias = d.attn_bias
524
+ if isinstance(attn_bias, BlockDiagonalCausalMask):
525
+ # Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
526
+ # if each batch element has equal number of queries and keys.
527
+ for k_start, q_start in zip_longest(
528
+ attn_bias.k_seqinfo.seqstart_py, attn_bias.q_seqinfo.seqstart_py
529
+ ):
530
+ if k_start != q_start:
531
+ reasons.append(
532
+ "Only support BlockDiagonalCausalMask if equal"
533
+ " numbers of keys and queries"
534
+ )
535
+ break
536
+ elif isinstance(attn_bias, LowerTriangularMask):
537
+ if d.query.shape[1] != d.key.shape[1]:
538
+ reasons.append(
539
+ "Only support LowerTriangularMask if equal number ofkeys and queries"
540
+ )
541
+
542
+
543
+ def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
544
+ """
545
+ We want to be able to collapse the G/H dimensions together
546
+ """
547
+ if x.ndim == 5:
548
+ stride_g, stride_h = x.stride(2), x.stride(3)
549
+ if x.shape[2] == 1:
550
+ return
551
+ if x.shape[3] == 1 or stride_h == 0:
552
+ return
553
+ if stride_g != stride_h * x.shape[-2]:
554
+ reasons.append(
555
+ f"GQA is only supported when the G/H dimensions are contiguous\n"
556
+ f" {name}.stride: {x.stride()}\n"
557
+ f" {name}.shape : {list(x.shape)}"
558
+ )
559
+
560
+
561
+ def _post_process_lse(
562
+ lse: torch.Tensor,
563
+ inp: Inputs,
564
+ original_query_shape: Tuple[int, ...],
565
+ ) -> torch.Tensor:
566
+ # Easy case: no varlen
567
+ if not isinstance(inp.attn_bias, VARLEN_BIASES):
568
+ if len(original_query_shape) == 5:
569
+ # [B, GH, M] => [B, G, H, M]
570
+ return lse.unflatten(1, original_query_shape[2:4])
571
+ return lse
572
+
573
+ # Already packed: just bring back the batch dimension
574
+ if VARLEN_LSE_PACKED:
575
+ if len(original_query_shape) == 5:
576
+ # (1, G, H, total_q)
577
+ return lse.unflatten(0, original_query_shape[2:4]).unsqueeze(0)
578
+ # (1, H, total_q)
579
+ return lse.unsqueeze(0)
580
+
581
+ if not inp.is_partial:
582
+ # (B, H, M)
583
+ return lse
584
+
585
+ # reshape from (B, G*H, max_seqlen) to (1, G*H, B*max_seqlen)
586
+ # Unfortunately this flatten is not just a view.
587
+ lse_hkm = lse.permute(1, 0, 2).flatten(start_dim=1)[None]
588
+ if len(original_query_shape) == 5:
589
+ return lse_hkm.unflatten(1, original_query_shape[2:4])
590
+ return lse_hkm
591
+
592
+
593
+ @register_operator
594
+ class FwOp(AttentionFwOpBase):
595
+ """Operator that computes memory-efficient attention using \
596
+ `Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
597
+ implementation.
598
+ """
599
+
600
+ OPERATOR = get_operator("mslk_flash", "flash_fwd")
601
+ SUPPORTED_DEVICES: Set[str] = {"cuda"}
602
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
603
+ SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
604
+ SUPPORTED_MAX_K = 256
605
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
606
+ type(None),
607
+ LowerTriangularMask,
608
+ LowerTriangularFromBottomRightMask,
609
+ LowerTriangularFromBottomRightLocalAttentionMask,
610
+ BlockDiagonalMask,
611
+ BlockDiagonalCausalMask,
612
+ BlockDiagonalCausalLocalAttentionMask,
613
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
614
+ BlockDiagonalLocalAttentionPaddedKeysMask,
615
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
616
+ BlockDiagonalCausalFromBottomRightMask,
617
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
618
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
619
+ BlockDiagonalGappyKeysMask,
620
+ BlockDiagonalPaddedKeysMask,
621
+ LocalAttentionFromBottomRightMask,
622
+ PagedBlockDiagonalCausalLocalPaddedKeysMask,
623
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
624
+ PagedBlockDiagonalPaddedKeysMask,
625
+ )
626
+
627
+ SUPPORTS_DROPOUT = True
628
+ SUPPORTS_CUSTOM_SCALE = True
629
+ SUPPORTS_DIFFERENT_VALUE_EMBED = False
630
+ SUPPORTS_BMGHK = True
631
+ SUPPORTS_PARTIAL = True
632
+ VARLEN_LSE_PACKED = VARLEN_LSE_PACKED
633
+ NAME = f"fa2F@{FLASH_VERSION}-pt" if _USE_PT_FLASH_ATTN else f"fa2F@{FLASH_VERSION}"
634
+ VERSION = FLASH_VERSION
635
+
636
+ @classmethod
637
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
638
+ reasons = super(FwOp, cls).not_supported_reasons(d)
639
+ check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
640
+ _check_needs_no_topleft(d, reasons)
641
+ _check_strides_for_bmghk(d.query, "query", reasons)
642
+ _check_strides_for_bmghk(d.key, "key", reasons)
643
+ _check_strides_for_bmghk(d.value, "value", reasons)
644
+
645
+ if (
646
+ d.is_partial
647
+ and not VARLEN_LSE_PACKED
648
+ and isinstance(d.attn_bias, VARLEN_BIASES)
649
+ ):
650
+ q_seqinfo = d.attn_bias.q_seqinfo
651
+ if q_seqinfo.min_seqlen != q_seqinfo.max_seqlen:
652
+ # Flash provides padded LSE which we don't handle.
653
+ reasons.append("partial attention with heterogeneous queries")
654
+
655
+ if isinstance(
656
+ d.attn_bias,
657
+ (PagedBlockDiagonalGappyKeysMask, PagedBlockDiagonalPaddedKeysMask),
658
+ ):
659
+ if d.attn_bias.page_size % 256 != 0:
660
+ reasons.append("Paged KV cache block size must be divisible by 256.")
661
+ return reasons
662
+
663
+ @classmethod
664
+ def apply(
665
+ cls, inp: Inputs, needs_gradient: bool
666
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
667
+ return_softmax = False
668
+ original_query_shape = inp.query.shape
669
+
670
+ out_shape = [
671
+ *inp.query.shape[:-1],
672
+ inp.value.shape[-1],
673
+ ]
674
+ # no cumulative seqlen
675
+ (
676
+ inp,
677
+ cu_seqlens_q,
678
+ max_seqlen_q,
679
+ cu_seqlens_k,
680
+ max_seqlen_k,
681
+ seqused_k,
682
+ ) = _convert_input_format(inp, supports_mqa=True)
683
+
684
+ if inp.query.numel() > 0 and inp.key.numel() > 0:
685
+ win_left, win_right = _window_size(inp.attn_bias)
686
+ block_tables = (
687
+ inp.attn_bias.block_tables
688
+ if isinstance(inp.attn_bias, PagedBlockDiagonalPaddedKeysMask)
689
+ else None
690
+ )
691
+ out, softmax_lse, rng_state = cls.OPERATOR(
692
+ inp.query,
693
+ inp.key,
694
+ inp.value,
695
+ cu_seqlens_q,
696
+ cu_seqlens_k,
697
+ seqused_k,
698
+ max_seqlen_q,
699
+ max_seqlen_k,
700
+ inp.p,
701
+ inp.scale_float,
702
+ _is_causal(inp.attn_bias),
703
+ window_left=win_left,
704
+ window_right=win_right,
705
+ return_softmax=return_softmax,
706
+ block_tables=block_tables,
707
+ )
708
+ out = out.reshape(out_shape)
709
+ else:
710
+ out = torch.zeros(out_shape, device=inp.query.device, dtype=inp.query.dtype)
711
+ rng_state = None
712
+ lse_shape = (
713
+ [inp.query.shape[2], inp.query.shape[0] * inp.query.shape[1]]
714
+ if VARLEN_LSE_PACKED and isinstance(inp.attn_bias, VARLEN_BIASES)
715
+ else [inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]]
716
+ )
717
+ if inp.is_partial:
718
+ softmax_lse = torch.full(
719
+ lse_shape,
720
+ float("-inf"),
721
+ device=inp.query.device,
722
+ dtype=torch.float32,
723
+ )
724
+ else:
725
+ softmax_lse = torch.empty(
726
+ lse_shape,
727
+ device=inp.query.device,
728
+ dtype=torch.float32,
729
+ )
730
+
731
+ if not needs_gradient:
732
+ return out, None
733
+ ctx = Context(
734
+ out=out,
735
+ lse=_post_process_lse(softmax_lse, inp, tuple(original_query_shape)),
736
+ )
737
+ if inp.p != 0.0:
738
+ ctx.op_bw = BwOp
739
+ ctx.rng_state = rng_state
740
+ return (out, ctx)
741
+
742
+
743
+ @register_operator
744
+ class BwOp(AttentionBwOpBase):
745
+ __doc__ = FwOp.__doc__
746
+
747
+ OPERATOR = get_operator("mslk_flash", "flash_bwd")
748
+ SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
749
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
750
+ SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
751
+ SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
752
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = tuple(
753
+ set(FwOp.SUPPORTED_ATTN_BIAS_TYPES).difference(
754
+ {
755
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
756
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
757
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
758
+ BlockDiagonalLocalAttentionPaddedKeysMask,
759
+ BlockDiagonalGappyKeysMask,
760
+ BlockDiagonalPaddedKeysMask,
761
+ PagedBlockDiagonalCausalLocalPaddedKeysMask,
762
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
763
+ PagedBlockDiagonalPaddedKeysMask,
764
+ }
765
+ )
766
+ )
767
+ SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
768
+ SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
769
+ SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
770
+ IS_DETERMINISTIC = False
771
+ SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this!
772
+ VARLEN_LSE_PACKED = VARLEN_LSE_PACKED
773
+ NAME = f"fa2B@{FLASH_VERSION}-pt" if _USE_PT_FLASH_ATTN else f"fa2B@{FLASH_VERSION}"
774
+ VERSION = FLASH_VERSION
775
+
776
+ MAX_HEADDIM_DROPOUT_SM8x = 224
777
+
778
+ @classmethod
779
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
780
+ reasons = super(BwOp, cls).not_supported_reasons(d)
781
+ check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
782
+ _check_needs_no_topleft(d, reasons)
783
+ if d.device.type == "cuda":
784
+ # Due to limited shared-memory, some GPUs are limited in head dimension
785
+ device_capability = torch.cuda.get_device_capability(d.device)
786
+ is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
787
+ if (
788
+ max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_DROPOUT_SM8x
789
+ and not is_sm80_or_sm90
790
+ and d.p != 0.0
791
+ ):
792
+ reasons.append(
793
+ "requires a GPU with compute capability 8.0 "
794
+ f"(A100) or 9.0 (H100) for dropout when 'query.shape[-1] > {cls.MAX_HEADDIM_DROPOUT_SM8x}'"
795
+ )
796
+ return reasons
797
+
798
+ @classmethod
799
+ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
800
+ dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
801
+ (
802
+ inp,
803
+ cu_seqlens_q,
804
+ max_seqlen_q,
805
+ cu_seqlens_k,
806
+ max_seqlen_k,
807
+ seqused_k,
808
+ ) = _convert_input_format(inp, supports_mqa=False)
809
+ # assert ctx.lse.is_contiguous()
810
+ assert seqused_k is None
811
+ ctx_lse = ctx.lse
812
+ if isinstance(inp.attn_bias, VARLEN_BIASES) and VARLEN_LSE_PACKED:
813
+ assert ctx_lse.shape[0] == 1
814
+ ctx_lse = ctx_lse[0]
815
+ else:
816
+ # NOTE: cutlass pads the last dimension, we need to slice it
817
+ assert ctx_lse.shape[2] >= max_seqlen_q
818
+ ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
819
+ kernel_out_shape = [
820
+ *inp.query.shape[:-1],
821
+ inp.value.shape[-1],
822
+ ]
823
+ assert grad.dtype in cls.SUPPORTED_DTYPES
824
+
825
+ if inp.query.numel() and inp.key.numel():
826
+ win_left, win_right = _window_size(inp.attn_bias)
827
+ grads = Gradients(
828
+ *cls.OPERATOR(
829
+ ctx.qkv_share_storage,
830
+ grad.reshape(kernel_out_shape).contiguous(),
831
+ inp.query,
832
+ inp.key,
833
+ inp.value,
834
+ ctx.out.reshape(kernel_out_shape),
835
+ ctx_lse,
836
+ cu_seqlens_q,
837
+ cu_seqlens_k,
838
+ max_seqlen_q,
839
+ max_seqlen_k,
840
+ inp.p,
841
+ inp.scale_float,
842
+ _is_causal(inp.attn_bias),
843
+ window_left=win_left,
844
+ window_right=win_right,
845
+ rng_state=ctx.rng_state if inp.p > 0.0 else None,
846
+ )
847
+ )
848
+ else:
849
+ grads = Gradients(
850
+ dq=torch.zeros_like(inp.query),
851
+ dk=torch.zeros_like(inp.key),
852
+ dv=torch.zeros_like(inp.value),
853
+ )
854
+ if grads.dq.numel() == 0:
855
+ grads.dk.zero_()
856
+ grads.dv.zero_()
857
+ if grads.dv.numel() == 0:
858
+ grads.dq.zero_()
859
+ grads.dq = grads.dq.reshape(dq_shape)
860
+ grads.dk = grads.dk.reshape(dk_shape)
861
+ grads.dv = grads.dv.reshape(dv_shape)
862
+ return grads