mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,533 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from enum import IntEnum
10
+ from typing import Any, Optional
11
+
12
+ import mslk.attention.cutlass_blackwell_fmha # noqa: F401
13
+ import torch
14
+
15
+
16
+ class GenKernelType(IntEnum):
17
+ UMMA_I = 0
18
+ UMMA_P = 1
19
+
20
+
21
+ def get_splitk_heuristic(
22
+ batch: int,
23
+ seqlen_kv: int,
24
+ kv_heads: int = 1,
25
+ tile_n: int = 256,
26
+ sm_count: int | None = None,
27
+ ) -> int:
28
+ """
29
+ Compute optimal split-K size for Shape<64, 256, 128> tile configuration.
30
+
31
+ Targets full GPU utilization by distributing work across all SMs.
32
+ First calculates SMs per batch, then per kv_head, then divides seqlen_kv by that number.
33
+ Ensures split size evenly divides seqlen_kv so all CTAs process same number of tiles.
34
+ Returns 0 (no split) when split would equal seqlen_kv (only 1 split).
35
+
36
+ Args:
37
+ batch: Batch size
38
+ seqlen_kv: Maximum sequence length for K/V
39
+ kv_heads: Number of KV heads (default 1 for MQA)
40
+ tile_n: TileN dimension (default 256 for Shape<64, 256, 128>)
41
+ sm_count: Number of SMs on the GPU. If None, queries the current device.
42
+
43
+ Returns:
44
+ Optimal split size along the K/V sequence dimension, or 0 to disable split-K
45
+ """
46
+ # Get SM count from current device if not provided
47
+ if sm_count is None:
48
+ sm_count = torch.cuda.get_device_properties(
49
+ torch.cuda.current_device()
50
+ ).multi_processor_count
51
+
52
+ # Calculate number of SMs available per batch element
53
+ sms_per_batch = max(1, sm_count // batch)
54
+ # Further divide by kv_heads for multi-head KV
55
+ sms_per_head_batch = max(1, sms_per_batch // kv_heads)
56
+
57
+ # Each (batch, kv_head) element should have sms_per_head_batch splits
58
+ # So split size = seqlen_kv / sms_per_head_batch
59
+ ideal_split = seqlen_kv // sms_per_head_batch
60
+
61
+ # Round up to multiple of tile_n
62
+ split = ((ideal_split + tile_n - 1) // tile_n) * tile_n
63
+
64
+ # Clamp to valid range: [tile_n, seqlen_kv]
65
+ split = max(split, tile_n)
66
+ split = min(split, seqlen_kv)
67
+
68
+ # If split equals seqlen_kv, there's only 1 split - disable split-K
69
+ if split == seqlen_kv:
70
+ split = 0
71
+
72
+ return split
73
+
74
+
75
+ def maybe_contiguous(x: torch.Tensor) -> torch.Tensor:
76
+ """
77
+ We only require the head dim to be contiguous
78
+ """
79
+ return (
80
+ x.contiguous()
81
+ if x is not None and (x.stride(-1) != 1 or x.stride(-2) % 8 != 0)
82
+ else x
83
+ )
84
+
85
+
86
+ def _cutlass_blackwell_fmha_forward(
87
+ q: torch.Tensor,
88
+ k: torch.Tensor,
89
+ v: torch.Tensor,
90
+ cu_seqlens_q: torch.Tensor | None = None,
91
+ cu_seqlens_k: torch.Tensor | None = None,
92
+ max_seq_len_q: int | None = None,
93
+ max_seq_len_k: int | None = None,
94
+ softmax_scale: float | None = None,
95
+ causal: bool = False,
96
+ seqlen_kv: torch.Tensor | None = None,
97
+ page_table: torch.Tensor | None = None,
98
+ seqlen_k: int | None = None,
99
+ window_left: int = -1,
100
+ window_right: int = -1,
101
+ bottom_right: bool = True,
102
+ ) -> tuple[torch.Tensor, torch.Tensor]:
103
+ q = maybe_contiguous(q)
104
+ k = maybe_contiguous(k)
105
+ v = maybe_contiguous(v)
106
+ return torch.ops.mslk.fmha_fwd(
107
+ q,
108
+ k,
109
+ v,
110
+ cu_seqlens_q=cu_seqlens_q,
111
+ cu_seqlens_k=cu_seqlens_k,
112
+ max_seq_len_q=max_seq_len_q,
113
+ max_seq_len_k=max_seq_len_k,
114
+ softmax_scale=softmax_scale,
115
+ causal=causal,
116
+ seqlen_kv=seqlen_kv,
117
+ page_table=page_table,
118
+ seqlen_k=seqlen_k,
119
+ window_size_left=window_left,
120
+ window_size_right=window_right,
121
+ bottom_right=bottom_right,
122
+ )
123
+
124
+
125
+ def _cutlass_blackwell_fmha_backward(
126
+ dout: torch.Tensor,
127
+ q: torch.Tensor,
128
+ k: torch.Tensor,
129
+ v: torch.Tensor,
130
+ out: torch.Tensor,
131
+ softmax_lse: torch.Tensor,
132
+ cu_seqlens_q: torch.Tensor | None = None,
133
+ cu_seqlens_k: torch.Tensor | None = None,
134
+ max_seq_len_q: int | None = None,
135
+ max_seq_len_k: int | None = None,
136
+ softmax_scale: float | None = None,
137
+ causal: bool = False,
138
+ window_left: int = -1,
139
+ window_right: int = -1,
140
+ bottom_right: bool = True,
141
+ deterministic: bool = False,
142
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
143
+ deterministic = deterministic or torch.are_deterministic_algorithms_enabled()
144
+ dout = maybe_contiguous(dout)
145
+ q = maybe_contiguous(q)
146
+ k = maybe_contiguous(k)
147
+ v = maybe_contiguous(v)
148
+ out = maybe_contiguous(out)
149
+ return torch.ops.mslk.fmha_bwd(
150
+ dout,
151
+ q,
152
+ k,
153
+ v,
154
+ out,
155
+ softmax_lse,
156
+ cu_seqlens_q=cu_seqlens_q,
157
+ cu_seqlens_k=cu_seqlens_k,
158
+ max_seq_len_q=max_seq_len_q,
159
+ max_seq_len_k=max_seq_len_k,
160
+ softmax_scale=softmax_scale,
161
+ causal=causal,
162
+ window_size_left=window_left,
163
+ window_size_right=window_right,
164
+ bottom_right=bottom_right,
165
+ deterministic=deterministic,
166
+ )
167
+
168
+
169
+ def _validate_and_adjust_split_k_size(split_k_size: int) -> int:
170
+ """
171
+ Validate and adjust split_k_size parameter for optimal performance.
172
+
173
+ Args:
174
+ split_k_size: The requested split size along the K/V sequence dimension.
175
+
176
+ Returns:
177
+ Adjusted split_k_size that is valid for the kernel.
178
+
179
+ Valid values:
180
+ - split_k_size <= 0: Disable split-K (no splitting)
181
+ - split_k_size > 0: Enable split-K with specified split size
182
+ """
183
+ if not isinstance(split_k_size, int):
184
+ raise TypeError(
185
+ f"split_k_size must be an integer, got {type(split_k_size).__name__}"
186
+ )
187
+
188
+ # If split-K is disabled, return as-is
189
+ if split_k_size <= 0:
190
+ return split_k_size
191
+
192
+ # Constants
193
+ MIN_RECOMMENDED_SPLIT_SIZE = 256
194
+ TILE_SIZE = 128
195
+
196
+ # Adjust if split_k_size is too small
197
+ if split_k_size < MIN_RECOMMENDED_SPLIT_SIZE:
198
+ split_k_size = MIN_RECOMMENDED_SPLIT_SIZE
199
+
200
+ # Check if split_k_size is a power of 2
201
+ is_power_of_2 = (split_k_size & (split_k_size - 1)) == 0
202
+
203
+ # If not a power of 2, round to nearest multiple of tile size (128)
204
+ if not is_power_of_2:
205
+ split_k_size = ((split_k_size + TILE_SIZE - 1) // TILE_SIZE) * TILE_SIZE
206
+
207
+ return split_k_size
208
+
209
+
210
+ def _validate_decode_inputs(
211
+ q: torch.Tensor,
212
+ k: torch.Tensor,
213
+ v: torch.Tensor,
214
+ seqlen_kv: torch.Tensor | None,
215
+ ) -> None:
216
+ assert seqlen_kv is not None, "seqlen_kv must be provided for decode"
217
+ tensors = {"q": q, "k": k, "v": v, "seqlen_kv": seqlen_kv}
218
+
219
+ for name, tensor in tensors.items():
220
+ # assert tensor.is_contiguous(), f"{name} is not contiguous"
221
+ assert tensor.is_cuda, f"{name} must be on GPU"
222
+
223
+
224
+ def _prepare_decode_inputs(
225
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
226
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool, tuple[int, ...]]:
227
+ """
228
+ Prepare inputs for decode kernel by handling both varlen and batch formats.
229
+
230
+ Returns:
231
+ - Reshaped q, k, v tensors in batch format [B, 1, H, D]
232
+ - batch_size
233
+ - needs_reshape_output flag
234
+ - original_shape of q
235
+ """
236
+ original_shape = tuple(q.shape)
237
+ needs_reshape_output = False
238
+ batch_size = q.shape[0]
239
+
240
+ if q.dim() == 3:
241
+ # Varlen format: [total_queries, num_heads, head_dim]
242
+ q = q.view(batch_size, 1, q.shape[1], q.shape[2])
243
+ needs_reshape_output = True
244
+
245
+ if q.dim() != 4:
246
+ raise ValueError(
247
+ f"Invalid query shape: {q.shape}. Expected [B, 1, H, D] or [total_queries, H, D]"
248
+ )
249
+ assert q.shape[1] == 1, "Kernel have sq=1"
250
+
251
+ k = k.view(batch_size, -1, k.shape[1], k.shape[2]) if k.dim() == 3 else k
252
+ v = v.view(batch_size, -1, v.shape[1], v.shape[2]) if v.dim() == 3 else v
253
+
254
+ return q, k, v, batch_size, needs_reshape_output, original_shape
255
+
256
+
257
+ def cutlass_blackwell_fmha_decode_forward(
258
+ q: torch.Tensor,
259
+ k: torch.Tensor,
260
+ v: torch.Tensor,
261
+ seqlen_kv: torch.Tensor | None = None,
262
+ cu_seqlens_q: torch.Tensor | None = None,
263
+ cu_seqlens_k: torch.Tensor | None = None,
264
+ max_seq_len_q: int | None = None,
265
+ max_seq_len_k: int | None = None,
266
+ softmax_scale: float | None = None,
267
+ causal: bool = False,
268
+ window_left: int = -1,
269
+ window_right: int = -1,
270
+ bottom_right: bool = True,
271
+ split_k_size: int = 0,
272
+ use_heuristic: bool = True,
273
+ ) -> tuple[torch.Tensor, torch.Tensor]:
274
+ """
275
+ Decode-optimized forward pass using the gen kernel.
276
+ This is a wrapper to use the gen kernel which is optimized
277
+ for decode (query length = 1).
278
+
279
+ This function is called externally by xformers ops.
280
+
281
+ Accepts inputs in two formats:
282
+ - Varlen format: [total_queries, num_heads, head_dim] (3D)
283
+ - Batch format: [batch_size, 1, num_heads, head_dim] (4D)
284
+
285
+ Args:
286
+ q: Query tensor in varlen [B, H, D] or batch [B, 1, H, D] format
287
+ k: Key tensor [B, Sk, H_kv, D]
288
+ v: Value tensor [B, Sk, H_kv, D]
289
+ seqlen_kv: Per-batch sequence lengths [B] (required)
290
+ split_k_size: Size of each split along the K/V sequence dimension.
291
+ - split_k_size <= 0 with use_heuristic=True: auto-compute using heuristic
292
+ - split_k_size <= 0 with use_heuristic=False: disable split-K
293
+ - split_k_size > 0: use the provided split size directly
294
+ Values below 256 are adjusted to 256. Non-power-of-2 values
295
+ are rounded to the nearest multiple of 128.
296
+ use_heuristic: If True and split_k_size <= 0, automatically compute optimal
297
+ split size using the heuristic. Default is True.
298
+
299
+ Returns:
300
+ Kernel output with Q dimension added:
301
+ - out: [B, 1, H, num_splits, D] (num_splits=1 when split-K disabled)
302
+ - lse: [B, num_splits, H, 1]
303
+ """
304
+ _validate_decode_inputs(q, k, v, seqlen_kv)
305
+
306
+ # Prepare inputs and handle format conversion
307
+ q, k, v, batch_size, _, original_shape = _prepare_decode_inputs(q, k, v)
308
+
309
+ # Determine effective split_k_size
310
+ if split_k_size <= 0 and use_heuristic:
311
+ # Auto-compute using heuristic
312
+ max_seqlen_kv = k.shape[1]
313
+ kv_heads = k.shape[2] # K shape is [B, Sk, H_kv, D]
314
+ split_k_size = get_splitk_heuristic(batch_size, max_seqlen_kv, kv_heads)
315
+
316
+ # Validate and adjust split_k_size
317
+ split_k_size = _validate_and_adjust_split_k_size(split_k_size)
318
+
319
+ # Validate window_right: decode kernel only supports causal attention (window_right <= 0)
320
+ if window_right > 0:
321
+ raise ValueError(
322
+ f"window_right={window_right} is not supported for decode attention. "
323
+ "The decode kernel only supports causal attention with window_right <= 0. "
324
+ "Use window_right=0 (causal, current position only)."
325
+ )
326
+
327
+ # Call the gen kernel (optimized for decode)
328
+ # Note: window_left specifies how many tokens to look back (exclusive)
329
+ # The kernel will attend to positions [seqlen_kv - window_left, seqlen_kv)
330
+ out, lse = torch.ops.mslk.fmha_gen_fwd(
331
+ q,
332
+ k,
333
+ v,
334
+ seqlen_kv,
335
+ None,
336
+ kernel_type=GenKernelType.UMMA_I,
337
+ window_left=window_left,
338
+ window_right=0,
339
+ split_k_size=split_k_size,
340
+ )
341
+
342
+ # Kernel returns: out [B, H, num_splits, D], lse [B, num_splits, H]
343
+ # Reshape to consistent format with Q dimension:
344
+ # out: [B, H, num_splits, D] -> [B, 1, H, num_splits, D]
345
+ # lse: [B, num_splits, H] -> [B, num_splits, H, 1]
346
+ out = out.unsqueeze(1) # [B, 1, H, num_splits, D]
347
+ lse = lse.unsqueeze(-1) # [B, num_splits, H, 1]
348
+ return out, lse
349
+
350
+
351
+ class CutlassBlackwellFmhaFunc(torch.autograd.Function):
352
+ @staticmethod
353
+ def forward( # type: ignore
354
+ ctx,
355
+ q: torch.Tensor,
356
+ k: torch.Tensor,
357
+ v: torch.Tensor,
358
+ softmax_scale: float | None = None,
359
+ causal: bool = False,
360
+ cu_seqlens_q: Optional[torch.Tensor] = None,
361
+ cu_seqlens_k: Optional[torch.Tensor] = None,
362
+ max_seq_len_q: Optional[int] = None,
363
+ max_seq_len_k: Optional[int] = None,
364
+ seqlen_kv: Optional[torch.Tensor] = None,
365
+ page_table: Optional[torch.Tensor] = None,
366
+ seqlen_k: Optional[int] = None,
367
+ window_size: tuple[int, int] = (-1, -1),
368
+ bottom_right: bool = True,
369
+ deterministic: bool = False,
370
+ ) -> torch.Tensor:
371
+ window_left, window_right = window_size
372
+ # Check if this is generation phase (sq = 1)
373
+ sq = q.shape[1]
374
+ if q.dim() == 4 and sq == 1:
375
+ # For gen case, we don't need to save tensors for backward
376
+ ctx.is_gen = True
377
+ out, _ = cutlass_blackwell_fmha_decode_forward(
378
+ q,
379
+ k,
380
+ v,
381
+ seqlen_kv,
382
+ cu_seqlens_q,
383
+ cu_seqlens_k,
384
+ max_seq_len_q,
385
+ max_seq_len_k,
386
+ softmax_scale,
387
+ causal,
388
+ window_left,
389
+ window_right,
390
+ bottom_right,
391
+ )
392
+ return out
393
+
394
+ ctx.is_gen = False
395
+ # Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
396
+ if cu_seqlens_q is not None and cu_seqlens_k is not None:
397
+ assert (
398
+ cu_seqlens_q.dtype == torch.int32
399
+ and cu_seqlens_q.dtype == cu_seqlens_k.dtype
400
+ ), "cu_seqlens_q and cu_seqlens_k must be int32"
401
+
402
+ # handle window_size
403
+ if causal and window_left >= 0:
404
+ window_right = 0
405
+ # Use regular FMHA for non-generation case
406
+ out, softmax_lse = _cutlass_blackwell_fmha_forward(
407
+ q,
408
+ k,
409
+ v,
410
+ cu_seqlens_q,
411
+ cu_seqlens_k,
412
+ max_seq_len_q,
413
+ max_seq_len_k,
414
+ softmax_scale,
415
+ causal,
416
+ seqlen_kv,
417
+ page_table,
418
+ seqlen_k,
419
+ window_left,
420
+ window_right,
421
+ bottom_right,
422
+ )
423
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
424
+ ctx.softmax_scale = softmax_scale
425
+ ctx.causal = causal
426
+ ctx.window_size = window_size
427
+ ctx.max_seq_len_q = max_seq_len_q
428
+ ctx.max_seq_len_k = max_seq_len_k
429
+ ctx.cu_seqlens_q = cu_seqlens_q
430
+ ctx.cu_seqlens_k = cu_seqlens_k
431
+ ctx.bottom_right = bottom_right
432
+ ctx.deterministic = deterministic
433
+ return out
434
+
435
+ @staticmethod
436
+ def backward(
437
+ ctx, dout: torch.Tensor, *args: Any
438
+ ) -> tuple[ # type: ignore
439
+ torch.Tensor,
440
+ torch.Tensor,
441
+ torch.Tensor,
442
+ None,
443
+ None,
444
+ None,
445
+ None,
446
+ None,
447
+ None,
448
+ None,
449
+ None,
450
+ None,
451
+ None,
452
+ None,
453
+ None,
454
+ ]:
455
+ if ctx.is_gen:
456
+ # For gen case, no backward pass is needed (generation is inference only)
457
+ raise RuntimeError(
458
+ "Backward pass is not supported for generation phase (sq=1)"
459
+ )
460
+
461
+ q, k, v, out, softmax_lse = ctx.saved_tensors
462
+ window_left, window_right = ctx.window_size
463
+ dq, dk, dv = _cutlass_blackwell_fmha_backward(
464
+ dout,
465
+ q,
466
+ k,
467
+ v,
468
+ out,
469
+ softmax_lse,
470
+ ctx.cu_seqlens_q,
471
+ ctx.cu_seqlens_k,
472
+ ctx.max_seq_len_q,
473
+ ctx.max_seq_len_k,
474
+ ctx.softmax_scale,
475
+ ctx.causal,
476
+ window_left,
477
+ window_right,
478
+ bottom_right=ctx.bottom_right,
479
+ deterministic=ctx.deterministic,
480
+ )
481
+ return (
482
+ dq,
483
+ dk,
484
+ dv,
485
+ None,
486
+ None,
487
+ None,
488
+ None,
489
+ None,
490
+ None,
491
+ None,
492
+ None,
493
+ None,
494
+ None,
495
+ None,
496
+ None,
497
+ )
498
+
499
+
500
+ def cutlass_blackwell_fmha_func(
501
+ q: torch.Tensor,
502
+ k: torch.Tensor,
503
+ v: torch.Tensor,
504
+ softmax_scale: float | None = None,
505
+ causal: bool = False,
506
+ cu_seqlens_q: torch.Tensor | None = None,
507
+ cu_seqlens_k: torch.Tensor | None = None,
508
+ max_seq_len_q: int | None = None,
509
+ max_seq_len_k: int | None = None,
510
+ seqlen_kv: torch.Tensor | None = None,
511
+ page_table: torch.Tensor | None = None,
512
+ seqlen_k: int | None = None,
513
+ window_size: tuple[int, int] | None = (-1, -1),
514
+ bottom_right: bool = True,
515
+ deterministic: bool = False,
516
+ ):
517
+ return CutlassBlackwellFmhaFunc.apply(
518
+ q,
519
+ k,
520
+ v,
521
+ softmax_scale,
522
+ causal,
523
+ cu_seqlens_q,
524
+ cu_seqlens_k,
525
+ max_seq_len_q,
526
+ max_seq_len_k,
527
+ seqlen_kv,
528
+ page_table,
529
+ seqlen_k,
530
+ window_size,
531
+ bottom_right,
532
+ deterministic,
533
+ )
@@ -0,0 +1,22 @@
1
+ # @nolint # fbcode
2
+ """Flash Attention CUTE (CUDA Template Engine) implementation."""
3
+
4
+ __version__ = "0.1.0"
5
+
6
+ import cutlass.cute as cute
7
+
8
+ from .interface import (
9
+ flash_attn_func,
10
+ flash_attn_varlen_func,
11
+ )
12
+
13
+ from mslk.attention.flash_attn.cute_dsl_utils import cute_compile_patched
14
+
15
+ # Patch cute.compile to optionally dump SASS
16
+ cute.compile = cute_compile_patched
17
+
18
+
19
+ __all__ = [
20
+ "flash_attn_func",
21
+ "flash_attn_varlen_func",
22
+ ]
@@ -0,0 +1,104 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+ from typing import Type, Callable, Optional
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+
8
+
9
+ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:
10
+ dtype_byte = cutlass.const_expr(dtype.width // 8)
11
+ bytes_per_row = cutlass.const_expr(k_dim * dtype_byte)
12
+ smem_k_block_size = (
13
+ cutlass.const_expr(
14
+ 128
15
+ if bytes_per_row % 128 == 0
16
+ else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))
17
+ )
18
+ // dtype_byte
19
+ )
20
+ swizzle_bits = (
21
+ 4
22
+ if smem_k_block_size == 128
23
+ else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))
24
+ )
25
+ swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)
26
+ return cute.make_composed_layout(
27
+ cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),
28
+ 0,
29
+ cute.make_ordered_layout(
30
+ (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)
31
+ ),
32
+ )
33
+
34
+
35
+ @cute.jit
36
+ def gemm(
37
+ tiled_mma: cute.TiledMma,
38
+ acc: cute.Tensor,
39
+ tCrA: cute.Tensor,
40
+ tCrB: cute.Tensor,
41
+ tCsA: cute.Tensor,
42
+ tCsB: cute.Tensor,
43
+ smem_thr_copy_A: cute.TiledCopy,
44
+ smem_thr_copy_B: cute.TiledCopy,
45
+ hook_fn: Optional[Callable] = None,
46
+ A_in_regs: cutlass.Constexpr[bool] = False,
47
+ B_in_regs: cutlass.Constexpr[bool] = False,
48
+ swap_AB: cutlass.Constexpr[bool] = False,
49
+ ) -> None:
50
+ if cutlass.const_expr(swap_AB):
51
+ gemm(
52
+ tiled_mma,
53
+ acc,
54
+ tCrB,
55
+ tCrA,
56
+ tCsB,
57
+ tCsA,
58
+ smem_thr_copy_B,
59
+ smem_thr_copy_A,
60
+ hook_fn,
61
+ A_in_regs=B_in_regs,
62
+ B_in_regs=A_in_regs,
63
+ swap_AB=False,
64
+ )
65
+ else:
66
+ tCrA_copy_view = smem_thr_copy_A.retile(tCrA)
67
+ tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
68
+ if cutlass.const_expr(not A_in_regs):
69
+ cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0])
70
+ if cutlass.const_expr(not B_in_regs):
71
+ cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
72
+ for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])):
73
+ if k < cute.size(tCsA.shape[2]) - 1:
74
+ if cutlass.const_expr(not A_in_regs):
75
+ cute.copy(
76
+ smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]
77
+ )
78
+ if cutlass.const_expr(not B_in_regs):
79
+ cute.copy(
80
+ smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]
81
+ )
82
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
83
+ if cutlass.const_expr(k == 0 and hook_fn is not None):
84
+ hook_fn()
85
+
86
+
87
+ @cute.jit
88
+ def gemm_rs(
89
+ tiled_mma: cute.TiledMma,
90
+ acc: cute.Tensor,
91
+ tCrA: cute.Tensor,
92
+ tCrB: cute.Tensor,
93
+ tCsB: cute.Tensor,
94
+ smem_thr_copy_B: cute.TiledCopy,
95
+ hook_fn: Optional[Callable] = None,
96
+ ) -> None:
97
+ tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
98
+ cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
99
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
100
+ if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1):
101
+ cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1])
102
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
103
+ if cutlass.const_expr(k == 0 and hook_fn is not None):
104
+ hook_fn()