fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,709 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import logging
11
+ import statistics
12
+ import threading
13
+ import time
14
+ from subprocess import Popen
15
+ from typing import Callable, Optional
16
+
17
+ import torch
18
+
19
+ from fbgemm_gpu.tbe.utils import b_indices, TBERequest
20
+ from fbgemm_gpu.tbe.utils.common import get_device
21
+
22
+ logging.basicConfig(level=logging.DEBUG)
23
+
24
+
25
+ def bench_warmup(
26
+ request: TBERequest,
27
+ warmup_ms: int,
28
+ warmup_runs: int,
29
+ func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
30
+ bwd_only: bool = False,
31
+ grad: Optional[torch.Tensor] = None,
32
+ ) -> None:
33
+ indices, offsets, weights = request.unpack_3()
34
+ if warmup_ms:
35
+ start_time_ms = time.time() * 1000
36
+ while time.time() * 1000 - start_time_ms < warmup_ms:
37
+ out = func(indices, offsets, weights)
38
+ if bwd_only:
39
+ out.backward(grad)
40
+ else:
41
+ for _ in range(warmup_runs):
42
+ out = func(indices, offsets, weights)
43
+ if bwd_only:
44
+ out.backward(grad)
45
+
46
+
47
+ def bench_warmup_with_spec(
48
+ request: TBERequest,
49
+ warmup_ms: int,
50
+ warmup_runs: int,
51
+ func: Callable[
52
+ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]],
53
+ torch.Tensor,
54
+ ],
55
+ bwd_only: bool = False,
56
+ grad: Optional[torch.Tensor] = None,
57
+ ) -> None:
58
+ indices, offsets, weights, batch_size_per_feature_per_rank = request.unpack_4()
59
+ if warmup_ms:
60
+ start_time_ms = time.time() * 1000
61
+ while time.time() * 1000 - start_time_ms < warmup_ms:
62
+ out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
63
+ if bwd_only:
64
+ out.backward(grad)
65
+ else:
66
+ for _ in range(warmup_runs):
67
+ out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
68
+ if bwd_only:
69
+ out.backward(grad)
70
+
71
+
72
+ class BMBarrier:
73
+
74
+ def __init__(self) -> None:
75
+ self.bar: Optional[threading.Barrier] = None
76
+
77
+ def create_barrier(self, party_size: int) -> None:
78
+ if self.bar is not None:
79
+ self.bar.reset()
80
+ self.bar = None
81
+ self.bar = torch.multiprocessing.Barrier(party_size)
82
+
83
+ def wait(self) -> None:
84
+ if self.bar is not None:
85
+ self.bar.wait()
86
+
87
+
88
+ # This barrier ensures all CPU TBE workers start the embedding workload
89
+ # together so that we get the most accurate measurement. This needs to be
90
+ # a global variable because it will be shared among worker processes.
91
+ cpu_bm_barrier = BMBarrier()
92
+
93
+
94
+ def cpu_tbe_worker(
95
+ requests_: list[TBERequest],
96
+ func_: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
97
+ use_barrier: bool = False,
98
+ ) -> float:
99
+ """
100
+ Worker function to process CPU TBE workload.
101
+
102
+ Args:
103
+ requests_ (List[TBERequest]): A list of TBERequest objects to be processed. Namely, the dataset.
104
+ func_ (Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]):
105
+ The function to process each request, usually the `.forward()` method
106
+ n the embedding module instance.
107
+ use_barrier (bool, optional): Whether to use a barrier to synchronize the
108
+ start of embedding workload. Defaults to False.
109
+
110
+ Returns:
111
+ float: The average runtime per iteration in seconds.
112
+ """
113
+ import time
114
+
115
+ if use_barrier:
116
+ cpu_bm_barrier.wait()
117
+
118
+ start_time = time.perf_counter()
119
+ for req in requests_:
120
+ func_(*(req.unpack_3()))
121
+ end_time = time.perf_counter()
122
+
123
+ return (end_time - start_time) / len(requests_)
124
+
125
+
126
+ def benchmark_cpu_requests_mp(
127
+ requests: list[TBERequest],
128
+ emb_module: torch.nn.Module,
129
+ num_warmups: int = 0,
130
+ num_copies: int = 1,
131
+ start_script: str = "",
132
+ end_script: str = "",
133
+ ) -> float:
134
+ """
135
+ CPU benchmark request handler with multi-processing support
136
+
137
+ Args:
138
+ requests (List[TBERequest]): A list of TBERequest objects to be processed.
139
+ emb_module (torch.nn.Module): The embedding module to be used for processing requests,
140
+ for example, an instance of `IntNBitTableBatchedEmbeddingBagsCodegen` module.
141
+ num_warmups (int, optional): Number of warm-up iterations to perform before benchmarking. Defaults to 0.
142
+ num_copies (int, optional): Number of parallel copies of the workloads. By `copies`,
143
+ we mean the number of parallel processes working on the same dataset described in `requests`.
144
+ Defaults to 1 (which means single threaded). Increasing this will enable the benchmark to use
145
+ more CPU cores and push higher memory bandwidth.
146
+ start_script (str, optional): Path to a script to be executed before starting the benchmark.
147
+ Defaults to empty (not running anything). This can be used to collect perf counters.
148
+ The script will be terminated upon benchmark finishing.
149
+ end_script (str, optional): Path to a script to be executed after completing the benchmark.
150
+ Defaults to empty (not running anything). This can be used to post-process perf counters.
151
+
152
+ Returns:
153
+ float: The average runtime per iteration in seconds.
154
+
155
+ """
156
+ import os
157
+
158
+ strategy = os.environ.get("PYTORCH_SHARE_STRATEGY")
159
+ current_strategy = torch.multiprocessing.get_sharing_strategy()
160
+ if strategy is not None and current_strategy != strategy:
161
+ torch.multiprocessing.set_sharing_strategy(strategy)
162
+
163
+ cpu_bm_barrier.create_barrier(num_copies)
164
+ worker_pool = torch.multiprocessing.Pool(num_copies)
165
+
166
+ if num_warmups > 0:
167
+ asyncres = []
168
+ for _ in range(num_copies):
169
+ asyncres.append(
170
+ worker_pool.apply_async(
171
+ cpu_tbe_worker,
172
+ args=(
173
+ [requests[0]],
174
+ emb_module.forward,
175
+ False,
176
+ num_warmups,
177
+ ),
178
+ )
179
+ )
180
+ for res in asyncres:
181
+ res.wait()
182
+
183
+ if start_script:
184
+ p_start = Popen([start_script, str(num_copies)])
185
+
186
+ asyncres = []
187
+ for _ in range(num_copies):
188
+ asyncres.append(
189
+ worker_pool.apply_async(
190
+ cpu_tbe_worker,
191
+ args=(
192
+ requests,
193
+ emb_module.forward,
194
+ True,
195
+ ),
196
+ )
197
+ )
198
+ runtime_per_iter = 0.0
199
+ for res in asyncres:
200
+ res.wait()
201
+ runtime_per_iter += res.get()
202
+ worker_pool.close()
203
+ worker_pool.join()
204
+ worker_pool.terminate()
205
+
206
+ if start_script:
207
+ p_start.terminate()
208
+
209
+ if end_script:
210
+ p_end = Popen([end_script, str(num_copies)])
211
+ p_end.wait()
212
+
213
+ return runtime_per_iter / num_copies
214
+
215
+
216
+ def benchmark_cpu_requests(
217
+ requests: list[TBERequest],
218
+ func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
219
+ num_warmups: int = 0,
220
+ ) -> float:
221
+ import time
222
+
223
+ if num_warmups > 0:
224
+ for _ in range(num_warmups):
225
+ func(*(requests[0].unpack_3()))
226
+
227
+ start_time = time.perf_counter()
228
+ for req in requests:
229
+ func(*(req.unpack_3()))
230
+ end_time = time.perf_counter()
231
+ return (end_time - start_time) / len(requests)
232
+
233
+
234
+ def benchmark_requests( # noqa: C901
235
+ requests: list[TBERequest],
236
+ func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
237
+ flush_gpu_cache_size_mb: int = 0,
238
+ check_median: bool = False,
239
+ num_warmups: int = 0,
240
+ bwd_only: bool = False,
241
+ grad: Optional[torch.Tensor] = None,
242
+ # Used to label benchmark iterations differently in nsys profile result
243
+ # so that we can compare performance of two different models for example.
244
+ # If empty string is provided, it won't have any effect.
245
+ nvtx_range: str = "",
246
+ # Can be used to clear model's stats after warmup for example.
247
+ callback_after_warmup: Optional[Callable[[], None]] = None,
248
+ periodic_logs: bool = False,
249
+ warmup_ms: Optional[int] = None,
250
+ iters: int = -1,
251
+ ) -> float:
252
+ times = []
253
+ # Run at least one warmup iteration to avoid the long cudaLaunchKernel time
254
+ # for the first kernel if warmup_ms > 0
255
+ # warmup_ms is prioritized over num_warmups
256
+
257
+ if warmup_ms is None:
258
+ num_warmups = num_warmups + 1 if num_warmups >= 0 else 1
259
+
260
+ # warm-up the GPU before profiling
261
+ bench_warmup(
262
+ requests[0],
263
+ # pyre-ignore[6]
264
+ warmup_ms,
265
+ num_warmups,
266
+ lambda indices, offsets, per_sample_weights: func(
267
+ indices,
268
+ offsets,
269
+ per_sample_weights,
270
+ ),
271
+ bwd_only=bwd_only,
272
+ grad=grad,
273
+ )
274
+
275
+ if callback_after_warmup is not None:
276
+ callback_after_warmup()
277
+
278
+ num_reqs = len(requests)
279
+ iters = num_reqs if iters == -1 else iters
280
+
281
+ if torch.cuda.is_available():
282
+ torch.cuda.synchronize()
283
+ start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
284
+ end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
285
+ else:
286
+ start_events = []
287
+ end_events = []
288
+
289
+ for it in range(iters):
290
+ req = requests[it % num_reqs]
291
+
292
+ indices, offsets, weights = req.unpack_3()
293
+ if bwd_only:
294
+ # Run forward before profiling if does backward only
295
+ out = func(indices, offsets, weights)
296
+ start_time = time.time()
297
+ if torch.cuda.is_available():
298
+ if flush_gpu_cache_size_mb:
299
+ _ = torch.rand(
300
+ flush_gpu_cache_size_mb * 1024 * 1024 // 4,
301
+ dtype=torch.float,
302
+ device=get_device(),
303
+ )
304
+ start_events[it].record()
305
+
306
+ if nvtx_range:
307
+ torch.cuda.nvtx.range_push(f"{nvtx_range}-{it}")
308
+
309
+ if bwd_only:
310
+ out.backward(grad)
311
+ else:
312
+ func(indices, offsets, weights)
313
+
314
+ if nvtx_range:
315
+ torch.cuda.nvtx.range_pop()
316
+
317
+ if torch.cuda.is_available():
318
+ end_events[it].record()
319
+ else:
320
+ it_time = time.time() - start_time
321
+ times.append(it_time)
322
+
323
+ if torch.cuda.is_available():
324
+ torch.cuda.synchronize()
325
+ times = [
326
+ start.elapsed_time(end) * 1.0e-3
327
+ for start, end in zip(start_events, end_events)
328
+ ]
329
+
330
+ if periodic_logs:
331
+ for it in range(100, iters + 1, 100):
332
+ times_ = times[0:it]
333
+ avg_time = sum(times_) / len(times_) * 1.0e6
334
+ last_100_avg = sum(times_[-100:]) / 100 * 1.0e6
335
+ logging.info(
336
+ f"Iteration [{it}/{len(requests)}]: Last 100: {last_100_avg:.2f} us, Running avg: {avg_time:.2f} us"
337
+ )
338
+
339
+ avg_time = sum(times) / iters
340
+ median_time = statistics.median(times)
341
+ return median_time if check_median else avg_time
342
+
343
+
344
+ def benchmark_requests_with_spec( # noqa: C901
345
+ requests: list[TBERequest],
346
+ func: Callable[
347
+ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]],
348
+ torch.Tensor,
349
+ ],
350
+ flush_gpu_cache_size_mb: int = 0,
351
+ check_median: bool = False,
352
+ num_warmups: int = 0,
353
+ bwd_only: bool = False,
354
+ grad: Optional[torch.Tensor] = None,
355
+ # Used to label benchmark iterations differently in nsys profile result
356
+ # so that we can compare performance of two different models for example.
357
+ # If empty string is provided, it won't have any effect.
358
+ nvtx_range: str = "",
359
+ # Can be used to clear model's stats after warmup for example.
360
+ callback_after_warmup: Optional[Callable[[], None]] = None,
361
+ periodic_logs: bool = False,
362
+ warmup_ms: Optional[int] = None,
363
+ iters: int = -1,
364
+ ) -> float:
365
+ times = []
366
+ # Run at least one warmup iteration to avoid the long cudaLaunchKernel time
367
+ # for the first kernel if warmup_ms > 0
368
+ # warmup_ms is prioritized over num_warmups
369
+
370
+ if warmup_ms is None:
371
+ num_warmups = num_warmups + 1 if num_warmups >= 0 else 1
372
+
373
+ # warm-up the GPU before profiling
374
+ bench_warmup_with_spec(
375
+ requests[0],
376
+ # pyre-ignore[6]
377
+ warmup_ms,
378
+ num_warmups,
379
+ lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: func(
380
+ indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
381
+ ),
382
+ bwd_only=bwd_only,
383
+ grad=grad,
384
+ )
385
+
386
+ if callback_after_warmup is not None:
387
+ callback_after_warmup()
388
+
389
+ num_reqs = len(requests)
390
+ iters = num_reqs if iters == -1 else iters
391
+
392
+ if torch.cuda.is_available():
393
+ torch.cuda.synchronize()
394
+ start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
395
+ end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
396
+ else:
397
+ start_events = []
398
+ end_events = []
399
+
400
+ for it in range(iters):
401
+ req = requests[it % num_reqs]
402
+
403
+ indices, offsets, weights, batch_size_per_feature_per_rank = req.unpack_4()
404
+ # logging.info(
405
+ # f"[Benchmark Request] batch_size_per_feature_per_rank {batch_size_per_feature_per_rank} {indices.device}"
406
+ # )
407
+
408
+ if bwd_only:
409
+ # Run forward before profiling if does backward only
410
+ out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
411
+ start_time = time.time()
412
+ if torch.cuda.is_available():
413
+ if flush_gpu_cache_size_mb:
414
+ _ = torch.rand(
415
+ flush_gpu_cache_size_mb * 1024 * 1024 // 4,
416
+ dtype=torch.float,
417
+ device=get_device(),
418
+ )
419
+ start_events[it].record()
420
+
421
+ if nvtx_range:
422
+ torch.cuda.nvtx.range_push(f"{nvtx_range}-{it}")
423
+
424
+ if bwd_only:
425
+ out.backward(grad)
426
+ else:
427
+ func(indices, offsets, weights, batch_size_per_feature_per_rank)
428
+
429
+ if nvtx_range:
430
+ torch.cuda.nvtx.range_pop()
431
+
432
+ if torch.cuda.is_available():
433
+ end_events[it].record()
434
+ else:
435
+ it_time = time.time() - start_time
436
+ times.append(it_time)
437
+
438
+ if torch.cuda.is_available():
439
+ torch.cuda.synchronize()
440
+ times = [
441
+ start.elapsed_time(end) * 1.0e-3
442
+ for start, end in zip(start_events, end_events)
443
+ ]
444
+
445
+ if periodic_logs:
446
+ for it in range(100, iters + 1, 100):
447
+ times_ = times[0:it]
448
+ avg_time = sum(times_) / len(times_) * 1.0e6
449
+ last_100_avg = sum(times_[-100:]) / 100 * 1.0e6
450
+ logging.info(
451
+ f"Iteration [{it}/{len(requests)}]: Last 100: {last_100_avg:.2f} us, Running avg: {avg_time:.2f} us"
452
+ )
453
+
454
+ avg_time = sum(times) / iters
455
+ median_time = statistics.median(times)
456
+ return median_time if check_median else avg_time
457
+
458
+
459
+ def benchmark_requests_refer(
460
+ requests: list[TBERequest],
461
+ T: int,
462
+ B: int,
463
+ L: int,
464
+ E: int,
465
+ D: int,
466
+ pooling_mode: str,
467
+ weighted: bool,
468
+ flush_gpu_cache_size_mb: int = 0,
469
+ check_median: bool = False,
470
+ ) -> float:
471
+ do_pooling = pooling_mode in ["sum", "mean"]
472
+
473
+ if do_pooling:
474
+ nn_embedding_list = [
475
+ torch.nn.EmbeddingBag(E, D, mode=pooling_mode, sparse=True).cuda()
476
+ ] * T
477
+ else:
478
+ nn_embedding_list = [torch.nn.Embedding(E, D, sparse=True).cuda()] * T
479
+
480
+ times = []
481
+ if torch.cuda.is_available():
482
+ torch.cuda.synchronize()
483
+ start_event = torch.cuda.Event(enable_timing=True)
484
+ end_event = torch.cuda.Event(enable_timing=True)
485
+ for req in requests:
486
+ indices, _, weights = req.unpack_3()
487
+ indices_list = indices.view(T, B, L).split(1)
488
+
489
+ if weighted:
490
+ assert weights is not None
491
+ weights_list = weights.view(T, B, L).split(1)
492
+
493
+ start_time = time.time()
494
+ if torch.cuda.is_available():
495
+ if flush_gpu_cache_size_mb:
496
+ _ = torch.rand(
497
+ flush_gpu_cache_size_mb * 1024 * 1024 // 4,
498
+ dtype=torch.float,
499
+ device=get_device(),
500
+ )
501
+ torch.cuda.synchronize()
502
+ start_event.record()
503
+
504
+ nn_embedding_output = (
505
+ [
506
+ b_indices(nn_embedding, x, use_cpu=False, do_pooling=do_pooling)
507
+ for (nn_embedding, x) in zip(nn_embedding_list, indices_list)
508
+ ]
509
+ if not weighted
510
+ else [
511
+ b_indices(
512
+ nn_embedding,
513
+ x,
514
+ per_sample_weights=xw.view(-1),
515
+ use_cpu=False,
516
+ do_pooling=do_pooling,
517
+ )
518
+ for (nn_embedding, x, xw) in zip(
519
+ nn_embedding_list,
520
+ indices_list,
521
+ # pyre-fixme[61]: `weights_list` is undefined, or not always
522
+ # defined.
523
+ weights_list,
524
+ )
525
+ ]
526
+ )
527
+
528
+ if do_pooling:
529
+ final_output = torch.cat(
530
+ [f.view(B, -1) for f in nn_embedding_output], dim=1
531
+ )
532
+ else:
533
+ final_output = torch.cat(nn_embedding_output, dim=0).view( # noqa: F841
534
+ -1, D
535
+ )
536
+
537
+ if torch.cuda.is_available():
538
+ end_event.record()
539
+ torch.cuda.synchronize()
540
+ # pyre-fixme[61]: `end_event` is undefined, or not always defined.
541
+ it_time = start_event.elapsed_time(end_event) * 1.0e-3
542
+ times.append(it_time)
543
+ else:
544
+ it_time = time.time() - start_time
545
+ times.append(it_time)
546
+ avg_time = sum(times) / len(requests)
547
+ median_time = statistics.median(times)
548
+ return median_time if check_median else avg_time
549
+
550
+
551
+ def benchmark_pipelined_requests(
552
+ requests: list[TBERequest],
553
+ func1: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None],
554
+ func2: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None],
555
+ flush_gpu_cache_size_mb: int = 0,
556
+ check_median: bool = False,
557
+ ) -> tuple[float, float]:
558
+ torch.cuda.synchronize()
559
+ start_events = [
560
+ (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
561
+ for _ in requests
562
+ ]
563
+ end_events = [
564
+ (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
565
+ for _ in requests
566
+ ]
567
+ for req, start_event, end_event in zip(requests, start_events, end_events):
568
+ indices, offsets, indices_weights = req.unpack_3()
569
+ if flush_gpu_cache_size_mb:
570
+ _ = torch.rand(
571
+ flush_gpu_cache_size_mb * 1024 * 1024 // 4,
572
+ dtype=torch.float,
573
+ device=get_device(),
574
+ )
575
+ torch.cuda.synchronize()
576
+ start_event[0].record()
577
+ func1(indices, offsets, indices_weights)
578
+ end_event[0].record()
579
+ start_event[1].record()
580
+ func2(indices, offsets, indices_weights)
581
+ end_event[1].record()
582
+ torch.cuda.synchronize()
583
+ avg_time = (
584
+ sum(
585
+ start_event[0].elapsed_time(end_event[0]) * 1.0e-3
586
+ for start_event, end_event in zip(start_events, end_events)
587
+ )
588
+ / len(requests),
589
+ sum(
590
+ start_event[1].elapsed_time(end_event[1]) * 1.0e-3
591
+ for start_event, end_event in zip(start_events, end_events)
592
+ )
593
+ / len(requests),
594
+ )
595
+ median_time = (
596
+ statistics.median(
597
+ start_event[0].elapsed_time(end_event[0]) * 1.0e-3
598
+ for start_event, end_event in zip(start_events, end_events)
599
+ ),
600
+ statistics.median(
601
+ start_event[1].elapsed_time(end_event[1]) * 1.0e-3
602
+ for start_event, end_event in zip(start_events, end_events)
603
+ ),
604
+ )
605
+ return median_time if check_median else avg_time
606
+
607
+
608
+ def benchmark_vbe(
609
+ requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
610
+ func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
611
+ num_warmups: int = 0,
612
+ ) -> tuple[float, float]:
613
+ """
614
+ A benchmark function to return the average execution time in seconds of
615
+ forward and backward of VBE kernels.
616
+
617
+ Args:
618
+ requests (List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]):
619
+ A list of requests. Each request is a tuple
620
+ of indices, offsets and weights.
621
+
622
+ func (Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]):
623
+ A function that takes in indices, offsets, and weights
624
+ and returns the output of the VBE kernel.
625
+
626
+ num_warmups (int):
627
+ The number of warm-up iterations before measuring performance.
628
+
629
+ Returns:
630
+ Tuple[float, float]:
631
+ A tuple of average execution time in seconds of forward and
632
+ backward of VBE kernels.
633
+ """
634
+
635
+ use_cuda = torch.cuda.is_available()
636
+
637
+ # Warm-ups.
638
+ for _ in range(num_warmups):
639
+ # Warm-up using the first request as done in benchmark_requests
640
+ indices, offsets, weights = requests[0]
641
+ out = func(indices, offsets, weights)
642
+ grad = torch.rand_like(out)
643
+ out.backward(grad)
644
+
645
+ iters = len(requests)
646
+ if use_cuda:
647
+ fwd_start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
648
+ fwd_end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
649
+ bwd_start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
650
+ bwd_end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
651
+ torch.cuda.synchronize()
652
+ else:
653
+ # Actual measurement in seconds.
654
+ fwd_times_sec = []
655
+ bwd_times_sec = []
656
+
657
+ for i, (indices, offsets, weights) in enumerate(requests):
658
+ # forward
659
+ if use_cuda:
660
+ # pyre-ignore[61]
661
+ fwd_start_events[i].record()
662
+ else:
663
+ start_time = time.time()
664
+
665
+ out = func(indices, offsets, weights)
666
+ if use_cuda:
667
+ # pyre-ignore[61]
668
+ fwd_end_events[i].record()
669
+ else:
670
+ # pyre-ignore[61]
671
+ fwd_times_sec.append(time.time() - start_time)
672
+
673
+ grad = torch.rand_like(out)
674
+
675
+ if use_cuda:
676
+ # pyre-ignore[61]
677
+ bwd_start_events[i].record()
678
+ else:
679
+ start_time = time.time()
680
+ # backward
681
+ out.backward(grad)
682
+ if use_cuda:
683
+ # pyre-ignore[61]
684
+ bwd_end_events[i].record()
685
+ else:
686
+ # pyre-ignore[61]
687
+ bwd_times_sec.append(time.time() - start_time)
688
+
689
+ if use_cuda:
690
+ torch.cuda.synchronize()
691
+
692
+ if use_cuda:
693
+ fwd_times_sec = [
694
+ start_event.elapsed_time(end_event) * 1.0e-3
695
+ # pyre-ignore[61]
696
+ for start_event, end_event in zip(fwd_start_events, fwd_end_events)
697
+ ]
698
+ bwd_times_sec = [
699
+ start_event.elapsed_time(end_event) * 1.0e-3
700
+ # pyre-ignore[61]
701
+ for start_event, end_event in zip(bwd_start_events, bwd_end_events)
702
+ ]
703
+
704
+ # pyre-ignore[61]
705
+ fwd_time_sec = statistics.median(fwd_times_sec)
706
+ # pyre-ignore[61]
707
+ bwd_time_sec = statistics.median(bwd_times_sec)
708
+
709
+ return fwd_time_sec, bwd_time_sec