mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
3
|
+
from typing import Tuple, Optional
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
from cutlass import Int32, const_expr
|
|
9
|
+
|
|
10
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class BlockInfo:
|
|
15
|
+
tile_m: cutlass.Constexpr[int]
|
|
16
|
+
tile_n: cutlass.Constexpr[int]
|
|
17
|
+
is_causal: cutlass.Constexpr[bool]
|
|
18
|
+
is_local: cutlass.Constexpr[bool] = False
|
|
19
|
+
is_split_kv: cutlass.Constexpr[bool] = False
|
|
20
|
+
window_size_left: Optional[Int32] = None
|
|
21
|
+
window_size_right: Optional[Int32] = None
|
|
22
|
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
|
23
|
+
|
|
24
|
+
@cute.jit
|
|
25
|
+
def get_n_block_min_max(
|
|
26
|
+
self,
|
|
27
|
+
seqlen_info: SeqlenInfoQK,
|
|
28
|
+
m_block: Int32,
|
|
29
|
+
split_idx: cutlass.Int32 = 0,
|
|
30
|
+
num_splits: cutlass.Int32 = 1,
|
|
31
|
+
) -> Tuple[Int32, Int32]:
|
|
32
|
+
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
|
|
33
|
+
if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
|
|
34
|
+
m_idx_max = (m_block + 1) * self.tile_m
|
|
35
|
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
|
36
|
+
m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
|
|
37
|
+
n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
|
38
|
+
n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right
|
|
39
|
+
n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n))
|
|
40
|
+
n_block_min = 0
|
|
41
|
+
if const_expr(self.is_local and self.window_size_left is not None):
|
|
42
|
+
m_idx_min = m_block * self.tile_m
|
|
43
|
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
|
44
|
+
m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
|
|
45
|
+
n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
|
46
|
+
n_idx_left = n_idx - self.window_size_left
|
|
47
|
+
n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
|
|
48
|
+
if cutlass.const_expr(self.is_split_kv):
|
|
49
|
+
num_n_blocks_per_split = (
|
|
50
|
+
cutlass.Int32(0)
|
|
51
|
+
if n_block_max <= n_block_min
|
|
52
|
+
else (n_block_max - n_block_min + num_splits - 1) // num_splits
|
|
53
|
+
)
|
|
54
|
+
n_block_min = n_block_min + split_idx * num_n_blocks_per_split
|
|
55
|
+
n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)
|
|
56
|
+
return n_block_min, n_block_max
|
|
57
|
+
|
|
58
|
+
@cute.jit
|
|
59
|
+
def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]:
|
|
60
|
+
m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m)
|
|
61
|
+
m_block_min = 0
|
|
62
|
+
if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
|
|
63
|
+
n_idx_min = n_block * self.tile_n
|
|
64
|
+
m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k
|
|
65
|
+
m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right
|
|
66
|
+
m_block_min = max(m_block_min, m_idx_right // self.tile_m)
|
|
67
|
+
if const_expr(self.is_local and self.window_size_left is not None):
|
|
68
|
+
n_idx_max = (n_block + 1) * self.tile_n
|
|
69
|
+
m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k
|
|
70
|
+
m_idx_left = m_idx + self.window_size_left
|
|
71
|
+
m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
|
|
72
|
+
return m_block_min, m_block_max
|
|
73
|
+
|
|
74
|
+
@cute.jit
|
|
75
|
+
def get_n_block_min_causal_local_mask(
|
|
76
|
+
self,
|
|
77
|
+
seqlen_info: SeqlenInfoQK,
|
|
78
|
+
m_block: Int32,
|
|
79
|
+
n_block_min: Int32,
|
|
80
|
+
) -> Int32:
|
|
81
|
+
"""If we have separate iterations with causal or local masking at the start, where do we stop"""
|
|
82
|
+
m_idx_min = m_block * self.tile_m
|
|
83
|
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
|
84
|
+
m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
|
|
85
|
+
n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
|
86
|
+
n_idx_right = (
|
|
87
|
+
n_idx
|
|
88
|
+
if const_expr(not self.is_local or self.window_size_right is None)
|
|
89
|
+
else n_idx + self.window_size_right
|
|
90
|
+
)
|
|
91
|
+
return cutlass.max(n_block_min, n_idx_right // self.tile_n)
|
|
92
|
+
|
|
93
|
+
@cute.jit
|
|
94
|
+
def get_n_block_min_before_local_mask(
|
|
95
|
+
self,
|
|
96
|
+
seqlen_info: SeqlenInfoQK,
|
|
97
|
+
m_block: Int32,
|
|
98
|
+
n_block_min: Int32,
|
|
99
|
+
) -> Int32:
|
|
100
|
+
"""If we have separate iterations with local masking at the end, where do we stop the non-masked iterations"""
|
|
101
|
+
if const_expr(not self.is_local or self.window_size_left is None):
|
|
102
|
+
return n_block_min
|
|
103
|
+
else:
|
|
104
|
+
m_idx_max = (m_block + 1) * self.tile_m
|
|
105
|
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
|
106
|
+
m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
|
|
107
|
+
n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
|
108
|
+
n_idx_left = n_idx - self.window_size_left
|
|
109
|
+
return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n))
|