mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
from typing import Any, Iterable, List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
|
13
|
+
from .common import AttentionFwOpBase, check_lastdim_alignment_stride1, Context, Inputs
|
|
14
|
+
from .utils.op_common import get_operator, register_operator
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_operator
|
|
18
|
+
class FwOp(AttentionFwOpBase):
|
|
19
|
+
OPERATOR = get_operator("xformers", "efficient_attention_forward_decoder_splitk_ck")
|
|
20
|
+
SUPPORTED_DEVICES = {"cuda"}
|
|
21
|
+
SUPPORTED_DTYPES = {
|
|
22
|
+
torch.half,
|
|
23
|
+
torch.bfloat16,
|
|
24
|
+
torch.float,
|
|
25
|
+
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
|
|
26
|
+
SUPPORTED_MAX_K = 256
|
|
27
|
+
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
|
|
28
|
+
type(None),
|
|
29
|
+
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
30
|
+
)
|
|
31
|
+
SUPPORTS_DROPOUT = False
|
|
32
|
+
SUPPORTS_CUSTOM_SCALE = True
|
|
33
|
+
SUPPORTS_BMGHK = True
|
|
34
|
+
NAME = "ck_splitKF"
|
|
35
|
+
|
|
36
|
+
SPLIT_K: Optional[int] = None
|
|
37
|
+
BLOCK_M = 16
|
|
38
|
+
BLOCK_N = 64
|
|
39
|
+
|
|
40
|
+
NUM_GROUPS = 1 # Default quantization is row-wise
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def shape_not_supported_reasons(
|
|
44
|
+
cls, Mq: int, Mkv: int, K: int, Kv: int
|
|
45
|
+
) -> List[str]:
|
|
46
|
+
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
|
47
|
+
# if K not in {16, 32, 64, 128}:
|
|
48
|
+
# reasons.append(f"Embed dim {K} not supported")
|
|
49
|
+
return reasons
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
53
|
+
reasons = super(FwOp, cls).not_supported_reasons(d)
|
|
54
|
+
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
|
55
|
+
if d.key.dtype != torch.int32:
|
|
56
|
+
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
|
57
|
+
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
|
58
|
+
if cls.OPERATOR is None:
|
|
59
|
+
reasons.append("triton is not available")
|
|
60
|
+
if d.device.type == "cuda":
|
|
61
|
+
# Has only been tested on 8.0 / 9.0.
|
|
62
|
+
if torch.cuda.get_device_capability(d.device) < (7, 0):
|
|
63
|
+
reasons.append(
|
|
64
|
+
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
q_len = d.query.shape[1]
|
|
68
|
+
if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
|
69
|
+
seqinfo = d.attn_bias.q_seqinfo
|
|
70
|
+
if q_len != seqinfo.seqstart_py[-1]:
|
|
71
|
+
reasons.append(
|
|
72
|
+
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
|
|
73
|
+
)
|
|
74
|
+
q_len = seqinfo.min_seqlen
|
|
75
|
+
if q_len != seqinfo.max_seqlen:
|
|
76
|
+
reasons.append(
|
|
77
|
+
"Variable query len is not supported in the presence of causal mask."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
|
|
81
|
+
if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
|
|
82
|
+
reasons.append("multiquery is only supported with query seqlen=1")
|
|
83
|
+
|
|
84
|
+
if d.attn_bias is not None and q_len > 1:
|
|
85
|
+
reasons.append(
|
|
86
|
+
"query with seqlen > 1 is not supported in the presence of causal mask"
|
|
87
|
+
)
|
|
88
|
+
return reasons
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
|
|
92
|
+
"""Heuristic for the number of splits"""
|
|
93
|
+
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
|
|
94
|
+
split_k = max(Mk, 1024) // bh
|
|
95
|
+
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
|
|
96
|
+
while split_k > 0 and Mk / split_k < max_chunk_size:
|
|
97
|
+
split_k = split_k // 2
|
|
98
|
+
split_k = min(split_k, 64)
|
|
99
|
+
split_k = max(split_k, 1)
|
|
100
|
+
return split_k
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def apply(
|
|
104
|
+
cls, inp: Inputs, needs_gradient: bool
|
|
105
|
+
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
106
|
+
attn_bias = inp.attn_bias
|
|
107
|
+
q, k, v = inp.get_qkv_in_bmghk()
|
|
108
|
+
|
|
109
|
+
if attn_bias is not None:
|
|
110
|
+
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
|
111
|
+
attn_bias.k_seqinfo.to(k.device)
|
|
112
|
+
attn_bias.q_seqinfo.to(q.device)
|
|
113
|
+
padding = attn_bias.k_seqinfo.padding
|
|
114
|
+
seq_positions_gpu = attn_bias.k_seqinfo.seqlen
|
|
115
|
+
else:
|
|
116
|
+
padding = k.shape[1]
|
|
117
|
+
seq_positions_gpu = None
|
|
118
|
+
|
|
119
|
+
if attn_bias is not None:
|
|
120
|
+
# key: (1, B * padding, G, 1 if multiquery else Hkv, D)
|
|
121
|
+
# value: like key
|
|
122
|
+
# query: (1, B * q_seqlen, G, Hq, D)
|
|
123
|
+
multiquery = k.stride(3) == 0
|
|
124
|
+
if multiquery:
|
|
125
|
+
key = k[0, :, :, :1].unflatten(0, (-1, padding))
|
|
126
|
+
value = v[0, :, :, :1].unflatten(0, (-1, padding))
|
|
127
|
+
else:
|
|
128
|
+
key = k[0].unflatten(0, (-1, padding))
|
|
129
|
+
value = v[0].unflatten(0, (-1, padding))
|
|
130
|
+
query = q[0].unflatten(0, (key.shape[0], -1))
|
|
131
|
+
else:
|
|
132
|
+
# key: (B, padding, G, 1 if multiquery else Hkv, D)
|
|
133
|
+
# value: like key
|
|
134
|
+
# query: (B, q_seqlen, G, Hq, D)
|
|
135
|
+
key = k
|
|
136
|
+
query = q
|
|
137
|
+
value = v
|
|
138
|
+
|
|
139
|
+
B, _, _, H, _ = query.shape
|
|
140
|
+
_, Mk, _, _, _ = key.shape
|
|
141
|
+
|
|
142
|
+
if cls.SPLIT_K is not None:
|
|
143
|
+
split_k = cls.SPLIT_K
|
|
144
|
+
else:
|
|
145
|
+
# Use heuristics
|
|
146
|
+
split_k = cls.get_split_k(B, H, Mk)
|
|
147
|
+
|
|
148
|
+
if inp.scale is not None:
|
|
149
|
+
qk_scale = inp.scale
|
|
150
|
+
else:
|
|
151
|
+
qk_scale = torch.rsqrt(
|
|
152
|
+
torch.tensor(k.shape[-1], dtype=torch.float32)
|
|
153
|
+
).item()
|
|
154
|
+
|
|
155
|
+
out = cls.OPERATOR(
|
|
156
|
+
query=query,
|
|
157
|
+
key=key,
|
|
158
|
+
value=value,
|
|
159
|
+
seq_positions=seq_positions_gpu,
|
|
160
|
+
scale=qk_scale,
|
|
161
|
+
split_k=split_k,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return out, None
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class FwOp_S1(FwOp):
|
|
168
|
+
SPLIT_K = 1
|
|
169
|
+
NAME = "ck_splitK1"
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class FwOp_S2(FwOp):
|
|
173
|
+
SPLIT_K = 2
|
|
174
|
+
NAME = "ck_splitK2"
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class FwOp_S4(FwOp):
|
|
178
|
+
SPLIT_K = 4
|
|
179
|
+
NAME = "ck_splitK4"
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class FwOp_S8(FwOp):
|
|
183
|
+
SPLIT_K = 8
|
|
184
|
+
NAME = "ck_splitK8"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class FwOp_S16(FwOp):
|
|
188
|
+
SPLIT_K = 16
|
|
189
|
+
NAME = "ck_splitK16"
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class FwOp_S32(FwOp):
|
|
193
|
+
SPLIT_K = 32
|
|
194
|
+
NAME = "ck_splitK32"
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class FwOp_S64(FwOp):
|
|
198
|
+
SPLIT_K = 64
|
|
199
|
+
NAME = "ck_splitK64"
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class FwOp_S128(FwOp):
|
|
203
|
+
SPLIT_K = 128
|
|
204
|
+
NAME = "ck_splitK128"
|