fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +112 -19
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +118 -0
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +190 -54
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
- fbgemm_gpu/split_embedding_configs.py +134 -37
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +6 -1
- fbgemm_gpu/tbe/bench/bench_config.py +14 -3
- fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
- fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
- fbgemm_gpu/tbe/ssd/common.py +1 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +1292 -267
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +15 -15
- fbgemm_gpu/tbe_input_multiplexer.py +10 -11
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +6 -2
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +1 -0
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/tbe/ssd/training.py
CHANGED
|
@@ -14,13 +14,13 @@ import itertools
|
|
|
14
14
|
import logging
|
|
15
15
|
import math
|
|
16
16
|
import os
|
|
17
|
-
import tempfile
|
|
18
17
|
import threading
|
|
19
18
|
import time
|
|
20
19
|
from functools import cached_property
|
|
21
|
-
from math import
|
|
22
|
-
from typing import Any, Callable,
|
|
20
|
+
from math import floor, log2
|
|
21
|
+
from typing import Any, Callable, ClassVar, Optional, Union
|
|
23
22
|
import torch # usort:skip
|
|
23
|
+
import weakref
|
|
24
24
|
|
|
25
25
|
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
|
26
26
|
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
|
|
@@ -35,6 +35,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
|
35
35
|
BoundsCheckMode,
|
|
36
36
|
CacheAlgorithm,
|
|
37
37
|
EmbeddingLocation,
|
|
38
|
+
EvictionPolicy,
|
|
38
39
|
get_bounds_check_version_for_platform,
|
|
39
40
|
KVZCHParams,
|
|
40
41
|
PoolingMode,
|
|
@@ -49,10 +50,12 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
|
|
49
50
|
WeightDecayMode,
|
|
50
51
|
)
|
|
51
52
|
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
|
|
53
|
+
check_allocated_vbe_output,
|
|
52
54
|
generate_vbe_metadata,
|
|
53
55
|
is_torchdynamo_compiling,
|
|
54
56
|
)
|
|
55
57
|
from torch import distributed as dist, nn, Tensor # usort:skip
|
|
58
|
+
import sys
|
|
56
59
|
from dataclasses import dataclass
|
|
57
60
|
|
|
58
61
|
from torch.autograd.profiler import record_function
|
|
@@ -76,10 +79,10 @@ class IterData:
|
|
|
76
79
|
|
|
77
80
|
@dataclass
|
|
78
81
|
class KVZCHCachedData:
|
|
79
|
-
cached_optimizer_states_per_table:
|
|
80
|
-
cached_weight_tensor_per_table:
|
|
81
|
-
cached_id_tensor_per_table:
|
|
82
|
-
cached_bucket_splits:
|
|
82
|
+
cached_optimizer_states_per_table: list[list[torch.Tensor]]
|
|
83
|
+
cached_weight_tensor_per_table: list[torch.Tensor]
|
|
84
|
+
cached_id_tensor_per_table: list[torch.Tensor]
|
|
85
|
+
cached_bucket_splits: list[torch.Tensor]
|
|
83
86
|
|
|
84
87
|
|
|
85
88
|
class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
@@ -100,13 +103,18 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
100
103
|
weights_offsets: Tensor
|
|
101
104
|
_local_instance_index: int = -1
|
|
102
105
|
res_params: RESParams
|
|
103
|
-
table_names:
|
|
106
|
+
table_names: list[str]
|
|
107
|
+
_all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet()
|
|
108
|
+
_first_instance_ref: ClassVar[weakref.ref] = None
|
|
109
|
+
_eviction_triggered: ClassVar[bool] = False
|
|
104
110
|
|
|
105
111
|
def __init__(
|
|
106
112
|
self,
|
|
107
|
-
embedding_specs:
|
|
108
|
-
feature_table_map: Optional[
|
|
113
|
+
embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
|
|
114
|
+
feature_table_map: Optional[list[int]], # [T]
|
|
109
115
|
cache_sets: int,
|
|
116
|
+
# A comma-separated string, e.g. "/data00_nvidia0,/data01_nvidia0/", db shards
|
|
117
|
+
# will be placed in these paths round-robin.
|
|
110
118
|
ssd_storage_directory: str,
|
|
111
119
|
ssd_rocksdb_shards: int = 1,
|
|
112
120
|
ssd_memtable_flush_period: int = -1,
|
|
@@ -146,13 +154,16 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
146
154
|
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
147
155
|
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
148
156
|
# Parameter Server Configs
|
|
149
|
-
ps_hosts: Optional[
|
|
157
|
+
ps_hosts: Optional[tuple[tuple[str, int]]] = None,
|
|
150
158
|
ps_max_key_per_request: Optional[int] = None,
|
|
151
159
|
ps_client_thread_num: Optional[int] = None,
|
|
152
160
|
ps_max_local_index_length: Optional[int] = None,
|
|
153
161
|
tbe_unique_id: int = -1,
|
|
154
|
-
#
|
|
155
|
-
#
|
|
162
|
+
# If set to True, will use `ssd_storage_directory` as the ssd paths.
|
|
163
|
+
# If set to False, will use the default ssd paths.
|
|
164
|
+
# In local test we need to use the pass in path for rocksdb creation
|
|
165
|
+
# fn production we could either use the default ssd mount points or explicity specify ssd
|
|
166
|
+
# mount points using `ssd_storage_directory`.
|
|
156
167
|
use_passed_in_path: int = True,
|
|
157
168
|
gather_ssd_cache_stats: Optional[bool] = False,
|
|
158
169
|
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
|
|
@@ -172,14 +183,18 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
172
183
|
enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
|
|
173
184
|
res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
|
|
174
185
|
flushing_block_size: int = 2_000_000_000, # 2GB
|
|
175
|
-
table_names: Optional[
|
|
176
|
-
|
|
186
|
+
table_names: Optional[list[str]] = None,
|
|
187
|
+
use_rowwise_bias_correction: bool = False, # For Adam use
|
|
188
|
+
optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
|
|
189
|
+
pg: Optional[dist.ProcessGroup] = None,
|
|
177
190
|
) -> None:
|
|
178
191
|
super(SSDTableBatchedEmbeddingBags, self).__init__()
|
|
179
192
|
|
|
180
193
|
# Set the optimizer
|
|
181
194
|
assert optimizer in (
|
|
182
195
|
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
196
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
197
|
+
OptimType.ADAM,
|
|
183
198
|
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
|
|
184
199
|
self.optimizer = optimizer
|
|
185
200
|
|
|
@@ -187,15 +202,28 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
187
202
|
assert weights_precision in (SparseType.FP32, SparseType.FP16)
|
|
188
203
|
self.weights_precision = weights_precision
|
|
189
204
|
self.output_dtype: int = output_dtype.as_int()
|
|
190
|
-
|
|
205
|
+
|
|
206
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
207
|
+
# Adagrad currently only supports FP32 for momentum1
|
|
208
|
+
self.optimizer_state_dtypes: dict[str, SparseType] = {
|
|
209
|
+
"momentum1": SparseType.FP32,
|
|
210
|
+
}
|
|
211
|
+
else:
|
|
212
|
+
self.optimizer_state_dtypes: dict[str, SparseType] = optimizer_state_dtypes
|
|
191
213
|
|
|
192
214
|
# Zero collision TBE configurations
|
|
193
215
|
self.kv_zch_params = kv_zch_params
|
|
194
216
|
self.backend_type = backend_type
|
|
195
217
|
self.enable_optimizer_offloading: bool = False
|
|
196
218
|
self.backend_return_whole_row: bool = False
|
|
219
|
+
self._embedding_cache_mode: bool = False
|
|
220
|
+
self.load_ckpt_without_opt: bool = False
|
|
197
221
|
if self.kv_zch_params:
|
|
198
222
|
self.kv_zch_params.validate()
|
|
223
|
+
self.load_ckpt_without_opt = (
|
|
224
|
+
# pyre-ignore [16]
|
|
225
|
+
self.kv_zch_params.load_ckpt_without_opt
|
|
226
|
+
)
|
|
199
227
|
self.enable_optimizer_offloading = (
|
|
200
228
|
# pyre-ignore [16]
|
|
201
229
|
self.kv_zch_params.enable_optimizer_offloading
|
|
@@ -214,12 +242,43 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
214
242
|
logging.info(
|
|
215
243
|
"Backend will return whole row including metaheader, weight and optimizer for checkpoint"
|
|
216
244
|
)
|
|
245
|
+
# pyre-ignore [16]
|
|
246
|
+
self._embedding_cache_mode = self.kv_zch_params.embedding_cache_mode
|
|
247
|
+
if self._embedding_cache_mode:
|
|
248
|
+
logging.info("KVZCH is in embedding_cache_mode")
|
|
249
|
+
assert self.optimizer in [
|
|
250
|
+
OptimType.EXACT_ROWWISE_ADAGRAD
|
|
251
|
+
], f"only EXACT_ROWWISE_ADAGRAD supports embedding cache mode, but got {self.optimizer}"
|
|
252
|
+
if self.load_ckpt_without_opt:
|
|
253
|
+
if (
|
|
254
|
+
# pyre-ignore [16]
|
|
255
|
+
self.kv_zch_params.optimizer_type_for_st
|
|
256
|
+
== OptimType.PARTIAL_ROWWISE_ADAM.value
|
|
257
|
+
):
|
|
258
|
+
self.optimizer = OptimType.PARTIAL_ROWWISE_ADAM
|
|
259
|
+
logging.info(
|
|
260
|
+
f"Override optimizer type with {self.optimizer=} for st publish"
|
|
261
|
+
)
|
|
262
|
+
if (
|
|
263
|
+
# pyre-ignore [16]
|
|
264
|
+
self.kv_zch_params.optimizer_state_dtypes_for_st
|
|
265
|
+
is not None
|
|
266
|
+
):
|
|
267
|
+
optimizer_state_dtypes = {}
|
|
268
|
+
for k, v in dict(
|
|
269
|
+
self.kv_zch_params.optimizer_state_dtypes_for_st
|
|
270
|
+
).items():
|
|
271
|
+
optimizer_state_dtypes[k] = SparseType.from_int(v)
|
|
272
|
+
self.optimizer_state_dtypes = optimizer_state_dtypes
|
|
273
|
+
logging.info(
|
|
274
|
+
f"Override optimizer_state_dtypes with {self.optimizer_state_dtypes=} for st publish"
|
|
275
|
+
)
|
|
217
276
|
|
|
218
277
|
self.pooling_mode = pooling_mode
|
|
219
278
|
self.bounds_check_mode_int: int = bounds_check_mode.value
|
|
220
279
|
self.embedding_specs = embedding_specs
|
|
221
280
|
self.table_names = table_names if table_names is not None else []
|
|
222
|
-
|
|
281
|
+
rows, dims = zip(*embedding_specs)
|
|
223
282
|
T_ = len(self.embedding_specs)
|
|
224
283
|
assert T_ > 0
|
|
225
284
|
# pyre-fixme[8]: Attribute has type `device`; used as `int`.
|
|
@@ -238,7 +297,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
238
297
|
f"get env {self.res_params.res_server_port=}, at rank {dist.get_rank()}, with {self.res_params=}"
|
|
239
298
|
)
|
|
240
299
|
|
|
241
|
-
self.feature_table_map:
|
|
300
|
+
self.feature_table_map: list[int] = (
|
|
242
301
|
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
243
302
|
)
|
|
244
303
|
T = len(self.feature_table_map)
|
|
@@ -318,7 +377,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
318
377
|
torch.tensor(dims, device="cpu", dtype=torch.int64),
|
|
319
378
|
)
|
|
320
379
|
|
|
321
|
-
|
|
380
|
+
info_B_num_bits_, info_B_mask_ = torch.ops.fbgemm.get_infos_metadata(
|
|
322
381
|
self.D_offsets, # unused tensor
|
|
323
382
|
1, # max_B
|
|
324
383
|
T, # T
|
|
@@ -514,11 +573,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
514
573
|
self.record_function_via_dummy_profile_factory(use_dummy_profile)
|
|
515
574
|
)
|
|
516
575
|
|
|
517
|
-
|
|
576
|
+
if use_passed_in_path:
|
|
577
|
+
ssd_dir_list = ssd_storage_directory.split(",")
|
|
578
|
+
for ssd_dir in ssd_dir_list:
|
|
579
|
+
os.makedirs(ssd_dir, exist_ok=True)
|
|
518
580
|
|
|
519
|
-
ssd_directory =
|
|
520
|
-
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
|
|
521
|
-
)
|
|
581
|
+
ssd_directory = ssd_storage_directory
|
|
522
582
|
# logging.info("DEBUG: weights_precision {}".format(weights_precision))
|
|
523
583
|
|
|
524
584
|
"""
|
|
@@ -538,10 +598,16 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
538
598
|
"""
|
|
539
599
|
self._cached_kvzch_data: Optional[KVZCHCachedData] = None
|
|
540
600
|
# initial embedding rows on this rank per table, this is used for loading checkpoint
|
|
541
|
-
self.local_weight_counts:
|
|
601
|
+
self.local_weight_counts: list[int] = [0] * T_
|
|
602
|
+
# groundtruth global id on this rank per table, this is used for loading checkpoint
|
|
603
|
+
self.global_id_per_rank: list[torch.Tensor] = [torch.zeros(0)] * T_
|
|
542
604
|
# loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend
|
|
543
605
|
self.load_state_dict: bool = False
|
|
544
606
|
|
|
607
|
+
SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self)
|
|
608
|
+
if SSDTableBatchedEmbeddingBags._first_instance_ref is None:
|
|
609
|
+
SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self)
|
|
610
|
+
|
|
545
611
|
# create tbe unique id using rank index | local tbe idx
|
|
546
612
|
if tbe_unique_id == -1:
|
|
547
613
|
SSDTableBatchedEmbeddingBags._local_instance_index += 1
|
|
@@ -559,6 +625,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
559
625
|
self.tbe_unique_id = tbe_unique_id
|
|
560
626
|
self.l2_cache_size = l2_cache_size
|
|
561
627
|
logging.info(f"tbe_unique_id: {tbe_unique_id}")
|
|
628
|
+
self.enable_free_mem_trigger_eviction: bool = False
|
|
562
629
|
if self.backend_type == BackendType.SSD:
|
|
563
630
|
logging.info(
|
|
564
631
|
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
|
|
@@ -614,6 +681,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
614
681
|
else None
|
|
615
682
|
),
|
|
616
683
|
flushing_block_size,
|
|
684
|
+
self._embedding_cache_mode, # disable_random_init
|
|
617
685
|
)
|
|
618
686
|
if self.bulk_init_chunk_size > 0:
|
|
619
687
|
self.ssd_uniform_init_lower: float = ssd_uniform_init_lower
|
|
@@ -662,18 +730,41 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
662
730
|
if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
|
|
663
731
|
else self.l2_cache_size
|
|
664
732
|
)
|
|
733
|
+
kv_zch_params = self.kv_zch_params
|
|
734
|
+
eviction_policy = self.kv_zch_params.eviction_policy
|
|
735
|
+
if eviction_policy.eviction_trigger_mode == 5:
|
|
736
|
+
# If trigger mode is free_mem(5), populate config
|
|
737
|
+
self.set_free_mem_eviction_trigger_config(eviction_policy)
|
|
738
|
+
|
|
739
|
+
enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
|
|
740
|
+
[
|
|
741
|
+
int(x)
|
|
742
|
+
for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
|
|
743
|
+
]
|
|
744
|
+
if eviction_policy.enable_eviction_for_feature_score_eviction_policy
|
|
745
|
+
is not None
|
|
746
|
+
else None
|
|
747
|
+
)
|
|
748
|
+
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
|
|
665
749
|
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
750
|
+
eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
|
|
751
|
+
eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
|
|
752
|
+
eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
|
|
669
753
|
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
754
|
+
eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
|
|
755
|
+
eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
|
|
756
|
+
eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
|
|
757
|
+
eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
|
|
758
|
+
eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
|
|
759
|
+
eviction_policy.training_id_keep_count, # training_id_keep_count for each table
|
|
760
|
+
enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
|
|
761
|
+
eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
|
|
674
762
|
table_dims.tolist() if table_dims is not None else None,
|
|
675
|
-
|
|
676
|
-
|
|
763
|
+
eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
|
|
764
|
+
eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
|
|
765
|
+
eviction_policy.interval_for_insufficient_eviction_s,
|
|
766
|
+
eviction_policy.interval_for_sufficient_eviction_s,
|
|
767
|
+
eviction_policy.interval_for_feature_statistics_decay_s,
|
|
677
768
|
)
|
|
678
769
|
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
|
|
679
770
|
self.cache_row_dim,
|
|
@@ -690,16 +781,20 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
690
781
|
else None
|
|
691
782
|
), # hash_size_cumsum
|
|
692
783
|
self.backend_return_whole_row, # backend_return_whole_row
|
|
784
|
+
False, # enable_async_update
|
|
785
|
+
self._embedding_cache_mode, # disable_random_init
|
|
693
786
|
)
|
|
694
787
|
else:
|
|
695
788
|
raise AssertionError(f"Invalid backend type {self.backend_type}")
|
|
696
789
|
|
|
697
790
|
# pyre-fixme[20]: Argument `self` expected.
|
|
698
|
-
|
|
791
|
+
low_priority, high_priority = torch.cuda.Stream.priority_range()
|
|
699
792
|
# GPU stream for SSD cache eviction
|
|
700
793
|
self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
|
|
701
|
-
# GPU stream for SSD memory copy
|
|
794
|
+
# GPU stream for SSD memory copy (also reused for feature score D2H)
|
|
702
795
|
self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
|
|
796
|
+
# GPU stream for async metadata operation
|
|
797
|
+
self.feature_score_stream = torch.cuda.Stream(priority=low_priority)
|
|
703
798
|
|
|
704
799
|
# SSD get completion event
|
|
705
800
|
self.ssd_event_get = torch.cuda.Event()
|
|
@@ -711,6 +806,17 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
711
806
|
self.ssd_event_backward = torch.cuda.Event()
|
|
712
807
|
# SSD get's input copy completion event
|
|
713
808
|
self.ssd_event_get_inputs_cpy = torch.cuda.Event()
|
|
809
|
+
if self._embedding_cache_mode:
|
|
810
|
+
# Direct write embedding completion event
|
|
811
|
+
self.direct_write_l1_complete_event: torch.cuda.streams.Event = (
|
|
812
|
+
torch.cuda.Event()
|
|
813
|
+
)
|
|
814
|
+
self.direct_write_sp_complete_event: torch.cuda.streams.Event = (
|
|
815
|
+
torch.cuda.Event()
|
|
816
|
+
)
|
|
817
|
+
# Prefetch operation completion event
|
|
818
|
+
self.prefetch_complete_event = torch.cuda.Event()
|
|
819
|
+
|
|
714
820
|
if self.prefetch_pipeline:
|
|
715
821
|
# SSD scratch pad index queue insert completion event
|
|
716
822
|
self.ssd_event_sp_idxq_insert: torch.cuda.streams.Event = torch.cuda.Event()
|
|
@@ -771,22 +877,22 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
771
877
|
)
|
|
772
878
|
|
|
773
879
|
# (Indices, Count)
|
|
774
|
-
self.prefetched_info:
|
|
880
|
+
self.prefetched_info: list[tuple[Tensor, Tensor]] = []
|
|
775
881
|
|
|
776
|
-
self.timesteps_prefetched:
|
|
882
|
+
self.timesteps_prefetched: list[int] = []
|
|
777
883
|
# TODO: add type annotation
|
|
778
884
|
# pyre-fixme[4]: Attribute must be annotated.
|
|
779
885
|
self.ssd_prefetch_data = []
|
|
780
886
|
|
|
781
887
|
# Scratch pad eviction data queue
|
|
782
|
-
self.ssd_scratch_pad_eviction_data:
|
|
783
|
-
|
|
888
|
+
self.ssd_scratch_pad_eviction_data: list[
|
|
889
|
+
tuple[Tensor, Tensor, Tensor, bool]
|
|
784
890
|
] = []
|
|
785
|
-
self.ssd_location_update_data:
|
|
891
|
+
self.ssd_location_update_data: list[tuple[Tensor, Tensor]] = []
|
|
786
892
|
|
|
787
893
|
if self.prefetch_pipeline:
|
|
788
894
|
# Scratch pad value queue
|
|
789
|
-
self.ssd_scratch_pads:
|
|
895
|
+
self.ssd_scratch_pads: list[tuple[Tensor, Tensor, Tensor]] = []
|
|
790
896
|
|
|
791
897
|
# pyre-ignore[4]
|
|
792
898
|
# Scratch pad index queue
|
|
@@ -835,7 +941,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
835
941
|
weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
|
|
836
942
|
lower_bound=cowclip_regularization.lower_bound,
|
|
837
943
|
regularization_mode=weight_decay_mode.value,
|
|
838
|
-
use_rowwise_bias_correction=
|
|
944
|
+
use_rowwise_bias_correction=use_rowwise_bias_correction, # Used in Adam optimizer
|
|
839
945
|
)
|
|
840
946
|
|
|
841
947
|
table_embedding_dtype = weights_precision.as_dtype()
|
|
@@ -888,7 +994,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
888
994
|
self.ssd_cache_stats_size = 6
|
|
889
995
|
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
|
|
890
996
|
# 4: N_conflict_unique_misses, 5: N_conflict_misses
|
|
891
|
-
self.last_reported_ssd_stats:
|
|
997
|
+
self.last_reported_ssd_stats: list[float] = []
|
|
892
998
|
self.last_reported_step = 0
|
|
893
999
|
|
|
894
1000
|
self.register_buffer(
|
|
@@ -919,7 +1025,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
919
1025
|
self.prefetch_parallel_stream_cnt: int = 2
|
|
920
1026
|
# tuple of iteration, prefetch parallel stream cnt, reported duration
|
|
921
1027
|
# since there are 2 stream in parallel in prefetch, we want to count the longest one
|
|
922
|
-
self.prefetch_duration_us:
|
|
1028
|
+
self.prefetch_duration_us: tuple[int, int, float] = (
|
|
923
1029
|
-1,
|
|
924
1030
|
self.prefetch_parallel_stream_cnt,
|
|
925
1031
|
0,
|
|
@@ -945,6 +1051,20 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
945
1051
|
self.dram_kv_allocated_bytes_stats_name: str = (
|
|
946
1052
|
f"dram_kv.mem.tbe_id{tbe_unique_id}.allocated_bytes"
|
|
947
1053
|
)
|
|
1054
|
+
self.dram_kv_mem_num_rows_stats_name: str = (
|
|
1055
|
+
f"dram_kv.mem.tbe_id{tbe_unique_id}.num_rows"
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
self.eviction_sum_evicted_counts_stats_name: str = (
|
|
1059
|
+
f"eviction.tbe_id.{tbe_unique_id}.sum_evicted_counts"
|
|
1060
|
+
)
|
|
1061
|
+
self.eviction_sum_processed_counts_stats_name: str = (
|
|
1062
|
+
f"eviction.tbe_id.{tbe_unique_id}.sum_processed_counts"
|
|
1063
|
+
)
|
|
1064
|
+
self.eviction_evict_rate_stats_name: str = (
|
|
1065
|
+
f"eviction.tbe_id.{tbe_unique_id}.evict_rate"
|
|
1066
|
+
)
|
|
1067
|
+
|
|
948
1068
|
if self.stats_reporter:
|
|
949
1069
|
self.ssd_prefetch_read_timer = AsyncSeriesTimer(
|
|
950
1070
|
functools.partial(
|
|
@@ -972,9 +1092,41 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
972
1092
|
self.stats_reporter.register_stats(
|
|
973
1093
|
self.dram_kv_actual_used_chunk_bytes_stats_name
|
|
974
1094
|
)
|
|
1095
|
+
self.stats_reporter.register_stats(self.dram_kv_mem_num_rows_stats_name)
|
|
1096
|
+
self.stats_reporter.register_stats(
|
|
1097
|
+
self.eviction_sum_evicted_counts_stats_name
|
|
1098
|
+
)
|
|
1099
|
+
self.stats_reporter.register_stats(
|
|
1100
|
+
self.eviction_sum_processed_counts_stats_name
|
|
1101
|
+
)
|
|
1102
|
+
self.stats_reporter.register_stats(self.eviction_evict_rate_stats_name)
|
|
1103
|
+
for t in self.feature_table_map:
|
|
1104
|
+
self.stats_reporter.register_stats(
|
|
1105
|
+
f"eviction.feature_table.{t}.evicted_counts"
|
|
1106
|
+
)
|
|
1107
|
+
self.stats_reporter.register_stats(
|
|
1108
|
+
f"eviction.feature_table.{t}.processed_counts"
|
|
1109
|
+
)
|
|
1110
|
+
self.stats_reporter.register_stats(
|
|
1111
|
+
f"eviction.feature_table.{t}.evict_rate"
|
|
1112
|
+
)
|
|
1113
|
+
self.stats_reporter.register_stats(
|
|
1114
|
+
"eviction.feature_table.full_duration_ms"
|
|
1115
|
+
)
|
|
1116
|
+
self.stats_reporter.register_stats(
|
|
1117
|
+
"eviction.feature_table.exec_duration_ms"
|
|
1118
|
+
)
|
|
1119
|
+
self.stats_reporter.register_stats(
|
|
1120
|
+
"eviction.feature_table.dry_run_exec_duration_ms"
|
|
1121
|
+
)
|
|
1122
|
+
self.stats_reporter.register_stats(
|
|
1123
|
+
"eviction.feature_table.exec_div_full_duration_rate"
|
|
1124
|
+
)
|
|
975
1125
|
|
|
976
1126
|
self.bounds_check_version: int = get_bounds_check_version_for_platform()
|
|
977
1127
|
|
|
1128
|
+
self._pg = pg
|
|
1129
|
+
|
|
978
1130
|
@cached_property
|
|
979
1131
|
def cache_row_dim(self) -> int:
|
|
980
1132
|
"""
|
|
@@ -982,7 +1134,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
982
1134
|
padding to the nearest 4 elements and the optimizer state appended to
|
|
983
1135
|
the back of the row
|
|
984
1136
|
"""
|
|
985
|
-
|
|
1137
|
+
|
|
1138
|
+
# For st publish, we only need to load weight for publishing and bulk eval
|
|
1139
|
+
if self.enable_optimizer_offloading and not self.load_ckpt_without_opt:
|
|
986
1140
|
return self.max_D + pad4(
|
|
987
1141
|
# Compute the number of elements of cache_dtype needed to store
|
|
988
1142
|
# the optimizer state
|
|
@@ -1182,10 +1336,10 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1182
1336
|
self,
|
|
1183
1337
|
split: SplitState,
|
|
1184
1338
|
prefix: str,
|
|
1185
|
-
dtype:
|
|
1339
|
+
dtype: type[torch.dtype],
|
|
1186
1340
|
enforce_hbm: bool = False,
|
|
1187
1341
|
make_dev_param: bool = False,
|
|
1188
|
-
dev_reshape: Optional[
|
|
1342
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
1189
1343
|
) -> None:
|
|
1190
1344
|
apply_split_helper(
|
|
1191
1345
|
self.register_buffer,
|
|
@@ -1208,11 +1362,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1208
1362
|
|
|
1209
1363
|
def to_pinned_cpu_on_stream_wait_on_another_stream(
|
|
1210
1364
|
self,
|
|
1211
|
-
tensors:
|
|
1365
|
+
tensors: list[Tensor],
|
|
1212
1366
|
stream: torch.cuda.Stream,
|
|
1213
1367
|
stream_to_wait_on: torch.cuda.Stream,
|
|
1214
1368
|
post_event: Optional[torch.cuda.Event] = None,
|
|
1215
|
-
) ->
|
|
1369
|
+
) -> list[Tensor]:
|
|
1216
1370
|
"""
|
|
1217
1371
|
Transfer input tensors from GPU to CPU using a pinned host
|
|
1218
1372
|
buffer. The transfer is carried out on the given stream
|
|
@@ -1274,6 +1428,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1274
1428
|
Returns:
|
|
1275
1429
|
None
|
|
1276
1430
|
"""
|
|
1431
|
+
if not self.training: # if not training, freeze the embedding
|
|
1432
|
+
return
|
|
1277
1433
|
with record_function(f"## ssd_evict_{name} ##"):
|
|
1278
1434
|
with torch.cuda.stream(stream):
|
|
1279
1435
|
if pre_event is not None:
|
|
@@ -1286,7 +1442,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1286
1442
|
self.record_function_via_dummy_profile(
|
|
1287
1443
|
f"## ssd_set_{name} ##",
|
|
1288
1444
|
self.ssd_db.set_cuda,
|
|
1289
|
-
indices_cpu
|
|
1445
|
+
indices_cpu,
|
|
1290
1446
|
rows_cpu,
|
|
1291
1447
|
actions_count_cpu,
|
|
1292
1448
|
self.timestep,
|
|
@@ -1450,7 +1606,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1450
1606
|
def _update_cache_counter_and_pointers(
|
|
1451
1607
|
self,
|
|
1452
1608
|
module: nn.Module,
|
|
1453
|
-
grad_input: Union[
|
|
1609
|
+
grad_input: Union[tuple[Tensor, ...], Tensor],
|
|
1454
1610
|
) -> None:
|
|
1455
1611
|
"""
|
|
1456
1612
|
Update cache line locking counter and pointers before backward
|
|
@@ -1535,9 +1691,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1535
1691
|
if len(self.ssd_location_update_data) == 0:
|
|
1536
1692
|
return
|
|
1537
1693
|
|
|
1538
|
-
|
|
1539
|
-
0
|
|
1540
|
-
)
|
|
1694
|
+
sp_curr_next_map, inserted_rows_next = self.ssd_location_update_data.pop(0)
|
|
1541
1695
|
|
|
1542
1696
|
# Update poitners
|
|
1543
1697
|
torch.ops.fbgemm.ssd_update_row_addrs(
|
|
@@ -1552,12 +1706,63 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1552
1706
|
unique_indices_length_curr=curr_data.actions_count_gpu,
|
|
1553
1707
|
)
|
|
1554
1708
|
|
|
1709
|
+
def _update_feature_score_metadata(
|
|
1710
|
+
self,
|
|
1711
|
+
linear_cache_indices: Tensor,
|
|
1712
|
+
weights: Tensor,
|
|
1713
|
+
d2h_stream: torch.cuda.Stream,
|
|
1714
|
+
write_stream: torch.cuda.Stream,
|
|
1715
|
+
pre_event_for_write: torch.cuda.Event,
|
|
1716
|
+
post_event: Optional[torch.cuda.Event] = None,
|
|
1717
|
+
) -> None:
|
|
1718
|
+
"""
|
|
1719
|
+
Write feature score metadata to DRAM
|
|
1720
|
+
|
|
1721
|
+
This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
|
|
1722
|
+
The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
|
|
1723
|
+
|
|
1724
|
+
Args:
|
|
1725
|
+
linear_cache_indices: GPU tensor containing cache indices
|
|
1726
|
+
weights: GPU tensor containing feature scores
|
|
1727
|
+
d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
|
|
1728
|
+
write_stream: Stream for metadata write operation
|
|
1729
|
+
pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
|
|
1730
|
+
post_event: Event to record when the operation is done
|
|
1731
|
+
"""
|
|
1732
|
+
# Start D2H copy on d2h_stream
|
|
1733
|
+
with torch.cuda.stream(d2h_stream):
|
|
1734
|
+
# Record streams to prevent premature deallocation
|
|
1735
|
+
linear_cache_indices.record_stream(d2h_stream)
|
|
1736
|
+
weights.record_stream(d2h_stream)
|
|
1737
|
+
# Do the D2H copy
|
|
1738
|
+
linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
|
|
1739
|
+
score_weights_cpu = self.to_pinned_cpu(weights)
|
|
1740
|
+
|
|
1741
|
+
# Write feature score metadata to DRAM
|
|
1742
|
+
with record_function("## ssd_write_feature_score_metadata ##"):
|
|
1743
|
+
with torch.cuda.stream(write_stream):
|
|
1744
|
+
write_stream.wait_event(pre_event_for_write)
|
|
1745
|
+
write_stream.wait_stream(d2h_stream)
|
|
1746
|
+
self.record_function_via_dummy_profile(
|
|
1747
|
+
"## ssd_write_feature_score_metadata ##",
|
|
1748
|
+
self.ssd_db.set_feature_score_metadata_cuda,
|
|
1749
|
+
linear_cache_indices_cpu,
|
|
1750
|
+
torch.tensor(
|
|
1751
|
+
[score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
|
|
1752
|
+
),
|
|
1753
|
+
score_weights_cpu,
|
|
1754
|
+
)
|
|
1755
|
+
|
|
1756
|
+
if post_event is not None:
|
|
1757
|
+
write_stream.record_event(post_event)
|
|
1758
|
+
|
|
1555
1759
|
def prefetch(
|
|
1556
1760
|
self,
|
|
1557
1761
|
indices: Tensor,
|
|
1558
1762
|
offsets: Tensor,
|
|
1763
|
+
weights: Optional[Tensor] = None, # todo: need to update caller
|
|
1559
1764
|
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
1560
|
-
batch_size_per_feature_per_rank: Optional[
|
|
1765
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
1561
1766
|
) -> None:
|
|
1562
1767
|
if self.prefetch_stream is None and forward_stream is not None:
|
|
1563
1768
|
# Set the prefetch stream to the current stream
|
|
@@ -1581,6 +1786,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1581
1786
|
self._prefetch(
|
|
1582
1787
|
indices,
|
|
1583
1788
|
offsets,
|
|
1789
|
+
weights,
|
|
1584
1790
|
vbe_metadata,
|
|
1585
1791
|
forward_stream,
|
|
1586
1792
|
)
|
|
@@ -1589,11 +1795,17 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1589
1795
|
self,
|
|
1590
1796
|
indices: Tensor,
|
|
1591
1797
|
offsets: Tensor,
|
|
1798
|
+
weights: Optional[Tensor] = None,
|
|
1592
1799
|
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
|
|
1593
1800
|
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
1594
1801
|
) -> None:
|
|
1595
|
-
#
|
|
1802
|
+
# Wait for any ongoing direct_write_embedding operations to complete
|
|
1803
|
+
# Moving this from forward() to _prefetch() is more logical as direct_write
|
|
1804
|
+
# operations affect the same cache structures that prefetch interacts with
|
|
1596
1805
|
current_stream = torch.cuda.current_stream()
|
|
1806
|
+
if self._embedding_cache_mode:
|
|
1807
|
+
current_stream.wait_event(self.direct_write_l1_complete_event)
|
|
1808
|
+
current_stream.wait_event(self.direct_write_sp_complete_event)
|
|
1597
1809
|
|
|
1598
1810
|
B_offsets = None
|
|
1599
1811
|
max_B = -1
|
|
@@ -1700,8 +1912,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1700
1912
|
name="cache_update",
|
|
1701
1913
|
)
|
|
1702
1914
|
current_stream.wait_event(self.ssd_event_cache_streaming_synced)
|
|
1703
|
-
|
|
1704
|
-
|
|
1915
|
+
updated_indices, updated_counts_gpu = self.prefetched_info.pop(
|
|
1916
|
+
0
|
|
1705
1917
|
)
|
|
1706
1918
|
self.lxu_cache_updated_indices[: updated_indices.size(0)].copy_(
|
|
1707
1919
|
updated_indices,
|
|
@@ -1878,12 +2090,13 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1878
2090
|
# Store info for evicting the previous iteration's
|
|
1879
2091
|
# scratch pad after the corresponding backward pass is
|
|
1880
2092
|
# done
|
|
1881
|
-
self.
|
|
1882
|
-
(
|
|
1883
|
-
|
|
1884
|
-
|
|
2093
|
+
if self.training:
|
|
2094
|
+
self.ssd_location_update_data.append(
|
|
2095
|
+
(
|
|
2096
|
+
sp_curr_prev_map_gpu,
|
|
2097
|
+
inserted_rows,
|
|
2098
|
+
)
|
|
1885
2099
|
)
|
|
1886
|
-
)
|
|
1887
2100
|
|
|
1888
2101
|
# Ensure the previous iterations eviction is complete
|
|
1889
2102
|
current_stream.wait_event(self.ssd_event_sp_evict)
|
|
@@ -1931,7 +2144,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1931
2144
|
self.ssd_cache_stats = torch.add(
|
|
1932
2145
|
self.ssd_cache_stats, self.local_ssd_cache_stats
|
|
1933
2146
|
)
|
|
1934
|
-
|
|
2147
|
+
# only report metrics from rank0 to avoid flooded logging
|
|
2148
|
+
if dist.get_rank() == 0:
|
|
2149
|
+
self._report_kv_backend_stats()
|
|
2150
|
+
|
|
2151
|
+
# May trigger eviction if free mem trigger mode enabled before get cuda
|
|
2152
|
+
self.may_trigger_eviction()
|
|
1935
2153
|
|
|
1936
2154
|
# Fetch data from SSD
|
|
1937
2155
|
if linear_cache_indices.numel() > 0:
|
|
@@ -1955,21 +2173,35 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
1955
2173
|
use_pipeline=self.prefetch_pipeline,
|
|
1956
2174
|
)
|
|
1957
2175
|
|
|
1958
|
-
if
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
2176
|
+
if self.training:
|
|
2177
|
+
if linear_cache_indices.numel() > 0:
|
|
2178
|
+
# Evict rows from cache to SSD
|
|
2179
|
+
self.evict(
|
|
2180
|
+
rows=self.lxu_cache_evicted_weights,
|
|
2181
|
+
indices_cpu=self.lxu_cache_evicted_indices,
|
|
2182
|
+
actions_count_cpu=self.lxu_cache_evicted_count,
|
|
2183
|
+
stream=self.ssd_eviction_stream,
|
|
2184
|
+
pre_event=self.ssd_event_get,
|
|
2185
|
+
# Record completion event after scratch pad eviction
|
|
2186
|
+
# instead since that happens after L1 eviction
|
|
2187
|
+
post_event=self.ssd_event_cache_evict,
|
|
2188
|
+
is_rows_uvm=True,
|
|
2189
|
+
name="cache",
|
|
2190
|
+
is_bwd=False,
|
|
2191
|
+
)
|
|
2192
|
+
if (
|
|
2193
|
+
self.backend_type == BackendType.DRAM
|
|
2194
|
+
and weights is not None
|
|
2195
|
+
and linear_cache_indices.numel() > 0
|
|
2196
|
+
):
|
|
2197
|
+
# Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
|
|
2198
|
+
self._update_feature_score_metadata(
|
|
2199
|
+
linear_cache_indices=linear_cache_indices,
|
|
2200
|
+
weights=weights,
|
|
2201
|
+
d2h_stream=self.ssd_memcpy_stream,
|
|
2202
|
+
write_stream=self.feature_score_stream,
|
|
2203
|
+
pre_event_for_write=self.ssd_event_cache_evict,
|
|
2204
|
+
)
|
|
1973
2205
|
|
|
1974
2206
|
# Generate row addresses (pointing to either L1 or the current
|
|
1975
2207
|
# iteration's scratch pad)
|
|
@@ -2051,24 +2283,32 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2051
2283
|
)
|
|
2052
2284
|
)
|
|
2053
2285
|
|
|
2054
|
-
# Store scratch pad info for post backward eviction
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2286
|
+
# Store scratch pad info for post backward eviction only for training
|
|
2287
|
+
# for eval job, no backward pass, so no need to store this info
|
|
2288
|
+
if self.training:
|
|
2289
|
+
self.ssd_scratch_pad_eviction_data.append(
|
|
2290
|
+
(
|
|
2291
|
+
inserted_rows,
|
|
2292
|
+
post_bwd_evicted_indices_cpu,
|
|
2293
|
+
actions_count_cpu,
|
|
2294
|
+
linear_cache_indices.numel() > 0,
|
|
2295
|
+
)
|
|
2061
2296
|
)
|
|
2062
|
-
)
|
|
2063
2297
|
|
|
2064
2298
|
# Store data for forward
|
|
2065
2299
|
self.ssd_prefetch_data.append(prefetch_data)
|
|
2066
2300
|
|
|
2301
|
+
# Record an event to mark the completion of prefetch operations
|
|
2302
|
+
# This will be used by direct_write_embedding to ensure it doesn't run concurrently with prefetch
|
|
2303
|
+
current_stream.record_event(self.prefetch_complete_event)
|
|
2304
|
+
|
|
2067
2305
|
@torch.jit.ignore
|
|
2068
2306
|
def _generate_vbe_metadata(
|
|
2069
2307
|
self,
|
|
2070
2308
|
offsets: Tensor,
|
|
2071
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2309
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
2310
|
+
vbe_output: Optional[Tensor] = None,
|
|
2311
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
2072
2312
|
) -> invokers.lookup_args.VBEMetadata:
|
|
2073
2313
|
# Blocking D2H copy, but only runs at first call
|
|
2074
2314
|
self.feature_dims = self.feature_dims.cpu()
|
|
@@ -2087,6 +2327,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2087
2327
|
self.pooling_mode,
|
|
2088
2328
|
self.feature_dims,
|
|
2089
2329
|
self.current_device,
|
|
2330
|
+
vbe_output,
|
|
2331
|
+
vbe_output_offsets,
|
|
2090
2332
|
)
|
|
2091
2333
|
|
|
2092
2334
|
def _increment_iteration(self) -> int:
|
|
@@ -2113,14 +2355,30 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2113
2355
|
self,
|
|
2114
2356
|
indices: Tensor,
|
|
2115
2357
|
offsets: Tensor,
|
|
2358
|
+
weights: Optional[Tensor] = None,
|
|
2116
2359
|
per_sample_weights: Optional[Tensor] = None,
|
|
2117
2360
|
feature_requires_grad: Optional[Tensor] = None,
|
|
2118
|
-
batch_size_per_feature_per_rank: Optional[
|
|
2361
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
2362
|
+
vbe_output: Optional[Tensor] = None,
|
|
2363
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
2119
2364
|
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
|
|
2120
2365
|
) -> Tensor:
|
|
2121
2366
|
self.clear_cache()
|
|
2367
|
+
if vbe_output is not None or vbe_output_offsets is not None:
|
|
2368
|
+
# CPU is not supported in SSD TBE
|
|
2369
|
+
check_allocated_vbe_output(
|
|
2370
|
+
self.output_dtype,
|
|
2371
|
+
batch_size_per_feature_per_rank,
|
|
2372
|
+
vbe_output,
|
|
2373
|
+
vbe_output_offsets,
|
|
2374
|
+
)
|
|
2122
2375
|
indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
|
|
2123
|
-
indices,
|
|
2376
|
+
indices,
|
|
2377
|
+
offsets,
|
|
2378
|
+
per_sample_weights,
|
|
2379
|
+
batch_size_per_feature_per_rank,
|
|
2380
|
+
vbe_output=vbe_output,
|
|
2381
|
+
vbe_output_offsets=vbe_output_offsets,
|
|
2124
2382
|
)
|
|
2125
2383
|
|
|
2126
2384
|
if len(self.timesteps_prefetched) == 0:
|
|
@@ -2134,7 +2392,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2134
2392
|
context=self.step,
|
|
2135
2393
|
stream=self.ssd_eviction_stream,
|
|
2136
2394
|
):
|
|
2137
|
-
self._prefetch(indices, offsets, vbe_metadata)
|
|
2395
|
+
self._prefetch(indices, offsets, weights, vbe_metadata)
|
|
2138
2396
|
|
|
2139
2397
|
assert len(self.ssd_prefetch_data) > 0
|
|
2140
2398
|
|
|
@@ -2205,13 +2463,21 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2205
2463
|
self.step += 1
|
|
2206
2464
|
|
|
2207
2465
|
# Increment the iteration (value is used for certain optimizers)
|
|
2208
|
-
self._increment_iteration()
|
|
2209
|
-
|
|
2210
|
-
if self.optimizer
|
|
2211
|
-
|
|
2212
|
-
|
|
2466
|
+
iter_int = self._increment_iteration()
|
|
2467
|
+
|
|
2468
|
+
if self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
|
|
2469
|
+
momentum2 = invokers.lookup_args_ssd.Momentum(
|
|
2470
|
+
# pyre-ignore[6]
|
|
2471
|
+
dev=self.momentum2_dev,
|
|
2472
|
+
# pyre-ignore[6]
|
|
2473
|
+
host=self.momentum2_host,
|
|
2474
|
+
# pyre-ignore[6]
|
|
2475
|
+
uvm=self.momentum2_uvm,
|
|
2476
|
+
# pyre-ignore[6]
|
|
2477
|
+
offsets=self.momentum2_offsets,
|
|
2478
|
+
# pyre-ignore[6]
|
|
2479
|
+
placements=self.momentum2_placements,
|
|
2213
2480
|
)
|
|
2214
|
-
return invokers.lookup_sgd_ssd.invoke(common_args, self.optimizer_args)
|
|
2215
2481
|
|
|
2216
2482
|
momentum1 = invokers.lookup_args_ssd.Momentum(
|
|
2217
2483
|
dev=self.momentum1_dev,
|
|
@@ -2226,10 +2492,44 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2226
2492
|
common_args, self.optimizer_args, momentum1
|
|
2227
2493
|
)
|
|
2228
2494
|
|
|
2495
|
+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
|
|
2496
|
+
return invokers.lookup_partial_rowwise_adam_ssd.invoke(
|
|
2497
|
+
common_args,
|
|
2498
|
+
self.optimizer_args,
|
|
2499
|
+
momentum1,
|
|
2500
|
+
# pyre-ignore[61]
|
|
2501
|
+
momentum2,
|
|
2502
|
+
iter_int,
|
|
2503
|
+
)
|
|
2504
|
+
|
|
2505
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2506
|
+
row_counter = invokers.lookup_args_ssd.Momentum(
|
|
2507
|
+
# pyre-fixme[6]
|
|
2508
|
+
dev=self.row_counter_dev,
|
|
2509
|
+
# pyre-fixme[6]
|
|
2510
|
+
host=self.row_counter_host,
|
|
2511
|
+
# pyre-fixme[6]
|
|
2512
|
+
uvm=self.row_counter_uvm,
|
|
2513
|
+
# pyre-fixme[6]
|
|
2514
|
+
offsets=self.row_counter_offsets,
|
|
2515
|
+
# pyre-fixme[6]
|
|
2516
|
+
placements=self.row_counter_placements,
|
|
2517
|
+
)
|
|
2518
|
+
|
|
2519
|
+
return invokers.lookup_adam_ssd.invoke(
|
|
2520
|
+
common_args,
|
|
2521
|
+
self.optimizer_args,
|
|
2522
|
+
momentum1,
|
|
2523
|
+
# pyre-ignore[61]
|
|
2524
|
+
momentum2,
|
|
2525
|
+
iter_int,
|
|
2526
|
+
row_counter=row_counter,
|
|
2527
|
+
)
|
|
2528
|
+
|
|
2229
2529
|
@torch.jit.ignore
|
|
2230
2530
|
def _split_optimizer_states_non_kv_zch(
|
|
2231
2531
|
self,
|
|
2232
|
-
) ->
|
|
2532
|
+
) -> list[list[torch.Tensor]]:
|
|
2233
2533
|
"""
|
|
2234
2534
|
Returns a list of optimizer states (view), split by table.
|
|
2235
2535
|
|
|
@@ -2246,11 +2546,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2246
2546
|
"""
|
|
2247
2547
|
|
|
2248
2548
|
# Row count per table
|
|
2249
|
-
|
|
2549
|
+
rows, dims = zip(*self.embedding_specs)
|
|
2250
2550
|
# Cumulative row counts per table for rowwise states
|
|
2251
|
-
row_count_cumsum:
|
|
2551
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
2252
2552
|
# Cumulative element counts per table for elementwise states
|
|
2253
|
-
elem_count_cumsum:
|
|
2553
|
+
elem_count_cumsum: list[int] = [0] + list(
|
|
2254
2554
|
itertools.accumulate([r * d for r, d in self.embedding_specs])
|
|
2255
2555
|
)
|
|
2256
2556
|
|
|
@@ -2286,6 +2586,17 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2286
2586
|
]
|
|
2287
2587
|
for t, _ in enumerate(rows)
|
|
2288
2588
|
]
|
|
2589
|
+
|
|
2590
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2591
|
+
return [
|
|
2592
|
+
[
|
|
2593
|
+
_slice(self.momentum1_dev, t, rowwise=False),
|
|
2594
|
+
# pyre-ignore[6]
|
|
2595
|
+
_slice(self.momentum2_dev, t, rowwise=False),
|
|
2596
|
+
]
|
|
2597
|
+
for t, _ in enumerate(rows)
|
|
2598
|
+
]
|
|
2599
|
+
|
|
2289
2600
|
else:
|
|
2290
2601
|
raise NotImplementedError(
|
|
2291
2602
|
f"Getting optimizer states is not supported for {self.optimizer}"
|
|
@@ -2295,14 +2606,14 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2295
2606
|
def _split_optimizer_states_kv_zch_no_offloading(
|
|
2296
2607
|
self,
|
|
2297
2608
|
sorted_ids: torch.Tensor,
|
|
2298
|
-
) ->
|
|
2609
|
+
) -> list[list[torch.Tensor]]:
|
|
2299
2610
|
|
|
2300
2611
|
# Row count per table
|
|
2301
|
-
|
|
2612
|
+
rows, dims = zip(*self.embedding_specs)
|
|
2302
2613
|
# Cumulative row counts per table for rowwise states
|
|
2303
|
-
row_count_cumsum:
|
|
2614
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
2304
2615
|
# Cumulative element counts per table for elementwise states
|
|
2305
|
-
elem_count_cumsum:
|
|
2616
|
+
elem_count_cumsum: list[int] = [0] + list(
|
|
2306
2617
|
itertools.accumulate([r * d for r, d in self.embedding_specs])
|
|
2307
2618
|
)
|
|
2308
2619
|
|
|
@@ -2332,7 +2643,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2332
2643
|
# based on the sorted_ids compute the table offset for the
|
|
2333
2644
|
# table, view the slice as 2D tensor of e x d, then fetch the
|
|
2334
2645
|
# sub-slice by local ids
|
|
2335
|
-
|
|
2646
|
+
#
|
|
2647
|
+
# local_ids is [N, 1], flatten it to N to keep the returned tensor 2D
|
|
2648
|
+
local_ids = (sorted_ids[t] - bucket_id_start * bucket_size).view(-1)
|
|
2336
2649
|
return (
|
|
2337
2650
|
tensor.detach()
|
|
2338
2651
|
.cpu()[elem_count_cumsum[t] : elem_count_cumsum[t + 1]]
|
|
@@ -2364,6 +2677,16 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2364
2677
|
for t, _ in enumerate(rows)
|
|
2365
2678
|
]
|
|
2366
2679
|
|
|
2680
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2681
|
+
return [
|
|
2682
|
+
[
|
|
2683
|
+
_slice("momentum1", self.momentum1_dev, t, rowwise=False),
|
|
2684
|
+
# pyre-ignore[6]
|
|
2685
|
+
_slice("momentum2", self.momentum2_dev, t, rowwise=False),
|
|
2686
|
+
]
|
|
2687
|
+
for t, _ in enumerate(rows)
|
|
2688
|
+
]
|
|
2689
|
+
|
|
2367
2690
|
else:
|
|
2368
2691
|
raise NotImplementedError(
|
|
2369
2692
|
f"Getting optimizer states is not supported for {self.optimizer}"
|
|
@@ -2375,12 +2698,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2375
2698
|
sorted_ids: torch.Tensor,
|
|
2376
2699
|
no_snapshot: bool = True,
|
|
2377
2700
|
should_flush: bool = False,
|
|
2378
|
-
) ->
|
|
2701
|
+
) -> list[list[torch.Tensor]]:
|
|
2379
2702
|
dtype = self.weights_precision.as_dtype()
|
|
2380
2703
|
# Row count per table
|
|
2381
|
-
|
|
2704
|
+
rows_, dims_ = zip(*self.embedding_specs)
|
|
2382
2705
|
# Cumulative row counts per table for rowwise states
|
|
2383
|
-
row_count_cumsum:
|
|
2706
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
|
|
2384
2707
|
|
|
2385
2708
|
snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
|
|
2386
2709
|
no_snapshot=no_snapshot,
|
|
@@ -2390,7 +2713,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2390
2713
|
# pyre-ignore[53]
|
|
2391
2714
|
def _fetch_offloaded_optimizer_states(
|
|
2392
2715
|
t: int,
|
|
2393
|
-
) ->
|
|
2716
|
+
) -> list[Tensor]:
|
|
2394
2717
|
e: int = rows_[t]
|
|
2395
2718
|
d: int = dims_[t]
|
|
2396
2719
|
|
|
@@ -2403,12 +2726,31 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2403
2726
|
# Count of rows to fetch
|
|
2404
2727
|
rows_to_fetch = sorted_ids[t].numel()
|
|
2405
2728
|
|
|
2729
|
+
# Lookup the byte offsets for each optimizer state
|
|
2730
|
+
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
2731
|
+
d, self.weights_precision, self.optimizer_state_dtypes
|
|
2732
|
+
)
|
|
2733
|
+
# Find the minimum start of all the start/end pairs - we have to
|
|
2734
|
+
# offset the start/end pairs by this value to get the correct start/end
|
|
2735
|
+
offset_ = min(
|
|
2736
|
+
[start for _, (start, _) in optimizer_state_byte_offsets.items()]
|
|
2737
|
+
)
|
|
2738
|
+
# Update the start/end pairs to be relative to offset_
|
|
2739
|
+
optimizer_state_byte_offsets = dict(
|
|
2740
|
+
(k, (v1 - offset_, v2 - offset_))
|
|
2741
|
+
for k, (v1, v2) in optimizer_state_byte_offsets.items()
|
|
2742
|
+
)
|
|
2743
|
+
|
|
2406
2744
|
# Since the backend returns cache rows that pack the weights and
|
|
2407
2745
|
# optimizer states together, reading the whole tensor could cause OOM,
|
|
2408
2746
|
# so we use the KVTensorWrapper abstraction to query the backend and
|
|
2409
2747
|
# fetch the data in chunks instead.
|
|
2410
2748
|
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
2411
|
-
shape=[
|
|
2749
|
+
shape=[
|
|
2750
|
+
e,
|
|
2751
|
+
# Dim is terms of **weights** dtype
|
|
2752
|
+
self.optimizer_state_dim,
|
|
2753
|
+
],
|
|
2412
2754
|
dtype=dtype,
|
|
2413
2755
|
row_offset=row_offset,
|
|
2414
2756
|
snapshot_handle=snapshot_handle,
|
|
@@ -2421,19 +2763,6 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2421
2763
|
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
2422
2764
|
)
|
|
2423
2765
|
|
|
2424
|
-
# Lookup the byte offsets for each optimizer state
|
|
2425
|
-
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
2426
|
-
d, self.weights_precision, self.optimizer_state_dtypes
|
|
2427
|
-
)
|
|
2428
|
-
# Since we will be working with buffer rows that contain the
|
|
2429
|
-
# optimizer states only, we need to offset the byte offsets by
|
|
2430
|
-
# D * dtype.itemsize
|
|
2431
|
-
offset_ = d * dtype.itemsize
|
|
2432
|
-
optimizer_state_byte_offsets = dict(
|
|
2433
|
-
(k, (v1 - offset_, v2 - offset_))
|
|
2434
|
-
for k, (v1, v2) in optimizer_state_byte_offsets.items()
|
|
2435
|
-
)
|
|
2436
|
-
|
|
2437
2766
|
# Fetch the state size table for the given weights domension
|
|
2438
2767
|
state_size_table = self.optimizer.state_size_table(d)
|
|
2439
2768
|
|
|
@@ -2462,10 +2791,10 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2462
2791
|
)
|
|
2463
2792
|
|
|
2464
2793
|
# Now split up the buffer into N views, N for each optimizer state
|
|
2465
|
-
optimizer_states:
|
|
2794
|
+
optimizer_states: list[Tensor] = []
|
|
2466
2795
|
for state_name in self.optimizer.state_names():
|
|
2467
2796
|
# Extract the offsets
|
|
2468
|
-
|
|
2797
|
+
start, end = optimizer_state_byte_offsets[state_name]
|
|
2469
2798
|
|
|
2470
2799
|
state = optimizer_states_buffer.view(
|
|
2471
2800
|
# Force tensor to byte view
|
|
@@ -2500,13 +2829,150 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2500
2829
|
for t, d in enumerate(dims_)
|
|
2501
2830
|
]
|
|
2502
2831
|
|
|
2832
|
+
@torch.jit.ignore
|
|
2833
|
+
def _split_optimizer_states_kv_zch_whole_row(
|
|
2834
|
+
self,
|
|
2835
|
+
sorted_ids: torch.Tensor,
|
|
2836
|
+
no_snapshot: bool = True,
|
|
2837
|
+
should_flush: bool = False,
|
|
2838
|
+
) -> list[list[torch.Tensor]]:
|
|
2839
|
+
dtype = self.weights_precision.as_dtype()
|
|
2840
|
+
|
|
2841
|
+
# Row and dimension counts per table
|
|
2842
|
+
# rows_ is only used here to compute the virtual table offsets
|
|
2843
|
+
rows_, dims_ = zip(*self.embedding_specs)
|
|
2844
|
+
|
|
2845
|
+
# Cumulative row counts per (virtual) table for rowwise states
|
|
2846
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
|
|
2847
|
+
|
|
2848
|
+
snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
|
|
2849
|
+
no_snapshot=no_snapshot,
|
|
2850
|
+
should_flush=should_flush,
|
|
2851
|
+
)
|
|
2852
|
+
|
|
2853
|
+
# pyre-ignore[53]
|
|
2854
|
+
def _fetch_offloaded_optimizer_states(
|
|
2855
|
+
t: int,
|
|
2856
|
+
) -> list[Tensor]:
|
|
2857
|
+
d: int = dims_[t]
|
|
2858
|
+
|
|
2859
|
+
# pyre-ignore[16]
|
|
2860
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
2861
|
+
# pyre-ignore[16]
|
|
2862
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
2863
|
+
row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
|
|
2864
|
+
|
|
2865
|
+
# When backend returns whole row, the optimizer will be returned as
|
|
2866
|
+
# PMT directly
|
|
2867
|
+
if sorted_ids[t].size(0) == 0 and self.local_weight_counts[t] > 0:
|
|
2868
|
+
logging.info(
|
|
2869
|
+
f"Before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
|
|
2870
|
+
)
|
|
2871
|
+
sorted_ids[t] = torch.zeros(
|
|
2872
|
+
(self.local_weight_counts[t], 1),
|
|
2873
|
+
device=torch.device("cpu"),
|
|
2874
|
+
dtype=torch.int64,
|
|
2875
|
+
)
|
|
2876
|
+
|
|
2877
|
+
# Lookup the byte offsets for each optimizer state relative to the
|
|
2878
|
+
# start of the weights
|
|
2879
|
+
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
2880
|
+
d, self.weights_precision, self.optimizer_state_dtypes
|
|
2881
|
+
)
|
|
2882
|
+
# Get the number of elements (of the optimizer state dtype) per state
|
|
2883
|
+
optimizer_state_size_table = self.optimizer.state_size_table(d)
|
|
2884
|
+
|
|
2885
|
+
# Get metaheader dimensions in number of elements of weight dtype
|
|
2886
|
+
metaheader_dim = (
|
|
2887
|
+
# pyre-ignore[16]
|
|
2888
|
+
self.kv_zch_params.eviction_policy.meta_header_lens[t]
|
|
2889
|
+
)
|
|
2890
|
+
|
|
2891
|
+
# Now split up the buffer into N views, N for each optimizer state
|
|
2892
|
+
optimizer_states: list[PartiallyMaterializedTensor] = []
|
|
2893
|
+
for state_name in self.optimizer.state_names():
|
|
2894
|
+
state_dtype = self.optimizer_state_dtypes.get(
|
|
2895
|
+
state_name, SparseType.FP32
|
|
2896
|
+
).as_dtype()
|
|
2897
|
+
|
|
2898
|
+
# Get the size of the state in elements of the optimizer state,
|
|
2899
|
+
# in terms of the **weights** dtype
|
|
2900
|
+
state_size = math.ceil(
|
|
2901
|
+
optimizer_state_size_table[state_name]
|
|
2902
|
+
* state_dtype.itemsize
|
|
2903
|
+
/ dtype.itemsize
|
|
2904
|
+
)
|
|
2905
|
+
|
|
2906
|
+
# Extract the offsets relative to the start of the weights (in
|
|
2907
|
+
# num bytes)
|
|
2908
|
+
start, _ = optimizer_state_byte_offsets[state_name]
|
|
2909
|
+
|
|
2910
|
+
# Convert the start to number of elements in terms of the
|
|
2911
|
+
# **weights** dtype, then add the mmetaheader dim offset
|
|
2912
|
+
start = metaheader_dim + start // dtype.itemsize
|
|
2913
|
+
|
|
2914
|
+
shape = [
|
|
2915
|
+
(
|
|
2916
|
+
sorted_ids[t].size(0)
|
|
2917
|
+
if sorted_ids is not None and sorted_ids[t].size(0) > 0
|
|
2918
|
+
else self.local_weight_counts[t]
|
|
2919
|
+
),
|
|
2920
|
+
(
|
|
2921
|
+
# Dim is in terms of the **weights** dtype
|
|
2922
|
+
state_size
|
|
2923
|
+
),
|
|
2924
|
+
]
|
|
2925
|
+
|
|
2926
|
+
# NOTE: We have to view using the **weights** dtype, as
|
|
2927
|
+
# there is currently a bug with KVTensorWrapper where using
|
|
2928
|
+
# a different dtype does not result in the same bytes being
|
|
2929
|
+
# returned, e.g.
|
|
2930
|
+
#
|
|
2931
|
+
# KVTensorWrapper(dtype=fp32, width_offset=130, shape=[N, 1])
|
|
2932
|
+
#
|
|
2933
|
+
# is NOT the same as
|
|
2934
|
+
#
|
|
2935
|
+
# KVTensorWrapper(dtype=fp16, width_offset=260, shape=[N, 2]).view(-1).view(fp32)
|
|
2936
|
+
#
|
|
2937
|
+
# TODO: Fix KVTensorWrapper to support viewing data under different dtypes
|
|
2938
|
+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
2939
|
+
shape=shape,
|
|
2940
|
+
dtype=(
|
|
2941
|
+
# NOTE: Use the *weights* dtype
|
|
2942
|
+
dtype
|
|
2943
|
+
),
|
|
2944
|
+
row_offset=row_offset,
|
|
2945
|
+
snapshot_handle=snapshot_handle,
|
|
2946
|
+
sorted_indices=sorted_ids[t],
|
|
2947
|
+
width_offset=(
|
|
2948
|
+
# NOTE: Width offset is in terms of **weights** dtype
|
|
2949
|
+
start
|
|
2950
|
+
),
|
|
2951
|
+
# Optimizer written to DB with weights, so skip write here
|
|
2952
|
+
read_only=True,
|
|
2953
|
+
)
|
|
2954
|
+
(
|
|
2955
|
+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
2956
|
+
if self.backend_type == BackendType.SSD
|
|
2957
|
+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
2958
|
+
)
|
|
2959
|
+
|
|
2960
|
+
optimizer_states.append(
|
|
2961
|
+
PartiallyMaterializedTensor(tensor_wrapper, True)
|
|
2962
|
+
)
|
|
2963
|
+
|
|
2964
|
+
# pyre-ignore [7]
|
|
2965
|
+
return optimizer_states
|
|
2966
|
+
|
|
2967
|
+
return [_fetch_offloaded_optimizer_states(t) for t, _ in enumerate(dims_)]
|
|
2968
|
+
|
|
2503
2969
|
@torch.jit.export
|
|
2504
2970
|
def split_optimizer_states(
|
|
2505
2971
|
self,
|
|
2506
|
-
sorted_id_tensor: Optional[
|
|
2972
|
+
sorted_id_tensor: Optional[list[torch.Tensor]] = None,
|
|
2507
2973
|
no_snapshot: bool = True,
|
|
2508
2974
|
should_flush: bool = False,
|
|
2509
|
-
) ->
|
|
2975
|
+
) -> list[list[torch.Tensor]]:
|
|
2510
2976
|
"""
|
|
2511
2977
|
Returns a list of optimizer states split by table.
|
|
2512
2978
|
|
|
@@ -2555,75 +3021,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2555
3021
|
)
|
|
2556
3022
|
|
|
2557
3023
|
else:
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
should_flush
|
|
3024
|
+
# Handle the KVZCH with-optimizer-offloading backend-whole-row case
|
|
3025
|
+
optimizer_states = self._split_optimizer_states_kv_zch_whole_row(
|
|
3026
|
+
sorted_id_tensor, no_snapshot, should_flush
|
|
2561
3027
|
)
|
|
2562
3028
|
|
|
2563
|
-
optimizer_states = []
|
|
2564
|
-
table_offset = 0
|
|
2565
|
-
|
|
2566
|
-
for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
|
|
2567
|
-
# pyre-ignore
|
|
2568
|
-
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
2569
|
-
# pyre-ignore
|
|
2570
|
-
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
2571
|
-
row_offset = table_offset - (bucket_id_start * bucket_size)
|
|
2572
|
-
|
|
2573
|
-
# When backend returns whole row, the optimizer will be returned as PMT directly
|
|
2574
|
-
# pyre-ignore [16]
|
|
2575
|
-
if sorted_id_tensor[t].size(0) == 0 and self.local_weight_counts[t] > 0:
|
|
2576
|
-
logging.info(
|
|
2577
|
-
f"before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
|
|
2578
|
-
)
|
|
2579
|
-
# pyre-ignore [16]
|
|
2580
|
-
sorted_id_tensor[t] = torch.zeros(
|
|
2581
|
-
(self.local_weight_counts[t], 1),
|
|
2582
|
-
device=torch.device("cpu"),
|
|
2583
|
-
dtype=torch.int64,
|
|
2584
|
-
)
|
|
2585
|
-
|
|
2586
|
-
metaheader_dim = (
|
|
2587
|
-
# pyre-ignore[16]
|
|
2588
|
-
self.kv_zch_params.eviction_policy.meta_header_lens[t]
|
|
2589
|
-
)
|
|
2590
|
-
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
2591
|
-
shape=[
|
|
2592
|
-
(
|
|
2593
|
-
sorted_id_tensor[t].size(0)
|
|
2594
|
-
if sorted_id_tensor is not None
|
|
2595
|
-
and sorted_id_tensor[t].size(0) > 0
|
|
2596
|
-
else emb_height
|
|
2597
|
-
),
|
|
2598
|
-
self.optimizer_state_dim,
|
|
2599
|
-
],
|
|
2600
|
-
dtype=self.weights_precision.as_dtype(),
|
|
2601
|
-
row_offset=row_offset,
|
|
2602
|
-
snapshot_handle=snapshot_handle,
|
|
2603
|
-
sorted_indices=sorted_id_tensor[t],
|
|
2604
|
-
width_offset=(
|
|
2605
|
-
metaheader_dim # metaheader is already padded so no need for pad4
|
|
2606
|
-
+ pad4(emb_dim)
|
|
2607
|
-
),
|
|
2608
|
-
read_only=True, # optimizer written to DB with weights, so skip write here
|
|
2609
|
-
)
|
|
2610
|
-
(
|
|
2611
|
-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
2612
|
-
if self.backend_type == BackendType.SSD
|
|
2613
|
-
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
2614
|
-
)
|
|
2615
|
-
|
|
2616
|
-
optimizer_states.append(
|
|
2617
|
-
[
|
|
2618
|
-
PartiallyMaterializedTensor(
|
|
2619
|
-
tensor_wrapper,
|
|
2620
|
-
True if self.kv_zch_params else False,
|
|
2621
|
-
)
|
|
2622
|
-
]
|
|
2623
|
-
)
|
|
2624
|
-
|
|
2625
|
-
table_offset += emb_height
|
|
2626
|
-
|
|
2627
3029
|
logging.info(
|
|
2628
3030
|
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
|
|
2629
3031
|
# pyre-ignore[16]
|
|
@@ -2635,14 +3037,14 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2635
3037
|
@torch.jit.export
|
|
2636
3038
|
def get_optimizer_state(
|
|
2637
3039
|
self,
|
|
2638
|
-
sorted_id_tensor: Optional[
|
|
3040
|
+
sorted_id_tensor: Optional[list[torch.Tensor]],
|
|
2639
3041
|
no_snapshot: bool = True,
|
|
2640
3042
|
should_flush: bool = False,
|
|
2641
|
-
) ->
|
|
3043
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
2642
3044
|
"""
|
|
2643
3045
|
Returns a list of dictionaries of optimizer states split by table.
|
|
2644
3046
|
"""
|
|
2645
|
-
states_list:
|
|
3047
|
+
states_list: list[list[Tensor]] = self.split_optimizer_states(
|
|
2646
3048
|
sorted_id_tensor=sorted_id_tensor,
|
|
2647
3049
|
no_snapshot=no_snapshot,
|
|
2648
3050
|
should_flush=should_flush,
|
|
@@ -2651,13 +3053,13 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2651
3053
|
return [dict(zip(state_names, states)) for states in states_list]
|
|
2652
3054
|
|
|
2653
3055
|
@torch.jit.export
|
|
2654
|
-
def debug_split_embedding_weights(self) ->
|
|
3056
|
+
def debug_split_embedding_weights(self) -> list[torch.Tensor]:
|
|
2655
3057
|
"""
|
|
2656
3058
|
Returns a list of weights, split by table.
|
|
2657
3059
|
|
|
2658
3060
|
Testing only, very slow.
|
|
2659
3061
|
"""
|
|
2660
|
-
|
|
3062
|
+
rows, _ = zip(*self.embedding_specs)
|
|
2661
3063
|
|
|
2662
3064
|
rows_cumsum = [0] + list(itertools.accumulate(rows))
|
|
2663
3065
|
splits = []
|
|
@@ -2738,15 +3140,48 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2738
3140
|
self.flush(force=should_flush)
|
|
2739
3141
|
return snapshot_handle, checkpoint_handle
|
|
2740
3142
|
|
|
3143
|
+
def get_embedding_dim_for_kvt(
|
|
3144
|
+
self, metaheader_dim: int, emb_dim: int, is_loading_checkpoint: bool
|
|
3145
|
+
) -> int:
|
|
3146
|
+
if self.load_ckpt_without_opt:
|
|
3147
|
+
# For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host.
|
|
3148
|
+
# So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that
|
|
3149
|
+
# backend will only have metaheader + weight.
|
|
3150
|
+
# For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error
|
|
3151
|
+
# after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb
|
|
3152
|
+
if is_loading_checkpoint:
|
|
3153
|
+
return (
|
|
3154
|
+
(
|
|
3155
|
+
metaheader_dim # metaheader is already padded
|
|
3156
|
+
+ pad4(emb_dim)
|
|
3157
|
+
+ pad4(self.optimizer_state_dim)
|
|
3158
|
+
)
|
|
3159
|
+
if self.backend_return_whole_row
|
|
3160
|
+
else emb_dim
|
|
3161
|
+
)
|
|
3162
|
+
else:
|
|
3163
|
+
return metaheader_dim + pad4(emb_dim)
|
|
3164
|
+
else:
|
|
3165
|
+
return (
|
|
3166
|
+
(
|
|
3167
|
+
metaheader_dim # metaheader is already padded
|
|
3168
|
+
+ pad4(emb_dim)
|
|
3169
|
+
+ pad4(self.optimizer_state_dim)
|
|
3170
|
+
)
|
|
3171
|
+
if self.backend_return_whole_row
|
|
3172
|
+
else emb_dim
|
|
3173
|
+
)
|
|
3174
|
+
|
|
2741
3175
|
@torch.jit.export
|
|
2742
3176
|
def split_embedding_weights(
|
|
2743
3177
|
self,
|
|
2744
3178
|
no_snapshot: bool = True,
|
|
2745
3179
|
should_flush: bool = False,
|
|
2746
|
-
) ->
|
|
2747
|
-
Union[
|
|
2748
|
-
Optional[
|
|
2749
|
-
Optional[
|
|
3180
|
+
) -> tuple[ # TODO: make this a NamedTuple for readability
|
|
3181
|
+
Union[list[PartiallyMaterializedTensor], list[torch.Tensor]],
|
|
3182
|
+
Optional[list[torch.Tensor]],
|
|
3183
|
+
Optional[list[torch.Tensor]],
|
|
3184
|
+
Optional[list[torch.Tensor]],
|
|
2750
3185
|
]:
|
|
2751
3186
|
"""
|
|
2752
3187
|
This method is intended to be used by the checkpointing engine
|
|
@@ -2766,6 +3201,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2766
3201
|
2nd arg: input id sorted in bucket id ascending order
|
|
2767
3202
|
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
|
|
2768
3203
|
where for the i th element, we have i + bucket_id_start = global bucket id
|
|
3204
|
+
4th arg: kvzch eviction metadata for each input id sorted in bucket id ascending order
|
|
2769
3205
|
"""
|
|
2770
3206
|
snapshot_handle, checkpoint_handle = self._may_create_snapshot_for_state_dict(
|
|
2771
3207
|
no_snapshot=no_snapshot,
|
|
@@ -2782,16 +3218,21 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2782
3218
|
self._cached_kvzch_data.cached_weight_tensor_per_table,
|
|
2783
3219
|
self._cached_kvzch_data.cached_id_tensor_per_table,
|
|
2784
3220
|
self._cached_kvzch_data.cached_bucket_splits,
|
|
3221
|
+
[], # metadata tensor is not needed for checkpointing loading
|
|
2785
3222
|
)
|
|
2786
3223
|
start_time = time.time()
|
|
2787
3224
|
pmt_splits = []
|
|
2788
3225
|
bucket_sorted_id_splits = [] if self.kv_zch_params else None
|
|
2789
3226
|
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
|
|
3227
|
+
metadata_splits = [] if self.kv_zch_params else None
|
|
3228
|
+
skip_metadata = False
|
|
2790
3229
|
|
|
2791
3230
|
table_offset = 0
|
|
2792
3231
|
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
|
|
3232
|
+
is_loading_checkpoint = False
|
|
2793
3233
|
bucket_ascending_id_tensor = None
|
|
2794
3234
|
bucket_t = None
|
|
3235
|
+
metadata_tensor = None
|
|
2795
3236
|
row_offset = table_offset
|
|
2796
3237
|
metaheader_dim = 0
|
|
2797
3238
|
if self.kv_zch_params:
|
|
@@ -2823,6 +3264,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2823
3264
|
bucket_size,
|
|
2824
3265
|
)
|
|
2825
3266
|
)
|
|
3267
|
+
metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
|
|
3268
|
+
bucket_ascending_id_tensor + table_offset,
|
|
3269
|
+
torch.as_tensor(bucket_ascending_id_tensor.size(0)),
|
|
3270
|
+
snapshot_handle,
|
|
3271
|
+
).view(-1, 1)
|
|
3272
|
+
|
|
2826
3273
|
# 3. convert local id back to global id
|
|
2827
3274
|
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
|
|
2828
3275
|
|
|
@@ -2833,16 +3280,32 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2833
3280
|
logging.info(
|
|
2834
3281
|
f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
|
|
2835
3282
|
)
|
|
2836
|
-
|
|
2837
|
-
(
|
|
2838
|
-
|
|
2839
|
-
|
|
2840
|
-
|
|
3283
|
+
if self.global_id_per_rank[i].numel() != 0:
|
|
3284
|
+
assert (
|
|
3285
|
+
self.local_weight_counts[i]
|
|
3286
|
+
== self.global_id_per_rank[i].numel()
|
|
3287
|
+
), f"local weight count and global id per rank size mismatch, with {self.local_weight_counts[i]} and {self.global_id_per_rank[i].numel()}"
|
|
3288
|
+
bucket_ascending_id_tensor = self.global_id_per_rank[i].to(
|
|
3289
|
+
device=torch.device("cpu"), dtype=torch.int64
|
|
3290
|
+
)
|
|
3291
|
+
else:
|
|
3292
|
+
bucket_ascending_id_tensor = torch.zeros(
|
|
3293
|
+
(self.local_weight_counts[i], 1),
|
|
3294
|
+
device=torch.device("cpu"),
|
|
3295
|
+
dtype=torch.int64,
|
|
3296
|
+
)
|
|
3297
|
+
skip_metadata = True
|
|
3298
|
+
is_loading_checkpoint = True
|
|
3299
|
+
|
|
2841
3300
|
# self.local_weight_counts[i] = 0 # Reset the count
|
|
2842
3301
|
|
|
2843
3302
|
# pyre-ignore [16] bucket_sorted_id_splits is not None
|
|
2844
3303
|
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
|
|
2845
3304
|
active_id_cnt_per_bucket_split.append(bucket_t)
|
|
3305
|
+
if skip_metadata:
|
|
3306
|
+
metadata_splits = None
|
|
3307
|
+
else:
|
|
3308
|
+
metadata_splits.append(metadata_tensor)
|
|
2846
3309
|
|
|
2847
3310
|
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
|
|
2848
3311
|
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
|
|
@@ -2857,14 +3320,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2857
3320
|
if bucket_ascending_id_tensor is not None
|
|
2858
3321
|
else emb_height
|
|
2859
3322
|
),
|
|
2860
|
-
(
|
|
2861
|
-
|
|
2862
|
-
metaheader_dim # metaheader is already padded
|
|
2863
|
-
+ pad4(emb_dim)
|
|
2864
|
-
+ pad4(self.optimizer_state_dim)
|
|
2865
|
-
)
|
|
2866
|
-
if self.backend_return_whole_row
|
|
2867
|
-
else emb_dim
|
|
3323
|
+
self.get_embedding_dim_for_kvt(
|
|
3324
|
+
metaheader_dim, emb_dim, is_loading_checkpoint
|
|
2868
3325
|
),
|
|
2869
3326
|
],
|
|
2870
3327
|
dtype=dtype,
|
|
@@ -2876,6 +3333,11 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2876
3333
|
bucket_ascending_id_tensor if self.kv_zch_params else None
|
|
2877
3334
|
),
|
|
2878
3335
|
checkpoint_handle=checkpoint_handle,
|
|
3336
|
+
only_load_weight=(
|
|
3337
|
+
True
|
|
3338
|
+
if self.load_ckpt_without_opt and is_loading_checkpoint
|
|
3339
|
+
else False
|
|
3340
|
+
),
|
|
2879
3341
|
)
|
|
2880
3342
|
(
|
|
2881
3343
|
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
@@ -2898,14 +3360,19 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2898
3360
|
f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
|
|
2899
3361
|
)
|
|
2900
3362
|
|
|
2901
|
-
return (
|
|
3363
|
+
return (
|
|
3364
|
+
pmt_splits,
|
|
3365
|
+
bucket_sorted_id_splits,
|
|
3366
|
+
active_id_cnt_per_bucket_split,
|
|
3367
|
+
metadata_splits,
|
|
3368
|
+
)
|
|
2902
3369
|
|
|
2903
3370
|
@torch.jit.ignore
|
|
2904
3371
|
def _apply_state_dict_w_offloading(self) -> None:
|
|
2905
3372
|
# Row count per table
|
|
2906
|
-
|
|
3373
|
+
rows, _ = zip(*self.embedding_specs)
|
|
2907
3374
|
# Cumulative row counts per table for rowwise states
|
|
2908
|
-
row_count_cumsum:
|
|
3375
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
2909
3376
|
|
|
2910
3377
|
for t, _ in enumerate(self.embedding_specs):
|
|
2911
3378
|
# pyre-ignore [16]
|
|
@@ -2932,9 +3399,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2932
3399
|
@torch.jit.ignore
|
|
2933
3400
|
def _apply_state_dict_no_offloading(self) -> None:
|
|
2934
3401
|
# Row count per table
|
|
2935
|
-
|
|
3402
|
+
rows, _ = zip(*self.embedding_specs)
|
|
2936
3403
|
# Cumulative row counts per table for rowwise states
|
|
2937
|
-
row_count_cumsum:
|
|
3404
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
2938
3405
|
|
|
2939
3406
|
def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None:
|
|
2940
3407
|
device = dst.device
|
|
@@ -2968,7 +3435,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2968
3435
|
# Set up the plan for copying optimizer states over
|
|
2969
3436
|
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
2970
3437
|
mapping = [(opt_states[0], self.momentum1_dev)]
|
|
2971
|
-
elif self.optimizer
|
|
3438
|
+
elif self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
|
|
2972
3439
|
mapping = [
|
|
2973
3440
|
(opt_states[0], self.momentum1_dev),
|
|
2974
3441
|
(opt_states[1], self.momentum2_dev),
|
|
@@ -3025,7 +3492,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3025
3492
|
def streaming_write_weight_and_id_per_table(
|
|
3026
3493
|
self,
|
|
3027
3494
|
weight_state: torch.Tensor,
|
|
3028
|
-
opt_states:
|
|
3495
|
+
opt_states: list[torch.Tensor],
|
|
3029
3496
|
id_tensor: torch.Tensor,
|
|
3030
3497
|
row_offset: int,
|
|
3031
3498
|
) -> None:
|
|
@@ -3082,7 +3549,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3082
3549
|
state_name = self.optimizer.state_names()[o]
|
|
3083
3550
|
|
|
3084
3551
|
# Fetch the byte offsets for the optimizer state by its name
|
|
3085
|
-
|
|
3552
|
+
start, end = optimizer_state_byte_offsets[state_name]
|
|
3086
3553
|
|
|
3087
3554
|
# Assume that the opt_state passed in already has dtype matching
|
|
3088
3555
|
# self.optimizer_state_dtypes[state_name]
|
|
@@ -3119,7 +3586,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3119
3586
|
self.load_state_dict = True
|
|
3120
3587
|
|
|
3121
3588
|
dtype = self.weights_precision.as_dtype()
|
|
3122
|
-
|
|
3589
|
+
_, dims = zip(*self.embedding_specs)
|
|
3123
3590
|
|
|
3124
3591
|
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
|
|
3125
3592
|
|
|
@@ -3192,6 +3659,10 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3192
3659
|
def flush(self, force: bool = False) -> None:
|
|
3193
3660
|
# allow force flush from split_embedding_weights to cover edge cases, e.g. checkpointing
|
|
3194
3661
|
# after trained 0 batches
|
|
3662
|
+
if not self.training:
|
|
3663
|
+
# for eval mode, we should not write anything to embedding
|
|
3664
|
+
return
|
|
3665
|
+
|
|
3195
3666
|
if self.step == self.last_flush_step and not force:
|
|
3196
3667
|
logging.info(
|
|
3197
3668
|
f"SSD TBE has been flushed at {self.last_flush_step=} already for tbe:{self.tbe_unique_id}"
|
|
@@ -3237,18 +3708,20 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3237
3708
|
indices: Tensor,
|
|
3238
3709
|
offsets: Tensor,
|
|
3239
3710
|
per_sample_weights: Optional[Tensor] = None,
|
|
3240
|
-
batch_size_per_feature_per_rank: Optional[
|
|
3241
|
-
|
|
3711
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3712
|
+
vbe_output: Optional[Tensor] = None,
|
|
3713
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
3714
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
|
|
3242
3715
|
"""
|
|
3243
3716
|
Prepare TBE inputs
|
|
3244
3717
|
"""
|
|
3245
3718
|
# Generate VBE metadata
|
|
3246
3719
|
vbe_metadata = self._generate_vbe_metadata(
|
|
3247
|
-
offsets, batch_size_per_feature_per_rank
|
|
3720
|
+
offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
|
|
3248
3721
|
)
|
|
3249
3722
|
|
|
3250
3723
|
# Force casting indices and offsets to long
|
|
3251
|
-
|
|
3724
|
+
indices, offsets = indices.long(), offsets.long()
|
|
3252
3725
|
|
|
3253
3726
|
# Force casting per_sample_weights to float
|
|
3254
3727
|
if per_sample_weights is not None:
|
|
@@ -3287,6 +3760,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3287
3760
|
self._report_l2_cache_perf_stats()
|
|
3288
3761
|
if self.backend_type == BackendType.DRAM:
|
|
3289
3762
|
self._report_dram_kv_perf_stats()
|
|
3763
|
+
if self.kv_zch_params and self.kv_zch_params.eviction_policy:
|
|
3764
|
+
self._report_eviction_stats()
|
|
3290
3765
|
|
|
3291
3766
|
@torch.jit.ignore
|
|
3292
3767
|
def _report_ssd_l1_cache_stats(self) -> None:
|
|
@@ -3303,7 +3778,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3303
3778
|
ssd_cache_stats = self.ssd_cache_stats.tolist()
|
|
3304
3779
|
if len(self.last_reported_ssd_stats) == 0:
|
|
3305
3780
|
self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
|
|
3306
|
-
ssd_cache_stats_delta:
|
|
3781
|
+
ssd_cache_stats_delta: list[float] = [0.0] * len(ssd_cache_stats)
|
|
3307
3782
|
for i in range(len(ssd_cache_stats)):
|
|
3308
3783
|
ssd_cache_stats_delta[i] = (
|
|
3309
3784
|
ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
|
|
@@ -3553,6 +4028,98 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3553
4028
|
time_unit="us",
|
|
3554
4029
|
)
|
|
3555
4030
|
|
|
4031
|
+
@torch.jit.ignore
|
|
4032
|
+
def _report_eviction_stats(self) -> None:
|
|
4033
|
+
if self.stats_reporter is None:
|
|
4034
|
+
return
|
|
4035
|
+
|
|
4036
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
4037
|
+
if not stats_reporter.should_report(self.step):
|
|
4038
|
+
return
|
|
4039
|
+
|
|
4040
|
+
# skip metrics reporting when evicting disabled
|
|
4041
|
+
if self.kv_zch_params.eviction_policy.eviction_trigger_mode == 0:
|
|
4042
|
+
return
|
|
4043
|
+
|
|
4044
|
+
T = len(set(self.feature_table_map))
|
|
4045
|
+
evicted_counts = torch.zeros(T, dtype=torch.int64)
|
|
4046
|
+
processed_counts = torch.zeros(T, dtype=torch.int64)
|
|
4047
|
+
eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float)
|
|
4048
|
+
full_duration_ms = torch.tensor(0, dtype=torch.int64)
|
|
4049
|
+
exec_duration_ms = torch.tensor(0, dtype=torch.int64)
|
|
4050
|
+
self.ssd_db.get_feature_evict_metric(
|
|
4051
|
+
evicted_counts,
|
|
4052
|
+
processed_counts,
|
|
4053
|
+
eviction_threshold_with_dry_run,
|
|
4054
|
+
full_duration_ms,
|
|
4055
|
+
exec_duration_ms,
|
|
4056
|
+
)
|
|
4057
|
+
|
|
4058
|
+
stats_reporter.report_data_amount(
|
|
4059
|
+
iteration_step=self.step,
|
|
4060
|
+
event_name=self.eviction_sum_evicted_counts_stats_name,
|
|
4061
|
+
data_bytes=int(evicted_counts.sum().item()),
|
|
4062
|
+
enable_tb_metrics=True,
|
|
4063
|
+
)
|
|
4064
|
+
stats_reporter.report_data_amount(
|
|
4065
|
+
iteration_step=self.step,
|
|
4066
|
+
event_name=self.eviction_sum_processed_counts_stats_name,
|
|
4067
|
+
data_bytes=int(processed_counts.sum().item()),
|
|
4068
|
+
enable_tb_metrics=True,
|
|
4069
|
+
)
|
|
4070
|
+
if processed_counts.sum().item() != 0:
|
|
4071
|
+
stats_reporter.report_data_amount(
|
|
4072
|
+
iteration_step=self.step,
|
|
4073
|
+
event_name=self.eviction_evict_rate_stats_name,
|
|
4074
|
+
data_bytes=int(
|
|
4075
|
+
evicted_counts.sum().item() * 100 / processed_counts.sum().item()
|
|
4076
|
+
),
|
|
4077
|
+
enable_tb_metrics=True,
|
|
4078
|
+
)
|
|
4079
|
+
for t in self.feature_table_map:
|
|
4080
|
+
stats_reporter.report_data_amount(
|
|
4081
|
+
iteration_step=self.step,
|
|
4082
|
+
event_name=f"eviction.feature_table.{t}.evicted_counts",
|
|
4083
|
+
data_bytes=int(evicted_counts[t].item()),
|
|
4084
|
+
enable_tb_metrics=True,
|
|
4085
|
+
)
|
|
4086
|
+
stats_reporter.report_data_amount(
|
|
4087
|
+
iteration_step=self.step,
|
|
4088
|
+
event_name=f"eviction.feature_table.{t}.processed_counts",
|
|
4089
|
+
data_bytes=int(processed_counts[t].item()),
|
|
4090
|
+
enable_tb_metrics=True,
|
|
4091
|
+
)
|
|
4092
|
+
if processed_counts[t].item() != 0:
|
|
4093
|
+
stats_reporter.report_data_amount(
|
|
4094
|
+
iteration_step=self.step,
|
|
4095
|
+
event_name=f"eviction.feature_table.{t}.evict_rate",
|
|
4096
|
+
data_bytes=int(
|
|
4097
|
+
evicted_counts[t].item() * 100 / processed_counts[t].item()
|
|
4098
|
+
),
|
|
4099
|
+
enable_tb_metrics=True,
|
|
4100
|
+
)
|
|
4101
|
+
stats_reporter.report_duration(
|
|
4102
|
+
iteration_step=self.step,
|
|
4103
|
+
event_name="eviction.feature_table.full_duration_ms",
|
|
4104
|
+
duration_ms=full_duration_ms.item(),
|
|
4105
|
+
time_unit="ms",
|
|
4106
|
+
enable_tb_metrics=True,
|
|
4107
|
+
)
|
|
4108
|
+
stats_reporter.report_duration(
|
|
4109
|
+
iteration_step=self.step,
|
|
4110
|
+
event_name="eviction.feature_table.exec_duration_ms",
|
|
4111
|
+
duration_ms=exec_duration_ms.item(),
|
|
4112
|
+
time_unit="ms",
|
|
4113
|
+
enable_tb_metrics=True,
|
|
4114
|
+
)
|
|
4115
|
+
if full_duration_ms.item() != 0:
|
|
4116
|
+
stats_reporter.report_data_amount(
|
|
4117
|
+
iteration_step=self.step,
|
|
4118
|
+
event_name="eviction.feature_table.exec_div_full_duration_rate",
|
|
4119
|
+
data_bytes=int(exec_duration_ms.item() * 100 / full_duration_ms.item()),
|
|
4120
|
+
enable_tb_metrics=True,
|
|
4121
|
+
)
|
|
4122
|
+
|
|
3556
4123
|
@torch.jit.ignore
|
|
3557
4124
|
def _report_dram_kv_perf_stats(self) -> None:
|
|
3558
4125
|
"""
|
|
@@ -3570,8 +4137,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3570
4137
|
self.step, stats_reporter.report_interval # pyre-ignore
|
|
3571
4138
|
)
|
|
3572
4139
|
|
|
3573
|
-
if len(dram_kv_perf_stats) !=
|
|
3574
|
-
logging.error("dram cache perf stats should have
|
|
4140
|
+
if len(dram_kv_perf_stats) != 36:
|
|
4141
|
+
logging.error("dram cache perf stats should have 36 elements")
|
|
3575
4142
|
return
|
|
3576
4143
|
|
|
3577
4144
|
dram_read_duration = dram_kv_perf_stats[0]
|
|
@@ -3599,52 +4166,75 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3599
4166
|
|
|
3600
4167
|
dram_kv_allocated_bytes = dram_kv_perf_stats[20]
|
|
3601
4168
|
dram_kv_actual_used_chunk_bytes = dram_kv_perf_stats[21]
|
|
4169
|
+
dram_kv_num_rows = dram_kv_perf_stats[22]
|
|
4170
|
+
dram_kv_read_counts = dram_kv_perf_stats[23]
|
|
4171
|
+
dram_metadata_write_sharding_total_duration = dram_kv_perf_stats[24]
|
|
4172
|
+
dram_metadata_write_total_duration = dram_kv_perf_stats[25]
|
|
4173
|
+
dram_metadata_write_allocate_avg_duration = dram_kv_perf_stats[26]
|
|
4174
|
+
dram_metadata_write_lookup_cache_avg_duration = dram_kv_perf_stats[27]
|
|
4175
|
+
dram_metadata_write_acquire_lock_avg_duration = dram_kv_perf_stats[28]
|
|
4176
|
+
dram_metadata_write_cache_miss_avg_count = dram_kv_perf_stats[29]
|
|
4177
|
+
|
|
4178
|
+
dram_read_metadata_total_duration = dram_kv_perf_stats[30]
|
|
4179
|
+
dram_read_metadata_sharding_total_duration = dram_kv_perf_stats[31]
|
|
4180
|
+
dram_read_metadata_cache_hit_copy_avg_duration = dram_kv_perf_stats[32]
|
|
4181
|
+
dram_read_metadata_lookup_cache_total_avg_duration = dram_kv_perf_stats[33]
|
|
4182
|
+
dram_read_metadata_acquire_lock_avg_duration = dram_kv_perf_stats[34]
|
|
4183
|
+
dram_read_read_metadata_load_size = dram_kv_perf_stats[35]
|
|
3602
4184
|
|
|
3603
4185
|
stats_reporter.report_duration(
|
|
3604
4186
|
iteration_step=self.step,
|
|
3605
4187
|
event_name="dram_kv.perf.get.dram_read_duration_us",
|
|
3606
4188
|
duration_ms=dram_read_duration,
|
|
4189
|
+
enable_tb_metrics=True,
|
|
3607
4190
|
time_unit="us",
|
|
3608
4191
|
)
|
|
3609
4192
|
stats_reporter.report_duration(
|
|
3610
4193
|
iteration_step=self.step,
|
|
3611
4194
|
event_name="dram_kv.perf.get.dram_read_sharding_duration_us",
|
|
3612
4195
|
duration_ms=dram_read_sharding_duration,
|
|
4196
|
+
enable_tb_metrics=True,
|
|
3613
4197
|
time_unit="us",
|
|
3614
4198
|
)
|
|
3615
4199
|
stats_reporter.report_duration(
|
|
3616
4200
|
iteration_step=self.step,
|
|
3617
4201
|
event_name="dram_kv.perf.get.dram_read_cache_hit_copy_duration_us",
|
|
3618
4202
|
duration_ms=dram_read_cache_hit_copy_duration,
|
|
4203
|
+
enable_tb_metrics=True,
|
|
3619
4204
|
time_unit="us",
|
|
3620
4205
|
)
|
|
3621
4206
|
stats_reporter.report_duration(
|
|
3622
4207
|
iteration_step=self.step,
|
|
3623
4208
|
event_name="dram_kv.perf.get.dram_read_fill_row_storage_duration_us",
|
|
3624
4209
|
duration_ms=dram_read_fill_row_storage_duration,
|
|
4210
|
+
enable_tb_metrics=True,
|
|
3625
4211
|
time_unit="us",
|
|
3626
4212
|
)
|
|
3627
4213
|
stats_reporter.report_duration(
|
|
3628
4214
|
iteration_step=self.step,
|
|
3629
4215
|
event_name="dram_kv.perf.get.dram_read_lookup_cache_duration_us",
|
|
3630
4216
|
duration_ms=dram_read_lookup_cache_duration,
|
|
4217
|
+
enable_tb_metrics=True,
|
|
3631
4218
|
time_unit="us",
|
|
3632
4219
|
)
|
|
3633
4220
|
stats_reporter.report_duration(
|
|
3634
4221
|
iteration_step=self.step,
|
|
3635
4222
|
event_name="dram_kv.perf.get.dram_read_acquire_lock_duration_us",
|
|
3636
4223
|
duration_ms=dram_read_acquire_lock_duration,
|
|
4224
|
+
enable_tb_metrics=True,
|
|
3637
4225
|
time_unit="us",
|
|
3638
4226
|
)
|
|
3639
4227
|
stats_reporter.report_data_amount(
|
|
3640
4228
|
iteration_step=self.step,
|
|
3641
4229
|
event_name="dram_kv.perf.get.dram_read_missing_load",
|
|
4230
|
+
enable_tb_metrics=True,
|
|
3642
4231
|
data_bytes=dram_read_missing_load,
|
|
3643
4232
|
)
|
|
3644
4233
|
stats_reporter.report_duration(
|
|
3645
4234
|
iteration_step=self.step,
|
|
3646
4235
|
event_name="dram_kv.perf.set.dram_write_sharing_duration_us",
|
|
3647
4236
|
duration_ms=dram_write_sharing_duration,
|
|
4237
|
+
enable_tb_metrics=True,
|
|
3648
4238
|
time_unit="us",
|
|
3649
4239
|
)
|
|
3650
4240
|
|
|
@@ -3652,83 +4242,192 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3652
4242
|
iteration_step=self.step,
|
|
3653
4243
|
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_duration_us",
|
|
3654
4244
|
duration_ms=dram_fwd_l1_eviction_write_duration,
|
|
4245
|
+
enable_tb_metrics=True,
|
|
3655
4246
|
time_unit="us",
|
|
3656
4247
|
)
|
|
3657
4248
|
stats_reporter.report_duration(
|
|
3658
4249
|
iteration_step=self.step,
|
|
3659
4250
|
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_allocate_duration_us",
|
|
3660
4251
|
duration_ms=dram_fwd_l1_eviction_write_allocate_duration,
|
|
4252
|
+
enable_tb_metrics=True,
|
|
3661
4253
|
time_unit="us",
|
|
3662
4254
|
)
|
|
3663
4255
|
stats_reporter.report_duration(
|
|
3664
4256
|
iteration_step=self.step,
|
|
3665
4257
|
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_cache_copy_duration_us",
|
|
3666
4258
|
duration_ms=dram_fwd_l1_eviction_write_cache_copy_duration,
|
|
4259
|
+
enable_tb_metrics=True,
|
|
3667
4260
|
time_unit="us",
|
|
3668
4261
|
)
|
|
3669
4262
|
stats_reporter.report_duration(
|
|
3670
4263
|
iteration_step=self.step,
|
|
3671
4264
|
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_lookup_cache_duration_us",
|
|
3672
4265
|
duration_ms=dram_fwd_l1_eviction_write_lookup_cache_duration,
|
|
4266
|
+
enable_tb_metrics=True,
|
|
3673
4267
|
time_unit="us",
|
|
3674
4268
|
)
|
|
3675
4269
|
stats_reporter.report_duration(
|
|
3676
4270
|
iteration_step=self.step,
|
|
3677
4271
|
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_acquire_lock_duration_us",
|
|
3678
4272
|
duration_ms=dram_fwd_l1_eviction_write_acquire_lock_duration,
|
|
4273
|
+
enable_tb_metrics=True,
|
|
3679
4274
|
time_unit="us",
|
|
3680
4275
|
)
|
|
3681
4276
|
stats_reporter.report_data_amount(
|
|
3682
4277
|
iteration_step=self.step,
|
|
3683
4278
|
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load",
|
|
3684
4279
|
data_bytes=dram_fwd_l1_eviction_write_missing_load,
|
|
4280
|
+
enable_tb_metrics=True,
|
|
3685
4281
|
)
|
|
3686
4282
|
|
|
3687
4283
|
stats_reporter.report_duration(
|
|
3688
4284
|
iteration_step=self.step,
|
|
3689
4285
|
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_duration_us",
|
|
3690
4286
|
duration_ms=dram_bwd_l1_cnflct_miss_write_duration,
|
|
4287
|
+
enable_tb_metrics=True,
|
|
3691
4288
|
time_unit="us",
|
|
3692
4289
|
)
|
|
3693
4290
|
stats_reporter.report_duration(
|
|
3694
4291
|
iteration_step=self.step,
|
|
3695
4292
|
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_allocate_duration_us",
|
|
3696
4293
|
duration_ms=dram_bwd_l1_cnflct_miss_write_allocate_duration,
|
|
4294
|
+
enable_tb_metrics=True,
|
|
3697
4295
|
time_unit="us",
|
|
3698
4296
|
)
|
|
3699
4297
|
stats_reporter.report_duration(
|
|
3700
4298
|
iteration_step=self.step,
|
|
3701
4299
|
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_cache_copy_duration_us",
|
|
3702
4300
|
duration_ms=dram_bwd_l1_cnflct_miss_write_cache_copy_duration,
|
|
4301
|
+
enable_tb_metrics=True,
|
|
3703
4302
|
time_unit="us",
|
|
3704
4303
|
)
|
|
3705
4304
|
stats_reporter.report_duration(
|
|
3706
4305
|
iteration_step=self.step,
|
|
3707
4306
|
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_lookup_cache_duration_us",
|
|
3708
4307
|
duration_ms=dram_bwd_l1_cnflct_miss_write_lookup_cache_duration,
|
|
4308
|
+
enable_tb_metrics=True,
|
|
3709
4309
|
time_unit="us",
|
|
3710
4310
|
)
|
|
3711
4311
|
stats_reporter.report_duration(
|
|
3712
4312
|
iteration_step=self.step,
|
|
3713
4313
|
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_acquire_lock_duration_us",
|
|
3714
4314
|
duration_ms=dram_bwd_l1_cnflct_miss_write_acquire_lock_duration,
|
|
4315
|
+
enable_tb_metrics=True,
|
|
3715
4316
|
time_unit="us",
|
|
3716
4317
|
)
|
|
3717
4318
|
stats_reporter.report_data_amount(
|
|
3718
4319
|
iteration_step=self.step,
|
|
3719
4320
|
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load",
|
|
3720
4321
|
data_bytes=dram_bwd_l1_cnflct_miss_write_missing_load,
|
|
4322
|
+
enable_tb_metrics=True,
|
|
4323
|
+
)
|
|
4324
|
+
|
|
4325
|
+
stats_reporter.report_data_amount(
|
|
4326
|
+
iteration_step=self.step,
|
|
4327
|
+
event_name="dram_kv.perf.get.dram_kv_read_counts",
|
|
4328
|
+
data_bytes=dram_kv_read_counts,
|
|
4329
|
+
enable_tb_metrics=True,
|
|
3721
4330
|
)
|
|
3722
4331
|
|
|
3723
4332
|
stats_reporter.report_data_amount(
|
|
3724
4333
|
iteration_step=self.step,
|
|
3725
4334
|
event_name=self.dram_kv_allocated_bytes_stats_name,
|
|
3726
4335
|
data_bytes=dram_kv_allocated_bytes,
|
|
4336
|
+
enable_tb_metrics=True,
|
|
3727
4337
|
)
|
|
3728
4338
|
stats_reporter.report_data_amount(
|
|
3729
4339
|
iteration_step=self.step,
|
|
3730
4340
|
event_name=self.dram_kv_actual_used_chunk_bytes_stats_name,
|
|
3731
4341
|
data_bytes=dram_kv_actual_used_chunk_bytes,
|
|
4342
|
+
enable_tb_metrics=True,
|
|
4343
|
+
)
|
|
4344
|
+
stats_reporter.report_data_amount(
|
|
4345
|
+
iteration_step=self.step,
|
|
4346
|
+
event_name=self.dram_kv_mem_num_rows_stats_name,
|
|
4347
|
+
data_bytes=dram_kv_num_rows,
|
|
4348
|
+
enable_tb_metrics=True,
|
|
4349
|
+
)
|
|
4350
|
+
stats_reporter.report_duration(
|
|
4351
|
+
iteration_step=self.step,
|
|
4352
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_sharding_total_duration_us",
|
|
4353
|
+
duration_ms=dram_metadata_write_sharding_total_duration,
|
|
4354
|
+
enable_tb_metrics=True,
|
|
4355
|
+
time_unit="us",
|
|
4356
|
+
)
|
|
4357
|
+
stats_reporter.report_duration(
|
|
4358
|
+
iteration_step=self.step,
|
|
4359
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_total_duration_us",
|
|
4360
|
+
duration_ms=dram_metadata_write_total_duration,
|
|
4361
|
+
enable_tb_metrics=True,
|
|
4362
|
+
time_unit="us",
|
|
4363
|
+
)
|
|
4364
|
+
stats_reporter.report_duration(
|
|
4365
|
+
iteration_step=self.step,
|
|
4366
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_allocate_avg_duration_us",
|
|
4367
|
+
duration_ms=dram_metadata_write_allocate_avg_duration,
|
|
4368
|
+
enable_tb_metrics=True,
|
|
4369
|
+
time_unit="us",
|
|
4370
|
+
)
|
|
4371
|
+
stats_reporter.report_duration(
|
|
4372
|
+
iteration_step=self.step,
|
|
4373
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_lookup_cache_avg_duration_us",
|
|
4374
|
+
duration_ms=dram_metadata_write_lookup_cache_avg_duration,
|
|
4375
|
+
enable_tb_metrics=True,
|
|
4376
|
+
time_unit="us",
|
|
4377
|
+
)
|
|
4378
|
+
stats_reporter.report_duration(
|
|
4379
|
+
iteration_step=self.step,
|
|
4380
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_acquire_lock_avg_duration_us",
|
|
4381
|
+
duration_ms=dram_metadata_write_acquire_lock_avg_duration,
|
|
4382
|
+
enable_tb_metrics=True,
|
|
4383
|
+
time_unit="us",
|
|
4384
|
+
)
|
|
4385
|
+
stats_reporter.report_data_amount(
|
|
4386
|
+
iteration_step=self.step,
|
|
4387
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_cache_miss_avg_count",
|
|
4388
|
+
data_bytes=dram_metadata_write_cache_miss_avg_count,
|
|
4389
|
+
enable_tb_metrics=True,
|
|
4390
|
+
)
|
|
4391
|
+
stats_reporter.report_duration(
|
|
4392
|
+
iteration_step=self.step,
|
|
4393
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_total_duration_us",
|
|
4394
|
+
duration_ms=dram_read_metadata_total_duration,
|
|
4395
|
+
enable_tb_metrics=True,
|
|
4396
|
+
time_unit="us",
|
|
4397
|
+
)
|
|
4398
|
+
stats_reporter.report_duration(
|
|
4399
|
+
iteration_step=self.step,
|
|
4400
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_sharding_total_duration_us",
|
|
4401
|
+
duration_ms=dram_read_metadata_sharding_total_duration,
|
|
4402
|
+
enable_tb_metrics=True,
|
|
4403
|
+
time_unit="us",
|
|
4404
|
+
)
|
|
4405
|
+
stats_reporter.report_duration(
|
|
4406
|
+
iteration_step=self.step,
|
|
4407
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_cache_hit_copy_avg_duration_us",
|
|
4408
|
+
duration_ms=dram_read_metadata_cache_hit_copy_avg_duration,
|
|
4409
|
+
enable_tb_metrics=True,
|
|
4410
|
+
time_unit="us",
|
|
4411
|
+
)
|
|
4412
|
+
stats_reporter.report_duration(
|
|
4413
|
+
iteration_step=self.step,
|
|
4414
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_lookup_cache_total_avg_duration_us",
|
|
4415
|
+
duration_ms=dram_read_metadata_lookup_cache_total_avg_duration,
|
|
4416
|
+
enable_tb_metrics=True,
|
|
4417
|
+
time_unit="us",
|
|
4418
|
+
)
|
|
4419
|
+
stats_reporter.report_duration(
|
|
4420
|
+
iteration_step=self.step,
|
|
4421
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_acquire_lock_avg_duration_us",
|
|
4422
|
+
duration_ms=dram_read_metadata_acquire_lock_avg_duration,
|
|
4423
|
+
enable_tb_metrics=True,
|
|
4424
|
+
time_unit="us",
|
|
4425
|
+
)
|
|
4426
|
+
stats_reporter.report_data_amount(
|
|
4427
|
+
iteration_step=self.step,
|
|
4428
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_load_size",
|
|
4429
|
+
data_bytes=dram_read_read_metadata_load_size,
|
|
4430
|
+
enable_tb_metrics=True,
|
|
3732
4431
|
)
|
|
3733
4432
|
|
|
3734
4433
|
def _recording_to_timer(
|
|
@@ -3749,7 +4448,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3749
4448
|
|
|
3750
4449
|
def fetch_from_l1_sp_w_row_ids(
|
|
3751
4450
|
self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False
|
|
3752
|
-
) ->
|
|
4451
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
3753
4452
|
"""
|
|
3754
4453
|
Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
|
|
3755
4454
|
@return: updated_weights/optimizer_states, mask of which rows are filled
|
|
@@ -3762,36 +4461,38 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3762
4461
|
# NOTE: Remove this once there is support for fetching multiple
|
|
3763
4462
|
# optimizer states in fetch_from_l1_sp_w_row_ids
|
|
3764
4463
|
if only_get_optimizer_states and self.optimizer not in [
|
|
3765
|
-
OptimType.EXACT_ROWWISE_ADAGRAD
|
|
4464
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
4465
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
3766
4466
|
]:
|
|
3767
4467
|
raise RuntimeError(
|
|
3768
4468
|
f"Fetching optimizer states using fetch_from_l1_sp_w_row_ids() is not yet supported for {self.optimizer}"
|
|
3769
4469
|
)
|
|
3770
4470
|
|
|
3771
|
-
|
|
3772
|
-
|
|
3773
|
-
|
|
3774
|
-
|
|
3775
|
-
|
|
3776
|
-
|
|
3777
|
-
|
|
3778
|
-
|
|
3779
|
-
|
|
3780
|
-
|
|
3781
|
-
|
|
4471
|
+
def split_results_by_opt_states(
|
|
4472
|
+
updated_weights: torch.Tensor, cache_location_mask: torch.Tensor
|
|
4473
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
4474
|
+
if not only_get_optimizer_states:
|
|
4475
|
+
return [updated_weights], cache_location_mask
|
|
4476
|
+
# TODO: support mixed dimension case
|
|
4477
|
+
# currently only supports tables with the same max_D dimension
|
|
4478
|
+
opt_to_dim = self.optimizer.byte_offsets_along_row(
|
|
4479
|
+
self.max_D, self.weights_precision, self.optimizer_state_dtypes
|
|
4480
|
+
)
|
|
4481
|
+
updated_opt_states = []
|
|
4482
|
+
for opt_name, dim in opt_to_dim.items():
|
|
4483
|
+
opt_dtype = self.optimizer._extract_dtype(
|
|
4484
|
+
self.optimizer_state_dtypes, opt_name
|
|
3782
4485
|
)
|
|
3783
|
-
|
|
3784
|
-
|
|
3785
|
-
|
|
4486
|
+
updated_opt_states.append(
|
|
4487
|
+
updated_weights.view(dtype=torch.uint8)[:, dim[0] : dim[1]].view(
|
|
4488
|
+
dtype=opt_dtype
|
|
4489
|
+
)
|
|
3786
4490
|
)
|
|
4491
|
+
return updated_opt_states, cache_location_mask
|
|
3787
4492
|
|
|
3788
|
-
|
|
3789
|
-
|
|
3790
|
-
|
|
3791
|
-
row_dim = self.cache_row_dim
|
|
3792
|
-
result_dim = row_dim
|
|
3793
|
-
result_dtype = weights_dtype
|
|
3794
|
-
|
|
4493
|
+
with torch.no_grad():
|
|
4494
|
+
weights_dtype = self.weights_precision.as_dtype()
|
|
4495
|
+
step = self.step
|
|
3795
4496
|
with record_function(f"## fetch_from_l1_{step}_{self.tbe_unique_id} ##"):
|
|
3796
4497
|
lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
|
|
3797
4498
|
row_ids,
|
|
@@ -3800,17 +4501,23 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3800
4501
|
)
|
|
3801
4502
|
updated_weights = torch.empty(
|
|
3802
4503
|
row_ids.numel(),
|
|
3803
|
-
|
|
4504
|
+
self.cache_row_dim,
|
|
3804
4505
|
device=self.current_device,
|
|
3805
|
-
dtype=
|
|
4506
|
+
dtype=weights_dtype,
|
|
3806
4507
|
)
|
|
3807
4508
|
|
|
3808
4509
|
# D2D copy cache
|
|
3809
4510
|
cache_location_mask = lxu_cache_locations >= 0
|
|
3810
|
-
|
|
3811
|
-
|
|
3812
|
-
|
|
3813
|
-
|
|
4511
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4512
|
+
updated_weights,
|
|
4513
|
+
lxu_cache_locations,
|
|
4514
|
+
self.lxu_cache_weights,
|
|
4515
|
+
torch.tensor(
|
|
4516
|
+
[row_ids.numel()],
|
|
4517
|
+
device=self.current_device,
|
|
4518
|
+
dtype=torch.int32,
|
|
4519
|
+
),
|
|
4520
|
+
)
|
|
3814
4521
|
|
|
3815
4522
|
with record_function(f"## fetch_from_sp_{step}_{self.tbe_unique_id} ##"):
|
|
3816
4523
|
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
@@ -3821,7 +4528,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3821
4528
|
actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
|
|
3822
4529
|
if actions_count_gpu.item() == 0:
|
|
3823
4530
|
# no action to take
|
|
3824
|
-
return (
|
|
4531
|
+
return split_results_by_opt_states(
|
|
4532
|
+
updated_weights, cache_location_mask
|
|
4533
|
+
)
|
|
3825
4534
|
|
|
3826
4535
|
sp_idx = sp_idx[:actions_count_gpu]
|
|
3827
4536
|
|
|
@@ -3872,16 +4581,23 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3872
4581
|
)
|
|
3873
4582
|
|
|
3874
4583
|
# D2D copy SP
|
|
3875
|
-
|
|
3876
|
-
|
|
3877
|
-
|
|
4584
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4585
|
+
updated_weights,
|
|
4586
|
+
sp_locations_in_updated_weights,
|
|
4587
|
+
sp,
|
|
4588
|
+
torch.tensor(
|
|
4589
|
+
[row_ids.numel()],
|
|
4590
|
+
device=self.current_device,
|
|
4591
|
+
dtype=torch.int32,
|
|
4592
|
+
),
|
|
4593
|
+
)
|
|
3878
4594
|
# cache_location_mask is the mask of rows in L1
|
|
3879
4595
|
# exact_match_mask is the mask of rows in SP
|
|
3880
4596
|
cache_location_mask = torch.logical_or(
|
|
3881
4597
|
cache_location_mask, exact_match_mask
|
|
3882
4598
|
)
|
|
3883
4599
|
|
|
3884
|
-
return (updated_weights, cache_location_mask)
|
|
4600
|
+
return split_results_by_opt_states(updated_weights, cache_location_mask)
|
|
3885
4601
|
|
|
3886
4602
|
def register_backward_hook_before_eviction(
|
|
3887
4603
|
self, backward_hook: Callable[[torch.Tensor], None]
|
|
@@ -3901,3 +4617,312 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3901
4617
|
self.placeholder_autograd_tensor.register_hook(backward_hook)
|
|
3902
4618
|
for hook in hooks:
|
|
3903
4619
|
self.placeholder_autograd_tensor.register_hook(hook)
|
|
4620
|
+
|
|
4621
|
+
def set_local_weight_counts_for_table(
|
|
4622
|
+
self, table_idx: int, weight_count: int
|
|
4623
|
+
) -> None:
|
|
4624
|
+
self.local_weight_counts[table_idx] = weight_count
|
|
4625
|
+
|
|
4626
|
+
def set_global_id_per_rank_for_table(
|
|
4627
|
+
self, table_idx: int, global_id: torch.Tensor
|
|
4628
|
+
) -> None:
|
|
4629
|
+
self.global_id_per_rank[table_idx] = global_id
|
|
4630
|
+
|
|
4631
|
+
def direct_write_embedding(
|
|
4632
|
+
self,
|
|
4633
|
+
indices: torch.Tensor,
|
|
4634
|
+
offsets: torch.Tensor,
|
|
4635
|
+
weights: torch.Tensor,
|
|
4636
|
+
) -> None:
|
|
4637
|
+
"""
|
|
4638
|
+
Directly write the weights to L1, SP and backend without relying on auto-gradient for embedding cache.
|
|
4639
|
+
Please refer to design doc for more details: https://docs.google.com/document/d/1TJHKvO1m3-5tYAKZGhacXnGk7iCNAzz7wQlrFbX_LDI/edit?tab=t.0
|
|
4640
|
+
"""
|
|
4641
|
+
assert (
|
|
4642
|
+
self._embedding_cache_mode
|
|
4643
|
+
), "Must be in embedding_cache_mode to support direct_write_embedding method."
|
|
4644
|
+
|
|
4645
|
+
B_offsets = None
|
|
4646
|
+
max_B = -1
|
|
4647
|
+
|
|
4648
|
+
with torch.no_grad():
|
|
4649
|
+
# Wait for any ongoing prefetch operations to complete before starting direct_write
|
|
4650
|
+
current_stream = torch.cuda.current_stream()
|
|
4651
|
+
current_stream.wait_event(self.prefetch_complete_event)
|
|
4652
|
+
|
|
4653
|
+
# Create local step events for internal sequential execution
|
|
4654
|
+
weights_dtype = self.weights_precision.as_dtype()
|
|
4655
|
+
assert (
|
|
4656
|
+
weights_dtype == weights.dtype
|
|
4657
|
+
), f"Expected embedding table dtype {weights_dtype} is same with input weight dtype, but got {weights.dtype}"
|
|
4658
|
+
|
|
4659
|
+
# Pad the weights to match self.max_D width if necessary
|
|
4660
|
+
if weights.size(1) < self.cache_row_dim:
|
|
4661
|
+
weights = torch.nn.functional.pad(
|
|
4662
|
+
weights, (0, self.cache_row_dim - weights.size(1))
|
|
4663
|
+
)
|
|
4664
|
+
|
|
4665
|
+
step = self.step
|
|
4666
|
+
|
|
4667
|
+
# step 0: run backward hook for prefetch if prefetch pipeline is enabled before writing to L1 and SP
|
|
4668
|
+
if self.prefetch_pipeline:
|
|
4669
|
+
self._update_cache_counter_and_pointers(nn.Module(), torch.empty(0))
|
|
4670
|
+
|
|
4671
|
+
# step 1: lookup and write to l1 cache
|
|
4672
|
+
with record_function(
|
|
4673
|
+
f"## direct_write_to_l1_{step}_{self.tbe_unique_id} ##"
|
|
4674
|
+
):
|
|
4675
|
+
if self.gather_ssd_cache_stats:
|
|
4676
|
+
self.local_ssd_cache_stats.zero_()
|
|
4677
|
+
|
|
4678
|
+
# Linearize indices
|
|
4679
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
4680
|
+
self.hash_size_cumsum,
|
|
4681
|
+
indices,
|
|
4682
|
+
offsets,
|
|
4683
|
+
B_offsets,
|
|
4684
|
+
max_B,
|
|
4685
|
+
)
|
|
4686
|
+
|
|
4687
|
+
lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
|
|
4688
|
+
linear_cache_indices,
|
|
4689
|
+
self.lxu_cache_state,
|
|
4690
|
+
self.total_hash_size,
|
|
4691
|
+
)
|
|
4692
|
+
cache_location_mask = lxu_cache_locations >= 0
|
|
4693
|
+
|
|
4694
|
+
# Get the cache locations for the row_ids that are already in the cache
|
|
4695
|
+
cache_locations = lxu_cache_locations[cache_location_mask]
|
|
4696
|
+
|
|
4697
|
+
# Get the corresponding input weights for these row_ids
|
|
4698
|
+
cache_weights = weights[cache_location_mask]
|
|
4699
|
+
|
|
4700
|
+
# Update the cache with these input weights
|
|
4701
|
+
if cache_locations.numel() > 0:
|
|
4702
|
+
self.lxu_cache_weights.index_put_(
|
|
4703
|
+
(cache_locations,), cache_weights, accumulate=False
|
|
4704
|
+
)
|
|
4705
|
+
|
|
4706
|
+
# Record completion of step 1
|
|
4707
|
+
current_stream.record_event(self.direct_write_l1_complete_event)
|
|
4708
|
+
|
|
4709
|
+
# step 2: pop the current scratch pad and write to next batch scratch pad if exists
|
|
4710
|
+
# Wait for step 1 to complete
|
|
4711
|
+
with record_function(
|
|
4712
|
+
f"## direct_write_to_sp_{step}_{self.tbe_unique_id} ##"
|
|
4713
|
+
):
|
|
4714
|
+
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
4715
|
+
self.ssd_scratch_pad_eviction_data.pop(0)
|
|
4716
|
+
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
4717
|
+
# Wait for any pending backend reads to the next scratch pad
|
|
4718
|
+
# to complete before we write to it. Otherwise, stale backend data
|
|
4719
|
+
# will overwrite our direct_write updates.
|
|
4720
|
+
# The ssd_event_get marks completion of backend fetch operations.
|
|
4721
|
+
current_stream.wait_event(self.ssd_event_get)
|
|
4722
|
+
|
|
4723
|
+
# if scratch pad exists, write to next batch scratch pad
|
|
4724
|
+
sp = self.ssd_scratch_pad_eviction_data[0][0]
|
|
4725
|
+
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
|
|
4726
|
+
self.current_device
|
|
4727
|
+
)
|
|
4728
|
+
actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
|
|
4729
|
+
if actions_count_gpu.item() != 0:
|
|
4730
|
+
# when no actional_count_gpu, no need to write to SP
|
|
4731
|
+
sp_idx = sp_idx[:actions_count_gpu]
|
|
4732
|
+
|
|
4733
|
+
# -1 in lxu_cache_locations means the row is not in L1 cache and in SP
|
|
4734
|
+
# fill the row_ids in L1 with -2, >0 values means in SP or backend
|
|
4735
|
+
# @eg. updated_indices_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
|
|
4736
|
+
updated_indices_in_sp = linear_cache_indices.masked_fill(
|
|
4737
|
+
lxu_cache_locations != -1, -2
|
|
4738
|
+
)
|
|
4739
|
+
# sort the sp_idx for binary search
|
|
4740
|
+
# should already be sorted
|
|
4741
|
+
# sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
|
|
4742
|
+
# @eg. sp_idx = [4, 2, 1, 3, 10]
|
|
4743
|
+
# @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
|
|
4744
|
+
sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
|
|
4745
|
+
# search rows id in sp against the SP indexes to find location of the rows in SP
|
|
4746
|
+
# @eg: updated_indices_in_sp = [0, 5, 0, 1, 0, 2, 3, 4, 4]
|
|
4747
|
+
# @eg: 5 is OOB
|
|
4748
|
+
updated_indices_in_sp_idx = torch.searchsorted(
|
|
4749
|
+
sorted_sp_idx, updated_indices_in_sp
|
|
4750
|
+
)
|
|
4751
|
+
# does not found in SP will Out of Bound
|
|
4752
|
+
oob_sp_idx = updated_indices_in_sp_idx >= sp_idx.numel()
|
|
4753
|
+
# make the oob items in bound
|
|
4754
|
+
# @eg updated_indices_in_sp=[0, 0, 0, 1, 0, 2, 3, 4, 4]
|
|
4755
|
+
updated_indices_in_sp_idx[oob_sp_idx] = 0
|
|
4756
|
+
|
|
4757
|
+
# torch.searchsorted is not exact match,
|
|
4758
|
+
# we only take exact matched rows, where the id is found in SP.
|
|
4759
|
+
# @eg 5 in updated_indices_in_sp is not in sp_idx, but has 4 in updated_indices_in_sp
|
|
4760
|
+
# @eg sorted_sp_idx[updated_indices_in_sp]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
|
|
4761
|
+
# @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
|
|
4762
|
+
exact_match_mask = (
|
|
4763
|
+
sorted_sp_idx[updated_indices_in_sp_idx]
|
|
4764
|
+
== updated_indices_in_sp
|
|
4765
|
+
)
|
|
4766
|
+
# Get the location of the row ids found in SP.
|
|
4767
|
+
# @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
|
|
4768
|
+
sp_locations_found = sp_idx_inverse_indices[
|
|
4769
|
+
updated_indices_in_sp[exact_match_mask]
|
|
4770
|
+
]
|
|
4771
|
+
# Get the corresponding weights for the matched indices
|
|
4772
|
+
matched_weights = weights[exact_match_mask]
|
|
4773
|
+
|
|
4774
|
+
# Write the weights to the sparse tensor at the found locations
|
|
4775
|
+
if sp_locations_found.numel() > 0:
|
|
4776
|
+
sp.index_put_(
|
|
4777
|
+
(sp_locations_found,),
|
|
4778
|
+
matched_weights,
|
|
4779
|
+
accumulate=False,
|
|
4780
|
+
)
|
|
4781
|
+
current_stream.record_event(self.direct_write_sp_complete_event)
|
|
4782
|
+
|
|
4783
|
+
# step 3: write l1 cache missing rows to backend
|
|
4784
|
+
# Wait for step 2 to complete
|
|
4785
|
+
with record_function(
|
|
4786
|
+
f"## direct_write_to_backend_{step}_{self.tbe_unique_id} ##"
|
|
4787
|
+
):
|
|
4788
|
+
# Use the existing ssd_eviction_stream for all backend write operations
|
|
4789
|
+
# This stream is already created with low priority during initialization
|
|
4790
|
+
with torch.cuda.stream(self.ssd_eviction_stream):
|
|
4791
|
+
# Create a mask for indices not in L1 cache
|
|
4792
|
+
non_cache_mask = ~cache_location_mask
|
|
4793
|
+
|
|
4794
|
+
# Calculate the count of valid indices (those not in L1 cache)
|
|
4795
|
+
valid_count = non_cache_mask.sum().to(torch.int64).cpu()
|
|
4796
|
+
|
|
4797
|
+
if valid_count.item() > 0:
|
|
4798
|
+
# Extract only the indices and weights that are not in L1 cache
|
|
4799
|
+
non_cache_indices = linear_cache_indices[non_cache_mask]
|
|
4800
|
+
non_cache_weights = weights[non_cache_mask]
|
|
4801
|
+
|
|
4802
|
+
# Move tensors to CPU for set_cuda
|
|
4803
|
+
cpu_indices = non_cache_indices.cpu()
|
|
4804
|
+
cpu_weights = non_cache_weights.cpu()
|
|
4805
|
+
|
|
4806
|
+
# Write to backend - only sending the non-cache indices and weights
|
|
4807
|
+
self.record_function_via_dummy_profile(
|
|
4808
|
+
f"## ssd_write_{step}_set_cuda_{self.tbe_unique_id} ##",
|
|
4809
|
+
self.ssd_db.set_cuda,
|
|
4810
|
+
cpu_indices,
|
|
4811
|
+
cpu_weights,
|
|
4812
|
+
valid_count,
|
|
4813
|
+
self.timestep,
|
|
4814
|
+
is_bwd=False,
|
|
4815
|
+
)
|
|
4816
|
+
|
|
4817
|
+
# Return control to the main stream without waiting for the backend operation to complete
|
|
4818
|
+
|
|
4819
|
+
def get_free_cpu_memory_gb(self) -> float:
|
|
4820
|
+
def _get_mem_available() -> float:
|
|
4821
|
+
if sys.platform.startswith("linux"):
|
|
4822
|
+
info = {}
|
|
4823
|
+
with open("/proc/meminfo") as f:
|
|
4824
|
+
for line in f:
|
|
4825
|
+
p = line.split()
|
|
4826
|
+
info[p[0].strip(":").lower()] = int(p[1]) * 1024
|
|
4827
|
+
if "memavailable" in info:
|
|
4828
|
+
# Linux >= 3.14
|
|
4829
|
+
return info["memavailable"]
|
|
4830
|
+
else:
|
|
4831
|
+
return info["memfree"] + info["cached"]
|
|
4832
|
+
else:
|
|
4833
|
+
raise RuntimeError(
|
|
4834
|
+
"Unsupported platform for free memory eviction, pls use ID count eviction tirgger mode"
|
|
4835
|
+
)
|
|
4836
|
+
|
|
4837
|
+
mem = _get_mem_available()
|
|
4838
|
+
return mem / (1024**3)
|
|
4839
|
+
|
|
4840
|
+
@classmethod
|
|
4841
|
+
def trigger_evict_in_all_tbes(cls) -> None:
|
|
4842
|
+
for tbe in cls._all_tbe_instances:
|
|
4843
|
+
tbe.ssd_db.trigger_feature_evict()
|
|
4844
|
+
|
|
4845
|
+
@classmethod
|
|
4846
|
+
def tbe_has_ongoing_eviction(cls) -> bool:
|
|
4847
|
+
for tbe in cls._all_tbe_instances:
|
|
4848
|
+
if tbe.ssd_db.is_evicting():
|
|
4849
|
+
return True
|
|
4850
|
+
return False
|
|
4851
|
+
|
|
4852
|
+
def set_free_mem_eviction_trigger_config(
|
|
4853
|
+
self, eviction_policy: EvictionPolicy
|
|
4854
|
+
) -> None:
|
|
4855
|
+
self.enable_free_mem_trigger_eviction = True
|
|
4856
|
+
self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode
|
|
4857
|
+
assert (
|
|
4858
|
+
eviction_policy.eviction_free_mem_check_interval_batch is not None
|
|
4859
|
+
), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode"
|
|
4860
|
+
self.eviction_free_mem_check_interval_batch: int = (
|
|
4861
|
+
eviction_policy.eviction_free_mem_check_interval_batch
|
|
4862
|
+
)
|
|
4863
|
+
assert (
|
|
4864
|
+
eviction_policy.eviction_free_mem_threshold_gb is not None
|
|
4865
|
+
), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode"
|
|
4866
|
+
self.eviction_free_mem_threshold_gb: int = (
|
|
4867
|
+
eviction_policy.eviction_free_mem_threshold_gb
|
|
4868
|
+
)
|
|
4869
|
+
logging.info(
|
|
4870
|
+
f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}"
|
|
4871
|
+
)
|
|
4872
|
+
|
|
4873
|
+
def may_trigger_eviction(self) -> None:
|
|
4874
|
+
def is_first_tbe() -> bool:
|
|
4875
|
+
first = SSDTableBatchedEmbeddingBags._first_instance_ref
|
|
4876
|
+
return first is not None and first() is self
|
|
4877
|
+
|
|
4878
|
+
# We assume that the eviction time is less than free mem check interval time
|
|
4879
|
+
# So every time we reach this check, all evictions in all tbes should be finished.
|
|
4880
|
+
# We only need to check the first tbe because all tbes share the same free mem,
|
|
4881
|
+
# once the first tbe detect need to trigger eviction, it will call trigger func
|
|
4882
|
+
# in all tbes from _all_tbe_instances
|
|
4883
|
+
if (
|
|
4884
|
+
self.enable_free_mem_trigger_eviction
|
|
4885
|
+
and self.step % self.eviction_free_mem_check_interval_batch == 0
|
|
4886
|
+
and self.training
|
|
4887
|
+
and is_first_tbe()
|
|
4888
|
+
):
|
|
4889
|
+
if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction():
|
|
4890
|
+
SSDTableBatchedEmbeddingBags._eviction_triggered = False
|
|
4891
|
+
|
|
4892
|
+
free_cpu_mem_gb = self.get_free_cpu_memory_gb()
|
|
4893
|
+
local_evict_trigger = int(
|
|
4894
|
+
free_cpu_mem_gb < self.eviction_free_mem_threshold_gb
|
|
4895
|
+
)
|
|
4896
|
+
tensor_flag = torch.tensor(
|
|
4897
|
+
local_evict_trigger,
|
|
4898
|
+
device=self.current_device,
|
|
4899
|
+
dtype=torch.int,
|
|
4900
|
+
)
|
|
4901
|
+
world_size = dist.get_world_size(self._pg)
|
|
4902
|
+
if world_size > 1:
|
|
4903
|
+
dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg)
|
|
4904
|
+
global_evict_trigger = tensor_flag.item()
|
|
4905
|
+
else:
|
|
4906
|
+
global_evict_trigger = local_evict_trigger
|
|
4907
|
+
if (
|
|
4908
|
+
global_evict_trigger >= 1
|
|
4909
|
+
and SSDTableBatchedEmbeddingBags._eviction_triggered
|
|
4910
|
+
):
|
|
4911
|
+
logging.warning(
|
|
4912
|
+
f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true"
|
|
4913
|
+
)
|
|
4914
|
+
if (
|
|
4915
|
+
global_evict_trigger >= 1
|
|
4916
|
+
and not SSDTableBatchedEmbeddingBags._eviction_triggered
|
|
4917
|
+
):
|
|
4918
|
+
SSDTableBatchedEmbeddingBags._eviction_triggered = True
|
|
4919
|
+
SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes()
|
|
4920
|
+
logging.info(
|
|
4921
|
+
f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction"
|
|
4922
|
+
)
|
|
4923
|
+
|
|
4924
|
+
def reset_inference_mode(self) -> None:
|
|
4925
|
+
"""
|
|
4926
|
+
Reset the inference mode
|
|
4927
|
+
"""
|
|
4928
|
+
self.eval()
|