fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/__init__.py CHANGED
@@ -5,19 +5,108 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ import json
8
9
  import logging
9
10
  import os
11
+ import re
10
12
 
11
13
  import torch
12
14
 
15
+ # Based on the FBGEMM-PyTorch compatibility table at
16
+ # https://docs.pytorch.org/FBGEMM/general/Releases.html#fbgemm-releases-compatibility
17
+ _fbgemm_torch_compat_table = {
18
+ "1.5": "2.10",
19
+ "1.4": "2.9",
20
+ "1.3": "2.8",
21
+ "1.2": "2.7",
22
+ "1.1": "2.6",
23
+ "1.0": "2.5",
24
+ "0.8": "2.4",
25
+ "0.7": "2.3",
26
+ "0.6": "2.2",
27
+ "0.5": "2.1",
28
+ "0.4": "2.0",
29
+ }
30
+
31
+
32
+ def _load_target_info(target: str) -> dict[str, str]:
33
+ try:
34
+ filepath = os.path.join(
35
+ os.path.dirname(__file__), "docs", f"target.{target}.json.py"
36
+ )
37
+ with open(filepath, "r") as file:
38
+ data = json.load(file)
39
+ except Exception:
40
+ data = {}
41
+
42
+ return data
13
43
 
14
- def _load_library(filename: str, no_throw: bool = False) -> None:
44
+
45
+ def _load_library(filename: str, version: str, no_throw: bool = False) -> None:
15
46
  """Load a shared library from the given filename."""
47
+
48
+ # Check if the version of PyTorch is compatible with the version of FBGEMM
49
+ # that we are trying to load, and print a loud warning if not. This is
50
+ # useful for the OSS build, where we have a single FBGEMM library that is
51
+ # compatible with multiple versions of PyTorch.
52
+ #
53
+ # Based on: https://github.com/pytorch/ao/blob/main/torchao/__init__.py#L30
54
+
55
+ keys = [
56
+ key
57
+ for key in _fbgemm_torch_compat_table.keys()
58
+ if version.startswith(f"{key}.")
59
+ ]
60
+
61
+ if version == "INTERNAL" or "+git" in version:
62
+ # if FBGEMM version has "+git", assume it's locally built and we don't know
63
+ # anything about the PyTorch version used to build it
64
+ logging.info(
65
+ "FBGEMM version is INTERNAL or local, ignoring version compatibility check with PyTorch"
66
+ )
67
+
68
+ elif re.match(r"^\d{4}\.\d{1,2}\.\d{1,2}.*$", version):
69
+ # if FBGEMM version is a date, assume it's a nightly build and that we
70
+ # know what we're doing
71
+ logging.info(
72
+ "FBGEMM version is a nightly version, ignoring version compatibility check with PyTorch"
73
+ )
74
+
75
+ elif not keys:
76
+ # fmt: off
77
+ logging.warning(
78
+ f"""
79
+ \033[33m
80
+ _fbgemm_torch_compat_table has no entry for {version} of FBGEMM;
81
+ cannot determine compatibility with PyTorch {torch.__version__}
82
+ \033[0m
83
+ """
84
+ )
85
+ # fmt: on
86
+
87
+ elif not str(torch.__version__).startswith(_fbgemm_torch_compat_table[keys[0]]):
88
+ # fmt: off
89
+ logging.warning(
90
+ f"""
91
+ \033[31m
92
+ FBGEMM_GPU version is {version}, which is not guaranteed to be
93
+ compatible with PyTorch {torch.__version__}; library loading might
94
+ crash!
95
+
96
+ Please refer to
97
+ https://docs.pytorch.org/FBGEMM/general/Releases.html#fbgemm-releases-compatibility
98
+ for the FBGEMM-PyTorch compatibility table.
99
+ \033[0m
100
+ """
101
+ )
102
+ # fmt: on
103
+
16
104
  try:
17
105
  torch.ops.load_library(os.path.join(os.path.dirname(__file__), filename))
18
106
  logging.info(f"Successfully loaded: '{filename}'")
107
+
19
108
  except Exception as error:
20
- logging.error(f"Could not load the library '{filename}': {error}")
109
+ logging.error(f"Could not load the library '{filename}'!\n\n\n{error}\n\n\n")
21
110
  if not no_throw:
22
111
  raise error
23
112
 
@@ -29,17 +118,25 @@ open_source: bool = True
29
118
  # Trigger the manual addition of docstrings to pybind11-generated operators
30
119
  import fbgemm_gpu.docs # noqa: F401, E402
31
120
 
121
+ __targets_infos__ = {
122
+ target: _load_target_info(target) for target in ["default", "genai", "hstu"]
123
+ }
124
+ __targets_infos__ = {k: v for (k, v) in __targets_infos__.items() if v}
125
+
32
126
  try:
33
- # Export the version string from the version file auto-generated by setup.py
34
- from fbgemm_gpu.docs.version import __variant__, __version__ # noqa: F401, E402
127
+ __target__, __info__ = next(iter(__targets_infos__.items()))
128
+ __variant__ = __info__["variant"]
129
+ __version__ = __info__["version"]
35
130
  except Exception:
36
131
  __variant__: str = "INTERNAL"
37
132
  __version__: str = "INTERNAL"
133
+ __target__: str = "INTERNAL"
38
134
 
39
135
  fbgemm_gpu_libraries = [
40
136
  "fbgemm_gpu_config",
41
137
  "fbgemm_gpu_tbe_utils",
42
138
  "fbgemm_gpu_tbe_index_select",
139
+ "fbgemm_gpu_tbe_cache",
43
140
  "fbgemm_gpu_tbe_optimizers",
44
141
  "fbgemm_gpu_tbe_inference",
45
142
  "fbgemm_gpu_tbe_training_forward",
@@ -52,7 +149,7 @@ fbgemm_gpu_libraries = [
52
149
  "fbgemm_gpu_py",
53
150
  ]
54
151
 
55
- fbgemm_gpu_genai_libraries = [
152
+ fbgemm_genai_libraries = [
56
153
  "experimental/gen_ai/fbgemm_gpu_experimental_gen_ai",
57
154
  ]
58
155
 
@@ -64,28 +161,26 @@ fbgemm_gpu_genai_libraries = [
64
161
  # .SO file for the ROCm case, so that clients can import
65
162
  # fbgemm_gpu.experimental.gemm without triggering an error.
66
163
  if torch.cuda.is_available() and torch.version.hip:
67
- fbgemm_gpu_genai_libraries = []
164
+ fbgemm_genai_libraries = []
68
165
 
69
166
  libraries_to_load = {
70
- "cpu": fbgemm_gpu_libraries,
71
- "docs": fbgemm_gpu_libraries,
72
- "cuda": fbgemm_gpu_libraries + fbgemm_gpu_genai_libraries,
73
- "genai": fbgemm_gpu_genai_libraries,
74
- "rocm": fbgemm_gpu_libraries,
167
+ "default": fbgemm_gpu_libraries,
168
+ "genai": fbgemm_genai_libraries,
75
169
  }
76
170
 
77
- for library in libraries_to_load.get(__variant__, []):
78
- # NOTE: In all cases, we want to throw an error if we cannot load the
79
- # library. However, this appears to break the OSS documentation build,
80
- # where the Python documentation doesn't show up in the generated docs.
81
- #
82
- # To work around this problem, we introduce a fake build variant called
83
- # `docs` and we only throw a library load error when the variant is not
84
- # `docs`. For more information, see:
85
- #
86
- # https://github.com/pytorch/FBGEMM/pull/3477
87
- # https://github.com/pytorch/FBGEMM/pull/3717
88
- _load_library(f"{library}.so", __variant__ == "docs")
171
+ for target, info in __targets_infos__.items():
172
+ for library in libraries_to_load.get(target, []):
173
+ # NOTE: In all cases, we want to throw an error if we cannot load the
174
+ # library. However, this appears to break the OSS documentation build,
175
+ # where the Python documentation doesn't show up in the generated docs.
176
+ #
177
+ # To work around this problem, we introduce a fake build variant called
178
+ # `docs` and we only throw a library load error when the variant is not
179
+ # `docs`. For more information, see:
180
+ #
181
+ # https://github.com/pytorch/FBGEMM/pull/3477
182
+ # https://github.com/pytorch/FBGEMM/pull/3717
183
+ _load_library(f"{library}.so", info["version"], info["variant"] == "docs")
89
184
 
90
185
  try:
91
186
  # Trigger meta operator registrations
fbgemm_gpu/asmjit.so CHANGED
Binary file
@@ -9,10 +9,10 @@
9
9
 
10
10
 
11
11
  from math import sqrt
12
- from typing import List
13
12
 
14
13
  import torch
15
14
 
15
+ # fmt:skip
16
16
  from fbgemm_gpu.utils.loader import load_torch_module
17
17
 
18
18
  try:
@@ -22,7 +22,7 @@ except Exception:
22
22
  load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
23
23
 
24
24
 
25
- def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]:
25
+ def wrap_weight_to_parameter(weights: list[torch.Tensor]) -> list[torch.Tensor]:
26
26
  for i, v in enumerate(weights):
27
27
  if not isinstance(v, torch.nn.Parameter):
28
28
  weights[i] = torch.nn.Parameter(v)
@@ -31,7 +31,7 @@ def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]:
31
31
 
32
32
  class BatchedUnaryEmbeddingBag(torch.nn.Module):
33
33
  # pyre-fixme[3]: Return type must be annotated.
34
- def __init__(self, num_tasks: int, hash_sizes: List[int], long_index: bool = False):
34
+ def __init__(self, num_tasks: int, hash_sizes: list[int], long_index: bool = False):
35
35
  super().__init__()
36
36
  self.num_tasks = num_tasks
37
37
  self.hash_sizes = hash_sizes
@@ -11,7 +11,7 @@ from enum import auto, Enum
11
11
  import torch
12
12
 
13
13
  try:
14
- torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:config_cpp")
14
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:config_cpp_torch_op")
15
15
  except Exception:
16
16
  import fbgemm_gpu # noqa F401
17
17
 
@@ -60,6 +60,12 @@ class FeatureGateName(Enum):
60
60
  # Enable bounds_check_indices_v2
61
61
  BOUNDS_CHECK_INDICES_V2 = auto()
62
62
 
63
+ # Enable TBE input parameters extraction
64
+ TBE_REPORT_INPUT_PARAMS = auto()
65
+
66
+ # Enable tuned max segment length per CTA for B200
67
+ TBE_USE_TUNED_SEGMENT_LENGTHS_CTA_B200 = auto()
68
+
63
69
  def is_enabled(self) -> bool:
64
70
  return FeatureGate.is_enabled(self)
65
71
 
@@ -9,7 +9,6 @@ import torch
9
9
 
10
10
  from .common import add_docs
11
11
 
12
-
13
12
  add_docs(
14
13
  torch.ops.fbgemm.jagged_2d_to_dense,
15
14
  """
@@ -323,7 +323,7 @@ Returns:
323
323
  add_docs(
324
324
  torch.ops.fbgemm.block_bucketize_sparse_features,
325
325
  """
326
- block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]
326
+ block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False, total_num_blocks=None, keep_orig_idx_per_feature=None) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]
327
327
 
328
328
  Preprocess sparse features by partitioning sparse features into multiple
329
329
  buckets. Every feature is split into the same number of buckets, but the bucket
@@ -387,6 +387,11 @@ Args:
387
387
  where we scale up/down the number of GPUs but want to maintain
388
388
  same numerical behavior.
389
389
 
390
+ keep_orig_idx_per_feature (Optional[Tensor] = None): An optional tensor that
391
+ contains whether to keep original indices for each feature. If not None,
392
+ the operator will use this tensor to determine whether to keep original
393
+ indices for each feature. if None, will fallback to `keep_orig_idx`
394
+
390
395
  Return:
391
396
  A tuple of tensors containing
392
397
 
@@ -448,6 +453,24 @@ Return:
448
453
  dtype=torch.int32),
449
454
  tensor([ 0, 1, 5, 2, 6, 7, 3, 8, 9, 10, 11, 4, 12, 13, 14],
450
455
  device='cuda:0', dtype=torch.int32))
456
+ >>> # Invoke with keep_orig_idx_per_feature
457
+ >>> keep_orig_idx_per_feature = torch.tensor([False, True, False, True], dtype=torch.bool)
458
+ >>> torch.ops.fbgemm.block_bucketize_sparse_features(
459
+ >>> lengths,
460
+ >>> indices,
461
+ >>> bucketize_pos=False,
462
+ >>> sequence=False,
463
+ >>> block_sizes=block_sizes,
464
+ >>> my_size=my_size,
465
+ >>> keep_orig_idx=False,
466
+ >>> keep_orig_idx_per_feature=keep_orig_idx_per_feature)
467
+ (tensor([0, 0, 0, 1, 1, 1, 2, 1, 0, 2, 1, 2, 1, 2, 1, 0], device='cuda:0',
468
+ dtype=torch.int32),
469
+ tensor([ 3, 4, 11, 1, 11, 15, 28, 29, 0, 1, 2, 3, 22, 20, 20],
470
+ device='cuda:0', dtype=torch.int32),
471
+ None,
472
+ None,
473
+ None)
451
474
  >>> # Invoke with block_bucketize_pos
452
475
  >>> block_bucketize_pos = [
453
476
  >>> torch.tensor([0, 2, 8], dtype=torch.int),
@@ -473,3 +496,121 @@ Return:
473
496
  None)
474
497
  """,
475
498
  )
499
+
500
+ add_docs(
501
+ torch.ops.fbgemm.block_bucketize_sparse_features_2d_weights,
502
+ """
503
+ block_bucketize_sparse_features_2d_weights(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights, weights_dim=1, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False, total_num_blocks=None, keep_orig_idx_per_feature=None) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]
504
+
505
+ Preprocess sparse features by partitioning sparse features into multiple
506
+ buckets with support for 2D weights. Every feature is split into the same number of buckets, but the bucket
507
+ sizes (widths) for the different features can be different. Moreover, the
508
+ bucket sizes within each feature can be different.
509
+
510
+ This function is similar to block_bucketize_sparse_features but supports 2D weights,
511
+ where each index can have multiple weight values associated with it.
512
+
513
+ Args:
514
+ lengths (Tensor): The lengths of the sparse features. The tensor contains
515
+ the lengths of each sample in a batch and each feature. Shape is `B *
516
+ T` where `B` is the batch size and `T` is the number of features
517
+
518
+ indices (Tensor): The sparse data. Only support integer types. Shape is the
519
+ sum of `lengths`
520
+
521
+ bucketize_pos (bool): If True, return the original relative indices within
522
+ a sample. For example, `indices = [9, 8, 2, 1, 0, 8, 9]` and `lengths =
523
+ [3, 4]`. The original relative indices within a sample for the indices
524
+ are `[0, 1, 2, 0, 1, 2, 3]`
525
+
526
+ sequence (bool): If True, return the new indices positions in the original
527
+ indices positions (the tensor is called `unbucketize_permute_data`).
528
+
529
+ block_sizes (Tensor): This tensor is used for the case where the bucket
530
+ size within a feature is uniform (i.e., when
531
+ `block_bucketize_pos=None`). The tensor contains bucket sizes (i.e.,
532
+ bucket widths) for each feature. `block_sizes[t]` represents the
533
+ bucket size of feature `t`. Shape is the number of features.
534
+
535
+ my_size (int): The number of buckets for each feature. Note that every
536
+ feature has the same number of buckets.
537
+
538
+ weights (Tensor): A float tensor that will be bucketized the same way as
539
+ `indices`. This tensor must have shape `[indices.size(0), weights_dim]`
540
+ where `weights_dim` is the dimension of the weight values for each index.
541
+
542
+ weights_dim (int = 1): The dimension of the weight values for each index.
543
+ This parameter is only used when `weights` is not None.
544
+
545
+ batch_size_per_feature (Optional[Tensor] = None): An optional tensor that
546
+ contains batch sizes for different features. If not None, batch sizes
547
+ are not uniform among features. Otherwise, the operator will assume
548
+ that the batch size is uniform and infer it from the `lengths` and
549
+ `block_sizes` tensors
550
+
551
+ max_B (int = -1): The max batch size. Must be set if
552
+ `batch_size_per_feature` is not None
553
+
554
+ block_bucketize_pos (Optional[List[Tensor]] = None): The input is used for
555
+ non-uniform bucket sizes within a feature. `block_bucketize_pos` is a
556
+ list of tensors. Each tensor contains the range offsets of buckets for
557
+ each feature. These range offsets are equivalent to the complete
558
+ cumulative sum of the bucket sizes. For example, `[0, 4, 20]` represents
559
+ two buckets. The first bucket size is `(4 - 0) = 4`, and the second
560
+ bucket size is `(20 - 4) = 16`. The length of `block_bucketize_pos`
561
+ must be equal to the number of features.
562
+
563
+ keep_orig_idx (bool = False): If True, return original indices instead of
564
+ the relative indices within each bucket
565
+
566
+ total_num_blocks (Optional[torch.Tensor] = None): An optional tensor that
567
+ contains then number of logical buckets (aka blocks) within a given
568
+ feature. This is useful for applications where the number of buckets
569
+ is more than the number of physical GPUs, which is common in cases
570
+ where we scale up/down the number of GPUs but want to maintain
571
+ same numerical behavior.
572
+
573
+ keep_orig_idx_per_feature (Optional[Tensor] = None): An optional tensor that
574
+ contains whether to keep original indices for each feature. If not None,
575
+ the operator will use this tensor to determine whether to keep original
576
+ indices for each feature. if None, will fallback to `keep_orig_idx`
577
+
578
+ Return:
579
+ A tuple of tensors containing
580
+
581
+ (1) Bucketized lengths. Shape is `lengths.num() * my_size`.
582
+
583
+ (2) Bucketized indices. Same shape as `indices`.
584
+
585
+ (3) Bucketized weights or None if `weights` is None. Shape is
586
+ `[indices.size(0), weights_dim]`.
587
+
588
+ (4) Bucketized positions or None if `bucketize_pos=False`. Same shape as
589
+ `indices`.
590
+
591
+ (5) `unbucketize_permute` or None if `sequence=False`. Same shape as
592
+ `indices`
593
+
594
+ **Example**:
595
+
596
+ >>> # Generate input example. Batch size = 2. Number of features = 4
597
+ >>> lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=torch.int, device="cuda")
598
+ >>> indices = torch.tensor([3, 4, 15, 11, 28, 29, 1, 10, 11, 12, 13, 11, 22, 20, 20], dtype=torch.int, device="cuda")
599
+ >>> block_sizes = torch.tensor([[5, 15, 10, 20]], dtype=torch.int, device="cuda")
600
+ >>> my_size = 2 # Number of buckets
601
+ >>> weights_dim = 3 # Dimension of weight values for each index
602
+ >>> weights = torch.randn(indices.size(0), weights_dim, dtype=torch.float, device="cuda")
603
+ >>> # Invoke with keep_orig_idx=False, bucketize_pos=False, and
604
+ >>> # sequence=False
605
+ >>> torch.ops.fbgemm.block_bucketize_sparse_features_2d_weights(
606
+ >>> lengths,
607
+ >>> indices,
608
+ >>> bucketize_pos=False,
609
+ >>> sequence=False,
610
+ >>> block_sizes=block_sizes,
611
+ >>> my_size=my_size,
612
+ >>> weights=weights,
613
+ >>> weights_dim=weights_dim,
614
+ >>> keep_orig_idx=False)
615
+ """,
616
+ )
@@ -0,0 +1,6 @@
1
+
2
+ {
3
+ "version": "2026.1.29",
4
+ "target": "default",
5
+ "variant": "cpu"
6
+ }
fbgemm_gpu/enums.py CHANGED
@@ -8,14 +8,13 @@
8
8
  # pyre-strict
9
9
 
10
10
  import enum
11
- import typing
12
- from typing import Any, Callable, List, Tuple
11
+ from typing import Any, Callable
13
12
 
14
13
 
15
14
  # Create enums in given namespace with information from query_op
16
15
  def create_enums(
17
- namespace: typing.Dict[str, Any],
18
- query_op: Callable[[], List[Tuple[str, List[Tuple[str, int]]]]],
16
+ namespace: dict[str, Any],
17
+ query_op: Callable[[], list[tuple[str, list[tuple[str, int]]]]],
19
18
  ) -> None:
20
19
  for enum_name, items in query_op():
21
20
  # Create matching python enumeration
fbgemm_gpu/fbgemm.so CHANGED
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -8,10 +8,11 @@
8
8
  # pyre-strict
9
9
 
10
10
  from itertools import accumulate
11
- from typing import List, Optional
11
+ from typing import Optional
12
12
 
13
13
  import torch
14
14
 
15
+ # fmt:skip
15
16
  from fbgemm_gpu.utils.loader import load_torch_module
16
17
 
17
18
  try:
@@ -93,8 +94,8 @@ class PermutePooledEmbeddings:
93
94
 
94
95
  def __init__(
95
96
  self,
96
- embs_dims: List[int],
97
- permute: List[int],
97
+ embs_dims: list[int],
98
+ permute: list[int],
98
99
  device: Optional[torch.device] = None,
99
100
  ) -> None:
100
101
  self._offset_dim_list: torch.Tensor = torch.tensor(
@@ -105,7 +106,7 @@ class PermutePooledEmbeddings:
105
106
  permute, device=device, dtype=torch.int64
106
107
  )
107
108
 
108
- inv_permute: List[int] = [0] * len(permute)
109
+ inv_permute: list[int] = [0] * len(permute)
109
110
  for i, p in enumerate(permute):
110
111
  inv_permute[p] = i
111
112
 
@@ -9,7 +9,7 @@
9
9
 
10
10
  import logging
11
11
  from itertools import accumulate
12
- from typing import List, Optional
12
+ from typing import Optional
13
13
 
14
14
  import torch
15
15
  from torch import nn
@@ -34,8 +34,8 @@ def _fx_wrap_tensor_to_device(t: torch.Tensor, device: torch.device) -> torch.Te
34
34
  class PermutePooledEmbeddingsSplit(nn.Module):
35
35
  def __init__(
36
36
  self,
37
- embs_dims: List[int],
38
- permute: List[int],
37
+ embs_dims: list[int],
38
+ permute: list[int],
39
39
  device: Optional[torch.device] = None,
40
40
  ) -> None:
41
41
  super(PermutePooledEmbeddingsSplit, self).__init__()
@@ -51,7 +51,7 @@ class PermutePooledEmbeddingsSplit(nn.Module):
51
51
  "_permute", torch.tensor(permute, device=device, dtype=torch.int64)
52
52
  )
53
53
 
54
- inv_permute: List[int] = [0] * len(permute)
54
+ inv_permute: list[int] = [0] * len(permute)
55
55
  for i, p in enumerate(permute):
56
56
  inv_permute[p] = i
57
57
 
@@ -11,6 +11,7 @@ from fbgemm_gpu.utils import TorchLibraryFragment
11
11
 
12
12
  lib = TorchLibraryFragment("fbgemm")
13
13
 
14
+ # fmt: off
14
15
  lib.define(
15
16
  """quantize_mx(
16
17
  Tensor input,
@@ -41,3 +42,4 @@ lib.register(
41
42
  "dequantize_mx",
42
43
  {"CUDA": dequantize_mx, "CPU": dequantize_mx},
43
44
  )
45
+ # fmt: on
@@ -9,6 +9,7 @@ from typing import Union
9
9
 
10
10
  import torch
11
11
 
12
+ # fmt:skip
12
13
  from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode
13
14
 
14
15