fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,4908 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+ # pyre-ignore-all-errors[13,56]
10
+
11
+ import contextlib
12
+ import functools
13
+ import itertools
14
+ import logging
15
+ import math
16
+ import os
17
+ import threading
18
+ import time
19
+ from functools import cached_property
20
+ from math import floor, log2
21
+ from typing import Any, Callable, ClassVar, Optional, Union
22
+ import torch # usort:skip
23
+ import weakref
24
+
25
+ # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
26
+ import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
27
+ from fbgemm_gpu.runtime_monitor import (
28
+ AsyncSeriesTimer,
29
+ TBEStatsReporter,
30
+ TBEStatsReporterConfig,
31
+ )
32
+ from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
33
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
34
+ BackendType,
35
+ BoundsCheckMode,
36
+ CacheAlgorithm,
37
+ EmbeddingLocation,
38
+ EvictionPolicy,
39
+ get_bounds_check_version_for_platform,
40
+ KVZCHParams,
41
+ PoolingMode,
42
+ SplitState,
43
+ )
44
+ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
45
+ apply_split_helper,
46
+ CounterBasedRegularizationDefinition,
47
+ CowClipDefinition,
48
+ RESParams,
49
+ UVMCacheStatsIndex,
50
+ WeightDecayMode,
51
+ )
52
+ from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
53
+ generate_vbe_metadata,
54
+ is_torchdynamo_compiling,
55
+ )
56
+ from torch import distributed as dist, nn, Tensor # usort:skip
57
+ import sys
58
+ from dataclasses import dataclass
59
+
60
+ from torch.autograd.profiler import record_function
61
+
62
+ from ..cache import get_unique_indices_v2
63
+ from .common import ASSOC, pad4, tensor_pad4
64
+ from .utils.partially_materialized_tensor import PartiallyMaterializedTensor
65
+
66
+
67
+ @dataclass
68
+ class IterData:
69
+ indices: Tensor
70
+ offsets: Tensor
71
+ lxu_cache_locations: Tensor
72
+ lxu_cache_ptrs: Tensor
73
+ actions_count_gpu: Tensor
74
+ cache_set_inverse_indices: Tensor
75
+ B_offsets: Optional[Tensor] = None
76
+ max_B: Optional[int] = -1
77
+
78
+
79
+ @dataclass
80
+ class KVZCHCachedData:
81
+ cached_optimizer_states_per_table: list[list[torch.Tensor]]
82
+ cached_weight_tensor_per_table: list[torch.Tensor]
83
+ cached_id_tensor_per_table: list[torch.Tensor]
84
+ cached_bucket_splits: list[torch.Tensor]
85
+
86
+
87
+ class SSDTableBatchedEmbeddingBags(nn.Module):
88
+ D_offsets: Tensor
89
+ lxu_cache_weights: Tensor
90
+ lru_state: Tensor
91
+ lxu_cache_weights: Tensor
92
+ lxu_cache_state: Tensor
93
+ momentum1_dev: Tensor
94
+ momentum1_uvm: Tensor
95
+ momentum1_host: Tensor
96
+ momentum1_placements: Tensor
97
+ momentum1_offsets: Tensor
98
+ weights_dev: Tensor
99
+ weights_uvm: Tensor
100
+ weights_host: Tensor
101
+ weights_placements: Tensor
102
+ weights_offsets: Tensor
103
+ _local_instance_index: int = -1
104
+ res_params: RESParams
105
+ table_names: list[str]
106
+ _all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet()
107
+ _first_instance_ref: ClassVar[weakref.ref] = None
108
+ _eviction_triggered: ClassVar[bool] = False
109
+
110
+ def __init__(
111
+ self,
112
+ embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
113
+ feature_table_map: Optional[list[int]], # [T]
114
+ cache_sets: int,
115
+ # A comma-separated string, e.g. "/data00_nvidia0,/data01_nvidia0/", db shards
116
+ # will be placed in these paths round-robin.
117
+ ssd_storage_directory: str,
118
+ ssd_rocksdb_shards: int = 1,
119
+ ssd_memtable_flush_period: int = -1,
120
+ ssd_memtable_flush_offset: int = -1,
121
+ ssd_l0_files_per_compact: int = 4,
122
+ ssd_rate_limit_mbps: int = 0,
123
+ ssd_size_ratio: int = 10,
124
+ ssd_compaction_trigger: int = 8,
125
+ ssd_rocksdb_write_buffer_size: int = 2 * 1024 * 1024 * 1024,
126
+ ssd_max_write_buffer_num: int = 4,
127
+ ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED,
128
+ ssd_uniform_init_lower: float = -0.01,
129
+ ssd_uniform_init_upper: float = 0.01,
130
+ ssd_block_cache_size_per_tbe: int = 0,
131
+ weights_precision: SparseType = SparseType.FP32,
132
+ output_dtype: SparseType = SparseType.FP32,
133
+ optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD,
134
+ # General Optimizer args
135
+ stochastic_rounding: bool = True,
136
+ gradient_clipping: bool = False,
137
+ max_gradient: float = 1.0,
138
+ max_norm: float = 0.0,
139
+ learning_rate: float = 0.01,
140
+ eps: float = 1.0e-8, # used by Adagrad, LAMB, and Adam
141
+ momentum: float = 0.9, # used by LARS-SGD
142
+ weight_decay: float = 0.0, # used by LARS-SGD, LAMB, ADAM, and Rowwise Adagrad
143
+ weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, # used by Rowwise Adagrad
144
+ eta: float = 0.001, # used by LARS-SGD,
145
+ beta1: float = 0.9, # used by LAMB and ADAM
146
+ beta2: float = 0.999, # used by LAMB and ADAM
147
+ counter_based_regularization: Optional[
148
+ CounterBasedRegularizationDefinition
149
+ ] = None, # used by Rowwise Adagrad
150
+ cowclip_regularization: Optional[
151
+ CowClipDefinition
152
+ ] = None, # used by Rowwise Adagrad
153
+ pooling_mode: PoolingMode = PoolingMode.SUM,
154
+ bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
155
+ # Parameter Server Configs
156
+ ps_hosts: Optional[tuple[tuple[str, int]]] = None,
157
+ ps_max_key_per_request: Optional[int] = None,
158
+ ps_client_thread_num: Optional[int] = None,
159
+ ps_max_local_index_length: Optional[int] = None,
160
+ tbe_unique_id: int = -1,
161
+ # If set to True, will use `ssd_storage_directory` as the ssd paths.
162
+ # If set to False, will use the default ssd paths.
163
+ # In local test we need to use the pass in path for rocksdb creation
164
+ # fn production we could either use the default ssd mount points or explicity specify ssd
165
+ # mount points using `ssd_storage_directory`.
166
+ use_passed_in_path: int = True,
167
+ gather_ssd_cache_stats: Optional[bool] = False,
168
+ stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
169
+ l2_cache_size: int = 0,
170
+ # Set to True to enable pipeline prefetching
171
+ prefetch_pipeline: bool = False,
172
+ # Set to True to alloc a UVM tensor using malloc+cudaHostRegister.
173
+ # Set to False to use cudaMallocManaged
174
+ uvm_host_mapped: bool = False,
175
+ enable_async_update: bool = True, # whether enable L2/rocksdb write to async background thread
176
+ # if > 0, insert all kv pairs to rocksdb at init time, in chunks of *bulk_init_chunk_size* bytes
177
+ # number of rows will be decided by bulk_init_chunk_size / size_of_each_row
178
+ bulk_init_chunk_size: int = 0,
179
+ lazy_bulk_init_enabled: bool = False,
180
+ backend_type: BackendType = BackendType.SSD,
181
+ kv_zch_params: Optional[KVZCHParams] = None,
182
+ enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
183
+ res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
184
+ flushing_block_size: int = 2_000_000_000, # 2GB
185
+ table_names: Optional[list[str]] = None,
186
+ use_rowwise_bias_correction: bool = False, # For Adam use
187
+ optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
188
+ pg: Optional[dist.ProcessGroup] = None,
189
+ ) -> None:
190
+ super(SSDTableBatchedEmbeddingBags, self).__init__()
191
+
192
+ # Set the optimizer
193
+ assert optimizer in (
194
+ OptimType.EXACT_ROWWISE_ADAGRAD,
195
+ OptimType.PARTIAL_ROWWISE_ADAM,
196
+ OptimType.ADAM,
197
+ ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
198
+ self.optimizer = optimizer
199
+
200
+ # Set the table weight and output dtypes
201
+ assert weights_precision in (SparseType.FP32, SparseType.FP16)
202
+ self.weights_precision = weights_precision
203
+ self.output_dtype: int = output_dtype.as_int()
204
+
205
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
206
+ # Adagrad currently only supports FP32 for momentum1
207
+ self.optimizer_state_dtypes: dict[str, SparseType] = {
208
+ "momentum1": SparseType.FP32,
209
+ }
210
+ else:
211
+ self.optimizer_state_dtypes: dict[str, SparseType] = optimizer_state_dtypes
212
+
213
+ # Zero collision TBE configurations
214
+ self.kv_zch_params = kv_zch_params
215
+ self.backend_type = backend_type
216
+ self.enable_optimizer_offloading: bool = False
217
+ self.backend_return_whole_row: bool = False
218
+ self._embedding_cache_mode: bool = False
219
+ self.load_ckpt_without_opt: bool = False
220
+ if self.kv_zch_params:
221
+ self.kv_zch_params.validate()
222
+ self.load_ckpt_without_opt = (
223
+ # pyre-ignore [16]
224
+ self.kv_zch_params.load_ckpt_without_opt
225
+ )
226
+ self.enable_optimizer_offloading = (
227
+ # pyre-ignore [16]
228
+ self.kv_zch_params.enable_optimizer_offloading
229
+ )
230
+ self.backend_return_whole_row = (
231
+ # pyre-ignore [16]
232
+ self.kv_zch_params.backend_return_whole_row
233
+ )
234
+
235
+ if self.enable_optimizer_offloading:
236
+ logging.info("Optimizer state offloading is enabled")
237
+ if self.backend_return_whole_row:
238
+ assert (
239
+ self.backend_type == BackendType.DRAM
240
+ ), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}"
241
+ logging.info(
242
+ "Backend will return whole row including metaheader, weight and optimizer for checkpoint"
243
+ )
244
+ # pyre-ignore [16]
245
+ self._embedding_cache_mode = self.kv_zch_params.embedding_cache_mode
246
+ if self._embedding_cache_mode:
247
+ logging.info("KVZCH is in embedding_cache_mode")
248
+ assert self.optimizer in [
249
+ OptimType.EXACT_ROWWISE_ADAGRAD
250
+ ], f"only EXACT_ROWWISE_ADAGRAD supports embedding cache mode, but got {self.optimizer}"
251
+ if self.load_ckpt_without_opt:
252
+ if (
253
+ # pyre-ignore [16]
254
+ self.kv_zch_params.optimizer_type_for_st
255
+ == OptimType.PARTIAL_ROWWISE_ADAM.value
256
+ ):
257
+ self.optimizer = OptimType.PARTIAL_ROWWISE_ADAM
258
+ logging.info(
259
+ f"Override optimizer type with {self.optimizer=} for st publish"
260
+ )
261
+ if (
262
+ # pyre-ignore [16]
263
+ self.kv_zch_params.optimizer_state_dtypes_for_st
264
+ is not None
265
+ ):
266
+ optimizer_state_dtypes = {}
267
+ for k, v in dict(
268
+ self.kv_zch_params.optimizer_state_dtypes_for_st
269
+ ).items():
270
+ optimizer_state_dtypes[k] = SparseType.from_int(v)
271
+ self.optimizer_state_dtypes = optimizer_state_dtypes
272
+ logging.info(
273
+ f"Override optimizer_state_dtypes with {self.optimizer_state_dtypes=} for st publish"
274
+ )
275
+
276
+ self.pooling_mode = pooling_mode
277
+ self.bounds_check_mode_int: int = bounds_check_mode.value
278
+ self.embedding_specs = embedding_specs
279
+ self.table_names = table_names if table_names is not None else []
280
+ (rows, dims) = zip(*embedding_specs)
281
+ T_ = len(self.embedding_specs)
282
+ assert T_ > 0
283
+ # pyre-fixme[8]: Attribute has type `device`; used as `int`.
284
+ self.current_device: torch.device = torch.cuda.current_device()
285
+
286
+ self.enable_raw_embedding_streaming = enable_raw_embedding_streaming
287
+ # initialize the raw embedding streaming related variables
288
+ self.res_params: RESParams = res_params or RESParams()
289
+ if self.enable_raw_embedding_streaming:
290
+ self.res_params.table_sizes = [0] + list(itertools.accumulate(rows))
291
+ res_port_from_env = os.getenv("LOCAL_RES_PORT")
292
+ self.res_params.res_server_port = (
293
+ int(res_port_from_env) if res_port_from_env else 0
294
+ )
295
+ logging.info(
296
+ f"get env {self.res_params.res_server_port=}, at rank {dist.get_rank()}, with {self.res_params=}"
297
+ )
298
+
299
+ self.feature_table_map: list[int] = (
300
+ feature_table_map if feature_table_map is not None else list(range(T_))
301
+ )
302
+ T = len(self.feature_table_map)
303
+ assert T_ <= T
304
+ table_has_feature = [False] * T_
305
+ for t in self.feature_table_map:
306
+ table_has_feature[t] = True
307
+ assert all(table_has_feature), "Each table must have at least one feature!"
308
+
309
+ feature_dims = [dims[t] for t in self.feature_table_map]
310
+ D_offsets = [dims[t] for t in self.feature_table_map]
311
+ D_offsets = [0] + list(itertools.accumulate(D_offsets))
312
+
313
+ # Sum of row length of all tables
314
+ self.total_D: int = D_offsets[-1]
315
+
316
+ # Max number of elements required to store a row in the cache
317
+ self.max_D: int = max(dims)
318
+ self.register_buffer(
319
+ "D_offsets",
320
+ torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
321
+ )
322
+ assert self.D_offsets.numel() == T + 1
323
+ hash_size_cumsum = [0] + list(itertools.accumulate(rows))
324
+ if hash_size_cumsum[-1] == 0:
325
+ self.total_hash_size_bits: int = 0
326
+ else:
327
+ self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
328
+ self.register_buffer(
329
+ "table_hash_size_cumsum",
330
+ torch.tensor(
331
+ hash_size_cumsum, device=self.current_device, dtype=torch.int64
332
+ ),
333
+ )
334
+ # The last element is to easily access # of rows of each table by
335
+ self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1)
336
+ self.total_hash_size: int = hash_size_cumsum[-1]
337
+ # The last element is to easily access # of rows of each table by
338
+ # hash_size_cumsum[t + 1] - hash_size_cumsum[t]
339
+ hash_size_cumsum = [hash_size_cumsum[t] for t in self.feature_table_map] + [
340
+ hash_size_cumsum[-1]
341
+ ]
342
+ self.register_buffer(
343
+ "hash_size_cumsum",
344
+ torch.tensor(
345
+ hash_size_cumsum, device=self.current_device, dtype=torch.int64
346
+ ),
347
+ )
348
+
349
+ self.uvm_host_mapped = uvm_host_mapped
350
+ logging.info(
351
+ f"TBE will allocate a UVM buffer with is_host_mapped={uvm_host_mapped}"
352
+ )
353
+ self.bulk_init_chunk_size = bulk_init_chunk_size
354
+ self.lazy_init_thread: threading.Thread | None = None
355
+
356
+ # Buffers for bounds check
357
+ self.register_buffer(
358
+ "rows_per_table",
359
+ torch.tensor(
360
+ [rows[t] for t in self.feature_table_map],
361
+ device=self.current_device,
362
+ dtype=torch.int64,
363
+ ),
364
+ )
365
+ self.register_buffer(
366
+ "bounds_check_warning",
367
+ torch.tensor([0], device=self.current_device, dtype=torch.int64),
368
+ )
369
+ # Required for VBE
370
+ self.register_buffer(
371
+ "feature_dims",
372
+ torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
373
+ )
374
+ self.register_buffer(
375
+ "table_dims",
376
+ torch.tensor(dims, device="cpu", dtype=torch.int64),
377
+ )
378
+
379
+ (info_B_num_bits_, info_B_mask_) = torch.ops.fbgemm.get_infos_metadata(
380
+ self.D_offsets, # unused tensor
381
+ 1, # max_B
382
+ T, # T
383
+ )
384
+ self.info_B_num_bits: int = info_B_num_bits_
385
+ self.info_B_mask: int = info_B_mask_
386
+
387
+ assert cache_sets > 0
388
+ element_size = weights_precision.bit_rate() // 8
389
+ assert (
390
+ element_size == 4 or element_size == 2
391
+ ), f"Invalid element size {element_size}"
392
+ cache_size = cache_sets * ASSOC * element_size * self.cache_row_dim
393
+ logging.info(
394
+ f"Using cache for SSD with admission algorithm "
395
+ f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_rocksdb_shards} shards, "
396
+ f"SSD storage directory: {ssd_storage_directory}, "
397
+ f"Memtable Flush Period: {ssd_memtable_flush_period}, "
398
+ f"Memtable Flush Offset: {ssd_memtable_flush_offset}, "
399
+ f"Desired L0 files per compaction: {ssd_l0_files_per_compact}, "
400
+ f"Cache size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
401
+ f"weights precision: {weights_precision}, "
402
+ f"output dtype: {output_dtype}, "
403
+ f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
404
+ f"kv_zch_params: {kv_zch_params}, "
405
+ f"embedding spec: {embedding_specs}"
406
+ )
407
+ self.register_buffer(
408
+ "lxu_cache_state",
409
+ torch.zeros(
410
+ cache_sets, ASSOC, device=self.current_device, dtype=torch.int64
411
+ ).fill_(-1),
412
+ )
413
+ self.register_buffer(
414
+ "lru_state",
415
+ torch.zeros(
416
+ cache_sets, ASSOC, device=self.current_device, dtype=torch.int64
417
+ ),
418
+ )
419
+
420
+ self.step = 0
421
+ self.last_flush_step = -1
422
+
423
+ # Set prefetch pipeline
424
+ self.prefetch_pipeline: bool = prefetch_pipeline
425
+ self.prefetch_stream: Optional[torch.cuda.Stream] = None
426
+
427
+ # Cache locking counter for pipeline prefetching
428
+ if self.prefetch_pipeline:
429
+ self.register_buffer(
430
+ "lxu_cache_locking_counter",
431
+ torch.zeros(
432
+ cache_sets,
433
+ ASSOC,
434
+ device=self.current_device,
435
+ dtype=torch.int32,
436
+ ),
437
+ persistent=True,
438
+ )
439
+ else:
440
+ self.register_buffer(
441
+ "lxu_cache_locking_counter",
442
+ torch.zeros([0, 0], dtype=torch.int32, device=self.current_device),
443
+ persistent=False,
444
+ )
445
+
446
+ assert ssd_cache_location in (
447
+ EmbeddingLocation.MANAGED,
448
+ EmbeddingLocation.DEVICE,
449
+ )
450
+
451
+ cache_dtype = weights_precision.as_dtype()
452
+ if ssd_cache_location == EmbeddingLocation.MANAGED:
453
+ self.register_buffer(
454
+ "lxu_cache_weights",
455
+ torch.ops.fbgemm.new_unified_tensor(
456
+ torch.zeros(
457
+ 1,
458
+ device=self.current_device,
459
+ dtype=cache_dtype,
460
+ ),
461
+ [cache_sets * ASSOC, self.cache_row_dim],
462
+ is_host_mapped=self.uvm_host_mapped,
463
+ ),
464
+ )
465
+ else:
466
+ self.register_buffer(
467
+ "lxu_cache_weights",
468
+ torch.zeros(
469
+ cache_sets * ASSOC,
470
+ self.cache_row_dim,
471
+ device=self.current_device,
472
+ dtype=cache_dtype,
473
+ ),
474
+ )
475
+ assert (
476
+ cache_size
477
+ == self.lxu_cache_weights.numel()
478
+ * self.lxu_cache_weights.element_size()
479
+ ), "The precomputed cache_size does not match the actual cache size"
480
+
481
+ # Buffers for cache eviction
482
+ # For storing weights to evict
483
+ # The max number of rows to be evicted is limited by the number of
484
+ # slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to
485
+ # be the same shape as the L1 cache (lxu_cache_weights)
486
+ self.register_buffer(
487
+ "lxu_cache_evicted_weights",
488
+ torch.ops.fbgemm.new_unified_tensor(
489
+ torch.zeros(
490
+ 1,
491
+ device=self.current_device,
492
+ dtype=cache_dtype,
493
+ ),
494
+ self.lxu_cache_weights.shape,
495
+ is_host_mapped=self.uvm_host_mapped,
496
+ ),
497
+ )
498
+
499
+ # For storing embedding indices to evict to
500
+ self.register_buffer(
501
+ "lxu_cache_evicted_indices",
502
+ torch.ops.fbgemm.new_unified_tensor(
503
+ torch.zeros(
504
+ 1,
505
+ device=self.current_device,
506
+ dtype=torch.long,
507
+ ),
508
+ (self.lxu_cache_weights.shape[0],),
509
+ is_host_mapped=self.uvm_host_mapped,
510
+ ),
511
+ )
512
+
513
+ # For storing cache slots to evict
514
+ self.register_buffer(
515
+ "lxu_cache_evicted_slots",
516
+ torch.ops.fbgemm.new_unified_tensor(
517
+ torch.zeros(
518
+ 1,
519
+ device=self.current_device,
520
+ dtype=torch.int,
521
+ ),
522
+ (self.lxu_cache_weights.shape[0],),
523
+ is_host_mapped=self.uvm_host_mapped,
524
+ ),
525
+ )
526
+
527
+ # For storing the number of evicted rows
528
+ self.register_buffer(
529
+ "lxu_cache_evicted_count",
530
+ torch.ops.fbgemm.new_unified_tensor(
531
+ torch.zeros(
532
+ 1,
533
+ device=self.current_device,
534
+ dtype=torch.int,
535
+ ),
536
+ (1,),
537
+ is_host_mapped=self.uvm_host_mapped,
538
+ ),
539
+ )
540
+
541
+ self.timestep = 0
542
+
543
+ # Store the iteration number on GPU and CPU (used for certain optimizers)
544
+ persistent_iter_ = optimizer in (OptimType.PARTIAL_ROWWISE_ADAM,)
545
+ self.register_buffer(
546
+ "iter",
547
+ torch.zeros(1, dtype=torch.int64, device=self.current_device),
548
+ persistent=persistent_iter_,
549
+ )
550
+ self.iter_cpu: torch.Tensor = torch.zeros(1, dtype=torch.int64, device="cpu")
551
+
552
+ # Dummy profile configuration for measuring the SSD get/set time
553
+ # get and set are executed by another thread which (for some reason) is
554
+ # not traceable by PyTorch's Kineto. We workaround this problem by
555
+ # injecting a dummy kernel into the GPU stream to make it traceable
556
+ #
557
+ # This function can be enabled by setting an environment variable
558
+ # FBGEMM_SSD_TBE_USE_DUMMY_PROFILE=1
559
+ self.dummy_profile_tensor: Tensor = torch.as_tensor(
560
+ [0], device=self.current_device, dtype=torch.int
561
+ )
562
+ set_dummy_profile = os.environ.get("FBGEMM_SSD_TBE_USE_DUMMY_PROFILE")
563
+ use_dummy_profile = False
564
+ if set_dummy_profile is not None:
565
+ use_dummy_profile = int(set_dummy_profile) == 1
566
+ logging.info(
567
+ f"FBGEMM_SSD_TBE_USE_DUMMY_PROFILE is set to {set_dummy_profile}; "
568
+ f"Use dummy profile: {use_dummy_profile}"
569
+ )
570
+
571
+ self.record_function_via_dummy_profile: Callable[..., Any] = (
572
+ self.record_function_via_dummy_profile_factory(use_dummy_profile)
573
+ )
574
+
575
+ if use_passed_in_path:
576
+ ssd_dir_list = ssd_storage_directory.split(",")
577
+ for ssd_dir in ssd_dir_list:
578
+ os.makedirs(ssd_dir, exist_ok=True)
579
+
580
+ ssd_directory = ssd_storage_directory
581
+ # logging.info("DEBUG: weights_precision {}".format(weights_precision))
582
+
583
+ """
584
+ ##################### for ZCH v.Next loading checkpoints Short Term Solution #######################
585
+ weight_id tensor is the weight and optimizer keys, to load from checkpoint, weight_id tensor
586
+ needs to be loaded first, then we can load the weight and optimizer tensors.
587
+ However, the stateful checkpoint loading does not guarantee the tensor loading order, so we need
588
+ to cache the weight_id, weight and optimizer tensors untils all data are loaded, then we can apply
589
+ them to backend.
590
+ Currently, we'll cache the weight_id, weight and optimizer tensors in the KVZCHCachedData class,
591
+ and apply them to backend when all data are loaded. The downside of this solution is that we'll
592
+ have to duplicate a whole tensor memory to backend before we can release the python tensor memory,
593
+ which is not ideal.
594
+ The longer term solution is to support the caching from the backend side, and allow streaming based
595
+ data move from cached weight and optimizer to key/value format without duplicate one whole tensor's
596
+ memory.
597
+ """
598
+ self._cached_kvzch_data: Optional[KVZCHCachedData] = None
599
+ # initial embedding rows on this rank per table, this is used for loading checkpoint
600
+ self.local_weight_counts: list[int] = [0] * T_
601
+ # groundtruth global id on this rank per table, this is used for loading checkpoint
602
+ self.global_id_per_rank: list[torch.Tensor] = [torch.zeros(0)] * T_
603
+ # loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend
604
+ self.load_state_dict: bool = False
605
+
606
+ SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self)
607
+ if SSDTableBatchedEmbeddingBags._first_instance_ref is None:
608
+ SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self)
609
+
610
+ # create tbe unique id using rank index | local tbe idx
611
+ if tbe_unique_id == -1:
612
+ SSDTableBatchedEmbeddingBags._local_instance_index += 1
613
+ if dist.is_initialized():
614
+ assert (
615
+ SSDTableBatchedEmbeddingBags._local_instance_index < 1024
616
+ ), f"{SSDTableBatchedEmbeddingBags._local_instance_index}, more than 1024 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
617
+ tbe_unique_id = (
618
+ dist.get_rank() << 10
619
+ | SSDTableBatchedEmbeddingBags._local_instance_index
620
+ )
621
+ else:
622
+ logging.warning("dist is not initialized, treating as single gpu cases")
623
+ tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
624
+ self.tbe_unique_id = tbe_unique_id
625
+ self.l2_cache_size = l2_cache_size
626
+ logging.info(f"tbe_unique_id: {tbe_unique_id}")
627
+ self.enable_free_mem_trigger_eviction: bool = False
628
+ if self.backend_type == BackendType.SSD:
629
+ logging.info(
630
+ f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
631
+ f"enable_async_update:{enable_async_update}, passed_in_path={ssd_directory}, "
632
+ f"num_shards={ssd_rocksdb_shards}, num_threads={ssd_rocksdb_shards}, "
633
+ f"memtable_flush_period={ssd_memtable_flush_period}, memtable_flush_offset={ssd_memtable_flush_offset}, "
634
+ f"l0_files_per_compact={ssd_l0_files_per_compact}, max_D={self.max_D}, "
635
+ f"cache_row_size={self.cache_row_dim}, rate_limit_mbps={ssd_rate_limit_mbps}, "
636
+ f"size_ratio={ssd_size_ratio}, compaction_trigger={ssd_compaction_trigger}, "
637
+ f"lazy_bulk_init_enabled={lazy_bulk_init_enabled}, write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size}, "
638
+ f"max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num}, "
639
+ f"uniform_init_lower={ssd_uniform_init_lower}, uniform_init_upper={ssd_uniform_init_upper}, "
640
+ f"row_storage_bitwidth={weights_precision.bit_rate()}, block_cache_size_per_tbe={ssd_block_cache_size_per_tbe}, "
641
+ f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB, "
642
+ f"enable_raw_embedding_streaming:{self.enable_raw_embedding_streaming}, flushing_block_size:{flushing_block_size}"
643
+ )
644
+ # pyre-fixme[4]: Attribute must be annotated.
645
+ self._ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
646
+ ssd_directory,
647
+ ssd_rocksdb_shards,
648
+ ssd_rocksdb_shards,
649
+ ssd_memtable_flush_period,
650
+ ssd_memtable_flush_offset,
651
+ ssd_l0_files_per_compact,
652
+ self.cache_row_dim,
653
+ ssd_rate_limit_mbps,
654
+ ssd_size_ratio,
655
+ ssd_compaction_trigger,
656
+ ssd_rocksdb_write_buffer_size,
657
+ ssd_max_write_buffer_num,
658
+ ssd_uniform_init_lower,
659
+ ssd_uniform_init_upper,
660
+ weights_precision.bit_rate(), # row_storage_bitwidth
661
+ ssd_block_cache_size_per_tbe,
662
+ use_passed_in_path,
663
+ tbe_unique_id,
664
+ l2_cache_size,
665
+ enable_async_update,
666
+ self.enable_raw_embedding_streaming,
667
+ self.res_params.res_store_shards,
668
+ self.res_params.res_server_port,
669
+ self.res_params.table_names,
670
+ self.res_params.table_offsets,
671
+ self.res_params.table_sizes,
672
+ (
673
+ tensor_pad4(self.table_dims)
674
+ if self.enable_optimizer_offloading
675
+ else None
676
+ ),
677
+ (
678
+ self.table_hash_size_cumsum.cpu()
679
+ if self.enable_optimizer_offloading
680
+ else None
681
+ ),
682
+ flushing_block_size,
683
+ self._embedding_cache_mode, # disable_random_init
684
+ )
685
+ if self.bulk_init_chunk_size > 0:
686
+ self.ssd_uniform_init_lower: float = ssd_uniform_init_lower
687
+ self.ssd_uniform_init_upper: float = ssd_uniform_init_upper
688
+ if lazy_bulk_init_enabled:
689
+ self._lazy_initialize_ssd_tbe()
690
+ else:
691
+ self._insert_all_kv()
692
+ elif self.backend_type == BackendType.PS:
693
+ self._ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
694
+ [host[0] for host in ps_hosts], # pyre-ignore
695
+ [host[1] for host in ps_hosts],
696
+ tbe_unique_id,
697
+ (
698
+ ps_max_local_index_length
699
+ if ps_max_local_index_length is not None
700
+ else 54
701
+ ),
702
+ ps_client_thread_num if ps_client_thread_num is not None else 32,
703
+ ps_max_key_per_request if ps_max_key_per_request is not None else 500,
704
+ l2_cache_size,
705
+ self.cache_row_dim,
706
+ )
707
+ elif self.backend_type == BackendType.DRAM:
708
+ logging.info(
709
+ f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,"
710
+ f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
711
+ f"max_D={self.max_D},"
712
+ f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
713
+ f"row_storage_bitwidth={weights_precision.bit_rate()},"
714
+ f"self.cache_row_dim={self.cache_row_dim},"
715
+ f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
716
+ f"feature_dims={self.feature_dims},"
717
+ f"hash_size_cumsum={self.hash_size_cumsum},"
718
+ f"backend_return_whole_row={self.backend_return_whole_row}"
719
+ )
720
+ table_dims = (
721
+ tensor_pad4(self.table_dims)
722
+ if self.enable_optimizer_offloading
723
+ else None
724
+ ) # table_dims
725
+ eviction_config = None
726
+ if self.kv_zch_params and self.kv_zch_params.eviction_policy:
727
+ eviction_mem_threshold_gb = (
728
+ self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
729
+ if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
730
+ else self.l2_cache_size
731
+ )
732
+ kv_zch_params = self.kv_zch_params
733
+ eviction_policy = self.kv_zch_params.eviction_policy
734
+ if eviction_policy.eviction_trigger_mode == 5:
735
+ # If trigger mode is free_mem(5), populate config
736
+ self.set_free_mem_eviction_trigger_config(eviction_policy)
737
+
738
+ enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
739
+ [
740
+ int(x)
741
+ for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
742
+ ]
743
+ if eviction_policy.enable_eviction_for_feature_score_eviction_policy
744
+ is not None
745
+ else None
746
+ )
747
+ # Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
748
+ eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
749
+ eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
750
+ eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
751
+ eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
752
+ eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
753
+ eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
754
+ eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
755
+ eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
756
+ eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
757
+ eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
758
+ eviction_policy.training_id_keep_count, # training_id_keep_count for each table
759
+ enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
760
+ eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
761
+ table_dims.tolist() if table_dims is not None else None,
762
+ eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
763
+ eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
764
+ eviction_policy.interval_for_insufficient_eviction_s,
765
+ eviction_policy.interval_for_sufficient_eviction_s,
766
+ eviction_policy.interval_for_feature_statistics_decay_s,
767
+ )
768
+ self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
769
+ self.cache_row_dim,
770
+ ssd_uniform_init_lower,
771
+ ssd_uniform_init_upper,
772
+ eviction_config,
773
+ ssd_rocksdb_shards, # num_shards
774
+ ssd_rocksdb_shards, # num_threads
775
+ weights_precision.bit_rate(), # row_storage_bitwidth
776
+ table_dims,
777
+ (
778
+ self.table_hash_size_cumsum.cpu()
779
+ if self.enable_optimizer_offloading
780
+ else None
781
+ ), # hash_size_cumsum
782
+ self.backend_return_whole_row, # backend_return_whole_row
783
+ False, # enable_async_update
784
+ self._embedding_cache_mode, # disable_random_init
785
+ )
786
+ else:
787
+ raise AssertionError(f"Invalid backend type {self.backend_type}")
788
+
789
+ # pyre-fixme[20]: Argument `self` expected.
790
+ (low_priority, high_priority) = torch.cuda.Stream.priority_range()
791
+ # GPU stream for SSD cache eviction
792
+ self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
793
+ # GPU stream for SSD memory copy (also reused for feature score D2H)
794
+ self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
795
+ # GPU stream for async metadata operation
796
+ self.feature_score_stream = torch.cuda.Stream(priority=low_priority)
797
+
798
+ # SSD get completion event
799
+ self.ssd_event_get = torch.cuda.Event()
800
+ # SSD scratch pad eviction completion event
801
+ self.ssd_event_sp_evict = torch.cuda.Event()
802
+ # SSD cache eviction completion event
803
+ self.ssd_event_cache_evict = torch.cuda.Event()
804
+ # SSD backward completion event
805
+ self.ssd_event_backward = torch.cuda.Event()
806
+ # SSD get's input copy completion event
807
+ self.ssd_event_get_inputs_cpy = torch.cuda.Event()
808
+ if self._embedding_cache_mode:
809
+ # Direct write embedding completion event
810
+ self.direct_write_l1_complete_event: torch.cuda.streams.Event = (
811
+ torch.cuda.Event()
812
+ )
813
+ self.direct_write_sp_complete_event: torch.cuda.streams.Event = (
814
+ torch.cuda.Event()
815
+ )
816
+ # Prefetch operation completion event
817
+ self.prefetch_complete_event = torch.cuda.Event()
818
+
819
+ if self.prefetch_pipeline:
820
+ # SSD scratch pad index queue insert completion event
821
+ self.ssd_event_sp_idxq_insert: torch.cuda.streams.Event = torch.cuda.Event()
822
+ # SSD scratch pad index queue lookup completion event
823
+ self.ssd_event_sp_idxq_lookup: torch.cuda.streams.Event = torch.cuda.Event()
824
+
825
+ if self.enable_raw_embedding_streaming:
826
+ # RES reuse the eviction stream
827
+ self.ssd_event_cache_streamed: torch.cuda.streams.Event = torch.cuda.Event()
828
+ self.ssd_event_cache_streaming_synced: torch.cuda.streams.Event = (
829
+ torch.cuda.Event()
830
+ )
831
+ self.ssd_event_cache_streaming_computed: torch.cuda.streams.Event = (
832
+ torch.cuda.Event()
833
+ )
834
+ self.ssd_event_sp_streamed: torch.cuda.streams.Event = torch.cuda.Event()
835
+
836
+ # Updated buffers
837
+ self.register_buffer(
838
+ "lxu_cache_updated_weights",
839
+ torch.ops.fbgemm.new_unified_tensor(
840
+ torch.zeros(
841
+ 1,
842
+ device=self.current_device,
843
+ dtype=cache_dtype,
844
+ ),
845
+ self.lxu_cache_weights.shape,
846
+ is_host_mapped=self.uvm_host_mapped,
847
+ ),
848
+ )
849
+
850
+ # For storing embedding indices to update to
851
+ self.register_buffer(
852
+ "lxu_cache_updated_indices",
853
+ torch.ops.fbgemm.new_unified_tensor(
854
+ torch.zeros(
855
+ 1,
856
+ device=self.current_device,
857
+ dtype=torch.long,
858
+ ),
859
+ (self.lxu_cache_weights.shape[0],),
860
+ is_host_mapped=self.uvm_host_mapped,
861
+ ),
862
+ )
863
+
864
+ # For storing the number of updated rows
865
+ self.register_buffer(
866
+ "lxu_cache_updated_count",
867
+ torch.ops.fbgemm.new_unified_tensor(
868
+ torch.zeros(
869
+ 1,
870
+ device=self.current_device,
871
+ dtype=torch.int,
872
+ ),
873
+ (1,),
874
+ is_host_mapped=self.uvm_host_mapped,
875
+ ),
876
+ )
877
+
878
+ # (Indices, Count)
879
+ self.prefetched_info: list[tuple[Tensor, Tensor]] = []
880
+
881
+ self.timesteps_prefetched: list[int] = []
882
+ # TODO: add type annotation
883
+ # pyre-fixme[4]: Attribute must be annotated.
884
+ self.ssd_prefetch_data = []
885
+
886
+ # Scratch pad eviction data queue
887
+ self.ssd_scratch_pad_eviction_data: list[
888
+ tuple[Tensor, Tensor, Tensor, bool]
889
+ ] = []
890
+ self.ssd_location_update_data: list[tuple[Tensor, Tensor]] = []
891
+
892
+ if self.prefetch_pipeline:
893
+ # Scratch pad value queue
894
+ self.ssd_scratch_pads: list[tuple[Tensor, Tensor, Tensor]] = []
895
+
896
+ # pyre-ignore[4]
897
+ # Scratch pad index queue
898
+ self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(
899
+ -1
900
+ )
901
+
902
+ if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization:
903
+ raise AssertionError(
904
+ "weight_decay_mode = WeightDecayMode.COUNTER is not supported for SSD TBE."
905
+ )
906
+ counter_based_regularization = CounterBasedRegularizationDefinition()
907
+
908
+ if weight_decay_mode == WeightDecayMode.COWCLIP or cowclip_regularization:
909
+ raise AssertionError(
910
+ "weight_decay_mode = WeightDecayMode.COWCLIP is not supported for SSD TBE."
911
+ )
912
+ cowclip_regularization = CowClipDefinition()
913
+
914
+ self.learning_rate_tensor: torch.Tensor = torch.tensor(
915
+ learning_rate, device=torch.device("cpu"), dtype=torch.float32
916
+ )
917
+
918
+ self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs(
919
+ stochastic_rounding=stochastic_rounding,
920
+ gradient_clipping=gradient_clipping,
921
+ max_gradient=max_gradient,
922
+ max_norm=max_norm,
923
+ eps=eps,
924
+ beta1=beta1,
925
+ beta2=beta2,
926
+ weight_decay=weight_decay,
927
+ weight_decay_mode=weight_decay_mode.value,
928
+ eta=eta,
929
+ momentum=momentum,
930
+ counter_halflife=counter_based_regularization.counter_halflife,
931
+ adjustment_iter=counter_based_regularization.adjustment_iter,
932
+ adjustment_ub=counter_based_regularization.adjustment_ub,
933
+ learning_rate_mode=counter_based_regularization.learning_rate_mode.value,
934
+ grad_sum_decay=counter_based_regularization.grad_sum_decay.value,
935
+ tail_id_threshold=counter_based_regularization.tail_id_threshold.val,
936
+ is_tail_id_thresh_ratio=int(
937
+ counter_based_regularization.tail_id_threshold.is_ratio
938
+ ),
939
+ total_hash_size=-1, # Unused
940
+ weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
941
+ lower_bound=cowclip_regularization.lower_bound,
942
+ regularization_mode=weight_decay_mode.value,
943
+ use_rowwise_bias_correction=use_rowwise_bias_correction, # Used in Adam optimizer
944
+ )
945
+
946
+ table_embedding_dtype = weights_precision.as_dtype()
947
+
948
+ self._apply_split(
949
+ SplitState(
950
+ dev_size=0,
951
+ host_size=0,
952
+ uvm_size=0,
953
+ placements=[EmbeddingLocation.MANAGED_CACHING for _ in range(T_)],
954
+ offsets=[0] * (len(rows)),
955
+ ),
956
+ "weights",
957
+ # pyre-fixme[6]: For 3rd argument expected `Type[dtype]` but got `dtype`.
958
+ dtype=table_embedding_dtype,
959
+ )
960
+
961
+ # Create the optimizer state tensors
962
+ for template in self.optimizer.ssd_state_splits(
963
+ self.embedding_specs,
964
+ self.optimizer_state_dtypes,
965
+ self.enable_optimizer_offloading,
966
+ ):
967
+ # pyre-fixme[6]: For 3rd argument expected `Type[dtype]` but got `dtype`.
968
+ self._apply_split(*template)
969
+
970
+ # For storing current iteration data
971
+ self.current_iter_data: Optional[IterData] = None
972
+
973
+ # add placeholder require_grad param to enable autograd without nn.parameter
974
+ # this is needed to enable int8 embedding weights for SplitTableBatchedEmbedding
975
+ self.placeholder_autograd_tensor = nn.Parameter(
976
+ torch.zeros(0, device=self.current_device, dtype=torch.float)
977
+ )
978
+
979
+ # Register backward hook for evicting rows from a scratch pad to SSD
980
+ # post backward
981
+ self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad)
982
+
983
+ if self.prefetch_pipeline:
984
+ self.register_full_backward_pre_hook(
985
+ self._update_cache_counter_and_pointers
986
+ )
987
+
988
+ # stats reporter
989
+ self.gather_ssd_cache_stats = gather_ssd_cache_stats
990
+ self.stats_reporter: Optional[TBEStatsReporter] = (
991
+ stats_reporter_config.create_reporter() if stats_reporter_config else None
992
+ )
993
+ self.ssd_cache_stats_size = 6
994
+ # 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
995
+ # 4: N_conflict_unique_misses, 5: N_conflict_misses
996
+ self.last_reported_ssd_stats: list[float] = []
997
+ self.last_reported_step = 0
998
+
999
+ self.register_buffer(
1000
+ "ssd_cache_stats",
1001
+ torch.zeros(
1002
+ size=(self.ssd_cache_stats_size,),
1003
+ device=self.current_device,
1004
+ dtype=torch.int64,
1005
+ ),
1006
+ )
1007
+
1008
+ self.register_buffer(
1009
+ "local_ssd_cache_stats",
1010
+ torch.zeros(
1011
+ self.ssd_cache_stats_size,
1012
+ device=self.current_device,
1013
+ dtype=torch.int32,
1014
+ ),
1015
+ )
1016
+ logging.info(
1017
+ f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, "
1018
+ f"stats_reporter:{self.stats_reporter if self.stats_reporter else 'none'}"
1019
+ )
1020
+
1021
+ # prefetch launch a series of kernels, we use AsyncSeriesTimer to track the kernel time
1022
+ self.ssd_prefetch_read_timer: Optional[AsyncSeriesTimer] = None
1023
+ self.ssd_prefetch_evict_timer: Optional[AsyncSeriesTimer] = None
1024
+ self.prefetch_parallel_stream_cnt: int = 2
1025
+ # tuple of iteration, prefetch parallel stream cnt, reported duration
1026
+ # since there are 2 stream in parallel in prefetch, we want to count the longest one
1027
+ self.prefetch_duration_us: tuple[int, int, float] = (
1028
+ -1,
1029
+ self.prefetch_parallel_stream_cnt,
1030
+ 0,
1031
+ )
1032
+ self.l2_num_cache_misses_stats_name: str = (
1033
+ f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_cache_misses"
1034
+ )
1035
+ self.l2_num_cache_lookups_stats_name: str = (
1036
+ f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_lookups"
1037
+ )
1038
+ self.l2_num_cache_evictions_stats_name: str = (
1039
+ f"l2_cache.perf.tbe_id{tbe_unique_id}.num_l2_cache_evictions"
1040
+ )
1041
+ self.l2_cache_free_mem_stats_name: str = (
1042
+ f"l2_cache.mem.tbe_id{tbe_unique_id}.free_mem_bytes"
1043
+ )
1044
+ self.l2_cache_capacity_stats_name: str = (
1045
+ f"l2_cache.mem.tbe_id{tbe_unique_id}.capacity_bytes"
1046
+ )
1047
+ self.dram_kv_actual_used_chunk_bytes_stats_name: str = (
1048
+ f"dram_kv.mem.tbe_id{tbe_unique_id}.actual_used_chunk_bytes"
1049
+ )
1050
+ self.dram_kv_allocated_bytes_stats_name: str = (
1051
+ f"dram_kv.mem.tbe_id{tbe_unique_id}.allocated_bytes"
1052
+ )
1053
+ self.dram_kv_mem_num_rows_stats_name: str = (
1054
+ f"dram_kv.mem.tbe_id{tbe_unique_id}.num_rows"
1055
+ )
1056
+
1057
+ self.eviction_sum_evicted_counts_stats_name: str = (
1058
+ f"eviction.tbe_id.{tbe_unique_id}.sum_evicted_counts"
1059
+ )
1060
+ self.eviction_sum_processed_counts_stats_name: str = (
1061
+ f"eviction.tbe_id.{tbe_unique_id}.sum_processed_counts"
1062
+ )
1063
+ self.eviction_evict_rate_stats_name: str = (
1064
+ f"eviction.tbe_id.{tbe_unique_id}.evict_rate"
1065
+ )
1066
+
1067
+ if self.stats_reporter:
1068
+ self.ssd_prefetch_read_timer = AsyncSeriesTimer(
1069
+ functools.partial(
1070
+ SSDTableBatchedEmbeddingBags._report_duration,
1071
+ self,
1072
+ event_name="tbe.prefetch_duration_us",
1073
+ time_unit="us",
1074
+ )
1075
+ )
1076
+ self.ssd_prefetch_evict_timer = AsyncSeriesTimer(
1077
+ functools.partial(
1078
+ SSDTableBatchedEmbeddingBags._report_duration,
1079
+ self,
1080
+ event_name="tbe.prefetch_duration_us",
1081
+ time_unit="us",
1082
+ )
1083
+ )
1084
+ # pyre-ignore
1085
+ self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name)
1086
+ self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name)
1087
+ self.stats_reporter.register_stats(self.l2_num_cache_evictions_stats_name)
1088
+ self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name)
1089
+ self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name)
1090
+ self.stats_reporter.register_stats(self.dram_kv_allocated_bytes_stats_name)
1091
+ self.stats_reporter.register_stats(
1092
+ self.dram_kv_actual_used_chunk_bytes_stats_name
1093
+ )
1094
+ self.stats_reporter.register_stats(self.dram_kv_mem_num_rows_stats_name)
1095
+ self.stats_reporter.register_stats(
1096
+ self.eviction_sum_evicted_counts_stats_name
1097
+ )
1098
+ self.stats_reporter.register_stats(
1099
+ self.eviction_sum_processed_counts_stats_name
1100
+ )
1101
+ self.stats_reporter.register_stats(self.eviction_evict_rate_stats_name)
1102
+ for t in self.feature_table_map:
1103
+ self.stats_reporter.register_stats(
1104
+ f"eviction.feature_table.{t}.evicted_counts"
1105
+ )
1106
+ self.stats_reporter.register_stats(
1107
+ f"eviction.feature_table.{t}.processed_counts"
1108
+ )
1109
+ self.stats_reporter.register_stats(
1110
+ f"eviction.feature_table.{t}.evict_rate"
1111
+ )
1112
+ self.stats_reporter.register_stats(
1113
+ "eviction.feature_table.full_duration_ms"
1114
+ )
1115
+ self.stats_reporter.register_stats(
1116
+ "eviction.feature_table.exec_duration_ms"
1117
+ )
1118
+ self.stats_reporter.register_stats(
1119
+ "eviction.feature_table.dry_run_exec_duration_ms"
1120
+ )
1121
+ self.stats_reporter.register_stats(
1122
+ "eviction.feature_table.exec_div_full_duration_rate"
1123
+ )
1124
+
1125
+ self.bounds_check_version: int = get_bounds_check_version_for_platform()
1126
+
1127
+ self._pg = pg
1128
+
1129
+ @cached_property
1130
+ def cache_row_dim(self) -> int:
1131
+ """
1132
+ Compute the effective physical cache row size taking into account
1133
+ padding to the nearest 4 elements and the optimizer state appended to
1134
+ the back of the row
1135
+ """
1136
+
1137
+ # For st publish, we only need to load weight for publishing and bulk eval
1138
+ if self.enable_optimizer_offloading and not self.load_ckpt_without_opt:
1139
+ return self.max_D + pad4(
1140
+ # Compute the number of elements of cache_dtype needed to store
1141
+ # the optimizer state
1142
+ self.optimizer_state_dim
1143
+ )
1144
+ else:
1145
+ return self.max_D
1146
+
1147
+ @cached_property
1148
+ def optimizer_state_dim(self) -> int:
1149
+ return int(
1150
+ math.ceil(
1151
+ self.optimizer.state_size_nbytes(
1152
+ self.max_D, self.optimizer_state_dtypes
1153
+ )
1154
+ / self.weights_precision.as_dtype().itemsize
1155
+ )
1156
+ )
1157
+
1158
+ @property
1159
+ # pyre-ignore
1160
+ def ssd_db(self):
1161
+ """Intercept the ssd_db property to make sure it is fully initialized before use.
1162
+ This is needed because random weights are initialized in a separate thread"""
1163
+ if self.lazy_init_thread is not None:
1164
+ self.lazy_init_thread.join()
1165
+ self.lazy_init_thread = None
1166
+ logging.info("lazy ssd tbe initialization completed, weights are ready")
1167
+
1168
+ return self._ssd_db
1169
+
1170
+ @ssd_db.setter
1171
+ # pyre-ignore
1172
+ def ssd_db(self, value):
1173
+ """Setter for ssd_db property."""
1174
+ if self.lazy_init_thread is not None:
1175
+ # This is essentially a copy assignment operation, since the thread is
1176
+ # already existing, and we are assigning a new ssd_db to it. Complete
1177
+ # the initialization first, then assign the new value to it.
1178
+ self.lazy_init_thread.join()
1179
+ self.lazy_init_thread = None
1180
+ logging.info(
1181
+ "lazy ssd tbe initialization completed, ssd_db will now get overridden"
1182
+ )
1183
+
1184
+ self._ssd_db = value
1185
+
1186
+ def _lazy_initialize_ssd_tbe(self) -> None:
1187
+ """
1188
+ Initialize the SSD TBE with random weights. This function should only be
1189
+ called once at initialization time.
1190
+ """
1191
+ if self.bulk_init_chunk_size > 0:
1192
+ self.lazy_init_thread = threading.Thread(target=self._insert_all_kv)
1193
+ # pyre-ignore
1194
+ self.lazy_init_thread.start()
1195
+ logging.info(
1196
+ f"lazy ssd tbe initialization started since bulk_init_chunk_size is set to {self.bulk_init_chunk_size}"
1197
+ )
1198
+ else:
1199
+ logging.debug(
1200
+ "bulk_init_chunk_size is not set, skipping lazy initialization"
1201
+ )
1202
+
1203
+ @torch.jit.ignore
1204
+ def _insert_all_kv(self) -> None:
1205
+ """
1206
+ Populate all rows in the ssd TBE with random weights. Existing keys will
1207
+ be effectively overwritten. This function should only be called once at
1208
+ initailization time.
1209
+ """
1210
+ self._ssd_db.toggle_compaction(False)
1211
+ row_offset = 0
1212
+ row_count = floor(
1213
+ self.bulk_init_chunk_size
1214
+ / (self.cache_row_dim * self.weights_precision.as_dtype().itemsize)
1215
+ )
1216
+ total_dim0 = 0
1217
+ for dim0, _ in self.embedding_specs:
1218
+ total_dim0 += dim0
1219
+
1220
+ start_ts = time.time()
1221
+ # TODO: do we have case for non-kvzch ssd with bulk init enabled + optimizer offloading? probably not?
1222
+ # if we have such cases, we should only init the emb dim not the optimizer dim
1223
+ chunk_tensor = torch.empty(
1224
+ row_count,
1225
+ self.cache_row_dim,
1226
+ dtype=self.weights_precision.as_dtype(),
1227
+ device="cuda",
1228
+ )
1229
+ cpu_tensor = torch.empty_like(chunk_tensor, device="cpu")
1230
+ for row_offset in range(0, total_dim0, row_count):
1231
+ actual_dim0 = min(total_dim0 - row_offset, row_count)
1232
+ chunk_tensor.uniform_(
1233
+ self.ssd_uniform_init_lower, self.ssd_uniform_init_upper
1234
+ )
1235
+ cpu_tensor.copy_(chunk_tensor, non_blocking=False)
1236
+ rand_val = cpu_tensor[:actual_dim0, :]
1237
+ # This code is intentionally not calling through the getter property
1238
+ # to avoid the lazy initialization thread from joining with itself.
1239
+ self._ssd_db.set_range_to_storage(rand_val, row_offset, actual_dim0)
1240
+ end_ts = time.time()
1241
+ elapsed = int((end_ts - start_ts) * 1e6)
1242
+ logging.info(
1243
+ f"TBE bulk initialization took {elapsed:_} us, bulk_init_chunk_size={self.bulk_init_chunk_size}, each batch of {row_count} rows, total rows of {total_dim0}"
1244
+ )
1245
+ self._ssd_db.toggle_compaction(True)
1246
+
1247
+ @torch.jit.ignore
1248
+ def _report_duration(
1249
+ self,
1250
+ it_step: int,
1251
+ dur_ms: float,
1252
+ event_name: str,
1253
+ time_unit: str,
1254
+ ) -> None:
1255
+ """
1256
+ Callback function passed into AsyncSeriesTimer, which will be
1257
+ invoked when the last kernel in AsyncSeriesTimer scope is done.
1258
+ Currently this is only used to trace prefetch duration, in which
1259
+ there are 2 streams involved, main stream and eviction stream.
1260
+ This will report the duration of the longer stream to ODS
1261
+
1262
+ Function is not thread safe
1263
+
1264
+ Args:
1265
+ it_step (int): The reporting iteration step
1266
+ dur_ms (float): The duration of the all the kernels within the
1267
+ AsyncSeriesTimer scope in milliseconds
1268
+ event_name (str): The name of the event
1269
+ time_unit (str): The unit of the duration(us or ms)
1270
+ """
1271
+ recorded_itr, stream_cnt, report_val = self.prefetch_duration_us
1272
+ duration = dur_ms
1273
+ if time_unit == "us":
1274
+ duration = dur_ms * 1000
1275
+ if it_step == recorded_itr:
1276
+ report_val = max(report_val, duration)
1277
+ stream_cnt -= 1
1278
+ else:
1279
+ # reset
1280
+ recorded_itr = it_step
1281
+ report_val = duration
1282
+ stream_cnt = self.prefetch_parallel_stream_cnt
1283
+ self.prefetch_duration_us = (recorded_itr, stream_cnt, report_val)
1284
+
1285
+ if stream_cnt == 1:
1286
+ # this is the last stream, handling ods report
1287
+ # pyre-ignore
1288
+ self.stats_reporter.report_duration(
1289
+ it_step, event_name, report_val, time_unit=time_unit
1290
+ )
1291
+
1292
+ def record_function_via_dummy_profile_factory(
1293
+ self,
1294
+ use_dummy_profile: bool,
1295
+ ) -> Callable[..., Any]:
1296
+ """
1297
+ Generate the record_function_via_dummy_profile based on the
1298
+ use_dummy_profile flag.
1299
+
1300
+ If use_dummy_profile is True, inject a dummy kernel before and after
1301
+ the function call and record function via `record_function`
1302
+
1303
+ Otherwise, just execute the function
1304
+
1305
+ Args:
1306
+ use_dummy_profile (bool): A flag for enabling/disabling
1307
+ record_function_via_dummy_profile
1308
+ """
1309
+ if use_dummy_profile:
1310
+
1311
+ def func(
1312
+ name: str,
1313
+ fn: Callable[..., Any],
1314
+ *args: Any,
1315
+ **kwargs: Any,
1316
+ ) -> None:
1317
+ with record_function(name):
1318
+ self.dummy_profile_tensor.add_(1)
1319
+ fn(*args, **kwargs)
1320
+ self.dummy_profile_tensor.add_(1)
1321
+
1322
+ return func
1323
+
1324
+ def func(
1325
+ name: str,
1326
+ fn: Callable[..., Any],
1327
+ *args: Any,
1328
+ **kwargs: Any,
1329
+ ) -> None:
1330
+ fn(*args, **kwargs)
1331
+
1332
+ return func
1333
+
1334
+ def _apply_split(
1335
+ self,
1336
+ split: SplitState,
1337
+ prefix: str,
1338
+ dtype: type[torch.dtype],
1339
+ enforce_hbm: bool = False,
1340
+ make_dev_param: bool = False,
1341
+ dev_reshape: Optional[tuple[int, ...]] = None,
1342
+ ) -> None:
1343
+ apply_split_helper(
1344
+ self.register_buffer,
1345
+ functools.partial(setattr, self),
1346
+ self.current_device,
1347
+ False, # use_cpu
1348
+ self.feature_table_map,
1349
+ split,
1350
+ prefix,
1351
+ dtype,
1352
+ enforce_hbm,
1353
+ make_dev_param,
1354
+ dev_reshape,
1355
+ )
1356
+
1357
+ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:
1358
+ t_cpu = torch.empty(t.shape, pin_memory=True, dtype=t.dtype)
1359
+ t_cpu.copy_(t, non_blocking=True)
1360
+ return t_cpu
1361
+
1362
+ def to_pinned_cpu_on_stream_wait_on_another_stream(
1363
+ self,
1364
+ tensors: list[Tensor],
1365
+ stream: torch.cuda.Stream,
1366
+ stream_to_wait_on: torch.cuda.Stream,
1367
+ post_event: Optional[torch.cuda.Event] = None,
1368
+ ) -> list[Tensor]:
1369
+ """
1370
+ Transfer input tensors from GPU to CPU using a pinned host
1371
+ buffer. The transfer is carried out on the given stream
1372
+ (`stream`) after all the kernels in the other stream
1373
+ (`stream_to_wait_on`) are complete.
1374
+
1375
+ Args:
1376
+ tensors (List[Tensor]): The list of tensors to be
1377
+ transferred
1378
+ stream (Stream): The stream to run memory copy
1379
+ stream_to_wait_on (Stream): The stream to wait on
1380
+ post_event (Event): The post completion event
1381
+
1382
+ Returns:
1383
+ The list of pinned CPU tensors
1384
+ """
1385
+ with torch.cuda.stream(stream):
1386
+ stream.wait_stream(stream_to_wait_on)
1387
+ cpu_tensors = []
1388
+ for t in tensors:
1389
+ t.record_stream(stream)
1390
+ cpu_tensors.append(self.to_pinned_cpu(t))
1391
+ if post_event is not None:
1392
+ stream.record_event(post_event)
1393
+ return cpu_tensors
1394
+
1395
+ def evict(
1396
+ self,
1397
+ rows: Tensor,
1398
+ indices_cpu: Tensor,
1399
+ actions_count_cpu: Tensor,
1400
+ stream: torch.cuda.Stream,
1401
+ pre_event: Optional[torch.cuda.Event],
1402
+ post_event: Optional[torch.cuda.Event],
1403
+ is_rows_uvm: bool,
1404
+ name: Optional[str] = "",
1405
+ is_bwd: bool = True,
1406
+ ) -> None:
1407
+ """
1408
+ Evict data from the given input tensors to SSD via RocksDB
1409
+ Args:
1410
+ rows (Tensor): The 2D tensor that contains rows to evict
1411
+ indices_cpu (Tensor): The 1D CPU tensor that contains the row
1412
+ indices that the rows will be evicted to
1413
+ actions_count_cpu (Tensor): A scalar tensor that contains the
1414
+ number of rows that the evict function
1415
+ has to process
1416
+ stream (Stream): The CUDA stream that cudaStreamAddCallback will
1417
+ synchronize the host function with. Moreover, the
1418
+ asynchronous D->H memory copies will operate on
1419
+ this stream
1420
+ pre_event (Event): The CUDA event that the stream has to wait on
1421
+ post_event (Event): The CUDA event that the current will record on
1422
+ when the eviction is done
1423
+ is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
1424
+ tensor (which is accessible on both host and
1425
+ device)
1426
+ is_bwd (bool): A flag to indicate if the eviction is during backward
1427
+ Returns:
1428
+ None
1429
+ """
1430
+ if not self.training: # if not training, freeze the embedding
1431
+ return
1432
+ with record_function(f"## ssd_evict_{name} ##"):
1433
+ with torch.cuda.stream(stream):
1434
+ if pre_event is not None:
1435
+ stream.wait_event(pre_event)
1436
+
1437
+ rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)
1438
+
1439
+ rows.record_stream(stream)
1440
+
1441
+ self.record_function_via_dummy_profile(
1442
+ f"## ssd_set_{name} ##",
1443
+ self.ssd_db.set_cuda,
1444
+ indices_cpu,
1445
+ rows_cpu,
1446
+ actions_count_cpu,
1447
+ self.timestep,
1448
+ is_bwd,
1449
+ )
1450
+
1451
+ if post_event is not None:
1452
+ stream.record_event(post_event)
1453
+
1454
+ def raw_embedding_stream_sync(
1455
+ self,
1456
+ stream: torch.cuda.Stream,
1457
+ pre_event: Optional[torch.cuda.Event],
1458
+ post_event: Optional[torch.cuda.Event],
1459
+ name: Optional[str] = "",
1460
+ ) -> None:
1461
+ """
1462
+ Blocking wait the copy operation of the tensors to be streamed,
1463
+ to make sure they are not overwritten
1464
+ Args:
1465
+ stream (Stream): The CUDA stream that cudaStreamAddCallback will
1466
+ synchronize the host function with. Moreover, the
1467
+ asynchronous D->H memory copies will operate on
1468
+ this stream
1469
+ pre_event (Event): The CUDA event that the stream has to wait on
1470
+ post_event (Event): The CUDA event that the current will record on
1471
+ when the eviction is done
1472
+ Returns:
1473
+ None
1474
+ """
1475
+ with record_function(f"## ssd_stream_{name} ##"):
1476
+ with torch.cuda.stream(stream):
1477
+ if pre_event is not None:
1478
+ stream.wait_event(pre_event)
1479
+
1480
+ self.record_function_via_dummy_profile(
1481
+ f"## ssd_stream_sync_{name} ##",
1482
+ self.ssd_db.stream_sync_cuda,
1483
+ )
1484
+
1485
+ if post_event is not None:
1486
+ stream.record_event(post_event)
1487
+
1488
+ def raw_embedding_stream(
1489
+ self,
1490
+ rows: Tensor,
1491
+ indices_cpu: Tensor,
1492
+ actions_count_cpu: Tensor,
1493
+ stream: torch.cuda.Stream,
1494
+ pre_event: Optional[torch.cuda.Event],
1495
+ post_event: Optional[torch.cuda.Event],
1496
+ is_rows_uvm: bool,
1497
+ blocking_tensor_copy: bool = True,
1498
+ name: Optional[str] = "",
1499
+ ) -> None:
1500
+ """
1501
+ Stream data from the given input tensors to a remote service
1502
+ Args:
1503
+ rows (Tensor): The 2D tensor that contains rows to evict
1504
+ indices_cpu (Tensor): The 1D CPU tensor that contains the row
1505
+ indices that the rows will be evicted to
1506
+ actions_count_cpu (Tensor): A scalar tensor that contains the
1507
+ number of rows that the evict function
1508
+ has to process
1509
+ stream (Stream): The CUDA stream that cudaStreamAddCallback will
1510
+ synchronize the host function with. Moreover, the
1511
+ asynchronous D->H memory copies will operate on
1512
+ this stream
1513
+ pre_event (Event): The CUDA event that the stream has to wait on
1514
+ post_event (Event): The CUDA event that the current will record on
1515
+ when the eviction is done
1516
+ is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
1517
+ tensor (which is accessible on both host and
1518
+ device)
1519
+ Returns:
1520
+ None
1521
+ """
1522
+ with record_function(f"## ssd_stream_{name} ##"):
1523
+ with torch.cuda.stream(stream):
1524
+ if pre_event is not None:
1525
+ stream.wait_event(pre_event)
1526
+
1527
+ rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)
1528
+
1529
+ rows.record_stream(stream)
1530
+
1531
+ self.record_function_via_dummy_profile(
1532
+ f"## ssd_stream_{name} ##",
1533
+ self.ssd_db.stream_cuda,
1534
+ indices_cpu,
1535
+ rows_cpu,
1536
+ actions_count_cpu,
1537
+ blocking_tensor_copy,
1538
+ )
1539
+
1540
+ if post_event is not None:
1541
+ stream.record_event(post_event)
1542
+
1543
+ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
1544
+ """
1545
+ Evict conflict missed rows from a scratch pad
1546
+ (`inserted_rows`) on the `ssd_eviction_stream`. This is a hook
1547
+ that is invoked right after TBE backward.
1548
+
1549
+ Conflict missed indices are specified in
1550
+ `post_bwd_evicted_indices_cpu`. Indices that are not -1 and
1551
+ their positions < `actions_count_cpu` (i.e., rows
1552
+ `post_bwd_evicted_indices_cpu[:actions_count_cpu] != -1` in
1553
+ post_bwd_evicted_indices_cpu) will be evicted.
1554
+
1555
+ Args:
1556
+ grad (Tensor): Unused gradient tensor
1557
+
1558
+ Returns:
1559
+ None
1560
+ """
1561
+ with record_function("## ssd_evict_from_scratch_pad_pipeline ##"):
1562
+ current_stream = torch.cuda.current_stream()
1563
+ current_stream.record_event(self.ssd_event_backward)
1564
+
1565
+ assert (
1566
+ len(self.ssd_scratch_pad_eviction_data) > 0
1567
+ ), "There must be at least one scratch pad"
1568
+
1569
+ (
1570
+ inserted_rows,
1571
+ post_bwd_evicted_indices_cpu,
1572
+ actions_count_cpu,
1573
+ do_evict,
1574
+ ) = self.ssd_scratch_pad_eviction_data.pop(0)
1575
+
1576
+ if not do_evict:
1577
+ return
1578
+
1579
+ if self.enable_raw_embedding_streaming:
1580
+ self.raw_embedding_stream(
1581
+ rows=inserted_rows,
1582
+ indices_cpu=post_bwd_evicted_indices_cpu,
1583
+ actions_count_cpu=actions_count_cpu,
1584
+ stream=self.ssd_eviction_stream,
1585
+ pre_event=self.ssd_event_backward,
1586
+ post_event=self.ssd_event_sp_streamed,
1587
+ is_rows_uvm=True,
1588
+ blocking_tensor_copy=True,
1589
+ name="scratch_pad",
1590
+ )
1591
+ self.evict(
1592
+ rows=inserted_rows,
1593
+ indices_cpu=post_bwd_evicted_indices_cpu,
1594
+ actions_count_cpu=actions_count_cpu,
1595
+ stream=self.ssd_eviction_stream,
1596
+ pre_event=self.ssd_event_backward,
1597
+ post_event=self.ssd_event_sp_evict,
1598
+ is_rows_uvm=True,
1599
+ name="scratch_pad",
1600
+ )
1601
+
1602
+ if self.prefetch_stream:
1603
+ self.prefetch_stream.wait_stream(current_stream)
1604
+
1605
+ def _update_cache_counter_and_pointers(
1606
+ self,
1607
+ module: nn.Module,
1608
+ grad_input: Union[tuple[Tensor, ...], Tensor],
1609
+ ) -> None:
1610
+ """
1611
+ Update cache line locking counter and pointers before backward
1612
+ TBE. This is a hook that is called before the backward of TBE
1613
+
1614
+ Update cache line counter:
1615
+
1616
+ We ensure that cache prefetching does not execute concurrently
1617
+ with the backward TBE. Therefore, it is safe to unlock the
1618
+ cache lines used in current iteration before backward TBE.
1619
+
1620
+ Update pointers:
1621
+
1622
+ Now some rows that are used in both the current iteration and
1623
+ the next iteration are moved (1) from the current iteration's
1624
+ scratch pad into the next iteration's scratch pad or (2) from
1625
+ the current iteration's scratch pad into the L1 cache
1626
+
1627
+ To ensure that the TBE backward kernel accesses valid data,
1628
+ here we update the pointers of these rows in the current
1629
+ iteration's `lxu_cache_ptrs` to point to either L1 cache or
1630
+ the next iteration scratch pad
1631
+
1632
+ Args:
1633
+ module (nn.Module): Unused
1634
+ grad_input (Union[Tuple[Tensor, ...], Tensor]): Unused
1635
+
1636
+ Returns:
1637
+ None
1638
+ """
1639
+ if self.prefetch_stream:
1640
+ # Ensure that prefetch is done
1641
+ torch.cuda.current_stream().wait_stream(self.prefetch_stream)
1642
+
1643
+ assert self.current_iter_data is not None, "current_iter_data must be set"
1644
+
1645
+ curr_data: IterData = self.current_iter_data
1646
+
1647
+ if curr_data.lxu_cache_locations.numel() == 0:
1648
+ return
1649
+
1650
+ with record_function("## ssd_update_cache_counter_and_pointers ##"):
1651
+ # Unlock the cache lines
1652
+ torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
1653
+ self.lxu_cache_locking_counter,
1654
+ curr_data.lxu_cache_locations,
1655
+ )
1656
+
1657
+ # Recompute linear_cache_indices to save memory
1658
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
1659
+ self.hash_size_cumsum,
1660
+ curr_data.indices,
1661
+ curr_data.offsets,
1662
+ curr_data.B_offsets,
1663
+ curr_data.max_B,
1664
+ )
1665
+ (
1666
+ linear_unique_indices,
1667
+ linear_unique_indices_length,
1668
+ unique_indices_count,
1669
+ linear_index_inverse_indices,
1670
+ ) = get_unique_indices_v2(
1671
+ linear_cache_indices,
1672
+ self.total_hash_size,
1673
+ compute_count=True,
1674
+ compute_inverse_indices=True,
1675
+ )
1676
+ unique_indices_count_cumsum = torch.ops.fbgemm.asynchronous_complete_cumsum(
1677
+ unique_indices_count
1678
+ )
1679
+
1680
+ # Look up the cache to check which indices in the scratch
1681
+ # pad are moved to L1
1682
+ torch.ops.fbgemm.lxu_cache_lookup(
1683
+ linear_cache_indices,
1684
+ self.lxu_cache_state,
1685
+ self.total_hash_size,
1686
+ gather_cache_stats=False, # not collecting cache stats
1687
+ lxu_cache_locations_output=curr_data.lxu_cache_locations,
1688
+ )
1689
+
1690
+ if len(self.ssd_location_update_data) == 0:
1691
+ return
1692
+
1693
+ (sp_curr_next_map, inserted_rows_next) = self.ssd_location_update_data.pop(
1694
+ 0
1695
+ )
1696
+
1697
+ # Update poitners
1698
+ torch.ops.fbgemm.ssd_update_row_addrs(
1699
+ ssd_row_addrs_curr=curr_data.lxu_cache_ptrs,
1700
+ inserted_ssd_weights_curr_next_map=sp_curr_next_map,
1701
+ lxu_cache_locations_curr=curr_data.lxu_cache_locations,
1702
+ linear_index_inverse_indices_curr=linear_index_inverse_indices,
1703
+ unique_indices_count_cumsum_curr=unique_indices_count_cumsum,
1704
+ cache_set_inverse_indices_curr=curr_data.cache_set_inverse_indices,
1705
+ lxu_cache_weights=self.lxu_cache_weights,
1706
+ inserted_ssd_weights_next=inserted_rows_next,
1707
+ unique_indices_length_curr=curr_data.actions_count_gpu,
1708
+ )
1709
+
1710
+ def _update_feature_score_metadata(
1711
+ self,
1712
+ linear_cache_indices: Tensor,
1713
+ weights: Tensor,
1714
+ d2h_stream: torch.cuda.Stream,
1715
+ write_stream: torch.cuda.Stream,
1716
+ pre_event_for_write: torch.cuda.Event,
1717
+ post_event: Optional[torch.cuda.Event] = None,
1718
+ ) -> None:
1719
+ """
1720
+ Write feature score metadata to DRAM
1721
+
1722
+ This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
1723
+ The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
1724
+
1725
+ Args:
1726
+ linear_cache_indices: GPU tensor containing cache indices
1727
+ weights: GPU tensor containing feature scores
1728
+ d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
1729
+ write_stream: Stream for metadata write operation
1730
+ pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
1731
+ post_event: Event to record when the operation is done
1732
+ """
1733
+ # Start D2H copy on d2h_stream
1734
+ with torch.cuda.stream(d2h_stream):
1735
+ # Record streams to prevent premature deallocation
1736
+ linear_cache_indices.record_stream(d2h_stream)
1737
+ weights.record_stream(d2h_stream)
1738
+ # Do the D2H copy
1739
+ linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
1740
+ score_weights_cpu = self.to_pinned_cpu(weights)
1741
+
1742
+ # Write feature score metadata to DRAM
1743
+ with record_function("## ssd_write_feature_score_metadata ##"):
1744
+ with torch.cuda.stream(write_stream):
1745
+ write_stream.wait_event(pre_event_for_write)
1746
+ write_stream.wait_stream(d2h_stream)
1747
+ self.record_function_via_dummy_profile(
1748
+ "## ssd_write_feature_score_metadata ##",
1749
+ self.ssd_db.set_feature_score_metadata_cuda,
1750
+ linear_cache_indices_cpu,
1751
+ torch.tensor(
1752
+ [score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
1753
+ ),
1754
+ score_weights_cpu,
1755
+ )
1756
+
1757
+ if post_event is not None:
1758
+ write_stream.record_event(post_event)
1759
+
1760
+ def prefetch(
1761
+ self,
1762
+ indices: Tensor,
1763
+ offsets: Tensor,
1764
+ weights: Optional[Tensor] = None, # todo: need to update caller
1765
+ forward_stream: Optional[torch.cuda.Stream] = None,
1766
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
1767
+ ) -> None:
1768
+ if self.prefetch_stream is None and forward_stream is not None:
1769
+ # Set the prefetch stream to the current stream
1770
+ self.prefetch_stream = torch.cuda.current_stream()
1771
+ assert (
1772
+ self.prefetch_stream != forward_stream
1773
+ ), "prefetch_stream and forward_stream should not be the same stream"
1774
+
1775
+ current_stream = torch.cuda.current_stream()
1776
+ # Record tensors on the current stream
1777
+ indices.record_stream(current_stream)
1778
+ offsets.record_stream(current_stream)
1779
+
1780
+ indices, offsets, _, vbe_metadata = self.prepare_inputs(
1781
+ indices,
1782
+ offsets,
1783
+ per_sample_weights=None,
1784
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
1785
+ )
1786
+
1787
+ self._prefetch(
1788
+ indices,
1789
+ offsets,
1790
+ weights,
1791
+ vbe_metadata,
1792
+ forward_stream,
1793
+ )
1794
+
1795
+ def _prefetch( # noqa C901
1796
+ self,
1797
+ indices: Tensor,
1798
+ offsets: Tensor,
1799
+ weights: Optional[Tensor] = None,
1800
+ vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
1801
+ forward_stream: Optional[torch.cuda.Stream] = None,
1802
+ ) -> None:
1803
+ # Wait for any ongoing direct_write_embedding operations to complete
1804
+ # Moving this from forward() to _prefetch() is more logical as direct_write
1805
+ # operations affect the same cache structures that prefetch interacts with
1806
+ current_stream = torch.cuda.current_stream()
1807
+ if self._embedding_cache_mode:
1808
+ current_stream.wait_event(self.direct_write_l1_complete_event)
1809
+ current_stream.wait_event(self.direct_write_sp_complete_event)
1810
+
1811
+ B_offsets = None
1812
+ max_B = -1
1813
+ if vbe_metadata is not None:
1814
+ B_offsets = vbe_metadata.B_offsets
1815
+ max_B = vbe_metadata.max_B
1816
+
1817
+ with record_function("## ssd_prefetch {} ##".format(self.timestep)):
1818
+ if self.gather_ssd_cache_stats:
1819
+ self.local_ssd_cache_stats.zero_()
1820
+
1821
+ # Linearize indices
1822
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
1823
+ self.hash_size_cumsum,
1824
+ indices,
1825
+ offsets,
1826
+ B_offsets,
1827
+ max_B,
1828
+ )
1829
+
1830
+ self.timestep += 1
1831
+ self.timesteps_prefetched.append(self.timestep)
1832
+
1833
+ # Lookup and virtually insert indices into L1. After this operator,
1834
+ # we know:
1835
+ # (1) which cache lines can be evicted
1836
+ # (2) which rows are already in cache (hit)
1837
+ # (3) which rows are missed and can be inserted later (missed, but
1838
+ # not conflict missed)
1839
+ # (4) which rows are missed but CANNOT be inserted later (conflict
1840
+ # missed)
1841
+ (
1842
+ inserted_indices,
1843
+ evicted_indices,
1844
+ assigned_cache_slots,
1845
+ actions_count_gpu,
1846
+ linear_index_inverse_indices,
1847
+ unique_indices_count_cumsum,
1848
+ cache_set_inverse_indices,
1849
+ unique_indices_length,
1850
+ ) = torch.ops.fbgemm.ssd_cache_populate_actions(
1851
+ linear_cache_indices,
1852
+ self.total_hash_size,
1853
+ self.lxu_cache_state,
1854
+ self.timestep,
1855
+ 1, # for now assume prefetch_dist == 1
1856
+ self.lru_state,
1857
+ self.gather_ssd_cache_stats,
1858
+ self.local_ssd_cache_stats,
1859
+ lock_cache_line=self.prefetch_pipeline,
1860
+ lxu_cache_locking_counter=self.lxu_cache_locking_counter,
1861
+ )
1862
+
1863
+ # Compute cache locations (rows that are hit are missed but can be
1864
+ # inserted will have cache locations != -1)
1865
+ with record_function("## ssd_tbe_lxu_cache_lookup ##"):
1866
+ lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
1867
+ linear_cache_indices,
1868
+ self.lxu_cache_state,
1869
+ self.total_hash_size,
1870
+ self.gather_ssd_cache_stats,
1871
+ self.local_ssd_cache_stats,
1872
+ )
1873
+
1874
+ # Defrag indices based on evicted_indices (removing -1 and making
1875
+ # the non -1 elements contiguous). We need to do this because the
1876
+ # number of rows in `lxu_cache_evicted_weights` might be smaller
1877
+ # than the number of elements in `evicted_indices`. Without this
1878
+ # step, we can run into the index out of bound issue
1879
+ current_stream.wait_event(self.ssd_event_cache_evict)
1880
+ torch.ops.fbgemm.compact_indices(
1881
+ compact_indices=[
1882
+ self.lxu_cache_evicted_indices,
1883
+ self.lxu_cache_evicted_slots,
1884
+ ],
1885
+ compact_count=self.lxu_cache_evicted_count,
1886
+ indices=[evicted_indices, assigned_cache_slots],
1887
+ masks=torch.where(evicted_indices != -1, 1, 0),
1888
+ count=actions_count_gpu,
1889
+ )
1890
+ has_raw_embedding_streaming = False
1891
+ if self.enable_raw_embedding_streaming:
1892
+ # when pipelining is enabled
1893
+ # prefetch in iter i happens before the backward sparse in iter i - 1
1894
+ # so embeddings for iter i - 1's changed ids are not updated.
1895
+ # so we can only fetch the indices from the iter i - 2
1896
+ # when pipelining is disabled
1897
+ # prefetch in iter i happens before forward iter i
1898
+ # so we can get the iter i - 1's changed ids safely.
1899
+ target_prev_iter = 1
1900
+ if self.prefetch_pipeline:
1901
+ target_prev_iter = 2
1902
+ if len(self.prefetched_info) > (target_prev_iter - 1):
1903
+ with record_function(
1904
+ "## ssd_lookup_prefetched_rows {} {} ##".format(
1905
+ self.timestep, self.tbe_unique_id
1906
+ )
1907
+ ):
1908
+ # wait for the copy to finish before overwriting the buffer
1909
+ self.raw_embedding_stream_sync(
1910
+ stream=self.ssd_eviction_stream,
1911
+ pre_event=self.ssd_event_cache_streamed,
1912
+ post_event=self.ssd_event_cache_streaming_synced,
1913
+ name="cache_update",
1914
+ )
1915
+ current_stream.wait_event(self.ssd_event_cache_streaming_synced)
1916
+ (updated_indices, updated_counts_gpu) = (
1917
+ self.prefetched_info.pop(0)
1918
+ )
1919
+ self.lxu_cache_updated_indices[: updated_indices.size(0)].copy_(
1920
+ updated_indices,
1921
+ non_blocking=True,
1922
+ )
1923
+ self.lxu_cache_updated_count[:1].copy_(
1924
+ updated_counts_gpu, non_blocking=True
1925
+ )
1926
+ has_raw_embedding_streaming = True
1927
+
1928
+ with record_function(
1929
+ "## ssd_save_prefetched_rows {} {} ##".format(
1930
+ self.timestep, self.tbe_unique_id
1931
+ )
1932
+ ):
1933
+ masked_updated_indices = torch.where(
1934
+ torch.where(lxu_cache_locations != -1, True, False),
1935
+ linear_cache_indices,
1936
+ -1,
1937
+ )
1938
+
1939
+ (
1940
+ uni_updated_indices,
1941
+ uni_updated_indices_length,
1942
+ ) = get_unique_indices_v2(
1943
+ masked_updated_indices,
1944
+ self.total_hash_size,
1945
+ compute_count=False,
1946
+ compute_inverse_indices=False,
1947
+ )
1948
+ assert uni_updated_indices is not None
1949
+ assert uni_updated_indices_length is not None
1950
+ # The unique indices has 1 more -1 element than needed,
1951
+ # which might make the tensor length go out of range
1952
+ # compared to the pre-allocated buffer.
1953
+ unique_len = min(
1954
+ self.lxu_cache_weights.size(0),
1955
+ uni_updated_indices.size(0),
1956
+ )
1957
+ self.prefetched_info.append(
1958
+ (
1959
+ uni_updated_indices.narrow(0, 0, unique_len),
1960
+ uni_updated_indices_length.clamp(max=unique_len),
1961
+ )
1962
+ )
1963
+
1964
+ with record_function("## ssd_d2h_inserted_indices ##"):
1965
+ # Transfer actions_count and insert_indices right away to
1966
+ # incrase an overlap opportunity
1967
+ actions_count_cpu, inserted_indices_cpu = (
1968
+ self.to_pinned_cpu_on_stream_wait_on_another_stream(
1969
+ tensors=[
1970
+ actions_count_gpu,
1971
+ inserted_indices,
1972
+ ],
1973
+ stream=self.ssd_memcpy_stream,
1974
+ stream_to_wait_on=current_stream,
1975
+ post_event=self.ssd_event_get_inputs_cpy,
1976
+ )
1977
+ )
1978
+
1979
+ # Copy rows to be evicted into a separate buffer (will be evicted
1980
+ # later in the prefetch step)
1981
+ with record_function("## ssd_compute_evicted_rows ##"):
1982
+ torch.ops.fbgemm.masked_index_select(
1983
+ self.lxu_cache_evicted_weights,
1984
+ self.lxu_cache_evicted_slots,
1985
+ self.lxu_cache_weights,
1986
+ self.lxu_cache_evicted_count,
1987
+ )
1988
+
1989
+ # Allocation a scratch pad for the current iteration. The scratch
1990
+ # pad is a UVA tensor
1991
+ inserted_rows_shape = (assigned_cache_slots.numel(), self.cache_row_dim)
1992
+ if linear_cache_indices.numel() > 0:
1993
+ inserted_rows = torch.ops.fbgemm.new_unified_tensor(
1994
+ torch.zeros(
1995
+ 1,
1996
+ device=self.current_device,
1997
+ dtype=self.lxu_cache_weights.dtype,
1998
+ ),
1999
+ inserted_rows_shape,
2000
+ is_host_mapped=self.uvm_host_mapped,
2001
+ )
2002
+ else:
2003
+ inserted_rows = torch.empty(
2004
+ inserted_rows_shape,
2005
+ dtype=self.lxu_cache_weights.dtype,
2006
+ device=self.current_device,
2007
+ )
2008
+
2009
+ if self.prefetch_pipeline and len(self.ssd_scratch_pads) > 0:
2010
+ # Look up all missed indices from the previous iteration's
2011
+ # scratch pad (do this only if pipeline prefetching is being
2012
+ # used)
2013
+ with record_function("## ssd_lookup_scratch_pad ##"):
2014
+ # Get the previous scratch pad
2015
+ (
2016
+ inserted_rows_prev,
2017
+ post_bwd_evicted_indices_cpu_prev,
2018
+ actions_count_cpu_prev,
2019
+ ) = self.ssd_scratch_pads.pop(0)
2020
+
2021
+ # Inserted indices that are found in the scratch pad
2022
+ # from the previous iteration
2023
+ sp_prev_curr_map_cpu = torch.empty(
2024
+ inserted_indices_cpu.shape,
2025
+ dtype=inserted_indices_cpu.dtype,
2026
+ pin_memory=True,
2027
+ )
2028
+
2029
+ # Conflict missed indices from the previous iteration that
2030
+ # overlap with the current iterations's inserted indices
2031
+ sp_curr_prev_map_cpu = torch.empty(
2032
+ post_bwd_evicted_indices_cpu_prev.shape,
2033
+ dtype=torch.int,
2034
+ pin_memory=True,
2035
+ ).fill_(-1)
2036
+
2037
+ # Ensure that the necessary D2H transfers are done
2038
+ current_stream.wait_event(self.ssd_event_get_inputs_cpy)
2039
+ # Ensure that the previous iteration's scratch pad indices
2040
+ # insertion is complete
2041
+ current_stream.wait_event(self.ssd_event_sp_idxq_insert)
2042
+
2043
+ # Before entering this function: inserted_indices_cpu
2044
+ # contains all linear indices that are missed from the
2045
+ # L1 cache
2046
+ #
2047
+ # After this function: inserted indices that are found
2048
+ # in the scratch pad from the previous iteration are
2049
+ # stored in sp_prev_curr_map_cpu, while the rests are
2050
+ # stored in inserted_indices_cpu
2051
+ #
2052
+ # An invalid index is -1 or its position >
2053
+ # actions_count_cpu
2054
+ self.record_function_via_dummy_profile(
2055
+ "## ssd_lookup_mask_and_pop_front ##",
2056
+ self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda,
2057
+ sp_prev_curr_map_cpu, # scratch_pad_prev_curr_map
2058
+ sp_curr_prev_map_cpu, # scratch_pad_curr_prev_map
2059
+ post_bwd_evicted_indices_cpu_prev, # scratch_pad_indices_prev
2060
+ inserted_indices_cpu, # inserted_indices_curr
2061
+ actions_count_cpu, # count_curr
2062
+ )
2063
+
2064
+ # Mark scratch pad index queue lookup completion
2065
+ current_stream.record_event(self.ssd_event_sp_idxq_lookup)
2066
+
2067
+ # Transfer sp_prev_curr_map_cpu to GPU
2068
+ sp_prev_curr_map_gpu = sp_prev_curr_map_cpu.cuda(non_blocking=True)
2069
+ # Transfer sp_curr_prev_map_cpu to GPU
2070
+ sp_curr_prev_map_gpu = sp_curr_prev_map_cpu.cuda(non_blocking=True)
2071
+
2072
+ # Previously actions_count_gpu was recorded on another
2073
+ # stream. Thus, we need to record it on this stream
2074
+ actions_count_gpu.record_stream(current_stream)
2075
+
2076
+ # Copy data from the previous iteration's scratch pad to
2077
+ # the current iteration's scratch pad
2078
+ torch.ops.fbgemm.masked_index_select(
2079
+ inserted_rows,
2080
+ sp_prev_curr_map_gpu,
2081
+ inserted_rows_prev,
2082
+ actions_count_gpu,
2083
+ use_pipeline=self.prefetch_pipeline,
2084
+ )
2085
+
2086
+ # Record the tensors that will be pushed into a queue
2087
+ # on the forward stream
2088
+ if forward_stream:
2089
+ sp_curr_prev_map_gpu.record_stream(forward_stream)
2090
+
2091
+ # Store info for evicting the previous iteration's
2092
+ # scratch pad after the corresponding backward pass is
2093
+ # done
2094
+ if self.training:
2095
+ self.ssd_location_update_data.append(
2096
+ (
2097
+ sp_curr_prev_map_gpu,
2098
+ inserted_rows,
2099
+ )
2100
+ )
2101
+
2102
+ # Ensure the previous iterations eviction is complete
2103
+ current_stream.wait_event(self.ssd_event_sp_evict)
2104
+ # Ensure that D2H is done
2105
+ current_stream.wait_event(self.ssd_event_get_inputs_cpy)
2106
+
2107
+ if self.enable_raw_embedding_streaming and has_raw_embedding_streaming:
2108
+ current_stream.wait_event(self.ssd_event_sp_streamed)
2109
+ with record_function(
2110
+ "## ssd_compute_updated_rows {} {} ##".format(
2111
+ self.timestep, self.tbe_unique_id
2112
+ )
2113
+ ):
2114
+ # cache rows that are changed in the previous iteration
2115
+ updated_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
2116
+ self.lxu_cache_updated_indices,
2117
+ self.lxu_cache_state,
2118
+ self.total_hash_size,
2119
+ self.gather_ssd_cache_stats,
2120
+ self.local_ssd_cache_stats,
2121
+ )
2122
+ torch.ops.fbgemm.masked_index_select(
2123
+ self.lxu_cache_updated_weights,
2124
+ updated_cache_locations,
2125
+ self.lxu_cache_weights,
2126
+ self.lxu_cache_updated_count,
2127
+ )
2128
+ current_stream.record_event(self.ssd_event_cache_streaming_computed)
2129
+
2130
+ self.raw_embedding_stream(
2131
+ rows=self.lxu_cache_updated_weights,
2132
+ indices_cpu=self.lxu_cache_updated_indices,
2133
+ actions_count_cpu=self.lxu_cache_updated_count,
2134
+ stream=self.ssd_eviction_stream,
2135
+ pre_event=self.ssd_event_cache_streaming_computed,
2136
+ post_event=self.ssd_event_cache_streamed,
2137
+ is_rows_uvm=True,
2138
+ blocking_tensor_copy=False,
2139
+ name="cache_update",
2140
+ )
2141
+
2142
+ if self.gather_ssd_cache_stats:
2143
+ # call to collect past SSD IO dur right before next rocksdb IO
2144
+
2145
+ self.ssd_cache_stats = torch.add(
2146
+ self.ssd_cache_stats, self.local_ssd_cache_stats
2147
+ )
2148
+ # only report metrics from rank0 to avoid flooded logging
2149
+ if dist.get_rank() == 0:
2150
+ self._report_kv_backend_stats()
2151
+
2152
+ # May trigger eviction if free mem trigger mode enabled before get cuda
2153
+ self.may_trigger_eviction()
2154
+
2155
+ # Fetch data from SSD
2156
+ if linear_cache_indices.numel() > 0:
2157
+ self.record_function_via_dummy_profile(
2158
+ "## ssd_get ##",
2159
+ self.ssd_db.get_cuda,
2160
+ inserted_indices_cpu,
2161
+ inserted_rows,
2162
+ actions_count_cpu,
2163
+ )
2164
+
2165
+ # Record an event to mark the completion of `get_cuda`
2166
+ current_stream.record_event(self.ssd_event_get)
2167
+
2168
+ # Copy rows from the current iteration's scratch pad to L1
2169
+ torch.ops.fbgemm.masked_index_put(
2170
+ self.lxu_cache_weights,
2171
+ assigned_cache_slots,
2172
+ inserted_rows,
2173
+ actions_count_gpu,
2174
+ use_pipeline=self.prefetch_pipeline,
2175
+ )
2176
+
2177
+ if self.training:
2178
+ if linear_cache_indices.numel() > 0:
2179
+ # Evict rows from cache to SSD
2180
+ self.evict(
2181
+ rows=self.lxu_cache_evicted_weights,
2182
+ indices_cpu=self.lxu_cache_evicted_indices,
2183
+ actions_count_cpu=self.lxu_cache_evicted_count,
2184
+ stream=self.ssd_eviction_stream,
2185
+ pre_event=self.ssd_event_get,
2186
+ # Record completion event after scratch pad eviction
2187
+ # instead since that happens after L1 eviction
2188
+ post_event=self.ssd_event_cache_evict,
2189
+ is_rows_uvm=True,
2190
+ name="cache",
2191
+ is_bwd=False,
2192
+ )
2193
+ if (
2194
+ self.backend_type == BackendType.DRAM
2195
+ and weights is not None
2196
+ and linear_cache_indices.numel() > 0
2197
+ ):
2198
+ # Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
2199
+ self._update_feature_score_metadata(
2200
+ linear_cache_indices=linear_cache_indices,
2201
+ weights=weights,
2202
+ d2h_stream=self.ssd_memcpy_stream,
2203
+ write_stream=self.feature_score_stream,
2204
+ pre_event_for_write=self.ssd_event_cache_evict,
2205
+ )
2206
+
2207
+ # Generate row addresses (pointing to either L1 or the current
2208
+ # iteration's scratch pad)
2209
+ with record_function("## ssd_generate_row_addrs ##"):
2210
+ lxu_cache_ptrs, post_bwd_evicted_indices = (
2211
+ torch.ops.fbgemm.ssd_generate_row_addrs(
2212
+ lxu_cache_locations,
2213
+ assigned_cache_slots,
2214
+ linear_index_inverse_indices,
2215
+ unique_indices_count_cumsum,
2216
+ cache_set_inverse_indices,
2217
+ self.lxu_cache_weights,
2218
+ inserted_rows,
2219
+ unique_indices_length,
2220
+ inserted_indices,
2221
+ )
2222
+ )
2223
+
2224
+ with record_function("## ssd_d2h_post_bwd_evicted_indices ##"):
2225
+ # Transfer post_bwd_evicted_indices from GPU to CPU right away to
2226
+ # increase a chance of overlapping with compute in the default stream
2227
+ (post_bwd_evicted_indices_cpu,) = (
2228
+ self.to_pinned_cpu_on_stream_wait_on_another_stream(
2229
+ tensors=[post_bwd_evicted_indices],
2230
+ stream=self.ssd_eviction_stream,
2231
+ stream_to_wait_on=current_stream,
2232
+ post_event=None,
2233
+ )
2234
+ )
2235
+
2236
+ if self.prefetch_pipeline:
2237
+ # Insert the current iteration's conflict miss indices in the index
2238
+ # queue for future lookup.
2239
+ #
2240
+ # post_bwd_evicted_indices_cpu is transferred on the
2241
+ # ssd_eviction_stream stream so it does not need stream
2242
+ # synchronization
2243
+ #
2244
+ # actions_count_cpu is transferred on the ssd_memcpy_stream stream.
2245
+ # Thus, we have to explicitly sync the stream
2246
+ with torch.cuda.stream(self.ssd_eviction_stream):
2247
+ # Ensure that actions_count_cpu transfer is done
2248
+ self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy)
2249
+ # Ensure that the scratch pad index queue look up is complete
2250
+ self.ssd_eviction_stream.wait_event(self.ssd_event_sp_idxq_lookup)
2251
+ self.record_function_via_dummy_profile(
2252
+ "## ssd_scratch_pad_idx_queue_insert ##",
2253
+ self.scratch_pad_idx_queue.insert_cuda,
2254
+ post_bwd_evicted_indices_cpu,
2255
+ actions_count_cpu,
2256
+ )
2257
+ # Mark the completion of scratch pad index insertion
2258
+ self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert)
2259
+
2260
+ prefetch_data = (
2261
+ lxu_cache_ptrs,
2262
+ inserted_rows,
2263
+ post_bwd_evicted_indices_cpu,
2264
+ actions_count_cpu,
2265
+ actions_count_gpu,
2266
+ lxu_cache_locations,
2267
+ cache_set_inverse_indices,
2268
+ )
2269
+
2270
+ # Record tensors on the forward stream
2271
+ if forward_stream is not None:
2272
+ for t in prefetch_data:
2273
+ if t.is_cuda:
2274
+ t.record_stream(forward_stream)
2275
+
2276
+ if self.prefetch_pipeline:
2277
+ # Store scratch pad info for the lookup in the next iteration
2278
+ # prefetch
2279
+ self.ssd_scratch_pads.append(
2280
+ (
2281
+ inserted_rows,
2282
+ post_bwd_evicted_indices_cpu,
2283
+ actions_count_cpu,
2284
+ )
2285
+ )
2286
+
2287
+ # Store scratch pad info for post backward eviction only for training
2288
+ # for eval job, no backward pass, so no need to store this info
2289
+ if self.training:
2290
+ self.ssd_scratch_pad_eviction_data.append(
2291
+ (
2292
+ inserted_rows,
2293
+ post_bwd_evicted_indices_cpu,
2294
+ actions_count_cpu,
2295
+ linear_cache_indices.numel() > 0,
2296
+ )
2297
+ )
2298
+
2299
+ # Store data for forward
2300
+ self.ssd_prefetch_data.append(prefetch_data)
2301
+
2302
+ # Record an event to mark the completion of prefetch operations
2303
+ # This will be used by direct_write_embedding to ensure it doesn't run concurrently with prefetch
2304
+ current_stream.record_event(self.prefetch_complete_event)
2305
+
2306
+ @torch.jit.ignore
2307
+ def _generate_vbe_metadata(
2308
+ self,
2309
+ offsets: Tensor,
2310
+ batch_size_per_feature_per_rank: Optional[list[list[int]]],
2311
+ ) -> invokers.lookup_args.VBEMetadata:
2312
+ # Blocking D2H copy, but only runs at first call
2313
+ self.feature_dims = self.feature_dims.cpu()
2314
+ if batch_size_per_feature_per_rank is not None:
2315
+ assert self.optimizer in (
2316
+ OptimType.EXACT_ROWWISE_ADAGRAD,
2317
+ OptimType.EXACT_SGD,
2318
+ ), (
2319
+ "Variable batch size TBE support is enabled for "
2320
+ "OptimType.EXACT_ROWWISE_ADAGRAD and "
2321
+ "ENSEMBLE_ROWWISE_ADAGRAD only"
2322
+ )
2323
+ return generate_vbe_metadata(
2324
+ offsets,
2325
+ batch_size_per_feature_per_rank,
2326
+ self.pooling_mode,
2327
+ self.feature_dims,
2328
+ self.current_device,
2329
+ )
2330
+
2331
+ def _increment_iteration(self) -> int:
2332
+ # Although self.iter_cpu is created on CPU. It might be transferred to
2333
+ # GPU by the user. So, we need to transfer it to CPU explicitly. This
2334
+ # should be done only once.
2335
+ self.iter_cpu = self.iter_cpu.cpu()
2336
+
2337
+ # Sync with loaded state
2338
+ # Wrap to make it compatible with PT2 compile
2339
+ if not is_torchdynamo_compiling():
2340
+ if self.iter_cpu.item() == 0:
2341
+ self.iter_cpu.fill_(self.iter.cpu().item())
2342
+
2343
+ # Increment the iteration counter
2344
+ # The CPU counterpart is used for local computation
2345
+ iter_int = int(self.iter_cpu.add_(1).item())
2346
+ # The GPU counterpart is used for checkpointing
2347
+ self.iter.add_(1)
2348
+
2349
+ return iter_int
2350
+
2351
+ def forward(
2352
+ self,
2353
+ indices: Tensor,
2354
+ offsets: Tensor,
2355
+ weights: Optional[Tensor] = None,
2356
+ per_sample_weights: Optional[Tensor] = None,
2357
+ feature_requires_grad: Optional[Tensor] = None,
2358
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
2359
+ # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
2360
+ ) -> Tensor:
2361
+ self.clear_cache()
2362
+ indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
2363
+ indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
2364
+ )
2365
+
2366
+ if len(self.timesteps_prefetched) == 0:
2367
+
2368
+ with self._recording_to_timer(
2369
+ self.ssd_prefetch_read_timer,
2370
+ context=self.step,
2371
+ stream=torch.cuda.current_stream(),
2372
+ ), self._recording_to_timer(
2373
+ self.ssd_prefetch_evict_timer,
2374
+ context=self.step,
2375
+ stream=self.ssd_eviction_stream,
2376
+ ):
2377
+ self._prefetch(indices, offsets, weights, vbe_metadata)
2378
+
2379
+ assert len(self.ssd_prefetch_data) > 0
2380
+
2381
+ (
2382
+ lxu_cache_ptrs,
2383
+ inserted_rows,
2384
+ post_bwd_evicted_indices_cpu,
2385
+ actions_count_cpu,
2386
+ actions_count_gpu,
2387
+ lxu_cache_locations,
2388
+ cache_set_inverse_indices,
2389
+ ) = self.ssd_prefetch_data.pop(0)
2390
+
2391
+ # Storing current iteration data for future use
2392
+ self.current_iter_data = IterData(
2393
+ indices,
2394
+ offsets,
2395
+ lxu_cache_locations,
2396
+ lxu_cache_ptrs,
2397
+ actions_count_gpu,
2398
+ cache_set_inverse_indices,
2399
+ vbe_metadata.B_offsets,
2400
+ vbe_metadata.max_B,
2401
+ )
2402
+
2403
+ common_args = invokers.lookup_args_ssd.CommonArgs(
2404
+ placeholder_autograd_tensor=self.placeholder_autograd_tensor,
2405
+ output_dtype=self.output_dtype,
2406
+ dev_weights=self.weights_dev,
2407
+ host_weights=self.weights_host,
2408
+ uvm_weights=self.weights_uvm,
2409
+ lxu_cache_weights=self.lxu_cache_weights,
2410
+ weights_placements=self.weights_placements,
2411
+ weights_offsets=self.weights_offsets,
2412
+ D_offsets=self.D_offsets,
2413
+ total_D=self.total_D,
2414
+ max_D=self.max_D,
2415
+ hash_size_cumsum=self.hash_size_cumsum,
2416
+ total_hash_size_bits=self.total_hash_size_bits,
2417
+ indices=indices,
2418
+ offsets=offsets,
2419
+ pooling_mode=self.pooling_mode,
2420
+ indice_weights=per_sample_weights,
2421
+ feature_requires_grad=feature_requires_grad,
2422
+ lxu_cache_locations=lxu_cache_ptrs,
2423
+ uvm_cache_stats=None,
2424
+ # Unused arguments
2425
+ is_experimental=False,
2426
+ use_uniq_cache_locations_bwd=False,
2427
+ use_homogeneous_placements=True,
2428
+ # The keys for ssd_tensors are controlled by ssd_tensors in
2429
+ # codegen/genscript/optimizer_args.py
2430
+ ssd_tensors={
2431
+ "row_addrs": lxu_cache_ptrs,
2432
+ "inserted_rows": inserted_rows,
2433
+ "post_bwd_evicted_indices": post_bwd_evicted_indices_cpu,
2434
+ "actions_count": actions_count_cpu,
2435
+ },
2436
+ enable_optimizer_offloading=self.enable_optimizer_offloading,
2437
+ # pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
2438
+ vbe_metadata=vbe_metadata,
2439
+ learning_rate_tensor=self.learning_rate_tensor,
2440
+ info_B_num_bits=self.info_B_num_bits,
2441
+ info_B_mask=self.info_B_mask,
2442
+ )
2443
+
2444
+ self.timesteps_prefetched.pop(0)
2445
+ self.step += 1
2446
+
2447
+ # Increment the iteration (value is used for certain optimizers)
2448
+ iter_int = self._increment_iteration()
2449
+
2450
+ if self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
2451
+ momentum2 = invokers.lookup_args_ssd.Momentum(
2452
+ # pyre-ignore[6]
2453
+ dev=self.momentum2_dev,
2454
+ # pyre-ignore[6]
2455
+ host=self.momentum2_host,
2456
+ # pyre-ignore[6]
2457
+ uvm=self.momentum2_uvm,
2458
+ # pyre-ignore[6]
2459
+ offsets=self.momentum2_offsets,
2460
+ # pyre-ignore[6]
2461
+ placements=self.momentum2_placements,
2462
+ )
2463
+
2464
+ momentum1 = invokers.lookup_args_ssd.Momentum(
2465
+ dev=self.momentum1_dev,
2466
+ host=self.momentum1_host,
2467
+ uvm=self.momentum1_uvm,
2468
+ offsets=self.momentum1_offsets,
2469
+ placements=self.momentum1_placements,
2470
+ )
2471
+
2472
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2473
+ return invokers.lookup_rowwise_adagrad_ssd.invoke(
2474
+ common_args, self.optimizer_args, momentum1
2475
+ )
2476
+
2477
+ elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2478
+ return invokers.lookup_partial_rowwise_adam_ssd.invoke(
2479
+ common_args,
2480
+ self.optimizer_args,
2481
+ momentum1,
2482
+ # pyre-ignore[61]
2483
+ momentum2,
2484
+ iter_int,
2485
+ )
2486
+
2487
+ elif self.optimizer == OptimType.ADAM:
2488
+ row_counter = invokers.lookup_args_ssd.Momentum(
2489
+ # pyre-fixme[6]
2490
+ dev=self.row_counter_dev,
2491
+ # pyre-fixme[6]
2492
+ host=self.row_counter_host,
2493
+ # pyre-fixme[6]
2494
+ uvm=self.row_counter_uvm,
2495
+ # pyre-fixme[6]
2496
+ offsets=self.row_counter_offsets,
2497
+ # pyre-fixme[6]
2498
+ placements=self.row_counter_placements,
2499
+ )
2500
+
2501
+ return invokers.lookup_adam_ssd.invoke(
2502
+ common_args,
2503
+ self.optimizer_args,
2504
+ momentum1,
2505
+ # pyre-ignore[61]
2506
+ momentum2,
2507
+ iter_int,
2508
+ row_counter=row_counter,
2509
+ )
2510
+
2511
+ @torch.jit.ignore
2512
+ def _split_optimizer_states_non_kv_zch(
2513
+ self,
2514
+ ) -> list[list[torch.Tensor]]:
2515
+ """
2516
+ Returns a list of optimizer states (view), split by table.
2517
+
2518
+ Returns:
2519
+ A list of list of states. Shape = (the number of tables, the number
2520
+ of states).
2521
+
2522
+ The following shows the list of states (in the returned order) for
2523
+ each optimizer:
2524
+
2525
+ (1) `EXACT_ROWWISE_ADAGRAD`: `momentum1` (rowwise)
2526
+
2527
+ (1) `PARTIAL_ROWWISE_ADAM`: `momentum1`, `momentum2` (rowwise)
2528
+ """
2529
+
2530
+ # Row count per table
2531
+ (rows, dims) = zip(*self.embedding_specs)
2532
+ # Cumulative row counts per table for rowwise states
2533
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
2534
+ # Cumulative element counts per table for elementwise states
2535
+ elem_count_cumsum: list[int] = [0] + list(
2536
+ itertools.accumulate([r * d for r, d in self.embedding_specs])
2537
+ )
2538
+
2539
+ # pyre-ignore[53]
2540
+ def _slice(tensor: Tensor, t: int, rowwise: bool) -> Tensor:
2541
+ d: int = dims[t]
2542
+ e: int = rows[t]
2543
+
2544
+ if not rowwise:
2545
+ # Optimizer state is element-wise - compute the table offset for
2546
+ # the table, view the slice as 2D tensor
2547
+ return tensor.detach()[
2548
+ elem_count_cumsum[t] : elem_count_cumsum[t + 1]
2549
+ ].view(-1, d)
2550
+ else:
2551
+ # Optimizer state is row-wise - fetch elements in range and view
2552
+ # slice as 1D
2553
+ return tensor.detach()[
2554
+ row_count_cumsum[t] : row_count_cumsum[t + 1]
2555
+ ].view(e)
2556
+
2557
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2558
+ return [
2559
+ [_slice(self.momentum1_dev, t, rowwise=True)]
2560
+ for t, _ in enumerate(rows)
2561
+ ]
2562
+ elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2563
+ return [
2564
+ [
2565
+ _slice(self.momentum1_dev, t, rowwise=False),
2566
+ # pyre-ignore[6]
2567
+ _slice(self.momentum2_dev, t, rowwise=True),
2568
+ ]
2569
+ for t, _ in enumerate(rows)
2570
+ ]
2571
+
2572
+ elif self.optimizer == OptimType.ADAM:
2573
+ return [
2574
+ [
2575
+ _slice(self.momentum1_dev, t, rowwise=False),
2576
+ # pyre-ignore[6]
2577
+ _slice(self.momentum2_dev, t, rowwise=False),
2578
+ ]
2579
+ for t, _ in enumerate(rows)
2580
+ ]
2581
+
2582
+ else:
2583
+ raise NotImplementedError(
2584
+ f"Getting optimizer states is not supported for {self.optimizer}"
2585
+ )
2586
+
2587
+ @torch.jit.ignore
2588
+ def _split_optimizer_states_kv_zch_no_offloading(
2589
+ self,
2590
+ sorted_ids: torch.Tensor,
2591
+ ) -> list[list[torch.Tensor]]:
2592
+
2593
+ # Row count per table
2594
+ (rows, dims) = zip(*self.embedding_specs)
2595
+ # Cumulative row counts per table for rowwise states
2596
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
2597
+ # Cumulative element counts per table for elementwise states
2598
+ elem_count_cumsum: list[int] = [0] + list(
2599
+ itertools.accumulate([r * d for r, d in self.embedding_specs])
2600
+ )
2601
+
2602
+ # pyre-ignore[53]
2603
+ def _slice(state_name: str, tensor: Tensor, t: int, rowwise: bool) -> Tensor:
2604
+ d: int = dims[t]
2605
+
2606
+ # pyre-ignore[16]
2607
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2608
+ # pyre-ignore[16]
2609
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
2610
+
2611
+ if sorted_ids is None or sorted_ids[t].numel() == 0:
2612
+ # Empty optimizer state for module initialization
2613
+ return torch.empty(
2614
+ 0,
2615
+ dtype=(
2616
+ self.optimizer_state_dtypes.get(
2617
+ state_name, SparseType.FP32
2618
+ ).as_dtype()
2619
+ ),
2620
+ device="cpu",
2621
+ )
2622
+
2623
+ elif not rowwise:
2624
+ # Optimizer state is element-wise - materialize the local ids
2625
+ # based on the sorted_ids compute the table offset for the
2626
+ # table, view the slice as 2D tensor of e x d, then fetch the
2627
+ # sub-slice by local ids
2628
+ #
2629
+ # local_ids is [N, 1], flatten it to N to keep the returned tensor 2D
2630
+ local_ids = (sorted_ids[t] - bucket_id_start * bucket_size).view(-1)
2631
+ return (
2632
+ tensor.detach()
2633
+ .cpu()[elem_count_cumsum[t] : elem_count_cumsum[t + 1]]
2634
+ .view(-1, d)[local_ids]
2635
+ )
2636
+
2637
+ else:
2638
+ # Optimizer state is row-wise - materialize the local ids based
2639
+ # on the sorted_ids and table offset (i.e. row count cumsum),
2640
+ # then fetch by local ids
2641
+ linearized_local_ids = (
2642
+ sorted_ids[t] - bucket_id_start * bucket_size + row_count_cumsum[t]
2643
+ )
2644
+ return tensor.detach().cpu()[linearized_local_ids].view(-1)
2645
+
2646
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2647
+ return [
2648
+ [_slice("momentum1", self.momentum1_dev, t, rowwise=True)]
2649
+ for t, _ in enumerate(rows)
2650
+ ]
2651
+
2652
+ elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2653
+ return [
2654
+ [
2655
+ _slice("momentum1", self.momentum1_dev, t, rowwise=False),
2656
+ # pyre-ignore[6]
2657
+ _slice("momentum2", self.momentum2_dev, t, rowwise=True),
2658
+ ]
2659
+ for t, _ in enumerate(rows)
2660
+ ]
2661
+
2662
+ elif self.optimizer == OptimType.ADAM:
2663
+ return [
2664
+ [
2665
+ _slice("momentum1", self.momentum1_dev, t, rowwise=False),
2666
+ # pyre-ignore[6]
2667
+ _slice("momentum2", self.momentum2_dev, t, rowwise=False),
2668
+ ]
2669
+ for t, _ in enumerate(rows)
2670
+ ]
2671
+
2672
+ else:
2673
+ raise NotImplementedError(
2674
+ f"Getting optimizer states is not supported for {self.optimizer}"
2675
+ )
2676
+
2677
+ @torch.jit.ignore
2678
+ def _split_optimizer_states_kv_zch_w_offloading(
2679
+ self,
2680
+ sorted_ids: torch.Tensor,
2681
+ no_snapshot: bool = True,
2682
+ should_flush: bool = False,
2683
+ ) -> list[list[torch.Tensor]]:
2684
+ dtype = self.weights_precision.as_dtype()
2685
+ # Row count per table
2686
+ (rows_, dims_) = zip(*self.embedding_specs)
2687
+ # Cumulative row counts per table for rowwise states
2688
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
2689
+
2690
+ snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
2691
+ no_snapshot=no_snapshot,
2692
+ should_flush=should_flush,
2693
+ )
2694
+
2695
+ # pyre-ignore[53]
2696
+ def _fetch_offloaded_optimizer_states(
2697
+ t: int,
2698
+ ) -> list[Tensor]:
2699
+ e: int = rows_[t]
2700
+ d: int = dims_[t]
2701
+
2702
+ # pyre-ignore[16]
2703
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2704
+ # pyre-ignore[16]
2705
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
2706
+
2707
+ row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
2708
+ # Count of rows to fetch
2709
+ rows_to_fetch = sorted_ids[t].numel()
2710
+
2711
+ # Lookup the byte offsets for each optimizer state
2712
+ optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
2713
+ d, self.weights_precision, self.optimizer_state_dtypes
2714
+ )
2715
+ # Find the minimum start of all the start/end pairs - we have to
2716
+ # offset the start/end pairs by this value to get the correct start/end
2717
+ offset_ = min(
2718
+ [start for _, (start, _) in optimizer_state_byte_offsets.items()]
2719
+ )
2720
+ # Update the start/end pairs to be relative to offset_
2721
+ optimizer_state_byte_offsets = dict(
2722
+ (k, (v1 - offset_, v2 - offset_))
2723
+ for k, (v1, v2) in optimizer_state_byte_offsets.items()
2724
+ )
2725
+
2726
+ # Since the backend returns cache rows that pack the weights and
2727
+ # optimizer states together, reading the whole tensor could cause OOM,
2728
+ # so we use the KVTensorWrapper abstraction to query the backend and
2729
+ # fetch the data in chunks instead.
2730
+ tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2731
+ shape=[
2732
+ e,
2733
+ # Dim is terms of **weights** dtype
2734
+ self.optimizer_state_dim,
2735
+ ],
2736
+ dtype=dtype,
2737
+ row_offset=row_offset,
2738
+ snapshot_handle=snapshot_handle,
2739
+ sorted_indices=sorted_ids[t],
2740
+ width_offset=pad4(d),
2741
+ )
2742
+ (
2743
+ tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2744
+ if self.backend_type == BackendType.SSD
2745
+ else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2746
+ )
2747
+
2748
+ # Fetch the state size table for the given weights domension
2749
+ state_size_table = self.optimizer.state_size_table(d)
2750
+
2751
+ # Create a 2D output buffer of [rows x optimizer state dim] with the
2752
+ # weights type as the type. For optimizers with multiple states (e.g.
2753
+ # momentum1 and momentum2), this tensor will include data from all
2754
+ # states, hence self.optimizer_state_dim as the row size.
2755
+ optimizer_states_buffer = torch.empty(
2756
+ (rows_to_fetch, self.optimizer_state_dim), dtype=dtype, device="cpu"
2757
+ )
2758
+
2759
+ # Set the chunk size for fetching
2760
+ chunk_size = (
2761
+ # 10M rows => 260(max_D)* 2(ele_bytes) * 10M => 5.2GB mem spike
2762
+ 10_000_000
2763
+ )
2764
+ logging.info(f"split optimizer chunk rows: {chunk_size}")
2765
+
2766
+ # Chunk the fetching by chunk_size
2767
+ for i in range(0, rows_to_fetch, chunk_size):
2768
+ length = min(chunk_size, rows_to_fetch - i)
2769
+
2770
+ # Fetch from backend and copy to the output buffer
2771
+ optimizer_states_buffer.narrow(0, i, length).copy_(
2772
+ tensor_wrapper.narrow(0, i, length).view(dtype)
2773
+ )
2774
+
2775
+ # Now split up the buffer into N views, N for each optimizer state
2776
+ optimizer_states: list[Tensor] = []
2777
+ for state_name in self.optimizer.state_names():
2778
+ # Extract the offsets
2779
+ (start, end) = optimizer_state_byte_offsets[state_name]
2780
+
2781
+ state = optimizer_states_buffer.view(
2782
+ # Force tensor to byte view
2783
+ dtype=torch.uint8
2784
+ # Copy by byte offsets
2785
+ )[:, start:end].view(
2786
+ # Re-view in the state's target dtype
2787
+ self.optimizer_state_dtypes.get(
2788
+ state_name, SparseType.FP32
2789
+ ).as_dtype()
2790
+ )
2791
+
2792
+ optimizer_states.append(
2793
+ # If the state is rowwise (i.e. just one element per row),
2794
+ # then re-view as 1D tensor
2795
+ state
2796
+ if state_size_table[state_name] > 1
2797
+ else state.view(-1)
2798
+ )
2799
+
2800
+ # Return the views
2801
+ return optimizer_states
2802
+
2803
+ return [
2804
+ (
2805
+ self.optimizer.empty_states([0], [d], self.optimizer_state_dtypes)[0]
2806
+ # Return a set of empty states if sorted_ids[t] is empty
2807
+ if sorted_ids is None or sorted_ids[t].numel() == 0
2808
+ # Else fetch the list of optimizer states for the table
2809
+ else _fetch_offloaded_optimizer_states(t)
2810
+ )
2811
+ for t, d in enumerate(dims_)
2812
+ ]
2813
+
2814
+ @torch.jit.ignore
2815
+ def _split_optimizer_states_kv_zch_whole_row(
2816
+ self,
2817
+ sorted_ids: torch.Tensor,
2818
+ no_snapshot: bool = True,
2819
+ should_flush: bool = False,
2820
+ ) -> list[list[torch.Tensor]]:
2821
+ dtype = self.weights_precision.as_dtype()
2822
+
2823
+ # Row and dimension counts per table
2824
+ # rows_ is only used here to compute the virtual table offsets
2825
+ (rows_, dims_) = zip(*self.embedding_specs)
2826
+
2827
+ # Cumulative row counts per (virtual) table for rowwise states
2828
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
2829
+
2830
+ snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
2831
+ no_snapshot=no_snapshot,
2832
+ should_flush=should_flush,
2833
+ )
2834
+
2835
+ # pyre-ignore[53]
2836
+ def _fetch_offloaded_optimizer_states(
2837
+ t: int,
2838
+ ) -> list[Tensor]:
2839
+ d: int = dims_[t]
2840
+
2841
+ # pyre-ignore[16]
2842
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2843
+ # pyre-ignore[16]
2844
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
2845
+ row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
2846
+
2847
+ # When backend returns whole row, the optimizer will be returned as
2848
+ # PMT directly
2849
+ if sorted_ids[t].size(0) == 0 and self.local_weight_counts[t] > 0:
2850
+ logging.info(
2851
+ f"Before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
2852
+ )
2853
+ sorted_ids[t] = torch.zeros(
2854
+ (self.local_weight_counts[t], 1),
2855
+ device=torch.device("cpu"),
2856
+ dtype=torch.int64,
2857
+ )
2858
+
2859
+ # Lookup the byte offsets for each optimizer state relative to the
2860
+ # start of the weights
2861
+ optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
2862
+ d, self.weights_precision, self.optimizer_state_dtypes
2863
+ )
2864
+ # Get the number of elements (of the optimizer state dtype) per state
2865
+ optimizer_state_size_table = self.optimizer.state_size_table(d)
2866
+
2867
+ # Get metaheader dimensions in number of elements of weight dtype
2868
+ metaheader_dim = (
2869
+ # pyre-ignore[16]
2870
+ self.kv_zch_params.eviction_policy.meta_header_lens[t]
2871
+ )
2872
+
2873
+ # Now split up the buffer into N views, N for each optimizer state
2874
+ optimizer_states: list[PartiallyMaterializedTensor] = []
2875
+ for state_name in self.optimizer.state_names():
2876
+ state_dtype = self.optimizer_state_dtypes.get(
2877
+ state_name, SparseType.FP32
2878
+ ).as_dtype()
2879
+
2880
+ # Get the size of the state in elements of the optimizer state,
2881
+ # in terms of the **weights** dtype
2882
+ state_size = math.ceil(
2883
+ optimizer_state_size_table[state_name]
2884
+ * state_dtype.itemsize
2885
+ / dtype.itemsize
2886
+ )
2887
+
2888
+ # Extract the offsets relative to the start of the weights (in
2889
+ # num bytes)
2890
+ (start, _) = optimizer_state_byte_offsets[state_name]
2891
+
2892
+ # Convert the start to number of elements in terms of the
2893
+ # **weights** dtype, then add the mmetaheader dim offset
2894
+ start = metaheader_dim + start // dtype.itemsize
2895
+
2896
+ shape = [
2897
+ (
2898
+ sorted_ids[t].size(0)
2899
+ if sorted_ids is not None and sorted_ids[t].size(0) > 0
2900
+ else self.local_weight_counts[t]
2901
+ ),
2902
+ (
2903
+ # Dim is in terms of the **weights** dtype
2904
+ state_size
2905
+ ),
2906
+ ]
2907
+
2908
+ # NOTE: We have to view using the **weights** dtype, as
2909
+ # there is currently a bug with KVTensorWrapper where using
2910
+ # a different dtype does not result in the same bytes being
2911
+ # returned, e.g.
2912
+ #
2913
+ # KVTensorWrapper(dtype=fp32, width_offset=130, shape=[N, 1])
2914
+ #
2915
+ # is NOT the same as
2916
+ #
2917
+ # KVTensorWrapper(dtype=fp16, width_offset=260, shape=[N, 2]).view(-1).view(fp32)
2918
+ #
2919
+ # TODO: Fix KVTensorWrapper to support viewing data under different dtypes
2920
+ tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2921
+ shape=shape,
2922
+ dtype=(
2923
+ # NOTE: Use the *weights* dtype
2924
+ dtype
2925
+ ),
2926
+ row_offset=row_offset,
2927
+ snapshot_handle=snapshot_handle,
2928
+ sorted_indices=sorted_ids[t],
2929
+ width_offset=(
2930
+ # NOTE: Width offset is in terms of **weights** dtype
2931
+ start
2932
+ ),
2933
+ # Optimizer written to DB with weights, so skip write here
2934
+ read_only=True,
2935
+ )
2936
+ (
2937
+ tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2938
+ if self.backend_type == BackendType.SSD
2939
+ else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2940
+ )
2941
+
2942
+ optimizer_states.append(
2943
+ PartiallyMaterializedTensor(tensor_wrapper, True)
2944
+ )
2945
+
2946
+ # pyre-ignore [7]
2947
+ return optimizer_states
2948
+
2949
+ return [_fetch_offloaded_optimizer_states(t) for t, _ in enumerate(dims_)]
2950
+
2951
+ @torch.jit.export
2952
+ def split_optimizer_states(
2953
+ self,
2954
+ sorted_id_tensor: Optional[list[torch.Tensor]] = None,
2955
+ no_snapshot: bool = True,
2956
+ should_flush: bool = False,
2957
+ ) -> list[list[torch.Tensor]]:
2958
+ """
2959
+ Returns a list of optimizer states split by table.
2960
+
2961
+ Since EXACT_ROWWISE_ADAGRAD has small optimizer states, we would generate
2962
+ a full tensor for each table (shard). When other optimizer types are supported,
2963
+ we should integrate with KVTensorWrapper (ssd_split_table_batched_embeddings.cpp)
2964
+ to allow caller to read the optimizer states using `narrow()` in a rolling-window manner.
2965
+
2966
+ Args:
2967
+ sorted_id_tensor (Optional[List[torch.Tensor]]): sorted id tensor by table, used to query optimizer
2968
+ state from backend. Call should reuse the generated id tensor from weight state_dict, to guarantee
2969
+ id consistency between weight and optimizer states.
2970
+
2971
+ """
2972
+
2973
+ # Handle the non-KVZCH case
2974
+ if not self.kv_zch_params:
2975
+ # If not in KV
2976
+ return self._split_optimizer_states_non_kv_zch()
2977
+
2978
+ # Handle the loading-from-state-dict case
2979
+ if self.load_state_dict:
2980
+ # Initialize for checkpointing loading
2981
+ assert (
2982
+ self._cached_kvzch_data is not None
2983
+ and self._cached_kvzch_data.cached_optimizer_states_per_table
2984
+ ), "optimizer state is not initialized for load checkpointing"
2985
+
2986
+ return self._cached_kvzch_data.cached_optimizer_states_per_table
2987
+
2988
+ logging.info(
2989
+ f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
2990
+ )
2991
+ start_time = time.time()
2992
+
2993
+ if not self.enable_optimizer_offloading:
2994
+ # Handle the KVZCH non-optimizer-offloading case
2995
+ optimizer_states = self._split_optimizer_states_kv_zch_no_offloading(
2996
+ sorted_id_tensor
2997
+ )
2998
+
2999
+ elif not self.backend_return_whole_row:
3000
+ # Handle the KVZCH with-optimizer-offloading case
3001
+ optimizer_states = self._split_optimizer_states_kv_zch_w_offloading(
3002
+ sorted_id_tensor, no_snapshot, should_flush
3003
+ )
3004
+
3005
+ else:
3006
+ # Handle the KVZCH with-optimizer-offloading backend-whole-row case
3007
+ optimizer_states = self._split_optimizer_states_kv_zch_whole_row(
3008
+ sorted_id_tensor, no_snapshot, should_flush
3009
+ )
3010
+
3011
+ logging.info(
3012
+ f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
3013
+ # pyre-ignore[16]
3014
+ f"num ids list: {None if not sorted_id_tensor else [ids.numel() for ids in sorted_id_tensor]}"
3015
+ )
3016
+
3017
+ return optimizer_states
3018
+
3019
+ @torch.jit.export
3020
+ def get_optimizer_state(
3021
+ self,
3022
+ sorted_id_tensor: Optional[list[torch.Tensor]],
3023
+ no_snapshot: bool = True,
3024
+ should_flush: bool = False,
3025
+ ) -> list[dict[str, torch.Tensor]]:
3026
+ """
3027
+ Returns a list of dictionaries of optimizer states split by table.
3028
+ """
3029
+ states_list: list[list[Tensor]] = self.split_optimizer_states(
3030
+ sorted_id_tensor=sorted_id_tensor,
3031
+ no_snapshot=no_snapshot,
3032
+ should_flush=should_flush,
3033
+ )
3034
+ state_names = self.optimizer.state_names()
3035
+ return [dict(zip(state_names, states)) for states in states_list]
3036
+
3037
+ @torch.jit.export
3038
+ def debug_split_embedding_weights(self) -> list[torch.Tensor]:
3039
+ """
3040
+ Returns a list of weights, split by table.
3041
+
3042
+ Testing only, very slow.
3043
+ """
3044
+ (rows, _) = zip(*self.embedding_specs)
3045
+
3046
+ rows_cumsum = [0] + list(itertools.accumulate(rows))
3047
+ splits = []
3048
+ get_event = torch.cuda.Event()
3049
+
3050
+ for t, (row, _) in enumerate(self.embedding_specs):
3051
+ weights = torch.empty(
3052
+ (row, self.max_D), dtype=self.weights_precision.as_dtype()
3053
+ )
3054
+ self.ssd_db.get_cuda(
3055
+ torch.arange(rows_cumsum[t], rows_cumsum[t + 1]).to(torch.int64),
3056
+ weights,
3057
+ torch.as_tensor([row]),
3058
+ )
3059
+ splits.append(weights)
3060
+
3061
+ # Record the event to create a dependency between get_cuda's callback
3062
+ # function and the kernel on the GPU default stream (the intention is
3063
+ # actually to synchronize between the callback CPU thread and the
3064
+ # Python CPU thread but we do not have a mechanism to explicitly sync
3065
+ # between them)
3066
+ get_event.record()
3067
+
3068
+ # Synchronize to make sure that the callback function in get_cuda
3069
+ # completes (here the CPU thread is blocked until get_event is done)
3070
+ get_event.synchronize()
3071
+
3072
+ # Reshape the weight tensors (this can be expensive, however, this
3073
+ # function is for debugging only)
3074
+ for t, (row, dim) in enumerate(self.embedding_specs):
3075
+ weight = splits[t]
3076
+ weight = weight[:, :dim].contiguous()
3077
+ assert weight.shape == (row, dim), "Shapes mismatch"
3078
+ splits[t] = weight
3079
+
3080
+ return splits
3081
+
3082
+ def clear_cache(self) -> None:
3083
+ # clear KV ZCH cache for checkpointing
3084
+ self._cached_kvzch_data = None
3085
+
3086
+ @torch.jit.ignore
3087
+ # pyre-ignore [3] - do not definte snapshot class EmbeddingSnapshotHandleWrapper to avoid import dependency in other production code
3088
+ def _may_create_snapshot_for_state_dict(
3089
+ self,
3090
+ no_snapshot: bool = True,
3091
+ should_flush: bool = False,
3092
+ ):
3093
+ """
3094
+ Create a rocksdb snapshot if needed.
3095
+ """
3096
+ start_time = time.time()
3097
+ # Force device synchronize for now
3098
+ torch.cuda.synchronize()
3099
+ snapshot_handle = None
3100
+ checkpoint_handle = None
3101
+ if self.backend_type == BackendType.SSD:
3102
+ # Create a rocksdb snapshot
3103
+ if not no_snapshot:
3104
+ # Flush L1 and L2 caches
3105
+ self.flush(force=should_flush)
3106
+ logging.info(
3107
+ f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
3108
+ )
3109
+ snapshot_handle = self.ssd_db.create_snapshot()
3110
+ checkpoint_handle = self.ssd_db.get_active_checkpoint_uuid(self.step)
3111
+ logging.info(
3112
+ f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
3113
+ )
3114
+ elif self.backend_type == BackendType.DRAM:
3115
+ # if there is any ongoing eviction, lets wait until eviction is finished before state_dict
3116
+ # so that we can reach consistent model state before/after state_dict
3117
+ evict_wait_start_time = time.time()
3118
+ self.ssd_db.wait_until_eviction_done()
3119
+ logging.info(
3120
+ f"state_dict wait for ongoing eviction: {time.time() - evict_wait_start_time} s"
3121
+ )
3122
+ self.flush(force=should_flush)
3123
+ return snapshot_handle, checkpoint_handle
3124
+
3125
+ def get_embedding_dim_for_kvt(
3126
+ self, metaheader_dim: int, emb_dim: int, is_loading_checkpoint: bool
3127
+ ) -> int:
3128
+ if self.load_ckpt_without_opt:
3129
+ # For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host.
3130
+ # So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that
3131
+ # backend will only have metaheader + weight.
3132
+ # For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error
3133
+ # after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb
3134
+ if is_loading_checkpoint:
3135
+ return (
3136
+ (
3137
+ metaheader_dim # metaheader is already padded
3138
+ + pad4(emb_dim)
3139
+ + pad4(self.optimizer_state_dim)
3140
+ )
3141
+ if self.backend_return_whole_row
3142
+ else emb_dim
3143
+ )
3144
+ else:
3145
+ return metaheader_dim + pad4(emb_dim)
3146
+ else:
3147
+ return (
3148
+ (
3149
+ metaheader_dim # metaheader is already padded
3150
+ + pad4(emb_dim)
3151
+ + pad4(self.optimizer_state_dim)
3152
+ )
3153
+ if self.backend_return_whole_row
3154
+ else emb_dim
3155
+ )
3156
+
3157
+ @torch.jit.export
3158
+ def split_embedding_weights(
3159
+ self,
3160
+ no_snapshot: bool = True,
3161
+ should_flush: bool = False,
3162
+ ) -> tuple[ # TODO: make this a NamedTuple for readability
3163
+ Union[list[PartiallyMaterializedTensor], list[torch.Tensor]],
3164
+ Optional[list[torch.Tensor]],
3165
+ Optional[list[torch.Tensor]],
3166
+ Optional[list[torch.Tensor]],
3167
+ ]:
3168
+ """
3169
+ This method is intended to be used by the checkpointing engine
3170
+ only.
3171
+
3172
+ Since we cannot materialize SSD backed tensors fully in CPU memory,
3173
+ we would create a KVTensorWrapper (ssd_split_table_batched_embeddings.cpp)
3174
+ for each table (shard), which allows caller to read the weights
3175
+ using `narrow()` in a rolling-window manner.
3176
+ Args:
3177
+ should_flush (bool): Flush caches if True. Note: this is an expensive
3178
+ operation, only set to True when necessary.
3179
+
3180
+ Returns:
3181
+ tuples of 3 lists, each element corresponds to a logical table
3182
+ 1st arg: partially materialized tensors, each representing a table
3183
+ 2nd arg: input id sorted in bucket id ascending order
3184
+ 3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
3185
+ where for the i th element, we have i + bucket_id_start = global bucket id
3186
+ 4th arg: kvzch eviction metadata for each input id sorted in bucket id ascending order
3187
+ """
3188
+ snapshot_handle, checkpoint_handle = self._may_create_snapshot_for_state_dict(
3189
+ no_snapshot=no_snapshot,
3190
+ should_flush=should_flush,
3191
+ )
3192
+
3193
+ dtype = self.weights_precision.as_dtype()
3194
+ if self.load_state_dict and self.kv_zch_params:
3195
+ # init for checkpointing loading
3196
+ assert (
3197
+ self._cached_kvzch_data is not None
3198
+ ), "weight id and bucket state are not initialized for load checkpointing"
3199
+ return (
3200
+ self._cached_kvzch_data.cached_weight_tensor_per_table,
3201
+ self._cached_kvzch_data.cached_id_tensor_per_table,
3202
+ self._cached_kvzch_data.cached_bucket_splits,
3203
+ [], # metadata tensor is not needed for checkpointing loading
3204
+ )
3205
+ start_time = time.time()
3206
+ pmt_splits = []
3207
+ bucket_sorted_id_splits = [] if self.kv_zch_params else None
3208
+ active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
3209
+ metadata_splits = [] if self.kv_zch_params else None
3210
+ skip_metadata = False
3211
+
3212
+ table_offset = 0
3213
+ for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
3214
+ is_loading_checkpoint = False
3215
+ bucket_ascending_id_tensor = None
3216
+ bucket_t = None
3217
+ metadata_tensor = None
3218
+ row_offset = table_offset
3219
+ metaheader_dim = 0
3220
+ if self.kv_zch_params:
3221
+ bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
3222
+ # pyre-ignore
3223
+ bucket_size = self.kv_zch_params.bucket_sizes[i]
3224
+ metaheader_dim = (
3225
+ # pyre-ignore[16]
3226
+ self.kv_zch_params.eviction_policy.meta_header_lens[i]
3227
+ )
3228
+
3229
+ # linearize with table offset
3230
+ table_input_id_start = table_offset
3231
+ table_input_id_end = table_offset + emb_height
3232
+ # 1. get all keys from backend for one table
3233
+ unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
3234
+ table_input_id_start,
3235
+ table_input_id_end,
3236
+ table_offset,
3237
+ snapshot_handle,
3238
+ )
3239
+ # 2. sorting keys in bucket ascending order
3240
+ bucket_ascending_id_tensor, bucket_t = (
3241
+ torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
3242
+ unordered_id_tensor,
3243
+ 0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
3244
+ 0, # local bucket offset
3245
+ bucket_id_end - bucket_id_start, # local bucket num
3246
+ bucket_size,
3247
+ )
3248
+ )
3249
+ metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
3250
+ bucket_ascending_id_tensor + table_offset,
3251
+ torch.as_tensor(bucket_ascending_id_tensor.size(0)),
3252
+ snapshot_handle,
3253
+ ).view(-1, 1)
3254
+
3255
+ # 3. convert local id back to global id
3256
+ bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
3257
+
3258
+ if (
3259
+ bucket_ascending_id_tensor.size(0) == 0
3260
+ and self.local_weight_counts[i] > 0
3261
+ ):
3262
+ logging.info(
3263
+ f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
3264
+ )
3265
+ if self.global_id_per_rank[i].numel() != 0:
3266
+ assert (
3267
+ self.local_weight_counts[i]
3268
+ == self.global_id_per_rank[i].numel()
3269
+ ), f"local weight count and global id per rank size mismatch, with {self.local_weight_counts[i]} and {self.global_id_per_rank[i].numel()}"
3270
+ bucket_ascending_id_tensor = self.global_id_per_rank[i].to(
3271
+ device=torch.device("cpu"), dtype=torch.int64
3272
+ )
3273
+ else:
3274
+ bucket_ascending_id_tensor = torch.zeros(
3275
+ (self.local_weight_counts[i], 1),
3276
+ device=torch.device("cpu"),
3277
+ dtype=torch.int64,
3278
+ )
3279
+ skip_metadata = True
3280
+ is_loading_checkpoint = True
3281
+
3282
+ # self.local_weight_counts[i] = 0 # Reset the count
3283
+
3284
+ # pyre-ignore [16] bucket_sorted_id_splits is not None
3285
+ bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
3286
+ active_id_cnt_per_bucket_split.append(bucket_t)
3287
+ if skip_metadata:
3288
+ metadata_splits = None
3289
+ else:
3290
+ metadata_splits.append(metadata_tensor)
3291
+
3292
+ # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
3293
+ # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
3294
+ # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
3295
+ # to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
3296
+ row_offset = table_offset - (bucket_id_start * bucket_size)
3297
+
3298
+ tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
3299
+ shape=[
3300
+ (
3301
+ bucket_ascending_id_tensor.size(0)
3302
+ if bucket_ascending_id_tensor is not None
3303
+ else emb_height
3304
+ ),
3305
+ self.get_embedding_dim_for_kvt(
3306
+ metaheader_dim, emb_dim, is_loading_checkpoint
3307
+ ),
3308
+ ],
3309
+ dtype=dtype,
3310
+ row_offset=row_offset,
3311
+ snapshot_handle=snapshot_handle,
3312
+ # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
3313
+ # embedding weights.
3314
+ sorted_indices=(
3315
+ bucket_ascending_id_tensor if self.kv_zch_params else None
3316
+ ),
3317
+ checkpoint_handle=checkpoint_handle,
3318
+ only_load_weight=(
3319
+ True
3320
+ if self.load_ckpt_without_opt and is_loading_checkpoint
3321
+ else False
3322
+ ),
3323
+ )
3324
+ (
3325
+ tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
3326
+ if self.backend_type == BackendType.SSD
3327
+ else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
3328
+ )
3329
+ table_offset += emb_height
3330
+ pmt_splits.append(
3331
+ PartiallyMaterializedTensor(
3332
+ tensor_wrapper,
3333
+ True if self.kv_zch_params else False,
3334
+ )
3335
+ )
3336
+ logging.info(
3337
+ f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms, "
3338
+ )
3339
+ if self.kv_zch_params is not None:
3340
+ logging.info(
3341
+ # pyre-ignore [16]
3342
+ f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
3343
+ )
3344
+
3345
+ return (
3346
+ pmt_splits,
3347
+ bucket_sorted_id_splits,
3348
+ active_id_cnt_per_bucket_split,
3349
+ metadata_splits,
3350
+ )
3351
+
3352
+ @torch.jit.ignore
3353
+ def _apply_state_dict_w_offloading(self) -> None:
3354
+ # Row count per table
3355
+ (rows, _) = zip(*self.embedding_specs)
3356
+ # Cumulative row counts per table for rowwise states
3357
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
3358
+
3359
+ for t, _ in enumerate(self.embedding_specs):
3360
+ # pyre-ignore [16]
3361
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
3362
+ # pyre-ignore [16]
3363
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
3364
+ row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
3365
+
3366
+ # pyre-ignore [16]
3367
+ weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
3368
+ # pyre-ignore [16]
3369
+ opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
3370
+
3371
+ self.streaming_write_weight_and_id_per_table(
3372
+ weight_state,
3373
+ opt_states,
3374
+ # pyre-ignore [16]
3375
+ self._cached_kvzch_data.cached_id_tensor_per_table[t],
3376
+ row_offset,
3377
+ )
3378
+ self._cached_kvzch_data.cached_weight_tensor_per_table[t] = None
3379
+ self._cached_kvzch_data.cached_optimizer_states_per_table[t] = None
3380
+
3381
+ @torch.jit.ignore
3382
+ def _apply_state_dict_no_offloading(self) -> None:
3383
+ # Row count per table
3384
+ (rows, _) = zip(*self.embedding_specs)
3385
+ # Cumulative row counts per table for rowwise states
3386
+ row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
3387
+
3388
+ def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None:
3389
+ device = dst.device
3390
+ dst.index_put_(
3391
+ indices=(
3392
+ # indices is expected to be a tuple of Tensors, not Tensor
3393
+ indices.to(device).view(-1),
3394
+ ),
3395
+ values=src.to(device),
3396
+ )
3397
+
3398
+ for t, _ in enumerate(rows):
3399
+ # pyre-ignore [16]
3400
+ bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
3401
+ # pyre-ignore [16]
3402
+ bucket_size = self.kv_zch_params.bucket_sizes[t]
3403
+ row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
3404
+
3405
+ # pyre-ignore [16]
3406
+ weights = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
3407
+ # pyre-ignore [16]
3408
+ ids = self._cached_kvzch_data.cached_id_tensor_per_table[t]
3409
+ local_ids = ids + row_offset
3410
+
3411
+ logging.info(
3412
+ f"applying sd for table {t} without optimizer offloading, local_ids is {local_ids}"
3413
+ )
3414
+ # pyre-ignore [16]
3415
+ opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
3416
+
3417
+ # Set up the plan for copying optimizer states over
3418
+ if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
3419
+ mapping = [(opt_states[0], self.momentum1_dev)]
3420
+ elif self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
3421
+ mapping = [
3422
+ (opt_states[0], self.momentum1_dev),
3423
+ (opt_states[1], self.momentum2_dev),
3424
+ ]
3425
+ else:
3426
+ mapping = []
3427
+
3428
+ # Execute the plan and copy the optimizer states over
3429
+ # pyre-ignore [6]
3430
+ [copy_optimizer_state_(dst, src, local_ids) for (src, dst) in mapping]
3431
+
3432
+ self.ssd_db.set_cuda(
3433
+ local_ids.view(-1),
3434
+ weights,
3435
+ torch.as_tensor(local_ids.size(0)),
3436
+ 1,
3437
+ False,
3438
+ )
3439
+
3440
+ @torch.jit.ignore
3441
+ def apply_state_dict(self) -> None:
3442
+ if self.backend_return_whole_row:
3443
+ logging.info(
3444
+ "backend_return_whole_row is enabled, no need to apply_state_dict"
3445
+ )
3446
+ return
3447
+ # After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
3448
+ # Caller should call this function to apply the cached states to backend.
3449
+ if self.load_state_dict is False:
3450
+ return
3451
+ self.load_state_dict = False
3452
+ assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
3453
+ assert (
3454
+ self._cached_kvzch_data is not None
3455
+ and self._cached_kvzch_data.cached_optimizer_states_per_table is not None
3456
+ ), "optimizer state is not initialized for load checkpointing"
3457
+ assert (
3458
+ self._cached_kvzch_data.cached_weight_tensor_per_table is not None
3459
+ and self._cached_kvzch_data.cached_id_tensor_per_table is not None
3460
+ ), "weight and id state is not initialized for load checkpointing"
3461
+
3462
+ # Compute the number of elements of cache_dtype needed to store the
3463
+ # optimizer state, round to the nearest 4
3464
+ # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
3465
+ # apply weight and optimizer state per table
3466
+ if self.enable_optimizer_offloading:
3467
+ self._apply_state_dict_w_offloading()
3468
+ else:
3469
+ self._apply_state_dict_no_offloading()
3470
+
3471
+ self.clear_cache()
3472
+
3473
+ @torch.jit.ignore
3474
+ def streaming_write_weight_and_id_per_table(
3475
+ self,
3476
+ weight_state: torch.Tensor,
3477
+ opt_states: list[torch.Tensor],
3478
+ id_tensor: torch.Tensor,
3479
+ row_offset: int,
3480
+ ) -> None:
3481
+ """
3482
+ This function is used to write weight, optimizer and id to the backend using kvt wrapper.
3483
+ to avoid over use memory, we will write the weight and id to backend in a rolling window manner
3484
+
3485
+ Args:
3486
+ weight_state (torch.tensor): The weight state tensor to be written.
3487
+ opt_states (torch.tensor): The optimizer state tensor(s) to be written.
3488
+ id_tensor (torch.tensor): The id tensor to be written.
3489
+ """
3490
+ D = weight_state.size(1)
3491
+ dtype = self.weights_precision.as_dtype()
3492
+
3493
+ optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
3494
+ D, self.weights_precision, self.optimizer_state_dtypes
3495
+ )
3496
+ optimizer_state_size_table = self.optimizer.state_size_table(D)
3497
+
3498
+ kvt = torch.classes.fbgemm.KVTensorWrapper(
3499
+ shape=[weight_state.size(0), self.cache_row_dim],
3500
+ dtype=dtype,
3501
+ row_offset=row_offset,
3502
+ snapshot_handle=None,
3503
+ sorted_indices=id_tensor,
3504
+ )
3505
+ (
3506
+ kvt.set_embedding_rocks_dp_wrapper(self.ssd_db)
3507
+ if self.backend_type == BackendType.SSD
3508
+ else kvt.set_dram_db_wrapper(self.ssd_db)
3509
+ )
3510
+
3511
+ # TODO: make chunk_size configurable or dynamic
3512
+ chunk_size = 10000
3513
+ row = weight_state.size(0)
3514
+
3515
+ for i in range(0, row, chunk_size):
3516
+ # Construct the chunk buffer, using the weights precision as the dtype
3517
+ length = min(chunk_size, row - i)
3518
+ chunk_buffer = torch.empty(
3519
+ length,
3520
+ self.cache_row_dim,
3521
+ dtype=dtype,
3522
+ device="cpu",
3523
+ )
3524
+
3525
+ # Copy the weight state over to the chunk buffer
3526
+ chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
3527
+
3528
+ # Copy the optimizer state(s) over to the chunk buffer
3529
+ for o, opt_state in enumerate(opt_states):
3530
+ # Fetch the state name based on the index
3531
+ state_name = self.optimizer.state_names()[o]
3532
+
3533
+ # Fetch the byte offsets for the optimizer state by its name
3534
+ (start, end) = optimizer_state_byte_offsets[state_name]
3535
+
3536
+ # Assume that the opt_state passed in already has dtype matching
3537
+ # self.optimizer_state_dtypes[state_name]
3538
+ opt_state_byteview = opt_state.view(
3539
+ # Force it to be 2D table, with row size matching the
3540
+ # optimizer state size
3541
+ -1,
3542
+ optimizer_state_size_table[state_name],
3543
+ ).view(
3544
+ # Then force tensor to byte view
3545
+ dtype=torch.uint8
3546
+ )
3547
+
3548
+ # Convert the chunk buffer and optimizer state to byte views
3549
+ # Then use the start and end offsets to narrow the chunk buffer
3550
+ # and copy opt_state over
3551
+ chunk_buffer.view(dtype=torch.uint8)[:, start:end] = opt_state_byteview[
3552
+ i : i + length, :
3553
+ ]
3554
+
3555
+ # Write chunk to KVTensor
3556
+ kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
3557
+
3558
+ @torch.jit.ignore
3559
+ def enable_load_state_dict_mode(self) -> None:
3560
+ if self.backend_return_whole_row:
3561
+ logging.info(
3562
+ "backend_return_whole_row is enabled, no need to enable load_state_dict mode"
3563
+ )
3564
+ return
3565
+ # Enable load state dict mode before loading checkpoint
3566
+ if self.load_state_dict:
3567
+ return
3568
+ self.load_state_dict = True
3569
+
3570
+ dtype = self.weights_precision.as_dtype()
3571
+ (_, dims) = zip(*self.embedding_specs)
3572
+
3573
+ self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
3574
+
3575
+ for i, _ in enumerate(self.embedding_specs):
3576
+ # For checkpointing loading, we need to store the weight and id
3577
+ # tensor temporarily in memory. First check that the local_weight_counts
3578
+ # are properly set before even initializing the optimizer states
3579
+ assert (
3580
+ self.local_weight_counts[i] > 0
3581
+ ), f"local_weight_counts for table {i} is not set"
3582
+
3583
+ # pyre-ignore [16]
3584
+ self._cached_kvzch_data.cached_optimizer_states_per_table = (
3585
+ self.optimizer.empty_states(
3586
+ self.local_weight_counts,
3587
+ dims,
3588
+ self.optimizer_state_dtypes,
3589
+ )
3590
+ )
3591
+
3592
+ for i, (_, emb_dim) in enumerate(self.embedding_specs):
3593
+ # pyre-ignore [16]
3594
+ bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
3595
+ rows = self.local_weight_counts[i]
3596
+ weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
3597
+ # pyre-ignore [16]
3598
+ self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
3599
+ logging.info(
3600
+ f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}"
3601
+ )
3602
+ id_tensor = torch.zeros((rows, 1), dtype=torch.int64, device="cpu")
3603
+ # pyre-ignore [16]
3604
+ self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
3605
+ # pyre-ignore [16]
3606
+ self._cached_kvzch_data.cached_bucket_splits.append(
3607
+ torch.empty(
3608
+ (bucket_id_end - bucket_id_start, 1),
3609
+ dtype=torch.int64,
3610
+ device="cpu",
3611
+ )
3612
+ )
3613
+
3614
+ @torch.jit.export
3615
+ def set_learning_rate(self, lr: float) -> None:
3616
+ """
3617
+ Sets the learning rate.
3618
+
3619
+ Args:
3620
+ lr (float): The learning rate value to set to
3621
+ """
3622
+ self._set_learning_rate(lr)
3623
+
3624
+ def get_learning_rate(self) -> float:
3625
+ """
3626
+ Get and return the learning rate.
3627
+ """
3628
+ return self.learning_rate_tensor.item()
3629
+
3630
+ @torch.jit.ignore
3631
+ def _set_learning_rate(self, lr: float) -> float:
3632
+ """
3633
+ Helper function to script `set_learning_rate`.
3634
+ Note that returning None does not work.
3635
+ """
3636
+ self.learning_rate_tensor = torch.tensor(
3637
+ lr, device=torch.device("cpu"), dtype=torch.float32
3638
+ )
3639
+ return 0.0
3640
+
3641
+ def flush(self, force: bool = False) -> None:
3642
+ # allow force flush from split_embedding_weights to cover edge cases, e.g. checkpointing
3643
+ # after trained 0 batches
3644
+ if not self.training:
3645
+ # for eval mode, we should not write anything to embedding
3646
+ return
3647
+
3648
+ if self.step == self.last_flush_step and not force:
3649
+ logging.info(
3650
+ f"SSD TBE has been flushed at {self.last_flush_step=} already for tbe:{self.tbe_unique_id}"
3651
+ )
3652
+ return
3653
+ logging.info(
3654
+ f"SSD TBE flush at {self.step=}, it is an expensive call please be cautious"
3655
+ )
3656
+ active_slots_mask = self.lxu_cache_state != -1
3657
+
3658
+ active_weights_gpu = self.lxu_cache_weights[active_slots_mask.view(-1)].view(
3659
+ -1, self.cache_row_dim
3660
+ )
3661
+ active_ids_gpu = self.lxu_cache_state.view(-1)[active_slots_mask.view(-1)]
3662
+
3663
+ active_weights_cpu = active_weights_gpu.cpu()
3664
+ active_ids_cpu = active_ids_gpu.cpu()
3665
+
3666
+ torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream)
3667
+
3668
+ torch.cuda.synchronize()
3669
+ self.ssd_db.set(
3670
+ active_ids_cpu,
3671
+ active_weights_cpu,
3672
+ torch.tensor([active_ids_cpu.numel()]),
3673
+ )
3674
+ self.ssd_db.flush()
3675
+ self.last_flush_step = self.step
3676
+
3677
+ def create_rocksdb_hard_link_snapshot(self) -> None:
3678
+ """
3679
+ Create a rocksdb hard link snapshot to provide cross procs access to the underlying data
3680
+ """
3681
+ if self.backend_type == BackendType.SSD:
3682
+ self.ssd_db.create_rocksdb_hard_link_snapshot(self.step)
3683
+ else:
3684
+ logging.warning(
3685
+ "create_rocksdb_hard_link_snapshot is only supported for SSD backend"
3686
+ )
3687
+
3688
+ def prepare_inputs(
3689
+ self,
3690
+ indices: Tensor,
3691
+ offsets: Tensor,
3692
+ per_sample_weights: Optional[Tensor] = None,
3693
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3694
+ ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3695
+ """
3696
+ Prepare TBE inputs
3697
+ """
3698
+ # Generate VBE metadata
3699
+ vbe_metadata = self._generate_vbe_metadata(
3700
+ offsets, batch_size_per_feature_per_rank
3701
+ )
3702
+
3703
+ # Force casting indices and offsets to long
3704
+ (indices, offsets) = indices.long(), offsets.long()
3705
+
3706
+ # Force casting per_sample_weights to float
3707
+ if per_sample_weights is not None:
3708
+ per_sample_weights = per_sample_weights.float()
3709
+
3710
+ if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
3711
+ torch.ops.fbgemm.bounds_check_indices(
3712
+ self.rows_per_table,
3713
+ indices,
3714
+ offsets,
3715
+ self.bounds_check_mode_int,
3716
+ self.bounds_check_warning,
3717
+ per_sample_weights,
3718
+ B_offsets=vbe_metadata.B_offsets,
3719
+ max_B=vbe_metadata.max_B,
3720
+ bounds_check_version=self.bounds_check_version,
3721
+ )
3722
+
3723
+ return indices, offsets, per_sample_weights, vbe_metadata
3724
+
3725
+ @torch.jit.ignore
3726
+ def _report_kv_backend_stats(self) -> None:
3727
+ """
3728
+ All ssd stats report function entrance
3729
+ """
3730
+ if self.stats_reporter is None:
3731
+ return
3732
+
3733
+ if not self.stats_reporter.should_report(self.step):
3734
+ return
3735
+ self._report_ssd_l1_cache_stats()
3736
+
3737
+ if self.backend_type == BackendType.SSD:
3738
+ self._report_ssd_io_stats()
3739
+ self._report_ssd_mem_usage()
3740
+ self._report_l2_cache_perf_stats()
3741
+ if self.backend_type == BackendType.DRAM:
3742
+ self._report_dram_kv_perf_stats()
3743
+ if self.kv_zch_params and self.kv_zch_params.eviction_policy:
3744
+ self._report_eviction_stats()
3745
+
3746
+ @torch.jit.ignore
3747
+ def _report_ssd_l1_cache_stats(self) -> None:
3748
+ """
3749
+ Each iteration we will record cache stats about L1 SSD cache in ssd_cache_stats tensor
3750
+ this function extract those stats and report it with stats_reporter
3751
+ """
3752
+ passed_steps = self.step - self.last_reported_step
3753
+ if passed_steps == 0:
3754
+ return
3755
+
3756
+ # ssd hbm cache stats
3757
+
3758
+ ssd_cache_stats = self.ssd_cache_stats.tolist()
3759
+ if len(self.last_reported_ssd_stats) == 0:
3760
+ self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
3761
+ ssd_cache_stats_delta: list[float] = [0.0] * len(ssd_cache_stats)
3762
+ for i in range(len(ssd_cache_stats)):
3763
+ ssd_cache_stats_delta[i] = (
3764
+ ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
3765
+ )
3766
+ self.last_reported_step = self.step
3767
+ self.last_reported_ssd_stats = ssd_cache_stats
3768
+ element_size = self.lxu_cache_weights.element_size()
3769
+
3770
+ for stat_index in UVMCacheStatsIndex:
3771
+ # pyre-ignore
3772
+ self.stats_reporter.report_data_amount(
3773
+ iteration_step=self.step,
3774
+ event_name=f"ssd_tbe.prefetch.cache_stats_by_data_size.{stat_index.name.lower()}",
3775
+ data_bytes=int(
3776
+ ssd_cache_stats_delta[stat_index.value]
3777
+ * element_size
3778
+ * self.cache_row_dim
3779
+ / passed_steps
3780
+ ),
3781
+ )
3782
+
3783
+ self.stats_reporter.report_data_amount(
3784
+ iteration_step=self.step,
3785
+ event_name=f"ssd_tbe.prefetch.cache_stats.{stat_index.name.lower()}",
3786
+ data_bytes=int(ssd_cache_stats_delta[stat_index.value] / passed_steps),
3787
+ )
3788
+
3789
+ @torch.jit.ignore
3790
+ def _report_ssd_io_stats(self) -> None:
3791
+ """
3792
+ EmbeddingRocksDB will hold stats for total read/write duration in fwd/bwd
3793
+ this function fetch the stats from EmbeddingRocksDB and report it with stats_reporter
3794
+ """
3795
+ ssd_io_duration = self.ssd_db.get_rocksdb_io_duration(
3796
+ self.step, self.stats_reporter.report_interval # pyre-ignore
3797
+ )
3798
+
3799
+ if len(ssd_io_duration) != 5:
3800
+ logging.error("ssd io duration should have 5 elements")
3801
+ return
3802
+
3803
+ ssd_read_dur_us = ssd_io_duration[0]
3804
+ fwd_rocksdb_read_dur = ssd_io_duration[1]
3805
+ fwd_l1_eviction_dur = ssd_io_duration[2]
3806
+ bwd_l1_cnflct_miss_write_back_dur = ssd_io_duration[3]
3807
+ flush_write_dur = ssd_io_duration[4]
3808
+
3809
+ # pyre-ignore [16]
3810
+ self.stats_reporter.report_duration(
3811
+ iteration_step=self.step,
3812
+ event_name="ssd.io_duration.read_us",
3813
+ duration_ms=ssd_read_dur_us,
3814
+ time_unit="us",
3815
+ )
3816
+
3817
+ self.stats_reporter.report_duration(
3818
+ iteration_step=self.step,
3819
+ event_name="ssd.io_duration.write.fwd_rocksdb_read_us",
3820
+ duration_ms=fwd_rocksdb_read_dur,
3821
+ time_unit="us",
3822
+ )
3823
+
3824
+ self.stats_reporter.report_duration(
3825
+ iteration_step=self.step,
3826
+ event_name="ssd.io_duration.write.fwd_l1_eviction_us",
3827
+ duration_ms=fwd_l1_eviction_dur,
3828
+ time_unit="us",
3829
+ )
3830
+
3831
+ self.stats_reporter.report_duration(
3832
+ iteration_step=self.step,
3833
+ event_name="ssd.io_duration.write.bwd_l1_cnflct_miss_write_back_us",
3834
+ duration_ms=bwd_l1_cnflct_miss_write_back_dur,
3835
+ time_unit="us",
3836
+ )
3837
+
3838
+ self.stats_reporter.report_duration(
3839
+ iteration_step=self.step,
3840
+ event_name="ssd.io_duration.write.flush_write_us",
3841
+ duration_ms=flush_write_dur,
3842
+ time_unit="us",
3843
+ )
3844
+
3845
+ @torch.jit.ignore
3846
+ def _report_ssd_mem_usage(
3847
+ self,
3848
+ ) -> None:
3849
+ """
3850
+ rocskdb has internal stats for dram mem usage, here we call EmbeddingRocksDB to
3851
+ extract those stats out and report it with stats_reporter
3852
+ """
3853
+ mem_usage_list = self.ssd_db.get_mem_usage()
3854
+ block_cache_usage = mem_usage_list[0]
3855
+ estimate_table_reader_usage = mem_usage_list[1]
3856
+ memtable_usage = mem_usage_list[2]
3857
+ block_cache_pinned_usage = mem_usage_list[3]
3858
+
3859
+ # pyre-ignore [16]
3860
+ self.stats_reporter.report_data_amount(
3861
+ iteration_step=self.step,
3862
+ event_name="ssd.mem_usage.block_cache",
3863
+ data_bytes=block_cache_usage,
3864
+ )
3865
+
3866
+ self.stats_reporter.report_data_amount(
3867
+ iteration_step=self.step,
3868
+ event_name="ssd.mem_usage.estimate_table_reader",
3869
+ data_bytes=estimate_table_reader_usage,
3870
+ )
3871
+
3872
+ self.stats_reporter.report_data_amount(
3873
+ iteration_step=self.step,
3874
+ event_name="ssd.mem_usage.memtable",
3875
+ data_bytes=memtable_usage,
3876
+ )
3877
+
3878
+ self.stats_reporter.report_data_amount(
3879
+ iteration_step=self.step,
3880
+ event_name="ssd.mem_usage.block_cache_pinned",
3881
+ data_bytes=block_cache_pinned_usage,
3882
+ )
3883
+
3884
+ @torch.jit.ignore
3885
+ def _report_l2_cache_perf_stats(self) -> None:
3886
+ """
3887
+ EmbeddingKVDB will hold stats for L2+SSD performance in fwd/bwd
3888
+ this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
3889
+ """
3890
+ if self.stats_reporter is None:
3891
+ return
3892
+
3893
+ stats_reporter: TBEStatsReporter = self.stats_reporter
3894
+ if not stats_reporter.should_report(self.step):
3895
+ return
3896
+
3897
+ l2_cache_perf_stats = self.ssd_db.get_l2cache_perf(
3898
+ self.step, stats_reporter.report_interval # pyre-ignore
3899
+ )
3900
+
3901
+ if len(l2_cache_perf_stats) != 15:
3902
+ logging.error("l2 perf stats should have 15 elements")
3903
+ return
3904
+
3905
+ num_cache_misses = l2_cache_perf_stats[0]
3906
+ num_lookups = l2_cache_perf_stats[1]
3907
+ get_total_duration = l2_cache_perf_stats[2]
3908
+ get_cache_lookup_total_duration = l2_cache_perf_stats[3]
3909
+ get_cache_lookup_wait_filling_thread_duration = l2_cache_perf_stats[4]
3910
+ get_weights_fillup_total_duration = l2_cache_perf_stats[5]
3911
+ get_cache_memcpy_duration = l2_cache_perf_stats[6]
3912
+ total_cache_update_duration = l2_cache_perf_stats[7]
3913
+ get_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[8]
3914
+ set_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[9]
3915
+ num_l2_evictions = l2_cache_perf_stats[10]
3916
+
3917
+ l2_cache_free_bytes = l2_cache_perf_stats[11]
3918
+ l2_cache_capacity = l2_cache_perf_stats[12]
3919
+
3920
+ set_cache_lock_wait_duration = l2_cache_perf_stats[13]
3921
+ get_cache_lock_wait_duration = l2_cache_perf_stats[14]
3922
+
3923
+ stats_reporter.report_data_amount(
3924
+ iteration_step=self.step,
3925
+ event_name=self.l2_num_cache_misses_stats_name,
3926
+ data_bytes=num_cache_misses,
3927
+ )
3928
+ stats_reporter.report_data_amount(
3929
+ iteration_step=self.step,
3930
+ event_name=self.l2_num_cache_lookups_stats_name,
3931
+ data_bytes=num_lookups,
3932
+ )
3933
+ stats_reporter.report_data_amount(
3934
+ iteration_step=self.step,
3935
+ event_name=self.l2_num_cache_evictions_stats_name,
3936
+ data_bytes=num_l2_evictions,
3937
+ )
3938
+ stats_reporter.report_data_amount(
3939
+ iteration_step=self.step,
3940
+ event_name=self.l2_cache_capacity_stats_name,
3941
+ data_bytes=l2_cache_capacity,
3942
+ )
3943
+ stats_reporter.report_data_amount(
3944
+ iteration_step=self.step,
3945
+ event_name=self.l2_cache_free_mem_stats_name,
3946
+ data_bytes=l2_cache_free_bytes,
3947
+ )
3948
+
3949
+ stats_reporter.report_duration(
3950
+ iteration_step=self.step,
3951
+ event_name="l2_cache.perf.get.total_duration_us",
3952
+ duration_ms=get_total_duration,
3953
+ time_unit="us",
3954
+ )
3955
+ stats_reporter.report_duration(
3956
+ iteration_step=self.step,
3957
+ event_name="l2_cache.perf.get.cache_lookup_duration_us",
3958
+ duration_ms=get_cache_lookup_total_duration,
3959
+ time_unit="us",
3960
+ )
3961
+ stats_reporter.report_duration(
3962
+ iteration_step=self.step,
3963
+ event_name="l2_cache.perf.get.cache_lookup_wait_filling_thread_duration_us",
3964
+ duration_ms=get_cache_lookup_wait_filling_thread_duration,
3965
+ time_unit="us",
3966
+ )
3967
+ stats_reporter.report_duration(
3968
+ iteration_step=self.step,
3969
+ event_name="l2_cache.perf.get.weights_fillup_duration_us",
3970
+ duration_ms=get_weights_fillup_total_duration,
3971
+ time_unit="us",
3972
+ )
3973
+ stats_reporter.report_duration(
3974
+ iteration_step=self.step,
3975
+ event_name="l2_cache.perf.get.cache_memcpy_duration_us",
3976
+ duration_ms=get_cache_memcpy_duration,
3977
+ time_unit="us",
3978
+ )
3979
+ stats_reporter.report_duration(
3980
+ iteration_step=self.step,
3981
+ event_name="l2_cache.perf.total.cache_update_duration_us",
3982
+ duration_ms=total_cache_update_duration,
3983
+ time_unit="us",
3984
+ )
3985
+ stats_reporter.report_duration(
3986
+ iteration_step=self.step,
3987
+ event_name="l2_cache.perf.get.tensor_copy_for_cache_update_duration_us",
3988
+ duration_ms=get_tensor_copy_for_cache_update_duration,
3989
+ time_unit="us",
3990
+ )
3991
+ stats_reporter.report_duration(
3992
+ iteration_step=self.step,
3993
+ event_name="l2_cache.perf.set.tensor_copy_for_cache_update_duration_us",
3994
+ duration_ms=set_tensor_copy_for_cache_update_duration,
3995
+ time_unit="us",
3996
+ )
3997
+
3998
+ stats_reporter.report_duration(
3999
+ iteration_step=self.step,
4000
+ event_name="l2_cache.perf.get.cache_lock_wait_duration_us",
4001
+ duration_ms=get_cache_lock_wait_duration,
4002
+ time_unit="us",
4003
+ )
4004
+ stats_reporter.report_duration(
4005
+ iteration_step=self.step,
4006
+ event_name="l2_cache.perf.set.cache_lock_wait_duration_us",
4007
+ duration_ms=set_cache_lock_wait_duration,
4008
+ time_unit="us",
4009
+ )
4010
+
4011
+ @torch.jit.ignore
4012
+ def _report_eviction_stats(self) -> None:
4013
+ if self.stats_reporter is None:
4014
+ return
4015
+
4016
+ stats_reporter: TBEStatsReporter = self.stats_reporter
4017
+ if not stats_reporter.should_report(self.step):
4018
+ return
4019
+
4020
+ # skip metrics reporting when evicting disabled
4021
+ if self.kv_zch_params.eviction_policy.eviction_trigger_mode == 0:
4022
+ return
4023
+
4024
+ T = len(set(self.feature_table_map))
4025
+ evicted_counts = torch.zeros(T, dtype=torch.int64)
4026
+ processed_counts = torch.zeros(T, dtype=torch.int64)
4027
+ eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float)
4028
+ full_duration_ms = torch.tensor(0, dtype=torch.int64)
4029
+ exec_duration_ms = torch.tensor(0, dtype=torch.int64)
4030
+ self.ssd_db.get_feature_evict_metric(
4031
+ evicted_counts,
4032
+ processed_counts,
4033
+ eviction_threshold_with_dry_run,
4034
+ full_duration_ms,
4035
+ exec_duration_ms,
4036
+ )
4037
+
4038
+ stats_reporter.report_data_amount(
4039
+ iteration_step=self.step,
4040
+ event_name=self.eviction_sum_evicted_counts_stats_name,
4041
+ data_bytes=int(evicted_counts.sum().item()),
4042
+ enable_tb_metrics=True,
4043
+ )
4044
+ stats_reporter.report_data_amount(
4045
+ iteration_step=self.step,
4046
+ event_name=self.eviction_sum_processed_counts_stats_name,
4047
+ data_bytes=int(processed_counts.sum().item()),
4048
+ enable_tb_metrics=True,
4049
+ )
4050
+ if processed_counts.sum().item() != 0:
4051
+ stats_reporter.report_data_amount(
4052
+ iteration_step=self.step,
4053
+ event_name=self.eviction_evict_rate_stats_name,
4054
+ data_bytes=int(
4055
+ evicted_counts.sum().item() * 100 / processed_counts.sum().item()
4056
+ ),
4057
+ enable_tb_metrics=True,
4058
+ )
4059
+ for t in self.feature_table_map:
4060
+ stats_reporter.report_data_amount(
4061
+ iteration_step=self.step,
4062
+ event_name=f"eviction.feature_table.{t}.evicted_counts",
4063
+ data_bytes=int(evicted_counts[t].item()),
4064
+ enable_tb_metrics=True,
4065
+ )
4066
+ stats_reporter.report_data_amount(
4067
+ iteration_step=self.step,
4068
+ event_name=f"eviction.feature_table.{t}.processed_counts",
4069
+ data_bytes=int(processed_counts[t].item()),
4070
+ enable_tb_metrics=True,
4071
+ )
4072
+ if processed_counts[t].item() != 0:
4073
+ stats_reporter.report_data_amount(
4074
+ iteration_step=self.step,
4075
+ event_name=f"eviction.feature_table.{t}.evict_rate",
4076
+ data_bytes=int(
4077
+ evicted_counts[t].item() * 100 / processed_counts[t].item()
4078
+ ),
4079
+ enable_tb_metrics=True,
4080
+ )
4081
+ stats_reporter.report_duration(
4082
+ iteration_step=self.step,
4083
+ event_name="eviction.feature_table.full_duration_ms",
4084
+ duration_ms=full_duration_ms.item(),
4085
+ time_unit="ms",
4086
+ enable_tb_metrics=True,
4087
+ )
4088
+ stats_reporter.report_duration(
4089
+ iteration_step=self.step,
4090
+ event_name="eviction.feature_table.exec_duration_ms",
4091
+ duration_ms=exec_duration_ms.item(),
4092
+ time_unit="ms",
4093
+ enable_tb_metrics=True,
4094
+ )
4095
+ if full_duration_ms.item() != 0:
4096
+ stats_reporter.report_data_amount(
4097
+ iteration_step=self.step,
4098
+ event_name="eviction.feature_table.exec_div_full_duration_rate",
4099
+ data_bytes=int(exec_duration_ms.item() * 100 / full_duration_ms.item()),
4100
+ enable_tb_metrics=True,
4101
+ )
4102
+
4103
+ @torch.jit.ignore
4104
+ def _report_dram_kv_perf_stats(self) -> None:
4105
+ """
4106
+ EmbeddingKVDB will hold stats for DRAM cache performance in fwd/bwd
4107
+ this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
4108
+ """
4109
+ if self.stats_reporter is None:
4110
+ return
4111
+
4112
+ stats_reporter: TBEStatsReporter = self.stats_reporter
4113
+ if not stats_reporter.should_report(self.step):
4114
+ return
4115
+
4116
+ dram_kv_perf_stats = self.ssd_db.get_dram_kv_perf(
4117
+ self.step, stats_reporter.report_interval # pyre-ignore
4118
+ )
4119
+
4120
+ if len(dram_kv_perf_stats) != 36:
4121
+ logging.error("dram cache perf stats should have 36 elements")
4122
+ return
4123
+
4124
+ dram_read_duration = dram_kv_perf_stats[0]
4125
+ dram_read_sharding_duration = dram_kv_perf_stats[1]
4126
+ dram_read_cache_hit_copy_duration = dram_kv_perf_stats[2]
4127
+ dram_read_fill_row_storage_duration = dram_kv_perf_stats[3]
4128
+ dram_read_lookup_cache_duration = dram_kv_perf_stats[4]
4129
+ dram_read_acquire_lock_duration = dram_kv_perf_stats[5]
4130
+ dram_read_missing_load = dram_kv_perf_stats[6]
4131
+ dram_write_sharing_duration = dram_kv_perf_stats[7]
4132
+
4133
+ dram_fwd_l1_eviction_write_duration = dram_kv_perf_stats[8]
4134
+ dram_fwd_l1_eviction_write_allocate_duration = dram_kv_perf_stats[9]
4135
+ dram_fwd_l1_eviction_write_cache_copy_duration = dram_kv_perf_stats[10]
4136
+ dram_fwd_l1_eviction_write_lookup_cache_duration = dram_kv_perf_stats[11]
4137
+ dram_fwd_l1_eviction_write_acquire_lock_duration = dram_kv_perf_stats[12]
4138
+ dram_fwd_l1_eviction_write_missing_load = dram_kv_perf_stats[13]
4139
+
4140
+ dram_bwd_l1_cnflct_miss_write_duration = dram_kv_perf_stats[14]
4141
+ dram_bwd_l1_cnflct_miss_write_allocate_duration = dram_kv_perf_stats[15]
4142
+ dram_bwd_l1_cnflct_miss_write_cache_copy_duration = dram_kv_perf_stats[16]
4143
+ dram_bwd_l1_cnflct_miss_write_lookup_cache_duration = dram_kv_perf_stats[17]
4144
+ dram_bwd_l1_cnflct_miss_write_acquire_lock_duration = dram_kv_perf_stats[18]
4145
+ dram_bwd_l1_cnflct_miss_write_missing_load = dram_kv_perf_stats[19]
4146
+
4147
+ dram_kv_allocated_bytes = dram_kv_perf_stats[20]
4148
+ dram_kv_actual_used_chunk_bytes = dram_kv_perf_stats[21]
4149
+ dram_kv_num_rows = dram_kv_perf_stats[22]
4150
+ dram_kv_read_counts = dram_kv_perf_stats[23]
4151
+ dram_metadata_write_sharding_total_duration = dram_kv_perf_stats[24]
4152
+ dram_metadata_write_total_duration = dram_kv_perf_stats[25]
4153
+ dram_metadata_write_allocate_avg_duration = dram_kv_perf_stats[26]
4154
+ dram_metadata_write_lookup_cache_avg_duration = dram_kv_perf_stats[27]
4155
+ dram_metadata_write_acquire_lock_avg_duration = dram_kv_perf_stats[28]
4156
+ dram_metadata_write_cache_miss_avg_count = dram_kv_perf_stats[29]
4157
+
4158
+ dram_read_metadata_total_duration = dram_kv_perf_stats[30]
4159
+ dram_read_metadata_sharding_total_duration = dram_kv_perf_stats[31]
4160
+ dram_read_metadata_cache_hit_copy_avg_duration = dram_kv_perf_stats[32]
4161
+ dram_read_metadata_lookup_cache_total_avg_duration = dram_kv_perf_stats[33]
4162
+ dram_read_metadata_acquire_lock_avg_duration = dram_kv_perf_stats[34]
4163
+ dram_read_read_metadata_load_size = dram_kv_perf_stats[35]
4164
+
4165
+ stats_reporter.report_duration(
4166
+ iteration_step=self.step,
4167
+ event_name="dram_kv.perf.get.dram_read_duration_us",
4168
+ duration_ms=dram_read_duration,
4169
+ enable_tb_metrics=True,
4170
+ time_unit="us",
4171
+ )
4172
+ stats_reporter.report_duration(
4173
+ iteration_step=self.step,
4174
+ event_name="dram_kv.perf.get.dram_read_sharding_duration_us",
4175
+ duration_ms=dram_read_sharding_duration,
4176
+ enable_tb_metrics=True,
4177
+ time_unit="us",
4178
+ )
4179
+ stats_reporter.report_duration(
4180
+ iteration_step=self.step,
4181
+ event_name="dram_kv.perf.get.dram_read_cache_hit_copy_duration_us",
4182
+ duration_ms=dram_read_cache_hit_copy_duration,
4183
+ enable_tb_metrics=True,
4184
+ time_unit="us",
4185
+ )
4186
+ stats_reporter.report_duration(
4187
+ iteration_step=self.step,
4188
+ event_name="dram_kv.perf.get.dram_read_fill_row_storage_duration_us",
4189
+ duration_ms=dram_read_fill_row_storage_duration,
4190
+ enable_tb_metrics=True,
4191
+ time_unit="us",
4192
+ )
4193
+ stats_reporter.report_duration(
4194
+ iteration_step=self.step,
4195
+ event_name="dram_kv.perf.get.dram_read_lookup_cache_duration_us",
4196
+ duration_ms=dram_read_lookup_cache_duration,
4197
+ enable_tb_metrics=True,
4198
+ time_unit="us",
4199
+ )
4200
+ stats_reporter.report_duration(
4201
+ iteration_step=self.step,
4202
+ event_name="dram_kv.perf.get.dram_read_acquire_lock_duration_us",
4203
+ duration_ms=dram_read_acquire_lock_duration,
4204
+ enable_tb_metrics=True,
4205
+ time_unit="us",
4206
+ )
4207
+ stats_reporter.report_data_amount(
4208
+ iteration_step=self.step,
4209
+ event_name="dram_kv.perf.get.dram_read_missing_load",
4210
+ enable_tb_metrics=True,
4211
+ data_bytes=dram_read_missing_load,
4212
+ )
4213
+ stats_reporter.report_duration(
4214
+ iteration_step=self.step,
4215
+ event_name="dram_kv.perf.set.dram_write_sharing_duration_us",
4216
+ duration_ms=dram_write_sharing_duration,
4217
+ enable_tb_metrics=True,
4218
+ time_unit="us",
4219
+ )
4220
+
4221
+ stats_reporter.report_duration(
4222
+ iteration_step=self.step,
4223
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_duration_us",
4224
+ duration_ms=dram_fwd_l1_eviction_write_duration,
4225
+ enable_tb_metrics=True,
4226
+ time_unit="us",
4227
+ )
4228
+ stats_reporter.report_duration(
4229
+ iteration_step=self.step,
4230
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_allocate_duration_us",
4231
+ duration_ms=dram_fwd_l1_eviction_write_allocate_duration,
4232
+ enable_tb_metrics=True,
4233
+ time_unit="us",
4234
+ )
4235
+ stats_reporter.report_duration(
4236
+ iteration_step=self.step,
4237
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_cache_copy_duration_us",
4238
+ duration_ms=dram_fwd_l1_eviction_write_cache_copy_duration,
4239
+ enable_tb_metrics=True,
4240
+ time_unit="us",
4241
+ )
4242
+ stats_reporter.report_duration(
4243
+ iteration_step=self.step,
4244
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_lookup_cache_duration_us",
4245
+ duration_ms=dram_fwd_l1_eviction_write_lookup_cache_duration,
4246
+ enable_tb_metrics=True,
4247
+ time_unit="us",
4248
+ )
4249
+ stats_reporter.report_duration(
4250
+ iteration_step=self.step,
4251
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_acquire_lock_duration_us",
4252
+ duration_ms=dram_fwd_l1_eviction_write_acquire_lock_duration,
4253
+ enable_tb_metrics=True,
4254
+ time_unit="us",
4255
+ )
4256
+ stats_reporter.report_data_amount(
4257
+ iteration_step=self.step,
4258
+ event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load",
4259
+ data_bytes=dram_fwd_l1_eviction_write_missing_load,
4260
+ enable_tb_metrics=True,
4261
+ )
4262
+
4263
+ stats_reporter.report_duration(
4264
+ iteration_step=self.step,
4265
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_duration_us",
4266
+ duration_ms=dram_bwd_l1_cnflct_miss_write_duration,
4267
+ enable_tb_metrics=True,
4268
+ time_unit="us",
4269
+ )
4270
+ stats_reporter.report_duration(
4271
+ iteration_step=self.step,
4272
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_allocate_duration_us",
4273
+ duration_ms=dram_bwd_l1_cnflct_miss_write_allocate_duration,
4274
+ enable_tb_metrics=True,
4275
+ time_unit="us",
4276
+ )
4277
+ stats_reporter.report_duration(
4278
+ iteration_step=self.step,
4279
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_cache_copy_duration_us",
4280
+ duration_ms=dram_bwd_l1_cnflct_miss_write_cache_copy_duration,
4281
+ enable_tb_metrics=True,
4282
+ time_unit="us",
4283
+ )
4284
+ stats_reporter.report_duration(
4285
+ iteration_step=self.step,
4286
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_lookup_cache_duration_us",
4287
+ duration_ms=dram_bwd_l1_cnflct_miss_write_lookup_cache_duration,
4288
+ enable_tb_metrics=True,
4289
+ time_unit="us",
4290
+ )
4291
+ stats_reporter.report_duration(
4292
+ iteration_step=self.step,
4293
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_acquire_lock_duration_us",
4294
+ duration_ms=dram_bwd_l1_cnflct_miss_write_acquire_lock_duration,
4295
+ enable_tb_metrics=True,
4296
+ time_unit="us",
4297
+ )
4298
+ stats_reporter.report_data_amount(
4299
+ iteration_step=self.step,
4300
+ event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load",
4301
+ data_bytes=dram_bwd_l1_cnflct_miss_write_missing_load,
4302
+ enable_tb_metrics=True,
4303
+ )
4304
+
4305
+ stats_reporter.report_data_amount(
4306
+ iteration_step=self.step,
4307
+ event_name="dram_kv.perf.get.dram_kv_read_counts",
4308
+ data_bytes=dram_kv_read_counts,
4309
+ enable_tb_metrics=True,
4310
+ )
4311
+
4312
+ stats_reporter.report_data_amount(
4313
+ iteration_step=self.step,
4314
+ event_name=self.dram_kv_allocated_bytes_stats_name,
4315
+ data_bytes=dram_kv_allocated_bytes,
4316
+ enable_tb_metrics=True,
4317
+ )
4318
+ stats_reporter.report_data_amount(
4319
+ iteration_step=self.step,
4320
+ event_name=self.dram_kv_actual_used_chunk_bytes_stats_name,
4321
+ data_bytes=dram_kv_actual_used_chunk_bytes,
4322
+ enable_tb_metrics=True,
4323
+ )
4324
+ stats_reporter.report_data_amount(
4325
+ iteration_step=self.step,
4326
+ event_name=self.dram_kv_mem_num_rows_stats_name,
4327
+ data_bytes=dram_kv_num_rows,
4328
+ enable_tb_metrics=True,
4329
+ )
4330
+ stats_reporter.report_duration(
4331
+ iteration_step=self.step,
4332
+ event_name="dram_kv.perf.set.dram_eviction_score_write_sharding_total_duration_us",
4333
+ duration_ms=dram_metadata_write_sharding_total_duration,
4334
+ enable_tb_metrics=True,
4335
+ time_unit="us",
4336
+ )
4337
+ stats_reporter.report_duration(
4338
+ iteration_step=self.step,
4339
+ event_name="dram_kv.perf.set.dram_eviction_score_write_total_duration_us",
4340
+ duration_ms=dram_metadata_write_total_duration,
4341
+ enable_tb_metrics=True,
4342
+ time_unit="us",
4343
+ )
4344
+ stats_reporter.report_duration(
4345
+ iteration_step=self.step,
4346
+ event_name="dram_kv.perf.set.dram_eviction_score_write_allocate_avg_duration_us",
4347
+ duration_ms=dram_metadata_write_allocate_avg_duration,
4348
+ enable_tb_metrics=True,
4349
+ time_unit="us",
4350
+ )
4351
+ stats_reporter.report_duration(
4352
+ iteration_step=self.step,
4353
+ event_name="dram_kv.perf.set.dram_eviction_score_write_lookup_cache_avg_duration_us",
4354
+ duration_ms=dram_metadata_write_lookup_cache_avg_duration,
4355
+ enable_tb_metrics=True,
4356
+ time_unit="us",
4357
+ )
4358
+ stats_reporter.report_duration(
4359
+ iteration_step=self.step,
4360
+ event_name="dram_kv.perf.set.dram_eviction_score_write_acquire_lock_avg_duration_us",
4361
+ duration_ms=dram_metadata_write_acquire_lock_avg_duration,
4362
+ enable_tb_metrics=True,
4363
+ time_unit="us",
4364
+ )
4365
+ stats_reporter.report_data_amount(
4366
+ iteration_step=self.step,
4367
+ event_name="dram_kv.perf.set.dram_eviction_score_write_cache_miss_avg_count",
4368
+ data_bytes=dram_metadata_write_cache_miss_avg_count,
4369
+ enable_tb_metrics=True,
4370
+ )
4371
+ stats_reporter.report_duration(
4372
+ iteration_step=self.step,
4373
+ event_name="dram_kv.perf.get.dram_eviction_score_read_total_duration_us",
4374
+ duration_ms=dram_read_metadata_total_duration,
4375
+ enable_tb_metrics=True,
4376
+ time_unit="us",
4377
+ )
4378
+ stats_reporter.report_duration(
4379
+ iteration_step=self.step,
4380
+ event_name="dram_kv.perf.get.dram_eviction_score_read_sharding_total_duration_us",
4381
+ duration_ms=dram_read_metadata_sharding_total_duration,
4382
+ enable_tb_metrics=True,
4383
+ time_unit="us",
4384
+ )
4385
+ stats_reporter.report_duration(
4386
+ iteration_step=self.step,
4387
+ event_name="dram_kv.perf.get.dram_eviction_score_read_cache_hit_copy_avg_duration_us",
4388
+ duration_ms=dram_read_metadata_cache_hit_copy_avg_duration,
4389
+ enable_tb_metrics=True,
4390
+ time_unit="us",
4391
+ )
4392
+ stats_reporter.report_duration(
4393
+ iteration_step=self.step,
4394
+ event_name="dram_kv.perf.get.dram_eviction_score_read_lookup_cache_total_avg_duration_us",
4395
+ duration_ms=dram_read_metadata_lookup_cache_total_avg_duration,
4396
+ enable_tb_metrics=True,
4397
+ time_unit="us",
4398
+ )
4399
+ stats_reporter.report_duration(
4400
+ iteration_step=self.step,
4401
+ event_name="dram_kv.perf.get.dram_eviction_score_read_acquire_lock_avg_duration_us",
4402
+ duration_ms=dram_read_metadata_acquire_lock_avg_duration,
4403
+ enable_tb_metrics=True,
4404
+ time_unit="us",
4405
+ )
4406
+ stats_reporter.report_data_amount(
4407
+ iteration_step=self.step,
4408
+ event_name="dram_kv.perf.get.dram_eviction_score_read_load_size",
4409
+ data_bytes=dram_read_read_metadata_load_size,
4410
+ enable_tb_metrics=True,
4411
+ )
4412
+
4413
+ def _recording_to_timer(
4414
+ self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
4415
+ ) -> Any:
4416
+ """
4417
+ helper function to call AsyncSeriesTimer, wrap it inside the kernels we want to record
4418
+ """
4419
+ if self.stats_reporter is not None and self.stats_reporter.should_report(
4420
+ self.step
4421
+ ):
4422
+ assert (
4423
+ timer
4424
+ ), "We shouldn't be here, async timer must have been initiated if reporter is present."
4425
+ return timer.recording(**kwargs)
4426
+ # No-Op context manager
4427
+ return contextlib.nullcontext()
4428
+
4429
+ def fetch_from_l1_sp_w_row_ids(
4430
+ self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False
4431
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
4432
+ """
4433
+ Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
4434
+ @return: updated_weights/optimizer_states, mask of which rows are filled
4435
+ """
4436
+ if not self.enable_optimizer_offloading and only_get_optimizer_states:
4437
+ raise RuntimeError(
4438
+ "Optimizer states are not offloaded, while only_get_optimizer_states is True"
4439
+ )
4440
+
4441
+ # NOTE: Remove this once there is support for fetching multiple
4442
+ # optimizer states in fetch_from_l1_sp_w_row_ids
4443
+ if only_get_optimizer_states and self.optimizer not in [
4444
+ OptimType.EXACT_ROWWISE_ADAGRAD,
4445
+ OptimType.PARTIAL_ROWWISE_ADAM,
4446
+ ]:
4447
+ raise RuntimeError(
4448
+ f"Fetching optimizer states using fetch_from_l1_sp_w_row_ids() is not yet supported for {self.optimizer}"
4449
+ )
4450
+
4451
+ def split_results_by_opt_states(
4452
+ updated_weights: torch.Tensor, cache_location_mask: torch.Tensor
4453
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
4454
+ if not only_get_optimizer_states:
4455
+ return [updated_weights], cache_location_mask
4456
+ # TODO: support mixed dimension case
4457
+ # currently only supports tables with the same max_D dimension
4458
+ opt_to_dim = self.optimizer.byte_offsets_along_row(
4459
+ self.max_D, self.weights_precision, self.optimizer_state_dtypes
4460
+ )
4461
+ updated_opt_states = []
4462
+ for opt_name, dim in opt_to_dim.items():
4463
+ opt_dtype = self.optimizer._extract_dtype(
4464
+ self.optimizer_state_dtypes, opt_name
4465
+ )
4466
+ updated_opt_states.append(
4467
+ updated_weights.view(dtype=torch.uint8)[:, dim[0] : dim[1]].view(
4468
+ dtype=opt_dtype
4469
+ )
4470
+ )
4471
+ return updated_opt_states, cache_location_mask
4472
+
4473
+ with torch.no_grad():
4474
+ weights_dtype = self.weights_precision.as_dtype()
4475
+ step = self.step
4476
+ with record_function(f"## fetch_from_l1_{step}_{self.tbe_unique_id} ##"):
4477
+ lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
4478
+ row_ids,
4479
+ self.lxu_cache_state,
4480
+ self.total_hash_size,
4481
+ )
4482
+ updated_weights = torch.empty(
4483
+ row_ids.numel(),
4484
+ self.cache_row_dim,
4485
+ device=self.current_device,
4486
+ dtype=weights_dtype,
4487
+ )
4488
+
4489
+ # D2D copy cache
4490
+ cache_location_mask = lxu_cache_locations >= 0
4491
+ torch.ops.fbgemm.masked_index_select(
4492
+ updated_weights,
4493
+ lxu_cache_locations,
4494
+ self.lxu_cache_weights,
4495
+ torch.tensor(
4496
+ [row_ids.numel()],
4497
+ device=self.current_device,
4498
+ dtype=torch.int32,
4499
+ ),
4500
+ )
4501
+
4502
+ with record_function(f"## fetch_from_sp_{step}_{self.tbe_unique_id} ##"):
4503
+ if len(self.ssd_scratch_pad_eviction_data) > 0:
4504
+ sp = self.ssd_scratch_pad_eviction_data[0][0]
4505
+ sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
4506
+ self.current_device
4507
+ )
4508
+ actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
4509
+ if actions_count_gpu.item() == 0:
4510
+ # no action to take
4511
+ return split_results_by_opt_states(
4512
+ updated_weights, cache_location_mask
4513
+ )
4514
+
4515
+ sp_idx = sp_idx[:actions_count_gpu]
4516
+
4517
+ # -1 in lxu_cache_locations means the row is not in L1 cache and in SP
4518
+ # fill the row_ids in L1 with -2, >0 values means in SP
4519
+ # @eg. updated_row_ids_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
4520
+ updated_row_ids_in_sp = row_ids.masked_fill(
4521
+ lxu_cache_locations != -1, -2
4522
+ )
4523
+ # sort the sp_idx for binary search
4524
+ # should already be sorted
4525
+ # sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
4526
+ # @eg. sp_idx = [4, 2, 1, 3, 10]
4527
+ # @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
4528
+ sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
4529
+ # search rows id in sp against the SP indexes to find location of the rows in SP
4530
+ # @eg: updated_ids_in_sp_idx = [0, 5, 0, 1, 0, 2, 3, 4, 4]
4531
+ # @eg: 5 is OOB
4532
+ updated_ids_in_sp_idx = torch.searchsorted(
4533
+ sorted_sp_idx, updated_row_ids_in_sp
4534
+ )
4535
+ # does not found in SP will Out of Bound
4536
+ oob_sp_idx = updated_ids_in_sp_idx >= sp_idx.numel()
4537
+ # make the oob items in bound
4538
+ # @eg updated_ids_in_sp_idx=[0, 0, 0, 1, 0, 2, 3, 4, 4]
4539
+ updated_ids_in_sp_idx[oob_sp_idx] = 0
4540
+
4541
+ # -1s locations will be filtered out in masked_index_select
4542
+ sp_locations_in_updated_weights = torch.full_like(
4543
+ updated_row_ids_in_sp, -1
4544
+ )
4545
+ # torch.searchsorted is not exact match,
4546
+ # we only take exact matched rows, where the id is found in SP.
4547
+ # @eg 5 in updated_row_ids_in_sp is not in sp_idx, but has 4 in updated_ids_in_sp_idx
4548
+ # @eg sorted_sp_idx[updated_ids_in_sp_idx]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
4549
+ # @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
4550
+ exact_match_mask = (
4551
+ sorted_sp_idx[updated_ids_in_sp_idx] == updated_row_ids_in_sp
4552
+ )
4553
+ # Get the location of the row ids found in SP.
4554
+ # @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
4555
+ sp_locations_found = sp_idx_inverse_indices[
4556
+ updated_ids_in_sp_idx[exact_match_mask]
4557
+ ]
4558
+ # @eg: sp_locations_in_updated_weights=[ 2, -1, 2, 1, -1, 3, 0, -1, 4]
4559
+ sp_locations_in_updated_weights[exact_match_mask] = (
4560
+ sp_locations_found
4561
+ )
4562
+
4563
+ # D2D copy SP
4564
+ torch.ops.fbgemm.masked_index_select(
4565
+ updated_weights,
4566
+ sp_locations_in_updated_weights,
4567
+ sp,
4568
+ torch.tensor(
4569
+ [row_ids.numel()],
4570
+ device=self.current_device,
4571
+ dtype=torch.int32,
4572
+ ),
4573
+ )
4574
+ # cache_location_mask is the mask of rows in L1
4575
+ # exact_match_mask is the mask of rows in SP
4576
+ cache_location_mask = torch.logical_or(
4577
+ cache_location_mask, exact_match_mask
4578
+ )
4579
+
4580
+ return split_results_by_opt_states(updated_weights, cache_location_mask)
4581
+
4582
+ def register_backward_hook_before_eviction(
4583
+ self, backward_hook: Callable[[torch.Tensor], None]
4584
+ ) -> None:
4585
+ """
4586
+ Register a backward hook to the TBE module.
4587
+ And make sure this is called before the sp eviction hook.
4588
+ """
4589
+ # make sure this hook is the first one to be executed
4590
+ hooks = []
4591
+ backward_hooks = self.placeholder_autograd_tensor._backward_hooks
4592
+ if backward_hooks is not None:
4593
+ for _handle_id, hook in backward_hooks.items():
4594
+ hooks.append(hook)
4595
+ backward_hooks.clear()
4596
+
4597
+ self.placeholder_autograd_tensor.register_hook(backward_hook)
4598
+ for hook in hooks:
4599
+ self.placeholder_autograd_tensor.register_hook(hook)
4600
+
4601
+ def set_local_weight_counts_for_table(
4602
+ self, table_idx: int, weight_count: int
4603
+ ) -> None:
4604
+ self.local_weight_counts[table_idx] = weight_count
4605
+
4606
+ def set_global_id_per_rank_for_table(
4607
+ self, table_idx: int, global_id: torch.Tensor
4608
+ ) -> None:
4609
+ self.global_id_per_rank[table_idx] = global_id
4610
+
4611
+ def direct_write_embedding(
4612
+ self,
4613
+ indices: torch.Tensor,
4614
+ offsets: torch.Tensor,
4615
+ weights: torch.Tensor,
4616
+ ) -> None:
4617
+ """
4618
+ Directly write the weights to L1, SP and backend without relying on auto-gradient for embedding cache.
4619
+ Please refer to design doc for more details: https://docs.google.com/document/d/1TJHKvO1m3-5tYAKZGhacXnGk7iCNAzz7wQlrFbX_LDI/edit?tab=t.0
4620
+ """
4621
+ assert (
4622
+ self._embedding_cache_mode
4623
+ ), "Must be in embedding_cache_mode to support direct_write_embedding method."
4624
+
4625
+ B_offsets = None
4626
+ max_B = -1
4627
+
4628
+ with torch.no_grad():
4629
+ # Wait for any ongoing prefetch operations to complete before starting direct_write
4630
+ current_stream = torch.cuda.current_stream()
4631
+ current_stream.wait_event(self.prefetch_complete_event)
4632
+
4633
+ # Create local step events for internal sequential execution
4634
+ weights_dtype = self.weights_precision.as_dtype()
4635
+ assert (
4636
+ weights_dtype == weights.dtype
4637
+ ), f"Expected embedding table dtype {weights_dtype} is same with input weight dtype, but got {weights.dtype}"
4638
+
4639
+ # Pad the weights to match self.max_D width if necessary
4640
+ if weights.size(1) < self.cache_row_dim:
4641
+ weights = torch.nn.functional.pad(
4642
+ weights, (0, self.cache_row_dim - weights.size(1))
4643
+ )
4644
+
4645
+ step = self.step
4646
+
4647
+ # step 0: run backward hook for prefetch if prefetch pipeline is enabled before writing to L1 and SP
4648
+ if self.prefetch_pipeline:
4649
+ self._update_cache_counter_and_pointers(nn.Module(), torch.empty(0))
4650
+
4651
+ # step 1: lookup and write to l1 cache
4652
+ with record_function(
4653
+ f"## direct_write_to_l1_{step}_{self.tbe_unique_id} ##"
4654
+ ):
4655
+ if self.gather_ssd_cache_stats:
4656
+ self.local_ssd_cache_stats.zero_()
4657
+
4658
+ # Linearize indices
4659
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
4660
+ self.hash_size_cumsum,
4661
+ indices,
4662
+ offsets,
4663
+ B_offsets,
4664
+ max_B,
4665
+ )
4666
+
4667
+ lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
4668
+ linear_cache_indices,
4669
+ self.lxu_cache_state,
4670
+ self.total_hash_size,
4671
+ )
4672
+ cache_location_mask = lxu_cache_locations >= 0
4673
+
4674
+ # Get the cache locations for the row_ids that are already in the cache
4675
+ cache_locations = lxu_cache_locations[cache_location_mask]
4676
+
4677
+ # Get the corresponding input weights for these row_ids
4678
+ cache_weights = weights[cache_location_mask]
4679
+
4680
+ # Update the cache with these input weights
4681
+ if cache_locations.numel() > 0:
4682
+ self.lxu_cache_weights.index_put_(
4683
+ (cache_locations,), cache_weights, accumulate=False
4684
+ )
4685
+
4686
+ # Record completion of step 1
4687
+ current_stream.record_event(self.direct_write_l1_complete_event)
4688
+
4689
+ # step 2: pop the current scratch pad and write to next batch scratch pad if exists
4690
+ # Wait for step 1 to complete
4691
+ with record_function(
4692
+ f"## direct_write_to_sp_{step}_{self.tbe_unique_id} ##"
4693
+ ):
4694
+ if len(self.ssd_scratch_pad_eviction_data) > 0:
4695
+ self.ssd_scratch_pad_eviction_data.pop(0)
4696
+ if len(self.ssd_scratch_pad_eviction_data) > 0:
4697
+ # Wait for any pending backend reads to the next scratch pad
4698
+ # to complete before we write to it. Otherwise, stale backend data
4699
+ # will overwrite our direct_write updates.
4700
+ # The ssd_event_get marks completion of backend fetch operations.
4701
+ current_stream.wait_event(self.ssd_event_get)
4702
+
4703
+ # if scratch pad exists, write to next batch scratch pad
4704
+ sp = self.ssd_scratch_pad_eviction_data[0][0]
4705
+ sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
4706
+ self.current_device
4707
+ )
4708
+ actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
4709
+ if actions_count_gpu.item() != 0:
4710
+ # when no actional_count_gpu, no need to write to SP
4711
+ sp_idx = sp_idx[:actions_count_gpu]
4712
+
4713
+ # -1 in lxu_cache_locations means the row is not in L1 cache and in SP
4714
+ # fill the row_ids in L1 with -2, >0 values means in SP or backend
4715
+ # @eg. updated_indices_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
4716
+ updated_indices_in_sp = linear_cache_indices.masked_fill(
4717
+ lxu_cache_locations != -1, -2
4718
+ )
4719
+ # sort the sp_idx for binary search
4720
+ # should already be sorted
4721
+ # sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
4722
+ # @eg. sp_idx = [4, 2, 1, 3, 10]
4723
+ # @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
4724
+ sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
4725
+ # search rows id in sp against the SP indexes to find location of the rows in SP
4726
+ # @eg: updated_indices_in_sp = [0, 5, 0, 1, 0, 2, 3, 4, 4]
4727
+ # @eg: 5 is OOB
4728
+ updated_indices_in_sp_idx = torch.searchsorted(
4729
+ sorted_sp_idx, updated_indices_in_sp
4730
+ )
4731
+ # does not found in SP will Out of Bound
4732
+ oob_sp_idx = updated_indices_in_sp_idx >= sp_idx.numel()
4733
+ # make the oob items in bound
4734
+ # @eg updated_indices_in_sp=[0, 0, 0, 1, 0, 2, 3, 4, 4]
4735
+ updated_indices_in_sp_idx[oob_sp_idx] = 0
4736
+
4737
+ # torch.searchsorted is not exact match,
4738
+ # we only take exact matched rows, where the id is found in SP.
4739
+ # @eg 5 in updated_indices_in_sp is not in sp_idx, but has 4 in updated_indices_in_sp
4740
+ # @eg sorted_sp_idx[updated_indices_in_sp]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
4741
+ # @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
4742
+ exact_match_mask = (
4743
+ sorted_sp_idx[updated_indices_in_sp_idx]
4744
+ == updated_indices_in_sp
4745
+ )
4746
+ # Get the location of the row ids found in SP.
4747
+ # @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
4748
+ sp_locations_found = sp_idx_inverse_indices[
4749
+ updated_indices_in_sp[exact_match_mask]
4750
+ ]
4751
+ # Get the corresponding weights for the matched indices
4752
+ matched_weights = weights[exact_match_mask]
4753
+
4754
+ # Write the weights to the sparse tensor at the found locations
4755
+ if sp_locations_found.numel() > 0:
4756
+ sp.index_put_(
4757
+ (sp_locations_found,),
4758
+ matched_weights,
4759
+ accumulate=False,
4760
+ )
4761
+ current_stream.record_event(self.direct_write_sp_complete_event)
4762
+
4763
+ # step 3: write l1 cache missing rows to backend
4764
+ # Wait for step 2 to complete
4765
+ with record_function(
4766
+ f"## direct_write_to_backend_{step}_{self.tbe_unique_id} ##"
4767
+ ):
4768
+ # Use the existing ssd_eviction_stream for all backend write operations
4769
+ # This stream is already created with low priority during initialization
4770
+ with torch.cuda.stream(self.ssd_eviction_stream):
4771
+ # Create a mask for indices not in L1 cache
4772
+ non_cache_mask = ~cache_location_mask
4773
+
4774
+ # Calculate the count of valid indices (those not in L1 cache)
4775
+ valid_count = non_cache_mask.sum().to(torch.int64).cpu()
4776
+
4777
+ if valid_count.item() > 0:
4778
+ # Extract only the indices and weights that are not in L1 cache
4779
+ non_cache_indices = linear_cache_indices[non_cache_mask]
4780
+ non_cache_weights = weights[non_cache_mask]
4781
+
4782
+ # Move tensors to CPU for set_cuda
4783
+ cpu_indices = non_cache_indices.cpu()
4784
+ cpu_weights = non_cache_weights.cpu()
4785
+
4786
+ # Write to backend - only sending the non-cache indices and weights
4787
+ self.record_function_via_dummy_profile(
4788
+ f"## ssd_write_{step}_set_cuda_{self.tbe_unique_id} ##",
4789
+ self.ssd_db.set_cuda,
4790
+ cpu_indices,
4791
+ cpu_weights,
4792
+ valid_count,
4793
+ self.timestep,
4794
+ is_bwd=False,
4795
+ )
4796
+
4797
+ # Return control to the main stream without waiting for the backend operation to complete
4798
+
4799
+ def get_free_cpu_memory_gb(self) -> float:
4800
+ def _get_mem_available() -> float:
4801
+ if sys.platform.startswith("linux"):
4802
+ info = {}
4803
+ with open("/proc/meminfo") as f:
4804
+ for line in f:
4805
+ p = line.split()
4806
+ info[p[0].strip(":").lower()] = int(p[1]) * 1024
4807
+ if "memavailable" in info:
4808
+ # Linux >= 3.14
4809
+ return info["memavailable"]
4810
+ else:
4811
+ return info["memfree"] + info["cached"]
4812
+ else:
4813
+ raise RuntimeError(
4814
+ "Unsupported platform for free memory eviction, pls use ID count eviction tirgger mode"
4815
+ )
4816
+
4817
+ mem = _get_mem_available()
4818
+ return mem / (1024**3)
4819
+
4820
+ @classmethod
4821
+ def trigger_evict_in_all_tbes(cls) -> None:
4822
+ for tbe in cls._all_tbe_instances:
4823
+ tbe.ssd_db.trigger_feature_evict()
4824
+
4825
+ @classmethod
4826
+ def tbe_has_ongoing_eviction(cls) -> bool:
4827
+ for tbe in cls._all_tbe_instances:
4828
+ if tbe.ssd_db.is_evicting():
4829
+ return True
4830
+ return False
4831
+
4832
+ def set_free_mem_eviction_trigger_config(
4833
+ self, eviction_policy: EvictionPolicy
4834
+ ) -> None:
4835
+ self.enable_free_mem_trigger_eviction = True
4836
+ self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode
4837
+ assert (
4838
+ eviction_policy.eviction_free_mem_check_interval_batch is not None
4839
+ ), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode"
4840
+ self.eviction_free_mem_check_interval_batch: int = (
4841
+ eviction_policy.eviction_free_mem_check_interval_batch
4842
+ )
4843
+ assert (
4844
+ eviction_policy.eviction_free_mem_threshold_gb is not None
4845
+ ), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode"
4846
+ self.eviction_free_mem_threshold_gb: int = (
4847
+ eviction_policy.eviction_free_mem_threshold_gb
4848
+ )
4849
+ logging.info(
4850
+ f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}"
4851
+ )
4852
+
4853
+ def may_trigger_eviction(self) -> None:
4854
+ def is_first_tbe() -> bool:
4855
+ first = SSDTableBatchedEmbeddingBags._first_instance_ref
4856
+ return first is not None and first() is self
4857
+
4858
+ # We assume that the eviction time is less than free mem check interval time
4859
+ # So every time we reach this check, all evictions in all tbes should be finished.
4860
+ # We only need to check the first tbe because all tbes share the same free mem,
4861
+ # once the first tbe detect need to trigger eviction, it will call trigger func
4862
+ # in all tbes from _all_tbe_instances
4863
+ if (
4864
+ self.enable_free_mem_trigger_eviction
4865
+ and self.step % self.eviction_free_mem_check_interval_batch == 0
4866
+ and self.training
4867
+ and is_first_tbe()
4868
+ ):
4869
+ if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction():
4870
+ SSDTableBatchedEmbeddingBags._eviction_triggered = False
4871
+
4872
+ free_cpu_mem_gb = self.get_free_cpu_memory_gb()
4873
+ local_evict_trigger = int(
4874
+ free_cpu_mem_gb < self.eviction_free_mem_threshold_gb
4875
+ )
4876
+ tensor_flag = torch.tensor(
4877
+ local_evict_trigger,
4878
+ device=self.current_device,
4879
+ dtype=torch.int,
4880
+ )
4881
+ world_size = dist.get_world_size(self._pg)
4882
+ if world_size > 1:
4883
+ dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg)
4884
+ global_evict_trigger = tensor_flag.item()
4885
+ else:
4886
+ global_evict_trigger = local_evict_trigger
4887
+ if (
4888
+ global_evict_trigger >= 1
4889
+ and SSDTableBatchedEmbeddingBags._eviction_triggered
4890
+ ):
4891
+ logging.warning(
4892
+ f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true"
4893
+ )
4894
+ if (
4895
+ global_evict_trigger >= 1
4896
+ and not SSDTableBatchedEmbeddingBags._eviction_triggered
4897
+ ):
4898
+ SSDTableBatchedEmbeddingBags._eviction_triggered = True
4899
+ SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes()
4900
+ logging.info(
4901
+ f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction"
4902
+ )
4903
+
4904
+ def reset_inference_mode(self) -> None:
4905
+ """
4906
+ Reset the inference mode
4907
+ """
4908
+ self.eval()