fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/tbe/ssd/training.py
CHANGED
|
@@ -12,13 +12,15 @@ import contextlib
|
|
|
12
12
|
import functools
|
|
13
13
|
import itertools
|
|
14
14
|
import logging
|
|
15
|
+
import math
|
|
15
16
|
import os
|
|
16
|
-
import tempfile
|
|
17
17
|
import threading
|
|
18
18
|
import time
|
|
19
|
+
from functools import cached_property
|
|
19
20
|
from math import floor, log2
|
|
20
|
-
from typing import Any, Callable,
|
|
21
|
+
from typing import Any, Callable, ClassVar, Optional, Union
|
|
21
22
|
import torch # usort:skip
|
|
23
|
+
import weakref
|
|
22
24
|
|
|
23
25
|
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
|
24
26
|
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
|
|
@@ -29,9 +31,13 @@ from fbgemm_gpu.runtime_monitor import (
|
|
|
29
31
|
)
|
|
30
32
|
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
|
|
31
33
|
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
34
|
+
BackendType,
|
|
32
35
|
BoundsCheckMode,
|
|
33
36
|
CacheAlgorithm,
|
|
34
37
|
EmbeddingLocation,
|
|
38
|
+
EvictionPolicy,
|
|
39
|
+
get_bounds_check_version_for_platform,
|
|
40
|
+
KVZCHParams,
|
|
35
41
|
PoolingMode,
|
|
36
42
|
SplitState,
|
|
37
43
|
)
|
|
@@ -39,21 +45,23 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
|
|
39
45
|
apply_split_helper,
|
|
40
46
|
CounterBasedRegularizationDefinition,
|
|
41
47
|
CowClipDefinition,
|
|
48
|
+
RESParams,
|
|
42
49
|
UVMCacheStatsIndex,
|
|
43
50
|
WeightDecayMode,
|
|
44
51
|
)
|
|
45
52
|
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
|
|
53
|
+
check_allocated_vbe_output,
|
|
46
54
|
generate_vbe_metadata,
|
|
55
|
+
is_torchdynamo_compiling,
|
|
47
56
|
)
|
|
48
|
-
|
|
49
57
|
from torch import distributed as dist, nn, Tensor # usort:skip
|
|
58
|
+
import sys
|
|
50
59
|
from dataclasses import dataclass
|
|
51
60
|
|
|
52
61
|
from torch.autograd.profiler import record_function
|
|
53
62
|
|
|
54
63
|
from ..cache import get_unique_indices_v2
|
|
55
|
-
|
|
56
|
-
from .common import ASSOC
|
|
64
|
+
from .common import ASSOC, pad4, tensor_pad4
|
|
57
65
|
from .utils.partially_materialized_tensor import PartiallyMaterializedTensor
|
|
58
66
|
|
|
59
67
|
|
|
@@ -69,6 +77,14 @@ class IterData:
|
|
|
69
77
|
max_B: Optional[int] = -1
|
|
70
78
|
|
|
71
79
|
|
|
80
|
+
@dataclass
|
|
81
|
+
class KVZCHCachedData:
|
|
82
|
+
cached_optimizer_states_per_table: list[list[torch.Tensor]]
|
|
83
|
+
cached_weight_tensor_per_table: list[torch.Tensor]
|
|
84
|
+
cached_id_tensor_per_table: list[torch.Tensor]
|
|
85
|
+
cached_bucket_splits: list[torch.Tensor]
|
|
86
|
+
|
|
87
|
+
|
|
72
88
|
class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
73
89
|
D_offsets: Tensor
|
|
74
90
|
lxu_cache_weights: Tensor
|
|
@@ -86,12 +102,19 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
86
102
|
weights_placements: Tensor
|
|
87
103
|
weights_offsets: Tensor
|
|
88
104
|
_local_instance_index: int = -1
|
|
105
|
+
res_params: RESParams
|
|
106
|
+
table_names: list[str]
|
|
107
|
+
_all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet()
|
|
108
|
+
_first_instance_ref: ClassVar[weakref.ref] = None
|
|
109
|
+
_eviction_triggered: ClassVar[bool] = False
|
|
89
110
|
|
|
90
111
|
def __init__(
|
|
91
112
|
self,
|
|
92
|
-
embedding_specs:
|
|
93
|
-
feature_table_map: Optional[
|
|
113
|
+
embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
|
|
114
|
+
feature_table_map: Optional[list[int]], # [T]
|
|
94
115
|
cache_sets: int,
|
|
116
|
+
# A comma-separated string, e.g. "/data00_nvidia0,/data01_nvidia0/", db shards
|
|
117
|
+
# will be placed in these paths round-robin.
|
|
95
118
|
ssd_storage_directory: str,
|
|
96
119
|
ssd_rocksdb_shards: int = 1,
|
|
97
120
|
ssd_memtable_flush_period: int = -1,
|
|
@@ -131,13 +154,16 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
131
154
|
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
132
155
|
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
133
156
|
# Parameter Server Configs
|
|
134
|
-
ps_hosts: Optional[
|
|
157
|
+
ps_hosts: Optional[tuple[tuple[str, int]]] = None,
|
|
135
158
|
ps_max_key_per_request: Optional[int] = None,
|
|
136
159
|
ps_client_thread_num: Optional[int] = None,
|
|
137
160
|
ps_max_local_index_length: Optional[int] = None,
|
|
138
161
|
tbe_unique_id: int = -1,
|
|
139
|
-
#
|
|
140
|
-
#
|
|
162
|
+
# If set to True, will use `ssd_storage_directory` as the ssd paths.
|
|
163
|
+
# If set to False, will use the default ssd paths.
|
|
164
|
+
# In local test we need to use the pass in path for rocksdb creation
|
|
165
|
+
# fn production we could either use the default ssd mount points or explicity specify ssd
|
|
166
|
+
# mount points using `ssd_storage_directory`.
|
|
141
167
|
use_passed_in_path: int = True,
|
|
142
168
|
gather_ssd_cache_stats: Optional[bool] = False,
|
|
143
169
|
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
|
|
@@ -152,19 +178,126 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
152
178
|
# number of rows will be decided by bulk_init_chunk_size / size_of_each_row
|
|
153
179
|
bulk_init_chunk_size: int = 0,
|
|
154
180
|
lazy_bulk_init_enabled: bool = False,
|
|
181
|
+
backend_type: BackendType = BackendType.SSD,
|
|
182
|
+
kv_zch_params: Optional[KVZCHParams] = None,
|
|
183
|
+
enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
|
|
184
|
+
res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
|
|
185
|
+
flushing_block_size: int = 2_000_000_000, # 2GB
|
|
186
|
+
table_names: Optional[list[str]] = None,
|
|
187
|
+
use_rowwise_bias_correction: bool = False, # For Adam use
|
|
188
|
+
optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
|
|
189
|
+
pg: Optional[dist.ProcessGroup] = None,
|
|
155
190
|
) -> None:
|
|
156
191
|
super(SSDTableBatchedEmbeddingBags, self).__init__()
|
|
157
192
|
|
|
193
|
+
# Set the optimizer
|
|
194
|
+
assert optimizer in (
|
|
195
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
196
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
197
|
+
OptimType.ADAM,
|
|
198
|
+
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
|
|
199
|
+
self.optimizer = optimizer
|
|
200
|
+
|
|
201
|
+
# Set the table weight and output dtypes
|
|
202
|
+
assert weights_precision in (SparseType.FP32, SparseType.FP16)
|
|
203
|
+
self.weights_precision = weights_precision
|
|
204
|
+
self.output_dtype: int = output_dtype.as_int()
|
|
205
|
+
|
|
206
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
207
|
+
# Adagrad currently only supports FP32 for momentum1
|
|
208
|
+
self.optimizer_state_dtypes: dict[str, SparseType] = {
|
|
209
|
+
"momentum1": SparseType.FP32,
|
|
210
|
+
}
|
|
211
|
+
else:
|
|
212
|
+
self.optimizer_state_dtypes: dict[str, SparseType] = optimizer_state_dtypes
|
|
213
|
+
|
|
214
|
+
# Zero collision TBE configurations
|
|
215
|
+
self.kv_zch_params = kv_zch_params
|
|
216
|
+
self.backend_type = backend_type
|
|
217
|
+
self.enable_optimizer_offloading: bool = False
|
|
218
|
+
self.backend_return_whole_row: bool = False
|
|
219
|
+
self._embedding_cache_mode: bool = False
|
|
220
|
+
self.load_ckpt_without_opt: bool = False
|
|
221
|
+
if self.kv_zch_params:
|
|
222
|
+
self.kv_zch_params.validate()
|
|
223
|
+
self.load_ckpt_without_opt = (
|
|
224
|
+
# pyre-ignore [16]
|
|
225
|
+
self.kv_zch_params.load_ckpt_without_opt
|
|
226
|
+
)
|
|
227
|
+
self.enable_optimizer_offloading = (
|
|
228
|
+
# pyre-ignore [16]
|
|
229
|
+
self.kv_zch_params.enable_optimizer_offloading
|
|
230
|
+
)
|
|
231
|
+
self.backend_return_whole_row = (
|
|
232
|
+
# pyre-ignore [16]
|
|
233
|
+
self.kv_zch_params.backend_return_whole_row
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if self.enable_optimizer_offloading:
|
|
237
|
+
logging.info("Optimizer state offloading is enabled")
|
|
238
|
+
if self.backend_return_whole_row:
|
|
239
|
+
assert (
|
|
240
|
+
self.backend_type == BackendType.DRAM
|
|
241
|
+
), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}"
|
|
242
|
+
logging.info(
|
|
243
|
+
"Backend will return whole row including metaheader, weight and optimizer for checkpoint"
|
|
244
|
+
)
|
|
245
|
+
# pyre-ignore [16]
|
|
246
|
+
self._embedding_cache_mode = self.kv_zch_params.embedding_cache_mode
|
|
247
|
+
if self._embedding_cache_mode:
|
|
248
|
+
logging.info("KVZCH is in embedding_cache_mode")
|
|
249
|
+
assert self.optimizer in [
|
|
250
|
+
OptimType.EXACT_ROWWISE_ADAGRAD
|
|
251
|
+
], f"only EXACT_ROWWISE_ADAGRAD supports embedding cache mode, but got {self.optimizer}"
|
|
252
|
+
if self.load_ckpt_without_opt:
|
|
253
|
+
if (
|
|
254
|
+
# pyre-ignore [16]
|
|
255
|
+
self.kv_zch_params.optimizer_type_for_st
|
|
256
|
+
== OptimType.PARTIAL_ROWWISE_ADAM.value
|
|
257
|
+
):
|
|
258
|
+
self.optimizer = OptimType.PARTIAL_ROWWISE_ADAM
|
|
259
|
+
logging.info(
|
|
260
|
+
f"Override optimizer type with {self.optimizer=} for st publish"
|
|
261
|
+
)
|
|
262
|
+
if (
|
|
263
|
+
# pyre-ignore [16]
|
|
264
|
+
self.kv_zch_params.optimizer_state_dtypes_for_st
|
|
265
|
+
is not None
|
|
266
|
+
):
|
|
267
|
+
optimizer_state_dtypes = {}
|
|
268
|
+
for k, v in dict(
|
|
269
|
+
self.kv_zch_params.optimizer_state_dtypes_for_st
|
|
270
|
+
).items():
|
|
271
|
+
optimizer_state_dtypes[k] = SparseType.from_int(v)
|
|
272
|
+
self.optimizer_state_dtypes = optimizer_state_dtypes
|
|
273
|
+
logging.info(
|
|
274
|
+
f"Override optimizer_state_dtypes with {self.optimizer_state_dtypes=} for st publish"
|
|
275
|
+
)
|
|
276
|
+
|
|
158
277
|
self.pooling_mode = pooling_mode
|
|
159
278
|
self.bounds_check_mode_int: int = bounds_check_mode.value
|
|
160
279
|
self.embedding_specs = embedding_specs
|
|
161
|
-
|
|
280
|
+
self.table_names = table_names if table_names is not None else []
|
|
281
|
+
rows, dims = zip(*embedding_specs)
|
|
162
282
|
T_ = len(self.embedding_specs)
|
|
163
283
|
assert T_ > 0
|
|
164
284
|
# pyre-fixme[8]: Attribute has type `device`; used as `int`.
|
|
165
285
|
self.current_device: torch.device = torch.cuda.current_device()
|
|
166
286
|
|
|
167
|
-
self.
|
|
287
|
+
self.enable_raw_embedding_streaming = enable_raw_embedding_streaming
|
|
288
|
+
# initialize the raw embedding streaming related variables
|
|
289
|
+
self.res_params: RESParams = res_params or RESParams()
|
|
290
|
+
if self.enable_raw_embedding_streaming:
|
|
291
|
+
self.res_params.table_sizes = [0] + list(itertools.accumulate(rows))
|
|
292
|
+
res_port_from_env = os.getenv("LOCAL_RES_PORT")
|
|
293
|
+
self.res_params.res_server_port = (
|
|
294
|
+
int(res_port_from_env) if res_port_from_env else 0
|
|
295
|
+
)
|
|
296
|
+
logging.info(
|
|
297
|
+
f"get env {self.res_params.res_server_port=}, at rank {dist.get_rank()}, with {self.res_params=}"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
self.feature_table_map: list[int] = (
|
|
168
301
|
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
169
302
|
)
|
|
170
303
|
T = len(self.feature_table_map)
|
|
@@ -177,7 +310,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
177
310
|
feature_dims = [dims[t] for t in self.feature_table_map]
|
|
178
311
|
D_offsets = [dims[t] for t in self.feature_table_map]
|
|
179
312
|
D_offsets = [0] + list(itertools.accumulate(D_offsets))
|
|
313
|
+
|
|
314
|
+
# Sum of row length of all tables
|
|
180
315
|
self.total_D: int = D_offsets[-1]
|
|
316
|
+
|
|
317
|
+
# Max number of elements required to store a row in the cache
|
|
181
318
|
self.max_D: int = max(dims)
|
|
182
319
|
self.register_buffer(
|
|
183
320
|
"D_offsets",
|
|
@@ -189,6 +326,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
189
326
|
self.total_hash_size_bits: int = 0
|
|
190
327
|
else:
|
|
191
328
|
self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
|
|
329
|
+
self.register_buffer(
|
|
330
|
+
"table_hash_size_cumsum",
|
|
331
|
+
torch.tensor(
|
|
332
|
+
hash_size_cumsum, device=self.current_device, dtype=torch.int64
|
|
333
|
+
),
|
|
334
|
+
)
|
|
192
335
|
# The last element is to easily access # of rows of each table by
|
|
193
336
|
self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1)
|
|
194
337
|
self.total_hash_size: int = hash_size_cumsum[-1]
|
|
@@ -229,13 +372,25 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
229
372
|
"feature_dims",
|
|
230
373
|
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
|
|
231
374
|
)
|
|
375
|
+
self.register_buffer(
|
|
376
|
+
"table_dims",
|
|
377
|
+
torch.tensor(dims, device="cpu", dtype=torch.int64),
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
info_B_num_bits_, info_B_mask_ = torch.ops.fbgemm.get_infos_metadata(
|
|
381
|
+
self.D_offsets, # unused tensor
|
|
382
|
+
1, # max_B
|
|
383
|
+
T, # T
|
|
384
|
+
)
|
|
385
|
+
self.info_B_num_bits: int = info_B_num_bits_
|
|
386
|
+
self.info_B_mask: int = info_B_mask_
|
|
232
387
|
|
|
233
388
|
assert cache_sets > 0
|
|
234
389
|
element_size = weights_precision.bit_rate() // 8
|
|
235
390
|
assert (
|
|
236
391
|
element_size == 4 or element_size == 2
|
|
237
392
|
), f"Invalid element size {element_size}"
|
|
238
|
-
cache_size = cache_sets * ASSOC * element_size * self.
|
|
393
|
+
cache_size = cache_sets * ASSOC * element_size * self.cache_row_dim
|
|
239
394
|
logging.info(
|
|
240
395
|
f"Using cache for SSD with admission algorithm "
|
|
241
396
|
f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_rocksdb_shards} shards, "
|
|
@@ -243,10 +398,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
243
398
|
f"Memtable Flush Period: {ssd_memtable_flush_period}, "
|
|
244
399
|
f"Memtable Flush Offset: {ssd_memtable_flush_offset}, "
|
|
245
400
|
f"Desired L0 files per compaction: {ssd_l0_files_per_compact}, "
|
|
246
|
-
f"{cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
|
|
401
|
+
f"Cache size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
|
|
247
402
|
f"weights precision: {weights_precision}, "
|
|
248
403
|
f"output dtype: {output_dtype}, "
|
|
249
|
-
f"chunk size in bulk init: {bulk_init_chunk_size} bytes"
|
|
404
|
+
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
|
|
405
|
+
f"kv_zch_params: {kv_zch_params}, "
|
|
406
|
+
f"embedding spec: {embedding_specs}"
|
|
250
407
|
)
|
|
251
408
|
self.register_buffer(
|
|
252
409
|
"lxu_cache_state",
|
|
@@ -262,6 +419,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
262
419
|
)
|
|
263
420
|
|
|
264
421
|
self.step = 0
|
|
422
|
+
self.last_flush_step = -1
|
|
265
423
|
|
|
266
424
|
# Set prefetch pipeline
|
|
267
425
|
self.prefetch_pipeline: bool = prefetch_pipeline
|
|
@@ -291,10 +449,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
291
449
|
EmbeddingLocation.DEVICE,
|
|
292
450
|
)
|
|
293
451
|
|
|
294
|
-
assert weights_precision in (SparseType.FP32, SparseType.FP16)
|
|
295
|
-
self.weights_precision = weights_precision
|
|
296
|
-
self.output_dtype: int = output_dtype.as_int()
|
|
297
|
-
|
|
298
452
|
cache_dtype = weights_precision.as_dtype()
|
|
299
453
|
if ssd_cache_location == EmbeddingLocation.MANAGED:
|
|
300
454
|
self.register_buffer(
|
|
@@ -305,7 +459,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
305
459
|
device=self.current_device,
|
|
306
460
|
dtype=cache_dtype,
|
|
307
461
|
),
|
|
308
|
-
[cache_sets * ASSOC, self.
|
|
462
|
+
[cache_sets * ASSOC, self.cache_row_dim],
|
|
309
463
|
is_host_mapped=self.uvm_host_mapped,
|
|
310
464
|
),
|
|
311
465
|
)
|
|
@@ -314,7 +468,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
314
468
|
"lxu_cache_weights",
|
|
315
469
|
torch.zeros(
|
|
316
470
|
cache_sets * ASSOC,
|
|
317
|
-
self.
|
|
471
|
+
self.cache_row_dim,
|
|
318
472
|
device=self.current_device,
|
|
319
473
|
dtype=cache_dtype,
|
|
320
474
|
),
|
|
@@ -387,6 +541,15 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
387
541
|
|
|
388
542
|
self.timestep = 0
|
|
389
543
|
|
|
544
|
+
# Store the iteration number on GPU and CPU (used for certain optimizers)
|
|
545
|
+
persistent_iter_ = optimizer in (OptimType.PARTIAL_ROWWISE_ADAM,)
|
|
546
|
+
self.register_buffer(
|
|
547
|
+
"iter",
|
|
548
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
549
|
+
persistent=persistent_iter_,
|
|
550
|
+
)
|
|
551
|
+
self.iter_cpu: torch.Tensor = torch.zeros(1, dtype=torch.int64, device="cpu")
|
|
552
|
+
|
|
390
553
|
# Dummy profile configuration for measuring the SSD get/set time
|
|
391
554
|
# get and set are executed by another thread which (for some reason) is
|
|
392
555
|
# not traceable by PyTorch's Kineto. We workaround this problem by
|
|
@@ -405,18 +568,46 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
405
568
|
f"FBGEMM_SSD_TBE_USE_DUMMY_PROFILE is set to {set_dummy_profile}; "
|
|
406
569
|
f"Use dummy profile: {use_dummy_profile}"
|
|
407
570
|
)
|
|
408
|
-
|
|
571
|
+
|
|
409
572
|
self.record_function_via_dummy_profile: Callable[..., Any] = (
|
|
410
573
|
self.record_function_via_dummy_profile_factory(use_dummy_profile)
|
|
411
574
|
)
|
|
412
575
|
|
|
413
|
-
|
|
576
|
+
if use_passed_in_path:
|
|
577
|
+
ssd_dir_list = ssd_storage_directory.split(",")
|
|
578
|
+
for ssd_dir in ssd_dir_list:
|
|
579
|
+
os.makedirs(ssd_dir, exist_ok=True)
|
|
414
580
|
|
|
415
|
-
ssd_directory =
|
|
416
|
-
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
|
|
417
|
-
)
|
|
581
|
+
ssd_directory = ssd_storage_directory
|
|
418
582
|
# logging.info("DEBUG: weights_precision {}".format(weights_precision))
|
|
419
583
|
|
|
584
|
+
"""
|
|
585
|
+
##################### for ZCH v.Next loading checkpoints Short Term Solution #######################
|
|
586
|
+
weight_id tensor is the weight and optimizer keys, to load from checkpoint, weight_id tensor
|
|
587
|
+
needs to be loaded first, then we can load the weight and optimizer tensors.
|
|
588
|
+
However, the stateful checkpoint loading does not guarantee the tensor loading order, so we need
|
|
589
|
+
to cache the weight_id, weight and optimizer tensors untils all data are loaded, then we can apply
|
|
590
|
+
them to backend.
|
|
591
|
+
Currently, we'll cache the weight_id, weight and optimizer tensors in the KVZCHCachedData class,
|
|
592
|
+
and apply them to backend when all data are loaded. The downside of this solution is that we'll
|
|
593
|
+
have to duplicate a whole tensor memory to backend before we can release the python tensor memory,
|
|
594
|
+
which is not ideal.
|
|
595
|
+
The longer term solution is to support the caching from the backend side, and allow streaming based
|
|
596
|
+
data move from cached weight and optimizer to key/value format without duplicate one whole tensor's
|
|
597
|
+
memory.
|
|
598
|
+
"""
|
|
599
|
+
self._cached_kvzch_data: Optional[KVZCHCachedData] = None
|
|
600
|
+
# initial embedding rows on this rank per table, this is used for loading checkpoint
|
|
601
|
+
self.local_weight_counts: list[int] = [0] * T_
|
|
602
|
+
# groundtruth global id on this rank per table, this is used for loading checkpoint
|
|
603
|
+
self.global_id_per_rank: list[torch.Tensor] = [torch.zeros(0)] * T_
|
|
604
|
+
# loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend
|
|
605
|
+
self.load_state_dict: bool = False
|
|
606
|
+
|
|
607
|
+
SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self)
|
|
608
|
+
if SSDTableBatchedEmbeddingBags._first_instance_ref is None:
|
|
609
|
+
SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self)
|
|
610
|
+
|
|
420
611
|
# create tbe unique id using rank index | local tbe idx
|
|
421
612
|
if tbe_unique_id == -1:
|
|
422
613
|
SSDTableBatchedEmbeddingBags._local_instance_index += 1
|
|
@@ -432,21 +623,26 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
432
623
|
logging.warning("dist is not initialized, treating as single gpu cases")
|
|
433
624
|
tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
|
|
434
625
|
self.tbe_unique_id = tbe_unique_id
|
|
626
|
+
self.l2_cache_size = l2_cache_size
|
|
435
627
|
logging.info(f"tbe_unique_id: {tbe_unique_id}")
|
|
436
|
-
|
|
628
|
+
self.enable_free_mem_trigger_eviction: bool = False
|
|
629
|
+
if self.backend_type == BackendType.SSD:
|
|
437
630
|
logging.info(
|
|
438
|
-
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,
|
|
439
|
-
f"
|
|
440
|
-
f"
|
|
441
|
-
f"
|
|
442
|
-
f"
|
|
443
|
-
f"
|
|
444
|
-
f"
|
|
445
|
-
f"
|
|
446
|
-
f"
|
|
631
|
+
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
|
|
632
|
+
f"enable_async_update:{enable_async_update}, passed_in_path={ssd_directory}, "
|
|
633
|
+
f"num_shards={ssd_rocksdb_shards}, num_threads={ssd_rocksdb_shards}, "
|
|
634
|
+
f"memtable_flush_period={ssd_memtable_flush_period}, memtable_flush_offset={ssd_memtable_flush_offset}, "
|
|
635
|
+
f"l0_files_per_compact={ssd_l0_files_per_compact}, max_D={self.max_D}, "
|
|
636
|
+
f"cache_row_size={self.cache_row_dim}, rate_limit_mbps={ssd_rate_limit_mbps}, "
|
|
637
|
+
f"size_ratio={ssd_size_ratio}, compaction_trigger={ssd_compaction_trigger}, "
|
|
638
|
+
f"lazy_bulk_init_enabled={lazy_bulk_init_enabled}, write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size}, "
|
|
639
|
+
f"max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num}, "
|
|
640
|
+
f"uniform_init_lower={ssd_uniform_init_lower}, uniform_init_upper={ssd_uniform_init_upper}, "
|
|
641
|
+
f"row_storage_bitwidth={weights_precision.bit_rate()}, block_cache_size_per_tbe={ssd_block_cache_size_per_tbe}, "
|
|
642
|
+
f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB, "
|
|
643
|
+
f"enable_raw_embedding_streaming:{self.enable_raw_embedding_streaming}, flushing_block_size:{flushing_block_size}"
|
|
447
644
|
)
|
|
448
645
|
# pyre-fixme[4]: Attribute must be annotated.
|
|
449
|
-
# pyre-ignore[16]
|
|
450
646
|
self._ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
|
|
451
647
|
ssd_directory,
|
|
452
648
|
ssd_rocksdb_shards,
|
|
@@ -454,7 +650,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
454
650
|
ssd_memtable_flush_period,
|
|
455
651
|
ssd_memtable_flush_offset,
|
|
456
652
|
ssd_l0_files_per_compact,
|
|
457
|
-
self.
|
|
653
|
+
self.cache_row_dim,
|
|
458
654
|
ssd_rate_limit_mbps,
|
|
459
655
|
ssd_size_ratio,
|
|
460
656
|
ssd_compaction_trigger,
|
|
@@ -468,6 +664,24 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
468
664
|
tbe_unique_id,
|
|
469
665
|
l2_cache_size,
|
|
470
666
|
enable_async_update,
|
|
667
|
+
self.enable_raw_embedding_streaming,
|
|
668
|
+
self.res_params.res_store_shards,
|
|
669
|
+
self.res_params.res_server_port,
|
|
670
|
+
self.res_params.table_names,
|
|
671
|
+
self.res_params.table_offsets,
|
|
672
|
+
self.res_params.table_sizes,
|
|
673
|
+
(
|
|
674
|
+
tensor_pad4(self.table_dims)
|
|
675
|
+
if self.enable_optimizer_offloading
|
|
676
|
+
else None
|
|
677
|
+
),
|
|
678
|
+
(
|
|
679
|
+
self.table_hash_size_cumsum.cpu()
|
|
680
|
+
if self.enable_optimizer_offloading
|
|
681
|
+
else None
|
|
682
|
+
),
|
|
683
|
+
flushing_block_size,
|
|
684
|
+
self._embedding_cache_mode, # disable_random_init
|
|
471
685
|
)
|
|
472
686
|
if self.bulk_init_chunk_size > 0:
|
|
473
687
|
self.ssd_uniform_init_lower: float = ssd_uniform_init_lower
|
|
@@ -476,11 +690,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
476
690
|
self._lazy_initialize_ssd_tbe()
|
|
477
691
|
else:
|
|
478
692
|
self._insert_all_kv()
|
|
479
|
-
|
|
480
|
-
# pyre-fixme[4]: Attribute must be annotated.
|
|
481
|
-
# pyre-ignore[16]
|
|
693
|
+
elif self.backend_type == BackendType.PS:
|
|
482
694
|
self._ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
|
|
483
|
-
[host[0] for host in ps_hosts],
|
|
695
|
+
[host[0] for host in ps_hosts], # pyre-ignore
|
|
484
696
|
[host[1] for host in ps_hosts],
|
|
485
697
|
tbe_unique_id,
|
|
486
698
|
(
|
|
@@ -491,14 +703,98 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
491
703
|
ps_client_thread_num if ps_client_thread_num is not None else 32,
|
|
492
704
|
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
|
|
493
705
|
l2_cache_size,
|
|
494
|
-
self.
|
|
706
|
+
self.cache_row_dim,
|
|
707
|
+
)
|
|
708
|
+
elif self.backend_type == BackendType.DRAM:
|
|
709
|
+
logging.info(
|
|
710
|
+
f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,"
|
|
711
|
+
f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
|
|
712
|
+
f"max_D={self.max_D},"
|
|
713
|
+
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
|
|
714
|
+
f"row_storage_bitwidth={weights_precision.bit_rate()},"
|
|
715
|
+
f"self.cache_row_dim={self.cache_row_dim},"
|
|
716
|
+
f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
|
|
717
|
+
f"feature_dims={self.feature_dims},"
|
|
718
|
+
f"hash_size_cumsum={self.hash_size_cumsum},"
|
|
719
|
+
f"backend_return_whole_row={self.backend_return_whole_row}"
|
|
495
720
|
)
|
|
721
|
+
table_dims = (
|
|
722
|
+
tensor_pad4(self.table_dims)
|
|
723
|
+
if self.enable_optimizer_offloading
|
|
724
|
+
else None
|
|
725
|
+
) # table_dims
|
|
726
|
+
eviction_config = None
|
|
727
|
+
if self.kv_zch_params and self.kv_zch_params.eviction_policy:
|
|
728
|
+
eviction_mem_threshold_gb = (
|
|
729
|
+
self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
|
|
730
|
+
if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
|
|
731
|
+
else self.l2_cache_size
|
|
732
|
+
)
|
|
733
|
+
kv_zch_params = self.kv_zch_params
|
|
734
|
+
eviction_policy = self.kv_zch_params.eviction_policy
|
|
735
|
+
if eviction_policy.eviction_trigger_mode == 5:
|
|
736
|
+
# If trigger mode is free_mem(5), populate config
|
|
737
|
+
self.set_free_mem_eviction_trigger_config(eviction_policy)
|
|
738
|
+
|
|
739
|
+
enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
|
|
740
|
+
[
|
|
741
|
+
int(x)
|
|
742
|
+
for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
|
|
743
|
+
]
|
|
744
|
+
if eviction_policy.enable_eviction_for_feature_score_eviction_policy
|
|
745
|
+
is not None
|
|
746
|
+
else None
|
|
747
|
+
)
|
|
748
|
+
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
|
|
749
|
+
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
|
|
750
|
+
eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
|
|
751
|
+
eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
|
|
752
|
+
eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
|
|
753
|
+
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
|
|
754
|
+
eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
|
|
755
|
+
eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
|
|
756
|
+
eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
|
|
757
|
+
eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
|
|
758
|
+
eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
|
|
759
|
+
eviction_policy.training_id_keep_count, # training_id_keep_count for each table
|
|
760
|
+
enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
|
|
761
|
+
eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
|
|
762
|
+
table_dims.tolist() if table_dims is not None else None,
|
|
763
|
+
eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
|
|
764
|
+
eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
|
|
765
|
+
eviction_policy.interval_for_insufficient_eviction_s,
|
|
766
|
+
eviction_policy.interval_for_sufficient_eviction_s,
|
|
767
|
+
eviction_policy.interval_for_feature_statistics_decay_s,
|
|
768
|
+
)
|
|
769
|
+
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
|
|
770
|
+
self.cache_row_dim,
|
|
771
|
+
ssd_uniform_init_lower,
|
|
772
|
+
ssd_uniform_init_upper,
|
|
773
|
+
eviction_config,
|
|
774
|
+
ssd_rocksdb_shards, # num_shards
|
|
775
|
+
ssd_rocksdb_shards, # num_threads
|
|
776
|
+
weights_precision.bit_rate(), # row_storage_bitwidth
|
|
777
|
+
table_dims,
|
|
778
|
+
(
|
|
779
|
+
self.table_hash_size_cumsum.cpu()
|
|
780
|
+
if self.enable_optimizer_offloading
|
|
781
|
+
else None
|
|
782
|
+
), # hash_size_cumsum
|
|
783
|
+
self.backend_return_whole_row, # backend_return_whole_row
|
|
784
|
+
False, # enable_async_update
|
|
785
|
+
self._embedding_cache_mode, # disable_random_init
|
|
786
|
+
)
|
|
787
|
+
else:
|
|
788
|
+
raise AssertionError(f"Invalid backend type {self.backend_type}")
|
|
789
|
+
|
|
496
790
|
# pyre-fixme[20]: Argument `self` expected.
|
|
497
|
-
|
|
791
|
+
low_priority, high_priority = torch.cuda.Stream.priority_range()
|
|
498
792
|
# GPU stream for SSD cache eviction
|
|
499
793
|
self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
|
|
500
|
-
# GPU stream for SSD memory copy
|
|
794
|
+
# GPU stream for SSD memory copy (also reused for feature score D2H)
|
|
501
795
|
self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
|
|
796
|
+
# GPU stream for async metadata operation
|
|
797
|
+
self.feature_score_stream = torch.cuda.Stream(priority=low_priority)
|
|
502
798
|
|
|
503
799
|
# SSD get completion event
|
|
504
800
|
self.ssd_event_get = torch.cuda.Event()
|
|
@@ -510,26 +806,93 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
510
806
|
self.ssd_event_backward = torch.cuda.Event()
|
|
511
807
|
# SSD get's input copy completion event
|
|
512
808
|
self.ssd_event_get_inputs_cpy = torch.cuda.Event()
|
|
809
|
+
if self._embedding_cache_mode:
|
|
810
|
+
# Direct write embedding completion event
|
|
811
|
+
self.direct_write_l1_complete_event: torch.cuda.streams.Event = (
|
|
812
|
+
torch.cuda.Event()
|
|
813
|
+
)
|
|
814
|
+
self.direct_write_sp_complete_event: torch.cuda.streams.Event = (
|
|
815
|
+
torch.cuda.Event()
|
|
816
|
+
)
|
|
817
|
+
# Prefetch operation completion event
|
|
818
|
+
self.prefetch_complete_event = torch.cuda.Event()
|
|
819
|
+
|
|
513
820
|
if self.prefetch_pipeline:
|
|
514
821
|
# SSD scratch pad index queue insert completion event
|
|
515
822
|
self.ssd_event_sp_idxq_insert: torch.cuda.streams.Event = torch.cuda.Event()
|
|
516
823
|
# SSD scratch pad index queue lookup completion event
|
|
517
824
|
self.ssd_event_sp_idxq_lookup: torch.cuda.streams.Event = torch.cuda.Event()
|
|
518
825
|
|
|
519
|
-
self.
|
|
826
|
+
if self.enable_raw_embedding_streaming:
|
|
827
|
+
# RES reuse the eviction stream
|
|
828
|
+
self.ssd_event_cache_streamed: torch.cuda.streams.Event = torch.cuda.Event()
|
|
829
|
+
self.ssd_event_cache_streaming_synced: torch.cuda.streams.Event = (
|
|
830
|
+
torch.cuda.Event()
|
|
831
|
+
)
|
|
832
|
+
self.ssd_event_cache_streaming_computed: torch.cuda.streams.Event = (
|
|
833
|
+
torch.cuda.Event()
|
|
834
|
+
)
|
|
835
|
+
self.ssd_event_sp_streamed: torch.cuda.streams.Event = torch.cuda.Event()
|
|
836
|
+
|
|
837
|
+
# Updated buffers
|
|
838
|
+
self.register_buffer(
|
|
839
|
+
"lxu_cache_updated_weights",
|
|
840
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
841
|
+
torch.zeros(
|
|
842
|
+
1,
|
|
843
|
+
device=self.current_device,
|
|
844
|
+
dtype=cache_dtype,
|
|
845
|
+
),
|
|
846
|
+
self.lxu_cache_weights.shape,
|
|
847
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
848
|
+
),
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
# For storing embedding indices to update to
|
|
852
|
+
self.register_buffer(
|
|
853
|
+
"lxu_cache_updated_indices",
|
|
854
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
855
|
+
torch.zeros(
|
|
856
|
+
1,
|
|
857
|
+
device=self.current_device,
|
|
858
|
+
dtype=torch.long,
|
|
859
|
+
),
|
|
860
|
+
(self.lxu_cache_weights.shape[0],),
|
|
861
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
862
|
+
),
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
# For storing the number of updated rows
|
|
866
|
+
self.register_buffer(
|
|
867
|
+
"lxu_cache_updated_count",
|
|
868
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
869
|
+
torch.zeros(
|
|
870
|
+
1,
|
|
871
|
+
device=self.current_device,
|
|
872
|
+
dtype=torch.int,
|
|
873
|
+
),
|
|
874
|
+
(1,),
|
|
875
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
876
|
+
),
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
# (Indices, Count)
|
|
880
|
+
self.prefetched_info: list[tuple[Tensor, Tensor]] = []
|
|
881
|
+
|
|
882
|
+
self.timesteps_prefetched: list[int] = []
|
|
520
883
|
# TODO: add type annotation
|
|
521
884
|
# pyre-fixme[4]: Attribute must be annotated.
|
|
522
885
|
self.ssd_prefetch_data = []
|
|
523
886
|
|
|
524
887
|
# Scratch pad eviction data queue
|
|
525
|
-
self.ssd_scratch_pad_eviction_data:
|
|
526
|
-
|
|
888
|
+
self.ssd_scratch_pad_eviction_data: list[
|
|
889
|
+
tuple[Tensor, Tensor, Tensor, bool]
|
|
527
890
|
] = []
|
|
528
|
-
self.ssd_location_update_data:
|
|
891
|
+
self.ssd_location_update_data: list[tuple[Tensor, Tensor]] = []
|
|
529
892
|
|
|
530
893
|
if self.prefetch_pipeline:
|
|
531
894
|
# Scratch pad value queue
|
|
532
|
-
self.ssd_scratch_pads:
|
|
895
|
+
self.ssd_scratch_pads: list[tuple[Tensor, Tensor, Tensor]] = []
|
|
533
896
|
|
|
534
897
|
# pyre-ignore[4]
|
|
535
898
|
# Scratch pad index queue
|
|
@@ -549,12 +912,15 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
549
912
|
)
|
|
550
913
|
cowclip_regularization = CowClipDefinition()
|
|
551
914
|
|
|
915
|
+
self.learning_rate_tensor: torch.Tensor = torch.tensor(
|
|
916
|
+
learning_rate, device=torch.device("cpu"), dtype=torch.float32
|
|
917
|
+
)
|
|
918
|
+
|
|
552
919
|
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs(
|
|
553
920
|
stochastic_rounding=stochastic_rounding,
|
|
554
921
|
gradient_clipping=gradient_clipping,
|
|
555
922
|
max_gradient=max_gradient,
|
|
556
923
|
max_norm=max_norm,
|
|
557
|
-
learning_rate=learning_rate,
|
|
558
924
|
eps=eps,
|
|
559
925
|
beta1=beta1,
|
|
560
926
|
beta2=beta2,
|
|
@@ -575,7 +941,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
575
941
|
weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
|
|
576
942
|
lower_bound=cowclip_regularization.lower_bound,
|
|
577
943
|
regularization_mode=weight_decay_mode.value,
|
|
578
|
-
use_rowwise_bias_correction=
|
|
944
|
+
use_rowwise_bias_correction=use_rowwise_bias_correction, # Used in Adam optimizer
|
|
579
945
|
)
|
|
580
946
|
|
|
581
947
|
table_embedding_dtype = weights_precision.as_dtype()
|
|
@@ -593,19 +959,14 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
593
959
|
dtype=table_embedding_dtype,
|
|
594
960
|
)
|
|
595
961
|
|
|
596
|
-
|
|
597
|
-
self.
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
|
|
603
|
-
offsets=momentum1_offsets[:-1],
|
|
604
|
-
),
|
|
605
|
-
"momentum1",
|
|
962
|
+
# Create the optimizer state tensors
|
|
963
|
+
for template in self.optimizer.ssd_state_splits(
|
|
964
|
+
self.embedding_specs,
|
|
965
|
+
self.optimizer_state_dtypes,
|
|
966
|
+
self.enable_optimizer_offloading,
|
|
967
|
+
):
|
|
606
968
|
# pyre-fixme[6]: For 3rd argument expected `Type[dtype]` but got `dtype`.
|
|
607
|
-
|
|
608
|
-
)
|
|
969
|
+
self._apply_split(*template)
|
|
609
970
|
|
|
610
971
|
# For storing current iteration data
|
|
611
972
|
self.current_iter_data: Optional[IterData] = None
|
|
@@ -625,11 +986,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
625
986
|
self._update_cache_counter_and_pointers
|
|
626
987
|
)
|
|
627
988
|
|
|
628
|
-
assert optimizer in (
|
|
629
|
-
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
630
|
-
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
|
|
631
|
-
self.optimizer = optimizer
|
|
632
|
-
|
|
633
989
|
# stats reporter
|
|
634
990
|
self.gather_ssd_cache_stats = gather_ssd_cache_stats
|
|
635
991
|
self.stats_reporter: Optional[TBEStatsReporter] = (
|
|
@@ -638,7 +994,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
638
994
|
self.ssd_cache_stats_size = 6
|
|
639
995
|
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
|
|
640
996
|
# 4: N_conflict_unique_misses, 5: N_conflict_misses
|
|
641
|
-
self.last_reported_ssd_stats:
|
|
997
|
+
self.last_reported_ssd_stats: list[float] = []
|
|
642
998
|
self.last_reported_step = 0
|
|
643
999
|
|
|
644
1000
|
self.register_buffer(
|
|
@@ -669,7 +1025,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
669
1025
|
self.prefetch_parallel_stream_cnt: int = 2
|
|
670
1026
|
# tuple of iteration, prefetch parallel stream cnt, reported duration
|
|
671
1027
|
# since there are 2 stream in parallel in prefetch, we want to count the longest one
|
|
672
|
-
self.prefetch_duration_us:
|
|
1028
|
+
self.prefetch_duration_us: tuple[int, int, float] = (
|
|
673
1029
|
-1,
|
|
674
1030
|
self.prefetch_parallel_stream_cnt,
|
|
675
1031
|
0,
|
|
@@ -689,6 +1045,26 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
689
1045
|
self.l2_cache_capacity_stats_name: str = (
|
|
690
1046
|
f"l2_cache.mem.tbe_id{tbe_unique_id}.capacity_bytes"
|
|
691
1047
|
)
|
|
1048
|
+
self.dram_kv_actual_used_chunk_bytes_stats_name: str = (
|
|
1049
|
+
f"dram_kv.mem.tbe_id{tbe_unique_id}.actual_used_chunk_bytes"
|
|
1050
|
+
)
|
|
1051
|
+
self.dram_kv_allocated_bytes_stats_name: str = (
|
|
1052
|
+
f"dram_kv.mem.tbe_id{tbe_unique_id}.allocated_bytes"
|
|
1053
|
+
)
|
|
1054
|
+
self.dram_kv_mem_num_rows_stats_name: str = (
|
|
1055
|
+
f"dram_kv.mem.tbe_id{tbe_unique_id}.num_rows"
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
self.eviction_sum_evicted_counts_stats_name: str = (
|
|
1059
|
+
f"eviction.tbe_id.{tbe_unique_id}.sum_evicted_counts"
|
|
1060
|
+
)
|
|
1061
|
+
self.eviction_sum_processed_counts_stats_name: str = (
|
|
1062
|
+
f"eviction.tbe_id.{tbe_unique_id}.sum_processed_counts"
|
|
1063
|
+
)
|
|
1064
|
+
self.eviction_evict_rate_stats_name: str = (
|
|
1065
|
+
f"eviction.tbe_id.{tbe_unique_id}.evict_rate"
|
|
1066
|
+
)
|
|
1067
|
+
|
|
692
1068
|
if self.stats_reporter:
|
|
693
1069
|
self.ssd_prefetch_read_timer = AsyncSeriesTimer(
|
|
694
1070
|
functools.partial(
|
|
@@ -708,11 +1084,77 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
708
1084
|
)
|
|
709
1085
|
# pyre-ignore
|
|
710
1086
|
self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name)
|
|
711
|
-
# pyre-ignore
|
|
712
1087
|
self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name)
|
|
713
1088
|
self.stats_reporter.register_stats(self.l2_num_cache_evictions_stats_name)
|
|
714
1089
|
self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name)
|
|
715
1090
|
self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name)
|
|
1091
|
+
self.stats_reporter.register_stats(self.dram_kv_allocated_bytes_stats_name)
|
|
1092
|
+
self.stats_reporter.register_stats(
|
|
1093
|
+
self.dram_kv_actual_used_chunk_bytes_stats_name
|
|
1094
|
+
)
|
|
1095
|
+
self.stats_reporter.register_stats(self.dram_kv_mem_num_rows_stats_name)
|
|
1096
|
+
self.stats_reporter.register_stats(
|
|
1097
|
+
self.eviction_sum_evicted_counts_stats_name
|
|
1098
|
+
)
|
|
1099
|
+
self.stats_reporter.register_stats(
|
|
1100
|
+
self.eviction_sum_processed_counts_stats_name
|
|
1101
|
+
)
|
|
1102
|
+
self.stats_reporter.register_stats(self.eviction_evict_rate_stats_name)
|
|
1103
|
+
for t in self.feature_table_map:
|
|
1104
|
+
self.stats_reporter.register_stats(
|
|
1105
|
+
f"eviction.feature_table.{t}.evicted_counts"
|
|
1106
|
+
)
|
|
1107
|
+
self.stats_reporter.register_stats(
|
|
1108
|
+
f"eviction.feature_table.{t}.processed_counts"
|
|
1109
|
+
)
|
|
1110
|
+
self.stats_reporter.register_stats(
|
|
1111
|
+
f"eviction.feature_table.{t}.evict_rate"
|
|
1112
|
+
)
|
|
1113
|
+
self.stats_reporter.register_stats(
|
|
1114
|
+
"eviction.feature_table.full_duration_ms"
|
|
1115
|
+
)
|
|
1116
|
+
self.stats_reporter.register_stats(
|
|
1117
|
+
"eviction.feature_table.exec_duration_ms"
|
|
1118
|
+
)
|
|
1119
|
+
self.stats_reporter.register_stats(
|
|
1120
|
+
"eviction.feature_table.dry_run_exec_duration_ms"
|
|
1121
|
+
)
|
|
1122
|
+
self.stats_reporter.register_stats(
|
|
1123
|
+
"eviction.feature_table.exec_div_full_duration_rate"
|
|
1124
|
+
)
|
|
1125
|
+
|
|
1126
|
+
self.bounds_check_version: int = get_bounds_check_version_for_platform()
|
|
1127
|
+
|
|
1128
|
+
self._pg = pg
|
|
1129
|
+
|
|
1130
|
+
@cached_property
|
|
1131
|
+
def cache_row_dim(self) -> int:
|
|
1132
|
+
"""
|
|
1133
|
+
Compute the effective physical cache row size taking into account
|
|
1134
|
+
padding to the nearest 4 elements and the optimizer state appended to
|
|
1135
|
+
the back of the row
|
|
1136
|
+
"""
|
|
1137
|
+
|
|
1138
|
+
# For st publish, we only need to load weight for publishing and bulk eval
|
|
1139
|
+
if self.enable_optimizer_offloading and not self.load_ckpt_without_opt:
|
|
1140
|
+
return self.max_D + pad4(
|
|
1141
|
+
# Compute the number of elements of cache_dtype needed to store
|
|
1142
|
+
# the optimizer state
|
|
1143
|
+
self.optimizer_state_dim
|
|
1144
|
+
)
|
|
1145
|
+
else:
|
|
1146
|
+
return self.max_D
|
|
1147
|
+
|
|
1148
|
+
@cached_property
|
|
1149
|
+
def optimizer_state_dim(self) -> int:
|
|
1150
|
+
return int(
|
|
1151
|
+
math.ceil(
|
|
1152
|
+
self.optimizer.state_size_nbytes(
|
|
1153
|
+
self.max_D, self.optimizer_state_dtypes
|
|
1154
|
+
)
|
|
1155
|
+
/ self.weights_precision.as_dtype().itemsize
|
|
1156
|
+
)
|
|
1157
|
+
)
|
|
716
1158
|
|
|
717
1159
|
@property
|
|
718
1160
|
# pyre-ignore
|
|
@@ -766,19 +1208,22 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
766
1208
|
be effectively overwritten. This function should only be called once at
|
|
767
1209
|
initailization time.
|
|
768
1210
|
"""
|
|
1211
|
+
self._ssd_db.toggle_compaction(False)
|
|
769
1212
|
row_offset = 0
|
|
770
1213
|
row_count = floor(
|
|
771
1214
|
self.bulk_init_chunk_size
|
|
772
|
-
/ (self.
|
|
1215
|
+
/ (self.cache_row_dim * self.weights_precision.as_dtype().itemsize)
|
|
773
1216
|
)
|
|
774
1217
|
total_dim0 = 0
|
|
775
1218
|
for dim0, _ in self.embedding_specs:
|
|
776
1219
|
total_dim0 += dim0
|
|
777
1220
|
|
|
778
1221
|
start_ts = time.time()
|
|
1222
|
+
# TODO: do we have case for non-kvzch ssd with bulk init enabled + optimizer offloading? probably not?
|
|
1223
|
+
# if we have such cases, we should only init the emb dim not the optimizer dim
|
|
779
1224
|
chunk_tensor = torch.empty(
|
|
780
1225
|
row_count,
|
|
781
|
-
self.
|
|
1226
|
+
self.cache_row_dim,
|
|
782
1227
|
dtype=self.weights_precision.as_dtype(),
|
|
783
1228
|
device="cuda",
|
|
784
1229
|
)
|
|
@@ -793,12 +1238,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
793
1238
|
# This code is intentionally not calling through the getter property
|
|
794
1239
|
# to avoid the lazy initialization thread from joining with itself.
|
|
795
1240
|
self._ssd_db.set_range_to_storage(rand_val, row_offset, actual_dim0)
|
|
796
|
-
self.ssd_db.toggle_compaction(True)
|
|
797
1241
|
end_ts = time.time()
|
|
798
1242
|
elapsed = int((end_ts - start_ts) * 1e6)
|
|
799
1243
|
logging.info(
|
|
800
1244
|
f"TBE bulk initialization took {elapsed:_} us, bulk_init_chunk_size={self.bulk_init_chunk_size}, each batch of {row_count} rows, total rows of {total_dim0}"
|
|
801
1245
|
)
|
|
1246
|
+
self._ssd_db.toggle_compaction(True)
|
|
802
1247
|
|
|
803
1248
|
@torch.jit.ignore
|
|
804
1249
|
def _report_duration(
|
|
@@ -826,7 +1271,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
826
1271
|
"""
|
|
827
1272
|
recorded_itr, stream_cnt, report_val = self.prefetch_duration_us
|
|
828
1273
|
duration = dur_ms
|
|
829
|
-
if time_unit == "us":
|
|
1274
|
+
if time_unit == "us":
|
|
830
1275
|
duration = dur_ms * 1000
|
|
831
1276
|
if it_step == recorded_itr:
|
|
832
1277
|
report_val = max(report_val, duration)
|
|
@@ -845,7 +1290,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
845
1290
|
it_step, event_name, report_val, time_unit=time_unit
|
|
846
1291
|
)
|
|
847
1292
|
|
|
848
|
-
# pyre-ignore[3]
|
|
849
1293
|
def record_function_via_dummy_profile_factory(
|
|
850
1294
|
self,
|
|
851
1295
|
use_dummy_profile: bool,
|
|
@@ -867,7 +1311,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
867
1311
|
|
|
868
1312
|
def func(
|
|
869
1313
|
name: str,
|
|
870
|
-
# pyre-ignore[2]
|
|
871
1314
|
fn: Callable[..., Any],
|
|
872
1315
|
*args: Any,
|
|
873
1316
|
**kwargs: Any,
|
|
@@ -881,7 +1324,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
881
1324
|
|
|
882
1325
|
def func(
|
|
883
1326
|
name: str,
|
|
884
|
-
# pyre-ignore[2]
|
|
885
1327
|
fn: Callable[..., Any],
|
|
886
1328
|
*args: Any,
|
|
887
1329
|
**kwargs: Any,
|
|
@@ -894,10 +1336,10 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
894
1336
|
self,
|
|
895
1337
|
split: SplitState,
|
|
896
1338
|
prefix: str,
|
|
897
|
-
dtype:
|
|
1339
|
+
dtype: type[torch.dtype],
|
|
898
1340
|
enforce_hbm: bool = False,
|
|
899
1341
|
make_dev_param: bool = False,
|
|
900
|
-
dev_reshape: Optional[
|
|
1342
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
901
1343
|
) -> None:
|
|
902
1344
|
apply_split_helper(
|
|
903
1345
|
self.register_buffer,
|
|
@@ -920,11 +1362,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
920
1362
|
|
|
921
1363
|
def to_pinned_cpu_on_stream_wait_on_another_stream(
|
|
922
1364
|
self,
|
|
923
|
-
tensors:
|
|
1365
|
+
tensors: list[Tensor],
|
|
924
1366
|
stream: torch.cuda.Stream,
|
|
925
1367
|
stream_to_wait_on: torch.cuda.Stream,
|
|
926
1368
|
post_event: Optional[torch.cuda.Event] = None,
|
|
927
|
-
) ->
|
|
1369
|
+
) -> list[Tensor]:
|
|
928
1370
|
"""
|
|
929
1371
|
Transfer input tensors from GPU to CPU using a pinned host
|
|
930
1372
|
buffer. The transfer is carried out on the given stream
|
|
@@ -982,9 +1424,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
982
1424
|
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
|
|
983
1425
|
tensor (which is accessible on both host and
|
|
984
1426
|
device)
|
|
1427
|
+
is_bwd (bool): A flag to indicate if the eviction is during backward
|
|
985
1428
|
Returns:
|
|
986
1429
|
None
|
|
987
1430
|
"""
|
|
1431
|
+
if not self.training: # if not training, freeze the embedding
|
|
1432
|
+
return
|
|
988
1433
|
with record_function(f"## ssd_evict_{name} ##"):
|
|
989
1434
|
with torch.cuda.stream(stream):
|
|
990
1435
|
if pre_event is not None:
|
|
@@ -1007,6 +1452,95 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1007
1452
|
if post_event is not None:
|
|
1008
1453
|
stream.record_event(post_event)
|
|
1009
1454
|
|
|
1455
|
+
def raw_embedding_stream_sync(
|
|
1456
|
+
self,
|
|
1457
|
+
stream: torch.cuda.Stream,
|
|
1458
|
+
pre_event: Optional[torch.cuda.Event],
|
|
1459
|
+
post_event: Optional[torch.cuda.Event],
|
|
1460
|
+
name: Optional[str] = "",
|
|
1461
|
+
) -> None:
|
|
1462
|
+
"""
|
|
1463
|
+
Blocking wait the copy operation of the tensors to be streamed,
|
|
1464
|
+
to make sure they are not overwritten
|
|
1465
|
+
Args:
|
|
1466
|
+
stream (Stream): The CUDA stream that cudaStreamAddCallback will
|
|
1467
|
+
synchronize the host function with. Moreover, the
|
|
1468
|
+
asynchronous D->H memory copies will operate on
|
|
1469
|
+
this stream
|
|
1470
|
+
pre_event (Event): The CUDA event that the stream has to wait on
|
|
1471
|
+
post_event (Event): The CUDA event that the current will record on
|
|
1472
|
+
when the eviction is done
|
|
1473
|
+
Returns:
|
|
1474
|
+
None
|
|
1475
|
+
"""
|
|
1476
|
+
with record_function(f"## ssd_stream_{name} ##"):
|
|
1477
|
+
with torch.cuda.stream(stream):
|
|
1478
|
+
if pre_event is not None:
|
|
1479
|
+
stream.wait_event(pre_event)
|
|
1480
|
+
|
|
1481
|
+
self.record_function_via_dummy_profile(
|
|
1482
|
+
f"## ssd_stream_sync_{name} ##",
|
|
1483
|
+
self.ssd_db.stream_sync_cuda,
|
|
1484
|
+
)
|
|
1485
|
+
|
|
1486
|
+
if post_event is not None:
|
|
1487
|
+
stream.record_event(post_event)
|
|
1488
|
+
|
|
1489
|
+
def raw_embedding_stream(
|
|
1490
|
+
self,
|
|
1491
|
+
rows: Tensor,
|
|
1492
|
+
indices_cpu: Tensor,
|
|
1493
|
+
actions_count_cpu: Tensor,
|
|
1494
|
+
stream: torch.cuda.Stream,
|
|
1495
|
+
pre_event: Optional[torch.cuda.Event],
|
|
1496
|
+
post_event: Optional[torch.cuda.Event],
|
|
1497
|
+
is_rows_uvm: bool,
|
|
1498
|
+
blocking_tensor_copy: bool = True,
|
|
1499
|
+
name: Optional[str] = "",
|
|
1500
|
+
) -> None:
|
|
1501
|
+
"""
|
|
1502
|
+
Stream data from the given input tensors to a remote service
|
|
1503
|
+
Args:
|
|
1504
|
+
rows (Tensor): The 2D tensor that contains rows to evict
|
|
1505
|
+
indices_cpu (Tensor): The 1D CPU tensor that contains the row
|
|
1506
|
+
indices that the rows will be evicted to
|
|
1507
|
+
actions_count_cpu (Tensor): A scalar tensor that contains the
|
|
1508
|
+
number of rows that the evict function
|
|
1509
|
+
has to process
|
|
1510
|
+
stream (Stream): The CUDA stream that cudaStreamAddCallback will
|
|
1511
|
+
synchronize the host function with. Moreover, the
|
|
1512
|
+
asynchronous D->H memory copies will operate on
|
|
1513
|
+
this stream
|
|
1514
|
+
pre_event (Event): The CUDA event that the stream has to wait on
|
|
1515
|
+
post_event (Event): The CUDA event that the current will record on
|
|
1516
|
+
when the eviction is done
|
|
1517
|
+
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
|
|
1518
|
+
tensor (which is accessible on both host and
|
|
1519
|
+
device)
|
|
1520
|
+
Returns:
|
|
1521
|
+
None
|
|
1522
|
+
"""
|
|
1523
|
+
with record_function(f"## ssd_stream_{name} ##"):
|
|
1524
|
+
with torch.cuda.stream(stream):
|
|
1525
|
+
if pre_event is not None:
|
|
1526
|
+
stream.wait_event(pre_event)
|
|
1527
|
+
|
|
1528
|
+
rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)
|
|
1529
|
+
|
|
1530
|
+
rows.record_stream(stream)
|
|
1531
|
+
|
|
1532
|
+
self.record_function_via_dummy_profile(
|
|
1533
|
+
f"## ssd_stream_{name} ##",
|
|
1534
|
+
self.ssd_db.stream_cuda,
|
|
1535
|
+
indices_cpu,
|
|
1536
|
+
rows_cpu,
|
|
1537
|
+
actions_count_cpu,
|
|
1538
|
+
blocking_tensor_copy,
|
|
1539
|
+
)
|
|
1540
|
+
|
|
1541
|
+
if post_event is not None:
|
|
1542
|
+
stream.record_event(post_event)
|
|
1543
|
+
|
|
1010
1544
|
def _evict_from_scratch_pad(self, grad: Tensor) -> None:
|
|
1011
1545
|
"""
|
|
1012
1546
|
Evict conflict missed rows from a scratch pad
|
|
@@ -1043,6 +1577,18 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1043
1577
|
if not do_evict:
|
|
1044
1578
|
return
|
|
1045
1579
|
|
|
1580
|
+
if self.enable_raw_embedding_streaming:
|
|
1581
|
+
self.raw_embedding_stream(
|
|
1582
|
+
rows=inserted_rows,
|
|
1583
|
+
indices_cpu=post_bwd_evicted_indices_cpu,
|
|
1584
|
+
actions_count_cpu=actions_count_cpu,
|
|
1585
|
+
stream=self.ssd_eviction_stream,
|
|
1586
|
+
pre_event=self.ssd_event_backward,
|
|
1587
|
+
post_event=self.ssd_event_sp_streamed,
|
|
1588
|
+
is_rows_uvm=True,
|
|
1589
|
+
blocking_tensor_copy=True,
|
|
1590
|
+
name="scratch_pad",
|
|
1591
|
+
)
|
|
1046
1592
|
self.evict(
|
|
1047
1593
|
rows=inserted_rows,
|
|
1048
1594
|
indices_cpu=post_bwd_evicted_indices_cpu,
|
|
@@ -1060,7 +1606,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1060
1606
|
def _update_cache_counter_and_pointers(
|
|
1061
1607
|
self,
|
|
1062
1608
|
module: nn.Module,
|
|
1063
|
-
grad_input: Union[
|
|
1609
|
+
grad_input: Union[tuple[Tensor, ...], Tensor],
|
|
1064
1610
|
) -> None:
|
|
1065
1611
|
"""
|
|
1066
1612
|
Update cache line locking counter and pointers before backward
|
|
@@ -1145,9 +1691,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1145
1691
|
if len(self.ssd_location_update_data) == 0:
|
|
1146
1692
|
return
|
|
1147
1693
|
|
|
1148
|
-
|
|
1149
|
-
0
|
|
1150
|
-
)
|
|
1694
|
+
sp_curr_next_map, inserted_rows_next = self.ssd_location_update_data.pop(0)
|
|
1151
1695
|
|
|
1152
1696
|
# Update poitners
|
|
1153
1697
|
torch.ops.fbgemm.ssd_update_row_addrs(
|
|
@@ -1162,12 +1706,63 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1162
1706
|
unique_indices_length_curr=curr_data.actions_count_gpu,
|
|
1163
1707
|
)
|
|
1164
1708
|
|
|
1709
|
+
def _update_feature_score_metadata(
|
|
1710
|
+
self,
|
|
1711
|
+
linear_cache_indices: Tensor,
|
|
1712
|
+
weights: Tensor,
|
|
1713
|
+
d2h_stream: torch.cuda.Stream,
|
|
1714
|
+
write_stream: torch.cuda.Stream,
|
|
1715
|
+
pre_event_for_write: torch.cuda.Event,
|
|
1716
|
+
post_event: Optional[torch.cuda.Event] = None,
|
|
1717
|
+
) -> None:
|
|
1718
|
+
"""
|
|
1719
|
+
Write feature score metadata to DRAM
|
|
1720
|
+
|
|
1721
|
+
This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
|
|
1722
|
+
The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
|
|
1723
|
+
|
|
1724
|
+
Args:
|
|
1725
|
+
linear_cache_indices: GPU tensor containing cache indices
|
|
1726
|
+
weights: GPU tensor containing feature scores
|
|
1727
|
+
d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
|
|
1728
|
+
write_stream: Stream for metadata write operation
|
|
1729
|
+
pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
|
|
1730
|
+
post_event: Event to record when the operation is done
|
|
1731
|
+
"""
|
|
1732
|
+
# Start D2H copy on d2h_stream
|
|
1733
|
+
with torch.cuda.stream(d2h_stream):
|
|
1734
|
+
# Record streams to prevent premature deallocation
|
|
1735
|
+
linear_cache_indices.record_stream(d2h_stream)
|
|
1736
|
+
weights.record_stream(d2h_stream)
|
|
1737
|
+
# Do the D2H copy
|
|
1738
|
+
linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
|
|
1739
|
+
score_weights_cpu = self.to_pinned_cpu(weights)
|
|
1740
|
+
|
|
1741
|
+
# Write feature score metadata to DRAM
|
|
1742
|
+
with record_function("## ssd_write_feature_score_metadata ##"):
|
|
1743
|
+
with torch.cuda.stream(write_stream):
|
|
1744
|
+
write_stream.wait_event(pre_event_for_write)
|
|
1745
|
+
write_stream.wait_stream(d2h_stream)
|
|
1746
|
+
self.record_function_via_dummy_profile(
|
|
1747
|
+
"## ssd_write_feature_score_metadata ##",
|
|
1748
|
+
self.ssd_db.set_feature_score_metadata_cuda,
|
|
1749
|
+
linear_cache_indices_cpu,
|
|
1750
|
+
torch.tensor(
|
|
1751
|
+
[score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
|
|
1752
|
+
),
|
|
1753
|
+
score_weights_cpu,
|
|
1754
|
+
)
|
|
1755
|
+
|
|
1756
|
+
if post_event is not None:
|
|
1757
|
+
write_stream.record_event(post_event)
|
|
1758
|
+
|
|
1165
1759
|
def prefetch(
|
|
1166
1760
|
self,
|
|
1167
1761
|
indices: Tensor,
|
|
1168
1762
|
offsets: Tensor,
|
|
1763
|
+
weights: Optional[Tensor] = None, # todo: need to update caller
|
|
1169
1764
|
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
1170
|
-
batch_size_per_feature_per_rank: Optional[
|
|
1765
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
1171
1766
|
) -> None:
|
|
1172
1767
|
if self.prefetch_stream is None and forward_stream is not None:
|
|
1173
1768
|
# Set the prefetch stream to the current stream
|
|
@@ -1191,6 +1786,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1191
1786
|
self._prefetch(
|
|
1192
1787
|
indices,
|
|
1193
1788
|
offsets,
|
|
1789
|
+
weights,
|
|
1194
1790
|
vbe_metadata,
|
|
1195
1791
|
forward_stream,
|
|
1196
1792
|
)
|
|
@@ -1199,11 +1795,17 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1199
1795
|
self,
|
|
1200
1796
|
indices: Tensor,
|
|
1201
1797
|
offsets: Tensor,
|
|
1798
|
+
weights: Optional[Tensor] = None,
|
|
1202
1799
|
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
|
|
1203
1800
|
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
1204
1801
|
) -> None:
|
|
1205
|
-
#
|
|
1802
|
+
# Wait for any ongoing direct_write_embedding operations to complete
|
|
1803
|
+
# Moving this from forward() to _prefetch() is more logical as direct_write
|
|
1804
|
+
# operations affect the same cache structures that prefetch interacts with
|
|
1206
1805
|
current_stream = torch.cuda.current_stream()
|
|
1806
|
+
if self._embedding_cache_mode:
|
|
1807
|
+
current_stream.wait_event(self.direct_write_l1_complete_event)
|
|
1808
|
+
current_stream.wait_event(self.direct_write_sp_complete_event)
|
|
1207
1809
|
|
|
1208
1810
|
B_offsets = None
|
|
1209
1811
|
max_B = -1
|
|
@@ -1284,10 +1886,83 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1284
1886
|
masks=torch.where(evicted_indices != -1, 1, 0),
|
|
1285
1887
|
count=actions_count_gpu,
|
|
1286
1888
|
)
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
#
|
|
1290
|
-
#
|
|
1889
|
+
has_raw_embedding_streaming = False
|
|
1890
|
+
if self.enable_raw_embedding_streaming:
|
|
1891
|
+
# when pipelining is enabled
|
|
1892
|
+
# prefetch in iter i happens before the backward sparse in iter i - 1
|
|
1893
|
+
# so embeddings for iter i - 1's changed ids are not updated.
|
|
1894
|
+
# so we can only fetch the indices from the iter i - 2
|
|
1895
|
+
# when pipelining is disabled
|
|
1896
|
+
# prefetch in iter i happens before forward iter i
|
|
1897
|
+
# so we can get the iter i - 1's changed ids safely.
|
|
1898
|
+
target_prev_iter = 1
|
|
1899
|
+
if self.prefetch_pipeline:
|
|
1900
|
+
target_prev_iter = 2
|
|
1901
|
+
if len(self.prefetched_info) > (target_prev_iter - 1):
|
|
1902
|
+
with record_function(
|
|
1903
|
+
"## ssd_lookup_prefetched_rows {} {} ##".format(
|
|
1904
|
+
self.timestep, self.tbe_unique_id
|
|
1905
|
+
)
|
|
1906
|
+
):
|
|
1907
|
+
# wait for the copy to finish before overwriting the buffer
|
|
1908
|
+
self.raw_embedding_stream_sync(
|
|
1909
|
+
stream=self.ssd_eviction_stream,
|
|
1910
|
+
pre_event=self.ssd_event_cache_streamed,
|
|
1911
|
+
post_event=self.ssd_event_cache_streaming_synced,
|
|
1912
|
+
name="cache_update",
|
|
1913
|
+
)
|
|
1914
|
+
current_stream.wait_event(self.ssd_event_cache_streaming_synced)
|
|
1915
|
+
updated_indices, updated_counts_gpu = self.prefetched_info.pop(
|
|
1916
|
+
0
|
|
1917
|
+
)
|
|
1918
|
+
self.lxu_cache_updated_indices[: updated_indices.size(0)].copy_(
|
|
1919
|
+
updated_indices,
|
|
1920
|
+
non_blocking=True,
|
|
1921
|
+
)
|
|
1922
|
+
self.lxu_cache_updated_count[:1].copy_(
|
|
1923
|
+
updated_counts_gpu, non_blocking=True
|
|
1924
|
+
)
|
|
1925
|
+
has_raw_embedding_streaming = True
|
|
1926
|
+
|
|
1927
|
+
with record_function(
|
|
1928
|
+
"## ssd_save_prefetched_rows {} {} ##".format(
|
|
1929
|
+
self.timestep, self.tbe_unique_id
|
|
1930
|
+
)
|
|
1931
|
+
):
|
|
1932
|
+
masked_updated_indices = torch.where(
|
|
1933
|
+
torch.where(lxu_cache_locations != -1, True, False),
|
|
1934
|
+
linear_cache_indices,
|
|
1935
|
+
-1,
|
|
1936
|
+
)
|
|
1937
|
+
|
|
1938
|
+
(
|
|
1939
|
+
uni_updated_indices,
|
|
1940
|
+
uni_updated_indices_length,
|
|
1941
|
+
) = get_unique_indices_v2(
|
|
1942
|
+
masked_updated_indices,
|
|
1943
|
+
self.total_hash_size,
|
|
1944
|
+
compute_count=False,
|
|
1945
|
+
compute_inverse_indices=False,
|
|
1946
|
+
)
|
|
1947
|
+
assert uni_updated_indices is not None
|
|
1948
|
+
assert uni_updated_indices_length is not None
|
|
1949
|
+
# The unique indices has 1 more -1 element than needed,
|
|
1950
|
+
# which might make the tensor length go out of range
|
|
1951
|
+
# compared to the pre-allocated buffer.
|
|
1952
|
+
unique_len = min(
|
|
1953
|
+
self.lxu_cache_weights.size(0),
|
|
1954
|
+
uni_updated_indices.size(0),
|
|
1955
|
+
)
|
|
1956
|
+
self.prefetched_info.append(
|
|
1957
|
+
(
|
|
1958
|
+
uni_updated_indices.narrow(0, 0, unique_len),
|
|
1959
|
+
uni_updated_indices_length.clamp(max=unique_len),
|
|
1960
|
+
)
|
|
1961
|
+
)
|
|
1962
|
+
|
|
1963
|
+
with record_function("## ssd_d2h_inserted_indices ##"):
|
|
1964
|
+
# Transfer actions_count and insert_indices right away to
|
|
1965
|
+
# incrase an overlap opportunity
|
|
1291
1966
|
actions_count_cpu, inserted_indices_cpu = (
|
|
1292
1967
|
self.to_pinned_cpu_on_stream_wait_on_another_stream(
|
|
1293
1968
|
tensors=[
|
|
@@ -1312,7 +1987,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1312
1987
|
|
|
1313
1988
|
# Allocation a scratch pad for the current iteration. The scratch
|
|
1314
1989
|
# pad is a UVA tensor
|
|
1315
|
-
inserted_rows_shape = (assigned_cache_slots.numel(), self.
|
|
1990
|
+
inserted_rows_shape = (assigned_cache_slots.numel(), self.cache_row_dim)
|
|
1316
1991
|
if linear_cache_indices.numel() > 0:
|
|
1317
1992
|
inserted_rows = torch.ops.fbgemm.new_unified_tensor(
|
|
1318
1993
|
torch.zeros(
|
|
@@ -1415,25 +2090,66 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1415
2090
|
# Store info for evicting the previous iteration's
|
|
1416
2091
|
# scratch pad after the corresponding backward pass is
|
|
1417
2092
|
# done
|
|
1418
|
-
self.
|
|
1419
|
-
(
|
|
1420
|
-
|
|
1421
|
-
|
|
2093
|
+
if self.training:
|
|
2094
|
+
self.ssd_location_update_data.append(
|
|
2095
|
+
(
|
|
2096
|
+
sp_curr_prev_map_gpu,
|
|
2097
|
+
inserted_rows,
|
|
2098
|
+
)
|
|
1422
2099
|
)
|
|
1423
|
-
)
|
|
1424
2100
|
|
|
1425
2101
|
# Ensure the previous iterations eviction is complete
|
|
1426
2102
|
current_stream.wait_event(self.ssd_event_sp_evict)
|
|
1427
2103
|
# Ensure that D2H is done
|
|
1428
2104
|
current_stream.wait_event(self.ssd_event_get_inputs_cpy)
|
|
1429
2105
|
|
|
2106
|
+
if self.enable_raw_embedding_streaming and has_raw_embedding_streaming:
|
|
2107
|
+
current_stream.wait_event(self.ssd_event_sp_streamed)
|
|
2108
|
+
with record_function(
|
|
2109
|
+
"## ssd_compute_updated_rows {} {} ##".format(
|
|
2110
|
+
self.timestep, self.tbe_unique_id
|
|
2111
|
+
)
|
|
2112
|
+
):
|
|
2113
|
+
# cache rows that are changed in the previous iteration
|
|
2114
|
+
updated_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
|
|
2115
|
+
self.lxu_cache_updated_indices,
|
|
2116
|
+
self.lxu_cache_state,
|
|
2117
|
+
self.total_hash_size,
|
|
2118
|
+
self.gather_ssd_cache_stats,
|
|
2119
|
+
self.local_ssd_cache_stats,
|
|
2120
|
+
)
|
|
2121
|
+
torch.ops.fbgemm.masked_index_select(
|
|
2122
|
+
self.lxu_cache_updated_weights,
|
|
2123
|
+
updated_cache_locations,
|
|
2124
|
+
self.lxu_cache_weights,
|
|
2125
|
+
self.lxu_cache_updated_count,
|
|
2126
|
+
)
|
|
2127
|
+
current_stream.record_event(self.ssd_event_cache_streaming_computed)
|
|
2128
|
+
|
|
2129
|
+
self.raw_embedding_stream(
|
|
2130
|
+
rows=self.lxu_cache_updated_weights,
|
|
2131
|
+
indices_cpu=self.lxu_cache_updated_indices,
|
|
2132
|
+
actions_count_cpu=self.lxu_cache_updated_count,
|
|
2133
|
+
stream=self.ssd_eviction_stream,
|
|
2134
|
+
pre_event=self.ssd_event_cache_streaming_computed,
|
|
2135
|
+
post_event=self.ssd_event_cache_streamed,
|
|
2136
|
+
is_rows_uvm=True,
|
|
2137
|
+
blocking_tensor_copy=False,
|
|
2138
|
+
name="cache_update",
|
|
2139
|
+
)
|
|
2140
|
+
|
|
1430
2141
|
if self.gather_ssd_cache_stats:
|
|
1431
2142
|
# call to collect past SSD IO dur right before next rocksdb IO
|
|
1432
2143
|
|
|
1433
2144
|
self.ssd_cache_stats = torch.add(
|
|
1434
2145
|
self.ssd_cache_stats, self.local_ssd_cache_stats
|
|
1435
2146
|
)
|
|
1436
|
-
|
|
2147
|
+
# only report metrics from rank0 to avoid flooded logging
|
|
2148
|
+
if dist.get_rank() == 0:
|
|
2149
|
+
self._report_kv_backend_stats()
|
|
2150
|
+
|
|
2151
|
+
# May trigger eviction if free mem trigger mode enabled before get cuda
|
|
2152
|
+
self.may_trigger_eviction()
|
|
1437
2153
|
|
|
1438
2154
|
# Fetch data from SSD
|
|
1439
2155
|
if linear_cache_indices.numel() > 0:
|
|
@@ -1457,21 +2173,35 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1457
2173
|
use_pipeline=self.prefetch_pipeline,
|
|
1458
2174
|
)
|
|
1459
2175
|
|
|
1460
|
-
if
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
2176
|
+
if self.training:
|
|
2177
|
+
if linear_cache_indices.numel() > 0:
|
|
2178
|
+
# Evict rows from cache to SSD
|
|
2179
|
+
self.evict(
|
|
2180
|
+
rows=self.lxu_cache_evicted_weights,
|
|
2181
|
+
indices_cpu=self.lxu_cache_evicted_indices,
|
|
2182
|
+
actions_count_cpu=self.lxu_cache_evicted_count,
|
|
2183
|
+
stream=self.ssd_eviction_stream,
|
|
2184
|
+
pre_event=self.ssd_event_get,
|
|
2185
|
+
# Record completion event after scratch pad eviction
|
|
2186
|
+
# instead since that happens after L1 eviction
|
|
2187
|
+
post_event=self.ssd_event_cache_evict,
|
|
2188
|
+
is_rows_uvm=True,
|
|
2189
|
+
name="cache",
|
|
2190
|
+
is_bwd=False,
|
|
2191
|
+
)
|
|
2192
|
+
if (
|
|
2193
|
+
self.backend_type == BackendType.DRAM
|
|
2194
|
+
and weights is not None
|
|
2195
|
+
and linear_cache_indices.numel() > 0
|
|
2196
|
+
):
|
|
2197
|
+
# Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
|
|
2198
|
+
self._update_feature_score_metadata(
|
|
2199
|
+
linear_cache_indices=linear_cache_indices,
|
|
2200
|
+
weights=weights,
|
|
2201
|
+
d2h_stream=self.ssd_memcpy_stream,
|
|
2202
|
+
write_stream=self.feature_score_stream,
|
|
2203
|
+
pre_event_for_write=self.ssd_event_cache_evict,
|
|
2204
|
+
)
|
|
1475
2205
|
|
|
1476
2206
|
# Generate row addresses (pointing to either L1 or the current
|
|
1477
2207
|
# iteration's scratch pad)
|
|
@@ -1553,24 +2283,32 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1553
2283
|
)
|
|
1554
2284
|
)
|
|
1555
2285
|
|
|
1556
|
-
# Store scratch pad info for post backward eviction
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
2286
|
+
# Store scratch pad info for post backward eviction only for training
|
|
2287
|
+
# for eval job, no backward pass, so no need to store this info
|
|
2288
|
+
if self.training:
|
|
2289
|
+
self.ssd_scratch_pad_eviction_data.append(
|
|
2290
|
+
(
|
|
2291
|
+
inserted_rows,
|
|
2292
|
+
post_bwd_evicted_indices_cpu,
|
|
2293
|
+
actions_count_cpu,
|
|
2294
|
+
linear_cache_indices.numel() > 0,
|
|
2295
|
+
)
|
|
1563
2296
|
)
|
|
1564
|
-
)
|
|
1565
2297
|
|
|
1566
2298
|
# Store data for forward
|
|
1567
2299
|
self.ssd_prefetch_data.append(prefetch_data)
|
|
1568
2300
|
|
|
2301
|
+
# Record an event to mark the completion of prefetch operations
|
|
2302
|
+
# This will be used by direct_write_embedding to ensure it doesn't run concurrently with prefetch
|
|
2303
|
+
current_stream.record_event(self.prefetch_complete_event)
|
|
2304
|
+
|
|
1569
2305
|
@torch.jit.ignore
|
|
1570
2306
|
def _generate_vbe_metadata(
|
|
1571
2307
|
self,
|
|
1572
2308
|
offsets: Tensor,
|
|
1573
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2309
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
2310
|
+
vbe_output: Optional[Tensor] = None,
|
|
2311
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
1574
2312
|
) -> invokers.lookup_args.VBEMetadata:
|
|
1575
2313
|
# Blocking D2H copy, but only runs at first call
|
|
1576
2314
|
self.feature_dims = self.feature_dims.cpu()
|
|
@@ -1589,19 +2327,58 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1589
2327
|
self.pooling_mode,
|
|
1590
2328
|
self.feature_dims,
|
|
1591
2329
|
self.current_device,
|
|
2330
|
+
vbe_output,
|
|
2331
|
+
vbe_output_offsets,
|
|
1592
2332
|
)
|
|
1593
2333
|
|
|
2334
|
+
def _increment_iteration(self) -> int:
|
|
2335
|
+
# Although self.iter_cpu is created on CPU. It might be transferred to
|
|
2336
|
+
# GPU by the user. So, we need to transfer it to CPU explicitly. This
|
|
2337
|
+
# should be done only once.
|
|
2338
|
+
self.iter_cpu = self.iter_cpu.cpu()
|
|
2339
|
+
|
|
2340
|
+
# Sync with loaded state
|
|
2341
|
+
# Wrap to make it compatible with PT2 compile
|
|
2342
|
+
if not is_torchdynamo_compiling():
|
|
2343
|
+
if self.iter_cpu.item() == 0:
|
|
2344
|
+
self.iter_cpu.fill_(self.iter.cpu().item())
|
|
2345
|
+
|
|
2346
|
+
# Increment the iteration counter
|
|
2347
|
+
# The CPU counterpart is used for local computation
|
|
2348
|
+
iter_int = int(self.iter_cpu.add_(1).item())
|
|
2349
|
+
# The GPU counterpart is used for checkpointing
|
|
2350
|
+
self.iter.add_(1)
|
|
2351
|
+
|
|
2352
|
+
return iter_int
|
|
2353
|
+
|
|
1594
2354
|
def forward(
|
|
1595
2355
|
self,
|
|
1596
2356
|
indices: Tensor,
|
|
1597
2357
|
offsets: Tensor,
|
|
2358
|
+
weights: Optional[Tensor] = None,
|
|
1598
2359
|
per_sample_weights: Optional[Tensor] = None,
|
|
1599
2360
|
feature_requires_grad: Optional[Tensor] = None,
|
|
1600
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2361
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
2362
|
+
vbe_output: Optional[Tensor] = None,
|
|
2363
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
1601
2364
|
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
|
|
1602
2365
|
) -> Tensor:
|
|
2366
|
+
self.clear_cache()
|
|
2367
|
+
if vbe_output is not None or vbe_output_offsets is not None:
|
|
2368
|
+
# CPU is not supported in SSD TBE
|
|
2369
|
+
check_allocated_vbe_output(
|
|
2370
|
+
self.output_dtype,
|
|
2371
|
+
batch_size_per_feature_per_rank,
|
|
2372
|
+
vbe_output,
|
|
2373
|
+
vbe_output_offsets,
|
|
2374
|
+
)
|
|
1603
2375
|
indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
|
|
1604
|
-
indices,
|
|
2376
|
+
indices,
|
|
2377
|
+
offsets,
|
|
2378
|
+
per_sample_weights,
|
|
2379
|
+
batch_size_per_feature_per_rank,
|
|
2380
|
+
vbe_output=vbe_output,
|
|
2381
|
+
vbe_output_offsets=vbe_output_offsets,
|
|
1605
2382
|
)
|
|
1606
2383
|
|
|
1607
2384
|
if len(self.timesteps_prefetched) == 0:
|
|
@@ -1615,7 +2392,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1615
2392
|
context=self.step,
|
|
1616
2393
|
stream=self.ssd_eviction_stream,
|
|
1617
2394
|
):
|
|
1618
|
-
self._prefetch(indices, offsets, vbe_metadata)
|
|
2395
|
+
self._prefetch(indices, offsets, weights, vbe_metadata)
|
|
1619
2396
|
|
|
1620
2397
|
assert len(self.ssd_prefetch_data) > 0
|
|
1621
2398
|
|
|
@@ -1674,18 +2451,33 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1674
2451
|
"post_bwd_evicted_indices": post_bwd_evicted_indices_cpu,
|
|
1675
2452
|
"actions_count": actions_count_cpu,
|
|
1676
2453
|
},
|
|
2454
|
+
enable_optimizer_offloading=self.enable_optimizer_offloading,
|
|
1677
2455
|
# pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
|
|
1678
2456
|
vbe_metadata=vbe_metadata,
|
|
2457
|
+
learning_rate_tensor=self.learning_rate_tensor,
|
|
2458
|
+
info_B_num_bits=self.info_B_num_bits,
|
|
2459
|
+
info_B_mask=self.info_B_mask,
|
|
1679
2460
|
)
|
|
1680
2461
|
|
|
1681
2462
|
self.timesteps_prefetched.pop(0)
|
|
1682
2463
|
self.step += 1
|
|
1683
2464
|
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
2465
|
+
# Increment the iteration (value is used for certain optimizers)
|
|
2466
|
+
iter_int = self._increment_iteration()
|
|
2467
|
+
|
|
2468
|
+
if self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
|
|
2469
|
+
momentum2 = invokers.lookup_args_ssd.Momentum(
|
|
2470
|
+
# pyre-ignore[6]
|
|
2471
|
+
dev=self.momentum2_dev,
|
|
2472
|
+
# pyre-ignore[6]
|
|
2473
|
+
host=self.momentum2_host,
|
|
2474
|
+
# pyre-ignore[6]
|
|
2475
|
+
uvm=self.momentum2_uvm,
|
|
2476
|
+
# pyre-ignore[6]
|
|
2477
|
+
offsets=self.momentum2_offsets,
|
|
2478
|
+
# pyre-ignore[6]
|
|
2479
|
+
placements=self.momentum2_placements,
|
|
1687
2480
|
)
|
|
1688
|
-
return invokers.lookup_sgd_ssd.invoke(common_args, self.optimizer_args)
|
|
1689
2481
|
|
|
1690
2482
|
momentum1 = invokers.lookup_args_ssd.Momentum(
|
|
1691
2483
|
dev=self.momentum1_dev,
|
|
@@ -1696,44 +2488,584 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1696
2488
|
)
|
|
1697
2489
|
|
|
1698
2490
|
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
1699
|
-
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
|
|
1700
2491
|
return invokers.lookup_rowwise_adagrad_ssd.invoke(
|
|
1701
2492
|
common_args, self.optimizer_args, momentum1
|
|
1702
2493
|
)
|
|
1703
2494
|
|
|
2495
|
+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
|
|
2496
|
+
return invokers.lookup_partial_rowwise_adam_ssd.invoke(
|
|
2497
|
+
common_args,
|
|
2498
|
+
self.optimizer_args,
|
|
2499
|
+
momentum1,
|
|
2500
|
+
# pyre-ignore[61]
|
|
2501
|
+
momentum2,
|
|
2502
|
+
iter_int,
|
|
2503
|
+
)
|
|
2504
|
+
|
|
2505
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2506
|
+
row_counter = invokers.lookup_args_ssd.Momentum(
|
|
2507
|
+
# pyre-fixme[6]
|
|
2508
|
+
dev=self.row_counter_dev,
|
|
2509
|
+
# pyre-fixme[6]
|
|
2510
|
+
host=self.row_counter_host,
|
|
2511
|
+
# pyre-fixme[6]
|
|
2512
|
+
uvm=self.row_counter_uvm,
|
|
2513
|
+
# pyre-fixme[6]
|
|
2514
|
+
offsets=self.row_counter_offsets,
|
|
2515
|
+
# pyre-fixme[6]
|
|
2516
|
+
placements=self.row_counter_placements,
|
|
2517
|
+
)
|
|
2518
|
+
|
|
2519
|
+
return invokers.lookup_adam_ssd.invoke(
|
|
2520
|
+
common_args,
|
|
2521
|
+
self.optimizer_args,
|
|
2522
|
+
momentum1,
|
|
2523
|
+
# pyre-ignore[61]
|
|
2524
|
+
momentum2,
|
|
2525
|
+
iter_int,
|
|
2526
|
+
row_counter=row_counter,
|
|
2527
|
+
)
|
|
2528
|
+
|
|
1704
2529
|
@torch.jit.ignore
|
|
1705
|
-
def
|
|
2530
|
+
def _split_optimizer_states_non_kv_zch(
|
|
2531
|
+
self,
|
|
2532
|
+
) -> list[list[torch.Tensor]]:
|
|
1706
2533
|
"""
|
|
1707
|
-
Returns a list of states, split by table
|
|
1708
|
-
|
|
2534
|
+
Returns a list of optimizer states (view), split by table.
|
|
2535
|
+
|
|
2536
|
+
Returns:
|
|
2537
|
+
A list of list of states. Shape = (the number of tables, the number
|
|
2538
|
+
of states).
|
|
2539
|
+
|
|
2540
|
+
The following shows the list of states (in the returned order) for
|
|
2541
|
+
each optimizer:
|
|
2542
|
+
|
|
2543
|
+
(1) `EXACT_ROWWISE_ADAGRAD`: `momentum1` (rowwise)
|
|
2544
|
+
|
|
2545
|
+
(1) `PARTIAL_ROWWISE_ADAM`: `momentum1`, `momentum2` (rowwise)
|
|
1709
2546
|
"""
|
|
1710
|
-
(rows, _) = zip(*self.embedding_specs)
|
|
1711
2547
|
|
|
1712
|
-
|
|
2548
|
+
# Row count per table
|
|
2549
|
+
rows, dims = zip(*self.embedding_specs)
|
|
2550
|
+
# Cumulative row counts per table for rowwise states
|
|
2551
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
2552
|
+
# Cumulative element counts per table for elementwise states
|
|
2553
|
+
elem_count_cumsum: list[int] = [0] + list(
|
|
2554
|
+
itertools.accumulate([r * d for r, d in self.embedding_specs])
|
|
2555
|
+
)
|
|
2556
|
+
|
|
2557
|
+
# pyre-ignore[53]
|
|
2558
|
+
def _slice(tensor: Tensor, t: int, rowwise: bool) -> Tensor:
|
|
2559
|
+
d: int = dims[t]
|
|
2560
|
+
e: int = rows[t]
|
|
2561
|
+
|
|
2562
|
+
if not rowwise:
|
|
2563
|
+
# Optimizer state is element-wise - compute the table offset for
|
|
2564
|
+
# the table, view the slice as 2D tensor
|
|
2565
|
+
return tensor.detach()[
|
|
2566
|
+
elem_count_cumsum[t] : elem_count_cumsum[t + 1]
|
|
2567
|
+
].view(-1, d)
|
|
2568
|
+
else:
|
|
2569
|
+
# Optimizer state is row-wise - fetch elements in range and view
|
|
2570
|
+
# slice as 1D
|
|
2571
|
+
return tensor.detach()[
|
|
2572
|
+
row_count_cumsum[t] : row_count_cumsum[t + 1]
|
|
2573
|
+
].view(e)
|
|
2574
|
+
|
|
2575
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
2576
|
+
return [
|
|
2577
|
+
[_slice(self.momentum1_dev, t, rowwise=True)]
|
|
2578
|
+
for t, _ in enumerate(rows)
|
|
2579
|
+
]
|
|
2580
|
+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
|
|
2581
|
+
return [
|
|
2582
|
+
[
|
|
2583
|
+
_slice(self.momentum1_dev, t, rowwise=False),
|
|
2584
|
+
# pyre-ignore[6]
|
|
2585
|
+
_slice(self.momentum2_dev, t, rowwise=True),
|
|
2586
|
+
]
|
|
2587
|
+
for t, _ in enumerate(rows)
|
|
2588
|
+
]
|
|
2589
|
+
|
|
2590
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2591
|
+
return [
|
|
2592
|
+
[
|
|
2593
|
+
_slice(self.momentum1_dev, t, rowwise=False),
|
|
2594
|
+
# pyre-ignore[6]
|
|
2595
|
+
_slice(self.momentum2_dev, t, rowwise=False),
|
|
2596
|
+
]
|
|
2597
|
+
for t, _ in enumerate(rows)
|
|
2598
|
+
]
|
|
2599
|
+
|
|
2600
|
+
else:
|
|
2601
|
+
raise NotImplementedError(
|
|
2602
|
+
f"Getting optimizer states is not supported for {self.optimizer}"
|
|
2603
|
+
)
|
|
2604
|
+
|
|
2605
|
+
@torch.jit.ignore
|
|
2606
|
+
def _split_optimizer_states_kv_zch_no_offloading(
|
|
2607
|
+
self,
|
|
2608
|
+
sorted_ids: torch.Tensor,
|
|
2609
|
+
) -> list[list[torch.Tensor]]:
|
|
2610
|
+
|
|
2611
|
+
# Row count per table
|
|
2612
|
+
rows, dims = zip(*self.embedding_specs)
|
|
2613
|
+
# Cumulative row counts per table for rowwise states
|
|
2614
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
2615
|
+
# Cumulative element counts per table for elementwise states
|
|
2616
|
+
elem_count_cumsum: list[int] = [0] + list(
|
|
2617
|
+
itertools.accumulate([r * d for r, d in self.embedding_specs])
|
|
2618
|
+
)
|
|
2619
|
+
|
|
2620
|
+
# pyre-ignore[53]
|
|
2621
|
+
def _slice(state_name: str, tensor: Tensor, t: int, rowwise: bool) -> Tensor:
|
|
2622
|
+
d: int = dims[t]
|
|
2623
|
+
|
|
2624
|
+
# pyre-ignore[16]
|
|
2625
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
2626
|
+
# pyre-ignore[16]
|
|
2627
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
2628
|
+
|
|
2629
|
+
if sorted_ids is None or sorted_ids[t].numel() == 0:
|
|
2630
|
+
# Empty optimizer state for module initialization
|
|
2631
|
+
return torch.empty(
|
|
2632
|
+
0,
|
|
2633
|
+
dtype=(
|
|
2634
|
+
self.optimizer_state_dtypes.get(
|
|
2635
|
+
state_name, SparseType.FP32
|
|
2636
|
+
).as_dtype()
|
|
2637
|
+
),
|
|
2638
|
+
device="cpu",
|
|
2639
|
+
)
|
|
2640
|
+
|
|
2641
|
+
elif not rowwise:
|
|
2642
|
+
# Optimizer state is element-wise - materialize the local ids
|
|
2643
|
+
# based on the sorted_ids compute the table offset for the
|
|
2644
|
+
# table, view the slice as 2D tensor of e x d, then fetch the
|
|
2645
|
+
# sub-slice by local ids
|
|
2646
|
+
#
|
|
2647
|
+
# local_ids is [N, 1], flatten it to N to keep the returned tensor 2D
|
|
2648
|
+
local_ids = (sorted_ids[t] - bucket_id_start * bucket_size).view(-1)
|
|
2649
|
+
return (
|
|
2650
|
+
tensor.detach()
|
|
2651
|
+
.cpu()[elem_count_cumsum[t] : elem_count_cumsum[t + 1]]
|
|
2652
|
+
.view(-1, d)[local_ids]
|
|
2653
|
+
)
|
|
2654
|
+
|
|
2655
|
+
else:
|
|
2656
|
+
# Optimizer state is row-wise - materialize the local ids based
|
|
2657
|
+
# on the sorted_ids and table offset (i.e. row count cumsum),
|
|
2658
|
+
# then fetch by local ids
|
|
2659
|
+
linearized_local_ids = (
|
|
2660
|
+
sorted_ids[t] - bucket_id_start * bucket_size + row_count_cumsum[t]
|
|
2661
|
+
)
|
|
2662
|
+
return tensor.detach().cpu()[linearized_local_ids].view(-1)
|
|
2663
|
+
|
|
2664
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
2665
|
+
return [
|
|
2666
|
+
[_slice("momentum1", self.momentum1_dev, t, rowwise=True)]
|
|
2667
|
+
for t, _ in enumerate(rows)
|
|
2668
|
+
]
|
|
2669
|
+
|
|
2670
|
+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
|
|
2671
|
+
return [
|
|
2672
|
+
[
|
|
2673
|
+
_slice("momentum1", self.momentum1_dev, t, rowwise=False),
|
|
2674
|
+
# pyre-ignore[6]
|
|
2675
|
+
_slice("momentum2", self.momentum2_dev, t, rowwise=True),
|
|
2676
|
+
]
|
|
2677
|
+
for t, _ in enumerate(rows)
|
|
2678
|
+
]
|
|
2679
|
+
|
|
2680
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2681
|
+
return [
|
|
2682
|
+
[
|
|
2683
|
+
_slice("momentum1", self.momentum1_dev, t, rowwise=False),
|
|
2684
|
+
# pyre-ignore[6]
|
|
2685
|
+
_slice("momentum2", self.momentum2_dev, t, rowwise=False),
|
|
2686
|
+
]
|
|
2687
|
+
for t, _ in enumerate(rows)
|
|
2688
|
+
]
|
|
2689
|
+
|
|
2690
|
+
else:
|
|
2691
|
+
raise NotImplementedError(
|
|
2692
|
+
f"Getting optimizer states is not supported for {self.optimizer}"
|
|
2693
|
+
)
|
|
2694
|
+
|
|
2695
|
+
@torch.jit.ignore
|
|
2696
|
+
def _split_optimizer_states_kv_zch_w_offloading(
|
|
2697
|
+
self,
|
|
2698
|
+
sorted_ids: torch.Tensor,
|
|
2699
|
+
no_snapshot: bool = True,
|
|
2700
|
+
should_flush: bool = False,
|
|
2701
|
+
) -> list[list[torch.Tensor]]:
|
|
2702
|
+
dtype = self.weights_precision.as_dtype()
|
|
2703
|
+
# Row count per table
|
|
2704
|
+
rows_, dims_ = zip(*self.embedding_specs)
|
|
2705
|
+
# Cumulative row counts per table for rowwise states
|
|
2706
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
|
|
2707
|
+
|
|
2708
|
+
snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
|
|
2709
|
+
no_snapshot=no_snapshot,
|
|
2710
|
+
should_flush=should_flush,
|
|
2711
|
+
)
|
|
2712
|
+
|
|
2713
|
+
# pyre-ignore[53]
|
|
2714
|
+
def _fetch_offloaded_optimizer_states(
|
|
2715
|
+
t: int,
|
|
2716
|
+
) -> list[Tensor]:
|
|
2717
|
+
e: int = rows_[t]
|
|
2718
|
+
d: int = dims_[t]
|
|
2719
|
+
|
|
2720
|
+
# pyre-ignore[16]
|
|
2721
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
2722
|
+
# pyre-ignore[16]
|
|
2723
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
2724
|
+
|
|
2725
|
+
row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
|
|
2726
|
+
# Count of rows to fetch
|
|
2727
|
+
rows_to_fetch = sorted_ids[t].numel()
|
|
2728
|
+
|
|
2729
|
+
# Lookup the byte offsets for each optimizer state
|
|
2730
|
+
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
2731
|
+
d, self.weights_precision, self.optimizer_state_dtypes
|
|
2732
|
+
)
|
|
2733
|
+
# Find the minimum start of all the start/end pairs - we have to
|
|
2734
|
+
# offset the start/end pairs by this value to get the correct start/end
|
|
2735
|
+
offset_ = min(
|
|
2736
|
+
[start for _, (start, _) in optimizer_state_byte_offsets.items()]
|
|
2737
|
+
)
|
|
2738
|
+
# Update the start/end pairs to be relative to offset_
|
|
2739
|
+
optimizer_state_byte_offsets = dict(
|
|
2740
|
+
(k, (v1 - offset_, v2 - offset_))
|
|
2741
|
+
for k, (v1, v2) in optimizer_state_byte_offsets.items()
|
|
2742
|
+
)
|
|
2743
|
+
|
|
2744
|
+
# Since the backend returns cache rows that pack the weights and
|
|
2745
|
+
# optimizer states together, reading the whole tensor could cause OOM,
|
|
2746
|
+
# so we use the KVTensorWrapper abstraction to query the backend and
|
|
2747
|
+
# fetch the data in chunks instead.
|
|
2748
|
+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
2749
|
+
shape=[
|
|
2750
|
+
e,
|
|
2751
|
+
# Dim is terms of **weights** dtype
|
|
2752
|
+
self.optimizer_state_dim,
|
|
2753
|
+
],
|
|
2754
|
+
dtype=dtype,
|
|
2755
|
+
row_offset=row_offset,
|
|
2756
|
+
snapshot_handle=snapshot_handle,
|
|
2757
|
+
sorted_indices=sorted_ids[t],
|
|
2758
|
+
width_offset=pad4(d),
|
|
2759
|
+
)
|
|
2760
|
+
(
|
|
2761
|
+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
2762
|
+
if self.backend_type == BackendType.SSD
|
|
2763
|
+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
2764
|
+
)
|
|
2765
|
+
|
|
2766
|
+
# Fetch the state size table for the given weights domension
|
|
2767
|
+
state_size_table = self.optimizer.state_size_table(d)
|
|
2768
|
+
|
|
2769
|
+
# Create a 2D output buffer of [rows x optimizer state dim] with the
|
|
2770
|
+
# weights type as the type. For optimizers with multiple states (e.g.
|
|
2771
|
+
# momentum1 and momentum2), this tensor will include data from all
|
|
2772
|
+
# states, hence self.optimizer_state_dim as the row size.
|
|
2773
|
+
optimizer_states_buffer = torch.empty(
|
|
2774
|
+
(rows_to_fetch, self.optimizer_state_dim), dtype=dtype, device="cpu"
|
|
2775
|
+
)
|
|
2776
|
+
|
|
2777
|
+
# Set the chunk size for fetching
|
|
2778
|
+
chunk_size = (
|
|
2779
|
+
# 10M rows => 260(max_D)* 2(ele_bytes) * 10M => 5.2GB mem spike
|
|
2780
|
+
10_000_000
|
|
2781
|
+
)
|
|
2782
|
+
logging.info(f"split optimizer chunk rows: {chunk_size}")
|
|
2783
|
+
|
|
2784
|
+
# Chunk the fetching by chunk_size
|
|
2785
|
+
for i in range(0, rows_to_fetch, chunk_size):
|
|
2786
|
+
length = min(chunk_size, rows_to_fetch - i)
|
|
2787
|
+
|
|
2788
|
+
# Fetch from backend and copy to the output buffer
|
|
2789
|
+
optimizer_states_buffer.narrow(0, i, length).copy_(
|
|
2790
|
+
tensor_wrapper.narrow(0, i, length).view(dtype)
|
|
2791
|
+
)
|
|
2792
|
+
|
|
2793
|
+
# Now split up the buffer into N views, N for each optimizer state
|
|
2794
|
+
optimizer_states: list[Tensor] = []
|
|
2795
|
+
for state_name in self.optimizer.state_names():
|
|
2796
|
+
# Extract the offsets
|
|
2797
|
+
start, end = optimizer_state_byte_offsets[state_name]
|
|
2798
|
+
|
|
2799
|
+
state = optimizer_states_buffer.view(
|
|
2800
|
+
# Force tensor to byte view
|
|
2801
|
+
dtype=torch.uint8
|
|
2802
|
+
# Copy by byte offsets
|
|
2803
|
+
)[:, start:end].view(
|
|
2804
|
+
# Re-view in the state's target dtype
|
|
2805
|
+
self.optimizer_state_dtypes.get(
|
|
2806
|
+
state_name, SparseType.FP32
|
|
2807
|
+
).as_dtype()
|
|
2808
|
+
)
|
|
2809
|
+
|
|
2810
|
+
optimizer_states.append(
|
|
2811
|
+
# If the state is rowwise (i.e. just one element per row),
|
|
2812
|
+
# then re-view as 1D tensor
|
|
2813
|
+
state
|
|
2814
|
+
if state_size_table[state_name] > 1
|
|
2815
|
+
else state.view(-1)
|
|
2816
|
+
)
|
|
2817
|
+
|
|
2818
|
+
# Return the views
|
|
2819
|
+
return optimizer_states
|
|
1713
2820
|
|
|
1714
2821
|
return [
|
|
1715
2822
|
(
|
|
1716
|
-
self.
|
|
1717
|
-
|
|
1718
|
-
)
|
|
2823
|
+
self.optimizer.empty_states([0], [d], self.optimizer_state_dtypes)[0]
|
|
2824
|
+
# Return a set of empty states if sorted_ids[t] is empty
|
|
2825
|
+
if sorted_ids is None or sorted_ids[t].numel() == 0
|
|
2826
|
+
# Else fetch the list of optimizer states for the table
|
|
2827
|
+
else _fetch_offloaded_optimizer_states(t)
|
|
1719
2828
|
)
|
|
1720
|
-
for t,
|
|
2829
|
+
for t, d in enumerate(dims_)
|
|
1721
2830
|
]
|
|
1722
2831
|
|
|
2832
|
+
@torch.jit.ignore
|
|
2833
|
+
def _split_optimizer_states_kv_zch_whole_row(
|
|
2834
|
+
self,
|
|
2835
|
+
sorted_ids: torch.Tensor,
|
|
2836
|
+
no_snapshot: bool = True,
|
|
2837
|
+
should_flush: bool = False,
|
|
2838
|
+
) -> list[list[torch.Tensor]]:
|
|
2839
|
+
dtype = self.weights_precision.as_dtype()
|
|
2840
|
+
|
|
2841
|
+
# Row and dimension counts per table
|
|
2842
|
+
# rows_ is only used here to compute the virtual table offsets
|
|
2843
|
+
rows_, dims_ = zip(*self.embedding_specs)
|
|
2844
|
+
|
|
2845
|
+
# Cumulative row counts per (virtual) table for rowwise states
|
|
2846
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
|
|
2847
|
+
|
|
2848
|
+
snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
|
|
2849
|
+
no_snapshot=no_snapshot,
|
|
2850
|
+
should_flush=should_flush,
|
|
2851
|
+
)
|
|
2852
|
+
|
|
2853
|
+
# pyre-ignore[53]
|
|
2854
|
+
def _fetch_offloaded_optimizer_states(
|
|
2855
|
+
t: int,
|
|
2856
|
+
) -> list[Tensor]:
|
|
2857
|
+
d: int = dims_[t]
|
|
2858
|
+
|
|
2859
|
+
# pyre-ignore[16]
|
|
2860
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
2861
|
+
# pyre-ignore[16]
|
|
2862
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
2863
|
+
row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
|
|
2864
|
+
|
|
2865
|
+
# When backend returns whole row, the optimizer will be returned as
|
|
2866
|
+
# PMT directly
|
|
2867
|
+
if sorted_ids[t].size(0) == 0 and self.local_weight_counts[t] > 0:
|
|
2868
|
+
logging.info(
|
|
2869
|
+
f"Before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
|
|
2870
|
+
)
|
|
2871
|
+
sorted_ids[t] = torch.zeros(
|
|
2872
|
+
(self.local_weight_counts[t], 1),
|
|
2873
|
+
device=torch.device("cpu"),
|
|
2874
|
+
dtype=torch.int64,
|
|
2875
|
+
)
|
|
2876
|
+
|
|
2877
|
+
# Lookup the byte offsets for each optimizer state relative to the
|
|
2878
|
+
# start of the weights
|
|
2879
|
+
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
2880
|
+
d, self.weights_precision, self.optimizer_state_dtypes
|
|
2881
|
+
)
|
|
2882
|
+
# Get the number of elements (of the optimizer state dtype) per state
|
|
2883
|
+
optimizer_state_size_table = self.optimizer.state_size_table(d)
|
|
2884
|
+
|
|
2885
|
+
# Get metaheader dimensions in number of elements of weight dtype
|
|
2886
|
+
metaheader_dim = (
|
|
2887
|
+
# pyre-ignore[16]
|
|
2888
|
+
self.kv_zch_params.eviction_policy.meta_header_lens[t]
|
|
2889
|
+
)
|
|
2890
|
+
|
|
2891
|
+
# Now split up the buffer into N views, N for each optimizer state
|
|
2892
|
+
optimizer_states: list[PartiallyMaterializedTensor] = []
|
|
2893
|
+
for state_name in self.optimizer.state_names():
|
|
2894
|
+
state_dtype = self.optimizer_state_dtypes.get(
|
|
2895
|
+
state_name, SparseType.FP32
|
|
2896
|
+
).as_dtype()
|
|
2897
|
+
|
|
2898
|
+
# Get the size of the state in elements of the optimizer state,
|
|
2899
|
+
# in terms of the **weights** dtype
|
|
2900
|
+
state_size = math.ceil(
|
|
2901
|
+
optimizer_state_size_table[state_name]
|
|
2902
|
+
* state_dtype.itemsize
|
|
2903
|
+
/ dtype.itemsize
|
|
2904
|
+
)
|
|
2905
|
+
|
|
2906
|
+
# Extract the offsets relative to the start of the weights (in
|
|
2907
|
+
# num bytes)
|
|
2908
|
+
start, _ = optimizer_state_byte_offsets[state_name]
|
|
2909
|
+
|
|
2910
|
+
# Convert the start to number of elements in terms of the
|
|
2911
|
+
# **weights** dtype, then add the mmetaheader dim offset
|
|
2912
|
+
start = metaheader_dim + start // dtype.itemsize
|
|
2913
|
+
|
|
2914
|
+
shape = [
|
|
2915
|
+
(
|
|
2916
|
+
sorted_ids[t].size(0)
|
|
2917
|
+
if sorted_ids is not None and sorted_ids[t].size(0) > 0
|
|
2918
|
+
else self.local_weight_counts[t]
|
|
2919
|
+
),
|
|
2920
|
+
(
|
|
2921
|
+
# Dim is in terms of the **weights** dtype
|
|
2922
|
+
state_size
|
|
2923
|
+
),
|
|
2924
|
+
]
|
|
2925
|
+
|
|
2926
|
+
# NOTE: We have to view using the **weights** dtype, as
|
|
2927
|
+
# there is currently a bug with KVTensorWrapper where using
|
|
2928
|
+
# a different dtype does not result in the same bytes being
|
|
2929
|
+
# returned, e.g.
|
|
2930
|
+
#
|
|
2931
|
+
# KVTensorWrapper(dtype=fp32, width_offset=130, shape=[N, 1])
|
|
2932
|
+
#
|
|
2933
|
+
# is NOT the same as
|
|
2934
|
+
#
|
|
2935
|
+
# KVTensorWrapper(dtype=fp16, width_offset=260, shape=[N, 2]).view(-1).view(fp32)
|
|
2936
|
+
#
|
|
2937
|
+
# TODO: Fix KVTensorWrapper to support viewing data under different dtypes
|
|
2938
|
+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
2939
|
+
shape=shape,
|
|
2940
|
+
dtype=(
|
|
2941
|
+
# NOTE: Use the *weights* dtype
|
|
2942
|
+
dtype
|
|
2943
|
+
),
|
|
2944
|
+
row_offset=row_offset,
|
|
2945
|
+
snapshot_handle=snapshot_handle,
|
|
2946
|
+
sorted_indices=sorted_ids[t],
|
|
2947
|
+
width_offset=(
|
|
2948
|
+
# NOTE: Width offset is in terms of **weights** dtype
|
|
2949
|
+
start
|
|
2950
|
+
),
|
|
2951
|
+
# Optimizer written to DB with weights, so skip write here
|
|
2952
|
+
read_only=True,
|
|
2953
|
+
)
|
|
2954
|
+
(
|
|
2955
|
+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
2956
|
+
if self.backend_type == BackendType.SSD
|
|
2957
|
+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
2958
|
+
)
|
|
2959
|
+
|
|
2960
|
+
optimizer_states.append(
|
|
2961
|
+
PartiallyMaterializedTensor(tensor_wrapper, True)
|
|
2962
|
+
)
|
|
2963
|
+
|
|
2964
|
+
# pyre-ignore [7]
|
|
2965
|
+
return optimizer_states
|
|
2966
|
+
|
|
2967
|
+
return [_fetch_offloaded_optimizer_states(t) for t, _ in enumerate(dims_)]
|
|
2968
|
+
|
|
2969
|
+
@torch.jit.export
|
|
2970
|
+
def split_optimizer_states(
|
|
2971
|
+
self,
|
|
2972
|
+
sorted_id_tensor: Optional[list[torch.Tensor]] = None,
|
|
2973
|
+
no_snapshot: bool = True,
|
|
2974
|
+
should_flush: bool = False,
|
|
2975
|
+
) -> list[list[torch.Tensor]]:
|
|
2976
|
+
"""
|
|
2977
|
+
Returns a list of optimizer states split by table.
|
|
2978
|
+
|
|
2979
|
+
Since EXACT_ROWWISE_ADAGRAD has small optimizer states, we would generate
|
|
2980
|
+
a full tensor for each table (shard). When other optimizer types are supported,
|
|
2981
|
+
we should integrate with KVTensorWrapper (ssd_split_table_batched_embeddings.cpp)
|
|
2982
|
+
to allow caller to read the optimizer states using `narrow()` in a rolling-window manner.
|
|
2983
|
+
|
|
2984
|
+
Args:
|
|
2985
|
+
sorted_id_tensor (Optional[List[torch.Tensor]]): sorted id tensor by table, used to query optimizer
|
|
2986
|
+
state from backend. Call should reuse the generated id tensor from weight state_dict, to guarantee
|
|
2987
|
+
id consistency between weight and optimizer states.
|
|
2988
|
+
|
|
2989
|
+
"""
|
|
2990
|
+
|
|
2991
|
+
# Handle the non-KVZCH case
|
|
2992
|
+
if not self.kv_zch_params:
|
|
2993
|
+
# If not in KV
|
|
2994
|
+
return self._split_optimizer_states_non_kv_zch()
|
|
2995
|
+
|
|
2996
|
+
# Handle the loading-from-state-dict case
|
|
2997
|
+
if self.load_state_dict:
|
|
2998
|
+
# Initialize for checkpointing loading
|
|
2999
|
+
assert (
|
|
3000
|
+
self._cached_kvzch_data is not None
|
|
3001
|
+
and self._cached_kvzch_data.cached_optimizer_states_per_table
|
|
3002
|
+
), "optimizer state is not initialized for load checkpointing"
|
|
3003
|
+
|
|
3004
|
+
return self._cached_kvzch_data.cached_optimizer_states_per_table
|
|
3005
|
+
|
|
3006
|
+
logging.info(
|
|
3007
|
+
f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
|
|
3008
|
+
)
|
|
3009
|
+
start_time = time.time()
|
|
3010
|
+
|
|
3011
|
+
if not self.enable_optimizer_offloading:
|
|
3012
|
+
# Handle the KVZCH non-optimizer-offloading case
|
|
3013
|
+
optimizer_states = self._split_optimizer_states_kv_zch_no_offloading(
|
|
3014
|
+
sorted_id_tensor
|
|
3015
|
+
)
|
|
3016
|
+
|
|
3017
|
+
elif not self.backend_return_whole_row:
|
|
3018
|
+
# Handle the KVZCH with-optimizer-offloading case
|
|
3019
|
+
optimizer_states = self._split_optimizer_states_kv_zch_w_offloading(
|
|
3020
|
+
sorted_id_tensor, no_snapshot, should_flush
|
|
3021
|
+
)
|
|
3022
|
+
|
|
3023
|
+
else:
|
|
3024
|
+
# Handle the KVZCH with-optimizer-offloading backend-whole-row case
|
|
3025
|
+
optimizer_states = self._split_optimizer_states_kv_zch_whole_row(
|
|
3026
|
+
sorted_id_tensor, no_snapshot, should_flush
|
|
3027
|
+
)
|
|
3028
|
+
|
|
3029
|
+
logging.info(
|
|
3030
|
+
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
|
|
3031
|
+
# pyre-ignore[16]
|
|
3032
|
+
f"num ids list: {None if not sorted_id_tensor else [ids.numel() for ids in sorted_id_tensor]}"
|
|
3033
|
+
)
|
|
3034
|
+
|
|
3035
|
+
return optimizer_states
|
|
3036
|
+
|
|
3037
|
+
@torch.jit.export
|
|
3038
|
+
def get_optimizer_state(
|
|
3039
|
+
self,
|
|
3040
|
+
sorted_id_tensor: Optional[list[torch.Tensor]],
|
|
3041
|
+
no_snapshot: bool = True,
|
|
3042
|
+
should_flush: bool = False,
|
|
3043
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
3044
|
+
"""
|
|
3045
|
+
Returns a list of dictionaries of optimizer states split by table.
|
|
3046
|
+
"""
|
|
3047
|
+
states_list: list[list[Tensor]] = self.split_optimizer_states(
|
|
3048
|
+
sorted_id_tensor=sorted_id_tensor,
|
|
3049
|
+
no_snapshot=no_snapshot,
|
|
3050
|
+
should_flush=should_flush,
|
|
3051
|
+
)
|
|
3052
|
+
state_names = self.optimizer.state_names()
|
|
3053
|
+
return [dict(zip(state_names, states)) for states in states_list]
|
|
3054
|
+
|
|
1723
3055
|
@torch.jit.export
|
|
1724
|
-
def debug_split_embedding_weights(self) ->
|
|
3056
|
+
def debug_split_embedding_weights(self) -> list[torch.Tensor]:
|
|
1725
3057
|
"""
|
|
1726
3058
|
Returns a list of weights, split by table.
|
|
1727
3059
|
|
|
1728
3060
|
Testing only, very slow.
|
|
1729
3061
|
"""
|
|
1730
|
-
|
|
3062
|
+
rows, _ = zip(*self.embedding_specs)
|
|
1731
3063
|
|
|
1732
3064
|
rows_cumsum = [0] + list(itertools.accumulate(rows))
|
|
1733
3065
|
splits = []
|
|
1734
3066
|
get_event = torch.cuda.Event()
|
|
1735
3067
|
|
|
1736
|
-
for t, (row,
|
|
3068
|
+
for t, (row, _) in enumerate(self.embedding_specs):
|
|
1737
3069
|
weights = torch.empty(
|
|
1738
3070
|
(row, self.max_D), dtype=self.weights_precision.as_dtype()
|
|
1739
3071
|
)
|
|
@@ -1765,12 +3097,92 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1765
3097
|
|
|
1766
3098
|
return splits
|
|
1767
3099
|
|
|
3100
|
+
def clear_cache(self) -> None:
|
|
3101
|
+
# clear KV ZCH cache for checkpointing
|
|
3102
|
+
self._cached_kvzch_data = None
|
|
3103
|
+
|
|
3104
|
+
@torch.jit.ignore
|
|
3105
|
+
# pyre-ignore [3] - do not definte snapshot class EmbeddingSnapshotHandleWrapper to avoid import dependency in other production code
|
|
3106
|
+
def _may_create_snapshot_for_state_dict(
|
|
3107
|
+
self,
|
|
3108
|
+
no_snapshot: bool = True,
|
|
3109
|
+
should_flush: bool = False,
|
|
3110
|
+
):
|
|
3111
|
+
"""
|
|
3112
|
+
Create a rocksdb snapshot if needed.
|
|
3113
|
+
"""
|
|
3114
|
+
start_time = time.time()
|
|
3115
|
+
# Force device synchronize for now
|
|
3116
|
+
torch.cuda.synchronize()
|
|
3117
|
+
snapshot_handle = None
|
|
3118
|
+
checkpoint_handle = None
|
|
3119
|
+
if self.backend_type == BackendType.SSD:
|
|
3120
|
+
# Create a rocksdb snapshot
|
|
3121
|
+
if not no_snapshot:
|
|
3122
|
+
# Flush L1 and L2 caches
|
|
3123
|
+
self.flush(force=should_flush)
|
|
3124
|
+
logging.info(
|
|
3125
|
+
f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
|
|
3126
|
+
)
|
|
3127
|
+
snapshot_handle = self.ssd_db.create_snapshot()
|
|
3128
|
+
checkpoint_handle = self.ssd_db.get_active_checkpoint_uuid(self.step)
|
|
3129
|
+
logging.info(
|
|
3130
|
+
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
|
|
3131
|
+
)
|
|
3132
|
+
elif self.backend_type == BackendType.DRAM:
|
|
3133
|
+
# if there is any ongoing eviction, lets wait until eviction is finished before state_dict
|
|
3134
|
+
# so that we can reach consistent model state before/after state_dict
|
|
3135
|
+
evict_wait_start_time = time.time()
|
|
3136
|
+
self.ssd_db.wait_until_eviction_done()
|
|
3137
|
+
logging.info(
|
|
3138
|
+
f"state_dict wait for ongoing eviction: {time.time() - evict_wait_start_time} s"
|
|
3139
|
+
)
|
|
3140
|
+
self.flush(force=should_flush)
|
|
3141
|
+
return snapshot_handle, checkpoint_handle
|
|
3142
|
+
|
|
3143
|
+
def get_embedding_dim_for_kvt(
|
|
3144
|
+
self, metaheader_dim: int, emb_dim: int, is_loading_checkpoint: bool
|
|
3145
|
+
) -> int:
|
|
3146
|
+
if self.load_ckpt_without_opt:
|
|
3147
|
+
# For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host.
|
|
3148
|
+
# So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that
|
|
3149
|
+
# backend will only have metaheader + weight.
|
|
3150
|
+
# For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error
|
|
3151
|
+
# after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb
|
|
3152
|
+
if is_loading_checkpoint:
|
|
3153
|
+
return (
|
|
3154
|
+
(
|
|
3155
|
+
metaheader_dim # metaheader is already padded
|
|
3156
|
+
+ pad4(emb_dim)
|
|
3157
|
+
+ pad4(self.optimizer_state_dim)
|
|
3158
|
+
)
|
|
3159
|
+
if self.backend_return_whole_row
|
|
3160
|
+
else emb_dim
|
|
3161
|
+
)
|
|
3162
|
+
else:
|
|
3163
|
+
return metaheader_dim + pad4(emb_dim)
|
|
3164
|
+
else:
|
|
3165
|
+
return (
|
|
3166
|
+
(
|
|
3167
|
+
metaheader_dim # metaheader is already padded
|
|
3168
|
+
+ pad4(emb_dim)
|
|
3169
|
+
+ pad4(self.optimizer_state_dim)
|
|
3170
|
+
)
|
|
3171
|
+
if self.backend_return_whole_row
|
|
3172
|
+
else emb_dim
|
|
3173
|
+
)
|
|
3174
|
+
|
|
1768
3175
|
@torch.jit.export
|
|
1769
3176
|
def split_embedding_weights(
|
|
1770
3177
|
self,
|
|
1771
3178
|
no_snapshot: bool = True,
|
|
1772
3179
|
should_flush: bool = False,
|
|
1773
|
-
) ->
|
|
3180
|
+
) -> tuple[ # TODO: make this a NamedTuple for readability
|
|
3181
|
+
Union[list[PartiallyMaterializedTensor], list[torch.Tensor]],
|
|
3182
|
+
Optional[list[torch.Tensor]],
|
|
3183
|
+
Optional[list[torch.Tensor]],
|
|
3184
|
+
Optional[list[torch.Tensor]],
|
|
3185
|
+
]:
|
|
1774
3186
|
"""
|
|
1775
3187
|
This method is intended to be used by the checkpointing engine
|
|
1776
3188
|
only.
|
|
@@ -1784,50 +3196,454 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1784
3196
|
operation, only set to True when necessary.
|
|
1785
3197
|
|
|
1786
3198
|
Returns:
|
|
1787
|
-
|
|
3199
|
+
tuples of 3 lists, each element corresponds to a logical table
|
|
3200
|
+
1st arg: partially materialized tensors, each representing a table
|
|
3201
|
+
2nd arg: input id sorted in bucket id ascending order
|
|
3202
|
+
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
|
|
3203
|
+
where for the i th element, we have i + bucket_id_start = global bucket id
|
|
3204
|
+
4th arg: kvzch eviction metadata for each input id sorted in bucket id ascending order
|
|
1788
3205
|
"""
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
snapshot_handle = None
|
|
1794
|
-
else:
|
|
1795
|
-
if should_flush:
|
|
1796
|
-
# Flush L1 and L2 caches
|
|
1797
|
-
self.flush()
|
|
1798
|
-
snapshot_handle = self.ssd_db.create_snapshot()
|
|
3206
|
+
snapshot_handle, checkpoint_handle = self._may_create_snapshot_for_state_dict(
|
|
3207
|
+
no_snapshot=no_snapshot,
|
|
3208
|
+
should_flush=should_flush,
|
|
3209
|
+
)
|
|
1799
3210
|
|
|
1800
3211
|
dtype = self.weights_precision.as_dtype()
|
|
1801
|
-
|
|
3212
|
+
if self.load_state_dict and self.kv_zch_params:
|
|
3213
|
+
# init for checkpointing loading
|
|
3214
|
+
assert (
|
|
3215
|
+
self._cached_kvzch_data is not None
|
|
3216
|
+
), "weight id and bucket state are not initialized for load checkpointing"
|
|
3217
|
+
return (
|
|
3218
|
+
self._cached_kvzch_data.cached_weight_tensor_per_table,
|
|
3219
|
+
self._cached_kvzch_data.cached_id_tensor_per_table,
|
|
3220
|
+
self._cached_kvzch_data.cached_bucket_splits,
|
|
3221
|
+
[], # metadata tensor is not needed for checkpointing loading
|
|
3222
|
+
)
|
|
3223
|
+
start_time = time.time()
|
|
3224
|
+
pmt_splits = []
|
|
3225
|
+
bucket_sorted_id_splits = [] if self.kv_zch_params else None
|
|
3226
|
+
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
|
|
3227
|
+
metadata_splits = [] if self.kv_zch_params else None
|
|
3228
|
+
skip_metadata = False
|
|
3229
|
+
|
|
3230
|
+
table_offset = 0
|
|
3231
|
+
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
|
|
3232
|
+
is_loading_checkpoint = False
|
|
3233
|
+
bucket_ascending_id_tensor = None
|
|
3234
|
+
bucket_t = None
|
|
3235
|
+
metadata_tensor = None
|
|
3236
|
+
row_offset = table_offset
|
|
3237
|
+
metaheader_dim = 0
|
|
3238
|
+
if self.kv_zch_params:
|
|
3239
|
+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
|
|
3240
|
+
# pyre-ignore
|
|
3241
|
+
bucket_size = self.kv_zch_params.bucket_sizes[i]
|
|
3242
|
+
metaheader_dim = (
|
|
3243
|
+
# pyre-ignore[16]
|
|
3244
|
+
self.kv_zch_params.eviction_policy.meta_header_lens[i]
|
|
3245
|
+
)
|
|
1802
3246
|
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
3247
|
+
# linearize with table offset
|
|
3248
|
+
table_input_id_start = table_offset
|
|
3249
|
+
table_input_id_end = table_offset + emb_height
|
|
3250
|
+
# 1. get all keys from backend for one table
|
|
3251
|
+
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
|
|
3252
|
+
table_input_id_start,
|
|
3253
|
+
table_input_id_end,
|
|
3254
|
+
table_offset,
|
|
3255
|
+
snapshot_handle,
|
|
3256
|
+
)
|
|
3257
|
+
# 2. sorting keys in bucket ascending order
|
|
3258
|
+
bucket_ascending_id_tensor, bucket_t = (
|
|
3259
|
+
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
|
|
3260
|
+
unordered_id_tensor,
|
|
3261
|
+
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
|
|
3262
|
+
0, # local bucket offset
|
|
3263
|
+
bucket_id_end - bucket_id_start, # local bucket num
|
|
3264
|
+
bucket_size,
|
|
3265
|
+
)
|
|
3266
|
+
)
|
|
3267
|
+
metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
|
|
3268
|
+
bucket_ascending_id_tensor + table_offset,
|
|
3269
|
+
torch.as_tensor(bucket_ascending_id_tensor.size(0)),
|
|
3270
|
+
snapshot_handle,
|
|
3271
|
+
).view(-1, 1)
|
|
3272
|
+
|
|
3273
|
+
# 3. convert local id back to global id
|
|
3274
|
+
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
|
|
3275
|
+
|
|
3276
|
+
if (
|
|
3277
|
+
bucket_ascending_id_tensor.size(0) == 0
|
|
3278
|
+
and self.local_weight_counts[i] > 0
|
|
3279
|
+
):
|
|
3280
|
+
logging.info(
|
|
3281
|
+
f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
|
|
3282
|
+
)
|
|
3283
|
+
if self.global_id_per_rank[i].numel() != 0:
|
|
3284
|
+
assert (
|
|
3285
|
+
self.local_weight_counts[i]
|
|
3286
|
+
== self.global_id_per_rank[i].numel()
|
|
3287
|
+
), f"local weight count and global id per rank size mismatch, with {self.local_weight_counts[i]} and {self.global_id_per_rank[i].numel()}"
|
|
3288
|
+
bucket_ascending_id_tensor = self.global_id_per_rank[i].to(
|
|
3289
|
+
device=torch.device("cpu"), dtype=torch.int64
|
|
3290
|
+
)
|
|
3291
|
+
else:
|
|
3292
|
+
bucket_ascending_id_tensor = torch.zeros(
|
|
3293
|
+
(self.local_weight_counts[i], 1),
|
|
3294
|
+
device=torch.device("cpu"),
|
|
3295
|
+
dtype=torch.int64,
|
|
3296
|
+
)
|
|
3297
|
+
skip_metadata = True
|
|
3298
|
+
is_loading_checkpoint = True
|
|
3299
|
+
|
|
3300
|
+
# self.local_weight_counts[i] = 0 # Reset the count
|
|
3301
|
+
|
|
3302
|
+
# pyre-ignore [16] bucket_sorted_id_splits is not None
|
|
3303
|
+
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
|
|
3304
|
+
active_id_cnt_per_bucket_split.append(bucket_t)
|
|
3305
|
+
if skip_metadata:
|
|
3306
|
+
metadata_splits = None
|
|
3307
|
+
else:
|
|
3308
|
+
metadata_splits.append(metadata_tensor)
|
|
3309
|
+
|
|
3310
|
+
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
|
|
3311
|
+
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
|
|
3312
|
+
# first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
|
|
3313
|
+
# to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
|
|
3314
|
+
row_offset = table_offset - (bucket_id_start * bucket_size)
|
|
3315
|
+
|
|
3316
|
+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
3317
|
+
shape=[
|
|
3318
|
+
(
|
|
3319
|
+
bucket_ascending_id_tensor.size(0)
|
|
3320
|
+
if bucket_ascending_id_tensor is not None
|
|
3321
|
+
else emb_height
|
|
3322
|
+
),
|
|
3323
|
+
self.get_embedding_dim_for_kvt(
|
|
3324
|
+
metaheader_dim, emb_dim, is_loading_checkpoint
|
|
3325
|
+
),
|
|
3326
|
+
],
|
|
3327
|
+
dtype=dtype,
|
|
1809
3328
|
row_offset=row_offset,
|
|
1810
3329
|
snapshot_handle=snapshot_handle,
|
|
3330
|
+
# set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
|
|
3331
|
+
# embedding weights.
|
|
3332
|
+
sorted_indices=(
|
|
3333
|
+
bucket_ascending_id_tensor if self.kv_zch_params else None
|
|
3334
|
+
),
|
|
3335
|
+
checkpoint_handle=checkpoint_handle,
|
|
3336
|
+
only_load_weight=(
|
|
3337
|
+
True
|
|
3338
|
+
if self.load_ckpt_without_opt and is_loading_checkpoint
|
|
3339
|
+
else False
|
|
3340
|
+
),
|
|
3341
|
+
)
|
|
3342
|
+
(
|
|
3343
|
+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
3344
|
+
if self.backend_type == BackendType.SSD
|
|
3345
|
+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
3346
|
+
)
|
|
3347
|
+
table_offset += emb_height
|
|
3348
|
+
pmt_splits.append(
|
|
3349
|
+
PartiallyMaterializedTensor(
|
|
3350
|
+
tensor_wrapper,
|
|
3351
|
+
True if self.kv_zch_params else False,
|
|
3352
|
+
)
|
|
3353
|
+
)
|
|
3354
|
+
logging.info(
|
|
3355
|
+
f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms, "
|
|
3356
|
+
)
|
|
3357
|
+
if self.kv_zch_params is not None:
|
|
3358
|
+
logging.info(
|
|
3359
|
+
# pyre-ignore [16]
|
|
3360
|
+
f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
|
|
3361
|
+
)
|
|
3362
|
+
|
|
3363
|
+
return (
|
|
3364
|
+
pmt_splits,
|
|
3365
|
+
bucket_sorted_id_splits,
|
|
3366
|
+
active_id_cnt_per_bucket_split,
|
|
3367
|
+
metadata_splits,
|
|
3368
|
+
)
|
|
3369
|
+
|
|
3370
|
+
@torch.jit.ignore
|
|
3371
|
+
def _apply_state_dict_w_offloading(self) -> None:
|
|
3372
|
+
# Row count per table
|
|
3373
|
+
rows, _ = zip(*self.embedding_specs)
|
|
3374
|
+
# Cumulative row counts per table for rowwise states
|
|
3375
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
3376
|
+
|
|
3377
|
+
for t, _ in enumerate(self.embedding_specs):
|
|
3378
|
+
# pyre-ignore [16]
|
|
3379
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
3380
|
+
# pyre-ignore [16]
|
|
3381
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
3382
|
+
row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
|
|
3383
|
+
|
|
3384
|
+
# pyre-ignore [16]
|
|
3385
|
+
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
|
|
3386
|
+
# pyre-ignore [16]
|
|
3387
|
+
opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
|
|
3388
|
+
|
|
3389
|
+
self.streaming_write_weight_and_id_per_table(
|
|
3390
|
+
weight_state,
|
|
3391
|
+
opt_states,
|
|
3392
|
+
# pyre-ignore [16]
|
|
3393
|
+
self._cached_kvzch_data.cached_id_tensor_per_table[t],
|
|
3394
|
+
row_offset,
|
|
3395
|
+
)
|
|
3396
|
+
self._cached_kvzch_data.cached_weight_tensor_per_table[t] = None
|
|
3397
|
+
self._cached_kvzch_data.cached_optimizer_states_per_table[t] = None
|
|
3398
|
+
|
|
3399
|
+
@torch.jit.ignore
|
|
3400
|
+
def _apply_state_dict_no_offloading(self) -> None:
|
|
3401
|
+
# Row count per table
|
|
3402
|
+
rows, _ = zip(*self.embedding_specs)
|
|
3403
|
+
# Cumulative row counts per table for rowwise states
|
|
3404
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
3405
|
+
|
|
3406
|
+
def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None:
|
|
3407
|
+
device = dst.device
|
|
3408
|
+
dst.index_put_(
|
|
3409
|
+
indices=(
|
|
3410
|
+
# indices is expected to be a tuple of Tensors, not Tensor
|
|
3411
|
+
indices.to(device).view(-1),
|
|
3412
|
+
),
|
|
3413
|
+
values=src.to(device),
|
|
3414
|
+
)
|
|
3415
|
+
|
|
3416
|
+
for t, _ in enumerate(rows):
|
|
3417
|
+
# pyre-ignore [16]
|
|
3418
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
3419
|
+
# pyre-ignore [16]
|
|
3420
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
3421
|
+
row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
|
|
3422
|
+
|
|
3423
|
+
# pyre-ignore [16]
|
|
3424
|
+
weights = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
|
|
3425
|
+
# pyre-ignore [16]
|
|
3426
|
+
ids = self._cached_kvzch_data.cached_id_tensor_per_table[t]
|
|
3427
|
+
local_ids = ids + row_offset
|
|
3428
|
+
|
|
3429
|
+
logging.info(
|
|
3430
|
+
f"applying sd for table {t} without optimizer offloading, local_ids is {local_ids}"
|
|
3431
|
+
)
|
|
3432
|
+
# pyre-ignore [16]
|
|
3433
|
+
opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
|
|
3434
|
+
|
|
3435
|
+
# Set up the plan for copying optimizer states over
|
|
3436
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
3437
|
+
mapping = [(opt_states[0], self.momentum1_dev)]
|
|
3438
|
+
elif self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
|
|
3439
|
+
mapping = [
|
|
3440
|
+
(opt_states[0], self.momentum1_dev),
|
|
3441
|
+
(opt_states[1], self.momentum2_dev),
|
|
3442
|
+
]
|
|
3443
|
+
else:
|
|
3444
|
+
mapping = []
|
|
3445
|
+
|
|
3446
|
+
# Execute the plan and copy the optimizer states over
|
|
3447
|
+
# pyre-ignore [6]
|
|
3448
|
+
[copy_optimizer_state_(dst, src, local_ids) for (src, dst) in mapping]
|
|
3449
|
+
|
|
3450
|
+
self.ssd_db.set_cuda(
|
|
3451
|
+
local_ids.view(-1),
|
|
3452
|
+
weights,
|
|
3453
|
+
torch.as_tensor(local_ids.size(0)),
|
|
3454
|
+
1,
|
|
3455
|
+
False,
|
|
3456
|
+
)
|
|
3457
|
+
|
|
3458
|
+
@torch.jit.ignore
|
|
3459
|
+
def apply_state_dict(self) -> None:
|
|
3460
|
+
if self.backend_return_whole_row:
|
|
3461
|
+
logging.info(
|
|
3462
|
+
"backend_return_whole_row is enabled, no need to apply_state_dict"
|
|
3463
|
+
)
|
|
3464
|
+
return
|
|
3465
|
+
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
|
|
3466
|
+
# Caller should call this function to apply the cached states to backend.
|
|
3467
|
+
if self.load_state_dict is False:
|
|
3468
|
+
return
|
|
3469
|
+
self.load_state_dict = False
|
|
3470
|
+
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
|
|
3471
|
+
assert (
|
|
3472
|
+
self._cached_kvzch_data is not None
|
|
3473
|
+
and self._cached_kvzch_data.cached_optimizer_states_per_table is not None
|
|
3474
|
+
), "optimizer state is not initialized for load checkpointing"
|
|
3475
|
+
assert (
|
|
3476
|
+
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
|
|
3477
|
+
and self._cached_kvzch_data.cached_id_tensor_per_table is not None
|
|
3478
|
+
), "weight and id state is not initialized for load checkpointing"
|
|
3479
|
+
|
|
3480
|
+
# Compute the number of elements of cache_dtype needed to store the
|
|
3481
|
+
# optimizer state, round to the nearest 4
|
|
3482
|
+
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
|
|
3483
|
+
# apply weight and optimizer state per table
|
|
3484
|
+
if self.enable_optimizer_offloading:
|
|
3485
|
+
self._apply_state_dict_w_offloading()
|
|
3486
|
+
else:
|
|
3487
|
+
self._apply_state_dict_no_offloading()
|
|
3488
|
+
|
|
3489
|
+
self.clear_cache()
|
|
3490
|
+
|
|
3491
|
+
@torch.jit.ignore
|
|
3492
|
+
def streaming_write_weight_and_id_per_table(
|
|
3493
|
+
self,
|
|
3494
|
+
weight_state: torch.Tensor,
|
|
3495
|
+
opt_states: list[torch.Tensor],
|
|
3496
|
+
id_tensor: torch.Tensor,
|
|
3497
|
+
row_offset: int,
|
|
3498
|
+
) -> None:
|
|
3499
|
+
"""
|
|
3500
|
+
This function is used to write weight, optimizer and id to the backend using kvt wrapper.
|
|
3501
|
+
to avoid over use memory, we will write the weight and id to backend in a rolling window manner
|
|
3502
|
+
|
|
3503
|
+
Args:
|
|
3504
|
+
weight_state (torch.tensor): The weight state tensor to be written.
|
|
3505
|
+
opt_states (torch.tensor): The optimizer state tensor(s) to be written.
|
|
3506
|
+
id_tensor (torch.tensor): The id tensor to be written.
|
|
3507
|
+
"""
|
|
3508
|
+
D = weight_state.size(1)
|
|
3509
|
+
dtype = self.weights_precision.as_dtype()
|
|
3510
|
+
|
|
3511
|
+
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
3512
|
+
D, self.weights_precision, self.optimizer_state_dtypes
|
|
3513
|
+
)
|
|
3514
|
+
optimizer_state_size_table = self.optimizer.state_size_table(D)
|
|
3515
|
+
|
|
3516
|
+
kvt = torch.classes.fbgemm.KVTensorWrapper(
|
|
3517
|
+
shape=[weight_state.size(0), self.cache_row_dim],
|
|
3518
|
+
dtype=dtype,
|
|
3519
|
+
row_offset=row_offset,
|
|
3520
|
+
snapshot_handle=None,
|
|
3521
|
+
sorted_indices=id_tensor,
|
|
3522
|
+
)
|
|
3523
|
+
(
|
|
3524
|
+
kvt.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
3525
|
+
if self.backend_type == BackendType.SSD
|
|
3526
|
+
else kvt.set_dram_db_wrapper(self.ssd_db)
|
|
3527
|
+
)
|
|
3528
|
+
|
|
3529
|
+
# TODO: make chunk_size configurable or dynamic
|
|
3530
|
+
chunk_size = 10000
|
|
3531
|
+
row = weight_state.size(0)
|
|
3532
|
+
|
|
3533
|
+
for i in range(0, row, chunk_size):
|
|
3534
|
+
# Construct the chunk buffer, using the weights precision as the dtype
|
|
3535
|
+
length = min(chunk_size, row - i)
|
|
3536
|
+
chunk_buffer = torch.empty(
|
|
3537
|
+
length,
|
|
3538
|
+
self.cache_row_dim,
|
|
3539
|
+
dtype=dtype,
|
|
3540
|
+
device="cpu",
|
|
3541
|
+
)
|
|
3542
|
+
|
|
3543
|
+
# Copy the weight state over to the chunk buffer
|
|
3544
|
+
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
|
|
3545
|
+
|
|
3546
|
+
# Copy the optimizer state(s) over to the chunk buffer
|
|
3547
|
+
for o, opt_state in enumerate(opt_states):
|
|
3548
|
+
# Fetch the state name based on the index
|
|
3549
|
+
state_name = self.optimizer.state_names()[o]
|
|
3550
|
+
|
|
3551
|
+
# Fetch the byte offsets for the optimizer state by its name
|
|
3552
|
+
start, end = optimizer_state_byte_offsets[state_name]
|
|
3553
|
+
|
|
3554
|
+
# Assume that the opt_state passed in already has dtype matching
|
|
3555
|
+
# self.optimizer_state_dtypes[state_name]
|
|
3556
|
+
opt_state_byteview = opt_state.view(
|
|
3557
|
+
# Force it to be 2D table, with row size matching the
|
|
3558
|
+
# optimizer state size
|
|
3559
|
+
-1,
|
|
3560
|
+
optimizer_state_size_table[state_name],
|
|
3561
|
+
).view(
|
|
3562
|
+
# Then force tensor to byte view
|
|
3563
|
+
dtype=torch.uint8
|
|
3564
|
+
)
|
|
3565
|
+
|
|
3566
|
+
# Convert the chunk buffer and optimizer state to byte views
|
|
3567
|
+
# Then use the start and end offsets to narrow the chunk buffer
|
|
3568
|
+
# and copy opt_state over
|
|
3569
|
+
chunk_buffer.view(dtype=torch.uint8)[:, start:end] = opt_state_byteview[
|
|
3570
|
+
i : i + length, :
|
|
3571
|
+
]
|
|
3572
|
+
|
|
3573
|
+
# Write chunk to KVTensor
|
|
3574
|
+
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
|
|
3575
|
+
|
|
3576
|
+
@torch.jit.ignore
|
|
3577
|
+
def enable_load_state_dict_mode(self) -> None:
|
|
3578
|
+
if self.backend_return_whole_row:
|
|
3579
|
+
logging.info(
|
|
3580
|
+
"backend_return_whole_row is enabled, no need to enable load_state_dict mode"
|
|
3581
|
+
)
|
|
3582
|
+
return
|
|
3583
|
+
# Enable load state dict mode before loading checkpoint
|
|
3584
|
+
if self.load_state_dict:
|
|
3585
|
+
return
|
|
3586
|
+
self.load_state_dict = True
|
|
3587
|
+
|
|
3588
|
+
dtype = self.weights_precision.as_dtype()
|
|
3589
|
+
_, dims = zip(*self.embedding_specs)
|
|
3590
|
+
|
|
3591
|
+
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
|
|
3592
|
+
|
|
3593
|
+
for i, _ in enumerate(self.embedding_specs):
|
|
3594
|
+
# For checkpointing loading, we need to store the weight and id
|
|
3595
|
+
# tensor temporarily in memory. First check that the local_weight_counts
|
|
3596
|
+
# are properly set before even initializing the optimizer states
|
|
3597
|
+
assert (
|
|
3598
|
+
self.local_weight_counts[i] > 0
|
|
3599
|
+
), f"local_weight_counts for table {i} is not set"
|
|
3600
|
+
|
|
3601
|
+
# pyre-ignore [16]
|
|
3602
|
+
self._cached_kvzch_data.cached_optimizer_states_per_table = (
|
|
3603
|
+
self.optimizer.empty_states(
|
|
3604
|
+
self.local_weight_counts,
|
|
3605
|
+
dims,
|
|
3606
|
+
self.optimizer_state_dtypes,
|
|
3607
|
+
)
|
|
3608
|
+
)
|
|
3609
|
+
|
|
3610
|
+
for i, (_, emb_dim) in enumerate(self.embedding_specs):
|
|
3611
|
+
# pyre-ignore [16]
|
|
3612
|
+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
|
|
3613
|
+
rows = self.local_weight_counts[i]
|
|
3614
|
+
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
|
|
3615
|
+
# pyre-ignore [16]
|
|
3616
|
+
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
|
|
3617
|
+
logging.info(
|
|
3618
|
+
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}"
|
|
3619
|
+
)
|
|
3620
|
+
id_tensor = torch.zeros((rows, 1), dtype=torch.int64, device="cpu")
|
|
3621
|
+
# pyre-ignore [16]
|
|
3622
|
+
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
|
|
3623
|
+
# pyre-ignore [16]
|
|
3624
|
+
self._cached_kvzch_data.cached_bucket_splits.append(
|
|
3625
|
+
torch.empty(
|
|
3626
|
+
(bucket_id_end - bucket_id_start, 1),
|
|
3627
|
+
dtype=torch.int64,
|
|
3628
|
+
device="cpu",
|
|
3629
|
+
)
|
|
1811
3630
|
)
|
|
1812
|
-
row_offset += emb_height
|
|
1813
|
-
splits.append(PartiallyMaterializedTensor(tensor_wrapper))
|
|
1814
|
-
return splits
|
|
1815
3631
|
|
|
1816
3632
|
@torch.jit.export
|
|
1817
3633
|
def set_learning_rate(self, lr: float) -> None:
|
|
1818
3634
|
"""
|
|
1819
3635
|
Sets the learning rate.
|
|
3636
|
+
|
|
3637
|
+
Args:
|
|
3638
|
+
lr (float): The learning rate value to set to
|
|
1820
3639
|
"""
|
|
1821
3640
|
self._set_learning_rate(lr)
|
|
1822
3641
|
|
|
1823
3642
|
def get_learning_rate(self) -> float:
|
|
1824
3643
|
"""
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
Args:
|
|
1828
|
-
lr (float): The learning rate value to set to
|
|
3644
|
+
Get and return the learning rate.
|
|
1829
3645
|
"""
|
|
1830
|
-
return self.
|
|
3646
|
+
return self.learning_rate_tensor.item()
|
|
1831
3647
|
|
|
1832
3648
|
@torch.jit.ignore
|
|
1833
3649
|
def _set_learning_rate(self, lr: float) -> float:
|
|
@@ -1835,14 +3651,30 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1835
3651
|
Helper function to script `set_learning_rate`.
|
|
1836
3652
|
Note that returning None does not work.
|
|
1837
3653
|
"""
|
|
1838
|
-
self.
|
|
3654
|
+
self.learning_rate_tensor = torch.tensor(
|
|
3655
|
+
lr, device=torch.device("cpu"), dtype=torch.float32
|
|
3656
|
+
)
|
|
1839
3657
|
return 0.0
|
|
1840
3658
|
|
|
1841
|
-
def flush(self) -> None:
|
|
3659
|
+
def flush(self, force: bool = False) -> None:
|
|
3660
|
+
# allow force flush from split_embedding_weights to cover edge cases, e.g. checkpointing
|
|
3661
|
+
# after trained 0 batches
|
|
3662
|
+
if not self.training:
|
|
3663
|
+
# for eval mode, we should not write anything to embedding
|
|
3664
|
+
return
|
|
3665
|
+
|
|
3666
|
+
if self.step == self.last_flush_step and not force:
|
|
3667
|
+
logging.info(
|
|
3668
|
+
f"SSD TBE has been flushed at {self.last_flush_step=} already for tbe:{self.tbe_unique_id}"
|
|
3669
|
+
)
|
|
3670
|
+
return
|
|
3671
|
+
logging.info(
|
|
3672
|
+
f"SSD TBE flush at {self.step=}, it is an expensive call please be cautious"
|
|
3673
|
+
)
|
|
1842
3674
|
active_slots_mask = self.lxu_cache_state != -1
|
|
1843
3675
|
|
|
1844
3676
|
active_weights_gpu = self.lxu_cache_weights[active_slots_mask.view(-1)].view(
|
|
1845
|
-
-1, self.
|
|
3677
|
+
-1, self.cache_row_dim
|
|
1846
3678
|
)
|
|
1847
3679
|
active_ids_gpu = self.lxu_cache_state.view(-1)[active_slots_mask.view(-1)]
|
|
1848
3680
|
|
|
@@ -1858,24 +3690,38 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1858
3690
|
torch.tensor([active_ids_cpu.numel()]),
|
|
1859
3691
|
)
|
|
1860
3692
|
self.ssd_db.flush()
|
|
3693
|
+
self.last_flush_step = self.step
|
|
3694
|
+
|
|
3695
|
+
def create_rocksdb_hard_link_snapshot(self) -> None:
|
|
3696
|
+
"""
|
|
3697
|
+
Create a rocksdb hard link snapshot to provide cross procs access to the underlying data
|
|
3698
|
+
"""
|
|
3699
|
+
if self.backend_type == BackendType.SSD:
|
|
3700
|
+
self.ssd_db.create_rocksdb_hard_link_snapshot(self.step)
|
|
3701
|
+
else:
|
|
3702
|
+
logging.warning(
|
|
3703
|
+
"create_rocksdb_hard_link_snapshot is only supported for SSD backend"
|
|
3704
|
+
)
|
|
1861
3705
|
|
|
1862
3706
|
def prepare_inputs(
|
|
1863
3707
|
self,
|
|
1864
3708
|
indices: Tensor,
|
|
1865
3709
|
offsets: Tensor,
|
|
1866
3710
|
per_sample_weights: Optional[Tensor] = None,
|
|
1867
|
-
batch_size_per_feature_per_rank: Optional[
|
|
1868
|
-
|
|
3711
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3712
|
+
vbe_output: Optional[Tensor] = None,
|
|
3713
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
3714
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
|
|
1869
3715
|
"""
|
|
1870
3716
|
Prepare TBE inputs
|
|
1871
3717
|
"""
|
|
1872
3718
|
# Generate VBE metadata
|
|
1873
3719
|
vbe_metadata = self._generate_vbe_metadata(
|
|
1874
|
-
offsets, batch_size_per_feature_per_rank
|
|
3720
|
+
offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
|
|
1875
3721
|
)
|
|
1876
3722
|
|
|
1877
3723
|
# Force casting indices and offsets to long
|
|
1878
|
-
|
|
3724
|
+
indices, offsets = indices.long(), offsets.long()
|
|
1879
3725
|
|
|
1880
3726
|
# Force casting per_sample_weights to float
|
|
1881
3727
|
if per_sample_weights is not None:
|
|
@@ -1891,12 +3737,13 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1891
3737
|
per_sample_weights,
|
|
1892
3738
|
B_offsets=vbe_metadata.B_offsets,
|
|
1893
3739
|
max_B=vbe_metadata.max_B,
|
|
3740
|
+
bounds_check_version=self.bounds_check_version,
|
|
1894
3741
|
)
|
|
1895
3742
|
|
|
1896
3743
|
return indices, offsets, per_sample_weights, vbe_metadata
|
|
1897
3744
|
|
|
1898
3745
|
@torch.jit.ignore
|
|
1899
|
-
def
|
|
3746
|
+
def _report_kv_backend_stats(self) -> None:
|
|
1900
3747
|
"""
|
|
1901
3748
|
All ssd stats report function entrance
|
|
1902
3749
|
"""
|
|
@@ -1906,9 +3753,15 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1906
3753
|
if not self.stats_reporter.should_report(self.step):
|
|
1907
3754
|
return
|
|
1908
3755
|
self._report_ssd_l1_cache_stats()
|
|
1909
|
-
|
|
1910
|
-
self.
|
|
1911
|
-
|
|
3756
|
+
|
|
3757
|
+
if self.backend_type == BackendType.SSD:
|
|
3758
|
+
self._report_ssd_io_stats()
|
|
3759
|
+
self._report_ssd_mem_usage()
|
|
3760
|
+
self._report_l2_cache_perf_stats()
|
|
3761
|
+
if self.backend_type == BackendType.DRAM:
|
|
3762
|
+
self._report_dram_kv_perf_stats()
|
|
3763
|
+
if self.kv_zch_params and self.kv_zch_params.eviction_policy:
|
|
3764
|
+
self._report_eviction_stats()
|
|
1912
3765
|
|
|
1913
3766
|
@torch.jit.ignore
|
|
1914
3767
|
def _report_ssd_l1_cache_stats(self) -> None:
|
|
@@ -1925,7 +3778,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1925
3778
|
ssd_cache_stats = self.ssd_cache_stats.tolist()
|
|
1926
3779
|
if len(self.last_reported_ssd_stats) == 0:
|
|
1927
3780
|
self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
|
|
1928
|
-
ssd_cache_stats_delta:
|
|
3781
|
+
ssd_cache_stats_delta: list[float] = [0.0] * len(ssd_cache_stats)
|
|
1929
3782
|
for i in range(len(ssd_cache_stats)):
|
|
1930
3783
|
ssd_cache_stats_delta[i] = (
|
|
1931
3784
|
ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
|
|
@@ -1942,11 +3795,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1942
3795
|
data_bytes=int(
|
|
1943
3796
|
ssd_cache_stats_delta[stat_index.value]
|
|
1944
3797
|
* element_size
|
|
1945
|
-
* self.
|
|
3798
|
+
* self.cache_row_dim
|
|
1946
3799
|
/ passed_steps
|
|
1947
3800
|
),
|
|
1948
3801
|
)
|
|
1949
|
-
|
|
3802
|
+
|
|
1950
3803
|
self.stats_reporter.report_data_amount(
|
|
1951
3804
|
iteration_step=self.step,
|
|
1952
3805
|
event_name=f"ssd_tbe.prefetch.cache_stats.{stat_index.name.lower()}",
|
|
@@ -1973,35 +3826,35 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1973
3826
|
bwd_l1_cnflct_miss_write_back_dur = ssd_io_duration[3]
|
|
1974
3827
|
flush_write_dur = ssd_io_duration[4]
|
|
1975
3828
|
|
|
1976
|
-
# pyre-ignore
|
|
3829
|
+
# pyre-ignore [16]
|
|
1977
3830
|
self.stats_reporter.report_duration(
|
|
1978
3831
|
iteration_step=self.step,
|
|
1979
3832
|
event_name="ssd.io_duration.read_us",
|
|
1980
3833
|
duration_ms=ssd_read_dur_us,
|
|
1981
3834
|
time_unit="us",
|
|
1982
3835
|
)
|
|
1983
|
-
|
|
3836
|
+
|
|
1984
3837
|
self.stats_reporter.report_duration(
|
|
1985
3838
|
iteration_step=self.step,
|
|
1986
3839
|
event_name="ssd.io_duration.write.fwd_rocksdb_read_us",
|
|
1987
3840
|
duration_ms=fwd_rocksdb_read_dur,
|
|
1988
3841
|
time_unit="us",
|
|
1989
3842
|
)
|
|
1990
|
-
|
|
3843
|
+
|
|
1991
3844
|
self.stats_reporter.report_duration(
|
|
1992
3845
|
iteration_step=self.step,
|
|
1993
3846
|
event_name="ssd.io_duration.write.fwd_l1_eviction_us",
|
|
1994
3847
|
duration_ms=fwd_l1_eviction_dur,
|
|
1995
3848
|
time_unit="us",
|
|
1996
3849
|
)
|
|
1997
|
-
|
|
3850
|
+
|
|
1998
3851
|
self.stats_reporter.report_duration(
|
|
1999
3852
|
iteration_step=self.step,
|
|
2000
3853
|
event_name="ssd.io_duration.write.bwd_l1_cnflct_miss_write_back_us",
|
|
2001
3854
|
duration_ms=bwd_l1_cnflct_miss_write_back_dur,
|
|
2002
3855
|
time_unit="us",
|
|
2003
3856
|
)
|
|
2004
|
-
|
|
3857
|
+
|
|
2005
3858
|
self.stats_reporter.report_duration(
|
|
2006
3859
|
iteration_step=self.step,
|
|
2007
3860
|
event_name="ssd.io_duration.write.flush_write_us",
|
|
@@ -2023,25 +3876,25 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2023
3876
|
memtable_usage = mem_usage_list[2]
|
|
2024
3877
|
block_cache_pinned_usage = mem_usage_list[3]
|
|
2025
3878
|
|
|
2026
|
-
# pyre-ignore
|
|
3879
|
+
# pyre-ignore [16]
|
|
2027
3880
|
self.stats_reporter.report_data_amount(
|
|
2028
3881
|
iteration_step=self.step,
|
|
2029
3882
|
event_name="ssd.mem_usage.block_cache",
|
|
2030
3883
|
data_bytes=block_cache_usage,
|
|
2031
3884
|
)
|
|
2032
|
-
|
|
3885
|
+
|
|
2033
3886
|
self.stats_reporter.report_data_amount(
|
|
2034
3887
|
iteration_step=self.step,
|
|
2035
3888
|
event_name="ssd.mem_usage.estimate_table_reader",
|
|
2036
3889
|
data_bytes=estimate_table_reader_usage,
|
|
2037
3890
|
)
|
|
2038
|
-
|
|
3891
|
+
|
|
2039
3892
|
self.stats_reporter.report_data_amount(
|
|
2040
3893
|
iteration_step=self.step,
|
|
2041
3894
|
event_name="ssd.mem_usage.memtable",
|
|
2042
3895
|
data_bytes=memtable_usage,
|
|
2043
3896
|
)
|
|
2044
|
-
|
|
3897
|
+
|
|
2045
3898
|
self.stats_reporter.report_data_amount(
|
|
2046
3899
|
iteration_step=self.step,
|
|
2047
3900
|
event_name="ssd.mem_usage.block_cache_pinned",
|
|
@@ -2175,7 +4028,408 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2175
4028
|
time_unit="us",
|
|
2176
4029
|
)
|
|
2177
4030
|
|
|
2178
|
-
|
|
4031
|
+
@torch.jit.ignore
|
|
4032
|
+
def _report_eviction_stats(self) -> None:
|
|
4033
|
+
if self.stats_reporter is None:
|
|
4034
|
+
return
|
|
4035
|
+
|
|
4036
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
4037
|
+
if not stats_reporter.should_report(self.step):
|
|
4038
|
+
return
|
|
4039
|
+
|
|
4040
|
+
# skip metrics reporting when evicting disabled
|
|
4041
|
+
if self.kv_zch_params.eviction_policy.eviction_trigger_mode == 0:
|
|
4042
|
+
return
|
|
4043
|
+
|
|
4044
|
+
T = len(set(self.feature_table_map))
|
|
4045
|
+
evicted_counts = torch.zeros(T, dtype=torch.int64)
|
|
4046
|
+
processed_counts = torch.zeros(T, dtype=torch.int64)
|
|
4047
|
+
eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float)
|
|
4048
|
+
full_duration_ms = torch.tensor(0, dtype=torch.int64)
|
|
4049
|
+
exec_duration_ms = torch.tensor(0, dtype=torch.int64)
|
|
4050
|
+
self.ssd_db.get_feature_evict_metric(
|
|
4051
|
+
evicted_counts,
|
|
4052
|
+
processed_counts,
|
|
4053
|
+
eviction_threshold_with_dry_run,
|
|
4054
|
+
full_duration_ms,
|
|
4055
|
+
exec_duration_ms,
|
|
4056
|
+
)
|
|
4057
|
+
|
|
4058
|
+
stats_reporter.report_data_amount(
|
|
4059
|
+
iteration_step=self.step,
|
|
4060
|
+
event_name=self.eviction_sum_evicted_counts_stats_name,
|
|
4061
|
+
data_bytes=int(evicted_counts.sum().item()),
|
|
4062
|
+
enable_tb_metrics=True,
|
|
4063
|
+
)
|
|
4064
|
+
stats_reporter.report_data_amount(
|
|
4065
|
+
iteration_step=self.step,
|
|
4066
|
+
event_name=self.eviction_sum_processed_counts_stats_name,
|
|
4067
|
+
data_bytes=int(processed_counts.sum().item()),
|
|
4068
|
+
enable_tb_metrics=True,
|
|
4069
|
+
)
|
|
4070
|
+
if processed_counts.sum().item() != 0:
|
|
4071
|
+
stats_reporter.report_data_amount(
|
|
4072
|
+
iteration_step=self.step,
|
|
4073
|
+
event_name=self.eviction_evict_rate_stats_name,
|
|
4074
|
+
data_bytes=int(
|
|
4075
|
+
evicted_counts.sum().item() * 100 / processed_counts.sum().item()
|
|
4076
|
+
),
|
|
4077
|
+
enable_tb_metrics=True,
|
|
4078
|
+
)
|
|
4079
|
+
for t in self.feature_table_map:
|
|
4080
|
+
stats_reporter.report_data_amount(
|
|
4081
|
+
iteration_step=self.step,
|
|
4082
|
+
event_name=f"eviction.feature_table.{t}.evicted_counts",
|
|
4083
|
+
data_bytes=int(evicted_counts[t].item()),
|
|
4084
|
+
enable_tb_metrics=True,
|
|
4085
|
+
)
|
|
4086
|
+
stats_reporter.report_data_amount(
|
|
4087
|
+
iteration_step=self.step,
|
|
4088
|
+
event_name=f"eviction.feature_table.{t}.processed_counts",
|
|
4089
|
+
data_bytes=int(processed_counts[t].item()),
|
|
4090
|
+
enable_tb_metrics=True,
|
|
4091
|
+
)
|
|
4092
|
+
if processed_counts[t].item() != 0:
|
|
4093
|
+
stats_reporter.report_data_amount(
|
|
4094
|
+
iteration_step=self.step,
|
|
4095
|
+
event_name=f"eviction.feature_table.{t}.evict_rate",
|
|
4096
|
+
data_bytes=int(
|
|
4097
|
+
evicted_counts[t].item() * 100 / processed_counts[t].item()
|
|
4098
|
+
),
|
|
4099
|
+
enable_tb_metrics=True,
|
|
4100
|
+
)
|
|
4101
|
+
stats_reporter.report_duration(
|
|
4102
|
+
iteration_step=self.step,
|
|
4103
|
+
event_name="eviction.feature_table.full_duration_ms",
|
|
4104
|
+
duration_ms=full_duration_ms.item(),
|
|
4105
|
+
time_unit="ms",
|
|
4106
|
+
enable_tb_metrics=True,
|
|
4107
|
+
)
|
|
4108
|
+
stats_reporter.report_duration(
|
|
4109
|
+
iteration_step=self.step,
|
|
4110
|
+
event_name="eviction.feature_table.exec_duration_ms",
|
|
4111
|
+
duration_ms=exec_duration_ms.item(),
|
|
4112
|
+
time_unit="ms",
|
|
4113
|
+
enable_tb_metrics=True,
|
|
4114
|
+
)
|
|
4115
|
+
if full_duration_ms.item() != 0:
|
|
4116
|
+
stats_reporter.report_data_amount(
|
|
4117
|
+
iteration_step=self.step,
|
|
4118
|
+
event_name="eviction.feature_table.exec_div_full_duration_rate",
|
|
4119
|
+
data_bytes=int(exec_duration_ms.item() * 100 / full_duration_ms.item()),
|
|
4120
|
+
enable_tb_metrics=True,
|
|
4121
|
+
)
|
|
4122
|
+
|
|
4123
|
+
@torch.jit.ignore
|
|
4124
|
+
def _report_dram_kv_perf_stats(self) -> None:
|
|
4125
|
+
"""
|
|
4126
|
+
EmbeddingKVDB will hold stats for DRAM cache performance in fwd/bwd
|
|
4127
|
+
this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
|
|
4128
|
+
"""
|
|
4129
|
+
if self.stats_reporter is None:
|
|
4130
|
+
return
|
|
4131
|
+
|
|
4132
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
4133
|
+
if not stats_reporter.should_report(self.step):
|
|
4134
|
+
return
|
|
4135
|
+
|
|
4136
|
+
dram_kv_perf_stats = self.ssd_db.get_dram_kv_perf(
|
|
4137
|
+
self.step, stats_reporter.report_interval # pyre-ignore
|
|
4138
|
+
)
|
|
4139
|
+
|
|
4140
|
+
if len(dram_kv_perf_stats) != 36:
|
|
4141
|
+
logging.error("dram cache perf stats should have 36 elements")
|
|
4142
|
+
return
|
|
4143
|
+
|
|
4144
|
+
dram_read_duration = dram_kv_perf_stats[0]
|
|
4145
|
+
dram_read_sharding_duration = dram_kv_perf_stats[1]
|
|
4146
|
+
dram_read_cache_hit_copy_duration = dram_kv_perf_stats[2]
|
|
4147
|
+
dram_read_fill_row_storage_duration = dram_kv_perf_stats[3]
|
|
4148
|
+
dram_read_lookup_cache_duration = dram_kv_perf_stats[4]
|
|
4149
|
+
dram_read_acquire_lock_duration = dram_kv_perf_stats[5]
|
|
4150
|
+
dram_read_missing_load = dram_kv_perf_stats[6]
|
|
4151
|
+
dram_write_sharing_duration = dram_kv_perf_stats[7]
|
|
4152
|
+
|
|
4153
|
+
dram_fwd_l1_eviction_write_duration = dram_kv_perf_stats[8]
|
|
4154
|
+
dram_fwd_l1_eviction_write_allocate_duration = dram_kv_perf_stats[9]
|
|
4155
|
+
dram_fwd_l1_eviction_write_cache_copy_duration = dram_kv_perf_stats[10]
|
|
4156
|
+
dram_fwd_l1_eviction_write_lookup_cache_duration = dram_kv_perf_stats[11]
|
|
4157
|
+
dram_fwd_l1_eviction_write_acquire_lock_duration = dram_kv_perf_stats[12]
|
|
4158
|
+
dram_fwd_l1_eviction_write_missing_load = dram_kv_perf_stats[13]
|
|
4159
|
+
|
|
4160
|
+
dram_bwd_l1_cnflct_miss_write_duration = dram_kv_perf_stats[14]
|
|
4161
|
+
dram_bwd_l1_cnflct_miss_write_allocate_duration = dram_kv_perf_stats[15]
|
|
4162
|
+
dram_bwd_l1_cnflct_miss_write_cache_copy_duration = dram_kv_perf_stats[16]
|
|
4163
|
+
dram_bwd_l1_cnflct_miss_write_lookup_cache_duration = dram_kv_perf_stats[17]
|
|
4164
|
+
dram_bwd_l1_cnflct_miss_write_acquire_lock_duration = dram_kv_perf_stats[18]
|
|
4165
|
+
dram_bwd_l1_cnflct_miss_write_missing_load = dram_kv_perf_stats[19]
|
|
4166
|
+
|
|
4167
|
+
dram_kv_allocated_bytes = dram_kv_perf_stats[20]
|
|
4168
|
+
dram_kv_actual_used_chunk_bytes = dram_kv_perf_stats[21]
|
|
4169
|
+
dram_kv_num_rows = dram_kv_perf_stats[22]
|
|
4170
|
+
dram_kv_read_counts = dram_kv_perf_stats[23]
|
|
4171
|
+
dram_metadata_write_sharding_total_duration = dram_kv_perf_stats[24]
|
|
4172
|
+
dram_metadata_write_total_duration = dram_kv_perf_stats[25]
|
|
4173
|
+
dram_metadata_write_allocate_avg_duration = dram_kv_perf_stats[26]
|
|
4174
|
+
dram_metadata_write_lookup_cache_avg_duration = dram_kv_perf_stats[27]
|
|
4175
|
+
dram_metadata_write_acquire_lock_avg_duration = dram_kv_perf_stats[28]
|
|
4176
|
+
dram_metadata_write_cache_miss_avg_count = dram_kv_perf_stats[29]
|
|
4177
|
+
|
|
4178
|
+
dram_read_metadata_total_duration = dram_kv_perf_stats[30]
|
|
4179
|
+
dram_read_metadata_sharding_total_duration = dram_kv_perf_stats[31]
|
|
4180
|
+
dram_read_metadata_cache_hit_copy_avg_duration = dram_kv_perf_stats[32]
|
|
4181
|
+
dram_read_metadata_lookup_cache_total_avg_duration = dram_kv_perf_stats[33]
|
|
4182
|
+
dram_read_metadata_acquire_lock_avg_duration = dram_kv_perf_stats[34]
|
|
4183
|
+
dram_read_read_metadata_load_size = dram_kv_perf_stats[35]
|
|
4184
|
+
|
|
4185
|
+
stats_reporter.report_duration(
|
|
4186
|
+
iteration_step=self.step,
|
|
4187
|
+
event_name="dram_kv.perf.get.dram_read_duration_us",
|
|
4188
|
+
duration_ms=dram_read_duration,
|
|
4189
|
+
enable_tb_metrics=True,
|
|
4190
|
+
time_unit="us",
|
|
4191
|
+
)
|
|
4192
|
+
stats_reporter.report_duration(
|
|
4193
|
+
iteration_step=self.step,
|
|
4194
|
+
event_name="dram_kv.perf.get.dram_read_sharding_duration_us",
|
|
4195
|
+
duration_ms=dram_read_sharding_duration,
|
|
4196
|
+
enable_tb_metrics=True,
|
|
4197
|
+
time_unit="us",
|
|
4198
|
+
)
|
|
4199
|
+
stats_reporter.report_duration(
|
|
4200
|
+
iteration_step=self.step,
|
|
4201
|
+
event_name="dram_kv.perf.get.dram_read_cache_hit_copy_duration_us",
|
|
4202
|
+
duration_ms=dram_read_cache_hit_copy_duration,
|
|
4203
|
+
enable_tb_metrics=True,
|
|
4204
|
+
time_unit="us",
|
|
4205
|
+
)
|
|
4206
|
+
stats_reporter.report_duration(
|
|
4207
|
+
iteration_step=self.step,
|
|
4208
|
+
event_name="dram_kv.perf.get.dram_read_fill_row_storage_duration_us",
|
|
4209
|
+
duration_ms=dram_read_fill_row_storage_duration,
|
|
4210
|
+
enable_tb_metrics=True,
|
|
4211
|
+
time_unit="us",
|
|
4212
|
+
)
|
|
4213
|
+
stats_reporter.report_duration(
|
|
4214
|
+
iteration_step=self.step,
|
|
4215
|
+
event_name="dram_kv.perf.get.dram_read_lookup_cache_duration_us",
|
|
4216
|
+
duration_ms=dram_read_lookup_cache_duration,
|
|
4217
|
+
enable_tb_metrics=True,
|
|
4218
|
+
time_unit="us",
|
|
4219
|
+
)
|
|
4220
|
+
stats_reporter.report_duration(
|
|
4221
|
+
iteration_step=self.step,
|
|
4222
|
+
event_name="dram_kv.perf.get.dram_read_acquire_lock_duration_us",
|
|
4223
|
+
duration_ms=dram_read_acquire_lock_duration,
|
|
4224
|
+
enable_tb_metrics=True,
|
|
4225
|
+
time_unit="us",
|
|
4226
|
+
)
|
|
4227
|
+
stats_reporter.report_data_amount(
|
|
4228
|
+
iteration_step=self.step,
|
|
4229
|
+
event_name="dram_kv.perf.get.dram_read_missing_load",
|
|
4230
|
+
enable_tb_metrics=True,
|
|
4231
|
+
data_bytes=dram_read_missing_load,
|
|
4232
|
+
)
|
|
4233
|
+
stats_reporter.report_duration(
|
|
4234
|
+
iteration_step=self.step,
|
|
4235
|
+
event_name="dram_kv.perf.set.dram_write_sharing_duration_us",
|
|
4236
|
+
duration_ms=dram_write_sharing_duration,
|
|
4237
|
+
enable_tb_metrics=True,
|
|
4238
|
+
time_unit="us",
|
|
4239
|
+
)
|
|
4240
|
+
|
|
4241
|
+
stats_reporter.report_duration(
|
|
4242
|
+
iteration_step=self.step,
|
|
4243
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_duration_us",
|
|
4244
|
+
duration_ms=dram_fwd_l1_eviction_write_duration,
|
|
4245
|
+
enable_tb_metrics=True,
|
|
4246
|
+
time_unit="us",
|
|
4247
|
+
)
|
|
4248
|
+
stats_reporter.report_duration(
|
|
4249
|
+
iteration_step=self.step,
|
|
4250
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_allocate_duration_us",
|
|
4251
|
+
duration_ms=dram_fwd_l1_eviction_write_allocate_duration,
|
|
4252
|
+
enable_tb_metrics=True,
|
|
4253
|
+
time_unit="us",
|
|
4254
|
+
)
|
|
4255
|
+
stats_reporter.report_duration(
|
|
4256
|
+
iteration_step=self.step,
|
|
4257
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_cache_copy_duration_us",
|
|
4258
|
+
duration_ms=dram_fwd_l1_eviction_write_cache_copy_duration,
|
|
4259
|
+
enable_tb_metrics=True,
|
|
4260
|
+
time_unit="us",
|
|
4261
|
+
)
|
|
4262
|
+
stats_reporter.report_duration(
|
|
4263
|
+
iteration_step=self.step,
|
|
4264
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_lookup_cache_duration_us",
|
|
4265
|
+
duration_ms=dram_fwd_l1_eviction_write_lookup_cache_duration,
|
|
4266
|
+
enable_tb_metrics=True,
|
|
4267
|
+
time_unit="us",
|
|
4268
|
+
)
|
|
4269
|
+
stats_reporter.report_duration(
|
|
4270
|
+
iteration_step=self.step,
|
|
4271
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_acquire_lock_duration_us",
|
|
4272
|
+
duration_ms=dram_fwd_l1_eviction_write_acquire_lock_duration,
|
|
4273
|
+
enable_tb_metrics=True,
|
|
4274
|
+
time_unit="us",
|
|
4275
|
+
)
|
|
4276
|
+
stats_reporter.report_data_amount(
|
|
4277
|
+
iteration_step=self.step,
|
|
4278
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load",
|
|
4279
|
+
data_bytes=dram_fwd_l1_eviction_write_missing_load,
|
|
4280
|
+
enable_tb_metrics=True,
|
|
4281
|
+
)
|
|
4282
|
+
|
|
4283
|
+
stats_reporter.report_duration(
|
|
4284
|
+
iteration_step=self.step,
|
|
4285
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_duration_us",
|
|
4286
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_duration,
|
|
4287
|
+
enable_tb_metrics=True,
|
|
4288
|
+
time_unit="us",
|
|
4289
|
+
)
|
|
4290
|
+
stats_reporter.report_duration(
|
|
4291
|
+
iteration_step=self.step,
|
|
4292
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_allocate_duration_us",
|
|
4293
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_allocate_duration,
|
|
4294
|
+
enable_tb_metrics=True,
|
|
4295
|
+
time_unit="us",
|
|
4296
|
+
)
|
|
4297
|
+
stats_reporter.report_duration(
|
|
4298
|
+
iteration_step=self.step,
|
|
4299
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_cache_copy_duration_us",
|
|
4300
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_cache_copy_duration,
|
|
4301
|
+
enable_tb_metrics=True,
|
|
4302
|
+
time_unit="us",
|
|
4303
|
+
)
|
|
4304
|
+
stats_reporter.report_duration(
|
|
4305
|
+
iteration_step=self.step,
|
|
4306
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_lookup_cache_duration_us",
|
|
4307
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_lookup_cache_duration,
|
|
4308
|
+
enable_tb_metrics=True,
|
|
4309
|
+
time_unit="us",
|
|
4310
|
+
)
|
|
4311
|
+
stats_reporter.report_duration(
|
|
4312
|
+
iteration_step=self.step,
|
|
4313
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_acquire_lock_duration_us",
|
|
4314
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_acquire_lock_duration,
|
|
4315
|
+
enable_tb_metrics=True,
|
|
4316
|
+
time_unit="us",
|
|
4317
|
+
)
|
|
4318
|
+
stats_reporter.report_data_amount(
|
|
4319
|
+
iteration_step=self.step,
|
|
4320
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load",
|
|
4321
|
+
data_bytes=dram_bwd_l1_cnflct_miss_write_missing_load,
|
|
4322
|
+
enable_tb_metrics=True,
|
|
4323
|
+
)
|
|
4324
|
+
|
|
4325
|
+
stats_reporter.report_data_amount(
|
|
4326
|
+
iteration_step=self.step,
|
|
4327
|
+
event_name="dram_kv.perf.get.dram_kv_read_counts",
|
|
4328
|
+
data_bytes=dram_kv_read_counts,
|
|
4329
|
+
enable_tb_metrics=True,
|
|
4330
|
+
)
|
|
4331
|
+
|
|
4332
|
+
stats_reporter.report_data_amount(
|
|
4333
|
+
iteration_step=self.step,
|
|
4334
|
+
event_name=self.dram_kv_allocated_bytes_stats_name,
|
|
4335
|
+
data_bytes=dram_kv_allocated_bytes,
|
|
4336
|
+
enable_tb_metrics=True,
|
|
4337
|
+
)
|
|
4338
|
+
stats_reporter.report_data_amount(
|
|
4339
|
+
iteration_step=self.step,
|
|
4340
|
+
event_name=self.dram_kv_actual_used_chunk_bytes_stats_name,
|
|
4341
|
+
data_bytes=dram_kv_actual_used_chunk_bytes,
|
|
4342
|
+
enable_tb_metrics=True,
|
|
4343
|
+
)
|
|
4344
|
+
stats_reporter.report_data_amount(
|
|
4345
|
+
iteration_step=self.step,
|
|
4346
|
+
event_name=self.dram_kv_mem_num_rows_stats_name,
|
|
4347
|
+
data_bytes=dram_kv_num_rows,
|
|
4348
|
+
enable_tb_metrics=True,
|
|
4349
|
+
)
|
|
4350
|
+
stats_reporter.report_duration(
|
|
4351
|
+
iteration_step=self.step,
|
|
4352
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_sharding_total_duration_us",
|
|
4353
|
+
duration_ms=dram_metadata_write_sharding_total_duration,
|
|
4354
|
+
enable_tb_metrics=True,
|
|
4355
|
+
time_unit="us",
|
|
4356
|
+
)
|
|
4357
|
+
stats_reporter.report_duration(
|
|
4358
|
+
iteration_step=self.step,
|
|
4359
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_total_duration_us",
|
|
4360
|
+
duration_ms=dram_metadata_write_total_duration,
|
|
4361
|
+
enable_tb_metrics=True,
|
|
4362
|
+
time_unit="us",
|
|
4363
|
+
)
|
|
4364
|
+
stats_reporter.report_duration(
|
|
4365
|
+
iteration_step=self.step,
|
|
4366
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_allocate_avg_duration_us",
|
|
4367
|
+
duration_ms=dram_metadata_write_allocate_avg_duration,
|
|
4368
|
+
enable_tb_metrics=True,
|
|
4369
|
+
time_unit="us",
|
|
4370
|
+
)
|
|
4371
|
+
stats_reporter.report_duration(
|
|
4372
|
+
iteration_step=self.step,
|
|
4373
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_lookup_cache_avg_duration_us",
|
|
4374
|
+
duration_ms=dram_metadata_write_lookup_cache_avg_duration,
|
|
4375
|
+
enable_tb_metrics=True,
|
|
4376
|
+
time_unit="us",
|
|
4377
|
+
)
|
|
4378
|
+
stats_reporter.report_duration(
|
|
4379
|
+
iteration_step=self.step,
|
|
4380
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_acquire_lock_avg_duration_us",
|
|
4381
|
+
duration_ms=dram_metadata_write_acquire_lock_avg_duration,
|
|
4382
|
+
enable_tb_metrics=True,
|
|
4383
|
+
time_unit="us",
|
|
4384
|
+
)
|
|
4385
|
+
stats_reporter.report_data_amount(
|
|
4386
|
+
iteration_step=self.step,
|
|
4387
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_cache_miss_avg_count",
|
|
4388
|
+
data_bytes=dram_metadata_write_cache_miss_avg_count,
|
|
4389
|
+
enable_tb_metrics=True,
|
|
4390
|
+
)
|
|
4391
|
+
stats_reporter.report_duration(
|
|
4392
|
+
iteration_step=self.step,
|
|
4393
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_total_duration_us",
|
|
4394
|
+
duration_ms=dram_read_metadata_total_duration,
|
|
4395
|
+
enable_tb_metrics=True,
|
|
4396
|
+
time_unit="us",
|
|
4397
|
+
)
|
|
4398
|
+
stats_reporter.report_duration(
|
|
4399
|
+
iteration_step=self.step,
|
|
4400
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_sharding_total_duration_us",
|
|
4401
|
+
duration_ms=dram_read_metadata_sharding_total_duration,
|
|
4402
|
+
enable_tb_metrics=True,
|
|
4403
|
+
time_unit="us",
|
|
4404
|
+
)
|
|
4405
|
+
stats_reporter.report_duration(
|
|
4406
|
+
iteration_step=self.step,
|
|
4407
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_cache_hit_copy_avg_duration_us",
|
|
4408
|
+
duration_ms=dram_read_metadata_cache_hit_copy_avg_duration,
|
|
4409
|
+
enable_tb_metrics=True,
|
|
4410
|
+
time_unit="us",
|
|
4411
|
+
)
|
|
4412
|
+
stats_reporter.report_duration(
|
|
4413
|
+
iteration_step=self.step,
|
|
4414
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_lookup_cache_total_avg_duration_us",
|
|
4415
|
+
duration_ms=dram_read_metadata_lookup_cache_total_avg_duration,
|
|
4416
|
+
enable_tb_metrics=True,
|
|
4417
|
+
time_unit="us",
|
|
4418
|
+
)
|
|
4419
|
+
stats_reporter.report_duration(
|
|
4420
|
+
iteration_step=self.step,
|
|
4421
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_acquire_lock_avg_duration_us",
|
|
4422
|
+
duration_ms=dram_read_metadata_acquire_lock_avg_duration,
|
|
4423
|
+
enable_tb_metrics=True,
|
|
4424
|
+
time_unit="us",
|
|
4425
|
+
)
|
|
4426
|
+
stats_reporter.report_data_amount(
|
|
4427
|
+
iteration_step=self.step,
|
|
4428
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_load_size",
|
|
4429
|
+
data_bytes=dram_read_read_metadata_load_size,
|
|
4430
|
+
enable_tb_metrics=True,
|
|
4431
|
+
)
|
|
4432
|
+
|
|
2179
4433
|
def _recording_to_timer(
|
|
2180
4434
|
self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
|
|
2181
4435
|
) -> Any:
|
|
@@ -2191,3 +4445,484 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2191
4445
|
return timer.recording(**kwargs)
|
|
2192
4446
|
# No-Op context manager
|
|
2193
4447
|
return contextlib.nullcontext()
|
|
4448
|
+
|
|
4449
|
+
def fetch_from_l1_sp_w_row_ids(
|
|
4450
|
+
self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False
|
|
4451
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
4452
|
+
"""
|
|
4453
|
+
Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
|
|
4454
|
+
@return: updated_weights/optimizer_states, mask of which rows are filled
|
|
4455
|
+
"""
|
|
4456
|
+
if not self.enable_optimizer_offloading and only_get_optimizer_states:
|
|
4457
|
+
raise RuntimeError(
|
|
4458
|
+
"Optimizer states are not offloaded, while only_get_optimizer_states is True"
|
|
4459
|
+
)
|
|
4460
|
+
|
|
4461
|
+
# NOTE: Remove this once there is support for fetching multiple
|
|
4462
|
+
# optimizer states in fetch_from_l1_sp_w_row_ids
|
|
4463
|
+
if only_get_optimizer_states and self.optimizer not in [
|
|
4464
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
4465
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
4466
|
+
]:
|
|
4467
|
+
raise RuntimeError(
|
|
4468
|
+
f"Fetching optimizer states using fetch_from_l1_sp_w_row_ids() is not yet supported for {self.optimizer}"
|
|
4469
|
+
)
|
|
4470
|
+
|
|
4471
|
+
def split_results_by_opt_states(
|
|
4472
|
+
updated_weights: torch.Tensor, cache_location_mask: torch.Tensor
|
|
4473
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
4474
|
+
if not only_get_optimizer_states:
|
|
4475
|
+
return [updated_weights], cache_location_mask
|
|
4476
|
+
# TODO: support mixed dimension case
|
|
4477
|
+
# currently only supports tables with the same max_D dimension
|
|
4478
|
+
opt_to_dim = self.optimizer.byte_offsets_along_row(
|
|
4479
|
+
self.max_D, self.weights_precision, self.optimizer_state_dtypes
|
|
4480
|
+
)
|
|
4481
|
+
updated_opt_states = []
|
|
4482
|
+
for opt_name, dim in opt_to_dim.items():
|
|
4483
|
+
opt_dtype = self.optimizer._extract_dtype(
|
|
4484
|
+
self.optimizer_state_dtypes, opt_name
|
|
4485
|
+
)
|
|
4486
|
+
updated_opt_states.append(
|
|
4487
|
+
updated_weights.view(dtype=torch.uint8)[:, dim[0] : dim[1]].view(
|
|
4488
|
+
dtype=opt_dtype
|
|
4489
|
+
)
|
|
4490
|
+
)
|
|
4491
|
+
return updated_opt_states, cache_location_mask
|
|
4492
|
+
|
|
4493
|
+
with torch.no_grad():
|
|
4494
|
+
weights_dtype = self.weights_precision.as_dtype()
|
|
4495
|
+
step = self.step
|
|
4496
|
+
with record_function(f"## fetch_from_l1_{step}_{self.tbe_unique_id} ##"):
|
|
4497
|
+
lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
|
|
4498
|
+
row_ids,
|
|
4499
|
+
self.lxu_cache_state,
|
|
4500
|
+
self.total_hash_size,
|
|
4501
|
+
)
|
|
4502
|
+
updated_weights = torch.empty(
|
|
4503
|
+
row_ids.numel(),
|
|
4504
|
+
self.cache_row_dim,
|
|
4505
|
+
device=self.current_device,
|
|
4506
|
+
dtype=weights_dtype,
|
|
4507
|
+
)
|
|
4508
|
+
|
|
4509
|
+
# D2D copy cache
|
|
4510
|
+
cache_location_mask = lxu_cache_locations >= 0
|
|
4511
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4512
|
+
updated_weights,
|
|
4513
|
+
lxu_cache_locations,
|
|
4514
|
+
self.lxu_cache_weights,
|
|
4515
|
+
torch.tensor(
|
|
4516
|
+
[row_ids.numel()],
|
|
4517
|
+
device=self.current_device,
|
|
4518
|
+
dtype=torch.int32,
|
|
4519
|
+
),
|
|
4520
|
+
)
|
|
4521
|
+
|
|
4522
|
+
with record_function(f"## fetch_from_sp_{step}_{self.tbe_unique_id} ##"):
|
|
4523
|
+
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
4524
|
+
sp = self.ssd_scratch_pad_eviction_data[0][0]
|
|
4525
|
+
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
|
|
4526
|
+
self.current_device
|
|
4527
|
+
)
|
|
4528
|
+
actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
|
|
4529
|
+
if actions_count_gpu.item() == 0:
|
|
4530
|
+
# no action to take
|
|
4531
|
+
return split_results_by_opt_states(
|
|
4532
|
+
updated_weights, cache_location_mask
|
|
4533
|
+
)
|
|
4534
|
+
|
|
4535
|
+
sp_idx = sp_idx[:actions_count_gpu]
|
|
4536
|
+
|
|
4537
|
+
# -1 in lxu_cache_locations means the row is not in L1 cache and in SP
|
|
4538
|
+
# fill the row_ids in L1 with -2, >0 values means in SP
|
|
4539
|
+
# @eg. updated_row_ids_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
|
|
4540
|
+
updated_row_ids_in_sp = row_ids.masked_fill(
|
|
4541
|
+
lxu_cache_locations != -1, -2
|
|
4542
|
+
)
|
|
4543
|
+
# sort the sp_idx for binary search
|
|
4544
|
+
# should already be sorted
|
|
4545
|
+
# sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
|
|
4546
|
+
# @eg. sp_idx = [4, 2, 1, 3, 10]
|
|
4547
|
+
# @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
|
|
4548
|
+
sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
|
|
4549
|
+
# search rows id in sp against the SP indexes to find location of the rows in SP
|
|
4550
|
+
# @eg: updated_ids_in_sp_idx = [0, 5, 0, 1, 0, 2, 3, 4, 4]
|
|
4551
|
+
# @eg: 5 is OOB
|
|
4552
|
+
updated_ids_in_sp_idx = torch.searchsorted(
|
|
4553
|
+
sorted_sp_idx, updated_row_ids_in_sp
|
|
4554
|
+
)
|
|
4555
|
+
# does not found in SP will Out of Bound
|
|
4556
|
+
oob_sp_idx = updated_ids_in_sp_idx >= sp_idx.numel()
|
|
4557
|
+
# make the oob items in bound
|
|
4558
|
+
# @eg updated_ids_in_sp_idx=[0, 0, 0, 1, 0, 2, 3, 4, 4]
|
|
4559
|
+
updated_ids_in_sp_idx[oob_sp_idx] = 0
|
|
4560
|
+
|
|
4561
|
+
# -1s locations will be filtered out in masked_index_select
|
|
4562
|
+
sp_locations_in_updated_weights = torch.full_like(
|
|
4563
|
+
updated_row_ids_in_sp, -1
|
|
4564
|
+
)
|
|
4565
|
+
# torch.searchsorted is not exact match,
|
|
4566
|
+
# we only take exact matched rows, where the id is found in SP.
|
|
4567
|
+
# @eg 5 in updated_row_ids_in_sp is not in sp_idx, but has 4 in updated_ids_in_sp_idx
|
|
4568
|
+
# @eg sorted_sp_idx[updated_ids_in_sp_idx]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
|
|
4569
|
+
# @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
|
|
4570
|
+
exact_match_mask = (
|
|
4571
|
+
sorted_sp_idx[updated_ids_in_sp_idx] == updated_row_ids_in_sp
|
|
4572
|
+
)
|
|
4573
|
+
# Get the location of the row ids found in SP.
|
|
4574
|
+
# @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
|
|
4575
|
+
sp_locations_found = sp_idx_inverse_indices[
|
|
4576
|
+
updated_ids_in_sp_idx[exact_match_mask]
|
|
4577
|
+
]
|
|
4578
|
+
# @eg: sp_locations_in_updated_weights=[ 2, -1, 2, 1, -1, 3, 0, -1, 4]
|
|
4579
|
+
sp_locations_in_updated_weights[exact_match_mask] = (
|
|
4580
|
+
sp_locations_found
|
|
4581
|
+
)
|
|
4582
|
+
|
|
4583
|
+
# D2D copy SP
|
|
4584
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4585
|
+
updated_weights,
|
|
4586
|
+
sp_locations_in_updated_weights,
|
|
4587
|
+
sp,
|
|
4588
|
+
torch.tensor(
|
|
4589
|
+
[row_ids.numel()],
|
|
4590
|
+
device=self.current_device,
|
|
4591
|
+
dtype=torch.int32,
|
|
4592
|
+
),
|
|
4593
|
+
)
|
|
4594
|
+
# cache_location_mask is the mask of rows in L1
|
|
4595
|
+
# exact_match_mask is the mask of rows in SP
|
|
4596
|
+
cache_location_mask = torch.logical_or(
|
|
4597
|
+
cache_location_mask, exact_match_mask
|
|
4598
|
+
)
|
|
4599
|
+
|
|
4600
|
+
return split_results_by_opt_states(updated_weights, cache_location_mask)
|
|
4601
|
+
|
|
4602
|
+
def register_backward_hook_before_eviction(
|
|
4603
|
+
self, backward_hook: Callable[[torch.Tensor], None]
|
|
4604
|
+
) -> None:
|
|
4605
|
+
"""
|
|
4606
|
+
Register a backward hook to the TBE module.
|
|
4607
|
+
And make sure this is called before the sp eviction hook.
|
|
4608
|
+
"""
|
|
4609
|
+
# make sure this hook is the first one to be executed
|
|
4610
|
+
hooks = []
|
|
4611
|
+
backward_hooks = self.placeholder_autograd_tensor._backward_hooks
|
|
4612
|
+
if backward_hooks is not None:
|
|
4613
|
+
for _handle_id, hook in backward_hooks.items():
|
|
4614
|
+
hooks.append(hook)
|
|
4615
|
+
backward_hooks.clear()
|
|
4616
|
+
|
|
4617
|
+
self.placeholder_autograd_tensor.register_hook(backward_hook)
|
|
4618
|
+
for hook in hooks:
|
|
4619
|
+
self.placeholder_autograd_tensor.register_hook(hook)
|
|
4620
|
+
|
|
4621
|
+
def set_local_weight_counts_for_table(
|
|
4622
|
+
self, table_idx: int, weight_count: int
|
|
4623
|
+
) -> None:
|
|
4624
|
+
self.local_weight_counts[table_idx] = weight_count
|
|
4625
|
+
|
|
4626
|
+
def set_global_id_per_rank_for_table(
|
|
4627
|
+
self, table_idx: int, global_id: torch.Tensor
|
|
4628
|
+
) -> None:
|
|
4629
|
+
self.global_id_per_rank[table_idx] = global_id
|
|
4630
|
+
|
|
4631
|
+
def direct_write_embedding(
|
|
4632
|
+
self,
|
|
4633
|
+
indices: torch.Tensor,
|
|
4634
|
+
offsets: torch.Tensor,
|
|
4635
|
+
weights: torch.Tensor,
|
|
4636
|
+
) -> None:
|
|
4637
|
+
"""
|
|
4638
|
+
Directly write the weights to L1, SP and backend without relying on auto-gradient for embedding cache.
|
|
4639
|
+
Please refer to design doc for more details: https://docs.google.com/document/d/1TJHKvO1m3-5tYAKZGhacXnGk7iCNAzz7wQlrFbX_LDI/edit?tab=t.0
|
|
4640
|
+
"""
|
|
4641
|
+
assert (
|
|
4642
|
+
self._embedding_cache_mode
|
|
4643
|
+
), "Must be in embedding_cache_mode to support direct_write_embedding method."
|
|
4644
|
+
|
|
4645
|
+
B_offsets = None
|
|
4646
|
+
max_B = -1
|
|
4647
|
+
|
|
4648
|
+
with torch.no_grad():
|
|
4649
|
+
# Wait for any ongoing prefetch operations to complete before starting direct_write
|
|
4650
|
+
current_stream = torch.cuda.current_stream()
|
|
4651
|
+
current_stream.wait_event(self.prefetch_complete_event)
|
|
4652
|
+
|
|
4653
|
+
# Create local step events for internal sequential execution
|
|
4654
|
+
weights_dtype = self.weights_precision.as_dtype()
|
|
4655
|
+
assert (
|
|
4656
|
+
weights_dtype == weights.dtype
|
|
4657
|
+
), f"Expected embedding table dtype {weights_dtype} is same with input weight dtype, but got {weights.dtype}"
|
|
4658
|
+
|
|
4659
|
+
# Pad the weights to match self.max_D width if necessary
|
|
4660
|
+
if weights.size(1) < self.cache_row_dim:
|
|
4661
|
+
weights = torch.nn.functional.pad(
|
|
4662
|
+
weights, (0, self.cache_row_dim - weights.size(1))
|
|
4663
|
+
)
|
|
4664
|
+
|
|
4665
|
+
step = self.step
|
|
4666
|
+
|
|
4667
|
+
# step 0: run backward hook for prefetch if prefetch pipeline is enabled before writing to L1 and SP
|
|
4668
|
+
if self.prefetch_pipeline:
|
|
4669
|
+
self._update_cache_counter_and_pointers(nn.Module(), torch.empty(0))
|
|
4670
|
+
|
|
4671
|
+
# step 1: lookup and write to l1 cache
|
|
4672
|
+
with record_function(
|
|
4673
|
+
f"## direct_write_to_l1_{step}_{self.tbe_unique_id} ##"
|
|
4674
|
+
):
|
|
4675
|
+
if self.gather_ssd_cache_stats:
|
|
4676
|
+
self.local_ssd_cache_stats.zero_()
|
|
4677
|
+
|
|
4678
|
+
# Linearize indices
|
|
4679
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
4680
|
+
self.hash_size_cumsum,
|
|
4681
|
+
indices,
|
|
4682
|
+
offsets,
|
|
4683
|
+
B_offsets,
|
|
4684
|
+
max_B,
|
|
4685
|
+
)
|
|
4686
|
+
|
|
4687
|
+
lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
|
|
4688
|
+
linear_cache_indices,
|
|
4689
|
+
self.lxu_cache_state,
|
|
4690
|
+
self.total_hash_size,
|
|
4691
|
+
)
|
|
4692
|
+
cache_location_mask = lxu_cache_locations >= 0
|
|
4693
|
+
|
|
4694
|
+
# Get the cache locations for the row_ids that are already in the cache
|
|
4695
|
+
cache_locations = lxu_cache_locations[cache_location_mask]
|
|
4696
|
+
|
|
4697
|
+
# Get the corresponding input weights for these row_ids
|
|
4698
|
+
cache_weights = weights[cache_location_mask]
|
|
4699
|
+
|
|
4700
|
+
# Update the cache with these input weights
|
|
4701
|
+
if cache_locations.numel() > 0:
|
|
4702
|
+
self.lxu_cache_weights.index_put_(
|
|
4703
|
+
(cache_locations,), cache_weights, accumulate=False
|
|
4704
|
+
)
|
|
4705
|
+
|
|
4706
|
+
# Record completion of step 1
|
|
4707
|
+
current_stream.record_event(self.direct_write_l1_complete_event)
|
|
4708
|
+
|
|
4709
|
+
# step 2: pop the current scratch pad and write to next batch scratch pad if exists
|
|
4710
|
+
# Wait for step 1 to complete
|
|
4711
|
+
with record_function(
|
|
4712
|
+
f"## direct_write_to_sp_{step}_{self.tbe_unique_id} ##"
|
|
4713
|
+
):
|
|
4714
|
+
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
4715
|
+
self.ssd_scratch_pad_eviction_data.pop(0)
|
|
4716
|
+
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
4717
|
+
# Wait for any pending backend reads to the next scratch pad
|
|
4718
|
+
# to complete before we write to it. Otherwise, stale backend data
|
|
4719
|
+
# will overwrite our direct_write updates.
|
|
4720
|
+
# The ssd_event_get marks completion of backend fetch operations.
|
|
4721
|
+
current_stream.wait_event(self.ssd_event_get)
|
|
4722
|
+
|
|
4723
|
+
# if scratch pad exists, write to next batch scratch pad
|
|
4724
|
+
sp = self.ssd_scratch_pad_eviction_data[0][0]
|
|
4725
|
+
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
|
|
4726
|
+
self.current_device
|
|
4727
|
+
)
|
|
4728
|
+
actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
|
|
4729
|
+
if actions_count_gpu.item() != 0:
|
|
4730
|
+
# when no actional_count_gpu, no need to write to SP
|
|
4731
|
+
sp_idx = sp_idx[:actions_count_gpu]
|
|
4732
|
+
|
|
4733
|
+
# -1 in lxu_cache_locations means the row is not in L1 cache and in SP
|
|
4734
|
+
# fill the row_ids in L1 with -2, >0 values means in SP or backend
|
|
4735
|
+
# @eg. updated_indices_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
|
|
4736
|
+
updated_indices_in_sp = linear_cache_indices.masked_fill(
|
|
4737
|
+
lxu_cache_locations != -1, -2
|
|
4738
|
+
)
|
|
4739
|
+
# sort the sp_idx for binary search
|
|
4740
|
+
# should already be sorted
|
|
4741
|
+
# sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
|
|
4742
|
+
# @eg. sp_idx = [4, 2, 1, 3, 10]
|
|
4743
|
+
# @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
|
|
4744
|
+
sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
|
|
4745
|
+
# search rows id in sp against the SP indexes to find location of the rows in SP
|
|
4746
|
+
# @eg: updated_indices_in_sp = [0, 5, 0, 1, 0, 2, 3, 4, 4]
|
|
4747
|
+
# @eg: 5 is OOB
|
|
4748
|
+
updated_indices_in_sp_idx = torch.searchsorted(
|
|
4749
|
+
sorted_sp_idx, updated_indices_in_sp
|
|
4750
|
+
)
|
|
4751
|
+
# does not found in SP will Out of Bound
|
|
4752
|
+
oob_sp_idx = updated_indices_in_sp_idx >= sp_idx.numel()
|
|
4753
|
+
# make the oob items in bound
|
|
4754
|
+
# @eg updated_indices_in_sp=[0, 0, 0, 1, 0, 2, 3, 4, 4]
|
|
4755
|
+
updated_indices_in_sp_idx[oob_sp_idx] = 0
|
|
4756
|
+
|
|
4757
|
+
# torch.searchsorted is not exact match,
|
|
4758
|
+
# we only take exact matched rows, where the id is found in SP.
|
|
4759
|
+
# @eg 5 in updated_indices_in_sp is not in sp_idx, but has 4 in updated_indices_in_sp
|
|
4760
|
+
# @eg sorted_sp_idx[updated_indices_in_sp]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
|
|
4761
|
+
# @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
|
|
4762
|
+
exact_match_mask = (
|
|
4763
|
+
sorted_sp_idx[updated_indices_in_sp_idx]
|
|
4764
|
+
== updated_indices_in_sp
|
|
4765
|
+
)
|
|
4766
|
+
# Get the location of the row ids found in SP.
|
|
4767
|
+
# @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
|
|
4768
|
+
sp_locations_found = sp_idx_inverse_indices[
|
|
4769
|
+
updated_indices_in_sp[exact_match_mask]
|
|
4770
|
+
]
|
|
4771
|
+
# Get the corresponding weights for the matched indices
|
|
4772
|
+
matched_weights = weights[exact_match_mask]
|
|
4773
|
+
|
|
4774
|
+
# Write the weights to the sparse tensor at the found locations
|
|
4775
|
+
if sp_locations_found.numel() > 0:
|
|
4776
|
+
sp.index_put_(
|
|
4777
|
+
(sp_locations_found,),
|
|
4778
|
+
matched_weights,
|
|
4779
|
+
accumulate=False,
|
|
4780
|
+
)
|
|
4781
|
+
current_stream.record_event(self.direct_write_sp_complete_event)
|
|
4782
|
+
|
|
4783
|
+
# step 3: write l1 cache missing rows to backend
|
|
4784
|
+
# Wait for step 2 to complete
|
|
4785
|
+
with record_function(
|
|
4786
|
+
f"## direct_write_to_backend_{step}_{self.tbe_unique_id} ##"
|
|
4787
|
+
):
|
|
4788
|
+
# Use the existing ssd_eviction_stream for all backend write operations
|
|
4789
|
+
# This stream is already created with low priority during initialization
|
|
4790
|
+
with torch.cuda.stream(self.ssd_eviction_stream):
|
|
4791
|
+
# Create a mask for indices not in L1 cache
|
|
4792
|
+
non_cache_mask = ~cache_location_mask
|
|
4793
|
+
|
|
4794
|
+
# Calculate the count of valid indices (those not in L1 cache)
|
|
4795
|
+
valid_count = non_cache_mask.sum().to(torch.int64).cpu()
|
|
4796
|
+
|
|
4797
|
+
if valid_count.item() > 0:
|
|
4798
|
+
# Extract only the indices and weights that are not in L1 cache
|
|
4799
|
+
non_cache_indices = linear_cache_indices[non_cache_mask]
|
|
4800
|
+
non_cache_weights = weights[non_cache_mask]
|
|
4801
|
+
|
|
4802
|
+
# Move tensors to CPU for set_cuda
|
|
4803
|
+
cpu_indices = non_cache_indices.cpu()
|
|
4804
|
+
cpu_weights = non_cache_weights.cpu()
|
|
4805
|
+
|
|
4806
|
+
# Write to backend - only sending the non-cache indices and weights
|
|
4807
|
+
self.record_function_via_dummy_profile(
|
|
4808
|
+
f"## ssd_write_{step}_set_cuda_{self.tbe_unique_id} ##",
|
|
4809
|
+
self.ssd_db.set_cuda,
|
|
4810
|
+
cpu_indices,
|
|
4811
|
+
cpu_weights,
|
|
4812
|
+
valid_count,
|
|
4813
|
+
self.timestep,
|
|
4814
|
+
is_bwd=False,
|
|
4815
|
+
)
|
|
4816
|
+
|
|
4817
|
+
# Return control to the main stream without waiting for the backend operation to complete
|
|
4818
|
+
|
|
4819
|
+
def get_free_cpu_memory_gb(self) -> float:
|
|
4820
|
+
def _get_mem_available() -> float:
|
|
4821
|
+
if sys.platform.startswith("linux"):
|
|
4822
|
+
info = {}
|
|
4823
|
+
with open("/proc/meminfo") as f:
|
|
4824
|
+
for line in f:
|
|
4825
|
+
p = line.split()
|
|
4826
|
+
info[p[0].strip(":").lower()] = int(p[1]) * 1024
|
|
4827
|
+
if "memavailable" in info:
|
|
4828
|
+
# Linux >= 3.14
|
|
4829
|
+
return info["memavailable"]
|
|
4830
|
+
else:
|
|
4831
|
+
return info["memfree"] + info["cached"]
|
|
4832
|
+
else:
|
|
4833
|
+
raise RuntimeError(
|
|
4834
|
+
"Unsupported platform for free memory eviction, pls use ID count eviction tirgger mode"
|
|
4835
|
+
)
|
|
4836
|
+
|
|
4837
|
+
mem = _get_mem_available()
|
|
4838
|
+
return mem / (1024**3)
|
|
4839
|
+
|
|
4840
|
+
@classmethod
|
|
4841
|
+
def trigger_evict_in_all_tbes(cls) -> None:
|
|
4842
|
+
for tbe in cls._all_tbe_instances:
|
|
4843
|
+
tbe.ssd_db.trigger_feature_evict()
|
|
4844
|
+
|
|
4845
|
+
@classmethod
|
|
4846
|
+
def tbe_has_ongoing_eviction(cls) -> bool:
|
|
4847
|
+
for tbe in cls._all_tbe_instances:
|
|
4848
|
+
if tbe.ssd_db.is_evicting():
|
|
4849
|
+
return True
|
|
4850
|
+
return False
|
|
4851
|
+
|
|
4852
|
+
def set_free_mem_eviction_trigger_config(
|
|
4853
|
+
self, eviction_policy: EvictionPolicy
|
|
4854
|
+
) -> None:
|
|
4855
|
+
self.enable_free_mem_trigger_eviction = True
|
|
4856
|
+
self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode
|
|
4857
|
+
assert (
|
|
4858
|
+
eviction_policy.eviction_free_mem_check_interval_batch is not None
|
|
4859
|
+
), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode"
|
|
4860
|
+
self.eviction_free_mem_check_interval_batch: int = (
|
|
4861
|
+
eviction_policy.eviction_free_mem_check_interval_batch
|
|
4862
|
+
)
|
|
4863
|
+
assert (
|
|
4864
|
+
eviction_policy.eviction_free_mem_threshold_gb is not None
|
|
4865
|
+
), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode"
|
|
4866
|
+
self.eviction_free_mem_threshold_gb: int = (
|
|
4867
|
+
eviction_policy.eviction_free_mem_threshold_gb
|
|
4868
|
+
)
|
|
4869
|
+
logging.info(
|
|
4870
|
+
f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}"
|
|
4871
|
+
)
|
|
4872
|
+
|
|
4873
|
+
def may_trigger_eviction(self) -> None:
|
|
4874
|
+
def is_first_tbe() -> bool:
|
|
4875
|
+
first = SSDTableBatchedEmbeddingBags._first_instance_ref
|
|
4876
|
+
return first is not None and first() is self
|
|
4877
|
+
|
|
4878
|
+
# We assume that the eviction time is less than free mem check interval time
|
|
4879
|
+
# So every time we reach this check, all evictions in all tbes should be finished.
|
|
4880
|
+
# We only need to check the first tbe because all tbes share the same free mem,
|
|
4881
|
+
# once the first tbe detect need to trigger eviction, it will call trigger func
|
|
4882
|
+
# in all tbes from _all_tbe_instances
|
|
4883
|
+
if (
|
|
4884
|
+
self.enable_free_mem_trigger_eviction
|
|
4885
|
+
and self.step % self.eviction_free_mem_check_interval_batch == 0
|
|
4886
|
+
and self.training
|
|
4887
|
+
and is_first_tbe()
|
|
4888
|
+
):
|
|
4889
|
+
if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction():
|
|
4890
|
+
SSDTableBatchedEmbeddingBags._eviction_triggered = False
|
|
4891
|
+
|
|
4892
|
+
free_cpu_mem_gb = self.get_free_cpu_memory_gb()
|
|
4893
|
+
local_evict_trigger = int(
|
|
4894
|
+
free_cpu_mem_gb < self.eviction_free_mem_threshold_gb
|
|
4895
|
+
)
|
|
4896
|
+
tensor_flag = torch.tensor(
|
|
4897
|
+
local_evict_trigger,
|
|
4898
|
+
device=self.current_device,
|
|
4899
|
+
dtype=torch.int,
|
|
4900
|
+
)
|
|
4901
|
+
world_size = dist.get_world_size(self._pg)
|
|
4902
|
+
if world_size > 1:
|
|
4903
|
+
dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg)
|
|
4904
|
+
global_evict_trigger = tensor_flag.item()
|
|
4905
|
+
else:
|
|
4906
|
+
global_evict_trigger = local_evict_trigger
|
|
4907
|
+
if (
|
|
4908
|
+
global_evict_trigger >= 1
|
|
4909
|
+
and SSDTableBatchedEmbeddingBags._eviction_triggered
|
|
4910
|
+
):
|
|
4911
|
+
logging.warning(
|
|
4912
|
+
f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true"
|
|
4913
|
+
)
|
|
4914
|
+
if (
|
|
4915
|
+
global_evict_trigger >= 1
|
|
4916
|
+
and not SSDTableBatchedEmbeddingBags._eviction_triggered
|
|
4917
|
+
):
|
|
4918
|
+
SSDTableBatchedEmbeddingBags._eviction_triggered = True
|
|
4919
|
+
SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes()
|
|
4920
|
+
logging.info(
|
|
4921
|
+
f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction"
|
|
4922
|
+
)
|
|
4923
|
+
|
|
4924
|
+
def reset_inference_mode(self) -> None:
|
|
4925
|
+
"""
|
|
4926
|
+
Reset the inference mode
|
|
4927
|
+
"""
|
|
4928
|
+
self.eval()
|