mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,461 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ from dataclasses import replace
9
+ from enum import Enum
10
+ from functools import partial
11
+ from typing import Any, Iterable, List, Optional, Set, Tuple, Union
12
+
13
+ import torch
14
+
15
+ from . import attn_bias
16
+ from .attn_bias import (
17
+ AttentionBias,
18
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
19
+ BlockDiagonalCausalLocalAttentionMask,
20
+ BlockDiagonalCausalMask,
21
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
22
+ BlockDiagonalMask,
23
+ LowerTriangularFromBottomRightLocalAttentionMask,
24
+ LowerTriangularFromBottomRightMask,
25
+ LowerTriangularMask,
26
+ LowerTriangularMaskWithTensorBias,
27
+ )
28
+ from .common import (
29
+ _attn_bias_apply,
30
+ AttentionBwOpBase,
31
+ AttentionFwOpBase,
32
+ check_lastdim_alignment_stride1,
33
+ Context,
34
+ Gradients,
35
+ Inputs,
36
+ )
37
+ from .torch_attention_compat import is_pt_cutlass_compatible
38
+ from .utils.op_common import get_operator, register_operator
39
+
40
+
41
+ def _uses_tensorcores(sm: int, is_half: bool) -> bool:
42
+ if sm >= 80:
43
+ return True
44
+ if sm >= 70:
45
+ return is_half
46
+ return False
47
+
48
+
49
+ def _minimum_gemm_alignment(inp: Inputs) -> int:
50
+ if inp.device.type != "cuda":
51
+ return 1
52
+ cap = torch.cuda.get_device_capability(inp.device)
53
+ sm = cap[0] * 10 + cap[1]
54
+ bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[
55
+ inp.query.dtype
56
+ ]
57
+ uses_tensorcores = _uses_tensorcores(sm, bits_per_scalar == 16)
58
+ matmul_alignment_mn = 1
59
+ if sm >= 80:
60
+ matmul_alignment_mn = 4
61
+ if uses_tensorcores:
62
+ matmul_alignment_mn = max(matmul_alignment_mn, 128 // bits_per_scalar)
63
+ return matmul_alignment_mn
64
+
65
+
66
+ def _get_seqlen_info(
67
+ inp: Inputs,
68
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]:
69
+ attn_bias = inp.attn_bias
70
+ if isinstance(
71
+ attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask)
72
+ ):
73
+ assert attn_bias.k_seqinfo.seqstart.device == inp.query.device
74
+ seqstart_k = attn_bias.k_seqinfo.seqstart
75
+ seqstart_q = attn_bias.q_seqinfo.seqstart
76
+ max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
77
+ max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
78
+ else:
79
+ seqstart_k = None
80
+ seqstart_q = None
81
+ max_seqlen_q = -1
82
+ max_seqlen_k = -1
83
+
84
+ return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k
85
+
86
+
87
+ def _get_tensor_bias(
88
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]],
89
+ ) -> Optional[torch.Tensor]:
90
+ if isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
91
+ return attn_bias._bias
92
+ if isinstance(attn_bias, torch.Tensor):
93
+ return attn_bias
94
+ return None
95
+
96
+
97
+ def _check_bias_alignment(
98
+ reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
99
+ ) -> None:
100
+ attn_bias_tensor = _get_tensor_bias(attn_bias)
101
+ if attn_bias_tensor is not None:
102
+ alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits
103
+ show_padding_hint = False
104
+ for d in range(attn_bias_tensor.ndim - 1):
105
+ if attn_bias_tensor.stride(d) % alignment != 0:
106
+ reasons.append(
107
+ f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})"
108
+ )
109
+ show_padding_hint = True
110
+ if show_padding_hint:
111
+ reasons.append(
112
+ """\
113
+ HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \
114
+ you need to ensure memory is aligned by slicing a bigger tensor. \
115
+ Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`"""
116
+ )
117
+ # We can have stride=0 sometimes if dimension=1
118
+ if attn_bias_tensor.stride(-1) > 1:
119
+ reasons.append(
120
+ f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - "
121
+ "you should call `.contiguous()` on the bias"
122
+ )
123
+
124
+
125
+ class _CustomMaskType(int, Enum):
126
+ """
127
+ (Matches CustomMaskType in C++.)
128
+ """
129
+
130
+ NoCustomMask = 0
131
+ CausalFromTopLeft = 1
132
+ CausalFromBottomRight = 2
133
+
134
+
135
+ def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
136
+ if isinstance(
137
+ bias,
138
+ (
139
+ LowerTriangularMask,
140
+ BlockDiagonalCausalMask,
141
+ BlockDiagonalCausalLocalAttentionMask,
142
+ ),
143
+ ):
144
+ return int(_CustomMaskType.CausalFromTopLeft)
145
+ if isinstance(
146
+ bias,
147
+ (
148
+ LowerTriangularFromBottomRightMask,
149
+ LowerTriangularFromBottomRightLocalAttentionMask,
150
+ attn_bias.BlockDiagonalCausalFromBottomRightMask,
151
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
152
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
153
+ ),
154
+ ):
155
+ return int(_CustomMaskType.CausalFromBottomRight)
156
+ return int(_CustomMaskType.NoCustomMask)
157
+
158
+
159
+ @register_operator
160
+ class FwOp(AttentionFwOpBase):
161
+ """xFormers' MHA kernel based on CUTLASS.
162
+ Supports a large number of settings (including without TensorCores, f32 ...)
163
+ and GPUs as old as P100 (Sm60)
164
+ """
165
+
166
+ OPERATOR = (
167
+ get_operator("aten", "_efficient_attention_forward")
168
+ if is_pt_cutlass_compatible()
169
+ else None
170
+ )
171
+ CUDA_MAXIMUM_COMPUTE_CAPABILITY = (9, 0)
172
+ SUPPORTED_DEVICES: Set[str] = {"cuda"}
173
+ SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16}
174
+ SUPPORTED_MAX_K = 65536
175
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
176
+ type(None),
177
+ torch.Tensor,
178
+ LowerTriangularMask,
179
+ LowerTriangularFromBottomRightMask,
180
+ LowerTriangularFromBottomRightLocalAttentionMask,
181
+ LowerTriangularMaskWithTensorBias,
182
+ BlockDiagonalMask,
183
+ BlockDiagonalCausalMask,
184
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
185
+ attn_bias.BlockDiagonalCausalFromBottomRightMask,
186
+ attn_bias.BlockDiagonalCausalLocalAttentionMask,
187
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
188
+ )
189
+ SUPPORTS_DROPOUT = True
190
+ SUPPORTS_CUSTOM_SCALE = True
191
+ SUPPORTS_DIFFERENT_VALUE_EMBED = True
192
+ SUPPORTS_BMGHK = True
193
+ VARLEN_LSE_PACKED = False
194
+ NAME = "cutlassF-pt"
195
+
196
+ _TEST_K: List[int] = [
197
+ 32, # 64x64 kernel
198
+ 128, # 64x128 kernel
199
+ 256, # 64x128 with accumulation in gmem
200
+ ]
201
+
202
+ @classmethod
203
+ def apply(
204
+ cls, inp: Inputs, needs_gradient: bool
205
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
206
+ if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
207
+ raise NotImplementedError("Unsupported attn_bias type")
208
+ if inp.query.ndim in [3, 4]:
209
+ return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
210
+ assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
211
+ ctx: Optional[Context] = None
212
+ # XXX: Hackfix for BMGHK with H=1
213
+ # In that case we don't want to run G different streams because it adds
214
+ # some overhead
215
+ if inp.query.ndim == 5 and inp.query.shape[3] == 1:
216
+ slice_op = partial(torch.squeeze, dim=3)
217
+ inp = replace(
218
+ inp,
219
+ query=slice_op(inp.query),
220
+ key=slice_op(inp.key),
221
+ value=slice_op(inp.value),
222
+ attn_bias=_attn_bias_apply(
223
+ inp.attn_bias, partial(torch.squeeze, dim=2)
224
+ ),
225
+ )
226
+ out, ctx = cls.apply_bmhk(inp, needs_gradient=needs_gradient)
227
+ out = out.unsqueeze(3)
228
+ if ctx is not None:
229
+ ctx = replace(ctx, lse=ctx.lse.unsqueeze(1), out=out)
230
+ return out, ctx
231
+
232
+ # Workaround until this is properly implemented in C++
233
+ # run each head group in a different stream
234
+ n_groups = inp.key.shape[2]
235
+ main_stream = torch.cuda.current_stream()
236
+ streams = [main_stream] + [
237
+ torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1)
238
+ ]
239
+ outs = []
240
+ for group, stream in enumerate(streams):
241
+ stream.wait_stream(main_stream)
242
+ with torch.cuda.stream(stream):
243
+ query = inp.query[:, :, group]
244
+ key = inp.key[:, :, group]
245
+ value = inp.value[:, :, group]
246
+ bias = _attn_bias_apply(
247
+ inp.attn_bias, partial(torch.select, dim=1, index=group)
248
+ )
249
+ outs.append(
250
+ cls.apply_bmhk(
251
+ replace(inp, query=query, key=key, value=value, attn_bias=bias),
252
+ needs_gradient=needs_gradient,
253
+ )
254
+ )
255
+ for s in streams[1:]:
256
+ main_stream.wait_stream(s)
257
+ out = torch.stack([o[0] for o in outs], dim=2)
258
+ if needs_gradient:
259
+ ctx = Context(
260
+ out=out,
261
+ lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore
262
+ op_bw=outs[0][1].op_bw, # type: ignore
263
+ )
264
+ return out, ctx
265
+
266
+ @classmethod
267
+ def apply_bmhk(
268
+ cls, inp: Inputs, needs_gradient: bool
269
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
270
+ if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
271
+ raise NotImplementedError("Unsupported attn_bias type")
272
+ seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp)
273
+ out, lse, rng_seed, rng_offset, _, _ = cls.OPERATOR(
274
+ query=inp.query,
275
+ key=inp.key,
276
+ value=inp.value,
277
+ bias=_get_tensor_bias(inp.attn_bias),
278
+ cu_seqlens_q=seqstart_q,
279
+ cu_seqlens_k=seqstart_k,
280
+ max_seqlen_q=max_seqlen_q,
281
+ max_seqlen_k=max_seqlen_k,
282
+ dropout_p=inp.p,
283
+ compute_log_sumexp=needs_gradient,
284
+ custom_mask_type=_custom_mask_type(inp.attn_bias),
285
+ scale=inp.scale,
286
+ seqlen_k=(
287
+ inp.attn_bias.k_seqinfo.seqlen
288
+ if isinstance(
289
+ inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask
290
+ )
291
+ else None
292
+ ),
293
+ window_size=(
294
+ inp.attn_bias._window_size
295
+ if isinstance(
296
+ inp.attn_bias,
297
+ (
298
+ BlockDiagonalCausalLocalAttentionMask,
299
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
300
+ LowerTriangularFromBottomRightLocalAttentionMask,
301
+ ),
302
+ )
303
+ else None
304
+ ),
305
+ )
306
+ ctx: Optional[Context] = None
307
+ if needs_gradient:
308
+ ctx = Context(out=out, lse=lse)
309
+ if inp.p != 0:
310
+ # cutlass forward is only compatible with cutlass backward if
311
+ # dropout is used (because of the way RNG states are passed and the
312
+ # way random numbers are generated during backward)
313
+ ctx.rng_state = (rng_seed, rng_offset)
314
+ ctx.op_bw = BwOp
315
+ return out, ctx
316
+
317
+ @classmethod
318
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
319
+ reasons = super(FwOp, cls).not_supported_reasons(d)
320
+ matmul_alignment_mn = _minimum_gemm_alignment(d)
321
+ check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
322
+ check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
323
+ _check_bias_alignment(reasons, d.attn_bias)
324
+ return reasons
325
+
326
+
327
+ @register_operator
328
+ class BwOp(AttentionBwOpBase):
329
+ __doc__ = FwOp.__doc__
330
+
331
+ OPERATOR = (
332
+ get_operator("aten", "_efficient_attention_backward")
333
+ if is_pt_cutlass_compatible()
334
+ else None
335
+ )
336
+ CUDA_MAXIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MAXIMUM_COMPUTE_CAPABILITY
337
+ SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
338
+ SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
339
+ SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
340
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
341
+ type(None),
342
+ torch.Tensor,
343
+ LowerTriangularMask,
344
+ LowerTriangularFromBottomRightMask,
345
+ # TODO: Still some infs/nans in the BW pass for
346
+ # local + causal
347
+ # LowerTriangularFromBottomRightLocalAttentionMask,
348
+ # TODO: Fix handling of gradient through the fMHA autograd function
349
+ # LowerTriangularMaskWithTensorBias,
350
+ BlockDiagonalMask,
351
+ BlockDiagonalCausalMask,
352
+ attn_bias.BlockDiagonalCausalFromBottomRightMask,
353
+ attn_bias.BlockDiagonalCausalLocalAttentionMask,
354
+ )
355
+ SUPPORTS_ATTN_BIAS_GRAD = True
356
+ SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
357
+ SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
358
+ SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
359
+ VARLEN_LSE_PACKED = False
360
+ NAME = "cutlassB-pt"
361
+
362
+ _TEST_K: List[int] = [
363
+ 32, # 64x64 kernel
364
+ 128, # 64x128/128x128 kernel
365
+ 256, # 64x128 with accumulation in gmem
366
+ ]
367
+
368
+ @classmethod
369
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
370
+ reasons = super(BwOp, cls).not_supported_reasons(d)
371
+ matmul_alignment_mn = _minimum_gemm_alignment(d)
372
+
373
+ check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
374
+ check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
375
+ check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
376
+ _check_bias_alignment(reasons, d.attn_bias)
377
+ attn_bias_tensor = _get_tensor_bias(d.attn_bias)
378
+
379
+ # Backprop of gradient through broadcasted bias is not supported
380
+ if attn_bias_tensor is not None and attn_bias_tensor.requires_grad:
381
+ # Don't forget that inputs are either in BMK or BMHK!
382
+ if d.query.ndim == 3 and attn_bias_tensor.ndim == 3:
383
+ expected_bias_shape = (*d.query.shape[:2], d.key.shape[1])
384
+ else:
385
+ # bias is B H Mq Mk
386
+ expected_bias_shape = (
387
+ d.query.shape[0],
388
+ d.query.shape[2] if d.query.ndim == 4 else 1,
389
+ d.query.shape[1],
390
+ d.key.shape[1],
391
+ )
392
+ if tuple(attn_bias_tensor.shape) != expected_bias_shape:
393
+ reasons.append(
394
+ "Broadcasting the `attn_bias` tensor is not supported "
395
+ f"(shape: {tuple(attn_bias_tensor.shape)}"
396
+ f"/ expected: {expected_bias_shape})"
397
+ )
398
+ return reasons
399
+
400
+ @classmethod
401
+ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
402
+ if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES:
403
+ raise NotImplementedError("Unsupported attn_bias type")
404
+
405
+ seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp)
406
+ dtype = inp.query.dtype
407
+
408
+ rng_seed = rng_offset = torch.Tensor()
409
+ if inp.p != 0.0:
410
+ assert ctx.rng_state is not None
411
+ rng_seed, rng_offset = ctx.rng_state
412
+ tensor_bias = _get_tensor_bias(inp.attn_bias)
413
+
414
+ force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5)
415
+ (grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
416
+ grad.to(dtype),
417
+ inp.query,
418
+ inp.key,
419
+ inp.value,
420
+ bias=tensor_bias,
421
+ bias_requires_grad=(
422
+ tensor_bias.requires_grad if tensor_bias is not None else False
423
+ ),
424
+ cu_seqlens_q=seqstart_q,
425
+ cu_seqlens_k=seqstart_k,
426
+ max_seqlen_q=max_seqlen_q,
427
+ max_seqlen_k=max_seqlen_k,
428
+ logsumexp=ctx.get_padded_lse(32, force_pad_inf=force_pad_inf),
429
+ out=ctx.out.to(dtype),
430
+ dropout_p=inp.p,
431
+ # if not using dropout, seed and offset are irrelevant but still expected
432
+ # in function signature so just pass 0
433
+ # seed and offset could be None if a different FW op other than cutlass
434
+ # was used.
435
+ philox_seed=rng_seed,
436
+ philox_offset=rng_offset,
437
+ custom_mask_type=_custom_mask_type(inp.attn_bias),
438
+ scale=inp.scale,
439
+ num_splits_key=None, # Let C++ determine it
440
+ window_size=(
441
+ inp.attn_bias._window_size
442
+ if isinstance(
443
+ inp.attn_bias,
444
+ (
445
+ BlockDiagonalCausalLocalAttentionMask,
446
+ BlockDiagonalCausalLocalAttentionFromBottomRightMask,
447
+ LowerTriangularFromBottomRightLocalAttentionMask,
448
+ ),
449
+ )
450
+ else None
451
+ ),
452
+ )
453
+
454
+ # c++/CUDA implementation returns an uninitialized tensor if bias doesn't
455
+ # require grad
456
+ if not (
457
+ isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad
458
+ ):
459
+ grad_bias = None
460
+
461
+ return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)