mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,560 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+ from typing import Any, Iterable, List, Optional, Set, Tuple, Union
8
+
9
+ import torch
10
+
11
+ from .attn_bias import (
12
+ AttentionBias,
13
+ BlockDiagonalCausalFromBottomRightMask,
14
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
15
+ BlockDiagonalCausalLocalAttentionMask,
16
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
17
+ BlockDiagonalCausalMask,
18
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
19
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
20
+ BlockDiagonalGappyKeysMask,
21
+ BlockDiagonalLocalAttentionPaddedKeysMask,
22
+ BlockDiagonalMask,
23
+ BlockDiagonalPaddedKeysMask,
24
+ LocalAttentionFromBottomRightMask,
25
+ LowerTriangularFromBottomRightLocalAttentionMask,
26
+ LowerTriangularFromBottomRightMask,
27
+ LowerTriangularMask,
28
+ )
29
+ from .common import AttentionBwOpBase, AttentionFwOpBase, Context, Gradients, Inputs
30
+ from .utils.op_common import register_operator
31
+
32
+
33
+ def _get_operator(name: str):
34
+ def no_such_operator(*args, **kwargs):
35
+ raise RuntimeError(
36
+ "No such operator "
37
+ f"mslk.attention.cutlass_blackwell_fmha.{name} "
38
+ "- did you forget to build xformers with `python setup.py develop`?"
39
+ )
40
+
41
+ try:
42
+ # type: ignore # pyre-ignore
43
+ from mslk.attention.cutlass_blackwell_fmha import (
44
+ cutlass_blackwell_fmha_interface as fmha,
45
+ )
46
+
47
+ return getattr(fmha, name) # type: ignore # pyre-ignore
48
+ except (RuntimeError, ModuleNotFoundError):
49
+ return no_such_operator
50
+
51
+
52
+ def _convert_input_format(
53
+ inp: Inputs,
54
+ ) -> Tuple[
55
+ Inputs,
56
+ Optional[torch.Tensor],
57
+ Optional[int],
58
+ Optional[torch.Tensor],
59
+ Optional[int],
60
+ Optional[torch.Tensor],
61
+ ]:
62
+ assert inp.query.ndim in (4, 5)
63
+ query, key, value = inp.query, inp.key, inp.value
64
+
65
+ attn_bias = inp.attn_bias
66
+ if isinstance(attn_bias, BlockDiagonalMask):
67
+ assert attn_bias.k_seqinfo.seqstart.device == inp.query.device
68
+ cu_seqlen_k = attn_bias.k_seqinfo.seqstart
69
+ cu_seqlen_q = attn_bias.q_seqinfo.seqstart
70
+ max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
71
+ max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
72
+ seqused_k = None
73
+ elif isinstance(
74
+ attn_bias,
75
+ (
76
+ BlockDiagonalPaddedKeysMask,
77
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
78
+ BlockDiagonalGappyKeysMask,
79
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
80
+ BlockDiagonalLocalAttentionPaddedKeysMask,
81
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
82
+ ),
83
+ ):
84
+ assert attn_bias.k_seqinfo.seqstart.device == inp.query.device
85
+ cu_seqlen_k = attn_bias.k_seqinfo.seqstart
86
+ cu_seqlen_q = attn_bias.q_seqinfo.seqstart
87
+ max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
88
+ max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
89
+ # All these mask types inherit from classes that have seqlen attribute
90
+ seqused_k = attn_bias.k_seqinfo.seqlen
91
+ assert seqused_k is not None
92
+ else:
93
+ cu_seqlen_k = None
94
+ cu_seqlen_q = None
95
+ seqused_k = None
96
+ max_seqlen_q = None
97
+ max_seqlen_k = None
98
+
99
+ if query.ndim == 5: # GQA
100
+ # Fold the group/head_in_group dimensions together
101
+ def fold(x):
102
+ # Either the head is replicated
103
+ if x.stride(3) == 0:
104
+ return x[:, :, :, 0]
105
+
106
+ # Or we reshape
107
+ return x.reshape(
108
+ [
109
+ x.shape[0],
110
+ x.shape[1],
111
+ -1,
112
+ x.shape[4],
113
+ ]
114
+ )
115
+
116
+ query = fold(query)
117
+ key = fold(key)
118
+ value = fold(value)
119
+
120
+ if cu_seqlen_k is not None and query.ndim == 4:
121
+ # Fold to 3D when using varlen
122
+ def fold(x):
123
+ assert x.shape[0] == 1
124
+ x = x.squeeze(0)
125
+ assert x.ndim == 3
126
+ if x.stride(1) == 0:
127
+ # BMHK for MQA with kv_head = 1
128
+ return x[:, 0, :].unsqueeze(1)
129
+ return x
130
+
131
+ query = fold(query)
132
+ key = fold(key)
133
+ value = fold(value)
134
+
135
+ new_inp = Inputs(
136
+ query=query,
137
+ key=key,
138
+ value=value,
139
+ attn_bias=attn_bias,
140
+ p=inp.p,
141
+ scale=inp.scale,
142
+ output_dtype=inp.output_dtype,
143
+ is_partial=inp.is_partial,
144
+ )
145
+ return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k, seqused_k
146
+
147
+
148
+ def _is_seqlen_q_le_seqlen_k(
149
+ cu_seqlens_q_py: List[int], cu_seqlens_k_py: List[int]
150
+ ) -> bool:
151
+ if len(cu_seqlens_q_py) < 2 or len(cu_seqlens_k_py) < 2:
152
+ # The seqlens q and k info does not exist on CPU
153
+ return True
154
+ cu_seqlens_q = torch.as_tensor(cu_seqlens_q_py, dtype=torch.int, device="cpu")
155
+ cu_seqlens_k = torch.as_tensor(cu_seqlens_k_py, dtype=torch.int, device="cpu")
156
+ seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
157
+ seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
158
+ return bool(torch.all(seqlens_k >= seqlens_q))
159
+
160
+
161
+ def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
162
+ return isinstance(
163
+ attn_bias,
164
+ (
165
+ LowerTriangularMask,
166
+ BlockDiagonalCausalMask,
167
+ LowerTriangularFromBottomRightMask,
168
+ BlockDiagonalCausalFromBottomRightMask,
169
+ LowerTriangularFromBottomRightLocalAttentionMask,
170
+ BlockDiagonalCausalLocalAttentionMask,
171
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
172
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
173
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
174
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
175
+ ),
176
+ )
177
+
178
+
179
+ def _is_bottom_right(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
180
+ return isinstance(
181
+ attn_bias,
182
+ (
183
+ LowerTriangularFromBottomRightMask,
184
+ BlockDiagonalCausalFromBottomRightMask,
185
+ LocalAttentionFromBottomRightMask,
186
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
187
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
188
+ BlockDiagonalLocalAttentionPaddedKeysMask,
189
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
190
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
191
+ ),
192
+ )
193
+
194
+
195
+ def _window_size(
196
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
197
+ ) -> Tuple[int, int]:
198
+ win_left = -1
199
+ win_right = -1
200
+ if isinstance(
201
+ attn_bias,
202
+ (
203
+ BlockDiagonalCausalLocalAttentionMask,
204
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
205
+ LowerTriangularFromBottomRightLocalAttentionMask,
206
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
207
+ ),
208
+ ):
209
+ win_left = attn_bias._window_size - 1
210
+ if isinstance(
211
+ attn_bias,
212
+ (
213
+ BlockDiagonalLocalAttentionPaddedKeysMask,
214
+ LocalAttentionFromBottomRightMask,
215
+ ),
216
+ ):
217
+ win_left = attn_bias.window_left
218
+ win_right = attn_bias.window_right
219
+ return (win_left, win_right)
220
+
221
+
222
+ @register_operator
223
+ class FwOp(AttentionFwOpBase):
224
+ OPERATOR = _get_operator("_cutlass_blackwell_fmha_forward")
225
+ SUPPORTED_DEVICES: Set[str] = {"cuda"}
226
+ SUPPORTED_DTYPES: Set[torch.dtype] = {torch.bfloat16, torch.float16}
227
+ SUPPORTED_MAX_K = 128
228
+ SUPPORTED_MIN_K = 64
229
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
230
+ type(None),
231
+ LowerTriangularMask,
232
+ LowerTriangularFromBottomRightMask,
233
+ BlockDiagonalCausalFromBottomRightMask,
234
+ BlockDiagonalMask,
235
+ BlockDiagonalCausalMask,
236
+ BlockDiagonalPaddedKeysMask,
237
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
238
+ BlockDiagonalGappyKeysMask,
239
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
240
+ BlockDiagonalLocalAttentionPaddedKeysMask,
241
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
242
+ LocalAttentionFromBottomRightMask,
243
+ LowerTriangularFromBottomRightLocalAttentionMask,
244
+ BlockDiagonalCausalLocalAttentionMask,
245
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
246
+ )
247
+ SUPPORTS_DROPOUT = False
248
+ SUPPORTS_CUSTOM_SCALE = True
249
+ SUPPORTS_DIFFERENT_VALUE_EMBED = False
250
+ SUPPORTS_BMGHK = True
251
+ VARLEN_LSE_PACKED = True
252
+ SUPPORTS_PARTIAL = False
253
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = (10, 0)
254
+ NAME = "cutlassF-blackwell"
255
+
256
+ _TEST_K: List[int] = [64, 128]
257
+
258
+ @classmethod
259
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
260
+ reasons = super(FwOp, cls).not_supported_reasons(d)
261
+ attn_bias = d.attn_bias
262
+ if isinstance(attn_bias, BlockDiagonalCausalMask):
263
+ (
264
+ _,
265
+ cu_seqlens_q,
266
+ _,
267
+ cu_seqlens_k,
268
+ _,
269
+ _,
270
+ ) = _convert_input_format(d)
271
+ if not _is_seqlen_q_le_seqlen_k(
272
+ attn_bias.q_seqinfo.seqstart_py,
273
+ attn_bias.k_seqinfo.seqstart_py,
274
+ ):
275
+ reasons.append("seqlens_k must be >= seqlens_q")
276
+
277
+ if d.query.ndim < 4 or d.key.ndim < 4 or d.value.ndim < 4:
278
+ reasons.append("Only supports BMHK or BMGHK")
279
+
280
+ return reasons
281
+
282
+ @classmethod
283
+ def shape_not_supported_reasons(
284
+ cls, Mq: int, Mkv: int, K: int, Kv: int
285
+ ) -> List[str]:
286
+ reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
287
+ if K not in [64, 128] or Kv not in [64, 128]:
288
+ reasons.append(f"Embed dim {K} not supported")
289
+ elif Mkv != 0 and Mq > Mkv:
290
+ reasons.append(f"Only support Mq ({Mq}) <= Mk ({Mkv})")
291
+ return reasons
292
+
293
+ @classmethod
294
+ def apply(
295
+ cls, inp: Inputs, needs_gradient: bool
296
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
297
+ q_shape = inp.query.shape
298
+ (
299
+ inp,
300
+ cu_seqlens_q,
301
+ max_seq_len_q,
302
+ cu_seqlens_k,
303
+ max_seq_len_k,
304
+ seqused_k,
305
+ ) = _convert_input_format(inp)
306
+
307
+ window_left, window_right = _window_size(inp.attn_bias)
308
+
309
+ if inp.query.numel() > 0 and inp.key.numel() > 0:
310
+ out, lse = cls.OPERATOR(
311
+ q=inp.query,
312
+ k=inp.key,
313
+ v=inp.value,
314
+ cu_seqlens_q=cu_seqlens_q,
315
+ cu_seqlens_k=cu_seqlens_k,
316
+ seqlen_kv=seqused_k,
317
+ max_seq_len_q=max_seq_len_q,
318
+ max_seq_len_k=max_seq_len_k,
319
+ softmax_scale=inp.scale,
320
+ causal=_is_causal(inp.attn_bias),
321
+ window_left=window_left,
322
+ window_right=window_right,
323
+ bottom_right=_is_bottom_right(inp.attn_bias),
324
+ )
325
+ else:
326
+ out = torch.zeros_like(inp.query)
327
+ if cu_seqlens_q is None:
328
+ assert inp.query.ndim == 4
329
+ B, M, H, K = inp.query.shape
330
+ lse_shape = [B, H, M]
331
+ else:
332
+ assert inp.query.ndim == 3
333
+ M, H, K = inp.query.shape
334
+ lse_shape = [1, H, M]
335
+ lse = torch.zeros(*lse_shape, dtype=torch.float, device=out.device)
336
+ out = out.reshape(q_shape)
337
+ if not needs_gradient:
338
+ return out, None
339
+ return out, Context(out=out, lse=lse)
340
+
341
+
342
+ @register_operator
343
+ class FwOpDecode(AttentionFwOpBase):
344
+ """CUTLASS Blackwell decode kernel optimized for inference with sequence length 1.
345
+
346
+ This operator is specifically designed for the decode phase of autoregressive generation
347
+ where query length is 1.
348
+ """
349
+
350
+ OPERATOR = _get_operator("cutlass_blackwell_fmha_decode_forward")
351
+ SUPPORTED_DEVICES: Set[str] = {"cuda"}
352
+ SUPPORTED_DTYPES: Set[torch.dtype] = {torch.bfloat16}
353
+ SUPPORTED_MAX_K = 128
354
+ SUPPORTED_MIN_K = 64
355
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
356
+ type(None),
357
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
358
+ )
359
+ SUPPORTS_DROPOUT = False
360
+ SUPPORTS_CUSTOM_SCALE = True
361
+ SUPPORTS_DIFFERENT_VALUE_EMBED = False
362
+ SUPPORTS_BMGHK = True
363
+ VARLEN_LSE_PACKED = True
364
+ SUPPORTS_PARTIAL = False
365
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = (10, 0)
366
+ NAME = "cutlassF-blackwell-decode"
367
+
368
+ _TEST_K: List[int] = [64, 128]
369
+
370
+ @classmethod
371
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
372
+ reasons = super(FwOpDecode, cls).not_supported_reasons(d)
373
+ q_shape = d.query.shape
374
+ if q_shape[-2] > 16:
375
+ reasons.append(f"Max qHeads ({q_shape[-2]}) per KV head is > 16")
376
+ return reasons
377
+
378
+ @classmethod
379
+ def shape_not_supported_reasons(
380
+ cls, Mq: int, Mkv: int, K: int, Kv: int
381
+ ) -> List[str]:
382
+ reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
383
+ if K not in [64, 128]:
384
+ reasons.append(f"Embed dim {K} not supported")
385
+ return reasons
386
+
387
+ @classmethod
388
+ def apply(
389
+ cls, inp: Inputs, needs_gradient: bool
390
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
391
+ q_shape = inp.query.shape
392
+ (
393
+ inp,
394
+ cu_seqlens_q,
395
+ max_seq_len_q,
396
+ cu_seqlens_k,
397
+ max_seq_len_k,
398
+ seqused_k,
399
+ ) = _convert_input_format(inp)
400
+
401
+ window_left, window_right = _window_size(inp.attn_bias)
402
+
403
+ if inp.query.numel() > 0 and inp.key.numel() > 0:
404
+ out, lse = cls.OPERATOR(
405
+ q=inp.query,
406
+ k=inp.key,
407
+ v=inp.value,
408
+ cu_seqlens_q=cu_seqlens_q, # not used
409
+ cu_seqlens_k=cu_seqlens_k, # not used
410
+ seqlen_kv=seqused_k,
411
+ max_seq_len_q=max_seq_len_q, # not used
412
+ max_seq_len_k=max_seq_len_k, # not used
413
+ softmax_scale=inp.scale, # not used
414
+ causal=_is_causal(inp.attn_bias),
415
+ window_left=window_left,
416
+ window_right=window_right,
417
+ bottom_right=_is_bottom_right(inp.attn_bias), # not used
418
+ )
419
+ else:
420
+ out = torch.zeros_like(inp.query)
421
+ if cu_seqlens_q is None:
422
+ assert inp.query.ndim == 4
423
+ B, M, H, K = inp.query.shape
424
+ # lse_shape = [B, H, M]
425
+ else:
426
+ assert inp.query.ndim == 3
427
+ M, H, K = inp.query.shape
428
+ # lse_shape = [1, H, M]
429
+ # lse = torch.zeros(*lse_shape, dtype=torch.float, device=out.device)
430
+ out = out.reshape(q_shape)
431
+ assert not needs_gradient, "FwOpDecode does not support gradient computation"
432
+ return out, None
433
+
434
+
435
+ @register_operator
436
+ class BwOp(AttentionBwOpBase):
437
+ __doc__ = FwOp.__doc__
438
+
439
+ OPERATOR = _get_operator("_cutlass_blackwell_fmha_backward")
440
+
441
+ SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
442
+ SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
443
+ SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
444
+ SUPPORTED_MIN_K = FwOp.SUPPORTED_MIN_K
445
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
446
+ type(None),
447
+ LowerTriangularMask,
448
+ LowerTriangularFromBottomRightMask,
449
+ BlockDiagonalCausalFromBottomRightMask,
450
+ BlockDiagonalMask,
451
+ BlockDiagonalCausalMask,
452
+ LocalAttentionFromBottomRightMask,
453
+ LowerTriangularFromBottomRightLocalAttentionMask,
454
+ BlockDiagonalCausalLocalAttentionMask,
455
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
456
+ )
457
+ SUPPORTS_ATTN_BIAS_GRAD = False
458
+ SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
459
+ SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
460
+ SUPPORTS_DIFFERENT_VALUE_EMBED = False
461
+ SUPPORTS_BMGHK = False
462
+ VARLEN_LSE_PACKED = True
463
+ SUPPORTS_PARTIAL = False
464
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = (10, 0)
465
+ NAME = "cutlassB-blackwell"
466
+
467
+ @classmethod
468
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
469
+ reasons = super(BwOp, cls).not_supported_reasons(d)
470
+ attn_bias = d.attn_bias
471
+ if isinstance(attn_bias, BlockDiagonalCausalMask):
472
+ (
473
+ _,
474
+ cu_seqlens_q,
475
+ _,
476
+ cu_seqlens_k,
477
+ _,
478
+ _,
479
+ ) = _convert_input_format(d)
480
+ if not _is_seqlen_q_le_seqlen_k(
481
+ attn_bias.q_seqinfo.seqstart_py,
482
+ attn_bias.k_seqinfo.seqstart_py,
483
+ ):
484
+ reasons.append("seqlens_k must be >= seqlens_q")
485
+
486
+ if d.query.ndim != 4 or d.key.ndim != 4 or d.value.ndim != 4:
487
+ reasons.append("Only supports BMHK format")
488
+
489
+ return reasons
490
+
491
+ @classmethod
492
+ def shape_not_supported_reasons(
493
+ cls, Mq: int, Mkv: int, K: int, Kv: int
494
+ ) -> List[str]:
495
+ reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
496
+ if K not in [64, 128]:
497
+ reasons.append(f"Embed dim {K} not supported")
498
+ elif Mkv != 0 and Mq > Mkv:
499
+ reasons.append(f"Only support Mq ({Mq}) <= Mk ({Mkv})")
500
+ elif Mq < 8:
501
+ reasons.append(f"Only support Mq ({Mq}) >= 8")
502
+ return reasons
503
+
504
+ @classmethod
505
+ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
506
+ assert inp.query.ndim == 4
507
+ dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
508
+ (
509
+ inp,
510
+ cu_seqlens_q,
511
+ max_seq_len_q,
512
+ cu_seqlens_k,
513
+ max_seq_len_k,
514
+ _,
515
+ ) = _convert_input_format(inp)
516
+
517
+ window_left, window_right = _window_size(inp.attn_bias)
518
+
519
+ is_varlen = cu_seqlens_q is not None
520
+ if is_varlen:
521
+
522
+ def fold(x):
523
+ assert x.shape[0] == 1
524
+ x = x.squeeze(0)
525
+ assert x.ndim == 3
526
+ return x
527
+
528
+ grad = fold(grad)
529
+ ctx.out = fold(ctx.out)
530
+
531
+ if inp.query.numel() and inp.key.numel():
532
+ grads = Gradients(
533
+ *cls.OPERATOR(
534
+ dout=grad,
535
+ q=inp.query,
536
+ k=inp.key,
537
+ v=inp.value,
538
+ out=ctx.out,
539
+ softmax_lse=ctx.lse,
540
+ cu_seqlens_q=cu_seqlens_q,
541
+ cu_seqlens_k=cu_seqlens_k,
542
+ max_seq_len_q=max_seq_len_q,
543
+ max_seq_len_k=max_seq_len_k,
544
+ causal=_is_causal(inp.attn_bias),
545
+ window_left=window_left,
546
+ window_right=window_right,
547
+ bottom_right=_is_bottom_right(inp.attn_bias),
548
+ )
549
+ )
550
+ else:
551
+ grads = Gradients(
552
+ dq=torch.zeros_like(inp.query),
553
+ dk=torch.zeros_like(inp.key),
554
+ dv=torch.zeros_like(inp.value),
555
+ )
556
+
557
+ grads.dq = grads.dq.reshape(dq_shape)
558
+ grads.dk = grads.dk.reshape(dk_shape)
559
+ grads.dv = grads.dv.reshape(dv_shape)
560
+ return grads