fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,4908 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
# pyre-ignore-all-errors[13,56]
|
|
10
|
+
|
|
11
|
+
import contextlib
|
|
12
|
+
import functools
|
|
13
|
+
import itertools
|
|
14
|
+
import logging
|
|
15
|
+
import math
|
|
16
|
+
import os
|
|
17
|
+
import threading
|
|
18
|
+
import time
|
|
19
|
+
from functools import cached_property
|
|
20
|
+
from math import floor, log2
|
|
21
|
+
from typing import Any, Callable, ClassVar, Optional, Union
|
|
22
|
+
import torch # usort:skip
|
|
23
|
+
import weakref
|
|
24
|
+
|
|
25
|
+
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
|
26
|
+
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
|
|
27
|
+
from fbgemm_gpu.runtime_monitor import (
|
|
28
|
+
AsyncSeriesTimer,
|
|
29
|
+
TBEStatsReporter,
|
|
30
|
+
TBEStatsReporterConfig,
|
|
31
|
+
)
|
|
32
|
+
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
|
|
33
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
34
|
+
BackendType,
|
|
35
|
+
BoundsCheckMode,
|
|
36
|
+
CacheAlgorithm,
|
|
37
|
+
EmbeddingLocation,
|
|
38
|
+
EvictionPolicy,
|
|
39
|
+
get_bounds_check_version_for_platform,
|
|
40
|
+
KVZCHParams,
|
|
41
|
+
PoolingMode,
|
|
42
|
+
SplitState,
|
|
43
|
+
)
|
|
44
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
|
45
|
+
apply_split_helper,
|
|
46
|
+
CounterBasedRegularizationDefinition,
|
|
47
|
+
CowClipDefinition,
|
|
48
|
+
RESParams,
|
|
49
|
+
UVMCacheStatsIndex,
|
|
50
|
+
WeightDecayMode,
|
|
51
|
+
)
|
|
52
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
|
|
53
|
+
generate_vbe_metadata,
|
|
54
|
+
is_torchdynamo_compiling,
|
|
55
|
+
)
|
|
56
|
+
from torch import distributed as dist, nn, Tensor # usort:skip
|
|
57
|
+
import sys
|
|
58
|
+
from dataclasses import dataclass
|
|
59
|
+
|
|
60
|
+
from torch.autograd.profiler import record_function
|
|
61
|
+
|
|
62
|
+
from ..cache import get_unique_indices_v2
|
|
63
|
+
from .common import ASSOC, pad4, tensor_pad4
|
|
64
|
+
from .utils.partially_materialized_tensor import PartiallyMaterializedTensor
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class IterData:
|
|
69
|
+
indices: Tensor
|
|
70
|
+
offsets: Tensor
|
|
71
|
+
lxu_cache_locations: Tensor
|
|
72
|
+
lxu_cache_ptrs: Tensor
|
|
73
|
+
actions_count_gpu: Tensor
|
|
74
|
+
cache_set_inverse_indices: Tensor
|
|
75
|
+
B_offsets: Optional[Tensor] = None
|
|
76
|
+
max_B: Optional[int] = -1
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class KVZCHCachedData:
|
|
81
|
+
cached_optimizer_states_per_table: list[list[torch.Tensor]]
|
|
82
|
+
cached_weight_tensor_per_table: list[torch.Tensor]
|
|
83
|
+
cached_id_tensor_per_table: list[torch.Tensor]
|
|
84
|
+
cached_bucket_splits: list[torch.Tensor]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
88
|
+
D_offsets: Tensor
|
|
89
|
+
lxu_cache_weights: Tensor
|
|
90
|
+
lru_state: Tensor
|
|
91
|
+
lxu_cache_weights: Tensor
|
|
92
|
+
lxu_cache_state: Tensor
|
|
93
|
+
momentum1_dev: Tensor
|
|
94
|
+
momentum1_uvm: Tensor
|
|
95
|
+
momentum1_host: Tensor
|
|
96
|
+
momentum1_placements: Tensor
|
|
97
|
+
momentum1_offsets: Tensor
|
|
98
|
+
weights_dev: Tensor
|
|
99
|
+
weights_uvm: Tensor
|
|
100
|
+
weights_host: Tensor
|
|
101
|
+
weights_placements: Tensor
|
|
102
|
+
weights_offsets: Tensor
|
|
103
|
+
_local_instance_index: int = -1
|
|
104
|
+
res_params: RESParams
|
|
105
|
+
table_names: list[str]
|
|
106
|
+
_all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet()
|
|
107
|
+
_first_instance_ref: ClassVar[weakref.ref] = None
|
|
108
|
+
_eviction_triggered: ClassVar[bool] = False
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
|
|
113
|
+
feature_table_map: Optional[list[int]], # [T]
|
|
114
|
+
cache_sets: int,
|
|
115
|
+
# A comma-separated string, e.g. "/data00_nvidia0,/data01_nvidia0/", db shards
|
|
116
|
+
# will be placed in these paths round-robin.
|
|
117
|
+
ssd_storage_directory: str,
|
|
118
|
+
ssd_rocksdb_shards: int = 1,
|
|
119
|
+
ssd_memtable_flush_period: int = -1,
|
|
120
|
+
ssd_memtable_flush_offset: int = -1,
|
|
121
|
+
ssd_l0_files_per_compact: int = 4,
|
|
122
|
+
ssd_rate_limit_mbps: int = 0,
|
|
123
|
+
ssd_size_ratio: int = 10,
|
|
124
|
+
ssd_compaction_trigger: int = 8,
|
|
125
|
+
ssd_rocksdb_write_buffer_size: int = 2 * 1024 * 1024 * 1024,
|
|
126
|
+
ssd_max_write_buffer_num: int = 4,
|
|
127
|
+
ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED,
|
|
128
|
+
ssd_uniform_init_lower: float = -0.01,
|
|
129
|
+
ssd_uniform_init_upper: float = 0.01,
|
|
130
|
+
ssd_block_cache_size_per_tbe: int = 0,
|
|
131
|
+
weights_precision: SparseType = SparseType.FP32,
|
|
132
|
+
output_dtype: SparseType = SparseType.FP32,
|
|
133
|
+
optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
134
|
+
# General Optimizer args
|
|
135
|
+
stochastic_rounding: bool = True,
|
|
136
|
+
gradient_clipping: bool = False,
|
|
137
|
+
max_gradient: float = 1.0,
|
|
138
|
+
max_norm: float = 0.0,
|
|
139
|
+
learning_rate: float = 0.01,
|
|
140
|
+
eps: float = 1.0e-8, # used by Adagrad, LAMB, and Adam
|
|
141
|
+
momentum: float = 0.9, # used by LARS-SGD
|
|
142
|
+
weight_decay: float = 0.0, # used by LARS-SGD, LAMB, ADAM, and Rowwise Adagrad
|
|
143
|
+
weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, # used by Rowwise Adagrad
|
|
144
|
+
eta: float = 0.001, # used by LARS-SGD,
|
|
145
|
+
beta1: float = 0.9, # used by LAMB and ADAM
|
|
146
|
+
beta2: float = 0.999, # used by LAMB and ADAM
|
|
147
|
+
counter_based_regularization: Optional[
|
|
148
|
+
CounterBasedRegularizationDefinition
|
|
149
|
+
] = None, # used by Rowwise Adagrad
|
|
150
|
+
cowclip_regularization: Optional[
|
|
151
|
+
CowClipDefinition
|
|
152
|
+
] = None, # used by Rowwise Adagrad
|
|
153
|
+
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
154
|
+
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
155
|
+
# Parameter Server Configs
|
|
156
|
+
ps_hosts: Optional[tuple[tuple[str, int]]] = None,
|
|
157
|
+
ps_max_key_per_request: Optional[int] = None,
|
|
158
|
+
ps_client_thread_num: Optional[int] = None,
|
|
159
|
+
ps_max_local_index_length: Optional[int] = None,
|
|
160
|
+
tbe_unique_id: int = -1,
|
|
161
|
+
# If set to True, will use `ssd_storage_directory` as the ssd paths.
|
|
162
|
+
# If set to False, will use the default ssd paths.
|
|
163
|
+
# In local test we need to use the pass in path for rocksdb creation
|
|
164
|
+
# fn production we could either use the default ssd mount points or explicity specify ssd
|
|
165
|
+
# mount points using `ssd_storage_directory`.
|
|
166
|
+
use_passed_in_path: int = True,
|
|
167
|
+
gather_ssd_cache_stats: Optional[bool] = False,
|
|
168
|
+
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
|
|
169
|
+
l2_cache_size: int = 0,
|
|
170
|
+
# Set to True to enable pipeline prefetching
|
|
171
|
+
prefetch_pipeline: bool = False,
|
|
172
|
+
# Set to True to alloc a UVM tensor using malloc+cudaHostRegister.
|
|
173
|
+
# Set to False to use cudaMallocManaged
|
|
174
|
+
uvm_host_mapped: bool = False,
|
|
175
|
+
enable_async_update: bool = True, # whether enable L2/rocksdb write to async background thread
|
|
176
|
+
# if > 0, insert all kv pairs to rocksdb at init time, in chunks of *bulk_init_chunk_size* bytes
|
|
177
|
+
# number of rows will be decided by bulk_init_chunk_size / size_of_each_row
|
|
178
|
+
bulk_init_chunk_size: int = 0,
|
|
179
|
+
lazy_bulk_init_enabled: bool = False,
|
|
180
|
+
backend_type: BackendType = BackendType.SSD,
|
|
181
|
+
kv_zch_params: Optional[KVZCHParams] = None,
|
|
182
|
+
enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
|
|
183
|
+
res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
|
|
184
|
+
flushing_block_size: int = 2_000_000_000, # 2GB
|
|
185
|
+
table_names: Optional[list[str]] = None,
|
|
186
|
+
use_rowwise_bias_correction: bool = False, # For Adam use
|
|
187
|
+
optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
|
|
188
|
+
pg: Optional[dist.ProcessGroup] = None,
|
|
189
|
+
) -> None:
|
|
190
|
+
super(SSDTableBatchedEmbeddingBags, self).__init__()
|
|
191
|
+
|
|
192
|
+
# Set the optimizer
|
|
193
|
+
assert optimizer in (
|
|
194
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
195
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
196
|
+
OptimType.ADAM,
|
|
197
|
+
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
|
|
198
|
+
self.optimizer = optimizer
|
|
199
|
+
|
|
200
|
+
# Set the table weight and output dtypes
|
|
201
|
+
assert weights_precision in (SparseType.FP32, SparseType.FP16)
|
|
202
|
+
self.weights_precision = weights_precision
|
|
203
|
+
self.output_dtype: int = output_dtype.as_int()
|
|
204
|
+
|
|
205
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
206
|
+
# Adagrad currently only supports FP32 for momentum1
|
|
207
|
+
self.optimizer_state_dtypes: dict[str, SparseType] = {
|
|
208
|
+
"momentum1": SparseType.FP32,
|
|
209
|
+
}
|
|
210
|
+
else:
|
|
211
|
+
self.optimizer_state_dtypes: dict[str, SparseType] = optimizer_state_dtypes
|
|
212
|
+
|
|
213
|
+
# Zero collision TBE configurations
|
|
214
|
+
self.kv_zch_params = kv_zch_params
|
|
215
|
+
self.backend_type = backend_type
|
|
216
|
+
self.enable_optimizer_offloading: bool = False
|
|
217
|
+
self.backend_return_whole_row: bool = False
|
|
218
|
+
self._embedding_cache_mode: bool = False
|
|
219
|
+
self.load_ckpt_without_opt: bool = False
|
|
220
|
+
if self.kv_zch_params:
|
|
221
|
+
self.kv_zch_params.validate()
|
|
222
|
+
self.load_ckpt_without_opt = (
|
|
223
|
+
# pyre-ignore [16]
|
|
224
|
+
self.kv_zch_params.load_ckpt_without_opt
|
|
225
|
+
)
|
|
226
|
+
self.enable_optimizer_offloading = (
|
|
227
|
+
# pyre-ignore [16]
|
|
228
|
+
self.kv_zch_params.enable_optimizer_offloading
|
|
229
|
+
)
|
|
230
|
+
self.backend_return_whole_row = (
|
|
231
|
+
# pyre-ignore [16]
|
|
232
|
+
self.kv_zch_params.backend_return_whole_row
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if self.enable_optimizer_offloading:
|
|
236
|
+
logging.info("Optimizer state offloading is enabled")
|
|
237
|
+
if self.backend_return_whole_row:
|
|
238
|
+
assert (
|
|
239
|
+
self.backend_type == BackendType.DRAM
|
|
240
|
+
), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}"
|
|
241
|
+
logging.info(
|
|
242
|
+
"Backend will return whole row including metaheader, weight and optimizer for checkpoint"
|
|
243
|
+
)
|
|
244
|
+
# pyre-ignore [16]
|
|
245
|
+
self._embedding_cache_mode = self.kv_zch_params.embedding_cache_mode
|
|
246
|
+
if self._embedding_cache_mode:
|
|
247
|
+
logging.info("KVZCH is in embedding_cache_mode")
|
|
248
|
+
assert self.optimizer in [
|
|
249
|
+
OptimType.EXACT_ROWWISE_ADAGRAD
|
|
250
|
+
], f"only EXACT_ROWWISE_ADAGRAD supports embedding cache mode, but got {self.optimizer}"
|
|
251
|
+
if self.load_ckpt_without_opt:
|
|
252
|
+
if (
|
|
253
|
+
# pyre-ignore [16]
|
|
254
|
+
self.kv_zch_params.optimizer_type_for_st
|
|
255
|
+
== OptimType.PARTIAL_ROWWISE_ADAM.value
|
|
256
|
+
):
|
|
257
|
+
self.optimizer = OptimType.PARTIAL_ROWWISE_ADAM
|
|
258
|
+
logging.info(
|
|
259
|
+
f"Override optimizer type with {self.optimizer=} for st publish"
|
|
260
|
+
)
|
|
261
|
+
if (
|
|
262
|
+
# pyre-ignore [16]
|
|
263
|
+
self.kv_zch_params.optimizer_state_dtypes_for_st
|
|
264
|
+
is not None
|
|
265
|
+
):
|
|
266
|
+
optimizer_state_dtypes = {}
|
|
267
|
+
for k, v in dict(
|
|
268
|
+
self.kv_zch_params.optimizer_state_dtypes_for_st
|
|
269
|
+
).items():
|
|
270
|
+
optimizer_state_dtypes[k] = SparseType.from_int(v)
|
|
271
|
+
self.optimizer_state_dtypes = optimizer_state_dtypes
|
|
272
|
+
logging.info(
|
|
273
|
+
f"Override optimizer_state_dtypes with {self.optimizer_state_dtypes=} for st publish"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
self.pooling_mode = pooling_mode
|
|
277
|
+
self.bounds_check_mode_int: int = bounds_check_mode.value
|
|
278
|
+
self.embedding_specs = embedding_specs
|
|
279
|
+
self.table_names = table_names if table_names is not None else []
|
|
280
|
+
(rows, dims) = zip(*embedding_specs)
|
|
281
|
+
T_ = len(self.embedding_specs)
|
|
282
|
+
assert T_ > 0
|
|
283
|
+
# pyre-fixme[8]: Attribute has type `device`; used as `int`.
|
|
284
|
+
self.current_device: torch.device = torch.cuda.current_device()
|
|
285
|
+
|
|
286
|
+
self.enable_raw_embedding_streaming = enable_raw_embedding_streaming
|
|
287
|
+
# initialize the raw embedding streaming related variables
|
|
288
|
+
self.res_params: RESParams = res_params or RESParams()
|
|
289
|
+
if self.enable_raw_embedding_streaming:
|
|
290
|
+
self.res_params.table_sizes = [0] + list(itertools.accumulate(rows))
|
|
291
|
+
res_port_from_env = os.getenv("LOCAL_RES_PORT")
|
|
292
|
+
self.res_params.res_server_port = (
|
|
293
|
+
int(res_port_from_env) if res_port_from_env else 0
|
|
294
|
+
)
|
|
295
|
+
logging.info(
|
|
296
|
+
f"get env {self.res_params.res_server_port=}, at rank {dist.get_rank()}, with {self.res_params=}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
self.feature_table_map: list[int] = (
|
|
300
|
+
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
301
|
+
)
|
|
302
|
+
T = len(self.feature_table_map)
|
|
303
|
+
assert T_ <= T
|
|
304
|
+
table_has_feature = [False] * T_
|
|
305
|
+
for t in self.feature_table_map:
|
|
306
|
+
table_has_feature[t] = True
|
|
307
|
+
assert all(table_has_feature), "Each table must have at least one feature!"
|
|
308
|
+
|
|
309
|
+
feature_dims = [dims[t] for t in self.feature_table_map]
|
|
310
|
+
D_offsets = [dims[t] for t in self.feature_table_map]
|
|
311
|
+
D_offsets = [0] + list(itertools.accumulate(D_offsets))
|
|
312
|
+
|
|
313
|
+
# Sum of row length of all tables
|
|
314
|
+
self.total_D: int = D_offsets[-1]
|
|
315
|
+
|
|
316
|
+
# Max number of elements required to store a row in the cache
|
|
317
|
+
self.max_D: int = max(dims)
|
|
318
|
+
self.register_buffer(
|
|
319
|
+
"D_offsets",
|
|
320
|
+
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
|
|
321
|
+
)
|
|
322
|
+
assert self.D_offsets.numel() == T + 1
|
|
323
|
+
hash_size_cumsum = [0] + list(itertools.accumulate(rows))
|
|
324
|
+
if hash_size_cumsum[-1] == 0:
|
|
325
|
+
self.total_hash_size_bits: int = 0
|
|
326
|
+
else:
|
|
327
|
+
self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
|
|
328
|
+
self.register_buffer(
|
|
329
|
+
"table_hash_size_cumsum",
|
|
330
|
+
torch.tensor(
|
|
331
|
+
hash_size_cumsum, device=self.current_device, dtype=torch.int64
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
# The last element is to easily access # of rows of each table by
|
|
335
|
+
self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1)
|
|
336
|
+
self.total_hash_size: int = hash_size_cumsum[-1]
|
|
337
|
+
# The last element is to easily access # of rows of each table by
|
|
338
|
+
# hash_size_cumsum[t + 1] - hash_size_cumsum[t]
|
|
339
|
+
hash_size_cumsum = [hash_size_cumsum[t] for t in self.feature_table_map] + [
|
|
340
|
+
hash_size_cumsum[-1]
|
|
341
|
+
]
|
|
342
|
+
self.register_buffer(
|
|
343
|
+
"hash_size_cumsum",
|
|
344
|
+
torch.tensor(
|
|
345
|
+
hash_size_cumsum, device=self.current_device, dtype=torch.int64
|
|
346
|
+
),
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
self.uvm_host_mapped = uvm_host_mapped
|
|
350
|
+
logging.info(
|
|
351
|
+
f"TBE will allocate a UVM buffer with is_host_mapped={uvm_host_mapped}"
|
|
352
|
+
)
|
|
353
|
+
self.bulk_init_chunk_size = bulk_init_chunk_size
|
|
354
|
+
self.lazy_init_thread: threading.Thread | None = None
|
|
355
|
+
|
|
356
|
+
# Buffers for bounds check
|
|
357
|
+
self.register_buffer(
|
|
358
|
+
"rows_per_table",
|
|
359
|
+
torch.tensor(
|
|
360
|
+
[rows[t] for t in self.feature_table_map],
|
|
361
|
+
device=self.current_device,
|
|
362
|
+
dtype=torch.int64,
|
|
363
|
+
),
|
|
364
|
+
)
|
|
365
|
+
self.register_buffer(
|
|
366
|
+
"bounds_check_warning",
|
|
367
|
+
torch.tensor([0], device=self.current_device, dtype=torch.int64),
|
|
368
|
+
)
|
|
369
|
+
# Required for VBE
|
|
370
|
+
self.register_buffer(
|
|
371
|
+
"feature_dims",
|
|
372
|
+
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
|
|
373
|
+
)
|
|
374
|
+
self.register_buffer(
|
|
375
|
+
"table_dims",
|
|
376
|
+
torch.tensor(dims, device="cpu", dtype=torch.int64),
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
(info_B_num_bits_, info_B_mask_) = torch.ops.fbgemm.get_infos_metadata(
|
|
380
|
+
self.D_offsets, # unused tensor
|
|
381
|
+
1, # max_B
|
|
382
|
+
T, # T
|
|
383
|
+
)
|
|
384
|
+
self.info_B_num_bits: int = info_B_num_bits_
|
|
385
|
+
self.info_B_mask: int = info_B_mask_
|
|
386
|
+
|
|
387
|
+
assert cache_sets > 0
|
|
388
|
+
element_size = weights_precision.bit_rate() // 8
|
|
389
|
+
assert (
|
|
390
|
+
element_size == 4 or element_size == 2
|
|
391
|
+
), f"Invalid element size {element_size}"
|
|
392
|
+
cache_size = cache_sets * ASSOC * element_size * self.cache_row_dim
|
|
393
|
+
logging.info(
|
|
394
|
+
f"Using cache for SSD with admission algorithm "
|
|
395
|
+
f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_rocksdb_shards} shards, "
|
|
396
|
+
f"SSD storage directory: {ssd_storage_directory}, "
|
|
397
|
+
f"Memtable Flush Period: {ssd_memtable_flush_period}, "
|
|
398
|
+
f"Memtable Flush Offset: {ssd_memtable_flush_offset}, "
|
|
399
|
+
f"Desired L0 files per compaction: {ssd_l0_files_per_compact}, "
|
|
400
|
+
f"Cache size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
|
|
401
|
+
f"weights precision: {weights_precision}, "
|
|
402
|
+
f"output dtype: {output_dtype}, "
|
|
403
|
+
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
|
|
404
|
+
f"kv_zch_params: {kv_zch_params}, "
|
|
405
|
+
f"embedding spec: {embedding_specs}"
|
|
406
|
+
)
|
|
407
|
+
self.register_buffer(
|
|
408
|
+
"lxu_cache_state",
|
|
409
|
+
torch.zeros(
|
|
410
|
+
cache_sets, ASSOC, device=self.current_device, dtype=torch.int64
|
|
411
|
+
).fill_(-1),
|
|
412
|
+
)
|
|
413
|
+
self.register_buffer(
|
|
414
|
+
"lru_state",
|
|
415
|
+
torch.zeros(
|
|
416
|
+
cache_sets, ASSOC, device=self.current_device, dtype=torch.int64
|
|
417
|
+
),
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
self.step = 0
|
|
421
|
+
self.last_flush_step = -1
|
|
422
|
+
|
|
423
|
+
# Set prefetch pipeline
|
|
424
|
+
self.prefetch_pipeline: bool = prefetch_pipeline
|
|
425
|
+
self.prefetch_stream: Optional[torch.cuda.Stream] = None
|
|
426
|
+
|
|
427
|
+
# Cache locking counter for pipeline prefetching
|
|
428
|
+
if self.prefetch_pipeline:
|
|
429
|
+
self.register_buffer(
|
|
430
|
+
"lxu_cache_locking_counter",
|
|
431
|
+
torch.zeros(
|
|
432
|
+
cache_sets,
|
|
433
|
+
ASSOC,
|
|
434
|
+
device=self.current_device,
|
|
435
|
+
dtype=torch.int32,
|
|
436
|
+
),
|
|
437
|
+
persistent=True,
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
self.register_buffer(
|
|
441
|
+
"lxu_cache_locking_counter",
|
|
442
|
+
torch.zeros([0, 0], dtype=torch.int32, device=self.current_device),
|
|
443
|
+
persistent=False,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
assert ssd_cache_location in (
|
|
447
|
+
EmbeddingLocation.MANAGED,
|
|
448
|
+
EmbeddingLocation.DEVICE,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
cache_dtype = weights_precision.as_dtype()
|
|
452
|
+
if ssd_cache_location == EmbeddingLocation.MANAGED:
|
|
453
|
+
self.register_buffer(
|
|
454
|
+
"lxu_cache_weights",
|
|
455
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
456
|
+
torch.zeros(
|
|
457
|
+
1,
|
|
458
|
+
device=self.current_device,
|
|
459
|
+
dtype=cache_dtype,
|
|
460
|
+
),
|
|
461
|
+
[cache_sets * ASSOC, self.cache_row_dim],
|
|
462
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
463
|
+
),
|
|
464
|
+
)
|
|
465
|
+
else:
|
|
466
|
+
self.register_buffer(
|
|
467
|
+
"lxu_cache_weights",
|
|
468
|
+
torch.zeros(
|
|
469
|
+
cache_sets * ASSOC,
|
|
470
|
+
self.cache_row_dim,
|
|
471
|
+
device=self.current_device,
|
|
472
|
+
dtype=cache_dtype,
|
|
473
|
+
),
|
|
474
|
+
)
|
|
475
|
+
assert (
|
|
476
|
+
cache_size
|
|
477
|
+
== self.lxu_cache_weights.numel()
|
|
478
|
+
* self.lxu_cache_weights.element_size()
|
|
479
|
+
), "The precomputed cache_size does not match the actual cache size"
|
|
480
|
+
|
|
481
|
+
# Buffers for cache eviction
|
|
482
|
+
# For storing weights to evict
|
|
483
|
+
# The max number of rows to be evicted is limited by the number of
|
|
484
|
+
# slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to
|
|
485
|
+
# be the same shape as the L1 cache (lxu_cache_weights)
|
|
486
|
+
self.register_buffer(
|
|
487
|
+
"lxu_cache_evicted_weights",
|
|
488
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
489
|
+
torch.zeros(
|
|
490
|
+
1,
|
|
491
|
+
device=self.current_device,
|
|
492
|
+
dtype=cache_dtype,
|
|
493
|
+
),
|
|
494
|
+
self.lxu_cache_weights.shape,
|
|
495
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
496
|
+
),
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# For storing embedding indices to evict to
|
|
500
|
+
self.register_buffer(
|
|
501
|
+
"lxu_cache_evicted_indices",
|
|
502
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
503
|
+
torch.zeros(
|
|
504
|
+
1,
|
|
505
|
+
device=self.current_device,
|
|
506
|
+
dtype=torch.long,
|
|
507
|
+
),
|
|
508
|
+
(self.lxu_cache_weights.shape[0],),
|
|
509
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
510
|
+
),
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# For storing cache slots to evict
|
|
514
|
+
self.register_buffer(
|
|
515
|
+
"lxu_cache_evicted_slots",
|
|
516
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
517
|
+
torch.zeros(
|
|
518
|
+
1,
|
|
519
|
+
device=self.current_device,
|
|
520
|
+
dtype=torch.int,
|
|
521
|
+
),
|
|
522
|
+
(self.lxu_cache_weights.shape[0],),
|
|
523
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
524
|
+
),
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
# For storing the number of evicted rows
|
|
528
|
+
self.register_buffer(
|
|
529
|
+
"lxu_cache_evicted_count",
|
|
530
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
531
|
+
torch.zeros(
|
|
532
|
+
1,
|
|
533
|
+
device=self.current_device,
|
|
534
|
+
dtype=torch.int,
|
|
535
|
+
),
|
|
536
|
+
(1,),
|
|
537
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
538
|
+
),
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
self.timestep = 0
|
|
542
|
+
|
|
543
|
+
# Store the iteration number on GPU and CPU (used for certain optimizers)
|
|
544
|
+
persistent_iter_ = optimizer in (OptimType.PARTIAL_ROWWISE_ADAM,)
|
|
545
|
+
self.register_buffer(
|
|
546
|
+
"iter",
|
|
547
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
548
|
+
persistent=persistent_iter_,
|
|
549
|
+
)
|
|
550
|
+
self.iter_cpu: torch.Tensor = torch.zeros(1, dtype=torch.int64, device="cpu")
|
|
551
|
+
|
|
552
|
+
# Dummy profile configuration for measuring the SSD get/set time
|
|
553
|
+
# get and set are executed by another thread which (for some reason) is
|
|
554
|
+
# not traceable by PyTorch's Kineto. We workaround this problem by
|
|
555
|
+
# injecting a dummy kernel into the GPU stream to make it traceable
|
|
556
|
+
#
|
|
557
|
+
# This function can be enabled by setting an environment variable
|
|
558
|
+
# FBGEMM_SSD_TBE_USE_DUMMY_PROFILE=1
|
|
559
|
+
self.dummy_profile_tensor: Tensor = torch.as_tensor(
|
|
560
|
+
[0], device=self.current_device, dtype=torch.int
|
|
561
|
+
)
|
|
562
|
+
set_dummy_profile = os.environ.get("FBGEMM_SSD_TBE_USE_DUMMY_PROFILE")
|
|
563
|
+
use_dummy_profile = False
|
|
564
|
+
if set_dummy_profile is not None:
|
|
565
|
+
use_dummy_profile = int(set_dummy_profile) == 1
|
|
566
|
+
logging.info(
|
|
567
|
+
f"FBGEMM_SSD_TBE_USE_DUMMY_PROFILE is set to {set_dummy_profile}; "
|
|
568
|
+
f"Use dummy profile: {use_dummy_profile}"
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
self.record_function_via_dummy_profile: Callable[..., Any] = (
|
|
572
|
+
self.record_function_via_dummy_profile_factory(use_dummy_profile)
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
if use_passed_in_path:
|
|
576
|
+
ssd_dir_list = ssd_storage_directory.split(",")
|
|
577
|
+
for ssd_dir in ssd_dir_list:
|
|
578
|
+
os.makedirs(ssd_dir, exist_ok=True)
|
|
579
|
+
|
|
580
|
+
ssd_directory = ssd_storage_directory
|
|
581
|
+
# logging.info("DEBUG: weights_precision {}".format(weights_precision))
|
|
582
|
+
|
|
583
|
+
"""
|
|
584
|
+
##################### for ZCH v.Next loading checkpoints Short Term Solution #######################
|
|
585
|
+
weight_id tensor is the weight and optimizer keys, to load from checkpoint, weight_id tensor
|
|
586
|
+
needs to be loaded first, then we can load the weight and optimizer tensors.
|
|
587
|
+
However, the stateful checkpoint loading does not guarantee the tensor loading order, so we need
|
|
588
|
+
to cache the weight_id, weight and optimizer tensors untils all data are loaded, then we can apply
|
|
589
|
+
them to backend.
|
|
590
|
+
Currently, we'll cache the weight_id, weight and optimizer tensors in the KVZCHCachedData class,
|
|
591
|
+
and apply them to backend when all data are loaded. The downside of this solution is that we'll
|
|
592
|
+
have to duplicate a whole tensor memory to backend before we can release the python tensor memory,
|
|
593
|
+
which is not ideal.
|
|
594
|
+
The longer term solution is to support the caching from the backend side, and allow streaming based
|
|
595
|
+
data move from cached weight and optimizer to key/value format without duplicate one whole tensor's
|
|
596
|
+
memory.
|
|
597
|
+
"""
|
|
598
|
+
self._cached_kvzch_data: Optional[KVZCHCachedData] = None
|
|
599
|
+
# initial embedding rows on this rank per table, this is used for loading checkpoint
|
|
600
|
+
self.local_weight_counts: list[int] = [0] * T_
|
|
601
|
+
# groundtruth global id on this rank per table, this is used for loading checkpoint
|
|
602
|
+
self.global_id_per_rank: list[torch.Tensor] = [torch.zeros(0)] * T_
|
|
603
|
+
# loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend
|
|
604
|
+
self.load_state_dict: bool = False
|
|
605
|
+
|
|
606
|
+
SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self)
|
|
607
|
+
if SSDTableBatchedEmbeddingBags._first_instance_ref is None:
|
|
608
|
+
SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self)
|
|
609
|
+
|
|
610
|
+
# create tbe unique id using rank index | local tbe idx
|
|
611
|
+
if tbe_unique_id == -1:
|
|
612
|
+
SSDTableBatchedEmbeddingBags._local_instance_index += 1
|
|
613
|
+
if dist.is_initialized():
|
|
614
|
+
assert (
|
|
615
|
+
SSDTableBatchedEmbeddingBags._local_instance_index < 1024
|
|
616
|
+
), f"{SSDTableBatchedEmbeddingBags._local_instance_index}, more than 1024 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
|
|
617
|
+
tbe_unique_id = (
|
|
618
|
+
dist.get_rank() << 10
|
|
619
|
+
| SSDTableBatchedEmbeddingBags._local_instance_index
|
|
620
|
+
)
|
|
621
|
+
else:
|
|
622
|
+
logging.warning("dist is not initialized, treating as single gpu cases")
|
|
623
|
+
tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
|
|
624
|
+
self.tbe_unique_id = tbe_unique_id
|
|
625
|
+
self.l2_cache_size = l2_cache_size
|
|
626
|
+
logging.info(f"tbe_unique_id: {tbe_unique_id}")
|
|
627
|
+
self.enable_free_mem_trigger_eviction: bool = False
|
|
628
|
+
if self.backend_type == BackendType.SSD:
|
|
629
|
+
logging.info(
|
|
630
|
+
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
|
|
631
|
+
f"enable_async_update:{enable_async_update}, passed_in_path={ssd_directory}, "
|
|
632
|
+
f"num_shards={ssd_rocksdb_shards}, num_threads={ssd_rocksdb_shards}, "
|
|
633
|
+
f"memtable_flush_period={ssd_memtable_flush_period}, memtable_flush_offset={ssd_memtable_flush_offset}, "
|
|
634
|
+
f"l0_files_per_compact={ssd_l0_files_per_compact}, max_D={self.max_D}, "
|
|
635
|
+
f"cache_row_size={self.cache_row_dim}, rate_limit_mbps={ssd_rate_limit_mbps}, "
|
|
636
|
+
f"size_ratio={ssd_size_ratio}, compaction_trigger={ssd_compaction_trigger}, "
|
|
637
|
+
f"lazy_bulk_init_enabled={lazy_bulk_init_enabled}, write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size}, "
|
|
638
|
+
f"max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num}, "
|
|
639
|
+
f"uniform_init_lower={ssd_uniform_init_lower}, uniform_init_upper={ssd_uniform_init_upper}, "
|
|
640
|
+
f"row_storage_bitwidth={weights_precision.bit_rate()}, block_cache_size_per_tbe={ssd_block_cache_size_per_tbe}, "
|
|
641
|
+
f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB, "
|
|
642
|
+
f"enable_raw_embedding_streaming:{self.enable_raw_embedding_streaming}, flushing_block_size:{flushing_block_size}"
|
|
643
|
+
)
|
|
644
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
645
|
+
self._ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
|
|
646
|
+
ssd_directory,
|
|
647
|
+
ssd_rocksdb_shards,
|
|
648
|
+
ssd_rocksdb_shards,
|
|
649
|
+
ssd_memtable_flush_period,
|
|
650
|
+
ssd_memtable_flush_offset,
|
|
651
|
+
ssd_l0_files_per_compact,
|
|
652
|
+
self.cache_row_dim,
|
|
653
|
+
ssd_rate_limit_mbps,
|
|
654
|
+
ssd_size_ratio,
|
|
655
|
+
ssd_compaction_trigger,
|
|
656
|
+
ssd_rocksdb_write_buffer_size,
|
|
657
|
+
ssd_max_write_buffer_num,
|
|
658
|
+
ssd_uniform_init_lower,
|
|
659
|
+
ssd_uniform_init_upper,
|
|
660
|
+
weights_precision.bit_rate(), # row_storage_bitwidth
|
|
661
|
+
ssd_block_cache_size_per_tbe,
|
|
662
|
+
use_passed_in_path,
|
|
663
|
+
tbe_unique_id,
|
|
664
|
+
l2_cache_size,
|
|
665
|
+
enable_async_update,
|
|
666
|
+
self.enable_raw_embedding_streaming,
|
|
667
|
+
self.res_params.res_store_shards,
|
|
668
|
+
self.res_params.res_server_port,
|
|
669
|
+
self.res_params.table_names,
|
|
670
|
+
self.res_params.table_offsets,
|
|
671
|
+
self.res_params.table_sizes,
|
|
672
|
+
(
|
|
673
|
+
tensor_pad4(self.table_dims)
|
|
674
|
+
if self.enable_optimizer_offloading
|
|
675
|
+
else None
|
|
676
|
+
),
|
|
677
|
+
(
|
|
678
|
+
self.table_hash_size_cumsum.cpu()
|
|
679
|
+
if self.enable_optimizer_offloading
|
|
680
|
+
else None
|
|
681
|
+
),
|
|
682
|
+
flushing_block_size,
|
|
683
|
+
self._embedding_cache_mode, # disable_random_init
|
|
684
|
+
)
|
|
685
|
+
if self.bulk_init_chunk_size > 0:
|
|
686
|
+
self.ssd_uniform_init_lower: float = ssd_uniform_init_lower
|
|
687
|
+
self.ssd_uniform_init_upper: float = ssd_uniform_init_upper
|
|
688
|
+
if lazy_bulk_init_enabled:
|
|
689
|
+
self._lazy_initialize_ssd_tbe()
|
|
690
|
+
else:
|
|
691
|
+
self._insert_all_kv()
|
|
692
|
+
elif self.backend_type == BackendType.PS:
|
|
693
|
+
self._ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
|
|
694
|
+
[host[0] for host in ps_hosts], # pyre-ignore
|
|
695
|
+
[host[1] for host in ps_hosts],
|
|
696
|
+
tbe_unique_id,
|
|
697
|
+
(
|
|
698
|
+
ps_max_local_index_length
|
|
699
|
+
if ps_max_local_index_length is not None
|
|
700
|
+
else 54
|
|
701
|
+
),
|
|
702
|
+
ps_client_thread_num if ps_client_thread_num is not None else 32,
|
|
703
|
+
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
|
|
704
|
+
l2_cache_size,
|
|
705
|
+
self.cache_row_dim,
|
|
706
|
+
)
|
|
707
|
+
elif self.backend_type == BackendType.DRAM:
|
|
708
|
+
logging.info(
|
|
709
|
+
f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,"
|
|
710
|
+
f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
|
|
711
|
+
f"max_D={self.max_D},"
|
|
712
|
+
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
|
|
713
|
+
f"row_storage_bitwidth={weights_precision.bit_rate()},"
|
|
714
|
+
f"self.cache_row_dim={self.cache_row_dim},"
|
|
715
|
+
f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
|
|
716
|
+
f"feature_dims={self.feature_dims},"
|
|
717
|
+
f"hash_size_cumsum={self.hash_size_cumsum},"
|
|
718
|
+
f"backend_return_whole_row={self.backend_return_whole_row}"
|
|
719
|
+
)
|
|
720
|
+
table_dims = (
|
|
721
|
+
tensor_pad4(self.table_dims)
|
|
722
|
+
if self.enable_optimizer_offloading
|
|
723
|
+
else None
|
|
724
|
+
) # table_dims
|
|
725
|
+
eviction_config = None
|
|
726
|
+
if self.kv_zch_params and self.kv_zch_params.eviction_policy:
|
|
727
|
+
eviction_mem_threshold_gb = (
|
|
728
|
+
self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
|
|
729
|
+
if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
|
|
730
|
+
else self.l2_cache_size
|
|
731
|
+
)
|
|
732
|
+
kv_zch_params = self.kv_zch_params
|
|
733
|
+
eviction_policy = self.kv_zch_params.eviction_policy
|
|
734
|
+
if eviction_policy.eviction_trigger_mode == 5:
|
|
735
|
+
# If trigger mode is free_mem(5), populate config
|
|
736
|
+
self.set_free_mem_eviction_trigger_config(eviction_policy)
|
|
737
|
+
|
|
738
|
+
enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
|
|
739
|
+
[
|
|
740
|
+
int(x)
|
|
741
|
+
for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
|
|
742
|
+
]
|
|
743
|
+
if eviction_policy.enable_eviction_for_feature_score_eviction_policy
|
|
744
|
+
is not None
|
|
745
|
+
else None
|
|
746
|
+
)
|
|
747
|
+
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
|
|
748
|
+
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
|
|
749
|
+
eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
|
|
750
|
+
eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
|
|
751
|
+
eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
|
|
752
|
+
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
|
|
753
|
+
eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
|
|
754
|
+
eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
|
|
755
|
+
eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
|
|
756
|
+
eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
|
|
757
|
+
eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
|
|
758
|
+
eviction_policy.training_id_keep_count, # training_id_keep_count for each table
|
|
759
|
+
enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
|
|
760
|
+
eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
|
|
761
|
+
table_dims.tolist() if table_dims is not None else None,
|
|
762
|
+
eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
|
|
763
|
+
eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
|
|
764
|
+
eviction_policy.interval_for_insufficient_eviction_s,
|
|
765
|
+
eviction_policy.interval_for_sufficient_eviction_s,
|
|
766
|
+
eviction_policy.interval_for_feature_statistics_decay_s,
|
|
767
|
+
)
|
|
768
|
+
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
|
|
769
|
+
self.cache_row_dim,
|
|
770
|
+
ssd_uniform_init_lower,
|
|
771
|
+
ssd_uniform_init_upper,
|
|
772
|
+
eviction_config,
|
|
773
|
+
ssd_rocksdb_shards, # num_shards
|
|
774
|
+
ssd_rocksdb_shards, # num_threads
|
|
775
|
+
weights_precision.bit_rate(), # row_storage_bitwidth
|
|
776
|
+
table_dims,
|
|
777
|
+
(
|
|
778
|
+
self.table_hash_size_cumsum.cpu()
|
|
779
|
+
if self.enable_optimizer_offloading
|
|
780
|
+
else None
|
|
781
|
+
), # hash_size_cumsum
|
|
782
|
+
self.backend_return_whole_row, # backend_return_whole_row
|
|
783
|
+
False, # enable_async_update
|
|
784
|
+
self._embedding_cache_mode, # disable_random_init
|
|
785
|
+
)
|
|
786
|
+
else:
|
|
787
|
+
raise AssertionError(f"Invalid backend type {self.backend_type}")
|
|
788
|
+
|
|
789
|
+
# pyre-fixme[20]: Argument `self` expected.
|
|
790
|
+
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
|
|
791
|
+
# GPU stream for SSD cache eviction
|
|
792
|
+
self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
|
|
793
|
+
# GPU stream for SSD memory copy (also reused for feature score D2H)
|
|
794
|
+
self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
|
|
795
|
+
# GPU stream for async metadata operation
|
|
796
|
+
self.feature_score_stream = torch.cuda.Stream(priority=low_priority)
|
|
797
|
+
|
|
798
|
+
# SSD get completion event
|
|
799
|
+
self.ssd_event_get = torch.cuda.Event()
|
|
800
|
+
# SSD scratch pad eviction completion event
|
|
801
|
+
self.ssd_event_sp_evict = torch.cuda.Event()
|
|
802
|
+
# SSD cache eviction completion event
|
|
803
|
+
self.ssd_event_cache_evict = torch.cuda.Event()
|
|
804
|
+
# SSD backward completion event
|
|
805
|
+
self.ssd_event_backward = torch.cuda.Event()
|
|
806
|
+
# SSD get's input copy completion event
|
|
807
|
+
self.ssd_event_get_inputs_cpy = torch.cuda.Event()
|
|
808
|
+
if self._embedding_cache_mode:
|
|
809
|
+
# Direct write embedding completion event
|
|
810
|
+
self.direct_write_l1_complete_event: torch.cuda.streams.Event = (
|
|
811
|
+
torch.cuda.Event()
|
|
812
|
+
)
|
|
813
|
+
self.direct_write_sp_complete_event: torch.cuda.streams.Event = (
|
|
814
|
+
torch.cuda.Event()
|
|
815
|
+
)
|
|
816
|
+
# Prefetch operation completion event
|
|
817
|
+
self.prefetch_complete_event = torch.cuda.Event()
|
|
818
|
+
|
|
819
|
+
if self.prefetch_pipeline:
|
|
820
|
+
# SSD scratch pad index queue insert completion event
|
|
821
|
+
self.ssd_event_sp_idxq_insert: torch.cuda.streams.Event = torch.cuda.Event()
|
|
822
|
+
# SSD scratch pad index queue lookup completion event
|
|
823
|
+
self.ssd_event_sp_idxq_lookup: torch.cuda.streams.Event = torch.cuda.Event()
|
|
824
|
+
|
|
825
|
+
if self.enable_raw_embedding_streaming:
|
|
826
|
+
# RES reuse the eviction stream
|
|
827
|
+
self.ssd_event_cache_streamed: torch.cuda.streams.Event = torch.cuda.Event()
|
|
828
|
+
self.ssd_event_cache_streaming_synced: torch.cuda.streams.Event = (
|
|
829
|
+
torch.cuda.Event()
|
|
830
|
+
)
|
|
831
|
+
self.ssd_event_cache_streaming_computed: torch.cuda.streams.Event = (
|
|
832
|
+
torch.cuda.Event()
|
|
833
|
+
)
|
|
834
|
+
self.ssd_event_sp_streamed: torch.cuda.streams.Event = torch.cuda.Event()
|
|
835
|
+
|
|
836
|
+
# Updated buffers
|
|
837
|
+
self.register_buffer(
|
|
838
|
+
"lxu_cache_updated_weights",
|
|
839
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
840
|
+
torch.zeros(
|
|
841
|
+
1,
|
|
842
|
+
device=self.current_device,
|
|
843
|
+
dtype=cache_dtype,
|
|
844
|
+
),
|
|
845
|
+
self.lxu_cache_weights.shape,
|
|
846
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
847
|
+
),
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
# For storing embedding indices to update to
|
|
851
|
+
self.register_buffer(
|
|
852
|
+
"lxu_cache_updated_indices",
|
|
853
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
854
|
+
torch.zeros(
|
|
855
|
+
1,
|
|
856
|
+
device=self.current_device,
|
|
857
|
+
dtype=torch.long,
|
|
858
|
+
),
|
|
859
|
+
(self.lxu_cache_weights.shape[0],),
|
|
860
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
861
|
+
),
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# For storing the number of updated rows
|
|
865
|
+
self.register_buffer(
|
|
866
|
+
"lxu_cache_updated_count",
|
|
867
|
+
torch.ops.fbgemm.new_unified_tensor(
|
|
868
|
+
torch.zeros(
|
|
869
|
+
1,
|
|
870
|
+
device=self.current_device,
|
|
871
|
+
dtype=torch.int,
|
|
872
|
+
),
|
|
873
|
+
(1,),
|
|
874
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
875
|
+
),
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
# (Indices, Count)
|
|
879
|
+
self.prefetched_info: list[tuple[Tensor, Tensor]] = []
|
|
880
|
+
|
|
881
|
+
self.timesteps_prefetched: list[int] = []
|
|
882
|
+
# TODO: add type annotation
|
|
883
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
884
|
+
self.ssd_prefetch_data = []
|
|
885
|
+
|
|
886
|
+
# Scratch pad eviction data queue
|
|
887
|
+
self.ssd_scratch_pad_eviction_data: list[
|
|
888
|
+
tuple[Tensor, Tensor, Tensor, bool]
|
|
889
|
+
] = []
|
|
890
|
+
self.ssd_location_update_data: list[tuple[Tensor, Tensor]] = []
|
|
891
|
+
|
|
892
|
+
if self.prefetch_pipeline:
|
|
893
|
+
# Scratch pad value queue
|
|
894
|
+
self.ssd_scratch_pads: list[tuple[Tensor, Tensor, Tensor]] = []
|
|
895
|
+
|
|
896
|
+
# pyre-ignore[4]
|
|
897
|
+
# Scratch pad index queue
|
|
898
|
+
self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(
|
|
899
|
+
-1
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization:
|
|
903
|
+
raise AssertionError(
|
|
904
|
+
"weight_decay_mode = WeightDecayMode.COUNTER is not supported for SSD TBE."
|
|
905
|
+
)
|
|
906
|
+
counter_based_regularization = CounterBasedRegularizationDefinition()
|
|
907
|
+
|
|
908
|
+
if weight_decay_mode == WeightDecayMode.COWCLIP or cowclip_regularization:
|
|
909
|
+
raise AssertionError(
|
|
910
|
+
"weight_decay_mode = WeightDecayMode.COWCLIP is not supported for SSD TBE."
|
|
911
|
+
)
|
|
912
|
+
cowclip_regularization = CowClipDefinition()
|
|
913
|
+
|
|
914
|
+
self.learning_rate_tensor: torch.Tensor = torch.tensor(
|
|
915
|
+
learning_rate, device=torch.device("cpu"), dtype=torch.float32
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs(
|
|
919
|
+
stochastic_rounding=stochastic_rounding,
|
|
920
|
+
gradient_clipping=gradient_clipping,
|
|
921
|
+
max_gradient=max_gradient,
|
|
922
|
+
max_norm=max_norm,
|
|
923
|
+
eps=eps,
|
|
924
|
+
beta1=beta1,
|
|
925
|
+
beta2=beta2,
|
|
926
|
+
weight_decay=weight_decay,
|
|
927
|
+
weight_decay_mode=weight_decay_mode.value,
|
|
928
|
+
eta=eta,
|
|
929
|
+
momentum=momentum,
|
|
930
|
+
counter_halflife=counter_based_regularization.counter_halflife,
|
|
931
|
+
adjustment_iter=counter_based_regularization.adjustment_iter,
|
|
932
|
+
adjustment_ub=counter_based_regularization.adjustment_ub,
|
|
933
|
+
learning_rate_mode=counter_based_regularization.learning_rate_mode.value,
|
|
934
|
+
grad_sum_decay=counter_based_regularization.grad_sum_decay.value,
|
|
935
|
+
tail_id_threshold=counter_based_regularization.tail_id_threshold.val,
|
|
936
|
+
is_tail_id_thresh_ratio=int(
|
|
937
|
+
counter_based_regularization.tail_id_threshold.is_ratio
|
|
938
|
+
),
|
|
939
|
+
total_hash_size=-1, # Unused
|
|
940
|
+
weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
|
|
941
|
+
lower_bound=cowclip_regularization.lower_bound,
|
|
942
|
+
regularization_mode=weight_decay_mode.value,
|
|
943
|
+
use_rowwise_bias_correction=use_rowwise_bias_correction, # Used in Adam optimizer
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
table_embedding_dtype = weights_precision.as_dtype()
|
|
947
|
+
|
|
948
|
+
self._apply_split(
|
|
949
|
+
SplitState(
|
|
950
|
+
dev_size=0,
|
|
951
|
+
host_size=0,
|
|
952
|
+
uvm_size=0,
|
|
953
|
+
placements=[EmbeddingLocation.MANAGED_CACHING for _ in range(T_)],
|
|
954
|
+
offsets=[0] * (len(rows)),
|
|
955
|
+
),
|
|
956
|
+
"weights",
|
|
957
|
+
# pyre-fixme[6]: For 3rd argument expected `Type[dtype]` but got `dtype`.
|
|
958
|
+
dtype=table_embedding_dtype,
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
# Create the optimizer state tensors
|
|
962
|
+
for template in self.optimizer.ssd_state_splits(
|
|
963
|
+
self.embedding_specs,
|
|
964
|
+
self.optimizer_state_dtypes,
|
|
965
|
+
self.enable_optimizer_offloading,
|
|
966
|
+
):
|
|
967
|
+
# pyre-fixme[6]: For 3rd argument expected `Type[dtype]` but got `dtype`.
|
|
968
|
+
self._apply_split(*template)
|
|
969
|
+
|
|
970
|
+
# For storing current iteration data
|
|
971
|
+
self.current_iter_data: Optional[IterData] = None
|
|
972
|
+
|
|
973
|
+
# add placeholder require_grad param to enable autograd without nn.parameter
|
|
974
|
+
# this is needed to enable int8 embedding weights for SplitTableBatchedEmbedding
|
|
975
|
+
self.placeholder_autograd_tensor = nn.Parameter(
|
|
976
|
+
torch.zeros(0, device=self.current_device, dtype=torch.float)
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# Register backward hook for evicting rows from a scratch pad to SSD
|
|
980
|
+
# post backward
|
|
981
|
+
self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad)
|
|
982
|
+
|
|
983
|
+
if self.prefetch_pipeline:
|
|
984
|
+
self.register_full_backward_pre_hook(
|
|
985
|
+
self._update_cache_counter_and_pointers
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
# stats reporter
|
|
989
|
+
self.gather_ssd_cache_stats = gather_ssd_cache_stats
|
|
990
|
+
self.stats_reporter: Optional[TBEStatsReporter] = (
|
|
991
|
+
stats_reporter_config.create_reporter() if stats_reporter_config else None
|
|
992
|
+
)
|
|
993
|
+
self.ssd_cache_stats_size = 6
|
|
994
|
+
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
|
|
995
|
+
# 4: N_conflict_unique_misses, 5: N_conflict_misses
|
|
996
|
+
self.last_reported_ssd_stats: list[float] = []
|
|
997
|
+
self.last_reported_step = 0
|
|
998
|
+
|
|
999
|
+
self.register_buffer(
|
|
1000
|
+
"ssd_cache_stats",
|
|
1001
|
+
torch.zeros(
|
|
1002
|
+
size=(self.ssd_cache_stats_size,),
|
|
1003
|
+
device=self.current_device,
|
|
1004
|
+
dtype=torch.int64,
|
|
1005
|
+
),
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
self.register_buffer(
|
|
1009
|
+
"local_ssd_cache_stats",
|
|
1010
|
+
torch.zeros(
|
|
1011
|
+
self.ssd_cache_stats_size,
|
|
1012
|
+
device=self.current_device,
|
|
1013
|
+
dtype=torch.int32,
|
|
1014
|
+
),
|
|
1015
|
+
)
|
|
1016
|
+
logging.info(
|
|
1017
|
+
f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, "
|
|
1018
|
+
f"stats_reporter:{self.stats_reporter if self.stats_reporter else 'none'}"
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
# prefetch launch a series of kernels, we use AsyncSeriesTimer to track the kernel time
|
|
1022
|
+
self.ssd_prefetch_read_timer: Optional[AsyncSeriesTimer] = None
|
|
1023
|
+
self.ssd_prefetch_evict_timer: Optional[AsyncSeriesTimer] = None
|
|
1024
|
+
self.prefetch_parallel_stream_cnt: int = 2
|
|
1025
|
+
# tuple of iteration, prefetch parallel stream cnt, reported duration
|
|
1026
|
+
# since there are 2 stream in parallel in prefetch, we want to count the longest one
|
|
1027
|
+
self.prefetch_duration_us: tuple[int, int, float] = (
|
|
1028
|
+
-1,
|
|
1029
|
+
self.prefetch_parallel_stream_cnt,
|
|
1030
|
+
0,
|
|
1031
|
+
)
|
|
1032
|
+
self.l2_num_cache_misses_stats_name: str = (
|
|
1033
|
+
f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_cache_misses"
|
|
1034
|
+
)
|
|
1035
|
+
self.l2_num_cache_lookups_stats_name: str = (
|
|
1036
|
+
f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_lookups"
|
|
1037
|
+
)
|
|
1038
|
+
self.l2_num_cache_evictions_stats_name: str = (
|
|
1039
|
+
f"l2_cache.perf.tbe_id{tbe_unique_id}.num_l2_cache_evictions"
|
|
1040
|
+
)
|
|
1041
|
+
self.l2_cache_free_mem_stats_name: str = (
|
|
1042
|
+
f"l2_cache.mem.tbe_id{tbe_unique_id}.free_mem_bytes"
|
|
1043
|
+
)
|
|
1044
|
+
self.l2_cache_capacity_stats_name: str = (
|
|
1045
|
+
f"l2_cache.mem.tbe_id{tbe_unique_id}.capacity_bytes"
|
|
1046
|
+
)
|
|
1047
|
+
self.dram_kv_actual_used_chunk_bytes_stats_name: str = (
|
|
1048
|
+
f"dram_kv.mem.tbe_id{tbe_unique_id}.actual_used_chunk_bytes"
|
|
1049
|
+
)
|
|
1050
|
+
self.dram_kv_allocated_bytes_stats_name: str = (
|
|
1051
|
+
f"dram_kv.mem.tbe_id{tbe_unique_id}.allocated_bytes"
|
|
1052
|
+
)
|
|
1053
|
+
self.dram_kv_mem_num_rows_stats_name: str = (
|
|
1054
|
+
f"dram_kv.mem.tbe_id{tbe_unique_id}.num_rows"
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
self.eviction_sum_evicted_counts_stats_name: str = (
|
|
1058
|
+
f"eviction.tbe_id.{tbe_unique_id}.sum_evicted_counts"
|
|
1059
|
+
)
|
|
1060
|
+
self.eviction_sum_processed_counts_stats_name: str = (
|
|
1061
|
+
f"eviction.tbe_id.{tbe_unique_id}.sum_processed_counts"
|
|
1062
|
+
)
|
|
1063
|
+
self.eviction_evict_rate_stats_name: str = (
|
|
1064
|
+
f"eviction.tbe_id.{tbe_unique_id}.evict_rate"
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
if self.stats_reporter:
|
|
1068
|
+
self.ssd_prefetch_read_timer = AsyncSeriesTimer(
|
|
1069
|
+
functools.partial(
|
|
1070
|
+
SSDTableBatchedEmbeddingBags._report_duration,
|
|
1071
|
+
self,
|
|
1072
|
+
event_name="tbe.prefetch_duration_us",
|
|
1073
|
+
time_unit="us",
|
|
1074
|
+
)
|
|
1075
|
+
)
|
|
1076
|
+
self.ssd_prefetch_evict_timer = AsyncSeriesTimer(
|
|
1077
|
+
functools.partial(
|
|
1078
|
+
SSDTableBatchedEmbeddingBags._report_duration,
|
|
1079
|
+
self,
|
|
1080
|
+
event_name="tbe.prefetch_duration_us",
|
|
1081
|
+
time_unit="us",
|
|
1082
|
+
)
|
|
1083
|
+
)
|
|
1084
|
+
# pyre-ignore
|
|
1085
|
+
self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name)
|
|
1086
|
+
self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name)
|
|
1087
|
+
self.stats_reporter.register_stats(self.l2_num_cache_evictions_stats_name)
|
|
1088
|
+
self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name)
|
|
1089
|
+
self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name)
|
|
1090
|
+
self.stats_reporter.register_stats(self.dram_kv_allocated_bytes_stats_name)
|
|
1091
|
+
self.stats_reporter.register_stats(
|
|
1092
|
+
self.dram_kv_actual_used_chunk_bytes_stats_name
|
|
1093
|
+
)
|
|
1094
|
+
self.stats_reporter.register_stats(self.dram_kv_mem_num_rows_stats_name)
|
|
1095
|
+
self.stats_reporter.register_stats(
|
|
1096
|
+
self.eviction_sum_evicted_counts_stats_name
|
|
1097
|
+
)
|
|
1098
|
+
self.stats_reporter.register_stats(
|
|
1099
|
+
self.eviction_sum_processed_counts_stats_name
|
|
1100
|
+
)
|
|
1101
|
+
self.stats_reporter.register_stats(self.eviction_evict_rate_stats_name)
|
|
1102
|
+
for t in self.feature_table_map:
|
|
1103
|
+
self.stats_reporter.register_stats(
|
|
1104
|
+
f"eviction.feature_table.{t}.evicted_counts"
|
|
1105
|
+
)
|
|
1106
|
+
self.stats_reporter.register_stats(
|
|
1107
|
+
f"eviction.feature_table.{t}.processed_counts"
|
|
1108
|
+
)
|
|
1109
|
+
self.stats_reporter.register_stats(
|
|
1110
|
+
f"eviction.feature_table.{t}.evict_rate"
|
|
1111
|
+
)
|
|
1112
|
+
self.stats_reporter.register_stats(
|
|
1113
|
+
"eviction.feature_table.full_duration_ms"
|
|
1114
|
+
)
|
|
1115
|
+
self.stats_reporter.register_stats(
|
|
1116
|
+
"eviction.feature_table.exec_duration_ms"
|
|
1117
|
+
)
|
|
1118
|
+
self.stats_reporter.register_stats(
|
|
1119
|
+
"eviction.feature_table.dry_run_exec_duration_ms"
|
|
1120
|
+
)
|
|
1121
|
+
self.stats_reporter.register_stats(
|
|
1122
|
+
"eviction.feature_table.exec_div_full_duration_rate"
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
self.bounds_check_version: int = get_bounds_check_version_for_platform()
|
|
1126
|
+
|
|
1127
|
+
self._pg = pg
|
|
1128
|
+
|
|
1129
|
+
@cached_property
|
|
1130
|
+
def cache_row_dim(self) -> int:
|
|
1131
|
+
"""
|
|
1132
|
+
Compute the effective physical cache row size taking into account
|
|
1133
|
+
padding to the nearest 4 elements and the optimizer state appended to
|
|
1134
|
+
the back of the row
|
|
1135
|
+
"""
|
|
1136
|
+
|
|
1137
|
+
# For st publish, we only need to load weight for publishing and bulk eval
|
|
1138
|
+
if self.enable_optimizer_offloading and not self.load_ckpt_without_opt:
|
|
1139
|
+
return self.max_D + pad4(
|
|
1140
|
+
# Compute the number of elements of cache_dtype needed to store
|
|
1141
|
+
# the optimizer state
|
|
1142
|
+
self.optimizer_state_dim
|
|
1143
|
+
)
|
|
1144
|
+
else:
|
|
1145
|
+
return self.max_D
|
|
1146
|
+
|
|
1147
|
+
@cached_property
|
|
1148
|
+
def optimizer_state_dim(self) -> int:
|
|
1149
|
+
return int(
|
|
1150
|
+
math.ceil(
|
|
1151
|
+
self.optimizer.state_size_nbytes(
|
|
1152
|
+
self.max_D, self.optimizer_state_dtypes
|
|
1153
|
+
)
|
|
1154
|
+
/ self.weights_precision.as_dtype().itemsize
|
|
1155
|
+
)
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
@property
|
|
1159
|
+
# pyre-ignore
|
|
1160
|
+
def ssd_db(self):
|
|
1161
|
+
"""Intercept the ssd_db property to make sure it is fully initialized before use.
|
|
1162
|
+
This is needed because random weights are initialized in a separate thread"""
|
|
1163
|
+
if self.lazy_init_thread is not None:
|
|
1164
|
+
self.lazy_init_thread.join()
|
|
1165
|
+
self.lazy_init_thread = None
|
|
1166
|
+
logging.info("lazy ssd tbe initialization completed, weights are ready")
|
|
1167
|
+
|
|
1168
|
+
return self._ssd_db
|
|
1169
|
+
|
|
1170
|
+
@ssd_db.setter
|
|
1171
|
+
# pyre-ignore
|
|
1172
|
+
def ssd_db(self, value):
|
|
1173
|
+
"""Setter for ssd_db property."""
|
|
1174
|
+
if self.lazy_init_thread is not None:
|
|
1175
|
+
# This is essentially a copy assignment operation, since the thread is
|
|
1176
|
+
# already existing, and we are assigning a new ssd_db to it. Complete
|
|
1177
|
+
# the initialization first, then assign the new value to it.
|
|
1178
|
+
self.lazy_init_thread.join()
|
|
1179
|
+
self.lazy_init_thread = None
|
|
1180
|
+
logging.info(
|
|
1181
|
+
"lazy ssd tbe initialization completed, ssd_db will now get overridden"
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
self._ssd_db = value
|
|
1185
|
+
|
|
1186
|
+
def _lazy_initialize_ssd_tbe(self) -> None:
|
|
1187
|
+
"""
|
|
1188
|
+
Initialize the SSD TBE with random weights. This function should only be
|
|
1189
|
+
called once at initialization time.
|
|
1190
|
+
"""
|
|
1191
|
+
if self.bulk_init_chunk_size > 0:
|
|
1192
|
+
self.lazy_init_thread = threading.Thread(target=self._insert_all_kv)
|
|
1193
|
+
# pyre-ignore
|
|
1194
|
+
self.lazy_init_thread.start()
|
|
1195
|
+
logging.info(
|
|
1196
|
+
f"lazy ssd tbe initialization started since bulk_init_chunk_size is set to {self.bulk_init_chunk_size}"
|
|
1197
|
+
)
|
|
1198
|
+
else:
|
|
1199
|
+
logging.debug(
|
|
1200
|
+
"bulk_init_chunk_size is not set, skipping lazy initialization"
|
|
1201
|
+
)
|
|
1202
|
+
|
|
1203
|
+
@torch.jit.ignore
|
|
1204
|
+
def _insert_all_kv(self) -> None:
|
|
1205
|
+
"""
|
|
1206
|
+
Populate all rows in the ssd TBE with random weights. Existing keys will
|
|
1207
|
+
be effectively overwritten. This function should only be called once at
|
|
1208
|
+
initailization time.
|
|
1209
|
+
"""
|
|
1210
|
+
self._ssd_db.toggle_compaction(False)
|
|
1211
|
+
row_offset = 0
|
|
1212
|
+
row_count = floor(
|
|
1213
|
+
self.bulk_init_chunk_size
|
|
1214
|
+
/ (self.cache_row_dim * self.weights_precision.as_dtype().itemsize)
|
|
1215
|
+
)
|
|
1216
|
+
total_dim0 = 0
|
|
1217
|
+
for dim0, _ in self.embedding_specs:
|
|
1218
|
+
total_dim0 += dim0
|
|
1219
|
+
|
|
1220
|
+
start_ts = time.time()
|
|
1221
|
+
# TODO: do we have case for non-kvzch ssd with bulk init enabled + optimizer offloading? probably not?
|
|
1222
|
+
# if we have such cases, we should only init the emb dim not the optimizer dim
|
|
1223
|
+
chunk_tensor = torch.empty(
|
|
1224
|
+
row_count,
|
|
1225
|
+
self.cache_row_dim,
|
|
1226
|
+
dtype=self.weights_precision.as_dtype(),
|
|
1227
|
+
device="cuda",
|
|
1228
|
+
)
|
|
1229
|
+
cpu_tensor = torch.empty_like(chunk_tensor, device="cpu")
|
|
1230
|
+
for row_offset in range(0, total_dim0, row_count):
|
|
1231
|
+
actual_dim0 = min(total_dim0 - row_offset, row_count)
|
|
1232
|
+
chunk_tensor.uniform_(
|
|
1233
|
+
self.ssd_uniform_init_lower, self.ssd_uniform_init_upper
|
|
1234
|
+
)
|
|
1235
|
+
cpu_tensor.copy_(chunk_tensor, non_blocking=False)
|
|
1236
|
+
rand_val = cpu_tensor[:actual_dim0, :]
|
|
1237
|
+
# This code is intentionally not calling through the getter property
|
|
1238
|
+
# to avoid the lazy initialization thread from joining with itself.
|
|
1239
|
+
self._ssd_db.set_range_to_storage(rand_val, row_offset, actual_dim0)
|
|
1240
|
+
end_ts = time.time()
|
|
1241
|
+
elapsed = int((end_ts - start_ts) * 1e6)
|
|
1242
|
+
logging.info(
|
|
1243
|
+
f"TBE bulk initialization took {elapsed:_} us, bulk_init_chunk_size={self.bulk_init_chunk_size}, each batch of {row_count} rows, total rows of {total_dim0}"
|
|
1244
|
+
)
|
|
1245
|
+
self._ssd_db.toggle_compaction(True)
|
|
1246
|
+
|
|
1247
|
+
@torch.jit.ignore
|
|
1248
|
+
def _report_duration(
|
|
1249
|
+
self,
|
|
1250
|
+
it_step: int,
|
|
1251
|
+
dur_ms: float,
|
|
1252
|
+
event_name: str,
|
|
1253
|
+
time_unit: str,
|
|
1254
|
+
) -> None:
|
|
1255
|
+
"""
|
|
1256
|
+
Callback function passed into AsyncSeriesTimer, which will be
|
|
1257
|
+
invoked when the last kernel in AsyncSeriesTimer scope is done.
|
|
1258
|
+
Currently this is only used to trace prefetch duration, in which
|
|
1259
|
+
there are 2 streams involved, main stream and eviction stream.
|
|
1260
|
+
This will report the duration of the longer stream to ODS
|
|
1261
|
+
|
|
1262
|
+
Function is not thread safe
|
|
1263
|
+
|
|
1264
|
+
Args:
|
|
1265
|
+
it_step (int): The reporting iteration step
|
|
1266
|
+
dur_ms (float): The duration of the all the kernels within the
|
|
1267
|
+
AsyncSeriesTimer scope in milliseconds
|
|
1268
|
+
event_name (str): The name of the event
|
|
1269
|
+
time_unit (str): The unit of the duration(us or ms)
|
|
1270
|
+
"""
|
|
1271
|
+
recorded_itr, stream_cnt, report_val = self.prefetch_duration_us
|
|
1272
|
+
duration = dur_ms
|
|
1273
|
+
if time_unit == "us":
|
|
1274
|
+
duration = dur_ms * 1000
|
|
1275
|
+
if it_step == recorded_itr:
|
|
1276
|
+
report_val = max(report_val, duration)
|
|
1277
|
+
stream_cnt -= 1
|
|
1278
|
+
else:
|
|
1279
|
+
# reset
|
|
1280
|
+
recorded_itr = it_step
|
|
1281
|
+
report_val = duration
|
|
1282
|
+
stream_cnt = self.prefetch_parallel_stream_cnt
|
|
1283
|
+
self.prefetch_duration_us = (recorded_itr, stream_cnt, report_val)
|
|
1284
|
+
|
|
1285
|
+
if stream_cnt == 1:
|
|
1286
|
+
# this is the last stream, handling ods report
|
|
1287
|
+
# pyre-ignore
|
|
1288
|
+
self.stats_reporter.report_duration(
|
|
1289
|
+
it_step, event_name, report_val, time_unit=time_unit
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
def record_function_via_dummy_profile_factory(
|
|
1293
|
+
self,
|
|
1294
|
+
use_dummy_profile: bool,
|
|
1295
|
+
) -> Callable[..., Any]:
|
|
1296
|
+
"""
|
|
1297
|
+
Generate the record_function_via_dummy_profile based on the
|
|
1298
|
+
use_dummy_profile flag.
|
|
1299
|
+
|
|
1300
|
+
If use_dummy_profile is True, inject a dummy kernel before and after
|
|
1301
|
+
the function call and record function via `record_function`
|
|
1302
|
+
|
|
1303
|
+
Otherwise, just execute the function
|
|
1304
|
+
|
|
1305
|
+
Args:
|
|
1306
|
+
use_dummy_profile (bool): A flag for enabling/disabling
|
|
1307
|
+
record_function_via_dummy_profile
|
|
1308
|
+
"""
|
|
1309
|
+
if use_dummy_profile:
|
|
1310
|
+
|
|
1311
|
+
def func(
|
|
1312
|
+
name: str,
|
|
1313
|
+
fn: Callable[..., Any],
|
|
1314
|
+
*args: Any,
|
|
1315
|
+
**kwargs: Any,
|
|
1316
|
+
) -> None:
|
|
1317
|
+
with record_function(name):
|
|
1318
|
+
self.dummy_profile_tensor.add_(1)
|
|
1319
|
+
fn(*args, **kwargs)
|
|
1320
|
+
self.dummy_profile_tensor.add_(1)
|
|
1321
|
+
|
|
1322
|
+
return func
|
|
1323
|
+
|
|
1324
|
+
def func(
|
|
1325
|
+
name: str,
|
|
1326
|
+
fn: Callable[..., Any],
|
|
1327
|
+
*args: Any,
|
|
1328
|
+
**kwargs: Any,
|
|
1329
|
+
) -> None:
|
|
1330
|
+
fn(*args, **kwargs)
|
|
1331
|
+
|
|
1332
|
+
return func
|
|
1333
|
+
|
|
1334
|
+
def _apply_split(
|
|
1335
|
+
self,
|
|
1336
|
+
split: SplitState,
|
|
1337
|
+
prefix: str,
|
|
1338
|
+
dtype: type[torch.dtype],
|
|
1339
|
+
enforce_hbm: bool = False,
|
|
1340
|
+
make_dev_param: bool = False,
|
|
1341
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
1342
|
+
) -> None:
|
|
1343
|
+
apply_split_helper(
|
|
1344
|
+
self.register_buffer,
|
|
1345
|
+
functools.partial(setattr, self),
|
|
1346
|
+
self.current_device,
|
|
1347
|
+
False, # use_cpu
|
|
1348
|
+
self.feature_table_map,
|
|
1349
|
+
split,
|
|
1350
|
+
prefix,
|
|
1351
|
+
dtype,
|
|
1352
|
+
enforce_hbm,
|
|
1353
|
+
make_dev_param,
|
|
1354
|
+
dev_reshape,
|
|
1355
|
+
)
|
|
1356
|
+
|
|
1357
|
+
def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:
|
|
1358
|
+
t_cpu = torch.empty(t.shape, pin_memory=True, dtype=t.dtype)
|
|
1359
|
+
t_cpu.copy_(t, non_blocking=True)
|
|
1360
|
+
return t_cpu
|
|
1361
|
+
|
|
1362
|
+
def to_pinned_cpu_on_stream_wait_on_another_stream(
|
|
1363
|
+
self,
|
|
1364
|
+
tensors: list[Tensor],
|
|
1365
|
+
stream: torch.cuda.Stream,
|
|
1366
|
+
stream_to_wait_on: torch.cuda.Stream,
|
|
1367
|
+
post_event: Optional[torch.cuda.Event] = None,
|
|
1368
|
+
) -> list[Tensor]:
|
|
1369
|
+
"""
|
|
1370
|
+
Transfer input tensors from GPU to CPU using a pinned host
|
|
1371
|
+
buffer. The transfer is carried out on the given stream
|
|
1372
|
+
(`stream`) after all the kernels in the other stream
|
|
1373
|
+
(`stream_to_wait_on`) are complete.
|
|
1374
|
+
|
|
1375
|
+
Args:
|
|
1376
|
+
tensors (List[Tensor]): The list of tensors to be
|
|
1377
|
+
transferred
|
|
1378
|
+
stream (Stream): The stream to run memory copy
|
|
1379
|
+
stream_to_wait_on (Stream): The stream to wait on
|
|
1380
|
+
post_event (Event): The post completion event
|
|
1381
|
+
|
|
1382
|
+
Returns:
|
|
1383
|
+
The list of pinned CPU tensors
|
|
1384
|
+
"""
|
|
1385
|
+
with torch.cuda.stream(stream):
|
|
1386
|
+
stream.wait_stream(stream_to_wait_on)
|
|
1387
|
+
cpu_tensors = []
|
|
1388
|
+
for t in tensors:
|
|
1389
|
+
t.record_stream(stream)
|
|
1390
|
+
cpu_tensors.append(self.to_pinned_cpu(t))
|
|
1391
|
+
if post_event is not None:
|
|
1392
|
+
stream.record_event(post_event)
|
|
1393
|
+
return cpu_tensors
|
|
1394
|
+
|
|
1395
|
+
def evict(
|
|
1396
|
+
self,
|
|
1397
|
+
rows: Tensor,
|
|
1398
|
+
indices_cpu: Tensor,
|
|
1399
|
+
actions_count_cpu: Tensor,
|
|
1400
|
+
stream: torch.cuda.Stream,
|
|
1401
|
+
pre_event: Optional[torch.cuda.Event],
|
|
1402
|
+
post_event: Optional[torch.cuda.Event],
|
|
1403
|
+
is_rows_uvm: bool,
|
|
1404
|
+
name: Optional[str] = "",
|
|
1405
|
+
is_bwd: bool = True,
|
|
1406
|
+
) -> None:
|
|
1407
|
+
"""
|
|
1408
|
+
Evict data from the given input tensors to SSD via RocksDB
|
|
1409
|
+
Args:
|
|
1410
|
+
rows (Tensor): The 2D tensor that contains rows to evict
|
|
1411
|
+
indices_cpu (Tensor): The 1D CPU tensor that contains the row
|
|
1412
|
+
indices that the rows will be evicted to
|
|
1413
|
+
actions_count_cpu (Tensor): A scalar tensor that contains the
|
|
1414
|
+
number of rows that the evict function
|
|
1415
|
+
has to process
|
|
1416
|
+
stream (Stream): The CUDA stream that cudaStreamAddCallback will
|
|
1417
|
+
synchronize the host function with. Moreover, the
|
|
1418
|
+
asynchronous D->H memory copies will operate on
|
|
1419
|
+
this stream
|
|
1420
|
+
pre_event (Event): The CUDA event that the stream has to wait on
|
|
1421
|
+
post_event (Event): The CUDA event that the current will record on
|
|
1422
|
+
when the eviction is done
|
|
1423
|
+
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
|
|
1424
|
+
tensor (which is accessible on both host and
|
|
1425
|
+
device)
|
|
1426
|
+
is_bwd (bool): A flag to indicate if the eviction is during backward
|
|
1427
|
+
Returns:
|
|
1428
|
+
None
|
|
1429
|
+
"""
|
|
1430
|
+
if not self.training: # if not training, freeze the embedding
|
|
1431
|
+
return
|
|
1432
|
+
with record_function(f"## ssd_evict_{name} ##"):
|
|
1433
|
+
with torch.cuda.stream(stream):
|
|
1434
|
+
if pre_event is not None:
|
|
1435
|
+
stream.wait_event(pre_event)
|
|
1436
|
+
|
|
1437
|
+
rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)
|
|
1438
|
+
|
|
1439
|
+
rows.record_stream(stream)
|
|
1440
|
+
|
|
1441
|
+
self.record_function_via_dummy_profile(
|
|
1442
|
+
f"## ssd_set_{name} ##",
|
|
1443
|
+
self.ssd_db.set_cuda,
|
|
1444
|
+
indices_cpu,
|
|
1445
|
+
rows_cpu,
|
|
1446
|
+
actions_count_cpu,
|
|
1447
|
+
self.timestep,
|
|
1448
|
+
is_bwd,
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
if post_event is not None:
|
|
1452
|
+
stream.record_event(post_event)
|
|
1453
|
+
|
|
1454
|
+
def raw_embedding_stream_sync(
|
|
1455
|
+
self,
|
|
1456
|
+
stream: torch.cuda.Stream,
|
|
1457
|
+
pre_event: Optional[torch.cuda.Event],
|
|
1458
|
+
post_event: Optional[torch.cuda.Event],
|
|
1459
|
+
name: Optional[str] = "",
|
|
1460
|
+
) -> None:
|
|
1461
|
+
"""
|
|
1462
|
+
Blocking wait the copy operation of the tensors to be streamed,
|
|
1463
|
+
to make sure they are not overwritten
|
|
1464
|
+
Args:
|
|
1465
|
+
stream (Stream): The CUDA stream that cudaStreamAddCallback will
|
|
1466
|
+
synchronize the host function with. Moreover, the
|
|
1467
|
+
asynchronous D->H memory copies will operate on
|
|
1468
|
+
this stream
|
|
1469
|
+
pre_event (Event): The CUDA event that the stream has to wait on
|
|
1470
|
+
post_event (Event): The CUDA event that the current will record on
|
|
1471
|
+
when the eviction is done
|
|
1472
|
+
Returns:
|
|
1473
|
+
None
|
|
1474
|
+
"""
|
|
1475
|
+
with record_function(f"## ssd_stream_{name} ##"):
|
|
1476
|
+
with torch.cuda.stream(stream):
|
|
1477
|
+
if pre_event is not None:
|
|
1478
|
+
stream.wait_event(pre_event)
|
|
1479
|
+
|
|
1480
|
+
self.record_function_via_dummy_profile(
|
|
1481
|
+
f"## ssd_stream_sync_{name} ##",
|
|
1482
|
+
self.ssd_db.stream_sync_cuda,
|
|
1483
|
+
)
|
|
1484
|
+
|
|
1485
|
+
if post_event is not None:
|
|
1486
|
+
stream.record_event(post_event)
|
|
1487
|
+
|
|
1488
|
+
def raw_embedding_stream(
|
|
1489
|
+
self,
|
|
1490
|
+
rows: Tensor,
|
|
1491
|
+
indices_cpu: Tensor,
|
|
1492
|
+
actions_count_cpu: Tensor,
|
|
1493
|
+
stream: torch.cuda.Stream,
|
|
1494
|
+
pre_event: Optional[torch.cuda.Event],
|
|
1495
|
+
post_event: Optional[torch.cuda.Event],
|
|
1496
|
+
is_rows_uvm: bool,
|
|
1497
|
+
blocking_tensor_copy: bool = True,
|
|
1498
|
+
name: Optional[str] = "",
|
|
1499
|
+
) -> None:
|
|
1500
|
+
"""
|
|
1501
|
+
Stream data from the given input tensors to a remote service
|
|
1502
|
+
Args:
|
|
1503
|
+
rows (Tensor): The 2D tensor that contains rows to evict
|
|
1504
|
+
indices_cpu (Tensor): The 1D CPU tensor that contains the row
|
|
1505
|
+
indices that the rows will be evicted to
|
|
1506
|
+
actions_count_cpu (Tensor): A scalar tensor that contains the
|
|
1507
|
+
number of rows that the evict function
|
|
1508
|
+
has to process
|
|
1509
|
+
stream (Stream): The CUDA stream that cudaStreamAddCallback will
|
|
1510
|
+
synchronize the host function with. Moreover, the
|
|
1511
|
+
asynchronous D->H memory copies will operate on
|
|
1512
|
+
this stream
|
|
1513
|
+
pre_event (Event): The CUDA event that the stream has to wait on
|
|
1514
|
+
post_event (Event): The CUDA event that the current will record on
|
|
1515
|
+
when the eviction is done
|
|
1516
|
+
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
|
|
1517
|
+
tensor (which is accessible on both host and
|
|
1518
|
+
device)
|
|
1519
|
+
Returns:
|
|
1520
|
+
None
|
|
1521
|
+
"""
|
|
1522
|
+
with record_function(f"## ssd_stream_{name} ##"):
|
|
1523
|
+
with torch.cuda.stream(stream):
|
|
1524
|
+
if pre_event is not None:
|
|
1525
|
+
stream.wait_event(pre_event)
|
|
1526
|
+
|
|
1527
|
+
rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)
|
|
1528
|
+
|
|
1529
|
+
rows.record_stream(stream)
|
|
1530
|
+
|
|
1531
|
+
self.record_function_via_dummy_profile(
|
|
1532
|
+
f"## ssd_stream_{name} ##",
|
|
1533
|
+
self.ssd_db.stream_cuda,
|
|
1534
|
+
indices_cpu,
|
|
1535
|
+
rows_cpu,
|
|
1536
|
+
actions_count_cpu,
|
|
1537
|
+
blocking_tensor_copy,
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
if post_event is not None:
|
|
1541
|
+
stream.record_event(post_event)
|
|
1542
|
+
|
|
1543
|
+
def _evict_from_scratch_pad(self, grad: Tensor) -> None:
|
|
1544
|
+
"""
|
|
1545
|
+
Evict conflict missed rows from a scratch pad
|
|
1546
|
+
(`inserted_rows`) on the `ssd_eviction_stream`. This is a hook
|
|
1547
|
+
that is invoked right after TBE backward.
|
|
1548
|
+
|
|
1549
|
+
Conflict missed indices are specified in
|
|
1550
|
+
`post_bwd_evicted_indices_cpu`. Indices that are not -1 and
|
|
1551
|
+
their positions < `actions_count_cpu` (i.e., rows
|
|
1552
|
+
`post_bwd_evicted_indices_cpu[:actions_count_cpu] != -1` in
|
|
1553
|
+
post_bwd_evicted_indices_cpu) will be evicted.
|
|
1554
|
+
|
|
1555
|
+
Args:
|
|
1556
|
+
grad (Tensor): Unused gradient tensor
|
|
1557
|
+
|
|
1558
|
+
Returns:
|
|
1559
|
+
None
|
|
1560
|
+
"""
|
|
1561
|
+
with record_function("## ssd_evict_from_scratch_pad_pipeline ##"):
|
|
1562
|
+
current_stream = torch.cuda.current_stream()
|
|
1563
|
+
current_stream.record_event(self.ssd_event_backward)
|
|
1564
|
+
|
|
1565
|
+
assert (
|
|
1566
|
+
len(self.ssd_scratch_pad_eviction_data) > 0
|
|
1567
|
+
), "There must be at least one scratch pad"
|
|
1568
|
+
|
|
1569
|
+
(
|
|
1570
|
+
inserted_rows,
|
|
1571
|
+
post_bwd_evicted_indices_cpu,
|
|
1572
|
+
actions_count_cpu,
|
|
1573
|
+
do_evict,
|
|
1574
|
+
) = self.ssd_scratch_pad_eviction_data.pop(0)
|
|
1575
|
+
|
|
1576
|
+
if not do_evict:
|
|
1577
|
+
return
|
|
1578
|
+
|
|
1579
|
+
if self.enable_raw_embedding_streaming:
|
|
1580
|
+
self.raw_embedding_stream(
|
|
1581
|
+
rows=inserted_rows,
|
|
1582
|
+
indices_cpu=post_bwd_evicted_indices_cpu,
|
|
1583
|
+
actions_count_cpu=actions_count_cpu,
|
|
1584
|
+
stream=self.ssd_eviction_stream,
|
|
1585
|
+
pre_event=self.ssd_event_backward,
|
|
1586
|
+
post_event=self.ssd_event_sp_streamed,
|
|
1587
|
+
is_rows_uvm=True,
|
|
1588
|
+
blocking_tensor_copy=True,
|
|
1589
|
+
name="scratch_pad",
|
|
1590
|
+
)
|
|
1591
|
+
self.evict(
|
|
1592
|
+
rows=inserted_rows,
|
|
1593
|
+
indices_cpu=post_bwd_evicted_indices_cpu,
|
|
1594
|
+
actions_count_cpu=actions_count_cpu,
|
|
1595
|
+
stream=self.ssd_eviction_stream,
|
|
1596
|
+
pre_event=self.ssd_event_backward,
|
|
1597
|
+
post_event=self.ssd_event_sp_evict,
|
|
1598
|
+
is_rows_uvm=True,
|
|
1599
|
+
name="scratch_pad",
|
|
1600
|
+
)
|
|
1601
|
+
|
|
1602
|
+
if self.prefetch_stream:
|
|
1603
|
+
self.prefetch_stream.wait_stream(current_stream)
|
|
1604
|
+
|
|
1605
|
+
def _update_cache_counter_and_pointers(
|
|
1606
|
+
self,
|
|
1607
|
+
module: nn.Module,
|
|
1608
|
+
grad_input: Union[tuple[Tensor, ...], Tensor],
|
|
1609
|
+
) -> None:
|
|
1610
|
+
"""
|
|
1611
|
+
Update cache line locking counter and pointers before backward
|
|
1612
|
+
TBE. This is a hook that is called before the backward of TBE
|
|
1613
|
+
|
|
1614
|
+
Update cache line counter:
|
|
1615
|
+
|
|
1616
|
+
We ensure that cache prefetching does not execute concurrently
|
|
1617
|
+
with the backward TBE. Therefore, it is safe to unlock the
|
|
1618
|
+
cache lines used in current iteration before backward TBE.
|
|
1619
|
+
|
|
1620
|
+
Update pointers:
|
|
1621
|
+
|
|
1622
|
+
Now some rows that are used in both the current iteration and
|
|
1623
|
+
the next iteration are moved (1) from the current iteration's
|
|
1624
|
+
scratch pad into the next iteration's scratch pad or (2) from
|
|
1625
|
+
the current iteration's scratch pad into the L1 cache
|
|
1626
|
+
|
|
1627
|
+
To ensure that the TBE backward kernel accesses valid data,
|
|
1628
|
+
here we update the pointers of these rows in the current
|
|
1629
|
+
iteration's `lxu_cache_ptrs` to point to either L1 cache or
|
|
1630
|
+
the next iteration scratch pad
|
|
1631
|
+
|
|
1632
|
+
Args:
|
|
1633
|
+
module (nn.Module): Unused
|
|
1634
|
+
grad_input (Union[Tuple[Tensor, ...], Tensor]): Unused
|
|
1635
|
+
|
|
1636
|
+
Returns:
|
|
1637
|
+
None
|
|
1638
|
+
"""
|
|
1639
|
+
if self.prefetch_stream:
|
|
1640
|
+
# Ensure that prefetch is done
|
|
1641
|
+
torch.cuda.current_stream().wait_stream(self.prefetch_stream)
|
|
1642
|
+
|
|
1643
|
+
assert self.current_iter_data is not None, "current_iter_data must be set"
|
|
1644
|
+
|
|
1645
|
+
curr_data: IterData = self.current_iter_data
|
|
1646
|
+
|
|
1647
|
+
if curr_data.lxu_cache_locations.numel() == 0:
|
|
1648
|
+
return
|
|
1649
|
+
|
|
1650
|
+
with record_function("## ssd_update_cache_counter_and_pointers ##"):
|
|
1651
|
+
# Unlock the cache lines
|
|
1652
|
+
torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
|
|
1653
|
+
self.lxu_cache_locking_counter,
|
|
1654
|
+
curr_data.lxu_cache_locations,
|
|
1655
|
+
)
|
|
1656
|
+
|
|
1657
|
+
# Recompute linear_cache_indices to save memory
|
|
1658
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
1659
|
+
self.hash_size_cumsum,
|
|
1660
|
+
curr_data.indices,
|
|
1661
|
+
curr_data.offsets,
|
|
1662
|
+
curr_data.B_offsets,
|
|
1663
|
+
curr_data.max_B,
|
|
1664
|
+
)
|
|
1665
|
+
(
|
|
1666
|
+
linear_unique_indices,
|
|
1667
|
+
linear_unique_indices_length,
|
|
1668
|
+
unique_indices_count,
|
|
1669
|
+
linear_index_inverse_indices,
|
|
1670
|
+
) = get_unique_indices_v2(
|
|
1671
|
+
linear_cache_indices,
|
|
1672
|
+
self.total_hash_size,
|
|
1673
|
+
compute_count=True,
|
|
1674
|
+
compute_inverse_indices=True,
|
|
1675
|
+
)
|
|
1676
|
+
unique_indices_count_cumsum = torch.ops.fbgemm.asynchronous_complete_cumsum(
|
|
1677
|
+
unique_indices_count
|
|
1678
|
+
)
|
|
1679
|
+
|
|
1680
|
+
# Look up the cache to check which indices in the scratch
|
|
1681
|
+
# pad are moved to L1
|
|
1682
|
+
torch.ops.fbgemm.lxu_cache_lookup(
|
|
1683
|
+
linear_cache_indices,
|
|
1684
|
+
self.lxu_cache_state,
|
|
1685
|
+
self.total_hash_size,
|
|
1686
|
+
gather_cache_stats=False, # not collecting cache stats
|
|
1687
|
+
lxu_cache_locations_output=curr_data.lxu_cache_locations,
|
|
1688
|
+
)
|
|
1689
|
+
|
|
1690
|
+
if len(self.ssd_location_update_data) == 0:
|
|
1691
|
+
return
|
|
1692
|
+
|
|
1693
|
+
(sp_curr_next_map, inserted_rows_next) = self.ssd_location_update_data.pop(
|
|
1694
|
+
0
|
|
1695
|
+
)
|
|
1696
|
+
|
|
1697
|
+
# Update poitners
|
|
1698
|
+
torch.ops.fbgemm.ssd_update_row_addrs(
|
|
1699
|
+
ssd_row_addrs_curr=curr_data.lxu_cache_ptrs,
|
|
1700
|
+
inserted_ssd_weights_curr_next_map=sp_curr_next_map,
|
|
1701
|
+
lxu_cache_locations_curr=curr_data.lxu_cache_locations,
|
|
1702
|
+
linear_index_inverse_indices_curr=linear_index_inverse_indices,
|
|
1703
|
+
unique_indices_count_cumsum_curr=unique_indices_count_cumsum,
|
|
1704
|
+
cache_set_inverse_indices_curr=curr_data.cache_set_inverse_indices,
|
|
1705
|
+
lxu_cache_weights=self.lxu_cache_weights,
|
|
1706
|
+
inserted_ssd_weights_next=inserted_rows_next,
|
|
1707
|
+
unique_indices_length_curr=curr_data.actions_count_gpu,
|
|
1708
|
+
)
|
|
1709
|
+
|
|
1710
|
+
def _update_feature_score_metadata(
|
|
1711
|
+
self,
|
|
1712
|
+
linear_cache_indices: Tensor,
|
|
1713
|
+
weights: Tensor,
|
|
1714
|
+
d2h_stream: torch.cuda.Stream,
|
|
1715
|
+
write_stream: torch.cuda.Stream,
|
|
1716
|
+
pre_event_for_write: torch.cuda.Event,
|
|
1717
|
+
post_event: Optional[torch.cuda.Event] = None,
|
|
1718
|
+
) -> None:
|
|
1719
|
+
"""
|
|
1720
|
+
Write feature score metadata to DRAM
|
|
1721
|
+
|
|
1722
|
+
This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
|
|
1723
|
+
The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
|
|
1724
|
+
|
|
1725
|
+
Args:
|
|
1726
|
+
linear_cache_indices: GPU tensor containing cache indices
|
|
1727
|
+
weights: GPU tensor containing feature scores
|
|
1728
|
+
d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
|
|
1729
|
+
write_stream: Stream for metadata write operation
|
|
1730
|
+
pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
|
|
1731
|
+
post_event: Event to record when the operation is done
|
|
1732
|
+
"""
|
|
1733
|
+
# Start D2H copy on d2h_stream
|
|
1734
|
+
with torch.cuda.stream(d2h_stream):
|
|
1735
|
+
# Record streams to prevent premature deallocation
|
|
1736
|
+
linear_cache_indices.record_stream(d2h_stream)
|
|
1737
|
+
weights.record_stream(d2h_stream)
|
|
1738
|
+
# Do the D2H copy
|
|
1739
|
+
linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
|
|
1740
|
+
score_weights_cpu = self.to_pinned_cpu(weights)
|
|
1741
|
+
|
|
1742
|
+
# Write feature score metadata to DRAM
|
|
1743
|
+
with record_function("## ssd_write_feature_score_metadata ##"):
|
|
1744
|
+
with torch.cuda.stream(write_stream):
|
|
1745
|
+
write_stream.wait_event(pre_event_for_write)
|
|
1746
|
+
write_stream.wait_stream(d2h_stream)
|
|
1747
|
+
self.record_function_via_dummy_profile(
|
|
1748
|
+
"## ssd_write_feature_score_metadata ##",
|
|
1749
|
+
self.ssd_db.set_feature_score_metadata_cuda,
|
|
1750
|
+
linear_cache_indices_cpu,
|
|
1751
|
+
torch.tensor(
|
|
1752
|
+
[score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
|
|
1753
|
+
),
|
|
1754
|
+
score_weights_cpu,
|
|
1755
|
+
)
|
|
1756
|
+
|
|
1757
|
+
if post_event is not None:
|
|
1758
|
+
write_stream.record_event(post_event)
|
|
1759
|
+
|
|
1760
|
+
def prefetch(
|
|
1761
|
+
self,
|
|
1762
|
+
indices: Tensor,
|
|
1763
|
+
offsets: Tensor,
|
|
1764
|
+
weights: Optional[Tensor] = None, # todo: need to update caller
|
|
1765
|
+
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
1766
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
1767
|
+
) -> None:
|
|
1768
|
+
if self.prefetch_stream is None and forward_stream is not None:
|
|
1769
|
+
# Set the prefetch stream to the current stream
|
|
1770
|
+
self.prefetch_stream = torch.cuda.current_stream()
|
|
1771
|
+
assert (
|
|
1772
|
+
self.prefetch_stream != forward_stream
|
|
1773
|
+
), "prefetch_stream and forward_stream should not be the same stream"
|
|
1774
|
+
|
|
1775
|
+
current_stream = torch.cuda.current_stream()
|
|
1776
|
+
# Record tensors on the current stream
|
|
1777
|
+
indices.record_stream(current_stream)
|
|
1778
|
+
offsets.record_stream(current_stream)
|
|
1779
|
+
|
|
1780
|
+
indices, offsets, _, vbe_metadata = self.prepare_inputs(
|
|
1781
|
+
indices,
|
|
1782
|
+
offsets,
|
|
1783
|
+
per_sample_weights=None,
|
|
1784
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
1785
|
+
)
|
|
1786
|
+
|
|
1787
|
+
self._prefetch(
|
|
1788
|
+
indices,
|
|
1789
|
+
offsets,
|
|
1790
|
+
weights,
|
|
1791
|
+
vbe_metadata,
|
|
1792
|
+
forward_stream,
|
|
1793
|
+
)
|
|
1794
|
+
|
|
1795
|
+
def _prefetch( # noqa C901
|
|
1796
|
+
self,
|
|
1797
|
+
indices: Tensor,
|
|
1798
|
+
offsets: Tensor,
|
|
1799
|
+
weights: Optional[Tensor] = None,
|
|
1800
|
+
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
|
|
1801
|
+
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
1802
|
+
) -> None:
|
|
1803
|
+
# Wait for any ongoing direct_write_embedding operations to complete
|
|
1804
|
+
# Moving this from forward() to _prefetch() is more logical as direct_write
|
|
1805
|
+
# operations affect the same cache structures that prefetch interacts with
|
|
1806
|
+
current_stream = torch.cuda.current_stream()
|
|
1807
|
+
if self._embedding_cache_mode:
|
|
1808
|
+
current_stream.wait_event(self.direct_write_l1_complete_event)
|
|
1809
|
+
current_stream.wait_event(self.direct_write_sp_complete_event)
|
|
1810
|
+
|
|
1811
|
+
B_offsets = None
|
|
1812
|
+
max_B = -1
|
|
1813
|
+
if vbe_metadata is not None:
|
|
1814
|
+
B_offsets = vbe_metadata.B_offsets
|
|
1815
|
+
max_B = vbe_metadata.max_B
|
|
1816
|
+
|
|
1817
|
+
with record_function("## ssd_prefetch {} ##".format(self.timestep)):
|
|
1818
|
+
if self.gather_ssd_cache_stats:
|
|
1819
|
+
self.local_ssd_cache_stats.zero_()
|
|
1820
|
+
|
|
1821
|
+
# Linearize indices
|
|
1822
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
1823
|
+
self.hash_size_cumsum,
|
|
1824
|
+
indices,
|
|
1825
|
+
offsets,
|
|
1826
|
+
B_offsets,
|
|
1827
|
+
max_B,
|
|
1828
|
+
)
|
|
1829
|
+
|
|
1830
|
+
self.timestep += 1
|
|
1831
|
+
self.timesteps_prefetched.append(self.timestep)
|
|
1832
|
+
|
|
1833
|
+
# Lookup and virtually insert indices into L1. After this operator,
|
|
1834
|
+
# we know:
|
|
1835
|
+
# (1) which cache lines can be evicted
|
|
1836
|
+
# (2) which rows are already in cache (hit)
|
|
1837
|
+
# (3) which rows are missed and can be inserted later (missed, but
|
|
1838
|
+
# not conflict missed)
|
|
1839
|
+
# (4) which rows are missed but CANNOT be inserted later (conflict
|
|
1840
|
+
# missed)
|
|
1841
|
+
(
|
|
1842
|
+
inserted_indices,
|
|
1843
|
+
evicted_indices,
|
|
1844
|
+
assigned_cache_slots,
|
|
1845
|
+
actions_count_gpu,
|
|
1846
|
+
linear_index_inverse_indices,
|
|
1847
|
+
unique_indices_count_cumsum,
|
|
1848
|
+
cache_set_inverse_indices,
|
|
1849
|
+
unique_indices_length,
|
|
1850
|
+
) = torch.ops.fbgemm.ssd_cache_populate_actions(
|
|
1851
|
+
linear_cache_indices,
|
|
1852
|
+
self.total_hash_size,
|
|
1853
|
+
self.lxu_cache_state,
|
|
1854
|
+
self.timestep,
|
|
1855
|
+
1, # for now assume prefetch_dist == 1
|
|
1856
|
+
self.lru_state,
|
|
1857
|
+
self.gather_ssd_cache_stats,
|
|
1858
|
+
self.local_ssd_cache_stats,
|
|
1859
|
+
lock_cache_line=self.prefetch_pipeline,
|
|
1860
|
+
lxu_cache_locking_counter=self.lxu_cache_locking_counter,
|
|
1861
|
+
)
|
|
1862
|
+
|
|
1863
|
+
# Compute cache locations (rows that are hit are missed but can be
|
|
1864
|
+
# inserted will have cache locations != -1)
|
|
1865
|
+
with record_function("## ssd_tbe_lxu_cache_lookup ##"):
|
|
1866
|
+
lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
|
|
1867
|
+
linear_cache_indices,
|
|
1868
|
+
self.lxu_cache_state,
|
|
1869
|
+
self.total_hash_size,
|
|
1870
|
+
self.gather_ssd_cache_stats,
|
|
1871
|
+
self.local_ssd_cache_stats,
|
|
1872
|
+
)
|
|
1873
|
+
|
|
1874
|
+
# Defrag indices based on evicted_indices (removing -1 and making
|
|
1875
|
+
# the non -1 elements contiguous). We need to do this because the
|
|
1876
|
+
# number of rows in `lxu_cache_evicted_weights` might be smaller
|
|
1877
|
+
# than the number of elements in `evicted_indices`. Without this
|
|
1878
|
+
# step, we can run into the index out of bound issue
|
|
1879
|
+
current_stream.wait_event(self.ssd_event_cache_evict)
|
|
1880
|
+
torch.ops.fbgemm.compact_indices(
|
|
1881
|
+
compact_indices=[
|
|
1882
|
+
self.lxu_cache_evicted_indices,
|
|
1883
|
+
self.lxu_cache_evicted_slots,
|
|
1884
|
+
],
|
|
1885
|
+
compact_count=self.lxu_cache_evicted_count,
|
|
1886
|
+
indices=[evicted_indices, assigned_cache_slots],
|
|
1887
|
+
masks=torch.where(evicted_indices != -1, 1, 0),
|
|
1888
|
+
count=actions_count_gpu,
|
|
1889
|
+
)
|
|
1890
|
+
has_raw_embedding_streaming = False
|
|
1891
|
+
if self.enable_raw_embedding_streaming:
|
|
1892
|
+
# when pipelining is enabled
|
|
1893
|
+
# prefetch in iter i happens before the backward sparse in iter i - 1
|
|
1894
|
+
# so embeddings for iter i - 1's changed ids are not updated.
|
|
1895
|
+
# so we can only fetch the indices from the iter i - 2
|
|
1896
|
+
# when pipelining is disabled
|
|
1897
|
+
# prefetch in iter i happens before forward iter i
|
|
1898
|
+
# so we can get the iter i - 1's changed ids safely.
|
|
1899
|
+
target_prev_iter = 1
|
|
1900
|
+
if self.prefetch_pipeline:
|
|
1901
|
+
target_prev_iter = 2
|
|
1902
|
+
if len(self.prefetched_info) > (target_prev_iter - 1):
|
|
1903
|
+
with record_function(
|
|
1904
|
+
"## ssd_lookup_prefetched_rows {} {} ##".format(
|
|
1905
|
+
self.timestep, self.tbe_unique_id
|
|
1906
|
+
)
|
|
1907
|
+
):
|
|
1908
|
+
# wait for the copy to finish before overwriting the buffer
|
|
1909
|
+
self.raw_embedding_stream_sync(
|
|
1910
|
+
stream=self.ssd_eviction_stream,
|
|
1911
|
+
pre_event=self.ssd_event_cache_streamed,
|
|
1912
|
+
post_event=self.ssd_event_cache_streaming_synced,
|
|
1913
|
+
name="cache_update",
|
|
1914
|
+
)
|
|
1915
|
+
current_stream.wait_event(self.ssd_event_cache_streaming_synced)
|
|
1916
|
+
(updated_indices, updated_counts_gpu) = (
|
|
1917
|
+
self.prefetched_info.pop(0)
|
|
1918
|
+
)
|
|
1919
|
+
self.lxu_cache_updated_indices[: updated_indices.size(0)].copy_(
|
|
1920
|
+
updated_indices,
|
|
1921
|
+
non_blocking=True,
|
|
1922
|
+
)
|
|
1923
|
+
self.lxu_cache_updated_count[:1].copy_(
|
|
1924
|
+
updated_counts_gpu, non_blocking=True
|
|
1925
|
+
)
|
|
1926
|
+
has_raw_embedding_streaming = True
|
|
1927
|
+
|
|
1928
|
+
with record_function(
|
|
1929
|
+
"## ssd_save_prefetched_rows {} {} ##".format(
|
|
1930
|
+
self.timestep, self.tbe_unique_id
|
|
1931
|
+
)
|
|
1932
|
+
):
|
|
1933
|
+
masked_updated_indices = torch.where(
|
|
1934
|
+
torch.where(lxu_cache_locations != -1, True, False),
|
|
1935
|
+
linear_cache_indices,
|
|
1936
|
+
-1,
|
|
1937
|
+
)
|
|
1938
|
+
|
|
1939
|
+
(
|
|
1940
|
+
uni_updated_indices,
|
|
1941
|
+
uni_updated_indices_length,
|
|
1942
|
+
) = get_unique_indices_v2(
|
|
1943
|
+
masked_updated_indices,
|
|
1944
|
+
self.total_hash_size,
|
|
1945
|
+
compute_count=False,
|
|
1946
|
+
compute_inverse_indices=False,
|
|
1947
|
+
)
|
|
1948
|
+
assert uni_updated_indices is not None
|
|
1949
|
+
assert uni_updated_indices_length is not None
|
|
1950
|
+
# The unique indices has 1 more -1 element than needed,
|
|
1951
|
+
# which might make the tensor length go out of range
|
|
1952
|
+
# compared to the pre-allocated buffer.
|
|
1953
|
+
unique_len = min(
|
|
1954
|
+
self.lxu_cache_weights.size(0),
|
|
1955
|
+
uni_updated_indices.size(0),
|
|
1956
|
+
)
|
|
1957
|
+
self.prefetched_info.append(
|
|
1958
|
+
(
|
|
1959
|
+
uni_updated_indices.narrow(0, 0, unique_len),
|
|
1960
|
+
uni_updated_indices_length.clamp(max=unique_len),
|
|
1961
|
+
)
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
with record_function("## ssd_d2h_inserted_indices ##"):
|
|
1965
|
+
# Transfer actions_count and insert_indices right away to
|
|
1966
|
+
# incrase an overlap opportunity
|
|
1967
|
+
actions_count_cpu, inserted_indices_cpu = (
|
|
1968
|
+
self.to_pinned_cpu_on_stream_wait_on_another_stream(
|
|
1969
|
+
tensors=[
|
|
1970
|
+
actions_count_gpu,
|
|
1971
|
+
inserted_indices,
|
|
1972
|
+
],
|
|
1973
|
+
stream=self.ssd_memcpy_stream,
|
|
1974
|
+
stream_to_wait_on=current_stream,
|
|
1975
|
+
post_event=self.ssd_event_get_inputs_cpy,
|
|
1976
|
+
)
|
|
1977
|
+
)
|
|
1978
|
+
|
|
1979
|
+
# Copy rows to be evicted into a separate buffer (will be evicted
|
|
1980
|
+
# later in the prefetch step)
|
|
1981
|
+
with record_function("## ssd_compute_evicted_rows ##"):
|
|
1982
|
+
torch.ops.fbgemm.masked_index_select(
|
|
1983
|
+
self.lxu_cache_evicted_weights,
|
|
1984
|
+
self.lxu_cache_evicted_slots,
|
|
1985
|
+
self.lxu_cache_weights,
|
|
1986
|
+
self.lxu_cache_evicted_count,
|
|
1987
|
+
)
|
|
1988
|
+
|
|
1989
|
+
# Allocation a scratch pad for the current iteration. The scratch
|
|
1990
|
+
# pad is a UVA tensor
|
|
1991
|
+
inserted_rows_shape = (assigned_cache_slots.numel(), self.cache_row_dim)
|
|
1992
|
+
if linear_cache_indices.numel() > 0:
|
|
1993
|
+
inserted_rows = torch.ops.fbgemm.new_unified_tensor(
|
|
1994
|
+
torch.zeros(
|
|
1995
|
+
1,
|
|
1996
|
+
device=self.current_device,
|
|
1997
|
+
dtype=self.lxu_cache_weights.dtype,
|
|
1998
|
+
),
|
|
1999
|
+
inserted_rows_shape,
|
|
2000
|
+
is_host_mapped=self.uvm_host_mapped,
|
|
2001
|
+
)
|
|
2002
|
+
else:
|
|
2003
|
+
inserted_rows = torch.empty(
|
|
2004
|
+
inserted_rows_shape,
|
|
2005
|
+
dtype=self.lxu_cache_weights.dtype,
|
|
2006
|
+
device=self.current_device,
|
|
2007
|
+
)
|
|
2008
|
+
|
|
2009
|
+
if self.prefetch_pipeline and len(self.ssd_scratch_pads) > 0:
|
|
2010
|
+
# Look up all missed indices from the previous iteration's
|
|
2011
|
+
# scratch pad (do this only if pipeline prefetching is being
|
|
2012
|
+
# used)
|
|
2013
|
+
with record_function("## ssd_lookup_scratch_pad ##"):
|
|
2014
|
+
# Get the previous scratch pad
|
|
2015
|
+
(
|
|
2016
|
+
inserted_rows_prev,
|
|
2017
|
+
post_bwd_evicted_indices_cpu_prev,
|
|
2018
|
+
actions_count_cpu_prev,
|
|
2019
|
+
) = self.ssd_scratch_pads.pop(0)
|
|
2020
|
+
|
|
2021
|
+
# Inserted indices that are found in the scratch pad
|
|
2022
|
+
# from the previous iteration
|
|
2023
|
+
sp_prev_curr_map_cpu = torch.empty(
|
|
2024
|
+
inserted_indices_cpu.shape,
|
|
2025
|
+
dtype=inserted_indices_cpu.dtype,
|
|
2026
|
+
pin_memory=True,
|
|
2027
|
+
)
|
|
2028
|
+
|
|
2029
|
+
# Conflict missed indices from the previous iteration that
|
|
2030
|
+
# overlap with the current iterations's inserted indices
|
|
2031
|
+
sp_curr_prev_map_cpu = torch.empty(
|
|
2032
|
+
post_bwd_evicted_indices_cpu_prev.shape,
|
|
2033
|
+
dtype=torch.int,
|
|
2034
|
+
pin_memory=True,
|
|
2035
|
+
).fill_(-1)
|
|
2036
|
+
|
|
2037
|
+
# Ensure that the necessary D2H transfers are done
|
|
2038
|
+
current_stream.wait_event(self.ssd_event_get_inputs_cpy)
|
|
2039
|
+
# Ensure that the previous iteration's scratch pad indices
|
|
2040
|
+
# insertion is complete
|
|
2041
|
+
current_stream.wait_event(self.ssd_event_sp_idxq_insert)
|
|
2042
|
+
|
|
2043
|
+
# Before entering this function: inserted_indices_cpu
|
|
2044
|
+
# contains all linear indices that are missed from the
|
|
2045
|
+
# L1 cache
|
|
2046
|
+
#
|
|
2047
|
+
# After this function: inserted indices that are found
|
|
2048
|
+
# in the scratch pad from the previous iteration are
|
|
2049
|
+
# stored in sp_prev_curr_map_cpu, while the rests are
|
|
2050
|
+
# stored in inserted_indices_cpu
|
|
2051
|
+
#
|
|
2052
|
+
# An invalid index is -1 or its position >
|
|
2053
|
+
# actions_count_cpu
|
|
2054
|
+
self.record_function_via_dummy_profile(
|
|
2055
|
+
"## ssd_lookup_mask_and_pop_front ##",
|
|
2056
|
+
self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda,
|
|
2057
|
+
sp_prev_curr_map_cpu, # scratch_pad_prev_curr_map
|
|
2058
|
+
sp_curr_prev_map_cpu, # scratch_pad_curr_prev_map
|
|
2059
|
+
post_bwd_evicted_indices_cpu_prev, # scratch_pad_indices_prev
|
|
2060
|
+
inserted_indices_cpu, # inserted_indices_curr
|
|
2061
|
+
actions_count_cpu, # count_curr
|
|
2062
|
+
)
|
|
2063
|
+
|
|
2064
|
+
# Mark scratch pad index queue lookup completion
|
|
2065
|
+
current_stream.record_event(self.ssd_event_sp_idxq_lookup)
|
|
2066
|
+
|
|
2067
|
+
# Transfer sp_prev_curr_map_cpu to GPU
|
|
2068
|
+
sp_prev_curr_map_gpu = sp_prev_curr_map_cpu.cuda(non_blocking=True)
|
|
2069
|
+
# Transfer sp_curr_prev_map_cpu to GPU
|
|
2070
|
+
sp_curr_prev_map_gpu = sp_curr_prev_map_cpu.cuda(non_blocking=True)
|
|
2071
|
+
|
|
2072
|
+
# Previously actions_count_gpu was recorded on another
|
|
2073
|
+
# stream. Thus, we need to record it on this stream
|
|
2074
|
+
actions_count_gpu.record_stream(current_stream)
|
|
2075
|
+
|
|
2076
|
+
# Copy data from the previous iteration's scratch pad to
|
|
2077
|
+
# the current iteration's scratch pad
|
|
2078
|
+
torch.ops.fbgemm.masked_index_select(
|
|
2079
|
+
inserted_rows,
|
|
2080
|
+
sp_prev_curr_map_gpu,
|
|
2081
|
+
inserted_rows_prev,
|
|
2082
|
+
actions_count_gpu,
|
|
2083
|
+
use_pipeline=self.prefetch_pipeline,
|
|
2084
|
+
)
|
|
2085
|
+
|
|
2086
|
+
# Record the tensors that will be pushed into a queue
|
|
2087
|
+
# on the forward stream
|
|
2088
|
+
if forward_stream:
|
|
2089
|
+
sp_curr_prev_map_gpu.record_stream(forward_stream)
|
|
2090
|
+
|
|
2091
|
+
# Store info for evicting the previous iteration's
|
|
2092
|
+
# scratch pad after the corresponding backward pass is
|
|
2093
|
+
# done
|
|
2094
|
+
if self.training:
|
|
2095
|
+
self.ssd_location_update_data.append(
|
|
2096
|
+
(
|
|
2097
|
+
sp_curr_prev_map_gpu,
|
|
2098
|
+
inserted_rows,
|
|
2099
|
+
)
|
|
2100
|
+
)
|
|
2101
|
+
|
|
2102
|
+
# Ensure the previous iterations eviction is complete
|
|
2103
|
+
current_stream.wait_event(self.ssd_event_sp_evict)
|
|
2104
|
+
# Ensure that D2H is done
|
|
2105
|
+
current_stream.wait_event(self.ssd_event_get_inputs_cpy)
|
|
2106
|
+
|
|
2107
|
+
if self.enable_raw_embedding_streaming and has_raw_embedding_streaming:
|
|
2108
|
+
current_stream.wait_event(self.ssd_event_sp_streamed)
|
|
2109
|
+
with record_function(
|
|
2110
|
+
"## ssd_compute_updated_rows {} {} ##".format(
|
|
2111
|
+
self.timestep, self.tbe_unique_id
|
|
2112
|
+
)
|
|
2113
|
+
):
|
|
2114
|
+
# cache rows that are changed in the previous iteration
|
|
2115
|
+
updated_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
|
|
2116
|
+
self.lxu_cache_updated_indices,
|
|
2117
|
+
self.lxu_cache_state,
|
|
2118
|
+
self.total_hash_size,
|
|
2119
|
+
self.gather_ssd_cache_stats,
|
|
2120
|
+
self.local_ssd_cache_stats,
|
|
2121
|
+
)
|
|
2122
|
+
torch.ops.fbgemm.masked_index_select(
|
|
2123
|
+
self.lxu_cache_updated_weights,
|
|
2124
|
+
updated_cache_locations,
|
|
2125
|
+
self.lxu_cache_weights,
|
|
2126
|
+
self.lxu_cache_updated_count,
|
|
2127
|
+
)
|
|
2128
|
+
current_stream.record_event(self.ssd_event_cache_streaming_computed)
|
|
2129
|
+
|
|
2130
|
+
self.raw_embedding_stream(
|
|
2131
|
+
rows=self.lxu_cache_updated_weights,
|
|
2132
|
+
indices_cpu=self.lxu_cache_updated_indices,
|
|
2133
|
+
actions_count_cpu=self.lxu_cache_updated_count,
|
|
2134
|
+
stream=self.ssd_eviction_stream,
|
|
2135
|
+
pre_event=self.ssd_event_cache_streaming_computed,
|
|
2136
|
+
post_event=self.ssd_event_cache_streamed,
|
|
2137
|
+
is_rows_uvm=True,
|
|
2138
|
+
blocking_tensor_copy=False,
|
|
2139
|
+
name="cache_update",
|
|
2140
|
+
)
|
|
2141
|
+
|
|
2142
|
+
if self.gather_ssd_cache_stats:
|
|
2143
|
+
# call to collect past SSD IO dur right before next rocksdb IO
|
|
2144
|
+
|
|
2145
|
+
self.ssd_cache_stats = torch.add(
|
|
2146
|
+
self.ssd_cache_stats, self.local_ssd_cache_stats
|
|
2147
|
+
)
|
|
2148
|
+
# only report metrics from rank0 to avoid flooded logging
|
|
2149
|
+
if dist.get_rank() == 0:
|
|
2150
|
+
self._report_kv_backend_stats()
|
|
2151
|
+
|
|
2152
|
+
# May trigger eviction if free mem trigger mode enabled before get cuda
|
|
2153
|
+
self.may_trigger_eviction()
|
|
2154
|
+
|
|
2155
|
+
# Fetch data from SSD
|
|
2156
|
+
if linear_cache_indices.numel() > 0:
|
|
2157
|
+
self.record_function_via_dummy_profile(
|
|
2158
|
+
"## ssd_get ##",
|
|
2159
|
+
self.ssd_db.get_cuda,
|
|
2160
|
+
inserted_indices_cpu,
|
|
2161
|
+
inserted_rows,
|
|
2162
|
+
actions_count_cpu,
|
|
2163
|
+
)
|
|
2164
|
+
|
|
2165
|
+
# Record an event to mark the completion of `get_cuda`
|
|
2166
|
+
current_stream.record_event(self.ssd_event_get)
|
|
2167
|
+
|
|
2168
|
+
# Copy rows from the current iteration's scratch pad to L1
|
|
2169
|
+
torch.ops.fbgemm.masked_index_put(
|
|
2170
|
+
self.lxu_cache_weights,
|
|
2171
|
+
assigned_cache_slots,
|
|
2172
|
+
inserted_rows,
|
|
2173
|
+
actions_count_gpu,
|
|
2174
|
+
use_pipeline=self.prefetch_pipeline,
|
|
2175
|
+
)
|
|
2176
|
+
|
|
2177
|
+
if self.training:
|
|
2178
|
+
if linear_cache_indices.numel() > 0:
|
|
2179
|
+
# Evict rows from cache to SSD
|
|
2180
|
+
self.evict(
|
|
2181
|
+
rows=self.lxu_cache_evicted_weights,
|
|
2182
|
+
indices_cpu=self.lxu_cache_evicted_indices,
|
|
2183
|
+
actions_count_cpu=self.lxu_cache_evicted_count,
|
|
2184
|
+
stream=self.ssd_eviction_stream,
|
|
2185
|
+
pre_event=self.ssd_event_get,
|
|
2186
|
+
# Record completion event after scratch pad eviction
|
|
2187
|
+
# instead since that happens after L1 eviction
|
|
2188
|
+
post_event=self.ssd_event_cache_evict,
|
|
2189
|
+
is_rows_uvm=True,
|
|
2190
|
+
name="cache",
|
|
2191
|
+
is_bwd=False,
|
|
2192
|
+
)
|
|
2193
|
+
if (
|
|
2194
|
+
self.backend_type == BackendType.DRAM
|
|
2195
|
+
and weights is not None
|
|
2196
|
+
and linear_cache_indices.numel() > 0
|
|
2197
|
+
):
|
|
2198
|
+
# Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
|
|
2199
|
+
self._update_feature_score_metadata(
|
|
2200
|
+
linear_cache_indices=linear_cache_indices,
|
|
2201
|
+
weights=weights,
|
|
2202
|
+
d2h_stream=self.ssd_memcpy_stream,
|
|
2203
|
+
write_stream=self.feature_score_stream,
|
|
2204
|
+
pre_event_for_write=self.ssd_event_cache_evict,
|
|
2205
|
+
)
|
|
2206
|
+
|
|
2207
|
+
# Generate row addresses (pointing to either L1 or the current
|
|
2208
|
+
# iteration's scratch pad)
|
|
2209
|
+
with record_function("## ssd_generate_row_addrs ##"):
|
|
2210
|
+
lxu_cache_ptrs, post_bwd_evicted_indices = (
|
|
2211
|
+
torch.ops.fbgemm.ssd_generate_row_addrs(
|
|
2212
|
+
lxu_cache_locations,
|
|
2213
|
+
assigned_cache_slots,
|
|
2214
|
+
linear_index_inverse_indices,
|
|
2215
|
+
unique_indices_count_cumsum,
|
|
2216
|
+
cache_set_inverse_indices,
|
|
2217
|
+
self.lxu_cache_weights,
|
|
2218
|
+
inserted_rows,
|
|
2219
|
+
unique_indices_length,
|
|
2220
|
+
inserted_indices,
|
|
2221
|
+
)
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2224
|
+
with record_function("## ssd_d2h_post_bwd_evicted_indices ##"):
|
|
2225
|
+
# Transfer post_bwd_evicted_indices from GPU to CPU right away to
|
|
2226
|
+
# increase a chance of overlapping with compute in the default stream
|
|
2227
|
+
(post_bwd_evicted_indices_cpu,) = (
|
|
2228
|
+
self.to_pinned_cpu_on_stream_wait_on_another_stream(
|
|
2229
|
+
tensors=[post_bwd_evicted_indices],
|
|
2230
|
+
stream=self.ssd_eviction_stream,
|
|
2231
|
+
stream_to_wait_on=current_stream,
|
|
2232
|
+
post_event=None,
|
|
2233
|
+
)
|
|
2234
|
+
)
|
|
2235
|
+
|
|
2236
|
+
if self.prefetch_pipeline:
|
|
2237
|
+
# Insert the current iteration's conflict miss indices in the index
|
|
2238
|
+
# queue for future lookup.
|
|
2239
|
+
#
|
|
2240
|
+
# post_bwd_evicted_indices_cpu is transferred on the
|
|
2241
|
+
# ssd_eviction_stream stream so it does not need stream
|
|
2242
|
+
# synchronization
|
|
2243
|
+
#
|
|
2244
|
+
# actions_count_cpu is transferred on the ssd_memcpy_stream stream.
|
|
2245
|
+
# Thus, we have to explicitly sync the stream
|
|
2246
|
+
with torch.cuda.stream(self.ssd_eviction_stream):
|
|
2247
|
+
# Ensure that actions_count_cpu transfer is done
|
|
2248
|
+
self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy)
|
|
2249
|
+
# Ensure that the scratch pad index queue look up is complete
|
|
2250
|
+
self.ssd_eviction_stream.wait_event(self.ssd_event_sp_idxq_lookup)
|
|
2251
|
+
self.record_function_via_dummy_profile(
|
|
2252
|
+
"## ssd_scratch_pad_idx_queue_insert ##",
|
|
2253
|
+
self.scratch_pad_idx_queue.insert_cuda,
|
|
2254
|
+
post_bwd_evicted_indices_cpu,
|
|
2255
|
+
actions_count_cpu,
|
|
2256
|
+
)
|
|
2257
|
+
# Mark the completion of scratch pad index insertion
|
|
2258
|
+
self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert)
|
|
2259
|
+
|
|
2260
|
+
prefetch_data = (
|
|
2261
|
+
lxu_cache_ptrs,
|
|
2262
|
+
inserted_rows,
|
|
2263
|
+
post_bwd_evicted_indices_cpu,
|
|
2264
|
+
actions_count_cpu,
|
|
2265
|
+
actions_count_gpu,
|
|
2266
|
+
lxu_cache_locations,
|
|
2267
|
+
cache_set_inverse_indices,
|
|
2268
|
+
)
|
|
2269
|
+
|
|
2270
|
+
# Record tensors on the forward stream
|
|
2271
|
+
if forward_stream is not None:
|
|
2272
|
+
for t in prefetch_data:
|
|
2273
|
+
if t.is_cuda:
|
|
2274
|
+
t.record_stream(forward_stream)
|
|
2275
|
+
|
|
2276
|
+
if self.prefetch_pipeline:
|
|
2277
|
+
# Store scratch pad info for the lookup in the next iteration
|
|
2278
|
+
# prefetch
|
|
2279
|
+
self.ssd_scratch_pads.append(
|
|
2280
|
+
(
|
|
2281
|
+
inserted_rows,
|
|
2282
|
+
post_bwd_evicted_indices_cpu,
|
|
2283
|
+
actions_count_cpu,
|
|
2284
|
+
)
|
|
2285
|
+
)
|
|
2286
|
+
|
|
2287
|
+
# Store scratch pad info for post backward eviction only for training
|
|
2288
|
+
# for eval job, no backward pass, so no need to store this info
|
|
2289
|
+
if self.training:
|
|
2290
|
+
self.ssd_scratch_pad_eviction_data.append(
|
|
2291
|
+
(
|
|
2292
|
+
inserted_rows,
|
|
2293
|
+
post_bwd_evicted_indices_cpu,
|
|
2294
|
+
actions_count_cpu,
|
|
2295
|
+
linear_cache_indices.numel() > 0,
|
|
2296
|
+
)
|
|
2297
|
+
)
|
|
2298
|
+
|
|
2299
|
+
# Store data for forward
|
|
2300
|
+
self.ssd_prefetch_data.append(prefetch_data)
|
|
2301
|
+
|
|
2302
|
+
# Record an event to mark the completion of prefetch operations
|
|
2303
|
+
# This will be used by direct_write_embedding to ensure it doesn't run concurrently with prefetch
|
|
2304
|
+
current_stream.record_event(self.prefetch_complete_event)
|
|
2305
|
+
|
|
2306
|
+
@torch.jit.ignore
|
|
2307
|
+
def _generate_vbe_metadata(
|
|
2308
|
+
self,
|
|
2309
|
+
offsets: Tensor,
|
|
2310
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
2311
|
+
) -> invokers.lookup_args.VBEMetadata:
|
|
2312
|
+
# Blocking D2H copy, but only runs at first call
|
|
2313
|
+
self.feature_dims = self.feature_dims.cpu()
|
|
2314
|
+
if batch_size_per_feature_per_rank is not None:
|
|
2315
|
+
assert self.optimizer in (
|
|
2316
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
2317
|
+
OptimType.EXACT_SGD,
|
|
2318
|
+
), (
|
|
2319
|
+
"Variable batch size TBE support is enabled for "
|
|
2320
|
+
"OptimType.EXACT_ROWWISE_ADAGRAD and "
|
|
2321
|
+
"ENSEMBLE_ROWWISE_ADAGRAD only"
|
|
2322
|
+
)
|
|
2323
|
+
return generate_vbe_metadata(
|
|
2324
|
+
offsets,
|
|
2325
|
+
batch_size_per_feature_per_rank,
|
|
2326
|
+
self.pooling_mode,
|
|
2327
|
+
self.feature_dims,
|
|
2328
|
+
self.current_device,
|
|
2329
|
+
)
|
|
2330
|
+
|
|
2331
|
+
def _increment_iteration(self) -> int:
|
|
2332
|
+
# Although self.iter_cpu is created on CPU. It might be transferred to
|
|
2333
|
+
# GPU by the user. So, we need to transfer it to CPU explicitly. This
|
|
2334
|
+
# should be done only once.
|
|
2335
|
+
self.iter_cpu = self.iter_cpu.cpu()
|
|
2336
|
+
|
|
2337
|
+
# Sync with loaded state
|
|
2338
|
+
# Wrap to make it compatible with PT2 compile
|
|
2339
|
+
if not is_torchdynamo_compiling():
|
|
2340
|
+
if self.iter_cpu.item() == 0:
|
|
2341
|
+
self.iter_cpu.fill_(self.iter.cpu().item())
|
|
2342
|
+
|
|
2343
|
+
# Increment the iteration counter
|
|
2344
|
+
# The CPU counterpart is used for local computation
|
|
2345
|
+
iter_int = int(self.iter_cpu.add_(1).item())
|
|
2346
|
+
# The GPU counterpart is used for checkpointing
|
|
2347
|
+
self.iter.add_(1)
|
|
2348
|
+
|
|
2349
|
+
return iter_int
|
|
2350
|
+
|
|
2351
|
+
def forward(
|
|
2352
|
+
self,
|
|
2353
|
+
indices: Tensor,
|
|
2354
|
+
offsets: Tensor,
|
|
2355
|
+
weights: Optional[Tensor] = None,
|
|
2356
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
2357
|
+
feature_requires_grad: Optional[Tensor] = None,
|
|
2358
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
2359
|
+
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
|
|
2360
|
+
) -> Tensor:
|
|
2361
|
+
self.clear_cache()
|
|
2362
|
+
indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
|
|
2363
|
+
indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
|
|
2364
|
+
)
|
|
2365
|
+
|
|
2366
|
+
if len(self.timesteps_prefetched) == 0:
|
|
2367
|
+
|
|
2368
|
+
with self._recording_to_timer(
|
|
2369
|
+
self.ssd_prefetch_read_timer,
|
|
2370
|
+
context=self.step,
|
|
2371
|
+
stream=torch.cuda.current_stream(),
|
|
2372
|
+
), self._recording_to_timer(
|
|
2373
|
+
self.ssd_prefetch_evict_timer,
|
|
2374
|
+
context=self.step,
|
|
2375
|
+
stream=self.ssd_eviction_stream,
|
|
2376
|
+
):
|
|
2377
|
+
self._prefetch(indices, offsets, weights, vbe_metadata)
|
|
2378
|
+
|
|
2379
|
+
assert len(self.ssd_prefetch_data) > 0
|
|
2380
|
+
|
|
2381
|
+
(
|
|
2382
|
+
lxu_cache_ptrs,
|
|
2383
|
+
inserted_rows,
|
|
2384
|
+
post_bwd_evicted_indices_cpu,
|
|
2385
|
+
actions_count_cpu,
|
|
2386
|
+
actions_count_gpu,
|
|
2387
|
+
lxu_cache_locations,
|
|
2388
|
+
cache_set_inverse_indices,
|
|
2389
|
+
) = self.ssd_prefetch_data.pop(0)
|
|
2390
|
+
|
|
2391
|
+
# Storing current iteration data for future use
|
|
2392
|
+
self.current_iter_data = IterData(
|
|
2393
|
+
indices,
|
|
2394
|
+
offsets,
|
|
2395
|
+
lxu_cache_locations,
|
|
2396
|
+
lxu_cache_ptrs,
|
|
2397
|
+
actions_count_gpu,
|
|
2398
|
+
cache_set_inverse_indices,
|
|
2399
|
+
vbe_metadata.B_offsets,
|
|
2400
|
+
vbe_metadata.max_B,
|
|
2401
|
+
)
|
|
2402
|
+
|
|
2403
|
+
common_args = invokers.lookup_args_ssd.CommonArgs(
|
|
2404
|
+
placeholder_autograd_tensor=self.placeholder_autograd_tensor,
|
|
2405
|
+
output_dtype=self.output_dtype,
|
|
2406
|
+
dev_weights=self.weights_dev,
|
|
2407
|
+
host_weights=self.weights_host,
|
|
2408
|
+
uvm_weights=self.weights_uvm,
|
|
2409
|
+
lxu_cache_weights=self.lxu_cache_weights,
|
|
2410
|
+
weights_placements=self.weights_placements,
|
|
2411
|
+
weights_offsets=self.weights_offsets,
|
|
2412
|
+
D_offsets=self.D_offsets,
|
|
2413
|
+
total_D=self.total_D,
|
|
2414
|
+
max_D=self.max_D,
|
|
2415
|
+
hash_size_cumsum=self.hash_size_cumsum,
|
|
2416
|
+
total_hash_size_bits=self.total_hash_size_bits,
|
|
2417
|
+
indices=indices,
|
|
2418
|
+
offsets=offsets,
|
|
2419
|
+
pooling_mode=self.pooling_mode,
|
|
2420
|
+
indice_weights=per_sample_weights,
|
|
2421
|
+
feature_requires_grad=feature_requires_grad,
|
|
2422
|
+
lxu_cache_locations=lxu_cache_ptrs,
|
|
2423
|
+
uvm_cache_stats=None,
|
|
2424
|
+
# Unused arguments
|
|
2425
|
+
is_experimental=False,
|
|
2426
|
+
use_uniq_cache_locations_bwd=False,
|
|
2427
|
+
use_homogeneous_placements=True,
|
|
2428
|
+
# The keys for ssd_tensors are controlled by ssd_tensors in
|
|
2429
|
+
# codegen/genscript/optimizer_args.py
|
|
2430
|
+
ssd_tensors={
|
|
2431
|
+
"row_addrs": lxu_cache_ptrs,
|
|
2432
|
+
"inserted_rows": inserted_rows,
|
|
2433
|
+
"post_bwd_evicted_indices": post_bwd_evicted_indices_cpu,
|
|
2434
|
+
"actions_count": actions_count_cpu,
|
|
2435
|
+
},
|
|
2436
|
+
enable_optimizer_offloading=self.enable_optimizer_offloading,
|
|
2437
|
+
# pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
|
|
2438
|
+
vbe_metadata=vbe_metadata,
|
|
2439
|
+
learning_rate_tensor=self.learning_rate_tensor,
|
|
2440
|
+
info_B_num_bits=self.info_B_num_bits,
|
|
2441
|
+
info_B_mask=self.info_B_mask,
|
|
2442
|
+
)
|
|
2443
|
+
|
|
2444
|
+
self.timesteps_prefetched.pop(0)
|
|
2445
|
+
self.step += 1
|
|
2446
|
+
|
|
2447
|
+
# Increment the iteration (value is used for certain optimizers)
|
|
2448
|
+
iter_int = self._increment_iteration()
|
|
2449
|
+
|
|
2450
|
+
if self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
|
|
2451
|
+
momentum2 = invokers.lookup_args_ssd.Momentum(
|
|
2452
|
+
# pyre-ignore[6]
|
|
2453
|
+
dev=self.momentum2_dev,
|
|
2454
|
+
# pyre-ignore[6]
|
|
2455
|
+
host=self.momentum2_host,
|
|
2456
|
+
# pyre-ignore[6]
|
|
2457
|
+
uvm=self.momentum2_uvm,
|
|
2458
|
+
# pyre-ignore[6]
|
|
2459
|
+
offsets=self.momentum2_offsets,
|
|
2460
|
+
# pyre-ignore[6]
|
|
2461
|
+
placements=self.momentum2_placements,
|
|
2462
|
+
)
|
|
2463
|
+
|
|
2464
|
+
momentum1 = invokers.lookup_args_ssd.Momentum(
|
|
2465
|
+
dev=self.momentum1_dev,
|
|
2466
|
+
host=self.momentum1_host,
|
|
2467
|
+
uvm=self.momentum1_uvm,
|
|
2468
|
+
offsets=self.momentum1_offsets,
|
|
2469
|
+
placements=self.momentum1_placements,
|
|
2470
|
+
)
|
|
2471
|
+
|
|
2472
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
2473
|
+
return invokers.lookup_rowwise_adagrad_ssd.invoke(
|
|
2474
|
+
common_args, self.optimizer_args, momentum1
|
|
2475
|
+
)
|
|
2476
|
+
|
|
2477
|
+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
|
|
2478
|
+
return invokers.lookup_partial_rowwise_adam_ssd.invoke(
|
|
2479
|
+
common_args,
|
|
2480
|
+
self.optimizer_args,
|
|
2481
|
+
momentum1,
|
|
2482
|
+
# pyre-ignore[61]
|
|
2483
|
+
momentum2,
|
|
2484
|
+
iter_int,
|
|
2485
|
+
)
|
|
2486
|
+
|
|
2487
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2488
|
+
row_counter = invokers.lookup_args_ssd.Momentum(
|
|
2489
|
+
# pyre-fixme[6]
|
|
2490
|
+
dev=self.row_counter_dev,
|
|
2491
|
+
# pyre-fixme[6]
|
|
2492
|
+
host=self.row_counter_host,
|
|
2493
|
+
# pyre-fixme[6]
|
|
2494
|
+
uvm=self.row_counter_uvm,
|
|
2495
|
+
# pyre-fixme[6]
|
|
2496
|
+
offsets=self.row_counter_offsets,
|
|
2497
|
+
# pyre-fixme[6]
|
|
2498
|
+
placements=self.row_counter_placements,
|
|
2499
|
+
)
|
|
2500
|
+
|
|
2501
|
+
return invokers.lookup_adam_ssd.invoke(
|
|
2502
|
+
common_args,
|
|
2503
|
+
self.optimizer_args,
|
|
2504
|
+
momentum1,
|
|
2505
|
+
# pyre-ignore[61]
|
|
2506
|
+
momentum2,
|
|
2507
|
+
iter_int,
|
|
2508
|
+
row_counter=row_counter,
|
|
2509
|
+
)
|
|
2510
|
+
|
|
2511
|
+
@torch.jit.ignore
|
|
2512
|
+
def _split_optimizer_states_non_kv_zch(
|
|
2513
|
+
self,
|
|
2514
|
+
) -> list[list[torch.Tensor]]:
|
|
2515
|
+
"""
|
|
2516
|
+
Returns a list of optimizer states (view), split by table.
|
|
2517
|
+
|
|
2518
|
+
Returns:
|
|
2519
|
+
A list of list of states. Shape = (the number of tables, the number
|
|
2520
|
+
of states).
|
|
2521
|
+
|
|
2522
|
+
The following shows the list of states (in the returned order) for
|
|
2523
|
+
each optimizer:
|
|
2524
|
+
|
|
2525
|
+
(1) `EXACT_ROWWISE_ADAGRAD`: `momentum1` (rowwise)
|
|
2526
|
+
|
|
2527
|
+
(1) `PARTIAL_ROWWISE_ADAM`: `momentum1`, `momentum2` (rowwise)
|
|
2528
|
+
"""
|
|
2529
|
+
|
|
2530
|
+
# Row count per table
|
|
2531
|
+
(rows, dims) = zip(*self.embedding_specs)
|
|
2532
|
+
# Cumulative row counts per table for rowwise states
|
|
2533
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
2534
|
+
# Cumulative element counts per table for elementwise states
|
|
2535
|
+
elem_count_cumsum: list[int] = [0] + list(
|
|
2536
|
+
itertools.accumulate([r * d for r, d in self.embedding_specs])
|
|
2537
|
+
)
|
|
2538
|
+
|
|
2539
|
+
# pyre-ignore[53]
|
|
2540
|
+
def _slice(tensor: Tensor, t: int, rowwise: bool) -> Tensor:
|
|
2541
|
+
d: int = dims[t]
|
|
2542
|
+
e: int = rows[t]
|
|
2543
|
+
|
|
2544
|
+
if not rowwise:
|
|
2545
|
+
# Optimizer state is element-wise - compute the table offset for
|
|
2546
|
+
# the table, view the slice as 2D tensor
|
|
2547
|
+
return tensor.detach()[
|
|
2548
|
+
elem_count_cumsum[t] : elem_count_cumsum[t + 1]
|
|
2549
|
+
].view(-1, d)
|
|
2550
|
+
else:
|
|
2551
|
+
# Optimizer state is row-wise - fetch elements in range and view
|
|
2552
|
+
# slice as 1D
|
|
2553
|
+
return tensor.detach()[
|
|
2554
|
+
row_count_cumsum[t] : row_count_cumsum[t + 1]
|
|
2555
|
+
].view(e)
|
|
2556
|
+
|
|
2557
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
2558
|
+
return [
|
|
2559
|
+
[_slice(self.momentum1_dev, t, rowwise=True)]
|
|
2560
|
+
for t, _ in enumerate(rows)
|
|
2561
|
+
]
|
|
2562
|
+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
|
|
2563
|
+
return [
|
|
2564
|
+
[
|
|
2565
|
+
_slice(self.momentum1_dev, t, rowwise=False),
|
|
2566
|
+
# pyre-ignore[6]
|
|
2567
|
+
_slice(self.momentum2_dev, t, rowwise=True),
|
|
2568
|
+
]
|
|
2569
|
+
for t, _ in enumerate(rows)
|
|
2570
|
+
]
|
|
2571
|
+
|
|
2572
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2573
|
+
return [
|
|
2574
|
+
[
|
|
2575
|
+
_slice(self.momentum1_dev, t, rowwise=False),
|
|
2576
|
+
# pyre-ignore[6]
|
|
2577
|
+
_slice(self.momentum2_dev, t, rowwise=False),
|
|
2578
|
+
]
|
|
2579
|
+
for t, _ in enumerate(rows)
|
|
2580
|
+
]
|
|
2581
|
+
|
|
2582
|
+
else:
|
|
2583
|
+
raise NotImplementedError(
|
|
2584
|
+
f"Getting optimizer states is not supported for {self.optimizer}"
|
|
2585
|
+
)
|
|
2586
|
+
|
|
2587
|
+
@torch.jit.ignore
|
|
2588
|
+
def _split_optimizer_states_kv_zch_no_offloading(
|
|
2589
|
+
self,
|
|
2590
|
+
sorted_ids: torch.Tensor,
|
|
2591
|
+
) -> list[list[torch.Tensor]]:
|
|
2592
|
+
|
|
2593
|
+
# Row count per table
|
|
2594
|
+
(rows, dims) = zip(*self.embedding_specs)
|
|
2595
|
+
# Cumulative row counts per table for rowwise states
|
|
2596
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
2597
|
+
# Cumulative element counts per table for elementwise states
|
|
2598
|
+
elem_count_cumsum: list[int] = [0] + list(
|
|
2599
|
+
itertools.accumulate([r * d for r, d in self.embedding_specs])
|
|
2600
|
+
)
|
|
2601
|
+
|
|
2602
|
+
# pyre-ignore[53]
|
|
2603
|
+
def _slice(state_name: str, tensor: Tensor, t: int, rowwise: bool) -> Tensor:
|
|
2604
|
+
d: int = dims[t]
|
|
2605
|
+
|
|
2606
|
+
# pyre-ignore[16]
|
|
2607
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
2608
|
+
# pyre-ignore[16]
|
|
2609
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
2610
|
+
|
|
2611
|
+
if sorted_ids is None or sorted_ids[t].numel() == 0:
|
|
2612
|
+
# Empty optimizer state for module initialization
|
|
2613
|
+
return torch.empty(
|
|
2614
|
+
0,
|
|
2615
|
+
dtype=(
|
|
2616
|
+
self.optimizer_state_dtypes.get(
|
|
2617
|
+
state_name, SparseType.FP32
|
|
2618
|
+
).as_dtype()
|
|
2619
|
+
),
|
|
2620
|
+
device="cpu",
|
|
2621
|
+
)
|
|
2622
|
+
|
|
2623
|
+
elif not rowwise:
|
|
2624
|
+
# Optimizer state is element-wise - materialize the local ids
|
|
2625
|
+
# based on the sorted_ids compute the table offset for the
|
|
2626
|
+
# table, view the slice as 2D tensor of e x d, then fetch the
|
|
2627
|
+
# sub-slice by local ids
|
|
2628
|
+
#
|
|
2629
|
+
# local_ids is [N, 1], flatten it to N to keep the returned tensor 2D
|
|
2630
|
+
local_ids = (sorted_ids[t] - bucket_id_start * bucket_size).view(-1)
|
|
2631
|
+
return (
|
|
2632
|
+
tensor.detach()
|
|
2633
|
+
.cpu()[elem_count_cumsum[t] : elem_count_cumsum[t + 1]]
|
|
2634
|
+
.view(-1, d)[local_ids]
|
|
2635
|
+
)
|
|
2636
|
+
|
|
2637
|
+
else:
|
|
2638
|
+
# Optimizer state is row-wise - materialize the local ids based
|
|
2639
|
+
# on the sorted_ids and table offset (i.e. row count cumsum),
|
|
2640
|
+
# then fetch by local ids
|
|
2641
|
+
linearized_local_ids = (
|
|
2642
|
+
sorted_ids[t] - bucket_id_start * bucket_size + row_count_cumsum[t]
|
|
2643
|
+
)
|
|
2644
|
+
return tensor.detach().cpu()[linearized_local_ids].view(-1)
|
|
2645
|
+
|
|
2646
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
2647
|
+
return [
|
|
2648
|
+
[_slice("momentum1", self.momentum1_dev, t, rowwise=True)]
|
|
2649
|
+
for t, _ in enumerate(rows)
|
|
2650
|
+
]
|
|
2651
|
+
|
|
2652
|
+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
|
|
2653
|
+
return [
|
|
2654
|
+
[
|
|
2655
|
+
_slice("momentum1", self.momentum1_dev, t, rowwise=False),
|
|
2656
|
+
# pyre-ignore[6]
|
|
2657
|
+
_slice("momentum2", self.momentum2_dev, t, rowwise=True),
|
|
2658
|
+
]
|
|
2659
|
+
for t, _ in enumerate(rows)
|
|
2660
|
+
]
|
|
2661
|
+
|
|
2662
|
+
elif self.optimizer == OptimType.ADAM:
|
|
2663
|
+
return [
|
|
2664
|
+
[
|
|
2665
|
+
_slice("momentum1", self.momentum1_dev, t, rowwise=False),
|
|
2666
|
+
# pyre-ignore[6]
|
|
2667
|
+
_slice("momentum2", self.momentum2_dev, t, rowwise=False),
|
|
2668
|
+
]
|
|
2669
|
+
for t, _ in enumerate(rows)
|
|
2670
|
+
]
|
|
2671
|
+
|
|
2672
|
+
else:
|
|
2673
|
+
raise NotImplementedError(
|
|
2674
|
+
f"Getting optimizer states is not supported for {self.optimizer}"
|
|
2675
|
+
)
|
|
2676
|
+
|
|
2677
|
+
@torch.jit.ignore
|
|
2678
|
+
def _split_optimizer_states_kv_zch_w_offloading(
|
|
2679
|
+
self,
|
|
2680
|
+
sorted_ids: torch.Tensor,
|
|
2681
|
+
no_snapshot: bool = True,
|
|
2682
|
+
should_flush: bool = False,
|
|
2683
|
+
) -> list[list[torch.Tensor]]:
|
|
2684
|
+
dtype = self.weights_precision.as_dtype()
|
|
2685
|
+
# Row count per table
|
|
2686
|
+
(rows_, dims_) = zip(*self.embedding_specs)
|
|
2687
|
+
# Cumulative row counts per table for rowwise states
|
|
2688
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
|
|
2689
|
+
|
|
2690
|
+
snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
|
|
2691
|
+
no_snapshot=no_snapshot,
|
|
2692
|
+
should_flush=should_flush,
|
|
2693
|
+
)
|
|
2694
|
+
|
|
2695
|
+
# pyre-ignore[53]
|
|
2696
|
+
def _fetch_offloaded_optimizer_states(
|
|
2697
|
+
t: int,
|
|
2698
|
+
) -> list[Tensor]:
|
|
2699
|
+
e: int = rows_[t]
|
|
2700
|
+
d: int = dims_[t]
|
|
2701
|
+
|
|
2702
|
+
# pyre-ignore[16]
|
|
2703
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
2704
|
+
# pyre-ignore[16]
|
|
2705
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
2706
|
+
|
|
2707
|
+
row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
|
|
2708
|
+
# Count of rows to fetch
|
|
2709
|
+
rows_to_fetch = sorted_ids[t].numel()
|
|
2710
|
+
|
|
2711
|
+
# Lookup the byte offsets for each optimizer state
|
|
2712
|
+
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
2713
|
+
d, self.weights_precision, self.optimizer_state_dtypes
|
|
2714
|
+
)
|
|
2715
|
+
# Find the minimum start of all the start/end pairs - we have to
|
|
2716
|
+
# offset the start/end pairs by this value to get the correct start/end
|
|
2717
|
+
offset_ = min(
|
|
2718
|
+
[start for _, (start, _) in optimizer_state_byte_offsets.items()]
|
|
2719
|
+
)
|
|
2720
|
+
# Update the start/end pairs to be relative to offset_
|
|
2721
|
+
optimizer_state_byte_offsets = dict(
|
|
2722
|
+
(k, (v1 - offset_, v2 - offset_))
|
|
2723
|
+
for k, (v1, v2) in optimizer_state_byte_offsets.items()
|
|
2724
|
+
)
|
|
2725
|
+
|
|
2726
|
+
# Since the backend returns cache rows that pack the weights and
|
|
2727
|
+
# optimizer states together, reading the whole tensor could cause OOM,
|
|
2728
|
+
# so we use the KVTensorWrapper abstraction to query the backend and
|
|
2729
|
+
# fetch the data in chunks instead.
|
|
2730
|
+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
2731
|
+
shape=[
|
|
2732
|
+
e,
|
|
2733
|
+
# Dim is terms of **weights** dtype
|
|
2734
|
+
self.optimizer_state_dim,
|
|
2735
|
+
],
|
|
2736
|
+
dtype=dtype,
|
|
2737
|
+
row_offset=row_offset,
|
|
2738
|
+
snapshot_handle=snapshot_handle,
|
|
2739
|
+
sorted_indices=sorted_ids[t],
|
|
2740
|
+
width_offset=pad4(d),
|
|
2741
|
+
)
|
|
2742
|
+
(
|
|
2743
|
+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
2744
|
+
if self.backend_type == BackendType.SSD
|
|
2745
|
+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
2746
|
+
)
|
|
2747
|
+
|
|
2748
|
+
# Fetch the state size table for the given weights domension
|
|
2749
|
+
state_size_table = self.optimizer.state_size_table(d)
|
|
2750
|
+
|
|
2751
|
+
# Create a 2D output buffer of [rows x optimizer state dim] with the
|
|
2752
|
+
# weights type as the type. For optimizers with multiple states (e.g.
|
|
2753
|
+
# momentum1 and momentum2), this tensor will include data from all
|
|
2754
|
+
# states, hence self.optimizer_state_dim as the row size.
|
|
2755
|
+
optimizer_states_buffer = torch.empty(
|
|
2756
|
+
(rows_to_fetch, self.optimizer_state_dim), dtype=dtype, device="cpu"
|
|
2757
|
+
)
|
|
2758
|
+
|
|
2759
|
+
# Set the chunk size for fetching
|
|
2760
|
+
chunk_size = (
|
|
2761
|
+
# 10M rows => 260(max_D)* 2(ele_bytes) * 10M => 5.2GB mem spike
|
|
2762
|
+
10_000_000
|
|
2763
|
+
)
|
|
2764
|
+
logging.info(f"split optimizer chunk rows: {chunk_size}")
|
|
2765
|
+
|
|
2766
|
+
# Chunk the fetching by chunk_size
|
|
2767
|
+
for i in range(0, rows_to_fetch, chunk_size):
|
|
2768
|
+
length = min(chunk_size, rows_to_fetch - i)
|
|
2769
|
+
|
|
2770
|
+
# Fetch from backend and copy to the output buffer
|
|
2771
|
+
optimizer_states_buffer.narrow(0, i, length).copy_(
|
|
2772
|
+
tensor_wrapper.narrow(0, i, length).view(dtype)
|
|
2773
|
+
)
|
|
2774
|
+
|
|
2775
|
+
# Now split up the buffer into N views, N for each optimizer state
|
|
2776
|
+
optimizer_states: list[Tensor] = []
|
|
2777
|
+
for state_name in self.optimizer.state_names():
|
|
2778
|
+
# Extract the offsets
|
|
2779
|
+
(start, end) = optimizer_state_byte_offsets[state_name]
|
|
2780
|
+
|
|
2781
|
+
state = optimizer_states_buffer.view(
|
|
2782
|
+
# Force tensor to byte view
|
|
2783
|
+
dtype=torch.uint8
|
|
2784
|
+
# Copy by byte offsets
|
|
2785
|
+
)[:, start:end].view(
|
|
2786
|
+
# Re-view in the state's target dtype
|
|
2787
|
+
self.optimizer_state_dtypes.get(
|
|
2788
|
+
state_name, SparseType.FP32
|
|
2789
|
+
).as_dtype()
|
|
2790
|
+
)
|
|
2791
|
+
|
|
2792
|
+
optimizer_states.append(
|
|
2793
|
+
# If the state is rowwise (i.e. just one element per row),
|
|
2794
|
+
# then re-view as 1D tensor
|
|
2795
|
+
state
|
|
2796
|
+
if state_size_table[state_name] > 1
|
|
2797
|
+
else state.view(-1)
|
|
2798
|
+
)
|
|
2799
|
+
|
|
2800
|
+
# Return the views
|
|
2801
|
+
return optimizer_states
|
|
2802
|
+
|
|
2803
|
+
return [
|
|
2804
|
+
(
|
|
2805
|
+
self.optimizer.empty_states([0], [d], self.optimizer_state_dtypes)[0]
|
|
2806
|
+
# Return a set of empty states if sorted_ids[t] is empty
|
|
2807
|
+
if sorted_ids is None or sorted_ids[t].numel() == 0
|
|
2808
|
+
# Else fetch the list of optimizer states for the table
|
|
2809
|
+
else _fetch_offloaded_optimizer_states(t)
|
|
2810
|
+
)
|
|
2811
|
+
for t, d in enumerate(dims_)
|
|
2812
|
+
]
|
|
2813
|
+
|
|
2814
|
+
@torch.jit.ignore
|
|
2815
|
+
def _split_optimizer_states_kv_zch_whole_row(
|
|
2816
|
+
self,
|
|
2817
|
+
sorted_ids: torch.Tensor,
|
|
2818
|
+
no_snapshot: bool = True,
|
|
2819
|
+
should_flush: bool = False,
|
|
2820
|
+
) -> list[list[torch.Tensor]]:
|
|
2821
|
+
dtype = self.weights_precision.as_dtype()
|
|
2822
|
+
|
|
2823
|
+
# Row and dimension counts per table
|
|
2824
|
+
# rows_ is only used here to compute the virtual table offsets
|
|
2825
|
+
(rows_, dims_) = zip(*self.embedding_specs)
|
|
2826
|
+
|
|
2827
|
+
# Cumulative row counts per (virtual) table for rowwise states
|
|
2828
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows_))
|
|
2829
|
+
|
|
2830
|
+
snapshot_handle, _ = self._may_create_snapshot_for_state_dict(
|
|
2831
|
+
no_snapshot=no_snapshot,
|
|
2832
|
+
should_flush=should_flush,
|
|
2833
|
+
)
|
|
2834
|
+
|
|
2835
|
+
# pyre-ignore[53]
|
|
2836
|
+
def _fetch_offloaded_optimizer_states(
|
|
2837
|
+
t: int,
|
|
2838
|
+
) -> list[Tensor]:
|
|
2839
|
+
d: int = dims_[t]
|
|
2840
|
+
|
|
2841
|
+
# pyre-ignore[16]
|
|
2842
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
2843
|
+
# pyre-ignore[16]
|
|
2844
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
2845
|
+
row_offset = row_count_cumsum[t] - (bucket_id_start * bucket_size)
|
|
2846
|
+
|
|
2847
|
+
# When backend returns whole row, the optimizer will be returned as
|
|
2848
|
+
# PMT directly
|
|
2849
|
+
if sorted_ids[t].size(0) == 0 and self.local_weight_counts[t] > 0:
|
|
2850
|
+
logging.info(
|
|
2851
|
+
f"Before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
|
|
2852
|
+
)
|
|
2853
|
+
sorted_ids[t] = torch.zeros(
|
|
2854
|
+
(self.local_weight_counts[t], 1),
|
|
2855
|
+
device=torch.device("cpu"),
|
|
2856
|
+
dtype=torch.int64,
|
|
2857
|
+
)
|
|
2858
|
+
|
|
2859
|
+
# Lookup the byte offsets for each optimizer state relative to the
|
|
2860
|
+
# start of the weights
|
|
2861
|
+
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
2862
|
+
d, self.weights_precision, self.optimizer_state_dtypes
|
|
2863
|
+
)
|
|
2864
|
+
# Get the number of elements (of the optimizer state dtype) per state
|
|
2865
|
+
optimizer_state_size_table = self.optimizer.state_size_table(d)
|
|
2866
|
+
|
|
2867
|
+
# Get metaheader dimensions in number of elements of weight dtype
|
|
2868
|
+
metaheader_dim = (
|
|
2869
|
+
# pyre-ignore[16]
|
|
2870
|
+
self.kv_zch_params.eviction_policy.meta_header_lens[t]
|
|
2871
|
+
)
|
|
2872
|
+
|
|
2873
|
+
# Now split up the buffer into N views, N for each optimizer state
|
|
2874
|
+
optimizer_states: list[PartiallyMaterializedTensor] = []
|
|
2875
|
+
for state_name in self.optimizer.state_names():
|
|
2876
|
+
state_dtype = self.optimizer_state_dtypes.get(
|
|
2877
|
+
state_name, SparseType.FP32
|
|
2878
|
+
).as_dtype()
|
|
2879
|
+
|
|
2880
|
+
# Get the size of the state in elements of the optimizer state,
|
|
2881
|
+
# in terms of the **weights** dtype
|
|
2882
|
+
state_size = math.ceil(
|
|
2883
|
+
optimizer_state_size_table[state_name]
|
|
2884
|
+
* state_dtype.itemsize
|
|
2885
|
+
/ dtype.itemsize
|
|
2886
|
+
)
|
|
2887
|
+
|
|
2888
|
+
# Extract the offsets relative to the start of the weights (in
|
|
2889
|
+
# num bytes)
|
|
2890
|
+
(start, _) = optimizer_state_byte_offsets[state_name]
|
|
2891
|
+
|
|
2892
|
+
# Convert the start to number of elements in terms of the
|
|
2893
|
+
# **weights** dtype, then add the mmetaheader dim offset
|
|
2894
|
+
start = metaheader_dim + start // dtype.itemsize
|
|
2895
|
+
|
|
2896
|
+
shape = [
|
|
2897
|
+
(
|
|
2898
|
+
sorted_ids[t].size(0)
|
|
2899
|
+
if sorted_ids is not None and sorted_ids[t].size(0) > 0
|
|
2900
|
+
else self.local_weight_counts[t]
|
|
2901
|
+
),
|
|
2902
|
+
(
|
|
2903
|
+
# Dim is in terms of the **weights** dtype
|
|
2904
|
+
state_size
|
|
2905
|
+
),
|
|
2906
|
+
]
|
|
2907
|
+
|
|
2908
|
+
# NOTE: We have to view using the **weights** dtype, as
|
|
2909
|
+
# there is currently a bug with KVTensorWrapper where using
|
|
2910
|
+
# a different dtype does not result in the same bytes being
|
|
2911
|
+
# returned, e.g.
|
|
2912
|
+
#
|
|
2913
|
+
# KVTensorWrapper(dtype=fp32, width_offset=130, shape=[N, 1])
|
|
2914
|
+
#
|
|
2915
|
+
# is NOT the same as
|
|
2916
|
+
#
|
|
2917
|
+
# KVTensorWrapper(dtype=fp16, width_offset=260, shape=[N, 2]).view(-1).view(fp32)
|
|
2918
|
+
#
|
|
2919
|
+
# TODO: Fix KVTensorWrapper to support viewing data under different dtypes
|
|
2920
|
+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
2921
|
+
shape=shape,
|
|
2922
|
+
dtype=(
|
|
2923
|
+
# NOTE: Use the *weights* dtype
|
|
2924
|
+
dtype
|
|
2925
|
+
),
|
|
2926
|
+
row_offset=row_offset,
|
|
2927
|
+
snapshot_handle=snapshot_handle,
|
|
2928
|
+
sorted_indices=sorted_ids[t],
|
|
2929
|
+
width_offset=(
|
|
2930
|
+
# NOTE: Width offset is in terms of **weights** dtype
|
|
2931
|
+
start
|
|
2932
|
+
),
|
|
2933
|
+
# Optimizer written to DB with weights, so skip write here
|
|
2934
|
+
read_only=True,
|
|
2935
|
+
)
|
|
2936
|
+
(
|
|
2937
|
+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
2938
|
+
if self.backend_type == BackendType.SSD
|
|
2939
|
+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
2940
|
+
)
|
|
2941
|
+
|
|
2942
|
+
optimizer_states.append(
|
|
2943
|
+
PartiallyMaterializedTensor(tensor_wrapper, True)
|
|
2944
|
+
)
|
|
2945
|
+
|
|
2946
|
+
# pyre-ignore [7]
|
|
2947
|
+
return optimizer_states
|
|
2948
|
+
|
|
2949
|
+
return [_fetch_offloaded_optimizer_states(t) for t, _ in enumerate(dims_)]
|
|
2950
|
+
|
|
2951
|
+
@torch.jit.export
|
|
2952
|
+
def split_optimizer_states(
|
|
2953
|
+
self,
|
|
2954
|
+
sorted_id_tensor: Optional[list[torch.Tensor]] = None,
|
|
2955
|
+
no_snapshot: bool = True,
|
|
2956
|
+
should_flush: bool = False,
|
|
2957
|
+
) -> list[list[torch.Tensor]]:
|
|
2958
|
+
"""
|
|
2959
|
+
Returns a list of optimizer states split by table.
|
|
2960
|
+
|
|
2961
|
+
Since EXACT_ROWWISE_ADAGRAD has small optimizer states, we would generate
|
|
2962
|
+
a full tensor for each table (shard). When other optimizer types are supported,
|
|
2963
|
+
we should integrate with KVTensorWrapper (ssd_split_table_batched_embeddings.cpp)
|
|
2964
|
+
to allow caller to read the optimizer states using `narrow()` in a rolling-window manner.
|
|
2965
|
+
|
|
2966
|
+
Args:
|
|
2967
|
+
sorted_id_tensor (Optional[List[torch.Tensor]]): sorted id tensor by table, used to query optimizer
|
|
2968
|
+
state from backend. Call should reuse the generated id tensor from weight state_dict, to guarantee
|
|
2969
|
+
id consistency between weight and optimizer states.
|
|
2970
|
+
|
|
2971
|
+
"""
|
|
2972
|
+
|
|
2973
|
+
# Handle the non-KVZCH case
|
|
2974
|
+
if not self.kv_zch_params:
|
|
2975
|
+
# If not in KV
|
|
2976
|
+
return self._split_optimizer_states_non_kv_zch()
|
|
2977
|
+
|
|
2978
|
+
# Handle the loading-from-state-dict case
|
|
2979
|
+
if self.load_state_dict:
|
|
2980
|
+
# Initialize for checkpointing loading
|
|
2981
|
+
assert (
|
|
2982
|
+
self._cached_kvzch_data is not None
|
|
2983
|
+
and self._cached_kvzch_data.cached_optimizer_states_per_table
|
|
2984
|
+
), "optimizer state is not initialized for load checkpointing"
|
|
2985
|
+
|
|
2986
|
+
return self._cached_kvzch_data.cached_optimizer_states_per_table
|
|
2987
|
+
|
|
2988
|
+
logging.info(
|
|
2989
|
+
f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
|
|
2990
|
+
)
|
|
2991
|
+
start_time = time.time()
|
|
2992
|
+
|
|
2993
|
+
if not self.enable_optimizer_offloading:
|
|
2994
|
+
# Handle the KVZCH non-optimizer-offloading case
|
|
2995
|
+
optimizer_states = self._split_optimizer_states_kv_zch_no_offloading(
|
|
2996
|
+
sorted_id_tensor
|
|
2997
|
+
)
|
|
2998
|
+
|
|
2999
|
+
elif not self.backend_return_whole_row:
|
|
3000
|
+
# Handle the KVZCH with-optimizer-offloading case
|
|
3001
|
+
optimizer_states = self._split_optimizer_states_kv_zch_w_offloading(
|
|
3002
|
+
sorted_id_tensor, no_snapshot, should_flush
|
|
3003
|
+
)
|
|
3004
|
+
|
|
3005
|
+
else:
|
|
3006
|
+
# Handle the KVZCH with-optimizer-offloading backend-whole-row case
|
|
3007
|
+
optimizer_states = self._split_optimizer_states_kv_zch_whole_row(
|
|
3008
|
+
sorted_id_tensor, no_snapshot, should_flush
|
|
3009
|
+
)
|
|
3010
|
+
|
|
3011
|
+
logging.info(
|
|
3012
|
+
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
|
|
3013
|
+
# pyre-ignore[16]
|
|
3014
|
+
f"num ids list: {None if not sorted_id_tensor else [ids.numel() for ids in sorted_id_tensor]}"
|
|
3015
|
+
)
|
|
3016
|
+
|
|
3017
|
+
return optimizer_states
|
|
3018
|
+
|
|
3019
|
+
@torch.jit.export
|
|
3020
|
+
def get_optimizer_state(
|
|
3021
|
+
self,
|
|
3022
|
+
sorted_id_tensor: Optional[list[torch.Tensor]],
|
|
3023
|
+
no_snapshot: bool = True,
|
|
3024
|
+
should_flush: bool = False,
|
|
3025
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
3026
|
+
"""
|
|
3027
|
+
Returns a list of dictionaries of optimizer states split by table.
|
|
3028
|
+
"""
|
|
3029
|
+
states_list: list[list[Tensor]] = self.split_optimizer_states(
|
|
3030
|
+
sorted_id_tensor=sorted_id_tensor,
|
|
3031
|
+
no_snapshot=no_snapshot,
|
|
3032
|
+
should_flush=should_flush,
|
|
3033
|
+
)
|
|
3034
|
+
state_names = self.optimizer.state_names()
|
|
3035
|
+
return [dict(zip(state_names, states)) for states in states_list]
|
|
3036
|
+
|
|
3037
|
+
@torch.jit.export
|
|
3038
|
+
def debug_split_embedding_weights(self) -> list[torch.Tensor]:
|
|
3039
|
+
"""
|
|
3040
|
+
Returns a list of weights, split by table.
|
|
3041
|
+
|
|
3042
|
+
Testing only, very slow.
|
|
3043
|
+
"""
|
|
3044
|
+
(rows, _) = zip(*self.embedding_specs)
|
|
3045
|
+
|
|
3046
|
+
rows_cumsum = [0] + list(itertools.accumulate(rows))
|
|
3047
|
+
splits = []
|
|
3048
|
+
get_event = torch.cuda.Event()
|
|
3049
|
+
|
|
3050
|
+
for t, (row, _) in enumerate(self.embedding_specs):
|
|
3051
|
+
weights = torch.empty(
|
|
3052
|
+
(row, self.max_D), dtype=self.weights_precision.as_dtype()
|
|
3053
|
+
)
|
|
3054
|
+
self.ssd_db.get_cuda(
|
|
3055
|
+
torch.arange(rows_cumsum[t], rows_cumsum[t + 1]).to(torch.int64),
|
|
3056
|
+
weights,
|
|
3057
|
+
torch.as_tensor([row]),
|
|
3058
|
+
)
|
|
3059
|
+
splits.append(weights)
|
|
3060
|
+
|
|
3061
|
+
# Record the event to create a dependency between get_cuda's callback
|
|
3062
|
+
# function and the kernel on the GPU default stream (the intention is
|
|
3063
|
+
# actually to synchronize between the callback CPU thread and the
|
|
3064
|
+
# Python CPU thread but we do not have a mechanism to explicitly sync
|
|
3065
|
+
# between them)
|
|
3066
|
+
get_event.record()
|
|
3067
|
+
|
|
3068
|
+
# Synchronize to make sure that the callback function in get_cuda
|
|
3069
|
+
# completes (here the CPU thread is blocked until get_event is done)
|
|
3070
|
+
get_event.synchronize()
|
|
3071
|
+
|
|
3072
|
+
# Reshape the weight tensors (this can be expensive, however, this
|
|
3073
|
+
# function is for debugging only)
|
|
3074
|
+
for t, (row, dim) in enumerate(self.embedding_specs):
|
|
3075
|
+
weight = splits[t]
|
|
3076
|
+
weight = weight[:, :dim].contiguous()
|
|
3077
|
+
assert weight.shape == (row, dim), "Shapes mismatch"
|
|
3078
|
+
splits[t] = weight
|
|
3079
|
+
|
|
3080
|
+
return splits
|
|
3081
|
+
|
|
3082
|
+
def clear_cache(self) -> None:
|
|
3083
|
+
# clear KV ZCH cache for checkpointing
|
|
3084
|
+
self._cached_kvzch_data = None
|
|
3085
|
+
|
|
3086
|
+
@torch.jit.ignore
|
|
3087
|
+
# pyre-ignore [3] - do not definte snapshot class EmbeddingSnapshotHandleWrapper to avoid import dependency in other production code
|
|
3088
|
+
def _may_create_snapshot_for_state_dict(
|
|
3089
|
+
self,
|
|
3090
|
+
no_snapshot: bool = True,
|
|
3091
|
+
should_flush: bool = False,
|
|
3092
|
+
):
|
|
3093
|
+
"""
|
|
3094
|
+
Create a rocksdb snapshot if needed.
|
|
3095
|
+
"""
|
|
3096
|
+
start_time = time.time()
|
|
3097
|
+
# Force device synchronize for now
|
|
3098
|
+
torch.cuda.synchronize()
|
|
3099
|
+
snapshot_handle = None
|
|
3100
|
+
checkpoint_handle = None
|
|
3101
|
+
if self.backend_type == BackendType.SSD:
|
|
3102
|
+
# Create a rocksdb snapshot
|
|
3103
|
+
if not no_snapshot:
|
|
3104
|
+
# Flush L1 and L2 caches
|
|
3105
|
+
self.flush(force=should_flush)
|
|
3106
|
+
logging.info(
|
|
3107
|
+
f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
|
|
3108
|
+
)
|
|
3109
|
+
snapshot_handle = self.ssd_db.create_snapshot()
|
|
3110
|
+
checkpoint_handle = self.ssd_db.get_active_checkpoint_uuid(self.step)
|
|
3111
|
+
logging.info(
|
|
3112
|
+
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
|
|
3113
|
+
)
|
|
3114
|
+
elif self.backend_type == BackendType.DRAM:
|
|
3115
|
+
# if there is any ongoing eviction, lets wait until eviction is finished before state_dict
|
|
3116
|
+
# so that we can reach consistent model state before/after state_dict
|
|
3117
|
+
evict_wait_start_time = time.time()
|
|
3118
|
+
self.ssd_db.wait_until_eviction_done()
|
|
3119
|
+
logging.info(
|
|
3120
|
+
f"state_dict wait for ongoing eviction: {time.time() - evict_wait_start_time} s"
|
|
3121
|
+
)
|
|
3122
|
+
self.flush(force=should_flush)
|
|
3123
|
+
return snapshot_handle, checkpoint_handle
|
|
3124
|
+
|
|
3125
|
+
def get_embedding_dim_for_kvt(
|
|
3126
|
+
self, metaheader_dim: int, emb_dim: int, is_loading_checkpoint: bool
|
|
3127
|
+
) -> int:
|
|
3128
|
+
if self.load_ckpt_without_opt:
|
|
3129
|
+
# For silvertorch publish, we don't want to load opt into backend due to limited cpu memory in publish host.
|
|
3130
|
+
# So we need to load the whole row into state dict which loading the checkpoint in st publish, then only save weight into backend, after that
|
|
3131
|
+
# backend will only have metaheader + weight.
|
|
3132
|
+
# For the first loading, we need to set dim with metaheader_dim + emb_dim + optimizer_state_dim, otherwise the checkpoint loadding will throw size mismatch error
|
|
3133
|
+
# after the first loading, we only need to get metaheader+weight from backend for state dict, so we can set dim with metaheader_dim + emb
|
|
3134
|
+
if is_loading_checkpoint:
|
|
3135
|
+
return (
|
|
3136
|
+
(
|
|
3137
|
+
metaheader_dim # metaheader is already padded
|
|
3138
|
+
+ pad4(emb_dim)
|
|
3139
|
+
+ pad4(self.optimizer_state_dim)
|
|
3140
|
+
)
|
|
3141
|
+
if self.backend_return_whole_row
|
|
3142
|
+
else emb_dim
|
|
3143
|
+
)
|
|
3144
|
+
else:
|
|
3145
|
+
return metaheader_dim + pad4(emb_dim)
|
|
3146
|
+
else:
|
|
3147
|
+
return (
|
|
3148
|
+
(
|
|
3149
|
+
metaheader_dim # metaheader is already padded
|
|
3150
|
+
+ pad4(emb_dim)
|
|
3151
|
+
+ pad4(self.optimizer_state_dim)
|
|
3152
|
+
)
|
|
3153
|
+
if self.backend_return_whole_row
|
|
3154
|
+
else emb_dim
|
|
3155
|
+
)
|
|
3156
|
+
|
|
3157
|
+
@torch.jit.export
|
|
3158
|
+
def split_embedding_weights(
|
|
3159
|
+
self,
|
|
3160
|
+
no_snapshot: bool = True,
|
|
3161
|
+
should_flush: bool = False,
|
|
3162
|
+
) -> tuple[ # TODO: make this a NamedTuple for readability
|
|
3163
|
+
Union[list[PartiallyMaterializedTensor], list[torch.Tensor]],
|
|
3164
|
+
Optional[list[torch.Tensor]],
|
|
3165
|
+
Optional[list[torch.Tensor]],
|
|
3166
|
+
Optional[list[torch.Tensor]],
|
|
3167
|
+
]:
|
|
3168
|
+
"""
|
|
3169
|
+
This method is intended to be used by the checkpointing engine
|
|
3170
|
+
only.
|
|
3171
|
+
|
|
3172
|
+
Since we cannot materialize SSD backed tensors fully in CPU memory,
|
|
3173
|
+
we would create a KVTensorWrapper (ssd_split_table_batched_embeddings.cpp)
|
|
3174
|
+
for each table (shard), which allows caller to read the weights
|
|
3175
|
+
using `narrow()` in a rolling-window manner.
|
|
3176
|
+
Args:
|
|
3177
|
+
should_flush (bool): Flush caches if True. Note: this is an expensive
|
|
3178
|
+
operation, only set to True when necessary.
|
|
3179
|
+
|
|
3180
|
+
Returns:
|
|
3181
|
+
tuples of 3 lists, each element corresponds to a logical table
|
|
3182
|
+
1st arg: partially materialized tensors, each representing a table
|
|
3183
|
+
2nd arg: input id sorted in bucket id ascending order
|
|
3184
|
+
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
|
|
3185
|
+
where for the i th element, we have i + bucket_id_start = global bucket id
|
|
3186
|
+
4th arg: kvzch eviction metadata for each input id sorted in bucket id ascending order
|
|
3187
|
+
"""
|
|
3188
|
+
snapshot_handle, checkpoint_handle = self._may_create_snapshot_for_state_dict(
|
|
3189
|
+
no_snapshot=no_snapshot,
|
|
3190
|
+
should_flush=should_flush,
|
|
3191
|
+
)
|
|
3192
|
+
|
|
3193
|
+
dtype = self.weights_precision.as_dtype()
|
|
3194
|
+
if self.load_state_dict and self.kv_zch_params:
|
|
3195
|
+
# init for checkpointing loading
|
|
3196
|
+
assert (
|
|
3197
|
+
self._cached_kvzch_data is not None
|
|
3198
|
+
), "weight id and bucket state are not initialized for load checkpointing"
|
|
3199
|
+
return (
|
|
3200
|
+
self._cached_kvzch_data.cached_weight_tensor_per_table,
|
|
3201
|
+
self._cached_kvzch_data.cached_id_tensor_per_table,
|
|
3202
|
+
self._cached_kvzch_data.cached_bucket_splits,
|
|
3203
|
+
[], # metadata tensor is not needed for checkpointing loading
|
|
3204
|
+
)
|
|
3205
|
+
start_time = time.time()
|
|
3206
|
+
pmt_splits = []
|
|
3207
|
+
bucket_sorted_id_splits = [] if self.kv_zch_params else None
|
|
3208
|
+
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
|
|
3209
|
+
metadata_splits = [] if self.kv_zch_params else None
|
|
3210
|
+
skip_metadata = False
|
|
3211
|
+
|
|
3212
|
+
table_offset = 0
|
|
3213
|
+
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
|
|
3214
|
+
is_loading_checkpoint = False
|
|
3215
|
+
bucket_ascending_id_tensor = None
|
|
3216
|
+
bucket_t = None
|
|
3217
|
+
metadata_tensor = None
|
|
3218
|
+
row_offset = table_offset
|
|
3219
|
+
metaheader_dim = 0
|
|
3220
|
+
if self.kv_zch_params:
|
|
3221
|
+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
|
|
3222
|
+
# pyre-ignore
|
|
3223
|
+
bucket_size = self.kv_zch_params.bucket_sizes[i]
|
|
3224
|
+
metaheader_dim = (
|
|
3225
|
+
# pyre-ignore[16]
|
|
3226
|
+
self.kv_zch_params.eviction_policy.meta_header_lens[i]
|
|
3227
|
+
)
|
|
3228
|
+
|
|
3229
|
+
# linearize with table offset
|
|
3230
|
+
table_input_id_start = table_offset
|
|
3231
|
+
table_input_id_end = table_offset + emb_height
|
|
3232
|
+
# 1. get all keys from backend for one table
|
|
3233
|
+
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
|
|
3234
|
+
table_input_id_start,
|
|
3235
|
+
table_input_id_end,
|
|
3236
|
+
table_offset,
|
|
3237
|
+
snapshot_handle,
|
|
3238
|
+
)
|
|
3239
|
+
# 2. sorting keys in bucket ascending order
|
|
3240
|
+
bucket_ascending_id_tensor, bucket_t = (
|
|
3241
|
+
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
|
|
3242
|
+
unordered_id_tensor,
|
|
3243
|
+
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
|
|
3244
|
+
0, # local bucket offset
|
|
3245
|
+
bucket_id_end - bucket_id_start, # local bucket num
|
|
3246
|
+
bucket_size,
|
|
3247
|
+
)
|
|
3248
|
+
)
|
|
3249
|
+
metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
|
|
3250
|
+
bucket_ascending_id_tensor + table_offset,
|
|
3251
|
+
torch.as_tensor(bucket_ascending_id_tensor.size(0)),
|
|
3252
|
+
snapshot_handle,
|
|
3253
|
+
).view(-1, 1)
|
|
3254
|
+
|
|
3255
|
+
# 3. convert local id back to global id
|
|
3256
|
+
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
|
|
3257
|
+
|
|
3258
|
+
if (
|
|
3259
|
+
bucket_ascending_id_tensor.size(0) == 0
|
|
3260
|
+
and self.local_weight_counts[i] > 0
|
|
3261
|
+
):
|
|
3262
|
+
logging.info(
|
|
3263
|
+
f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
|
|
3264
|
+
)
|
|
3265
|
+
if self.global_id_per_rank[i].numel() != 0:
|
|
3266
|
+
assert (
|
|
3267
|
+
self.local_weight_counts[i]
|
|
3268
|
+
== self.global_id_per_rank[i].numel()
|
|
3269
|
+
), f"local weight count and global id per rank size mismatch, with {self.local_weight_counts[i]} and {self.global_id_per_rank[i].numel()}"
|
|
3270
|
+
bucket_ascending_id_tensor = self.global_id_per_rank[i].to(
|
|
3271
|
+
device=torch.device("cpu"), dtype=torch.int64
|
|
3272
|
+
)
|
|
3273
|
+
else:
|
|
3274
|
+
bucket_ascending_id_tensor = torch.zeros(
|
|
3275
|
+
(self.local_weight_counts[i], 1),
|
|
3276
|
+
device=torch.device("cpu"),
|
|
3277
|
+
dtype=torch.int64,
|
|
3278
|
+
)
|
|
3279
|
+
skip_metadata = True
|
|
3280
|
+
is_loading_checkpoint = True
|
|
3281
|
+
|
|
3282
|
+
# self.local_weight_counts[i] = 0 # Reset the count
|
|
3283
|
+
|
|
3284
|
+
# pyre-ignore [16] bucket_sorted_id_splits is not None
|
|
3285
|
+
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
|
|
3286
|
+
active_id_cnt_per_bucket_split.append(bucket_t)
|
|
3287
|
+
if skip_metadata:
|
|
3288
|
+
metadata_splits = None
|
|
3289
|
+
else:
|
|
3290
|
+
metadata_splits.append(metadata_tensor)
|
|
3291
|
+
|
|
3292
|
+
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
|
|
3293
|
+
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
|
|
3294
|
+
# first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
|
|
3295
|
+
# to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
|
|
3296
|
+
row_offset = table_offset - (bucket_id_start * bucket_size)
|
|
3297
|
+
|
|
3298
|
+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
|
|
3299
|
+
shape=[
|
|
3300
|
+
(
|
|
3301
|
+
bucket_ascending_id_tensor.size(0)
|
|
3302
|
+
if bucket_ascending_id_tensor is not None
|
|
3303
|
+
else emb_height
|
|
3304
|
+
),
|
|
3305
|
+
self.get_embedding_dim_for_kvt(
|
|
3306
|
+
metaheader_dim, emb_dim, is_loading_checkpoint
|
|
3307
|
+
),
|
|
3308
|
+
],
|
|
3309
|
+
dtype=dtype,
|
|
3310
|
+
row_offset=row_offset,
|
|
3311
|
+
snapshot_handle=snapshot_handle,
|
|
3312
|
+
# set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
|
|
3313
|
+
# embedding weights.
|
|
3314
|
+
sorted_indices=(
|
|
3315
|
+
bucket_ascending_id_tensor if self.kv_zch_params else None
|
|
3316
|
+
),
|
|
3317
|
+
checkpoint_handle=checkpoint_handle,
|
|
3318
|
+
only_load_weight=(
|
|
3319
|
+
True
|
|
3320
|
+
if self.load_ckpt_without_opt and is_loading_checkpoint
|
|
3321
|
+
else False
|
|
3322
|
+
),
|
|
3323
|
+
)
|
|
3324
|
+
(
|
|
3325
|
+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
3326
|
+
if self.backend_type == BackendType.SSD
|
|
3327
|
+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
|
|
3328
|
+
)
|
|
3329
|
+
table_offset += emb_height
|
|
3330
|
+
pmt_splits.append(
|
|
3331
|
+
PartiallyMaterializedTensor(
|
|
3332
|
+
tensor_wrapper,
|
|
3333
|
+
True if self.kv_zch_params else False,
|
|
3334
|
+
)
|
|
3335
|
+
)
|
|
3336
|
+
logging.info(
|
|
3337
|
+
f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms, "
|
|
3338
|
+
)
|
|
3339
|
+
if self.kv_zch_params is not None:
|
|
3340
|
+
logging.info(
|
|
3341
|
+
# pyre-ignore [16]
|
|
3342
|
+
f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
|
|
3343
|
+
)
|
|
3344
|
+
|
|
3345
|
+
return (
|
|
3346
|
+
pmt_splits,
|
|
3347
|
+
bucket_sorted_id_splits,
|
|
3348
|
+
active_id_cnt_per_bucket_split,
|
|
3349
|
+
metadata_splits,
|
|
3350
|
+
)
|
|
3351
|
+
|
|
3352
|
+
@torch.jit.ignore
|
|
3353
|
+
def _apply_state_dict_w_offloading(self) -> None:
|
|
3354
|
+
# Row count per table
|
|
3355
|
+
(rows, _) = zip(*self.embedding_specs)
|
|
3356
|
+
# Cumulative row counts per table for rowwise states
|
|
3357
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
3358
|
+
|
|
3359
|
+
for t, _ in enumerate(self.embedding_specs):
|
|
3360
|
+
# pyre-ignore [16]
|
|
3361
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
3362
|
+
# pyre-ignore [16]
|
|
3363
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
3364
|
+
row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
|
|
3365
|
+
|
|
3366
|
+
# pyre-ignore [16]
|
|
3367
|
+
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
|
|
3368
|
+
# pyre-ignore [16]
|
|
3369
|
+
opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
|
|
3370
|
+
|
|
3371
|
+
self.streaming_write_weight_and_id_per_table(
|
|
3372
|
+
weight_state,
|
|
3373
|
+
opt_states,
|
|
3374
|
+
# pyre-ignore [16]
|
|
3375
|
+
self._cached_kvzch_data.cached_id_tensor_per_table[t],
|
|
3376
|
+
row_offset,
|
|
3377
|
+
)
|
|
3378
|
+
self._cached_kvzch_data.cached_weight_tensor_per_table[t] = None
|
|
3379
|
+
self._cached_kvzch_data.cached_optimizer_states_per_table[t] = None
|
|
3380
|
+
|
|
3381
|
+
@torch.jit.ignore
|
|
3382
|
+
def _apply_state_dict_no_offloading(self) -> None:
|
|
3383
|
+
# Row count per table
|
|
3384
|
+
(rows, _) = zip(*self.embedding_specs)
|
|
3385
|
+
# Cumulative row counts per table for rowwise states
|
|
3386
|
+
row_count_cumsum: list[int] = [0] + list(itertools.accumulate(rows))
|
|
3387
|
+
|
|
3388
|
+
def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None:
|
|
3389
|
+
device = dst.device
|
|
3390
|
+
dst.index_put_(
|
|
3391
|
+
indices=(
|
|
3392
|
+
# indices is expected to be a tuple of Tensors, not Tensor
|
|
3393
|
+
indices.to(device).view(-1),
|
|
3394
|
+
),
|
|
3395
|
+
values=src.to(device),
|
|
3396
|
+
)
|
|
3397
|
+
|
|
3398
|
+
for t, _ in enumerate(rows):
|
|
3399
|
+
# pyre-ignore [16]
|
|
3400
|
+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
|
|
3401
|
+
# pyre-ignore [16]
|
|
3402
|
+
bucket_size = self.kv_zch_params.bucket_sizes[t]
|
|
3403
|
+
row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
|
|
3404
|
+
|
|
3405
|
+
# pyre-ignore [16]
|
|
3406
|
+
weights = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
|
|
3407
|
+
# pyre-ignore [16]
|
|
3408
|
+
ids = self._cached_kvzch_data.cached_id_tensor_per_table[t]
|
|
3409
|
+
local_ids = ids + row_offset
|
|
3410
|
+
|
|
3411
|
+
logging.info(
|
|
3412
|
+
f"applying sd for table {t} without optimizer offloading, local_ids is {local_ids}"
|
|
3413
|
+
)
|
|
3414
|
+
# pyre-ignore [16]
|
|
3415
|
+
opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
|
|
3416
|
+
|
|
3417
|
+
# Set up the plan for copying optimizer states over
|
|
3418
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
3419
|
+
mapping = [(opt_states[0], self.momentum1_dev)]
|
|
3420
|
+
elif self.optimizer in [OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM]:
|
|
3421
|
+
mapping = [
|
|
3422
|
+
(opt_states[0], self.momentum1_dev),
|
|
3423
|
+
(opt_states[1], self.momentum2_dev),
|
|
3424
|
+
]
|
|
3425
|
+
else:
|
|
3426
|
+
mapping = []
|
|
3427
|
+
|
|
3428
|
+
# Execute the plan and copy the optimizer states over
|
|
3429
|
+
# pyre-ignore [6]
|
|
3430
|
+
[copy_optimizer_state_(dst, src, local_ids) for (src, dst) in mapping]
|
|
3431
|
+
|
|
3432
|
+
self.ssd_db.set_cuda(
|
|
3433
|
+
local_ids.view(-1),
|
|
3434
|
+
weights,
|
|
3435
|
+
torch.as_tensor(local_ids.size(0)),
|
|
3436
|
+
1,
|
|
3437
|
+
False,
|
|
3438
|
+
)
|
|
3439
|
+
|
|
3440
|
+
@torch.jit.ignore
|
|
3441
|
+
def apply_state_dict(self) -> None:
|
|
3442
|
+
if self.backend_return_whole_row:
|
|
3443
|
+
logging.info(
|
|
3444
|
+
"backend_return_whole_row is enabled, no need to apply_state_dict"
|
|
3445
|
+
)
|
|
3446
|
+
return
|
|
3447
|
+
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
|
|
3448
|
+
# Caller should call this function to apply the cached states to backend.
|
|
3449
|
+
if self.load_state_dict is False:
|
|
3450
|
+
return
|
|
3451
|
+
self.load_state_dict = False
|
|
3452
|
+
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
|
|
3453
|
+
assert (
|
|
3454
|
+
self._cached_kvzch_data is not None
|
|
3455
|
+
and self._cached_kvzch_data.cached_optimizer_states_per_table is not None
|
|
3456
|
+
), "optimizer state is not initialized for load checkpointing"
|
|
3457
|
+
assert (
|
|
3458
|
+
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
|
|
3459
|
+
and self._cached_kvzch_data.cached_id_tensor_per_table is not None
|
|
3460
|
+
), "weight and id state is not initialized for load checkpointing"
|
|
3461
|
+
|
|
3462
|
+
# Compute the number of elements of cache_dtype needed to store the
|
|
3463
|
+
# optimizer state, round to the nearest 4
|
|
3464
|
+
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
|
|
3465
|
+
# apply weight and optimizer state per table
|
|
3466
|
+
if self.enable_optimizer_offloading:
|
|
3467
|
+
self._apply_state_dict_w_offloading()
|
|
3468
|
+
else:
|
|
3469
|
+
self._apply_state_dict_no_offloading()
|
|
3470
|
+
|
|
3471
|
+
self.clear_cache()
|
|
3472
|
+
|
|
3473
|
+
@torch.jit.ignore
|
|
3474
|
+
def streaming_write_weight_and_id_per_table(
|
|
3475
|
+
self,
|
|
3476
|
+
weight_state: torch.Tensor,
|
|
3477
|
+
opt_states: list[torch.Tensor],
|
|
3478
|
+
id_tensor: torch.Tensor,
|
|
3479
|
+
row_offset: int,
|
|
3480
|
+
) -> None:
|
|
3481
|
+
"""
|
|
3482
|
+
This function is used to write weight, optimizer and id to the backend using kvt wrapper.
|
|
3483
|
+
to avoid over use memory, we will write the weight and id to backend in a rolling window manner
|
|
3484
|
+
|
|
3485
|
+
Args:
|
|
3486
|
+
weight_state (torch.tensor): The weight state tensor to be written.
|
|
3487
|
+
opt_states (torch.tensor): The optimizer state tensor(s) to be written.
|
|
3488
|
+
id_tensor (torch.tensor): The id tensor to be written.
|
|
3489
|
+
"""
|
|
3490
|
+
D = weight_state.size(1)
|
|
3491
|
+
dtype = self.weights_precision.as_dtype()
|
|
3492
|
+
|
|
3493
|
+
optimizer_state_byte_offsets = self.optimizer.byte_offsets_along_row(
|
|
3494
|
+
D, self.weights_precision, self.optimizer_state_dtypes
|
|
3495
|
+
)
|
|
3496
|
+
optimizer_state_size_table = self.optimizer.state_size_table(D)
|
|
3497
|
+
|
|
3498
|
+
kvt = torch.classes.fbgemm.KVTensorWrapper(
|
|
3499
|
+
shape=[weight_state.size(0), self.cache_row_dim],
|
|
3500
|
+
dtype=dtype,
|
|
3501
|
+
row_offset=row_offset,
|
|
3502
|
+
snapshot_handle=None,
|
|
3503
|
+
sorted_indices=id_tensor,
|
|
3504
|
+
)
|
|
3505
|
+
(
|
|
3506
|
+
kvt.set_embedding_rocks_dp_wrapper(self.ssd_db)
|
|
3507
|
+
if self.backend_type == BackendType.SSD
|
|
3508
|
+
else kvt.set_dram_db_wrapper(self.ssd_db)
|
|
3509
|
+
)
|
|
3510
|
+
|
|
3511
|
+
# TODO: make chunk_size configurable or dynamic
|
|
3512
|
+
chunk_size = 10000
|
|
3513
|
+
row = weight_state.size(0)
|
|
3514
|
+
|
|
3515
|
+
for i in range(0, row, chunk_size):
|
|
3516
|
+
# Construct the chunk buffer, using the weights precision as the dtype
|
|
3517
|
+
length = min(chunk_size, row - i)
|
|
3518
|
+
chunk_buffer = torch.empty(
|
|
3519
|
+
length,
|
|
3520
|
+
self.cache_row_dim,
|
|
3521
|
+
dtype=dtype,
|
|
3522
|
+
device="cpu",
|
|
3523
|
+
)
|
|
3524
|
+
|
|
3525
|
+
# Copy the weight state over to the chunk buffer
|
|
3526
|
+
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
|
|
3527
|
+
|
|
3528
|
+
# Copy the optimizer state(s) over to the chunk buffer
|
|
3529
|
+
for o, opt_state in enumerate(opt_states):
|
|
3530
|
+
# Fetch the state name based on the index
|
|
3531
|
+
state_name = self.optimizer.state_names()[o]
|
|
3532
|
+
|
|
3533
|
+
# Fetch the byte offsets for the optimizer state by its name
|
|
3534
|
+
(start, end) = optimizer_state_byte_offsets[state_name]
|
|
3535
|
+
|
|
3536
|
+
# Assume that the opt_state passed in already has dtype matching
|
|
3537
|
+
# self.optimizer_state_dtypes[state_name]
|
|
3538
|
+
opt_state_byteview = opt_state.view(
|
|
3539
|
+
# Force it to be 2D table, with row size matching the
|
|
3540
|
+
# optimizer state size
|
|
3541
|
+
-1,
|
|
3542
|
+
optimizer_state_size_table[state_name],
|
|
3543
|
+
).view(
|
|
3544
|
+
# Then force tensor to byte view
|
|
3545
|
+
dtype=torch.uint8
|
|
3546
|
+
)
|
|
3547
|
+
|
|
3548
|
+
# Convert the chunk buffer and optimizer state to byte views
|
|
3549
|
+
# Then use the start and end offsets to narrow the chunk buffer
|
|
3550
|
+
# and copy opt_state over
|
|
3551
|
+
chunk_buffer.view(dtype=torch.uint8)[:, start:end] = opt_state_byteview[
|
|
3552
|
+
i : i + length, :
|
|
3553
|
+
]
|
|
3554
|
+
|
|
3555
|
+
# Write chunk to KVTensor
|
|
3556
|
+
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
|
|
3557
|
+
|
|
3558
|
+
@torch.jit.ignore
|
|
3559
|
+
def enable_load_state_dict_mode(self) -> None:
|
|
3560
|
+
if self.backend_return_whole_row:
|
|
3561
|
+
logging.info(
|
|
3562
|
+
"backend_return_whole_row is enabled, no need to enable load_state_dict mode"
|
|
3563
|
+
)
|
|
3564
|
+
return
|
|
3565
|
+
# Enable load state dict mode before loading checkpoint
|
|
3566
|
+
if self.load_state_dict:
|
|
3567
|
+
return
|
|
3568
|
+
self.load_state_dict = True
|
|
3569
|
+
|
|
3570
|
+
dtype = self.weights_precision.as_dtype()
|
|
3571
|
+
(_, dims) = zip(*self.embedding_specs)
|
|
3572
|
+
|
|
3573
|
+
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
|
|
3574
|
+
|
|
3575
|
+
for i, _ in enumerate(self.embedding_specs):
|
|
3576
|
+
# For checkpointing loading, we need to store the weight and id
|
|
3577
|
+
# tensor temporarily in memory. First check that the local_weight_counts
|
|
3578
|
+
# are properly set before even initializing the optimizer states
|
|
3579
|
+
assert (
|
|
3580
|
+
self.local_weight_counts[i] > 0
|
|
3581
|
+
), f"local_weight_counts for table {i} is not set"
|
|
3582
|
+
|
|
3583
|
+
# pyre-ignore [16]
|
|
3584
|
+
self._cached_kvzch_data.cached_optimizer_states_per_table = (
|
|
3585
|
+
self.optimizer.empty_states(
|
|
3586
|
+
self.local_weight_counts,
|
|
3587
|
+
dims,
|
|
3588
|
+
self.optimizer_state_dtypes,
|
|
3589
|
+
)
|
|
3590
|
+
)
|
|
3591
|
+
|
|
3592
|
+
for i, (_, emb_dim) in enumerate(self.embedding_specs):
|
|
3593
|
+
# pyre-ignore [16]
|
|
3594
|
+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
|
|
3595
|
+
rows = self.local_weight_counts[i]
|
|
3596
|
+
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
|
|
3597
|
+
# pyre-ignore [16]
|
|
3598
|
+
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
|
|
3599
|
+
logging.info(
|
|
3600
|
+
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}"
|
|
3601
|
+
)
|
|
3602
|
+
id_tensor = torch.zeros((rows, 1), dtype=torch.int64, device="cpu")
|
|
3603
|
+
# pyre-ignore [16]
|
|
3604
|
+
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
|
|
3605
|
+
# pyre-ignore [16]
|
|
3606
|
+
self._cached_kvzch_data.cached_bucket_splits.append(
|
|
3607
|
+
torch.empty(
|
|
3608
|
+
(bucket_id_end - bucket_id_start, 1),
|
|
3609
|
+
dtype=torch.int64,
|
|
3610
|
+
device="cpu",
|
|
3611
|
+
)
|
|
3612
|
+
)
|
|
3613
|
+
|
|
3614
|
+
@torch.jit.export
|
|
3615
|
+
def set_learning_rate(self, lr: float) -> None:
|
|
3616
|
+
"""
|
|
3617
|
+
Sets the learning rate.
|
|
3618
|
+
|
|
3619
|
+
Args:
|
|
3620
|
+
lr (float): The learning rate value to set to
|
|
3621
|
+
"""
|
|
3622
|
+
self._set_learning_rate(lr)
|
|
3623
|
+
|
|
3624
|
+
def get_learning_rate(self) -> float:
|
|
3625
|
+
"""
|
|
3626
|
+
Get and return the learning rate.
|
|
3627
|
+
"""
|
|
3628
|
+
return self.learning_rate_tensor.item()
|
|
3629
|
+
|
|
3630
|
+
@torch.jit.ignore
|
|
3631
|
+
def _set_learning_rate(self, lr: float) -> float:
|
|
3632
|
+
"""
|
|
3633
|
+
Helper function to script `set_learning_rate`.
|
|
3634
|
+
Note that returning None does not work.
|
|
3635
|
+
"""
|
|
3636
|
+
self.learning_rate_tensor = torch.tensor(
|
|
3637
|
+
lr, device=torch.device("cpu"), dtype=torch.float32
|
|
3638
|
+
)
|
|
3639
|
+
return 0.0
|
|
3640
|
+
|
|
3641
|
+
def flush(self, force: bool = False) -> None:
|
|
3642
|
+
# allow force flush from split_embedding_weights to cover edge cases, e.g. checkpointing
|
|
3643
|
+
# after trained 0 batches
|
|
3644
|
+
if not self.training:
|
|
3645
|
+
# for eval mode, we should not write anything to embedding
|
|
3646
|
+
return
|
|
3647
|
+
|
|
3648
|
+
if self.step == self.last_flush_step and not force:
|
|
3649
|
+
logging.info(
|
|
3650
|
+
f"SSD TBE has been flushed at {self.last_flush_step=} already for tbe:{self.tbe_unique_id}"
|
|
3651
|
+
)
|
|
3652
|
+
return
|
|
3653
|
+
logging.info(
|
|
3654
|
+
f"SSD TBE flush at {self.step=}, it is an expensive call please be cautious"
|
|
3655
|
+
)
|
|
3656
|
+
active_slots_mask = self.lxu_cache_state != -1
|
|
3657
|
+
|
|
3658
|
+
active_weights_gpu = self.lxu_cache_weights[active_slots_mask.view(-1)].view(
|
|
3659
|
+
-1, self.cache_row_dim
|
|
3660
|
+
)
|
|
3661
|
+
active_ids_gpu = self.lxu_cache_state.view(-1)[active_slots_mask.view(-1)]
|
|
3662
|
+
|
|
3663
|
+
active_weights_cpu = active_weights_gpu.cpu()
|
|
3664
|
+
active_ids_cpu = active_ids_gpu.cpu()
|
|
3665
|
+
|
|
3666
|
+
torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream)
|
|
3667
|
+
|
|
3668
|
+
torch.cuda.synchronize()
|
|
3669
|
+
self.ssd_db.set(
|
|
3670
|
+
active_ids_cpu,
|
|
3671
|
+
active_weights_cpu,
|
|
3672
|
+
torch.tensor([active_ids_cpu.numel()]),
|
|
3673
|
+
)
|
|
3674
|
+
self.ssd_db.flush()
|
|
3675
|
+
self.last_flush_step = self.step
|
|
3676
|
+
|
|
3677
|
+
def create_rocksdb_hard_link_snapshot(self) -> None:
|
|
3678
|
+
"""
|
|
3679
|
+
Create a rocksdb hard link snapshot to provide cross procs access to the underlying data
|
|
3680
|
+
"""
|
|
3681
|
+
if self.backend_type == BackendType.SSD:
|
|
3682
|
+
self.ssd_db.create_rocksdb_hard_link_snapshot(self.step)
|
|
3683
|
+
else:
|
|
3684
|
+
logging.warning(
|
|
3685
|
+
"create_rocksdb_hard_link_snapshot is only supported for SSD backend"
|
|
3686
|
+
)
|
|
3687
|
+
|
|
3688
|
+
def prepare_inputs(
|
|
3689
|
+
self,
|
|
3690
|
+
indices: Tensor,
|
|
3691
|
+
offsets: Tensor,
|
|
3692
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
3693
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3694
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
|
|
3695
|
+
"""
|
|
3696
|
+
Prepare TBE inputs
|
|
3697
|
+
"""
|
|
3698
|
+
# Generate VBE metadata
|
|
3699
|
+
vbe_metadata = self._generate_vbe_metadata(
|
|
3700
|
+
offsets, batch_size_per_feature_per_rank
|
|
3701
|
+
)
|
|
3702
|
+
|
|
3703
|
+
# Force casting indices and offsets to long
|
|
3704
|
+
(indices, offsets) = indices.long(), offsets.long()
|
|
3705
|
+
|
|
3706
|
+
# Force casting per_sample_weights to float
|
|
3707
|
+
if per_sample_weights is not None:
|
|
3708
|
+
per_sample_weights = per_sample_weights.float()
|
|
3709
|
+
|
|
3710
|
+
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
|
|
3711
|
+
torch.ops.fbgemm.bounds_check_indices(
|
|
3712
|
+
self.rows_per_table,
|
|
3713
|
+
indices,
|
|
3714
|
+
offsets,
|
|
3715
|
+
self.bounds_check_mode_int,
|
|
3716
|
+
self.bounds_check_warning,
|
|
3717
|
+
per_sample_weights,
|
|
3718
|
+
B_offsets=vbe_metadata.B_offsets,
|
|
3719
|
+
max_B=vbe_metadata.max_B,
|
|
3720
|
+
bounds_check_version=self.bounds_check_version,
|
|
3721
|
+
)
|
|
3722
|
+
|
|
3723
|
+
return indices, offsets, per_sample_weights, vbe_metadata
|
|
3724
|
+
|
|
3725
|
+
@torch.jit.ignore
|
|
3726
|
+
def _report_kv_backend_stats(self) -> None:
|
|
3727
|
+
"""
|
|
3728
|
+
All ssd stats report function entrance
|
|
3729
|
+
"""
|
|
3730
|
+
if self.stats_reporter is None:
|
|
3731
|
+
return
|
|
3732
|
+
|
|
3733
|
+
if not self.stats_reporter.should_report(self.step):
|
|
3734
|
+
return
|
|
3735
|
+
self._report_ssd_l1_cache_stats()
|
|
3736
|
+
|
|
3737
|
+
if self.backend_type == BackendType.SSD:
|
|
3738
|
+
self._report_ssd_io_stats()
|
|
3739
|
+
self._report_ssd_mem_usage()
|
|
3740
|
+
self._report_l2_cache_perf_stats()
|
|
3741
|
+
if self.backend_type == BackendType.DRAM:
|
|
3742
|
+
self._report_dram_kv_perf_stats()
|
|
3743
|
+
if self.kv_zch_params and self.kv_zch_params.eviction_policy:
|
|
3744
|
+
self._report_eviction_stats()
|
|
3745
|
+
|
|
3746
|
+
@torch.jit.ignore
|
|
3747
|
+
def _report_ssd_l1_cache_stats(self) -> None:
|
|
3748
|
+
"""
|
|
3749
|
+
Each iteration we will record cache stats about L1 SSD cache in ssd_cache_stats tensor
|
|
3750
|
+
this function extract those stats and report it with stats_reporter
|
|
3751
|
+
"""
|
|
3752
|
+
passed_steps = self.step - self.last_reported_step
|
|
3753
|
+
if passed_steps == 0:
|
|
3754
|
+
return
|
|
3755
|
+
|
|
3756
|
+
# ssd hbm cache stats
|
|
3757
|
+
|
|
3758
|
+
ssd_cache_stats = self.ssd_cache_stats.tolist()
|
|
3759
|
+
if len(self.last_reported_ssd_stats) == 0:
|
|
3760
|
+
self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
|
|
3761
|
+
ssd_cache_stats_delta: list[float] = [0.0] * len(ssd_cache_stats)
|
|
3762
|
+
for i in range(len(ssd_cache_stats)):
|
|
3763
|
+
ssd_cache_stats_delta[i] = (
|
|
3764
|
+
ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
|
|
3765
|
+
)
|
|
3766
|
+
self.last_reported_step = self.step
|
|
3767
|
+
self.last_reported_ssd_stats = ssd_cache_stats
|
|
3768
|
+
element_size = self.lxu_cache_weights.element_size()
|
|
3769
|
+
|
|
3770
|
+
for stat_index in UVMCacheStatsIndex:
|
|
3771
|
+
# pyre-ignore
|
|
3772
|
+
self.stats_reporter.report_data_amount(
|
|
3773
|
+
iteration_step=self.step,
|
|
3774
|
+
event_name=f"ssd_tbe.prefetch.cache_stats_by_data_size.{stat_index.name.lower()}",
|
|
3775
|
+
data_bytes=int(
|
|
3776
|
+
ssd_cache_stats_delta[stat_index.value]
|
|
3777
|
+
* element_size
|
|
3778
|
+
* self.cache_row_dim
|
|
3779
|
+
/ passed_steps
|
|
3780
|
+
),
|
|
3781
|
+
)
|
|
3782
|
+
|
|
3783
|
+
self.stats_reporter.report_data_amount(
|
|
3784
|
+
iteration_step=self.step,
|
|
3785
|
+
event_name=f"ssd_tbe.prefetch.cache_stats.{stat_index.name.lower()}",
|
|
3786
|
+
data_bytes=int(ssd_cache_stats_delta[stat_index.value] / passed_steps),
|
|
3787
|
+
)
|
|
3788
|
+
|
|
3789
|
+
@torch.jit.ignore
|
|
3790
|
+
def _report_ssd_io_stats(self) -> None:
|
|
3791
|
+
"""
|
|
3792
|
+
EmbeddingRocksDB will hold stats for total read/write duration in fwd/bwd
|
|
3793
|
+
this function fetch the stats from EmbeddingRocksDB and report it with stats_reporter
|
|
3794
|
+
"""
|
|
3795
|
+
ssd_io_duration = self.ssd_db.get_rocksdb_io_duration(
|
|
3796
|
+
self.step, self.stats_reporter.report_interval # pyre-ignore
|
|
3797
|
+
)
|
|
3798
|
+
|
|
3799
|
+
if len(ssd_io_duration) != 5:
|
|
3800
|
+
logging.error("ssd io duration should have 5 elements")
|
|
3801
|
+
return
|
|
3802
|
+
|
|
3803
|
+
ssd_read_dur_us = ssd_io_duration[0]
|
|
3804
|
+
fwd_rocksdb_read_dur = ssd_io_duration[1]
|
|
3805
|
+
fwd_l1_eviction_dur = ssd_io_duration[2]
|
|
3806
|
+
bwd_l1_cnflct_miss_write_back_dur = ssd_io_duration[3]
|
|
3807
|
+
flush_write_dur = ssd_io_duration[4]
|
|
3808
|
+
|
|
3809
|
+
# pyre-ignore [16]
|
|
3810
|
+
self.stats_reporter.report_duration(
|
|
3811
|
+
iteration_step=self.step,
|
|
3812
|
+
event_name="ssd.io_duration.read_us",
|
|
3813
|
+
duration_ms=ssd_read_dur_us,
|
|
3814
|
+
time_unit="us",
|
|
3815
|
+
)
|
|
3816
|
+
|
|
3817
|
+
self.stats_reporter.report_duration(
|
|
3818
|
+
iteration_step=self.step,
|
|
3819
|
+
event_name="ssd.io_duration.write.fwd_rocksdb_read_us",
|
|
3820
|
+
duration_ms=fwd_rocksdb_read_dur,
|
|
3821
|
+
time_unit="us",
|
|
3822
|
+
)
|
|
3823
|
+
|
|
3824
|
+
self.stats_reporter.report_duration(
|
|
3825
|
+
iteration_step=self.step,
|
|
3826
|
+
event_name="ssd.io_duration.write.fwd_l1_eviction_us",
|
|
3827
|
+
duration_ms=fwd_l1_eviction_dur,
|
|
3828
|
+
time_unit="us",
|
|
3829
|
+
)
|
|
3830
|
+
|
|
3831
|
+
self.stats_reporter.report_duration(
|
|
3832
|
+
iteration_step=self.step,
|
|
3833
|
+
event_name="ssd.io_duration.write.bwd_l1_cnflct_miss_write_back_us",
|
|
3834
|
+
duration_ms=bwd_l1_cnflct_miss_write_back_dur,
|
|
3835
|
+
time_unit="us",
|
|
3836
|
+
)
|
|
3837
|
+
|
|
3838
|
+
self.stats_reporter.report_duration(
|
|
3839
|
+
iteration_step=self.step,
|
|
3840
|
+
event_name="ssd.io_duration.write.flush_write_us",
|
|
3841
|
+
duration_ms=flush_write_dur,
|
|
3842
|
+
time_unit="us",
|
|
3843
|
+
)
|
|
3844
|
+
|
|
3845
|
+
@torch.jit.ignore
|
|
3846
|
+
def _report_ssd_mem_usage(
|
|
3847
|
+
self,
|
|
3848
|
+
) -> None:
|
|
3849
|
+
"""
|
|
3850
|
+
rocskdb has internal stats for dram mem usage, here we call EmbeddingRocksDB to
|
|
3851
|
+
extract those stats out and report it with stats_reporter
|
|
3852
|
+
"""
|
|
3853
|
+
mem_usage_list = self.ssd_db.get_mem_usage()
|
|
3854
|
+
block_cache_usage = mem_usage_list[0]
|
|
3855
|
+
estimate_table_reader_usage = mem_usage_list[1]
|
|
3856
|
+
memtable_usage = mem_usage_list[2]
|
|
3857
|
+
block_cache_pinned_usage = mem_usage_list[3]
|
|
3858
|
+
|
|
3859
|
+
# pyre-ignore [16]
|
|
3860
|
+
self.stats_reporter.report_data_amount(
|
|
3861
|
+
iteration_step=self.step,
|
|
3862
|
+
event_name="ssd.mem_usage.block_cache",
|
|
3863
|
+
data_bytes=block_cache_usage,
|
|
3864
|
+
)
|
|
3865
|
+
|
|
3866
|
+
self.stats_reporter.report_data_amount(
|
|
3867
|
+
iteration_step=self.step,
|
|
3868
|
+
event_name="ssd.mem_usage.estimate_table_reader",
|
|
3869
|
+
data_bytes=estimate_table_reader_usage,
|
|
3870
|
+
)
|
|
3871
|
+
|
|
3872
|
+
self.stats_reporter.report_data_amount(
|
|
3873
|
+
iteration_step=self.step,
|
|
3874
|
+
event_name="ssd.mem_usage.memtable",
|
|
3875
|
+
data_bytes=memtable_usage,
|
|
3876
|
+
)
|
|
3877
|
+
|
|
3878
|
+
self.stats_reporter.report_data_amount(
|
|
3879
|
+
iteration_step=self.step,
|
|
3880
|
+
event_name="ssd.mem_usage.block_cache_pinned",
|
|
3881
|
+
data_bytes=block_cache_pinned_usage,
|
|
3882
|
+
)
|
|
3883
|
+
|
|
3884
|
+
@torch.jit.ignore
|
|
3885
|
+
def _report_l2_cache_perf_stats(self) -> None:
|
|
3886
|
+
"""
|
|
3887
|
+
EmbeddingKVDB will hold stats for L2+SSD performance in fwd/bwd
|
|
3888
|
+
this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
|
|
3889
|
+
"""
|
|
3890
|
+
if self.stats_reporter is None:
|
|
3891
|
+
return
|
|
3892
|
+
|
|
3893
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
3894
|
+
if not stats_reporter.should_report(self.step):
|
|
3895
|
+
return
|
|
3896
|
+
|
|
3897
|
+
l2_cache_perf_stats = self.ssd_db.get_l2cache_perf(
|
|
3898
|
+
self.step, stats_reporter.report_interval # pyre-ignore
|
|
3899
|
+
)
|
|
3900
|
+
|
|
3901
|
+
if len(l2_cache_perf_stats) != 15:
|
|
3902
|
+
logging.error("l2 perf stats should have 15 elements")
|
|
3903
|
+
return
|
|
3904
|
+
|
|
3905
|
+
num_cache_misses = l2_cache_perf_stats[0]
|
|
3906
|
+
num_lookups = l2_cache_perf_stats[1]
|
|
3907
|
+
get_total_duration = l2_cache_perf_stats[2]
|
|
3908
|
+
get_cache_lookup_total_duration = l2_cache_perf_stats[3]
|
|
3909
|
+
get_cache_lookup_wait_filling_thread_duration = l2_cache_perf_stats[4]
|
|
3910
|
+
get_weights_fillup_total_duration = l2_cache_perf_stats[5]
|
|
3911
|
+
get_cache_memcpy_duration = l2_cache_perf_stats[6]
|
|
3912
|
+
total_cache_update_duration = l2_cache_perf_stats[7]
|
|
3913
|
+
get_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[8]
|
|
3914
|
+
set_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[9]
|
|
3915
|
+
num_l2_evictions = l2_cache_perf_stats[10]
|
|
3916
|
+
|
|
3917
|
+
l2_cache_free_bytes = l2_cache_perf_stats[11]
|
|
3918
|
+
l2_cache_capacity = l2_cache_perf_stats[12]
|
|
3919
|
+
|
|
3920
|
+
set_cache_lock_wait_duration = l2_cache_perf_stats[13]
|
|
3921
|
+
get_cache_lock_wait_duration = l2_cache_perf_stats[14]
|
|
3922
|
+
|
|
3923
|
+
stats_reporter.report_data_amount(
|
|
3924
|
+
iteration_step=self.step,
|
|
3925
|
+
event_name=self.l2_num_cache_misses_stats_name,
|
|
3926
|
+
data_bytes=num_cache_misses,
|
|
3927
|
+
)
|
|
3928
|
+
stats_reporter.report_data_amount(
|
|
3929
|
+
iteration_step=self.step,
|
|
3930
|
+
event_name=self.l2_num_cache_lookups_stats_name,
|
|
3931
|
+
data_bytes=num_lookups,
|
|
3932
|
+
)
|
|
3933
|
+
stats_reporter.report_data_amount(
|
|
3934
|
+
iteration_step=self.step,
|
|
3935
|
+
event_name=self.l2_num_cache_evictions_stats_name,
|
|
3936
|
+
data_bytes=num_l2_evictions,
|
|
3937
|
+
)
|
|
3938
|
+
stats_reporter.report_data_amount(
|
|
3939
|
+
iteration_step=self.step,
|
|
3940
|
+
event_name=self.l2_cache_capacity_stats_name,
|
|
3941
|
+
data_bytes=l2_cache_capacity,
|
|
3942
|
+
)
|
|
3943
|
+
stats_reporter.report_data_amount(
|
|
3944
|
+
iteration_step=self.step,
|
|
3945
|
+
event_name=self.l2_cache_free_mem_stats_name,
|
|
3946
|
+
data_bytes=l2_cache_free_bytes,
|
|
3947
|
+
)
|
|
3948
|
+
|
|
3949
|
+
stats_reporter.report_duration(
|
|
3950
|
+
iteration_step=self.step,
|
|
3951
|
+
event_name="l2_cache.perf.get.total_duration_us",
|
|
3952
|
+
duration_ms=get_total_duration,
|
|
3953
|
+
time_unit="us",
|
|
3954
|
+
)
|
|
3955
|
+
stats_reporter.report_duration(
|
|
3956
|
+
iteration_step=self.step,
|
|
3957
|
+
event_name="l2_cache.perf.get.cache_lookup_duration_us",
|
|
3958
|
+
duration_ms=get_cache_lookup_total_duration,
|
|
3959
|
+
time_unit="us",
|
|
3960
|
+
)
|
|
3961
|
+
stats_reporter.report_duration(
|
|
3962
|
+
iteration_step=self.step,
|
|
3963
|
+
event_name="l2_cache.perf.get.cache_lookup_wait_filling_thread_duration_us",
|
|
3964
|
+
duration_ms=get_cache_lookup_wait_filling_thread_duration,
|
|
3965
|
+
time_unit="us",
|
|
3966
|
+
)
|
|
3967
|
+
stats_reporter.report_duration(
|
|
3968
|
+
iteration_step=self.step,
|
|
3969
|
+
event_name="l2_cache.perf.get.weights_fillup_duration_us",
|
|
3970
|
+
duration_ms=get_weights_fillup_total_duration,
|
|
3971
|
+
time_unit="us",
|
|
3972
|
+
)
|
|
3973
|
+
stats_reporter.report_duration(
|
|
3974
|
+
iteration_step=self.step,
|
|
3975
|
+
event_name="l2_cache.perf.get.cache_memcpy_duration_us",
|
|
3976
|
+
duration_ms=get_cache_memcpy_duration,
|
|
3977
|
+
time_unit="us",
|
|
3978
|
+
)
|
|
3979
|
+
stats_reporter.report_duration(
|
|
3980
|
+
iteration_step=self.step,
|
|
3981
|
+
event_name="l2_cache.perf.total.cache_update_duration_us",
|
|
3982
|
+
duration_ms=total_cache_update_duration,
|
|
3983
|
+
time_unit="us",
|
|
3984
|
+
)
|
|
3985
|
+
stats_reporter.report_duration(
|
|
3986
|
+
iteration_step=self.step,
|
|
3987
|
+
event_name="l2_cache.perf.get.tensor_copy_for_cache_update_duration_us",
|
|
3988
|
+
duration_ms=get_tensor_copy_for_cache_update_duration,
|
|
3989
|
+
time_unit="us",
|
|
3990
|
+
)
|
|
3991
|
+
stats_reporter.report_duration(
|
|
3992
|
+
iteration_step=self.step,
|
|
3993
|
+
event_name="l2_cache.perf.set.tensor_copy_for_cache_update_duration_us",
|
|
3994
|
+
duration_ms=set_tensor_copy_for_cache_update_duration,
|
|
3995
|
+
time_unit="us",
|
|
3996
|
+
)
|
|
3997
|
+
|
|
3998
|
+
stats_reporter.report_duration(
|
|
3999
|
+
iteration_step=self.step,
|
|
4000
|
+
event_name="l2_cache.perf.get.cache_lock_wait_duration_us",
|
|
4001
|
+
duration_ms=get_cache_lock_wait_duration,
|
|
4002
|
+
time_unit="us",
|
|
4003
|
+
)
|
|
4004
|
+
stats_reporter.report_duration(
|
|
4005
|
+
iteration_step=self.step,
|
|
4006
|
+
event_name="l2_cache.perf.set.cache_lock_wait_duration_us",
|
|
4007
|
+
duration_ms=set_cache_lock_wait_duration,
|
|
4008
|
+
time_unit="us",
|
|
4009
|
+
)
|
|
4010
|
+
|
|
4011
|
+
@torch.jit.ignore
|
|
4012
|
+
def _report_eviction_stats(self) -> None:
|
|
4013
|
+
if self.stats_reporter is None:
|
|
4014
|
+
return
|
|
4015
|
+
|
|
4016
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
4017
|
+
if not stats_reporter.should_report(self.step):
|
|
4018
|
+
return
|
|
4019
|
+
|
|
4020
|
+
# skip metrics reporting when evicting disabled
|
|
4021
|
+
if self.kv_zch_params.eviction_policy.eviction_trigger_mode == 0:
|
|
4022
|
+
return
|
|
4023
|
+
|
|
4024
|
+
T = len(set(self.feature_table_map))
|
|
4025
|
+
evicted_counts = torch.zeros(T, dtype=torch.int64)
|
|
4026
|
+
processed_counts = torch.zeros(T, dtype=torch.int64)
|
|
4027
|
+
eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float)
|
|
4028
|
+
full_duration_ms = torch.tensor(0, dtype=torch.int64)
|
|
4029
|
+
exec_duration_ms = torch.tensor(0, dtype=torch.int64)
|
|
4030
|
+
self.ssd_db.get_feature_evict_metric(
|
|
4031
|
+
evicted_counts,
|
|
4032
|
+
processed_counts,
|
|
4033
|
+
eviction_threshold_with_dry_run,
|
|
4034
|
+
full_duration_ms,
|
|
4035
|
+
exec_duration_ms,
|
|
4036
|
+
)
|
|
4037
|
+
|
|
4038
|
+
stats_reporter.report_data_amount(
|
|
4039
|
+
iteration_step=self.step,
|
|
4040
|
+
event_name=self.eviction_sum_evicted_counts_stats_name,
|
|
4041
|
+
data_bytes=int(evicted_counts.sum().item()),
|
|
4042
|
+
enable_tb_metrics=True,
|
|
4043
|
+
)
|
|
4044
|
+
stats_reporter.report_data_amount(
|
|
4045
|
+
iteration_step=self.step,
|
|
4046
|
+
event_name=self.eviction_sum_processed_counts_stats_name,
|
|
4047
|
+
data_bytes=int(processed_counts.sum().item()),
|
|
4048
|
+
enable_tb_metrics=True,
|
|
4049
|
+
)
|
|
4050
|
+
if processed_counts.sum().item() != 0:
|
|
4051
|
+
stats_reporter.report_data_amount(
|
|
4052
|
+
iteration_step=self.step,
|
|
4053
|
+
event_name=self.eviction_evict_rate_stats_name,
|
|
4054
|
+
data_bytes=int(
|
|
4055
|
+
evicted_counts.sum().item() * 100 / processed_counts.sum().item()
|
|
4056
|
+
),
|
|
4057
|
+
enable_tb_metrics=True,
|
|
4058
|
+
)
|
|
4059
|
+
for t in self.feature_table_map:
|
|
4060
|
+
stats_reporter.report_data_amount(
|
|
4061
|
+
iteration_step=self.step,
|
|
4062
|
+
event_name=f"eviction.feature_table.{t}.evicted_counts",
|
|
4063
|
+
data_bytes=int(evicted_counts[t].item()),
|
|
4064
|
+
enable_tb_metrics=True,
|
|
4065
|
+
)
|
|
4066
|
+
stats_reporter.report_data_amount(
|
|
4067
|
+
iteration_step=self.step,
|
|
4068
|
+
event_name=f"eviction.feature_table.{t}.processed_counts",
|
|
4069
|
+
data_bytes=int(processed_counts[t].item()),
|
|
4070
|
+
enable_tb_metrics=True,
|
|
4071
|
+
)
|
|
4072
|
+
if processed_counts[t].item() != 0:
|
|
4073
|
+
stats_reporter.report_data_amount(
|
|
4074
|
+
iteration_step=self.step,
|
|
4075
|
+
event_name=f"eviction.feature_table.{t}.evict_rate",
|
|
4076
|
+
data_bytes=int(
|
|
4077
|
+
evicted_counts[t].item() * 100 / processed_counts[t].item()
|
|
4078
|
+
),
|
|
4079
|
+
enable_tb_metrics=True,
|
|
4080
|
+
)
|
|
4081
|
+
stats_reporter.report_duration(
|
|
4082
|
+
iteration_step=self.step,
|
|
4083
|
+
event_name="eviction.feature_table.full_duration_ms",
|
|
4084
|
+
duration_ms=full_duration_ms.item(),
|
|
4085
|
+
time_unit="ms",
|
|
4086
|
+
enable_tb_metrics=True,
|
|
4087
|
+
)
|
|
4088
|
+
stats_reporter.report_duration(
|
|
4089
|
+
iteration_step=self.step,
|
|
4090
|
+
event_name="eviction.feature_table.exec_duration_ms",
|
|
4091
|
+
duration_ms=exec_duration_ms.item(),
|
|
4092
|
+
time_unit="ms",
|
|
4093
|
+
enable_tb_metrics=True,
|
|
4094
|
+
)
|
|
4095
|
+
if full_duration_ms.item() != 0:
|
|
4096
|
+
stats_reporter.report_data_amount(
|
|
4097
|
+
iteration_step=self.step,
|
|
4098
|
+
event_name="eviction.feature_table.exec_div_full_duration_rate",
|
|
4099
|
+
data_bytes=int(exec_duration_ms.item() * 100 / full_duration_ms.item()),
|
|
4100
|
+
enable_tb_metrics=True,
|
|
4101
|
+
)
|
|
4102
|
+
|
|
4103
|
+
@torch.jit.ignore
|
|
4104
|
+
def _report_dram_kv_perf_stats(self) -> None:
|
|
4105
|
+
"""
|
|
4106
|
+
EmbeddingKVDB will hold stats for DRAM cache performance in fwd/bwd
|
|
4107
|
+
this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
|
|
4108
|
+
"""
|
|
4109
|
+
if self.stats_reporter is None:
|
|
4110
|
+
return
|
|
4111
|
+
|
|
4112
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
4113
|
+
if not stats_reporter.should_report(self.step):
|
|
4114
|
+
return
|
|
4115
|
+
|
|
4116
|
+
dram_kv_perf_stats = self.ssd_db.get_dram_kv_perf(
|
|
4117
|
+
self.step, stats_reporter.report_interval # pyre-ignore
|
|
4118
|
+
)
|
|
4119
|
+
|
|
4120
|
+
if len(dram_kv_perf_stats) != 36:
|
|
4121
|
+
logging.error("dram cache perf stats should have 36 elements")
|
|
4122
|
+
return
|
|
4123
|
+
|
|
4124
|
+
dram_read_duration = dram_kv_perf_stats[0]
|
|
4125
|
+
dram_read_sharding_duration = dram_kv_perf_stats[1]
|
|
4126
|
+
dram_read_cache_hit_copy_duration = dram_kv_perf_stats[2]
|
|
4127
|
+
dram_read_fill_row_storage_duration = dram_kv_perf_stats[3]
|
|
4128
|
+
dram_read_lookup_cache_duration = dram_kv_perf_stats[4]
|
|
4129
|
+
dram_read_acquire_lock_duration = dram_kv_perf_stats[5]
|
|
4130
|
+
dram_read_missing_load = dram_kv_perf_stats[6]
|
|
4131
|
+
dram_write_sharing_duration = dram_kv_perf_stats[7]
|
|
4132
|
+
|
|
4133
|
+
dram_fwd_l1_eviction_write_duration = dram_kv_perf_stats[8]
|
|
4134
|
+
dram_fwd_l1_eviction_write_allocate_duration = dram_kv_perf_stats[9]
|
|
4135
|
+
dram_fwd_l1_eviction_write_cache_copy_duration = dram_kv_perf_stats[10]
|
|
4136
|
+
dram_fwd_l1_eviction_write_lookup_cache_duration = dram_kv_perf_stats[11]
|
|
4137
|
+
dram_fwd_l1_eviction_write_acquire_lock_duration = dram_kv_perf_stats[12]
|
|
4138
|
+
dram_fwd_l1_eviction_write_missing_load = dram_kv_perf_stats[13]
|
|
4139
|
+
|
|
4140
|
+
dram_bwd_l1_cnflct_miss_write_duration = dram_kv_perf_stats[14]
|
|
4141
|
+
dram_bwd_l1_cnflct_miss_write_allocate_duration = dram_kv_perf_stats[15]
|
|
4142
|
+
dram_bwd_l1_cnflct_miss_write_cache_copy_duration = dram_kv_perf_stats[16]
|
|
4143
|
+
dram_bwd_l1_cnflct_miss_write_lookup_cache_duration = dram_kv_perf_stats[17]
|
|
4144
|
+
dram_bwd_l1_cnflct_miss_write_acquire_lock_duration = dram_kv_perf_stats[18]
|
|
4145
|
+
dram_bwd_l1_cnflct_miss_write_missing_load = dram_kv_perf_stats[19]
|
|
4146
|
+
|
|
4147
|
+
dram_kv_allocated_bytes = dram_kv_perf_stats[20]
|
|
4148
|
+
dram_kv_actual_used_chunk_bytes = dram_kv_perf_stats[21]
|
|
4149
|
+
dram_kv_num_rows = dram_kv_perf_stats[22]
|
|
4150
|
+
dram_kv_read_counts = dram_kv_perf_stats[23]
|
|
4151
|
+
dram_metadata_write_sharding_total_duration = dram_kv_perf_stats[24]
|
|
4152
|
+
dram_metadata_write_total_duration = dram_kv_perf_stats[25]
|
|
4153
|
+
dram_metadata_write_allocate_avg_duration = dram_kv_perf_stats[26]
|
|
4154
|
+
dram_metadata_write_lookup_cache_avg_duration = dram_kv_perf_stats[27]
|
|
4155
|
+
dram_metadata_write_acquire_lock_avg_duration = dram_kv_perf_stats[28]
|
|
4156
|
+
dram_metadata_write_cache_miss_avg_count = dram_kv_perf_stats[29]
|
|
4157
|
+
|
|
4158
|
+
dram_read_metadata_total_duration = dram_kv_perf_stats[30]
|
|
4159
|
+
dram_read_metadata_sharding_total_duration = dram_kv_perf_stats[31]
|
|
4160
|
+
dram_read_metadata_cache_hit_copy_avg_duration = dram_kv_perf_stats[32]
|
|
4161
|
+
dram_read_metadata_lookup_cache_total_avg_duration = dram_kv_perf_stats[33]
|
|
4162
|
+
dram_read_metadata_acquire_lock_avg_duration = dram_kv_perf_stats[34]
|
|
4163
|
+
dram_read_read_metadata_load_size = dram_kv_perf_stats[35]
|
|
4164
|
+
|
|
4165
|
+
stats_reporter.report_duration(
|
|
4166
|
+
iteration_step=self.step,
|
|
4167
|
+
event_name="dram_kv.perf.get.dram_read_duration_us",
|
|
4168
|
+
duration_ms=dram_read_duration,
|
|
4169
|
+
enable_tb_metrics=True,
|
|
4170
|
+
time_unit="us",
|
|
4171
|
+
)
|
|
4172
|
+
stats_reporter.report_duration(
|
|
4173
|
+
iteration_step=self.step,
|
|
4174
|
+
event_name="dram_kv.perf.get.dram_read_sharding_duration_us",
|
|
4175
|
+
duration_ms=dram_read_sharding_duration,
|
|
4176
|
+
enable_tb_metrics=True,
|
|
4177
|
+
time_unit="us",
|
|
4178
|
+
)
|
|
4179
|
+
stats_reporter.report_duration(
|
|
4180
|
+
iteration_step=self.step,
|
|
4181
|
+
event_name="dram_kv.perf.get.dram_read_cache_hit_copy_duration_us",
|
|
4182
|
+
duration_ms=dram_read_cache_hit_copy_duration,
|
|
4183
|
+
enable_tb_metrics=True,
|
|
4184
|
+
time_unit="us",
|
|
4185
|
+
)
|
|
4186
|
+
stats_reporter.report_duration(
|
|
4187
|
+
iteration_step=self.step,
|
|
4188
|
+
event_name="dram_kv.perf.get.dram_read_fill_row_storage_duration_us",
|
|
4189
|
+
duration_ms=dram_read_fill_row_storage_duration,
|
|
4190
|
+
enable_tb_metrics=True,
|
|
4191
|
+
time_unit="us",
|
|
4192
|
+
)
|
|
4193
|
+
stats_reporter.report_duration(
|
|
4194
|
+
iteration_step=self.step,
|
|
4195
|
+
event_name="dram_kv.perf.get.dram_read_lookup_cache_duration_us",
|
|
4196
|
+
duration_ms=dram_read_lookup_cache_duration,
|
|
4197
|
+
enable_tb_metrics=True,
|
|
4198
|
+
time_unit="us",
|
|
4199
|
+
)
|
|
4200
|
+
stats_reporter.report_duration(
|
|
4201
|
+
iteration_step=self.step,
|
|
4202
|
+
event_name="dram_kv.perf.get.dram_read_acquire_lock_duration_us",
|
|
4203
|
+
duration_ms=dram_read_acquire_lock_duration,
|
|
4204
|
+
enable_tb_metrics=True,
|
|
4205
|
+
time_unit="us",
|
|
4206
|
+
)
|
|
4207
|
+
stats_reporter.report_data_amount(
|
|
4208
|
+
iteration_step=self.step,
|
|
4209
|
+
event_name="dram_kv.perf.get.dram_read_missing_load",
|
|
4210
|
+
enable_tb_metrics=True,
|
|
4211
|
+
data_bytes=dram_read_missing_load,
|
|
4212
|
+
)
|
|
4213
|
+
stats_reporter.report_duration(
|
|
4214
|
+
iteration_step=self.step,
|
|
4215
|
+
event_name="dram_kv.perf.set.dram_write_sharing_duration_us",
|
|
4216
|
+
duration_ms=dram_write_sharing_duration,
|
|
4217
|
+
enable_tb_metrics=True,
|
|
4218
|
+
time_unit="us",
|
|
4219
|
+
)
|
|
4220
|
+
|
|
4221
|
+
stats_reporter.report_duration(
|
|
4222
|
+
iteration_step=self.step,
|
|
4223
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_duration_us",
|
|
4224
|
+
duration_ms=dram_fwd_l1_eviction_write_duration,
|
|
4225
|
+
enable_tb_metrics=True,
|
|
4226
|
+
time_unit="us",
|
|
4227
|
+
)
|
|
4228
|
+
stats_reporter.report_duration(
|
|
4229
|
+
iteration_step=self.step,
|
|
4230
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_allocate_duration_us",
|
|
4231
|
+
duration_ms=dram_fwd_l1_eviction_write_allocate_duration,
|
|
4232
|
+
enable_tb_metrics=True,
|
|
4233
|
+
time_unit="us",
|
|
4234
|
+
)
|
|
4235
|
+
stats_reporter.report_duration(
|
|
4236
|
+
iteration_step=self.step,
|
|
4237
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_cache_copy_duration_us",
|
|
4238
|
+
duration_ms=dram_fwd_l1_eviction_write_cache_copy_duration,
|
|
4239
|
+
enable_tb_metrics=True,
|
|
4240
|
+
time_unit="us",
|
|
4241
|
+
)
|
|
4242
|
+
stats_reporter.report_duration(
|
|
4243
|
+
iteration_step=self.step,
|
|
4244
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_lookup_cache_duration_us",
|
|
4245
|
+
duration_ms=dram_fwd_l1_eviction_write_lookup_cache_duration,
|
|
4246
|
+
enable_tb_metrics=True,
|
|
4247
|
+
time_unit="us",
|
|
4248
|
+
)
|
|
4249
|
+
stats_reporter.report_duration(
|
|
4250
|
+
iteration_step=self.step,
|
|
4251
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_acquire_lock_duration_us",
|
|
4252
|
+
duration_ms=dram_fwd_l1_eviction_write_acquire_lock_duration,
|
|
4253
|
+
enable_tb_metrics=True,
|
|
4254
|
+
time_unit="us",
|
|
4255
|
+
)
|
|
4256
|
+
stats_reporter.report_data_amount(
|
|
4257
|
+
iteration_step=self.step,
|
|
4258
|
+
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load",
|
|
4259
|
+
data_bytes=dram_fwd_l1_eviction_write_missing_load,
|
|
4260
|
+
enable_tb_metrics=True,
|
|
4261
|
+
)
|
|
4262
|
+
|
|
4263
|
+
stats_reporter.report_duration(
|
|
4264
|
+
iteration_step=self.step,
|
|
4265
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_duration_us",
|
|
4266
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_duration,
|
|
4267
|
+
enable_tb_metrics=True,
|
|
4268
|
+
time_unit="us",
|
|
4269
|
+
)
|
|
4270
|
+
stats_reporter.report_duration(
|
|
4271
|
+
iteration_step=self.step,
|
|
4272
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_allocate_duration_us",
|
|
4273
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_allocate_duration,
|
|
4274
|
+
enable_tb_metrics=True,
|
|
4275
|
+
time_unit="us",
|
|
4276
|
+
)
|
|
4277
|
+
stats_reporter.report_duration(
|
|
4278
|
+
iteration_step=self.step,
|
|
4279
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_cache_copy_duration_us",
|
|
4280
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_cache_copy_duration,
|
|
4281
|
+
enable_tb_metrics=True,
|
|
4282
|
+
time_unit="us",
|
|
4283
|
+
)
|
|
4284
|
+
stats_reporter.report_duration(
|
|
4285
|
+
iteration_step=self.step,
|
|
4286
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_lookup_cache_duration_us",
|
|
4287
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_lookup_cache_duration,
|
|
4288
|
+
enable_tb_metrics=True,
|
|
4289
|
+
time_unit="us",
|
|
4290
|
+
)
|
|
4291
|
+
stats_reporter.report_duration(
|
|
4292
|
+
iteration_step=self.step,
|
|
4293
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_acquire_lock_duration_us",
|
|
4294
|
+
duration_ms=dram_bwd_l1_cnflct_miss_write_acquire_lock_duration,
|
|
4295
|
+
enable_tb_metrics=True,
|
|
4296
|
+
time_unit="us",
|
|
4297
|
+
)
|
|
4298
|
+
stats_reporter.report_data_amount(
|
|
4299
|
+
iteration_step=self.step,
|
|
4300
|
+
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load",
|
|
4301
|
+
data_bytes=dram_bwd_l1_cnflct_miss_write_missing_load,
|
|
4302
|
+
enable_tb_metrics=True,
|
|
4303
|
+
)
|
|
4304
|
+
|
|
4305
|
+
stats_reporter.report_data_amount(
|
|
4306
|
+
iteration_step=self.step,
|
|
4307
|
+
event_name="dram_kv.perf.get.dram_kv_read_counts",
|
|
4308
|
+
data_bytes=dram_kv_read_counts,
|
|
4309
|
+
enable_tb_metrics=True,
|
|
4310
|
+
)
|
|
4311
|
+
|
|
4312
|
+
stats_reporter.report_data_amount(
|
|
4313
|
+
iteration_step=self.step,
|
|
4314
|
+
event_name=self.dram_kv_allocated_bytes_stats_name,
|
|
4315
|
+
data_bytes=dram_kv_allocated_bytes,
|
|
4316
|
+
enable_tb_metrics=True,
|
|
4317
|
+
)
|
|
4318
|
+
stats_reporter.report_data_amount(
|
|
4319
|
+
iteration_step=self.step,
|
|
4320
|
+
event_name=self.dram_kv_actual_used_chunk_bytes_stats_name,
|
|
4321
|
+
data_bytes=dram_kv_actual_used_chunk_bytes,
|
|
4322
|
+
enable_tb_metrics=True,
|
|
4323
|
+
)
|
|
4324
|
+
stats_reporter.report_data_amount(
|
|
4325
|
+
iteration_step=self.step,
|
|
4326
|
+
event_name=self.dram_kv_mem_num_rows_stats_name,
|
|
4327
|
+
data_bytes=dram_kv_num_rows,
|
|
4328
|
+
enable_tb_metrics=True,
|
|
4329
|
+
)
|
|
4330
|
+
stats_reporter.report_duration(
|
|
4331
|
+
iteration_step=self.step,
|
|
4332
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_sharding_total_duration_us",
|
|
4333
|
+
duration_ms=dram_metadata_write_sharding_total_duration,
|
|
4334
|
+
enable_tb_metrics=True,
|
|
4335
|
+
time_unit="us",
|
|
4336
|
+
)
|
|
4337
|
+
stats_reporter.report_duration(
|
|
4338
|
+
iteration_step=self.step,
|
|
4339
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_total_duration_us",
|
|
4340
|
+
duration_ms=dram_metadata_write_total_duration,
|
|
4341
|
+
enable_tb_metrics=True,
|
|
4342
|
+
time_unit="us",
|
|
4343
|
+
)
|
|
4344
|
+
stats_reporter.report_duration(
|
|
4345
|
+
iteration_step=self.step,
|
|
4346
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_allocate_avg_duration_us",
|
|
4347
|
+
duration_ms=dram_metadata_write_allocate_avg_duration,
|
|
4348
|
+
enable_tb_metrics=True,
|
|
4349
|
+
time_unit="us",
|
|
4350
|
+
)
|
|
4351
|
+
stats_reporter.report_duration(
|
|
4352
|
+
iteration_step=self.step,
|
|
4353
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_lookup_cache_avg_duration_us",
|
|
4354
|
+
duration_ms=dram_metadata_write_lookup_cache_avg_duration,
|
|
4355
|
+
enable_tb_metrics=True,
|
|
4356
|
+
time_unit="us",
|
|
4357
|
+
)
|
|
4358
|
+
stats_reporter.report_duration(
|
|
4359
|
+
iteration_step=self.step,
|
|
4360
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_acquire_lock_avg_duration_us",
|
|
4361
|
+
duration_ms=dram_metadata_write_acquire_lock_avg_duration,
|
|
4362
|
+
enable_tb_metrics=True,
|
|
4363
|
+
time_unit="us",
|
|
4364
|
+
)
|
|
4365
|
+
stats_reporter.report_data_amount(
|
|
4366
|
+
iteration_step=self.step,
|
|
4367
|
+
event_name="dram_kv.perf.set.dram_eviction_score_write_cache_miss_avg_count",
|
|
4368
|
+
data_bytes=dram_metadata_write_cache_miss_avg_count,
|
|
4369
|
+
enable_tb_metrics=True,
|
|
4370
|
+
)
|
|
4371
|
+
stats_reporter.report_duration(
|
|
4372
|
+
iteration_step=self.step,
|
|
4373
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_total_duration_us",
|
|
4374
|
+
duration_ms=dram_read_metadata_total_duration,
|
|
4375
|
+
enable_tb_metrics=True,
|
|
4376
|
+
time_unit="us",
|
|
4377
|
+
)
|
|
4378
|
+
stats_reporter.report_duration(
|
|
4379
|
+
iteration_step=self.step,
|
|
4380
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_sharding_total_duration_us",
|
|
4381
|
+
duration_ms=dram_read_metadata_sharding_total_duration,
|
|
4382
|
+
enable_tb_metrics=True,
|
|
4383
|
+
time_unit="us",
|
|
4384
|
+
)
|
|
4385
|
+
stats_reporter.report_duration(
|
|
4386
|
+
iteration_step=self.step,
|
|
4387
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_cache_hit_copy_avg_duration_us",
|
|
4388
|
+
duration_ms=dram_read_metadata_cache_hit_copy_avg_duration,
|
|
4389
|
+
enable_tb_metrics=True,
|
|
4390
|
+
time_unit="us",
|
|
4391
|
+
)
|
|
4392
|
+
stats_reporter.report_duration(
|
|
4393
|
+
iteration_step=self.step,
|
|
4394
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_lookup_cache_total_avg_duration_us",
|
|
4395
|
+
duration_ms=dram_read_metadata_lookup_cache_total_avg_duration,
|
|
4396
|
+
enable_tb_metrics=True,
|
|
4397
|
+
time_unit="us",
|
|
4398
|
+
)
|
|
4399
|
+
stats_reporter.report_duration(
|
|
4400
|
+
iteration_step=self.step,
|
|
4401
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_acquire_lock_avg_duration_us",
|
|
4402
|
+
duration_ms=dram_read_metadata_acquire_lock_avg_duration,
|
|
4403
|
+
enable_tb_metrics=True,
|
|
4404
|
+
time_unit="us",
|
|
4405
|
+
)
|
|
4406
|
+
stats_reporter.report_data_amount(
|
|
4407
|
+
iteration_step=self.step,
|
|
4408
|
+
event_name="dram_kv.perf.get.dram_eviction_score_read_load_size",
|
|
4409
|
+
data_bytes=dram_read_read_metadata_load_size,
|
|
4410
|
+
enable_tb_metrics=True,
|
|
4411
|
+
)
|
|
4412
|
+
|
|
4413
|
+
def _recording_to_timer(
|
|
4414
|
+
self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
|
|
4415
|
+
) -> Any:
|
|
4416
|
+
"""
|
|
4417
|
+
helper function to call AsyncSeriesTimer, wrap it inside the kernels we want to record
|
|
4418
|
+
"""
|
|
4419
|
+
if self.stats_reporter is not None and self.stats_reporter.should_report(
|
|
4420
|
+
self.step
|
|
4421
|
+
):
|
|
4422
|
+
assert (
|
|
4423
|
+
timer
|
|
4424
|
+
), "We shouldn't be here, async timer must have been initiated if reporter is present."
|
|
4425
|
+
return timer.recording(**kwargs)
|
|
4426
|
+
# No-Op context manager
|
|
4427
|
+
return contextlib.nullcontext()
|
|
4428
|
+
|
|
4429
|
+
def fetch_from_l1_sp_w_row_ids(
|
|
4430
|
+
self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False
|
|
4431
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
4432
|
+
"""
|
|
4433
|
+
Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
|
|
4434
|
+
@return: updated_weights/optimizer_states, mask of which rows are filled
|
|
4435
|
+
"""
|
|
4436
|
+
if not self.enable_optimizer_offloading and only_get_optimizer_states:
|
|
4437
|
+
raise RuntimeError(
|
|
4438
|
+
"Optimizer states are not offloaded, while only_get_optimizer_states is True"
|
|
4439
|
+
)
|
|
4440
|
+
|
|
4441
|
+
# NOTE: Remove this once there is support for fetching multiple
|
|
4442
|
+
# optimizer states in fetch_from_l1_sp_w_row_ids
|
|
4443
|
+
if only_get_optimizer_states and self.optimizer not in [
|
|
4444
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
4445
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
4446
|
+
]:
|
|
4447
|
+
raise RuntimeError(
|
|
4448
|
+
f"Fetching optimizer states using fetch_from_l1_sp_w_row_ids() is not yet supported for {self.optimizer}"
|
|
4449
|
+
)
|
|
4450
|
+
|
|
4451
|
+
def split_results_by_opt_states(
|
|
4452
|
+
updated_weights: torch.Tensor, cache_location_mask: torch.Tensor
|
|
4453
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
4454
|
+
if not only_get_optimizer_states:
|
|
4455
|
+
return [updated_weights], cache_location_mask
|
|
4456
|
+
# TODO: support mixed dimension case
|
|
4457
|
+
# currently only supports tables with the same max_D dimension
|
|
4458
|
+
opt_to_dim = self.optimizer.byte_offsets_along_row(
|
|
4459
|
+
self.max_D, self.weights_precision, self.optimizer_state_dtypes
|
|
4460
|
+
)
|
|
4461
|
+
updated_opt_states = []
|
|
4462
|
+
for opt_name, dim in opt_to_dim.items():
|
|
4463
|
+
opt_dtype = self.optimizer._extract_dtype(
|
|
4464
|
+
self.optimizer_state_dtypes, opt_name
|
|
4465
|
+
)
|
|
4466
|
+
updated_opt_states.append(
|
|
4467
|
+
updated_weights.view(dtype=torch.uint8)[:, dim[0] : dim[1]].view(
|
|
4468
|
+
dtype=opt_dtype
|
|
4469
|
+
)
|
|
4470
|
+
)
|
|
4471
|
+
return updated_opt_states, cache_location_mask
|
|
4472
|
+
|
|
4473
|
+
with torch.no_grad():
|
|
4474
|
+
weights_dtype = self.weights_precision.as_dtype()
|
|
4475
|
+
step = self.step
|
|
4476
|
+
with record_function(f"## fetch_from_l1_{step}_{self.tbe_unique_id} ##"):
|
|
4477
|
+
lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
|
|
4478
|
+
row_ids,
|
|
4479
|
+
self.lxu_cache_state,
|
|
4480
|
+
self.total_hash_size,
|
|
4481
|
+
)
|
|
4482
|
+
updated_weights = torch.empty(
|
|
4483
|
+
row_ids.numel(),
|
|
4484
|
+
self.cache_row_dim,
|
|
4485
|
+
device=self.current_device,
|
|
4486
|
+
dtype=weights_dtype,
|
|
4487
|
+
)
|
|
4488
|
+
|
|
4489
|
+
# D2D copy cache
|
|
4490
|
+
cache_location_mask = lxu_cache_locations >= 0
|
|
4491
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4492
|
+
updated_weights,
|
|
4493
|
+
lxu_cache_locations,
|
|
4494
|
+
self.lxu_cache_weights,
|
|
4495
|
+
torch.tensor(
|
|
4496
|
+
[row_ids.numel()],
|
|
4497
|
+
device=self.current_device,
|
|
4498
|
+
dtype=torch.int32,
|
|
4499
|
+
),
|
|
4500
|
+
)
|
|
4501
|
+
|
|
4502
|
+
with record_function(f"## fetch_from_sp_{step}_{self.tbe_unique_id} ##"):
|
|
4503
|
+
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
4504
|
+
sp = self.ssd_scratch_pad_eviction_data[0][0]
|
|
4505
|
+
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
|
|
4506
|
+
self.current_device
|
|
4507
|
+
)
|
|
4508
|
+
actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
|
|
4509
|
+
if actions_count_gpu.item() == 0:
|
|
4510
|
+
# no action to take
|
|
4511
|
+
return split_results_by_opt_states(
|
|
4512
|
+
updated_weights, cache_location_mask
|
|
4513
|
+
)
|
|
4514
|
+
|
|
4515
|
+
sp_idx = sp_idx[:actions_count_gpu]
|
|
4516
|
+
|
|
4517
|
+
# -1 in lxu_cache_locations means the row is not in L1 cache and in SP
|
|
4518
|
+
# fill the row_ids in L1 with -2, >0 values means in SP
|
|
4519
|
+
# @eg. updated_row_ids_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
|
|
4520
|
+
updated_row_ids_in_sp = row_ids.masked_fill(
|
|
4521
|
+
lxu_cache_locations != -1, -2
|
|
4522
|
+
)
|
|
4523
|
+
# sort the sp_idx for binary search
|
|
4524
|
+
# should already be sorted
|
|
4525
|
+
# sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
|
|
4526
|
+
# @eg. sp_idx = [4, 2, 1, 3, 10]
|
|
4527
|
+
# @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
|
|
4528
|
+
sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
|
|
4529
|
+
# search rows id in sp against the SP indexes to find location of the rows in SP
|
|
4530
|
+
# @eg: updated_ids_in_sp_idx = [0, 5, 0, 1, 0, 2, 3, 4, 4]
|
|
4531
|
+
# @eg: 5 is OOB
|
|
4532
|
+
updated_ids_in_sp_idx = torch.searchsorted(
|
|
4533
|
+
sorted_sp_idx, updated_row_ids_in_sp
|
|
4534
|
+
)
|
|
4535
|
+
# does not found in SP will Out of Bound
|
|
4536
|
+
oob_sp_idx = updated_ids_in_sp_idx >= sp_idx.numel()
|
|
4537
|
+
# make the oob items in bound
|
|
4538
|
+
# @eg updated_ids_in_sp_idx=[0, 0, 0, 1, 0, 2, 3, 4, 4]
|
|
4539
|
+
updated_ids_in_sp_idx[oob_sp_idx] = 0
|
|
4540
|
+
|
|
4541
|
+
# -1s locations will be filtered out in masked_index_select
|
|
4542
|
+
sp_locations_in_updated_weights = torch.full_like(
|
|
4543
|
+
updated_row_ids_in_sp, -1
|
|
4544
|
+
)
|
|
4545
|
+
# torch.searchsorted is not exact match,
|
|
4546
|
+
# we only take exact matched rows, where the id is found in SP.
|
|
4547
|
+
# @eg 5 in updated_row_ids_in_sp is not in sp_idx, but has 4 in updated_ids_in_sp_idx
|
|
4548
|
+
# @eg sorted_sp_idx[updated_ids_in_sp_idx]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
|
|
4549
|
+
# @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
|
|
4550
|
+
exact_match_mask = (
|
|
4551
|
+
sorted_sp_idx[updated_ids_in_sp_idx] == updated_row_ids_in_sp
|
|
4552
|
+
)
|
|
4553
|
+
# Get the location of the row ids found in SP.
|
|
4554
|
+
# @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
|
|
4555
|
+
sp_locations_found = sp_idx_inverse_indices[
|
|
4556
|
+
updated_ids_in_sp_idx[exact_match_mask]
|
|
4557
|
+
]
|
|
4558
|
+
# @eg: sp_locations_in_updated_weights=[ 2, -1, 2, 1, -1, 3, 0, -1, 4]
|
|
4559
|
+
sp_locations_in_updated_weights[exact_match_mask] = (
|
|
4560
|
+
sp_locations_found
|
|
4561
|
+
)
|
|
4562
|
+
|
|
4563
|
+
# D2D copy SP
|
|
4564
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4565
|
+
updated_weights,
|
|
4566
|
+
sp_locations_in_updated_weights,
|
|
4567
|
+
sp,
|
|
4568
|
+
torch.tensor(
|
|
4569
|
+
[row_ids.numel()],
|
|
4570
|
+
device=self.current_device,
|
|
4571
|
+
dtype=torch.int32,
|
|
4572
|
+
),
|
|
4573
|
+
)
|
|
4574
|
+
# cache_location_mask is the mask of rows in L1
|
|
4575
|
+
# exact_match_mask is the mask of rows in SP
|
|
4576
|
+
cache_location_mask = torch.logical_or(
|
|
4577
|
+
cache_location_mask, exact_match_mask
|
|
4578
|
+
)
|
|
4579
|
+
|
|
4580
|
+
return split_results_by_opt_states(updated_weights, cache_location_mask)
|
|
4581
|
+
|
|
4582
|
+
def register_backward_hook_before_eviction(
|
|
4583
|
+
self, backward_hook: Callable[[torch.Tensor], None]
|
|
4584
|
+
) -> None:
|
|
4585
|
+
"""
|
|
4586
|
+
Register a backward hook to the TBE module.
|
|
4587
|
+
And make sure this is called before the sp eviction hook.
|
|
4588
|
+
"""
|
|
4589
|
+
# make sure this hook is the first one to be executed
|
|
4590
|
+
hooks = []
|
|
4591
|
+
backward_hooks = self.placeholder_autograd_tensor._backward_hooks
|
|
4592
|
+
if backward_hooks is not None:
|
|
4593
|
+
for _handle_id, hook in backward_hooks.items():
|
|
4594
|
+
hooks.append(hook)
|
|
4595
|
+
backward_hooks.clear()
|
|
4596
|
+
|
|
4597
|
+
self.placeholder_autograd_tensor.register_hook(backward_hook)
|
|
4598
|
+
for hook in hooks:
|
|
4599
|
+
self.placeholder_autograd_tensor.register_hook(hook)
|
|
4600
|
+
|
|
4601
|
+
def set_local_weight_counts_for_table(
|
|
4602
|
+
self, table_idx: int, weight_count: int
|
|
4603
|
+
) -> None:
|
|
4604
|
+
self.local_weight_counts[table_idx] = weight_count
|
|
4605
|
+
|
|
4606
|
+
def set_global_id_per_rank_for_table(
|
|
4607
|
+
self, table_idx: int, global_id: torch.Tensor
|
|
4608
|
+
) -> None:
|
|
4609
|
+
self.global_id_per_rank[table_idx] = global_id
|
|
4610
|
+
|
|
4611
|
+
def direct_write_embedding(
|
|
4612
|
+
self,
|
|
4613
|
+
indices: torch.Tensor,
|
|
4614
|
+
offsets: torch.Tensor,
|
|
4615
|
+
weights: torch.Tensor,
|
|
4616
|
+
) -> None:
|
|
4617
|
+
"""
|
|
4618
|
+
Directly write the weights to L1, SP and backend without relying on auto-gradient for embedding cache.
|
|
4619
|
+
Please refer to design doc for more details: https://docs.google.com/document/d/1TJHKvO1m3-5tYAKZGhacXnGk7iCNAzz7wQlrFbX_LDI/edit?tab=t.0
|
|
4620
|
+
"""
|
|
4621
|
+
assert (
|
|
4622
|
+
self._embedding_cache_mode
|
|
4623
|
+
), "Must be in embedding_cache_mode to support direct_write_embedding method."
|
|
4624
|
+
|
|
4625
|
+
B_offsets = None
|
|
4626
|
+
max_B = -1
|
|
4627
|
+
|
|
4628
|
+
with torch.no_grad():
|
|
4629
|
+
# Wait for any ongoing prefetch operations to complete before starting direct_write
|
|
4630
|
+
current_stream = torch.cuda.current_stream()
|
|
4631
|
+
current_stream.wait_event(self.prefetch_complete_event)
|
|
4632
|
+
|
|
4633
|
+
# Create local step events for internal sequential execution
|
|
4634
|
+
weights_dtype = self.weights_precision.as_dtype()
|
|
4635
|
+
assert (
|
|
4636
|
+
weights_dtype == weights.dtype
|
|
4637
|
+
), f"Expected embedding table dtype {weights_dtype} is same with input weight dtype, but got {weights.dtype}"
|
|
4638
|
+
|
|
4639
|
+
# Pad the weights to match self.max_D width if necessary
|
|
4640
|
+
if weights.size(1) < self.cache_row_dim:
|
|
4641
|
+
weights = torch.nn.functional.pad(
|
|
4642
|
+
weights, (0, self.cache_row_dim - weights.size(1))
|
|
4643
|
+
)
|
|
4644
|
+
|
|
4645
|
+
step = self.step
|
|
4646
|
+
|
|
4647
|
+
# step 0: run backward hook for prefetch if prefetch pipeline is enabled before writing to L1 and SP
|
|
4648
|
+
if self.prefetch_pipeline:
|
|
4649
|
+
self._update_cache_counter_and_pointers(nn.Module(), torch.empty(0))
|
|
4650
|
+
|
|
4651
|
+
# step 1: lookup and write to l1 cache
|
|
4652
|
+
with record_function(
|
|
4653
|
+
f"## direct_write_to_l1_{step}_{self.tbe_unique_id} ##"
|
|
4654
|
+
):
|
|
4655
|
+
if self.gather_ssd_cache_stats:
|
|
4656
|
+
self.local_ssd_cache_stats.zero_()
|
|
4657
|
+
|
|
4658
|
+
# Linearize indices
|
|
4659
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
4660
|
+
self.hash_size_cumsum,
|
|
4661
|
+
indices,
|
|
4662
|
+
offsets,
|
|
4663
|
+
B_offsets,
|
|
4664
|
+
max_B,
|
|
4665
|
+
)
|
|
4666
|
+
|
|
4667
|
+
lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
|
|
4668
|
+
linear_cache_indices,
|
|
4669
|
+
self.lxu_cache_state,
|
|
4670
|
+
self.total_hash_size,
|
|
4671
|
+
)
|
|
4672
|
+
cache_location_mask = lxu_cache_locations >= 0
|
|
4673
|
+
|
|
4674
|
+
# Get the cache locations for the row_ids that are already in the cache
|
|
4675
|
+
cache_locations = lxu_cache_locations[cache_location_mask]
|
|
4676
|
+
|
|
4677
|
+
# Get the corresponding input weights for these row_ids
|
|
4678
|
+
cache_weights = weights[cache_location_mask]
|
|
4679
|
+
|
|
4680
|
+
# Update the cache with these input weights
|
|
4681
|
+
if cache_locations.numel() > 0:
|
|
4682
|
+
self.lxu_cache_weights.index_put_(
|
|
4683
|
+
(cache_locations,), cache_weights, accumulate=False
|
|
4684
|
+
)
|
|
4685
|
+
|
|
4686
|
+
# Record completion of step 1
|
|
4687
|
+
current_stream.record_event(self.direct_write_l1_complete_event)
|
|
4688
|
+
|
|
4689
|
+
# step 2: pop the current scratch pad and write to next batch scratch pad if exists
|
|
4690
|
+
# Wait for step 1 to complete
|
|
4691
|
+
with record_function(
|
|
4692
|
+
f"## direct_write_to_sp_{step}_{self.tbe_unique_id} ##"
|
|
4693
|
+
):
|
|
4694
|
+
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
4695
|
+
self.ssd_scratch_pad_eviction_data.pop(0)
|
|
4696
|
+
if len(self.ssd_scratch_pad_eviction_data) > 0:
|
|
4697
|
+
# Wait for any pending backend reads to the next scratch pad
|
|
4698
|
+
# to complete before we write to it. Otherwise, stale backend data
|
|
4699
|
+
# will overwrite our direct_write updates.
|
|
4700
|
+
# The ssd_event_get marks completion of backend fetch operations.
|
|
4701
|
+
current_stream.wait_event(self.ssd_event_get)
|
|
4702
|
+
|
|
4703
|
+
# if scratch pad exists, write to next batch scratch pad
|
|
4704
|
+
sp = self.ssd_scratch_pad_eviction_data[0][0]
|
|
4705
|
+
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
|
|
4706
|
+
self.current_device
|
|
4707
|
+
)
|
|
4708
|
+
actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
|
|
4709
|
+
if actions_count_gpu.item() != 0:
|
|
4710
|
+
# when no actional_count_gpu, no need to write to SP
|
|
4711
|
+
sp_idx = sp_idx[:actions_count_gpu]
|
|
4712
|
+
|
|
4713
|
+
# -1 in lxu_cache_locations means the row is not in L1 cache and in SP
|
|
4714
|
+
# fill the row_ids in L1 with -2, >0 values means in SP or backend
|
|
4715
|
+
# @eg. updated_indices_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
|
|
4716
|
+
updated_indices_in_sp = linear_cache_indices.masked_fill(
|
|
4717
|
+
lxu_cache_locations != -1, -2
|
|
4718
|
+
)
|
|
4719
|
+
# sort the sp_idx for binary search
|
|
4720
|
+
# should already be sorted
|
|
4721
|
+
# sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
|
|
4722
|
+
# @eg. sp_idx = [4, 2, 1, 3, 10]
|
|
4723
|
+
# @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
|
|
4724
|
+
sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
|
|
4725
|
+
# search rows id in sp against the SP indexes to find location of the rows in SP
|
|
4726
|
+
# @eg: updated_indices_in_sp = [0, 5, 0, 1, 0, 2, 3, 4, 4]
|
|
4727
|
+
# @eg: 5 is OOB
|
|
4728
|
+
updated_indices_in_sp_idx = torch.searchsorted(
|
|
4729
|
+
sorted_sp_idx, updated_indices_in_sp
|
|
4730
|
+
)
|
|
4731
|
+
# does not found in SP will Out of Bound
|
|
4732
|
+
oob_sp_idx = updated_indices_in_sp_idx >= sp_idx.numel()
|
|
4733
|
+
# make the oob items in bound
|
|
4734
|
+
# @eg updated_indices_in_sp=[0, 0, 0, 1, 0, 2, 3, 4, 4]
|
|
4735
|
+
updated_indices_in_sp_idx[oob_sp_idx] = 0
|
|
4736
|
+
|
|
4737
|
+
# torch.searchsorted is not exact match,
|
|
4738
|
+
# we only take exact matched rows, where the id is found in SP.
|
|
4739
|
+
# @eg 5 in updated_indices_in_sp is not in sp_idx, but has 4 in updated_indices_in_sp
|
|
4740
|
+
# @eg sorted_sp_idx[updated_indices_in_sp]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
|
|
4741
|
+
# @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
|
|
4742
|
+
exact_match_mask = (
|
|
4743
|
+
sorted_sp_idx[updated_indices_in_sp_idx]
|
|
4744
|
+
== updated_indices_in_sp
|
|
4745
|
+
)
|
|
4746
|
+
# Get the location of the row ids found in SP.
|
|
4747
|
+
# @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
|
|
4748
|
+
sp_locations_found = sp_idx_inverse_indices[
|
|
4749
|
+
updated_indices_in_sp[exact_match_mask]
|
|
4750
|
+
]
|
|
4751
|
+
# Get the corresponding weights for the matched indices
|
|
4752
|
+
matched_weights = weights[exact_match_mask]
|
|
4753
|
+
|
|
4754
|
+
# Write the weights to the sparse tensor at the found locations
|
|
4755
|
+
if sp_locations_found.numel() > 0:
|
|
4756
|
+
sp.index_put_(
|
|
4757
|
+
(sp_locations_found,),
|
|
4758
|
+
matched_weights,
|
|
4759
|
+
accumulate=False,
|
|
4760
|
+
)
|
|
4761
|
+
current_stream.record_event(self.direct_write_sp_complete_event)
|
|
4762
|
+
|
|
4763
|
+
# step 3: write l1 cache missing rows to backend
|
|
4764
|
+
# Wait for step 2 to complete
|
|
4765
|
+
with record_function(
|
|
4766
|
+
f"## direct_write_to_backend_{step}_{self.tbe_unique_id} ##"
|
|
4767
|
+
):
|
|
4768
|
+
# Use the existing ssd_eviction_stream for all backend write operations
|
|
4769
|
+
# This stream is already created with low priority during initialization
|
|
4770
|
+
with torch.cuda.stream(self.ssd_eviction_stream):
|
|
4771
|
+
# Create a mask for indices not in L1 cache
|
|
4772
|
+
non_cache_mask = ~cache_location_mask
|
|
4773
|
+
|
|
4774
|
+
# Calculate the count of valid indices (those not in L1 cache)
|
|
4775
|
+
valid_count = non_cache_mask.sum().to(torch.int64).cpu()
|
|
4776
|
+
|
|
4777
|
+
if valid_count.item() > 0:
|
|
4778
|
+
# Extract only the indices and weights that are not in L1 cache
|
|
4779
|
+
non_cache_indices = linear_cache_indices[non_cache_mask]
|
|
4780
|
+
non_cache_weights = weights[non_cache_mask]
|
|
4781
|
+
|
|
4782
|
+
# Move tensors to CPU for set_cuda
|
|
4783
|
+
cpu_indices = non_cache_indices.cpu()
|
|
4784
|
+
cpu_weights = non_cache_weights.cpu()
|
|
4785
|
+
|
|
4786
|
+
# Write to backend - only sending the non-cache indices and weights
|
|
4787
|
+
self.record_function_via_dummy_profile(
|
|
4788
|
+
f"## ssd_write_{step}_set_cuda_{self.tbe_unique_id} ##",
|
|
4789
|
+
self.ssd_db.set_cuda,
|
|
4790
|
+
cpu_indices,
|
|
4791
|
+
cpu_weights,
|
|
4792
|
+
valid_count,
|
|
4793
|
+
self.timestep,
|
|
4794
|
+
is_bwd=False,
|
|
4795
|
+
)
|
|
4796
|
+
|
|
4797
|
+
# Return control to the main stream without waiting for the backend operation to complete
|
|
4798
|
+
|
|
4799
|
+
def get_free_cpu_memory_gb(self) -> float:
|
|
4800
|
+
def _get_mem_available() -> float:
|
|
4801
|
+
if sys.platform.startswith("linux"):
|
|
4802
|
+
info = {}
|
|
4803
|
+
with open("/proc/meminfo") as f:
|
|
4804
|
+
for line in f:
|
|
4805
|
+
p = line.split()
|
|
4806
|
+
info[p[0].strip(":").lower()] = int(p[1]) * 1024
|
|
4807
|
+
if "memavailable" in info:
|
|
4808
|
+
# Linux >= 3.14
|
|
4809
|
+
return info["memavailable"]
|
|
4810
|
+
else:
|
|
4811
|
+
return info["memfree"] + info["cached"]
|
|
4812
|
+
else:
|
|
4813
|
+
raise RuntimeError(
|
|
4814
|
+
"Unsupported platform for free memory eviction, pls use ID count eviction tirgger mode"
|
|
4815
|
+
)
|
|
4816
|
+
|
|
4817
|
+
mem = _get_mem_available()
|
|
4818
|
+
return mem / (1024**3)
|
|
4819
|
+
|
|
4820
|
+
@classmethod
|
|
4821
|
+
def trigger_evict_in_all_tbes(cls) -> None:
|
|
4822
|
+
for tbe in cls._all_tbe_instances:
|
|
4823
|
+
tbe.ssd_db.trigger_feature_evict()
|
|
4824
|
+
|
|
4825
|
+
@classmethod
|
|
4826
|
+
def tbe_has_ongoing_eviction(cls) -> bool:
|
|
4827
|
+
for tbe in cls._all_tbe_instances:
|
|
4828
|
+
if tbe.ssd_db.is_evicting():
|
|
4829
|
+
return True
|
|
4830
|
+
return False
|
|
4831
|
+
|
|
4832
|
+
def set_free_mem_eviction_trigger_config(
|
|
4833
|
+
self, eviction_policy: EvictionPolicy
|
|
4834
|
+
) -> None:
|
|
4835
|
+
self.enable_free_mem_trigger_eviction = True
|
|
4836
|
+
self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode
|
|
4837
|
+
assert (
|
|
4838
|
+
eviction_policy.eviction_free_mem_check_interval_batch is not None
|
|
4839
|
+
), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode"
|
|
4840
|
+
self.eviction_free_mem_check_interval_batch: int = (
|
|
4841
|
+
eviction_policy.eviction_free_mem_check_interval_batch
|
|
4842
|
+
)
|
|
4843
|
+
assert (
|
|
4844
|
+
eviction_policy.eviction_free_mem_threshold_gb is not None
|
|
4845
|
+
), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode"
|
|
4846
|
+
self.eviction_free_mem_threshold_gb: int = (
|
|
4847
|
+
eviction_policy.eviction_free_mem_threshold_gb
|
|
4848
|
+
)
|
|
4849
|
+
logging.info(
|
|
4850
|
+
f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}"
|
|
4851
|
+
)
|
|
4852
|
+
|
|
4853
|
+
def may_trigger_eviction(self) -> None:
|
|
4854
|
+
def is_first_tbe() -> bool:
|
|
4855
|
+
first = SSDTableBatchedEmbeddingBags._first_instance_ref
|
|
4856
|
+
return first is not None and first() is self
|
|
4857
|
+
|
|
4858
|
+
# We assume that the eviction time is less than free mem check interval time
|
|
4859
|
+
# So every time we reach this check, all evictions in all tbes should be finished.
|
|
4860
|
+
# We only need to check the first tbe because all tbes share the same free mem,
|
|
4861
|
+
# once the first tbe detect need to trigger eviction, it will call trigger func
|
|
4862
|
+
# in all tbes from _all_tbe_instances
|
|
4863
|
+
if (
|
|
4864
|
+
self.enable_free_mem_trigger_eviction
|
|
4865
|
+
and self.step % self.eviction_free_mem_check_interval_batch == 0
|
|
4866
|
+
and self.training
|
|
4867
|
+
and is_first_tbe()
|
|
4868
|
+
):
|
|
4869
|
+
if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction():
|
|
4870
|
+
SSDTableBatchedEmbeddingBags._eviction_triggered = False
|
|
4871
|
+
|
|
4872
|
+
free_cpu_mem_gb = self.get_free_cpu_memory_gb()
|
|
4873
|
+
local_evict_trigger = int(
|
|
4874
|
+
free_cpu_mem_gb < self.eviction_free_mem_threshold_gb
|
|
4875
|
+
)
|
|
4876
|
+
tensor_flag = torch.tensor(
|
|
4877
|
+
local_evict_trigger,
|
|
4878
|
+
device=self.current_device,
|
|
4879
|
+
dtype=torch.int,
|
|
4880
|
+
)
|
|
4881
|
+
world_size = dist.get_world_size(self._pg)
|
|
4882
|
+
if world_size > 1:
|
|
4883
|
+
dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg)
|
|
4884
|
+
global_evict_trigger = tensor_flag.item()
|
|
4885
|
+
else:
|
|
4886
|
+
global_evict_trigger = local_evict_trigger
|
|
4887
|
+
if (
|
|
4888
|
+
global_evict_trigger >= 1
|
|
4889
|
+
and SSDTableBatchedEmbeddingBags._eviction_triggered
|
|
4890
|
+
):
|
|
4891
|
+
logging.warning(
|
|
4892
|
+
f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true"
|
|
4893
|
+
)
|
|
4894
|
+
if (
|
|
4895
|
+
global_evict_trigger >= 1
|
|
4896
|
+
and not SSDTableBatchedEmbeddingBags._eviction_triggered
|
|
4897
|
+
):
|
|
4898
|
+
SSDTableBatchedEmbeddingBags._eviction_triggered = True
|
|
4899
|
+
SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes()
|
|
4900
|
+
logging.info(
|
|
4901
|
+
f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction"
|
|
4902
|
+
)
|
|
4903
|
+
|
|
4904
|
+
def reset_inference_mode(self) -> None:
|
|
4905
|
+
"""
|
|
4906
|
+
Reset the inference mode
|
|
4907
|
+
"""
|
|
4908
|
+
self.eval()
|