fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -31,15 +31,18 @@ except Exception:
31
31
 
32
32
  # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
33
33
  import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
34
+ from fbgemm_gpu.split_embedding_configs import sparse_type_int_to_dtype
34
35
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
35
36
 
36
37
 
37
38
  def generate_vbe_metadata(
38
39
  offsets: Tensor,
39
- batch_size_per_feature_per_rank: Optional[List[List[int]]],
40
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
40
41
  pooling_mode: PoolingMode,
41
42
  feature_dims_cpu: Tensor,
42
43
  device: torch.device,
44
+ vbe_output: Optional[Tensor] = None,
45
+ vbe_output_offsets: Optional[Tensor] = None,
43
46
  ) -> invokers.lookup_args.VBEMetadata:
44
47
  """
45
48
  Generate VBE metadata based on batch_size_per_feature_per_rank.
@@ -133,6 +136,8 @@ def generate_vbe_metadata(
133
136
  max_B_feature_rank=max_B_feature_rank,
134
137
  # pyre-ignore
135
138
  output_size=output_size,
139
+ vbe_output=vbe_output,
140
+ vbe_output_offsets=vbe_output_offsets,
136
141
  )
137
142
  else:
138
143
  vbe_metadata = invokers.lookup_args.VBEMetadata(
@@ -142,5 +147,43 @@ def generate_vbe_metadata(
142
147
  max_B=-1,
143
148
  max_B_feature_rank=-1,
144
149
  output_size=-1,
150
+ vbe_output=None,
151
+ vbe_output_offsets=None,
145
152
  )
146
153
  return vbe_metadata
154
+
155
+
156
+ def check_allocated_vbe_output(
157
+ output_dtype: int,
158
+ batch_size_per_feature_per_rank: Optional[List[List[int]]],
159
+ vbe_output: Optional[Tensor] = None,
160
+ vbe_output_offsets: Optional[Tensor] = None,
161
+ ) -> None:
162
+ assert (
163
+ batch_size_per_feature_per_rank is not None
164
+ ), "[Merged_VBE] vbe_output is passed, batch_size_per_feature_per_rank cannot be None"
165
+ assert (
166
+ vbe_output is not None
167
+ ), "[Merged_VBE] vbe_output_offsets is not None, vbe_output cannot be None"
168
+ assert (
169
+ vbe_output_offsets is not None
170
+ ), "[Merged_VBE] vbe_output is not None, vbe_output_offsets cannot be None"
171
+ num_features = len(batch_size_per_feature_per_rank)
172
+ num_ranks = len(batch_size_per_feature_per_rank[0])
173
+ assert vbe_output_offsets.shape == torch.Size(
174
+ [num_ranks, num_features]
175
+ ), f"[Merged_VBE] Mismatched vbe_output_offsets shape. batch_size_per_feature_per_rank={batch_size_per_feature_per_rank}. Expected: {torch.Size([num_ranks, num_features])}, Actual: {vbe_output_offsets.shape}"
176
+ assert (
177
+ vbe_output.dim() == 1
178
+ ), f"[Merged_VBE] vbe_output must have 1 dimension, but got {vbe_output.dim()}. vbe_output shape is {vbe_output.shape}"
179
+ assert (
180
+ vbe_output_offsets.device == vbe_output.device
181
+ ), "[Merged_VBE] vbe_output_offsets and vbe_output must be on the same device"
182
+ _output_dtype = sparse_type_int_to_dtype(output_dtype)
183
+ assert (
184
+ vbe_output.dtype == _output_dtype
185
+ ), f"[Merged_VBE] vbe_output dtype must match TBE output dtype {_output_dtype} (SparseType {output_dtype}), but got {vbe_output.dtype}"
186
+ assert (
187
+ vbe_output_offsets.is_contiguous()
188
+ ), "[Merged_VBE] vbe_output_offsets needs to be contiguous"
189
+ assert vbe_output.is_contiguous(), "[Merged_VBE] vbe_output needs to be contiguous"
@@ -16,7 +16,6 @@ from fbgemm_gpu.tbe.ssd import ( # noqa: F401
16
16
  SSDTableBatchedEmbeddingBags, # noqa: F401
17
17
  )
18
18
 
19
-
20
19
  warnings.warn( # noqa: B028
21
20
  f"""\033[93m
22
21
  The Python module {__name__} is now DEPRECATED and will be removed in the
@@ -21,6 +21,7 @@ from .bench_runs import ( # noqa F401
21
21
  benchmark_pipelined_requests,
22
22
  benchmark_requests,
23
23
  benchmark_requests_refer,
24
+ benchmark_requests_with_spec,
24
25
  benchmark_vbe,
25
26
  )
26
27
  from .benchmark_click_interface import TbeBenchClickInterface # noqa F401
@@ -40,7 +41,11 @@ from .tbe_data_config_param_models import ( # noqa F401
40
41
  IndicesParams,
41
42
  PoolingParams,
42
43
  )
43
- from .utils import fill_random_scale_bias # noqa F401
44
+ from .utils import ( # noqa F401
45
+ check_oom,
46
+ fill_random_scale_bias,
47
+ generate_merged_output_and_offsets,
48
+ )
44
49
 
45
50
  try:
46
51
  torch.ops.load_library(
@@ -10,7 +10,7 @@
10
10
  import dataclasses
11
11
  import json
12
12
  from enum import Enum
13
- from typing import Any, Dict, Optional
13
+ from typing import Any, Optional
14
14
 
15
15
  import click
16
16
 
@@ -29,10 +29,12 @@ class TBEBenchmarkingConfig:
29
29
  export_trace: bool
30
30
  # The path for exporting the trace
31
31
  trace_url: Optional[str]
32
+ # If set and export_trace is true, the benchmark will upload performance data from the trace to Scuba
33
+ upload_perf_data: bool
32
34
 
33
35
  @classmethod
34
36
  # pyre-ignore [3]
35
- def from_dict(cls, data: Dict[str, Any]):
37
+ def from_dict(cls, data: dict[str, Any]):
36
38
  return cls(**data)
37
39
 
38
40
  @classmethod
@@ -40,7 +42,7 @@ class TBEBenchmarkingConfig:
40
42
  def from_json(cls, data: str):
41
43
  return cls.from_dict(json.loads(data))
42
44
 
43
- def dict(self) -> Dict[str, Any]:
45
+ def dict(self) -> dict[str, Any]:
44
46
  return dataclasses.asdict(self)
45
47
 
46
48
  def json(self, format: bool = False) -> str:
@@ -71,6 +73,7 @@ class TBEBenchmarkingHelperText(Enum):
71
73
  "If set, trace will be exported to the path specified in trace url"
72
74
  )
73
75
  BENCH_TRACE_URL = "The path for exporting the trace"
76
+ BENCH_UPLOAD_PERF_DATA = "If set and export_trace is true, the benchmark will upload performance data from the trace to Scuba"
74
77
 
75
78
 
76
79
  class TBEBenchmarkingConfigLoader:
@@ -115,6 +118,12 @@ class TBEBenchmarkingConfigLoader:
115
118
  default="{emb_op_type}_tbe_{phase}_trace_{ospid}.json",
116
119
  help=TBEBenchmarkingHelperText.BENCH_TRACE_URL.value,
117
120
  ),
121
+ click.option(
122
+ "--upload-perf-data",
123
+ is_flag=True,
124
+ default=False,
125
+ help=TBEBenchmarkingHelperText.BENCH_UPLOAD_PERF_DATA.value,
126
+ ),
118
127
  ]
119
128
 
120
129
  for option in reversed(options):
@@ -131,6 +140,7 @@ class TBEBenchmarkingConfigLoader:
131
140
  flush_gpu_cache_size = params["bench_flush_gpu_cache_size"]
132
141
  export_trace = params["bench_export_trace"]
133
142
  trace_url = params["bench_trace_url"]
143
+ upload_perf_data = params["upload_perf_data"]
134
144
 
135
145
  # Default the number of TBE requests to number of iterations specified
136
146
  num_requests = iterations if num_requests == -1 else num_requests
@@ -142,4 +152,5 @@ class TBEBenchmarkingConfigLoader:
142
152
  flush_gpu_cache_size,
143
153
  export_trace,
144
154
  trace_url,
155
+ upload_perf_data,
145
156
  ).validate()
@@ -1,3 +1,4 @@
1
+ #!/usr/bin/env python3
1
2
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
3
  # All rights reserved.
3
4
  #
@@ -11,12 +12,13 @@ import statistics
11
12
  import threading
12
13
  import time
13
14
  from subprocess import Popen
14
- from typing import Callable, List, Optional, Tuple
15
+ from typing import Callable, Optional
15
16
 
16
17
  import torch
17
18
 
19
+ # fmt:skip
18
20
  from fbgemm_gpu.tbe.utils import b_indices, TBERequest
19
-
21
+ from fbgemm_gpu.tbe.utils.common import get_device
20
22
 
21
23
  logging.basicConfig(level=logging.DEBUG)
22
24
 
@@ -43,6 +45,31 @@ def bench_warmup(
43
45
  out.backward(grad)
44
46
 
45
47
 
48
+ def bench_warmup_with_spec(
49
+ request: TBERequest,
50
+ warmup_ms: int,
51
+ warmup_runs: int,
52
+ func: Callable[
53
+ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]],
54
+ torch.Tensor,
55
+ ],
56
+ bwd_only: bool = False,
57
+ grad: Optional[torch.Tensor] = None,
58
+ ) -> None:
59
+ indices, offsets, weights, batch_size_per_feature_per_rank = request.unpack_4()
60
+ if warmup_ms:
61
+ start_time_ms = time.time() * 1000
62
+ while time.time() * 1000 - start_time_ms < warmup_ms:
63
+ out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
64
+ if bwd_only:
65
+ out.backward(grad)
66
+ else:
67
+ for _ in range(warmup_runs):
68
+ out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
69
+ if bwd_only:
70
+ out.backward(grad)
71
+
72
+
46
73
  class BMBarrier:
47
74
 
48
75
  def __init__(self) -> None:
@@ -66,7 +93,7 @@ cpu_bm_barrier = BMBarrier()
66
93
 
67
94
 
68
95
  def cpu_tbe_worker(
69
- requests_: List[TBERequest],
96
+ requests_: list[TBERequest],
70
97
  func_: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
71
98
  use_barrier: bool = False,
72
99
  ) -> float:
@@ -98,7 +125,7 @@ def cpu_tbe_worker(
98
125
 
99
126
 
100
127
  def benchmark_cpu_requests_mp(
101
- requests: List[TBERequest],
128
+ requests: list[TBERequest],
102
129
  emb_module: torch.nn.Module,
103
130
  num_warmups: int = 0,
104
131
  num_copies: int = 1,
@@ -127,6 +154,13 @@ def benchmark_cpu_requests_mp(
127
154
  float: The average runtime per iteration in seconds.
128
155
 
129
156
  """
157
+ import os
158
+
159
+ strategy = os.environ.get("PYTORCH_SHARE_STRATEGY")
160
+ current_strategy = torch.multiprocessing.get_sharing_strategy()
161
+ if strategy is not None and current_strategy != strategy:
162
+ torch.multiprocessing.set_sharing_strategy(strategy)
163
+
130
164
  cpu_bm_barrier.create_barrier(num_copies)
131
165
  worker_pool = torch.multiprocessing.Pool(num_copies)
132
166
 
@@ -181,7 +215,7 @@ def benchmark_cpu_requests_mp(
181
215
 
182
216
 
183
217
  def benchmark_cpu_requests(
184
- requests: List[TBERequest],
218
+ requests: list[TBERequest],
185
219
  func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
186
220
  num_warmups: int = 0,
187
221
  ) -> float:
@@ -199,7 +233,7 @@ def benchmark_cpu_requests(
199
233
 
200
234
 
201
235
  def benchmark_requests( # noqa: C901
202
- requests: List[TBERequest],
236
+ requests: list[TBERequest],
203
237
  func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
204
238
  flush_gpu_cache_size_mb: int = 0,
205
239
  check_median: bool = False,
@@ -266,7 +300,7 @@ def benchmark_requests( # noqa: C901
266
300
  _ = torch.rand(
267
301
  flush_gpu_cache_size_mb * 1024 * 1024 // 4,
268
302
  dtype=torch.float,
269
- device="cuda",
303
+ device=get_device(),
270
304
  )
271
305
  start_events[it].record()
272
306
 
@@ -308,8 +342,123 @@ def benchmark_requests( # noqa: C901
308
342
  return median_time if check_median else avg_time
309
343
 
310
344
 
345
+ def benchmark_requests_with_spec( # noqa: C901
346
+ requests: list[TBERequest],
347
+ func: Callable[
348
+ [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]],
349
+ torch.Tensor,
350
+ ],
351
+ flush_gpu_cache_size_mb: int = 0,
352
+ check_median: bool = False,
353
+ num_warmups: int = 0,
354
+ bwd_only: bool = False,
355
+ grad: Optional[torch.Tensor] = None,
356
+ # Used to label benchmark iterations differently in nsys profile result
357
+ # so that we can compare performance of two different models for example.
358
+ # If empty string is provided, it won't have any effect.
359
+ nvtx_range: str = "",
360
+ # Can be used to clear model's stats after warmup for example.
361
+ callback_after_warmup: Optional[Callable[[], None]] = None,
362
+ periodic_logs: bool = False,
363
+ warmup_ms: Optional[int] = None,
364
+ iters: int = -1,
365
+ ) -> float:
366
+ times = []
367
+ # Run at least one warmup iteration to avoid the long cudaLaunchKernel time
368
+ # for the first kernel if warmup_ms > 0
369
+ # warmup_ms is prioritized over num_warmups
370
+
371
+ if warmup_ms is None:
372
+ num_warmups = num_warmups + 1 if num_warmups >= 0 else 1
373
+
374
+ # warm-up the GPU before profiling
375
+ bench_warmup_with_spec(
376
+ requests[0],
377
+ # pyre-ignore[6]
378
+ warmup_ms,
379
+ num_warmups,
380
+ lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: func(
381
+ indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
382
+ ),
383
+ bwd_only=bwd_only,
384
+ grad=grad,
385
+ )
386
+
387
+ if callback_after_warmup is not None:
388
+ callback_after_warmup()
389
+
390
+ num_reqs = len(requests)
391
+ iters = num_reqs if iters == -1 else iters
392
+
393
+ if torch.cuda.is_available():
394
+ torch.cuda.synchronize()
395
+ start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
396
+ end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
397
+ else:
398
+ start_events = []
399
+ end_events = []
400
+
401
+ for it in range(iters):
402
+ req = requests[it % num_reqs]
403
+
404
+ indices, offsets, weights, batch_size_per_feature_per_rank = req.unpack_4()
405
+ # logging.info(
406
+ # f"[Benchmark Request] batch_size_per_feature_per_rank {batch_size_per_feature_per_rank} {indices.device}"
407
+ # )
408
+
409
+ if bwd_only:
410
+ # Run forward before profiling if does backward only
411
+ out = func(indices, offsets, weights, batch_size_per_feature_per_rank)
412
+ start_time = time.time()
413
+ if torch.cuda.is_available():
414
+ if flush_gpu_cache_size_mb:
415
+ _ = torch.rand(
416
+ flush_gpu_cache_size_mb * 1024 * 1024 // 4,
417
+ dtype=torch.float,
418
+ device=get_device(),
419
+ )
420
+ start_events[it].record()
421
+
422
+ if nvtx_range:
423
+ torch.cuda.nvtx.range_push(f"{nvtx_range}-{it}")
424
+
425
+ if bwd_only:
426
+ out.backward(grad)
427
+ else:
428
+ func(indices, offsets, weights, batch_size_per_feature_per_rank)
429
+
430
+ if nvtx_range:
431
+ torch.cuda.nvtx.range_pop()
432
+
433
+ if torch.cuda.is_available():
434
+ end_events[it].record()
435
+ else:
436
+ it_time = time.time() - start_time
437
+ times.append(it_time)
438
+
439
+ if torch.cuda.is_available():
440
+ torch.cuda.synchronize()
441
+ times = [
442
+ start.elapsed_time(end) * 1.0e-3
443
+ for start, end in zip(start_events, end_events)
444
+ ]
445
+
446
+ if periodic_logs:
447
+ for it in range(100, iters + 1, 100):
448
+ times_ = times[0:it]
449
+ avg_time = sum(times_) / len(times_) * 1.0e6
450
+ last_100_avg = sum(times_[-100:]) / 100 * 1.0e6
451
+ logging.info(
452
+ f"Iteration [{it}/{len(requests)}]: Last 100: {last_100_avg:.2f} us, Running avg: {avg_time:.2f} us"
453
+ )
454
+
455
+ avg_time = sum(times) / iters
456
+ median_time = statistics.median(times)
457
+ return median_time if check_median else avg_time
458
+
459
+
311
460
  def benchmark_requests_refer(
312
- requests: List[TBERequest],
461
+ requests: list[TBERequest],
313
462
  T: int,
314
463
  B: int,
315
464
  L: int,
@@ -348,7 +497,7 @@ def benchmark_requests_refer(
348
497
  _ = torch.rand(
349
498
  flush_gpu_cache_size_mb * 1024 * 1024 // 4,
350
499
  dtype=torch.float,
351
- device="cuda",
500
+ device=get_device(),
352
501
  )
353
502
  torch.cuda.synchronize()
354
503
  start_event.record()
@@ -401,12 +550,12 @@ def benchmark_requests_refer(
401
550
 
402
551
 
403
552
  def benchmark_pipelined_requests(
404
- requests: List[TBERequest],
553
+ requests: list[TBERequest],
405
554
  func1: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None],
406
555
  func2: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None],
407
556
  flush_gpu_cache_size_mb: int = 0,
408
557
  check_median: bool = False,
409
- ) -> Tuple[float, float]:
558
+ ) -> tuple[float, float]:
410
559
  torch.cuda.synchronize()
411
560
  start_events = [
412
561
  (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
@@ -422,7 +571,7 @@ def benchmark_pipelined_requests(
422
571
  _ = torch.rand(
423
572
  flush_gpu_cache_size_mb * 1024 * 1024 // 4,
424
573
  dtype=torch.float,
425
- device="cuda",
574
+ device=get_device(),
426
575
  )
427
576
  torch.cuda.synchronize()
428
577
  start_event[0].record()
@@ -458,10 +607,10 @@ def benchmark_pipelined_requests(
458
607
 
459
608
 
460
609
  def benchmark_vbe(
461
- requests: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
610
+ requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
462
611
  func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
463
612
  num_warmups: int = 0,
464
- ) -> Tuple[float, float]:
613
+ ) -> tuple[float, float]:
465
614
  """
466
615
  A benchmark function to return the average execution time in seconds of
467
616
  forward and backward of VBE kernels.
@@ -8,11 +8,14 @@
8
8
  # pyre-strict
9
9
 
10
10
  import click
11
+
12
+ # fmt:skip
11
13
  from fbgemm_gpu.split_embedding_configs import SparseType
12
14
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import BoundsCheckMode
13
15
 
14
- from .bench_config import TBEBenchmarkingHelperText
15
- from .tbe_data_config_loader import TBEDataConfigHelperText
16
+ # fmt:skip
17
+ from .bench_config import TBEBenchmarkingHelperText # usort:skip
18
+ from .tbe_data_config_loader import TBEDataConfigHelperText # usort:skip
16
19
 
17
20
 
18
21
  class TbeBenchClickInterface:
@@ -6,11 +6,11 @@
6
6
 
7
7
  # pyre-strict
8
8
 
9
- from typing import List, Tuple
10
9
 
11
10
  import click
12
11
  import torch
13
12
 
13
+ # fmt:skip
14
14
  from fbgemm_gpu.tbe.bench import IndicesParams
15
15
 
16
16
 
@@ -82,7 +82,7 @@ def estimate(indices: str) -> None:
82
82
  )
83
83
  def generate(
84
84
  hitters: str,
85
- zipf: Tuple[float, float],
85
+ zipf: tuple[float, float],
86
86
  max_index: int,
87
87
  num_indices: int,
88
88
  output: str,
@@ -114,7 +114,7 @@ def generate(
114
114
  assert output != "", "Output file path must be provided"
115
115
 
116
116
  try:
117
- _hitters: List[float] = (
117
+ _hitters: list[float] = (
118
118
  [float(x) for x in hitters.split(",")] if hitters else []
119
119
  )
120
120
  except Exception as e:
@@ -8,11 +8,12 @@
8
8
  # pyre-strict
9
9
 
10
10
  import dataclasses
11
- from typing import Any, Dict, Optional
11
+ from typing import Any, Optional
12
12
 
13
13
  import click
14
14
  import torch
15
15
 
16
+ # fmt:skip
16
17
  from fbgemm_gpu.split_embedding_configs import SparseType
17
18
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
18
19
  BoundsCheckMode,
@@ -44,7 +45,7 @@ class EmbeddingOpsCommonConfig:
44
45
  def validate(self):
45
46
  return self
46
47
 
47
- def split_args(self) -> Dict[str, Any]:
48
+ def split_args(self) -> dict[str, Any]:
48
49
  return {
49
50
  "weights_precision": self.weights_dtype,
50
51
  "stochastic_rounding": self.stochastic_rounding,
@@ -10,7 +10,7 @@
10
10
  import logging
11
11
  import statistics
12
12
  from dataclasses import dataclass
13
- from typing import Callable, List, Tuple
13
+ from typing import Callable
14
14
 
15
15
  import torch
16
16
 
@@ -29,8 +29,8 @@ class EvalCompressionBenchmarkOutput:
29
29
 
30
30
 
31
31
  def benchmark_eval_compression(
32
- baseline_requests: List[Tuple[torch.Tensor, torch.Tensor]],
33
- compressed_requests: List[Tuple[torch.Tensor, torch.Tensor]],
32
+ baseline_requests: list[tuple[torch.Tensor, torch.Tensor]],
33
+ compressed_requests: list[tuple[torch.Tensor, torch.Tensor]],
34
34
  baseline_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
35
35
  compressed_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
36
36
  reindex: torch.Tensor,