fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -18,14 +18,14 @@ import uuid
18
18
  from dataclasses import dataclass, field
19
19
  from itertools import accumulate
20
20
  from math import log2
21
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
21
+ from typing import Any, Callable, Optional, Union
22
22
 
23
23
  import torch # usort:skip
24
24
  from torch import nn, Tensor # usort:skip
25
+ from torch.autograd.profiler import record_function # usort:skip
25
26
 
26
27
  # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
27
28
  import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
28
-
29
29
  from fbgemm_gpu.config import FeatureGate, FeatureGateName
30
30
  from fbgemm_gpu.runtime_monitor import (
31
31
  AsyncSeriesTimer,
@@ -37,8 +37,10 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
37
37
  BoundsCheckMode,
38
38
  CacheAlgorithm,
39
39
  CacheState,
40
+ ComputeDevice,
40
41
  construct_cache_state,
41
42
  EmbeddingLocation,
43
+ get_bounds_check_version_for_platform,
42
44
  MAX_PREFETCH_DEPTH,
43
45
  MultiPassPrefetchConfig,
44
46
  PoolingMode,
@@ -46,6 +48,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
46
48
  SplitState,
47
49
  )
48
50
  from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
51
+ check_allocated_vbe_output,
49
52
  generate_vbe_metadata,
50
53
  is_torchdynamo_compiling,
51
54
  )
@@ -55,8 +58,8 @@ from fbgemm_gpu.tbe_input_multiplexer import (
55
58
  TBEInputMultiplexer,
56
59
  TBEInputMultiplexerConfig,
57
60
  )
58
-
59
61
  from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
62
+ from fbgemm_gpu.utils.writeback_util import writeback_gradient
60
63
 
61
64
  try:
62
65
  load_torch_module(
@@ -80,12 +83,6 @@ class DoesNotHavePrefix(Exception):
80
83
  pass
81
84
 
82
85
 
83
- class ComputeDevice(enum.IntEnum):
84
- CPU = 0
85
- CUDA = 1
86
- MTIA = 2
87
-
88
-
89
86
  class WeightDecayMode(enum.IntEnum):
90
87
  NONE = 0
91
88
  L2 = 1
@@ -99,6 +96,7 @@ class CounterWeightDecayMode(enum.IntEnum):
99
96
  NONE = 0
100
97
  L2 = 1
101
98
  DECOUPLE = 2
99
+ ADAGRADW = 3
102
100
 
103
101
 
104
102
  class StepMode(enum.IntEnum):
@@ -160,6 +158,8 @@ class UserEnabledConfigDefinition:
160
158
  # This is used in Adam to perform rowwise bias correction using `row_counter`
161
159
  # More details can be found in D64848802.
162
160
  use_rowwise_bias_correction: bool = False
161
+ use_writeback_bwd_prehook: bool = False
162
+ writeback_first_feature_only: bool = False
163
163
 
164
164
 
165
165
  @dataclass(frozen=True)
@@ -188,16 +188,52 @@ class UVMCacheStatsIndex(enum.IntEnum):
188
188
  num_conflict_misses = 5
189
189
 
190
190
 
191
+ @dataclass
192
+ class RESParams:
193
+ res_server_port: int = 0 # the port of the res server
194
+ res_store_shards: int = 1 # the number of shards to store the raw embeddings
195
+ table_names: list[str] = field(default_factory=list) # table names the TBE holds
196
+ table_offsets: list[int] = field(
197
+ default_factory=list
198
+ ) # table offsets for the global rows the TBE holds
199
+ table_sizes: list[int] = field(
200
+ default_factory=list
201
+ ) # table sizes for the global rows the TBE holds
202
+
203
+
204
+ class PrefetchedInfo:
205
+ """
206
+ Container for prefetched cache information.
207
+
208
+ This class is explicitly defined (not using @dataclass) to be compatible with
209
+ TorchScript's inspect.getsource() requirements.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ linear_unique_indices: torch.Tensor,
215
+ linear_unique_cache_indices: torch.Tensor,
216
+ linear_unique_indices_length: torch.Tensor,
217
+ hash_zch_identities: Optional[torch.Tensor],
218
+ hash_zch_runtime_meta: Optional[torch.Tensor],
219
+ ) -> None:
220
+ self.linear_unique_indices = linear_unique_indices
221
+ self.linear_unique_cache_indices = linear_unique_cache_indices
222
+ self.linear_unique_indices_length = linear_unique_indices_length
223
+ self.hash_zch_identities = hash_zch_identities
224
+ self.hash_zch_runtime_meta = hash_zch_runtime_meta
225
+
226
+
191
227
  def construct_split_state(
192
- embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]],
228
+ embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]],
193
229
  rowwise: bool,
194
230
  cacheable: bool,
195
231
  precision: SparseType = SparseType.FP32,
196
232
  int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
197
233
  placement: Optional[EmbeddingLocation] = None,
198
234
  ) -> SplitState:
199
- placements: List[EmbeddingLocation] = []
200
- offsets: List[int] = []
235
+ placements: list[EmbeddingLocation] = []
236
+ offsets: list[int] = []
201
237
  dev_size: int = 0
202
238
  host_size: int = 0
203
239
  uvm_size: int = 0
@@ -239,18 +275,18 @@ def construct_split_state(
239
275
  def apply_split_helper(
240
276
  persistent_state_fn: Callable[[str, Tensor], None],
241
277
  set_attr_fn: Callable[
242
- [str, Union[Tensor, List[int], List[EmbeddingLocation]]], None
278
+ [str, Union[Tensor, list[int], list[EmbeddingLocation]]], None
243
279
  ],
244
280
  current_device: torch.device,
245
281
  use_cpu: bool,
246
- feature_table_map: List[int],
282
+ feature_table_map: list[int],
247
283
  split: SplitState,
248
284
  prefix: str,
249
- dtype: Type[torch.dtype],
285
+ dtype: type[torch.dtype],
250
286
  enforce_hbm: bool = False,
251
287
  make_dev_param: bool = False,
252
- dev_reshape: Optional[Tuple[int, ...]] = None,
253
- uvm_tensors_log: Optional[List[str]] = None,
288
+ dev_reshape: Optional[tuple[int, ...]] = None,
289
+ uvm_tensors_log: Optional[list[str]] = None,
254
290
  uvm_host_mapped: bool = False,
255
291
  ) -> None:
256
292
  set_attr_fn(f"{prefix}_physical_placements", split.placements)
@@ -335,6 +371,7 @@ def apply_split_helper(
335
371
  f"{prefix}_uvm",
336
372
  torch.zeros(
337
373
  split.uvm_size,
374
+ device=current_device,
338
375
  out=torch.ops.fbgemm.new_unified_tensor(
339
376
  # pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]`
340
377
  # for 3rd param but got `Type[Type[torch._dtype]]`.
@@ -603,13 +640,19 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
603
640
  embedding_table_offset_type (torch.dtype = torch.int64): The data type of
604
641
  the embedding table offset tensor. Options are `torch.int32` and
605
642
  `torch.int64`
643
+
644
+ embedding_shard_info (Optional[List[Tuple[int, int, int, int]]] = None): the
645
+ information about shard position and pre-sharded table size. If not set,
646
+ the table is not sharded.
647
+ (preshard_table_height, preshard_table_dim, height_offset, dim_offset)
606
648
  """
607
649
 
608
- embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
650
+ embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]]
609
651
  optimizer_args: invokers.lookup_args.OptimizerArgs
610
- lxu_cache_locations_list: List[Tensor]
652
+ lxu_cache_locations_list: list[Tensor]
611
653
  lxu_cache_locations_empty: Tensor
612
- timesteps_prefetched: List[int]
654
+ timesteps_prefetched: list[int]
655
+ prefetched_info_list: list[PrefetchedInfo]
613
656
  record_cache_metrics: RecordCacheMetrics
614
657
  # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
615
658
  uvm_cache_stats: torch.Tensor
@@ -623,10 +666,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
623
666
 
624
667
  def __init__( # noqa C901
625
668
  self,
626
- embedding_specs: List[
627
- Tuple[int, int, EmbeddingLocation, ComputeDevice]
669
+ embedding_specs: list[
670
+ tuple[int, int, EmbeddingLocation, ComputeDevice]
628
671
  ], # tuple of (rows, dims, placements, compute_devices)
629
- feature_table_map: Optional[List[int]] = None, # [T]
672
+ feature_table_map: Optional[list[int]] = None, # [T]
630
673
  cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
631
674
  cache_load_factor: float = 0.2,
632
675
  cache_sets: int = 0,
@@ -664,8 +707,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
664
707
  use_experimental_tbe: bool = False,
665
708
  prefetch_pipeline: bool = False,
666
709
  stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
667
- table_names: Optional[List[str]] = None,
668
- optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None,
710
+ table_names: Optional[list[str]] = None,
711
+ optimizer_state_dtypes: Optional[dict[str, SparseType]] = None,
669
712
  multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
670
713
  global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
671
714
  uvm_host_mapped: bool = False,
@@ -673,23 +716,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
673
716
  tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
674
717
  embedding_table_index_type: torch.dtype = torch.int64,
675
718
  embedding_table_offset_type: torch.dtype = torch.int64,
719
+ embedding_shard_info: Optional[list[tuple[int, int, int, int]]] = None,
720
+ enable_raw_embedding_streaming: bool = False,
721
+ res_params: Optional[RESParams] = None,
676
722
  ) -> None:
677
723
  super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
678
724
  self.uuid = str(uuid.uuid4())
679
-
725
+ self.log("SplitTableBatchedEmbeddingBagsCodegen API: V2")
680
726
  self.log(f"SplitTableBatchedEmbeddingBagsCodegen Arguments: {locals()}")
681
727
  self.log(
682
728
  f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
683
729
  )
684
730
 
731
+ self.table_names: Optional[list[str]] = table_names
685
732
  self.logging_table_name: str = self.get_table_name_for_logging(table_names)
733
+ self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
686
734
  self.pooling_mode = pooling_mode
687
735
  self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
736
+
688
737
  # If environment variable is set, it overwrites the default bounds check mode.
689
- self.bounds_check_version: int = 1
738
+ self.bounds_check_version: int = (
739
+ 2
740
+ if self._feature_is_enabled(FeatureGateName.BOUNDS_CHECK_INDICES_V2)
741
+ else get_bounds_check_version_for_platform()
742
+ )
690
743
  self.bounds_check_mode_int: int = int(
691
744
  os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
692
745
  )
746
+ # Check if bounds_check_indices_v2 is enabled via the feature gate
693
747
  bounds_check_mode = BoundsCheckMode(self.bounds_check_mode_int)
694
748
  if bounds_check_mode.name.startswith("V2_"):
695
749
  self.bounds_check_version = 2
@@ -699,10 +753,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
699
753
  bounds_check_mode = BoundsCheckMode.WARNING
700
754
  elif bounds_check_mode == BoundsCheckMode.V2_FATAL:
701
755
  bounds_check_mode = BoundsCheckMode.FATAL
702
- else:
703
- raise NotImplementedError(
704
- f"Did not recognize V2 bounds check mode: {bounds_check_mode}"
705
- )
756
+
757
+ if bounds_check_mode not in (
758
+ BoundsCheckMode.IGNORE,
759
+ BoundsCheckMode.WARNING,
760
+ BoundsCheckMode.FATAL,
761
+ BoundsCheckMode.NONE,
762
+ ):
763
+ raise NotImplementedError(
764
+ f"SplitTableBatchedEmbeddingBagsCodegen bounds_check_mode={bounds_check_mode} is not supported"
765
+ )
766
+
767
+ self.bounds_check_mode_int = bounds_check_mode.value
768
+
769
+ self.log(
770
+ f"SplitTableBatchedEmbeddingBagsCodegen bounds_check_mode={bounds_check_mode} bounds_check_version={self.bounds_check_version}"
771
+ )
706
772
 
707
773
  self.weights_precision = weights_precision
708
774
 
@@ -715,6 +781,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
715
781
  # See:
716
782
  # https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
717
783
  cache_precision = SparseType.FP32
784
+ self.log("Override cache_precision=SparseType.FP32 on ROCm")
718
785
  else:
719
786
  # NOTE: The changes from D65865527 are retained here until we can
720
787
  # test that the the hack also works for non-ROCm environments.
@@ -757,9 +824,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
757
824
  ), "Unique cache miss counters are not accurate in multipass prefetch and therefore not supported"
758
825
 
759
826
  self.embedding_specs = embedding_specs
760
- (rows, dims, locations, compute_devices) = zip(*embedding_specs)
827
+ rows, dims, locations, compute_devices = zip(*embedding_specs)
761
828
  T_ = len(self.embedding_specs)
762
- self.dims: List[int] = dims
829
+ self.dims: list[int] = dims
763
830
  assert T_ > 0
764
831
  # mixed D is not supported by no bag kernels
765
832
  mixed_D = False
@@ -772,7 +839,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
772
839
  assert (
773
840
  self.pooling_mode != PoolingMode.NONE
774
841
  ), "Mixed dimension tables only supported for pooling tables."
775
-
842
+ self.mixed_D: bool = mixed_D
776
843
  assert all(
777
844
  cd == compute_devices[0] for cd in compute_devices
778
845
  ), "Heterogenous compute_devices are NOT supported!"
@@ -836,7 +903,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
836
903
  self.stats_reporter: Optional[TBEStatsReporter] = (
837
904
  stats_reporter_config.create_reporter() if stats_reporter_config else None
838
905
  )
839
- self._uvm_tensors_log: List[str] = []
906
+ self._uvm_tensors_log: list[str] = []
840
907
 
841
908
  self.bwd_wait_prefetch_timer: Optional[AsyncSeriesTimer] = None
842
909
  self.prefetch_duration_timer: Optional[AsyncSeriesTimer] = None
@@ -863,10 +930,20 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
863
930
 
864
931
  self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET
865
932
 
866
- self.feature_table_map: List[int] = (
933
+ self.feature_table_map: list[int] = (
867
934
  feature_table_map if feature_table_map is not None else list(range(T_))
868
935
  )
869
936
 
937
+ if embedding_shard_info:
938
+ full_table_heights, full_table_dims, row_offset, col_offset = zip(
939
+ *embedding_shard_info
940
+ )
941
+ else:
942
+ # Just assume the table is unsharded
943
+ full_table_heights = rows
944
+ full_table_dims = dims
945
+ row_offset = [0] * len(rows)
946
+ col_offset = [0] * len(rows)
870
947
  self.tbe_input_multiplexer: Optional[TBEInputMultiplexer] = (
871
948
  tbe_input_multiplexer_config.create_tbe_input_multiplexer(
872
949
  tbe_info=TBEInfo(
@@ -878,6 +955,11 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
878
955
  table_heights=rows,
879
956
  tbe_uuid=self.uuid,
880
957
  feature_table_map=self.feature_table_map,
958
+ table_dims=dims,
959
+ full_table_heights=full_table_heights,
960
+ full_table_dims=full_table_dims,
961
+ row_offset=row_offset,
962
+ col_offset=col_offset,
881
963
  )
882
964
  )
883
965
  if tbe_input_multiplexer_config is not None
@@ -888,7 +970,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
888
970
  table_has_feature = [False] * T_
889
971
  for t in self.feature_table_map:
890
972
  table_has_feature[t] = True
891
- assert all(table_has_feature), "Each table must have at least one feature!"
973
+ assert all(table_has_feature), (
974
+ "Each table must have at least one feature!"
975
+ + f"{[(i, x) for i, x in enumerate(table_has_feature)]}"
976
+ )
892
977
 
893
978
  feature_dims = [dims[t] for t in self.feature_table_map]
894
979
  D_offsets = [0] + list(accumulate(feature_dims))
@@ -940,6 +1025,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
940
1025
  "feature_dims",
941
1026
  torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
942
1027
  )
1028
+ _info_B_num_bits, _info_B_mask = torch.ops.fbgemm.get_infos_metadata(
1029
+ self.D_offsets, # unused tensor
1030
+ 1, # max_B
1031
+ T, # T
1032
+ )
1033
+ self.info_B_num_bits: int = _info_B_num_bits
1034
+ self.info_B_mask: int = _info_B_mask
943
1035
 
944
1036
  # A flag for indicating whether all embedding tables are placed in the
945
1037
  # same locations
@@ -1047,13 +1139,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1047
1139
 
1048
1140
  if ensemble_mode is None:
1049
1141
  ensemble_mode = EnsembleModeDefinition()
1050
- self._ensemble_mode: Dict[str, float] = {
1142
+ self._ensemble_mode: dict[str, float] = {
1051
1143
  key: float(fval) for key, fval in ensemble_mode.__dict__.items()
1052
1144
  }
1053
1145
 
1054
1146
  if emainplace_mode is None:
1055
1147
  emainplace_mode = EmainplaceModeDefinition()
1056
- self._emainplace_mode: Dict[str, float] = {
1148
+ self._emainplace_mode: dict[str, float] = {
1057
1149
  key: float(fval) for key, fval in emainplace_mode.__dict__.items()
1058
1150
  }
1059
1151
 
@@ -1085,26 +1177,37 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1085
1177
  # and CowClipDefinition are not used
1086
1178
  counter_halflife = -1
1087
1179
 
1088
- # TO DO: Enable this on the new interface
1089
- # learning_rate_tensor = torch.tensor(
1090
- # learning_rate, device=torch.device("cpu"), dtype=torch.float
1091
- # )
1092
1180
  if extra_optimizer_config is None:
1093
1181
  extra_optimizer_config = UserEnabledConfigDefinition()
1094
1182
  self.use_rowwise_bias_correction: bool = (
1095
1183
  extra_optimizer_config.use_rowwise_bias_correction
1096
1184
  )
1185
+ self.use_writeback_bwd_prehook: bool = (
1186
+ extra_optimizer_config.use_writeback_bwd_prehook
1187
+ )
1188
+
1189
+ writeback_first_feature_only: bool = (
1190
+ extra_optimizer_config.writeback_first_feature_only
1191
+ )
1192
+ self.log(f"self.extra_optimizer_config is {extra_optimizer_config}")
1097
1193
  if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
1098
1194
  raise AssertionError(
1099
1195
  "`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
1100
1196
  )
1197
+ if self.use_writeback_bwd_prehook and not self.optimizer == OptimType.EXACT_SGD:
1198
+ raise AssertionError(
1199
+ "`use_writeback_bwd_prehook` is only supported for OptimType.EXACT_SGD",
1200
+ )
1201
+
1202
+ self.learning_rate_tensor: torch.Tensor = torch.tensor(
1203
+ learning_rate, device=torch.device("cpu"), dtype=torch.float32
1204
+ )
1101
1205
 
1102
1206
  self.optimizer_args = invokers.lookup_args.OptimizerArgs(
1103
1207
  stochastic_rounding=stochastic_rounding,
1104
1208
  gradient_clipping=gradient_clipping,
1105
1209
  max_gradient=max_gradient,
1106
1210
  max_norm=max_norm,
1107
- learning_rate=learning_rate,
1108
1211
  eps=eps,
1109
1212
  beta1=beta1,
1110
1213
  beta2=beta2,
@@ -1351,7 +1454,11 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1351
1454
 
1352
1455
  self.step = 0
1353
1456
  self.last_reported_step = 0
1354
- self.last_reported_uvm_stats: List[float] = []
1457
+ self.last_reported_uvm_stats: list[float] = []
1458
+ # Track number of times detailed memory breakdown has been reported
1459
+ self.detailed_mem_breakdown_report_count = 0
1460
+ # Set max number of reports for detailed memory breakdown
1461
+ self.max_detailed_mem_breakdown_reports = 10
1355
1462
 
1356
1463
  # Check whether to use TBE v2
1357
1464
  is_experimental = False
@@ -1370,18 +1477,25 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1370
1477
  # self.log("TBE_V2 Knob is set to True; Using experimental TBE")
1371
1478
 
1372
1479
  self.is_experimental: bool = is_experimental
1480
+ self._writeback_first_feature_only: bool = writeback_first_feature_only
1373
1481
 
1374
1482
  # Get a debug function pointer
1375
1483
  self._debug_print_input_stats: Callable[..., None] = (
1376
1484
  self._debug_print_input_stats_factory()
1377
1485
  )
1378
1486
 
1379
- # Check if bounds_check_indices_v2 is enabled via the feature gate
1380
- self.use_bounds_check_v2: bool = self._feature_is_enabled(
1381
- FeatureGateName.BOUNDS_CHECK_INDICES_V2
1487
+ # Get a reporter function pointer
1488
+ self._report_input_params: Callable[..., None] = (
1489
+ self.__report_input_params_factory()
1382
1490
  )
1383
- if self.bounds_check_version == 2:
1384
- self.use_bounds_check_v2 = True
1491
+
1492
+ if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
1493
+ # Register writeback hook for Exact_SGD optimizer
1494
+ self.log(
1495
+ f"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled with first feature only={self._writeback_first_feature_only}"
1496
+ )
1497
+ # pyre-fixme[6]: Expected `typing.Callable[[Module, Union[Tensor, typing.Tuple[Tensor, ...]]], Union[None, Tensor, typing.Tuple[Tensor, ...]]]`
1498
+ self.register_full_backward_pre_hook(self.writeback_hook)
1385
1499
 
1386
1500
  if embedding_table_index_type not in [torch.int32, torch.int64]:
1387
1501
  raise ValueError(
@@ -1394,6 +1508,30 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1394
1508
  )
1395
1509
  self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type
1396
1510
 
1511
+ self.prefetched_info_list: list[PrefetchedInfo] = torch.jit.annotate(
1512
+ list[PrefetchedInfo], []
1513
+ )
1514
+ if self.enable_raw_embedding_streaming:
1515
+ self.res_params: RESParams = res_params or RESParams()
1516
+ self.res_params.table_sizes = [0] + list(accumulate(rows))
1517
+ res_port_from_env = os.getenv("LOCAL_RES_PORT")
1518
+ self.res_params.res_server_port = (
1519
+ int(res_port_from_env) if res_port_from_env else 0
1520
+ )
1521
+ # pyre-fixme[4]: Attribute must be annotated.
1522
+ self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
1523
+ self.uuid,
1524
+ self.enable_raw_embedding_streaming,
1525
+ self.res_params.res_store_shards,
1526
+ self.res_params.res_server_port,
1527
+ self.res_params.table_names,
1528
+ self.res_params.table_offsets,
1529
+ self.res_params.table_sizes,
1530
+ )
1531
+ logging.info(
1532
+ f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
1533
+ )
1534
+
1397
1535
  @torch.jit.ignore
1398
1536
  def log(self, msg: str) -> None:
1399
1537
  """
@@ -1437,7 +1575,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1437
1575
  )
1438
1576
 
1439
1577
  @staticmethod
1440
- def get_table_name_for_logging(table_names: Optional[List[str]]) -> str:
1578
+ def get_table_name_for_logging(table_names: Optional[list[str]]) -> str:
1441
1579
  """
1442
1580
  Given a list of all table names in the TBE, generate a string to
1443
1581
  represent them in logging. If there is more than one table, this method
@@ -1453,17 +1591,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1453
1591
  return "<Unknown>"
1454
1592
  # Do this because sometimes multiple shards of the same table could appear
1455
1593
  # in one TBE.
1456
- table_name_set = set(table_names)
1594
+ table_name_set = sorted(set(table_names))
1457
1595
  if len(table_name_set) == 1:
1458
1596
  return next(iter(table_name_set))
1459
- return f"<{len(table_name_set)} tables>"
1597
+ return f"<{len(table_name_set)} tables>: {table_name_set}"
1460
1598
 
1461
1599
  @staticmethod
1462
1600
  def get_prefetch_passes(
1463
1601
  multipass_prefetch_config: Optional[MultiPassPrefetchConfig],
1464
1602
  input_tensor: Tensor,
1465
1603
  output_tensor: Tensor,
1466
- ) -> List[Tuple[Tensor, Tensor, int]]:
1604
+ ) -> list[tuple[Tensor, Tensor, int]]:
1467
1605
  """
1468
1606
  Given inputs (the indices to forward), partition the input and output
1469
1607
  into smaller chunks and return them as a list of tuples
@@ -1511,7 +1649,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1511
1649
  )
1512
1650
  )
1513
1651
 
1514
- def get_states(self, prefix: str) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
1652
+ def get_states(self, prefix: str) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
1515
1653
  """
1516
1654
  Get a state of a given tensor (`prefix`)
1517
1655
 
@@ -1550,7 +1688,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1550
1688
  torch.tensor(offsets, dtype=torch.int64),
1551
1689
  )
1552
1690
 
1553
- def get_all_states(self) -> List[Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
1691
+ def get_all_states(self) -> list[tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
1554
1692
  """
1555
1693
  Get all states in the TBE (`weights`, `momentum1`, `momentum2`,
1556
1694
  `prev_iter`, and `row_counter`)
@@ -1614,10 +1752,161 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1614
1752
  tbe_id=self.uuid,
1615
1753
  )
1616
1754
 
1617
- @torch.jit.ignore
1618
- def _report_tbe_mem_usage(
1755
+ def _get_tensor_memory(self, tensor_name: str) -> int:
1756
+ """Get memory usage of a tensor in bytes."""
1757
+ if not hasattr(self, tensor_name):
1758
+ self.log(f"Tensor '{tensor_name}' not found, using 0 bytes")
1759
+ return 0
1760
+ tensor = getattr(self, tensor_name)
1761
+ return tensor.numel() * tensor.element_size()
1762
+
1763
+ def _categorize_memory_by_location(
1764
+ self, tensor_names: list[str]
1765
+ ) -> tuple[int, int]:
1766
+ """Categorize memory into HBM and UVM for given tensors.
1767
+
1768
+ Returns:
1769
+ (hbm_bytes, uvm_bytes)
1770
+ """
1771
+ uvm_set = set(self._uvm_tensors_log)
1772
+ hbm_bytes = 0
1773
+ uvm_bytes = 0
1774
+
1775
+ for name in tensor_names:
1776
+ size = self._get_tensor_memory(name)
1777
+ if name in uvm_set:
1778
+ uvm_bytes += size
1779
+ else:
1780
+ hbm_bytes += size
1781
+
1782
+ return hbm_bytes, uvm_bytes
1783
+
1784
+ def _report_hbm_breakdown(
1619
1785
  self,
1786
+ stats_reporter: TBEStatsReporter,
1787
+ embeddings: int,
1788
+ optimizer_states: int,
1789
+ cache: int,
1790
+ total_static_sparse: int,
1791
+ ephemeral: int,
1792
+ cache_weights: int = 0,
1793
+ cache_aux: int = 0,
1620
1794
  ) -> None:
1795
+ """Report HBM memory breakdown to stats reporter."""
1796
+ stats_reporter.report_data_amount(
1797
+ iteration_step=self.step,
1798
+ event_name="tbe.hbm.embeddings",
1799
+ data_bytes=embeddings,
1800
+ embedding_id=self.logging_table_name,
1801
+ tbe_id=self.uuid,
1802
+ )
1803
+ stats_reporter.report_data_amount(
1804
+ iteration_step=self.step,
1805
+ event_name="tbe.hbm.optimizer_states",
1806
+ data_bytes=optimizer_states,
1807
+ embedding_id=self.logging_table_name,
1808
+ tbe_id=self.uuid,
1809
+ )
1810
+ stats_reporter.report_data_amount(
1811
+ iteration_step=self.step,
1812
+ event_name="tbe.hbm.cache",
1813
+ data_bytes=cache,
1814
+ embedding_id=self.logging_table_name,
1815
+ tbe_id=self.uuid,
1816
+ )
1817
+ stats_reporter.report_data_amount(
1818
+ iteration_step=self.step,
1819
+ event_name="tbe.hbm.cache_weights",
1820
+ data_bytes=cache_weights,
1821
+ embedding_id=self.logging_table_name,
1822
+ tbe_id=self.uuid,
1823
+ )
1824
+ stats_reporter.report_data_amount(
1825
+ iteration_step=self.step,
1826
+ event_name="tbe.hbm.cache_aux",
1827
+ data_bytes=cache_aux,
1828
+ embedding_id=self.logging_table_name,
1829
+ tbe_id=self.uuid,
1830
+ )
1831
+ stats_reporter.report_data_amount(
1832
+ iteration_step=self.step,
1833
+ event_name="tbe.hbm.total_static_sparse",
1834
+ data_bytes=total_static_sparse,
1835
+ embedding_id=self.logging_table_name,
1836
+ tbe_id=self.uuid,
1837
+ )
1838
+ stats_reporter.report_data_amount(
1839
+ iteration_step=self.step,
1840
+ event_name="tbe.hbm.ephemeral",
1841
+ data_bytes=ephemeral,
1842
+ embedding_id=self.logging_table_name,
1843
+ tbe_id=self.uuid,
1844
+ )
1845
+
1846
+ def _report_uvm_breakdown(
1847
+ self,
1848
+ stats_reporter: TBEStatsReporter,
1849
+ embeddings: int,
1850
+ optimizer_states: int,
1851
+ cache: int,
1852
+ total_static_sparse: int,
1853
+ ephemeral: int,
1854
+ cache_weights: int = 0,
1855
+ cache_aux: int = 0,
1856
+ ) -> None:
1857
+ """Report UVM memory breakdown to stats reporter."""
1858
+ stats_reporter.report_data_amount(
1859
+ iteration_step=self.step,
1860
+ event_name="tbe.uvm.embeddings",
1861
+ data_bytes=embeddings,
1862
+ embedding_id=self.logging_table_name,
1863
+ tbe_id=self.uuid,
1864
+ )
1865
+ stats_reporter.report_data_amount(
1866
+ iteration_step=self.step,
1867
+ event_name="tbe.uvm.optimizer_states",
1868
+ data_bytes=optimizer_states,
1869
+ embedding_id=self.logging_table_name,
1870
+ tbe_id=self.uuid,
1871
+ )
1872
+ stats_reporter.report_data_amount(
1873
+ iteration_step=self.step,
1874
+ event_name="tbe.uvm.cache",
1875
+ data_bytes=cache,
1876
+ embedding_id=self.logging_table_name,
1877
+ tbe_id=self.uuid,
1878
+ )
1879
+ stats_reporter.report_data_amount(
1880
+ iteration_step=self.step,
1881
+ event_name="tbe.uvm.cache_weights",
1882
+ data_bytes=cache_weights,
1883
+ embedding_id=self.logging_table_name,
1884
+ tbe_id=self.uuid,
1885
+ )
1886
+ stats_reporter.report_data_amount(
1887
+ iteration_step=self.step,
1888
+ event_name="tbe.uvm.cache_aux",
1889
+ data_bytes=cache_aux,
1890
+ embedding_id=self.logging_table_name,
1891
+ tbe_id=self.uuid,
1892
+ )
1893
+ stats_reporter.report_data_amount(
1894
+ iteration_step=self.step,
1895
+ event_name="tbe.uvm.total_static_sparse",
1896
+ data_bytes=total_static_sparse,
1897
+ embedding_id=self.logging_table_name,
1898
+ tbe_id=self.uuid,
1899
+ )
1900
+ stats_reporter.report_data_amount(
1901
+ iteration_step=self.step,
1902
+ event_name="tbe.uvm.ephemeral",
1903
+ data_bytes=ephemeral,
1904
+ embedding_id=self.logging_table_name,
1905
+ tbe_id=self.uuid,
1906
+ )
1907
+
1908
+ @torch.jit.ignore
1909
+ def _report_tbe_mem_usage(self) -> None:
1621
1910
  if self.stats_reporter is None:
1622
1911
  return
1623
1912
 
@@ -1625,22 +1914,24 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1625
1914
  if not stats_reporter.should_report(self.step):
1626
1915
  return
1627
1916
 
1917
+ # Calculate total memory from all parameters and buffers (always needed)
1628
1918
  total_mem_usage = sum(
1629
- param.numel() * param.element_size() for param in self.parameters()
1630
- ) + sum(buffer.numel() * buffer.element_size() for buffer in self.buffers())
1919
+ p.numel() * p.element_size() for p in self.parameters()
1920
+ ) + sum(b.numel() * b.element_size() for b in self.buffers())
1921
+
1922
+ # Calculate total HBM and UVM usage (always needed)
1631
1923
  if self.use_cpu:
1632
1924
  total_hbm_usage = 0
1633
1925
  total_uvm_usage = total_mem_usage
1634
1926
  else:
1635
- # hbm usage is total usage minus uvm usage
1636
1927
  total_uvm_usage = sum(
1637
- getattr(self, tensor_name).numel()
1638
- * getattr(self, tensor_name).element_size()
1639
- for tensor_name in self._uvm_tensors_log
1640
- if hasattr(self, tensor_name)
1928
+ self._get_tensor_memory(name)
1929
+ for name in self._uvm_tensors_log
1930
+ if hasattr(self, name)
1641
1931
  )
1642
1932
  total_hbm_usage = total_mem_usage - total_uvm_usage
1643
1933
 
1934
+ # Report total memory usage metrics (always reported for backward compatibility)
1644
1935
  stats_reporter.report_data_amount(
1645
1936
  iteration_step=self.step,
1646
1937
  event_name="tbe.total_hbm_usage",
@@ -1656,6 +1947,96 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1656
1947
  tbe_id=self.uuid,
1657
1948
  )
1658
1949
 
1950
+ # Only report detailed breakdown for the first max_detailed_mem_breakdown_reports reportable
1951
+ # steps since static sparse memory (weights, optimizer states, cache) is constant
1952
+ if (
1953
+ self.detailed_mem_breakdown_report_count
1954
+ >= self.max_detailed_mem_breakdown_reports
1955
+ ):
1956
+ return
1957
+ self.detailed_mem_breakdown_report_count += 1
1958
+
1959
+ # Tensor groups for sparse memory categorization
1960
+ weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]
1961
+ optimizer_tensors = [
1962
+ "momentum1_dev",
1963
+ "momentum1_host",
1964
+ "momentum1_uvm",
1965
+ "momentum2_dev",
1966
+ "momentum2_host",
1967
+ "momentum2_uvm",
1968
+ ]
1969
+ # Cache weights tensor (the actual cached embeddings in HBM)
1970
+ cache_weight_tensors = [
1971
+ "lxu_cache_weights",
1972
+ ]
1973
+ # Cache auxiliary state tensors (metadata for cache management, excluding weights)
1974
+ # Sizes scale with hash_size or cache_slots (hash_size × clf)
1975
+ # Excludes constant-size tensors: cache_hash_size_cumsum, cache_miss_counter, etc.
1976
+ cache_aux_tensors = [
1977
+ "cache_index_table_map", # int32, 4B × hash_size
1978
+ "lxu_cache_state", # int64, 8B × cache_slots
1979
+ "lxu_state", # int64, 8B × cache_slots (LRU) or hash_size (LFU)
1980
+ "lxu_cache_locking_counter", # int32, 4B × cache_slots (only if prefetch_pipeline)
1981
+ ]
1982
+
1983
+ # Calculate total memory for each component
1984
+ weights_total = sum(self._get_tensor_memory(t) for t in weight_tensors)
1985
+ optimizer_total = sum(self._get_tensor_memory(t) for t in optimizer_tensors)
1986
+ cache_weights_total = sum(
1987
+ self._get_tensor_memory(t) for t in cache_weight_tensors
1988
+ )
1989
+ cache_aux_total = sum(self._get_tensor_memory(t) for t in cache_aux_tensors)
1990
+
1991
+ # Categorize memory by location (HBM vs UVM)
1992
+ if self.use_cpu:
1993
+ weights_hbm, weights_uvm = 0, weights_total
1994
+ opt_hbm, opt_uvm = 0, optimizer_total
1995
+ cache_weights_hbm, cache_weights_uvm = 0, cache_weights_total
1996
+ cache_aux_hbm, cache_aux_uvm = 0, cache_aux_total
1997
+ else:
1998
+ weights_hbm, weights_uvm = self._categorize_memory_by_location(
1999
+ weight_tensors
2000
+ )
2001
+ opt_hbm, opt_uvm = self._categorize_memory_by_location(optimizer_tensors)
2002
+ cache_weights_hbm, cache_weights_uvm = self._categorize_memory_by_location(
2003
+ cache_weight_tensors
2004
+ )
2005
+ cache_aux_hbm, cache_aux_uvm = self._categorize_memory_by_location(
2006
+ cache_aux_tensors
2007
+ )
2008
+
2009
+ # Calculate ephemeral memory split between HBM and UVM
2010
+ # Total cache = cache weights + cache auxiliary state
2011
+ cache_hbm = cache_weights_hbm + cache_aux_hbm
2012
+ cache_uvm = cache_weights_uvm + cache_aux_uvm
2013
+ static_sparse_hbm = weights_hbm + opt_hbm + cache_hbm
2014
+ static_sparse_uvm = weights_uvm + opt_uvm + cache_uvm
2015
+ ephemeral_hbm = total_hbm_usage - static_sparse_hbm
2016
+ ephemeral_uvm = total_uvm_usage - static_sparse_uvm
2017
+
2018
+ # Report granular memory breakdowns
2019
+ self._report_hbm_breakdown(
2020
+ stats_reporter,
2021
+ weights_hbm,
2022
+ opt_hbm,
2023
+ cache_hbm,
2024
+ static_sparse_hbm,
2025
+ ephemeral_hbm,
2026
+ cache_weights_hbm,
2027
+ cache_aux_hbm,
2028
+ )
2029
+ self._report_uvm_breakdown(
2030
+ stats_reporter,
2031
+ weights_uvm,
2032
+ opt_uvm,
2033
+ cache_uvm,
2034
+ static_sparse_uvm,
2035
+ ephemeral_uvm,
2036
+ cache_weights_uvm,
2037
+ cache_aux_uvm,
2038
+ )
2039
+
1659
2040
  @torch.jit.ignore
1660
2041
  def _report_io_size_count(self, event: str, data: Tensor) -> Tensor:
1661
2042
  if self.stats_reporter is None:
@@ -1682,7 +2063,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1682
2063
  def _generate_vbe_metadata(
1683
2064
  self,
1684
2065
  offsets: Tensor,
1685
- batch_size_per_feature_per_rank: Optional[List[List[int]]],
2066
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
2067
+ vbe_output: Optional[Tensor] = None,
2068
+ vbe_output_offsets: Optional[Tensor] = None,
1686
2069
  ) -> invokers.lookup_args.VBEMetadata:
1687
2070
  # Blocking D2H copy, but only runs at first call
1688
2071
  self.feature_dims = self.feature_dims.cpu()
@@ -1705,6 +2088,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1705
2088
  self.pooling_mode,
1706
2089
  self.feature_dims,
1707
2090
  self.current_device,
2091
+ vbe_output,
2092
+ vbe_output_offsets,
1708
2093
  )
1709
2094
 
1710
2095
  @torch.jit.ignore
@@ -1713,14 +2098,30 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1713
2098
  # This allows models using this class to compile correctly
1714
2099
  return FeatureGate.is_enabled(feature)
1715
2100
 
2101
+ # pyre-fixme[2]: For 1st argument expected not ANY
2102
+ def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]:
2103
+ indices = self._indices
2104
+ offsets = self._offsets
2105
+ return writeback_gradient(
2106
+ grad,
2107
+ indices,
2108
+ offsets,
2109
+ self.feature_table_map,
2110
+ self._writeback_first_feature_only,
2111
+ )
2112
+
1716
2113
  def forward( # noqa: C901
1717
2114
  self,
1718
2115
  indices: Tensor,
1719
2116
  offsets: Tensor,
1720
2117
  per_sample_weights: Optional[Tensor] = None,
1721
2118
  feature_requires_grad: Optional[Tensor] = None,
1722
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
2119
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
1723
2120
  total_unique_indices: Optional[int] = None,
2121
+ hash_zch_identities: Optional[Tensor] = None,
2122
+ hash_zch_runtime_meta: Optional[Tensor] = None,
2123
+ vbe_output: Optional[Tensor] = None,
2124
+ vbe_output_offsets: Optional[Tensor] = None,
1724
2125
  ) -> Tensor:
1725
2126
  """
1726
2127
  The forward pass function that
@@ -1773,7 +2174,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1773
2174
  be set when using `OptimType.NONE`. This is because TBE
1774
2175
  requires this information for allocating the weight gradient
1775
2176
  tensor in the backward pass.
1776
-
2177
+ hash_zch_identities (Optional[Tensor]): The original raw IDs before
2178
+ remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
2179
+ populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
2180
+ and is required for Raw Embedding Streaming (RES) to maintain
2181
+ consistency between training and inference.
2182
+ vbe_output (Optional[Tensor]): An optional 2-D tensor of size that
2183
+ contains output for TBE VBE. The shape of the tensor is
2184
+ [1, total_vbe_output_size] where total_vbe_output_size is the
2185
+ output size across all ranks and all embedding tables.
2186
+ If this tensor is not None, the TBE VBE forward output is written
2187
+ to this tensor at the locations specified by `vbe_output_offsets`.
2188
+ vbe_output_offsets (Optional[Tensor]): An optional 2-D tensor that
2189
+ contains VBE output offsets to `vbe_output`. The shape of the
2190
+ tensor is [num_ranks, num_features].
2191
+ vbe_output_offsets[r][f] represents the starting offset for rank `r`
2192
+ and feature `f`.
1777
2193
  Returns:
1778
2194
  A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
1779
2195
  batch size and `total_D` = the sum of all embedding dimensions in the
@@ -1846,11 +2262,35 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1846
2262
  per_sample_weights,
1847
2263
  batch_size_per_feature_per_rank,
1848
2264
  force_cast_input_types=True,
2265
+ prefetch_pipeline=False,
2266
+ vbe_output=vbe_output,
2267
+ vbe_output_offsets=vbe_output_offsets,
1849
2268
  )
1850
2269
 
2270
+ # Only enable VBE if batch_size_per_feature_per_rank is not None
2271
+ assert not (
2272
+ batch_size_per_feature_per_rank is not None
2273
+ and self.use_writeback_bwd_prehook
2274
+ ), "VBE is not supported with writeback_bwd_prehook"
2275
+
1851
2276
  # Print input stats if enable (for debugging purpose only)
1852
2277
  self._debug_print_input_stats(indices, offsets, per_sample_weights)
1853
2278
 
2279
+ # Extract and Write input stats if enable
2280
+ if self._report_input_params is not None:
2281
+ self._report_input_params(
2282
+ feature_rows=self.rows_per_table,
2283
+ feature_dims=self.feature_dims,
2284
+ iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0,
2285
+ indices=indices,
2286
+ offsets=offsets,
2287
+ op_id=self.uuid,
2288
+ per_sample_weights=per_sample_weights,
2289
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2290
+ embedding_specs=[(s[0], s[1]) for s in self.embedding_specs],
2291
+ feature_table_map=self.feature_table_map,
2292
+ )
2293
+
1854
2294
  if not is_torchdynamo_compiling():
1855
2295
  # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
1856
2296
 
@@ -1878,7 +2318,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1878
2318
  # to be as fast as possible and memory usage doesn't matter (will be recycled
1879
2319
  # by dense fwd/bwd)
1880
2320
  self._prefetch(
1881
- indices, offsets, vbe_metadata, multipass_prefetch_config=None
2321
+ indices,
2322
+ offsets,
2323
+ vbe_metadata,
2324
+ multipass_prefetch_config=None,
2325
+ hash_zch_identities=hash_zch_identities,
2326
+ hash_zch_runtime_meta=hash_zch_runtime_meta,
1882
2327
  )
1883
2328
 
1884
2329
  if len(self.timesteps_prefetched) > 0:
@@ -1936,6 +2381,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1936
2381
  is_experimental=self.is_experimental,
1937
2382
  use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd,
1938
2383
  use_homogeneous_placements=self.use_homogeneous_placements,
2384
+ learning_rate_tensor=self.learning_rate_tensor,
2385
+ info_B_num_bits=self.info_B_num_bits,
2386
+ info_B_mask=self.info_B_mask,
1939
2387
  )
1940
2388
 
1941
2389
  if self.optimizer == OptimType.NONE:
@@ -2048,7 +2496,6 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2048
2496
  momentum1,
2049
2497
  momentum2,
2050
2498
  iter_int,
2051
- self.use_rowwise_bias_correction,
2052
2499
  row_counter=(
2053
2500
  row_counter if self.use_rowwise_bias_correction else None
2054
2501
  ),
@@ -2158,6 +2605,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2158
2605
  row_counter,
2159
2606
  iter_int,
2160
2607
  self.max_counter.item(),
2608
+ mixed_D=self.mixed_D,
2161
2609
  ),
2162
2610
  )
2163
2611
  elif self._used_rowwise_adagrad_with_global_weight_decay:
@@ -2176,6 +2624,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2176
2624
  # `Optional[Tensor]` but got `Union[Module, Tensor]`.
2177
2625
  prev_iter_dev=self.prev_iter_dev,
2178
2626
  gwd_lower_bound=self.gwd_lower_bound,
2627
+ mixed_D=self.mixed_D,
2179
2628
  ),
2180
2629
  )
2181
2630
  else:
@@ -2185,12 +2634,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2185
2634
  common_args,
2186
2635
  self.optimizer_args,
2187
2636
  momentum1,
2637
+ mixed_D=self.mixed_D,
2188
2638
  ),
2189
2639
  )
2190
2640
 
2191
2641
  raise ValueError(f"Invalid OptimType: {self.optimizer}")
2192
2642
 
2193
- def ema_inplace(self, emainplace_mode: Dict[str, float]) -> None:
2643
+ def ema_inplace(self, emainplace_mode: dict[str, float]) -> None:
2194
2644
  """
2195
2645
  Perform ema operations on the full sparse embedding tables.
2196
2646
  We organize the sparse table, in the following way.
@@ -2220,7 +2670,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2220
2670
  emainplace_mode["step_ema_coef"],
2221
2671
  )
2222
2672
 
2223
- def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None:
2673
+ def ensemble_and_swap(self, ensemble_mode: dict[str, float]) -> None:
2224
2674
  """
2225
2675
  Perform ensemble and swap operations on the full sparse embedding tables.
2226
2676
 
@@ -2268,7 +2718,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2268
2718
  ), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
2269
2719
  return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats
2270
2720
 
2271
- def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> List[float]:
2721
+ def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> list[float]:
2272
2722
  snapshot = self.get_uvm_cache_stats(use_local_cache)
2273
2723
  if use_local_cache:
2274
2724
  return snapshot.tolist()
@@ -2281,7 +2731,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2281
2731
  @torch.jit.ignore
2282
2732
  def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None:
2283
2733
  # TODO: Create a separate reporter class to unify the stdlog reporting
2284
- uvm_cache_stats: List[float] = self._get_uvm_cache_print_state(use_local_cache)
2734
+ uvm_cache_stats: list[float] = self._get_uvm_cache_print_state(use_local_cache)
2285
2735
  N = max(1, uvm_cache_stats[0])
2286
2736
  m = {
2287
2737
  "N_called": uvm_cache_stats[UVMCacheStatsIndex.num_calls],
@@ -2325,14 +2775,14 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2325
2775
  if not stats_reporter.should_report(self.step):
2326
2776
  return
2327
2777
 
2328
- uvm_cache_stats: List[float] = self.get_uvm_cache_stats(
2778
+ uvm_cache_stats: list[float] = self.get_uvm_cache_stats(
2329
2779
  use_local_cache=False
2330
2780
  ).tolist()
2331
2781
  self.last_reported_step = self.step
2332
2782
 
2333
2783
  if len(self.last_reported_uvm_stats) == 0:
2334
2784
  self.last_reported_uvm_stats = [0.0] * len(uvm_cache_stats)
2335
- uvm_cache_stats_delta: List[float] = [0.0] * len(uvm_cache_stats)
2785
+ uvm_cache_stats_delta: list[float] = [0.0] * len(uvm_cache_stats)
2336
2786
  for i in range(len(uvm_cache_stats)):
2337
2787
  uvm_cache_stats_delta[i] = (
2338
2788
  uvm_cache_stats[i] - self.last_reported_uvm_stats[i]
@@ -2361,7 +2811,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2361
2811
  indices: Tensor,
2362
2812
  offsets: Tensor,
2363
2813
  forward_stream: Optional[torch.cuda.Stream] = None,
2364
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
2814
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
2365
2815
  ) -> None:
2366
2816
  if self.prefetch_stream is None and forward_stream is not None:
2367
2817
  self.prefetch_stream = torch.cuda.current_stream()
@@ -2369,19 +2819,21 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2369
2819
  self.prefetch_stream != forward_stream
2370
2820
  ), "prefetch_stream and forward_stream should not be the same stream"
2371
2821
 
2372
- indices, offsets, _, vbe_metadata = self.prepare_inputs(
2373
- indices,
2374
- offsets,
2375
- per_sample_weights=None,
2376
- batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2377
- force_cast_input_types=False,
2378
- )
2379
-
2380
2822
  with self._recording_to_timer(
2381
2823
  self.prefetch_duration_timer,
2382
2824
  context=self.step,
2383
2825
  stream=torch.cuda.current_stream(),
2384
2826
  ):
2827
+
2828
+ indices, offsets, _, vbe_metadata = self.prepare_inputs(
2829
+ indices,
2830
+ offsets,
2831
+ per_sample_weights=None,
2832
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2833
+ force_cast_input_types=False,
2834
+ prefetch_pipeline=self.prefetch_pipeline,
2835
+ )
2836
+
2385
2837
  self._prefetch(
2386
2838
  indices,
2387
2839
  offsets,
@@ -2398,6 +2850,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2398
2850
  offsets: Tensor,
2399
2851
  vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
2400
2852
  multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
2853
+ hash_zch_identities: Optional[Tensor] = None,
2854
+ hash_zch_runtime_meta: Optional[Tensor] = None,
2401
2855
  ) -> None:
2402
2856
  if not is_torchdynamo_compiling():
2403
2857
  # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
@@ -2416,7 +2870,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2416
2870
  self.local_uvm_cache_stats.zero_()
2417
2871
  self._report_io_size_count("prefetch_input", indices)
2418
2872
 
2873
+ # streaming before updating the cache
2874
+ self.raw_embedding_stream()
2875
+
2419
2876
  final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
2877
+ linear_cache_indices_merged = torch.zeros(
2878
+ 0, dtype=indices.dtype, device=indices.device
2879
+ )
2420
2880
  for (
2421
2881
  partial_indices,
2422
2882
  partial_lxu_cache_locations,
@@ -2432,6 +2892,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2432
2892
  vbe_metadata.max_B if vbe_metadata is not None else -1,
2433
2893
  base_offset,
2434
2894
  )
2895
+ linear_cache_indices_merged = torch.cat(
2896
+ [linear_cache_indices_merged, linear_cache_indices]
2897
+ )
2435
2898
 
2436
2899
  if (
2437
2900
  self.record_cache_metrics.record_cache_miss_counter
@@ -2512,6 +2975,16 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2512
2975
  if self.should_log():
2513
2976
  self.print_uvm_cache_stats(use_local_cache=False)
2514
2977
 
2978
+ self._store_prefetched_tensors(
2979
+ indices,
2980
+ offsets,
2981
+ vbe_metadata,
2982
+ linear_cache_indices_merged,
2983
+ final_lxu_cache_locations,
2984
+ hash_zch_identities,
2985
+ hash_zch_runtime_meta,
2986
+ )
2987
+
2515
2988
  def should_log(self) -> bool:
2516
2989
  """Determines if we should log for this step, using exponentially decreasing frequency.
2517
2990
 
@@ -2596,12 +3069,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2596
3069
  tmp_emb.uniform_(min_val, max_val)
2597
3070
  tmp_emb_i8 = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(tmp_emb)
2598
3071
  emb.data.copy_(tmp_emb_i8)
3072
+ # Torch doesnt implement direct fp8 distribution functions, so we need to start in higher precision.
3073
+ elif self.weights_precision == SparseType.NFP8:
3074
+ assert (
3075
+ self.current_device.type == "cuda"
3076
+ ), "NFP8 is currently only supportd on GPU."
3077
+ assert self.optimizer in [
3078
+ OptimType.EXACT_ADAGRAD,
3079
+ OptimType.ROWWISE_ADAGRAD,
3080
+ OptimType.EXACT_ROWWISE_ADAGRAD,
3081
+ OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
3082
+ OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
3083
+ ], "NFP8 is currently only supportd with adagrad optimizers."
3084
+ for param in splits:
3085
+ tmp_param = torch.zeros(param.shape, device=self.current_device)
3086
+ # Create initialized weights and cast to fp8.
3087
+ fp8_dtype = (
3088
+ torch.float8_e4m3fnuz
3089
+ if torch.version.hip is not None
3090
+ else torch.float8_e4m3fn
3091
+ )
3092
+ tmp_param.uniform_(min_val, max_val).to(fp8_dtype)
3093
+ param.data.copy_(tmp_param)
2599
3094
  else:
2600
3095
  for param in splits:
2601
3096
  param.uniform_(min_val, max_val)
2602
3097
 
2603
3098
  @torch.jit.ignore
2604
- def split_embedding_weights(self) -> List[Tensor]:
3099
+ def split_embedding_weights(self) -> list[Tensor]:
2605
3100
  """
2606
3101
  Returns a list of embedding weights (view), split by table
2607
3102
 
@@ -2643,7 +3138,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2643
3138
  raise ValueError(f"Optimizer buffer {state} not found")
2644
3139
 
2645
3140
  @torch.jit.export
2646
- def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
3141
+ def get_optimizer_state(self) -> list[dict[str, torch.Tensor]]:
2647
3142
  r"""
2648
3143
  Get the optimizer state dict that matches the OSS Pytorch optims
2649
3144
  TODO: populate the supported list of optimizers
@@ -2656,7 +3151,23 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2656
3151
  ):
2657
3152
  list_of_state_dict = [
2658
3153
  (
2659
- {"sum": states[0], "prev_iter": states[1], "row_counter": states[2]}
3154
+ (
3155
+ {
3156
+ "sum": states[0],
3157
+ "prev_iter": states[1],
3158
+ "row_counter": states[2],
3159
+ "iter": self.iter,
3160
+ }
3161
+ if self.optimizer_args.regularization_mode
3162
+ == WeightDecayMode.COUNTER.value
3163
+ and self.optimizer_args.weight_decay_mode
3164
+ == CounterWeightDecayMode.ADAGRADW.value
3165
+ else {
3166
+ "sum": states[0],
3167
+ "prev_iter": states[1],
3168
+ "row_counter": states[2],
3169
+ }
3170
+ )
2660
3171
  if self._used_rowwise_adagrad_with_counter
2661
3172
  else (
2662
3173
  {
@@ -2711,7 +3222,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2711
3222
  @torch.jit.ignore
2712
3223
  def split_optimizer_states(
2713
3224
  self,
2714
- ) -> List[List[torch.Tensor]]:
3225
+ ) -> list[list[torch.Tensor]]:
2715
3226
  """
2716
3227
  Returns a list of optimizer states (view), split by table
2717
3228
 
@@ -2759,7 +3270,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2759
3270
  state_offsets: Tensor,
2760
3271
  state_placements: Tensor,
2761
3272
  rowwise: bool,
2762
- ) -> List[torch.Tensor]:
3273
+ ) -> list[torch.Tensor]:
2763
3274
  splits = []
2764
3275
  for t, (rows, dim, _, _) in enumerate(self.embedding_specs):
2765
3276
  offset = state_offsets[t]
@@ -2778,7 +3289,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2778
3289
  splits.append(state.detach()[offset : offset + rows].view(rows))
2779
3290
  return splits
2780
3291
 
2781
- states: List[List[torch.Tensor]] = []
3292
+ states: list[list[torch.Tensor]] = []
2782
3293
  if self.optimizer not in (OptimType.EXACT_SGD,):
2783
3294
  states.append(
2784
3295
  get_optimizer_states(
@@ -2899,15 +3410,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2899
3410
 
2900
3411
  def get_learning_rate(self) -> float:
2901
3412
  """
2902
- Sets the learning rate.
2903
-
2904
- Args:
2905
- lr (float): The learning rate value to set to
3413
+ Get and return the learning rate.
2906
3414
  """
2907
- return self.optimizer_args.learning_rate
3415
+ return self.learning_rate_tensor.item()
2908
3416
 
2909
3417
  @torch.jit.ignore
2910
- def update_hyper_parameters(self, params_dict: Dict[str, float]) -> None:
3418
+ def update_hyper_parameters(self, params_dict: dict[str, float]) -> None:
2911
3419
  """
2912
3420
  Sets hyper-parameters from external control flow.
2913
3421
 
@@ -2943,7 +3451,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2943
3451
  Helper function to script `set_learning_rate`.
2944
3452
  Note that returning None does not work.
2945
3453
  """
2946
- self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
3454
+ self.learning_rate_tensor.fill_(lr)
2947
3455
  return 0.0
2948
3456
 
2949
3457
  @torch.jit.ignore
@@ -2983,10 +3491,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2983
3491
  self,
2984
3492
  split: SplitState,
2985
3493
  prefix: str,
2986
- dtype: Type[torch.dtype],
3494
+ dtype: type[torch.dtype],
2987
3495
  enforce_hbm: bool = False,
2988
3496
  make_dev_param: bool = False,
2989
- dev_reshape: Optional[Tuple[int, ...]] = None,
3497
+ dev_reshape: Optional[tuple[int, ...]] = None,
2990
3498
  uvm_host_mapped: bool = False,
2991
3499
  ) -> None:
2992
3500
  apply_split_helper(
@@ -3036,6 +3544,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3036
3544
  dtype = torch.float32
3037
3545
  elif cache_precision == SparseType.FP16:
3038
3546
  dtype = torch.float16
3547
+ elif cache_precision == SparseType.NFP8:
3548
+ # NFP8 weights use floating point cache.
3549
+ dtype = torch.float16
3039
3550
  else:
3040
3551
  dtype = torch.float32 # not relevant, but setting it to keep linter happy
3041
3552
  if not self.use_cpu > 0:
@@ -3229,7 +3740,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3229
3740
  def _update_cache_counter_and_locations(
3230
3741
  self,
3231
3742
  module: nn.Module,
3232
- grad_input: Union[Tuple[Tensor, ...], Tensor],
3743
+ grad_input: Union[tuple[Tensor, ...], Tensor],
3233
3744
  ) -> None:
3234
3745
  """
3235
3746
  Backward prehook function when prefetch_pipeline is enabled.
@@ -3425,9 +3936,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3425
3936
  indices: Tensor,
3426
3937
  offsets: Tensor,
3427
3938
  per_sample_weights: Optional[Tensor] = None,
3428
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
3939
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3429
3940
  force_cast_input_types: bool = True,
3430
- ) -> Tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3941
+ prefetch_pipeline: bool = False,
3942
+ vbe_output: Optional[Tensor] = None,
3943
+ vbe_output_offsets: Optional[Tensor] = None,
3944
+ ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3431
3945
  """
3432
3946
  Prepare TBE inputs as follows:
3433
3947
 
@@ -3453,11 +3967,34 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3453
3967
  metadata
3454
3968
  """
3455
3969
 
3970
+ if vbe_output is not None or vbe_output_offsets is not None:
3971
+ assert (
3972
+ not self.use_cpu
3973
+ ), "[TBE API v2] Using pre-allocated vbe_output is not supported on CPU"
3974
+ check_allocated_vbe_output(
3975
+ self.output_dtype,
3976
+ batch_size_per_feature_per_rank,
3977
+ vbe_output,
3978
+ vbe_output_offsets,
3979
+ )
3980
+
3456
3981
  # Generate VBE metadata
3457
3982
  vbe_metadata = self._generate_vbe_metadata(
3458
- offsets, batch_size_per_feature_per_rank
3983
+ offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
3459
3984
  )
3460
3985
 
3986
+ vbe = vbe_metadata.B_offsets is not None
3987
+ # Note this check has already been done in C++ side
3988
+ # TODO: max_B <= self.info_B_mask in python
3989
+ # We cannot use assert as it breaks pt2 compile for dynamic shape
3990
+ # and need to use torch._check for dynamic shape and cannot construct fstring, use constant string.
3991
+ # torch._check(
3992
+ # max_B <= self.info_B_mask,
3993
+ # "Not enough infos bits to accommodate T and B.",
3994
+ # )
3995
+ # We cannot use lambda as it fails jit script.
3996
+ # torch._check is also not supported in jitscript
3997
+
3461
3998
  # TODO: remove this and add an assert after updating
3462
3999
  # bounds_check_indices to support different indices type and offset
3463
4000
  # type
@@ -3485,10 +4022,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3485
4022
  per_sample_weights = per_sample_weights.float()
3486
4023
 
3487
4024
  if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
4025
+ # Override the bounds check version based on prefetch_pipeline
4026
+ use_bounds_check_v2 = self.bounds_check_version == 2 or prefetch_pipeline
4027
+ bounds_check_version = (
4028
+ 2 if use_bounds_check_v2 else self.bounds_check_version
4029
+ )
4030
+
3488
4031
  vbe = vbe_metadata.B_offsets is not None
4032
+
3489
4033
  # Compute B info and VBE metadata for bounds_check_indices only if
3490
4034
  # VBE and bounds check indices v2 are used
3491
- if vbe and self.use_bounds_check_v2:
4035
+ if vbe and use_bounds_check_v2:
3492
4036
  B_offsets = vbe_metadata.B_offsets
3493
4037
  B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
3494
4038
  output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
@@ -3499,11 +4043,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3499
4043
  assert isinstance(
3500
4044
  output_offsets_feature_rank, Tensor
3501
4045
  ), "output_offsets_feature_rank must be tensor"
3502
- info_B_num_bits, info_B_mask = torch.ops.fbgemm.get_infos_metadata(
3503
- B_offsets, # unused tensor
3504
- vbe_metadata.max_B,
3505
- B_offsets.numel() - 1, # T
3506
- )
4046
+
3507
4047
  row_output_offsets, b_t_map = torch.ops.fbgemm.generate_vbe_metadata(
3508
4048
  B_offsets,
3509
4049
  B_offsets_rank_per_feature,
@@ -3512,13 +4052,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3512
4052
  self.max_D,
3513
4053
  self.is_nobag,
3514
4054
  vbe_metadata.max_B_feature_rank,
3515
- info_B_num_bits,
3516
- offsets.numel() - 1, # total_B
4055
+ self.info_B_num_bits,
4056
+ offsets.numel() - 1, # total_B,
4057
+ vbe_output_offsets,
3517
4058
  )
3518
4059
  else:
3519
4060
  b_t_map = None
3520
- info_B_num_bits = -1
3521
- info_B_mask = -1
3522
4061
 
3523
4062
  torch.ops.fbgemm.bounds_check_indices(
3524
4063
  self.rows_per_table,
@@ -3530,9 +4069,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3530
4069
  B_offsets=vbe_metadata.B_offsets,
3531
4070
  max_B=vbe_metadata.max_B,
3532
4071
  b_t_map=b_t_map,
3533
- info_B_num_bits=info_B_num_bits,
3534
- info_B_mask=info_B_mask,
3535
- bounds_check_version=self.bounds_check_version,
4072
+ info_B_num_bits=self.info_B_num_bits,
4073
+ info_B_mask=self.info_B_mask,
4074
+ bounds_check_version=bounds_check_version,
4075
+ prefetch_pipeline=prefetch_pipeline,
3536
4076
  )
3537
4077
 
3538
4078
  return indices, offsets, per_sample_weights, vbe_metadata
@@ -3603,7 +4143,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3603
4143
  # Counts of indices that segment lengths > 1024
3604
4144
  counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024]
3605
4145
 
3606
- def compute_numel_and_avg(counts: Tensor) -> Tuple[int, float]:
4146
+ def compute_numel_and_avg(counts: Tensor) -> tuple[int, float]:
3607
4147
  numel = counts.numel()
3608
4148
  avg = (counts.sum().item() / numel) if numel != 0 else -1.0
3609
4149
  return numel, avg
@@ -3671,6 +4211,240 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3671
4211
  return _debug_print_input_stats_factory_impl
3672
4212
  return _debug_print_input_stats_factory_null
3673
4213
 
4214
+ @torch.jit.ignore
4215
+ def raw_embedding_stream(self) -> None:
4216
+ if not self.enable_raw_embedding_streaming:
4217
+ return None
4218
+ # when pipelining is enabled
4219
+ # prefetch in iter i happens before the backward sparse in iter i - 1
4220
+ # so embeddings for iter i - 1's changed ids are not updated.
4221
+ # so we can only fetch the indices from the iter i - 2
4222
+ # when pipelining is disabled
4223
+ # prefetch in iter i happens before forward iter i
4224
+ # so we can get the iter i - 1's changed ids safely.
4225
+ target_prev_iter = 1
4226
+ if self.prefetch_pipeline:
4227
+ target_prev_iter = 2
4228
+ if not len(self.prefetched_info_list) > (target_prev_iter - 1):
4229
+ return None
4230
+ with record_function(
4231
+ "## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
4232
+ ):
4233
+ prefetched_info = self.prefetched_info_list.pop(0)
4234
+ updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
4235
+ prefetched_info.linear_unique_cache_indices,
4236
+ self.lxu_cache_state,
4237
+ self.total_cache_hash_size,
4238
+ gather_cache_stats=False, # not collecting cache stats
4239
+ num_uniq_cache_indices=prefetched_info.linear_unique_indices_length,
4240
+ )
4241
+ updated_weights = torch.empty(
4242
+ [
4243
+ prefetched_info.linear_unique_cache_indices.size()[0],
4244
+ self.max_D_cache,
4245
+ ],
4246
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
4247
+ dtype=self.lxu_cache_weights.dtype,
4248
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
4249
+ device=self.lxu_cache_weights.device,
4250
+ )
4251
+ torch.ops.fbgemm.masked_index_select(
4252
+ updated_weights,
4253
+ updated_locations,
4254
+ self.lxu_cache_weights,
4255
+ prefetched_info.linear_unique_indices_length,
4256
+ )
4257
+ # TODO: this statement triggers a sync
4258
+ # added here to make this diff self-contained
4259
+ # will remove in later change
4260
+ cache_hit_mask_index = (
4261
+ updated_locations.narrow(
4262
+ 0, 0, prefetched_info.linear_unique_indices_length.item()
4263
+ )
4264
+ .not_equal(-1)
4265
+ .nonzero()
4266
+ .flatten()
4267
+ )
4268
+ # stream weights
4269
+ self._raw_embedding_streamer.stream(
4270
+ prefetched_info.linear_unique_indices.index_select(
4271
+ dim=0, index=cache_hit_mask_index
4272
+ ).to(device=torch.device("cpu")),
4273
+ updated_weights.index_select(dim=0, index=cache_hit_mask_index).to(
4274
+ device=torch.device("cpu")
4275
+ ),
4276
+ (
4277
+ prefetched_info.hash_zch_identities.index_select(
4278
+ dim=0, index=cache_hit_mask_index
4279
+ ).to(device=torch.device("cpu"))
4280
+ if prefetched_info.hash_zch_identities is not None
4281
+ else None
4282
+ ),
4283
+ (
4284
+ prefetched_info.hash_zch_runtime_meta.index_select(
4285
+ dim=0, index=cache_hit_mask_index
4286
+ ).to(device=torch.device("cpu"))
4287
+ if prefetched_info.hash_zch_runtime_meta is not None
4288
+ else None
4289
+ ),
4290
+ prefetched_info.linear_unique_indices_length.to(
4291
+ device=torch.device("cpu")
4292
+ ),
4293
+ False, # require_tensor_copy
4294
+ False, # blocking_tensor_copy
4295
+ )
4296
+
4297
+ @staticmethod
4298
+ @torch.jit.ignore
4299
+ def _get_prefetched_info(
4300
+ linear_indices: torch.Tensor,
4301
+ linear_cache_indices_merged: torch.Tensor,
4302
+ total_cache_hash_size: int,
4303
+ hash_zch_identities: Optional[torch.Tensor],
4304
+ hash_zch_runtime_meta: Optional[torch.Tensor],
4305
+ max_indices_length: int,
4306
+ ) -> PrefetchedInfo:
4307
+ (
4308
+ linear_unique_cache_indices,
4309
+ linear_unique_cache_indices_length,
4310
+ linear_unique_cache_indices_cnt,
4311
+ linear_unique_cache_inverse_indices,
4312
+ ) = torch.ops.fbgemm.get_unique_indices_with_inverse(
4313
+ linear_cache_indices_merged,
4314
+ total_cache_hash_size,
4315
+ compute_count=True,
4316
+ compute_inverse_indices=True,
4317
+ )
4318
+ # pure cpu op, no need to sync, to avoid the indices out size the weights buffer
4319
+ max_len = min(
4320
+ max_indices_length,
4321
+ linear_unique_cache_indices.size(0),
4322
+ )
4323
+ if max_len < linear_unique_cache_indices.size(0):
4324
+ linear_unique_cache_indices_length.clamp_(max=max_len)
4325
+ # linear_unique_indices is the result after deduplication and sorting
4326
+ linear_unique_cache_indices = linear_unique_cache_indices.narrow(
4327
+ 0, 0, max_len
4328
+ )
4329
+ # Compute cumulative sum as indices for selecting unique elements to
4330
+ # map hash_zch_identities and hash_zch_runtime_meta to linear_unique_indices
4331
+ count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum(
4332
+ linear_unique_cache_indices_cnt
4333
+ )
4334
+ # count_cum_sum will be one more element than linear_unique_cache_indices_cnt
4335
+ count_cum_sum = count_cum_sum.narrow(0, 0, max_len)
4336
+ # clamp the uninitialized elements to avoid out of bound access
4337
+ # the uninitialized elements will be sliced out by linear_unique_cache_indices_length
4338
+ # directly using linear_unique_cache_indices_length requires a sync
4339
+ count_cum_sum.clamp_(min=0, max=linear_unique_cache_inverse_indices.size(0) - 1)
4340
+
4341
+ # Select indices corresponding to first occurrence of each unique element
4342
+ linear_unique_inverse_indices = (
4343
+ linear_unique_cache_inverse_indices.index_select(dim=0, index=count_cum_sum)
4344
+ )
4345
+ # same as above clamp
4346
+ linear_unique_inverse_indices.clamp_(min=0, max=linear_indices.size(0) - 1)
4347
+ linear_unique_indices = linear_indices.index_select(
4348
+ dim=0, index=linear_unique_inverse_indices
4349
+ )
4350
+ if hash_zch_identities is not None:
4351
+ # Map hash_zch_identities to unique indices
4352
+ hash_zch_identities = hash_zch_identities.index_select(
4353
+ dim=0, index=linear_unique_inverse_indices
4354
+ )
4355
+
4356
+ if hash_zch_runtime_meta is not None:
4357
+ # Map hash_zch_runtime_meta to unique indices
4358
+ hash_zch_runtime_meta = hash_zch_runtime_meta.index_select(
4359
+ dim=0, index=linear_unique_inverse_indices
4360
+ )
4361
+
4362
+ return PrefetchedInfo(
4363
+ linear_unique_indices,
4364
+ linear_unique_cache_indices,
4365
+ linear_unique_cache_indices_length,
4366
+ hash_zch_identities,
4367
+ hash_zch_runtime_meta,
4368
+ )
4369
+
4370
+ @torch.jit.ignore
4371
+ def _store_prefetched_tensors(
4372
+ self,
4373
+ indices: torch.Tensor,
4374
+ offsets: torch.Tensor,
4375
+ vbe_metadata: Optional[invokers.lookup_args.VBEMetadata],
4376
+ linear_cache_indices_merged: torch.Tensor,
4377
+ final_lxu_cache_locations: torch.Tensor,
4378
+ hash_zch_identities: Optional[torch.Tensor],
4379
+ hash_zch_runtime_meta: Optional[torch.Tensor],
4380
+ ) -> None:
4381
+ """
4382
+ NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional.
4383
+ This function stores the prefetched tensors for the raw embedding streaming.
4384
+ """
4385
+ if not self.enable_raw_embedding_streaming:
4386
+ return
4387
+
4388
+ with record_function(
4389
+ "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
4390
+ ):
4391
+ found_in_cache_mask = final_lxu_cache_locations != -1
4392
+ # only process the indices that are found in the cache
4393
+ # this will filter out the indices from tables that doesn't have UVM_CACHE enabled
4394
+ linear_cache_indices_merged_masked = torch.where(
4395
+ found_in_cache_mask,
4396
+ linear_cache_indices_merged,
4397
+ self.total_cache_hash_size,
4398
+ )
4399
+ linearize_indices = torch.ops.fbgemm.linearize_cache_indices(
4400
+ self.hash_size_cumsum,
4401
+ indices,
4402
+ offsets,
4403
+ vbe_metadata.B_offsets if vbe_metadata is not None else None,
4404
+ vbe_metadata.max_B if vbe_metadata is not None else -1,
4405
+ )
4406
+ # -1 indices are ignored in raw_embedding_streamer.
4407
+ linearize_indices_masked = torch.where(
4408
+ found_in_cache_mask,
4409
+ linearize_indices,
4410
+ -1,
4411
+ )
4412
+ # Process hash_zch_identities using helper function
4413
+ prefetched_info = self._get_prefetched_info(
4414
+ linearize_indices_masked,
4415
+ linear_cache_indices_merged_masked,
4416
+ self.total_cache_hash_size,
4417
+ hash_zch_identities,
4418
+ hash_zch_runtime_meta,
4419
+ self.lxu_cache_weights.size(0),
4420
+ )
4421
+
4422
+ self.prefetched_info_list.append(prefetched_info)
4423
+
4424
+ @torch.jit.ignore
4425
+ def __report_input_params_factory(
4426
+ self,
4427
+ ) -> Optional[Callable[..., None]]:
4428
+ """
4429
+ This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
4430
+
4431
+ If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
4432
+ - Reports input parameters (TBEDataConfig).
4433
+ - Writes the output as a JSON file.
4434
+
4435
+ If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
4436
+ """
4437
+ try:
4438
+ if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
4439
+ from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
4440
+
4441
+ reporter = TBEBenchmarkParamsReporter.create()
4442
+ return reporter.report_stats
4443
+ except Exception:
4444
+ return None
4445
+
4446
+ return None
4447
+
3674
4448
 
3675
4449
  class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3676
4450
  """
@@ -3684,12 +4458,12 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3684
4458
  max_D: int
3685
4459
  hash_size_cumsum: Tensor
3686
4460
  total_hash_size_bits: int
3687
- embedding_specs: List[Tuple[int, int]]
4461
+ embedding_specs: list[tuple[int, int]]
3688
4462
 
3689
4463
  def __init__(
3690
4464
  self,
3691
- embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims)
3692
- feature_table_map: Optional[List[int]] = None, # [T]
4465
+ embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
4466
+ feature_table_map: Optional[list[int]] = None, # [T]
3693
4467
  weights_precision: SparseType = SparseType.FP32,
3694
4468
  pooling_mode: PoolingMode = PoolingMode.SUM,
3695
4469
  use_cpu: bool = False,
@@ -3732,7 +4506,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3732
4506
  )
3733
4507
 
3734
4508
  self.embedding_specs = embedding_specs
3735
- (rows, dims) = zip(*embedding_specs)
4509
+ rows, dims = zip(*embedding_specs)
3736
4510
  T_ = len(self.embedding_specs)
3737
4511
  assert T_ > 0
3738
4512
 
@@ -3802,7 +4576,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3802
4576
  row for (row, _) in embedding_specs[:t]
3803
4577
  )
3804
4578
 
3805
- self.weights_physical_offsets: List[int] = weights_offsets
4579
+ self.weights_physical_offsets: list[int] = weights_offsets
3806
4580
  weights_offsets = [weights_offsets[t] for t in feature_table_map]
3807
4581
  self.register_buffer(
3808
4582
  "weights_offsets",
@@ -3829,7 +4603,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3829
4603
  def _generate_vbe_metadata(
3830
4604
  self,
3831
4605
  offsets: Tensor,
3832
- batch_size_per_feature_per_rank: Optional[List[List[int]]],
4606
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
3833
4607
  ) -> invokers.lookup_args.VBEMetadata:
3834
4608
  # Blocking D2H copy, but only runs at first call
3835
4609
  self.feature_dims = self.feature_dims.cpu()
@@ -3847,7 +4621,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3847
4621
  offsets: Tensor,
3848
4622
  per_sample_weights: Optional[Tensor] = None,
3849
4623
  feature_requires_grad: Optional[Tensor] = None,
3850
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
4624
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3851
4625
  ) -> Tensor:
3852
4626
  # Generate VBE metadata
3853
4627
  vbe_metadata = self._generate_vbe_metadata(
@@ -3886,7 +4660,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
3886
4660
  )
3887
4661
 
3888
4662
  @torch.jit.export
3889
- def split_embedding_weights(self) -> List[Tensor]:
4663
+ def split_embedding_weights(self) -> list[Tensor]:
3890
4664
  """
3891
4665
  Returns a list of weights, split by table
3892
4666
  """