fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -12,13 +12,15 @@ import contextlib
12
12
  import functools
13
13
  import itertools
14
14
  import logging
15
+ import math
15
16
  import os
16
- import tempfile
17
17
  import threading
18
18
  import time
19
+ from functools import cached_property
19
20
  from math import floor, log2
20
- from typing import Any, Callable, List, Optional, Tuple, Type, Union
21
+ from typing import Any, Callable, ClassVar, Optional, Union
21
22
  import torch # usort:skip
23
+ import weakref
22
24
 
23
25
  # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
24
26
  import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
@@ -29,9 +31,13 @@ from fbgemm_gpu.runtime_monitor import (
29
31
  )
30
32
  from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
31
33
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
34
+ BackendType,
32
35
  BoundsCheckMode,
33
36
  CacheAlgorithm,
34
37
  EmbeddingLocation,
38
+ EvictionPolicy,
39
+ get_bounds_check_version_for_platform,
40
+ KVZCHParams,
35
41
  PoolingMode,
36
42
  SplitState,
37
43
  )
@@ -39,21 +45,23 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
39
45
  apply_split_helper,
40
46
  CounterBasedRegularizationDefinition,
41
47
  CowClipDefinition,
48
+ RESParams,
42
49
  UVMCacheStatsIndex,
43
50
  WeightDecayMode,
44
51
  )
45
52
  from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
53
+ check_allocated_vbe_output,
46
54
  generate_vbe_metadata,
55
+ is_torchdynamo_compiling,
47
56
  )
48
-
49
57
  from torch import distributed as dist, nn, Tensor # usort:skip
58
+ import sys
50
59
  from dataclasses import dataclass
51
60
 
52
61
  from torch.autograd.profiler import record_function
53
62
 
54
63
  from ..cache import get_unique_indices_v2
55
-
56
- from .common import ASSOC
64
+ from .common import ASSOC, pad4, tensor_pad4
57
65
  from .utils.partially_materialized_tensor import PartiallyMaterializedTensor
58
66
 
59
67
 
@@ -69,6 +77,14 @@ class IterData:
69
77
  max_B: Optional[int] = -1
70
78
 
71
79
 
80
+ @dataclass
81
+ class KVZCHCachedData:
82
+ cached_optimizer_states_per_table: list[list[torch.Tensor]]
83
+ cached_weight_tensor_per_table: list[torch.Tensor]
84
+ cached_id_tensor_per_table: list[torch.Tensor]
85
+ cached_bucket_splits: list[torch.Tensor]
86
+
87
+
72
88
  class SSDTableBatchedEmbeddingBags(nn.Module):
73
89
  D_offsets: Tensor
74
90
  lxu_cache_weights: Tensor
@@ -86,12 +102,19 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
86
102
  weights_placements: Tensor
87
103
  weights_offsets: Tensor
88
104
  _local_instance_index: int = -1
105
+ res_params: RESParams
106
+ table_names: list[str]
107
+ _all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet()
108
+ _first_instance_ref: ClassVar[weakref.ref] = None
109
+ _eviction_triggered: ClassVar[bool] = False
89
110
 
90
111
  def __init__(
91
112
  self,
92
- embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims)
93
- feature_table_map: Optional[List[int]], # [T]
113
+ embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
114
+ feature_table_map: Optional[list[int]], # [T]
94
115
  cache_sets: int,
116
+ # A comma-separated string, e.g. "/data00_nvidia0,/data01_nvidia0/", db shards
117
+ # will be placed in these paths round-robin.
95
118
  ssd_storage_directory: str,
96
119
  ssd_rocksdb_shards: int = 1,
97
120
  ssd_memtable_flush_period: int = -1,
@@ -131,13 +154,16 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
131
154
  pooling_mode: PoolingMode = PoolingMode.SUM,
132
155
  bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
133
156
  # Parameter Server Configs
134
- ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
157
+ ps_hosts: Optional[tuple[tuple[str, int]]] = None,
135
158
  ps_max_key_per_request: Optional[int] = None,
136
159
  ps_client_thread_num: Optional[int] = None,
137
160
  ps_max_local_index_length: Optional[int] = None,
138
161
  tbe_unique_id: int = -1,
139
- # in local test we need to use the pass in path for rocksdb creation
140
- # in production we need to do it inside SSD mount path which will ignores the passed in path
162
+ # If set to True, will use `ssd_storage_directory` as the ssd paths.
163
+ # If set to False, will use the default ssd paths.
164
+ # In local test we need to use the pass in path for rocksdb creation
165
+ # fn production we could either use the default ssd mount points or explicity specify ssd
166
+ # mount points using `ssd_storage_directory`.
141
167
  use_passed_in_path: int = True,
142
168
  gather_ssd_cache_stats: Optional[bool] = False,
143
169
  stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
@@ -152,19 +178,126 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
152
178
  # number of rows will be decided by bulk_init_chunk_size / size_of_each_row
153
179
  bulk_init_chunk_size: int = 0,
154
180
  lazy_bulk_init_enabled: bool = False,
181
+ backend_type: BackendType = BackendType.SSD,
182
+ kv_zch_params: Optional[KVZCHParams] = None,
183
+ enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
184
+ res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
185
+ flushing_block_size: int = 2_000_000_000, # 2GB
186
+ table_names: Optional[list[str]] = None,
187
+ use_rowwise_bias_correction: bool = False, # For Adam use
188
+ optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
189
+ pg: Optional[dist.ProcessGroup] = None,
155
190
  ) -> None:
156
191
  super(SSDTableBatchedEmbeddingBags, self).__init__()
157
192
 
193
+ # Set the optimizer
194
+ assert optimizer in (
195
+ OptimType.EXACT_ROWWISE_ADAGRAD,
196
+ OptimType.PARTIAL_ROWWISE_ADAM,
197
+ OptimType.ADAM,
198
+ ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
199
+ self.optimizer = optimizer
200
+
201
+ # Set the table weight and output dtypes
202
+ assert weights_precision in (SparseType.FP32, SparseType.FP16)
203
+ self.weights_precision = weights_precision
204
+ self.output_dtype: int = output_dtype.as_int()
205
+
206
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
207
+ # Adagrad currently only supports FP32 for momentum1
208
+ self.optimizer_state_dtypes: dict[str, SparseType] = {
209
+ "momentum1": SparseType.FP32,
210
+ }
211
+ else:
212
+ self.optimizer_state_dtypes: dict[str, SparseType] = optimizer_state_dtypes
213
+
214
+ # Zero collision TBE configurations
215
+ self.kv_zch_params = kv_zch_params
216
+ self.backend_type = backend_type
217
+ self.enable_optimizer_offloading: bool = False
218
+ self.backend_return_whole_row: bool = False
219
+ self._embedding_cache_mode: bool = False
220
+ self.load_ckpt_without_opt: bool = False
221
+ if self.kv_zch_params:
222
+ self.kv_zch_params.validate()
223
+ self.load_ckpt_without_opt = (
224
+ # pyre-ignore [16]
225
+ self.kv_zch_params.load_ckpt_without_opt
226
+ )
227
+ self.enable_optimizer_offloading = (
228
+ # pyre-ignore [16]
229
+ self.kv_zch_params.enable_optimizer_offloading
230
+ )
231
+ self.backend_return_whole_row = (
232
+ # pyre-ignore [16]
233
+ self.kv_zch_params.backend_return_whole_row
234
+ )
235
+
236
+ if self.enable_optimizer_offloading:
237
+ logging.info("Optimizer state offloading is enabled")
238
+ if self.backend_return_whole_row:
239
+ assert (
240
+ self.backend_type == BackendType.DRAM
241
+ ), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}"
242
+ logging.info(
243
+ "Backend will return whole row including metaheader, weight and optimizer for checkpoint"
244
+ )
245
+ # pyre-ignore [16]
246
+ self._embedding_cache_mode = self.kv_zch_params.embedding_cache_mode
247
+ if self._embedding_cache_mode:
248
+ logging.info("KVZCH is in embedding_cache_mode")
249
+ assert self.optimizer in [
250
+ OptimType.EXACT_ROWWISE_ADAGRAD
251
+ ], f"only EXACT_ROWWISE_ADAGRAD supports embedding cache mode, but got {self.optimizer}"
252
+ if self.load_ckpt_without_opt:
253
+ if (
254
+ # pyre-ignore [16]
255
+ self.kv_zch_params.optimizer_type_for_st
256
+ == OptimType.PARTIAL_ROWWISE_ADAM.value
257
+ ):
258
+ self.optimizer = OptimType.PARTIAL_ROWWISE_ADAM
259
+ logging.info(
260
+ f"Override optimizer type with {self.optimizer=} for st publish"
261
+ )
262
+ if (
263
+ # pyre-ignore [16]
264
+ self.kv_zch_params.optimizer_state_dtypes_for_st
265
+ is not None
266
+ ):
267
+ optimizer_state_dtypes = {}
268
+ for k, v in dict(
269
+ self.kv_zch_params.optimizer_state_dtypes_for_st
270
+ ).items():
271
+ optimizer_state_dtypes[k] = SparseType.from_int(v)
272
+ self.optimizer_state_dtypes = optimizer_state_dtypes
273
+ logging.info(
274
+ f"Override optimizer_state_dtypes with {self.optimizer_state_dtypes=} for st publish"
275
+ )
276
+
158
277
  self.pooling_mode = pooling_mode
159
278
  self.bounds_check_mode_int: int = bounds_check_mode.value
160
279
  self.embedding_specs = embedding_specs
161
- (rows, dims) = zip(*embedding_specs)
280
+ self.table_names = table_names if table_names is not None else []
281
+ rows, dims = zip(*embedding_specs)
162
282
  T_ = len(self.embedding_specs)
163
283
  assert T_ > 0
164
284
  # pyre-fixme[8]: Attribute has type `device`; used as `int`.
165
285
  self.current_device: torch.device = torch.cuda.current_device()
166
286
 
167
- self.feature_table_map: List[int] = (
287
+ self.enable_raw_embedding_streaming = enable_raw_embedding_streaming
288
+ # initialize the raw embedding streaming related variables
289
+ self.res_params: RESParams = res_params or RESParams()
290
+ if self.enable_raw_embedding_streaming:
291
+ self.res_params.table_sizes = [0] + list(itertools.accumulate(rows))
292
+ res_port_from_env = os.getenv("LOCAL_RES_PORT")
293
+ self.res_params.res_server_port = (
294
+ int(res_port_from_env) if res_port_from_env else 0
295
+ )
296
+ logging.info(
297
+ f"get env {self.res_params.res_server_port=}, at rank {dist.get_rank()}, with {self.res_params=}"
298
+ )
299
+
300
+ self.feature_table_map: list[int] = (
168
301
  feature_table_map if feature_table_map is not None else list(range(T_))
169
302
  )
170
303
  T = len(self.feature_table_map)
@@ -177,7 +310,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
177
310
  feature_dims = [dims[t] for t in self.feature_table_map]
178
311
  D_offsets = [dims[t] for t in self.feature_table_map]
179
312
  D_offsets = [0] + list(itertools.accumulate(D_offsets))
313
+
314
+ # Sum of row length of all tables
180
315
  self.total_D: int = D_offsets[-1]
316
+
317
+ # Max number of elements required to store a row in the cache
181
318
  self.max_D: int = max(dims)
182
319
  self.register_buffer(
183
320
  "D_offsets",
@@ -189,6 +326,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
189
326
  self.total_hash_size_bits: int = 0
190
327
  else:
191
328
  self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
329
+ self.register_buffer(
330
+ "table_hash_size_cumsum",
331
+ torch.tensor(
332
+ hash_size_cumsum, device=self.current_device, dtype=torch.int64
333
+ ),
334
+ )
192
335
  # The last element is to easily access # of rows of each table by
193
336
  self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1)
194
337
  self.total_hash_size: int = hash_size_cumsum[-1]
@@ -229,13 +372,25 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
229
372
  "feature_dims",
230
373
  torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
231
374
  )
375
+ self.register_buffer(
376
+ "table_dims",
377
+ torch.tensor(dims, device="cpu", dtype=torch.int64),
378
+ )
379
+
380
+ info_B_num_bits_, info_B_mask_ = torch.ops.fbgemm.get_infos_metadata(
381
+ self.D_offsets, # unused tensor
382
+ 1, # max_B
383
+ T, # T
384
+ )
385
+ self.info_B_num_bits: int = info_B_num_bits_
386
+ self.info_B_mask: int = info_B_mask_
232
387
 
233
388
  assert cache_sets > 0
234
389
  element_size = weights_precision.bit_rate() // 8
235
390
  assert (
236
391
  element_size == 4 or element_size == 2
237
392
  ), f"Invalid element size {element_size}"
238
- cache_size = cache_sets * ASSOC * element_size * self.max_D
393
+ cache_size = cache_sets * ASSOC * element_size * self.cache_row_dim
239
394
  logging.info(
240
395
  f"Using cache for SSD with admission algorithm "
241
396
  f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_rocksdb_shards} shards, "
@@ -243,10 +398,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
243
398
  f"Memtable Flush Period: {ssd_memtable_flush_period}, "
244
399
  f"Memtable Flush Offset: {ssd_memtable_flush_offset}, "
245
400
  f"Desired L0 files per compaction: {ssd_l0_files_per_compact}, "
246
- f"{cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
401
+ f"Cache size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
247
402
  f"weights precision: {weights_precision}, "
248
403
  f"output dtype: {output_dtype}, "
249
- f"chunk size in bulk init: {bulk_init_chunk_size} bytes"
404
+ f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
405
+ f"kv_zch_params: {kv_zch_params}, "
406
+ f"embedding spec: {embedding_specs}"
250
407
  )
251
408
  self.register_buffer(
252
409
  "lxu_cache_state",
@@ -262,6 +419,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
262
419
  )
263
420
 
264
421
  self.step = 0
422
+ self.last_flush_step = -1
265
423
 
266
424
  # Set prefetch pipeline
267
425
  self.prefetch_pipeline: bool = prefetch_pipeline
@@ -291,10 +449,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
291
449
  EmbeddingLocation.DEVICE,
292
450
  )
293
451
 
294
- assert weights_precision in (SparseType.FP32, SparseType.FP16)
295
- self.weights_precision = weights_precision
296
- self.output_dtype: int = output_dtype.as_int()
297
-
298
452
  cache_dtype = weights_precision.as_dtype()
299
453
  if ssd_cache_location == EmbeddingLocation.MANAGED:
300
454
  self.register_buffer(
@@ -305,7 +459,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
305
459
  device=self.current_device,
306
460
  dtype=cache_dtype,
307
461
  ),
308
- [cache_sets * ASSOC, self.max_D],
462
+ [cache_sets * ASSOC, self.cache_row_dim],
309
463
  is_host_mapped=self.uvm_host_mapped,
310
464
  ),
311
465
  )
@@ -314,7 +468,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
314
468
  "lxu_cache_weights",
315
469
  torch.zeros(
316
470
  cache_sets * ASSOC,
317
- self.max_D,
471
+ self.cache_row_dim,
318
472
  device=self.current_device,
319
473
  dtype=cache_dtype,
320
474
  ),
@@ -387,6 +541,15 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
387
541
 
388
542
  self.timestep = 0
389
543
 
544
+ # Store the iteration number on GPU and CPU (used for certain optimizers)
545
+ persistent_iter_ = optimizer in (OptimType.PARTIAL_ROWWISE_ADAM,)
546
+ self.register_buffer(
547
+ "iter",
548
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
549
+ persistent=persistent_iter_,
550
+ )
551
+ self.iter_cpu: torch.Tensor = torch.zeros(1, dtype=torch.int64, device="cpu")
552
+
390
553
  # Dummy profile configuration for measuring the SSD get/set time
391
554
  # get and set are executed by another thread which (for some reason) is
392
555
  # not traceable by PyTorch's Kineto. We workaround this problem by
@@ -405,18 +568,46 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
405
568
  f"FBGEMM_SSD_TBE_USE_DUMMY_PROFILE is set to {set_dummy_profile}; "
406
569
  f"Use dummy profile: {use_dummy_profile}"
407
570
  )
408
- # pyre-ignore[4]
571
+
409
572
  self.record_function_via_dummy_profile: Callable[..., Any] = (
410
573
  self.record_function_via_dummy_profile_factory(use_dummy_profile)
411
574
  )
412
575
 
413
- os.makedirs(ssd_storage_directory, exist_ok=True)
576
+ if use_passed_in_path:
577
+ ssd_dir_list = ssd_storage_directory.split(",")
578
+ for ssd_dir in ssd_dir_list:
579
+ os.makedirs(ssd_dir, exist_ok=True)
414
580
 
415
- ssd_directory = tempfile.mkdtemp(
416
- prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
417
- )
581
+ ssd_directory = ssd_storage_directory
418
582
  # logging.info("DEBUG: weights_precision {}".format(weights_precision))
419
583
 
584
+ """
585
+ ##################### for ZCH v.Next loading checkpoints Short Term Solution #######################
586
+ weight_id tensor is the weight and optimizer keys, to load from checkpoint, weight_id tensor
587
+ needs to be loaded first, then we can load the weight and optimizer tensors.
588
+ However, the stateful checkpoint loading does not guarantee the tensor loading order, so we need
589
+ to cache the weight_id, weight and optimizer tensors untils all data are loaded, then we can apply
590
+ them to backend.
591
+ Currently, we'll cache the weight_id, weight and optimizer tensors in the KVZCHCachedData class,
592
+ and apply them to backend when all data are loaded. The downside of this solution is that we'll
593
+ have to duplicate a whole tensor memory to backend before we can release the python tensor memory,
594
+ which is not ideal.
595
+ The longer term solution is to support the caching from the backend side, and allow streaming based
596
+ data move from cached weight and optimizer to key/value format without duplicate one whole tensor's
597
+ memory.
598
+ """
599
+ self._cached_kvzch_data: Optional[KVZCHCachedData] = None
600
+ # initial embedding rows on this rank per table, this is used for loading checkpoint
601
+ self.local_weight_counts: list[int] = [0] * T_
602
+ # groundtruth global id on this rank per table, this is used for loading checkpoint
603
+ self.global_id_per_rank: list[torch.Tensor] = [torch.zeros(0)] * T_
604
+ # loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend
605
+ self.load_state_dict: bool = False
606
+
607
+ SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self)
608
+ if SSDTableBatchedEmbeddingBags._first_instance_ref is None:
609
+ SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self)
610
+
420
611
  # create tbe unique id using rank index | local tbe idx
421
612
  if tbe_unique_id == -1:
422
613
  SSDTableBatchedEmbeddingBags._local_instance_index += 1
@@ -432,21 +623,26 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
432
623
  logging.warning("dist is not initialized, treating as single gpu cases")
433
624
  tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
434
625
  self.tbe_unique_id = tbe_unique_id
626
+ self.l2_cache_size = l2_cache_size
435
627
  logging.info(f"tbe_unique_id: {tbe_unique_id}")
436
- if not ps_hosts:
628
+ self.enable_free_mem_trigger_eviction: bool = False
629
+ if self.backend_type == BackendType.SSD:
437
630
  logging.info(
438
- f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, enable_async_update:{enable_async_update}"
439
- f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
440
- f"memtable_flush_period={ssd_memtable_flush_period},memtable_flush_offset={ssd_memtable_flush_offset},"
441
- f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},rate_limit_mbps={ssd_rate_limit_mbps},"
442
- f"size_ratio={ssd_size_ratio},compaction_trigger={ssd_compaction_trigger}, lazy_bulk_init_enabled={lazy_bulk_init_enabled},"
443
- f"write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size},max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num},"
444
- f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
445
- f"row_storage_bitwidth={weights_precision.bit_rate()},block_cache_size_per_tbe={ssd_block_cache_size_per_tbe},"
446
- f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB"
631
+ f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
632
+ f"enable_async_update:{enable_async_update}, passed_in_path={ssd_directory}, "
633
+ f"num_shards={ssd_rocksdb_shards}, num_threads={ssd_rocksdb_shards}, "
634
+ f"memtable_flush_period={ssd_memtable_flush_period}, memtable_flush_offset={ssd_memtable_flush_offset}, "
635
+ f"l0_files_per_compact={ssd_l0_files_per_compact}, max_D={self.max_D}, "
636
+ f"cache_row_size={self.cache_row_dim}, rate_limit_mbps={ssd_rate_limit_mbps}, "
637
+ f"size_ratio={ssd_size_ratio}, compaction_trigger={ssd_compaction_trigger}, "
638
+ f"lazy_bulk_init_enabled={lazy_bulk_init_enabled}, write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size}, "
639
+ f"max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num}, "
640
+ f"uniform_init_lower={ssd_uniform_init_lower}, uniform_init_upper={ssd_uniform_init_upper}, "
641
+ f"row_storage_bitwidth={weights_precision.bit_rate()}, block_cache_size_per_tbe={ssd_block_cache_size_per_tbe}, "
642
+ f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB, "
643
+ f"enable_raw_embedding_streaming:{self.enable_raw_embedding_streaming}, flushing_block_size:{flushing_block_size}"
447
644
  )
448
645
  # pyre-fixme[4]: Attribute must be annotated.
449
- # pyre-ignore[16]
450
646
  self._ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
451
647
  ssd_directory,
452
648
  ssd_rocksdb_shards,
@@ -454,7 +650,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
454
650
  ssd_memtable_flush_period,
455
651
  ssd_memtable_flush_offset,
456
652
  ssd_l0_files_per_compact,
457
- self.max_D,
653
+ self.cache_row_dim,
458
654
  ssd_rate_limit_mbps,
459
655
  ssd_size_ratio,
460
656
  ssd_compaction_trigger,
@@ -468,6 +664,24 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
468
664
  tbe_unique_id,
469
665
  l2_cache_size,
470
666
  enable_async_update,
667
+ self.enable_raw_embedding_streaming,
668
+ self.res_params.res_store_shards,
669
+ self.res_params.res_server_port,
670
+ self.res_params.table_names,
671
+ self.res_params.table_offsets,
672
+ self.res_params.table_sizes,
673
+ (
674
+ tensor_pad4(self.table_dims)
675
+ if self.enable_optimizer_offloading
676
+ else None
677
+ ),
678
+ (
679
+ self.table_hash_size_cumsum.cpu()
680
+ if self.enable_optimizer_offloading
681
+ else None
682
+ ),
683
+ flushing_block_size,
684
+ self._embedding_cache_mode, # disable_random_init
471
685
  )
472
686
  if self.bulk_init_chunk_size > 0:
473
687
  self.ssd_uniform_init_lower: float = ssd_uniform_init_lower
@@ -476,11 +690,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
476
690
  self._lazy_initialize_ssd_tbe()
477
691
  else:
478
692
  self._insert_all_kv()
479
- else:
480
- # pyre-fixme[4]: Attribute must be annotated.
481
- # pyre-ignore[16]
693
+ elif self.backend_type == BackendType.PS:
482
694
  self._ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
483
- [host[0] for host in ps_hosts],
695
+ [host[0] for host in ps_hosts], # pyre-ignore
484
696
  [host[1] for host in ps_hosts],
485
697
  tbe_unique_id,
486
698
  (
@@ -491,14 +703,98 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
491
703
  ps_client_thread_num if ps_client_thread_num is not None else 32,
492
704
  ps_max_key_per_request if ps_max_key_per_request is not None else 500,
493
705
  l2_cache_size,
494
- self.max_D,
706
+ self.cache_row_dim,
707
+ )
708
+ elif self.backend_type == BackendType.DRAM:
709
+ logging.info(
710
+ f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,"
711
+ f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
712
+ f"max_D={self.max_D},"
713
+ f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
714
+ f"row_storage_bitwidth={weights_precision.bit_rate()},"
715
+ f"self.cache_row_dim={self.cache_row_dim},"
716
+ f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
717
+ f"feature_dims={self.feature_dims},"
718
+ f"hash_size_cumsum={self.hash_size_cumsum},"
719
+ f"backend_return_whole_row={self.backend_return_whole_row}"
495
720
  )
721
+ table_dims = (
722
+ tensor_pad4(self.table_dims)
723
+ if self.enable_optimizer_offloading
724
+ else None
725
+ ) # table_dims
726
+ eviction_config = None
727
+ if self.kv_zch_params and self.kv_zch_params.eviction_policy:
728
+ eviction_mem_threshold_gb = (
729
+ self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
730
+ if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
731
+ else self.l2_cache_size
732
+ )
733
+ kv_zch_params = self.kv_zch_params
734
+ eviction_policy = self.kv_zch_params.eviction_policy
735
+ if eviction_policy.eviction_trigger_mode == 5:
736
+ # If trigger mode is free_mem(5), populate config
737
+ self.set_free_mem_eviction_trigger_config(eviction_policy)
738
+
739
+ enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
740
+ [
741
+ int(x)
742
+ for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
743
+ ]
744
+ if eviction_policy.enable_eviction_for_feature_score_eviction_policy
745
+ is not None
746
+ else None
747
+ )
748
+ # Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
749
+ eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
750
+ eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
751
+ eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
752
+ eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
753
+ eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
754
+ eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
755
+ eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
756
+ eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
757
+ eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
758
+ eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
759
+ eviction_policy.training_id_keep_count, # training_id_keep_count for each table
760
+ enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
761
+ eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
762
+ table_dims.tolist() if table_dims is not None else None,
763
+ eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
764
+ eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
765
+ eviction_policy.interval_for_insufficient_eviction_s,
766
+ eviction_policy.interval_for_sufficient_eviction_s,
767
+ eviction_policy.interval_for_feature_statistics_decay_s,
768
+ )
769
+ self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
770
+ self.cache_row_dim,
771
+ ssd_uniform_init_lower,
772
+ ssd_uniform_init_upper,
773
+ eviction_config,
774
+ ssd_rocksdb_shards, # num_shards
775
+ ssd_rocksdb_shards, # num_threads
776
+ weights_precision.bit_rate(), # row_storage_bitwidth
777
+ table_dims,
778
+ (
779
+ self.table_hash_size_cumsum.cpu()
780
+ if self.enable_optimizer_offloading
781
+ else None
782
+ ), # hash_size_cumsum
783
+ self.backend_return_whole_row, # backend_return_whole_row
784
+ False, # enable_async_update
785
+ self._embedding_cache_mode, # disable_random_init
786
+ )
787
+ else:
788
+ raise AssertionError(f"Invalid backend type {self.backend_type}")
789
+
496
790
  # pyre-fixme[20]: Argument `self` expected.
497
- (low_priority, high_priority) = torch.cuda.Stream.priority_range()
791
+ low_priority, high_priority = torch.cuda.Stream.priority_range()
498
792
  # GPU stream for SSD cache eviction
499
793
  self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
500
- # GPU stream for SSD memory copy
794
+ # GPU stream for SSD memory copy (also reused for feature score D2H)
501
795
  self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
796
+ # GPU stream for async metadata operation
797
+ self.feature_score_stream = torch.cuda.Stream(priority=low_priority)
502
798
 
503
799
  # SSD get completion event
504
800
  self.ssd_event_get = torch.cuda.Event()
@@ -510,26 +806,93 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
510
806
  self.ssd_event_backward = torch.cuda.Event()
511
807
  # SSD get's input copy completion event
512
808
  self.ssd_event_get_inputs_cpy = torch.cuda.Event()
809
+ if self._embedding_cache_mode:
810
+ # Direct write embedding completion event
811
+ self.direct_write_l1_complete_event: torch.cuda.streams.Event = (
812
+ torch.cuda.Event()
813
+ )
814
+ self.direct_write_sp_complete_event: torch.cuda.streams.Event = (
815
+ torch.cuda.Event()
816
+ )
817
+ # Prefetch operation completion event
818
+ self.prefetch_complete_event = torch.cuda.Event()
819
+
513
820
  if self.prefetch_pipeline:
514
821
  # SSD scratch pad index queue insert completion event
515
822
  self.ssd_event_sp_idxq_insert: torch.cuda.streams.Event = torch.cuda.Event()
516
823
  # SSD scratch pad index queue lookup completion event
517
824
  self.ssd_event_sp_idxq_lookup: torch.cuda.streams.Event = torch.cuda.Event()
518
825
 
519
- self.timesteps_prefetched: List[int] = []
826
+ if self.enable_raw_embedding_streaming:
827
+ # RES reuse the eviction stream
828
+ self.ssd_event_cache_streamed: torch.cuda.streams.Event = torch.cuda.Event()
829
+ self.ssd_event_cache_streaming_synced: torch.cuda.streams.Event = (
830
+ torch.cuda.Event()
831
+ )
832
+ self.ssd_event_cache_streaming_computed: torch.cuda.streams.Event = (
833
+ torch.cuda.Event()
834
+ )
835
+ self.ssd_event_sp_streamed: torch.cuda.streams.Event = torch.cuda.Event()
836
+
837
+ # Updated buffers
838
+ self.register_buffer(
839
+ "lxu_cache_updated_weights",
840
+ torch.ops.fbgemm.new_unified_tensor(
841
+ torch.zeros(
842
+ 1,
843
+ device=self.current_device,
844
+ dtype=cache_dtype,
845
+ ),
846
+ self.lxu_cache_weights.shape,
847
+ is_host_mapped=self.uvm_host_mapped,
848
+ ),
849
+ )
850
+
851
+ # For storing embedding indices to update to
852
+ self.register_buffer(
853
+ "lxu_cache_updated_indices",
854
+ torch.ops.fbgemm.new_unified_tensor(
855
+ torch.zeros(
856
+ 1,
857
+ device=self.current_device,
858
+ dtype=torch.long,
859
+ ),
860
+ (self.lxu_cache_weights.shape[0],),
861
+ is_host_mapped=self.uvm_host_mapped,
862
+ ),
863
+ )
864
+
865
+ # For storing the number of updated rows
866
+ self.register_buffer(
867
+ "lxu_cache_updated_count",
868
+ torch.ops.fbgemm.new_unified_tensor(
869
+ torch.zeros(
870
+ 1,
871
+ device=self.current_device,
872
+ dtype=torch.int,
873
+ ),
874
+ (1,),
875
+ is_host_mapped=self.uvm_host_mapped,
876
+ ),
877
+ )
878
+
879
+ # (Indices, Count)
880
+ self.prefetched_info: list[tuple[Tensor, Tensor]] = []
881
+
882
+ self.timesteps_prefetched: list[int] = []
520
883
  # TODO: add type annotation
521
884
  # pyre-fixme[4]: Attribute must be annotated.
522
885
  self.ssd_prefetch_data = []
523
886
 
524
887
  # Scratch pad eviction data queue
525
- self.ssd_scratch_pad_eviction_data: List[
526
- Tuple[Tensor, Tensor, Tensor, bool]
888
+ self.ssd_scratch_pad_eviction_data: list[
889
+ tuple[Tensor, Tensor, Tensor, bool]
527
890
  ] = []
528
- self.ssd_location_update_data: List[Tuple[Tensor, Tensor]] = []
891
+ self.ssd_location_update_data: list[tuple[Tensor, Tensor]] = []
529
892
 
530
893
  if self.prefetch_pipeline:
531
894
  # Scratch pad value queue
532
- self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = []
895
+ self.ssd_scratch_pads: list[tuple[Tensor, Tensor, Tensor]] = []
533
896
 
534
897
  # pyre-ignore[4]
535
898
  # Scratch pad index queue
@@ -549,12 +912,15 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
549
912
  )
550
913
  cowclip_regularization = CowClipDefinition()
551
914
 
915
+ self.learning_rate_tensor: torch.Tensor = torch.tensor(
916
+ learning_rate, device=torch.device("cpu"), dtype=torch.float32
917
+ )
918
+
552
919
  self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs(
553
920
  stochastic_rounding=stochastic_rounding,
554
921
  gradient_clipping=gradient_clipping,
555
922
  max_gradient=max_gradient,
556
923
  max_norm=max_norm,
557
- learning_rate=learning_rate,
558
924
  eps=eps,
559
925
  beta1=beta1,
560
926
  beta2=beta2,
@@ -575,7 +941,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
575
941
  weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
576
942
  lower_bound=cowclip_regularization.lower_bound,
577
943
  regularization_mode=weight_decay_mode.value,
578
- use_rowwise_bias_correction=False, # Unused, this is used in TBE's Adam
944
+ use_rowwise_bias_correction=use_rowwise_bias_correction, # Used in Adam optimizer
579
945
  )
580
946
 
581
947
  table_embedding_dtype = weights_precision.as_dtype()
@@ -593,19 +959,14 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
593
959
  dtype=table_embedding_dtype,
594
960
  )
595
961
 
596
- momentum1_offsets = [0] + list(itertools.accumulate(rows))
597
- self._apply_split(
598
- SplitState(
599
- dev_size=self.total_hash_size,
600
- host_size=0,
601
- uvm_size=0,
602
- placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
603
- offsets=momentum1_offsets[:-1],
604
- ),
605
- "momentum1",
962
+ # Create the optimizer state tensors
963
+ for template in self.optimizer.ssd_state_splits(
964
+ self.embedding_specs,
965
+ self.optimizer_state_dtypes,
966
+ self.enable_optimizer_offloading,
967
+ ):
606
968
  # pyre-fixme[6]: For 3rd argument expected `Type[dtype]` but got `dtype`.
607
- dtype=torch.float32,
608
- )
969
+ self._apply_split(*template)
609
970
 
610
971
  # For storing current iteration data
611
972
  self.current_iter_data: Optional[IterData] = None
@@ -625,11 +986,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
625
986
  self._update_cache_counter_and_pointers
626
987
  )
627
988
 
628
- assert optimizer in (
629
- OptimType.EXACT_ROWWISE_ADAGRAD,
630
- ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
631
- self.optimizer = optimizer
632
-
633
989
  # stats reporter
634
990
  self.gather_ssd_cache_stats = gather_ssd_cache_stats
635
991
  self.stats_reporter: Optional[TBEStatsReporter] = (
@@ -638,7 +994,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
638
994
  self.ssd_cache_stats_size = 6
639
995
  # 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
640
996
  # 4: N_conflict_unique_misses, 5: N_conflict_misses
641
- self.last_reported_ssd_stats: List[float] = []
997
+ self.last_reported_ssd_stats: list[float] = []
642
998
  self.last_reported_step = 0
643
999
 
644
1000
  self.register_buffer(
@@ -669,7 +1025,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
669
1025
  self.prefetch_parallel_stream_cnt: int = 2
670
1026
  # tuple of iteration, prefetch parallel stream cnt, reported duration
671
1027
  # since there are 2 stream in parallel in prefetch, we want to count the longest one
672
- self.prefetch_duration_us: Tuple[int, int, float] = (
1028
+ self.prefetch_duration_us: tuple[int, int, float] = (
673
1029
  -1,
674
1030
  self.prefetch_parallel_stream_cnt,
675
1031
  0,
@@ -689,6 +1045,26 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
689
1045
  self.l2_cache_capacity_stats_name: str = (
690
1046
  f"l2_cache.mem.tbe_id{tbe_unique_id}.capacity_bytes"
691
1047
  )
1048
+ self.dram_kv_actual_used_chunk_bytes_stats_name: str = (
1049
+ f"dram_kv.mem.tbe_id{tbe_unique_id}.actual_used_chunk_bytes"
1050
+ )
1051
+ self.dram_kv_allocated_bytes_stats_name: str = (
1052
+ f"dram_kv.mem.tbe_id{tbe_unique_id}.allocated_bytes"
1053
+ )
1054
+ self.dram_kv_mem_num_rows_stats_name: str = (
1055
+ f"dram_kv.mem.tbe_id{tbe_unique_id}.num_rows"
1056
+ )
1057
+
1058
+ self.eviction_sum_evicted_counts_stats_name: str = (
1059
+ f"eviction.tbe_id.{tbe_unique_id}.sum_evicted_counts"
1060
+ )
1061
+ self.eviction_sum_processed_counts_stats_name: str = (
1062
+ f"eviction.tbe_id.{tbe_unique_id}.sum_processed_counts"
1063
+ )
1064
+ self.eviction_evict_rate_stats_name: str = (
1065
+ f"eviction.tbe_id.{tbe_unique_id}.evict_rate"
1066
+ )
1067
+
692
1068
  if self.stats_reporter:
693
1069
  self.ssd_prefetch_read_timer = AsyncSeriesTimer(
694
1070
  functools.partial(
@@ -708,11 +1084,77 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
708
1084
  )
709
1085
  # pyre-ignore
710
1086
  self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name)
711
- # pyre-ignore
712
1087
  self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name)
713
1088
  self.stats_reporter.register_stats(self.l2_num_cache_evictions_stats_name)
714
1089
  self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name)
715
1090
  self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name)
1091
+ self.stats_reporter.register_stats(self.dram_kv_allocated_bytes_stats_name)
1092
+ self.stats_reporter.register_stats(
1093
+ self.dram_kv_actual_used_chunk_bytes_stats_name
1094
+ )
1095
+ self.stats_reporter.register_stats(self.dram_kv_mem_num_rows_stats_name)
1096
+ self.stats_reporter.register_stats(
1097
+ self.eviction_sum_evicted_counts_stats_name
1098
+ )
1099
+ self.stats_reporter.register_stats(
1100
+ self.eviction_sum_processed_counts_stats_name
1101
+ )
1102
+ self.stats_reporter.register_stats(self.eviction_evict_rate_stats_name)
1103
+ for t in self.feature_table_map:
1104
+ self.stats_reporter.register_stats(
1105
+ f"eviction.feature_table.{t}.evicted_counts"
1106
+ )
1107
+ self.stats_reporter.register_stats(
1108
+ f"eviction.feature_table.{t}.processed_counts"
1109
+ )
1110
+ self.stats_reporter.register_stats(
1111
+ f"eviction.feature_table.{t}.evict_rate"
1112
+ )
1113
+ self.stats_reporter.register_stats(
1114
+ "eviction.feature_table.full_duration_ms"
1115
+ )
1116
+ self.stats_reporter.register_stats(
1117
+ "eviction.feature_table.exec_duration_ms"
1118
+ )
1119
+ self.stats_reporter.register_stats(
1120
+ "eviction.feature_table.dry_run_exec_duration_ms"
1121
+ )
1122
+ self.stats_reporter.register_stats(
1123
+ "eviction.feature_table.exec_div_full_duration_rate"
1124
+ )
1125
+
1126
+ self.bounds_check_version: int = get_bounds_check_version_for_platform()
1127
+
1128
+ self._pg = pg
1129
+
1130
+ @cached_property
1131
+ def cache_row_dim(self) -> int:
1132
+ """
1133
+ Compute the effective physical cache row size taking into account
1134
+ padding to the nearest 4 elements and the optimizer state appended to
1135
+ the back of the row
1136
+ """
1137
+
1138
+ # For st publish, we only need to load weight for publishing and bulk eval
1139
+ if self.enable_optimizer_offloading and not self.load_ckpt_without_opt:
1140
+ return self.max_D + pad4(
1141
+ # Compute the number of elements of cache_dtype needed to store
1142
+ # the optimizer state
1143
+ self.optimizer_state_dim
1144
+ )
1145
+ else:
1146
+ return self.max_D
1147
+
1148
+ @cached_property
1149
+ def optimizer_state_dim(self) -> int:
1150
+ return int(
1151
+ math.ceil(
1152
+ self.optimizer.state_size_nbytes(
1153
+ self.max_D, self.optimizer_state_dtypes
1154
+ )
1155
+ / self.weights_precision.as_dtype().itemsize
1156
+ )
1157
+ )
716
1158
 
717
1159
  @property
718
1160
  # pyre-ignore
@@ -766,19 +1208,22 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
766
1208
  be effectively overwritten. This function should only be called once at
767
1209
  initailization time.
768
1210
  """
1211
+ self._ssd_db.toggle_compaction(False)
769
1212
  row_offset = 0
770
1213
  row_count = floor(
771
1214
  self.bulk_init_chunk_size
772
- / (self.max_D * self.weights_precision.as_dtype().itemsize)
1215
+ / (self.cache_row_dim * self.weights_precision.as_dtype().itemsize)
773
1216
  )
774
1217
  total_dim0 = 0
775
1218
  for dim0, _ in self.embedding_specs:
776
1219
  total_dim0 += dim0
777
1220
 
778
1221
  start_ts = time.time()
1222
+ # TODO: do we have case for non-kvzch ssd with bulk init enabled + optimizer offloading? probably not?
1223
+ # if we have such cases, we should only init the emb dim not the optimizer dim
779
1224
  chunk_tensor = torch.empty(
780
1225
  row_count,
781
- self.max_D,
1226
+ self.cache_row_dim,
782
1227
  dtype=self.weights_precision.as_dtype(),
783
1228
  device="cuda",
784
1229
  )
@@ -793,12 +1238,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
793
1238
  # This code is intentionally not calling through the getter property
794
1239
  # to avoid the lazy initialization thread from joining with itself.
795
1240
  self._ssd_db.set_range_to_storage(rand_val, row_offset, actual_dim0)
796
- self.ssd_db.toggle_compaction(True)
797
1241
  end_ts = time.time()
798
1242
  elapsed = int((end_ts - start_ts) * 1e6)
799
1243
  logging.info(
800
1244
  f"TBE bulk initialization took {elapsed:_} us, bulk_init_chunk_size={self.bulk_init_chunk_size}, each batch of {row_count} rows, total rows of {total_dim0}"
801
1245
  )
1246
+ self._ssd_db.toggle_compaction(True)
802
1247
 
803
1248
  @torch.jit.ignore
804
1249
  def _report_duration(
@@ -826,7 +1271,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
826
1271
  """
827
1272
  recorded_itr, stream_cnt, report_val = self.prefetch_duration_us
828
1273
  duration = dur_ms
829
- if time_unit == "us": # pyre-ignore
1274
+ if time_unit == "us":
830
1275
  duration = dur_ms * 1000
831
1276
  if it_step == recorded_itr:
832
1277
  report_val = max(report_val, duration)
@@ -845,7 +1290,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
845
1290
  it_step, event_name, report_val, time_unit=time_unit
846
1291
  )
847
1292
 
848
- # pyre-ignore[3]
849
1293
  def record_function_via_dummy_profile_factory(
850
1294
  self,
851
1295
  use_dummy_profile: bool,
@@ -867,7 +1311,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
867
1311
 
868
1312
  def func(
869
1313
  name: str,
870
- # pyre-ignore[2]
871
1314
  fn: Callable[..., Any],
872
1315
  *args: Any,
873
1316
  **kwargs: Any,
@@ -881,7 +1324,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
881
1324
 
882
1325
  def func(
883
1326
  name: str,
884
- # pyre-ignore[2]
885
1327
  fn: Callable[..., Any],
886
1328
  *args: Any,
887
1329
  **kwargs: Any,
@@ -894,10 +1336,10 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
894
1336
  self,
895
1337
  split: SplitState,
896
1338
  prefix: str,
897
- dtype: Type[torch.dtype],
1339
+ dtype: type[torch.dtype],
898
1340
  enforce_hbm: bool = False,
899
1341
  make_dev_param: bool = False,
900
- dev_reshape: Optional[Tuple[int, ...]] = None,
1342
+ dev_reshape: Optional[tuple[int, ...]] = None,
901
1343
  ) -> None:
902
1344
  apply_split_helper(
903
1345
  self.register_buffer,
@@ -920,11 +1362,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
920
1362
 
921
1363
  def to_pinned_cpu_on_stream_wait_on_another_stream(
922
1364
  self,
923
- tensors: List[Tensor],
1365
+ tensors: list[Tensor],
924
1366
  stream: torch.cuda.Stream,
925
1367
  stream_to_wait_on: torch.cuda.Stream,
926
1368
  post_event: Optional[torch.cuda.Event] = None,
927
- ) -> List[Tensor]:
1369
+ ) -> list[Tensor]:
928
1370
  """
929
1371
  Transfer input tensors from GPU to CPU using a pinned host
930
1372
  buffer. The transfer is carried out on the given stream
@@ -982,9 +1424,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
982
1424
  is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
983
1425
  tensor (which is accessible on both host and
984
1426
  device)
1427
+ is_bwd (bool): A flag to indicate if the eviction is during backward
985
1428
  Returns:
986
1429
  None
987
1430
  """
1431
+ if not self.training: # if not training, freeze the embedding
1432
+ return
988
1433
  with record_function(f"## ssd_evict_{name} ##"):
989
1434
  with torch.cuda.stream(stream):
990
1435
  if pre_event is not None:
@@ -1007,6 +1452,95 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1007
1452
  if post_event is not None:
1008
1453
  stream.record_event(post_event)
1009
1454
 
1455
+ def raw_embedding_stream_sync(
1456
+ self,
1457
+ stream: torch.cuda.Stream,
1458
+ pre_event: Optional[torch.cuda.Event],
1459
+ post_event: Optional[torch.cuda.Event],
1460
+ name: Optional[str] = "",
1461
+ ) -> None:
1462
+ """
1463
+ Blocking wait the copy operation of the tensors to be streamed,
1464
+ to make sure they are not overwritten
1465
+ Args:
1466
+ stream (Stream): The CUDA stream that cudaStreamAddCallback will
1467
+ synchronize the host function with. Moreover, the
1468
+ asynchronous D->H memory copies will operate on
1469
+ this stream
1470
+ pre_event (Event): The CUDA event that the stream has to wait on
1471
+ post_event (Event): The CUDA event that the current will record on
1472
+ when the eviction is done
1473
+ Returns:
1474
+ None
1475
+ """
1476
+ with record_function(f"## ssd_stream_{name} ##"):
1477
+ with torch.cuda.stream(stream):
1478
+ if pre_event is not None:
1479
+ stream.wait_event(pre_event)
1480
+
1481
+ self.record_function_via_dummy_profile(
1482
+ f"## ssd_stream_sync_{name} ##",
1483
+ self.ssd_db.stream_sync_cuda,
1484
+ )
1485
+
1486
+ if post_event is not None:
1487
+ stream.record_event(post_event)
1488
+
1489
+ def raw_embedding_stream(
1490
+ self,
1491
+ rows: Tensor,
1492
+ indices_cpu: Tensor,
1493
+ actions_count_cpu: Tensor,
1494
+ stream: torch.cuda.Stream,
1495
+ pre_event: Optional[torch.cuda.Event],
1496
+ post_event: Optional[torch.cuda.Event],
1497
+ is_rows_uvm: bool,
1498
+ blocking_tensor_copy: bool = True,
1499
+ name: Optional[str] = "",
1500
+ ) -> None:
1501
+ """
1502
+ Stream data from the given input tensors to a remote service
1503
+ Args:
1504
+ rows (Tensor): The 2D tensor that contains rows to evict
1505
+ indices_cpu (Tensor): The 1D CPU tensor that contains the row
1506
+ indices that the rows will be evicted to
1507
+ actions_count_cpu (Tensor): A scalar tensor that contains the
1508
+ number of rows that the evict function
1509
+ has to process
1510
+ stream (Stream): The CUDA stream that cudaStreamAddCallback will
1511
+ synchronize the host function with. Moreover, the
1512
+ asynchronous D->H memory copies will operate on
1513
+ this stream
1514
+ pre_event (Event): The CUDA event that the stream has to wait on
1515
+ post_event (Event): The CUDA event that the current will record on
1516
+ when the eviction is done
1517
+ is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
1518
+ tensor (which is accessible on both host and
1519
+ device)
1520
+ Returns:
1521
+ None
1522
+ """
1523
+ with record_function(f"## ssd_stream_{name} ##"):
1524
+ with torch.cuda.stream(stream):
1525
+ if pre_event is not None:
1526
+ stream.wait_event(pre_event)
1527
+
1528
+ rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)
1529
+
1530
+ rows.record_stream(stream)
1531
+
1532
+ self.record_function_via_dummy_profile(
1533
+ f"## ssd_stream_{name} ##",
1534
+ self.ssd_db.stream_cuda,
1535
+ indices_cpu,
1536
+ rows_cpu,
1537
+ actions_count_cpu,
1538
+ blocking_tensor_copy,
1539
+ )
1540
+
1541
+ if post_event is not None:
1542
+ stream.record_event(post_event)
1543
+
1010
1544
  def _evict_from_scratch_pad(self, grad: Tensor) -> None:
1011
1545
  """
1012
1546
  Evict conflict missed rows from a scratch pad
@@ -1043,6 +1577,18 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1043
1577
  if not do_evict:
1044
1578
  return
1045
1579
 
1580
+ if self.enable_raw_embedding_streaming:
1581
+ self.raw_embedding_stream(
1582
+ rows=inserted_rows,
1583
+ indices_cpu=post_bwd_evicted_indices_cpu,
1584
+ actions_count_cpu=actions_count_cpu,
1585
+ stream=self.ssd_eviction_stream,
1586
+ pre_event=self.ssd_event_backward,
1587
+ post_event=self.ssd_event_sp_streamed,
1588
+ is_rows_uvm=True,
1589
+ blocking_tensor_copy=True,
1590
+ name="scratch_pad",
1591
+ )
1046
1592
  self.evict(
1047
1593
  rows=inserted_rows,
1048
1594
  indices_cpu=post_bwd_evicted_indices_cpu,
@@ -1060,7 +1606,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1060
1606
  def _update_cache_counter_and_pointers(
1061
1607
  self,
1062
1608
  module: nn.Module,
1063
- grad_input: Union[Tuple[Tensor, ...], Tensor],
1609
+ grad_input: Union[tuple[Tensor, ...], Tensor],
1064
1610
  ) -> None:
1065
1611
  """
1066
1612
  Update cache line locking counter and pointers before backward
@@ -1145,9 +1691,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1145
1691
  if len(self.ssd_location_update_data) == 0:
1146
1692
  return
1147
1693
 
1148
- (sp_curr_next_map, inserted_rows_next) = self.ssd_location_update_data.pop(
1149
- 0
1150
- )
1694
+ sp_curr_next_map, inserted_rows_next = self.ssd_location_update_data.pop(0)
1151
1695
 
1152
1696
  # Update poitners
1153
1697
  torch.ops.fbgemm.ssd_update_row_addrs(
@@ -1162,12 +1706,63 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1162
1706
  unique_indices_length_curr=curr_data.actions_count_gpu,
1163
1707
  )
1164
1708
 
1709
+ def _update_feature_score_metadata(
1710
+ self,
1711
+ linear_cache_indices: Tensor,
1712
+ weights: Tensor,
1713
+ d2h_stream: torch.cuda.Stream,
1714
+ write_stream: torch.cuda.Stream,
1715
+ pre_event_for_write: torch.cuda.Event,
1716
+ post_event: Optional[torch.cuda.Event] = None,
1717
+ ) -> None:
1718
+ """
1719
+ Write feature score metadata to DRAM
1720
+
1721
+ This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
1722
+ The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
1723
+
1724
+ Args:
1725
+ linear_cache_indices: GPU tensor containing cache indices
1726
+ weights: GPU tensor containing feature scores
1727
+ d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
1728
+ write_stream: Stream for metadata write operation
1729
+ pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
1730
+ post_event: Event to record when the operation is done
1731
+ """
1732
+ # Start D2H copy on d2h_stream
1733
+ with torch.cuda.stream(d2h_stream):
1734
+ # Record streams to prevent premature deallocation
1735
+ linear_cache_indices.record_stream(d2h_stream)
1736
+ weights.record_stream(d2h_stream)
1737
+ # Do the D2H copy
1738
+ linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
1739
+ score_weights_cpu = self.to_pinned_cpu(weights)
1740
+
1741
+ # Write feature score metadata to DRAM
1742
+ with record_function("## ssd_write_feature_score_metadata ##"):
1743
+ with torch.cuda.stream(write_stream):
1744
+ write_stream.wait_event(pre_event_for_write)
1745
+ write_stream.wait_stream(d2h_stream)
1746
+ self.record_function_via_dummy_profile(
1747
+ "## ssd_write_feature_score_metadata ##",
1748
+ self.ssd_db.set_feature_score_metadata_cuda,
1749
+ linear_cache_indices_cpu,
1750
+ torch.tensor(
1751
+ [score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
1752
+ ),
1753
+ score_weights_cpu,
1754
+ )
1755
+
1756
+ if post_event is not None:
1757
+ write_stream.record_event(post_event)
1758
+
1165
1759
  def prefetch(
1166
1760
  self,
1167
1761
  indices: Tensor,
1168
1762
  offsets: Tensor,
1763
+ weights: Optional[Tensor] = None, # todo: need to update caller
1169
1764
  forward_stream: Optional[torch.cuda.Stream] = None,
1170
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
1765
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
1171
1766
  ) -> None:
1172
1767
  if self.prefetch_stream is None and forward_stream is not None:
1173
1768
  # Set the prefetch stream to the current stream
@@ -1191,6 +1786,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1191
1786
  self._prefetch(
1192
1787
  indices,
1193
1788
  offsets,
1789
+ weights,
1194
1790
  vbe_metadata,
1195
1791
  forward_stream,
1196
1792
  )
@@ -1199,11 +1795,17 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1199
1795
  self,
1200
1796
  indices: Tensor,
1201
1797
  offsets: Tensor,
1798
+ weights: Optional[Tensor] = None,
1202
1799
  vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
1203
1800
  forward_stream: Optional[torch.cuda.Stream] = None,
1204
1801
  ) -> None:
1205
- # TODO: Refactor prefetch
1802
+ # Wait for any ongoing direct_write_embedding operations to complete
1803
+ # Moving this from forward() to _prefetch() is more logical as direct_write
1804
+ # operations affect the same cache structures that prefetch interacts with
1206
1805
  current_stream = torch.cuda.current_stream()
1806
+ if self._embedding_cache_mode:
1807
+ current_stream.wait_event(self.direct_write_l1_complete_event)
1808
+ current_stream.wait_event(self.direct_write_sp_complete_event)
1207
1809
 
1208
1810
  B_offsets = None
1209
1811
  max_B = -1
@@ -1284,10 +1886,83 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1284
1886
  masks=torch.where(evicted_indices != -1, 1, 0),
1285
1887
  count=actions_count_gpu,
1286
1888
  )
1287
-
1288
- with record_function("## ssd_d2h_inserted_indices ##"):
1289
- # Transfer actions_count and insert_indices right away to
1290
- # incrase an overlap opportunity
1889
+ has_raw_embedding_streaming = False
1890
+ if self.enable_raw_embedding_streaming:
1891
+ # when pipelining is enabled
1892
+ # prefetch in iter i happens before the backward sparse in iter i - 1
1893
+ # so embeddings for iter i - 1's changed ids are not updated.
1894
+ # so we can only fetch the indices from the iter i - 2
1895
+ # when pipelining is disabled
1896
+ # prefetch in iter i happens before forward iter i
1897
+ # so we can get the iter i - 1's changed ids safely.
1898
+ target_prev_iter = 1
1899
+ if self.prefetch_pipeline:
1900
+ target_prev_iter = 2
1901
+ if len(self.prefetched_info) > (target_prev_iter - 1):
1902
+ with record_function(
1903
+ "## ssd_lookup_prefetched_rows {} {} ##".format(
1904
+ self.timestep, self.tbe_unique_id
1905
+ )
1906
+ ):
1907
+ # wait for the copy to finish before overwriting the buffer
1908
+ self.raw_embedding_stream_sync(
1909
+ stream=self.ssd_eviction_stream,
1910
+ pre_event=self.ssd_event_cache_streamed,
1911
+ post_event=self.ssd_event_cache_streaming_synced,
1912
+ name="cache_update",
1913
+ )
1914
+ current_stream.wait_event(self.ssd_event_cache_streaming_synced)
1915
+ updated_indices, updated_counts_gpu = self.prefetched_info.pop(
1916
+ 0
1917
+ )
1918
+ self.lxu_cache_updated_indices[: updated_indices.size(0)].copy_(
1919
+ updated_indices,
1920
+ non_blocking=True,
1921
+ )
1922
+ self.lxu_cache_updated_count[:1].copy_(
1923
+ updated_counts_gpu, non_blocking=True
1924
+ )
1925
+ has_raw_embedding_streaming = True
1926
+
1927
+ with record_function(
1928
+ "## ssd_save_prefetched_rows {} {} ##".format(
1929
+ self.timestep, self.tbe_unique_id
1930
+ )
1931
+ ):
1932
+ masked_updated_indices = torch.where(
1933
+ torch.where(lxu_cache_locations != -1, True, False),
1934
+ linear_cache_indices,
1935
+ -1,
1936
+ )
1937
+
1938
+ (
1939
+ uni_updated_indices,
1940
+ uni_updated_indices_length,
1941
+ ) = get_unique_indices_v2(
1942
+ masked_updated_indices,
1943
+ self.total_hash_size,
1944
+ compute_count=False,
1945
+ compute_inverse_indices=False,
1946
+ )
1947
+ assert uni_updated_indices is not None
1948
+ assert uni_updated_indices_length is not None
1949
+ # The unique indices has 1 more -1 element than needed,
1950
+ # which might make the tensor length go out of range
1951
+ # compared to the pre-allocated buffer.
1952
+ unique_len = min(
1953
+ self.lxu_cache_weights.size(0),
1954
+ uni_updated_indices.size(0),
1955
+ )
1956
+ self.prefetched_info.append(
1957
+ (
1958
+ uni_updated_indices.narrow(0, 0, unique_len),
1959
+ uni_updated_indices_length.clamp(max=unique_len),
1960
+ )
1961
+ )
1962
+
1963
+ with record_function("## ssd_d2h_inserted_indices ##"):
1964
+ # Transfer actions_count and insert_indices right away to
1965
+ # incrase an overlap opportunity
1291
1966
  actions_count_cpu, inserted_indices_cpu = (
1292
1967
  self.to_pinned_cpu_on_stream_wait_on_another_stream(
1293
1968
  tensors=[
@@ -1312,7 +1987,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1312
1987
 
1313
1988
  # Allocation a scratch pad for the current iteration. The scratch
1314
1989
  # pad is a UVA tensor
1315
- inserted_rows_shape = (assigned_cache_slots.numel(), self.max_D)
1990
+ inserted_rows_shape = (assigned_cache_slots.numel(), self.cache_row_dim)
1316
1991
  if linear_cache_indices.numel() > 0:
1317
1992
  inserted_rows = torch.ops.fbgemm.new_unified_tensor(
1318
1993
  torch.zeros(
@@ -1415,25 +2090,66 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1415
2090
  # Store info for evicting the previous iteration's
1416
2091
  # scratch pad after the corresponding backward pass is
1417
2092
  # done
1418
- self.ssd_location_update_data.append(
1419
- (
1420
- sp_curr_prev_map_gpu,
1421
- inserted_rows,
2093
+ if self.training:
2094
+ self.ssd_location_update_data.append(
2095
+ (
2096
+ sp_curr_prev_map_gpu,
2097
+ inserted_rows,
2098
+ )
1422
2099
  )
1423
- )
1424
2100
 
1425
2101
  # Ensure the previous iterations eviction is complete
1426
2102
  current_stream.wait_event(self.ssd_event_sp_evict)
1427
2103
  # Ensure that D2H is done
1428
2104
  current_stream.wait_event(self.ssd_event_get_inputs_cpy)
1429
2105
 
2106
+ if self.enable_raw_embedding_streaming and has_raw_embedding_streaming:
2107
+ current_stream.wait_event(self.ssd_event_sp_streamed)
2108
+ with record_function(
2109
+ "## ssd_compute_updated_rows {} {} ##".format(
2110
+ self.timestep, self.tbe_unique_id
2111
+ )
2112
+ ):
2113
+ # cache rows that are changed in the previous iteration
2114
+ updated_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
2115
+ self.lxu_cache_updated_indices,
2116
+ self.lxu_cache_state,
2117
+ self.total_hash_size,
2118
+ self.gather_ssd_cache_stats,
2119
+ self.local_ssd_cache_stats,
2120
+ )
2121
+ torch.ops.fbgemm.masked_index_select(
2122
+ self.lxu_cache_updated_weights,
2123
+ updated_cache_locations,
2124
+ self.lxu_cache_weights,
2125
+ self.lxu_cache_updated_count,
2126
+ )
2127
+ current_stream.record_event(self.ssd_event_cache_streaming_computed)
2128
+
2129
+ self.raw_embedding_stream(
2130
+ rows=self.lxu_cache_updated_weights,
2131
+ indices_cpu=self.lxu_cache_updated_indices,
2132
+ actions_count_cpu=self.lxu_cache_updated_count,
2133
+ stream=self.ssd_eviction_stream,
2134
+ pre_event=self.ssd_event_cache_streaming_computed,
2135
+ post_event=self.ssd_event_cache_streamed,
2136
+ is_rows_uvm=True,
2137
+ blocking_tensor_copy=False,
2138
+ name="cache_update",
2139
+ )
2140
+
1430
2141
  if self.gather_ssd_cache_stats:
1431
2142
  # call to collect past SSD IO dur right before next rocksdb IO
1432
2143
 
1433
2144
  self.ssd_cache_stats = torch.add(
1434
2145
  self.ssd_cache_stats, self.local_ssd_cache_stats
1435
2146
  )
1436
- self._report_ssd_stats()
2147
+ # only report metrics from rank0 to avoid flooded logging
2148
+ if dist.get_rank() == 0:
2149
+ self._report_kv_backend_stats()
2150
+
2151
+ # May trigger eviction if free mem trigger mode enabled before get cuda
2152
+ self.may_trigger_eviction()
1437
2153
 
1438
2154
  # Fetch data from SSD
1439
2155
  if linear_cache_indices.numel() > 0:
@@ -1457,21 +2173,35 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1457
2173
  use_pipeline=self.prefetch_pipeline,
1458
2174
  )
1459
2175
 
1460
- if linear_cache_indices.numel() > 0:
1461
- # Evict rows from cache to SSD
1462
- self.evict(
1463
- rows=self.lxu_cache_evicted_weights,
1464
- indices_cpu=self.lxu_cache_evicted_indices,
1465
- actions_count_cpu=self.lxu_cache_evicted_count,
1466
- stream=self.ssd_eviction_stream,
1467
- pre_event=self.ssd_event_get,
1468
- # Record completion event after scratch pad eviction
1469
- # instead since that happens after L1 eviction
1470
- post_event=self.ssd_event_cache_evict,
1471
- is_rows_uvm=True,
1472
- name="cache",
1473
- is_bwd=False,
1474
- )
2176
+ if self.training:
2177
+ if linear_cache_indices.numel() > 0:
2178
+ # Evict rows from cache to SSD
2179
+ self.evict(
2180
+ rows=self.lxu_cache_evicted_weights,
2181
+ indices_cpu=self.lxu_cache_evicted_indices,
2182
+ actions_count_cpu=self.lxu_cache_evicted_count,
2183
+ stream=self.ssd_eviction_stream,
2184
+ pre_event=self.ssd_event_get,
2185
+ # Record completion event after scratch pad eviction
2186
+ # instead since that happens after L1 eviction
2187
+ post_event=self.ssd_event_cache_evict,
2188
+ is_rows_uvm=True,
2189
+ name="cache",
2190
+ is_bwd=False,
2191
+ )
2192
+ if (
2193
+ self.backend_type == BackendType.DRAM
2194
+ and weights is not None
2195
+ and linear_cache_indices.numel() > 0
2196
+ ):
2197
+ # Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
2198
+ self._update_feature_score_metadata(
2199
+ linear_cache_indices=linear_cache_indices,
2200
+ weights=weights,
2201
+ d2h_stream=self.ssd_memcpy_stream,
2202
+ write_stream=self.feature_score_stream,
2203
+ pre_event_for_write=self.ssd_event_cache_evict,
2204
+ )
1475
2205
 
1476
2206
  # Generate row addresses (pointing to either L1 or the current
1477
2207
  # iteration's scratch pad)
@@ -1553,24 +2283,32 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1553
2283
  )
1554
2284
  )
1555
2285
 
1556
- # Store scratch pad info for post backward eviction
1557
- self.ssd_scratch_pad_eviction_data.append(
1558
- (
1559
- inserted_rows,
1560
- post_bwd_evicted_indices_cpu,
1561
- actions_count_cpu,
1562
- linear_cache_indices.numel() > 0,
2286
+ # Store scratch pad info for post backward eviction only for training
2287
+ # for eval job, no backward pass, so no need to store this info
2288
+ if self.training:
2289
+ self.ssd_scratch_pad_eviction_data.append(
2290
+ (
2291
+ inserted_rows,
2292
+ post_bwd_evicted_indices_cpu,
2293
+ actions_count_cpu,
2294
+ linear_cache_indices.numel() > 0,
2295
+ )
1563
2296
  )
1564
- )
1565
2297
 
1566
2298
  # Store data for forward
1567
2299
  self.ssd_prefetch_data.append(prefetch_data)
1568
2300
 
2301
+ # Record an event to mark the completion of prefetch operations
2302
+ # This will be used by direct_write_embedding to ensure it doesn't run concurrently with prefetch
2303
+ current_stream.record_event(self.prefetch_complete_event)
2304
+
1569
2305
  @torch.jit.ignore
1570
2306
  def _generate_vbe_metadata(
1571
2307
  self,
1572
2308
  offsets: Tensor,
1573
- batch_size_per_feature_per_rank: Optional[List[List[int]]],
2309
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
2310
+ vbe_output: Optional[Tensor] = None,
2311
+ vbe_output_offsets: Optional[Tensor] = None,
1574
2312
  ) -> invokers.lookup_args.VBEMetadata:
1575
2313
  # Blocking D2H copy, but only runs at first call
1576
2314
  self.feature_dims = self.feature_dims.cpu()
@@ -1589,19 +2327,58 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1589
2327
  self.pooling_mode,
1590
2328
  self.feature_dims,
1591
2329
  self.current_device,
2330
+ vbe_output,
2331
+ vbe_output_offsets,
1592
2332
  )
1593
2333
 
2334
+ def _increment_iteration(self) -> int:
2335
+ # Although self.iter_cpu is created on CPU. It might be transferred to
2336
+ # GPU by the user. So, we need to transfer it to CPU explicitly. This
2337
+ # should be done only once.
2338
+ self.iter_cpu = self.iter_cpu.cpu()
2339
+
2340
+ # Sync with loaded state
2341
+ # Wrap to make it compatible with PT2 compile
2342
+ if not is_torchdynamo_compiling():
2343
+ if self.iter_cpu.item() == 0:
2344
+ self.iter_cpu.fill_(self.iter.cpu().item())
2345
+
2346
+ # Increment the iteration counter
2347
+ # The CPU counterpart is used for local computation
2348
+ iter_int = int(self.iter_cpu.add_(1).item())
2349
+ # The GPU counterpart is used for checkpointing
2350
+ self.iter.add_(1)
2351
+
2352
+ return iter_int
2353
+
1594
2354
  def forward(
1595
2355
  self,
1596
2356
  indices: Tensor,
1597
2357
  offsets: Tensor,
2358
+ weights: Optional[Tensor] = None,
1598
2359
  per_sample_weights: Optional[Tensor] = None,
1599
2360
  feature_requires_grad: Optional[Tensor] = None,
1600
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
2361
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
2362
+ vbe_output: Optional[Tensor] = None,
2363
+ vbe_output_offsets: Optional[Tensor] = None,
1601
2364
  # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
1602
2365
  ) -> Tensor:
2366
+ self.clear_cache()
2367
+ if vbe_output is not None or vbe_output_offsets is not None:
2368
+ # CPU is not supported in SSD TBE
2369
+ check_allocated_vbe_output(
2370
+ self.output_dtype,
2371
+ batch_size_per_feature_per_rank,
2372
+ vbe_output,
2373
+ vbe_output_offsets,
2374
+ )
1603
2375
  indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
1604
- indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
2376
+ indices,
2377
+ offsets,
2378
+ per_sample_weights,
2379
+ batch_size_per_feature_per_rank,
2380
+ vbe_output=vbe_output,
2381
+ vbe_output_offsets=vbe_output_offsets,
1605
2382
  )
1606
2383
 
1607
2384
  if len(self.timesteps_prefetched) == 0:
@@ -1615,7 +2392,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1615
2392
  context=self.step,
1616
2393
  stream=self.ssd_eviction_stream,
1617
2394
  ):
1618
- self._prefetch(indices, offsets, vbe_metadata)
2395
+ self._prefetch(indices, offsets, weights, vbe_metadata)
1619
2396
 
1620
2397
  assert len(self.ssd_prefetch_data) > 0
1621
2398
 
@@ -1674,18 +2451,33 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1674
2451
  "post_bwd_evicted_indices": post_bwd_evicted_indices_cpu,
1675
2452
  "actions_count": actions_count_cpu,
1676
2453
  },
2454
+ enable_optimizer_offloading=self.enable_optimizer_offloading,
1677
2455
  # pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
1678
2456
  vbe_metadata=vbe_metadata,
2457
+ learning_rate_tensor=self.learning_rate_tensor,
2458
+ info_B_num_bits=self.info_B_num_bits,
2459
+ info_B_mask=self.info_B_mask,
1679
2460
  )
1680
2461
 
1681
2462
  self.timesteps_prefetched.pop(0)
1682
2463
  self.step += 1
1683
2464
 
1684
- if self.optimizer == OptimType.EXACT_SGD:
1685
- raise AssertionError(
1686
- "SSDTableBatchedEmbeddingBags currently does not support SGD"
2465
+ # Increment the iteration (value is used for certain optimizers)
2466
+ iter_int = self._increment_iteration()
2467
+
2468
+ if self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
2469
+ momentum2 = invokers.lookup_args_ssd.Momentum(
2470
+ # pyre-ignore[6]
2471
+ dev=self.momentum2_dev,
2472
+ # pyre-ignore[6]
2473
+ host=self.momentum2_host,
2474
+ # pyre-ignore[6]
2475
+ uvm=self.momentum2_uvm,
2476
+ # pyre-ignore[6]
2477
+ offsets=self.momentum2_offsets,
2478
+ # pyre-ignore[6]
2479
+ placements=self.momentum2_placements,
1687
2480
  )
1688
- return invokers.lookup_sgd_ssd.invoke(common_args, self.optimizer_args)
1689
2481
 
1690
2482
  momentum1 = invokers.lookup_args_ssd.Momentum(
1691
2483
  dev=self.momentum1_dev,
@@ -1696,44 +2488,584 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1696
2488
  )
1697
2489
 
1698
2490
  if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
1699
- # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
1700
2491
  return invokers.lookup_rowwise_adagrad_ssd.invoke(
1701
2492
  common_args, self.optimizer_args, momentum1
1702
2493
  )
1703
2494
 
2495
+ elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2496
+ return invokers.lookup_partial_rowwise_adam_ssd.invoke(
2497
+ common_args,
2498
+ self.optimizer_args,
2499
+ momentum1,
2500
+ # pyre-ignore[61]
2501
+ momentum2,
2502
+ iter_int,
2503
+ )
2504
+
2505
+ elif self.optimizer == OptimType.ADAM:
2506
+ row_counter = invokers.lookup_args_ssd.Momentum(
2507
+ # pyre-fixme[6]
2508
+ dev=self.row_counter_dev,
2509
+ # pyre-fixme[6]
2510
+ host=self.row_counter_host,
2511
+ # pyre-fixme[6]
2512
+ uvm=self.row_counter_uvm,
2513
+ # pyre-fixme[6]
2514
+ offsets=self.row_counter_offsets,
2515
+ # pyre-fixme[6]
2516
+ placements=self.row_counter_placements,
2517
+ )
2518
+
2519
+ return invokers.lookup_adam_ssd.invoke(
2520
+ common_args,
2521
+ self.optimizer_args,
2522
+ momentum1,
2523
+ # pyre-ignore[61]
2524
+ momentum2,
2525
+ iter_int,
2526
+ row_counter=row_counter,
2527
+ )
2528
+
1704
2529
  @torch.jit.ignore
1705
- def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor]]:
2530
+ def _split_optimizer_states_non_kv_zch(
2531
+ self,
2532
+ ) -> list[list[torch.Tensor]]:
1706
2533
  """
1707
- Returns a list of states, split by table
1708
- Testing only
2534
+ Returns a list of optimizer states (view), split by table.
2535
+
2536
+ Returns:
2537
+ A list of list of states. Shape = (the number of tables, the number
2538
+ of states).
2539
+
2540
+ The following shows the list of states (in the returned order) for
2541
+ each optimizer:
2542
+
2543
+ (1) `EXACT_ROWWISE_ADAGRAD`: `momentum1` (rowwise)
2544
+
2545
+ (1) `PARTIAL_ROWWISE_ADAM`: `momentum1`, `momentum2` (rowwise)
1709
2546
  """
1710
- (rows, _) = zip(*self.embedding_specs)
1711
2547
 
1712
- rows_cumsum = [0] + list(itertools.accumulate(rows))
2548
+ # Row count per table
2549
+ rows, dims = zip(*self.embedding_specs)
2550
+ # Cumulative row counts per table for rowwise states
2551
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
2552
+ # Cumulative element counts per table for elementwise states
2553
+ elem_count_cumsum: list[int] = [0] + list(
2554
+ itertools.accumulate([r * d for r, d in self.embedding_specs])
2555
+ )
2556
+
2557
+ # pyre-ignore[53]
2558
+ def _slice(tensor: Tensor, t: int, rowwise: bool) -> Tensor:
2559
+ d: int = dims[t]
2560
+ e: int = rows[t]
2561
+
2562
+ if not rowwise:
2563
+ # Optimizer state is element-wise - compute the table offset for
2564
+ # the table, view the slice as 2D tensor
2565
+ return tensor.detach()[
2566
+ elem_count_cumsum[t] : elem_count_cumsum[t + 1]
2567
+ ].view(-1, d)
2568
+ else:
2569
+ # Optimizer state is row-wise - fetch elements in range and view
2570
+ # slice as 1D
2571
+ return tensor.detach()[
2572
+ row_count_cumsum[t] : row_count_cumsum[t + 1]
2573
+ ].view(e)
2574
+
2575
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2576
+ return [
2577
+ [_slice(self.momentum1_dev, t, rowwise=True)]
2578
+ for t, _ in enumerate(rows)
2579
+ ]
2580
+ elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2581
+ return [
2582
+ [
2583
+ _slice(self.momentum1_dev, t, rowwise=False),
2584
+ # pyre-ignore[6]
2585
+ _slice(self.momentum2_dev, t, rowwise=True),
2586
+ ]
2587
+ for t, _ in enumerate(rows)
2588
+ ]
2589
+
2590
+ elif self.optimizer == OptimType.ADAM:
2591
+ return [
2592
+ [
2593
+ _slice(self.momentum1_dev, t, rowwise=False),
2594
+ # pyre-ignore[6]
2595
+ _slice(self.momentum2_dev, t, rowwise=False),
2596
+ ]
2597
+ for t, _ in enumerate(rows)
2598
+ ]
2599
+
2600
+ else:
2601
+ raise NotImplementedError(
2602
+ f"Getting optimizer states is not supported for {self.optimizer}"
2603
+ )
2604
+
2605
+ @torch.jit.ignore
2606
+ def _split_optimizer_states_kv_zch_no_offloading(
2607
+ self,
2608
+ sorted_ids: torch.Tensor,
2609
+ ) -> list[list[torch.Tensor]]:
2610
+
2611
+ # Row count per table
2612
+ rows, dims = zip(*self.embedding_specs)
2613
+ # Cumulative row counts per table for rowwise states
2614
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
2615
+ # Cumulative element counts per table for elementwise states
2616
+ elem_count_cumsum: list[int] = [0] + list(
2617
+ itertools.accumulate([r * d for r, d in self.embedding_specs])
2618
+ )
2619
+
2620
+ # pyre-ignore[53]
2621
+ def _slice(state_name: str, tensor: Tensor, t: int, rowwise: bool) -> Tensor:
2622
+ d: int = dims[t]
2623
+
2624
+ # pyre-ignore[16]
2625
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2626
+ # pyre-ignore[16]
2627
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
2628
+
2629
+ if sorted_ids is None or sorted_ids[t].numel() == 0:
2630
+ # Empty optimizer state for module initialization
2631
+ return torch.empty(
2632
+ 0,
2633
+ dtype=(
2634
+ self.optimizer_state_dtypes.get(
2635
+ state_name, SparseType.FP32
2636
+ ).as_dtype()
2637
+ ),
2638
+ device="cpu",
2639
+ )
2640
+
2641
+ elif not rowwise:
2642
+ # Optimizer state is element-wise - materialize the local ids
2643
+ # based on the sorted_ids compute the table offset for the
2644
+ # table, view the slice as 2D tensor of e x d, then fetch the
2645
+ # sub-slice by local ids
2646
+ #
2647
+ # local_ids is [N, 1], flatten it to N to keep the returned tensor 2D
2648
+ local_ids = (sorted_ids[t] - bucket_id_start * bucket_size).view(-1)
2649
+ return (
2650
+ tensor.detach()
2651
+ .cpu()[elem_count_cumsum[t] : elem_count_cumsum[t + 1]]
2652
+ .view(-1, d)[local_ids]
2653
+ )
2654
+
2655
+ else:
2656
+ # Optimizer state is row-wise - materialize the local ids based
2657
+ # on the sorted_ids and table offset (i.e. row count cumsum),
2658
+ # then fetch by local ids
2659
+ linearized_local_ids = (
2660
+ sorted_ids[t] - bucket_id_start * bucket_size + row_count_cumsum[t]
2661
+ )
2662
+ return tensor.detach().cpu()[linearized_local_ids].view(-1)
2663
+
2664
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2665
+ return [
2666
+ [_slice("momentum1", self.momentum1_dev, t, rowwise=True)]
2667
+ for t, _ in enumerate(rows)
2668
+ ]
2669
+
2670
+ elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2671
+ return [
2672
+ [
2673
+ _slice("momentum1", self.momentum1_dev, t, rowwise=False),
2674
+ # pyre-ignore[6]
2675
+ _slice("momentum2", self.momentum2_dev, t, rowwise=True),
2676
+ ]
2677
+ for t, _ in enumerate(rows)
2678
+ ]
2679
+
2680
+ elif self.optimizer == OptimType.ADAM:
2681
+ return [
2682
+ [
2683
+ _slice("momentum1", self.momentum1_dev, t, rowwise=False),
2684
+ # pyre-ignore[6]
2685
+ _slice("momentum2", self.momentum2_dev, t, rowwise=False),
2686
+ ]
2687
+ for t, _ in enumerate(rows)
2688
+ ]
2689
+
2690
+ else:
2691
+ raise NotImplementedError(
2692
+ f"Getting optimizer states is not supported for {self.optimizer}"
2693
+ )
2694
+
2695
+ @torch.jit.ignore
2696
+ def _split_optimizer_states_kv_zch_w_offloading(
2697
+ self,
2698
+ sorted_ids: torch.Tensor,
2699
+ no_snapshot: bool = True,
2700
+ should_flush: bool = False,
2701
+ ) -> list[list[torch.Tensor]]:
2702
+ dtype = self.weights_precision.as_dtype()
2703
+ # Row count per table
2704
+ rows_, dims_ = zip(*self.embedding_specs)
2705
+ # Cumulative row counts per table for rowwise states
2706
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
2707
+
2708
+ snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
2709
+ no_snapshot=no_snapshot,
2710
+ should_flush=should_flush,
2711
+ )
2712
+
2713
+ # pyre-ignore[53]
2714
+ def _fetch_offloaded_optimizer_states(
2715
+ t: int,
2716
+ ) -> list[Tensor]:
2717
+ e: int = rows_[t]
2718
+ d: int = dims_[t]
2719
+
2720
+ # pyre-ignore[16]
2721
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2722
+ # pyre-ignore[16]
2723
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
2724
+
2725
+ row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
2726
+ # Count of rows to fetch
2727
+ rows_to_fetch = sorted_ids[t].numel()
2728
+
2729
+ # Lookup the byte offsets for each optimizer state
2730
+ optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
2731
+ d, self.weights_precision, self.optimizer_state_dtypes
2732
+ )
2733
+ # Find the minimum start of all the start/end pairs - we have to
2734
+ # offset the start/end pairs by this value to get the correct start/end
2735
+ offset_ = min(
2736
+ [start for _, (start, _) in optimizer_state_byte_offsets.items()]
2737
+ )
2738
+ # Update the start/end pairs to be relative to offset_
2739
+ optimizer_state_byte_offsets = dict(
2740
+ (k, (v1 - offset_, v2 - offset_))
2741
+ for k, (v1, v2) in optimizer_state_byte_offsets.items()
2742
+ )
2743
+
2744
+ # Since the backend returns cache rows that pack the weights and
2745
+ # optimizer states together, reading the whole tensor could cause OOM,
2746
+ # so we use the KVTensorWrapper abstraction to query the backend and
2747
+ # fetch the data in chunks instead.
2748
+ tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2749
+ shape=[
2750
+ e,
2751
+ # Dim is terms of **weights** dtype
2752
+ self.optimizer_state_dim,
2753
+ ],
2754
+ dtype=dtype,
2755
+ row_offset=row_offset,
2756
+ snapshot_handle=snapshot_handle,
2757
+ sorted_indices=sorted_ids[t],
2758
+ width_offset=pad4(d),
2759
+ )
2760
+ (
2761
+ tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2762
+ if self.backend_type == BackendType.SSD
2763
+ else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2764
+ )
2765
+
2766
+ # Fetch the state size table for the given weights domension
2767
+ state_size_table = self.optimizer.state_size_table(d)
2768
+
2769
+ # Create a 2D output buffer of [rows x optimizer state dim] with the
2770
+ # weights type as the type. For optimizers with multiple states (e.g.
2771
+ # momentum1 and momentum2), this tensor will include data from all
2772
+ # states, hence self.optimizer_state_dim as the row size.
2773
+ optimizer_states_buffer = torch.empty(
2774
+ (rows_to_fetch, self.optimizer_state_dim), dtype=dtype, device="cpu"
2775
+ )
2776
+
2777
+ # Set the chunk size for fetching
2778
+ chunk_size = (
2779
+ # 10M rows => 260(max_D)* 2(ele_bytes) * 10M => 5.2GB mem spike
2780
+ 10_000_000
2781
+ )
2782
+ logging.info(f"split optimizer chunk rows: {chunk_size}")
2783
+
2784
+ # Chunk the fetching by chunk_size
2785
+ for i in range(0, rows_to_fetch, chunk_size):
2786
+ length = min(chunk_size, rows_to_fetch - i)
2787
+
2788
+ # Fetch from backend and copy to the output buffer
2789
+ optimizer_states_buffer.narrow(0, i, length).copy_(
2790
+ tensor_wrapper.narrow(0, i, length).view(dtype)
2791
+ )
2792
+
2793
+ # Now split up the buffer into N views, N for each optimizer state
2794
+ optimizer_states: list[Tensor] = []
2795
+ for state_name in self.optimizer.state_names():
2796
+ # Extract the offsets
2797
+ start, end = optimizer_state_byte_offsets[state_name]
2798
+
2799
+ state = optimizer_states_buffer.view(
2800
+ # Force tensor to byte view
2801
+ dtype=torch.uint8
2802
+ # Copy by byte offsets
2803
+ )[:, start:end].view(
2804
+ # Re-view in the state's target dtype
2805
+ self.optimizer_state_dtypes.get(
2806
+ state_name, SparseType.FP32
2807
+ ).as_dtype()
2808
+ )
2809
+
2810
+ optimizer_states.append(
2811
+ # If the state is rowwise (i.e. just one element per row),
2812
+ # then re-view as 1D tensor
2813
+ state
2814
+ if state_size_table[state_name] > 1
2815
+ else state.view(-1)
2816
+ )
2817
+
2818
+ # Return the views
2819
+ return optimizer_states
1713
2820
 
1714
2821
  return [
1715
2822
  (
1716
- self.momentum1_dev.detach()[rows_cumsum[t] : rows_cumsum[t + 1]].view(
1717
- row
1718
- ),
2823
+ self.optimizer.empty_states([0], [d], self.optimizer_state_dtypes)[0]
2824
+ # Return a set of empty states if sorted_ids[t] is empty
2825
+ if sorted_ids is None or sorted_ids[t].numel() == 0
2826
+ # Else fetch the list of optimizer states for the table
2827
+ else _fetch_offloaded_optimizer_states(t)
1719
2828
  )
1720
- for t, row in enumerate(rows)
2829
+ for t, d in enumerate(dims_)
1721
2830
  ]
1722
2831
 
2832
+ @torch.jit.ignore
2833
+ def _split_optimizer_states_kv_zch_whole_row(
2834
+ self,
2835
+ sorted_ids: torch.Tensor,
2836
+ no_snapshot: bool = True,
2837
+ should_flush: bool = False,
2838
+ ) -> list[list[torch.Tensor]]:
2839
+ dtype = self.weights_precision.as_dtype()
2840
+
2841
+ # Row and dimension counts per table
2842
+ # rows_ is only used here to compute the virtual table offsets
2843
+ rows_, dims_ = zip(*self.embedding_specs)
2844
+
2845
+ # Cumulative row counts per (virtual) table for rowwise states
2846
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
2847
+
2848
+ snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
2849
+ no_snapshot=no_snapshot,
2850
+ should_flush=should_flush,
2851
+ )
2852
+
2853
+ # pyre-ignore[53]
2854
+ def _fetch_offloaded_optimizer_states(
2855
+ t: int,
2856
+ ) -> list[Tensor]:
2857
+ d: int = dims_[t]
2858
+
2859
+ # pyre-ignore[16]
2860
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2861
+ # pyre-ignore[16]
2862
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
2863
+ row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
2864
+
2865
+ # When backend returns whole row, the optimizer will be returned as
2866
+ # PMT directly
2867
+ if sorted_ids[t].size(0) == 0 and self.local_weight_counts[t] > 0:
2868
+ logging.info(
2869
+ f"Before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
2870
+ )
2871
+ sorted_ids[t] = torch.zeros(
2872
+ (self.local_weight_counts[t], 1),
2873
+ device=torch.device("cpu"),
2874
+ dtype=torch.int64,
2875
+ )
2876
+
2877
+ # Lookup the byte offsets for each optimizer state relative to the
2878
+ # start of the weights
2879
+ optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
2880
+ d, self.weights_precision, self.optimizer_state_dtypes
2881
+ )
2882
+ # Get the number of elements (of the optimizer state dtype) per state
2883
+ optimizer_state_size_table = self.optimizer.state_size_table(d)
2884
+
2885
+ # Get metaheader dimensions in number of elements of weight dtype
2886
+ metaheader_dim = (
2887
+ # pyre-ignore[16]
2888
+ self.kv_zch_params.eviction_policy.meta_header_lens[t]
2889
+ )
2890
+
2891
+ # Now split up the buffer into N views, N for each optimizer state
2892
+ optimizer_states: list[PartiallyMaterializedTensor] = []
2893
+ for state_name in self.optimizer.state_names():
2894
+ state_dtype = self.optimizer_state_dtypes.get(
2895
+ state_name, SparseType.FP32
2896
+ ).as_dtype()
2897
+
2898
+ # Get the size of the state in elements of the optimizer state,
2899
+ # in terms of the **weights** dtype
2900
+ state_size = math.ceil(
2901
+ optimizer_state_size_table[state_name]
2902
+ * state_dtype.itemsize
2903
+ / dtype.itemsize
2904
+ )
2905
+
2906
+ # Extract the offsets relative to the start of the weights (in
2907
+ # num bytes)
2908
+ start, _ = optimizer_state_byte_offsets[state_name]
2909
+
2910
+ # Convert the start to number of elements in terms of the
2911
+ # **weights** dtype, then add the mmetaheader dim offset
2912
+ start = metaheader_dim + start // dtype.itemsize
2913
+
2914
+ shape = [
2915
+ (
2916
+ sorted_ids[t].size(0)
2917
+ if sorted_ids is not None and sorted_ids[t].size(0) > 0
2918
+ else self.local_weight_counts[t]
2919
+ ),
2920
+ (
2921
+ # Dim is in terms of the **weights** dtype
2922
+ state_size
2923
+ ),
2924
+ ]
2925
+
2926
+ # NOTE: We have to view using the **weights** dtype, as
2927
+ # there is currently a bug with KVTensorWrapper where using
2928
+ # a different dtype does not result in the same bytes being
2929
+ # returned, e.g.
2930
+ #
2931
+ # KVTensorWrapper(dtype=fp32, width_offset=130, shape=[N, 1])
2932
+ #
2933
+ # is NOT the same as
2934
+ #
2935
+ # KVTensorWrapper(dtype=fp16, width_offset=260, shape=[N, 2]).view(-1).view(fp32)
2936
+ #
2937
+ # TODO: Fix KVTensorWrapper to support viewing data under different dtypes
2938
+ tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2939
+ shape=shape,
2940
+ dtype=(
2941
+ # NOTE: Use the *weights* dtype
2942
+ dtype
2943
+ ),
2944
+ row_offset=row_offset,
2945
+ snapshot_handle=snapshot_handle,
2946
+ sorted_indices=sorted_ids[t],
2947
+ width_offset=(
2948
+ # NOTE: Width offset is in terms of **weights** dtype
2949
+ start
2950
+ ),
2951
+ # Optimizer written to DB with weights, so skip write here
2952
+ read_only=True,
2953
+ )
2954
+ (
2955
+ tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2956
+ if self.backend_type == BackendType.SSD
2957
+ else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2958
+ )
2959
+
2960
+ optimizer_states.append(
2961
+ PartiallyMaterializedTensor(tensor_wrapper, True)
2962
+ )
2963
+
2964
+ # pyre-ignore [7]
2965
+ return optimizer_states
2966
+
2967
+ return [_fetch_offloaded_optimizer_states(t) for t, _ in enumerate(dims_)]
2968
+
2969
+ @torch.jit.export
2970
+ def split_optimizer_states(
2971
+ self,
2972
+ sorted_id_tensor: Optional[list[torch.Tensor]] = None,
2973
+ no_snapshot: bool = True,
2974
+ should_flush: bool = False,
2975
+ ) -> list[list[torch.Tensor]]:
2976
+ """
2977
+ Returns a list of optimizer states split by table.
2978
+
2979
+ Since EXACT_ROWWISE_ADAGRAD has small optimizer states, we would generate
2980
+ a full tensor for each table (shard). When other optimizer types are supported,
2981
+ we should integrate with KVTensorWrapper (ssd_split_table_batched_embeddings.cpp)
2982
+ to allow caller to read the optimizer states using `narrow()` in a rolling-window manner.
2983
+
2984
+ Args:
2985
+ sorted_id_tensor (Optional[List[torch.Tensor]]): sorted id tensor by table, used to query optimizer
2986
+ state from backend. Call should reuse the generated id tensor from weight state_dict, to guarantee
2987
+ id consistency between weight and optimizer states.
2988
+
2989
+ """
2990
+
2991
+ # Handle the non-KVZCH case
2992
+ if not self.kv_zch_params:
2993
+ # If not in KV
2994
+ return self._split_optimizer_states_non_kv_zch()
2995
+
2996
+ # Handle the loading-from-state-dict case
2997
+ if self.load_state_dict:
2998
+ # Initialize for checkpointing loading
2999
+ assert (
3000
+ self._cached_kvzch_data is not None
3001
+ and self._cached_kvzch_data.cached_optimizer_states_per_table
3002
+ ), "optimizer state is not initialized for load checkpointing"
3003
+
3004
+ return self._cached_kvzch_data.cached_optimizer_states_per_table
3005
+
3006
+ logging.info(
3007
+ f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
3008
+ )
3009
+ start_time = time.time()
3010
+
3011
+ if not self.enable_optimizer_offloading:
3012
+ # Handle the KVZCH non-optimizer-offloading case
3013
+ optimizer_states = self._split_optimizer_states_kv_zch_no_offloading(
3014
+ sorted_id_tensor
3015
+ )
3016
+
3017
+ elif not self.backend_return_whole_row:
3018
+ # Handle the KVZCH with-optimizer-offloading case
3019
+ optimizer_states = self._split_optimizer_states_kv_zch_w_offloading(
3020
+ sorted_id_tensor, no_snapshot, should_flush
3021
+ )
3022
+
3023
+ else:
3024
+ # Handle the KVZCH with-optimizer-offloading backend-whole-row case
3025
+ optimizer_states = self._split_optimizer_states_kv_zch_whole_row(
3026
+ sorted_id_tensor, no_snapshot, should_flush
3027
+ )
3028
+
3029
+ logging.info(
3030
+ f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
3031
+ # pyre-ignore[16]
3032
+ f"num ids list: {None if not sorted_id_tensor else [ids.numel() for ids in sorted_id_tensor]}"
3033
+ )
3034
+
3035
+ return optimizer_states
3036
+
3037
+ @torch.jit.export
3038
+ def get_optimizer_state(
3039
+ self,
3040
+ sorted_id_tensor: Optional[list[torch.Tensor]],
3041
+ no_snapshot: bool = True,
3042
+ should_flush: bool = False,
3043
+ ) -> list[dict[str, torch.Tensor]]:
3044
+ """
3045
+ Returns a list of dictionaries of optimizer states split by table.
3046
+ """
3047
+ states_list: list[list[Tensor]] = self.split_optimizer_states(
3048
+ sorted_id_tensor=sorted_id_tensor,
3049
+ no_snapshot=no_snapshot,
3050
+ should_flush=should_flush,
3051
+ )
3052
+ state_names = self.optimizer.state_names()
3053
+ return [dict(zip(state_names, states)) for states in states_list]
3054
+
1723
3055
  @torch.jit.export
1724
- def debug_split_embedding_weights(self) -> List[torch.Tensor]:
3056
+ def debug_split_embedding_weights(self) -> list[torch.Tensor]:
1725
3057
  """
1726
3058
  Returns a list of weights, split by table.
1727
3059
 
1728
3060
  Testing only, very slow.
1729
3061
  """
1730
- (rows, _) = zip(*self.embedding_specs)
3062
+ rows, _ = zip(*self.embedding_specs)
1731
3063
 
1732
3064
  rows_cumsum = [0] + list(itertools.accumulate(rows))
1733
3065
  splits = []
1734
3066
  get_event = torch.cuda.Event()
1735
3067
 
1736
- for t, (row, dim) in enumerate(self.embedding_specs):
3068
+ for t, (row, _) in enumerate(self.embedding_specs):
1737
3069
  weights = torch.empty(
1738
3070
  (row, self.max_D), dtype=self.weights_precision.as_dtype()
1739
3071
  )
@@ -1765,12 +3097,92 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1765
3097
 
1766
3098
  return splits
1767
3099
 
3100
+ def clear_cache(self) -> None:
3101
+ # clear KV ZCH cache for checkpointing
3102
+ self._cached_kvzch_data = None
3103
+
3104
+ @torch.jit.ignore
3105
+ # pyre-ignore [3] - do not definte snapshot class EmbeddingSnapshotHandleWrapper to avoid import dependency in other production code
3106
+ def _may_create_snapshot_for_state_dict(
3107
+ self,
3108
+ no_snapshot: bool = True,
3109
+ should_flush: bool = False,
3110
+ ):
3111
+ """
3112
+ Create a rocksdb snapshot if needed.
3113
+ """
3114
+ start_time = time.time()
3115
+ # Force device synchronize for now
3116
+ torch.cuda.synchronize()
3117
+ snapshot_handle = None
3118
+ checkpoint_handle = None
3119
+ if self.backend_type == BackendType.SSD:
3120
+ # Create a rocksdb snapshot
3121
+ if not no_snapshot:
3122
+ # Flush L1 and L2 caches
3123
+ self.flush(force=should_flush)
3124
+ logging.info(
3125
+ f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
3126
+ )
3127
+ snapshot_handle = self.ssd_db.create_snapshot()
3128
+ checkpoint_handle = self.ssd_db.get_active_checkpoint_uuid(self.step)
3129
+ logging.info(
3130
+ f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
3131
+ )
3132
+ elif self.backend_type == BackendType.DRAM:
3133
+ # if there is any ongoing eviction, lets wait until eviction is finished before state_dict
3134
+ # so that we can reach consistent model state before/after state_dict
3135
+ evict_wait_start_time = time.time()
3136
+ self.ssd_db.wait_until_eviction_done()
3137
+ logging.info(
3138
+ f"state_dict wait for ongoing eviction: {time.time() - evict_wait_start_time} s"
3139
+ )
3140
+ self.flush(force=should_flush)
3141
+ return snapshot_handle, checkpoint_handle
3142
+
3143
+ def get_embedding_dim_for_kvt(
3144
+ self, metaheader_dim: int, emb_dim: int, is_loading_checkpoint: bool
3145
+ ) -> int:
3146
+ if self.load_ckpt_without_opt:
3147
+ # For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host.
3148
+ # So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that
3149
+ # backend will only have metaheader + weight.
3150
+ # For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error
3151
+ # after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb
3152
+ if is_loading_checkpoint:
3153
+ return (
3154
+ (
3155
+ metaheader_dim # metaheader is already padded
3156
+ + pad4(emb_dim)
3157
+ + pad4(self.optimizer_state_dim)
3158
+ )
3159
+ if self.backend_return_whole_row
3160
+ else emb_dim
3161
+ )
3162
+ else:
3163
+ return metaheader_dim + pad4(emb_dim)
3164
+ else:
3165
+ return (
3166
+ (
3167
+ metaheader_dim # metaheader is already padded
3168
+ + pad4(emb_dim)
3169
+ + pad4(self.optimizer_state_dim)
3170
+ )
3171
+ if self.backend_return_whole_row
3172
+ else emb_dim
3173
+ )
3174
+
1768
3175
  @torch.jit.export
1769
3176
  def split_embedding_weights(
1770
3177
  self,
1771
3178
  no_snapshot: bool = True,
1772
3179
  should_flush: bool = False,
1773
- ) -> List[PartiallyMaterializedTensor]:
3180
+ ) -> tuple[ # TODO: make this a NamedTuple for readability
3181
+ Union[list[PartiallyMaterializedTensor], list[torch.Tensor]],
3182
+ Optional[list[torch.Tensor]],
3183
+ Optional[list[torch.Tensor]],
3184
+ Optional[list[torch.Tensor]],
3185
+ ]:
1774
3186
  """
1775
3187
  This method is intended to be used by the checkpointing engine
1776
3188
  only.
@@ -1784,50 +3196,454 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1784
3196
  operation, only set to True when necessary.
1785
3197
 
1786
3198
  Returns:
1787
- a list of partially materialized tensors, each representing a table
3199
+ tuples of 3 lists, each element corresponds to a logical table
3200
+ 1st arg: partially materialized tensors, each representing a table
3201
+ 2nd arg: input id sorted in bucket id ascending order
3202
+ 3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
3203
+ where for the i th element, we have i + bucket_id_start = global bucket id
3204
+ 4th arg: kvzch eviction metadata for each input id sorted in bucket id ascending order
1788
3205
  """
1789
- # Force device synchronize for now
1790
- torch.cuda.synchronize()
1791
- # Create a snapshot
1792
- if no_snapshot:
1793
- snapshot_handle = None
1794
- else:
1795
- if should_flush:
1796
- # Flush L1 and L2 caches
1797
- self.flush()
1798
- snapshot_handle = self.ssd_db.create_snapshot()
3206
+ snapshot_handle, checkpoint_handle = self._may_create_snapshot_for_state_dict(
3207
+ no_snapshot=no_snapshot,
3208
+ should_flush=should_flush,
3209
+ )
1799
3210
 
1800
3211
  dtype = self.weights_precision.as_dtype()
1801
- splits = []
3212
+ if self.load_state_dict and self.kv_zch_params:
3213
+ # init for checkpointing loading
3214
+ assert (
3215
+ self._cached_kvzch_data is not None
3216
+ ), "weight id and bucket state are not initialized for load checkpointing"
3217
+ return (
3218
+ self._cached_kvzch_data.cached_weight_tensor_per_table,
3219
+ self._cached_kvzch_data.cached_id_tensor_per_table,
3220
+ self._cached_kvzch_data.cached_bucket_splits,
3221
+ [], # metadata tensor is not needed for checkpointing loading
3222
+ )
3223
+ start_time = time.time()
3224
+ pmt_splits = []
3225
+ bucket_sorted_id_splits = [] if self.kv_zch_params else None
3226
+ active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
3227
+ metadata_splits = [] if self.kv_zch_params else None
3228
+ skip_metadata = False
3229
+
3230
+ table_offset = 0
3231
+ for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
3232
+ is_loading_checkpoint = False
3233
+ bucket_ascending_id_tensor = None
3234
+ bucket_t = None
3235
+ metadata_tensor = None
3236
+ row_offset = table_offset
3237
+ metaheader_dim = 0
3238
+ if self.kv_zch_params:
3239
+ bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
3240
+ # pyre-ignore
3241
+ bucket_size = self.kv_zch_params.bucket_sizes[i]
3242
+ metaheader_dim = (
3243
+ # pyre-ignore[16]
3244
+ self.kv_zch_params.eviction_policy.meta_header_lens[i]
3245
+ )
1802
3246
 
1803
- row_offset = 0
1804
- for emb_height, emb_dim in self.embedding_specs:
1805
- tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
1806
- db=self.ssd_db,
1807
- shape=[emb_height, emb_dim],
1808
- dtype=dtype,
3247
+ # linearize with table offset
3248
+ table_input_id_start = table_offset
3249
+ table_input_id_end = table_offset + emb_height
3250
+ # 1. get all keys from backend for one table
3251
+ unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
3252
+ table_input_id_start,
3253
+ table_input_id_end,
3254
+ table_offset,
3255
+ snapshot_handle,
3256
+ )
3257
+ # 2. sorting keys in bucket ascending order
3258
+ bucket_ascending_id_tensor, bucket_t = (
3259
+ torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
3260
+ unordered_id_tensor,
3261
+ 0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
3262
+ 0, # local bucket offset
3263
+ bucket_id_end - bucket_id_start, # local bucket num
3264
+ bucket_size,
3265
+ )
3266
+ )
3267
+ metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
3268
+ bucket_ascending_id_tensor + table_offset,
3269
+ torch.as_tensor(bucket_ascending_id_tensor.size(0)),
3270
+ snapshot_handle,
3271
+ ).view(-1, 1)
3272
+
3273
+ # 3. convert local id back to global id
3274
+ bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
3275
+
3276
+ if (
3277
+ bucket_ascending_id_tensor.size(0) == 0
3278
+ and self.local_weight_counts[i] > 0
3279
+ ):
3280
+ logging.info(
3281
+ f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
3282
+ )
3283
+ if self.global_id_per_rank[i].numel() != 0:
3284
+ assert (
3285
+ self.local_weight_counts[i]
3286
+ == self.global_id_per_rank[i].numel()
3287
+ ), f"local weight count and global id per rank size mismatch, with {self.local_weight_counts[i]} and {self.global_id_per_rank[i].numel()}"
3288
+ bucket_ascending_id_tensor = self.global_id_per_rank[i].to(
3289
+ device=torch.device("cpu"), dtype=torch.int64
3290
+ )
3291
+ else:
3292
+ bucket_ascending_id_tensor = torch.zeros(
3293
+ (self.local_weight_counts[i], 1),
3294
+ device=torch.device("cpu"),
3295
+ dtype=torch.int64,
3296
+ )
3297
+ skip_metadata = True
3298
+ is_loading_checkpoint = True
3299
+
3300
+ # self.local_weight_counts[i] = 0 # Reset the count
3301
+
3302
+ # pyre-ignore [16] bucket_sorted_id_splits is not None
3303
+ bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
3304
+ active_id_cnt_per_bucket_split.append(bucket_t)
3305
+ if skip_metadata:
3306
+ metadata_splits = None
3307
+ else:
3308
+ metadata_splits.append(metadata_tensor)
3309
+
3310
+ # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
3311
+ # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
3312
+ # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
3313
+ # to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
3314
+ row_offset = table_offset - (bucket_id_start * bucket_size)
3315
+
3316
+ tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
3317
+ shape=[
3318
+ (
3319
+ bucket_ascending_id_tensor.size(0)
3320
+ if bucket_ascending_id_tensor is not None
3321
+ else emb_height
3322
+ ),
3323
+ self.get_embedding_dim_for_kvt(
3324
+ metaheader_dim, emb_dim, is_loading_checkpoint
3325
+ ),
3326
+ ],
3327
+ dtype=dtype,
1809
3328
  row_offset=row_offset,
1810
3329
  snapshot_handle=snapshot_handle,
3330
+ # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
3331
+ # embedding weights.
3332
+ sorted_indices=(
3333
+ bucket_ascending_id_tensor if self.kv_zch_params else None
3334
+ ),
3335
+ checkpoint_handle=checkpoint_handle,
3336
+ only_load_weight=(
3337
+ True
3338
+ if self.load_ckpt_without_opt and is_loading_checkpoint
3339
+ else False
3340
+ ),
3341
+ )
3342
+ (
3343
+ tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
3344
+ if self.backend_type == BackendType.SSD
3345
+ else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
3346
+ )
3347
+ table_offset += emb_height
3348
+ pmt_splits.append(
3349
+ PartiallyMaterializedTensor(
3350
+ tensor_wrapper,
3351
+ True if self.kv_zch_params else False,
3352
+ )
3353
+ )
3354
+ logging.info(
3355
+ f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms, "
3356
+ )
3357
+ if self.kv_zch_params is not None:
3358
+ logging.info(
3359
+ # pyre-ignore [16]
3360
+ f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
3361
+ )
3362
+
3363
+ return (
3364
+ pmt_splits,
3365
+ bucket_sorted_id_splits,
3366
+ active_id_cnt_per_bucket_split,
3367
+ metadata_splits,
3368
+ )
3369
+
3370
+ @torch.jit.ignore
3371
+ def _apply_state_dict_w_offloading(self) -> None:
3372
+ # Row count per table
3373
+ rows, _ = zip(*self.embedding_specs)
3374
+ # Cumulative row counts per table for rowwise states
3375
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
3376
+
3377
+ for t, _ in enumerate(self.embedding_specs):
3378
+ # pyre-ignore [16]
3379
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
3380
+ # pyre-ignore [16]
3381
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
3382
+ row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
3383
+
3384
+ # pyre-ignore [16]
3385
+ weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
3386
+ # pyre-ignore [16]
3387
+ opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
3388
+
3389
+ self.streaming_write_weight_and_id_per_table(
3390
+ weight_state,
3391
+ opt_states,
3392
+ # pyre-ignore [16]
3393
+ self._cached_kvzch_data.cached_id_tensor_per_table[t],
3394
+ row_offset,
3395
+ )
3396
+ self._cached_kvzch_data.cached_weight_tensor_per_table[t] = None
3397
+ self._cached_kvzch_data.cached_optimizer_states_per_table[t] = None
3398
+
3399
+ @torch.jit.ignore
3400
+ def _apply_state_dict_no_offloading(self) -> None:
3401
+ # Row count per table
3402
+ rows, _ = zip(*self.embedding_specs)
3403
+ # Cumulative row counts per table for rowwise states
3404
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
3405
+
3406
+ def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None:
3407
+ device = dst.device
3408
+ dst.index_put_(
3409
+ indices=(
3410
+ # indices is expected to be a tuple of Tensors, not Tensor
3411
+ indices.to(device).view(-1),
3412
+ ),
3413
+ values=src.to(device),
3414
+ )
3415
+
3416
+ for t, _ in enumerate(rows):
3417
+ # pyre-ignore [16]
3418
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
3419
+ # pyre-ignore [16]
3420
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
3421
+ row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
3422
+
3423
+ # pyre-ignore [16]
3424
+ weights = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
3425
+ # pyre-ignore [16]
3426
+ ids = self._cached_kvzch_data.cached_id_tensor_per_table[t]
3427
+ local_ids = ids + row_offset
3428
+
3429
+ logging.info(
3430
+ f"applying sd for table {t} without optimizer offloading, local_ids is {local_ids}"
3431
+ )
3432
+ # pyre-ignore [16]
3433
+ opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
3434
+
3435
+ # Set up the plan for copying optimizer states over
3436
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
3437
+ mapping = [(opt_states[0], self.momentum1_dev)]
3438
+ elif self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
3439
+ mapping = [
3440
+ (opt_states[0], self.momentum1_dev),
3441
+ (opt_states[1], self.momentum2_dev),
3442
+ ]
3443
+ else:
3444
+ mapping = []
3445
+
3446
+ # Execute the plan and copy the optimizer states over
3447
+ # pyre-ignore [6]
3448
+ [copy_optimizer_state_(dst, src, local_ids) for (src, dst) in mapping]
3449
+
3450
+ self.ssd_db.set_cuda(
3451
+ local_ids.view(-1),
3452
+ weights,
3453
+ torch.as_tensor(local_ids.size(0)),
3454
+ 1,
3455
+ False,
3456
+ )
3457
+
3458
+ @torch.jit.ignore
3459
+ def apply_state_dict(self) -> None:
3460
+ if self.backend_return_whole_row:
3461
+ logging.info(
3462
+ "backend_return_whole_row is enabled, no need to apply_state_dict"
3463
+ )
3464
+ return
3465
+ # After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
3466
+ # Caller should call this function to apply the cached states to backend.
3467
+ if self.load_state_dict is False:
3468
+ return
3469
+ self.load_state_dict = False
3470
+ assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
3471
+ assert (
3472
+ self._cached_kvzch_data is not None
3473
+ and self._cached_kvzch_data.cached_optimizer_states_per_table is not None
3474
+ ), "optimizer state is not initialized for load checkpointing"
3475
+ assert (
3476
+ self._cached_kvzch_data.cached_weight_tensor_per_table is not None
3477
+ and self._cached_kvzch_data.cached_id_tensor_per_table is not None
3478
+ ), "weight and id state is not initialized for load checkpointing"
3479
+
3480
+ # Compute the number of elements of cache_dtype needed to store the
3481
+ # optimizer state, round to the nearest 4
3482
+ # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
3483
+ # apply weight and optimizer state per table
3484
+ if self.enable_optimizer_offloading:
3485
+ self._apply_state_dict_w_offloading()
3486
+ else:
3487
+ self._apply_state_dict_no_offloading()
3488
+
3489
+ self.clear_cache()
3490
+
3491
+ @torch.jit.ignore
3492
+ def streaming_write_weight_and_id_per_table(
3493
+ self,
3494
+ weight_state: torch.Tensor,
3495
+ opt_states: list[torch.Tensor],
3496
+ id_tensor: torch.Tensor,
3497
+ row_offset: int,
3498
+ ) -> None:
3499
+ """
3500
+ This function is used to write weight, optimizer and id to the backend using kvt wrapper.
3501
+ to avoid over use memory, we will write the weight and id to backend in a rolling window manner
3502
+
3503
+ Args:
3504
+ weight_state (torch.tensor): The weight state tensor to be written.
3505
+ opt_states (torch.tensor): The optimizer state tensor(s) to be written.
3506
+ id_tensor (torch.tensor): The id tensor to be written.
3507
+ """
3508
+ D = weight_state.size(1)
3509
+ dtype = self.weights_precision.as_dtype()
3510
+
3511
+ optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
3512
+ D, self.weights_precision, self.optimizer_state_dtypes
3513
+ )
3514
+ optimizer_state_size_table = self.optimizer.state_size_table(D)
3515
+
3516
+ kvt = torch.classes.fbgemm.KVTensorWrapper(
3517
+ shape=[weight_state.size(0), self.cache_row_dim],
3518
+ dtype=dtype,
3519
+ row_offset=row_offset,
3520
+ snapshot_handle=None,
3521
+ sorted_indices=id_tensor,
3522
+ )
3523
+ (
3524
+ kvt.set_embedding_rocks_dp_wrapper(self.ssd_db)
3525
+ if self.backend_type == BackendType.SSD
3526
+ else kvt.set_dram_db_wrapper(self.ssd_db)
3527
+ )
3528
+
3529
+ # TODO: make chunk_size configurable or dynamic
3530
+ chunk_size = 10000
3531
+ row = weight_state.size(0)
3532
+
3533
+ for i in range(0, row, chunk_size):
3534
+ # Construct the chunk buffer, using the weights precision as the dtype
3535
+ length = min(chunk_size, row - i)
3536
+ chunk_buffer = torch.empty(
3537
+ length,
3538
+ self.cache_row_dim,
3539
+ dtype=dtype,
3540
+ device="cpu",
3541
+ )
3542
+
3543
+ # Copy the weight state over to the chunk buffer
3544
+ chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
3545
+
3546
+ # Copy the optimizer state(s) over to the chunk buffer
3547
+ for o, opt_state in enumerate(opt_states):
3548
+ # Fetch the state name based on the index
3549
+ state_name = self.optimizer.state_names()[o]
3550
+
3551
+ # Fetch the byte offsets for the optimizer state by its name
3552
+ start, end = optimizer_state_byte_offsets[state_name]
3553
+
3554
+ # Assume that the opt_state passed in already has dtype matching
3555
+ # self.optimizer_state_dtypes[state_name]
3556
+ opt_state_byteview = opt_state.view(
3557
+ # Force it to be 2D table, with row size matching the
3558
+ # optimizer state size
3559
+ -1,
3560
+ optimizer_state_size_table[state_name],
3561
+ ).view(
3562
+ # Then force tensor to byte view
3563
+ dtype=torch.uint8
3564
+ )
3565
+
3566
+ # Convert the chunk buffer and optimizer state to byte views
3567
+ # Then use the start and end offsets to narrow the chunk buffer
3568
+ # and copy opt_state over
3569
+ chunk_buffer.view(dtype=torch.uint8)[:, start:end] = opt_state_byteview[
3570
+ i : i + length, :
3571
+ ]
3572
+
3573
+ # Write chunk to KVTensor
3574
+ kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
3575
+
3576
+ @torch.jit.ignore
3577
+ def enable_load_state_dict_mode(self) -> None:
3578
+ if self.backend_return_whole_row:
3579
+ logging.info(
3580
+ "backend_return_whole_row is enabled, no need to enable load_state_dict mode"
3581
+ )
3582
+ return
3583
+ # Enable load state dict mode before loading checkpoint
3584
+ if self.load_state_dict:
3585
+ return
3586
+ self.load_state_dict = True
3587
+
3588
+ dtype = self.weights_precision.as_dtype()
3589
+ _, dims = zip(*self.embedding_specs)
3590
+
3591
+ self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
3592
+
3593
+ for i, _ in enumerate(self.embedding_specs):
3594
+ # For checkpointing loading, we need to store the weight and id
3595
+ # tensor temporarily in memory. First check that the local_weight_counts
3596
+ # are properly set before even initializing the optimizer states
3597
+ assert (
3598
+ self.local_weight_counts[i] > 0
3599
+ ), f"local_weight_counts for table {i} is not set"
3600
+
3601
+ # pyre-ignore [16]
3602
+ self._cached_kvzch_data.cached_optimizer_states_per_table = (
3603
+ self.optimizer.empty_states(
3604
+ self.local_weight_counts,
3605
+ dims,
3606
+ self.optimizer_state_dtypes,
3607
+ )
3608
+ )
3609
+
3610
+ for i, (_, emb_dim) in enumerate(self.embedding_specs):
3611
+ # pyre-ignore [16]
3612
+ bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
3613
+ rows = self.local_weight_counts[i]
3614
+ weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
3615
+ # pyre-ignore [16]
3616
+ self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
3617
+ logging.info(
3618
+ f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}"
3619
+ )
3620
+ id_tensor = torch.zeros((rows, 1), dtype=torch.int64, device="cpu")
3621
+ # pyre-ignore [16]
3622
+ self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
3623
+ # pyre-ignore [16]
3624
+ self._cached_kvzch_data.cached_bucket_splits.append(
3625
+ torch.empty(
3626
+ (bucket_id_end - bucket_id_start, 1),
3627
+ dtype=torch.int64,
3628
+ device="cpu",
3629
+ )
1811
3630
  )
1812
- row_offset += emb_height
1813
- splits.append(PartiallyMaterializedTensor(tensor_wrapper))
1814
- return splits
1815
3631
 
1816
3632
  @torch.jit.export
1817
3633
  def set_learning_rate(self, lr: float) -> None:
1818
3634
  """
1819
3635
  Sets the learning rate.
3636
+
3637
+ Args:
3638
+ lr (float): The learning rate value to set to
1820
3639
  """
1821
3640
  self._set_learning_rate(lr)
1822
3641
 
1823
3642
  def get_learning_rate(self) -> float:
1824
3643
  """
1825
- Sets the learning rate.
1826
-
1827
- Args:
1828
- lr (float): The learning rate value to set to
3644
+ Get and return the learning rate.
1829
3645
  """
1830
- return self.optimizer_args.learning_rate
3646
+ return self.learning_rate_tensor.item()
1831
3647
 
1832
3648
  @torch.jit.ignore
1833
3649
  def _set_learning_rate(self, lr: float) -> float:
@@ -1835,14 +3651,30 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1835
3651
  Helper function to script `set_learning_rate`.
1836
3652
  Note that returning None does not work.
1837
3653
  """
1838
- self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
3654
+ self.learning_rate_tensor = torch.tensor(
3655
+ lr, device=torch.device("cpu"), dtype=torch.float32
3656
+ )
1839
3657
  return 0.0
1840
3658
 
1841
- def flush(self) -> None:
3659
+ def flush(self, force: bool = False) -> None:
3660
+ # allow force flush from split_embedding_weights to cover edge cases, e.g. checkpointing
3661
+ # after trained 0 batches
3662
+ if not self.training:
3663
+ # for eval mode, we should not write anything to embedding
3664
+ return
3665
+
3666
+ if self.step == self.last_flush_step and not force:
3667
+ logging.info(
3668
+ f"SSD TBE has been flushed at {self.last_flush_step=} already for tbe:{self.tbe_unique_id}"
3669
+ )
3670
+ return
3671
+ logging.info(
3672
+ f"SSD TBE flush at {self.step=}, it is an expensive call please be cautious"
3673
+ )
1842
3674
  active_slots_mask = self.lxu_cache_state != -1
1843
3675
 
1844
3676
  active_weights_gpu = self.lxu_cache_weights[active_slots_mask.view(-1)].view(
1845
- -1, self.max_D
3677
+ -1, self.cache_row_dim
1846
3678
  )
1847
3679
  active_ids_gpu = self.lxu_cache_state.view(-1)[active_slots_mask.view(-1)]
1848
3680
 
@@ -1858,24 +3690,38 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1858
3690
  torch.tensor([active_ids_cpu.numel()]),
1859
3691
  )
1860
3692
  self.ssd_db.flush()
3693
+ self.last_flush_step = self.step
3694
+
3695
+ def create_rocksdb_hard_link_snapshot(self) -> None:
3696
+ """
3697
+ Create a rocksdb hard link snapshot to provide cross procs access to the underlying data
3698
+ """
3699
+ if self.backend_type == BackendType.SSD:
3700
+ self.ssd_db.create_rocksdb_hard_link_snapshot(self.step)
3701
+ else:
3702
+ logging.warning(
3703
+ "create_rocksdb_hard_link_snapshot is only supported for SSD backend"
3704
+ )
1861
3705
 
1862
3706
  def prepare_inputs(
1863
3707
  self,
1864
3708
  indices: Tensor,
1865
3709
  offsets: Tensor,
1866
3710
  per_sample_weights: Optional[Tensor] = None,
1867
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
1868
- ) -> Tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3711
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3712
+ vbe_output: Optional[Tensor] = None,
3713
+ vbe_output_offsets: Optional[Tensor] = None,
3714
+ ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
1869
3715
  """
1870
3716
  Prepare TBE inputs
1871
3717
  """
1872
3718
  # Generate VBE metadata
1873
3719
  vbe_metadata = self._generate_vbe_metadata(
1874
- offsets, batch_size_per_feature_per_rank
3720
+ offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
1875
3721
  )
1876
3722
 
1877
3723
  # Force casting indices and offsets to long
1878
- (indices, offsets) = indices.long(), offsets.long()
3724
+ indices, offsets = indices.long(), offsets.long()
1879
3725
 
1880
3726
  # Force casting per_sample_weights to float
1881
3727
  if per_sample_weights is not None:
@@ -1891,12 +3737,13 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1891
3737
  per_sample_weights,
1892
3738
  B_offsets=vbe_metadata.B_offsets,
1893
3739
  max_B=vbe_metadata.max_B,
3740
+ bounds_check_version=self.bounds_check_version,
1894
3741
  )
1895
3742
 
1896
3743
  return indices, offsets, per_sample_weights, vbe_metadata
1897
3744
 
1898
3745
  @torch.jit.ignore
1899
- def _report_ssd_stats(self) -> None:
3746
+ def _report_kv_backend_stats(self) -> None:
1900
3747
  """
1901
3748
  All ssd stats report function entrance
1902
3749
  """
@@ -1906,9 +3753,15 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1906
3753
  if not self.stats_reporter.should_report(self.step):
1907
3754
  return
1908
3755
  self._report_ssd_l1_cache_stats()
1909
- self._report_ssd_io_stats()
1910
- self._report_ssd_mem_usage()
1911
- self._report_l2_cache_perf_stats()
3756
+
3757
+ if self.backend_type == BackendType.SSD:
3758
+ self._report_ssd_io_stats()
3759
+ self._report_ssd_mem_usage()
3760
+ self._report_l2_cache_perf_stats()
3761
+ if self.backend_type == BackendType.DRAM:
3762
+ self._report_dram_kv_perf_stats()
3763
+ if self.kv_zch_params and self.kv_zch_params.eviction_policy:
3764
+ self._report_eviction_stats()
1912
3765
 
1913
3766
  @torch.jit.ignore
1914
3767
  def _report_ssd_l1_cache_stats(self) -> None:
@@ -1925,7 +3778,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1925
3778
  ssd_cache_stats = self.ssd_cache_stats.tolist()
1926
3779
  if len(self.last_reported_ssd_stats) == 0:
1927
3780
  self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
1928
- ssd_cache_stats_delta: List[float] = [0.0] * len(ssd_cache_stats)
3781
+ ssd_cache_stats_delta: list[float] = [0.0] * len(ssd_cache_stats)
1929
3782
  for i in range(len(ssd_cache_stats)):
1930
3783
  ssd_cache_stats_delta[i] = (
1931
3784
  ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
@@ -1942,11 +3795,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1942
3795
  data_bytes=int(
1943
3796
  ssd_cache_stats_delta[stat_index.value]
1944
3797
  * element_size
1945
- * self.max_D
3798
+ * self.cache_row_dim
1946
3799
  / passed_steps
1947
3800
  ),
1948
3801
  )
1949
- # pyre-ignore
3802
+
1950
3803
  self.stats_reporter.report_data_amount(
1951
3804
  iteration_step=self.step,
1952
3805
  event_name=f"ssd_tbe.prefetch.cache_stats.{stat_index.name.lower()}",
@@ -1973,35 +3826,35 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
1973
3826
  bwd_l1_cnflct_miss_write_back_dur = ssd_io_duration[3]
1974
3827
  flush_write_dur = ssd_io_duration[4]
1975
3828
 
1976
- # pyre-ignore
3829
+ # pyre-ignore [16]
1977
3830
  self.stats_reporter.report_duration(
1978
3831
  iteration_step=self.step,
1979
3832
  event_name="ssd.io_duration.read_us",
1980
3833
  duration_ms=ssd_read_dur_us,
1981
3834
  time_unit="us",
1982
3835
  )
1983
- # pyre-ignore
3836
+
1984
3837
  self.stats_reporter.report_duration(
1985
3838
  iteration_step=self.step,
1986
3839
  event_name="ssd.io_duration.write.fwd_rocksdb_read_us",
1987
3840
  duration_ms=fwd_rocksdb_read_dur,
1988
3841
  time_unit="us",
1989
3842
  )
1990
- # pyre-ignore
3843
+
1991
3844
  self.stats_reporter.report_duration(
1992
3845
  iteration_step=self.step,
1993
3846
  event_name="ssd.io_duration.write.fwd_l1_eviction_us",
1994
3847
  duration_ms=fwd_l1_eviction_dur,
1995
3848
  time_unit="us",
1996
3849
  )
1997
- # pyre-ignore
3850
+
1998
3851
  self.stats_reporter.report_duration(
1999
3852
  iteration_step=self.step,
2000
3853
  event_name="ssd.io_duration.write.bwd_l1_cnflct_miss_write_back_us",
2001
3854
  duration_ms=bwd_l1_cnflct_miss_write_back_dur,
2002
3855
  time_unit="us",
2003
3856
  )
2004
- # pyre-ignore
3857
+
2005
3858
  self.stats_reporter.report_duration(
2006
3859
  iteration_step=self.step,
2007
3860
  event_name="ssd.io_duration.write.flush_write_us",
@@ -2023,25 +3876,25 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2023
3876
  memtable_usage = mem_usage_list[2]
2024
3877
  block_cache_pinned_usage = mem_usage_list[3]
2025
3878
 
2026
- # pyre-ignore
3879
+ # pyre-ignore [16]
2027
3880
  self.stats_reporter.report_data_amount(
2028
3881
  iteration_step=self.step,
2029
3882
  event_name="ssd.mem_usage.block_cache",
2030
3883
  data_bytes=block_cache_usage,
2031
3884
  )
2032
- # pyre-ignore
3885
+
2033
3886
  self.stats_reporter.report_data_amount(
2034
3887
  iteration_step=self.step,
2035
3888
  event_name="ssd.mem_usage.estimate_table_reader",
2036
3889
  data_bytes=estimate_table_reader_usage,
2037
3890
  )
2038
- # pyre-ignore
3891
+
2039
3892
  self.stats_reporter.report_data_amount(
2040
3893
  iteration_step=self.step,
2041
3894
  event_name="ssd.mem_usage.memtable",
2042
3895
  data_bytes=memtable_usage,
2043
3896
  )
2044
- # pyre-ignore
3897
+
2045
3898
  self.stats_reporter.report_data_amount(
2046
3899
  iteration_step=self.step,
2047
3900
  event_name="ssd.mem_usage.block_cache_pinned",
@@ -2175,7 +4028,408 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2175
4028
  time_unit="us",
2176
4029
  )
2177
4030
 
2178
- # pyre-ignore
4031
+ @torch.jit.ignore
4032
+ def _report_eviction_stats(self) -> None:
4033
+ if self.stats_reporter is None:
4034
+ return
4035
+
4036
+ stats_reporter: TBEStatsReporter = self.stats_reporter
4037
+ if not stats_reporter.should_report(self.step):
4038
+ return
4039
+
4040
+ # skip metrics reporting when evicting disabled
4041
+ if self.kv_zch_params.eviction_policy.eviction_trigger_mode == 0:
4042
+ return
4043
+
4044
+ T = len(set(self.feature_table_map))
4045
+ evicted_counts = torch.zeros(T, dtype=torch.int64)
4046
+ processed_counts = torch.zeros(T, dtype=torch.int64)
4047
+ eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float)
4048
+ full_duration_ms = torch.tensor(0, dtype=torch.int64)
4049
+ exec_duration_ms = torch.tensor(0, dtype=torch.int64)
4050
+ self.ssd_db.get_feature_evict_metric(
4051
+ evicted_counts,
4052
+ processed_counts,
4053
+ eviction_threshold_with_dry_run,
4054
+ full_duration_ms,
4055
+ exec_duration_ms,
4056
+ )
4057
+
4058
+ stats_reporter.report_data_amount(
4059
+ iteration_step=self.step,
4060
+ event_name=self.eviction_sum_evicted_counts_stats_name,
4061
+ data_bytes=int(evicted_counts.sum().item()),
4062
+ enable_tb_metrics=True,
4063
+ )
4064
+ stats_reporter.report_data_amount(
4065
+ iteration_step=self.step,
4066
+ event_name=self.eviction_sum_processed_counts_stats_name,
4067
+ data_bytes=int(processed_counts.sum().item()),
4068
+ enable_tb_metrics=True,
4069
+ )
4070
+ if processed_counts.sum().item() != 0:
4071
+ stats_reporter.report_data_amount(
4072
+ iteration_step=self.step,
4073
+ event_name=self.eviction_evict_rate_stats_name,
4074
+ data_bytes=int(
4075
+ evicted_counts.sum().item() * 100 / processed_counts.sum().item()
4076
+ ),
4077
+ enable_tb_metrics=True,
4078
+ )
4079
+ for t in self.feature_table_map:
4080
+ stats_reporter.report_data_amount(
4081
+ iteration_step=self.step,
4082
+ event_name=f"eviction.feature_table.{t}.evicted_counts",
4083
+ data_bytes=int(evicted_counts[t].item()),
4084
+ enable_tb_metrics=True,
4085
+ )
4086
+ stats_reporter.report_data_amount(
4087
+ iteration_step=self.step,
4088
+ event_name=f"eviction.feature_table.{t}.processed_counts",
4089
+ data_bytes=int(processed_counts[t].item()),
4090
+ enable_tb_metrics=True,
4091
+ )
4092
+ if processed_counts[t].item() != 0:
4093
+ stats_reporter.report_data_amount(
4094
+ iteration_step=self.step,
4095
+ event_name=f"eviction.feature_table.{t}.evict_rate",
4096
+ data_bytes=int(
4097
+ evicted_counts[t].item() * 100 / processed_counts[t].item()
4098
+ ),
4099
+ enable_tb_metrics=True,
4100
+ )
4101
+ stats_reporter.report_duration(
4102
+ iteration_step=self.step,
4103
+ event_name="eviction.feature_table.full_duration_ms",
4104
+ duration_ms=full_duration_ms.item(),
4105
+ time_unit="ms",
4106
+ enable_tb_metrics=True,
4107
+ )
4108
+ stats_reporter.report_duration(
4109
+ iteration_step=self.step,
4110
+ event_name="eviction.feature_table.exec_duration_ms",
4111
+ duration_ms=exec_duration_ms.item(),
4112
+ time_unit="ms",
4113
+ enable_tb_metrics=True,
4114
+ )
4115
+ if full_duration_ms.item() != 0:
4116
+ stats_reporter.report_data_amount(
4117
+ iteration_step=self.step,
4118
+ event_name="eviction.feature_table.exec_div_full_duration_rate",
4119
+ data_bytes=int(exec_duration_ms.item() * 100 / full_duration_ms.item()),
4120
+ enable_tb_metrics=True,
4121
+ )
4122
+
4123
+ @torch.jit.ignore
4124
+ def _report_dram_kv_perf_stats(self) -> None:
4125
+ """
4126
+ EmbeddingKVDB will hold stats for DRAM cache performance in fwd/bwd
4127
+ this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
4128
+ """
4129
+ if self.stats_reporter is None:
4130
+ return
4131
+
4132
+ stats_reporter: TBEStatsReporter = self.stats_reporter
4133
+ if not stats_reporter.should_report(self.step):
4134
+ return
4135
+
4136
+ dram_kv_perf_stats = self.ssd_db.get_dram_kv_perf(
4137
+ self.step, stats_reporter.report_interval # pyre-ignore
4138
+ )
4139
+
4140
+ if len(dram_kv_perf_stats) != 36:
4141
+ logging.error("dram cache perf stats should have 36 elements")
4142
+ return
4143
+
4144
+ dram_read_duration = dram_kv_perf_stats[0]
4145
+ dram_read_sharding_duration = dram_kv_perf_stats[1]
4146
+ dram_read_cache_hit_copy_duration = dram_kv_perf_stats[2]
4147
+ dram_read_fill_row_storage_duration = dram_kv_perf_stats[3]
4148
+ dram_read_lookup_cache_duration = dram_kv_perf_stats[4]
4149
+ dram_read_acquire_lock_duration = dram_kv_perf_stats[5]
4150
+ dram_read_missing_load = dram_kv_perf_stats[6]
4151
+ dram_write_sharing_duration = dram_kv_perf_stats[7]
4152
+
4153
+ dram_fwd_l1_eviction_write_duration = dram_kv_perf_stats[8]
4154
+ dram_fwd_l1_eviction_write_allocate_duration = dram_kv_perf_stats[9]
4155
+ dram_fwd_l1_eviction_write_cache_copy_duration = dram_kv_perf_stats[10]
4156
+ dram_fwd_l1_eviction_write_lookup_cache_duration = dram_kv_perf_stats[11]
4157
+ dram_fwd_l1_eviction_write_acquire_lock_duration = dram_kv_perf_stats[12]
4158
+ dram_fwd_l1_eviction_write_missing_load = dram_kv_perf_stats[13]
4159
+
4160
+ dram_bwd_l1_cnflct_miss_write_duration = dram_kv_perf_stats[14]
4161
+ dram_bwd_l1_cnflct_miss_write_allocate_duration = dram_kv_perf_stats[15]
4162
+ dram_bwd_l1_cnflct_miss_write_cache_copy_duration = dram_kv_perf_stats[16]
4163
+ dram_bwd_l1_cnflct_miss_write_lookup_cache_duration = dram_kv_perf_stats[17]
4164
+ dram_bwd_l1_cnflct_miss_write_acquire_lock_duration = dram_kv_perf_stats[18]
4165
+ dram_bwd_l1_cnflct_miss_write_missing_load = dram_kv_perf_stats[19]
4166
+
4167
+ dram_kv_allocated_bytes = dram_kv_perf_stats[20]
4168
+ dram_kv_actual_used_chunk_bytes = dram_kv_perf_stats[21]
4169
+ dram_kv_num_rows = dram_kv_perf_stats[22]
4170
+ dram_kv_read_counts = dram_kv_perf_stats[23]
4171
+ dram_metadata_write_sharding_total_duration = dram_kv_perf_stats[24]
4172
+ dram_metadata_write_total_duration = dram_kv_perf_stats[25]
4173
+ dram_metadata_write_allocate_avg_duration = dram_kv_perf_stats[26]
4174
+ dram_metadata_write_lookup_cache_avg_duration = dram_kv_perf_stats[27]
4175
+ dram_metadata_write_acquire_lock_avg_duration = dram_kv_perf_stats[28]
4176
+ dram_metadata_write_cache_miss_avg_count = dram_kv_perf_stats[29]
4177
+
4178
+ dram_read_metadata_total_duration = dram_kv_perf_stats[30]
4179
+ dram_read_metadata_sharding_total_duration = dram_kv_perf_stats[31]
4180
+ dram_read_metadata_cache_hit_copy_avg_duration = dram_kv_perf_stats[32]
4181
+ dram_read_metadata_lookup_cache_total_avg_duration = dram_kv_perf_stats[33]
4182
+ dram_read_metadata_acquire_lock_avg_duration = dram_kv_perf_stats[34]
4183
+ dram_read_read_metadata_load_size = dram_kv_perf_stats[35]
4184
+
4185
+ stats_reporter.report_duration(
4186
+ iteration_step=self.step,
4187
+ event_name="dram_kv.perf.get.dram_read_duration_us",
4188
+ duration_ms=dram_read_duration,
4189
+ enable_tb_metrics=True,
4190
+ time_unit="us",
4191
+ )
4192
+ stats_reporter.report_duration(
4193
+ iteration_step=self.step,
4194
+ event_name="dram_kv.perf.get.dram_read_sharding_duration_us",
4195
+ duration_ms=dram_read_sharding_duration,
4196
+ enable_tb_metrics=True,
4197
+ time_unit="us",
4198
+ )
4199
+ stats_reporter.report_duration(
4200
+ iteration_step=self.step,
4201
+ event_name="dram_kv.perf.get.dram_read_cache_hit_copy_duration_us",
4202
+ duration_ms=dram_read_cache_hit_copy_duration,
4203
+ enable_tb_metrics=True,
4204
+ time_unit="us",
4205
+ )
4206
+ stats_reporter.report_duration(
4207
+ iteration_step=self.step,
4208
+ event_name="dram_kv.perf.get.dram_read_fill_row_storage_duration_us",
4209
+ duration_ms=dram_read_fill_row_storage_duration,
4210
+ enable_tb_metrics=True,
4211
+ time_unit="us",
4212
+ )
4213
+ stats_reporter.report_duration(
4214
+ iteration_step=self.step,
4215
+ event_name="dram_kv.perf.get.dram_read_lookup_cache_duration_us",
4216
+ duration_ms=dram_read_lookup_cache_duration,
4217
+ enable_tb_metrics=True,
4218
+ time_unit="us",
4219
+ )
4220
+ stats_reporter.report_duration(
4221
+ iteration_step=self.step,
4222
+ event_name="dram_kv.perf.get.dram_read_acquire_lock_duration_us",
4223
+ duration_ms=dram_read_acquire_lock_duration,
4224
+ enable_tb_metrics=True,
4225
+ time_unit="us",
4226
+ )
4227
+ stats_reporter.report_data_amount(
4228
+ iteration_step=self.step,
4229
+ event_name="dram_kv.perf.get.dram_read_missing_load",
4230
+ enable_tb_metrics=True,
4231
+ data_bytes=dram_read_missing_load,
4232
+ )
4233
+ stats_reporter.report_duration(
4234
+ iteration_step=self.step,
4235
+ event_name="dram_kv.perf.set.dram_write_sharing_duration_us",
4236
+ duration_ms=dram_write_sharing_duration,
4237
+ enable_tb_metrics=True,
4238
+ time_unit="us",
4239
+ )
4240
+
4241
+ stats_reporter.report_duration(
4242
+ iteration_step=self.step,
4243
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_duration_us",
4244
+ duration_ms=dram_fwd_l1_eviction_write_duration,
4245
+ enable_tb_metrics=True,
4246
+ time_unit="us",
4247
+ )
4248
+ stats_reporter.report_duration(
4249
+ iteration_step=self.step,
4250
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_allocate_duration_us",
4251
+ duration_ms=dram_fwd_l1_eviction_write_allocate_duration,
4252
+ enable_tb_metrics=True,
4253
+ time_unit="us",
4254
+ )
4255
+ stats_reporter.report_duration(
4256
+ iteration_step=self.step,
4257
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_cache_copy_duration_us",
4258
+ duration_ms=dram_fwd_l1_eviction_write_cache_copy_duration,
4259
+ enable_tb_metrics=True,
4260
+ time_unit="us",
4261
+ )
4262
+ stats_reporter.report_duration(
4263
+ iteration_step=self.step,
4264
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_lookup_cache_duration_us",
4265
+ duration_ms=dram_fwd_l1_eviction_write_lookup_cache_duration,
4266
+ enable_tb_metrics=True,
4267
+ time_unit="us",
4268
+ )
4269
+ stats_reporter.report_duration(
4270
+ iteration_step=self.step,
4271
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_acquire_lock_duration_us",
4272
+ duration_ms=dram_fwd_l1_eviction_write_acquire_lock_duration,
4273
+ enable_tb_metrics=True,
4274
+ time_unit="us",
4275
+ )
4276
+ stats_reporter.report_data_amount(
4277
+ iteration_step=self.step,
4278
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load",
4279
+ data_bytes=dram_fwd_l1_eviction_write_missing_load,
4280
+ enable_tb_metrics=True,
4281
+ )
4282
+
4283
+ stats_reporter.report_duration(
4284
+ iteration_step=self.step,
4285
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_duration_us",
4286
+ duration_ms=dram_bwd_l1_cnflct_miss_write_duration,
4287
+ enable_tb_metrics=True,
4288
+ time_unit="us",
4289
+ )
4290
+ stats_reporter.report_duration(
4291
+ iteration_step=self.step,
4292
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_allocate_duration_us",
4293
+ duration_ms=dram_bwd_l1_cnflct_miss_write_allocate_duration,
4294
+ enable_tb_metrics=True,
4295
+ time_unit="us",
4296
+ )
4297
+ stats_reporter.report_duration(
4298
+ iteration_step=self.step,
4299
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_cache_copy_duration_us",
4300
+ duration_ms=dram_bwd_l1_cnflct_miss_write_cache_copy_duration,
4301
+ enable_tb_metrics=True,
4302
+ time_unit="us",
4303
+ )
4304
+ stats_reporter.report_duration(
4305
+ iteration_step=self.step,
4306
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_lookup_cache_duration_us",
4307
+ duration_ms=dram_bwd_l1_cnflct_miss_write_lookup_cache_duration,
4308
+ enable_tb_metrics=True,
4309
+ time_unit="us",
4310
+ )
4311
+ stats_reporter.report_duration(
4312
+ iteration_step=self.step,
4313
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_acquire_lock_duration_us",
4314
+ duration_ms=dram_bwd_l1_cnflct_miss_write_acquire_lock_duration,
4315
+ enable_tb_metrics=True,
4316
+ time_unit="us",
4317
+ )
4318
+ stats_reporter.report_data_amount(
4319
+ iteration_step=self.step,
4320
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load",
4321
+ data_bytes=dram_bwd_l1_cnflct_miss_write_missing_load,
4322
+ enable_tb_metrics=True,
4323
+ )
4324
+
4325
+ stats_reporter.report_data_amount(
4326
+ iteration_step=self.step,
4327
+ event_name="dram_kv.perf.get.dram_kv_read_counts",
4328
+ data_bytes=dram_kv_read_counts,
4329
+ enable_tb_metrics=True,
4330
+ )
4331
+
4332
+ stats_reporter.report_data_amount(
4333
+ iteration_step=self.step,
4334
+ event_name=self.dram_kv_allocated_bytes_stats_name,
4335
+ data_bytes=dram_kv_allocated_bytes,
4336
+ enable_tb_metrics=True,
4337
+ )
4338
+ stats_reporter.report_data_amount(
4339
+ iteration_step=self.step,
4340
+ event_name=self.dram_kv_actual_used_chunk_bytes_stats_name,
4341
+ data_bytes=dram_kv_actual_used_chunk_bytes,
4342
+ enable_tb_metrics=True,
4343
+ )
4344
+ stats_reporter.report_data_amount(
4345
+ iteration_step=self.step,
4346
+ event_name=self.dram_kv_mem_num_rows_stats_name,
4347
+ data_bytes=dram_kv_num_rows,
4348
+ enable_tb_metrics=True,
4349
+ )
4350
+ stats_reporter.report_duration(
4351
+ iteration_step=self.step,
4352
+ event_name="dram_kv.perf.set.dram_eviction_score_write_sharding_total_duration_us",
4353
+ duration_ms=dram_metadata_write_sharding_total_duration,
4354
+ enable_tb_metrics=True,
4355
+ time_unit="us",
4356
+ )
4357
+ stats_reporter.report_duration(
4358
+ iteration_step=self.step,
4359
+ event_name="dram_kv.perf.set.dram_eviction_score_write_total_duration_us",
4360
+ duration_ms=dram_metadata_write_total_duration,
4361
+ enable_tb_metrics=True,
4362
+ time_unit="us",
4363
+ )
4364
+ stats_reporter.report_duration(
4365
+ iteration_step=self.step,
4366
+ event_name="dram_kv.perf.set.dram_eviction_score_write_allocate_avg_duration_us",
4367
+ duration_ms=dram_metadata_write_allocate_avg_duration,
4368
+ enable_tb_metrics=True,
4369
+ time_unit="us",
4370
+ )
4371
+ stats_reporter.report_duration(
4372
+ iteration_step=self.step,
4373
+ event_name="dram_kv.perf.set.dram_eviction_score_write_lookup_cache_avg_duration_us",
4374
+ duration_ms=dram_metadata_write_lookup_cache_avg_duration,
4375
+ enable_tb_metrics=True,
4376
+ time_unit="us",
4377
+ )
4378
+ stats_reporter.report_duration(
4379
+ iteration_step=self.step,
4380
+ event_name="dram_kv.perf.set.dram_eviction_score_write_acquire_lock_avg_duration_us",
4381
+ duration_ms=dram_metadata_write_acquire_lock_avg_duration,
4382
+ enable_tb_metrics=True,
4383
+ time_unit="us",
4384
+ )
4385
+ stats_reporter.report_data_amount(
4386
+ iteration_step=self.step,
4387
+ event_name="dram_kv.perf.set.dram_eviction_score_write_cache_miss_avg_count",
4388
+ data_bytes=dram_metadata_write_cache_miss_avg_count,
4389
+ enable_tb_metrics=True,
4390
+ )
4391
+ stats_reporter.report_duration(
4392
+ iteration_step=self.step,
4393
+ event_name="dram_kv.perf.get.dram_eviction_score_read_total_duration_us",
4394
+ duration_ms=dram_read_metadata_total_duration,
4395
+ enable_tb_metrics=True,
4396
+ time_unit="us",
4397
+ )
4398
+ stats_reporter.report_duration(
4399
+ iteration_step=self.step,
4400
+ event_name="dram_kv.perf.get.dram_eviction_score_read_sharding_total_duration_us",
4401
+ duration_ms=dram_read_metadata_sharding_total_duration,
4402
+ enable_tb_metrics=True,
4403
+ time_unit="us",
4404
+ )
4405
+ stats_reporter.report_duration(
4406
+ iteration_step=self.step,
4407
+ event_name="dram_kv.perf.get.dram_eviction_score_read_cache_hit_copy_avg_duration_us",
4408
+ duration_ms=dram_read_metadata_cache_hit_copy_avg_duration,
4409
+ enable_tb_metrics=True,
4410
+ time_unit="us",
4411
+ )
4412
+ stats_reporter.report_duration(
4413
+ iteration_step=self.step,
4414
+ event_name="dram_kv.perf.get.dram_eviction_score_read_lookup_cache_total_avg_duration_us",
4415
+ duration_ms=dram_read_metadata_lookup_cache_total_avg_duration,
4416
+ enable_tb_metrics=True,
4417
+ time_unit="us",
4418
+ )
4419
+ stats_reporter.report_duration(
4420
+ iteration_step=self.step,
4421
+ event_name="dram_kv.perf.get.dram_eviction_score_read_acquire_lock_avg_duration_us",
4422
+ duration_ms=dram_read_metadata_acquire_lock_avg_duration,
4423
+ enable_tb_metrics=True,
4424
+ time_unit="us",
4425
+ )
4426
+ stats_reporter.report_data_amount(
4427
+ iteration_step=self.step,
4428
+ event_name="dram_kv.perf.get.dram_eviction_score_read_load_size",
4429
+ data_bytes=dram_read_read_metadata_load_size,
4430
+ enable_tb_metrics=True,
4431
+ )
4432
+
2179
4433
  def _recording_to_timer(
2180
4434
  self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
2181
4435
  ) -> Any:
@@ -2191,3 +4445,484 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2191
4445
  return timer.recording(**kwargs)
2192
4446
  # No-Op context manager
2193
4447
  return contextlib.nullcontext()
4448
+
4449
+ def fetch_from_l1_sp_w_row_ids(
4450
+ self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False
4451
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
4452
+ """
4453
+ Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
4454
+ @return: updated_weights/optimizer_states, mask of which rows are filled
4455
+ """
4456
+ if not self.enable_optimizer_offloading and only_get_optimizer_states:
4457
+ raise RuntimeError(
4458
+ "Optimizer states are not offloaded, while only_get_optimizer_states is True"
4459
+ )
4460
+
4461
+ # NOTE: Remove this once there is support for fetching multiple
4462
+ # optimizer states in fetch_from_l1_sp_w_row_ids
4463
+ if only_get_optimizer_states and self.optimizer not in [
4464
+ OptimType.EXACT_ROWWISE_ADAGRAD,
4465
+ OptimType.PARTIAL_ROWWISE_ADAM,
4466
+ ]:
4467
+ raise RuntimeError(
4468
+ f"Fetching optimizer states using fetch_from_l1_sp_w_row_ids() is not yet supported for {self.optimizer}"
4469
+ )
4470
+
4471
+ def split_results_by_opt_states(
4472
+ updated_weights: torch.Tensor, cache_location_mask: torch.Tensor
4473
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
4474
+ if not only_get_optimizer_states:
4475
+ return [updated_weights], cache_location_mask
4476
+ # TODO: support mixed dimension case
4477
+ # currently only supports tables with the same max_D dimension
4478
+ opt_to_dim = self.optimizer.byte_offsets_along_row(
4479
+ self.max_D, self.weights_precision, self.optimizer_state_dtypes
4480
+ )
4481
+ updated_opt_states = []
4482
+ for opt_name, dim in opt_to_dim.items():
4483
+ opt_dtype = self.optimizer._extract_dtype(
4484
+ self.optimizer_state_dtypes, opt_name
4485
+ )
4486
+ updated_opt_states.append(
4487
+ updated_weights.view(dtype=torch.uint8)[:, dim[0] : dim[1]].view(
4488
+ dtype=opt_dtype
4489
+ )
4490
+ )
4491
+ return updated_opt_states, cache_location_mask
4492
+
4493
+ with torch.no_grad():
4494
+ weights_dtype = self.weights_precision.as_dtype()
4495
+ step = self.step
4496
+ with record_function(f"## fetch_from_l1_{step}_{self.tbe_unique_id} ##"):
4497
+ lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
4498
+ row_ids,
4499
+ self.lxu_cache_state,
4500
+ self.total_hash_size,
4501
+ )
4502
+ updated_weights = torch.empty(
4503
+ row_ids.numel(),
4504
+ self.cache_row_dim,
4505
+ device=self.current_device,
4506
+ dtype=weights_dtype,
4507
+ )
4508
+
4509
+ # D2D copy cache
4510
+ cache_location_mask = lxu_cache_locations >= 0
4511
+ torch.ops.fbgemm.masked_index_select(
4512
+ updated_weights,
4513
+ lxu_cache_locations,
4514
+ self.lxu_cache_weights,
4515
+ torch.tensor(
4516
+ [row_ids.numel()],
4517
+ device=self.current_device,
4518
+ dtype=torch.int32,
4519
+ ),
4520
+ )
4521
+
4522
+ with record_function(f"## fetch_from_sp_{step}_{self.tbe_unique_id} ##"):
4523
+ if len(self.ssd_scratch_pad_eviction_data) > 0:
4524
+ sp = self.ssd_scratch_pad_eviction_data[0][0]
4525
+ sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
4526
+ self.current_device
4527
+ )
4528
+ actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
4529
+ if actions_count_gpu.item() == 0:
4530
+ # no action to take
4531
+ return split_results_by_opt_states(
4532
+ updated_weights, cache_location_mask
4533
+ )
4534
+
4535
+ sp_idx = sp_idx[:actions_count_gpu]
4536
+
4537
+ # -1 in lxu_cache_locations means the row is not in L1 cache and in SP
4538
+ # fill the row_ids in L1 with -2, >0 values means in SP
4539
+ # @eg. updated_row_ids_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
4540
+ updated_row_ids_in_sp = row_ids.masked_fill(
4541
+ lxu_cache_locations != -1, -2
4542
+ )
4543
+ # sort the sp_idx for binary search
4544
+ # should already be sorted
4545
+ # sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
4546
+ # @eg. sp_idx = [4, 2, 1, 3, 10]
4547
+ # @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
4548
+ sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
4549
+ # search rows id in sp against the SP indexes to find location of the rows in SP
4550
+ # @eg: updated_ids_in_sp_idx = [0, 5, 0, 1, 0, 2, 3, 4, 4]
4551
+ # @eg: 5 is OOB
4552
+ updated_ids_in_sp_idx = torch.searchsorted(
4553
+ sorted_sp_idx, updated_row_ids_in_sp
4554
+ )
4555
+ # does not found in SP will Out of Bound
4556
+ oob_sp_idx = updated_ids_in_sp_idx >= sp_idx.numel()
4557
+ # make the oob items in bound
4558
+ # @eg updated_ids_in_sp_idx=[0, 0, 0, 1, 0, 2, 3, 4, 4]
4559
+ updated_ids_in_sp_idx[oob_sp_idx] = 0
4560
+
4561
+ # -1s locations will be filtered out in masked_index_select
4562
+ sp_locations_in_updated_weights = torch.full_like(
4563
+ updated_row_ids_in_sp, -1
4564
+ )
4565
+ # torch.searchsorted is not exact match,
4566
+ # we only take exact matched rows, where the id is found in SP.
4567
+ # @eg 5 in updated_row_ids_in_sp is not in sp_idx, but has 4 in updated_ids_in_sp_idx
4568
+ # @eg sorted_sp_idx[updated_ids_in_sp_idx]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
4569
+ # @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
4570
+ exact_match_mask = (
4571
+ sorted_sp_idx[updated_ids_in_sp_idx] == updated_row_ids_in_sp
4572
+ )
4573
+ # Get the location of the row ids found in SP.
4574
+ # @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
4575
+ sp_locations_found = sp_idx_inverse_indices[
4576
+ updated_ids_in_sp_idx[exact_match_mask]
4577
+ ]
4578
+ # @eg: sp_locations_in_updated_weights=[ 2, -1, 2, 1, -1, 3, 0, -1, 4]
4579
+ sp_locations_in_updated_weights[exact_match_mask] = (
4580
+ sp_locations_found
4581
+ )
4582
+
4583
+ # D2D copy SP
4584
+ torch.ops.fbgemm.masked_index_select(
4585
+ updated_weights,
4586
+ sp_locations_in_updated_weights,
4587
+ sp,
4588
+ torch.tensor(
4589
+ [row_ids.numel()],
4590
+ device=self.current_device,
4591
+ dtype=torch.int32,
4592
+ ),
4593
+ )
4594
+ # cache_location_mask is the mask of rows in L1
4595
+ # exact_match_mask is the mask of rows in SP
4596
+ cache_location_mask = torch.logical_or(
4597
+ cache_location_mask, exact_match_mask
4598
+ )
4599
+
4600
+ return split_results_by_opt_states(updated_weights, cache_location_mask)
4601
+
4602
+ def register_backward_hook_before_eviction(
4603
+ self, backward_hook: Callable[[torch.Tensor], None]
4604
+ ) -> None:
4605
+ """
4606
+ Register a backward hook to the TBE module.
4607
+ And make sure this is called before the sp eviction hook.
4608
+ """
4609
+ # make sure this hook is the first one to be executed
4610
+ hooks = []
4611
+ backward_hooks = self.placeholder_autograd_tensor._backward_hooks
4612
+ if backward_hooks is not None:
4613
+ for _handle_id, hook in backward_hooks.items():
4614
+ hooks.append(hook)
4615
+ backward_hooks.clear()
4616
+
4617
+ self.placeholder_autograd_tensor.register_hook(backward_hook)
4618
+ for hook in hooks:
4619
+ self.placeholder_autograd_tensor.register_hook(hook)
4620
+
4621
+ def set_local_weight_counts_for_table(
4622
+ self, table_idx: int, weight_count: int
4623
+ ) -> None:
4624
+ self.local_weight_counts[table_idx] = weight_count
4625
+
4626
+ def set_global_id_per_rank_for_table(
4627
+ self, table_idx: int, global_id: torch.Tensor
4628
+ ) -> None:
4629
+ self.global_id_per_rank[table_idx] = global_id
4630
+
4631
+ def direct_write_embedding(
4632
+ self,
4633
+ indices: torch.Tensor,
4634
+ offsets: torch.Tensor,
4635
+ weights: torch.Tensor,
4636
+ ) -> None:
4637
+ """
4638
+ Directly write the weights to L1, SP and backend without relying on auto-gradient for embedding cache.
4639
+ Please refer to design doc for more details: https://docs.google.com/document/d/1TJHKvO1m3-5tYAKZGhacXnGk7iCNAzz7wQlrFbX_LDI/edit?tab=t.0
4640
+ """
4641
+ assert (
4642
+ self._embedding_cache_mode
4643
+ ), "Must be in embedding_cache_mode to support direct_write_embedding method."
4644
+
4645
+ B_offsets = None
4646
+ max_B = -1
4647
+
4648
+ with torch.no_grad():
4649
+ # Wait for any ongoing prefetch operations to complete before starting direct_write
4650
+ current_stream = torch.cuda.current_stream()
4651
+ current_stream.wait_event(self.prefetch_complete_event)
4652
+
4653
+ # Create local step events for internal sequential execution
4654
+ weights_dtype = self.weights_precision.as_dtype()
4655
+ assert (
4656
+ weights_dtype == weights.dtype
4657
+ ), f"Expected embedding table dtype {weights_dtype} is same with input weight dtype, but got {weights.dtype}"
4658
+
4659
+ # Pad the weights to match self.max_D width if necessary
4660
+ if weights.size(1) < self.cache_row_dim:
4661
+ weights = torch.nn.functional.pad(
4662
+ weights, (0, self.cache_row_dim - weights.size(1))
4663
+ )
4664
+
4665
+ step = self.step
4666
+
4667
+ # step 0: run backward hook for prefetch if prefetch pipeline is enabled before writing to L1 and SP
4668
+ if self.prefetch_pipeline:
4669
+ self._update_cache_counter_and_pointers(nn.Module(), torch.empty(0))
4670
+
4671
+ # step 1: lookup and write to l1 cache
4672
+ with record_function(
4673
+ f"## direct_write_to_l1_{step}_{self.tbe_unique_id} ##"
4674
+ ):
4675
+ if self.gather_ssd_cache_stats:
4676
+ self.local_ssd_cache_stats.zero_()
4677
+
4678
+ # Linearize indices
4679
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
4680
+ self.hash_size_cumsum,
4681
+ indices,
4682
+ offsets,
4683
+ B_offsets,
4684
+ max_B,
4685
+ )
4686
+
4687
+ lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
4688
+ linear_cache_indices,
4689
+ self.lxu_cache_state,
4690
+ self.total_hash_size,
4691
+ )
4692
+ cache_location_mask = lxu_cache_locations >= 0
4693
+
4694
+ # Get the cache locations for the row_ids that are already in the cache
4695
+ cache_locations = lxu_cache_locations[cache_location_mask]
4696
+
4697
+ # Get the corresponding input weights for these row_ids
4698
+ cache_weights = weights[cache_location_mask]
4699
+
4700
+ # Update the cache with these input weights
4701
+ if cache_locations.numel() > 0:
4702
+ self.lxu_cache_weights.index_put_(
4703
+ (cache_locations,), cache_weights, accumulate=False
4704
+ )
4705
+
4706
+ # Record completion of step 1
4707
+ current_stream.record_event(self.direct_write_l1_complete_event)
4708
+
4709
+ # step 2: pop the current scratch pad and write to next batch scratch pad if exists
4710
+ # Wait for step 1 to complete
4711
+ with record_function(
4712
+ f"## direct_write_to_sp_{step}_{self.tbe_unique_id} ##"
4713
+ ):
4714
+ if len(self.ssd_scratch_pad_eviction_data) > 0:
4715
+ self.ssd_scratch_pad_eviction_data.pop(0)
4716
+ if len(self.ssd_scratch_pad_eviction_data) > 0:
4717
+ # Wait for any pending backend reads to the next scratch pad
4718
+ # to complete before we write to it. Otherwise, stale backend data
4719
+ # will overwrite our direct_write updates.
4720
+ # The ssd_event_get marks completion of backend fetch operations.
4721
+ current_stream.wait_event(self.ssd_event_get)
4722
+
4723
+ # if scratch pad exists, write to next batch scratch pad
4724
+ sp = self.ssd_scratch_pad_eviction_data[0][0]
4725
+ sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
4726
+ self.current_device
4727
+ )
4728
+ actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
4729
+ if actions_count_gpu.item() != 0:
4730
+ # when no actional_count_gpu, no need to write to SP
4731
+ sp_idx = sp_idx[:actions_count_gpu]
4732
+
4733
+ # -1 in lxu_cache_locations means the row is not in L1 cache and in SP
4734
+ # fill the row_ids in L1 with -2, >0 values means in SP or backend
4735
+ # @eg. updated_indices_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
4736
+ updated_indices_in_sp = linear_cache_indices.masked_fill(
4737
+ lxu_cache_locations != -1, -2
4738
+ )
4739
+ # sort the sp_idx for binary search
4740
+ # should already be sorted
4741
+ # sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
4742
+ # @eg. sp_idx = [4, 2, 1, 3, 10]
4743
+ # @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
4744
+ sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
4745
+ # search rows id in sp against the SP indexes to find location of the rows in SP
4746
+ # @eg: updated_indices_in_sp = [0, 5, 0, 1, 0, 2, 3, 4, 4]
4747
+ # @eg: 5 is OOB
4748
+ updated_indices_in_sp_idx = torch.searchsorted(
4749
+ sorted_sp_idx, updated_indices_in_sp
4750
+ )
4751
+ # does not found in SP will Out of Bound
4752
+ oob_sp_idx = updated_indices_in_sp_idx >= sp_idx.numel()
4753
+ # make the oob items in bound
4754
+ # @eg updated_indices_in_sp=[0, 0, 0, 1, 0, 2, 3, 4, 4]
4755
+ updated_indices_in_sp_idx[oob_sp_idx] = 0
4756
+
4757
+ # torch.searchsorted is not exact match,
4758
+ # we only take exact matched rows, where the id is found in SP.
4759
+ # @eg 5 in updated_indices_in_sp is not in sp_idx, but has 4 in updated_indices_in_sp
4760
+ # @eg sorted_sp_idx[updated_indices_in_sp]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
4761
+ # @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
4762
+ exact_match_mask = (
4763
+ sorted_sp_idx[updated_indices_in_sp_idx]
4764
+ == updated_indices_in_sp
4765
+ )
4766
+ # Get the location of the row ids found in SP.
4767
+ # @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
4768
+ sp_locations_found = sp_idx_inverse_indices[
4769
+ updated_indices_in_sp[exact_match_mask]
4770
+ ]
4771
+ # Get the corresponding weights for the matched indices
4772
+ matched_weights = weights[exact_match_mask]
4773
+
4774
+ # Write the weights to the sparse tensor at the found locations
4775
+ if sp_locations_found.numel() > 0:
4776
+ sp.index_put_(
4777
+ (sp_locations_found,),
4778
+ matched_weights,
4779
+ accumulate=False,
4780
+ )
4781
+ current_stream.record_event(self.direct_write_sp_complete_event)
4782
+
4783
+ # step 3: write l1 cache missing rows to backend
4784
+ # Wait for step 2 to complete
4785
+ with record_function(
4786
+ f"## direct_write_to_backend_{step}_{self.tbe_unique_id} ##"
4787
+ ):
4788
+ # Use the existing ssd_eviction_stream for all backend write operations
4789
+ # This stream is already created with low priority during initialization
4790
+ with torch.cuda.stream(self.ssd_eviction_stream):
4791
+ # Create a mask for indices not in L1 cache
4792
+ non_cache_mask = ~cache_location_mask
4793
+
4794
+ # Calculate the count of valid indices (those not in L1 cache)
4795
+ valid_count = non_cache_mask.sum().to(torch.int64).cpu()
4796
+
4797
+ if valid_count.item() > 0:
4798
+ # Extract only the indices and weights that are not in L1 cache
4799
+ non_cache_indices = linear_cache_indices[non_cache_mask]
4800
+ non_cache_weights = weights[non_cache_mask]
4801
+
4802
+ # Move tensors to CPU for set_cuda
4803
+ cpu_indices = non_cache_indices.cpu()
4804
+ cpu_weights = non_cache_weights.cpu()
4805
+
4806
+ # Write to backend - only sending the non-cache indices and weights
4807
+ self.record_function_via_dummy_profile(
4808
+ f"## ssd_write_{step}_set_cuda_{self.tbe_unique_id} ##",
4809
+ self.ssd_db.set_cuda,
4810
+ cpu_indices,
4811
+ cpu_weights,
4812
+ valid_count,
4813
+ self.timestep,
4814
+ is_bwd=False,
4815
+ )
4816
+
4817
+ # Return control to the main stream without waiting for the backend operation to complete
4818
+
4819
+ def get_free_cpu_memory_gb(self) -> float:
4820
+ def _get_mem_available() -> float:
4821
+ if sys.platform.startswith("linux"):
4822
+ info = {}
4823
+ with open("/proc/meminfo") as f:
4824
+ for line in f:
4825
+ p = line.split()
4826
+ info[p[0].strip(":").lower()] = int(p[1]) * 1024
4827
+ if "memavailable" in info:
4828
+ # Linux >= 3.14
4829
+ return info["memavailable"]
4830
+ else:
4831
+ return info["memfree"] + info["cached"]
4832
+ else:
4833
+ raise RuntimeError(
4834
+ "Unsupported platform for free memory eviction, pls use ID count eviction tirgger mode"
4835
+ )
4836
+
4837
+ mem = _get_mem_available()
4838
+ return mem / (1024**3)
4839
+
4840
+ @classmethod
4841
+ def trigger_evict_in_all_tbes(cls) -> None:
4842
+ for tbe in cls._all_tbe_instances:
4843
+ tbe.ssd_db.trigger_feature_evict()
4844
+
4845
+ @classmethod
4846
+ def tbe_has_ongoing_eviction(cls) -> bool:
4847
+ for tbe in cls._all_tbe_instances:
4848
+ if tbe.ssd_db.is_evicting():
4849
+ return True
4850
+ return False
4851
+
4852
+ def set_free_mem_eviction_trigger_config(
4853
+ self, eviction_policy: EvictionPolicy
4854
+ ) -> None:
4855
+ self.enable_free_mem_trigger_eviction = True
4856
+ self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode
4857
+ assert (
4858
+ eviction_policy.eviction_free_mem_check_interval_batch is not None
4859
+ ), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode"
4860
+ self.eviction_free_mem_check_interval_batch: int = (
4861
+ eviction_policy.eviction_free_mem_check_interval_batch
4862
+ )
4863
+ assert (
4864
+ eviction_policy.eviction_free_mem_threshold_gb is not None
4865
+ ), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode"
4866
+ self.eviction_free_mem_threshold_gb: int = (
4867
+ eviction_policy.eviction_free_mem_threshold_gb
4868
+ )
4869
+ logging.info(
4870
+ f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}"
4871
+ )
4872
+
4873
+ def may_trigger_eviction(self) -> None:
4874
+ def is_first_tbe() -> bool:
4875
+ first = SSDTableBatchedEmbeddingBags._first_instance_ref
4876
+ return first is not None and first() is self
4877
+
4878
+ # We assume that the eviction time is less than free mem check interval time
4879
+ # So every time we reach this check, all evictions in all tbes should be finished.
4880
+ # We only need to check the first tbe because all tbes share the same free mem,
4881
+ # once the first tbe detect need to trigger eviction, it will call trigger func
4882
+ # in all tbes from _all_tbe_instances
4883
+ if (
4884
+ self.enable_free_mem_trigger_eviction
4885
+ and self.step % self.eviction_free_mem_check_interval_batch == 0
4886
+ and self.training
4887
+ and is_first_tbe()
4888
+ ):
4889
+ if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction():
4890
+ SSDTableBatchedEmbeddingBags._eviction_triggered = False
4891
+
4892
+ free_cpu_mem_gb = self.get_free_cpu_memory_gb()
4893
+ local_evict_trigger = int(
4894
+ free_cpu_mem_gb < self.eviction_free_mem_threshold_gb
4895
+ )
4896
+ tensor_flag = torch.tensor(
4897
+ local_evict_trigger,
4898
+ device=self.current_device,
4899
+ dtype=torch.int,
4900
+ )
4901
+ world_size = dist.get_world_size(self._pg)
4902
+ if world_size > 1:
4903
+ dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg)
4904
+ global_evict_trigger = tensor_flag.item()
4905
+ else:
4906
+ global_evict_trigger = local_evict_trigger
4907
+ if (
4908
+ global_evict_trigger >= 1
4909
+ and SSDTableBatchedEmbeddingBags._eviction_triggered
4910
+ ):
4911
+ logging.warning(
4912
+ f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true"
4913
+ )
4914
+ if (
4915
+ global_evict_trigger >= 1
4916
+ and not SSDTableBatchedEmbeddingBags._eviction_triggered
4917
+ ):
4918
+ SSDTableBatchedEmbeddingBags._eviction_triggered = True
4919
+ SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes()
4920
+ logging.info(
4921
+ f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction"
4922
+ )
4923
+
4924
+ def reset_inference_mode(self) -> None:
4925
+ """
4926
+ Reset the inference mode
4927
+ """
4928
+ self.eval()