fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -15,6 +17,12 @@
15
17
 
16
18
  import torch
17
19
  from .lookup_args import *
20
+
21
+
22
+
23
+
24
+
25
+
18
26
  def invoke(
19
27
  common_args: CommonArgs,
20
28
  optimizer_args: OptimizerArgs,
@@ -23,84 +31,177 @@ def invoke(
23
31
  # only pass prev_iter_dev since prev_iter is never created on UVM
24
32
  prev_iter_dev: Optional[torch.Tensor] = None,
25
33
  gwd_lower_bound: float = 0.0,
34
+ mixed_D: bool = True,
26
35
  ) -> torch.Tensor:
27
-
36
+ # host_weights is only used for CPU training
37
+ use_cpu = common_args.host_weights.numel() > 0
28
38
  vbe_metadata = common_args.vbe_metadata
29
- if (common_args.host_weights.numel() > 0):
30
- T = common_args.D_offsets.numel() - 1
31
- vbe: bool = vbe_metadata.B_offsets is not None
32
- if vbe:
33
- # create offsets with fixed batch size max_B
34
- # not efficient but for now we just need a functional implementation for CPU
35
- max_B = vbe_metadata.max_B
36
- offsets = torch.empty([T * max_B + 1], dtype=common_args.offsets.dtype, device=common_args.offsets.device)
37
- for t in range(T):
38
- B_offsets = vbe_metadata.B_offsets
39
- assert isinstance(B_offsets, torch.Tensor)
40
- begin = B_offsets[t]
41
- end = B_offsets[t + 1]
42
- offsets[t * max_B : t * max_B + end - begin] = common_args.offsets[begin : end]
43
- offsets[t * max_B + end - begin : (t + 1) * max_B] = common_args.offsets[end]
44
- offsets[-1] = common_args.offsets[-1]
45
- else:
46
- offsets = common_args.offsets
47
- output = torch.ops.fbgemm.split_embedding_codegen_lookup_sgd_function_cpu(
48
- # common_args
49
- host_weights=common_args.host_weights,
50
- weights_placements=common_args.weights_placements,
51
- weights_offsets=common_args.weights_offsets,
52
- D_offsets=common_args.D_offsets,
53
- total_D=common_args.total_D,
54
- max_D=common_args.max_D,
55
- hash_size_cumsum=common_args.hash_size_cumsum,
56
- total_hash_size_bits=common_args.total_hash_size_bits,
57
- indices=common_args.indices,
58
- offsets=offsets,
59
- pooling_mode=common_args.pooling_mode,
60
- indice_weights=common_args.indice_weights,
61
- feature_requires_grad=common_args.feature_requires_grad,
62
- # optimizer_args
63
- gradient_clipping = optimizer_args.gradient_clipping,
64
- max_gradient=optimizer_args.max_gradient,
65
- stochastic_rounding=optimizer_args.stochastic_rounding,
66
- learning_rate=optimizer_args.learning_rate,
67
- # momentum1
68
- # momentum2
69
- # prev_iter
70
- # row_counter
71
- # iter
72
- # max counter
73
- )
74
- if vbe:
75
- output_new = torch.empty([vbe_metadata.output_size], dtype=output.dtype, device=output.device)
76
- B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
77
- assert isinstance(B_offsets_rank_per_feature, torch.Tensor)
78
- output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
79
- assert isinstance(output_offsets_feature_rank, torch.Tensor)
80
- R = B_offsets_rank_per_feature.size(1) - 1
81
- for r in range(R):
82
- D_offset = 0
83
- for t in range(T):
84
- o_begin = output_offsets_feature_rank[r * T + t].item()
85
- o_end = output_offsets_feature_rank[r * T + t + 1].item()
86
- D = common_args.D_offsets[t + 1].item() - common_args.D_offsets[t].item()
87
- b_begin = B_offsets_rank_per_feature[t][r].item()
88
- b_end = B_offsets_rank_per_feature[t][r + 1].item()
89
- assert o_end - o_begin == (b_end - b_begin) * D
90
- output_new[o_begin : o_end] = output[b_begin : b_end, D_offset : D_offset + D].flatten()
91
- D_offset += D
92
- return output_new
93
- else:
94
- return output
95
-
96
- return torch.ops.fbgemm.split_embedding_codegen_lookup_sgd_function(
39
+
40
+ # pack weights
41
+ weights = [
42
+ common_args.dev_weights,
43
+ common_args.uvm_weights,
44
+ common_args.weights_placements,
45
+ common_args.weights_offsets,
46
+ common_args.lxu_cache_weights,
47
+ ] if not use_cpu else [
48
+ common_args.host_weights,
49
+ common_args.weights_placements,
50
+ common_args.weights_offsets,
51
+ ]
52
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
53
+ "B_offsets": vbe_metadata.B_offsets,
54
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
55
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
56
+ "lxu_cache_locations": common_args.lxu_cache_locations,
57
+ "uvm_cache_stats": common_args.uvm_cache_stats,
58
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
59
+ }
60
+
61
+ dict_aux_int: Dict[str, int] = {
62
+ "iter": iter,
63
+ "info_B_num_bits": common_args.info_B_num_bits,
64
+ "info_B_mask": common_args.info_B_mask,
65
+ }
66
+
67
+ dict_aux_float: Dict[str, float] = {
68
+ "gwd_lower_bound": gwd_lower_bound,
69
+ }
70
+
71
+ dict_aux_bool: Dict[str, bool] = {
72
+ "is_experimental_tbe": common_args.is_experimental,
73
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
74
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
75
+ "apply_global_weight_decay": apply_global_weight_decay,
76
+ "mixed_D": mixed_D,
77
+ }
78
+ dict_optim_int: Dict[str, int] = {}
79
+ dict_optim_float: Dict[str, float] = {}
80
+ dict_optim_bool: Dict[str, bool] = {}
81
+
82
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
83
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
84
+
85
+
86
+ # optimizer_args # if optimizer == none
87
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
88
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
89
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
90
+
91
+
92
+ aux_tensor: List[Optional[torch.Tensor]] = []
93
+ assert "B_offsets" in dict_aux_tensor, (
94
+ "B_offsets must be in dict_aux_tensor. "
95
+ "Please check the frontend and backend version. "
96
+ )
97
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
98
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
99
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
100
+ "Please check the frontend and backend version. "
101
+ )
102
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
103
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
104
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
105
+ "Please check the frontend and backend version. "
106
+ )
107
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
108
+ assert "lxu_cache_locations" in dict_aux_tensor, (
109
+ "lxu_cache_locations must be in dict_aux_tensor. "
110
+ "Please check the frontend and backend version. "
111
+ )
112
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
113
+ assert "uvm_cache_stats" in dict_aux_tensor, (
114
+ "uvm_cache_stats must be in dict_aux_tensor. "
115
+ "Please check the frontend and backend version. "
116
+ )
117
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
118
+ assert "prev_iter_dev" in dict_aux_tensor, (
119
+ "prev_iter_dev must be in dict_aux_tensor. "
120
+ "Please check the frontend and backend version. "
121
+ )
122
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
123
+ assert "vbe_output_offsets" in dict_aux_tensor, (
124
+ "vbe_output_offsets must be in dict_aux_tensor. "
125
+ "Please check the frontend and backend version. "
126
+ )
127
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
128
+
129
+ aux_int: List[int] = []
130
+ assert "iter" in dict_aux_int, (
131
+ "iter must be in dict_aux_int. "
132
+ "Please check the frontend and backend version. "
133
+ )
134
+ aux_int.append(dict_aux_int["iter"])
135
+ assert "info_B_num_bits" in dict_aux_int, (
136
+ "info_B_num_bits must be in dict_aux_int. "
137
+ "Please check the frontend and backend version. "
138
+ )
139
+ aux_int.append(dict_aux_int["info_B_num_bits"])
140
+ assert "info_B_mask" in dict_aux_int, (
141
+ "info_B_mask must be in dict_aux_int. "
142
+ "Please check the frontend and backend version. "
143
+ )
144
+ aux_int.append(dict_aux_int["info_B_mask"])
145
+
146
+ aux_float: List[float] = []
147
+ assert "gwd_lower_bound" in dict_aux_float, (
148
+ "gwd_lower_bound must be in dict_aux_float. "
149
+ "Please check the frontend and backend version. "
150
+ )
151
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
152
+ assert "max_gradient" in dict_aux_float, (
153
+ "max_gradient must be in dict_aux_float. "
154
+ "Please check the frontend and backend version. "
155
+ )
156
+ aux_float.append(dict_aux_float["max_gradient"])
157
+
158
+ aux_bool: List[bool] = []
159
+ assert "is_experimental_tbe" in dict_aux_bool, (
160
+ "is_experimental_tbe must be in dict_aux_bool. "
161
+ "Please check the frontend and backend version. "
162
+ )
163
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
164
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
165
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
166
+ "Please check the frontend and backend version. "
167
+ )
168
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
169
+ assert "use_homogeneous_placements" in dict_aux_bool, (
170
+ "use_homogeneous_placements must be in dict_aux_bool. "
171
+ "Please check the frontend and backend version. "
172
+ )
173
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
174
+ assert "apply_global_weight_decay" in dict_aux_bool, (
175
+ "apply_global_weight_decay must be in dict_aux_bool. "
176
+ "Please check the frontend and backend version. "
177
+ )
178
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
179
+ assert "gradient_clipping" in dict_aux_bool, (
180
+ "gradient_clipping must be in dict_aux_bool. "
181
+ "Please check the frontend and backend version. "
182
+ )
183
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
184
+ assert "stochastic_rounding" in dict_aux_bool, (
185
+ "stochastic_rounding must be in dict_aux_bool. "
186
+ "Please check the frontend and backend version. "
187
+ )
188
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
189
+ assert "mixed_D" in dict_aux_bool, (
190
+ "mixed_D must be in dict_aux_bool. "
191
+ "Please check the frontend and backend version. "
192
+ )
193
+ aux_bool.append(dict_aux_bool["mixed_D"])
194
+
195
+ # optim_int
196
+ # optim_float
197
+ # ['learning_rate_tensor']
198
+ # optim_bool
199
+
200
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_sgd_function_pt2(
97
201
  # common_args
98
202
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
99
- dev_weights=common_args.dev_weights,
100
- uvm_weights=common_args.uvm_weights,
101
- lxu_cache_weights=common_args.lxu_cache_weights,
102
- weights_placements=common_args.weights_placements,
103
- weights_offsets=common_args.weights_offsets,
203
+ # weights
204
+ weights=weights,
104
205
  D_offsets=common_args.D_offsets,
105
206
  total_D=common_args.total_D,
106
207
  max_D=common_args.max_D,
@@ -111,35 +212,30 @@ def invoke(
111
212
  pooling_mode=common_args.pooling_mode,
112
213
  indice_weights=common_args.indice_weights,
113
214
  feature_requires_grad=common_args.feature_requires_grad,
114
- lxu_cache_locations=common_args.lxu_cache_locations,
115
- uvm_cache_stats=common_args.uvm_cache_stats,
215
+ output_dtype=common_args.output_dtype,
116
216
  # VBE metadata
117
- B_offsets=vbe_metadata.B_offsets,
118
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
119
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
120
217
  max_B=vbe_metadata.max_B,
121
218
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
122
219
  vbe_output_size=vbe_metadata.output_size,
123
- # optimizer_args
124
- gradient_clipping = optimizer_args.gradient_clipping,
125
- max_gradient=optimizer_args.max_gradient,
126
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
127
- # V1 interface still accepts learning_rate as float
128
- learning_rate=optimizer_args.learning_rate,
220
+ vbe_output=vbe_metadata.vbe_output,
221
+ # aux_tensor
222
+ aux_tensor=aux_tensor,
223
+ # aux_int
224
+ aux_int=aux_int,
225
+ # aux_float
226
+ aux_float=aux_float,
227
+ # aux_bool
228
+ aux_bool=aux_bool,
229
+ learning_rate_tensor=common_args.learning_rate_tensor,
230
+
129
231
  # momentum1
130
232
  # momentum2
131
233
  # prev_iter
132
-
133
- prev_iter_dev=prev_iter_dev,
134
234
  # row_counter
135
- # iter
136
- iter=iter,
137
- # max counter
235
+ # optim_tensor
236
+ # optim_int
237
+ # optim_float
238
+ # optim_bool
239
+ # optim symint args
138
240
  # total_unique_indices
139
- output_dtype=common_args.output_dtype,
140
- is_experimental=common_args.is_experimental,
141
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
142
- use_homogeneous_placements=common_args.use_homogeneous_placements,
143
- apply_global_weight_decay=apply_global_weight_decay,
144
- gwd_lower_bound=gwd_lower_bound,
145
241
  )