fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -18,14 +18,14 @@ import uuid
18
18
  from dataclasses import dataclass, field
19
19
  from itertools import accumulate
20
20
  from math import log2
21
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
21
+ from typing import Any, Callable, Optional, Union
22
22
 
23
23
  import torch # usort:skip
24
24
  from torch import nn, Tensor # usort:skip
25
+ from torch.autograd.profiler import record_function # usort:skip
25
26
 
26
27
  # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
27
28
  import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
28
-
29
29
  from fbgemm_gpu.config import FeatureGate, FeatureGateName
30
30
  from fbgemm_gpu.runtime_monitor import (
31
31
  AsyncSeriesTimer,
@@ -48,6 +48,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
48
48
  SplitState,
49
49
  )
50
50
  from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
51
+ check_allocated_vbe_output,
51
52
  generate_vbe_metadata,
52
53
  is_torchdynamo_compiling,
53
54
  )
@@ -57,8 +58,8 @@ from fbgemm_gpu.tbe_input_multiplexer import (
57
58
  TBEInputMultiplexer,
58
59
  TBEInputMultiplexerConfig,
59
60
  )
60
-
61
61
  from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
62
+ from fbgemm_gpu.utils.writeback_util import writeback_gradient
62
63
 
63
64
  try:
64
65
  load_torch_module(
@@ -158,6 +159,7 @@ class UserEnabledConfigDefinition:
158
159
  # More details can be found in D64848802.
159
160
  use_rowwise_bias_correction: bool = False
160
161
  use_writeback_bwd_prehook: bool = False
162
+ writeback_first_feature_only: bool = False
161
163
 
162
164
 
163
165
  @dataclass(frozen=True)
@@ -190,25 +192,48 @@ class UVMCacheStatsIndex(enum.IntEnum):
190
192
  class RESParams:
191
193
  res_server_port: int = 0 # the port of the res server
192
194
  res_store_shards: int = 1 # the number of shards to store the raw embeddings
193
- table_names: List[str] = field(default_factory=list) # table names the TBE holds
194
- table_offsets: List[int] = field(
195
+ table_names: list[str] = field(default_factory=list) # table names the TBE holds
196
+ table_offsets: list[int] = field(
195
197
  default_factory=list
196
198
  ) # table offsets for the global rows the TBE holds
197
- table_sizes: List[int] = field(
199
+ table_sizes: list[int] = field(
198
200
  default_factory=list
199
201
  ) # table sizes for the global rows the TBE holds
200
202
 
201
203
 
204
+ class PrefetchedInfo:
205
+ """
206
+ Container for prefetched cache information.
207
+
208
+ This class is explicitly defined (not using @dataclass) to be compatible with
209
+ TorchScript's inspect.getsource() requirements.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ linear_unique_indices: torch.Tensor,
215
+ linear_unique_cache_indices: torch.Tensor,
216
+ linear_unique_indices_length: torch.Tensor,
217
+ hash_zch_identities: Optional[torch.Tensor],
218
+ hash_zch_runtime_meta: Optional[torch.Tensor],
219
+ ) -> None:
220
+ self.linear_unique_indices = linear_unique_indices
221
+ self.linear_unique_cache_indices = linear_unique_cache_indices
222
+ self.linear_unique_indices_length = linear_unique_indices_length
223
+ self.hash_zch_identities = hash_zch_identities
224
+ self.hash_zch_runtime_meta = hash_zch_runtime_meta
225
+
226
+
202
227
  def construct_split_state(
203
- embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]],
228
+ embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]],
204
229
  rowwise: bool,
205
230
  cacheable: bool,
206
231
  precision: SparseType = SparseType.FP32,
207
232
  int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
208
233
  placement: Optional[EmbeddingLocation] = None,
209
234
  ) -> SplitState:
210
- placements: List[EmbeddingLocation] = []
211
- offsets: List[int] = []
235
+ placements: list[EmbeddingLocation] = []
236
+ offsets: list[int] = []
212
237
  dev_size: int = 0
213
238
  host_size: int = 0
214
239
  uvm_size: int = 0
@@ -250,18 +275,18 @@ def construct_split_state(
250
275
  def apply_split_helper(
251
276
  persistent_state_fn: Callable[[str, Tensor], None],
252
277
  set_attr_fn: Callable[
253
- [str, Union[Tensor, List[int], List[EmbeddingLocation]]], None
278
+ [str, Union[Tensor, list[int], list[EmbeddingLocation]]], None
254
279
  ],
255
280
  current_device: torch.device,
256
281
  use_cpu: bool,
257
- feature_table_map: List[int],
282
+ feature_table_map: list[int],
258
283
  split: SplitState,
259
284
  prefix: str,
260
- dtype: Type[torch.dtype],
285
+ dtype: type[torch.dtype],
261
286
  enforce_hbm: bool = False,
262
287
  make_dev_param: bool = False,
263
- dev_reshape: Optional[Tuple[int, ...]] = None,
264
- uvm_tensors_log: Optional[List[str]] = None,
288
+ dev_reshape: Optional[tuple[int, ...]] = None,
289
+ uvm_tensors_log: Optional[list[str]] = None,
265
290
  uvm_host_mapped: bool = False,
266
291
  ) -> None:
267
292
  set_attr_fn(f"{prefix}_physical_placements", split.placements)
@@ -346,6 +371,7 @@ def apply_split_helper(
346
371
  f"{prefix}_uvm",
347
372
  torch.zeros(
348
373
  split.uvm_size,
374
+ device=current_device,
349
375
  out=torch.ops.fbgemm.new_unified_tensor(
350
376
  # pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]`
351
377
  # for 3rd param but got `Type[Type[torch._dtype]]`.
@@ -621,11 +647,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
621
647
  (preshard_table_height, preshard_table_dim, height_offset, dim_offset)
622
648
  """
623
649
 
624
- embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
650
+ embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]]
625
651
  optimizer_args: invokers.lookup_args.OptimizerArgs
626
- lxu_cache_locations_list: List[Tensor]
652
+ lxu_cache_locations_list: list[Tensor]
627
653
  lxu_cache_locations_empty: Tensor
628
- timesteps_prefetched: List[int]
654
+ timesteps_prefetched: list[int]
655
+ prefetched_info_list: list[PrefetchedInfo]
629
656
  record_cache_metrics: RecordCacheMetrics
630
657
  # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
631
658
  uvm_cache_stats: torch.Tensor
@@ -639,10 +666,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
639
666
 
640
667
  def __init__( # noqa C901
641
668
  self,
642
- embedding_specs: List[
643
- Tuple[int, int, EmbeddingLocation, ComputeDevice]
669
+ embedding_specs: list[
670
+ tuple[int, int, EmbeddingLocation, ComputeDevice]
644
671
  ], # tuple of (rows, dims, placements, compute_devices)
645
- feature_table_map: Optional[List[int]] = None, # [T]
672
+ feature_table_map: Optional[list[int]] = None, # [T]
646
673
  cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
647
674
  cache_load_factor: float = 0.2,
648
675
  cache_sets: int = 0,
@@ -680,8 +707,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
680
707
  use_experimental_tbe: bool = False,
681
708
  prefetch_pipeline: bool = False,
682
709
  stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
683
- table_names: Optional[List[str]] = None,
684
- optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None,
710
+ table_names: Optional[list[str]] = None,
711
+ optimizer_state_dtypes: Optional[dict[str, SparseType]] = None,
685
712
  multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
686
713
  global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
687
714
  uvm_host_mapped: bool = False,
@@ -689,7 +716,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
689
716
  tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
690
717
  embedding_table_index_type: torch.dtype = torch.int64,
691
718
  embedding_table_offset_type: torch.dtype = torch.int64,
692
- embedding_shard_info: Optional[List[Tuple[int, int, int, int]]] = None,
719
+ embedding_shard_info: Optional[list[tuple[int, int, int, int]]] = None,
720
+ enable_raw_embedding_streaming: bool = False,
721
+ res_params: Optional[RESParams] = None,
693
722
  ) -> None:
694
723
  super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
695
724
  self.uuid = str(uuid.uuid4())
@@ -699,7 +728,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
699
728
  f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
700
729
  )
701
730
 
731
+ self.table_names: Optional[list[str]] = table_names
702
732
  self.logging_table_name: str = self.get_table_name_for_logging(table_names)
733
+ self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
703
734
  self.pooling_mode = pooling_mode
704
735
  self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
705
736
 
@@ -793,9 +824,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
793
824
  ), "Unique cache miss counters are not accurate in multipass prefetch and therefore not supported"
794
825
 
795
826
  self.embedding_specs = embedding_specs
796
- (rows, dims, locations, compute_devices) = zip(*embedding_specs)
827
+ rows, dims, locations, compute_devices = zip(*embedding_specs)
797
828
  T_ = len(self.embedding_specs)
798
- self.dims: List[int] = dims
829
+ self.dims: list[int] = dims
799
830
  assert T_ > 0
800
831
  # mixed D is not supported by no bag kernels
801
832
  mixed_D = False
@@ -808,7 +839,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
808
839
  assert (
809
840
  self.pooling_mode != PoolingMode.NONE
810
841
  ), "Mixed dimension tables only supported for pooling tables."
811
-
842
+ self.mixed_D: bool = mixed_D
812
843
  assert all(
813
844
  cd == compute_devices[0] for cd in compute_devices
814
845
  ), "Heterogenous compute_devices are NOT supported!"
@@ -872,7 +903,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
872
903
  self.stats_reporter: Optional[TBEStatsReporter] = (
873
904
  stats_reporter_config.create_reporter() if stats_reporter_config else None
874
905
  )
875
- self._uvm_tensors_log: List[str] = []
906
+ self._uvm_tensors_log: list[str] = []
876
907
 
877
908
  self.bwd_wait_prefetch_timer: Optional[AsyncSeriesTimer] = None
878
909
  self.prefetch_duration_timer: Optional[AsyncSeriesTimer] = None
@@ -899,12 +930,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
899
930
 
900
931
  self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET
901
932
 
902
- self.feature_table_map: List[int] = (
933
+ self.feature_table_map: list[int] = (
903
934
  feature_table_map if feature_table_map is not None else list(range(T_))
904
935
  )
905
936
 
906
937
  if embedding_shard_info:
907
- (full_table_heights, full_table_dims, row_offset, col_offset) = zip(
938
+ full_table_heights, full_table_dims, row_offset, col_offset = zip(
908
939
  *embedding_shard_info
909
940
  )
910
941
  else:
@@ -939,7 +970,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
939
970
  table_has_feature = [False] * T_
940
971
  for t in self.feature_table_map:
941
972
  table_has_feature[t] = True
942
- assert all(table_has_feature), "Each table must have at least one feature!"
973
+ assert all(table_has_feature), (
974
+ "Each table must have at least one feature!"
975
+ + f"{[(i, x) for i, x in enumerate(table_has_feature)]}"
976
+ )
943
977
 
944
978
  feature_dims = [dims[t] for t in self.feature_table_map]
945
979
  D_offsets = [0] + list(accumulate(feature_dims))
@@ -991,7 +1025,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
991
1025
  "feature_dims",
992
1026
  torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
993
1027
  )
994
- (_info_B_num_bits, _info_B_mask) = torch.ops.fbgemm.get_infos_metadata(
1028
+ _info_B_num_bits, _info_B_mask = torch.ops.fbgemm.get_infos_metadata(
995
1029
  self.D_offsets, # unused tensor
996
1030
  1, # max_B
997
1031
  T, # T
@@ -1105,13 +1139,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1105
1139
 
1106
1140
  if ensemble_mode is None:
1107
1141
  ensemble_mode = EnsembleModeDefinition()
1108
- self._ensemble_mode: Dict[str, float] = {
1142
+ self._ensemble_mode: dict[str, float] = {
1109
1143
  key: float(fval) for key, fval in ensemble_mode.__dict__.items()
1110
1144
  }
1111
1145
 
1112
1146
  if emainplace_mode is None:
1113
1147
  emainplace_mode = EmainplaceModeDefinition()
1114
- self._emainplace_mode: Dict[str, float] = {
1148
+ self._emainplace_mode: dict[str, float] = {
1115
1149
  key: float(fval) for key, fval in emainplace_mode.__dict__.items()
1116
1150
  }
1117
1151
 
@@ -1151,6 +1185,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1151
1185
  self.use_writeback_bwd_prehook: bool = (
1152
1186
  extra_optimizer_config.use_writeback_bwd_prehook
1153
1187
  )
1188
+
1189
+ writeback_first_feature_only: bool = (
1190
+ extra_optimizer_config.writeback_first_feature_only
1191
+ )
1154
1192
  self.log(f"self.extra_optimizer_config is {extra_optimizer_config}")
1155
1193
  if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
1156
1194
  raise AssertionError(
@@ -1416,7 +1454,11 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1416
1454
 
1417
1455
  self.step = 0
1418
1456
  self.last_reported_step = 0
1419
- self.last_reported_uvm_stats: List[float] = []
1457
+ self.last_reported_uvm_stats: list[float] = []
1458
+ # Track number of times detailed memory breakdown has been reported
1459
+ self.detailed_mem_breakdown_report_count = 0
1460
+ # Set max number of reports for detailed memory breakdown
1461
+ self.max_detailed_mem_breakdown_reports = 10
1420
1462
 
1421
1463
  # Check whether to use TBE v2
1422
1464
  is_experimental = False
@@ -1435,16 +1477,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1435
1477
  # self.log("TBE_V2 Knob is set to True; Using experimental TBE")
1436
1478
 
1437
1479
  self.is_experimental: bool = is_experimental
1480
+ self._writeback_first_feature_only: bool = writeback_first_feature_only
1438
1481
 
1439
1482
  # Get a debug function pointer
1440
1483
  self._debug_print_input_stats: Callable[..., None] = (
1441
1484
  self._debug_print_input_stats_factory()
1442
1485
  )
1443
1486
 
1487
+ # Get a reporter function pointer
1488
+ self._report_input_params: Callable[..., None] = (
1489
+ self.__report_input_params_factory()
1490
+ )
1491
+
1444
1492
  if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
1445
1493
  # Register writeback hook for Exact_SGD optimizer
1446
1494
  self.log(
1447
- "SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled."
1495
+ f"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled with first feature only={self._writeback_first_feature_only}"
1448
1496
  )
1449
1497
  # pyre-fixme[6]: Expected `typing.Callable[[Module, Union[Tensor, typing.Tuple[Tensor, ...]]], Union[None, Tensor, typing.Tuple[Tensor, ...]]]`
1450
1498
  self.register_full_backward_pre_hook(self.writeback_hook)
@@ -1460,6 +1508,30 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1460
1508
  )
1461
1509
  self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type
1462
1510
 
1511
+ self.prefetched_info_list: list[PrefetchedInfo] = torch.jit.annotate(
1512
+ list[PrefetchedInfo], []
1513
+ )
1514
+ if self.enable_raw_embedding_streaming:
1515
+ self.res_params: RESParams = res_params or RESParams()
1516
+ self.res_params.table_sizes = [0] + list(accumulate(rows))
1517
+ res_port_from_env = os.getenv("LOCAL_RES_PORT")
1518
+ self.res_params.res_server_port = (
1519
+ int(res_port_from_env) if res_port_from_env else 0
1520
+ )
1521
+ # pyre-fixme[4]: Attribute must be annotated.
1522
+ self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
1523
+ self.uuid,
1524
+ self.enable_raw_embedding_streaming,
1525
+ self.res_params.res_store_shards,
1526
+ self.res_params.res_server_port,
1527
+ self.res_params.table_names,
1528
+ self.res_params.table_offsets,
1529
+ self.res_params.table_sizes,
1530
+ )
1531
+ logging.info(
1532
+ f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
1533
+ )
1534
+
1463
1535
  @torch.jit.ignore
1464
1536
  def log(self, msg: str) -> None:
1465
1537
  """
@@ -1503,7 +1575,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1503
1575
  )
1504
1576
 
1505
1577
  @staticmethod
1506
- def get_table_name_for_logging(table_names: Optional[List[str]]) -> str:
1578
+ def get_table_name_for_logging(table_names: Optional[list[str]]) -> str:
1507
1579
  """
1508
1580
  Given a list of all table names in the TBE, generate a string to
1509
1581
  represent them in logging. If there is more than one table, this method
@@ -1519,17 +1591,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1519
1591
  return "<Unknown>"
1520
1592
  # Do this because sometimes multiple shards of the same table could appear
1521
1593
  # in one TBE.
1522
- table_name_set = set(table_names)
1594
+ table_name_set = sorted(set(table_names))
1523
1595
  if len(table_name_set) == 1:
1524
1596
  return next(iter(table_name_set))
1525
- return f"<{len(table_name_set)} tables>"
1597
+ return f"<{len(table_name_set)} tables>: {table_name_set}"
1526
1598
 
1527
1599
  @staticmethod
1528
1600
  def get_prefetch_passes(
1529
1601
  multipass_prefetch_config: Optional[MultiPassPrefetchConfig],
1530
1602
  input_tensor: Tensor,
1531
1603
  output_tensor: Tensor,
1532
- ) -> List[Tuple[Tensor, Tensor, int]]:
1604
+ ) -> list[tuple[Tensor, Tensor, int]]:
1533
1605
  """
1534
1606
  Given inputs (the indices to forward), partition the input and output
1535
1607
  into smaller chunks and return them as a list of tuples
@@ -1577,7 +1649,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1577
1649
  )
1578
1650
  )
1579
1651
 
1580
- def get_states(self, prefix: str) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
1652
+ def get_states(self, prefix: str) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
1581
1653
  """
1582
1654
  Get a state of a given tensor (`prefix`)
1583
1655
 
@@ -1616,7 +1688,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1616
1688
  torch.tensor(offsets, dtype=torch.int64),
1617
1689
  )
1618
1690
 
1619
- def get_all_states(self) -> List[Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
1691
+ def get_all_states(self) -> list[tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
1620
1692
  """
1621
1693
  Get all states in the TBE (`weights`, `momentum1`, `momentum2`,
1622
1694
  `prev_iter`, and `row_counter`)
@@ -1680,10 +1752,161 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1680
1752
  tbe_id=self.uuid,
1681
1753
  )
1682
1754
 
1683
- @torch.jit.ignore
1684
- def _report_tbe_mem_usage(
1755
+ def _get_tensor_memory(self, tensor_name: str) -> int:
1756
+ """Get memory usage of a tensor in bytes."""
1757
+ if not hasattr(self, tensor_name):
1758
+ self.log(f"Tensor '{tensor_name}' not found, using 0 bytes")
1759
+ return 0
1760
+ tensor = getattr(self, tensor_name)
1761
+ return tensor.numel() * tensor.element_size()
1762
+
1763
+ def _categorize_memory_by_location(
1764
+ self, tensor_names: list[str]
1765
+ ) -> tuple[int, int]:
1766
+ """Categorize memory into HBM and UVM for given tensors.
1767
+
1768
+ Returns:
1769
+ (hbm_bytes, uvm_bytes)
1770
+ """
1771
+ uvm_set = set(self._uvm_tensors_log)
1772
+ hbm_bytes = 0
1773
+ uvm_bytes = 0
1774
+
1775
+ for name in tensor_names:
1776
+ size = self._get_tensor_memory(name)
1777
+ if name in uvm_set:
1778
+ uvm_bytes += size
1779
+ else:
1780
+ hbm_bytes += size
1781
+
1782
+ return hbm_bytes, uvm_bytes
1783
+
1784
+ def _report_hbm_breakdown(
1785
+ self,
1786
+ stats_reporter: TBEStatsReporter,
1787
+ embeddings: int,
1788
+ optimizer_states: int,
1789
+ cache: int,
1790
+ total_static_sparse: int,
1791
+ ephemeral: int,
1792
+ cache_weights: int = 0,
1793
+ cache_aux: int = 0,
1794
+ ) -> None:
1795
+ """Report HBM memory breakdown to stats reporter."""
1796
+ stats_reporter.report_data_amount(
1797
+ iteration_step=self.step,
1798
+ event_name="tbe.hbm.embeddings",
1799
+ data_bytes=embeddings,
1800
+ embedding_id=self.logging_table_name,
1801
+ tbe_id=self.uuid,
1802
+ )
1803
+ stats_reporter.report_data_amount(
1804
+ iteration_step=self.step,
1805
+ event_name="tbe.hbm.optimizer_states",
1806
+ data_bytes=optimizer_states,
1807
+ embedding_id=self.logging_table_name,
1808
+ tbe_id=self.uuid,
1809
+ )
1810
+ stats_reporter.report_data_amount(
1811
+ iteration_step=self.step,
1812
+ event_name="tbe.hbm.cache",
1813
+ data_bytes=cache,
1814
+ embedding_id=self.logging_table_name,
1815
+ tbe_id=self.uuid,
1816
+ )
1817
+ stats_reporter.report_data_amount(
1818
+ iteration_step=self.step,
1819
+ event_name="tbe.hbm.cache_weights",
1820
+ data_bytes=cache_weights,
1821
+ embedding_id=self.logging_table_name,
1822
+ tbe_id=self.uuid,
1823
+ )
1824
+ stats_reporter.report_data_amount(
1825
+ iteration_step=self.step,
1826
+ event_name="tbe.hbm.cache_aux",
1827
+ data_bytes=cache_aux,
1828
+ embedding_id=self.logging_table_name,
1829
+ tbe_id=self.uuid,
1830
+ )
1831
+ stats_reporter.report_data_amount(
1832
+ iteration_step=self.step,
1833
+ event_name="tbe.hbm.total_static_sparse",
1834
+ data_bytes=total_static_sparse,
1835
+ embedding_id=self.logging_table_name,
1836
+ tbe_id=self.uuid,
1837
+ )
1838
+ stats_reporter.report_data_amount(
1839
+ iteration_step=self.step,
1840
+ event_name="tbe.hbm.ephemeral",
1841
+ data_bytes=ephemeral,
1842
+ embedding_id=self.logging_table_name,
1843
+ tbe_id=self.uuid,
1844
+ )
1845
+
1846
+ def _report_uvm_breakdown(
1685
1847
  self,
1848
+ stats_reporter: TBEStatsReporter,
1849
+ embeddings: int,
1850
+ optimizer_states: int,
1851
+ cache: int,
1852
+ total_static_sparse: int,
1853
+ ephemeral: int,
1854
+ cache_weights: int = 0,
1855
+ cache_aux: int = 0,
1686
1856
  ) -> None:
1857
+ """Report UVM memory breakdown to stats reporter."""
1858
+ stats_reporter.report_data_amount(
1859
+ iteration_step=self.step,
1860
+ event_name="tbe.uvm.embeddings",
1861
+ data_bytes=embeddings,
1862
+ embedding_id=self.logging_table_name,
1863
+ tbe_id=self.uuid,
1864
+ )
1865
+ stats_reporter.report_data_amount(
1866
+ iteration_step=self.step,
1867
+ event_name="tbe.uvm.optimizer_states",
1868
+ data_bytes=optimizer_states,
1869
+ embedding_id=self.logging_table_name,
1870
+ tbe_id=self.uuid,
1871
+ )
1872
+ stats_reporter.report_data_amount(
1873
+ iteration_step=self.step,
1874
+ event_name="tbe.uvm.cache",
1875
+ data_bytes=cache,
1876
+ embedding_id=self.logging_table_name,
1877
+ tbe_id=self.uuid,
1878
+ )
1879
+ stats_reporter.report_data_amount(
1880
+ iteration_step=self.step,
1881
+ event_name="tbe.uvm.cache_weights",
1882
+ data_bytes=cache_weights,
1883
+ embedding_id=self.logging_table_name,
1884
+ tbe_id=self.uuid,
1885
+ )
1886
+ stats_reporter.report_data_amount(
1887
+ iteration_step=self.step,
1888
+ event_name="tbe.uvm.cache_aux",
1889
+ data_bytes=cache_aux,
1890
+ embedding_id=self.logging_table_name,
1891
+ tbe_id=self.uuid,
1892
+ )
1893
+ stats_reporter.report_data_amount(
1894
+ iteration_step=self.step,
1895
+ event_name="tbe.uvm.total_static_sparse",
1896
+ data_bytes=total_static_sparse,
1897
+ embedding_id=self.logging_table_name,
1898
+ tbe_id=self.uuid,
1899
+ )
1900
+ stats_reporter.report_data_amount(
1901
+ iteration_step=self.step,
1902
+ event_name="tbe.uvm.ephemeral",
1903
+ data_bytes=ephemeral,
1904
+ embedding_id=self.logging_table_name,
1905
+ tbe_id=self.uuid,
1906
+ )
1907
+
1908
+ @torch.jit.ignore
1909
+ def _report_tbe_mem_usage(self) -> None:
1687
1910
  if self.stats_reporter is None:
1688
1911
  return
1689
1912
 
@@ -1691,22 +1914,24 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1691
1914
  if not stats_reporter.should_report(self.step):
1692
1915
  return
1693
1916
 
1917
+ # Calculate total memory from all parameters and buffers (always needed)
1694
1918
  total_mem_usage = sum(
1695
- param.numel() * param.element_size() for param in self.parameters()
1696
- ) + sum(buffer.numel() * buffer.element_size() for buffer in self.buffers())
1919
+ p.numel() * p.element_size() for p in self.parameters()
1920
+ ) + sum(b.numel() * b.element_size() for b in self.buffers())
1921
+
1922
+ # Calculate total HBM and UVM usage (always needed)
1697
1923
  if self.use_cpu:
1698
1924
  total_hbm_usage = 0
1699
1925
  total_uvm_usage = total_mem_usage
1700
1926
  else:
1701
- # hbm usage is total usage minus uvm usage
1702
1927
  total_uvm_usage = sum(
1703
- getattr(self, tensor_name).numel()
1704
- * getattr(self, tensor_name).element_size()
1705
- for tensor_name in self._uvm_tensors_log
1706
- if hasattr(self, tensor_name)
1928
+ self._get_tensor_memory(name)
1929
+ for name in self._uvm_tensors_log
1930
+ if hasattr(self, name)
1707
1931
  )
1708
1932
  total_hbm_usage = total_mem_usage - total_uvm_usage
1709
1933
 
1934
+ # Report total memory usage metrics (always reported for backward compatibility)
1710
1935
  stats_reporter.report_data_amount(
1711
1936
  iteration_step=self.step,
1712
1937
  event_name="tbe.total_hbm_usage",
@@ -1722,6 +1947,96 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1722
1947
  tbe_id=self.uuid,
1723
1948
  )
1724
1949
 
1950
+ # Only report detailed breakdown for the first max_detailed_mem_breakdown_reports reportable
1951
+ # steps since static sparse memory (weights, optimizer states, cache) is constant
1952
+ if (
1953
+ self.detailed_mem_breakdown_report_count
1954
+ >= self.max_detailed_mem_breakdown_reports
1955
+ ):
1956
+ return
1957
+ self.detailed_mem_breakdown_report_count += 1
1958
+
1959
+ # Tensor groups for sparse memory categorization
1960
+ weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]
1961
+ optimizer_tensors = [
1962
+ "momentum1_dev",
1963
+ "momentum1_host",
1964
+ "momentum1_uvm",
1965
+ "momentum2_dev",
1966
+ "momentum2_host",
1967
+ "momentum2_uvm",
1968
+ ]
1969
+ # Cache weights tensor (the actual cached embeddings in HBM)
1970
+ cache_weight_tensors = [
1971
+ "lxu_cache_weights",
1972
+ ]
1973
+ # Cache auxiliary state tensors (metadata for cache management, excluding weights)
1974
+ # Sizes scale with hash_size or cache_slots (hash_size × clf)
1975
+ # Excludes constant-size tensors: cache_hash_size_cumsum, cache_miss_counter, etc.
1976
+ cache_aux_tensors = [
1977
+ "cache_index_table_map", # int32, 4B × hash_size
1978
+ "lxu_cache_state", # int64, 8B × cache_slots
1979
+ "lxu_state", # int64, 8B × cache_slots (LRU) or hash_size (LFU)
1980
+ "lxu_cache_locking_counter", # int32, 4B × cache_slots (only if prefetch_pipeline)
1981
+ ]
1982
+
1983
+ # Calculate total memory for each component
1984
+ weights_total = sum(self._get_tensor_memory(t) for t in weight_tensors)
1985
+ optimizer_total = sum(self._get_tensor_memory(t) for t in optimizer_tensors)
1986
+ cache_weights_total = sum(
1987
+ self._get_tensor_memory(t) for t in cache_weight_tensors
1988
+ )
1989
+ cache_aux_total = sum(self._get_tensor_memory(t) for t in cache_aux_tensors)
1990
+
1991
+ # Categorize memory by location (HBM vs UVM)
1992
+ if self.use_cpu:
1993
+ weights_hbm, weights_uvm = 0, weights_total
1994
+ opt_hbm, opt_uvm = 0, optimizer_total
1995
+ cache_weights_hbm, cache_weights_uvm = 0, cache_weights_total
1996
+ cache_aux_hbm, cache_aux_uvm = 0, cache_aux_total
1997
+ else:
1998
+ weights_hbm, weights_uvm = self._categorize_memory_by_location(
1999
+ weight_tensors
2000
+ )
2001
+ opt_hbm, opt_uvm = self._categorize_memory_by_location(optimizer_tensors)
2002
+ cache_weights_hbm, cache_weights_uvm = self._categorize_memory_by_location(
2003
+ cache_weight_tensors
2004
+ )
2005
+ cache_aux_hbm, cache_aux_uvm = self._categorize_memory_by_location(
2006
+ cache_aux_tensors
2007
+ )
2008
+
2009
+ # Calculate ephemeral memory split between HBM and UVM
2010
+ # Total cache = cache weights + cache auxiliary state
2011
+ cache_hbm = cache_weights_hbm + cache_aux_hbm
2012
+ cache_uvm = cache_weights_uvm + cache_aux_uvm
2013
+ static_sparse_hbm = weights_hbm + opt_hbm + cache_hbm
2014
+ static_sparse_uvm = weights_uvm + opt_uvm + cache_uvm
2015
+ ephemeral_hbm = total_hbm_usage - static_sparse_hbm
2016
+ ephemeral_uvm = total_uvm_usage - static_sparse_uvm
2017
+
2018
+ # Report granular memory breakdowns
2019
+ self._report_hbm_breakdown(
2020
+ stats_reporter,
2021
+ weights_hbm,
2022
+ opt_hbm,
2023
+ cache_hbm,
2024
+ static_sparse_hbm,
2025
+ ephemeral_hbm,
2026
+ cache_weights_hbm,
2027
+ cache_aux_hbm,
2028
+ )
2029
+ self._report_uvm_breakdown(
2030
+ stats_reporter,
2031
+ weights_uvm,
2032
+ opt_uvm,
2033
+ cache_uvm,
2034
+ static_sparse_uvm,
2035
+ ephemeral_uvm,
2036
+ cache_weights_uvm,
2037
+ cache_aux_uvm,
2038
+ )
2039
+
1725
2040
  @torch.jit.ignore
1726
2041
  def _report_io_size_count(self, event: str, data: Tensor) -> Tensor:
1727
2042
  if self.stats_reporter is None:
@@ -1748,7 +2063,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1748
2063
  def _generate_vbe_metadata(
1749
2064
  self,
1750
2065
  offsets: Tensor,
1751
- batch_size_per_feature_per_rank: Optional[List[List[int]]],
2066
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
2067
+ vbe_output: Optional[Tensor] = None,
2068
+ vbe_output_offsets: Optional[Tensor] = None,
1752
2069
  ) -> invokers.lookup_args.VBEMetadata:
1753
2070
  # Blocking D2H copy, but only runs at first call
1754
2071
  self.feature_dims = self.feature_dims.cpu()
@@ -1771,6 +2088,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1771
2088
  self.pooling_mode,
1772
2089
  self.feature_dims,
1773
2090
  self.current_device,
2091
+ vbe_output,
2092
+ vbe_output_offsets,
1774
2093
  )
1775
2094
 
1776
2095
  @torch.jit.ignore
@@ -1779,40 +2098,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1779
2098
  # This allows models using this class to compile correctly
1780
2099
  return FeatureGate.is_enabled(feature)
1781
2100
 
1782
- def writeback_update_gradient(
1783
- self, indices: torch.Tensor, offsets: torch.Tensor, grad: Tensor
1784
- ) -> Tensor:
1785
- if indices.numel() == 0:
1786
- return grad[0]
1787
- num_of_tables = len(set(self.feature_table_map))
1788
- assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
1789
- batch_size = offsets.shape[0] // num_of_tables
1790
- max_indices = indices.max()
1791
- non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
1792
- # disable dedup across different table
1793
- indices = ((offsets[non_empty_index]) // batch_size) * (
1794
- 1 + max_indices
1795
- ) + indices
1796
- grad = grad[0]
1797
- _, idx, counts = torch.unique(
1798
- indices, dim=0, sorted=True, return_inverse=True, return_counts=True
1799
- )
1800
- _, ind_sorted = torch.sort(idx, stable=True)
1801
- cum_sum = counts.cumsum(0)
1802
- cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
1803
- first_indicies = ind_sorted[cum_sum]
1804
- mask = torch.zeros_like(grad, device=grad.device)
1805
- original_index = non_empty_index[first_indicies]
1806
-
1807
- mask[original_index] = grad[original_index]
1808
- return mask
1809
-
1810
2101
  # pyre-fixme[2]: For 1st argument expected not ANY
1811
- def writeback_hook(self, module: Any, grad: Tensor) -> Tuple[Tensor]:
2102
+ def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]:
1812
2103
  indices = self._indices
1813
2104
  offsets = self._offsets
1814
-
1815
- return (self.writeback_update_gradient(indices, offsets, grad),)
2105
+ return writeback_gradient(
2106
+ grad,
2107
+ indices,
2108
+ offsets,
2109
+ self.feature_table_map,
2110
+ self._writeback_first_feature_only,
2111
+ )
1816
2112
 
1817
2113
  def forward( # noqa: C901
1818
2114
  self,
@@ -1820,8 +2116,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1820
2116
  offsets: Tensor,
1821
2117
  per_sample_weights: Optional[Tensor] = None,
1822
2118
  feature_requires_grad: Optional[Tensor] = None,
1823
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
2119
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
1824
2120
  total_unique_indices: Optional[int] = None,
2121
+ hash_zch_identities: Optional[Tensor] = None,
2122
+ hash_zch_runtime_meta: Optional[Tensor] = None,
2123
+ vbe_output: Optional[Tensor] = None,
2124
+ vbe_output_offsets: Optional[Tensor] = None,
1825
2125
  ) -> Tensor:
1826
2126
  """
1827
2127
  The forward pass function that
@@ -1874,7 +2174,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1874
2174
  be set when using `OptimType.NONE`. This is because TBE
1875
2175
  requires this information for allocating the weight gradient
1876
2176
  tensor in the backward pass.
1877
-
2177
+ hash_zch_identities (Optional[Tensor]): The original raw IDs before
2178
+ remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
2179
+ populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
2180
+ and is required for Raw Embedding Streaming (RES) to maintain
2181
+ consistency between training and inference.
2182
+ vbe_output (Optional[Tensor]): An optional 2-D tensor of size that
2183
+ contains output for TBE VBE. The shape of the tensor is
2184
+ [1, total_vbe_output_size] where total_vbe_output_size is the
2185
+ output size across all ranks and all embedding tables.
2186
+ If this tensor is not None, the TBE VBE forward output is written
2187
+ to this tensor at the locations specified by `vbe_output_offsets`.
2188
+ vbe_output_offsets (Optional[Tensor]): An optional 2-D tensor that
2189
+ contains VBE output offsets to `vbe_output`. The shape of the
2190
+ tensor is [num_ranks, num_features].
2191
+ vbe_output_offsets[r][f] represents the starting offset for rank `r`
2192
+ and feature `f`.
1878
2193
  Returns:
1879
2194
  A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
1880
2195
  batch size and `total_D` = the sum of all embedding dimensions in the
@@ -1948,11 +2263,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1948
2263
  batch_size_per_feature_per_rank,
1949
2264
  force_cast_input_types=True,
1950
2265
  prefetch_pipeline=False,
2266
+ vbe_output=vbe_output,
2267
+ vbe_output_offsets=vbe_output_offsets,
1951
2268
  )
1952
2269
 
2270
+ # Only enable VBE if batch_size_per_feature_per_rank is not None
2271
+ assert not (
2272
+ batch_size_per_feature_per_rank is not None
2273
+ and self.use_writeback_bwd_prehook
2274
+ ), "VBE is not supported with writeback_bwd_prehook"
2275
+
1953
2276
  # Print input stats if enable (for debugging purpose only)
1954
2277
  self._debug_print_input_stats(indices, offsets, per_sample_weights)
1955
2278
 
2279
+ # Extract and Write input stats if enable
2280
+ if self._report_input_params is not None:
2281
+ self._report_input_params(
2282
+ feature_rows=self.rows_per_table,
2283
+ feature_dims=self.feature_dims,
2284
+ iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0,
2285
+ indices=indices,
2286
+ offsets=offsets,
2287
+ op_id=self.uuid,
2288
+ per_sample_weights=per_sample_weights,
2289
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2290
+ embedding_specs=[(s[0], s[1]) for s in self.embedding_specs],
2291
+ feature_table_map=self.feature_table_map,
2292
+ )
2293
+
1956
2294
  if not is_torchdynamo_compiling():
1957
2295
  # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
1958
2296
 
@@ -1980,7 +2318,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1980
2318
  # to be as fast as possible and memory usage doesn't matter (will be recycled
1981
2319
  # by dense fwd/bwd)
1982
2320
  self._prefetch(
1983
- indices, offsets, vbe_metadata, multipass_prefetch_config=None
2321
+ indices,
2322
+ offsets,
2323
+ vbe_metadata,
2324
+ multipass_prefetch_config=None,
2325
+ hash_zch_identities=hash_zch_identities,
2326
+ hash_zch_runtime_meta=hash_zch_runtime_meta,
1984
2327
  )
1985
2328
 
1986
2329
  if len(self.timesteps_prefetched) > 0:
@@ -2262,6 +2605,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2262
2605
  row_counter,
2263
2606
  iter_int,
2264
2607
  self.max_counter.item(),
2608
+ mixed_D=self.mixed_D,
2265
2609
  ),
2266
2610
  )
2267
2611
  elif self._used_rowwise_adagrad_with_global_weight_decay:
@@ -2280,6 +2624,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2280
2624
  # `Optional[Tensor]` but got `Union[Module, Tensor]`.
2281
2625
  prev_iter_dev=self.prev_iter_dev,
2282
2626
  gwd_lower_bound=self.gwd_lower_bound,
2627
+ mixed_D=self.mixed_D,
2283
2628
  ),
2284
2629
  )
2285
2630
  else:
@@ -2289,12 +2634,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2289
2634
  common_args,
2290
2635
  self.optimizer_args,
2291
2636
  momentum1,
2637
+ mixed_D=self.mixed_D,
2292
2638
  ),
2293
2639
  )
2294
2640
 
2295
2641
  raise ValueError(f"Invalid OptimType: {self.optimizer}")
2296
2642
 
2297
- def ema_inplace(self, emainplace_mode: Dict[str, float]) -> None:
2643
+ def ema_inplace(self, emainplace_mode: dict[str, float]) -> None:
2298
2644
  """
2299
2645
  Perform ema operations on the full sparse embedding tables.
2300
2646
  We organize the sparse table, in the following way.
@@ -2324,7 +2670,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2324
2670
  emainplace_mode["step_ema_coef"],
2325
2671
  )
2326
2672
 
2327
- def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None:
2673
+ def ensemble_and_swap(self, ensemble_mode: dict[str, float]) -> None:
2328
2674
  """
2329
2675
  Perform ensemble and swap operations on the full sparse embedding tables.
2330
2676
 
@@ -2372,7 +2718,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2372
2718
  ), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
2373
2719
  return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats
2374
2720
 
2375
- def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> List[float]:
2721
+ def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> list[float]:
2376
2722
  snapshot = self.get_uvm_cache_stats(use_local_cache)
2377
2723
  if use_local_cache:
2378
2724
  return snapshot.tolist()
@@ -2385,7 +2731,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2385
2731
  @torch.jit.ignore
2386
2732
  def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None:
2387
2733
  # TODO: Create a separate reporter class to unify the stdlog reporting
2388
- uvm_cache_stats: List[float] = self._get_uvm_cache_print_state(use_local_cache)
2734
+ uvm_cache_stats: list[float] = self._get_uvm_cache_print_state(use_local_cache)
2389
2735
  N = max(1, uvm_cache_stats[0])
2390
2736
  m = {
2391
2737
  "N_called": uvm_cache_stats[UVMCacheStatsIndex.num_calls],
@@ -2429,14 +2775,14 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2429
2775
  if not stats_reporter.should_report(self.step):
2430
2776
  return
2431
2777
 
2432
- uvm_cache_stats: List[float] = self.get_uvm_cache_stats(
2778
+ uvm_cache_stats: list[float] = self.get_uvm_cache_stats(
2433
2779
  use_local_cache=False
2434
2780
  ).tolist()
2435
2781
  self.last_reported_step = self.step
2436
2782
 
2437
2783
  if len(self.last_reported_uvm_stats) == 0:
2438
2784
  self.last_reported_uvm_stats = [0.0] * len(uvm_cache_stats)
2439
- uvm_cache_stats_delta: List[float] = [0.0] * len(uvm_cache_stats)
2785
+ uvm_cache_stats_delta: list[float] = [0.0] * len(uvm_cache_stats)
2440
2786
  for i in range(len(uvm_cache_stats)):
2441
2787
  uvm_cache_stats_delta[i] = (
2442
2788
  uvm_cache_stats[i] - self.last_reported_uvm_stats[i]
@@ -2465,7 +2811,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2465
2811
  indices: Tensor,
2466
2812
  offsets: Tensor,
2467
2813
  forward_stream: Optional[torch.cuda.Stream] = None,
2468
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
2814
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
2469
2815
  ) -> None:
2470
2816
  if self.prefetch_stream is None and forward_stream is not None:
2471
2817
  self.prefetch_stream = torch.cuda.current_stream()
@@ -2473,20 +2819,21 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2473
2819
  self.prefetch_stream != forward_stream
2474
2820
  ), "prefetch_stream and forward_stream should not be the same stream"
2475
2821
 
2476
- indices, offsets, _, vbe_metadata = self.prepare_inputs(
2477
- indices,
2478
- offsets,
2479
- per_sample_weights=None,
2480
- batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2481
- force_cast_input_types=False,
2482
- prefetch_pipeline=self.prefetch_pipeline,
2483
- )
2484
-
2485
2822
  with self._recording_to_timer(
2486
2823
  self.prefetch_duration_timer,
2487
2824
  context=self.step,
2488
2825
  stream=torch.cuda.current_stream(),
2489
2826
  ):
2827
+
2828
+ indices, offsets, _, vbe_metadata = self.prepare_inputs(
2829
+ indices,
2830
+ offsets,
2831
+ per_sample_weights=None,
2832
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2833
+ force_cast_input_types=False,
2834
+ prefetch_pipeline=self.prefetch_pipeline,
2835
+ )
2836
+
2490
2837
  self._prefetch(
2491
2838
  indices,
2492
2839
  offsets,
@@ -2503,6 +2850,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2503
2850
  offsets: Tensor,
2504
2851
  vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
2505
2852
  multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
2853
+ hash_zch_identities: Optional[Tensor] = None,
2854
+ hash_zch_runtime_meta: Optional[Tensor] = None,
2506
2855
  ) -> None:
2507
2856
  if not is_torchdynamo_compiling():
2508
2857
  # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
@@ -2521,7 +2870,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2521
2870
  self.local_uvm_cache_stats.zero_()
2522
2871
  self._report_io_size_count("prefetch_input", indices)
2523
2872
 
2873
+ # streaming before updating the cache
2874
+ self.raw_embedding_stream()
2875
+
2524
2876
  final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
2877
+ linear_cache_indices_merged = torch.zeros(
2878
+ 0, dtype=indices.dtype, device=indices.device
2879
+ )
2525
2880
  for (
2526
2881
  partial_indices,
2527
2882
  partial_lxu_cache_locations,
@@ -2537,6 +2892,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2537
2892
  vbe_metadata.max_B if vbe_metadata is not None else -1,
2538
2893
  base_offset,
2539
2894
  )
2895
+ linear_cache_indices_merged = torch.cat(
2896
+ [linear_cache_indices_merged, linear_cache_indices]
2897
+ )
2540
2898
 
2541
2899
  if (
2542
2900
  self.record_cache_metrics.record_cache_miss_counter
@@ -2617,6 +2975,16 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2617
2975
  if self.should_log():
2618
2976
  self.print_uvm_cache_stats(use_local_cache=False)
2619
2977
 
2978
+ self._store_prefetched_tensors(
2979
+ indices,
2980
+ offsets,
2981
+ vbe_metadata,
2982
+ linear_cache_indices_merged,
2983
+ final_lxu_cache_locations,
2984
+ hash_zch_identities,
2985
+ hash_zch_runtime_meta,
2986
+ )
2987
+
2620
2988
  def should_log(self) -> bool:
2621
2989
  """Determines if we should log for this step, using exponentially decreasing frequency.
2622
2990
 
@@ -2701,12 +3069,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2701
3069
  tmp_emb.uniform_(min_val, max_val)
2702
3070
  tmp_emb_i8 = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(tmp_emb)
2703
3071
  emb.data.copy_(tmp_emb_i8)
3072
+ # Torch doesnt implement direct fp8 distribution functions, so we need to start in higher precision.
3073
+ elif self.weights_precision == SparseType.NFP8:
3074
+ assert (
3075
+ self.current_device.type == "cuda"
3076
+ ), "NFP8 is currently only supportd on GPU."
3077
+ assert self.optimizer in [
3078
+ OptimType.EXACT_ADAGRAD,
3079
+ OptimType.ROWWISE_ADAGRAD,
3080
+ OptimType.EXACT_ROWWISE_ADAGRAD,
3081
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
3082
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
3083
+ ], "NFP8 is currently only supportd with adagrad optimizers."
3084
+ for param in splits:
3085
+ tmp_param = torch.zeros(param.shape, device=self.current_device)
3086
+ # Create initialized weights and cast to fp8.
3087
+ fp8_dtype = (
3088
+ torch.float8_e4m3fnuz
3089
+ if torch.version.hip is not None
3090
+ else torch.float8_e4m3fn
3091
+ )
3092
+ tmp_param.uniform_(min_val, max_val).to(fp8_dtype)
3093
+ param.data.copy_(tmp_param)
2704
3094
  else:
2705
3095
  for param in splits:
2706
3096
  param.uniform_(min_val, max_val)
2707
3097
 
2708
3098
  @torch.jit.ignore
2709
- def split_embedding_weights(self) -> List[Tensor]:
3099
+ def split_embedding_weights(self) -> list[Tensor]:
2710
3100
  """
2711
3101
  Returns a list of embedding weights (view), split by table
2712
3102
 
@@ -2748,7 +3138,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2748
3138
  raise ValueError(f"Optimizer buffer {state} not found")
2749
3139
 
2750
3140
  @torch.jit.export
2751
- def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
3141
+ def get_optimizer_state(self) -> list[dict[str, torch.Tensor]]:
2752
3142
  r"""
2753
3143
  Get the optimizer state dict that matches the OSS Pytorch optims
2754
3144
  TODO: populate the supported list of optimizers
@@ -2832,7 +3222,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2832
3222
  @torch.jit.ignore
2833
3223
  def split_optimizer_states(
2834
3224
  self,
2835
- ) -> List[List[torch.Tensor]]:
3225
+ ) -> list[list[torch.Tensor]]:
2836
3226
  """
2837
3227
  Returns a list of optimizer states (view), split by table
2838
3228
 
@@ -2880,7 +3270,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2880
3270
  state_offsets: Tensor,
2881
3271
  state_placements: Tensor,
2882
3272
  rowwise: bool,
2883
- ) -> List[torch.Tensor]:
3273
+ ) -> list[torch.Tensor]:
2884
3274
  splits = []
2885
3275
  for t, (rows, dim, _, _) in enumerate(self.embedding_specs):
2886
3276
  offset = state_offsets[t]
@@ -2899,7 +3289,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2899
3289
  splits.append(state.detach()[offset : offset + rows].view(rows))
2900
3290
  return splits
2901
3291
 
2902
- states: List[List[torch.Tensor]] = []
3292
+ states: list[list[torch.Tensor]] = []
2903
3293
  if self.optimizer not in (OptimType.EXACT_SGD,):
2904
3294
  states.append(
2905
3295
  get_optimizer_states(
@@ -3025,7 +3415,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3025
3415
  return self.learning_rate_tensor.item()
3026
3416
 
3027
3417
  @torch.jit.ignore
3028
- def update_hyper_parameters(self, params_dict: Dict[str, float]) -> None:
3418
+ def update_hyper_parameters(self, params_dict: dict[str, float]) -> None:
3029
3419
  """
3030
3420
  Sets hyper-parameters from external control flow.
3031
3421
 
@@ -3101,10 +3491,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3101
3491
  self,
3102
3492
  split: SplitState,
3103
3493
  prefix: str,
3104
- dtype: Type[torch.dtype],
3494
+ dtype: type[torch.dtype],
3105
3495
  enforce_hbm: bool = False,
3106
3496
  make_dev_param: bool = False,
3107
- dev_reshape: Optional[Tuple[int, ...]] = None,
3497
+ dev_reshape: Optional[tuple[int, ...]] = None,
3108
3498
  uvm_host_mapped: bool = False,
3109
3499
  ) -> None:
3110
3500
  apply_split_helper(
@@ -3154,6 +3544,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3154
3544
  dtype = torch.float32
3155
3545
  elif cache_precision == SparseType.FP16:
3156
3546
  dtype = torch.float16
3547
+ elif cache_precision == SparseType.NFP8:
3548
+ # NFP8 weights use floating point cache.
3549
+ dtype = torch.float16
3157
3550
  else:
3158
3551
  dtype = torch.float32 # not relevant, but setting it to keep linter happy
3159
3552
  if not self.use_cpu > 0:
@@ -3347,7 +3740,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3347
3740
  def _update_cache_counter_and_locations(
3348
3741
  self,
3349
3742
  module: nn.Module,
3350
- grad_input: Union[Tuple[Tensor, ...], Tensor],
3743
+ grad_input: Union[tuple[Tensor, ...], Tensor],
3351
3744
  ) -> None:
3352
3745
  """
3353
3746
  Backward prehook function when prefetch_pipeline is enabled.
@@ -3543,10 +3936,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3543
3936
  indices: Tensor,
3544
3937
  offsets: Tensor,
3545
3938
  per_sample_weights: Optional[Tensor] = None,
3546
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
3939
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3547
3940
  force_cast_input_types: bool = True,
3548
3941
  prefetch_pipeline: bool = False,
3549
- ) -> Tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3942
+ vbe_output: Optional[Tensor] = None,
3943
+ vbe_output_offsets: Optional[Tensor] = None,
3944
+ ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3550
3945
  """
3551
3946
  Prepare TBE inputs as follows:
3552
3947
 
@@ -3572,9 +3967,20 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3572
3967
  metadata
3573
3968
  """
3574
3969
 
3970
+ if vbe_output is not None or vbe_output_offsets is not None:
3971
+ assert (
3972
+ not self.use_cpu
3973
+ ), "[TBE API v2] Using pre-allocated vbe_output is not supported on CPU"
3974
+ check_allocated_vbe_output(
3975
+ self.output_dtype,
3976
+ batch_size_per_feature_per_rank,
3977
+ vbe_output,
3978
+ vbe_output_offsets,
3979
+ )
3980
+
3575
3981
  # Generate VBE metadata
3576
3982
  vbe_metadata = self._generate_vbe_metadata(
3577
- offsets, batch_size_per_feature_per_rank
3983
+ offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
3578
3984
  )
3579
3985
 
3580
3986
  vbe = vbe_metadata.B_offsets is not None
@@ -3647,7 +4053,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3647
4053
  self.is_nobag,
3648
4054
  vbe_metadata.max_B_feature_rank,
3649
4055
  self.info_B_num_bits,
3650
- offsets.numel() - 1, # total_B
4056
+ offsets.numel() - 1, # total_B,
4057
+ vbe_output_offsets,
3651
4058
  )
3652
4059
  else:
3653
4060
  b_t_map = None
@@ -3736,7 +4143,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3736
4143
  # Counts of indices that segment lengths > 1024
3737
4144
  counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024]
3738
4145
 
3739
- def compute_numel_and_avg(counts: Tensor) -> Tuple[int, float]:
4146
+ def compute_numel_and_avg(counts: Tensor) -> tuple[int, float]:
3740
4147
  numel = counts.numel()
3741
4148
  avg = (counts.sum().item() / numel) if numel != 0 else -1.0
3742
4149
  return numel, avg
@@ -3804,6 +4211,240 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3804
4211
  return _debug_print_input_stats_factory_impl
3805
4212
  return _debug_print_input_stats_factory_null
3806
4213
 
4214
+ @torch.jit.ignore
4215
+ def raw_embedding_stream(self) -> None:
4216
+ if not self.enable_raw_embedding_streaming:
4217
+ return None
4218
+ # when pipelining is enabled
4219
+ # prefetch in iter i happens before the backward sparse in iter i - 1
4220
+ # so embeddings for iter i - 1's changed ids are not updated.
4221
+ # so we can only fetch the indices from the iter i - 2
4222
+ # when pipelining is disabled
4223
+ # prefetch in iter i happens before forward iter i
4224
+ # so we can get the iter i - 1's changed ids safely.
4225
+ target_prev_iter = 1
4226
+ if self.prefetch_pipeline:
4227
+ target_prev_iter = 2
4228
+ if not len(self.prefetched_info_list) > (target_prev_iter - 1):
4229
+ return None
4230
+ with record_function(
4231
+ "## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
4232
+ ):
4233
+ prefetched_info = self.prefetched_info_list.pop(0)
4234
+ updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
4235
+ prefetched_info.linear_unique_cache_indices,
4236
+ self.lxu_cache_state,
4237
+ self.total_cache_hash_size,
4238
+ gather_cache_stats=False, # not collecting cache stats
4239
+ num_uniq_cache_indices=prefetched_info.linear_unique_indices_length,
4240
+ )
4241
+ updated_weights = torch.empty(
4242
+ [
4243
+ prefetched_info.linear_unique_cache_indices.size()[0],
4244
+ self.max_D_cache,
4245
+ ],
4246
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
4247
+ dtype=self.lxu_cache_weights.dtype,
4248
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
4249
+ device=self.lxu_cache_weights.device,
4250
+ )
4251
+ torch.ops.fbgemm.masked_index_select(
4252
+ updated_weights,
4253
+ updated_locations,
4254
+ self.lxu_cache_weights,
4255
+ prefetched_info.linear_unique_indices_length,
4256
+ )
4257
+ # TODO: this statement triggers a sync
4258
+ # added here to make this diff self-contained
4259
+ # will remove in later change
4260
+ cache_hit_mask_index = (
4261
+ updated_locations.narrow(
4262
+ 0, 0, prefetched_info.linear_unique_indices_length.item()
4263
+ )
4264
+ .not_equal(-1)
4265
+ .nonzero()
4266
+ .flatten()
4267
+ )
4268
+ # stream weights
4269
+ self._raw_embedding_streamer.stream(
4270
+ prefetched_info.linear_unique_indices.index_select(
4271
+ dim=0, index=cache_hit_mask_index
4272
+ ).to(device=torch.device("cpu")),
4273
+ updated_weights.index_select(dim=0, index=cache_hit_mask_index).to(
4274
+ device=torch.device("cpu")
4275
+ ),
4276
+ (
4277
+ prefetched_info.hash_zch_identities.index_select(
4278
+ dim=0, index=cache_hit_mask_index
4279
+ ).to(device=torch.device("cpu"))
4280
+ if prefetched_info.hash_zch_identities is not None
4281
+ else None
4282
+ ),
4283
+ (
4284
+ prefetched_info.hash_zch_runtime_meta.index_select(
4285
+ dim=0, index=cache_hit_mask_index
4286
+ ).to(device=torch.device("cpu"))
4287
+ if prefetched_info.hash_zch_runtime_meta is not None
4288
+ else None
4289
+ ),
4290
+ prefetched_info.linear_unique_indices_length.to(
4291
+ device=torch.device("cpu")
4292
+ ),
4293
+ False, # require_tensor_copy
4294
+ False, # blocking_tensor_copy
4295
+ )
4296
+
4297
+ @staticmethod
4298
+ @torch.jit.ignore
4299
+ def _get_prefetched_info(
4300
+ linear_indices: torch.Tensor,
4301
+ linear_cache_indices_merged: torch.Tensor,
4302
+ total_cache_hash_size: int,
4303
+ hash_zch_identities: Optional[torch.Tensor],
4304
+ hash_zch_runtime_meta: Optional[torch.Tensor],
4305
+ max_indices_length: int,
4306
+ ) -> PrefetchedInfo:
4307
+ (
4308
+ linear_unique_cache_indices,
4309
+ linear_unique_cache_indices_length,
4310
+ linear_unique_cache_indices_cnt,
4311
+ linear_unique_cache_inverse_indices,
4312
+ ) = torch.ops.fbgemm.get_unique_indices_with_inverse(
4313
+ linear_cache_indices_merged,
4314
+ total_cache_hash_size,
4315
+ compute_count=True,
4316
+ compute_inverse_indices=True,
4317
+ )
4318
+ # pure cpu op, no need to sync, to avoid the indices out size the weights buffer
4319
+ max_len = min(
4320
+ max_indices_length,
4321
+ linear_unique_cache_indices.size(0),
4322
+ )
4323
+ if max_len < linear_unique_cache_indices.size(0):
4324
+ linear_unique_cache_indices_length.clamp_(max=max_len)
4325
+ # linear_unique_indices is the result after deduplication and sorting
4326
+ linear_unique_cache_indices = linear_unique_cache_indices.narrow(
4327
+ 0, 0, max_len
4328
+ )
4329
+ # Compute cumulative sum as indices for selecting unique elements to
4330
+ # map hash_zch_identities and hash_zch_runtime_meta to linear_unique_indices
4331
+ count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum(
4332
+ linear_unique_cache_indices_cnt
4333
+ )
4334
+ # count_cum_sum will be one more element than linear_unique_cache_indices_cnt
4335
+ count_cum_sum = count_cum_sum.narrow(0, 0, max_len)
4336
+ # clamp the uninitialized elements to avoid out of bound access
4337
+ # the uninitialized elements will be sliced out by linear_unique_cache_indices_length
4338
+ # directly using linear_unique_cache_indices_length requires a sync
4339
+ count_cum_sum.clamp_(min=0, max=linear_unique_cache_inverse_indices.size(0) - 1)
4340
+
4341
+ # Select indices corresponding to first occurrence of each unique element
4342
+ linear_unique_inverse_indices = (
4343
+ linear_unique_cache_inverse_indices.index_select(dim=0, index=count_cum_sum)
4344
+ )
4345
+ # same as above clamp
4346
+ linear_unique_inverse_indices.clamp_(min=0, max=linear_indices.size(0) - 1)
4347
+ linear_unique_indices = linear_indices.index_select(
4348
+ dim=0, index=linear_unique_inverse_indices
4349
+ )
4350
+ if hash_zch_identities is not None:
4351
+ # Map hash_zch_identities to unique indices
4352
+ hash_zch_identities = hash_zch_identities.index_select(
4353
+ dim=0, index=linear_unique_inverse_indices
4354
+ )
4355
+
4356
+ if hash_zch_runtime_meta is not None:
4357
+ # Map hash_zch_runtime_meta to unique indices
4358
+ hash_zch_runtime_meta = hash_zch_runtime_meta.index_select(
4359
+ dim=0, index=linear_unique_inverse_indices
4360
+ )
4361
+
4362
+ return PrefetchedInfo(
4363
+ linear_unique_indices,
4364
+ linear_unique_cache_indices,
4365
+ linear_unique_cache_indices_length,
4366
+ hash_zch_identities,
4367
+ hash_zch_runtime_meta,
4368
+ )
4369
+
4370
+ @torch.jit.ignore
4371
+ def _store_prefetched_tensors(
4372
+ self,
4373
+ indices: torch.Tensor,
4374
+ offsets: torch.Tensor,
4375
+ vbe_metadata: Optional[invokers.lookup_args.VBEMetadata],
4376
+ linear_cache_indices_merged: torch.Tensor,
4377
+ final_lxu_cache_locations: torch.Tensor,
4378
+ hash_zch_identities: Optional[torch.Tensor],
4379
+ hash_zch_runtime_meta: Optional[torch.Tensor],
4380
+ ) -> None:
4381
+ """
4382
+ NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional.
4383
+ This function stores the prefetched tensors for the raw embedding streaming.
4384
+ """
4385
+ if not self.enable_raw_embedding_streaming:
4386
+ return
4387
+
4388
+ with record_function(
4389
+ "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
4390
+ ):
4391
+ found_in_cache_mask = final_lxu_cache_locations != -1
4392
+ # only process the indices that are found in the cache
4393
+ # this will filter out the indices from tables that doesn't have UVM_CACHE enabled
4394
+ linear_cache_indices_merged_masked = torch.where(
4395
+ found_in_cache_mask,
4396
+ linear_cache_indices_merged,
4397
+ self.total_cache_hash_size,
4398
+ )
4399
+ linearize_indices = torch.ops.fbgemm.linearize_cache_indices(
4400
+ self.hash_size_cumsum,
4401
+ indices,
4402
+ offsets,
4403
+ vbe_metadata.B_offsets if vbe_metadata is not None else None,
4404
+ vbe_metadata.max_B if vbe_metadata is not None else -1,
4405
+ )
4406
+ # -1 indices are ignored in raw_embedding_streamer.
4407
+ linearize_indices_masked = torch.where(
4408
+ found_in_cache_mask,
4409
+ linearize_indices,
4410
+ -1,
4411
+ )
4412
+ # Process hash_zch_identities using helper function
4413
+ prefetched_info = self._get_prefetched_info(
4414
+ linearize_indices_masked,
4415
+ linear_cache_indices_merged_masked,
4416
+ self.total_cache_hash_size,
4417
+ hash_zch_identities,
4418
+ hash_zch_runtime_meta,
4419
+ self.lxu_cache_weights.size(0),
4420
+ )
4421
+
4422
+ self.prefetched_info_list.append(prefetched_info)
4423
+
4424
+ @torch.jit.ignore
4425
+ def __report_input_params_factory(
4426
+ self,
4427
+ ) -> Optional[Callable[..., None]]:
4428
+ """
4429
+ This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
4430
+
4431
+ If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
4432
+ - Reports input parameters (TBEDataConfig).
4433
+ - Writes the output as a JSON file.
4434
+
4435
+ If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
4436
+ """
4437
+ try:
4438
+ if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
4439
+ from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
4440
+
4441
+ reporter = TBEBenchmarkParamsReporter.create()
4442
+ return reporter.report_stats
4443
+ except Exception:
4444
+ return None
4445
+
4446
+ return None
4447
+
3807
4448
 
3808
4449
  class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3809
4450
  """
@@ -3817,12 +4458,12 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3817
4458
  max_D: int
3818
4459
  hash_size_cumsum: Tensor
3819
4460
  total_hash_size_bits: int
3820
- embedding_specs: List[Tuple[int, int]]
4461
+ embedding_specs: list[tuple[int, int]]
3821
4462
 
3822
4463
  def __init__(
3823
4464
  self,
3824
- embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims)
3825
- feature_table_map: Optional[List[int]] = None, # [T]
4465
+ embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
4466
+ feature_table_map: Optional[list[int]] = None, # [T]
3826
4467
  weights_precision: SparseType = SparseType.FP32,
3827
4468
  pooling_mode: PoolingMode = PoolingMode.SUM,
3828
4469
  use_cpu: bool = False,
@@ -3865,7 +4506,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3865
4506
  )
3866
4507
 
3867
4508
  self.embedding_specs = embedding_specs
3868
- (rows, dims) = zip(*embedding_specs)
4509
+ rows, dims = zip(*embedding_specs)
3869
4510
  T_ = len(self.embedding_specs)
3870
4511
  assert T_ > 0
3871
4512
 
@@ -3935,7 +4576,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3935
4576
  row for (row, _) in embedding_specs[:t]
3936
4577
  )
3937
4578
 
3938
- self.weights_physical_offsets: List[int] = weights_offsets
4579
+ self.weights_physical_offsets: list[int] = weights_offsets
3939
4580
  weights_offsets = [weights_offsets[t] for t in feature_table_map]
3940
4581
  self.register_buffer(
3941
4582
  "weights_offsets",
@@ -3962,7 +4603,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3962
4603
  def _generate_vbe_metadata(
3963
4604
  self,
3964
4605
  offsets: Tensor,
3965
- batch_size_per_feature_per_rank: Optional[List[List[int]]],
4606
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
3966
4607
  ) -> invokers.lookup_args.VBEMetadata:
3967
4608
  # Blocking D2H copy, but only runs at first call
3968
4609
  self.feature_dims = self.feature_dims.cpu()
@@ -3980,7 +4621,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3980
4621
  offsets: Tensor,
3981
4622
  per_sample_weights: Optional[Tensor] = None,
3982
4623
  feature_requires_grad: Optional[Tensor] = None,
3983
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
4624
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3984
4625
  ) -> Tensor:
3985
4626
  # Generate VBE metadata
3986
4627
  vbe_metadata = self._generate_vbe_metadata(
@@ -4019,7 +4660,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
4019
4660
  )
4020
4661
 
4021
4662
  @torch.jit.export
4022
- def split_embedding_weights(self) -> List[Tensor]:
4663
+ def split_embedding_weights(self) -> list[Tensor]:
4023
4664
  """
4024
4665
  Returns a list of weights, split by table
4025
4666
  """