fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,385 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # pyre-ignore-all-errors[56]
11
+
12
+
13
+ from typing import Optional, Union
14
+
15
+ import torch # usort:skip
16
+ from torch import Tensor # usort:skip
17
+ from fbgemm_gpu.split_embedding_configs import SparseType
18
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
19
+ BoundsCheckMode,
20
+ CacheAlgorithm,
21
+ DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
22
+ EmbeddingLocation,
23
+ PoolingMode,
24
+ RecordCacheMetrics,
25
+ )
26
+ from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
27
+ inputs_to_device,
28
+ IntNBitTableBatchedEmbeddingBagsCodegen,
29
+ random_quant_scaled_tensor,
30
+ rounded_row_size_in_bytes,
31
+ )
32
+ from fbgemm_gpu.utils.loader import load_torch_module
33
+
34
+ try:
35
+ load_torch_module(
36
+ "//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference",
37
+ )
38
+ except Exception:
39
+ pass
40
+
41
+
42
+ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
43
+ """
44
+ KV Table-batched version of nn.EmbeddingBag(sparse=False)
45
+ Inference version, with support for FP32/FP16/FP8/INT8/INT4/INT2 weights
46
+ """
47
+
48
+ def __init__( # noqa C901
49
+ self,
50
+ embedding_specs: list[
51
+ tuple[str, int, int, SparseType, EmbeddingLocation]
52
+ ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
53
+ feature_table_map: Optional[list[int]] = None, # [T]
54
+ index_remapping: Optional[list[Tensor]] = None,
55
+ pooling_mode: PoolingMode = PoolingMode.SUM,
56
+ device: Optional[Union[str, int, torch.device]] = None,
57
+ bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
58
+ weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
59
+ pruning_hash_load_factor: float = 0.5,
60
+ use_array_for_index_remapping: bool = True,
61
+ output_dtype: SparseType = SparseType.FP16,
62
+ cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
63
+ cache_load_factor: float = 0.2,
64
+ cache_sets: int = 0,
65
+ cache_reserved_memory: float = 0.0,
66
+ enforce_hbm: bool = False, # place all weights/momentums in HBM when using cache
67
+ record_cache_metrics: Optional[RecordCacheMetrics] = None,
68
+ gather_uvm_cache_stats: Optional[bool] = False,
69
+ row_alignment: Optional[int] = None,
70
+ fp8_exponent_bits: Optional[int] = None,
71
+ fp8_exponent_bias: Optional[int] = None,
72
+ cache_assoc: int = 32,
73
+ scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
74
+ cacheline_alignment: bool = True,
75
+ uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
76
+ reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
77
+ feature_names_per_table: Optional[list[list[str]]] = None,
78
+ indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
79
+ embedding_cache_mode: bool = False, # True for zero initialization, False for randomized initialization
80
+ ) -> None: # noqa C901 # tuple of (rows, dims,)
81
+ super(KVEmbeddingInference, self).__init__(
82
+ embedding_specs=embedding_specs,
83
+ feature_table_map=feature_table_map,
84
+ index_remapping=index_remapping,
85
+ pooling_mode=pooling_mode,
86
+ device=device,
87
+ bounds_check_mode=bounds_check_mode,
88
+ weight_lists=weight_lists,
89
+ pruning_hash_load_factor=pruning_hash_load_factor,
90
+ use_array_for_index_remapping=use_array_for_index_remapping,
91
+ output_dtype=output_dtype,
92
+ cache_algorithm=cache_algorithm,
93
+ cache_load_factor=cache_load_factor,
94
+ cache_sets=cache_sets,
95
+ cache_reserved_memory=cache_reserved_memory,
96
+ enforce_hbm=enforce_hbm,
97
+ record_cache_metrics=record_cache_metrics,
98
+ gather_uvm_cache_stats=gather_uvm_cache_stats,
99
+ row_alignment=row_alignment,
100
+ fp8_exponent_bits=fp8_exponent_bits,
101
+ fp8_exponent_bias=fp8_exponent_bias,
102
+ cache_assoc=cache_assoc,
103
+ scale_bias_size_in_bytes=scale_bias_size_in_bytes,
104
+ cacheline_alignment=cacheline_alignment,
105
+ uvm_host_mapped=uvm_host_mapped,
106
+ reverse_qparam=reverse_qparam,
107
+ feature_names_per_table=feature_names_per_table,
108
+ indices_dtype=indices_dtype,
109
+ )
110
+ self.register_buffer(
111
+ "weights_ids",
112
+ torch.tensor(0, device=self.current_device, dtype=torch.int64),
113
+ )
114
+
115
+ num_shards = 32
116
+ uniform_init_lower: float = -0.01
117
+ uniform_init_upper: float = 0.01
118
+
119
+ # pyre-fixme[4]: Attribute must be annotated.
120
+ self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
121
+ num_shards,
122
+ uniform_init_lower,
123
+ uniform_init_upper,
124
+ embedding_cache_mode, # in embedding_cache_mode, we disable random init
125
+ )
126
+
127
+ self.specs: list[tuple[int, int, int]] = [
128
+ (rows, dims, sparse_type.as_int())
129
+ for (_, rows, dims, sparse_type, _) in self.embedding_specs
130
+ ]
131
+ # table shard offset if inference sharding is enabled, otherwise, should be all zeros
132
+ self.table_sharding_offset: list[int] = [0] * len(self.embedding_specs)
133
+ self.kv_embedding_cache_initialized = False
134
+ self.hash_size_cumsum: torch.Tensor = torch.zeros(
135
+ 0,
136
+ device=self.current_device,
137
+ dtype=torch.int64,
138
+ )
139
+ self.feature_hash_size_cumsum: torch.Tensor = torch.zeros(
140
+ 0,
141
+ device=self.current_device,
142
+ dtype=torch.int64,
143
+ )
144
+
145
+ def construct_hash_size_cumsum(self) -> list[int]:
146
+ hash_size_cumsum = [0]
147
+ for spec in self.embedding_specs:
148
+ rows = spec[1]
149
+ hash_size_cumsum.append(hash_size_cumsum[-1] + rows)
150
+ return hash_size_cumsum
151
+
152
+ def calculate_indices_and_weights_offsets(
153
+ self, indices: Tensor, offsets: Tensor
154
+ ) -> tuple[Tensor, Tensor]:
155
+ if self.pooling_mode is not PoolingMode.NONE:
156
+ T = self.weights_offsets.numel()
157
+ else:
158
+ T = self.D_offsets.numel() - 1
159
+ B = int((offsets.size(0) - 1) / T)
160
+
161
+ total_bytes_added = 0
162
+ new_indices = torch.tensor(
163
+ [0] * indices.size(0), device=self.current_device, dtype=indices.dtype
164
+ )
165
+ new_weights_offsets = torch.tensor(
166
+ [0] * T, device=self.current_device, dtype=self.weights_offsets.dtype
167
+ )
168
+ for t in range(T):
169
+ new_weights_offsets[t] = total_bytes_added
170
+ start, end = int(offsets[t * B]), int(offsets[(t + 1) * B])
171
+ index_size = end - start
172
+ new_indices[start:end] = torch.arange(index_size)
173
+ table_id = self.feature_table_map[t]
174
+ total_bytes_added += index_size * rounded_row_size_in_bytes(
175
+ self.embedding_specs[table_id][2], # dim
176
+ self.embedding_specs[table_id][3], # weight_ty
177
+ self.row_alignment,
178
+ self.scale_bias_size_in_bytes,
179
+ )
180
+ return new_indices, new_weights_offsets
181
+
182
+ def linearize_cache_indices(
183
+ self,
184
+ indices: torch.Tensor,
185
+ offsets: torch.Tensor,
186
+ ) -> torch.Tensor:
187
+ """
188
+ Linearize cache indices for KV cache.
189
+ """
190
+ linearized_indices = torch.zeros(
191
+ indices.numel(),
192
+ device=indices.device,
193
+ dtype=torch.int64,
194
+ )
195
+
196
+ T = self.feature_hash_size_cumsum.numel() - 1
197
+ B = int((offsets.size(0) - 1) / T)
198
+
199
+ for t in range(T):
200
+ start, end = int(offsets[t * B]), int(offsets[(t + 1) * B])
201
+ linearized_indices[start:end] = (
202
+ indices[start:end] + self.feature_hash_size_cumsum[t]
203
+ )
204
+
205
+ return linearized_indices
206
+
207
+ def forward(
208
+ self,
209
+ indices: Tensor,
210
+ offsets: Tensor,
211
+ per_sample_weights: Optional[Tensor] = None,
212
+ ) -> Tensor:
213
+ assert (
214
+ self.weight_initialized
215
+ ), "weight needs to be initialized before forward function"
216
+
217
+ indices, offsets, per_sample_weights = inputs_to_device(
218
+ indices, offsets, per_sample_weights, self.bounds_check_warning
219
+ )
220
+
221
+ lxu_cache_locations = self.lxu_cache_locations_list.pop()
222
+
223
+ weights_offsets = self.weights_offsets
224
+ weights = self.weights_host if self.host_size > 0 else self.weights_dev
225
+
226
+ if self.kv_embedding_cache_initialized:
227
+ indices = self.linearize_cache_indices(
228
+ indices,
229
+ offsets,
230
+ )
231
+
232
+ weights = self.kv_embedding_cache.get_embeddings(indices)
233
+
234
+ indices, weights_offsets = self.calculate_indices_and_weights_offsets(
235
+ indices, offsets
236
+ )
237
+
238
+ return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(
239
+ dev_weights=weights,
240
+ uvm_weights=self.weights_uvm,
241
+ weights_placements=self.weights_placements,
242
+ weights_offsets=weights_offsets,
243
+ weights_tys=self.weights_tys,
244
+ D_offsets=self.D_offsets,
245
+ total_D=self.total_D,
246
+ max_int2_D=self.max_int2_D,
247
+ max_int4_D=self.max_int4_D,
248
+ max_int8_D=self.max_int8_D,
249
+ max_float16_D=self.max_float16_D,
250
+ max_float32_D=self.max_float32_D,
251
+ indices=indices,
252
+ offsets=offsets,
253
+ pooling_mode=int(self.pooling_mode),
254
+ indice_weights=per_sample_weights,
255
+ output_dtype=self.output_dtype,
256
+ lxu_cache_weights=self.lxu_cache_weights,
257
+ lxu_cache_locations=lxu_cache_locations,
258
+ row_alignment=self.row_alignment,
259
+ max_float8_D=self.max_float8_D,
260
+ fp8_exponent_bits=self.fp8_exponent_bits,
261
+ fp8_exponent_bias=self.fp8_exponent_bias,
262
+ )
263
+
264
+ def fill_random_weights(self) -> None:
265
+ """
266
+ Fill the buffer with random weights, table by table
267
+ """
268
+ self.initialize_kv_embedding_cache()
269
+ for i, (_, num_embeddings, embedding_dim, weight_ty, _) in enumerate(
270
+ self.embedding_specs
271
+ ):
272
+ embedding_dim = rounded_row_size_in_bytes(
273
+ embedding_dim, weight_ty, self.row_alignment
274
+ )
275
+ indices = torch.range(0, num_embeddings - 1, dtype=torch.int64)
276
+ weights = random_quant_scaled_tensor(
277
+ shape=torch.Size([num_embeddings, embedding_dim]),
278
+ device=self.current_device,
279
+ )
280
+ self.embedding_inplace_update_per_table(
281
+ i,
282
+ indices,
283
+ weights,
284
+ )
285
+ self.weight_initialized = True
286
+
287
+ @torch.jit.export
288
+ def init_tbe_config(self, table_sharding_offset: list[int]) -> None:
289
+ """
290
+ Initialize the dynamic TBE table configs, e.g. sharded table offsets, etc.
291
+ Should be called before loading weights.
292
+ """
293
+ self.table_sharding_offset = table_sharding_offset
294
+
295
+ @torch.jit.export
296
+ def embedding_inplace_update(
297
+ self,
298
+ update_table_indices: list[int],
299
+ update_row_indices: list[list[int]],
300
+ update_weights: list[Tensor],
301
+ ) -> None:
302
+ # function is not used for now on the inference side
303
+ for i in range(len(update_table_indices)):
304
+ self.embedding_inplace_update_per_table(
305
+ update_table_indices[i],
306
+ torch.tensor(
307
+ update_row_indices[i], device=self.current_device, dtype=torch.int64
308
+ ),
309
+ update_weights[i],
310
+ None,
311
+ )
312
+
313
+ @torch.jit.export
314
+ def embedding_inplace_update_per_table(
315
+ self,
316
+ table_id: int,
317
+ update_row_indices: Tensor,
318
+ update_weights: Tensor,
319
+ inplace_update_ts_sec: Optional[int] = None,
320
+ ) -> None:
321
+ assert table_id < len(
322
+ self.embedding_specs
323
+ ), f"table index {table_id} is out of range {len(self.embedding_specs)}"
324
+ # pyre-ignore [29]
325
+ table_offset = self.hash_size_cumsum[table_id]
326
+ sharding_offset = self.table_sharding_offset[table_id]
327
+
328
+ row_size = update_row_indices.numel()
329
+ if row_size == 0:
330
+ return
331
+
332
+ # convert global weight index to fused local weight index
333
+ row_indices = update_row_indices + table_offset - sharding_offset
334
+ # set weight by id
335
+ self.kv_embedding_cache.set_embeddings(
336
+ row_indices, update_weights, inplace_update_ts_sec
337
+ )
338
+
339
+ @torch.jit.export
340
+ def log_inplace_update_stats(
341
+ self,
342
+ ) -> None:
343
+ self.kv_embedding_cache.log_inplace_update_stats()
344
+
345
+ @torch.jit.export
346
+ def embedding_trigger_evict(
347
+ self,
348
+ inplace_update_ts_sec: int,
349
+ ) -> None:
350
+ self.kv_embedding_cache.trigger_evict(inplace_update_ts_sec)
351
+
352
+ @torch.jit.export
353
+ def embedding_wait_evict_completion(
354
+ self,
355
+ ) -> None:
356
+ self.kv_embedding_cache.wait_evict_completion()
357
+
358
+ @torch.jit.export
359
+ def initialize_kv_embedding_cache(self) -> None:
360
+ if not self.kv_embedding_cache_initialized:
361
+ self.initialize_logical_weights_placements_and_offsets()
362
+
363
+ self.row_alignment = 8 # in order to use mempool implementation for kv embedding it needs to be divisible by 8
364
+
365
+ hash_size_cumsum = self.construct_hash_size_cumsum()
366
+ self.hash_size_cumsum = torch.tensor(
367
+ hash_size_cumsum,
368
+ dtype=torch.int64,
369
+ device=self.current_device,
370
+ )
371
+
372
+ self.feature_hash_size_cumsum = torch.tensor(
373
+ [hash_size_cumsum[t] for t in self.feature_table_map]
374
+ + [hash_size_cumsum[-1]],
375
+ dtype=torch.int64,
376
+ device=self.current_device,
377
+ )
378
+
379
+ self.kv_embedding_cache.init(
380
+ self.specs,
381
+ self.row_alignment,
382
+ self.scale_bias_size_in_bytes,
383
+ self.hash_size_cumsum,
384
+ )
385
+ self.kv_embedding_cache_initialized = True
@@ -0,0 +1,48 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from typing import Optional, Union
10
+
11
+ import torch
12
+
13
+
14
+ def get_unique_indices_v2(
15
+ linear_indices: torch.Tensor,
16
+ max_indices: int,
17
+ compute_count: bool = False,
18
+ compute_inverse_indices: bool = False,
19
+ ) -> Union[
20
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
21
+ tuple[
22
+ torch.Tensor,
23
+ torch.Tensor,
24
+ Optional[torch.Tensor],
25
+ ],
26
+ tuple[torch.Tensor, torch.Tensor],
27
+ ]:
28
+ """
29
+ A wrapper for get_unique_indices for overloading the return type
30
+ based on inputs
31
+ """
32
+ ret = torch.ops.fbgemm.get_unique_indices_with_inverse(
33
+ linear_indices,
34
+ max_indices,
35
+ compute_count,
36
+ compute_inverse_indices,
37
+ )
38
+ if compute_count and compute_inverse_indices:
39
+ # Return all tensors
40
+ return ret
41
+ if compute_count:
42
+ # Return (unique_indices, length, count)
43
+ return ret[:-1]
44
+ if compute_inverse_indices:
45
+ # Return (unique_indices, length, inverse_indices)
46
+ return ret[0], ret[1], ret[3]
47
+ # Return (unique_indices, length)
48
+ return ret[:-2]
@@ -0,0 +1,15 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # Load the prelude
11
+ from .common import ASSOC # noqa: F401
12
+
13
+ # Load the inference and training ops
14
+ from .inference import SSDIntNBitTableBatchedEmbeddingBags # noqa: F401
15
+ from .training import SSDTableBatchedEmbeddingBags # noqa: F401
@@ -0,0 +1,46 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+ # pyre-ignore-all-errors[56]
10
+
11
+ import torch
12
+
13
+ from fbgemm_gpu.utils.loader import load_torch_module
14
+
15
+ try:
16
+ load_torch_module(
17
+ "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings"
18
+ )
19
+ except Exception:
20
+ pass
21
+
22
+ ASSOC = 32
23
+
24
+
25
+ def pad4(value: int) -> int:
26
+ """
27
+ Compute the smallest multiple of 4 that is greater than or equal to the given value.
28
+
29
+ Parameters:
30
+ value (int): The integer to align (must be non-negative).
31
+
32
+ Returns:
33
+ int: The aligned value.
34
+
35
+ Raises:
36
+ ValueError: If the input is negative.
37
+ TypeError: If the input is not an integer.
38
+ """
39
+ return (int(value) + 3) & ~3
40
+
41
+
42
+ def tensor_pad4(value: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ The equivalent of pad4 for tensors.
45
+ """
46
+ return (value + 3) & ~3