fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -15,6 +17,12 @@
15
17
 
16
18
  import torch
17
19
  from .lookup_args import *
20
+
21
+
22
+
23
+
24
+
25
+
18
26
  def invoke(
19
27
  common_args: CommonArgs,
20
28
  optimizer_args: OptimizerArgs,
@@ -24,88 +32,191 @@ def invoke(
24
32
  # only pass prev_iter_dev since prev_iter is never created on UVM
25
33
  prev_iter_dev: Optional[torch.Tensor] = None,
26
34
  gwd_lower_bound: float = 0.0,
35
+ mixed_D: bool = True,
27
36
  ) -> torch.Tensor:
28
-
37
+ # host_weights is only used for CPU training
38
+ use_cpu = common_args.host_weights.numel() > 0
29
39
  vbe_metadata = common_args.vbe_metadata
30
- if (common_args.host_weights.numel() > 0):
31
- T = common_args.D_offsets.numel() - 1
32
- vbe: bool = vbe_metadata.B_offsets is not None
33
- if vbe:
34
- # create offsets with fixed batch size max_B
35
- # not efficient but for now we just need a functional implementation for CPU
36
- max_B = vbe_metadata.max_B
37
- offsets = torch.empty([T * max_B + 1], dtype=common_args.offsets.dtype, device=common_args.offsets.device)
38
- for t in range(T):
39
- B_offsets = vbe_metadata.B_offsets
40
- assert isinstance(B_offsets, torch.Tensor)
41
- begin = B_offsets[t]
42
- end = B_offsets[t + 1]
43
- offsets[t * max_B : t * max_B + end - begin] = common_args.offsets[begin : end]
44
- offsets[t * max_B + end - begin : (t + 1) * max_B] = common_args.offsets[end]
45
- offsets[-1] = common_args.offsets[-1]
46
- else:
47
- offsets = common_args.offsets
48
- output = torch.ops.fbgemm.split_embedding_codegen_lookup_adagrad_function_cpu(
49
- # common_args
50
- host_weights=common_args.host_weights,
51
- weights_placements=common_args.weights_placements,
52
- weights_offsets=common_args.weights_offsets,
53
- D_offsets=common_args.D_offsets,
54
- total_D=common_args.total_D,
55
- max_D=common_args.max_D,
56
- hash_size_cumsum=common_args.hash_size_cumsum,
57
- total_hash_size_bits=common_args.total_hash_size_bits,
58
- indices=common_args.indices,
59
- offsets=offsets,
60
- pooling_mode=common_args.pooling_mode,
61
- indice_weights=common_args.indice_weights,
62
- feature_requires_grad=common_args.feature_requires_grad,
63
- # optimizer_args
64
- gradient_clipping = optimizer_args.gradient_clipping,
65
- max_gradient=optimizer_args.max_gradient,
66
- stochastic_rounding=optimizer_args.stochastic_rounding,
67
- learning_rate=optimizer_args.learning_rate,
68
- eps=optimizer_args.eps,
69
- # momentum1
70
- momentum1_host=momentum1.host,
71
- momentum1_offsets=momentum1.offsets,
72
- momentum1_placements=momentum1.placements,
73
- # momentum2
74
- # prev_iter
75
- # row_counter
76
- # iter
77
- # max counter
78
- )
79
- if vbe:
80
- output_new = torch.empty([vbe_metadata.output_size], dtype=output.dtype, device=output.device)
81
- B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
82
- assert isinstance(B_offsets_rank_per_feature, torch.Tensor)
83
- output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
84
- assert isinstance(output_offsets_feature_rank, torch.Tensor)
85
- R = B_offsets_rank_per_feature.size(1) - 1
86
- for r in range(R):
87
- D_offset = 0
88
- for t in range(T):
89
- o_begin = output_offsets_feature_rank[r * T + t].item()
90
- o_end = output_offsets_feature_rank[r * T + t + 1].item()
91
- D = common_args.D_offsets[t + 1].item() - common_args.D_offsets[t].item()
92
- b_begin = B_offsets_rank_per_feature[t][r].item()
93
- b_end = B_offsets_rank_per_feature[t][r + 1].item()
94
- assert o_end - o_begin == (b_end - b_begin) * D
95
- output_new[o_begin : o_end] = output[b_begin : b_end, D_offset : D_offset + D].flatten()
96
- D_offset += D
97
- return output_new
98
- else:
99
- return output
100
-
101
- return torch.ops.fbgemm.split_embedding_codegen_lookup_adagrad_function(
40
+
41
+ # pack weights
42
+ weights = [
43
+ common_args.dev_weights,
44
+ common_args.uvm_weights,
45
+ common_args.weights_placements,
46
+ common_args.weights_offsets,
47
+ common_args.lxu_cache_weights,
48
+ ] if not use_cpu else [
49
+ common_args.host_weights,
50
+ common_args.weights_placements,
51
+ common_args.weights_offsets,
52
+ ]
53
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
54
+ "B_offsets": vbe_metadata.B_offsets,
55
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
56
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
57
+ "lxu_cache_locations": common_args.lxu_cache_locations,
58
+ "uvm_cache_stats": common_args.uvm_cache_stats,
59
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
60
+ }
61
+
62
+ dict_aux_int: Dict[str, int] = {
63
+ "iter": iter,
64
+ "info_B_num_bits": common_args.info_B_num_bits,
65
+ "info_B_mask": common_args.info_B_mask,
66
+ }
67
+
68
+ dict_aux_float: Dict[str, float] = {
69
+ "gwd_lower_bound": gwd_lower_bound,
70
+ }
71
+
72
+ dict_aux_bool: Dict[str, bool] = {
73
+ "is_experimental_tbe": common_args.is_experimental,
74
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
75
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
76
+ "apply_global_weight_decay": apply_global_weight_decay,
77
+ "mixed_D": mixed_D,
78
+ }
79
+ dict_optim_int: Dict[str, int] = {}
80
+ dict_optim_float: Dict[str, float] = {}
81
+ dict_optim_bool: Dict[str, bool] = {}
82
+
83
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
84
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
85
+
86
+
87
+ # optimizer_args # if optimizer == none
88
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
89
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
90
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
91
+ dict_optim_float["eps"] = optimizer_args.eps
92
+
93
+ momentum1_list = [
94
+ momentum1.dev,
95
+ momentum1.uvm,
96
+ momentum1.placements,
97
+ momentum1.offsets,
98
+ ] if not use_cpu else [
99
+ momentum1.host,
100
+ momentum1.placements,
101
+ momentum1.offsets,
102
+ ] if momentum1 is not None else None
103
+
104
+
105
+ aux_tensor: List[Optional[torch.Tensor]] = []
106
+ assert "B_offsets" in dict_aux_tensor, (
107
+ "B_offsets must be in dict_aux_tensor. "
108
+ "Please check the frontend and backend version. "
109
+ )
110
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
111
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
112
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
113
+ "Please check the frontend and backend version. "
114
+ )
115
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
116
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
117
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
118
+ "Please check the frontend and backend version. "
119
+ )
120
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
121
+ assert "lxu_cache_locations" in dict_aux_tensor, (
122
+ "lxu_cache_locations must be in dict_aux_tensor. "
123
+ "Please check the frontend and backend version. "
124
+ )
125
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
126
+ assert "uvm_cache_stats" in dict_aux_tensor, (
127
+ "uvm_cache_stats must be in dict_aux_tensor. "
128
+ "Please check the frontend and backend version. "
129
+ )
130
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
131
+ assert "prev_iter_dev" in dict_aux_tensor, (
132
+ "prev_iter_dev must be in dict_aux_tensor. "
133
+ "Please check the frontend and backend version. "
134
+ )
135
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
136
+ assert "vbe_output_offsets" in dict_aux_tensor, (
137
+ "vbe_output_offsets must be in dict_aux_tensor. "
138
+ "Please check the frontend and backend version. "
139
+ )
140
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
141
+
142
+ aux_int: List[int] = []
143
+ assert "iter" in dict_aux_int, (
144
+ "iter must be in dict_aux_int. "
145
+ "Please check the frontend and backend version. "
146
+ )
147
+ aux_int.append(dict_aux_int["iter"])
148
+ assert "info_B_num_bits" in dict_aux_int, (
149
+ "info_B_num_bits must be in dict_aux_int. "
150
+ "Please check the frontend and backend version. "
151
+ )
152
+ aux_int.append(dict_aux_int["info_B_num_bits"])
153
+ assert "info_B_mask" in dict_aux_int, (
154
+ "info_B_mask must be in dict_aux_int. "
155
+ "Please check the frontend and backend version. "
156
+ )
157
+ aux_int.append(dict_aux_int["info_B_mask"])
158
+
159
+ aux_float: List[float] = []
160
+ assert "gwd_lower_bound" in dict_aux_float, (
161
+ "gwd_lower_bound must be in dict_aux_float. "
162
+ "Please check the frontend and backend version. "
163
+ )
164
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
165
+ assert "max_gradient" in dict_aux_float, (
166
+ "max_gradient must be in dict_aux_float. "
167
+ "Please check the frontend and backend version. "
168
+ )
169
+ aux_float.append(dict_aux_float["max_gradient"])
170
+
171
+ aux_bool: List[bool] = []
172
+ assert "is_experimental_tbe" in dict_aux_bool, (
173
+ "is_experimental_tbe must be in dict_aux_bool. "
174
+ "Please check the frontend and backend version. "
175
+ )
176
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
177
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
178
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
179
+ "Please check the frontend and backend version. "
180
+ )
181
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
182
+ assert "use_homogeneous_placements" in dict_aux_bool, (
183
+ "use_homogeneous_placements must be in dict_aux_bool. "
184
+ "Please check the frontend and backend version. "
185
+ )
186
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
187
+ assert "apply_global_weight_decay" in dict_aux_bool, (
188
+ "apply_global_weight_decay must be in dict_aux_bool. "
189
+ "Please check the frontend and backend version. "
190
+ )
191
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
192
+ assert "gradient_clipping" in dict_aux_bool, (
193
+ "gradient_clipping must be in dict_aux_bool. "
194
+ "Please check the frontend and backend version. "
195
+ )
196
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
197
+ assert "stochastic_rounding" in dict_aux_bool, (
198
+ "stochastic_rounding must be in dict_aux_bool. "
199
+ "Please check the frontend and backend version. "
200
+ )
201
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
202
+ assert "mixed_D" in dict_aux_bool, (
203
+ "mixed_D must be in dict_aux_bool. "
204
+ "Please check the frontend and backend version. "
205
+ )
206
+ aux_bool.append(dict_aux_bool["mixed_D"])
207
+
208
+ # optim_int
209
+ # optim_float
210
+ # ['momentum1', 'learning_rate_tensor', 'optim_float']
211
+ optim_float: List[float] = []
212
+ optim_float.append(dict_optim_float["eps"])
213
+ # optim_bool
214
+
215
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_adagrad_function_pt2(
102
216
  # common_args
103
217
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
104
- dev_weights=common_args.dev_weights,
105
- uvm_weights=common_args.uvm_weights,
106
- lxu_cache_weights=common_args.lxu_cache_weights,
107
- weights_placements=common_args.weights_placements,
108
- weights_offsets=common_args.weights_offsets,
218
+ # weights
219
+ weights=weights,
109
220
  D_offsets=common_args.D_offsets,
110
221
  total_D=common_args.total_D,
111
222
  max_D=common_args.max_D,
@@ -116,40 +227,32 @@ def invoke(
116
227
  pooling_mode=common_args.pooling_mode,
117
228
  indice_weights=common_args.indice_weights,
118
229
  feature_requires_grad=common_args.feature_requires_grad,
119
- lxu_cache_locations=common_args.lxu_cache_locations,
120
- uvm_cache_stats=common_args.uvm_cache_stats,
230
+ output_dtype=common_args.output_dtype,
121
231
  # VBE metadata
122
- B_offsets=vbe_metadata.B_offsets,
123
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
124
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
125
232
  max_B=vbe_metadata.max_B,
126
233
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
127
234
  vbe_output_size=vbe_metadata.output_size,
128
- # optimizer_args
129
- gradient_clipping = optimizer_args.gradient_clipping,
130
- max_gradient=optimizer_args.max_gradient,
131
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
132
- # V1 interface still accepts learning_rate as float
133
- learning_rate=optimizer_args.learning_rate,
134
- eps=optimizer_args.eps,
235
+ vbe_output=vbe_metadata.vbe_output,
236
+ # aux_tensor
237
+ aux_tensor=aux_tensor,
238
+ # aux_int
239
+ aux_int=aux_int,
240
+ # aux_float
241
+ aux_float=aux_float,
242
+ # aux_bool
243
+ aux_bool=aux_bool,
244
+ learning_rate_tensor=common_args.learning_rate_tensor,
245
+
135
246
  # momentum1
136
- momentum1_dev=momentum1.dev,
137
- momentum1_uvm=momentum1.uvm,
138
- momentum1_offsets=momentum1.offsets,
139
- momentum1_placements=momentum1.placements,
247
+ momentum1 = momentum1_list,
140
248
  # momentum2
141
249
  # prev_iter
142
-
143
- prev_iter_dev=prev_iter_dev,
144
250
  # row_counter
145
- # iter
146
- iter=iter,
147
- # max counter
251
+ # optim_tensor
252
+ # optim_int
253
+ # optim_float
254
+ optim_float=optim_float,
255
+ # optim_bool
256
+ # optim symint args
148
257
  # total_unique_indices
149
- output_dtype=common_args.output_dtype,
150
- is_experimental=common_args.is_experimental,
151
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
152
- use_homogeneous_placements=common_args.use_homogeneous_placements,
153
- apply_global_weight_decay=apply_global_weight_decay,
154
- gwd_lower_bound=gwd_lower_bound,
155
258
  )