fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import tempfile
|
|
12
|
+
import uuid
|
|
13
|
+
from functools import lru_cache
|
|
14
|
+
from pprint import pprint
|
|
15
|
+
|
|
16
|
+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.distributed as dist
|
|
21
|
+
import torch.distributed._symmetric_memory as symm_mem
|
|
22
|
+
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@lru_cache(None)
|
|
26
|
+
def get_symm_buffer(group):
|
|
27
|
+
inp = symm_mem.empty(
|
|
28
|
+
16 * 1024 * 1024, device="cuda", dtype=torch.bfloat16
|
|
29
|
+
) # .normal_()
|
|
30
|
+
symm_mem.rendezvous(inp, group=group)
|
|
31
|
+
return inp, group.group_name
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _setup(path: str) -> tuple[int, int]:
|
|
35
|
+
rank = int(os.environ["LOCAL_RANK"])
|
|
36
|
+
W = int(os.environ["WORLD_SIZE"])
|
|
37
|
+
device = torch.device(f"cuda:{rank}")
|
|
38
|
+
torch.cuda.set_device(device)
|
|
39
|
+
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
|
40
|
+
|
|
41
|
+
torch.ops.fbgemm.nccl_init(rank, W, os.path.join(path, "rdvz"))
|
|
42
|
+
torch.distributed.init_process_group(
|
|
43
|
+
backend="cpu:gloo,cuda:nccl",
|
|
44
|
+
init_method=f"file://{os.path.join(path, 'gloo_rdvz')}",
|
|
45
|
+
world_size=W,
|
|
46
|
+
rank=rank,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
buffer = torch.ops.fbgemm.car_tensor()
|
|
50
|
+
barrier = torch.ops.fbgemm.car_tensor()
|
|
51
|
+
barrier.zero_()
|
|
52
|
+
|
|
53
|
+
buffer_handle = torch.ops.fbgemm.car_ipc_handle(buffer)
|
|
54
|
+
all_buffer_handles = [torch.empty_like(buffer_handle) for _ in range(W)]
|
|
55
|
+
torch.distributed.all_gather(all_buffer_handles, buffer_handle)
|
|
56
|
+
|
|
57
|
+
barrier_handle = torch.ops.fbgemm.car_ipc_handle(barrier)
|
|
58
|
+
all_barrier_handles = [torch.empty_like(barrier_handle) for _ in range(W)]
|
|
59
|
+
torch.distributed.all_gather(all_barrier_handles, barrier_handle)
|
|
60
|
+
torch.ops.fbgemm.car_init(
|
|
61
|
+
rank, W, barrier, all_barrier_handles, buffer, all_buffer_handles
|
|
62
|
+
)
|
|
63
|
+
torch.cuda.synchronize()
|
|
64
|
+
torch.distributed.barrier()
|
|
65
|
+
group = dist.group.WORLD
|
|
66
|
+
_ = get_symm_buffer(group)
|
|
67
|
+
return rank, W
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def symm_one_shot_allreduce(dst_tensor, src_tensor, bias=None, comm_idx=None):
|
|
71
|
+
# get_symm_buffer should be called for the first time during model init,
|
|
72
|
+
# and now return cached values. Make sure group is the same as during init
|
|
73
|
+
symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
|
|
74
|
+
symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
|
|
75
|
+
torch.ops.symm_mem.one_shot_all_reduce_copy_out(
|
|
76
|
+
symm_buffer, src_tensor, "sum", group_name, dst_tensor
|
|
77
|
+
)
|
|
78
|
+
if bias is not None:
|
|
79
|
+
dst_tensor.add_(bias)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def symm_two_shot_allreduce(dst_tensor, src_tensor, bias=None, comm_idx=None):
|
|
83
|
+
# get_symm_buffer should be called for the first time during model init,
|
|
84
|
+
# and now return cached values. Make sure group is the same as during init
|
|
85
|
+
symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
|
|
86
|
+
# car is also doing explicit copy
|
|
87
|
+
symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
|
|
88
|
+
symm_buffer.copy_(src_tensor)
|
|
89
|
+
torch.ops.symm_mem.two_shot_all_reduce_out(
|
|
90
|
+
symm_buffer, "sum", group_name, dst_tensor
|
|
91
|
+
)
|
|
92
|
+
if bias is not None:
|
|
93
|
+
dst_tensor.add_(bias)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def symm_reduce_scatter(dst_tensor, src_tensor, comm_idx=None):
|
|
97
|
+
symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
|
|
98
|
+
symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
|
|
99
|
+
symm_buffer.copy_(src_tensor)
|
|
100
|
+
torch.ops.symm_mem.reduce_scatter_out(symm_buffer, group_name, False, dst_tensor)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def run_one_algo(fn, out, inp, num_iters, num_warmup_iters):
|
|
104
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
105
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
106
|
+
for _ in range(num_warmup_iters):
|
|
107
|
+
fn(out, inp)
|
|
108
|
+
start_event.record()
|
|
109
|
+
for _ in range(num_iters):
|
|
110
|
+
fn(out, inp)
|
|
111
|
+
end_event.record()
|
|
112
|
+
torch.cuda.synchronize()
|
|
113
|
+
time = start_event.elapsed_time(end_event) / num_iters
|
|
114
|
+
return time
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def run_benchmark(args, path):
|
|
118
|
+
rank, W = _setup(path)
|
|
119
|
+
if rank == 0:
|
|
120
|
+
print(f"Running benchmark with {W} ranks")
|
|
121
|
+
# benchmark_results = defaultdict(defaultdict)
|
|
122
|
+
benchmark_results = []
|
|
123
|
+
# with torch.profiler.profile() as p:
|
|
124
|
+
for N in torch.logspace(
|
|
125
|
+
args.min_size, args.max_size, steps=args.size_steps, base=2
|
|
126
|
+
).tolist():
|
|
127
|
+
|
|
128
|
+
def round_up(a: int, b: int) -> int:
|
|
129
|
+
return ((a + b - 1) // b) * b
|
|
130
|
+
|
|
131
|
+
N_even_divisor = 8 * 64 if torch.version.hip else 8 * 32
|
|
132
|
+
N = round_up(int(N), N_even_divisor)
|
|
133
|
+
inp = torch.rand(N, dtype=torch.bfloat16, device="cuda")
|
|
134
|
+
results = {"N": N}
|
|
135
|
+
if args.op == "allreduce":
|
|
136
|
+
out = torch.full_like(inp, -1)
|
|
137
|
+
fns = (
|
|
138
|
+
torch.ops.fbgemm.one_shot_car_allreduce,
|
|
139
|
+
symm_one_shot_allreduce,
|
|
140
|
+
torch.ops.fbgemm.two_shot_car_allreduce,
|
|
141
|
+
symm_two_shot_allreduce,
|
|
142
|
+
torch.ops.fbgemm.nccl_allreduce,
|
|
143
|
+
)
|
|
144
|
+
labels = (
|
|
145
|
+
"fbgemm_1shot",
|
|
146
|
+
"symm_1shot",
|
|
147
|
+
"fbgemm_2shot",
|
|
148
|
+
"symm_2shot",
|
|
149
|
+
"nccl",
|
|
150
|
+
)
|
|
151
|
+
for fn, label in zip(fns, labels):
|
|
152
|
+
time = run_one_algo(
|
|
153
|
+
fn,
|
|
154
|
+
out,
|
|
155
|
+
inp,
|
|
156
|
+
args.num_iters,
|
|
157
|
+
args.num_warmup_iters,
|
|
158
|
+
)
|
|
159
|
+
results[f"{label}_time"] = time
|
|
160
|
+
results[f"{label}_bwidth"] = (
|
|
161
|
+
N * inp.element_size() / (time * 1e-3) / 1e9
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
out = torch.full(
|
|
165
|
+
(inp.shape[0] // W,), -1, dtype=inp.dtype, device=inp.device
|
|
166
|
+
)
|
|
167
|
+
fns = (
|
|
168
|
+
torch.ops.fbgemm.car_reducescatter,
|
|
169
|
+
symm_reduce_scatter,
|
|
170
|
+
torch.ops.fbgemm.nccl_reducescatter,
|
|
171
|
+
)
|
|
172
|
+
labels = ("fbgemm_rs", "symm_rs", "nccl_rs")
|
|
173
|
+
for fn, label in zip(fns, labels):
|
|
174
|
+
time = run_one_algo(
|
|
175
|
+
fn,
|
|
176
|
+
out,
|
|
177
|
+
inp,
|
|
178
|
+
args.num_iters,
|
|
179
|
+
args.num_warmup_iters,
|
|
180
|
+
)
|
|
181
|
+
results[f"{label}_time"] = time
|
|
182
|
+
results[f"{label}_bwidth"] = (
|
|
183
|
+
N * inp.element_size() / (time * 1e-3) / 1e9
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
benchmark_results.append(results)
|
|
187
|
+
|
|
188
|
+
if rank == 0:
|
|
189
|
+
pprint(benchmark_results)
|
|
190
|
+
if args.export_csv:
|
|
191
|
+
csv_file = os.path.join(args.output_dir, "comm_ops_benchmark.csv")
|
|
192
|
+
# Export results to a CSV file.
|
|
193
|
+
df = pd.DataFrame(benchmark_results)
|
|
194
|
+
df.to_csv(csv_file, index=False)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def main(args, path):
|
|
198
|
+
if args.export_csv:
|
|
199
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
200
|
+
print("csv and images will be saved to " + args.output_dir)
|
|
201
|
+
|
|
202
|
+
lc = LaunchConfig(
|
|
203
|
+
min_nodes=1,
|
|
204
|
+
max_nodes=1,
|
|
205
|
+
nproc_per_node=args.num_ranks,
|
|
206
|
+
run_id=str(uuid.uuid4()),
|
|
207
|
+
rdzv_backend="c10d",
|
|
208
|
+
rdzv_endpoint="localhost:0",
|
|
209
|
+
max_restarts=0,
|
|
210
|
+
monitor_interval=1,
|
|
211
|
+
)
|
|
212
|
+
elastic_launch(lc, entrypoint=run_benchmark)(args, path)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def invoke_main():
|
|
216
|
+
parser = argparse.ArgumentParser()
|
|
217
|
+
parser.add_argument(
|
|
218
|
+
"--output_dir", default="/tmp", help="Directory to save plots and csvs to"
|
|
219
|
+
)
|
|
220
|
+
parser.add_argument(
|
|
221
|
+
"--export_csv",
|
|
222
|
+
action="store_true",
|
|
223
|
+
help="Export results to a CSV file.",
|
|
224
|
+
)
|
|
225
|
+
parser.add_argument("--num_ranks", type=int, default=8)
|
|
226
|
+
parser.add_argument("--num_iters", type=int, default=20)
|
|
227
|
+
parser.add_argument("--num_warmup_iters", type=int, default=10)
|
|
228
|
+
parser.add_argument(
|
|
229
|
+
"--min_size",
|
|
230
|
+
type=int,
|
|
231
|
+
default=10,
|
|
232
|
+
help="minimum size will be set to 2**min_size",
|
|
233
|
+
)
|
|
234
|
+
parser.add_argument(
|
|
235
|
+
"--max_size",
|
|
236
|
+
type=int,
|
|
237
|
+
default=24,
|
|
238
|
+
help="maximum size will be set to 2**max_size",
|
|
239
|
+
)
|
|
240
|
+
parser.add_argument(
|
|
241
|
+
"--size_steps", type=int, default=20, help="number of size steps to run"
|
|
242
|
+
)
|
|
243
|
+
parser.add_argument(
|
|
244
|
+
"--op",
|
|
245
|
+
type=str,
|
|
246
|
+
default="allreduce",
|
|
247
|
+
choices=["allreduce", "reduce_scatter"],
|
|
248
|
+
help="op to benchmark, allreduce or reduce_scatter",
|
|
249
|
+
)
|
|
250
|
+
args = parser.parse_args()
|
|
251
|
+
|
|
252
|
+
with tempfile.TemporaryDirectory() as path:
|
|
253
|
+
main(args, path)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
if __name__ == "__main__":
|
|
257
|
+
invoke_main()
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import itertools
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import click
|
|
12
|
+
import torch
|
|
13
|
+
import triton # noqa: F401
|
|
14
|
+
from fbgemm_gpu.experimental.gen_ai.moe import (
|
|
15
|
+
combine_shuffling,
|
|
16
|
+
gather_along_first_dim,
|
|
17
|
+
gather_scale_dense_tokens,
|
|
18
|
+
gather_scale_quant_dense_tokens,
|
|
19
|
+
index_shuffling,
|
|
20
|
+
scatter_add_along_first_dim,
|
|
21
|
+
scatter_add_dense_tokens,
|
|
22
|
+
split_shuffling,
|
|
23
|
+
)
|
|
24
|
+
from triton.testing import do_bench, do_bench_cudagraph
|
|
25
|
+
|
|
26
|
+
_ACCELERATOR_TAG = torch.accelerator.current_accelerator()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def bench_gather_along_first_dim(M: int, N: int, K: int) -> None:
|
|
30
|
+
src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
|
|
31
|
+
if M == N:
|
|
32
|
+
indices = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int32)
|
|
33
|
+
else:
|
|
34
|
+
indices = torch.randint(0, M, [N], device=_ACCELERATOR_TAG, dtype=torch.int32)
|
|
35
|
+
|
|
36
|
+
def fn():
|
|
37
|
+
return gather_along_first_dim(src, indices)
|
|
38
|
+
|
|
39
|
+
def ref_fn():
|
|
40
|
+
return torch.index_select(src, 0, indices)
|
|
41
|
+
|
|
42
|
+
# Load src, store dst. x2.
|
|
43
|
+
data_size_in_gigabytes = N * K * 2 * 2 / 1e9
|
|
44
|
+
|
|
45
|
+
time_in_us = triton.testing.do_bench(fn) * 1e3
|
|
46
|
+
time_in_second = time_in_us / 1e6
|
|
47
|
+
gigabytes_per_second = data_size_in_gigabytes / time_in_second
|
|
48
|
+
|
|
49
|
+
ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
|
|
50
|
+
ref_time_in_second = ref_time_in_us / 1e6
|
|
51
|
+
ref_gigabytes_per_second = data_size_in_gigabytes / ref_time_in_second
|
|
52
|
+
|
|
53
|
+
print(
|
|
54
|
+
f"Benchmark gather_along_first_dim: {M=:5d}, {N=:5d}, {K=:5d}, "
|
|
55
|
+
f"FBGEMM time: {time_in_us:10.3f} us. Bandwidth: {gigabytes_per_second:10.3f} GB/s, "
|
|
56
|
+
f"Torch time: {ref_time_in_us:10.3f} us. Bandwidth: {ref_gigabytes_per_second:10.3f} GB/s"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def bench_scatter_add_along_first_dim_(op, M: int, N: int, K: int) -> None:
|
|
61
|
+
src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
|
|
62
|
+
dst = torch.randn([N, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
|
|
63
|
+
if M == N:
|
|
64
|
+
indices_1d = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int64)
|
|
65
|
+
else:
|
|
66
|
+
indices_1d = torch.randint(
|
|
67
|
+
0, N, [M], device=_ACCELERATOR_TAG, dtype=torch.int64
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
indices_2d = indices_1d.to(torch.int64).unsqueeze(1).expand(-1, K)
|
|
71
|
+
|
|
72
|
+
test_dst = dst.clone()
|
|
73
|
+
ref_dst = dst.clone()
|
|
74
|
+
|
|
75
|
+
def fn():
|
|
76
|
+
op(test_dst, src, indices_1d)
|
|
77
|
+
|
|
78
|
+
def ref_fn():
|
|
79
|
+
ref_dst.scatter_add_(0, indices_2d, src)
|
|
80
|
+
|
|
81
|
+
# Load src, load dst, store dst. x3.
|
|
82
|
+
data_size_in_gigabytes = N * K * 2 * 3 / 1e9
|
|
83
|
+
|
|
84
|
+
time_in_us = triton.testing.do_bench(fn) * 1e3
|
|
85
|
+
time_in_second = time_in_us / 1e6
|
|
86
|
+
gigabytes_per_second = data_size_in_gigabytes / time_in_second
|
|
87
|
+
|
|
88
|
+
ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
|
|
89
|
+
ref_time_in_second = ref_time_in_us / 1e6
|
|
90
|
+
ref_gigabytes_per_second = data_size_in_gigabytes / ref_time_in_second
|
|
91
|
+
|
|
92
|
+
print(
|
|
93
|
+
f"Benchmark {op.__name__}: {M=:5d}, {N=:5d}, {K=:5d}, "
|
|
94
|
+
f"FBGEMM time: {time_in_us:10.3f} us. Bandwidth: {gigabytes_per_second:10.3f} GB/s, "
|
|
95
|
+
f"Torch time: {ref_time_in_us:10.3f} us. Bandwidth: {ref_gigabytes_per_second:10.3f} GB/s"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
bench_scatter_add_along_first_dim = functools.partial(
|
|
100
|
+
bench_scatter_add_along_first_dim_, scatter_add_along_first_dim
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
bench_scatter_add_dense_tokens = functools.partial(
|
|
104
|
+
bench_scatter_add_along_first_dim_, scatter_add_dense_tokens
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def bench_gather_scale_dense_tokens(E: int, T: int, D: int, quantize: bool):
|
|
109
|
+
x = torch.randn((T, D), dtype=torch.bfloat16, device=_ACCELERATOR_TAG).abs()
|
|
110
|
+
expert_indices = torch.randint(0, E, (T,), device=_ACCELERATOR_TAG)
|
|
111
|
+
token_indices = torch.randperm(T, device=_ACCELERATOR_TAG)
|
|
112
|
+
scores = torch.rand((E, T), dtype=torch.bfloat16, device=_ACCELERATOR_TAG)
|
|
113
|
+
|
|
114
|
+
def torch_fn():
|
|
115
|
+
shuffled_x = torch.index_select(x, dim=0, index=token_indices)
|
|
116
|
+
shuffled_scores = torch.index_select(scores, dim=1, index=token_indices)
|
|
117
|
+
shuffled_selected_scores = torch.gather(
|
|
118
|
+
shuffled_scores, dim=0, index=expert_indices.view(1, T)
|
|
119
|
+
)
|
|
120
|
+
ref_output = shuffled_x * shuffled_selected_scores.view(-1, 1)
|
|
121
|
+
return ref_output
|
|
122
|
+
|
|
123
|
+
torch_fn()
|
|
124
|
+
|
|
125
|
+
scores_TE = scores.transpose(0, 1).contiguous()
|
|
126
|
+
|
|
127
|
+
fbgemm_fn = (
|
|
128
|
+
gather_scale_quant_dense_tokens if quantize else gather_scale_dense_tokens
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def triton_fn():
|
|
132
|
+
test_output = fbgemm_fn(x, token_indices, expert_indices, scores_TE)
|
|
133
|
+
return test_output
|
|
134
|
+
|
|
135
|
+
triton_fn()
|
|
136
|
+
|
|
137
|
+
# Run benchmark
|
|
138
|
+
if quantize:
|
|
139
|
+
data_size_in_gigabytes = T * D * 3 / 1e9
|
|
140
|
+
else:
|
|
141
|
+
data_size_in_gigabytes = T * D * 4 / 1e9
|
|
142
|
+
|
|
143
|
+
fbgemm_time = do_bench(triton_fn, rep=1000) * 1e3
|
|
144
|
+
fbgemm_bw = data_size_in_gigabytes / (fbgemm_time / 1e6)
|
|
145
|
+
|
|
146
|
+
torch_time = do_bench(torch_fn, rep=1000) * 1e3
|
|
147
|
+
torch_bw = data_size_in_gigabytes / (torch_time / 1e6)
|
|
148
|
+
print(
|
|
149
|
+
f"Benchmark gather_scale_dense_tokens({quantize=}), {E=:3d}, {T=:5d}, {D=:5d}, "
|
|
150
|
+
f"FBGEMM time: {fbgemm_time:10.3f} us. Bandwidth: {fbgemm_bw:10.3f} GB/s, "
|
|
151
|
+
f"Torch time: {torch_time:10.3f} us. Bandwidth: {torch_bw:10.3f} GB/s"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def bench_topk_index_shuffling(T: int, E: int, K: int) -> None:
|
|
156
|
+
torch.manual_seed(0)
|
|
157
|
+
|
|
158
|
+
num_rotating_buffers = min(max(2, triton.cdiv(1024 * 1024 * 1024, T * E * 2)), 1000)
|
|
159
|
+
scores_list: list[torch.Tensor] = [
|
|
160
|
+
torch.randn(T, E, device=_ACCELERATOR_TAG, dtype=torch.bfloat16)
|
|
161
|
+
for i in range(num_rotating_buffers)
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
def fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
165
|
+
for scores in scores_list:
|
|
166
|
+
index_shuffling(scores, top_k=K)
|
|
167
|
+
|
|
168
|
+
def ref_fn() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
169
|
+
for scores in scores_list:
|
|
170
|
+
_, selected_expert_indices = torch.topk(scores, K, dim=1)
|
|
171
|
+
expert_indices, _ = torch.sort(
|
|
172
|
+
selected_expert_indices.flatten(), dim=0, stable=True
|
|
173
|
+
)
|
|
174
|
+
_ = (
|
|
175
|
+
expert_indices[:, None]
|
|
176
|
+
== torch.arange(E, device=expert_indices.device)[None, :]
|
|
177
|
+
).sum(dim=0)
|
|
178
|
+
|
|
179
|
+
fbgemm_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
|
|
180
|
+
torch_time = do_bench_cudagraph(ref_fn) * 1e3 / num_rotating_buffers
|
|
181
|
+
print(
|
|
182
|
+
f"Benchmark index_shuffling, num_tokens={T:4}, num_experts={E:4}, top_k={K:4}, "
|
|
183
|
+
f"fbgemm_time={fbgemm_time:7.3f}us, torch_time={torch_time:7.3f}us"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def bench_combine_or_split_shuffling(
|
|
188
|
+
T: int,
|
|
189
|
+
D: int,
|
|
190
|
+
E: int,
|
|
191
|
+
EP: bool,
|
|
192
|
+
is_padded: bool,
|
|
193
|
+
is_balanced: bool,
|
|
194
|
+
is_combine_shuffling: bool,
|
|
195
|
+
):
|
|
196
|
+
torch.manual_seed(0)
|
|
197
|
+
|
|
198
|
+
assert E % EP == 0
|
|
199
|
+
if is_padded:
|
|
200
|
+
# graph. allgather
|
|
201
|
+
input_num_tokens: int = EP * T
|
|
202
|
+
input_num_experts: int = E
|
|
203
|
+
output_num_experts: int = E // EP
|
|
204
|
+
start_expert_index: int = 1
|
|
205
|
+
end_expert_index: int = 1 + output_num_experts
|
|
206
|
+
else:
|
|
207
|
+
# eager. all2all
|
|
208
|
+
input_num_tokens: int = T
|
|
209
|
+
input_num_experts: int = E // EP
|
|
210
|
+
output_num_experts: int = E // EP
|
|
211
|
+
start_expert_index: int = 0
|
|
212
|
+
end_expert_index: int = output_num_experts
|
|
213
|
+
|
|
214
|
+
tokens = torch.randn(
|
|
215
|
+
input_num_tokens, D, device=_ACCELERATOR_TAG, dtype=torch.bfloat16
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if input_num_tokens < (EP * input_num_experts) != 0:
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
input_num_tokens_per_expert: int = input_num_tokens // (EP * input_num_experts)
|
|
222
|
+
token_counts: torch.Tensor = (
|
|
223
|
+
torch.ones(
|
|
224
|
+
[EP, input_num_experts],
|
|
225
|
+
dtype=torch.int32,
|
|
226
|
+
device=_ACCELERATOR_TAG,
|
|
227
|
+
)
|
|
228
|
+
* input_num_tokens_per_expert
|
|
229
|
+
)
|
|
230
|
+
if not is_balanced:
|
|
231
|
+
for i in range(EP):
|
|
232
|
+
token_counts[i, start_expert_index] -= input_num_tokens_per_expert
|
|
233
|
+
token_counts[i, end_expert_index - 1] += input_num_tokens_per_expert
|
|
234
|
+
|
|
235
|
+
assert token_counts.sum().item() == input_num_tokens
|
|
236
|
+
|
|
237
|
+
num_rotating_buffers = triton.cdiv(1024 * 1024 * 1024, tokens.numel() * 2)
|
|
238
|
+
token_list: list[torch.Tensor] = [
|
|
239
|
+
tokens.clone() for _ in range(num_rotating_buffers)
|
|
240
|
+
]
|
|
241
|
+
token_count_list: list[torch.Tensor] = [
|
|
242
|
+
token_counts.clone() for _ in range(num_rotating_buffers)
|
|
243
|
+
]
|
|
244
|
+
|
|
245
|
+
def fn() -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
246
|
+
for tokens, token_counts in zip(token_list, token_count_list):
|
|
247
|
+
if is_combine_shuffling:
|
|
248
|
+
combine_shuffling(
|
|
249
|
+
tokens,
|
|
250
|
+
token_counts,
|
|
251
|
+
expert_start=start_expert_index,
|
|
252
|
+
expert_end=end_expert_index,
|
|
253
|
+
is_balanced=is_balanced,
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
split_shuffling(
|
|
257
|
+
tokens,
|
|
258
|
+
token_counts,
|
|
259
|
+
expert_start=start_expert_index,
|
|
260
|
+
expert_end=end_expert_index,
|
|
261
|
+
is_balanced=is_balanced,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
fn()
|
|
265
|
+
|
|
266
|
+
output_num_tokens = 0
|
|
267
|
+
for per_rank_counts in token_counts.tolist():
|
|
268
|
+
for expert_index, per_expert_counts in enumerate(per_rank_counts):
|
|
269
|
+
if expert_index >= start_expert_index and expert_index < end_expert_index:
|
|
270
|
+
output_num_tokens += per_expert_counts
|
|
271
|
+
|
|
272
|
+
mem_bytes = output_num_tokens * D * 2 * 2
|
|
273
|
+
fbgemm_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
|
|
274
|
+
fbgemm_bw = mem_bytes * 1e-9 / (fbgemm_time * 1e-6)
|
|
275
|
+
|
|
276
|
+
print(
|
|
277
|
+
f"Benchmark {'combine_shuffling' if is_combine_shuffling else 'split_shuffling'}, "
|
|
278
|
+
f"num_tokens={T:4}, dim={D:4}, num_experts={E:4}, expert_parallelism={EP:4}, output_num_tokens={output_num_tokens:4}, "
|
|
279
|
+
f"{is_balanced=}, {is_padded=}, "
|
|
280
|
+
f"fbgemm_time={fbgemm_time:7.3f}us, fbgemm_bw={fbgemm_bw:8.3f}GBytes/s."
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@click.command()
|
|
285
|
+
@click.option(
|
|
286
|
+
"--kernels",
|
|
287
|
+
default=None,
|
|
288
|
+
help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
|
|
289
|
+
)
|
|
290
|
+
def main(kernels: Optional[str]):
|
|
291
|
+
if kernels is not None:
|
|
292
|
+
kernels = kernels.split(",")
|
|
293
|
+
|
|
294
|
+
def should_bench_kernel(fn):
|
|
295
|
+
return (fn is not None) and (kernels is None or fn.__name__ in kernels)
|
|
296
|
+
|
|
297
|
+
Es = [16, 128]
|
|
298
|
+
Ts = [1, 128, 2048, 4096, 8192, 16384]
|
|
299
|
+
Ds = [5120]
|
|
300
|
+
|
|
301
|
+
# Gather/Scatter
|
|
302
|
+
if should_bench_kernel(gather_scale_dense_tokens):
|
|
303
|
+
for E, T, D in itertools.product(Es, Ts, Ds):
|
|
304
|
+
bench_gather_scale_dense_tokens(E, T, D, quantize=False)
|
|
305
|
+
|
|
306
|
+
if should_bench_kernel(gather_scale_quant_dense_tokens):
|
|
307
|
+
for E, T, D in itertools.product(Es, Ts, Ds):
|
|
308
|
+
bench_gather_scale_dense_tokens(E, T, D, quantize=True)
|
|
309
|
+
|
|
310
|
+
if should_bench_kernel(gather_along_first_dim):
|
|
311
|
+
for T, D in itertools.product(Ts, Ds):
|
|
312
|
+
bench_gather_along_first_dim(T, T, D)
|
|
313
|
+
|
|
314
|
+
if should_bench_kernel(scatter_add_along_first_dim):
|
|
315
|
+
for T, D in itertools.product(Ts, Ds):
|
|
316
|
+
bench_scatter_add_along_first_dim(T, T, D)
|
|
317
|
+
|
|
318
|
+
if should_bench_kernel(scatter_add_dense_tokens):
|
|
319
|
+
for T, D in itertools.product(Ts, Ds):
|
|
320
|
+
bench_scatter_add_dense_tokens(T, T, D)
|
|
321
|
+
|
|
322
|
+
Ks = [1, 2, 4]
|
|
323
|
+
Es = [16, 32, 128, 320]
|
|
324
|
+
# Shuffling
|
|
325
|
+
if should_bench_kernel(index_shuffling):
|
|
326
|
+
for T, E, K in itertools.product(Ts, Es, Ks):
|
|
327
|
+
bench_topk_index_shuffling(T, E, K)
|
|
328
|
+
|
|
329
|
+
EPs = [2, 16]
|
|
330
|
+
Ts = [32, 128, 2048, 4096, 8192, 16384]
|
|
331
|
+
padded = [True, False]
|
|
332
|
+
balanced = [True, False]
|
|
333
|
+
|
|
334
|
+
if should_bench_kernel(combine_shuffling):
|
|
335
|
+
for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
|
|
336
|
+
bench_combine_or_split_shuffling(
|
|
337
|
+
T, D, E, EP, p, b, is_combine_shuffling=True
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
if should_bench_kernel(split_shuffling):
|
|
341
|
+
for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
|
|
342
|
+
bench_combine_or_split_shuffling(
|
|
343
|
+
T, D, E, EP, p, b, is_combine_shuffling=False
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
if __name__ == "__main__":
|
|
348
|
+
main()
|