fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
|
|
5
5
|
################################################################################
|
|
6
6
|
|
|
7
|
+
__template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
|
|
8
|
+
|
|
7
9
|
#!/usr/bin/env python3
|
|
8
10
|
|
|
9
11
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
@@ -15,6 +17,12 @@
|
|
|
15
17
|
|
|
16
18
|
import torch
|
|
17
19
|
from .lookup_args import *
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
18
26
|
def invoke(
|
|
19
27
|
common_args: CommonArgs,
|
|
20
28
|
optimizer_args: OptimizerArgs,
|
|
@@ -23,84 +31,177 @@ def invoke(
|
|
|
23
31
|
# only pass prev_iter_dev since prev_iter is never created on UVM
|
|
24
32
|
prev_iter_dev: Optional[torch.Tensor] = None,
|
|
25
33
|
gwd_lower_bound: float = 0.0,
|
|
34
|
+
mixed_D: bool = True,
|
|
26
35
|
) -> torch.Tensor:
|
|
27
|
-
|
|
36
|
+
# host_weights is only used for CPU training
|
|
37
|
+
use_cpu = common_args.host_weights.numel() > 0
|
|
28
38
|
vbe_metadata = common_args.vbe_metadata
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
39
|
+
|
|
40
|
+
# pack weights
|
|
41
|
+
weights = [
|
|
42
|
+
common_args.dev_weights,
|
|
43
|
+
common_args.uvm_weights,
|
|
44
|
+
common_args.weights_placements,
|
|
45
|
+
common_args.weights_offsets,
|
|
46
|
+
common_args.lxu_cache_weights,
|
|
47
|
+
] if not use_cpu else [
|
|
48
|
+
common_args.host_weights,
|
|
49
|
+
common_args.weights_placements,
|
|
50
|
+
common_args.weights_offsets,
|
|
51
|
+
]
|
|
52
|
+
dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
|
|
53
|
+
"B_offsets": vbe_metadata.B_offsets,
|
|
54
|
+
"vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
|
|
55
|
+
"vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
|
|
56
|
+
"lxu_cache_locations": common_args.lxu_cache_locations,
|
|
57
|
+
"uvm_cache_stats": common_args.uvm_cache_stats,
|
|
58
|
+
"vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
dict_aux_int: Dict[str, int] = {
|
|
62
|
+
"iter": iter,
|
|
63
|
+
"info_B_num_bits": common_args.info_B_num_bits,
|
|
64
|
+
"info_B_mask": common_args.info_B_mask,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
dict_aux_float: Dict[str, float] = {
|
|
68
|
+
"gwd_lower_bound": gwd_lower_bound,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
dict_aux_bool: Dict[str, bool] = {
|
|
72
|
+
"is_experimental_tbe": common_args.is_experimental,
|
|
73
|
+
"use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
|
|
74
|
+
"use_homogeneous_placements": common_args.use_homogeneous_placements,
|
|
75
|
+
"apply_global_weight_decay": apply_global_weight_decay,
|
|
76
|
+
"mixed_D": mixed_D,
|
|
77
|
+
}
|
|
78
|
+
dict_optim_int: Dict[str, int] = {}
|
|
79
|
+
dict_optim_float: Dict[str, float] = {}
|
|
80
|
+
dict_optim_bool: Dict[str, bool] = {}
|
|
81
|
+
|
|
82
|
+
# Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
|
|
83
|
+
dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# optimizer_args # if optimizer == none
|
|
87
|
+
dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
|
|
88
|
+
dict_aux_float["max_gradient"] = optimizer_args.max_gradient
|
|
89
|
+
dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
aux_tensor: List[Optional[torch.Tensor]] = []
|
|
93
|
+
assert "B_offsets" in dict_aux_tensor, (
|
|
94
|
+
"B_offsets must be in dict_aux_tensor. "
|
|
95
|
+
"Please check the frontend and backend version. "
|
|
96
|
+
)
|
|
97
|
+
aux_tensor.append(dict_aux_tensor["B_offsets"])
|
|
98
|
+
assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
|
|
99
|
+
"vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
|
|
100
|
+
"Please check the frontend and backend version. "
|
|
101
|
+
)
|
|
102
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
|
|
103
|
+
assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
|
|
104
|
+
"vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
|
|
105
|
+
"Please check the frontend and backend version. "
|
|
106
|
+
)
|
|
107
|
+
aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
|
|
108
|
+
assert "lxu_cache_locations" in dict_aux_tensor, (
|
|
109
|
+
"lxu_cache_locations must be in dict_aux_tensor. "
|
|
110
|
+
"Please check the frontend and backend version. "
|
|
111
|
+
)
|
|
112
|
+
aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
|
|
113
|
+
assert "uvm_cache_stats" in dict_aux_tensor, (
|
|
114
|
+
"uvm_cache_stats must be in dict_aux_tensor. "
|
|
115
|
+
"Please check the frontend and backend version. "
|
|
116
|
+
)
|
|
117
|
+
aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
|
|
118
|
+
assert "prev_iter_dev" in dict_aux_tensor, (
|
|
119
|
+
"prev_iter_dev must be in dict_aux_tensor. "
|
|
120
|
+
"Please check the frontend and backend version. "
|
|
121
|
+
)
|
|
122
|
+
aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
|
|
123
|
+
assert "vbe_output_offsets" in dict_aux_tensor, (
|
|
124
|
+
"vbe_output_offsets must be in dict_aux_tensor. "
|
|
125
|
+
"Please check the frontend and backend version. "
|
|
126
|
+
)
|
|
127
|
+
aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
|
|
128
|
+
|
|
129
|
+
aux_int: List[int] = []
|
|
130
|
+
assert "iter" in dict_aux_int, (
|
|
131
|
+
"iter must be in dict_aux_int. "
|
|
132
|
+
"Please check the frontend and backend version. "
|
|
133
|
+
)
|
|
134
|
+
aux_int.append(dict_aux_int["iter"])
|
|
135
|
+
assert "info_B_num_bits" in dict_aux_int, (
|
|
136
|
+
"info_B_num_bits must be in dict_aux_int. "
|
|
137
|
+
"Please check the frontend and backend version. "
|
|
138
|
+
)
|
|
139
|
+
aux_int.append(dict_aux_int["info_B_num_bits"])
|
|
140
|
+
assert "info_B_mask" in dict_aux_int, (
|
|
141
|
+
"info_B_mask must be in dict_aux_int. "
|
|
142
|
+
"Please check the frontend and backend version. "
|
|
143
|
+
)
|
|
144
|
+
aux_int.append(dict_aux_int["info_B_mask"])
|
|
145
|
+
|
|
146
|
+
aux_float: List[float] = []
|
|
147
|
+
assert "gwd_lower_bound" in dict_aux_float, (
|
|
148
|
+
"gwd_lower_bound must be in dict_aux_float. "
|
|
149
|
+
"Please check the frontend and backend version. "
|
|
150
|
+
)
|
|
151
|
+
aux_float.append(dict_aux_float["gwd_lower_bound"])
|
|
152
|
+
assert "max_gradient" in dict_aux_float, (
|
|
153
|
+
"max_gradient must be in dict_aux_float. "
|
|
154
|
+
"Please check the frontend and backend version. "
|
|
155
|
+
)
|
|
156
|
+
aux_float.append(dict_aux_float["max_gradient"])
|
|
157
|
+
|
|
158
|
+
aux_bool: List[bool] = []
|
|
159
|
+
assert "is_experimental_tbe" in dict_aux_bool, (
|
|
160
|
+
"is_experimental_tbe must be in dict_aux_bool. "
|
|
161
|
+
"Please check the frontend and backend version. "
|
|
162
|
+
)
|
|
163
|
+
aux_bool.append(dict_aux_bool["is_experimental_tbe"])
|
|
164
|
+
assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
|
|
165
|
+
"use_uniq_cache_locations_bwd must be in dict_aux_bool. "
|
|
166
|
+
"Please check the frontend and backend version. "
|
|
167
|
+
)
|
|
168
|
+
aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
|
|
169
|
+
assert "use_homogeneous_placements" in dict_aux_bool, (
|
|
170
|
+
"use_homogeneous_placements must be in dict_aux_bool. "
|
|
171
|
+
"Please check the frontend and backend version. "
|
|
172
|
+
)
|
|
173
|
+
aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
|
|
174
|
+
assert "apply_global_weight_decay" in dict_aux_bool, (
|
|
175
|
+
"apply_global_weight_decay must be in dict_aux_bool. "
|
|
176
|
+
"Please check the frontend and backend version. "
|
|
177
|
+
)
|
|
178
|
+
aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
|
|
179
|
+
assert "gradient_clipping" in dict_aux_bool, (
|
|
180
|
+
"gradient_clipping must be in dict_aux_bool. "
|
|
181
|
+
"Please check the frontend and backend version. "
|
|
182
|
+
)
|
|
183
|
+
aux_bool.append(dict_aux_bool["gradient_clipping"])
|
|
184
|
+
assert "stochastic_rounding" in dict_aux_bool, (
|
|
185
|
+
"stochastic_rounding must be in dict_aux_bool. "
|
|
186
|
+
"Please check the frontend and backend version. "
|
|
187
|
+
)
|
|
188
|
+
aux_bool.append(dict_aux_bool["stochastic_rounding"])
|
|
189
|
+
assert "mixed_D" in dict_aux_bool, (
|
|
190
|
+
"mixed_D must be in dict_aux_bool. "
|
|
191
|
+
"Please check the frontend and backend version. "
|
|
192
|
+
)
|
|
193
|
+
aux_bool.append(dict_aux_bool["mixed_D"])
|
|
194
|
+
|
|
195
|
+
# optim_int
|
|
196
|
+
# optim_float
|
|
197
|
+
# ['learning_rate_tensor']
|
|
198
|
+
# optim_bool
|
|
199
|
+
|
|
200
|
+
return torch.ops.fbgemm.split_embedding_codegen_lookup_sgd_function_pt2(
|
|
97
201
|
# common_args
|
|
98
202
|
placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
lxu_cache_weights=common_args.lxu_cache_weights,
|
|
102
|
-
weights_placements=common_args.weights_placements,
|
|
103
|
-
weights_offsets=common_args.weights_offsets,
|
|
203
|
+
# weights
|
|
204
|
+
weights=weights,
|
|
104
205
|
D_offsets=common_args.D_offsets,
|
|
105
206
|
total_D=common_args.total_D,
|
|
106
207
|
max_D=common_args.max_D,
|
|
@@ -111,35 +212,30 @@ def invoke(
|
|
|
111
212
|
pooling_mode=common_args.pooling_mode,
|
|
112
213
|
indice_weights=common_args.indice_weights,
|
|
113
214
|
feature_requires_grad=common_args.feature_requires_grad,
|
|
114
|
-
|
|
115
|
-
uvm_cache_stats=common_args.uvm_cache_stats,
|
|
215
|
+
output_dtype=common_args.output_dtype,
|
|
116
216
|
# VBE metadata
|
|
117
|
-
B_offsets=vbe_metadata.B_offsets,
|
|
118
|
-
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
|
|
119
|
-
vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
|
|
120
217
|
max_B=vbe_metadata.max_B,
|
|
121
218
|
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
|
|
122
219
|
vbe_output_size=vbe_metadata.output_size,
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
220
|
+
vbe_output=vbe_metadata.vbe_output,
|
|
221
|
+
# aux_tensor
|
|
222
|
+
aux_tensor=aux_tensor,
|
|
223
|
+
# aux_int
|
|
224
|
+
aux_int=aux_int,
|
|
225
|
+
# aux_float
|
|
226
|
+
aux_float=aux_float,
|
|
227
|
+
# aux_bool
|
|
228
|
+
aux_bool=aux_bool,
|
|
229
|
+
learning_rate_tensor=common_args.learning_rate_tensor,
|
|
230
|
+
|
|
129
231
|
# momentum1
|
|
130
232
|
# momentum2
|
|
131
233
|
# prev_iter
|
|
132
|
-
|
|
133
|
-
prev_iter_dev=prev_iter_dev,
|
|
134
234
|
# row_counter
|
|
135
|
-
#
|
|
136
|
-
|
|
137
|
-
#
|
|
235
|
+
# optim_tensor
|
|
236
|
+
# optim_int
|
|
237
|
+
# optim_float
|
|
238
|
+
# optim_bool
|
|
239
|
+
# optim symint args
|
|
138
240
|
# total_unique_indices
|
|
139
|
-
output_dtype=common_args.output_dtype,
|
|
140
|
-
is_experimental=common_args.is_experimental,
|
|
141
|
-
use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
|
|
142
|
-
use_homogeneous_placements=common_args.use_homogeneous_placements,
|
|
143
|
-
apply_global_weight_decay=apply_global_weight_decay,
|
|
144
|
-
gwd_lower_bound=gwd_lower_bound,
|
|
145
241
|
)
|