fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -9,21 +9,18 @@
9
9
 
10
10
  import dataclasses
11
11
  import json
12
- from typing import Any, Dict, List, Optional, Tuple
12
+ import logging
13
+ from typing import Any, List, Optional, Tuple
13
14
 
14
- import numpy as np
15
15
  import torch
16
16
 
17
- from fbgemm_gpu.tbe.utils.common import get_device, round_up
18
- from fbgemm_gpu.tbe.utils.requests import (
19
- generate_batch_sizes_from_stats,
20
- generate_pooling_factors_from_stats,
21
- get_table_batched_offsets_from_dense,
22
- maybe_to_dtype,
23
- TBERequest,
24
- )
25
-
26
- from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
17
+ # fmt:skip
18
+ from fbgemm_gpu.tbe.utils.common import get_device
19
+ from .tbe_data_config_param_models import (
20
+ BatchParams,
21
+ IndicesParams,
22
+ PoolingParams,
23
+ ) # usort:skip
27
24
 
28
25
  try:
29
26
  torch.ops.load_library(
@@ -35,27 +32,87 @@ except Exception:
35
32
 
36
33
  @dataclasses.dataclass(frozen=True)
37
34
  class TBEDataConfig:
38
- # Number of tables
39
35
  T: int
40
- # Number of rows in the embedding table
41
36
  E: int
42
- # Target embedding dimension for a table (number of columns)
43
37
  D: int
44
- # Generate mixed dimensions if true
45
38
  mixed_dim: bool
46
- # Whether the table is weighted or not
47
39
  weighted: bool
48
- # Batch parameters
49
40
  batch_params: BatchParams
50
- # Indices parameters
51
41
  indices_params: IndicesParams
52
- # Pooling parameters
53
42
  pooling_params: PoolingParams
54
- # Force generated tensors to be on CPU
55
43
  use_cpu: bool = False
44
+ Es: Optional[list[int]] = None
45
+ Ds: Optional[list[int]] = None
46
+ max_indices: Optional[int] = None
47
+ embedding_specs: Optional[List[Tuple[int, int]]] = None
48
+ feature_table_map: Optional[List[int]] = None
49
+ """
50
+ Configuration for TBE (Table Batched Embedding) benchmark data collection and generation.
51
+
52
+ This dataclass holds parameters required to generate synthetic data for
53
+ TBE benchmarking, including table specifications, batch parameters, indices
54
+ distribution parameters, and pooling parameters.
55
+
56
+ Args:
57
+ T (int): Number of embedding tables (features). Must be positive.
58
+ E (int): Number of rows in the embedding table (feature). If T > 1, this
59
+ represents the averaged number of rows across all features.
60
+ D (int): Target embedding dimension for a table (feature), i.e., number of
61
+ columns. If T > 1, this represents the averaged dimension across
62
+ all features.
63
+ mixed_dim (bool): If True, generate embeddings with mixed dimensions
64
+ across tables (features). This is automatically set to True if D is provided
65
+ as a list with non-uniform values.
66
+ weighted (bool): If True, the lookup rows are weighted (per-sample
67
+ weights). The weights will be generated as FP32 tensors.
68
+ batch_params (BatchParams): Parameters controlling batch generation.
69
+ Contains:
70
+ (1) `B` = target batch size (number of batch lookups per features)
71
+ (2) `sigma_B` = optional standard deviation for variable batch size
72
+ (3) `vbe_distribution` = distribution type ("normal" or "uniform")
73
+ (4) `vbe_num_ranks` = number of ranks for variable batch size
74
+ (5) `Bs` = per-feature batch sizes
75
+ indices_params (IndicesParams): Parameters controlling index generation
76
+ following a Zipf distribution. Contains:
77
+ (1) `heavy_hitters` = probability density map for hot indices
78
+ (2) `zipf_q` = q parameter in Zipf distribution (x+q)^{-s}
79
+ (3) `zipf_s` = s parameter (alpha) in Zipf distribution
80
+ (4) `index_dtype` = optional dtype for indices tensor
81
+ (5) `offset_dtype` = optional dtype for offsets tensor
82
+ pooling_params (PoolingParams): Parameters controlling pooling behavior.
83
+ Contains:
84
+ (1) `L` = target bag size (pooling factor, indices per lookup)
85
+ (2) `sigma_L` = optional standard deviation for variable bag size
86
+ (3) `length_distribution` = distribution type ("normal" or "uniform")
87
+ (4) `Ls` = per-feature bag sizes
88
+ use_cpu (bool = False): If True, force generated tensors to be placed
89
+ on CPU instead of the default compute device.
90
+ Es (Optional[List[int]] = None): Number of embeddings (rows) for each
91
+ individual embedding feature. If provided, must have length equal
92
+ to T. All elements must be positive.
93
+ Ds (Optional[List[int]] = None): Target embedding dimension (columns)
94
+ for each individual feature. If provided, must have length equal
95
+ to T. All elements must be positive.
96
+ max_indices (Optional[int] = None): Maximum number of indices for
97
+ bounds checking. If Es is provided as a list and max_indices is
98
+ None, it is automatically computed as sum(Es) - 1.
99
+ embedding_specs (Optional[List[Tuple[int, int]]] = None): A list of
100
+ embedding specs consisting of a list of tuples of (num_rows, embedding_dim).
101
+ See https://fburl.com/tbe_embedding_specs for details.
102
+ feature_table_map (Optional[List[int]] = None): An optional list that
103
+ specifies feature-table mapping. feature_table_map[i] indicates the
104
+ physical embedding table that feature i maps to.
105
+ """
106
+
107
+ def __post_init__(self) -> None:
108
+ if isinstance(self.D, list):
109
+ object.__setattr__(self, "mixed_dim", len(set(self.D)) > 1)
110
+ if isinstance(self.E, list) and self.max_indices is None:
111
+ object.__setattr__(self, "max_indices", sum(self.E) - 1)
112
+ self.validate()
56
113
 
57
114
  @staticmethod
58
- def complex_fields() -> Dict[str, Any]:
115
+ def complex_fields() -> dict[str, Any]:
59
116
  return {
60
117
  "batch_params": BatchParams,
61
118
  "indices_params": IndicesParams,
@@ -64,7 +121,7 @@ class TBEDataConfig:
64
121
 
65
122
  @classmethod
66
123
  # pyre-ignore [3]
67
- def from_dict(cls, data: Dict[str, Any]):
124
+ def from_dict(cls, data: dict[str, Any]):
68
125
  for field, Type in cls.complex_fields().items():
69
126
  if not isinstance(data[field], Type):
70
127
  data[field] = Type.from_dict(data[field])
@@ -73,9 +130,22 @@ class TBEDataConfig:
73
130
  @classmethod
74
131
  # pyre-ignore [3]
75
132
  def from_json(cls, data: str):
76
- return cls.from_dict(json.loads(data))
133
+ raw = json.loads(data)
134
+ allowed = {f.name for f in dataclasses.fields(cls)}
135
+ existing_fields = {k: v for k, v in raw.items() if k in allowed}
136
+ missing_fields = allowed - set(existing_fields.keys())
137
+ unknown_fields = set(raw.keys()) - allowed
138
+ if missing_fields:
139
+ logging.warning(
140
+ f"TBEDataConfig.from_json: Missing expected fields not loaded: {sorted(missing_fields)}"
141
+ )
142
+ if unknown_fields:
143
+ logging.info(
144
+ f"TBEDataConfig.from_json: Ignored unknown fields from input: {sorted(unknown_fields)}"
145
+ )
146
+ return cls.from_dict(existing_fields)
77
147
 
78
- def dict(self) -> Dict[str, Any]:
148
+ def dict(self) -> dict[str, Any]:
79
149
  tmp = dataclasses.asdict(self)
80
150
  for field in TBEDataConfig.complex_fields().keys():
81
151
  tmp[field] = self.__dict__[field].dict()
@@ -89,10 +159,30 @@ class TBEDataConfig:
89
159
  # NOTE: Add validation logic here
90
160
  assert self.T > 0, "T must be positive"
91
161
  assert self.E > 0, "E must be positive"
162
+ if self.Es is not None:
163
+ assert all(e > 0 for e in self.Es), "All elements in Es must be positive"
92
164
  assert self.D > 0, "D must be positive"
165
+ if self.Ds is not None:
166
+ assert all(d > 0 for d in self.Ds), "All elements in Ds must be positive"
167
+ if isinstance(self.Es, list) and isinstance(self.Ds, list):
168
+ assert (
169
+ len(self.Es) == len(self.Ds) == self.T
170
+ ), "Lengths of Es, Lengths of Ds, and T must be equal"
171
+ if self.max_indices is not None:
172
+ assert self.max_indices == (
173
+ sum(self.Es) - 1
174
+ ), "max_indices must be equal to sum(Es) - 1"
93
175
  self.batch_params.validate()
176
+ if self.batch_params.Bs is not None:
177
+ assert (
178
+ len(self.batch_params.Bs) == self.T
179
+ ), f"Length of Bs must be equal to T. Expected: {self.T}, but got: {len(self.batch_params.Bs)}"
94
180
  self.indices_params.validate()
95
181
  self.pooling_params.validate()
182
+ if self.pooling_params.Ls is not None:
183
+ assert (
184
+ len(self.pooling_params.Ls) == self.T
185
+ ), f"Length of Ls must be equal to T. Expected: {self.T}, but got: {len(self.pooling_params.Ls)}"
96
186
  return self
97
187
 
98
188
  def variable_B(self) -> bool:
@@ -102,177 +192,5 @@ class TBEDataConfig:
102
192
  return self.pooling_params.sigma_L is not None
103
193
 
104
194
  def _new_weights(self, size: int) -> Optional[torch.Tensor]:
105
- # per sample weights will always be FP32
195
+ # Per-sample weights will always be FP32
106
196
  return None if not self.weighted else torch.randn(size, device=get_device())
107
-
108
- def _generate_batch_sizes(self) -> Tuple[List[int], Optional[List[List[int]]]]:
109
- if self.variable_B():
110
- assert (
111
- self.batch_params.vbe_num_ranks is not None
112
- ), "vbe_num_ranks must be set for varaible batch size generation"
113
- return generate_batch_sizes_from_stats(
114
- self.batch_params.B,
115
- self.T,
116
- # pyre-ignore [6]
117
- self.batch_params.sigma_B,
118
- self.batch_params.vbe_num_ranks,
119
- # pyre-ignore [6]
120
- self.batch_params.vbe_distribution,
121
- )
122
-
123
- else:
124
- return ([self.batch_params.B] * self.T, None)
125
-
126
- def _generate_pooling_info(self, iters: int, Bs: List[int]) -> torch.Tensor:
127
- if self.variable_L():
128
- # Generate L from stats
129
- _, L_offsets = generate_pooling_factors_from_stats(
130
- iters,
131
- Bs,
132
- self.pooling_params.L,
133
- # pyre-ignore [6]
134
- self.pooling_params.sigma_L,
135
- # pyre-ignore [6]
136
- self.pooling_params.length_distribution,
137
- )
138
-
139
- else:
140
- Ls = [self.pooling_params.L] * (sum(Bs) * iters)
141
- L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
142
-
143
- return L_offsets
144
-
145
- def _generate_indices(
146
- self,
147
- iters: int,
148
- Bs: List[int],
149
- L_offsets: torch.Tensor,
150
- ) -> torch.Tensor:
151
- total_B = sum(Bs)
152
- L_offsets_list = L_offsets.tolist()
153
- indices_list = []
154
- for it in range(iters):
155
- # L_offsets is defined over the entire set of batches for a single iteration
156
- start_offset = L_offsets_list[it * total_B]
157
- end_offset = L_offsets_list[(it + 1) * total_B]
158
-
159
- indices_list.append(
160
- torch.ops.fbgemm.tbe_generate_indices_from_distribution(
161
- self.indices_params.heavy_hitters,
162
- self.indices_params.zipf_q,
163
- self.indices_params.zipf_s,
164
- # max_index = dimensions of the embedding table
165
- self.E,
166
- # num_indices = number of indices to generate
167
- end_offset - start_offset,
168
- )
169
- )
170
-
171
- return torch.cat(indices_list)
172
-
173
- def _build_requests_jagged(
174
- self,
175
- iters: int,
176
- Bs: List[int],
177
- Bs_feature_rank: Optional[List[List[int]]],
178
- L_offsets: torch.Tensor,
179
- all_indices: torch.Tensor,
180
- ) -> List[TBERequest]:
181
- total_B = sum(Bs)
182
- all_indices = all_indices.flatten()
183
- requests = []
184
- for it in range(iters):
185
- start_offset = L_offsets[it * total_B]
186
- it_L_offsets = torch.concat(
187
- [
188
- torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
189
- L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
190
- ]
191
- )
192
- requests.append(
193
- TBERequest(
194
- maybe_to_dtype(
195
- all_indices[start_offset : L_offsets[(it + 1) * total_B]],
196
- self.indices_params.index_dtype,
197
- ),
198
- maybe_to_dtype(
199
- it_L_offsets.to(get_device()), self.indices_params.offset_dtype
200
- ),
201
- self._new_weights(int(it_L_offsets[-1].item())),
202
- Bs_feature_rank if self.variable_B() else None,
203
- )
204
- )
205
- return requests
206
-
207
- def _build_requests_dense(
208
- self, iters: int, all_indices: torch.Tensor
209
- ) -> List[TBERequest]:
210
- # NOTE: We're using existing code from requests.py to build the
211
- # requests, and since the existing code requires 2D view of all_indices,
212
- # the existing all_indices must be reshaped
213
- all_indices = all_indices.reshape(iters, -1)
214
-
215
- requests = []
216
- for it in range(iters):
217
- indices, offsets = get_table_batched_offsets_from_dense(
218
- all_indices[it].view(
219
- self.T, self.batch_params.B, self.pooling_params.L
220
- ),
221
- use_cpu=self.use_cpu,
222
- )
223
- requests.append(
224
- TBERequest(
225
- maybe_to_dtype(indices, self.indices_params.index_dtype),
226
- maybe_to_dtype(offsets, self.indices_params.offset_dtype),
227
- self._new_weights(
228
- self.T * self.batch_params.B * self.pooling_params.L
229
- ),
230
- )
231
- )
232
- return requests
233
-
234
- def generate_requests(
235
- self,
236
- iters: int = 1,
237
- ) -> List[TBERequest]:
238
- # Generate batch sizes
239
- Bs, Bs_feature_rank = self._generate_batch_sizes()
240
-
241
- # Generate pooling info
242
- L_offsets = self._generate_pooling_info(iters, Bs)
243
-
244
- # Generate indices
245
- all_indices = self._generate_indices(iters, Bs, L_offsets)
246
-
247
- # Build TBE requests
248
- if self.variable_B() or self.variable_L():
249
- return self._build_requests_jagged(
250
- iters, Bs, Bs_feature_rank, L_offsets, all_indices
251
- )
252
- else:
253
- return self._build_requests_dense(iters, all_indices)
254
-
255
- def generate_embedding_dims(self) -> Tuple[int, List[int]]:
256
- if self.mixed_dim:
257
- Ds = [
258
- round_up(
259
- np.random.randint(low=int(0.5 * self.D), high=int(1.5 * self.D)), 4
260
- )
261
- for _ in range(self.T)
262
- ]
263
- return (int(np.average(Ds)), Ds)
264
- else:
265
- return (self.D, [self.D] * self.T)
266
-
267
- def generate_feature_requires_grad(self, size: int) -> torch.Tensor:
268
- assert size <= self.T, "size of feature_requires_grad must be less than T"
269
- weighted_requires_grad_tables = np.random.choice(
270
- self.T, replace=False, size=(size,)
271
- ).tolist()
272
- return (
273
- torch.tensor(
274
- [1 if t in weighted_requires_grad_tables else 0 for t in range(self.T)]
275
- )
276
- .to(get_device())
277
- .int()
278
- )
@@ -0,0 +1,332 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import logging
11
+ from typing import Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ # fmt:skip
17
+ from fbgemm_gpu.tbe.bench.tbe_data_config import TBEDataConfig
18
+ from fbgemm_gpu.tbe.utils.common import get_device, round_up
19
+ from fbgemm_gpu.tbe.utils.requests import (
20
+ generate_batch_sizes_from_stats,
21
+ generate_pooling_factors_from_stats,
22
+ get_table_batched_offsets_from_dense,
23
+ maybe_to_dtype,
24
+ TBERequest,
25
+ )
26
+
27
+ try:
28
+ # pyre-ignore[21]
29
+ from fbgemm_gpu import open_source # noqa: F401
30
+ except Exception:
31
+ torch.ops.load_library(
32
+ "//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_generator"
33
+ )
34
+
35
+
36
+ def _generate_batch_sizes(
37
+ tbe_data_config: TBEDataConfig,
38
+ ) -> tuple[list[int], Optional[list[list[int]]]]:
39
+ logging.info(
40
+ f"DEBUG_TBE: [_generate_batch_sizes] VBE tbe_data_config.variable_B()={tbe_data_config.variable_B()}"
41
+ )
42
+ if tbe_data_config.variable_B():
43
+ assert (
44
+ tbe_data_config.batch_params.vbe_num_ranks is not None
45
+ ), "vbe_num_ranks must be set for varaible batch size generation"
46
+ return generate_batch_sizes_from_stats(
47
+ tbe_data_config.batch_params.B,
48
+ tbe_data_config.T,
49
+ # pyre-ignore [6]
50
+ tbe_data_config.batch_params.sigma_B,
51
+ tbe_data_config.batch_params.vbe_num_ranks,
52
+ # pyre-ignore [6]
53
+ tbe_data_config.batch_params.vbe_distribution,
54
+ )
55
+ else:
56
+ return ([tbe_data_config.batch_params.B] * tbe_data_config.T, None)
57
+
58
+
59
+ def _generate_pooling_info(
60
+ tbe_data_config: TBEDataConfig, iters: int, Bs: list[int]
61
+ ) -> torch.Tensor:
62
+ if tbe_data_config.variable_L():
63
+ # Generate L from stats
64
+ _, L_offsets = generate_pooling_factors_from_stats(
65
+ iters,
66
+ Bs,
67
+ tbe_data_config.pooling_params.L,
68
+ # pyre-ignore [6]
69
+ tbe_data_config.pooling_params.sigma_L,
70
+ # pyre-ignore [6]
71
+ tbe_data_config.pooling_params.length_distribution,
72
+ )
73
+ else:
74
+ Ls = [tbe_data_config.pooling_params.L] * (sum(Bs) * iters)
75
+ L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
76
+
77
+ return L_offsets
78
+
79
+
80
+ def _generate_indices(
81
+ tbe_data_config: TBEDataConfig,
82
+ iters: int,
83
+ Bs: list[int],
84
+ L_offsets: torch.Tensor,
85
+ ) -> torch.Tensor:
86
+
87
+ total_B = sum(Bs)
88
+ L_offsets_list = L_offsets.tolist()
89
+ indices_list = []
90
+ for it in range(iters):
91
+ # L_offsets is defined over the entire set of batches for a single iteration
92
+ start_offset = L_offsets_list[it * total_B]
93
+ end_offset = L_offsets_list[(it + 1) * total_B]
94
+
95
+ logging.info(f"DEBUG_TBE: _generate_indices E = {tbe_data_config.E=}")
96
+
97
+ indices_list.append(
98
+ torch.ops.fbgemm.tbe_generate_indices_from_distribution(
99
+ tbe_data_config.indices_params.heavy_hitters,
100
+ tbe_data_config.indices_params.zipf_q,
101
+ tbe_data_config.indices_params.zipf_s,
102
+ # max_index = dimensions of the embedding table
103
+ int(tbe_data_config.E),
104
+ # num_indices = number of indices to generate
105
+ end_offset - start_offset,
106
+ )
107
+ )
108
+
109
+ return torch.cat(indices_list)
110
+
111
+
112
+ def _build_requests_jagged(
113
+ tbe_data_config: TBEDataConfig,
114
+ iters: int,
115
+ Bs: list[int],
116
+ Bs_feature_rank: Optional[list[list[int]]],
117
+ L_offsets: torch.Tensor,
118
+ all_indices: torch.Tensor,
119
+ ) -> list[TBERequest]:
120
+ total_B = sum(Bs)
121
+ all_indices = all_indices.flatten()
122
+ requests = []
123
+ for it in range(iters):
124
+ start_offset = L_offsets[it * total_B]
125
+ it_L_offsets = torch.concat(
126
+ [
127
+ torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
128
+ L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
129
+ ]
130
+ )
131
+ requests.append(
132
+ TBERequest(
133
+ maybe_to_dtype(
134
+ all_indices[start_offset : L_offsets[(it + 1) * total_B]],
135
+ tbe_data_config.indices_params.index_dtype,
136
+ ),
137
+ maybe_to_dtype(
138
+ it_L_offsets.to(get_device()),
139
+ tbe_data_config.indices_params.offset_dtype,
140
+ ),
141
+ tbe_data_config._new_weights(int(it_L_offsets[-1].item())),
142
+ Bs_feature_rank if tbe_data_config.variable_B() else None,
143
+ )
144
+ )
145
+ return requests
146
+
147
+
148
+ def _build_requests_dense(
149
+ tbe_data_config: TBEDataConfig, iters: int, all_indices: torch.Tensor
150
+ ) -> list[TBERequest]:
151
+ # NOTE: We're using existing code from requests.py to build the
152
+ # requests, and since the existing code requires 2D view of all_indices,
153
+ # the existing all_indices must be reshaped
154
+ all_indices = all_indices.reshape(iters, -1)
155
+
156
+ requests = []
157
+ for it in range(iters):
158
+ indices, offsets = get_table_batched_offsets_from_dense(
159
+ all_indices[it].view(
160
+ tbe_data_config.T,
161
+ tbe_data_config.batch_params.B,
162
+ tbe_data_config.pooling_params.L,
163
+ ),
164
+ use_cpu=tbe_data_config.use_cpu,
165
+ )
166
+ requests.append(
167
+ TBERequest(
168
+ maybe_to_dtype(indices, tbe_data_config.indices_params.index_dtype),
169
+ maybe_to_dtype(offsets, tbe_data_config.indices_params.offset_dtype),
170
+ tbe_data_config._new_weights(
171
+ tbe_data_config.T
172
+ * tbe_data_config.batch_params.B
173
+ * tbe_data_config.pooling_params.L
174
+ ),
175
+ )
176
+ )
177
+ return requests
178
+
179
+
180
+ def generate_requests(
181
+ tbe_data_config: TBEDataConfig,
182
+ iters: int = 1,
183
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
184
+ ) -> list[TBERequest]:
185
+
186
+ # Generate batch sizes
187
+ if batch_size_per_feature_per_rank:
188
+ Bs = tbe_data_config.batch_params.Bs
189
+ else:
190
+ Bs, _ = _generate_batch_sizes(tbe_data_config)
191
+
192
+ logging.info(
193
+ f"DEBUG_TBE: VBE [generate_requests] batch_size_per_feature_per_rank={batch_size_per_feature_per_rank} Bs={Bs}"
194
+ )
195
+
196
+ assert Bs is not None, "Batch sizes (Bs) must be set"
197
+
198
+ # Generate pooling info
199
+ L_offsets = _generate_pooling_info(tbe_data_config, iters, Bs)
200
+
201
+ # Generate indices
202
+ all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
203
+ all_indices = all_indices.to(get_device())
204
+
205
+ # Build TBE requests
206
+ if tbe_data_config.variable_B() or tbe_data_config.variable_L():
207
+ if batch_size_per_feature_per_rank:
208
+ return _build_requests_jagged(
209
+ tbe_data_config,
210
+ iters,
211
+ Bs,
212
+ batch_size_per_feature_per_rank,
213
+ L_offsets,
214
+ all_indices,
215
+ )
216
+ else:
217
+ return _build_requests_jagged(
218
+ tbe_data_config,
219
+ iters,
220
+ Bs,
221
+ batch_size_per_feature_per_rank,
222
+ L_offsets,
223
+ all_indices,
224
+ )
225
+ else:
226
+ return _build_requests_dense(tbe_data_config, iters, all_indices)
227
+
228
+
229
+ def generate_requests_with_Llist(
230
+ tbe_data_config: TBEDataConfig,
231
+ L_list: torch.Tensor,
232
+ iters: int = 1,
233
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
234
+ ) -> list[TBERequest]:
235
+ """
236
+ Generate a list of TBERequest objects based on the provided TBE data configuration and L_list
237
+ This function generates batch sizes and pooling information from the input L_list,
238
+ simulates L distributions with Gaussian noise, and creates indices for embedding lookups.
239
+ It supports both variable batch sizes and sequence lengths, building either jagged or dense requests accordingly.
240
+ Args:
241
+ tbe_data_config (TBEDataConfig): Configuration object containing batch parameters and pooling parameters.
242
+ L_list (torch.Tensor): Tensor of base sequence lengths for each batch.
243
+ iters (int, optional): Number of iterations to repeat the generated requests. Defaults to 1.
244
+ batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Optional batch size specification per feature per rank. Defaults to None.
245
+ Returns:
246
+ List[TBERequest]: A list of TBERequest objects constructed according to the configuration and input parameters.
247
+ Raises:
248
+ AssertionError: If batch sizes (Bs) are not set in the tbe_data_config.
249
+ Example:
250
+ >>> requests = generate_requests_with_Llist(tbe_data_config, L_list=torch.tensor([10, 20]), iters=2)
251
+ >>> len(requests)
252
+ 2
253
+ """
254
+
255
+ # Generate batch sizes
256
+ Bs = tbe_data_config.batch_params.Bs
257
+ assert (
258
+ Bs is not None
259
+ ), "Batch sizes (Bs) must be set for generate_requests_with_Llist"
260
+
261
+ # Generate pooling info from L list
262
+ Ls_list = []
263
+ for i in range(len(Bs)):
264
+ L = L_list[i]
265
+ B = Bs[i]
266
+ Ls_iter = np.random.normal(
267
+ loc=L, scale=tbe_data_config.pooling_params.sigma_L, size=B
268
+ ).astype(int)
269
+ Ls_list.append(Ls_iter)
270
+ Ls = np.concatenate(Ls_list)
271
+ Ls[Ls < 0] = 0
272
+ # Use the same L distribution across iters
273
+ Ls = np.tile(Ls, iters)
274
+ L = Ls.max()
275
+ # Make it exclusive cumsum
276
+ L_offsets = torch.from_numpy(np.insert(Ls.cumsum(), 0, 0)).to(torch.long)
277
+
278
+ # Generate indices
279
+ all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
280
+ all_indices = all_indices.to(get_device())
281
+
282
+ # Build TBE requests
283
+ if tbe_data_config.variable_B() or tbe_data_config.variable_L():
284
+ return _build_requests_jagged(
285
+ tbe_data_config,
286
+ iters,
287
+ Bs,
288
+ batch_size_per_feature_per_rank,
289
+ L_offsets,
290
+ all_indices,
291
+ )
292
+ else:
293
+ return _build_requests_dense(tbe_data_config, iters, all_indices)
294
+
295
+
296
+ def generate_embedding_dims(tbe_data_config: TBEDataConfig) -> tuple[int, list[int]]:
297
+ if tbe_data_config.mixed_dim:
298
+ Ds = [
299
+ round_up(
300
+ int(
301
+ torch.randint(
302
+ low=int(0.5 * tbe_data_config.D),
303
+ high=int(1.5 * tbe_data_config.D),
304
+ size=(1,),
305
+ ).item()
306
+ ),
307
+ 4,
308
+ )
309
+ for _ in range(tbe_data_config.T)
310
+ ]
311
+ return (sum(Ds) // len(Ds), Ds)
312
+ else:
313
+ return (tbe_data_config.D, [tbe_data_config.D] * tbe_data_config.T)
314
+
315
+
316
+ def generate_feature_requires_grad(
317
+ tbe_data_config: TBEDataConfig, size: int
318
+ ) -> torch.Tensor:
319
+ assert (
320
+ size <= tbe_data_config.T
321
+ ), "size of feature_requires_grad must be less than T"
322
+ weighted_requires_grad_tables = torch.randperm(tbe_data_config.T)[:size].tolist()
323
+ return (
324
+ torch.tensor(
325
+ [
326
+ 1 if t in weighted_requires_grad_tables else 0
327
+ for t in range(tbe_data_config.T)
328
+ ]
329
+ )
330
+ .to(get_device())
331
+ .int()
332
+ )