mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,508 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+
9
+ from enum import Enum
10
+ from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple, Union
11
+
12
+ import torch
13
+
14
+ from . import attn_bias
15
+ from .attn_bias import (
16
+ AttentionBias,
17
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
18
+ BlockDiagonalCausalLocalAttentionMask,
19
+ BlockDiagonalCausalMask,
20
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
21
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
22
+ BlockDiagonalGappyKeysMask,
23
+ BlockDiagonalMask,
24
+ BlockDiagonalPaddedKeysMask,
25
+ LowerTriangularFromBottomRightLocalAttentionMask,
26
+ LowerTriangularFromBottomRightMask,
27
+ LowerTriangularMask,
28
+ LowerTriangularMaskWithTensorBias,
29
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
30
+ PagedBlockDiagonalGappyKeysMask,
31
+ PagedBlockDiagonalPaddedKeysMask,
32
+ )
33
+ from .common import (
34
+ AttentionBwOpBase,
35
+ AttentionFwOpBase,
36
+ check_lastdim_alignment_stride1,
37
+ Context,
38
+ Gradients,
39
+ Inputs,
40
+ )
41
+ from .utils.op_common import get_operator, register_operator
42
+
43
+
44
+ def _minimum_gemm_alignment(inp: Inputs) -> int:
45
+ return 1
46
+
47
+
48
+ def _get_seqlen_info(
49
+ inp: Inputs,
50
+ ) -> Tuple[
51
+ Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], int, int
52
+ ]:
53
+ attn_bias = inp.attn_bias
54
+ if isinstance(
55
+ attn_bias,
56
+ (
57
+ BlockDiagonalMask,
58
+ BlockDiagonalGappyKeysMask,
59
+ BlockDiagonalPaddedKeysMask,
60
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
61
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
62
+ PagedBlockDiagonalPaddedKeysMask,
63
+ PagedBlockDiagonalGappyKeysMask,
64
+ ),
65
+ ):
66
+ attn_bias.k_seqinfo.to(inp.query.device)
67
+ attn_bias.q_seqinfo.to(inp.query.device)
68
+ seqstart_k = attn_bias.k_seqinfo.seqstart
69
+ seqstart_q = attn_bias.q_seqinfo.seqstart
70
+ max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
71
+ max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
72
+ seqlen = (
73
+ None
74
+ if isinstance(attn_bias, BlockDiagonalMask)
75
+ else attn_bias.k_seqinfo.seqlen
76
+ )
77
+ else:
78
+ seqstart_k = None
79
+ seqstart_q = None
80
+ max_seqlen_q = -1
81
+ max_seqlen_k = -1
82
+ seqlen = None
83
+
84
+ if isinstance(attn_bias, PagedBlockDiagonalGappyKeysMask):
85
+ assert seqstart_k is not None
86
+ assert seqlen is not None
87
+ seqstart_k = seqstart_k[:-1]
88
+ seqlen = seqlen - seqstart_k
89
+
90
+ return seqstart_k, seqstart_q, seqlen, max_seqlen_q, max_seqlen_k
91
+
92
+
93
+ def _get_tensor_bias(
94
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
95
+ ) -> Optional[torch.Tensor]:
96
+ if isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
97
+ return attn_bias._bias
98
+ if isinstance(attn_bias, torch.Tensor):
99
+ return attn_bias
100
+ return None
101
+
102
+
103
+ def _check_bias_alignment(
104
+ reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
105
+ ) -> None:
106
+ attn_bias_tensor = _get_tensor_bias(attn_bias)
107
+ if attn_bias_tensor is not None:
108
+ alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits
109
+ show_padding_hint = False
110
+ for d in range(attn_bias_tensor.ndim - 1):
111
+ if attn_bias_tensor.stride(d) % alignment != 0:
112
+ reasons.append(
113
+ f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})"
114
+ )
115
+ show_padding_hint = True
116
+ if show_padding_hint:
117
+ reasons.append(
118
+ """\
119
+ HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \
120
+ you need to ensure memory is aligned by slicing a bigger tensor. \
121
+ Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`"""
122
+ )
123
+ # We can have stride=0 sometimes if dimension=1
124
+ if attn_bias_tensor.stride(-1) > 1:
125
+ reasons.append(
126
+ f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - "
127
+ "you should call `.contiguous()` on the bias"
128
+ )
129
+
130
+
131
+ class _CustomMaskType(int, Enum):
132
+ """
133
+ (Matches CustomMaskType in C++.)
134
+ """
135
+
136
+ NoCustomMask = 0
137
+ CausalFromTopLeft = 1
138
+ CausalFromBottomRight = 2
139
+
140
+
141
+ def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
142
+ if isinstance(
143
+ bias,
144
+ (
145
+ LowerTriangularMask,
146
+ BlockDiagonalCausalMask,
147
+ BlockDiagonalCausalLocalAttentionMask,
148
+ ),
149
+ ):
150
+ return int(_CustomMaskType.CausalFromTopLeft)
151
+ if isinstance(
152
+ bias,
153
+ (
154
+ LowerTriangularFromBottomRightMask,
155
+ LowerTriangularFromBottomRightLocalAttentionMask,
156
+ attn_bias.BlockDiagonalCausalFromBottomRightMask,
157
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
158
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
159
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
160
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
161
+ ),
162
+ ):
163
+ return int(_CustomMaskType.CausalFromBottomRight)
164
+ return int(_CustomMaskType.NoCustomMask)
165
+
166
+
167
+ @register_operator
168
+ class FwOp(AttentionFwOpBase):
169
+ """xFormers' MHA kernel based on Composable Kernel."""
170
+
171
+ OPERATOR = get_operator("xformers", "efficient_attention_forward_ck")
172
+ SUPPORTED_DEVICES: Set[str] = {"cuda"}
173
+ SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
174
+ SUPPORTED_MAX_K = 512
175
+
176
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
177
+ type(None),
178
+ torch.Tensor,
179
+ LowerTriangularMask,
180
+ LowerTriangularFromBottomRightMask,
181
+ LowerTriangularFromBottomRightLocalAttentionMask,
182
+ LowerTriangularMaskWithTensorBias,
183
+ BlockDiagonalMask,
184
+ BlockDiagonalCausalMask,
185
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
186
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
187
+ BlockDiagonalGappyKeysMask,
188
+ BlockDiagonalPaddedKeysMask,
189
+ attn_bias.BlockDiagonalCausalFromBottomRightMask,
190
+ attn_bias.BlockDiagonalCausalLocalAttentionMask,
191
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
192
+ PagedBlockDiagonalPaddedKeysMask,
193
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
194
+ PagedBlockDiagonalGappyKeysMask,
195
+ )
196
+
197
+ SUPPORTS_DROPOUT = True
198
+ SUPPORTS_CUSTOM_SCALE = True
199
+ SUPPORTS_DIFFERENT_VALUE_EMBED = True
200
+ SUPPORTS_PARTIAL = True
201
+ SUPPORTS_BMGHK = True
202
+ VARLEN_LSE_PACKED = True
203
+ NAME = "ckF"
204
+
205
+ ERROR_ATOL: Mapping[torch.dtype, float] = {
206
+ torch.float: 3e-4,
207
+ torch.half: 6e-3,
208
+ torch.bfloat16: 2.8e-2,
209
+ }
210
+ ERROR_RTOL: Mapping[torch.dtype, float] = {
211
+ torch.float: 2e-5,
212
+ torch.half: 3e-3,
213
+ torch.bfloat16: 2e-2,
214
+ }
215
+
216
+ _TEST_K: List[int] = [
217
+ 32, # 64x64 kernel
218
+ 96,
219
+ 128, # 64x128 kernel
220
+ 256, # 64x128 with accumulation in gmem
221
+ 512,
222
+ ]
223
+
224
+ @classmethod
225
+ def apply(
226
+ cls, inp: Inputs, needs_gradient: bool
227
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
228
+ if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
229
+ raise NotImplementedError("Unsupported attn_bias type")
230
+ if inp.query.ndim in [1, 2, 3]:
231
+ raise NotImplementedError("Unsupported number of dimensions")
232
+ if inp.query.ndim in [4]:
233
+ return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
234
+ assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
235
+ ctx: Optional[Context] = None
236
+
237
+ # when the input is expanded 5-D, the group dimension has zero stride
238
+ if inp.key.stride()[3] == 0:
239
+ assert inp.value.stride()[3] == 0, (
240
+ "key and value should be expanded in the same way"
241
+ )
242
+ k_shape = inp.key.size()
243
+ k_stride = inp.key.stride()
244
+ key = inp.key.as_strided(
245
+ (k_shape[0], k_shape[1], k_shape[2], k_shape[4]),
246
+ (k_stride[0], k_stride[1], k_stride[2], k_stride[4]),
247
+ )
248
+ v_shape = inp.value.size()
249
+ v_stride = inp.value.stride()
250
+ value = inp.value.as_strided(
251
+ (v_shape[0], v_shape[1], v_shape[2], v_shape[4]),
252
+ (v_stride[0], v_stride[1], v_stride[2], v_stride[4]),
253
+ )
254
+ else:
255
+ key = inp.key.flatten(2, 3)
256
+ value = inp.value.flatten(2, 3)
257
+
258
+ [_, _, G, Hq, _] = inp.query.shape
259
+ attn_bias_replace = inp.attn_bias
260
+ if isinstance(inp.attn_bias, LowerTriangularMaskWithTensorBias):
261
+ bias_tensor = _get_tensor_bias(inp.attn_bias)
262
+ if bias_tensor is not None and bias_tensor.ndim == 5:
263
+ attn_bias_replace = LowerTriangularMaskWithTensorBias(
264
+ bias_tensor.flatten(1, 2)
265
+ )
266
+ elif isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.ndim == 5:
267
+ attn_bias_replace = inp.attn_bias.flatten(1, 2)
268
+ inp = Inputs(
269
+ query=inp.query.flatten(2, 3),
270
+ key=key,
271
+ value=value,
272
+ attn_bias=attn_bias_replace,
273
+ p=inp.p,
274
+ scale=inp.scale,
275
+ output_dtype=inp.output_dtype,
276
+ is_partial=inp.is_partial,
277
+ )
278
+ out, ctx = cls.apply_bmhk(inp, needs_gradient=needs_gradient)
279
+ out = out.unflatten(2, (G, Hq))
280
+ if ctx is not None:
281
+ lse = ctx.lse.unflatten(1, (G, Hq))
282
+ ctx = Context(
283
+ lse=lse,
284
+ out=out,
285
+ op_bw=ctx.op_bw,
286
+ rng_state=ctx.rng_state,
287
+ qkv_share_storage=ctx.qkv_share_storage,
288
+ )
289
+
290
+ return out, ctx
291
+
292
+ @classmethod
293
+ def apply_bmhk(
294
+ cls, inp: Inputs, needs_gradient: bool
295
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
296
+ if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
297
+ raise NotImplementedError("Unsupported attn_bias type")
298
+ seqstart_k, seqstart_q, seqlen_k, max_seqlen_q, _ = _get_seqlen_info(inp)
299
+ out, lse, rng_seed, rng_offset = cls.OPERATOR(
300
+ query=inp.query,
301
+ key=inp.key,
302
+ value=inp.value,
303
+ attn_bias=_get_tensor_bias(inp.attn_bias),
304
+ seqstart_q=seqstart_q,
305
+ seqstart_k=seqstart_k,
306
+ max_seqlen_q=max_seqlen_q,
307
+ dropout_p=inp.p,
308
+ compute_logsumexp=needs_gradient,
309
+ custom_mask_type=_custom_mask_type(inp.attn_bias),
310
+ scale=inp.scale,
311
+ seqlen_k=seqlen_k,
312
+ window_size=(
313
+ inp.attn_bias._window_size
314
+ if isinstance(
315
+ inp.attn_bias,
316
+ (
317
+ BlockDiagonalCausalLocalAttentionMask,
318
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
319
+ LowerTriangularFromBottomRightLocalAttentionMask,
320
+ ),
321
+ )
322
+ else None
323
+ ),
324
+ block_tables=(
325
+ inp.attn_bias.block_tables
326
+ if isinstance(
327
+ inp.attn_bias,
328
+ (
329
+ PagedBlockDiagonalPaddedKeysMask,
330
+ PagedBlockDiagonalGappyKeysMask,
331
+ ),
332
+ )
333
+ else None
334
+ ),
335
+ page_size=(
336
+ inp.attn_bias.page_size
337
+ if isinstance(
338
+ inp.attn_bias,
339
+ (
340
+ PagedBlockDiagonalPaddedKeysMask,
341
+ PagedBlockDiagonalGappyKeysMask,
342
+ ),
343
+ )
344
+ else None
345
+ ),
346
+ )
347
+
348
+ ctx: Optional[Context] = None
349
+ if needs_gradient:
350
+ ctx = Context(
351
+ out=out,
352
+ # lse=_post_process_lse(lse, inp, tuple(original_query_shape)),
353
+ lse=lse,
354
+ # cutlass forward is only compatible with cutlass backward if
355
+ # dropout is used (because of the way RNG states are passed and the
356
+ # way random numbers are generated during backward)
357
+ op_bw=BwOp if inp.p != 0 else None,
358
+ )
359
+ if inp.p != 0:
360
+ ctx.rng_state = torch.tensor(
361
+ [rng_seed, rng_offset], dtype=torch.int64, device="cpu"
362
+ )
363
+ return out, ctx
364
+
365
+ @classmethod
366
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
367
+ reasons = super(FwOp, cls).not_supported_reasons(d)
368
+ matmul_alignment_mn = _minimum_gemm_alignment(d)
369
+ check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
370
+ check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
371
+ _check_bias_alignment(reasons, d.attn_bias)
372
+ return reasons
373
+
374
+
375
+ @register_operator
376
+ class BwOp(AttentionBwOpBase):
377
+ __doc__ = FwOp.__doc__
378
+
379
+ OPERATOR = get_operator("xformers", "efficient_attention_backward_ck")
380
+ SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
381
+ SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
382
+ SUPPORTED_MAX_K = 256
383
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
384
+ type(None),
385
+ torch.Tensor,
386
+ LowerTriangularMask,
387
+ LowerTriangularFromBottomRightMask,
388
+ LowerTriangularFromBottomRightLocalAttentionMask,
389
+ # TODO: Fix handling of gradient through the fMHA autograd function
390
+ # LowerTriangularMaskWithTensorBias,
391
+ BlockDiagonalMask,
392
+ BlockDiagonalCausalMask,
393
+ attn_bias.BlockDiagonalCausalFromBottomRightMask,
394
+ attn_bias.BlockDiagonalCausalLocalAttentionMask,
395
+ )
396
+ SUPPORTS_ATTN_BIAS_GRAD = True
397
+ SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
398
+ SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
399
+ SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
400
+ SUPPORTS_UNPADDED_LSE = True
401
+ NAME = "ckB"
402
+
403
+ _TEST_K: List[int] = [
404
+ 32, # 64x64 kernel
405
+ 64,
406
+ 96,
407
+ 128, # 64x128/128x128 kernel
408
+ 256,
409
+ ]
410
+
411
+ @classmethod
412
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
413
+ reasons = super(BwOp, cls).not_supported_reasons(d)
414
+ matmul_alignment_mn = _minimum_gemm_alignment(d)
415
+
416
+ check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
417
+ check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
418
+ check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
419
+ _check_bias_alignment(reasons, d.attn_bias)
420
+ attn_bias_tensor = _get_tensor_bias(d.attn_bias)
421
+
422
+ # Backprop of gradient through broadcasted bias is not supported
423
+ if attn_bias_tensor is not None and attn_bias_tensor.requires_grad:
424
+ # Don't forget that inputs are either in BMK or BMHK!
425
+ if d.query.ndim == 3 and attn_bias_tensor.ndim == 3:
426
+ expected_bias_shape = (*d.query.shape[:2], d.key.shape[1])
427
+ else:
428
+ # bias is B H Mq Mk
429
+ expected_bias_shape = (
430
+ d.query.shape[0],
431
+ d.query.shape[2] if d.query.ndim == 4 else 1,
432
+ d.query.shape[1],
433
+ d.key.shape[1],
434
+ )
435
+ if tuple(attn_bias_tensor.shape) != expected_bias_shape:
436
+ reasons.append(
437
+ "Broadcasting the `attn_bias` tensor is not supported "
438
+ f"(shape: {tuple(attn_bias_tensor.shape)}"
439
+ f"/ expected: {expected_bias_shape})"
440
+ )
441
+
442
+ return reasons
443
+
444
+ @classmethod
445
+ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
446
+ if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES:
447
+ raise NotImplementedError("Unsupported attn_bias type")
448
+
449
+ seqstart_k, seqstart_q, seqlen_k, max_seqlen_q, max_seqlen_k = _get_seqlen_info(
450
+ inp
451
+ )
452
+ dtype = inp.query.dtype
453
+
454
+ rng_seed = rng_offset = 0
455
+ if inp.p != 0.0:
456
+ if (
457
+ ctx.rng_state is None
458
+ or ctx.rng_state.dtype != torch.int64
459
+ or ctx.rng_state.device.type != "cpu"
460
+ or ctx.rng_state.shape != (2,)
461
+ ):
462
+ raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
463
+ rng_seed, rng_offset = ctx.rng_state.tolist()
464
+
465
+ (grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
466
+ grad.to(dtype),
467
+ inp.query,
468
+ inp.key,
469
+ inp.value,
470
+ attn_bias=_get_tensor_bias(inp.attn_bias),
471
+ seqstart_q=seqstart_q,
472
+ seqstart_k=seqstart_k,
473
+ max_seqlen_q=max_seqlen_q,
474
+ max_seqlen_k=max_seqlen_k,
475
+ seqlen_k=seqlen_k,
476
+ logsumexp=ctx.lse,
477
+ output=ctx.out.to(dtype),
478
+ dropout_p=inp.p,
479
+ # if not using dropout, seed and offset are irrelevant but still expected
480
+ # in function signature so just pass 0
481
+ # seed and offset could be None if a different FW op other than cutlass
482
+ # was used.
483
+ rng_seed=rng_seed,
484
+ rng_offset=rng_offset,
485
+ custom_mask_type=_custom_mask_type(inp.attn_bias),
486
+ scale=inp.scale,
487
+ window_size=(
488
+ inp.attn_bias._window_size
489
+ if isinstance(
490
+ inp.attn_bias,
491
+ (
492
+ BlockDiagonalCausalLocalAttentionMask,
493
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
494
+ LowerTriangularFromBottomRightLocalAttentionMask,
495
+ ),
496
+ )
497
+ else None
498
+ ),
499
+ )
500
+
501
+ # c++/CUDA implementation returns an uninitialized tensor if bias doesn't
502
+ # require grad
503
+ if not (
504
+ isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad
505
+ ):
506
+ grad_bias = None
507
+
508
+ return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)
@@ -0,0 +1,141 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ from typing import Any, Iterable, List, Optional, Set, Tuple
9
+
10
+ import torch
11
+
12
+ from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
13
+ from .common import AttentionFwOpBase, Context, Inputs
14
+ from .utils.op_common import get_operator, register_operator
15
+
16
+
17
+ @register_operator
18
+ class FwOp(AttentionFwOpBase):
19
+ """
20
+ An operator optimized for K=256 (so the contiguous dim fits into registers).
21
+ Tested to work on MI250x.
22
+ """
23
+
24
+ OPERATOR = get_operator("xformers", "efficient_attention_forward_decoder_ck")
25
+ SUPPORTED_DEVICES: Set[str] = {"cuda"}
26
+ SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16, torch.float}
27
+ SUPPORTED_MAX_K: int = 256
28
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
29
+ type(None),
30
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
31
+ )
32
+ SUPPORTS_DROPOUT = False
33
+ SUPPORTS_CUSTOM_SCALE = True
34
+ SUPPORTS_BMGHK = True
35
+ NAME = "ck_decoderF"
36
+
37
+ @classmethod
38
+ def not_supported_reasons(cls, d: Inputs) -> List[str]: # noqa: C901
39
+ reasons = super(FwOp, cls).not_supported_reasons(d)
40
+
41
+ attn_bias = d.attn_bias
42
+ if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
43
+ if d.query.shape[0] != 1:
44
+ reasons.append(
45
+ f"One formal batch element expected; got {d.query.shape[0]}"
46
+ )
47
+
48
+ if d.query.shape[-1] > cls.SUPPORTED_MAX_K:
49
+ reasons.append(
50
+ f"Got head_dim={d.query.shape[-1]}; only head_dim<={cls.SUPPORTED_MAX_K} is supported for now."
51
+ )
52
+
53
+ threads_per_warp = 64 # TODO: ideally query the platform here
54
+ required_alignment = 0
55
+ head_dim = d.query.shape[-1]
56
+ for vec_size in (4, 2, 1):
57
+ if head_dim <= vec_size * threads_per_warp:
58
+ required_alignment = vec_size
59
+
60
+ if not required_alignment:
61
+ reasons.append(f"Got head_dim={head_dim} which is too large")
62
+
63
+ if head_dim % required_alignment != 0:
64
+ reasons.append(
65
+ f"Got head_dim={head_dim}; it needs to be divisible by {required_alignment}"
66
+ )
67
+
68
+ if d.key.stride(-1) != 1:
69
+ reasons.append("expect keys to have last dim contiguous")
70
+
71
+ if d.value.stride(-1) != 1:
72
+ reasons.append("expect values to have last dim contiguous")
73
+
74
+ q_starts = attn_bias.q_seqinfo.seqstart_py
75
+ padding = attn_bias.k_seqinfo.padding
76
+ bsz = d.key.shape[1] // padding
77
+ num_queries = d.query.shape[1] // bsz
78
+
79
+ if q_starts != list(range(0, 1 + bsz, num_queries)):
80
+ reasons.append("expect to have same num_queries in each batch")
81
+ if bsz != len(q_starts) - 1:
82
+ reasons.append("empty lanes not supported yet")
83
+
84
+ if attn_bias.k_seqinfo.padding > 8192:
85
+ reasons.append("key padding exceeds 8192")
86
+
87
+ return reasons
88
+
89
+ @classmethod
90
+ def apply(
91
+ cls, inp: Inputs, needs_gradient: bool
92
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
93
+ if needs_gradient:
94
+ raise NotImplementedError("backward pass is not supported")
95
+ attn_bias = inp.attn_bias
96
+ q, k, v = inp.get_qkv_in_bmghk()
97
+ if attn_bias is not None:
98
+ assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
99
+ attn_bias.k_seqinfo.to(k.device)
100
+ attn_bias.q_seqinfo.to(q.device)
101
+ padding = attn_bias.k_seqinfo.padding
102
+ seq_positions_gpu = attn_bias.k_seqinfo.seqlen
103
+ else:
104
+ padding = k.shape[1]
105
+ seq_positions_gpu = None
106
+
107
+ if attn_bias is not None:
108
+ # key: (1, B * padding, G, 1 if multiquery else Hkv, D)
109
+ # value: like key
110
+ # query: (1, B * q_seqlen, G, Hq, D)
111
+ multiquery = k.stride(3) == 0
112
+ if multiquery:
113
+ key = k[0, :, :, :1].unflatten(0, (-1, padding))
114
+ value = v[0, :, :, :1].unflatten(0, (-1, padding))
115
+ else:
116
+ key = k[0].unflatten(0, (-1, padding))
117
+ value = v[0].unflatten(0, (-1, padding))
118
+ query = q[0].unflatten(0, (key.shape[0], -1))
119
+ else:
120
+ # key: (B, padding, G, 1 if multiquery else Hkv, D)
121
+ # value: like key
122
+ # query: (B, q_seqlen, G, Hq, D)
123
+ key = k
124
+ query = q
125
+ value = v
126
+
127
+ if inp.scale is not None:
128
+ qk_scale = inp.scale
129
+ else:
130
+ qk_scale = torch.rsqrt(
131
+ torch.tensor(key.shape[-1], dtype=torch.float32)
132
+ ).item()
133
+
134
+ out = cls.OPERATOR(
135
+ query=query,
136
+ key=key,
137
+ value=value,
138
+ seq_positions=seq_positions_gpu,
139
+ scale=qk_scale,
140
+ )
141
+ return out, None