fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -8,14 +8,21 @@
8
8
  # pyre-strict
9
9
 
10
10
  import dataclasses
11
+ import logging
12
+ import re
11
13
  from enum import Enum
12
14
 
13
15
  import click
14
16
  import torch
15
17
  import yaml
16
18
 
17
- from .tbe_data_config import TBEDataConfig
18
- from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
19
+ # fmt:skip
20
+ from fbgemm_gpu.tbe.bench.tbe_data_config import (
21
+ BatchParams,
22
+ IndicesParams,
23
+ PoolingParams,
24
+ TBEDataConfig,
25
+ )
19
26
 
20
27
 
21
28
  @dataclasses.dataclass(frozen=True)
@@ -40,12 +47,16 @@ class TBEDataConfigHelperText(Enum):
40
47
  TBE_INDICES_HITTERS = "Heavy hitters for indices (comma-delimited list of floats)"
41
48
  TBE_INDICES_ZIPF = "Zipf distribution parameters for indices generation (q, s)"
42
49
  TBE_INDICES_DTYPE = "The dtype of the table indices (choices: '32', '64')"
43
- TBE_OFFSETS_DTYPE = "The dtype of the table indices (choices: '32', '64')"
50
+ TBE_OFFSETS_DTYPE = "The dtype of the table offsets (choices: '32', '64')"
44
51
 
45
52
  # Pooling Parameters
46
53
  TBE_POOLING_SIZE = "Bag size / pooling factor (L)"
47
- TBE_POOLING_VL_SIGMA = "Standard deviation of B for VBE"
48
- TBE_POOLING_VL_DIST = "VBE distribution (choices: 'uniform', 'normal')"
54
+ TBE_POOLING_VL_SIGMA = "Standard deviation of L for variable bag size"
55
+ TBE_POOLING_VL_DIST = (
56
+ "Variable bag size distribution (choices: 'uniform', 'normal')"
57
+ )
58
+ TBE_EMBEDDING_SPECS = "Embedding Specs which is List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]"
59
+ TBE_FEATURE_TABLE_MAP = "Mapping of feature-table"
49
60
 
50
61
 
51
62
  class TBEDataConfigLoader:
@@ -73,12 +84,26 @@ class TBEDataConfigLoader:
73
84
  default=int(1e5),
74
85
  help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value,
75
86
  ),
87
+ click.option(
88
+ "--tbe-num-embeddings-list",
89
+ type=str,
90
+ required=False,
91
+ default=None,
92
+ help="Comma-separated list of number of embeddings (Es)",
93
+ ),
76
94
  click.option(
77
95
  "--tbe-embedding-dim",
78
96
  type=int,
79
97
  default=128,
80
98
  help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value,
81
99
  ),
100
+ click.option(
101
+ "--tbe-embedding-dim-list",
102
+ type=str,
103
+ required=False,
104
+ default=None,
105
+ help="Comma-separated list of number of Embedding dimensions (D)",
106
+ ),
82
107
  click.option(
83
108
  "--tbe-mixed-dim",
84
109
  is_flag=True,
@@ -91,6 +116,13 @@ class TBEDataConfigLoader:
91
116
  default=False,
92
117
  help=TBEDataConfigHelperText.TBE_WEIGHTED.value,
93
118
  ),
119
+ click.option(
120
+ "--tbe-max-indices",
121
+ type=int,
122
+ required=False,
123
+ default=None,
124
+ help="(Optional) Maximum number of indices, will be calculated if not provided",
125
+ ),
94
126
  # Batch Parameters
95
127
  click.option(
96
128
  "--tbe-batch-size",
@@ -98,6 +130,13 @@ class TBEDataConfigLoader:
98
130
  default=512,
99
131
  help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value,
100
132
  ),
133
+ click.option(
134
+ "--tbe-batch-sizes-list",
135
+ type=str,
136
+ required=False,
137
+ default=None,
138
+ help="List Batch sizes per feature (Bs)",
139
+ ),
101
140
  click.option(
102
141
  "--tbe-batch-vbe-sigma",
103
142
  type=int,
@@ -160,6 +199,18 @@ class TBEDataConfigLoader:
160
199
  required=False,
161
200
  help=TBEDataConfigHelperText.TBE_POOLING_VL_DIST.value,
162
201
  ),
202
+ click.option(
203
+ "--tbe-embedding-specs",
204
+ type=str,
205
+ required=False,
206
+ help=TBEDataConfigHelperText.TBE_EMBEDDING_SPECS.value,
207
+ ),
208
+ click.option(
209
+ "--tbe-feature-table-map",
210
+ type=str,
211
+ required=False,
212
+ help=TBEDataConfigHelperText.TBE_FEATURE_TABLE_MAP.value,
213
+ ),
163
214
  ]
164
215
 
165
216
  for option in reversed(options):
@@ -180,18 +231,62 @@ class TBEDataConfigLoader:
180
231
  params = context.params
181
232
 
182
233
  # Read table parameters
183
- T = params["tbe_num_tables"]
184
- E = params["tbe_num_embeddings"]
234
+ T = params["tbe_num_tables"] # number of features
235
+ E = params["tbe_num_embeddings"] # feature_rows
236
+ if params["tbe_num_embeddings_list"] is not None:
237
+ Es = [int(x) for x in params["tbe_num_embeddings_list"].split(",")]
238
+ T = len(Es)
239
+ E = sum(Es) // T # average E
240
+ else:
241
+ Es = None
185
242
  D = params["tbe_embedding_dim"]
243
+ if params["tbe_embedding_dim_list"] is not None:
244
+ Ds = [int(x) for x in params["tbe_embedding_dim_list"].split(",")]
245
+ assert (
246
+ len(Ds) == T
247
+ ), f"Expected tbe_embedding_dim_list to have {T} elements, but got {len(Ds)}"
248
+ D = sum(Ds) // T # average D
249
+ else:
250
+ Ds = None
251
+
186
252
  mixed_dim = params["tbe_mixed_dim"]
187
253
  weighted = params["tbe_weighted"]
254
+ if params["tbe_max_indices"] is not None:
255
+ max_indices = params["tbe_max_indices"]
256
+ else:
257
+ max_indices = None
188
258
 
189
259
  # Read batch parameters
190
260
  B = params["tbe_batch_size"]
191
261
  sigma_B = params["tbe_batch_vbe_sigma"]
192
262
  vbe_distribution = params["tbe_batch_vbe_dist"]
193
263
  vbe_num_ranks = params["tbe_batch_vbe_ranks"]
194
- batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks)
264
+ if params["tbe_batch_sizes_list"] is not None:
265
+ Bs = [int(x) for x in params["tbe_batch_sizes_list"].split(",")]
266
+ B = sum(Bs) // T # average B
267
+ else:
268
+ B = params["tbe_batch_size"]
269
+ Bs = None
270
+ batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks, Bs)
271
+
272
+ # Parse embedding_specs: "(E,D),(E,D),..." or "(E,D,loc,dev),(E,D,loc,dev),..."
273
+ # Only the first two values (E, D) are extracted.
274
+ embedding_specs = None
275
+ feature_table_map = None
276
+ if params["tbe_embedding_specs"] is not None:
277
+ try:
278
+ tuples = re.findall(r"\(([^)]+)\)", params["tbe_embedding_specs"])
279
+ if tuples:
280
+ embedding_specs = [
281
+ (int(t.split(",")[0].strip()), int(t.split(",")[1].strip()))
282
+ for t in tuples
283
+ ]
284
+ except (ValueError, IndexError):
285
+ logging.warning("Failed to parse embedding_specs. Setting to None.")
286
+ if params["tbe_feature_table_map"] is not None:
287
+ feature_table_map = [
288
+ int(x) for x in params["tbe_feature_table_map"].split(",")
289
+ ]
195
290
 
196
291
  # Read indices parameters
197
292
  heavy_hitters = (
@@ -226,6 +321,11 @@ class TBEDataConfigLoader:
226
321
  indices_params,
227
322
  pooling_params,
228
323
  not torch.cuda.is_available(),
324
+ Es,
325
+ Ds,
326
+ max_indices,
327
+ embedding_specs,
328
+ feature_table_map,
229
329
  ).validate()
230
330
 
231
331
  @classmethod
@@ -9,7 +9,7 @@
9
9
 
10
10
  import dataclasses
11
11
  import json
12
- from typing import Any, Dict, Optional
12
+ from typing import Any, Optional
13
13
 
14
14
  import torch
15
15
 
@@ -40,7 +40,7 @@ class IndicesParams:
40
40
 
41
41
  @classmethod
42
42
  # pyre-ignore [3]
43
- def from_dict(cls, data: Dict[str, Any]):
43
+ def from_dict(cls, data: dict[str, Any]):
44
44
  if not isinstance(data["heavy_hitters"], torch.Tensor):
45
45
  data["heavy_hitters"] = torch.tensor(
46
46
  data["heavy_hitters"], dtype=torch.float32
@@ -54,7 +54,7 @@ class IndicesParams:
54
54
  def from_json(cls, data: str):
55
55
  return cls.from_dict(json.loads(data))
56
56
 
57
- def dict(self) -> Dict[str, Any]:
57
+ def dict(self) -> dict[str, Any]:
58
58
  # https://stackoverflow.com/questions/73735974/convert-dataclass-of-dataclass-to-json-string
59
59
  tmp = dataclasses.asdict(self)
60
60
  # Convert tensor to list for JSON serialization
@@ -98,10 +98,12 @@ class BatchParams:
98
98
  vbe_distribution: Optional[str] = "normal"
99
99
  # Number of ranks for variable batch size generation
100
100
  vbe_num_ranks: Optional[int] = None
101
+ # List of target batch sizes, i.e. number of batch lookups per feature
102
+ Bs: Optional[list[int]] = None
101
103
 
102
104
  @classmethod
103
105
  # pyre-ignore [3]
104
- def from_dict(cls, data: Dict[str, Any]):
106
+ def from_dict(cls, data: dict[str, Any]):
105
107
  return cls(**data)
106
108
 
107
109
  @classmethod
@@ -109,7 +111,7 @@ class BatchParams:
109
111
  def from_json(cls, data: str):
110
112
  return cls.from_dict(json.loads(data))
111
113
 
112
- def dict(self) -> Dict[str, Any]:
114
+ def dict(self) -> dict[str, Any]:
113
115
  return dataclasses.asdict(self)
114
116
 
115
117
  def json(self, format: bool = False) -> str:
@@ -117,7 +119,10 @@ class BatchParams:
117
119
 
118
120
  # pyre-ignore [3]
119
121
  def validate(self):
120
- assert self.B > 0, "B must be positive"
122
+ if self.Bs is not None:
123
+ assert all(b > 0 for b in self.Bs), "All elements in Bs must be positive"
124
+ else:
125
+ assert self.B > 0, "B must be positive"
121
126
  assert not self.sigma_B or self.sigma_B > 0, "sigma_B must be positive"
122
127
  assert (
123
128
  self.vbe_num_ranks is None or self.vbe_num_ranks > 0
@@ -137,10 +142,12 @@ class PoolingParams:
137
142
  sigma_L: Optional[int] = None
138
143
  # [Optional] Distribution of embedding sequence lengths (normal, uniform)
139
144
  length_distribution: Optional[str] = "normal"
145
+ # [Optional] List of target bag sizes, i.e. pooling factors per batch
146
+ Ls: Optional[list[float]] = None
140
147
 
141
148
  @classmethod
142
149
  # pyre-ignore [3]
143
- def from_dict(cls, data: Dict[str, Any]):
150
+ def from_dict(cls, data: dict[str, Any]):
144
151
  return cls(**data)
145
152
 
146
153
  @classmethod
@@ -148,7 +155,7 @@ class PoolingParams:
148
155
  def from_json(cls, data: str):
149
156
  return cls.from_dict(json.loads(data))
150
157
 
151
- def dict(self) -> Dict[str, Any]:
158
+ def dict(self) -> dict[str, Any]:
152
159
  return dataclasses.asdict(self)
153
160
 
154
161
  def json(self, format: bool = False) -> str:
@@ -6,15 +6,14 @@
6
6
 
7
7
  # pyre-strict
8
8
 
9
- import logging
9
+ from typing import List, Tuple
10
10
 
11
11
  import numpy as np
12
12
  import torch
13
13
 
14
+ # fmt:skip
14
15
  from fbgemm_gpu.split_embedding_configs import SparseType
15
16
 
16
- logging.basicConfig(level=logging.DEBUG)
17
-
18
17
 
19
18
  def fill_random_scale_bias(
20
19
  emb: torch.nn.Module,
@@ -23,9 +22,9 @@ def fill_random_scale_bias(
23
22
  ) -> None:
24
23
  for t in range(T):
25
24
  # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
26
- (weights, scale_shift) = emb.split_embedding_weights()[t]
25
+ weights, scale_shift = emb.split_embedding_weights()[t]
27
26
  if scale_shift is not None:
28
- (E, R) = scale_shift.shape
27
+ E, R = scale_shift.shape
29
28
  assert R == 4
30
29
  scales = None
31
30
  shifts = None
@@ -46,3 +45,128 @@ def fill_random_scale_bias(
46
45
  device=scale_shift.device,
47
46
  )
48
47
  )
48
+
49
+
50
+ def check_oom(
51
+ data_size: int,
52
+ ) -> Tuple[bool, str]:
53
+ free_memory, total_memory = torch.cuda.mem_get_info()
54
+ if data_size > free_memory:
55
+ warning = f"Expect to allocate {round(data_size / (1024 ** 3), 2)} GB, but available memory is {round(free_memory / (1024 ** 3), 2)} GB from {round(total_memory / (1024 ** 3), 2)} GB."
56
+ return (True, warning)
57
+ return (False, "")
58
+
59
+
60
+ def generate_batch_size_per_feature_per_rank(
61
+ Bs: List[int], num_ranks: int
62
+ ) -> List[List[int]]:
63
+ """
64
+ Generate batch size per feature per rank for VBE, assuming the batch size
65
+ is evenly distributed across ranks.
66
+ Args:
67
+ Bs (List[int]): batch size per feature
68
+ num_ranks (int): number of ranks
69
+ Returns:
70
+ List[List[int]]: batch size per feature per rank
71
+ """
72
+ b_per_feature_per_rank = []
73
+ for B in Bs:
74
+ b_per_feature = []
75
+ for i in range(num_ranks):
76
+ if i != num_ranks - 1:
77
+ b_per_feature.append(int(B / num_ranks))
78
+ else:
79
+ b_per_feature.append(B - sum(b_per_feature))
80
+ b_per_feature_per_rank.append(b_per_feature)
81
+ return b_per_feature_per_rank
82
+
83
+
84
+ def generate_merged_output_and_offsets(
85
+ Ds: List[int],
86
+ Bs: List[int],
87
+ output_dtype: torch.dtype,
88
+ device: torch.device,
89
+ num_ranks: int = 2,
90
+ num_tbe_ops: int = 2,
91
+ ) -> Tuple[List[List[int]], torch.Tensor, torch.Tensor]:
92
+ """
93
+ Generate merged vbe_output and vbe_output_offsets tensors for VBE.
94
+ The vbe_output is a tensor that will contain forward output from all VBE TBE ops.
95
+ The vbe_output_offsets is a tensor that will contain start offsets for the output to be written to.
96
+
97
+ Args:
98
+ Ds (List[int]): embedding dimension per feature
99
+ Bs (List[int]): batch size per feature
100
+ num_ranks (int): number of ranks
101
+ num_tbe_ops (int): number of TBE ops
102
+ Returns:
103
+ Tuple[List[List[int]], torch.Tensor, torch.Tensor]: batch_size_per_feature_per_rank, merged vbe_output and vbe_output_offsets tensors
104
+ """
105
+ # The first embedding ops is the embedding op created in the benchmark
106
+ emb_op = {}
107
+ emb_op[0] = {}
108
+ emb_op[0]["dim"] = Ds
109
+ emb_op[0]["Bs"] = Bs
110
+ emb_op[0]["output_size"] = sum([b * d for b, d in zip(Bs, Ds)])
111
+ emb_op[0]["batch_size_per_feature_per_rank"] = (
112
+ generate_batch_size_per_feature_per_rank(Bs, num_ranks)
113
+ )
114
+ num_features = len(Bs)
115
+ # create other embedding ops to allocate output and offsets tensors
116
+ # Using representative values for additional TBE ops in multi-op scenarios:
117
+ # - batch_size=32000: typical large batch size for production workloads
118
+ # - dim=512: common embedding dimension for large models
119
+ for i in range(1, num_tbe_ops):
120
+ emb_op[i] = {}
121
+ emb_op[i]["batch_size_per_feature_per_rank"] = (
122
+ generate_batch_size_per_feature_per_rank([32000], num_ranks)
123
+ )
124
+ emb_op[i]["Bs"] = [sum(B) for B in emb_op[i]["batch_size_per_feature_per_rank"]]
125
+ emb_op[i]["dim"] = [512]
126
+ emb_op[i]["output_size"] = sum(
127
+ [b * d for b, d in zip(emb_op[i]["Bs"], emb_op[i]["dim"])]
128
+ )
129
+ total_output = 0
130
+ ranks = [[] for _ in range(num_ranks)]
131
+ for e in emb_op.values():
132
+ b_per_rank_per_feature = list(zip(*e["batch_size_per_feature_per_rank"]))
133
+ assert len(b_per_rank_per_feature) == num_ranks
134
+ dims = e["dim"]
135
+ for r, b_r in enumerate(b_per_rank_per_feature):
136
+ for f, b in enumerate(b_r):
137
+ output_size_per_batch = b * dims[f]
138
+ ranks[r].append(output_size_per_batch)
139
+ total_output += output_size_per_batch
140
+ ranks[0].insert(0, 0)
141
+ offsets_ranks: List[List[int]] = [[] for _ in range(num_ranks)]
142
+ total_output_offsets = []
143
+ start = 0
144
+ for r in range(num_ranks):
145
+ offsets_ranks[r] = [
146
+ start + sum(ranks[r][: i + 1]) for i in range(len(ranks[r]))
147
+ ]
148
+ start = offsets_ranks[r][-1]
149
+ total_output_offsets.extend(offsets_ranks[r])
150
+ check_total_output_size = sum([e["output_size"] for e in emb_op.values()])
151
+ assert (
152
+ total_output == check_total_output_size
153
+ ), f"{total_output} != {check_total_output_size}{[e['output_size'] for e in emb_op.values()]}"
154
+ assert (
155
+ total_output == total_output_offsets[-1]
156
+ ), f"{total_output} != {total_output_offsets[-1]}"
157
+ out = torch.empty(total_output, dtype=output_dtype, device=device)
158
+ offsets = []
159
+ offsets.append(offsets_ranks[0][:num_features])
160
+ for r in range(1, num_ranks):
161
+ start = [offsets_ranks[r - 1][-1]]
162
+ the_rest = offsets_ranks[r][: num_features - 1] if num_features > 1 else []
163
+ start.extend(the_rest)
164
+ offsets.append(start)
165
+
166
+ out_offsets = torch.tensor(
167
+ offsets,
168
+ dtype=torch.int64,
169
+ device=device,
170
+ )
171
+ batch_size_per_feature_per_rank = emb_op[0]["batch_size_per_feature_per_rank"]
172
+ return (batch_size_per_feature_per_rank, out, out_offsets)
@@ -10,7 +10,7 @@
10
10
  # pyre-ignore-all-errors[56]
11
11
 
12
12
 
13
- from typing import List, Optional, Tuple, Union
13
+ from typing import Optional, Union
14
14
 
15
15
  import torch # usort:skip
16
16
  from torch import Tensor # usort:skip
@@ -47,15 +47,15 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
47
47
 
48
48
  def __init__( # noqa C901
49
49
  self,
50
- embedding_specs: List[
51
- Tuple[str, int, int, SparseType, EmbeddingLocation]
50
+ embedding_specs: list[
51
+ tuple[str, int, int, SparseType, EmbeddingLocation]
52
52
  ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
53
- feature_table_map: Optional[List[int]] = None, # [T]
54
- index_remapping: Optional[List[Tensor]] = None,
53
+ feature_table_map: Optional[list[int]] = None, # [T]
54
+ index_remapping: Optional[list[Tensor]] = None,
55
55
  pooling_mode: PoolingMode = PoolingMode.SUM,
56
56
  device: Optional[Union[str, int, torch.device]] = None,
57
57
  bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
58
- weight_lists: Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None,
58
+ weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
59
59
  pruning_hash_load_factor: float = 0.5,
60
60
  use_array_for_index_remapping: bool = True,
61
61
  output_dtype: SparseType = SparseType.FP16,
@@ -74,8 +74,9 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
74
74
  cacheline_alignment: bool = True,
75
75
  uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
76
76
  reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
77
- feature_names_per_table: Optional[List[List[str]]] = None,
77
+ feature_names_per_table: Optional[list[list[str]]] = None,
78
78
  indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
79
+ embedding_cache_mode: bool = False, # True for zero initialization, False for randomized initialization
79
80
  ) -> None: # noqa C901 # tuple of (rows, dims,)
80
81
  super(KVEmbeddingInference, self).__init__(
81
82
  embedding_specs=embedding_specs,
@@ -114,17 +115,21 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
114
115
  num_shards = 32
115
116
  uniform_init_lower: float = -0.01
116
117
  uniform_init_upper: float = 0.01
118
+
117
119
  # pyre-fixme[4]: Attribute must be annotated.
118
120
  self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
119
- num_shards, uniform_init_lower, uniform_init_upper
121
+ num_shards,
122
+ uniform_init_lower,
123
+ uniform_init_upper,
124
+ embedding_cache_mode, # in embedding_cache_mode, we disable random init
120
125
  )
121
126
 
122
- self.specs: List[Tuple[int, int, int]] = [
127
+ self.specs: list[tuple[int, int, int]] = [
123
128
  (rows, dims, sparse_type.as_int())
124
129
  for (_, rows, dims, sparse_type, _) in self.embedding_specs
125
130
  ]
126
131
  # table shard offset if inference sharding is enabled, otherwise, should be all zeros
127
- self.table_sharding_offset: List[int] = [0] * len(self.embedding_specs)
132
+ self.table_sharding_offset: list[int] = [0] * len(self.embedding_specs)
128
133
  self.kv_embedding_cache_initialized = False
129
134
  self.hash_size_cumsum: torch.Tensor = torch.zeros(
130
135
  0,
@@ -137,7 +142,7 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
137
142
  dtype=torch.int64,
138
143
  )
139
144
 
140
- def construct_hash_size_cumsum(self) -> List[int]:
145
+ def construct_hash_size_cumsum(self) -> list[int]:
141
146
  hash_size_cumsum = [0]
142
147
  for spec in self.embedding_specs:
143
148
  rows = spec[1]
@@ -146,7 +151,7 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
146
151
 
147
152
  def calculate_indices_and_weights_offsets(
148
153
  self, indices: Tensor, offsets: Tensor
149
- ) -> Tuple[Tensor, Tensor]:
154
+ ) -> tuple[Tensor, Tensor]:
150
155
  if self.pooling_mode is not PoolingMode.NONE:
151
156
  T = self.weights_offsets.numel()
152
157
  else:
@@ -280,7 +285,7 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
280
285
  self.weight_initialized = True
281
286
 
282
287
  @torch.jit.export
283
- def init_tbe_config(self, table_sharding_offset: List[int]) -> None:
288
+ def init_tbe_config(self, table_sharding_offset: list[int]) -> None:
284
289
  """
285
290
  Initialize the dynamic TBE table configs, e.g. sharded table offsets, etc.
286
291
  Should be called before loading weights.
@@ -290,9 +295,9 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
290
295
  @torch.jit.export
291
296
  def embedding_inplace_update(
292
297
  self,
293
- update_table_indices: List[int],
294
- update_row_indices: List[List[int]],
295
- update_weights: List[Tensor],
298
+ update_table_indices: list[int],
299
+ update_row_indices: list[list[int]],
300
+ update_weights: list[Tensor],
296
301
  ) -> None:
297
302
  # function is not used for now on the inference side
298
303
  for i in range(len(update_table_indices)):
@@ -355,9 +360,7 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
355
360
  if not self.kv_embedding_cache_initialized:
356
361
  self.initialize_logical_weights_placements_and_offsets()
357
362
 
358
- self.row_alignment = (
359
- 8 if self.use_cpu else self.row_alignment
360
- ) # in order to use mempool implementation for kv embedding it needs to be divisible by 8
363
+ self.row_alignment = 8 # in order to use mempool implementation for kv embedding it needs to be divisible by 8
361
364
 
362
365
  hash_size_cumsum = self.construct_hash_size_cumsum()
363
366
  self.hash_size_cumsum = torch.tensor(
@@ -6,7 +6,7 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Optional, Tuple, Union
9
+ from typing import Optional, Union
10
10
 
11
11
  import torch
12
12
 
@@ -17,13 +17,13 @@ def get_unique_indices_v2(
17
17
  compute_count: bool = False,
18
18
  compute_inverse_indices: bool = False,
19
19
  ) -> Union[
20
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
21
- Tuple[
20
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
21
+ tuple[
22
22
  torch.Tensor,
23
23
  torch.Tensor,
24
24
  Optional[torch.Tensor],
25
25
  ],
26
- Tuple[torch.Tensor, torch.Tensor],
26
+ tuple[torch.Tensor, torch.Tensor],
27
27
  ]:
28
28
  """
29
29
  A wrapper for get_unique_indices for overloading the return type
@@ -10,6 +10,7 @@
10
10
 
11
11
  import torch
12
12
 
13
+ # fmt:skip
13
14
  from fbgemm_gpu.utils.loader import load_torch_module
14
15
 
15
16
  try:
@@ -13,7 +13,7 @@ import logging
13
13
  import os
14
14
  import tempfile
15
15
  from math import log2
16
- from typing import List, Optional, Tuple
16
+ from typing import Optional
17
17
 
18
18
  import torch # usort:skip
19
19
 
@@ -42,15 +42,15 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
42
42
  Inference version, with FP32/FP16/FP8/INT8/INT4/INT2 supports
43
43
  """
44
44
 
45
- embedding_specs: List[Tuple[str, int, int, SparseType]]
45
+ embedding_specs: list[tuple[str, int, int, SparseType]]
46
46
  _local_instance_index: int = -1
47
47
 
48
48
  def __init__(
49
49
  self,
50
- embedding_specs: List[
51
- Tuple[str, int, int, SparseType]
50
+ embedding_specs: list[
51
+ tuple[str, int, int, SparseType]
52
52
  ], # tuple of (feature_names, rows, dims, SparseType)
53
- feature_table_map: Optional[List[int]] = None, # [T]
53
+ feature_table_map: Optional[list[int]] = None, # [T]
54
54
  pooling_mode: PoolingMode = PoolingMode.SUM,
55
55
  output_dtype: SparseType = SparseType.FP16,
56
56
  row_alignment: Optional[int] = None,
@@ -73,7 +73,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
73
73
  ssd_uniform_init_lower: float = -0.01,
74
74
  ssd_uniform_init_upper: float = 0.01,
75
75
  # Parameter Server Configs
76
- ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
76
+ ps_hosts: Optional[tuple[tuple[str, int]]] = None,
77
77
  ps_max_key_per_request: Optional[int] = None,
78
78
  ps_client_thread_num: Optional[int] = None,
79
79
  ps_max_local_index_length: Optional[int] = None,
@@ -99,7 +99,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
99
99
  self.current_device = torch.device(device)
100
100
  self.use_cpu: bool = self.current_device.type == "cpu"
101
101
 
102
- self.feature_table_map: List[int] = (
102
+ self.feature_table_map: list[int] = (
103
103
  feature_table_map if feature_table_map is not None else list(range(T_))
104
104
  )
105
105
  T = len(self.feature_table_map)
@@ -112,9 +112,9 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
112
112
  self.output_dtype: int = output_dtype.as_int()
113
113
  # (feature_names, rows, dims, weights_tys) = zip(*embedding_specs)
114
114
  # Pyre workaround
115
- rows: List[int] = [e[1] for e in embedding_specs]
116
- dims: List[int] = [e[2] for e in embedding_specs]
117
- weights_tys: List[SparseType] = [e[3] for e in embedding_specs]
115
+ rows: list[int] = [e[1] for e in embedding_specs]
116
+ dims: list[int] = [e[2] for e in embedding_specs]
117
+ weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
118
118
 
119
119
  D_offsets = [dims[t] for t in self.feature_table_map]
120
120
  D_offsets = [0] + list(itertools.accumulate(D_offsets))
@@ -169,7 +169,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
169
169
  offsets.append(uvm_size)
170
170
  uvm_size += state_size
171
171
 
172
- self.weights_physical_offsets: List[int] = offsets
172
+ self.weights_physical_offsets: list[int] = offsets
173
173
 
174
174
  weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map]
175
175
  self.register_buffer(
@@ -306,7 +306,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
306
306
  )
307
307
 
308
308
  # pyre-fixme[20]: Argument `self` expected.
309
- (low_priority, high_priority) = torch.cuda.Stream.priority_range()
309
+ low_priority, high_priority = torch.cuda.Stream.priority_range()
310
310
  self.ssd_stream = torch.cuda.Stream(priority=low_priority)
311
311
  self.ssd_set_start = torch.cuda.Event()
312
312
  self.ssd_set_end = torch.cuda.Event()
@@ -369,7 +369,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
369
369
 
370
370
  @torch.jit.export
371
371
  def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
372
- (indices, offsets) = indices.long(), offsets.long()
372
+ indices, offsets = indices.long(), offsets.long()
373
373
  linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
374
374
  self.hash_size_cumsum,
375
375
  indices,
@@ -517,13 +517,13 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
517
517
  @torch.jit.export
518
518
  def split_embedding_weights(
519
519
  self, split_scale_shifts: bool = True
520
- ) -> List[Tuple[Tensor, Optional[Tensor]]]:
520
+ ) -> list[tuple[Tensor, Optional[Tensor]]]:
521
521
  """
522
522
  Returns a list of weights, split by table.
523
523
 
524
524
  Testing only, very slow.
525
525
  """
526
- splits: List[Tuple[Tensor, Optional[Tensor]]] = []
526
+ splits: list[tuple[Tensor, Optional[Tensor]]] = []
527
527
  rows_cumsum = 0
528
528
  for _, row, dim, weight_ty in self.embedding_specs:
529
529
  weights = torch.empty(