fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,4600 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
# pyre-ignore-all-errors[56]
|
|
10
|
+
|
|
11
|
+
import contextlib
|
|
12
|
+
import enum
|
|
13
|
+
import functools
|
|
14
|
+
import logging
|
|
15
|
+
import math
|
|
16
|
+
import os
|
|
17
|
+
import uuid
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from itertools import accumulate
|
|
20
|
+
from math import log2
|
|
21
|
+
from typing import Any, Callable, Optional, Union
|
|
22
|
+
|
|
23
|
+
import torch # usort:skip
|
|
24
|
+
from torch import nn, Tensor # usort:skip
|
|
25
|
+
from torch.autograd.profiler import record_function # usort:skip
|
|
26
|
+
|
|
27
|
+
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
|
28
|
+
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
|
|
29
|
+
|
|
30
|
+
from fbgemm_gpu.config import FeatureGate, FeatureGateName
|
|
31
|
+
from fbgemm_gpu.runtime_monitor import (
|
|
32
|
+
AsyncSeriesTimer,
|
|
33
|
+
TBEStatsReporter,
|
|
34
|
+
TBEStatsReporterConfig,
|
|
35
|
+
)
|
|
36
|
+
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
|
|
37
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
38
|
+
BoundsCheckMode,
|
|
39
|
+
CacheAlgorithm,
|
|
40
|
+
CacheState,
|
|
41
|
+
ComputeDevice,
|
|
42
|
+
construct_cache_state,
|
|
43
|
+
EmbeddingLocation,
|
|
44
|
+
get_bounds_check_version_for_platform,
|
|
45
|
+
MAX_PREFETCH_DEPTH,
|
|
46
|
+
MultiPassPrefetchConfig,
|
|
47
|
+
PoolingMode,
|
|
48
|
+
RecordCacheMetrics,
|
|
49
|
+
SplitState,
|
|
50
|
+
)
|
|
51
|
+
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
|
|
52
|
+
generate_vbe_metadata,
|
|
53
|
+
is_torchdynamo_compiling,
|
|
54
|
+
)
|
|
55
|
+
from fbgemm_gpu.tbe_input_multiplexer import (
|
|
56
|
+
TBEInfo,
|
|
57
|
+
TBEInputInfo,
|
|
58
|
+
TBEInputMultiplexer,
|
|
59
|
+
TBEInputMultiplexerConfig,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
load_torch_module(
|
|
66
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_gpu",
|
|
67
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training",
|
|
68
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training",
|
|
69
|
+
)
|
|
70
|
+
load_torch_module_bc(
|
|
71
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_cpu",
|
|
72
|
+
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training",
|
|
73
|
+
)
|
|
74
|
+
except Exception:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
DEFAULT_ASSOC = 32 if torch.version.hip is None else 64
|
|
79
|
+
INT8_EMB_ROW_DIM_OFFSET = 8
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class DoesNotHavePrefix(Exception):
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class WeightDecayMode(enum.IntEnum):
|
|
87
|
+
NONE = 0
|
|
88
|
+
L2 = 1
|
|
89
|
+
DECOUPLE = 2
|
|
90
|
+
COUNTER = 3
|
|
91
|
+
COWCLIP = 4
|
|
92
|
+
DECOUPLE_GLOBAL = 5
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class CounterWeightDecayMode(enum.IntEnum):
|
|
96
|
+
NONE = 0
|
|
97
|
+
L2 = 1
|
|
98
|
+
DECOUPLE = 2
|
|
99
|
+
ADAGRADW = 3
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class StepMode(enum.IntEnum):
|
|
103
|
+
NONE = 0
|
|
104
|
+
USE_COUNTER = 1
|
|
105
|
+
USE_ITER = 2
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LearningRateMode(enum.IntEnum):
|
|
109
|
+
EQUAL = -1
|
|
110
|
+
TAIL_ID_LR_INCREASE = 0
|
|
111
|
+
TAIL_ID_LR_DECREASE = 1
|
|
112
|
+
COUNTER_SGD = 2
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class GradSumDecay(enum.IntEnum):
|
|
116
|
+
NO_DECAY = -1
|
|
117
|
+
CTR_DECAY = 0
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass(frozen=True)
|
|
121
|
+
class TailIdThreshold:
|
|
122
|
+
val: float = 0
|
|
123
|
+
is_ratio: bool = False
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass(frozen=True)
|
|
127
|
+
class CounterBasedRegularizationDefinition:
|
|
128
|
+
counter_weight_decay_mode: CounterWeightDecayMode = CounterWeightDecayMode.NONE
|
|
129
|
+
counter_halflife: int = -1
|
|
130
|
+
adjustment_iter: int = -1
|
|
131
|
+
adjustment_ub: float = 1.0
|
|
132
|
+
learning_rate_mode: LearningRateMode = LearningRateMode.EQUAL
|
|
133
|
+
grad_sum_decay: GradSumDecay = GradSumDecay.NO_DECAY
|
|
134
|
+
tail_id_threshold: TailIdThreshold = field(default_factory=TailIdThreshold)
|
|
135
|
+
max_counter_update_freq: int = 1000
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass(frozen=True)
|
|
139
|
+
class CowClipDefinition:
|
|
140
|
+
counter_weight_decay_mode: CounterWeightDecayMode = CounterWeightDecayMode.NONE
|
|
141
|
+
counter_halflife: int = -1
|
|
142
|
+
weight_norm_coefficient: float = 0.0
|
|
143
|
+
lower_bound: float = 0.0
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass(frozen=True)
|
|
147
|
+
class GlobalWeightDecayDefinition:
|
|
148
|
+
start_iter: int = 0
|
|
149
|
+
lower_bound: float = 0.0
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclass(frozen=True)
|
|
153
|
+
class UserEnabledConfigDefinition:
|
|
154
|
+
"""
|
|
155
|
+
This class is used to configure whether certain modes are to be enabled
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
# This is used in Adam to perform rowwise bias correction using `row_counter`
|
|
159
|
+
# More details can be found in D64848802.
|
|
160
|
+
use_rowwise_bias_correction: bool = False
|
|
161
|
+
use_writeback_bwd_prehook: bool = False
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@dataclass(frozen=True)
|
|
165
|
+
class EnsembleModeDefinition:
|
|
166
|
+
step_ema: float = 10000
|
|
167
|
+
step_swap: float = 10000
|
|
168
|
+
step_start: float = 0
|
|
169
|
+
step_ema_coef: float = 0.6
|
|
170
|
+
step_mode: StepMode = StepMode.USE_ITER
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@dataclass(frozen=True)
|
|
174
|
+
class EmainplaceModeDefinition:
|
|
175
|
+
step_ema: float = 10
|
|
176
|
+
step_start: float = 0
|
|
177
|
+
step_ema_coef: float = 0.6
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# Keep in sync with fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
|
|
181
|
+
class UVMCacheStatsIndex(enum.IntEnum):
|
|
182
|
+
num_calls = 0
|
|
183
|
+
num_requested_indices = 1
|
|
184
|
+
num_unique_indices = 2
|
|
185
|
+
num_unique_misses = 3
|
|
186
|
+
num_conflict_unique_misses = 4
|
|
187
|
+
num_conflict_misses = 5
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclass
|
|
191
|
+
class RESParams:
|
|
192
|
+
res_server_port: int = 0 # the port of the res server
|
|
193
|
+
res_store_shards: int = 1 # the number of shards to store the raw embeddings
|
|
194
|
+
table_names: list[str] = field(default_factory=list) # table names the TBE holds
|
|
195
|
+
table_offsets: list[int] = field(
|
|
196
|
+
default_factory=list
|
|
197
|
+
) # table offsets for the global rows the TBE holds
|
|
198
|
+
table_sizes: list[int] = field(
|
|
199
|
+
default_factory=list
|
|
200
|
+
) # table sizes for the global rows the TBE holds
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class PrefetchedInfo:
|
|
204
|
+
"""
|
|
205
|
+
Container for prefetched cache information.
|
|
206
|
+
|
|
207
|
+
This class is explicitly defined (not using @dataclass) to be compatible with
|
|
208
|
+
TorchScript's inspect.getsource() requirements.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
linear_unique_indices: torch.Tensor,
|
|
214
|
+
linear_unique_cache_indices: torch.Tensor,
|
|
215
|
+
linear_unique_indices_length: torch.Tensor,
|
|
216
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
217
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
218
|
+
) -> None:
|
|
219
|
+
self.linear_unique_indices = linear_unique_indices
|
|
220
|
+
self.linear_unique_cache_indices = linear_unique_cache_indices
|
|
221
|
+
self.linear_unique_indices_length = linear_unique_indices_length
|
|
222
|
+
self.hash_zch_identities = hash_zch_identities
|
|
223
|
+
self.hash_zch_runtime_meta = hash_zch_runtime_meta
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def construct_split_state(
|
|
227
|
+
embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]],
|
|
228
|
+
rowwise: bool,
|
|
229
|
+
cacheable: bool,
|
|
230
|
+
precision: SparseType = SparseType.FP32,
|
|
231
|
+
int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
|
|
232
|
+
placement: Optional[EmbeddingLocation] = None,
|
|
233
|
+
) -> SplitState:
|
|
234
|
+
placements: list[EmbeddingLocation] = []
|
|
235
|
+
offsets: list[int] = []
|
|
236
|
+
dev_size: int = 0
|
|
237
|
+
host_size: int = 0
|
|
238
|
+
uvm_size: int = 0
|
|
239
|
+
for num_embeddings, embedding_dim, location, _ in embedding_specs:
|
|
240
|
+
assert (
|
|
241
|
+
embedding_dim % 4 == 0
|
|
242
|
+
), f"embedding_dim must be a multiple of 4, but got {embedding_dim}"
|
|
243
|
+
if precision == SparseType.INT8:
|
|
244
|
+
embedding_dim += int8_emb_row_dim_offset
|
|
245
|
+
state_size = num_embeddings * embedding_dim if not rowwise else num_embeddings
|
|
246
|
+
location = placement if placement is not None else location
|
|
247
|
+
if location == EmbeddingLocation.HOST:
|
|
248
|
+
placements.append(EmbeddingLocation.HOST)
|
|
249
|
+
offsets.append(host_size)
|
|
250
|
+
host_size += state_size
|
|
251
|
+
# If table is on device, then opimtizer is on device.
|
|
252
|
+
# If table is managed, then if optimizer state is rowwise, optimizer is on device, otherwise optimizer is managed.
|
|
253
|
+
elif location == EmbeddingLocation.DEVICE or rowwise:
|
|
254
|
+
placements.append(EmbeddingLocation.DEVICE)
|
|
255
|
+
offsets.append(dev_size)
|
|
256
|
+
dev_size += state_size
|
|
257
|
+
else:
|
|
258
|
+
if cacheable and location == EmbeddingLocation.MANAGED_CACHING:
|
|
259
|
+
placements.append(EmbeddingLocation.MANAGED_CACHING)
|
|
260
|
+
else:
|
|
261
|
+
placements.append(EmbeddingLocation.MANAGED)
|
|
262
|
+
offsets.append(uvm_size)
|
|
263
|
+
uvm_size += state_size
|
|
264
|
+
assert len(placements) == len(offsets)
|
|
265
|
+
return SplitState(
|
|
266
|
+
dev_size=dev_size,
|
|
267
|
+
host_size=host_size,
|
|
268
|
+
uvm_size=uvm_size,
|
|
269
|
+
placements=placements,
|
|
270
|
+
offsets=offsets,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def apply_split_helper(
|
|
275
|
+
persistent_state_fn: Callable[[str, Tensor], None],
|
|
276
|
+
set_attr_fn: Callable[
|
|
277
|
+
[str, Union[Tensor, list[int], list[EmbeddingLocation]]], None
|
|
278
|
+
],
|
|
279
|
+
current_device: torch.device,
|
|
280
|
+
use_cpu: bool,
|
|
281
|
+
feature_table_map: list[int],
|
|
282
|
+
split: SplitState,
|
|
283
|
+
prefix: str,
|
|
284
|
+
dtype: type[torch.dtype],
|
|
285
|
+
enforce_hbm: bool = False,
|
|
286
|
+
make_dev_param: bool = False,
|
|
287
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
288
|
+
uvm_tensors_log: Optional[list[str]] = None,
|
|
289
|
+
uvm_host_mapped: bool = False,
|
|
290
|
+
) -> None:
|
|
291
|
+
set_attr_fn(f"{prefix}_physical_placements", split.placements)
|
|
292
|
+
set_attr_fn(f"{prefix}_physical_offsets", split.offsets)
|
|
293
|
+
|
|
294
|
+
offsets = [split.offsets[t] for t in feature_table_map]
|
|
295
|
+
placements = [split.placements[t] for t in feature_table_map]
|
|
296
|
+
persistent_state_fn(
|
|
297
|
+
f"{prefix}_offsets",
|
|
298
|
+
torch.tensor(offsets, device=current_device, dtype=torch.int64),
|
|
299
|
+
)
|
|
300
|
+
persistent_state_fn(
|
|
301
|
+
f"{prefix}_placements",
|
|
302
|
+
torch.tensor(placements, device=current_device, dtype=torch.int32),
|
|
303
|
+
)
|
|
304
|
+
if split.dev_size > 0:
|
|
305
|
+
dev_buffer = torch.zeros(
|
|
306
|
+
split.dev_size,
|
|
307
|
+
device=current_device,
|
|
308
|
+
# pyre-fixme[6]
|
|
309
|
+
dtype=dtype,
|
|
310
|
+
)
|
|
311
|
+
dev_buffer = (
|
|
312
|
+
dev_buffer.view(*dev_reshape) if dev_reshape is not None else dev_buffer
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
# pyre-fixme[6]
|
|
316
|
+
dev_buffer = torch.empty(0, device=current_device, dtype=dtype)
|
|
317
|
+
if make_dev_param:
|
|
318
|
+
set_attr_fn(f"{prefix}_dev", nn.Parameter(dev_buffer))
|
|
319
|
+
else:
|
|
320
|
+
persistent_state_fn(f"{prefix}_dev", dev_buffer)
|
|
321
|
+
if split.host_size > 0:
|
|
322
|
+
if dtype == torch.uint8:
|
|
323
|
+
persistent_state_fn(
|
|
324
|
+
f"{prefix}_host",
|
|
325
|
+
torch.zeros(
|
|
326
|
+
split.host_size,
|
|
327
|
+
device=current_device,
|
|
328
|
+
# pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]` for
|
|
329
|
+
# 3rd param but got `Type[Type[torch._dtype]]`.
|
|
330
|
+
dtype=dtype,
|
|
331
|
+
),
|
|
332
|
+
)
|
|
333
|
+
else:
|
|
334
|
+
set_attr_fn(
|
|
335
|
+
f"{prefix}_host",
|
|
336
|
+
nn.Parameter(
|
|
337
|
+
torch.zeros(
|
|
338
|
+
split.host_size,
|
|
339
|
+
device=current_device,
|
|
340
|
+
# pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]`
|
|
341
|
+
# for 3rd param but got `Type[Type[torch._dtype]]`.
|
|
342
|
+
dtype=dtype,
|
|
343
|
+
)
|
|
344
|
+
),
|
|
345
|
+
)
|
|
346
|
+
if uvm_tensors_log is not None:
|
|
347
|
+
uvm_tensors_log.append(f"{prefix}_host")
|
|
348
|
+
else:
|
|
349
|
+
persistent_state_fn(
|
|
350
|
+
f"{prefix}_host",
|
|
351
|
+
# pyre-fixme[6]: For 3rd param expected `dtype` but got `Type[dtype]`.
|
|
352
|
+
torch.empty(0, device=current_device, dtype=dtype),
|
|
353
|
+
)
|
|
354
|
+
if split.uvm_size > 0:
|
|
355
|
+
assert not use_cpu
|
|
356
|
+
if enforce_hbm:
|
|
357
|
+
logging.info("Enforce hbm for the cache location")
|
|
358
|
+
persistent_state_fn(
|
|
359
|
+
f"{prefix}_uvm",
|
|
360
|
+
torch.zeros(
|
|
361
|
+
split.uvm_size,
|
|
362
|
+
device=current_device,
|
|
363
|
+
# pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]` for
|
|
364
|
+
# 3rd param but got `Type[Type[torch._dtype]]`.
|
|
365
|
+
dtype=dtype,
|
|
366
|
+
),
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
persistent_state_fn(
|
|
370
|
+
f"{prefix}_uvm",
|
|
371
|
+
torch.zeros(
|
|
372
|
+
split.uvm_size,
|
|
373
|
+
device=current_device,
|
|
374
|
+
out=torch.ops.fbgemm.new_unified_tensor(
|
|
375
|
+
# pyre-fixme[6]: Expected `Optional[Type[torch._dtype]]`
|
|
376
|
+
# for 3rd param but got `Type[Type[torch._dtype]]`.
|
|
377
|
+
torch.zeros(1, device=current_device, dtype=dtype),
|
|
378
|
+
[split.uvm_size],
|
|
379
|
+
is_host_mapped=uvm_host_mapped,
|
|
380
|
+
),
|
|
381
|
+
),
|
|
382
|
+
)
|
|
383
|
+
if uvm_tensors_log is not None:
|
|
384
|
+
uvm_tensors_log.append(f"{prefix}_uvm")
|
|
385
|
+
else:
|
|
386
|
+
persistent_state_fn(
|
|
387
|
+
f"{prefix}_uvm",
|
|
388
|
+
# pyre-fixme[6]: For 3rd param expected `dtype` but got `Type[dtype]`.
|
|
389
|
+
torch.empty(0, device=current_device, dtype=dtype),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def get_available_compute_device() -> ComputeDevice:
|
|
394
|
+
if torch.cuda.is_available():
|
|
395
|
+
return ComputeDevice.CUDA
|
|
396
|
+
elif torch.mtia.is_available():
|
|
397
|
+
return ComputeDevice.MTIA
|
|
398
|
+
else:
|
|
399
|
+
return ComputeDevice.CPU
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
|
|
403
|
+
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
|
|
404
|
+
class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
405
|
+
"""
|
|
406
|
+
Table Batched Embedding (TBE) operator. Looks up one or more embedding
|
|
407
|
+
tables. The module is application for training. The backward operator is
|
|
408
|
+
fused with optimizer. Thus, the embedding tables are updated during
|
|
409
|
+
backward.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]):
|
|
413
|
+
A list of embedding specifications. Each spec describes a
|
|
414
|
+
specification of a physical embedding table. Each one is a tuple of
|
|
415
|
+
number of embedding rows, embedding dimension (must be a multiple of
|
|
416
|
+
4), table placement (`EmbeddingLocation`), and compute device
|
|
417
|
+
(`ComputeDevice`).
|
|
418
|
+
|
|
419
|
+
Available `EmbeddingLocation` options are
|
|
420
|
+
|
|
421
|
+
(1) `DEVICE` = placing an embedding table in the GPU global memory
|
|
422
|
+
(HBM)
|
|
423
|
+
|
|
424
|
+
(2) `MANAGED` = placing an embedding in the unified virtual memory
|
|
425
|
+
(accessible from both GPU and CPU)
|
|
426
|
+
|
|
427
|
+
(3) `MANAGED_CACHING` = placing an embedding table in the unified
|
|
428
|
+
virtual memory and using the GPU global memory (HBM) as a cache
|
|
429
|
+
|
|
430
|
+
(4) `HOST` = placing an embedding table in the CPU memory (DRAM)
|
|
431
|
+
|
|
432
|
+
(5) `MTIA` = placing an embedding table in the MTIA memory
|
|
433
|
+
|
|
434
|
+
Available `ComputeDevice` options are
|
|
435
|
+
|
|
436
|
+
(1) `CPU` = performing table lookup on CPU
|
|
437
|
+
|
|
438
|
+
(2) `CUDA` = performing table lookup on GPU
|
|
439
|
+
|
|
440
|
+
(3) `MTIA` = performing table lookup on MTIA
|
|
441
|
+
|
|
442
|
+
feature_table_map (Optional[List[int]] = None): An optional list that
|
|
443
|
+
specifies feature-table mapping. feature_table_map[i] indicates the
|
|
444
|
+
physical embedding table that feature i maps to.
|
|
445
|
+
|
|
446
|
+
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU): The cache
|
|
447
|
+
algorithm (used when `EmbeddingLocation` is set to
|
|
448
|
+
`MANAGED_CACHING`). Options are
|
|
449
|
+
|
|
450
|
+
(1) `LRU` = least recently used
|
|
451
|
+
|
|
452
|
+
(2) `LFU` = least frequently used
|
|
453
|
+
|
|
454
|
+
cache_load_factor (float = 0.2): A factor used for determining the
|
|
455
|
+
cache capacity when `EmbeddingLocation.MANAGED_CACHING` is used.
|
|
456
|
+
The cache capacity is `cache_load_factor` * the total number of
|
|
457
|
+
rows in all embedding tables
|
|
458
|
+
|
|
459
|
+
cache_sets (int = 0): The number of cache sets (used when
|
|
460
|
+
`EmbeddingLocation` is set to `MANAGED_CACHING`)
|
|
461
|
+
|
|
462
|
+
cache_reserved_memory (float = 0.0): The amount of memory reserved in
|
|
463
|
+
HBM for non-cache purpose (used when `EmbeddingLocation` is set to
|
|
464
|
+
`MANAGED_CACHING`).
|
|
465
|
+
|
|
466
|
+
cache_precision (SparseType = SparseType.FP32): The data type of the
|
|
467
|
+
cache (used when `EmbeddingLocation` is set to `MANAGED_CACHING`).
|
|
468
|
+
Options are `SparseType.FP32` and `SparseType.FP16`
|
|
469
|
+
|
|
470
|
+
weights_precision (SparseType = SparseType.FP32): The data type of
|
|
471
|
+
embedding tables (also known as weights). Options are
|
|
472
|
+
`SparseType.FP32` and `SparseType.FP16`
|
|
473
|
+
|
|
474
|
+
output_dtype (SparseType = SparseType.FP32): The data type of an output
|
|
475
|
+
tensor. Options are `SparseType.FP32` and `SparseType.FP16`
|
|
476
|
+
|
|
477
|
+
enforce_hbm (bool = False): If True, place all weights/momentums in HBM
|
|
478
|
+
when using `EmbeddingLocation.MANAGED_CACHING`
|
|
479
|
+
|
|
480
|
+
optimizer (OptimType = OptimType.EXACT_SGD): An optimizer to use for
|
|
481
|
+
embedding table update in the backward pass. Available `OptimType`
|
|
482
|
+
options are
|
|
483
|
+
|
|
484
|
+
(1) `ADAM` = Adam
|
|
485
|
+
|
|
486
|
+
(2) `EXACT_ADAGRAD` = Adagrad
|
|
487
|
+
|
|
488
|
+
(3) `EXACT_ROWWISE_ADAGRAD` = Rowwise-Aadagrad
|
|
489
|
+
|
|
490
|
+
(4) `EXACT_SGD` = SGD
|
|
491
|
+
|
|
492
|
+
(5) `LAMB` = Lamb
|
|
493
|
+
|
|
494
|
+
(6) `LARS_SGD` = LARS-SGD
|
|
495
|
+
|
|
496
|
+
(7) `PARTIAL_ROWWISE_ADAM` = Partial rowwise-Adam
|
|
497
|
+
|
|
498
|
+
(8) `PARTIAL_ROWWISE_LAMB` = Partial rowwise-Lamb
|
|
499
|
+
|
|
500
|
+
(9) `ENSEMBLE_ROWWISE_ADAGRAD` = Ensemble rowwise-Adagrad
|
|
501
|
+
|
|
502
|
+
(10) `EMAINPLACE_ROWWISE_ADAGRAD` = Ema inplace rowwise-Adagrad
|
|
503
|
+
|
|
504
|
+
(11) `NONE` = Not applying an optimizer update in the backward pass
|
|
505
|
+
and outputting a sparse weight gradient
|
|
506
|
+
|
|
507
|
+
record_cache_metrics (Optional[RecordCacheMetrics] = None): Record
|
|
508
|
+
a number of hits, a number of requests, etc if
|
|
509
|
+
`RecordCacheMetrics.record_cache_miss_counter` is True and record
|
|
510
|
+
the similar metrics table-wise if
|
|
511
|
+
`RecordCacheMetrics.record_tablewise_cache_miss is True`
|
|
512
|
+
|
|
513
|
+
gather_uvm_cache_stats (Optional[bool] = False): If True, collect the
|
|
514
|
+
cache statistics when `EmbeddingLocation` is set to
|
|
515
|
+
`MANAGED_CACHING`
|
|
516
|
+
|
|
517
|
+
stochastic_rounding (bool = True): If True, apply stochastic rounding
|
|
518
|
+
for weight type that is not `SparseType.FP32`
|
|
519
|
+
|
|
520
|
+
gradient_clipping (bool = False): If True, apply gradient clipping
|
|
521
|
+
|
|
522
|
+
max_gradient (float = 1.0): The value for gradient clipping
|
|
523
|
+
|
|
524
|
+
max_norm (float = 0.0): The max norm value
|
|
525
|
+
|
|
526
|
+
learning_rate (float = 0.01): The learning rate
|
|
527
|
+
|
|
528
|
+
eps (float = 1.0e-8): The epsilon value used by Adagrad, LAMB, and
|
|
529
|
+
Adam. Note that default is different from torch.nn.optim.Adagrad
|
|
530
|
+
default of 1e-10
|
|
531
|
+
|
|
532
|
+
momentum (float = 0.9): Momentum used by LARS-SGD
|
|
533
|
+
|
|
534
|
+
weight_decay (float = 0.0): Weight decay used by LARS-SGD, LAMB, ADAM,
|
|
535
|
+
and rowwise-Adagrad.
|
|
536
|
+
|
|
537
|
+
(1) EXACT_ADAGRAD, SGD, EXACT_SGD do not support weight decay
|
|
538
|
+
|
|
539
|
+
(2) LAMB, ADAM, PARTIAL_ROWWISE_ADAM, PARTIAL_ROWWISE_LAMB, LARS_SGD
|
|
540
|
+
support decoupled weight decay
|
|
541
|
+
|
|
542
|
+
(3) EXACT_ROWWISE_ADAGRAD support both L2 and decoupled weight decay
|
|
543
|
+
(via weight_decay_mode)
|
|
544
|
+
|
|
545
|
+
weight_decay_mode (WeightDecayMode = WeightDecayMode.NONE): Weight decay
|
|
546
|
+
mode. Options are `WeightDecayMode.NONE`, `WeightDecayMode.L2`,
|
|
547
|
+
and `WeightDecayMode.DECOUPLE`
|
|
548
|
+
|
|
549
|
+
eta (float = 0.001): The eta value used by LARS-SGD
|
|
550
|
+
|
|
551
|
+
beta1 (float = 0.9): The beta1 value used by LAMB and ADAM
|
|
552
|
+
|
|
553
|
+
beta2 (float = 0.999): The beta2 value used by LAMB and ADAM
|
|
554
|
+
|
|
555
|
+
ensemble_mode (Optional[EnsembleModeDefinition] = None):
|
|
556
|
+
Used by Ensemble Rowwise Adagrad
|
|
557
|
+
|
|
558
|
+
emainplace_mode (Optional[EmainplaceModeDefinition] = None):
|
|
559
|
+
Used by EMA in-place Rowwise Adagrad
|
|
560
|
+
|
|
561
|
+
counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None):
|
|
562
|
+
Used by Rowwise Adagrad
|
|
563
|
+
|
|
564
|
+
cowclip_regularization (Optional[CowClipDefinition] = None): Used by
|
|
565
|
+
Rowwise Adagrad
|
|
566
|
+
|
|
567
|
+
pooling_mode (PoolingMode = PoolingMode.SUM): Pooling mode. Available
|
|
568
|
+
`PoolingMode` options are
|
|
569
|
+
|
|
570
|
+
(1) `SUM` = Sum pooling
|
|
571
|
+
|
|
572
|
+
(2) `MEAN` = Mean pooling
|
|
573
|
+
|
|
574
|
+
(3) `NONE` = No pooling (sequence embedding)
|
|
575
|
+
|
|
576
|
+
device (Optional[Union[str, int, torch.device]] = None): The current
|
|
577
|
+
device to place tensors on
|
|
578
|
+
|
|
579
|
+
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING): Input
|
|
580
|
+
checking mode. Available `BoundsCheckMode` options are
|
|
581
|
+
|
|
582
|
+
(1) `NONE` = skip bounds check
|
|
583
|
+
|
|
584
|
+
(2) `FATAL` = throw an error when encountering an invalid
|
|
585
|
+
index/offset
|
|
586
|
+
|
|
587
|
+
(3) `WARNING` = print a warning message when encountering an
|
|
588
|
+
invalid index/offset and fix it (setting an invalid index to
|
|
589
|
+
zero and adjusting an invalid offset to be within the bound)
|
|
590
|
+
|
|
591
|
+
(4) `IGNORE` = silently fix an invalid index/offset (setting an
|
|
592
|
+
invalid index to zero and adjusting an invalid offset to be
|
|
593
|
+
within the bound)
|
|
594
|
+
|
|
595
|
+
uvm_non_rowwise_momentum (bool = False): If True, place non-rowwise
|
|
596
|
+
momentum on the unified virtual memory
|
|
597
|
+
|
|
598
|
+
use_experimental_tbe (bool = False): If True, use an optimized TBE
|
|
599
|
+
implementation (TBE v2). Note that this is supported only on NVIDIA
|
|
600
|
+
GPUs.
|
|
601
|
+
|
|
602
|
+
prefetch_pipeline (bool = False): If True, enable cache prefetch
|
|
603
|
+
pipeline when using `EmbeddingLocation.MANAGED_CACHING`. Currently
|
|
604
|
+
only supports the LRU cache policy. If a separate stream is used
|
|
605
|
+
for prefetch, the optional `forward_stream` arg of prefetch
|
|
606
|
+
function must be set.
|
|
607
|
+
|
|
608
|
+
stats_reporter_config (Optional[TBEStatsReporterConfig] = None):
|
|
609
|
+
A config for TBE stats reporter
|
|
610
|
+
|
|
611
|
+
table_names (Optional[List[str]] = None): A list of embedding table
|
|
612
|
+
names in this TBE
|
|
613
|
+
|
|
614
|
+
optimizer_state_dtypes (Optional[Dict[str, SparseType]] = None): A
|
|
615
|
+
optimizer state data types dict. Keys are the optimizer state names
|
|
616
|
+
and values are their corresponding types
|
|
617
|
+
|
|
618
|
+
multipass_prefetch_config (Optional[MultiPassPrefetchConfig] = None):
|
|
619
|
+
A config for multipass cache prefetching (when
|
|
620
|
+
`EmbeddingLocation.MANAGED_CACHING` is used)
|
|
621
|
+
|
|
622
|
+
global_weight_decay (Optional[GlobalWeightDecayDefinition] = None):
|
|
623
|
+
A config for global weight decay
|
|
624
|
+
|
|
625
|
+
uvm_host_mapped (bool = False): If True, allocate every UVM tensor
|
|
626
|
+
using `malloc` + `cudaHostRegister`. Otherwise use
|
|
627
|
+
`cudaMallocManaged`
|
|
628
|
+
|
|
629
|
+
extra_optimizer_config Optional[UserEnabledConfigDefinition] = None):
|
|
630
|
+
An extra config to enable certain modes for optimizer. These modes
|
|
631
|
+
are not enabled by default.
|
|
632
|
+
- `use_rowwise_bias_correction` is used in Adam to enable rowwise
|
|
633
|
+
bias correction computation
|
|
634
|
+
|
|
635
|
+
embedding_table_index_type (torch.dtype = torch.int64): The data type of
|
|
636
|
+
the embedding table index tensor. Options are `torch.int32` and
|
|
637
|
+
`torch.int64`
|
|
638
|
+
|
|
639
|
+
embedding_table_offset_type (torch.dtype = torch.int64): The data type of
|
|
640
|
+
the embedding table offset tensor. Options are `torch.int32` and
|
|
641
|
+
`torch.int64`
|
|
642
|
+
|
|
643
|
+
embedding_shard_info (Optional[List[Tuple[int, int, int, int]]] = None): the
|
|
644
|
+
information about shard position and pre-sharded table size. If not set,
|
|
645
|
+
the table is not sharded.
|
|
646
|
+
(preshard_table_height, preshard_table_dim, height_offset, dim_offset)
|
|
647
|
+
"""
|
|
648
|
+
|
|
649
|
+
embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]]
|
|
650
|
+
optimizer_args: invokers.lookup_args.OptimizerArgs
|
|
651
|
+
lxu_cache_locations_list: list[Tensor]
|
|
652
|
+
lxu_cache_locations_empty: Tensor
|
|
653
|
+
timesteps_prefetched: list[int]
|
|
654
|
+
prefetched_info_list: list[PrefetchedInfo]
|
|
655
|
+
record_cache_metrics: RecordCacheMetrics
|
|
656
|
+
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
|
|
657
|
+
uvm_cache_stats: torch.Tensor
|
|
658
|
+
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
|
|
659
|
+
local_uvm_cache_stats: torch.Tensor
|
|
660
|
+
uuid: str
|
|
661
|
+
# pyre-fixme[13]: Attribute `last_uvm_cache_print_state` is never initialized.
|
|
662
|
+
last_uvm_cache_print_state: torch.Tensor
|
|
663
|
+
_vbe_B_offsets: Optional[torch.Tensor]
|
|
664
|
+
_vbe_max_B: int
|
|
665
|
+
|
|
666
|
+
def __init__( # noqa C901
|
|
667
|
+
self,
|
|
668
|
+
embedding_specs: list[
|
|
669
|
+
tuple[int, int, EmbeddingLocation, ComputeDevice]
|
|
670
|
+
], # tuple of (rows, dims, placements, compute_devices)
|
|
671
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
672
|
+
cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU,
|
|
673
|
+
cache_load_factor: float = 0.2,
|
|
674
|
+
cache_sets: int = 0,
|
|
675
|
+
cache_reserved_memory: float = 0.0,
|
|
676
|
+
cache_precision: Optional[SparseType] = None,
|
|
677
|
+
weights_precision: SparseType = SparseType.FP32,
|
|
678
|
+
output_dtype: SparseType = SparseType.FP32,
|
|
679
|
+
enforce_hbm: bool = False,
|
|
680
|
+
optimizer: OptimType = OptimType.EXACT_SGD,
|
|
681
|
+
record_cache_metrics: Optional[RecordCacheMetrics] = None,
|
|
682
|
+
gather_uvm_cache_stats: Optional[bool] = False,
|
|
683
|
+
# General Optimizer args
|
|
684
|
+
stochastic_rounding: bool = True,
|
|
685
|
+
gradient_clipping: bool = False,
|
|
686
|
+
max_gradient: float = 1.0,
|
|
687
|
+
max_norm: float = 0.0,
|
|
688
|
+
learning_rate: float = 0.01,
|
|
689
|
+
eps: float = 1.0e-8,
|
|
690
|
+
momentum: float = 0.9,
|
|
691
|
+
weight_decay: float = 0.0,
|
|
692
|
+
weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE,
|
|
693
|
+
eta: float = 0.001,
|
|
694
|
+
beta1: float = 0.9,
|
|
695
|
+
beta2: float = 0.999,
|
|
696
|
+
ensemble_mode: Optional[EnsembleModeDefinition] = None,
|
|
697
|
+
emainplace_mode: Optional[EmainplaceModeDefinition] = None,
|
|
698
|
+
counter_based_regularization: Optional[
|
|
699
|
+
CounterBasedRegularizationDefinition
|
|
700
|
+
] = None,
|
|
701
|
+
cowclip_regularization: Optional[CowClipDefinition] = None,
|
|
702
|
+
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
703
|
+
device: Optional[Union[str, int, torch.device]] = None,
|
|
704
|
+
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
705
|
+
uvm_non_rowwise_momentum: bool = False,
|
|
706
|
+
use_experimental_tbe: bool = False,
|
|
707
|
+
prefetch_pipeline: bool = False,
|
|
708
|
+
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
|
|
709
|
+
table_names: Optional[list[str]] = None,
|
|
710
|
+
optimizer_state_dtypes: Optional[dict[str, SparseType]] = None,
|
|
711
|
+
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
|
|
712
|
+
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
|
|
713
|
+
uvm_host_mapped: bool = False,
|
|
714
|
+
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
|
|
715
|
+
tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
|
|
716
|
+
embedding_table_index_type: torch.dtype = torch.int64,
|
|
717
|
+
embedding_table_offset_type: torch.dtype = torch.int64,
|
|
718
|
+
embedding_shard_info: Optional[list[tuple[int, int, int, int]]] = None,
|
|
719
|
+
enable_raw_embedding_streaming: bool = False,
|
|
720
|
+
res_params: Optional[RESParams] = None,
|
|
721
|
+
) -> None:
|
|
722
|
+
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
|
|
723
|
+
self.uuid = str(uuid.uuid4())
|
|
724
|
+
self.log("SplitTableBatchedEmbeddingBagsCodegen API: V2")
|
|
725
|
+
self.log(f"SplitTableBatchedEmbeddingBagsCodegen Arguments: {locals()}")
|
|
726
|
+
self.log(
|
|
727
|
+
f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
self.table_names: Optional[list[str]] = table_names
|
|
731
|
+
self.logging_table_name: str = self.get_table_name_for_logging(table_names)
|
|
732
|
+
self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
|
|
733
|
+
self.pooling_mode = pooling_mode
|
|
734
|
+
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
|
|
735
|
+
|
|
736
|
+
# If environment variable is set, it overwrites the default bounds check mode.
|
|
737
|
+
self.bounds_check_version: int = (
|
|
738
|
+
2
|
|
739
|
+
if self._feature_is_enabled(FeatureGateName.BOUNDS_CHECK_INDICES_V2)
|
|
740
|
+
else get_bounds_check_version_for_platform()
|
|
741
|
+
)
|
|
742
|
+
self.bounds_check_mode_int: int = int(
|
|
743
|
+
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
|
|
744
|
+
)
|
|
745
|
+
# Check if bounds_check_indices_v2 is enabled via the feature gate
|
|
746
|
+
bounds_check_mode = BoundsCheckMode(self.bounds_check_mode_int)
|
|
747
|
+
if bounds_check_mode.name.startswith("V2_"):
|
|
748
|
+
self.bounds_check_version = 2
|
|
749
|
+
if bounds_check_mode == BoundsCheckMode.V2_IGNORE:
|
|
750
|
+
bounds_check_mode = BoundsCheckMode.IGNORE
|
|
751
|
+
elif bounds_check_mode == BoundsCheckMode.V2_WARNING:
|
|
752
|
+
bounds_check_mode = BoundsCheckMode.WARNING
|
|
753
|
+
elif bounds_check_mode == BoundsCheckMode.V2_FATAL:
|
|
754
|
+
bounds_check_mode = BoundsCheckMode.FATAL
|
|
755
|
+
|
|
756
|
+
if bounds_check_mode not in (
|
|
757
|
+
BoundsCheckMode.IGNORE,
|
|
758
|
+
BoundsCheckMode.WARNING,
|
|
759
|
+
BoundsCheckMode.FATAL,
|
|
760
|
+
BoundsCheckMode.NONE,
|
|
761
|
+
):
|
|
762
|
+
raise NotImplementedError(
|
|
763
|
+
f"SplitTableBatchedEmbeddingBagsCodegen bounds_check_mode={bounds_check_mode} is not supported"
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
self.bounds_check_mode_int = bounds_check_mode.value
|
|
767
|
+
|
|
768
|
+
self.log(
|
|
769
|
+
f"SplitTableBatchedEmbeddingBagsCodegen bounds_check_mode={bounds_check_mode} bounds_check_version={self.bounds_check_version}"
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
self.weights_precision = weights_precision
|
|
773
|
+
|
|
774
|
+
if torch.cuda.is_available() and torch.version.hip:
|
|
775
|
+
# NOTE: It was discovered that FP16 cache precision caused a 500x
|
|
776
|
+
# slowdown in performance of split_embedding_nobag_backward_codegen_rowwise_adagrad_unweighted_kernel_warp_per_row_1
|
|
777
|
+
# kernel on ROCm, so to work around this, we fix cache precision to
|
|
778
|
+
# be FP32 always for the ROCm environment case.
|
|
779
|
+
#
|
|
780
|
+
# See:
|
|
781
|
+
# https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
|
|
782
|
+
cache_precision = SparseType.FP32
|
|
783
|
+
self.log("Override cache_precision=SparseType.FP32 on ROCm")
|
|
784
|
+
else:
|
|
785
|
+
# NOTE: The changes from D65865527 are retained here until we can
|
|
786
|
+
# test that the the hack also works for non-ROCm environments.
|
|
787
|
+
cache_precision = (
|
|
788
|
+
weights_precision if cache_precision is None else cache_precision
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
self.output_dtype: int = output_dtype.as_int()
|
|
792
|
+
assert (
|
|
793
|
+
not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU
|
|
794
|
+
), "Only LRU cache policy supports prefetch_pipeline."
|
|
795
|
+
self.prefetch_pipeline: bool = prefetch_pipeline
|
|
796
|
+
self.lock_cache_line: bool = self.prefetch_pipeline
|
|
797
|
+
self.use_uniq_cache_locations_bwd: bool = self.prefetch_pipeline
|
|
798
|
+
self.multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = (
|
|
799
|
+
multipass_prefetch_config
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
if record_cache_metrics is not None:
|
|
803
|
+
self.record_cache_metrics = record_cache_metrics
|
|
804
|
+
else:
|
|
805
|
+
self.record_cache_metrics = RecordCacheMetrics(False, False)
|
|
806
|
+
|
|
807
|
+
if multipass_prefetch_config:
|
|
808
|
+
assert (
|
|
809
|
+
prefetch_pipeline
|
|
810
|
+
), "Multipass prefetch makes no sense in non-prefetch mode."
|
|
811
|
+
assert (
|
|
812
|
+
cache_algorithm == CacheAlgorithm.LRU
|
|
813
|
+
), "Multipass prefetch is only supported in LRU cache."
|
|
814
|
+
assert (
|
|
815
|
+
multipass_prefetch_config.num_passes > 0
|
|
816
|
+
), f"num_passes must be positive, get {multipass_prefetch_config.num_passes}"
|
|
817
|
+
assert (
|
|
818
|
+
multipass_prefetch_config.min_splitable_pass_size > 0
|
|
819
|
+
), f"min_splitable_pass_size must be positive, get {multipass_prefetch_config.min_splitable_pass_size}"
|
|
820
|
+
assert (
|
|
821
|
+
not self.record_cache_metrics.record_cache_miss_counter
|
|
822
|
+
and not self.record_cache_metrics.record_tablewise_cache_miss
|
|
823
|
+
), "Unique cache miss counters are not accurate in multipass prefetch and therefore not supported"
|
|
824
|
+
|
|
825
|
+
self.embedding_specs = embedding_specs
|
|
826
|
+
(rows, dims, locations, compute_devices) = zip(*embedding_specs)
|
|
827
|
+
T_ = len(self.embedding_specs)
|
|
828
|
+
self.dims: list[int] = dims
|
|
829
|
+
assert T_ > 0
|
|
830
|
+
# mixed D is not supported by no bag kernels
|
|
831
|
+
mixed_D = False
|
|
832
|
+
D = self.dims[0]
|
|
833
|
+
for d in self.dims:
|
|
834
|
+
if d != D:
|
|
835
|
+
mixed_D = True
|
|
836
|
+
break
|
|
837
|
+
if mixed_D:
|
|
838
|
+
assert (
|
|
839
|
+
self.pooling_mode != PoolingMode.NONE
|
|
840
|
+
), "Mixed dimension tables only supported for pooling tables."
|
|
841
|
+
self.mixed_D: bool = mixed_D
|
|
842
|
+
assert all(
|
|
843
|
+
cd == compute_devices[0] for cd in compute_devices
|
|
844
|
+
), "Heterogenous compute_devices are NOT supported!"
|
|
845
|
+
# Split TBE has different function schemas for CUDA and CPU.
|
|
846
|
+
# For MTIA device type, it uses the CPU one.
|
|
847
|
+
self.use_cpu: bool = (
|
|
848
|
+
compute_devices[0] == ComputeDevice.CPU
|
|
849
|
+
or compute_devices[0] == ComputeDevice.MTIA
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
assert not self.use_cpu or all(
|
|
853
|
+
loc == EmbeddingLocation.HOST for loc in locations
|
|
854
|
+
), "ComputeDevice.CPU is only for EmbeddingLocation.HOST!"
|
|
855
|
+
assert self.use_cpu or all(
|
|
856
|
+
loc != EmbeddingLocation.HOST for loc in locations
|
|
857
|
+
), "EmbeddingLocation.HOST doesn't work for CUDA device!"
|
|
858
|
+
if self.use_cpu or self.pooling_mode == PoolingMode.NONE:
|
|
859
|
+
assert output_dtype in [
|
|
860
|
+
SparseType.FP32,
|
|
861
|
+
SparseType.FP16,
|
|
862
|
+
SparseType.BF16,
|
|
863
|
+
], "Fused pooled embedding quantization only supported for cuda."
|
|
864
|
+
|
|
865
|
+
if optimizer == OptimType.NONE:
|
|
866
|
+
assert all(
|
|
867
|
+
loc == EmbeddingLocation.DEVICE for loc in locations
|
|
868
|
+
), "OptimType.NONE supports only EmbeddingLocation.DEVICE"
|
|
869
|
+
assert all(
|
|
870
|
+
cd == ComputeDevice.CUDA for cd in compute_devices
|
|
871
|
+
), "OptimType.NONE supports only ComputeDevice.CUDA"
|
|
872
|
+
assert (
|
|
873
|
+
not mixed_D
|
|
874
|
+
), "OptimType.NONE does not support mixed embedding dimension"
|
|
875
|
+
|
|
876
|
+
if device is None:
|
|
877
|
+
self.current_device: torch.device = (
|
|
878
|
+
torch.device("cpu")
|
|
879
|
+
if self.use_cpu
|
|
880
|
+
else torch.device(torch.cuda.current_device())
|
|
881
|
+
)
|
|
882
|
+
elif isinstance(device, torch.device):
|
|
883
|
+
self.current_device = device
|
|
884
|
+
else:
|
|
885
|
+
self.current_device = torch.device(device)
|
|
886
|
+
|
|
887
|
+
# add placeholder require_grad param tensor to enable autograd with int8 weights
|
|
888
|
+
self.placeholder_autograd_tensor = nn.Parameter(
|
|
889
|
+
torch.zeros(0, device=self.current_device, dtype=torch.float)
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
self.gather_uvm_cache_stats = gather_uvm_cache_stats
|
|
893
|
+
# Define the size of uvm cache stats as class variable
|
|
894
|
+
# to make it work with torch jit script.
|
|
895
|
+
self.uvm_cache_stats_size = 6
|
|
896
|
+
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
|
|
897
|
+
# 4: N_conflict_unique_misses, 5: N_conflict_misses
|
|
898
|
+
|
|
899
|
+
# Reporter to collect runtime performance stats bottom-up. Reporter may
|
|
900
|
+
# do aggregation across TBEs and publish results per training batch.
|
|
901
|
+
# Example of stats include UVM cache hit rate, table I/O size, etc.
|
|
902
|
+
self.stats_reporter: Optional[TBEStatsReporter] = (
|
|
903
|
+
stats_reporter_config.create_reporter() if stats_reporter_config else None
|
|
904
|
+
)
|
|
905
|
+
self._uvm_tensors_log: list[str] = []
|
|
906
|
+
|
|
907
|
+
self.bwd_wait_prefetch_timer: Optional[AsyncSeriesTimer] = None
|
|
908
|
+
self.prefetch_duration_timer: Optional[AsyncSeriesTimer] = None
|
|
909
|
+
if self.stats_reporter:
|
|
910
|
+
# When stats_reporter is present, we set up async series timer to
|
|
911
|
+
# measure the GPU time per tracked event accordingly. Each of them
|
|
912
|
+
# is attached to custom callback report function to report collected
|
|
913
|
+
# duration with the corresponding event name.
|
|
914
|
+
self.bwd_wait_prefetch_timer = AsyncSeriesTimer(
|
|
915
|
+
functools.partial(
|
|
916
|
+
SplitTableBatchedEmbeddingBagsCodegen._report_duration,
|
|
917
|
+
self,
|
|
918
|
+
event_name="bwd_wait_for_prefetch",
|
|
919
|
+
)
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
self.prefetch_duration_timer = AsyncSeriesTimer(
|
|
923
|
+
functools.partial(
|
|
924
|
+
SplitTableBatchedEmbeddingBagsCodegen._report_duration,
|
|
925
|
+
self,
|
|
926
|
+
event_name="total_prefetch_duration",
|
|
927
|
+
)
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET
|
|
931
|
+
|
|
932
|
+
self.feature_table_map: list[int] = (
|
|
933
|
+
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
if embedding_shard_info:
|
|
937
|
+
(full_table_heights, full_table_dims, row_offset, col_offset) = zip(
|
|
938
|
+
*embedding_shard_info
|
|
939
|
+
)
|
|
940
|
+
else:
|
|
941
|
+
# Just assume the table is unsharded
|
|
942
|
+
full_table_heights = rows
|
|
943
|
+
full_table_dims = dims
|
|
944
|
+
row_offset = [0] * len(rows)
|
|
945
|
+
col_offset = [0] * len(rows)
|
|
946
|
+
self.tbe_input_multiplexer: Optional[TBEInputMultiplexer] = (
|
|
947
|
+
tbe_input_multiplexer_config.create_tbe_input_multiplexer(
|
|
948
|
+
tbe_info=TBEInfo(
|
|
949
|
+
table_names=(
|
|
950
|
+
table_names
|
|
951
|
+
if table_names
|
|
952
|
+
else [f"table-{i}" for i in range(len(embedding_specs))]
|
|
953
|
+
),
|
|
954
|
+
table_heights=rows,
|
|
955
|
+
tbe_uuid=self.uuid,
|
|
956
|
+
feature_table_map=self.feature_table_map,
|
|
957
|
+
table_dims=dims,
|
|
958
|
+
full_table_heights=full_table_heights,
|
|
959
|
+
full_table_dims=full_table_dims,
|
|
960
|
+
row_offset=row_offset,
|
|
961
|
+
col_offset=col_offset,
|
|
962
|
+
)
|
|
963
|
+
)
|
|
964
|
+
if tbe_input_multiplexer_config is not None
|
|
965
|
+
else None
|
|
966
|
+
)
|
|
967
|
+
T = len(self.feature_table_map)
|
|
968
|
+
assert T_ <= T
|
|
969
|
+
table_has_feature = [False] * T_
|
|
970
|
+
for t in self.feature_table_map:
|
|
971
|
+
table_has_feature[t] = True
|
|
972
|
+
assert all(table_has_feature), "Each table must have at least one feature!"
|
|
973
|
+
|
|
974
|
+
feature_dims = [dims[t] for t in self.feature_table_map]
|
|
975
|
+
D_offsets = [0] + list(accumulate(feature_dims))
|
|
976
|
+
self.total_D: int = D_offsets[-1]
|
|
977
|
+
self.max_D: int = max(dims)
|
|
978
|
+
cached_dims = [
|
|
979
|
+
embedding_spec[1]
|
|
980
|
+
for embedding_spec in embedding_specs
|
|
981
|
+
if embedding_spec[2] == EmbeddingLocation.MANAGED_CACHING
|
|
982
|
+
]
|
|
983
|
+
self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0
|
|
984
|
+
|
|
985
|
+
self.register_buffer(
|
|
986
|
+
"D_offsets",
|
|
987
|
+
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
|
|
988
|
+
)
|
|
989
|
+
hash_size_cumsum = [0] + list(accumulate(rows))
|
|
990
|
+
self.total_hash_size: int = int(hash_size_cumsum[-1])
|
|
991
|
+
if self.total_hash_size == 0:
|
|
992
|
+
self.total_hash_size_bits: int = 0
|
|
993
|
+
else:
|
|
994
|
+
self.total_hash_size_bits: int = int(log2(float(self.total_hash_size)) + 1)
|
|
995
|
+
# The last element is to easily access # of rows of each table by
|
|
996
|
+
# hash_size_cumsum[t + 1] - hash_size_cumsum[t]
|
|
997
|
+
hash_size_cumsum = [hash_size_cumsum[t] for t in self.feature_table_map] + [
|
|
998
|
+
self.total_hash_size
|
|
999
|
+
]
|
|
1000
|
+
self.register_buffer(
|
|
1001
|
+
"hash_size_cumsum",
|
|
1002
|
+
torch.tensor(
|
|
1003
|
+
hash_size_cumsum, device=self.current_device, dtype=torch.int64
|
|
1004
|
+
),
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
self.register_buffer(
|
|
1008
|
+
"rows_per_table",
|
|
1009
|
+
torch.tensor(
|
|
1010
|
+
[rows[t] for t in self.feature_table_map],
|
|
1011
|
+
device=self.current_device,
|
|
1012
|
+
dtype=torch.int64,
|
|
1013
|
+
),
|
|
1014
|
+
)
|
|
1015
|
+
self.register_buffer(
|
|
1016
|
+
"bounds_check_warning",
|
|
1017
|
+
torch.tensor([0], device=self.current_device, dtype=torch.int64),
|
|
1018
|
+
)
|
|
1019
|
+
# Required for VBE
|
|
1020
|
+
self.register_buffer(
|
|
1021
|
+
"feature_dims",
|
|
1022
|
+
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
|
|
1023
|
+
)
|
|
1024
|
+
(_info_B_num_bits, _info_B_mask) = torch.ops.fbgemm.get_infos_metadata(
|
|
1025
|
+
self.D_offsets, # unused tensor
|
|
1026
|
+
1, # max_B
|
|
1027
|
+
T, # T
|
|
1028
|
+
)
|
|
1029
|
+
self.info_B_num_bits: int = _info_B_num_bits
|
|
1030
|
+
self.info_B_mask: int = _info_B_mask
|
|
1031
|
+
|
|
1032
|
+
# A flag for indicating whether all embedding tables are placed in the
|
|
1033
|
+
# same locations
|
|
1034
|
+
self.use_homogeneous_placements: bool = all(
|
|
1035
|
+
loc == locations[0] for loc in locations
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
self.uvm_host_mapped = uvm_host_mapped
|
|
1039
|
+
|
|
1040
|
+
weight_split = construct_split_state(
|
|
1041
|
+
embedding_specs,
|
|
1042
|
+
rowwise=False,
|
|
1043
|
+
cacheable=True,
|
|
1044
|
+
precision=weights_precision,
|
|
1045
|
+
)
|
|
1046
|
+
table_embedding_dtype = weights_precision.as_dtype()
|
|
1047
|
+
|
|
1048
|
+
self._apply_split(
|
|
1049
|
+
weight_split,
|
|
1050
|
+
prefix="weights",
|
|
1051
|
+
# pyre-fixme[6]: For 3rd param expected `Type[Type[_dtype]]` but got
|
|
1052
|
+
# `Type[_dtype]`.
|
|
1053
|
+
dtype=table_embedding_dtype,
|
|
1054
|
+
enforce_hbm=enforce_hbm,
|
|
1055
|
+
make_dev_param=optimizer == OptimType.NONE,
|
|
1056
|
+
dev_reshape=(-1, self.max_D) if optimizer == OptimType.NONE else None,
|
|
1057
|
+
uvm_host_mapped=self.uvm_host_mapped,
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
assert optimizer not in (
|
|
1061
|
+
OptimType.SGD,
|
|
1062
|
+
OptimType.ROWWISE_ADAGRAD,
|
|
1063
|
+
), f"Optimizer {optimizer} is deprecated in the CPU + GPU modes."
|
|
1064
|
+
|
|
1065
|
+
if self.use_cpu:
|
|
1066
|
+
# Construct optimizer states
|
|
1067
|
+
assert optimizer in (
|
|
1068
|
+
OptimType.EXACT_ADAGRAD,
|
|
1069
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
1070
|
+
OptimType.EXACT_SGD,
|
|
1071
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
1072
|
+
), f"Optimizer {optimizer} is not supported in CPU mode."
|
|
1073
|
+
else:
|
|
1074
|
+
assert optimizer in (
|
|
1075
|
+
OptimType.ADAM,
|
|
1076
|
+
OptimType.EXACT_ADAGRAD,
|
|
1077
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
1078
|
+
OptimType.EXACT_SGD,
|
|
1079
|
+
OptimType.LAMB,
|
|
1080
|
+
OptimType.LARS_SGD,
|
|
1081
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
1082
|
+
OptimType.PARTIAL_ROWWISE_LAMB,
|
|
1083
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
1084
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
1085
|
+
OptimType.NONE,
|
|
1086
|
+
), f"Optimizer {optimizer} is not supported."
|
|
1087
|
+
|
|
1088
|
+
self.stochastic_rounding = stochastic_rounding
|
|
1089
|
+
self.optimizer = optimizer
|
|
1090
|
+
|
|
1091
|
+
self.weight_decay_mode = weight_decay_mode
|
|
1092
|
+
if (weight_decay_mode == WeightDecayMode.COUNTER) != (
|
|
1093
|
+
counter_based_regularization is not None
|
|
1094
|
+
):
|
|
1095
|
+
raise AssertionError(
|
|
1096
|
+
"Need to set weight_decay_mode=WeightDecayMode.COUNTER together with valid counter_based_regularization"
|
|
1097
|
+
)
|
|
1098
|
+
if (weight_decay_mode == WeightDecayMode.COWCLIP) != (
|
|
1099
|
+
cowclip_regularization is not None
|
|
1100
|
+
):
|
|
1101
|
+
raise AssertionError(
|
|
1102
|
+
"Need to set weight_decay_mode=WeightDecayMode.COWCLIP together with valid cowclip_regularization"
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
self._used_rowwise_adagrad_with_counter: bool = (
|
|
1106
|
+
optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
|
|
1107
|
+
and (
|
|
1108
|
+
weight_decay_mode in (WeightDecayMode.COUNTER, WeightDecayMode.COWCLIP)
|
|
1109
|
+
)
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
if weight_decay_mode == WeightDecayMode.DECOUPLE_GLOBAL and (
|
|
1113
|
+
not optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
|
|
1114
|
+
or global_weight_decay is None
|
|
1115
|
+
):
|
|
1116
|
+
raise AssertionError(
|
|
1117
|
+
"""weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL is supported for
|
|
1118
|
+
optimizer=OptimType.EXACT_ROWWISE_ADAGRAD and global_weight_decay cannot be None.
|
|
1119
|
+
"""
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
self._used_rowwise_adagrad_with_global_weight_decay: bool = (
|
|
1123
|
+
optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
|
|
1124
|
+
and (weight_decay_mode == WeightDecayMode.DECOUPLE_GLOBAL)
|
|
1125
|
+
)
|
|
1126
|
+
self.log(
|
|
1127
|
+
f"Using global weight decay = {self._used_rowwise_adagrad_with_global_weight_decay}"
|
|
1128
|
+
)
|
|
1129
|
+
# Declare GWD params here to avoid torch.jit.script error
|
|
1130
|
+
if global_weight_decay is None:
|
|
1131
|
+
global_weight_decay = GlobalWeightDecayDefinition()
|
|
1132
|
+
|
|
1133
|
+
self.gwd_start_iter: int = global_weight_decay.start_iter
|
|
1134
|
+
self.gwd_lower_bound: float = global_weight_decay.lower_bound
|
|
1135
|
+
|
|
1136
|
+
if ensemble_mode is None:
|
|
1137
|
+
ensemble_mode = EnsembleModeDefinition()
|
|
1138
|
+
self._ensemble_mode: dict[str, float] = {
|
|
1139
|
+
key: float(fval) for key, fval in ensemble_mode.__dict__.items()
|
|
1140
|
+
}
|
|
1141
|
+
|
|
1142
|
+
if emainplace_mode is None:
|
|
1143
|
+
emainplace_mode = EmainplaceModeDefinition()
|
|
1144
|
+
self._emainplace_mode: dict[str, float] = {
|
|
1145
|
+
key: float(fval) for key, fval in emainplace_mode.__dict__.items()
|
|
1146
|
+
}
|
|
1147
|
+
|
|
1148
|
+
if counter_based_regularization is None:
|
|
1149
|
+
counter_based_regularization = CounterBasedRegularizationDefinition()
|
|
1150
|
+
if cowclip_regularization is None:
|
|
1151
|
+
cowclip_regularization = CowClipDefinition()
|
|
1152
|
+
self._max_counter_update_freq: int = -1
|
|
1153
|
+
# Extract parameters from CounterBasedRegularizationDefinition or CowClipDefinition
|
|
1154
|
+
# which are passed as entries for OptimizerArgs
|
|
1155
|
+
if self._used_rowwise_adagrad_with_counter:
|
|
1156
|
+
if self.weight_decay_mode == WeightDecayMode.COUNTER:
|
|
1157
|
+
self._max_counter_update_freq = (
|
|
1158
|
+
counter_based_regularization.max_counter_update_freq
|
|
1159
|
+
)
|
|
1160
|
+
opt_arg_weight_decay_mode = (
|
|
1161
|
+
counter_based_regularization.counter_weight_decay_mode
|
|
1162
|
+
)
|
|
1163
|
+
counter_halflife = counter_based_regularization.counter_halflife
|
|
1164
|
+
else:
|
|
1165
|
+
opt_arg_weight_decay_mode = (
|
|
1166
|
+
cowclip_regularization.counter_weight_decay_mode
|
|
1167
|
+
)
|
|
1168
|
+
counter_halflife = cowclip_regularization.counter_halflife
|
|
1169
|
+
else:
|
|
1170
|
+
opt_arg_weight_decay_mode = weight_decay_mode
|
|
1171
|
+
# Default: -1, no decay applied, as a placeholder for OptimizerArgs
|
|
1172
|
+
# which should not be effective when CounterBasedRegularizationDefinition
|
|
1173
|
+
# and CowClipDefinition are not used
|
|
1174
|
+
counter_halflife = -1
|
|
1175
|
+
|
|
1176
|
+
if extra_optimizer_config is None:
|
|
1177
|
+
extra_optimizer_config = UserEnabledConfigDefinition()
|
|
1178
|
+
self.use_rowwise_bias_correction: bool = (
|
|
1179
|
+
extra_optimizer_config.use_rowwise_bias_correction
|
|
1180
|
+
)
|
|
1181
|
+
self.use_writeback_bwd_prehook: bool = (
|
|
1182
|
+
extra_optimizer_config.use_writeback_bwd_prehook
|
|
1183
|
+
)
|
|
1184
|
+
self.log(f"self.extra_optimizer_config is {extra_optimizer_config}")
|
|
1185
|
+
if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
|
|
1186
|
+
raise AssertionError(
|
|
1187
|
+
"`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
|
|
1188
|
+
)
|
|
1189
|
+
if self.use_writeback_bwd_prehook and not self.optimizer == OptimType.EXACT_SGD:
|
|
1190
|
+
raise AssertionError(
|
|
1191
|
+
"`use_writeback_bwd_prehook` is only supported for OptimType.EXACT_SGD",
|
|
1192
|
+
)
|
|
1193
|
+
|
|
1194
|
+
self.learning_rate_tensor: torch.Tensor = torch.tensor(
|
|
1195
|
+
learning_rate, device=torch.device("cpu"), dtype=torch.float32
|
|
1196
|
+
)
|
|
1197
|
+
|
|
1198
|
+
self.optimizer_args = invokers.lookup_args.OptimizerArgs(
|
|
1199
|
+
stochastic_rounding=stochastic_rounding,
|
|
1200
|
+
gradient_clipping=gradient_clipping,
|
|
1201
|
+
max_gradient=max_gradient,
|
|
1202
|
+
max_norm=max_norm,
|
|
1203
|
+
eps=eps,
|
|
1204
|
+
beta1=beta1,
|
|
1205
|
+
beta2=beta2,
|
|
1206
|
+
weight_decay=weight_decay,
|
|
1207
|
+
weight_decay_mode=opt_arg_weight_decay_mode.value,
|
|
1208
|
+
eta=eta,
|
|
1209
|
+
momentum=momentum,
|
|
1210
|
+
counter_halflife=counter_halflife,
|
|
1211
|
+
adjustment_iter=counter_based_regularization.adjustment_iter,
|
|
1212
|
+
adjustment_ub=counter_based_regularization.adjustment_ub,
|
|
1213
|
+
learning_rate_mode=counter_based_regularization.learning_rate_mode.value,
|
|
1214
|
+
grad_sum_decay=counter_based_regularization.grad_sum_decay.value,
|
|
1215
|
+
tail_id_threshold=counter_based_regularization.tail_id_threshold.val,
|
|
1216
|
+
is_tail_id_thresh_ratio=int(
|
|
1217
|
+
counter_based_regularization.tail_id_threshold.is_ratio
|
|
1218
|
+
),
|
|
1219
|
+
total_hash_size=self.total_hash_size,
|
|
1220
|
+
weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
|
|
1221
|
+
lower_bound=cowclip_regularization.lower_bound,
|
|
1222
|
+
regularization_mode=weight_decay_mode.value,
|
|
1223
|
+
use_rowwise_bias_correction=self.use_rowwise_bias_correction,
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
if optimizer != OptimType.NONE:
|
|
1227
|
+
assert (
|
|
1228
|
+
optimizer
|
|
1229
|
+
in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ENSEMBLE_ROWWISE_ADAGRAD)
|
|
1230
|
+
or optimizer_state_dtypes is None
|
|
1231
|
+
), "optimizer_state_dtypes option is only supported for OptimType.PARTIAL_ROWWISE_ADAM and OptimType.ENSEMBLE_ROWWISE_ADAGRAD"
|
|
1232
|
+
if optimizer in (OptimType.EXACT_SGD,):
|
|
1233
|
+
# NOTE: make TorchScript work!
|
|
1234
|
+
self._register_nonpersistent_buffers("momentum1")
|
|
1235
|
+
else:
|
|
1236
|
+
momentum1_dtype = (
|
|
1237
|
+
torch.float32
|
|
1238
|
+
if (
|
|
1239
|
+
optimizer_state_dtypes is None
|
|
1240
|
+
or "momentum1" not in optimizer_state_dtypes
|
|
1241
|
+
or optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD
|
|
1242
|
+
)
|
|
1243
|
+
else optimizer_state_dtypes["momentum1"].as_dtype()
|
|
1244
|
+
)
|
|
1245
|
+
rowwise = optimizer in [
|
|
1246
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
1247
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
1248
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
1249
|
+
]
|
|
1250
|
+
self._apply_split(
|
|
1251
|
+
construct_split_state(
|
|
1252
|
+
embedding_specs,
|
|
1253
|
+
rowwise=rowwise,
|
|
1254
|
+
cacheable=False,
|
|
1255
|
+
placement=(
|
|
1256
|
+
EmbeddingLocation.MANAGED
|
|
1257
|
+
if ((not rowwise) and uvm_non_rowwise_momentum)
|
|
1258
|
+
else None
|
|
1259
|
+
),
|
|
1260
|
+
),
|
|
1261
|
+
prefix="momentum1",
|
|
1262
|
+
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
|
|
1263
|
+
# but got `Type[torch.float32]`.
|
|
1264
|
+
dtype=momentum1_dtype,
|
|
1265
|
+
enforce_hbm=enforce_hbm,
|
|
1266
|
+
uvm_host_mapped=self.uvm_host_mapped,
|
|
1267
|
+
)
|
|
1268
|
+
if optimizer in (
|
|
1269
|
+
OptimType.ADAM,
|
|
1270
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
1271
|
+
OptimType.LAMB,
|
|
1272
|
+
OptimType.PARTIAL_ROWWISE_LAMB,
|
|
1273
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
1274
|
+
):
|
|
1275
|
+
rowwise = optimizer in (
|
|
1276
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
1277
|
+
OptimType.PARTIAL_ROWWISE_LAMB,
|
|
1278
|
+
)
|
|
1279
|
+
momentum2_dtype = (
|
|
1280
|
+
torch.float32
|
|
1281
|
+
if (
|
|
1282
|
+
optimizer_state_dtypes is None
|
|
1283
|
+
or "momentum2" not in optimizer_state_dtypes
|
|
1284
|
+
)
|
|
1285
|
+
else optimizer_state_dtypes["momentum2"].as_dtype()
|
|
1286
|
+
)
|
|
1287
|
+
self._apply_split(
|
|
1288
|
+
construct_split_state(
|
|
1289
|
+
embedding_specs,
|
|
1290
|
+
rowwise=rowwise,
|
|
1291
|
+
cacheable=False,
|
|
1292
|
+
placement=(
|
|
1293
|
+
EmbeddingLocation.MANAGED
|
|
1294
|
+
if ((not rowwise) and uvm_non_rowwise_momentum)
|
|
1295
|
+
else None
|
|
1296
|
+
),
|
|
1297
|
+
),
|
|
1298
|
+
prefix="momentum2",
|
|
1299
|
+
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
|
|
1300
|
+
# but got `Type[torch.float32]`.
|
|
1301
|
+
dtype=momentum2_dtype,
|
|
1302
|
+
uvm_host_mapped=self.uvm_host_mapped,
|
|
1303
|
+
)
|
|
1304
|
+
else:
|
|
1305
|
+
# NOTE: make TorchScript work!
|
|
1306
|
+
self._register_nonpersistent_buffers("momentum2")
|
|
1307
|
+
if self._used_rowwise_adagrad_with_counter:
|
|
1308
|
+
self._apply_split(
|
|
1309
|
+
construct_split_state(
|
|
1310
|
+
embedding_specs,
|
|
1311
|
+
rowwise=True,
|
|
1312
|
+
cacheable=False,
|
|
1313
|
+
),
|
|
1314
|
+
prefix="prev_iter",
|
|
1315
|
+
# TODO: ideally we should use int64 to track iter but it failed to compile.
|
|
1316
|
+
# It may be related to low precision training code. Currently using float32
|
|
1317
|
+
# as a workaround while investigating the issue.
|
|
1318
|
+
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
|
|
1319
|
+
# but got `Type[torch.float32]`.
|
|
1320
|
+
dtype=torch.float32,
|
|
1321
|
+
uvm_host_mapped=self.uvm_host_mapped,
|
|
1322
|
+
)
|
|
1323
|
+
self._apply_split(
|
|
1324
|
+
construct_split_state(
|
|
1325
|
+
embedding_specs,
|
|
1326
|
+
rowwise=True,
|
|
1327
|
+
cacheable=False,
|
|
1328
|
+
),
|
|
1329
|
+
prefix="row_counter",
|
|
1330
|
+
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
|
|
1331
|
+
# but got `Type[torch.float32]`.
|
|
1332
|
+
dtype=torch.float32,
|
|
1333
|
+
uvm_host_mapped=self.uvm_host_mapped,
|
|
1334
|
+
)
|
|
1335
|
+
self.register_buffer(
|
|
1336
|
+
"max_counter", torch.tensor([1], dtype=torch.float32)
|
|
1337
|
+
)
|
|
1338
|
+
elif self._used_rowwise_adagrad_with_global_weight_decay:
|
|
1339
|
+
self._apply_split(
|
|
1340
|
+
construct_split_state(
|
|
1341
|
+
embedding_specs,
|
|
1342
|
+
rowwise=True,
|
|
1343
|
+
cacheable=False,
|
|
1344
|
+
),
|
|
1345
|
+
prefix="prev_iter",
|
|
1346
|
+
# TODO: ideally we should use int64 to track iter but it failed to compile.
|
|
1347
|
+
# It may be related to low precision training code. Currently using float32
|
|
1348
|
+
# as a workaround while investigating the issue.
|
|
1349
|
+
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
|
|
1350
|
+
# but got `Type[torch.float32]`.
|
|
1351
|
+
dtype=torch.float32,
|
|
1352
|
+
uvm_host_mapped=self.uvm_host_mapped,
|
|
1353
|
+
)
|
|
1354
|
+
self._register_nonpersistent_buffers("row_counter")
|
|
1355
|
+
self.register_buffer(
|
|
1356
|
+
"max_counter",
|
|
1357
|
+
torch.ones(1, dtype=torch.float32, device=self.current_device),
|
|
1358
|
+
persistent=False,
|
|
1359
|
+
)
|
|
1360
|
+
elif optimizer == OptimType.ADAM and self.use_rowwise_bias_correction:
|
|
1361
|
+
self._apply_split(
|
|
1362
|
+
construct_split_state(
|
|
1363
|
+
embedding_specs,
|
|
1364
|
+
rowwise=True,
|
|
1365
|
+
cacheable=False,
|
|
1366
|
+
),
|
|
1367
|
+
prefix="row_counter",
|
|
1368
|
+
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
|
|
1369
|
+
# but got `Type[torch.float32]`.
|
|
1370
|
+
dtype=torch.float32,
|
|
1371
|
+
uvm_host_mapped=self.uvm_host_mapped,
|
|
1372
|
+
)
|
|
1373
|
+
else:
|
|
1374
|
+
self._register_nonpersistent_buffers("prev_iter")
|
|
1375
|
+
self._register_nonpersistent_buffers("row_counter")
|
|
1376
|
+
self.register_buffer(
|
|
1377
|
+
"max_counter",
|
|
1378
|
+
torch.ones(1, dtype=torch.float32, device=self.current_device),
|
|
1379
|
+
persistent=False,
|
|
1380
|
+
)
|
|
1381
|
+
if (
|
|
1382
|
+
optimizer
|
|
1383
|
+
in (
|
|
1384
|
+
OptimType.ADAM,
|
|
1385
|
+
OptimType.LAMB,
|
|
1386
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
1387
|
+
OptimType.PARTIAL_ROWWISE_LAMB,
|
|
1388
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
1389
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
1390
|
+
)
|
|
1391
|
+
or self._used_rowwise_adagrad_with_global_weight_decay
|
|
1392
|
+
):
|
|
1393
|
+
self.register_buffer(
|
|
1394
|
+
"iter",
|
|
1395
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1396
|
+
)
|
|
1397
|
+
else:
|
|
1398
|
+
self.register_buffer(
|
|
1399
|
+
"iter",
|
|
1400
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1401
|
+
persistent=False,
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
self.iter_cpu: torch.Tensor = torch.zeros(1, dtype=torch.int64, device="cpu")
|
|
1405
|
+
|
|
1406
|
+
cache_state = construct_cache_state(rows, locations, self.feature_table_map)
|
|
1407
|
+
|
|
1408
|
+
# Add table-wise cache miss counter
|
|
1409
|
+
if self.record_cache_metrics.record_tablewise_cache_miss:
|
|
1410
|
+
num_tables = len(cache_state.cache_hash_size_cumsum) - 1
|
|
1411
|
+
self.register_buffer(
|
|
1412
|
+
"table_wise_cache_miss",
|
|
1413
|
+
torch.zeros(
|
|
1414
|
+
num_tables,
|
|
1415
|
+
device=self.current_device,
|
|
1416
|
+
dtype=torch.int64,
|
|
1417
|
+
),
|
|
1418
|
+
)
|
|
1419
|
+
# NOTE: make TorchScript work!
|
|
1420
|
+
else:
|
|
1421
|
+
self.register_buffer(
|
|
1422
|
+
"table_wise_cache_miss",
|
|
1423
|
+
torch.zeros(
|
|
1424
|
+
0,
|
|
1425
|
+
device=self.current_device,
|
|
1426
|
+
dtype=torch.int64,
|
|
1427
|
+
),
|
|
1428
|
+
)
|
|
1429
|
+
|
|
1430
|
+
self._apply_cache_state(
|
|
1431
|
+
cache_state,
|
|
1432
|
+
cache_algorithm,
|
|
1433
|
+
cache_load_factor,
|
|
1434
|
+
cache_sets,
|
|
1435
|
+
cache_reserved_memory,
|
|
1436
|
+
cache_precision,
|
|
1437
|
+
)
|
|
1438
|
+
|
|
1439
|
+
self.log(f"Contents: {table_names}")
|
|
1440
|
+
self.log(
|
|
1441
|
+
f"Using fused {optimizer} with optimizer_args={self.optimizer_args if optimizer != OptimType.NONE else None}"
|
|
1442
|
+
)
|
|
1443
|
+
self.log(
|
|
1444
|
+
f"Using rowwise_adagrad_with_counter={self._used_rowwise_adagrad_with_counter}"
|
|
1445
|
+
)
|
|
1446
|
+
|
|
1447
|
+
self.step = 0
|
|
1448
|
+
self.last_reported_step = 0
|
|
1449
|
+
self.last_reported_uvm_stats: list[float] = []
|
|
1450
|
+
# Track number of times detailed memory breakdown has been reported
|
|
1451
|
+
self.detailed_mem_breakdown_report_count = 0
|
|
1452
|
+
# Set max number of reports for detailed memory breakdown
|
|
1453
|
+
self.max_detailed_mem_breakdown_reports = 10
|
|
1454
|
+
|
|
1455
|
+
# Check whether to use TBE v2
|
|
1456
|
+
is_experimental = False
|
|
1457
|
+
if use_experimental_tbe:
|
|
1458
|
+
is_experimental = True
|
|
1459
|
+
self.log("use_experimental_tbe is set to True; Using experimental TBE")
|
|
1460
|
+
|
|
1461
|
+
elif int(os.environ.get("FBGEMM_EXPERIMENTAL_TBE", "0")) == 1:
|
|
1462
|
+
# Keep the old feature enablement mechanism to ensure no negative impact on models that have already adopted TBE v2
|
|
1463
|
+
is_experimental = True
|
|
1464
|
+
self.log("FBGEMM_EXPERIMENTAL_TBE is set to True; Using experimental TBE")
|
|
1465
|
+
|
|
1466
|
+
# NOTE: Keep this disabled for now until the backend lands into Pyper
|
|
1467
|
+
# elif FeatureGateName.TBE_V2.is_enabled():
|
|
1468
|
+
# is_experimental = True
|
|
1469
|
+
# self.log("TBE_V2 Knob is set to True; Using experimental TBE")
|
|
1470
|
+
|
|
1471
|
+
self.is_experimental: bool = is_experimental
|
|
1472
|
+
|
|
1473
|
+
# Get a debug function pointer
|
|
1474
|
+
self._debug_print_input_stats: Callable[..., None] = (
|
|
1475
|
+
self._debug_print_input_stats_factory()
|
|
1476
|
+
)
|
|
1477
|
+
|
|
1478
|
+
# Get a reporter function pointer
|
|
1479
|
+
self._report_input_params: Callable[..., None] = (
|
|
1480
|
+
self.__report_input_params_factory()
|
|
1481
|
+
)
|
|
1482
|
+
|
|
1483
|
+
if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
|
|
1484
|
+
# Register writeback hook for Exact_SGD optimizer
|
|
1485
|
+
self.log(
|
|
1486
|
+
"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled."
|
|
1487
|
+
)
|
|
1488
|
+
# pyre-fixme[6]: Expected `typing.Callable[[Module, Union[Tensor, typing.Tuple[Tensor, ...]]], Union[None, Tensor, typing.Tuple[Tensor, ...]]]`
|
|
1489
|
+
self.register_full_backward_pre_hook(self.writeback_hook)
|
|
1490
|
+
|
|
1491
|
+
if embedding_table_index_type not in [torch.int32, torch.int64]:
|
|
1492
|
+
raise ValueError(
|
|
1493
|
+
f"embedding_table_index_type must be torch.int32 or torch.int64, but got {embedding_table_index_type}"
|
|
1494
|
+
)
|
|
1495
|
+
self.embedding_table_index_type: torch.dtype = embedding_table_index_type
|
|
1496
|
+
if embedding_table_offset_type not in [torch.int32, torch.int64]:
|
|
1497
|
+
raise ValueError(
|
|
1498
|
+
f"embedding_table_offset_type must be torch.int32 or torch.int64, but got {embedding_table_offset_type}"
|
|
1499
|
+
)
|
|
1500
|
+
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type
|
|
1501
|
+
|
|
1502
|
+
self.prefetched_info_list: list[PrefetchedInfo] = torch.jit.annotate(
|
|
1503
|
+
list[PrefetchedInfo], []
|
|
1504
|
+
)
|
|
1505
|
+
if self.enable_raw_embedding_streaming:
|
|
1506
|
+
self.res_params: RESParams = res_params or RESParams()
|
|
1507
|
+
self.res_params.table_sizes = [0] + list(accumulate(rows))
|
|
1508
|
+
res_port_from_env = os.getenv("LOCAL_RES_PORT")
|
|
1509
|
+
self.res_params.res_server_port = (
|
|
1510
|
+
int(res_port_from_env) if res_port_from_env else 0
|
|
1511
|
+
)
|
|
1512
|
+
# pyre-fixme[4]: Attribute must be annotated.
|
|
1513
|
+
self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
|
|
1514
|
+
self.uuid,
|
|
1515
|
+
self.enable_raw_embedding_streaming,
|
|
1516
|
+
self.res_params.res_store_shards,
|
|
1517
|
+
self.res_params.res_server_port,
|
|
1518
|
+
self.res_params.table_names,
|
|
1519
|
+
self.res_params.table_offsets,
|
|
1520
|
+
self.res_params.table_sizes,
|
|
1521
|
+
)
|
|
1522
|
+
logging.info(
|
|
1523
|
+
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
|
|
1524
|
+
)
|
|
1525
|
+
|
|
1526
|
+
@torch.jit.ignore
|
|
1527
|
+
def log(self, msg: str) -> None:
|
|
1528
|
+
"""
|
|
1529
|
+
Log with TBE id prefix to distinguish between multiple TBE instances
|
|
1530
|
+
per process
|
|
1531
|
+
|
|
1532
|
+
Args:
|
|
1533
|
+
msg (str): The message to print
|
|
1534
|
+
|
|
1535
|
+
Returns:
|
|
1536
|
+
None
|
|
1537
|
+
"""
|
|
1538
|
+
logging.info(f"[TBE={self.uuid}] {msg}")
|
|
1539
|
+
|
|
1540
|
+
def _register_nonpersistent_buffers(self, prefix: str) -> None:
|
|
1541
|
+
# NOTE: make TorchScript work!
|
|
1542
|
+
self.register_buffer(
|
|
1543
|
+
f"{prefix}_dev",
|
|
1544
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1545
|
+
persistent=False,
|
|
1546
|
+
)
|
|
1547
|
+
self.register_buffer(
|
|
1548
|
+
f"{prefix}_host",
|
|
1549
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1550
|
+
persistent=False,
|
|
1551
|
+
)
|
|
1552
|
+
self.register_buffer(
|
|
1553
|
+
f"{prefix}_uvm",
|
|
1554
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1555
|
+
persistent=False,
|
|
1556
|
+
)
|
|
1557
|
+
self.register_buffer(
|
|
1558
|
+
f"{prefix}_placements",
|
|
1559
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1560
|
+
persistent=False,
|
|
1561
|
+
)
|
|
1562
|
+
self.register_buffer(
|
|
1563
|
+
f"{prefix}_offsets",
|
|
1564
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
1565
|
+
persistent=False,
|
|
1566
|
+
)
|
|
1567
|
+
|
|
1568
|
+
@staticmethod
|
|
1569
|
+
def get_table_name_for_logging(table_names: Optional[list[str]]) -> str:
|
|
1570
|
+
"""
|
|
1571
|
+
Given a list of all table names in the TBE, generate a string to
|
|
1572
|
+
represent them in logging. If there is more than one table, this method
|
|
1573
|
+
will count them than list them.
|
|
1574
|
+
|
|
1575
|
+
Args:
|
|
1576
|
+
table_names (Optional[List[str]]): A list of table anmes in TBE
|
|
1577
|
+
|
|
1578
|
+
Returns:
|
|
1579
|
+
A string that represents tables in logging
|
|
1580
|
+
"""
|
|
1581
|
+
if table_names is None:
|
|
1582
|
+
return "<Unknown>"
|
|
1583
|
+
# Do this because sometimes multiple shards of the same table could appear
|
|
1584
|
+
# in one TBE.
|
|
1585
|
+
table_name_set = sorted(set(table_names))
|
|
1586
|
+
if len(table_name_set) == 1:
|
|
1587
|
+
return next(iter(table_name_set))
|
|
1588
|
+
return f"<{len(table_name_set)} tables>: {table_name_set}"
|
|
1589
|
+
|
|
1590
|
+
@staticmethod
|
|
1591
|
+
def get_prefetch_passes(
|
|
1592
|
+
multipass_prefetch_config: Optional[MultiPassPrefetchConfig],
|
|
1593
|
+
input_tensor: Tensor,
|
|
1594
|
+
output_tensor: Tensor,
|
|
1595
|
+
) -> list[tuple[Tensor, Tensor, int]]:
|
|
1596
|
+
"""
|
|
1597
|
+
Given inputs (the indices to forward), partition the input and output
|
|
1598
|
+
into smaller chunks and return them as a list of tuples
|
|
1599
|
+
(input[start_idx:end_idx], output[start_idx:end_idx], start_idx).
|
|
1600
|
+
|
|
1601
|
+
The caller must guarantee that input and output have non-zero dimension
|
|
1602
|
+
0. The returned segments are guaranteed to completely and
|
|
1603
|
+
non-overlappingly cover the input tensor.
|
|
1604
|
+
|
|
1605
|
+
In non-multipass-prefetch mode, it returns the input/output tensor
|
|
1606
|
+
itself.
|
|
1607
|
+
|
|
1608
|
+
Args:
|
|
1609
|
+
multipass_prefetch_config (Optional[MultiPassPrefetchConfig]):
|
|
1610
|
+
A config for multi-pass cache prefetch. If None, multi-pass
|
|
1611
|
+
prefetch is not used.
|
|
1612
|
+
|
|
1613
|
+
input_tensor (Tensor): The input tensor to be partitioned
|
|
1614
|
+
|
|
1615
|
+
output_tensor (Tensor): The output tensor to be partitioned
|
|
1616
|
+
|
|
1617
|
+
Returns:
|
|
1618
|
+
A list of partitioned inputs and outputs (List[Tuple[Tensor,
|
|
1619
|
+
Tensor, int]])
|
|
1620
|
+
"""
|
|
1621
|
+
if multipass_prefetch_config is None:
|
|
1622
|
+
return [(input_tensor, output_tensor, 0)]
|
|
1623
|
+
mpp_config: MultiPassPrefetchConfig = multipass_prefetch_config
|
|
1624
|
+
|
|
1625
|
+
N = input_tensor.size(0)
|
|
1626
|
+
if N <= mpp_config.num_passes or mpp_config.num_passes == 1:
|
|
1627
|
+
# One row per pass, just don't split
|
|
1628
|
+
return [(input_tensor, output_tensor, 0)]
|
|
1629
|
+
|
|
1630
|
+
pass_size: int = max(
|
|
1631
|
+
(N + mpp_config.num_passes - 1) // mpp_config.num_passes,
|
|
1632
|
+
mpp_config.min_splitable_pass_size,
|
|
1633
|
+
)
|
|
1634
|
+
|
|
1635
|
+
return list(
|
|
1636
|
+
zip(
|
|
1637
|
+
torch.split(input_tensor, pass_size),
|
|
1638
|
+
torch.split(output_tensor, pass_size),
|
|
1639
|
+
range(0, N, pass_size),
|
|
1640
|
+
)
|
|
1641
|
+
)
|
|
1642
|
+
|
|
1643
|
+
def get_states(self, prefix: str) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
1644
|
+
"""
|
|
1645
|
+
Get a state of a given tensor (`prefix`)
|
|
1646
|
+
|
|
1647
|
+
Args:
|
|
1648
|
+
prefix (str): A prefix of the state to obtain
|
|
1649
|
+
|
|
1650
|
+
Returns:
|
|
1651
|
+
A tuple of tensors corresponding to the obtained state containing
|
|
1652
|
+
|
|
1653
|
+
(1) A GPU state tensor
|
|
1654
|
+
|
|
1655
|
+
(2) A CPU state tensor
|
|
1656
|
+
|
|
1657
|
+
(3) A UVM state tensor
|
|
1658
|
+
|
|
1659
|
+
(4) A placement tensor - containing placements of embedding tables
|
|
1660
|
+
(torch.int32_t tensor). (0 = DEVICE, 1 = MANAGED, 2 =
|
|
1661
|
+
MANAGED_CACHING, 3 = HOST, 4 = MTIA)
|
|
1662
|
+
|
|
1663
|
+
(5) An offset tensor - containing the relative positions of
|
|
1664
|
+
embedding tables in the corresponding state tensor (GPU, CPU,
|
|
1665
|
+
or UVM state tensor)
|
|
1666
|
+
"""
|
|
1667
|
+
if not hasattr(self, f"{prefix}_physical_placements"):
|
|
1668
|
+
raise DoesNotHavePrefix()
|
|
1669
|
+
dev_param = getattr(self, f"{prefix}_dev")
|
|
1670
|
+
host_param = getattr(self, f"{prefix}_host")
|
|
1671
|
+
uvm_param = getattr(self, f"{prefix}_uvm")
|
|
1672
|
+
placements = getattr(self, f"{prefix}_physical_placements")
|
|
1673
|
+
offsets = getattr(self, f"{prefix}_physical_offsets")
|
|
1674
|
+
return (
|
|
1675
|
+
dev_param,
|
|
1676
|
+
host_param,
|
|
1677
|
+
uvm_param,
|
|
1678
|
+
torch.tensor(placements, dtype=torch.int32),
|
|
1679
|
+
torch.tensor(offsets, dtype=torch.int64),
|
|
1680
|
+
)
|
|
1681
|
+
|
|
1682
|
+
def get_all_states(self) -> list[tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
|
|
1683
|
+
"""
|
|
1684
|
+
Get all states in the TBE (`weights`, `momentum1`, `momentum2`,
|
|
1685
|
+
`prev_iter`, and `row_counter`)
|
|
1686
|
+
|
|
1687
|
+
Returns:
|
|
1688
|
+
A list of states. Each state is a tuple of tensors (GPU state
|
|
1689
|
+
tensor, CPU state tensor, UVM state tensor, placement tensor and
|
|
1690
|
+
offset tensor)
|
|
1691
|
+
"""
|
|
1692
|
+
all_states = []
|
|
1693
|
+
for prefix in ["weights", "momentum1", "momentum2", "prev_iter", "row_counter"]:
|
|
1694
|
+
try:
|
|
1695
|
+
all_states.append(self.get_states(prefix))
|
|
1696
|
+
except DoesNotHavePrefix:
|
|
1697
|
+
pass
|
|
1698
|
+
return all_states
|
|
1699
|
+
|
|
1700
|
+
@torch.jit.export
|
|
1701
|
+
def get_cache_miss_counter(self) -> Tensor:
|
|
1702
|
+
"""
|
|
1703
|
+
Get the cache miss counter. `cache_miss_counter` contains two items:
|
|
1704
|
+
|
|
1705
|
+
(1) `cache_miss_forward_count` which records the total number of
|
|
1706
|
+
forwards which has at least one cache miss
|
|
1707
|
+
|
|
1708
|
+
(2) `unique_cache_miss_count` which records to total number of unique
|
|
1709
|
+
(dedup) cache misses
|
|
1710
|
+
|
|
1711
|
+
Returns:
|
|
1712
|
+
The cache miss counter
|
|
1713
|
+
"""
|
|
1714
|
+
# pyre-fixme[7]: Expected `Tensor` but got `Union[Module, Tensor]`.
|
|
1715
|
+
return self.cache_miss_counter
|
|
1716
|
+
|
|
1717
|
+
@torch.jit.export
|
|
1718
|
+
def get_table_wise_cache_miss(self) -> Tensor:
|
|
1719
|
+
"""
|
|
1720
|
+
Get the table-wise cache miss tensor. `table_wise_cache_miss` contains
|
|
1721
|
+
all the cache miss count for each table in this embedding table object:
|
|
1722
|
+
|
|
1723
|
+
Returns:
|
|
1724
|
+
The table-wise cache miss tensor
|
|
1725
|
+
"""
|
|
1726
|
+
return self.table_wise_cache_miss
|
|
1727
|
+
|
|
1728
|
+
# The callback function for AsyncTimer to record duration to different event
|
|
1729
|
+
def _report_duration(
|
|
1730
|
+
self,
|
|
1731
|
+
it_step: int,
|
|
1732
|
+
dur_ms: float,
|
|
1733
|
+
event_name: str,
|
|
1734
|
+
) -> None:
|
|
1735
|
+
assert (
|
|
1736
|
+
self.stats_reporter
|
|
1737
|
+
), "We should not be here. AsyncTimer only happens with reporter present."
|
|
1738
|
+
self.stats_reporter.report_duration(
|
|
1739
|
+
iteration_step=it_step,
|
|
1740
|
+
event_name=event_name,
|
|
1741
|
+
duration_ms=dur_ms,
|
|
1742
|
+
embedding_id=self.logging_table_name,
|
|
1743
|
+
tbe_id=self.uuid,
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
def _get_tensor_memory(self, tensor_name: str) -> int:
|
|
1747
|
+
"""Get memory usage of a tensor in bytes."""
|
|
1748
|
+
if not hasattr(self, tensor_name):
|
|
1749
|
+
self.log(f"Tensor '{tensor_name}' not found, using 0 bytes")
|
|
1750
|
+
return 0
|
|
1751
|
+
tensor = getattr(self, tensor_name)
|
|
1752
|
+
return tensor.numel() * tensor.element_size()
|
|
1753
|
+
|
|
1754
|
+
def _categorize_memory_by_location(
|
|
1755
|
+
self, tensor_names: list[str]
|
|
1756
|
+
) -> tuple[int, int]:
|
|
1757
|
+
"""Categorize memory into HBM and UVM for given tensors.
|
|
1758
|
+
|
|
1759
|
+
Returns:
|
|
1760
|
+
(hbm_bytes, uvm_bytes)
|
|
1761
|
+
"""
|
|
1762
|
+
uvm_set = set(self._uvm_tensors_log)
|
|
1763
|
+
hbm_bytes = 0
|
|
1764
|
+
uvm_bytes = 0
|
|
1765
|
+
|
|
1766
|
+
for name in tensor_names:
|
|
1767
|
+
size = self._get_tensor_memory(name)
|
|
1768
|
+
if name in uvm_set:
|
|
1769
|
+
uvm_bytes += size
|
|
1770
|
+
else:
|
|
1771
|
+
hbm_bytes += size
|
|
1772
|
+
|
|
1773
|
+
return hbm_bytes, uvm_bytes
|
|
1774
|
+
|
|
1775
|
+
def _report_hbm_breakdown(
|
|
1776
|
+
self,
|
|
1777
|
+
stats_reporter: TBEStatsReporter,
|
|
1778
|
+
embeddings: int,
|
|
1779
|
+
optimizer_states: int,
|
|
1780
|
+
cache: int,
|
|
1781
|
+
total_static_sparse: int,
|
|
1782
|
+
ephemeral: int,
|
|
1783
|
+
) -> None:
|
|
1784
|
+
"""Report HBM memory breakdown to stats reporter."""
|
|
1785
|
+
stats_reporter.report_data_amount(
|
|
1786
|
+
iteration_step=self.step,
|
|
1787
|
+
event_name="tbe.hbm.embeddings",
|
|
1788
|
+
data_bytes=embeddings,
|
|
1789
|
+
embedding_id=self.logging_table_name,
|
|
1790
|
+
tbe_id=self.uuid,
|
|
1791
|
+
)
|
|
1792
|
+
stats_reporter.report_data_amount(
|
|
1793
|
+
iteration_step=self.step,
|
|
1794
|
+
event_name="tbe.hbm.optimizer_states",
|
|
1795
|
+
data_bytes=optimizer_states,
|
|
1796
|
+
embedding_id=self.logging_table_name,
|
|
1797
|
+
tbe_id=self.uuid,
|
|
1798
|
+
)
|
|
1799
|
+
stats_reporter.report_data_amount(
|
|
1800
|
+
iteration_step=self.step,
|
|
1801
|
+
event_name="tbe.hbm.cache",
|
|
1802
|
+
data_bytes=cache,
|
|
1803
|
+
embedding_id=self.logging_table_name,
|
|
1804
|
+
tbe_id=self.uuid,
|
|
1805
|
+
)
|
|
1806
|
+
stats_reporter.report_data_amount(
|
|
1807
|
+
iteration_step=self.step,
|
|
1808
|
+
event_name="tbe.hbm.total_static_sparse",
|
|
1809
|
+
data_bytes=total_static_sparse,
|
|
1810
|
+
embedding_id=self.logging_table_name,
|
|
1811
|
+
tbe_id=self.uuid,
|
|
1812
|
+
)
|
|
1813
|
+
stats_reporter.report_data_amount(
|
|
1814
|
+
iteration_step=self.step,
|
|
1815
|
+
event_name="tbe.hbm.ephemeral",
|
|
1816
|
+
data_bytes=ephemeral,
|
|
1817
|
+
embedding_id=self.logging_table_name,
|
|
1818
|
+
tbe_id=self.uuid,
|
|
1819
|
+
)
|
|
1820
|
+
|
|
1821
|
+
def _report_uvm_breakdown(
|
|
1822
|
+
self,
|
|
1823
|
+
stats_reporter: TBEStatsReporter,
|
|
1824
|
+
embeddings: int,
|
|
1825
|
+
optimizer_states: int,
|
|
1826
|
+
cache: int,
|
|
1827
|
+
total_static_sparse: int,
|
|
1828
|
+
ephemeral: int,
|
|
1829
|
+
) -> None:
|
|
1830
|
+
"""Report UVM memory breakdown to stats reporter."""
|
|
1831
|
+
stats_reporter.report_data_amount(
|
|
1832
|
+
iteration_step=self.step,
|
|
1833
|
+
event_name="tbe.uvm.embeddings",
|
|
1834
|
+
data_bytes=embeddings,
|
|
1835
|
+
embedding_id=self.logging_table_name,
|
|
1836
|
+
tbe_id=self.uuid,
|
|
1837
|
+
)
|
|
1838
|
+
stats_reporter.report_data_amount(
|
|
1839
|
+
iteration_step=self.step,
|
|
1840
|
+
event_name="tbe.uvm.optimizer_states",
|
|
1841
|
+
data_bytes=optimizer_states,
|
|
1842
|
+
embedding_id=self.logging_table_name,
|
|
1843
|
+
tbe_id=self.uuid,
|
|
1844
|
+
)
|
|
1845
|
+
stats_reporter.report_data_amount(
|
|
1846
|
+
iteration_step=self.step,
|
|
1847
|
+
event_name="tbe.uvm.cache",
|
|
1848
|
+
data_bytes=cache,
|
|
1849
|
+
embedding_id=self.logging_table_name,
|
|
1850
|
+
tbe_id=self.uuid,
|
|
1851
|
+
)
|
|
1852
|
+
stats_reporter.report_data_amount(
|
|
1853
|
+
iteration_step=self.step,
|
|
1854
|
+
event_name="tbe.uvm.total_static_sparse",
|
|
1855
|
+
data_bytes=total_static_sparse,
|
|
1856
|
+
embedding_id=self.logging_table_name,
|
|
1857
|
+
tbe_id=self.uuid,
|
|
1858
|
+
)
|
|
1859
|
+
stats_reporter.report_data_amount(
|
|
1860
|
+
iteration_step=self.step,
|
|
1861
|
+
event_name="tbe.uvm.ephemeral",
|
|
1862
|
+
data_bytes=ephemeral,
|
|
1863
|
+
embedding_id=self.logging_table_name,
|
|
1864
|
+
tbe_id=self.uuid,
|
|
1865
|
+
)
|
|
1866
|
+
|
|
1867
|
+
@torch.jit.ignore
|
|
1868
|
+
def _report_tbe_mem_usage(self) -> None:
|
|
1869
|
+
if self.stats_reporter is None:
|
|
1870
|
+
return
|
|
1871
|
+
|
|
1872
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
1873
|
+
if not stats_reporter.should_report(self.step):
|
|
1874
|
+
return
|
|
1875
|
+
|
|
1876
|
+
# Calculate total memory from all parameters and buffers (always needed)
|
|
1877
|
+
total_mem_usage = sum(
|
|
1878
|
+
p.numel() * p.element_size() for p in self.parameters()
|
|
1879
|
+
) + sum(b.numel() * b.element_size() for b in self.buffers())
|
|
1880
|
+
|
|
1881
|
+
# Calculate total HBM and UVM usage (always needed)
|
|
1882
|
+
if self.use_cpu:
|
|
1883
|
+
total_hbm_usage = 0
|
|
1884
|
+
total_uvm_usage = total_mem_usage
|
|
1885
|
+
else:
|
|
1886
|
+
total_uvm_usage = sum(
|
|
1887
|
+
self._get_tensor_memory(name)
|
|
1888
|
+
for name in self._uvm_tensors_log
|
|
1889
|
+
if hasattr(self, name)
|
|
1890
|
+
)
|
|
1891
|
+
total_hbm_usage = total_mem_usage - total_uvm_usage
|
|
1892
|
+
|
|
1893
|
+
# Report total memory usage metrics (always reported for backward compatibility)
|
|
1894
|
+
stats_reporter.report_data_amount(
|
|
1895
|
+
iteration_step=self.step,
|
|
1896
|
+
event_name="tbe.total_hbm_usage",
|
|
1897
|
+
data_bytes=total_hbm_usage,
|
|
1898
|
+
embedding_id=self.logging_table_name,
|
|
1899
|
+
tbe_id=self.uuid,
|
|
1900
|
+
)
|
|
1901
|
+
stats_reporter.report_data_amount(
|
|
1902
|
+
iteration_step=self.step,
|
|
1903
|
+
event_name="tbe.total_uvm_usage",
|
|
1904
|
+
data_bytes=total_uvm_usage,
|
|
1905
|
+
embedding_id=self.logging_table_name,
|
|
1906
|
+
tbe_id=self.uuid,
|
|
1907
|
+
)
|
|
1908
|
+
|
|
1909
|
+
# Only report detailed breakdown for the first max_detailed_mem_breakdown_reports reportable
|
|
1910
|
+
# steps since static sparse memory (weights, optimizer states, cache) is constant
|
|
1911
|
+
if (
|
|
1912
|
+
self.detailed_mem_breakdown_report_count
|
|
1913
|
+
>= self.max_detailed_mem_breakdown_reports
|
|
1914
|
+
):
|
|
1915
|
+
return
|
|
1916
|
+
self.detailed_mem_breakdown_report_count += 1
|
|
1917
|
+
|
|
1918
|
+
# Tensor groups for sparse memory categorization
|
|
1919
|
+
weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]
|
|
1920
|
+
optimizer_tensors = [
|
|
1921
|
+
"momentum1_dev",
|
|
1922
|
+
"momentum1_host",
|
|
1923
|
+
"momentum1_uvm",
|
|
1924
|
+
"momentum2_dev",
|
|
1925
|
+
"momentum2_host",
|
|
1926
|
+
"momentum2_uvm",
|
|
1927
|
+
]
|
|
1928
|
+
cache_tensors = [
|
|
1929
|
+
"lxu_cache_weights",
|
|
1930
|
+
"lxu_cache_state",
|
|
1931
|
+
"lxu_state",
|
|
1932
|
+
"cache_hash_size_cumsum",
|
|
1933
|
+
"cache_index_table_map",
|
|
1934
|
+
"cache_miss_counter",
|
|
1935
|
+
"lxu_cache_locking_counter",
|
|
1936
|
+
]
|
|
1937
|
+
|
|
1938
|
+
# Calculate total memory for each component
|
|
1939
|
+
weights_total = sum(self._get_tensor_memory(t) for t in weight_tensors)
|
|
1940
|
+
optimizer_total = sum(self._get_tensor_memory(t) for t in optimizer_tensors)
|
|
1941
|
+
cache_total = sum(self._get_tensor_memory(t) for t in cache_tensors)
|
|
1942
|
+
|
|
1943
|
+
# Categorize memory by location (HBM vs UVM)
|
|
1944
|
+
if self.use_cpu:
|
|
1945
|
+
weights_hbm, weights_uvm = 0, weights_total
|
|
1946
|
+
opt_hbm, opt_uvm = 0, optimizer_total
|
|
1947
|
+
cache_hbm, cache_uvm = 0, cache_total
|
|
1948
|
+
else:
|
|
1949
|
+
weights_hbm, weights_uvm = self._categorize_memory_by_location(
|
|
1950
|
+
weight_tensors
|
|
1951
|
+
)
|
|
1952
|
+
opt_hbm, opt_uvm = self._categorize_memory_by_location(optimizer_tensors)
|
|
1953
|
+
cache_hbm, cache_uvm = self._categorize_memory_by_location(cache_tensors)
|
|
1954
|
+
|
|
1955
|
+
# Calculate ephemeral memory split between HBM and UVM
|
|
1956
|
+
static_sparse_hbm = weights_hbm + opt_hbm + cache_hbm
|
|
1957
|
+
static_sparse_uvm = weights_uvm + opt_uvm + cache_uvm
|
|
1958
|
+
ephemeral_hbm = total_hbm_usage - static_sparse_hbm
|
|
1959
|
+
ephemeral_uvm = total_uvm_usage - static_sparse_uvm
|
|
1960
|
+
|
|
1961
|
+
# Report granular memory breakdowns
|
|
1962
|
+
self._report_hbm_breakdown(
|
|
1963
|
+
stats_reporter,
|
|
1964
|
+
weights_hbm,
|
|
1965
|
+
opt_hbm,
|
|
1966
|
+
cache_hbm,
|
|
1967
|
+
static_sparse_hbm,
|
|
1968
|
+
ephemeral_hbm,
|
|
1969
|
+
)
|
|
1970
|
+
self._report_uvm_breakdown(
|
|
1971
|
+
stats_reporter,
|
|
1972
|
+
weights_uvm,
|
|
1973
|
+
opt_uvm,
|
|
1974
|
+
cache_uvm,
|
|
1975
|
+
static_sparse_uvm,
|
|
1976
|
+
ephemeral_uvm,
|
|
1977
|
+
)
|
|
1978
|
+
|
|
1979
|
+
@torch.jit.ignore
|
|
1980
|
+
def _report_io_size_count(self, event: str, data: Tensor) -> Tensor:
|
|
1981
|
+
if self.stats_reporter is None:
|
|
1982
|
+
return data
|
|
1983
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
1984
|
+
if stats_reporter.should_report(self.step):
|
|
1985
|
+
stats_reporter.report_data_amount(
|
|
1986
|
+
iteration_step=self.step,
|
|
1987
|
+
event_name=f"tbe.{event}_size",
|
|
1988
|
+
data_bytes=data.element_size() * data.numel(),
|
|
1989
|
+
embedding_id=self.logging_table_name,
|
|
1990
|
+
tbe_id=self.uuid,
|
|
1991
|
+
)
|
|
1992
|
+
stats_reporter.report_data_amount(
|
|
1993
|
+
iteration_step=self.step,
|
|
1994
|
+
event_name=f"tbe.{event}_count",
|
|
1995
|
+
data_bytes=data.numel(),
|
|
1996
|
+
embedding_id=self.logging_table_name,
|
|
1997
|
+
tbe_id=self.uuid,
|
|
1998
|
+
)
|
|
1999
|
+
return data
|
|
2000
|
+
|
|
2001
|
+
@torch.jit.ignore
|
|
2002
|
+
def _generate_vbe_metadata(
|
|
2003
|
+
self,
|
|
2004
|
+
offsets: Tensor,
|
|
2005
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
2006
|
+
) -> invokers.lookup_args.VBEMetadata:
|
|
2007
|
+
# Blocking D2H copy, but only runs at first call
|
|
2008
|
+
self.feature_dims = self.feature_dims.cpu()
|
|
2009
|
+
if batch_size_per_feature_per_rank is not None:
|
|
2010
|
+
assert self.optimizer in (
|
|
2011
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
2012
|
+
OptimType.EXACT_SGD,
|
|
2013
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
2014
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
2015
|
+
OptimType.NONE,
|
|
2016
|
+
OptimType.ADAM,
|
|
2017
|
+
), (
|
|
2018
|
+
"Variable batch size TBE support is enabled for "
|
|
2019
|
+
"OptimType.EXACT_ROWWISE_ADAGRAD,EXACT_SGD, "
|
|
2020
|
+
"ENSEMBLE_ROWWISE_ADAGRAD, NONE, and ADAM only"
|
|
2021
|
+
)
|
|
2022
|
+
return generate_vbe_metadata(
|
|
2023
|
+
offsets,
|
|
2024
|
+
batch_size_per_feature_per_rank,
|
|
2025
|
+
self.pooling_mode,
|
|
2026
|
+
self.feature_dims,
|
|
2027
|
+
self.current_device,
|
|
2028
|
+
)
|
|
2029
|
+
|
|
2030
|
+
@torch.jit.ignore
|
|
2031
|
+
def _feature_is_enabled(self, feature: FeatureGateName) -> bool:
|
|
2032
|
+
# Define proxy method so that it can be marked with @torch.jit.ignore
|
|
2033
|
+
# This allows models using this class to compile correctly
|
|
2034
|
+
return FeatureGate.is_enabled(feature)
|
|
2035
|
+
|
|
2036
|
+
def writeback_update_gradient(
|
|
2037
|
+
self, indices: torch.Tensor, offsets: torch.Tensor, grad: Tensor
|
|
2038
|
+
) -> Tensor:
|
|
2039
|
+
if indices.numel() == 0:
|
|
2040
|
+
return grad[0]
|
|
2041
|
+
num_of_tables = len(set(self.feature_table_map))
|
|
2042
|
+
assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
|
|
2043
|
+
batch_size = offsets.shape[0] // num_of_tables
|
|
2044
|
+
max_indices = indices.max()
|
|
2045
|
+
non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
|
|
2046
|
+
# disable dedup across different table
|
|
2047
|
+
indices = ((offsets[non_empty_index]) // batch_size) * (
|
|
2048
|
+
1 + max_indices
|
|
2049
|
+
) + indices
|
|
2050
|
+
grad = grad[0]
|
|
2051
|
+
_, idx, counts = torch.unique(
|
|
2052
|
+
indices, dim=0, sorted=True, return_inverse=True, return_counts=True
|
|
2053
|
+
)
|
|
2054
|
+
_, ind_sorted = torch.sort(idx, stable=True)
|
|
2055
|
+
cum_sum = counts.cumsum(0)
|
|
2056
|
+
cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
|
|
2057
|
+
first_indicies = ind_sorted[cum_sum]
|
|
2058
|
+
mask = torch.zeros_like(grad, device=grad.device)
|
|
2059
|
+
original_index = non_empty_index[first_indicies]
|
|
2060
|
+
|
|
2061
|
+
mask[original_index] = grad[original_index]
|
|
2062
|
+
return mask
|
|
2063
|
+
|
|
2064
|
+
# pyre-fixme[2]: For 1st argument expected not ANY
|
|
2065
|
+
def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]:
|
|
2066
|
+
indices = self._indices
|
|
2067
|
+
offsets = self._offsets
|
|
2068
|
+
|
|
2069
|
+
return (self.writeback_update_gradient(indices, offsets, grad),)
|
|
2070
|
+
|
|
2071
|
+
def forward( # noqa: C901
|
|
2072
|
+
self,
|
|
2073
|
+
indices: Tensor,
|
|
2074
|
+
offsets: Tensor,
|
|
2075
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
2076
|
+
feature_requires_grad: Optional[Tensor] = None,
|
|
2077
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
2078
|
+
total_unique_indices: Optional[int] = None,
|
|
2079
|
+
hash_zch_identities: Optional[Tensor] = None,
|
|
2080
|
+
hash_zch_runtime_meta: Optional[Tensor] = None,
|
|
2081
|
+
) -> Tensor:
|
|
2082
|
+
"""
|
|
2083
|
+
The forward pass function that
|
|
2084
|
+
|
|
2085
|
+
(1) Performs input bound checking
|
|
2086
|
+
|
|
2087
|
+
(2) Generates necessary variable batch size embedding (VBE) metadata (if
|
|
2088
|
+
VBE is used)
|
|
2089
|
+
|
|
2090
|
+
(3) Prefetches data from UVM to cache (if
|
|
2091
|
+
`EmbeddingLocation.MANAGED_CACHING` is used and the user has not
|
|
2092
|
+
explicitly prefetched data)
|
|
2093
|
+
|
|
2094
|
+
(4) Performs the embedding table lookup by invoking a corresponding
|
|
2095
|
+
Autograd function (based on the chosen optimizer)
|
|
2096
|
+
|
|
2097
|
+
Args:
|
|
2098
|
+
indices (Tensor): A 1D-tensor that contains indices to be looked up
|
|
2099
|
+
from all embedding table
|
|
2100
|
+
|
|
2101
|
+
offsets (Tensor): A 1D-tensor that conatins offsets of indices.
|
|
2102
|
+
Shape `(B * T + 1)` where `B` = batch size and `T` = the number
|
|
2103
|
+
of features. `offsets[t * B + b + 1] - offsets[t * B + b]` is
|
|
2104
|
+
the length of bag `b` of feature `t`
|
|
2105
|
+
|
|
2106
|
+
per_sample_weights (Optional[Tensor]): An optional 1D-float-tensor that
|
|
2107
|
+
contains per sample weights. If None, **unweighted** embedding
|
|
2108
|
+
lookup will be perform. Otherwise, **weighted** will be used. The
|
|
2109
|
+
length of this tensor must be the same as the length of the
|
|
2110
|
+
`indices` tensor. The value of `per_sample_weights[i]` will be
|
|
2111
|
+
used to multiply with every element in the looked up row
|
|
2112
|
+
`indices[i]`, where `0 <= i < len(per_sample_weights)`.
|
|
2113
|
+
|
|
2114
|
+
feature_requires_grad (Optional[Tensor]): An optional 1D-tensor for
|
|
2115
|
+
indicating if `per_sample_weights` requires gradient. The
|
|
2116
|
+
length of the tensor must be equal to the number of features
|
|
2117
|
+
|
|
2118
|
+
batch_size_per_feature_per_rank (Optional[List[List[int]]]): An
|
|
2119
|
+
optional 2D-tensor that contains batch sizes for every rank and
|
|
2120
|
+
every feature. If None, TBE assumes that **every feature has the
|
|
2121
|
+
same batch size** and computes the batch size from the `offsets`
|
|
2122
|
+
shape. Otherwise, TBE assumes that different features can have
|
|
2123
|
+
different batch sizes and uses the **variable batch size
|
|
2124
|
+
embedding look up mode (VBE)**. Shape (number of features,
|
|
2125
|
+
number of ranks). `batch_size_per_feature_per_rank[f][r]`
|
|
2126
|
+
represents the batch size of feature `f` and rank `r`
|
|
2127
|
+
|
|
2128
|
+
total_unique_indices (Optional[int]): An optional integer that
|
|
2129
|
+
represents the total number of unique indices. This value must
|
|
2130
|
+
be set when using `OptimType.NONE`. This is because TBE
|
|
2131
|
+
requires this information for allocating the weight gradient
|
|
2132
|
+
tensor in the backward pass.
|
|
2133
|
+
|
|
2134
|
+
hash_zch_identities (Optional[Tensor]): The original raw IDs before
|
|
2135
|
+
remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
|
|
2136
|
+
populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
|
|
2137
|
+
and is required for Raw Embedding Streaming (RES) to maintain
|
|
2138
|
+
consistency between training and inference.
|
|
2139
|
+
|
|
2140
|
+
Returns:
|
|
2141
|
+
A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
|
|
2142
|
+
batch size and `total_D` = the sum of all embedding dimensions in the
|
|
2143
|
+
table
|
|
2144
|
+
|
|
2145
|
+
Example:
|
|
2146
|
+
|
|
2147
|
+
>>> import torch
|
|
2148
|
+
>>>
|
|
2149
|
+
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
2150
|
+
>>> EmbeddingLocation,
|
|
2151
|
+
>>> )
|
|
2152
|
+
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
|
2153
|
+
>>> SplitTableBatchedEmbeddingBagsCodegen,
|
|
2154
|
+
>>> ComputeDevice,
|
|
2155
|
+
>>> )
|
|
2156
|
+
>>>
|
|
2157
|
+
>>> # Two tables
|
|
2158
|
+
>>> embedding_specs = [
|
|
2159
|
+
>>> (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA),
|
|
2160
|
+
>>> (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA)
|
|
2161
|
+
>>> ]
|
|
2162
|
+
>>>
|
|
2163
|
+
>>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs)
|
|
2164
|
+
>>> tbe.init_embedding_weights_uniform(-1, 1)
|
|
2165
|
+
>>>
|
|
2166
|
+
>>> print(tbe.split_embedding_weights())
|
|
2167
|
+
[tensor([[-0.9426, 0.7046, 0.4214, -0.0419, 0.1331, -0.7856, -0.8124, -0.2021],
|
|
2168
|
+
[-0.5771, 0.5911, -0.7792, -0.1068, -0.6203, 0.4813, -0.1677, 0.4790],
|
|
2169
|
+
[-0.5587, -0.0941, 0.5754, 0.3475, -0.8952, -0.1964, 0.0810, -0.4174]],
|
|
2170
|
+
device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775, 0.3273],
|
|
2171
|
+
[-0.5399, -0.0229, -0.1455, -0.8770],
|
|
2172
|
+
[-0.9520, 0.4593, -0.7169, 0.6307],
|
|
2173
|
+
[-0.1765, 0.8757, 0.8614, 0.2051],
|
|
2174
|
+
[-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
|
|
2175
|
+
|
|
2176
|
+
|
|
2177
|
+
>>> # Batch size = 3
|
|
2178
|
+
>>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0],
|
|
2179
|
+
>>> device="cuda",
|
|
2180
|
+
>>> dtype=torch.long)
|
|
2181
|
+
>>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13],
|
|
2182
|
+
>>> device="cuda",
|
|
2183
|
+
>>> dtype=torch.long)
|
|
2184
|
+
>>>
|
|
2185
|
+
>>> output = tbe(indices, offsets)
|
|
2186
|
+
>>>
|
|
2187
|
+
>>> # Batch size = 3, total embedding dimension = 12
|
|
2188
|
+
>>> print(output.shape)
|
|
2189
|
+
torch.Size([3, 12])
|
|
2190
|
+
|
|
2191
|
+
>>> print(output)
|
|
2192
|
+
tensor([[-1.5197, 1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801, 0.2769,
|
|
2193
|
+
-0.7164, 0.8528, 0.7159, -0.6719],
|
|
2194
|
+
[-2.0784, 1.2016, 0.2176, 0.1988, -1.3825, -0.5008, -0.8991, -0.1405,
|
|
2195
|
+
-1.2637, -0.9427, -1.8902, 0.3754],
|
|
2196
|
+
[-1.5013, 0.6105, 0.9968, 0.3057, -0.7621, -0.9821, -0.7314, -0.6195,
|
|
2197
|
+
-0.2513, -0.4039, -0.3775, 0.3273]], device='cuda:0',
|
|
2198
|
+
grad_fn=<CppNode<SplitLookupFunction_sgd_Op>>)
|
|
2199
|
+
|
|
2200
|
+
"""
|
|
2201
|
+
(
|
|
2202
|
+
indices,
|
|
2203
|
+
offsets,
|
|
2204
|
+
per_sample_weights,
|
|
2205
|
+
vbe_metadata,
|
|
2206
|
+
) = self.prepare_inputs(
|
|
2207
|
+
indices,
|
|
2208
|
+
offsets,
|
|
2209
|
+
per_sample_weights,
|
|
2210
|
+
batch_size_per_feature_per_rank,
|
|
2211
|
+
force_cast_input_types=True,
|
|
2212
|
+
prefetch_pipeline=False,
|
|
2213
|
+
)
|
|
2214
|
+
|
|
2215
|
+
# Print input stats if enable (for debugging purpose only)
|
|
2216
|
+
self._debug_print_input_stats(indices, offsets, per_sample_weights)
|
|
2217
|
+
|
|
2218
|
+
# Extract and Write input stats if enable
|
|
2219
|
+
if self._report_input_params is not None:
|
|
2220
|
+
self._report_input_params(
|
|
2221
|
+
feature_rows=self.rows_per_table,
|
|
2222
|
+
feature_dims=self.feature_dims,
|
|
2223
|
+
iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0,
|
|
2224
|
+
indices=indices,
|
|
2225
|
+
offsets=offsets,
|
|
2226
|
+
op_id=self.uuid,
|
|
2227
|
+
per_sample_weights=per_sample_weights,
|
|
2228
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
2229
|
+
)
|
|
2230
|
+
|
|
2231
|
+
if not is_torchdynamo_compiling():
|
|
2232
|
+
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
|
|
2233
|
+
|
|
2234
|
+
# Storing tensors for linear_cache_indices recomputation
|
|
2235
|
+
self._indices = indices
|
|
2236
|
+
self._offsets = offsets
|
|
2237
|
+
self._vbe_B_offsets = vbe_metadata.B_offsets
|
|
2238
|
+
self._vbe_max_B = vbe_metadata.max_B
|
|
2239
|
+
|
|
2240
|
+
self.step += 1
|
|
2241
|
+
self._report_io_size_count("fwd_input", indices)
|
|
2242
|
+
self._report_tbe_mem_usage()
|
|
2243
|
+
|
|
2244
|
+
if self.tbe_input_multiplexer is not None:
|
|
2245
|
+
tbe_input_multiplexer: TBEInputMultiplexer = self.tbe_input_multiplexer
|
|
2246
|
+
if tbe_input_multiplexer.should_run(self.step):
|
|
2247
|
+
tbe_input_multiplexer.run(
|
|
2248
|
+
tbe_input_info=TBEInputInfo(
|
|
2249
|
+
indices, offsets, batch_size_per_feature_per_rank
|
|
2250
|
+
)
|
|
2251
|
+
)
|
|
2252
|
+
|
|
2253
|
+
if len(self.timesteps_prefetched) == 0:
|
|
2254
|
+
# In forward, we don't enable multi-pass prefetch as we want the process
|
|
2255
|
+
# to be as fast as possible and memory usage doesn't matter (will be recycled
|
|
2256
|
+
# by dense fwd/bwd)
|
|
2257
|
+
self._prefetch(
|
|
2258
|
+
indices,
|
|
2259
|
+
offsets,
|
|
2260
|
+
vbe_metadata,
|
|
2261
|
+
multipass_prefetch_config=None,
|
|
2262
|
+
hash_zch_identities=hash_zch_identities,
|
|
2263
|
+
hash_zch_runtime_meta=hash_zch_runtime_meta,
|
|
2264
|
+
)
|
|
2265
|
+
|
|
2266
|
+
if len(self.timesteps_prefetched) > 0:
|
|
2267
|
+
self.timesteps_prefetched.pop(0)
|
|
2268
|
+
|
|
2269
|
+
self.lxu_cache_locations = (
|
|
2270
|
+
self.lxu_cache_locations_empty
|
|
2271
|
+
if len(self.lxu_cache_locations_list) == 0
|
|
2272
|
+
else self.lxu_cache_locations_list.pop(0)
|
|
2273
|
+
)
|
|
2274
|
+
common_args = invokers.lookup_args.CommonArgs(
|
|
2275
|
+
placeholder_autograd_tensor=self.placeholder_autograd_tensor,
|
|
2276
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
2277
|
+
# `Union[Module, Tensor]`.
|
|
2278
|
+
dev_weights=self.weights_dev,
|
|
2279
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
2280
|
+
# `Union[Module, Tensor]`.
|
|
2281
|
+
host_weights=self.weights_host,
|
|
2282
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
2283
|
+
# `Union[Module, Tensor]`.
|
|
2284
|
+
uvm_weights=self.weights_uvm,
|
|
2285
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
2286
|
+
# `Union[Module, Tensor]`.
|
|
2287
|
+
lxu_cache_weights=self.lxu_cache_weights,
|
|
2288
|
+
# pyre-fixme[6]: For 6th argument expected `Tensor` but got
|
|
2289
|
+
# `Union[Module, Tensor]`.
|
|
2290
|
+
weights_placements=self.weights_placements,
|
|
2291
|
+
# pyre-fixme[6]: For 7th argument expected `Tensor` but got
|
|
2292
|
+
# `Union[Module, Tensor]`.
|
|
2293
|
+
weights_offsets=self.weights_offsets,
|
|
2294
|
+
D_offsets=self.D_offsets,
|
|
2295
|
+
total_D=self.total_D,
|
|
2296
|
+
max_D=self.max_D,
|
|
2297
|
+
hash_size_cumsum=self.hash_size_cumsum,
|
|
2298
|
+
total_hash_size_bits=self.total_hash_size_bits,
|
|
2299
|
+
indices=indices,
|
|
2300
|
+
offsets=offsets,
|
|
2301
|
+
pooling_mode=self.pooling_mode,
|
|
2302
|
+
indice_weights=per_sample_weights,
|
|
2303
|
+
feature_requires_grad=feature_requires_grad,
|
|
2304
|
+
lxu_cache_locations=self.lxu_cache_locations,
|
|
2305
|
+
# Pass the local_uvm_cache_stats bc only that information is
|
|
2306
|
+
# relevant for the current iteration
|
|
2307
|
+
uvm_cache_stats=(
|
|
2308
|
+
self.local_uvm_cache_stats
|
|
2309
|
+
if (
|
|
2310
|
+
self.gather_uvm_cache_stats
|
|
2311
|
+
# Unique conflict misses are only collected when using CacheAlgorithm.LRU
|
|
2312
|
+
and self.cache_algorithm == CacheAlgorithm.LRU
|
|
2313
|
+
)
|
|
2314
|
+
else None
|
|
2315
|
+
),
|
|
2316
|
+
output_dtype=self.output_dtype,
|
|
2317
|
+
vbe_metadata=vbe_metadata,
|
|
2318
|
+
is_experimental=self.is_experimental,
|
|
2319
|
+
use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd,
|
|
2320
|
+
use_homogeneous_placements=self.use_homogeneous_placements,
|
|
2321
|
+
learning_rate_tensor=self.learning_rate_tensor,
|
|
2322
|
+
info_B_num_bits=self.info_B_num_bits,
|
|
2323
|
+
info_B_mask=self.info_B_mask,
|
|
2324
|
+
)
|
|
2325
|
+
|
|
2326
|
+
if self.optimizer == OptimType.NONE:
|
|
2327
|
+
assert (
|
|
2328
|
+
total_unique_indices is not None
|
|
2329
|
+
and total_unique_indices <= indices.numel()
|
|
2330
|
+
), f"OptimType.NONE requires total_unique_indices. Please pass it or check the value (total_unique_indices = {total_unique_indices})"
|
|
2331
|
+
return self._report_io_size_count(
|
|
2332
|
+
"fwd_output",
|
|
2333
|
+
invokers.lookup_none.invoke(
|
|
2334
|
+
common_args, self.optimizer_args, total_unique_indices
|
|
2335
|
+
),
|
|
2336
|
+
)
|
|
2337
|
+
elif self.optimizer == OptimType.EXACT_SGD:
|
|
2338
|
+
return self._report_io_size_count(
|
|
2339
|
+
"fwd_output",
|
|
2340
|
+
invokers.lookup_sgd.invoke(common_args, self.optimizer_args),
|
|
2341
|
+
)
|
|
2342
|
+
|
|
2343
|
+
momentum1 = invokers.lookup_args.Momentum(
|
|
2344
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
2345
|
+
# `Union[Module, Tensor]`.
|
|
2346
|
+
dev=self.momentum1_dev,
|
|
2347
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
2348
|
+
# `Union[Module, Tensor]`.
|
|
2349
|
+
host=self.momentum1_host,
|
|
2350
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
2351
|
+
# `Union[Module, Tensor]`.
|
|
2352
|
+
uvm=self.momentum1_uvm,
|
|
2353
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
2354
|
+
# `Union[Module, Tensor]`.
|
|
2355
|
+
offsets=self.momentum1_offsets,
|
|
2356
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
2357
|
+
# `Union[Module, Tensor]`.
|
|
2358
|
+
placements=self.momentum1_placements,
|
|
2359
|
+
)
|
|
2360
|
+
|
|
2361
|
+
if self.optimizer == OptimType.LARS_SGD:
|
|
2362
|
+
return self._report_io_size_count(
|
|
2363
|
+
"fwd_output",
|
|
2364
|
+
invokers.lookup_lars_sgd.invoke(
|
|
2365
|
+
common_args, self.optimizer_args, momentum1
|
|
2366
|
+
),
|
|
2367
|
+
)
|
|
2368
|
+
if self.optimizer == OptimType.EXACT_ADAGRAD:
|
|
2369
|
+
return self._report_io_size_count(
|
|
2370
|
+
"fwd_output",
|
|
2371
|
+
invokers.lookup_adagrad.invoke(
|
|
2372
|
+
common_args, self.optimizer_args, momentum1
|
|
2373
|
+
),
|
|
2374
|
+
)
|
|
2375
|
+
|
|
2376
|
+
momentum2 = invokers.lookup_args.Momentum(
|
|
2377
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
2378
|
+
# `Union[Module, Tensor]`.
|
|
2379
|
+
dev=self.momentum2_dev,
|
|
2380
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
2381
|
+
# `Union[Module, Tensor]`.
|
|
2382
|
+
host=self.momentum2_host,
|
|
2383
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
2384
|
+
# `Union[Module, Tensor]`.
|
|
2385
|
+
uvm=self.momentum2_uvm,
|
|
2386
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
2387
|
+
# `Union[Module, Tensor]`.
|
|
2388
|
+
offsets=self.momentum2_offsets,
|
|
2389
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
2390
|
+
# `Union[Module, Tensor]`.
|
|
2391
|
+
placements=self.momentum2_placements,
|
|
2392
|
+
)
|
|
2393
|
+
|
|
2394
|
+
# Although self.iter_cpu is created on CPU. It might be transferred to
|
|
2395
|
+
# GPU by the user. So, we need to transfer it to CPU explicitly. This
|
|
2396
|
+
# should be done only once.
|
|
2397
|
+
self.iter_cpu = self.iter_cpu.cpu()
|
|
2398
|
+
|
|
2399
|
+
# Sync with loaded state
|
|
2400
|
+
if (
|
|
2401
|
+
not is_torchdynamo_compiling()
|
|
2402
|
+
): # wrap to make it compatible with PT2 compile
|
|
2403
|
+
if self.iter_cpu.item() == 0:
|
|
2404
|
+
self.iter_cpu.fill_(self.iter.cpu().item())
|
|
2405
|
+
# Increment the iteration counter
|
|
2406
|
+
iter_int = int(self.iter_cpu.add_(1).item()) # used for local computation
|
|
2407
|
+
self.iter.add_(1) # used for checkpointing
|
|
2408
|
+
|
|
2409
|
+
row_counter = invokers.lookup_args.Momentum(
|
|
2410
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
2411
|
+
# `Union[Module, Tensor]`.
|
|
2412
|
+
dev=self.row_counter_dev,
|
|
2413
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
2414
|
+
# `Union[Module, Tensor]`.
|
|
2415
|
+
host=self.row_counter_host,
|
|
2416
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
2417
|
+
# `Union[Module, Tensor]`.
|
|
2418
|
+
uvm=self.row_counter_uvm,
|
|
2419
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
2420
|
+
# `Union[Module, Tensor]`.
|
|
2421
|
+
offsets=self.row_counter_offsets,
|
|
2422
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
2423
|
+
# `Union[Module, Tensor]`.
|
|
2424
|
+
placements=self.row_counter_placements,
|
|
2425
|
+
)
|
|
2426
|
+
|
|
2427
|
+
if self.optimizer == OptimType.ADAM:
|
|
2428
|
+
return self._report_io_size_count(
|
|
2429
|
+
"fwd_output",
|
|
2430
|
+
invokers.lookup_adam.invoke(
|
|
2431
|
+
common_args,
|
|
2432
|
+
self.optimizer_args,
|
|
2433
|
+
momentum1,
|
|
2434
|
+
momentum2,
|
|
2435
|
+
iter_int,
|
|
2436
|
+
row_counter=(
|
|
2437
|
+
row_counter if self.use_rowwise_bias_correction else None
|
|
2438
|
+
),
|
|
2439
|
+
),
|
|
2440
|
+
)
|
|
2441
|
+
if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
|
|
2442
|
+
return self._report_io_size_count(
|
|
2443
|
+
"fwd_output",
|
|
2444
|
+
invokers.lookup_partial_rowwise_adam.invoke(
|
|
2445
|
+
common_args,
|
|
2446
|
+
self.optimizer_args,
|
|
2447
|
+
momentum1,
|
|
2448
|
+
momentum2,
|
|
2449
|
+
iter_int,
|
|
2450
|
+
),
|
|
2451
|
+
)
|
|
2452
|
+
if self.optimizer == OptimType.LAMB:
|
|
2453
|
+
return self._report_io_size_count(
|
|
2454
|
+
"fwd_output",
|
|
2455
|
+
invokers.lookup_lamb.invoke(
|
|
2456
|
+
common_args,
|
|
2457
|
+
self.optimizer_args,
|
|
2458
|
+
momentum1,
|
|
2459
|
+
momentum2,
|
|
2460
|
+
iter_int,
|
|
2461
|
+
),
|
|
2462
|
+
)
|
|
2463
|
+
if self.optimizer == OptimType.PARTIAL_ROWWISE_LAMB:
|
|
2464
|
+
return self._report_io_size_count(
|
|
2465
|
+
"fwd_output",
|
|
2466
|
+
invokers.lookup_partial_rowwise_lamb.invoke(
|
|
2467
|
+
common_args,
|
|
2468
|
+
self.optimizer_args,
|
|
2469
|
+
momentum1,
|
|
2470
|
+
momentum2,
|
|
2471
|
+
iter_int,
|
|
2472
|
+
),
|
|
2473
|
+
)
|
|
2474
|
+
|
|
2475
|
+
prev_iter = invokers.lookup_args.Momentum(
|
|
2476
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
2477
|
+
# `Union[Module, Tensor]`.
|
|
2478
|
+
dev=self.prev_iter_dev,
|
|
2479
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
2480
|
+
# `Union[Module, Tensor]`.
|
|
2481
|
+
host=self.prev_iter_host,
|
|
2482
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
2483
|
+
# `Union[Module, Tensor]`.
|
|
2484
|
+
uvm=self.prev_iter_uvm,
|
|
2485
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
2486
|
+
# `Union[Module, Tensor]`.
|
|
2487
|
+
offsets=self.prev_iter_offsets,
|
|
2488
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
2489
|
+
# `Union[Module, Tensor]`.
|
|
2490
|
+
placements=self.prev_iter_placements,
|
|
2491
|
+
)
|
|
2492
|
+
|
|
2493
|
+
if self.optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD:
|
|
2494
|
+
with torch.no_grad():
|
|
2495
|
+
if self.training:
|
|
2496
|
+
self.ema_inplace(self._emainplace_mode)
|
|
2497
|
+
return self._report_io_size_count(
|
|
2498
|
+
"fwd_output",
|
|
2499
|
+
invokers.lookup_rowwise_adagrad.invoke(
|
|
2500
|
+
common_args,
|
|
2501
|
+
self.optimizer_args,
|
|
2502
|
+
momentum1,
|
|
2503
|
+
),
|
|
2504
|
+
)
|
|
2505
|
+
|
|
2506
|
+
if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD:
|
|
2507
|
+
assert self._feature_is_enabled(
|
|
2508
|
+
FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD
|
|
2509
|
+
), "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!"
|
|
2510
|
+
with torch.no_grad():
|
|
2511
|
+
if self.training:
|
|
2512
|
+
self.ensemble_and_swap(self._ensemble_mode)
|
|
2513
|
+
return self._report_io_size_count(
|
|
2514
|
+
"fwd_output",
|
|
2515
|
+
invokers.lookup_rowwise_adagrad.invoke(
|
|
2516
|
+
common_args,
|
|
2517
|
+
self.optimizer_args,
|
|
2518
|
+
momentum1,
|
|
2519
|
+
),
|
|
2520
|
+
)
|
|
2521
|
+
|
|
2522
|
+
if self._used_rowwise_adagrad_with_counter:
|
|
2523
|
+
if (
|
|
2524
|
+
self._max_counter_update_freq > 0
|
|
2525
|
+
and iter_int % self._max_counter_update_freq == 0
|
|
2526
|
+
):
|
|
2527
|
+
row_counter_dev = self.row_counter_dev.detach()
|
|
2528
|
+
if row_counter_dev.numel() > 0:
|
|
2529
|
+
self.max_counter[0] = torch.max(row_counter_dev).cpu().item() + 1
|
|
2530
|
+
else:
|
|
2531
|
+
self.max_counter[0] = 1
|
|
2532
|
+
|
|
2533
|
+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
|
|
2534
|
+
if self._used_rowwise_adagrad_with_counter:
|
|
2535
|
+
return self._report_io_size_count(
|
|
2536
|
+
"fwd_output",
|
|
2537
|
+
invokers.lookup_rowwise_adagrad_with_counter.invoke(
|
|
2538
|
+
common_args,
|
|
2539
|
+
self.optimizer_args,
|
|
2540
|
+
momentum1,
|
|
2541
|
+
prev_iter,
|
|
2542
|
+
row_counter,
|
|
2543
|
+
iter_int,
|
|
2544
|
+
self.max_counter.item(),
|
|
2545
|
+
mixed_D=self.mixed_D,
|
|
2546
|
+
),
|
|
2547
|
+
)
|
|
2548
|
+
elif self._used_rowwise_adagrad_with_global_weight_decay:
|
|
2549
|
+
apply_global_weight_decay = (
|
|
2550
|
+
iter_int >= self.gwd_start_iter and self.training
|
|
2551
|
+
)
|
|
2552
|
+
return self._report_io_size_count(
|
|
2553
|
+
"fwd_output",
|
|
2554
|
+
invokers.lookup_rowwise_adagrad.invoke(
|
|
2555
|
+
common_args,
|
|
2556
|
+
self.optimizer_args,
|
|
2557
|
+
momentum1,
|
|
2558
|
+
iter=iter_int,
|
|
2559
|
+
apply_global_weight_decay=apply_global_weight_decay,
|
|
2560
|
+
# pyre-fixme[6]: For 6th argument expected
|
|
2561
|
+
# `Optional[Tensor]` but got `Union[Module, Tensor]`.
|
|
2562
|
+
prev_iter_dev=self.prev_iter_dev,
|
|
2563
|
+
gwd_lower_bound=self.gwd_lower_bound,
|
|
2564
|
+
mixed_D=self.mixed_D,
|
|
2565
|
+
),
|
|
2566
|
+
)
|
|
2567
|
+
else:
|
|
2568
|
+
return self._report_io_size_count(
|
|
2569
|
+
"fwd_output",
|
|
2570
|
+
invokers.lookup_rowwise_adagrad.invoke(
|
|
2571
|
+
common_args,
|
|
2572
|
+
self.optimizer_args,
|
|
2573
|
+
momentum1,
|
|
2574
|
+
mixed_D=self.mixed_D,
|
|
2575
|
+
),
|
|
2576
|
+
)
|
|
2577
|
+
|
|
2578
|
+
raise ValueError(f"Invalid OptimType: {self.optimizer}")
|
|
2579
|
+
|
|
2580
|
+
def ema_inplace(self, emainplace_mode: dict[str, float]) -> None:
|
|
2581
|
+
"""
|
|
2582
|
+
Perform ema operations on the full sparse embedding tables.
|
|
2583
|
+
We organize the sparse table, in the following way.
|
|
2584
|
+
|
|
2585
|
+
Emb_table:
|
|
2586
|
+
-------------------------------------------------
|
|
2587
|
+
- -- -
|
|
2588
|
+
- Fast part -- Slow part -
|
|
2589
|
+
- (RL) main part -- target part -
|
|
2590
|
+
- -- -
|
|
2591
|
+
-------------------------------------------------
|
|
2592
|
+
|
|
2593
|
+
In every "step_ema" step, we perform
|
|
2594
|
+
slow_part += coef_ema * (fast_part - slow_part)
|
|
2595
|
+
"""
|
|
2596
|
+
iter_int = int(self.iter_cpu.item())
|
|
2597
|
+
if iter_int % int(emainplace_mode["step_ema"]) == 0 and iter_int >= int(
|
|
2598
|
+
emainplace_mode["step_start"]
|
|
2599
|
+
):
|
|
2600
|
+
weights = self.split_embedding_weights()
|
|
2601
|
+
for table_i, (_, dim, _, _) in enumerate(self.embedding_specs):
|
|
2602
|
+
assert (
|
|
2603
|
+
dim & 1 == 0
|
|
2604
|
+
), f"table dimension {dim} is odd, not supported for ema_inplace_rowwise_adagrad" # make sure that the dimension is even
|
|
2605
|
+
weights[table_i][:, dim // 2 :].data.lerp_(
|
|
2606
|
+
weights[table_i][:, : dim // 2].data,
|
|
2607
|
+
emainplace_mode["step_ema_coef"],
|
|
2608
|
+
)
|
|
2609
|
+
|
|
2610
|
+
def ensemble_and_swap(self, ensemble_mode: dict[str, float]) -> None:
|
|
2611
|
+
"""
|
|
2612
|
+
Perform ensemble and swap operations on the full sparse embedding tables.
|
|
2613
|
+
|
|
2614
|
+
Returns:
|
|
2615
|
+
Sparse embedding weights and optimizer states will be updated in-place.
|
|
2616
|
+
"""
|
|
2617
|
+
iter_int = int(self.iter_cpu.item())
|
|
2618
|
+
should_ema = iter_int % int(ensemble_mode["step_ema"]) == 0
|
|
2619
|
+
should_swap = iter_int % int(ensemble_mode["step_swap"]) == 0
|
|
2620
|
+
if should_ema or should_swap:
|
|
2621
|
+
weights = self.split_embedding_weights()
|
|
2622
|
+
states = self.split_optimizer_states()
|
|
2623
|
+
coef_ema = (
|
|
2624
|
+
0.0
|
|
2625
|
+
if iter_int <= int(ensemble_mode["step_start"])
|
|
2626
|
+
else ensemble_mode["step_ema_coef"]
|
|
2627
|
+
)
|
|
2628
|
+
for i in range(len(self.embedding_specs)):
|
|
2629
|
+
# 0) copying weights from gpu to cpu
|
|
2630
|
+
weights_cpu = weights[i].to(
|
|
2631
|
+
dtype=states[i][1].dtype, device=states[i][1].device
|
|
2632
|
+
)
|
|
2633
|
+
# 1) ema step
|
|
2634
|
+
if should_ema:
|
|
2635
|
+
states[i][1].lerp_(weights_cpu, 1.0 - coef_ema)
|
|
2636
|
+
# 2) swap step
|
|
2637
|
+
if should_swap:
|
|
2638
|
+
weights[i].copy_(states[i][1], non_blocking=True)
|
|
2639
|
+
# 3) post-processing step
|
|
2640
|
+
if should_ema:
|
|
2641
|
+
if int(ensemble_mode["step_mode"]) == 0: # embedding scaling
|
|
2642
|
+
states[i][1].mul_(0.0)
|
|
2643
|
+
# elif int(ensemble_mode["step_mode"]) == 2: pure ema
|
|
2644
|
+
|
|
2645
|
+
def reset_uvm_cache_stats(self) -> None:
|
|
2646
|
+
assert (
|
|
2647
|
+
self.gather_uvm_cache_stats
|
|
2648
|
+
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
|
|
2649
|
+
self.uvm_cache_stats.zero_()
|
|
2650
|
+
self.local_uvm_cache_stats.zero_()
|
|
2651
|
+
|
|
2652
|
+
def get_uvm_cache_stats(self, use_local_cache: bool = False) -> Tensor:
|
|
2653
|
+
assert (
|
|
2654
|
+
self.gather_uvm_cache_stats
|
|
2655
|
+
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
|
|
2656
|
+
return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats
|
|
2657
|
+
|
|
2658
|
+
def _get_uvm_cache_print_state(self, use_local_cache: bool = False) -> list[float]:
|
|
2659
|
+
snapshot = self.get_uvm_cache_stats(use_local_cache)
|
|
2660
|
+
if use_local_cache:
|
|
2661
|
+
return snapshot.tolist()
|
|
2662
|
+
|
|
2663
|
+
# Stats are accumulated over multiple steps. Compute delta, and update state.
|
|
2664
|
+
delta = snapshot - self.last_uvm_cache_print_state
|
|
2665
|
+
self.last_uvm_cache_print_state = snapshot.clone()
|
|
2666
|
+
return delta.tolist()
|
|
2667
|
+
|
|
2668
|
+
@torch.jit.ignore
|
|
2669
|
+
def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None:
|
|
2670
|
+
# TODO: Create a separate reporter class to unify the stdlog reporting
|
|
2671
|
+
uvm_cache_stats: list[float] = self._get_uvm_cache_print_state(use_local_cache)
|
|
2672
|
+
N = max(1, uvm_cache_stats[0])
|
|
2673
|
+
m = {
|
|
2674
|
+
"N_called": uvm_cache_stats[UVMCacheStatsIndex.num_calls],
|
|
2675
|
+
"requested_indices": uvm_cache_stats[
|
|
2676
|
+
UVMCacheStatsIndex.num_requested_indices
|
|
2677
|
+
]
|
|
2678
|
+
/ N,
|
|
2679
|
+
"unique_indices": uvm_cache_stats[UVMCacheStatsIndex.num_unique_indices]
|
|
2680
|
+
/ N,
|
|
2681
|
+
"unique_misses": uvm_cache_stats[UVMCacheStatsIndex.num_unique_misses] / N,
|
|
2682
|
+
"conflict_unique_misses": uvm_cache_stats[
|
|
2683
|
+
UVMCacheStatsIndex.num_conflict_unique_misses
|
|
2684
|
+
]
|
|
2685
|
+
/ N,
|
|
2686
|
+
"conflict_misses": uvm_cache_stats[UVMCacheStatsIndex.num_conflict_misses]
|
|
2687
|
+
/ N,
|
|
2688
|
+
}
|
|
2689
|
+
if uvm_cache_stats[1]:
|
|
2690
|
+
m.update(
|
|
2691
|
+
{
|
|
2692
|
+
"unique indices / requested indices": uvm_cache_stats[
|
|
2693
|
+
UVMCacheStatsIndex.num_unique_indices
|
|
2694
|
+
]
|
|
2695
|
+
/ uvm_cache_stats[UVMCacheStatsIndex.num_requested_indices],
|
|
2696
|
+
"unique misses / requested indices": uvm_cache_stats[
|
|
2697
|
+
UVMCacheStatsIndex.num_unique_misses
|
|
2698
|
+
]
|
|
2699
|
+
/ uvm_cache_stats[UVMCacheStatsIndex.num_requested_indices],
|
|
2700
|
+
}
|
|
2701
|
+
)
|
|
2702
|
+
self.log(f"uvm_cache_stats={m}")
|
|
2703
|
+
|
|
2704
|
+
@torch.jit.ignore
|
|
2705
|
+
def _report_uvm_cache_stats(self) -> None:
|
|
2706
|
+
if self.stats_reporter is None:
|
|
2707
|
+
return
|
|
2708
|
+
stats_reporter: TBEStatsReporter = self.stats_reporter
|
|
2709
|
+
passed_steps = self.step - self.last_reported_step
|
|
2710
|
+
if passed_steps == 0:
|
|
2711
|
+
return
|
|
2712
|
+
if not stats_reporter.should_report(self.step):
|
|
2713
|
+
return
|
|
2714
|
+
|
|
2715
|
+
uvm_cache_stats: list[float] = self.get_uvm_cache_stats(
|
|
2716
|
+
use_local_cache=False
|
|
2717
|
+
).tolist()
|
|
2718
|
+
self.last_reported_step = self.step
|
|
2719
|
+
|
|
2720
|
+
if len(self.last_reported_uvm_stats) == 0:
|
|
2721
|
+
self.last_reported_uvm_stats = [0.0] * len(uvm_cache_stats)
|
|
2722
|
+
uvm_cache_stats_delta: list[float] = [0.0] * len(uvm_cache_stats)
|
|
2723
|
+
for i in range(len(uvm_cache_stats)):
|
|
2724
|
+
uvm_cache_stats_delta[i] = (
|
|
2725
|
+
uvm_cache_stats[i] - self.last_reported_uvm_stats[i]
|
|
2726
|
+
)
|
|
2727
|
+
self.last_reported_uvm_stats = uvm_cache_stats
|
|
2728
|
+
|
|
2729
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
|
|
2730
|
+
# a function.
|
|
2731
|
+
element_size = self.lxu_cache_weights.element_size()
|
|
2732
|
+
for stat_index in UVMCacheStatsIndex:
|
|
2733
|
+
stats_reporter.report_data_amount(
|
|
2734
|
+
iteration_step=self.step,
|
|
2735
|
+
event_name=f"tbe.prefetch.cache_stats_by_data_size.{stat_index.name.lower()}",
|
|
2736
|
+
data_bytes=int(
|
|
2737
|
+
uvm_cache_stats_delta[stat_index.value]
|
|
2738
|
+
* element_size
|
|
2739
|
+
* self.max_D_cache
|
|
2740
|
+
/ passed_steps
|
|
2741
|
+
),
|
|
2742
|
+
embedding_id=self.logging_table_name,
|
|
2743
|
+
tbe_id=self.uuid,
|
|
2744
|
+
)
|
|
2745
|
+
|
|
2746
|
+
def prefetch(
|
|
2747
|
+
self,
|
|
2748
|
+
indices: Tensor,
|
|
2749
|
+
offsets: Tensor,
|
|
2750
|
+
forward_stream: Optional[torch.cuda.Stream] = None,
|
|
2751
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
2752
|
+
) -> None:
|
|
2753
|
+
if self.prefetch_stream is None and forward_stream is not None:
|
|
2754
|
+
self.prefetch_stream = torch.cuda.current_stream()
|
|
2755
|
+
assert (
|
|
2756
|
+
self.prefetch_stream != forward_stream
|
|
2757
|
+
), "prefetch_stream and forward_stream should not be the same stream"
|
|
2758
|
+
|
|
2759
|
+
indices, offsets, _, vbe_metadata = self.prepare_inputs(
|
|
2760
|
+
indices,
|
|
2761
|
+
offsets,
|
|
2762
|
+
per_sample_weights=None,
|
|
2763
|
+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
|
|
2764
|
+
force_cast_input_types=False,
|
|
2765
|
+
prefetch_pipeline=self.prefetch_pipeline,
|
|
2766
|
+
)
|
|
2767
|
+
|
|
2768
|
+
with self._recording_to_timer(
|
|
2769
|
+
self.prefetch_duration_timer,
|
|
2770
|
+
context=self.step,
|
|
2771
|
+
stream=torch.cuda.current_stream(),
|
|
2772
|
+
):
|
|
2773
|
+
self._prefetch(
|
|
2774
|
+
indices,
|
|
2775
|
+
offsets,
|
|
2776
|
+
vbe_metadata,
|
|
2777
|
+
multipass_prefetch_config=self.multipass_prefetch_config,
|
|
2778
|
+
)
|
|
2779
|
+
|
|
2780
|
+
if forward_stream is not None:
|
|
2781
|
+
self._prefetch_tensors_record_stream(forward_stream)
|
|
2782
|
+
|
|
2783
|
+
def _prefetch(
|
|
2784
|
+
self,
|
|
2785
|
+
indices: Tensor,
|
|
2786
|
+
offsets: Tensor,
|
|
2787
|
+
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
|
|
2788
|
+
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
|
|
2789
|
+
hash_zch_identities: Optional[Tensor] = None,
|
|
2790
|
+
hash_zch_runtime_meta: Optional[Tensor] = None,
|
|
2791
|
+
) -> None:
|
|
2792
|
+
if not is_torchdynamo_compiling():
|
|
2793
|
+
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
|
|
2794
|
+
self.timestep += 1
|
|
2795
|
+
self.timesteps_prefetched.append(self.timestep)
|
|
2796
|
+
|
|
2797
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
|
|
2798
|
+
# a function.
|
|
2799
|
+
if not self.lxu_cache_weights.numel():
|
|
2800
|
+
return
|
|
2801
|
+
|
|
2802
|
+
# Clear the local_uvm_cache_stats before the prefetch instead of after
|
|
2803
|
+
# the prefetch step, since it will be used in the CommonArgs in the
|
|
2804
|
+
# forward step
|
|
2805
|
+
if self.gather_uvm_cache_stats:
|
|
2806
|
+
self.local_uvm_cache_stats.zero_()
|
|
2807
|
+
self._report_io_size_count("prefetch_input", indices)
|
|
2808
|
+
|
|
2809
|
+
# streaming before updating the cache
|
|
2810
|
+
self.raw_embedding_stream()
|
|
2811
|
+
|
|
2812
|
+
final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
|
|
2813
|
+
linear_cache_indices_merged = torch.zeros(
|
|
2814
|
+
0, dtype=indices.dtype, device=indices.device
|
|
2815
|
+
)
|
|
2816
|
+
for (
|
|
2817
|
+
partial_indices,
|
|
2818
|
+
partial_lxu_cache_locations,
|
|
2819
|
+
base_offset,
|
|
2820
|
+
) in self.get_prefetch_passes(
|
|
2821
|
+
multipass_prefetch_config, indices, final_lxu_cache_locations
|
|
2822
|
+
):
|
|
2823
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
2824
|
+
self.cache_hash_size_cumsum,
|
|
2825
|
+
partial_indices,
|
|
2826
|
+
offsets,
|
|
2827
|
+
vbe_metadata.B_offsets if vbe_metadata is not None else None,
|
|
2828
|
+
vbe_metadata.max_B if vbe_metadata is not None else -1,
|
|
2829
|
+
base_offset,
|
|
2830
|
+
)
|
|
2831
|
+
linear_cache_indices_merged = torch.cat(
|
|
2832
|
+
[linear_cache_indices_merged, linear_cache_indices]
|
|
2833
|
+
)
|
|
2834
|
+
|
|
2835
|
+
if (
|
|
2836
|
+
self.record_cache_metrics.record_cache_miss_counter
|
|
2837
|
+
or self.record_cache_metrics.record_tablewise_cache_miss
|
|
2838
|
+
):
|
|
2839
|
+
lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
|
|
2840
|
+
linear_cache_indices,
|
|
2841
|
+
self.lxu_cache_state,
|
|
2842
|
+
self.total_cache_hash_size,
|
|
2843
|
+
self.gather_uvm_cache_stats,
|
|
2844
|
+
self.local_uvm_cache_stats,
|
|
2845
|
+
)
|
|
2846
|
+
if self.record_cache_metrics.record_cache_miss_counter:
|
|
2847
|
+
self._update_cache_miss_counter(
|
|
2848
|
+
lxu_cache_locations, linear_cache_indices
|
|
2849
|
+
)
|
|
2850
|
+
if self.record_cache_metrics.record_tablewise_cache_miss:
|
|
2851
|
+
self._update_tablewise_cache_miss(
|
|
2852
|
+
lxu_cache_locations, linear_cache_indices, offsets
|
|
2853
|
+
)
|
|
2854
|
+
|
|
2855
|
+
if self.cache_algorithm == CacheAlgorithm.LRU:
|
|
2856
|
+
torch.ops.fbgemm.lru_cache_populate(
|
|
2857
|
+
self.weights_uvm,
|
|
2858
|
+
self.cache_hash_size_cumsum,
|
|
2859
|
+
self.total_cache_hash_size,
|
|
2860
|
+
self.cache_index_table_map,
|
|
2861
|
+
self.weights_offsets,
|
|
2862
|
+
self.D_offsets,
|
|
2863
|
+
linear_cache_indices,
|
|
2864
|
+
self.lxu_cache_state,
|
|
2865
|
+
self.lxu_cache_weights,
|
|
2866
|
+
self.timestep,
|
|
2867
|
+
self.lxu_state,
|
|
2868
|
+
self.stochastic_rounding,
|
|
2869
|
+
self.gather_uvm_cache_stats,
|
|
2870
|
+
self.local_uvm_cache_stats,
|
|
2871
|
+
self.lock_cache_line,
|
|
2872
|
+
self.lxu_cache_locking_counter,
|
|
2873
|
+
)
|
|
2874
|
+
elif self.cache_algorithm == CacheAlgorithm.LFU:
|
|
2875
|
+
torch.ops.fbgemm.lfu_cache_populate(
|
|
2876
|
+
self.weights_uvm,
|
|
2877
|
+
self.cache_hash_size_cumsum,
|
|
2878
|
+
self.total_cache_hash_size,
|
|
2879
|
+
self.cache_index_table_map,
|
|
2880
|
+
self.weights_offsets,
|
|
2881
|
+
self.D_offsets,
|
|
2882
|
+
linear_cache_indices,
|
|
2883
|
+
self.lxu_cache_state,
|
|
2884
|
+
self.lxu_cache_weights,
|
|
2885
|
+
self.lxu_state,
|
|
2886
|
+
self.stochastic_rounding,
|
|
2887
|
+
)
|
|
2888
|
+
|
|
2889
|
+
torch.ops.fbgemm.lxu_cache_lookup(
|
|
2890
|
+
linear_cache_indices,
|
|
2891
|
+
self.lxu_cache_state,
|
|
2892
|
+
self.total_cache_hash_size,
|
|
2893
|
+
self.gather_uvm_cache_stats,
|
|
2894
|
+
self.local_uvm_cache_stats,
|
|
2895
|
+
lxu_cache_locations_output=partial_lxu_cache_locations,
|
|
2896
|
+
)
|
|
2897
|
+
|
|
2898
|
+
assert (
|
|
2899
|
+
len(self.lxu_cache_locations_list) < self.max_prefetch_depth
|
|
2900
|
+
), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
|
|
2901
|
+
self.lxu_cache_locations_list.append(final_lxu_cache_locations)
|
|
2902
|
+
|
|
2903
|
+
if self.gather_uvm_cache_stats:
|
|
2904
|
+
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
|
|
2905
|
+
# We may want to do this accumulation atomically, but as it's only
|
|
2906
|
+
# for monitoring, slightly inaccurate result may be acceptable.
|
|
2907
|
+
self.uvm_cache_stats = torch.add(
|
|
2908
|
+
self.uvm_cache_stats, self.local_uvm_cache_stats
|
|
2909
|
+
)
|
|
2910
|
+
self._report_uvm_cache_stats()
|
|
2911
|
+
if self.should_log():
|
|
2912
|
+
self.print_uvm_cache_stats(use_local_cache=False)
|
|
2913
|
+
|
|
2914
|
+
self._store_prefetched_tensors(
|
|
2915
|
+
indices,
|
|
2916
|
+
offsets,
|
|
2917
|
+
vbe_metadata,
|
|
2918
|
+
linear_cache_indices_merged,
|
|
2919
|
+
final_lxu_cache_locations,
|
|
2920
|
+
hash_zch_identities,
|
|
2921
|
+
hash_zch_runtime_meta,
|
|
2922
|
+
)
|
|
2923
|
+
|
|
2924
|
+
def should_log(self) -> bool:
|
|
2925
|
+
"""Determines if we should log for this step, using exponentially decreasing frequency.
|
|
2926
|
+
|
|
2927
|
+
Logs for steps: 100 200 ... 1,000 2,000 ... 10,000 20,000 ... 100,000 200,000 ...
|
|
2928
|
+
"""
|
|
2929
|
+
s = self.step + 1 # step starts at 0
|
|
2930
|
+
return s >= 100 and s % (10 ** int(math.log10(s))) == 0
|
|
2931
|
+
|
|
2932
|
+
def _prefetch_tensors_record_stream(
|
|
2933
|
+
self, forward_stream: torch.cuda.Stream
|
|
2934
|
+
) -> None:
|
|
2935
|
+
# Record the tensors created by prefetch stream and consumed by forward/backward
|
|
2936
|
+
# to the forward stream. In PyTorch, each backward CUDA op runs on the same
|
|
2937
|
+
# stream that was used for its corresponding forward op.
|
|
2938
|
+
|
|
2939
|
+
for t in self.lxu_cache_locations_list:
|
|
2940
|
+
t.record_stream(forward_stream)
|
|
2941
|
+
|
|
2942
|
+
def _update_cache_miss_counter(
|
|
2943
|
+
self,
|
|
2944
|
+
lxu_cache_locations: Tensor,
|
|
2945
|
+
linear_cache_indices: Tensor,
|
|
2946
|
+
) -> None:
|
|
2947
|
+
CACHE_MISS = -1
|
|
2948
|
+
CACHE_HIT = -2
|
|
2949
|
+
|
|
2950
|
+
cache_missed_locations = torch.where(
|
|
2951
|
+
lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
|
|
2952
|
+
)
|
|
2953
|
+
unique_ids_list = torch.unique(cache_missed_locations)
|
|
2954
|
+
unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)
|
|
2955
|
+
|
|
2956
|
+
miss_count = torch.sum(unique_ids_count_list)
|
|
2957
|
+
|
|
2958
|
+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
|
2959
|
+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
|
2960
|
+
self.cache_miss_counter[0] += (miss_count > 0).to(torch.int64)
|
|
2961
|
+
|
|
2962
|
+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
|
2963
|
+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
|
2964
|
+
self.cache_miss_counter[1] += miss_count
|
|
2965
|
+
|
|
2966
|
+
def _update_tablewise_cache_miss(
|
|
2967
|
+
self,
|
|
2968
|
+
lxu_cache_locations: Tensor,
|
|
2969
|
+
linear_cache_indices: Tensor,
|
|
2970
|
+
offsets: Tensor,
|
|
2971
|
+
) -> None:
|
|
2972
|
+
CACHE_MISS = -1
|
|
2973
|
+
CACHE_HIT = -2
|
|
2974
|
+
|
|
2975
|
+
# pyre-fixme[6]: For 1st argument expected
|
|
2976
|
+
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Module, Tensor]`.
|
|
2977
|
+
num_tables = len(self.cache_hash_size_cumsum) - 1
|
|
2978
|
+
num_offsets_per_table = (len(offsets) - 1) // num_tables
|
|
2979
|
+
cache_missed_locations = torch.where(
|
|
2980
|
+
lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
|
|
2981
|
+
)
|
|
2982
|
+
|
|
2983
|
+
for i in range(num_tables):
|
|
2984
|
+
start = offsets[i * num_offsets_per_table]
|
|
2985
|
+
end = offsets[(i + 1) * num_offsets_per_table]
|
|
2986
|
+
|
|
2987
|
+
current_cache_missed_locations = cache_missed_locations[start:end]
|
|
2988
|
+
unique_ids_list = torch.unique(current_cache_missed_locations)
|
|
2989
|
+
unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)
|
|
2990
|
+
|
|
2991
|
+
miss_count = torch.sum(unique_ids_count_list)
|
|
2992
|
+
|
|
2993
|
+
self.table_wise_cache_miss[i] += miss_count
|
|
2994
|
+
|
|
2995
|
+
def init_embedding_weights_uniform(self, min_val: float, max_val: float) -> None:
|
|
2996
|
+
splits = self.split_embedding_weights()
|
|
2997
|
+
if self.weights_precision == SparseType.INT8:
|
|
2998
|
+
# TODO: add in-place FloatToFused8BitRowwiseQuantized conversion
|
|
2999
|
+
for emb in splits:
|
|
3000
|
+
assert (
|
|
3001
|
+
len(emb.shape) == 2
|
|
3002
|
+
), "Int8 embedding only supported for 2D weight tensors."
|
|
3003
|
+
shape = [emb.shape[0], emb.shape[1] - self.int8_emb_row_dim_offset]
|
|
3004
|
+
tmp_emb = torch.zeros(shape, device=self.current_device)
|
|
3005
|
+
tmp_emb.uniform_(min_val, max_val)
|
|
3006
|
+
tmp_emb_i8 = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(tmp_emb)
|
|
3007
|
+
emb.data.copy_(tmp_emb_i8)
|
|
3008
|
+
# Torch doesnt implement direct fp8 distribution functions, so we need to start in higher precision.
|
|
3009
|
+
elif self.weights_precision == SparseType.NFP8:
|
|
3010
|
+
assert (
|
|
3011
|
+
self.current_device.type == "cuda"
|
|
3012
|
+
), "NFP8 is currently only supportd on GPU."
|
|
3013
|
+
assert self.optimizer in [
|
|
3014
|
+
OptimType.EXACT_ADAGRAD,
|
|
3015
|
+
OptimType.ROWWISE_ADAGRAD,
|
|
3016
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
3017
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
3018
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
3019
|
+
], "NFP8 is currently only supportd with adagrad optimizers."
|
|
3020
|
+
for param in splits:
|
|
3021
|
+
tmp_param = torch.zeros(param.shape, device=self.current_device)
|
|
3022
|
+
# Create initialized weights and cast to fp8.
|
|
3023
|
+
fp8_dtype = (
|
|
3024
|
+
torch.float8_e4m3fnuz
|
|
3025
|
+
if torch.version.hip is not None
|
|
3026
|
+
else torch.float8_e4m3fn
|
|
3027
|
+
)
|
|
3028
|
+
tmp_param.uniform_(min_val, max_val).to(fp8_dtype)
|
|
3029
|
+
param.data.copy_(tmp_param)
|
|
3030
|
+
else:
|
|
3031
|
+
for param in splits:
|
|
3032
|
+
param.uniform_(min_val, max_val)
|
|
3033
|
+
|
|
3034
|
+
@torch.jit.ignore
|
|
3035
|
+
def split_embedding_weights(self) -> list[Tensor]:
|
|
3036
|
+
"""
|
|
3037
|
+
Returns a list of embedding weights (view), split by table
|
|
3038
|
+
|
|
3039
|
+
Returns:
|
|
3040
|
+
A list of weights. Length = the number of tables
|
|
3041
|
+
"""
|
|
3042
|
+
splits = []
|
|
3043
|
+
for t, (rows, dim, _, _) in enumerate(self.embedding_specs):
|
|
3044
|
+
if self.weights_precision == SparseType.INT8:
|
|
3045
|
+
dim += self.int8_emb_row_dim_offset
|
|
3046
|
+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
|
3047
|
+
placement = self.weights_physical_placements[t]
|
|
3048
|
+
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
|
3049
|
+
offset = self.weights_physical_offsets[t]
|
|
3050
|
+
if placement == EmbeddingLocation.DEVICE.value:
|
|
3051
|
+
weights = self.weights_dev
|
|
3052
|
+
elif placement == EmbeddingLocation.HOST.value:
|
|
3053
|
+
weights = self.weights_host
|
|
3054
|
+
else:
|
|
3055
|
+
weights = self.weights_uvm
|
|
3056
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is
|
|
3057
|
+
# not a function.
|
|
3058
|
+
if weights.dim() == 2:
|
|
3059
|
+
weights = weights.flatten()
|
|
3060
|
+
splits.append(
|
|
3061
|
+
weights.detach()[offset : offset + rows * dim].view(rows, dim)
|
|
3062
|
+
)
|
|
3063
|
+
return splits
|
|
3064
|
+
|
|
3065
|
+
@torch.jit.ignore
|
|
3066
|
+
def get_optimizer_buffer(self, state: str) -> torch.Tensor:
|
|
3067
|
+
if self.optimizer == OptimType.NONE:
|
|
3068
|
+
raise NotImplementedError(
|
|
3069
|
+
f"Getting optimizer buffer is not supported for {self.optimizer}"
|
|
3070
|
+
)
|
|
3071
|
+
for name, buffer in self.named_buffers():
|
|
3072
|
+
if name == state:
|
|
3073
|
+
return buffer
|
|
3074
|
+
raise ValueError(f"Optimizer buffer {state} not found")
|
|
3075
|
+
|
|
3076
|
+
@torch.jit.export
|
|
3077
|
+
def get_optimizer_state(self) -> list[dict[str, torch.Tensor]]:
|
|
3078
|
+
r"""
|
|
3079
|
+
Get the optimizer state dict that matches the OSS Pytorch optims
|
|
3080
|
+
TODO: populate the supported list of optimizers
|
|
3081
|
+
"""
|
|
3082
|
+
split_optimizer_states = self.split_optimizer_states()
|
|
3083
|
+
if (
|
|
3084
|
+
self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
|
|
3085
|
+
or self.optimizer == OptimType.EXACT_ADAGRAD
|
|
3086
|
+
or self.optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD
|
|
3087
|
+
):
|
|
3088
|
+
list_of_state_dict = [
|
|
3089
|
+
(
|
|
3090
|
+
(
|
|
3091
|
+
{
|
|
3092
|
+
"sum": states[0],
|
|
3093
|
+
"prev_iter": states[1],
|
|
3094
|
+
"row_counter": states[2],
|
|
3095
|
+
"iter": self.iter,
|
|
3096
|
+
}
|
|
3097
|
+
if self.optimizer_args.regularization_mode
|
|
3098
|
+
== WeightDecayMode.COUNTER.value
|
|
3099
|
+
and self.optimizer_args.weight_decay_mode
|
|
3100
|
+
== CounterWeightDecayMode.ADAGRADW.value
|
|
3101
|
+
else {
|
|
3102
|
+
"sum": states[0],
|
|
3103
|
+
"prev_iter": states[1],
|
|
3104
|
+
"row_counter": states[2],
|
|
3105
|
+
}
|
|
3106
|
+
)
|
|
3107
|
+
if self._used_rowwise_adagrad_with_counter
|
|
3108
|
+
else (
|
|
3109
|
+
{
|
|
3110
|
+
"sum": states[0],
|
|
3111
|
+
"prev_iter": states[1],
|
|
3112
|
+
"iter": self.iter,
|
|
3113
|
+
}
|
|
3114
|
+
if self._used_rowwise_adagrad_with_global_weight_decay
|
|
3115
|
+
else {"sum": states[0]}
|
|
3116
|
+
)
|
|
3117
|
+
)
|
|
3118
|
+
for states in split_optimizer_states
|
|
3119
|
+
]
|
|
3120
|
+
elif self.optimizer == OptimType.SGD or self.optimizer == OptimType.EXACT_SGD:
|
|
3121
|
+
list_of_state_dict = [
|
|
3122
|
+
{"momentum_buffer": states[0]} for states in split_optimizer_states
|
|
3123
|
+
]
|
|
3124
|
+
elif self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction:
|
|
3125
|
+
list_of_state_dict = [
|
|
3126
|
+
{
|
|
3127
|
+
"exp_avg": states[0],
|
|
3128
|
+
"exp_avg_sq": states[1],
|
|
3129
|
+
"row_counter": states[2],
|
|
3130
|
+
}
|
|
3131
|
+
for states in split_optimizer_states
|
|
3132
|
+
]
|
|
3133
|
+
elif (
|
|
3134
|
+
self.optimizer == OptimType.ADAM
|
|
3135
|
+
or self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM
|
|
3136
|
+
or self.optimizer == OptimType.LAMB
|
|
3137
|
+
or self.optimizer == OptimType.PARTIAL_ROWWISE_LAMB
|
|
3138
|
+
):
|
|
3139
|
+
list_of_state_dict = [
|
|
3140
|
+
{"exp_avg": states[0], "exp_avg_sq": states[1]}
|
|
3141
|
+
for states in split_optimizer_states
|
|
3142
|
+
]
|
|
3143
|
+
elif self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD:
|
|
3144
|
+
list_of_state_dict = [
|
|
3145
|
+
{
|
|
3146
|
+
"sum": states[0],
|
|
3147
|
+
"sparse_ema": states[1],
|
|
3148
|
+
}
|
|
3149
|
+
for states in split_optimizer_states
|
|
3150
|
+
]
|
|
3151
|
+
else:
|
|
3152
|
+
raise NotImplementedError(
|
|
3153
|
+
f"Getting optimizer state {self.optimizer} is not implmeneted"
|
|
3154
|
+
)
|
|
3155
|
+
|
|
3156
|
+
return list_of_state_dict
|
|
3157
|
+
|
|
3158
|
+
@torch.jit.ignore
|
|
3159
|
+
def split_optimizer_states(
|
|
3160
|
+
self,
|
|
3161
|
+
) -> list[list[torch.Tensor]]:
|
|
3162
|
+
"""
|
|
3163
|
+
Returns a list of optimizer states (view), split by table
|
|
3164
|
+
|
|
3165
|
+
Returns:
|
|
3166
|
+
A list of list of states. Shape = (the number of tables, the number
|
|
3167
|
+
of states).
|
|
3168
|
+
|
|
3169
|
+
The following shows the list of states (in the returned order) for
|
|
3170
|
+
each optimizer:
|
|
3171
|
+
|
|
3172
|
+
(1) `ADAM`: `momentum1`, `momentum2`
|
|
3173
|
+
|
|
3174
|
+
(2) `EXACT_ADAGRAD`: `momentum1`
|
|
3175
|
+
|
|
3176
|
+
(3) `EXACT_ROWWISE_ADAGRAD`: `momentum1` (rowwise), `prev_iter`
|
|
3177
|
+
(rowwise; only when using `WeightDecayMode` = `COUNTER` or
|
|
3178
|
+
`COWCLIP` or `global_weight_decay` is not None), `row_counter`
|
|
3179
|
+
(rowwise; only when using `WeightDecayMode` = `COUNTER` or
|
|
3180
|
+
`COWCLIP`)
|
|
3181
|
+
|
|
3182
|
+
(4) `EXACT_SGD`: no states
|
|
3183
|
+
|
|
3184
|
+
(5) `LAMB`: `momentum1`, `momentum2`
|
|
3185
|
+
|
|
3186
|
+
(6) `LARS_SGD`: `momentum1`
|
|
3187
|
+
|
|
3188
|
+
(7) `PARTIAL_ROWWISE_ADAM`: `momentum1`, `momentum2` (rowwise)
|
|
3189
|
+
|
|
3190
|
+
(8) `PARTIAL_ROWWISE_LAMB`: `momentum1`, `momentum2` (rowwise)
|
|
3191
|
+
|
|
3192
|
+
(9) `ENSEMBLE_ROWWISE_ADAGRAD`: `momentum1` (rowwise), `momentum2`
|
|
3193
|
+
|
|
3194
|
+
(10) `NONE`: no states (throwing an error)
|
|
3195
|
+
|
|
3196
|
+
"""
|
|
3197
|
+
if self.optimizer == OptimType.NONE:
|
|
3198
|
+
raise NotImplementedError(
|
|
3199
|
+
f"Getting optimizer states is not supported for {self.optimizer}"
|
|
3200
|
+
)
|
|
3201
|
+
|
|
3202
|
+
def get_optimizer_states(
|
|
3203
|
+
state_dev: Tensor,
|
|
3204
|
+
state_host: Tensor,
|
|
3205
|
+
state_uvm: Tensor,
|
|
3206
|
+
state_offsets: Tensor,
|
|
3207
|
+
state_placements: Tensor,
|
|
3208
|
+
rowwise: bool,
|
|
3209
|
+
) -> list[torch.Tensor]:
|
|
3210
|
+
splits = []
|
|
3211
|
+
for t, (rows, dim, _, _) in enumerate(self.embedding_specs):
|
|
3212
|
+
offset = state_offsets[t]
|
|
3213
|
+
placement = state_placements[t]
|
|
3214
|
+
if placement == EmbeddingLocation.DEVICE:
|
|
3215
|
+
state = state_dev
|
|
3216
|
+
elif placement == EmbeddingLocation.HOST:
|
|
3217
|
+
state = state_host
|
|
3218
|
+
else:
|
|
3219
|
+
state = state_uvm
|
|
3220
|
+
if not rowwise:
|
|
3221
|
+
splits.append(
|
|
3222
|
+
state.detach()[offset : offset + rows * dim].view(rows, dim)
|
|
3223
|
+
)
|
|
3224
|
+
else:
|
|
3225
|
+
splits.append(state.detach()[offset : offset + rows].view(rows))
|
|
3226
|
+
return splits
|
|
3227
|
+
|
|
3228
|
+
states: list[list[torch.Tensor]] = []
|
|
3229
|
+
if self.optimizer not in (OptimType.EXACT_SGD,):
|
|
3230
|
+
states.append(
|
|
3231
|
+
get_optimizer_states(
|
|
3232
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
3233
|
+
# `Union[Module, Tensor]`.
|
|
3234
|
+
self.momentum1_dev,
|
|
3235
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
3236
|
+
# `Union[Module, Tensor]`.
|
|
3237
|
+
self.momentum1_host,
|
|
3238
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
3239
|
+
# `Union[Module, Tensor]`.
|
|
3240
|
+
self.momentum1_uvm,
|
|
3241
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
3242
|
+
# `Union[Module, Tensor]`.
|
|
3243
|
+
self.momentum1_physical_offsets,
|
|
3244
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
3245
|
+
# `Union[Module, Tensor]`.
|
|
3246
|
+
self.momentum1_physical_placements,
|
|
3247
|
+
rowwise=self.optimizer
|
|
3248
|
+
in [
|
|
3249
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
3250
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
3251
|
+
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
|
|
3252
|
+
],
|
|
3253
|
+
)
|
|
3254
|
+
)
|
|
3255
|
+
if self.optimizer in (
|
|
3256
|
+
OptimType.ADAM,
|
|
3257
|
+
OptimType.PARTIAL_ROWWISE_ADAM,
|
|
3258
|
+
OptimType.LAMB,
|
|
3259
|
+
OptimType.PARTIAL_ROWWISE_LAMB,
|
|
3260
|
+
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
|
|
3261
|
+
):
|
|
3262
|
+
states.append(
|
|
3263
|
+
get_optimizer_states(
|
|
3264
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
3265
|
+
# `Union[Module, Tensor]`.
|
|
3266
|
+
self.momentum2_dev,
|
|
3267
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
3268
|
+
# `Union[Module, Tensor]`.
|
|
3269
|
+
self.momentum2_host,
|
|
3270
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
3271
|
+
# `Union[Module, Tensor]`.
|
|
3272
|
+
self.momentum2_uvm,
|
|
3273
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
3274
|
+
# `Union[Module, Tensor]`.
|
|
3275
|
+
self.momentum2_physical_offsets,
|
|
3276
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
3277
|
+
# `Union[Module, Tensor]`.
|
|
3278
|
+
self.momentum2_physical_placements,
|
|
3279
|
+
rowwise=self.optimizer
|
|
3280
|
+
in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB),
|
|
3281
|
+
)
|
|
3282
|
+
)
|
|
3283
|
+
if (
|
|
3284
|
+
self._used_rowwise_adagrad_with_counter
|
|
3285
|
+
or self._used_rowwise_adagrad_with_global_weight_decay
|
|
3286
|
+
):
|
|
3287
|
+
states.append(
|
|
3288
|
+
get_optimizer_states(
|
|
3289
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
3290
|
+
# `Union[Module, Tensor]`.
|
|
3291
|
+
self.prev_iter_dev,
|
|
3292
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
3293
|
+
# `Union[Module, Tensor]`.
|
|
3294
|
+
self.prev_iter_host,
|
|
3295
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
3296
|
+
# `Union[Module, Tensor]`.
|
|
3297
|
+
self.prev_iter_uvm,
|
|
3298
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
3299
|
+
# `Union[Module, Tensor]`.
|
|
3300
|
+
self.prev_iter_physical_offsets,
|
|
3301
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
3302
|
+
# `Union[Module, Tensor]`.
|
|
3303
|
+
self.prev_iter_physical_placements,
|
|
3304
|
+
rowwise=True,
|
|
3305
|
+
)
|
|
3306
|
+
)
|
|
3307
|
+
if self._used_rowwise_adagrad_with_counter or (
|
|
3308
|
+
self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction
|
|
3309
|
+
):
|
|
3310
|
+
states.append(
|
|
3311
|
+
get_optimizer_states(
|
|
3312
|
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
|
|
3313
|
+
# `Union[Module, Tensor]`.
|
|
3314
|
+
self.row_counter_dev,
|
|
3315
|
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
|
3316
|
+
# `Union[Module, Tensor]`.
|
|
3317
|
+
self.row_counter_host,
|
|
3318
|
+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
|
3319
|
+
# `Union[Module, Tensor]`.
|
|
3320
|
+
self.row_counter_uvm,
|
|
3321
|
+
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
|
|
3322
|
+
# `Union[Module, Tensor]`.
|
|
3323
|
+
self.row_counter_physical_offsets,
|
|
3324
|
+
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
|
|
3325
|
+
# `Union[Module, Tensor]`.
|
|
3326
|
+
self.row_counter_physical_placements,
|
|
3327
|
+
rowwise=True,
|
|
3328
|
+
)
|
|
3329
|
+
)
|
|
3330
|
+
return_states = [list(s) for s in zip(*states)]
|
|
3331
|
+
return return_states
|
|
3332
|
+
|
|
3333
|
+
@torch.jit.export
|
|
3334
|
+
def set_learning_rate(self, lr: float) -> None:
|
|
3335
|
+
"""
|
|
3336
|
+
Sets the learning rate.
|
|
3337
|
+
|
|
3338
|
+
Args:
|
|
3339
|
+
lr (float): The learning rate value to set to
|
|
3340
|
+
"""
|
|
3341
|
+
if self.optimizer == OptimType.NONE:
|
|
3342
|
+
raise NotImplementedError(
|
|
3343
|
+
f"Setting learning rate is not supported for {self.optimizer}"
|
|
3344
|
+
)
|
|
3345
|
+
self._set_learning_rate(lr)
|
|
3346
|
+
|
|
3347
|
+
def get_learning_rate(self) -> float:
|
|
3348
|
+
"""
|
|
3349
|
+
Get and return the learning rate.
|
|
3350
|
+
"""
|
|
3351
|
+
return self.learning_rate_tensor.item()
|
|
3352
|
+
|
|
3353
|
+
@torch.jit.ignore
|
|
3354
|
+
def update_hyper_parameters(self, params_dict: dict[str, float]) -> None:
|
|
3355
|
+
"""
|
|
3356
|
+
Sets hyper-parameters from external control flow.
|
|
3357
|
+
|
|
3358
|
+
Args:
|
|
3359
|
+
params_dict (Dict[str, float]): The dict that contains the
|
|
3360
|
+
hyper-parameter names and their values
|
|
3361
|
+
"""
|
|
3362
|
+
if self.optimizer == OptimType.NONE:
|
|
3363
|
+
raise NotImplementedError(
|
|
3364
|
+
f"Setting learning rate is not supported for {self.optimizer}"
|
|
3365
|
+
)
|
|
3366
|
+
for parameter_name, value in params_dict.items():
|
|
3367
|
+
if parameter_name == "lr":
|
|
3368
|
+
self._set_learning_rate(value)
|
|
3369
|
+
elif parameter_name == "eps":
|
|
3370
|
+
self.optimizer_args = self.optimizer_args._replace(eps=value)
|
|
3371
|
+
elif parameter_name == "beta1":
|
|
3372
|
+
self.optimizer_args = self.optimizer_args._replace(beta1=value)
|
|
3373
|
+
elif parameter_name == "beta2":
|
|
3374
|
+
self.optimizer_args = self.optimizer_args._replace(beta2=value)
|
|
3375
|
+
elif parameter_name == "weight_decay":
|
|
3376
|
+
self.optimizer_args = self.optimizer_args._replace(weight_decay=value)
|
|
3377
|
+
elif parameter_name == "lower_bound":
|
|
3378
|
+
self.gwd_lower_bound = value
|
|
3379
|
+
else:
|
|
3380
|
+
raise NotImplementedError(
|
|
3381
|
+
f"Setting hyper-parameter {parameter_name} is not supported"
|
|
3382
|
+
)
|
|
3383
|
+
|
|
3384
|
+
@torch.jit.ignore
|
|
3385
|
+
def _set_learning_rate(self, lr: float) -> float:
|
|
3386
|
+
"""
|
|
3387
|
+
Helper function to script `set_learning_rate`.
|
|
3388
|
+
Note that returning None does not work.
|
|
3389
|
+
"""
|
|
3390
|
+
self.learning_rate_tensor.fill_(lr)
|
|
3391
|
+
return 0.0
|
|
3392
|
+
|
|
3393
|
+
@torch.jit.ignore
|
|
3394
|
+
def set_optimizer_step(self, step: int) -> None:
|
|
3395
|
+
"""
|
|
3396
|
+
Sets the optimizer step.
|
|
3397
|
+
|
|
3398
|
+
Args:
|
|
3399
|
+
step (int): The step value to set to
|
|
3400
|
+
"""
|
|
3401
|
+
self.log(f"set_optimizer_step from {self.iter[0]=} to {step=}")
|
|
3402
|
+
if self.optimizer == OptimType.NONE:
|
|
3403
|
+
raise NotImplementedError(
|
|
3404
|
+
f"Setting optimizer step is not supported for {self.optimizer}"
|
|
3405
|
+
)
|
|
3406
|
+
self.iter[0] = step
|
|
3407
|
+
|
|
3408
|
+
@torch.jit.export
|
|
3409
|
+
def flush(self) -> None:
|
|
3410
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
|
|
3411
|
+
# a function.
|
|
3412
|
+
if not self.lxu_cache_weights.numel():
|
|
3413
|
+
return
|
|
3414
|
+
torch.ops.fbgemm.lxu_cache_flush(
|
|
3415
|
+
self.weights_uvm,
|
|
3416
|
+
self.cache_hash_size_cumsum,
|
|
3417
|
+
self.cache_index_table_map,
|
|
3418
|
+
self.weights_offsets,
|
|
3419
|
+
self.D_offsets,
|
|
3420
|
+
self.total_D,
|
|
3421
|
+
self.lxu_cache_state,
|
|
3422
|
+
self.lxu_cache_weights,
|
|
3423
|
+
self.stochastic_rounding,
|
|
3424
|
+
)
|
|
3425
|
+
|
|
3426
|
+
def _apply_split(
|
|
3427
|
+
self,
|
|
3428
|
+
split: SplitState,
|
|
3429
|
+
prefix: str,
|
|
3430
|
+
dtype: type[torch.dtype],
|
|
3431
|
+
enforce_hbm: bool = False,
|
|
3432
|
+
make_dev_param: bool = False,
|
|
3433
|
+
dev_reshape: Optional[tuple[int, ...]] = None,
|
|
3434
|
+
uvm_host_mapped: bool = False,
|
|
3435
|
+
) -> None:
|
|
3436
|
+
apply_split_helper(
|
|
3437
|
+
self.register_buffer,
|
|
3438
|
+
functools.partial(setattr, self),
|
|
3439
|
+
self.current_device,
|
|
3440
|
+
self.use_cpu,
|
|
3441
|
+
self.feature_table_map,
|
|
3442
|
+
split,
|
|
3443
|
+
prefix,
|
|
3444
|
+
dtype,
|
|
3445
|
+
enforce_hbm,
|
|
3446
|
+
make_dev_param,
|
|
3447
|
+
dev_reshape,
|
|
3448
|
+
self._uvm_tensors_log,
|
|
3449
|
+
uvm_host_mapped=uvm_host_mapped,
|
|
3450
|
+
)
|
|
3451
|
+
|
|
3452
|
+
def _apply_cache_state(
|
|
3453
|
+
self,
|
|
3454
|
+
cache_state: CacheState,
|
|
3455
|
+
cache_algorithm: CacheAlgorithm,
|
|
3456
|
+
cache_load_factor: float,
|
|
3457
|
+
cache_sets: int,
|
|
3458
|
+
cache_reserved_memory: float,
|
|
3459
|
+
cache_precision: SparseType,
|
|
3460
|
+
) -> None:
|
|
3461
|
+
self.cache_algorithm = cache_algorithm
|
|
3462
|
+
self.timestep = 1
|
|
3463
|
+
self.timesteps_prefetched = []
|
|
3464
|
+
|
|
3465
|
+
self.max_prefetch_depth = MAX_PREFETCH_DEPTH
|
|
3466
|
+
self.lxu_cache_locations_list = []
|
|
3467
|
+
self.lxu_cache_locations_empty = torch.empty(
|
|
3468
|
+
0, device=self.current_device, dtype=torch.int32
|
|
3469
|
+
).fill_(-1)
|
|
3470
|
+
self.lxu_cache_locations = self.lxu_cache_locations_empty
|
|
3471
|
+
self._indices = self.lxu_cache_locations_empty
|
|
3472
|
+
self._offsets = self.lxu_cache_locations_empty
|
|
3473
|
+
self._vbe_B_offsets = self.lxu_cache_locations_empty
|
|
3474
|
+
self._vbe_max_B = -1
|
|
3475
|
+
self.prefetch_stream: Optional[torch.cuda.Stream] = None
|
|
3476
|
+
|
|
3477
|
+
self._init_uvm_cache_stats()
|
|
3478
|
+
|
|
3479
|
+
if cache_precision == SparseType.FP32:
|
|
3480
|
+
dtype = torch.float32
|
|
3481
|
+
elif cache_precision == SparseType.FP16:
|
|
3482
|
+
dtype = torch.float16
|
|
3483
|
+
elif cache_precision == SparseType.NFP8:
|
|
3484
|
+
# NFP8 weights use floating point cache.
|
|
3485
|
+
dtype = torch.float16
|
|
3486
|
+
else:
|
|
3487
|
+
dtype = torch.float32 # not relevant, but setting it to keep linter happy
|
|
3488
|
+
if not self.use_cpu > 0:
|
|
3489
|
+
raise AssertionError(
|
|
3490
|
+
f"cache_precision {cache_precision} not supported!"
|
|
3491
|
+
)
|
|
3492
|
+
|
|
3493
|
+
# NOTE: no cache for CPU mode!
|
|
3494
|
+
if cache_state.total_cache_hash_size == 0 or self.use_cpu:
|
|
3495
|
+
self.register_buffer(
|
|
3496
|
+
"lxu_cache_weights",
|
|
3497
|
+
torch.zeros(0, 0, device=self.current_device, dtype=dtype),
|
|
3498
|
+
)
|
|
3499
|
+
# NOTE: make TorchScript work!
|
|
3500
|
+
self.register_buffer(
|
|
3501
|
+
"cache_hash_size_cumsum",
|
|
3502
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
3503
|
+
persistent=False,
|
|
3504
|
+
)
|
|
3505
|
+
self.register_buffer(
|
|
3506
|
+
"total_cache_hash_size",
|
|
3507
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
3508
|
+
persistent=False,
|
|
3509
|
+
)
|
|
3510
|
+
self.register_buffer(
|
|
3511
|
+
"cache_index_table_map",
|
|
3512
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
3513
|
+
persistent=False,
|
|
3514
|
+
)
|
|
3515
|
+
self.register_buffer(
|
|
3516
|
+
"lxu_cache_state",
|
|
3517
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
3518
|
+
persistent=False,
|
|
3519
|
+
)
|
|
3520
|
+
self.register_buffer(
|
|
3521
|
+
"lxu_state",
|
|
3522
|
+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
|
|
3523
|
+
persistent=False,
|
|
3524
|
+
)
|
|
3525
|
+
self.register_buffer(
|
|
3526
|
+
"cache_miss_counter",
|
|
3527
|
+
torch.tensor([0, 0], dtype=torch.int64),
|
|
3528
|
+
persistent=False,
|
|
3529
|
+
)
|
|
3530
|
+
self._init_uvm_cache_counter(cache_sets, persistent=False)
|
|
3531
|
+
return
|
|
3532
|
+
|
|
3533
|
+
assert cache_load_factor > 0
|
|
3534
|
+
element_size = 2 if dtype == torch.float16 else 4
|
|
3535
|
+
if cache_sets <= 0:
|
|
3536
|
+
total_memory = torch.cuda.get_device_properties(
|
|
3537
|
+
self.current_device
|
|
3538
|
+
).total_memory
|
|
3539
|
+
free_memory = (
|
|
3540
|
+
total_memory
|
|
3541
|
+
- torch.cuda.memory_reserved(self.current_device)
|
|
3542
|
+
- int(cache_reserved_memory)
|
|
3543
|
+
)
|
|
3544
|
+
assert free_memory > 0
|
|
3545
|
+
cache_sets = (
|
|
3546
|
+
int(cache_state.total_cache_hash_size * cache_load_factor)
|
|
3547
|
+
+ DEFAULT_ASSOC
|
|
3548
|
+
- 1
|
|
3549
|
+
) // DEFAULT_ASSOC
|
|
3550
|
+
cache_sets = 1 if cache_sets == 0 else cache_sets
|
|
3551
|
+
cache_size = cache_sets * DEFAULT_ASSOC * element_size * self.max_D_cache
|
|
3552
|
+
if cache_size > free_memory:
|
|
3553
|
+
cache_sets = (
|
|
3554
|
+
int(1.0 * free_memory / self.max_D_cache / element_size)
|
|
3555
|
+
+ DEFAULT_ASSOC
|
|
3556
|
+
- 1
|
|
3557
|
+
) // DEFAULT_ASSOC
|
|
3558
|
+
cache_load_factor = (
|
|
3559
|
+
1.0 * cache_sets * DEFAULT_ASSOC / int(cache_state.total_cache_hash_size)
|
|
3560
|
+
)
|
|
3561
|
+
assert cache_sets > 0
|
|
3562
|
+
if cache_algorithm == CacheAlgorithm.LFU:
|
|
3563
|
+
assert cache_sets < 2**24 - 1
|
|
3564
|
+
cache_size = cache_sets * DEFAULT_ASSOC * element_size * self.max_D_cache
|
|
3565
|
+
self.log(
|
|
3566
|
+
f"Using on-device cache with admission algorithm "
|
|
3567
|
+
f"{cache_algorithm}, {cache_sets} sets, "
|
|
3568
|
+
f"load_factor: {cache_load_factor : .3f}, "
|
|
3569
|
+
f"cache_size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
|
|
3570
|
+
f"cache_precision: {dtype}, "
|
|
3571
|
+
f"weights_precision: {self.weights_precision}"
|
|
3572
|
+
)
|
|
3573
|
+
|
|
3574
|
+
self.total_cache_hash_size = cache_state.total_cache_hash_size
|
|
3575
|
+
# 8x of # tables, trivial size
|
|
3576
|
+
self.register_buffer(
|
|
3577
|
+
"cache_hash_size_cumsum",
|
|
3578
|
+
torch.tensor(
|
|
3579
|
+
cache_state.cache_hash_size_cumsum,
|
|
3580
|
+
device=self.current_device,
|
|
3581
|
+
dtype=torch.int64,
|
|
3582
|
+
),
|
|
3583
|
+
)
|
|
3584
|
+
# 4x total embedding hash size with uvm cache
|
|
3585
|
+
self.register_buffer(
|
|
3586
|
+
"cache_index_table_map",
|
|
3587
|
+
torch.tensor(
|
|
3588
|
+
cache_state.cache_index_table_map,
|
|
3589
|
+
device=self.current_device,
|
|
3590
|
+
dtype=torch.int32,
|
|
3591
|
+
),
|
|
3592
|
+
)
|
|
3593
|
+
# 8x of total cache slots (embedding hash size * clf)
|
|
3594
|
+
self.register_buffer(
|
|
3595
|
+
"lxu_cache_state",
|
|
3596
|
+
torch.zeros(
|
|
3597
|
+
cache_sets, DEFAULT_ASSOC, device=self.current_device, dtype=torch.int64
|
|
3598
|
+
).fill_(-1),
|
|
3599
|
+
)
|
|
3600
|
+
# Cache itself, not auxiliary size
|
|
3601
|
+
self.register_buffer(
|
|
3602
|
+
"lxu_cache_weights",
|
|
3603
|
+
torch.zeros(
|
|
3604
|
+
cache_sets * DEFAULT_ASSOC,
|
|
3605
|
+
self.max_D_cache,
|
|
3606
|
+
device=self.current_device,
|
|
3607
|
+
dtype=dtype,
|
|
3608
|
+
),
|
|
3609
|
+
)
|
|
3610
|
+
# LRU: 8x of total cache slots (embedding hash size * clf)
|
|
3611
|
+
# LFU: 8x of total embedding hash size with uvm cache
|
|
3612
|
+
self.register_buffer(
|
|
3613
|
+
"lxu_state",
|
|
3614
|
+
torch.zeros(
|
|
3615
|
+
size=(
|
|
3616
|
+
(self.total_cache_hash_size + 1,)
|
|
3617
|
+
if cache_algorithm == CacheAlgorithm.LFU
|
|
3618
|
+
else (cache_sets, DEFAULT_ASSOC)
|
|
3619
|
+
),
|
|
3620
|
+
device=self.current_device,
|
|
3621
|
+
dtype=torch.int64,
|
|
3622
|
+
),
|
|
3623
|
+
)
|
|
3624
|
+
self.register_buffer(
|
|
3625
|
+
"cache_miss_counter",
|
|
3626
|
+
torch.tensor([0, 0], device=self.current_device, dtype=torch.int64),
|
|
3627
|
+
)
|
|
3628
|
+
self._init_uvm_cache_counter(cache_sets, persistent=True)
|
|
3629
|
+
if self.prefetch_pipeline:
|
|
3630
|
+
# using the placeholder_autograd_tensor to make sure
|
|
3631
|
+
# the hook is executed after the backward pass
|
|
3632
|
+
# not using register_module_full_backward_hook
|
|
3633
|
+
# due to https://github.com/pytorch/pytorch/issues/100528
|
|
3634
|
+
self.placeholder_autograd_tensor.register_hook(
|
|
3635
|
+
self._sync_stream_post_backward
|
|
3636
|
+
)
|
|
3637
|
+
self.register_full_backward_pre_hook(
|
|
3638
|
+
self._update_cache_counter_and_locations
|
|
3639
|
+
)
|
|
3640
|
+
|
|
3641
|
+
if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
|
|
3642
|
+
raise ValueError(
|
|
3643
|
+
f"cache_algorithm must be {CacheAlgorithm.LRU} "
|
|
3644
|
+
f"or {CacheAlgorithm.LFU}"
|
|
3645
|
+
)
|
|
3646
|
+
|
|
3647
|
+
# pyre-ignore
|
|
3648
|
+
def _recording_to_timer(
|
|
3649
|
+
self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
|
|
3650
|
+
) -> Any:
|
|
3651
|
+
if self.stats_reporter is not None and self.stats_reporter.should_report(
|
|
3652
|
+
self.step
|
|
3653
|
+
):
|
|
3654
|
+
assert (
|
|
3655
|
+
timer
|
|
3656
|
+
), "We shouldn't be here, async timer must have been initiated if reporter is present."
|
|
3657
|
+
return timer.recording(**kwargs)
|
|
3658
|
+
# No-Op context manager
|
|
3659
|
+
return contextlib.nullcontext()
|
|
3660
|
+
|
|
3661
|
+
def _sync_stream_post_backward(
|
|
3662
|
+
self,
|
|
3663
|
+
grad: Tensor,
|
|
3664
|
+
) -> None:
|
|
3665
|
+
"""
|
|
3666
|
+
backward hook function when prefetch_pipeline is enabled.
|
|
3667
|
+
|
|
3668
|
+
With the pipeline, prefetch(batch_{i+2}) may overlap with backward(batch_{i}).
|
|
3669
|
+
There is race condition that backward(batch_i) writes to UVM memory and
|
|
3670
|
+
at the same time prefetch(batch_{i+2}) loads UVM memory to cache. This stream sync forces
|
|
3671
|
+
backward(batch_i) to finish before prefetch(batch_{i+2}).
|
|
3672
|
+
"""
|
|
3673
|
+
if self.prefetch_stream is not None:
|
|
3674
|
+
self.prefetch_stream.wait_stream(torch.cuda.current_stream())
|
|
3675
|
+
|
|
3676
|
+
def _update_cache_counter_and_locations(
|
|
3677
|
+
self,
|
|
3678
|
+
module: nn.Module,
|
|
3679
|
+
grad_input: Union[tuple[Tensor, ...], Tensor],
|
|
3680
|
+
) -> None:
|
|
3681
|
+
"""
|
|
3682
|
+
Backward prehook function when prefetch_pipeline is enabled.
|
|
3683
|
+
|
|
3684
|
+
This function does 3 things:
|
|
3685
|
+
1. backward stream waits for prefetch stream to finish.
|
|
3686
|
+
Otherwise the prefetch(batch_{i+1}) might overlap with backward(batch_i).
|
|
3687
|
+
If an idx is not in cache in batch_i, but it is being inserted in batch_{i+1},
|
|
3688
|
+
there is race condition that backward(batch_i) writes to UVM memory and
|
|
3689
|
+
at the same time prefetch(batch_{i+1}) loads UVM memory to cache.
|
|
3690
|
+
|
|
3691
|
+
2. decrement the lxu_cache_locking_counter to indicate the current batch is finished.
|
|
3692
|
+
The lxu_cache_locking_counter is updated in both prefetch and TBE backward.
|
|
3693
|
+
As there is no overlap between prefetch and backward, we can decrement either before or
|
|
3694
|
+
after backward. It's better to decrement before lxu_cache_locations gets updated.
|
|
3695
|
+
|
|
3696
|
+
3. update lxu_cache_locations to address the cache inconsistency issue.
|
|
3697
|
+
In the case that the same index is not inserted into cache in batch_i,
|
|
3698
|
+
but it is inserted in batch_{i+1}, the cache can be invalid in
|
|
3699
|
+
the sense that the cached weight for this index does not have the
|
|
3700
|
+
backward update of batch_i.
|
|
3701
|
+
|
|
3702
|
+
Example of the issue is as follows:
|
|
3703
|
+
idx is in batch_i, batch_{i+1}
|
|
3704
|
+
prefetch(batch_i)
|
|
3705
|
+
- failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
|
|
3706
|
+
forward(batch_i)
|
|
3707
|
+
prefetch(batch_{i+1})
|
|
3708
|
+
- insert idx into cache, cache is loaded from host memory
|
|
3709
|
+
backward(batch_i)
|
|
3710
|
+
- cache_locations_batch_i of idx is -1, the host memory is updated
|
|
3711
|
+
forward(batch_{i+1})
|
|
3712
|
+
- OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.
|
|
3713
|
+
|
|
3714
|
+
The fix to this cache inconsistency is to update the cache_locations_batch_i before backward of batch_i,
|
|
3715
|
+
so that the cache gets updated correctly by the backward pass of TBE.
|
|
3716
|
+
"""
|
|
3717
|
+
|
|
3718
|
+
if self.prefetch_stream is not None:
|
|
3719
|
+
# need to wait for the prefetch of next batch,
|
|
3720
|
+
# so that cache states are valid
|
|
3721
|
+
with self._recording_to_timer(
|
|
3722
|
+
self.bwd_wait_prefetch_timer,
|
|
3723
|
+
context=self.step,
|
|
3724
|
+
stream=torch.cuda.current_stream(),
|
|
3725
|
+
):
|
|
3726
|
+
torch.cuda.current_stream().wait_stream(self.prefetch_stream)
|
|
3727
|
+
|
|
3728
|
+
torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
|
|
3729
|
+
self.lxu_cache_locking_counter,
|
|
3730
|
+
self.lxu_cache_locations,
|
|
3731
|
+
)
|
|
3732
|
+
# Recompute linear_cache_indices
|
|
3733
|
+
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
3734
|
+
self.cache_hash_size_cumsum,
|
|
3735
|
+
self._indices,
|
|
3736
|
+
self._offsets,
|
|
3737
|
+
self._vbe_B_offsets,
|
|
3738
|
+
self._vbe_max_B,
|
|
3739
|
+
)
|
|
3740
|
+
(
|
|
3741
|
+
linear_unique_indices,
|
|
3742
|
+
linear_unique_indices_length,
|
|
3743
|
+
_,
|
|
3744
|
+
) = torch.ops.fbgemm.get_unique_indices(
|
|
3745
|
+
linear_cache_indices,
|
|
3746
|
+
self.total_cache_hash_size,
|
|
3747
|
+
compute_count=False,
|
|
3748
|
+
)
|
|
3749
|
+
torch.ops.fbgemm.lxu_cache_lookup(
|
|
3750
|
+
linear_unique_indices,
|
|
3751
|
+
self.lxu_cache_state,
|
|
3752
|
+
self.total_cache_hash_size,
|
|
3753
|
+
gather_cache_stats=False, # not collecting cache stats
|
|
3754
|
+
num_uniq_cache_indices=linear_unique_indices_length,
|
|
3755
|
+
lxu_cache_locations_output=self.lxu_cache_locations,
|
|
3756
|
+
)
|
|
3757
|
+
|
|
3758
|
+
def _init_uvm_cache_counter(self, cache_sets: int, persistent: bool) -> None:
|
|
3759
|
+
if self.prefetch_pipeline and persistent:
|
|
3760
|
+
self.register_buffer(
|
|
3761
|
+
"lxu_cache_locking_counter",
|
|
3762
|
+
torch.zeros(
|
|
3763
|
+
cache_sets,
|
|
3764
|
+
DEFAULT_ASSOC,
|
|
3765
|
+
device=self.current_device,
|
|
3766
|
+
dtype=torch.int32,
|
|
3767
|
+
),
|
|
3768
|
+
)
|
|
3769
|
+
else:
|
|
3770
|
+
self.register_buffer(
|
|
3771
|
+
"lxu_cache_locking_counter",
|
|
3772
|
+
torch.zeros([0, 0], dtype=torch.int32, device=self.current_device),
|
|
3773
|
+
persistent=persistent,
|
|
3774
|
+
)
|
|
3775
|
+
|
|
3776
|
+
def _init_uvm_cache_stats(self) -> None:
|
|
3777
|
+
if not self.gather_uvm_cache_stats:
|
|
3778
|
+
# If uvm_cache_stats is not enabled, register stub entries via buffer to state_dict for TorchScript to JIT properly.
|
|
3779
|
+
# Since we're not using these variables, we can choose minimize tensor size to keep state_dict size small.
|
|
3780
|
+
self.register_buffer(
|
|
3781
|
+
"uvm_cache_stats",
|
|
3782
|
+
torch.zeros(
|
|
3783
|
+
1,
|
|
3784
|
+
device=self.current_device,
|
|
3785
|
+
dtype=torch.int64,
|
|
3786
|
+
),
|
|
3787
|
+
persistent=False,
|
|
3788
|
+
)
|
|
3789
|
+
self.register_buffer(
|
|
3790
|
+
"local_uvm_cache_stats",
|
|
3791
|
+
torch.zeros(
|
|
3792
|
+
1,
|
|
3793
|
+
device=self.current_device,
|
|
3794
|
+
dtype=torch.int32,
|
|
3795
|
+
),
|
|
3796
|
+
persistent=False,
|
|
3797
|
+
)
|
|
3798
|
+
else:
|
|
3799
|
+
self.register_buffer(
|
|
3800
|
+
"uvm_cache_stats",
|
|
3801
|
+
torch.zeros(
|
|
3802
|
+
size=(self.uvm_cache_stats_size,),
|
|
3803
|
+
device=self.current_device,
|
|
3804
|
+
dtype=torch.int64,
|
|
3805
|
+
),
|
|
3806
|
+
)
|
|
3807
|
+
self.register_buffer(
|
|
3808
|
+
"local_uvm_cache_stats",
|
|
3809
|
+
torch.zeros(
|
|
3810
|
+
size=(self.uvm_cache_stats_size,),
|
|
3811
|
+
device=self.current_device,
|
|
3812
|
+
dtype=torch.int32,
|
|
3813
|
+
),
|
|
3814
|
+
)
|
|
3815
|
+
self.reset_uvm_cache_stats()
|
|
3816
|
+
self.last_uvm_cache_print_state = torch.zeros_like(self.uvm_cache_stats)
|
|
3817
|
+
|
|
3818
|
+
def reset_cache_states(self) -> None:
|
|
3819
|
+
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
|
|
3820
|
+
# a function.
|
|
3821
|
+
if not self.lxu_cache_weights.numel():
|
|
3822
|
+
return
|
|
3823
|
+
self.lxu_cache_state.fill_(-1)
|
|
3824
|
+
self.lxu_state.fill_(0)
|
|
3825
|
+
self.timestep = 1
|
|
3826
|
+
|
|
3827
|
+
def reset_embedding_weight_momentum(
|
|
3828
|
+
self,
|
|
3829
|
+
pruned_indices: Tensor,
|
|
3830
|
+
pruned_indices_offsets: Tensor,
|
|
3831
|
+
logical_table_ids: Tensor,
|
|
3832
|
+
buffer_ids: Tensor,
|
|
3833
|
+
) -> None:
|
|
3834
|
+
if self.optimizer == OptimType.NONE:
|
|
3835
|
+
raise NotImplementedError(
|
|
3836
|
+
f"Resetting embedding weight momentum is not supported for {self.optimizer}"
|
|
3837
|
+
)
|
|
3838
|
+
total_cache_hash_size = 0
|
|
3839
|
+
if isinstance(self.total_cache_hash_size, Tensor):
|
|
3840
|
+
total_cache_hash_size = self.total_cache_hash_size.item()
|
|
3841
|
+
else:
|
|
3842
|
+
total_cache_hash_size = self.total_cache_hash_size
|
|
3843
|
+
|
|
3844
|
+
rowwise = self.optimizer in [
|
|
3845
|
+
OptimType.EXACT_ROWWISE_ADAGRAD,
|
|
3846
|
+
]
|
|
3847
|
+
if rowwise:
|
|
3848
|
+
torch.ops.fbgemm.reset_weight_momentum(
|
|
3849
|
+
dev_weights=self.weights_dev,
|
|
3850
|
+
uvm_weights=self.weights_uvm,
|
|
3851
|
+
lxu_cache_weights=self.lxu_cache_weights,
|
|
3852
|
+
weights_placements=self.weights_placements,
|
|
3853
|
+
weights_offsets=self.weights_offsets,
|
|
3854
|
+
momentum1_dev=self.momentum1_dev,
|
|
3855
|
+
momentum1_uvm=self.momentum1_uvm,
|
|
3856
|
+
momentum1_placements=self.momentum1_placements,
|
|
3857
|
+
momentum1_offsets=self.momentum1_offsets,
|
|
3858
|
+
D_offsets=self.D_offsets,
|
|
3859
|
+
pruned_indices=pruned_indices.to(device=self.current_device),
|
|
3860
|
+
pruned_indices_offsets=pruned_indices_offsets.to(
|
|
3861
|
+
device=self.current_device
|
|
3862
|
+
),
|
|
3863
|
+
logical_table_ids=logical_table_ids.to(device=self.current_device),
|
|
3864
|
+
buffer_ids=buffer_ids.to(device=self.current_device),
|
|
3865
|
+
cache_hash_size_cumsum=self.cache_hash_size_cumsum,
|
|
3866
|
+
lxu_cache_state=self.lxu_cache_state,
|
|
3867
|
+
total_cache_hash_size=total_cache_hash_size,
|
|
3868
|
+
)
|
|
3869
|
+
|
|
3870
|
+
def prepare_inputs(
|
|
3871
|
+
self,
|
|
3872
|
+
indices: Tensor,
|
|
3873
|
+
offsets: Tensor,
|
|
3874
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
3875
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3876
|
+
force_cast_input_types: bool = True,
|
|
3877
|
+
prefetch_pipeline: bool = False,
|
|
3878
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
|
|
3879
|
+
"""
|
|
3880
|
+
Prepare TBE inputs as follows:
|
|
3881
|
+
|
|
3882
|
+
(1) Create VBE metadata
|
|
3883
|
+
(2) Convert input types if `force_cast_input_types=True`
|
|
3884
|
+
(3) Run `bounds_check_indices` if `bounds_check_mode` is not
|
|
3885
|
+
BoundsCheckMode.NONE
|
|
3886
|
+
|
|
3887
|
+
Args:
|
|
3888
|
+
indices (Tensor): Input indices
|
|
3889
|
+
offsets (Tensor): Input offsets
|
|
3890
|
+
per_sample_weights (Optional[Tensor]): Input per sample
|
|
3891
|
+
weights
|
|
3892
|
+
batch_size_per_feature_per_rank
|
|
3893
|
+
(Optional[List[List[int]]]): A 2D tensor of batch size
|
|
3894
|
+
for each rank and feature. Shape = (number of
|
|
3895
|
+
features, number of ranks)
|
|
3896
|
+
force_cast_input_types (bool): A flag to force convert
|
|
3897
|
+
input types if set to True
|
|
3898
|
+
|
|
3899
|
+
Returns:
|
|
3900
|
+
A tuple of indices, offsets, per_sample_weights, and VBE
|
|
3901
|
+
metadata
|
|
3902
|
+
"""
|
|
3903
|
+
|
|
3904
|
+
# Generate VBE metadata
|
|
3905
|
+
vbe_metadata = self._generate_vbe_metadata(
|
|
3906
|
+
offsets, batch_size_per_feature_per_rank
|
|
3907
|
+
)
|
|
3908
|
+
|
|
3909
|
+
vbe = vbe_metadata.B_offsets is not None
|
|
3910
|
+
# Note this check has already been done in C++ side
|
|
3911
|
+
# TODO: max_B <= self.info_B_mask in python
|
|
3912
|
+
# We cannot use assert as it breaks pt2 compile for dynamic shape
|
|
3913
|
+
# and need to use torch._check for dynamic shape and cannot construct fstring, use constant string.
|
|
3914
|
+
# torch._check(
|
|
3915
|
+
# max_B <= self.info_B_mask,
|
|
3916
|
+
# "Not enough infos bits to accommodate T and B.",
|
|
3917
|
+
# )
|
|
3918
|
+
# We cannot use lambda as it fails jit script.
|
|
3919
|
+
# torch._check is also not supported in jitscript
|
|
3920
|
+
|
|
3921
|
+
# TODO: remove this and add an assert after updating
|
|
3922
|
+
# bounds_check_indices to support different indices type and offset
|
|
3923
|
+
# type
|
|
3924
|
+
force_cast_input_types = (
|
|
3925
|
+
indices.dtype != offsets.dtype or force_cast_input_types
|
|
3926
|
+
)
|
|
3927
|
+
|
|
3928
|
+
if force_cast_input_types:
|
|
3929
|
+
# NOTE: Force offsets to have the same dtype as indices since the
|
|
3930
|
+
# kernels assume same dtype. We might need to revisit the assumption
|
|
3931
|
+
# of same dtypes in the future.
|
|
3932
|
+
if self.embedding_table_index_type == torch.int32:
|
|
3933
|
+
self.log(
|
|
3934
|
+
"Casting indices to int32 based on embedding_table_index_type input."
|
|
3935
|
+
)
|
|
3936
|
+
indices = indices.to(torch.int32)
|
|
3937
|
+
if self.embedding_table_index_type != self.embedding_table_offset_type:
|
|
3938
|
+
self.log(
|
|
3939
|
+
f"Force casting offsets to {self.embedding_table_index_type} so that it is the same as the indices type."
|
|
3940
|
+
)
|
|
3941
|
+
offsets = offsets.to(dtype=indices.dtype)
|
|
3942
|
+
|
|
3943
|
+
# Force casting per_sample_weights to float
|
|
3944
|
+
if per_sample_weights is not None:
|
|
3945
|
+
per_sample_weights = per_sample_weights.float()
|
|
3946
|
+
|
|
3947
|
+
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
|
|
3948
|
+
# Override the bounds check version based on prefetch_pipeline
|
|
3949
|
+
use_bounds_check_v2 = self.bounds_check_version == 2 or prefetch_pipeline
|
|
3950
|
+
bounds_check_version = (
|
|
3951
|
+
2 if use_bounds_check_v2 else self.bounds_check_version
|
|
3952
|
+
)
|
|
3953
|
+
|
|
3954
|
+
vbe = vbe_metadata.B_offsets is not None
|
|
3955
|
+
|
|
3956
|
+
# Compute B info and VBE metadata for bounds_check_indices only if
|
|
3957
|
+
# VBE and bounds check indices v2 are used
|
|
3958
|
+
if vbe and use_bounds_check_v2:
|
|
3959
|
+
B_offsets = vbe_metadata.B_offsets
|
|
3960
|
+
B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
|
|
3961
|
+
output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
|
|
3962
|
+
assert isinstance(B_offsets, Tensor), "B_offsets must be tensor"
|
|
3963
|
+
assert isinstance(
|
|
3964
|
+
B_offsets_rank_per_feature, Tensor
|
|
3965
|
+
), "B_offsets_rank_per_feature must be tensor"
|
|
3966
|
+
assert isinstance(
|
|
3967
|
+
output_offsets_feature_rank, Tensor
|
|
3968
|
+
), "output_offsets_feature_rank must be tensor"
|
|
3969
|
+
|
|
3970
|
+
row_output_offsets, b_t_map = torch.ops.fbgemm.generate_vbe_metadata(
|
|
3971
|
+
B_offsets,
|
|
3972
|
+
B_offsets_rank_per_feature,
|
|
3973
|
+
output_offsets_feature_rank,
|
|
3974
|
+
self.D_offsets,
|
|
3975
|
+
self.max_D,
|
|
3976
|
+
self.is_nobag,
|
|
3977
|
+
vbe_metadata.max_B_feature_rank,
|
|
3978
|
+
self.info_B_num_bits,
|
|
3979
|
+
offsets.numel() - 1, # total_B
|
|
3980
|
+
)
|
|
3981
|
+
else:
|
|
3982
|
+
b_t_map = None
|
|
3983
|
+
|
|
3984
|
+
torch.ops.fbgemm.bounds_check_indices(
|
|
3985
|
+
self.rows_per_table,
|
|
3986
|
+
indices,
|
|
3987
|
+
offsets,
|
|
3988
|
+
self.bounds_check_mode_int,
|
|
3989
|
+
self.bounds_check_warning,
|
|
3990
|
+
per_sample_weights,
|
|
3991
|
+
B_offsets=vbe_metadata.B_offsets,
|
|
3992
|
+
max_B=vbe_metadata.max_B,
|
|
3993
|
+
b_t_map=b_t_map,
|
|
3994
|
+
info_B_num_bits=self.info_B_num_bits,
|
|
3995
|
+
info_B_mask=self.info_B_mask,
|
|
3996
|
+
bounds_check_version=bounds_check_version,
|
|
3997
|
+
prefetch_pipeline=prefetch_pipeline,
|
|
3998
|
+
)
|
|
3999
|
+
|
|
4000
|
+
return indices, offsets, per_sample_weights, vbe_metadata
|
|
4001
|
+
|
|
4002
|
+
def _debug_print_input_stats_factory(self) -> Callable[..., None]:
|
|
4003
|
+
"""
|
|
4004
|
+
If the environment variable FBGEMM_DEBUG_PRINT_INPUT_STATS=1,
|
|
4005
|
+
return a function pointer of a function that prints input
|
|
4006
|
+
stats including weighted/unweighted, number of features,
|
|
4007
|
+
batch size, average pooling factor, total number of indices,
|
|
4008
|
+
number of unique indices, and number of indices that goes
|
|
4009
|
+
through the different backward functions. Otherwise, return
|
|
4010
|
+
a dummy function pointer.
|
|
4011
|
+
"""
|
|
4012
|
+
|
|
4013
|
+
@torch.jit.ignore
|
|
4014
|
+
def _debug_print_input_stats_factory_impl(
|
|
4015
|
+
indices: Tensor,
|
|
4016
|
+
offsets: Tensor,
|
|
4017
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
4018
|
+
) -> None:
|
|
4019
|
+
"""
|
|
4020
|
+
Print input stats (for debugging purpose only)
|
|
4021
|
+
|
|
4022
|
+
Args:
|
|
4023
|
+
indices (Tensor): Input indices
|
|
4024
|
+
offsets (Tensor): Input offsets
|
|
4025
|
+
per_sample_weights (Optional[Tensor]): Input per
|
|
4026
|
+
sample weights
|
|
4027
|
+
"""
|
|
4028
|
+
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
|
4029
|
+
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
|
|
4030
|
+
if self.debug_step % 100 == 0:
|
|
4031
|
+
# Get number of features (T) and batch size (B)
|
|
4032
|
+
T = len(self.feature_table_map)
|
|
4033
|
+
B = (offsets.numel() - 1) // T
|
|
4034
|
+
|
|
4035
|
+
# Transfer hash_size_cumsum, indices and offsets to CPU
|
|
4036
|
+
hash_size_cumsum_cpu = self.hash_size_cumsum.cpu()
|
|
4037
|
+
indices_cpu = indices.cpu()
|
|
4038
|
+
offsets_cpu = offsets.cpu()
|
|
4039
|
+
|
|
4040
|
+
# Compute linear indices
|
|
4041
|
+
for t in range(T):
|
|
4042
|
+
start = offsets_cpu[B * t].item()
|
|
4043
|
+
end = offsets_cpu[B * (t + 1)].item()
|
|
4044
|
+
indices_cpu[start:end] += hash_size_cumsum_cpu[t]
|
|
4045
|
+
|
|
4046
|
+
# Compute unique indices
|
|
4047
|
+
uniq_indices_cpu, counts = indices_cpu.unique(return_counts=True)
|
|
4048
|
+
|
|
4049
|
+
# Compute num unique indices
|
|
4050
|
+
num_uniq_indices = uniq_indices_cpu.numel()
|
|
4051
|
+
|
|
4052
|
+
# The warp_per_row kernel handles indices that their
|
|
4053
|
+
# segment lengths <= 32
|
|
4054
|
+
#
|
|
4055
|
+
# The cta_per_row kernel handles indices that their
|
|
4056
|
+
# segment lengths > 32. A single thread block is used
|
|
4057
|
+
# if segment lengths <= 1024. Otherwise, multiple
|
|
4058
|
+
# thread blocks are used.
|
|
4059
|
+
#
|
|
4060
|
+
# Counts of indices that segment lengths <= 32
|
|
4061
|
+
counts_warp_per_row = counts[counts <= 32]
|
|
4062
|
+
counts_cta_per_row = counts[counts > 32]
|
|
4063
|
+
# Counts of indices that segment lengths > 32 and <= 1024
|
|
4064
|
+
counts_cta_per_row_sth = counts_cta_per_row[counts_cta_per_row <= 1024]
|
|
4065
|
+
# Counts of indices that segment lengths > 1024
|
|
4066
|
+
counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024]
|
|
4067
|
+
|
|
4068
|
+
def compute_numel_and_avg(counts: Tensor) -> tuple[int, float]:
|
|
4069
|
+
numel = counts.numel()
|
|
4070
|
+
avg = (counts.sum().item() / numel) if numel != 0 else -1.0
|
|
4071
|
+
return numel, avg
|
|
4072
|
+
|
|
4073
|
+
# warp_per_row stats
|
|
4074
|
+
num_warp_per_row, avg_seglen_warp_per_row = compute_numel_and_avg(
|
|
4075
|
+
counts_warp_per_row
|
|
4076
|
+
)
|
|
4077
|
+
# cta_per_row using a single thread block stats
|
|
4078
|
+
num_cta_per_row_sth, avg_seglen_cta_per_row_sth = compute_numel_and_avg(
|
|
4079
|
+
counts_cta_per_row_sth
|
|
4080
|
+
)
|
|
4081
|
+
# cta_per_row using multiple thread block stats
|
|
4082
|
+
num_cta_per_row_mth, avg_seglen_cta_per_row_mth = compute_numel_and_avg(
|
|
4083
|
+
counts_cta_per_row_mth
|
|
4084
|
+
)
|
|
4085
|
+
|
|
4086
|
+
assert num_uniq_indices == (
|
|
4087
|
+
num_warp_per_row + num_cta_per_row_sth + num_cta_per_row_mth
|
|
4088
|
+
)
|
|
4089
|
+
|
|
4090
|
+
self.log(
|
|
4091
|
+
"TBE_DEBUG: "
|
|
4092
|
+
"weighted {} "
|
|
4093
|
+
"num features {} "
|
|
4094
|
+
"batch size {} "
|
|
4095
|
+
"avg pooling factor {:.2f} "
|
|
4096
|
+
"total num indices {} "
|
|
4097
|
+
"num unique indices {} "
|
|
4098
|
+
"num warp_per_row {} (avg segment length {:.2f}) "
|
|
4099
|
+
"num cta_per_row single thread block (avg segment length) {} ({:.2f}) "
|
|
4100
|
+
"num cta_per_row multiple thread blocks (avg segment length) {} ({:.2f})".format(
|
|
4101
|
+
per_sample_weights is not None,
|
|
4102
|
+
T,
|
|
4103
|
+
B,
|
|
4104
|
+
indices.numel() / (B * T),
|
|
4105
|
+
indices.numel(),
|
|
4106
|
+
num_uniq_indices,
|
|
4107
|
+
num_warp_per_row,
|
|
4108
|
+
avg_seglen_warp_per_row,
|
|
4109
|
+
num_cta_per_row_sth,
|
|
4110
|
+
avg_seglen_cta_per_row_sth,
|
|
4111
|
+
num_cta_per_row_mth,
|
|
4112
|
+
avg_seglen_cta_per_row_mth,
|
|
4113
|
+
)
|
|
4114
|
+
)
|
|
4115
|
+
# pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no
|
|
4116
|
+
# attribute `debug_step`.
|
|
4117
|
+
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
|
4118
|
+
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
|
|
4119
|
+
self.debug_step += 1
|
|
4120
|
+
|
|
4121
|
+
@torch.jit.ignore
|
|
4122
|
+
def _debug_print_input_stats_factory_null(
|
|
4123
|
+
indices: Tensor,
|
|
4124
|
+
offsets: Tensor,
|
|
4125
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
4126
|
+
) -> None:
|
|
4127
|
+
pass
|
|
4128
|
+
|
|
4129
|
+
if int(os.environ.get("FBGEMM_DEBUG_PRINT_INPUT_STATS", "0")) == 1:
|
|
4130
|
+
# pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no
|
|
4131
|
+
# attribute `debug_step`.
|
|
4132
|
+
self.debug_step = 0
|
|
4133
|
+
return _debug_print_input_stats_factory_impl
|
|
4134
|
+
return _debug_print_input_stats_factory_null
|
|
4135
|
+
|
|
4136
|
+
@torch.jit.ignore
|
|
4137
|
+
def raw_embedding_stream(self) -> None:
|
|
4138
|
+
if not self.enable_raw_embedding_streaming:
|
|
4139
|
+
return None
|
|
4140
|
+
# when pipelining is enabled
|
|
4141
|
+
# prefetch in iter i happens before the backward sparse in iter i - 1
|
|
4142
|
+
# so embeddings for iter i - 1's changed ids are not updated.
|
|
4143
|
+
# so we can only fetch the indices from the iter i - 2
|
|
4144
|
+
# when pipelining is disabled
|
|
4145
|
+
# prefetch in iter i happens before forward iter i
|
|
4146
|
+
# so we can get the iter i - 1's changed ids safely.
|
|
4147
|
+
target_prev_iter = 1
|
|
4148
|
+
if self.prefetch_pipeline:
|
|
4149
|
+
target_prev_iter = 2
|
|
4150
|
+
if not len(self.prefetched_info_list) > (target_prev_iter - 1):
|
|
4151
|
+
return None
|
|
4152
|
+
with record_function(
|
|
4153
|
+
"## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
|
|
4154
|
+
):
|
|
4155
|
+
prefetched_info = self.prefetched_info_list.pop(0)
|
|
4156
|
+
updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
|
|
4157
|
+
prefetched_info.linear_unique_cache_indices,
|
|
4158
|
+
self.lxu_cache_state,
|
|
4159
|
+
self.total_cache_hash_size,
|
|
4160
|
+
gather_cache_stats=False, # not collecting cache stats
|
|
4161
|
+
num_uniq_cache_indices=prefetched_info.linear_unique_indices_length,
|
|
4162
|
+
)
|
|
4163
|
+
updated_weights = torch.empty(
|
|
4164
|
+
[
|
|
4165
|
+
prefetched_info.linear_unique_cache_indices.size()[0],
|
|
4166
|
+
self.max_D_cache,
|
|
4167
|
+
],
|
|
4168
|
+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
|
|
4169
|
+
dtype=self.lxu_cache_weights.dtype,
|
|
4170
|
+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
|
|
4171
|
+
device=self.lxu_cache_weights.device,
|
|
4172
|
+
)
|
|
4173
|
+
torch.ops.fbgemm.masked_index_select(
|
|
4174
|
+
updated_weights,
|
|
4175
|
+
updated_locations,
|
|
4176
|
+
self.lxu_cache_weights,
|
|
4177
|
+
prefetched_info.linear_unique_indices_length,
|
|
4178
|
+
)
|
|
4179
|
+
# TODO: this statement triggers a sync
|
|
4180
|
+
# added here to make this diff self-contained
|
|
4181
|
+
# will remove in later change
|
|
4182
|
+
cache_hit_mask_index = (
|
|
4183
|
+
updated_locations.narrow(
|
|
4184
|
+
0, 0, prefetched_info.linear_unique_indices_length.item()
|
|
4185
|
+
)
|
|
4186
|
+
.not_equal(-1)
|
|
4187
|
+
.nonzero()
|
|
4188
|
+
.flatten()
|
|
4189
|
+
)
|
|
4190
|
+
# stream weights
|
|
4191
|
+
self._raw_embedding_streamer.stream(
|
|
4192
|
+
prefetched_info.linear_unique_indices.index_select(
|
|
4193
|
+
dim=0, index=cache_hit_mask_index
|
|
4194
|
+
).to(device=torch.device("cpu")),
|
|
4195
|
+
updated_weights.index_select(dim=0, index=cache_hit_mask_index).to(
|
|
4196
|
+
device=torch.device("cpu")
|
|
4197
|
+
),
|
|
4198
|
+
(
|
|
4199
|
+
prefetched_info.hash_zch_identities.index_select(
|
|
4200
|
+
dim=0, index=cache_hit_mask_index
|
|
4201
|
+
).to(device=torch.device("cpu"))
|
|
4202
|
+
if prefetched_info.hash_zch_identities is not None
|
|
4203
|
+
else None
|
|
4204
|
+
),
|
|
4205
|
+
(
|
|
4206
|
+
prefetched_info.hash_zch_runtime_meta.index_select(
|
|
4207
|
+
dim=0, index=cache_hit_mask_index
|
|
4208
|
+
).to(device=torch.device("cpu"))
|
|
4209
|
+
if prefetched_info.hash_zch_runtime_meta is not None
|
|
4210
|
+
else None
|
|
4211
|
+
),
|
|
4212
|
+
prefetched_info.linear_unique_indices_length.to(
|
|
4213
|
+
device=torch.device("cpu")
|
|
4214
|
+
),
|
|
4215
|
+
False, # require_tensor_copy
|
|
4216
|
+
False, # blocking_tensor_copy
|
|
4217
|
+
)
|
|
4218
|
+
|
|
4219
|
+
@staticmethod
|
|
4220
|
+
@torch.jit.ignore
|
|
4221
|
+
def _get_prefetched_info(
|
|
4222
|
+
linear_indices: torch.Tensor,
|
|
4223
|
+
linear_cache_indices_merged: torch.Tensor,
|
|
4224
|
+
total_cache_hash_size: int,
|
|
4225
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
4226
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
4227
|
+
max_indices_length: int,
|
|
4228
|
+
) -> PrefetchedInfo:
|
|
4229
|
+
(
|
|
4230
|
+
linear_unique_cache_indices,
|
|
4231
|
+
linear_unique_cache_indices_length,
|
|
4232
|
+
linear_unique_cache_indices_cnt,
|
|
4233
|
+
linear_unique_cache_inverse_indices,
|
|
4234
|
+
) = torch.ops.fbgemm.get_unique_indices_with_inverse(
|
|
4235
|
+
linear_cache_indices_merged,
|
|
4236
|
+
total_cache_hash_size,
|
|
4237
|
+
compute_count=True,
|
|
4238
|
+
compute_inverse_indices=True,
|
|
4239
|
+
)
|
|
4240
|
+
# pure cpu op, no need to sync, to avoid the indices out size the weights buffer
|
|
4241
|
+
max_len = min(
|
|
4242
|
+
max_indices_length,
|
|
4243
|
+
linear_unique_cache_indices.size(0),
|
|
4244
|
+
)
|
|
4245
|
+
if max_len < linear_unique_cache_indices.size(0):
|
|
4246
|
+
linear_unique_cache_indices_length.clamp_(max=max_len)
|
|
4247
|
+
# linear_unique_indices is the result after deduplication and sorting
|
|
4248
|
+
linear_unique_cache_indices = linear_unique_cache_indices.narrow(
|
|
4249
|
+
0, 0, max_len
|
|
4250
|
+
)
|
|
4251
|
+
# Compute cumulative sum as indices for selecting unique elements to
|
|
4252
|
+
# map hash_zch_identities and hash_zch_runtime_meta to linear_unique_indices
|
|
4253
|
+
count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum(
|
|
4254
|
+
linear_unique_cache_indices_cnt
|
|
4255
|
+
)
|
|
4256
|
+
# count_cum_sum will be one more element than linear_unique_cache_indices_cnt
|
|
4257
|
+
count_cum_sum = count_cum_sum.narrow(0, 0, max_len)
|
|
4258
|
+
# clamp the uninitialized elements to avoid out of bound access
|
|
4259
|
+
# the uninitialized elements will be sliced out by linear_unique_cache_indices_length
|
|
4260
|
+
# directly using linear_unique_cache_indices_length requires a sync
|
|
4261
|
+
count_cum_sum.clamp_(min=0, max=linear_unique_cache_inverse_indices.size(0) - 1)
|
|
4262
|
+
|
|
4263
|
+
# Select indices corresponding to first occurrence of each unique element
|
|
4264
|
+
linear_unique_inverse_indices = (
|
|
4265
|
+
linear_unique_cache_inverse_indices.index_select(dim=0, index=count_cum_sum)
|
|
4266
|
+
)
|
|
4267
|
+
# same as above clamp
|
|
4268
|
+
linear_unique_inverse_indices.clamp_(min=0, max=linear_indices.size(0) - 1)
|
|
4269
|
+
linear_unique_indices = linear_indices.index_select(
|
|
4270
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4271
|
+
)
|
|
4272
|
+
if hash_zch_identities is not None:
|
|
4273
|
+
# Map hash_zch_identities to unique indices
|
|
4274
|
+
hash_zch_identities = hash_zch_identities.index_select(
|
|
4275
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4276
|
+
)
|
|
4277
|
+
|
|
4278
|
+
if hash_zch_runtime_meta is not None:
|
|
4279
|
+
# Map hash_zch_runtime_meta to unique indices
|
|
4280
|
+
hash_zch_runtime_meta = hash_zch_runtime_meta.index_select(
|
|
4281
|
+
dim=0, index=linear_unique_inverse_indices
|
|
4282
|
+
)
|
|
4283
|
+
|
|
4284
|
+
return PrefetchedInfo(
|
|
4285
|
+
linear_unique_indices,
|
|
4286
|
+
linear_unique_cache_indices,
|
|
4287
|
+
linear_unique_cache_indices_length,
|
|
4288
|
+
hash_zch_identities,
|
|
4289
|
+
hash_zch_runtime_meta,
|
|
4290
|
+
)
|
|
4291
|
+
|
|
4292
|
+
@torch.jit.ignore
|
|
4293
|
+
def _store_prefetched_tensors(
|
|
4294
|
+
self,
|
|
4295
|
+
indices: torch.Tensor,
|
|
4296
|
+
offsets: torch.Tensor,
|
|
4297
|
+
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata],
|
|
4298
|
+
linear_cache_indices_merged: torch.Tensor,
|
|
4299
|
+
final_lxu_cache_locations: torch.Tensor,
|
|
4300
|
+
hash_zch_identities: Optional[torch.Tensor],
|
|
4301
|
+
hash_zch_runtime_meta: Optional[torch.Tensor],
|
|
4302
|
+
) -> None:
|
|
4303
|
+
"""
|
|
4304
|
+
NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional.
|
|
4305
|
+
This function stores the prefetched tensors for the raw embedding streaming.
|
|
4306
|
+
"""
|
|
4307
|
+
if not self.enable_raw_embedding_streaming:
|
|
4308
|
+
return
|
|
4309
|
+
|
|
4310
|
+
with record_function(
|
|
4311
|
+
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
|
|
4312
|
+
):
|
|
4313
|
+
found_in_cache_mask = final_lxu_cache_locations != -1
|
|
4314
|
+
# only process the indices that are found in the cache
|
|
4315
|
+
# this will filter out the indices from tables that doesn't have UVM_CACHE enabled
|
|
4316
|
+
linear_cache_indices_merged_masked = torch.where(
|
|
4317
|
+
found_in_cache_mask,
|
|
4318
|
+
linear_cache_indices_merged,
|
|
4319
|
+
self.total_cache_hash_size,
|
|
4320
|
+
)
|
|
4321
|
+
linearize_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
4322
|
+
self.hash_size_cumsum,
|
|
4323
|
+
indices,
|
|
4324
|
+
offsets,
|
|
4325
|
+
vbe_metadata.B_offsets if vbe_metadata is not None else None,
|
|
4326
|
+
vbe_metadata.max_B if vbe_metadata is not None else -1,
|
|
4327
|
+
)
|
|
4328
|
+
# -1 indices are ignored in raw_embedding_streamer.
|
|
4329
|
+
linearize_indices_masked = torch.where(
|
|
4330
|
+
found_in_cache_mask,
|
|
4331
|
+
linearize_indices,
|
|
4332
|
+
-1,
|
|
4333
|
+
)
|
|
4334
|
+
# Process hash_zch_identities using helper function
|
|
4335
|
+
prefetched_info = self._get_prefetched_info(
|
|
4336
|
+
linearize_indices_masked,
|
|
4337
|
+
linear_cache_indices_merged_masked,
|
|
4338
|
+
self.total_cache_hash_size,
|
|
4339
|
+
hash_zch_identities,
|
|
4340
|
+
hash_zch_runtime_meta,
|
|
4341
|
+
self.lxu_cache_weights.size(0),
|
|
4342
|
+
)
|
|
4343
|
+
|
|
4344
|
+
self.prefetched_info_list.append(prefetched_info)
|
|
4345
|
+
|
|
4346
|
+
@torch.jit.ignore
|
|
4347
|
+
def __report_input_params_factory(
|
|
4348
|
+
self,
|
|
4349
|
+
) -> Optional[Callable[..., None]]:
|
|
4350
|
+
"""
|
|
4351
|
+
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
|
|
4352
|
+
|
|
4353
|
+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
|
|
4354
|
+
- Reports input parameters (TBEDataConfig).
|
|
4355
|
+
- Writes the output as a JSON file.
|
|
4356
|
+
|
|
4357
|
+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
|
|
4358
|
+
"""
|
|
4359
|
+
try:
|
|
4360
|
+
if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
|
|
4361
|
+
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
|
|
4362
|
+
|
|
4363
|
+
reporter = TBEBenchmarkParamsReporter.create()
|
|
4364
|
+
return reporter.report_stats
|
|
4365
|
+
except Exception:
|
|
4366
|
+
return None
|
|
4367
|
+
|
|
4368
|
+
return None
|
|
4369
|
+
|
|
4370
|
+
|
|
4371
|
+
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
4372
|
+
"""
|
|
4373
|
+
Table-batched version of nn.EmbeddingBag(sparse=False)
|
|
4374
|
+
"""
|
|
4375
|
+
|
|
4376
|
+
weights: Tensor
|
|
4377
|
+
weights_offsets: Tensor
|
|
4378
|
+
D_offsets: Tensor
|
|
4379
|
+
total_D: int
|
|
4380
|
+
max_D: int
|
|
4381
|
+
hash_size_cumsum: Tensor
|
|
4382
|
+
total_hash_size_bits: int
|
|
4383
|
+
embedding_specs: list[tuple[int, int]]
|
|
4384
|
+
|
|
4385
|
+
def __init__(
|
|
4386
|
+
self,
|
|
4387
|
+
embedding_specs: list[tuple[int, int]], # tuple of (rows, dims)
|
|
4388
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
4389
|
+
weights_precision: SparseType = SparseType.FP32,
|
|
4390
|
+
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
4391
|
+
use_cpu: bool = False,
|
|
4392
|
+
output_dtype: SparseType = SparseType.FP32,
|
|
4393
|
+
use_mtia: bool = False,
|
|
4394
|
+
) -> None: # noqa C901 # tuple of (rows, dims,)
|
|
4395
|
+
super(DenseTableBatchedEmbeddingBagsCodegen, self).__init__()
|
|
4396
|
+
self.uuid = str(uuid.uuid4())
|
|
4397
|
+
|
|
4398
|
+
self.log(
|
|
4399
|
+
f"Feature Gates: {[(feature.name, feature.is_enabled()) for feature in FeatureGateName]}"
|
|
4400
|
+
)
|
|
4401
|
+
|
|
4402
|
+
self.pooling_mode = pooling_mode
|
|
4403
|
+
self.weights_precision = weights_precision
|
|
4404
|
+
self.output_dtype: int = output_dtype.as_int()
|
|
4405
|
+
table_embedding_dtype = weights_precision.as_dtype()
|
|
4406
|
+
|
|
4407
|
+
self.use_cpu: bool = use_cpu
|
|
4408
|
+
self.use_mtia: bool = use_mtia
|
|
4409
|
+
|
|
4410
|
+
assert not (use_cpu and use_mtia), "Cannot use CPU and MTIA at the same time"
|
|
4411
|
+
|
|
4412
|
+
if self.use_cpu or self.pooling_mode == PoolingMode.NONE:
|
|
4413
|
+
assert output_dtype in [
|
|
4414
|
+
SparseType.FP32,
|
|
4415
|
+
SparseType.FP16,
|
|
4416
|
+
SparseType.BF16,
|
|
4417
|
+
], "Fused pooled embedding quantization only supported for cuda."
|
|
4418
|
+
|
|
4419
|
+
# pyre-fixme[8]: Attribute has type `device`; used as `Union[int, device]`.
|
|
4420
|
+
self.current_device: torch.device = (
|
|
4421
|
+
torch.device("cpu")
|
|
4422
|
+
if self.use_cpu
|
|
4423
|
+
else (
|
|
4424
|
+
torch.device(f"mtia:{torch.mtia.current_device()}")
|
|
4425
|
+
if self.use_mtia
|
|
4426
|
+
else torch.cuda.current_device()
|
|
4427
|
+
)
|
|
4428
|
+
)
|
|
4429
|
+
|
|
4430
|
+
self.embedding_specs = embedding_specs
|
|
4431
|
+
(rows, dims) = zip(*embedding_specs)
|
|
4432
|
+
T_ = len(self.embedding_specs)
|
|
4433
|
+
assert T_ > 0
|
|
4434
|
+
|
|
4435
|
+
feature_table_map = (
|
|
4436
|
+
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
4437
|
+
)
|
|
4438
|
+
T = len(feature_table_map)
|
|
4439
|
+
assert T_ <= T
|
|
4440
|
+
|
|
4441
|
+
feature_dims = [dims[t] for t in feature_table_map]
|
|
4442
|
+
D_offsets = [0] + list(accumulate(feature_dims))
|
|
4443
|
+
self.total_D = D_offsets[-1]
|
|
4444
|
+
self.max_D = max(dims)
|
|
4445
|
+
self.register_buffer(
|
|
4446
|
+
"D_offsets",
|
|
4447
|
+
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
|
|
4448
|
+
)
|
|
4449
|
+
assert self.D_offsets.numel() == T + 1
|
|
4450
|
+
|
|
4451
|
+
# Required for VBE
|
|
4452
|
+
self.register_buffer(
|
|
4453
|
+
"feature_dims",
|
|
4454
|
+
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
|
|
4455
|
+
)
|
|
4456
|
+
|
|
4457
|
+
hash_size_cumsum = [0] + list(accumulate(rows))
|
|
4458
|
+
if hash_size_cumsum[-1] == 0:
|
|
4459
|
+
self.total_hash_size_bits: int = 0
|
|
4460
|
+
else:
|
|
4461
|
+
self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
|
|
4462
|
+
# The last element is to easily access # of rows of each table by
|
|
4463
|
+
# hash_size_cumsum[t + 1] - hash_size_cumsum[t]
|
|
4464
|
+
hash_size_cumsum = [hash_size_cumsum[t] for t in feature_table_map] + [
|
|
4465
|
+
hash_size_cumsum[-1]
|
|
4466
|
+
]
|
|
4467
|
+
self.register_buffer(
|
|
4468
|
+
"hash_size_cumsum",
|
|
4469
|
+
torch.tensor(
|
|
4470
|
+
hash_size_cumsum, device=self.current_device, dtype=torch.int64
|
|
4471
|
+
),
|
|
4472
|
+
)
|
|
4473
|
+
weights_offsets = [0] + list(
|
|
4474
|
+
accumulate([row * dim for (row, dim) in embedding_specs])
|
|
4475
|
+
)
|
|
4476
|
+
self.weights = nn.Parameter(
|
|
4477
|
+
torch.randn(
|
|
4478
|
+
weights_offsets[-1],
|
|
4479
|
+
device=self.current_device,
|
|
4480
|
+
dtype=table_embedding_dtype,
|
|
4481
|
+
)
|
|
4482
|
+
)
|
|
4483
|
+
for feature in range(T):
|
|
4484
|
+
t = feature_table_map[feature]
|
|
4485
|
+
row, dim = embedding_specs[t]
|
|
4486
|
+
if (
|
|
4487
|
+
self.weights[weights_offsets[t] : weights_offsets[t + 1]].numel()
|
|
4488
|
+
!= row * dim
|
|
4489
|
+
):
|
|
4490
|
+
self.log(
|
|
4491
|
+
f"row {row} dim {dim} feature {feature} t {t} {self.weights[weights_offsets[t] : weights_offsets[t + 1]].numel()}"
|
|
4492
|
+
)
|
|
4493
|
+
assert (
|
|
4494
|
+
self.weights[weights_offsets[t] : weights_offsets[t + 1]].numel()
|
|
4495
|
+
== row * dim
|
|
4496
|
+
)
|
|
4497
|
+
assert self.hash_size_cumsum[feature] == sum(
|
|
4498
|
+
row for (row, _) in embedding_specs[:t]
|
|
4499
|
+
)
|
|
4500
|
+
|
|
4501
|
+
self.weights_physical_offsets: list[int] = weights_offsets
|
|
4502
|
+
weights_offsets = [weights_offsets[t] for t in feature_table_map]
|
|
4503
|
+
self.register_buffer(
|
|
4504
|
+
"weights_offsets",
|
|
4505
|
+
torch.tensor(
|
|
4506
|
+
weights_offsets, device=self.current_device, dtype=torch.int64
|
|
4507
|
+
),
|
|
4508
|
+
)
|
|
4509
|
+
|
|
4510
|
+
@torch.jit.ignore
|
|
4511
|
+
def log(self, msg: str) -> None:
|
|
4512
|
+
"""
|
|
4513
|
+
Log with TBE id prefix to distinguish between multiple TBE instances
|
|
4514
|
+
per process
|
|
4515
|
+
|
|
4516
|
+
Args:
|
|
4517
|
+
msg (str): The message to print
|
|
4518
|
+
|
|
4519
|
+
Returns:
|
|
4520
|
+
None
|
|
4521
|
+
"""
|
|
4522
|
+
logging.info(f"[TBE={self.uuid}] {msg}")
|
|
4523
|
+
|
|
4524
|
+
@torch.jit.ignore
|
|
4525
|
+
def _generate_vbe_metadata(
|
|
4526
|
+
self,
|
|
4527
|
+
offsets: Tensor,
|
|
4528
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
4529
|
+
) -> invokers.lookup_args.VBEMetadata:
|
|
4530
|
+
# Blocking D2H copy, but only runs at first call
|
|
4531
|
+
self.feature_dims = self.feature_dims.cpu()
|
|
4532
|
+
return generate_vbe_metadata(
|
|
4533
|
+
offsets,
|
|
4534
|
+
batch_size_per_feature_per_rank,
|
|
4535
|
+
self.pooling_mode,
|
|
4536
|
+
self.feature_dims,
|
|
4537
|
+
self.current_device,
|
|
4538
|
+
)
|
|
4539
|
+
|
|
4540
|
+
def forward(
|
|
4541
|
+
self,
|
|
4542
|
+
indices: Tensor,
|
|
4543
|
+
offsets: Tensor,
|
|
4544
|
+
per_sample_weights: Optional[Tensor] = None,
|
|
4545
|
+
feature_requires_grad: Optional[Tensor] = None,
|
|
4546
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
4547
|
+
) -> Tensor:
|
|
4548
|
+
# Generate VBE metadata
|
|
4549
|
+
vbe_metadata = self._generate_vbe_metadata(
|
|
4550
|
+
offsets, batch_size_per_feature_per_rank
|
|
4551
|
+
)
|
|
4552
|
+
|
|
4553
|
+
# NOTE: Force offsets to have the same dtype as indices since the
|
|
4554
|
+
# kernels assume same dtype. We might need to revisit the assumption
|
|
4555
|
+
# of same dtypes in the future.
|
|
4556
|
+
offsets = offsets.to(dtype=indices.dtype)
|
|
4557
|
+
|
|
4558
|
+
# Force casting per_sample_weights to float
|
|
4559
|
+
if per_sample_weights is not None:
|
|
4560
|
+
per_sample_weights = per_sample_weights.float()
|
|
4561
|
+
|
|
4562
|
+
return torch.ops.fbgemm.dense_embedding_codegen_lookup_function(
|
|
4563
|
+
dev_weights=self.weights,
|
|
4564
|
+
weights_offsets=self.weights_offsets,
|
|
4565
|
+
D_offsets=self.D_offsets,
|
|
4566
|
+
total_D=self.total_D,
|
|
4567
|
+
max_D=self.max_D,
|
|
4568
|
+
hash_size_cumsum=self.hash_size_cumsum,
|
|
4569
|
+
total_hash_size_bits=self.total_hash_size_bits,
|
|
4570
|
+
indices=indices,
|
|
4571
|
+
offsets=offsets,
|
|
4572
|
+
pooling_mode=self.pooling_mode,
|
|
4573
|
+
indice_weights=per_sample_weights,
|
|
4574
|
+
feature_requires_grad=feature_requires_grad,
|
|
4575
|
+
output_dtype=self.output_dtype,
|
|
4576
|
+
B_offsets=vbe_metadata.B_offsets,
|
|
4577
|
+
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
|
|
4578
|
+
vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
|
|
4579
|
+
max_B=vbe_metadata.max_B,
|
|
4580
|
+
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
|
|
4581
|
+
vbe_output_size=vbe_metadata.output_size,
|
|
4582
|
+
)
|
|
4583
|
+
|
|
4584
|
+
@torch.jit.export
|
|
4585
|
+
def split_embedding_weights(self) -> list[Tensor]:
|
|
4586
|
+
"""
|
|
4587
|
+
Returns a list of weights, split by table
|
|
4588
|
+
"""
|
|
4589
|
+
splits = []
|
|
4590
|
+
for t, (rows, dim) in enumerate(self.embedding_specs):
|
|
4591
|
+
offset = self.weights_physical_offsets[t]
|
|
4592
|
+
splits.append(
|
|
4593
|
+
self.weights.detach()[offset : offset + rows * dim].view(rows, dim)
|
|
4594
|
+
)
|
|
4595
|
+
return splits
|
|
4596
|
+
|
|
4597
|
+
def init_embedding_weights_uniform(self, min_val: float, max_val: float) -> None:
|
|
4598
|
+
splits = self.split_embedding_weights()
|
|
4599
|
+
for param in splits:
|
|
4600
|
+
param.uniform_(min_val, max_val)
|