fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
|
|
9
11
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
@@ -16,6 +18,12 @@
|
|
|
16
18
|
import torch
|
|
17
19
|
import warnings
|
|
18
20
|
from .lookup_args import *
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
19
27
|
def invoke(
|
|
20
28
|
common_args: CommonArgs,
|
|
21
29
|
optimizer_args: OptimizerArgs,
|
|
@@ -26,6 +34,7 @@ def invoke(
|
|
|
26
34
|
# only pass prev_iter_dev since prev_iter is never created on UVM
|
|
27
35
|
prev_iter_dev: Optional[torch.Tensor] = None,
|
|
28
36
|
gwd_lower_bound: float = 0.0,
|
|
37
|
+
mixed_D: bool = True,
|
|
29
38
|
row_counter: Optional[Momentum] = None,
|
|
30
39
|
) -> torch.Tensor:
|
|
31
40
|
# By design, the warning only shows up once
|
|
@@ -36,29 +45,245 @@ def invoke(
|
|
|
36
45
|
times and build sizes!
|
|
37
46
|
\033[0m"""
|
|
38
47
|
)
|
|
39
|
-
|
|
48
|
+
# host_weights is only used for CPU training
|
|
49
|
+
use_cpu = common_args.host_weights.numel() > 0
|
|
40
50
|
vbe_metadata = common_args.vbe_metadata
|
|
41
|
-
|
|
51
|
+
|
|
52
|
+
# pack weights
|
|
53
|
+
weights = [
|
|
54
|
+
common_args.dev_weights,
|
|
55
|
+
common_args.uvm_weights,
|
|
56
|
+
common_args.weights_placements,
|
|
57
|
+
common_args.weights_offsets,
|
|
58
|
+
common_args.lxu_cache_weights,
|
|
59
|
+
] if not use_cpu else [
|
|
60
|
+
common_args.host_weights,
|
|
61
|
+
common_args.weights_placements,
|
|
62
|
+
common_args.weights_offsets,
|
|
63
|
+
]
|
|
64
|
+
dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
|
|
65
|
+
"B_offsets": vbe_metadata.B_offsets,
|
|
66
|
+
"vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
|
|
67
|
+
"vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
|
|
68
|
+
"lxu_cache_locations": common_args.lxu_cache_locations,
|
|
69
|
+
"uvm_cache_stats": common_args.uvm_cache_stats,
|
|
70
|
+
"vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
dict_aux_int: Dict[str, int] = {
|
|
74
|
+
"iter": iter,
|
|
75
|
+
"info_B_num_bits": common_args.info_B_num_bits,
|
|
76
|
+
"info_B_mask": common_args.info_B_mask,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
dict_aux_float: Dict[str, float] = {
|
|
80
|
+
"gwd_lower_bound": gwd_lower_bound,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
dict_aux_bool: Dict[str, bool] = {
|
|
84
|
+
"is_experimental_tbe": common_args.is_experimental,
|
|
85
|
+
"use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
|
|
86
|
+
"use_homogeneous_placements": common_args.use_homogeneous_placements,
|
|
87
|
+
"apply_global_weight_decay": apply_global_weight_decay,
|
|
88
|
+
"mixed_D": mixed_D,
|
|
89
|
+
}
|
|
90
|
+
dict_optim_int: Dict[str, int] = {}
|
|
91
|
+
dict_optim_float: Dict[str, float] = {}
|
|
92
|
+
dict_optim_bool: Dict[str, bool] = {}
|
|
93
|
+
|
|
94
|
+
# Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
|
|
95
|
+
dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# optimizer_args # if optimizer == none
|
|
99
|
+
dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
|
|
100
|
+
dict_aux_float["max_gradient"] = optimizer_args.max_gradient
|
|
101
|
+
dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
|
|
102
|
+
dict_optim_float["eps"] = optimizer_args.eps
|
|
103
|
+
dict_optim_float["beta1"] = optimizer_args.beta1
|
|
104
|
+
dict_optim_float["beta2"] = optimizer_args.beta2
|
|
105
|
+
dict_optim_float["weight_decay"] = optimizer_args.weight_decay
|
|
106
|
+
dict_optim_bool["use_rowwise_bias_correction"] = optimizer_args.use_rowwise_bias_correction
|
|
107
|
+
|
|
108
|
+
momentum1_list = [
|
|
109
|
+
momentum1.dev,
|
|
110
|
+
momentum1.uvm,
|
|
111
|
+
momentum1.placements,
|
|
112
|
+
momentum1.offsets,
|
|
113
|
+
] if not use_cpu else [
|
|
114
|
+
momentum1.host,
|
|
115
|
+
momentum1.placements,
|
|
116
|
+
momentum1.offsets,
|
|
117
|
+
] if momentum1 is not None else None
|
|
118
|
+
|
|
119
|
+
momentum2_list = [
|
|
120
|
+
momentum2.dev,
|
|
121
|
+
momentum2.uvm,
|
|
122
|
+
momentum2.placements,
|
|
123
|
+
momentum2.offsets,
|
|
124
|
+
] if not use_cpu else [
|
|
125
|
+
momentum2.host,
|
|
126
|
+
momentum2.placements,
|
|
127
|
+
momentum2.offsets,
|
|
128
|
+
] if momentum2 is not None else None
|
|
129
|
+
|
|
130
|
+
if optimizer_args.use_rowwise_bias_correction and row_counter is not None:
|
|
131
|
+
row_counter_host = None # not supported on CPU
|
|
132
|
+
row_counter_dev = row_counter.dev
|
|
133
|
+
row_counter_uvm = row_counter.uvm
|
|
134
|
+
row_counter_offsets = row_counter.offsets
|
|
135
|
+
row_counter_placements = row_counter.placements
|
|
136
|
+
elif optimizer_args.use_rowwise_bias_correction:
|
|
137
|
+
assert False, "`use_rowwise_bias_correction` is set, `row_counter` cannot be None"
|
|
138
|
+
else:
|
|
139
|
+
row_counter_host = None
|
|
42
140
|
row_counter_dev = None
|
|
43
141
|
row_counter_uvm = None
|
|
44
142
|
row_counter_offsets = None
|
|
45
143
|
row_counter_placements = None
|
|
46
|
-
|
|
47
|
-
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
aux_tensor: List[Optional[torch.Tensor]] = []
|
|
147
|
+
assert "B_offsets" in dict_aux_tensor, (
|
|
148
|
+
"B_offsets must be in dict_aux_tensor. "
|
|
149
|
+
"Please check the frontend and backend version. "
|
|
150
|
+
)
|
|
151
|
+
aux_tensor.append(dict_aux_tensor["B_offsets"])
|
|
152
|
+
assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
|
|
153
|
+
"vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
|
|
154
|
+
"Please check the frontend and backend version. "
|
|
155
|
+
)
|
|
156
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
|
|
157
|
+
assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
|
|
158
|
+
"vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
|
|
159
|
+
"Please check the frontend and backend version. "
|
|
160
|
+
)
|
|
161
|
+
aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
|
|
162
|
+
assert "lxu_cache_locations" in dict_aux_tensor, (
|
|
163
|
+
"lxu_cache_locations must be in dict_aux_tensor. "
|
|
164
|
+
"Please check the frontend and backend version. "
|
|
165
|
+
)
|
|
166
|
+
aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
|
|
167
|
+
assert "uvm_cache_stats" in dict_aux_tensor, (
|
|
168
|
+
"uvm_cache_stats must be in dict_aux_tensor. "
|
|
169
|
+
"Please check the frontend and backend version. "
|
|
170
|
+
)
|
|
171
|
+
aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
|
|
172
|
+
assert "prev_iter_dev" in dict_aux_tensor, (
|
|
173
|
+
"prev_iter_dev must be in dict_aux_tensor. "
|
|
174
|
+
"Please check the frontend and backend version. "
|
|
175
|
+
)
|
|
176
|
+
aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
|
|
177
|
+
assert "vbe_output_offsets" in dict_aux_tensor, (
|
|
178
|
+
"vbe_output_offsets must be in dict_aux_tensor. "
|
|
179
|
+
"Please check the frontend and backend version. "
|
|
180
|
+
)
|
|
181
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
|
|
182
|
+
|
|
183
|
+
aux_int: List[int] = []
|
|
184
|
+
assert "iter" in dict_aux_int, (
|
|
185
|
+
"iter must be in dict_aux_int. "
|
|
186
|
+
"Please check the frontend and backend version. "
|
|
187
|
+
)
|
|
188
|
+
aux_int.append(dict_aux_int["iter"])
|
|
189
|
+
assert "info_B_num_bits" in dict_aux_int, (
|
|
190
|
+
"info_B_num_bits must be in dict_aux_int. "
|
|
191
|
+
"Please check the frontend and backend version. "
|
|
192
|
+
)
|
|
193
|
+
aux_int.append(dict_aux_int["info_B_num_bits"])
|
|
194
|
+
assert "info_B_mask" in dict_aux_int, (
|
|
195
|
+
"info_B_mask must be in dict_aux_int. "
|
|
196
|
+
"Please check the frontend and backend version. "
|
|
197
|
+
)
|
|
198
|
+
aux_int.append(dict_aux_int["info_B_mask"])
|
|
199
|
+
|
|
200
|
+
aux_float: List[float] = []
|
|
201
|
+
assert "gwd_lower_bound" in dict_aux_float, (
|
|
202
|
+
"gwd_lower_bound must be in dict_aux_float. "
|
|
203
|
+
"Please check the frontend and backend version. "
|
|
204
|
+
)
|
|
205
|
+
aux_float.append(dict_aux_float["gwd_lower_bound"])
|
|
206
|
+
assert "max_gradient" in dict_aux_float, (
|
|
207
|
+
"max_gradient must be in dict_aux_float. "
|
|
208
|
+
"Please check the frontend and backend version. "
|
|
209
|
+
)
|
|
210
|
+
aux_float.append(dict_aux_float["max_gradient"])
|
|
211
|
+
|
|
212
|
+
aux_bool: List[bool] = []
|
|
213
|
+
assert "is_experimental_tbe" in dict_aux_bool, (
|
|
214
|
+
"is_experimental_tbe must be in dict_aux_bool. "
|
|
215
|
+
"Please check the frontend and backend version. "
|
|
216
|
+
)
|
|
217
|
+
aux_bool.append(dict_aux_bool["is_experimental_tbe"])
|
|
218
|
+
assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
|
|
219
|
+
"use_uniq_cache_locations_bwd must be in dict_aux_bool. "
|
|
220
|
+
"Please check the frontend and backend version. "
|
|
221
|
+
)
|
|
222
|
+
aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
|
|
223
|
+
assert "use_homogeneous_placements" in dict_aux_bool, (
|
|
224
|
+
"use_homogeneous_placements must be in dict_aux_bool. "
|
|
225
|
+
"Please check the frontend and backend version. "
|
|
226
|
+
)
|
|
227
|
+
aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
|
|
228
|
+
assert "apply_global_weight_decay" in dict_aux_bool, (
|
|
229
|
+
"apply_global_weight_decay must be in dict_aux_bool. "
|
|
230
|
+
"Please check the frontend and backend version. "
|
|
231
|
+
)
|
|
232
|
+
aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
|
|
233
|
+
assert "gradient_clipping" in dict_aux_bool, (
|
|
234
|
+
"gradient_clipping must be in dict_aux_bool. "
|
|
235
|
+
"Please check the frontend and backend version. "
|
|
236
|
+
)
|
|
237
|
+
aux_bool.append(dict_aux_bool["gradient_clipping"])
|
|
238
|
+
assert "stochastic_rounding" in dict_aux_bool, (
|
|
239
|
+
"stochastic_rounding must be in dict_aux_bool. "
|
|
240
|
+
"Please check the frontend and backend version. "
|
|
241
|
+
)
|
|
242
|
+
aux_bool.append(dict_aux_bool["stochastic_rounding"])
|
|
243
|
+
assert "mixed_D" in dict_aux_bool, (
|
|
244
|
+
"mixed_D must be in dict_aux_bool. "
|
|
245
|
+
"Please check the frontend and backend version. "
|
|
246
|
+
)
|
|
247
|
+
aux_bool.append(dict_aux_bool["mixed_D"])
|
|
248
|
+
optim_tensor: List[Optional[torch.Tensor]] = []
|
|
249
|
+
# We cannot do list of optional tensorlist (optional tensorlist is Tensor?[]).
|
|
250
|
+
# we need to pack optimizer optional tensors in a flatten manner.
|
|
251
|
+
# We pack unified args (i.e., 5 items) since it's very confusing to pack/unpack per device (i.e, 3 for cpu and 4 for cuda)
|
|
252
|
+
# e.g., if we have optim optional tensors x and y, the optim_tensor will look like
|
|
253
|
+
# [x_host, x_dev, x_uvm, x_placements, x_offsets, y_host, y_dev, y_uvm, y_placements, y_offsets]
|
|
254
|
+
# ['row_counter']
|
|
255
|
+
|
|
256
|
+
# using .extend fails torch script
|
|
257
|
+
if row_counter is None:
|
|
258
|
+
optim_tensor.append(None)
|
|
259
|
+
optim_tensor.append(None)
|
|
260
|
+
optim_tensor.append(None)
|
|
261
|
+
optim_tensor.append(None)
|
|
262
|
+
optim_tensor.append(None)
|
|
48
263
|
else:
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
264
|
+
optim_tensor.append(row_counter.host)
|
|
265
|
+
optim_tensor.append(row_counter.dev)
|
|
266
|
+
optim_tensor.append(row_counter.uvm)
|
|
267
|
+
optim_tensor.append(row_counter.placements)
|
|
268
|
+
optim_tensor.append(row_counter.offsets)
|
|
269
|
+
|
|
270
|
+
# optim_int
|
|
271
|
+
# optim_float
|
|
272
|
+
# ['momentum1', 'momentum2', 'learning_rate_tensor', 'optim_tensor', 'optim_float', 'optim_bool']
|
|
273
|
+
optim_float: List[float] = []
|
|
274
|
+
optim_float.append(dict_optim_float["eps"])
|
|
275
|
+
optim_float.append(dict_optim_float["beta1"])
|
|
276
|
+
optim_float.append(dict_optim_float["beta2"])
|
|
277
|
+
optim_float.append(dict_optim_float["weight_decay"])
|
|
278
|
+
# optim_bool
|
|
279
|
+
optim_bool: List[bool] = []
|
|
280
|
+
optim_bool.append(dict_optim_bool["use_rowwise_bias_correction"])
|
|
53
281
|
|
|
54
|
-
return torch.ops.fbgemm.
|
|
282
|
+
return torch.ops.fbgemm.split_embedding_codegen_lookup_adam_function_pt2(
|
|
55
283
|
# common_args
|
|
56
284
|
placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
lxu_cache_weights=common_args.lxu_cache_weights,
|
|
60
|
-
weights_placements=common_args.weights_placements,
|
|
61
|
-
weights_offsets=common_args.weights_offsets,
|
|
285
|
+
# weights
|
|
286
|
+
weights=weights,
|
|
62
287
|
D_offsets=common_args.D_offsets,
|
|
63
288
|
total_D=common_args.total_D,
|
|
64
289
|
max_D=common_args.max_D,
|
|
@@ -69,52 +294,35 @@ def invoke(
|
|
|
69
294
|
pooling_mode=common_args.pooling_mode,
|
|
70
295
|
indice_weights=common_args.indice_weights,
|
|
71
296
|
feature_requires_grad=common_args.feature_requires_grad,
|
|
72
|
-
|
|
73
|
-
uvm_cache_stats=common_args.uvm_cache_stats,
|
|
297
|
+
output_dtype=common_args.output_dtype,
|
|
74
298
|
# VBE metadata
|
|
75
|
-
B_offsets=vbe_metadata.B_offsets,
|
|
76
|
-
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
|
|
77
|
-
vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
|
|
78
299
|
max_B=vbe_metadata.max_B,
|
|
79
300
|
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
|
|
80
301
|
vbe_output_size=vbe_metadata.output_size,
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
302
|
+
vbe_output=vbe_metadata.vbe_output,
|
|
303
|
+
# aux_tensor
|
|
304
|
+
aux_tensor=aux_tensor,
|
|
305
|
+
# aux_int
|
|
306
|
+
aux_int=aux_int,
|
|
307
|
+
# aux_float
|
|
308
|
+
aux_float=aux_float,
|
|
309
|
+
# aux_bool
|
|
310
|
+
aux_bool=aux_bool,
|
|
311
|
+
learning_rate_tensor=common_args.learning_rate_tensor,
|
|
312
|
+
|
|
91
313
|
# momentum1
|
|
92
|
-
|
|
93
|
-
momentum1_uvm=momentum1.uvm,
|
|
94
|
-
momentum1_offsets=momentum1.offsets,
|
|
95
|
-
momentum1_placements=momentum1.placements,
|
|
314
|
+
momentum1 = momentum1_list,
|
|
96
315
|
# momentum2
|
|
97
|
-
|
|
98
|
-
momentum2_uvm=momentum2.uvm,
|
|
99
|
-
momentum2_offsets=momentum2.offsets,
|
|
100
|
-
momentum2_placements=momentum2.placements,
|
|
316
|
+
momentum2=momentum2_list,
|
|
101
317
|
# prev_iter
|
|
102
|
-
|
|
103
|
-
prev_iter_dev=prev_iter_dev,
|
|
104
318
|
# row_counter
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
#
|
|
111
|
-
|
|
112
|
-
#
|
|
319
|
+
# optim_tensor
|
|
320
|
+
optim_tensor=optim_tensor,
|
|
321
|
+
# optim_int
|
|
322
|
+
# optim_float
|
|
323
|
+
optim_float=optim_float,
|
|
324
|
+
# optim_bool
|
|
325
|
+
optim_bool=optim_bool,
|
|
326
|
+
# optim symint args
|
|
113
327
|
# total_unique_indices
|
|
114
|
-
output_dtype=common_args.output_dtype,
|
|
115
|
-
is_experimental=common_args.is_experimental,
|
|
116
|
-
use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
|
|
117
|
-
use_homogeneous_placements=common_args.use_homogeneous_placements,
|
|
118
|
-
apply_global_weight_decay=apply_global_weight_decay,
|
|
119
|
-
gwd_lower_bound=gwd_lower_bound,
|
|
120
328
|
)
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/lookup_args.template
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/lookup_args.template"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
9
11
|
# All rights reserved.
|
|
@@ -25,6 +27,8 @@ class VBEMetadata(NamedTuple):
|
|
|
25
27
|
max_B_feature_rank: int = -1
|
|
26
28
|
max_B: int = -1
|
|
27
29
|
output_size: int = -1
|
|
30
|
+
vbe_output: Optional[torch.Tensor] = None
|
|
31
|
+
vbe_output_offsets: Optional[torch.Tensor] = None
|
|
28
32
|
|
|
29
33
|
|
|
30
34
|
class CommonArgs(NamedTuple):
|
|
@@ -52,71 +56,18 @@ class CommonArgs(NamedTuple):
|
|
|
52
56
|
is_experimental: bool
|
|
53
57
|
use_uniq_cache_locations_bwd: bool
|
|
54
58
|
use_homogeneous_placements: bool
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
class OptimizerArgs(NamedTuple):
|
|
58
|
-
stochastic_rounding: bool
|
|
59
|
-
gradient_clipping: bool
|
|
60
|
-
max_gradient: float
|
|
61
|
-
max_norm: float
|
|
62
|
-
learning_rate: float
|
|
63
|
-
eps: float
|
|
64
|
-
beta1: float
|
|
65
|
-
beta2: float
|
|
66
|
-
weight_decay: float
|
|
67
|
-
weight_decay_mode: int
|
|
68
|
-
eta: float
|
|
69
|
-
momentum: float
|
|
70
|
-
counter_halflife: int
|
|
71
|
-
adjustment_iter: int
|
|
72
|
-
adjustment_ub: float
|
|
73
|
-
learning_rate_mode: int
|
|
74
|
-
grad_sum_decay: int
|
|
75
|
-
tail_id_threshold: float
|
|
76
|
-
is_tail_id_thresh_ratio: int
|
|
77
|
-
total_hash_size: int # Required for OptimType.NONE
|
|
78
|
-
weight_norm_coefficient: float
|
|
79
|
-
lower_bound: float
|
|
80
|
-
regularization_mode: int
|
|
81
|
-
use_rowwise_bias_correction: bool # Used for OptimType.ADAM
|
|
82
|
-
|
|
83
|
-
class CommonArgsPT2(NamedTuple):
|
|
84
|
-
placeholder_autograd_tensor: torch.Tensor
|
|
85
|
-
dev_weights: torch.Tensor
|
|
86
|
-
host_weights: torch.Tensor
|
|
87
|
-
uvm_weights: torch.Tensor
|
|
88
|
-
lxu_cache_weights: torch.Tensor
|
|
89
|
-
weights_placements: torch.Tensor
|
|
90
|
-
weights_offsets: torch.Tensor
|
|
91
|
-
D_offsets: torch.Tensor
|
|
92
|
-
total_D: int
|
|
93
|
-
max_D: int
|
|
94
|
-
hash_size_cumsum: torch.Tensor
|
|
95
|
-
total_hash_size_bits: int
|
|
96
|
-
indices: torch.Tensor
|
|
97
|
-
offsets: torch.Tensor
|
|
98
|
-
pooling_mode: int
|
|
99
|
-
indice_weights: Optional[torch.Tensor]
|
|
100
|
-
feature_requires_grad: Optional[torch.Tensor]
|
|
101
|
-
lxu_cache_locations: torch.Tensor
|
|
102
|
-
uvm_cache_stats: Optional[torch.Tensor]
|
|
103
|
-
output_dtype: int
|
|
104
|
-
vbe_metadata: VBEMetadata
|
|
105
|
-
is_experimental: bool
|
|
106
|
-
use_uniq_cache_locations_bwd: bool
|
|
107
|
-
use_homogeneous_placements: bool
|
|
59
|
+
learning_rate_tensor: torch.Tensor
|
|
108
60
|
info_B_num_bits: int
|
|
109
61
|
info_B_mask: int
|
|
110
62
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
63
|
+
|
|
64
|
+
# Do not add a parameter of Type tensor here. It will cause JIT script error due to a bug in PyTorch.
|
|
65
|
+
# See more detail in D71010630.
|
|
66
|
+
class OptimizerArgs(NamedTuple):
|
|
115
67
|
stochastic_rounding: bool
|
|
116
68
|
gradient_clipping: bool
|
|
117
69
|
max_gradient: float
|
|
118
70
|
max_norm: float
|
|
119
|
-
learning_rate_tensor: torch.Tensor
|
|
120
71
|
eps: float
|
|
121
72
|
beta1: float
|
|
122
73
|
beta2: float
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/lookup_args.template
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/lookup_args.template"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
9
11
|
# All rights reserved.
|
|
@@ -25,6 +27,8 @@ class VBEMetadata(NamedTuple):
|
|
|
25
27
|
max_B_feature_rank: int = -1
|
|
26
28
|
max_B: int = -1
|
|
27
29
|
output_size: int = -1
|
|
30
|
+
vbe_output: Optional[torch.Tensor] = None
|
|
31
|
+
vbe_output_offsets: Optional[torch.Tensor] = None
|
|
28
32
|
|
|
29
33
|
|
|
30
34
|
class CommonArgs(NamedTuple):
|
|
@@ -53,72 +57,19 @@ class CommonArgs(NamedTuple):
|
|
|
53
57
|
use_uniq_cache_locations_bwd: bool
|
|
54
58
|
use_homogeneous_placements: bool
|
|
55
59
|
ssd_tensors: Dict[str, torch.Tensor]
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class OptimizerArgs(NamedTuple):
|
|
59
|
-
stochastic_rounding: bool
|
|
60
|
-
gradient_clipping: bool
|
|
61
|
-
max_gradient: float
|
|
62
|
-
max_norm: float
|
|
63
|
-
learning_rate: float
|
|
64
|
-
eps: float
|
|
65
|
-
beta1: float
|
|
66
|
-
beta2: float
|
|
67
|
-
weight_decay: float
|
|
68
|
-
weight_decay_mode: int
|
|
69
|
-
eta: float
|
|
70
|
-
momentum: float
|
|
71
|
-
counter_halflife: int
|
|
72
|
-
adjustment_iter: int
|
|
73
|
-
adjustment_ub: float
|
|
74
|
-
learning_rate_mode: int
|
|
75
|
-
grad_sum_decay: int
|
|
76
|
-
tail_id_threshold: float
|
|
77
|
-
is_tail_id_thresh_ratio: int
|
|
78
|
-
total_hash_size: int # Required for OptimType.NONE
|
|
79
|
-
weight_norm_coefficient: float
|
|
80
|
-
lower_bound: float
|
|
81
|
-
regularization_mode: int
|
|
82
|
-
use_rowwise_bias_correction: bool # Used for OptimType.ADAM
|
|
83
|
-
|
|
84
|
-
class CommonArgsPT2(NamedTuple):
|
|
85
|
-
placeholder_autograd_tensor: torch.Tensor
|
|
86
|
-
dev_weights: torch.Tensor
|
|
87
|
-
host_weights: torch.Tensor
|
|
88
|
-
uvm_weights: torch.Tensor
|
|
89
|
-
lxu_cache_weights: torch.Tensor
|
|
90
|
-
weights_placements: torch.Tensor
|
|
91
|
-
weights_offsets: torch.Tensor
|
|
92
|
-
D_offsets: torch.Tensor
|
|
93
|
-
total_D: int
|
|
94
|
-
max_D: int
|
|
95
|
-
hash_size_cumsum: torch.Tensor
|
|
96
|
-
total_hash_size_bits: int
|
|
97
|
-
indices: torch.Tensor
|
|
98
|
-
offsets: torch.Tensor
|
|
99
|
-
pooling_mode: int
|
|
100
|
-
indice_weights: Optional[torch.Tensor]
|
|
101
|
-
feature_requires_grad: Optional[torch.Tensor]
|
|
102
|
-
lxu_cache_locations: torch.Tensor
|
|
103
|
-
uvm_cache_stats: Optional[torch.Tensor]
|
|
104
|
-
output_dtype: int
|
|
105
|
-
vbe_metadata: VBEMetadata
|
|
106
|
-
is_experimental: bool
|
|
107
|
-
use_uniq_cache_locations_bwd: bool
|
|
108
|
-
use_homogeneous_placements: bool
|
|
60
|
+
enable_optimizer_offloading: bool
|
|
61
|
+
learning_rate_tensor: torch.Tensor
|
|
109
62
|
info_B_num_bits: int
|
|
110
63
|
info_B_mask: int
|
|
111
|
-
ssd_tensors: Dict[str, torch.Tensor]
|
|
112
64
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
65
|
+
|
|
66
|
+
# Do not add a parameter of Type tensor here. It will cause JIT script error due to a bug in PyTorch.
|
|
67
|
+
# See more detail in D71010630.
|
|
68
|
+
class OptimizerArgs(NamedTuple):
|
|
117
69
|
stochastic_rounding: bool
|
|
118
70
|
gradient_clipping: bool
|
|
119
71
|
max_gradient: float
|
|
120
72
|
max_norm: float
|
|
121
|
-
learning_rate_tensor: torch.Tensor
|
|
122
73
|
eps: float
|
|
123
74
|
beta1: float
|
|
124
75
|
beta2: float
|