mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
mslk/__init__.py ADDED
@@ -0,0 +1,56 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+
12
+ open_source: bool = True
13
+
14
+
15
+ def _load_library(filename: str, version: str, no_throw: bool = False) -> None:
16
+ """Load a shared library from the given filename."""
17
+ try:
18
+ library_path = os.path.join(os.path.dirname(__file__), filename)
19
+ torch.ops.load_library(library_path)
20
+ torch.classes.load_library(library_path)
21
+ logging.info(f"Successfully loaded: '{filename}'")
22
+
23
+ except Exception as error:
24
+ logging.error(f"Could not load the library '{filename}'!\n\n\n{error}\n\n\n")
25
+ if not no_throw:
26
+ raise error
27
+
28
+
29
+ try:
30
+ # Export the version string from the version file auto-generated by setup.py
31
+ from .version import __target__, __variant__, __version__ # noqa: F401, E402
32
+ except Exception:
33
+ __variant__: str = "INTERNAL"
34
+ __version__: str = "INTERNAL"
35
+ __target__: str = "default"
36
+
37
+ _default_libraries = [
38
+ "mslk",
39
+ ]
40
+
41
+ libraries_to_load = {
42
+ "default": _default_libraries,
43
+ }
44
+
45
+ for library in libraries_to_load.get(__target__, []):
46
+ # NOTE: In all cases, we want to throw an error if we cannot load the
47
+ # library. However, this appears to break the OSS documentation build,
48
+ # where the Python documentation doesn't show up in the generated docs.
49
+ #
50
+ # To work around this problem, we introduce a fake build variant called
51
+ # `docs` and we only throw a library load error when the variant is not
52
+ # `docs`. For more information, see:
53
+ #
54
+ # https://github.com/pytorch/FBGEMM/pull/3477
55
+ # https://github.com/pytorch/FBGEMM/pull/3717
56
+ _load_library(f"{library}.so", __version__, __variant__ == "docs")
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,30 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from mslk.utils.torch.library import load_library_buck
10
+
11
+ from . import cutlass_blackwell_fmha_custom_op # noqa: F401
12
+ from .cutlass_blackwell_fmha_interface import ( # noqa: F401
13
+ _cutlass_blackwell_fmha_forward,
14
+ cutlass_blackwell_fmha_decode_forward,
15
+ cutlass_blackwell_fmha_func,
16
+ )
17
+
18
+ load_library_buck(
19
+ "//mslk/csrc/attention/cuda/cutlass_blackwell_fmha:blackwell_attention_ops_gpu"
20
+ )
21
+
22
+ # Note: _cutlass_blackwell_fmha_forward is an internal function (indicated by leading underscore)
23
+ # that is exported here specifically for testing purposes. It allows tests to access the LSE
24
+ # (log-sum-exp) values returned by the forward pass without modifying the public API.
25
+ # Production code should use cutlass_blackwell_fmha_func instead.
26
+ __all__ = [
27
+ "_cutlass_blackwell_fmha_forward",
28
+ "cutlass_blackwell_fmha_decode_forward",
29
+ "cutlass_blackwell_fmha_func",
30
+ ]
@@ -0,0 +1,332 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch.library import register_fake
11
+
12
+
13
+ torch.library.define(
14
+ "mslk::cutlass_blackwell_fmha_fwd",
15
+ "(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv, Tensor? page_table, int seqlen_k=-1, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True) -> (Tensor, Tensor)",
16
+ tags=torch.Tag.pt2_compliant_tag,
17
+ )
18
+
19
+ torch.library.define(
20
+ "mslk::cutlass_blackwell_fmha_bwd",
21
+ "(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True, bool deterministic=False) -> (Tensor, Tensor, Tensor)",
22
+ tags=torch.Tag.pt2_compliant_tag,
23
+ )
24
+
25
+
26
+ @torch.library.impl("mslk::cutlass_blackwell_fmha_fwd", "cuda")
27
+ def custom_op_fmha(
28
+ q: torch.Tensor,
29
+ k: torch.Tensor,
30
+ v: torch.Tensor,
31
+ cu_seqlens_q: Optional[torch.Tensor] = None,
32
+ cu_seqlens_k: Optional[torch.Tensor] = None,
33
+ max_seq_len_q: Optional[int] = None,
34
+ max_seq_len_k: Optional[int] = None,
35
+ softmax_scale: Optional[float] = None,
36
+ causal: bool = False,
37
+ seqlen_kv: Optional[torch.Tensor] = None,
38
+ page_table: Optional[torch.Tensor] = None,
39
+ seqlen_k: Optional[int] = None,
40
+ window_size_left: int = -1,
41
+ window_size_right: int = -1,
42
+ bottom_right: bool = True,
43
+ ) -> tuple[torch.Tensor, torch.Tensor]:
44
+ assert q.is_contiguous(), "q is not contiguous"
45
+ assert k.is_contiguous(), "k is not contiguous"
46
+ assert v.is_contiguous(), "v is not contiguous"
47
+ assert q.is_cuda, "q must be on GPU"
48
+ assert k.is_cuda, "k must be on GPU"
49
+ assert v.is_cuda, "v must be on GPU"
50
+
51
+ return torch.ops.mslk.fmha_fwd(
52
+ q,
53
+ k,
54
+ v,
55
+ cu_seqlens_q=cu_seqlens_q,
56
+ cu_seqlens_k=cu_seqlens_k,
57
+ max_seq_len_q=max_seq_len_q,
58
+ max_seq_len_k=max_seq_len_k,
59
+ softmax_scale=softmax_scale,
60
+ causal=causal,
61
+ seqlen_kv=seqlen_kv,
62
+ page_table=page_table,
63
+ seqlen_k=seqlen_k,
64
+ window_size_left=window_size_left,
65
+ window_size_right=window_size_right,
66
+ bottom_right=bottom_right,
67
+ )
68
+
69
+
70
+ @register_fake("mslk::cutlass_blackwell_fmha_fwd")
71
+ def fmha_fwd_meta(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ cu_seqlens_q: Optional[torch.Tensor] = None,
76
+ cu_seqlens_k: Optional[torch.Tensor] = None,
77
+ max_seq_len_q: Optional[int] = None,
78
+ max_seq_len_k: Optional[int] = None,
79
+ softmax_scale: Optional[float] = None,
80
+ causal: bool = False,
81
+ seqlen_kv: Optional[torch.Tensor] = None,
82
+ page_table: Optional[torch.Tensor] = None,
83
+ seqlen_k: Optional[int] = None,
84
+ window_size_left: int = -1,
85
+ window_size_right: int = -1,
86
+ bottom_right: bool = True,
87
+ ):
88
+ if q.dtype == torch.float16:
89
+ out_dtype = torch.float16
90
+ elif q.dtype == torch.bfloat16:
91
+ out_dtype = torch.bfloat16
92
+ elif q.dtype == torch.float8_e4m3fn:
93
+ # Output is BF16 when input is FP8
94
+ out_dtype = torch.bfloat16
95
+ else:
96
+ raise RuntimeError(f"Unsupported dtype for q: {q.dtype}")
97
+
98
+ kIsVarlen = max_seq_len_q is not None
99
+ if kIsVarlen:
100
+ assert cu_seqlens_q is not None
101
+ SQ = q.shape[0]
102
+ H_Q = q.shape[1]
103
+ B = cu_seqlens_q.shape[0] - 1
104
+ else:
105
+ SQ = q.shape[1]
106
+ H_Q = q.shape[2]
107
+ B = q.shape[0]
108
+ device = q.device
109
+ options2 = {"dtype": torch.float32, "device": device}
110
+ if kIsVarlen:
111
+ assert max_seq_len_q is not None
112
+ out = torch.empty_like(q, dtype=out_dtype)
113
+ size = out.size()
114
+ stride = out.stride()
115
+ storage_offset = q.shape[-1] * max_seq_len_q * H_Q # example scalar offset
116
+ out1 = torch.as_strided(
117
+ out, size=size, stride=stride, storage_offset=storage_offset
118
+ )
119
+ else:
120
+ out1 = torch.empty_like(q, dtype=out_dtype)
121
+
122
+ if kIsVarlen:
123
+ out2 = torch.empty((1, H_Q, SQ), **options2) # type: ignore
124
+ else:
125
+ out2 = torch.empty((B, H_Q, SQ), **options2) # type: ignore
126
+ return out1, out2
127
+
128
+
129
+ @torch.library.impl("mslk::cutlass_blackwell_fmha_bwd", "cuda")
130
+ def custom_op_fmha_bwd(
131
+ dOutput: torch.Tensor,
132
+ query: torch.Tensor,
133
+ key: torch.Tensor,
134
+ value: torch.Tensor,
135
+ output: torch.Tensor,
136
+ softmax_lse: torch.Tensor,
137
+ cu_seqlens_q: Optional[torch.Tensor] = None,
138
+ cu_seqlens_k: Optional[torch.Tensor] = None,
139
+ max_seq_len_q: Optional[int] = None,
140
+ max_seq_len_k: Optional[int] = None,
141
+ softmax_scale: Optional[float] = None,
142
+ causal: bool = False,
143
+ window_size_left: int = -1,
144
+ window_size_right: int = -1,
145
+ bottom_right: bool = True,
146
+ deterministic: bool = False,
147
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148
+ return torch.ops.mslk.fmha_bwd(
149
+ dOutput,
150
+ query,
151
+ key,
152
+ value,
153
+ output,
154
+ softmax_lse,
155
+ cu_seqlens_q=cu_seqlens_q,
156
+ cu_seqlens_k=cu_seqlens_k,
157
+ max_seq_len_q=max_seq_len_q,
158
+ max_seq_len_k=max_seq_len_k,
159
+ softmax_scale=softmax_scale,
160
+ causal=causal,
161
+ window_size_left=window_size_left,
162
+ window_size_right=window_size_right,
163
+ bottom_right=bottom_right,
164
+ deterministic=deterministic,
165
+ )
166
+
167
+
168
+ @register_fake("mslk::cutlass_blackwell_fmha_bwd")
169
+ def fmha_bwd_meta(
170
+ dOutput: torch.Tensor,
171
+ query: torch.Tensor,
172
+ key: torch.Tensor,
173
+ value: torch.Tensor,
174
+ output: torch.Tensor,
175
+ softmax_lse: torch.Tensor,
176
+ cu_seqlens_q: Optional[torch.Tensor] = None,
177
+ cu_seqlens_k: Optional[torch.Tensor] = None,
178
+ max_seq_len_q: Optional[int] = None,
179
+ max_seq_len_k: Optional[int] = None,
180
+ softmax_scale: Optional[float] = None,
181
+ causal: bool = False,
182
+ window_size_left: int = -1,
183
+ window_size_right: int = -1,
184
+ bottom_right: bool = True,
185
+ deterministic: bool = False,
186
+ ):
187
+ return (
188
+ torch.empty_like(query),
189
+ torch.empty_like(key),
190
+ torch.empty_like(value),
191
+ )
192
+
193
+
194
+ def _backward(ctx, *grad):
195
+ if ctx.is_gen:
196
+ # For gen case, no backward pass is needed (generation is inference only)
197
+ raise RuntimeError("Backward pass is not supported for generation phase (sq=1)")
198
+ q, k, v, out, softmax_lse = ctx.saved_tensors
199
+ if not grad[0].is_contiguous():
200
+ grad0 = grad[0].contiguous()
201
+ else:
202
+ grad0 = grad[0]
203
+ if not softmax_lse.is_contiguous:
204
+ softmax_lse = softmax_lse.contiguous()
205
+ if not out.is_contiguous:
206
+ out = out.contiguous()
207
+ if not q.is_contiguous:
208
+ q = q.contiguous()
209
+ if not k.is_contiguous:
210
+ k = k.contiguous()
211
+
212
+ if not softmax_lse.is_contiguous:
213
+ softmax_lse = softmax_lse.contiguous()
214
+ if not out.is_contiguous:
215
+ out = out.contiguous()
216
+ if not q.is_contiguous:
217
+ q = q.contiguous()
218
+ if not k.is_contiguous:
219
+ k = k.contiguous()
220
+
221
+ dq, dk, dv = torch.ops.mslk.cutlass_blackwell_fmha_bwd(
222
+ grad0,
223
+ q,
224
+ k,
225
+ v,
226
+ out,
227
+ softmax_lse,
228
+ ctx.cu_seqlens_q,
229
+ ctx.cu_seqlens_k,
230
+ ctx.max_seq_len_q,
231
+ ctx.max_seq_len_k,
232
+ ctx.softmax_scale,
233
+ ctx.causal,
234
+ ctx.window_size_left,
235
+ ctx.window_size_right,
236
+ ctx.bottom_right,
237
+ ctx.deterministic,
238
+ )
239
+ return (
240
+ dq,
241
+ dk,
242
+ dv,
243
+ None,
244
+ None,
245
+ None,
246
+ None,
247
+ None,
248
+ None,
249
+ None,
250
+ None,
251
+ None,
252
+ None,
253
+ None,
254
+ None,
255
+ )
256
+
257
+
258
+ def _setup_context(ctx, inputs, output):
259
+ (
260
+ q,
261
+ k,
262
+ v,
263
+ cu_seqlens_q,
264
+ cu_seqlens_k,
265
+ max_seq_len_q,
266
+ max_seq_len_k,
267
+ softmax_scale,
268
+ causal,
269
+ seqlen_kv,
270
+ page_table,
271
+ seqlen_k,
272
+ window_size_left,
273
+ window_size_right,
274
+ bottom_right,
275
+ ) = inputs
276
+ (out, softmax_lse) = output
277
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
278
+ ctx.softmax_scale = softmax_scale
279
+ ctx.causal = causal
280
+ ctx.max_seq_len_q = max_seq_len_q
281
+ ctx.max_seq_len_k = max_seq_len_k
282
+ ctx.cu_seqlens_q = cu_seqlens_q
283
+ ctx.cu_seqlens_k = cu_seqlens_k
284
+ ctx.window_size_left = window_size_left
285
+ ctx.window_size_right = window_size_right
286
+ ctx.bottom_right = bottom_right
287
+ ctx.deterministic = False # Set default value
288
+ ctx.is_gen = False
289
+
290
+
291
+ # This code adds training support for the operator. You must provide us
292
+ # the backward formula for the operator and a `setup_context` function
293
+ # to save values to be used in the backward.
294
+ torch.library.register_autograd(
295
+ "mslk::cutlass_blackwell_fmha_fwd", _backward, setup_context=_setup_context
296
+ )
297
+
298
+
299
+ def cutlass_blackwell_fmha_custom_op(
300
+ q: torch.Tensor,
301
+ k: torch.Tensor,
302
+ v: torch.Tensor,
303
+ softmax_scale: float | None = None,
304
+ causal: bool = False,
305
+ cu_seqlens_q: torch.Tensor | None = None,
306
+ cu_seqlens_k: torch.Tensor | None = None,
307
+ max_seq_len_q: int | None = None,
308
+ max_seq_len_k: int | None = None,
309
+ seqlen_kv: torch.Tensor | None = None,
310
+ page_table: torch.Tensor | None = None,
311
+ seqlen_k: int | None = -1,
312
+ window_size_left: int | None = -1,
313
+ window_size_right: int | None = -1,
314
+ bottom_right: bool | None = True,
315
+ ):
316
+ return torch.ops.mslk.cutlass_blackwell_fmha_fwd(
317
+ q=q,
318
+ k=k,
319
+ v=v,
320
+ cu_seqlens_q=cu_seqlens_q,
321
+ cu_seqlens_k=cu_seqlens_k,
322
+ max_seq_len_q=max_seq_len_q,
323
+ max_seq_len_k=max_seq_len_k,
324
+ softmax_scale=softmax_scale,
325
+ causal=causal,
326
+ seqlen_kv=seqlen_kv,
327
+ page_table=page_table,
328
+ seqlen_k=seqlen_k,
329
+ window_size_left=window_size_left,
330
+ window_size_right=window_size_right,
331
+ bottom_right=bottom_right,
332
+ )[0]