fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -7,12 +7,56 @@
7
7
 
8
8
  # pyre-strict
9
9
 
10
+ import dataclasses
11
+ import logging
12
+ import re
13
+ from enum import Enum
14
+
10
15
  import click
11
16
  import torch
12
17
  import yaml
13
18
 
14
- from .tbe_data_config import TBEDataConfig
15
- from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
19
+ # fmt:skip
20
+ from fbgemm_gpu.tbe.bench.tbe_data_config import (
21
+ BatchParams,
22
+ IndicesParams,
23
+ PoolingParams,
24
+ TBEDataConfig,
25
+ )
26
+
27
+
28
+ @dataclasses.dataclass(frozen=True)
29
+ class TBEDataConfigHelperText(Enum):
30
+ # Config File
31
+ TBE_CONFIG = "TBE data configuration filepath. If provided, all other `--tbe-*` options are ignored."
32
+
33
+ # Table Parameters
34
+ TBE_NUM_TABLES = "Number of tables (T)"
35
+ TBE_NUM_EMBEDDINGS = "Number of embeddings (E)"
36
+ TBE_EMBEDDING_DIM = "Embedding dimensions (D)"
37
+ TBE_MIXED_DIM = "Use mixed dimensions"
38
+ TBE_WEIGHTED = "Flag to indicate if the table is weighted"
39
+
40
+ # Batch Parameters
41
+ TBE_BATCH_SIZE = "Batch size (B)"
42
+ TBE_BATCH_VBE_SIGMA = "Standard deviation of B for VBE"
43
+ TBE_BATCH_VBE_DIST = "VBE distribution (choices: 'uniform', 'normal')"
44
+ TBE_BATCH_VBE_RANKS = "Number of ranks for VBE"
45
+
46
+ # Indices Parameters
47
+ TBE_INDICES_HITTERS = "Heavy hitters for indices (comma-delimited list of floats)"
48
+ TBE_INDICES_ZIPF = "Zipf distribution parameters for indices generation (q, s)"
49
+ TBE_INDICES_DTYPE = "The dtype of the table indices (choices: '32', '64')"
50
+ TBE_OFFSETS_DTYPE = "The dtype of the table offsets (choices: '32', '64')"
51
+
52
+ # Pooling Parameters
53
+ TBE_POOLING_SIZE = "Bag size / pooling factor (L)"
54
+ TBE_POOLING_VL_SIGMA = "Standard deviation of L for variable bag size"
55
+ TBE_POOLING_VL_DIST = (
56
+ "Variable bag size distribution (choices: 'uniform', 'normal')"
57
+ )
58
+ TBE_EMBEDDING_SPECS = "Embedding Specs which is List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]"
59
+ TBE_FEATURE_TABLE_MAP = "Mapping of feature-table"
16
60
 
17
61
 
18
62
  class TBEDataConfigLoader:
@@ -20,119 +64,152 @@ class TBEDataConfigLoader:
20
64
  # pyre-ignore [2]
21
65
  def options(cls, func) -> click.Command:
22
66
  options = [
23
- ####################################################################
24
67
  # Config File
25
- ####################################################################
26
68
  click.option(
27
69
  "--tbe-config",
28
70
  type=str,
29
71
  required=False,
30
- help="TBE data configuration filepath. If provided, all other `--tbe-*` options are ignored.",
72
+ help=TBEDataConfigHelperText.TBE_CONFIG.value,
31
73
  ),
32
- ####################################################################
33
74
  # Table Parameters
34
- ####################################################################
35
75
  click.option(
36
76
  "--tbe-num-tables",
37
77
  type=int,
38
78
  default=32,
39
- help="Number of tables (T)",
79
+ help=TBEDataConfigHelperText.TBE_NUM_TABLES.value,
40
80
  ),
41
81
  click.option(
42
82
  "--tbe-num-embeddings",
43
83
  type=int,
44
84
  default=int(1e5),
45
- help="Number of embeddings (E)",
85
+ help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value,
86
+ ),
87
+ click.option(
88
+ "--tbe-num-embeddings-list",
89
+ type=str,
90
+ required=False,
91
+ default=None,
92
+ help="Comma-separated list of number of embeddings (Es)",
46
93
  ),
47
94
  click.option(
48
95
  "--tbe-embedding-dim",
49
96
  type=int,
50
97
  default=128,
51
- help="Embedding dimensions (D)",
98
+ help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value,
99
+ ),
100
+ click.option(
101
+ "--tbe-embedding-dim-list",
102
+ type=str,
103
+ required=False,
104
+ default=None,
105
+ help="Comma-separated list of number of Embedding dimensions (D)",
52
106
  ),
53
107
  click.option(
54
108
  "--tbe-mixed-dim",
55
109
  is_flag=True,
56
110
  default=False,
57
- help="Use mixed dimensions",
111
+ help=TBEDataConfigHelperText.TBE_MIXED_DIM.value,
58
112
  ),
59
113
  click.option(
60
114
  "--tbe-weighted",
61
115
  is_flag=True,
62
116
  default=False,
63
- help="Whether the table is weighted or not",
117
+ help=TBEDataConfigHelperText.TBE_WEIGHTED.value,
118
+ ),
119
+ click.option(
120
+ "--tbe-max-indices",
121
+ type=int,
122
+ required=False,
123
+ default=None,
124
+ help="(Optional) Maximum number of indices, will be calculated if not provided",
64
125
  ),
65
- ####################################################################
66
126
  # Batch Parameters
67
- ####################################################################
68
127
  click.option(
69
- "--tbe-batch-size", type=int, default=512, help="Batch size (B)"
128
+ "--tbe-batch-size",
129
+ type=int,
130
+ default=512,
131
+ help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value,
132
+ ),
133
+ click.option(
134
+ "--tbe-batch-sizes-list",
135
+ type=str,
136
+ required=False,
137
+ default=None,
138
+ help="List Batch sizes per feature (Bs)",
70
139
  ),
71
140
  click.option(
72
141
  "--tbe-batch-vbe-sigma",
73
142
  type=int,
74
143
  required=False,
75
- help="Standard deviation of B for VBE",
144
+ help=TBEDataConfigHelperText.TBE_BATCH_VBE_SIGMA.value,
76
145
  ),
77
146
  click.option(
78
147
  "--tbe-batch-vbe-dist",
79
148
  type=click.Choice(["uniform", "normal"]),
80
149
  required=False,
81
- help="VBE distribution",
150
+ help=TBEDataConfigHelperText.TBE_BATCH_VBE_DIST.value,
82
151
  ),
83
152
  click.option(
84
153
  "--tbe-batch-vbe-ranks",
85
154
  type=int,
86
155
  required=False,
87
- help="Number of ranks for VBE",
156
+ help=TBEDataConfigHelperText.TBE_BATCH_VBE_RANKS.value,
88
157
  ),
89
- ####################################################################
90
158
  # Indices Parameters
91
- ####################################################################
92
159
  click.option(
93
160
  "--tbe-indices-hitters",
94
161
  type=str,
95
162
  default="",
96
- help="TBE heavy hitter indices (comma-delimited list of floats)",
163
+ help=TBEDataConfigHelperText.TBE_INDICES_HITTERS.value,
97
164
  ),
98
165
  click.option(
99
166
  "--tbe-indices-zipf",
100
167
  type=(float, float),
101
168
  default=(0.1, 0.1),
102
- help="Zipf distribution parameters for indices generation (q, s)",
169
+ help=TBEDataConfigHelperText.TBE_INDICES_ZIPF.value,
103
170
  ),
104
171
  click.option(
105
172
  "--tbe-indices-dtype",
106
173
  type=click.Choice(["32", "64"]),
107
174
  default="64",
108
- help="The dtype of the table indices",
175
+ help=TBEDataConfigHelperText.TBE_INDICES_DTYPE.value,
109
176
  ),
110
177
  click.option(
111
178
  "--tbe-offsets-dtype",
112
179
  type=click.Choice(["32", "64"]),
113
180
  default="64",
114
- help="The dtype of the table indices offsets",
181
+ help=TBEDataConfigHelperText.TBE_OFFSETS_DTYPE.value,
115
182
  ),
116
- ####################################################################
117
183
  # Pooling Parameters
118
- ####################################################################
119
184
  click.option(
120
185
  "--tbe-pooling-size",
121
186
  type=int,
122
187
  default=20,
123
- help="Bag size / pooling factor (L)",
188
+ help=TBEDataConfigHelperText.TBE_POOLING_SIZE.value,
124
189
  ),
125
190
  click.option(
126
191
  "--tbe-pooling-vl-sigma",
127
192
  type=int,
128
193
  required=False,
129
- help="Standard deviation of B for VBE",
194
+ help=TBEDataConfigHelperText.TBE_POOLING_VL_SIGMA.value,
130
195
  ),
131
196
  click.option(
132
197
  "--tbe-pooling-vl-dist",
133
198
  type=click.Choice(["uniform", "normal"]),
134
199
  required=False,
135
- help="Pooling factor distribution",
200
+ help=TBEDataConfigHelperText.TBE_POOLING_VL_DIST.value,
201
+ ),
202
+ click.option(
203
+ "--tbe-embedding-specs",
204
+ type=str,
205
+ required=False,
206
+ help=TBEDataConfigHelperText.TBE_EMBEDDING_SPECS.value,
207
+ ),
208
+ click.option(
209
+ "--tbe-feature-table-map",
210
+ type=str,
211
+ required=False,
212
+ help=TBEDataConfigHelperText.TBE_FEATURE_TABLE_MAP.value,
136
213
  ),
137
214
  ]
138
215
 
@@ -154,18 +231,62 @@ class TBEDataConfigLoader:
154
231
  params = context.params
155
232
 
156
233
  # Read table parameters
157
- T = params["tbe_num_tables"]
158
- E = params["tbe_num_embeddings"]
234
+ T = params["tbe_num_tables"] # number of features
235
+ E = params["tbe_num_embeddings"] # feature_rows
236
+ if params["tbe_num_embeddings_list"] is not None:
237
+ Es = [int(x) for x in params["tbe_num_embeddings_list"].split(",")]
238
+ T = len(Es)
239
+ E = sum(Es) // T # average E
240
+ else:
241
+ Es = None
159
242
  D = params["tbe_embedding_dim"]
243
+ if params["tbe_embedding_dim_list"] is not None:
244
+ Ds = [int(x) for x in params["tbe_embedding_dim_list"].split(",")]
245
+ assert (
246
+ len(Ds) == T
247
+ ), f"Expected tbe_embedding_dim_list to have {T} elements, but got {len(Ds)}"
248
+ D = sum(Ds) // T # average D
249
+ else:
250
+ Ds = None
251
+
160
252
  mixed_dim = params["tbe_mixed_dim"]
161
253
  weighted = params["tbe_weighted"]
254
+ if params["tbe_max_indices"] is not None:
255
+ max_indices = params["tbe_max_indices"]
256
+ else:
257
+ max_indices = None
162
258
 
163
259
  # Read batch parameters
164
260
  B = params["tbe_batch_size"]
165
261
  sigma_B = params["tbe_batch_vbe_sigma"]
166
262
  vbe_distribution = params["tbe_batch_vbe_dist"]
167
263
  vbe_num_ranks = params["tbe_batch_vbe_ranks"]
168
- batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks)
264
+ if params["tbe_batch_sizes_list"] is not None:
265
+ Bs = [int(x) for x in params["tbe_batch_sizes_list"].split(",")]
266
+ B = sum(Bs) // T # average B
267
+ else:
268
+ B = params["tbe_batch_size"]
269
+ Bs = None
270
+ batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks, Bs)
271
+
272
+ # Parse embedding_specs: "(E,D),(E,D),..." or "(E,D,loc,dev),(E,D,loc,dev),..."
273
+ # Only the first two values (E, D) are extracted.
274
+ embedding_specs = None
275
+ feature_table_map = None
276
+ if params["tbe_embedding_specs"] is not None:
277
+ try:
278
+ tuples = re.findall(r"\(([^)]+)\)", params["tbe_embedding_specs"])
279
+ if tuples:
280
+ embedding_specs = [
281
+ (int(t.split(",")[0].strip()), int(t.split(",")[1].strip()))
282
+ for t in tuples
283
+ ]
284
+ except (ValueError, IndexError):
285
+ logging.warning("Failed to parse embedding_specs. Setting to None.")
286
+ if params["tbe_feature_table_map"] is not None:
287
+ feature_table_map = [
288
+ int(x) for x in params["tbe_feature_table_map"].split(",")
289
+ ]
169
290
 
170
291
  # Read indices parameters
171
292
  heavy_hitters = (
@@ -200,6 +321,11 @@ class TBEDataConfigLoader:
200
321
  indices_params,
201
322
  pooling_params,
202
323
  not torch.cuda.is_available(),
324
+ Es,
325
+ Ds,
326
+ max_indices,
327
+ embedding_specs,
328
+ feature_table_map,
203
329
  ).validate()
204
330
 
205
331
  @classmethod
@@ -9,7 +9,7 @@
9
9
 
10
10
  import dataclasses
11
11
  import json
12
- from typing import Any, Dict, Optional
12
+ from typing import Any, Optional
13
13
 
14
14
  import torch
15
15
 
@@ -31,6 +31,7 @@ class IndicesParams:
31
31
  heavy_hitters: torch.Tensor
32
32
  # zipf*: parameters for the Zipf distribution (x+q)^{-s}
33
33
  zipf_q: float
34
+ # zipf_s is synonymous with alpha in the literature
34
35
  zipf_s: float
35
36
  # [Optional] dtype for indices tensor
36
37
  index_dtype: Optional[torch.dtype] = None
@@ -39,7 +40,7 @@ class IndicesParams:
39
40
 
40
41
  @classmethod
41
42
  # pyre-ignore [3]
42
- def from_dict(cls, data: Dict[str, Any]):
43
+ def from_dict(cls, data: dict[str, Any]):
43
44
  if not isinstance(data["heavy_hitters"], torch.Tensor):
44
45
  data["heavy_hitters"] = torch.tensor(
45
46
  data["heavy_hitters"], dtype=torch.float32
@@ -53,7 +54,7 @@ class IndicesParams:
53
54
  def from_json(cls, data: str):
54
55
  return cls.from_dict(json.loads(data))
55
56
 
56
- def dict(self) -> Dict[str, Any]:
57
+ def dict(self) -> dict[str, Any]:
57
58
  # https://stackoverflow.com/questions/73735974/convert-dataclass-of-dataclass-to-json-string
58
59
  tmp = dataclasses.asdict(self)
59
60
  # Convert tensor to list for JSON serialization
@@ -97,10 +98,12 @@ class BatchParams:
97
98
  vbe_distribution: Optional[str] = "normal"
98
99
  # Number of ranks for variable batch size generation
99
100
  vbe_num_ranks: Optional[int] = None
101
+ # List of target batch sizes, i.e. number of batch lookups per feature
102
+ Bs: Optional[list[int]] = None
100
103
 
101
104
  @classmethod
102
105
  # pyre-ignore [3]
103
- def from_dict(cls, data: Dict[str, Any]):
106
+ def from_dict(cls, data: dict[str, Any]):
104
107
  return cls(**data)
105
108
 
106
109
  @classmethod
@@ -108,7 +111,7 @@ class BatchParams:
108
111
  def from_json(cls, data: str):
109
112
  return cls.from_dict(json.loads(data))
110
113
 
111
- def dict(self) -> Dict[str, Any]:
114
+ def dict(self) -> dict[str, Any]:
112
115
  return dataclasses.asdict(self)
113
116
 
114
117
  def json(self, format: bool = False) -> str:
@@ -116,7 +119,10 @@ class BatchParams:
116
119
 
117
120
  # pyre-ignore [3]
118
121
  def validate(self):
119
- assert self.B > 0, "B must be positive"
122
+ if self.Bs is not None:
123
+ assert all(b > 0 for b in self.Bs), "All elements in Bs must be positive"
124
+ else:
125
+ assert self.B > 0, "B must be positive"
120
126
  assert not self.sigma_B or self.sigma_B > 0, "sigma_B must be positive"
121
127
  assert (
122
128
  self.vbe_num_ranks is None or self.vbe_num_ranks > 0
@@ -136,10 +142,12 @@ class PoolingParams:
136
142
  sigma_L: Optional[int] = None
137
143
  # [Optional] Distribution of embedding sequence lengths (normal, uniform)
138
144
  length_distribution: Optional[str] = "normal"
145
+ # [Optional] List of target bag sizes, i.e. pooling factors per batch
146
+ Ls: Optional[list[float]] = None
139
147
 
140
148
  @classmethod
141
149
  # pyre-ignore [3]
142
- def from_dict(cls, data: Dict[str, Any]):
150
+ def from_dict(cls, data: dict[str, Any]):
143
151
  return cls(**data)
144
152
 
145
153
  @classmethod
@@ -147,7 +155,7 @@ class PoolingParams:
147
155
  def from_json(cls, data: str):
148
156
  return cls.from_dict(json.loads(data))
149
157
 
150
- def dict(self) -> Dict[str, Any]:
158
+ def dict(self) -> dict[str, Any]:
151
159
  return dataclasses.asdict(self)
152
160
 
153
161
  def json(self, format: bool = False) -> str:
@@ -6,15 +6,14 @@
6
6
 
7
7
  # pyre-strict
8
8
 
9
- import logging
9
+ from typing import List, Tuple
10
10
 
11
11
  import numpy as np
12
12
  import torch
13
13
 
14
+ # fmt:skip
14
15
  from fbgemm_gpu.split_embedding_configs import SparseType
15
16
 
16
- logging.basicConfig(level=logging.DEBUG)
17
-
18
17
 
19
18
  def fill_random_scale_bias(
20
19
  emb: torch.nn.Module,
@@ -23,9 +22,9 @@ def fill_random_scale_bias(
23
22
  ) -> None:
24
23
  for t in range(T):
25
24
  # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
26
- (weights, scale_shift) = emb.split_embedding_weights()[t]
25
+ weights, scale_shift = emb.split_embedding_weights()[t]
27
26
  if scale_shift is not None:
28
- (E, R) = scale_shift.shape
27
+ E, R = scale_shift.shape
29
28
  assert R == 4
30
29
  scales = None
31
30
  shifts = None
@@ -46,3 +45,128 @@ def fill_random_scale_bias(
46
45
  device=scale_shift.device,
47
46
  )
48
47
  )
48
+
49
+
50
+ def check_oom(
51
+ data_size: int,
52
+ ) -> Tuple[bool, str]:
53
+ free_memory, total_memory = torch.cuda.mem_get_info()
54
+ if data_size > free_memory:
55
+ warning = f"Expect to allocate {round(data_size / (1024 ** 3), 2)} GB, but available memory is {round(free_memory / (1024 ** 3), 2)} GB from {round(total_memory / (1024 ** 3), 2)} GB."
56
+ return (True, warning)
57
+ return (False, "")
58
+
59
+
60
+ def generate_batch_size_per_feature_per_rank(
61
+ Bs: List[int], num_ranks: int
62
+ ) -> List[List[int]]:
63
+ """
64
+ Generate batch size per feature per rank for VBE, assuming the batch size
65
+ is evenly distributed across ranks.
66
+ Args:
67
+ Bs (List[int]): batch size per feature
68
+ num_ranks (int): number of ranks
69
+ Returns:
70
+ List[List[int]]: batch size per feature per rank
71
+ """
72
+ b_per_feature_per_rank = []
73
+ for B in Bs:
74
+ b_per_feature = []
75
+ for i in range(num_ranks):
76
+ if i != num_ranks - 1:
77
+ b_per_feature.append(int(B / num_ranks))
78
+ else:
79
+ b_per_feature.append(B - sum(b_per_feature))
80
+ b_per_feature_per_rank.append(b_per_feature)
81
+ return b_per_feature_per_rank
82
+
83
+
84
+ def generate_merged_output_and_offsets(
85
+ Ds: List[int],
86
+ Bs: List[int],
87
+ output_dtype: torch.dtype,
88
+ device: torch.device,
89
+ num_ranks: int = 2,
90
+ num_tbe_ops: int = 2,
91
+ ) -> Tuple[List[List[int]], torch.Tensor, torch.Tensor]:
92
+ """
93
+ Generate merged vbe_output and vbe_output_offsets tensors for VBE.
94
+ The vbe_output is a tensor that will contain forward output from all VBE TBE ops.
95
+ The vbe_output_offsets is a tensor that will contain start offsets for the output to be written to.
96
+
97
+ Args:
98
+ Ds (List[int]): embedding dimension per feature
99
+ Bs (List[int]): batch size per feature
100
+ num_ranks (int): number of ranks
101
+ num_tbe_ops (int): number of TBE ops
102
+ Returns:
103
+ Tuple[List[List[int]], torch.Tensor, torch.Tensor]: batch_size_per_feature_per_rank, merged vbe_output and vbe_output_offsets tensors
104
+ """
105
+ # The first embedding ops is the embedding op created in the benchmark
106
+ emb_op = {}
107
+ emb_op[0] = {}
108
+ emb_op[0]["dim"] = Ds
109
+ emb_op[0]["Bs"] = Bs
110
+ emb_op[0]["output_size"] = sum([b * d for b, d in zip(Bs, Ds)])
111
+ emb_op[0]["batch_size_per_feature_per_rank"] = (
112
+ generate_batch_size_per_feature_per_rank(Bs, num_ranks)
113
+ )
114
+ num_features = len(Bs)
115
+ # create other embedding ops to allocate output and offsets tensors
116
+ # Using representative values for additional TBE ops in multi-op scenarios:
117
+ # - batch_size=32000: typical large batch size for production workloads
118
+ # - dim=512: common embedding dimension for large models
119
+ for i in range(1, num_tbe_ops):
120
+ emb_op[i] = {}
121
+ emb_op[i]["batch_size_per_feature_per_rank"] = (
122
+ generate_batch_size_per_feature_per_rank([32000], num_ranks)
123
+ )
124
+ emb_op[i]["Bs"] = [sum(B) for B in emb_op[i]["batch_size_per_feature_per_rank"]]
125
+ emb_op[i]["dim"] = [512]
126
+ emb_op[i]["output_size"] = sum(
127
+ [b * d for b, d in zip(emb_op[i]["Bs"], emb_op[i]["dim"])]
128
+ )
129
+ total_output = 0
130
+ ranks = [[] for _ in range(num_ranks)]
131
+ for e in emb_op.values():
132
+ b_per_rank_per_feature = list(zip(*e["batch_size_per_feature_per_rank"]))
133
+ assert len(b_per_rank_per_feature) == num_ranks
134
+ dims = e["dim"]
135
+ for r, b_r in enumerate(b_per_rank_per_feature):
136
+ for f, b in enumerate(b_r):
137
+ output_size_per_batch = b * dims[f]
138
+ ranks[r].append(output_size_per_batch)
139
+ total_output += output_size_per_batch
140
+ ranks[0].insert(0, 0)
141
+ offsets_ranks: List[List[int]] = [[] for _ in range(num_ranks)]
142
+ total_output_offsets = []
143
+ start = 0
144
+ for r in range(num_ranks):
145
+ offsets_ranks[r] = [
146
+ start + sum(ranks[r][: i + 1]) for i in range(len(ranks[r]))
147
+ ]
148
+ start = offsets_ranks[r][-1]
149
+ total_output_offsets.extend(offsets_ranks[r])
150
+ check_total_output_size = sum([e["output_size"] for e in emb_op.values()])
151
+ assert (
152
+ total_output == check_total_output_size
153
+ ), f"{total_output} != {check_total_output_size}{[e['output_size'] for e in emb_op.values()]}"
154
+ assert (
155
+ total_output == total_output_offsets[-1]
156
+ ), f"{total_output} != {total_output_offsets[-1]}"
157
+ out = torch.empty(total_output, dtype=output_dtype, device=device)
158
+ offsets = []
159
+ offsets.append(offsets_ranks[0][:num_features])
160
+ for r in range(1, num_ranks):
161
+ start = [offsets_ranks[r - 1][-1]]
162
+ the_rest = offsets_ranks[r][: num_features - 1] if num_features > 1 else []
163
+ start.extend(the_rest)
164
+ offsets.append(start)
165
+
166
+ out_offsets = torch.tensor(
167
+ offsets,
168
+ dtype=torch.int64,
169
+ device=device,
170
+ )
171
+ batch_size_per_feature_per_rank = emb_op[0]["batch_size_per_feature_per_rank"]
172
+ return (batch_size_per_feature_per_rank, out, out_offsets)
@@ -7,4 +7,5 @@
7
7
 
8
8
  # pyre-unsafe
9
9
 
10
+ from .kv_embedding_ops_inference import KVEmbeddingInference # noqa: F401
10
11
  from .split_embeddings_cache_ops import get_unique_indices_v2 # noqa: F401