fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,257 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import argparse
9
+
10
+ import os
11
+ import tempfile
12
+ import uuid
13
+ from functools import lru_cache
14
+ from pprint import pprint
15
+
16
+ import fbgemm_gpu.experimental.gen_ai # noqa: F401
17
+ import pandas as pd
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ import torch.distributed._symmetric_memory as symm_mem
22
+ from torch.distributed.launcher.api import elastic_launch, LaunchConfig
23
+
24
+
25
+ @lru_cache(None)
26
+ def get_symm_buffer(group):
27
+ inp = symm_mem.empty(
28
+ 16 * 1024 * 1024, device="cuda", dtype=torch.bfloat16
29
+ ) # .normal_()
30
+ symm_mem.rendezvous(inp, group=group)
31
+ return inp, group.group_name
32
+
33
+
34
+ def _setup(path: str) -> tuple[int, int]:
35
+ rank = int(os.environ["LOCAL_RANK"])
36
+ W = int(os.environ["WORLD_SIZE"])
37
+ device = torch.device(f"cuda:{rank}")
38
+ torch.cuda.set_device(device)
39
+ os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
40
+
41
+ torch.ops.fbgemm.nccl_init(rank, W, os.path.join(path, "rdvz"))
42
+ torch.distributed.init_process_group(
43
+ backend="cpu:gloo,cuda:nccl",
44
+ init_method=f"file://{os.path.join(path, 'gloo_rdvz')}",
45
+ world_size=W,
46
+ rank=rank,
47
+ )
48
+
49
+ buffer = torch.ops.fbgemm.car_tensor()
50
+ barrier = torch.ops.fbgemm.car_tensor()
51
+ barrier.zero_()
52
+
53
+ buffer_handle = torch.ops.fbgemm.car_ipc_handle(buffer)
54
+ all_buffer_handles = [torch.empty_like(buffer_handle) for _ in range(W)]
55
+ torch.distributed.all_gather(all_buffer_handles, buffer_handle)
56
+
57
+ barrier_handle = torch.ops.fbgemm.car_ipc_handle(barrier)
58
+ all_barrier_handles = [torch.empty_like(barrier_handle) for _ in range(W)]
59
+ torch.distributed.all_gather(all_barrier_handles, barrier_handle)
60
+ torch.ops.fbgemm.car_init(
61
+ rank, W, barrier, all_barrier_handles, buffer, all_buffer_handles
62
+ )
63
+ torch.cuda.synchronize()
64
+ torch.distributed.barrier()
65
+ group = dist.group.WORLD
66
+ _ = get_symm_buffer(group)
67
+ return rank, W
68
+
69
+
70
+ def symm_one_shot_allreduce(dst_tensor, src_tensor, bias=None, comm_idx=None):
71
+ # get_symm_buffer should be called for the first time during model init,
72
+ # and now return cached values. Make sure group is the same as during init
73
+ symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
74
+ symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
75
+ torch.ops.symm_mem.one_shot_all_reduce_copy_out(
76
+ symm_buffer, src_tensor, "sum", group_name, dst_tensor
77
+ )
78
+ if bias is not None:
79
+ dst_tensor.add_(bias)
80
+
81
+
82
+ def symm_two_shot_allreduce(dst_tensor, src_tensor, bias=None, comm_idx=None):
83
+ # get_symm_buffer should be called for the first time during model init,
84
+ # and now return cached values. Make sure group is the same as during init
85
+ symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
86
+ # car is also doing explicit copy
87
+ symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
88
+ symm_buffer.copy_(src_tensor)
89
+ torch.ops.symm_mem.two_shot_all_reduce_out(
90
+ symm_buffer, "sum", group_name, dst_tensor
91
+ )
92
+ if bias is not None:
93
+ dst_tensor.add_(bias)
94
+
95
+
96
+ def symm_reduce_scatter(dst_tensor, src_tensor, comm_idx=None):
97
+ symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
98
+ symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
99
+ symm_buffer.copy_(src_tensor)
100
+ torch.ops.symm_mem.reduce_scatter_out(symm_buffer, group_name, False, dst_tensor)
101
+
102
+
103
+ def run_one_algo(fn, out, inp, num_iters, num_warmup_iters):
104
+ start_event = torch.cuda.Event(enable_timing=True)
105
+ end_event = torch.cuda.Event(enable_timing=True)
106
+ for _ in range(num_warmup_iters):
107
+ fn(out, inp)
108
+ start_event.record()
109
+ for _ in range(num_iters):
110
+ fn(out, inp)
111
+ end_event.record()
112
+ torch.cuda.synchronize()
113
+ time = start_event.elapsed_time(end_event) / num_iters
114
+ return time
115
+
116
+
117
+ def run_benchmark(args, path):
118
+ rank, W = _setup(path)
119
+ if rank == 0:
120
+ print(f"Running benchmark with {W} ranks")
121
+ # benchmark_results = defaultdict(defaultdict)
122
+ benchmark_results = []
123
+ # with torch.profiler.profile() as p:
124
+ for N in torch.logspace(
125
+ args.min_size, args.max_size, steps=args.size_steps, base=2
126
+ ).tolist():
127
+
128
+ def round_up(a: int, b: int) -> int:
129
+ return ((a + b - 1) // b) * b
130
+
131
+ N_even_divisor = 8 * 64 if torch.version.hip else 8 * 32
132
+ N = round_up(int(N), N_even_divisor)
133
+ inp = torch.rand(N, dtype=torch.bfloat16, device="cuda")
134
+ results = {"N": N}
135
+ if args.op == "allreduce":
136
+ out = torch.full_like(inp, -1)
137
+ fns = (
138
+ torch.ops.fbgemm.one_shot_car_allreduce,
139
+ symm_one_shot_allreduce,
140
+ torch.ops.fbgemm.two_shot_car_allreduce,
141
+ symm_two_shot_allreduce,
142
+ torch.ops.fbgemm.nccl_allreduce,
143
+ )
144
+ labels = (
145
+ "fbgemm_1shot",
146
+ "symm_1shot",
147
+ "fbgemm_2shot",
148
+ "symm_2shot",
149
+ "nccl",
150
+ )
151
+ for fn, label in zip(fns, labels):
152
+ time = run_one_algo(
153
+ fn,
154
+ out,
155
+ inp,
156
+ args.num_iters,
157
+ args.num_warmup_iters,
158
+ )
159
+ results[f"{label}_time"] = time
160
+ results[f"{label}_bwidth"] = (
161
+ N * inp.element_size() / (time * 1e-3) / 1e9
162
+ )
163
+ else:
164
+ out = torch.full(
165
+ (inp.shape[0] // W,), -1, dtype=inp.dtype, device=inp.device
166
+ )
167
+ fns = (
168
+ torch.ops.fbgemm.car_reducescatter,
169
+ symm_reduce_scatter,
170
+ torch.ops.fbgemm.nccl_reducescatter,
171
+ )
172
+ labels = ("fbgemm_rs", "symm_rs", "nccl_rs")
173
+ for fn, label in zip(fns, labels):
174
+ time = run_one_algo(
175
+ fn,
176
+ out,
177
+ inp,
178
+ args.num_iters,
179
+ args.num_warmup_iters,
180
+ )
181
+ results[f"{label}_time"] = time
182
+ results[f"{label}_bwidth"] = (
183
+ N * inp.element_size() / (time * 1e-3) / 1e9
184
+ )
185
+
186
+ benchmark_results.append(results)
187
+
188
+ if rank == 0:
189
+ pprint(benchmark_results)
190
+ if args.export_csv:
191
+ csv_file = os.path.join(args.output_dir, "comm_ops_benchmark.csv")
192
+ # Export results to a CSV file.
193
+ df = pd.DataFrame(benchmark_results)
194
+ df.to_csv(csv_file, index=False)
195
+
196
+
197
+ def main(args, path):
198
+ if args.export_csv:
199
+ os.makedirs(args.output_dir, exist_ok=True)
200
+ print("csv and images will be saved to " + args.output_dir)
201
+
202
+ lc = LaunchConfig(
203
+ min_nodes=1,
204
+ max_nodes=1,
205
+ nproc_per_node=args.num_ranks,
206
+ run_id=str(uuid.uuid4()),
207
+ rdzv_backend="c10d",
208
+ rdzv_endpoint="localhost:0",
209
+ max_restarts=0,
210
+ monitor_interval=1,
211
+ )
212
+ elastic_launch(lc, entrypoint=run_benchmark)(args, path)
213
+
214
+
215
+ def invoke_main():
216
+ parser = argparse.ArgumentParser()
217
+ parser.add_argument(
218
+ "--output_dir", default="/tmp", help="Directory to save plots and csvs to"
219
+ )
220
+ parser.add_argument(
221
+ "--export_csv",
222
+ action="store_true",
223
+ help="Export results to a CSV file.",
224
+ )
225
+ parser.add_argument("--num_ranks", type=int, default=8)
226
+ parser.add_argument("--num_iters", type=int, default=20)
227
+ parser.add_argument("--num_warmup_iters", type=int, default=10)
228
+ parser.add_argument(
229
+ "--min_size",
230
+ type=int,
231
+ default=10,
232
+ help="minimum size will be set to 2**min_size",
233
+ )
234
+ parser.add_argument(
235
+ "--max_size",
236
+ type=int,
237
+ default=24,
238
+ help="maximum size will be set to 2**max_size",
239
+ )
240
+ parser.add_argument(
241
+ "--size_steps", type=int, default=20, help="number of size steps to run"
242
+ )
243
+ parser.add_argument(
244
+ "--op",
245
+ type=str,
246
+ default="allreduce",
247
+ choices=["allreduce", "reduce_scatter"],
248
+ help="op to benchmark, allreduce or reduce_scatter",
249
+ )
250
+ args = parser.parse_args()
251
+
252
+ with tempfile.TemporaryDirectory() as path:
253
+ main(args, path)
254
+
255
+
256
+ if __name__ == "__main__":
257
+ invoke_main()
@@ -0,0 +1,348 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import functools
8
+ import itertools
9
+ from typing import Optional
10
+
11
+ import click
12
+ import torch
13
+ import triton # noqa: F401
14
+ from fbgemm_gpu.experimental.gen_ai.moe import (
15
+ combine_shuffling,
16
+ gather_along_first_dim,
17
+ gather_scale_dense_tokens,
18
+ gather_scale_quant_dense_tokens,
19
+ index_shuffling,
20
+ scatter_add_along_first_dim,
21
+ scatter_add_dense_tokens,
22
+ split_shuffling,
23
+ )
24
+ from triton.testing import do_bench, do_bench_cudagraph
25
+
26
+ _ACCELERATOR_TAG = torch.accelerator.current_accelerator()
27
+
28
+
29
+ def bench_gather_along_first_dim(M: int, N: int, K: int) -> None:
30
+ src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
31
+ if M == N:
32
+ indices = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int32)
33
+ else:
34
+ indices = torch.randint(0, M, [N], device=_ACCELERATOR_TAG, dtype=torch.int32)
35
+
36
+ def fn():
37
+ return gather_along_first_dim(src, indices)
38
+
39
+ def ref_fn():
40
+ return torch.index_select(src, 0, indices)
41
+
42
+ # Load src, store dst. x2.
43
+ data_size_in_gigabytes = N * K * 2 * 2 / 1e9
44
+
45
+ time_in_us = triton.testing.do_bench(fn) * 1e3
46
+ time_in_second = time_in_us / 1e6
47
+ gigabytes_per_second = data_size_in_gigabytes / time_in_second
48
+
49
+ ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
50
+ ref_time_in_second = ref_time_in_us / 1e6
51
+ ref_gigabytes_per_second = data_size_in_gigabytes / ref_time_in_second
52
+
53
+ print(
54
+ f"Benchmark gather_along_first_dim: {M=:5d}, {N=:5d}, {K=:5d}, "
55
+ f"FBGEMM time: {time_in_us:10.3f} us. Bandwidth: {gigabytes_per_second:10.3f} GB/s, "
56
+ f"Torch time: {ref_time_in_us:10.3f} us. Bandwidth: {ref_gigabytes_per_second:10.3f} GB/s"
57
+ )
58
+
59
+
60
+ def bench_scatter_add_along_first_dim_(op, M: int, N: int, K: int) -> None:
61
+ src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
62
+ dst = torch.randn([N, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
63
+ if M == N:
64
+ indices_1d = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int64)
65
+ else:
66
+ indices_1d = torch.randint(
67
+ 0, N, [M], device=_ACCELERATOR_TAG, dtype=torch.int64
68
+ )
69
+
70
+ indices_2d = indices_1d.to(torch.int64).unsqueeze(1).expand(-1, K)
71
+
72
+ test_dst = dst.clone()
73
+ ref_dst = dst.clone()
74
+
75
+ def fn():
76
+ op(test_dst, src, indices_1d)
77
+
78
+ def ref_fn():
79
+ ref_dst.scatter_add_(0, indices_2d, src)
80
+
81
+ # Load src, load dst, store dst. x3.
82
+ data_size_in_gigabytes = N * K * 2 * 3 / 1e9
83
+
84
+ time_in_us = triton.testing.do_bench(fn) * 1e3
85
+ time_in_second = time_in_us / 1e6
86
+ gigabytes_per_second = data_size_in_gigabytes / time_in_second
87
+
88
+ ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
89
+ ref_time_in_second = ref_time_in_us / 1e6
90
+ ref_gigabytes_per_second = data_size_in_gigabytes / ref_time_in_second
91
+
92
+ print(
93
+ f"Benchmark {op.__name__}: {M=:5d}, {N=:5d}, {K=:5d}, "
94
+ f"FBGEMM time: {time_in_us:10.3f} us. Bandwidth: {gigabytes_per_second:10.3f} GB/s, "
95
+ f"Torch time: {ref_time_in_us:10.3f} us. Bandwidth: {ref_gigabytes_per_second:10.3f} GB/s"
96
+ )
97
+
98
+
99
+ bench_scatter_add_along_first_dim = functools.partial(
100
+ bench_scatter_add_along_first_dim_, scatter_add_along_first_dim
101
+ )
102
+
103
+ bench_scatter_add_dense_tokens = functools.partial(
104
+ bench_scatter_add_along_first_dim_, scatter_add_dense_tokens
105
+ )
106
+
107
+
108
+ def bench_gather_scale_dense_tokens(E: int, T: int, D: int, quantize: bool):
109
+ x = torch.randn((T, D), dtype=torch.bfloat16, device=_ACCELERATOR_TAG).abs()
110
+ expert_indices = torch.randint(0, E, (T,), device=_ACCELERATOR_TAG)
111
+ token_indices = torch.randperm(T, device=_ACCELERATOR_TAG)
112
+ scores = torch.rand((E, T), dtype=torch.bfloat16, device=_ACCELERATOR_TAG)
113
+
114
+ def torch_fn():
115
+ shuffled_x = torch.index_select(x, dim=0, index=token_indices)
116
+ shuffled_scores = torch.index_select(scores, dim=1, index=token_indices)
117
+ shuffled_selected_scores = torch.gather(
118
+ shuffled_scores, dim=0, index=expert_indices.view(1, T)
119
+ )
120
+ ref_output = shuffled_x * shuffled_selected_scores.view(-1, 1)
121
+ return ref_output
122
+
123
+ torch_fn()
124
+
125
+ scores_TE = scores.transpose(0, 1).contiguous()
126
+
127
+ fbgemm_fn = (
128
+ gather_scale_quant_dense_tokens if quantize else gather_scale_dense_tokens
129
+ )
130
+
131
+ def triton_fn():
132
+ test_output = fbgemm_fn(x, token_indices, expert_indices, scores_TE)
133
+ return test_output
134
+
135
+ triton_fn()
136
+
137
+ # Run benchmark
138
+ if quantize:
139
+ data_size_in_gigabytes = T * D * 3 / 1e9
140
+ else:
141
+ data_size_in_gigabytes = T * D * 4 / 1e9
142
+
143
+ fbgemm_time = do_bench(triton_fn, rep=1000) * 1e3
144
+ fbgemm_bw = data_size_in_gigabytes / (fbgemm_time / 1e6)
145
+
146
+ torch_time = do_bench(torch_fn, rep=1000) * 1e3
147
+ torch_bw = data_size_in_gigabytes / (torch_time / 1e6)
148
+ print(
149
+ f"Benchmark gather_scale_dense_tokens({quantize=}), {E=:3d}, {T=:5d}, {D=:5d}, "
150
+ f"FBGEMM time: {fbgemm_time:10.3f} us. Bandwidth: {fbgemm_bw:10.3f} GB/s, "
151
+ f"Torch time: {torch_time:10.3f} us. Bandwidth: {torch_bw:10.3f} GB/s"
152
+ )
153
+
154
+
155
+ def bench_topk_index_shuffling(T: int, E: int, K: int) -> None:
156
+ torch.manual_seed(0)
157
+
158
+ num_rotating_buffers = min(max(2, triton.cdiv(1024 * 1024 * 1024, T * E * 2)), 1000)
159
+ scores_list: list[torch.Tensor] = [
160
+ torch.randn(T, E, device=_ACCELERATOR_TAG, dtype=torch.bfloat16)
161
+ for i in range(num_rotating_buffers)
162
+ ]
163
+
164
+ def fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
165
+ for scores in scores_list:
166
+ index_shuffling(scores, top_k=K)
167
+
168
+ def ref_fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ for scores in scores_list:
170
+ _, selected_expert_indices = torch.topk(scores, K, dim=1)
171
+ expert_indices, _ = torch.sort(
172
+ selected_expert_indices.flatten(), dim=0, stable=True
173
+ )
174
+ _ = (
175
+ expert_indices[:, None]
176
+ == torch.arange(E, device=expert_indices.device)[None, :]
177
+ ).sum(dim=0)
178
+
179
+ fbgemm_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
180
+ torch_time = do_bench_cudagraph(ref_fn) * 1e3 / num_rotating_buffers
181
+ print(
182
+ f"Benchmark index_shuffling, num_tokens={T:4}, num_experts={E:4}, top_k={K:4}, "
183
+ f"fbgemm_time={fbgemm_time:7.3f}us, torch_time={torch_time:7.3f}us"
184
+ )
185
+
186
+
187
+ def bench_combine_or_split_shuffling(
188
+ T: int,
189
+ D: int,
190
+ E: int,
191
+ EP: bool,
192
+ is_padded: bool,
193
+ is_balanced: bool,
194
+ is_combine_shuffling: bool,
195
+ ):
196
+ torch.manual_seed(0)
197
+
198
+ assert E % EP == 0
199
+ if is_padded:
200
+ # graph. allgather
201
+ input_num_tokens: int = EP * T
202
+ input_num_experts: int = E
203
+ output_num_experts: int = E // EP
204
+ start_expert_index: int = 1
205
+ end_expert_index: int = 1 + output_num_experts
206
+ else:
207
+ # eager. all2all
208
+ input_num_tokens: int = T
209
+ input_num_experts: int = E // EP
210
+ output_num_experts: int = E // EP
211
+ start_expert_index: int = 0
212
+ end_expert_index: int = output_num_experts
213
+
214
+ tokens = torch.randn(
215
+ input_num_tokens, D, device=_ACCELERATOR_TAG, dtype=torch.bfloat16
216
+ )
217
+
218
+ if input_num_tokens < (EP * input_num_experts) != 0:
219
+ return
220
+
221
+ input_num_tokens_per_expert: int = input_num_tokens // (EP * input_num_experts)
222
+ token_counts: torch.Tensor = (
223
+ torch.ones(
224
+ [EP, input_num_experts],
225
+ dtype=torch.int32,
226
+ device=_ACCELERATOR_TAG,
227
+ )
228
+ * input_num_tokens_per_expert
229
+ )
230
+ if not is_balanced:
231
+ for i in range(EP):
232
+ token_counts[i, start_expert_index] -= input_num_tokens_per_expert
233
+ token_counts[i, end_expert_index - 1] += input_num_tokens_per_expert
234
+
235
+ assert token_counts.sum().item() == input_num_tokens
236
+
237
+ num_rotating_buffers = triton.cdiv(1024 * 1024 * 1024, tokens.numel() * 2)
238
+ token_list: list[torch.Tensor] = [
239
+ tokens.clone() for _ in range(num_rotating_buffers)
240
+ ]
241
+ token_count_list: list[torch.Tensor] = [
242
+ token_counts.clone() for _ in range(num_rotating_buffers)
243
+ ]
244
+
245
+ def fn() -> tuple[torch.Tensor, Optional[torch.Tensor]]:
246
+ for tokens, token_counts in zip(token_list, token_count_list):
247
+ if is_combine_shuffling:
248
+ combine_shuffling(
249
+ tokens,
250
+ token_counts,
251
+ expert_start=start_expert_index,
252
+ expert_end=end_expert_index,
253
+ is_balanced=is_balanced,
254
+ )
255
+ else:
256
+ split_shuffling(
257
+ tokens,
258
+ token_counts,
259
+ expert_start=start_expert_index,
260
+ expert_end=end_expert_index,
261
+ is_balanced=is_balanced,
262
+ )
263
+
264
+ fn()
265
+
266
+ output_num_tokens = 0
267
+ for per_rank_counts in token_counts.tolist():
268
+ for expert_index, per_expert_counts in enumerate(per_rank_counts):
269
+ if expert_index >= start_expert_index and expert_index < end_expert_index:
270
+ output_num_tokens += per_expert_counts
271
+
272
+ mem_bytes = output_num_tokens * D * 2 * 2
273
+ fbgemm_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
274
+ fbgemm_bw = mem_bytes * 1e-9 / (fbgemm_time * 1e-6)
275
+
276
+ print(
277
+ f"Benchmark {'combine_shuffling' if is_combine_shuffling else 'split_shuffling'}, "
278
+ f"num_tokens={T:4}, dim={D:4}, num_experts={E:4}, expert_parallelism={EP:4}, output_num_tokens={output_num_tokens:4}, "
279
+ f"{is_balanced=}, {is_padded=}, "
280
+ f"fbgemm_time={fbgemm_time:7.3f}us, fbgemm_bw={fbgemm_bw:8.3f}GBytes/s."
281
+ )
282
+
283
+
284
+ @click.command()
285
+ @click.option(
286
+ "--kernels",
287
+ default=None,
288
+ help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
289
+ )
290
+ def main(kernels: Optional[str]):
291
+ if kernels is not None:
292
+ kernels = kernels.split(",")
293
+
294
+ def should_bench_kernel(fn):
295
+ return (fn is not None) and (kernels is None or fn.__name__ in kernels)
296
+
297
+ Es = [16, 128]
298
+ Ts = [1, 128, 2048, 4096, 8192, 16384]
299
+ Ds = [5120]
300
+
301
+ # Gather/Scatter
302
+ if should_bench_kernel(gather_scale_dense_tokens):
303
+ for E, T, D in itertools.product(Es, Ts, Ds):
304
+ bench_gather_scale_dense_tokens(E, T, D, quantize=False)
305
+
306
+ if should_bench_kernel(gather_scale_quant_dense_tokens):
307
+ for E, T, D in itertools.product(Es, Ts, Ds):
308
+ bench_gather_scale_dense_tokens(E, T, D, quantize=True)
309
+
310
+ if should_bench_kernel(gather_along_first_dim):
311
+ for T, D in itertools.product(Ts, Ds):
312
+ bench_gather_along_first_dim(T, T, D)
313
+
314
+ if should_bench_kernel(scatter_add_along_first_dim):
315
+ for T, D in itertools.product(Ts, Ds):
316
+ bench_scatter_add_along_first_dim(T, T, D)
317
+
318
+ if should_bench_kernel(scatter_add_dense_tokens):
319
+ for T, D in itertools.product(Ts, Ds):
320
+ bench_scatter_add_dense_tokens(T, T, D)
321
+
322
+ Ks = [1, 2, 4]
323
+ Es = [16, 32, 128, 320]
324
+ # Shuffling
325
+ if should_bench_kernel(index_shuffling):
326
+ for T, E, K in itertools.product(Ts, Es, Ks):
327
+ bench_topk_index_shuffling(T, E, K)
328
+
329
+ EPs = [2, 16]
330
+ Ts = [32, 128, 2048, 4096, 8192, 16384]
331
+ padded = [True, False]
332
+ balanced = [True, False]
333
+
334
+ if should_bench_kernel(combine_shuffling):
335
+ for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
336
+ bench_combine_or_split_shuffling(
337
+ T, D, E, EP, p, b, is_combine_shuffling=True
338
+ )
339
+
340
+ if should_bench_kernel(split_shuffling):
341
+ for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
342
+ bench_combine_or_split_shuffling(
343
+ T, D, E, EP, p, b, is_combine_shuffling=False
344
+ )
345
+
346
+
347
+ if __name__ == "__main__":
348
+ main()