mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,329 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from .attn_bias import (
13
+ _GappySeqInfo,
14
+ _PaddedSeqLenInfo,
15
+ _SeqLenInfo,
16
+ AttentionBias,
17
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
18
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
19
+ BlockDiagonalGappyKeysMask,
20
+ BlockDiagonalPaddedKeysMask,
21
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
22
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
23
+ PagedBlockDiagonalGappyKeysMask,
24
+ PagedBlockDiagonalPaddedKeysMask,
25
+ )
26
+
27
+
28
+ def split_blocks_for_decoding_gpu_part(
29
+ input_bias: Union[
30
+ BlockDiagonalPaddedKeysMask, BlockDiagonalCausalWithOffsetPaddedKeysMask
31
+ ],
32
+ batchify_len: Optional[int],
33
+ block_tables: Optional[torch.Tensor] = None,
34
+ page_size: Optional[int] = None,
35
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
36
+ """
37
+ This is the gpu part of split_blocks_for_decoding,
38
+ which can be called in advance.
39
+ """
40
+ if batchify_len is None:
41
+ return None
42
+ assert batchify_len > 0
43
+ assert input_bias.q_seqinfo.min_seqlen == input_bias.q_seqinfo.max_seqlen
44
+
45
+ seqstart = input_bias.k_seqinfo.seqstart # (B+1,)
46
+ seqlen = input_bias.k_seqinfo.seqlen # (B,)
47
+
48
+ # compute raw block boundaries
49
+ k_ends = seqstart[:-1] + seqlen # (B,)
50
+ # For non-speculative decoding, we have a causal bias here,
51
+ # which will always be from-bottom-right style.
52
+ # Q and K are aligned so that their last tokens are at the same position.
53
+ # If seqlen == batchify_len, the first token of the query is at position batchify_len - 1,
54
+ # and it can attend to all keys from the previous iRoPE chunk.
55
+ # The diagram shows that when seqlen == batchify_len == N and the bias is causal,
56
+ # Q can still attend to K from the previous chunk.
57
+ # -----------iRoPE chunk 0---------|---------iRoPE chunk 1---------------
58
+ # Q[0] |
59
+ # K[0] K[1] K[2] ... K[N-2] K[N-1] |
60
+
61
+ # For speculative decoding, we use this function for the prefix bias only.
62
+ # We are called with a non-causal bias.
63
+ # The query is positioned after the keys, and so when seqlen == batchify_len,
64
+ # the first token of the query is at position batchify_len.
65
+ # So it can't attend to any key from the previous chunk,
66
+ # so we want k_starts == k_ends => k_lens == 0.
67
+ # The diagram shows that when seqlen == batchify_len == N and the bias is non-causal,
68
+ # Q is located entirely in the next iRoPE chunk and can't attend to K[0] ... K[N-1].
69
+ # ------------iRoPE chunk 0---------------|---------iRoPE chunk 1---------
70
+ # | Q[0] Q[1] Q[2]
71
+ # K[0] K[1] K[2] ... K[N-3] K[N-2] K[N-1] |
72
+
73
+ shift = int(isinstance(input_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask))
74
+ k_starts = (k_ends - shift) // batchify_len * batchify_len
75
+ k_starts = torch.where(seqlen == 0, k_ends, k_starts)
76
+ k_lens = k_ends - k_starts
77
+
78
+ if block_tables is None:
79
+ k_seqstarts = torch.cat([k_starts, seqstart[-1:]])
80
+ else:
81
+ k_seqstarts = (k_starts - seqstart[:-1]).clamp(min=0)
82
+ k_lens = k_lens + k_seqstarts
83
+
84
+ return k_seqstarts, k_lens
85
+
86
+
87
+ def split_blocks_for_decoding(
88
+ input_bias: Union[
89
+ BlockDiagonalPaddedKeysMask, BlockDiagonalCausalWithOffsetPaddedKeysMask
90
+ ],
91
+ batchify_len: Optional[int],
92
+ block_tables: Optional[torch.Tensor] = None,
93
+ page_size: Optional[int] = None,
94
+ gpu_data: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
95
+ ) -> Optional[Union[BlockDiagonalGappyKeysMask, PagedBlockDiagonalGappyKeysMask]]:
96
+ """
97
+ For decoding, when query length is 1, we can represent iRoPE-batchified bias as a gappy bias.
98
+ This function can also be applied for speculative decoding, when query length is > 1,
99
+ but same across all batch elements. In this case we assume that query (draft) lies entirely
100
+ in one block/subsequence, not crossing the boundary. Cases when the query crosses the boundary
101
+ need to be handled separately by the caller.
102
+ """
103
+ if batchify_len is None:
104
+ return None
105
+ assert batchify_len > 0
106
+ assert input_bias.q_seqinfo.min_seqlen == input_bias.q_seqinfo.max_seqlen
107
+
108
+ if gpu_data is None:
109
+ gpu_data = split_blocks_for_decoding_gpu_part(
110
+ input_bias, batchify_len, block_tables, page_size
111
+ )
112
+ assert gpu_data is not None
113
+ k_seqstarts, k_lens = gpu_data
114
+
115
+ k_seqstarts_list = []
116
+ k_seqlens_list = []
117
+ k_seqlens_list_actual = []
118
+ B = len(input_bias.k_seqinfo.seqlen_py)
119
+ # About the shift, see the comment in split_blocks_for_decoding_gpu_part.
120
+ shift = int(isinstance(input_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask))
121
+ for i in range(B):
122
+ input_k_start_ = input_bias.k_seqinfo.seqstart_py[i]
123
+ input_k_len_ = input_bias.k_seqinfo.seqlen_py[i]
124
+ input_k_end_ = input_k_start_ + input_k_len_
125
+ k_seqstart = (input_k_end_ - shift) // batchify_len * batchify_len
126
+ if input_k_len_ == 0:
127
+ k_seqstart = input_k_end_
128
+ k_seqend = min(k_seqstart + batchify_len, input_k_end_)
129
+ k_len = k_seqend - k_seqstart
130
+ # NOTE: With chunked, `k_len` cannot exceed the original length `input_k_len_`, so we clamp it here.
131
+ k_len = min(k_len, input_k_len_)
132
+
133
+ if k_seqstart < 0:
134
+ k_len = k_seqstart = 0
135
+ k_seqstart = (
136
+ k_seqstart if block_tables is None else max(k_seqstart - input_k_start_, 0)
137
+ )
138
+ k_seqstarts_list.append(k_seqstart)
139
+ k_seqlens_list_actual.append(k_len)
140
+ k_seqlens_list.append(k_len if block_tables is None else k_len + k_seqstart)
141
+
142
+ OutBiasType = (
143
+ BlockDiagonalCausalWithOffsetGappyKeysMask
144
+ if isinstance(input_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
145
+ else BlockDiagonalGappyKeysMask
146
+ )
147
+ PagedOutBiasType = (
148
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask
149
+ if isinstance(input_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
150
+ else PagedBlockDiagonalGappyKeysMask
151
+ )
152
+ if block_tables is None:
153
+ k_seqstarts_list.append(input_bias.k_seqinfo.seqstart_py[-1])
154
+ return OutBiasType(
155
+ q_seqinfo=input_bias.q_seqinfo,
156
+ k_seqinfo=_GappySeqInfo(
157
+ seqstart_py=k_seqstarts_list,
158
+ seqstart=k_seqstarts,
159
+ seqlen=k_lens,
160
+ seqlen_py=k_seqlens_list,
161
+ min_seqlen=min(k_seqlens_list),
162
+ max_seqlen=max(k_seqlens_list),
163
+ ),
164
+ )
165
+ assert page_size is not None
166
+ return PagedOutBiasType(
167
+ q_seqinfo=input_bias.q_seqinfo,
168
+ k_seqinfo=_GappySeqInfo(
169
+ seqstart_py=k_seqstarts_list,
170
+ seqstart=k_seqstarts,
171
+ seqlen=k_lens,
172
+ seqlen_py=k_seqlens_list,
173
+ min_seqlen=min(k_seqlens_list_actual),
174
+ max_seqlen=max(k_seqlens_list_actual),
175
+ ),
176
+ block_tables=block_tables,
177
+ page_size=page_size,
178
+ )
179
+
180
+
181
+ def split_blocks_for_prefill(
182
+ input_bias: BlockDiagonalPaddedKeysMask, batchify_len: Optional[int]
183
+ ) -> Optional[BlockDiagonalPaddedKeysMask]:
184
+ """
185
+ From
186
+ https://github.com/fairinternal/llm_inference/blob/11bbb2/llm_inference/models/disagg_transformer.py#L1955
187
+ """
188
+ if batchify_len is None:
189
+ return None
190
+ padding = input_bias.k_seqinfo.padding
191
+ assert padding % batchify_len == 0, f"{padding} % {batchify_len} != 0"
192
+ split_factor = padding // batchify_len
193
+ batch_size = len(input_bias.q_seqinfo.seqstart_py) - 1
194
+ new_batch_size = batch_size * split_factor
195
+ k_seqlen = input_bias.k_seqinfo.seqlen
196
+ q_seqlen = input_bias.q_seqinfo.seqstart[1:] - input_bias.q_seqinfo.seqstart[:-1]
197
+ k_seqlen_each = k_seqlen.repeat_interleave(split_factor, output_size=new_batch_size)
198
+ q_seqlen_each = q_seqlen.repeat_interleave(split_factor, output_size=new_batch_size)
199
+ res_seqlen_each = k_seqlen_each - q_seqlen_each
200
+ seqpos = torch.arange(
201
+ 0, padding, batchify_len, device=k_seqlen.device, dtype=k_seqlen.dtype
202
+ )
203
+ seqpos_start = seqpos.repeat(batch_size)
204
+ k_lengths = (k_seqlen_each - seqpos_start).clamp(min=0, max=batchify_len)
205
+ res_lengths = (res_seqlen_each - seqpos_start).clamp(min=0, max=batchify_len)
206
+
207
+ k_seqstart = torch.arange(
208
+ 0,
209
+ new_batch_size * batchify_len + 1,
210
+ batchify_len,
211
+ device=k_seqlen.device,
212
+ dtype=k_seqlen.dtype,
213
+ )
214
+ k_seqstart_py = list(range(0, new_batch_size * batchify_len + 1, batchify_len))
215
+ q_seqstart = torch.zeros_like(k_seqstart)
216
+ torch.cumsum(k_lengths - res_lengths, 0, out=q_seqstart[1:])
217
+
218
+ # start at 2 to avoid reshaping issues with
219
+ # https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L602
220
+ max_q_len = 2
221
+ min_q_len = 2
222
+ max_k_len = 0
223
+ q_seqstart_list: List[int] = [0]
224
+ k_seqlen_list: List[int] = []
225
+ for i in range(len(input_bias.k_seqinfo.seqlen)):
226
+ q_seqlen_ = (
227
+ input_bias.q_seqinfo.seqstart_py[i + 1]
228
+ - input_bias.q_seqinfo.seqstart_py[i]
229
+ )
230
+ k_seqlen_ = input_bias.k_seqinfo.seqlen_py[i]
231
+ res_seqlen_ = k_seqlen_ - q_seqlen_
232
+ for seqpos_ in range(0, padding, batchify_len):
233
+ k_chunk_size = max(min(k_seqlen_ - seqpos_, batchify_len), 0)
234
+ res_chunk_size = max(min(res_seqlen_ - seqpos_, batchify_len), 0)
235
+ q_chunk_size = k_chunk_size - res_chunk_size
236
+
237
+ q_seqstart_list.append(q_seqstart_list[-1] + q_chunk_size)
238
+ k_seqlen_list.append(k_chunk_size)
239
+ if q_chunk_size > max_q_len:
240
+ max_q_len = q_chunk_size
241
+ if q_chunk_size < min_q_len:
242
+ min_q_len = q_chunk_size
243
+ if k_chunk_size > max_k_len:
244
+ max_k_len = k_chunk_size
245
+
246
+ batchify_attn_bias = input_bias.__class__(
247
+ q_seqinfo=_SeqLenInfo(
248
+ seqstart=q_seqstart,
249
+ max_seqlen=max_q_len,
250
+ min_seqlen=min_q_len,
251
+ seqstart_py=q_seqstart_list,
252
+ ),
253
+ k_seqinfo=_PaddedSeqLenInfo(
254
+ seqstart=k_seqstart,
255
+ seqlen_py=k_seqlen_list,
256
+ seqlen=k_lengths,
257
+ padding=batchify_len,
258
+ seqstart_py=k_seqstart_py,
259
+ min_seqlen=0,
260
+ max_seqlen=max_k_len,
261
+ ),
262
+ )
263
+ return batchify_attn_bias
264
+
265
+
266
+ def maybe_make_paged(
267
+ attn_bias: Optional[
268
+ Union[
269
+ BlockDiagonalPaddedKeysMask,
270
+ BlockDiagonalGappyKeysMask,
271
+ ]
272
+ ],
273
+ block_tables: Optional[torch.Tensor],
274
+ page_size: int,
275
+ notional_padding: Optional[int],
276
+ ) -> Optional[AttentionBias]:
277
+ """
278
+ Convert attention bias into its paged version if block_tables is not None.
279
+ Args:
280
+ attn_bias: input attention bias.
281
+ block_tables: table of shape [batch_size, max_pages_per_lane]
282
+ redirecting from logical to physical pages.
283
+ page_size: number of tokens per page.
284
+ notional_padding: if input attention bias is gappy, it has
285
+ no notion of padding, sequence starts are arbitrary.
286
+ However, we need to know how to divide logical sequence space
287
+ into lanes corresponding to each row of block tables.
288
+ In other words, where is 0th block in i-th row of block table
289
+ located in the logical space?
290
+ This function assumes that it's located at i * notional_padding.
291
+ The value of notional_padding needs to be consisted which
292
+ padding used when block_tables was created.
293
+ For example, if a gappy bias was created from a padded bias
294
+ using split_blocks* functions, notional padding
295
+ should be equal to the padding of the original bias.
296
+ Returns:
297
+ Paged version of the original attention bias.
298
+ """
299
+ if attn_bias is None:
300
+ return None
301
+ if block_tables is None:
302
+ return attn_bias
303
+
304
+ attn_batch_size = len(attn_bias.k_seqinfo.seqlen)
305
+ if attn_batch_size != block_tables.shape[0]:
306
+ # In case of iRoPE each batch lane has been split into smaller chunks,
307
+ # so we need to reshape the block tables accordingly.
308
+ block_tables = block_tables.view(attn_batch_size, -1)
309
+ if isinstance(attn_bias, BlockDiagonalGappyKeysMask):
310
+ assert notional_padding is not None, (
311
+ "Notional padding must be specified to create gappy paged biases."
312
+ )
313
+ return attn_bias.make_paged(
314
+ block_tables=block_tables,
315
+ page_size=page_size,
316
+ notional_padding=notional_padding,
317
+ paged_type=PagedBlockDiagonalGappyKeysMask,
318
+ )
319
+ if isinstance(attn_bias, PagedBlockDiagonalGappyKeysMask):
320
+ return attn_bias
321
+ paged_type = (
322
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask
323
+ if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
324
+ else PagedBlockDiagonalPaddedKeysMask
325
+ )
326
+ assert isinstance(attn_bias, BlockDiagonalPaddedKeysMask)
327
+ return attn_bias.make_paged(
328
+ block_tables=block_tables, page_size=page_size, paged_type=paged_type
329
+ )
@@ -0,0 +1,154 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-strict
7
+
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from torch._C import parse_schema
12
+
13
+
14
+ def is_pt_cutlass_compatible(force: bool = False) -> bool:
15
+ if torch.version.hip is not None:
16
+ if force:
17
+ raise ImportError("CUTLASS is not supported on ROCm")
18
+ return False
19
+ compatible = True
20
+
21
+ fwd_schema_str = (
22
+ "aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, "
23
+ "Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, "
24
+ "SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, "
25
+ "float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> "
26
+ "(Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, "
27
+ "SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)"
28
+ )
29
+ expected_fwd_schema = parse_schema(fwd_schema_str)
30
+
31
+ current_schema = torch.ops.aten._efficient_attention_forward.default._schema
32
+ if not current_schema.is_backward_compatible_with(expected_fwd_schema):
33
+ compatible = False
34
+
35
+ if force:
36
+ raise ImportError(
37
+ f"Current Torch CUTLASS doesnt have a compatible aten::_efficient_attention_forward schema\n"
38
+ f"EXPECTED:\n{expected_fwd_schema}\n"
39
+ f"but GOT:\n{current_schema}"
40
+ )
41
+
42
+ bwd_schema_str = (
43
+ "aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, "
44
+ "Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, "
45
+ "SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, "
46
+ "int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, "
47
+ "int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor)"
48
+ )
49
+
50
+ expected_bwd_schema = parse_schema(bwd_schema_str)
51
+
52
+ current_schema = torch.ops.aten._efficient_attention_backward.default._schema
53
+ if not current_schema.is_backward_compatible_with(expected_bwd_schema):
54
+ compatible = False
55
+
56
+ if force:
57
+ raise ImportError(
58
+ f"Current Torch CUTLASS doesnt have a compatible aten::_efficient_attention_backward schema\n"
59
+ f"EXPECTED:\n{expected_bwd_schema}\n"
60
+ f"but GOT:\n{current_schema}"
61
+ )
62
+
63
+ return compatible
64
+
65
+
66
+ def is_pt_flash_old(force: bool) -> Optional[bool]:
67
+ """
68
+ Returns True if the current PyTorch version has the old Flash-Attention
69
+ ops instead of the new ones.
70
+ If it has none at all, raises an ImportError or returns None.
71
+ """
72
+ if not torch.backends.cuda.is_flash_attention_available():
73
+ if force:
74
+ raise ImportError("Flash SDP backend is disabled")
75
+ return None
76
+
77
+ if not hasattr(torch.nn, "attention") or not hasattr(
78
+ torch.nn.attention, "_get_flash_version"
79
+ ):
80
+ if force:
81
+ raise ImportError(
82
+ f"Current Torch {torch.__version__} doesnt implement "
83
+ "torch.nn.attention._get_flash_version()"
84
+ )
85
+ return None
86
+
87
+ FLASH_VERSION = torch.nn.attention._get_flash_version()
88
+
89
+ compatible = True
90
+
91
+ # old = before 25/2/2025
92
+ # https://github.com/pytorch/pytorch/commit/3ecfe6be256c585bcadf4c845d7119545444a222
93
+ old_fwd_schema_str = (
94
+ "aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, "
95
+ "Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, "
96
+ "bool is_causal, bool return_debug_mask, *, float? scale=None, "
97
+ "SymInt? window_size_left=None, SymInt? window_size_right=None, "
98
+ "Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, "
99
+ "Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)"
100
+ )
101
+ fwd_schema_str = (
102
+ "aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, "
103
+ "Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, "
104
+ "bool is_causal, bool return_debug_mask, *, float? scale=None, "
105
+ "SymInt? window_size_left=None, SymInt? window_size_right=None, "
106
+ "Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, "
107
+ "Tensor rng_state, Tensor unused, Tensor debug_attn_mask)"
108
+ )
109
+ expected_fwd_schema = parse_schema(fwd_schema_str)
110
+ expected_old_fwd_schema = parse_schema(old_fwd_schema_str)
111
+
112
+ current_schema = torch.ops.aten._flash_attention_forward.default._schema
113
+ old = current_schema.is_backward_compatible_with(expected_old_fwd_schema)
114
+ if not old and not current_schema.is_backward_compatible_with(expected_fwd_schema):
115
+ compatible = False
116
+
117
+ if force:
118
+ raise ImportError(
119
+ f"Current Torch with Flash-Attention {FLASH_VERSION} doesnt have "
120
+ "a compatible aten::_flash_attention_forward schema\n"
121
+ f"EXPECTED:\n{expected_old_fwd_schema}\n"
122
+ f"or:\n{expected_fwd_schema}\n"
123
+ f"but GOT:\n{current_schema}"
124
+ )
125
+
126
+ bwd_schema_old_str = (
127
+ "aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, "
128
+ "Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, "
129
+ "float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, "
130
+ "SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)"
131
+ )
132
+ bwd_schema_str = (
133
+ "aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, "
134
+ "Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, "
135
+ "float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, "
136
+ "SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)"
137
+ )
138
+ expected_bwd_schema = parse_schema(bwd_schema_old_str if old else bwd_schema_str)
139
+
140
+ current_schema = torch.ops.aten._flash_attention_backward.default._schema
141
+ if not current_schema.is_backward_compatible_with(expected_bwd_schema):
142
+ compatible = False
143
+
144
+ if force:
145
+ raise ImportError(
146
+ f"Current Torch with Flash-Attention {FLASH_VERSION} doesnt have "
147
+ "a compatible aten::_flash_attention_backward schema\n"
148
+ f"EXPECTED:\n{expected_bwd_schema}\n"
149
+ f"but GOT:\n{current_schema}"
150
+ )
151
+
152
+ if not compatible:
153
+ return None
154
+ return old