fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/split_embedding_codegen_lookup_invoker.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/split_embedding_codegen_lookup_invoker.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -16,6 +18,12 @@
16
18
  import torch
17
19
  import warnings
18
20
  from .lookup_args import *
21
+
22
+
23
+
24
+
25
+
26
+
19
27
  def invoke(
20
28
  common_args: CommonArgs,
21
29
  optimizer_args: OptimizerArgs,
@@ -26,6 +34,7 @@ def invoke(
26
34
  # only pass prev_iter_dev since prev_iter is never created on UVM
27
35
  prev_iter_dev: Optional[torch.Tensor] = None,
28
36
  gwd_lower_bound: float = 0.0,
37
+ mixed_D: bool = True,
29
38
  row_counter: Optional[Momentum] = None,
30
39
  ) -> torch.Tensor:
31
40
  # By design, the warning only shows up once
@@ -36,29 +45,245 @@ def invoke(
36
45
  times and build sizes!
37
46
  \033[0m"""
38
47
  )
39
-
48
+ # host_weights is only used for CPU training
49
+ use_cpu = common_args.host_weights.numel() > 0
40
50
  vbe_metadata = common_args.vbe_metadata
41
- if not optimizer_args.use_rowwise_bias_correction or row_counter is None:
51
+
52
+ # pack weights
53
+ weights = [
54
+ common_args.dev_weights,
55
+ common_args.uvm_weights,
56
+ common_args.weights_placements,
57
+ common_args.weights_offsets,
58
+ common_args.lxu_cache_weights,
59
+ ] if not use_cpu else [
60
+ common_args.host_weights,
61
+ common_args.weights_placements,
62
+ common_args.weights_offsets,
63
+ ]
64
+ dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
65
+ "B_offsets": vbe_metadata.B_offsets,
66
+ "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
67
+ "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
68
+ "lxu_cache_locations": common_args.lxu_cache_locations,
69
+ "uvm_cache_stats": common_args.uvm_cache_stats,
70
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
71
+ }
72
+
73
+ dict_aux_int: Dict[str, int] = {
74
+ "iter": iter,
75
+ "info_B_num_bits": common_args.info_B_num_bits,
76
+ "info_B_mask": common_args.info_B_mask,
77
+ }
78
+
79
+ dict_aux_float: Dict[str, float] = {
80
+ "gwd_lower_bound": gwd_lower_bound,
81
+ }
82
+
83
+ dict_aux_bool: Dict[str, bool] = {
84
+ "is_experimental_tbe": common_args.is_experimental,
85
+ "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
86
+ "use_homogeneous_placements": common_args.use_homogeneous_placements,
87
+ "apply_global_weight_decay": apply_global_weight_decay,
88
+ "mixed_D": mixed_D,
89
+ }
90
+ dict_optim_int: Dict[str, int] = {}
91
+ dict_optim_float: Dict[str, float] = {}
92
+ dict_optim_bool: Dict[str, bool] = {}
93
+
94
+ # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
95
+ dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
96
+
97
+
98
+ # optimizer_args # if optimizer == none
99
+ dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
100
+ dict_aux_float["max_gradient"] = optimizer_args.max_gradient
101
+ dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
102
+ dict_optim_float["eps"] = optimizer_args.eps
103
+ dict_optim_float["beta1"] = optimizer_args.beta1
104
+ dict_optim_float["beta2"] = optimizer_args.beta2
105
+ dict_optim_float["weight_decay"] = optimizer_args.weight_decay
106
+ dict_optim_bool["use_rowwise_bias_correction"] = optimizer_args.use_rowwise_bias_correction
107
+
108
+ momentum1_list = [
109
+ momentum1.dev,
110
+ momentum1.uvm,
111
+ momentum1.placements,
112
+ momentum1.offsets,
113
+ ] if not use_cpu else [
114
+ momentum1.host,
115
+ momentum1.placements,
116
+ momentum1.offsets,
117
+ ] if momentum1 is not None else None
118
+
119
+ momentum2_list = [
120
+ momentum2.dev,
121
+ momentum2.uvm,
122
+ momentum2.placements,
123
+ momentum2.offsets,
124
+ ] if not use_cpu else [
125
+ momentum2.host,
126
+ momentum2.placements,
127
+ momentum2.offsets,
128
+ ] if momentum2 is not None else None
129
+
130
+ if optimizer_args.use_rowwise_bias_correction and row_counter is not None:
131
+ row_counter_host = None # not supported on CPU
132
+ row_counter_dev = row_counter.dev
133
+ row_counter_uvm = row_counter.uvm
134
+ row_counter_offsets = row_counter.offsets
135
+ row_counter_placements = row_counter.placements
136
+ elif optimizer_args.use_rowwise_bias_correction:
137
+ assert False, "`use_rowwise_bias_correction` is set, `row_counter` cannot be None"
138
+ else:
139
+ row_counter_host = None
42
140
  row_counter_dev = None
43
141
  row_counter_uvm = None
44
142
  row_counter_offsets = None
45
143
  row_counter_placements = None
46
- elif optimizer_args.use_rowwise_bias_correction and row_counter is None:
47
- assert False, "use_rowwise_bias_correction is set but row_counter cannot be None"
144
+
145
+
146
+ aux_tensor: List[Optional[torch.Tensor]] = []
147
+ assert "B_offsets" in dict_aux_tensor, (
148
+ "B_offsets must be in dict_aux_tensor. "
149
+ "Please check the frontend and backend version. "
150
+ )
151
+ aux_tensor.append(dict_aux_tensor["B_offsets"])
152
+ assert "vbe_output_offsets_feature_rank" in dict_aux_tensor, (
153
+ "vbe_output_offsets_feature_rank must be in dict_aux_tensor. "
154
+ "Please check the frontend and backend version. "
155
+ )
156
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets_feature_rank"])
157
+ assert "vbe_B_offsets_rank_per_feature" in dict_aux_tensor, (
158
+ "vbe_B_offsets_rank_per_feature must be in dict_aux_tensor. "
159
+ "Please check the frontend and backend version. "
160
+ )
161
+ aux_tensor.append(dict_aux_tensor["vbe_B_offsets_rank_per_feature"])
162
+ assert "lxu_cache_locations" in dict_aux_tensor, (
163
+ "lxu_cache_locations must be in dict_aux_tensor. "
164
+ "Please check the frontend and backend version. "
165
+ )
166
+ aux_tensor.append(dict_aux_tensor["lxu_cache_locations"])
167
+ assert "uvm_cache_stats" in dict_aux_tensor, (
168
+ "uvm_cache_stats must be in dict_aux_tensor. "
169
+ "Please check the frontend and backend version. "
170
+ )
171
+ aux_tensor.append(dict_aux_tensor["uvm_cache_stats"])
172
+ assert "prev_iter_dev" in dict_aux_tensor, (
173
+ "prev_iter_dev must be in dict_aux_tensor. "
174
+ "Please check the frontend and backend version. "
175
+ )
176
+ aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
177
+ assert "vbe_output_offsets" in dict_aux_tensor, (
178
+ "vbe_output_offsets must be in dict_aux_tensor. "
179
+ "Please check the frontend and backend version. "
180
+ )
181
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
182
+
183
+ aux_int: List[int] = []
184
+ assert "iter" in dict_aux_int, (
185
+ "iter must be in dict_aux_int. "
186
+ "Please check the frontend and backend version. "
187
+ )
188
+ aux_int.append(dict_aux_int["iter"])
189
+ assert "info_B_num_bits" in dict_aux_int, (
190
+ "info_B_num_bits must be in dict_aux_int. "
191
+ "Please check the frontend and backend version. "
192
+ )
193
+ aux_int.append(dict_aux_int["info_B_num_bits"])
194
+ assert "info_B_mask" in dict_aux_int, (
195
+ "info_B_mask must be in dict_aux_int. "
196
+ "Please check the frontend and backend version. "
197
+ )
198
+ aux_int.append(dict_aux_int["info_B_mask"])
199
+
200
+ aux_float: List[float] = []
201
+ assert "gwd_lower_bound" in dict_aux_float, (
202
+ "gwd_lower_bound must be in dict_aux_float. "
203
+ "Please check the frontend and backend version. "
204
+ )
205
+ aux_float.append(dict_aux_float["gwd_lower_bound"])
206
+ assert "max_gradient" in dict_aux_float, (
207
+ "max_gradient must be in dict_aux_float. "
208
+ "Please check the frontend and backend version. "
209
+ )
210
+ aux_float.append(dict_aux_float["max_gradient"])
211
+
212
+ aux_bool: List[bool] = []
213
+ assert "is_experimental_tbe" in dict_aux_bool, (
214
+ "is_experimental_tbe must be in dict_aux_bool. "
215
+ "Please check the frontend and backend version. "
216
+ )
217
+ aux_bool.append(dict_aux_bool["is_experimental_tbe"])
218
+ assert "use_uniq_cache_locations_bwd" in dict_aux_bool, (
219
+ "use_uniq_cache_locations_bwd must be in dict_aux_bool. "
220
+ "Please check the frontend and backend version. "
221
+ )
222
+ aux_bool.append(dict_aux_bool["use_uniq_cache_locations_bwd"])
223
+ assert "use_homogeneous_placements" in dict_aux_bool, (
224
+ "use_homogeneous_placements must be in dict_aux_bool. "
225
+ "Please check the frontend and backend version. "
226
+ )
227
+ aux_bool.append(dict_aux_bool["use_homogeneous_placements"])
228
+ assert "apply_global_weight_decay" in dict_aux_bool, (
229
+ "apply_global_weight_decay must be in dict_aux_bool. "
230
+ "Please check the frontend and backend version. "
231
+ )
232
+ aux_bool.append(dict_aux_bool["apply_global_weight_decay"])
233
+ assert "gradient_clipping" in dict_aux_bool, (
234
+ "gradient_clipping must be in dict_aux_bool. "
235
+ "Please check the frontend and backend version. "
236
+ )
237
+ aux_bool.append(dict_aux_bool["gradient_clipping"])
238
+ assert "stochastic_rounding" in dict_aux_bool, (
239
+ "stochastic_rounding must be in dict_aux_bool. "
240
+ "Please check the frontend and backend version. "
241
+ )
242
+ aux_bool.append(dict_aux_bool["stochastic_rounding"])
243
+ assert "mixed_D" in dict_aux_bool, (
244
+ "mixed_D must be in dict_aux_bool. "
245
+ "Please check the frontend and backend version. "
246
+ )
247
+ aux_bool.append(dict_aux_bool["mixed_D"])
248
+ optim_tensor: List[Optional[torch.Tensor]] = []
249
+ # We cannot do list of optional tensorlist (optional tensorlist is Tensor?[]).
250
+ # we need to pack optimizer optional tensors in a flatten manner.
251
+ # We pack unified args (i.e., 5 items) since it's very confusing to pack/unpack per device (i.e, 3 for cpu and 4 for cuda)
252
+ # e.g., if we have optim optional tensors x and y, the optim_tensor will look like
253
+ # [x_host, x_dev, x_uvm, x_placements, x_offsets, y_host, y_dev, y_uvm, y_placements, y_offsets]
254
+ # ['row_counter']
255
+
256
+ # using .extend fails torch script
257
+ if row_counter is None:
258
+ optim_tensor.append(None)
259
+ optim_tensor.append(None)
260
+ optim_tensor.append(None)
261
+ optim_tensor.append(None)
262
+ optim_tensor.append(None)
48
263
  else:
49
- row_counter_dev = row_counter.dev
50
- row_counter_uvm = row_counter.uvm
51
- row_counter_offsets = row_counter.offsets
52
- row_counter_placements = row_counter.placements
264
+ optim_tensor.append(row_counter.host)
265
+ optim_tensor.append(row_counter.dev)
266
+ optim_tensor.append(row_counter.uvm)
267
+ optim_tensor.append(row_counter.placements)
268
+ optim_tensor.append(row_counter.offsets)
269
+
270
+ # optim_int
271
+ # optim_float
272
+ # ['momentum1', 'momentum2', 'learning_rate_tensor', 'optim_tensor', 'optim_float', 'optim_bool']
273
+ optim_float: List[float] = []
274
+ optim_float.append(dict_optim_float["eps"])
275
+ optim_float.append(dict_optim_float["beta1"])
276
+ optim_float.append(dict_optim_float["beta2"])
277
+ optim_float.append(dict_optim_float["weight_decay"])
278
+ # optim_bool
279
+ optim_bool: List[bool] = []
280
+ optim_bool.append(dict_optim_bool["use_rowwise_bias_correction"])
53
281
 
54
- return torch.ops.fbgemm.split_embedding_codegen_lookup_adam_function(
282
+ return torch.ops.fbgemm.split_embedding_codegen_lookup_adam_function_pt2(
55
283
  # common_args
56
284
  placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
57
- dev_weights=common_args.dev_weights,
58
- uvm_weights=common_args.uvm_weights,
59
- lxu_cache_weights=common_args.lxu_cache_weights,
60
- weights_placements=common_args.weights_placements,
61
- weights_offsets=common_args.weights_offsets,
285
+ # weights
286
+ weights=weights,
62
287
  D_offsets=common_args.D_offsets,
63
288
  total_D=common_args.total_D,
64
289
  max_D=common_args.max_D,
@@ -69,52 +294,35 @@ def invoke(
69
294
  pooling_mode=common_args.pooling_mode,
70
295
  indice_weights=common_args.indice_weights,
71
296
  feature_requires_grad=common_args.feature_requires_grad,
72
- lxu_cache_locations=common_args.lxu_cache_locations,
73
- uvm_cache_stats=common_args.uvm_cache_stats,
297
+ output_dtype=common_args.output_dtype,
74
298
  # VBE metadata
75
- B_offsets=vbe_metadata.B_offsets,
76
- vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
77
- vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
78
299
  max_B=vbe_metadata.max_B,
79
300
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
80
301
  vbe_output_size=vbe_metadata.output_size,
81
- # optimizer_args
82
- gradient_clipping = optimizer_args.gradient_clipping,
83
- max_gradient=optimizer_args.max_gradient,
84
- stochastic_rounding=optimizer_args.stochastic_rounding, # if optimizer == none
85
- # V1 interface still accepts learning_rate as float
86
- learning_rate=optimizer_args.learning_rate,
87
- eps=optimizer_args.eps,
88
- beta1=optimizer_args.beta1,
89
- beta2=optimizer_args.beta2,
90
- weight_decay=optimizer_args.weight_decay,
302
+ vbe_output=vbe_metadata.vbe_output,
303
+ # aux_tensor
304
+ aux_tensor=aux_tensor,
305
+ # aux_int
306
+ aux_int=aux_int,
307
+ # aux_float
308
+ aux_float=aux_float,
309
+ # aux_bool
310
+ aux_bool=aux_bool,
311
+ learning_rate_tensor=common_args.learning_rate_tensor,
312
+
91
313
  # momentum1
92
- momentum1_dev=momentum1.dev,
93
- momentum1_uvm=momentum1.uvm,
94
- momentum1_offsets=momentum1.offsets,
95
- momentum1_placements=momentum1.placements,
314
+ momentum1 = momentum1_list,
96
315
  # momentum2
97
- momentum2_dev=momentum2.dev,
98
- momentum2_uvm=momentum2.uvm,
99
- momentum2_offsets=momentum2.offsets,
100
- momentum2_placements=momentum2.placements,
316
+ momentum2=momentum2_list,
101
317
  # prev_iter
102
-
103
- prev_iter_dev=prev_iter_dev,
104
318
  # row_counter
105
- row_counter_dev=row_counter_dev,
106
- row_counter_uvm=row_counter_uvm,
107
- row_counter_offsets=row_counter_offsets,
108
- row_counter_placements=row_counter_placements,
109
- use_rowwise_bias_correction=optimizer_args.use_rowwise_bias_correction,
110
- # iter
111
- iter=iter,
112
- # max counter
319
+ # optim_tensor
320
+ optim_tensor=optim_tensor,
321
+ # optim_int
322
+ # optim_float
323
+ optim_float=optim_float,
324
+ # optim_bool
325
+ optim_bool=optim_bool,
326
+ # optim symint args
113
327
  # total_unique_indices
114
- output_dtype=common_args.output_dtype,
115
- is_experimental=common_args.is_experimental,
116
- use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
117
- use_homogeneous_placements=common_args.use_homogeneous_placements,
118
- apply_global_weight_decay=apply_global_weight_decay,
119
- gwd_lower_bound=gwd_lower_bound,
120
328
  )
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/lookup_args.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/lookup_args.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
  # Copyright (c) Meta Platforms, Inc. and affiliates.
9
11
  # All rights reserved.
@@ -25,6 +27,8 @@ class VBEMetadata(NamedTuple):
25
27
  max_B_feature_rank: int = -1
26
28
  max_B: int = -1
27
29
  output_size: int = -1
30
+ vbe_output: Optional[torch.Tensor] = None
31
+ vbe_output_offsets: Optional[torch.Tensor] = None
28
32
 
29
33
 
30
34
  class CommonArgs(NamedTuple):
@@ -52,71 +56,18 @@ class CommonArgs(NamedTuple):
52
56
  is_experimental: bool
53
57
  use_uniq_cache_locations_bwd: bool
54
58
  use_homogeneous_placements: bool
55
-
56
-
57
- class OptimizerArgs(NamedTuple):
58
- stochastic_rounding: bool
59
- gradient_clipping: bool
60
- max_gradient: float
61
- max_norm: float
62
- learning_rate: float
63
- eps: float
64
- beta1: float
65
- beta2: float
66
- weight_decay: float
67
- weight_decay_mode: int
68
- eta: float
69
- momentum: float
70
- counter_halflife: int
71
- adjustment_iter: int
72
- adjustment_ub: float
73
- learning_rate_mode: int
74
- grad_sum_decay: int
75
- tail_id_threshold: float
76
- is_tail_id_thresh_ratio: int
77
- total_hash_size: int # Required for OptimType.NONE
78
- weight_norm_coefficient: float
79
- lower_bound: float
80
- regularization_mode: int
81
- use_rowwise_bias_correction: bool # Used for OptimType.ADAM
82
-
83
- class CommonArgsPT2(NamedTuple):
84
- placeholder_autograd_tensor: torch.Tensor
85
- dev_weights: torch.Tensor
86
- host_weights: torch.Tensor
87
- uvm_weights: torch.Tensor
88
- lxu_cache_weights: torch.Tensor
89
- weights_placements: torch.Tensor
90
- weights_offsets: torch.Tensor
91
- D_offsets: torch.Tensor
92
- total_D: int
93
- max_D: int
94
- hash_size_cumsum: torch.Tensor
95
- total_hash_size_bits: int
96
- indices: torch.Tensor
97
- offsets: torch.Tensor
98
- pooling_mode: int
99
- indice_weights: Optional[torch.Tensor]
100
- feature_requires_grad: Optional[torch.Tensor]
101
- lxu_cache_locations: torch.Tensor
102
- uvm_cache_stats: Optional[torch.Tensor]
103
- output_dtype: int
104
- vbe_metadata: VBEMetadata
105
- is_experimental: bool
106
- use_uniq_cache_locations_bwd: bool
107
- use_homogeneous_placements: bool
59
+ learning_rate_tensor: torch.Tensor
108
60
  info_B_num_bits: int
109
61
  info_B_mask: int
110
62
 
111
- class OptimizerArgsPT2(NamedTuple):
112
- """
113
- Optimizer arguments for PT2 interface
114
- """
63
+
64
+ # Do not add a parameter of Type tensor here. It will cause JIT script error due to a bug in PyTorch.
65
+ # See more detail in D71010630.
66
+ class OptimizerArgs(NamedTuple):
115
67
  stochastic_rounding: bool
116
68
  gradient_clipping: bool
117
69
  max_gradient: float
118
70
  max_norm: float
119
- learning_rate_tensor: torch.Tensor
120
71
  eps: float
121
72
  beta1: float
122
73
  beta2: float
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/lookup_args.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/lookup_args.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
  # Copyright (c) Meta Platforms, Inc. and affiliates.
9
11
  # All rights reserved.
@@ -25,6 +27,8 @@ class VBEMetadata(NamedTuple):
25
27
  max_B_feature_rank: int = -1
26
28
  max_B: int = -1
27
29
  output_size: int = -1
30
+ vbe_output: Optional[torch.Tensor] = None
31
+ vbe_output_offsets: Optional[torch.Tensor] = None
28
32
 
29
33
 
30
34
  class CommonArgs(NamedTuple):
@@ -53,72 +57,19 @@ class CommonArgs(NamedTuple):
53
57
  use_uniq_cache_locations_bwd: bool
54
58
  use_homogeneous_placements: bool
55
59
  ssd_tensors: Dict[str, torch.Tensor]
56
-
57
-
58
- class OptimizerArgs(NamedTuple):
59
- stochastic_rounding: bool
60
- gradient_clipping: bool
61
- max_gradient: float
62
- max_norm: float
63
- learning_rate: float
64
- eps: float
65
- beta1: float
66
- beta2: float
67
- weight_decay: float
68
- weight_decay_mode: int
69
- eta: float
70
- momentum: float
71
- counter_halflife: int
72
- adjustment_iter: int
73
- adjustment_ub: float
74
- learning_rate_mode: int
75
- grad_sum_decay: int
76
- tail_id_threshold: float
77
- is_tail_id_thresh_ratio: int
78
- total_hash_size: int # Required for OptimType.NONE
79
- weight_norm_coefficient: float
80
- lower_bound: float
81
- regularization_mode: int
82
- use_rowwise_bias_correction: bool # Used for OptimType.ADAM
83
-
84
- class CommonArgsPT2(NamedTuple):
85
- placeholder_autograd_tensor: torch.Tensor
86
- dev_weights: torch.Tensor
87
- host_weights: torch.Tensor
88
- uvm_weights: torch.Tensor
89
- lxu_cache_weights: torch.Tensor
90
- weights_placements: torch.Tensor
91
- weights_offsets: torch.Tensor
92
- D_offsets: torch.Tensor
93
- total_D: int
94
- max_D: int
95
- hash_size_cumsum: torch.Tensor
96
- total_hash_size_bits: int
97
- indices: torch.Tensor
98
- offsets: torch.Tensor
99
- pooling_mode: int
100
- indice_weights: Optional[torch.Tensor]
101
- feature_requires_grad: Optional[torch.Tensor]
102
- lxu_cache_locations: torch.Tensor
103
- uvm_cache_stats: Optional[torch.Tensor]
104
- output_dtype: int
105
- vbe_metadata: VBEMetadata
106
- is_experimental: bool
107
- use_uniq_cache_locations_bwd: bool
108
- use_homogeneous_placements: bool
60
+ enable_optimizer_offloading: bool
61
+ learning_rate_tensor: torch.Tensor
109
62
  info_B_num_bits: int
110
63
  info_B_mask: int
111
- ssd_tensors: Dict[str, torch.Tensor]
112
64
 
113
- class OptimizerArgsPT2(NamedTuple):
114
- """
115
- Optimizer arguments for PT2 interface
116
- """
65
+
66
+ # Do not add a parameter of Type tensor here. It will cause JIT script error due to a bug in PyTorch.
67
+ # See more detail in D71010630.
68
+ class OptimizerArgs(NamedTuple):
117
69
  stochastic_rounding: bool
118
70
  gradient_clipping: bool
119
71
  max_gradient: float
120
72
  max_norm: float
121
- learning_rate_tensor: torch.Tensor
122
73
  eps: float
123
74
  beta1: float
124
75
  beta2: float