fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -15,6 +17,12 @@
15
17
 
16
18
  import torch
17
19
  from .lookup_args import *
20
+
21
+
22
+
23
+
24
+
25
+
18
26
  def invoke(
19
27
  common_args: CommonArgs,
20
28
  optimizer_args: OptimizerArgs,
@@ -25,18 +33,208 @@ def invoke(
25
33
  # only pass prev_iter_dev since prev_iter is never created on UVM
26
34
  prev_iter_dev: Optional[torch.Tensor] = None,
27
35
  gwd_lower_bound: float = 0.0,
36
+ mixed_D: bool = True,
28
37
  ) -> torch.Tensor:
29
-
38
+ # host_weights is only used for CPU training
39
+ use_cpu = common_args.host_weights.numel() > 0
30
40
  vbe_metadata = common_args.vbe_metadata
31
41
 
32
- return torch.ops.fbgemm.split_embedding_codegen_lookup_partial_rowwise_lamb_function(
42
+ # pack weights
43
+ weights = [
44
+ common_args.dev_weights,
45
+ common_args.uvm_weights,
46
+ common_args.weights_placements,
47
+ common_args.weights_offsets,
48
+ common_args.lxu_cache_weights,
49
+ ] if not use_cpu else [
50
+ common_args.host_weights,
51
+ common_args.weights_placements,
52
+ common_args.weights_offsets,
53
+ ]
54
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
55
+ "B_offsets": vbe_metadata.B_offsets,
56
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
57
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
58
+ "lxu_cache_locations": common_args.lxu_cache_locations,
59
+ "uvm_cache_stats": common_args.uvm_cache_stats,
60
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
61
+ }
62
+
63
+ dict_aux_int: Dict[str, int] = {
64
+ "iter": iter,
65
+ "info_B_num_bits": common_args.info_B_num_bits,
66
+ "info_B_mask": common_args.info_B_mask,
67
+ }
68
+
69
+ dict_aux_float: Dict[str, float] = {
70
+ "gwd_lower_bound": gwd_lower_bound,
71
+ }
72
+
73
+ dict_aux_bool: Dict[str, bool] = {
74
+ "is_experimental_tbe": common_args.is_experimental,
75
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
76
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
77
+ "apply_global_weight_decay": apply_global_weight_decay,
78
+ "mixed_D": mixed_D,
79
+ }
80
+ dict_optim_int: Dict[str, int] = {}
81
+ dict_optim_float: Dict[str, float] = {}
82
+ dict_optim_bool: Dict[str, bool] = {}
83
+
84
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
85
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
86
+
87
+
88
+ # optimizer_args # if optimizer == none
89
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
90
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
91
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
92
+ dict_optim_float["eps"] = optimizer_args.eps
93
+ dict_optim_float["beta1"] = optimizer_args.beta1
94
+ dict_optim_float["beta2"] = optimizer_args.beta2
95
+ dict_optim_float["weight_decay"] = optimizer_args.weight_decay
96
+
97
+ momentum1_list = [
98
+ momentum1.dev,
99
+ momentum1.uvm,
100
+ momentum1.placements,
101
+ momentum1.offsets,
102
+ ] if not use_cpu else [
103
+ momentum1.host,
104
+ momentum1.placements,
105
+ momentum1.offsets,
106
+ ] if momentum1 is not None else None
107
+
108
+ momentum2_list = [
109
+ momentum2.dev,
110
+ momentum2.uvm,
111
+ momentum2.placements,
112
+ momentum2.offsets,
113
+ ] if not use_cpu else [
114
+ momentum2.host,
115
+ momentum2.placements,
116
+ momentum2.offsets,
117
+ ] if momentum2 is not None else None
118
+
119
+
120
+ aux_tensor: List[Optional[torch.Tensor]] = []
121
+ assert "B_offsets" in dict_aux_tensor, (
122
+ "B_offsets must be in dict_aux_tensor. "
123
+ "Please check the frontend and backend version. "
124
+ )
125
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
126
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
127
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
128
+ "Please check the frontend and backend version. "
129
+ )
130
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
131
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
132
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
133
+ "Please check the frontend and backend version. "
134
+ )
135
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
136
+ assert "lxu_cache_locations" in dict_aux_tensor, (
137
+ "lxu_cache_locations must be in dict_aux_tensor. "
138
+ "Please check the frontend and backend version. "
139
+ )
140
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
141
+ assert "uvm_cache_stats" in dict_aux_tensor, (
142
+ "uvm_cache_stats must be in dict_aux_tensor. "
143
+ "Please check the frontend and backend version. "
144
+ )
145
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
146
+ assert "prev_iter_dev" in dict_aux_tensor, (
147
+ "prev_iter_dev must be in dict_aux_tensor. "
148
+ "Please check the frontend and backend version. "
149
+ )
150
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
151
+ assert "vbe_output_offsets" in dict_aux_tensor, (
152
+ "vbe_output_offsets must be in dict_aux_tensor. "
153
+ "Please check the frontend and backend version. "
154
+ )
155
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
156
+
157
+ aux_int: List[int] = []
158
+ assert "iter" in dict_aux_int, (
159
+ "iter must be in dict_aux_int. "
160
+ "Please check the frontend and backend version. "
161
+ )
162
+ aux_int.append(dict_aux_int["iter"])
163
+ assert "info_B_num_bits" in dict_aux_int, (
164
+ "info_B_num_bits must be in dict_aux_int. "
165
+ "Please check the frontend and backend version. "
166
+ )
167
+ aux_int.append(dict_aux_int["info_B_num_bits"])
168
+ assert "info_B_mask" in dict_aux_int, (
169
+ "info_B_mask must be in dict_aux_int. "
170
+ "Please check the frontend and backend version. "
171
+ )
172
+ aux_int.append(dict_aux_int["info_B_mask"])
173
+
174
+ aux_float: List[float] = []
175
+ assert "gwd_lower_bound" in dict_aux_float, (
176
+ "gwd_lower_bound must be in dict_aux_float. "
177
+ "Please check the frontend and backend version. "
178
+ )
179
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
180
+ assert "max_gradient" in dict_aux_float, (
181
+ "max_gradient must be in dict_aux_float. "
182
+ "Please check the frontend and backend version. "
183
+ )
184
+ aux_float.append(dict_aux_float["max_gradient"])
185
+
186
+ aux_bool: List[bool] = []
187
+ assert "is_experimental_tbe" in dict_aux_bool, (
188
+ "is_experimental_tbe must be in dict_aux_bool. "
189
+ "Please check the frontend and backend version. "
190
+ )
191
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
192
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
193
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
194
+ "Please check the frontend and backend version. "
195
+ )
196
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
197
+ assert "use_homogeneous_placements" in dict_aux_bool, (
198
+ "use_homogeneous_placements must be in dict_aux_bool. "
199
+ "Please check the frontend and backend version. "
200
+ )
201
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
202
+ assert "apply_global_weight_decay" in dict_aux_bool, (
203
+ "apply_global_weight_decay must be in dict_aux_bool. "
204
+ "Please check the frontend and backend version. "
205
+ )
206
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
207
+ assert "gradient_clipping" in dict_aux_bool, (
208
+ "gradient_clipping must be in dict_aux_bool. "
209
+ "Please check the frontend and backend version. "
210
+ )
211
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
212
+ assert "stochastic_rounding" in dict_aux_bool, (
213
+ "stochastic_rounding must be in dict_aux_bool. "
214
+ "Please check the frontend and backend version. "
215
+ )
216
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
217
+ assert "mixed_D" in dict_aux_bool, (
218
+ "mixed_D must be in dict_aux_bool. "
219
+ "Please check the frontend and backend version. "
220
+ )
221
+ aux_bool.append(dict_aux_bool["mixed_D"])
222
+
223
+ # optim_int
224
+ # optim_float
225
+ # ['momentum1', 'momentum2', 'learning_rate_tensor', 'optim_float']
226
+ optim_float: List[float] = []
227
+ optim_float.append(dict_optim_float["eps"])
228
+ optim_float.append(dict_optim_float["beta1"])
229
+ optim_float.append(dict_optim_float["beta2"])
230
+ optim_float.append(dict_optim_float["weight_decay"])
231
+ # optim_bool
232
+
233
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_partial_rowwise_lamb_function_pt2(
33
234
  # common_args
34
235
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
35
- dev_weights=common_args.dev_weights,
36
- uvm_weights=common_args.uvm_weights,
37
- lxu_cache_weights=common_args.lxu_cache_weights,
38
- weights_placements=common_args.weights_placements,
39
- weights_offsets=common_args.weights_offsets,
236
+ # weights
237
+ weights=weights,
40
238
  D_offsets=common_args.D_offsets,
41
239
  total_D=common_args.total_D,
42
240
  max_D=common_args.max_D,
@@ -47,47 +245,33 @@ def invoke(
47
245
  pooling_mode=common_args.pooling_mode,
48
246
  indice_weights=common_args.indice_weights,
49
247
  feature_requires_grad=common_args.feature_requires_grad,
50
- lxu_cache_locations=common_args.lxu_cache_locations,
51
- uvm_cache_stats=common_args.uvm_cache_stats,
248
+ output_dtype=common_args.output_dtype,
52
249
  # VBE metadata
53
- B_offsets=vbe_metadata.B_offsets,
54
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
55
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
56
250
  max_B=vbe_metadata.max_B,
57
251
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
58
252
  vbe_output_size=vbe_metadata.output_size,
59
- # optimizer_args
60
- gradient_clipping = optimizer_args.gradient_clipping,
61
- max_gradient=optimizer_args.max_gradient,
62
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
63
- # V1 interface still accepts learning_rate as float
64
- learning_rate=optimizer_args.learning_rate,
65
- eps=optimizer_args.eps,
66
- beta1=optimizer_args.beta1,
67
- beta2=optimizer_args.beta2,
68
- weight_decay=optimizer_args.weight_decay,
253
+ vbe_output=vbe_metadata.vbe_output,
254
+ # aux_tensor
255
+ aux_tensor=aux_tensor,
256
+ # aux_int
257
+ aux_int=aux_int,
258
+ # aux_float
259
+ aux_float=aux_float,
260
+ # aux_bool
261
+ aux_bool=aux_bool,
262
+ learning_rate_tensor=common_args.learning_rate_tensor,
263
+
69
264
  # momentum1
70
- momentum1_dev=momentum1.dev,
71
- momentum1_uvm=momentum1.uvm,
72
- momentum1_offsets=momentum1.offsets,
73
- momentum1_placements=momentum1.placements,
265
+ momentum1 = momentum1_list,
74
266
  # momentum2
75
- momentum2_dev=momentum2.dev,
76
- momentum2_uvm=momentum2.uvm,
77
- momentum2_offsets=momentum2.offsets,
78
- momentum2_placements=momentum2.placements,
267
+ momentum2=momentum2_list,
79
268
  # prev_iter
80
-
81
- prev_iter_dev=prev_iter_dev,
82
269
  # row_counter
83
- # iter
84
- iter=iter,
85
- # max counter
270
+ # optim_tensor
271
+ # optim_int
272
+ # optim_float
273
+ optim_float=optim_float,
274
+ # optim_bool
275
+ # optim symint args
86
276
  # total_unique_indices
87
- output_dtype=common_args.output_dtype,
88
- is_experimental=common_args.is_experimental,
89
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
90
- use_homogeneous_placements=common_args.use_homogeneous_placements,
91
- apply_global_weight_decay=apply_global_weight_decay,
92
- gwd_lower_bound=gwd_lower_bound,
93
277
  )
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -15,6 +17,12 @@
15
17
 
16
18
  import torch
17
19
  from .lookup_args import *
20
+
21
+
22
+
23
+
24
+
25
+
18
26
  def invoke(
19
27
  common_args: CommonArgs,
20
28
  optimizer_args: OptimizerArgs,
@@ -24,91 +32,198 @@ def invoke(
24
32
  # only pass prev_iter_dev since prev_iter is never created on UVM
25
33
  prev_iter_dev: Optional[torch.Tensor] = None,
26
34
  gwd_lower_bound: float = 0.0,
35
+ mixed_D: bool = True,
27
36
  ) -> torch.Tensor:
28
-
37
+ # host_weights is only used for CPU training
38
+ use_cpu = common_args.host_weights.numel() > 0
29
39
  vbe_metadata = common_args.vbe_metadata
30
- if (common_args.host_weights.numel() > 0):
31
- T = common_args.D_offsets.numel() - 1
32
- vbe: bool = vbe_metadata.B_offsets is not None
33
- if vbe:
34
- # create offsets with fixed batch size max_B
35
- # not efficient but for now we just need a functional implementation for CPU
36
- max_B = vbe_metadata.max_B
37
- offsets = torch.empty([T * max_B + 1], dtype=common_args.offsets.dtype, device=common_args.offsets.device)
38
- for t in range(T):
39
- B_offsets = vbe_metadata.B_offsets
40
- assert isinstance(B_offsets, torch.Tensor)
41
- begin = B_offsets[t]
42
- end = B_offsets[t + 1]
43
- offsets[t * max_B : t * max_B + end - begin] = common_args.offsets[begin : end]
44
- offsets[t * max_B + end - begin : (t + 1) * max_B] = common_args.offsets[end]
45
- offsets[-1] = common_args.offsets[-1]
46
- else:
47
- offsets = common_args.offsets
48
- output = torch.ops.fbgemm.split_embedding_codegen_lookup_rowwise_adagrad_function_cpu(
49
- # common_args
50
- host_weights=common_args.host_weights,
51
- weights_placements=common_args.weights_placements,
52
- weights_offsets=common_args.weights_offsets,
53
- D_offsets=common_args.D_offsets,
54
- total_D=common_args.total_D,
55
- max_D=common_args.max_D,
56
- hash_size_cumsum=common_args.hash_size_cumsum,
57
- total_hash_size_bits=common_args.total_hash_size_bits,
58
- indices=common_args.indices,
59
- offsets=offsets,
60
- pooling_mode=common_args.pooling_mode,
61
- indice_weights=common_args.indice_weights,
62
- feature_requires_grad=common_args.feature_requires_grad,
63
- # optimizer_args
64
- gradient_clipping = optimizer_args.gradient_clipping,
65
- max_gradient=optimizer_args.max_gradient,
66
- stochastic_rounding=optimizer_args.stochastic_rounding,
67
- learning_rate=optimizer_args.learning_rate,
68
- eps=optimizer_args.eps,
69
- weight_decay=optimizer_args.weight_decay,
70
- weight_decay_mode=optimizer_args.weight_decay_mode,
71
- max_norm=optimizer_args.max_norm,
72
- # momentum1
73
- momentum1_host=momentum1.host,
74
- momentum1_offsets=momentum1.offsets,
75
- momentum1_placements=momentum1.placements,
76
- # momentum2
77
- # prev_iter
78
- # row_counter
79
- # iter
80
- # max counter
81
- )
82
- if vbe:
83
- output_new = torch.empty([vbe_metadata.output_size], dtype=output.dtype, device=output.device)
84
- B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
85
- assert isinstance(B_offsets_rank_per_feature, torch.Tensor)
86
- output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
87
- assert isinstance(output_offsets_feature_rank, torch.Tensor)
88
- R = B_offsets_rank_per_feature.size(1) - 1
89
- for r in range(R):
90
- D_offset = 0
91
- for t in range(T):
92
- o_begin = output_offsets_feature_rank[r * T + t].item()
93
- o_end = output_offsets_feature_rank[r * T + t + 1].item()
94
- D = common_args.D_offsets[t + 1].item() - common_args.D_offsets[t].item()
95
- b_begin = B_offsets_rank_per_feature[t][r].item()
96
- b_end = B_offsets_rank_per_feature[t][r + 1].item()
97
- assert o_end - o_begin == (b_end - b_begin) * D
98
- output_new[o_begin : o_end] = output[b_begin : b_end, D_offset : D_offset + D].flatten()
99
- D_offset += D
100
- return output_new
101
- else:
102
- return output
103
-
104
- return torch.ops.fbgemm.split_embedding_codegen_lookup_rowwise_adagrad_function(
40
+
41
+ # pack weights
42
+ weights = [
43
+ common_args.dev_weights,
44
+ common_args.uvm_weights,
45
+ common_args.weights_placements,
46
+ common_args.weights_offsets,
47
+ common_args.lxu_cache_weights,
48
+ ] if not use_cpu else [
49
+ common_args.host_weights,
50
+ common_args.weights_placements,
51
+ common_args.weights_offsets,
52
+ ]
53
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
54
+ "B_offsets": vbe_metadata.B_offsets,
55
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
56
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
57
+ "lxu_cache_locations": common_args.lxu_cache_locations,
58
+ "uvm_cache_stats": common_args.uvm_cache_stats,
59
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
60
+ }
61
+
62
+ dict_aux_int: Dict[str, int] = {
63
+ "iter": iter,
64
+ "info_B_num_bits": common_args.info_B_num_bits,
65
+ "info_B_mask": common_args.info_B_mask,
66
+ }
67
+
68
+ dict_aux_float: Dict[str, float] = {
69
+ "gwd_lower_bound": gwd_lower_bound,
70
+ }
71
+
72
+ dict_aux_bool: Dict[str, bool] = {
73
+ "is_experimental_tbe": common_args.is_experimental,
74
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
75
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
76
+ "apply_global_weight_decay": apply_global_weight_decay,
77
+ "mixed_D": mixed_D,
78
+ }
79
+ dict_optim_int: Dict[str, int] = {}
80
+ dict_optim_float: Dict[str, float] = {}
81
+ dict_optim_bool: Dict[str, bool] = {}
82
+
83
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
84
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
85
+
86
+
87
+ # optimizer_args # if optimizer == none
88
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
89
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
90
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
91
+ dict_optim_float["eps"] = optimizer_args.eps
92
+ dict_optim_float["weight_decay"] = optimizer_args.weight_decay
93
+ dict_optim_int["weight_decay_mode"] = optimizer_args.weight_decay_mode
94
+ dict_optim_float["max_norm"] = optimizer_args.max_norm
95
+
96
+ momentum1_list = [
97
+ momentum1.dev,
98
+ momentum1.uvm,
99
+ momentum1.placements,
100
+ momentum1.offsets,
101
+ ] if not use_cpu else [
102
+ momentum1.host,
103
+ momentum1.placements,
104
+ momentum1.offsets,
105
+ ] if momentum1 is not None else None
106
+
107
+
108
+ aux_tensor: List[Optional[torch.Tensor]] = []
109
+ assert "B_offsets" in dict_aux_tensor, (
110
+ "B_offsets must be in dict_aux_tensor. "
111
+ "Please check the frontend and backend version. "
112
+ )
113
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
114
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
115
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
116
+ "Please check the frontend and backend version. "
117
+ )
118
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
119
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
120
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
121
+ "Please check the frontend and backend version. "
122
+ )
123
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
124
+ assert "lxu_cache_locations" in dict_aux_tensor, (
125
+ "lxu_cache_locations must be in dict_aux_tensor. "
126
+ "Please check the frontend and backend version. "
127
+ )
128
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
129
+ assert "uvm_cache_stats" in dict_aux_tensor, (
130
+ "uvm_cache_stats must be in dict_aux_tensor. "
131
+ "Please check the frontend and backend version. "
132
+ )
133
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
134
+ assert "prev_iter_dev" in dict_aux_tensor, (
135
+ "prev_iter_dev must be in dict_aux_tensor. "
136
+ "Please check the frontend and backend version. "
137
+ )
138
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
139
+ assert "vbe_output_offsets" in dict_aux_tensor, (
140
+ "vbe_output_offsets must be in dict_aux_tensor. "
141
+ "Please check the frontend and backend version. "
142
+ )
143
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
144
+
145
+ aux_int: List[int] = []
146
+ assert "iter" in dict_aux_int, (
147
+ "iter must be in dict_aux_int. "
148
+ "Please check the frontend and backend version. "
149
+ )
150
+ aux_int.append(dict_aux_int["iter"])
151
+ assert "info_B_num_bits" in dict_aux_int, (
152
+ "info_B_num_bits must be in dict_aux_int. "
153
+ "Please check the frontend and backend version. "
154
+ )
155
+ aux_int.append(dict_aux_int["info_B_num_bits"])
156
+ assert "info_B_mask" in dict_aux_int, (
157
+ "info_B_mask must be in dict_aux_int. "
158
+ "Please check the frontend and backend version. "
159
+ )
160
+ aux_int.append(dict_aux_int["info_B_mask"])
161
+
162
+ aux_float: List[float] = []
163
+ assert "gwd_lower_bound" in dict_aux_float, (
164
+ "gwd_lower_bound must be in dict_aux_float. "
165
+ "Please check the frontend and backend version. "
166
+ )
167
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
168
+ assert "max_gradient" in dict_aux_float, (
169
+ "max_gradient must be in dict_aux_float. "
170
+ "Please check the frontend and backend version. "
171
+ )
172
+ aux_float.append(dict_aux_float["max_gradient"])
173
+
174
+ aux_bool: List[bool] = []
175
+ assert "is_experimental_tbe" in dict_aux_bool, (
176
+ "is_experimental_tbe must be in dict_aux_bool. "
177
+ "Please check the frontend and backend version. "
178
+ )
179
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
180
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
181
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
182
+ "Please check the frontend and backend version. "
183
+ )
184
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
185
+ assert "use_homogeneous_placements" in dict_aux_bool, (
186
+ "use_homogeneous_placements must be in dict_aux_bool. "
187
+ "Please check the frontend and backend version. "
188
+ )
189
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
190
+ assert "apply_global_weight_decay" in dict_aux_bool, (
191
+ "apply_global_weight_decay must be in dict_aux_bool. "
192
+ "Please check the frontend and backend version. "
193
+ )
194
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
195
+ assert "gradient_clipping" in dict_aux_bool, (
196
+ "gradient_clipping must be in dict_aux_bool. "
197
+ "Please check the frontend and backend version. "
198
+ )
199
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
200
+ assert "stochastic_rounding" in dict_aux_bool, (
201
+ "stochastic_rounding must be in dict_aux_bool. "
202
+ "Please check the frontend and backend version. "
203
+ )
204
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
205
+ assert "mixed_D" in dict_aux_bool, (
206
+ "mixed_D must be in dict_aux_bool. "
207
+ "Please check the frontend and backend version. "
208
+ )
209
+ aux_bool.append(dict_aux_bool["mixed_D"])
210
+
211
+ # optim_int
212
+ optim_int: List[int] = []
213
+ optim_int.append(dict_optim_int["weight_decay_mode"])
214
+ # optim_float
215
+ # ['momentum1', 'learning_rate_tensor', 'optim_int', 'optim_float']
216
+ optim_float: List[float] = []
217
+ optim_float.append(dict_optim_float["eps"])
218
+ optim_float.append(dict_optim_float["weight_decay"])
219
+ optim_float.append(dict_optim_float["max_norm"])
220
+ # optim_bool
221
+
222
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_rowwise_adagrad_function_pt2(
105
223
  # common_args
106
224
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
107
- dev_weights=common_args.dev_weights,
108
- uvm_weights=common_args.uvm_weights,
109
- lxu_cache_weights=common_args.lxu_cache_weights,
110
- weights_placements=common_args.weights_placements,
111
- weights_offsets=common_args.weights_offsets,
225
+ # weights
226
+ weights=weights,
112
227
  D_offsets=common_args.D_offsets,
113
228
  total_D=common_args.total_D,
114
229
  max_D=common_args.max_D,
@@ -119,43 +234,33 @@ def invoke(
119
234
  pooling_mode=common_args.pooling_mode,
120
235
  indice_weights=common_args.indice_weights,
121
236
  feature_requires_grad=common_args.feature_requires_grad,
122
- lxu_cache_locations=common_args.lxu_cache_locations,
123
- uvm_cache_stats=common_args.uvm_cache_stats,
237
+ output_dtype=common_args.output_dtype,
124
238
  # VBE metadata
125
- B_offsets=vbe_metadata.B_offsets,
126
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
127
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
128
239
  max_B=vbe_metadata.max_B,
129
240
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
130
241
  vbe_output_size=vbe_metadata.output_size,
131
- # optimizer_args
132
- gradient_clipping = optimizer_args.gradient_clipping,
133
- max_gradient=optimizer_args.max_gradient,
134
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
135
- # V1 interface still accepts learning_rate as float
136
- learning_rate=optimizer_args.learning_rate,
137
- eps=optimizer_args.eps,
138
- weight_decay=optimizer_args.weight_decay,
139
- weight_decay_mode=optimizer_args.weight_decay_mode,
140
- max_norm=optimizer_args.max_norm,
242
+ vbe_output=vbe_metadata.vbe_output,
243
+ # aux_tensor
244
+ aux_tensor=aux_tensor,
245
+ # aux_int
246
+ aux_int=aux_int,
247
+ # aux_float
248
+ aux_float=aux_float,
249
+ # aux_bool
250
+ aux_bool=aux_bool,
251
+ learning_rate_tensor=common_args.learning_rate_tensor,
252
+
141
253
  # momentum1
142
- momentum1_dev=momentum1.dev,
143
- momentum1_uvm=momentum1.uvm,
144
- momentum1_offsets=momentum1.offsets,
145
- momentum1_placements=momentum1.placements,
254
+ momentum1 = momentum1_list,
146
255
  # momentum2
147
256
  # prev_iter
148
-
149
- prev_iter_dev=prev_iter_dev,
150
257
  # row_counter
151
- # iter
152
- iter=iter,
153
- # max counter
258
+ # optim_tensor
259
+ # optim_int
260
+ optim_int=optim_int,
261
+ # optim_float
262
+ optim_float=optim_float,
263
+ # optim_bool
264
+ # optim symint args
154
265
  # total_unique_indices
155
- output_dtype=common_args.output_dtype,
156
- is_experimental=common_args.is_experimental,
157
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
158
- use_homogeneous_placements=common_args.use_homogeneous_placements,
159
- apply_global_weight_decay=apply_global_weight_decay,
160
- gwd_lower_bound=gwd_lower_bound,
161
266
  )