fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +112 -19
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +118 -0
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +190 -54
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
- fbgemm_gpu/split_embedding_configs.py +134 -37
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +6 -1
- fbgemm_gpu/tbe/bench/bench_config.py +14 -3
- fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
- fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
- fbgemm_gpu/tbe/ssd/common.py +1 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +1292 -267
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +15 -15
- fbgemm_gpu/tbe_input_multiplexer.py +10 -11
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +6 -2
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +1 -0
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
import logging
|
|
13
13
|
import uuid
|
|
14
14
|
from itertools import accumulate
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Optional, Union
|
|
16
16
|
|
|
17
17
|
import fbgemm_gpu # noqa: F401
|
|
18
18
|
import torch # usort:skip
|
|
@@ -92,14 +92,14 @@ def align_to_cacheline(a: int) -> int:
|
|
|
92
92
|
|
|
93
93
|
|
|
94
94
|
def nbit_construct_split_state(
|
|
95
|
-
embedding_specs:
|
|
95
|
+
embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]],
|
|
96
96
|
cacheable: bool,
|
|
97
97
|
row_alignment: int,
|
|
98
98
|
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
99
99
|
cacheline_alignment: bool = True,
|
|
100
100
|
) -> SplitState:
|
|
101
|
-
placements = torch.jit.annotate(
|
|
102
|
-
offsets = torch.jit.annotate(
|
|
101
|
+
placements = torch.jit.annotate(list[EmbeddingLocation], [])
|
|
102
|
+
offsets = torch.jit.annotate(list[int], [])
|
|
103
103
|
dev_size = 0
|
|
104
104
|
host_size = 0
|
|
105
105
|
uvm_size = 0
|
|
@@ -165,7 +165,7 @@ def inputs_to_device(
|
|
|
165
165
|
offsets: torch.Tensor,
|
|
166
166
|
per_sample_weights: Optional[torch.Tensor],
|
|
167
167
|
bounds_check_warning: torch.Tensor,
|
|
168
|
-
) ->
|
|
168
|
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
169
169
|
if bounds_check_warning.device.type == "meta":
|
|
170
170
|
return indices, offsets, per_sample_weights
|
|
171
171
|
|
|
@@ -331,7 +331,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
331
331
|
Options are `torch.int32` and `torch.int64`.
|
|
332
332
|
"""
|
|
333
333
|
|
|
334
|
-
embedding_specs:
|
|
334
|
+
embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]]
|
|
335
335
|
record_cache_metrics: RecordCacheMetrics
|
|
336
336
|
# pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
|
|
337
337
|
cache_miss_counter: torch.Tensor
|
|
@@ -346,15 +346,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
346
346
|
|
|
347
347
|
def __init__( # noqa C901
|
|
348
348
|
self,
|
|
349
|
-
embedding_specs:
|
|
350
|
-
|
|
349
|
+
embedding_specs: list[
|
|
350
|
+
tuple[str, int, int, SparseType, EmbeddingLocation]
|
|
351
351
|
], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
|
|
352
|
-
feature_table_map: Optional[
|
|
353
|
-
index_remapping: Optional[
|
|
352
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
353
|
+
index_remapping: Optional[list[Tensor]] = None,
|
|
354
354
|
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
355
355
|
device: Optional[Union[str, int, torch.device]] = None,
|
|
356
356
|
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
357
|
-
weight_lists: Optional[
|
|
357
|
+
weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
|
|
358
358
|
pruning_hash_load_factor: float = 0.5,
|
|
359
359
|
use_array_for_index_remapping: bool = True,
|
|
360
360
|
output_dtype: SparseType = SparseType.FP16,
|
|
@@ -373,7 +373,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
373
373
|
cacheline_alignment: bool = True,
|
|
374
374
|
uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
|
|
375
375
|
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
|
|
376
|
-
feature_names_per_table: Optional[
|
|
376
|
+
feature_names_per_table: Optional[list[list[str]]] = None,
|
|
377
377
|
indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
|
|
378
378
|
) -> None: # noqa C901 # tuple of (rows, dims,)
|
|
379
379
|
super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()
|
|
@@ -406,14 +406,14 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
406
406
|
self.indices_dtype = indices_dtype
|
|
407
407
|
# (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs)
|
|
408
408
|
# Pyre workaround
|
|
409
|
-
self.feature_names:
|
|
409
|
+
self.feature_names: list[str] = [e[0] for e in embedding_specs]
|
|
410
410
|
self.cache_load_factor: float = cache_load_factor
|
|
411
411
|
self.cache_sets: int = cache_sets
|
|
412
412
|
self.cache_reserved_memory: float = cache_reserved_memory
|
|
413
|
-
rows:
|
|
414
|
-
dims:
|
|
415
|
-
weights_tys:
|
|
416
|
-
locations:
|
|
413
|
+
rows: list[int] = [e[1] for e in embedding_specs]
|
|
414
|
+
dims: list[int] = [e[2] for e in embedding_specs]
|
|
415
|
+
weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
|
|
416
|
+
locations: list[EmbeddingLocation] = [e[4] for e in embedding_specs]
|
|
417
417
|
# if target device is meta then we set use_cpu based on the embedding location
|
|
418
418
|
# information in embedding_specs.
|
|
419
419
|
if self.current_device.type == "meta":
|
|
@@ -453,7 +453,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
453
453
|
T_ = len(self.embedding_specs)
|
|
454
454
|
assert T_ > 0
|
|
455
455
|
|
|
456
|
-
self.feature_table_map:
|
|
456
|
+
self.feature_table_map: list[int] = (
|
|
457
457
|
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
458
458
|
)
|
|
459
459
|
T = len(self.feature_table_map)
|
|
@@ -676,7 +676,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
676
676
|
return self.table_wise_cache_miss
|
|
677
677
|
|
|
678
678
|
@torch.jit.export
|
|
679
|
-
def get_feature_num_per_table(self) ->
|
|
679
|
+
def get_feature_num_per_table(self) -> list[int]:
|
|
680
680
|
if self.feature_names_per_table is None:
|
|
681
681
|
return []
|
|
682
682
|
return [len(feature_names) for feature_names in self.feature_names_per_table]
|
|
@@ -1211,8 +1211,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1211
1211
|
dev_size: int,
|
|
1212
1212
|
host_size: int,
|
|
1213
1213
|
uvm_size: int,
|
|
1214
|
-
placements:
|
|
1215
|
-
offsets:
|
|
1214
|
+
placements: list[int],
|
|
1215
|
+
offsets: list[int],
|
|
1216
1216
|
enforce_hbm: bool,
|
|
1217
1217
|
) -> None:
|
|
1218
1218
|
assert not self.weight_initialized, "Weights have already been initialized."
|
|
@@ -1602,7 +1602,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1602
1602
|
@torch.jit.export
|
|
1603
1603
|
def split_embedding_weights_with_scale_bias(
|
|
1604
1604
|
self, split_scale_bias_mode: int = 1
|
|
1605
|
-
) ->
|
|
1605
|
+
) -> list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
|
|
1606
1606
|
"""
|
|
1607
1607
|
Returns a list of weights, split by table
|
|
1608
1608
|
split_scale_bias_mode:
|
|
@@ -1611,7 +1611,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1611
1611
|
2: return weights, scale, bias.
|
|
1612
1612
|
"""
|
|
1613
1613
|
assert self.weight_initialized
|
|
1614
|
-
splits:
|
|
1614
|
+
splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
|
|
1615
1615
|
for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs):
|
|
1616
1616
|
placement = self.weights_physical_placements[t]
|
|
1617
1617
|
if (
|
|
@@ -1736,12 +1736,12 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1736
1736
|
# the second with scale_bias.
|
|
1737
1737
|
# This should've been named as split_scale_bias.
|
|
1738
1738
|
# Keep as is for backward compatibility.
|
|
1739
|
-
) ->
|
|
1739
|
+
) -> list[tuple[Tensor, Optional[Tensor]]]:
|
|
1740
1740
|
"""
|
|
1741
1741
|
Returns a list of weights, split by table
|
|
1742
1742
|
"""
|
|
1743
1743
|
# fmt: off
|
|
1744
|
-
splits:
|
|
1744
|
+
splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
|
|
1745
1745
|
self.split_embedding_weights_with_scale_bias(
|
|
1746
1746
|
split_scale_bias_mode=(1 if split_scale_shifts else 0)
|
|
1747
1747
|
)
|
|
@@ -1779,7 +1779,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1779
1779
|
)
|
|
1780
1780
|
|
|
1781
1781
|
def assign_embedding_weights(
|
|
1782
|
-
self, q_weight_list:
|
|
1782
|
+
self, q_weight_list: list[tuple[Tensor, Optional[Tensor]]]
|
|
1783
1783
|
) -> None:
|
|
1784
1784
|
"""
|
|
1785
1785
|
Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts.
|
|
@@ -1799,11 +1799,11 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1799
1799
|
@torch.jit.export
|
|
1800
1800
|
def set_index_remappings_array(
|
|
1801
1801
|
self,
|
|
1802
|
-
index_remapping:
|
|
1802
|
+
index_remapping: list[Tensor],
|
|
1803
1803
|
) -> None:
|
|
1804
|
-
rows:
|
|
1804
|
+
rows: list[int] = [e[1] for e in self.embedding_specs]
|
|
1805
1805
|
index_remappings_array_offsets = [0]
|
|
1806
|
-
original_feature_rows = torch.jit.annotate(
|
|
1806
|
+
original_feature_rows = torch.jit.annotate(list[int], [])
|
|
1807
1807
|
last_offset = 0
|
|
1808
1808
|
for t, mapping in enumerate(index_remapping):
|
|
1809
1809
|
if mapping is not None:
|
|
@@ -1842,11 +1842,11 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1842
1842
|
|
|
1843
1843
|
def set_index_remappings(
|
|
1844
1844
|
self,
|
|
1845
|
-
index_remapping:
|
|
1845
|
+
index_remapping: list[Tensor],
|
|
1846
1846
|
pruning_hash_load_factor: float = 0.5,
|
|
1847
1847
|
use_array_for_index_remapping: bool = True,
|
|
1848
1848
|
) -> None:
|
|
1849
|
-
rows:
|
|
1849
|
+
rows: list[int] = [e[1] for e in self.embedding_specs]
|
|
1850
1850
|
T = len(self.embedding_specs)
|
|
1851
1851
|
# Hash mapping pruning
|
|
1852
1852
|
if not use_array_for_index_remapping:
|
|
@@ -1916,7 +1916,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1916
1916
|
def _embedding_inplace_update_per_table(
|
|
1917
1917
|
self,
|
|
1918
1918
|
update_table_idx: int,
|
|
1919
|
-
update_row_indices:
|
|
1919
|
+
update_row_indices: list[int],
|
|
1920
1920
|
update_weights: Tensor,
|
|
1921
1921
|
) -> None:
|
|
1922
1922
|
row_size = len(update_row_indices)
|
|
@@ -1941,9 +1941,9 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1941
1941
|
@torch.jit.export
|
|
1942
1942
|
def embedding_inplace_update(
|
|
1943
1943
|
self,
|
|
1944
|
-
update_table_indices:
|
|
1945
|
-
update_row_indices:
|
|
1946
|
-
update_weights:
|
|
1944
|
+
update_table_indices: list[int],
|
|
1945
|
+
update_row_indices: list[list[int]],
|
|
1946
|
+
update_weights: list[Tensor],
|
|
1947
1947
|
) -> None:
|
|
1948
1948
|
for i in range(len(update_table_indices)):
|
|
1949
1949
|
self._embedding_inplace_update_per_table(
|
|
@@ -1954,8 +1954,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1954
1954
|
|
|
1955
1955
|
def embedding_inplace_update_internal(
|
|
1956
1956
|
self,
|
|
1957
|
-
update_table_indices:
|
|
1958
|
-
update_row_indices:
|
|
1957
|
+
update_table_indices: list[int],
|
|
1958
|
+
update_row_indices: list[int],
|
|
1959
1959
|
update_weights: Tensor,
|
|
1960
1960
|
) -> None:
|
|
1961
1961
|
assert len(update_table_indices) == len(update_row_indices)
|