fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -9,7 +9,7 @@
9
9
  from __future__ import annotations
10
10
 
11
11
  import functools
12
- from typing import List, Optional, Union
12
+ from typing import Optional, Union
13
13
 
14
14
  import torch
15
15
 
@@ -191,7 +191,6 @@ class PartiallyMaterializedTensor:
191
191
  layout_str_mapping = {
192
192
  "SparseCsr": "sparse_csr",
193
193
  "Strided": "strided",
194
- "SparseCsr": "sparse_csr",
195
194
  "SparseCsc": "sparse_csc",
196
195
  "Jagged": "jagged",
197
196
  }
@@ -249,7 +248,7 @@ class PartiallyMaterializedTensor:
249
248
 
250
249
  return torch.equal(tensor1.full_tensor(), tensor2.full_tensor())
251
250
 
252
- def get_kvtensor_serializable_metadata(self) -> List[str]:
251
+ def get_kvtensor_serializable_metadata(self) -> list[str]:
253
252
  return self._wrapped.get_kvtensor_serializable_metadata()
254
253
 
255
254
  def __hash__(self):
@@ -8,31 +8,37 @@
8
8
  # pyre-strict
9
9
 
10
10
  import io
11
+ import json
11
12
  import logging
12
13
  import os
13
- from typing import List, Optional
14
+ from typing import List, Optional, Tuple
14
15
 
15
16
  import fbgemm_gpu # noqa F401
16
- import numpy as np # usort:skip
17
17
  import torch # usort:skip
18
18
 
19
- from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
20
- SplitTableBatchedEmbeddingBagsCodegen,
21
- )
22
- from fbgemm_gpu.tbe.bench import (
19
+ from fbgemm_gpu.tbe.bench.tbe_data_config import (
23
20
  BatchParams,
24
21
  IndicesParams,
25
22
  PoolingParams,
26
23
  TBEDataConfig,
27
24
  )
28
25
 
29
- # pyre-ignore[16]
26
+ open_source: bool = False
27
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
30
28
  open_source: bool = getattr(fbgemm_gpu, "open_source", False)
31
29
 
32
30
  if open_source:
33
31
  from fbgemm_gpu.utils import FileStore
32
+
34
33
  else:
35
- from fbgemm_gpu.fb.utils import FileStore
34
+ try:
35
+ from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore
36
+
37
+ torch.ops.load_library(
38
+ "//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_estimator"
39
+ )
40
+ except Exception:
41
+ pass
36
42
 
37
43
 
38
44
  class TBEBenchmarkParamsReporter:
@@ -43,7 +49,8 @@ class TBEBenchmarkParamsReporter:
43
49
  def __init__(
44
50
  self,
45
51
  report_interval: int,
46
- report_once: bool = False,
52
+ report_iter_start: int = 0,
53
+ report_iter_end: int = -1,
47
54
  bucket: Optional[str] = None,
48
55
  path_prefix: Optional[str] = None,
49
56
  ) -> None:
@@ -52,13 +59,31 @@ class TBEBenchmarkParamsReporter:
52
59
 
53
60
  Args:
54
61
  report_interval (int): The interval at which reports are generated.
55
- report_once (bool, optional): If True, reporting occurs only once. Defaults to False.
62
+ report_iter_start (int): The start of the iteration range to capture. Defaults to 0.
63
+ report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration).
56
64
  bucket (Optional[str], optional): The storage bucket for reports. Defaults to None.
57
65
  path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None.
58
66
  """
67
+
68
+ assert report_interval > 0, "report_interval must be greater than 0"
69
+ assert (
70
+ report_iter_start >= 0
71
+ ), "report_iter_start must be greater than or equal to 0"
72
+ assert (
73
+ report_iter_end >= -1
74
+ ), "report_iter_end must be greater than or equal to -1"
75
+ assert (
76
+ report_iter_end == -1 or report_iter_start <= report_iter_end
77
+ ), "report_iter_start must be less than or equal to report_iter_end"
78
+
59
79
  self.report_interval = report_interval
60
- self.report_once = report_once
61
- self.has_reported = False
80
+ self.report_iter_start = report_iter_start
81
+ self.report_iter_end = report_iter_end
82
+
83
+ if path_prefix is not None and path_prefix.endswith("/"):
84
+ path_prefix = path_prefix[:-1]
85
+
86
+ self.path_prefix = path_prefix
62
87
 
63
88
  default_bucket = "/tmp" if open_source else "tlparse_reports"
64
89
  bucket = (
@@ -68,22 +93,83 @@ class TBEBenchmarkParamsReporter:
68
93
  )
69
94
  self.filestore = FileStore(bucket)
70
95
 
96
+ if self.path_prefix is not None and not self.filestore.exists(self.path_prefix):
97
+ self.filestore.create_directory(self.path_prefix)
98
+
71
99
  self.logger: logging.Logger = logging.getLogger(__name__)
72
100
  self.logger.setLevel(logging.INFO)
73
101
 
102
+ @classmethod
103
+ def create(cls) -> "TBEBenchmarkParamsReporter":
104
+ """
105
+ This method returns an instance of TBEBenchmarkParamsReporter based on environment variables.
106
+
107
+ If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that:
108
+ - Reports input parameters (TBEDataConfig).
109
+ - Writes the output as a JSON file.
110
+
111
+ Additionally, the following environment variables are considered:
112
+ - `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture.
113
+ - `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture.
114
+ - `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting.
115
+ - `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting.
116
+
117
+ Returns:
118
+ TBEBenchmarkParamsReporter: An instance configured based on the environment variables.
119
+ """
120
+ report_interval = int(
121
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1")
122
+ )
123
+ report_iter_start = int(
124
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0")
125
+ )
126
+ report_iter_end = int(
127
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1")
128
+ )
129
+ bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "")
130
+ path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "")
131
+
132
+ return cls(
133
+ report_interval=report_interval,
134
+ report_iter_start=report_iter_start,
135
+ report_iter_end=report_iter_end,
136
+ bucket=bucket,
137
+ path_prefix=path_prefix,
138
+ )
139
+
140
+ def extract_Ls(
141
+ self,
142
+ bag_sizes: List[int],
143
+ Bs: List[int],
144
+ ) -> List[float]:
145
+ Ls = []
146
+ start = 0
147
+ for b in Bs:
148
+ end = start + b
149
+ avg_L = sum(bag_sizes[start:end]) / b if b > 0 else 0
150
+ start = end
151
+ Ls.append(avg_L)
152
+ return Ls
153
+
74
154
  def extract_params(
75
155
  self,
76
- embedding_op: SplitTableBatchedEmbeddingBagsCodegen,
156
+ feature_rows: torch.Tensor,
157
+ feature_dims: torch.Tensor,
77
158
  indices: torch.Tensor,
78
159
  offsets: torch.Tensor,
79
160
  per_sample_weights: Optional[torch.Tensor] = None,
80
161
  batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
162
+ Es: Optional[List[int]] = None,
163
+ Ds: Optional[List[int]] = None,
164
+ embedding_specs: Optional[List[Tuple[int, int]]] = None,
165
+ feature_table_map: Optional[List[int]] = None,
81
166
  ) -> TBEDataConfig:
82
167
  """
83
- Extracts parameters from the embedding operation, input indices and offsets to create a TBEDataConfig.
168
+ Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig.
84
169
 
85
170
  Args:
86
- embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation.
171
+ feature_rows (torch.Tensor): Number of rows in each feature.
172
+ feature_dims (torch.Tensor): Number of dimensions in each feature.
87
173
  indices (torch.Tensor): The input indices tensor.
88
174
  offsets (torch.Tensor): The input offsets tensor.
89
175
  per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
@@ -92,24 +178,37 @@ class TBEBenchmarkParamsReporter:
92
178
  Returns:
93
179
  TBEDataConfig: The configuration data for TBE benchmarking.
94
180
  """
181
+
182
+ Es = feature_rows.tolist()
183
+ Ds = feature_dims.tolist()
184
+
185
+ assert len(Es) == len(
186
+ Ds
187
+ ), "feature_rows and feature_dims must have the same length"
188
+
95
189
  # Transfer indices back to CPU for EEG analysis
96
190
  indices_cpu = indices.cpu()
97
191
 
98
- # Extract embedding table specs
99
- embedding_specs = [
100
- embedding_op.embedding_specs[t] for t in embedding_op.feature_table_map
101
- ]
102
- rowcounts = [embedding_spec[0] for embedding_spec in embedding_specs]
103
- dims = [embedding_spec[1] for embedding_spec in embedding_specs]
104
-
105
192
  # Set T to be the number of features we are looking at
106
- T = len(embedding_op.feature_table_map)
193
+ T = len(Ds)
107
194
  # Set E to be the mean of the rowcounts to avoid biasing
108
- E = rowcounts[0] if len(set(rowcounts)) == 1 else np.ceil((np.mean(rowcounts)))
195
+ E = (
196
+ Es[0]
197
+ if len(set(Es)) == 1
198
+ else torch.ceil(
199
+ torch.mean(torch.tensor(feature_rows, dtype=torch.float))
200
+ ).item()
201
+ )
109
202
  # Set mixed_dim to be True if there are multiple dims
110
- mixed_dim = len(set(dims)) > 1
203
+ mixed_dim = len(set(Ds)) > 1
111
204
  # Set D to be the mean of the dims to avoid biasing
112
- D = dims[0] if not mixed_dim else np.ceil((np.mean(dims)))
205
+ D = (
206
+ Ds[0]
207
+ if not mixed_dim
208
+ else torch.ceil(
209
+ torch.mean(torch.tensor(feature_dims, dtype=torch.float))
210
+ ).item()
211
+ )
113
212
 
114
213
  # Compute indices distribution parameters
115
214
  heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution(
@@ -120,11 +219,27 @@ class TBEBenchmarkParamsReporter:
120
219
  )
121
220
 
122
221
  # Compute batch parameters
222
+ B = int((offsets.numel() - 1) // T)
223
+ Bs = (
224
+ [sum(b_per_rank) for b_per_rank in batch_size_per_feature_per_rank]
225
+ if batch_size_per_feature_per_rank
226
+ else [B] * T
227
+ )
123
228
  batch_params = BatchParams(
124
- B=((offsets.numel() - 1) // T),
229
+ B=B,
125
230
  sigma_B=(
126
- np.ceil(
127
- np.std([b for bs in batch_size_per_feature_per_rank for b in bs])
231
+ int(
232
+ torch.ceil(
233
+ torch.std(
234
+ torch.tensor(
235
+ [
236
+ b
237
+ for bs in batch_size_per_feature_per_rank
238
+ for b in bs
239
+ ]
240
+ ).float()
241
+ )
242
+ )
128
243
  )
129
244
  if batch_size_per_feature_per_rank
130
245
  else None
@@ -135,15 +250,28 @@ class TBEBenchmarkParamsReporter:
135
250
  if batch_size_per_feature_per_rank
136
251
  else None
137
252
  ),
253
+ Bs=Bs,
138
254
  )
139
255
 
140
256
  # Compute pooling parameters
141
- bag_sizes = (offsets[1:] - offsets[:-1]).tolist()
257
+ bag_sizes = offsets[1:] - offsets[:-1]
258
+ if batch_size_per_feature_per_rank is None:
259
+ _B = int(bag_sizes.numel() // T)
260
+ assert _B == Bs[0], f"Expected constant batch size {Bs[0]} but got {_B}"
142
261
  mixed_bag_sizes = len(set(bag_sizes)) > 1
143
262
  pooling_params = PoolingParams(
144
- L=np.ceil(np.mean(bag_sizes)) if mixed_bag_sizes else bag_sizes[0],
145
- sigma_L=(np.ceil(np.std(bag_sizes)) if mixed_bag_sizes else None),
263
+ L=(
264
+ int(torch.ceil(torch.mean(bag_sizes.float())))
265
+ if mixed_bag_sizes
266
+ else int(bag_sizes[0])
267
+ ),
268
+ sigma_L=(
269
+ int(torch.ceil(torch.std(bag_sizes.float())))
270
+ if mixed_bag_sizes
271
+ else None
272
+ ),
146
273
  length_distribution=("normal" if mixed_bag_sizes else None),
274
+ Ls=self.extract_Ls(bag_sizes.tolist(), Bs),
147
275
  )
148
276
 
149
277
  return TBEDataConfig(
@@ -156,38 +284,66 @@ class TBEBenchmarkParamsReporter:
156
284
  indices_params=indices_params,
157
285
  pooling_params=pooling_params,
158
286
  use_cpu=(not torch.cuda.is_available()),
287
+ Es=Es,
288
+ Ds=Ds,
289
+ embedding_specs=embedding_specs,
290
+ feature_table_map=feature_table_map,
159
291
  )
160
292
 
161
293
  def report_stats(
162
294
  self,
163
- embedding_op: SplitTableBatchedEmbeddingBagsCodegen,
295
+ feature_rows: torch.Tensor,
296
+ feature_dims: torch.Tensor,
297
+ iteration: int,
164
298
  indices: torch.Tensor,
165
299
  offsets: torch.Tensor,
300
+ op_id: str = "",
166
301
  per_sample_weights: Optional[torch.Tensor] = None,
167
302
  batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
303
+ embedding_specs: Optional[List[Tuple[int, int]]] = None,
304
+ feature_table_map: Optional[List[int]] = None,
168
305
  ) -> None:
169
306
  """
170
- Reports the configuration of the embedding operation and input data then writes the TBE configuration to the filestore.
307
+ Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore.
171
308
 
172
309
  Args:
173
- embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation.
310
+ feature_rows (torch.Tensor): Number of rows in each feature.
311
+ feature_dims (torch.Tensor): Number of dimensions in each feature.
312
+ iteration (int): The current iteration number.
174
313
  indices (torch.Tensor): The input indices tensor.
175
314
  offsets (torch.Tensor): The input offsets tensor.
315
+ op_id (str, optional): The operation identifier. Defaults to an empty string.
176
316
  per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
177
317
  batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
318
+ embedding_specs (Optional[List[Tuple[int, int]]]): Embedding specs. Defaults to None.
319
+ feature_table_map (Optional[List[int]], optional): Feature table map. Defaults to None.
178
320
  """
179
- if embedding_op.iter.item() % self.report_interval == 0 and (
180
- not self.report_once or (self.report_once and not self.has_reported)
321
+ if (
322
+ (iteration - self.report_iter_start) % self.report_interval == 0
323
+ and (iteration >= self.report_iter_start)
324
+ and (self.report_iter_end == -1 or iteration <= self.report_iter_end)
181
325
  ):
326
+ # If indices tensor is empty (indices.numel() == 0), skip reporting
327
+ # TODO: Remove this once we have a better way to handle empty indices tensors
328
+ if indices.numel() == 0:
329
+ return
330
+
182
331
  # Extract TBE config
183
332
  config = self.extract_params(
184
- embedding_op, indices, offsets, per_sample_weights
333
+ feature_rows=feature_rows,
334
+ feature_dims=feature_dims,
335
+ indices=indices,
336
+ offsets=offsets,
337
+ per_sample_weights=per_sample_weights,
338
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
339
+ Es=feature_rows.tolist(),
340
+ Ds=feature_dims.tolist(),
341
+ embedding_specs=embedding_specs,
342
+ feature_table_map=feature_table_map,
185
343
  )
186
344
 
187
345
  # Write the TBE config to FileStore
188
346
  self.filestore.write(
189
- f"tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter.item()}.json",
190
- io.BytesIO(config.json(format=True).encode()),
347
+ f"{self.path_prefix}/tbe-{op_id}-config-estimation-{iteration}.json",
348
+ io.BytesIO(json.dumps(config.dict(), indent=2).encode()),
191
349
  )
192
-
193
- self.has_reported = True
@@ -6,7 +6,7 @@
6
6
 
7
7
  # pyre-strict
8
8
 
9
- from typing import Callable, Optional, Tuple
9
+ from typing import Callable, Optional
10
10
 
11
11
  import numpy as np
12
12
  import torch
@@ -21,9 +21,9 @@ def get_table_batched_offsets_from_dense(
21
21
  L: Optional[int] = None,
22
22
  total_B: Optional[int] = None,
23
23
  use_cpu: bool = False,
24
- ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ ) -> tuple[torch.Tensor, torch.Tensor]:
25
25
  if L is None and total_B is None:
26
- (T, B, L) = merged_indices.size()
26
+ T, B, L = merged_indices.size()
27
27
  total_B = T * B
28
28
  # pyre-fixme[6]: For 1st argument expected `Union[Sequence[SupportsIndex],
29
29
  # SupportsIndex]` but got `Optional[int]`.
@@ -37,8 +37,8 @@ def get_table_batched_offsets_from_dense(
37
37
  )
38
38
 
39
39
 
40
- def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
41
- (B, L) = indices.size()
40
+ def get_offsets_from_dense(indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
41
+ B, L = indices.size()
42
42
  return (
43
43
  indices.contiguous().view(-1),
44
44
  torch.tensor(
@@ -54,7 +54,7 @@ def b_indices(
54
54
  use_cpu: bool = False,
55
55
  do_pooling: bool = True,
56
56
  ) -> torch.Tensor:
57
- (indices, offsets) = get_offsets_from_dense(x)
57
+ indices, offsets = get_offsets_from_dense(x)
58
58
  if do_pooling:
59
59
  return b(
60
60
  to_device(indices, use_cpu),
@@ -7,7 +7,7 @@
7
7
  # pyre-strict
8
8
  # pyre-ignore-all-errors[61]
9
9
 
10
- from typing import Optional, Tuple
10
+ from typing import Optional
11
11
 
12
12
  import torch
13
13
 
@@ -22,7 +22,7 @@ def quantize_embs(
22
22
  weight: torch.Tensor,
23
23
  weight_ty: SparseType,
24
24
  fp8_config: Optional[FP8QuantizationConfig] = None,
25
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
25
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
26
26
  weight = weight.detach()
27
27
  if weight_ty == SparseType.FP32:
28
28
  q_weight = weight.float()
@@ -91,7 +91,7 @@ def dequantize_embs(
91
91
  th_scale_shift: torch.Tensor = scale_shift.view(torch.float16).to(torch.float32)
92
92
 
93
93
  if weight_ty == SparseType.INT4:
94
- (E, D_2) = th_weights.shape
94
+ E, D_2 = th_weights.shape
95
95
  D = D_2 * 2
96
96
 
97
97
  def comp(i: int) -> torch.Tensor:
@@ -109,7 +109,7 @@ def dequantize_embs(
109
109
  return to_device(torch.tensor(comps), use_cpu)
110
110
 
111
111
  elif weight_ty == SparseType.INT2:
112
- (E, D_4) = th_weights.shape
112
+ E, D_4 = th_weights.shape
113
113
  D = D_4 * 4
114
114
 
115
115
  # pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
@@ -129,7 +129,7 @@ def dequantize_embs(
129
129
  return to_device(torch.tensor(comps), use_cpu)
130
130
 
131
131
  elif weight_ty == SparseType.INT8:
132
- (E, D) = th_weights.shape
132
+ E, D = th_weights.shape
133
133
  comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
134
134
  torch.float32
135
135
  ) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
@@ -177,7 +177,7 @@ def fake_quantize_embs(
177
177
  )
178
178
 
179
179
  if weight_ty == SparseType.INT4:
180
- (E, D_2) = th_weights.shape
180
+ E, D_2 = th_weights.shape
181
181
  D = D_2 * 2
182
182
 
183
183
  def comp(i: int) -> torch.Tensor:
@@ -195,7 +195,7 @@ def fake_quantize_embs(
195
195
  dequant_weights.copy_(to_device(comps, use_cpu))
196
196
 
197
197
  elif weight_ty == SparseType.INT2:
198
- (E, D_4) = th_weights.shape
198
+ E, D_4 = th_weights.shape
199
199
  D = D_4 * 4
200
200
 
201
201
  # pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
@@ -215,7 +215,7 @@ def fake_quantize_embs(
215
215
  dequant_weights.copy_(to_device(comps, use_cpu))
216
216
 
217
217
  elif weight_ty == SparseType.INT8:
218
- (E, D) = th_weights.shape
218
+ E, D = th_weights.shape
219
219
  comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
220
220
  torch.float32
221
221
  ) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
@@ -8,7 +8,7 @@
8
8
 
9
9
  import logging
10
10
  from dataclasses import dataclass
11
- from typing import List, Optional, Tuple
11
+ from typing import Optional
12
12
 
13
13
  import numpy as np
14
14
  import numpy.typing as npt
@@ -32,20 +32,20 @@ class TBERequest:
32
32
  indices: torch.Tensor
33
33
  offsets: torch.Tensor
34
34
  per_sample_weights: Optional[torch.Tensor] = None
35
- Bs_per_feature_per_rank: Optional[List[List[int]]] = None
35
+ Bs_per_feature_per_rank: Optional[list[list[int]]] = None
36
36
 
37
- def unpack_2(self) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ def unpack_2(self) -> tuple[torch.Tensor, torch.Tensor]:
38
38
  return (self.indices, self.offsets)
39
39
 
40
40
  def unpack_3(
41
41
  self,
42
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
42
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
43
43
  return (self.indices, self.offsets, self.per_sample_weights)
44
44
 
45
45
  def unpack_4(
46
46
  self,
47
- ) -> Tuple[
48
- torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[List[List[int]]]
47
+ ) -> tuple[
48
+ torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]
49
49
  ]:
50
50
  return (
51
51
  self.indices,
@@ -68,7 +68,7 @@ def generate_requests_from_data_file(
68
68
  tables: Optional[str] = None,
69
69
  index_dtype: Optional[torch.dtype] = None,
70
70
  offset_dtype: Optional[torch.dtype] = None,
71
- ) -> List[TBERequest]:
71
+ ) -> list[TBERequest]:
72
72
  """
73
73
  Generate TBE requests from the input data file. If `requests_data_file` is provided,
74
74
  `indices_file` and `offsets_file` should not be provided. If either `indices_file`
@@ -178,12 +178,12 @@ def generate_int_data_from_stats(
178
178
 
179
179
  def generate_pooling_factors_from_stats(
180
180
  iters: int,
181
- Bs: List[int],
181
+ Bs: list[int],
182
182
  L: int,
183
183
  sigma_L: int,
184
184
  # distribution of pooling factors
185
185
  length_dist: str,
186
- ) -> Tuple[int, torch.Tensor]:
186
+ ) -> tuple[int, torch.Tensor]:
187
187
  """
188
188
  Generate pooling factors for the TBE requests from the given stats
189
189
  """
@@ -211,7 +211,7 @@ def generate_batch_sizes_from_stats(
211
211
  vbe_num_ranks: int,
212
212
  # Distribution of batch sizes
213
213
  batch_size_dist: str,
214
- ) -> Tuple[List[int], List[List[int]]]:
214
+ ) -> tuple[list[int], list[list[int]]]:
215
215
  """
216
216
  Generate batch sizes for features from the given stats
217
217
  """
@@ -234,7 +234,7 @@ def generate_batch_sizes_from_stats(
234
234
 
235
235
  def generate_indices_uniform(
236
236
  iters: int,
237
- Bs: List[int],
237
+ Bs: list[int],
238
238
  L: int,
239
239
  E: int,
240
240
  use_variable_L: bool,
@@ -252,7 +252,7 @@ def generate_indices_uniform(
252
252
  dtype=torch.int32,
253
253
  )
254
254
  # each bag is usually sorted
255
- (indices, _) = torch.sort(indices)
255
+ indices, _ = torch.sort(indices)
256
256
  if use_variable_L:
257
257
  # 1D layout, where row offsets are determined by L_offsets
258
258
  indices = torch.ops.fbgemm.bottom_k_per_row(
@@ -267,7 +267,7 @@ def generate_indices_uniform(
267
267
 
268
268
  def generate_indices_zipf(
269
269
  iters: int,
270
- Bs: List[int],
270
+ Bs: list[int],
271
271
  L: int,
272
272
  E: int,
273
273
  alpha: float,
@@ -324,7 +324,7 @@ def generate_indices_zipf(
324
324
 
325
325
  def update_indices_with_random_reuse(
326
326
  iters: int,
327
- Bs: List[int],
327
+ Bs: list[int],
328
328
  L: int,
329
329
  reuse: float,
330
330
  indices: torch.Tensor,
@@ -411,7 +411,7 @@ def generate_requests( # noqa C901
411
411
  vbe_num_ranks: Optional[int] = None,
412
412
  index_dtype: Optional[torch.dtype] = None,
413
413
  offset_dtype: Optional[torch.dtype] = None,
414
- ) -> List[TBERequest]:
414
+ ) -> list[TBERequest]:
415
415
  # TODO: refactor and split into helper functions to separate load from file,
416
416
  # generate from distribution, and other future methods of generating data
417
417
  if (
@@ -8,9 +8,8 @@
8
8
  # pyre-unsafe
9
9
 
10
10
  import abc
11
-
12
11
  from dataclasses import dataclass
13
- from typing import List, Optional
12
+ from typing import Optional
14
13
 
15
14
  from torch import Tensor
16
15
 
@@ -32,15 +31,15 @@ class TBEInfo:
32
31
  col_offset: the shard offset of the current rank on column (dim)
33
32
  """
34
33
 
35
- table_names: List[str]
36
- table_heights: List[int]
34
+ table_names: list[str]
35
+ table_heights: list[int]
37
36
  tbe_uuid: str
38
- feature_table_map: List[int]
39
- table_dims: List[int]
40
- full_table_heights: List[int]
41
- full_table_dims: List[int]
42
- row_offset: List[int]
43
- col_offset: List[int]
37
+ feature_table_map: list[int]
38
+ table_dims: list[int]
39
+ full_table_heights: list[int]
40
+ full_table_dims: list[int]
41
+ row_offset: list[int]
42
+ col_offset: list[int]
44
43
 
45
44
 
46
45
  @dataclass(frozen=True)
@@ -55,7 +54,7 @@ class TBEInputInfo:
55
54
 
56
55
  indices: Tensor
57
56
  offsets: Tensor
58
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None
57
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None
59
58
 
60
59
 
61
60
  class TBEInputMultiplexer(abc.ABC):
@@ -10,7 +10,6 @@ from enum import IntEnum
10
10
 
11
11
  import torch
12
12
 
13
-
14
13
  # We keep LUTs persistent to minimize the number of device copies required.
15
14
  E2M1_LUT = torch.tensor(
16
15
  [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6],