fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/sparse_ops.py CHANGED
@@ -7,10 +7,12 @@
7
7
  # pyre-strict
8
8
 
9
9
  import math
10
- from typing import Callable, List, Optional, Sequence, Tuple
10
+ from collections.abc import Sequence
11
+ from typing import Callable, Optional
11
12
 
12
13
  import torch
13
14
 
15
+ # fmt:skip
14
16
  from fbgemm_gpu.split_embedding_configs import SparseType
15
17
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
16
18
  from fbgemm_gpu.utils.loader import load_torch_module
@@ -48,8 +50,7 @@ except Exception:
48
50
 
49
51
  import torch.utils._pytree as pytree
50
52
  from torch import SymInt, Tensor
51
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
52
-
53
+ from torch.fx.experimental.symbolic_shapes import guard_or_true
53
54
 
54
55
  if hasattr(torch.library, "register_fake"):
55
56
  # pyre-ignore[9]
@@ -74,7 +75,7 @@ def permute_2D_sparse_data_input1D_meta(
74
75
  stride: int,
75
76
  weights: Optional[Tensor] = None,
76
77
  permuted_lengths_sum: Optional[int] = None,
77
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
78
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
78
79
  torch._check(
79
80
  lengths.dim() == 1, lambda: f"expected lengths.dim() == 1, got {lengths.dim()}"
80
81
  )
@@ -111,7 +112,7 @@ def permute_2D_sparse_data_input1D_backward(
111
112
  grad_lengths: torch.Tensor,
112
113
  grad_values: torch.Tensor,
113
114
  grad_weights: torch.Tensor,
114
- ) -> Tuple[None, Tensor, Tensor, None, Tensor, None]:
115
+ ) -> tuple[None, Tensor, Tensor, None, Tensor, None]:
115
116
  inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
116
117
  permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
117
118
  torch.ops.fbgemm.permute_2D_sparse_data_input1D(
@@ -139,7 +140,7 @@ def permute_2D_sparse_data_meta(
139
140
  values: Tensor,
140
141
  weights: Optional[Tensor] = None,
141
142
  permuted_lengths_sum: Optional[int] = None,
142
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
143
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
143
144
  torch._check(
144
145
  lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}"
145
146
  )
@@ -166,6 +167,89 @@ def invert_permute_abstract(permute: Tensor) -> Tensor:
166
167
  return torch.empty_like(permute)
167
168
 
168
169
 
170
+ def get_source_mask_meta(
171
+ num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
172
+ ) -> Tensor:
173
+ if output_size is None:
174
+ ctx = torch.library.get_ctx()
175
+ output_size = ctx.new_dynamic_size()
176
+ return torch.empty([output_size], dtype=torch.bool)
177
+
178
+
179
+ def get_source_mask(
180
+ num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
181
+ ) -> Tensor:
182
+ """
183
+ Generate a boolean mask indicating which elements are from sources vs targets.
184
+
185
+ This is a Python wrapper that computes output_size when not provided,
186
+ enabling the operation to work with meta tensors for compilation.
187
+
188
+ Args:
189
+ num_sources: 1D tensor of source counts per batch element
190
+ num_targets: 1D tensor of target counts per batch element
191
+ output_size: Optional pre-computed output size.
192
+
193
+ Returns:
194
+ A 1D boolean tensor where True indicates source elements and False
195
+ indicates target elements
196
+
197
+ Example:
198
+ >>> num_sources = torch.tensor([2, 3])
199
+ >>> num_targets = torch.tensor([1, 2])
200
+ >>> get_source_mask(num_sources, num_targets)
201
+ tensor([True, True, False, True, True, True, False, False])
202
+ """
203
+ # Compute output_size if not provided and tensors are regular (not meta/fake)
204
+ if output_size is None:
205
+ combined = num_sources + num_targets
206
+ output_size = int(combined.sum().item())
207
+
208
+ return torch.ops.fbgemm.get_source_mask(num_sources, num_targets, output_size)
209
+
210
+
211
+ def repeat_arange_meta(lengths: Tensor) -> Tensor:
212
+ """Meta implementation for repeat_arange."""
213
+ # Output size is data-dependent (sum of lengths).
214
+ # For FakeTensors (used in torch.compile), we use dynamic sizing.
215
+ # For actual meta tensors, we cannot determine the size so return empty.
216
+ if lengths.device.type == "meta":
217
+ # Actual meta tensors: return a zero-sized tensor as placeholder
218
+ # since we cannot compute the data-dependent output size
219
+ return torch.empty([0], dtype=lengths.dtype, device=lengths.device)
220
+ else:
221
+ # FakeTensor context: use dynamic sizing for proper shape tracking
222
+ ctx = torch.library.get_ctx()
223
+ output_size = ctx.new_dynamic_size()
224
+ return torch.empty([output_size], dtype=lengths.dtype, device=lengths.device)
225
+
226
+
227
+ def repeat_arange(lengths: Tensor) -> Tensor:
228
+ """
229
+ Creates a concatenated tensor of aranges based on a lengths tensor.
230
+
231
+ This is a high-performance CUDA kernel that replaces the inefficient PyTorch
232
+ implementation which uses 4+ separate kernels (cumsum, arange, repeat_interleave, sub).
233
+
234
+ Args:
235
+ lengths: 1D tensor of lengths for each arange sequence
236
+
237
+ Returns:
238
+ A 1D tensor containing concatenated arange sequences
239
+
240
+ Example:
241
+ >>> lengths = torch.tensor([3, 5, 2])
242
+ >>> repeat_arange(lengths)
243
+ tensor([0, 1, 2, 0, 1, 2, 3, 4, 0, 1])
244
+
245
+ Performance:
246
+ - PyTorch implementation: 4+ kernel launches + intermediate allocations
247
+ - CUDA implementation: 1 fused kernel, no intermediate allocations
248
+ - Typical speedup: 3-5x on realistic workloads
249
+ """
250
+ return torch.ops.fbgemm.repeat_arange(lengths)
251
+
252
+
169
253
  # pyre-ignore
170
254
  def permute_2D_sparse_data_setup_context(ctx, inputs, output):
171
255
  permute, lengths, values, weights, permuted_lengths_sum = inputs
@@ -197,7 +281,7 @@ def permute_1D_sparse_data_meta(
197
281
  values: Tensor,
198
282
  weights: Optional[Tensor] = None,
199
283
  permuted_lengths_sum: Optional[int] = None,
200
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
284
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
201
285
  indices = values
202
286
  permuted_lengths_size = permute.numel()
203
287
  permuted_lengths = lengths.new_empty([permuted_lengths_size])
@@ -218,7 +302,7 @@ def permute_1D_sparse_data_meta(
218
302
 
219
303
  def masked_select_jagged_1d(
220
304
  values: Tensor, lengths: Tensor, mask: Tensor
221
- ) -> Tuple[Tensor, Tensor]:
305
+ ) -> tuple[Tensor, Tensor]:
222
306
  torch._check(values.dim() == 1)
223
307
  torch._check(lengths.dim() == 1)
224
308
  torch._check(values.device == lengths.device)
@@ -231,11 +315,11 @@ def masked_select_jagged_1d(
231
315
 
232
316
 
233
317
  def tbe_input_combine_abstract(
234
- indices_list: List[Tensor],
235
- offsets_list: List[Tensor],
236
- per_sample_weights: List[Tensor],
318
+ indices_list: list[Tensor],
319
+ offsets_list: list[Tensor],
320
+ per_sample_weights: list[Tensor],
237
321
  include_last_offsets: Tensor,
238
- ) -> Tuple[Tensor, Tensor, Tensor]:
322
+ ) -> tuple[Tensor, Tensor, Tensor]:
239
323
  torch._check(len(indices_list) > 0)
240
324
  torch._check(len(indices_list) == len(offsets_list))
241
325
  torch._check(len(indices_list) == len(per_sample_weights))
@@ -250,7 +334,7 @@ def tbe_input_combine_abstract(
250
334
  torch._check(index.is_contiguous())
251
335
  torch._check(offset.is_contiguous())
252
336
  total_indices = total_indices + index.numel()
253
- if guard_size_oblivious(weight.numel() > 0):
337
+ if guard_or_true(weight.numel() > 0):
254
338
  torch._check(weight.dim() == 1)
255
339
  torch._check(weight.numel() == index.numel())
256
340
  torch._check(weight.is_contiguous())
@@ -268,10 +352,10 @@ def tbe_input_combine_abstract(
268
352
 
269
353
 
270
354
  def tbe_input_combine_with_length_abstract(
271
- indices_list: List[Tensor],
272
- offsets_list: List[Tensor],
273
- per_sample_weights: List[Tensor],
274
- ) -> Tuple[Tensor, Tensor, Tensor]:
355
+ indices_list: list[Tensor],
356
+ offsets_list: list[Tensor],
357
+ per_sample_weights: list[Tensor],
358
+ ) -> tuple[Tensor, Tensor, Tensor]:
275
359
  torch._check(len(indices_list) > 0)
276
360
  torch._check(len(indices_list) == len(offsets_list))
277
361
  torch._check(len(indices_list) == len(per_sample_weights))
@@ -287,7 +371,7 @@ def tbe_input_combine_with_length_abstract(
287
371
  torch._check(offset.is_contiguous())
288
372
  total_indices = total_indices + index.numel()
289
373
  total_offsets = total_offsets + offset.numel()
290
- if guard_size_oblivious(weight.numel() > 0):
374
+ if guard_or_true(weight.numel() > 0):
291
375
  torch._check(weight.dim() == 1)
292
376
  torch._check(weight.numel() == index.numel())
293
377
  torch._check(weight.is_contiguous())
@@ -339,7 +423,7 @@ def expand_into_jagged_permute_meta(
339
423
  permute: Tensor,
340
424
  input_offsets: Tensor,
341
425
  output_offsets: Tensor,
342
- output_size: Tuple[int, ...],
426
+ output_size: tuple[int, ...],
343
427
  ) -> Tensor:
344
428
  torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0")
345
429
  torch._check(
@@ -465,7 +549,7 @@ def block_bucketize_sparse_features_meta(
465
549
  keep_orig_idx: bool = False,
466
550
  total_num_blocks: Optional[torch.Tensor] = None,
467
551
  keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
468
- ) -> Tuple[
552
+ ) -> tuple[
469
553
  torch.Tensor,
470
554
  torch.Tensor,
471
555
  Optional[torch.Tensor],
@@ -485,8 +569,43 @@ def block_bucketize_sparse_features_meta(
485
569
  )
486
570
 
487
571
 
572
+ def block_bucketize_sparse_features_2d_weights_meta(
573
+ lengths: torch.Tensor,
574
+ indices: torch.Tensor,
575
+ bucketize_pos: bool,
576
+ sequence: bool,
577
+ block_sizes: torch.Tensor,
578
+ my_size: int,
579
+ weights: torch.Tensor,
580
+ weights_dim: int = 1,
581
+ batch_size_per_feature: Optional[torch.Tensor] = None,
582
+ max_B: int = -1,
583
+ block_bucketize_pos: Optional[torch.Tensor] = None,
584
+ keep_orig_idx: bool = False,
585
+ total_num_blocks: Optional[torch.Tensor] = None,
586
+ keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
587
+ ) -> tuple[
588
+ torch.Tensor,
589
+ torch.Tensor,
590
+ torch.Tensor,
591
+ Optional[torch.Tensor],
592
+ Optional[torch.Tensor],
593
+ ]:
594
+ # Output: lengths, indices, weights", pos?, unbucketize_permute?
595
+ num_buckets = my_size
596
+ num_features = lengths.size(0)
597
+ num_values = indices.size(0)
598
+ return (
599
+ lengths.new_empty([num_buckets * num_features]),
600
+ indices.new_empty([num_values]),
601
+ weights.new_empty([num_values, weights_dim]),
602
+ indices.new_empty([num_values]) if bucketize_pos else None,
603
+ indices.new_empty([num_values]),
604
+ )
605
+
606
+
488
607
  def merge_pooled_embeddings(
489
- pooled_embeddings: List[torch.Tensor],
608
+ pooled_embeddings: list[torch.Tensor],
490
609
  uncat_dim_size: int,
491
610
  target_device: torch.device,
492
611
  cat_dim: int = 1,
@@ -517,7 +636,7 @@ def merge_pooled_embeddings(
517
636
 
518
637
  def permute_sparse_features_abstract(
519
638
  permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None
520
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
639
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
521
640
  torch._check(lengths.dtype == indices.dtype)
522
641
  torch._check(permute.device == lengths.device)
523
642
  torch._check(permute.device == indices.device)
@@ -548,7 +667,7 @@ def segment_sum_csr_abstract(
548
667
 
549
668
  def dense_to_jagged_forward(
550
669
  dense: torch.Tensor,
551
- offsets: List[torch.Tensor],
670
+ offsets: list[torch.Tensor],
552
671
  total_L: Optional[torch.SymInt] = None,
553
672
  ) -> torch.Tensor:
554
673
  if total_L is None:
@@ -563,9 +682,9 @@ def dense_to_jagged_forward(
563
682
 
564
683
  def dense_to_jagged(
565
684
  dense: torch.Tensor,
566
- offsets: List[torch.Tensor],
685
+ offsets: list[torch.Tensor],
567
686
  total_L: Optional[torch.SymInt] = None,
568
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
687
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
569
688
  if total_L is None:
570
689
  total_L = torch.library.get_ctx().new_dynamic_size()
571
690
  return (dense_to_jagged_forward(dense, offsets, total_L), offsets)
@@ -574,9 +693,9 @@ def dense_to_jagged(
574
693
  def batch_index_select_dim0_abstract(
575
694
  inputs: torch.Tensor,
576
695
  indices: torch.Tensor,
577
- input_num_indices: List[int],
578
- input_rows: List[int],
579
- input_columns: List[int],
696
+ input_num_indices: list[int],
697
+ input_rows: list[int],
698
+ input_columns: list[int],
580
699
  permute_output_dim_0_1: bool,
581
700
  ) -> torch.Tensor:
582
701
  """
@@ -618,11 +737,11 @@ def batch_index_select_dim0_tensor_abstract(
618
737
  def batch_index_select_dim0_forward_cuda_impl_abstract(
619
738
  inputs: torch.Tensor,
620
739
  indices: torch.Tensor,
621
- input_num_indices: List[int],
622
- input_rows: List[int],
623
- input_columns: List[int],
740
+ input_num_indices: list[int],
741
+ input_rows: list[int],
742
+ input_columns: list[int],
624
743
  permute_output_dim_0_1: bool,
625
- ) -> List[torch.Tensor]:
744
+ ) -> list[torch.Tensor]:
626
745
  num_inputs = len(input_rows)
627
746
  torch._check(len(input_num_indices) == len(input_rows))
628
747
  torch._check(len(input_num_indices) == len(input_columns))
@@ -659,7 +778,7 @@ def batch_index_select_dim0_tensor_forward_cuda_impl_abstract(
659
778
  input_rows: torch.Tensor,
660
779
  input_columns: torch.Tensor,
661
780
  permute_output_dim_0_1: bool,
662
- ) -> List[torch.Tensor]:
781
+ ) -> list[torch.Tensor]:
663
782
  num_inputs: int = input_rows.size(0)
664
783
  torch._check(input_num_indices.size(0) == input_rows.size(0))
665
784
  torch._check(input_num_indices.size(0) == input_columns.size(0))
@@ -704,7 +823,7 @@ def keyed_jagged_index_select_dim1_abstract(
704
823
  batch_size: torch.SymInt,
705
824
  weights: Optional[torch.Tensor] = None,
706
825
  selected_lengths_sum: Optional[torch.SymInt] = None,
707
- ) -> List[torch.Tensor]:
826
+ ) -> list[torch.Tensor]:
708
827
  """
709
828
  This meta function is used to calculate the shape of output tensors
710
829
  from the original function `fbgemm::keyed_jagged_index_select_dim1` without the actual data.
@@ -729,7 +848,7 @@ def keyed_jagged_index_select_dim1_abstract(
729
848
  torch.index_select(lengths, 0, length_indices).sum().item()
730
849
  )
731
850
 
732
- ret: List[torch.Tensor] = [
851
+ ret: list[torch.Tensor] = [
733
852
  # pyre-ignore
734
853
  values.new_empty([selected_lengths_sum]),
735
854
  lengths.new_empty([indices.shape[0] * num_batches]),
@@ -761,17 +880,17 @@ def batch_index_select_dim0_backward_cuda_impl_abstract(
761
880
  def batch_index_select_dim0_forward_cpu_impl_abstract(
762
881
  inputs: torch.Tensor,
763
882
  indices: torch.Tensor,
764
- input_num_indices: List[int],
765
- input_rows: List[int],
766
- input_columns: List[int],
883
+ input_num_indices: list[int],
884
+ input_rows: list[int],
885
+ input_columns: list[int],
767
886
  permute_output_dim_0_1: bool,
768
- ) -> List[torch.Tensor]:
887
+ ) -> list[torch.Tensor]:
769
888
  # input lists must have the same length
770
889
  num_inputs = len(input_num_indices)
771
890
  torch._check(num_inputs == len(input_rows))
772
891
  torch._check(num_inputs == len(input_columns))
773
892
 
774
- if permute_output_dim_0_1 and guard_size_oblivious(len(input_num_indices) > 0):
893
+ if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0):
775
894
  # All num_indices must be the same if permute_output_dim_0_1 is True
776
895
  for x in input_num_indices:
777
896
  torch._check(x == input_num_indices[0])
@@ -795,7 +914,7 @@ def batch_index_select_dim0_tensor_forward_cpu_impl_abstract(
795
914
  input_rows: torch.Tensor,
796
915
  input_columns: torch.Tensor,
797
916
  permute_output_dim_0_1: bool,
798
- ) -> List[torch.Tensor]:
917
+ ) -> list[torch.Tensor]:
799
918
  # input lists must have the same length
800
919
  num_inputs = len(input_num_indices)
801
920
  torch._check(num_inputs == len(input_rows))
@@ -845,8 +964,8 @@ def bounds_check_indices_abstract(
845
964
 
846
965
 
847
966
  def group_index_select_dim0_gpu_impl_abstract(
848
- inputs: List[torch.Tensor], group_size: int
849
- ) -> List[torch.Tensor]:
967
+ inputs: list[torch.Tensor], group_size: int
968
+ ) -> list[torch.Tensor]:
850
969
  """
851
970
  Calculate output shapes for group_index_select_dim0_gpu_impl
852
971
  without the actual data.
@@ -876,8 +995,8 @@ def group_index_select_dim0_gpu_impl_abstract(
876
995
 
877
996
 
878
997
  def group_index_select_dim0_gpu_backward_abstract(
879
- all_inputs: List[torch.Tensor], output_shape_group_ref: List[torch.SymInt]
880
- ) -> List[torch.Tensor]:
998
+ all_inputs: list[torch.Tensor], output_shape_group_ref: list[torch.SymInt]
999
+ ) -> list[torch.Tensor]:
881
1000
  """
882
1001
  Calculate output shapes for group_index_select_dim0_gpu_backward
883
1002
  without the actual data.
@@ -910,7 +1029,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
910
1029
  batch_size: torch.SymInt,
911
1030
  weights: Optional[torch.Tensor] = None,
912
1031
  selected_lengths_sum: Optional[torch.SymInt] = None,
913
- ) -> List[torch.Tensor]:
1032
+ ) -> list[torch.Tensor]:
914
1033
  num_batches = lengths.size(0) // batch_size
915
1034
  torch._check(lengths.size(0) + 1 == offsets.size(0))
916
1035
  # pyre-ignore
@@ -924,7 +1043,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
924
1043
  selected_lengths_sum = torch.library.get_ctx().new_dynamic_size()
925
1044
 
926
1045
  torch._check_is_size(selected_lengths_sum)
927
- vlw: List[torch.Tensor] = [
1046
+ vlw: list[torch.Tensor] = [
928
1047
  values.new_empty([selected_lengths_sum]), # output
929
1048
  lengths.new_empty([indices.shape[0] * num_batches]), # output_lengths
930
1049
  ]
@@ -967,7 +1086,7 @@ def histogram_binning_calibration_abstract(
967
1086
  upper_bound: float,
968
1087
  bin_ctr_in_use_after: int,
969
1088
  bin_ctr_weight_value: float,
970
- ) -> Tuple[Tensor, Tensor]:
1089
+ ) -> tuple[Tensor, Tensor]:
971
1090
  return torch.empty_like(logit), torch.empty([logit.numel()], dtype=torch.int64)
972
1091
 
973
1092
 
@@ -1118,7 +1237,7 @@ def generic_histogram_binning_calibration_by_feature(
1118
1237
  positive_weight: float,
1119
1238
  bin_ctr_in_use_after: int,
1120
1239
  bin_ctr_weight_value: float,
1121
- ) -> Tuple[Tensor, Tensor]:
1240
+ ) -> tuple[Tensor, Tensor]:
1122
1241
  torch._check(bin_num_examples.numel() == bin_num_positives.numel())
1123
1242
  torch._check(
1124
1243
  bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1)
@@ -1129,13 +1248,13 @@ def generic_histogram_binning_calibration_by_feature(
1129
1248
 
1130
1249
 
1131
1250
  def permute_multi_embedding_function_impl_abstract(
1132
- pooled_embs: List[Tensor],
1251
+ pooled_embs: list[Tensor],
1133
1252
  permutes: Tensor,
1134
1253
  in_shapes: Tensor,
1135
1254
  out_shapes: Tensor,
1136
- out_lengths: List[int],
1255
+ out_lengths: list[int],
1137
1256
  reverse: bool = False,
1138
- ) -> List[Tensor]:
1257
+ ) -> list[Tensor]:
1139
1258
  out_dtype = pooled_embs[0].dtype
1140
1259
  bs = pooled_embs[0].shape[0]
1141
1260
  torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns")
@@ -1161,15 +1280,25 @@ def lengths_range_abstract(
1161
1280
 
1162
1281
 
1163
1282
  def all_to_one_device(
1164
- input_tensors: List[Tensor],
1283
+ input_tensors: list[Tensor],
1165
1284
  target_device: torch.device,
1166
- ) -> List[Tensor]:
1285
+ ) -> list[Tensor]:
1167
1286
  return [
1168
1287
  torch.empty_like(input_tensor, device=torch.device("meta"))
1169
1288
  for input_tensor in input_tensors
1170
1289
  ]
1171
1290
 
1172
1291
 
1292
+ def sum_reduce_to_one(
1293
+ input_tensors: list[Tensor],
1294
+ target_device: torch.device,
1295
+ ) -> Tensor:
1296
+ torch._check(len(input_tensors) > 0, lambda: "reducing no tensor is undefined")
1297
+ # All tensors should have the same shape
1298
+ first_tensor = input_tensors[0]
1299
+ return torch.empty_like(first_tensor, device=torch.device("meta"))
1300
+
1301
+
1173
1302
  def _setup() -> None:
1174
1303
  # pyre-ignore[16]
1175
1304
  _setup.done = getattr(_setup, "done", False)
@@ -1202,6 +1331,8 @@ def _setup() -> None:
1202
1331
  )
1203
1332
 
1204
1333
  impl_abstract("fbgemm::permute_2D_sparse_data", permute_2D_sparse_data_meta)
1334
+ impl_abstract("fbgemm::get_source_mask", get_source_mask_meta)
1335
+ impl_abstract("fbgemm::repeat_arange", repeat_arange_meta)
1205
1336
  impl_abstract(
1206
1337
  "fbgemm::permute_2D_sparse_data_input1D",
1207
1338
  permute_2D_sparse_data_input1D_meta,
@@ -1234,6 +1365,10 @@ def _setup() -> None:
1234
1365
  "fbgemm::block_bucketize_sparse_features",
1235
1366
  block_bucketize_sparse_features_meta,
1236
1367
  )
1368
+ impl_abstract(
1369
+ "fbgemm::block_bucketize_sparse_features_2d_weights",
1370
+ block_bucketize_sparse_features_2d_weights_meta,
1371
+ )
1237
1372
  impl_abstract("fbgemm::merge_pooled_embeddings", merge_pooled_embeddings)
1238
1373
  impl_abstract(
1239
1374
  "fbgemm::permute_sparse_features", permute_sparse_features_abstract
@@ -1241,6 +1376,7 @@ def _setup() -> None:
1241
1376
  impl_abstract("fbgemm::segment_sum_csr", segment_sum_csr_abstract)
1242
1377
  impl_abstract("fbgemm::dense_to_jagged_forward", dense_to_jagged_forward)
1243
1378
  impl_abstract("fbgemm::all_to_one_device", all_to_one_device)
1379
+ impl_abstract("fbgemm::sum_reduce_to_one", sum_reduce_to_one)
1244
1380
  impl_abstract(
1245
1381
  "fbgemm::batch_index_select_dim0", batch_index_select_dim0_abstract
1246
1382
  )
@@ -152,6 +152,18 @@ except:
152
152
  DeprecationWarning,
153
153
  )
154
154
 
155
+ try:
156
+ # Import is placed under a try-except bc the op is experimental and can be
157
+ # removed/updated in the future
158
+ import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam_ssd as lookup_adam_ssd # noqa: F401
159
+ except:
160
+ warnings.warn(
161
+ f"""\033[93m
162
+ Failed to import: fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam_ssd
163
+ \033[0m""",
164
+ DeprecationWarning,
165
+ )
166
+
155
167
  try:
156
168
  # Import is placed under a try-except bc the op is experimental and can be
157
169
  # removed/updated in the future
@@ -56,14 +56,15 @@ def invoke(
56
56
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
57
57
  "lxu_cache_locations": common_args.lxu_cache_locations,
58
58
  "uvm_cache_stats": common_args.uvm_cache_stats,
59
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
59
60
  }
60
61
 
61
62
  dict_aux_int: Dict[str, int] = {
62
- "iter": iter,
63
- "info_B_num_bits": common_args.info_B_num_bits,
63
+ "iter": iter,
64
+ "info_B_num_bits": common_args.info_B_num_bits,
64
65
  "info_B_mask": common_args.info_B_mask,
65
66
  }
66
-
67
+
67
68
  dict_aux_float: Dict[str, float] = {
68
69
  "gwd_lower_bound": gwd_lower_bound,
69
70
  }
@@ -81,7 +82,7 @@ def invoke(
81
82
 
82
83
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
83
84
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
84
-
85
+
85
86
 
86
87
  # optimizer_args # if optimizer == none
87
88
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -132,6 +133,11 @@ def invoke(
132
133
  "Please check the frontend and backend version. "
133
134
  )
134
135
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
136
+ assert "vbe_output_offsets" in dict_aux_tensor, (
137
+ "vbe_output_offsets must be in dict_aux_tensor. "
138
+ "Please check the frontend and backend version. "
139
+ )
140
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
135
141
 
136
142
  aux_int: List[int] = []
137
143
  assert "iter" in dict_aux_int, (
@@ -204,7 +210,7 @@ def invoke(
204
210
  # ['momentum1', 'learning_rate_tensor', 'optim_float']
205
211
  optim_float: List[float] = []
206
212
  optim_float.append(dict_optim_float["eps"])
207
- # optim_bool
213
+ # optim_bool
208
214
 
209
215
  return torch.ops.fbgemm.split_embedding_codegen_lookup_adagrad_function_pt2(
210
216
  # common_args
@@ -226,6 +232,7 @@ def invoke(
226
232
  max_B=vbe_metadata.max_B,
227
233
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
228
234
  vbe_output_size=vbe_metadata.output_size,
235
+ vbe_output=vbe_metadata.vbe_output,
229
236
  # aux_tensor
230
237
  aux_tensor=aux_tensor,
231
238
  # aux_int
@@ -67,14 +67,15 @@ def invoke(
67
67
  "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
68
68
  "lxu_cache_locations": common_args.lxu_cache_locations,
69
69
  "uvm_cache_stats": common_args.uvm_cache_stats,
70
+ "vbe_output_offsets" : vbe_metadata.vbe_output_offsets,
70
71
  }
71
72
 
72
73
  dict_aux_int: Dict[str, int] = {
73
- "iter": iter,
74
- "info_B_num_bits": common_args.info_B_num_bits,
74
+ "iter": iter,
75
+ "info_B_num_bits": common_args.info_B_num_bits,
75
76
  "info_B_mask": common_args.info_B_mask,
76
77
  }
77
-
78
+
78
79
  dict_aux_float: Dict[str, float] = {
79
80
  "gwd_lower_bound": gwd_lower_bound,
80
81
  }
@@ -92,7 +93,7 @@ def invoke(
92
93
 
93
94
  # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
94
95
  dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
95
-
96
+
96
97
 
97
98
  # optimizer_args # if optimizer == none
98
99
  dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
@@ -125,13 +126,13 @@ def invoke(
125
126
  momentum2.placements,
126
127
  momentum2.offsets,
127
128
  ] if momentum2 is not None else None
128
-
129
+
129
130
  if optimizer_args.use_rowwise_bias_correction and row_counter is not None:
130
131
  row_counter_host = None # not supported on CPU
131
132
  row_counter_dev = row_counter.dev
132
133
  row_counter_uvm = row_counter.uvm
133
134
  row_counter_offsets = row_counter.offsets
134
- row_counter_placements = row_counter.placements
135
+ row_counter_placements = row_counter.placements
135
136
  elif optimizer_args.use_rowwise_bias_correction:
136
137
  assert False, "`use_rowwise_bias_correction` is set, `row_counter` cannot be None"
137
138
  else:
@@ -173,6 +174,11 @@ def invoke(
173
174
  "Please check the frontend and backend version. "
174
175
  )
175
176
  aux_tensor.append(dict_aux_tensor["prev_iter_dev"])
177
+ assert "vbe_output_offsets" in dict_aux_tensor, (
178
+ "vbe_output_offsets must be in dict_aux_tensor. "
179
+ "Please check the frontend and backend version. "
180
+ )
181
+ aux_tensor.append(dict_aux_tensor["vbe_output_offsets"])
176
182
 
177
183
  aux_int: List[int] = []
178
184
  assert "iter" in dict_aux_int, (
@@ -271,7 +277,7 @@ def invoke(
271
277
  optim_float.append(dict_optim_float["weight_decay"])
272
278
  # optim_bool
273
279
  optim_bool: List[bool] = []
274
- optim_bool.append(dict_optim_bool["use_rowwise_bias_correction"])
280
+ optim_bool.append(dict_optim_bool["use_rowwise_bias_correction"])
275
281
 
276
282
  return torch.ops.fbgemm.split_embedding_codegen_lookup_adam_function_pt2(
277
283
  # common_args
@@ -293,6 +299,7 @@ def invoke(
293
299
  max_B=vbe_metadata.max_B,
294
300
  max_B_feature_rank=vbe_metadata.max_B_feature_rank,
295
301
  vbe_output_size=vbe_metadata.output_size,
302
+ vbe_output=vbe_metadata.vbe_output,
296
303
  # aux_tensor
297
304
  aux_tensor=aux_tensor,
298
305
  # aux_int
@@ -27,6 +27,8 @@ class VBEMetadata(NamedTuple):
27
27
  max_B_feature_rank: int = -1
28
28
  max_B: int = -1
29
29
  output_size: int = -1
30
+ vbe_output: Optional[torch.Tensor] = None
31
+ vbe_output_offsets: Optional[torch.Tensor] = None
30
32
 
31
33
 
32
34
  class CommonArgs(NamedTuple):
@@ -27,6 +27,8 @@ class VBEMetadata(NamedTuple):
27
27
  max_B_feature_rank: int = -1
28
28
  max_B: int = -1
29
29
  output_size: int = -1
30
+ vbe_output: Optional[torch.Tensor] = None
31
+ vbe_output_offsets: Optional[torch.Tensor] = None
30
32
 
31
33
 
32
34
  class CommonArgs(NamedTuple):