fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,709 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import statistics
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
from subprocess import Popen
|
|
15
|
+
from typing import Callable, Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from fbgemm_gpu.tbe.utils import b_indices, TBERequest
|
|
20
|
+
from fbgemm_gpu.tbe.utils.common import get_device
|
|
21
|
+
|
|
22
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def bench_warmup(
|
|
26
|
+
request: TBERequest,
|
|
27
|
+
warmup_ms: int,
|
|
28
|
+
warmup_runs: int,
|
|
29
|
+
func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
|
|
30
|
+
bwd_only: bool = False,
|
|
31
|
+
grad: Optional[torch.Tensor] = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
indices, offsets, weights = request.unpack_3()
|
|
34
|
+
if warmup_ms:
|
|
35
|
+
start_time_ms = time.time() * 1000
|
|
36
|
+
while time.time() * 1000 - start_time_ms < warmup_ms:
|
|
37
|
+
out = func(indices, offsets, weights)
|
|
38
|
+
if bwd_only:
|
|
39
|
+
out.backward(grad)
|
|
40
|
+
else:
|
|
41
|
+
for _ in range(warmup_runs):
|
|
42
|
+
out = func(indices, offsets, weights)
|
|
43
|
+
if bwd_only:
|
|
44
|
+
out.backward(grad)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def bench_warmup_with_spec(
|
|
48
|
+
request: TBERequest,
|
|
49
|
+
warmup_ms: int,
|
|
50
|
+
warmup_runs: int,
|
|
51
|
+
func: Callable[
|
|
52
|
+
[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]],
|
|
53
|
+
torch.Tensor,
|
|
54
|
+
],
|
|
55
|
+
bwd_only: bool = False,
|
|
56
|
+
grad: Optional[torch.Tensor] = None,
|
|
57
|
+
) -> None:
|
|
58
|
+
indices, offsets, weights, batch_size_per_feature_per_rank = request.unpack_4()
|
|
59
|
+
if warmup_ms:
|
|
60
|
+
start_time_ms = time.time() * 1000
|
|
61
|
+
while time.time() * 1000 - start_time_ms < warmup_ms:
|
|
62
|
+
out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
|
|
63
|
+
if bwd_only:
|
|
64
|
+
out.backward(grad)
|
|
65
|
+
else:
|
|
66
|
+
for _ in range(warmup_runs):
|
|
67
|
+
out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
|
|
68
|
+
if bwd_only:
|
|
69
|
+
out.backward(grad)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class BMBarrier:
|
|
73
|
+
|
|
74
|
+
def __init__(self) -> None:
|
|
75
|
+
self.bar: Optional[threading.Barrier] = None
|
|
76
|
+
|
|
77
|
+
def create_barrier(self, party_size: int) -> None:
|
|
78
|
+
if self.bar is not None:
|
|
79
|
+
self.bar.reset()
|
|
80
|
+
self.bar = None
|
|
81
|
+
self.bar = torch.multiprocessing.Barrier(party_size)
|
|
82
|
+
|
|
83
|
+
def wait(self) -> None:
|
|
84
|
+
if self.bar is not None:
|
|
85
|
+
self.bar.wait()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# This barrier ensures all CPU TBE workers start the embedding workload
|
|
89
|
+
# together so that we get the most accurate measurement. This needs to be
|
|
90
|
+
# a global variable because it will be shared among worker processes.
|
|
91
|
+
cpu_bm_barrier = BMBarrier()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def cpu_tbe_worker(
|
|
95
|
+
requests_: list[TBERequest],
|
|
96
|
+
func_: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
|
|
97
|
+
use_barrier: bool = False,
|
|
98
|
+
) -> float:
|
|
99
|
+
"""
|
|
100
|
+
Worker function to process CPU TBE workload.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
requests_ (List[TBERequest]): A list of TBERequest objects to be processed. Namely, the dataset.
|
|
104
|
+
func_ (Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]):
|
|
105
|
+
The function to process each request, usually the `.forward()` method
|
|
106
|
+
n the embedding module instance.
|
|
107
|
+
use_barrier (bool, optional): Whether to use a barrier to synchronize the
|
|
108
|
+
start of embedding workload. Defaults to False.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
float: The average runtime per iteration in seconds.
|
|
112
|
+
"""
|
|
113
|
+
import time
|
|
114
|
+
|
|
115
|
+
if use_barrier:
|
|
116
|
+
cpu_bm_barrier.wait()
|
|
117
|
+
|
|
118
|
+
start_time = time.perf_counter()
|
|
119
|
+
for req in requests_:
|
|
120
|
+
func_(*(req.unpack_3()))
|
|
121
|
+
end_time = time.perf_counter()
|
|
122
|
+
|
|
123
|
+
return (end_time - start_time) / len(requests_)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def benchmark_cpu_requests_mp(
|
|
127
|
+
requests: list[TBERequest],
|
|
128
|
+
emb_module: torch.nn.Module,
|
|
129
|
+
num_warmups: int = 0,
|
|
130
|
+
num_copies: int = 1,
|
|
131
|
+
start_script: str = "",
|
|
132
|
+
end_script: str = "",
|
|
133
|
+
) -> float:
|
|
134
|
+
"""
|
|
135
|
+
CPU benchmark request handler with multi-processing support
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
requests (List[TBERequest]): A list of TBERequest objects to be processed.
|
|
139
|
+
emb_module (torch.nn.Module): The embedding module to be used for processing requests,
|
|
140
|
+
for example, an instance of `IntNBitTableBatchedEmbeddingBagsCodegen` module.
|
|
141
|
+
num_warmups (int, optional): Number of warm-up iterations to perform before benchmarking. Defaults to 0.
|
|
142
|
+
num_copies (int, optional): Number of parallel copies of the workloads. By `copies`,
|
|
143
|
+
we mean the number of parallel processes working on the same dataset described in `requests`.
|
|
144
|
+
Defaults to 1 (which means single threaded). Increasing this will enable the benchmark to use
|
|
145
|
+
more CPU cores and push higher memory bandwidth.
|
|
146
|
+
start_script (str, optional): Path to a script to be executed before starting the benchmark.
|
|
147
|
+
Defaults to empty (not running anything). This can be used to collect perf counters.
|
|
148
|
+
The script will be terminated upon benchmark finishing.
|
|
149
|
+
end_script (str, optional): Path to a script to be executed after completing the benchmark.
|
|
150
|
+
Defaults to empty (not running anything). This can be used to post-process perf counters.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
float: The average runtime per iteration in seconds.
|
|
154
|
+
|
|
155
|
+
"""
|
|
156
|
+
import os
|
|
157
|
+
|
|
158
|
+
strategy = os.environ.get("PYTORCH_SHARE_STRATEGY")
|
|
159
|
+
current_strategy = torch.multiprocessing.get_sharing_strategy()
|
|
160
|
+
if strategy is not None and current_strategy != strategy:
|
|
161
|
+
torch.multiprocessing.set_sharing_strategy(strategy)
|
|
162
|
+
|
|
163
|
+
cpu_bm_barrier.create_barrier(num_copies)
|
|
164
|
+
worker_pool = torch.multiprocessing.Pool(num_copies)
|
|
165
|
+
|
|
166
|
+
if num_warmups > 0:
|
|
167
|
+
asyncres = []
|
|
168
|
+
for _ in range(num_copies):
|
|
169
|
+
asyncres.append(
|
|
170
|
+
worker_pool.apply_async(
|
|
171
|
+
cpu_tbe_worker,
|
|
172
|
+
args=(
|
|
173
|
+
[requests[0]],
|
|
174
|
+
emb_module.forward,
|
|
175
|
+
False,
|
|
176
|
+
num_warmups,
|
|
177
|
+
),
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
for res in asyncres:
|
|
181
|
+
res.wait()
|
|
182
|
+
|
|
183
|
+
if start_script:
|
|
184
|
+
p_start = Popen([start_script, str(num_copies)])
|
|
185
|
+
|
|
186
|
+
asyncres = []
|
|
187
|
+
for _ in range(num_copies):
|
|
188
|
+
asyncres.append(
|
|
189
|
+
worker_pool.apply_async(
|
|
190
|
+
cpu_tbe_worker,
|
|
191
|
+
args=(
|
|
192
|
+
requests,
|
|
193
|
+
emb_module.forward,
|
|
194
|
+
True,
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
runtime_per_iter = 0.0
|
|
199
|
+
for res in asyncres:
|
|
200
|
+
res.wait()
|
|
201
|
+
runtime_per_iter += res.get()
|
|
202
|
+
worker_pool.close()
|
|
203
|
+
worker_pool.join()
|
|
204
|
+
worker_pool.terminate()
|
|
205
|
+
|
|
206
|
+
if start_script:
|
|
207
|
+
p_start.terminate()
|
|
208
|
+
|
|
209
|
+
if end_script:
|
|
210
|
+
p_end = Popen([end_script, str(num_copies)])
|
|
211
|
+
p_end.wait()
|
|
212
|
+
|
|
213
|
+
return runtime_per_iter / num_copies
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def benchmark_cpu_requests(
|
|
217
|
+
requests: list[TBERequest],
|
|
218
|
+
func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
|
|
219
|
+
num_warmups: int = 0,
|
|
220
|
+
) -> float:
|
|
221
|
+
import time
|
|
222
|
+
|
|
223
|
+
if num_warmups > 0:
|
|
224
|
+
for _ in range(num_warmups):
|
|
225
|
+
func(*(requests[0].unpack_3()))
|
|
226
|
+
|
|
227
|
+
start_time = time.perf_counter()
|
|
228
|
+
for req in requests:
|
|
229
|
+
func(*(req.unpack_3()))
|
|
230
|
+
end_time = time.perf_counter()
|
|
231
|
+
return (end_time - start_time) / len(requests)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def benchmark_requests( # noqa: C901
|
|
235
|
+
requests: list[TBERequest],
|
|
236
|
+
func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
|
|
237
|
+
flush_gpu_cache_size_mb: int = 0,
|
|
238
|
+
check_median: bool = False,
|
|
239
|
+
num_warmups: int = 0,
|
|
240
|
+
bwd_only: bool = False,
|
|
241
|
+
grad: Optional[torch.Tensor] = None,
|
|
242
|
+
# Used to label benchmark iterations differently in nsys profile result
|
|
243
|
+
# so that we can compare performance of two different models for example.
|
|
244
|
+
# If empty string is provided, it won't have any effect.
|
|
245
|
+
nvtx_range: str = "",
|
|
246
|
+
# Can be used to clear model's stats after warmup for example.
|
|
247
|
+
callback_after_warmup: Optional[Callable[[], None]] = None,
|
|
248
|
+
periodic_logs: bool = False,
|
|
249
|
+
warmup_ms: Optional[int] = None,
|
|
250
|
+
iters: int = -1,
|
|
251
|
+
) -> float:
|
|
252
|
+
times = []
|
|
253
|
+
# Run at least one warmup iteration to avoid the long cudaLaunchKernel time
|
|
254
|
+
# for the first kernel if warmup_ms > 0
|
|
255
|
+
# warmup_ms is prioritized over num_warmups
|
|
256
|
+
|
|
257
|
+
if warmup_ms is None:
|
|
258
|
+
num_warmups = num_warmups + 1 if num_warmups >= 0 else 1
|
|
259
|
+
|
|
260
|
+
# warm-up the GPU before profiling
|
|
261
|
+
bench_warmup(
|
|
262
|
+
requests[0],
|
|
263
|
+
# pyre-ignore[6]
|
|
264
|
+
warmup_ms,
|
|
265
|
+
num_warmups,
|
|
266
|
+
lambda indices, offsets, per_sample_weights: func(
|
|
267
|
+
indices,
|
|
268
|
+
offsets,
|
|
269
|
+
per_sample_weights,
|
|
270
|
+
),
|
|
271
|
+
bwd_only=bwd_only,
|
|
272
|
+
grad=grad,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if callback_after_warmup is not None:
|
|
276
|
+
callback_after_warmup()
|
|
277
|
+
|
|
278
|
+
num_reqs = len(requests)
|
|
279
|
+
iters = num_reqs if iters == -1 else iters
|
|
280
|
+
|
|
281
|
+
if torch.cuda.is_available():
|
|
282
|
+
torch.cuda.synchronize()
|
|
283
|
+
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
|
284
|
+
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
|
285
|
+
else:
|
|
286
|
+
start_events = []
|
|
287
|
+
end_events = []
|
|
288
|
+
|
|
289
|
+
for it in range(iters):
|
|
290
|
+
req = requests[it % num_reqs]
|
|
291
|
+
|
|
292
|
+
indices, offsets, weights = req.unpack_3()
|
|
293
|
+
if bwd_only:
|
|
294
|
+
# Run forward before profiling if does backward only
|
|
295
|
+
out = func(indices, offsets, weights)
|
|
296
|
+
start_time = time.time()
|
|
297
|
+
if torch.cuda.is_available():
|
|
298
|
+
if flush_gpu_cache_size_mb:
|
|
299
|
+
_ = torch.rand(
|
|
300
|
+
flush_gpu_cache_size_mb * 1024 * 1024 // 4,
|
|
301
|
+
dtype=torch.float,
|
|
302
|
+
device=get_device(),
|
|
303
|
+
)
|
|
304
|
+
start_events[it].record()
|
|
305
|
+
|
|
306
|
+
if nvtx_range:
|
|
307
|
+
torch.cuda.nvtx.range_push(f"{nvtx_range}-{it}")
|
|
308
|
+
|
|
309
|
+
if bwd_only:
|
|
310
|
+
out.backward(grad)
|
|
311
|
+
else:
|
|
312
|
+
func(indices, offsets, weights)
|
|
313
|
+
|
|
314
|
+
if nvtx_range:
|
|
315
|
+
torch.cuda.nvtx.range_pop()
|
|
316
|
+
|
|
317
|
+
if torch.cuda.is_available():
|
|
318
|
+
end_events[it].record()
|
|
319
|
+
else:
|
|
320
|
+
it_time = time.time() - start_time
|
|
321
|
+
times.append(it_time)
|
|
322
|
+
|
|
323
|
+
if torch.cuda.is_available():
|
|
324
|
+
torch.cuda.synchronize()
|
|
325
|
+
times = [
|
|
326
|
+
start.elapsed_time(end) * 1.0e-3
|
|
327
|
+
for start, end in zip(start_events, end_events)
|
|
328
|
+
]
|
|
329
|
+
|
|
330
|
+
if periodic_logs:
|
|
331
|
+
for it in range(100, iters + 1, 100):
|
|
332
|
+
times_ = times[0:it]
|
|
333
|
+
avg_time = sum(times_) / len(times_) * 1.0e6
|
|
334
|
+
last_100_avg = sum(times_[-100:]) / 100 * 1.0e6
|
|
335
|
+
logging.info(
|
|
336
|
+
f"Iteration [{it}/{len(requests)}]: Last 100: {last_100_avg:.2f} us, Running avg: {avg_time:.2f} us"
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
avg_time = sum(times) / iters
|
|
340
|
+
median_time = statistics.median(times)
|
|
341
|
+
return median_time if check_median else avg_time
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def benchmark_requests_with_spec( # noqa: C901
|
|
345
|
+
requests: list[TBERequest],
|
|
346
|
+
func: Callable[
|
|
347
|
+
[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]],
|
|
348
|
+
torch.Tensor,
|
|
349
|
+
],
|
|
350
|
+
flush_gpu_cache_size_mb: int = 0,
|
|
351
|
+
check_median: bool = False,
|
|
352
|
+
num_warmups: int = 0,
|
|
353
|
+
bwd_only: bool = False,
|
|
354
|
+
grad: Optional[torch.Tensor] = None,
|
|
355
|
+
# Used to label benchmark iterations differently in nsys profile result
|
|
356
|
+
# so that we can compare performance of two different models for example.
|
|
357
|
+
# If empty string is provided, it won't have any effect.
|
|
358
|
+
nvtx_range: str = "",
|
|
359
|
+
# Can be used to clear model's stats after warmup for example.
|
|
360
|
+
callback_after_warmup: Optional[Callable[[], None]] = None,
|
|
361
|
+
periodic_logs: bool = False,
|
|
362
|
+
warmup_ms: Optional[int] = None,
|
|
363
|
+
iters: int = -1,
|
|
364
|
+
) -> float:
|
|
365
|
+
times = []
|
|
366
|
+
# Run at least one warmup iteration to avoid the long cudaLaunchKernel time
|
|
367
|
+
# for the first kernel if warmup_ms > 0
|
|
368
|
+
# warmup_ms is prioritized over num_warmups
|
|
369
|
+
|
|
370
|
+
if warmup_ms is None:
|
|
371
|
+
num_warmups = num_warmups + 1 if num_warmups >= 0 else 1
|
|
372
|
+
|
|
373
|
+
# warm-up the GPU before profiling
|
|
374
|
+
bench_warmup_with_spec(
|
|
375
|
+
requests[0],
|
|
376
|
+
# pyre-ignore[6]
|
|
377
|
+
warmup_ms,
|
|
378
|
+
num_warmups,
|
|
379
|
+
lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: func(
|
|
380
|
+
indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
|
|
381
|
+
),
|
|
382
|
+
bwd_only=bwd_only,
|
|
383
|
+
grad=grad,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if callback_after_warmup is not None:
|
|
387
|
+
callback_after_warmup()
|
|
388
|
+
|
|
389
|
+
num_reqs = len(requests)
|
|
390
|
+
iters = num_reqs if iters == -1 else iters
|
|
391
|
+
|
|
392
|
+
if torch.cuda.is_available():
|
|
393
|
+
torch.cuda.synchronize()
|
|
394
|
+
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
|
395
|
+
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
|
396
|
+
else:
|
|
397
|
+
start_events = []
|
|
398
|
+
end_events = []
|
|
399
|
+
|
|
400
|
+
for it in range(iters):
|
|
401
|
+
req = requests[it % num_reqs]
|
|
402
|
+
|
|
403
|
+
indices, offsets, weights, batch_size_per_feature_per_rank = req.unpack_4()
|
|
404
|
+
# logging.info(
|
|
405
|
+
# f"[Benchmark Request] batch_size_per_feature_per_rank {batch_size_per_feature_per_rank} {indices.device}"
|
|
406
|
+
# )
|
|
407
|
+
|
|
408
|
+
if bwd_only:
|
|
409
|
+
# Run forward before profiling if does backward only
|
|
410
|
+
out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
|
|
411
|
+
start_time = time.time()
|
|
412
|
+
if torch.cuda.is_available():
|
|
413
|
+
if flush_gpu_cache_size_mb:
|
|
414
|
+
_ = torch.rand(
|
|
415
|
+
flush_gpu_cache_size_mb * 1024 * 1024 // 4,
|
|
416
|
+
dtype=torch.float,
|
|
417
|
+
device=get_device(),
|
|
418
|
+
)
|
|
419
|
+
start_events[it].record()
|
|
420
|
+
|
|
421
|
+
if nvtx_range:
|
|
422
|
+
torch.cuda.nvtx.range_push(f"{nvtx_range}-{it}")
|
|
423
|
+
|
|
424
|
+
if bwd_only:
|
|
425
|
+
out.backward(grad)
|
|
426
|
+
else:
|
|
427
|
+
func(indices, offsets, weights, batch_size_per_feature_per_rank)
|
|
428
|
+
|
|
429
|
+
if nvtx_range:
|
|
430
|
+
torch.cuda.nvtx.range_pop()
|
|
431
|
+
|
|
432
|
+
if torch.cuda.is_available():
|
|
433
|
+
end_events[it].record()
|
|
434
|
+
else:
|
|
435
|
+
it_time = time.time() - start_time
|
|
436
|
+
times.append(it_time)
|
|
437
|
+
|
|
438
|
+
if torch.cuda.is_available():
|
|
439
|
+
torch.cuda.synchronize()
|
|
440
|
+
times = [
|
|
441
|
+
start.elapsed_time(end) * 1.0e-3
|
|
442
|
+
for start, end in zip(start_events, end_events)
|
|
443
|
+
]
|
|
444
|
+
|
|
445
|
+
if periodic_logs:
|
|
446
|
+
for it in range(100, iters + 1, 100):
|
|
447
|
+
times_ = times[0:it]
|
|
448
|
+
avg_time = sum(times_) / len(times_) * 1.0e6
|
|
449
|
+
last_100_avg = sum(times_[-100:]) / 100 * 1.0e6
|
|
450
|
+
logging.info(
|
|
451
|
+
f"Iteration [{it}/{len(requests)}]: Last 100: {last_100_avg:.2f} us, Running avg: {avg_time:.2f} us"
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
avg_time = sum(times) / iters
|
|
455
|
+
median_time = statistics.median(times)
|
|
456
|
+
return median_time if check_median else avg_time
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def benchmark_requests_refer(
|
|
460
|
+
requests: list[TBERequest],
|
|
461
|
+
T: int,
|
|
462
|
+
B: int,
|
|
463
|
+
L: int,
|
|
464
|
+
E: int,
|
|
465
|
+
D: int,
|
|
466
|
+
pooling_mode: str,
|
|
467
|
+
weighted: bool,
|
|
468
|
+
flush_gpu_cache_size_mb: int = 0,
|
|
469
|
+
check_median: bool = False,
|
|
470
|
+
) -> float:
|
|
471
|
+
do_pooling = pooling_mode in ["sum", "mean"]
|
|
472
|
+
|
|
473
|
+
if do_pooling:
|
|
474
|
+
nn_embedding_list = [
|
|
475
|
+
torch.nn.EmbeddingBag(E, D, mode=pooling_mode, sparse=True).cuda()
|
|
476
|
+
] * T
|
|
477
|
+
else:
|
|
478
|
+
nn_embedding_list = [torch.nn.Embedding(E, D, sparse=True).cuda()] * T
|
|
479
|
+
|
|
480
|
+
times = []
|
|
481
|
+
if torch.cuda.is_available():
|
|
482
|
+
torch.cuda.synchronize()
|
|
483
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
484
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
485
|
+
for req in requests:
|
|
486
|
+
indices, _, weights = req.unpack_3()
|
|
487
|
+
indices_list = indices.view(T, B, L).split(1)
|
|
488
|
+
|
|
489
|
+
if weighted:
|
|
490
|
+
assert weights is not None
|
|
491
|
+
weights_list = weights.view(T, B, L).split(1)
|
|
492
|
+
|
|
493
|
+
start_time = time.time()
|
|
494
|
+
if torch.cuda.is_available():
|
|
495
|
+
if flush_gpu_cache_size_mb:
|
|
496
|
+
_ = torch.rand(
|
|
497
|
+
flush_gpu_cache_size_mb * 1024 * 1024 // 4,
|
|
498
|
+
dtype=torch.float,
|
|
499
|
+
device=get_device(),
|
|
500
|
+
)
|
|
501
|
+
torch.cuda.synchronize()
|
|
502
|
+
start_event.record()
|
|
503
|
+
|
|
504
|
+
nn_embedding_output = (
|
|
505
|
+
[
|
|
506
|
+
b_indices(nn_embedding, x, use_cpu=False, do_pooling=do_pooling)
|
|
507
|
+
for (nn_embedding, x) in zip(nn_embedding_list, indices_list)
|
|
508
|
+
]
|
|
509
|
+
if not weighted
|
|
510
|
+
else [
|
|
511
|
+
b_indices(
|
|
512
|
+
nn_embedding,
|
|
513
|
+
x,
|
|
514
|
+
per_sample_weights=xw.view(-1),
|
|
515
|
+
use_cpu=False,
|
|
516
|
+
do_pooling=do_pooling,
|
|
517
|
+
)
|
|
518
|
+
for (nn_embedding, x, xw) in zip(
|
|
519
|
+
nn_embedding_list,
|
|
520
|
+
indices_list,
|
|
521
|
+
# pyre-fixme[61]: `weights_list` is undefined, or not always
|
|
522
|
+
# defined.
|
|
523
|
+
weights_list,
|
|
524
|
+
)
|
|
525
|
+
]
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
if do_pooling:
|
|
529
|
+
final_output = torch.cat(
|
|
530
|
+
[f.view(B, -1) for f in nn_embedding_output], dim=1
|
|
531
|
+
)
|
|
532
|
+
else:
|
|
533
|
+
final_output = torch.cat(nn_embedding_output, dim=0).view( # noqa: F841
|
|
534
|
+
-1, D
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
if torch.cuda.is_available():
|
|
538
|
+
end_event.record()
|
|
539
|
+
torch.cuda.synchronize()
|
|
540
|
+
# pyre-fixme[61]: `end_event` is undefined, or not always defined.
|
|
541
|
+
it_time = start_event.elapsed_time(end_event) * 1.0e-3
|
|
542
|
+
times.append(it_time)
|
|
543
|
+
else:
|
|
544
|
+
it_time = time.time() - start_time
|
|
545
|
+
times.append(it_time)
|
|
546
|
+
avg_time = sum(times) / len(requests)
|
|
547
|
+
median_time = statistics.median(times)
|
|
548
|
+
return median_time if check_median else avg_time
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def benchmark_pipelined_requests(
|
|
552
|
+
requests: list[TBERequest],
|
|
553
|
+
func1: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None],
|
|
554
|
+
func2: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None],
|
|
555
|
+
flush_gpu_cache_size_mb: int = 0,
|
|
556
|
+
check_median: bool = False,
|
|
557
|
+
) -> tuple[float, float]:
|
|
558
|
+
torch.cuda.synchronize()
|
|
559
|
+
start_events = [
|
|
560
|
+
(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
|
|
561
|
+
for _ in requests
|
|
562
|
+
]
|
|
563
|
+
end_events = [
|
|
564
|
+
(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
|
|
565
|
+
for _ in requests
|
|
566
|
+
]
|
|
567
|
+
for req, start_event, end_event in zip(requests, start_events, end_events):
|
|
568
|
+
indices, offsets, indices_weights = req.unpack_3()
|
|
569
|
+
if flush_gpu_cache_size_mb:
|
|
570
|
+
_ = torch.rand(
|
|
571
|
+
flush_gpu_cache_size_mb * 1024 * 1024 // 4,
|
|
572
|
+
dtype=torch.float,
|
|
573
|
+
device=get_device(),
|
|
574
|
+
)
|
|
575
|
+
torch.cuda.synchronize()
|
|
576
|
+
start_event[0].record()
|
|
577
|
+
func1(indices, offsets, indices_weights)
|
|
578
|
+
end_event[0].record()
|
|
579
|
+
start_event[1].record()
|
|
580
|
+
func2(indices, offsets, indices_weights)
|
|
581
|
+
end_event[1].record()
|
|
582
|
+
torch.cuda.synchronize()
|
|
583
|
+
avg_time = (
|
|
584
|
+
sum(
|
|
585
|
+
start_event[0].elapsed_time(end_event[0]) * 1.0e-3
|
|
586
|
+
for start_event, end_event in zip(start_events, end_events)
|
|
587
|
+
)
|
|
588
|
+
/ len(requests),
|
|
589
|
+
sum(
|
|
590
|
+
start_event[1].elapsed_time(end_event[1]) * 1.0e-3
|
|
591
|
+
for start_event, end_event in zip(start_events, end_events)
|
|
592
|
+
)
|
|
593
|
+
/ len(requests),
|
|
594
|
+
)
|
|
595
|
+
median_time = (
|
|
596
|
+
statistics.median(
|
|
597
|
+
start_event[0].elapsed_time(end_event[0]) * 1.0e-3
|
|
598
|
+
for start_event, end_event in zip(start_events, end_events)
|
|
599
|
+
),
|
|
600
|
+
statistics.median(
|
|
601
|
+
start_event[1].elapsed_time(end_event[1]) * 1.0e-3
|
|
602
|
+
for start_event, end_event in zip(start_events, end_events)
|
|
603
|
+
),
|
|
604
|
+
)
|
|
605
|
+
return median_time if check_median else avg_time
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def benchmark_vbe(
|
|
609
|
+
requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
|
|
610
|
+
func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
|
|
611
|
+
num_warmups: int = 0,
|
|
612
|
+
) -> tuple[float, float]:
|
|
613
|
+
"""
|
|
614
|
+
A benchmark function to return the average execution time in seconds of
|
|
615
|
+
forward and backward of VBE kernels.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
requests (List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]):
|
|
619
|
+
A list of requests. Each request is a tuple
|
|
620
|
+
of indices, offsets and weights.
|
|
621
|
+
|
|
622
|
+
func (Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]):
|
|
623
|
+
A function that takes in indices, offsets, and weights
|
|
624
|
+
and returns the output of the VBE kernel.
|
|
625
|
+
|
|
626
|
+
num_warmups (int):
|
|
627
|
+
The number of warm-up iterations before measuring performance.
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
Tuple[float, float]:
|
|
631
|
+
A tuple of average execution time in seconds of forward and
|
|
632
|
+
backward of VBE kernels.
|
|
633
|
+
"""
|
|
634
|
+
|
|
635
|
+
use_cuda = torch.cuda.is_available()
|
|
636
|
+
|
|
637
|
+
# Warm-ups.
|
|
638
|
+
for _ in range(num_warmups):
|
|
639
|
+
# Warm-up using the first request as done in benchmark_requests
|
|
640
|
+
indices, offsets, weights = requests[0]
|
|
641
|
+
out = func(indices, offsets, weights)
|
|
642
|
+
grad = torch.rand_like(out)
|
|
643
|
+
out.backward(grad)
|
|
644
|
+
|
|
645
|
+
iters = len(requests)
|
|
646
|
+
if use_cuda:
|
|
647
|
+
fwd_start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
|
648
|
+
fwd_end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
|
649
|
+
bwd_start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
|
650
|
+
bwd_end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
|
651
|
+
torch.cuda.synchronize()
|
|
652
|
+
else:
|
|
653
|
+
# Actual measurement in seconds.
|
|
654
|
+
fwd_times_sec = []
|
|
655
|
+
bwd_times_sec = []
|
|
656
|
+
|
|
657
|
+
for i, (indices, offsets, weights) in enumerate(requests):
|
|
658
|
+
# forward
|
|
659
|
+
if use_cuda:
|
|
660
|
+
# pyre-ignore[61]
|
|
661
|
+
fwd_start_events[i].record()
|
|
662
|
+
else:
|
|
663
|
+
start_time = time.time()
|
|
664
|
+
|
|
665
|
+
out = func(indices, offsets, weights)
|
|
666
|
+
if use_cuda:
|
|
667
|
+
# pyre-ignore[61]
|
|
668
|
+
fwd_end_events[i].record()
|
|
669
|
+
else:
|
|
670
|
+
# pyre-ignore[61]
|
|
671
|
+
fwd_times_sec.append(time.time() - start_time)
|
|
672
|
+
|
|
673
|
+
grad = torch.rand_like(out)
|
|
674
|
+
|
|
675
|
+
if use_cuda:
|
|
676
|
+
# pyre-ignore[61]
|
|
677
|
+
bwd_start_events[i].record()
|
|
678
|
+
else:
|
|
679
|
+
start_time = time.time()
|
|
680
|
+
# backward
|
|
681
|
+
out.backward(grad)
|
|
682
|
+
if use_cuda:
|
|
683
|
+
# pyre-ignore[61]
|
|
684
|
+
bwd_end_events[i].record()
|
|
685
|
+
else:
|
|
686
|
+
# pyre-ignore[61]
|
|
687
|
+
bwd_times_sec.append(time.time() - start_time)
|
|
688
|
+
|
|
689
|
+
if use_cuda:
|
|
690
|
+
torch.cuda.synchronize()
|
|
691
|
+
|
|
692
|
+
if use_cuda:
|
|
693
|
+
fwd_times_sec = [
|
|
694
|
+
start_event.elapsed_time(end_event) * 1.0e-3
|
|
695
|
+
# pyre-ignore[61]
|
|
696
|
+
for start_event, end_event in zip(fwd_start_events, fwd_end_events)
|
|
697
|
+
]
|
|
698
|
+
bwd_times_sec = [
|
|
699
|
+
start_event.elapsed_time(end_event) * 1.0e-3
|
|
700
|
+
# pyre-ignore[61]
|
|
701
|
+
for start_event, end_event in zip(bwd_start_events, bwd_end_events)
|
|
702
|
+
]
|
|
703
|
+
|
|
704
|
+
# pyre-ignore[61]
|
|
705
|
+
fwd_time_sec = statistics.median(fwd_times_sec)
|
|
706
|
+
# pyre-ignore[61]
|
|
707
|
+
bwd_time_sec = statistics.median(bwd_times_sec)
|
|
708
|
+
|
|
709
|
+
return fwd_time_sec, bwd_time_sec
|