mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import textwrap
|
|
10
|
+
from collections import deque
|
|
11
|
+
from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from . import attn_bias, ck, cutlass, flash, flash3, flash_mtia, triton_splitk
|
|
16
|
+
from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import mtia.host_runtime.torch_mtia.dynamic_library # noqa # type: ignore
|
|
23
|
+
|
|
24
|
+
# Use MTIA flash attention if the MTIA libraries are available
|
|
25
|
+
_USE_MTIA_FLASH_ATTENTION = True
|
|
26
|
+
except (ImportError, OSError):
|
|
27
|
+
# Failed to load MTIA libraries, so don't use MTIA flash attention
|
|
28
|
+
_USE_MTIA_FLASH_ATTENTION = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
_USE_FLASH_ATTENTION_3 = False
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _set_use_fa3(use_flash_attention3: bool) -> None:
|
|
35
|
+
global _USE_FLASH_ATTENTION_3
|
|
36
|
+
_USE_FLASH_ATTENTION_3 = use_flash_attention3
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_use_fa3() -> bool:
|
|
40
|
+
return _USE_FLASH_ATTENTION_3
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def fa3_available() -> bool:
|
|
44
|
+
has_valid_flash3 = flash3._C_flashattention3 is not None # pyre-ignore[16]
|
|
45
|
+
is_90a = torch.version.cuda and torch.cuda.get_device_capability() >= (9, 0)
|
|
46
|
+
return has_valid_flash3 and is_90a
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _format_inputs_description(inp: Inputs) -> str:
|
|
50
|
+
return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
|
|
51
|
+
key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
|
|
52
|
+
value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
|
|
53
|
+
attn_bias : {type(inp.attn_bias)}
|
|
54
|
+
p : {inp.p}"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
|
|
58
|
+
reasons = op.not_supported_reasons(inp)
|
|
59
|
+
if not reasons:
|
|
60
|
+
return
|
|
61
|
+
raise exc_type(
|
|
62
|
+
f"""Operator `{name}` does not support inputs:
|
|
63
|
+
{textwrap.indent(_format_inputs_description(inp), " ")}
|
|
64
|
+
{_format_not_supported_reasons(op, reasons)}"""
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _format_not_supported_reasons(op, reasons: List[str]) -> str:
|
|
69
|
+
return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _run_priority_list(
|
|
73
|
+
name: str,
|
|
74
|
+
priority_list: Sequence[T],
|
|
75
|
+
inp: Inputs,
|
|
76
|
+
extra_op_reasons: Optional[List[Tuple[Any, List[str]]]] = None,
|
|
77
|
+
) -> T:
|
|
78
|
+
not_supported_reasons: List[List[str]] = []
|
|
79
|
+
for op in priority_list:
|
|
80
|
+
not_supported = op.not_supported_reasons(inp)
|
|
81
|
+
if not not_supported:
|
|
82
|
+
return op
|
|
83
|
+
not_supported_reasons.append(not_supported)
|
|
84
|
+
|
|
85
|
+
# Let's write a nice message explaining what we tried and why it's not supported
|
|
86
|
+
msg = f"""No operator found for `{name}` with inputs:
|
|
87
|
+
{textwrap.indent(_format_inputs_description(inp), " ")}"""
|
|
88
|
+
for op, not_supported in zip(priority_list, not_supported_reasons):
|
|
89
|
+
msg += "\n" + _format_not_supported_reasons(op, not_supported)
|
|
90
|
+
if extra_op_reasons is not None:
|
|
91
|
+
for op, not_supported in extra_op_reasons:
|
|
92
|
+
msg += "\n" + _format_not_supported_reasons(op, not_supported)
|
|
93
|
+
raise NotImplementedError(msg)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _dispatch_fw_priority_list(
|
|
97
|
+
inp: Inputs, needs_gradient: bool
|
|
98
|
+
) -> Sequence[Type[AttentionFwOpBase]]:
|
|
99
|
+
if torch.version.cuda:
|
|
100
|
+
flash3_op = [flash3.FwOp] if _get_use_fa3() else []
|
|
101
|
+
priority_list_ops = deque(
|
|
102
|
+
flash3_op
|
|
103
|
+
+ [
|
|
104
|
+
flash.FwOp,
|
|
105
|
+
cutlass.FwOp,
|
|
106
|
+
]
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
priority_list_ops = deque(
|
|
110
|
+
[
|
|
111
|
+
ck.FwOp,
|
|
112
|
+
]
|
|
113
|
+
)
|
|
114
|
+
priority_list_ops.append(triton_splitk.FwOp)
|
|
115
|
+
if not needs_gradient:
|
|
116
|
+
mqa_or_gqa = (
|
|
117
|
+
inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1
|
|
118
|
+
)
|
|
119
|
+
# Split-KV is useful with MQA
|
|
120
|
+
# for short Q-seqlen / long K-seqlen
|
|
121
|
+
if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256:
|
|
122
|
+
parallelism_BH = 0 # BMK
|
|
123
|
+
if inp.query.ndim == 3:
|
|
124
|
+
parallelism_BH = inp.query.shape[0]
|
|
125
|
+
elif inp.query.ndim == 4: # BMHK
|
|
126
|
+
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
|
127
|
+
elif inp.query.ndim == 5: # BMGHK
|
|
128
|
+
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
|
129
|
+
if (
|
|
130
|
+
parallelism_BH > 0
|
|
131
|
+
and parallelism_BH < 64
|
|
132
|
+
and not torch.mtia.is_available()
|
|
133
|
+
):
|
|
134
|
+
# priority_list_ops.appendleft(ck_splitk.FwOp)
|
|
135
|
+
priority_list_ops.remove(triton_splitk.FwOp)
|
|
136
|
+
priority_list_ops.appendleft(triton_splitk.FwOp)
|
|
137
|
+
# Without variable seqlen flash is fastest
|
|
138
|
+
if torch.version.cuda and not isinstance(
|
|
139
|
+
inp.attn_bias, attn_bias.BlockDiagonalMask
|
|
140
|
+
):
|
|
141
|
+
if _get_use_fa3():
|
|
142
|
+
priority_list_ops.remove(flash3.FwOp)
|
|
143
|
+
priority_list_ops.remove(flash.FwOp)
|
|
144
|
+
priority_list_ops.appendleft(flash.FwOp)
|
|
145
|
+
|
|
146
|
+
# torch.mtia.is_available() cannot be called here because it isn't supported
|
|
147
|
+
# when tracing with PT2, so we simply add flash_mtia to the end if the MTIA
|
|
148
|
+
# dynamic library can be loaded
|
|
149
|
+
if _USE_MTIA_FLASH_ATTENTION:
|
|
150
|
+
priority_list_ops.append(flash_mtia.FwOp)
|
|
151
|
+
|
|
152
|
+
return priority_list_ops
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
|
|
156
|
+
"""Computes the best operator for forward
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
NotImplementedError: if not operator was found
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
AttentionOp: The best operator for the configuration
|
|
163
|
+
"""
|
|
164
|
+
return _run_priority_list(
|
|
165
|
+
"memory_efficient_attention_forward",
|
|
166
|
+
_dispatch_fw_priority_list(inp, needs_gradient),
|
|
167
|
+
inp,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _dispatch_bw(
|
|
176
|
+
inp: Inputs, varlen_lse_packed: Optional[bool]
|
|
177
|
+
) -> Type[AttentionBwOpBase]:
|
|
178
|
+
if torch.version.cuda:
|
|
179
|
+
priority_list_ops: List[Type[AttentionBwOpBase]] = [
|
|
180
|
+
flash.BwOp,
|
|
181
|
+
cutlass.BwOp,
|
|
182
|
+
]
|
|
183
|
+
if _get_use_fa3():
|
|
184
|
+
priority_list_ops = [flash3.BwOp] + priority_list_ops
|
|
185
|
+
else:
|
|
186
|
+
priority_list_ops: List[Type[AttentionBwOpBase]] = [
|
|
187
|
+
ck.BwOp,
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
# NOTE: If we have a variable seqlen `attn_bias`, we need to get a BW pass
|
|
191
|
+
# that supports the LSE format
|
|
192
|
+
# *unless* we are in the case where both formats are the same (bs=1)
|
|
193
|
+
extra_op_reasons = []
|
|
194
|
+
if (
|
|
195
|
+
isinstance(inp.attn_bias, attn_bias.VARLEN_BIASES)
|
|
196
|
+
and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2
|
|
197
|
+
):
|
|
198
|
+
assert varlen_lse_packed is not None
|
|
199
|
+
for op in priority_list_ops:
|
|
200
|
+
if op.VARLEN_LSE_PACKED != varlen_lse_packed:
|
|
201
|
+
extra_op_reasons.append(
|
|
202
|
+
(
|
|
203
|
+
op,
|
|
204
|
+
[
|
|
205
|
+
f"LSE is in {'packed' if varlen_lse_packed else 'padded'} format"
|
|
206
|
+
],
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
priority_list_ops = [
|
|
210
|
+
op for op in priority_list_ops if op.VARLEN_LSE_PACKED == varlen_lse_packed
|
|
211
|
+
]
|
|
212
|
+
if torch.version.cuda and _is_cutlassB_faster_than_flash(inp):
|
|
213
|
+
priority_list_ops.remove(cutlass.BwOp)
|
|
214
|
+
priority_list_ops.insert(0, cutlass.BwOp)
|
|
215
|
+
|
|
216
|
+
# torch.mtia.is_available() cannot be called here because it isn't supported
|
|
217
|
+
# when tracing with PT2, so we simply add flash_mtia to the end if the MTIA
|
|
218
|
+
# dynamic library can be loaded
|
|
219
|
+
if _USE_MTIA_FLASH_ATTENTION:
|
|
220
|
+
priority_list_ops.append(flash_mtia.BwOp)
|
|
221
|
+
|
|
222
|
+
return _run_priority_list(
|
|
223
|
+
"memory_efficient_attention_backward", priority_list_ops, inp
|
|
224
|
+
)
|