fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -0,0 +1,385 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # pyre-ignore-all-errors[56]
11
+
12
+
13
+ from typing import Optional, Union
14
+
15
+ import torch # usort:skip
16
+ from torch import Tensor # usort:skip
17
+ from fbgemm_gpu.split_embedding_configs import SparseType
18
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
19
+ BoundsCheckMode,
20
+ CacheAlgorithm,
21
+ DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
22
+ EmbeddingLocation,
23
+ PoolingMode,
24
+ RecordCacheMetrics,
25
+ )
26
+ from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
27
+ inputs_to_device,
28
+ IntNBitTableBatchedEmbeddingBagsCodegen,
29
+ random_quant_scaled_tensor,
30
+ rounded_row_size_in_bytes,
31
+ )
32
+ from fbgemm_gpu.utils.loader import load_torch_module
33
+
34
+ try:
35
+ load_torch_module(
36
+ "//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference",
37
+ )
38
+ except Exception:
39
+ pass
40
+
41
+
42
+ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
43
+ """
44
+ KV Table-batched version of nn.EmbeddingBag(sparse=False)
45
+ Inference version, with support for FP32/FP16/FP8/INT8/INT4/INT2 weights
46
+ """
47
+
48
+ def __init__( # noqa C901
49
+ self,
50
+ embedding_specs: list[
51
+ tuple[str, int, int, SparseType, EmbeddingLocation]
52
+ ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
53
+ feature_table_map: Optional[list[int]] = None, # [T]
54
+ index_remapping: Optional[list[Tensor]] = None,
55
+ pooling_mode: PoolingMode = PoolingMode.SUM,
56
+ device: Optional[Union[str, int, torch.device]] = None,
57
+ bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
58
+ weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
59
+ pruning_hash_load_factor: float = 0.5,
60
+ use_array_for_index_remapping: bool = True,
61
+ output_dtype: SparseType = SparseType.FP16,
62
+ cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
63
+ cache_load_factor: float = 0.2,
64
+ cache_sets: int = 0,
65
+ cache_reserved_memory: float = 0.0,
66
+ enforce_hbm: bool = False, # place all weights/momentums in HBM when using cache
67
+ record_cache_metrics: Optional[RecordCacheMetrics] = None,
68
+ gather_uvm_cache_stats: Optional[bool] = False,
69
+ row_alignment: Optional[int] = None,
70
+ fp8_exponent_bits: Optional[int] = None,
71
+ fp8_exponent_bias: Optional[int] = None,
72
+ cache_assoc: int = 32,
73
+ scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
74
+ cacheline_alignment: bool = True,
75
+ uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
76
+ reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
77
+ feature_names_per_table: Optional[list[list[str]]] = None,
78
+ indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
79
+ embedding_cache_mode: bool = False, # True for zero initialization, False for randomized initialization
80
+ ) -> None: # noqa C901 # tuple of (rows, dims,)
81
+ super(KVEmbeddingInference, self).__init__(
82
+ embedding_specs=embedding_specs,
83
+ feature_table_map=feature_table_map,
84
+ index_remapping=index_remapping,
85
+ pooling_mode=pooling_mode,
86
+ device=device,
87
+ bounds_check_mode=bounds_check_mode,
88
+ weight_lists=weight_lists,
89
+ pruning_hash_load_factor=pruning_hash_load_factor,
90
+ use_array_for_index_remapping=use_array_for_index_remapping,
91
+ output_dtype=output_dtype,
92
+ cache_algorithm=cache_algorithm,
93
+ cache_load_factor=cache_load_factor,
94
+ cache_sets=cache_sets,
95
+ cache_reserved_memory=cache_reserved_memory,
96
+ enforce_hbm=enforce_hbm,
97
+ record_cache_metrics=record_cache_metrics,
98
+ gather_uvm_cache_stats=gather_uvm_cache_stats,
99
+ row_alignment=row_alignment,
100
+ fp8_exponent_bits=fp8_exponent_bits,
101
+ fp8_exponent_bias=fp8_exponent_bias,
102
+ cache_assoc=cache_assoc,
103
+ scale_bias_size_in_bytes=scale_bias_size_in_bytes,
104
+ cacheline_alignment=cacheline_alignment,
105
+ uvm_host_mapped=uvm_host_mapped,
106
+ reverse_qparam=reverse_qparam,
107
+ feature_names_per_table=feature_names_per_table,
108
+ indices_dtype=indices_dtype,
109
+ )
110
+ self.register_buffer(
111
+ "weights_ids",
112
+ torch.tensor(0, device=self.current_device, dtype=torch.int64),
113
+ )
114
+
115
+ num_shards = 32
116
+ uniform_init_lower: float = -0.01
117
+ uniform_init_upper: float = 0.01
118
+
119
+ # pyre-fixme[4]: Attribute must be annotated.
120
+ self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
121
+ num_shards,
122
+ uniform_init_lower,
123
+ uniform_init_upper,
124
+ embedding_cache_mode, # in embedding_cache_mode, we disable random init
125
+ )
126
+
127
+ self.specs: list[tuple[int, int, int]] = [
128
+ (rows, dims, sparse_type.as_int())
129
+ for (_, rows, dims, sparse_type, _) in self.embedding_specs
130
+ ]
131
+ # table shard offset if inference sharding is enabled, otherwise, should be all zeros
132
+ self.table_sharding_offset: list[int] = [0] * len(self.embedding_specs)
133
+ self.kv_embedding_cache_initialized = False
134
+ self.hash_size_cumsum: torch.Tensor = torch.zeros(
135
+ 0,
136
+ device=self.current_device,
137
+ dtype=torch.int64,
138
+ )
139
+ self.feature_hash_size_cumsum: torch.Tensor = torch.zeros(
140
+ 0,
141
+ device=self.current_device,
142
+ dtype=torch.int64,
143
+ )
144
+
145
+ def construct_hash_size_cumsum(self) -> list[int]:
146
+ hash_size_cumsum = [0]
147
+ for spec in self.embedding_specs:
148
+ rows = spec[1]
149
+ hash_size_cumsum.append(hash_size_cumsum[-1] + rows)
150
+ return hash_size_cumsum
151
+
152
+ def calculate_indices_and_weights_offsets(
153
+ self, indices: Tensor, offsets: Tensor
154
+ ) -> tuple[Tensor, Tensor]:
155
+ if self.pooling_mode is not PoolingMode.NONE:
156
+ T = self.weights_offsets.numel()
157
+ else:
158
+ T = self.D_offsets.numel() - 1
159
+ B = int((offsets.size(0) - 1) / T)
160
+
161
+ total_bytes_added = 0
162
+ new_indices = torch.tensor(
163
+ [0] * indices.size(0), device=self.current_device, dtype=indices.dtype
164
+ )
165
+ new_weights_offsets = torch.tensor(
166
+ [0] * T, device=self.current_device, dtype=self.weights_offsets.dtype
167
+ )
168
+ for t in range(T):
169
+ new_weights_offsets[t] = total_bytes_added
170
+ start, end = int(offsets[t * B]), int(offsets[(t + 1) * B])
171
+ index_size = end - start
172
+ new_indices[start:end] = torch.arange(index_size)
173
+ table_id = self.feature_table_map[t]
174
+ total_bytes_added += index_size * rounded_row_size_in_bytes(
175
+ self.embedding_specs[table_id][2], # dim
176
+ self.embedding_specs[table_id][3], # weight_ty
177
+ self.row_alignment,
178
+ self.scale_bias_size_in_bytes,
179
+ )
180
+ return new_indices, new_weights_offsets
181
+
182
+ def linearize_cache_indices(
183
+ self,
184
+ indices: torch.Tensor,
185
+ offsets: torch.Tensor,
186
+ ) -> torch.Tensor:
187
+ """
188
+ Linearize cache indices for KV cache.
189
+ """
190
+ linearized_indices = torch.zeros(
191
+ indices.numel(),
192
+ device=indices.device,
193
+ dtype=torch.int64,
194
+ )
195
+
196
+ T = self.feature_hash_size_cumsum.numel() - 1
197
+ B = int((offsets.size(0) - 1) / T)
198
+
199
+ for t in range(T):
200
+ start, end = int(offsets[t * B]), int(offsets[(t + 1) * B])
201
+ linearized_indices[start:end] = (
202
+ indices[start:end] + self.feature_hash_size_cumsum[t]
203
+ )
204
+
205
+ return linearized_indices
206
+
207
+ def forward(
208
+ self,
209
+ indices: Tensor,
210
+ offsets: Tensor,
211
+ per_sample_weights: Optional[Tensor] = None,
212
+ ) -> Tensor:
213
+ assert (
214
+ self.weight_initialized
215
+ ), "weight needs to be initialized before forward function"
216
+
217
+ indices, offsets, per_sample_weights = inputs_to_device(
218
+ indices, offsets, per_sample_weights, self.bounds_check_warning
219
+ )
220
+
221
+ lxu_cache_locations = self.lxu_cache_locations_list.pop()
222
+
223
+ weights_offsets = self.weights_offsets
224
+ weights = self.weights_host if self.host_size > 0 else self.weights_dev
225
+
226
+ if self.kv_embedding_cache_initialized:
227
+ indices = self.linearize_cache_indices(
228
+ indices,
229
+ offsets,
230
+ )
231
+
232
+ weights = self.kv_embedding_cache.get_embeddings(indices)
233
+
234
+ indices, weights_offsets = self.calculate_indices_and_weights_offsets(
235
+ indices, offsets
236
+ )
237
+
238
+ return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(
239
+ dev_weights=weights,
240
+ uvm_weights=self.weights_uvm,
241
+ weights_placements=self.weights_placements,
242
+ weights_offsets=weights_offsets,
243
+ weights_tys=self.weights_tys,
244
+ D_offsets=self.D_offsets,
245
+ total_D=self.total_D,
246
+ max_int2_D=self.max_int2_D,
247
+ max_int4_D=self.max_int4_D,
248
+ max_int8_D=self.max_int8_D,
249
+ max_float16_D=self.max_float16_D,
250
+ max_float32_D=self.max_float32_D,
251
+ indices=indices,
252
+ offsets=offsets,
253
+ pooling_mode=int(self.pooling_mode),
254
+ indice_weights=per_sample_weights,
255
+ output_dtype=self.output_dtype,
256
+ lxu_cache_weights=self.lxu_cache_weights,
257
+ lxu_cache_locations=lxu_cache_locations,
258
+ row_alignment=self.row_alignment,
259
+ max_float8_D=self.max_float8_D,
260
+ fp8_exponent_bits=self.fp8_exponent_bits,
261
+ fp8_exponent_bias=self.fp8_exponent_bias,
262
+ )
263
+
264
+ def fill_random_weights(self) -> None:
265
+ """
266
+ Fill the buffer with random weights, table by table
267
+ """
268
+ self.initialize_kv_embedding_cache()
269
+ for i, (_, num_embeddings, embedding_dim, weight_ty, _) in enumerate(
270
+ self.embedding_specs
271
+ ):
272
+ embedding_dim = rounded_row_size_in_bytes(
273
+ embedding_dim, weight_ty, self.row_alignment
274
+ )
275
+ indices = torch.range(0, num_embeddings - 1, dtype=torch.int64)
276
+ weights = random_quant_scaled_tensor(
277
+ shape=torch.Size([num_embeddings, embedding_dim]),
278
+ device=self.current_device,
279
+ )
280
+ self.embedding_inplace_update_per_table(
281
+ i,
282
+ indices,
283
+ weights,
284
+ )
285
+ self.weight_initialized = True
286
+
287
+ @torch.jit.export
288
+ def init_tbe_config(self, table_sharding_offset: list[int]) -> None:
289
+ """
290
+ Initialize the dynamic TBE table configs, e.g. sharded table offsets, etc.
291
+ Should be called before loading weights.
292
+ """
293
+ self.table_sharding_offset = table_sharding_offset
294
+
295
+ @torch.jit.export
296
+ def embedding_inplace_update(
297
+ self,
298
+ update_table_indices: list[int],
299
+ update_row_indices: list[list[int]],
300
+ update_weights: list[Tensor],
301
+ ) -> None:
302
+ # function is not used for now on the inference side
303
+ for i in range(len(update_table_indices)):
304
+ self.embedding_inplace_update_per_table(
305
+ update_table_indices[i],
306
+ torch.tensor(
307
+ update_row_indices[i], device=self.current_device, dtype=torch.int64
308
+ ),
309
+ update_weights[i],
310
+ None,
311
+ )
312
+
313
+ @torch.jit.export
314
+ def embedding_inplace_update_per_table(
315
+ self,
316
+ table_id: int,
317
+ update_row_indices: Tensor,
318
+ update_weights: Tensor,
319
+ inplace_update_ts_sec: Optional[int] = None,
320
+ ) -> None:
321
+ assert table_id < len(
322
+ self.embedding_specs
323
+ ), f"table index {table_id} is out of range {len(self.embedding_specs)}"
324
+ # pyre-ignore [29]
325
+ table_offset = self.hash_size_cumsum[table_id]
326
+ sharding_offset = self.table_sharding_offset[table_id]
327
+
328
+ row_size = update_row_indices.numel()
329
+ if row_size == 0:
330
+ return
331
+
332
+ # convert global weight index to fused local weight index
333
+ row_indices = update_row_indices + table_offset - sharding_offset
334
+ # set weight by id
335
+ self.kv_embedding_cache.set_embeddings(
336
+ row_indices, update_weights, inplace_update_ts_sec
337
+ )
338
+
339
+ @torch.jit.export
340
+ def log_inplace_update_stats(
341
+ self,
342
+ ) -> None:
343
+ self.kv_embedding_cache.log_inplace_update_stats()
344
+
345
+ @torch.jit.export
346
+ def embedding_trigger_evict(
347
+ self,
348
+ inplace_update_ts_sec: int,
349
+ ) -> None:
350
+ self.kv_embedding_cache.trigger_evict(inplace_update_ts_sec)
351
+
352
+ @torch.jit.export
353
+ def embedding_wait_evict_completion(
354
+ self,
355
+ ) -> None:
356
+ self.kv_embedding_cache.wait_evict_completion()
357
+
358
+ @torch.jit.export
359
+ def initialize_kv_embedding_cache(self) -> None:
360
+ if not self.kv_embedding_cache_initialized:
361
+ self.initialize_logical_weights_placements_and_offsets()
362
+
363
+ self.row_alignment = 8 # in order to use mempool implementation for kv embedding it needs to be divisible by 8
364
+
365
+ hash_size_cumsum = self.construct_hash_size_cumsum()
366
+ self.hash_size_cumsum = torch.tensor(
367
+ hash_size_cumsum,
368
+ dtype=torch.int64,
369
+ device=self.current_device,
370
+ )
371
+
372
+ self.feature_hash_size_cumsum = torch.tensor(
373
+ [hash_size_cumsum[t] for t in self.feature_table_map]
374
+ + [hash_size_cumsum[-1]],
375
+ dtype=torch.int64,
376
+ device=self.current_device,
377
+ )
378
+
379
+ self.kv_embedding_cache.init(
380
+ self.specs,
381
+ self.row_alignment,
382
+ self.scale_bias_size_in_bytes,
383
+ self.hash_size_cumsum,
384
+ )
385
+ self.kv_embedding_cache_initialized = True
@@ -6,7 +6,7 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Optional, Tuple, Union
9
+ from typing import Optional, Union
10
10
 
11
11
  import torch
12
12
 
@@ -17,13 +17,13 @@ def get_unique_indices_v2(
17
17
  compute_count: bool = False,
18
18
  compute_inverse_indices: bool = False,
19
19
  ) -> Union[
20
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
21
- Tuple[
20
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
21
+ tuple[
22
22
  torch.Tensor,
23
23
  torch.Tensor,
24
24
  Optional[torch.Tensor],
25
- Tuple[torch.Tensor, torch.Tensor],
26
25
  ],
26
+ tuple[torch.Tensor, torch.Tensor],
27
27
  ]:
28
28
  """
29
29
  A wrapper for get_unique_indices for overloading the return type
@@ -43,7 +43,6 @@ def get_unique_indices_v2(
43
43
  return ret[:-1]
44
44
  if compute_inverse_indices:
45
45
  # Return (unique_indices, length, inverse_indices)
46
- # pyre-fixme[7]: The arity arity of this return is wrong (3 vs 4)
47
46
  return ret[0], ret[1], ret[3]
48
47
  # Return (unique_indices, length)
49
48
  return ret[:-2]
@@ -8,6 +8,9 @@
8
8
  # pyre-strict
9
9
  # pyre-ignore-all-errors[56]
10
10
 
11
+ import torch
12
+
13
+ # fmt:skip
11
14
  from fbgemm_gpu.utils.loader import load_torch_module
12
15
 
13
16
  try:
@@ -18,3 +21,27 @@ except Exception:
18
21
  pass
19
22
 
20
23
  ASSOC = 32
24
+
25
+
26
+ def pad4(value: int) -> int:
27
+ """
28
+ Compute the smallest multiple of 4 that is greater than or equal to the given value.
29
+
30
+ Parameters:
31
+ value (int): The integer to align (must be non-negative).
32
+
33
+ Returns:
34
+ int: The aligned value.
35
+
36
+ Raises:
37
+ ValueError: If the input is negative.
38
+ TypeError: If the input is not an integer.
39
+ """
40
+ return (int(value) + 3) & ~3
41
+
42
+
43
+ def tensor_pad4(value: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ The equivalent of pad4 for tensors.
46
+ """
47
+ return (value + 3) & ~3
@@ -13,7 +13,7 @@ import logging
13
13
  import os
14
14
  import tempfile
15
15
  from math import log2
16
- from typing import List, Optional, Tuple
16
+ from typing import Optional
17
17
 
18
18
  import torch # usort:skip
19
19
 
@@ -42,15 +42,15 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
42
42
  Inference version, with FP32/FP16/FP8/INT8/INT4/INT2 supports
43
43
  """
44
44
 
45
- embedding_specs: List[Tuple[str, int, int, SparseType]]
45
+ embedding_specs: list[tuple[str, int, int, SparseType]]
46
46
  _local_instance_index: int = -1
47
47
 
48
48
  def __init__(
49
49
  self,
50
- embedding_specs: List[
51
- Tuple[str, int, int, SparseType]
50
+ embedding_specs: list[
51
+ tuple[str, int, int, SparseType]
52
52
  ], # tuple of (feature_names, rows, dims, SparseType)
53
- feature_table_map: Optional[List[int]] = None, # [T]
53
+ feature_table_map: Optional[list[int]] = None, # [T]
54
54
  pooling_mode: PoolingMode = PoolingMode.SUM,
55
55
  output_dtype: SparseType = SparseType.FP16,
56
56
  row_alignment: Optional[int] = None,
@@ -73,7 +73,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
73
73
  ssd_uniform_init_lower: float = -0.01,
74
74
  ssd_uniform_init_upper: float = 0.01,
75
75
  # Parameter Server Configs
76
- ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
76
+ ps_hosts: Optional[tuple[tuple[str, int]]] = None,
77
77
  ps_max_key_per_request: Optional[int] = None,
78
78
  ps_client_thread_num: Optional[int] = None,
79
79
  ps_max_local_index_length: Optional[int] = None,
@@ -99,7 +99,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
99
99
  self.current_device = torch.device(device)
100
100
  self.use_cpu: bool = self.current_device.type == "cpu"
101
101
 
102
- self.feature_table_map: List[int] = (
102
+ self.feature_table_map: list[int] = (
103
103
  feature_table_map if feature_table_map is not None else list(range(T_))
104
104
  )
105
105
  T = len(self.feature_table_map)
@@ -112,9 +112,9 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
112
112
  self.output_dtype: int = output_dtype.as_int()
113
113
  # (feature_names, rows, dims, weights_tys) = zip(*embedding_specs)
114
114
  # Pyre workaround
115
- rows: List[int] = [e[1] for e in embedding_specs]
116
- dims: List[int] = [e[2] for e in embedding_specs]
117
- weights_tys: List[SparseType] = [e[3] for e in embedding_specs]
115
+ rows: list[int] = [e[1] for e in embedding_specs]
116
+ dims: list[int] = [e[2] for e in embedding_specs]
117
+ weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
118
118
 
119
119
  D_offsets = [dims[t] for t in self.feature_table_map]
120
120
  D_offsets = [0] + list(itertools.accumulate(D_offsets))
@@ -169,7 +169,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
169
169
  offsets.append(uvm_size)
170
170
  uvm_size += state_size
171
171
 
172
- self.weights_physical_offsets: List[int] = offsets
172
+ self.weights_physical_offsets: list[int] = offsets
173
173
 
174
174
  weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map]
175
175
  self.register_buffer(
@@ -306,7 +306,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
306
306
  )
307
307
 
308
308
  # pyre-fixme[20]: Argument `self` expected.
309
- (low_priority, high_priority) = torch.cuda.Stream.priority_range()
309
+ low_priority, high_priority = torch.cuda.Stream.priority_range()
310
310
  self.ssd_stream = torch.cuda.Stream(priority=low_priority)
311
311
  self.ssd_set_start = torch.cuda.Event()
312
312
  self.ssd_set_end = torch.cuda.Event()
@@ -369,7 +369,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
369
369
 
370
370
  @torch.jit.export
371
371
  def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
372
- (indices, offsets) = indices.long(), offsets.long()
372
+ indices, offsets = indices.long(), offsets.long()
373
373
  linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
374
374
  self.hash_size_cumsum,
375
375
  indices,
@@ -517,13 +517,13 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
517
517
  @torch.jit.export
518
518
  def split_embedding_weights(
519
519
  self, split_scale_shifts: bool = True
520
- ) -> List[Tuple[Tensor, Optional[Tensor]]]:
520
+ ) -> list[tuple[Tensor, Optional[Tensor]]]:
521
521
  """
522
522
  Returns a list of weights, split by table.
523
523
 
524
524
  Testing only, very slow.
525
525
  """
526
- splits: List[Tuple[Tensor, Optional[Tensor]]] = []
526
+ splits: list[tuple[Tensor, Optional[Tensor]]] = []
527
527
  rows_cumsum = 0
528
528
  for _, row, dim, weight_ty in self.embedding_specs:
529
529
  weights = torch.empty(