mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
from typing import Type, Union, Optional
|
|
4
|
+
import cutlass
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
7
|
+
from cutlass.cute.nvgpu import warpgroup
|
|
8
|
+
from cutlass.cutlass_dsl import Numeric, dsl_user_op
|
|
9
|
+
from cutlass.utils import LayoutEnum
|
|
10
|
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@cute.jit
|
|
14
|
+
def gemm(
|
|
15
|
+
tiled_mma: cute.TiledMma,
|
|
16
|
+
acc: cute.Tensor,
|
|
17
|
+
tCrA: cute.Tensor,
|
|
18
|
+
tCrB: cute.Tensor,
|
|
19
|
+
zero_init: cutlass.Constexpr[bool] = False,
|
|
20
|
+
wg_wait: cutlass.Constexpr[int] = 0,
|
|
21
|
+
# A_in_regs: cutlass.Constexpr[bool] = False,
|
|
22
|
+
swap_AB: cutlass.Constexpr[bool] = False,
|
|
23
|
+
) -> None:
|
|
24
|
+
if const_expr(swap_AB):
|
|
25
|
+
gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
|
|
26
|
+
else:
|
|
27
|
+
warpgroup.fence()
|
|
28
|
+
# We make a new mma_atom since we'll be modifying its attribute (accumulate).
|
|
29
|
+
# Otherwise the compiler complains "operand #0 does not dominate this use"
|
|
30
|
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
|
31
|
+
mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
|
|
32
|
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
|
33
|
+
cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
|
34
|
+
mma_atom.set(warpgroup.Field.ACCUMULATE, True)
|
|
35
|
+
warpgroup.commit_group()
|
|
36
|
+
if const_expr(wg_wait >= 0):
|
|
37
|
+
warpgroup.wait_group(wg_wait)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def gemm_zero_init(
|
|
41
|
+
tiled_mma: cute.TiledMma,
|
|
42
|
+
shape: cute.Shape,
|
|
43
|
+
tCrA: cute.Tensor,
|
|
44
|
+
tCrB: cute.Tensor,
|
|
45
|
+
A_idx: Optional[Int32] = None,
|
|
46
|
+
B_idx: Optional[Int32] = None,
|
|
47
|
+
wg_wait: int = -1,
|
|
48
|
+
swap_AB: bool = False,
|
|
49
|
+
) -> cute.Tensor:
|
|
50
|
+
if const_expr(swap_AB):
|
|
51
|
+
return gemm_zero_init(
|
|
52
|
+
tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
|
|
56
|
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
57
|
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
58
|
+
gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
|
|
59
|
+
return acc
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def gemm_w_idx(
|
|
63
|
+
tiled_mma: cute.TiledMma,
|
|
64
|
+
acc: cute.Tensor,
|
|
65
|
+
tCrA: cute.Tensor,
|
|
66
|
+
tCrB: cute.Tensor,
|
|
67
|
+
zero_init: Boolean,
|
|
68
|
+
A_idx: Optional[Int32] = None,
|
|
69
|
+
B_idx: Optional[Int32] = None,
|
|
70
|
+
wg_wait: int = -1,
|
|
71
|
+
swap_AB: bool = False,
|
|
72
|
+
) -> None:
|
|
73
|
+
if const_expr(swap_AB):
|
|
74
|
+
gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
|
|
75
|
+
else:
|
|
76
|
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
77
|
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
78
|
+
gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dsl_user_op
|
|
82
|
+
def make_smem_layout(
|
|
83
|
+
dtype: Type[Numeric],
|
|
84
|
+
layout: LayoutEnum,
|
|
85
|
+
shape: cute.Shape,
|
|
86
|
+
stage: Optional[int] = None,
|
|
87
|
+
*,
|
|
88
|
+
loc=None,
|
|
89
|
+
ip=None,
|
|
90
|
+
) -> Union[cute.Layout, cute.ComposedLayout]:
|
|
91
|
+
major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
|
|
92
|
+
smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
93
|
+
sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
|
|
94
|
+
dtype,
|
|
95
|
+
)
|
|
96
|
+
order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
|
|
97
|
+
smem_layout_staged = cute.tile_to_shape(
|
|
98
|
+
smem_layout_atom,
|
|
99
|
+
cute.append(shape, stage) if const_expr(stage is not None) else shape,
|
|
100
|
+
order=order if const_expr(stage is not None) else order[:2],
|
|
101
|
+
)
|
|
102
|
+
return smem_layout_staged
|