mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1771 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.
4
+
5
+ # Supported features:
6
+ # - BF16 & FP16 dtype
7
+ # - noncausal & causal attention
8
+ # - MHA, GQA, MQA
9
+ # - hdim 64, 96, 128.
10
+ # - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape)
11
+ # - varlen
12
+ # - sliding window
13
+ # - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow)
14
+
15
+ # Features not supported yet:
16
+ # - split (i.e. FlashDecoding)
17
+ # - tuned block sizes
18
+ # - paged KV
19
+ # - append KV to existing KV cache
20
+ # - FP8
21
+ # - bwd pass optimized for Hopper/Blackwell
22
+
23
+ import math
24
+ from functools import lru_cache
25
+ from typing import Optional, Tuple, Callable
26
+
27
+ import torch
28
+
29
+
30
+ import cuda.bindings.driver as cuda
31
+
32
+ import cutlass
33
+ import cutlass.cute as cute
34
+
35
+ from mslk.attention.flash_attn import utils
36
+ from mslk.attention.flash_attn.cute_dsl_utils import to_cute_tensor
37
+ from mslk.attention.flash_attn.flash_fwd import FlashAttentionForwardSm90
38
+ from mslk.attention.flash_attn.flash_fwd_sm100 import FlashAttentionForwardSm100
39
+ from mslk.attention.flash_attn.flash_bwd_preprocess import FlashAttentionBackwardPreprocess
40
+ from mslk.attention.flash_attn.flash_bwd import FlashAttentionBackwardSm80
41
+ from mslk.attention.flash_attn.flash_bwd_sm90 import FlashAttentionBackwardSm90
42
+ from mslk.attention.flash_attn.flash_bwd_sm100 import FlashAttentionBackwardSm100
43
+ from mslk.attention.flash_attn.flash_bwd_postprocess import FlashAttentionBackwardPostprocess
44
+ from mslk.attention.flash_attn.flash_fwd_combine import FlashAttentionForwardCombine
45
+
46
+ from mslk.attention.flash_attn.block_sparsity import (
47
+ BlockSparseTensorsTorch,
48
+ to_cute_block_sparse_tensors,
49
+ normalize_block_sparse_tensors,
50
+ get_block_sparse_expected_shapes,
51
+ get_block_sparse_expected_shapes_bwd,
52
+ )
53
+
54
+ @lru_cache(maxsize=None)
55
+ def _get_device_capability():
56
+ """Cached device capability check."""
57
+ return torch.cuda.get_device_capability()[0]
58
+
59
+ def maybe_contiguous(x):
60
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
61
+
62
+
63
+ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
64
+ assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
65
+ assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
66
+ assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
67
+ assert t.is_cuda, f"{name} must be on CUDA"
68
+
69
+
70
+ torch2cute_dtype_map = {
71
+ torch.float16: cutlass.Float16,
72
+ torch.bfloat16: cutlass.BFloat16,
73
+ torch.float32: cutlass.Float32,
74
+ }
75
+
76
+
77
+ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):
78
+ # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
79
+ if num_n_blocks <= 4:
80
+ return 1
81
+
82
+ # NOTE: We should revisit this heuristic after persistence is supported for split KV.
83
+ # Sometimes, it's ideal to over-schedule splits for better efficiency.
84
+ return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
85
+
86
+
87
+ def _flash_attn_fwd(
88
+ q: torch.Tensor,
89
+ k: torch.Tensor,
90
+ v: torch.Tensor,
91
+ cu_seqlens_q: Optional[torch.Tensor] = None,
92
+ cu_seqlens_k: Optional[torch.Tensor] = None,
93
+ seqused_q: Optional[torch.Tensor] = None,
94
+ seqused_k: Optional[torch.Tensor] = None,
95
+ max_seqlen_q: Optional[int] = None,
96
+ max_seqlen_k: Optional[int] = None,
97
+ page_table: Optional[torch.Tensor] = None,
98
+ softmax_scale: Optional[float] = None,
99
+ causal: bool = False,
100
+ softcap: Optional[float] = None,
101
+ window_size_left: Optional[int] = None,
102
+ window_size_right: Optional[int] = None,
103
+ learnable_sink: Optional[torch.Tensor] = None,
104
+ # m_block_size: int = 128,
105
+ # n_block_size: int = 64,
106
+ # num_threads: int = 128,
107
+ m_block_size: int = 128,
108
+ n_block_size: int = 128,
109
+ num_threads: int = 384,
110
+ num_splits: int = 1,
111
+ pack_gqa: Optional[bool] = None,
112
+ _compute_capability: Optional[int] = None,
113
+ score_mod: Optional[Callable] = None,
114
+ mask_mod: Optional[Callable] = None,
115
+ block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
116
+ return_lse: bool = False,
117
+ out: Optional[torch.Tensor] = None,
118
+ lse: Optional[torch.Tensor] = None,
119
+ aux_tensors: Optional[list[torch.Tensor]] = None,
120
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ """Forward pass for FlashAttention.
122
+
123
+ Args:
124
+ ...
125
+ score_mod: A callable that takes the attention scores and applies a modification.
126
+ mask_mod: A callable that takes token position information and selectively masks
127
+ block_sparse_tensors: A tuple of tensors used for block sparsity.
128
+ return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
129
+ out: Optional pre-allocated output tensor. If None, will be allocated internally.
130
+ lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
131
+ aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
132
+ """
133
+ q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
134
+ num_head, head_dim = q.shape[-2:]
135
+ if cu_seqlens_q is None:
136
+ batch_size, seqlen_q = q.shape[:2]
137
+ total_q = batch_size * seqlen_q
138
+ else:
139
+ batch_size = cu_seqlens_q.shape[0] - 1
140
+ seqlen_q = None
141
+ total_q = q.shape[0]
142
+ if page_table is not None:
143
+ assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k"
144
+ assert page_table.dtype == torch.int32, "page_table must be int32"
145
+ assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension"
146
+ max_num_pages_per_seq = page_table.shape[1]
147
+ assert page_table.shape == (batch_size, max_num_pages_per_seq)
148
+ num_pages, page_size = k.shape[:2]
149
+ seqlen_k = num_pages * page_size
150
+ else:
151
+ num_pages, page_size = None, None
152
+ seqlen_k = k.shape[-3]
153
+ num_head_kv = k.shape[-2]
154
+ head_dim_v = v.shape[-1]
155
+ if cu_seqlens_k is None:
156
+ if page_table is None:
157
+ assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
158
+ assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
159
+ else:
160
+ assert k.shape == (num_pages, page_size, num_head_kv, head_dim)
161
+ assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)
162
+ else:
163
+ assert k.shape == (seqlen_k, num_head_kv, head_dim)
164
+ assert v.shape == (seqlen_k, num_head_kv, head_dim_v)
165
+ assert cu_seqlens_k.shape == (batch_size + 1,), (
166
+ "cu_seqlens_k must have shape (batch_size + 1,)"
167
+ )
168
+
169
+ if cu_seqlens_q is not None:
170
+ assert cu_seqlens_q.shape == (batch_size + 1,), (
171
+ "cu_seqlens_q must have shape (batch_size + 1,)"
172
+ )
173
+ assert seqused_q is None or seqused_q.shape == (batch_size,), (
174
+ "seqused_q must have shape (batch_size,)"
175
+ )
176
+ assert seqused_k is None or seqused_k.shape == (batch_size,), (
177
+ "seqused_k must have shape (batch_size,)"
178
+ )
179
+ assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
180
+ assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
181
+ for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
182
+ if t is not None:
183
+ assert t.dtype == torch.int32, (
184
+ "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
185
+ )
186
+ assert t.stride(0) == 1, (
187
+ "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous"
188
+ )
189
+ if learnable_sink is not None:
190
+ assert learnable_sink.shape == (num_head,)
191
+ assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
192
+
193
+ assert all(
194
+ t is None or t.is_cuda
195
+ for t in (
196
+ q,
197
+ k,
198
+ v,
199
+ cu_seqlens_q,
200
+ cu_seqlens_k,
201
+ seqused_q,
202
+ seqused_k,
203
+ page_table,
204
+ learnable_sink,
205
+ )
206
+ ), "inputs must be on CUDA device"
207
+ assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
208
+ assert head_dim <= 256, "head_dim must be less than or equal to 256"
209
+ alignment = 16 // q.element_size()
210
+ assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
211
+ assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
212
+ if softmax_scale is None:
213
+ softmax_scale = 1.0 / math.sqrt(head_dim)
214
+ if softcap == 0.0:
215
+ softcap = None
216
+ qhead_per_kvhead = num_head // num_head_kv
217
+ if pack_gqa is None:
218
+ pack_gqa = qhead_per_kvhead > 1
219
+
220
+ out_torch_dtype = q.dtype
221
+ device = q.device
222
+ q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
223
+ lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q)
224
+ requires_grad = q.requires_grad or k.requires_grad or v.requires_grad
225
+
226
+ if out is None:
227
+ out = torch.empty(
228
+ *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device
229
+ )
230
+ else:
231
+ _validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device)
232
+
233
+ if lse is None:
234
+ lse = (
235
+ torch.empty(lse_shape, dtype=torch.float32, device=device)
236
+ if requires_grad or return_lse
237
+ else None
238
+ )
239
+ elif lse is not None:
240
+ _validate_tensor(lse, "lse", lse_shape, torch.float32, device)
241
+
242
+ dtype = torch2cute_dtype_map[q.dtype]
243
+ compute_capability = (
244
+ _get_device_capability()
245
+ if _compute_capability is None
246
+ else _compute_capability
247
+ )
248
+
249
+ assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
250
+
251
+ use_block_sparsity = block_sparse_tensors is not None
252
+
253
+ if mask_mod is None:
254
+ if causal:
255
+ window_size_right = 0
256
+ local = window_size_left is not None or window_size_right is not None
257
+ if window_size_left is not None or window_size_right is not None:
258
+ if window_size_left is None and window_size_right == 0:
259
+ causal, local = True, False
260
+ window_size_right = None
261
+ else:
262
+ causal, local = False, True
263
+ else:
264
+ causal, local = False, False
265
+
266
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
267
+
268
+ if compute_capability == 9: # TODO: tune block size according to hdim.
269
+ if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity:
270
+ n_block_size = 192
271
+
272
+ if compute_capability in [10, 11]:
273
+ if (
274
+ pack_gqa
275
+ and (128 % qhead_per_kvhead != 0)
276
+ ):
277
+ pack_gqa = False
278
+ # TODO: fix GQA + SplitKV + non-varlen
279
+ if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
280
+ pack_gqa = False
281
+
282
+ if max_seqlen_q is None:
283
+ max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
284
+ if max_seqlen_k is None:
285
+ max_seqlen_k = seqlen_k
286
+ seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
287
+ if compute_capability == 10:
288
+ q_stage = 2 if seqlen_q_packgqa > m_block_size else 1
289
+ else:
290
+ q_stage = 1
291
+
292
+ if num_splits < 1:
293
+ m_block_size_effective = q_stage * m_block_size
294
+ seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size))
295
+ num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
296
+ num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective
297
+ total_mblocks = batch_size * num_head_kv * num_m_blocks
298
+ num_splits = num_splits_heuristic(
299
+ total_mblocks,
300
+ torch.cuda.get_device_properties(device).multi_processor_count,
301
+ num_n_blocks,
302
+ 128,
303
+ )
304
+
305
+ is_split_kv = num_splits > 1
306
+ if is_split_kv:
307
+ out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
308
+ lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)
309
+
310
+ # hash score and mask mods for compile cache
311
+ score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
312
+ mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
313
+
314
+ if softcap is not None:
315
+ assert score_mod is None, "softcap and score_mod cannot be used together"
316
+ score_mod = utils.create_softcap_scoremod(softcap)
317
+
318
+ is_varlen = (
319
+ cu_seqlens_q is not None
320
+ or cu_seqlens_k is not None
321
+ or seqused_q is not None
322
+ or seqused_k is not None
323
+ )
324
+
325
+ if mask_mod is not None:
326
+ if is_varlen:
327
+ raise NotImplementedError(
328
+ "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR."
329
+ )
330
+
331
+ if use_block_sparsity:
332
+ if is_varlen:
333
+ raise NotImplementedError(
334
+ "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR."
335
+ )
336
+ # NB: pack_gqa requires block sparse head dim == 1 (broadcasted)
337
+ if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1:
338
+ pack_gqa = False
339
+ if is_split_kv:
340
+ raise NotImplementedError(
341
+ "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
342
+ )
343
+
344
+ compile_key = (
345
+ dtype,
346
+ head_dim,
347
+ head_dim_v,
348
+ qhead_per_kvhead,
349
+ causal,
350
+ score_mod_hash,
351
+ mask_mod_hash,
352
+ use_block_sparsity,
353
+ len(aux_tensors) if aux_tensors is not None else 0,
354
+ lse is None,
355
+ cu_seqlens_q is None,
356
+ cu_seqlens_k is None,
357
+ seqused_q is None,
358
+ seqused_k is None,
359
+ page_table is not None,
360
+ window_size_left is not None,
361
+ window_size_right is not None,
362
+ learnable_sink is not None,
363
+ m_block_size,
364
+ n_block_size,
365
+ q_stage,
366
+ num_threads,
367
+ is_split_kv,
368
+ pack_gqa,
369
+ compute_capability,
370
+ page_size not in [None, 128], # paged KV non-TMA
371
+ )
372
+ if compile_key not in _flash_attn_fwd.compile_cache:
373
+ (
374
+ cu_seqlens_q_tensor,
375
+ cu_seqlens_k_tensor,
376
+ seqused_q_tensor,
377
+ seqused_k_tensor,
378
+ learnable_sink_tensor,
379
+ ) = [
380
+ to_cute_tensor(t, assumed_align=4, leading_dim=0)
381
+ if t is not None
382
+ else None
383
+ for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
384
+ ]
385
+ page_table_tensor = (
386
+ to_cute_tensor(page_table, assumed_align=4, leading_dim=1)
387
+ if page_table is not None
388
+ else None
389
+ )
390
+ q_tensor, k_tensor, v_tensor, o_tensor = [
391
+ to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial)
392
+ ]
393
+ if is_split_kv:
394
+ lse_tensor = to_cute_tensor(lse_partial, assumed_align=4)
395
+ elif lse is not None:
396
+ lse_tensor = to_cute_tensor(lse, assumed_align=4)
397
+ else:
398
+ lse_tensor = None
399
+
400
+ sparse_tensors = None
401
+ if block_sparse_tensors is not None:
402
+ if seqlen_q is None:
403
+ raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).")
404
+ expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
405
+ batch_size, num_head, seqlen_q, seqlen_k,
406
+ m_block_size, n_block_size, q_stage,
407
+ )
408
+ compile_time_normalized = normalize_block_sparse_tensors(
409
+ block_sparse_tensors,
410
+ expected_count_shape=expected_count_shape,
411
+ expected_index_shape=expected_index_shape,
412
+ )
413
+ sparse_tensors = to_cute_block_sparse_tensors(compile_time_normalized)
414
+
415
+ cute_aux_tensors = None
416
+ if aux_tensors is not None:
417
+ cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]
418
+
419
+ if compute_capability == 9:
420
+ assert page_table is None, "paged KV not supported on SM 9.0"
421
+ assert not is_split_kv, "SplitKV not supported on SM 9.0"
422
+ # fa_fwd = FlashAttentionForwardSm80(
423
+ fa_fwd = FlashAttentionForwardSm90(
424
+ dtype,
425
+ head_dim,
426
+ head_dim_v,
427
+ qhead_per_kvhead,
428
+ is_causal=causal,
429
+ is_local=local,
430
+ pack_gqa=pack_gqa,
431
+ tile_m=m_block_size,
432
+ tile_n=n_block_size,
433
+ # num_stages=1,
434
+ num_stages=2,
435
+ num_threads=num_threads,
436
+ Q_in_regs=False,
437
+ intra_wg_overlap=True,
438
+ mma_pv_is_rs=True,
439
+ mask_mod=mask_mod,
440
+ score_mod=score_mod,
441
+ has_aux_tensors=aux_tensors is not None,
442
+ )
443
+ elif compute_capability in [10, 11]:
444
+ fa_fwd = FlashAttentionForwardSm100(
445
+ head_dim,
446
+ head_dim_v,
447
+ qhead_per_kvhead=qhead_per_kvhead,
448
+ is_causal=causal,
449
+ is_local=local,
450
+ is_split_kv=is_split_kv,
451
+ pack_gqa=pack_gqa,
452
+ m_block_size=m_block_size,
453
+ n_block_size=n_block_size,
454
+ q_stage=q_stage,
455
+ is_persistent=not causal
456
+ and not local
457
+ and cu_seqlens_q is None
458
+ and seqused_q is None
459
+ and not is_split_kv,
460
+ score_mod=score_mod,
461
+ mask_mod=mask_mod,
462
+ has_aux_tensors=aux_tensors is not None,
463
+ paged_kv_non_tma=page_size not in [None, 128],
464
+ is_varlen_q=cu_seqlens_q is not None
465
+ or seqused_q is not None,
466
+ )
467
+ else:
468
+ raise ValueError(
469
+ f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x, 11.x"
470
+ )
471
+ # TODO: check @can_implement
472
+ _flash_attn_fwd.compile_cache[compile_key] = cute.compile(
473
+ fa_fwd,
474
+ q_tensor,
475
+ k_tensor,
476
+ v_tensor,
477
+ o_tensor,
478
+ lse_tensor,
479
+ softmax_scale,
480
+ current_stream,
481
+ cu_seqlens_q_tensor,
482
+ cu_seqlens_k_tensor,
483
+ seqused_q_tensor,
484
+ seqused_k_tensor,
485
+ page_table_tensor,
486
+ window_size_left,
487
+ window_size_right,
488
+ learnable_sink_tensor,
489
+ sparse_tensors,
490
+ cute_aux_tensors,
491
+ options="--enable-tvm-ffi",
492
+ )
493
+
494
+ # Expand block sparse tensors to match actual head count (may be broadcast from 1)
495
+ normalized_block_sparse_tensors = None
496
+ if block_sparse_tensors is not None:
497
+ expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
498
+ batch_size, num_head, seqlen_q, seqlen_k,
499
+ m_block_size, n_block_size, q_stage,
500
+ )
501
+ normalized_block_sparse_tensors = normalize_block_sparse_tensors(
502
+ block_sparse_tensors,
503
+ expected_count_shape=expected_count_shape,
504
+ expected_index_shape=expected_index_shape,
505
+ )
506
+ _flash_attn_fwd.compile_cache[compile_key](
507
+ q,
508
+ k,
509
+ v,
510
+ out if not is_split_kv else out_partial,
511
+ lse_partial if is_split_kv else lse,
512
+ softmax_scale,
513
+ current_stream,
514
+ cu_seqlens_q,
515
+ cu_seqlens_k,
516
+ seqused_q,
517
+ seqused_k,
518
+ page_table,
519
+ window_size_left,
520
+ window_size_right,
521
+ learnable_sink,
522
+ normalized_block_sparse_tensors,
523
+ aux_tensors,
524
+ )
525
+ if is_split_kv:
526
+ _flash_attn_fwd_combine(
527
+ out_partial,
528
+ lse_partial.transpose(-1, -2),
529
+ out,
530
+ lse.transpose(-1, -2) if lse is not None else None,
531
+ cu_seqlens_q,
532
+ seqused_q,
533
+ )
534
+ return out, lse
535
+
536
+
537
+ _flash_attn_fwd.compile_cache = {}
538
+
539
+
540
+ def _flash_attn_bwd(
541
+ q: torch.Tensor,
542
+ k: torch.Tensor,
543
+ v: torch.Tensor,
544
+ out: torch.Tensor,
545
+ dout: torch.Tensor,
546
+ lse: torch.Tensor,
547
+ softmax_scale: Optional[float] = None,
548
+ causal: bool = False,
549
+ softcap: float = 0.0,
550
+ window_size_left: Optional[int] = None,
551
+ window_size_right: Optional[int] = None,
552
+ m_block_size: int = 64,
553
+ n_block_size: int = 128,
554
+ num_threads: int = 256,
555
+ pack_gqa: bool = False,
556
+ num_stages_Q: int = 2,
557
+ num_stages_dO: int = 2,
558
+ SdP_swapAB: bool = False,
559
+ dKV_swapAB: bool = False,
560
+ dQ_swapAB: bool = False,
561
+ AtomLayoutMSdP: int = 2,
562
+ AtomLayoutNdKV: int = 2,
563
+ AtomLayoutMdQ: int = 2,
564
+ V_in_regs: bool = False,
565
+ cu_seqlens_q: Optional[torch.Tensor] = None,
566
+ cu_seqlens_k: Optional[torch.Tensor] = None,
567
+ seqused_q: Optional[torch.Tensor] = None,
568
+ seqused_k: Optional[torch.Tensor] = None,
569
+ max_seqlen_q: Optional[int] = None,
570
+ max_seqlen_k: Optional[int] = None,
571
+ deterministic: bool = False,
572
+ dq: Optional[torch.Tensor] = None,
573
+ dk: Optional[torch.Tensor] = None,
574
+ dv: Optional[torch.Tensor] = None,
575
+ score_mod: Optional[Callable] = None,
576
+ score_mod_bwd: Optional[Callable] = None,
577
+ mask_mod: Optional[Callable] = None,
578
+ aux_tensors: Optional[list[torch.Tensor]] = None,
579
+ block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
580
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
581
+ compute_capability = _get_device_capability()
582
+ assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
583
+
584
+ if compute_capability == 9:
585
+ m_block_size = 80 if not causal else 64
586
+ n_block_size = 128
587
+ num_stages_Q = 2
588
+ num_stages_dO = 2
589
+ num_stages_PdS = 2
590
+ SdP_swapAB = True
591
+ dKV_swapAB = False
592
+ dQ_swapAB = not causal
593
+ AtomLayoutMSdP = 1
594
+ AtomLayoutNdKV = 2
595
+ AtomLayoutMdQ = 1
596
+ cluster_size = 1
597
+ assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
598
+ else:
599
+ m_block_size = 128
600
+ n_block_size = 128
601
+ dQ_swapAB = False
602
+ dKV_swapAB = False
603
+ AtomLayoutMdQ = 1
604
+ AtomLayoutNdKV = 1
605
+ # TODO: support cluster size 2
606
+ cluster_size = 1
607
+ q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
608
+ maybe_contiguous(t)
609
+ for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
610
+ ]
611
+ num_head, head_dim = q.shape[-2:]
612
+ if cu_seqlens_q is None:
613
+ batch_size, seqlen_q = q.shape[:2]
614
+ total_q = batch_size * seqlen_q
615
+ else:
616
+ batch_size = cu_seqlens_q.shape[0] - 1
617
+ total_q = q.shape[0]
618
+ seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q
619
+
620
+ if cu_seqlens_k is None:
621
+ batch_size, seqlen_k = k.shape[:2]
622
+ total_k = batch_size * seqlen_k
623
+ else:
624
+ batch_size = cu_seqlens_k.shape[0] - 1
625
+ total_k = k.shape[0]
626
+ seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
627
+
628
+ num_head_kv = k.shape[-2]
629
+ head_dim_v = v.shape[-1]
630
+
631
+ if causal:
632
+ window_size_right = 0
633
+ local = window_size_left is not None or window_size_right is not None
634
+ if local:
635
+ if window_size_left is None and window_size_right == 0:
636
+ causal, local = True, False
637
+ window_size_right = None
638
+ else:
639
+ causal, local = False, True
640
+
641
+ use_block_sparsity = block_sparse_tensors is not None
642
+
643
+ # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits,
644
+ # the base block_m of 128 from forward, and block-sparse size for subtiling.
645
+ if compute_capability == 9 and use_block_sparsity:
646
+ m_block_size = 64
647
+ # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case)
648
+ dQ_swapAB = False
649
+
650
+ # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2
651
+ subtile_factor = 2
652
+ sparse_block_size_q = subtile_factor * m_block_size
653
+
654
+ seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
655
+ seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
656
+
657
+ if cu_seqlens_k is None:
658
+ assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
659
+ assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
660
+ else:
661
+ assert k.shape == (total_k, num_head_kv, head_dim)
662
+ assert v.shape == (total_k, num_head_kv, head_dim_v)
663
+ assert cu_seqlens_k.shape == (batch_size + 1,), (
664
+ "cu_seqlens_k must have shape (batch_size + 1,)"
665
+ )
666
+
667
+ if cu_seqlens_q is not None:
668
+ assert cu_seqlens_q.shape == (batch_size + 1,), (
669
+ "cu_seqlens_q must have shape (batch_size + 1,)"
670
+ )
671
+
672
+ assert out.shape == (total_q, num_head, head_dim_v)
673
+ assert dout.shape == (total_q, num_head, head_dim_v)
674
+ assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)"
675
+ else:
676
+ assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v)
677
+ assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v)
678
+ assert lse.shape == (batch_size, num_head, seqlen_q), (
679
+ "lse must have shape (batch_size, num_head, seqlen_q)"
680
+ )
681
+
682
+ assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
683
+ assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, (
684
+ "inputs must have the same dtype"
685
+ )
686
+ for t in [cu_seqlens_q, cu_seqlens_k]:
687
+ if t is not None:
688
+ assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32"
689
+ assert lse.dtype == torch.float32, "lse must be float32"
690
+ assert all(
691
+ t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
692
+ ), "inputs must be on CUDA device"
693
+ assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
694
+ assert head_dim <= 256, "head_dim must be less than or equal to 256"
695
+ alignment = 16 // q.element_size()
696
+ assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
697
+ assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
698
+ if softmax_scale is None:
699
+ softmax_scale = 1.0 / math.sqrt(head_dim)
700
+ qhead_per_kvhead = num_head // num_head_kv
701
+ if pack_gqa is None:
702
+ pack_gqa = qhead_per_kvhead > 1
703
+ # pack_gqa backward not yet supported in bwd
704
+ pack_gqa = False
705
+ if compute_capability not in [10, 11]:
706
+ assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now"
707
+
708
+ if score_mod is not None:
709
+ assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
710
+ assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
711
+ assert cu_seqlens_q is None and cu_seqlens_k is None, (
712
+ "varlen + score_mod not supported in bwd yet"
713
+ )
714
+
715
+ device = q.device
716
+ out_torch_dtype = q.dtype
717
+
718
+ if dq is None:
719
+ dq = torch.empty_like(q)
720
+ else:
721
+ _validate_tensor(dq, "dq", q.shape, out_torch_dtype, device)
722
+
723
+ if dk is None:
724
+ dk = torch.empty_like(k)
725
+ else:
726
+ _validate_tensor(dk, "dk", k.shape, out_torch_dtype, device)
727
+
728
+ if dv is None:
729
+ dv = torch.empty_like(v)
730
+ else:
731
+ _validate_tensor(dv, "dv", v.shape, out_torch_dtype, device)
732
+
733
+ head_dim_rounded = (head_dim + 32 - 1) // 32 * 32
734
+
735
+ if cu_seqlens_q is None:
736
+ dq_accum = torch.empty(
737
+ batch_size,
738
+ num_head,
739
+ seqlen_q_rounded * head_dim_rounded,
740
+ dtype=torch.float32,
741
+ device=device,
742
+ )
743
+ dpsum = torch.empty(
744
+ batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device
745
+ )
746
+ lse_log2 = torch.empty(
747
+ batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device
748
+ )
749
+ else:
750
+ total_q_rounded_padded = (
751
+ (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size
752
+ )
753
+ dq_accum = torch.empty(
754
+ num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device
755
+ )
756
+ dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
757
+ lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
758
+
759
+ dKV_postprocess = qhead_per_kvhead > 1
760
+ if dKV_postprocess:
761
+ head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
762
+ if cu_seqlens_k is None:
763
+ num_n_blocks = seqlen_k_rounded // n_block_size
764
+ if cluster_size == 2 and num_n_blocks % cluster_size != 0:
765
+ seqlen_k_rounded = seqlen_k_rounded + n_block_size
766
+ dk_accum = torch.zeros(
767
+ batch_size,
768
+ num_head_kv,
769
+ seqlen_k_rounded * head_dim_rounded,
770
+ dtype=torch.float32,
771
+ device=device,
772
+ )
773
+ dv_accum = torch.zeros(
774
+ batch_size,
775
+ num_head_kv,
776
+ seqlen_k_rounded * head_dim_v_rounded,
777
+ dtype=torch.float32,
778
+ device=device,
779
+ )
780
+ else:
781
+ total_k_rounded_padded = (
782
+ (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size
783
+ )
784
+ num_n_blocks = total_k_rounded_padded // n_block_size
785
+ if cluster_size == 2 and num_n_blocks % cluster_size != 0:
786
+ total_k_rounded_padded = total_k_rounded_padded + n_block_size
787
+ dk_accum = torch.zeros(
788
+ num_head_kv,
789
+ total_k_rounded_padded * head_dim_rounded,
790
+ dtype=torch.float32,
791
+ device=device,
792
+ )
793
+ dv_accum = torch.zeros(
794
+ num_head_kv,
795
+ total_k_rounded_padded * head_dim_v_rounded,
796
+ dtype=torch.float32,
797
+ device=device,
798
+ )
799
+
800
+ dtype = torch2cute_dtype_map[q.dtype]
801
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
802
+
803
+ if deterministic:
804
+ dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda")
805
+ else:
806
+ dQ_semaphore = None
807
+
808
+ if deterministic and qhead_per_kvhead > 1:
809
+ dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
810
+ dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
811
+ else:
812
+ dK_semaphore = None
813
+ dV_semaphore = None
814
+
815
+ # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
816
+ compile_key_pre = (
817
+ compute_capability,
818
+ dtype,
819
+ head_dim_v,
820
+ m_block_size,
821
+ num_threads,
822
+ cu_seqlens_q is None,
823
+ seqused_q is None,
824
+ )
825
+ if compile_key_pre not in _flash_attn_bwd.compile_cache_pre:
826
+ o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)]
827
+ dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
828
+ to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
829
+ ]
830
+ lse_tensor = to_cute_tensor(lse, assumed_align=4)
831
+ cu_seqlens_q_tensor, seqused_q_tensor = [
832
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
833
+ for t in (cu_seqlens_q, seqused_q)
834
+ ]
835
+ arch = compute_capability * 10
836
+ fa_bwd_pre = FlashAttentionBackwardPreprocess(
837
+ dtype,
838
+ head_dim_v,
839
+ arch,
840
+ m_block_size,
841
+ num_threads=num_threads,
842
+ )
843
+ # TODO: check @can_implement
844
+ _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile(
845
+ fa_bwd_pre,
846
+ o_tensor,
847
+ do_tensor,
848
+ dpsum_tensor,
849
+ lse_tensor,
850
+ lse_log2_tensor,
851
+ dq_accum_tensor,
852
+ cu_seqlens_q_tensor,
853
+ seqused_q_tensor,
854
+ current_stream,
855
+ options="--enable-tvm-ffi",
856
+ )
857
+ _flash_attn_bwd.compile_cache_pre[compile_key_pre](
858
+ out,
859
+ dout,
860
+ dpsum,
861
+ lse,
862
+ lse_log2,
863
+ dq_accum,
864
+ cu_seqlens_q,
865
+ seqused_q,
866
+ current_stream,
867
+ )
868
+
869
+ # NB num_threads application for 3 kernels
870
+ # There are pre, main, post processing kernels, currenlty num_threads is only actually
871
+ # used for the pre proc, and then we hard code to 384 for the main and post proc, and we do
872
+ # before cache key gen
873
+ num_threads = 384
874
+
875
+ # Backward kernel: compute dk, dv, dq_accum.
876
+ score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
877
+ score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False
878
+ mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False
879
+ num_aux_tensors = len(aux_tensors) if aux_tensors else 0
880
+ cute_aux_tensors = None
881
+ if aux_tensors is not None:
882
+ cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]
883
+
884
+ if compute_capability == 9:
885
+ compile_key = (
886
+ compute_capability,
887
+ dtype,
888
+ head_dim,
889
+ head_dim_v,
890
+ qhead_per_kvhead,
891
+ causal,
892
+ softcap != 0.0,
893
+ m_block_size,
894
+ n_block_size,
895
+ num_threads,
896
+ pack_gqa,
897
+ num_stages_Q,
898
+ num_stages_dO,
899
+ SdP_swapAB,
900
+ dKV_swapAB,
901
+ dQ_swapAB,
902
+ AtomLayoutMSdP,
903
+ AtomLayoutNdKV,
904
+ AtomLayoutMdQ,
905
+ V_in_regs,
906
+ cu_seqlens_q is None,
907
+ cu_seqlens_k is None,
908
+ seqused_q is None,
909
+ seqused_k is None,
910
+ score_mod_hash,
911
+ score_mod_bwd_hash,
912
+ mask_mod_hash,
913
+ num_aux_tensors,
914
+ use_block_sparsity,
915
+ )
916
+ else:
917
+ compile_key = (
918
+ compute_capability,
919
+ dtype,
920
+ head_dim,
921
+ head_dim_v,
922
+ qhead_per_kvhead,
923
+ causal,
924
+ window_size_left is not None,
925
+ window_size_right is not None,
926
+ softcap != 0.0,
927
+ m_block_size,
928
+ n_block_size,
929
+ num_threads,
930
+ pack_gqa,
931
+ cluster_size,
932
+ deterministic,
933
+ score_mod_hash,
934
+ score_mod_bwd_hash,
935
+ mask_mod_hash,
936
+ num_aux_tensors,
937
+ use_block_sparsity,
938
+ cu_seqlens_q is None,
939
+ cu_seqlens_k is None,
940
+ seqused_q is None,
941
+ seqused_k is None,
942
+ )
943
+ if compile_key not in _flash_attn_bwd.compile_cache:
944
+ q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
945
+ to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv)
946
+ ]
947
+ dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
948
+ to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
949
+ ]
950
+ if dKV_postprocess:
951
+ dk_accum_tensor, dv_accum_tensor = [
952
+ to_cute_tensor(t) for t in (dk_accum, dv_accum)
953
+ ]
954
+ cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [
955
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
956
+ for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
957
+ ]
958
+ dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [
959
+ utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order())
960
+ if t is not None else None
961
+ for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
962
+ ]
963
+ fa_bwd_sm80 = FlashAttentionBackwardSm80(
964
+ dtype,
965
+ head_dim,
966
+ head_dim_v,
967
+ qhead_per_kvhead,
968
+ m_block_size,
969
+ n_block_size,
970
+ num_stages_Q,
971
+ num_stages_dO,
972
+ num_threads,
973
+ pack_gqa,
974
+ causal,
975
+ SdP_swapAB,
976
+ dKV_swapAB,
977
+ dQ_swapAB,
978
+ AtomLayoutMSdP,
979
+ AtomLayoutNdKV,
980
+ AtomLayoutMdQ,
981
+ V_in_regs=V_in_regs,
982
+ )
983
+ if compute_capability == 9:
984
+ fa_bwd_obj = FlashAttentionBackwardSm90(
985
+ dtype,
986
+ head_dim,
987
+ head_dim_v,
988
+ qhead_per_kvhead,
989
+ causal,
990
+ m_block_size,
991
+ n_block_size,
992
+ num_stages_Q,
993
+ num_stages_dO,
994
+ num_stages_PdS,
995
+ SdP_swapAB,
996
+ dKV_swapAB,
997
+ dQ_swapAB,
998
+ AtomLayoutMSdP,
999
+ AtomLayoutNdKV,
1000
+ AtomLayoutMdQ,
1001
+ num_threads,
1002
+ V_in_regs=V_in_regs,
1003
+ score_mod=score_mod,
1004
+ score_mod_bwd=score_mod_bwd,
1005
+ mask_mod=mask_mod,
1006
+ has_aux_tensors=aux_tensors is not None,
1007
+ subtile_factor=subtile_factor,
1008
+ )
1009
+ else:
1010
+ fa_bwd_obj = FlashAttentionBackwardSm100(
1011
+ head_dim,
1012
+ head_dim_v,
1013
+ is_causal=causal,
1014
+ is_local=local,
1015
+ qhead_per_kvhead=qhead_per_kvhead,
1016
+ # tile_m=m_block_size,
1017
+ # tile_n=n_block_size,
1018
+ cluster_size=cluster_size,
1019
+ # cluster_size=1,
1020
+ deterministic=deterministic,
1021
+ score_mod=score_mod,
1022
+ score_mod_bwd=score_mod_bwd,
1023
+ mask_mod=mask_mod,
1024
+ has_aux_tensors=aux_tensors is not None,
1025
+ subtile_factor=subtile_factor,
1026
+ )
1027
+
1028
+ # Block sparse tensors for backward use Q-direction indexing (transposed from forward).
1029
+ # sparse_block_size_q = subtile_factor * tile_m matches BlockMask granularity.
1030
+ sparse_tensors_compile = None
1031
+ if block_sparse_tensors is not None:
1032
+ expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
1033
+ batch_size, num_head, seqlen_q, seqlen_k,
1034
+ m_block_size, n_block_size, subtile_factor,
1035
+ )
1036
+ compile_time_normalized = normalize_block_sparse_tensors(
1037
+ block_sparse_tensors,
1038
+ expected_count_shape=expected_count_shape,
1039
+ expected_index_shape=expected_index_shape,
1040
+ context="_flash_attn_bwd",
1041
+ hint=lambda: (
1042
+ f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
1043
+ f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
1044
+ f"(sparse_block_size_q={sparse_block_size_q})."
1045
+ ),
1046
+ )
1047
+ sparse_tensors_compile = to_cute_block_sparse_tensors(compile_time_normalized)
1048
+
1049
+ # TODO: check @can_implement
1050
+ _flash_attn_bwd.compile_cache[compile_key] = cute.compile(
1051
+ fa_bwd_obj,
1052
+ q_tensor,
1053
+ k_tensor,
1054
+ v_tensor,
1055
+ do_tensor,
1056
+ lse_log2_tensor,
1057
+ dpsum_tensor,
1058
+ dq_accum_tensor,
1059
+ dk_tensor if not dKV_postprocess else dk_accum_tensor,
1060
+ dv_tensor if not dKV_postprocess else dv_accum_tensor,
1061
+ softmax_scale,
1062
+ current_stream,
1063
+ cu_seqlens_q_tensor,
1064
+ cu_seqlens_k_tensor,
1065
+ seqused_q_tensor,
1066
+ seqused_k_tensor,
1067
+ None, # softcap - not yet supported in backward
1068
+ window_size_left,
1069
+ window_size_right,
1070
+ dQ_semaphore_tensor,
1071
+ dK_semaphore_tensor,
1072
+ dV_semaphore_tensor,
1073
+ cute_aux_tensors,
1074
+ sparse_tensors_compile,
1075
+ options="--enable-tvm-ffi",
1076
+ )
1077
+ # Runtime normalization of block sparse tensors for both SM90 and SM100
1078
+ normalized_block_sparse_tensors = None
1079
+ if block_sparse_tensors is not None:
1080
+ expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
1081
+ batch_size, num_head, seqlen_q, seqlen_k,
1082
+ m_block_size, n_block_size, subtile_factor,
1083
+ )
1084
+ normalized_block_sparse_tensors = normalize_block_sparse_tensors(
1085
+ block_sparse_tensors,
1086
+ expected_count_shape=expected_count_shape,
1087
+ expected_index_shape=expected_index_shape,
1088
+ context="_flash_attn_bwd",
1089
+ hint=lambda: (
1090
+ f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
1091
+ f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
1092
+ f"(sparse_block_size_q={sparse_block_size_q})."
1093
+ ),
1094
+ )
1095
+
1096
+ _flash_attn_bwd.compile_cache[compile_key](
1097
+ q,
1098
+ k,
1099
+ v,
1100
+ dout,
1101
+ lse_log2,
1102
+ dpsum,
1103
+ dq_accum,
1104
+ dk if not dKV_postprocess else dk_accum,
1105
+ dv if not dKV_postprocess else dv_accum,
1106
+ softmax_scale,
1107
+ current_stream,
1108
+ cu_seqlens_q,
1109
+ cu_seqlens_k,
1110
+ seqused_q,
1111
+ seqused_k,
1112
+ None, # softcap - not yet supported in backward
1113
+ window_size_left,
1114
+ window_size_right,
1115
+ dQ_semaphore,
1116
+ dK_semaphore,
1117
+ dV_semaphore,
1118
+ aux_tensors,
1119
+ normalized_block_sparse_tensors,
1120
+ )
1121
+
1122
+ num_threads = 256 if compute_capability == 9 else 128
1123
+ arch = compute_capability * 10
1124
+ # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16
1125
+ compile_key_post = (
1126
+ compute_capability,
1127
+ dtype,
1128
+ head_dim,
1129
+ m_block_size,
1130
+ num_threads,
1131
+ AtomLayoutMdQ,
1132
+ dQ_swapAB,
1133
+ cu_seqlens_q is None,
1134
+ seqused_q is None,
1135
+ )
1136
+ if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1137
+ dq_accum_tensor = to_cute_tensor(dq_accum)
1138
+ dq_tensor = to_cute_tensor(dq)
1139
+ cu_seqlens_q_tensor, seqused_q_tensor = [
1140
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
1141
+ for t in (cu_seqlens_q, seqused_q)
1142
+ ]
1143
+ fa_bwd_post = FlashAttentionBackwardPostprocess(
1144
+ dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB
1145
+ )
1146
+ # TODO: check @can_implement
1147
+ _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1148
+ fa_bwd_post,
1149
+ dq_accum_tensor,
1150
+ dq_tensor,
1151
+ softmax_scale,
1152
+ cu_seqlens_q_tensor,
1153
+ seqused_q_tensor,
1154
+ current_stream,
1155
+ options="--enable-tvm-ffi",
1156
+ )
1157
+ _flash_attn_bwd.compile_cache_post[compile_key_post](
1158
+ dq_accum,
1159
+ dq,
1160
+ softmax_scale,
1161
+ cu_seqlens_q,
1162
+ seqused_q,
1163
+ current_stream,
1164
+ )
1165
+
1166
+ if dKV_postprocess:
1167
+ # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16
1168
+ compile_key_post = (
1169
+ compute_capability,
1170
+ dtype,
1171
+ head_dim,
1172
+ n_block_size,
1173
+ num_threads,
1174
+ AtomLayoutNdKV,
1175
+ dKV_swapAB,
1176
+ cu_seqlens_k is None,
1177
+ seqused_k is None,
1178
+ )
1179
+ if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1180
+ dk_accum_tensor = to_cute_tensor(dk_accum)
1181
+ dk_tensor = to_cute_tensor(dk)
1182
+ cu_seqlens_k_tensor, seqused_k_tensor = [
1183
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
1184
+ for t in (cu_seqlens_k, seqused_k)
1185
+ ]
1186
+ arch = compute_capability * 10
1187
+ fa_bwd_post = FlashAttentionBackwardPostprocess(
1188
+ dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
1189
+ )
1190
+ # TODO: check @can_implement
1191
+ _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1192
+ fa_bwd_post,
1193
+ dk_accum_tensor,
1194
+ dk_tensor,
1195
+ softmax_scale,
1196
+ cu_seqlens_k_tensor,
1197
+ seqused_k_tensor,
1198
+ current_stream,
1199
+ options="--enable-tvm-ffi",
1200
+ )
1201
+ _flash_attn_bwd.compile_cache_post[compile_key_post](
1202
+ dk_accum,
1203
+ dk,
1204
+ softmax_scale,
1205
+ cu_seqlens_k,
1206
+ seqused_k,
1207
+ current_stream,
1208
+ )
1209
+ compile_key_post = (
1210
+ compute_capability,
1211
+ dtype,
1212
+ head_dim_v,
1213
+ n_block_size,
1214
+ num_threads,
1215
+ AtomLayoutNdKV,
1216
+ dKV_swapAB,
1217
+ cu_seqlens_k is None,
1218
+ seqused_k is None,
1219
+ )
1220
+ if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1221
+ dv_accum_tensor = to_cute_tensor(dv_accum)
1222
+ dv_tensor = to_cute_tensor(dv)
1223
+ cu_seqlens_k_tensor, seqused_k_tensor = [
1224
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
1225
+ for t in (cu_seqlens_k, seqused_k)
1226
+ ]
1227
+ arch = compute_capability * 10
1228
+ fa_bwd_post = FlashAttentionBackwardPostprocess(
1229
+ dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
1230
+ )
1231
+ # TODO: check @can_implement
1232
+ _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1233
+ fa_bwd_post,
1234
+ dv_accum_tensor,
1235
+ dv_tensor,
1236
+ cutlass.Float32(1.0),
1237
+ cu_seqlens_k_tensor,
1238
+ seqused_k_tensor,
1239
+ current_stream,
1240
+ options="--enable-tvm-ffi",
1241
+ )
1242
+ _flash_attn_bwd.compile_cache_post[compile_key_post](
1243
+ dv_accum,
1244
+ dv,
1245
+ 1.0,
1246
+ cu_seqlens_k,
1247
+ seqused_k,
1248
+ current_stream,
1249
+ )
1250
+
1251
+ return dq, dk, dv
1252
+
1253
+
1254
+ _flash_attn_bwd.compile_cache_pre = {}
1255
+ _flash_attn_bwd.compile_cache = {}
1256
+ _flash_attn_bwd.compile_cache_post = {}
1257
+
1258
+
1259
+ class FlashAttnFunc(torch.autograd.Function):
1260
+ @staticmethod
1261
+ def forward(
1262
+ ctx,
1263
+ q: torch.Tensor,
1264
+ k: torch.Tensor,
1265
+ v: torch.Tensor,
1266
+ softmax_scale: Optional[float] = None,
1267
+ causal: bool = False,
1268
+ window_size: Tuple[Optional[int], Optional[int]] = (None, None),
1269
+ learnable_sink: Optional[torch.Tensor] = None,
1270
+ softcap: float = 0.0,
1271
+ num_splits: int = 1,
1272
+ pack_gqa: Optional[bool] = None,
1273
+ deterministic: bool = False,
1274
+ mask_mod: Optional[Callable] = None,
1275
+ full_block_cnt: Optional[torch.Tensor] = None,
1276
+ full_block_idx: Optional[torch.Tensor] = None,
1277
+ mask_block_cnt: Optional[torch.Tensor] = None,
1278
+ mask_block_idx: Optional[torch.Tensor] = None,
1279
+ ):
1280
+ # Only create block sparse tensors if at least one block sparse parameter is provided
1281
+ block_sparse_tensors = None
1282
+ if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]):
1283
+ block_sparse_tensors = BlockSparseTensorsTorch(
1284
+ full_block_cnt=full_block_cnt,
1285
+ full_block_idx=full_block_idx,
1286
+ mask_block_cnt=mask_block_cnt,
1287
+ mask_block_idx=mask_block_idx,
1288
+ )
1289
+ out, lse = _flash_attn_fwd(
1290
+ q,
1291
+ k,
1292
+ v,
1293
+ softmax_scale=softmax_scale,
1294
+ causal=causal,
1295
+ window_size_left=window_size[0],
1296
+ window_size_right=window_size[1],
1297
+ learnable_sink=learnable_sink,
1298
+ softcap=softcap,
1299
+ num_splits=num_splits,
1300
+ pack_gqa=pack_gqa,
1301
+ mask_mod=mask_mod,
1302
+ block_sparse_tensors=block_sparse_tensors
1303
+ )
1304
+ ctx.save_for_backward(q, k, v, out, lse)
1305
+ ctx.softmax_scale = softmax_scale
1306
+ ctx.causal = causal
1307
+ ctx.window_size = window_size
1308
+ ctx.softcap = softcap
1309
+ ctx.deterministic = deterministic
1310
+ return out, lse
1311
+
1312
+ @staticmethod
1313
+ def backward(ctx, dout, *args):
1314
+ q, k, v, out, lse = ctx.saved_tensors
1315
+ dq, dk, dv = _flash_attn_bwd(
1316
+ q,
1317
+ k,
1318
+ v,
1319
+ out,
1320
+ dout,
1321
+ lse,
1322
+ ctx.softmax_scale,
1323
+ ctx.causal,
1324
+ ctx.softcap,
1325
+ window_size_left=ctx.window_size[0],
1326
+ window_size_right=ctx.window_size[1],
1327
+ deterministic=ctx.deterministic,
1328
+ )
1329
+ return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
1330
+
1331
+
1332
+ class FlashAttnVarlenFunc(torch.autograd.Function):
1333
+ @staticmethod
1334
+ def forward(
1335
+ ctx,
1336
+ q: torch.Tensor,
1337
+ k: torch.Tensor,
1338
+ v: torch.Tensor,
1339
+ cu_seqlens_q: Optional[torch.Tensor],
1340
+ cu_seqlens_k: Optional[torch.Tensor],
1341
+ seqused_q: Optional[torch.Tensor] = None,
1342
+ seqused_k: Optional[torch.Tensor] = None,
1343
+ max_seqlen_q: Optional[int] = None,
1344
+ max_seqlen_k: Optional[int] = None,
1345
+ page_table: Optional[torch.Tensor] = None,
1346
+ softmax_scale: Optional[float] = None,
1347
+ causal: bool = False,
1348
+ window_size: Tuple[Optional[int], Optional[int]] = (None, None),
1349
+ learnable_sink: Optional[torch.Tensor] = None,
1350
+ softcap: float = 0.0,
1351
+ num_splits: int = 1,
1352
+ pack_gqa: Optional[bool] = None,
1353
+ deterministic: bool = False,
1354
+ score_mod: Optional[Callable] = None,
1355
+ aux_tensors: Optional[list] = None,
1356
+ ):
1357
+ out, lse = _flash_attn_fwd(
1358
+ q,
1359
+ k,
1360
+ v,
1361
+ cu_seqlens_q,
1362
+ cu_seqlens_k,
1363
+ seqused_q,
1364
+ seqused_k,
1365
+ max_seqlen_q=max_seqlen_q,
1366
+ max_seqlen_k=max_seqlen_k,
1367
+ page_table=page_table,
1368
+ softmax_scale=softmax_scale,
1369
+ causal=causal,
1370
+ window_size_left=window_size[0],
1371
+ window_size_right=window_size[1],
1372
+ learnable_sink=learnable_sink,
1373
+ softcap=softcap,
1374
+ num_splits=num_splits,
1375
+ pack_gqa=pack_gqa,
1376
+ score_mod=score_mod,
1377
+ aux_tensors=aux_tensors,
1378
+ )
1379
+ ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
1380
+ ctx.softmax_scale = softmax_scale
1381
+ ctx.causal = causal
1382
+ ctx.window_size = window_size
1383
+ ctx.softcap = softcap
1384
+ ctx.deterministic = deterministic
1385
+ ctx.max_seqlen_q = max_seqlen_q
1386
+ ctx.max_seqlen_k = max_seqlen_k
1387
+ return out, lse
1388
+
1389
+ @staticmethod
1390
+ def backward(ctx, dout, *args):
1391
+ q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
1392
+ assert ctx.softcap == 0.0
1393
+ dq, dk, dv = _flash_attn_bwd(
1394
+ q,
1395
+ k,
1396
+ v,
1397
+ out,
1398
+ dout,
1399
+ lse,
1400
+ ctx.softmax_scale,
1401
+ ctx.causal,
1402
+ ctx.softcap,
1403
+ window_size_left=ctx.window_size[0],
1404
+ window_size_right=ctx.window_size[1],
1405
+ cu_seqlens_q=cu_seqlens_q,
1406
+ cu_seqlens_k=cu_seqlens_k,
1407
+ seqused_q=seqused_q,
1408
+ seqused_k=seqused_k,
1409
+ max_seqlen_q=ctx.max_seqlen_q,
1410
+ max_seqlen_k=ctx.max_seqlen_k,
1411
+ deterministic=ctx.deterministic,
1412
+ )
1413
+
1414
+ return dq, dk, dv, *((None,) * 20)
1415
+
1416
+
1417
+ def flash_attn_func(
1418
+ q: torch.Tensor,
1419
+ k: torch.Tensor,
1420
+ v: torch.Tensor,
1421
+ softmax_scale: Optional[float] = None,
1422
+ causal: bool = False,
1423
+ window_size: Tuple[Optional[int], Optional[int]] = (None, None),
1424
+ learnable_sink: Optional[torch.Tensor] = None,
1425
+ softcap: float = 0.0,
1426
+ num_splits: int = 1,
1427
+ pack_gqa: Optional[bool] = None,
1428
+ deterministic: bool = False,
1429
+ mask_mod: Optional[Callable] = None,
1430
+ full_block_cnt: Optional[torch.Tensor] = None,
1431
+ full_block_idx: Optional[torch.Tensor] = None,
1432
+ mask_block_cnt: Optional[torch.Tensor] = None,
1433
+ mask_block_idx: Optional[torch.Tensor] = None,
1434
+ ):
1435
+ return FlashAttnFunc.apply(
1436
+ q,
1437
+ k,
1438
+ v,
1439
+ softmax_scale,
1440
+ causal,
1441
+ window_size,
1442
+ learnable_sink,
1443
+ softcap,
1444
+ num_splits,
1445
+ pack_gqa,
1446
+ deterministic,
1447
+ mask_mod,
1448
+ full_block_cnt,
1449
+ full_block_idx,
1450
+ mask_block_cnt,
1451
+ mask_block_idx,
1452
+ )
1453
+
1454
+
1455
+ def flash_attn_varlen_func(
1456
+ q: torch.Tensor,
1457
+ k: torch.Tensor,
1458
+ v: torch.Tensor,
1459
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1460
+ cu_seqlens_k: Optional[torch.Tensor] = None,
1461
+ max_seqlen_q: Optional[int] = None,
1462
+ max_seqlen_k: Optional[int] = None,
1463
+ seqused_q: Optional[torch.Tensor] = None,
1464
+ seqused_k: Optional[torch.Tensor] = None,
1465
+ page_table: Optional[torch.Tensor] = None,
1466
+ softmax_scale: Optional[float] = None,
1467
+ causal: bool = False,
1468
+ window_size: Tuple[Optional[int], Optional[int]] = (None, None),
1469
+ learnable_sink: Optional[torch.Tensor] = None,
1470
+ softcap: float = 0.0,
1471
+ num_splits: int = 1,
1472
+ pack_gqa: Optional[bool] = None,
1473
+ deterministic: bool = False,
1474
+ score_mod: Optional[Callable] = None,
1475
+ aux_tensors: Optional[list] = None,
1476
+ ):
1477
+ return FlashAttnVarlenFunc.apply(
1478
+ q,
1479
+ k,
1480
+ v,
1481
+ cu_seqlens_q,
1482
+ cu_seqlens_k,
1483
+ seqused_q,
1484
+ seqused_k,
1485
+ max_seqlen_q,
1486
+ max_seqlen_k,
1487
+ page_table,
1488
+ softmax_scale,
1489
+ causal,
1490
+ window_size,
1491
+ learnable_sink,
1492
+ softcap,
1493
+ num_splits,
1494
+ pack_gqa,
1495
+ deterministic,
1496
+ score_mod,
1497
+ aux_tensors,
1498
+ )
1499
+
1500
+
1501
+ def _flash_attn_fwd_combine(
1502
+ out_partial: torch.Tensor,
1503
+ lse_partial: torch.Tensor,
1504
+ out: torch.Tensor,
1505
+ lse: Optional[torch.Tensor] = None,
1506
+ cu_seqlens: Optional[torch.Tensor] = None,
1507
+ seqused: Optional[torch.Tensor] = None,
1508
+ num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
1509
+ semaphore_to_reset: Optional[torch.Tensor] = None,
1510
+ ) -> None:
1511
+ """Forward combine kernel for split attention computation.
1512
+
1513
+ Combines partial outputs and log-sum-exp values from multiple splits
1514
+ of attention computation into final outputs.
1515
+
1516
+ Args:
1517
+ out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or
1518
+ (num_splits, total_q, nheads, headdim) if there's cu_seqlens
1519
+ lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or
1520
+ (num_splits, total_q, nheads) if there's cu_seqlens
1521
+ out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens
1522
+ lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens.
1523
+ cu_seqlens: Cumulative sequence lengths for variable length sequences
1524
+ seqused: Used sequence lengths for each batch
1525
+ num_splits_dynamic_ptr: Dynamic number of splits per batch
1526
+ semaphore_to_reset: Semaphore for synchronization
1527
+ k_block_size: Block size for head dimension
1528
+
1529
+ Returns:
1530
+ None
1531
+ """
1532
+ # Input validation
1533
+ assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
1534
+ assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
1535
+ assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (
1536
+ "out_partial must be fp16, bf16, or fp32"
1537
+ )
1538
+ assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
1539
+ assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device"
1540
+ assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension"
1541
+ assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension"
1542
+ assert lse_partial.shape == out_partial.shape[:-1]
1543
+
1544
+ # Determine if this is variable length based on dimensions
1545
+ is_varlen = out_partial.dim() == 4
1546
+
1547
+ # Validate output tensor shapes and types
1548
+ assert out.shape == out_partial.shape[1:], "out shape mismatch"
1549
+ if lse is not None:
1550
+ assert lse.shape == lse_partial.shape[1:], "lse shape mismatch"
1551
+ assert lse.dtype == torch.float32, "lse must be fp32"
1552
+
1553
+ # Validate optional tensors
1554
+ for t, name in [
1555
+ (cu_seqlens, "cu_seqlens"),
1556
+ (seqused, "seqused"),
1557
+ (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
1558
+ ]:
1559
+ if t is not None:
1560
+ assert t.dtype == torch.int32, f"{name} must be int32"
1561
+ assert t.is_cuda, f"{name} must be on CUDA device"
1562
+ assert t.is_contiguous(), f"{name} must be contiguous"
1563
+
1564
+ head_dim = out_partial.shape[-1]
1565
+ num_splits = out_partial.shape[0]
1566
+ assert num_splits <= 256
1567
+ # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
1568
+ # so that kBlockM is smaller and we have more parallelism.
1569
+ k_block_size = 64 if head_dim <= 64 else 128
1570
+ # We want kBlockM to be as small as possible to maximize parallelism.
1571
+ # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
1572
+ m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)
1573
+ log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
1574
+ if m_block_size == 8:
1575
+ # If kBlockM == 8 then the minimum number of splits is 32.
1576
+ # TODO: we can deal w this by using 128 threads instead
1577
+ log_max_splits = max(log_max_splits, 5)
1578
+
1579
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1580
+
1581
+ # Create combine kernel configuration
1582
+ dtype = torch2cute_dtype_map[out.dtype]
1583
+ dtype_partial = torch2cute_dtype_map[out_partial.dtype]
1584
+
1585
+ compile_key = (
1586
+ dtype,
1587
+ dtype_partial,
1588
+ head_dim,
1589
+ m_block_size,
1590
+ k_block_size,
1591
+ log_max_splits,
1592
+ cu_seqlens is not None,
1593
+ seqused is not None,
1594
+ lse is not None,
1595
+ )
1596
+
1597
+ if compile_key not in _flash_attn_fwd_combine.compile_cache:
1598
+ out_partial_tensor = to_cute_tensor(
1599
+ out_partial, leading_dim=4 if not is_varlen else 3
1600
+ )
1601
+ lse_partial_tensor = to_cute_tensor(
1602
+ lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2
1603
+ )
1604
+ out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2)
1605
+ lse_tensor = (
1606
+ to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2)
1607
+ if lse is not None
1608
+ else None
1609
+ )
1610
+
1611
+ optional_tensors = [
1612
+ to_cute_tensor(t, assumed_align=4, leading_dim=0)
1613
+ if t is not None
1614
+ else None
1615
+ for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
1616
+ ]
1617
+ cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = (
1618
+ optional_tensors
1619
+ )
1620
+ fa_combine = FlashAttentionForwardCombine(
1621
+ dtype=dtype,
1622
+ dtype_partial=dtype_partial,
1623
+ head_dim=head_dim,
1624
+ m_block_size=m_block_size,
1625
+ k_block_size=k_block_size,
1626
+ log_max_splits=log_max_splits,
1627
+ )
1628
+
1629
+ # Check if implementation is supported
1630
+ if not fa_combine.can_implement(
1631
+ dtype,
1632
+ dtype_partial,
1633
+ head_dim,
1634
+ m_block_size,
1635
+ k_block_size,
1636
+ log_max_splits,
1637
+ num_threads=256,
1638
+ ):
1639
+ raise RuntimeError(
1640
+ "FlashAttention combine kernel cannot be implemented with given parameters"
1641
+ )
1642
+
1643
+ _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile(
1644
+ fa_combine,
1645
+ out_partial_tensor,
1646
+ lse_partial_tensor,
1647
+ out_tensor,
1648
+ lse_tensor,
1649
+ cu_seqlens_tensor,
1650
+ seqused_tensor,
1651
+ num_splits_dynamic_tensor,
1652
+ semaphore_tensor,
1653
+ current_stream,
1654
+ options="--enable-tvm-ffi",
1655
+ )
1656
+ _flash_attn_fwd_combine.compile_cache[compile_key](
1657
+ out_partial,
1658
+ lse_partial,
1659
+ out,
1660
+ lse,
1661
+ cu_seqlens,
1662
+ seqused,
1663
+ num_splits_dynamic_ptr,
1664
+ semaphore_to_reset,
1665
+ current_stream,
1666
+ )
1667
+
1668
+
1669
+ _flash_attn_fwd_combine.compile_cache = {}
1670
+
1671
+
1672
+ def flash_attn_combine(
1673
+ out_partial: torch.Tensor,
1674
+ lse_partial: torch.Tensor,
1675
+ out: Optional[torch.Tensor] = None,
1676
+ out_dtype: Optional[torch.dtype] = None,
1677
+ cu_seqlens: Optional[torch.Tensor] = None,
1678
+ seqused: Optional[torch.Tensor] = None,
1679
+ return_lse: bool = True,
1680
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1681
+ """Flash Attention combine function for split attention computation.
1682
+
1683
+ Combines partial outputs and log-sum-exp values from multiple splits
1684
+ of attention computation into final outputs. This is the main user-facing
1685
+ interface for the combine kernel.
1686
+
1687
+ Args:
1688
+ out_partial: Partial outputs tensor with shape:
1689
+ - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input
1690
+ - (num_splits, total_q, num_heads, head_size) for variable length input
1691
+ lse_partial: Partial LSE tensor with shape:
1692
+ - (num_splits, batch_size, seqlen, num_heads) for regular batched input
1693
+ - (num_splits, total_q, num_heads) for variable length input
1694
+ out: Optional output tensor. If None, will be created automatically.
1695
+ out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.
1696
+ cu_seqlens: Cumulative sequence lengths for variable length sequences
1697
+ seqused: Used sequence lengths for each batch
1698
+ return_lse: Whether to return the combined LSE tensor. Default is True.
1699
+
1700
+ Returns:
1701
+ Tuple of (out, lse) where:
1702
+ - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size)
1703
+ or (total_q, num_heads, head_size) for varlen
1704
+ - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads)
1705
+ or (total_q, num_heads) for varlen. None if return_lse=False
1706
+
1707
+ Note:
1708
+ This function expects the input tensors to be in the format produced by
1709
+ split attention computation, where the first dimension is num_splits.
1710
+ The permuting from user format to kernel format is now done inside the kernel.
1711
+ """
1712
+ # Input validation
1713
+ assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
1714
+ assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
1715
+ assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)"
1716
+ assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
1717
+
1718
+ # Determine if this is variable length based on dimensions
1719
+ is_varlen = out_partial.dim() == 4
1720
+
1721
+ if is_varlen:
1722
+ # Variable length: (num_splits, total_q, num_heads, head_size)
1723
+ num_splits, total_q, num_heads, head_size = out_partial.shape
1724
+ assert lse_partial.shape == (num_splits, total_q, num_heads), (
1725
+ "lse_partial shape mismatch for varlen"
1726
+ )
1727
+ batch_size = 1 # Treat as single batch for varlen
1728
+ seqlen = total_q
1729
+ else:
1730
+ # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)
1731
+ num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape
1732
+ assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), (
1733
+ "lse_partial shape mismatch"
1734
+ )
1735
+
1736
+ # Determine output dtype
1737
+ if out_dtype is None:
1738
+ out_dtype = out_partial.dtype
1739
+
1740
+ # Create output if not provided
1741
+ device = out_partial.device
1742
+ if out is None:
1743
+ if is_varlen:
1744
+ out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device)
1745
+ else:
1746
+ out = torch.empty(
1747
+ batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device
1748
+ )
1749
+
1750
+ # Create lse output only if requested
1751
+ if return_lse:
1752
+ if is_varlen:
1753
+ lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(
1754
+ 0, 1
1755
+ )
1756
+ else:
1757
+ lse = torch.empty(
1758
+ batch_size, num_heads, seqlen, dtype=torch.float32, device=device
1759
+ ).transpose(1, 2)
1760
+ else:
1761
+ lse = None
1762
+
1763
+ _flash_attn_fwd_combine(
1764
+ out_partial,
1765
+ lse_partial,
1766
+ out,
1767
+ lse,
1768
+ cu_seqlens,
1769
+ seqused,
1770
+ )
1771
+ return out, lse