mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
from typing import Callable, Dict, List
|
|
10
|
+
|
|
11
|
+
import click
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import torch
|
|
14
|
+
import triton # @manual
|
|
15
|
+
from mslk.gemm.triton.grouped_gemm import grouped_gemm
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def triton_fused_bench(
|
|
19
|
+
x: torch.Tensor,
|
|
20
|
+
w: torch.Tensor,
|
|
21
|
+
m_sizes: torch.Tensor,
|
|
22
|
+
bias: torch.Tensor,
|
|
23
|
+
token_weights: torch.Tensor,
|
|
24
|
+
) -> Callable[[], torch.Tensor]:
|
|
25
|
+
"""Factory for Triton fused grouped_gemm + bias + token_weights."""
|
|
26
|
+
|
|
27
|
+
def run() -> torch.Tensor:
|
|
28
|
+
return grouped_gemm(x, w, m_sizes, bias=bias, token_weights=token_weights)
|
|
29
|
+
|
|
30
|
+
return run
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@torch.compile(mode="reduce-overhead")
|
|
34
|
+
def _torch_bmm_bias_scale(
|
|
35
|
+
x: torch.Tensor,
|
|
36
|
+
w: torch.Tensor,
|
|
37
|
+
bias: torch.Tensor,
|
|
38
|
+
token_weights: torch.Tensor,
|
|
39
|
+
G: int,
|
|
40
|
+
M_per_group: int,
|
|
41
|
+
) -> torch.Tensor:
|
|
42
|
+
"""Compiled torch baseline: bmm + bias + scale."""
|
|
43
|
+
N = w.shape[0] // G
|
|
44
|
+
K = w.shape[1]
|
|
45
|
+
x_3d = x.view(G, M_per_group, K)
|
|
46
|
+
w_3d = w.view(G, N, K)
|
|
47
|
+
out = torch.bmm(x_3d, w_3d.transpose(-1, -2))
|
|
48
|
+
out = out + bias.unsqueeze(1)
|
|
49
|
+
out = out * token_weights.view(G, M_per_group, 1)
|
|
50
|
+
return out.view(-1, N)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def torch_baseline_bench(
|
|
54
|
+
x: torch.Tensor,
|
|
55
|
+
w: torch.Tensor,
|
|
56
|
+
bias: torch.Tensor,
|
|
57
|
+
token_weights: torch.Tensor,
|
|
58
|
+
G: int,
|
|
59
|
+
M_per_group: int,
|
|
60
|
+
) -> Callable[[], torch.Tensor]:
|
|
61
|
+
"""Factory for torch.compile'd batched matmul baseline."""
|
|
62
|
+
|
|
63
|
+
def run() -> torch.Tensor:
|
|
64
|
+
return _torch_bmm_bias_scale(x, w, bias, token_weights, G, M_per_group)
|
|
65
|
+
|
|
66
|
+
return run
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def triton_gemm_torch_bias_scale_bench(
|
|
70
|
+
x: torch.Tensor,
|
|
71
|
+
w: torch.Tensor,
|
|
72
|
+
m_sizes: torch.Tensor,
|
|
73
|
+
bias: torch.Tensor,
|
|
74
|
+
token_weights: torch.Tensor,
|
|
75
|
+
G: int,
|
|
76
|
+
M_per_group: int,
|
|
77
|
+
) -> Callable[[], torch.Tensor]:
|
|
78
|
+
"""Factory for Triton grouped_gemm + torch bias + torch token_weights."""
|
|
79
|
+
|
|
80
|
+
def run() -> torch.Tensor:
|
|
81
|
+
out = grouped_gemm(x, w, m_sizes)
|
|
82
|
+
out_3d = out.view(G, M_per_group, -1)
|
|
83
|
+
out_3d = out_3d + bias.unsqueeze(1)
|
|
84
|
+
out_3d = out_3d * token_weights.view(G, M_per_group, 1)
|
|
85
|
+
return out_3d.view(-1, out.shape[-1])
|
|
86
|
+
|
|
87
|
+
return run
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@click.command()
|
|
91
|
+
@click.option("--warmup", type=int, default=25, help="Warmup iterations")
|
|
92
|
+
@click.option("--rep", type=int, default=25, help="Benchmark repetitions")
|
|
93
|
+
def bench(warmup: int, rep: int) -> None:
|
|
94
|
+
"""Benchmark grouped_gemm_bias_scale vs torch baseline."""
|
|
95
|
+
device = torch.accelerator.current_accelerator()
|
|
96
|
+
dtype = torch.bfloat16
|
|
97
|
+
|
|
98
|
+
# G: Number of experts/groups in the MoE layer
|
|
99
|
+
# M: Total number of tokens across all groups
|
|
100
|
+
# N: Output dimension (hidden size of expert output)
|
|
101
|
+
# K: Input dimension (hidden size of expert input)
|
|
102
|
+
configs = [
|
|
103
|
+
{"G": 4, "M": 512, "N": 256, "K": 256, "name": "Small"},
|
|
104
|
+
{"G": 16, "M": 4096, "N": 512, "K": 512, "name": "Medium"},
|
|
105
|
+
{"G": 64, "M": 16384, "N": 512, "K": 512, "name": "Large"},
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
# Print configuration table
|
|
109
|
+
config_df = pd.DataFrame(configs).rename(
|
|
110
|
+
columns={
|
|
111
|
+
"name": "Config",
|
|
112
|
+
"G": "G (experts)",
|
|
113
|
+
"M": "M (tokens)",
|
|
114
|
+
"N": "N (out_dim)",
|
|
115
|
+
"K": "K (in_dim)",
|
|
116
|
+
}
|
|
117
|
+
)[["Config", "G (experts)", "M (tokens)", "N (out_dim)", "K (in_dim)"]]
|
|
118
|
+
print("\nBenchmark Configurations:")
|
|
119
|
+
print(config_df.to_string(index=False))
|
|
120
|
+
print()
|
|
121
|
+
|
|
122
|
+
results: List[Dict[str, str]] = []
|
|
123
|
+
|
|
124
|
+
for idx, cfg in enumerate(configs):
|
|
125
|
+
G: int = cfg["G"] # pyre-ignore[9]
|
|
126
|
+
M: int = cfg["M"] # pyre-ignore[9]
|
|
127
|
+
N: int = cfg["N"] # pyre-ignore[9]
|
|
128
|
+
K: int = cfg["K"] # pyre-ignore[9]
|
|
129
|
+
name: str = cfg["name"] # pyre-ignore[9]
|
|
130
|
+
M_per_group = M // G
|
|
131
|
+
|
|
132
|
+
print(f"Processing config {idx + 1}/{len(configs)}: {name}...")
|
|
133
|
+
|
|
134
|
+
# Create tensors
|
|
135
|
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
|
136
|
+
w = torch.randn(G * N, K, dtype=dtype, device=device)
|
|
137
|
+
bias = torch.randn(G, N, dtype=dtype, device=device)
|
|
138
|
+
token_weights = torch.rand(M, dtype=dtype, device=device) + 0.5
|
|
139
|
+
m_sizes = torch.full((G,), M_per_group, dtype=torch.int32, device=device)
|
|
140
|
+
|
|
141
|
+
# Create benchmark functions
|
|
142
|
+
triton_fn = triton_fused_bench(x, w, m_sizes, bias, token_weights)
|
|
143
|
+
triton_torch_fn = triton_gemm_torch_bias_scale_bench(
|
|
144
|
+
x, w, m_sizes, bias, token_weights, G, M_per_group
|
|
145
|
+
)
|
|
146
|
+
torch_fn = torch_baseline_bench(x, w, bias, token_weights, G, M_per_group)
|
|
147
|
+
|
|
148
|
+
# Warmup torch.compile
|
|
149
|
+
for _ in range(3):
|
|
150
|
+
torch_fn()
|
|
151
|
+
torch.cuda.synchronize()
|
|
152
|
+
|
|
153
|
+
# Benchmark
|
|
154
|
+
fused_ms = triton.testing.do_bench(triton_fn, warmup=warmup, rep=rep)
|
|
155
|
+
triton_torch_ms = triton.testing.do_bench(
|
|
156
|
+
triton_torch_fn, warmup=warmup, rep=rep
|
|
157
|
+
)
|
|
158
|
+
torch_ms = triton.testing.do_bench(torch_fn, warmup=warmup, rep=rep)
|
|
159
|
+
|
|
160
|
+
results.append(
|
|
161
|
+
{
|
|
162
|
+
"Config": name,
|
|
163
|
+
"fused (ms)": f"{fused_ms:.3f}",
|
|
164
|
+
"triton+torch (ms)": f"{triton_torch_ms:.3f}",
|
|
165
|
+
"torch (ms)": f"{torch_ms:.3f}",
|
|
166
|
+
"Speedup vs torch": f"{torch_ms / fused_ms:.2f}x",
|
|
167
|
+
"Speedup vs triton+torch": f"{triton_torch_ms / fused_ms:.2f}x",
|
|
168
|
+
}
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
print("\nBenchmark Results:")
|
|
172
|
+
print(pd.DataFrame(results).to_string(index=False))
|
|
173
|
+
print()
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
if __name__ == "__main__":
|
|
177
|
+
bench()
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import itertools
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import click
|
|
12
|
+
import torch
|
|
13
|
+
import triton # noqa: F401
|
|
14
|
+
from mslk.moe import (
|
|
15
|
+
combine_shuffling,
|
|
16
|
+
gather_scale_dense_tokens,
|
|
17
|
+
gather_scale_quant_dense_tokens,
|
|
18
|
+
scatter_add_dense_tokens,
|
|
19
|
+
split_shuffling,
|
|
20
|
+
)
|
|
21
|
+
from triton.testing import do_bench, do_bench_cudagraph
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
index_shuffling = None
|
|
25
|
+
gather_along_first_dim = None
|
|
26
|
+
scatter_add_along_first_dim = None
|
|
27
|
+
|
|
28
|
+
if torch.cuda.is_available():
|
|
29
|
+
index_shuffling = torch.ops.mslk.index_shuffling # noqa F401
|
|
30
|
+
if not torch.version.hip:
|
|
31
|
+
# SM90 support
|
|
32
|
+
gather_along_first_dim = torch.ops.mslk.gather_along_first_dim # noqa F401
|
|
33
|
+
scatter_add_along_first_dim = torch.ops.mslk.scatter_add_along_first_dim # noqa F401
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_ACCELERATOR_TAG = torch.accelerator.current_accelerator()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def bench_gather_along_first_dim(M: int, N: int, K: int) -> None:
|
|
40
|
+
src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
|
|
41
|
+
if M == N:
|
|
42
|
+
indices = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int32)
|
|
43
|
+
else:
|
|
44
|
+
indices = torch.randint(0, M, [N], device=_ACCELERATOR_TAG, dtype=torch.int32)
|
|
45
|
+
|
|
46
|
+
def fn():
|
|
47
|
+
return torch.ops.mslk.gather_along_first_dim(src, indices)
|
|
48
|
+
|
|
49
|
+
def ref_fn():
|
|
50
|
+
return torch.index_select(src, 0, indices)
|
|
51
|
+
|
|
52
|
+
# Load src, store dst. x2.
|
|
53
|
+
data_size_in_gigabytes = N * K * 2 * 2 / 1e9
|
|
54
|
+
|
|
55
|
+
time_in_us = triton.testing.do_bench(fn) * 1e3
|
|
56
|
+
time_in_second = time_in_us / 1e6
|
|
57
|
+
gigabytes_per_second = data_size_in_gigabytes / time_in_second
|
|
58
|
+
|
|
59
|
+
ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
|
|
60
|
+
ref_time_in_second = ref_time_in_us / 1e6
|
|
61
|
+
ref_gigabytes_per_second = data_size_in_gigabytes / ref_time_in_second
|
|
62
|
+
|
|
63
|
+
print(
|
|
64
|
+
f"Benchmark gather_along_first_dim: {M=:5d}, {N=:5d}, {K=:5d}, "
|
|
65
|
+
f"MSLK time: {time_in_us:10.3f} us. Bandwidth: {gigabytes_per_second:10.3f} GB/s, "
|
|
66
|
+
f"Torch time: {ref_time_in_us:10.3f} us. Bandwidth: {ref_gigabytes_per_second:10.3f} GB/s"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def bench_scatter_add_along_first_dim_(op, M: int, N: int, K: int) -> None:
|
|
71
|
+
src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
|
|
72
|
+
dst = torch.randn([N, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
|
|
73
|
+
if M == N:
|
|
74
|
+
indices_1d = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int64)
|
|
75
|
+
else:
|
|
76
|
+
indices_1d = torch.randint(
|
|
77
|
+
0, N, [M], device=_ACCELERATOR_TAG, dtype=torch.int64
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
indices_2d = indices_1d.to(torch.int64).unsqueeze(1).expand(-1, K)
|
|
81
|
+
|
|
82
|
+
test_dst = dst.clone()
|
|
83
|
+
ref_dst = dst.clone()
|
|
84
|
+
|
|
85
|
+
def fn():
|
|
86
|
+
op(test_dst, src, indices_1d)
|
|
87
|
+
|
|
88
|
+
def ref_fn():
|
|
89
|
+
ref_dst.scatter_add_(0, indices_2d, src)
|
|
90
|
+
|
|
91
|
+
# Load src, load dst, store dst. x3.
|
|
92
|
+
data_size_in_gigabytes = N * K * 2 * 3 / 1e9
|
|
93
|
+
|
|
94
|
+
time_in_us = triton.testing.do_bench(fn) * 1e3
|
|
95
|
+
time_in_second = time_in_us / 1e6
|
|
96
|
+
gigabytes_per_second = data_size_in_gigabytes / time_in_second
|
|
97
|
+
|
|
98
|
+
ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
|
|
99
|
+
ref_time_in_second = ref_time_in_us / 1e6
|
|
100
|
+
ref_gigabytes_per_second = data_size_in_gigabytes / ref_time_in_second
|
|
101
|
+
|
|
102
|
+
print(
|
|
103
|
+
f"Benchmark {op.__name__}: {M=:5d}, {N=:5d}, {K=:5d}, "
|
|
104
|
+
f"MSLK time: {time_in_us:10.3f} us. Bandwidth: {gigabytes_per_second:10.3f} GB/s, "
|
|
105
|
+
f"Torch time: {ref_time_in_us:10.3f} us. Bandwidth: {ref_gigabytes_per_second:10.3f} GB/s"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
bench_scatter_add_along_first_dim = functools.partial(
|
|
110
|
+
bench_scatter_add_along_first_dim_, scatter_add_along_first_dim
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
bench_scatter_add_dense_tokens = functools.partial(
|
|
114
|
+
bench_scatter_add_along_first_dim_, scatter_add_dense_tokens
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def bench_gather_scale_dense_tokens(E: int, T: int, D: int, quantize: bool):
|
|
119
|
+
x = torch.randn((T, D), dtype=torch.bfloat16, device=_ACCELERATOR_TAG).abs()
|
|
120
|
+
expert_indices = torch.randint(0, E, (T,), device=_ACCELERATOR_TAG)
|
|
121
|
+
token_indices = torch.randperm(T, device=_ACCELERATOR_TAG)
|
|
122
|
+
scores = torch.rand((E, T), dtype=torch.bfloat16, device=_ACCELERATOR_TAG)
|
|
123
|
+
|
|
124
|
+
def torch_fn():
|
|
125
|
+
shuffled_x = torch.index_select(x, dim=0, index=token_indices)
|
|
126
|
+
shuffled_scores = torch.index_select(scores, dim=1, index=token_indices)
|
|
127
|
+
shuffled_selected_scores = torch.gather(
|
|
128
|
+
shuffled_scores, dim=0, index=expert_indices.view(1, T)
|
|
129
|
+
)
|
|
130
|
+
ref_output = shuffled_x * shuffled_selected_scores.view(-1, 1)
|
|
131
|
+
return ref_output
|
|
132
|
+
|
|
133
|
+
torch_fn()
|
|
134
|
+
|
|
135
|
+
scores_TE = scores.transpose(0, 1).contiguous()
|
|
136
|
+
|
|
137
|
+
mslk_fn = gather_scale_quant_dense_tokens if quantize else gather_scale_dense_tokens
|
|
138
|
+
|
|
139
|
+
def triton_fn():
|
|
140
|
+
test_output = mslk_fn(x, token_indices, expert_indices, scores_TE)
|
|
141
|
+
return test_output
|
|
142
|
+
|
|
143
|
+
triton_fn()
|
|
144
|
+
|
|
145
|
+
# Run benchmark
|
|
146
|
+
if quantize:
|
|
147
|
+
data_size_in_gigabytes = T * D * 3 / 1e9
|
|
148
|
+
else:
|
|
149
|
+
data_size_in_gigabytes = T * D * 4 / 1e9
|
|
150
|
+
|
|
151
|
+
mslk_time = do_bench(triton_fn, rep=1000) * 1e3
|
|
152
|
+
mslk_bw = data_size_in_gigabytes / (mslk_time / 1e6)
|
|
153
|
+
|
|
154
|
+
torch_time = do_bench(torch_fn, rep=1000) * 1e3
|
|
155
|
+
torch_bw = data_size_in_gigabytes / (torch_time / 1e6)
|
|
156
|
+
print(
|
|
157
|
+
f"Benchmark gather_scale_dense_tokens({quantize=}), {E=:3d}, {T=:5d}, {D=:5d}, "
|
|
158
|
+
f"MSLK time: {mslk_time:10.3f} us. Bandwidth: {mslk_bw:10.3f} GB/s, "
|
|
159
|
+
f"Torch time: {torch_time:10.3f} us. Bandwidth: {torch_bw:10.3f} GB/s"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def bench_topk_index_shuffling(T: int, E: int, K: int) -> None:
|
|
164
|
+
torch.manual_seed(0)
|
|
165
|
+
|
|
166
|
+
num_rotating_buffers = min(max(2, triton.cdiv(1024 * 1024 * 1024, T * E * 2)), 1000)
|
|
167
|
+
scores_list: list[torch.Tensor] = [
|
|
168
|
+
torch.randn(T, E, device=_ACCELERATOR_TAG, dtype=torch.bfloat16)
|
|
169
|
+
for i in range(num_rotating_buffers)
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
def fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
173
|
+
for scores in scores_list:
|
|
174
|
+
index_shuffling(scores, top_k=K)
|
|
175
|
+
|
|
176
|
+
def ref_fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
177
|
+
for scores in scores_list:
|
|
178
|
+
_, selected_expert_indices = torch.topk(scores, K, dim=1)
|
|
179
|
+
expert_indices, _ = torch.sort(
|
|
180
|
+
selected_expert_indices.flatten(), dim=0, stable=True
|
|
181
|
+
)
|
|
182
|
+
_ = (
|
|
183
|
+
expert_indices[:, None]
|
|
184
|
+
== torch.arange(E, device=expert_indices.device)[None, :]
|
|
185
|
+
).sum(dim=0)
|
|
186
|
+
|
|
187
|
+
mslk_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
|
|
188
|
+
torch_time = do_bench_cudagraph(ref_fn) * 1e3 / num_rotating_buffers
|
|
189
|
+
print(
|
|
190
|
+
f"Benchmark index_shuffling, num_tokens={T:4}, num_experts={E:4}, top_k={K:4}, "
|
|
191
|
+
f"mslk_time={mslk_time:7.3f}us, torch_time={torch_time:7.3f}us"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def bench_combine_or_split_shuffling(
|
|
196
|
+
T: int,
|
|
197
|
+
D: int,
|
|
198
|
+
E: int,
|
|
199
|
+
EP: bool,
|
|
200
|
+
is_padded: bool,
|
|
201
|
+
is_balanced: bool,
|
|
202
|
+
is_combine_shuffling: bool,
|
|
203
|
+
):
|
|
204
|
+
torch.manual_seed(0)
|
|
205
|
+
|
|
206
|
+
assert E % EP == 0
|
|
207
|
+
if is_padded:
|
|
208
|
+
# graph. allgather
|
|
209
|
+
input_num_tokens: int = EP * T
|
|
210
|
+
input_num_experts: int = E
|
|
211
|
+
output_num_experts: int = E // EP
|
|
212
|
+
start_expert_index: int = 1
|
|
213
|
+
end_expert_index: int = 1 + output_num_experts
|
|
214
|
+
else:
|
|
215
|
+
# eager. all2all
|
|
216
|
+
input_num_tokens: int = T
|
|
217
|
+
input_num_experts: int = E // EP
|
|
218
|
+
output_num_experts: int = E // EP
|
|
219
|
+
start_expert_index: int = 0
|
|
220
|
+
end_expert_index: int = output_num_experts
|
|
221
|
+
|
|
222
|
+
tokens = torch.randn(
|
|
223
|
+
input_num_tokens, D, device=_ACCELERATOR_TAG, dtype=torch.bfloat16
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if input_num_tokens < (EP * input_num_experts) != 0:
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
input_num_tokens_per_expert: int = input_num_tokens // (EP * input_num_experts)
|
|
230
|
+
token_counts: torch.Tensor = (
|
|
231
|
+
torch.ones(
|
|
232
|
+
[EP, input_num_experts],
|
|
233
|
+
dtype=torch.int32,
|
|
234
|
+
device=_ACCELERATOR_TAG,
|
|
235
|
+
)
|
|
236
|
+
* input_num_tokens_per_expert
|
|
237
|
+
)
|
|
238
|
+
if not is_balanced:
|
|
239
|
+
for i in range(EP):
|
|
240
|
+
token_counts[i, start_expert_index] -= input_num_tokens_per_expert
|
|
241
|
+
token_counts[i, end_expert_index - 1] += input_num_tokens_per_expert
|
|
242
|
+
|
|
243
|
+
assert token_counts.sum().item() == input_num_tokens
|
|
244
|
+
|
|
245
|
+
num_rotating_buffers = triton.cdiv(1024 * 1024 * 1024, tokens.numel() * 2)
|
|
246
|
+
token_list: list[torch.Tensor] = [
|
|
247
|
+
tokens.clone() for _ in range(num_rotating_buffers)
|
|
248
|
+
]
|
|
249
|
+
token_count_list: list[torch.Tensor] = [
|
|
250
|
+
token_counts.clone() for _ in range(num_rotating_buffers)
|
|
251
|
+
]
|
|
252
|
+
|
|
253
|
+
def fn() -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
254
|
+
for tokens, token_counts in zip(token_list, token_count_list):
|
|
255
|
+
if is_combine_shuffling:
|
|
256
|
+
combine_shuffling(
|
|
257
|
+
tokens,
|
|
258
|
+
token_counts,
|
|
259
|
+
expert_start=start_expert_index,
|
|
260
|
+
expert_end=end_expert_index,
|
|
261
|
+
is_balanced=is_balanced,
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
split_shuffling(
|
|
265
|
+
tokens,
|
|
266
|
+
token_counts,
|
|
267
|
+
expert_start=start_expert_index,
|
|
268
|
+
expert_end=end_expert_index,
|
|
269
|
+
is_balanced=is_balanced,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
fn()
|
|
273
|
+
|
|
274
|
+
output_num_tokens = 0
|
|
275
|
+
for per_rank_counts in token_counts.tolist():
|
|
276
|
+
for expert_index, per_expert_counts in enumerate(per_rank_counts):
|
|
277
|
+
if expert_index >= start_expert_index and expert_index < end_expert_index:
|
|
278
|
+
output_num_tokens += per_expert_counts
|
|
279
|
+
|
|
280
|
+
mem_bytes = output_num_tokens * D * 2 * 2
|
|
281
|
+
mslk_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
|
|
282
|
+
mslk_bw = mem_bytes * 1e-9 / (mslk_time * 1e-6)
|
|
283
|
+
|
|
284
|
+
print(
|
|
285
|
+
f"Benchmark {'combine_shuffling' if is_combine_shuffling else 'split_shuffling'}, "
|
|
286
|
+
f"num_tokens={T:4}, dim={D:4}, num_experts={E:4}, expert_parallelism={EP:4}, output_num_tokens={output_num_tokens:4}, "
|
|
287
|
+
f"{is_balanced=}, {is_padded=}, "
|
|
288
|
+
f"mslk_time={mslk_time:7.3f}us, mslk_bw={mslk_bw:8.3f}GBytes/s."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@click.command()
|
|
293
|
+
@click.option(
|
|
294
|
+
"--kernels",
|
|
295
|
+
default=None,
|
|
296
|
+
help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
|
|
297
|
+
)
|
|
298
|
+
def main(kernels: Optional[str]):
|
|
299
|
+
if kernels is not None:
|
|
300
|
+
kernels = kernels.split(",")
|
|
301
|
+
|
|
302
|
+
def should_bench_kernel(fn):
|
|
303
|
+
return (fn is not None) and (kernels is None or fn.__name__ in kernels)
|
|
304
|
+
|
|
305
|
+
Es = [16, 128]
|
|
306
|
+
Ts = [1, 128, 2048, 4096, 8192, 16384]
|
|
307
|
+
Ds = [5120]
|
|
308
|
+
|
|
309
|
+
# Gather/Scatter
|
|
310
|
+
if should_bench_kernel(gather_scale_dense_tokens):
|
|
311
|
+
for E, T, D in itertools.product(Es, Ts, Ds):
|
|
312
|
+
bench_gather_scale_dense_tokens(E, T, D, quantize=False)
|
|
313
|
+
|
|
314
|
+
if should_bench_kernel(gather_scale_quant_dense_tokens):
|
|
315
|
+
for E, T, D in itertools.product(Es, Ts, Ds):
|
|
316
|
+
bench_gather_scale_dense_tokens(E, T, D, quantize=True)
|
|
317
|
+
|
|
318
|
+
if should_bench_kernel(gather_along_first_dim):
|
|
319
|
+
for T, D in itertools.product(Ts, Ds):
|
|
320
|
+
bench_gather_along_first_dim(T, T, D)
|
|
321
|
+
|
|
322
|
+
if should_bench_kernel(scatter_add_along_first_dim):
|
|
323
|
+
for T, D in itertools.product(Ts, Ds):
|
|
324
|
+
bench_scatter_add_along_first_dim(T, T, D)
|
|
325
|
+
|
|
326
|
+
if should_bench_kernel(scatter_add_dense_tokens):
|
|
327
|
+
for T, D in itertools.product(Ts, Ds):
|
|
328
|
+
bench_scatter_add_dense_tokens(T, T, D)
|
|
329
|
+
|
|
330
|
+
Ks = [1, 2, 4]
|
|
331
|
+
Es = [16, 32, 128, 320]
|
|
332
|
+
# Shuffling
|
|
333
|
+
if should_bench_kernel(index_shuffling):
|
|
334
|
+
for T, E, K in itertools.product(Ts, Es, Ks):
|
|
335
|
+
bench_topk_index_shuffling(T, E, K)
|
|
336
|
+
|
|
337
|
+
EPs = [2, 16]
|
|
338
|
+
Ts = [32, 128, 2048, 4096, 8192, 16384]
|
|
339
|
+
padded = [True, False]
|
|
340
|
+
balanced = [True, False]
|
|
341
|
+
|
|
342
|
+
if should_bench_kernel(combine_shuffling):
|
|
343
|
+
for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
|
|
344
|
+
bench_combine_or_split_shuffling(
|
|
345
|
+
T, D, E, EP, p, b, is_combine_shuffling=True
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
if should_bench_kernel(split_shuffling):
|
|
349
|
+
for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
|
|
350
|
+
bench_combine_or_split_shuffling(
|
|
351
|
+
T, D, E, EP, p, b, is_combine_shuffling=False
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
if __name__ == "__main__":
|
|
356
|
+
main()
|