mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1378 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+ import functools
8
+ import sys
9
+ from dataclasses import dataclass
10
+ from typing import (
11
+ Any,
12
+ cast,
13
+ Dict,
14
+ Iterable,
15
+ List,
16
+ Optional,
17
+ Sequence,
18
+ Tuple,
19
+ Type,
20
+ TYPE_CHECKING,
21
+ Union,
22
+ )
23
+
24
+ import torch
25
+
26
+ from ._triton.available import is_triton_available
27
+ from .attn_bias import (
28
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
29
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
30
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
31
+ BlockDiagonalGappyKeysMask,
32
+ BlockDiagonalLocalAttentionPaddedKeysMask,
33
+ BlockDiagonalPaddedKeysMask,
34
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
35
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
36
+ PagedBlockDiagonalGappyKeysMask,
37
+ PagedBlockDiagonalPaddedKeysMask,
38
+ )
39
+ from .common import AttentionFwOpBase, check_lastdim_alignment_stride1, Context, Inputs
40
+ from .utils.op_common import register_operator
41
+
42
+
43
+ def _strides(x: Optional[torch.Tensor], *stride_names: str):
44
+ if x is None:
45
+ return {f"stride_{name}": None for name in stride_names}
46
+ assert x.ndim == len(stride_names)
47
+ return {f"stride_{name}": s for name, s in zip(stride_names, x.stride())}
48
+
49
+
50
+ def _is_supported_causal_bias(attn_bias: Any) -> bool:
51
+ return isinstance(
52
+ attn_bias,
53
+ (
54
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
55
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
56
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
57
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
58
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
59
+ ),
60
+ )
61
+
62
+
63
+ def _is_supported_local_bias(attn_bias: Any) -> bool:
64
+ return isinstance(
65
+ attn_bias,
66
+ (
67
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
68
+ BlockDiagonalLocalAttentionPaddedKeysMask,
69
+ ),
70
+ )
71
+
72
+
73
+ def _is_supported_gappy_bias(attn_bias: Any) -> bool:
74
+ return isinstance(
75
+ attn_bias,
76
+ (
77
+ BlockDiagonalGappyKeysMask,
78
+ PagedBlockDiagonalGappyKeysMask,
79
+ ),
80
+ )
81
+
82
+
83
+ def _is_supported_paged_bias(attn_bias: Any) -> bool:
84
+ return isinstance(
85
+ attn_bias,
86
+ (
87
+ PagedBlockDiagonalGappyKeysMask,
88
+ PagedBlockDiagonalPaddedKeysMask,
89
+ ),
90
+ )
91
+
92
+
93
+ @dataclass
94
+ class InputsFp8(Inputs):
95
+ """
96
+ Each of k/v_fp8_scales is an int32 tensor of shape (1, B * Mkv, Hq),
97
+ or (1, page_size * max_pages_per_lane, Hq) in the paged case.
98
+ Each int32 element contains two packed fp16 number
99
+ - scales and shifts for row-wise FP8 quantization.
100
+ """
101
+
102
+ k_fp8_scale_shift: Optional[torch.Tensor] = None
103
+ v_fp8_scale_shift: Optional[torch.Tensor] = None
104
+ q_fp8_scale_shift: Optional[torch.Tensor] = None
105
+ quantize_pv_to_fp8: bool = False
106
+ quantize_qk_to_fp8: bool = False
107
+
108
+ @property
109
+ def nbytes(self) -> int:
110
+ """
111
+ Number of bytes in the input, not counting the attention bias.
112
+ """
113
+ return (
114
+ super(InputsFp8, self).nbytes
115
+ + (
116
+ self.k_fp8_scale_shift.untyped_storage().nbytes()
117
+ if self.k_fp8_scale_shift is not None
118
+ else 0
119
+ )
120
+ + (
121
+ self.v_fp8_scale_shift.untyped_storage().nbytes()
122
+ if self.v_fp8_scale_shift is not None
123
+ else 0
124
+ )
125
+ )
126
+
127
+
128
+ if TYPE_CHECKING or is_triton_available():
129
+ from ._triton.splitk_kernels import _fwd_kernel_splitK, _splitK_reduce
130
+ else:
131
+ _fwd_kernel_splitK = None
132
+ _splitK_reduce = None
133
+
134
+
135
+ def _is_cuda() -> bool:
136
+ return torch.version.cuda is not None
137
+
138
+
139
+ def _is_cuda_at_least_sm80(device: torch.device) -> bool:
140
+ return _is_cuda() and torch.cuda.get_device_capability(device) >= (
141
+ 8,
142
+ 0,
143
+ )
144
+
145
+
146
+ @register_operator
147
+ class FwOp(AttentionFwOpBase):
148
+ """Flash-Attention with Split-K. Supports fused int4 and fp8 K/V quantization.
149
+ Quantized path will be taken if input K/V have type int32.
150
+
151
+ Int4 quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
152
+ the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
153
+ Quantization coefficients (scale and shift) are represented as two
154
+ float16 constants per group, packed into int32. Quantization coefficients of
155
+ all groups are placed at the beginning of the row. So, if unquantized K/V have head
156
+ dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
157
+ and dtype int32.
158
+ Pseudocode for dequantizing one row can look like:
159
+ group_size = D // 8
160
+ for i in range(NUM_GROUPS):
161
+ group_start = NUM_GROUPS + i * group_size
162
+ group_quant = K[..., group_start: group_start + group_size]
163
+ scale, shift = unpack_int32_into_float16x2(group_quant[0])
164
+ group_dequant = group_quant[..., 1:] * scale + shift
165
+ ...
166
+
167
+ For fp8 only row-wise quantization is supported. To use it, provide input of type
168
+ xformers.ops.fmha.triton_splitk.InputsFp8 (instead of the usual xformers.ops.fmha.Inputs) to
169
+ xformers.ops.fmha.triton_splitk.FwOp.apply or xformers.ops.fmha._memory_efficient_attention_forward.
170
+
171
+ This op uses Paged Attention when bias is one of the Paged* classes.
172
+ In this case bias has additional fields:
173
+ - block_tables of shape [batch_size, max_num_pages]
174
+ - K/V of shape [1, max_num_pages * page_size, num_heads, head_dim]
175
+ or [1, max_num_pages * page_size, num_groups, num_heads, head_dim]
176
+
177
+ The shape which the kernel takes the queries and the output
178
+ is quite different from the user interface. There are three
179
+ types of input (a) no bias / tensor bias, (b) variable q_len
180
+ (which is only for non causal) and (c) other bias objects.
181
+ From the interface to the kernel the following changes happen.
182
+
183
+ (0) In all cases, a group dimension may need to be added.
184
+
185
+ (1) For (c), a batch dimension is created, reshaping from (1, B*Mq, G, Hq, K)
186
+ to (B, Mq, G, Hq, K)
187
+
188
+ (2) For (a) and (c), in the case of multiquery (i.e. the head dimension
189
+ of keys and values is expanded), the head-swapping trick
190
+ reshaping from (B, Mq, G, Hq, K) to (B, M=Hq*Mq, G, H=1, K)
191
+
192
+ (3) For (b), in the case of multiquery, the head-swapping trick
193
+ trick, reshaping from (1, Mq, G, Hq, K) to (1, Mq*Hq, G, H=1, K)
194
+ Note here that Mq is a single long dimension which spans all the queries
195
+ in the batch, unlike in case (C). Also that Hq has to run faster than
196
+ Mq in order that the queries in a batch element remain evenly spaced.
197
+
198
+ In all cases, the shape as seen by the kernel is called (Bqq, Mqq, G, H, K).
199
+ The kernel operates on B batch elements and M queries per batch element.
200
+ """
201
+
202
+ OPERATOR = True
203
+ SUPPORTED_DEVICES = {"cuda"}
204
+ CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
205
+ SUPPORTED_DTYPES = {
206
+ torch.half,
207
+ torch.bfloat16,
208
+ torch.float8_e4m3fn,
209
+ } # Those are dtypes of Q. In the quantized case K/V has dtype int32
210
+ SUPPORTED_MAX_K = 512
211
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
212
+ type(None),
213
+ torch.Tensor,
214
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
215
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
216
+ BlockDiagonalLocalAttentionPaddedKeysMask,
217
+ BlockDiagonalGappyKeysMask,
218
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
219
+ BlockDiagonalPaddedKeysMask,
220
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
221
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
222
+ PagedBlockDiagonalGappyKeysMask,
223
+ PagedBlockDiagonalPaddedKeysMask,
224
+ )
225
+ SUPPORTS_DROPOUT = False
226
+ SUPPORTS_CUSTOM_SCALE = True
227
+ SUPPORTS_BMGHK = True
228
+ SUPPORTS_OUTPUT_DTYPE = True
229
+ SUPPORTS_PARTIAL = True
230
+ NAME = "triton_splitKF"
231
+
232
+ SPLIT_K: Optional[int] = None
233
+ MAX_BLOCK_M = 32
234
+
235
+ # Whether blocks attending to no part of a variable sequence length
236
+ # should exit early. This requires extra kernels to run beforehand
237
+ # to initialise the outputs.
238
+ # TODO: avoid these by making the reduce kernel work out it doesn't need
239
+ # to look at the irrelevant places.
240
+ SPLIT_K_EARLY_EXIT: bool = False
241
+
242
+ # Perform kernel-level Triton autotune
243
+ AUTOTUNE = False
244
+
245
+ NUM_GROUPS = 1 # Default quantization is row-wise
246
+ NUM_GROUPS_VALUES = [1, 2, 4, 8]
247
+
248
+ # Values below are used when autotune=False.
249
+ # Note that under certain conditions different values might be used, see the code just before the kernel launch.
250
+ BLOCK_M: int = 16 # When M > 1, different BLOCK_M can be used.
251
+ BLOCK_N: int = 64
252
+ # On AMD or for M > 1 different NUM_STAGES and NUM_WARPS can be used.
253
+ NUM_STAGES: int = 1
254
+ NUM_WARPS: int = 2
255
+
256
+ @classmethod
257
+ def shape_not_supported_reasons(
258
+ cls, Mq: int, Mkv: int, K: int, Kv: int
259
+ ) -> List[str]:
260
+ reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
261
+ if K not in {16, 32, 64, 128, 256, 512}:
262
+ reasons.append(f"Embed dim {K} not supported")
263
+ return reasons
264
+
265
+ @classmethod
266
+ def not_supported_reasons(cls, d: Inputs) -> List[str]: # noqa: C901
267
+ reasons = super(FwOp, cls).not_supported_reasons(d)
268
+ if (sys.version_info.major, sys.version_info.minor) < (3, 9):
269
+ reasons.append("triton_splitk requires python 3.9 or above!")
270
+ check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
271
+ if d.key.dtype != torch.int32:
272
+ check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
273
+ check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
274
+ if cls.OPERATOR is None:
275
+ reasons.append("triton is not available")
276
+ if d.device.type == "cuda":
277
+ # Has only been tested on 8.0 / 9.0.
278
+ if _is_cuda() and not _is_cuda_at_least_sm80(d.device):
279
+ reasons.append(
280
+ "requires NVidia GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
281
+ )
282
+ # TODO: AMD GPU support matrix needs to be figured out. MI300X is tested to work.
283
+
284
+ q_len = d.query.shape[1]
285
+ is_block_diagonal = isinstance(
286
+ d.attn_bias, (BlockDiagonalPaddedKeysMask, BlockDiagonalGappyKeysMask)
287
+ )
288
+ is_paged = _is_supported_paged_bias(d.attn_bias)
289
+ is_causal = _is_supported_causal_bias(d.attn_bias)
290
+ is_local = _is_supported_local_bias(d.attn_bias)
291
+ if is_block_diagonal or is_paged:
292
+ seqinfo = d.attn_bias.q_seqinfo # type: ignore
293
+ if q_len != seqinfo.seqstart_py[-1]:
294
+ reasons.append(
295
+ f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
296
+ )
297
+ q_len = seqinfo.max_seqlen
298
+ if is_causal and q_len != seqinfo.min_seqlen:
299
+ reasons.append(
300
+ f"Variable query len is not supported for causal masks: got {seqinfo.max_seqlen=} {seqinfo.min_seqlen=}."
301
+ )
302
+ elif is_local and q_len != seqinfo.min_seqlen:
303
+ reasons.append(
304
+ f"Variable query len is not supported for local masks: got {seqinfo.max_seqlen=} {seqinfo.min_seqlen=}."
305
+ )
306
+ if q_len > 16 and (is_causal or is_local):
307
+ # 16 is the minimum BLOCK_M which gets used
308
+ # XXX I don't really understand why this is needed.
309
+ reasons.append(
310
+ "Query length should not be larger than 16 for causal or local attention biases"
311
+ )
312
+
313
+ if is_paged:
314
+ page_size = d.attn_bias.page_size # type: ignore
315
+ if d.key.shape[1] % page_size:
316
+ reasons.append(
317
+ "For paged attention, key.shape[1] should be divisible "
318
+ "by the page size, "
319
+ f"but got {d.key.shape[1]=}, {page_size=}."
320
+ )
321
+ if page_size % cls.BLOCK_N:
322
+ reasons.append(
323
+ "For paged attention, page size should be divisible "
324
+ "by the block size, "
325
+ f"but got {page_size=}, {cls.BLOCK_N=}."
326
+ )
327
+
328
+ if isinstance(d.attn_bias, torch.Tensor):
329
+ if d.attn_bias.ndim not in (4, 5):
330
+ reasons.append(
331
+ "Additive attention bias has to have shape (B, G, H, Mq, Mkv) "
332
+ f"or (B, H, Mq, Mkv), but got {d.attn_bias.shape}."
333
+ )
334
+ if cls.SPLIT_K is not None and cls.SPLIT_K > 1:
335
+ reasons.append(
336
+ "Additive attention bias is not supported with split-k > 1."
337
+ )
338
+
339
+ return reasons
340
+
341
+ @classmethod
342
+ def get_split_k(
343
+ cls, B: int, G: int, H: int, Mk: int, Mq: int, page_size: int, is_paged=False
344
+ ) -> int:
345
+ """Heuristic for the number of splits"""
346
+ bh = max(B * H, 1) # NOTE: Handle B*h=0 case
347
+ if torch.version.hip:
348
+ split_k = max(Mk + bh - 1, 1024) // bh
349
+ max_chunk_size = 64
350
+ split_k_stop_val = max(1024 / (B * G * H), 1)
351
+ while split_k > 1 and Mk / (split_k - 1) < max_chunk_size:
352
+ split_k = split_k - 1
353
+
354
+ while split_k > split_k_stop_val:
355
+ split_k = split_k // 2
356
+
357
+ split_size = (Mk + split_k - 1) // max(split_k, 1)
358
+
359
+ chunk_size = split_size // max_chunk_size * max_chunk_size
360
+ if chunk_size < split_size:
361
+ split_k += 1
362
+
363
+ split_k_upper_bound = 512
364
+ else:
365
+ if Mq > 1 and B * G * H > 64:
366
+ return 1
367
+ split_k = max(Mk, 1024) // bh
368
+ max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
369
+ split_k_stop_val = Mk / max_chunk_size
370
+ split_k_upper_bound = 64
371
+
372
+ while split_k > split_k_stop_val:
373
+ split_k = split_k // 2
374
+
375
+ split_k = min(split_k, split_k_upper_bound)
376
+ split_k = max(split_k, 1)
377
+
378
+ # makes no sense that split_size is larger than page_size
379
+ if is_paged and torch.version.hip:
380
+ split_size = (Mk + split_k - 1) // split_k
381
+ if split_size > page_size:
382
+ split_size = page_size
383
+ split_k = (Mk + split_size - 1) // split_size
384
+
385
+ return split_k
386
+
387
+ @classmethod
388
+ def get_kernel(cls):
389
+ from ._triton.splitk_kernels import (
390
+ _fwd_kernel_splitK_autotune,
391
+ _get_splitk_kernel,
392
+ )
393
+
394
+ if cls.AUTOTUNE:
395
+ return _fwd_kernel_splitK_autotune[cls.NUM_GROUPS]
396
+ else:
397
+ return _get_splitk_kernel(cls.NUM_GROUPS)
398
+
399
+ @classmethod
400
+ def get_fp8_scale_shift(
401
+ cls, inp: Inputs
402
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
403
+ if not hasattr(inp, "k_fp8_scale_shift"):
404
+ return None, None, None
405
+ inp_ = cast(InputsFp8, inp)
406
+ k_fp8_scale_shift = inp_.k_fp8_scale_shift
407
+ v_fp8_scale_shift = inp_.v_fp8_scale_shift
408
+ q_fp8_scale_shift = inp_.q_fp8_scale_shift
409
+
410
+ assert k_fp8_scale_shift is not None
411
+ assert v_fp8_scale_shift is not None
412
+ if k_fp8_scale_shift.ndim == 3:
413
+ k_fp8 = k_fp8_scale_shift.unsqueeze(2)
414
+ v_fp8 = v_fp8_scale_shift.unsqueeze(2)
415
+ q_fp8 = (
416
+ None if q_fp8_scale_shift is None else q_fp8_scale_shift.unsqueeze(2)
417
+ )
418
+ return k_fp8, v_fp8, q_fp8
419
+ if k_fp8_scale_shift.ndim == 4:
420
+ return k_fp8_scale_shift, v_fp8_scale_shift, q_fp8_scale_shift
421
+ raise ValueError(
422
+ "FP8 scales have to be provided in BMH or BMGH format, "
423
+ f"but got {k_fp8_scale_shift.shape=}"
424
+ )
425
+
426
+ @classmethod
427
+ def get_extra_args( # noqa: C901
428
+ cls,
429
+ *,
430
+ is_paged: bool,
431
+ B: int,
432
+ M: int,
433
+ Kkv: int,
434
+ Kq: int,
435
+ Mq: int,
436
+ split_k: int,
437
+ attn_bias: Any,
438
+ k_fp8_scale_shift: Any,
439
+ ) -> Dict[str, Any]:
440
+ BLOCK_M = cls.BLOCK_M
441
+ BLOCK_N = cls.BLOCK_N
442
+ if cls.AUTOTUNE:
443
+ extra_args = {}
444
+ else:
445
+ # TODO: remove this when autotuning on AMD is working
446
+ num_warps = cls.NUM_WARPS
447
+ num_stages = cls.NUM_STAGES
448
+ if torch.version.hip and attn_bias is not None:
449
+ # TODO: Double check paged.
450
+ mkv = attn_bias.k_seqinfo.max_seqlen
451
+ # TODO: Determine heuristics for paged attention
452
+ use_fp8_path = k_fp8_scale_shift is not None
453
+ if B == 1:
454
+ if use_fp8_path:
455
+ # Use specialized configs for FP8
456
+ if mkv <= 256:
457
+ BLOCK_N = 16
458
+ num_warps = 4
459
+ num_stages = 1
460
+ elif mkv <= 2048:
461
+ BLOCK_N = 32
462
+ num_warps = 4
463
+ num_stages = 1
464
+ elif mkv <= 16384:
465
+ BLOCK_N = 64
466
+ num_warps = 4
467
+ num_stages = 1
468
+ elif mkv >= 131072:
469
+ BLOCK_N = 128
470
+ num_warps = 2
471
+ num_stages = 1
472
+ else:
473
+ # Note: We don't have data for when transitioning num_wraps works well
474
+ BLOCK_N = 64
475
+ num_warps = 4
476
+ num_stages = 1
477
+ else:
478
+ num_warps = 4
479
+ num_stages = 1 # TODO num_stages = 0 gives better perf on AMD, but sometimes produces NaNs
480
+ BLOCK_N = 32
481
+ elif B <= 4 and split_k <= 128:
482
+ num_warps = 2
483
+ num_stages = 1
484
+ BLOCK_N = 32
485
+ elif B <= 16:
486
+ if use_fp8_path:
487
+ if mkv <= 256:
488
+ BLOCK_N = 16
489
+ num_warps = 4
490
+ num_stages = 1
491
+ elif mkv <= 4096:
492
+ BLOCK_N = 32
493
+ num_warps = 4
494
+ num_stages = 1
495
+ elif mkv <= 8192:
496
+ BLOCK_N = 16
497
+ num_warps = 2
498
+ num_stages = 1
499
+ elif mkv < 131072:
500
+ # Note: This isn't benchmarked, but fp8 seems to scale well.
501
+ BLOCK_N = 64
502
+ num_warps = 1
503
+ num_stages = 1
504
+ else:
505
+ BLOCK_N = 128
506
+ num_warps = 1
507
+ num_stages = 1
508
+ else:
509
+ if M < 16:
510
+ num_warps = 2
511
+ num_stages = 1
512
+ else:
513
+ num_warps = 1
514
+ num_stages = 1
515
+ BLOCK_N = 32
516
+ elif B <= 64 and use_fp8_path:
517
+ if is_paged:
518
+ num_stages = 1
519
+ if mkv <= 256:
520
+ BLOCK_N = 64
521
+ num_warps = 8
522
+ elif mkv <= 8192:
523
+ BLOCK_N = 64
524
+ num_warps = 1
525
+ elif mkv <= 16384:
526
+ BLOCK_N = 128
527
+ num_warps = 2
528
+ else:
529
+ # Note: This isn't benchmarked, but fp8 seems to scale well.
530
+ BLOCK_N = 128
531
+ num_warps = 1
532
+ else:
533
+ if mkv <= 256:
534
+ BLOCK_N = 16
535
+ num_warps = 4
536
+ num_stages = 1
537
+ elif mkv < 131072:
538
+ # Note: This isn't benchmarked, but fp8 seems to scale well.
539
+ BLOCK_N = 64
540
+ num_warps = 1
541
+ num_stages = 1
542
+ else:
543
+ BLOCK_N = 128
544
+ num_warps = 1
545
+ num_stages = 1
546
+ elif B <= 128 and use_fp8_path:
547
+ num_stages = 1
548
+ if is_paged:
549
+ if mkv <= 256:
550
+ num_warps = 4
551
+ BLOCK_N = 16
552
+ elif mkv <= 2048:
553
+ num_warps = 1
554
+ BLOCK_N = 64
555
+ elif mkv < 131072:
556
+ num_warps = 2
557
+ BLOCK_N = 128
558
+ else:
559
+ # Note: This isn't benchmarked, but fp8 seems to scale well.
560
+ num_warps = 1
561
+ BLOCK_N = 128
562
+ else:
563
+ if mkv <= 128:
564
+ num_warps = 4
565
+ BLOCK_N = 16
566
+ else:
567
+ num_warps = 1
568
+ BLOCK_N = 64
569
+ elif B <= 256 and use_fp8_path:
570
+ num_stages = 1
571
+ if is_paged:
572
+ if mkv <= 2048:
573
+ num_warps = 1
574
+ BLOCK_N = 64
575
+ elif mkv < 131072:
576
+ num_warps = 2
577
+ BLOCK_N = 128
578
+ else:
579
+ # Note: This isn't benchmarked, but fp8 seems to scale well.
580
+ num_warps = 1
581
+ BLOCK_N = 128
582
+ else:
583
+ if mkv <= 256:
584
+ num_warps = 2
585
+ BLOCK_N = 32
586
+ else:
587
+ num_warps = 1
588
+ BLOCK_N = 64
589
+ else:
590
+ num_warps = 1
591
+ num_stages = 1
592
+ BLOCK_N = 64
593
+ else:
594
+ should_modify_warp_and_block = (
595
+ Kkv == 128
596
+ and Kq == 128
597
+ and torch.cuda.get_device_capability() >= (8, 9)
598
+ )
599
+ if should_modify_warp_and_block:
600
+ if Mq > 1:
601
+ num_warps = 4
602
+ # Choose minimal round block size which covers M.
603
+ if M > 16:
604
+ BLOCK_M = 32
605
+ if M > 32:
606
+ BLOCK_M = 64
607
+ if M > 64:
608
+ BLOCK_M = 128
609
+ extra_args = {
610
+ "BLOCK_M": BLOCK_M,
611
+ "BLOCK_N": BLOCK_N,
612
+ "num_warps": num_warps,
613
+ "num_stages": num_stages,
614
+ }
615
+ return extra_args
616
+
617
+ @classmethod
618
+ def apply( # noqa: C901
619
+ cls,
620
+ inp: Inputs,
621
+ needs_gradient: bool,
622
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
623
+ """
624
+ Note that inp can be of type InputsFp8, in which case K/V are assumed to be row-wise FP8-quantized.
625
+ This is different from int4 quantization, where coefficients are kept together with the quantized
626
+ values at the beginning of each row, and inp has type Inputs.
627
+ """
628
+
629
+ output_dtype = inp.get_output_dtype()
630
+ # LSE may need higher precision than output
631
+ output_f64_lse = output_dtype in (torch.float32, torch.float64)
632
+ lse_dtype = torch.float64 if output_f64_lse else torch.float32
633
+
634
+ if inp.query.numel() == 0 or inp.key.numel() == 0:
635
+ out = torch.zeros_like(inp.query)
636
+ if needs_gradient:
637
+ lse_out = torch.full(
638
+ (inp.query.shape[0],)
639
+ + inp.query.shape[2:-1]
640
+ + (inp.query.shape[1],),
641
+ float("-inf"),
642
+ device=inp.query.device,
643
+ dtype=lse_dtype,
644
+ )
645
+ return out, Context(out=out, lse=lse_out)
646
+ return out, None
647
+
648
+ # Assert that if quantize_qk_to_fp8 is True, q_fp8_scale_shift must be provided
649
+ if hasattr(inp, "quantize_qk_to_fp8") and getattr(
650
+ inp, "quantize_qk_to_fp8", False
651
+ ):
652
+ assert (
653
+ hasattr(inp, "q_fp8_scale_shift") and inp.q_fp8_scale_shift is not None # type: ignore
654
+ ), "q_fp8_scale_shift must be provided when quantize_qk_to_fp8 is True"
655
+
656
+ k_fp8_scale_shift, v_fp8_scale_shift, q_fp8_scale_shift = (
657
+ cls.get_fp8_scale_shift(inp)
658
+ )
659
+
660
+ if not isinstance(inp.attn_bias, torch.Tensor):
661
+ attn_bias_tensor = None
662
+ attn_bias = cast(
663
+ Optional[
664
+ Union[
665
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
666
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask,
667
+ BlockDiagonalLocalAttentionPaddedKeysMask,
668
+ BlockDiagonalGappyKeysMask,
669
+ BlockDiagonalCausalWithOffsetGappyKeysMask,
670
+ BlockDiagonalPaddedKeysMask,
671
+ PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
672
+ PagedBlockDiagonalCausalWithOffsetGappyKeysMask,
673
+ PagedBlockDiagonalGappyKeysMask,
674
+ PagedBlockDiagonalPaddedKeysMask,
675
+ ]
676
+ ],
677
+ inp.attn_bias,
678
+ )
679
+ else:
680
+ attn_bias_tensor = inp.attn_bias
681
+ attn_bias = None
682
+
683
+ seq_len = None
684
+ seq_starts_k = None
685
+ seq_starts_q = None
686
+ seq_starts_q_multiplier = None
687
+ q, k, v = inp.get_qkv_in_bmghk()
688
+ IS_CAUSAL = False
689
+ IS_LOCAL = False
690
+ NUM_QUERIES_CAUSAL = 1
691
+ variable_q = False
692
+ window_left = -1
693
+ window_right = -1
694
+
695
+ is_block_diagonal = isinstance(attn_bias, BlockDiagonalPaddedKeysMask)
696
+ is_gappy = _is_supported_gappy_bias(attn_bias)
697
+ is_paged = _is_supported_paged_bias(attn_bias)
698
+ if attn_bias is not None:
699
+ assert is_paged or is_block_diagonal or is_gappy
700
+ assert attn_bias.k_seqinfo.seqlen.device == inp.query.device
701
+ seq_len = attn_bias.k_seqinfo.seqlen
702
+ assert seq_len.stride(0) == 1
703
+ if is_gappy:
704
+ seq_starts_k = attn_bias.k_seqinfo.seqstart
705
+ assert seq_starts_k.stride(0) == 1
706
+ assert q.shape[0] == 1
707
+ B = len(seq_len)
708
+ G, Hq, Kq = q.shape[-3:]
709
+ # force a bool because triton cannot take np.bool_
710
+ multiple_q = bool(attn_bias.q_seqinfo.max_seqlen > 1)
711
+ IS_CAUSAL = multiple_q and _is_supported_causal_bias(attn_bias)
712
+ IS_LOCAL = _is_supported_local_bias(attn_bias)
713
+ variable_q = multiple_q and not IS_CAUSAL
714
+ Kkv = v.shape[-1]
715
+ if isinstance(attn_bias, BlockDiagonalLocalAttentionPaddedKeysMask):
716
+ window_left = attn_bias.window_left
717
+ window_right = attn_bias.window_right
718
+ elif isinstance(attn_bias, BlockDiagonalCausalLocalAttentionPaddedKeysMask):
719
+ window_left = attn_bias._window_size - 1
720
+
721
+ if variable_q:
722
+ seq_starts_q = attn_bias.q_seqinfo.seqstart
723
+ seq_starts_q_multiplier = 1
724
+ assert seq_starts_q.stride(0) == 1
725
+ else:
726
+ q = q.view(B, -1, G, Hq, Kq)
727
+ if q_fp8_scale_shift is not None:
728
+ q_fp8_scale_shift = q_fp8_scale_shift.view(B, -1, G, Hq)
729
+
730
+ kv_shape = (1 if is_paged or is_gappy else B, -1, G, Hq, Kkv)
731
+ k = k.view(kv_shape)
732
+ v = v.view(kv_shape)
733
+ if k_fp8_scale_shift is not None and v_fp8_scale_shift is not None:
734
+ k_fp8_scale_shift = k_fp8_scale_shift.view(kv_shape[:-1])
735
+ v_fp8_scale_shift = v_fp8_scale_shift.view(kv_shape[:-1])
736
+
737
+ Mq = q.shape[1]
738
+ NUM_QUERIES_CAUSAL = Mq
739
+ else:
740
+ B, Mq, G, Hq, Kq = q.shape
741
+
742
+ if attn_bias_tensor is not None and attn_bias_tensor.ndim == 4:
743
+ # (B, H, Mq, Mkv) -> (B, G, H, Mq, Mkv)
744
+ attn_bias_tensor = attn_bias_tensor.unsqueeze(1)
745
+
746
+ # In the case of MQA/GQA, we make q have sequence length (H * Mq) and only one "head".
747
+ mqa_swap_seqlen_head = False
748
+ if (
749
+ k.shape[3] > 1
750
+ and k.stride(3) == 0
751
+ and v.stride(3) == 0
752
+ and attn_bias_tensor is None
753
+ ):
754
+ mqa_swap_seqlen_head = True
755
+ if q_fp8_scale_shift is not None:
756
+ assert q_fp8_scale_shift.shape == q.shape[:-1], (
757
+ f"{q.shape=}, {q_fp8_scale_shift.shape=}"
758
+ )
759
+ if variable_q:
760
+ q_fp8_scale_shift = q_fp8_scale_shift.permute(0, 1, 3, 2).reshape(
761
+ 1, -1, G, 1
762
+ )
763
+ else:
764
+ q_fp8_scale_shift = q_fp8_scale_shift.permute(0, 3, 1, 2).reshape(
765
+ q.shape[0], -1, G, 1
766
+ )
767
+ if variable_q:
768
+ seq_starts_q_multiplier = Hq
769
+ assert q.shape[0] == 1
770
+ # The idea is Hq,Mq are reshaped to (M=Mq*Hq, H=1)
771
+ q = q.permute(0, 1, 3, 2, 4).reshape(1, -1, G, 1, Kq)
772
+ else:
773
+ # This is a copy iff Mq, G and H are all > 1.
774
+ # The idea is Hq,Mq are reshaped to (M=Hq*Mq, H=1)
775
+ q = q.permute(0, 3, 1, 2, 4).reshape(q.shape[0], -1, G, 1, Kq)
776
+ k = k[:, :, :, :1]
777
+ v = v[:, :, :, :1]
778
+ if k_fp8_scale_shift is not None and v_fp8_scale_shift is not None:
779
+ k_fp8_scale_shift = k_fp8_scale_shift[:, :, :, :1]
780
+ v_fp8_scale_shift = v_fp8_scale_shift[:, :, :, :1]
781
+
782
+ if k.dtype == torch.int32:
783
+ if k_fp8_scale_shift is not None:
784
+ Lk = k.shape[-1] * 4
785
+ PACKED_PER_VAL = 4
786
+ else:
787
+ # Quantized K/V
788
+ PACKED_PER_VAL = 8
789
+ Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
790
+ else:
791
+ Lk = k.shape[-1]
792
+ PACKED_PER_VAL = 1
793
+ assert cls.NUM_GROUPS == 1, f"{cls.NUM_GROUPS=}"
794
+
795
+ _, Mk, G, H, Kkv = k.shape
796
+ Bqq, Mqq, G, H, Kq = q.shape
797
+ assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
798
+ if variable_q:
799
+ assert attn_bias is not None
800
+ assert seq_starts_q_multiplier is not None
801
+ M = attn_bias.q_seqinfo.max_seqlen * seq_starts_q_multiplier
802
+ else:
803
+ M = Mqq
804
+ page_size = inp.attn_bias.page_size if is_paged else 0 # type: ignore
805
+ block_tables = None
806
+ kv_cache_blocks_per_row = 0
807
+ if is_paged:
808
+ block_tables = inp.attn_bias.block_tables # type: ignore
809
+ kv_cache_blocks_per_row = block_tables.shape[1]
810
+ Mk = block_tables.shape[1] * page_size
811
+ elif attn_bias is not None:
812
+ Mk = min(Mk, attn_bias.k_seqinfo.max_seqlen)
813
+
814
+ if cls.SPLIT_K is not None:
815
+ split_k = cls.SPLIT_K
816
+ else:
817
+ # Use heuristics
818
+ split_k = (
819
+ cls.get_split_k(B, G, H, Mk, Mq, page_size, is_paged)
820
+ if attn_bias_tensor is None
821
+ else 1
822
+ )
823
+
824
+ # M_ceil = Mqq rounded up to a multiple of MAX_BLOCK_M
825
+ M_ceil = (Mqq + cls.MAX_BLOCK_M - 1) // cls.MAX_BLOCK_M * cls.MAX_BLOCK_M
826
+ IS_SPLITK = split_k > 1 # or cls.autotune?
827
+ output_shape = (Bqq, Mq, G, Hq, Kq)
828
+ if IS_SPLITK:
829
+ o_splitk_dtype = (
830
+ torch.float64 if output_dtype == torch.float64 else torch.float32
831
+ )
832
+ if cls.SPLIT_K_EARLY_EXIT:
833
+ o_splitk = torch.zeros(
834
+ [Bqq, G, H, split_k, M_ceil, Kq],
835
+ dtype=o_splitk_dtype,
836
+ device=q.device,
837
+ )
838
+ else:
839
+ o_splitk = torch.empty(
840
+ [Bqq, G, H, split_k, M_ceil, Kq],
841
+ dtype=o_splitk_dtype,
842
+ device=q.device,
843
+ )
844
+ else:
845
+ o_splitk = torch.empty(
846
+ [Bqq, split_k, Mqq, G, H, Kq],
847
+ dtype=output_dtype,
848
+ device=q.device,
849
+ ).permute(0, 3, 4, 1, 2, 5)
850
+ lse, lse_splitk = None, None
851
+ if IS_SPLITK or needs_gradient:
852
+ if IS_SPLITK or output_f64_lse:
853
+ lse_splitk_dtype = torch.float64
854
+ else:
855
+ lse_splitk_dtype = torch.float32
856
+ if cls.SPLIT_K_EARLY_EXIT:
857
+ lse_splitk = torch.full(
858
+ [Bqq, G, H, split_k, Mqq],
859
+ -float("inf"),
860
+ dtype=lse_splitk_dtype,
861
+ device=q.device,
862
+ )
863
+ else:
864
+ lse_splitk = torch.empty(
865
+ [Bqq, G, H, split_k, Mqq],
866
+ dtype=lse_splitk_dtype,
867
+ device=q.device,
868
+ )
869
+
870
+ def grid(META):
871
+ import triton
872
+
873
+ return triton.cdiv(M, META["BLOCK_M"]), B * G * H, split_k
874
+
875
+ split_size = (Mk + split_k - 1) // split_k
876
+ use_seq_len = seq_len is not None
877
+
878
+ kernel = cls.get_kernel()
879
+ extra_args = cls.get_extra_args(
880
+ is_paged=is_paged,
881
+ B=B,
882
+ M=M,
883
+ Kkv=Kkv,
884
+ Kq=Kq,
885
+ Mq=Mq,
886
+ split_k=split_k,
887
+ attn_bias=attn_bias,
888
+ k_fp8_scale_shift=k_fp8_scale_shift,
889
+ )
890
+
891
+ IS_HIP = torch.version.hip is not None
892
+
893
+ if inp.quantize_pv_to_fp8:
894
+ v = v.view(torch.int8)
895
+ v = v.view(torch.float8_e4m3fn)
896
+
897
+ kernel[grid](
898
+ Q=q,
899
+ K=k,
900
+ V=v,
901
+ sm_scale=inp.scale_float,
902
+ Out_splitK=o_splitk,
903
+ LSE_splitk=lse_splitk,
904
+ block_tables=block_tables,
905
+ Seq_len=seq_len,
906
+ Seq_starts_k=seq_starts_k,
907
+ Seq_starts_q=seq_starts_q,
908
+ Seq_starts_q_multiplier=seq_starts_q_multiplier,
909
+ additive_bias=attn_bias_tensor,
910
+ K_fp8_scale_shift=k_fp8_scale_shift,
911
+ V_fp8_scale_shift=v_fp8_scale_shift,
912
+ q_fp8_scale_shift=q_fp8_scale_shift,
913
+ **_strides(q, "qz", "qm", "qg", "qh", "qk"),
914
+ **_strides(k, "kz", "kn", "kg", "kh", "kk"),
915
+ **_strides(v, "vz", "vn", "vg", "vh", "vk"),
916
+ **_strides(o_splitk, "osk_z", "osk_g", "osk_h", "osk_s", "osk_m", "osk_k"),
917
+ **_strides(lse_splitk, "lsek_z", "lsek_g", "lsek_h", "lsek_s", "lsek_m"),
918
+ **_strides(block_tables, "blocktablesz", "blocktablesl"),
919
+ **_strides(
920
+ attn_bias_tensor, "bias_b", "bias_g", "bias_h", "bias_qm", "bias_km"
921
+ ),
922
+ **_strides(
923
+ k_fp8_scale_shift,
924
+ "k_fp8_scale_shift_z",
925
+ "k_fp8_scale_shift_n",
926
+ "k_fp8_scale_shift_g",
927
+ "k_fp8_scale_shift_h",
928
+ ),
929
+ **_strides(
930
+ v_fp8_scale_shift,
931
+ "v_fp8_scale_shift_z",
932
+ "v_fp8_scale_shift_n",
933
+ "v_fp8_scale_shift_g",
934
+ "v_fp8_scale_shift_h",
935
+ ),
936
+ **_strides(
937
+ q_fp8_scale_shift,
938
+ "q_fp8_scale_shift_z",
939
+ "q_fp8_scale_shift_m",
940
+ "q_fp8_scale_shift_g",
941
+ "q_fp8_scale_shift_h",
942
+ ),
943
+ kv_cache_blocks_per_row=kv_cache_blocks_per_row,
944
+ Z=B,
945
+ H=H,
946
+ G=G,
947
+ N_CTX_Q=M,
948
+ N_CTX_K=Mk,
949
+ BLOCK_N_PER_SPLIT=split_size,
950
+ BLOCK_DMODEL=Lk,
951
+ USE_SEQ_LEN=use_seq_len,
952
+ PACKED_PER_VAL=PACKED_PER_VAL,
953
+ N_GROUPS=cls.NUM_GROUPS,
954
+ IS_CAUSAL=IS_CAUSAL,
955
+ IS_LOCAL=IS_LOCAL,
956
+ NUM_QUERIES_CAUSAL=NUM_QUERIES_CAUSAL,
957
+ IS_SPLITK=IS_SPLITK,
958
+ SPLIT_K_EARLY_EXIT=cls.SPLIT_K_EARLY_EXIT,
959
+ USE_PAGED_ATTENTION=is_paged,
960
+ PAGE_SIZE=page_size,
961
+ WINDOW_LEFT=window_left,
962
+ WINDOW_RIGHT=window_right,
963
+ WRITE_LSE=IS_SPLITK or needs_gradient,
964
+ HAS_ADDITIVE_BIAS=attn_bias_tensor is not None,
965
+ NUM_PROGRAMS_DIM2_CONST=split_k,
966
+ IS_HIP=IS_HIP,
967
+ QUANTIZE_PV_TO_FP8=inp.quantize_pv_to_fp8,
968
+ QUANTIZE_QK_TO_FP8=inp.quantize_qk_to_fp8,
969
+ USE_FP32_SCALES=inp.use_fp32_scales,
970
+ **extra_args,
971
+ )
972
+ if not IS_SPLITK:
973
+ out = o_splitk[:, :, :, 0] # Bqq, G, H, Mqq, Kq
974
+ if variable_q and mqa_swap_seqlen_head:
975
+ out = out.view(1, G, Mq, Hq, Kq).permute(0, 2, 1, 3, 4).contiguous()
976
+ else:
977
+ out = out.view(Bqq, G, Hq, Mq, Kq)
978
+ # This is a copy iff mqa_swap_seqlen_head and Mq, G and Hq are all > 1.
979
+ out = out.permute(0, 3, 1, 2, 4).contiguous()
980
+ if needs_gradient:
981
+ assert lse_splitk is not None
982
+ lse = lse_splitk[:, :, :, 0] # Bqq, G, H, Mqq
983
+ if variable_q and mqa_swap_seqlen_head:
984
+ lse = lse.view(1, G, Mq, Hq).permute(0, 1, 3, 2)
985
+ else:
986
+ lse = lse.view(Bqq, G, Hq, Mq)
987
+ if attn_bias is not None and not variable_q:
988
+ lse = lse.permute(1, 2, 0, 3).reshape(1, G, Hq, B * Mq)
989
+ else:
990
+ lse = None
991
+
992
+ if inp.query.ndim == 4:
993
+ # BMGHK -> BMHK
994
+ assert G == 1
995
+ if lse is not None:
996
+ lse = lse[:, 0]
997
+ out = out[:, :, 0]
998
+
999
+ if lse is None:
1000
+ return out, None
1001
+ return out, Context(out=out, lse=lse)
1002
+
1003
+ out = torch.empty(output_shape, device=q.device, dtype=output_dtype)
1004
+
1005
+ # Merge attention and LSE outputs from different split-k chunks
1006
+ assert lse_splitk is not None
1007
+ output_lse = None
1008
+ if needs_gradient:
1009
+ if attn_bias is None or variable_q:
1010
+ output_lse = torch.empty(
1011
+ (Bqq, G, Hq, Mq), device=q.device, dtype=lse_dtype
1012
+ )
1013
+ lse = output_lse
1014
+ else:
1015
+ output_lse = torch.empty(
1016
+ (1, G, Hq, B * Mq), device=q.device, dtype=lse_dtype
1017
+ )
1018
+ lse = output_lse.view(G, Hq, B, Mq).permute(2, 0, 1, 3)
1019
+
1020
+ o_splitk = o_splitk[:, :, :, :, :Mqq]
1021
+
1022
+ if mqa_swap_seqlen_head:
1023
+ if variable_q:
1024
+ o_splitk = o_splitk.view(Bqq, G, split_k, Mq, Hq, Kq).permute(
1025
+ 0, 1, 4, 2, 3, 5
1026
+ )
1027
+ lse_splitk = lse_splitk.view(Bqq, G, split_k, Mq, Hq).permute(
1028
+ 0, 1, 4, 2, 3
1029
+ )
1030
+ else:
1031
+ o_splitk = o_splitk.view(Bqq, G, split_k, Hq, Mq, Kq).permute(
1032
+ 0, 1, 3, 2, 4, 5
1033
+ )
1034
+ lse_splitk = lse_splitk.view(Bqq, G, split_k, Hq, Mq).permute(
1035
+ 0, 1, 3, 2, 4
1036
+ )
1037
+
1038
+ merge_attentions(out, lse, o_splitk, lse_splitk)
1039
+
1040
+ if inp.query.ndim == 4:
1041
+ # BMGHK -> BMHK
1042
+ assert G == 1
1043
+ out = out[:, :, 0]
1044
+ if output_lse is not None:
1045
+ output_lse = output_lse[:, 0]
1046
+ if Mk == 0:
1047
+ out.zero_()
1048
+
1049
+ if attn_bias is not None and not variable_q:
1050
+ out = out.view(1, B * Mq, G, Hq, Kq)
1051
+
1052
+ if output_lse is None:
1053
+ return out, None
1054
+
1055
+ return out, Context(out=out, lse=output_lse)
1056
+
1057
+ @classmethod
1058
+ @functools.lru_cache
1059
+ def get_operator(
1060
+ cls,
1061
+ splitk: int,
1062
+ *,
1063
+ block_m: Optional[int] = None,
1064
+ block_n: Optional[int] = None,
1065
+ num_warps: Optional[int] = None,
1066
+ num_stages: Optional[int] = None,
1067
+ split_k_early_exit: Optional[bool] = None,
1068
+ ) -> Type[AttentionFwOpBase]:
1069
+ kwargs = {
1070
+ "NAME": f"triton_splitK{splitk}",
1071
+ "SPLIT_K": splitk,
1072
+ }
1073
+ if block_m is not None:
1074
+ kwargs["BLOCK_M"] = block_m
1075
+ if block_n is not None:
1076
+ kwargs["BLOCK_N"] = block_n
1077
+ if num_warps is not None:
1078
+ kwargs["NUM_WARPS"] = num_warps
1079
+ if num_stages is not None:
1080
+ kwargs["NUM_STAGES"] = num_stages
1081
+ if split_k_early_exit is not None:
1082
+ kwargs["SPLIT_K_EARLY_EXIT"] = split_k_early_exit
1083
+ return type(
1084
+ f"FwOp_S{splitk}",
1085
+ (cls,),
1086
+ kwargs,
1087
+ )
1088
+
1089
+
1090
+ def merge_attentions(
1091
+ attn_out: torch.Tensor,
1092
+ lse_out: Optional[torch.Tensor],
1093
+ attn_split: torch.Tensor,
1094
+ lse_split: torch.Tensor,
1095
+ ):
1096
+ import triton
1097
+
1098
+ from ._triton.splitk_kernels import _splitK_reduce
1099
+
1100
+ B, M, G, H, Kq = attn_out.shape
1101
+ B1, G1, H1, split_k, M1, Kq1 = attn_split.shape
1102
+ B2, G2, H2, split_k1, M2 = lse_split.shape
1103
+
1104
+ assert (
1105
+ B == B1 == B2
1106
+ and G == G1 == G2
1107
+ and H == H1 == H2
1108
+ and M == M1 == M2
1109
+ and Kq == Kq1
1110
+ ), (
1111
+ f"Incompatible shapes: {attn_out.shape=}, {attn_split.shape=}, {lse_split.shape=}"
1112
+ )
1113
+ assert split_k == split_k1, (
1114
+ f"Incompatible shapes: {attn_split.shape=}, {lse_split.shape=}"
1115
+ )
1116
+ if lse_out is not None:
1117
+ B3, G3, H3, M3 = lse_out.shape
1118
+ assert B == B3 and G == G3 and H == H3 and M == M3, (
1119
+ f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}"
1120
+ )
1121
+
1122
+ num_warps = 4 if B * G * H < 32 or torch.version.hip else 2
1123
+ splitK_pow2 = triton.next_power_of_2(split_k)
1124
+ head_dim = attn_out.shape[-1]
1125
+ grid = (M, B * G * H, 1)
1126
+ # pyre-ignore[28]
1127
+ _splitK_reduce[grid](
1128
+ attn_split,
1129
+ lse_split,
1130
+ attn_out,
1131
+ lse_out,
1132
+ split_k=split_k,
1133
+ splitK_pow2=splitK_pow2,
1134
+ **_strides(attn_split, "osk_z", "osk_g", "osk_h", "osk_s", "osk_m", "osk_k"),
1135
+ **_strides(lse_split, "lsek_z", "lsek_g", "lsek_h", "lsek_s", "lsek_m"),
1136
+ **_strides(attn_out, "oz", "om", "og", "oh", "ok"),
1137
+ **_strides(lse_out, "lse_z", "lse_g", "lse_h", "lse_m"),
1138
+ head_dim=head_dim,
1139
+ head_dim_pow_2=triton.next_power_of_2(head_dim),
1140
+ G=G,
1141
+ H=H,
1142
+ WRITE_LSE=lse_out is not None,
1143
+ num_warps=num_warps,
1144
+ )
1145
+
1146
+
1147
+ @torch.library.custom_op(
1148
+ "mslk::fmha_merge_attentions_varargs",
1149
+ mutates_args=(),
1150
+ device_types=["cuda"],
1151
+ )
1152
+ def merge_attentions_varargs(
1153
+ attn_split: Sequence[torch.Tensor],
1154
+ lse_split: Sequence[torch.Tensor],
1155
+ write_lse: bool,
1156
+ output_dtype: Optional[torch.dtype],
1157
+ B: int,
1158
+ M: int,
1159
+ G: int,
1160
+ H: int,
1161
+ Kq: int,
1162
+ ) -> List[torch.Tensor]:
1163
+ import triton
1164
+
1165
+ from ._triton.splitk_kernels import _splitK_reduce_varargs
1166
+ from ._triton.vararg_kernel import unroll_varargs
1167
+
1168
+ attn_out = torch.empty(
1169
+ (B, M, G, H, Kq),
1170
+ device=attn_split[0].device,
1171
+ dtype=output_dtype or attn_split[0].dtype,
1172
+ )
1173
+ if write_lse:
1174
+ lse_out = torch.empty(
1175
+ (B, G, H, M),
1176
+ device=attn_split[0].device,
1177
+ dtype=lse_split[0].dtype,
1178
+ )
1179
+ else:
1180
+ lse_out = None
1181
+ kernel_args, grid = _prepare_reduce_kernel_params(
1182
+ attn_out, lse_out, attn_split, lse_split
1183
+ )
1184
+ reduce_kernel = unroll_varargs(_splitK_reduce_varargs, N=len(attn_split))
1185
+ head_dim = attn_out.shape[-1]
1186
+ reduce_kernel[grid](
1187
+ *attn_split,
1188
+ *lse_split,
1189
+ Out=attn_out,
1190
+ LSE=lse_out,
1191
+ **kernel_args,
1192
+ head_dim=head_dim,
1193
+ head_dim_pow_2=triton.next_power_of_2(head_dim),
1194
+ WRITE_LSE=lse_out is not None,
1195
+ )
1196
+ if write_lse:
1197
+ assert lse_out is not None
1198
+ return [attn_out, lse_out]
1199
+ return [attn_out]
1200
+
1201
+
1202
+ @torch.library.register_fake("mslk::fmha_merge_attentions_varargs")
1203
+ def merge_attentions_varargs_fake(
1204
+ attn_split: Sequence[torch.Tensor],
1205
+ lse_split: Sequence[torch.Tensor],
1206
+ write_lse: bool,
1207
+ output_dtype: Optional[torch.dtype],
1208
+ B: int,
1209
+ M: int,
1210
+ G: int,
1211
+ H: int,
1212
+ Kq: int,
1213
+ ) -> List[torch.Tensor]:
1214
+ attn_out = torch.empty(
1215
+ (B, M, G, H, Kq),
1216
+ device=attn_split[0].device,
1217
+ dtype=output_dtype or attn_split[0].dtype,
1218
+ )
1219
+ if write_lse:
1220
+ lse_out = torch.empty(
1221
+ (B, G, H, M),
1222
+ device=attn_split[0].device,
1223
+ dtype=lse_split[0].dtype,
1224
+ )
1225
+ return [attn_out, lse_out]
1226
+ return [attn_out]
1227
+
1228
+
1229
+ def _merge_attentions_backward(
1230
+ ctx: torch.autograd.function.FunctionCtx,
1231
+ grad: List[torch.Tensor],
1232
+ ) -> Tuple[None, ...]:
1233
+ raise NotImplementedError(
1234
+ "Backward pass is not implemented for merge_attentions. "
1235
+ "If it was, it would be easy to get wrong attention gradients, "
1236
+ "because the gradients of the LSEs "
1237
+ "don't get propagated by attention backward."
1238
+ )
1239
+
1240
+
1241
+ merge_attentions_varargs.register_autograd(_merge_attentions_backward)
1242
+
1243
+
1244
+ @torch.library.custom_op(
1245
+ "mslk::merge_attentions_varargs_backward",
1246
+ mutates_args=(),
1247
+ device_types=["cuda"],
1248
+ )
1249
+ def merge_attentions_varargs_backward(
1250
+ attn_split: List[torch.Tensor],
1251
+ lse_split: List[torch.Tensor],
1252
+ attn_out: torch.Tensor,
1253
+ lse_out: torch.Tensor,
1254
+ grad_attn: torch.Tensor,
1255
+ grad_lse: torch.Tensor,
1256
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
1257
+ from ._triton.splitk_kernels import _splitK_reduce_varargs_backward
1258
+ from ._triton.vararg_kernel import unroll_varargs
1259
+
1260
+ dattn_splitk = [torch.empty_like(x) for x in attn_split]
1261
+ dlse_splitk = [torch.empty_like(x) for x in lse_split]
1262
+
1263
+ kernel_args, grid = _prepare_reduce_kernel_params(
1264
+ attn_out, lse_out, attn_split, lse_split, grad_attn, grad_lse
1265
+ )
1266
+
1267
+ reduce_kernel_backward = unroll_varargs(
1268
+ _splitK_reduce_varargs_backward, N=len(attn_split)
1269
+ )
1270
+ reduce_kernel_backward[grid](
1271
+ *attn_split,
1272
+ *lse_split,
1273
+ *dattn_splitk,
1274
+ *dlse_splitk,
1275
+ Out=attn_out,
1276
+ LSE=lse_out,
1277
+ DOut=grad_attn,
1278
+ DLSE=grad_lse,
1279
+ **kernel_args,
1280
+ BLOCK_SIZE=attn_out.shape[-1],
1281
+ )
1282
+
1283
+ return dattn_splitk, dlse_splitk
1284
+
1285
+
1286
+ @torch.library.register_fake("mslk::merge_attentions_varargs_backward")
1287
+ def merge_attentions_varargs_backward_fake(
1288
+ attn_split: List[torch.Tensor],
1289
+ lse_split: List[torch.Tensor],
1290
+ attn_out: torch.Tensor,
1291
+ lse_out: torch.Tensor,
1292
+ grad_attn: torch.Tensor,
1293
+ grad_lse: torch.Tensor,
1294
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
1295
+ dattn_splitk = [torch.empty_like(x) for x in attn_split]
1296
+ dlse_splitk = [torch.empty_like(x) for x in lse_split]
1297
+ return dattn_splitk, dlse_splitk
1298
+
1299
+
1300
+ def _prepare_reduce_kernel_params(
1301
+ attn_out: torch.Tensor,
1302
+ lse_out: Optional[torch.Tensor],
1303
+ attn_split: Sequence[torch.Tensor],
1304
+ lse_split: Sequence[torch.Tensor],
1305
+ grad_attn: Optional[torch.Tensor] = None,
1306
+ grad_lse: Optional[torch.Tensor] = None,
1307
+ ) -> Tuple[Dict[str, int], Tuple[int, int, int]]:
1308
+ B, M, G, H, Kq = attn_out.shape
1309
+ B1, G1, H1, M1, Kq1 = attn_split[0].shape
1310
+ B2, G2, H2, M2 = lse_split[0].shape
1311
+
1312
+ assert (
1313
+ B == B1 == B2
1314
+ and G == G1 == G2
1315
+ and H == H1 == H2
1316
+ and M == M1 == M2
1317
+ and Kq == Kq1
1318
+ ), (
1319
+ f"Incompatible shapes: {attn_out.shape=}, {attn_split[0].shape=}, {lse_split[0].shape=}"
1320
+ )
1321
+ if lse_out is not None:
1322
+ B3, G3, H3, M3 = lse_out.shape
1323
+ assert B == B3 and G == G3 and H == H3 and M == M3, (
1324
+ f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}"
1325
+ )
1326
+
1327
+ attn_split_strides = {}
1328
+ lse_split_strides = {}
1329
+ for i in range(len(attn_split)):
1330
+ attn_split_strides.update(
1331
+ _strides(
1332
+ attn_split[i],
1333
+ "osk_z" + str(i),
1334
+ "osk_g" + str(i),
1335
+ "osk_h" + str(i),
1336
+ "osk_m" + str(i),
1337
+ "osk_k" + str(i),
1338
+ )
1339
+ )
1340
+ lse_split_strides.update(
1341
+ _strides(
1342
+ lse_split[i],
1343
+ "lsek_z" + str(i),
1344
+ "lsek_g" + str(i),
1345
+ "lsek_h" + str(i),
1346
+ "lsek_m" + str(i),
1347
+ )
1348
+ )
1349
+
1350
+ num_warps = 4 if B * G * H < 32 or torch.version.hip else 2
1351
+ grid = (M, B * G * H, 1)
1352
+
1353
+ kernel_args = {
1354
+ "G": G,
1355
+ "H": H,
1356
+ "num_warps": num_warps,
1357
+ **attn_split_strides,
1358
+ **lse_split_strides,
1359
+ }
1360
+ kernel_args.update(_strides(attn_out, "oz", "om", "og", "oh", "ok"))
1361
+ kernel_args.update(_strides(lse_out, "lse_z", "lse_g", "lse_h", "lse_m"))
1362
+ if grad_attn is not None:
1363
+ kernel_args.update(_strides(grad_attn, "doz", "dom", "dog", "doh", "dok"))
1364
+ kernel_args.update(_strides(grad_lse, "dlse_z", "dlse_g", "dlse_h", "dlse_m"))
1365
+ return kernel_args, grid
1366
+
1367
+
1368
+ FwOp_Map = {
1369
+ k: FwOp.get_operator(k) for k in [1, 2, 4, 8, 16, 32, 48, 64, 72, 80, 96, 112, 128]
1370
+ }
1371
+ FwOp_S1 = FwOp_Map[1]
1372
+ FwOp_S2 = FwOp_Map[2]
1373
+ FwOp_S4 = FwOp_Map[4]
1374
+ FwOp_S8 = FwOp_Map[8]
1375
+ FwOp_S16 = FwOp_Map[16]
1376
+ FwOp_S32 = FwOp_Map[32]
1377
+ FwOp_S64 = FwOp_Map[64]
1378
+ FwOp_S128 = FwOp_Map[128]