fbgemm-gpu-genai-nightly 2026.1.4__cp313-cp313-manylinux_2_28_x86_64.whl → 2026.1.9__cp313-cp313-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
 
2
2
  {
3
- "version": "2026.1.4",
3
+ "version": "2026.1.9",
4
4
  "target": "genai",
5
5
  "variant": "cuda"
6
6
  }
fbgemm_gpu/fbgemm.so CHANGED
Binary file
@@ -25,7 +25,7 @@ from fbgemm_gpu.quantize_utils import (
25
25
  fp32_to_hfp8_with_clamp,
26
26
  fp32_to_mx4,
27
27
  hfp8_to_fp32,
28
- mx4_to_fp32,
28
+ mx4_to_float,
29
29
  RoundingMode,
30
30
  )
31
31
 
@@ -123,7 +123,7 @@ def _dequantize_tensor(
123
123
  comm_precision: SparseType,
124
124
  ctx: Optional[QuantizationContext] = None,
125
125
  is_fwd: bool = True,
126
- fp8_output_dtype: Optional[SparseType] = None,
126
+ output_dtype: Optional[SparseType] = None,
127
127
  ) -> torch.Tensor:
128
128
  if comm_precision == SparseType.FP32:
129
129
  assert quantized_tensor.dtype == torch.float
@@ -138,10 +138,8 @@ def _dequantize_tensor(
138
138
  if ctx is not None and ctx.row_dim > 0:
139
139
  row_dim_quant = ctx.row_dim_quant
140
140
  quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
141
- # use provided fp8_output_dtype or default to FP32 (0)
142
- output_dtype_int = (
143
- fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0
144
- )
141
+ # use provided output_dtype or default to FP32 (0)
142
+ output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
145
143
  dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
146
144
  quantized_tensor_2d,
147
145
  is_fwd,
@@ -161,7 +159,7 @@ def _dequantize_tensor(
161
159
  return dequant_tensor.view(-1)
162
160
  elif comm_precision == SparseType.MX4:
163
161
  mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
164
- return mx4_to_fp32(quantized_tensor, mx_group_size)
162
+ return mx4_to_float(quantized_tensor, mx_group_size, output_dtype=output_dtype)
165
163
  else:
166
164
  raise ValueError(f"comm_precision={comm_precision} is not supported")
167
165
 
@@ -175,7 +173,7 @@ class QuantizedCommCodec:
175
173
  row_dim: Optional[int] = None,
176
174
  is_fwd: bool = True,
177
175
  rounding_mode: Optional[RoundingMode] = None,
178
- fp8_output_dtype: Optional[SparseType] = None,
176
+ output_dtype: Optional[SparseType] = None,
179
177
  ) -> None:
180
178
  if loss_scale is not None:
181
179
  if comm_precision not in [SparseType.FP16, SparseType.BF16]:
@@ -193,7 +191,7 @@ class QuantizedCommCodec:
193
191
  self._is_fwd = is_fwd
194
192
  self._row_dim: int = -1 if row_dim is None else row_dim
195
193
  self._rounding_mode: Optional[RoundingMode] = rounding_mode
196
- self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype
194
+ self._output_dtype: Optional[SparseType] = output_dtype
197
195
  if self._comm_precision == SparseType.MX4:
198
196
  self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
199
197
  self._rounding_mode = (
@@ -229,7 +227,7 @@ class QuantizedCommCodec:
229
227
  self._comm_precision,
230
228
  ctx,
231
229
  self._is_fwd,
232
- fp8_output_dtype=self._fp8_output_dtype,
230
+ output_dtype=self._output_dtype,
233
231
  )
234
232
  return dequantized_tensor
235
233
 
@@ -14,9 +14,15 @@ import torch # isort:skip
14
14
 
15
15
  import fbgemm_gpu
16
16
 
17
- from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
17
+ from fbgemm_gpu.split_embedding_configs import SparseType
18
+ from fbgemm_gpu.triton.common import RoundingMode
18
19
  from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
19
20
 
21
+ if torch.cuda.is_available():
22
+ from fbgemm_gpu.triton import quantize_mx4
23
+ from fbgemm_gpu.triton.quantize import triton_dequantize_mx4
24
+
25
+
20
26
  try:
21
27
  # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
22
28
  open_source = bool(getattr(fbgemm_gpu, "open_source", False))
@@ -126,25 +132,71 @@ def mx4_to_fp32(
126
132
  ) -> torch.Tensor:
127
133
  """Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
128
134
 
135
+ This function is kept for backward compatibility and always returns FP32.
136
+ For BF16 output, use mx4_to_float() with output_dtype=SparseType.BF16.
137
+ """
138
+ return mx4_to_float(
139
+ tensor,
140
+ group_size,
141
+ use_triton,
142
+ ebits,
143
+ mbits,
144
+ output_dtype=None, # None = FP32 default for backward compatibility
145
+ )
146
+
147
+
148
+ def mx4_to_float(
149
+ tensor: torch.Tensor,
150
+ group_size: int = 32,
151
+ use_triton: bool = True,
152
+ ebits: int = 2,
153
+ mbits: int = 1,
154
+ output_dtype: Optional[SparseType] = None,
155
+ ) -> torch.Tensor:
156
+ """Dequantize an MX4 tensor to FP32 or BF16 with triton or native cuda impl.
157
+
129
158
  Args:
130
159
  tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
131
160
  group_size (int): Compute scale in chunks of group_size.
132
161
  use_triton (bool): If set, use triton quantization, otherwise cuda.
133
162
  ebits (int): Number of exponent bits in target mx4 format.
134
163
  mbits (int): Number of mantissa bits in target mx4 format.
164
+ output_dtype (Optional[SparseType]): Output dtype (FP32 or BF16).
165
+ Defaults to None (FP32) for backward compatibility.
135
166
 
136
167
  Return:
137
- output: FP32 tensor with total elements (M).
168
+ output: Tensor with dtype matching output_dtype and total elements (M).
138
169
  """
170
+ # Validate output_dtype
171
+ supported_dtypes = {SparseType.FP32, SparseType.BF16}
172
+ if output_dtype is not None and output_dtype not in supported_dtypes:
173
+ raise ValueError(
174
+ f"output_dtype must be one of {supported_dtypes}, got {output_dtype}. "
175
+ f"FP16 is not supported due to potential overflow/underflow with MX4's wide exponent range. "
176
+ f"Use BF16 for memory savings with same dynamic range as FP32."
177
+ )
178
+
179
+ target_dtype = (
180
+ output_dtype.as_dtype() if output_dtype is not None else torch.float32
181
+ )
182
+
139
183
  # Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
140
184
  if not tensor.is_cuda and not tensor.is_mtia:
141
- return py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
185
+ result = py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
186
+ return result.to(target_dtype) if output_dtype is not None else result
142
187
  if use_triton:
143
188
  if tensor.is_mtia:
144
- return mtia_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
145
- return dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
189
+ return mtia_dequantize_mx4(
190
+ tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
191
+ )
192
+ return triton_dequantize_mx4(
193
+ tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
194
+ )
146
195
  else:
147
- return torch.ops.fbgemm.dequantize_mx_cuda(tensor.flatten(), group_size)
196
+ output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
197
+ return torch.ops.fbgemm.dequantize_mx_cuda(
198
+ tensor.flatten(), group_size, output_dtype_int
199
+ )
148
200
 
149
201
 
150
202
  def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
@@ -575,7 +575,7 @@ def _kernel_dequantize_mx4(
575
575
  # Write final outputs.
576
576
  tl.store(
577
577
  out + output_offset,
578
- scaled_fp32,
578
+ scaled_fp32.to(out.dtype.element_ty),
579
579
  # Mask values that are out of this chunk or the main array.
580
580
  mask=(output_offset < OUTPUT_SIZE)
581
581
  & (output_offset < OUTPUT_CHUNK_SIZE * (pid + 1)),
@@ -588,10 +588,14 @@ def _kernel_dequantize_mx4(
588
588
 
589
589
 
590
590
  def triton_dequantize_mx4(
591
- a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1
591
+ a: torch.Tensor,
592
+ group_size: int = 32,
593
+ ebits: int = 2,
594
+ mbits: int = 1,
595
+ output_dtype: torch.dtype = torch.float32,
592
596
  ) -> torch.Tensor:
593
597
  """
594
- Dequantize a tensor from mx4 format to fp32.
598
+ Dequantize a tensor from mx4 format to fp32 or bf16.
595
599
 
596
600
  Args:
597
601
  a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
@@ -599,13 +603,15 @@ def triton_dequantize_mx4(
599
603
  group_size (int): Size of chunks that use the same shared exponent.
600
604
  ebits (int): Number of bits to use for exponent in target mx4 format.
601
605
  mbits (int): Number of bits to use for mantissa in target mx4 format.
606
+ output_dtype (torch.dtype): Output dtype (FP32 or BF16).
607
+ Defaults to torch.float32 for backward compatibility.
602
608
 
603
609
  Returns:
604
- torch.Tensor: [M, K] dequantized fp32 tensor.
610
+ torch.Tensor: [M, K] dequantized tensor in the specified dtype.
605
611
  """
606
612
  # If given an empty shape, return an empty tensor.
607
613
  if a.numel() == 0:
608
- return torch.empty(a.shape, device=a.device, dtype=torch.float32)
614
+ return torch.empty(a.shape, device=a.device, dtype=output_dtype)
609
615
  # View a as 2D for simplicity.
610
616
  orig_shape = a.shape
611
617
  a = a.flatten()
@@ -622,9 +628,9 @@ def triton_dequantize_mx4(
622
628
  # Use a lookup table to convert
623
629
  mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, a.device)
624
630
 
625
- # Create output tensor.
631
+ # Create output tensor in target dtype.
626
632
  output_elems = num_groups * group_size
627
- out = torch.empty([output_elems], device=a.device, dtype=torch.float)
633
+ out = torch.empty([output_elems], device=a.device, dtype=output_dtype)
628
634
  # Check if we need to use int64 for indexing.
629
635
  use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
630
636
  # Invoke triton dequantization kernel over rows.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fbgemm_gpu_genai_nightly
3
- Version: 2026.1.4
3
+ Version: 2026.1.9
4
4
  Home-page: https://github.com/pytorch/fbgemm
5
5
  Author: FBGEMM Team
6
6
  Author-email: packages@pytorch.org
@@ -2,12 +2,12 @@ fbgemm_gpu/__init__.py,sha256=bL2dL7uYeXb1GvdjIDUTcLXLRGNfmnI4MQoE3-Gg5m8,6361
2
2
  fbgemm_gpu/asmjit.so,sha256=UxnhHlu9LgmoRXa8fZwSX56b5QKffBxfAOs0AZLxRfk,501728
3
3
  fbgemm_gpu/batched_unary_embeddings_ops.py,sha256=GYeJ9pg-Wc9FokXVci_npDsL6UV18-pJXID2xzrJ9O8,2904
4
4
  fbgemm_gpu/enums.py,sha256=37ewGSfO1x7sO31ZkRiqV1yKuklfHXT5qZIxzeeGogo,755
5
- fbgemm_gpu/fbgemm.so,sha256=U864UANx-CVyFYk5ADawCd0uWRfntHaVcyl6AVty_3Q,5642616
5
+ fbgemm_gpu/fbgemm.so,sha256=HQUXhk9ikCtui6125NmRjOJvAMaWRgZXLlBcQHsh2Xo,5646712
6
6
  fbgemm_gpu/metrics.py,sha256=TsurFLJf0nJvPDN7urWb4LMQlf5RgdWPTTTDO7S4wtI,5663
7
7
  fbgemm_gpu/permute_pooled_embedding_modules.py,sha256=vOXMYclaGnwSt0St_SOAlAe18kz6WjMyTeHnC9jLhcE,5130
8
8
  fbgemm_gpu/permute_pooled_embedding_modules_split.py,sha256=f3VJvH_kw9Ltd_DXtaf_PJPHmlmEWrQgzQ7MDkhh5Nw,2746
9
- fbgemm_gpu/quantize_comm.py,sha256=ZfXtRHfqpVpV6k2PDL6oTUkKYzopqAV2M6vavp_RLSM,12022
10
- fbgemm_gpu/quantize_utils.py,sha256=q8Aokk6nlHbXF6HcDBbhBCAGSZV4klM8uPF-MUFFtAw,8324
9
+ fbgemm_gpu/quantize_comm.py,sha256=j4-wBqWRtXjhtQBKi7IOAftNDzv8-AeX9YXlD8e682c,11983
10
+ fbgemm_gpu/quantize_utils.py,sha256=fK4Dk9Qpjsu4qASCwAxkLjbiFRLI71Hd-AtHY4NyMZ8,10200
11
11
  fbgemm_gpu/runtime_monitor.py,sha256=YXRUv6nXCsoTgh5_RzailTGvCYzwoYDb-eR4rlGwtaw,7619
12
12
  fbgemm_gpu/sparse_ops.py,sha256=_EJC1pAbNnAnVQQ5JBg4DAV2TboIj-4XQkiKMmg1vXI,50417
13
13
  fbgemm_gpu/split_embedding_configs.py,sha256=EuVFKIDrgRQpRC5mmB4Du6WftK5GXJvDue9_ezt_eBI,16575
@@ -32,9 +32,9 @@ fbgemm_gpu/docs/merge_pooled_embedding_ops.py,sha256=oJLgSgZQmhsyGLbTmZTxNgQrk65
32
32
  fbgemm_gpu/docs/permute_pooled_embedding_ops.py,sha256=tZUqLVXlk5O6VAKKDA-OEMx2fCu5QPOOeoAPZA9_nLY,4454
33
33
  fbgemm_gpu/docs/quantize_ops.py,sha256=xTtOaVK1P02ymreE_i21YiyYDZCqhoZY9eWp_mEIRlo,1297
34
34
  fbgemm_gpu/docs/sparse_ops.py,sha256=gSLUFdnu8lle_6gLewFkM20wL3ek2jKLvDGMKR6POaY,27292
35
- fbgemm_gpu/docs/target.genai.json.py,sha256=5TMzQCJ6eRjDaUActAOucxjizI7IZg56rn512-ujiE4,77
35
+ fbgemm_gpu/docs/target.genai.json.py,sha256=TVO8vYaBQPaEdT-bYeXlTdOGiTw4ceWeAWa9m9Wnerg,77
36
36
  fbgemm_gpu/experimental/example/__init__.py,sha256=OvJHZgWnycL1gWKyCXFJCTKuys3KAqx4iadjx3R-tBQ,723
37
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so,sha256=y0Z22D1LnOkH9vXtRVPYWJ5raZC27OTViPEtnqi8TyY,190656
37
+ fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so,sha256=9NH_0L5RRD5NwIOILKiOGjKo25isXYYkmwFvbIlUGe0,190656
38
38
  fbgemm_gpu/experimental/example/utils.py,sha256=Je__VkMlBMLOhh7NXOocOdvaa2gz9kl9Dkqeu25tpFA,562
39
39
  fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py,sha256=1CqUfzlYyXTvU-BNaUq4RZpLV-2lKAVCAHeJzSIZFWw,419
40
40
  fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py,sha256=R4VNZdPSgmRmwDfTt2CShED2SGUF6dCXSUW2C4LISgE,215713
@@ -43,7 +43,7 @@ fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py,sha256=5ClZ-GDrx6q0uaqW
43
43
  fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py,sha256=SltbY_dsit5e7B8lDIB_VYPrEq0t9kckthj9mQaVNfA,7571
44
44
  fbgemm_gpu/experimental/gemm/triton_gemm/utils.py,sha256=rULXIpVaaRS3GKUZ1RHcWUrUyy0xMVREwS1SFShGgcw,4302
45
45
  fbgemm_gpu/experimental/gen_ai/__init__.py,sha256=r3NlNCXuIh0pfKwKU5v14y6AZkpoIkKWbtzxSprgeKA,1713
46
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so,sha256=2iHWrQDzhysRNMPbjFQpsxNdAkIRq__vTHy75sa4kJo,65238760
46
+ fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so,sha256=_1HEs4E69AIPZ7zauPqcHB01egv5eNf2IKNZuihMOAA,65230568
47
47
  fbgemm_gpu/experimental/gen_ai/quantize.py,sha256=KAljWSdN-1_c5DWfT-3MDxWLMULK49Yu36t6TmQI9Tw,12599
48
48
  fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py,sha256=-R_LxyHpdXMILU9TNuYoRisBCkfK0_VLyixefaeZf4g,1463
49
49
  fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py,sha256=gbhNU3mDTKJb3yt3inIDbiUjX_SG1oZfzgDygtHvMpk,10101
@@ -111,7 +111,7 @@ fbgemm_gpu/tbe/utils/quantize.py,sha256=icN2MXnl5rNqtKhGKkjpelx5pYBMYUv-6CrghxeV
111
111
  fbgemm_gpu/tbe/utils/requests.py,sha256=rQkEoaUUWEYCQM-1K_Lxg1wPcyIVw8sbdaGFTpsaE5I,18040
112
112
  fbgemm_gpu/triton/__init__.py,sha256=kPn_Ye6J9DAzWtqi76KYGwfKSqw0IhqG3Bir5aUpkWM,658
113
113
  fbgemm_gpu/triton/common.py,sha256=wnkLd2a8fKpefymLL-LjNKEL4hDVSxFiF5g3aF8mzsw,2131
114
- fbgemm_gpu/triton/quantize.py,sha256=z3y74-DCbGcQDsO70b2jK_HQDIYC0UJ7IEG2vvMu0_Y,26816
114
+ fbgemm_gpu/triton/quantize.py,sha256=bjMPgcUOcuG_d9I_EjSCpkU3Fr5f3FU7CIWIqzc_N3w,27074
115
115
  fbgemm_gpu/triton/quantize_ref.py,sha256=q4RBmFaqPVPELU52lbSgB0n26Aun7apeK7bRF2MWS80,11553
116
116
  fbgemm_gpu/triton/jagged/__init__.py,sha256=om0yhjuzKuE1UQakFMWHsXN4WNb8mvNkZtYofQ8hdn4,246
117
117
  fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py,sha256=F2eQWjkWMR5RWQ48oIr-8OU_CRZyLazDpT7DFrDWS6g,29871
@@ -122,7 +122,7 @@ fbgemm_gpu/utils/torch_library.py,sha256=ywsAHjbuwesj50LjEu99WkAH17FlaVgePZ9OmFg
122
122
  fbgemm_gpu/utils/writeback_util.py,sha256=PyVbHp1EuF-GKrJv_CTP6B50Z0oBblXKucf7Rhd6KKY,4614
123
123
  list_versions/__init__.py,sha256=UmTeqCk-UJWFtlZQWvZao3xvui2w9E3X_JdOXVjRaNw,315
124
124
  list_versions/cli_run.py,sha256=CChZoXQ-tiKaWboXAYlPVJ5w8K5zAKiKcncA087I1sc,4508
125
- fbgemm_gpu_genai_nightly-2026.1.4.dist-info/METADATA,sha256=MjhefCkWlccqGa-waygmSKkW1vaKWbpxX1U8VLRrMJ0,2655
126
- fbgemm_gpu_genai_nightly-2026.1.4.dist-info/WHEEL,sha256=Nkv8TSWVt7XcnRf1cdq5HOzycTl6Pjzlmn7gPSv4NiQ,108
127
- fbgemm_gpu_genai_nightly-2026.1.4.dist-info/top_level.txt,sha256=_2s1Aa08r_eDn0JP4FjOhzK09Q8bVlEI7q8pMep51UY,25
128
- fbgemm_gpu_genai_nightly-2026.1.4.dist-info/RECORD,,
125
+ fbgemm_gpu_genai_nightly-2026.1.9.dist-info/METADATA,sha256=WQ7sQGvWWGQao3Wcjk39i_bFY2BvrmWfGHxHgWxEsug,2655
126
+ fbgemm_gpu_genai_nightly-2026.1.9.dist-info/WHEEL,sha256=Nkv8TSWVt7XcnRf1cdq5HOzycTl6Pjzlmn7gPSv4NiQ,108
127
+ fbgemm_gpu_genai_nightly-2026.1.9.dist-info/top_level.txt,sha256=_2s1Aa08r_eDn0JP4FjOhzK09Q8bVlEI7q8pMep51UY,25
128
+ fbgemm_gpu_genai_nightly-2026.1.9.dist-info/RECORD,,