fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -8,7 +8,7 @@
8
8
 
9
9
  import logging
10
10
  from dataclasses import dataclass
11
- from typing import List, Optional, Tuple
11
+ from typing import Optional
12
12
 
13
13
  import numpy as np
14
14
  import numpy.typing as npt
@@ -32,20 +32,20 @@ class TBERequest:
32
32
  indices: torch.Tensor
33
33
  offsets: torch.Tensor
34
34
  per_sample_weights: Optional[torch.Tensor] = None
35
- Bs_per_feature_per_rank: Optional[List[List[int]]] = None
35
+ Bs_per_feature_per_rank: Optional[list[list[int]]] = None
36
36
 
37
- def unpack_2(self) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ def unpack_2(self) -> tuple[torch.Tensor, torch.Tensor]:
38
38
  return (self.indices, self.offsets)
39
39
 
40
40
  def unpack_3(
41
41
  self,
42
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
42
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
43
43
  return (self.indices, self.offsets, self.per_sample_weights)
44
44
 
45
45
  def unpack_4(
46
46
  self,
47
- ) -> Tuple[
48
- torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[List[List[int]]]
47
+ ) -> tuple[
48
+ torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]
49
49
  ]:
50
50
  return (
51
51
  self.indices,
@@ -56,21 +56,36 @@ class TBERequest:
56
56
 
57
57
 
58
58
  def generate_requests_from_data_file(
59
- requests_data_file: str,
60
59
  iters: int,
61
60
  B: int,
62
61
  T: int,
63
62
  L: int,
64
63
  E: int,
65
64
  weighted: bool,
65
+ requests_data_file: Optional[str] = None,
66
+ indices_file: Optional[str] = None,
67
+ offsets_file: Optional[str] = None,
66
68
  tables: Optional[str] = None,
67
69
  index_dtype: Optional[torch.dtype] = None,
68
70
  offset_dtype: Optional[torch.dtype] = None,
69
- ) -> List[TBERequest]:
71
+ ) -> list[TBERequest]:
70
72
  """
71
- Generate TBE requests from the input data file (`requests_data_file`)
73
+ Generate TBE requests from the input data file. If `requests_data_file` is provided,
74
+ `indices_file` and `offsets_file` should not be provided. If either `indices_file`
75
+ or `offsets_file` is provided, both must be provided.
72
76
  """
73
- indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file)
77
+ assert not (
78
+ requests_data_file and (indices_file or offsets_file)
79
+ ), "If requests_data_file is provided, indices_file and offsets_file cannot be provided."
80
+ assert (
81
+ indices_file and offsets_file
82
+ ), "Both indices_file and offsets_file must be provided if either is provided."
83
+
84
+ if requests_data_file:
85
+ indices_tensor, offsets_tensor, *rest = torch.load(requests_data_file)
86
+ else:
87
+ indices_tensor = torch.load(indices_file)
88
+ offsets_tensor = torch.load(offsets_file)
74
89
 
75
90
  average_L = 0
76
91
  if tables is not None:
@@ -104,7 +119,7 @@ def generate_requests_from_data_file(
104
119
  average_L = int((offsets_tensor[-1] - offsets_tensor[0]) / B)
105
120
  assert (np.prod(offsets_tensor.size()) - 1) == np.prod((T, B)), (
106
121
  f"Data file (indices = {indices_tensor.size()}, "
107
- f"offsets = {offsets_tensor.size()}, lengths = {lengths_tensor.size()}) "
122
+ f"offsets = {offsets_tensor.size()}, lengths = {offsets_tensor.size() - 1}) "
108
123
  f"does not conform to inputs (T, B) = ({T}, {B})."
109
124
  )
110
125
 
@@ -163,12 +178,12 @@ def generate_int_data_from_stats(
163
178
 
164
179
  def generate_pooling_factors_from_stats(
165
180
  iters: int,
166
- Bs: List[int],
181
+ Bs: list[int],
167
182
  L: int,
168
183
  sigma_L: int,
169
184
  # distribution of pooling factors
170
185
  length_dist: str,
171
- ) -> Tuple[int, torch.Tensor]:
186
+ ) -> tuple[int, torch.Tensor]:
172
187
  """
173
188
  Generate pooling factors for the TBE requests from the given stats
174
189
  """
@@ -196,7 +211,7 @@ def generate_batch_sizes_from_stats(
196
211
  vbe_num_ranks: int,
197
212
  # Distribution of batch sizes
198
213
  batch_size_dist: str,
199
- ) -> Tuple[List[int], List[List[int]]]:
214
+ ) -> tuple[list[int], list[list[int]]]:
200
215
  """
201
216
  Generate batch sizes for features from the given stats
202
217
  """
@@ -219,7 +234,7 @@ def generate_batch_sizes_from_stats(
219
234
 
220
235
  def generate_indices_uniform(
221
236
  iters: int,
222
- Bs: List[int],
237
+ Bs: list[int],
223
238
  L: int,
224
239
  E: int,
225
240
  use_variable_L: bool,
@@ -237,7 +252,7 @@ def generate_indices_uniform(
237
252
  dtype=torch.int32,
238
253
  )
239
254
  # each bag is usually sorted
240
- (indices, _) = torch.sort(indices)
255
+ indices, _ = torch.sort(indices)
241
256
  if use_variable_L:
242
257
  # 1D layout, where row offsets are determined by L_offsets
243
258
  indices = torch.ops.fbgemm.bottom_k_per_row(
@@ -252,7 +267,7 @@ def generate_indices_uniform(
252
267
 
253
268
  def generate_indices_zipf(
254
269
  iters: int,
255
- Bs: List[int],
270
+ Bs: list[int],
256
271
  L: int,
257
272
  E: int,
258
273
  alpha: float,
@@ -309,7 +324,7 @@ def generate_indices_zipf(
309
324
 
310
325
  def update_indices_with_random_reuse(
311
326
  iters: int,
312
- Bs: List[int],
327
+ Bs: list[int],
313
328
  L: int,
314
329
  reuse: float,
315
330
  indices: torch.Tensor,
@@ -371,6 +386,9 @@ def generate_requests( # noqa C901
371
386
  zipf_oversample_ratio: int = 3,
372
387
  weighted: bool = False,
373
388
  requests_data_file: Optional[str] = None,
389
+ # Path to file containing indices and offsets. If provided, this will be used
390
+ indices_file: Optional[str] = None,
391
+ offsets_file: Optional[str] = None,
374
392
  # Comma-separated list of table numbers
375
393
  tables: Optional[str] = None,
376
394
  # If sigma_L is not None, treat L as mu_L and generate Ls from sigma_L
@@ -393,21 +411,28 @@ def generate_requests( # noqa C901
393
411
  vbe_num_ranks: Optional[int] = None,
394
412
  index_dtype: Optional[torch.dtype] = None,
395
413
  offset_dtype: Optional[torch.dtype] = None,
396
- ) -> List[TBERequest]:
414
+ ) -> list[TBERequest]:
397
415
  # TODO: refactor and split into helper functions to separate load from file,
398
416
  # generate from distribution, and other future methods of generating data
399
- if requests_data_file is not None:
417
+ if (
418
+ requests_data_file is not None
419
+ or indices_file is not None
420
+ or offsets_file is not None
421
+ ):
422
+
400
423
  assert sigma_L is None, "Variable pooling factors is not supported"
401
424
  assert sigma_B is None, "Variable batch sizes is not supported"
402
425
  return generate_requests_from_data_file(
403
- requests_data_file,
404
- iters,
405
- B,
406
- T,
407
- L,
408
- E,
409
- weighted,
410
- tables,
426
+ iters=iters,
427
+ B=B,
428
+ T=T,
429
+ L=L,
430
+ E=E,
431
+ weighted=weighted,
432
+ requests_data_file=requests_data_file,
433
+ indices_file=indices_file,
434
+ offsets_file=offsets_file,
435
+ tables=tables,
411
436
  index_dtype=index_dtype,
412
437
  offset_dtype=offset_dtype,
413
438
  )
@@ -8,9 +8,8 @@
8
8
  # pyre-unsafe
9
9
 
10
10
  import abc
11
-
12
11
  from dataclasses import dataclass
13
- from typing import List, Optional
12
+ from typing import Optional
14
13
 
15
14
  from torch import Tensor
16
15
 
@@ -22,15 +21,25 @@ class TBEInfo:
22
21
 
23
22
  Args:
24
23
  table_names: table names within the tbe
25
- table_heights: table heights (hashsize)
24
+ table_heights: sharded table heights (hashsize)
26
25
  tbe_uuid: a unique identifier for the TBE
27
26
  feature_table_map: feature to table map
27
+ table_dims: sharded table dimensions
28
+ full_table_heights: table heights before sharding
29
+ full_table_dims: table dimensions before sharding
30
+ row_offset: the shard offset of the current rank on row (height)
31
+ col_offset: the shard offset of the current rank on column (dim)
28
32
  """
29
33
 
30
- table_names: List[str]
31
- table_heights: List[int]
34
+ table_names: list[str]
35
+ table_heights: list[int]
32
36
  tbe_uuid: str
33
- feature_table_map: List[int]
37
+ feature_table_map: list[int]
38
+ table_dims: list[int]
39
+ full_table_heights: list[int]
40
+ full_table_dims: list[int]
41
+ row_offset: list[int]
42
+ col_offset: list[int]
34
43
 
35
44
 
36
45
  @dataclass(frozen=True)
@@ -45,7 +54,7 @@ class TBEInputInfo:
45
54
 
46
55
  indices: Tensor
47
56
  offsets: Tensor
48
- batch_size_per_feature_per_rank: Optional[List[List[int]]] = None
57
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None
49
58
 
50
59
 
51
60
  class TBEInputMultiplexer(abc.ABC):
@@ -10,7 +10,6 @@ from enum import IntEnum
10
10
 
11
11
  import torch
12
12
 
13
-
14
13
  # We keep LUTs persistent to minimize the number of device copies required.
15
14
  E2M1_LUT = torch.tensor(
16
15
  [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6],
@@ -9,7 +9,7 @@
9
9
 
10
10
  # pyre-ignore-all-errors[6]
11
11
 
12
- from typing import List, Optional, Tuple, Union
12
+ from typing import Optional, Union
13
13
 
14
14
  import torch
15
15
  import triton # @manual
@@ -472,7 +472,7 @@ def triton_jagged_to_dense_optimization_2d(
472
472
  # In FBGEMM it was computed by GPU but in triton currently has some compilation issue so we use CUP computation method as workaround
473
473
  # However in real-world case if we only dealing with 2d jagged tensor we don't need to use this function at all
474
474
  def _jagged_offsets_to_dense_indice(
475
- offsets: List[torch.Tensor], dense_strides: List[int], dense_sizes: List[int]
475
+ offsets: list[torch.Tensor], dense_strides: list[int], dense_sizes: list[int]
476
476
  ) -> torch.Tensor:
477
477
 
478
478
  output_offset = torch.zeros(len(offsets[-1]) - 1, device="cpu", dtype=torch.int32)
@@ -532,8 +532,8 @@ def _jagged_offsets_to_dense_indice(
532
532
  # not be affected at all
533
533
  def jagged_to_dense(
534
534
  jagged_values: torch.Tensor,
535
- jagged_offsets: List[torch.Tensor],
536
- jagged_max_lengths: List[int],
535
+ jagged_offsets: list[torch.Tensor],
536
+ jagged_max_lengths: list[int],
537
537
  padding_value: float = 0.0, # padding value currently use 0.0 as default value
538
538
  operation_function: Union[
539
539
  str, None
@@ -720,10 +720,10 @@ def triton_dense_to_jagged(
720
720
 
721
721
  def dense_to_jagged(
722
722
  dense: torch.Tensor,
723
- jagged_offsets: List[torch.Tensor],
723
+ jagged_offsets: list[torch.Tensor],
724
724
  operation_function: Union[str, None] = None,
725
725
  operation_jagged_values: Union[torch.Tensor, None] = None,
726
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
726
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
727
727
 
728
728
  thread_block_row_size = 32
729
729
  thread_block_col_size = 32
@@ -780,7 +780,7 @@ def dense_to_jagged(
780
780
  # jagged_tensor + dense -> dense
781
781
  def jagged_dense_elementwise_add_dense_output(
782
782
  jagged_values: Tensor,
783
- jagged_offsets: List[Tensor],
783
+ jagged_offsets: list[Tensor],
784
784
  # pyre-fixme[2]: Parameter must be annotated.
785
785
  dense,
786
786
  ) -> Tensor:
@@ -800,8 +800,8 @@ def jagged_dense_elementwise_add_dense_output(
800
800
 
801
801
  # jagged_tensor + dense -> jagged_tensor
802
802
  def jagged_dense_elementwise_add_jagged_output(
803
- jagged_values: Optional[Tensor], jagged_offsets: List[Tensor], dense: Tensor
804
- ) -> Tuple[Tensor, List[Tensor]]:
803
+ jagged_values: Optional[Tensor], jagged_offsets: list[Tensor], dense: Tensor
804
+ ) -> tuple[Tensor, list[Tensor]]:
805
805
 
806
806
  return dense_to_jagged(
807
807
  dense,
@@ -813,8 +813,8 @@ def jagged_dense_elementwise_add_jagged_output(
813
813
 
814
814
  # jagged_tensor * dense -> jagged_tensor
815
815
  def jagged_dense_elementwise_mul_jagged_output(
816
- jagged_values: Optional[Tensor], jagged_offsets: List[Tensor], dense: Tensor
817
- ) -> Tuple[Tensor, List[Tensor]]:
816
+ jagged_values: Optional[Tensor], jagged_offsets: list[Tensor], dense: Tensor
817
+ ) -> tuple[Tensor, list[Tensor]]:
818
818
 
819
819
  return dense_to_jagged(
820
820
  dense,
@@ -11,7 +11,6 @@ from typing import Union
11
11
 
12
12
  import torch
13
13
  import triton # @manual
14
-
15
14
  import triton.language as tl # @manual
16
15
 
17
16
  from .common import get_mx4_exp_bias, get_mx4_lookup_table, RoundingMode
@@ -238,7 +237,7 @@ def _kernel_quantize_mx4(
238
237
  # We readd fp32_exp_bias for compatibility with cuda dequant.
239
238
  tl.store(
240
239
  out + exp_offset,
241
- (group_exp + FP32_EXP_BIAS).to(tl.int8),
240
+ (group_exp + FP32_EXP_BIAS).to(tl.uint8),
242
241
  # Prevent writing outside this chunk or the main array.
243
242
  mask=(exp_offset < OUTPUT_SIZE)
244
243
  & (exp_offset < (OUTPUT_CHUNK_SIZE * (pid + 1))),
@@ -575,7 +574,7 @@ def _kernel_dequantize_mx4(
575
574
  # Write final outputs.
576
575
  tl.store(
577
576
  out + output_offset,
578
- scaled_fp32,
577
+ scaled_fp32.to(out.dtype.element_ty),
579
578
  # Mask values that are out of this chunk or the main array.
580
579
  mask=(output_offset < OUTPUT_SIZE)
581
580
  & (output_offset < OUTPUT_CHUNK_SIZE * (pid + 1)),
@@ -588,10 +587,14 @@ def _kernel_dequantize_mx4(
588
587
 
589
588
 
590
589
  def triton_dequantize_mx4(
591
- a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1
590
+ a: torch.Tensor,
591
+ group_size: int = 32,
592
+ ebits: int = 2,
593
+ mbits: int = 1,
594
+ output_dtype: torch.dtype = torch.float32,
592
595
  ) -> torch.Tensor:
593
596
  """
594
- Dequantize a tensor from mx4 format to fp32.
597
+ Dequantize a tensor from mx4 format to fp32 or bf16.
595
598
 
596
599
  Args:
597
600
  a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
@@ -599,13 +602,15 @@ def triton_dequantize_mx4(
599
602
  group_size (int): Size of chunks that use the same shared exponent.
600
603
  ebits (int): Number of bits to use for exponent in target mx4 format.
601
604
  mbits (int): Number of bits to use for mantissa in target mx4 format.
605
+ output_dtype (torch.dtype): Output dtype (FP32 or BF16).
606
+ Defaults to torch.float32 for backward compatibility.
602
607
 
603
608
  Returns:
604
- torch.Tensor: [M, K] dequantized fp32 tensor.
609
+ torch.Tensor: [M, K] dequantized tensor in the specified dtype.
605
610
  """
606
611
  # If given an empty shape, return an empty tensor.
607
612
  if a.numel() == 0:
608
- return torch.empty(a.shape, device=a.device, dtype=torch.float32)
613
+ return torch.empty(a.shape, device=a.device, dtype=output_dtype)
609
614
  # View a as 2D for simplicity.
610
615
  orig_shape = a.shape
611
616
  a = a.flatten()
@@ -622,9 +627,9 @@ def triton_dequantize_mx4(
622
627
  # Use a lookup table to convert
623
628
  mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, a.device)
624
629
 
625
- # Create output tensor.
630
+ # Create output tensor in target dtype.
626
631
  output_elems = num_groups * group_size
627
- out = torch.empty([output_elems], device=a.device, dtype=torch.float)
632
+ out = torch.empty([output_elems], device=a.device, dtype=output_dtype)
628
633
  # Check if we need to use int64 for indexing.
629
634
  use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
630
635
  # Invoke triton dequantization kernel over rows.
@@ -11,7 +11,6 @@
11
11
  import io
12
12
  import logging
13
13
  import os
14
- import shutil
15
14
  from dataclasses import dataclass
16
15
  from pathlib import Path
17
16
  from typing import BinaryIO, Union
@@ -36,8 +35,6 @@ class FileStore:
36
35
  bucket: str
37
36
 
38
37
  def __post_init__(self) -> None:
39
- # self.bucket = bucket
40
-
41
38
  if not os.path.isdir(self.bucket):
42
39
  raise ValueError(f"Directory {self.bucket} does not exist")
43
40
 
@@ -78,7 +75,12 @@ class FileStore:
78
75
  elif isinstance(raw_input, Path):
79
76
  if not os.path.exists(raw_input):
80
77
  raise FileNotFoundError(f"File {raw_input} does not exist")
81
- shutil.copyfile(raw_input, filepath)
78
+ # Open the source file and destination file, and copy the contents
79
+ with open(raw_input, "rb") as src_file, open(
80
+ filepath, "wb"
81
+ ) as dst_file:
82
+ while chunk := src_file.read(4096): # Read 4 KB at a time
83
+ dst_file.write(chunk)
82
84
 
83
85
  elif isinstance(raw_input, io.BytesIO) or isinstance(raw_input, BinaryIO):
84
86
  with open(filepath, "wb") as file:
@@ -157,4 +159,53 @@ class FileStore:
157
159
  True if file exists, False otherwise.
158
160
  """
159
161
  filepath = f"{self.bucket}/{path}"
160
- return os.path.isfile(filepath)
162
+ return os.path.exists(filepath)
163
+
164
+ def create_directory(self, path: str) -> "FileStore":
165
+ """
166
+ Creates a directory in the file store.
167
+
168
+ Args:
169
+ path (str): The path of the node or symlink to a directory (relative
170
+ to `self.bucket`) to be created.
171
+
172
+ Returns:
173
+ self. This allows for method-chaining.
174
+ """
175
+ filepath = f"{self.bucket}/{path}"
176
+ event = f"creating directory {filepath}"
177
+ logger.info(f"FileStore: {event}")
178
+
179
+ try:
180
+ if not os.path.exists(filepath):
181
+ os.makedirs(filepath, exist_ok=True)
182
+ except Exception as e:
183
+ logger.error(f"FileStore: exception occurred when {event}: {e}")
184
+ raise e
185
+
186
+ return self
187
+
188
+ def remove_directory(self, path: str) -> "FileStore":
189
+ """
190
+ Removes a directory from the file store.
191
+
192
+ Args:
193
+ path (str): The path of the node or symlink to a directory (relative
194
+ to `self.bucket`) to be removed.
195
+
196
+ Returns:
197
+ self. This allows for method-chaining.
198
+ """
199
+ filepath = f"{self.bucket}/{path}"
200
+ event = f"deleting {filepath}"
201
+ logger.info(f"FileStore: {event}")
202
+
203
+ try:
204
+ if os.path.isdir(filepath):
205
+ os.rmdir(filepath)
206
+
207
+ except Exception as e:
208
+ logger.error(f"Manifold: exception occurred when {event}: {e}")
209
+ raise e
210
+
211
+ return self
@@ -8,7 +8,7 @@
8
8
  # pyre-strict
9
9
 
10
10
  import re
11
- from typing import Callable, Dict
11
+ from typing import Callable
12
12
 
13
13
  import torch
14
14
 
@@ -112,7 +112,7 @@ class TorchLibraryFragment:
112
112
  self.lib.impl(op_name, fn, dispatch_key)
113
113
 
114
114
  # pyre-ignore[24]
115
- def register(self, op_name: str, functors: Dict[str, Callable]) -> None:
115
+ def register(self, op_name: str, functors: dict[str, Callable]) -> None:
116
116
  """
117
117
  Registers a set of dispatches for a defined operator.
118
118
 
@@ -0,0 +1,124 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+
10
+ def writeback_update_gradient(
11
+ indices: torch.Tensor,
12
+ offsets: torch.Tensor,
13
+ grad: torch.Tensor,
14
+ feature_table_map: list[int],
15
+ ) -> torch.Tensor:
16
+ """
17
+ Update gradient tensor by deduplicating indices across all features/tables.
18
+ For duplicate indices, only the first occurrence receives the gradient to achieve the assign purpose via gradient update
19
+
20
+ NOTE: This function is not supporting VBE yet
21
+
22
+ Args:
23
+ indices (torch.Tensor): Embedding indices tensor
24
+ offsets (torch.Tensor): Offsets tensor for batched embeddings
25
+ grad (torch.Tensor): Gradient tensor to be updated
26
+ feature_table_map (list[int]): Mapping from feature to table
27
+
28
+ Returns:
29
+ torch.Tensor: Updated gradient tensor with duplicates masked out
30
+ """
31
+ if indices.numel() == 0:
32
+ return grad[0]
33
+ # get num of feature to estimate batch size
34
+ num_of_tables = len(feature_table_map)
35
+ assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
36
+ batch_size = offsets.shape[0] // num_of_tables
37
+ max_indices = indices.max()
38
+ non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
39
+ # disable dedup across different table
40
+ indices = ((offsets[non_empty_index]) // batch_size) * (1 + max_indices) + indices
41
+ grad = grad[0]
42
+ _, idx, counts = torch.unique(
43
+ indices, dim=0, sorted=True, return_inverse=True, return_counts=True
44
+ )
45
+ _, ind_sorted = torch.sort(idx, stable=True)
46
+ cum_sum = counts.cumsum(0)
47
+ cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
48
+ first_indicies = ind_sorted[cum_sum]
49
+ mask = torch.zeros_like(grad, device=grad.device)
50
+ original_index = non_empty_index[first_indicies]
51
+
52
+ mask[original_index] = grad[original_index]
53
+ return mask
54
+
55
+
56
+ def writeback_update_gradient_first_feature_only(
57
+ indices: torch.Tensor,
58
+ offsets: torch.Tensor,
59
+ grad: torch.Tensor,
60
+ feature_table_map: list[int],
61
+ ) -> torch.Tensor:
62
+ """
63
+ Special case of writeback_update_gradient where gradient only needs to be updated for the first feature. Other features will be forward-only
64
+
65
+ NOTE: This function is not supporting VBE yet
66
+
67
+ Args:
68
+ indices (torch.Tensor): Embedding indices tensor
69
+ offsets (torch.Tensor): Offsets tensor for batched embeddings
70
+ grad (torch.Tensor): Gradient tensor to be updated
71
+ feature_table_map (list[int]): Mapping from feature to table
72
+
73
+ Returns:
74
+ torch.Tensor: Updated gradient tensor with duplicates masked out
75
+ """
76
+ num_of_tables = len(feature_table_map)
77
+ batch_size = (offsets.shape[0] - 1) // num_of_tables
78
+ shrink_indices = indices[: offsets[batch_size]]
79
+ if shrink_indices.numel() == 0 or indices.numel() == 0:
80
+ return grad[0]
81
+ assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
82
+
83
+ grad = grad[0]
84
+ _, idx, counts = torch.unique(
85
+ shrink_indices, dim=0, sorted=True, return_inverse=True, return_counts=True
86
+ )
87
+ _, ind_sorted = torch.sort(idx, stable=True)
88
+ cum_sum = counts.cumsum(0)
89
+ cum_sum = torch.cat((torch.tensor([0]).to(shrink_indices.device), cum_sum[:-1]))
90
+ first_indicies = ind_sorted[cum_sum]
91
+ mask = torch.zeros_like(grad, device=grad.device)
92
+
93
+ mask[first_indicies] = grad[first_indicies]
94
+ return mask
95
+
96
+
97
+ def writeback_gradient(
98
+ grad: torch.Tensor,
99
+ indices: torch.Tensor,
100
+ offsets: torch.Tensor,
101
+ feature_table_map: list[int],
102
+ writeback_first_feature_only: bool = False,
103
+ ) -> tuple[torch.Tensor]:
104
+ """
105
+ Compute deduplicated gradient for writeback operation.
106
+
107
+ Args:
108
+ grad (torch.Tensor): Gradient tensor to be updated
109
+ indices (torch.Tensor): Embedding indices tensor
110
+ offsets (torch.Tensor): Offsets tensor for batched embeddings
111
+ feature_table_map (list[int]): Mapping from feature to table
112
+ writeback_first_feature_only (bool): If True, only first feature will apply gradient update, other features will be read-only
113
+
114
+ Returns:
115
+ tuple[torch.Tensor]: Tuple containing the updated gradient tensor
116
+ """
117
+ if writeback_first_feature_only:
118
+ return (
119
+ writeback_update_gradient_first_feature_only(
120
+ indices, offsets, grad, feature_table_map
121
+ ),
122
+ )
123
+ else:
124
+ return (writeback_update_gradient(indices, offsets, grad, feature_table_map),)
fbgemm_gpu/uvm.py CHANGED
@@ -12,6 +12,7 @@ from typing import Optional
12
12
 
13
13
  import torch
14
14
 
15
+ # fmt:skip
15
16
  from fbgemm_gpu.enums import create_enums
16
17
 
17
18
  try:
@@ -21,6 +22,8 @@ except Exception:
21
22
  torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils")
22
23
 
23
24
  # Import all uvm enums from c++ library
25
+ # pyre-fixme[6]: For 2nd argument expected `() -> List[Tuple[str, List[Tuple[str,
26
+ # int]]]]` but got `OpOverloadPacket`.
24
27
  create_enums(globals(), torch.ops.fbgemm.fbgemm_gpu_uvm_enum_query)
25
28
 
26
29
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: fbgemm_gpu_nightly-cpu
3
- Version: 2025.3.27
3
+ Version: 2026.1.29
4
4
  Home-page: https://github.com/pytorch/fbgemm
5
5
  Author: FBGEMM Team
6
6
  Author-email: packages@pytorch.org
@@ -12,11 +12,11 @@ Classifier: Intended Audience :: Science/Research
12
12
  Classifier: License :: OSI Approved :: BSD License
13
13
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
14
  Classifier: Programming Language :: Python :: 3
15
- Classifier: Programming Language :: Python :: 3.9
16
15
  Classifier: Programming Language :: Python :: 3.10
17
16
  Classifier: Programming Language :: Python :: 3.11
18
17
  Classifier: Programming Language :: Python :: 3.12
19
18
  Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Programming Language :: Python :: 3.14
20
20
  Description-Content-Type: text/markdown
21
21
  Requires-Dist: numpy
22
22
  Dynamic: author
@@ -40,9 +40,6 @@ PyTorch GPU operator libraries for training and inference. The library provides
40
40
  efficient table batched embedding bag, data layout transformation, and
41
41
  quantization supports.
42
42
 
43
- FBGEMM_GPU is currently tested with CUDA 12.4 and 11.8 in CI, and with PyTorch
44
- packages (2.1+) that are built against those CUDA versions.
45
-
46
43
  See the full [Documentation](https://pytorch.org/FBGEMM) for more information
47
44
  on building, installing, and developing with FBGEMM_GPU, as well as the most
48
45
  up-to-date support matrix for this library.