fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -12,7 +12,7 @@
12
12
  import logging
13
13
  import uuid
14
14
  from itertools import accumulate
15
- from typing import List, Optional, Tuple, Union
15
+ from typing import Optional, Union
16
16
 
17
17
  import fbgemm_gpu # noqa: F401
18
18
  import torch # usort:skip
@@ -92,14 +92,14 @@ def align_to_cacheline(a: int) -> int:
92
92
 
93
93
 
94
94
  def nbit_construct_split_state(
95
- embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]],
95
+ embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]],
96
96
  cacheable: bool,
97
97
  row_alignment: int,
98
98
  scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
99
99
  cacheline_alignment: bool = True,
100
100
  ) -> SplitState:
101
- placements = torch.jit.annotate(List[EmbeddingLocation], [])
102
- offsets = torch.jit.annotate(List[int], [])
101
+ placements = torch.jit.annotate(list[EmbeddingLocation], [])
102
+ offsets = torch.jit.annotate(list[int], [])
103
103
  dev_size = 0
104
104
  host_size = 0
105
105
  uvm_size = 0
@@ -165,7 +165,7 @@ def inputs_to_device(
165
165
  offsets: torch.Tensor,
166
166
  per_sample_weights: Optional[torch.Tensor],
167
167
  bounds_check_warning: torch.Tensor,
168
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
168
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
169
169
  if bounds_check_warning.device.type == "meta":
170
170
  return indices, offsets, per_sample_weights
171
171
 
@@ -331,7 +331,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
331
331
  Options are `torch.int32` and `torch.int64`.
332
332
  """
333
333
 
334
- embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]]
334
+ embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]]
335
335
  record_cache_metrics: RecordCacheMetrics
336
336
  # pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
337
337
  cache_miss_counter: torch.Tensor
@@ -346,15 +346,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
346
346
 
347
347
  def __init__( # noqa C901
348
348
  self,
349
- embedding_specs: List[
350
- Tuple[str, int, int, SparseType, EmbeddingLocation]
349
+ embedding_specs: list[
350
+ tuple[str, int, int, SparseType, EmbeddingLocation]
351
351
  ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
352
- feature_table_map: Optional[List[int]] = None, # [T]
353
- index_remapping: Optional[List[Tensor]] = None,
352
+ feature_table_map: Optional[list[int]] = None, # [T]
353
+ index_remapping: Optional[list[Tensor]] = None,
354
354
  pooling_mode: PoolingMode = PoolingMode.SUM,
355
355
  device: Optional[Union[str, int, torch.device]] = None,
356
356
  bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
357
- weight_lists: Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None,
357
+ weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
358
358
  pruning_hash_load_factor: float = 0.5,
359
359
  use_array_for_index_remapping: bool = True,
360
360
  output_dtype: SparseType = SparseType.FP16,
@@ -373,7 +373,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
373
373
  cacheline_alignment: bool = True,
374
374
  uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
375
375
  reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
376
- feature_names_per_table: Optional[List[List[str]]] = None,
376
+ feature_names_per_table: Optional[list[list[str]]] = None,
377
377
  indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
378
378
  ) -> None: # noqa C901 # tuple of (rows, dims,)
379
379
  super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()
@@ -406,14 +406,14 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
406
406
  self.indices_dtype = indices_dtype
407
407
  # (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs)
408
408
  # Pyre workaround
409
- self.feature_names: List[str] = [e[0] for e in embedding_specs]
409
+ self.feature_names: list[str] = [e[0] for e in embedding_specs]
410
410
  self.cache_load_factor: float = cache_load_factor
411
411
  self.cache_sets: int = cache_sets
412
412
  self.cache_reserved_memory: float = cache_reserved_memory
413
- rows: List[int] = [e[1] for e in embedding_specs]
414
- dims: List[int] = [e[2] for e in embedding_specs]
415
- weights_tys: List[SparseType] = [e[3] for e in embedding_specs]
416
- locations: List[EmbeddingLocation] = [e[4] for e in embedding_specs]
413
+ rows: list[int] = [e[1] for e in embedding_specs]
414
+ dims: list[int] = [e[2] for e in embedding_specs]
415
+ weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
416
+ locations: list[EmbeddingLocation] = [e[4] for e in embedding_specs]
417
417
  # if target device is meta then we set use_cpu based on the embedding location
418
418
  # information in embedding_specs.
419
419
  if self.current_device.type == "meta":
@@ -453,7 +453,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
453
453
  T_ = len(self.embedding_specs)
454
454
  assert T_ > 0
455
455
 
456
- self.feature_table_map: List[int] = (
456
+ self.feature_table_map: list[int] = (
457
457
  feature_table_map if feature_table_map is not None else list(range(T_))
458
458
  )
459
459
  T = len(self.feature_table_map)
@@ -676,7 +676,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
676
676
  return self.table_wise_cache_miss
677
677
 
678
678
  @torch.jit.export
679
- def get_feature_num_per_table(self) -> List[int]:
679
+ def get_feature_num_per_table(self) -> list[int]:
680
680
  if self.feature_names_per_table is None:
681
681
  return []
682
682
  return [len(feature_names) for feature_names in self.feature_names_per_table]
@@ -1211,8 +1211,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1211
1211
  dev_size: int,
1212
1212
  host_size: int,
1213
1213
  uvm_size: int,
1214
- placements: List[int],
1215
- offsets: List[int],
1214
+ placements: list[int],
1215
+ offsets: list[int],
1216
1216
  enforce_hbm: bool,
1217
1217
  ) -> None:
1218
1218
  assert not self.weight_initialized, "Weights have already been initialized."
@@ -1602,7 +1602,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1602
1602
  @torch.jit.export
1603
1603
  def split_embedding_weights_with_scale_bias(
1604
1604
  self, split_scale_bias_mode: int = 1
1605
- ) -> List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
1605
+ ) -> list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
1606
1606
  """
1607
1607
  Returns a list of weights, split by table
1608
1608
  split_scale_bias_mode:
@@ -1611,7 +1611,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1611
1611
  2: return weights, scale, bias.
1612
1612
  """
1613
1613
  assert self.weight_initialized
1614
- splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
1614
+ splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
1615
1615
  for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs):
1616
1616
  placement = self.weights_physical_placements[t]
1617
1617
  if (
@@ -1736,12 +1736,12 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1736
1736
  # the second with scale_bias.
1737
1737
  # This should've been named as split_scale_bias.
1738
1738
  # Keep as is for backward compatibility.
1739
- ) -> List[Tuple[Tensor, Optional[Tensor]]]:
1739
+ ) -> list[tuple[Tensor, Optional[Tensor]]]:
1740
1740
  """
1741
1741
  Returns a list of weights, split by table
1742
1742
  """
1743
1743
  # fmt: off
1744
- splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
1744
+ splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
1745
1745
  self.split_embedding_weights_with_scale_bias(
1746
1746
  split_scale_bias_mode=(1 if split_scale_shifts else 0)
1747
1747
  )
@@ -1779,7 +1779,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1779
1779
  )
1780
1780
 
1781
1781
  def assign_embedding_weights(
1782
- self, q_weight_list: List[Tuple[Tensor, Optional[Tensor]]]
1782
+ self, q_weight_list: list[tuple[Tensor, Optional[Tensor]]]
1783
1783
  ) -> None:
1784
1784
  """
1785
1785
  Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts.
@@ -1799,11 +1799,11 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1799
1799
  @torch.jit.export
1800
1800
  def set_index_remappings_array(
1801
1801
  self,
1802
- index_remapping: List[Tensor],
1802
+ index_remapping: list[Tensor],
1803
1803
  ) -> None:
1804
- rows: List[int] = [e[1] for e in self.embedding_specs]
1804
+ rows: list[int] = [e[1] for e in self.embedding_specs]
1805
1805
  index_remappings_array_offsets = [0]
1806
- original_feature_rows = torch.jit.annotate(List[int], [])
1806
+ original_feature_rows = torch.jit.annotate(list[int], [])
1807
1807
  last_offset = 0
1808
1808
  for t, mapping in enumerate(index_remapping):
1809
1809
  if mapping is not None:
@@ -1842,11 +1842,11 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1842
1842
 
1843
1843
  def set_index_remappings(
1844
1844
  self,
1845
- index_remapping: List[Tensor],
1845
+ index_remapping: list[Tensor],
1846
1846
  pruning_hash_load_factor: float = 0.5,
1847
1847
  use_array_for_index_remapping: bool = True,
1848
1848
  ) -> None:
1849
- rows: List[int] = [e[1] for e in self.embedding_specs]
1849
+ rows: list[int] = [e[1] for e in self.embedding_specs]
1850
1850
  T = len(self.embedding_specs)
1851
1851
  # Hash mapping pruning
1852
1852
  if not use_array_for_index_remapping:
@@ -1916,7 +1916,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1916
1916
  def _embedding_inplace_update_per_table(
1917
1917
  self,
1918
1918
  update_table_idx: int,
1919
- update_row_indices: List[int],
1919
+ update_row_indices: list[int],
1920
1920
  update_weights: Tensor,
1921
1921
  ) -> None:
1922
1922
  row_size = len(update_row_indices)
@@ -1941,9 +1941,9 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1941
1941
  @torch.jit.export
1942
1942
  def embedding_inplace_update(
1943
1943
  self,
1944
- update_table_indices: List[int],
1945
- update_row_indices: List[List[int]],
1946
- update_weights: List[Tensor],
1944
+ update_table_indices: list[int],
1945
+ update_row_indices: list[list[int]],
1946
+ update_weights: list[Tensor],
1947
1947
  ) -> None:
1948
1948
  for i in range(len(update_table_indices)):
1949
1949
  self._embedding_inplace_update_per_table(
@@ -1954,8 +1954,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
1954
1954
 
1955
1955
  def embedding_inplace_update_internal(
1956
1956
  self,
1957
- update_table_indices: List[int],
1958
- update_row_indices: List[int],
1957
+ update_table_indices: list[int],
1958
+ update_row_indices: list[int],
1959
1959
  update_weights: Tensor,
1960
1960
  ) -> None:
1961
1961
  assert len(update_table_indices) == len(update_row_indices)