fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -15,6 +17,12 @@
15
17
 
16
18
  import torch
17
19
  from .lookup_args import *
20
+
21
+
22
+
23
+
24
+
25
+
18
26
  def invoke(
19
27
  common_args: CommonArgs,
20
28
  optimizer_args: OptimizerArgs,
@@ -24,18 +32,180 @@ def invoke(
24
32
  # only pass prev_iter_dev since prev_iter is never created on UVM
25
33
  prev_iter_dev: Optional[torch.Tensor] = None,
26
34
  gwd_lower_bound: float = 0.0,
35
+ mixed_D: bool = True,
27
36
  ) -> torch.Tensor:
28
-
37
+ # host_weights is only used for CPU training
38
+ use_cpu = common_args.host_weights.numel() > 0
29
39
  vbe_metadata = common_args.vbe_metadata
30
40
 
31
- return torch.ops.fbgemm.split_embedding_codegen_lookup_none_function(
41
+ # pack weights
42
+ weights = [
43
+ common_args.dev_weights,
44
+ common_args.uvm_weights,
45
+ common_args.weights_placements,
46
+ common_args.weights_offsets,
47
+ common_args.lxu_cache_weights,
48
+ ] if not use_cpu else [
49
+ common_args.host_weights,
50
+ common_args.weights_placements,
51
+ common_args.weights_offsets,
52
+ ]
53
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
54
+ "B_offsets": vbe_metadata.B_offsets,
55
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
56
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
57
+ "lxu_cache_locations": common_args.lxu_cache_locations,
58
+ "uvm_cache_stats": common_args.uvm_cache_stats,
59
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
60
+ }
61
+
62
+ dict_aux_int: Dict[str, int] = {
63
+ "iter": iter,
64
+ "info_B_num_bits": common_args.info_B_num_bits,
65
+ "info_B_mask": common_args.info_B_mask,
66
+ }
67
+
68
+ dict_aux_float: Dict[str, float] = {
69
+ "gwd_lower_bound": gwd_lower_bound,
70
+ }
71
+
72
+ dict_aux_bool: Dict[str, bool] = {
73
+ "is_experimental_tbe": common_args.is_experimental,
74
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
75
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
76
+ "apply_global_weight_decay": apply_global_weight_decay,
77
+ "mixed_D": mixed_D,
78
+ }
79
+ dict_optim_int: Dict[str, int] = {}
80
+ dict_optim_float: Dict[str, float] = {}
81
+ dict_optim_bool: Dict[str, bool] = {}
82
+
83
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
84
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
85
+
86
+
87
+ # optimizer_args
88
+ dict_optim_int["total_hash_size"] = optimizer_args.total_hash_size # if optimizer == none
89
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
90
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
91
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
92
+
93
+
94
+ aux_tensor: List[Optional[torch.Tensor]] = []
95
+ assert "B_offsets" in dict_aux_tensor, (
96
+ "B_offsets must be in dict_aux_tensor. "
97
+ "Please check the frontend and backend version. "
98
+ )
99
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
100
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
101
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
102
+ "Please check the frontend and backend version. "
103
+ )
104
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
105
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
106
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
107
+ "Please check the frontend and backend version. "
108
+ )
109
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
110
+ assert "lxu_cache_locations" in dict_aux_tensor, (
111
+ "lxu_cache_locations must be in dict_aux_tensor. "
112
+ "Please check the frontend and backend version. "
113
+ )
114
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
115
+ assert "uvm_cache_stats" in dict_aux_tensor, (
116
+ "uvm_cache_stats must be in dict_aux_tensor. "
117
+ "Please check the frontend and backend version. "
118
+ )
119
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
120
+ assert "prev_iter_dev" in dict_aux_tensor, (
121
+ "prev_iter_dev must be in dict_aux_tensor. "
122
+ "Please check the frontend and backend version. "
123
+ )
124
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
125
+ assert "vbe_output_offsets" in dict_aux_tensor, (
126
+ "vbe_output_offsets must be in dict_aux_tensor. "
127
+ "Please check the frontend and backend version. "
128
+ )
129
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
130
+
131
+ aux_int: List[int] = []
132
+ assert "iter" in dict_aux_int, (
133
+ "iter must be in dict_aux_int. "
134
+ "Please check the frontend and backend version. "
135
+ )
136
+ aux_int.append(dict_aux_int["iter"])
137
+ assert "info_B_num_bits" in dict_aux_int, (
138
+ "info_B_num_bits must be in dict_aux_int. "
139
+ "Please check the frontend and backend version. "
140
+ )
141
+ aux_int.append(dict_aux_int["info_B_num_bits"])
142
+ assert "info_B_mask" in dict_aux_int, (
143
+ "info_B_mask must be in dict_aux_int. "
144
+ "Please check the frontend and backend version. "
145
+ )
146
+ aux_int.append(dict_aux_int["info_B_mask"])
147
+
148
+ aux_float: List[float] = []
149
+ assert "gwd_lower_bound" in dict_aux_float, (
150
+ "gwd_lower_bound must be in dict_aux_float. "
151
+ "Please check the frontend and backend version. "
152
+ )
153
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
154
+ assert "max_gradient" in dict_aux_float, (
155
+ "max_gradient must be in dict_aux_float. "
156
+ "Please check the frontend and backend version. "
157
+ )
158
+ aux_float.append(dict_aux_float["max_gradient"])
159
+
160
+ aux_bool: List[bool] = []
161
+ assert "is_experimental_tbe" in dict_aux_bool, (
162
+ "is_experimental_tbe must be in dict_aux_bool. "
163
+ "Please check the frontend and backend version. "
164
+ )
165
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
166
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
167
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
168
+ "Please check the frontend and backend version. "
169
+ )
170
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
171
+ assert "use_homogeneous_placements" in dict_aux_bool, (
172
+ "use_homogeneous_placements must be in dict_aux_bool. "
173
+ "Please check the frontend and backend version. "
174
+ )
175
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
176
+ assert "apply_global_weight_decay" in dict_aux_bool, (
177
+ "apply_global_weight_decay must be in dict_aux_bool. "
178
+ "Please check the frontend and backend version. "
179
+ )
180
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
181
+ assert "gradient_clipping" in dict_aux_bool, (
182
+ "gradient_clipping must be in dict_aux_bool. "
183
+ "Please check the frontend and backend version. "
184
+ )
185
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
186
+ assert "stochastic_rounding" in dict_aux_bool, (
187
+ "stochastic_rounding must be in dict_aux_bool. "
188
+ "Please check the frontend and backend version. "
189
+ )
190
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
191
+ assert "mixed_D" in dict_aux_bool, (
192
+ "mixed_D must be in dict_aux_bool. "
193
+ "Please check the frontend and backend version. "
194
+ )
195
+ aux_bool.append(dict_aux_bool["mixed_D"])
196
+
197
+ # optim_int
198
+ optim_int: List[int] = []
199
+ optim_int.append(dict_optim_int["total_hash_size"])
200
+ # optim_float
201
+ # ['optim_int', 'total_unique_indices']
202
+ # optim_bool
203
+
204
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_none_function_pt2(
32
205
  # common_args
33
206
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
34
- dev_weights=common_args.dev_weights,
35
- uvm_weights=common_args.uvm_weights,
36
- lxu_cache_weights=common_args.lxu_cache_weights,
37
- weights_placements=common_args.weights_placements,
38
- weights_offsets=common_args.weights_offsets,
207
+ # weights
208
+ weights=weights,
39
209
  D_offsets=common_args.D_offsets,
40
210
  total_D=common_args.total_D,
41
211
  max_D=common_args.max_D,
@@ -46,32 +216,31 @@ def invoke(
46
216
  pooling_mode=common_args.pooling_mode,
47
217
  indice_weights=common_args.indice_weights,
48
218
  feature_requires_grad=common_args.feature_requires_grad,
49
- lxu_cache_locations=common_args.lxu_cache_locations,
50
- uvm_cache_stats=common_args.uvm_cache_stats,
219
+ output_dtype=common_args.output_dtype,
51
220
  # VBE metadata
52
- B_offsets=vbe_metadata.B_offsets,
53
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
54
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
55
221
  max_B=vbe_metadata.max_B,
56
222
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
57
223
  vbe_output_size=vbe_metadata.output_size,
58
- # optimizer_args
59
- total_hash_size = optimizer_args.total_hash_size, # if optimizer == none
224
+ vbe_output=vbe_metadata.vbe_output,
225
+ # aux_tensor
226
+ aux_tensor=aux_tensor,
227
+ # aux_int
228
+ aux_int=aux_int,
229
+ # aux_float
230
+ aux_float=aux_float,
231
+ # aux_bool
232
+ aux_bool=aux_bool,
233
+
60
234
  # momentum1
61
235
  # momentum2
62
236
  # prev_iter
63
-
64
- prev_iter_dev=prev_iter_dev,
65
237
  # row_counter
66
- # iter
67
- iter=iter,
68
- # max counter
238
+ # optim_tensor
239
+ # optim_int
240
+ optim_int=optim_int,
241
+ # optim_float
242
+ # optim_bool
243
+ # optim symint args
69
244
  # total_unique_indices
70
- total_unique_indices = total_unique_indices,
71
- output_dtype=common_args.output_dtype,
72
- is_experimental=common_args.is_experimental,
73
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
74
- use_homogeneous_placements=common_args.use_homogeneous_placements,
75
- apply_global_weight_decay=apply_global_weight_decay,
76
- gwd_lower_bound=gwd_lower_bound,
245
+ total_unique_indices=total_unique_indices,
77
246
  )
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -15,6 +17,12 @@
15
17
 
16
18
  import torch
17
19
  from .lookup_args import *
20
+
21
+
22
+
23
+
24
+
25
+
18
26
  def invoke(
19
27
  common_args: CommonArgs,
20
28
  optimizer_args: OptimizerArgs,
@@ -25,18 +33,208 @@ def invoke(
25
33
  # only pass prev_iter_dev since prev_iter is never created on UVM
26
34
  prev_iter_dev: Optional[torch.Tensor] = None,
27
35
  gwd_lower_bound: float = 0.0,
36
+ mixed_D: bool = True,
28
37
  ) -> torch.Tensor:
29
-
38
+ # host_weights is only used for CPU training
39
+ use_cpu = common_args.host_weights.numel() > 0
30
40
  vbe_metadata = common_args.vbe_metadata
31
41
 
32
- return torch.ops.fbgemm.split_embedding_codegen_lookup_partial_rowwise_adam_function(
42
+ # pack weights
43
+ weights = [
44
+ common_args.dev_weights,
45
+ common_args.uvm_weights,
46
+ common_args.weights_placements,
47
+ common_args.weights_offsets,
48
+ common_args.lxu_cache_weights,
49
+ ] if not use_cpu else [
50
+ common_args.host_weights,
51
+ common_args.weights_placements,
52
+ common_args.weights_offsets,
53
+ ]
54
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
55
+ "B_offsets": vbe_metadata.B_offsets,
56
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
57
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
58
+ "lxu_cache_locations": common_args.lxu_cache_locations,
59
+ "uvm_cache_stats": common_args.uvm_cache_stats,
60
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
61
+ }
62
+
63
+ dict_aux_int: Dict[str, int] = {
64
+ "iter": iter,
65
+ "info_B_num_bits": common_args.info_B_num_bits,
66
+ "info_B_mask": common_args.info_B_mask,
67
+ }
68
+
69
+ dict_aux_float: Dict[str, float] = {
70
+ "gwd_lower_bound": gwd_lower_bound,
71
+ }
72
+
73
+ dict_aux_bool: Dict[str, bool] = {
74
+ "is_experimental_tbe": common_args.is_experimental,
75
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
76
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
77
+ "apply_global_weight_decay": apply_global_weight_decay,
78
+ "mixed_D": mixed_D,
79
+ }
80
+ dict_optim_int: Dict[str, int] = {}
81
+ dict_optim_float: Dict[str, float] = {}
82
+ dict_optim_bool: Dict[str, bool] = {}
83
+
84
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
85
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
86
+
87
+
88
+ # optimizer_args # if optimizer == none
89
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
90
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
91
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
92
+ dict_optim_float["eps"] = optimizer_args.eps
93
+ dict_optim_float["beta1"] = optimizer_args.beta1
94
+ dict_optim_float["beta2"] = optimizer_args.beta2
95
+ dict_optim_float["weight_decay"] = optimizer_args.weight_decay
96
+
97
+ momentum1_list = [
98
+ momentum1.dev,
99
+ momentum1.uvm,
100
+ momentum1.placements,
101
+ momentum1.offsets,
102
+ ] if not use_cpu else [
103
+ momentum1.host,
104
+ momentum1.placements,
105
+ momentum1.offsets,
106
+ ] if momentum1 is not None else None
107
+
108
+ momentum2_list = [
109
+ momentum2.dev,
110
+ momentum2.uvm,
111
+ momentum2.placements,
112
+ momentum2.offsets,
113
+ ] if not use_cpu else [
114
+ momentum2.host,
115
+ momentum2.placements,
116
+ momentum2.offsets,
117
+ ] if momentum2 is not None else None
118
+
119
+
120
+ aux_tensor: List[Optional[torch.Tensor]] = []
121
+ assert "B_offsets" in dict_aux_tensor, (
122
+ "B_offsets must be in dict_aux_tensor. "
123
+ "Please check the frontend and backend version. "
124
+ )
125
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
126
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
127
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
128
+ "Please check the frontend and backend version. "
129
+ )
130
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
131
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
132
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
133
+ "Please check the frontend and backend version. "
134
+ )
135
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
136
+ assert "lxu_cache_locations" in dict_aux_tensor, (
137
+ "lxu_cache_locations must be in dict_aux_tensor. "
138
+ "Please check the frontend and backend version. "
139
+ )
140
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
141
+ assert "uvm_cache_stats" in dict_aux_tensor, (
142
+ "uvm_cache_stats must be in dict_aux_tensor. "
143
+ "Please check the frontend and backend version. "
144
+ )
145
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
146
+ assert "prev_iter_dev" in dict_aux_tensor, (
147
+ "prev_iter_dev must be in dict_aux_tensor. "
148
+ "Please check the frontend and backend version. "
149
+ )
150
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
151
+ assert "vbe_output_offsets" in dict_aux_tensor, (
152
+ "vbe_output_offsets must be in dict_aux_tensor. "
153
+ "Please check the frontend and backend version. "
154
+ )
155
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
156
+
157
+ aux_int: List[int] = []
158
+ assert "iter" in dict_aux_int, (
159
+ "iter must be in dict_aux_int. "
160
+ "Please check the frontend and backend version. "
161
+ )
162
+ aux_int.append(dict_aux_int["iter"])
163
+ assert "info_B_num_bits" in dict_aux_int, (
164
+ "info_B_num_bits must be in dict_aux_int. "
165
+ "Please check the frontend and backend version. "
166
+ )
167
+ aux_int.append(dict_aux_int["info_B_num_bits"])
168
+ assert "info_B_mask" in dict_aux_int, (
169
+ "info_B_mask must be in dict_aux_int. "
170
+ "Please check the frontend and backend version. "
171
+ )
172
+ aux_int.append(dict_aux_int["info_B_mask"])
173
+
174
+ aux_float: List[float] = []
175
+ assert "gwd_lower_bound" in dict_aux_float, (
176
+ "gwd_lower_bound must be in dict_aux_float. "
177
+ "Please check the frontend and backend version. "
178
+ )
179
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
180
+ assert "max_gradient" in dict_aux_float, (
181
+ "max_gradient must be in dict_aux_float. "
182
+ "Please check the frontend and backend version. "
183
+ )
184
+ aux_float.append(dict_aux_float["max_gradient"])
185
+
186
+ aux_bool: List[bool] = []
187
+ assert "is_experimental_tbe" in dict_aux_bool, (
188
+ "is_experimental_tbe must be in dict_aux_bool. "
189
+ "Please check the frontend and backend version. "
190
+ )
191
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
192
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
193
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
194
+ "Please check the frontend and backend version. "
195
+ )
196
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
197
+ assert "use_homogeneous_placements" in dict_aux_bool, (
198
+ "use_homogeneous_placements must be in dict_aux_bool. "
199
+ "Please check the frontend and backend version. "
200
+ )
201
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
202
+ assert "apply_global_weight_decay" in dict_aux_bool, (
203
+ "apply_global_weight_decay must be in dict_aux_bool. "
204
+ "Please check the frontend and backend version. "
205
+ )
206
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
207
+ assert "gradient_clipping" in dict_aux_bool, (
208
+ "gradient_clipping must be in dict_aux_bool. "
209
+ "Please check the frontend and backend version. "
210
+ )
211
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
212
+ assert "stochastic_rounding" in dict_aux_bool, (
213
+ "stochastic_rounding must be in dict_aux_bool. "
214
+ "Please check the frontend and backend version. "
215
+ )
216
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
217
+ assert "mixed_D" in dict_aux_bool, (
218
+ "mixed_D must be in dict_aux_bool. "
219
+ "Please check the frontend and backend version. "
220
+ )
221
+ aux_bool.append(dict_aux_bool["mixed_D"])
222
+
223
+ # optim_int
224
+ # optim_float
225
+ # ['momentum1', 'momentum2', 'learning_rate_tensor', 'optim_float']
226
+ optim_float: List[float] = []
227
+ optim_float.append(dict_optim_float["eps"])
228
+ optim_float.append(dict_optim_float["beta1"])
229
+ optim_float.append(dict_optim_float["beta2"])
230
+ optim_float.append(dict_optim_float["weight_decay"])
231
+ # optim_bool
232
+
233
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_partial_rowwise_adam_function_pt2(
33
234
  # common_args
34
235
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
35
- dev_weights=common_args.dev_weights,
36
- uvm_weights=common_args.uvm_weights,
37
- lxu_cache_weights=common_args.lxu_cache_weights,
38
- weights_placements=common_args.weights_placements,
39
- weights_offsets=common_args.weights_offsets,
236
+ # weights
237
+ weights=weights,
40
238
  D_offsets=common_args.D_offsets,
41
239
  total_D=common_args.total_D,
42
240
  max_D=common_args.max_D,
@@ -47,47 +245,33 @@ def invoke(
47
245
  pooling_mode=common_args.pooling_mode,
48
246
  indice_weights=common_args.indice_weights,
49
247
  feature_requires_grad=common_args.feature_requires_grad,
50
- lxu_cache_locations=common_args.lxu_cache_locations,
51
- uvm_cache_stats=common_args.uvm_cache_stats,
248
+ output_dtype=common_args.output_dtype,
52
249
  # VBE metadata
53
- B_offsets=vbe_metadata.B_offsets,
54
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
55
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
56
250
  max_B=vbe_metadata.max_B,
57
251
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
58
252
  vbe_output_size=vbe_metadata.output_size,
59
- # optimizer_args
60
- gradient_clipping = optimizer_args.gradient_clipping,
61
- max_gradient=optimizer_args.max_gradient,
62
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
63
- # V1 interface still accepts learning_rate as float
64
- learning_rate=optimizer_args.learning_rate,
65
- eps=optimizer_args.eps,
66
- beta1=optimizer_args.beta1,
67
- beta2=optimizer_args.beta2,
68
- weight_decay=optimizer_args.weight_decay,
253
+ vbe_output=vbe_metadata.vbe_output,
254
+ # aux_tensor
255
+ aux_tensor=aux_tensor,
256
+ # aux_int
257
+ aux_int=aux_int,
258
+ # aux_float
259
+ aux_float=aux_float,
260
+ # aux_bool
261
+ aux_bool=aux_bool,
262
+ learning_rate_tensor=common_args.learning_rate_tensor,
263
+
69
264
  # momentum1
70
- momentum1_dev=momentum1.dev,
71
- momentum1_uvm=momentum1.uvm,
72
- momentum1_offsets=momentum1.offsets,
73
- momentum1_placements=momentum1.placements,
265
+ momentum1 = momentum1_list,
74
266
  # momentum2
75
- momentum2_dev=momentum2.dev,
76
- momentum2_uvm=momentum2.uvm,
77
- momentum2_offsets=momentum2.offsets,
78
- momentum2_placements=momentum2.placements,
267
+ momentum2=momentum2_list,
79
268
  # prev_iter
80
-
81
- prev_iter_dev=prev_iter_dev,
82
269
  # row_counter
83
- # iter
84
- iter=iter,
85
- # max counter
270
+ # optim_tensor
271
+ # optim_int
272
+ # optim_float
273
+ optim_float=optim_float,
274
+ # optim_bool
275
+ # optim symint args
86
276
  # total_unique_indices
87
- output_dtype=common_args.output_dtype,
88
- is_experimental=common_args.is_experimental,
89
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
90
- use_homogeneous_placements=common_args.use_homogeneous_placements,
91
- apply_global_weight_decay=apply_global_weight_decay,
92
- gwd_lower_bound=gwd_lower_bound,
93
277
  )