fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
|
|
9
11
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
@@ -15,6 +17,12 @@
|
|
|
15
17
|
|
|
16
18
|
import torch
|
|
17
19
|
from .lookup_args import *
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
18
26
|
def invoke(
|
|
19
27
|
common_args: CommonArgs,
|
|
20
28
|
optimizer_args: OptimizerArgs,
|
|
@@ -25,18 +33,208 @@ def invoke(
|
|
|
25
33
|
# only pass prev_iter_dev since prev_iter is never created on UVM
|
|
26
34
|
prev_iter_dev: Optional[torch.Tensor] = None,
|
|
27
35
|
gwd_lower_bound: float = 0.0,
|
|
36
|
+
mixed_D: bool = True,
|
|
28
37
|
) -> torch.Tensor:
|
|
29
|
-
|
|
38
|
+
# host_weights is only used for CPU training
|
|
39
|
+
use_cpu = common_args.host_weights.numel() > 0
|
|
30
40
|
vbe_metadata = common_args.vbe_metadata
|
|
31
41
|
|
|
32
|
-
|
|
42
|
+
# pack weights
|
|
43
|
+
weights = [
|
|
44
|
+
common_args.dev_weights,
|
|
45
|
+
common_args.uvm_weights,
|
|
46
|
+
common_args.weights_placements,
|
|
47
|
+
common_args.weights_offsets,
|
|
48
|
+
common_args.lxu_cache_weights,
|
|
49
|
+
] if not use_cpu else [
|
|
50
|
+
common_args.host_weights,
|
|
51
|
+
common_args.weights_placements,
|
|
52
|
+
common_args.weights_offsets,
|
|
53
|
+
]
|
|
54
|
+
dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
|
|
55
|
+
"B_offsets": vbe_metadata.B_offsets,
|
|
56
|
+
"vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
|
|
57
|
+
"vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
|
|
58
|
+
"lxu_cache_locations": common_args.lxu_cache_locations,
|
|
59
|
+
"uvm_cache_stats": common_args.uvm_cache_stats,
|
|
60
|
+
"vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
dict_aux_int: Dict[str, int] = {
|
|
64
|
+
"iter": iter,
|
|
65
|
+
"info_B_num_bits": common_args.info_B_num_bits,
|
|
66
|
+
"info_B_mask": common_args.info_B_mask,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
dict_aux_float: Dict[str, float] = {
|
|
70
|
+
"gwd_lower_bound": gwd_lower_bound,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
dict_aux_bool: Dict[str, bool] = {
|
|
74
|
+
"is_experimental_tbe": common_args.is_experimental,
|
|
75
|
+
"use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
|
|
76
|
+
"use_homogeneous_placements": common_args.use_homogeneous_placements,
|
|
77
|
+
"apply_global_weight_decay": apply_global_weight_decay,
|
|
78
|
+
"mixed_D": mixed_D,
|
|
79
|
+
}
|
|
80
|
+
dict_optim_int: Dict[str, int] = {}
|
|
81
|
+
dict_optim_float: Dict[str, float] = {}
|
|
82
|
+
dict_optim_bool: Dict[str, bool] = {}
|
|
83
|
+
|
|
84
|
+
# Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
|
|
85
|
+
dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# optimizer_args # if optimizer == none
|
|
89
|
+
dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
|
|
90
|
+
dict_aux_float["max_gradient"] = optimizer_args.max_gradient
|
|
91
|
+
dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
|
|
92
|
+
dict_optim_float["eps"] = optimizer_args.eps
|
|
93
|
+
dict_optim_float["beta1"] = optimizer_args.beta1
|
|
94
|
+
dict_optim_float["beta2"] = optimizer_args.beta2
|
|
95
|
+
dict_optim_float["weight_decay"] = optimizer_args.weight_decay
|
|
96
|
+
|
|
97
|
+
momentum1_list = [
|
|
98
|
+
momentum1.dev,
|
|
99
|
+
momentum1.uvm,
|
|
100
|
+
momentum1.placements,
|
|
101
|
+
momentum1.offsets,
|
|
102
|
+
] if not use_cpu else [
|
|
103
|
+
momentum1.host,
|
|
104
|
+
momentum1.placements,
|
|
105
|
+
momentum1.offsets,
|
|
106
|
+
] if momentum1 is not None else None
|
|
107
|
+
|
|
108
|
+
momentum2_list = [
|
|
109
|
+
momentum2.dev,
|
|
110
|
+
momentum2.uvm,
|
|
111
|
+
momentum2.placements,
|
|
112
|
+
momentum2.offsets,
|
|
113
|
+
] if not use_cpu else [
|
|
114
|
+
momentum2.host,
|
|
115
|
+
momentum2.placements,
|
|
116
|
+
momentum2.offsets,
|
|
117
|
+
] if momentum2 is not None else None
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
aux_tensor: List[Optional[torch.Tensor]] = []
|
|
121
|
+
assert "B_offsets" in dict_aux_tensor, (
|
|
122
|
+
"B_offsets must be in dict_aux_tensor. "
|
|
123
|
+
"Please check the frontend and backend version. "
|
|
124
|
+
)
|
|
125
|
+
aux_tensor.append(dict_aux_tensor["B_offsets"])
|
|
126
|
+
assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
|
|
127
|
+
"vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
|
|
128
|
+
"Please check the frontend and backend version. "
|
|
129
|
+
)
|
|
130
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
|
|
131
|
+
assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
|
|
132
|
+
"vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
|
|
133
|
+
"Please check the frontend and backend version. "
|
|
134
|
+
)
|
|
135
|
+
aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
|
|
136
|
+
assert "lxu_cache_locations" in dict_aux_tensor, (
|
|
137
|
+
"lxu_cache_locations must be in dict_aux_tensor. "
|
|
138
|
+
"Please check the frontend and backend version. "
|
|
139
|
+
)
|
|
140
|
+
aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
|
|
141
|
+
assert "uvm_cache_stats" in dict_aux_tensor, (
|
|
142
|
+
"uvm_cache_stats must be in dict_aux_tensor. "
|
|
143
|
+
"Please check the frontend and backend version. "
|
|
144
|
+
)
|
|
145
|
+
aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
|
|
146
|
+
assert "prev_iter_dev" in dict_aux_tensor, (
|
|
147
|
+
"prev_iter_dev must be in dict_aux_tensor. "
|
|
148
|
+
"Please check the frontend and backend version. "
|
|
149
|
+
)
|
|
150
|
+
aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
|
|
151
|
+
assert "vbe_output_offsets" in dict_aux_tensor, (
|
|
152
|
+
"vbe_output_offsets must be in dict_aux_tensor. "
|
|
153
|
+
"Please check the frontend and backend version. "
|
|
154
|
+
)
|
|
155
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
|
|
156
|
+
|
|
157
|
+
aux_int: List[int] = []
|
|
158
|
+
assert "iter" in dict_aux_int, (
|
|
159
|
+
"iter must be in dict_aux_int. "
|
|
160
|
+
"Please check the frontend and backend version. "
|
|
161
|
+
)
|
|
162
|
+
aux_int.append(dict_aux_int["iter"])
|
|
163
|
+
assert "info_B_num_bits" in dict_aux_int, (
|
|
164
|
+
"info_B_num_bits must be in dict_aux_int. "
|
|
165
|
+
"Please check the frontend and backend version. "
|
|
166
|
+
)
|
|
167
|
+
aux_int.append(dict_aux_int["info_B_num_bits"])
|
|
168
|
+
assert "info_B_mask" in dict_aux_int, (
|
|
169
|
+
"info_B_mask must be in dict_aux_int. "
|
|
170
|
+
"Please check the frontend and backend version. "
|
|
171
|
+
)
|
|
172
|
+
aux_int.append(dict_aux_int["info_B_mask"])
|
|
173
|
+
|
|
174
|
+
aux_float: List[float] = []
|
|
175
|
+
assert "gwd_lower_bound" in dict_aux_float, (
|
|
176
|
+
"gwd_lower_bound must be in dict_aux_float. "
|
|
177
|
+
"Please check the frontend and backend version. "
|
|
178
|
+
)
|
|
179
|
+
aux_float.append(dict_aux_float["gwd_lower_bound"])
|
|
180
|
+
assert "max_gradient" in dict_aux_float, (
|
|
181
|
+
"max_gradient must be in dict_aux_float. "
|
|
182
|
+
"Please check the frontend and backend version. "
|
|
183
|
+
)
|
|
184
|
+
aux_float.append(dict_aux_float["max_gradient"])
|
|
185
|
+
|
|
186
|
+
aux_bool: List[bool] = []
|
|
187
|
+
assert "is_experimental_tbe" in dict_aux_bool, (
|
|
188
|
+
"is_experimental_tbe must be in dict_aux_bool. "
|
|
189
|
+
"Please check the frontend and backend version. "
|
|
190
|
+
)
|
|
191
|
+
aux_bool.append(dict_aux_bool["is_experimental_tbe"])
|
|
192
|
+
assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
|
|
193
|
+
"use_uniq_cache_locations_bwd must be in dict_aux_bool. "
|
|
194
|
+
"Please check the frontend and backend version. "
|
|
195
|
+
)
|
|
196
|
+
aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
|
|
197
|
+
assert "use_homogeneous_placements" in dict_aux_bool, (
|
|
198
|
+
"use_homogeneous_placements must be in dict_aux_bool. "
|
|
199
|
+
"Please check the frontend and backend version. "
|
|
200
|
+
)
|
|
201
|
+
aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
|
|
202
|
+
assert "apply_global_weight_decay" in dict_aux_bool, (
|
|
203
|
+
"apply_global_weight_decay must be in dict_aux_bool. "
|
|
204
|
+
"Please check the frontend and backend version. "
|
|
205
|
+
)
|
|
206
|
+
aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
|
|
207
|
+
assert "gradient_clipping" in dict_aux_bool, (
|
|
208
|
+
"gradient_clipping must be in dict_aux_bool. "
|
|
209
|
+
"Please check the frontend and backend version. "
|
|
210
|
+
)
|
|
211
|
+
aux_bool.append(dict_aux_bool["gradient_clipping"])
|
|
212
|
+
assert "stochastic_rounding" in dict_aux_bool, (
|
|
213
|
+
"stochastic_rounding must be in dict_aux_bool. "
|
|
214
|
+
"Please check the frontend and backend version. "
|
|
215
|
+
)
|
|
216
|
+
aux_bool.append(dict_aux_bool["stochastic_rounding"])
|
|
217
|
+
assert "mixed_D" in dict_aux_bool, (
|
|
218
|
+
"mixed_D must be in dict_aux_bool. "
|
|
219
|
+
"Please check the frontend and backend version. "
|
|
220
|
+
)
|
|
221
|
+
aux_bool.append(dict_aux_bool["mixed_D"])
|
|
222
|
+
|
|
223
|
+
# optim_int
|
|
224
|
+
# optim_float
|
|
225
|
+
# ['momentum1', 'momentum2', 'learning_rate_tensor', 'optim_float']
|
|
226
|
+
optim_float: List[float] = []
|
|
227
|
+
optim_float.append(dict_optim_float["eps"])
|
|
228
|
+
optim_float.append(dict_optim_float["beta1"])
|
|
229
|
+
optim_float.append(dict_optim_float["beta2"])
|
|
230
|
+
optim_float.append(dict_optim_float["weight_decay"])
|
|
231
|
+
# optim_bool
|
|
232
|
+
|
|
233
|
+
return torch.ops.fbgemm.split_embedding_codegen_lookup_partial_rowwise_lamb_function_pt2(
|
|
33
234
|
# common_args
|
|
34
235
|
placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
lxu_cache_weights=common_args.lxu_cache_weights,
|
|
38
|
-
weights_placements=common_args.weights_placements,
|
|
39
|
-
weights_offsets=common_args.weights_offsets,
|
|
236
|
+
# weights
|
|
237
|
+
weights=weights,
|
|
40
238
|
D_offsets=common_args.D_offsets,
|
|
41
239
|
total_D=common_args.total_D,
|
|
42
240
|
max_D=common_args.max_D,
|
|
@@ -47,47 +245,33 @@ def invoke(
|
|
|
47
245
|
pooling_mode=common_args.pooling_mode,
|
|
48
246
|
indice_weights=common_args.indice_weights,
|
|
49
247
|
feature_requires_grad=common_args.feature_requires_grad,
|
|
50
|
-
|
|
51
|
-
uvm_cache_stats=common_args.uvm_cache_stats,
|
|
248
|
+
output_dtype=common_args.output_dtype,
|
|
52
249
|
# VBE metadata
|
|
53
|
-
B_offsets=vbe_metadata.B_offsets,
|
|
54
|
-
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
|
|
55
|
-
vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
|
|
56
250
|
max_B=vbe_metadata.max_B,
|
|
57
251
|
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
|
|
58
252
|
vbe_output_size=vbe_metadata.output_size,
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
253
|
+
vbe_output=vbe_metadata.vbe_output,
|
|
254
|
+
# aux_tensor
|
|
255
|
+
aux_tensor=aux_tensor,
|
|
256
|
+
# aux_int
|
|
257
|
+
aux_int=aux_int,
|
|
258
|
+
# aux_float
|
|
259
|
+
aux_float=aux_float,
|
|
260
|
+
# aux_bool
|
|
261
|
+
aux_bool=aux_bool,
|
|
262
|
+
learning_rate_tensor=common_args.learning_rate_tensor,
|
|
263
|
+
|
|
69
264
|
# momentum1
|
|
70
|
-
|
|
71
|
-
momentum1_uvm=momentum1.uvm,
|
|
72
|
-
momentum1_offsets=momentum1.offsets,
|
|
73
|
-
momentum1_placements=momentum1.placements,
|
|
265
|
+
momentum1 = momentum1_list,
|
|
74
266
|
# momentum2
|
|
75
|
-
|
|
76
|
-
momentum2_uvm=momentum2.uvm,
|
|
77
|
-
momentum2_offsets=momentum2.offsets,
|
|
78
|
-
momentum2_placements=momentum2.placements,
|
|
267
|
+
momentum2=momentum2_list,
|
|
79
268
|
# prev_iter
|
|
80
|
-
|
|
81
|
-
prev_iter_dev=prev_iter_dev,
|
|
82
269
|
# row_counter
|
|
83
|
-
#
|
|
84
|
-
|
|
85
|
-
#
|
|
270
|
+
# optim_tensor
|
|
271
|
+
# optim_int
|
|
272
|
+
# optim_float
|
|
273
|
+
optim_float=optim_float,
|
|
274
|
+
# optim_bool
|
|
275
|
+
# optim symint args
|
|
86
276
|
# total_unique_indices
|
|
87
|
-
output_dtype=common_args.output_dtype,
|
|
88
|
-
is_experimental=common_args.is_experimental,
|
|
89
|
-
use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
|
|
90
|
-
use_homogeneous_placements=common_args.use_homogeneous_placements,
|
|
91
|
-
apply_global_weight_decay=apply_global_weight_decay,
|
|
92
|
-
gwd_lower_bound=gwd_lower_bound,
|
|
93
277
|
)
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
|
|
9
11
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
@@ -15,6 +17,12 @@
|
|
|
15
17
|
|
|
16
18
|
import torch
|
|
17
19
|
from .lookup_args import *
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
18
26
|
def invoke(
|
|
19
27
|
common_args: CommonArgs,
|
|
20
28
|
optimizer_args: OptimizerArgs,
|
|
@@ -24,91 +32,198 @@ def invoke(
|
|
|
24
32
|
# only pass prev_iter_dev since prev_iter is never created on UVM
|
|
25
33
|
prev_iter_dev: Optional[torch.Tensor] = None,
|
|
26
34
|
gwd_lower_bound: float = 0.0,
|
|
35
|
+
mixed_D: bool = True,
|
|
27
36
|
) -> torch.Tensor:
|
|
28
|
-
|
|
37
|
+
# host_weights is only used for CPU training
|
|
38
|
+
use_cpu = common_args.host_weights.numel() > 0
|
|
29
39
|
vbe_metadata = common_args.vbe_metadata
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
40
|
+
|
|
41
|
+
# pack weights
|
|
42
|
+
weights = [
|
|
43
|
+
common_args.dev_weights,
|
|
44
|
+
common_args.uvm_weights,
|
|
45
|
+
common_args.weights_placements,
|
|
46
|
+
common_args.weights_offsets,
|
|
47
|
+
common_args.lxu_cache_weights,
|
|
48
|
+
] if not use_cpu else [
|
|
49
|
+
common_args.host_weights,
|
|
50
|
+
common_args.weights_placements,
|
|
51
|
+
common_args.weights_offsets,
|
|
52
|
+
]
|
|
53
|
+
dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
|
|
54
|
+
"B_offsets": vbe_metadata.B_offsets,
|
|
55
|
+
"vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
|
|
56
|
+
"vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
|
|
57
|
+
"lxu_cache_locations": common_args.lxu_cache_locations,
|
|
58
|
+
"uvm_cache_stats": common_args.uvm_cache_stats,
|
|
59
|
+
"vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
dict_aux_int: Dict[str, int] = {
|
|
63
|
+
"iter": iter,
|
|
64
|
+
"info_B_num_bits": common_args.info_B_num_bits,
|
|
65
|
+
"info_B_mask": common_args.info_B_mask,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
dict_aux_float: Dict[str, float] = {
|
|
69
|
+
"gwd_lower_bound": gwd_lower_bound,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
dict_aux_bool: Dict[str, bool] = {
|
|
73
|
+
"is_experimental_tbe": common_args.is_experimental,
|
|
74
|
+
"use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
|
|
75
|
+
"use_homogeneous_placements": common_args.use_homogeneous_placements,
|
|
76
|
+
"apply_global_weight_decay": apply_global_weight_decay,
|
|
77
|
+
"mixed_D": mixed_D,
|
|
78
|
+
}
|
|
79
|
+
dict_optim_int: Dict[str, int] = {}
|
|
80
|
+
dict_optim_float: Dict[str, float] = {}
|
|
81
|
+
dict_optim_bool: Dict[str, bool] = {}
|
|
82
|
+
|
|
83
|
+
# Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
|
|
84
|
+
dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# optimizer_args # if optimizer == none
|
|
88
|
+
dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
|
|
89
|
+
dict_aux_float["max_gradient"] = optimizer_args.max_gradient
|
|
90
|
+
dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
|
|
91
|
+
dict_optim_float["eps"] = optimizer_args.eps
|
|
92
|
+
dict_optim_float["weight_decay"] = optimizer_args.weight_decay
|
|
93
|
+
dict_optim_int["weight_decay_mode"] = optimizer_args.weight_decay_mode
|
|
94
|
+
dict_optim_float["max_norm"] = optimizer_args.max_norm
|
|
95
|
+
|
|
96
|
+
momentum1_list = [
|
|
97
|
+
momentum1.dev,
|
|
98
|
+
momentum1.uvm,
|
|
99
|
+
momentum1.placements,
|
|
100
|
+
momentum1.offsets,
|
|
101
|
+
] if not use_cpu else [
|
|
102
|
+
momentum1.host,
|
|
103
|
+
momentum1.placements,
|
|
104
|
+
momentum1.offsets,
|
|
105
|
+
] if momentum1 is not None else None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
aux_tensor: List[Optional[torch.Tensor]] = []
|
|
109
|
+
assert "B_offsets" in dict_aux_tensor, (
|
|
110
|
+
"B_offsets must be in dict_aux_tensor. "
|
|
111
|
+
"Please check the frontend and backend version. "
|
|
112
|
+
)
|
|
113
|
+
aux_tensor.append(dict_aux_tensor["B_offsets"])
|
|
114
|
+
assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
|
|
115
|
+
"vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
|
|
116
|
+
"Please check the frontend and backend version. "
|
|
117
|
+
)
|
|
118
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
|
|
119
|
+
assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
|
|
120
|
+
"vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
|
|
121
|
+
"Please check the frontend and backend version. "
|
|
122
|
+
)
|
|
123
|
+
aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
|
|
124
|
+
assert "lxu_cache_locations" in dict_aux_tensor, (
|
|
125
|
+
"lxu_cache_locations must be in dict_aux_tensor. "
|
|
126
|
+
"Please check the frontend and backend version. "
|
|
127
|
+
)
|
|
128
|
+
aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
|
|
129
|
+
assert "uvm_cache_stats" in dict_aux_tensor, (
|
|
130
|
+
"uvm_cache_stats must be in dict_aux_tensor. "
|
|
131
|
+
"Please check the frontend and backend version. "
|
|
132
|
+
)
|
|
133
|
+
aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
|
|
134
|
+
assert "prev_iter_dev" in dict_aux_tensor, (
|
|
135
|
+
"prev_iter_dev must be in dict_aux_tensor. "
|
|
136
|
+
"Please check the frontend and backend version. "
|
|
137
|
+
)
|
|
138
|
+
aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
|
|
139
|
+
assert "vbe_output_offsets" in dict_aux_tensor, (
|
|
140
|
+
"vbe_output_offsets must be in dict_aux_tensor. "
|
|
141
|
+
"Please check the frontend and backend version. "
|
|
142
|
+
)
|
|
143
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
|
|
144
|
+
|
|
145
|
+
aux_int: List[int] = []
|
|
146
|
+
assert "iter" in dict_aux_int, (
|
|
147
|
+
"iter must be in dict_aux_int. "
|
|
148
|
+
"Please check the frontend and backend version. "
|
|
149
|
+
)
|
|
150
|
+
aux_int.append(dict_aux_int["iter"])
|
|
151
|
+
assert "info_B_num_bits" in dict_aux_int, (
|
|
152
|
+
"info_B_num_bits must be in dict_aux_int. "
|
|
153
|
+
"Please check the frontend and backend version. "
|
|
154
|
+
)
|
|
155
|
+
aux_int.append(dict_aux_int["info_B_num_bits"])
|
|
156
|
+
assert "info_B_mask" in dict_aux_int, (
|
|
157
|
+
"info_B_mask must be in dict_aux_int. "
|
|
158
|
+
"Please check the frontend and backend version. "
|
|
159
|
+
)
|
|
160
|
+
aux_int.append(dict_aux_int["info_B_mask"])
|
|
161
|
+
|
|
162
|
+
aux_float: List[float] = []
|
|
163
|
+
assert "gwd_lower_bound" in dict_aux_float, (
|
|
164
|
+
"gwd_lower_bound must be in dict_aux_float. "
|
|
165
|
+
"Please check the frontend and backend version. "
|
|
166
|
+
)
|
|
167
|
+
aux_float.append(dict_aux_float["gwd_lower_bound"])
|
|
168
|
+
assert "max_gradient" in dict_aux_float, (
|
|
169
|
+
"max_gradient must be in dict_aux_float. "
|
|
170
|
+
"Please check the frontend and backend version. "
|
|
171
|
+
)
|
|
172
|
+
aux_float.append(dict_aux_float["max_gradient"])
|
|
173
|
+
|
|
174
|
+
aux_bool: List[bool] = []
|
|
175
|
+
assert "is_experimental_tbe" in dict_aux_bool, (
|
|
176
|
+
"is_experimental_tbe must be in dict_aux_bool. "
|
|
177
|
+
"Please check the frontend and backend version. "
|
|
178
|
+
)
|
|
179
|
+
aux_bool.append(dict_aux_bool["is_experimental_tbe"])
|
|
180
|
+
assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
|
|
181
|
+
"use_uniq_cache_locations_bwd must be in dict_aux_bool. "
|
|
182
|
+
"Please check the frontend and backend version. "
|
|
183
|
+
)
|
|
184
|
+
aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
|
|
185
|
+
assert "use_homogeneous_placements" in dict_aux_bool, (
|
|
186
|
+
"use_homogeneous_placements must be in dict_aux_bool. "
|
|
187
|
+
"Please check the frontend and backend version. "
|
|
188
|
+
)
|
|
189
|
+
aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
|
|
190
|
+
assert "apply_global_weight_decay" in dict_aux_bool, (
|
|
191
|
+
"apply_global_weight_decay must be in dict_aux_bool. "
|
|
192
|
+
"Please check the frontend and backend version. "
|
|
193
|
+
)
|
|
194
|
+
aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
|
|
195
|
+
assert "gradient_clipping" in dict_aux_bool, (
|
|
196
|
+
"gradient_clipping must be in dict_aux_bool. "
|
|
197
|
+
"Please check the frontend and backend version. "
|
|
198
|
+
)
|
|
199
|
+
aux_bool.append(dict_aux_bool["gradient_clipping"])
|
|
200
|
+
assert "stochastic_rounding" in dict_aux_bool, (
|
|
201
|
+
"stochastic_rounding must be in dict_aux_bool. "
|
|
202
|
+
"Please check the frontend and backend version. "
|
|
203
|
+
)
|
|
204
|
+
aux_bool.append(dict_aux_bool["stochastic_rounding"])
|
|
205
|
+
assert "mixed_D" in dict_aux_bool, (
|
|
206
|
+
"mixed_D must be in dict_aux_bool. "
|
|
207
|
+
"Please check the frontend and backend version. "
|
|
208
|
+
)
|
|
209
|
+
aux_bool.append(dict_aux_bool["mixed_D"])
|
|
210
|
+
|
|
211
|
+
# optim_int
|
|
212
|
+
optim_int: List[int] = []
|
|
213
|
+
optim_int.append(dict_optim_int["weight_decay_mode"])
|
|
214
|
+
# optim_float
|
|
215
|
+
# ['momentum1', 'learning_rate_tensor', 'optim_int', 'optim_float']
|
|
216
|
+
optim_float: List[float] = []
|
|
217
|
+
optim_float.append(dict_optim_float["eps"])
|
|
218
|
+
optim_float.append(dict_optim_float["weight_decay"])
|
|
219
|
+
optim_float.append(dict_optim_float["max_norm"])
|
|
220
|
+
# optim_bool
|
|
221
|
+
|
|
222
|
+
return torch.ops.fbgemm.split_embedding_codegen_lookup_rowwise_adagrad_function_pt2(
|
|
105
223
|
# common_args
|
|
106
224
|
placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
lxu_cache_weights=common_args.lxu_cache_weights,
|
|
110
|
-
weights_placements=common_args.weights_placements,
|
|
111
|
-
weights_offsets=common_args.weights_offsets,
|
|
225
|
+
# weights
|
|
226
|
+
weights=weights,
|
|
112
227
|
D_offsets=common_args.D_offsets,
|
|
113
228
|
total_D=common_args.total_D,
|
|
114
229
|
max_D=common_args.max_D,
|
|
@@ -119,43 +234,33 @@ def invoke(
|
|
|
119
234
|
pooling_mode=common_args.pooling_mode,
|
|
120
235
|
indice_weights=common_args.indice_weights,
|
|
121
236
|
feature_requires_grad=common_args.feature_requires_grad,
|
|
122
|
-
|
|
123
|
-
uvm_cache_stats=common_args.uvm_cache_stats,
|
|
237
|
+
output_dtype=common_args.output_dtype,
|
|
124
238
|
# VBE metadata
|
|
125
|
-
B_offsets=vbe_metadata.B_offsets,
|
|
126
|
-
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
|
|
127
|
-
vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
|
|
128
239
|
max_B=vbe_metadata.max_B,
|
|
129
240
|
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
|
|
130
241
|
vbe_output_size=vbe_metadata.output_size,
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
242
|
+
vbe_output=vbe_metadata.vbe_output,
|
|
243
|
+
# aux_tensor
|
|
244
|
+
aux_tensor=aux_tensor,
|
|
245
|
+
# aux_int
|
|
246
|
+
aux_int=aux_int,
|
|
247
|
+
# aux_float
|
|
248
|
+
aux_float=aux_float,
|
|
249
|
+
# aux_bool
|
|
250
|
+
aux_bool=aux_bool,
|
|
251
|
+
learning_rate_tensor=common_args.learning_rate_tensor,
|
|
252
|
+
|
|
141
253
|
# momentum1
|
|
142
|
-
|
|
143
|
-
momentum1_uvm=momentum1.uvm,
|
|
144
|
-
momentum1_offsets=momentum1.offsets,
|
|
145
|
-
momentum1_placements=momentum1.placements,
|
|
254
|
+
momentum1 = momentum1_list,
|
|
146
255
|
# momentum2
|
|
147
256
|
# prev_iter
|
|
148
|
-
|
|
149
|
-
prev_iter_dev=prev_iter_dev,
|
|
150
257
|
# row_counter
|
|
151
|
-
#
|
|
152
|
-
|
|
153
|
-
|
|
258
|
+
# optim_tensor
|
|
259
|
+
# optim_int
|
|
260
|
+
optim_int=optim_int,
|
|
261
|
+
# optim_float
|
|
262
|
+
optim_float=optim_float,
|
|
263
|
+
# optim_bool
|
|
264
|
+
# optim symint args
|
|
154
265
|
# total_unique_indices
|
|
155
|
-
output_dtype=common_args.output_dtype,
|
|
156
|
-
is_experimental=common_args.is_experimental,
|
|
157
|
-
use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
|
|
158
|
-
use_homogeneous_placements=common_args.use_homogeneous_placements,
|
|
159
|
-
apply_global_weight_decay=apply_global_weight_decay,
|
|
160
|
-
gwd_lower_bound=gwd_lower_bound,
|
|
161
266
|
)
|