fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -33,7 +33,7 @@ class PartiallyMaterializedTensor:
33
33
  or use `full_tensor()` to get the full tensor (this could OOM).
34
34
  """
35
35
 
36
- def __init__(self, wrapped) -> None:
36
+ def __init__(self, wrapped, is_virtual: bool = False) -> None:
37
37
  """
38
38
  Ensure caller loads the module before creating this object.
39
39
 
@@ -48,6 +48,7 @@ class PartiallyMaterializedTensor:
48
48
  wrapped: torch.classes.fbgemm.KVTensorWrapper
49
49
  """
50
50
  self._wrapped = wrapped
51
+ self._is_virtual = is_virtual
51
52
  self._requires_grad = False
52
53
 
53
54
  @property
@@ -57,6 +58,17 @@ class PartiallyMaterializedTensor:
57
58
  """
58
59
  return self._wrapped
59
60
 
61
+ @property
62
+ def is_virtual(self):
63
+ """
64
+ Indicate whether PMT is a virtual tensor.
65
+ This indicator is needed for checkpoint or publish.
66
+ They need to know wheether it is PMT for kvzch or for normal emb table
67
+ for kvzch, checkpoint and publish need to call all-gather to recalculate the correct
68
+ metadata of the ShardedTensor
69
+ """
70
+ return self._is_virtual
71
+
60
72
  @classmethod
61
73
  def __torch_function__(cls, func, types, args=(), kwargs=None):
62
74
  if kwargs is None:
@@ -75,6 +87,18 @@ class PartiallyMaterializedTensor:
75
87
  """
76
88
  return self._wrapped.narrow(dim, start, length)
77
89
 
90
+ def set_weights_and_ids(self, weights: torch.Tensor, ids: torch.Tensor) -> None:
91
+ self._wrapped.set_weights_and_ids(weights, ids)
92
+
93
+ def get_weights_by_ids(self, ids: torch.Tensor) -> torch.Tensor:
94
+ return self._wrapped.get_weights_by_ids(ids)
95
+
96
+ def __reduce__(self):
97
+ return (
98
+ PartiallyMaterializedTensor,
99
+ (self._wrapped,),
100
+ )
101
+
78
102
  def full_tensor(self) -> torch.Tensor:
79
103
  """
80
104
  This loads the full tensor into memory (may OOM).
@@ -141,6 +165,8 @@ class PartiallyMaterializedTensor:
141
165
 
142
166
  @property
143
167
  def dtype(self) -> torch.dtype:
168
+ if isinstance(self._wrapped, torch.Tensor):
169
+ return self._wrapped.dtype
144
170
  mapping = {"c10::Half": "half"}
145
171
  dtype_str: str = self._wrapped.dtype_str
146
172
  dtype_str = mapping.get(dtype_str, dtype_str)
@@ -151,6 +177,8 @@ class PartiallyMaterializedTensor:
151
177
 
152
178
  @property
153
179
  def device(self) -> torch.device:
180
+ if isinstance(self._wrapped, torch.Tensor):
181
+ return self._wrapped.device
154
182
  device_str: str = self._wrapped.device_str
155
183
  device = torch.device(device_str)
156
184
  assert isinstance(device, torch.device)
@@ -158,11 +186,11 @@ class PartiallyMaterializedTensor:
158
186
 
159
187
  @property
160
188
  def layout(self) -> torch.layout:
161
- pass
189
+ if isinstance(self._wrapped, torch.Tensor):
190
+ return self._wrapped.layout
162
191
  layout_str_mapping = {
163
192
  "SparseCsr": "sparse_csr",
164
193
  "Strided": "strided",
165
- "SparseCsr": "sparse_csr",
166
194
  "SparseCsc": "sparse_csc",
167
195
  "Jagged": "jagged",
168
196
  }
@@ -220,6 +248,9 @@ class PartiallyMaterializedTensor:
220
248
 
221
249
  return torch.equal(tensor1.full_tensor(), tensor2.full_tensor())
222
250
 
251
+ def get_kvtensor_serializable_metadata(self) -> list[str]:
252
+ return self._wrapped.get_kvtensor_serializable_metadata()
253
+
223
254
  def __hash__(self):
224
255
  return id(self)
225
256
 
@@ -0,0 +1,10 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ from .bench_params_reporter import TBEBenchmarkParamsReporter # noqa F401
@@ -0,0 +1,349 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import io
11
+ import json
12
+ import logging
13
+ import os
14
+ from typing import List, Optional, Tuple
15
+
16
+ import fbgemm_gpu # noqa F401
17
+ import torch # usort:skip
18
+
19
+ from fbgemm_gpu.tbe.bench.tbe_data_config import (
20
+ BatchParams,
21
+ IndicesParams,
22
+ PoolingParams,
23
+ TBEDataConfig,
24
+ )
25
+
26
+ open_source: bool = False
27
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
28
+ open_source: bool = getattr(fbgemm_gpu, "open_source", False)
29
+
30
+ if open_source:
31
+ from fbgemm_gpu.utils import FileStore
32
+
33
+ else:
34
+ try:
35
+ from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore
36
+
37
+ torch.ops.load_library(
38
+ "//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_estimator"
39
+ )
40
+ except Exception:
41
+ pass
42
+
43
+
44
+ class TBEBenchmarkParamsReporter:
45
+ """
46
+ TBEBenchmarkParamsReporter is responsible for extracting and reporting the configuration data of TBE processes.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ report_interval: int,
52
+ report_iter_start: int = 0,
53
+ report_iter_end: int = -1,
54
+ bucket: Optional[str] = None,
55
+ path_prefix: Optional[str] = None,
56
+ ) -> None:
57
+ """
58
+ Initializes the TBEBenchmarkParamsReporter with the specified parameters.
59
+
60
+ Args:
61
+ report_interval (int): The interval at which reports are generated.
62
+ report_iter_start (int): The start of the iteration range to capture. Defaults to 0.
63
+ report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration).
64
+ bucket (Optional[str], optional): The storage bucket for reports. Defaults to None.
65
+ path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None.
66
+ """
67
+
68
+ assert report_interval > 0, "report_interval must be greater than 0"
69
+ assert (
70
+ report_iter_start >= 0
71
+ ), "report_iter_start must be greater than or equal to 0"
72
+ assert (
73
+ report_iter_end >= -1
74
+ ), "report_iter_end must be greater than or equal to -1"
75
+ assert (
76
+ report_iter_end == -1 or report_iter_start <= report_iter_end
77
+ ), "report_iter_start must be less than or equal to report_iter_end"
78
+
79
+ self.report_interval = report_interval
80
+ self.report_iter_start = report_iter_start
81
+ self.report_iter_end = report_iter_end
82
+
83
+ if path_prefix is not None and path_prefix.endswith("/"):
84
+ path_prefix = path_prefix[:-1]
85
+
86
+ self.path_prefix = path_prefix
87
+
88
+ default_bucket = "/tmp" if open_source else "tlparse_reports"
89
+ bucket = (
90
+ bucket
91
+ if bucket is not None
92
+ else os.environ.get("FBGEMM_TBE_REPORTING_BUCKET", default_bucket)
93
+ )
94
+ self.filestore = FileStore(bucket)
95
+
96
+ if self.path_prefix is not None and not self.filestore.exists(self.path_prefix):
97
+ self.filestore.create_directory(self.path_prefix)
98
+
99
+ self.logger: logging.Logger = logging.getLogger(__name__)
100
+ self.logger.setLevel(logging.INFO)
101
+
102
+ @classmethod
103
+ def create(cls) -> "TBEBenchmarkParamsReporter":
104
+ """
105
+ This method returns an instance of TBEBenchmarkParamsReporter based on environment variables.
106
+
107
+ If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that:
108
+ - Reports input parameters (TBEDataConfig).
109
+ - Writes the output as a JSON file.
110
+
111
+ Additionally, the following environment variables are considered:
112
+ - `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture.
113
+ - `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture.
114
+ - `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting.
115
+ - `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting.
116
+
117
+ Returns:
118
+ TBEBenchmarkParamsReporter: An instance configured based on the environment variables.
119
+ """
120
+ report_interval = int(
121
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1")
122
+ )
123
+ report_iter_start = int(
124
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0")
125
+ )
126
+ report_iter_end = int(
127
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1")
128
+ )
129
+ bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "")
130
+ path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "")
131
+
132
+ return cls(
133
+ report_interval=report_interval,
134
+ report_iter_start=report_iter_start,
135
+ report_iter_end=report_iter_end,
136
+ bucket=bucket,
137
+ path_prefix=path_prefix,
138
+ )
139
+
140
+ def extract_Ls(
141
+ self,
142
+ bag_sizes: List[int],
143
+ Bs: List[int],
144
+ ) -> List[float]:
145
+ Ls = []
146
+ start = 0
147
+ for b in Bs:
148
+ end = start + b
149
+ avg_L = sum(bag_sizes[start:end]) / b if b > 0 else 0
150
+ start = end
151
+ Ls.append(avg_L)
152
+ return Ls
153
+
154
+ def extract_params(
155
+ self,
156
+ feature_rows: torch.Tensor,
157
+ feature_dims: torch.Tensor,
158
+ indices: torch.Tensor,
159
+ offsets: torch.Tensor,
160
+ per_sample_weights: Optional[torch.Tensor] = None,
161
+ batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
162
+ Es: Optional[List[int]] = None,
163
+ Ds: Optional[List[int]] = None,
164
+ embedding_specs: Optional[List[Tuple[int, int]]] = None,
165
+ feature_table_map: Optional[List[int]] = None,
166
+ ) -> TBEDataConfig:
167
+ """
168
+ Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig.
169
+
170
+ Args:
171
+ feature_rows (torch.Tensor): Number of rows in each feature.
172
+ feature_dims (torch.Tensor): Number of dimensions in each feature.
173
+ indices (torch.Tensor): The input indices tensor.
174
+ offsets (torch.Tensor): The input offsets tensor.
175
+ per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
176
+ batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
177
+
178
+ Returns:
179
+ TBEDataConfig: The configuration data for TBE benchmarking.
180
+ """
181
+
182
+ Es = feature_rows.tolist()
183
+ Ds = feature_dims.tolist()
184
+
185
+ assert len(Es) == len(
186
+ Ds
187
+ ), "feature_rows and feature_dims must have the same length"
188
+
189
+ # Transfer indices back to CPU for EEG analysis
190
+ indices_cpu = indices.cpu()
191
+
192
+ # Set T to be the number of features we are looking at
193
+ T = len(Ds)
194
+ # Set E to be the mean of the rowcounts to avoid biasing
195
+ E = (
196
+ Es[0]
197
+ if len(set(Es)) == 1
198
+ else torch.ceil(
199
+ torch.mean(torch.tensor(feature_rows, dtype=torch.float))
200
+ ).item()
201
+ )
202
+ # Set mixed_dim to be True if there are multiple dims
203
+ mixed_dim = len(set(Ds)) > 1
204
+ # Set D to be the mean of the dims to avoid biasing
205
+ D = (
206
+ Ds[0]
207
+ if not mixed_dim
208
+ else torch.ceil(
209
+ torch.mean(torch.tensor(feature_dims, dtype=torch.float))
210
+ ).item()
211
+ )
212
+
213
+ # Compute indices distribution parameters
214
+ heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution(
215
+ indices_cpu
216
+ )
217
+ indices_params = IndicesParams(
218
+ heavy_hitters, q, s, indices.dtype, offsets.dtype
219
+ )
220
+
221
+ # Compute batch parameters
222
+ B = int((offsets.numel() - 1) // T)
223
+ Bs = (
224
+ [sum(b_per_rank) for b_per_rank in batch_size_per_feature_per_rank]
225
+ if batch_size_per_feature_per_rank
226
+ else [B] * T
227
+ )
228
+ batch_params = BatchParams(
229
+ B=B,
230
+ sigma_B=(
231
+ int(
232
+ torch.ceil(
233
+ torch.std(
234
+ torch.tensor(
235
+ [
236
+ b
237
+ for bs in batch_size_per_feature_per_rank
238
+ for b in bs
239
+ ]
240
+ ).float()
241
+ )
242
+ )
243
+ )
244
+ if batch_size_per_feature_per_rank
245
+ else None
246
+ ),
247
+ vbe_distribution=("normal" if batch_size_per_feature_per_rank else None),
248
+ vbe_num_ranks=(
249
+ len(batch_size_per_feature_per_rank)
250
+ if batch_size_per_feature_per_rank
251
+ else None
252
+ ),
253
+ Bs=Bs,
254
+ )
255
+
256
+ # Compute pooling parameters
257
+ bag_sizes = offsets[1:] - offsets[:-1]
258
+ if batch_size_per_feature_per_rank is None:
259
+ _B = int(bag_sizes.numel() // T)
260
+ assert _B == Bs[0], f"Expected constant batch size {Bs[0]} but got {_B}"
261
+ mixed_bag_sizes = len(set(bag_sizes)) > 1
262
+ pooling_params = PoolingParams(
263
+ L=(
264
+ int(torch.ceil(torch.mean(bag_sizes.float())))
265
+ if mixed_bag_sizes
266
+ else int(bag_sizes[0])
267
+ ),
268
+ sigma_L=(
269
+ int(torch.ceil(torch.std(bag_sizes.float())))
270
+ if mixed_bag_sizes
271
+ else None
272
+ ),
273
+ length_distribution=("normal" if mixed_bag_sizes else None),
274
+ Ls=self.extract_Ls(bag_sizes.tolist(), Bs),
275
+ )
276
+
277
+ return TBEDataConfig(
278
+ T=T,
279
+ E=E,
280
+ D=D,
281
+ mixed_dim=mixed_dim,
282
+ weighted=(per_sample_weights is not None),
283
+ batch_params=batch_params,
284
+ indices_params=indices_params,
285
+ pooling_params=pooling_params,
286
+ use_cpu=(not torch.cuda.is_available()),
287
+ Es=Es,
288
+ Ds=Ds,
289
+ embedding_specs=embedding_specs,
290
+ feature_table_map=feature_table_map,
291
+ )
292
+
293
+ def report_stats(
294
+ self,
295
+ feature_rows: torch.Tensor,
296
+ feature_dims: torch.Tensor,
297
+ iteration: int,
298
+ indices: torch.Tensor,
299
+ offsets: torch.Tensor,
300
+ op_id: str = "",
301
+ per_sample_weights: Optional[torch.Tensor] = None,
302
+ batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
303
+ embedding_specs: Optional[List[Tuple[int, int]]] = None,
304
+ feature_table_map: Optional[List[int]] = None,
305
+ ) -> None:
306
+ """
307
+ Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore.
308
+
309
+ Args:
310
+ feature_rows (torch.Tensor): Number of rows in each feature.
311
+ feature_dims (torch.Tensor): Number of dimensions in each feature.
312
+ iteration (int): The current iteration number.
313
+ indices (torch.Tensor): The input indices tensor.
314
+ offsets (torch.Tensor): The input offsets tensor.
315
+ op_id (str, optional): The operation identifier. Defaults to an empty string.
316
+ per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
317
+ batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
318
+ embedding_specs (Optional[List[Tuple[int, int]]]): Embedding specs. Defaults to None.
319
+ feature_table_map (Optional[List[int]], optional): Feature table map. Defaults to None.
320
+ """
321
+ if (
322
+ (iteration - self.report_iter_start) % self.report_interval == 0
323
+ and (iteration >= self.report_iter_start)
324
+ and (self.report_iter_end == -1 or iteration <= self.report_iter_end)
325
+ ):
326
+ # If indices tensor is empty (indices.numel() == 0), skip reporting
327
+ # TODO: Remove this once we have a better way to handle empty indices tensors
328
+ if indices.numel() == 0:
329
+ return
330
+
331
+ # Extract TBE config
332
+ config = self.extract_params(
333
+ feature_rows=feature_rows,
334
+ feature_dims=feature_dims,
335
+ indices=indices,
336
+ offsets=offsets,
337
+ per_sample_weights=per_sample_weights,
338
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
339
+ Es=feature_rows.tolist(),
340
+ Ds=feature_dims.tolist(),
341
+ embedding_specs=embedding_specs,
342
+ feature_table_map=feature_table_map,
343
+ )
344
+
345
+ # Write the TBE config to FileStore
346
+ self.filestore.write(
347
+ f"{self.path_prefix}/tbe-{op_id}-config-estimation-{iteration}.json",
348
+ io.BytesIO(json.dumps(config.dict(), indent=2).encode()),
349
+ )
@@ -6,7 +6,7 @@
6
6
 
7
7
  # pyre-strict
8
8
 
9
- from typing import Callable, Optional, Tuple
9
+ from typing import Callable, Optional
10
10
 
11
11
  import numpy as np
12
12
  import torch
@@ -21,9 +21,9 @@ def get_table_batched_offsets_from_dense(
21
21
  L: Optional[int] = None,
22
22
  total_B: Optional[int] = None,
23
23
  use_cpu: bool = False,
24
- ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ ) -> tuple[torch.Tensor, torch.Tensor]:
25
25
  if L is None and total_B is None:
26
- (T, B, L) = merged_indices.size()
26
+ T, B, L = merged_indices.size()
27
27
  total_B = T * B
28
28
  # pyre-fixme[6]: For 1st argument expected `Union[Sequence[SupportsIndex],
29
29
  # SupportsIndex]` but got `Optional[int]`.
@@ -37,8 +37,8 @@ def get_table_batched_offsets_from_dense(
37
37
  )
38
38
 
39
39
 
40
- def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
41
- (B, L) = indices.size()
40
+ def get_offsets_from_dense(indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
41
+ B, L = indices.size()
42
42
  return (
43
43
  indices.contiguous().view(-1),
44
44
  torch.tensor(
@@ -54,7 +54,7 @@ def b_indices(
54
54
  use_cpu: bool = False,
55
55
  do_pooling: bool = True,
56
56
  ) -> torch.Tensor:
57
- (indices, offsets) = get_offsets_from_dense(x)
57
+ indices, offsets = get_offsets_from_dense(x)
58
58
  if do_pooling:
59
59
  return b(
60
60
  to_device(indices, use_cpu),
@@ -7,7 +7,7 @@
7
7
  # pyre-strict
8
8
  # pyre-ignore-all-errors[61]
9
9
 
10
- from typing import Optional, Tuple
10
+ from typing import Optional
11
11
 
12
12
  import torch
13
13
 
@@ -22,7 +22,7 @@ def quantize_embs(
22
22
  weight: torch.Tensor,
23
23
  weight_ty: SparseType,
24
24
  fp8_config: Optional[FP8QuantizationConfig] = None,
25
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
25
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
26
26
  weight = weight.detach()
27
27
  if weight_ty == SparseType.FP32:
28
28
  q_weight = weight.float()
@@ -91,7 +91,7 @@ def dequantize_embs(
91
91
  th_scale_shift: torch.Tensor = scale_shift.view(torch.float16).to(torch.float32)
92
92
 
93
93
  if weight_ty == SparseType.INT4:
94
- (E, D_2) = th_weights.shape
94
+ E, D_2 = th_weights.shape
95
95
  D = D_2 * 2
96
96
 
97
97
  def comp(i: int) -> torch.Tensor:
@@ -109,7 +109,7 @@ def dequantize_embs(
109
109
  return to_device(torch.tensor(comps), use_cpu)
110
110
 
111
111
  elif weight_ty == SparseType.INT2:
112
- (E, D_4) = th_weights.shape
112
+ E, D_4 = th_weights.shape
113
113
  D = D_4 * 4
114
114
 
115
115
  # pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
@@ -129,7 +129,7 @@ def dequantize_embs(
129
129
  return to_device(torch.tensor(comps), use_cpu)
130
130
 
131
131
  elif weight_ty == SparseType.INT8:
132
- (E, D) = th_weights.shape
132
+ E, D = th_weights.shape
133
133
  comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
134
134
  torch.float32
135
135
  ) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
@@ -177,7 +177,7 @@ def fake_quantize_embs(
177
177
  )
178
178
 
179
179
  if weight_ty == SparseType.INT4:
180
- (E, D_2) = th_weights.shape
180
+ E, D_2 = th_weights.shape
181
181
  D = D_2 * 2
182
182
 
183
183
  def comp(i: int) -> torch.Tensor:
@@ -195,7 +195,7 @@ def fake_quantize_embs(
195
195
  dequant_weights.copy_(to_device(comps, use_cpu))
196
196
 
197
197
  elif weight_ty == SparseType.INT2:
198
- (E, D_4) = th_weights.shape
198
+ E, D_4 = th_weights.shape
199
199
  D = D_4 * 4
200
200
 
201
201
  # pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
@@ -215,7 +215,7 @@ def fake_quantize_embs(
215
215
  dequant_weights.copy_(to_device(comps, use_cpu))
216
216
 
217
217
  elif weight_ty == SparseType.INT8:
218
- (E, D) = th_weights.shape
218
+ E, D = th_weights.shape
219
219
  comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
220
220
  torch.float32
221
221
  ) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)