fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/sparse_ops.py CHANGED
@@ -7,10 +7,12 @@
7
7
  # pyre-strict
8
8
 
9
9
  import math
10
- from typing import Callable, List, Optional, Sequence, Tuple
10
+ from collections.abc import Sequence
11
+ from typing import Callable, Optional
11
12
 
12
13
  import torch
13
14
 
15
+ # fmt:skip
14
16
  from fbgemm_gpu.split_embedding_configs import SparseType
15
17
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
16
18
  from fbgemm_gpu.utils.loader import load_torch_module
@@ -41,12 +43,14 @@ except Exception:
41
43
  torch.ops.load_library(
42
44
  "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu"
43
45
  )
46
+ torch.ops.load_library(
47
+ "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
48
+ )
44
49
 
45
50
 
46
51
  import torch.utils._pytree as pytree
47
52
  from torch import SymInt, Tensor
48
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
49
-
53
+ from torch.fx.experimental.symbolic_shapes import guard_or_true
50
54
 
51
55
  if hasattr(torch.library, "register_fake"):
52
56
  # pyre-ignore[9]
@@ -71,7 +75,7 @@ def permute_2D_sparse_data_input1D_meta(
71
75
  stride: int,
72
76
  weights: Optional[Tensor] = None,
73
77
  permuted_lengths_sum: Optional[int] = None,
74
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
78
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
75
79
  torch._check(
76
80
  lengths.dim() == 1, lambda: f"expected lengths.dim() == 1, got {lengths.dim()}"
77
81
  )
@@ -108,7 +112,7 @@ def permute_2D_sparse_data_input1D_backward(
108
112
  grad_lengths: torch.Tensor,
109
113
  grad_values: torch.Tensor,
110
114
  grad_weights: torch.Tensor,
111
- ) -> Tuple[None, Tensor, Tensor, None, Tensor, None]:
115
+ ) -> tuple[None, Tensor, Tensor, None, Tensor, None]:
112
116
  inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
113
117
  permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
114
118
  torch.ops.fbgemm.permute_2D_sparse_data_input1D(
@@ -136,7 +140,7 @@ def permute_2D_sparse_data_meta(
136
140
  values: Tensor,
137
141
  weights: Optional[Tensor] = None,
138
142
  permuted_lengths_sum: Optional[int] = None,
139
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
143
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
140
144
  torch._check(
141
145
  lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}"
142
146
  )
@@ -163,6 +167,89 @@ def invert_permute_abstract(permute: Tensor) -> Tensor:
163
167
  return torch.empty_like(permute)
164
168
 
165
169
 
170
+ def get_source_mask_meta(
171
+ num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
172
+ ) -> Tensor:
173
+ if output_size is None:
174
+ ctx = torch.library.get_ctx()
175
+ output_size = ctx.new_dynamic_size()
176
+ return torch.empty([output_size], dtype=torch.bool)
177
+
178
+
179
+ def get_source_mask(
180
+ num_sources: Tensor, num_targets: Tensor, output_size: Optional[int] = None
181
+ ) -> Tensor:
182
+ """
183
+ Generate a boolean mask indicating which elements are from sources vs targets.
184
+
185
+ This is a Python wrapper that computes output_size when not provided,
186
+ enabling the operation to work with meta tensors for compilation.
187
+
188
+ Args:
189
+ num_sources: 1D tensor of source counts per batch element
190
+ num_targets: 1D tensor of target counts per batch element
191
+ output_size: Optional pre-computed output size.
192
+
193
+ Returns:
194
+ A 1D boolean tensor where True indicates source elements and False
195
+ indicates target elements
196
+
197
+ Example:
198
+ >>> num_sources = torch.tensor([2, 3])
199
+ >>> num_targets = torch.tensor([1, 2])
200
+ >>> get_source_mask(num_sources, num_targets)
201
+ tensor([True, True, False, True, True, True, False, False])
202
+ """
203
+ # Compute output_size if not provided and tensors are regular (not meta/fake)
204
+ if output_size is None:
205
+ combined = num_sources + num_targets
206
+ output_size = int(combined.sum().item())
207
+
208
+ return torch.ops.fbgemm.get_source_mask(num_sources, num_targets, output_size)
209
+
210
+
211
+ def repeat_arange_meta(lengths: Tensor) -> Tensor:
212
+ """Meta implementation for repeat_arange."""
213
+ # Output size is data-dependent (sum of lengths).
214
+ # For FakeTensors (used in torch.compile), we use dynamic sizing.
215
+ # For actual meta tensors, we cannot determine the size so return empty.
216
+ if lengths.device.type == "meta":
217
+ # Actual meta tensors: return a zero-sized tensor as placeholder
218
+ # since we cannot compute the data-dependent output size
219
+ return torch.empty([0], dtype=lengths.dtype, device=lengths.device)
220
+ else:
221
+ # FakeTensor context: use dynamic sizing for proper shape tracking
222
+ ctx = torch.library.get_ctx()
223
+ output_size = ctx.new_dynamic_size()
224
+ return torch.empty([output_size], dtype=lengths.dtype, device=lengths.device)
225
+
226
+
227
+ def repeat_arange(lengths: Tensor) -> Tensor:
228
+ """
229
+ Creates a concatenated tensor of aranges based on a lengths tensor.
230
+
231
+ This is a high-performance CUDA kernel that replaces the inefficient PyTorch
232
+ implementation which uses 4+ separate kernels (cumsum, arange, repeat_interleave, sub).
233
+
234
+ Args:
235
+ lengths: 1D tensor of lengths for each arange sequence
236
+
237
+ Returns:
238
+ A 1D tensor containing concatenated arange sequences
239
+
240
+ Example:
241
+ >>> lengths = torch.tensor([3, 5, 2])
242
+ >>> repeat_arange(lengths)
243
+ tensor([0, 1, 2, 0, 1, 2, 3, 4, 0, 1])
244
+
245
+ Performance:
246
+ - PyTorch implementation: 4+ kernel launches + intermediate allocations
247
+ - CUDA implementation: 1 fused kernel, no intermediate allocations
248
+ - Typical speedup: 3-5x on realistic workloads
249
+ """
250
+ return torch.ops.fbgemm.repeat_arange(lengths)
251
+
252
+
166
253
  # pyre-ignore
167
254
  def permute_2D_sparse_data_setup_context(ctx, inputs, output):
168
255
  permute, lengths, values, weights, permuted_lengths_sum = inputs
@@ -194,7 +281,7 @@ def permute_1D_sparse_data_meta(
194
281
  values: Tensor,
195
282
  weights: Optional[Tensor] = None,
196
283
  permuted_lengths_sum: Optional[int] = None,
197
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
284
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
198
285
  indices = values
199
286
  permuted_lengths_size = permute.numel()
200
287
  permuted_lengths = lengths.new_empty([permuted_lengths_size])
@@ -215,7 +302,7 @@ def permute_1D_sparse_data_meta(
215
302
 
216
303
  def masked_select_jagged_1d(
217
304
  values: Tensor, lengths: Tensor, mask: Tensor
218
- ) -> Tuple[Tensor, Tensor]:
305
+ ) -> tuple[Tensor, Tensor]:
219
306
  torch._check(values.dim() == 1)
220
307
  torch._check(lengths.dim() == 1)
221
308
  torch._check(values.device == lengths.device)
@@ -228,11 +315,11 @@ def masked_select_jagged_1d(
228
315
 
229
316
 
230
317
  def tbe_input_combine_abstract(
231
- indices_list: List[Tensor],
232
- offsets_list: List[Tensor],
233
- per_sample_weights: List[Tensor],
318
+ indices_list: list[Tensor],
319
+ offsets_list: list[Tensor],
320
+ per_sample_weights: list[Tensor],
234
321
  include_last_offsets: Tensor,
235
- ) -> Tuple[Tensor, Tensor, Tensor]:
322
+ ) -> tuple[Tensor, Tensor, Tensor]:
236
323
  torch._check(len(indices_list) > 0)
237
324
  torch._check(len(indices_list) == len(offsets_list))
238
325
  torch._check(len(indices_list) == len(per_sample_weights))
@@ -247,7 +334,7 @@ def tbe_input_combine_abstract(
247
334
  torch._check(index.is_contiguous())
248
335
  torch._check(offset.is_contiguous())
249
336
  total_indices = total_indices + index.numel()
250
- if guard_size_oblivious(weight.numel() > 0):
337
+ if guard_or_true(weight.numel() > 0):
251
338
  torch._check(weight.dim() == 1)
252
339
  torch._check(weight.numel() == index.numel())
253
340
  torch._check(weight.is_contiguous())
@@ -265,10 +352,10 @@ def tbe_input_combine_abstract(
265
352
 
266
353
 
267
354
  def tbe_input_combine_with_length_abstract(
268
- indices_list: List[Tensor],
269
- offsets_list: List[Tensor],
270
- per_sample_weights: List[Tensor],
271
- ) -> Tuple[Tensor, Tensor, Tensor]:
355
+ indices_list: list[Tensor],
356
+ offsets_list: list[Tensor],
357
+ per_sample_weights: list[Tensor],
358
+ ) -> tuple[Tensor, Tensor, Tensor]:
272
359
  torch._check(len(indices_list) > 0)
273
360
  torch._check(len(indices_list) == len(offsets_list))
274
361
  torch._check(len(indices_list) == len(per_sample_weights))
@@ -284,7 +371,7 @@ def tbe_input_combine_with_length_abstract(
284
371
  torch._check(offset.is_contiguous())
285
372
  total_indices = total_indices + index.numel()
286
373
  total_offsets = total_offsets + offset.numel()
287
- if guard_size_oblivious(weight.numel() > 0):
374
+ if guard_or_true(weight.numel() > 0):
288
375
  torch._check(weight.dim() == 1)
289
376
  torch._check(weight.numel() == index.numel())
290
377
  torch._check(weight.is_contiguous())
@@ -336,7 +423,7 @@ def expand_into_jagged_permute_meta(
336
423
  permute: Tensor,
337
424
  input_offsets: Tensor,
338
425
  output_offsets: Tensor,
339
- output_size: Tuple[int, ...],
426
+ output_size: tuple[int, ...],
340
427
  ) -> Tensor:
341
428
  torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0")
342
429
  torch._check(
@@ -417,6 +504,7 @@ def int_nbit_split_embedding_codegen_lookup_function_meta(
417
504
  kINT8QparamsBytes = 8
418
505
 
419
506
  if pooling_mode == PoolingMode.NONE:
507
+ kINT8QparamsBytes = 4
420
508
  D = max(
421
509
  [
422
510
  max_int2_D,
@@ -432,7 +520,7 @@ def int_nbit_split_embedding_codegen_lookup_function_meta(
432
520
  torch._check(D > 0)
433
521
  adjusted_D = D
434
522
  if SparseType.from_int(output_dtype_int) == SparseType.INT8:
435
- adjusted_D += T * kINT8QparamsBytes
523
+ adjusted_D += kINT8QparamsBytes
436
524
  output = dev_weights.new_empty([total_L, adjusted_D], dtype=output_dtype)
437
525
  return output
438
526
 
@@ -460,7 +548,8 @@ def block_bucketize_sparse_features_meta(
460
548
  block_bucketize_pos: Optional[torch.Tensor] = None,
461
549
  keep_orig_idx: bool = False,
462
550
  total_num_blocks: Optional[torch.Tensor] = None,
463
- ) -> Tuple[
551
+ keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
552
+ ) -> tuple[
464
553
  torch.Tensor,
465
554
  torch.Tensor,
466
555
  Optional[torch.Tensor],
@@ -480,8 +569,43 @@ def block_bucketize_sparse_features_meta(
480
569
  )
481
570
 
482
571
 
572
+ def block_bucketize_sparse_features_2d_weights_meta(
573
+ lengths: torch.Tensor,
574
+ indices: torch.Tensor,
575
+ bucketize_pos: bool,
576
+ sequence: bool,
577
+ block_sizes: torch.Tensor,
578
+ my_size: int,
579
+ weights: torch.Tensor,
580
+ weights_dim: int = 1,
581
+ batch_size_per_feature: Optional[torch.Tensor] = None,
582
+ max_B: int = -1,
583
+ block_bucketize_pos: Optional[torch.Tensor] = None,
584
+ keep_orig_idx: bool = False,
585
+ total_num_blocks: Optional[torch.Tensor] = None,
586
+ keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
587
+ ) -> tuple[
588
+ torch.Tensor,
589
+ torch.Tensor,
590
+ torch.Tensor,
591
+ Optional[torch.Tensor],
592
+ Optional[torch.Tensor],
593
+ ]:
594
+ # Output: lengths, indices, weights", pos?, unbucketize_permute?
595
+ num_buckets = my_size
596
+ num_features = lengths.size(0)
597
+ num_values = indices.size(0)
598
+ return (
599
+ lengths.new_empty([num_buckets * num_features]),
600
+ indices.new_empty([num_values]),
601
+ weights.new_empty([num_values, weights_dim]),
602
+ indices.new_empty([num_values]) if bucketize_pos else None,
603
+ indices.new_empty([num_values]),
604
+ )
605
+
606
+
483
607
  def merge_pooled_embeddings(
484
- pooled_embeddings: List[torch.Tensor],
608
+ pooled_embeddings: list[torch.Tensor],
485
609
  uncat_dim_size: int,
486
610
  target_device: torch.device,
487
611
  cat_dim: int = 1,
@@ -512,7 +636,7 @@ def merge_pooled_embeddings(
512
636
 
513
637
  def permute_sparse_features_abstract(
514
638
  permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None
515
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
639
+ ) -> tuple[Tensor, Tensor, Optional[Tensor]]:
516
640
  torch._check(lengths.dtype == indices.dtype)
517
641
  torch._check(permute.device == lengths.device)
518
642
  torch._check(permute.device == indices.device)
@@ -543,7 +667,7 @@ def segment_sum_csr_abstract(
543
667
 
544
668
  def dense_to_jagged_forward(
545
669
  dense: torch.Tensor,
546
- offsets: List[torch.Tensor],
670
+ offsets: list[torch.Tensor],
547
671
  total_L: Optional[torch.SymInt] = None,
548
672
  ) -> torch.Tensor:
549
673
  if total_L is None:
@@ -558,9 +682,9 @@ def dense_to_jagged_forward(
558
682
 
559
683
  def dense_to_jagged(
560
684
  dense: torch.Tensor,
561
- offsets: List[torch.Tensor],
685
+ offsets: list[torch.Tensor],
562
686
  total_L: Optional[torch.SymInt] = None,
563
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
687
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
564
688
  if total_L is None:
565
689
  total_L = torch.library.get_ctx().new_dynamic_size()
566
690
  return (dense_to_jagged_forward(dense, offsets, total_L), offsets)
@@ -569,9 +693,9 @@ def dense_to_jagged(
569
693
  def batch_index_select_dim0_abstract(
570
694
  inputs: torch.Tensor,
571
695
  indices: torch.Tensor,
572
- input_num_indices: List[int],
573
- input_rows: List[int],
574
- input_columns: List[int],
696
+ input_num_indices: list[int],
697
+ input_rows: list[int],
698
+ input_columns: list[int],
575
699
  permute_output_dim_0_1: bool,
576
700
  ) -> torch.Tensor:
577
701
  """
@@ -613,11 +737,11 @@ def batch_index_select_dim0_tensor_abstract(
613
737
  def batch_index_select_dim0_forward_cuda_impl_abstract(
614
738
  inputs: torch.Tensor,
615
739
  indices: torch.Tensor,
616
- input_num_indices: List[int],
617
- input_rows: List[int],
618
- input_columns: List[int],
740
+ input_num_indices: list[int],
741
+ input_rows: list[int],
742
+ input_columns: list[int],
619
743
  permute_output_dim_0_1: bool,
620
- ) -> List[torch.Tensor]:
744
+ ) -> list[torch.Tensor]:
621
745
  num_inputs = len(input_rows)
622
746
  torch._check(len(input_num_indices) == len(input_rows))
623
747
  torch._check(len(input_num_indices) == len(input_columns))
@@ -654,7 +778,7 @@ def batch_index_select_dim0_tensor_forward_cuda_impl_abstract(
654
778
  input_rows: torch.Tensor,
655
779
  input_columns: torch.Tensor,
656
780
  permute_output_dim_0_1: bool,
657
- ) -> List[torch.Tensor]:
781
+ ) -> list[torch.Tensor]:
658
782
  num_inputs: int = input_rows.size(0)
659
783
  torch._check(input_num_indices.size(0) == input_rows.size(0))
660
784
  torch._check(input_num_indices.size(0) == input_columns.size(0))
@@ -699,7 +823,7 @@ def keyed_jagged_index_select_dim1_abstract(
699
823
  batch_size: torch.SymInt,
700
824
  weights: Optional[torch.Tensor] = None,
701
825
  selected_lengths_sum: Optional[torch.SymInt] = None,
702
- ) -> List[torch.Tensor]:
826
+ ) -> list[torch.Tensor]:
703
827
  """
704
828
  This meta function is used to calculate the shape of output tensors
705
829
  from the original function `fbgemm::keyed_jagged_index_select_dim1` without the actual data.
@@ -724,7 +848,7 @@ def keyed_jagged_index_select_dim1_abstract(
724
848
  torch.index_select(lengths, 0, length_indices).sum().item()
725
849
  )
726
850
 
727
- ret: List[torch.Tensor] = [
851
+ ret: list[torch.Tensor] = [
728
852
  # pyre-ignore
729
853
  values.new_empty([selected_lengths_sum]),
730
854
  lengths.new_empty([indices.shape[0] * num_batches]),
@@ -756,17 +880,17 @@ def batch_index_select_dim0_backward_cuda_impl_abstract(
756
880
  def batch_index_select_dim0_forward_cpu_impl_abstract(
757
881
  inputs: torch.Tensor,
758
882
  indices: torch.Tensor,
759
- input_num_indices: List[int],
760
- input_rows: List[int],
761
- input_columns: List[int],
883
+ input_num_indices: list[int],
884
+ input_rows: list[int],
885
+ input_columns: list[int],
762
886
  permute_output_dim_0_1: bool,
763
- ) -> List[torch.Tensor]:
887
+ ) -> list[torch.Tensor]:
764
888
  # input lists must have the same length
765
889
  num_inputs = len(input_num_indices)
766
890
  torch._check(num_inputs == len(input_rows))
767
891
  torch._check(num_inputs == len(input_columns))
768
892
 
769
- if permute_output_dim_0_1 and guard_size_oblivious(len(input_num_indices) > 0):
893
+ if permute_output_dim_0_1 and guard_or_true(len(input_num_indices) > 0):
770
894
  # All num_indices must be the same if permute_output_dim_0_1 is True
771
895
  for x in input_num_indices:
772
896
  torch._check(x == input_num_indices[0])
@@ -790,7 +914,7 @@ def batch_index_select_dim0_tensor_forward_cpu_impl_abstract(
790
914
  input_rows: torch.Tensor,
791
915
  input_columns: torch.Tensor,
792
916
  permute_output_dim_0_1: bool,
793
- ) -> List[torch.Tensor]:
917
+ ) -> list[torch.Tensor]:
794
918
  # input lists must have the same length
795
919
  num_inputs = len(input_num_indices)
796
920
  torch._check(num_inputs == len(input_rows))
@@ -829,6 +953,8 @@ def bounds_check_indices_abstract(
829
953
  b_t_map: Optional[torch.Tensor] = None,
830
954
  info_B_num_bits: int = -1,
831
955
  info_B_mask: int = -1,
956
+ bounds_check_version: int = 1,
957
+ prefetch_pipeline: bool = False,
832
958
  ) -> None:
833
959
  """
834
960
  This meta function is used to fake the bounds checking
@@ -838,8 +964,8 @@ def bounds_check_indices_abstract(
838
964
 
839
965
 
840
966
  def group_index_select_dim0_gpu_impl_abstract(
841
- inputs: List[torch.Tensor], group_size: int
842
- ) -> List[torch.Tensor]:
967
+ inputs: list[torch.Tensor], group_size: int
968
+ ) -> list[torch.Tensor]:
843
969
  """
844
970
  Calculate output shapes for group_index_select_dim0_gpu_impl
845
971
  without the actual data.
@@ -869,8 +995,8 @@ def group_index_select_dim0_gpu_impl_abstract(
869
995
 
870
996
 
871
997
  def group_index_select_dim0_gpu_backward_abstract(
872
- all_inputs: List[torch.Tensor], output_shape_group_ref: List[torch.SymInt]
873
- ) -> List[torch.Tensor]:
998
+ all_inputs: list[torch.Tensor], output_shape_group_ref: list[torch.SymInt]
999
+ ) -> list[torch.Tensor]:
874
1000
  """
875
1001
  Calculate output shapes for group_index_select_dim0_gpu_backward
876
1002
  without the actual data.
@@ -903,7 +1029,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
903
1029
  batch_size: torch.SymInt,
904
1030
  weights: Optional[torch.Tensor] = None,
905
1031
  selected_lengths_sum: Optional[torch.SymInt] = None,
906
- ) -> List[torch.Tensor]:
1032
+ ) -> list[torch.Tensor]:
907
1033
  num_batches = lengths.size(0) // batch_size
908
1034
  torch._check(lengths.size(0) + 1 == offsets.size(0))
909
1035
  # pyre-ignore
@@ -917,7 +1043,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
917
1043
  selected_lengths_sum = torch.library.get_ctx().new_dynamic_size()
918
1044
 
919
1045
  torch._check_is_size(selected_lengths_sum)
920
- vlw: List[torch.Tensor] = [
1046
+ vlw: list[torch.Tensor] = [
921
1047
  values.new_empty([selected_lengths_sum]), # output
922
1048
  lengths.new_empty([indices.shape[0] * num_batches]), # output_lengths
923
1049
  ]
@@ -960,7 +1086,7 @@ def histogram_binning_calibration_abstract(
960
1086
  upper_bound: float,
961
1087
  bin_ctr_in_use_after: int,
962
1088
  bin_ctr_weight_value: float,
963
- ) -> Tuple[Tensor, Tensor]:
1089
+ ) -> tuple[Tensor, Tensor]:
964
1090
  return torch.empty_like(logit), torch.empty([logit.numel()], dtype=torch.int64)
965
1091
 
966
1092
 
@@ -1111,7 +1237,7 @@ def generic_histogram_binning_calibration_by_feature(
1111
1237
  positive_weight: float,
1112
1238
  bin_ctr_in_use_after: int,
1113
1239
  bin_ctr_weight_value: float,
1114
- ) -> Tuple[Tensor, Tensor]:
1240
+ ) -> tuple[Tensor, Tensor]:
1115
1241
  torch._check(bin_num_examples.numel() == bin_num_positives.numel())
1116
1242
  torch._check(
1117
1243
  bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1)
@@ -1121,6 +1247,58 @@ def generic_histogram_binning_calibration_by_feature(
1121
1247
  )
1122
1248
 
1123
1249
 
1250
+ def permute_multi_embedding_function_impl_abstract(
1251
+ pooled_embs: list[Tensor],
1252
+ permutes: Tensor,
1253
+ in_shapes: Tensor,
1254
+ out_shapes: Tensor,
1255
+ out_lengths: list[int],
1256
+ reverse: bool = False,
1257
+ ) -> list[Tensor]:
1258
+ out_dtype = pooled_embs[0].dtype
1259
+ bs = pooled_embs[0].shape[0]
1260
+ torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns")
1261
+
1262
+ output = []
1263
+ for i in range(len(out_lengths)):
1264
+ output.append(torch.empty([bs, out_lengths[i]], dtype=out_dtype))
1265
+ return output
1266
+
1267
+
1268
+ def lengths_range_abstract(
1269
+ lengths: Tensor,
1270
+ output_shape: Optional[Sequence[int]] = None,
1271
+ ) -> Tensor:
1272
+ torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor")
1273
+ output_size = 0
1274
+ if output_shape is not None:
1275
+ output_size = math.prod(output_shape)
1276
+ else:
1277
+ ctx = torch.library.get_ctx()
1278
+ output_size = ctx.new_dynamic_size()
1279
+ return lengths.new_empty([output_size], dtype=lengths.dtype)
1280
+
1281
+
1282
+ def all_to_one_device(
1283
+ input_tensors: list[Tensor],
1284
+ target_device: torch.device,
1285
+ ) -> list[Tensor]:
1286
+ return [
1287
+ torch.empty_like(input_tensor, device=torch.device("meta"))
1288
+ for input_tensor in input_tensors
1289
+ ]
1290
+
1291
+
1292
+ def sum_reduce_to_one(
1293
+ input_tensors: list[Tensor],
1294
+ target_device: torch.device,
1295
+ ) -> Tensor:
1296
+ torch._check(len(input_tensors) > 0, lambda: "reducing no tensor is undefined")
1297
+ # All tensors should have the same shape
1298
+ first_tensor = input_tensors[0]
1299
+ return torch.empty_like(first_tensor, device=torch.device("meta"))
1300
+
1301
+
1124
1302
  def _setup() -> None:
1125
1303
  # pyre-ignore[16]
1126
1304
  _setup.done = getattr(_setup, "done", False)
@@ -1153,6 +1331,8 @@ def _setup() -> None:
1153
1331
  )
1154
1332
 
1155
1333
  impl_abstract("fbgemm::permute_2D_sparse_data", permute_2D_sparse_data_meta)
1334
+ impl_abstract("fbgemm::get_source_mask", get_source_mask_meta)
1335
+ impl_abstract("fbgemm::repeat_arange", repeat_arange_meta)
1156
1336
  impl_abstract(
1157
1337
  "fbgemm::permute_2D_sparse_data_input1D",
1158
1338
  permute_2D_sparse_data_input1D_meta,
@@ -1185,12 +1365,18 @@ def _setup() -> None:
1185
1365
  "fbgemm::block_bucketize_sparse_features",
1186
1366
  block_bucketize_sparse_features_meta,
1187
1367
  )
1368
+ impl_abstract(
1369
+ "fbgemm::block_bucketize_sparse_features_2d_weights",
1370
+ block_bucketize_sparse_features_2d_weights_meta,
1371
+ )
1188
1372
  impl_abstract("fbgemm::merge_pooled_embeddings", merge_pooled_embeddings)
1189
1373
  impl_abstract(
1190
1374
  "fbgemm::permute_sparse_features", permute_sparse_features_abstract
1191
1375
  )
1192
1376
  impl_abstract("fbgemm::segment_sum_csr", segment_sum_csr_abstract)
1193
1377
  impl_abstract("fbgemm::dense_to_jagged_forward", dense_to_jagged_forward)
1378
+ impl_abstract("fbgemm::all_to_one_device", all_to_one_device)
1379
+ impl_abstract("fbgemm::sum_reduce_to_one", sum_reduce_to_one)
1194
1380
  impl_abstract(
1195
1381
  "fbgemm::batch_index_select_dim0", batch_index_select_dim0_abstract
1196
1382
  )
@@ -1258,6 +1444,14 @@ def _setup() -> None:
1258
1444
  "fbgemm::generic_histogram_binning_calibration_by_feature",
1259
1445
  generic_histogram_binning_calibration_by_feature,
1260
1446
  )
1447
+ impl_abstract(
1448
+ "fbgemm::lengths_range",
1449
+ lengths_range_abstract,
1450
+ )
1451
+ impl_abstract(
1452
+ "fbgemm::permute_multi_embedding_function",
1453
+ permute_multi_embedding_function_impl_abstract,
1454
+ )
1261
1455
  impl_abstract(
1262
1456
  "fbgemm::FloatToHFP8Quantized",
1263
1457
  float_to_hfp8_quantized,
@@ -1302,29 +1496,3 @@ def _setup() -> None:
1302
1496
 
1303
1497
 
1304
1498
  _setup()
1305
-
1306
-
1307
- @torch.library.register_fake("fbgemm::lengths_range")
1308
- def lengths_range_abstract(
1309
- lengths: Tensor,
1310
- output_shape: Optional[Sequence[int]] = None,
1311
- ) -> Tensor:
1312
- torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor")
1313
- output_size = 0
1314
- if output_shape is not None:
1315
- output_size = math.prod(output_shape)
1316
- else:
1317
- ctx = torch.library.get_ctx()
1318
- output_size = ctx.new_dynamic_size()
1319
- return lengths.new_empty([output_size], dtype=lengths.dtype)
1320
-
1321
-
1322
- @torch.library.register_fake("fbgemm::all_to_one_device")
1323
- def all_to_one_device(
1324
- input_tensors: List[Tensor],
1325
- target_device: torch.device,
1326
- ) -> List[Tensor]:
1327
- return [
1328
- torch.empty_like(input_tensor, device=torch.device("meta"))
1329
- for input_tensor in input_tensors
1330
- ]
@@ -4,6 +4,8 @@
4
4
  ## Template Source: training/python/__init__.template
5
5
  ################################################################################
6
6
 
7
+ __template_source_file__ = "training/python/__init__.template"
8
+
7
9
  #!/usr/bin/env python3
8
10
 
9
11
  # Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -150,6 +152,30 @@ except:
150
152
  DeprecationWarning,
151
153
  )
152
154
 
155
+ try:
156
+ # Import is placed under a try-except bc the op is experimental and can be
157
+ # removed/updated in the future
158
+ import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam_ssd as lookup_adam_ssd # noqa: F401
159
+ except:
160
+ warnings.warn(
161
+ f"""\033[93m
162
+ Failed to import: fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam_ssd
163
+ \033[0m""",
164
+ DeprecationWarning,
165
+ )
166
+
167
+ try:
168
+ # Import is placed under a try-except bc the op is experimental and can be
169
+ # removed/updated in the future
170
+ import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_adam_ssd as lookup_partial_rowwise_adam_ssd # noqa: F401
171
+ except:
172
+ warnings.warn(
173
+ f"""\033[93m
174
+ Failed to import: fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_adam_ssd
175
+ \033[0m""",
176
+ DeprecationWarning,
177
+ )
178
+
153
179
  try:
154
180
  # Import is placed under a try-except bc the op is experimental and can be
155
181
  # removed/updated in the future