mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,177 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from typing import Callable, Dict, List
10
+
11
+ import click
12
+ import pandas as pd
13
+ import torch
14
+ import triton # @manual
15
+ from mslk.gemm.triton.grouped_gemm import grouped_gemm
16
+
17
+
18
+ def triton_fused_bench(
19
+ x: torch.Tensor,
20
+ w: torch.Tensor,
21
+ m_sizes: torch.Tensor,
22
+ bias: torch.Tensor,
23
+ token_weights: torch.Tensor,
24
+ ) -> Callable[[], torch.Tensor]:
25
+ """Factory for Triton fused grouped_gemm + bias + token_weights."""
26
+
27
+ def run() -> torch.Tensor:
28
+ return grouped_gemm(x, w, m_sizes, bias=bias, token_weights=token_weights)
29
+
30
+ return run
31
+
32
+
33
+ @torch.compile(mode="reduce-overhead")
34
+ def _torch_bmm_bias_scale(
35
+ x: torch.Tensor,
36
+ w: torch.Tensor,
37
+ bias: torch.Tensor,
38
+ token_weights: torch.Tensor,
39
+ G: int,
40
+ M_per_group: int,
41
+ ) -> torch.Tensor:
42
+ """Compiled torch baseline: bmm + bias + scale."""
43
+ N = w.shape[0] // G
44
+ K = w.shape[1]
45
+ x_3d = x.view(G, M_per_group, K)
46
+ w_3d = w.view(G, N, K)
47
+ out = torch.bmm(x_3d, w_3d.transpose(-1, -2))
48
+ out = out + bias.unsqueeze(1)
49
+ out = out * token_weights.view(G, M_per_group, 1)
50
+ return out.view(-1, N)
51
+
52
+
53
+ def torch_baseline_bench(
54
+ x: torch.Tensor,
55
+ w: torch.Tensor,
56
+ bias: torch.Tensor,
57
+ token_weights: torch.Tensor,
58
+ G: int,
59
+ M_per_group: int,
60
+ ) -> Callable[[], torch.Tensor]:
61
+ """Factory for torch.compile'd batched matmul baseline."""
62
+
63
+ def run() -> torch.Tensor:
64
+ return _torch_bmm_bias_scale(x, w, bias, token_weights, G, M_per_group)
65
+
66
+ return run
67
+
68
+
69
+ def triton_gemm_torch_bias_scale_bench(
70
+ x: torch.Tensor,
71
+ w: torch.Tensor,
72
+ m_sizes: torch.Tensor,
73
+ bias: torch.Tensor,
74
+ token_weights: torch.Tensor,
75
+ G: int,
76
+ M_per_group: int,
77
+ ) -> Callable[[], torch.Tensor]:
78
+ """Factory for Triton grouped_gemm + torch bias + torch token_weights."""
79
+
80
+ def run() -> torch.Tensor:
81
+ out = grouped_gemm(x, w, m_sizes)
82
+ out_3d = out.view(G, M_per_group, -1)
83
+ out_3d = out_3d + bias.unsqueeze(1)
84
+ out_3d = out_3d * token_weights.view(G, M_per_group, 1)
85
+ return out_3d.view(-1, out.shape[-1])
86
+
87
+ return run
88
+
89
+
90
+ @click.command()
91
+ @click.option("--warmup", type=int, default=25, help="Warmup iterations")
92
+ @click.option("--rep", type=int, default=25, help="Benchmark repetitions")
93
+ def bench(warmup: int, rep: int) -> None:
94
+ """Benchmark grouped_gemm_bias_scale vs torch baseline."""
95
+ device = torch.accelerator.current_accelerator()
96
+ dtype = torch.bfloat16
97
+
98
+ # G: Number of experts/groups in the MoE layer
99
+ # M: Total number of tokens across all groups
100
+ # N: Output dimension (hidden size of expert output)
101
+ # K: Input dimension (hidden size of expert input)
102
+ configs = [
103
+ {"G": 4, "M": 512, "N": 256, "K": 256, "name": "Small"},
104
+ {"G": 16, "M": 4096, "N": 512, "K": 512, "name": "Medium"},
105
+ {"G": 64, "M": 16384, "N": 512, "K": 512, "name": "Large"},
106
+ ]
107
+
108
+ # Print configuration table
109
+ config_df = pd.DataFrame(configs).rename(
110
+ columns={
111
+ "name": "Config",
112
+ "G": "G (experts)",
113
+ "M": "M (tokens)",
114
+ "N": "N (out_dim)",
115
+ "K": "K (in_dim)",
116
+ }
117
+ )[["Config", "G (experts)", "M (tokens)", "N (out_dim)", "K (in_dim)"]]
118
+ print("\nBenchmark Configurations:")
119
+ print(config_df.to_string(index=False))
120
+ print()
121
+
122
+ results: List[Dict[str, str]] = []
123
+
124
+ for idx, cfg in enumerate(configs):
125
+ G: int = cfg["G"] # pyre-ignore[9]
126
+ M: int = cfg["M"] # pyre-ignore[9]
127
+ N: int = cfg["N"] # pyre-ignore[9]
128
+ K: int = cfg["K"] # pyre-ignore[9]
129
+ name: str = cfg["name"] # pyre-ignore[9]
130
+ M_per_group = M // G
131
+
132
+ print(f"Processing config {idx + 1}/{len(configs)}: {name}...")
133
+
134
+ # Create tensors
135
+ x = torch.randn(M, K, dtype=dtype, device=device)
136
+ w = torch.randn(G * N, K, dtype=dtype, device=device)
137
+ bias = torch.randn(G, N, dtype=dtype, device=device)
138
+ token_weights = torch.rand(M, dtype=dtype, device=device) + 0.5
139
+ m_sizes = torch.full((G,), M_per_group, dtype=torch.int32, device=device)
140
+
141
+ # Create benchmark functions
142
+ triton_fn = triton_fused_bench(x, w, m_sizes, bias, token_weights)
143
+ triton_torch_fn = triton_gemm_torch_bias_scale_bench(
144
+ x, w, m_sizes, bias, token_weights, G, M_per_group
145
+ )
146
+ torch_fn = torch_baseline_bench(x, w, bias, token_weights, G, M_per_group)
147
+
148
+ # Warmup torch.compile
149
+ for _ in range(3):
150
+ torch_fn()
151
+ torch.cuda.synchronize()
152
+
153
+ # Benchmark
154
+ fused_ms = triton.testing.do_bench(triton_fn, warmup=warmup, rep=rep)
155
+ triton_torch_ms = triton.testing.do_bench(
156
+ triton_torch_fn, warmup=warmup, rep=rep
157
+ )
158
+ torch_ms = triton.testing.do_bench(torch_fn, warmup=warmup, rep=rep)
159
+
160
+ results.append(
161
+ {
162
+ "Config": name,
163
+ "fused (ms)": f"{fused_ms:.3f}",
164
+ "triton+torch (ms)": f"{triton_torch_ms:.3f}",
165
+ "torch (ms)": f"{torch_ms:.3f}",
166
+ "Speedup vs torch": f"{torch_ms / fused_ms:.2f}x",
167
+ "Speedup vs triton+torch": f"{triton_torch_ms / fused_ms:.2f}x",
168
+ }
169
+ )
170
+
171
+ print("\nBenchmark Results:")
172
+ print(pd.DataFrame(results).to_string(index=False))
173
+ print()
174
+
175
+
176
+ if __name__ == "__main__":
177
+ bench()
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,356 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import functools
8
+ import itertools
9
+ from typing import Optional
10
+
11
+ import click
12
+ import torch
13
+ import triton # noqa: F401
14
+ from mslk.moe import (
15
+ combine_shuffling,
16
+ gather_scale_dense_tokens,
17
+ gather_scale_quant_dense_tokens,
18
+ scatter_add_dense_tokens,
19
+ split_shuffling,
20
+ )
21
+ from triton.testing import do_bench, do_bench_cudagraph
22
+
23
+
24
+ index_shuffling = None
25
+ gather_along_first_dim = None
26
+ scatter_add_along_first_dim = None
27
+
28
+ if torch.cuda.is_available():
29
+ index_shuffling = torch.ops.mslk.index_shuffling # noqa F401
30
+ if not torch.version.hip:
31
+ # SM90 support
32
+ gather_along_first_dim = torch.ops.mslk.gather_along_first_dim # noqa F401
33
+ scatter_add_along_first_dim = torch.ops.mslk.scatter_add_along_first_dim # noqa F401
34
+
35
+
36
+ _ACCELERATOR_TAG = torch.accelerator.current_accelerator()
37
+
38
+
39
+ def bench_gather_along_first_dim(M: int, N: int, K: int) -> None:
40
+ src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
41
+ if M == N:
42
+ indices = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int32)
43
+ else:
44
+ indices = torch.randint(0, M, [N], device=_ACCELERATOR_TAG, dtype=torch.int32)
45
+
46
+ def fn():
47
+ return torch.ops.mslk.gather_along_first_dim(src, indices)
48
+
49
+ def ref_fn():
50
+ return torch.index_select(src, 0, indices)
51
+
52
+ # Load src, store dst. x2.
53
+ data_size_in_gigabytes = N * K * 2 * 2 / 1e9
54
+
55
+ time_in_us = triton.testing.do_bench(fn) * 1e3
56
+ time_in_second = time_in_us / 1e6
57
+ gigabytes_per_second = data_size_in_gigabytes / time_in_second
58
+
59
+ ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
60
+ ref_time_in_second = ref_time_in_us / 1e6
61
+ ref_gigabytes_per_second = data_size_in_gigabytes / ref_time_in_second
62
+
63
+ print(
64
+ f"Benchmark gather_along_first_dim: {M=:5d}, {N=:5d}, {K=:5d}, "
65
+ f"MSLK time: {time_in_us:10.3f} us. Bandwidth: {gigabytes_per_second:10.3f} GB/s, "
66
+ f"Torch time: {ref_time_in_us:10.3f} us. Bandwidth: {ref_gigabytes_per_second:10.3f} GB/s"
67
+ )
68
+
69
+
70
+ def bench_scatter_add_along_first_dim_(op, M: int, N: int, K: int) -> None:
71
+ src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
72
+ dst = torch.randn([N, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
73
+ if M == N:
74
+ indices_1d = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int64)
75
+ else:
76
+ indices_1d = torch.randint(
77
+ 0, N, [M], device=_ACCELERATOR_TAG, dtype=torch.int64
78
+ )
79
+
80
+ indices_2d = indices_1d.to(torch.int64).unsqueeze(1).expand(-1, K)
81
+
82
+ test_dst = dst.clone()
83
+ ref_dst = dst.clone()
84
+
85
+ def fn():
86
+ op(test_dst, src, indices_1d)
87
+
88
+ def ref_fn():
89
+ ref_dst.scatter_add_(0, indices_2d, src)
90
+
91
+ # Load src, load dst, store dst. x3.
92
+ data_size_in_gigabytes = N * K * 2 * 3 / 1e9
93
+
94
+ time_in_us = triton.testing.do_bench(fn) * 1e3
95
+ time_in_second = time_in_us / 1e6
96
+ gigabytes_per_second = data_size_in_gigabytes / time_in_second
97
+
98
+ ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
99
+ ref_time_in_second = ref_time_in_us / 1e6
100
+ ref_gigabytes_per_second = data_size_in_gigabytes / ref_time_in_second
101
+
102
+ print(
103
+ f"Benchmark {op.__name__}: {M=:5d}, {N=:5d}, {K=:5d}, "
104
+ f"MSLK time: {time_in_us:10.3f} us. Bandwidth: {gigabytes_per_second:10.3f} GB/s, "
105
+ f"Torch time: {ref_time_in_us:10.3f} us. Bandwidth: {ref_gigabytes_per_second:10.3f} GB/s"
106
+ )
107
+
108
+
109
+ bench_scatter_add_along_first_dim = functools.partial(
110
+ bench_scatter_add_along_first_dim_, scatter_add_along_first_dim
111
+ )
112
+
113
+ bench_scatter_add_dense_tokens = functools.partial(
114
+ bench_scatter_add_along_first_dim_, scatter_add_dense_tokens
115
+ )
116
+
117
+
118
+ def bench_gather_scale_dense_tokens(E: int, T: int, D: int, quantize: bool):
119
+ x = torch.randn((T, D), dtype=torch.bfloat16, device=_ACCELERATOR_TAG).abs()
120
+ expert_indices = torch.randint(0, E, (T,), device=_ACCELERATOR_TAG)
121
+ token_indices = torch.randperm(T, device=_ACCELERATOR_TAG)
122
+ scores = torch.rand((E, T), dtype=torch.bfloat16, device=_ACCELERATOR_TAG)
123
+
124
+ def torch_fn():
125
+ shuffled_x = torch.index_select(x, dim=0, index=token_indices)
126
+ shuffled_scores = torch.index_select(scores, dim=1, index=token_indices)
127
+ shuffled_selected_scores = torch.gather(
128
+ shuffled_scores, dim=0, index=expert_indices.view(1, T)
129
+ )
130
+ ref_output = shuffled_x * shuffled_selected_scores.view(-1, 1)
131
+ return ref_output
132
+
133
+ torch_fn()
134
+
135
+ scores_TE = scores.transpose(0, 1).contiguous()
136
+
137
+ mslk_fn = gather_scale_quant_dense_tokens if quantize else gather_scale_dense_tokens
138
+
139
+ def triton_fn():
140
+ test_output = mslk_fn(x, token_indices, expert_indices, scores_TE)
141
+ return test_output
142
+
143
+ triton_fn()
144
+
145
+ # Run benchmark
146
+ if quantize:
147
+ data_size_in_gigabytes = T * D * 3 / 1e9
148
+ else:
149
+ data_size_in_gigabytes = T * D * 4 / 1e9
150
+
151
+ mslk_time = do_bench(triton_fn, rep=1000) * 1e3
152
+ mslk_bw = data_size_in_gigabytes / (mslk_time / 1e6)
153
+
154
+ torch_time = do_bench(torch_fn, rep=1000) * 1e3
155
+ torch_bw = data_size_in_gigabytes / (torch_time / 1e6)
156
+ print(
157
+ f"Benchmark gather_scale_dense_tokens({quantize=}), {E=:3d}, {T=:5d}, {D=:5d}, "
158
+ f"MSLK time: {mslk_time:10.3f} us. Bandwidth: {mslk_bw:10.3f} GB/s, "
159
+ f"Torch time: {torch_time:10.3f} us. Bandwidth: {torch_bw:10.3f} GB/s"
160
+ )
161
+
162
+
163
+ def bench_topk_index_shuffling(T: int, E: int, K: int) -> None:
164
+ torch.manual_seed(0)
165
+
166
+ num_rotating_buffers = min(max(2, triton.cdiv(1024 * 1024 * 1024, T * E * 2)), 1000)
167
+ scores_list: list[torch.Tensor] = [
168
+ torch.randn(T, E, device=_ACCELERATOR_TAG, dtype=torch.bfloat16)
169
+ for i in range(num_rotating_buffers)
170
+ ]
171
+
172
+ def fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
173
+ for scores in scores_list:
174
+ index_shuffling(scores, top_k=K)
175
+
176
+ def ref_fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
177
+ for scores in scores_list:
178
+ _, selected_expert_indices = torch.topk(scores, K, dim=1)
179
+ expert_indices, _ = torch.sort(
180
+ selected_expert_indices.flatten(), dim=0, stable=True
181
+ )
182
+ _ = (
183
+ expert_indices[:, None]
184
+ == torch.arange(E, device=expert_indices.device)[None, :]
185
+ ).sum(dim=0)
186
+
187
+ mslk_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
188
+ torch_time = do_bench_cudagraph(ref_fn) * 1e3 / num_rotating_buffers
189
+ print(
190
+ f"Benchmark index_shuffling, num_tokens={T:4}, num_experts={E:4}, top_k={K:4}, "
191
+ f"mslk_time={mslk_time:7.3f}us, torch_time={torch_time:7.3f}us"
192
+ )
193
+
194
+
195
+ def bench_combine_or_split_shuffling(
196
+ T: int,
197
+ D: int,
198
+ E: int,
199
+ EP: bool,
200
+ is_padded: bool,
201
+ is_balanced: bool,
202
+ is_combine_shuffling: bool,
203
+ ):
204
+ torch.manual_seed(0)
205
+
206
+ assert E % EP == 0
207
+ if is_padded:
208
+ # graph. allgather
209
+ input_num_tokens: int = EP * T
210
+ input_num_experts: int = E
211
+ output_num_experts: int = E // EP
212
+ start_expert_index: int = 1
213
+ end_expert_index: int = 1 + output_num_experts
214
+ else:
215
+ # eager. all2all
216
+ input_num_tokens: int = T
217
+ input_num_experts: int = E // EP
218
+ output_num_experts: int = E // EP
219
+ start_expert_index: int = 0
220
+ end_expert_index: int = output_num_experts
221
+
222
+ tokens = torch.randn(
223
+ input_num_tokens, D, device=_ACCELERATOR_TAG, dtype=torch.bfloat16
224
+ )
225
+
226
+ if input_num_tokens < (EP * input_num_experts) != 0:
227
+ return
228
+
229
+ input_num_tokens_per_expert: int = input_num_tokens // (EP * input_num_experts)
230
+ token_counts: torch.Tensor = (
231
+ torch.ones(
232
+ [EP, input_num_experts],
233
+ dtype=torch.int32,
234
+ device=_ACCELERATOR_TAG,
235
+ )
236
+ * input_num_tokens_per_expert
237
+ )
238
+ if not is_balanced:
239
+ for i in range(EP):
240
+ token_counts[i, start_expert_index] -= input_num_tokens_per_expert
241
+ token_counts[i, end_expert_index - 1] += input_num_tokens_per_expert
242
+
243
+ assert token_counts.sum().item() == input_num_tokens
244
+
245
+ num_rotating_buffers = triton.cdiv(1024 * 1024 * 1024, tokens.numel() * 2)
246
+ token_list: list[torch.Tensor] = [
247
+ tokens.clone() for _ in range(num_rotating_buffers)
248
+ ]
249
+ token_count_list: list[torch.Tensor] = [
250
+ token_counts.clone() for _ in range(num_rotating_buffers)
251
+ ]
252
+
253
+ def fn() -> tuple[torch.Tensor, Optional[torch.Tensor]]:
254
+ for tokens, token_counts in zip(token_list, token_count_list):
255
+ if is_combine_shuffling:
256
+ combine_shuffling(
257
+ tokens,
258
+ token_counts,
259
+ expert_start=start_expert_index,
260
+ expert_end=end_expert_index,
261
+ is_balanced=is_balanced,
262
+ )
263
+ else:
264
+ split_shuffling(
265
+ tokens,
266
+ token_counts,
267
+ expert_start=start_expert_index,
268
+ expert_end=end_expert_index,
269
+ is_balanced=is_balanced,
270
+ )
271
+
272
+ fn()
273
+
274
+ output_num_tokens = 0
275
+ for per_rank_counts in token_counts.tolist():
276
+ for expert_index, per_expert_counts in enumerate(per_rank_counts):
277
+ if expert_index >= start_expert_index and expert_index < end_expert_index:
278
+ output_num_tokens += per_expert_counts
279
+
280
+ mem_bytes = output_num_tokens * D * 2 * 2
281
+ mslk_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
282
+ mslk_bw = mem_bytes * 1e-9 / (mslk_time * 1e-6)
283
+
284
+ print(
285
+ f"Benchmark {'combine_shuffling' if is_combine_shuffling else 'split_shuffling'}, "
286
+ f"num_tokens={T:4}, dim={D:4}, num_experts={E:4}, expert_parallelism={EP:4}, output_num_tokens={output_num_tokens:4}, "
287
+ f"{is_balanced=}, {is_padded=}, "
288
+ f"mslk_time={mslk_time:7.3f}us, mslk_bw={mslk_bw:8.3f}GBytes/s."
289
+ )
290
+
291
+
292
+ @click.command()
293
+ @click.option(
294
+ "--kernels",
295
+ default=None,
296
+ help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
297
+ )
298
+ def main(kernels: Optional[str]):
299
+ if kernels is not None:
300
+ kernels = kernels.split(",")
301
+
302
+ def should_bench_kernel(fn):
303
+ return (fn is not None) and (kernels is None or fn.__name__ in kernels)
304
+
305
+ Es = [16, 128]
306
+ Ts = [1, 128, 2048, 4096, 8192, 16384]
307
+ Ds = [5120]
308
+
309
+ # Gather/Scatter
310
+ if should_bench_kernel(gather_scale_dense_tokens):
311
+ for E, T, D in itertools.product(Es, Ts, Ds):
312
+ bench_gather_scale_dense_tokens(E, T, D, quantize=False)
313
+
314
+ if should_bench_kernel(gather_scale_quant_dense_tokens):
315
+ for E, T, D in itertools.product(Es, Ts, Ds):
316
+ bench_gather_scale_dense_tokens(E, T, D, quantize=True)
317
+
318
+ if should_bench_kernel(gather_along_first_dim):
319
+ for T, D in itertools.product(Ts, Ds):
320
+ bench_gather_along_first_dim(T, T, D)
321
+
322
+ if should_bench_kernel(scatter_add_along_first_dim):
323
+ for T, D in itertools.product(Ts, Ds):
324
+ bench_scatter_add_along_first_dim(T, T, D)
325
+
326
+ if should_bench_kernel(scatter_add_dense_tokens):
327
+ for T, D in itertools.product(Ts, Ds):
328
+ bench_scatter_add_dense_tokens(T, T, D)
329
+
330
+ Ks = [1, 2, 4]
331
+ Es = [16, 32, 128, 320]
332
+ # Shuffling
333
+ if should_bench_kernel(index_shuffling):
334
+ for T, E, K in itertools.product(Ts, Es, Ks):
335
+ bench_topk_index_shuffling(T, E, K)
336
+
337
+ EPs = [2, 16]
338
+ Ts = [32, 128, 2048, 4096, 8192, 16384]
339
+ padded = [True, False]
340
+ balanced = [True, False]
341
+
342
+ if should_bench_kernel(combine_shuffling):
343
+ for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
344
+ bench_combine_or_split_shuffling(
345
+ T, D, E, EP, p, b, is_combine_shuffling=True
346
+ )
347
+
348
+ if should_bench_kernel(split_shuffling):
349
+ for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
350
+ bench_combine_or_split_shuffling(
351
+ T, D, E, EP, p, b, is_combine_shuffling=False
352
+ )
353
+
354
+
355
+ if __name__ == "__main__":
356
+ main()