fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -66,14 +66,15 @@ def invoke(
66
66
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
67
67
  "lxu_cache_locations": common_args.lxu_cache_locations,
68
68
  "uvm_cache_stats": common_args.uvm_cache_stats,
69
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
69
70
  }
70
71
 
71
72
  dict_aux_int: Dict[str, int] = {
72
- "iter": iter,
73
- "info_B_num_bits": common_args.info_B_num_bits,
73
+ "iter": iter,
74
+ "info_B_num_bits": common_args.info_B_num_bits,
74
75
  "info_B_mask": common_args.info_B_mask,
75
76
  }
76
-
77
+
77
78
  dict_aux_float: Dict[str, float] = {
78
79
  "gwd_lower_bound": gwd_lower_bound,
79
80
  }
@@ -91,7 +92,7 @@ def invoke(
91
92
 
92
93
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
93
94
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
94
-
95
+
95
96
 
96
97
  # optimizer_args # if optimizer == none
97
98
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -156,6 +157,11 @@ def invoke(
156
157
  "Please check the frontend and backend version. "
157
158
  )
158
159
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
160
+ assert "vbe_output_offsets" in dict_aux_tensor, (
161
+ "vbe_output_offsets must be in dict_aux_tensor. "
162
+ "Please check the frontend and backend version. "
163
+ )
164
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
159
165
 
160
166
  aux_int: List[int] = []
161
167
  assert "iter" in dict_aux_int, (
@@ -231,7 +237,7 @@ def invoke(
231
237
  optim_float.append(dict_optim_float["beta1"])
232
238
  optim_float.append(dict_optim_float["beta2"])
233
239
  optim_float.append(dict_optim_float["weight_decay"])
234
- # optim_bool
240
+ # optim_bool
235
241
 
236
242
  return torch.ops.fbgemm.split_embedding_codegen_lookup_lamb_function_pt2(
237
243
  # common_args
@@ -253,6 +259,7 @@ def invoke(
253
259
  max_B=vbe_metadata.max_B,
254
260
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
255
261
  vbe_output_size=vbe_metadata.output_size,
262
+ vbe_output=vbe_metadata.vbe_output,
256
263
  # aux_tensor
257
264
  aux_tensor=aux_tensor,
258
265
  # aux_int
@@ -65,14 +65,15 @@ def invoke(
65
65
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
66
66
  "lxu_cache_locations": common_args.lxu_cache_locations,
67
67
  "uvm_cache_stats": common_args.uvm_cache_stats,
68
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
68
69
  }
69
70
 
70
71
  dict_aux_int: Dict[str, int] = {
71
- "iter": iter,
72
- "info_B_num_bits": common_args.info_B_num_bits,
72
+ "iter": iter,
73
+ "info_B_num_bits": common_args.info_B_num_bits,
73
74
  "info_B_mask": common_args.info_B_mask,
74
75
  }
75
-
76
+
76
77
  dict_aux_float: Dict[str, float] = {
77
78
  "gwd_lower_bound": gwd_lower_bound,
78
79
  }
@@ -90,7 +91,7 @@ def invoke(
90
91
 
91
92
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
92
93
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
93
-
94
+
94
95
 
95
96
  # optimizer_args # if optimizer == none
96
97
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -143,6 +144,11 @@ def invoke(
143
144
  "Please check the frontend and backend version. "
144
145
  )
145
146
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
147
+ assert "vbe_output_offsets" in dict_aux_tensor, (
148
+ "vbe_output_offsets must be in dict_aux_tensor. "
149
+ "Please check the frontend and backend version. "
150
+ )
151
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
146
152
 
147
153
  aux_int: List[int] = []
148
154
  assert "iter" in dict_aux_int, (
@@ -217,7 +223,7 @@ def invoke(
217
223
  optim_float.append(dict_optim_float["eta"])
218
224
  optim_float.append(dict_optim_float["momentum"])
219
225
  optim_float.append(dict_optim_float["weight_decay"])
220
- # optim_bool
226
+ # optim_bool
221
227
 
222
228
  return torch.ops.fbgemm.split_embedding_codegen_lookup_lars_sgd_function_pt2(
223
229
  # common_args
@@ -239,6 +245,7 @@ def invoke(
239
245
  max_B=vbe_metadata.max_B,
240
246
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
241
247
  vbe_output_size=vbe_metadata.output_size,
248
+ vbe_output=vbe_metadata.vbe_output,
242
249
  # aux_tensor
243
250
  aux_tensor=aux_tensor,
244
251
  # aux_int
@@ -56,14 +56,15 @@ def invoke(
56
56
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
57
57
  "lxu_cache_locations": common_args.lxu_cache_locations,
58
58
  "uvm_cache_stats": common_args.uvm_cache_stats,
59
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
59
60
  }
60
61
 
61
62
  dict_aux_int: Dict[str, int] = {
62
- "iter": iter,
63
- "info_B_num_bits": common_args.info_B_num_bits,
63
+ "iter": iter,
64
+ "info_B_num_bits": common_args.info_B_num_bits,
64
65
  "info_B_mask": common_args.info_B_mask,
65
66
  }
66
-
67
+
67
68
  dict_aux_float: Dict[str, float] = {
68
69
  "gwd_lower_bound": gwd_lower_bound,
69
70
  }
@@ -81,7 +82,7 @@ def invoke(
81
82
 
82
83
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
83
84
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
84
-
85
+
85
86
 
86
87
  # optimizer_args
87
88
  dict_optim_int["total_hash_size"] = optimizer_args.total_hash_size # if optimizer == none
@@ -121,6 +122,11 @@ def invoke(
121
122
  "Please check the frontend and backend version. "
122
123
  )
123
124
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
125
+ assert "vbe_output_offsets" in dict_aux_tensor, (
126
+ "vbe_output_offsets must be in dict_aux_tensor. "
127
+ "Please check the frontend and backend version. "
128
+ )
129
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
124
130
 
125
131
  aux_int: List[int] = []
126
132
  assert "iter" in dict_aux_int, (
@@ -193,7 +199,7 @@ def invoke(
193
199
  optim_int.append(dict_optim_int["total_hash_size"])
194
200
  # optim_float
195
201
  # ['optim_int', 'total_unique_indices']
196
- # optim_bool
202
+ # optim_bool
197
203
 
198
204
  return torch.ops.fbgemm.split_embedding_codegen_lookup_none_function_pt2(
199
205
  # common_args
@@ -215,6 +221,7 @@ def invoke(
215
221
  max_B=vbe_metadata.max_B,
216
222
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
217
223
  vbe_output_size=vbe_metadata.output_size,
224
+ vbe_output=vbe_metadata.vbe_output,
218
225
  # aux_tensor
219
226
  aux_tensor=aux_tensor,
220
227
  # aux_int
@@ -57,14 +57,15 @@ def invoke(
57
57
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
58
58
  "lxu_cache_locations": common_args.lxu_cache_locations,
59
59
  "uvm_cache_stats": common_args.uvm_cache_stats,
60
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
60
61
  }
61
62
 
62
63
  dict_aux_int: Dict[str, int] = {
63
- "iter": iter,
64
- "info_B_num_bits": common_args.info_B_num_bits,
64
+ "iter": iter,
65
+ "info_B_num_bits": common_args.info_B_num_bits,
65
66
  "info_B_mask": common_args.info_B_mask,
66
67
  }
67
-
68
+
68
69
  dict_aux_float: Dict[str, float] = {
69
70
  "gwd_lower_bound": gwd_lower_bound,
70
71
  }
@@ -82,7 +83,7 @@ def invoke(
82
83
 
83
84
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
84
85
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
85
-
86
+
86
87
 
87
88
  # optimizer_args # if optimizer == none
88
89
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -147,6 +148,11 @@ def invoke(
147
148
  "Please check the frontend and backend version. "
148
149
  )
149
150
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
151
+ assert "vbe_output_offsets" in dict_aux_tensor, (
152
+ "vbe_output_offsets must be in dict_aux_tensor. "
153
+ "Please check the frontend and backend version. "
154
+ )
155
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
150
156
 
151
157
  aux_int: List[int] = []
152
158
  assert "iter" in dict_aux_int, (
@@ -222,7 +228,7 @@ def invoke(
222
228
  optim_float.append(dict_optim_float["beta1"])
223
229
  optim_float.append(dict_optim_float["beta2"])
224
230
  optim_float.append(dict_optim_float["weight_decay"])
225
- # optim_bool
231
+ # optim_bool
226
232
 
227
233
  return torch.ops.fbgemm.split_embedding_codegen_lookup_partial_rowwise_adam_function_pt2(
228
234
  # common_args
@@ -244,6 +250,7 @@ def invoke(
244
250
  max_B=vbe_metadata.max_B,
245
251
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
246
252
  vbe_output_size=vbe_metadata.output_size,
253
+ vbe_output=vbe_metadata.vbe_output,
247
254
  # aux_tensor
248
255
  aux_tensor=aux_tensor,
249
256
  # aux_int
@@ -57,14 +57,15 @@ def invoke(
57
57
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
58
58
  "lxu_cache_locations": common_args.lxu_cache_locations,
59
59
  "uvm_cache_stats": common_args.uvm_cache_stats,
60
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
60
61
  }
61
62
 
62
63
  dict_aux_int: Dict[str, int] = {
63
- "iter": iter,
64
- "info_B_num_bits": common_args.info_B_num_bits,
64
+ "iter": iter,
65
+ "info_B_num_bits": common_args.info_B_num_bits,
65
66
  "info_B_mask": common_args.info_B_mask,
66
67
  }
67
-
68
+
68
69
  dict_aux_float: Dict[str, float] = {
69
70
  "gwd_lower_bound": gwd_lower_bound,
70
71
  }
@@ -82,7 +83,7 @@ def invoke(
82
83
 
83
84
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
84
85
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
85
-
86
+
86
87
 
87
88
  # optimizer_args # if optimizer == none
88
89
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -147,6 +148,11 @@ def invoke(
147
148
  "Please check the frontend and backend version. "
148
149
  )
149
150
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
151
+ assert "vbe_output_offsets" in dict_aux_tensor, (
152
+ "vbe_output_offsets must be in dict_aux_tensor. "
153
+ "Please check the frontend and backend version. "
154
+ )
155
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
150
156
 
151
157
  aux_int: List[int] = []
152
158
  assert "iter" in dict_aux_int, (
@@ -222,7 +228,7 @@ def invoke(
222
228
  optim_float.append(dict_optim_float["beta1"])
223
229
  optim_float.append(dict_optim_float["beta2"])
224
230
  optim_float.append(dict_optim_float["weight_decay"])
225
- # optim_bool
231
+ # optim_bool
226
232
 
227
233
  return torch.ops.fbgemm.split_embedding_codegen_lookup_partial_rowwise_lamb_function_pt2(
228
234
  # common_args
@@ -244,6 +250,7 @@ def invoke(
244
250
  max_B=vbe_metadata.max_B,
245
251
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
246
252
  vbe_output_size=vbe_metadata.output_size,
253
+ vbe_output=vbe_metadata.vbe_output,
247
254
  # aux_tensor
248
255
  aux_tensor=aux_tensor,
249
256
  # aux_int
@@ -56,14 +56,15 @@ def invoke(
56
56
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
57
57
  "lxu_cache_locations": common_args.lxu_cache_locations,
58
58
  "uvm_cache_stats": common_args.uvm_cache_stats,
59
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
59
60
  }
60
61
 
61
62
  dict_aux_int: Dict[str, int] = {
62
- "iter": iter,
63
- "info_B_num_bits": common_args.info_B_num_bits,
63
+ "iter": iter,
64
+ "info_B_num_bits": common_args.info_B_num_bits,
64
65
  "info_B_mask": common_args.info_B_mask,
65
66
  }
66
-
67
+
67
68
  dict_aux_float: Dict[str, float] = {
68
69
  "gwd_lower_bound": gwd_lower_bound,
69
70
  }
@@ -81,7 +82,7 @@ def invoke(
81
82
 
82
83
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
83
84
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
84
-
85
+
85
86
 
86
87
  # optimizer_args # if optimizer == none
87
88
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -135,6 +136,11 @@ def invoke(
135
136
  "Please check the frontend and backend version. "
136
137
  )
137
138
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
139
+ assert "vbe_output_offsets" in dict_aux_tensor, (
140
+ "vbe_output_offsets must be in dict_aux_tensor. "
141
+ "Please check the frontend and backend version. "
142
+ )
143
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
138
144
 
139
145
  aux_int: List[int] = []
140
146
  assert "iter" in dict_aux_int, (
@@ -211,7 +217,7 @@ def invoke(
211
217
  optim_float.append(dict_optim_float["eps"])
212
218
  optim_float.append(dict_optim_float["weight_decay"])
213
219
  optim_float.append(dict_optim_float["max_norm"])
214
- # optim_bool
220
+ # optim_bool
215
221
 
216
222
  return torch.ops.fbgemm.split_embedding_codegen_lookup_rowwise_adagrad_function_pt2(
217
223
  # common_args
@@ -233,6 +239,7 @@ def invoke(
233
239
  max_B=vbe_metadata.max_B,
234
240
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
235
241
  vbe_output_size=vbe_metadata.output_size,
242
+ vbe_output=vbe_metadata.vbe_output,
236
243
  # aux_tensor
237
244
  aux_tensor=aux_tensor,
238
245
  # aux_int
@@ -77,14 +77,15 @@ def invoke(
77
77
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
78
78
  "lxu_cache_locations": common_args.lxu_cache_locations,
79
79
  "uvm_cache_stats": common_args.uvm_cache_stats,
80
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
80
81
  }
81
82
 
82
83
  dict_aux_int: Dict[str, int] = {
83
- "iter": iter,
84
- "info_B_num_bits": common_args.info_B_num_bits,
84
+ "iter": iter,
85
+ "info_B_num_bits": common_args.info_B_num_bits,
85
86
  "info_B_mask": common_args.info_B_mask,
86
87
  }
87
-
88
+
88
89
  dict_aux_float: Dict[str, float] = {
89
90
  "gwd_lower_bound": gwd_lower_bound,
90
91
  }
@@ -103,7 +104,7 @@ def invoke(
103
104
 
104
105
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
105
106
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
106
-
107
+
107
108
 
108
109
  # optimizer_args # if optimizer == none
109
110
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -157,6 +158,11 @@ def invoke(
157
158
  "Please check the frontend and backend version. "
158
159
  )
159
160
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
161
+ assert "vbe_output_offsets" in dict_aux_tensor, (
162
+ "vbe_output_offsets must be in dict_aux_tensor. "
163
+ "Please check the frontend and backend version. "
164
+ )
165
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
160
166
 
161
167
  aux_int: List[int] = []
162
168
  assert "iter" in dict_aux_int, (
@@ -238,7 +244,7 @@ def invoke(
238
244
  optim_float.append(dict_optim_float["eps"])
239
245
  optim_float.append(dict_optim_float["weight_decay"])
240
246
  optim_float.append(dict_optim_float["max_norm"])
241
- # optim_bool
247
+ # optim_bool
242
248
 
243
249
  return torch.ops.fbgemm.ssd_embedding_codegen_lookup_rowwise_adagrad_function_pt2(
244
250
  # common_args
@@ -261,6 +267,7 @@ def invoke(
261
267
  max_B=vbe_metadata.max_B,
262
268
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
263
269
  vbe_output_size=vbe_metadata.output_size,
270
+ vbe_output=vbe_metadata.vbe_output,
264
271
  # aux_tensor
265
272
  aux_tensor=aux_tensor,
266
273
  # aux_int
@@ -57,14 +57,15 @@ def invoke(
57
57
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
58
58
  "lxu_cache_locations": common_args.lxu_cache_locations,
59
59
  "uvm_cache_stats": common_args.uvm_cache_stats,
60
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
60
61
  }
61
62
 
62
63
  dict_aux_int: Dict[str, int] = {
63
- "iter": iter,
64
- "info_B_num_bits": common_args.info_B_num_bits,
64
+ "iter": iter,
65
+ "info_B_num_bits": common_args.info_B_num_bits,
65
66
  "info_B_mask": common_args.info_B_mask,
66
67
  }
67
-
68
+
68
69
  dict_aux_float: Dict[str, float] = {
69
70
  "gwd_lower_bound": gwd_lower_bound,
70
71
  }
@@ -82,7 +83,7 @@ def invoke(
82
83
 
83
84
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
84
85
  dict_aux_tensor["prev_iter_dev"] = prev_iter.dev
85
-
86
+
86
87
 
87
88
  # optimizer_args # if optimizer == none
88
89
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -168,6 +169,11 @@ def invoke(
168
169
  "Please check the frontend and backend version. "
169
170
  )
170
171
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
172
+ assert "vbe_output_offsets" in dict_aux_tensor, (
173
+ "vbe_output_offsets must be in dict_aux_tensor. "
174
+ "Please check the frontend and backend version. "
175
+ )
176
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
171
177
 
172
178
  aux_int: List[int] = []
173
179
  assert "iter" in dict_aux_int, (
@@ -254,7 +260,7 @@ def invoke(
254
260
  optim_float.append(dict_optim_float["tail_id_threshold"])
255
261
  optim_float.append(dict_optim_float["weight_norm_coefficient"])
256
262
  optim_float.append(dict_optim_float["lower_bound"])
257
- # optim_bool
263
+ # optim_bool
258
264
 
259
265
  return torch.ops.fbgemm.split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function_pt2(
260
266
  # common_args
@@ -276,6 +282,7 @@ def invoke(
276
282
  max_B=vbe_metadata.max_B,
277
283
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
278
284
  vbe_output_size=vbe_metadata.output_size,
285
+ vbe_output=vbe_metadata.vbe_output,
279
286
  # aux_tensor
280
287
  aux_tensor=aux_tensor,
281
288
  # aux_int
@@ -55,14 +55,15 @@ def invoke(
55
55
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
56
56
  "lxu_cache_locations": common_args.lxu_cache_locations,
57
57
  "uvm_cache_stats": common_args.uvm_cache_stats,
58
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
58
59
  }
59
60
 
60
61
  dict_aux_int: Dict[str, int] = {
61
- "iter": iter,
62
- "info_B_num_bits": common_args.info_B_num_bits,
62
+ "iter": iter,
63
+ "info_B_num_bits": common_args.info_B_num_bits,
63
64
  "info_B_mask": common_args.info_B_mask,
64
65
  }
65
-
66
+
66
67
  dict_aux_float: Dict[str, float] = {
67
68
  "gwd_lower_bound": gwd_lower_bound,
68
69
  }
@@ -80,7 +81,7 @@ def invoke(
80
81
 
81
82
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
82
83
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
83
-
84
+
84
85
 
85
86
  # optimizer_args # if optimizer == none
86
87
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -119,6 +120,11 @@ def invoke(
119
120
  "Please check the frontend and backend version. "
120
121
  )
121
122
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
123
+ assert "vbe_output_offsets" in dict_aux_tensor, (
124
+ "vbe_output_offsets must be in dict_aux_tensor. "
125
+ "Please check the frontend and backend version. "
126
+ )
127
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
122
128
 
123
129
  aux_int: List[int] = []
124
130
  assert "iter" in dict_aux_int, (
@@ -189,7 +195,7 @@ def invoke(
189
195
  # optim_int
190
196
  # optim_float
191
197
  # ['learning_rate_tensor']
192
- # optim_bool
198
+ # optim_bool
193
199
 
194
200
  return torch.ops.fbgemm.split_embedding_codegen_lookup_sgd_function_pt2(
195
201
  # common_args
@@ -211,6 +217,7 @@ def invoke(
211
217
  max_B=vbe_metadata.max_B,
212
218
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
213
219
  vbe_output_size=vbe_metadata.output_size,
220
+ vbe_output=vbe_metadata.vbe_output,
214
221
  # aux_tensor
215
222
  aux_tensor=aux_tensor,
216
223
  # aux_int