fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -11,12 +11,11 @@
|
|
|
11
11
|
|
|
12
12
|
import enum
|
|
13
13
|
from dataclasses import dataclass
|
|
14
|
-
from typing import
|
|
14
|
+
from typing import FrozenSet, NamedTuple, Optional, Tuple
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import Tensor
|
|
18
18
|
|
|
19
|
-
|
|
20
19
|
# Maximum number of times prefetch() can be called without
|
|
21
20
|
# a corresponding forward() call
|
|
22
21
|
MAX_PREFETCH_DEPTH = 100
|
|
@@ -33,6 +32,17 @@ class EmbeddingLocation(enum.IntEnum):
|
|
|
33
32
|
HOST = 3
|
|
34
33
|
MTIA = 4
|
|
35
34
|
|
|
35
|
+
@classmethod
|
|
36
|
+
# pyre-ignore[3]
|
|
37
|
+
def str_values(cls):
|
|
38
|
+
return [
|
|
39
|
+
"device",
|
|
40
|
+
"managed",
|
|
41
|
+
"managed_caching",
|
|
42
|
+
"host",
|
|
43
|
+
"mtia",
|
|
44
|
+
]
|
|
45
|
+
|
|
36
46
|
@classmethod
|
|
37
47
|
# pyre-ignore[3]
|
|
38
48
|
def from_str(cls, key: str):
|
|
@@ -49,6 +59,246 @@ class EmbeddingLocation(enum.IntEnum):
|
|
|
49
59
|
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
|
|
50
60
|
|
|
51
61
|
|
|
62
|
+
class EvictionPolicy(NamedTuple):
|
|
63
|
+
eviction_trigger_mode: int = (
|
|
64
|
+
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual 4: id count
|
|
65
|
+
)
|
|
66
|
+
eviction_strategy: int = (
|
|
67
|
+
0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
|
|
68
|
+
)
|
|
69
|
+
eviction_step_intervals: Optional[int] = (
|
|
70
|
+
None # trigger_step_interval if trigger mode is iteration
|
|
71
|
+
)
|
|
72
|
+
eviction_mem_threshold_gb: Optional[int] = (
|
|
73
|
+
None # eviction trigger condition if trigger mode is mem_util
|
|
74
|
+
)
|
|
75
|
+
counter_thresholds: Optional[list[int]] = (
|
|
76
|
+
None # count_thresholds for each table if eviction strategy is counter
|
|
77
|
+
)
|
|
78
|
+
ttls_in_mins: Optional[list[int]] = (
|
|
79
|
+
None # ttls_in_mins for each table if eviction strategy is timestamp
|
|
80
|
+
)
|
|
81
|
+
counter_decay_rates: Optional[list[float]] = (
|
|
82
|
+
None # count_decay_rates for each table if eviction strategy is counter
|
|
83
|
+
)
|
|
84
|
+
feature_score_counter_decay_rates: Optional[list[float]] = (
|
|
85
|
+
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
|
|
86
|
+
)
|
|
87
|
+
training_id_eviction_trigger_count: Optional[list[int]] = (
|
|
88
|
+
None # Number of training IDs that, when exceeded, will trigger eviction for each table.
|
|
89
|
+
)
|
|
90
|
+
training_id_keep_count: Optional[list[int]] = (
|
|
91
|
+
None # Target number of training IDs to retain in each table after eviction.
|
|
92
|
+
)
|
|
93
|
+
l2_weight_thresholds: Optional[list[float]] = (
|
|
94
|
+
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
|
|
95
|
+
)
|
|
96
|
+
threshold_calculation_bucket_stride: Optional[float] = (
|
|
97
|
+
0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
|
|
98
|
+
)
|
|
99
|
+
threshold_calculation_bucket_num: Optional[int] = (
|
|
100
|
+
1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction.
|
|
101
|
+
)
|
|
102
|
+
interval_for_insufficient_eviction_s: int = (
|
|
103
|
+
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
|
|
104
|
+
# insufficient means we didn't evict enough rows, so we want to wait longer time to
|
|
105
|
+
# avoid another insufficient eviction
|
|
106
|
+
600
|
|
107
|
+
)
|
|
108
|
+
interval_for_sufficient_eviction_s: int = (
|
|
109
|
+
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
|
|
110
|
+
60
|
|
111
|
+
)
|
|
112
|
+
interval_for_feature_statistics_decay_s: int = (
|
|
113
|
+
24 * 3600 # 1 day, interval for feature statistics decay
|
|
114
|
+
)
|
|
115
|
+
meta_header_lens: Optional[list[int]] = None # metaheader length for each table
|
|
116
|
+
eviction_free_mem_threshold_gb: Optional[int] = (
|
|
117
|
+
None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
|
|
118
|
+
)
|
|
119
|
+
eviction_free_mem_check_interval_batch: Optional[int] = (
|
|
120
|
+
None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
|
|
121
|
+
)
|
|
122
|
+
enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
|
|
123
|
+
None # enable eviction if eviction policy is feature score, false means no eviction
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def validate(self) -> None:
|
|
127
|
+
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
|
|
128
|
+
"eviction_trigger_mode must be 0, 1, 2, 3, 4, 5"
|
|
129
|
+
f"actual {self.eviction_trigger_mode}"
|
|
130
|
+
)
|
|
131
|
+
if self.eviction_trigger_mode == 0:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
|
|
135
|
+
"eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
|
|
136
|
+
f"actual {self.eviction_strategy}"
|
|
137
|
+
)
|
|
138
|
+
if self.eviction_trigger_mode == 1:
|
|
139
|
+
assert (
|
|
140
|
+
self.eviction_step_intervals is not None
|
|
141
|
+
and self.eviction_step_intervals > 0
|
|
142
|
+
), (
|
|
143
|
+
"eviction_step_intervals must be positive if eviction_trigger_mode is 1, "
|
|
144
|
+
f"actual {self.eviction_step_intervals}"
|
|
145
|
+
)
|
|
146
|
+
elif self.eviction_trigger_mode == 2:
|
|
147
|
+
assert (
|
|
148
|
+
self.eviction_mem_threshold_gb is not None
|
|
149
|
+
), "eviction_mem_threshold_gb must be set if eviction_trigger_mode is 2"
|
|
150
|
+
elif self.eviction_trigger_mode == 4:
|
|
151
|
+
assert (
|
|
152
|
+
self.training_id_eviction_trigger_count is not None
|
|
153
|
+
), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
|
|
154
|
+
elif self.eviction_trigger_mode == 5:
|
|
155
|
+
assert (
|
|
156
|
+
self.eviction_free_mem_threshold_gb is not None
|
|
157
|
+
), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5"
|
|
158
|
+
assert (
|
|
159
|
+
self.eviction_free_mem_check_interval_batch is not None
|
|
160
|
+
), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5"
|
|
161
|
+
|
|
162
|
+
if self.eviction_strategy == 0:
|
|
163
|
+
assert self.ttls_in_mins is not None, (
|
|
164
|
+
"ttls_in_mins must be set if eviction_strategy is 0, "
|
|
165
|
+
f"actual {self.ttls_in_mins}"
|
|
166
|
+
)
|
|
167
|
+
elif self.eviction_strategy == 1:
|
|
168
|
+
assert self.counter_thresholds is not None, (
|
|
169
|
+
"counter_thresholds must be set if eviction_strategy is 1, "
|
|
170
|
+
f"actual {self.counter_thresholds}"
|
|
171
|
+
)
|
|
172
|
+
assert self.counter_decay_rates is not None, (
|
|
173
|
+
"counter_decay_rates must be set if eviction_strategy is 1, "
|
|
174
|
+
f"actual {self.counter_decay_rates}"
|
|
175
|
+
)
|
|
176
|
+
assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
|
|
177
|
+
"counter_thresholds and counter_decay_rates must have the same length, "
|
|
178
|
+
f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
|
|
179
|
+
)
|
|
180
|
+
elif self.eviction_strategy == 2:
|
|
181
|
+
assert self.counter_thresholds is not None, (
|
|
182
|
+
"counter_thresholds must be set if eviction_strategy is 2, "
|
|
183
|
+
f"actual {self.counter_thresholds}"
|
|
184
|
+
)
|
|
185
|
+
assert self.counter_decay_rates is not None, (
|
|
186
|
+
"counter_decay_rates must be set if eviction_strategy is 2, "
|
|
187
|
+
f"actual {self.counter_decay_rates}"
|
|
188
|
+
)
|
|
189
|
+
assert self.ttls_in_mins is not None, (
|
|
190
|
+
"ttls_in_mins must be set if eviction_strategy is 2, "
|
|
191
|
+
f"actual {self.ttls_in_mins}"
|
|
192
|
+
)
|
|
193
|
+
assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
|
|
194
|
+
"counter_thresholds and counter_decay_rates must have the same length, "
|
|
195
|
+
f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
|
|
196
|
+
)
|
|
197
|
+
assert len(self.counter_thresholds) == len(self.ttls_in_mins), (
|
|
198
|
+
"counter_thresholds and ttls_in_mins must have the same length, "
|
|
199
|
+
f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
|
|
200
|
+
)
|
|
201
|
+
elif self.eviction_strategy == 5:
|
|
202
|
+
assert self.feature_score_counter_decay_rates is not None, (
|
|
203
|
+
"feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
|
|
204
|
+
f"actual {self.feature_score_counter_decay_rates}"
|
|
205
|
+
)
|
|
206
|
+
assert self.training_id_eviction_trigger_count is not None, (
|
|
207
|
+
"training_id_eviction_trigger_count must be set if eviction_strategy is 5,"
|
|
208
|
+
f"actual {self.training_id_eviction_trigger_count}"
|
|
209
|
+
)
|
|
210
|
+
assert self.training_id_keep_count is not None, (
|
|
211
|
+
"training_id_keep_count must be set if eviction_strategy is 5,"
|
|
212
|
+
f"actual {self.training_id_keep_count}"
|
|
213
|
+
)
|
|
214
|
+
assert self.threshold_calculation_bucket_stride is not None, (
|
|
215
|
+
"threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
|
|
216
|
+
f"actual {self.threshold_calculation_bucket_stride}"
|
|
217
|
+
)
|
|
218
|
+
assert self.threshold_calculation_bucket_num is not None, (
|
|
219
|
+
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
|
|
220
|
+
f"actual {self.threshold_calculation_bucket_num}"
|
|
221
|
+
)
|
|
222
|
+
assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
|
|
223
|
+
"enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
|
|
224
|
+
f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
|
|
225
|
+
)
|
|
226
|
+
assert (
|
|
227
|
+
len(self.enable_eviction_for_feature_score_eviction_policy)
|
|
228
|
+
== len(self.training_id_keep_count)
|
|
229
|
+
== len(self.feature_score_counter_decay_rates)
|
|
230
|
+
), (
|
|
231
|
+
"feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
|
|
232
|
+
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class KVZCHParams(NamedTuple):
|
|
237
|
+
# global bucket id start and global bucket id end offsets for each logical table,
|
|
238
|
+
# where start offset is inclusive and end offset is exclusive
|
|
239
|
+
bucket_offsets: list[tuple[int, int]] = []
|
|
240
|
+
# bucket size for each logical table
|
|
241
|
+
# the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
|
|
242
|
+
bucket_sizes: list[int] = []
|
|
243
|
+
# enable optimizer offloading or not
|
|
244
|
+
enable_optimizer_offloading: bool = False
|
|
245
|
+
# when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
|
|
246
|
+
# can only be enabled when enable_optimizer_offloading is enabled
|
|
247
|
+
backend_return_whole_row: bool = False
|
|
248
|
+
eviction_policy: EvictionPolicy = EvictionPolicy()
|
|
249
|
+
embedding_cache_mode: bool = False
|
|
250
|
+
load_ckpt_without_opt: bool = False
|
|
251
|
+
optimizer_type_for_st: Optional[str] = None
|
|
252
|
+
optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
|
|
253
|
+
|
|
254
|
+
def validate(self) -> None:
|
|
255
|
+
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
|
|
256
|
+
"bucket_offsets and bucket_sizes must have the same length, "
|
|
257
|
+
f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
|
|
258
|
+
)
|
|
259
|
+
self.eviction_policy.validate()
|
|
260
|
+
assert (
|
|
261
|
+
not self.backend_return_whole_row or self.enable_optimizer_offloading
|
|
262
|
+
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class KVZCHTBEConfig(NamedTuple):
|
|
266
|
+
# Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem
|
|
267
|
+
kvzch_eviction_trigger_mode: int = 2 # mem_util
|
|
268
|
+
# Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
|
|
269
|
+
eviction_free_mem_threshold_gb: int = 200 # 200GB
|
|
270
|
+
# Number of batches between checks for free memory threshold when using free_mem trigger mode.
|
|
271
|
+
eviction_free_mem_check_interval_batch: int = 1000
|
|
272
|
+
# The width of each feature score bucket used for threshold calculation in feature score-based eviction.
|
|
273
|
+
threshold_calculation_bucket_stride: float = 0.2
|
|
274
|
+
# Total number of feature score buckets used for threshold calculation in feature score-based eviction.
|
|
275
|
+
threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M
|
|
276
|
+
# When true, we only save weight to kvzch backend and not optimizer state.
|
|
277
|
+
load_ckpt_without_opt: bool = False
|
|
278
|
+
# [DO NOT USE] This is for st publish only, do not set it in your config
|
|
279
|
+
optimizer_type_for_st: Optional[str] = None
|
|
280
|
+
# [DO NOT USE] This is for st publish only, do not set it in your config
|
|
281
|
+
optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class BackendType(enum.IntEnum):
|
|
285
|
+
SSD = 0
|
|
286
|
+
DRAM = 1
|
|
287
|
+
PS = 2
|
|
288
|
+
|
|
289
|
+
@classmethod
|
|
290
|
+
# pyre-ignore[3]
|
|
291
|
+
def from_str(cls, key: str):
|
|
292
|
+
lookup = {
|
|
293
|
+
"ssd": BackendType.SSD,
|
|
294
|
+
"dram": BackendType.DRAM,
|
|
295
|
+
}
|
|
296
|
+
if key in lookup:
|
|
297
|
+
return lookup[key]
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Cannot parse value into BackendType: {key}")
|
|
300
|
+
|
|
301
|
+
|
|
52
302
|
class CacheAlgorithm(enum.Enum):
|
|
53
303
|
LRU = 0
|
|
54
304
|
LFU = 1
|
|
@@ -106,6 +356,12 @@ class BoundsCheckMode(enum.IntEnum):
|
|
|
106
356
|
V2_FATAL = 6
|
|
107
357
|
|
|
108
358
|
|
|
359
|
+
class ComputeDevice(enum.IntEnum):
|
|
360
|
+
CPU = 0
|
|
361
|
+
CUDA = 1
|
|
362
|
+
MTIA = 2
|
|
363
|
+
|
|
364
|
+
|
|
109
365
|
class EmbeddingSpecInfo(enum.IntEnum):
|
|
110
366
|
feature_names = 0
|
|
111
367
|
rows = 1
|
|
@@ -125,8 +381,8 @@ SplitState: NamedTuple = NamedTuple(
|
|
|
125
381
|
("dev_size", int),
|
|
126
382
|
("host_size", int),
|
|
127
383
|
("uvm_size", int),
|
|
128
|
-
("placements",
|
|
129
|
-
("offsets",
|
|
384
|
+
("placements", list[EmbeddingLocation]),
|
|
385
|
+
("offsets", list[int]),
|
|
130
386
|
],
|
|
131
387
|
)
|
|
132
388
|
|
|
@@ -134,15 +390,15 @@ SplitState: NamedTuple = NamedTuple(
|
|
|
134
390
|
@dataclass
|
|
135
391
|
class CacheState:
|
|
136
392
|
# T + 1 elements and cache_hash_size_cumsum[-1] == total_cache_hash_size
|
|
137
|
-
cache_hash_size_cumsum:
|
|
138
|
-
cache_index_table_map:
|
|
393
|
+
cache_hash_size_cumsum: list[int]
|
|
394
|
+
cache_index_table_map: list[int]
|
|
139
395
|
total_cache_hash_size: int
|
|
140
396
|
|
|
141
397
|
|
|
142
398
|
def construct_cache_state(
|
|
143
|
-
row_list:
|
|
144
|
-
location_list:
|
|
145
|
-
feature_table_map:
|
|
399
|
+
row_list: list[int],
|
|
400
|
+
location_list: list[EmbeddingLocation],
|
|
401
|
+
feature_table_map: list[int],
|
|
146
402
|
) -> CacheState:
|
|
147
403
|
_cache_hash_size_cumsum = [0]
|
|
148
404
|
total_cache_hash_size = 0
|
|
@@ -215,3 +471,13 @@ def get_new_embedding_location(
|
|
|
215
471
|
# UVM caching
|
|
216
472
|
else:
|
|
217
473
|
return EmbeddingLocation.MANAGED_CACHING
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def get_bounds_check_version_for_platform() -> int:
|
|
477
|
+
# NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a
|
|
478
|
+
# constraint that the gridDim * blockDim has to be smaller than
|
|
479
|
+
# 2^32. The v1 kernel can be launched with gridDim * blockDim >
|
|
480
|
+
# 2^32 while the v2 kernel limits the gridDim size to 64 * # of
|
|
481
|
+
# SMs. Thus, its gridDim * blockDim is guaranteed to be smaller
|
|
482
|
+
# than 2^32
|
|
483
|
+
return 2 if (torch.cuda.is_available() and torch.version.hip) else 1
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
import logging
|
|
13
13
|
import uuid
|
|
14
14
|
from itertools import accumulate
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Optional, Union
|
|
16
16
|
|
|
17
17
|
import fbgemm_gpu # noqa: F401
|
|
18
18
|
import torch # usort:skip
|
|
@@ -28,6 +28,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
|
28
28
|
DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
29
29
|
EmbeddingLocation,
|
|
30
30
|
EmbeddingSpecInfo,
|
|
31
|
+
get_bounds_check_version_for_platform,
|
|
31
32
|
get_new_embedding_location,
|
|
32
33
|
MAX_PREFETCH_DEPTH,
|
|
33
34
|
PoolingMode,
|
|
@@ -91,14 +92,14 @@ def align_to_cacheline(a: int) -> int:
|
|
|
91
92
|
|
|
92
93
|
|
|
93
94
|
def nbit_construct_split_state(
|
|
94
|
-
embedding_specs:
|
|
95
|
+
embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]],
|
|
95
96
|
cacheable: bool,
|
|
96
97
|
row_alignment: int,
|
|
97
98
|
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
|
|
98
99
|
cacheline_alignment: bool = True,
|
|
99
100
|
) -> SplitState:
|
|
100
|
-
placements = torch.jit.annotate(
|
|
101
|
-
offsets = torch.jit.annotate(
|
|
101
|
+
placements = torch.jit.annotate(list[EmbeddingLocation], [])
|
|
102
|
+
offsets = torch.jit.annotate(list[int], [])
|
|
102
103
|
dev_size = 0
|
|
103
104
|
host_size = 0
|
|
104
105
|
uvm_size = 0
|
|
@@ -164,7 +165,7 @@ def inputs_to_device(
|
|
|
164
165
|
offsets: torch.Tensor,
|
|
165
166
|
per_sample_weights: Optional[torch.Tensor],
|
|
166
167
|
bounds_check_warning: torch.Tensor,
|
|
167
|
-
) ->
|
|
168
|
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
168
169
|
if bounds_check_warning.device.type == "meta":
|
|
169
170
|
return indices, offsets, per_sample_weights
|
|
170
171
|
|
|
@@ -330,7 +331,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
330
331
|
Options are `torch.int32` and `torch.int64`.
|
|
331
332
|
"""
|
|
332
333
|
|
|
333
|
-
embedding_specs:
|
|
334
|
+
embedding_specs: list[tuple[str, int, int, SparseType, EmbeddingLocation]]
|
|
334
335
|
record_cache_metrics: RecordCacheMetrics
|
|
335
336
|
# pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
|
|
336
337
|
cache_miss_counter: torch.Tensor
|
|
@@ -345,15 +346,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
345
346
|
|
|
346
347
|
def __init__( # noqa C901
|
|
347
348
|
self,
|
|
348
|
-
embedding_specs:
|
|
349
|
-
|
|
349
|
+
embedding_specs: list[
|
|
350
|
+
tuple[str, int, int, SparseType, EmbeddingLocation]
|
|
350
351
|
], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
|
|
351
|
-
feature_table_map: Optional[
|
|
352
|
-
index_remapping: Optional[
|
|
352
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
353
|
+
index_remapping: Optional[list[Tensor]] = None,
|
|
353
354
|
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
354
355
|
device: Optional[Union[str, int, torch.device]] = None,
|
|
355
356
|
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
356
|
-
weight_lists: Optional[
|
|
357
|
+
weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
|
|
357
358
|
pruning_hash_load_factor: float = 0.5,
|
|
358
359
|
use_array_for_index_remapping: bool = True,
|
|
359
360
|
output_dtype: SparseType = SparseType.FP16,
|
|
@@ -372,7 +373,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
372
373
|
cacheline_alignment: bool = True,
|
|
373
374
|
uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
|
|
374
375
|
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
|
|
375
|
-
feature_names_per_table: Optional[
|
|
376
|
+
feature_names_per_table: Optional[list[list[str]]] = None,
|
|
376
377
|
indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
|
|
377
378
|
) -> None: # noqa C901 # tuple of (rows, dims,)
|
|
378
379
|
super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()
|
|
@@ -405,14 +406,14 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
405
406
|
self.indices_dtype = indices_dtype
|
|
406
407
|
# (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs)
|
|
407
408
|
# Pyre workaround
|
|
408
|
-
self.feature_names:
|
|
409
|
+
self.feature_names: list[str] = [e[0] for e in embedding_specs]
|
|
409
410
|
self.cache_load_factor: float = cache_load_factor
|
|
410
411
|
self.cache_sets: int = cache_sets
|
|
411
412
|
self.cache_reserved_memory: float = cache_reserved_memory
|
|
412
|
-
rows:
|
|
413
|
-
dims:
|
|
414
|
-
weights_tys:
|
|
415
|
-
locations:
|
|
413
|
+
rows: list[int] = [e[1] for e in embedding_specs]
|
|
414
|
+
dims: list[int] = [e[2] for e in embedding_specs]
|
|
415
|
+
weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
|
|
416
|
+
locations: list[EmbeddingLocation] = [e[4] for e in embedding_specs]
|
|
416
417
|
# if target device is meta then we set use_cpu based on the embedding location
|
|
417
418
|
# information in embedding_specs.
|
|
418
419
|
if self.current_device.type == "meta":
|
|
@@ -452,7 +453,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
452
453
|
T_ = len(self.embedding_specs)
|
|
453
454
|
assert T_ > 0
|
|
454
455
|
|
|
455
|
-
self.feature_table_map:
|
|
456
|
+
self.feature_table_map: list[int] = (
|
|
456
457
|
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
457
458
|
)
|
|
458
459
|
T = len(self.feature_table_map)
|
|
@@ -635,6 +636,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
635
636
|
self.fp8_exponent_bits = -1
|
|
636
637
|
self.fp8_exponent_bias = -1
|
|
637
638
|
|
|
639
|
+
self.bounds_check_version: int = get_bounds_check_version_for_platform()
|
|
640
|
+
|
|
638
641
|
@torch.jit.ignore
|
|
639
642
|
def log(self, msg: str) -> None:
|
|
640
643
|
"""
|
|
@@ -673,7 +676,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
673
676
|
return self.table_wise_cache_miss
|
|
674
677
|
|
|
675
678
|
@torch.jit.export
|
|
676
|
-
def get_feature_num_per_table(self) ->
|
|
679
|
+
def get_feature_num_per_table(self) -> list[int]:
|
|
677
680
|
if self.feature_names_per_table is None:
|
|
678
681
|
return []
|
|
679
682
|
return [len(feature_names) for feature_names in self.feature_names_per_table]
|
|
@@ -975,6 +978,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
975
978
|
self.bounds_check_mode_int,
|
|
976
979
|
self.bounds_check_warning,
|
|
977
980
|
per_sample_weights,
|
|
981
|
+
bounds_check_version=self.bounds_check_version,
|
|
978
982
|
)
|
|
979
983
|
|
|
980
984
|
# Index remapping changes input indices, and some of them becomes -1 (prunned rows).
|
|
@@ -1017,6 +1021,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1017
1021
|
self.bounds_check_mode_int,
|
|
1018
1022
|
self.bounds_check_warning,
|
|
1019
1023
|
per_sample_weights,
|
|
1024
|
+
bounds_check_version=self.bounds_check_version,
|
|
1020
1025
|
)
|
|
1021
1026
|
# Note: CPU and CUDA ops use the same interface to facilitate JIT IR
|
|
1022
1027
|
# generation for CUDA/CPU. For CPU op, we don't need weights_uvm and
|
|
@@ -1206,8 +1211,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1206
1211
|
dev_size: int,
|
|
1207
1212
|
host_size: int,
|
|
1208
1213
|
uvm_size: int,
|
|
1209
|
-
placements:
|
|
1210
|
-
offsets:
|
|
1214
|
+
placements: list[int],
|
|
1215
|
+
offsets: list[int],
|
|
1211
1216
|
enforce_hbm: bool,
|
|
1212
1217
|
) -> None:
|
|
1213
1218
|
assert not self.weight_initialized, "Weights have already been initialized."
|
|
@@ -1516,6 +1521,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1516
1521
|
for i, weight in enumerate(weights):
|
|
1517
1522
|
weights[i] = (
|
|
1518
1523
|
weight[0].to(device),
|
|
1524
|
+
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `to`.
|
|
1519
1525
|
weight[1].to(device) if weight[1] is not None else None,
|
|
1520
1526
|
)
|
|
1521
1527
|
(
|
|
@@ -1596,7 +1602,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1596
1602
|
@torch.jit.export
|
|
1597
1603
|
def split_embedding_weights_with_scale_bias(
|
|
1598
1604
|
self, split_scale_bias_mode: int = 1
|
|
1599
|
-
) ->
|
|
1605
|
+
) -> list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
|
|
1600
1606
|
"""
|
|
1601
1607
|
Returns a list of weights, split by table
|
|
1602
1608
|
split_scale_bias_mode:
|
|
@@ -1605,7 +1611,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1605
1611
|
2: return weights, scale, bias.
|
|
1606
1612
|
"""
|
|
1607
1613
|
assert self.weight_initialized
|
|
1608
|
-
splits:
|
|
1614
|
+
splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
|
|
1609
1615
|
for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs):
|
|
1610
1616
|
placement = self.weights_physical_placements[t]
|
|
1611
1617
|
if (
|
|
@@ -1730,12 +1736,12 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1730
1736
|
# the second with scale_bias.
|
|
1731
1737
|
# This should've been named as split_scale_bias.
|
|
1732
1738
|
# Keep as is for backward compatibility.
|
|
1733
|
-
) ->
|
|
1739
|
+
) -> list[tuple[Tensor, Optional[Tensor]]]:
|
|
1734
1740
|
"""
|
|
1735
1741
|
Returns a list of weights, split by table
|
|
1736
1742
|
"""
|
|
1737
1743
|
# fmt: off
|
|
1738
|
-
splits:
|
|
1744
|
+
splits: list[tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
|
|
1739
1745
|
self.split_embedding_weights_with_scale_bias(
|
|
1740
1746
|
split_scale_bias_mode=(1 if split_scale_shifts else 0)
|
|
1741
1747
|
)
|
|
@@ -1773,7 +1779,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1773
1779
|
)
|
|
1774
1780
|
|
|
1775
1781
|
def assign_embedding_weights(
|
|
1776
|
-
self, q_weight_list:
|
|
1782
|
+
self, q_weight_list: list[tuple[Tensor, Optional[Tensor]]]
|
|
1777
1783
|
) -> None:
|
|
1778
1784
|
"""
|
|
1779
1785
|
Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts.
|
|
@@ -1785,6 +1791,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1785
1791
|
dest_weight[0].copy_(input_weight[0])
|
|
1786
1792
|
if input_weight[1] is not None:
|
|
1787
1793
|
assert dest_weight[1] is not None
|
|
1794
|
+
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `copy_`.
|
|
1788
1795
|
dest_weight[1].copy_(input_weight[1])
|
|
1789
1796
|
else:
|
|
1790
1797
|
assert dest_weight[1] is None
|
|
@@ -1792,11 +1799,11 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1792
1799
|
@torch.jit.export
|
|
1793
1800
|
def set_index_remappings_array(
|
|
1794
1801
|
self,
|
|
1795
|
-
index_remapping:
|
|
1802
|
+
index_remapping: list[Tensor],
|
|
1796
1803
|
) -> None:
|
|
1797
|
-
rows:
|
|
1804
|
+
rows: list[int] = [e[1] for e in self.embedding_specs]
|
|
1798
1805
|
index_remappings_array_offsets = [0]
|
|
1799
|
-
original_feature_rows = torch.jit.annotate(
|
|
1806
|
+
original_feature_rows = torch.jit.annotate(list[int], [])
|
|
1800
1807
|
last_offset = 0
|
|
1801
1808
|
for t, mapping in enumerate(index_remapping):
|
|
1802
1809
|
if mapping is not None:
|
|
@@ -1835,11 +1842,11 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1835
1842
|
|
|
1836
1843
|
def set_index_remappings(
|
|
1837
1844
|
self,
|
|
1838
|
-
index_remapping:
|
|
1845
|
+
index_remapping: list[Tensor],
|
|
1839
1846
|
pruning_hash_load_factor: float = 0.5,
|
|
1840
1847
|
use_array_for_index_remapping: bool = True,
|
|
1841
1848
|
) -> None:
|
|
1842
|
-
rows:
|
|
1849
|
+
rows: list[int] = [e[1] for e in self.embedding_specs]
|
|
1843
1850
|
T = len(self.embedding_specs)
|
|
1844
1851
|
# Hash mapping pruning
|
|
1845
1852
|
if not use_array_for_index_remapping:
|
|
@@ -1909,7 +1916,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1909
1916
|
def _embedding_inplace_update_per_table(
|
|
1910
1917
|
self,
|
|
1911
1918
|
update_table_idx: int,
|
|
1912
|
-
update_row_indices:
|
|
1919
|
+
update_row_indices: list[int],
|
|
1913
1920
|
update_weights: Tensor,
|
|
1914
1921
|
) -> None:
|
|
1915
1922
|
row_size = len(update_row_indices)
|
|
@@ -1934,9 +1941,9 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1934
1941
|
@torch.jit.export
|
|
1935
1942
|
def embedding_inplace_update(
|
|
1936
1943
|
self,
|
|
1937
|
-
update_table_indices:
|
|
1938
|
-
update_row_indices:
|
|
1939
|
-
update_weights:
|
|
1944
|
+
update_table_indices: list[int],
|
|
1945
|
+
update_row_indices: list[list[int]],
|
|
1946
|
+
update_weights: list[Tensor],
|
|
1940
1947
|
) -> None:
|
|
1941
1948
|
for i in range(len(update_table_indices)):
|
|
1942
1949
|
self._embedding_inplace_update_per_table(
|
|
@@ -1947,8 +1954,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1947
1954
|
|
|
1948
1955
|
def embedding_inplace_update_internal(
|
|
1949
1956
|
self,
|
|
1950
|
-
update_table_indices:
|
|
1951
|
-
update_row_indices:
|
|
1957
|
+
update_table_indices: list[int],
|
|
1958
|
+
update_row_indices: list[int],
|
|
1952
1959
|
update_weights: Tensor,
|
|
1953
1960
|
) -> None:
|
|
1954
1961
|
assert len(update_table_indices) == len(update_row_indices)
|