fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -16,6 +18,12 @@
16
18
  import torch
17
19
  import warnings
18
20
  from .lookup_args import *
21
+
22
+
23
+
24
+
25
+
26
+
19
27
  def invoke(
20
28
  common_args: CommonArgs,
21
29
  optimizer_args: OptimizerArgs,
@@ -26,6 +34,7 @@ def invoke(
26
34
  # only pass prev_iter_dev since prev_iter is never created on UVM
27
35
  prev_iter_dev: Optional[torch.Tensor] = None,
28
36
  gwd_lower_bound: float = 0.0,
37
+ mixed_D: bool = True,
29
38
  ) -> torch.Tensor:
30
39
  # By design, the warning only shows up once
31
40
  warnings.warn(
@@ -35,17 +44,206 @@ def invoke(
35
44
  times and build sizes!
36
45
  \033[0m"""
37
46
  )
38
-
47
+ # host_weights is only used for CPU training
48
+ use_cpu = common_args.host_weights.numel() > 0
39
49
  vbe_metadata = common_args.vbe_metadata
40
50
 
41
- return torch.ops.fbgemm.split_embedding_codegen_lookup_lamb_function(
51
+ # pack weights
52
+ weights = [
53
+ common_args.dev_weights,
54
+ common_args.uvm_weights,
55
+ common_args.weights_placements,
56
+ common_args.weights_offsets,
57
+ common_args.lxu_cache_weights,
58
+ ] if not use_cpu else [
59
+ common_args.host_weights,
60
+ common_args.weights_placements,
61
+ common_args.weights_offsets,
62
+ ]
63
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
64
+ "B_offsets": vbe_metadata.B_offsets,
65
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
66
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
67
+ "lxu_cache_locations": common_args.lxu_cache_locations,
68
+ "uvm_cache_stats": common_args.uvm_cache_stats,
69
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
70
+ }
71
+
72
+ dict_aux_int: Dict[str, int] = {
73
+ "iter": iter,
74
+ "info_B_num_bits": common_args.info_B_num_bits,
75
+ "info_B_mask": common_args.info_B_mask,
76
+ }
77
+
78
+ dict_aux_float: Dict[str, float] = {
79
+ "gwd_lower_bound": gwd_lower_bound,
80
+ }
81
+
82
+ dict_aux_bool: Dict[str, bool] = {
83
+ "is_experimental_tbe": common_args.is_experimental,
84
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
85
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
86
+ "apply_global_weight_decay": apply_global_weight_decay,
87
+ "mixed_D": mixed_D,
88
+ }
89
+ dict_optim_int: Dict[str, int] = {}
90
+ dict_optim_float: Dict[str, float] = {}
91
+ dict_optim_bool: Dict[str, bool] = {}
92
+
93
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
94
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
95
+
96
+
97
+ # optimizer_args # if optimizer == none
98
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
99
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
100
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
101
+ dict_optim_float["eps"] = optimizer_args.eps
102
+ dict_optim_float["beta1"] = optimizer_args.beta1
103
+ dict_optim_float["beta2"] = optimizer_args.beta2
104
+ dict_optim_float["weight_decay"] = optimizer_args.weight_decay
105
+
106
+ momentum1_list = [
107
+ momentum1.dev,
108
+ momentum1.uvm,
109
+ momentum1.placements,
110
+ momentum1.offsets,
111
+ ] if not use_cpu else [
112
+ momentum1.host,
113
+ momentum1.placements,
114
+ momentum1.offsets,
115
+ ] if momentum1 is not None else None
116
+
117
+ momentum2_list = [
118
+ momentum2.dev,
119
+ momentum2.uvm,
120
+ momentum2.placements,
121
+ momentum2.offsets,
122
+ ] if not use_cpu else [
123
+ momentum2.host,
124
+ momentum2.placements,
125
+ momentum2.offsets,
126
+ ] if momentum2 is not None else None
127
+
128
+
129
+ aux_tensor: List[Optional[torch.Tensor]] = []
130
+ assert "B_offsets" in dict_aux_tensor, (
131
+ "B_offsets must be in dict_aux_tensor. "
132
+ "Please check the frontend and backend version. "
133
+ )
134
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
135
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
136
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
137
+ "Please check the frontend and backend version. "
138
+ )
139
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
140
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
141
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
142
+ "Please check the frontend and backend version. "
143
+ )
144
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
145
+ assert "lxu_cache_locations" in dict_aux_tensor, (
146
+ "lxu_cache_locations must be in dict_aux_tensor. "
147
+ "Please check the frontend and backend version. "
148
+ )
149
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
150
+ assert "uvm_cache_stats" in dict_aux_tensor, (
151
+ "uvm_cache_stats must be in dict_aux_tensor. "
152
+ "Please check the frontend and backend version. "
153
+ )
154
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
155
+ assert "prev_iter_dev" in dict_aux_tensor, (
156
+ "prev_iter_dev must be in dict_aux_tensor. "
157
+ "Please check the frontend and backend version. "
158
+ )
159
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
160
+ assert "vbe_output_offsets" in dict_aux_tensor, (
161
+ "vbe_output_offsets must be in dict_aux_tensor. "
162
+ "Please check the frontend and backend version. "
163
+ )
164
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
165
+
166
+ aux_int: List[int] = []
167
+ assert "iter" in dict_aux_int, (
168
+ "iter must be in dict_aux_int. "
169
+ "Please check the frontend and backend version. "
170
+ )
171
+ aux_int.append(dict_aux_int["iter"])
172
+ assert "info_B_num_bits" in dict_aux_int, (
173
+ "info_B_num_bits must be in dict_aux_int. "
174
+ "Please check the frontend and backend version. "
175
+ )
176
+ aux_int.append(dict_aux_int["info_B_num_bits"])
177
+ assert "info_B_mask" in dict_aux_int, (
178
+ "info_B_mask must be in dict_aux_int. "
179
+ "Please check the frontend and backend version. "
180
+ )
181
+ aux_int.append(dict_aux_int["info_B_mask"])
182
+
183
+ aux_float: List[float] = []
184
+ assert "gwd_lower_bound" in dict_aux_float, (
185
+ "gwd_lower_bound must be in dict_aux_float. "
186
+ "Please check the frontend and backend version. "
187
+ )
188
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
189
+ assert "max_gradient" in dict_aux_float, (
190
+ "max_gradient must be in dict_aux_float. "
191
+ "Please check the frontend and backend version. "
192
+ )
193
+ aux_float.append(dict_aux_float["max_gradient"])
194
+
195
+ aux_bool: List[bool] = []
196
+ assert "is_experimental_tbe" in dict_aux_bool, (
197
+ "is_experimental_tbe must be in dict_aux_bool. "
198
+ "Please check the frontend and backend version. "
199
+ )
200
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
201
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
202
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
203
+ "Please check the frontend and backend version. "
204
+ )
205
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
206
+ assert "use_homogeneous_placements" in dict_aux_bool, (
207
+ "use_homogeneous_placements must be in dict_aux_bool. "
208
+ "Please check the frontend and backend version. "
209
+ )
210
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
211
+ assert "apply_global_weight_decay" in dict_aux_bool, (
212
+ "apply_global_weight_decay must be in dict_aux_bool. "
213
+ "Please check the frontend and backend version. "
214
+ )
215
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
216
+ assert "gradient_clipping" in dict_aux_bool, (
217
+ "gradient_clipping must be in dict_aux_bool. "
218
+ "Please check the frontend and backend version. "
219
+ )
220
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
221
+ assert "stochastic_rounding" in dict_aux_bool, (
222
+ "stochastic_rounding must be in dict_aux_bool. "
223
+ "Please check the frontend and backend version. "
224
+ )
225
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
226
+ assert "mixed_D" in dict_aux_bool, (
227
+ "mixed_D must be in dict_aux_bool. "
228
+ "Please check the frontend and backend version. "
229
+ )
230
+ aux_bool.append(dict_aux_bool["mixed_D"])
231
+
232
+ # optim_int
233
+ # optim_float
234
+ # ['momentum1', 'momentum2', 'learning_rate_tensor', 'optim_float']
235
+ optim_float: List[float] = []
236
+ optim_float.append(dict_optim_float["eps"])
237
+ optim_float.append(dict_optim_float["beta1"])
238
+ optim_float.append(dict_optim_float["beta2"])
239
+ optim_float.append(dict_optim_float["weight_decay"])
240
+ # optim_bool
241
+
242
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_lamb_function_pt2(
42
243
  # common_args
43
244
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
44
- dev_weights=common_args.dev_weights,
45
- uvm_weights=common_args.uvm_weights,
46
- lxu_cache_weights=common_args.lxu_cache_weights,
47
- weights_placements=common_args.weights_placements,
48
- weights_offsets=common_args.weights_offsets,
245
+ # weights
246
+ weights=weights,
49
247
  D_offsets=common_args.D_offsets,
50
248
  total_D=common_args.total_D,
51
249
  max_D=common_args.max_D,
@@ -56,47 +254,33 @@ def invoke(
56
254
  pooling_mode=common_args.pooling_mode,
57
255
  indice_weights=common_args.indice_weights,
58
256
  feature_requires_grad=common_args.feature_requires_grad,
59
- lxu_cache_locations=common_args.lxu_cache_locations,
60
- uvm_cache_stats=common_args.uvm_cache_stats,
257
+ output_dtype=common_args.output_dtype,
61
258
  # VBE metadata
62
- B_offsets=vbe_metadata.B_offsets,
63
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
64
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
65
259
  max_B=vbe_metadata.max_B,
66
260
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
67
261
  vbe_output_size=vbe_metadata.output_size,
68
- # optimizer_args
69
- gradient_clipping = optimizer_args.gradient_clipping,
70
- max_gradient=optimizer_args.max_gradient,
71
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
72
- # V1 interface still accepts learning_rate as float
73
- learning_rate=optimizer_args.learning_rate,
74
- eps=optimizer_args.eps,
75
- beta1=optimizer_args.beta1,
76
- beta2=optimizer_args.beta2,
77
- weight_decay=optimizer_args.weight_decay,
262
+ vbe_output=vbe_metadata.vbe_output,
263
+ # aux_tensor
264
+ aux_tensor=aux_tensor,
265
+ # aux_int
266
+ aux_int=aux_int,
267
+ # aux_float
268
+ aux_float=aux_float,
269
+ # aux_bool
270
+ aux_bool=aux_bool,
271
+ learning_rate_tensor=common_args.learning_rate_tensor,
272
+
78
273
  # momentum1
79
- momentum1_dev=momentum1.dev,
80
- momentum1_uvm=momentum1.uvm,
81
- momentum1_offsets=momentum1.offsets,
82
- momentum1_placements=momentum1.placements,
274
+ momentum1 = momentum1_list,
83
275
  # momentum2
84
- momentum2_dev=momentum2.dev,
85
- momentum2_uvm=momentum2.uvm,
86
- momentum2_offsets=momentum2.offsets,
87
- momentum2_placements=momentum2.placements,
276
+ momentum2=momentum2_list,
88
277
  # prev_iter
89
-
90
- prev_iter_dev=prev_iter_dev,
91
278
  # row_counter
92
- # iter
93
- iter=iter,
94
- # max counter
279
+ # optim_tensor
280
+ # optim_int
281
+ # optim_float
282
+ optim_float=optim_float,
283
+ # optim_bool
284
+ # optim symint args
95
285
  # total_unique_indices
96
- output_dtype=common_args.output_dtype,
97
- is_experimental=common_args.is_experimental,
98
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
99
- use_homogeneous_placements=common_args.use_homogeneous_placements,
100
- apply_global_weight_decay=apply_global_weight_decay,
101
- gwd_lower_bound=gwd_lower_bound,
102
286
  )
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -16,6 +18,12 @@
16
18
  import torch
17
19
  import warnings
18
20
  from .lookup_args import *
21
+
22
+
23
+
24
+
25
+
26
+
19
27
  def invoke(
20
28
  common_args: CommonArgs,
21
29
  optimizer_args: OptimizerArgs,
@@ -25,6 +33,7 @@ def invoke(
25
33
  # only pass prev_iter_dev since prev_iter is never created on UVM
26
34
  prev_iter_dev: Optional[torch.Tensor] = None,
27
35
  gwd_lower_bound: float = 0.0,
36
+ mixed_D: bool = True,
28
37
  ) -> torch.Tensor:
29
38
  # By design, the warning only shows up once
30
39
  warnings.warn(
@@ -34,17 +43,193 @@ def invoke(
34
43
  times and build sizes!
35
44
  \033[0m"""
36
45
  )
37
-
46
+ # host_weights is only used for CPU training
47
+ use_cpu = common_args.host_weights.numel() > 0
38
48
  vbe_metadata = common_args.vbe_metadata
39
49
 
40
- return torch.ops.fbgemm.split_embedding_codegen_lookup_lars_sgd_function(
50
+ # pack weights
51
+ weights = [
52
+ common_args.dev_weights,
53
+ common_args.uvm_weights,
54
+ common_args.weights_placements,
55
+ common_args.weights_offsets,
56
+ common_args.lxu_cache_weights,
57
+ ] if not use_cpu else [
58
+ common_args.host_weights,
59
+ common_args.weights_placements,
60
+ common_args.weights_offsets,
61
+ ]
62
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
63
+ "B_offsets": vbe_metadata.B_offsets,
64
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
65
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
66
+ "lxu_cache_locations": common_args.lxu_cache_locations,
67
+ "uvm_cache_stats": common_args.uvm_cache_stats,
68
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
69
+ }
70
+
71
+ dict_aux_int: Dict[str, int] = {
72
+ "iter": iter,
73
+ "info_B_num_bits": common_args.info_B_num_bits,
74
+ "info_B_mask": common_args.info_B_mask,
75
+ }
76
+
77
+ dict_aux_float: Dict[str, float] = {
78
+ "gwd_lower_bound": gwd_lower_bound,
79
+ }
80
+
81
+ dict_aux_bool: Dict[str, bool] = {
82
+ "is_experimental_tbe": common_args.is_experimental,
83
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
84
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
85
+ "apply_global_weight_decay": apply_global_weight_decay,
86
+ "mixed_D": mixed_D,
87
+ }
88
+ dict_optim_int: Dict[str, int] = {}
89
+ dict_optim_float: Dict[str, float] = {}
90
+ dict_optim_bool: Dict[str, bool] = {}
91
+
92
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
93
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
94
+
95
+
96
+ # optimizer_args # if optimizer == none
97
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
98
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
99
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
100
+ dict_optim_float["weight_decay"] = optimizer_args.weight_decay
101
+ dict_optim_float["eta"] = optimizer_args.eta
102
+ dict_optim_float["momentum"] = optimizer_args.momentum
103
+
104
+ momentum1_list = [
105
+ momentum1.dev,
106
+ momentum1.uvm,
107
+ momentum1.placements,
108
+ momentum1.offsets,
109
+ ] if not use_cpu else [
110
+ momentum1.host,
111
+ momentum1.placements,
112
+ momentum1.offsets,
113
+ ] if momentum1 is not None else None
114
+
115
+
116
+ aux_tensor: List[Optional[torch.Tensor]] = []
117
+ assert "B_offsets" in dict_aux_tensor, (
118
+ "B_offsets must be in dict_aux_tensor. "
119
+ "Please check the frontend and backend version. "
120
+ )
121
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
122
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
123
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
124
+ "Please check the frontend and backend version. "
125
+ )
126
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
127
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
128
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
129
+ "Please check the frontend and backend version. "
130
+ )
131
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
132
+ assert "lxu_cache_locations" in dict_aux_tensor, (
133
+ "lxu_cache_locations must be in dict_aux_tensor. "
134
+ "Please check the frontend and backend version. "
135
+ )
136
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
137
+ assert "uvm_cache_stats" in dict_aux_tensor, (
138
+ "uvm_cache_stats must be in dict_aux_tensor. "
139
+ "Please check the frontend and backend version. "
140
+ )
141
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
142
+ assert "prev_iter_dev" in dict_aux_tensor, (
143
+ "prev_iter_dev must be in dict_aux_tensor. "
144
+ "Please check the frontend and backend version. "
145
+ )
146
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
147
+ assert "vbe_output_offsets" in dict_aux_tensor, (
148
+ "vbe_output_offsets must be in dict_aux_tensor. "
149
+ "Please check the frontend and backend version. "
150
+ )
151
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
152
+
153
+ aux_int: List[int] = []
154
+ assert "iter" in dict_aux_int, (
155
+ "iter must be in dict_aux_int. "
156
+ "Please check the frontend and backend version. "
157
+ )
158
+ aux_int.append(dict_aux_int["iter"])
159
+ assert "info_B_num_bits" in dict_aux_int, (
160
+ "info_B_num_bits must be in dict_aux_int. "
161
+ "Please check the frontend and backend version. "
162
+ )
163
+ aux_int.append(dict_aux_int["info_B_num_bits"])
164
+ assert "info_B_mask" in dict_aux_int, (
165
+ "info_B_mask must be in dict_aux_int. "
166
+ "Please check the frontend and backend version. "
167
+ )
168
+ aux_int.append(dict_aux_int["info_B_mask"])
169
+
170
+ aux_float: List[float] = []
171
+ assert "gwd_lower_bound" in dict_aux_float, (
172
+ "gwd_lower_bound must be in dict_aux_float. "
173
+ "Please check the frontend and backend version. "
174
+ )
175
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
176
+ assert "max_gradient" in dict_aux_float, (
177
+ "max_gradient must be in dict_aux_float. "
178
+ "Please check the frontend and backend version. "
179
+ )
180
+ aux_float.append(dict_aux_float["max_gradient"])
181
+
182
+ aux_bool: List[bool] = []
183
+ assert "is_experimental_tbe" in dict_aux_bool, (
184
+ "is_experimental_tbe must be in dict_aux_bool. "
185
+ "Please check the frontend and backend version. "
186
+ )
187
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
188
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
189
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
190
+ "Please check the frontend and backend version. "
191
+ )
192
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
193
+ assert "use_homogeneous_placements" in dict_aux_bool, (
194
+ "use_homogeneous_placements must be in dict_aux_bool. "
195
+ "Please check the frontend and backend version. "
196
+ )
197
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
198
+ assert "apply_global_weight_decay" in dict_aux_bool, (
199
+ "apply_global_weight_decay must be in dict_aux_bool. "
200
+ "Please check the frontend and backend version. "
201
+ )
202
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
203
+ assert "gradient_clipping" in dict_aux_bool, (
204
+ "gradient_clipping must be in dict_aux_bool. "
205
+ "Please check the frontend and backend version. "
206
+ )
207
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
208
+ assert "stochastic_rounding" in dict_aux_bool, (
209
+ "stochastic_rounding must be in dict_aux_bool. "
210
+ "Please check the frontend and backend version. "
211
+ )
212
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
213
+ assert "mixed_D" in dict_aux_bool, (
214
+ "mixed_D must be in dict_aux_bool. "
215
+ "Please check the frontend and backend version. "
216
+ )
217
+ aux_bool.append(dict_aux_bool["mixed_D"])
218
+
219
+ # optim_int
220
+ # optim_float
221
+ # ['momentum1', 'learning_rate_tensor', 'optim_float']
222
+ optim_float: List[float] = []
223
+ optim_float.append(dict_optim_float["eta"])
224
+ optim_float.append(dict_optim_float["momentum"])
225
+ optim_float.append(dict_optim_float["weight_decay"])
226
+ # optim_bool
227
+
228
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_lars_sgd_function_pt2(
41
229
  # common_args
42
230
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
43
- dev_weights=common_args.dev_weights,
44
- uvm_weights=common_args.uvm_weights,
45
- lxu_cache_weights=common_args.lxu_cache_weights,
46
- weights_placements=common_args.weights_placements,
47
- weights_offsets=common_args.weights_offsets,
231
+ # weights
232
+ weights=weights,
48
233
  D_offsets=common_args.D_offsets,
49
234
  total_D=common_args.total_D,
50
235
  max_D=common_args.max_D,
@@ -55,42 +240,32 @@ def invoke(
55
240
  pooling_mode=common_args.pooling_mode,
56
241
  indice_weights=common_args.indice_weights,
57
242
  feature_requires_grad=common_args.feature_requires_grad,
58
- lxu_cache_locations=common_args.lxu_cache_locations,
59
- uvm_cache_stats=common_args.uvm_cache_stats,
243
+ output_dtype=common_args.output_dtype,
60
244
  # VBE metadata
61
- B_offsets=vbe_metadata.B_offsets,
62
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
63
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
64
245
  max_B=vbe_metadata.max_B,
65
246
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
66
247
  vbe_output_size=vbe_metadata.output_size,
67
- # optimizer_args
68
- gradient_clipping = optimizer_args.gradient_clipping,
69
- max_gradient=optimizer_args.max_gradient,
70
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
71
- # V1 interface still accepts learning_rate as float
72
- learning_rate=optimizer_args.learning_rate,
73
- weight_decay=optimizer_args.weight_decay,
74
- eta=optimizer_args.eta,
75
- momentum=optimizer_args.momentum,
248
+ vbe_output=vbe_metadata.vbe_output,
249
+ # aux_tensor
250
+ aux_tensor=aux_tensor,
251
+ # aux_int
252
+ aux_int=aux_int,
253
+ # aux_float
254
+ aux_float=aux_float,
255
+ # aux_bool
256
+ aux_bool=aux_bool,
257
+ learning_rate_tensor=common_args.learning_rate_tensor,
258
+
76
259
  # momentum1
77
- momentum1_dev=momentum1.dev,
78
- momentum1_uvm=momentum1.uvm,
79
- momentum1_offsets=momentum1.offsets,
80
- momentum1_placements=momentum1.placements,
260
+ momentum1 = momentum1_list,
81
261
  # momentum2
82
262
  # prev_iter
83
-
84
- prev_iter_dev=prev_iter_dev,
85
263
  # row_counter
86
- # iter
87
- iter=iter,
88
- # max counter
264
+ # optim_tensor
265
+ # optim_int
266
+ # optim_float
267
+ optim_float=optim_float,
268
+ # optim_bool
269
+ # optim symint args
89
270
  # total_unique_indices
90
- output_dtype=common_args.output_dtype,
91
- is_experimental=common_args.is_experimental,
92
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
93
- use_homogeneous_placements=common_args.use_homogeneous_placements,
94
- apply_global_weight_decay=apply_global_weight_decay,
95
- gwd_lower_bound=gwd_lower_bound,
96
271
  )