fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -14,13 +14,13 @@ import itertools
14
14
  import logging
15
15
  import math
16
16
  import os
17
- import tempfile
18
17
  import threading
19
18
  import time
20
19
  from functools import cached_property
21
- from math import ceil, floor, log2
22
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
20
+ from math import floor, log2
21
+ from typing import Any, Callable, ClassVar, Optional, Union
23
22
  import torch # usort:skip
23
+ import weakref
24
24
 
25
25
  # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
26
26
  import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
@@ -35,6 +35,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
35
35
  BoundsCheckMode,
36
36
  CacheAlgorithm,
37
37
  EmbeddingLocation,
38
+ EvictionPolicy,
38
39
  get_bounds_check_version_for_platform,
39
40
  KVZCHParams,
40
41
  PoolingMode,
@@ -49,10 +50,12 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
49
50
  WeightDecayMode,
50
51
  )
51
52
  from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
53
+ check_allocated_vbe_output,
52
54
  generate_vbe_metadata,
53
55
  is_torchdynamo_compiling,
54
56
  )
55
57
  from torch import distributed as dist, nn, Tensor # usort:skip
58
+ import sys
56
59
  from dataclasses import dataclass
57
60
 
58
61
  from torch.autograd.profiler import record_function
@@ -76,10 +79,10 @@ class IterData:
76
79
 
77
80
  @dataclass
78
81
  class KVZCHCachedData:
79
- cached_optimizer_states_per_table: List[List[torch.Tensor]]
80
- cached_weight_tensor_per_table: List[torch.Tensor]
81
- cached_id_tensor_per_table: List[torch.Tensor]
82
- cached_bucket_splits: List[torch.Tensor]
82
+ cached_optimizer_states_per_table: list[list[torch.Tensor]]
83
+ cached_weight_tensor_per_table: list[torch.Tensor]
84
+ cached_id_tensor_per_table: list[torch.Tensor]
85
+ cached_bucket_splits: list[torch.Tensor]
83
86
 
84
87
 
85
88
  class SSDTableBatchedEmbeddingBags(nn.Module):
@@ -100,13 +103,18 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
100
103
  weights_offsets: Tensor
101
104
  _local_instance_index: int = -1
102
105
  res_params: RESParams
103
- table_names: List[str]
106
+ table_names: list[str]
107
+ _all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet()
108
+ _first_instance_ref: ClassVar[weakref.ref] = None
109
+ _eviction_triggered: ClassVar[bool] = False
104
110
 
105
111
  def __init__(
106
112
  self,
107
- embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims)
108
- feature_table_map: Optional[List[int]], # [T]
113
+ embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
114
+ feature_table_map: Optional[list[int]], # [T]
109
115
  cache_sets: int,
116
+ # A comma-separated string, e.g. "/data00_nvidia0,/data01_nvidia0/", db shards
117
+ # will be placed in these paths round-robin.
110
118
  ssd_storage_directory: str,
111
119
  ssd_rocksdb_shards: int = 1,
112
120
  ssd_memtable_flush_period: int = -1,
@@ -146,13 +154,16 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
146
154
  pooling_mode: PoolingMode = PoolingMode.SUM,
147
155
  bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
148
156
  # Parameter Server Configs
149
- ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
157
+ ps_hosts: Optional[tuple[tuple[str, int]]] = None,
150
158
  ps_max_key_per_request: Optional[int] = None,
151
159
  ps_client_thread_num: Optional[int] = None,
152
160
  ps_max_local_index_length: Optional[int] = None,
153
161
  tbe_unique_id: int = -1,
154
- # in local test we need to use the pass in path for rocksdb creation
155
- # in production we need to do it inside SSD mount path which will ignores the passed in path
162
+ # If set to True, will use `ssd_storage_directory` as the ssd paths.
163
+ # If set to False, will use the default ssd paths.
164
+ # In local test we need to use the pass in path for rocksdb creation
165
+ # fn production we could either use the default ssd mount points or explicity specify ssd
166
+ # mount points using `ssd_storage_directory`.
156
167
  use_passed_in_path: int = True,
157
168
  gather_ssd_cache_stats: Optional[bool] = False,
158
169
  stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
@@ -172,14 +183,18 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
172
183
  enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
173
184
  res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
174
185
  flushing_block_size: int = 2_000_000_000, # 2GB
175
- table_names: Optional[List[str]] = None,
176
- optimizer_state_dtypes: Dict[str, SparseType] = {}, # noqa: B006
186
+ table_names: Optional[list[str]] = None,
187
+ use_rowwise_bias_correction: bool = False, # For Adam use
188
+ optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
189
+ pg: Optional[dist.ProcessGroup] = None,
177
190
  ) -> None:
178
191
  super(SSDTableBatchedEmbeddingBags, self).__init__()
179
192
 
180
193
  # Set the optimizer
181
194
  assert optimizer in (
182
195
  OptimType.EXACT_ROWWISE_ADAGRAD,
196
+ OptimType.PARTIAL_ROWWISE_ADAM,
197
+ OptimType.ADAM,
183
198
  ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
184
199
  self.optimizer = optimizer
185
200
 
@@ -187,15 +202,28 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
187
202
  assert weights_precision in (SparseType.FP32, SparseType.FP16)
188
203
  self.weights_precision = weights_precision
189
204
  self.output_dtype: int = output_dtype.as_int()
190
- self.optimizer_state_dtypes: Dict[str, SparseType] = optimizer_state_dtypes
205
+
206
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
207
+ # Adagrad currently only supports FP32 for momentum1
208
+ self.optimizer_state_dtypes: dict[str, SparseType] = {
209
+ "momentum1": SparseType.FP32,
210
+ }
211
+ else:
212
+ self.optimizer_state_dtypes: dict[str, SparseType] = optimizer_state_dtypes
191
213
 
192
214
  # Zero collision TBE configurations
193
215
  self.kv_zch_params = kv_zch_params
194
216
  self.backend_type = backend_type
195
217
  self.enable_optimizer_offloading: bool = False
196
218
  self.backend_return_whole_row: bool = False
219
+ self._embedding_cache_mode: bool = False
220
+ self.load_ckpt_without_opt: bool = False
197
221
  if self.kv_zch_params:
198
222
  self.kv_zch_params.validate()
223
+ self.load_ckpt_without_opt = (
224
+ # pyre-ignore [16]
225
+ self.kv_zch_params.load_ckpt_without_opt
226
+ )
199
227
  self.enable_optimizer_offloading = (
200
228
  # pyre-ignore [16]
201
229
  self.kv_zch_params.enable_optimizer_offloading
@@ -214,12 +242,43 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
214
242
  logging.info(
215
243
  "Backend will return whole row including metaheader, weight and optimizer for checkpoint"
216
244
  )
245
+ # pyre-ignore [16]
246
+ self._embedding_cache_mode = self.kv_zch_params.embedding_cache_mode
247
+ if self._embedding_cache_mode:
248
+ logging.info("KVZCH is in embedding_cache_mode")
249
+ assert self.optimizer in [
250
+ OptimType.EXACT_ROWWISE_ADAGRAD
251
+ ], f"only EXACT_ROWWISE_ADAGRAD supports embedding cache mode, but got {self.optimizer}"
252
+ if self.load_ckpt_without_opt:
253
+ if (
254
+ # pyre-ignore [16]
255
+ self.kv_zch_params.optimizer_type_for_st
256
+ == OptimType.PARTIAL_ROWWISE_ADAM.value
257
+ ):
258
+ self.optimizer = OptimType.PARTIAL_ROWWISE_ADAM
259
+ logging.info(
260
+ f"Override optimizer type with {self.optimizer=} for st publish"
261
+ )
262
+ if (
263
+ # pyre-ignore [16]
264
+ self.kv_zch_params.optimizer_state_dtypes_for_st
265
+ is not None
266
+ ):
267
+ optimizer_state_dtypes = {}
268
+ for k, v in dict(
269
+ self.kv_zch_params.optimizer_state_dtypes_for_st
270
+ ).items():
271
+ optimizer_state_dtypes[k] = SparseType.from_int(v)
272
+ self.optimizer_state_dtypes = optimizer_state_dtypes
273
+ logging.info(
274
+ f"Override optimizer_state_dtypes with {self.optimizer_state_dtypes=} for st publish"
275
+ )
217
276
 
218
277
  self.pooling_mode = pooling_mode
219
278
  self.bounds_check_mode_int: int = bounds_check_mode.value
220
279
  self.embedding_specs = embedding_specs
221
280
  self.table_names = table_names if table_names is not None else []
222
- (rows, dims) = zip(*embedding_specs)
281
+ rows, dims = zip(*embedding_specs)
223
282
  T_ = len(self.embedding_specs)
224
283
  assert T_ > 0
225
284
  # pyre-fixme[8]: Attribute has type `device`; used as `int`.
@@ -238,7 +297,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
238
297
  f"get env {self.res_params.res_server_port=}, at rank {dist.get_rank()}, with {self.res_params=}"
239
298
  )
240
299
 
241
- self.feature_table_map: List[int] = (
300
+ self.feature_table_map: list[int] = (
242
301
  feature_table_map if feature_table_map is not None else list(range(T_))
243
302
  )
244
303
  T = len(self.feature_table_map)
@@ -318,7 +377,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
318
377
  torch.tensor(dims, device="cpu", dtype=torch.int64),
319
378
  )
320
379
 
321
- (info_B_num_bits_, info_B_mask_) = torch.ops.fbgemm.get_infos_metadata(
380
+ info_B_num_bits_, info_B_mask_ = torch.ops.fbgemm.get_infos_metadata(
322
381
  self.D_offsets, # unused tensor
323
382
  1, # max_B
324
383
  T, # T
@@ -514,11 +573,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
514
573
  self.record_function_via_dummy_profile_factory(use_dummy_profile)
515
574
  )
516
575
 
517
- os.makedirs(ssd_storage_directory, exist_ok=True)
576
+ if use_passed_in_path:
577
+ ssd_dir_list = ssd_storage_directory.split(",")
578
+ for ssd_dir in ssd_dir_list:
579
+ os.makedirs(ssd_dir, exist_ok=True)
518
580
 
519
- ssd_directory = tempfile.mkdtemp(
520
- prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
521
- )
581
+ ssd_directory = ssd_storage_directory
522
582
  # logging.info("DEBUG: weights_precision {}".format(weights_precision))
523
583
 
524
584
  """
@@ -538,10 +598,16 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
538
598
  """
539
599
  self._cached_kvzch_data: Optional[KVZCHCachedData] = None
540
600
  # initial embedding rows on this rank per table, this is used for loading checkpoint
541
- self.local_weight_counts: List[int] = [0] * T_
601
+ self.local_weight_counts: list[int] = [0] * T_
602
+ # groundtruth global id on this rank per table, this is used for loading checkpoint
603
+ self.global_id_per_rank: list[torch.Tensor] = [torch.zeros(0)] * T_
542
604
  # loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend
543
605
  self.load_state_dict: bool = False
544
606
 
607
+ SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self)
608
+ if SSDTableBatchedEmbeddingBags._first_instance_ref is None:
609
+ SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self)
610
+
545
611
  # create tbe unique id using rank index | local tbe idx
546
612
  if tbe_unique_id == -1:
547
613
  SSDTableBatchedEmbeddingBags._local_instance_index += 1
@@ -559,6 +625,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
559
625
  self.tbe_unique_id = tbe_unique_id
560
626
  self.l2_cache_size = l2_cache_size
561
627
  logging.info(f"tbe_unique_id: {tbe_unique_id}")
628
+ self.enable_free_mem_trigger_eviction: bool = False
562
629
  if self.backend_type == BackendType.SSD:
563
630
  logging.info(
564
631
  f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
@@ -614,6 +681,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
614
681
  else None
615
682
  ),
616
683
  flushing_block_size,
684
+ self._embedding_cache_mode, # disable_random_init
617
685
  )
618
686
  if self.bulk_init_chunk_size > 0:
619
687
  self.ssd_uniform_init_lower: float = ssd_uniform_init_lower
@@ -662,18 +730,41 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
662
730
  if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
663
731
  else self.l2_cache_size
664
732
  )
733
+ kv_zch_params = self.kv_zch_params
734
+ eviction_policy = self.kv_zch_params.eviction_policy
735
+ if eviction_policy.eviction_trigger_mode == 5:
736
+ # If trigger mode is free_mem(5), populate config
737
+ self.set_free_mem_eviction_trigger_config(eviction_policy)
738
+
739
+ enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
740
+ [
741
+ int(x)
742
+ for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
743
+ ]
744
+ if eviction_policy.enable_eviction_for_feature_score_eviction_policy
745
+ is not None
746
+ else None
747
+ )
748
+ # Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
665
749
  eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
666
- self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
667
- self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
668
- self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
750
+ eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
751
+ eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
752
+ eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
669
753
  eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
670
- self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
671
- self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is feature score
672
- self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is feature score
673
- self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
754
+ eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
755
+ eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
756
+ eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
757
+ eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
758
+ eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
759
+ eviction_policy.training_id_keep_count, # training_id_keep_count for each table
760
+ enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
761
+ eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
674
762
  table_dims.tolist() if table_dims is not None else None,
675
- self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s,
676
- self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s,
763
+ eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
764
+ eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
765
+ eviction_policy.interval_for_insufficient_eviction_s,
766
+ eviction_policy.interval_for_sufficient_eviction_s,
767
+ eviction_policy.interval_for_feature_statistics_decay_s,
677
768
  )
678
769
  self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
679
770
  self.cache_row_dim,
@@ -690,16 +781,20 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
690
781
  else None
691
782
  ), # hash_size_cumsum
692
783
  self.backend_return_whole_row, # backend_return_whole_row
784
+ False, # enable_async_update
785
+ self._embedding_cache_mode, # disable_random_init
693
786
  )
694
787
  else:
695
788
  raise AssertionError(f"Invalid backend type {self.backend_type}")
696
789
 
697
790
  # pyre-fixme[20]: Argument `self` expected.
698
- (low_priority, high_priority) = torch.cuda.Stream.priority_range()
791
+ low_priority, high_priority = torch.cuda.Stream.priority_range()
699
792
  # GPU stream for SSD cache eviction
700
793
  self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
701
- # GPU stream for SSD memory copy
794
+ # GPU stream for SSD memory copy (also reused for feature score D2H)
702
795
  self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
796
+ # GPU stream for async metadata operation
797
+ self.feature_score_stream = torch.cuda.Stream(priority=low_priority)
703
798
 
704
799
  # SSD get completion event
705
800
  self.ssd_event_get = torch.cuda.Event()
@@ -711,6 +806,17 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
711
806
  self.ssd_event_backward = torch.cuda.Event()
712
807
  # SSD get's input copy completion event
713
808
  self.ssd_event_get_inputs_cpy = torch.cuda.Event()
809
+ if self._embedding_cache_mode:
810
+ # Direct write embedding completion event
811
+ self.direct_write_l1_complete_event: torch.cuda.streams.Event = (
812
+ torch.cuda.Event()
813
+ )
814
+ self.direct_write_sp_complete_event: torch.cuda.streams.Event = (
815
+ torch.cuda.Event()
816
+ )
817
+ # Prefetch operation completion event
818
+ self.prefetch_complete_event = torch.cuda.Event()
819
+
714
820
  if self.prefetch_pipeline:
715
821
  # SSD scratch pad index queue insert completion event
716
822
  self.ssd_event_sp_idxq_insert: torch.cuda.streams.Event = torch.cuda.Event()
@@ -771,22 +877,22 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
771
877
  )
772
878
 
773
879
  # (Indices, Count)
774
- self.prefetched_info: List[Tuple[Tensor, Tensor]] = []
880
+ self.prefetched_info: list[tuple[Tensor, Tensor]] = []
775
881
 
776
- self.timesteps_prefetched: List[int] = []
882
+ self.timesteps_prefetched: list[int] = []
777
883
  # TODO: add type annotation
778
884
  # pyre-fixme[4]: Attribute must be annotated.
779
885
  self.ssd_prefetch_data = []
780
886
 
781
887
  # Scratch pad eviction data queue
782
- self.ssd_scratch_pad_eviction_data: List[
783
- Tuple[Tensor, Tensor, Tensor, bool]
888
+ self.ssd_scratch_pad_eviction_data: list[
889
+ tuple[Tensor, Tensor, Tensor, bool]
784
890
  ] = []
785
- self.ssd_location_update_data: List[Tuple[Tensor, Tensor]] = []
891
+ self.ssd_location_update_data: list[tuple[Tensor, Tensor]] = []
786
892
 
787
893
  if self.prefetch_pipeline:
788
894
  # Scratch pad value queue
789
- self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = []
895
+ self.ssd_scratch_pads: list[tuple[Tensor, Tensor, Tensor]] = []
790
896
 
791
897
  # pyre-ignore[4]
792
898
  # Scratch pad index queue
@@ -835,7 +941,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
835
941
  weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
836
942
  lower_bound=cowclip_regularization.lower_bound,
837
943
  regularization_mode=weight_decay_mode.value,
838
- use_rowwise_bias_correction=False, # Unused, this is used in TBE's Adam
944
+ use_rowwise_bias_correction=use_rowwise_bias_correction, # Used in Adam optimizer
839
945
  )
840
946
 
841
947
  table_embedding_dtype = weights_precision.as_dtype()
@@ -888,7 +994,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
888
994
  self.ssd_cache_stats_size = 6
889
995
  # 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
890
996
  # 4: N_conflict_unique_misses, 5: N_conflict_misses
891
- self.last_reported_ssd_stats: List[float] = []
997
+ self.last_reported_ssd_stats: list[float] = []
892
998
  self.last_reported_step = 0
893
999
 
894
1000
  self.register_buffer(
@@ -919,7 +1025,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
919
1025
  self.prefetch_parallel_stream_cnt: int = 2
920
1026
  # tuple of iteration, prefetch parallel stream cnt, reported duration
921
1027
  # since there are 2 stream in parallel in prefetch, we want to count the longest one
922
- self.prefetch_duration_us: Tuple[int, int, float] = (
1028
+ self.prefetch_duration_us: tuple[int, int, float] = (
923
1029
  -1,
924
1030
  self.prefetch_parallel_stream_cnt,
925
1031
  0,
@@ -945,6 +1051,20 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
945
1051
  self.dram_kv_allocated_bytes_stats_name: str = (
946
1052
  f"dram_kv.mem.tbe_id{tbe_unique_id}.allocated_bytes"
947
1053
  )
1054
+ self.dram_kv_mem_num_rows_stats_name: str = (
1055
+ f"dram_kv.mem.tbe_id{tbe_unique_id}.num_rows"
1056
+ )
1057
+
1058
+ self.eviction_sum_evicted_counts_stats_name: str = (
1059
+ f"eviction.tbe_id.{tbe_unique_id}.sum_evicted_counts"
1060
+ )
1061
+ self.eviction_sum_processed_counts_stats_name: str = (
1062
+ f"eviction.tbe_id.{tbe_unique_id}.sum_processed_counts"
1063
+ )
1064
+ self.eviction_evict_rate_stats_name: str = (
1065
+ f"eviction.tbe_id.{tbe_unique_id}.evict_rate"
1066
+ )
1067
+
948
1068
  if self.stats_reporter:
949
1069
  self.ssd_prefetch_read_timer = AsyncSeriesTimer(
950
1070
  functools.partial(
@@ -972,9 +1092,41 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
972
1092
  self.stats_reporter.register_stats(
973
1093
  self.dram_kv_actual_used_chunk_bytes_stats_name
974
1094
  )
1095
+ self.stats_reporter.register_stats(self.dram_kv_mem_num_rows_stats_name)
1096
+ self.stats_reporter.register_stats(
1097
+ self.eviction_sum_evicted_counts_stats_name
1098
+ )
1099
+ self.stats_reporter.register_stats(
1100
+ self.eviction_sum_processed_counts_stats_name
1101
+ )
1102
+ self.stats_reporter.register_stats(self.eviction_evict_rate_stats_name)
1103
+ for t in self.feature_table_map:
1104
+ self.stats_reporter.register_stats(
1105
+ f"eviction.feature_table.{t}.evicted_counts"
1106
+ )
1107
+ self.stats_reporter.register_stats(
1108
+ f"eviction.feature_table.{t}.processed_counts"
1109
+ )
1110
+ self.stats_reporter.register_stats(
1111
+ f"eviction.feature_table.{t}.evict_rate"
1112
+ )
1113
+ self.stats_reporter.register_stats(
1114
+ "eviction.feature_table.full_duration_ms"
1115
+ )
1116
+ self.stats_reporter.register_stats(
1117
+ "eviction.feature_table.exec_duration_ms"
1118
+ )
1119
+ self.stats_reporter.register_stats(
1120
+ "eviction.feature_table.dry_run_exec_duration_ms"
1121
+ )
1122
+ self.stats_reporter.register_stats(
1123
+ "eviction.feature_table.exec_div_full_duration_rate"
1124
+ )
975
1125
 
976
1126
  self.bounds_check_version: int = get_bounds_check_version_for_platform()
977
1127
 
1128
+ self._pg = pg
1129
+
978
1130
  @cached_property
979
1131
  def cache_row_dim(self) -> int:
980
1132
  """
@@ -982,7 +1134,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
982
1134
  padding to the nearest 4 elements and the optimizer state appended to
983
1135
  the back of the row
984
1136
  """
985
- if self.enable_optimizer_offloading:
1137
+
1138
+ # For st publish, we only need to load weight for publishing and bulk eval
1139
+ if self.enable_optimizer_offloading and not self.load_ckpt_without_opt:
986
1140
  return self.max_D + pad4(
987
1141
  # Compute the number of elements of cache_dtype needed to store
988
1142
  # the optimizer state
@@ -1182,10 +1336,10 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1182
1336
  self,
1183
1337
  split: SplitState,
1184
1338
  prefix: str,
1185
- dtype: Type[torch.dtype],
1339
+ dtype: type[torch.dtype],
1186
1340
  enforce_hbm: bool = False,
1187
1341
  make_dev_param: bool = False,
1188
- dev_reshape: Optional[Tuple[int, ...]] = None,
1342
+ dev_reshape: Optional[tuple[int, ...]] = None,
1189
1343
  ) -> None:
1190
1344
  apply_split_helper(
1191
1345
  self.register_buffer,
@@ -1208,11 +1362,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1208
1362
 
1209
1363
  def to_pinned_cpu_on_stream_wait_on_another_stream(
1210
1364
  self,
1211
- tensors: List[Tensor],
1365
+ tensors: list[Tensor],
1212
1366
  stream: torch.cuda.Stream,
1213
1367
  stream_to_wait_on: torch.cuda.Stream,
1214
1368
  post_event: Optional[torch.cuda.Event] = None,
1215
- ) -> List[Tensor]:
1369
+ ) -> list[Tensor]:
1216
1370
  """
1217
1371
  Transfer input tensors from GPU to CPU using a pinned host
1218
1372
  buffer. The transfer is carried out on the given stream
@@ -1274,6 +1428,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1274
1428
  Returns:
1275
1429
  None
1276
1430
  """
1431
+ if not self.training: # if not training, freeze the embedding
1432
+ return
1277
1433
  with record_function(f"## ssd_evict_{name} ##"):
1278
1434
  with torch.cuda.stream(stream):
1279
1435
  if pre_event is not None:
@@ -1286,7 +1442,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1286
1442
  self.record_function_via_dummy_profile(
1287
1443
  f"## ssd_set_{name} ##",
1288
1444
  self.ssd_db.set_cuda,
1289
- indices_cpu.cpu(),
1445
+ indices_cpu,
1290
1446
  rows_cpu,
1291
1447
  actions_count_cpu,
1292
1448
  self.timestep,
@@ -1450,7 +1606,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1450
1606
  def _update_cache_counter_and_pointers(
1451
1607
  self,
1452
1608
  module: nn.Module,
1453
- grad_input: Union[Tuple[Tensor, ...], Tensor],
1609
+ grad_input: Union[tuple[Tensor, ...], Tensor],
1454
1610
  ) -> None:
1455
1611
  """
1456
1612
  Update cache line locking counter and pointers before backward
@@ -1535,9 +1691,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1535
1691
  if len(self.ssd_location_update_data) == 0:
1536
1692
  return
1537
1693
 
1538
- (sp_curr_next_map, inserted_rows_next) = self.ssd_location_update_data.pop(
1539
- 0
1540
- )
1694
+ sp_curr_next_map, inserted_rows_next = self.ssd_location_update_data.pop(0)
1541
1695
 
1542
1696
  # Update poitners
1543
1697
  torch.ops.fbgemm.ssd_update_row_addrs(
@@ -1552,12 +1706,63 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1552
1706
  unique_indices_length_curr=curr_data.actions_count_gpu,
1553
1707
  )
1554
1708
 
1709
+ def _update_feature_score_metadata(
1710
+ self,
1711
+ linear_cache_indices: Tensor,
1712
+ weights: Tensor,
1713
+ d2h_stream: torch.cuda.Stream,
1714
+ write_stream: torch.cuda.Stream,
1715
+ pre_event_for_write: torch.cuda.Event,
1716
+ post_event: Optional[torch.cuda.Event] = None,
1717
+ ) -> None:
1718
+ """
1719
+ Write feature score metadata to DRAM
1720
+
1721
+ This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
1722
+ The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
1723
+
1724
+ Args:
1725
+ linear_cache_indices: GPU tensor containing cache indices
1726
+ weights: GPU tensor containing feature scores
1727
+ d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
1728
+ write_stream: Stream for metadata write operation
1729
+ pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
1730
+ post_event: Event to record when the operation is done
1731
+ """
1732
+ # Start D2H copy on d2h_stream
1733
+ with torch.cuda.stream(d2h_stream):
1734
+ # Record streams to prevent premature deallocation
1735
+ linear_cache_indices.record_stream(d2h_stream)
1736
+ weights.record_stream(d2h_stream)
1737
+ # Do the D2H copy
1738
+ linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
1739
+ score_weights_cpu = self.to_pinned_cpu(weights)
1740
+
1741
+ # Write feature score metadata to DRAM
1742
+ with record_function("## ssd_write_feature_score_metadata ##"):
1743
+ with torch.cuda.stream(write_stream):
1744
+ write_stream.wait_event(pre_event_for_write)
1745
+ write_stream.wait_stream(d2h_stream)
1746
+ self.record_function_via_dummy_profile(
1747
+ "## ssd_write_feature_score_metadata ##",
1748
+ self.ssd_db.set_feature_score_metadata_cuda,
1749
+ linear_cache_indices_cpu,
1750
+ torch.tensor(
1751
+ [score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
1752
+ ),
1753
+ score_weights_cpu,
1754
+ )
1755
+
1756
+ if post_event is not None:
1757
+ write_stream.record_event(post_event)
1758
+
1555
1759
  def prefetch(
1556
1760
  self,
1557
1761
  indices: Tensor,
1558
1762
  offsets: Tensor,
1763
+ weights: Optional[Tensor] = None, # todo: need to update caller
1559
1764
  forward_stream: Optional[torch.cuda.Stream] = None,
1560
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
1765
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
1561
1766
  ) -> None:
1562
1767
  if self.prefetch_stream is None and forward_stream is not None:
1563
1768
  # Set the prefetch stream to the current stream
@@ -1581,6 +1786,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1581
1786
  self._prefetch(
1582
1787
  indices,
1583
1788
  offsets,
1789
+ weights,
1584
1790
  vbe_metadata,
1585
1791
  forward_stream,
1586
1792
  )
@@ -1589,11 +1795,17 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1589
1795
  self,
1590
1796
  indices: Tensor,
1591
1797
  offsets: Tensor,
1798
+ weights: Optional[Tensor] = None,
1592
1799
  vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
1593
1800
  forward_stream: Optional[torch.cuda.Stream] = None,
1594
1801
  ) -> None:
1595
- # TODO: Refactor prefetch
1802
+ # Wait for any ongoing direct_write_embedding operations to complete
1803
+ # Moving this from forward() to _prefetch() is more logical as direct_write
1804
+ # operations affect the same cache structures that prefetch interacts with
1596
1805
  current_stream = torch.cuda.current_stream()
1806
+ if self._embedding_cache_mode:
1807
+ current_stream.wait_event(self.direct_write_l1_complete_event)
1808
+ current_stream.wait_event(self.direct_write_sp_complete_event)
1597
1809
 
1598
1810
  B_offsets = None
1599
1811
  max_B = -1
@@ -1700,8 +1912,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1700
1912
  name="cache_update",
1701
1913
  )
1702
1914
  current_stream.wait_event(self.ssd_event_cache_streaming_synced)
1703
- (updated_indices, updated_counts_gpu) = (
1704
- self.prefetched_info.pop(0)
1915
+ updated_indices, updated_counts_gpu = self.prefetched_info.pop(
1916
+ 0
1705
1917
  )
1706
1918
  self.lxu_cache_updated_indices[: updated_indices.size(0)].copy_(
1707
1919
  updated_indices,
@@ -1878,12 +2090,13 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1878
2090
  # Store info for evicting the previous iteration's
1879
2091
  # scratch pad after the corresponding backward pass is
1880
2092
  # done
1881
- self.ssd_location_update_data.append(
1882
- (
1883
- sp_curr_prev_map_gpu,
1884
- inserted_rows,
2093
+ if self.training:
2094
+ self.ssd_location_update_data.append(
2095
+ (
2096
+ sp_curr_prev_map_gpu,
2097
+ inserted_rows,
2098
+ )
1885
2099
  )
1886
- )
1887
2100
 
1888
2101
  # Ensure the previous iterations eviction is complete
1889
2102
  current_stream.wait_event(self.ssd_event_sp_evict)
@@ -1931,7 +2144,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1931
2144
  self.ssd_cache_stats = torch.add(
1932
2145
  self.ssd_cache_stats, self.local_ssd_cache_stats
1933
2146
  )
1934
- self._report_kv_backend_stats()
2147
+ # only report metrics from rank0 to avoid flooded logging
2148
+ if dist.get_rank() == 0:
2149
+ self._report_kv_backend_stats()
2150
+
2151
+ # May trigger eviction if free mem trigger mode enabled before get cuda
2152
+ self.may_trigger_eviction()
1935
2153
 
1936
2154
  # Fetch data from SSD
1937
2155
  if linear_cache_indices.numel() > 0:
@@ -1955,21 +2173,35 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1955
2173
  use_pipeline=self.prefetch_pipeline,
1956
2174
  )
1957
2175
 
1958
- if linear_cache_indices.numel() > 0:
1959
- # Evict rows from cache to SSD
1960
- self.evict(
1961
- rows=self.lxu_cache_evicted_weights,
1962
- indices_cpu=self.lxu_cache_evicted_indices,
1963
- actions_count_cpu=self.lxu_cache_evicted_count,
1964
- stream=self.ssd_eviction_stream,
1965
- pre_event=self.ssd_event_get,
1966
- # Record completion event after scratch pad eviction
1967
- # instead since that happens after L1 eviction
1968
- post_event=self.ssd_event_cache_evict,
1969
- is_rows_uvm=True,
1970
- name="cache",
1971
- is_bwd=False,
1972
- )
2176
+ if self.training:
2177
+ if linear_cache_indices.numel() > 0:
2178
+ # Evict rows from cache to SSD
2179
+ self.evict(
2180
+ rows=self.lxu_cache_evicted_weights,
2181
+ indices_cpu=self.lxu_cache_evicted_indices,
2182
+ actions_count_cpu=self.lxu_cache_evicted_count,
2183
+ stream=self.ssd_eviction_stream,
2184
+ pre_event=self.ssd_event_get,
2185
+ # Record completion event after scratch pad eviction
2186
+ # instead since that happens after L1 eviction
2187
+ post_event=self.ssd_event_cache_evict,
2188
+ is_rows_uvm=True,
2189
+ name="cache",
2190
+ is_bwd=False,
2191
+ )
2192
+ if (
2193
+ self.backend_type == BackendType.DRAM
2194
+ and weights is not None
2195
+ and linear_cache_indices.numel() > 0
2196
+ ):
2197
+ # Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
2198
+ self._update_feature_score_metadata(
2199
+ linear_cache_indices=linear_cache_indices,
2200
+ weights=weights,
2201
+ d2h_stream=self.ssd_memcpy_stream,
2202
+ write_stream=self.feature_score_stream,
2203
+ pre_event_for_write=self.ssd_event_cache_evict,
2204
+ )
1973
2205
 
1974
2206
  # Generate row addresses (pointing to either L1 or the current
1975
2207
  # iteration's scratch pad)
@@ -2051,24 +2283,32 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2051
2283
  )
2052
2284
  )
2053
2285
 
2054
- # Store scratch pad info for post backward eviction
2055
- self.ssd_scratch_pad_eviction_data.append(
2056
- (
2057
- inserted_rows,
2058
- post_bwd_evicted_indices_cpu,
2059
- actions_count_cpu,
2060
- linear_cache_indices.numel() > 0,
2286
+ # Store scratch pad info for post backward eviction only for training
2287
+ # for eval job, no backward pass, so no need to store this info
2288
+ if self.training:
2289
+ self.ssd_scratch_pad_eviction_data.append(
2290
+ (
2291
+ inserted_rows,
2292
+ post_bwd_evicted_indices_cpu,
2293
+ actions_count_cpu,
2294
+ linear_cache_indices.numel() > 0,
2295
+ )
2061
2296
  )
2062
- )
2063
2297
 
2064
2298
  # Store data for forward
2065
2299
  self.ssd_prefetch_data.append(prefetch_data)
2066
2300
 
2301
+ # Record an event to mark the completion of prefetch operations
2302
+ # This will be used by direct_write_embedding to ensure it doesn't run concurrently with prefetch
2303
+ current_stream.record_event(self.prefetch_complete_event)
2304
+
2067
2305
  @torch.jit.ignore
2068
2306
  def _generate_vbe_metadata(
2069
2307
  self,
2070
2308
  offsets: Tensor,
2071
- batch_size_per_feature_per_rank: Optional[List[List[int]]],
2309
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
2310
+ vbe_output: Optional[Tensor] = None,
2311
+ vbe_output_offsets: Optional[Tensor] = None,
2072
2312
  ) -> invokers.lookup_args.VBEMetadata:
2073
2313
  # Blocking D2H copy, but only runs at first call
2074
2314
  self.feature_dims = self.feature_dims.cpu()
@@ -2087,6 +2327,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2087
2327
  self.pooling_mode,
2088
2328
  self.feature_dims,
2089
2329
  self.current_device,
2330
+ vbe_output,
2331
+ vbe_output_offsets,
2090
2332
  )
2091
2333
 
2092
2334
  def _increment_iteration(self) -> int:
@@ -2113,14 +2355,30 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2113
2355
  self,
2114
2356
  indices: Tensor,
2115
2357
  offsets: Tensor,
2358
+ weights: Optional[Tensor] = None,
2116
2359
  per_sample_weights: Optional[Tensor] = None,
2117
2360
  feature_requires_grad: Optional[Tensor] = None,
2118
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
2361
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
2362
+ vbe_output: Optional[Tensor] = None,
2363
+ vbe_output_offsets: Optional[Tensor] = None,
2119
2364
  # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
2120
2365
  ) -> Tensor:
2121
2366
  self.clear_cache()
2367
+ if vbe_output is not None or vbe_output_offsets is not None:
2368
+ # CPU is not supported in SSD TBE
2369
+ check_allocated_vbe_output(
2370
+ self.output_dtype,
2371
+ batch_size_per_feature_per_rank,
2372
+ vbe_output,
2373
+ vbe_output_offsets,
2374
+ )
2122
2375
  indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
2123
- indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
2376
+ indices,
2377
+ offsets,
2378
+ per_sample_weights,
2379
+ batch_size_per_feature_per_rank,
2380
+ vbe_output=vbe_output,
2381
+ vbe_output_offsets=vbe_output_offsets,
2124
2382
  )
2125
2383
 
2126
2384
  if len(self.timesteps_prefetched) == 0:
@@ -2134,7 +2392,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2134
2392
  context=self.step,
2135
2393
  stream=self.ssd_eviction_stream,
2136
2394
  ):
2137
- self._prefetch(indices, offsets, vbe_metadata)
2395
+ self._prefetch(indices, offsets, weights, vbe_metadata)
2138
2396
 
2139
2397
  assert len(self.ssd_prefetch_data) > 0
2140
2398
 
@@ -2205,13 +2463,21 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2205
2463
  self.step += 1
2206
2464
 
2207
2465
  # Increment the iteration (value is used for certain optimizers)
2208
- self._increment_iteration()
2209
-
2210
- if self.optimizer == OptimType.EXACT_SGD:
2211
- raise AssertionError(
2212
- "SSDTableBatchedEmbeddingBags currently does not support SGD"
2466
+ iter_int = self._increment_iteration()
2467
+
2468
+ if self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
2469
+ momentum2 = invokers.lookup_args_ssd.Momentum(
2470
+ # pyre-ignore[6]
2471
+ dev=self.momentum2_dev,
2472
+ # pyre-ignore[6]
2473
+ host=self.momentum2_host,
2474
+ # pyre-ignore[6]
2475
+ uvm=self.momentum2_uvm,
2476
+ # pyre-ignore[6]
2477
+ offsets=self.momentum2_offsets,
2478
+ # pyre-ignore[6]
2479
+ placements=self.momentum2_placements,
2213
2480
  )
2214
- return invokers.lookup_sgd_ssd.invoke(common_args, self.optimizer_args)
2215
2481
 
2216
2482
  momentum1 = invokers.lookup_args_ssd.Momentum(
2217
2483
  dev=self.momentum1_dev,
@@ -2226,10 +2492,44 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2226
2492
  common_args, self.optimizer_args, momentum1
2227
2493
  )
2228
2494
 
2495
+ elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2496
+ return invokers.lookup_partial_rowwise_adam_ssd.invoke(
2497
+ common_args,
2498
+ self.optimizer_args,
2499
+ momentum1,
2500
+ # pyre-ignore[61]
2501
+ momentum2,
2502
+ iter_int,
2503
+ )
2504
+
2505
+ elif self.optimizer == OptimType.ADAM:
2506
+ row_counter = invokers.lookup_args_ssd.Momentum(
2507
+ # pyre-fixme[6]
2508
+ dev=self.row_counter_dev,
2509
+ # pyre-fixme[6]
2510
+ host=self.row_counter_host,
2511
+ # pyre-fixme[6]
2512
+ uvm=self.row_counter_uvm,
2513
+ # pyre-fixme[6]
2514
+ offsets=self.row_counter_offsets,
2515
+ # pyre-fixme[6]
2516
+ placements=self.row_counter_placements,
2517
+ )
2518
+
2519
+ return invokers.lookup_adam_ssd.invoke(
2520
+ common_args,
2521
+ self.optimizer_args,
2522
+ momentum1,
2523
+ # pyre-ignore[61]
2524
+ momentum2,
2525
+ iter_int,
2526
+ row_counter=row_counter,
2527
+ )
2528
+
2229
2529
  @torch.jit.ignore
2230
2530
  def _split_optimizer_states_non_kv_zch(
2231
2531
  self,
2232
- ) -> List[List[torch.Tensor]]:
2532
+ ) -> list[list[torch.Tensor]]:
2233
2533
  """
2234
2534
  Returns a list of optimizer states (view), split by table.
2235
2535
 
@@ -2246,11 +2546,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2246
2546
  """
2247
2547
 
2248
2548
  # Row count per table
2249
- (rows, dims) = zip(*self.embedding_specs)
2549
+ rows, dims = zip(*self.embedding_specs)
2250
2550
  # Cumulative row counts per table for rowwise states
2251
- row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
2551
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
2252
2552
  # Cumulative element counts per table for elementwise states
2253
- elem_count_cumsum: List[int] = [0] + list(
2553
+ elem_count_cumsum: list[int] = [0] + list(
2254
2554
  itertools.accumulate([r * d for r, d in self.embedding_specs])
2255
2555
  )
2256
2556
 
@@ -2286,6 +2586,17 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2286
2586
  ]
2287
2587
  for t, _ in enumerate(rows)
2288
2588
  ]
2589
+
2590
+ elif self.optimizer == OptimType.ADAM:
2591
+ return [
2592
+ [
2593
+ _slice(self.momentum1_dev, t, rowwise=False),
2594
+ # pyre-ignore[6]
2595
+ _slice(self.momentum2_dev, t, rowwise=False),
2596
+ ]
2597
+ for t, _ in enumerate(rows)
2598
+ ]
2599
+
2289
2600
  else:
2290
2601
  raise NotImplementedError(
2291
2602
  f"Getting optimizer states is not supported for {self.optimizer}"
@@ -2295,14 +2606,14 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2295
2606
  def _split_optimizer_states_kv_zch_no_offloading(
2296
2607
  self,
2297
2608
  sorted_ids: torch.Tensor,
2298
- ) -> List[List[torch.Tensor]]:
2609
+ ) -> list[list[torch.Tensor]]:
2299
2610
 
2300
2611
  # Row count per table
2301
- (rows, dims) = zip(*self.embedding_specs)
2612
+ rows, dims = zip(*self.embedding_specs)
2302
2613
  # Cumulative row counts per table for rowwise states
2303
- row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
2614
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
2304
2615
  # Cumulative element counts per table for elementwise states
2305
- elem_count_cumsum: List[int] = [0] + list(
2616
+ elem_count_cumsum: list[int] = [0] + list(
2306
2617
  itertools.accumulate([r * d for r, d in self.embedding_specs])
2307
2618
  )
2308
2619
 
@@ -2332,7 +2643,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2332
2643
  # based on the sorted_ids compute the table offset for the
2333
2644
  # table, view the slice as 2D tensor of e x d, then fetch the
2334
2645
  # sub-slice by local ids
2335
- local_ids = sorted_ids[t] - bucket_id_start * bucket_size
2646
+ #
2647
+ # local_ids is [N, 1], flatten it to N to keep the returned tensor 2D
2648
+ local_ids = (sorted_ids[t] - bucket_id_start * bucket_size).view(-1)
2336
2649
  return (
2337
2650
  tensor.detach()
2338
2651
  .cpu()[elem_count_cumsum[t] : elem_count_cumsum[t + 1]]
@@ -2364,6 +2677,16 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2364
2677
  for t, _ in enumerate(rows)
2365
2678
  ]
2366
2679
 
2680
+ elif self.optimizer == OptimType.ADAM:
2681
+ return [
2682
+ [
2683
+ _slice("momentum1", self.momentum1_dev, t, rowwise=False),
2684
+ # pyre-ignore[6]
2685
+ _slice("momentum2", self.momentum2_dev, t, rowwise=False),
2686
+ ]
2687
+ for t, _ in enumerate(rows)
2688
+ ]
2689
+
2367
2690
  else:
2368
2691
  raise NotImplementedError(
2369
2692
  f"Getting optimizer states is not supported for {self.optimizer}"
@@ -2375,12 +2698,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2375
2698
  sorted_ids: torch.Tensor,
2376
2699
  no_snapshot: bool = True,
2377
2700
  should_flush: bool = False,
2378
- ) -> List[List[torch.Tensor]]:
2701
+ ) -> list[list[torch.Tensor]]:
2379
2702
  dtype = self.weights_precision.as_dtype()
2380
2703
  # Row count per table
2381
- (rows_, dims_) = zip(*self.embedding_specs)
2704
+ rows_, dims_ = zip(*self.embedding_specs)
2382
2705
  # Cumulative row counts per table for rowwise states
2383
- row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows_))
2706
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
2384
2707
 
2385
2708
  snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
2386
2709
  no_snapshot=no_snapshot,
@@ -2390,7 +2713,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2390
2713
  # pyre-ignore[53]
2391
2714
  def _fetch_offloaded_optimizer_states(
2392
2715
  t: int,
2393
- ) -> List[Tensor]:
2716
+ ) -> list[Tensor]:
2394
2717
  e: int = rows_[t]
2395
2718
  d: int = dims_[t]
2396
2719
 
@@ -2403,12 +2726,31 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2403
2726
  # Count of rows to fetch
2404
2727
  rows_to_fetch = sorted_ids[t].numel()
2405
2728
 
2729
+ # Lookup the byte offsets for each optimizer state
2730
+ optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
2731
+ d, self.weights_precision, self.optimizer_state_dtypes
2732
+ )
2733
+ # Find the minimum start of all the start/end pairs - we have to
2734
+ # offset the start/end pairs by this value to get the correct start/end
2735
+ offset_ = min(
2736
+ [start for _, (start, _) in optimizer_state_byte_offsets.items()]
2737
+ )
2738
+ # Update the start/end pairs to be relative to offset_
2739
+ optimizer_state_byte_offsets = dict(
2740
+ (k, (v1 - offset_, v2 - offset_))
2741
+ for k, (v1, v2) in optimizer_state_byte_offsets.items()
2742
+ )
2743
+
2406
2744
  # Since the backend returns cache rows that pack the weights and
2407
2745
  # optimizer states together, reading the whole tensor could cause OOM,
2408
2746
  # so we use the KVTensorWrapper abstraction to query the backend and
2409
2747
  # fetch the data in chunks instead.
2410
2748
  tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2411
- shape=[e, self.optimizer_state_dim],
2749
+ shape=[
2750
+ e,
2751
+ # Dim is terms of **weights** dtype
2752
+ self.optimizer_state_dim,
2753
+ ],
2412
2754
  dtype=dtype,
2413
2755
  row_offset=row_offset,
2414
2756
  snapshot_handle=snapshot_handle,
@@ -2421,19 +2763,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2421
2763
  else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2422
2764
  )
2423
2765
 
2424
- # Lookup the byte offsets for each optimizer state
2425
- optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
2426
- d, self.weights_precision, self.optimizer_state_dtypes
2427
- )
2428
- # Since we will be working with buffer rows that contain the
2429
- # optimizer states only, we need to offset the byte offsets by
2430
- # D * dtype.itemsize
2431
- offset_ = d * dtype.itemsize
2432
- optimizer_state_byte_offsets = dict(
2433
- (k, (v1 - offset_, v2 - offset_))
2434
- for k, (v1, v2) in optimizer_state_byte_offsets.items()
2435
- )
2436
-
2437
2766
  # Fetch the state size table for the given weights domension
2438
2767
  state_size_table = self.optimizer.state_size_table(d)
2439
2768
 
@@ -2462,10 +2791,10 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2462
2791
  )
2463
2792
 
2464
2793
  # Now split up the buffer into N views, N for each optimizer state
2465
- optimizer_states: List[Tensor] = []
2794
+ optimizer_states: list[Tensor] = []
2466
2795
  for state_name in self.optimizer.state_names():
2467
2796
  # Extract the offsets
2468
- (start, end) = optimizer_state_byte_offsets[state_name]
2797
+ start, end = optimizer_state_byte_offsets[state_name]
2469
2798
 
2470
2799
  state = optimizer_states_buffer.view(
2471
2800
  # Force tensor to byte view
@@ -2500,13 +2829,150 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2500
2829
  for t, d in enumerate(dims_)
2501
2830
  ]
2502
2831
 
2832
+ @torch.jit.ignore
2833
+ def _split_optimizer_states_kv_zch_whole_row(
2834
+ self,
2835
+ sorted_ids: torch.Tensor,
2836
+ no_snapshot: bool = True,
2837
+ should_flush: bool = False,
2838
+ ) -> list[list[torch.Tensor]]:
2839
+ dtype = self.weights_precision.as_dtype()
2840
+
2841
+ # Row and dimension counts per table
2842
+ # rows_ is only used here to compute the virtual table offsets
2843
+ rows_, dims_ = zip(*self.embedding_specs)
2844
+
2845
+ # Cumulative row counts per (virtual) table for rowwise states
2846
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
2847
+
2848
+ snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
2849
+ no_snapshot=no_snapshot,
2850
+ should_flush=should_flush,
2851
+ )
2852
+
2853
+ # pyre-ignore[53]
2854
+ def _fetch_offloaded_optimizer_states(
2855
+ t: int,
2856
+ ) -> list[Tensor]:
2857
+ d: int = dims_[t]
2858
+
2859
+ # pyre-ignore[16]
2860
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2861
+ # pyre-ignore[16]
2862
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
2863
+ row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
2864
+
2865
+ # When backend returns whole row, the optimizer will be returned as
2866
+ # PMT directly
2867
+ if sorted_ids[t].size(0) == 0 and self.local_weight_counts[t] > 0:
2868
+ logging.info(
2869
+ f"Before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
2870
+ )
2871
+ sorted_ids[t] = torch.zeros(
2872
+ (self.local_weight_counts[t], 1),
2873
+ device=torch.device("cpu"),
2874
+ dtype=torch.int64,
2875
+ )
2876
+
2877
+ # Lookup the byte offsets for each optimizer state relative to the
2878
+ # start of the weights
2879
+ optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
2880
+ d, self.weights_precision, self.optimizer_state_dtypes
2881
+ )
2882
+ # Get the number of elements (of the optimizer state dtype) per state
2883
+ optimizer_state_size_table = self.optimizer.state_size_table(d)
2884
+
2885
+ # Get metaheader dimensions in number of elements of weight dtype
2886
+ metaheader_dim = (
2887
+ # pyre-ignore[16]
2888
+ self.kv_zch_params.eviction_policy.meta_header_lens[t]
2889
+ )
2890
+
2891
+ # Now split up the buffer into N views, N for each optimizer state
2892
+ optimizer_states: list[PartiallyMaterializedTensor] = []
2893
+ for state_name in self.optimizer.state_names():
2894
+ state_dtype = self.optimizer_state_dtypes.get(
2895
+ state_name, SparseType.FP32
2896
+ ).as_dtype()
2897
+
2898
+ # Get the size of the state in elements of the optimizer state,
2899
+ # in terms of the **weights** dtype
2900
+ state_size = math.ceil(
2901
+ optimizer_state_size_table[state_name]
2902
+ * state_dtype.itemsize
2903
+ / dtype.itemsize
2904
+ )
2905
+
2906
+ # Extract the offsets relative to the start of the weights (in
2907
+ # num bytes)
2908
+ start, _ = optimizer_state_byte_offsets[state_name]
2909
+
2910
+ # Convert the start to number of elements in terms of the
2911
+ # **weights** dtype, then add the mmetaheader dim offset
2912
+ start = metaheader_dim + start // dtype.itemsize
2913
+
2914
+ shape = [
2915
+ (
2916
+ sorted_ids[t].size(0)
2917
+ if sorted_ids is not None and sorted_ids[t].size(0) > 0
2918
+ else self.local_weight_counts[t]
2919
+ ),
2920
+ (
2921
+ # Dim is in terms of the **weights** dtype
2922
+ state_size
2923
+ ),
2924
+ ]
2925
+
2926
+ # NOTE: We have to view using the **weights** dtype, as
2927
+ # there is currently a bug with KVTensorWrapper where using
2928
+ # a different dtype does not result in the same bytes being
2929
+ # returned, e.g.
2930
+ #
2931
+ # KVTensorWrapper(dtype=fp32, width_offset=130, shape=[N, 1])
2932
+ #
2933
+ # is NOT the same as
2934
+ #
2935
+ # KVTensorWrapper(dtype=fp16, width_offset=260, shape=[N, 2]).view(-1).view(fp32)
2936
+ #
2937
+ # TODO: Fix KVTensorWrapper to support viewing data under different dtypes
2938
+ tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2939
+ shape=shape,
2940
+ dtype=(
2941
+ # NOTE: Use the *weights* dtype
2942
+ dtype
2943
+ ),
2944
+ row_offset=row_offset,
2945
+ snapshot_handle=snapshot_handle,
2946
+ sorted_indices=sorted_ids[t],
2947
+ width_offset=(
2948
+ # NOTE: Width offset is in terms of **weights** dtype
2949
+ start
2950
+ ),
2951
+ # Optimizer written to DB with weights, so skip write here
2952
+ read_only=True,
2953
+ )
2954
+ (
2955
+ tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2956
+ if self.backend_type == BackendType.SSD
2957
+ else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2958
+ )
2959
+
2960
+ optimizer_states.append(
2961
+ PartiallyMaterializedTensor(tensor_wrapper, True)
2962
+ )
2963
+
2964
+ # pyre-ignore [7]
2965
+ return optimizer_states
2966
+
2967
+ return [_fetch_offloaded_optimizer_states(t) for t, _ in enumerate(dims_)]
2968
+
2503
2969
  @torch.jit.export
2504
2970
  def split_optimizer_states(
2505
2971
  self,
2506
- sorted_id_tensor: Optional[List[torch.Tensor]] = None,
2972
+ sorted_id_tensor: Optional[list[torch.Tensor]] = None,
2507
2973
  no_snapshot: bool = True,
2508
2974
  should_flush: bool = False,
2509
- ) -> List[List[torch.Tensor]]:
2975
+ ) -> list[list[torch.Tensor]]:
2510
2976
  """
2511
2977
  Returns a list of optimizer states split by table.
2512
2978
 
@@ -2555,75 +3021,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2555
3021
  )
2556
3022
 
2557
3023
  else:
2558
- snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
2559
- no_snapshot=no_snapshot,
2560
- should_flush=should_flush,
3024
+ # Handle the KVZCH with-optimizer-offloading backend-whole-row case
3025
+ optimizer_states = self._split_optimizer_states_kv_zch_whole_row(
3026
+ sorted_id_tensor, no_snapshot, should_flush
2561
3027
  )
2562
3028
 
2563
- optimizer_states = []
2564
- table_offset = 0
2565
-
2566
- for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
2567
- # pyre-ignore
2568
- bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2569
- # pyre-ignore
2570
- bucket_size = self.kv_zch_params.bucket_sizes[t]
2571
- row_offset = table_offset - (bucket_id_start * bucket_size)
2572
-
2573
- # When backend returns whole row, the optimizer will be returned as PMT directly
2574
- # pyre-ignore [16]
2575
- if sorted_id_tensor[t].size(0) == 0 and self.local_weight_counts[t] > 0:
2576
- logging.info(
2577
- f"before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
2578
- )
2579
- # pyre-ignore [16]
2580
- sorted_id_tensor[t] = torch.zeros(
2581
- (self.local_weight_counts[t], 1),
2582
- device=torch.device("cpu"),
2583
- dtype=torch.int64,
2584
- )
2585
-
2586
- metaheader_dim = (
2587
- # pyre-ignore[16]
2588
- self.kv_zch_params.eviction_policy.meta_header_lens[t]
2589
- )
2590
- tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2591
- shape=[
2592
- (
2593
- sorted_id_tensor[t].size(0)
2594
- if sorted_id_tensor is not None
2595
- and sorted_id_tensor[t].size(0) > 0
2596
- else emb_height
2597
- ),
2598
- self.optimizer_state_dim,
2599
- ],
2600
- dtype=self.weights_precision.as_dtype(),
2601
- row_offset=row_offset,
2602
- snapshot_handle=snapshot_handle,
2603
- sorted_indices=sorted_id_tensor[t],
2604
- width_offset=(
2605
- metaheader_dim # metaheader is already padded so no need for pad4
2606
- + pad4(emb_dim)
2607
- ),
2608
- read_only=True, # optimizer written to DB with weights, so skip write here
2609
- )
2610
- (
2611
- tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2612
- if self.backend_type == BackendType.SSD
2613
- else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2614
- )
2615
-
2616
- optimizer_states.append(
2617
- [
2618
- PartiallyMaterializedTensor(
2619
- tensor_wrapper,
2620
- True if self.kv_zch_params else False,
2621
- )
2622
- ]
2623
- )
2624
-
2625
- table_offset += emb_height
2626
-
2627
3029
  logging.info(
2628
3030
  f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
2629
3031
  # pyre-ignore[16]
@@ -2635,14 +3037,14 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2635
3037
  @torch.jit.export
2636
3038
  def get_optimizer_state(
2637
3039
  self,
2638
- sorted_id_tensor: Optional[List[torch.Tensor]],
3040
+ sorted_id_tensor: Optional[list[torch.Tensor]],
2639
3041
  no_snapshot: bool = True,
2640
3042
  should_flush: bool = False,
2641
- ) -> List[Dict[str, torch.Tensor]]:
3043
+ ) -> list[dict[str, torch.Tensor]]:
2642
3044
  """
2643
3045
  Returns a list of dictionaries of optimizer states split by table.
2644
3046
  """
2645
- states_list: List[List[Tensor]] = self.split_optimizer_states(
3047
+ states_list: list[list[Tensor]] = self.split_optimizer_states(
2646
3048
  sorted_id_tensor=sorted_id_tensor,
2647
3049
  no_snapshot=no_snapshot,
2648
3050
  should_flush=should_flush,
@@ -2651,13 +3053,13 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2651
3053
  return [dict(zip(state_names, states)) for states in states_list]
2652
3054
 
2653
3055
  @torch.jit.export
2654
- def debug_split_embedding_weights(self) -> List[torch.Tensor]:
3056
+ def debug_split_embedding_weights(self) -> list[torch.Tensor]:
2655
3057
  """
2656
3058
  Returns a list of weights, split by table.
2657
3059
 
2658
3060
  Testing only, very slow.
2659
3061
  """
2660
- (rows, _) = zip(*self.embedding_specs)
3062
+ rows, _ = zip(*self.embedding_specs)
2661
3063
 
2662
3064
  rows_cumsum = [0] + list(itertools.accumulate(rows))
2663
3065
  splits = []
@@ -2738,15 +3140,48 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2738
3140
  self.flush(force=should_flush)
2739
3141
  return snapshot_handle, checkpoint_handle
2740
3142
 
3143
+ def get_embedding_dim_for_kvt(
3144
+ self, metaheader_dim: int, emb_dim: int, is_loading_checkpoint: bool
3145
+ ) -> int:
3146
+ if self.load_ckpt_without_opt:
3147
+ # For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host.
3148
+ # So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that
3149
+ # backend will only have metaheader + weight.
3150
+ # For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error
3151
+ # after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb
3152
+ if is_loading_checkpoint:
3153
+ return (
3154
+ (
3155
+ metaheader_dim # metaheader is already padded
3156
+ + pad4(emb_dim)
3157
+ + pad4(self.optimizer_state_dim)
3158
+ )
3159
+ if self.backend_return_whole_row
3160
+ else emb_dim
3161
+ )
3162
+ else:
3163
+ return metaheader_dim + pad4(emb_dim)
3164
+ else:
3165
+ return (
3166
+ (
3167
+ metaheader_dim # metaheader is already padded
3168
+ + pad4(emb_dim)
3169
+ + pad4(self.optimizer_state_dim)
3170
+ )
3171
+ if self.backend_return_whole_row
3172
+ else emb_dim
3173
+ )
3174
+
2741
3175
  @torch.jit.export
2742
3176
  def split_embedding_weights(
2743
3177
  self,
2744
3178
  no_snapshot: bool = True,
2745
3179
  should_flush: bool = False,
2746
- ) -> Tuple[ # TODO: make this a NamedTuple for readability
2747
- Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
2748
- Optional[List[torch.Tensor]],
2749
- Optional[List[torch.Tensor]],
3180
+ ) -> tuple[ # TODO: make this a NamedTuple for readability
3181
+ Union[list[PartiallyMaterializedTensor], list[torch.Tensor]],
3182
+ Optional[list[torch.Tensor]],
3183
+ Optional[list[torch.Tensor]],
3184
+ Optional[list[torch.Tensor]],
2750
3185
  ]:
2751
3186
  """
2752
3187
  This method is intended to be used by the checkpointing engine
@@ -2766,6 +3201,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2766
3201
  2nd arg: input id sorted in bucket id ascending order
2767
3202
  3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
2768
3203
  where for the i th element, we have i + bucket_id_start = global bucket id
3204
+ 4th arg: kvzch eviction metadata for each input id sorted in bucket id ascending order
2769
3205
  """
2770
3206
  snapshot_handle, checkpoint_handle = self._may_create_snapshot_for_state_dict(
2771
3207
  no_snapshot=no_snapshot,
@@ -2782,16 +3218,21 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2782
3218
  self._cached_kvzch_data.cached_weight_tensor_per_table,
2783
3219
  self._cached_kvzch_data.cached_id_tensor_per_table,
2784
3220
  self._cached_kvzch_data.cached_bucket_splits,
3221
+ [], # metadata tensor is not needed for checkpointing loading
2785
3222
  )
2786
3223
  start_time = time.time()
2787
3224
  pmt_splits = []
2788
3225
  bucket_sorted_id_splits = [] if self.kv_zch_params else None
2789
3226
  active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
3227
+ metadata_splits = [] if self.kv_zch_params else None
3228
+ skip_metadata = False
2790
3229
 
2791
3230
  table_offset = 0
2792
3231
  for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
3232
+ is_loading_checkpoint = False
2793
3233
  bucket_ascending_id_tensor = None
2794
3234
  bucket_t = None
3235
+ metadata_tensor = None
2795
3236
  row_offset = table_offset
2796
3237
  metaheader_dim = 0
2797
3238
  if self.kv_zch_params:
@@ -2823,6 +3264,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2823
3264
  bucket_size,
2824
3265
  )
2825
3266
  )
3267
+ metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
3268
+ bucket_ascending_id_tensor + table_offset,
3269
+ torch.as_tensor(bucket_ascending_id_tensor.size(0)),
3270
+ snapshot_handle,
3271
+ ).view(-1, 1)
3272
+
2826
3273
  # 3. convert local id back to global id
2827
3274
  bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
2828
3275
 
@@ -2833,16 +3280,32 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2833
3280
  logging.info(
2834
3281
  f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
2835
3282
  )
2836
- bucket_ascending_id_tensor = torch.zeros(
2837
- (self.local_weight_counts[i], 1),
2838
- device=torch.device("cpu"),
2839
- dtype=torch.int64,
2840
- )
3283
+ if self.global_id_per_rank[i].numel() != 0:
3284
+ assert (
3285
+ self.local_weight_counts[i]
3286
+ == self.global_id_per_rank[i].numel()
3287
+ ), f"local weight count and global id per rank size mismatch, with {self.local_weight_counts[i]} and {self.global_id_per_rank[i].numel()}"
3288
+ bucket_ascending_id_tensor = self.global_id_per_rank[i].to(
3289
+ device=torch.device("cpu"), dtype=torch.int64
3290
+ )
3291
+ else:
3292
+ bucket_ascending_id_tensor = torch.zeros(
3293
+ (self.local_weight_counts[i], 1),
3294
+ device=torch.device("cpu"),
3295
+ dtype=torch.int64,
3296
+ )
3297
+ skip_metadata = True
3298
+ is_loading_checkpoint = True
3299
+
2841
3300
  # self.local_weight_counts[i] = 0 # Reset the count
2842
3301
 
2843
3302
  # pyre-ignore [16] bucket_sorted_id_splits is not None
2844
3303
  bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
2845
3304
  active_id_cnt_per_bucket_split.append(bucket_t)
3305
+ if skip_metadata:
3306
+ metadata_splits = None
3307
+ else:
3308
+ metadata_splits.append(metadata_tensor)
2846
3309
 
2847
3310
  # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2848
3311
  # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
@@ -2857,14 +3320,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2857
3320
  if bucket_ascending_id_tensor is not None
2858
3321
  else emb_height
2859
3322
  ),
2860
- (
2861
- (
2862
- metaheader_dim # metaheader is already padded
2863
- + pad4(emb_dim)
2864
- + pad4(self.optimizer_state_dim)
2865
- )
2866
- if self.backend_return_whole_row
2867
- else emb_dim
3323
+ self.get_embedding_dim_for_kvt(
3324
+ metaheader_dim, emb_dim, is_loading_checkpoint
2868
3325
  ),
2869
3326
  ],
2870
3327
  dtype=dtype,
@@ -2876,6 +3333,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2876
3333
  bucket_ascending_id_tensor if self.kv_zch_params else None
2877
3334
  ),
2878
3335
  checkpoint_handle=checkpoint_handle,
3336
+ only_load_weight=(
3337
+ True
3338
+ if self.load_ckpt_without_opt and is_loading_checkpoint
3339
+ else False
3340
+ ),
2879
3341
  )
2880
3342
  (
2881
3343
  tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
@@ -2898,14 +3360,19 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2898
3360
  f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
2899
3361
  )
2900
3362
 
2901
- return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
3363
+ return (
3364
+ pmt_splits,
3365
+ bucket_sorted_id_splits,
3366
+ active_id_cnt_per_bucket_split,
3367
+ metadata_splits,
3368
+ )
2902
3369
 
2903
3370
  @torch.jit.ignore
2904
3371
  def _apply_state_dict_w_offloading(self) -> None:
2905
3372
  # Row count per table
2906
- (rows, _) = zip(*self.embedding_specs)
3373
+ rows, _ = zip(*self.embedding_specs)
2907
3374
  # Cumulative row counts per table for rowwise states
2908
- row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
3375
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
2909
3376
 
2910
3377
  for t, _ in enumerate(self.embedding_specs):
2911
3378
  # pyre-ignore [16]
@@ -2932,9 +3399,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2932
3399
  @torch.jit.ignore
2933
3400
  def _apply_state_dict_no_offloading(self) -> None:
2934
3401
  # Row count per table
2935
- (rows, _) = zip(*self.embedding_specs)
3402
+ rows, _ = zip(*self.embedding_specs)
2936
3403
  # Cumulative row counts per table for rowwise states
2937
- row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
3404
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
2938
3405
 
2939
3406
  def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None:
2940
3407
  device = dst.device
@@ -2968,7 +3435,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2968
3435
  # Set up the plan for copying optimizer states over
2969
3436
  if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2970
3437
  mapping = [(opt_states[0], self.momentum1_dev)]
2971
- elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
3438
+ elif self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
2972
3439
  mapping = [
2973
3440
  (opt_states[0], self.momentum1_dev),
2974
3441
  (opt_states[1], self.momentum2_dev),
@@ -3025,7 +3492,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3025
3492
  def streaming_write_weight_and_id_per_table(
3026
3493
  self,
3027
3494
  weight_state: torch.Tensor,
3028
- opt_states: List[torch.Tensor],
3495
+ opt_states: list[torch.Tensor],
3029
3496
  id_tensor: torch.Tensor,
3030
3497
  row_offset: int,
3031
3498
  ) -> None:
@@ -3082,7 +3549,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3082
3549
  state_name = self.optimizer.state_names()[o]
3083
3550
 
3084
3551
  # Fetch the byte offsets for the optimizer state by its name
3085
- (start, end) = optimizer_state_byte_offsets[state_name]
3552
+ start, end = optimizer_state_byte_offsets[state_name]
3086
3553
 
3087
3554
  # Assume that the opt_state passed in already has dtype matching
3088
3555
  # self.optimizer_state_dtypes[state_name]
@@ -3119,7 +3586,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3119
3586
  self.load_state_dict = True
3120
3587
 
3121
3588
  dtype = self.weights_precision.as_dtype()
3122
- (_, dims) = zip(*self.embedding_specs)
3589
+ _, dims = zip(*self.embedding_specs)
3123
3590
 
3124
3591
  self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
3125
3592
 
@@ -3192,6 +3659,10 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3192
3659
  def flush(self, force: bool = False) -> None:
3193
3660
  # allow force flush from split_embedding_weights to cover edge cases, e.g. checkpointing
3194
3661
  # after trained 0 batches
3662
+ if not self.training:
3663
+ # for eval mode, we should not write anything to embedding
3664
+ return
3665
+
3195
3666
  if self.step == self.last_flush_step and not force:
3196
3667
  logging.info(
3197
3668
  f"SSD TBE has been flushed at {self.last_flush_step=} already for tbe:{self.tbe_unique_id}"
@@ -3237,18 +3708,20 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3237
3708
  indices: Tensor,
3238
3709
  offsets: Tensor,
3239
3710
  per_sample_weights: Optional[Tensor] = None,
3240
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
3241
- ) -> Tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3711
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3712
+ vbe_output: Optional[Tensor] = None,
3713
+ vbe_output_offsets: Optional[Tensor] = None,
3714
+ ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3242
3715
  """
3243
3716
  Prepare TBE inputs
3244
3717
  """
3245
3718
  # Generate VBE metadata
3246
3719
  vbe_metadata = self._generate_vbe_metadata(
3247
- offsets, batch_size_per_feature_per_rank
3720
+ offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
3248
3721
  )
3249
3722
 
3250
3723
  # Force casting indices and offsets to long
3251
- (indices, offsets) = indices.long(), offsets.long()
3724
+ indices, offsets = indices.long(), offsets.long()
3252
3725
 
3253
3726
  # Force casting per_sample_weights to float
3254
3727
  if per_sample_weights is not None:
@@ -3287,6 +3760,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3287
3760
  self._report_l2_cache_perf_stats()
3288
3761
  if self.backend_type == BackendType.DRAM:
3289
3762
  self._report_dram_kv_perf_stats()
3763
+ if self.kv_zch_params and self.kv_zch_params.eviction_policy:
3764
+ self._report_eviction_stats()
3290
3765
 
3291
3766
  @torch.jit.ignore
3292
3767
  def _report_ssd_l1_cache_stats(self) -> None:
@@ -3303,7 +3778,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3303
3778
  ssd_cache_stats = self.ssd_cache_stats.tolist()
3304
3779
  if len(self.last_reported_ssd_stats) == 0:
3305
3780
  self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
3306
- ssd_cache_stats_delta: List[float] = [0.0] * len(ssd_cache_stats)
3781
+ ssd_cache_stats_delta: list[float] = [0.0] * len(ssd_cache_stats)
3307
3782
  for i in range(len(ssd_cache_stats)):
3308
3783
  ssd_cache_stats_delta[i] = (
3309
3784
  ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
@@ -3553,6 +4028,98 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3553
4028
  time_unit="us",
3554
4029
  )
3555
4030
 
4031
+ @torch.jit.ignore
4032
+ def _report_eviction_stats(self) -> None:
4033
+ if self.stats_reporter is None:
4034
+ return
4035
+
4036
+ stats_reporter: TBEStatsReporter = self.stats_reporter
4037
+ if not stats_reporter.should_report(self.step):
4038
+ return
4039
+
4040
+ # skip metrics reporting when evicting disabled
4041
+ if self.kv_zch_params.eviction_policy.eviction_trigger_mode == 0:
4042
+ return
4043
+
4044
+ T = len(set(self.feature_table_map))
4045
+ evicted_counts = torch.zeros(T, dtype=torch.int64)
4046
+ processed_counts = torch.zeros(T, dtype=torch.int64)
4047
+ eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float)
4048
+ full_duration_ms = torch.tensor(0, dtype=torch.int64)
4049
+ exec_duration_ms = torch.tensor(0, dtype=torch.int64)
4050
+ self.ssd_db.get_feature_evict_metric(
4051
+ evicted_counts,
4052
+ processed_counts,
4053
+ eviction_threshold_with_dry_run,
4054
+ full_duration_ms,
4055
+ exec_duration_ms,
4056
+ )
4057
+
4058
+ stats_reporter.report_data_amount(
4059
+ iteration_step=self.step,
4060
+ event_name=self.eviction_sum_evicted_counts_stats_name,
4061
+ data_bytes=int(evicted_counts.sum().item()),
4062
+ enable_tb_metrics=True,
4063
+ )
4064
+ stats_reporter.report_data_amount(
4065
+ iteration_step=self.step,
4066
+ event_name=self.eviction_sum_processed_counts_stats_name,
4067
+ data_bytes=int(processed_counts.sum().item()),
4068
+ enable_tb_metrics=True,
4069
+ )
4070
+ if processed_counts.sum().item() != 0:
4071
+ stats_reporter.report_data_amount(
4072
+ iteration_step=self.step,
4073
+ event_name=self.eviction_evict_rate_stats_name,
4074
+ data_bytes=int(
4075
+ evicted_counts.sum().item() * 100 / processed_counts.sum().item()
4076
+ ),
4077
+ enable_tb_metrics=True,
4078
+ )
4079
+ for t in self.feature_table_map:
4080
+ stats_reporter.report_data_amount(
4081
+ iteration_step=self.step,
4082
+ event_name=f"eviction.feature_table.{t}.evicted_counts",
4083
+ data_bytes=int(evicted_counts[t].item()),
4084
+ enable_tb_metrics=True,
4085
+ )
4086
+ stats_reporter.report_data_amount(
4087
+ iteration_step=self.step,
4088
+ event_name=f"eviction.feature_table.{t}.processed_counts",
4089
+ data_bytes=int(processed_counts[t].item()),
4090
+ enable_tb_metrics=True,
4091
+ )
4092
+ if processed_counts[t].item() != 0:
4093
+ stats_reporter.report_data_amount(
4094
+ iteration_step=self.step,
4095
+ event_name=f"eviction.feature_table.{t}.evict_rate",
4096
+ data_bytes=int(
4097
+ evicted_counts[t].item() * 100 / processed_counts[t].item()
4098
+ ),
4099
+ enable_tb_metrics=True,
4100
+ )
4101
+ stats_reporter.report_duration(
4102
+ iteration_step=self.step,
4103
+ event_name="eviction.feature_table.full_duration_ms",
4104
+ duration_ms=full_duration_ms.item(),
4105
+ time_unit="ms",
4106
+ enable_tb_metrics=True,
4107
+ )
4108
+ stats_reporter.report_duration(
4109
+ iteration_step=self.step,
4110
+ event_name="eviction.feature_table.exec_duration_ms",
4111
+ duration_ms=exec_duration_ms.item(),
4112
+ time_unit="ms",
4113
+ enable_tb_metrics=True,
4114
+ )
4115
+ if full_duration_ms.item() != 0:
4116
+ stats_reporter.report_data_amount(
4117
+ iteration_step=self.step,
4118
+ event_name="eviction.feature_table.exec_div_full_duration_rate",
4119
+ data_bytes=int(exec_duration_ms.item() * 100 / full_duration_ms.item()),
4120
+ enable_tb_metrics=True,
4121
+ )
4122
+
3556
4123
  @torch.jit.ignore
3557
4124
  def _report_dram_kv_perf_stats(self) -> None:
3558
4125
  """
@@ -3570,8 +4137,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3570
4137
  self.step, stats_reporter.report_interval # pyre-ignore
3571
4138
  )
3572
4139
 
3573
- if len(dram_kv_perf_stats) != 22:
3574
- logging.error("dram cache perf stats should have 22 elements")
4140
+ if len(dram_kv_perf_stats) != 36:
4141
+ logging.error("dram cache perf stats should have 36 elements")
3575
4142
  return
3576
4143
 
3577
4144
  dram_read_duration = dram_kv_perf_stats[0]
@@ -3599,52 +4166,75 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3599
4166
 
3600
4167
  dram_kv_allocated_bytes = dram_kv_perf_stats[20]
3601
4168
  dram_kv_actual_used_chunk_bytes = dram_kv_perf_stats[21]
4169
+ dram_kv_num_rows = dram_kv_perf_stats[22]
4170
+ dram_kv_read_counts = dram_kv_perf_stats[23]
4171
+ dram_metadata_write_sharding_total_duration = dram_kv_perf_stats[24]
4172
+ dram_metadata_write_total_duration = dram_kv_perf_stats[25]
4173
+ dram_metadata_write_allocate_avg_duration = dram_kv_perf_stats[26]
4174
+ dram_metadata_write_lookup_cache_avg_duration = dram_kv_perf_stats[27]
4175
+ dram_metadata_write_acquire_lock_avg_duration = dram_kv_perf_stats[28]
4176
+ dram_metadata_write_cache_miss_avg_count = dram_kv_perf_stats[29]
4177
+
4178
+ dram_read_metadata_total_duration = dram_kv_perf_stats[30]
4179
+ dram_read_metadata_sharding_total_duration = dram_kv_perf_stats[31]
4180
+ dram_read_metadata_cache_hit_copy_avg_duration = dram_kv_perf_stats[32]
4181
+ dram_read_metadata_lookup_cache_total_avg_duration = dram_kv_perf_stats[33]
4182
+ dram_read_metadata_acquire_lock_avg_duration = dram_kv_perf_stats[34]
4183
+ dram_read_read_metadata_load_size = dram_kv_perf_stats[35]
3602
4184
 
3603
4185
  stats_reporter.report_duration(
3604
4186
  iteration_step=self.step,
3605
4187
  event_name="dram_kv.perf.get.dram_read_duration_us",
3606
4188
  duration_ms=dram_read_duration,
4189
+ enable_tb_metrics=True,
3607
4190
  time_unit="us",
3608
4191
  )
3609
4192
  stats_reporter.report_duration(
3610
4193
  iteration_step=self.step,
3611
4194
  event_name="dram_kv.perf.get.dram_read_sharding_duration_us",
3612
4195
  duration_ms=dram_read_sharding_duration,
4196
+ enable_tb_metrics=True,
3613
4197
  time_unit="us",
3614
4198
  )
3615
4199
  stats_reporter.report_duration(
3616
4200
  iteration_step=self.step,
3617
4201
  event_name="dram_kv.perf.get.dram_read_cache_hit_copy_duration_us",
3618
4202
  duration_ms=dram_read_cache_hit_copy_duration,
4203
+ enable_tb_metrics=True,
3619
4204
  time_unit="us",
3620
4205
  )
3621
4206
  stats_reporter.report_duration(
3622
4207
  iteration_step=self.step,
3623
4208
  event_name="dram_kv.perf.get.dram_read_fill_row_storage_duration_us",
3624
4209
  duration_ms=dram_read_fill_row_storage_duration,
4210
+ enable_tb_metrics=True,
3625
4211
  time_unit="us",
3626
4212
  )
3627
4213
  stats_reporter.report_duration(
3628
4214
  iteration_step=self.step,
3629
4215
  event_name="dram_kv.perf.get.dram_read_lookup_cache_duration_us",
3630
4216
  duration_ms=dram_read_lookup_cache_duration,
4217
+ enable_tb_metrics=True,
3631
4218
  time_unit="us",
3632
4219
  )
3633
4220
  stats_reporter.report_duration(
3634
4221
  iteration_step=self.step,
3635
4222
  event_name="dram_kv.perf.get.dram_read_acquire_lock_duration_us",
3636
4223
  duration_ms=dram_read_acquire_lock_duration,
4224
+ enable_tb_metrics=True,
3637
4225
  time_unit="us",
3638
4226
  )
3639
4227
  stats_reporter.report_data_amount(
3640
4228
  iteration_step=self.step,
3641
4229
  event_name="dram_kv.perf.get.dram_read_missing_load",
4230
+ enable_tb_metrics=True,
3642
4231
  data_bytes=dram_read_missing_load,
3643
4232
  )
3644
4233
  stats_reporter.report_duration(
3645
4234
  iteration_step=self.step,
3646
4235
  event_name="dram_kv.perf.set.dram_write_sharing_duration_us",
3647
4236
  duration_ms=dram_write_sharing_duration,
4237
+ enable_tb_metrics=True,
3648
4238
  time_unit="us",
3649
4239
  )
3650
4240
 
@@ -3652,83 +4242,192 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3652
4242
  iteration_step=self.step,
3653
4243
  event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_duration_us",
3654
4244
  duration_ms=dram_fwd_l1_eviction_write_duration,
4245
+ enable_tb_metrics=True,
3655
4246
  time_unit="us",
3656
4247
  )
3657
4248
  stats_reporter.report_duration(
3658
4249
  iteration_step=self.step,
3659
4250
  event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_allocate_duration_us",
3660
4251
  duration_ms=dram_fwd_l1_eviction_write_allocate_duration,
4252
+ enable_tb_metrics=True,
3661
4253
  time_unit="us",
3662
4254
  )
3663
4255
  stats_reporter.report_duration(
3664
4256
  iteration_step=self.step,
3665
4257
  event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_cache_copy_duration_us",
3666
4258
  duration_ms=dram_fwd_l1_eviction_write_cache_copy_duration,
4259
+ enable_tb_metrics=True,
3667
4260
  time_unit="us",
3668
4261
  )
3669
4262
  stats_reporter.report_duration(
3670
4263
  iteration_step=self.step,
3671
4264
  event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_lookup_cache_duration_us",
3672
4265
  duration_ms=dram_fwd_l1_eviction_write_lookup_cache_duration,
4266
+ enable_tb_metrics=True,
3673
4267
  time_unit="us",
3674
4268
  )
3675
4269
  stats_reporter.report_duration(
3676
4270
  iteration_step=self.step,
3677
4271
  event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_acquire_lock_duration_us",
3678
4272
  duration_ms=dram_fwd_l1_eviction_write_acquire_lock_duration,
4273
+ enable_tb_metrics=True,
3679
4274
  time_unit="us",
3680
4275
  )
3681
4276
  stats_reporter.report_data_amount(
3682
4277
  iteration_step=self.step,
3683
4278
  event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load",
3684
4279
  data_bytes=dram_fwd_l1_eviction_write_missing_load,
4280
+ enable_tb_metrics=True,
3685
4281
  )
3686
4282
 
3687
4283
  stats_reporter.report_duration(
3688
4284
  iteration_step=self.step,
3689
4285
  event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_duration_us",
3690
4286
  duration_ms=dram_bwd_l1_cnflct_miss_write_duration,
4287
+ enable_tb_metrics=True,
3691
4288
  time_unit="us",
3692
4289
  )
3693
4290
  stats_reporter.report_duration(
3694
4291
  iteration_step=self.step,
3695
4292
  event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_allocate_duration_us",
3696
4293
  duration_ms=dram_bwd_l1_cnflct_miss_write_allocate_duration,
4294
+ enable_tb_metrics=True,
3697
4295
  time_unit="us",
3698
4296
  )
3699
4297
  stats_reporter.report_duration(
3700
4298
  iteration_step=self.step,
3701
4299
  event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_cache_copy_duration_us",
3702
4300
  duration_ms=dram_bwd_l1_cnflct_miss_write_cache_copy_duration,
4301
+ enable_tb_metrics=True,
3703
4302
  time_unit="us",
3704
4303
  )
3705
4304
  stats_reporter.report_duration(
3706
4305
  iteration_step=self.step,
3707
4306
  event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_lookup_cache_duration_us",
3708
4307
  duration_ms=dram_bwd_l1_cnflct_miss_write_lookup_cache_duration,
4308
+ enable_tb_metrics=True,
3709
4309
  time_unit="us",
3710
4310
  )
3711
4311
  stats_reporter.report_duration(
3712
4312
  iteration_step=self.step,
3713
4313
  event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_acquire_lock_duration_us",
3714
4314
  duration_ms=dram_bwd_l1_cnflct_miss_write_acquire_lock_duration,
4315
+ enable_tb_metrics=True,
3715
4316
  time_unit="us",
3716
4317
  )
3717
4318
  stats_reporter.report_data_amount(
3718
4319
  iteration_step=self.step,
3719
4320
  event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load",
3720
4321
  data_bytes=dram_bwd_l1_cnflct_miss_write_missing_load,
4322
+ enable_tb_metrics=True,
4323
+ )
4324
+
4325
+ stats_reporter.report_data_amount(
4326
+ iteration_step=self.step,
4327
+ event_name="dram_kv.perf.get.dram_kv_read_counts",
4328
+ data_bytes=dram_kv_read_counts,
4329
+ enable_tb_metrics=True,
3721
4330
  )
3722
4331
 
3723
4332
  stats_reporter.report_data_amount(
3724
4333
  iteration_step=self.step,
3725
4334
  event_name=self.dram_kv_allocated_bytes_stats_name,
3726
4335
  data_bytes=dram_kv_allocated_bytes,
4336
+ enable_tb_metrics=True,
3727
4337
  )
3728
4338
  stats_reporter.report_data_amount(
3729
4339
  iteration_step=self.step,
3730
4340
  event_name=self.dram_kv_actual_used_chunk_bytes_stats_name,
3731
4341
  data_bytes=dram_kv_actual_used_chunk_bytes,
4342
+ enable_tb_metrics=True,
4343
+ )
4344
+ stats_reporter.report_data_amount(
4345
+ iteration_step=self.step,
4346
+ event_name=self.dram_kv_mem_num_rows_stats_name,
4347
+ data_bytes=dram_kv_num_rows,
4348
+ enable_tb_metrics=True,
4349
+ )
4350
+ stats_reporter.report_duration(
4351
+ iteration_step=self.step,
4352
+ event_name="dram_kv.perf.set.dram_eviction_score_write_sharding_total_duration_us",
4353
+ duration_ms=dram_metadata_write_sharding_total_duration,
4354
+ enable_tb_metrics=True,
4355
+ time_unit="us",
4356
+ )
4357
+ stats_reporter.report_duration(
4358
+ iteration_step=self.step,
4359
+ event_name="dram_kv.perf.set.dram_eviction_score_write_total_duration_us",
4360
+ duration_ms=dram_metadata_write_total_duration,
4361
+ enable_tb_metrics=True,
4362
+ time_unit="us",
4363
+ )
4364
+ stats_reporter.report_duration(
4365
+ iteration_step=self.step,
4366
+ event_name="dram_kv.perf.set.dram_eviction_score_write_allocate_avg_duration_us",
4367
+ duration_ms=dram_metadata_write_allocate_avg_duration,
4368
+ enable_tb_metrics=True,
4369
+ time_unit="us",
4370
+ )
4371
+ stats_reporter.report_duration(
4372
+ iteration_step=self.step,
4373
+ event_name="dram_kv.perf.set.dram_eviction_score_write_lookup_cache_avg_duration_us",
4374
+ duration_ms=dram_metadata_write_lookup_cache_avg_duration,
4375
+ enable_tb_metrics=True,
4376
+ time_unit="us",
4377
+ )
4378
+ stats_reporter.report_duration(
4379
+ iteration_step=self.step,
4380
+ event_name="dram_kv.perf.set.dram_eviction_score_write_acquire_lock_avg_duration_us",
4381
+ duration_ms=dram_metadata_write_acquire_lock_avg_duration,
4382
+ enable_tb_metrics=True,
4383
+ time_unit="us",
4384
+ )
4385
+ stats_reporter.report_data_amount(
4386
+ iteration_step=self.step,
4387
+ event_name="dram_kv.perf.set.dram_eviction_score_write_cache_miss_avg_count",
4388
+ data_bytes=dram_metadata_write_cache_miss_avg_count,
4389
+ enable_tb_metrics=True,
4390
+ )
4391
+ stats_reporter.report_duration(
4392
+ iteration_step=self.step,
4393
+ event_name="dram_kv.perf.get.dram_eviction_score_read_total_duration_us",
4394
+ duration_ms=dram_read_metadata_total_duration,
4395
+ enable_tb_metrics=True,
4396
+ time_unit="us",
4397
+ )
4398
+ stats_reporter.report_duration(
4399
+ iteration_step=self.step,
4400
+ event_name="dram_kv.perf.get.dram_eviction_score_read_sharding_total_duration_us",
4401
+ duration_ms=dram_read_metadata_sharding_total_duration,
4402
+ enable_tb_metrics=True,
4403
+ time_unit="us",
4404
+ )
4405
+ stats_reporter.report_duration(
4406
+ iteration_step=self.step,
4407
+ event_name="dram_kv.perf.get.dram_eviction_score_read_cache_hit_copy_avg_duration_us",
4408
+ duration_ms=dram_read_metadata_cache_hit_copy_avg_duration,
4409
+ enable_tb_metrics=True,
4410
+ time_unit="us",
4411
+ )
4412
+ stats_reporter.report_duration(
4413
+ iteration_step=self.step,
4414
+ event_name="dram_kv.perf.get.dram_eviction_score_read_lookup_cache_total_avg_duration_us",
4415
+ duration_ms=dram_read_metadata_lookup_cache_total_avg_duration,
4416
+ enable_tb_metrics=True,
4417
+ time_unit="us",
4418
+ )
4419
+ stats_reporter.report_duration(
4420
+ iteration_step=self.step,
4421
+ event_name="dram_kv.perf.get.dram_eviction_score_read_acquire_lock_avg_duration_us",
4422
+ duration_ms=dram_read_metadata_acquire_lock_avg_duration,
4423
+ enable_tb_metrics=True,
4424
+ time_unit="us",
4425
+ )
4426
+ stats_reporter.report_data_amount(
4427
+ iteration_step=self.step,
4428
+ event_name="dram_kv.perf.get.dram_eviction_score_read_load_size",
4429
+ data_bytes=dram_read_read_metadata_load_size,
4430
+ enable_tb_metrics=True,
3732
4431
  )
3733
4432
 
3734
4433
  def _recording_to_timer(
@@ -3749,7 +4448,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3749
4448
 
3750
4449
  def fetch_from_l1_sp_w_row_ids(
3751
4450
  self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False
3752
- ) -> Tuple[torch.Tensor, torch.Tensor]:
4451
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
3753
4452
  """
3754
4453
  Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
3755
4454
  @return: updated_weights/optimizer_states, mask of which rows are filled
@@ -3762,36 +4461,38 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3762
4461
  # NOTE: Remove this once there is support for fetching multiple
3763
4462
  # optimizer states in fetch_from_l1_sp_w_row_ids
3764
4463
  if only_get_optimizer_states and self.optimizer not in [
3765
- OptimType.EXACT_ROWWISE_ADAGRAD
4464
+ OptimType.EXACT_ROWWISE_ADAGRAD,
4465
+ OptimType.PARTIAL_ROWWISE_ADAM,
3766
4466
  ]:
3767
4467
  raise RuntimeError(
3768
4468
  f"Fetching optimizer states using fetch_from_l1_sp_w_row_ids() is not yet supported for {self.optimizer}"
3769
4469
  )
3770
4470
 
3771
- with torch.no_grad():
3772
- weights_dtype = self.weights_precision.as_dtype()
3773
- step = self.step
3774
-
3775
- if only_get_optimizer_states:
3776
- start_pos = pad4(self.max_D)
3777
- # NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids working
3778
- # until it is upgraded to support optimizers with multiple states
3779
- # and dtypes
3780
- row_dim = int(
3781
- math.ceil(torch.float32.itemsize / weights_dtype.itemsize)
4471
+ def split_results_by_opt_states(
4472
+ updated_weights: torch.Tensor, cache_location_mask: torch.Tensor
4473
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
4474
+ if not only_get_optimizer_states:
4475
+ return [updated_weights], cache_location_mask
4476
+ # TODO: support mixed dimension case
4477
+ # currently only supports tables with the same max_D dimension
4478
+ opt_to_dim = self.optimizer.byte_offsets_along_row(
4479
+ self.max_D, self.weights_precision, self.optimizer_state_dtypes
4480
+ )
4481
+ updated_opt_states = []
4482
+ for opt_name, dim in opt_to_dim.items():
4483
+ opt_dtype = self.optimizer._extract_dtype(
4484
+ self.optimizer_state_dtypes, opt_name
3782
4485
  )
3783
- result_dtype = torch.float32
3784
- result_dim = int(
3785
- ceil(row_dim / (result_dtype.itemsize / weights_dtype.itemsize))
4486
+ updated_opt_states.append(
4487
+ updated_weights.view(dtype=torch.uint8)[:, dim[0] : dim[1]].view(
4488
+ dtype=opt_dtype
4489
+ )
3786
4490
  )
4491
+ return updated_opt_states, cache_location_mask
3787
4492
 
3788
- else:
3789
- start_pos = 0
3790
- # get the whole row
3791
- row_dim = self.cache_row_dim
3792
- result_dim = row_dim
3793
- result_dtype = weights_dtype
3794
-
4493
+ with torch.no_grad():
4494
+ weights_dtype = self.weights_precision.as_dtype()
4495
+ step = self.step
3795
4496
  with record_function(f"## fetch_from_l1_{step}_{self.tbe_unique_id} ##"):
3796
4497
  lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
3797
4498
  row_ids,
@@ -3800,17 +4501,23 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3800
4501
  )
3801
4502
  updated_weights = torch.empty(
3802
4503
  row_ids.numel(),
3803
- result_dim,
4504
+ self.cache_row_dim,
3804
4505
  device=self.current_device,
3805
- dtype=result_dtype,
4506
+ dtype=weights_dtype,
3806
4507
  )
3807
4508
 
3808
4509
  # D2D copy cache
3809
4510
  cache_location_mask = lxu_cache_locations >= 0
3810
- updated_weights[cache_location_mask] = self.lxu_cache_weights[
3811
- lxu_cache_locations[cache_location_mask],
3812
- start_pos : start_pos + row_dim,
3813
- ].view(result_dtype)
4511
+ torch.ops.fbgemm.masked_index_select(
4512
+ updated_weights,
4513
+ lxu_cache_locations,
4514
+ self.lxu_cache_weights,
4515
+ torch.tensor(
4516
+ [row_ids.numel()],
4517
+ device=self.current_device,
4518
+ dtype=torch.int32,
4519
+ ),
4520
+ )
3814
4521
 
3815
4522
  with record_function(f"## fetch_from_sp_{step}_{self.tbe_unique_id} ##"):
3816
4523
  if len(self.ssd_scratch_pad_eviction_data) > 0:
@@ -3821,7 +4528,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3821
4528
  actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
3822
4529
  if actions_count_gpu.item() == 0:
3823
4530
  # no action to take
3824
- return (updated_weights, cache_location_mask)
4531
+ return split_results_by_opt_states(
4532
+ updated_weights, cache_location_mask
4533
+ )
3825
4534
 
3826
4535
  sp_idx = sp_idx[:actions_count_gpu]
3827
4536
 
@@ -3872,16 +4581,23 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3872
4581
  )
3873
4582
 
3874
4583
  # D2D copy SP
3875
- updated_weights[exact_match_mask] = sp[
3876
- sp_locations_found, start_pos : start_pos + row_dim
3877
- ].view(result_dtype)
4584
+ torch.ops.fbgemm.masked_index_select(
4585
+ updated_weights,
4586
+ sp_locations_in_updated_weights,
4587
+ sp,
4588
+ torch.tensor(
4589
+ [row_ids.numel()],
4590
+ device=self.current_device,
4591
+ dtype=torch.int32,
4592
+ ),
4593
+ )
3878
4594
  # cache_location_mask is the mask of rows in L1
3879
4595
  # exact_match_mask is the mask of rows in SP
3880
4596
  cache_location_mask = torch.logical_or(
3881
4597
  cache_location_mask, exact_match_mask
3882
4598
  )
3883
4599
 
3884
- return (updated_weights, cache_location_mask)
4600
+ return split_results_by_opt_states(updated_weights, cache_location_mask)
3885
4601
 
3886
4602
  def register_backward_hook_before_eviction(
3887
4603
  self, backward_hook: Callable[[torch.Tensor], None]
@@ -3901,3 +4617,312 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3901
4617
  self.placeholder_autograd_tensor.register_hook(backward_hook)
3902
4618
  for hook in hooks:
3903
4619
  self.placeholder_autograd_tensor.register_hook(hook)
4620
+
4621
+ def set_local_weight_counts_for_table(
4622
+ self, table_idx: int, weight_count: int
4623
+ ) -> None:
4624
+ self.local_weight_counts[table_idx] = weight_count
4625
+
4626
+ def set_global_id_per_rank_for_table(
4627
+ self, table_idx: int, global_id: torch.Tensor
4628
+ ) -> None:
4629
+ self.global_id_per_rank[table_idx] = global_id
4630
+
4631
+ def direct_write_embedding(
4632
+ self,
4633
+ indices: torch.Tensor,
4634
+ offsets: torch.Tensor,
4635
+ weights: torch.Tensor,
4636
+ ) -> None:
4637
+ """
4638
+ Directly write the weights to L1, SP and backend without relying on auto-gradient for embedding cache.
4639
+ Please refer to design doc for more details: https://docs.google.com/document/d/1TJHKvO1m3-5tYAKZGhacXnGk7iCNAzz7wQlrFbX_LDI/edit?tab=t.0
4640
+ """
4641
+ assert (
4642
+ self._embedding_cache_mode
4643
+ ), "Must be in embedding_cache_mode to support direct_write_embedding method."
4644
+
4645
+ B_offsets = None
4646
+ max_B = -1
4647
+
4648
+ with torch.no_grad():
4649
+ # Wait for any ongoing prefetch operations to complete before starting direct_write
4650
+ current_stream = torch.cuda.current_stream()
4651
+ current_stream.wait_event(self.prefetch_complete_event)
4652
+
4653
+ # Create local step events for internal sequential execution
4654
+ weights_dtype = self.weights_precision.as_dtype()
4655
+ assert (
4656
+ weights_dtype == weights.dtype
4657
+ ), f"Expected embedding table dtype {weights_dtype} is same with input weight dtype, but got {weights.dtype}"
4658
+
4659
+ # Pad the weights to match self.max_D width if necessary
4660
+ if weights.size(1) < self.cache_row_dim:
4661
+ weights = torch.nn.functional.pad(
4662
+ weights, (0, self.cache_row_dim - weights.size(1))
4663
+ )
4664
+
4665
+ step = self.step
4666
+
4667
+ # step 0: run backward hook for prefetch if prefetch pipeline is enabled before writing to L1 and SP
4668
+ if self.prefetch_pipeline:
4669
+ self._update_cache_counter_and_pointers(nn.Module(), torch.empty(0))
4670
+
4671
+ # step 1: lookup and write to l1 cache
4672
+ with record_function(
4673
+ f"## direct_write_to_l1_{step}_{self.tbe_unique_id} ##"
4674
+ ):
4675
+ if self.gather_ssd_cache_stats:
4676
+ self.local_ssd_cache_stats.zero_()
4677
+
4678
+ # Linearize indices
4679
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
4680
+ self.hash_size_cumsum,
4681
+ indices,
4682
+ offsets,
4683
+ B_offsets,
4684
+ max_B,
4685
+ )
4686
+
4687
+ lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
4688
+ linear_cache_indices,
4689
+ self.lxu_cache_state,
4690
+ self.total_hash_size,
4691
+ )
4692
+ cache_location_mask = lxu_cache_locations >= 0
4693
+
4694
+ # Get the cache locations for the row_ids that are already in the cache
4695
+ cache_locations = lxu_cache_locations[cache_location_mask]
4696
+
4697
+ # Get the corresponding input weights for these row_ids
4698
+ cache_weights = weights[cache_location_mask]
4699
+
4700
+ # Update the cache with these input weights
4701
+ if cache_locations.numel() > 0:
4702
+ self.lxu_cache_weights.index_put_(
4703
+ (cache_locations,), cache_weights, accumulate=False
4704
+ )
4705
+
4706
+ # Record completion of step 1
4707
+ current_stream.record_event(self.direct_write_l1_complete_event)
4708
+
4709
+ # step 2: pop the current scratch pad and write to next batch scratch pad if exists
4710
+ # Wait for step 1 to complete
4711
+ with record_function(
4712
+ f"## direct_write_to_sp_{step}_{self.tbe_unique_id} ##"
4713
+ ):
4714
+ if len(self.ssd_scratch_pad_eviction_data) > 0:
4715
+ self.ssd_scratch_pad_eviction_data.pop(0)
4716
+ if len(self.ssd_scratch_pad_eviction_data) > 0:
4717
+ # Wait for any pending backend reads to the next scratch pad
4718
+ # to complete before we write to it. Otherwise, stale backend data
4719
+ # will overwrite our direct_write updates.
4720
+ # The ssd_event_get marks completion of backend fetch operations.
4721
+ current_stream.wait_event(self.ssd_event_get)
4722
+
4723
+ # if scratch pad exists, write to next batch scratch pad
4724
+ sp = self.ssd_scratch_pad_eviction_data[0][0]
4725
+ sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
4726
+ self.current_device
4727
+ )
4728
+ actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
4729
+ if actions_count_gpu.item() != 0:
4730
+ # when no actional_count_gpu, no need to write to SP
4731
+ sp_idx = sp_idx[:actions_count_gpu]
4732
+
4733
+ # -1 in lxu_cache_locations means the row is not in L1 cache and in SP
4734
+ # fill the row_ids in L1 with -2, >0 values means in SP or backend
4735
+ # @eg. updated_indices_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
4736
+ updated_indices_in_sp = linear_cache_indices.masked_fill(
4737
+ lxu_cache_locations != -1, -2
4738
+ )
4739
+ # sort the sp_idx for binary search
4740
+ # should already be sorted
4741
+ # sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
4742
+ # @eg. sp_idx = [4, 2, 1, 3, 10]
4743
+ # @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
4744
+ sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
4745
+ # search rows id in sp against the SP indexes to find location of the rows in SP
4746
+ # @eg: updated_indices_in_sp = [0, 5, 0, 1, 0, 2, 3, 4, 4]
4747
+ # @eg: 5 is OOB
4748
+ updated_indices_in_sp_idx = torch.searchsorted(
4749
+ sorted_sp_idx, updated_indices_in_sp
4750
+ )
4751
+ # does not found in SP will Out of Bound
4752
+ oob_sp_idx = updated_indices_in_sp_idx >= sp_idx.numel()
4753
+ # make the oob items in bound
4754
+ # @eg updated_indices_in_sp=[0, 0, 0, 1, 0, 2, 3, 4, 4]
4755
+ updated_indices_in_sp_idx[oob_sp_idx] = 0
4756
+
4757
+ # torch.searchsorted is not exact match,
4758
+ # we only take exact matched rows, where the id is found in SP.
4759
+ # @eg 5 in updated_indices_in_sp is not in sp_idx, but has 4 in updated_indices_in_sp
4760
+ # @eg sorted_sp_idx[updated_indices_in_sp]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
4761
+ # @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
4762
+ exact_match_mask = (
4763
+ sorted_sp_idx[updated_indices_in_sp_idx]
4764
+ == updated_indices_in_sp
4765
+ )
4766
+ # Get the location of the row ids found in SP.
4767
+ # @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
4768
+ sp_locations_found = sp_idx_inverse_indices[
4769
+ updated_indices_in_sp[exact_match_mask]
4770
+ ]
4771
+ # Get the corresponding weights for the matched indices
4772
+ matched_weights = weights[exact_match_mask]
4773
+
4774
+ # Write the weights to the sparse tensor at the found locations
4775
+ if sp_locations_found.numel() > 0:
4776
+ sp.index_put_(
4777
+ (sp_locations_found,),
4778
+ matched_weights,
4779
+ accumulate=False,
4780
+ )
4781
+ current_stream.record_event(self.direct_write_sp_complete_event)
4782
+
4783
+ # step 3: write l1 cache missing rows to backend
4784
+ # Wait for step 2 to complete
4785
+ with record_function(
4786
+ f"## direct_write_to_backend_{step}_{self.tbe_unique_id} ##"
4787
+ ):
4788
+ # Use the existing ssd_eviction_stream for all backend write operations
4789
+ # This stream is already created with low priority during initialization
4790
+ with torch.cuda.stream(self.ssd_eviction_stream):
4791
+ # Create a mask for indices not in L1 cache
4792
+ non_cache_mask = ~cache_location_mask
4793
+
4794
+ # Calculate the count of valid indices (those not in L1 cache)
4795
+ valid_count = non_cache_mask.sum().to(torch.int64).cpu()
4796
+
4797
+ if valid_count.item() > 0:
4798
+ # Extract only the indices and weights that are not in L1 cache
4799
+ non_cache_indices = linear_cache_indices[non_cache_mask]
4800
+ non_cache_weights = weights[non_cache_mask]
4801
+
4802
+ # Move tensors to CPU for set_cuda
4803
+ cpu_indices = non_cache_indices.cpu()
4804
+ cpu_weights = non_cache_weights.cpu()
4805
+
4806
+ # Write to backend - only sending the non-cache indices and weights
4807
+ self.record_function_via_dummy_profile(
4808
+ f"## ssd_write_{step}_set_cuda_{self.tbe_unique_id} ##",
4809
+ self.ssd_db.set_cuda,
4810
+ cpu_indices,
4811
+ cpu_weights,
4812
+ valid_count,
4813
+ self.timestep,
4814
+ is_bwd=False,
4815
+ )
4816
+
4817
+ # Return control to the main stream without waiting for the backend operation to complete
4818
+
4819
+ def get_free_cpu_memory_gb(self) -> float:
4820
+ def _get_mem_available() -> float:
4821
+ if sys.platform.startswith("linux"):
4822
+ info = {}
4823
+ with open("/proc/meminfo") as f:
4824
+ for line in f:
4825
+ p = line.split()
4826
+ info[p[0].strip(":").lower()] = int(p[1]) * 1024
4827
+ if "memavailable" in info:
4828
+ # Linux >= 3.14
4829
+ return info["memavailable"]
4830
+ else:
4831
+ return info["memfree"] + info["cached"]
4832
+ else:
4833
+ raise RuntimeError(
4834
+ "Unsupported platform for free memory eviction, pls use ID count eviction tirgger mode"
4835
+ )
4836
+
4837
+ mem = _get_mem_available()
4838
+ return mem / (1024**3)
4839
+
4840
+ @classmethod
4841
+ def trigger_evict_in_all_tbes(cls) -> None:
4842
+ for tbe in cls._all_tbe_instances:
4843
+ tbe.ssd_db.trigger_feature_evict()
4844
+
4845
+ @classmethod
4846
+ def tbe_has_ongoing_eviction(cls) -> bool:
4847
+ for tbe in cls._all_tbe_instances:
4848
+ if tbe.ssd_db.is_evicting():
4849
+ return True
4850
+ return False
4851
+
4852
+ def set_free_mem_eviction_trigger_config(
4853
+ self, eviction_policy: EvictionPolicy
4854
+ ) -> None:
4855
+ self.enable_free_mem_trigger_eviction = True
4856
+ self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode
4857
+ assert (
4858
+ eviction_policy.eviction_free_mem_check_interval_batch is not None
4859
+ ), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode"
4860
+ self.eviction_free_mem_check_interval_batch: int = (
4861
+ eviction_policy.eviction_free_mem_check_interval_batch
4862
+ )
4863
+ assert (
4864
+ eviction_policy.eviction_free_mem_threshold_gb is not None
4865
+ ), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode"
4866
+ self.eviction_free_mem_threshold_gb: int = (
4867
+ eviction_policy.eviction_free_mem_threshold_gb
4868
+ )
4869
+ logging.info(
4870
+ f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}"
4871
+ )
4872
+
4873
+ def may_trigger_eviction(self) -> None:
4874
+ def is_first_tbe() -> bool:
4875
+ first = SSDTableBatchedEmbeddingBags._first_instance_ref
4876
+ return first is not None and first() is self
4877
+
4878
+ # We assume that the eviction time is less than free mem check interval time
4879
+ # So every time we reach this check, all evictions in all tbes should be finished.
4880
+ # We only need to check the first tbe because all tbes share the same free mem,
4881
+ # once the first tbe detect need to trigger eviction, it will call trigger func
4882
+ # in all tbes from _all_tbe_instances
4883
+ if (
4884
+ self.enable_free_mem_trigger_eviction
4885
+ and self.step % self.eviction_free_mem_check_interval_batch == 0
4886
+ and self.training
4887
+ and is_first_tbe()
4888
+ ):
4889
+ if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction():
4890
+ SSDTableBatchedEmbeddingBags._eviction_triggered = False
4891
+
4892
+ free_cpu_mem_gb = self.get_free_cpu_memory_gb()
4893
+ local_evict_trigger = int(
4894
+ free_cpu_mem_gb < self.eviction_free_mem_threshold_gb
4895
+ )
4896
+ tensor_flag = torch.tensor(
4897
+ local_evict_trigger,
4898
+ device=self.current_device,
4899
+ dtype=torch.int,
4900
+ )
4901
+ world_size = dist.get_world_size(self._pg)
4902
+ if world_size > 1:
4903
+ dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg)
4904
+ global_evict_trigger = tensor_flag.item()
4905
+ else:
4906
+ global_evict_trigger = local_evict_trigger
4907
+ if (
4908
+ global_evict_trigger >= 1
4909
+ and SSDTableBatchedEmbeddingBags._eviction_triggered
4910
+ ):
4911
+ logging.warning(
4912
+ f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true"
4913
+ )
4914
+ if (
4915
+ global_evict_trigger >= 1
4916
+ and not SSDTableBatchedEmbeddingBags._eviction_triggered
4917
+ ):
4918
+ SSDTableBatchedEmbeddingBags._eviction_triggered = True
4919
+ SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes()
4920
+ logging.info(
4921
+ f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction"
4922
+ )
4923
+
4924
+ def reset_inference_mode(self) -> None:
4925
+ """
4926
+ Reset the inference mode
4927
+ """
4928
+ self.eval()