fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/__init__.py CHANGED
@@ -5,17 +5,106 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ import json
8
9
  import logging
9
10
  import os
11
+ import re
10
12
 
11
13
  import torch
12
14
 
15
+ # Based on the FBGEMM-PyTorch compatibility table at
16
+ # https://docs.pytorch.org/FBGEMM/general/Releases.html#fbgemm-releases-compatibility
17
+ _fbgemm_torch_compat_table = {
18
+ "1.5": "2.10",
19
+ "1.4": "2.9",
20
+ "1.3": "2.8",
21
+ "1.2": "2.7",
22
+ "1.1": "2.6",
23
+ "1.0": "2.5",
24
+ "0.8": "2.4",
25
+ "0.7": "2.3",
26
+ "0.6": "2.2",
27
+ "0.5": "2.1",
28
+ "0.4": "2.0",
29
+ }
30
+
31
+
32
+ def _load_target_info(target: str) -> dict[str, str]:
33
+ try:
34
+ filepath = os.path.join(
35
+ os.path.dirname(__file__), "docs", f"target.{target}.json.py"
36
+ )
37
+ with open(filepath, "r") as file:
38
+ data = json.load(file)
39
+ except Exception:
40
+ data = {}
41
+
42
+ return data
13
43
 
14
- def _load_library(filename: str, no_throw: bool = False) -> None:
44
+
45
+ def _load_library(filename: str, version: str, no_throw: bool = False) -> None:
15
46
  """Load a shared library from the given filename."""
47
+
48
+ # Check if the version of PyTorch is compatible with the version of FBGEMM
49
+ # that we are trying to load, and print a loud warning if not. This is
50
+ # useful for the OSS build, where we have a single FBGEMM library that is
51
+ # compatible with multiple versions of PyTorch.
52
+ #
53
+ # Based on: https://github.com/pytorch/ao/blob/main/torchao/__init__.py#L30
54
+
55
+ keys = [
56
+ key
57
+ for key in _fbgemm_torch_compat_table.keys()
58
+ if version.startswith(f"{key}.")
59
+ ]
60
+
61
+ if version == "INTERNAL" or "+git" in version:
62
+ # if FBGEMM version has "+git", assume it's locally built and we don't know
63
+ # anything about the PyTorch version used to build it
64
+ logging.info(
65
+ "FBGEMM version is INTERNAL or local, ignoring version compatibility check with PyTorch"
66
+ )
67
+
68
+ elif re.match(r"^\d{4}\.\d{1,2}\.\d{1,2}.*$", version):
69
+ # if FBGEMM version is a date, assume it's a nightly build and that we
70
+ # know what we're doing
71
+ logging.info(
72
+ "FBGEMM version is a nightly version, ignoring version compatibility check with PyTorch"
73
+ )
74
+
75
+ elif not keys:
76
+ # fmt: off
77
+ logging.warning(
78
+ f"""
79
+ \033[33m
80
+ _fbgemm_torch_compat_table has no entry for {version} of FBGEMM;
81
+ cannot determine compatibility with PyTorch {torch.__version__}
82
+ \033[0m
83
+ """
84
+ )
85
+ # fmt: on
86
+
87
+ elif not str(torch.__version__).startswith(_fbgemm_torch_compat_table[keys[0]]):
88
+ # fmt: off
89
+ logging.warning(
90
+ f"""
91
+ \033[31m
92
+ FBGEMM_GPU version is {version}, which is not guaranteed to be
93
+ compatible with PyTorch {torch.__version__}; library loading might
94
+ crash!
95
+
96
+ Please refer to
97
+ https://docs.pytorch.org/FBGEMM/general/Releases.html#fbgemm-releases-compatibility
98
+ for the FBGEMM-PyTorch compatibility table.
99
+ \033[0m
100
+ """
101
+ )
102
+ # fmt: on
103
+
16
104
  try:
17
105
  torch.ops.load_library(os.path.join(os.path.dirname(__file__), filename))
18
106
  logging.info(f"Successfully loaded: '{filename}'")
107
+
19
108
  except Exception as error:
20
109
  logging.error(f"Could not load the library '{filename}'!\n\n\n{error}\n\n\n")
21
110
  if not no_throw:
@@ -29,13 +118,15 @@ open_source: bool = True
29
118
  # Trigger the manual addition of docstrings to pybind11-generated operators
30
119
  import fbgemm_gpu.docs # noqa: F401, E402
31
120
 
121
+ __targets_infos__ = {
122
+ target: _load_target_info(target) for target in ["default", "genai", "hstu"]
123
+ }
124
+ __targets_infos__ = {k: v for (k, v) in __targets_infos__.items() if v}
125
+
32
126
  try:
33
- # Export the version string from the version file auto-generated by setup.py
34
- from fbgemm_gpu.docs.version import ( # noqa: F401, E402
35
- __target__,
36
- __variant__,
37
- __version__,
38
- )
127
+ __target__, __info__ = next(iter(__targets_infos__.items()))
128
+ __variant__ = __info__["variant"]
129
+ __version__ = __info__["version"]
39
130
  except Exception:
40
131
  __variant__: str = "INTERNAL"
41
132
  __version__: str = "INTERNAL"
@@ -45,6 +136,7 @@ fbgemm_gpu_libraries = [
45
136
  "fbgemm_gpu_config",
46
137
  "fbgemm_gpu_tbe_utils",
47
138
  "fbgemm_gpu_tbe_index_select",
139
+ "fbgemm_gpu_tbe_cache",
48
140
  "fbgemm_gpu_tbe_optimizers",
49
141
  "fbgemm_gpu_tbe_inference",
50
142
  "fbgemm_gpu_tbe_training_forward",
@@ -76,18 +168,19 @@ libraries_to_load = {
76
168
  "genai": fbgemm_genai_libraries,
77
169
  }
78
170
 
79
- for library in libraries_to_load.get(__target__, []):
80
- # NOTE: In all cases, we want to throw an error if we cannot load the
81
- # library. However, this appears to break the OSS documentation build,
82
- # where the Python documentation doesn't show up in the generated docs.
83
- #
84
- # To work around this problem, we introduce a fake build variant called
85
- # `docs` and we only throw a library load error when the variant is not
86
- # `docs`. For more information, see:
87
- #
88
- # https://github.com/pytorch/FBGEMM/pull/3477
89
- # https://github.com/pytorch/FBGEMM/pull/3717
90
- _load_library(f"{library}.so", __variant__ == "docs")
171
+ for target, info in __targets_infos__.items():
172
+ for library in libraries_to_load.get(target, []):
173
+ # NOTE: In all cases, we want to throw an error if we cannot load the
174
+ # library. However, this appears to break the OSS documentation build,
175
+ # where the Python documentation doesn't show up in the generated docs.
176
+ #
177
+ # To work around this problem, we introduce a fake build variant called
178
+ # `docs` and we only throw a library load error when the variant is not
179
+ # `docs`. For more information, see:
180
+ #
181
+ # https://github.com/pytorch/FBGEMM/pull/3477
182
+ # https://github.com/pytorch/FBGEMM/pull/3717
183
+ _load_library(f"{library}.so", info["version"], info["variant"] == "docs")
91
184
 
92
185
  try:
93
186
  # Trigger meta operator registrations
fbgemm_gpu/asmjit.so CHANGED
Binary file
@@ -9,10 +9,10 @@
9
9
 
10
10
 
11
11
  from math import sqrt
12
- from typing import List
13
12
 
14
13
  import torch
15
14
 
15
+ # fmt:skip
16
16
  from fbgemm_gpu.utils.loader import load_torch_module
17
17
 
18
18
  try:
@@ -22,7 +22,7 @@ except Exception:
22
22
  load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
23
23
 
24
24
 
25
- def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]:
25
+ def wrap_weight_to_parameter(weights: list[torch.Tensor]) -> list[torch.Tensor]:
26
26
  for i, v in enumerate(weights):
27
27
  if not isinstance(v, torch.nn.Parameter):
28
28
  weights[i] = torch.nn.Parameter(v)
@@ -31,7 +31,7 @@ def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]:
31
31
 
32
32
  class BatchedUnaryEmbeddingBag(torch.nn.Module):
33
33
  # pyre-fixme[3]: Return type must be annotated.
34
- def __init__(self, num_tasks: int, hash_sizes: List[int], long_index: bool = False):
34
+ def __init__(self, num_tasks: int, hash_sizes: list[int], long_index: bool = False):
35
35
  super().__init__()
36
36
  self.num_tasks = num_tasks
37
37
  self.hash_sizes = hash_sizes
@@ -11,7 +11,7 @@ from enum import auto, Enum
11
11
  import torch
12
12
 
13
13
  try:
14
- torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:config_cpp")
14
+ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:config_cpp_torch_op")
15
15
  except Exception:
16
16
  import fbgemm_gpu # noqa F401
17
17
 
@@ -60,6 +60,12 @@ class FeatureGateName(Enum):
60
60
  # Enable bounds_check_indices_v2
61
61
  BOUNDS_CHECK_INDICES_V2 = auto()
62
62
 
63
+ # Enable TBE input parameters extraction
64
+ TBE_REPORT_INPUT_PARAMS = auto()
65
+
66
+ # Enable tuned max segment length per CTA for B200
67
+ TBE_USE_TUNED_SEGMENT_LENGTHS_CTA_B200 = auto()
68
+
63
69
  def is_enabled(self) -> bool:
64
70
  return FeatureGate.is_enabled(self)
65
71
 
@@ -9,7 +9,6 @@ import torch
9
9
 
10
10
  from .common import add_docs
11
11
 
12
-
13
12
  add_docs(
14
13
  torch.ops.fbgemm.jagged_2d_to_dense,
15
14
  """
@@ -496,3 +496,121 @@ Return:
496
496
  None)
497
497
  """,
498
498
  )
499
+
500
+ add_docs(
501
+ torch.ops.fbgemm.block_bucketize_sparse_features_2d_weights,
502
+ """
503
+ block_bucketize_sparse_features_2d_weights(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights, weights_dim=1, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False, total_num_blocks=None, keep_orig_idx_per_feature=None) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]
504
+
505
+ Preprocess sparse features by partitioning sparse features into multiple
506
+ buckets with support for 2D weights. Every feature is split into the same number of buckets, but the bucket
507
+ sizes (widths) for the different features can be different. Moreover, the
508
+ bucket sizes within each feature can be different.
509
+
510
+ This function is similar to block_bucketize_sparse_features but supports 2D weights,
511
+ where each index can have multiple weight values associated with it.
512
+
513
+ Args:
514
+ lengths (Tensor): The lengths of the sparse features. The tensor contains
515
+ the lengths of each sample in a batch and each feature. Shape is `B *
516
+ T` where `B` is the batch size and `T` is the number of features
517
+
518
+ indices (Tensor): The sparse data. Only support integer types. Shape is the
519
+ sum of `lengths`
520
+
521
+ bucketize_pos (bool): If True, return the original relative indices within
522
+ a sample. For example, `indices = [9, 8, 2, 1, 0, 8, 9]` and `lengths =
523
+ [3, 4]`. The original relative indices within a sample for the indices
524
+ are `[0, 1, 2, 0, 1, 2, 3]`
525
+
526
+ sequence (bool): If True, return the new indices positions in the original
527
+ indices positions (the tensor is called `unbucketize_permute_data`).
528
+
529
+ block_sizes (Tensor): This tensor is used for the case where the bucket
530
+ size within a feature is uniform (i.e., when
531
+ `block_bucketize_pos=None`). The tensor contains bucket sizes (i.e.,
532
+ bucket widths) for each feature. `block_sizes[t]` represents the
533
+ bucket size of feature `t`. Shape is the number of features.
534
+
535
+ my_size (int): The number of buckets for each feature. Note that every
536
+ feature has the same number of buckets.
537
+
538
+ weights (Tensor): A float tensor that will be bucketized the same way as
539
+ `indices`. This tensor must have shape `[indices.size(0), weights_dim]`
540
+ where `weights_dim` is the dimension of the weight values for each index.
541
+
542
+ weights_dim (int = 1): The dimension of the weight values for each index.
543
+ This parameter is only used when `weights` is not None.
544
+
545
+ batch_size_per_feature (Optional[Tensor] = None): An optional tensor that
546
+ contains batch sizes for different features. If not None, batch sizes
547
+ are not uniform among features. Otherwise, the operator will assume
548
+ that the batch size is uniform and infer it from the `lengths` and
549
+ `block_sizes` tensors
550
+
551
+ max_B (int = -1): The max batch size. Must be set if
552
+ `batch_size_per_feature` is not None
553
+
554
+ block_bucketize_pos (Optional[List[Tensor]] = None): The input is used for
555
+ non-uniform bucket sizes within a feature. `block_bucketize_pos` is a
556
+ list of tensors. Each tensor contains the range offsets of buckets for
557
+ each feature. These range offsets are equivalent to the complete
558
+ cumulative sum of the bucket sizes. For example, `[0, 4, 20]` represents
559
+ two buckets. The first bucket size is `(4 - 0) = 4`, and the second
560
+ bucket size is `(20 - 4) = 16`. The length of `block_bucketize_pos`
561
+ must be equal to the number of features.
562
+
563
+ keep_orig_idx (bool = False): If True, return original indices instead of
564
+ the relative indices within each bucket
565
+
566
+ total_num_blocks (Optional[torch.Tensor] = None): An optional tensor that
567
+ contains then number of logical buckets (aka blocks) within a given
568
+ feature. This is useful for applications where the number of buckets
569
+ is more than the number of physical GPUs, which is common in cases
570
+ where we scale up/down the number of GPUs but want to maintain
571
+ same numerical behavior.
572
+
573
+ keep_orig_idx_per_feature (Optional[Tensor] = None): An optional tensor that
574
+ contains whether to keep original indices for each feature. If not None,
575
+ the operator will use this tensor to determine whether to keep original
576
+ indices for each feature. if None, will fallback to `keep_orig_idx`
577
+
578
+ Return:
579
+ A tuple of tensors containing
580
+
581
+ (1) Bucketized lengths. Shape is `lengths.num() * my_size`.
582
+
583
+ (2) Bucketized indices. Same shape as `indices`.
584
+
585
+ (3) Bucketized weights or None if `weights` is None. Shape is
586
+ `[indices.size(0), weights_dim]`.
587
+
588
+ (4) Bucketized positions or None if `bucketize_pos=False`. Same shape as
589
+ `indices`.
590
+
591
+ (5) `unbucketize_permute` or None if `sequence=False`. Same shape as
592
+ `indices`
593
+
594
+ **Example**:
595
+
596
+ >>> # Generate input example. Batch size = 2. Number of features = 4
597
+ >>> lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=torch.int, device="cuda")
598
+ >>> indices = torch.tensor([3, 4, 15, 11, 28, 29, 1, 10, 11, 12, 13, 11, 22, 20, 20], dtype=torch.int, device="cuda")
599
+ >>> block_sizes = torch.tensor([[5, 15, 10, 20]], dtype=torch.int, device="cuda")
600
+ >>> my_size = 2 # Number of buckets
601
+ >>> weights_dim = 3 # Dimension of weight values for each index
602
+ >>> weights = torch.randn(indices.size(0), weights_dim, dtype=torch.float, device="cuda")
603
+ >>> # Invoke with keep_orig_idx=False, bucketize_pos=False, and
604
+ >>> # sequence=False
605
+ >>> torch.ops.fbgemm.block_bucketize_sparse_features_2d_weights(
606
+ >>> lengths,
607
+ >>> indices,
608
+ >>> bucketize_pos=False,
609
+ >>> sequence=False,
610
+ >>> block_sizes=block_sizes,
611
+ >>> my_size=my_size,
612
+ >>> weights=weights,
613
+ >>> weights_dim=weights_dim,
614
+ >>> keep_orig_idx=False)
615
+ """,
616
+ )
@@ -0,0 +1,6 @@
1
+
2
+ {
3
+ "version": "2026.1.29",
4
+ "target": "default",
5
+ "variant": "cpu"
6
+ }
fbgemm_gpu/enums.py CHANGED
@@ -8,14 +8,13 @@
8
8
  # pyre-strict
9
9
 
10
10
  import enum
11
- import typing
12
- from typing import Any, Callable, List, Tuple
11
+ from typing import Any, Callable
13
12
 
14
13
 
15
14
  # Create enums in given namespace with information from query_op
16
15
  def create_enums(
17
- namespace: typing.Dict[str, Any],
18
- query_op: Callable[[], List[Tuple[str, List[Tuple[str, int]]]]],
16
+ namespace: dict[str, Any],
17
+ query_op: Callable[[], list[tuple[str, list[tuple[str, int]]]]],
19
18
  ) -> None:
20
19
  for enum_name, items in query_op():
21
20
  # Create matching python enumeration
fbgemm_gpu/fbgemm.so CHANGED
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -8,10 +8,11 @@
8
8
  # pyre-strict
9
9
 
10
10
  from itertools import accumulate
11
- from typing import List, Optional
11
+ from typing import Optional
12
12
 
13
13
  import torch
14
14
 
15
+ # fmt:skip
15
16
  from fbgemm_gpu.utils.loader import load_torch_module
16
17
 
17
18
  try:
@@ -93,8 +94,8 @@ class PermutePooledEmbeddings:
93
94
 
94
95
  def __init__(
95
96
  self,
96
- embs_dims: List[int],
97
- permute: List[int],
97
+ embs_dims: list[int],
98
+ permute: list[int],
98
99
  device: Optional[torch.device] = None,
99
100
  ) -> None:
100
101
  self._offset_dim_list: torch.Tensor = torch.tensor(
@@ -105,7 +106,7 @@ class PermutePooledEmbeddings:
105
106
  permute, device=device, dtype=torch.int64
106
107
  )
107
108
 
108
- inv_permute: List[int] = [0] * len(permute)
109
+ inv_permute: list[int] = [0] * len(permute)
109
110
  for i, p in enumerate(permute):
110
111
  inv_permute[p] = i
111
112
 
@@ -9,7 +9,7 @@
9
9
 
10
10
  import logging
11
11
  from itertools import accumulate
12
- from typing import List, Optional
12
+ from typing import Optional
13
13
 
14
14
  import torch
15
15
  from torch import nn
@@ -34,8 +34,8 @@ def _fx_wrap_tensor_to_device(t: torch.Tensor, device: torch.device) -> torch.Te
34
34
  class PermutePooledEmbeddingsSplit(nn.Module):
35
35
  def __init__(
36
36
  self,
37
- embs_dims: List[int],
38
- permute: List[int],
37
+ embs_dims: list[int],
38
+ permute: list[int],
39
39
  device: Optional[torch.device] = None,
40
40
  ) -> None:
41
41
  super(PermutePooledEmbeddingsSplit, self).__init__()
@@ -51,7 +51,7 @@ class PermutePooledEmbeddingsSplit(nn.Module):
51
51
  "_permute", torch.tensor(permute, device=device, dtype=torch.int64)
52
52
  )
53
53
 
54
- inv_permute: List[int] = [0] * len(permute)
54
+ inv_permute: list[int] = [0] * len(permute)
55
55
  for i, p in enumerate(permute):
56
56
  inv_permute[p] = i
57
57
 
@@ -11,6 +11,7 @@ from fbgemm_gpu.utils import TorchLibraryFragment
11
11
 
12
12
  lib = TorchLibraryFragment("fbgemm")
13
13
 
14
+ # fmt: off
14
15
  lib.define(
15
16
  """quantize_mx(
16
17
  Tensor input,
@@ -41,3 +42,4 @@ lib.register(
41
42
  "dequantize_mx",
42
43
  {"CUDA": dequantize_mx, "CPU": dequantize_mx},
43
44
  )
45
+ # fmt: on
@@ -9,6 +9,7 @@ from typing import Union
9
9
 
10
10
  import torch
11
11
 
12
+ # fmt:skip
12
13
  from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode
13
14
 
14
15
 
@@ -13,10 +13,11 @@
13
13
 
14
14
 
15
15
  import logging
16
- from typing import List, Optional, Tuple, TypeVar
16
+ from typing import Optional, TypeVar
17
17
 
18
18
  import torch
19
19
 
20
+ # fmt:skip
20
21
  from fbgemm_gpu.quantize_utils import (
21
22
  bf16_to_fp32,
22
23
  fp16_to_fp32,
@@ -25,12 +26,10 @@ from fbgemm_gpu.quantize_utils import (
25
26
  fp32_to_hfp8_with_clamp,
26
27
  fp32_to_mx4,
27
28
  hfp8_to_fp32,
28
- mx4_to_fp32,
29
+ mx4_to_float,
29
30
  RoundingMode,
30
31
  )
31
-
32
32
  from fbgemm_gpu.split_embedding_configs import SparseType
33
-
34
33
  from torch.autograd.profiler import record_function # usort:skip
35
34
  from dataclasses import dataclass
36
35
 
@@ -66,8 +65,8 @@ class QuantizationContext:
66
65
  row_dim: int = ROW_DIM_DEFAULT
67
66
  row_dim_quant: int = -1
68
67
  mx_group_size: int = MX_GROUP_SIZE_DEFAULT
69
- rounding_mode: RoundingMode = RoundingMode.even
70
- padded_dim_sum_per_rank: Optional[List[int]] = None
68
+ rounding_mode: Optional[RoundingMode] = RoundingMode.even
69
+ padded_dim_sum_per_rank: Optional[list[int]] = None
71
70
 
72
71
 
73
72
  def _quantize_tensor(
@@ -123,6 +122,7 @@ def _dequantize_tensor(
123
122
  comm_precision: SparseType,
124
123
  ctx: Optional[QuantizationContext] = None,
125
124
  is_fwd: bool = True,
125
+ output_dtype: Optional[SparseType] = None,
126
126
  ) -> torch.Tensor:
127
127
  if comm_precision == SparseType.FP32:
128
128
  assert quantized_tensor.dtype == torch.float
@@ -137,8 +137,12 @@ def _dequantize_tensor(
137
137
  if ctx is not None and ctx.row_dim > 0:
138
138
  row_dim_quant = ctx.row_dim_quant
139
139
  quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
140
+ # use provided output_dtype or default to FP32 (0)
141
+ output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
140
142
  dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
141
- quantized_tensor_2d, is_fwd
143
+ quantized_tensor_2d,
144
+ is_fwd,
145
+ output_dtype_int,
142
146
  )
143
147
  return dequant_tensor.view(-1)
144
148
  else:
@@ -154,7 +158,7 @@ def _dequantize_tensor(
154
158
  return dequant_tensor.view(-1)
155
159
  elif comm_precision == SparseType.MX4:
156
160
  mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
157
- return mx4_to_fp32(quantized_tensor, mx_group_size)
161
+ return mx4_to_float(quantized_tensor, mx_group_size, output_dtype=output_dtype)
158
162
  else:
159
163
  raise ValueError(f"comm_precision={comm_precision} is not supported")
160
164
 
@@ -167,6 +171,8 @@ class QuantizedCommCodec:
167
171
  loss_scale: Optional[float] = None,
168
172
  row_dim: Optional[int] = None,
169
173
  is_fwd: bool = True,
174
+ rounding_mode: Optional[RoundingMode] = None,
175
+ output_dtype: Optional[SparseType] = None,
170
176
  ) -> None:
171
177
  if loss_scale is not None:
172
178
  if comm_precision not in [SparseType.FP16, SparseType.BF16]:
@@ -183,8 +189,13 @@ class QuantizedCommCodec:
183
189
  self._loss_scale = loss_scale
184
190
  self._is_fwd = is_fwd
185
191
  self._row_dim: int = -1 if row_dim is None else row_dim
192
+ self._rounding_mode: Optional[RoundingMode] = rounding_mode
193
+ self._output_dtype: Optional[SparseType] = output_dtype
186
194
  if self._comm_precision == SparseType.MX4:
187
195
  self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
196
+ self._rounding_mode = (
197
+ RoundingMode.even if rounding_mode is None else rounding_mode
198
+ )
188
199
 
189
200
  def encode(
190
201
  self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
@@ -211,7 +222,11 @@ class QuantizedCommCodec:
211
222
  f"## decoder {self._comm_precision} {self._loss_scale} ##"
212
223
  ):
213
224
  dequantized_tensor = _dequantize_tensor(
214
- input_tensor, self._comm_precision, ctx, self._is_fwd
225
+ input_tensor,
226
+ self._comm_precision,
227
+ ctx,
228
+ self._is_fwd,
229
+ output_dtype=self._output_dtype,
215
230
  )
216
231
  return dequantized_tensor
217
232
 
@@ -258,7 +273,9 @@ class QuantizedCommCodec:
258
273
  return QuantizationContext(self._row_dim)
259
274
  if self._comm_precision == SparseType.MX4:
260
275
  return QuantizationContext(
261
- row_dim=self._row_dim, mx_group_size=self._row_dim
276
+ row_dim=self._row_dim,
277
+ mx_group_size=self._row_dim,
278
+ rounding_mode=self._rounding_mode,
262
279
  )
263
280
  # int8 rowwise is default
264
281
  return QuantizationContext()
@@ -266,10 +283,10 @@ class QuantizedCommCodec:
266
283
  def padded_size(
267
284
  self,
268
285
  input_tensor: torch.Tensor,
269
- dim_per_rank: List[int],
286
+ dim_per_rank: list[int],
270
287
  my_rank: int,
271
288
  qcomm_ctx: QuantizationContext,
272
- ) -> Tuple[int, int]:
289
+ ) -> tuple[int, int]:
273
290
  if input_tensor.ndim == 1:
274
291
  return input_tensor.shape[0], 0
275
292
  # return padded size for the feature dimension (dim 1), 0 if no padding needed.