fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -15,6 +17,12 @@
15
17
 
16
18
  import torch
17
19
  from .lookup_args_ssd import *
20
+
21
+
22
+
23
+
24
+
25
+
18
26
  def invoke(
19
27
  common_args: CommonArgs,
20
28
  optimizer_args: OptimizerArgs,
@@ -24,8 +32,10 @@ def invoke(
24
32
  # only pass prev_iter_dev since prev_iter is never created on UVM
25
33
  prev_iter_dev: Optional[torch.Tensor] = None,
26
34
  gwd_lower_bound: float = 0.0,
35
+ mixed_D: bool = True,
27
36
  ) -> torch.Tensor:
28
-
37
+ # host_weights is only used for CPU training
38
+ use_cpu = common_args.host_weights.numel() > 0
29
39
  vbe_metadata = common_args.vbe_metadata
30
40
  ssd_tensors = []
31
41
  assert "row_addrs" in common_args.ssd_tensors, (
@@ -49,14 +59,198 @@ def invoke(
49
59
  )
50
60
  ssd_tensors.append(common_args.ssd_tensors["actions_count"])
51
61
 
52
- return torch.ops.fbgemm.ssd_embedding_codegen_lookup_rowwise_adagrad_function(
62
+ # pack weights
63
+ weights = [
64
+ common_args.dev_weights,
65
+ common_args.uvm_weights,
66
+ common_args.weights_placements,
67
+ common_args.weights_offsets,
68
+ common_args.lxu_cache_weights,
69
+ ] if not use_cpu else [
70
+ common_args.host_weights,
71
+ common_args.weights_placements,
72
+ common_args.weights_offsets,
73
+ ]
74
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
75
+ "B_offsets": vbe_metadata.B_offsets,
76
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
77
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
78
+ "lxu_cache_locations": common_args.lxu_cache_locations,
79
+ "uvm_cache_stats": common_args.uvm_cache_stats,
80
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
81
+ }
82
+
83
+ dict_aux_int: Dict[str, int] = {
84
+ "iter": iter,
85
+ "info_B_num_bits": common_args.info_B_num_bits,
86
+ "info_B_mask": common_args.info_B_mask,
87
+ }
88
+
89
+ dict_aux_float: Dict[str, float] = {
90
+ "gwd_lower_bound": gwd_lower_bound,
91
+ }
92
+
93
+ dict_aux_bool: Dict[str, bool] = {
94
+ "is_experimental_tbe": common_args.is_experimental,
95
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
96
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
97
+ "apply_global_weight_decay": apply_global_weight_decay,
98
+ "mixed_D": mixed_D,
99
+ "enable_optimizer_offloading": common_args.enable_optimizer_offloading,
100
+ }
101
+ dict_optim_int: Dict[str, int] = {}
102
+ dict_optim_float: Dict[str, float] = {}
103
+ dict_optim_bool: Dict[str, bool] = {}
104
+
105
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
106
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
107
+
108
+
109
+ # optimizer_args # if optimizer == none
110
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
111
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
112
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
113
+ dict_optim_float["eps"] = optimizer_args.eps
114
+ dict_optim_float["weight_decay"] = optimizer_args.weight_decay
115
+ dict_optim_int["weight_decay_mode"] = optimizer_args.weight_decay_mode
116
+ dict_optim_float["max_norm"] = optimizer_args.max_norm
117
+
118
+ momentum1_list = [
119
+ momentum1.dev,
120
+ momentum1.uvm,
121
+ momentum1.placements,
122
+ momentum1.offsets,
123
+ ] if not use_cpu else [
124
+ momentum1.host,
125
+ momentum1.placements,
126
+ momentum1.offsets,
127
+ ] if momentum1 is not None else None
128
+
129
+
130
+ aux_tensor: List[Optional[torch.Tensor]] = []
131
+ assert "B_offsets" in dict_aux_tensor, (
132
+ "B_offsets must be in dict_aux_tensor. "
133
+ "Please check the frontend and backend version. "
134
+ )
135
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
136
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
137
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
138
+ "Please check the frontend and backend version. "
139
+ )
140
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
141
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
142
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
143
+ "Please check the frontend and backend version. "
144
+ )
145
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
146
+ assert "lxu_cache_locations" in dict_aux_tensor, (
147
+ "lxu_cache_locations must be in dict_aux_tensor. "
148
+ "Please check the frontend and backend version. "
149
+ )
150
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
151
+ assert "uvm_cache_stats" in dict_aux_tensor, (
152
+ "uvm_cache_stats must be in dict_aux_tensor. "
153
+ "Please check the frontend and backend version. "
154
+ )
155
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
156
+ assert "prev_iter_dev" in dict_aux_tensor, (
157
+ "prev_iter_dev must be in dict_aux_tensor. "
158
+ "Please check the frontend and backend version. "
159
+ )
160
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
161
+ assert "vbe_output_offsets" in dict_aux_tensor, (
162
+ "vbe_output_offsets must be in dict_aux_tensor. "
163
+ "Please check the frontend and backend version. "
164
+ )
165
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
166
+
167
+ aux_int: List[int] = []
168
+ assert "iter" in dict_aux_int, (
169
+ "iter must be in dict_aux_int. "
170
+ "Please check the frontend and backend version. "
171
+ )
172
+ aux_int.append(dict_aux_int["iter"])
173
+ assert "info_B_num_bits" in dict_aux_int, (
174
+ "info_B_num_bits must be in dict_aux_int. "
175
+ "Please check the frontend and backend version. "
176
+ )
177
+ aux_int.append(dict_aux_int["info_B_num_bits"])
178
+ assert "info_B_mask" in dict_aux_int, (
179
+ "info_B_mask must be in dict_aux_int. "
180
+ "Please check the frontend and backend version. "
181
+ )
182
+ aux_int.append(dict_aux_int["info_B_mask"])
183
+
184
+ aux_float: List[float] = []
185
+ assert "gwd_lower_bound" in dict_aux_float, (
186
+ "gwd_lower_bound must be in dict_aux_float. "
187
+ "Please check the frontend and backend version. "
188
+ )
189
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
190
+ assert "max_gradient" in dict_aux_float, (
191
+ "max_gradient must be in dict_aux_float. "
192
+ "Please check the frontend and backend version. "
193
+ )
194
+ aux_float.append(dict_aux_float["max_gradient"])
195
+
196
+ aux_bool: List[bool] = []
197
+ assert "is_experimental_tbe" in dict_aux_bool, (
198
+ "is_experimental_tbe must be in dict_aux_bool. "
199
+ "Please check the frontend and backend version. "
200
+ )
201
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
202
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
203
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
204
+ "Please check the frontend and backend version. "
205
+ )
206
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
207
+ assert "use_homogeneous_placements" in dict_aux_bool, (
208
+ "use_homogeneous_placements must be in dict_aux_bool. "
209
+ "Please check the frontend and backend version. "
210
+ )
211
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
212
+ assert "apply_global_weight_decay" in dict_aux_bool, (
213
+ "apply_global_weight_decay must be in dict_aux_bool. "
214
+ "Please check the frontend and backend version. "
215
+ )
216
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
217
+ assert "gradient_clipping" in dict_aux_bool, (
218
+ "gradient_clipping must be in dict_aux_bool. "
219
+ "Please check the frontend and backend version. "
220
+ )
221
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
222
+ assert "stochastic_rounding" in dict_aux_bool, (
223
+ "stochastic_rounding must be in dict_aux_bool. "
224
+ "Please check the frontend and backend version. "
225
+ )
226
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
227
+ assert "mixed_D" in dict_aux_bool, (
228
+ "mixed_D must be in dict_aux_bool. "
229
+ "Please check the frontend and backend version. "
230
+ )
231
+ aux_bool.append(dict_aux_bool["mixed_D"])
232
+ assert "enable_optimizer_offloading" in dict_aux_bool, (
233
+ "enable_optimizer_offloading must be in dict_aux_bool. "
234
+ "Please check the frontend and backend version. "
235
+ )
236
+ aux_bool.append(dict_aux_bool["enable_optimizer_offloading"])
237
+
238
+ # optim_int
239
+ optim_int: List[int] = []
240
+ optim_int.append(dict_optim_int["weight_decay_mode"])
241
+ # optim_float
242
+ # ['momentum1', 'learning_rate_tensor', 'optim_int', 'optim_float']
243
+ optim_float: List[float] = []
244
+ optim_float.append(dict_optim_float["eps"])
245
+ optim_float.append(dict_optim_float["weight_decay"])
246
+ optim_float.append(dict_optim_float["max_norm"])
247
+ # optim_bool
248
+
249
+ return torch.ops.fbgemm.ssd_embedding_codegen_lookup_rowwise_adagrad_function_pt2(
53
250
  # common_args
54
251
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
55
- dev_weights=common_args.dev_weights,
56
- uvm_weights=common_args.uvm_weights,
57
- lxu_cache_weights=common_args.lxu_cache_weights,
58
- weights_placements=common_args.weights_placements,
59
- weights_offsets=common_args.weights_offsets,
252
+ # weights
253
+ weights=weights,
60
254
  D_offsets=common_args.D_offsets,
61
255
  total_D=common_args.total_D,
62
256
  max_D=common_args.max_D,
@@ -67,44 +261,34 @@ def invoke(
67
261
  pooling_mode=common_args.pooling_mode,
68
262
  indice_weights=common_args.indice_weights,
69
263
  feature_requires_grad=common_args.feature_requires_grad,
70
- lxu_cache_locations=common_args.lxu_cache_locations,
71
- uvm_cache_stats=common_args.uvm_cache_stats,
264
+ output_dtype=common_args.output_dtype,
72
265
  ssd_tensors=ssd_tensors,
73
266
  # VBE metadata
74
- B_offsets=vbe_metadata.B_offsets,
75
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
76
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
77
267
  max_B=vbe_metadata.max_B,
78
268
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
79
269
  vbe_output_size=vbe_metadata.output_size,
80
- # optimizer_args
81
- gradient_clipping = optimizer_args.gradient_clipping,
82
- max_gradient=optimizer_args.max_gradient,
83
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
84
- # V1 interface still accepts learning_rate as float
85
- learning_rate=optimizer_args.learning_rate,
86
- eps=optimizer_args.eps,
87
- weight_decay=optimizer_args.weight_decay,
88
- weight_decay_mode=optimizer_args.weight_decay_mode,
89
- max_norm=optimizer_args.max_norm,
270
+ vbe_output=vbe_metadata.vbe_output,
271
+ # aux_tensor
272
+ aux_tensor=aux_tensor,
273
+ # aux_int
274
+ aux_int=aux_int,
275
+ # aux_float
276
+ aux_float=aux_float,
277
+ # aux_bool
278
+ aux_bool=aux_bool,
279
+ learning_rate_tensor=common_args.learning_rate_tensor,
280
+
90
281
  # momentum1
91
- momentum1_dev=momentum1.dev,
92
- momentum1_uvm=momentum1.uvm,
93
- momentum1_offsets=momentum1.offsets,
94
- momentum1_placements=momentum1.placements,
282
+ momentum1 = momentum1_list,
95
283
  # momentum2
96
284
  # prev_iter
97
-
98
- prev_iter_dev=prev_iter_dev,
99
285
  # row_counter
100
- # iter
101
- iter=iter,
102
- # max counter
286
+ # optim_tensor
287
+ # optim_int
288
+ optim_int=optim_int,
289
+ # optim_float
290
+ optim_float=optim_float,
291
+ # optim_bool
292
+ # optim symint args
103
293
  # total_unique_indices
104
- output_dtype=common_args.output_dtype,
105
- is_experimental=common_args.is_experimental,
106
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
107
- use_homogeneous_placements=common_args.use_homogeneous_placements,
108
- apply_global_weight_decay=apply_global_weight_decay,
109
- gwd_lower_bound=gwd_lower_bound,
110
294
  )
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -15,6 +17,12 @@
15
17
 
16
18
  import torch
17
19
  from .lookup_args import *
20
+
21
+
22
+
23
+
24
+
25
+
18
26
  def invoke(
19
27
  common_args: CommonArgs,
20
28
  optimizer_args: OptimizerArgs,
@@ -25,18 +33,240 @@ def invoke(
25
33
  max_counter: float,
26
34
  apply_global_weight_decay: bool = False,
27
35
  gwd_lower_bound: float = 0.0,
36
+ mixed_D: bool = True,
28
37
  ) -> torch.Tensor:
29
-
38
+ # host_weights is only used for CPU training
39
+ use_cpu = common_args.host_weights.numel() > 0
30
40
  vbe_metadata = common_args.vbe_metadata
31
41
 
32
- return torch.ops.fbgemm.split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function(
42
+ # pack weights
43
+ weights = [
44
+ common_args.dev_weights,
45
+ common_args.uvm_weights,
46
+ common_args.weights_placements,
47
+ common_args.weights_offsets,
48
+ common_args.lxu_cache_weights,
49
+ ] if not use_cpu else [
50
+ common_args.host_weights,
51
+ common_args.weights_placements,
52
+ common_args.weights_offsets,
53
+ ]
54
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
55
+ "B_offsets": vbe_metadata.B_offsets,
56
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
57
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
58
+ "lxu_cache_locations": common_args.lxu_cache_locations,
59
+ "uvm_cache_stats": common_args.uvm_cache_stats,
60
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
61
+ }
62
+
63
+ dict_aux_int: Dict[str, int] = {
64
+ "iter": iter,
65
+ "info_B_num_bits": common_args.info_B_num_bits,
66
+ "info_B_mask": common_args.info_B_mask,
67
+ }
68
+
69
+ dict_aux_float: Dict[str, float] = {
70
+ "gwd_lower_bound": gwd_lower_bound,
71
+ }
72
+
73
+ dict_aux_bool: Dict[str, bool] = {
74
+ "is_experimental_tbe": common_args.is_experimental,
75
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
76
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
77
+ "apply_global_weight_decay": apply_global_weight_decay,
78
+ "mixed_D": mixed_D,
79
+ }
80
+ dict_optim_int: Dict[str, int] = {}
81
+ dict_optim_float: Dict[str, float] = {}
82
+ dict_optim_bool: Dict[str, bool] = {}
83
+
84
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
85
+ dict_aux_tensor["prev_iter_dev"] = prev_iter.dev
86
+
87
+
88
+ # optimizer_args # if optimizer == none
89
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
90
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
91
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
92
+ dict_optim_float["eps"] = optimizer_args.eps
93
+ dict_optim_float["weight_decay"] = optimizer_args.weight_decay
94
+ dict_optim_int["weight_decay_mode"] = optimizer_args.weight_decay_mode
95
+ dict_optim_int["counter_halflife"] = optimizer_args.counter_halflife
96
+ dict_optim_int["adjustment_iter"] = optimizer_args.adjustment_iter
97
+ dict_optim_float["adjustment_ub"] = optimizer_args.adjustment_ub
98
+ dict_optim_int["learning_rate_mode"] = optimizer_args.learning_rate_mode
99
+ dict_optim_int["grad_sum_decay"] = optimizer_args.grad_sum_decay
100
+ dict_optim_float["tail_id_threshold"] = optimizer_args.tail_id_threshold
101
+ dict_optim_int["is_tail_id_thresh_ratio"] = optimizer_args.is_tail_id_thresh_ratio
102
+ dict_optim_float["weight_norm_coefficient"] = optimizer_args.weight_norm_coefficient
103
+ dict_optim_float["lower_bound"] = optimizer_args.lower_bound
104
+ dict_optim_int["regularization_mode"] = optimizer_args.regularization_mode
105
+ dict_optim_float["max_counter"] = max_counter
106
+
107
+ momentum1_list = [
108
+ momentum1.dev,
109
+ momentum1.uvm,
110
+ momentum1.placements,
111
+ momentum1.offsets,
112
+ ] if not use_cpu else [
113
+ momentum1.host,
114
+ momentum1.placements,
115
+ momentum1.offsets,
116
+ ] if momentum1 is not None else None
117
+
118
+ prev_iter_list = [
119
+ prev_iter.dev,
120
+ prev_iter.uvm,
121
+ prev_iter.placements,
122
+ prev_iter.offsets,
123
+ ] if not use_cpu else [
124
+ prev_iter.host,
125
+ prev_iter.placements,
126
+ prev_iter.offsets,
127
+ ] if prev_iter is not None else None
128
+
129
+ row_counter_list = [
130
+ row_counter.dev,
131
+ row_counter.uvm,
132
+ row_counter.placements,
133
+ row_counter.offsets,
134
+ ] if not use_cpu else [
135
+ row_counter.host,
136
+ row_counter.placements,
137
+ row_counter.offsets,
138
+ ] if row_counter is not None else None
139
+
140
+
141
+ aux_tensor: List[Optional[torch.Tensor]] = []
142
+ assert "B_offsets" in dict_aux_tensor, (
143
+ "B_offsets must be in dict_aux_tensor. "
144
+ "Please check the frontend and backend version. "
145
+ )
146
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
147
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
148
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
149
+ "Please check the frontend and backend version. "
150
+ )
151
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
152
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
153
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
154
+ "Please check the frontend and backend version. "
155
+ )
156
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
157
+ assert "lxu_cache_locations" in dict_aux_tensor, (
158
+ "lxu_cache_locations must be in dict_aux_tensor. "
159
+ "Please check the frontend and backend version. "
160
+ )
161
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
162
+ assert "uvm_cache_stats" in dict_aux_tensor, (
163
+ "uvm_cache_stats must be in dict_aux_tensor. "
164
+ "Please check the frontend and backend version. "
165
+ )
166
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
167
+ assert "prev_iter_dev" in dict_aux_tensor, (
168
+ "prev_iter_dev must be in dict_aux_tensor. "
169
+ "Please check the frontend and backend version. "
170
+ )
171
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
172
+ assert "vbe_output_offsets" in dict_aux_tensor, (
173
+ "vbe_output_offsets must be in dict_aux_tensor. "
174
+ "Please check the frontend and backend version. "
175
+ )
176
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
177
+
178
+ aux_int: List[int] = []
179
+ assert "iter" in dict_aux_int, (
180
+ "iter must be in dict_aux_int. "
181
+ "Please check the frontend and backend version. "
182
+ )
183
+ aux_int.append(dict_aux_int["iter"])
184
+ assert "info_B_num_bits" in dict_aux_int, (
185
+ "info_B_num_bits must be in dict_aux_int. "
186
+ "Please check the frontend and backend version. "
187
+ )
188
+ aux_int.append(dict_aux_int["info_B_num_bits"])
189
+ assert "info_B_mask" in dict_aux_int, (
190
+ "info_B_mask must be in dict_aux_int. "
191
+ "Please check the frontend and backend version. "
192
+ )
193
+ aux_int.append(dict_aux_int["info_B_mask"])
194
+
195
+ aux_float: List[float] = []
196
+ assert "gwd_lower_bound" in dict_aux_float, (
197
+ "gwd_lower_bound must be in dict_aux_float. "
198
+ "Please check the frontend and backend version. "
199
+ )
200
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
201
+ assert "max_gradient" in dict_aux_float, (
202
+ "max_gradient must be in dict_aux_float. "
203
+ "Please check the frontend and backend version. "
204
+ )
205
+ aux_float.append(dict_aux_float["max_gradient"])
206
+
207
+ aux_bool: List[bool] = []
208
+ assert "is_experimental_tbe" in dict_aux_bool, (
209
+ "is_experimental_tbe must be in dict_aux_bool. "
210
+ "Please check the frontend and backend version. "
211
+ )
212
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
213
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
214
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
215
+ "Please check the frontend and backend version. "
216
+ )
217
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
218
+ assert "use_homogeneous_placements" in dict_aux_bool, (
219
+ "use_homogeneous_placements must be in dict_aux_bool. "
220
+ "Please check the frontend and backend version. "
221
+ )
222
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
223
+ assert "apply_global_weight_decay" in dict_aux_bool, (
224
+ "apply_global_weight_decay must be in dict_aux_bool. "
225
+ "Please check the frontend and backend version. "
226
+ )
227
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
228
+ assert "gradient_clipping" in dict_aux_bool, (
229
+ "gradient_clipping must be in dict_aux_bool. "
230
+ "Please check the frontend and backend version. "
231
+ )
232
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
233
+ assert "stochastic_rounding" in dict_aux_bool, (
234
+ "stochastic_rounding must be in dict_aux_bool. "
235
+ "Please check the frontend and backend version. "
236
+ )
237
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
238
+ assert "mixed_D" in dict_aux_bool, (
239
+ "mixed_D must be in dict_aux_bool. "
240
+ "Please check the frontend and backend version. "
241
+ )
242
+ aux_bool.append(dict_aux_bool["mixed_D"])
243
+
244
+ # optim_int
245
+ optim_int: List[int] = []
246
+ optim_int.append(dict_optim_int["counter_halflife"])
247
+ optim_int.append(dict_optim_int["adjustment_iter"])
248
+ optim_int.append(dict_optim_int["learning_rate_mode"])
249
+ optim_int.append(dict_optim_int["weight_decay_mode"])
250
+ optim_int.append(dict_optim_int["grad_sum_decay"])
251
+ optim_int.append(dict_optim_int["is_tail_id_thresh_ratio"])
252
+ optim_int.append(dict_optim_int["regularization_mode"])
253
+ # optim_float
254
+ # ['momentum1', 'prev_iter', 'row_counter', 'learning_rate_tensor', 'optim_int', 'optim_float']
255
+ optim_float: List[float] = []
256
+ optim_float.append(dict_optim_float["eps"])
257
+ optim_float.append(dict_optim_float["weight_decay"])
258
+ optim_float.append(dict_optim_float["adjustment_ub"])
259
+ optim_float.append(dict_optim_float["max_counter"])
260
+ optim_float.append(dict_optim_float["tail_id_threshold"])
261
+ optim_float.append(dict_optim_float["weight_norm_coefficient"])
262
+ optim_float.append(dict_optim_float["lower_bound"])
263
+ # optim_bool
264
+
265
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function_pt2(
33
266
  # common_args
34
267
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
35
- dev_weights=common_args.dev_weights,
36
- uvm_weights=common_args.uvm_weights,
37
- lxu_cache_weights=common_args.lxu_cache_weights,
38
- weights_placements=common_args.weights_placements,
39
- weights_offsets=common_args.weights_offsets,
268
+ # weights
269
+ weights=weights,
40
270
  D_offsets=common_args.D_offsets,
41
271
  total_D=common_args.total_D,
42
272
  max_D=common_args.max_D,
@@ -47,59 +277,35 @@ def invoke(
47
277
  pooling_mode=common_args.pooling_mode,
48
278
  indice_weights=common_args.indice_weights,
49
279
  feature_requires_grad=common_args.feature_requires_grad,
50
- lxu_cache_locations=common_args.lxu_cache_locations,
51
- uvm_cache_stats=common_args.uvm_cache_stats,
280
+ output_dtype=common_args.output_dtype,
52
281
  # VBE metadata
53
- B_offsets=vbe_metadata.B_offsets,
54
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
55
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
56
282
  max_B=vbe_metadata.max_B,
57
283
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
58
284
  vbe_output_size=vbe_metadata.output_size,
59
- # optimizer_args
60
- gradient_clipping = optimizer_args.gradient_clipping,
61
- max_gradient=optimizer_args.max_gradient,
62
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
63
- # V1 interface still accepts learning_rate as float
64
- learning_rate=optimizer_args.learning_rate,
65
- eps=optimizer_args.eps,
66
- weight_decay=optimizer_args.weight_decay,
67
- weight_decay_mode=optimizer_args.weight_decay_mode,
68
- counter_halflife=optimizer_args.counter_halflife,
69
- adjustment_iter=optimizer_args.adjustment_iter,
70
- adjustment_ub=optimizer_args.adjustment_ub,
71
- learning_rate_mode=optimizer_args.learning_rate_mode,
72
- grad_sum_decay=optimizer_args.grad_sum_decay,
73
- tail_id_threshold=optimizer_args.tail_id_threshold,
74
- is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio,
75
- weight_norm_coefficient=optimizer_args.weight_norm_coefficient,
76
- lower_bound=optimizer_args.lower_bound,
77
- regularization_mode=optimizer_args.regularization_mode,
285
+ vbe_output=vbe_metadata.vbe_output,
286
+ # aux_tensor
287
+ aux_tensor=aux_tensor,
288
+ # aux_int
289
+ aux_int=aux_int,
290
+ # aux_float
291
+ aux_float=aux_float,
292
+ # aux_bool
293
+ aux_bool=aux_bool,
294
+ learning_rate_tensor=common_args.learning_rate_tensor,
295
+
78
296
  # momentum1
79
- momentum1_dev=momentum1.dev,
80
- momentum1_uvm=momentum1.uvm,
81
- momentum1_offsets=momentum1.offsets,
82
- momentum1_placements=momentum1.placements,
297
+ momentum1 = momentum1_list,
83
298
  # momentum2
84
299
  # prev_iter
85
- prev_iter_dev=prev_iter.dev,
86
- prev_iter_uvm=prev_iter.uvm,
87
- prev_iter_offsets=prev_iter.offsets,
88
- prev_iter_placements=prev_iter.placements,
300
+ prev_iter=prev_iter_list,
89
301
  # row_counter
90
- row_counter_dev=row_counter.dev,
91
- row_counter_uvm=row_counter.uvm,
92
- row_counter_offsets=row_counter.offsets,
93
- row_counter_placements=row_counter.placements,
94
- # iter
95
- iter=iter,
96
- # max counter
97
- max_counter=max_counter,
302
+ row_counter=row_counter_list,
303
+ # optim_tensor
304
+ # optim_int
305
+ optim_int=optim_int,
306
+ # optim_float
307
+ optim_float=optim_float,
308
+ # optim_bool
309
+ # optim symint args
98
310
  # total_unique_indices
99
- output_dtype=common_args.output_dtype,
100
- is_experimental=common_args.is_experimental,
101
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
102
- use_homogeneous_placements=common_args.use_homogeneous_placements,
103
- apply_global_weight_decay=apply_global_weight_decay,
104
- gwd_lower_bound=gwd_lower_bound,
105
311
  )